Fixed "$pull" semantics for nested ListFields (#447)
This commit is contained in:
		| @@ -23,8 +23,9 @@ def _import_class(cls_name): | ||||
|     field_classes = ('DictField', 'DynamicField', 'EmbeddedDocumentField', | ||||
|                      'FileField', 'GenericReferenceField', | ||||
|                      'GenericEmbeddedDocumentField', 'GeoPointField', | ||||
|                      'PointField', 'LineStringField', 'PolygonField', | ||||
|                      'ReferenceField', 'StringField', 'ComplexBaseField') | ||||
|                      'PointField', 'LineStringField', 'ListField', | ||||
|                      'PolygonField', 'ReferenceField', 'StringField', | ||||
|                      'ComplexBaseField') | ||||
|     queryset_classes = ('OperationError',) | ||||
|     deref_classes = ('DeReference',) | ||||
|  | ||||
|   | ||||
| @@ -780,6 +780,10 @@ class DictField(ComplexBaseField): | ||||
|  | ||||
|         if op in match_operators and isinstance(value, basestring): | ||||
|             return StringField().prepare_query_value(op, value) | ||||
|  | ||||
|         if hasattr(self.field, 'field'): | ||||
|             return self.field.prepare_query_value(op, value) | ||||
|  | ||||
|         return super(DictField, self).prepare_query_value(op, value) | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -182,6 +182,7 @@ def update(_doc_cls=None, **update): | ||||
|             parts = [] | ||||
|  | ||||
|             cleaned_fields = [] | ||||
|             appended_sub_field = False | ||||
|             for field in fields: | ||||
|                 append_field = True | ||||
|                 if isinstance(field, basestring): | ||||
| @@ -193,10 +194,17 @@ def update(_doc_cls=None, **update): | ||||
|                 else: | ||||
|                     parts.append(field.db_field) | ||||
|                 if append_field: | ||||
|                     appended_sub_field = False | ||||
|                     cleaned_fields.append(field) | ||||
|                     if hasattr(field, 'field'): | ||||
|                         cleaned_fields.append(field.field) | ||||
|                         appended_sub_field = True | ||||
|  | ||||
|             # Convert value to proper value | ||||
|             field = cleaned_fields[-1] | ||||
|             if appended_sub_field: | ||||
|                 field = cleaned_fields[-2] | ||||
|             else: | ||||
|                 field = cleaned_fields[-1] | ||||
|  | ||||
|             if op in (None, 'set', 'push', 'pull'): | ||||
|                 if field.required or value is not None: | ||||
| @@ -223,11 +231,24 @@ def update(_doc_cls=None, **update): | ||||
|  | ||||
|         if 'pull' in op and '.' in key: | ||||
|             # Dot operators don't work on pull operations | ||||
|             # it uses nested dict syntax | ||||
|             # unless they point to a list field | ||||
|             # Otherwise it uses nested dict syntax | ||||
|             if op == 'pullAll': | ||||
|                 raise InvalidQueryError("pullAll operations only support " | ||||
|                                         "a single field depth") | ||||
|  | ||||
|             # Look for the last list field and use dot notation until there | ||||
|             field_classes = [c.__class__ for c in cleaned_fields] | ||||
|             field_classes.reverse() | ||||
|             ListField = _import_class('ListField') | ||||
|             if ListField in field_classes: | ||||
|                 # Join all fields via dot notation to the last ListField | ||||
|                 # Then process as normal | ||||
|                 last_listField = len(cleaned_fields) - field_classes.index(ListField) | ||||
|                 key = ".".join(parts[:last_listField]) | ||||
|                 parts = parts[last_listField:] | ||||
|                 parts.insert(0, key) | ||||
|  | ||||
|             parts.reverse() | ||||
|             for key in parts: | ||||
|                 value = {key: value} | ||||
|   | ||||
		Reference in New Issue
	
	Block a user