Added delta tracking to documents.
All saves on exisiting items do set / unset operations only on changed fields. * Note lists and dicts generally do set operations for things like pop() del[key] As there is no easy map to unset and explicitly matches the new list / dict fixes #18
This commit is contained in:
parent
ea35fb1c54
commit
0ed79a839d
@ -5,6 +5,7 @@ Changelog
|
||||
Changes in dev
|
||||
==============
|
||||
|
||||
- Added delta tracking now only sets / unsets explicitly changed fields
|
||||
- Fixed saving so sets updated values rather than overwrites
|
||||
- Added ComplexDateTimeField - Handles datetimes correctly with microseconds
|
||||
- Added ComplexBaseField - for improved flexibility and performance
|
||||
|
@ -18,10 +18,21 @@ attribute syntax::
|
||||
|
||||
Saving and deleting documents
|
||||
=============================
|
||||
To save the document to the database, call the
|
||||
:meth:`~mongoengine.Document.save` method. If the document does not exist in
|
||||
the database, it will be created. If it does already exist, it will be
|
||||
updated.
|
||||
MongoEngine tracks changes to documents to provide efficient saving. To save
|
||||
the document to the database, call the :meth:`~mongoengine.Document.save` method.
|
||||
If the document does not exist in the database, it will be created. If it does
|
||||
already exist, then any changes will be updated atomically. For example::
|
||||
|
||||
>>> page = Page(title="Test Page")
|
||||
>>> page.save() # Performs an insert
|
||||
>>> page.title = "My Page"
|
||||
>>> page.save() # Performs an atomic set on the title field.
|
||||
|
||||
.. note::
|
||||
Changes to documents are tracked and on the whole perform `set` operations.
|
||||
|
||||
* ``list_field.pop(0)`` - *sets* the resulting list
|
||||
* ``del(list_field)`` - *unsets* whole list
|
||||
|
||||
To delete a document, call the :meth:`~mongoengine.Document.delete` method.
|
||||
Note that this will only work if the document exists in the database and has a
|
||||
|
@ -4,6 +4,7 @@ from queryset import DO_NOTHING
|
||||
|
||||
from mongoengine import signals
|
||||
|
||||
import weakref
|
||||
import sys
|
||||
import pymongo
|
||||
import pymongo.objectid
|
||||
@ -86,16 +87,19 @@ class BaseField(object):
|
||||
# Allow callable default values
|
||||
if callable(value):
|
||||
value = value()
|
||||
|
||||
# Convert lists / values so we can watch for any changes on them
|
||||
if isinstance(value, (list, tuple)) and not isinstance(value, BaseList):
|
||||
value = BaseList(value, instance=instance, name=self.name)
|
||||
elif isinstance(value, dict) and not isinstance(value, BaseDict):
|
||||
value = BaseDict(value, instance=instance, name=self.name)
|
||||
return value
|
||||
|
||||
def __set__(self, instance, value):
|
||||
"""Descriptor for assigning a value to a field in a document.
|
||||
"""
|
||||
key = self.name
|
||||
instance._data[key] = value
|
||||
# If the field set is in the _present_fields list add it so we can track
|
||||
if hasattr(instance, '_present_fields') and key and key not in instance._present_fields:
|
||||
instance._present_fields.append(self.name)
|
||||
instance._data[self.name] = value
|
||||
instance._mark_as_changed(self.name)
|
||||
|
||||
def to_python(self, value):
|
||||
"""Convert a MongoDB-compatible type to a Python type.
|
||||
@ -173,21 +177,27 @@ class ComplexBaseField(BaseField):
|
||||
db = _get_db()
|
||||
dbref = {}
|
||||
collections = {}
|
||||
for k, v in value_list.items():
|
||||
dbref[k] = v
|
||||
for k,v in value_list.items():
|
||||
|
||||
# Save any DBRefs
|
||||
if isinstance(v, (pymongo.dbref.DBRef)):
|
||||
# direct reference (DBRef)
|
||||
collections.setdefault(v.collection, []).append((k, v))
|
||||
elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v:
|
||||
collections.setdefault(v.collection, []).append((k,v))
|
||||
elif isinstance(v, (dict, pymongo.son.SON)):
|
||||
if '_ref' in v:
|
||||
# generic reference
|
||||
collection = get_document(v['_cls'])._meta['collection']
|
||||
collections.setdefault(collection, []).append((k, v))
|
||||
collections.setdefault(collection, []).append((k,v))
|
||||
else:
|
||||
# Use BaseDict so can watch any changes
|
||||
dbref[k] = BaseDict(v, instance=instance, name=self.name)
|
||||
else:
|
||||
dbref[k] = v
|
||||
|
||||
# For each collection get the references
|
||||
for collection, dbrefs in collections.items():
|
||||
id_map = {}
|
||||
for k, v in dbrefs:
|
||||
for k,v in dbrefs:
|
||||
if isinstance(v, (pymongo.dbref.DBRef)):
|
||||
# direct reference (DBRef), has no _cls information
|
||||
id_map[v.id] = (k, None)
|
||||
@ -203,7 +213,9 @@ class ComplexBaseField(BaseField):
|
||||
dbref[key] = doc_cls._from_son(ref)
|
||||
|
||||
if is_list:
|
||||
dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))]
|
||||
dbref = BaseList([v for k,v in sorted(dbref.items(), key=itemgetter(0))], instance=instance, name=self.name)
|
||||
else:
|
||||
dbref = BaseDict(dbref, instance=instance, name=self.name)
|
||||
instance._data[self.name] = dbref
|
||||
return super(ComplexBaseField, self).__get__(instance, owner)
|
||||
|
||||
@ -714,7 +726,7 @@ class BaseDocument(object):
|
||||
self._meta.get('allow_inheritance', True) == False):
|
||||
data['_cls'] = self._class_name
|
||||
data['_types'] = self._superclasses.keys() + [self._class_name]
|
||||
if data.has_key('_id') and data['_id'] is None:
|
||||
if '_id' in data and data['_id'] is None:
|
||||
del data['_id']
|
||||
return data
|
||||
|
||||
@ -751,9 +763,71 @@ class BaseDocument(object):
|
||||
else field.to_python(value))
|
||||
|
||||
obj = cls(**data)
|
||||
obj._present_fields = present_fields
|
||||
obj._changed_fields = []
|
||||
return obj
|
||||
|
||||
def _mark_as_changed(self, key):
|
||||
"""Marks a key as explicitly changed by the user
|
||||
"""
|
||||
if not key:
|
||||
return
|
||||
if hasattr(self, '_changed_fields') and key not in self._changed_fields:
|
||||
self._changed_fields.append(key)
|
||||
|
||||
def _get_changed_fields(self, key=''):
|
||||
"""Returns a list of all fields that have explicitly been changed.
|
||||
"""
|
||||
from mongoengine import EmbeddedDocument
|
||||
_changed_fields = []
|
||||
_changed_fields += getattr(self, '_changed_fields', [])
|
||||
|
||||
for field_name in self._fields:
|
||||
key = '%s.' % field_name
|
||||
field = getattr(self, field_name, None)
|
||||
if isinstance(field, EmbeddedDocument): # Grab all embedded fields that have been changed
|
||||
_changed_fields += ["%s%s" % (key, k) for k in field._get_changed_fields(key) if k]
|
||||
elif isinstance(field, (list, tuple)): # Loop list fields as they contain documents
|
||||
for index, value in enumerate(field):
|
||||
if not hasattr(value, '_get_changed_fields'):
|
||||
continue
|
||||
list_key = "%s%s." % (key, index)
|
||||
_changed_fields += ["%s%s" % (list_key, k) for k in value._get_changed_fields(list_key) if k]
|
||||
return _changed_fields
|
||||
|
||||
def _delta(self):
|
||||
"""Returns the delta (set, unset) of the changes for a document.
|
||||
Gets any values that have been explicitly changed.
|
||||
"""
|
||||
# Handles cases where not loaded from_son but has _id
|
||||
doc = self.to_mongo()
|
||||
set_fields = self._get_changed_fields()
|
||||
set_data = {}
|
||||
unset_data = {}
|
||||
if hasattr(self, '_changed_fields'):
|
||||
set_data = {}
|
||||
# Fetch each set item from its path
|
||||
for path in set_fields:
|
||||
parts = path.split('.')
|
||||
d = doc
|
||||
for p in parts:
|
||||
if hasattr(d, '__getattr__'):
|
||||
d = getattr(p, d)
|
||||
elif p.isdigit():
|
||||
d = d[int(p)]
|
||||
else:
|
||||
d = d.get(p)
|
||||
set_data[path] = d
|
||||
else:
|
||||
set_data = doc
|
||||
if '_id' in set_data:
|
||||
del(set_data['_id'])
|
||||
|
||||
for k,v in set_data.items():
|
||||
if not v:
|
||||
del(set_data[k])
|
||||
unset_data[k] = 1
|
||||
return set_data, unset_data
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, self.__class__) and hasattr(other, 'id'):
|
||||
if self.id == other.id:
|
||||
@ -764,13 +838,112 @@ class BaseDocument(object):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __hash__(self):
|
||||
""" For list, dic key """
|
||||
""" For list, dict key """
|
||||
if self.pk is None:
|
||||
# For new object
|
||||
return super(BaseDocument,self).__hash__()
|
||||
else:
|
||||
return hash(self.pk)
|
||||
|
||||
|
||||
class BaseList(list):
|
||||
"""A special list so we can watch any changes
|
||||
"""
|
||||
|
||||
def __init__(self, list_items, instance, name):
|
||||
self.instance = weakref.proxy(instance)
|
||||
self.name = name
|
||||
super(BaseList, self).__init__(list_items)
|
||||
|
||||
def __setitem__(self, *args, **kwargs):
|
||||
if hasattr(self, 'instance') and hasattr(self, 'name'):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
super(BaseDict, self).__setitem__(*args, **kwargs)
|
||||
|
||||
def __delitem__(self, *args, **kwargs):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
super(BaseList, self).__delitem__(*args, **kwargs)
|
||||
|
||||
def __delete__(self, *args, **kwargs):
|
||||
if hasattr(self, 'instance') and hasattr(self, 'name'):
|
||||
import ipdb; ipdb.set_trace()
|
||||
self.instance._mark_as_changed(self.name)
|
||||
delattr(self, 'instance')
|
||||
delattr(self, 'name')
|
||||
super(BaseDict, self).__delete__(*args, **kwargs)
|
||||
|
||||
def append(self, *args, **kwargs):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
return super(BaseList, self).append(*args, **kwargs)
|
||||
|
||||
def extend(self, *args, **kwargs):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
return super(BaseList, self).extend(*args, **kwargs)
|
||||
|
||||
def insert(self, *args, **kwargs):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
return super(BaseList, self).insert(*args, **kwargs)
|
||||
|
||||
def pop(self, *args, **kwargs):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
return super(BaseList, self).pop(*args, **kwargs)
|
||||
|
||||
def remove(self, *args, **kwargs):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
return super(BaseList, self).remove(*args, **kwargs)
|
||||
|
||||
def reverse(self, *args, **kwargs):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
return super(BaseList, self).reverse(*args, **kwargs)
|
||||
|
||||
def sort(self, *args, **kwargs):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
return super(BaseList, self).sort(*args, **kwargs)
|
||||
|
||||
|
||||
class BaseDict(dict):
|
||||
"""A special dict so we can watch any changes
|
||||
"""
|
||||
|
||||
def __init__(self, dict_items, instance, name):
|
||||
self.instance = weakref.proxy(instance)
|
||||
self.name = name
|
||||
super(BaseDict, self).__init__(dict_items)
|
||||
|
||||
def __setitem__(self, *args, **kwargs):
|
||||
if hasattr(self, 'instance') and hasattr(self, 'name'):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
super(BaseDict, self).__setitem__(*args, **kwargs)
|
||||
|
||||
def __setattr__(self, *args, **kwargs):
|
||||
if hasattr(self, 'instance') and hasattr(self, 'name'):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
super(BaseDict, self).__setattr__(*args, **kwargs)
|
||||
|
||||
def __delete__(self, *args, **kwargs):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
super(BaseDict, self).__delete__(*args, **kwargs)
|
||||
|
||||
def __delitem__(self, *args, **kwargs):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
super(BaseDict, self).__delitem__(*args, **kwargs)
|
||||
|
||||
def __delattr__(self, *args, **kwargs):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
super(BaseDict, self).__delattr__(*args, **kwargs)
|
||||
|
||||
def clear(self, *args, **kwargs):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
super(BaseDict, self).clear(*args, **kwargs)
|
||||
|
||||
def pop(self, *args, **kwargs):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
super(BaseDict, self).clear(*args, **kwargs)
|
||||
|
||||
def popitem(self, *args, **kwargs):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
super(BaseDict, self).clear(*args, **kwargs)
|
||||
|
||||
if sys.version_info < (2, 5):
|
||||
# Prior to Python 2.5, Exception was an old-style class
|
||||
import types
|
||||
|
@ -1,12 +1,11 @@
|
||||
from mongoengine import signals
|
||||
from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument,
|
||||
ValidationError)
|
||||
ValidationError, BaseDict, BaseList)
|
||||
from queryset import OperationError
|
||||
from connection import _get_db
|
||||
|
||||
import pymongo
|
||||
|
||||
|
||||
__all__ = ['Document', 'EmbeddedDocument', 'ValidationError', 'OperationError']
|
||||
|
||||
|
||||
@ -19,6 +18,18 @@ class EmbeddedDocument(BaseDocument):
|
||||
|
||||
__metaclass__ = DocumentMetaclass
|
||||
|
||||
def __delattr__(self, *args, **kwargs):
|
||||
"""Handle deletions of fields"""
|
||||
field_name = args[0]
|
||||
if field_name in self._fields:
|
||||
default = self._fields[field_name].default
|
||||
if callable(default):
|
||||
default = default()
|
||||
setattr(self, field_name, default)
|
||||
else:
|
||||
super(EmbeddedDocument, self).__delattr__(*args, **kwargs)
|
||||
|
||||
|
||||
|
||||
class Document(BaseDocument):
|
||||
"""The base class used for defining the structure and properties of
|
||||
@ -59,7 +70,6 @@ class Document(BaseDocument):
|
||||
disabled by either setting types to False on the specific index or
|
||||
by setting index_types to False on the meta dictionary for the document.
|
||||
"""
|
||||
|
||||
__metaclass__ = TopLevelDocumentMetaclass
|
||||
|
||||
def save(self, safe=True, force_insert=False, validate=True, write_options=None):
|
||||
@ -95,18 +105,15 @@ class Document(BaseDocument):
|
||||
collection = self.__class__.objects._collection
|
||||
if force_insert:
|
||||
object_id = collection.insert(doc, safe=safe, **write_options)
|
||||
elif '_id' in doc:
|
||||
# Perform a set rather than a save - this will only save set fields
|
||||
object_id = doc.pop('_id')
|
||||
collection.update({'_id': object_id}, {"$set": doc}, upsert=True, safe=safe, **write_options)
|
||||
|
||||
# Find and unset any fields explicitly set to None
|
||||
if hasattr(self, '_present_fields'):
|
||||
removals = dict([(k, 1) for k in self._present_fields if k not in doc and k != '_id'])
|
||||
if created:
|
||||
object_id = collection.save(doc, safe=safe, **write_options)
|
||||
else:
|
||||
object_id = doc['_id']
|
||||
updates, removals = self._delta()
|
||||
if updates:
|
||||
collection.update({'_id': object_id}, {"$set": updates}, upsert=True, safe=safe, **write_options)
|
||||
if removals:
|
||||
collection.update({'_id': object_id}, {"$unset": removals}, upsert=True, safe=safe, **write_options)
|
||||
else:
|
||||
object_id = collection.save(doc, safe=safe, **write_options)
|
||||
except pymongo.errors.OperationFailure, err:
|
||||
message = 'Could not save document (%s)'
|
||||
if u'duplicate key' in unicode(err):
|
||||
@ -114,7 +121,7 @@ class Document(BaseDocument):
|
||||
raise OperationError(message % unicode(err))
|
||||
id_field = self._meta['id_field']
|
||||
self[id_field] = self._fields[id_field].to_python(object_id)
|
||||
|
||||
self._changed_fields = []
|
||||
signals.post_save.send(self, created=created)
|
||||
|
||||
def delete(self, safe=False):
|
||||
@ -135,14 +142,6 @@ class Document(BaseDocument):
|
||||
|
||||
signals.post_delete.send(self)
|
||||
|
||||
@classmethod
|
||||
def register_delete_rule(cls, document_cls, field_name, rule):
|
||||
"""This method registers the delete rules to apply when removing this
|
||||
object.
|
||||
"""
|
||||
cls._meta['delete_rules'][(document_cls, field_name)] = rule
|
||||
|
||||
|
||||
def reload(self):
|
||||
"""Reloads all attributes from the database.
|
||||
|
||||
@ -151,7 +150,29 @@ class Document(BaseDocument):
|
||||
id_field = self._meta['id_field']
|
||||
obj = self.__class__.objects(**{id_field: self[id_field]}).first()
|
||||
for field in self._fields:
|
||||
setattr(self, field, obj[field])
|
||||
setattr(self, field, self._reload(field, obj[field]))
|
||||
self._changed_fields = []
|
||||
|
||||
def _reload(self, key, value):
|
||||
"""Used by :meth:`~mongoengine.Document.reload` to ensure the
|
||||
correct instance is linked to self.
|
||||
"""
|
||||
if isinstance(value, BaseDict):
|
||||
value = [(k, self._reload(k,v)) for k,v in value.items()]
|
||||
value = BaseDict(value, instance=self, name=key)
|
||||
elif isinstance(value, BaseList):
|
||||
value = [self._reload(key, v) for v in value]
|
||||
value = BaseList(value, instance=self, name=key)
|
||||
elif isinstance(value, EmbeddedDocument):
|
||||
value._changed_fields = []
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def register_delete_rule(cls, document_cls, field_name, rule):
|
||||
"""This method registers the delete rules to apply when removing this
|
||||
object.
|
||||
"""
|
||||
cls._meta['delete_rules'][(document_cls, field_name)] = rule
|
||||
|
||||
@classmethod
|
||||
def drop_collection(cls):
|
||||
|
@ -347,9 +347,9 @@ class ComplexDateTimeField(StringField):
|
||||
return datetime.datetime.now()
|
||||
return self._convert_from_string(data)
|
||||
|
||||
def __set__(self, obj, val):
|
||||
data = self._convert_from_datetime(val)
|
||||
return super(ComplexDateTimeField, self).__set__(obj, data)
|
||||
def __set__(self, instance, value):
|
||||
value = self._convert_from_datetime(value)
|
||||
return super(ComplexDateTimeField, self).__set__(instance, value)
|
||||
|
||||
def validate(self, value):
|
||||
if not isinstance(value, datetime.datetime):
|
||||
@ -686,11 +686,13 @@ class GridFSProxy(object):
|
||||
.. versionadded:: 0.4
|
||||
"""
|
||||
|
||||
def __init__(self, grid_id=None):
|
||||
def __init__(self, grid_id=None, key=None, instance=None):
|
||||
self.fs = gridfs.GridFS(_get_db()) # Filesystem instance
|
||||
self.newfile = None # Used for partial writes
|
||||
self.grid_id = grid_id # Store GridFS id for file
|
||||
self.gridout = None
|
||||
self.key = key
|
||||
self.instance = instance
|
||||
|
||||
def __getattr__(self, name):
|
||||
obj = self.get()
|
||||
@ -723,6 +725,7 @@ class GridFSProxy(object):
|
||||
raise GridFSError('This document already has a file. Either delete '
|
||||
'it or call replace to overwrite it')
|
||||
self.grid_id = self.fs.put(file_obj, **kwargs)
|
||||
self._mark_as_changed()
|
||||
|
||||
def write(self, string):
|
||||
if self.grid_id:
|
||||
@ -750,6 +753,12 @@ class GridFSProxy(object):
|
||||
self.fs.delete(self.grid_id)
|
||||
self.grid_id = None
|
||||
self.gridout = None
|
||||
self._mark_as_changed()
|
||||
|
||||
def _mark_as_changed(self):
|
||||
"""Inform the instance that `self.key` has been changed"""
|
||||
if self.instance:
|
||||
self.instance._mark_as_changed(self.key)
|
||||
|
||||
def replace(self, file_obj, **kwargs):
|
||||
self.delete()
|
||||
@ -777,10 +786,14 @@ class FileField(BaseField):
|
||||
grid_file = instance._data.get(self.name)
|
||||
self.grid_file = grid_file
|
||||
if self.grid_file:
|
||||
if not self.grid_file.key:
|
||||
self.grid_file.key = self.name
|
||||
self.grid_file.instance = instance
|
||||
return self.grid_file
|
||||
return GridFSProxy()
|
||||
return GridFSProxy(key=self.name, instance=instance)
|
||||
|
||||
def __set__(self, instance, value):
|
||||
key = self.name
|
||||
if isinstance(value, file) or isinstance(value, str):
|
||||
# using "FileField() = file/string" notation
|
||||
grid_file = instance._data.get(self.name)
|
||||
@ -794,10 +807,12 @@ class FileField(BaseField):
|
||||
grid_file.put(value)
|
||||
else:
|
||||
# Create a new proxy object as we don't already have one
|
||||
instance._data[self.name] = GridFSProxy()
|
||||
instance._data[self.name].put(value)
|
||||
instance._data[key] = GridFSProxy(key=key, instance=instance)
|
||||
instance._data[key].put(value)
|
||||
else:
|
||||
instance._data[self.name] = value
|
||||
instance._data[key] = value
|
||||
|
||||
instance._mark_as_changed(key)
|
||||
|
||||
def to_mongo(self, value):
|
||||
# Store the GridFS file id in MongoDB
|
||||
|
@ -281,9 +281,7 @@ class FieldTest(unittest.TestCase):
|
||||
|
||||
[m for m in group_obj.members]
|
||||
self.assertEqual(q, 1)
|
||||
|
||||
for k, m in group_obj.members.iteritems():
|
||||
self.assertTrue('User' in m.__class__.__name__)
|
||||
self.assertEqual(group_obj.members, {})
|
||||
|
||||
UserA.drop_collection()
|
||||
UserB.drop_collection()
|
||||
|
@ -1,4 +1,3 @@
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import unittest
|
||||
|
@ -2,6 +2,7 @@ import unittest
|
||||
from datetime import datetime
|
||||
import pymongo
|
||||
import pickle
|
||||
import weakref
|
||||
|
||||
from mongoengine import *
|
||||
from mongoengine.base import BaseField
|
||||
@ -11,6 +12,7 @@ from mongoengine.connection import _get_db
|
||||
class PickleEmbedded(EmbeddedDocument):
|
||||
date = DateTimeField(default=datetime.now)
|
||||
|
||||
|
||||
class PickleTest(Document):
|
||||
number = IntField()
|
||||
string = StringField()
|
||||
@ -717,6 +719,47 @@ class DocumentTest(unittest.TestCase):
|
||||
self.assertEqual(person.name, "Mr Test User")
|
||||
self.assertEqual(person.age, 21)
|
||||
|
||||
def test_reload_referencing(self):
|
||||
"""Ensures reloading updates weakrefs correctly
|
||||
"""
|
||||
class Embedded(EmbeddedDocument):
|
||||
dict_field = DictField()
|
||||
list_field = ListField()
|
||||
|
||||
class Doc(Document):
|
||||
dict_field = DictField()
|
||||
list_field = ListField()
|
||||
embedded_field = EmbeddedDocumentField(Embedded)
|
||||
|
||||
Doc.drop_collection
|
||||
doc = Doc()
|
||||
doc.dict_field = {'hello': 'world'}
|
||||
doc.list_field = ['1', 2, {'hello': 'world'}]
|
||||
|
||||
embedded_1 = Embedded()
|
||||
embedded_1.dict_field = {'hello': 'world'}
|
||||
embedded_1.list_field = ['1', 2, {'hello': 'world'}]
|
||||
doc.embedded_field = embedded_1
|
||||
doc.save()
|
||||
|
||||
doc.reload()
|
||||
doc.list_field.append(1)
|
||||
doc.dict_field['woot'] = "woot"
|
||||
doc.embedded_field.list_field.append(1)
|
||||
doc.embedded_field.dict_field['woot'] = "woot"
|
||||
|
||||
self.assertEquals(doc._get_changed_fields(), [
|
||||
'list_field', 'dict_field', 'embedded_field.list_field',
|
||||
'embedded_field.dict_field'])
|
||||
doc.save()
|
||||
|
||||
doc.reload()
|
||||
self.assertEquals(doc._get_changed_fields(), [])
|
||||
self.assertEquals(len(doc.list_field), 4)
|
||||
self.assertEquals(len(doc.dict_field), 2)
|
||||
self.assertEquals(len(doc.embedded_field.list_field), 4)
|
||||
self.assertEquals(len(doc.embedded_field.dict_field), 2)
|
||||
|
||||
def test_dictionary_access(self):
|
||||
"""Ensure that dictionary-style field access works properly.
|
||||
"""
|
||||
@ -873,6 +916,197 @@ class DocumentTest(unittest.TestCase):
|
||||
self.assertEqual(person.name, None)
|
||||
self.assertEqual(person.age, None)
|
||||
|
||||
def test_delta(self):
|
||||
|
||||
class Doc(Document):
|
||||
string_field = StringField()
|
||||
int_field = IntField()
|
||||
dict_field = DictField()
|
||||
list_field = ListField()
|
||||
|
||||
Doc.drop_collection
|
||||
doc = Doc()
|
||||
doc.save()
|
||||
|
||||
doc = Doc.objects.first()
|
||||
self.assertEquals(doc._get_changed_fields(), [])
|
||||
self.assertEquals(doc._delta(), ({}, {}))
|
||||
|
||||
doc.string_field = 'hello'
|
||||
self.assertEquals(doc._delta(), ({'string_field': 'hello'}, {}))
|
||||
|
||||
doc._changed_fields = []
|
||||
doc.int_field = 1
|
||||
self.assertEquals(doc._delta(), ({'int_field': 1}, {}))
|
||||
|
||||
doc._changed_fields = []
|
||||
dict_value = {'hello': 'world', 'ping': 'pong'}
|
||||
doc.dict_field = dict_value
|
||||
self.assertEquals(doc._delta(), ({'dict_field': dict_value}, {}))
|
||||
|
||||
doc._changed_fields = []
|
||||
list_value = ['1', 2, {'hello': 'world'}]
|
||||
doc.list_field = list_value
|
||||
self.assertEquals(doc._delta(), ({'list_field': list_value}, {}))
|
||||
|
||||
# Test unsetting
|
||||
doc._changed_fields = []
|
||||
doc._unset_fields = []
|
||||
doc.dict_field = {}
|
||||
self.assertEquals(doc._delta(), ({}, {'dict_field': 1}))
|
||||
|
||||
doc._changed_fields = []
|
||||
doc._unset_fields = {}
|
||||
doc.list_field = []
|
||||
self.assertEquals(doc._delta(), ({}, {'list_field': 1}))
|
||||
|
||||
def test_delta_recursive(self):
|
||||
|
||||
class Embedded(EmbeddedDocument):
|
||||
string_field = StringField()
|
||||
int_field = IntField()
|
||||
dict_field = DictField()
|
||||
list_field = ListField()
|
||||
|
||||
class Doc(Document):
|
||||
string_field = StringField()
|
||||
int_field = IntField()
|
||||
dict_field = DictField()
|
||||
list_field = ListField()
|
||||
embedded_field = EmbeddedDocumentField(Embedded)
|
||||
|
||||
Doc.drop_collection
|
||||
doc = Doc()
|
||||
doc.save()
|
||||
|
||||
doc = Doc.objects.first()
|
||||
self.assertEquals(doc._get_changed_fields(), [])
|
||||
self.assertEquals(doc._delta(), ({}, {}))
|
||||
|
||||
embedded_1 = Embedded()
|
||||
embedded_1.string_field = 'hello'
|
||||
embedded_1.int_field = 1
|
||||
embedded_1.dict_field = {'hello': 'world'}
|
||||
embedded_1.list_field = ['1', 2, {'hello': 'world'}]
|
||||
doc.embedded_field = embedded_1
|
||||
|
||||
embedded_delta = {
|
||||
'_types': ['Embedded'],
|
||||
'_cls': 'Embedded',
|
||||
'string_field': 'hello',
|
||||
'int_field': 1,
|
||||
'dict_field': {'hello': 'world'},
|
||||
'list_field': ['1', 2, {'hello': 'world'}]
|
||||
}
|
||||
self.assertEquals(doc.embedded_field._delta(), (embedded_delta, {}))
|
||||
self.assertEquals(doc._delta(), ({'embedded_field': embedded_delta}, {}))
|
||||
|
||||
doc.save()
|
||||
doc.reload()
|
||||
|
||||
doc.embedded_field.dict_field = {}
|
||||
self.assertEquals(doc.embedded_field._delta(), ({}, {'dict_field': 1}))
|
||||
self.assertEquals(doc._delta(), ({}, {'embedded_field.dict_field': 1}))
|
||||
doc.save()
|
||||
doc.reload()
|
||||
self.assertEquals(doc.embedded_field.dict_field, {})
|
||||
|
||||
doc.embedded_field.list_field = []
|
||||
self.assertEquals(doc.embedded_field._delta(), ({}, {'list_field': 1}))
|
||||
self.assertEquals(doc._delta(), ({}, {'embedded_field.list_field': 1}))
|
||||
doc.save()
|
||||
doc.reload()
|
||||
self.assertEquals(doc.embedded_field.list_field, [])
|
||||
|
||||
embedded_2 = Embedded()
|
||||
embedded_2.string_field = 'hello'
|
||||
embedded_2.int_field = 1
|
||||
embedded_2.dict_field = {'hello': 'world'}
|
||||
embedded_2.list_field = ['1', 2, {'hello': 'world'}]
|
||||
|
||||
doc.embedded_field.list_field = ['1', 2, embedded_2]
|
||||
self.assertEquals(doc.embedded_field._delta(), ({
|
||||
'list_field': ['1', 2, {
|
||||
'_cls': 'Embedded',
|
||||
'_types': ['Embedded'],
|
||||
'string_field': 'hello',
|
||||
'dict_field': {'hello': 'world'},
|
||||
'int_field': 1,
|
||||
'list_field': ['1', 2, {'hello': 'world'}],
|
||||
}]
|
||||
}, {}))
|
||||
|
||||
self.assertEquals(doc._delta(), ({
|
||||
'embedded_field.list_field': ['1', 2, {
|
||||
'_cls': 'Embedded',
|
||||
'_types': ['Embedded'],
|
||||
'string_field': 'hello',
|
||||
'dict_field': {'hello': 'world'},
|
||||
'int_field': 1,
|
||||
'list_field': ['1', 2, {'hello': 'world'}],
|
||||
}]
|
||||
}, {}))
|
||||
doc.save()
|
||||
doc.reload()
|
||||
|
||||
self.assertEquals(doc.embedded_field.list_field[0], '1')
|
||||
self.assertEquals(doc.embedded_field.list_field[1], 2)
|
||||
for k in doc.embedded_field.list_field[2]._fields:
|
||||
self.assertEquals(doc.embedded_field.list_field[2][k], embedded_2[k])
|
||||
|
||||
doc.embedded_field.list_field[2].string_field = 'world'
|
||||
self.assertEquals(doc.embedded_field._delta(), ({'list_field.2.string_field': 'world'}, {}))
|
||||
self.assertEquals(doc._delta(), ({'embedded_field.list_field.2.string_field': 'world'}, {}))
|
||||
doc.save()
|
||||
doc.reload()
|
||||
self.assertEquals(doc.embedded_field.list_field[2].string_field, 'world')
|
||||
|
||||
# Test list native methods
|
||||
doc.embedded_field.list_field[2].list_field.pop(0)
|
||||
self.assertEquals(doc._delta(), ({'embedded_field.list_field.2.list_field': [2, {'hello': 'world'}]}, {}))
|
||||
doc.save()
|
||||
doc.reload()
|
||||
|
||||
doc.embedded_field.list_field[2].list_field.append(1)
|
||||
self.assertEquals(doc._delta(), ({'embedded_field.list_field.2.list_field': [2, {'hello': 'world'}, 1]}, {}))
|
||||
doc.save()
|
||||
doc.reload()
|
||||
self.assertEquals(doc.embedded_field.list_field[2].list_field, [2, {'hello': 'world'}, 1])
|
||||
|
||||
doc.embedded_field.list_field[2].list_field.sort()
|
||||
doc.save()
|
||||
doc.reload()
|
||||
self.assertEquals(doc.embedded_field.list_field[2].list_field, [1, 2, {'hello': 'world'}])
|
||||
|
||||
del(doc.embedded_field.list_field[2].list_field[2]['hello'])
|
||||
self.assertEquals(doc._delta(), ({'embedded_field.list_field.2.list_field': [1, 2, {}]}, {}))
|
||||
doc.save()
|
||||
doc.reload()
|
||||
|
||||
del(doc.embedded_field.list_field[2].list_field)
|
||||
self.assertEquals(doc._delta(), ({}, {'embedded_field.list_field.2.list_field': 1}))
|
||||
|
||||
def test_save_only_changed_fields(self):
|
||||
"""Ensure save only sets / unsets changed fields
|
||||
"""
|
||||
|
||||
# Create person object and save it to the database
|
||||
person = self.Person(name='Test User', age=30)
|
||||
person.save()
|
||||
person.reload()
|
||||
|
||||
same_person = self.Person.objects.get()
|
||||
|
||||
person.age = 21
|
||||
same_person.name = 'User'
|
||||
|
||||
person.save()
|
||||
same_person.save()
|
||||
|
||||
person = self.Person.objects.get()
|
||||
self.assertEquals(person.name, 'User')
|
||||
self.assertEquals(person.age, 21)
|
||||
|
||||
def test_delete(self):
|
||||
"""Ensure that document may be deleted using the delete method.
|
||||
"""
|
||||
@ -978,12 +1212,19 @@ class DocumentTest(unittest.TestCase):
|
||||
promoted_employee.details.position = 'Senior Developer'
|
||||
promoted_employee.save()
|
||||
|
||||
collection = self.db[self.Person._meta['collection']]
|
||||
employee_obj = collection.find_one({'name': 'Test Employee'})
|
||||
self.assertEqual(employee_obj['name'], 'Test Employee')
|
||||
self.assertEqual(employee_obj['age'], 50)
|
||||
promoted_employee.reload()
|
||||
self.assertEqual(promoted_employee.name, 'Test Employee')
|
||||
self.assertEqual(promoted_employee.age, 50)
|
||||
# Ensure that the 'details' embedded object saved correctly
|
||||
self.assertEqual(employee_obj['details']['position'], 'Senior Developer')
|
||||
self.assertEqual(promoted_employee.details.position, 'Senior Developer')
|
||||
|
||||
# Test removal
|
||||
promoted_employee.details = None
|
||||
promoted_employee.save()
|
||||
|
||||
promoted_employee.reload()
|
||||
self.assertEqual(promoted_employee.details, None)
|
||||
|
||||
|
||||
def test_save_reference(self):
|
||||
"""Ensure that a document reference field may be saved in the database.
|
||||
|
@ -843,6 +843,7 @@ class FieldTest(unittest.TestCase):
|
||||
name = StringField()
|
||||
children = ListField(EmbeddedDocumentField('self'))
|
||||
|
||||
Tree.drop_collection
|
||||
tree = Tree(name="Tree")
|
||||
|
||||
first_child = TreeNode(name="Child 1")
|
||||
@ -853,15 +854,42 @@ class FieldTest(unittest.TestCase):
|
||||
|
||||
third_child = TreeNode(name="Child 3")
|
||||
first_child.children.append(third_child)
|
||||
|
||||
tree.save()
|
||||
|
||||
tree_obj = Tree.objects.first()
|
||||
self.assertEqual(len(tree.children), 1)
|
||||
self.assertEqual(tree.children[0].name, first_child.name)
|
||||
self.assertEqual(tree.children[0].children[0].name, second_child.name)
|
||||
self.assertEqual(tree.children[0].children[1].name, third_child.name)
|
||||
|
||||
# Test updating
|
||||
tree.children[0].name = 'I am Child 1'
|
||||
tree.children[0].children[0].name = 'I am Child 2'
|
||||
tree.children[0].children[1].name = 'I am Child 3'
|
||||
tree.save()
|
||||
|
||||
self.assertEqual(tree.children[0].name, 'I am Child 1')
|
||||
self.assertEqual(tree.children[0].children[0].name, 'I am Child 2')
|
||||
self.assertEqual(tree.children[0].children[1].name, 'I am Child 3')
|
||||
|
||||
# Test removal
|
||||
self.assertEqual(len(tree.children[0].children), 2)
|
||||
del(tree.children[0].children[1])
|
||||
|
||||
tree.save()
|
||||
self.assertEqual(len(tree.children[0].children), 1)
|
||||
|
||||
tree.children[0].children.pop(0)
|
||||
tree.save()
|
||||
self.assertEqual(len(tree.children[0].children), 0)
|
||||
self.assertEqual(tree.children[0].children, [])
|
||||
|
||||
tree.children[0].children.insert(0, third_child)
|
||||
tree.children[0].children.insert(0, second_child)
|
||||
tree.save()
|
||||
self.assertEqual(len(tree.children[0].children), 2)
|
||||
self.assertEqual(tree.children[0].children[0].name, second_child.name)
|
||||
self.assertEqual(tree.children[0].children[1].name, third_child.name)
|
||||
|
||||
def test_undefined_reference(self):
|
||||
"""Ensure that ReferenceFields may reference undefined Documents.
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user