Merge branch 'dev' into feature/dev-indexes
This commit is contained in:
commit
8553022b0e
1
AUTHORS
1
AUTHORS
@ -3,3 +3,4 @@ Matt Dennewitz <mattdennewitz@gmail.com>
|
||||
Deepak Thukral <iapain@yahoo.com>
|
||||
Florian Schlachter <flori@n-schlachter.de>
|
||||
Steve Challis <steve@stevechallis.com>
|
||||
Ross Lawley <ross.lawley@gmail.com>
|
||||
|
@ -5,6 +5,9 @@ Changelog
|
||||
Changes in dev
|
||||
==============
|
||||
|
||||
- Added blinker signal support
|
||||
- Added query_counter context manager for tests
|
||||
- Added DereferenceBaseField - for improved performance in field dereferencing
|
||||
- Added optional map_reduce method item_frequencies
|
||||
- Added inline_map_reduce option to map_reduce
|
||||
- Updated connection exception so it provides more info on the cause.
|
||||
|
@ -11,3 +11,4 @@ User Guide
|
||||
document-instances
|
||||
querying
|
||||
gridfs
|
||||
signals
|
||||
|
49
docs/guide/signals.rst
Normal file
49
docs/guide/signals.rst
Normal file
@ -0,0 +1,49 @@
|
||||
.. _signals:
|
||||
|
||||
Signals
|
||||
=======
|
||||
|
||||
.. versionadded:: 0.5
|
||||
|
||||
Signal support is provided by the excellent `blinker`_ library and
|
||||
will gracefully fall back if it is not available.
|
||||
|
||||
|
||||
The following document signals exist in MongoEngine and are pretty self explaintary:
|
||||
|
||||
* `mongoengine.signals.pre_init`
|
||||
* `mongoengine.signals.post_init`
|
||||
* `mongoengine.signals.pre_save`
|
||||
* `mongoengine.signals.post_save`
|
||||
* `mongoengine.signals.pre_delete`
|
||||
* `mongoengine.signals.post_delete`
|
||||
|
||||
Example usage::
|
||||
|
||||
from mongoengine import *
|
||||
from mongoengine import signals
|
||||
|
||||
class Author(Document):
|
||||
name = StringField()
|
||||
|
||||
def __unicode__(self):
|
||||
return self.name
|
||||
|
||||
@classmethod
|
||||
def pre_save(cls, instance, **kwargs):
|
||||
logging.debug("Pre Save: %s" % instance.name)
|
||||
|
||||
@classmethod
|
||||
def post_save(cls, instance, **kwargs):
|
||||
logging.debug("Post Save: %s" % instance.name)
|
||||
if 'created' in kwargs:
|
||||
if kwargs['created']:
|
||||
logging.debug("Created")
|
||||
else:
|
||||
logging.debug("Updated")
|
||||
|
||||
signals.pre_save.connect(Author.pre_save)
|
||||
signals.post_save.connect(Author.post_save)
|
||||
|
||||
|
||||
.. _blinker: http://pypi.python.org/pypi/blinker
|
@ -6,9 +6,11 @@ import connection
|
||||
from connection import *
|
||||
import queryset
|
||||
from queryset import *
|
||||
import signals
|
||||
from signals import *
|
||||
|
||||
__all__ = (document.__all__ + fields.__all__ + connection.__all__ +
|
||||
queryset.__all__)
|
||||
queryset.__all__ + signals.__all__)
|
||||
|
||||
__author__ = 'Harry Marr'
|
||||
|
||||
|
@ -2,9 +2,12 @@ from queryset import QuerySet, QuerySetManager
|
||||
from queryset import DoesNotExist, MultipleObjectsReturned
|
||||
from queryset import DO_NOTHING
|
||||
|
||||
from mongoengine import signals
|
||||
|
||||
import sys
|
||||
import pymongo
|
||||
import pymongo.objectid
|
||||
from operator import itemgetter
|
||||
|
||||
|
||||
class NotRegistered(Exception):
|
||||
@ -126,6 +129,88 @@ class BaseField(object):
|
||||
|
||||
self.validate(value)
|
||||
|
||||
|
||||
class DereferenceBaseField(BaseField):
|
||||
"""Handles the lazy dereferencing of a queryset. Will dereference all
|
||||
items in a list / dict rather than one at a time.
|
||||
"""
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
"""Descriptor to automatically dereference references.
|
||||
"""
|
||||
from fields import ReferenceField, GenericReferenceField
|
||||
from connection import _get_db
|
||||
|
||||
if instance is None:
|
||||
# Document class being used rather than a document object
|
||||
return self
|
||||
|
||||
# Get value from document instance if available
|
||||
value_list = instance._data.get(self.name)
|
||||
if not value_list:
|
||||
return super(DereferenceBaseField, self).__get__(instance, owner)
|
||||
|
||||
is_list = False
|
||||
if not hasattr(value_list, 'items'):
|
||||
is_list = True
|
||||
value_list = dict([(k,v) for k,v in enumerate(value_list)])
|
||||
|
||||
if isinstance(self.field, ReferenceField) and value_list:
|
||||
db = _get_db()
|
||||
dbref = {}
|
||||
collections = {}
|
||||
|
||||
for k, v in value_list.items():
|
||||
dbref[k] = v
|
||||
# Save any DBRefs
|
||||
if isinstance(v, (pymongo.dbref.DBRef)):
|
||||
collections.setdefault(v.collection, []).append((k, v))
|
||||
|
||||
# For each collection get the references
|
||||
for collection, dbrefs in collections.items():
|
||||
id_map = dict([(v.id, k) for k, v in dbrefs])
|
||||
references = db[collection].find({'_id': {'$in': id_map.keys()}})
|
||||
for ref in references:
|
||||
key = id_map[ref['_id']]
|
||||
dbref[key] = get_document(ref['_cls'])._from_son(ref)
|
||||
|
||||
if is_list:
|
||||
dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))]
|
||||
instance._data[self.name] = dbref
|
||||
|
||||
# Get value from document instance if available
|
||||
if isinstance(self.field, GenericReferenceField) and value_list:
|
||||
db = _get_db()
|
||||
value_list = [(k,v) for k,v in value_list.items()]
|
||||
dbref = {}
|
||||
classes = {}
|
||||
|
||||
for k, v in value_list:
|
||||
dbref[k] = v
|
||||
# Save any DBRefs
|
||||
if isinstance(v, (dict, pymongo.son.SON)):
|
||||
classes.setdefault(v['_cls'], []).append((k, v))
|
||||
|
||||
# For each collection get the references
|
||||
for doc_cls, dbrefs in classes.items():
|
||||
id_map = dict([(v['_ref'].id, k) for k, v in dbrefs])
|
||||
doc_cls = get_document(doc_cls)
|
||||
collection = doc_cls._meta['collection']
|
||||
references = db[collection].find({'_id': {'$in': id_map.keys()}})
|
||||
|
||||
for ref in references:
|
||||
key = id_map[ref['_id']]
|
||||
dbref[key] = doc_cls._from_son(ref)
|
||||
|
||||
if is_list:
|
||||
dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))]
|
||||
|
||||
instance._data[self.name] = dbref
|
||||
|
||||
return super(DereferenceBaseField, self).__get__(instance, owner)
|
||||
|
||||
|
||||
|
||||
class ObjectIdField(BaseField):
|
||||
"""An field wrapper around MongoDB's ObjectIds.
|
||||
"""
|
||||
@ -382,6 +467,8 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
|
||||
class BaseDocument(object):
|
||||
|
||||
def __init__(self, **values):
|
||||
signals.pre_init.send(self, values=values)
|
||||
|
||||
self._data = {}
|
||||
# Assign default values to instance
|
||||
for attr_name in self._fields.keys():
|
||||
@ -395,6 +482,8 @@ class BaseDocument(object):
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
signals.post_init.send(self)
|
||||
|
||||
def validate(self):
|
||||
"""Ensure that all fields' values are valid and that required fields
|
||||
are present.
|
||||
|
@ -1,3 +1,4 @@
|
||||
from mongoengine import signals
|
||||
from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument,
|
||||
ValidationError)
|
||||
from queryset import OperationError
|
||||
@ -75,6 +76,8 @@ class Document(BaseDocument):
|
||||
For example, ``save(..., w=2, fsync=True)`` will wait until at least two servers
|
||||
have recorded the write and will force an fsync on each server being written to.
|
||||
"""
|
||||
signals.pre_save.send(self)
|
||||
|
||||
if validate:
|
||||
self.validate()
|
||||
|
||||
@ -82,6 +85,7 @@ class Document(BaseDocument):
|
||||
write_options = {}
|
||||
|
||||
doc = self.to_mongo()
|
||||
created = '_id' not in doc
|
||||
try:
|
||||
collection = self.__class__.objects._collection
|
||||
if force_insert:
|
||||
@ -96,12 +100,16 @@ class Document(BaseDocument):
|
||||
id_field = self._meta['id_field']
|
||||
self[id_field] = self._fields[id_field].to_python(object_id)
|
||||
|
||||
signals.post_save.send(self, created=created)
|
||||
|
||||
def delete(self, safe=False):
|
||||
"""Delete the :class:`~mongoengine.Document` from the database. This
|
||||
will only take effect if the document has been previously saved.
|
||||
|
||||
:param safe: check if the operation succeeded before returning
|
||||
"""
|
||||
signals.pre_delete.send(self)
|
||||
|
||||
id_field = self._meta['id_field']
|
||||
object_id = self._fields[id_field].to_mongo(self[id_field])
|
||||
try:
|
||||
@ -110,6 +118,8 @@ class Document(BaseDocument):
|
||||
message = u'Could not delete document (%s)' % err.message
|
||||
raise OperationError(message)
|
||||
|
||||
signals.post_delete.send(self)
|
||||
|
||||
@classmethod
|
||||
def register_delete_rule(cls, document_cls, field_name, rule):
|
||||
"""This method registers the delete rules to apply when removing this
|
||||
|
@ -1,4 +1,5 @@
|
||||
from base import BaseField, ObjectIdField, ValidationError, get_document
|
||||
from base import (BaseField, DereferenceBaseField, ObjectIdField,
|
||||
ValidationError, get_document)
|
||||
from queryset import DO_NOTHING
|
||||
from document import Document, EmbeddedDocument
|
||||
from connection import _get_db
|
||||
@ -12,7 +13,6 @@ import pymongo.binary
|
||||
import datetime, time
|
||||
import decimal
|
||||
import gridfs
|
||||
import warnings
|
||||
|
||||
|
||||
__all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField',
|
||||
@ -118,8 +118,8 @@ class EmailField(StringField):
|
||||
|
||||
EMAIL_REGEX = re.compile(
|
||||
r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*" # dot-atom
|
||||
r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string
|
||||
r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain
|
||||
r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string
|
||||
r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain
|
||||
)
|
||||
|
||||
def validate(self, value):
|
||||
@ -153,6 +153,7 @@ class IntField(BaseField):
|
||||
def prepare_query_value(self, op, value):
|
||||
return int(value)
|
||||
|
||||
|
||||
class FloatField(BaseField):
|
||||
"""An floating point number field.
|
||||
"""
|
||||
@ -178,6 +179,7 @@ class FloatField(BaseField):
|
||||
def prepare_query_value(self, op, value):
|
||||
return float(value)
|
||||
|
||||
|
||||
class DecimalField(BaseField):
|
||||
"""A fixed-point decimal number field.
|
||||
|
||||
@ -252,21 +254,21 @@ class DateTimeField(BaseField):
|
||||
else:
|
||||
usecs = 0
|
||||
kwargs = {'microsecond': usecs}
|
||||
try: # Seconds are optional, so try converting seconds first.
|
||||
try: # Seconds are optional, so try converting seconds first.
|
||||
return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M:%S')[:6],
|
||||
**kwargs)
|
||||
|
||||
except ValueError:
|
||||
try: # Try without seconds.
|
||||
try: # Try without seconds.
|
||||
return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M')[:5],
|
||||
**kwargs)
|
||||
except ValueError: # Try without hour/minutes/seconds.
|
||||
except ValueError: # Try without hour/minutes/seconds.
|
||||
try:
|
||||
return datetime.datetime(*time.strptime(value, '%Y-%m-%d')[:3],
|
||||
**kwargs)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
class EmbeddedDocumentField(BaseField):
|
||||
"""An embedded document field. Only valid values are subclasses of
|
||||
:class:`~mongoengine.EmbeddedDocument`.
|
||||
@ -314,7 +316,7 @@ class EmbeddedDocumentField(BaseField):
|
||||
return self.to_mongo(value)
|
||||
|
||||
|
||||
class ListField(BaseField):
|
||||
class ListField(DereferenceBaseField):
|
||||
"""A list field that wraps a standard field, allowing multiple instances
|
||||
of the field to be used as a list in the database.
|
||||
"""
|
||||
@ -330,42 +332,6 @@ class ListField(BaseField):
|
||||
kwargs.setdefault('default', lambda: [])
|
||||
super(ListField, self).__init__(**kwargs)
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
"""Descriptor to automatically dereference references.
|
||||
"""
|
||||
if instance is None:
|
||||
# Document class being used rather than a document object
|
||||
return self
|
||||
|
||||
if isinstance(self.field, ReferenceField):
|
||||
referenced_type = self.field.document_type
|
||||
# Get value from document instance if available
|
||||
value_list = instance._data.get(self.name)
|
||||
if value_list:
|
||||
deref_list = []
|
||||
for value in value_list:
|
||||
# Dereference DBRefs
|
||||
if isinstance(value, (pymongo.dbref.DBRef)):
|
||||
value = _get_db().dereference(value)
|
||||
deref_list.append(referenced_type._from_son(value))
|
||||
else:
|
||||
deref_list.append(value)
|
||||
instance._data[self.name] = deref_list
|
||||
|
||||
if isinstance(self.field, GenericReferenceField):
|
||||
value_list = instance._data.get(self.name)
|
||||
if value_list:
|
||||
deref_list = []
|
||||
for value in value_list:
|
||||
# Dereference DBRefs
|
||||
if isinstance(value, (dict, pymongo.son.SON)):
|
||||
deref_list.append(self.field.dereference(value))
|
||||
else:
|
||||
deref_list.append(value)
|
||||
instance._data[self.name] = deref_list
|
||||
|
||||
return super(ListField, self).__get__(instance, owner)
|
||||
|
||||
def to_python(self, value):
|
||||
return [self.field.to_python(item) for item in value]
|
||||
|
||||
@ -459,10 +425,10 @@ class DictField(BaseField):
|
||||
if op in match_operators and isinstance(value, basestring):
|
||||
return StringField().prepare_query_value(op, value)
|
||||
|
||||
return super(DictField,self).prepare_query_value(op, value)
|
||||
return super(DictField, self).prepare_query_value(op, value)
|
||||
|
||||
|
||||
class MapField(BaseField):
|
||||
class MapField(DereferenceBaseField):
|
||||
"""A field that maps a name to a specified field type. Similar to
|
||||
a DictField, except the 'value' of each item must match the specified
|
||||
field type.
|
||||
@ -494,47 +460,11 @@ class MapField(BaseField):
|
||||
except Exception, err:
|
||||
raise ValidationError('Invalid MapField item (%s)' % str(item))
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
"""Descriptor to automatically dereference references.
|
||||
"""
|
||||
if instance is None:
|
||||
# Document class being used rather than a document object
|
||||
return self
|
||||
|
||||
if isinstance(self.field, ReferenceField):
|
||||
referenced_type = self.field.document_type
|
||||
# Get value from document instance if available
|
||||
value_dict = instance._data.get(self.name)
|
||||
if value_dict:
|
||||
deref_dict = []
|
||||
for key,value in value_dict.iteritems():
|
||||
# Dereference DBRefs
|
||||
if isinstance(value, (pymongo.dbref.DBRef)):
|
||||
value = _get_db().dereference(value)
|
||||
deref_dict[key] = referenced_type._from_son(value)
|
||||
else:
|
||||
deref_dict[key] = value
|
||||
instance._data[self.name] = deref_dict
|
||||
|
||||
if isinstance(self.field, GenericReferenceField):
|
||||
value_dict = instance._data.get(self.name)
|
||||
if value_dict:
|
||||
deref_dict = []
|
||||
for key,value in value_dict.iteritems():
|
||||
# Dereference DBRefs
|
||||
if isinstance(value, (dict, pymongo.son.SON)):
|
||||
deref_dict[key] = self.field.dereference(value)
|
||||
else:
|
||||
deref_dict[key] = value
|
||||
instance._data[self.name] = deref_dict
|
||||
|
||||
return super(MapField, self).__get__(instance, owner)
|
||||
|
||||
def to_python(self, value):
|
||||
return dict( [(key,self.field.to_python(item)) for key,item in value.iteritems()] )
|
||||
return dict([(key, self.field.to_python(item)) for key, item in value.iteritems()])
|
||||
|
||||
def to_mongo(self, value):
|
||||
return dict( [(key,self.field.to_mongo(item)) for key,item in value.iteritems()] )
|
||||
return dict([(key, self.field.to_mongo(item)) for key, item in value.iteritems()])
|
||||
|
||||
def prepare_query_value(self, op, value):
|
||||
if op not in ('set', 'unset'):
|
||||
@ -752,11 +682,11 @@ class GridFSProxy(object):
|
||||
self.newfile = self.fs.new_file(**kwargs)
|
||||
self.grid_id = self.newfile._id
|
||||
|
||||
def put(self, file, **kwargs):
|
||||
def put(self, file_obj, **kwargs):
|
||||
if self.grid_id:
|
||||
raise GridFSError('This document already has a file. Either delete '
|
||||
'it or call replace to overwrite it')
|
||||
self.grid_id = self.fs.put(file, **kwargs)
|
||||
self.grid_id = self.fs.put(file_obj, **kwargs)
|
||||
|
||||
def write(self, string):
|
||||
if self.grid_id:
|
||||
@ -785,9 +715,9 @@ class GridFSProxy(object):
|
||||
self.grid_id = None
|
||||
self.gridout = None
|
||||
|
||||
def replace(self, file, **kwargs):
|
||||
def replace(self, file_obj, **kwargs):
|
||||
self.delete()
|
||||
self.put(file, **kwargs)
|
||||
self.put(file_obj, **kwargs)
|
||||
|
||||
def close(self):
|
||||
if self.newfile:
|
||||
|
44
mongoengine/signals.py
Normal file
44
mongoengine/signals.py
Normal file
@ -0,0 +1,44 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
__all__ = ['pre_init', 'post_init', 'pre_save', 'post_save',
|
||||
'pre_delete', 'post_delete']
|
||||
|
||||
signals_available = False
|
||||
try:
|
||||
from blinker import Namespace
|
||||
signals_available = True
|
||||
except ImportError:
|
||||
class Namespace(object):
|
||||
def signal(self, name, doc=None):
|
||||
return _FakeSignal(name, doc)
|
||||
|
||||
class _FakeSignal(object):
|
||||
"""If blinker is unavailable, create a fake class with the same
|
||||
interface that allows sending of signals but will fail with an
|
||||
error on anything else. Instead of doing anything on send, it
|
||||
will just ignore the arguments and do nothing instead.
|
||||
"""
|
||||
|
||||
def __init__(self, name, doc=None):
|
||||
self.name = name
|
||||
self.__doc__ = doc
|
||||
|
||||
def _fail(self, *args, **kwargs):
|
||||
raise RuntimeError('signalling support is unavailable '
|
||||
'because the blinker library is '
|
||||
'not installed.')
|
||||
send = lambda *a, **kw: None
|
||||
connect = disconnect = has_receivers_for = receivers_for = \
|
||||
temporarily_connected_to = _fail
|
||||
del _fail
|
||||
|
||||
# the namespace for code signals. If you are not mongoengine code, do
|
||||
# not put signals in here. Create your own namespace instead.
|
||||
_signals = Namespace()
|
||||
|
||||
pre_init = _signals.signal('pre_init')
|
||||
post_init = _signals.signal('post_init')
|
||||
pre_save = _signals.signal('pre_save')
|
||||
post_save = _signals.signal('post_save')
|
||||
pre_delete = _signals.signal('pre_delete')
|
||||
post_delete = _signals.signal('post_delete')
|
59
mongoengine/tests.py
Normal file
59
mongoengine/tests.py
Normal file
@ -0,0 +1,59 @@
|
||||
from mongoengine.connection import _get_db
|
||||
|
||||
|
||||
class query_counter(object):
|
||||
""" Query_counter contextmanager to get the number of queries. """
|
||||
|
||||
def __init__(self):
|
||||
""" Construct the query_counter. """
|
||||
self.counter = 0
|
||||
self.db = _get_db()
|
||||
|
||||
def __enter__(self):
|
||||
""" On every with block we need to drop the profile collection. """
|
||||
self.db.set_profiling_level(0)
|
||||
self.db.system.profile.drop()
|
||||
self.db.set_profiling_level(2)
|
||||
return self
|
||||
|
||||
def __exit__(self, t, value, traceback):
|
||||
""" Reset the profiling level. """
|
||||
self.db.set_profiling_level(0)
|
||||
|
||||
def __eq__(self, value):
|
||||
""" == Compare querycounter. """
|
||||
return value == self._get_count()
|
||||
|
||||
def __ne__(self, value):
|
||||
""" != Compare querycounter. """
|
||||
return not self.__eq__(value)
|
||||
|
||||
def __lt__(self, value):
|
||||
""" < Compare querycounter. """
|
||||
return self._get_count() < value
|
||||
|
||||
def __le__(self, value):
|
||||
""" <= Compare querycounter. """
|
||||
return self._get_count() <= value
|
||||
|
||||
def __gt__(self, value):
|
||||
""" > Compare querycounter. """
|
||||
return self._get_count() > value
|
||||
|
||||
def __ge__(self, value):
|
||||
""" >= Compare querycounter. """
|
||||
return self._get_count() >= value
|
||||
|
||||
def __int__(self):
|
||||
""" int representation. """
|
||||
return self._get_count()
|
||||
|
||||
def __repr__(self):
|
||||
""" repr query_counter as the number of queries. """
|
||||
return u"%s" % self._get_count()
|
||||
|
||||
def _get_count(self):
|
||||
""" Get the number of queries. """
|
||||
count = self.db.system.profile.find().count() - self.counter
|
||||
self.counter += 1
|
||||
return count
|
2
setup.py
2
setup.py
@ -45,6 +45,6 @@ setup(name='mongoengine',
|
||||
long_description=LONG_DESCRIPTION,
|
||||
platforms=['any'],
|
||||
classifiers=CLASSIFIERS,
|
||||
install_requires=['pymongo'],
|
||||
install_requires=['pymongo', 'blinker'],
|
||||
test_suite='tests',
|
||||
)
|
||||
|
288
tests/dereference.py
Normal file
288
tests/dereference.py
Normal file
@ -0,0 +1,288 @@
|
||||
import unittest
|
||||
|
||||
from mongoengine import *
|
||||
from mongoengine.connection import _get_db
|
||||
from mongoengine.tests import query_counter
|
||||
|
||||
|
||||
class FieldTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
connect(db='mongoenginetest')
|
||||
self.db = _get_db()
|
||||
|
||||
def test_list_item_dereference(self):
|
||||
"""Ensure that DBRef items in ListFields are dereferenced.
|
||||
"""
|
||||
class User(Document):
|
||||
name = StringField()
|
||||
|
||||
class Group(Document):
|
||||
members = ListField(ReferenceField(User))
|
||||
|
||||
User.drop_collection()
|
||||
Group.drop_collection()
|
||||
|
||||
for i in xrange(1, 51):
|
||||
user = User(name='user %s' % i)
|
||||
user.save()
|
||||
|
||||
group = Group(members=User.objects)
|
||||
group.save()
|
||||
|
||||
with query_counter() as q:
|
||||
self.assertEqual(q, 0)
|
||||
|
||||
group_obj = Group.objects.first()
|
||||
self.assertEqual(q, 1)
|
||||
|
||||
[m for m in group_obj.members]
|
||||
self.assertEqual(q, 2)
|
||||
|
||||
User.drop_collection()
|
||||
Group.drop_collection()
|
||||
|
||||
def test_recursive_reference(self):
|
||||
"""Ensure that ReferenceFields can reference their own documents.
|
||||
"""
|
||||
class Employee(Document):
|
||||
name = StringField()
|
||||
boss = ReferenceField('self')
|
||||
friends = ListField(ReferenceField('self'))
|
||||
|
||||
bill = Employee(name='Bill Lumbergh')
|
||||
bill.save()
|
||||
|
||||
michael = Employee(name='Michael Bolton')
|
||||
michael.save()
|
||||
|
||||
samir = Employee(name='Samir Nagheenanajar')
|
||||
samir.save()
|
||||
|
||||
friends = [michael, samir]
|
||||
peter = Employee(name='Peter Gibbons', boss=bill, friends=friends)
|
||||
peter.save()
|
||||
|
||||
with query_counter() as q:
|
||||
self.assertEqual(q, 0)
|
||||
|
||||
peter = Employee.objects.with_id(peter.id)
|
||||
self.assertEqual(q, 1)
|
||||
|
||||
peter.boss
|
||||
self.assertEqual(q, 2)
|
||||
|
||||
peter.friends
|
||||
self.assertEqual(q, 3)
|
||||
|
||||
def test_generic_reference(self):
|
||||
|
||||
class UserA(Document):
|
||||
name = StringField()
|
||||
|
||||
class UserB(Document):
|
||||
name = StringField()
|
||||
|
||||
class UserC(Document):
|
||||
name = StringField()
|
||||
|
||||
class Group(Document):
|
||||
members = ListField(GenericReferenceField())
|
||||
|
||||
UserA.drop_collection()
|
||||
UserB.drop_collection()
|
||||
UserC.drop_collection()
|
||||
Group.drop_collection()
|
||||
|
||||
members = []
|
||||
for i in xrange(1, 51):
|
||||
a = UserA(name='User A %s' % i)
|
||||
a.save()
|
||||
|
||||
b = UserB(name='User B %s' % i)
|
||||
b.save()
|
||||
|
||||
c = UserC(name='User C %s' % i)
|
||||
c.save()
|
||||
|
||||
members += [a, b, c]
|
||||
|
||||
group = Group(members=members)
|
||||
group.save()
|
||||
|
||||
with query_counter() as q:
|
||||
self.assertEqual(q, 0)
|
||||
|
||||
group_obj = Group.objects.first()
|
||||
self.assertEqual(q, 1)
|
||||
|
||||
[m for m in group_obj.members]
|
||||
self.assertEqual(q, 4)
|
||||
|
||||
[m for m in group_obj.members]
|
||||
self.assertEqual(q, 4)
|
||||
|
||||
UserA.drop_collection()
|
||||
UserB.drop_collection()
|
||||
UserC.drop_collection()
|
||||
Group.drop_collection()
|
||||
|
||||
def test_map_field_reference(self):
|
||||
|
||||
class User(Document):
|
||||
name = StringField()
|
||||
|
||||
class Group(Document):
|
||||
members = MapField(ReferenceField(User))
|
||||
|
||||
User.drop_collection()
|
||||
Group.drop_collection()
|
||||
|
||||
members = []
|
||||
for i in xrange(1, 51):
|
||||
user = User(name='user %s' % i)
|
||||
user.save()
|
||||
members.append(user)
|
||||
|
||||
group = Group(members=dict([(str(u.id), u) for u in members]))
|
||||
group.save()
|
||||
|
||||
with query_counter() as q:
|
||||
self.assertEqual(q, 0)
|
||||
|
||||
group_obj = Group.objects.first()
|
||||
self.assertEqual(q, 1)
|
||||
|
||||
[m for m in group_obj.members]
|
||||
self.assertEqual(q, 2)
|
||||
|
||||
User.drop_collection()
|
||||
Group.drop_collection()
|
||||
|
||||
def ztest_generic_reference_dict_field(self):
|
||||
|
||||
class UserA(Document):
|
||||
name = StringField()
|
||||
|
||||
class UserB(Document):
|
||||
name = StringField()
|
||||
|
||||
class UserC(Document):
|
||||
name = StringField()
|
||||
|
||||
class Group(Document):
|
||||
members = DictField()
|
||||
|
||||
UserA.drop_collection()
|
||||
UserB.drop_collection()
|
||||
UserC.drop_collection()
|
||||
Group.drop_collection()
|
||||
|
||||
members = []
|
||||
for i in xrange(1, 51):
|
||||
a = UserA(name='User A %s' % i)
|
||||
a.save()
|
||||
|
||||
b = UserB(name='User B %s' % i)
|
||||
b.save()
|
||||
|
||||
c = UserC(name='User C %s' % i)
|
||||
c.save()
|
||||
|
||||
members += [a, b, c]
|
||||
|
||||
group = Group(members=dict([(str(u.id), u) for u in members]))
|
||||
group.save()
|
||||
|
||||
with query_counter() as q:
|
||||
self.assertEqual(q, 0)
|
||||
|
||||
group_obj = Group.objects.first()
|
||||
self.assertEqual(q, 1)
|
||||
|
||||
[m for m in group_obj.members]
|
||||
self.assertEqual(q, 4)
|
||||
|
||||
[m for m in group_obj.members]
|
||||
self.assertEqual(q, 4)
|
||||
|
||||
group.members = {}
|
||||
group.save()
|
||||
|
||||
with query_counter() as q:
|
||||
self.assertEqual(q, 0)
|
||||
|
||||
group_obj = Group.objects.first()
|
||||
self.assertEqual(q, 1)
|
||||
|
||||
[m for m in group_obj.members]
|
||||
self.assertEqual(q, 1)
|
||||
|
||||
UserA.drop_collection()
|
||||
UserB.drop_collection()
|
||||
UserC.drop_collection()
|
||||
Group.drop_collection()
|
||||
|
||||
def test_generic_reference_map_field(self):
|
||||
|
||||
class UserA(Document):
|
||||
name = StringField()
|
||||
|
||||
class UserB(Document):
|
||||
name = StringField()
|
||||
|
||||
class UserC(Document):
|
||||
name = StringField()
|
||||
|
||||
class Group(Document):
|
||||
members = MapField(GenericReferenceField())
|
||||
|
||||
UserA.drop_collection()
|
||||
UserB.drop_collection()
|
||||
UserC.drop_collection()
|
||||
Group.drop_collection()
|
||||
|
||||
members = []
|
||||
for i in xrange(1, 51):
|
||||
a = UserA(name='User A %s' % i)
|
||||
a.save()
|
||||
|
||||
b = UserB(name='User B %s' % i)
|
||||
b.save()
|
||||
|
||||
c = UserC(name='User C %s' % i)
|
||||
c.save()
|
||||
|
||||
members += [a, b, c]
|
||||
|
||||
group = Group(members=dict([(str(u.id), u) for u in members]))
|
||||
group.save()
|
||||
|
||||
with query_counter() as q:
|
||||
self.assertEqual(q, 0)
|
||||
|
||||
group_obj = Group.objects.first()
|
||||
self.assertEqual(q, 1)
|
||||
|
||||
[m for m in group_obj.members]
|
||||
self.assertEqual(q, 4)
|
||||
|
||||
[m for m in group_obj.members]
|
||||
self.assertEqual(q, 4)
|
||||
|
||||
group.members = {}
|
||||
group.save()
|
||||
|
||||
with query_counter() as q:
|
||||
self.assertEqual(q, 0)
|
||||
|
||||
group_obj = Group.objects.first()
|
||||
self.assertEqual(q, 1)
|
||||
|
||||
[m for m in group_obj.members]
|
||||
self.assertEqual(q, 1)
|
||||
|
||||
UserA.drop_collection()
|
||||
UserB.drop_collection()
|
||||
UserC.drop_collection()
|
||||
Group.drop_collection()
|
130
tests/signals.py
Normal file
130
tests/signals.py
Normal file
@ -0,0 +1,130 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import unittest
|
||||
|
||||
from mongoengine import *
|
||||
from mongoengine import signals
|
||||
|
||||
signal_output = []
|
||||
|
||||
|
||||
class SignalTests(unittest.TestCase):
|
||||
"""
|
||||
Testing signals before/after saving and deleting.
|
||||
"""
|
||||
|
||||
def get_signal_output(self, fn, *args, **kwargs):
|
||||
# Flush any existing signal output
|
||||
global signal_output
|
||||
signal_output = []
|
||||
fn(*args, **kwargs)
|
||||
return signal_output
|
||||
|
||||
def setUp(self):
|
||||
connect(db='mongoenginetest')
|
||||
class Author(Document):
|
||||
name = StringField()
|
||||
|
||||
def __unicode__(self):
|
||||
return self.name
|
||||
|
||||
@classmethod
|
||||
def pre_init(cls, instance, **kwargs):
|
||||
signal_output.append('pre_init signal, %s' % cls.__name__)
|
||||
signal_output.append(str(kwargs['values']))
|
||||
|
||||
@classmethod
|
||||
def post_init(cls, instance, **kwargs):
|
||||
signal_output.append('post_init signal, %s' % instance)
|
||||
|
||||
@classmethod
|
||||
def pre_save(cls, instance, **kwargs):
|
||||
signal_output.append('pre_save signal, %s' % instance)
|
||||
|
||||
@classmethod
|
||||
def post_save(cls, instance, **kwargs):
|
||||
signal_output.append('post_save signal, %s' % instance)
|
||||
if 'created' in kwargs:
|
||||
if kwargs['created']:
|
||||
signal_output.append('Is created')
|
||||
else:
|
||||
signal_output.append('Is updated')
|
||||
|
||||
@classmethod
|
||||
def pre_delete(cls, instance, **kwargs):
|
||||
signal_output.append('pre_delete signal, %s' % instance)
|
||||
|
||||
@classmethod
|
||||
def post_delete(cls, instance, **kwargs):
|
||||
signal_output.append('post_delete signal, %s' % instance)
|
||||
|
||||
self.Author = Author
|
||||
|
||||
# Save up the number of connected signals so that we can check at the end
|
||||
# that all the signals we register get properly unregistered
|
||||
self.pre_signals = (
|
||||
len(signals.pre_init.receivers),
|
||||
len(signals.post_init.receivers),
|
||||
len(signals.pre_save.receivers),
|
||||
len(signals.post_save.receivers),
|
||||
len(signals.pre_delete.receivers),
|
||||
len(signals.post_delete.receivers)
|
||||
)
|
||||
|
||||
signals.pre_init.connect(Author.pre_init)
|
||||
signals.post_init.connect(Author.post_init)
|
||||
signals.pre_save.connect(Author.pre_save)
|
||||
signals.post_save.connect(Author.post_save)
|
||||
signals.pre_delete.connect(Author.pre_delete)
|
||||
signals.post_delete.connect(Author.post_delete)
|
||||
|
||||
def tearDown(self):
|
||||
signals.pre_init.disconnect(self.Author.pre_init)
|
||||
signals.post_init.disconnect(self.Author.post_init)
|
||||
signals.post_delete.disconnect(self.Author.post_delete)
|
||||
signals.pre_delete.disconnect(self.Author.pre_delete)
|
||||
signals.post_save.disconnect(self.Author.post_save)
|
||||
signals.pre_save.disconnect(self.Author.pre_save)
|
||||
|
||||
# Check that all our signals got disconnected properly.
|
||||
post_signals = (
|
||||
len(signals.pre_init.receivers),
|
||||
len(signals.post_init.receivers),
|
||||
len(signals.pre_save.receivers),
|
||||
len(signals.post_save.receivers),
|
||||
len(signals.pre_delete.receivers),
|
||||
len(signals.post_delete.receivers)
|
||||
)
|
||||
|
||||
self.assertEqual(self.pre_signals, post_signals)
|
||||
|
||||
def test_model_signals(self):
|
||||
""" Model saves should throw some signals. """
|
||||
|
||||
def create_author():
|
||||
a1 = self.Author(name='Bill Shakespeare')
|
||||
|
||||
self.assertEqual(self.get_signal_output(create_author), [
|
||||
"pre_init signal, Author",
|
||||
"{'name': 'Bill Shakespeare'}",
|
||||
"post_init signal, Bill Shakespeare",
|
||||
])
|
||||
|
||||
a1 = self.Author(name='Bill Shakespeare')
|
||||
self.assertEqual(self.get_signal_output(a1.save), [
|
||||
"pre_save signal, Bill Shakespeare",
|
||||
"post_save signal, Bill Shakespeare",
|
||||
"Is created"
|
||||
])
|
||||
|
||||
a1.reload()
|
||||
a1.name='William Shakespeare'
|
||||
self.assertEqual(self.get_signal_output(a1.save), [
|
||||
"pre_save signal, William Shakespeare",
|
||||
"post_save signal, William Shakespeare",
|
||||
"Is updated"
|
||||
])
|
||||
|
||||
self.assertEqual(self.get_signal_output(a1.delete), [
|
||||
'pre_delete signal, William Shakespeare',
|
||||
'post_delete signal, William Shakespeare',
|
||||
])
|
Loading…
x
Reference in New Issue
Block a user