Added delta tracking to documents.
All saves on exisiting items do set / unset operations only on changed fields. * Note lists and dicts generally do set operations for things like pop() del[key] As there is no easy map to unset and explicitly matches the new list / dict fixes #18
This commit is contained in:
		| @@ -5,6 +5,7 @@ Changelog | |||||||
| Changes in dev | Changes in dev | ||||||
| ============== | ============== | ||||||
|  |  | ||||||
|  | - Added delta tracking now only sets / unsets explicitly changed fields | ||||||
| - Fixed saving so sets updated values rather than overwrites | - Fixed saving so sets updated values rather than overwrites | ||||||
| - Added ComplexDateTimeField - Handles datetimes correctly with microseconds | - Added ComplexDateTimeField - Handles datetimes correctly with microseconds | ||||||
| - Added ComplexBaseField - for improved flexibility and performance | - Added ComplexBaseField - for improved flexibility and performance | ||||||
|   | |||||||
| @@ -18,10 +18,21 @@ attribute syntax:: | |||||||
|  |  | ||||||
| Saving and deleting documents | Saving and deleting documents | ||||||
| ============================= | ============================= | ||||||
| To save the document to the database, call the | MongoEngine tracks changes to documents to provide efficient saving.  To save  | ||||||
| :meth:`~mongoengine.Document.save` method. If the document does not exist in | the document to the database, call the :meth:`~mongoengine.Document.save` method. | ||||||
| the database, it will be created. If it does already exist, it will be | If the document does not exist in the database, it will be created. If it does  | ||||||
| updated. | already exist, then any changes will be updated atomically.  For example:: | ||||||
|  |  | ||||||
|  |     >>> page = Page(title="Test Page") | ||||||
|  |     >>> page.save()  # Performs an insert | ||||||
|  |     >>> page.title = "My Page" | ||||||
|  |     >>> page.save()  # Performs an atomic set on the title field. | ||||||
|  |  | ||||||
|  | .. note:: | ||||||
|  |     Changes to documents are tracked and on the whole perform `set` operations. | ||||||
|  |  | ||||||
|  |     * ``list_field.pop(0)`` - *sets* the resulting list | ||||||
|  |     * ``del(list_field)``   - *unsets* whole list | ||||||
|  |  | ||||||
| To delete a document, call the :meth:`~mongoengine.Document.delete` method. | To delete a document, call the :meth:`~mongoengine.Document.delete` method. | ||||||
| Note that this will only work if the document exists in the database and has a | Note that this will only work if the document exists in the database and has a | ||||||
|   | |||||||
| @@ -4,6 +4,7 @@ from queryset import DO_NOTHING | |||||||
|  |  | ||||||
| from mongoengine import signals | from mongoengine import signals | ||||||
|  |  | ||||||
|  | import weakref | ||||||
| import sys | import sys | ||||||
| import pymongo | import pymongo | ||||||
| import pymongo.objectid | import pymongo.objectid | ||||||
| @@ -86,16 +87,19 @@ class BaseField(object): | |||||||
|             # Allow callable default values |             # Allow callable default values | ||||||
|             if callable(value): |             if callable(value): | ||||||
|                 value = value() |                 value = value() | ||||||
|  |  | ||||||
|  |         # Convert lists / values so we can watch for any changes on them | ||||||
|  |         if isinstance(value, (list, tuple)) and not isinstance(value, BaseList): | ||||||
|  |             value = BaseList(value, instance=instance, name=self.name) | ||||||
|  |         elif isinstance(value, dict) and not isinstance(value, BaseDict): | ||||||
|  |             value = BaseDict(value, instance=instance, name=self.name) | ||||||
|         return value |         return value | ||||||
|  |  | ||||||
|     def __set__(self, instance, value): |     def __set__(self, instance, value): | ||||||
|         """Descriptor for assigning a value to a field in a document. |         """Descriptor for assigning a value to a field in a document. | ||||||
|         """ |         """ | ||||||
|         key = self.name |         instance._data[self.name] = value | ||||||
|         instance._data[key] = value |         instance._mark_as_changed(self.name) | ||||||
|         # If the field set is in the _present_fields list add it so we can track |  | ||||||
|         if hasattr(instance, '_present_fields') and key and key not in instance._present_fields: |  | ||||||
|             instance._present_fields.append(self.name) |  | ||||||
|  |  | ||||||
|     def to_python(self, value): |     def to_python(self, value): | ||||||
|         """Convert a MongoDB-compatible type to a Python type. |         """Convert a MongoDB-compatible type to a Python type. | ||||||
| @@ -173,21 +177,27 @@ class ComplexBaseField(BaseField): | |||||||
|         db = _get_db() |         db = _get_db() | ||||||
|         dbref = {} |         dbref = {} | ||||||
|         collections = {} |         collections = {} | ||||||
|         for k, v in value_list.items(): |         for k,v in value_list.items(): | ||||||
|             dbref[k] = v |  | ||||||
|             # Save any DBRefs |             # Save any DBRefs | ||||||
|             if isinstance(v, (pymongo.dbref.DBRef)): |             if isinstance(v, (pymongo.dbref.DBRef)): | ||||||
|                 # direct reference (DBRef) |                 # direct reference (DBRef) | ||||||
|                 collections.setdefault(v.collection, []).append((k, v)) |                 collections.setdefault(v.collection, []).append((k,v)) | ||||||
|             elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v: |             elif isinstance(v, (dict, pymongo.son.SON)): | ||||||
|  |                 if '_ref' in v: | ||||||
|                     # generic reference |                     # generic reference | ||||||
|                     collection = get_document(v['_cls'])._meta['collection'] |                     collection = get_document(v['_cls'])._meta['collection'] | ||||||
|                 collections.setdefault(collection, []).append((k, v)) |                     collections.setdefault(collection, []).append((k,v)) | ||||||
|  |                 else: | ||||||
|  |                     # Use BaseDict so can watch any changes | ||||||
|  |                     dbref[k] = BaseDict(v, instance=instance, name=self.name) | ||||||
|  |             else: | ||||||
|  |                 dbref[k] = v | ||||||
|  |  | ||||||
|         # For each collection get the references |         # For each collection get the references | ||||||
|         for collection, dbrefs in collections.items(): |         for collection, dbrefs in collections.items(): | ||||||
|             id_map = {} |             id_map = {} | ||||||
|             for k, v in dbrefs: |             for k,v in dbrefs: | ||||||
|                 if isinstance(v, (pymongo.dbref.DBRef)): |                 if isinstance(v, (pymongo.dbref.DBRef)): | ||||||
|                     # direct reference (DBRef), has no _cls information |                     # direct reference (DBRef), has no _cls information | ||||||
|                     id_map[v.id] = (k, None) |                     id_map[v.id] = (k, None) | ||||||
| @@ -203,7 +213,9 @@ class ComplexBaseField(BaseField): | |||||||
|                 dbref[key] = doc_cls._from_son(ref) |                 dbref[key] = doc_cls._from_son(ref) | ||||||
|  |  | ||||||
|         if is_list: |         if is_list: | ||||||
|             dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))] |             dbref = BaseList([v for k,v in sorted(dbref.items(), key=itemgetter(0))], instance=instance, name=self.name) | ||||||
|  |         else: | ||||||
|  |             dbref = BaseDict(dbref, instance=instance, name=self.name) | ||||||
|         instance._data[self.name] = dbref |         instance._data[self.name] = dbref | ||||||
|         return super(ComplexBaseField, self).__get__(instance, owner) |         return super(ComplexBaseField, self).__get__(instance, owner) | ||||||
|  |  | ||||||
| @@ -714,7 +726,7 @@ class BaseDocument(object): | |||||||
|                 self._meta.get('allow_inheritance', True) == False): |                 self._meta.get('allow_inheritance', True) == False): | ||||||
|             data['_cls'] = self._class_name |             data['_cls'] = self._class_name | ||||||
|             data['_types'] = self._superclasses.keys() + [self._class_name] |             data['_types'] = self._superclasses.keys() + [self._class_name] | ||||||
|         if data.has_key('_id') and data['_id'] is None: |         if '_id' in data and data['_id'] is None: | ||||||
|             del data['_id'] |             del data['_id'] | ||||||
|         return data |         return data | ||||||
|  |  | ||||||
| @@ -751,9 +763,71 @@ class BaseDocument(object): | |||||||
|                                     else field.to_python(value)) |                                     else field.to_python(value)) | ||||||
|  |  | ||||||
|         obj = cls(**data) |         obj = cls(**data) | ||||||
|         obj._present_fields = present_fields |         obj._changed_fields = [] | ||||||
|         return obj |         return obj | ||||||
|  |  | ||||||
|  |     def _mark_as_changed(self, key): | ||||||
|  |         """Marks a key as explicitly changed by the user | ||||||
|  |         """ | ||||||
|  |         if not key: | ||||||
|  |             return | ||||||
|  |         if hasattr(self, '_changed_fields') and key not in self._changed_fields: | ||||||
|  |             self._changed_fields.append(key) | ||||||
|  |  | ||||||
|  |     def _get_changed_fields(self, key=''): | ||||||
|  |         """Returns a list of all fields that have explicitly been changed. | ||||||
|  |         """ | ||||||
|  |         from mongoengine import EmbeddedDocument | ||||||
|  |         _changed_fields = [] | ||||||
|  |         _changed_fields += getattr(self, '_changed_fields', []) | ||||||
|  |  | ||||||
|  |         for field_name in self._fields: | ||||||
|  |             key = '%s.' % field_name | ||||||
|  |             field = getattr(self, field_name, None) | ||||||
|  |             if isinstance(field, EmbeddedDocument):  # Grab all embedded fields that have been changed | ||||||
|  |                 _changed_fields += ["%s%s" % (key, k) for k in field._get_changed_fields(key) if k] | ||||||
|  |             elif isinstance(field, (list, tuple)):  # Loop list fields as they contain documents | ||||||
|  |                 for index, value in enumerate(field): | ||||||
|  |                     if not hasattr(value, '_get_changed_fields'): | ||||||
|  |                         continue | ||||||
|  |                     list_key = "%s%s." % (key, index) | ||||||
|  |                     _changed_fields += ["%s%s" % (list_key, k) for k in value._get_changed_fields(list_key) if k] | ||||||
|  |         return _changed_fields | ||||||
|  |  | ||||||
|  |     def _delta(self): | ||||||
|  |         """Returns the delta (set, unset) of the changes for a document. | ||||||
|  |         Gets any values that have been explicitly changed. | ||||||
|  |         """ | ||||||
|  |         # Handles cases where not loaded from_son but has _id | ||||||
|  |         doc = self.to_mongo() | ||||||
|  |         set_fields = self._get_changed_fields() | ||||||
|  |         set_data = {} | ||||||
|  |         unset_data = {} | ||||||
|  |         if hasattr(self, '_changed_fields'): | ||||||
|  |             set_data = {} | ||||||
|  |             # Fetch each set item from its path | ||||||
|  |             for path in set_fields: | ||||||
|  |                 parts = path.split('.') | ||||||
|  |                 d = doc | ||||||
|  |                 for p in parts: | ||||||
|  |                     if hasattr(d, '__getattr__'): | ||||||
|  |                         d = getattr(p, d) | ||||||
|  |                     elif p.isdigit(): | ||||||
|  |                         d = d[int(p)] | ||||||
|  |                     else: | ||||||
|  |                         d = d.get(p) | ||||||
|  |                 set_data[path] = d | ||||||
|  |         else: | ||||||
|  |             set_data = doc | ||||||
|  |             if '_id' in set_data: | ||||||
|  |                 del(set_data['_id']) | ||||||
|  |  | ||||||
|  |         for k,v in set_data.items(): | ||||||
|  |             if not v: | ||||||
|  |                 del(set_data[k]) | ||||||
|  |                 unset_data[k] = 1 | ||||||
|  |         return set_data, unset_data | ||||||
|  |  | ||||||
|     def __eq__(self, other): |     def __eq__(self, other): | ||||||
|         if isinstance(other, self.__class__) and hasattr(other, 'id'): |         if isinstance(other, self.__class__) and hasattr(other, 'id'): | ||||||
|             if self.id == other.id: |             if self.id == other.id: | ||||||
| @@ -764,13 +838,112 @@ class BaseDocument(object): | |||||||
|         return not self.__eq__(other) |         return not self.__eq__(other) | ||||||
|  |  | ||||||
|     def __hash__(self): |     def __hash__(self): | ||||||
|         """ For list, dic key  """ |         """ For list, dict key  """ | ||||||
|         if self.pk is None: |         if self.pk is None: | ||||||
|             # For new object |             # For new object | ||||||
|             return super(BaseDocument,self).__hash__() |             return super(BaseDocument,self).__hash__() | ||||||
|         else: |         else: | ||||||
|             return hash(self.pk) |             return hash(self.pk) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class BaseList(list): | ||||||
|  |     """A special list so we can watch any changes | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self, list_items, instance, name): | ||||||
|  |         self.instance = weakref.proxy(instance) | ||||||
|  |         self.name = name | ||||||
|  |         super(BaseList, self).__init__(list_items) | ||||||
|  |  | ||||||
|  |     def __setitem__(self, *args, **kwargs): | ||||||
|  |         if hasattr(self, 'instance') and hasattr(self, 'name'): | ||||||
|  |             self.instance._mark_as_changed(self.name) | ||||||
|  |         super(BaseDict, self).__setitem__(*args, **kwargs) | ||||||
|  |  | ||||||
|  |     def __delitem__(self, *args, **kwargs): | ||||||
|  |         self.instance._mark_as_changed(self.name) | ||||||
|  |         super(BaseList, self).__delitem__(*args, **kwargs) | ||||||
|  |  | ||||||
|  |     def __delete__(self, *args, **kwargs): | ||||||
|  |         if hasattr(self, 'instance') and hasattr(self, 'name'): | ||||||
|  |             import ipdb; ipdb.set_trace() | ||||||
|  |             self.instance._mark_as_changed(self.name) | ||||||
|  |             delattr(self, 'instance') | ||||||
|  |             delattr(self, 'name') | ||||||
|  |         super(BaseDict, self).__delete__(*args, **kwargs) | ||||||
|  |  | ||||||
|  |     def append(self, *args, **kwargs): | ||||||
|  |         self.instance._mark_as_changed(self.name) | ||||||
|  |         return super(BaseList, self).append(*args, **kwargs) | ||||||
|  |  | ||||||
|  |     def extend(self, *args, **kwargs): | ||||||
|  |         self.instance._mark_as_changed(self.name) | ||||||
|  |         return super(BaseList, self).extend(*args, **kwargs) | ||||||
|  |  | ||||||
|  |     def insert(self, *args, **kwargs): | ||||||
|  |         self.instance._mark_as_changed(self.name) | ||||||
|  |         return super(BaseList, self).insert(*args, **kwargs) | ||||||
|  |  | ||||||
|  |     def pop(self, *args, **kwargs): | ||||||
|  |         self.instance._mark_as_changed(self.name) | ||||||
|  |         return super(BaseList, self).pop(*args, **kwargs) | ||||||
|  |  | ||||||
|  |     def remove(self, *args, **kwargs): | ||||||
|  |         self.instance._mark_as_changed(self.name) | ||||||
|  |         return super(BaseList, self).remove(*args, **kwargs) | ||||||
|  |  | ||||||
|  |     def reverse(self, *args, **kwargs): | ||||||
|  |         self.instance._mark_as_changed(self.name) | ||||||
|  |         return super(BaseList, self).reverse(*args, **kwargs) | ||||||
|  |  | ||||||
|  |     def sort(self, *args, **kwargs): | ||||||
|  |         self.instance._mark_as_changed(self.name) | ||||||
|  |         return super(BaseList, self).sort(*args, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class BaseDict(dict): | ||||||
|  |     """A special dict so we can watch any changes | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self, dict_items, instance, name): | ||||||
|  |         self.instance = weakref.proxy(instance) | ||||||
|  |         self.name = name | ||||||
|  |         super(BaseDict, self).__init__(dict_items) | ||||||
|  |  | ||||||
|  |     def __setitem__(self, *args, **kwargs): | ||||||
|  |         if hasattr(self, 'instance') and hasattr(self, 'name'): | ||||||
|  |             self.instance._mark_as_changed(self.name) | ||||||
|  |         super(BaseDict, self).__setitem__(*args, **kwargs) | ||||||
|  |  | ||||||
|  |     def __setattr__(self, *args, **kwargs): | ||||||
|  |         if hasattr(self, 'instance') and hasattr(self, 'name'): | ||||||
|  |             self.instance._mark_as_changed(self.name) | ||||||
|  |         super(BaseDict, self).__setattr__(*args, **kwargs) | ||||||
|  |  | ||||||
|  |     def __delete__(self, *args, **kwargs): | ||||||
|  |         self.instance._mark_as_changed(self.name) | ||||||
|  |         super(BaseDict, self).__delete__(*args, **kwargs) | ||||||
|  |  | ||||||
|  |     def __delitem__(self, *args, **kwargs): | ||||||
|  |         self.instance._mark_as_changed(self.name) | ||||||
|  |         super(BaseDict, self).__delitem__(*args, **kwargs) | ||||||
|  |  | ||||||
|  |     def __delattr__(self, *args, **kwargs): | ||||||
|  |         self.instance._mark_as_changed(self.name) | ||||||
|  |         super(BaseDict, self).__delattr__(*args, **kwargs) | ||||||
|  |  | ||||||
|  |     def clear(self, *args, **kwargs): | ||||||
|  |         self.instance._mark_as_changed(self.name) | ||||||
|  |         super(BaseDict, self).clear(*args, **kwargs) | ||||||
|  |  | ||||||
|  |     def pop(self, *args, **kwargs): | ||||||
|  |         self.instance._mark_as_changed(self.name) | ||||||
|  |         super(BaseDict, self).clear(*args, **kwargs) | ||||||
|  |  | ||||||
|  |     def popitem(self, *args, **kwargs): | ||||||
|  |         self.instance._mark_as_changed(self.name) | ||||||
|  |         super(BaseDict, self).clear(*args, **kwargs) | ||||||
|  |  | ||||||
| if sys.version_info < (2, 5): | if sys.version_info < (2, 5): | ||||||
|     # Prior to Python 2.5, Exception was an old-style class |     # Prior to Python 2.5, Exception was an old-style class | ||||||
|     import types |     import types | ||||||
|   | |||||||
| @@ -1,12 +1,11 @@ | |||||||
| from mongoengine import signals | from mongoengine import signals | ||||||
| from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, | from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, | ||||||
|                   ValidationError) |                   ValidationError, BaseDict, BaseList) | ||||||
| from queryset import OperationError | from queryset import OperationError | ||||||
| from connection import _get_db | from connection import _get_db | ||||||
|  |  | ||||||
| import pymongo | import pymongo | ||||||
|  |  | ||||||
|  |  | ||||||
| __all__ = ['Document', 'EmbeddedDocument', 'ValidationError', 'OperationError'] | __all__ = ['Document', 'EmbeddedDocument', 'ValidationError', 'OperationError'] | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -19,6 +18,18 @@ class EmbeddedDocument(BaseDocument): | |||||||
|  |  | ||||||
|     __metaclass__ = DocumentMetaclass |     __metaclass__ = DocumentMetaclass | ||||||
|  |  | ||||||
|  |     def __delattr__(self, *args, **kwargs): | ||||||
|  |         """Handle deletions of fields""" | ||||||
|  |         field_name = args[0] | ||||||
|  |         if field_name in self._fields: | ||||||
|  |             default = self._fields[field_name].default | ||||||
|  |             if callable(default): | ||||||
|  |                 default = default() | ||||||
|  |             setattr(self, field_name, default) | ||||||
|  |         else: | ||||||
|  |             super(EmbeddedDocument, self).__delattr__(*args, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Document(BaseDocument): | class Document(BaseDocument): | ||||||
|     """The base class used for defining the structure and properties of |     """The base class used for defining the structure and properties of | ||||||
| @@ -59,7 +70,6 @@ class Document(BaseDocument): | |||||||
|     disabled by either setting types to False on the specific index or |     disabled by either setting types to False on the specific index or | ||||||
|     by setting index_types to False on the meta dictionary for the document. |     by setting index_types to False on the meta dictionary for the document. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     __metaclass__ = TopLevelDocumentMetaclass |     __metaclass__ = TopLevelDocumentMetaclass | ||||||
|  |  | ||||||
|     def save(self, safe=True, force_insert=False, validate=True, write_options=None): |     def save(self, safe=True, force_insert=False, validate=True, write_options=None): | ||||||
| @@ -95,18 +105,15 @@ class Document(BaseDocument): | |||||||
|             collection = self.__class__.objects._collection |             collection = self.__class__.objects._collection | ||||||
|             if force_insert: |             if force_insert: | ||||||
|                 object_id = collection.insert(doc, safe=safe, **write_options) |                 object_id = collection.insert(doc, safe=safe, **write_options) | ||||||
|             elif '_id' in doc: |             if created: | ||||||
|                 # Perform a set rather than a save - this will only save set fields |                 object_id = collection.save(doc, safe=safe, **write_options) | ||||||
|                 object_id = doc.pop('_id') |             else: | ||||||
|                 collection.update({'_id': object_id}, {"$set": doc}, upsert=True, safe=safe, **write_options) |                 object_id = doc['_id'] | ||||||
|  |                 updates, removals = self._delta() | ||||||
|                 # Find and unset any fields explicitly set to None |                 if updates: | ||||||
|                 if hasattr(self, '_present_fields'): |                     collection.update({'_id': object_id}, {"$set": updates}, upsert=True, safe=safe, **write_options) | ||||||
|                     removals = dict([(k, 1) for k in self._present_fields if k not in doc and k != '_id']) |  | ||||||
|                 if removals: |                 if removals: | ||||||
|                     collection.update({'_id': object_id}, {"$unset": removals}, upsert=True, safe=safe, **write_options) |                     collection.update({'_id': object_id}, {"$unset": removals}, upsert=True, safe=safe, **write_options) | ||||||
|             else: |  | ||||||
|                 object_id = collection.save(doc, safe=safe, **write_options) |  | ||||||
|         except pymongo.errors.OperationFailure, err: |         except pymongo.errors.OperationFailure, err: | ||||||
|             message = 'Could not save document (%s)' |             message = 'Could not save document (%s)' | ||||||
|             if u'duplicate key' in unicode(err): |             if u'duplicate key' in unicode(err): | ||||||
| @@ -114,7 +121,7 @@ class Document(BaseDocument): | |||||||
|             raise OperationError(message % unicode(err)) |             raise OperationError(message % unicode(err)) | ||||||
|         id_field = self._meta['id_field'] |         id_field = self._meta['id_field'] | ||||||
|         self[id_field] = self._fields[id_field].to_python(object_id) |         self[id_field] = self._fields[id_field].to_python(object_id) | ||||||
|  |         self._changed_fields = [] | ||||||
|         signals.post_save.send(self, created=created) |         signals.post_save.send(self, created=created) | ||||||
|  |  | ||||||
|     def delete(self, safe=False): |     def delete(self, safe=False): | ||||||
| @@ -135,14 +142,6 @@ class Document(BaseDocument): | |||||||
|  |  | ||||||
|         signals.post_delete.send(self) |         signals.post_delete.send(self) | ||||||
|  |  | ||||||
|     @classmethod |  | ||||||
|     def register_delete_rule(cls, document_cls, field_name, rule): |  | ||||||
|         """This method registers the delete rules to apply when removing this |  | ||||||
|         object. |  | ||||||
|         """ |  | ||||||
|         cls._meta['delete_rules'][(document_cls, field_name)] = rule |  | ||||||
|  |  | ||||||
|  |  | ||||||
|     def reload(self): |     def reload(self): | ||||||
|         """Reloads all attributes from the database. |         """Reloads all attributes from the database. | ||||||
|  |  | ||||||
| @@ -151,7 +150,29 @@ class Document(BaseDocument): | |||||||
|         id_field = self._meta['id_field'] |         id_field = self._meta['id_field'] | ||||||
|         obj = self.__class__.objects(**{id_field: self[id_field]}).first() |         obj = self.__class__.objects(**{id_field: self[id_field]}).first() | ||||||
|         for field in self._fields: |         for field in self._fields: | ||||||
|             setattr(self, field, obj[field]) |             setattr(self, field, self._reload(field, obj[field])) | ||||||
|  |         self._changed_fields = [] | ||||||
|  |  | ||||||
|  |     def _reload(self, key, value): | ||||||
|  |         """Used by :meth:`~mongoengine.Document.reload` to ensure the | ||||||
|  |         correct instance is linked to self. | ||||||
|  |         """ | ||||||
|  |         if isinstance(value, BaseDict): | ||||||
|  |             value = [(k, self._reload(k,v)) for k,v in value.items()] | ||||||
|  |             value = BaseDict(value, instance=self, name=key) | ||||||
|  |         elif isinstance(value, BaseList): | ||||||
|  |             value = [self._reload(key, v) for v in value] | ||||||
|  |             value = BaseList(value, instance=self, name=key) | ||||||
|  |         elif isinstance(value, EmbeddedDocument): | ||||||
|  |             value._changed_fields = [] | ||||||
|  |         return value | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def register_delete_rule(cls, document_cls, field_name, rule): | ||||||
|  |         """This method registers the delete rules to apply when removing this | ||||||
|  |         object. | ||||||
|  |         """ | ||||||
|  |         cls._meta['delete_rules'][(document_cls, field_name)] = rule | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def drop_collection(cls): |     def drop_collection(cls): | ||||||
|   | |||||||
| @@ -347,9 +347,9 @@ class ComplexDateTimeField(StringField): | |||||||
|             return datetime.datetime.now() |             return datetime.datetime.now() | ||||||
|         return self._convert_from_string(data) |         return self._convert_from_string(data) | ||||||
|  |  | ||||||
|     def __set__(self, obj, val): |     def __set__(self, instance, value): | ||||||
|         data = self._convert_from_datetime(val) |         value = self._convert_from_datetime(value) | ||||||
|         return super(ComplexDateTimeField, self).__set__(obj, data) |         return super(ComplexDateTimeField, self).__set__(instance, value) | ||||||
|  |  | ||||||
|     def validate(self, value): |     def validate(self, value): | ||||||
|         if not isinstance(value, datetime.datetime): |         if not isinstance(value, datetime.datetime): | ||||||
| @@ -686,11 +686,13 @@ class GridFSProxy(object): | |||||||
|     .. versionadded:: 0.4 |     .. versionadded:: 0.4 | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, grid_id=None): |     def __init__(self, grid_id=None, key=None, instance=None): | ||||||
|         self.fs = gridfs.GridFS(_get_db())  # Filesystem instance |         self.fs = gridfs.GridFS(_get_db())  # Filesystem instance | ||||||
|         self.newfile = None                 # Used for partial writes |         self.newfile = None                 # Used for partial writes | ||||||
|         self.grid_id = grid_id              # Store GridFS id for file |         self.grid_id = grid_id              # Store GridFS id for file | ||||||
|         self.gridout = None |         self.gridout = None | ||||||
|  |         self.key = key | ||||||
|  |         self.instance = instance | ||||||
|  |  | ||||||
|     def __getattr__(self, name): |     def __getattr__(self, name): | ||||||
|         obj = self.get() |         obj = self.get() | ||||||
| @@ -723,6 +725,7 @@ class GridFSProxy(object): | |||||||
|             raise GridFSError('This document already has a file. Either delete ' |             raise GridFSError('This document already has a file. Either delete ' | ||||||
|                               'it or call replace to overwrite it') |                               'it or call replace to overwrite it') | ||||||
|         self.grid_id = self.fs.put(file_obj, **kwargs) |         self.grid_id = self.fs.put(file_obj, **kwargs) | ||||||
|  |         self._mark_as_changed() | ||||||
|  |  | ||||||
|     def write(self, string): |     def write(self, string): | ||||||
|         if self.grid_id: |         if self.grid_id: | ||||||
| @@ -750,6 +753,12 @@ class GridFSProxy(object): | |||||||
|         self.fs.delete(self.grid_id) |         self.fs.delete(self.grid_id) | ||||||
|         self.grid_id = None |         self.grid_id = None | ||||||
|         self.gridout = None |         self.gridout = None | ||||||
|  |         self._mark_as_changed() | ||||||
|  |  | ||||||
|  |     def _mark_as_changed(self): | ||||||
|  |         """Inform the instance that `self.key` has been changed""" | ||||||
|  |         if self.instance: | ||||||
|  |             self.instance._mark_as_changed(self.key) | ||||||
|  |  | ||||||
|     def replace(self, file_obj, **kwargs): |     def replace(self, file_obj, **kwargs): | ||||||
|         self.delete() |         self.delete() | ||||||
| @@ -777,10 +786,14 @@ class FileField(BaseField): | |||||||
|         grid_file = instance._data.get(self.name) |         grid_file = instance._data.get(self.name) | ||||||
|         self.grid_file = grid_file |         self.grid_file = grid_file | ||||||
|         if self.grid_file: |         if self.grid_file: | ||||||
|  |             if not self.grid_file.key: | ||||||
|  |                 self.grid_file.key = self.name | ||||||
|  |                 self.grid_file.instance = instance | ||||||
|             return self.grid_file |             return self.grid_file | ||||||
|         return GridFSProxy() |         return GridFSProxy(key=self.name, instance=instance) | ||||||
|  |  | ||||||
|     def __set__(self, instance, value): |     def __set__(self, instance, value): | ||||||
|  |         key = self.name | ||||||
|         if isinstance(value, file) or isinstance(value, str): |         if isinstance(value, file) or isinstance(value, str): | ||||||
|             # using "FileField() = file/string" notation |             # using "FileField() = file/string" notation | ||||||
|             grid_file = instance._data.get(self.name) |             grid_file = instance._data.get(self.name) | ||||||
| @@ -794,10 +807,12 @@ class FileField(BaseField): | |||||||
|                 grid_file.put(value) |                 grid_file.put(value) | ||||||
|             else: |             else: | ||||||
|                 # Create a new proxy object as we don't already have one |                 # Create a new proxy object as we don't already have one | ||||||
|                 instance._data[self.name] = GridFSProxy() |                 instance._data[key] = GridFSProxy(key=key, instance=instance) | ||||||
|                 instance._data[self.name].put(value) |                 instance._data[key].put(value) | ||||||
|         else: |         else: | ||||||
|             instance._data[self.name] = value |             instance._data[key] = value | ||||||
|  |  | ||||||
|  |         instance._mark_as_changed(key) | ||||||
|  |  | ||||||
|     def to_mongo(self, value): |     def to_mongo(self, value): | ||||||
|         # Store the GridFS file id in MongoDB |         # Store the GridFS file id in MongoDB | ||||||
|   | |||||||
| @@ -281,9 +281,7 @@ class FieldTest(unittest.TestCase): | |||||||
|  |  | ||||||
|             [m for m in group_obj.members] |             [m for m in group_obj.members] | ||||||
|             self.assertEqual(q, 1) |             self.assertEqual(q, 1) | ||||||
|  |             self.assertEqual(group_obj.members, {}) | ||||||
|             for k, m in group_obj.members.iteritems(): |  | ||||||
|                 self.assertTrue('User' in m.__class__.__name__) |  | ||||||
|  |  | ||||||
|         UserA.drop_collection() |         UserA.drop_collection() | ||||||
|         UserB.drop_collection() |         UserB.drop_collection() | ||||||
|   | |||||||
| @@ -1,4 +1,3 @@ | |||||||
|  |  | ||||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||||
|  |  | ||||||
| import unittest | import unittest | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ import unittest | |||||||
| from datetime import datetime | from datetime import datetime | ||||||
| import pymongo | import pymongo | ||||||
| import pickle | import pickle | ||||||
|  | import weakref | ||||||
|  |  | ||||||
| from mongoengine import * | from mongoengine import * | ||||||
| from mongoengine.base import BaseField | from mongoengine.base import BaseField | ||||||
| @@ -11,6 +12,7 @@ from mongoengine.connection import _get_db | |||||||
| class PickleEmbedded(EmbeddedDocument): | class PickleEmbedded(EmbeddedDocument): | ||||||
|     date = DateTimeField(default=datetime.now) |     date = DateTimeField(default=datetime.now) | ||||||
|  |  | ||||||
|  |  | ||||||
| class PickleTest(Document): | class PickleTest(Document): | ||||||
|     number = IntField() |     number = IntField() | ||||||
|     string = StringField() |     string = StringField() | ||||||
| @@ -717,6 +719,47 @@ class DocumentTest(unittest.TestCase): | |||||||
|         self.assertEqual(person.name, "Mr Test User") |         self.assertEqual(person.name, "Mr Test User") | ||||||
|         self.assertEqual(person.age, 21) |         self.assertEqual(person.age, 21) | ||||||
|  |  | ||||||
|  |     def test_reload_referencing(self): | ||||||
|  |         """Ensures reloading updates weakrefs correctly | ||||||
|  |         """ | ||||||
|  |         class Embedded(EmbeddedDocument): | ||||||
|  |             dict_field = DictField() | ||||||
|  |             list_field = ListField() | ||||||
|  |  | ||||||
|  |         class Doc(Document): | ||||||
|  |             dict_field = DictField() | ||||||
|  |             list_field = ListField() | ||||||
|  |             embedded_field = EmbeddedDocumentField(Embedded) | ||||||
|  |  | ||||||
|  |         Doc.drop_collection | ||||||
|  |         doc = Doc() | ||||||
|  |         doc.dict_field = {'hello': 'world'} | ||||||
|  |         doc.list_field = ['1', 2, {'hello': 'world'}] | ||||||
|  |  | ||||||
|  |         embedded_1 = Embedded() | ||||||
|  |         embedded_1.dict_field = {'hello': 'world'} | ||||||
|  |         embedded_1.list_field = ['1', 2, {'hello': 'world'}] | ||||||
|  |         doc.embedded_field = embedded_1 | ||||||
|  |         doc.save() | ||||||
|  |  | ||||||
|  |         doc.reload() | ||||||
|  |         doc.list_field.append(1) | ||||||
|  |         doc.dict_field['woot'] = "woot" | ||||||
|  |         doc.embedded_field.list_field.append(1) | ||||||
|  |         doc.embedded_field.dict_field['woot'] = "woot" | ||||||
|  |  | ||||||
|  |         self.assertEquals(doc._get_changed_fields(), [ | ||||||
|  |             'list_field', 'dict_field', 'embedded_field.list_field', | ||||||
|  |             'embedded_field.dict_field']) | ||||||
|  |         doc.save() | ||||||
|  |  | ||||||
|  |         doc.reload() | ||||||
|  |         self.assertEquals(doc._get_changed_fields(), []) | ||||||
|  |         self.assertEquals(len(doc.list_field), 4) | ||||||
|  |         self.assertEquals(len(doc.dict_field), 2) | ||||||
|  |         self.assertEquals(len(doc.embedded_field.list_field), 4) | ||||||
|  |         self.assertEquals(len(doc.embedded_field.dict_field), 2) | ||||||
|  |  | ||||||
|     def test_dictionary_access(self): |     def test_dictionary_access(self): | ||||||
|         """Ensure that dictionary-style field access works properly. |         """Ensure that dictionary-style field access works properly. | ||||||
|         """ |         """ | ||||||
| @@ -873,6 +916,197 @@ class DocumentTest(unittest.TestCase): | |||||||
|         self.assertEqual(person.name, None) |         self.assertEqual(person.name, None) | ||||||
|         self.assertEqual(person.age, None) |         self.assertEqual(person.age, None) | ||||||
|  |  | ||||||
|  |     def test_delta(self): | ||||||
|  |  | ||||||
|  |         class Doc(Document): | ||||||
|  |             string_field = StringField() | ||||||
|  |             int_field = IntField() | ||||||
|  |             dict_field = DictField() | ||||||
|  |             list_field = ListField() | ||||||
|  |  | ||||||
|  |         Doc.drop_collection | ||||||
|  |         doc = Doc() | ||||||
|  |         doc.save() | ||||||
|  |  | ||||||
|  |         doc = Doc.objects.first() | ||||||
|  |         self.assertEquals(doc._get_changed_fields(), []) | ||||||
|  |         self.assertEquals(doc._delta(), ({}, {})) | ||||||
|  |  | ||||||
|  |         doc.string_field = 'hello' | ||||||
|  |         self.assertEquals(doc._delta(), ({'string_field': 'hello'}, {})) | ||||||
|  |  | ||||||
|  |         doc._changed_fields = [] | ||||||
|  |         doc.int_field = 1 | ||||||
|  |         self.assertEquals(doc._delta(), ({'int_field': 1}, {})) | ||||||
|  |  | ||||||
|  |         doc._changed_fields = [] | ||||||
|  |         dict_value = {'hello': 'world', 'ping': 'pong'} | ||||||
|  |         doc.dict_field = dict_value | ||||||
|  |         self.assertEquals(doc._delta(), ({'dict_field': dict_value}, {})) | ||||||
|  |  | ||||||
|  |         doc._changed_fields = [] | ||||||
|  |         list_value = ['1', 2, {'hello': 'world'}] | ||||||
|  |         doc.list_field = list_value | ||||||
|  |         self.assertEquals(doc._delta(), ({'list_field': list_value}, {})) | ||||||
|  |  | ||||||
|  |         # Test unsetting | ||||||
|  |         doc._changed_fields = [] | ||||||
|  |         doc._unset_fields = [] | ||||||
|  |         doc.dict_field = {} | ||||||
|  |         self.assertEquals(doc._delta(), ({}, {'dict_field': 1})) | ||||||
|  |  | ||||||
|  |         doc._changed_fields = [] | ||||||
|  |         doc._unset_fields = {} | ||||||
|  |         doc.list_field = [] | ||||||
|  |         self.assertEquals(doc._delta(), ({}, {'list_field': 1})) | ||||||
|  |  | ||||||
|  |     def test_delta_recursive(self): | ||||||
|  |  | ||||||
|  |         class Embedded(EmbeddedDocument): | ||||||
|  |             string_field = StringField() | ||||||
|  |             int_field = IntField() | ||||||
|  |             dict_field = DictField() | ||||||
|  |             list_field = ListField() | ||||||
|  |  | ||||||
|  |         class Doc(Document): | ||||||
|  |             string_field = StringField() | ||||||
|  |             int_field = IntField() | ||||||
|  |             dict_field = DictField() | ||||||
|  |             list_field = ListField() | ||||||
|  |             embedded_field = EmbeddedDocumentField(Embedded) | ||||||
|  |  | ||||||
|  |         Doc.drop_collection | ||||||
|  |         doc = Doc() | ||||||
|  |         doc.save() | ||||||
|  |  | ||||||
|  |         doc = Doc.objects.first() | ||||||
|  |         self.assertEquals(doc._get_changed_fields(), []) | ||||||
|  |         self.assertEquals(doc._delta(), ({}, {})) | ||||||
|  |  | ||||||
|  |         embedded_1 = Embedded() | ||||||
|  |         embedded_1.string_field = 'hello' | ||||||
|  |         embedded_1.int_field = 1 | ||||||
|  |         embedded_1.dict_field = {'hello': 'world'} | ||||||
|  |         embedded_1.list_field = ['1', 2, {'hello': 'world'}] | ||||||
|  |         doc.embedded_field = embedded_1 | ||||||
|  |  | ||||||
|  |         embedded_delta = { | ||||||
|  |             '_types': ['Embedded'], | ||||||
|  |             '_cls': 'Embedded', | ||||||
|  |             'string_field': 'hello', | ||||||
|  |             'int_field': 1, | ||||||
|  |             'dict_field': {'hello': 'world'}, | ||||||
|  |             'list_field': ['1', 2, {'hello': 'world'}] | ||||||
|  |         } | ||||||
|  |         self.assertEquals(doc.embedded_field._delta(), (embedded_delta, {})) | ||||||
|  |         self.assertEquals(doc._delta(), ({'embedded_field': embedded_delta}, {})) | ||||||
|  |  | ||||||
|  |         doc.save() | ||||||
|  |         doc.reload() | ||||||
|  |  | ||||||
|  |         doc.embedded_field.dict_field = {} | ||||||
|  |         self.assertEquals(doc.embedded_field._delta(), ({}, {'dict_field': 1})) | ||||||
|  |         self.assertEquals(doc._delta(), ({}, {'embedded_field.dict_field': 1})) | ||||||
|  |         doc.save() | ||||||
|  |         doc.reload() | ||||||
|  |         self.assertEquals(doc.embedded_field.dict_field, {}) | ||||||
|  |  | ||||||
|  |         doc.embedded_field.list_field = [] | ||||||
|  |         self.assertEquals(doc.embedded_field._delta(), ({}, {'list_field': 1})) | ||||||
|  |         self.assertEquals(doc._delta(), ({}, {'embedded_field.list_field': 1})) | ||||||
|  |         doc.save() | ||||||
|  |         doc.reload() | ||||||
|  |         self.assertEquals(doc.embedded_field.list_field, []) | ||||||
|  |  | ||||||
|  |         embedded_2 = Embedded() | ||||||
|  |         embedded_2.string_field = 'hello' | ||||||
|  |         embedded_2.int_field = 1 | ||||||
|  |         embedded_2.dict_field = {'hello': 'world'} | ||||||
|  |         embedded_2.list_field = ['1', 2, {'hello': 'world'}] | ||||||
|  |  | ||||||
|  |         doc.embedded_field.list_field = ['1', 2, embedded_2] | ||||||
|  |         self.assertEquals(doc.embedded_field._delta(), ({ | ||||||
|  |             'list_field': ['1', 2, { | ||||||
|  |                 '_cls': 'Embedded', | ||||||
|  |                 '_types': ['Embedded'], | ||||||
|  |                 'string_field': 'hello', | ||||||
|  |                 'dict_field': {'hello': 'world'}, | ||||||
|  |                 'int_field': 1, | ||||||
|  |                 'list_field': ['1', 2, {'hello': 'world'}], | ||||||
|  |             }] | ||||||
|  |         }, {})) | ||||||
|  |  | ||||||
|  |         self.assertEquals(doc._delta(), ({ | ||||||
|  |             'embedded_field.list_field': ['1', 2, { | ||||||
|  |                 '_cls': 'Embedded', | ||||||
|  |                  '_types': ['Embedded'], | ||||||
|  |                  'string_field': 'hello', | ||||||
|  |                  'dict_field': {'hello': 'world'}, | ||||||
|  |                  'int_field': 1, | ||||||
|  |                  'list_field': ['1', 2, {'hello': 'world'}], | ||||||
|  |             }] | ||||||
|  |         }, {})) | ||||||
|  |         doc.save() | ||||||
|  |         doc.reload() | ||||||
|  |  | ||||||
|  |         self.assertEquals(doc.embedded_field.list_field[0], '1') | ||||||
|  |         self.assertEquals(doc.embedded_field.list_field[1], 2) | ||||||
|  |         for k in doc.embedded_field.list_field[2]._fields: | ||||||
|  |             self.assertEquals(doc.embedded_field.list_field[2][k], embedded_2[k]) | ||||||
|  |  | ||||||
|  |         doc.embedded_field.list_field[2].string_field = 'world' | ||||||
|  |         self.assertEquals(doc.embedded_field._delta(), ({'list_field.2.string_field': 'world'}, {})) | ||||||
|  |         self.assertEquals(doc._delta(), ({'embedded_field.list_field.2.string_field': 'world'}, {})) | ||||||
|  |         doc.save() | ||||||
|  |         doc.reload() | ||||||
|  |         self.assertEquals(doc.embedded_field.list_field[2].string_field, 'world') | ||||||
|  |  | ||||||
|  |         # Test list native methods | ||||||
|  |         doc.embedded_field.list_field[2].list_field.pop(0) | ||||||
|  |         self.assertEquals(doc._delta(), ({'embedded_field.list_field.2.list_field': [2, {'hello': 'world'}]}, {})) | ||||||
|  |         doc.save() | ||||||
|  |         doc.reload() | ||||||
|  |  | ||||||
|  |         doc.embedded_field.list_field[2].list_field.append(1) | ||||||
|  |         self.assertEquals(doc._delta(), ({'embedded_field.list_field.2.list_field': [2, {'hello': 'world'}, 1]}, {})) | ||||||
|  |         doc.save() | ||||||
|  |         doc.reload() | ||||||
|  |         self.assertEquals(doc.embedded_field.list_field[2].list_field, [2, {'hello': 'world'}, 1]) | ||||||
|  |  | ||||||
|  |         doc.embedded_field.list_field[2].list_field.sort() | ||||||
|  |         doc.save() | ||||||
|  |         doc.reload() | ||||||
|  |         self.assertEquals(doc.embedded_field.list_field[2].list_field, [1, 2, {'hello': 'world'}]) | ||||||
|  |  | ||||||
|  |         del(doc.embedded_field.list_field[2].list_field[2]['hello']) | ||||||
|  |         self.assertEquals(doc._delta(), ({'embedded_field.list_field.2.list_field': [1, 2, {}]}, {})) | ||||||
|  |         doc.save() | ||||||
|  |         doc.reload() | ||||||
|  |  | ||||||
|  |         del(doc.embedded_field.list_field[2].list_field) | ||||||
|  |         self.assertEquals(doc._delta(), ({}, {'embedded_field.list_field.2.list_field': 1})) | ||||||
|  |  | ||||||
|  |     def test_save_only_changed_fields(self): | ||||||
|  |         """Ensure save only sets / unsets changed fields | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         # Create person object and save it to the database | ||||||
|  |         person = self.Person(name='Test User', age=30) | ||||||
|  |         person.save() | ||||||
|  |         person.reload() | ||||||
|  |  | ||||||
|  |         same_person = self.Person.objects.get() | ||||||
|  |  | ||||||
|  |         person.age = 21 | ||||||
|  |         same_person.name = 'User' | ||||||
|  |  | ||||||
|  |         person.save() | ||||||
|  |         same_person.save() | ||||||
|  |  | ||||||
|  |         person = self.Person.objects.get() | ||||||
|  |         self.assertEquals(person.name, 'User') | ||||||
|  |         self.assertEquals(person.age, 21) | ||||||
|  |  | ||||||
|     def test_delete(self): |     def test_delete(self): | ||||||
|         """Ensure that document may be deleted using the delete method. |         """Ensure that document may be deleted using the delete method. | ||||||
|         """ |         """ | ||||||
| @@ -978,12 +1212,19 @@ class DocumentTest(unittest.TestCase): | |||||||
|         promoted_employee.details.position = 'Senior Developer' |         promoted_employee.details.position = 'Senior Developer' | ||||||
|         promoted_employee.save() |         promoted_employee.save() | ||||||
|  |  | ||||||
|         collection = self.db[self.Person._meta['collection']] |         promoted_employee.reload() | ||||||
|         employee_obj = collection.find_one({'name': 'Test Employee'}) |         self.assertEqual(promoted_employee.name, 'Test Employee') | ||||||
|         self.assertEqual(employee_obj['name'], 'Test Employee') |         self.assertEqual(promoted_employee.age, 50) | ||||||
|         self.assertEqual(employee_obj['age'], 50) |  | ||||||
|         # Ensure that the 'details' embedded object saved correctly |         # Ensure that the 'details' embedded object saved correctly | ||||||
|         self.assertEqual(employee_obj['details']['position'], 'Senior Developer') |         self.assertEqual(promoted_employee.details.position, 'Senior Developer') | ||||||
|  |  | ||||||
|  |         # Test removal | ||||||
|  |         promoted_employee.details = None | ||||||
|  |         promoted_employee.save() | ||||||
|  |  | ||||||
|  |         promoted_employee.reload() | ||||||
|  |         self.assertEqual(promoted_employee.details, None) | ||||||
|  |  | ||||||
|  |  | ||||||
|     def test_save_reference(self): |     def test_save_reference(self): | ||||||
|         """Ensure that a document reference field may be saved in the database. |         """Ensure that a document reference field may be saved in the database. | ||||||
|   | |||||||
| @@ -843,6 +843,7 @@ class FieldTest(unittest.TestCase): | |||||||
|             name = StringField() |             name = StringField() | ||||||
|             children = ListField(EmbeddedDocumentField('self')) |             children = ListField(EmbeddedDocumentField('self')) | ||||||
|  |  | ||||||
|  |         Tree.drop_collection | ||||||
|         tree = Tree(name="Tree") |         tree = Tree(name="Tree") | ||||||
|  |  | ||||||
|         first_child = TreeNode(name="Child 1") |         first_child = TreeNode(name="Child 1") | ||||||
| @@ -853,15 +854,42 @@ class FieldTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         third_child = TreeNode(name="Child 3") |         third_child = TreeNode(name="Child 3") | ||||||
|         first_child.children.append(third_child) |         first_child.children.append(third_child) | ||||||
|  |  | ||||||
|         tree.save() |         tree.save() | ||||||
|  |  | ||||||
|         tree_obj = Tree.objects.first() |  | ||||||
|         self.assertEqual(len(tree.children), 1) |         self.assertEqual(len(tree.children), 1) | ||||||
|         self.assertEqual(tree.children[0].name, first_child.name) |         self.assertEqual(tree.children[0].name, first_child.name) | ||||||
|         self.assertEqual(tree.children[0].children[0].name, second_child.name) |         self.assertEqual(tree.children[0].children[0].name, second_child.name) | ||||||
|         self.assertEqual(tree.children[0].children[1].name, third_child.name) |         self.assertEqual(tree.children[0].children[1].name, third_child.name) | ||||||
|  |  | ||||||
|  |         # Test updating | ||||||
|  |         tree.children[0].name = 'I am Child 1' | ||||||
|  |         tree.children[0].children[0].name = 'I am Child 2' | ||||||
|  |         tree.children[0].children[1].name = 'I am Child 3' | ||||||
|  |         tree.save() | ||||||
|  |  | ||||||
|  |         self.assertEqual(tree.children[0].name, 'I am Child 1') | ||||||
|  |         self.assertEqual(tree.children[0].children[0].name, 'I am Child 2') | ||||||
|  |         self.assertEqual(tree.children[0].children[1].name, 'I am Child 3') | ||||||
|  |  | ||||||
|  |         # Test removal | ||||||
|  |         self.assertEqual(len(tree.children[0].children), 2) | ||||||
|  |         del(tree.children[0].children[1]) | ||||||
|  |  | ||||||
|  |         tree.save() | ||||||
|  |         self.assertEqual(len(tree.children[0].children), 1) | ||||||
|  |  | ||||||
|  |         tree.children[0].children.pop(0) | ||||||
|  |         tree.save() | ||||||
|  |         self.assertEqual(len(tree.children[0].children), 0) | ||||||
|  |         self.assertEqual(tree.children[0].children, []) | ||||||
|  |  | ||||||
|  |         tree.children[0].children.insert(0, third_child) | ||||||
|  |         tree.children[0].children.insert(0, second_child) | ||||||
|  |         tree.save() | ||||||
|  |         self.assertEqual(len(tree.children[0].children), 2) | ||||||
|  |         self.assertEqual(tree.children[0].children[0].name, second_child.name) | ||||||
|  |         self.assertEqual(tree.children[0].children[1].name, third_child.name) | ||||||
|  |  | ||||||
|     def test_undefined_reference(self): |     def test_undefined_reference(self): | ||||||
|         """Ensure that ReferenceFields may reference undefined Documents. |         """Ensure that ReferenceFields may reference undefined Documents. | ||||||
|         """ |         """ | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user