Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 49 additions & 17 deletions rest_framework_json_api/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,20 +86,18 @@ def extract_relationships(fields, resource, resource_instance):
continue

source = field.source
try:
relation_instance_or_manager = getattr(resource_instance, source)
except AttributeError:
# if the field is not defined on the model then we check the serializer
# and if no value is there we skip over the field completely
serializer_method = getattr(field.parent, source, None)
if serializer_method and hasattr(serializer_method, '__call__'):
relation_instance_or_manager = serializer_method(resource_instance)
else:
continue

serializer_method = getattr(field.parent, source, None)
relation_type = utils.get_related_resource_type(field)

if isinstance(field, relations.HyperlinkedIdentityField):
try:
relation_instance_or_manager = getattr(resource_instance, source)
except AttributeError:
if serializer_method and hasattr(serializer_method, '__call__'):
relation_instance_or_manager = serializer_method(resource_instance)
else:
continue

# special case for HyperlinkedIdentityField
relation_data = list()

Expand All @@ -123,6 +121,14 @@ def extract_relationships(fields, resource, resource_instance):
continue

if isinstance(field, ResourceRelatedField):
try:
relation_instance_or_manager = getattr(resource_instance, source)
except AttributeError:
if serializer_method and hasattr(serializer_method, '__call__'):
relation_instance_or_manager = serializer_method(resource_instance)
else:
continue

# special case for ResourceRelatedField
relation_data = {
'data': resource.get(field_name)
Expand All @@ -137,8 +143,15 @@ def extract_relationships(fields, resource, resource_instance):
continue

if isinstance(field, (relations.PrimaryKeyRelatedField, relations.HyperlinkedRelatedField)):
relation_id = relation_instance_or_manager.pk if resource.get(field_name) else None
try:
relation = getattr(resource_instance, '%s_id' % field.source)
except AttributeError:
if serializer_method and hasattr(serializer_method, '__call__'):
relation = serializer_method(resource_instance).pk
else:
continue

relation_id = relation if resource.get(field_name) else None
relation_data = {
'data': (
OrderedDict([('type', relation_type), ('id', encoding.force_text(relation_id))])
Expand All @@ -153,6 +166,13 @@ def extract_relationships(fields, resource, resource_instance):
continue

if isinstance(field, relations.ManyRelatedField):
try:
relation_instance_or_manager = getattr(resource_instance, source)
except AttributeError:
if serializer_method and hasattr(serializer_method, '__call__'):
relation_instance_or_manager = serializer_method(resource_instance)
else:
continue

if isinstance(field.child_relation, ResourceRelatedField):
# special case for ResourceRelatedField
Expand Down Expand Up @@ -193,6 +213,14 @@ def extract_relationships(fields, resource, resource_instance):
continue

if isinstance(field, ListSerializer):
try:
relation_instance_or_manager = getattr(resource_instance, source)
except AttributeError:
if serializer_method and hasattr(serializer_method, '__call__'):
relation_instance_or_manager = serializer_method(resource_instance)
else:
continue

relation_data = list()

serializer_data = resource.get(field_name)
Expand All @@ -210,6 +238,14 @@ def extract_relationships(fields, resource, resource_instance):
continue

if isinstance(field, ModelSerializer):
try:
relation_instance_or_manager = getattr(resource_instance, source)
except AttributeError:
if serializer_method and hasattr(serializer_method, '__call__'):
relation_instance_or_manager = serializer_method(resource_instance)
else:
continue

relation_model = field.Meta.model
relation_type = utils.format_resource_type(relation_model.__name__)

Expand Down Expand Up @@ -415,11 +451,7 @@ def render(self, data, accepted_media_type=None, renderer_context=None):
if resource_name == 'errors':
return self.render_errors(data, accepted_media_type, renderer_context)

include_resources_param = request.query_params.get('include') if request else None
if include_resources_param:
included_resources = include_resources_param.split(',')
else:
included_resources = list()
included_resources = utils.get_included_resources(request)

json_api_data = data
json_api_included = list()
Expand Down
14 changes: 6 additions & 8 deletions rest_framework_json_api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,12 @@ def validate_path(serializer_class, field_path, path):
validate_path(this_included_serializer, new_included_field_path, path)

if request and view:
include_resources_param = request.query_params.get('include') if request else None
if include_resources_param:
included_resources = include_resources_param.split(',')
for included_field_name in included_resources:
included_field_path = included_field_name.split('.')
this_serializer_class = view.get_serializer_class()
# lets validate the current path
validate_path(this_serializer_class, included_field_path, included_field_name)
included_resources = utils.get_included_resources(request)
for included_field_name in included_resources:
included_field_path = included_field_name.split('.')
this_serializer_class = view.get_serializer_class()
# lets validate the current path
validate_path(this_serializer_class, included_field_path, included_field_name)

super(IncludedResourcesValidationMixin, self).__init__(*args, **kwargs)

Expand Down
9 changes: 9 additions & 0 deletions rest_framework_json_api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,15 @@ def get_resource_type_from_serializer(serializer):
return get_resource_type_from_model(serializer.Meta.model)


def get_included_resources(request):
included_args = list()
if request:
include_resources_param = request.query_params.get('include')
if include_resources_param:
included_args = include_resources_param.split(',')
return included_args


def get_included_serializers(serializer):
included_serializers = copy.copy(getattr(serializer, 'included_serializers', dict()))

Expand Down
44 changes: 42 additions & 2 deletions rest_framework_json_api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,55 @@
from django.db.models import Model
from django.db.models.query import QuerySet
from django.db.models.manager import Manager
from rest_framework import generics
from django.db.models.fields.related_descriptors import (
ForwardManyToOneDescriptor,
ManyToManyDescriptor,
)
from rest_framework import generics, viewsets
from rest_framework.response import Response
from rest_framework.exceptions import NotFound, MethodNotAllowed
from rest_framework.reverse import reverse
from rest_framework.serializers import Serializer

from rest_framework_json_api.exceptions import Conflict
from rest_framework_json_api.serializers import ResourceIdentifierObjectSerializer
from rest_framework_json_api.utils import get_resource_type_from_instance, OrderedDict, Hyperlink
from rest_framework_json_api.utils import (
get_resource_type_from_instance,
OrderedDict,
Hyperlink,
get_included_resources,
)


class ModelViewSet(viewsets.ModelViewSet):
def get_queryset(self, *args, **kwargs):
qs = super().get_queryset(*args, **kwargs)
included_resources = get_included_resources(self.request)

for included in included_resources:
included_model = None
levels = included.split('.')
level_model = qs.model
for level in levels:
if not hasattr(level_model, level):
break
field = getattr(level_model, level)
field_class = field.__class__
if not (
issubclass(field_class, ForwardManyToOneDescriptor)
or issubclass(field_class, ManyToManyDescriptor)
):
break

if level == levels[-1]:
included_model = field
else:
level_model = field.get_queryset().model

if included_model is not None:
qs = qs.prefetch_related(included.replace('.', '__'))

return qs


class RelationshipView(generics.GenericAPIView):
Expand Down