diff --git a/docs/changelog.rst b/docs/changelog.rst index ecd7ef57..54efb4ff 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,7 @@ Changelog Changes in dev ============== +- Added delta tracking now only sets / unsets explicitly changed fields - Fixed saving so sets updated values rather than overwrites - Added ComplexDateTimeField - Handles datetimes correctly with microseconds - Added ComplexBaseField - for improved flexibility and performance diff --git a/docs/guide/document-instances.rst b/docs/guide/document-instances.rst index 7b5d165b..aeed7cdb 100644 --- a/docs/guide/document-instances.rst +++ b/docs/guide/document-instances.rst @@ -18,10 +18,21 @@ attribute syntax:: Saving and deleting documents ============================= -To save the document to the database, call the -:meth:`~mongoengine.Document.save` method. If the document does not exist in -the database, it will be created. If it does already exist, it will be -updated. +MongoEngine tracks changes to documents to provide efficient saving. To save +the document to the database, call the :meth:`~mongoengine.Document.save` method. +If the document does not exist in the database, it will be created. If it does +already exist, then any changes will be updated atomically. For example:: + + >>> page = Page(title="Test Page") + >>> page.save() # Performs an insert + >>> page.title = "My Page" + >>> page.save() # Performs an atomic set on the title field. + +.. note:: + Changes to documents are tracked and on the whole perform `set` operations. + + * ``list_field.pop(0)`` - *sets* the resulting list + * ``del(list_field)`` - *unsets* whole list To delete a document, call the :meth:`~mongoengine.Document.delete` method. Note that this will only work if the document exists in the database and has a diff --git a/mongoengine/base.py b/mongoengine/base.py index 592a6784..292184ef 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -4,6 +4,7 @@ from queryset import DO_NOTHING from mongoengine import signals +import weakref import sys import pymongo import pymongo.objectid @@ -86,16 +87,19 @@ class BaseField(object): # Allow callable default values if callable(value): value = value() + + # Convert lists / values so we can watch for any changes on them + if isinstance(value, (list, tuple)) and not isinstance(value, BaseList): + value = BaseList(value, instance=instance, name=self.name) + elif isinstance(value, dict) and not isinstance(value, BaseDict): + value = BaseDict(value, instance=instance, name=self.name) return value def __set__(self, instance, value): """Descriptor for assigning a value to a field in a document. """ - key = self.name - instance._data[key] = value - # If the field set is in the _present_fields list add it so we can track - if hasattr(instance, '_present_fields') and key and key not in instance._present_fields: - instance._present_fields.append(self.name) + instance._data[self.name] = value + instance._mark_as_changed(self.name) def to_python(self, value): """Convert a MongoDB-compatible type to a Python type. @@ -173,21 +177,27 @@ class ComplexBaseField(BaseField): db = _get_db() dbref = {} collections = {} - for k, v in value_list.items(): - dbref[k] = v + for k,v in value_list.items(): + # Save any DBRefs if isinstance(v, (pymongo.dbref.DBRef)): # direct reference (DBRef) - collections.setdefault(v.collection, []).append((k, v)) - elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v: - # generic reference - collection = get_document(v['_cls'])._meta['collection'] - collections.setdefault(collection, []).append((k, v)) + collections.setdefault(v.collection, []).append((k,v)) + elif isinstance(v, (dict, pymongo.son.SON)): + if '_ref' in v: + # generic reference + collection = get_document(v['_cls'])._meta['collection'] + collections.setdefault(collection, []).append((k,v)) + else: + # Use BaseDict so can watch any changes + dbref[k] = BaseDict(v, instance=instance, name=self.name) + else: + dbref[k] = v # For each collection get the references for collection, dbrefs in collections.items(): id_map = {} - for k, v in dbrefs: + for k,v in dbrefs: if isinstance(v, (pymongo.dbref.DBRef)): # direct reference (DBRef), has no _cls information id_map[v.id] = (k, None) @@ -203,7 +213,9 @@ class ComplexBaseField(BaseField): dbref[key] = doc_cls._from_son(ref) if is_list: - dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))] + dbref = BaseList([v for k,v in sorted(dbref.items(), key=itemgetter(0))], instance=instance, name=self.name) + else: + dbref = BaseDict(dbref, instance=instance, name=self.name) instance._data[self.name] = dbref return super(ComplexBaseField, self).__get__(instance, owner) @@ -304,7 +316,7 @@ class ComplexBaseField(BaseField): if hasattr(value, 'iteritems'): [self.field.validate(v) for k,v in value.iteritems()] else: - [self.field.validate(v) for v in value] + [self.field.validate(v) for v in value] except Exception, err: raise ValidationError('Invalid %s item (%s)' % ( self.field.__class__.__name__, str(v))) @@ -714,7 +726,7 @@ class BaseDocument(object): self._meta.get('allow_inheritance', True) == False): data['_cls'] = self._class_name data['_types'] = self._superclasses.keys() + [self._class_name] - if data.has_key('_id') and data['_id'] is None: + if '_id' in data and data['_id'] is None: del data['_id'] return data @@ -751,9 +763,71 @@ class BaseDocument(object): else field.to_python(value)) obj = cls(**data) - obj._present_fields = present_fields + obj._changed_fields = [] return obj + def _mark_as_changed(self, key): + """Marks a key as explicitly changed by the user + """ + if not key: + return + if hasattr(self, '_changed_fields') and key not in self._changed_fields: + self._changed_fields.append(key) + + def _get_changed_fields(self, key=''): + """Returns a list of all fields that have explicitly been changed. + """ + from mongoengine import EmbeddedDocument + _changed_fields = [] + _changed_fields += getattr(self, '_changed_fields', []) + + for field_name in self._fields: + key = '%s.' % field_name + field = getattr(self, field_name, None) + if isinstance(field, EmbeddedDocument): # Grab all embedded fields that have been changed + _changed_fields += ["%s%s" % (key, k) for k in field._get_changed_fields(key) if k] + elif isinstance(field, (list, tuple)): # Loop list fields as they contain documents + for index, value in enumerate(field): + if not hasattr(value, '_get_changed_fields'): + continue + list_key = "%s%s." % (key, index) + _changed_fields += ["%s%s" % (list_key, k) for k in value._get_changed_fields(list_key) if k] + return _changed_fields + + def _delta(self): + """Returns the delta (set, unset) of the changes for a document. + Gets any values that have been explicitly changed. + """ + # Handles cases where not loaded from_son but has _id + doc = self.to_mongo() + set_fields = self._get_changed_fields() + set_data = {} + unset_data = {} + if hasattr(self, '_changed_fields'): + set_data = {} + # Fetch each set item from its path + for path in set_fields: + parts = path.split('.') + d = doc + for p in parts: + if hasattr(d, '__getattr__'): + d = getattr(p, d) + elif p.isdigit(): + d = d[int(p)] + else: + d = d.get(p) + set_data[path] = d + else: + set_data = doc + if '_id' in set_data: + del(set_data['_id']) + + for k,v in set_data.items(): + if not v: + del(set_data[k]) + unset_data[k] = 1 + return set_data, unset_data + def __eq__(self, other): if isinstance(other, self.__class__) and hasattr(other, 'id'): if self.id == other.id: @@ -764,13 +838,112 @@ class BaseDocument(object): return not self.__eq__(other) def __hash__(self): - """ For list, dic key """ + """ For list, dict key """ if self.pk is None: # For new object return super(BaseDocument,self).__hash__() else: return hash(self.pk) + +class BaseList(list): + """A special list so we can watch any changes + """ + + def __init__(self, list_items, instance, name): + self.instance = weakref.proxy(instance) + self.name = name + super(BaseList, self).__init__(list_items) + + def __setitem__(self, *args, **kwargs): + if hasattr(self, 'instance') and hasattr(self, 'name'): + self.instance._mark_as_changed(self.name) + super(BaseDict, self).__setitem__(*args, **kwargs) + + def __delitem__(self, *args, **kwargs): + self.instance._mark_as_changed(self.name) + super(BaseList, self).__delitem__(*args, **kwargs) + + def __delete__(self, *args, **kwargs): + if hasattr(self, 'instance') and hasattr(self, 'name'): + import ipdb; ipdb.set_trace() + self.instance._mark_as_changed(self.name) + delattr(self, 'instance') + delattr(self, 'name') + super(BaseDict, self).__delete__(*args, **kwargs) + + def append(self, *args, **kwargs): + self.instance._mark_as_changed(self.name) + return super(BaseList, self).append(*args, **kwargs) + + def extend(self, *args, **kwargs): + self.instance._mark_as_changed(self.name) + return super(BaseList, self).extend(*args, **kwargs) + + def insert(self, *args, **kwargs): + self.instance._mark_as_changed(self.name) + return super(BaseList, self).insert(*args, **kwargs) + + def pop(self, *args, **kwargs): + self.instance._mark_as_changed(self.name) + return super(BaseList, self).pop(*args, **kwargs) + + def remove(self, *args, **kwargs): + self.instance._mark_as_changed(self.name) + return super(BaseList, self).remove(*args, **kwargs) + + def reverse(self, *args, **kwargs): + self.instance._mark_as_changed(self.name) + return super(BaseList, self).reverse(*args, **kwargs) + + def sort(self, *args, **kwargs): + self.instance._mark_as_changed(self.name) + return super(BaseList, self).sort(*args, **kwargs) + + +class BaseDict(dict): + """A special dict so we can watch any changes + """ + + def __init__(self, dict_items, instance, name): + self.instance = weakref.proxy(instance) + self.name = name + super(BaseDict, self).__init__(dict_items) + + def __setitem__(self, *args, **kwargs): + if hasattr(self, 'instance') and hasattr(self, 'name'): + self.instance._mark_as_changed(self.name) + super(BaseDict, self).__setitem__(*args, **kwargs) + + def __setattr__(self, *args, **kwargs): + if hasattr(self, 'instance') and hasattr(self, 'name'): + self.instance._mark_as_changed(self.name) + super(BaseDict, self).__setattr__(*args, **kwargs) + + def __delete__(self, *args, **kwargs): + self.instance._mark_as_changed(self.name) + super(BaseDict, self).__delete__(*args, **kwargs) + + def __delitem__(self, *args, **kwargs): + self.instance._mark_as_changed(self.name) + super(BaseDict, self).__delitem__(*args, **kwargs) + + def __delattr__(self, *args, **kwargs): + self.instance._mark_as_changed(self.name) + super(BaseDict, self).__delattr__(*args, **kwargs) + + def clear(self, *args, **kwargs): + self.instance._mark_as_changed(self.name) + super(BaseDict, self).clear(*args, **kwargs) + + def pop(self, *args, **kwargs): + self.instance._mark_as_changed(self.name) + super(BaseDict, self).clear(*args, **kwargs) + + def popitem(self, *args, **kwargs): + self.instance._mark_as_changed(self.name) + super(BaseDict, self).clear(*args, **kwargs) + if sys.version_info < (2, 5): # Prior to Python 2.5, Exception was an old-style class import types diff --git a/mongoengine/document.py b/mongoengine/document.py index e25bea06..2f40eec7 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -1,12 +1,11 @@ from mongoengine import signals from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, - ValidationError) + ValidationError, BaseDict, BaseList) from queryset import OperationError from connection import _get_db import pymongo - __all__ = ['Document', 'EmbeddedDocument', 'ValidationError', 'OperationError'] @@ -19,6 +18,18 @@ class EmbeddedDocument(BaseDocument): __metaclass__ = DocumentMetaclass + def __delattr__(self, *args, **kwargs): + """Handle deletions of fields""" + field_name = args[0] + if field_name in self._fields: + default = self._fields[field_name].default + if callable(default): + default = default() + setattr(self, field_name, default) + else: + super(EmbeddedDocument, self).__delattr__(*args, **kwargs) + + class Document(BaseDocument): """The base class used for defining the structure and properties of @@ -59,7 +70,6 @@ class Document(BaseDocument): disabled by either setting types to False on the specific index or by setting index_types to False on the meta dictionary for the document. """ - __metaclass__ = TopLevelDocumentMetaclass def save(self, safe=True, force_insert=False, validate=True, write_options=None): @@ -95,18 +105,15 @@ class Document(BaseDocument): collection = self.__class__.objects._collection if force_insert: object_id = collection.insert(doc, safe=safe, **write_options) - elif '_id' in doc: - # Perform a set rather than a save - this will only save set fields - object_id = doc.pop('_id') - collection.update({'_id': object_id}, {"$set": doc}, upsert=True, safe=safe, **write_options) - - # Find and unset any fields explicitly set to None - if hasattr(self, '_present_fields'): - removals = dict([(k, 1) for k in self._present_fields if k not in doc and k != '_id']) - if removals: - collection.update({'_id': object_id}, {"$unset": removals}, upsert=True, safe=safe, **write_options) - else: + if created: object_id = collection.save(doc, safe=safe, **write_options) + else: + object_id = doc['_id'] + updates, removals = self._delta() + if updates: + collection.update({'_id': object_id}, {"$set": updates}, upsert=True, safe=safe, **write_options) + if removals: + collection.update({'_id': object_id}, {"$unset": removals}, upsert=True, safe=safe, **write_options) except pymongo.errors.OperationFailure, err: message = 'Could not save document (%s)' if u'duplicate key' in unicode(err): @@ -114,7 +121,7 @@ class Document(BaseDocument): raise OperationError(message % unicode(err)) id_field = self._meta['id_field'] self[id_field] = self._fields[id_field].to_python(object_id) - + self._changed_fields = [] signals.post_save.send(self, created=created) def delete(self, safe=False): @@ -135,14 +142,6 @@ class Document(BaseDocument): signals.post_delete.send(self) - @classmethod - def register_delete_rule(cls, document_cls, field_name, rule): - """This method registers the delete rules to apply when removing this - object. - """ - cls._meta['delete_rules'][(document_cls, field_name)] = rule - - def reload(self): """Reloads all attributes from the database. @@ -151,7 +150,29 @@ class Document(BaseDocument): id_field = self._meta['id_field'] obj = self.__class__.objects(**{id_field: self[id_field]}).first() for field in self._fields: - setattr(self, field, obj[field]) + setattr(self, field, self._reload(field, obj[field])) + self._changed_fields = [] + + def _reload(self, key, value): + """Used by :meth:`~mongoengine.Document.reload` to ensure the + correct instance is linked to self. + """ + if isinstance(value, BaseDict): + value = [(k, self._reload(k,v)) for k,v in value.items()] + value = BaseDict(value, instance=self, name=key) + elif isinstance(value, BaseList): + value = [self._reload(key, v) for v in value] + value = BaseList(value, instance=self, name=key) + elif isinstance(value, EmbeddedDocument): + value._changed_fields = [] + return value + + @classmethod + def register_delete_rule(cls, document_cls, field_name, rule): + """This method registers the delete rules to apply when removing this + object. + """ + cls._meta['delete_rules'][(document_cls, field_name)] = rule @classmethod def drop_collection(cls): diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 967ce834..eeb4c2c0 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -347,9 +347,9 @@ class ComplexDateTimeField(StringField): return datetime.datetime.now() return self._convert_from_string(data) - def __set__(self, obj, val): - data = self._convert_from_datetime(val) - return super(ComplexDateTimeField, self).__set__(obj, data) + def __set__(self, instance, value): + value = self._convert_from_datetime(value) + return super(ComplexDateTimeField, self).__set__(instance, value) def validate(self, value): if not isinstance(value, datetime.datetime): @@ -686,11 +686,13 @@ class GridFSProxy(object): .. versionadded:: 0.4 """ - def __init__(self, grid_id=None): + def __init__(self, grid_id=None, key=None, instance=None): self.fs = gridfs.GridFS(_get_db()) # Filesystem instance self.newfile = None # Used for partial writes self.grid_id = grid_id # Store GridFS id for file self.gridout = None + self.key = key + self.instance = instance def __getattr__(self, name): obj = self.get() @@ -723,6 +725,7 @@ class GridFSProxy(object): raise GridFSError('This document already has a file. Either delete ' 'it or call replace to overwrite it') self.grid_id = self.fs.put(file_obj, **kwargs) + self._mark_as_changed() def write(self, string): if self.grid_id: @@ -750,6 +753,12 @@ class GridFSProxy(object): self.fs.delete(self.grid_id) self.grid_id = None self.gridout = None + self._mark_as_changed() + + def _mark_as_changed(self): + """Inform the instance that `self.key` has been changed""" + if self.instance: + self.instance._mark_as_changed(self.key) def replace(self, file_obj, **kwargs): self.delete() @@ -777,10 +786,14 @@ class FileField(BaseField): grid_file = instance._data.get(self.name) self.grid_file = grid_file if self.grid_file: + if not self.grid_file.key: + self.grid_file.key = self.name + self.grid_file.instance = instance return self.grid_file - return GridFSProxy() + return GridFSProxy(key=self.name, instance=instance) def __set__(self, instance, value): + key = self.name if isinstance(value, file) or isinstance(value, str): # using "FileField() = file/string" notation grid_file = instance._data.get(self.name) @@ -794,10 +807,12 @@ class FileField(BaseField): grid_file.put(value) else: # Create a new proxy object as we don't already have one - instance._data[self.name] = GridFSProxy() - instance._data[self.name].put(value) + instance._data[key] = GridFSProxy(key=key, instance=instance) + instance._data[key].put(value) else: - instance._data[self.name] = value + instance._data[key] = value + + instance._mark_as_changed(key) def to_mongo(self, value): # Store the GridFS file id in MongoDB diff --git a/tests/dereference.py b/tests/dereference.py index 68792721..4040d5bd 100644 --- a/tests/dereference.py +++ b/tests/dereference.py @@ -281,9 +281,7 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 1) - - for k, m in group_obj.members.iteritems(): - self.assertTrue('User' in m.__class__.__name__) + self.assertEqual(group_obj.members, {}) UserA.drop_collection() UserB.drop_collection() diff --git a/tests/django_tests.py b/tests/django_tests.py index 6be1ea25..ee8084ce 100644 --- a/tests/django_tests.py +++ b/tests/django_tests.py @@ -1,4 +1,3 @@ - # -*- coding: utf-8 -*- import unittest diff --git a/tests/document.py b/tests/document.py index f0af8f2d..4c890800 100644 --- a/tests/document.py +++ b/tests/document.py @@ -2,6 +2,7 @@ import unittest from datetime import datetime import pymongo import pickle +import weakref from mongoengine import * from mongoengine.base import BaseField @@ -11,6 +12,7 @@ from mongoengine.connection import _get_db class PickleEmbedded(EmbeddedDocument): date = DateTimeField(default=datetime.now) + class PickleTest(Document): number = IntField() string = StringField() @@ -717,6 +719,47 @@ class DocumentTest(unittest.TestCase): self.assertEqual(person.name, "Mr Test User") self.assertEqual(person.age, 21) + def test_reload_referencing(self): + """Ensures reloading updates weakrefs correctly + """ + class Embedded(EmbeddedDocument): + dict_field = DictField() + list_field = ListField() + + class Doc(Document): + dict_field = DictField() + list_field = ListField() + embedded_field = EmbeddedDocumentField(Embedded) + + Doc.drop_collection + doc = Doc() + doc.dict_field = {'hello': 'world'} + doc.list_field = ['1', 2, {'hello': 'world'}] + + embedded_1 = Embedded() + embedded_1.dict_field = {'hello': 'world'} + embedded_1.list_field = ['1', 2, {'hello': 'world'}] + doc.embedded_field = embedded_1 + doc.save() + + doc.reload() + doc.list_field.append(1) + doc.dict_field['woot'] = "woot" + doc.embedded_field.list_field.append(1) + doc.embedded_field.dict_field['woot'] = "woot" + + self.assertEquals(doc._get_changed_fields(), [ + 'list_field', 'dict_field', 'embedded_field.list_field', + 'embedded_field.dict_field']) + doc.save() + + doc.reload() + self.assertEquals(doc._get_changed_fields(), []) + self.assertEquals(len(doc.list_field), 4) + self.assertEquals(len(doc.dict_field), 2) + self.assertEquals(len(doc.embedded_field.list_field), 4) + self.assertEquals(len(doc.embedded_field.dict_field), 2) + def test_dictionary_access(self): """Ensure that dictionary-style field access works properly. """ @@ -873,6 +916,197 @@ class DocumentTest(unittest.TestCase): self.assertEqual(person.name, None) self.assertEqual(person.age, None) + def test_delta(self): + + class Doc(Document): + string_field = StringField() + int_field = IntField() + dict_field = DictField() + list_field = ListField() + + Doc.drop_collection + doc = Doc() + doc.save() + + doc = Doc.objects.first() + self.assertEquals(doc._get_changed_fields(), []) + self.assertEquals(doc._delta(), ({}, {})) + + doc.string_field = 'hello' + self.assertEquals(doc._delta(), ({'string_field': 'hello'}, {})) + + doc._changed_fields = [] + doc.int_field = 1 + self.assertEquals(doc._delta(), ({'int_field': 1}, {})) + + doc._changed_fields = [] + dict_value = {'hello': 'world', 'ping': 'pong'} + doc.dict_field = dict_value + self.assertEquals(doc._delta(), ({'dict_field': dict_value}, {})) + + doc._changed_fields = [] + list_value = ['1', 2, {'hello': 'world'}] + doc.list_field = list_value + self.assertEquals(doc._delta(), ({'list_field': list_value}, {})) + + # Test unsetting + doc._changed_fields = [] + doc._unset_fields = [] + doc.dict_field = {} + self.assertEquals(doc._delta(), ({}, {'dict_field': 1})) + + doc._changed_fields = [] + doc._unset_fields = {} + doc.list_field = [] + self.assertEquals(doc._delta(), ({}, {'list_field': 1})) + + def test_delta_recursive(self): + + class Embedded(EmbeddedDocument): + string_field = StringField() + int_field = IntField() + dict_field = DictField() + list_field = ListField() + + class Doc(Document): + string_field = StringField() + int_field = IntField() + dict_field = DictField() + list_field = ListField() + embedded_field = EmbeddedDocumentField(Embedded) + + Doc.drop_collection + doc = Doc() + doc.save() + + doc = Doc.objects.first() + self.assertEquals(doc._get_changed_fields(), []) + self.assertEquals(doc._delta(), ({}, {})) + + embedded_1 = Embedded() + embedded_1.string_field = 'hello' + embedded_1.int_field = 1 + embedded_1.dict_field = {'hello': 'world'} + embedded_1.list_field = ['1', 2, {'hello': 'world'}] + doc.embedded_field = embedded_1 + + embedded_delta = { + '_types': ['Embedded'], + '_cls': 'Embedded', + 'string_field': 'hello', + 'int_field': 1, + 'dict_field': {'hello': 'world'}, + 'list_field': ['1', 2, {'hello': 'world'}] + } + self.assertEquals(doc.embedded_field._delta(), (embedded_delta, {})) + self.assertEquals(doc._delta(), ({'embedded_field': embedded_delta}, {})) + + doc.save() + doc.reload() + + doc.embedded_field.dict_field = {} + self.assertEquals(doc.embedded_field._delta(), ({}, {'dict_field': 1})) + self.assertEquals(doc._delta(), ({}, {'embedded_field.dict_field': 1})) + doc.save() + doc.reload() + self.assertEquals(doc.embedded_field.dict_field, {}) + + doc.embedded_field.list_field = [] + self.assertEquals(doc.embedded_field._delta(), ({}, {'list_field': 1})) + self.assertEquals(doc._delta(), ({}, {'embedded_field.list_field': 1})) + doc.save() + doc.reload() + self.assertEquals(doc.embedded_field.list_field, []) + + embedded_2 = Embedded() + embedded_2.string_field = 'hello' + embedded_2.int_field = 1 + embedded_2.dict_field = {'hello': 'world'} + embedded_2.list_field = ['1', 2, {'hello': 'world'}] + + doc.embedded_field.list_field = ['1', 2, embedded_2] + self.assertEquals(doc.embedded_field._delta(), ({ + 'list_field': ['1', 2, { + '_cls': 'Embedded', + '_types': ['Embedded'], + 'string_field': 'hello', + 'dict_field': {'hello': 'world'}, + 'int_field': 1, + 'list_field': ['1', 2, {'hello': 'world'}], + }] + }, {})) + + self.assertEquals(doc._delta(), ({ + 'embedded_field.list_field': ['1', 2, { + '_cls': 'Embedded', + '_types': ['Embedded'], + 'string_field': 'hello', + 'dict_field': {'hello': 'world'}, + 'int_field': 1, + 'list_field': ['1', 2, {'hello': 'world'}], + }] + }, {})) + doc.save() + doc.reload() + + self.assertEquals(doc.embedded_field.list_field[0], '1') + self.assertEquals(doc.embedded_field.list_field[1], 2) + for k in doc.embedded_field.list_field[2]._fields: + self.assertEquals(doc.embedded_field.list_field[2][k], embedded_2[k]) + + doc.embedded_field.list_field[2].string_field = 'world' + self.assertEquals(doc.embedded_field._delta(), ({'list_field.2.string_field': 'world'}, {})) + self.assertEquals(doc._delta(), ({'embedded_field.list_field.2.string_field': 'world'}, {})) + doc.save() + doc.reload() + self.assertEquals(doc.embedded_field.list_field[2].string_field, 'world') + + # Test list native methods + doc.embedded_field.list_field[2].list_field.pop(0) + self.assertEquals(doc._delta(), ({'embedded_field.list_field.2.list_field': [2, {'hello': 'world'}]}, {})) + doc.save() + doc.reload() + + doc.embedded_field.list_field[2].list_field.append(1) + self.assertEquals(doc._delta(), ({'embedded_field.list_field.2.list_field': [2, {'hello': 'world'}, 1]}, {})) + doc.save() + doc.reload() + self.assertEquals(doc.embedded_field.list_field[2].list_field, [2, {'hello': 'world'}, 1]) + + doc.embedded_field.list_field[2].list_field.sort() + doc.save() + doc.reload() + self.assertEquals(doc.embedded_field.list_field[2].list_field, [1, 2, {'hello': 'world'}]) + + del(doc.embedded_field.list_field[2].list_field[2]['hello']) + self.assertEquals(doc._delta(), ({'embedded_field.list_field.2.list_field': [1, 2, {}]}, {})) + doc.save() + doc.reload() + + del(doc.embedded_field.list_field[2].list_field) + self.assertEquals(doc._delta(), ({}, {'embedded_field.list_field.2.list_field': 1})) + + def test_save_only_changed_fields(self): + """Ensure save only sets / unsets changed fields + """ + + # Create person object and save it to the database + person = self.Person(name='Test User', age=30) + person.save() + person.reload() + + same_person = self.Person.objects.get() + + person.age = 21 + same_person.name = 'User' + + person.save() + same_person.save() + + person = self.Person.objects.get() + self.assertEquals(person.name, 'User') + self.assertEquals(person.age, 21) + def test_delete(self): """Ensure that document may be deleted using the delete method. """ @@ -978,12 +1212,19 @@ class DocumentTest(unittest.TestCase): promoted_employee.details.position = 'Senior Developer' promoted_employee.save() - collection = self.db[self.Person._meta['collection']] - employee_obj = collection.find_one({'name': 'Test Employee'}) - self.assertEqual(employee_obj['name'], 'Test Employee') - self.assertEqual(employee_obj['age'], 50) + promoted_employee.reload() + self.assertEqual(promoted_employee.name, 'Test Employee') + self.assertEqual(promoted_employee.age, 50) # Ensure that the 'details' embedded object saved correctly - self.assertEqual(employee_obj['details']['position'], 'Senior Developer') + self.assertEqual(promoted_employee.details.position, 'Senior Developer') + + # Test removal + promoted_employee.details = None + promoted_employee.save() + + promoted_employee.reload() + self.assertEqual(promoted_employee.details, None) + def test_save_reference(self): """Ensure that a document reference field may be saved in the database. diff --git a/tests/fields.py b/tests/fields.py index 531167c8..79cd519c 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -843,6 +843,7 @@ class FieldTest(unittest.TestCase): name = StringField() children = ListField(EmbeddedDocumentField('self')) + Tree.drop_collection tree = Tree(name="Tree") first_child = TreeNode(name="Child 1") @@ -853,15 +854,42 @@ class FieldTest(unittest.TestCase): third_child = TreeNode(name="Child 3") first_child.children.append(third_child) - tree.save() - tree_obj = Tree.objects.first() self.assertEqual(len(tree.children), 1) self.assertEqual(tree.children[0].name, first_child.name) self.assertEqual(tree.children[0].children[0].name, second_child.name) self.assertEqual(tree.children[0].children[1].name, third_child.name) + # Test updating + tree.children[0].name = 'I am Child 1' + tree.children[0].children[0].name = 'I am Child 2' + tree.children[0].children[1].name = 'I am Child 3' + tree.save() + + self.assertEqual(tree.children[0].name, 'I am Child 1') + self.assertEqual(tree.children[0].children[0].name, 'I am Child 2') + self.assertEqual(tree.children[0].children[1].name, 'I am Child 3') + + # Test removal + self.assertEqual(len(tree.children[0].children), 2) + del(tree.children[0].children[1]) + + tree.save() + self.assertEqual(len(tree.children[0].children), 1) + + tree.children[0].children.pop(0) + tree.save() + self.assertEqual(len(tree.children[0].children), 0) + self.assertEqual(tree.children[0].children, []) + + tree.children[0].children.insert(0, third_child) + tree.children[0].children.insert(0, second_child) + tree.save() + self.assertEqual(len(tree.children[0].children), 2) + self.assertEqual(tree.children[0].children[0].name, second_child.name) + self.assertEqual(tree.children[0].children[1].name, third_child.name) + def test_undefined_reference(self): """Ensure that ReferenceFields may reference undefined Documents. """