From 0dd01bda016e44aca102d4998bf7c1a0a89739e9 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Tue, 20 Aug 2013 15:54:42 +0000 Subject: [PATCH] Fixed "$pull" semantics for nested ListFields (#447) --- AUTHORS | 3 +- docs/changelog.rst | 1 + mongoengine/common.py | 5 ++- mongoengine/fields.py | 4 ++ mongoengine/queryset/transform.py | 25 ++++++++++- tests/queryset/queryset.py | 71 ++++++++++++++++++++++++++++--- 6 files changed, 99 insertions(+), 10 deletions(-) diff --git a/AUTHORS b/AUTHORS index a5b73c7b..452ba370 100644 --- a/AUTHORS +++ b/AUTHORS @@ -176,4 +176,5 @@ that much better: * Thom Knowles (https://github.com/fleat) * Paul (https://github.com/squamous) * Olivier Cortès (https://github.com/Karmak23) - * crazyzubr (https://github.com/crazyzubr) \ No newline at end of file + * crazyzubr (https://github.com/crazyzubr) + * FrankSomething (https://github.com/FrankSomething) \ No newline at end of file diff --git a/docs/changelog.rst b/docs/changelog.rst index 27754293..6a0258cc 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,7 @@ Changelog Changes in 0.8.4 ================ +- Fixed "$pull" semantics for nested ListFields (#447) - Allow fields to be named the same as query operators (#445) - Updated field filter logic - can now exclude subclass fields (#443) - Fixed dereference issue with embedded listfield referencefields (#439) diff --git a/mongoengine/common.py b/mongoengine/common.py index 20d51387..6303231e 100644 --- a/mongoengine/common.py +++ b/mongoengine/common.py @@ -23,8 +23,9 @@ def _import_class(cls_name): field_classes = ('DictField', 'DynamicField', 'EmbeddedDocumentField', 'FileField', 'GenericReferenceField', 'GenericEmbeddedDocumentField', 'GeoPointField', - 'PointField', 'LineStringField', 'PolygonField', - 'ReferenceField', 'StringField', 'ComplexBaseField') + 'PointField', 'LineStringField', 'ListField', + 'PolygonField', 'ReferenceField', 'StringField', + 'ComplexBaseField') queryset_classes = ('OperationError',) deref_classes = ('DeReference',) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index c1fc1a76..419f2ef7 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -780,6 +780,10 @@ class DictField(ComplexBaseField): if op in match_operators and isinstance(value, basestring): return StringField().prepare_query_value(op, value) + + if hasattr(self.field, 'field'): + return self.field.prepare_query_value(op, value) + return super(DictField, self).prepare_query_value(op, value) diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index d82f33db..2ee7e386 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -182,6 +182,7 @@ def update(_doc_cls=None, **update): parts = [] cleaned_fields = [] + appended_sub_field = False for field in fields: append_field = True if isinstance(field, basestring): @@ -193,10 +194,17 @@ def update(_doc_cls=None, **update): else: parts.append(field.db_field) if append_field: + appended_sub_field = False cleaned_fields.append(field) + if hasattr(field, 'field'): + cleaned_fields.append(field.field) + appended_sub_field = True # Convert value to proper value - field = cleaned_fields[-1] + if appended_sub_field: + field = cleaned_fields[-2] + else: + field = cleaned_fields[-1] if op in (None, 'set', 'push', 'pull'): if field.required or value is not None: @@ -223,11 +231,24 @@ def update(_doc_cls=None, **update): if 'pull' in op and '.' in key: # Dot operators don't work on pull operations - # it uses nested dict syntax + # unless they point to a list field + # Otherwise it uses nested dict syntax if op == 'pullAll': raise InvalidQueryError("pullAll operations only support " "a single field depth") + # Look for the last list field and use dot notation until there + field_classes = [c.__class__ for c in cleaned_fields] + field_classes.reverse() + ListField = _import_class('ListField') + if ListField in field_classes: + # Join all fields via dot notation to the last ListField + # Then process as normal + last_listField = len(cleaned_fields) - field_classes.index(ListField) + key = ".".join(parts[:last_listField]) + parts = parts[last_listField:] + parts.insert(0, key) + parts.reverse() for key in parts: value = {key: value} diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 7f641351..b4bcf2a7 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -1497,9 +1497,6 @@ class QuerySetTest(unittest.TestCase): def test_pull_nested(self): - class User(Document): - name = StringField() - class Collaborator(EmbeddedDocument): user = StringField() @@ -1514,8 +1511,7 @@ class QuerySetTest(unittest.TestCase): Site.drop_collection() c = Collaborator(user='Esteban') - s = Site(name="test", collaborators=[c]) - s.save() + s = Site(name="test", collaborators=[c]).save() Site.objects(id=s.id).update_one(pull__collaborators__user='Esteban') self.assertEqual(Site.objects.first().collaborators, []) @@ -1525,6 +1521,71 @@ class QuerySetTest(unittest.TestCase): self.assertRaises(InvalidQueryError, pull_all) + def test_pull_from_nested_embedded(self): + + class User(EmbeddedDocument): + name = StringField() + + def __unicode__(self): + return '%s' % self.name + + class Collaborator(EmbeddedDocument): + helpful = ListField(EmbeddedDocumentField(User)) + unhelpful = ListField(EmbeddedDocumentField(User)) + + class Site(Document): + name = StringField(max_length=75, unique=True, required=True) + collaborators = EmbeddedDocumentField(Collaborator) + + + Site.drop_collection() + + c = User(name='Esteban') + f = User(name='Frank') + s = Site(name="test", collaborators=Collaborator(helpful=[c], unhelpful=[f])).save() + + Site.objects(id=s.id).update_one(pull__collaborators__helpful=c) + self.assertEqual(Site.objects.first().collaborators['helpful'], []) + + Site.objects(id=s.id).update_one(pull__collaborators__unhelpful={'name': 'Frank'}) + self.assertEqual(Site.objects.first().collaborators['unhelpful'], []) + + def pull_all(): + Site.objects(id=s.id).update_one(pull_all__collaborators__helpful__name=['Ross']) + + self.assertRaises(InvalidQueryError, pull_all) + + def test_pull_from_nested_mapfield(self): + + class Collaborator(EmbeddedDocument): + user = StringField() + + def __unicode__(self): + return '%s' % self.user + + class Site(Document): + name = StringField(max_length=75, unique=True, required=True) + collaborators = MapField(ListField(EmbeddedDocumentField(Collaborator))) + + + Site.drop_collection() + + c = Collaborator(user='Esteban') + f = Collaborator(user='Frank') + s = Site(name="test", collaborators={'helpful':[c],'unhelpful':[f]}) + s.save() + + Site.objects(id=s.id).update_one(pull__collaborators__helpful__user='Esteban') + self.assertEqual(Site.objects.first().collaborators['helpful'], []) + + Site.objects(id=s.id).update_one(pull__collaborators__unhelpful={'user':'Frank'}) + self.assertEqual(Site.objects.first().collaborators['unhelpful'], []) + + def pull_all(): + Site.objects(id=s.id).update_one(pull_all__collaborators__helpful__user=['Ross']) + + self.assertRaises(InvalidQueryError, pull_all) + def test_update_one_pop_generic_reference(self): class BlogTag(Document):