Merge remote-tracking branch 'origin/pr/242'
Conflicts: tests/test_document.py
This commit is contained in:
commit
f970d5878a
@ -1121,6 +1121,22 @@ class BaseDocument(object):
|
|||||||
key not in self._changed_fields):
|
key not in self._changed_fields):
|
||||||
self._changed_fields.append(key)
|
self._changed_fields.append(key)
|
||||||
|
|
||||||
|
def _clear_changed_fields(self):
|
||||||
|
self._changed_fields = []
|
||||||
|
EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
|
||||||
|
for field_name, field in self._fields.iteritems():
|
||||||
|
if (isinstance(field, ComplexBaseField) and
|
||||||
|
isinstance(field.field, EmbeddedDocumentField)):
|
||||||
|
field_value = getattr(self, field_name, None)
|
||||||
|
if field_value:
|
||||||
|
for idx in (field_value if isinstance(field_value, dict)
|
||||||
|
else xrange(len(field_value))):
|
||||||
|
field_value[idx]._clear_changed_fields()
|
||||||
|
elif isinstance(field, EmbeddedDocumentField):
|
||||||
|
field_value = getattr(self, field_name, None)
|
||||||
|
if field_value:
|
||||||
|
field_value._clear_changed_fields()
|
||||||
|
|
||||||
def _get_changed_fields(self, key='', inspected=None):
|
def _get_changed_fields(self, key='', inspected=None):
|
||||||
"""Returns a list of all fields that have explicitly been changed.
|
"""Returns a list of all fields that have explicitly been changed.
|
||||||
"""
|
"""
|
||||||
|
@ -269,7 +269,7 @@ class Document(BaseDocument):
|
|||||||
if id_field not in self._meta.get('shard_key', []):
|
if id_field not in self._meta.get('shard_key', []):
|
||||||
self[id_field] = self._fields[id_field].to_python(object_id)
|
self[id_field] = self._fields[id_field].to_python(object_id)
|
||||||
|
|
||||||
self._changed_fields = []
|
self._clear_changed_fields()
|
||||||
self._created = False
|
self._created = False
|
||||||
signals.post_save.send(self.__class__, document=self, created=created)
|
signals.post_save.send(self.__class__, document=self, created=created)
|
||||||
return self
|
return self
|
||||||
|
@ -3370,7 +3370,7 @@ class DocumentTest(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
) ]), "1,2")
|
) ]), "1,2")
|
||||||
|
|
||||||
def test_data_contains_idfield(self):
|
def test_data_contains_id_field(self):
|
||||||
"""Ensure that asking for _data returns 'id'
|
"""Ensure that asking for _data returns 'id'
|
||||||
"""
|
"""
|
||||||
class Person(Document):
|
class Person(Document):
|
||||||
@ -3383,6 +3383,47 @@ class DocumentTest(unittest.TestCase):
|
|||||||
self.assertTrue('_id' in person._data.keys())
|
self.assertTrue('_id' in person._data.keys())
|
||||||
self.assertEqual(person._data.get('_id'), person.id)
|
self.assertEqual(person._data.get('_id'), person.id)
|
||||||
|
|
||||||
|
def test_complex_nesting_document_and_embedded_document(self):
|
||||||
|
|
||||||
|
class Macro(EmbeddedDocument):
|
||||||
|
value = DynamicField(default="UNDEFINED")
|
||||||
|
|
||||||
|
class Parameter(EmbeddedDocument):
|
||||||
|
macros = MapField(EmbeddedDocumentField(Macro))
|
||||||
|
|
||||||
|
def expand(self):
|
||||||
|
self.macros["test"] = Macro()
|
||||||
|
|
||||||
|
class Node(Document):
|
||||||
|
parameters = MapField(EmbeddedDocumentField(Parameter))
|
||||||
|
|
||||||
|
def expand(self):
|
||||||
|
self.flattened_parameter = {}
|
||||||
|
for parameter_name, parameter in self.parameters.iteritems():
|
||||||
|
parameter.expand()
|
||||||
|
|
||||||
|
class System(Document):
|
||||||
|
name = StringField(required=True)
|
||||||
|
nodes = MapField(ReferenceField(Node, dbref=False))
|
||||||
|
|
||||||
|
def save(self, *args, **kwargs):
|
||||||
|
for node_name, node in self.nodes.iteritems():
|
||||||
|
node.expand()
|
||||||
|
node.save(*args, **kwargs)
|
||||||
|
super(System, self).save(*args, **kwargs)
|
||||||
|
|
||||||
|
System.drop_collection()
|
||||||
|
Node.drop_collection()
|
||||||
|
|
||||||
|
system = System(name="system")
|
||||||
|
system.nodes["node"] = Node()
|
||||||
|
system.save()
|
||||||
|
system.nodes["node"].parameters["param"] = Parameter()
|
||||||
|
system.save()
|
||||||
|
|
||||||
|
system = System.objects.first()
|
||||||
|
self.assertEqual("UNDEFINED", system.nodes["node"].parameters["param"].macros["test"].value)
|
||||||
|
|
||||||
|
|
||||||
class ValidatorErrorTest(unittest.TestCase):
|
class ValidatorErrorTest(unittest.TestCase):
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user