Added delta tracking to documents.

All saves on exisiting items do set / unset operations only on changed fields.
* Note lists and dicts generally do set operations for things like pop() del[key]
  As there is no easy map to unset and explicitly matches the new list / dict

fixes #18
This commit is contained in:
Ross Lawley
2011-06-10 17:22:05 +01:00
parent ea35fb1c54
commit 0ed79a839d
9 changed files with 552 additions and 65 deletions

View File

@@ -4,6 +4,7 @@ from queryset import DO_NOTHING
from mongoengine import signals
import weakref
import sys
import pymongo
import pymongo.objectid
@@ -86,16 +87,19 @@ class BaseField(object):
# Allow callable default values
if callable(value):
value = value()
# Convert lists / values so we can watch for any changes on them
if isinstance(value, (list, tuple)) and not isinstance(value, BaseList):
value = BaseList(value, instance=instance, name=self.name)
elif isinstance(value, dict) and not isinstance(value, BaseDict):
value = BaseDict(value, instance=instance, name=self.name)
return value
def __set__(self, instance, value):
"""Descriptor for assigning a value to a field in a document.
"""
key = self.name
instance._data[key] = value
# If the field set is in the _present_fields list add it so we can track
if hasattr(instance, '_present_fields') and key and key not in instance._present_fields:
instance._present_fields.append(self.name)
instance._data[self.name] = value
instance._mark_as_changed(self.name)
def to_python(self, value):
"""Convert a MongoDB-compatible type to a Python type.
@@ -173,21 +177,27 @@ class ComplexBaseField(BaseField):
db = _get_db()
dbref = {}
collections = {}
for k, v in value_list.items():
dbref[k] = v
for k,v in value_list.items():
# Save any DBRefs
if isinstance(v, (pymongo.dbref.DBRef)):
# direct reference (DBRef)
collections.setdefault(v.collection, []).append((k, v))
elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v:
# generic reference
collection = get_document(v['_cls'])._meta['collection']
collections.setdefault(collection, []).append((k, v))
collections.setdefault(v.collection, []).append((k,v))
elif isinstance(v, (dict, pymongo.son.SON)):
if '_ref' in v:
# generic reference
collection = get_document(v['_cls'])._meta['collection']
collections.setdefault(collection, []).append((k,v))
else:
# Use BaseDict so can watch any changes
dbref[k] = BaseDict(v, instance=instance, name=self.name)
else:
dbref[k] = v
# For each collection get the references
for collection, dbrefs in collections.items():
id_map = {}
for k, v in dbrefs:
for k,v in dbrefs:
if isinstance(v, (pymongo.dbref.DBRef)):
# direct reference (DBRef), has no _cls information
id_map[v.id] = (k, None)
@@ -203,7 +213,9 @@ class ComplexBaseField(BaseField):
dbref[key] = doc_cls._from_son(ref)
if is_list:
dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))]
dbref = BaseList([v for k,v in sorted(dbref.items(), key=itemgetter(0))], instance=instance, name=self.name)
else:
dbref = BaseDict(dbref, instance=instance, name=self.name)
instance._data[self.name] = dbref
return super(ComplexBaseField, self).__get__(instance, owner)
@@ -304,7 +316,7 @@ class ComplexBaseField(BaseField):
if hasattr(value, 'iteritems'):
[self.field.validate(v) for k,v in value.iteritems()]
else:
[self.field.validate(v) for v in value]
[self.field.validate(v) for v in value]
except Exception, err:
raise ValidationError('Invalid %s item (%s)' % (
self.field.__class__.__name__, str(v)))
@@ -714,7 +726,7 @@ class BaseDocument(object):
self._meta.get('allow_inheritance', True) == False):
data['_cls'] = self._class_name
data['_types'] = self._superclasses.keys() + [self._class_name]
if data.has_key('_id') and data['_id'] is None:
if '_id' in data and data['_id'] is None:
del data['_id']
return data
@@ -751,9 +763,71 @@ class BaseDocument(object):
else field.to_python(value))
obj = cls(**data)
obj._present_fields = present_fields
obj._changed_fields = []
return obj
def _mark_as_changed(self, key):
"""Marks a key as explicitly changed by the user
"""
if not key:
return
if hasattr(self, '_changed_fields') and key not in self._changed_fields:
self._changed_fields.append(key)
def _get_changed_fields(self, key=''):
"""Returns a list of all fields that have explicitly been changed.
"""
from mongoengine import EmbeddedDocument
_changed_fields = []
_changed_fields += getattr(self, '_changed_fields', [])
for field_name in self._fields:
key = '%s.' % field_name
field = getattr(self, field_name, None)
if isinstance(field, EmbeddedDocument): # Grab all embedded fields that have been changed
_changed_fields += ["%s%s" % (key, k) for k in field._get_changed_fields(key) if k]
elif isinstance(field, (list, tuple)): # Loop list fields as they contain documents
for index, value in enumerate(field):
if not hasattr(value, '_get_changed_fields'):
continue
list_key = "%s%s." % (key, index)
_changed_fields += ["%s%s" % (list_key, k) for k in value._get_changed_fields(list_key) if k]
return _changed_fields
def _delta(self):
"""Returns the delta (set, unset) of the changes for a document.
Gets any values that have been explicitly changed.
"""
# Handles cases where not loaded from_son but has _id
doc = self.to_mongo()
set_fields = self._get_changed_fields()
set_data = {}
unset_data = {}
if hasattr(self, '_changed_fields'):
set_data = {}
# Fetch each set item from its path
for path in set_fields:
parts = path.split('.')
d = doc
for p in parts:
if hasattr(d, '__getattr__'):
d = getattr(p, d)
elif p.isdigit():
d = d[int(p)]
else:
d = d.get(p)
set_data[path] = d
else:
set_data = doc
if '_id' in set_data:
del(set_data['_id'])
for k,v in set_data.items():
if not v:
del(set_data[k])
unset_data[k] = 1
return set_data, unset_data
def __eq__(self, other):
if isinstance(other, self.__class__) and hasattr(other, 'id'):
if self.id == other.id:
@@ -764,13 +838,112 @@ class BaseDocument(object):
return not self.__eq__(other)
def __hash__(self):
""" For list, dic key """
""" For list, dict key """
if self.pk is None:
# For new object
return super(BaseDocument,self).__hash__()
else:
return hash(self.pk)
class BaseList(list):
"""A special list so we can watch any changes
"""
def __init__(self, list_items, instance, name):
self.instance = weakref.proxy(instance)
self.name = name
super(BaseList, self).__init__(list_items)
def __setitem__(self, *args, **kwargs):
if hasattr(self, 'instance') and hasattr(self, 'name'):
self.instance._mark_as_changed(self.name)
super(BaseDict, self).__setitem__(*args, **kwargs)
def __delitem__(self, *args, **kwargs):
self.instance._mark_as_changed(self.name)
super(BaseList, self).__delitem__(*args, **kwargs)
def __delete__(self, *args, **kwargs):
if hasattr(self, 'instance') and hasattr(self, 'name'):
import ipdb; ipdb.set_trace()
self.instance._mark_as_changed(self.name)
delattr(self, 'instance')
delattr(self, 'name')
super(BaseDict, self).__delete__(*args, **kwargs)
def append(self, *args, **kwargs):
self.instance._mark_as_changed(self.name)
return super(BaseList, self).append(*args, **kwargs)
def extend(self, *args, **kwargs):
self.instance._mark_as_changed(self.name)
return super(BaseList, self).extend(*args, **kwargs)
def insert(self, *args, **kwargs):
self.instance._mark_as_changed(self.name)
return super(BaseList, self).insert(*args, **kwargs)
def pop(self, *args, **kwargs):
self.instance._mark_as_changed(self.name)
return super(BaseList, self).pop(*args, **kwargs)
def remove(self, *args, **kwargs):
self.instance._mark_as_changed(self.name)
return super(BaseList, self).remove(*args, **kwargs)
def reverse(self, *args, **kwargs):
self.instance._mark_as_changed(self.name)
return super(BaseList, self).reverse(*args, **kwargs)
def sort(self, *args, **kwargs):
self.instance._mark_as_changed(self.name)
return super(BaseList, self).sort(*args, **kwargs)
class BaseDict(dict):
"""A special dict so we can watch any changes
"""
def __init__(self, dict_items, instance, name):
self.instance = weakref.proxy(instance)
self.name = name
super(BaseDict, self).__init__(dict_items)
def __setitem__(self, *args, **kwargs):
if hasattr(self, 'instance') and hasattr(self, 'name'):
self.instance._mark_as_changed(self.name)
super(BaseDict, self).__setitem__(*args, **kwargs)
def __setattr__(self, *args, **kwargs):
if hasattr(self, 'instance') and hasattr(self, 'name'):
self.instance._mark_as_changed(self.name)
super(BaseDict, self).__setattr__(*args, **kwargs)
def __delete__(self, *args, **kwargs):
self.instance._mark_as_changed(self.name)
super(BaseDict, self).__delete__(*args, **kwargs)
def __delitem__(self, *args, **kwargs):
self.instance._mark_as_changed(self.name)
super(BaseDict, self).__delitem__(*args, **kwargs)
def __delattr__(self, *args, **kwargs):
self.instance._mark_as_changed(self.name)
super(BaseDict, self).__delattr__(*args, **kwargs)
def clear(self, *args, **kwargs):
self.instance._mark_as_changed(self.name)
super(BaseDict, self).clear(*args, **kwargs)
def pop(self, *args, **kwargs):
self.instance._mark_as_changed(self.name)
super(BaseDict, self).clear(*args, **kwargs)
def popitem(self, *args, **kwargs):
self.instance._mark_as_changed(self.name)
super(BaseDict, self).clear(*args, **kwargs)
if sys.version_info < (2, 5):
# Prior to Python 2.5, Exception was an old-style class
import types

View File

@@ -1,12 +1,11 @@
from mongoengine import signals
from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument,
ValidationError)
ValidationError, BaseDict, BaseList)
from queryset import OperationError
from connection import _get_db
import pymongo
__all__ = ['Document', 'EmbeddedDocument', 'ValidationError', 'OperationError']
@@ -19,6 +18,18 @@ class EmbeddedDocument(BaseDocument):
__metaclass__ = DocumentMetaclass
def __delattr__(self, *args, **kwargs):
"""Handle deletions of fields"""
field_name = args[0]
if field_name in self._fields:
default = self._fields[field_name].default
if callable(default):
default = default()
setattr(self, field_name, default)
else:
super(EmbeddedDocument, self).__delattr__(*args, **kwargs)
class Document(BaseDocument):
"""The base class used for defining the structure and properties of
@@ -59,7 +70,6 @@ class Document(BaseDocument):
disabled by either setting types to False on the specific index or
by setting index_types to False on the meta dictionary for the document.
"""
__metaclass__ = TopLevelDocumentMetaclass
def save(self, safe=True, force_insert=False, validate=True, write_options=None):
@@ -95,18 +105,15 @@ class Document(BaseDocument):
collection = self.__class__.objects._collection
if force_insert:
object_id = collection.insert(doc, safe=safe, **write_options)
elif '_id' in doc:
# Perform a set rather than a save - this will only save set fields
object_id = doc.pop('_id')
collection.update({'_id': object_id}, {"$set": doc}, upsert=True, safe=safe, **write_options)
# Find and unset any fields explicitly set to None
if hasattr(self, '_present_fields'):
removals = dict([(k, 1) for k in self._present_fields if k not in doc and k != '_id'])
if removals:
collection.update({'_id': object_id}, {"$unset": removals}, upsert=True, safe=safe, **write_options)
else:
if created:
object_id = collection.save(doc, safe=safe, **write_options)
else:
object_id = doc['_id']
updates, removals = self._delta()
if updates:
collection.update({'_id': object_id}, {"$set": updates}, upsert=True, safe=safe, **write_options)
if removals:
collection.update({'_id': object_id}, {"$unset": removals}, upsert=True, safe=safe, **write_options)
except pymongo.errors.OperationFailure, err:
message = 'Could not save document (%s)'
if u'duplicate key' in unicode(err):
@@ -114,7 +121,7 @@ class Document(BaseDocument):
raise OperationError(message % unicode(err))
id_field = self._meta['id_field']
self[id_field] = self._fields[id_field].to_python(object_id)
self._changed_fields = []
signals.post_save.send(self, created=created)
def delete(self, safe=False):
@@ -135,14 +142,6 @@ class Document(BaseDocument):
signals.post_delete.send(self)
@classmethod
def register_delete_rule(cls, document_cls, field_name, rule):
"""This method registers the delete rules to apply when removing this
object.
"""
cls._meta['delete_rules'][(document_cls, field_name)] = rule
def reload(self):
"""Reloads all attributes from the database.
@@ -151,7 +150,29 @@ class Document(BaseDocument):
id_field = self._meta['id_field']
obj = self.__class__.objects(**{id_field: self[id_field]}).first()
for field in self._fields:
setattr(self, field, obj[field])
setattr(self, field, self._reload(field, obj[field]))
self._changed_fields = []
def _reload(self, key, value):
"""Used by :meth:`~mongoengine.Document.reload` to ensure the
correct instance is linked to self.
"""
if isinstance(value, BaseDict):
value = [(k, self._reload(k,v)) for k,v in value.items()]
value = BaseDict(value, instance=self, name=key)
elif isinstance(value, BaseList):
value = [self._reload(key, v) for v in value]
value = BaseList(value, instance=self, name=key)
elif isinstance(value, EmbeddedDocument):
value._changed_fields = []
return value
@classmethod
def register_delete_rule(cls, document_cls, field_name, rule):
"""This method registers the delete rules to apply when removing this
object.
"""
cls._meta['delete_rules'][(document_cls, field_name)] = rule
@classmethod
def drop_collection(cls):

