Fixed "$pull" semantics for nested ListFields (#447)

This commit is contained in:
Ross Lawley 2013-08-20 15:54:42 +00:00
parent a707598042
commit 0dd01bda01
6 changed files with 99 additions and 10 deletions

View File

@ -176,4 +176,5 @@ that much better:
* Thom Knowles (https://github.com/fleat) * Thom Knowles (https://github.com/fleat)
* Paul (https://github.com/squamous) * Paul (https://github.com/squamous)
* Olivier Cortès (https://github.com/Karmak23) * Olivier Cortès (https://github.com/Karmak23)
* crazyzubr (https://github.com/crazyzubr) * crazyzubr (https://github.com/crazyzubr)
* FrankSomething (https://github.com/FrankSomething)

View File

@ -4,6 +4,7 @@ Changelog
Changes in 0.8.4 Changes in 0.8.4
================ ================
- Fixed "$pull" semantics for nested ListFields (#447)
- Allow fields to be named the same as query operators (#445) - Allow fields to be named the same as query operators (#445)
- Updated field filter logic - can now exclude subclass fields (#443) - Updated field filter logic - can now exclude subclass fields (#443)
- Fixed dereference issue with embedded listfield referencefields (#439) - Fixed dereference issue with embedded listfield referencefields (#439)

View File

@ -23,8 +23,9 @@ def _import_class(cls_name):
field_classes = ('DictField', 'DynamicField', 'EmbeddedDocumentField', field_classes = ('DictField', 'DynamicField', 'EmbeddedDocumentField',
'FileField', 'GenericReferenceField', 'FileField', 'GenericReferenceField',
'GenericEmbeddedDocumentField', 'GeoPointField', 'GenericEmbeddedDocumentField', 'GeoPointField',
'PointField', 'LineStringField', 'PolygonField', 'PointField', 'LineStringField', 'ListField',
'ReferenceField', 'StringField', 'ComplexBaseField') 'PolygonField', 'ReferenceField', 'StringField',
'ComplexBaseField')
queryset_classes = ('OperationError',) queryset_classes = ('OperationError',)
deref_classes = ('DeReference',) deref_classes = ('DeReference',)

View File

@ -780,6 +780,10 @@ class DictField(ComplexBaseField):
if op in match_operators and isinstance(value, basestring): if op in match_operators and isinstance(value, basestring):
return StringField().prepare_query_value(op, value) return StringField().prepare_query_value(op, value)
if hasattr(self.field, 'field'):
return self.field.prepare_query_value(op, value)
return super(DictField, self).prepare_query_value(op, value) return super(DictField, self).prepare_query_value(op, value)

View File

@ -182,6 +182,7 @@ def update(_doc_cls=None, **update):
parts = [] parts = []
cleaned_fields = [] cleaned_fields = []
appended_sub_field = False
for field in fields: for field in fields:
append_field = True append_field = True
if isinstance(field, basestring): if isinstance(field, basestring):
@ -193,10 +194,17 @@ def update(_doc_cls=None, **update):
else: else:
parts.append(field.db_field) parts.append(field.db_field)
if append_field: if append_field:
appended_sub_field = False
cleaned_fields.append(field) cleaned_fields.append(field)
if hasattr(field, 'field'):
cleaned_fields.append(field.field)
appended_sub_field = True
# Convert value to proper value # Convert value to proper value
field = cleaned_fields[-1] if appended_sub_field:
field = cleaned_fields[-2]
else:
field = cleaned_fields[-1]
if op in (None, 'set', 'push', 'pull'): if op in (None, 'set', 'push', 'pull'):
if field.required or value is not None: if field.required or value is not None:
@ -223,11 +231,24 @@ def update(_doc_cls=None, **update):
if 'pull' in op and '.' in key: if 'pull' in op and '.' in key:
# Dot operators don't work on pull operations # Dot operators don't work on pull operations
# it uses nested dict syntax # unless they point to a list field
# Otherwise it uses nested dict syntax
if op == 'pullAll': if op == 'pullAll':
raise InvalidQueryError("pullAll operations only support " raise InvalidQueryError("pullAll operations only support "
"a single field depth") "a single field depth")
# Look for the last list field and use dot notation until there
field_classes = [c.__class__ for c in cleaned_fields]
field_classes.reverse()
ListField = _import_class('ListField')
if ListField in field_classes:
# Join all fields via dot notation to the last ListField
# Then process as normal
last_listField = len(cleaned_fields) - field_classes.index(ListField)
key = ".".join(parts[:last_listField])
parts = parts[last_listField:]
parts.insert(0, key)
parts.reverse() parts.reverse()
for key in parts: for key in parts:
value = {key: value} value = {key: value}

View File

@ -1497,9 +1497,6 @@ class QuerySetTest(unittest.TestCase):
def test_pull_nested(self): def test_pull_nested(self):
class User(Document):
name = StringField()
class Collaborator(EmbeddedDocument): class Collaborator(EmbeddedDocument):
user = StringField() user = StringField()
@ -1514,8 +1511,7 @@ class QuerySetTest(unittest.TestCase):
Site.drop_collection() Site.drop_collection()
c = Collaborator(user='Esteban') c = Collaborator(user='Esteban')
s = Site(name="test", collaborators=[c]) s = Site(name="test", collaborators=[c]).save()
s.save()
Site.objects(id=s.id).update_one(pull__collaborators__user='Esteban') Site.objects(id=s.id).update_one(pull__collaborators__user='Esteban')
self.assertEqual(Site.objects.first().collaborators, []) self.assertEqual(Site.objects.first().collaborators, [])
@ -1525,6 +1521,71 @@ class QuerySetTest(unittest.TestCase):
self.assertRaises(InvalidQueryError, pull_all) self.assertRaises(InvalidQueryError, pull_all)
def test_pull_from_nested_embedded(self):
class User(EmbeddedDocument):
name = StringField()
def __unicode__(self):
return '%s' % self.name
class Collaborator(EmbeddedDocument):
helpful = ListField(EmbeddedDocumentField(User))
unhelpful = ListField(EmbeddedDocumentField(User))
class Site(Document):
name = StringField(max_length=75, unique=True, required=True)
collaborators = EmbeddedDocumentField(Collaborator)
Site.drop_collection()
c = User(name='Esteban')
f = User(name='Frank')
s = Site(name="test", collaborators=Collaborator(helpful=[c], unhelpful=[f])).save()
Site.objects(id=s.id).update_one(pull__collaborators__helpful=c)
self.assertEqual(Site.objects.first().collaborators['helpful'], [])
Site.objects(id=s.id).update_one(pull__collaborators__unhelpful={'name': 'Frank'})
self.assertEqual(Site.objects.first().collaborators['unhelpful'], [])
def pull_all():
Site.objects(id=s.id).update_one(pull_all__collaborators__helpful__name=['Ross'])
self.assertRaises(InvalidQueryError, pull_all)
def test_pull_from_nested_mapfield(self):
class Collaborator(EmbeddedDocument):
user = StringField()
def __unicode__(self):
return '%s' % self.user
class Site(Document):
name = StringField(max_length=75, unique=True, required=True)
collaborators = MapField(ListField(EmbeddedDocumentField(Collaborator)))
Site.drop_collection()
c = Collaborator(user='Esteban')
f = Collaborator(user='Frank')
s = Site(name="test", collaborators={'helpful':[c],'unhelpful':[f]})
s.save()
Site.objects(id=s.id).update_one(pull__collaborators__helpful__user='Esteban')
self.assertEqual(Site.objects.first().collaborators['helpful'], [])
Site.objects(id=s.id).update_one(pull__collaborators__unhelpful={'user':'Frank'})
self.assertEqual(Site.objects.first().collaborators['unhelpful'], [])
def pull_all():
Site.objects(id=s.id).update_one(pull_all__collaborators__helpful__user=['Ross'])
self.assertRaises(InvalidQueryError, pull_all)
def test_update_one_pop_generic_reference(self): def test_update_one_pop_generic_reference(self):
class BlogTag(Document): class BlogTag(Document):