Merge branch 'dev' into feature/dev-indexes

This commit is contained in:
Ross Lawley
2011-06-06 11:40:21 +01:00
13 changed files with 697 additions and 91 deletions

View File

@@ -6,9 +6,11 @@ import connection
from connection import *
import queryset
from queryset import *
import signals
from signals import *
__all__ = (document.__all__ + fields.__all__ + connection.__all__ +
queryset.__all__)
queryset.__all__ + signals.__all__)
__author__ = 'Harry Marr'

View File

@@ -2,9 +2,12 @@ from queryset import QuerySet, QuerySetManager
from queryset import DoesNotExist, MultipleObjectsReturned
from queryset import DO_NOTHING
from mongoengine import signals
import sys
import pymongo
import pymongo.objectid
from operator import itemgetter
class NotRegistered(Exception):
@@ -126,6 +129,88 @@ class BaseField(object):
self.validate(value)
class DereferenceBaseField(BaseField):
"""Handles the lazy dereferencing of a queryset. Will dereference all
items in a list / dict rather than one at a time.
"""
def __get__(self, instance, owner):
"""Descriptor to automatically dereference references.
"""
from fields import ReferenceField, GenericReferenceField
from connection import _get_db
if instance is None:
# Document class being used rather than a document object
return self
# Get value from document instance if available
value_list = instance._data.get(self.name)
if not value_list:
return super(DereferenceBaseField, self).__get__(instance, owner)
is_list = False
if not hasattr(value_list, 'items'):
is_list = True
value_list = dict([(k,v) for k,v in enumerate(value_list)])
if isinstance(self.field, ReferenceField) and value_list:
db = _get_db()
dbref = {}
collections = {}
for k, v in value_list.items():
dbref[k] = v
# Save any DBRefs
if isinstance(v, (pymongo.dbref.DBRef)):
collections.setdefault(v.collection, []).append((k, v))
# For each collection get the references
for collection, dbrefs in collections.items():
id_map = dict([(v.id, k) for k, v in dbrefs])
references = db[collection].find({'_id': {'$in': id_map.keys()}})
for ref in references:
key = id_map[ref['_id']]
dbref[key] = get_document(ref['_cls'])._from_son(ref)
if is_list:
dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))]
instance._data[self.name] = dbref
# Get value from document instance if available
if isinstance(self.field, GenericReferenceField) and value_list:
db = _get_db()
value_list = [(k,v) for k,v in value_list.items()]
dbref = {}
classes = {}
for k, v in value_list:
dbref[k] = v
# Save any DBRefs
if isinstance(v, (dict, pymongo.son.SON)):
classes.setdefault(v['_cls'], []).append((k, v))
# For each collection get the references
for doc_cls, dbrefs in classes.items():
id_map = dict([(v['_ref'].id, k) for k, v in dbrefs])
doc_cls = get_document(doc_cls)
collection = doc_cls._meta['collection']
references = db[collection].find({'_id': {'$in': id_map.keys()}})
for ref in references:
key = id_map[ref['_id']]
dbref[key] = doc_cls._from_son(ref)
if is_list:
dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))]
instance._data[self.name] = dbref
return super(DereferenceBaseField, self).__get__(instance, owner)
class ObjectIdField(BaseField):
"""An field wrapper around MongoDB's ObjectIds.
"""
@@ -382,6 +467,8 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
class BaseDocument(object):
def __init__(self, **values):
signals.pre_init.send(self, values=values)
self._data = {}
# Assign default values to instance
for attr_name in self._fields.keys():
@@ -395,6 +482,8 @@ class BaseDocument(object):
except AttributeError:
pass
signals.post_init.send(self)
def validate(self):
"""Ensure that all fields' values are valid and that required fields
are present.

View File

@@ -1,3 +1,4 @@
from mongoengine import signals
from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument,
ValidationError)
from queryset import OperationError
@@ -75,6 +76,8 @@ class Document(BaseDocument):
For example, ``save(..., w=2, fsync=True)`` will wait until at least two servers
have recorded the write and will force an fsync on each server being written to.
"""
signals.pre_save.send(self)
if validate:
self.validate()
@@ -82,6 +85,7 @@ class Document(BaseDocument):
write_options = {}
doc = self.to_mongo()
created = '_id' not in doc
try:
collection = self.__class__.objects._collection
if force_insert:
@@ -96,12 +100,16 @@ class Document(BaseDocument):
id_field = self._meta['id_field']
self[id_field] = self._fields[id_field].to_python(object_id)
signals.post_save.send(self, created=created)
def delete(self, safe=False):
"""Delete the :class:`~mongoengine.Document` from the database. This
will only take effect if the document has been previously saved.
:param safe: check if the operation succeeded before returning
"""
signals.pre_delete.send(self)
id_field = self._meta['id_field']
object_id = self._fields[id_field].to_mongo(self[id_field])
try:
@@ -110,6 +118,8 @@ class Document(BaseDocument):
message = u'Could not delete document (%s)' % err.message
raise OperationError(message)
signals.post_delete.send(self)
@classmethod
def register_delete_rule(cls, document_cls, field_name, rule):
"""This method registers the delete rules to apply when removing this

View File

@@ -1,4 +1,5 @@
from base import BaseField, ObjectIdField, ValidationError, get_document
from base import (BaseField, DereferenceBaseField, ObjectIdField,
ValidationError, get_document)
from queryset import DO_NOTHING
from document import Document, EmbeddedDocument
from connection import _get_db
@@ -12,7 +13,6 @@ import pymongo.binary
import datetime, time
import decimal
import gridfs
import warnings
__all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField',
@@ -118,8 +118,8 @@ class EmailField(StringField):
EMAIL_REGEX = re.compile(
r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*" # dot-atom
r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string
r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain
r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string
r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain
)
def validate(self, value):
@@ -153,6 +153,7 @@ class IntField(BaseField):
def prepare_query_value(self, op, value):
return int(value)
class FloatField(BaseField):
"""An floating point number field.
"""
@@ -178,6 +179,7 @@ class FloatField(BaseField):
def prepare_query_value(self, op, value):
return float(value)
class DecimalField(BaseField):
"""A fixed-point decimal number field.
@@ -252,21 +254,21 @@ class DateTimeField(BaseField):
else:
usecs = 0
kwargs = {'microsecond': usecs}
try: # Seconds are optional, so try converting seconds first.
try: # Seconds are optional, so try converting seconds first.
return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M:%S')[:6],
**kwargs)
except ValueError:
try: # Try without seconds.
try: # Try without seconds.
return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M')[:5],
**kwargs)
except ValueError: # Try without hour/minutes/seconds.
except ValueError: # Try without hour/minutes/seconds.
try:
return datetime.datetime(*time.strptime(value, '%Y-%m-%d')[:3],
**kwargs)
except ValueError:
return None
class EmbeddedDocumentField(BaseField):
"""An embedded document field. Only valid values are subclasses of
:class:`~mongoengine.EmbeddedDocument`.
@@ -314,7 +316,7 @@ class EmbeddedDocumentField(BaseField):
return self.to_mongo(value)
class ListField(BaseField):
class ListField(DereferenceBaseField):
"""A list field that wraps a standard field, allowing multiple instances
of the field to be used as a list in the database.
"""
@@ -330,42 +332,6 @@ class ListField(BaseField):
kwargs.setdefault('default', lambda: [])
super(ListField, self).__init__(**kwargs)
def __get__(self, instance, owner):
"""Descriptor to automatically dereference references.
"""
if instance is None:
# Document class being used rather than a document object
return self
if isinstance(self.field, ReferenceField):
referenced_type = self.field.document_type
# Get value from document instance if available
value_list = instance._data.get(self.name)
if value_list:
deref_list = []
for value in value_list:
# Dereference DBRefs
if isinstance(value, (pymongo.dbref.DBRef)):
value = _get_db().dereference(value)
deref_list.append(referenced_type._from_son(value))
else:
deref_list.append(value)
instance._data[self.name] = deref_list
if isinstance(self.field, GenericReferenceField):
value_list = instance._data.get(self.name)
if value_list:
deref_list = []
for value in value_list:
# Dereference DBRefs
if isinstance(value, (dict, pymongo.son.SON)):
deref_list.append(self.field.dereference(value))
else:
deref_list.append(value)
instance._data[self.name] = deref_list
return super(ListField, self).__get__(instance, owner)
def to_python(self, value):
return [self.field.to_python(item) for item in value]
@@ -459,10 +425,10 @@ class DictField(BaseField):
if op in match_operators and isinstance(value, basestring):
return StringField().prepare_query_value(op, value)
return super(DictField,self).prepare_query_value(op, value)
return super(DictField, self).prepare_query_value(op, value)
class MapField(BaseField):
class MapField(DereferenceBaseField):
"""A field that maps a name to a specified field type. Similar to
a DictField, except the 'value' of each item must match the specified
field type.
@@ -494,47 +460,11 @@ class MapField(BaseField):
except Exception, err:
raise ValidationError('Invalid MapField item (%s)' % str(item))
def __get__(self, instance, owner):
"""Descriptor to automatically dereference references.
"""
if instance is None:
# Document class being used rather than a document object
return self
if isinstance(self.field, ReferenceField):
referenced_type = self.field.document_type
# Get value from document instance if available
value_dict = instance._data.get(self.name)
if value_dict:
deref_dict = []
for key,value in value_dict.iteritems():
# Dereference DBRefs
if isinstance(value, (pymongo.dbref.DBRef)):
value = _get_db().dereference(value)
deref_dict[key] = referenced_type._from_son(value)
else:
deref_dict[key] = value
instance._data[self.name] = deref_dict
if isinstance(self.field, GenericReferenceField):
value_dict = instance._data.get(self.name)
if value_dict:
deref_dict = []
for key,value in value_dict.iteritems():
# Dereference DBRefs
if isinstance(value, (dict, pymongo.son.SON)):
deref_dict[key] = self.field.dereference(value)
else:
deref_dict[key] = value
instance._data[self.name] = deref_dict
return super(MapField, self).__get__(instance, owner)
def to_python(self, value):
return dict( [(key,self.field.to_python(item)) for key,item in value.iteritems()] )
return dict([(key, self.field.to_python(item)) for key, item in value.iteritems()])
def to_mongo(self, value):
return dict( [(key,self.field.to_mongo(item)) for key,item in value.iteritems()] )
return dict([(key, self.field.to_mongo(item)) for key, item in value.iteritems()])
def prepare_query_value(self, op, value):
if op not in ('set', 'unset'):
@@ -752,11 +682,11 @@ class GridFSProxy(object):
self.newfile = self.fs.new_file(**kwargs)
self.grid_id = self.newfile._id
def put(self, file, **kwargs):
def put(self, file_obj, **kwargs):
if self.grid_id:
raise GridFSError('This document already has a file. Either delete '
'it or call replace to overwrite it')
self.grid_id = self.fs.put(file, **kwargs)
self.grid_id = self.fs.put(file_obj, **kwargs)
def write(self, string):
if self.grid_id:
@@ -785,9 +715,9 @@ class GridFSProxy(object):
self.grid_id = None
self.gridout = None
def replace(self, file, **kwargs):
def replace(self, file_obj, **kwargs):
self.delete()
self.put(file, **kwargs)
self.put(file_obj, **kwargs)
def close(self):
if self.newfile:

