Cleaned up dereferencing

Dereferencing now respects max_depth, so should be more performant.
Reload is chainable and can be passed a max_depth for dereferencing
Added an Observer for ComplexBaseFields.

Refs #324 #323 #289
Closes #320
This commit is contained in:
Ross Lawley 2011-11-25 08:28:20 -08:00
parent 5e553ffaf7
commit 83fff80b0f
6 changed files with 122 additions and 97 deletions

View File

@ -5,6 +5,7 @@ Changelog
Changes in dev Changes in dev
============== ==============
- Fixed dereferencing - max_depth now taken into account
- Fixed document mutation saving issue - Fixed document mutation saving issue
- Fixed positional operator when replacing embedded documents - Fixed positional operator when replacing embedded documents
- Added Non-Django Style choices back (you can have either) - Added Non-Django Style choices back (you can have either)

View File

@ -155,9 +155,11 @@ class BaseField(object):
# Convert lists / values so we can watch for any changes on them # Convert lists / values so we can watch for any changes on them
if isinstance(value, (list, tuple)) and not isinstance(value, BaseList): if isinstance(value, (list, tuple)) and not isinstance(value, BaseList):
value = BaseList(value, instance=instance, name=self.name) observer = DataObserver(instance, self.name)
value = BaseList(value, observer)
elif isinstance(value, dict) and not isinstance(value, BaseDict): elif isinstance(value, dict) and not isinstance(value, BaseDict):
value = BaseDict(value, instance=instance, name=self.name) observer = DataObserver(instance, self.name)
value = BaseDict(value, observer)
return value return value
def __set__(self, instance, value): def __set__(self, instance, value):
@ -237,7 +239,7 @@ class ComplexBaseField(BaseField):
from dereference import dereference from dereference import dereference
instance._data[self.name] = dereference( instance._data[self.name] = dereference(
instance._data.get(self.name), max_depth=1, instance=instance, name=self.name, get=True instance._data.get(self.name), max_depth=1, instance=instance, name=self.name
) )
return super(ComplexBaseField, self).__get__(instance, owner) return super(ComplexBaseField, self).__get__(instance, owner)
@ -780,9 +782,11 @@ class BaseDocument(object):
# Convert lists / values so we can watch for any changes on them # Convert lists / values so we can watch for any changes on them
if isinstance(value, (list, tuple)) and not isinstance(value, BaseList): if isinstance(value, (list, tuple)) and not isinstance(value, BaseList):
value = BaseList(value, instance=self, name=name) observer = DataObserver(self, name)
value = BaseList(value, observer)
elif isinstance(value, dict) and not isinstance(value, BaseDict): elif isinstance(value, dict) and not isinstance(value, BaseDict):
value = BaseDict(value, instance=self, name=name) observer = DataObserver(self, name)
value = BaseDict(value, observer)
return value return value
@ -1122,102 +1126,113 @@ class BaseDocument(object):
return hash(self.pk) return hash(self.pk)
class DataObserver(object):
__slots__ = ["instance", "name"]
def __init__(self, instance, name):
self.instance = instance
self.name = name
def updated(self):
self.instance._mark_as_changed(self.name)
class BaseList(list): class BaseList(list):
"""A special list so we can watch any changes """A special list so we can watch any changes
""" """
def __init__(self, list_items, instance, name): def __init__(self, list_items, observer):
self.instance = instance self.observer = observer
self.name = name
super(BaseList, self).__init__(list_items) super(BaseList, self).__init__(list_items)
def __setitem__(self, *args, **kwargs): def __setitem__(self, *args, **kwargs):
self._mark_as_changed() self.observer.updated()
super(BaseList, self).__setitem__(*args, **kwargs) super(BaseList, self).__setitem__(*args, **kwargs)
def __delitem__(self, *args, **kwargs): def __delitem__(self, *args, **kwargs):
self._mark_as_changed() self.observer.updated()
super(BaseList, self).__delitem__(*args, **kwargs) super(BaseList, self).__delitem__(*args, **kwargs)
def __getstate__(self):
self.observer = None
return self
def __setstate__(self, state):
self = state
def append(self, *args, **kwargs): def append(self, *args, **kwargs):
self._mark_as_changed() self.observer.updated()
return super(BaseList, self).append(*args, **kwargs) return super(BaseList, self).append(*args, **kwargs)
def extend(self, *args, **kwargs): def extend(self, *args, **kwargs):
self._mark_as_changed() self.observer.updated()
return super(BaseList, self).extend(*args, **kwargs) return super(BaseList, self).extend(*args, **kwargs)
def insert(self, *args, **kwargs): def insert(self, *args, **kwargs):
self._mark_as_changed() self.observer.updated()
return super(BaseList, self).insert(*args, **kwargs) return super(BaseList, self).insert(*args, **kwargs)
def pop(self, *args, **kwargs): def pop(self, *args, **kwargs):
self._mark_as_changed() self.observer.updated()
return super(BaseList, self).pop(*args, **kwargs) return super(BaseList, self).pop(*args, **kwargs)
def remove(self, *args, **kwargs): def remove(self, *args, **kwargs):
self._mark_as_changed() self.observer.updated()
return super(BaseList, self).remove(*args, **kwargs) return super(BaseList, self).remove(*args, **kwargs)
def reverse(self, *args, **kwargs): def reverse(self, *args, **kwargs):
self._mark_as_changed() self.observer.updated()
return super(BaseList, self).reverse(*args, **kwargs) return super(BaseList, self).reverse(*args, **kwargs)
def sort(self, *args, **kwargs): def sort(self, *args, **kwargs):
self._mark_as_changed() self.observer.updated()
return super(BaseList, self).sort(*args, **kwargs) return super(BaseList, self).sort(*args, **kwargs)
def _mark_as_changed(self):
"""Marks a list as changed if has an instance and a name"""
if hasattr(self, 'instance') and hasattr(self, 'name'):
self.instance._mark_as_changed(self.name)
class BaseDict(dict): class BaseDict(dict):
"""A special dict so we can watch any changes """A special dict so we can watch any changes
""" """
def __init__(self, dict_items, instance, name): def __init__(self, dict_items, observer):
self.instance = instance self.observer = observer
self.name = name
super(BaseDict, self).__init__(dict_items) super(BaseDict, self).__init__(dict_items)
def __setitem__(self, *args, **kwargs): def __setitem__(self, *args, **kwargs):
self._mark_as_changed() self.observer.updated()
super(BaseDict, self).__setitem__(*args, **kwargs) super(BaseDict, self).__setitem__(*args, **kwargs)
def __setattr__(self, *args, **kwargs):
self._mark_as_changed()
super(BaseDict, self).__setattr__(*args, **kwargs)
def __delete__(self, *args, **kwargs): def __delete__(self, *args, **kwargs):
self._mark_as_changed() self.observer.updated()
super(BaseDict, self).__delete__(*args, **kwargs) super(BaseDict, self).__delete__(*args, **kwargs)
def __delitem__(self, *args, **kwargs): def __delitem__(self, *args, **kwargs):
self._mark_as_changed() self.observer.updated()
super(BaseDict, self).__delitem__(*args, **kwargs) super(BaseDict, self).__delitem__(*args, **kwargs)
def __delattr__(self, *args, **kwargs): def __delattr__(self, *args, **kwargs):
self._mark_as_changed() self.observer.updated()
super(BaseDict, self).__delattr__(*args, **kwargs) super(BaseDict, self).__delattr__(*args, **kwargs)
def __getstate__(self):
self.observer = None
return self
def __setstate__(self, state):
self = state
def clear(self, *args, **kwargs): def clear(self, *args, **kwargs):
self._mark_as_changed() self.observer.updated()
super(BaseDict, self).clear(*args, **kwargs) super(BaseDict, self).clear(*args, **kwargs)
def pop(self, *args, **kwargs): def pop(self, *args, **kwargs):
self._mark_as_changed() self.observer.updated()
super(BaseDict, self).clear(*args, **kwargs) super(BaseDict, self).clear(*args, **kwargs)
def popitem(self, *args, **kwargs): def popitem(self, *args, **kwargs):
self._mark_as_changed() self.observer.updated()
super(BaseDict, self).clear(*args, **kwargs) super(BaseDict, self).clear(*args, **kwargs)
def _mark_as_changed(self):
"""Marks a dict as changed if has an instance and a name"""
if hasattr(self, 'instance') and hasattr(self, 'name'):
self.instance._mark_as_changed(self.name)
if sys.version_info < (2, 5): if sys.version_info < (2, 5):
# Prior to Python 2.5, Exception was an old-style class # Prior to Python 2.5, Exception was an old-style class

