diff --git a/.travis.yml b/.travis.yml index ec66ad00..9d00539e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -13,18 +13,27 @@ env: - TOXENV=py27-django18-drf31 - TOXENV=py27-django18-drf32 - TOXENV=py27-django18-drf33 + - TOXENV=py27-django18-drf34 - TOXENV=py33-django18-drf31 - TOXENV=py33-django18-drf32 - TOXENV=py33-django18-drf33 + - TOXENV=py33-django18-drf34 - TOXENV=py34-django18-drf31 - TOXENV=py34-django18-drf32 - TOXENV=py34-django18-drf33 + - TOXENV=py34-django18-drf34 - TOXENV=py27-django19-drf31 - TOXENV=py27-django19-drf32 - TOXENV=py27-django19-drf33 + - TOXENV=py27-django19-drf34 - TOXENV=py34-django19-drf31 - TOXENV=py34-django19-drf32 - TOXENV=py34-django19-drf33 + - TOXENV=py34-django19-drf34 - TOXENV=py35-django19-drf31 - TOXENV=py35-django19-drf32 - TOXENV=py35-django19-drf33 + - TOXENV=py35-django19-drf34 + - TOXENV=py27-django110-drf34 + - TOXENV=py34-django110-drf34 + - TOXENV=py35-django110-drf34 diff --git a/rest_framework_json_api/renderers.py b/rest_framework_json_api/renderers.py index 83332566..581d6a81 100644 --- a/rest_framework_json_api/renderers.py +++ b/rest_framework_json_api/renderers.py @@ -87,20 +87,18 @@ def extract_relationships(fields, resource, resource_instance): continue source = field.source - try: - relation_instance_or_manager = getattr(resource_instance, source) - except AttributeError: - # if the field is not defined on the model then we check the serializer - # and if no value is there we skip over the field completely - serializer_method = getattr(field.parent, source, None) - if serializer_method and hasattr(serializer_method, '__call__'): - relation_instance_or_manager = serializer_method(resource_instance) - else: - continue - + serializer_method = getattr(field.parent, source, None) relation_type = utils.get_related_resource_type(field) if isinstance(field, relations.HyperlinkedIdentityField): + try: + relation_instance_or_manager = getattr(resource_instance, source) + except AttributeError: + if serializer_method and hasattr(serializer_method, '__call__'): + relation_instance_or_manager = serializer_method(resource_instance) + else: + continue + # special case for HyperlinkedIdentityField relation_data = list() @@ -124,6 +122,14 @@ def extract_relationships(fields, resource, resource_instance): continue if isinstance(field, ResourceRelatedField): + try: + relation_instance_or_manager = getattr(resource_instance, source) + except AttributeError: + if serializer_method and hasattr(serializer_method, '__call__'): + relation_instance_or_manager = serializer_method(resource_instance) + else: + continue + # special case for ResourceRelatedField relation_data = { 'data': resource.get(field_name) @@ -138,8 +144,15 @@ def extract_relationships(fields, resource, resource_instance): continue if isinstance(field, (relations.PrimaryKeyRelatedField, relations.HyperlinkedRelatedField)): - relation_id = relation_instance_or_manager.pk if resource.get(field_name) else None + try: + relation = getattr(resource_instance, '%s_id' % field.source) + except AttributeError: + if serializer_method and hasattr(serializer_method, '__call__'): + relation = serializer_method(resource_instance).pk + else: + continue + relation_id = relation if resource.get(field_name) else None relation_data = { 'data': ( OrderedDict([('type', relation_type), ('id', encoding.force_text(relation_id))]) @@ -154,6 +167,13 @@ def extract_relationships(fields, resource, resource_instance): continue if isinstance(field, relations.ManyRelatedField): + try: + relation_instance_or_manager = getattr(resource_instance, source) + except AttributeError: + if serializer_method and hasattr(serializer_method, '__call__'): + relation_instance_or_manager = serializer_method(resource_instance) + else: + continue if isinstance(field.child_relation, ResourceRelatedField): # special case for ResourceRelatedField @@ -194,6 +214,14 @@ def extract_relationships(fields, resource, resource_instance): continue if isinstance(field, ListSerializer): + try: + relation_instance_or_manager = getattr(resource_instance, source) + except AttributeError: + if serializer_method and hasattr(serializer_method, '__call__'): + relation_instance_or_manager = serializer_method(resource_instance) + else: + continue + relation_data = list() serializer_data = resource.get(field_name) @@ -211,6 +239,14 @@ def extract_relationships(fields, resource, resource_instance): continue if isinstance(field, ModelSerializer): + try: + relation_instance_or_manager = getattr(resource_instance, source) + except AttributeError: + if serializer_method and hasattr(serializer_method, '__call__'): + relation_instance_or_manager = serializer_method(resource_instance) + else: + continue + relation_model = field.Meta.model relation_type = utils.format_resource_type(relation_model.__name__) @@ -429,12 +465,7 @@ def render(self, data, accepted_media_type=None, renderer_context=None): serializer = getattr(serializer_data, 'serializer', None) - # Build a list of included resources - include_resources_param = request.query_params.get('include') if request else None - if include_resources_param: - included_resources = include_resources_param.split(',') - else: - included_resources = utils.get_default_included_resources_from_serializer(serializer) + included_resources = utils.get_included_resources(request, serializer) if serializer is not None: @@ -472,7 +503,6 @@ def render(self, data, accepted_media_type=None, renderer_context=None): if included: json_api_included.extend(included) - # Make sure we render data in a specific order render_data = OrderedDict() diff --git a/rest_framework_json_api/serializers.py b/rest_framework_json_api/serializers.py index 24a73015..917ae98c 100644 --- a/rest_framework_json_api/serializers.py +++ b/rest_framework_json_api/serializers.py @@ -6,7 +6,7 @@ from rest_framework_json_api.relations import ResourceRelatedField from rest_framework_json_api.utils import ( get_resource_type_from_model, get_resource_type_from_instance, - get_resource_type_from_serializer, get_included_serializers) + get_resource_type_from_serializer, get_included_serializers, get_included_resources) class ResourceIdentifierObjectSerializer(BaseSerializer): @@ -90,14 +90,12 @@ def validate_path(serializer_class, field_path, path): validate_path(this_included_serializer, new_included_field_path, path) if request and view: - include_resources_param = request.query_params.get('include') if request else None - if include_resources_param: - included_resources = include_resources_param.split(',') - for included_field_name in included_resources: - included_field_path = included_field_name.split('.') - this_serializer_class = view.get_serializer_class() - # lets validate the current path - validate_path(this_serializer_class, included_field_path, included_field_name) + included_resources = get_included_resources(request) + for included_field_name in included_resources: + included_field_path = included_field_name.split('.') + this_serializer_class = view.get_serializer_class() + # lets validate the current path + validate_path(this_serializer_class, included_field_path, included_field_name) super(IncludedResourcesValidationMixin, self).__init__(*args, **kwargs) diff --git a/rest_framework_json_api/utils.py b/rest_framework_json_api/utils.py index c532e78e..b09f5c68 100644 --- a/rest_framework_json_api/utils.py +++ b/rest_framework_json_api/utils.py @@ -235,6 +235,15 @@ def get_resource_type_from_serializer(serializer): return get_resource_type_from_model(serializer.Meta.model) +def get_included_resources(request, serializer=None): + """ Build a list of included resources. """ + include_resources_param = request.query_params.get('include') if request else None + if include_resources_param: + return include_resources_param.split(',') + else: + return get_default_included_resources_from_serializer(serializer) + + def get_default_included_resources_from_serializer(serializer): try: return list(serializer.JSONAPIMeta.included_resources) diff --git a/rest_framework_json_api/views.py b/rest_framework_json_api/views.py index 9d4a649d..77bf9b91 100644 --- a/rest_framework_json_api/views.py +++ b/rest_framework_json_api/views.py @@ -4,7 +4,17 @@ from django.db.models import Model from django.db.models.query import QuerySet from django.db.models.manager import Manager -from rest_framework import generics +if django.VERSION < (1, 9): + from django.db.models.fields.related import ( + ReverseSingleRelatedObjectDescriptor as ForwardManyToOneDescriptor, + ManyRelatedObjectsDescriptor as ManyToManyDescriptor, + ) +else: + from django.db.models.fields.related_descriptors import ( + ForwardManyToOneDescriptor, + ManyToManyDescriptor, + ) +from rest_framework import generics, viewsets from rest_framework.response import Response from rest_framework.exceptions import NotFound, MethodNotAllowed from rest_framework.reverse import reverse @@ -12,7 +22,43 @@ from rest_framework_json_api.exceptions import Conflict from rest_framework_json_api.serializers import ResourceIdentifierObjectSerializer -from rest_framework_json_api.utils import get_resource_type_from_instance, OrderedDict, Hyperlink +from rest_framework_json_api.utils import ( + get_resource_type_from_instance, + OrderedDict, + Hyperlink, + get_included_resources, +) + + +class ModelViewSet(viewsets.ModelViewSet): + def get_queryset(self, *args, **kwargs): + qs = super().get_queryset(*args, **kwargs) + included_resources = get_included_resources(self.request) + + for included in included_resources: + included_model = None + levels = included.split('.') + level_model = qs.model + for level in levels: + if not hasattr(level_model, level): + break + field = getattr(level_model, level) + field_class = field.__class__ + if not ( + issubclass(field_class, ForwardManyToOneDescriptor) + or issubclass(field_class, ManyToManyDescriptor) + ): + break + + if level == levels[-1]: + included_model = field + else: + level_model = field.get_queryset().model + + if included_model is not None: + qs = qs.prefetch_related(included.replace('.', '__')) + + return qs class RelationshipView(generics.GenericAPIView): diff --git a/tox.ini b/tox.ini index 438981dd..0e1c4062 100644 --- a/tox.ini +++ b/tox.ini @@ -2,16 +2,19 @@ envlist = py{27,33,34}-django17-drf{31,32}, py{27,33,34}-django18-drf{31,32,33}, - py{27,34,35}-django19-drf{31,32,33}, + py{27,34,35}-django19-drf{31,32,33,34}, + py{27,34,35}-django110-drf{34}, [testenv] deps = django17: Django>=1.7,<1.8 django18: Django>=1.8,<1.9 django19: Django>=1.9,<1.10 + django110: Django>=1.10,<1.11 drf31: djangorestframework>=3.1,<3.2 drf32: djangorestframework>=3.2,<3.3 drf33: djangorestframework>=3.3,<3.4 + drf34: djangorestframework>=3.4,<3.5 -r{toxinidir}/requirements-development.txt setenv = @@ -20,4 +23,3 @@ setenv = commands = py.test --basetemp={envtmpdir} -