diff --git a/example/tests/integration/test_meta.py b/example/tests/integration/test_meta.py index af69a910..d854a34b 100644 --- a/example/tests/integration/test_meta.py +++ b/example/tests/integration/test_meta.py @@ -7,7 +7,39 @@ pytestmark = pytest.mark.django_db -def test_top_level_meta(blog, client): +def test_top_level_meta_for_list_view(blog, client): + + expected = { + "data": [{ + "type": "blogs", + "id": "1", + "attributes": { + "name": blog.name + }, + "meta": { + "copyright": datetime.now().year + }, + }], + 'links': { + 'first': 'http://testserver/blogs?page=1', + 'last': 'http://testserver/blogs?page=1', + 'next': None, + 'prev': None + }, + 'meta': { + 'pagination': {'count': 1, 'page': 1, 'pages': 1}, + 'apiDocs': '/docs/api/blogs' + } + } + + response = client.get(reverse("blog-list")) + content_dump = redump_json(response.content) + expected_dump = dump_json(expected) + + assert content_dump == expected_dump + + +def test_top_level_meta_for_detail_view(blog, client): expected = { "data": { diff --git a/example/tests/test_views.py b/example/tests/test_views.py index c8221077..a76df044 100644 --- a/example/tests/test_views.py +++ b/example/tests/test_views.py @@ -1,13 +1,18 @@ import json +from django.test import RequestFactory from django.utils import timezone from rest_framework.reverse import reverse from rest_framework.test import APITestCase +from rest_framework.test import force_authenticate from rest_framework_json_api.utils import format_relation_name from example.models import Blog, Entry, Comment, Author +from .. import views +from . import TestBase + class TestRelationshipView(APITestCase): def setUp(self): @@ -184,3 +189,33 @@ def test_delete_to_many_relationship_with_change(self): } response = self.client.delete(url, data=json.dumps(request_data), content_type='application/vnd.api+json') assert response.status_code == 200, response.content.decode() + + +class TestValidationErrorResponses(TestBase): + def test_if_returns_error_on_empty_post(self): + view = views.BlogViewSet.as_view({'post': 'create'}) + response = self._get_create_response("{}", view) + self.assertEqual(400, response.status_code) + expected = [{'detail': 'Received document does not contain primary data', 'status': '400', 'source': {'pointer': '/data'}}] + self.assertEqual(expected, response.data) + + def test_if_returns_error_on_missing_form_data_post(self): + view = views.BlogViewSet.as_view({'post': 'create'}) + response = self._get_create_response('{"data":{"attributes":{},"type":"blogs"}}', view) + self.assertEqual(400, response.status_code) + expected = [{'status': '400', 'detail': 'This field is required.', 'source': {'pointer': '/data/attributes/name'}}] + self.assertEqual(expected, response.data) + + def test_if_returns_error_on_bad_endpoint_name(self): + view = views.BlogViewSet.as_view({'post': 'create'}) + response = self._get_create_response('{"data":{"attributes":{},"type":"bad"}}', view) + self.assertEqual(409, response.status_code) + expected = [{'detail': "The resource object's type (bad) is not the type that constitute the collection represented by the endpoint (blogs).", 'source': {'pointer': '/data'}, 'status': '409'}] + self.assertEqual(expected, response.data) + + def _get_create_response(self, data, view): + factory = RequestFactory() + request = factory.post('/', data, content_type='application/vnd.api+json') + user = self.create_user('user', 'pass') + force_authenticate(request, user) + return view(request) diff --git a/example/views.py b/example/views.py index c8dddc50..2ce22160 100644 --- a/example/views.py +++ b/example/views.py @@ -1,15 +1,30 @@ +from rest_framework import exceptions from rest_framework import viewsets from rest_framework_json_api.views import RelationshipView from example.models import Blog, Entry, Author, Comment from example.serializers import ( BlogSerializer, EntrySerializer, AuthorSerializer, CommentSerializer) +from rest_framework_json_api.utils import format_drf_errors + +HTTP_422_UNPROCESSABLE_ENTITY = 422 + class BlogViewSet(viewsets.ModelViewSet): queryset = Blog.objects.all() serializer_class = BlogSerializer +class BlogCustomViewSet(viewsets.ModelViewSet): + queryset = Blog.objects.all() + serializer_class = BlogSerializer + + def handle_exception(self, exc): + if isinstance(exc, exceptions.ValidationError): + exc.status_code = HTTP_422_UNPROCESSABLE_ENTITY + return format_drf_errors(super(BlogCustomViewSet, self).handle_exception(exc), self.get_exception_handler_context(), exc) + + class EntryViewSet(viewsets.ModelViewSet): queryset = Entry.objects.all() resource_name = 'posts' diff --git a/rest_framework_json_api/exceptions.py b/rest_framework_json_api/exceptions.py index 935fecdb..c581bda2 100644 --- a/rest_framework_json_api/exceptions.py +++ b/rest_framework_json_api/exceptions.py @@ -1,9 +1,7 @@ -import inspect -from django.utils import six, encoding from django.utils.translation import ugettext_lazy as _ from rest_framework import status, exceptions -from rest_framework_json_api.utils import format_value +from rest_framework_json_api import utils def exception_handler(exc, context): @@ -18,63 +16,9 @@ def exception_handler(exc, context): if not response: return response - - errors = [] - # handle generic errors. ValidationError('test') in a view for example - if isinstance(response.data, list): - for message in response.data: - errors.append({ - 'detail': message, - 'source': { - 'pointer': '/data', - }, - 'status': encoding.force_text(response.status_code), - }) - # handle all errors thrown from serializers - else: - for field, error in response.data.items(): - field = format_value(field) - pointer = '/data/attributes/{}'.format(field) - # see if they passed a dictionary to ValidationError manually - if isinstance(error, dict): - errors.append(error) - elif isinstance(error, six.string_types): - classes = inspect.getmembers(exceptions, inspect.isclass) - # DRF sets the `field` to 'detail' for its own exceptions - if isinstance(exc, tuple(x[1] for x in classes)): - pointer = '/data' - errors.append({ - 'detail': error, - 'source': { - 'pointer': pointer, - }, - 'status': encoding.force_text(response.status_code), - }) - elif isinstance(error, list): - for message in error: - errors.append({ - 'detail': message, - 'source': { - 'pointer': pointer, - }, - 'status': encoding.force_text(response.status_code), - }) - else: - errors.append({ - 'detail': error, - 'source': { - 'pointer': pointer, - }, - 'status': encoding.force_text(response.status_code), - }) - - - context['view'].resource_name = 'errors' - response.data = errors - return response + return utils.format_drf_errors(response, context, exc) class Conflict(exceptions.APIException): status_code = status.HTTP_409_CONFLICT default_detail = _('Conflict.') - diff --git a/rest_framework_json_api/utils.py b/rest_framework_json_api/utils.py index d8a4e67a..231028c7 100644 --- a/rest_framework_json_api/utils.py +++ b/rest_framework_json_api/utils.py @@ -3,13 +3,16 @@ """ import copy from collections import OrderedDict +import inspect import inflection from django.conf import settings +from django.utils import encoding from django.utils import six from django.utils.module_loading import import_string as import_class_from_dotted_path from django.utils.translation import ugettext_lazy as _ from rest_framework.exceptions import APIException +from rest_framework import exceptions try: from rest_framework.serializers import ManyRelatedField @@ -249,3 +252,58 @@ def __new__(self, url, name): return ret is_hyperlink = True + + +def format_drf_errors(response, context, exc): + errors = [] + # handle generic errors. ValidationError('test') in a view for example + if isinstance(response.data, list): + for message in response.data: + errors.append({ + 'detail': message, + 'source': { + 'pointer': '/data', + }, + 'status': encoding.force_text(response.status_code), + }) + # handle all errors thrown from serializers + else: + for field, error in response.data.items(): + field = format_value(field) + pointer = '/data/attributes/{}'.format(field) + # see if they passed a dictionary to ValidationError manually + if isinstance(error, dict): + errors.append(error) + elif isinstance(error, six.string_types): + classes = inspect.getmembers(exceptions, inspect.isclass) + # DRF sets the `field` to 'detail' for its own exceptions + if isinstance(exc, tuple(x[1] for x in classes)): + pointer = '/data' + errors.append({ + 'detail': error, + 'source': { + 'pointer': pointer, + }, + 'status': encoding.force_text(response.status_code), + }) + elif isinstance(error, list): + for message in error: + errors.append({ + 'detail': message, + 'source': { + 'pointer': pointer, + }, + 'status': encoding.force_text(response.status_code), + }) + else: + errors.append({ + 'detail': error, + 'source': { + 'pointer': pointer, + }, + 'status': encoding.force_text(response.status_code), + }) + + context['view'].resource_name = 'errors' + response.data = errors + return response