View File

@ -1,6 +1,7 @@
import pymongo import pymongo
from base import BaseDict, BaseList, get_document, TopLevelDocumentMetaclass from base import (BaseDict, BaseList, DataObserver,
TopLevelDocumentMetaclass, get_document)
from fields import ReferenceField from fields import ReferenceField
from connection import get_db from connection import get_db
from queryset import QuerySet from queryset import QuerySet
@ -9,7 +10,7 @@ from document import Document
class DeReference(object): class DeReference(object):
def __call__(self, items, max_depth=1, instance=None, name=None, get=False): def __call__(self, items, max_depth=1, instance=None, name=None):
""" """
Cheaply dereferences the items to a set depth. Cheaply dereferences the items to a set depth.
Also handles the convertion of complex data types. Also handles the convertion of complex data types.
@ -43,7 +44,7 @@ class DeReference(object):
self.reference_map = self._find_references(items) self.reference_map = self._find_references(items)
self.object_map = self._fetch_objects(doc_type=doc_type) self.object_map = self._fetch_objects(doc_type=doc_type)
return self._attach_objects(items, 0, instance, name, get) return self._attach_objects(items, 0, instance, name)
def _find_references(self, items, depth=0): def _find_references(self, items, depth=0):
""" """
@ -53,7 +54,7 @@ class DeReference(object):
:param depth: The current depth of recursion :param depth: The current depth of recursion
""" """
reference_map = {} reference_map = {}
if not items: if not items or depth >= self.max_depth:
return reference_map return reference_map
# Determine the iterator to use # Determine the iterator to use
@ -63,6 +64,7 @@ class DeReference(object):
iterator = items.iteritems() iterator = items.iteritems()
# Recursively find dbreferences # Recursively find dbreferences
depth += 1
for k, item in iterator: for k, item in iterator:
if hasattr(item, '_fields'): if hasattr(item, '_fields'):
for field_name, field in item._fields.iteritems(): for field_name, field in item._fields.iteritems():
@ -82,11 +84,11 @@ class DeReference(object):
reference_map.setdefault(item.collection, []).append(item.id) reference_map.setdefault(item.collection, []).append(item.id)
elif isinstance(item, (dict, pymongo.son.SON)) and '_ref' in item: elif isinstance(item, (dict, pymongo.son.SON)) and '_ref' in item:
reference_map.setdefault(get_document(item['_cls']), []).append(item['_ref'].id) reference_map.setdefault(get_document(item['_cls']), []).append(item['_ref'].id)
elif isinstance(item, (dict, list, tuple)) and depth <= self.max_depth: elif isinstance(item, (dict, list, tuple)) and depth - 1 <= self.max_depth:
references = self._find_references(item, depth) references = self._find_references(item, depth - 1)
for key, refs in references.iteritems(): for key, refs in references.iteritems():
reference_map.setdefault(key, []).extend(refs) reference_map.setdefault(key, []).extend(refs)
depth += 1
return reference_map return reference_map
def _fetch_objects(self, doc_type=None): def _fetch_objects(self, doc_type=None):
@ -110,7 +112,7 @@ class DeReference(object):
object_map[doc.id] = doc object_map[doc.id] = doc
return object_map return object_map
def _attach_objects(self, items, depth=0, instance=None, name=None, get=False): def _attach_objects(self, items, depth=0, instance=None, name=None):
""" """
Recursively finds all db references to be dereferenced Recursively finds all db references to be dereferenced
@ -120,25 +122,24 @@ class DeReference(object):
:class:`~mongoengine.base.ComplexBaseField` :class:`~mongoengine.base.ComplexBaseField`
:param name: The name of the field, used for tracking changes by :param name: The name of the field, used for tracking changes by
:class:`~mongoengine.base.ComplexBaseField` :class:`~mongoengine.base.ComplexBaseField`
:param get: A boolean determining if being called by __get__
""" """
if not items: if not items:
if isinstance(items, (BaseDict, BaseList)): if isinstance(items, (BaseDict, BaseList)):
return items return items
if instance: if instance:
observer = DataObserver(instance, name)
if isinstance(items, dict): if isinstance(items, dict):
return BaseDict(items, instance=instance, name=name) return BaseDict(items, observer)
else: else:
return BaseList(items, instance=instance, name=name) return BaseList(items, observer)
if isinstance(items, (dict, pymongo.son.SON)): if isinstance(items, (dict, pymongo.son.SON)):
if '_ref' in items: if '_ref' in items:
return self.object_map.get(items['_ref'].id, items) return self.object_map.get(items['_ref'].id, items)
elif '_types' in items and '_cls' in items: elif '_types' in items and '_cls' in items:
doc = get_document(items['_cls'])._from_son(items) doc = get_document(items['_cls'])._from_son(items)
if not get: doc._data = self._attach_objects(doc._data, depth, doc, name)
doc._data = self._attach_objects(doc._data, depth, doc, name, get)
return doc return doc
if not hasattr(items, 'items'): if not hasattr(items, 'items'):
@ -150,6 +151,7 @@ class DeReference(object):
iterator = items.iteritems() iterator = items.iteritems()
data = {} data = {}
depth += 1
for k, v in iterator: for k, v in iterator:
if is_list: if is_list:
data.append(v) data.append(v)
@ -165,19 +167,20 @@ class DeReference(object):
data[k]._data[field_name] = self.object_map.get(v.id, v) data[k]._data[field_name] = self.object_map.get(v.id, v)
elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v: elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v:
data[k]._data[field_name] = self.object_map.get(v['_ref'].id, v) data[k]._data[field_name] = self.object_map.get(v['_ref'].id, v)
elif isinstance(v, dict) and depth < self.max_depth: elif isinstance(v, dict) and depth <= self.max_depth:
data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name, get=get) data[k]._data[field_name] = self._attach_objects(v, depth - 1, instance=instance, name=name)
elif isinstance(v, (list, tuple)): elif isinstance(v, (list, tuple)) and depth <= self.max_depth:
data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name, get=get) data[k]._data[field_name] = self._attach_objects(v, depth - 1, instance=instance, name=name)
elif isinstance(v, (dict, list, tuple)) and depth < self.max_depth: elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
data[k] = self._attach_objects(v, depth, instance=instance, name=name, get=get) data[k] = self._attach_objects(v, depth - 1, instance=instance, name=name)
elif hasattr(v, 'id'): elif hasattr(v, 'id'):
data[k] = self.object_map.get(v.id, v) data[k] = self.object_map.get(v.id, v)
if instance and name: if instance and name:
observer = DataObserver(instance, name)
if is_list: if is_list:
return BaseList(data, instance=instance, name=name) return BaseList(data, observer)
return BaseDict(data, instance=instance, name=name) return BaseDict(data, observer)
depth += 1 depth += 1
return data return data

