From 0c743eccf9586ae17c613653b5ba25ca0db470f3 Mon Sep 17 00:00:00 2001
From: Calum Mackervoy <c.mackervoy@gmail.com>
Date: Wed, 8 Apr 2020 20:52:36 +0100
Subject: [PATCH] container views return local objects only

---
 djangoldp/tests/tests_get.py     | 6 ++++++
 djangoldp/tests/tests_sources.py | 2 ++
 djangoldp/views.py               | 5 ++++-
 3 files changed, 12 insertions(+), 1 deletion(-)

diff --git a/djangoldp/tests/tests_get.py b/djangoldp/tests/tests_get.py
index e69f119f..adda7c7b 100644
--- a/djangoldp/tests/tests_get.py
+++ b/djangoldp/tests/tests_get.py
@@ -23,9 +23,12 @@ class TestGET(APITestCase):
 
     def test_get_container(self):
         Post.objects.create(content="content")
+        # federated object - should not be returned in the container view
+        Post.objects.create(content="federated", urlid="https://external.com/posts/1/")
         response = self.client.get('/posts/', content_type='application/ld+json')
         self.assertEqual(response.status_code, 200)
         self.assertIn('permissions', response.data)
+        self.assertEquals(1, len(response.data['ldp:contains']))
         self.assertEquals(2, len(response.data['permissions']))  # read and add
 
         Invoice.objects.create(title="content")
@@ -38,6 +41,7 @@ class TestGET(APITestCase):
         Post.objects.all().delete()
         response = self.client.get('/posts/', content_type='application/ld+json')
         self.assertEqual(response.status_code, 200)
+        self.assertEquals(0, len(response.data['ldp:contains']))
 
     def test_get_filtered_fields(self):
         skill = Skill.objects.create(title="Java", obligatoire="ok", slug="1")
@@ -79,6 +83,8 @@ class TestGET(APITestCase):
     def test_get_nested(self):
         invoice = Invoice.objects.create(title="invoice")
         batch = Batch.objects.create(invoice=invoice, title="batch")
+        distant_batch = Batch.objects.create(invoice=invoice, title="distant", urlid="https://external.com/batch/1/")
         response = self.client.get('/invoices/{}/batches/'.format(invoice.pk), content_type='application/ld+json')
         self.assertEqual(response.status_code, 200)
         self.assertEquals(response.data['@id'], 'http://happy-dev.fr/invoices/{}/batches/'.format(invoice.pk))
+        self.assertEquals(len(response.data['ldp:contains']), 2)
diff --git a/djangoldp/tests/tests_sources.py b/djangoldp/tests/tests_sources.py
index 94b58df8..374d9821 100644
--- a/djangoldp/tests/tests_sources.py
+++ b/djangoldp/tests/tests_sources.py
@@ -17,8 +17,10 @@ class TestSource(APITestCase):
         response = self.client.get('/sources/{}/'.format(source.federation), content_type='application/ld+json')
         self.assertEqual(response.status_code, 200)
         self.assertEqual(response.data['@id'], 'http://happy-dev.fr/sources/source_name/')
+        self.assertEqual(len(response.data['ldp:contains']), 1)
 
     def test_get_empty_resource(self):
         response = self.client.get('/sources/{}/'.format('unknown'), content_type='application/ld+json')
         self.assertEqual(response.status_code, 200)
         self.assertEqual(response.data['@id'], 'http://happy-dev.fr/sources/unknown/')
+        self.assertEqual(len(response.data['ldp:contains']), 0)
diff --git a/djangoldp/views.py b/djangoldp/views.py
index 16153b9f..720f4cef 100644
--- a/djangoldp/views.py
+++ b/djangoldp/views.py
@@ -19,6 +19,7 @@ from rest_framework.viewsets import ModelViewSet
 from djangoldp.endpoints.webfinger import WebFingerEndpoint, WebFingerError
 from djangoldp.models import LDPSource, Model
 from djangoldp.permissions import LDPPermissions
+from djangoldp.filters import LocalObjectOnContainerPathBackend
 
 
 get_user_model()._meta.rdf_context = {"get_full_name": "rdfs:label"}
@@ -114,6 +115,7 @@ class LDPViewSet(LDPViewSetGenerator):
     renderer_classes = (JSONLDRenderer,)
     parser_classes = (JSONLDParser,)
     authentication_classes = (NoCSRFAuthentication,)
+    filter_backends = [LocalObjectOnContainerPathBackend]
 
     def __init__(self, **kwargs):
         super().__init__(**kwargs)
@@ -122,7 +124,7 @@ class LDPViewSet(LDPViewSetGenerator):
         if self.permission_classes:
             for p in self.permission_classes:
                 if hasattr(p, 'filter_class') and p.filter_class:
-                    self.filter_backends = p.filter_class
+                    self.filter_backends.append(p.filter_class)
 
         self.serializer_class = self.build_read_serializer()
         self.write_serializer_class = self.build_write_serializer()
@@ -301,6 +303,7 @@ class LDPNestedViewSet(LDPViewSet):
 class LDPSourceViewSet(LDPViewSet):
     model = LDPSource
     federation = None
+    filter_backends = []
 
     def get_queryset(self, *args, **kwargs):
         return super().get_queryset(*args, **kwargs).filter(federation=self.kwargs['federation'])
-- 
GitLab