diff --git a/djangoldp/serializers.py b/djangoldp/serializers.py index da4632611caaa5efb7bb040a461bba7cc8153733..adfc85b3eb77dbc94b5232b05abbc07069e74cc6 100644 --- a/djangoldp/serializers.py +++ b/djangoldp/serializers.py @@ -17,13 +17,9 @@ from rest_framework.utils import model_meta from rest_framework.utils.field_mapping import get_nested_relation_kwargs from rest_framework.utils.serializer_helpers import ReturnDict +from djangoldp.fields import LDPUrlField, IdURLField from djangoldp.models import Model -from rest_framework.serializers import HyperlinkedModelSerializer, ListSerializer, ModelSerializer -from rest_framework.utils.field_mapping import get_nested_relation_kwargs -from rest_framework.utils.serializer_helpers import ReturnDict - -from djangoldp.fields import LDPUrlField, IdURLField class LDListMixin: def to_internal_value(self, data): @@ -50,11 +46,17 @@ class LDListMixin: def get_value(self, dictionary): try: object_list = dictionary["@graph"] - container_id = Model.container_id(self.parent.instance) - obj = next(filter(lambda o: container_id in o['@id'], object_list)) + if self.parent.instance is None: + obj = next(filter( + lambda o: not hasattr(o, self.parent.url_field_name) or "./" in o[self.parent.url_field_name], + object_list)) + else: + container_id = Model.container_id(self.parent.instance) + obj = next(filter(lambda o: container_id in o[self.parent.url_field_name], object_list)) list = super().get_value(obj) try: - list = next(filter(lambda o: list['@id'] == o['@id'], object_list)) + list = next( + filter(lambda o: list[self.parent.url_field_name] == o[self.parent.url_field_name], object_list)) except (KeyError, TypeError): pass @@ -63,6 +65,9 @@ class LDListMixin: except (KeyError, TypeError): pass + if list is empty: + return [] + if isinstance(list, dict): list = [list] @@ -70,7 +75,9 @@ class LDListMixin: for item in list: full_item = None try: - full_item = next(filter(lambda o: item['@id'] == o['@id'], object_list)) + full_item = next(filter( + lambda o: self.parent.url_field_name in o and item[self.parent.url_field_name] == o[ + self.parent.url_field_name], object_list)) except StopIteration: pass if full_item is None: @@ -95,7 +102,7 @@ class ContainerSerializer(LDListMixin, ListSerializer): def to_internal_value(self, data): try: - return super().to_internal_value(data['@id']) + return super().to_internal_value(data[self.parent.url_field_name]) except (KeyError, TypeError): return super().to_internal_value(data) @@ -133,7 +140,7 @@ class JsonLdRelatedField(JsonLdField): def to_internal_value(self, data): try: - return super().to_internal_value(data['@id']) + return super().to_internal_value(data[self.parent.url_field_name]) except (KeyError, TypeError): return super().to_internal_value(data) @@ -157,7 +164,7 @@ class JsonLdIdentityField(JsonLdField): def to_internal_value(self, data): try: - return super().to_internal_value(data['@id']) + return super().to_internal_value(data[self.parent.url_field_name]) except KeyError: return super().to_internal_value(data) @@ -169,8 +176,7 @@ class LDPSerializer(HyperlinkedModelSerializer): url_field_name = "@id" serializer_related_field = JsonLdRelatedField serializer_url_field = JsonLdIdentityField - ModelSerializer.serializer_field_mapping [LDPUrlField] = IdURLField - + ModelSerializer.serializer_field_mapping[LDPUrlField] = IdURLField @property def data(self): @@ -209,9 +215,15 @@ class LDPSerializer(HyperlinkedModelSerializer): def get_value(self, dictionary): try: object_list = dictionary["@graph"] - resource_id = Model.resource_id(self.parent.instance) - obj = next(filter(lambda o: resource_id in o['@id'], object_list)) - return super().get_value(obj) + if self.parent.instance is None: + obj = next(filter( + lambda o: not hasattr(o, self.parent.url_field_name) or "./" in o[self.url_field_name], + object_list)) + return super().get_value(obj) + else: + resource_id = Model.resource_id(self.parent.instance) + obj = next(filter(lambda o: resource_id in o[self.parent.url_field_name], object_list)) + return super().get_value(obj) except KeyError: return super().get_value(dictionary) @@ -306,7 +318,7 @@ class LDPSerializer(HyperlinkedModelSerializer): if item is empty: return empty try: - full_item = next(filter(lambda o: item['@id'] == o['@id'], object_list)) + full_item = next(filter(lambda o: item[self.url_field_name] == o[self.url_field_name], object_list)) except StopIteration: pass if full_item is None: diff --git a/djangoldp/tests/runner.py b/djangoldp/tests/runner.py index 47a4548bacecc8571777f1d7400a573ff64558c2..2d3f2e433db6d47c54426ebbcc944bd01286123b 100644 --- a/djangoldp/tests/runner.py +++ b/djangoldp/tests/runner.py @@ -30,7 +30,8 @@ failures = test_runner.run_tests([ 'djangoldp.tests.tests_save', 'djangoldp.tests.tests_user_permissions', 'djangoldp.tests.tests_anonymous_permissions', - 'djangoldp.tests.tests_update']) + 'djangoldp.tests.tests_update', +]) if failures: sys.exit(failures) diff --git a/djangoldp/tests/tests_save.py b/djangoldp/tests/tests_save.py index 4cf4f0d48b0079bb07a30fd14d24448619b62071..36f724bb19de040e21eac8d4204fc5a0ea39b5c7 100644 --- a/djangoldp/tests/tests_save.py +++ b/djangoldp/tests/tests_save.py @@ -15,7 +15,7 @@ class Save(TestCase): "ldp:contains": [ {"@id": "https://happy-dev.fr/skills/{}/".format(skill1.pk)}, {"@id": "https://happy-dev.fr/skills/{}/".format(skill2.pk), "title": "skill2 UP"}, - {"title": "skill3 NEW", "obligatoire":"obligatoire"}, + {"title": "skill3 NEW", "obligatoire": "obligatoire"}, ]} } @@ -29,9 +29,49 @@ class Save(TestCase): self.assertEquals(result.title, "job test") self.assertIs(result.skills.count(), 3) - self.assertEquals(result.skills.all()[0].title, "skill1") # no change + self.assertEquals(result.skills.all()[0].title, "skill1") # no change self.assertEquals(result.skills.all()[1].title, "skill2 UP") # title updated - self.assertEquals(result.skills.all()[2].title, "skill3 NEW") # creation on the fly + self.assertEquals(result.skills.all()[2].title, "skill3 NEW") # creation on the fly + + def test_save_m2m_graph_simple(self): + job = {"@graph": [ + {"title": "job test", + }, + ]} + + meta_args = {'model': JobOffer, 'depth': 1, 'fields': ("@id", "title", "skills")} + + meta_class = type('Meta', (), meta_args) + serializer_class = type(LDPSerializer)('JobOfferSerializer', (LDPSerializer,), {'Meta': meta_class}) + serializer = serializer_class(data=job) + serializer.is_valid() + result = serializer.save() + + self.assertEquals(result.title, "job test") + self.assertIs(result.skills.count(), 0) + + def test_save_m2m_graph_with_nested(self): + skill1 = Skill.objects.create(title="skill1", obligatoire="obligatoire") + skill2 = Skill.objects.create(title="skill2", obligatoire="obligatoire") + + job = {"@graph": [ + {"title": "job test", + "skills": {"@id": "_.123"} + }, + {"@id": "_.123", "title": "skill3 NEW", "obligatoire": "obligatoire"}, + ]} + + meta_args = {'model': JobOffer, 'depth': 1, 'fields': ("@id", "title", "skills")} + + meta_class = type('Meta', (), meta_args) + serializer_class = type(LDPSerializer)('JobOfferSerializer', (LDPSerializer,), {'Meta': meta_class}) + serializer = serializer_class(data=job) + serializer.is_valid() + result = serializer.save() + + self.assertEquals(result.title, "job test") + self.assertIs(result.skills.count(), 1) + self.assertEquals(result.skills.all()[0].title, "skill3 NEW") # creation on the fly def test_save_without_nested_fields(self): skill1 = Skill.objects.create(title="skill1", obligatoire="obligatoire") @@ -70,4 +110,3 @@ class Save(TestCase): self.assertIs(result.joboffer_set.count(), 1) self.assertEquals(result.joboffer_set.get(), job) self.assertIs(result.joboffer_set.get().skills.count(), 1) -