View File

@ -1,14 +1,13 @@
from mongoengine import signals from mongoengine import signals
from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument,
ValidationError, BaseDict, BaseList, BaseDynamicField) BaseDict, BaseList, DataObserver)
from queryset import OperationError from queryset import OperationError
from connection import get_db from connection import get_db
import pymongo import pymongo
__all__ = ['Document', 'EmbeddedDocument', 'DynamicDocument', __all__ = ['Document', 'EmbeddedDocument', 'DynamicDocument',
'DynamicEmbeddedDocument', 'ValidationError', 'OperationError', 'DynamicEmbeddedDocument', 'OperationError', 'InvalidCollectionError']
'InvalidCollectionError']
class InvalidCollectionError(Exception): class InvalidCollectionError(Exception):
@ -250,20 +249,23 @@ class Document(BaseDocument):
self._data = dereference(self._data, max_depth) self._data = dereference(self._data, max_depth)
return self return self
def reload(self): def reload(self, max_depth=1):
"""Reloads all attributes from the database. """Reloads all attributes from the database.
.. versionadded:: 0.1.2 .. versionadded:: 0.1.2
.. versionchanged:: 0.6 Now chainable
""" """
id_field = self._meta['id_field'] id_field = self._meta['id_field']
obj = self.__class__.objects(**{id_field: self[id_field]}).first() obj = self.__class__.objects(
**{id_field: self[id_field]}
).first().select_related(max_depth=max_depth)
for field in self._fields: for field in self._fields:
setattr(self, field, self._reload(field, obj[field])) setattr(self, field, self._reload(field, obj[field]))
if self._dynamic: if self._dynamic:
for name in self._dynamic_fields.keys(): for name in self._dynamic_fields.keys():
setattr(self, name, self._reload(name, obj._data[name])) setattr(self, name, self._reload(name, obj._data[name]))
self._changed_fields = [] self._changed_fields = obj._changed_fields
return obj
def _reload(self, key, value): def _reload(self, key, value):
"""Used by :meth:`~mongoengine.Document.reload` to ensure the """Used by :meth:`~mongoengine.Document.reload` to ensure the
@ -271,10 +273,12 @@ class Document(BaseDocument):
""" """
if isinstance(value, BaseDict): if isinstance(value, BaseDict):
value = [(k, self._reload(k, v)) for k, v in value.items()] value = [(k, self._reload(k, v)) for k, v in value.items()]
value = BaseDict(value, instance=self, name=key) observer = DataObserver(self, key)
value = BaseDict(value, observer)
elif isinstance(value, BaseList): elif isinstance(value, BaseList):
value = [self._reload(key, v) for v in value] value = [self._reload(key, v) for v in value]
value = BaseList(value, instance=self, name=key) observer = DataObserver(self, key)
value = BaseList(value, observer)
elif isinstance(value, (EmbeddedDocument, DynamicEmbeddedDocument)): elif isinstance(value, (EmbeddedDocument, DynamicEmbeddedDocument)):
value._changed_fields = [] value._changed_fields = []
return value return value

