diff --git a/rest_framework_json_api/renderers.py b/rest_framework_json_api/renderers.py index 1c66c927..af3625fb 100644 --- a/rest_framework_json_api/renderers.py +++ b/rest_framework_json_api/renderers.py @@ -86,20 +86,18 @@ def extract_relationships(fields, resource, resource_instance): continue source = field.source - try: - relation_instance_or_manager = getattr(resource_instance, source) - except AttributeError: - # if the field is not defined on the model then we check the serializer - # and if no value is there we skip over the field completely - serializer_method = getattr(field.parent, source, None) - if serializer_method and hasattr(serializer_method, '__call__'): - relation_instance_or_manager = serializer_method(resource_instance) - else: - continue - + serializer_method = getattr(field.parent, source, None) relation_type = utils.get_related_resource_type(field) if isinstance(field, relations.HyperlinkedIdentityField): + try: + relation_instance_or_manager = getattr(resource_instance, source) + except AttributeError: + if serializer_method and hasattr(serializer_method, '__call__'): + relation_instance_or_manager = serializer_method(resource_instance) + else: + continue + # special case for HyperlinkedIdentityField relation_data = list() @@ -123,6 +121,14 @@ def extract_relationships(fields, resource, resource_instance): continue if isinstance(field, ResourceRelatedField): + try: + relation_instance_or_manager = getattr(resource_instance, source) + except AttributeError: + if serializer_method and hasattr(serializer_method, '__call__'): + relation_instance_or_manager = serializer_method(resource_instance) + else: + continue + # special case for ResourceRelatedField relation_data = { 'data': resource.get(field_name) @@ -137,8 +143,15 @@ def extract_relationships(fields, resource, resource_instance): continue if isinstance(field, (relations.PrimaryKeyRelatedField, relations.HyperlinkedRelatedField)): - relation_id = relation_instance_or_manager.pk if resource.get(field_name) else None + try: + relation = getattr(resource_instance, '%s_id' % field.source) + except AttributeError: + if serializer_method and hasattr(serializer_method, '__call__'): + relation = serializer_method(resource_instance).pk + else: + continue + relation_id = relation if resource.get(field_name) else None relation_data = { 'data': ( OrderedDict([('type', relation_type), ('id', encoding.force_text(relation_id))]) @@ -153,6 +166,13 @@ def extract_relationships(fields, resource, resource_instance): continue if isinstance(field, relations.ManyRelatedField): + try: + relation_instance_or_manager = getattr(resource_instance, source) + except AttributeError: + if serializer_method and hasattr(serializer_method, '__call__'): + relation_instance_or_manager = serializer_method(resource_instance) + else: + continue if isinstance(field.child_relation, ResourceRelatedField): # special case for ResourceRelatedField @@ -193,6 +213,14 @@ def extract_relationships(fields, resource, resource_instance): continue if isinstance(field, ListSerializer): + try: + relation_instance_or_manager = getattr(resource_instance, source) + except AttributeError: + if serializer_method and hasattr(serializer_method, '__call__'): + relation_instance_or_manager = serializer_method(resource_instance) + else: + continue + relation_data = list() serializer_data = resource.get(field_name) @@ -210,6 +238,14 @@ def extract_relationships(fields, resource, resource_instance): continue if isinstance(field, ModelSerializer): + try: + relation_instance_or_manager = getattr(resource_instance, source) + except AttributeError: + if serializer_method and hasattr(serializer_method, '__call__'): + relation_instance_or_manager = serializer_method(resource_instance) + else: + continue + relation_model = field.Meta.model relation_type = utils.format_resource_type(relation_model.__name__) @@ -415,11 +451,7 @@ def render(self, data, accepted_media_type=None, renderer_context=None): if resource_name == 'errors': return self.render_errors(data, accepted_media_type, renderer_context) - include_resources_param = request.query_params.get('include') if request else None - if include_resources_param: - included_resources = include_resources_param.split(',') - else: - included_resources = list() + included_resources = utils.get_included_resources(request) json_api_data = data json_api_included = list() diff --git a/rest_framework_json_api/serializers.py b/rest_framework_json_api/serializers.py index 953c4437..15672e51 100644 --- a/rest_framework_json_api/serializers.py +++ b/rest_framework_json_api/serializers.py @@ -89,14 +89,12 @@ def validate_path(serializer_class, field_path, path): validate_path(this_included_serializer, new_included_field_path, path) if request and view: - include_resources_param = request.query_params.get('include') if request else None - if include_resources_param: - included_resources = include_resources_param.split(',') - for included_field_name in included_resources: - included_field_path = included_field_name.split('.') - this_serializer_class = view.get_serializer_class() - # lets validate the current path - validate_path(this_serializer_class, included_field_path, included_field_name) + included_resources = utils.get_included_resources(request) + for included_field_name in included_resources: + included_field_path = included_field_name.split('.') + this_serializer_class = view.get_serializer_class() + # lets validate the current path + validate_path(this_serializer_class, included_field_path, included_field_name) super(IncludedResourcesValidationMixin, self).__init__(*args, **kwargs) diff --git a/rest_framework_json_api/utils.py b/rest_framework_json_api/utils.py index 261640c6..c2754944 100644 --- a/rest_framework_json_api/utils.py +++ b/rest_framework_json_api/utils.py @@ -232,6 +232,15 @@ def get_resource_type_from_serializer(serializer): return get_resource_type_from_model(serializer.Meta.model) +def get_included_resources(request): + included_args = list() + if request: + include_resources_param = request.query_params.get('include') + if include_resources_param: + included_args = include_resources_param.split(',') + return included_args + + def get_included_serializers(serializer): included_serializers = copy.copy(getattr(serializer, 'included_serializers', dict())) diff --git a/rest_framework_json_api/views.py b/rest_framework_json_api/views.py index 4b6e631a..26cebbec 100644 --- a/rest_framework_json_api/views.py +++ b/rest_framework_json_api/views.py @@ -4,7 +4,11 @@ from django.db.models import Model from django.db.models.query import QuerySet from django.db.models.manager import Manager -from rest_framework import generics +from django.db.models.fields.related_descriptors import ( + ForwardManyToOneDescriptor, + ManyToManyDescriptor, +) +from rest_framework import generics, viewsets from rest_framework.response import Response from rest_framework.exceptions import NotFound, MethodNotAllowed from rest_framework.reverse import reverse @@ -12,7 +16,43 @@ from rest_framework_json_api.exceptions import Conflict from rest_framework_json_api.serializers import ResourceIdentifierObjectSerializer -from rest_framework_json_api.utils import get_resource_type_from_instance, OrderedDict, Hyperlink +from rest_framework_json_api.utils import ( + get_resource_type_from_instance, + OrderedDict, + Hyperlink, + get_included_resources, +) + + +class ModelViewSet(viewsets.ModelViewSet): + def get_queryset(self, *args, **kwargs): + qs = super().get_queryset(*args, **kwargs) + included_resources = get_included_resources(self.request) + + for included in included_resources: + included_model = None + levels = included.split('.') + level_model = qs.model + for level in levels: + if not hasattr(level_model, level): + break + field = getattr(level_model, level) + field_class = field.__class__ + if not ( + issubclass(field_class, ForwardManyToOneDescriptor) + or issubclass(field_class, ManyToManyDescriptor) + ): + break + + if level == levels[-1]: + included_model = field + else: + level_model = field.get_queryset().model + + if included_model is not None: + qs = qs.prefetch_related(included.replace('.', '__')) + + return qs class RelationshipView(generics.GenericAPIView):