Added delta tracking to documents.

All saves on exisiting items do set / unset operations only on changed fields.
* Note lists and dicts generally do set operations for things like pop() del[key]
  As there is no easy map to unset and explicitly matches the new list / dict

fixes #18
This commit is contained in:
Ross Lawley 2011-06-10 17:22:05 +01:00
parent ea35fb1c54
commit 0ed79a839d
9 changed files with 552 additions and 65 deletions

View File

@ -5,6 +5,7 @@ Changelog
Changes in dev
==============
- Added delta tracking now only sets / unsets explicitly changed fields
- Fixed saving so sets updated values rather than overwrites
- Added ComplexDateTimeField - Handles datetimes correctly with microseconds
- Added ComplexBaseField - for improved flexibility and performance

View File

@ -18,10 +18,21 @@ attribute syntax::
Saving and deleting documents
=============================
To save the document to the database, call the
:meth:`~mongoengine.Document.save` method. If the document does not exist in
the database, it will be created. If it does already exist, it will be
updated.
MongoEngine tracks changes to documents to provide efficient saving. To save
the document to the database, call the :meth:`~mongoengine.Document.save` method.
If the document does not exist in the database, it will be created. If it does
already exist, then any changes will be updated atomically. For example::
>>> page = Page(title="Test Page")
>>> page.save() # Performs an insert
>>> page.title = "My Page"
>>> page.save() # Performs an atomic set on the title field.
.. note::
Changes to documents are tracked and on the whole perform `set` operations.
* ``list_field.pop(0)`` - *sets* the resulting list
* ``del(list_field)`` - *unsets* whole list
To delete a document, call the :meth:`~mongoengine.Document.delete` method.
Note that this will only work if the document exists in the database and has a

View File

