Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,27 @@ env:
- TOXENV=py27-django18-drf31
- TOXENV=py27-django18-drf32
- TOXENV=py27-django18-drf33
- TOXENV=py27-django18-drf34
- TOXENV=py33-django18-drf31
- TOXENV=py33-django18-drf32
- TOXENV=py33-django18-drf33
- TOXENV=py33-django18-drf34
- TOXENV=py34-django18-drf31
- TOXENV=py34-django18-drf32
- TOXENV=py34-django18-drf33
- TOXENV=py34-django18-drf34
- TOXENV=py27-django19-drf31
- TOXENV=py27-django19-drf32
- TOXENV=py27-django19-drf33
- TOXENV=py27-django19-drf34
- TOXENV=py34-django19-drf31
- TOXENV=py34-django19-drf32
- TOXENV=py34-django19-drf33
- TOXENV=py34-django19-drf34
- TOXENV=py35-django19-drf31
- TOXENV=py35-django19-drf32
- TOXENV=py35-django19-drf33
- TOXENV=py35-django19-drf34
- TOXENV=py27-django110-drf34
- TOXENV=py34-django110-drf34
- TOXENV=py35-django110-drf34
68 changes: 49 additions & 19 deletions rest_framework_json_api/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,18 @@ def extract_relationships(fields, resource, resource_instance):
continue

source = field.source
try:
relation_instance_or_manager = getattr(resource_instance, source)
except AttributeError:
# if the field is not defined on the model then we check the serializer
# and if no value is there we skip over the field completely
serializer_method = getattr(field.parent, source, None)
if serializer_method and hasattr(serializer_method, '__call__'):
relation_instance_or_manager = serializer_method(resource_instance)
else:
continue

serializer_method = getattr(field.parent, source, None)
relation_type = utils.get_related_resource_type(field)

if isinstance(field, relations.HyperlinkedIdentityField):
try:
relation_instance_or_manager = getattr(resource_instance, source)
except AttributeError:
if serializer_method and hasattr(serializer_method, '__call__'):
relation_instance_or_manager = serializer_method(resource_instance)
else:
continue

# special case for HyperlinkedIdentityField
relation_data = list()

Expand All @@ -124,6 +122,14 @@ def extract_relationships(fields, resource, resource_instance):
continue

if isinstance(field, ResourceRelatedField):
try:
relation_instance_or_manager = getattr(resource_instance, source)
except AttributeError:
if serializer_method and hasattr(serializer_method, '__call__'):
relation_instance_or_manager = serializer_method(resource_instance)
else:
continue