44
mongoengine/signals.py Normal file
View File

@@ -0,0 +1,44 @@
# -*- coding: utf-8 -*-
__all__ = ['pre_init', 'post_init', 'pre_save', 'post_save',
'pre_delete', 'post_delete']
signals_available = False
try:
from blinker import Namespace
signals_available = True
except ImportError:
class Namespace(object):
def signal(self, name, doc=None):
return _FakeSignal(name, doc)
class _FakeSignal(object):
"""If blinker is unavailable, create a fake class with the same
interface that allows sending of signals but will fail with an
error on anything else. Instead of doing anything on send, it
will just ignore the arguments and do nothing instead.
"""
def __init__(self, name, doc=None):
self.name = name
self.__doc__ = doc
def _fail(self, *args, **kwargs):
raise RuntimeError('signalling support is unavailable '
'because the blinker library is '
'not installed.')
send = lambda *a, **kw: None
connect = disconnect = has_receivers_for = receivers_for = \
temporarily_connected_to = _fail
del _fail
# the namespace for code signals. If you are not mongoengine code, do
# not put signals in here. Create your own namespace instead.
_signals = Namespace()
pre_init = _signals.signal('pre_init')
post_init = _signals.signal('post_init')
pre_save = _signals.signal('pre_save')
post_save = _signals.signal('post_save')
pre_delete = _signals.signal('pre_delete')
post_delete = _signals.signal('post_delete')

