diff --git a/rest_framework_json_api/renderers.py b/rest_framework_json_api/renderers.py index a8e852cf..5b0bd4ca 100644 --- a/rest_framework_json_api/renderers.py +++ b/rest_framework_json_api/renderers.py @@ -12,7 +12,6 @@ from . import utils - class JSONRenderer(renderers.JSONRenderer): """ Render a JSON response per the JSON API spec: @@ -244,12 +243,24 @@ def extract_included(fields, resource, resource_instance, included_resources): if not isinstance(field, (relations.RelatedField, relations.ManyRelatedField, BaseSerializer)): continue - try: - included_resources.remove(field_name) - except ValueError: - # Skip fields not in requested included resources + # Find included resourcs for this field + matches = [ + key for key in included_resources + if key.split('.')[0] == field_name + ] + + # Skip fields without included resources + if not matches: continue + # Remove included resources and collect nested included resources + nested_included_resources = [] + for key in matches: + included_resources.remove(key) + nested_resource = '.'.join(key.split('.')[1:]) + if nested_resource: + nested_included_resources.append(nested_resource) + try: relation_instance_or_manager = getattr(resource_instance, field_name) except AttributeError: @@ -262,9 +273,6 @@ def extract_included(fields, resource, resource_instance, included_resources): serializer_method = getattr(current_serializer, field.source) relation_instance_or_manager = serializer_method(resource_instance) - new_included_resources = [key.replace('%s.' % field_name, '', 1) - for key in included_resources - if field_name == key.split('.')[0]] serializer_data = resource.get(field_name) if isinstance(field, relations.ManyRelatedField): @@ -297,7 +305,7 @@ def extract_included(fields, resource, resource_instance, included_resources): ) included_data.extend( JSONRenderer.extract_included( - serializer_fields, serializer_resource, nested_resource_instance, new_included_resources + serializer_fields, serializer_resource, nested_resource_instance, nested_included_resources ) ) @@ -315,7 +323,7 @@ def extract_included(fields, resource, resource_instance, included_resources): ) included_data.extend( JSONRenderer.extract_included( - serializer_fields, serializer_data, relation_instance_or_manager, new_included_resources + serializer_fields, serializer_data, relation_instance_or_manager, nested_included_resources ) ) @@ -497,3 +505,4 @@ def render(self, data, accepted_media_type=None, renderer_context=None): return super(JSONRenderer, self).render( render_data, accepted_media_type, renderer_context ) +