diff --git a/djangoldp/related.py b/djangoldp/related.py new file mode 100644 index 0000000000000000000000000000000000000000..4f3b35f2b5706239726dba9b5c4207345a78a7b8 --- /dev/null +++ b/djangoldp/related.py @@ -0,0 +1,44 @@ +from rest_framework.utils import model_meta + + +def get_prefetch_fields(model, serializer, depth, prepend_string=''): + ''' + This method should then be used with queryset.prefetch_related, to auto-fetch joined resources (to speed up nested serialization) + This can speed up ModelViewSet and LDPViewSet alike by as high a factor as 2 + :param model: the model to be analysed + :param serializer: an LDPSerializer instance. Used to extract the fields for each nested model + :param depth: the depth at which to stop the recursion (should be set to the configured depth of the ViewSet) + :param prepend_string: should be set to the default. Used in recursive calls + :return: set of strings to prefetch for a given model. Including serialized nested fields and foreign keys recursively + called on many-to-many fields until configured depth reached + ''' + # the objective is to build a list of fields and nested fields which should be prefetched for the optimisation + # of database queries + fields = set() + + # get a list of all fields which would be serialized on this model + # TODO: dynamically generating serializer fields is necessary to retrieve many-to-many fields at depth > 0, + # but the _all_ default has issues detecting reverse many-to-many fields + #Â meta_args = {'model': model, 'depth': 0, 'fields': Model.get_meta(model, 'serializer_fields', '__all__')} + # meta_class = type('Meta', (), meta_args) + # serializer = (type(LDPSerializer)('TestSerializer', (LDPSerializer,), {'Meta': meta_class}))() + serializer_fields = set([f for f in serializer.get_fields()]) + + # we are only interested in foreign keys (and many-to-many relationships) + model_relations = model_meta.get_field_info(model).relations + for field_name, relation_info in model_relations.items(): + # foreign keys can be added without fuss + if not relation_info.to_many: + fields.add((prepend_string + field_name)) + continue + + # nested fields should be added if serialized + if field_name in serializer_fields: + fields.add((prepend_string + field_name)) + + # and they should also have their immediate foreign keys prefetched if depth not reached + if depth >= 0: + new_prepend_str = prepend_string + field_name + '__' + fields = fields.union(get_prefetch_fields(relation_info.related_model, serializer, depth - 1, new_prepend_str)) + + return fields diff --git a/djangoldp/tests/runner.py b/djangoldp/tests/runner.py index 3966fd5190de0a9f37e7c3838f15e24b2c234e62..63cb160784dac942a2224dde6902205eff4f676c 100644 --- a/djangoldp/tests/runner.py +++ b/djangoldp/tests/runner.py @@ -13,6 +13,7 @@ test_runner = DiscoverRunner(verbosity=1) failures = test_runner.run_tests([ 'djangoldp.tests.tests_ldp_model', + 'djangoldp.tests.tests_ldp_viewset', 'djangoldp.tests.tests_save', 'djangoldp.tests.tests_user_permissions', 'djangoldp.tests.tests_guardian', diff --git a/djangoldp/tests/tests_ldp_model.py b/djangoldp/tests/tests_ldp_model.py index fb9ab4623ea8d3383b69195408b6f73979886e74..2fbeaef997d70ab986bc9cbded59dbebfa385906 100644 --- a/djangoldp/tests/tests_ldp_model.py +++ b/djangoldp/tests/tests_ldp_model.py @@ -1,5 +1,3 @@ -import unittest - from django.test import TestCase from djangoldp.models import Model diff --git a/djangoldp/tests/tests_ldp_viewset.py b/djangoldp/tests/tests_ldp_viewset.py new file mode 100644 index 0000000000000000000000000000000000000000..8e2f6cca3f52c301a17c6f8b95470e0acbd0bff2 --- /dev/null +++ b/djangoldp/tests/tests_ldp_viewset.py @@ -0,0 +1,59 @@ +from django.test import TestCase + +from djangoldp.tests.models import User, Circle, Project +from djangoldp.serializers import LDPSerializer +from djangoldp.related import get_prefetch_fields + + +class LDPViewSet(TestCase): + + user_serializer_fields = ['@id', 'username', 'first_name', 'last_name', 'email', 'userprofile', 'conversation_set', + 'circle_set', 'projects'] + user_expected_fields = {'userprofile', 'conversation_set', 'circle_set', 'projects', 'circle_set__owner', + 'conversation_set__author_user', 'conversation_set__peer_user'} + project_serializer_fields = ['@id', 'description', 'team'] + project_expected_fields = {'team', 'team__userprofile'} + + def _get_serializer(self, model, depth, fields): + meta_args = {'model': model, 'depth': depth, 'fields': fields} + meta_class = type('Meta', (), meta_args) + return (type(LDPSerializer)('TestSerializer', (LDPSerializer,), {'Meta': meta_class}))() + + def test_get_prefetch_fields_user(self): + model = User + depth = 0 + serializer_fields = self.user_serializer_fields + expected_fields = self.user_expected_fields + serializer = self._get_serializer(model, depth, serializer_fields) + result = get_prefetch_fields(model, serializer, depth) + self.assertEqual(expected_fields, result) + + def test_get_prefetch_fields_circle(self): + model = Circle + depth = 0 + serializer_fields = ['@id', 'name', 'description', 'owner', 'members', 'team'] + expected_fields = {'owner', 'members', 'team', 'members__user', 'members__circle', 'team__userprofile'} + serializer = self._get_serializer(model, depth, serializer_fields) + result = get_prefetch_fields(model, serializer, depth) + self.assertEqual(expected_fields, result) + + def test_get_prefetch_fields_project(self): + model = Project + depth = 0 + serializer_fields = self.project_serializer_fields + expected_fields = self.project_expected_fields + serializer = self._get_serializer(model, depth, serializer_fields) + result = get_prefetch_fields(model, serializer, depth) + self.assertEqual(expected_fields, result) + + # TODO: dynamically generating serializer fields is necessary to retrieve many-to-many fields at depth > 0, + # but the _all_ default has issues detecting reverse many-to-many fields + '''def test_get_prefetch_fields_depth_1(self): + model = Project + depth = 2 + serializer_fields = self.project_serializer_fields + user_expected = set(['team__' + x for x in self.user_expected_fields]) + expected_fields = self.project_expected_fields.union(user_expected) + serializer = self._get_serializer(model, depth, serializer_fields) + result = get_prefetch_fields(model, serializer, depth) + self.assertEqual(expected_fields, result)''' diff --git a/djangoldp/views.py b/djangoldp/views.py index 9693cb29c7a25ec848279b615633dadbb5fbff72..5980dc2949a8a732a394312e66e0feec7a1b78f1 100644 --- a/djangoldp/views.py +++ b/djangoldp/views.py @@ -1,6 +1,4 @@ import json -import logging - from django.apps import apps from django.conf import settings from django.conf.urls import include, re_path @@ -23,13 +21,15 @@ from rest_framework.utils import model_meta from rest_framework.views import APIView from rest_framework.viewsets import ModelViewSet -from djangoldp.activities import ActivityPubService -from djangoldp.activities import ActivityQueueService, as_activitystream -from djangoldp.activities.errors import ActivityStreamDecodeError, ActivityStreamValidationError from djangoldp.endpoints.webfinger import WebFingerEndpoint, WebFingerError -from djangoldp.filters import LocalObjectOnContainerPathBackend from djangoldp.models import LDPSource, Model, Follower from djangoldp.permissions import LDPPermissions +from djangoldp.filters import LocalObjectOnContainerPathBackend +from djangoldp.related import get_prefetch_fields +from djangoldp.activities import ActivityQueueService, as_activitystream +from djangoldp.activities import ActivityPubService +from djangoldp.activities.errors import ActivityStreamDecodeError, ActivityStreamValidationError +import logging logger = logging.getLogger('djangoldp') get_user_model()._meta.rdf_context = {"get_full_name": "rdfs:label"} @@ -386,6 +386,7 @@ class LDPViewSet(LDPViewSetGenerator): parser_classes = (JSONLDParser,) authentication_classes = (NoCSRFAuthentication,) filter_backends = [LocalObjectOnContainerPathBackend] + prefetch_fields = None def __init__(self, **kwargs): super().__init__(**kwargs) @@ -547,9 +548,13 @@ class LDPViewSet(LDPViewSetGenerator): def get_queryset(self, *args, **kwargs): if self.model: - return self.model.objects.all() + queryset = self.model.objects.all() else: - return super(LDPView, self).get_queryset(*args, **kwargs) + queryset = super(LDPView, self).get_queryset(*args, **kwargs) + if self.prefetch_fields is None: + depth = getattr(self, 'depth', Model.get_meta(self.model, 'depth', 0)) + self.prefetch_fields = get_prefetch_fields(self.model, self.get_serializer(), depth) + return queryset.prefetch_related(*self.prefetch_fields) def dispatch(self, request, *args, **kwargs): '''overriden dispatch method to append some custom headers'''