# special case for ResourceRelatedField
relation_data = {
'data': resource.get(field_name)
Expand All @@ -138,8 +144,15 @@ def extract_relationships(fields, resource, resource_instance):
continue

if isinstance(field, (relations.PrimaryKeyRelatedField, relations.HyperlinkedRelatedField)):
relation_id = relation_instance_or_manager.pk if resource.get(field_name) else None
try:
relation = getattr(resource_instance, '%s_id' % field.source)
except AttributeError:
if serializer_method and hasattr(serializer_method, '__call__'):
relation = serializer_method(resource_instance).pk
else:
continue

relation_id = relation if resource.get(field_name) else None
relation_data = {
'data': (
OrderedDict([('type', relation_type), ('id', encoding.force_text(relation_id))])
Expand All @@ -154,6 +167,13 @@ def extract_relationships(fields, resource, resource_instance):
continue

if isinstance(field, relations.ManyRelatedField):
try:
relation_instance_or_manager = getattr(resource_instance, source)
except AttributeError:
if serializer_method and hasattr(serializer_method, '__call__'):
relation_instance_or_manager = serializer_method(resource_instance)
else:
continue

if isinstance(field.child_relation, ResourceRelatedField):
# special case for ResourceRelatedField
Expand Down Expand Up @@ -194,6 +214,14 @@ def extract_relationships(fields, resource, resource_instance):
continue

if isinstance(field, ListSerializer):
try:
relation_instance_or_manager = getattr(resource_instance, source)
except AttributeError:
if serializer_method and hasattr(serializer_method, '__call__'):
relation_instance_or_manager = serializer_method(resource_instance)
else:
continue

relation_data = list()

serializer_data = resource.get(field_name)
Expand All @@ -211,6 +239,14 @@ def extract_relationships(fields, resource, resource_instance):
continue

if isinstance(field, ModelSerializer):
try:
relation_instance_or_manager = getattr(resource_instance, source)
except AttributeError:
if serializer_method and hasattr(serializer_method, '__call__'):
relation_instance_or_manager = serializer_method(resource_instance)
else:
continue

relation_model = field.Meta.model
relation_type = utils.format_resource_type(relation_model.__name__)

Expand Down Expand Up @@ -429,12 +465,7 @@ def render(self, data, accepted_media_type=None, renderer_context=None):

serializer = getattr(serializer_data, 'serializer', None)

# Build a list of included resources
include_resources_param = request.query_params.get('include') if request else None
if include_resources_param:
included_resources = include_resources_param.split(',')
else:
included_resources = utils.get_default_included_resources_from_serializer(serializer)
included_resources = utils.get_included_resources(request, serializer)

if serializer is not None:

Expand Down Expand Up @@ -472,7 +503,6 @@ def render(self, data, accepted_media_type=None, renderer_context=None):
if included:
json_api_included.extend(included)


# Make sure we render data in a specific order
render_data = OrderedDict()

Expand Down
16 changes: 7 additions & 9 deletions rest_framework_json_api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from rest_framework_json_api.relations import ResourceRelatedField
from rest_framework_json_api.utils import (
get_resource_type_from_model, get_resource_type_from_instance,
get_resource_type_from_serializer, get_included_serializers)
get_resource_type_from_serializer, get_included_serializers, get_included_resources)


class ResourceIdentifierObjectSerializer(BaseSerializer):
Expand Down Expand Up @@ -90,14 +90,12 @@ def validate_path(serializer_class, field_path, path):
validate_path(this_included_serializer, new_included_field_path, path)

if request and view:
include_resources_param = request.query_params.get('include') if request else None
if include_resources_param:
included_resources = include_resources_param.split(',')
for included_field_name in included_resources:
included_field_path = included_field_name.split('.')
this_serializer_class = view.get_serializer_class()
# lets validate the current path
validate_path(this_serializer_class, included_field_path, included_field_name)
included_resources = get_included_resources(request)
for included_field_name in included_resources:
included_field_path = included_field_name.split('.')
this_serializer_class = view.get_serializer_class()
# lets validate the current path
validate_path(this_serializer_class, included_field_path, included_field_name)

super(IncludedResourcesValidationMixin, self).__init__(*args, **kwargs)

Expand Down
9 changes: 9 additions & 0 deletions rest_framework_json_api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,15 @@ def get_resource_type_from_serializer(serializer):
return get_resource_type_from_model(serializer.Meta.model)


def get_included_resources(request, serializer=None):
""" Build a list of included resources. """
include_resources_param = request.query_params.get('include') if request else None
if include_resources_param:
return include_resources_param.split(',')
else:
return get_default_included_resources_from_serializer(serializer)


def get_default_included_resources_from_serializer(serializer):
try:
return list(serializer.JSONAPIMeta.included_resources)
Expand Down
50 changes: 48 additions & 2 deletions rest_framework_json_api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,61 @@
from django.db.models import Model
from django.db.models.query import QuerySet
from django.db.models.manager import Manager
from rest_framework import generics
if django.VERSION < (1, 9):
from django.db.models.fields.related import (
ReverseSingleRelatedObjectDescriptor as ForwardManyToOneDescriptor,
ManyRelatedObjectsDescriptor as ManyToManyDescriptor,
)
else:
from django.db.models.fields.related_descriptors import (
ForwardManyToOneDescriptor,
ManyToManyDescriptor,
)
from rest_framework import generics, viewsets
from rest_framework.response import Response
from rest_framework.exceptions import NotFound, MethodNotAllowed
from rest_framework.reverse import reverse
from rest_framework.serializers import Serializer

from rest_framework_json_api.exceptions import Conflict
from rest_framework_json_api.serializers import ResourceIdentifierObjectSerializer
from rest_framework_json_api.utils import get_resource_type_from_instance, OrderedDict, Hyperlink
from rest_framework_json_api.utils import (
get_resource_type_from_instance,
OrderedDict,
Hyperlink,
get_included_resources,
)


class ModelViewSet(viewsets.ModelViewSet):
def get_queryset(self, *args, **kwargs):
qs = super().get_queryset(*args, **kwargs)
included_resources = get_included_resources(self.request)

for included in included_resources:
included_model = None
levels = included.split('.')
level_model = qs.model
for level in levels:
if not hasattr(level_model, level):
break
field = getattr(level_model, level)
field_class = field.__class__
if not (
issubclass(field_class, ForwardManyToOneDescriptor)
or issubclass(field_class, ManyToManyDescriptor)
):
break

if level == levels[-1]:
included_model = field
else:
level_model = field.get_queryset().model

if included_model is not None:
qs = qs.prefetch_related(included.replace('.', '__'))

return qs


class RelationshipView(generics.GenericAPIView):
Expand Down
6 changes: 4 additions & 2 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@
envlist =
py{27,33,34}-django17-drf{31,32},
py{27,33,34}-django18-drf{31,32,33},
py{27,34,35}-django19-drf{31,32,33},
py{27,34,35}-django19-drf{31,32,33,34},
py{27,34,35}-django110-drf{34},

[testenv]
deps =
django17: Django>=1.7,<1.8
django18: Django>=1.8,<1.9
django19: Django>=1.9,<1.10
django110: Django>=1.10,<1.11
drf31: djangorestframework>=3.1,<3.2
drf32: djangorestframework>=3.2,<3.3
drf33: djangorestframework>=3.3,<3.4
drf34: djangorestframework>=3.4,<3.5
-r{toxinidir}/requirements-development.txt

setenv =
Expand All @@ -20,4 +23,3 @@ setenv =

commands =
py.test --basetemp={envtmpdir}