View File

@@ -347,9 +347,9 @@ class ComplexDateTimeField(StringField):
return datetime.datetime.now()
return self._convert_from_string(data)
def __set__(self, obj, val):
data = self._convert_from_datetime(val)
return super(ComplexDateTimeField, self).__set__(obj, data)
def __set__(self, instance, value):
value = self._convert_from_datetime(value)
return super(ComplexDateTimeField, self).__set__(instance, value)
def validate(self, value):
if not isinstance(value, datetime.datetime):
@@ -686,11 +686,13 @@ class GridFSProxy(object):
.. versionadded:: 0.4
"""
def __init__(self, grid_id=None):
def __init__(self, grid_id=None, key=None, instance=None):
self.fs = gridfs.GridFS(_get_db()) # Filesystem instance
self.newfile = None # Used for partial writes
self.grid_id = grid_id # Store GridFS id for file
self.gridout = None
self.key = key
self.instance = instance
def __getattr__(self, name):
obj = self.get()
@@ -723,6 +725,7 @@ class GridFSProxy(object):
raise GridFSError('This document already has a file. Either delete '
'it or call replace to overwrite it')
self.grid_id = self.fs.put(file_obj, **kwargs)
self._mark_as_changed()
def write(self, string):
if self.grid_id:
@@ -750,6 +753,12 @@ class GridFSProxy(object):
self.fs.delete(self.grid_id)
self.grid_id = None
self.gridout = None
self._mark_as_changed()
def _mark_as_changed(self):
"""Inform the instance that `self.key` has been changed"""
if self.instance:
self.instance._mark_as_changed(self.key)
def replace(self, file_obj, **kwargs):
self.delete()
@@ -777,10 +786,14 @@ class FileField(BaseField):
grid_file = instance._data.get(self.name)
self.grid_file = grid_file
if self.grid_file:
if not self.grid_file.key:
self.grid_file.key = self.name
self.grid_file.instance = instance
return self.grid_file
return GridFSProxy()
return GridFSProxy(key=self.name, instance=instance)
def __set__(self, instance, value):
key = self.name
if isinstance(value, file) or isinstance(value, str):
# using "FileField() = file/string" notation
grid_file = instance._data.get(self.name)
@@ -794,10 +807,12 @@ class FileField(BaseField):
grid_file.put(value)
else:
# Create a new proxy object as we don't already have one
instance._data[self.name] = GridFSProxy()
instance._data[self.name].put(value)
instance._data[key] = GridFSProxy(key=key, instance=instance)
instance._data[key].put(value)
else:
instance._data[self.name] = value
instance._data[key] = value
instance._mark_as_changed(key)
def to_mongo(self, value):
# Store the GridFS file id in MongoDB