Fixed "$pull" semantics for nested ListFields (#447)

This commit is contained in:
Ross Lawley
2013-08-20 15:54:42 +00:00
parent a707598042
commit 0dd01bda01
6 changed files with 99 additions and 10 deletions

View File

@@ -23,8 +23,9 @@ def _import_class(cls_name):
field_classes = ('DictField', 'DynamicField', 'EmbeddedDocumentField',
'FileField', 'GenericReferenceField',
'GenericEmbeddedDocumentField', 'GeoPointField',
'PointField', 'LineStringField', 'PolygonField',
'ReferenceField', 'StringField', 'ComplexBaseField')
'PointField', 'LineStringField', 'ListField',
'PolygonField', 'ReferenceField', 'StringField',
'ComplexBaseField')
queryset_classes = ('OperationError',)
deref_classes = ('DeReference',)

View File

@@ -780,6 +780,10 @@ class DictField(ComplexBaseField):
if op in match_operators and isinstance(value, basestring):
return StringField().prepare_query_value(op, value)
if hasattr(self.field, 'field'):
return self.field.prepare_query_value(op, value)
return super(DictField, self).prepare_query_value(op, value)

View File

@@ -182,6 +182,7 @@ def update(_doc_cls=None, **update):
parts = []
cleaned_fields = []
appended_sub_field = False
for field in fields:
append_field = True
if isinstance(field, basestring):
@@ -193,10 +194,17 @@ def update(_doc_cls=None, **update):
else:
parts.append(field.db_field)
if append_field:
appended_sub_field = False
cleaned_fields.append(field)
if hasattr(field, 'field'):
cleaned_fields.append(field.field)
appended_sub_field = True
# Convert value to proper value
field = cleaned_fields[-1]
if appended_sub_field:
field = cleaned_fields[-2]
else:
field = cleaned_fields[-1]
if op in (None, 'set', 'push', 'pull'):
if field.required or value is not None:
@@ -223,11 +231,24 @@ def update(_doc_cls=None, **update):
if 'pull' in op and '.' in key:
# Dot operators don't work on pull operations
# it uses nested dict syntax
# unless they point to a list field
# Otherwise it uses nested dict syntax
if op == 'pullAll':
raise InvalidQueryError("pullAll operations only support "
"a single field depth")
# Look for the last list field and use dot notation until there
field_classes = [c.__class__ for c in cleaned_fields]
field_classes.reverse()
ListField = _import_class('ListField')
if ListField in field_classes:
# Join all fields via dot notation to the last ListField
# Then process as normal
last_listField = len(cleaned_fields) - field_classes.index(ListField)
key = ".".join(parts[:last_listField])
parts = parts[last_listField:]
parts.insert(0, key)
parts.reverse()
for key in parts:
value = {key: value}