From 110f8fcbc3f077f8997bc21c5a6efb2055e9ba81 Mon Sep 17 00:00:00 2001
From: Jean-Baptiste <bleme@pm.me>
Date: Mon, 19 Oct 2020 14:16:45 +0200
Subject: [PATCH] feature: cache per user / vary parameter

---
 djangoldp/serializers.py | 52 ++++++++++++++++++++++++----------------
 1 file changed, 32 insertions(+), 20 deletions(-)

diff --git a/djangoldp/serializers.py b/djangoldp/serializers.py
index 3dad3bed..e57f3a5a 100644
--- a/djangoldp/serializers.py
+++ b/djangoldp/serializers.py
@@ -41,22 +41,30 @@ class InMemoryCache:
         self.cache = {
         }
 
-    def has(self, cache_key):
-        if cache_key in self.cache:
-            if time.time() - self.cache[cache_key]['time'] < self.max_age:
+    def has(self, cache_key, vary):
+        if cache_key in self.cache and vary in self.cache[cache_key]:
+            if time.time() - self.cache[cache_key][vary]['time'] < self.max_age:
                 return True
             else:
-                self.invalidate(cache_key)
+                self.invalidate(cache_key, vary)
         return False
 
-    def get(self, cache_key):
-        return self.cache[cache_key]['value']
+    def get(self, cache_key, vary):
+        if self.has(cache_key, vary):
+            return self.cache[cache_key][vary]['value']
+        else:
+            return None
 
-    def set(self, cache_key, value):
-        self.cache[cache_key] = {'time': time.time(), 'value': value}
+    def set(self, cache_key, vary, value):
+        if cache_key not in self.cache:
+            self.cache[cache_key] = {}
+        self.cache[cache_key][vary] = {'time': time.time(), 'value': value}
 
-    def invalidate(self, cache_key):
-        self.cache.pop(cache_key, None)
+    def invalidate(self, cache_key, vary=None):
+        if vary is None:
+            self.cache.pop(cache_key, None)
+        else:
+            self.cache[cache_key].pop(vary, None)
 
 
 class LDListMixin:
@@ -99,6 +107,8 @@ class LDListMixin:
 
         parent_model = None
 
+        cache_vary = str(self.context['request'].user)
+
         if isinstance(value, QuerySet):
             value = list(value)
 
@@ -107,8 +117,8 @@ class LDListMixin:
                 self.id = '{}{}{}'.format(settings.BASE_URL, Model.resource(parent_model), self.id)
 
             cache_key = self.id
-            if self.with_cache and self.to_representation_cache.has(cache_key):
-                return self.to_representation_cache.get(cache_key)
+            if self.with_cache and self.to_representation_cache.has(cache_key, cache_vary):
+                return self.to_representation_cache.get(cache_key, cache_vary)
 
             filtered_values = value
             container_permissions = Model.get_permissions(child_model, self.context, ['view', 'add'])
@@ -124,8 +134,8 @@ class LDListMixin:
                 self.id = '{}{}{}'.format(settings.BASE_URL, Model.resource(parent_model), self.id)
 
             cache_key = self.id
-            if self.with_cache and self.to_representation_cache.has(cache_key):
-                return self.to_representation_cache.get(cache_key)
+            if self.with_cache and self.to_representation_cache.has(cache_key, cache_vary):
+                return self.to_representation_cache.get(cache_key, cache_vary)
 
             # remove objects from the list which I don't have permission to view
             filtered_values = list(
@@ -135,13 +145,13 @@ class LDListMixin:
             container_permissions.extend(
                 Model.get_permissions(parent_model, self.context, ['view']))
 
-        self.to_representation_cache.set(self.id, {'@id': self.id,
+        self.to_representation_cache.set(self.id, cache_vary, {'@id': self.id,
                                                    '@type': 'ldp:Container',
                                                    'ldp:contains': super().to_representation(filtered_values),
                                                    'permissions': container_permissions
                                                    })
 
-        return self.to_representation_cache.get(self.id)
+        return self.to_representation_cache.get(self.id, cache_vary)
 
     def get_attribute(self, instance):
         parent_id_field = self.parent.fields[self.parent.url_field_name]
@@ -324,12 +334,13 @@ class LDPSerializer(HyperlinkedModelSerializer):
         if self.context['request'].method == 'GET' and Model.is_external(obj):
             return {'@id': obj.urlid}
 
+        cache_vary = str(self.context['request'].user)
         if self.with_cache and hasattr(obj, 'urlid'):
-            if self.to_representation_cache.has(obj.urlid):
-                data = self.to_representation_cache.get(obj.urlid)
+            if self.to_representation_cache.has(obj.urlid, cache_vary):
+                data = self.to_representation_cache.get(obj.urlid, cache_vary)
             else:
                 data = super().to_representation(obj)
-                self.to_representation_cache.set(obj.urlid, data)
+                self.to_representation_cache.set(obj.urlid, cache_vary, data)
         else:
             data = super().to_representation(obj)
 
@@ -594,7 +605,8 @@ class LDPSerializer(HyperlinkedModelSerializer):
     def create(self, validated_data):
         with transaction.atomic():
             instance = self.internal_create(validated_data, model=self.Meta.model)
-            LDListMixin.to_representation_cache.invalidate('{}{}'.format(settings.BASE_URL, Model.resource(self.Meta.model)))
+            LDListMixin.to_representation_cache.invalidate(
+                '{}{}'.format(settings.BASE_URL, Model.resource(self.Meta.model)))
             self.attach_related_object(instance, validated_data)
 
         return instance
-- 
GitLab