View File

@ -1675,6 +1675,8 @@ class QuerySet(object):
.. versionadded:: 0.5 .. versionadded:: 0.5
""" """
from dereference import dereference from dereference import dereference
# Make select related work the same for querysets
max_depth += 1
return dereference(self, max_depth=max_depth) return dereference(self, max_depth=max_depth)

View File

@ -1069,7 +1069,7 @@ class DocumentTest(unittest.TestCase):
doc.embedded_field = embedded_1 doc.embedded_field = embedded_1
doc.save() doc.save()
doc.reload() doc = doc.reload(10)
doc.list_field.append(1) doc.list_field.append(1)
doc.dict_field['woot'] = "woot" doc.dict_field['woot'] = "woot"
doc.embedded_field.list_field.append(1) doc.embedded_field.list_field.append(1)
@ -1080,7 +1080,7 @@ class DocumentTest(unittest.TestCase):
'embedded_field.dict_field']) 'embedded_field.dict_field'])
doc.save() doc.save()
doc.reload() doc = doc.reload(10)
self.assertEquals(doc._get_changed_fields(), []) self.assertEquals(doc._get_changed_fields(), [])
self.assertEquals(len(doc.list_field), 4) self.assertEquals(len(doc.list_field), 4)
self.assertEquals(len(doc.dict_field), 2) self.assertEquals(len(doc.dict_field), 2)
@ -1502,14 +1502,14 @@ class DocumentTest(unittest.TestCase):
self.assertEquals(doc._delta(), ({'embedded_field': embedded_delta}, {})) self.assertEquals(doc._delta(), ({'embedded_field': embedded_delta}, {}))
doc.save() doc.save()
doc.reload() doc = doc.reload(10)
doc.embedded_field.dict_field = {} doc.embedded_field.dict_field = {}
self.assertEquals(doc._get_changed_fields(), ['embedded_field.dict_field']) self.assertEquals(doc._get_changed_fields(), ['embedded_field.dict_field'])
self.assertEquals(doc.embedded_field._delta(), ({}, {'dict_field': 1})) self.assertEquals(doc.embedded_field._delta(), ({}, {'dict_field': 1}))
self.assertEquals(doc._delta(), ({}, {'embedded_field.dict_field': 1})) self.assertEquals(doc._delta(), ({}, {'embedded_field.dict_field': 1}))
doc.save() doc.save()
doc.reload() doc = doc.reload(10)
self.assertEquals(doc.embedded_field.dict_field, {}) self.assertEquals(doc.embedded_field.dict_field, {})
doc.embedded_field.list_field = [] doc.embedded_field.list_field = []
@ -1517,7 +1517,7 @@ class DocumentTest(unittest.TestCase):
self.assertEquals(doc.embedded_field._delta(), ({}, {'list_field': 1})) self.assertEquals(doc.embedded_field._delta(), ({}, {'list_field': 1}))
self.assertEquals(doc._delta(), ({}, {'embedded_field.list_field': 1})) self.assertEquals(doc._delta(), ({}, {'embedded_field.list_field': 1}))
doc.save() doc.save()
doc.reload() doc = doc.reload(10)
self.assertEquals(doc.embedded_field.list_field, []) self.assertEquals(doc.embedded_field.list_field, [])
embedded_2 = Embedded() embedded_2 = Embedded()
@ -1550,7 +1550,7 @@ class DocumentTest(unittest.TestCase):
}] }]
}, {})) }, {}))
doc.save() doc.save()
doc.reload() doc = doc.reload(10)
self.assertEquals(doc.embedded_field.list_field[0], '1') self.assertEquals(doc.embedded_field.list_field[0], '1')
self.assertEquals(doc.embedded_field.list_field[1], 2) self.assertEquals(doc.embedded_field.list_field[1], 2)
@ -1562,7 +1562,7 @@ class DocumentTest(unittest.TestCase):
self.assertEquals(doc.embedded_field._delta(), ({'list_field.2.string_field': 'world'}, {})) self.assertEquals(doc.embedded_field._delta(), ({'list_field.2.string_field': 'world'}, {}))
self.assertEquals(doc._delta(), ({'embedded_field.list_field.2.string_field': 'world'}, {})) self.assertEquals(doc._delta(), ({'embedded_field.list_field.2.string_field': 'world'}, {}))
doc.save() doc.save()
doc.reload() doc = doc.reload(10)
self.assertEquals(doc.embedded_field.list_field[2].string_field, 'world') self.assertEquals(doc.embedded_field.list_field[2].string_field, 'world')
# Test multiple assignments # Test multiple assignments
@ -1587,40 +1587,40 @@ class DocumentTest(unittest.TestCase):
'dict_field': {'hello': 'world'}} 'dict_field': {'hello': 'world'}}
]}, {})) ]}, {}))
doc.save() doc.save()
doc.reload() doc = doc.reload(10)
self.assertEquals(doc.embedded_field.list_field[2].string_field, 'hello world') self.assertEquals(doc.embedded_field.list_field[2].string_field, 'hello world')
# Test list native methods # Test list native methods
doc.embedded_field.list_field[2].list_field.pop(0) doc.embedded_field.list_field[2].list_field.pop(0)
self.assertEquals(doc._delta(), ({'embedded_field.list_field.2.list_field': [2, {'hello': 'world'}]}, {})) self.assertEquals(doc._delta(), ({'embedded_field.list_field.2.list_field': [2, {'hello': 'world'}]}, {}))
doc.save() doc.save()
doc.reload() doc = doc.reload(10)
doc.embedded_field.list_field[2].list_field.append(1) doc.embedded_field.list_field[2].list_field.append(1)
self.assertEquals(doc._delta(), ({'embedded_field.list_field.2.list_field': [2, {'hello': 'world'}, 1]}, {})) self.assertEquals(doc._delta(), ({'embedded_field.list_field.2.list_field': [2, {'hello': 'world'}, 1]}, {}))
doc.save() doc.save()
doc.reload() doc = doc.reload(10)
self.assertEquals(doc.embedded_field.list_field[2].list_field, [2, {'hello': 'world'}, 1]) self.assertEquals(doc.embedded_field.list_field[2].list_field, [2, {'hello': 'world'}, 1])
doc.embedded_field.list_field[2].list_field.sort() doc.embedded_field.list_field[2].list_field.sort()
doc.save() doc.save()
doc.reload() doc = doc.reload(10)
self.assertEquals(doc.embedded_field.list_field[2].list_field, [1, 2, {'hello': 'world'}]) self.assertEquals(doc.embedded_field.list_field[2].list_field, [1, 2, {'hello': 'world'}])
del(doc.embedded_field.list_field[2].list_field[2]['hello']) del(doc.embedded_field.list_field[2].list_field[2]['hello'])
self.assertEquals(doc._delta(), ({'embedded_field.list_field.2.list_field': [1, 2, {}]}, {})) self.assertEquals(doc._delta(), ({'embedded_field.list_field.2.list_field': [1, 2, {}]}, {}))
doc.save() doc.save()
doc.reload() doc = doc.reload(10)
del(doc.embedded_field.list_field[2].list_field) del(doc.embedded_field.list_field[2].list_field)
self.assertEquals(doc._delta(), ({}, {'embedded_field.list_field.2.list_field': 1})) self.assertEquals(doc._delta(), ({}, {'embedded_field.list_field.2.list_field': 1}))
doc.save() doc.save()
doc.reload() doc = doc.reload(10)
doc.dict_field['Embedded'] = embedded_1 doc.dict_field['Embedded'] = embedded_1
doc.save() doc.save()
doc.reload() doc = doc.reload(10)
doc.dict_field['Embedded'].string_field = 'Hello World' doc.dict_field['Embedded'].string_field = 'Hello World'
self.assertEquals(doc._get_changed_fields(), ['dict_field.Embedded.string_field']) self.assertEquals(doc._get_changed_fields(), ['dict_field.Embedded.string_field'])
@ -1684,7 +1684,7 @@ class DocumentTest(unittest.TestCase):
doc.dict_field = {'hello': 'world'} doc.dict_field = {'hello': 'world'}
doc.list_field = ['1', 2, {'hello': 'world'}] doc.list_field = ['1', 2, {'hello': 'world'}]
doc.save() doc.save()
doc.reload() doc = doc.reload(10)
self.assertEquals(doc.string_field, 'hello') self.assertEquals(doc.string_field, 'hello')
self.assertEquals(doc.int_field, 1) self.assertEquals(doc.int_field, 1)
@ -1735,14 +1735,14 @@ class DocumentTest(unittest.TestCase):
self.assertEquals(doc._delta(), ({'db_embedded_field': embedded_delta}, {})) self.assertEquals(doc._delta(), ({'db_embedded_field': embedded_delta}, {}))
doc.save() doc.save()
doc.reload() doc = doc.reload(10)
doc.embedded_field.dict_field = {} doc.embedded_field.dict_field = {}
self.assertEquals(doc._get_changed_fields(), ['db_embedded_field.db_dict_field']) self.assertEquals(doc._get_changed_fields(), ['db_embedded_field.db_dict_field'])
self.assertEquals(doc.embedded_field._delta(), ({}, {'db_dict_field': 1})) self.assertEquals(doc.embedded_field._delta(), ({}, {'db_dict_field': 1}))
self.assertEquals(doc._delta(), ({}, {'db_embedded_field.db_dict_field': 1})) self.assertEquals(doc._delta(), ({}, {'db_embedded_field.db_dict_field': 1}))
doc.save() doc.save()
doc.reload() doc = doc.reload(10)
self.assertEquals(doc.embedded_field.dict_field, {}) self.assertEquals(doc.embedded_field.dict_field, {})
doc.embedded_field.list_field = [] doc.embedded_field.list_field = []
@ -1750,7 +1750,7 @@ class DocumentTest(unittest.TestCase):
self.assertEquals(doc.embedded_field._delta(), ({}, {'db_list_field': 1})) self.assertEquals(doc.embedded_field._delta(), ({}, {'db_list_field': 1}))
self.assertEquals(doc._delta(), ({}, {'db_embedded_field.db_list_field': 1})) self.assertEquals(doc._delta(), ({}, {'db_embedded_field.db_list_field': 1}))
doc.save() doc.save()
doc.reload() doc = doc.reload(10)
self.assertEquals(doc.embedded_field.list_field, []) self.assertEquals(doc.embedded_field.list_field, [])
embedded_2 = Embedded() embedded_2 = Embedded()
@ -1783,7 +1783,7 @@ class DocumentTest(unittest.TestCase):
}] }]
}, {})) }, {}))
doc.save() doc.save()
doc.reload() doc = doc.reload(10)
self.assertEquals(doc.embedded_field.list_field[0], '1') self.assertEquals(doc.embedded_field.list_field[0], '1')
self.assertEquals(doc.embedded_field.list_field[1], 2) self.assertEquals(doc.embedded_field.list_field[1], 2)
@ -1795,7 +1795,7 @@ class DocumentTest(unittest.TestCase):
self.assertEquals(doc.embedded_field._delta(), ({'db_list_field.2.db_string_field': 'world'}, {})) self.assertEquals(doc.embedded_field._delta(), ({'db_list_field.2.db_string_field': 'world'}, {}))
self.assertEquals(doc._delta(), ({'db_embedded_field.db_list_field.2.db_string_field': 'world'}, {})) self.assertEquals(doc._delta(), ({'db_embedded_field.db_list_field.2.db_string_field': 'world'}, {}))
doc.save() doc.save()
doc.reload() doc = doc.reload(10)
self.assertEquals(doc.embedded_field.list_field[2].string_field, 'world') self.assertEquals(doc.embedded_field.list_field[2].string_field, 'world')
# Test multiple assignments # Test multiple assignments
@ -1820,30 +1820,30 @@ class DocumentTest(unittest.TestCase):
'db_dict_field': {'hello': 'world'}} 'db_dict_field': {'hello': 'world'}}
]}, {})) ]}, {}))
doc.save() doc.save()
doc.reload() doc = doc.reload(10)
self.assertEquals(doc.embedded_field.list_field[2].string_field, 'hello world') self.assertEquals(doc.embedded_field.list_field[2].string_field, 'hello world')
# Test list native methods # Test list native methods
doc.embedded_field.list_field[2].list_field.pop(0) doc.embedded_field.list_field[2].list_field.pop(0)
self.assertEquals(doc._delta(), ({'db_embedded_field.db_list_field.2.db_list_field': [2, {'hello': 'world'}]}, {})) self.assertEquals(doc._delta(), ({'db_embedded_field.db_list_field.2.db_list_field': [2, {'hello': 'world'}]}, {}))
doc.save() doc.save()
doc.reload() doc = doc.reload(10)
doc.embedded_field.list_field[2].list_field.append(1) doc.embedded_field.list_field[2].list_field.append(1)
self.assertEquals(doc._delta(), ({'db_embedded_field.db_list_field.2.db_list_field': [2, {'hello': 'world'}, 1]}, {})) self.assertEquals(doc._delta(), ({'db_embedded_field.db_list_field.2.db_list_field': [2, {'hello': 'world'}, 1]}, {}))
doc.save() doc.save()
doc.reload() doc = doc.reload(10)
self.assertEquals(doc.embedded_field.list_field[2].list_field, [2, {'hello': 'world'}, 1]) self.assertEquals(doc.embedded_field.list_field[2].list_field, [2, {'hello': 'world'}, 1])
doc.embedded_field.list_field[2].list_field.sort() doc.embedded_field.list_field[2].list_field.sort()
doc.save() doc.save()
doc.reload() doc = doc.reload(10)
self.assertEquals(doc.embedded_field.list_field[2].list_field, [1, 2, {'hello': 'world'}]) self.assertEquals(doc.embedded_field.list_field[2].list_field, [1, 2, {'hello': 'world'}])
del(doc.embedded_field.list_field[2].list_field[2]['hello']) del(doc.embedded_field.list_field[2].list_field[2]['hello'])
self.assertEquals(doc._delta(), ({'db_embedded_field.db_list_field.2.db_list_field': [1, 2, {}]}, {})) self.assertEquals(doc._delta(), ({'db_embedded_field.db_list_field.2.db_list_field': [1, 2, {}]}, {}))
doc.save() doc.save()
doc.reload() doc = doc.reload(10)
del(doc.embedded_field.list_field[2].list_field) del(doc.embedded_field.list_field[2].list_field)
self.assertEquals(doc._delta(), ({}, {'db_embedded_field.db_list_field.2.db_list_field': 1})) self.assertEquals(doc._delta(), ({}, {'db_embedded_field.db_list_field.2.db_list_field': 1}))
@ -2344,7 +2344,7 @@ class DocumentTest(unittest.TestCase):
resurrected.string = "Two" resurrected.string = "Two"
resurrected.save() resurrected.save()
pickle_doc.reload() pickle_doc = pickle_doc.reload()
self.assertEquals(resurrected, pickle_doc) self.assertEquals(resurrected, pickle_doc)
def test_throw_invalid_document_error(self): def test_throw_invalid_document_error(self):