@ -4,6 +4,7 @@ from queryset import DO_NOTHING
from mongoengine import signals
import weakref
import sys
import pymongo
import pymongo.objectid
@ -86,16 +87,19 @@ class BaseField(object):
# Allow callable default values
if callable(value):
value = value()
# Convert lists / values so we can watch for any changes on them
if isinstance(value, (list, tuple)) and not isinstance(value, BaseList):
value = BaseList(value, instance=instance, name=self.name)
elif isinstance(value, dict) and not isinstance(value, BaseDict):
value = BaseDict(value, instance=instance, name=self.name)
return value
def __set__(self, instance, value):
"""Descriptor for assigning a value to a field in a document.
"""
key = self.name
instance._data[key] = value
# If the field set is in the _present_fields list add it so we can track
if hasattr(instance, '_present_fields') and key and key not in instance._present_fields:
instance._present_fields.append(self.name)
instance._data[self.name] = value
instance._mark_as_changed(self.name)
def to_python(self, value):
"""Convert a MongoDB-compatible type to a Python type.
@ -173,21 +177,27 @@ class ComplexBaseField(BaseField):
db = _get_db()
dbref = {}
collections = {}
for k, v in value_list.items():
dbref[k] = v
for k,v in value_list.items():
# Save any DBRefs
if isinstance(v, (pymongo.dbref.DBRef)):
# direct reference (DBRef)
collections.setdefault(v.collection, []).append((k, v))
elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v:
collections.setdefault(v.collection, []).append((k,v))
elif isinstance(v, (dict, pymongo.son.SON)):
if '_ref' in v:
# generic reference
collection = get_document(v['_cls'])._meta['collection']
collections.setdefault(collection, []).append((k, v))
collections.setdefault(collection, []).append((k,v))
else:
# Use BaseDict so can watch any changes
dbref[k] = BaseDict(v, instance=instance, name=self.name)
else:
dbref[k] = v
# For each collection get the references
for collection, dbrefs in collections.items():
id_map = {}
for k, v in dbrefs:
for k,v in dbrefs:
if isinstance(v, (pymongo.dbref.DBRef)):
# direct reference (DBRef), has no _cls information
id_map[v.id] = (k, None)
@ -203,7 +213,9 @@ class ComplexBaseField(BaseField):
dbref[key] = doc_cls._from_son(ref)
if is_list:
dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))]
dbref = BaseList([v for k,v in sorted(dbref.items(), key=itemgetter(0))], instance=instance, name=self.name)
else:
dbref = BaseDict(dbref, instance=instance, name=self.name)
instance._data[self.name] = dbref
return super(ComplexBaseField, self).__get__(instance, owner)
@ -714,7 +726,7 @@ class BaseDocument(object):
self._meta.get('allow_inheritance', True) == False):
data['_cls'] = self._class_name
data['_types'] = self._superclasses.keys() + [self._class_name]
if data.has_key('_id') and data['_id'] is None:
if '_id' in data and data['_id'] is None:
del data['_id']
return data
@ -751,9 +763,71 @@ class BaseDocument(object):
else field.to_python(value))
obj = cls(**data)
obj._present_fields = present_fields
obj._changed_fields = []
return obj
def _mark_as_changed(self, key):
"""Marks a key as explicitly changed by the user
"""
if not key:
return
if hasattr(self, '_changed_fields') and key not in self._changed_fields:
self._changed_fields.append(key)
def _get_changed_fields(self, key=''):
"""Returns a list of all fields that have explicitly been changed.
"""
from mongoengine import EmbeddedDocument
_changed_fields = []
_changed_fields += getattr(self, '_changed_fields', [])
for field_name in self._fields:
key = '%s.' % field_name
field = getattr(self, field_name, None)
if isinstance(field, EmbeddedDocument): # Grab all embedded fields that have been changed
_changed_fields += ["%s%s" % (key, k) for k in field._get_changed_fields(key) if k]
elif isinstance(field, (list, tuple)): # Loop list fields as they contain documents
for index, value in enumerate(field):
if not hasattr(value, '_get_changed_fields'):
continue
list_key = "%s%s." % (key, index)
_changed_fields += ["%s%s" % (list_key, k) for k in value._get_changed_fields(list_key) if k]
return _changed_fields
def _delta(self):
"""Returns the delta (set, unset) of the changes for a document.
Gets any values that have been explicitly changed.
"""
# Handles cases where not loaded from_son but has _id
doc = self.to_mongo()
set_fields = self._get_changed_fields()
set_data = {}
unset_data = {}
if hasattr(self, '_changed_fields'):
set_data = {}
# Fetch each set item from its path
for path in set_fields:
parts = path.split('.')
d = doc
for p in parts:
if hasattr(d, '__getattr__'):
d = getattr(p, d)
elif p.isdigit():
d = d[int(p)]
else:
d = d.get(p)
set_data[path] = d
else:
set_data = doc
if '_id' in set_data:
del(set_data['_id'])
for k,v in set_data.items():
if not v:
del(set_data[k])
unset_data[k] = 1
return set_data, unset_data
def __eq__(self, other):
if isinstance(other, self.__class__) and hasattr(other, 'id'):
if self.id == other.id:
@ -764,13 +838,112 @@ class BaseDocument(object):
return not self.__eq__(other)
def __hash__(self):
""" For list, dic key """
""" For list, dict key """
if self.pk is None:
# For new object
return super(BaseDocument,self).__hash__()
else:
return hash(self.pk)
class BaseList(list):
"""A special list so we can watch any changes
"""
def __init__(self, list_items, instance, name):
self.instance = weakref.proxy(instance)
self.name = name
super(BaseList, self).__init__(list_items)
def __setitem__(self, *args, **kwargs):
if hasattr(self, 'instance') and hasattr(self, 'name'):
self.instance._mark_as_changed(self.name)
super(BaseDict, self).__setitem__(*args, **kwargs)
def __delitem__(self, *args, **kwargs):
self.instance._mark_as_changed(self.name)
super(BaseList, self).__delitem__(*args, **kwargs)
def __delete__(self, *args, **kwargs):
if hasattr(self, 'instance') and hasattr(self, 'name'):
import ipdb; ipdb.set_trace()
self.instance._mark_as_changed(self.name)
delattr(self, 'instance')
delattr(self, 'name')
super(BaseDict, self).__delete__(*args, **kwargs)
def append(self, *args, **kwargs):
self.instance._mark_as_changed(self.name)
return super(BaseList, self).append(*args, **kwargs)
def extend(self, *args, **kwargs):
self.instance._mark_as_changed(self.name)
return super(BaseList, self).extend(*args, **kwargs)
def insert(self, *args, **kwargs):
self.instance._mark_as_changed(self.name)
return super(BaseList, self).insert(*args, **kwargs)
def pop(self, *args, **kwargs):
self.instance._mark_as_changed(self.name)
return super(BaseList, self).pop(*args, **kwargs)
def remove(self, *args, **kwargs):
self.instance._mark_as_changed(self.name)
return super(BaseList, self).remove(*args, **kwargs)
def reverse(self, *args, **kwargs):
self.instance._mark_as_changed(self.name)
return super(BaseList, self).reverse(*args, **kwargs)
def sort(self, *args, **kwargs):
self.instance._mark_as_changed(self.name)
return super(BaseList, self).sort(*args, **kwargs)
class BaseDict(dict):
"""A special dict so we can watch any changes
"""
def __init__(self, dict_items, instance, name):
self.instance = weakref.proxy(instance)
self.name = name
super(BaseDict, self).__init__(dict_items)
def __setitem__(self, *args, **kwargs):
if hasattr(self, 'instance') and hasattr(self, 'name'):
self.instance._mark_as_changed(self.name)
super(BaseDict, self).__setitem__(*args, **kwargs)
def __setattr__(self, *args, **kwargs):
if hasattr(self, 'instance') and hasattr(self, 'name'):
self.instance._mark_as_changed(self.name)
super(BaseDict, self).__setattr__(*args, **kwargs)
def __delete__(self, *args, **kwargs):
self.instance._mark_as_changed(self.name)
super(BaseDict, self).__delete__(*args, **kwargs)
def __delitem__(self, *args, **kwargs):
self.instance._mark_as_changed(self.name)
super(BaseDict, self).__delitem__(*args, **kwargs)
def __delattr__(self, *args, **kwargs):
self.instance._mark_as_changed(self.name)
super(BaseDict, self).__delattr__(*args, **kwargs)
def clear(self, *args, **kwargs):
self.instance._mark_as_changed(self.name)
super(BaseDict, self).clear(*args, **kwargs)
def pop(self, *args, **kwargs):
self.instance._mark_as_changed(self.name)
super(BaseDict, self).clear(*args, **kwargs)
def popitem(self, *args, **kwargs):
self.instance._mark_as_changed(self.name)
super(BaseDict, self).clear(*args, **kwargs)
if sys.version_info < (2, 5):
# Prior to Python 2.5, Exception was an old-style class
import types

View File

@ -1,12 +1,11 @@
from mongoengine import signals
from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument,
ValidationError)
ValidationError, BaseDict, BaseList)
from queryset import OperationError
from connection import _get_db
import pymongo
__all__ = ['Document', 'EmbeddedDocument', 'ValidationError', 'OperationError']
@ -19,6 +18,18 @@ class EmbeddedDocument(BaseDocument):
__metaclass__ = DocumentMetaclass
def __delattr__(self, *args, **kwargs):
"""Handle deletions of fields"""
field_name = args[0]
if field_name in self._fields:
default = self._fields[field_name].default
if callable(default):
default = default()
setattr(self, field_name, default)
else:
super(EmbeddedDocument, self).__delattr__(*args, **kwargs)
class Document(BaseDocument):
"""The base class used for defining the structure and properties of
@ -59,7 +70,6 @@ class Document(BaseDocument):
disabled by either setting types to False on the specific index or
by setting index_types to False on the meta dictionary for the document.
"""
__metaclass__ = TopLevelDocumentMetaclass
def save(self, safe=True, force_insert=False, validate=True, write_options=None):
@ -95,18 +105,15 @@ class Document(BaseDocument):
collection = self.__class__.objects._collection
if force_insert:
object_id = collection.insert(doc, safe=safe, **write_options)
elif '_id' in doc:
# Perform a set rather than a save - this will only save set fields
object_id = doc.pop('_id')
collection.update({'_id': object_id}, {"$set": doc}, upsert=True, safe=safe, **write_options)
# Find and unset any fields explicitly set to None
if hasattr(self, '_present_fields'):
removals = dict([(k, 1) for k in self._present_fields if k not in doc and k != '_id'])
if created:
object_id = collection.save(doc, safe=safe, **write_options)
else:
object_id = doc['_id']
updates, removals = self._delta()
if updates:
collection.update({'_id': object_id}, {"$set": updates}, upsert=True, safe=safe, **write_options)
if removals:
collection.update({'_id': object_id}, {"$unset": removals}, upsert=True, safe=safe, **write_options)
else:
object_id = collection.save(doc, safe=safe, **write_options)
except pymongo.errors.OperationFailure, err:
message = 'Could not save document (%s)'
if u'duplicate key' in unicode(err):
@ -114,7 +121,7 @@ class Document(BaseDocument):
raise OperationError(message % unicode(err))
id_field = self._meta['id_field']
self[id_field] = self._fields[id_field].to_python(object_id)
self._changed_fields = []
signals.post_save.send(self, created=created)
def delete(self, safe=False):
@ -135,14 +142,6 @@ class Document(BaseDocument):
signals.post_delete.send(self)
@classmethod
def register_delete_rule(cls, document_cls, field_name, rule):
"""This method registers the delete rules to apply when removing this
object.
"""
cls._meta['delete_rules'][(document_cls, field_name)] = rule
def reload(self):
"""Reloads all attributes from the database.
@ -151,7 +150,29 @@ class Document(BaseDocument):
id_field = self._meta['id_field']
obj = self.__class__.objects(**{id_field: self[id_field]}).first()
for field in self._fields:
setattr(self, field, obj[field])
setattr(self, field, self._reload(field, obj[field]))
self._changed_fields = []
def _reload(self, key, value):
"""Used by :meth:`~mongoengine.Document.reload` to ensure the
correct instance is linked to self.
"""
if isinstance(value, BaseDict):
value = [(k, self._reload(k,v)) for k,v in value.items()]
value = BaseDict(value, instance=self, name=key)
elif isinstance(value, BaseList):
value = [self._reload(key, v) for v in value]
value = BaseList(value, instance=self, name=key)
elif isinstance(value, EmbeddedDocument):
value._changed_fields = []
return value
@classmethod
def register_delete_rule(cls, document_cls, field_name, rule):
"""This method registers the delete rules to apply when removing this
object.
"""
cls._meta['delete_rules'][(document_cls, field_name)] = rule
@classmethod
def drop_collection(cls):

