Merge pull request #1765 from kushalmitruka/master
fixed : `pull` not working for EmbeddedDocumentListField, only working for ListFields #1534
This commit is contained in:
commit
317e844886
@ -314,11 +314,17 @@ def update(_doc_cls=None, **update):
|
|||||||
field_classes = [c.__class__ for c in cleaned_fields]
|
field_classes = [c.__class__ for c in cleaned_fields]
|
||||||
field_classes.reverse()
|
field_classes.reverse()
|
||||||
ListField = _import_class('ListField')
|
ListField = _import_class('ListField')
|
||||||
if ListField in field_classes:
|
EmbeddedDocumentListField = _import_class('EmbeddedDocumentListField')
|
||||||
# Join all fields via dot notation to the last ListField
|
if ListField in field_classes or EmbeddedDocumentListField in field_classes:
|
||||||
|
# Join all fields via dot notation to the last ListField or EmbeddedDocumentListField
|
||||||
# Then process as normal
|
# Then process as normal
|
||||||
|
if ListField in field_classes:
|
||||||
|
_check_field = ListField
|
||||||
|
else:
|
||||||
|
_check_field = EmbeddedDocumentListField
|
||||||
|
|
||||||
last_listField = len(
|
last_listField = len(
|
||||||
cleaned_fields) - field_classes.index(ListField)
|
cleaned_fields) - field_classes.index(_check_field)
|
||||||
key = '.'.join(parts[:last_listField])
|
key = '.'.join(parts[:last_listField])
|
||||||
parts = parts[last_listField:]
|
parts = parts[last_listField:]
|
||||||
parts.insert(0, key)
|
parts.insert(0, key)
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
from bson.son import SON
|
||||||
|
|
||||||
from mongoengine import *
|
from mongoengine import *
|
||||||
from mongoengine.queryset import Q, transform
|
from mongoengine.queryset import Q, transform
|
||||||
|
|
||||||
@ -258,7 +260,31 @@ class TransformTest(unittest.TestCase):
|
|||||||
events = Event.objects(location__within=box)
|
events = Event.objects(location__within=box)
|
||||||
with self.assertRaises(InvalidQueryError):
|
with self.assertRaises(InvalidQueryError):
|
||||||
events.count()
|
events.count()
|
||||||
|
|
||||||
|
def test_update_pull_for_list_fields(self):
|
||||||
|
"""
|
||||||
|
Test added to check pull operation in update for
|
||||||
|
EmbeddedDocumentListField which is inside a EmbeddedDocumentField
|
||||||
|
"""
|
||||||
|
class Word(EmbeddedDocument):
|
||||||
|
word = StringField()
|
||||||
|
index = IntField()
|
||||||
|
|
||||||
|
class SubDoc(EmbeddedDocument):
|
||||||
|
heading = ListField(StringField())
|
||||||
|
text = EmbeddedDocumentListField(Word)
|
||||||
|
|
||||||
|
class MainDoc(Document):
|
||||||
|
title = StringField()
|
||||||
|
content = EmbeddedDocumentField(SubDoc)
|
||||||
|
|
||||||
|
word = Word(word='abc', index=1)
|
||||||
|
update = transform.update(MainDoc, pull__content__text=word)
|
||||||
|
self.assertEqual(update, {'$pull': {'content.text': SON([('word', u'abc'), ('index', 1)])}})
|
||||||
|
|
||||||
|
update = transform.update(MainDoc, pull__content__heading='xyz')
|
||||||
|
self.assertEqual(update, {'$pull': {'content.heading': 'xyz'}})
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user