Added delta tracking to documents.
All saves on exisiting items do set / unset operations only on changed fields. * Note lists and dicts generally do set operations for things like pop() del[key] As there is no easy map to unset and explicitly matches the new list / dict fixes #18
This commit is contained in:
@@ -4,6 +4,7 @@ from queryset import DO_NOTHING
|
||||
|
||||
from mongoengine import signals
|
||||
|
||||
import weakref
|
||||
import sys
|
||||
import pymongo
|
||||
import pymongo.objectid
|
||||
@@ -86,16 +87,19 @@ class BaseField(object):
|
||||
# Allow callable default values
|
||||
if callable(value):
|
||||
value = value()
|
||||
|
||||
# Convert lists / values so we can watch for any changes on them
|
||||
if isinstance(value, (list, tuple)) and not isinstance(value, BaseList):
|
||||
value = BaseList(value, instance=instance, name=self.name)
|
||||
elif isinstance(value, dict) and not isinstance(value, BaseDict):
|
||||
value = BaseDict(value, instance=instance, name=self.name)
|
||||
return value
|
||||
|
||||
def __set__(self, instance, value):
|
||||
"""Descriptor for assigning a value to a field in a document.
|
||||
"""
|
||||
key = self.name
|
||||
instance._data[key] = value
|
||||
# If the field set is in the _present_fields list add it so we can track
|
||||
if hasattr(instance, '_present_fields') and key and key not in instance._present_fields:
|
||||
instance._present_fields.append(self.name)
|
||||
instance._data[self.name] = value
|
||||
instance._mark_as_changed(self.name)
|
||||
|
||||
def to_python(self, value):
|
||||
"""Convert a MongoDB-compatible type to a Python type.
|
||||
@@ -173,21 +177,27 @@ class ComplexBaseField(BaseField):
|
||||
db = _get_db()
|
||||
dbref = {}
|
||||
collections = {}
|
||||
for k, v in value_list.items():
|
||||
dbref[k] = v
|
||||
for k,v in value_list.items():
|
||||
|
||||
# Save any DBRefs
|
||||
if isinstance(v, (pymongo.dbref.DBRef)):
|
||||
# direct reference (DBRef)
|
||||
collections.setdefault(v.collection, []).append((k, v))
|
||||
elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v:
|
||||
# generic reference
|
||||
collection = get_document(v['_cls'])._meta['collection']
|
||||
collections.setdefault(collection, []).append((k, v))
|
||||
collections.setdefault(v.collection, []).append((k,v))
|
||||
elif isinstance(v, (dict, pymongo.son.SON)):
|
||||
if '_ref' in v:
|
||||
# generic reference
|
||||
collection = get_document(v['_cls'])._meta['collection']
|
||||
collections.setdefault(collection, []).append((k,v))
|
||||
else:
|
||||
# Use BaseDict so can watch any changes
|
||||
dbref[k] = BaseDict(v, instance=instance, name=self.name)
|
||||
else:
|
||||
dbref[k] = v
|
||||
|
||||
# For each collection get the references
|
||||
for collection, dbrefs in collections.items():
|
||||
id_map = {}
|
||||
for k, v in dbrefs:
|
||||
for k,v in dbrefs:
|
||||
if isinstance(v, (pymongo.dbref.DBRef)):
|
||||
# direct reference (DBRef), has no _cls information
|
||||
id_map[v.id] = (k, None)
|
||||
@@ -203,7 +213,9 @@ class ComplexBaseField(BaseField):
|
||||
dbref[key] = doc_cls._from_son(ref)
|
||||
|
||||
if is_list:
|
||||
dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))]
|
||||
dbref = BaseList([v for k,v in sorted(dbref.items(), key=itemgetter(0))], instance=instance, name=self.name)
|
||||
else:
|
||||
dbref = BaseDict(dbref, instance=instance, name=self.name)
|
||||
instance._data[self.name] = dbref
|
||||
return super(ComplexBaseField, self).__get__(instance, owner)
|
||||
|
||||
@@ -304,7 +316,7 @@ class ComplexBaseField(BaseField):
|
||||
if hasattr(value, 'iteritems'):
|
||||
[self.field.validate(v) for k,v in value.iteritems()]
|
||||
else:
|
||||
[self.field.validate(v) for v in value]
|
||||
[self.field.validate(v) for v in value]
|
||||
except Exception, err:
|
||||
raise ValidationError('Invalid %s item (%s)' % (
|
||||
self.field.__class__.__name__, str(v)))
|
||||
@@ -714,7 +726,7 @@ class BaseDocument(object):
|
||||
self._meta.get('allow_inheritance', True) == False):
|
||||
data['_cls'] = self._class_name
|
||||
data['_types'] = self._superclasses.keys() + [self._class_name]
|
||||
if data.has_key('_id') and data['_id'] is None:
|
||||
if '_id' in data and data['_id'] is None:
|
||||
del data['_id']
|
||||
return data
|
||||
|
||||
@@ -751,9 +763,71 @@ class BaseDocument(object):
|
||||
else field.to_python(value))
|
||||
|
||||
obj = cls(**data)
|
||||
obj._present_fields = present_fields
|
||||
obj._changed_fields = []
|
||||
return obj
|
||||
|
||||
def _mark_as_changed(self, key):
|
||||
"""Marks a key as explicitly changed by the user
|
||||
"""
|
||||
if not key:
|
||||
return
|
||||
if hasattr(self, '_changed_fields') and key not in self._changed_fields:
|
||||
self._changed_fields.append(key)
|
||||
|
||||
def _get_changed_fields(self, key=''):
|
||||
"""Returns a list of all fields that have explicitly been changed.
|
||||
"""
|
||||
from mongoengine import EmbeddedDocument
|
||||
_changed_fields = []
|
||||
_changed_fields += getattr(self, '_changed_fields', [])
|
||||
|
||||
for field_name in self._fields:
|
||||
key = '%s.' % field_name
|
||||
field = getattr(self, field_name, None)
|
||||
if isinstance(field, EmbeddedDocument): # Grab all embedded fields that have been changed
|
||||
_changed_fields += ["%s%s" % (key, k) for k in field._get_changed_fields(key) if k]
|
||||
elif isinstance(field, (list, tuple)): # Loop list fields as they contain documents
|
||||
for index, value in enumerate(field):
|
||||
if not hasattr(value, '_get_changed_fields'):
|
||||
continue
|
||||
list_key = "%s%s." % (key, index)
|
||||
_changed_fields += ["%s%s" % (list_key, k) for k in value._get_changed_fields(list_key) if k]
|
||||
return _changed_fields
|
||||
|
||||
def _delta(self):
|
||||
"""Returns the delta (set, unset) of the changes for a document.
|
||||
Gets any values that have been explicitly changed.
|
||||
"""
|
||||
# Handles cases where not loaded from_son but has _id
|
||||
doc = self.to_mongo()
|
||||
set_fields = self._get_changed_fields()
|
||||
set_data = {}
|
||||
unset_data = {}
|
||||
if hasattr(self, '_changed_fields'):
|
||||
set_data = {}
|
||||
# Fetch each set item from its path
|
||||
for path in set_fields:
|
||||
parts = path.split('.')
|
||||
d = doc
|
||||
for p in parts:
|
||||
if hasattr(d, '__getattr__'):
|
||||
d = getattr(p, d)
|
||||
elif p.isdigit():
|
||||
d = d[int(p)]
|
||||
else:
|
||||
d = d.get(p)
|
||||
set_data[path] = d
|
||||
else:
|
||||
set_data = doc
|
||||
if '_id' in set_data:
|
||||
del(set_data['_id'])
|
||||
|
||||
for k,v in set_data.items():
|
||||
if not v:
|
||||
del(set_data[k])
|
||||
unset_data[k] = 1
|
||||
return set_data, unset_data
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, self.__class__) and hasattr(other, 'id'):
|
||||
if self.id == other.id:
|
||||
@@ -764,13 +838,112 @@ class BaseDocument(object):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __hash__(self):
|
||||
""" For list, dic key """
|
||||
""" For list, dict key """
|
||||
if self.pk is None:
|
||||
# For new object
|
||||
return super(BaseDocument,self).__hash__()
|
||||
else:
|
||||
return hash(self.pk)
|
||||
|
||||
|
||||
class BaseList(list):
|
||||
"""A special list so we can watch any changes
|
||||
"""
|
||||
|
||||
def __init__(self, list_items, instance, name):
|
||||
self.instance = weakref.proxy(instance)
|
||||
self.name = name
|
||||
super(BaseList, self).__init__(list_items)
|
||||
|
||||
def __setitem__(self, *args, **kwargs):
|
||||
if hasattr(self, 'instance') and hasattr(self, 'name'):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
super(BaseDict, self).__setitem__(*args, **kwargs)
|
||||
|
||||
def __delitem__(self, *args, **kwargs):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
super(BaseList, self).__delitem__(*args, **kwargs)
|
||||
|
||||
def __delete__(self, *args, **kwargs):
|
||||
if hasattr(self, 'instance') and hasattr(self, 'name'):
|
||||
import ipdb; ipdb.set_trace()
|
||||
self.instance._mark_as_changed(self.name)
|
||||
delattr(self, 'instance')
|
||||
delattr(self, 'name')
|
||||
super(BaseDict, self).__delete__(*args, **kwargs)
|
||||
|
||||
def append(self, *args, **kwargs):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
return super(BaseList, self).append(*args, **kwargs)
|
||||
|
||||
def extend(self, *args, **kwargs):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
return super(BaseList, self).extend(*args, **kwargs)
|
||||
|
||||
def insert(self, *args, **kwargs):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
return super(BaseList, self).insert(*args, **kwargs)
|
||||
|
||||
def pop(self, *args, **kwargs):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
return super(BaseList, self).pop(*args, **kwargs)
|
||||
|
||||
def remove(self, *args, **kwargs):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
return super(BaseList, self).remove(*args, **kwargs)
|
||||
|
||||
def reverse(self, *args, **kwargs):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
return super(BaseList, self).reverse(*args, **kwargs)
|
||||
|
||||
def sort(self, *args, **kwargs):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
return super(BaseList, self).sort(*args, **kwargs)
|
||||
|
||||
|
||||
class BaseDict(dict):
|
||||
"""A special dict so we can watch any changes
|
||||
"""
|
||||
|
||||
def __init__(self, dict_items, instance, name):
|
||||
self.instance = weakref.proxy(instance)
|
||||
self.name = name
|
||||
super(BaseDict, self).__init__(dict_items)
|
||||
|
||||
def __setitem__(self, *args, **kwargs):
|
||||
if hasattr(self, 'instance') and hasattr(self, 'name'):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
super(BaseDict, self).__setitem__(*args, **kwargs)
|
||||
|
||||
def __setattr__(self, *args, **kwargs):
|
||||
if hasattr(self, 'instance') and hasattr(self, 'name'):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
super(BaseDict, self).__setattr__(*args, **kwargs)
|
||||
|
||||
def __delete__(self, *args, **kwargs):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
super(BaseDict, self).__delete__(*args, **kwargs)
|
||||
|
||||
def __delitem__(self, *args, **kwargs):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
super(BaseDict, self).__delitem__(*args, **kwargs)
|
||||
|
||||
def __delattr__(self, *args, **kwargs):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
super(BaseDict, self).__delattr__(*args, **kwargs)
|
||||
|
||||
def clear(self, *args, **kwargs):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
super(BaseDict, self).clear(*args, **kwargs)
|
||||
|
||||
def pop(self, *args, **kwargs):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
super(BaseDict, self).clear(*args, **kwargs)
|
||||
|
||||
def popitem(self, *args, **kwargs):
|
||||
self.instance._mark_as_changed(self.name)
|
||||
super(BaseDict, self).clear(*args, **kwargs)
|
||||
|
||||
if sys.version_info < (2, 5):
|
||||
# Prior to Python 2.5, Exception was an old-style class
|
||||
import types
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
from mongoengine import signals
|
||||
from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument,
|
||||
ValidationError)
|
||||
ValidationError, BaseDict, BaseList)
|
||||
from queryset import OperationError
|
||||
from connection import _get_db
|
||||
|
||||
import pymongo
|
||||
|
||||
|
||||
__all__ = ['Document', 'EmbeddedDocument', 'ValidationError', 'OperationError']
|
||||
|
||||
|
||||
@@ -19,6 +18,18 @@ class EmbeddedDocument(BaseDocument):
|
||||
|
||||
__metaclass__ = DocumentMetaclass
|
||||
|
||||
def __delattr__(self, *args, **kwargs):
|
||||
"""Handle deletions of fields"""
|
||||
field_name = args[0]
|
||||
if field_name in self._fields:
|
||||
default = self._fields[field_name].default
|
||||
if callable(default):
|
||||
default = default()
|
||||
setattr(self, field_name, default)
|
||||
else:
|
||||
super(EmbeddedDocument, self).__delattr__(*args, **kwargs)
|
||||
|
||||
|
||||
|
||||
class Document(BaseDocument):
|
||||
"""The base class used for defining the structure and properties of
|
||||
@@ -59,7 +70,6 @@ class Document(BaseDocument):
|
||||
disabled by either setting types to False on the specific index or
|
||||
by setting index_types to False on the meta dictionary for the document.
|
||||
"""
|
||||
|
||||
__metaclass__ = TopLevelDocumentMetaclass
|
||||
|
||||
def save(self, safe=True, force_insert=False, validate=True, write_options=None):
|
||||
@@ -95,18 +105,15 @@ class Document(BaseDocument):
|
||||
collection = self.__class__.objects._collection
|
||||
if force_insert:
|
||||
object_id = collection.insert(doc, safe=safe, **write_options)
|
||||
elif '_id' in doc:
|
||||
# Perform a set rather than a save - this will only save set fields
|
||||
object_id = doc.pop('_id')
|
||||
collection.update({'_id': object_id}, {"$set": doc}, upsert=True, safe=safe, **write_options)
|
||||
|
||||
# Find and unset any fields explicitly set to None
|
||||
if hasattr(self, '_present_fields'):
|
||||
removals = dict([(k, 1) for k in self._present_fields if k not in doc and k != '_id'])
|
||||
if removals:
|
||||
collection.update({'_id': object_id}, {"$unset": removals}, upsert=True, safe=safe, **write_options)
|
||||
else:
|
||||
if created:
|
||||
object_id = collection.save(doc, safe=safe, **write_options)
|
||||
else:
|
||||
object_id = doc['_id']
|
||||
updates, removals = self._delta()
|
||||
if updates:
|
||||
collection.update({'_id': object_id}, {"$set": updates}, upsert=True, safe=safe, **write_options)
|
||||
if removals:
|
||||
collection.update({'_id': object_id}, {"$unset": removals}, upsert=True, safe=safe, **write_options)
|
||||
except pymongo.errors.OperationFailure, err:
|
||||
message = 'Could not save document (%s)'
|
||||
if u'duplicate key' in unicode(err):
|
||||
@@ -114,7 +121,7 @@ class Document(BaseDocument):
|
||||
raise OperationError(message % unicode(err))
|
||||
id_field = self._meta['id_field']
|
||||
self[id_field] = self._fields[id_field].to_python(object_id)
|
||||
|
||||
self._changed_fields = []
|
||||
signals.post_save.send(self, created=created)
|
||||
|
||||
def delete(self, safe=False):
|
||||
@@ -135,14 +142,6 @@ class Document(BaseDocument):
|
||||
|
||||
signals.post_delete.send(self)
|
||||
|
||||
@classmethod
|
||||
def register_delete_rule(cls, document_cls, field_name, rule):
|
||||
"""This method registers the delete rules to apply when removing this
|
||||
object.
|
||||
"""
|
||||
cls._meta['delete_rules'][(document_cls, field_name)] = rule
|
||||
|
||||
|
||||
def reload(self):
|
||||
"""Reloads all attributes from the database.
|
||||
|
||||
@@ -151,7 +150,29 @@ class Document(BaseDocument):
|
||||
id_field = self._meta['id_field']
|
||||
obj = self.__class__.objects(**{id_field: self[id_field]}).first()
|
||||
for field in self._fields:
|
||||
setattr(self, field, obj[field])
|
||||
setattr(self, field, self._reload(field, obj[field]))
|
||||
self._changed_fields = []
|
||||
|
||||
def _reload(self, key, value):
|
||||
"""Used by :meth:`~mongoengine.Document.reload` to ensure the
|
||||
correct instance is linked to self.
|
||||
"""
|
||||
if isinstance(value, BaseDict):
|
||||
value = [(k, self._reload(k,v)) for k,v in value.items()]
|
||||
value = BaseDict(value, instance=self, name=key)
|
||||
elif isinstance(value, BaseList):
|
||||
value = [self._reload(key, v) for v in value]
|
||||
value = BaseList(value, instance=self, name=key)
|
||||
elif isinstance(value, EmbeddedDocument):
|
||||
value._changed_fields = []
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def register_delete_rule(cls, document_cls, field_name, rule):
|
||||
"""This method registers the delete rules to apply when removing this
|
||||
object.
|
||||
"""
|
||||
cls._meta['delete_rules'][(document_cls, field_name)] = rule
|
||||
|
||||
@classmethod
|
||||
def drop_collection(cls):
|
||||
|
||||
@@ -347,9 +347,9 @@ class ComplexDateTimeField(StringField):
|
||||
return datetime.datetime.now()
|
||||
return self._convert_from_string(data)
|
||||
|
||||
def __set__(self, obj, val):
|
||||
data = self._convert_from_datetime(val)
|
||||
return super(ComplexDateTimeField, self).__set__(obj, data)
|
||||
def __set__(self, instance, value):
|
||||
value = self._convert_from_datetime(value)
|
||||
return super(ComplexDateTimeField, self).__set__(instance, value)
|
||||
|
||||
def validate(self, value):
|
||||
if not isinstance(value, datetime.datetime):
|
||||
@@ -686,11 +686,13 @@ class GridFSProxy(object):
|
||||
.. versionadded:: 0.4
|
||||
"""
|
||||
|
||||
def __init__(self, grid_id=None):
|
||||
def __init__(self, grid_id=None, key=None, instance=None):
|
||||
self.fs = gridfs.GridFS(_get_db()) # Filesystem instance
|
||||
self.newfile = None # Used for partial writes
|
||||
self.grid_id = grid_id # Store GridFS id for file
|
||||
self.gridout = None
|
||||
self.key = key
|
||||
self.instance = instance
|
||||
|
||||
def __getattr__(self, name):
|
||||
obj = self.get()
|
||||
@@ -723,6 +725,7 @@ class GridFSProxy(object):
|
||||
raise GridFSError('This document already has a file. Either delete '
|
||||
'it or call replace to overwrite it')
|
||||
self.grid_id = self.fs.put(file_obj, **kwargs)
|
||||
self._mark_as_changed()
|
||||
|
||||
def write(self, string):
|
||||
if self.grid_id:
|
||||
@@ -750,6 +753,12 @@ class GridFSProxy(object):
|
||||
self.fs.delete(self.grid_id)
|
||||
self.grid_id = None
|
||||
self.gridout = None
|
||||
self._mark_as_changed()
|
||||
|
||||
def _mark_as_changed(self):
|
||||
"""Inform the instance that `self.key` has been changed"""
|
||||
if self.instance:
|
||||
self.instance._mark_as_changed(self.key)
|
||||
|
||||
def replace(self, file_obj, **kwargs):
|
||||
self.delete()
|
||||
@@ -777,10 +786,14 @@ class FileField(BaseField):
|
||||
grid_file = instance._data.get(self.name)
|
||||
self.grid_file = grid_file
|
||||
if self.grid_file:
|
||||
if not self.grid_file.key:
|
||||
self.grid_file.key = self.name
|
||||
self.grid_file.instance = instance
|
||||
return self.grid_file
|
||||
return GridFSProxy()
|
||||
return GridFSProxy(key=self.name, instance=instance)
|
||||
|
||||
def __set__(self, instance, value):
|
||||
key = self.name
|
||||
if isinstance(value, file) or isinstance(value, str):
|
||||
# using "FileField() = file/string" notation
|
||||
grid_file = instance._data.get(self.name)
|
||||
@@ -794,10 +807,12 @@ class FileField(BaseField):
|
||||
grid_file.put(value)
|
||||
else:
|
||||
# Create a new proxy object as we don't already have one
|
||||
instance._data[self.name] = GridFSProxy()
|
||||
instance._data[self.name].put(value)
|
||||
instance._data[key] = GridFSProxy(key=key, instance=instance)
|
||||
instance._data[key].put(value)
|
||||
else:
|
||||
instance._data[self.name] = value
|
||||
instance._data[key] = value
|
||||
|
||||
instance._mark_as_changed(key)
|
||||
|
||||
def to_mongo(self, value):
|
||||
# Store the GridFS file id in MongoDB
|
||||
|
||||
Reference in New Issue
Block a user