View File

@ -347,9 +347,9 @@ class ComplexDateTimeField(StringField):
return datetime.datetime.now()
return self._convert_from_string(data)
def __set__(self, obj, val):
data = self._convert_from_datetime(val)
return super(ComplexDateTimeField, self).__set__(obj, data)
def __set__(self, instance, value):
value = self._convert_from_datetime(value)
return super(ComplexDateTimeField, self).__set__(instance, value)
def validate(self, value):
if not isinstance(value, datetime.datetime):
@ -686,11 +686,13 @@ class GridFSProxy(object):
.. versionadded:: 0.4
"""
def __init__(self, grid_id=None):
def __init__(self, grid_id=None, key=None, instance=None):
self.fs = gridfs.GridFS(_get_db()) # Filesystem instance
self.newfile = None # Used for partial writes
self.grid_id = grid_id # Store GridFS id for file
self.gridout = None
self.key = key
self.instance = instance
def __getattr__(self, name):
obj = self.get()
@ -723,6 +725,7 @@ class GridFSProxy(object):
raise GridFSError('This document already has a file. Either delete '
'it or call replace to overwrite it')
self.grid_id = self.fs.put(file_obj, **kwargs)
self._mark_as_changed()
def write(self, string):
if self.grid_id:
@ -750,6 +753,12 @@ class GridFSProxy(object):
self.fs.delete(self.grid_id)
self.grid_id = None
self.gridout = None
self._mark_as_changed()
def _mark_as_changed(self):
"""Inform the instance that `self.key` has been changed"""
if self.instance:
self.instance._mark_as_changed(self.key)
def replace(self, file_obj, **kwargs):
self.delete()
@ -777,10 +786,14 @@ class FileField(BaseField):
grid_file = instance._data.get(self.name)
self.grid_file = grid_file
if self.grid_file:
if not self.grid_file.key:
self.grid_file.key = self.name
self.grid_file.instance = instance
return self.grid_file
return GridFSProxy()
return GridFSProxy(key=self.name, instance=instance)
def __set__(self, instance, value):
key = self.name
if isinstance(value, file) or isinstance(value, str):
# using "FileField() = file/string" notation
grid_file = instance._data.get(self.name)
@ -794,10 +807,12 @@ class FileField(BaseField):
grid_file.put(value)
else:
# Create a new proxy object as we don't already have one
instance._data[self.name] = GridFSProxy()
instance._data[self.name].put(value)
instance._data[key] = GridFSProxy(key=key, instance=instance)
instance._data[key].put(value)
else:
instance._data[self.name] = value
instance._data[key] = value
instance._mark_as_changed(key)
def to_mongo(self, value):
# Store the GridFS file id in MongoDB

View File

@ -281,9 +281,7 @@ class FieldTest(unittest.TestCase):
[m for m in group_obj.members]
self.assertEqual(q, 1)
for k, m in group_obj.members.iteritems():
self.assertTrue('User' in m.__class__.__name__)
self.assertEqual(group_obj.members, {})
UserA.drop_collection()
UserB.drop_collection()

View File

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
import unittest

View File

@ -2,6 +2,7 @@ import unittest
from datetime import datetime
import pymongo
import pickle
import weakref
from mongoengine import *
from mongoengine.base import BaseField
@ -11,6 +12,7 @@ from mongoengine.connection import _get_db
class PickleEmbedded(EmbeddedDocument):
date = DateTimeField(default=datetime.now)
class PickleTest(Document):
number = IntField()
string = StringField()
@ -717,6 +719,47 @@ class DocumentTest(unittest.TestCase):
self.assertEqual(person.name, "Mr Test User")
self.assertEqual(person.age, 21)
def test_reload_referencing(self):
"""Ensures reloading updates weakrefs correctly
"""
class Embedded(EmbeddedDocument):
dict_field = DictField()
list_field = ListField()
class Doc(Document):
dict_field = DictField()
list_field = ListField()
embedded_field = EmbeddedDocumentField(Embedded)
Doc.drop_collection
doc = Doc()
doc.dict_field = {'hello': 'world'}
doc.list_field = ['1', 2, {'hello': 'world'}]
embedded_1 = Embedded()
embedded_1.dict_field = {'hello': 'world'}
embedded_1.list_field = ['1', 2, {'hello': 'world'}]
doc.embedded_field = embedded_1
doc.save()
doc.reload()
doc.list_field.append(1)
doc.dict_field['woot'] = "woot"
doc.embedded_field.list_field.append(1)
doc.embedded_field.dict_field['woot'] = "woot"
self.assertEquals(doc._get_changed_fields(), [
'list_field', 'dict_field', 'embedded_field.list_field',
'embedded_field.dict_field'])
doc.save()
doc.reload()
self.assertEquals(doc._get_changed_fields(), [])
self.assertEquals(len(doc.list_field), 4)
self.assertEquals(len(doc.dict_field), 2)
self.assertEquals(len(doc.embedded_field.list_field), 4)
self.assertEquals(len(doc.embedded_field.dict_field), 2)
def test_dictionary_access(self):
"""Ensure that dictionary-style field access works properly.
"""
@ -873,6 +916,197 @@ class DocumentTest(unittest.TestCase):
self.assertEqual(person.name, None)
self.assertEqual(person.age, None)
def test_delta(self):
class Doc(Document):
string_field = StringField()
int_field = IntField()
dict_field = DictField()
list_field = ListField()
Doc.drop_collection
doc = Doc()
doc.save()
doc = Doc.objects.first()
self.assertEquals(doc._get_changed_fields(), [])
self.assertEquals(doc._delta(), ({}, {}))
doc.string_field = 'hello'
self.assertEquals(doc._delta(), ({'string_field': 'hello'}, {}))
doc._changed_fields = []
doc.int_field = 1
self.assertEquals(doc._delta(), ({'int_field': 1}, {}))
doc._changed_fields = []
dict_value = {'hello': 'world', 'ping': 'pong'}
doc.dict_field = dict_value
self.assertEquals(doc._delta(), ({'dict_field': dict_value}, {}))
doc._changed_fields = []
list_value = ['1', 2, {'hello': 'world'}]
doc.list_field = list_value
self.assertEquals(doc._delta(), ({'list_field': list_value}, {}))
# Test unsetting
doc._changed_fields = []
doc._unset_fields = []
doc.dict_field = {}
self.assertEquals(doc._delta(), ({}, {'dict_field': 1}))
doc._changed_fields = []
doc._unset_fields = {}
doc.list_field = []
self.assertEquals(doc._delta(), ({}, {'list_field': 1}))
def test_delta_recursive(self):
class Embedded(EmbeddedDocument):
string_field = StringField()
int_field = IntField()
dict_field = DictField()
list_field = ListField()
class Doc(Document):
string_field = StringField()
int_field = IntField()
dict_field = DictField()
list_field = ListField()
embedded_field = EmbeddedDocumentField(Embedded)
Doc.drop_collection
doc = Doc()
doc.save()
doc = Doc.objects.first()
self.assertEquals(doc._get_changed_fields(), [])
self.assertEquals(doc._delta(), ({}, {}))
embedded_1 = Embedded()
embedded_1.string_field = 'hello'
embedded_1.int_field = 1
embedded_1.dict_field = {'hello': 'world'}
embedded_1.list_field = ['1', 2, {'hello': 'world'}]
doc.embedded_field = embedded_1
embedded_delta = {
'_types': ['Embedded'],
'_cls': 'Embedded',
'string_field': 'hello',
'int_field': 1,
'dict_field': {'hello': 'world'},
'list_field': ['1', 2, {'hello': 'world'}]
}
self.assertEquals(doc.embedded_field._delta(), (embedded_delta, {}))
self.assertEquals(doc._delta(), ({'embedded_field': embedded_delta}, {}))
doc.save()
doc.reload()
doc.embedded_field.dict_field = {}
self.assertEquals(doc.embedded_field._delta(), ({}, {'dict_field': 1}))
self.assertEquals(doc._delta(), ({}, {'embedded_field.dict_field': 1}))
doc.save()
doc.reload()
self.assertEquals(doc.embedded_field.dict_field, {})
doc.embedded_field.list_field = []
self.assertEquals(doc.embedded_field._delta(), ({}, {'list_field': 1}))
self.assertEquals(doc._delta(), ({}, {'embedded_field.list_field': 1}))
doc.save()
doc.reload()
self.assertEquals(doc.embedded_field.list_field, [])
embedded_2 = Embedded()
embedded_2.string_field = 'hello'
embedded_2.int_field = 1
embedded_2.dict_field = {'hello': 'world'}
embedded_2.list_field = ['1', 2, {'hello': 'world'}]
doc.embedded_field.list_field = ['1', 2, embedded_2]
self.assertEquals(doc.embedded_field._delta(), ({
'list_field': ['1', 2, {
'_cls': 'Embedded',
'_types': ['Embedded'],
'string_field': 'hello',
'dict_field': {'hello': 'world'},
'int_field': 1,
'list_field': ['1', 2, {'hello': 'world'}],
}]
}, {}))
self.assertEquals(doc._delta(), ({
'embedded_field.list_field': ['1', 2, {
'_cls': 'Embedded',
'_types': ['Embedded'],
'string_field': 'hello',
'dict_field': {'hello': 'world'},
'int_field': 1,
'list_field': ['1', 2, {'hello': 'world'}],
}]
}, {}))
doc.save()
doc.reload()
self.assertEquals(doc.embedded_field.list_field[0], '1')
self.assertEquals(doc.embedded_field.list_field[1], 2)
for k in doc.embedded_field.list_field[2]._fields:
self.assertEquals(doc.embedded_field.list_field[2][k], embedded_2[k])
doc.embedded_field.list_field[2].string_field = 'world'
self.assertEquals(doc.embedded_field._delta(), ({'list_field.2.string_field': 'world'}, {}))
self.assertEquals(doc._delta(), ({'embedded_field.list_field.2.string_field': 'world'}, {}))
doc.save()
doc.reload()
self.assertEquals(doc.embedded_field.list_field[2].string_field, 'world')
# Test list native methods
doc.embedded_field.list_field[2].list_field.pop(0)
self.assertEquals(doc._delta(), ({'embedded_field.list_field.2.list_field': [2, {'hello': 'world'}]}, {}))
doc.save()
doc.reload()
doc.embedded_field.list_field[2].list_field.append(1)
self.assertEquals(doc._delta(), ({'embedded_field.list_field.2.list_field': [2, {'hello': 'world'}, 1]}, {}))
doc.save()
doc.reload()
self.assertEquals(doc.embedded_field.list_field[2].list_field, [2, {'hello': 'world'}, 1])
doc.embedded_field.list_field[2].list_field.sort()
doc.save()
doc.reload()
self.assertEquals(doc.embedded_field.list_field[2].list_field, [1, 2, {'hello': 'world'}])
del(doc.embedded_field.list_field[2].list_field[2]['hello'])
self.assertEquals(doc._delta(), ({'embedded_field.list_field.2.list_field': [1, 2, {}]}, {}))
doc.save()
doc.reload()
del(doc.embedded_field.list_field[2].list_field)
self.assertEquals(doc._delta(), ({}, {'embedded_field.list_field.2.list_field': 1}))
def test_save_only_changed_fields(self):
"""Ensure save only sets / unsets changed fields
"""
# Create person object and save it to the database
person = self.Person(name='Test User', age=30)
person.save()
person.reload()
same_person = self.Person.objects.get()
person.age = 21
same_person.name = 'User'
person.save()
same_person.save()
person = self.Person.objects.get()
self.assertEquals(person.name, 'User')
self.assertEquals(person.age, 21)
def test_delete(self):
"""Ensure that document may be deleted using the delete method.
"""
@ -978,12 +1212,19 @@ class DocumentTest(unittest.TestCase):
promoted_employee.details.position = 'Senior Developer'
promoted_employee.save()
collection = self.db[self.Person._meta['collection']]
employee_obj = collection.find_one({'name': 'Test Employee'})
self.assertEqual(employee_obj['name'], 'Test Employee')
self.assertEqual(employee_obj['age'], 50)
promoted_employee.reload()
self.assertEqual(promoted_employee.name, 'Test Employee')
self.assertEqual(promoted_employee.age, 50)
# Ensure that the 'details' embedded object saved correctly
self.assertEqual(employee_obj['details']['position'], 'Senior Developer')
self.assertEqual(promoted_employee.details.position, 'Senior Developer')
# Test removal
promoted_employee.details = None
promoted_employee.save()
promoted_employee.reload()
self.assertEqual(promoted_employee.details, None)
def test_save_reference(self):
"""Ensure that a document reference field may be saved in the database.

