From 2b02e3e96875e455afb8cda1574a1d1763d45cb0 Mon Sep 17 00:00:00 2001
From: Jean-Baptiste <bleme@pm.me>
Date: Tue, 12 Feb 2019 18:21:39 +0100
Subject: [PATCH] fix: manage simple nested field (non list)

---
 djangoldp/serializers.py | 92 +++++++++++++++++++++++++++++++++-------
 1 file changed, 76 insertions(+), 16 deletions(-)

diff --git a/djangoldp/serializers.py b/djangoldp/serializers.py
index 82c031d8..4f8ead62 100644
--- a/djangoldp/serializers.py
+++ b/djangoldp/serializers.py
@@ -1,3 +1,4 @@
+from collections import OrderedDict, Mapping
 from urllib import parse
 
 from django.core.exceptions import ImproperlyConfigured
@@ -5,8 +6,13 @@ from django.core.urlresolvers import get_resolver, resolve, get_script_prefix, R
 from django.utils.datastructures import MultiValueDictKeyError
 from django.utils.encoding import uri_to_iri
 from guardian.shortcuts import get_perms
+from django.core.exceptions import ValidationError as DjangoValidationError
+from rest_framework.exceptions import ValidationError
+from rest_framework.fields import SkipField
+from rest_framework.fields import get_error_detail, set_value
 from rest_framework.relations import HyperlinkedRelatedField, ManyRelatedField, MANY_RELATION_KWARGS
 from rest_framework.serializers import HyperlinkedModelSerializer, ListSerializer
+from rest_framework.settings import api_settings
 from rest_framework.utils.field_mapping import get_nested_relation_kwargs
 from rest_framework.utils.serializer_helpers import ReturnDict
 
@@ -40,16 +46,16 @@ class LDListMixin:
             obj = next(filter(lambda o: part_id in o['@id'], object_list))
             list = super().get_value(obj);
             try:
-                list= list['ldp:contains']
+                list = list['ldp:contains']
             except KeyError:
                 pass
 
             if isinstance(list, dict):
                 list = [list]
 
-            ret=[]
+            ret = []
             for item in list:
-                full_item=None
+                full_item = None
                 try:
                     full_item = next(filter(lambda o: item['@id'] == o['@id'], object_list))
                 except StopIteration:
@@ -77,7 +83,7 @@ class ContainerSerializer(LDListMixin, ListSerializer):
     def to_internal_value(self, data):
         try:
             return super().to_internal_value(data['@id'])
-        except:
+        except (KeyError, TypeError):
             return super().to_internal_value(data)
 
 
@@ -115,7 +121,7 @@ class JsonLdRelatedField(JsonLdField):
     def to_internal_value(self, data):
         try:
             return super().to_internal_value(data['@id'])
-        except KeyError:
+        except (KeyError, TypeError):
             return super().to_internal_value(data)
 
     @classmethod
@@ -211,6 +217,38 @@ class LDPSerializer(HyperlinkedModelSerializer):
 
             def to_internal_value(self, data):
                 if self.url_field_name in data:
+                    if not isinstance(data, Mapping):
+                        message = self.error_messages['invalid'].format(
+                            datatype=type(data).__name__
+                        )
+                        raise ValidationError({
+                            api_settings.NON_FIELD_ERRORS_KEY: [message]
+                        }, code='invalid')
+
+                    ret = OrderedDict()
+                    errors = OrderedDict()
+                    fields = list(filter(lambda x: x.field_name in data, self._writable_fields))
+
+                    for field in fields:
+                        validate_method = getattr(self, 'validate_' + field.field_name, None)
+                        primitive_value = field.get_value(data)
+                        try:
+                            validated_value = field.run_validation(primitive_value)
+                            if validate_method is not None:
+                                validated_value = validate_method(validated_value)
+                        except ValidationError as exc:
+                            errors[field.field_name] = exc.detail
+                        except DjangoValidationError as exc:
+                            errors[field.field_name] = get_error_detail(exc)
+                        except SkipField:
+                            pass
+                        else:
+                            set_value(ret, field.source_attrs, validated_value)
+
+                    if errors:
+                        raise ValidationError(errors)
+
+
                     uri = data[self.url_field_name]
                     http_prefix = uri.startswith(('http:', 'https:'))
 
@@ -222,17 +260,14 @@ class LDPSerializer(HyperlinkedModelSerializer):
 
                     try:
                         match = resolve(uri_to_iri(uri))
-                        instance = self.Meta.model.objects.get(pk=match.kwargs['pk'])
-                        for key in self.data:
-                            if not key in data:
-                                data[key] = getattr(instance, key)
+                        ret['pk'] = match.kwargs['pk']
                     except Resolver404:
                         pass
 
-                return super().to_internal_value(data)
+                    return ret
+                else:
+                    return super().to_internal_value(data)
 
-            def get_value(self, dictionary):
-                return super().get_value(dictionary)
 
         kwargs = get_nested_relation_kwargs(relation_info)
         kwargs['read_only'] = False
@@ -249,7 +284,25 @@ class LDPSerializer(HyperlinkedModelSerializer):
         return ContainerSerializer(*args, **kwargs)
 
     def get_value(self, dictionary):
-        return super().get_value(dictionary)
+        try:
+            object_list = dictionary["@graph"]
+            view_name = '{}-list'.format(self.parent.Meta.model._meta.object_name.lower())
+            part_id = '/{}'.format(get_resolver().reverse_dict[view_name][0][0][0],
+                                   self.parent.instance.pk)
+            obj = next(filter(lambda o: part_id in o[self.url_field_name], object_list))
+            item = super().get_value(obj)
+            full_item = None
+            try:
+                full_item = next(filter(lambda o: item['@id'] == o['@id'], object_list))
+            except StopIteration:
+                pass
+            if full_item is None:
+                return item
+            else:
+                return full_item
+
+        except KeyError:
+            return super().get_value(dictionary)
 
     def create(self, validated_data):
         return self.internal_create(validated_data, model=self.Meta.model)
@@ -262,7 +315,7 @@ class LDPSerializer(HyperlinkedModelSerializer):
 
         instance = model.objects.create(**validated_data)
 
-        self.save_or_update_nested(instance, nested_fields)
+        self.save_or_update_nested_list(instance, nested_fields)
 
         return instance
 
@@ -273,14 +326,21 @@ class LDPSerializer(HyperlinkedModelSerializer):
             nested_fields.append((field_name, validated_data.pop(field_name)))
 
         for attr, value in validated_data.items():
+            if isinstance(value, dict):
+                manager = getattr(instance, attr)
+                if 'pk' in value:
+                    oldObj = manager._meta.model.objects.get(pk=value['pk'])
+                    value = self.update(instance=oldObj, validated_data=value)
+                else:
+                    value = self.internal_create(validated_data=value, model=manager._meta.model)
             setattr(instance, attr, value)
         instance.save()
 
-        self.save_or_update_nested(instance, nested_fields)
+        self.save_or_update_nested_list(instance, nested_fields)
 
         return instance
 
-    def save_or_update_nested(self, instance, nested_fields):
+    def save_or_update_nested_list(self, instance, nested_fields):
         for (field_name, data) in nested_fields:
             try:
                 getattr(instance, field_name).clear()
-- 
GitLab