|  | 
|  | 1 | +from collections import Iterable | 
|  | 2 | + | 
| 1 | 3 | from django.core.exceptions import ImproperlyConfigured | 
| 2 | 4 | from django.db.models import Model | 
| 3 | 5 | from django.db.models.fields.related_descriptors import ( | 
| @@ -98,6 +100,54 @@ def get_queryset(self, *args, **kwargs): | 
| 98 | 100 |         return qs | 
| 99 | 101 | 
 | 
| 100 | 102 | 
 | 
|  | 103 | +class RelatedMixin(object): | 
|  | 104 | +    """ | 
|  | 105 | +    This mixin handles all related entities, whose Serializers are declared in "related_serializers" | 
|  | 106 | +    """ | 
|  | 107 | +    related_serializers = {} | 
|  | 108 | +    field_name_mapping = {} | 
|  | 109 | + | 
|  | 110 | +    def retrieve_related(self, request, *args, **kwargs): | 
|  | 111 | +        serializer_kwargs = {} | 
|  | 112 | +        instance = self.get_related_instance() | 
|  | 113 | + | 
|  | 114 | +        if hasattr(instance, 'all'): | 
|  | 115 | +            instance = instance.all() | 
|  | 116 | + | 
|  | 117 | +        if callable(instance): | 
|  | 118 | +            instance = instance() | 
|  | 119 | + | 
|  | 120 | +        if instance is None: | 
|  | 121 | +            return Response(data=None) | 
|  | 122 | + | 
|  | 123 | +        if isinstance(instance, Iterable): | 
|  | 124 | +            serializer_kwargs['many'] = True | 
|  | 125 | + | 
|  | 126 | +        serializer = self.get_serializer(instance, **serializer_kwargs) | 
|  | 127 | +        return Response(serializer.data) | 
|  | 128 | + | 
|  | 129 | +    def get_serializer_class(self): | 
|  | 130 | +        if 'related_field' in self.kwargs: | 
|  | 131 | +            field_name = self.get_related_field_name() | 
|  | 132 | +            _class = self.related_serializers.get(field_name, None) | 
|  | 133 | +            if _class is None: | 
|  | 134 | +                raise NotFound | 
|  | 135 | +            return _class | 
|  | 136 | +        return super(RelatedMixin, self).get_serializer_class() | 
|  | 137 | + | 
|  | 138 | +    def get_related_field_name(self): | 
|  | 139 | +        field_name = self.kwargs['related_field'] | 
|  | 140 | +        if field_name in self.field_name_mapping: | 
|  | 141 | +            return self.field_name_mapping[field_name] | 
|  | 142 | +        return field_name | 
|  | 143 | + | 
|  | 144 | +    def get_related_instance(self): | 
|  | 145 | +        try: | 
|  | 146 | +            return getattr(self.get_object(), self.get_related_field_name()) | 
|  | 147 | +        except AttributeError: | 
|  | 148 | +            raise NotFound | 
|  | 149 | + | 
|  | 150 | + | 
| 101 | 151 | class ModelViewSet(AutoPrefetchMixin, PrefetchForIncludesHelperMixin, viewsets.ModelViewSet): | 
| 102 | 152 |     pass | 
| 103 | 153 | 
 | 
|  | 
0 commit comments