View File

@ -843,6 +843,7 @@ class FieldTest(unittest.TestCase):
name = StringField()
children = ListField(EmbeddedDocumentField('self'))
Tree.drop_collection
tree = Tree(name="Tree")
first_child = TreeNode(name="Child 1")
@ -853,15 +854,42 @@ class FieldTest(unittest.TestCase):
third_child = TreeNode(name="Child 3")
first_child.children.append(third_child)
tree.save()
tree_obj = Tree.objects.first()
self.assertEqual(len(tree.children), 1)
self.assertEqual(tree.children[0].name, first_child.name)
self.assertEqual(tree.children[0].children[0].name, second_child.name)
self.assertEqual(tree.children[0].children[1].name, third_child.name)
# Test updating
tree.children[0].name = 'I am Child 1'
tree.children[0].children[0].name = 'I am Child 2'
tree.children[0].children[1].name = 'I am Child 3'
tree.save()
self.assertEqual(tree.children[0].name, 'I am Child 1')
self.assertEqual(tree.children[0].children[0].name, 'I am Child 2')
self.assertEqual(tree.children[0].children[1].name, 'I am Child 3')
# Test removal
self.assertEqual(len(tree.children[0].children), 2)
del(tree.children[0].children[1])
tree.save()
self.assertEqual(len(tree.children[0].children), 1)
tree.children[0].children.pop(0)
tree.save()
self.assertEqual(len(tree.children[0].children), 0)
self.assertEqual(tree.children[0].children, [])
tree.children[0].children.insert(0, third_child)
tree.children[0].children.insert(0, second_child)
tree.save()
self.assertEqual(len(tree.children[0].children), 2)
self.assertEqual(tree.children[0].children[0].name, second_child.name)
self.assertEqual(tree.children[0].children[1].name, third_child.name)
def test_undefined_reference(self):
"""Ensure that ReferenceFields may reference undefined Documents.
"""