59
mongoengine/tests.py Normal file
View File

@@ -0,0 +1,59 @@
from mongoengine.connection import _get_db
class query_counter(object):
""" Query_counter contextmanager to get the number of queries. """
def __init__(self):
""" Construct the query_counter. """
self.counter = 0
self.db = _get_db()
def __enter__(self):
""" On every with block we need to drop the profile collection. """
self.db.set_profiling_level(0)
self.db.system.profile.drop()
self.db.set_profiling_level(2)
return self
def __exit__(self, t, value, traceback):
""" Reset the profiling level. """
self.db.set_profiling_level(0)
def __eq__(self, value):
""" == Compare querycounter. """
return value == self._get_count()
def __ne__(self, value):
""" != Compare querycounter. """
return not self.__eq__(value)
def __lt__(self, value):
""" < Compare querycounter. """
return self._get_count() < value
def __le__(self, value):
""" <= Compare querycounter. """
return self._get_count() <= value
def __gt__(self, value):
""" > Compare querycounter. """
return self._get_count() > value
def __ge__(self, value):
""" >= Compare querycounter. """
return self._get_count() >= value
def __int__(self):
""" int representation. """
return self._get_count()
def __repr__(self):
""" repr query_counter as the number of queries. """
return u"%s" % self._get_count()
def _get_count(self):
""" Get the number of queries. """
count = self.db.system.profile.find().count() - self.counter
self.counter += 1
return count