Merge remote branch 'upstream/dev' into dev
This commit is contained in:
commit
6081fc6faf
1
AUTHORS
1
AUTHORS
@ -3,3 +3,4 @@ Matt Dennewitz <mattdennewitz@gmail.com>
|
|||||||
Deepak Thukral <iapain@yahoo.com>
|
Deepak Thukral <iapain@yahoo.com>
|
||||||
Florian Schlachter <flori@n-schlachter.de>
|
Florian Schlachter <flori@n-schlachter.de>
|
||||||
Steve Challis <steve@stevechallis.com>
|
Steve Challis <steve@stevechallis.com>
|
||||||
|
Ross Lawley <ross.lawley@gmail.com>
|
||||||
|
@ -5,6 +5,11 @@ Changelog
|
|||||||
Changes in dev
|
Changes in dev
|
||||||
==============
|
==============
|
||||||
|
|
||||||
|
- Added slave_okay kwarg to queryset
|
||||||
|
- Added insert method for bulk inserts
|
||||||
|
- Added blinker signal support
|
||||||
|
- Added query_counter context manager for tests
|
||||||
|
- Added DereferenceBaseField - for improved performance in field dereferencing
|
||||||
- Added optional map_reduce method item_frequencies
|
- Added optional map_reduce method item_frequencies
|
||||||
- Added inline_map_reduce option to map_reduce
|
- Added inline_map_reduce option to map_reduce
|
||||||
- Updated connection exception so it provides more info on the cause.
|
- Updated connection exception so it provides more info on the cause.
|
||||||
|
@ -49,10 +49,11 @@ Storage
|
|||||||
=======
|
=======
|
||||||
With MongoEngine's support for GridFS via the :class:`~mongoengine.FileField`,
|
With MongoEngine's support for GridFS via the :class:`~mongoengine.FileField`,
|
||||||
it is useful to have a Django file storage backend that wraps this. The new
|
it is useful to have a Django file storage backend that wraps this. The new
|
||||||
storage module is called :class:`~mongoengine.django.GridFSStorage`. Using it
|
storage module is called :class:`~mongoengine.django.storage.GridFSStorage`.
|
||||||
is very similar to using the default FileSystemStorage.::
|
Using it is very similar to using the default FileSystemStorage.::
|
||||||
|
|
||||||
fs = mongoengine.django.GridFSStorage()
|
from mongoengine.django.storage import GridFSStorage
|
||||||
|
fs = GridFSStorage()
|
||||||
|
|
||||||
filename = fs.save('hello.txt', 'Hello, World!')
|
filename = fs.save('hello.txt', 'Hello, World!')
|
||||||
|
|
||||||
|
@ -341,9 +341,10 @@ Indexes
|
|||||||
You can specify indexes on collections to make querying faster. This is done
|
You can specify indexes on collections to make querying faster. This is done
|
||||||
by creating a list of index specifications called :attr:`indexes` in the
|
by creating a list of index specifications called :attr:`indexes` in the
|
||||||
:attr:`~mongoengine.Document.meta` dictionary, where an index specification may
|
:attr:`~mongoengine.Document.meta` dictionary, where an index specification may
|
||||||
either be a single field name, or a tuple containing multiple field names. A
|
either be a single field name, a tuple containing multiple field names, or a
|
||||||
direction may be specified on fields by prefixing the field name with a **+**
|
dictionary containing a full index definition. A direction may be specified on
|
||||||
or a **-** sign. Note that direction only matters on multi-field indexes. ::
|
fields by prefixing the field name with a **+** or a **-** sign. Note that
|
||||||
|
direction only matters on multi-field indexes. ::
|
||||||
|
|
||||||
class Page(Document):
|
class Page(Document):
|
||||||
title = StringField()
|
title = StringField()
|
||||||
@ -352,6 +353,21 @@ or a **-** sign. Note that direction only matters on multi-field indexes. ::
|
|||||||
'indexes': ['title', ('title', '-rating')]
|
'indexes': ['title', ('title', '-rating')]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
If a dictionary is passed then the following options are available:
|
||||||
|
|
||||||
|
:attr:`fields` (Default: None)
|
||||||
|
The fields to index. Specified in the same format as described above.
|
||||||
|
|
||||||
|
:attr:`types` (Default: True)
|
||||||
|
Whether the index should have the :attr:`_types` field added automatically
|
||||||
|
to the start of the index.
|
||||||
|
|
||||||
|
:attr:`sparse` (Default: False)
|
||||||
|
Whether the index should be sparse.
|
||||||
|
|
||||||
|
:attr:`unique` (Default: False)
|
||||||
|
Whether the index should be sparse.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
Geospatial indexes will be automatically created for all
|
Geospatial indexes will be automatically created for all
|
||||||
:class:`~mongoengine.GeoPointField`\ s
|
:class:`~mongoengine.GeoPointField`\ s
|
||||||
|
@ -11,3 +11,4 @@ User Guide
|
|||||||
document-instances
|
document-instances
|
||||||
querying
|
querying
|
||||||
gridfs
|
gridfs
|
||||||
|
signals
|
||||||
|
49
docs/guide/signals.rst
Normal file
49
docs/guide/signals.rst
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
.. _signals:
|
||||||
|
|
||||||
|
Signals
|
||||||
|
=======
|
||||||
|
|
||||||
|
.. versionadded:: 0.5
|
||||||
|
|
||||||
|
Signal support is provided by the excellent `blinker`_ library and
|
||||||
|
will gracefully fall back if it is not available.
|
||||||
|
|
||||||
|
|
||||||
|
The following document signals exist in MongoEngine and are pretty self explaintary:
|
||||||
|
|
||||||
|
* `mongoengine.signals.pre_init`
|
||||||
|
* `mongoengine.signals.post_init`
|
||||||
|
* `mongoengine.signals.pre_save`
|
||||||
|
* `mongoengine.signals.post_save`
|
||||||
|
* `mongoengine.signals.pre_delete`
|
||||||
|
* `mongoengine.signals.post_delete`
|
||||||
|
|
||||||
|
Example usage::
|
||||||
|
|
||||||
|
from mongoengine import *
|
||||||
|
from mongoengine import signals
|
||||||
|
|
||||||
|
class Author(Document):
|
||||||
|
name = StringField()
|
||||||
|
|
||||||
|
def __unicode__(self):
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def pre_save(cls, instance, **kwargs):
|
||||||
|
logging.debug("Pre Save: %s" % instance.name)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def post_save(cls, instance, **kwargs):
|
||||||
|
logging.debug("Post Save: %s" % instance.name)
|
||||||
|
if 'created' in kwargs:
|
||||||
|
if kwargs['created']:
|
||||||
|
logging.debug("Created")
|
||||||
|
else:
|
||||||
|
logging.debug("Updated")
|
||||||
|
|
||||||
|
signals.pre_save.connect(Author.pre_save)
|
||||||
|
signals.post_save.connect(Author.post_save)
|
||||||
|
|
||||||
|
|
||||||
|
.. _blinker: http://pypi.python.org/pypi/blinker
|
@ -6,9 +6,11 @@ import connection
|
|||||||
from connection import *
|
from connection import *
|
||||||
import queryset
|
import queryset
|
||||||
from queryset import *
|
from queryset import *
|
||||||
|
import signals
|
||||||
|
from signals import *
|
||||||
|
|
||||||
__all__ = (document.__all__ + fields.__all__ + connection.__all__ +
|
__all__ = (document.__all__ + fields.__all__ + connection.__all__ +
|
||||||
queryset.__all__)
|
queryset.__all__ + signals.__all__)
|
||||||
|
|
||||||
__author__ = 'Harry Marr'
|
__author__ = 'Harry Marr'
|
||||||
|
|
||||||
|
@ -2,9 +2,12 @@ from queryset import QuerySet, QuerySetManager
|
|||||||
from queryset import DoesNotExist, MultipleObjectsReturned
|
from queryset import DoesNotExist, MultipleObjectsReturned
|
||||||
from queryset import DO_NOTHING
|
from queryset import DO_NOTHING
|
||||||
|
|
||||||
|
from mongoengine import signals
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
import pymongo
|
import pymongo
|
||||||
import pymongo.objectid
|
import pymongo.objectid
|
||||||
|
from operator import itemgetter
|
||||||
|
|
||||||
|
|
||||||
class NotRegistered(Exception):
|
class NotRegistered(Exception):
|
||||||
@ -126,6 +129,88 @@ class BaseField(object):
|
|||||||
|
|
||||||
self.validate(value)
|
self.validate(value)
|
||||||
|
|
||||||
|
|
||||||
|
class DereferenceBaseField(BaseField):
|
||||||
|
"""Handles the lazy dereferencing of a queryset. Will dereference all
|
||||||
|
items in a list / dict rather than one at a time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __get__(self, instance, owner):
|
||||||
|
"""Descriptor to automatically dereference references.
|
||||||
|
"""
|
||||||
|
from fields import ReferenceField, GenericReferenceField
|
||||||
|
from connection import _get_db
|
||||||
|
|
||||||
|
if instance is None:
|
||||||
|
# Document class being used rather than a document object
|
||||||
|
return self
|
||||||
|
|
||||||
|
# Get value from document instance if available
|
||||||
|
value_list = instance._data.get(self.name)
|
||||||
|
if not value_list:
|
||||||
|
return super(DereferenceBaseField, self).__get__(instance, owner)
|
||||||
|
|
||||||
|
is_list = False
|
||||||
|
if not hasattr(value_list, 'items'):
|
||||||
|
is_list = True
|
||||||
|
value_list = dict([(k,v) for k,v in enumerate(value_list)])
|
||||||
|
|
||||||
|
if isinstance(self.field, ReferenceField) and value_list:
|
||||||
|
db = _get_db()
|
||||||
|
dbref = {}
|
||||||
|
collections = {}
|
||||||
|
|
||||||
|
for k, v in value_list.items():
|
||||||
|
dbref[k] = v
|
||||||
|
# Save any DBRefs
|
||||||
|
if isinstance(v, (pymongo.dbref.DBRef)):
|
||||||
|
collections.setdefault(v.collection, []).append((k, v))
|
||||||
|
|
||||||
|
# For each collection get the references
|
||||||
|
for collection, dbrefs in collections.items():
|
||||||
|
id_map = dict([(v.id, k) for k, v in dbrefs])
|
||||||
|
references = db[collection].find({'_id': {'$in': id_map.keys()}})
|
||||||
|
for ref in references:
|
||||||
|
key = id_map[ref['_id']]
|
||||||
|
dbref[key] = get_document(ref['_cls'])._from_son(ref)
|
||||||
|
|
||||||
|
if is_list:
|
||||||
|
dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))]
|
||||||
|
instance._data[self.name] = dbref
|
||||||
|
|
||||||
|
# Get value from document instance if available
|
||||||
|
if isinstance(self.field, GenericReferenceField) and value_list:
|
||||||
|
db = _get_db()
|
||||||
|
value_list = [(k,v) for k,v in value_list.items()]
|
||||||
|
dbref = {}
|
||||||
|
classes = {}
|
||||||
|
|
||||||
|
for k, v in value_list:
|
||||||
|
dbref[k] = v
|
||||||
|
# Save any DBRefs
|
||||||
|
if isinstance(v, (dict, pymongo.son.SON)):
|
||||||
|
classes.setdefault(v['_cls'], []).append((k, v))
|
||||||
|
|
||||||
|
# For each collection get the references
|
||||||
|
for doc_cls, dbrefs in classes.items():
|
||||||
|
id_map = dict([(v['_ref'].id, k) for k, v in dbrefs])
|
||||||
|
doc_cls = get_document(doc_cls)
|
||||||
|
collection = doc_cls._meta['collection']
|
||||||
|
references = db[collection].find({'_id': {'$in': id_map.keys()}})
|
||||||
|
|
||||||
|
for ref in references:
|
||||||
|
key = id_map[ref['_id']]
|
||||||
|
dbref[key] = doc_cls._from_son(ref)
|
||||||
|
|
||||||
|
if is_list:
|
||||||
|
dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))]
|
||||||
|
|
||||||
|
instance._data[self.name] = dbref
|
||||||
|
|
||||||
|
return super(DereferenceBaseField, self).__get__(instance, owner)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ObjectIdField(BaseField):
|
class ObjectIdField(BaseField):
|
||||||
"""An field wrapper around MongoDB's ObjectIds.
|
"""An field wrapper around MongoDB's ObjectIds.
|
||||||
"""
|
"""
|
||||||
@ -382,6 +467,8 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
|
|||||||
class BaseDocument(object):
|
class BaseDocument(object):
|
||||||
|
|
||||||
def __init__(self, **values):
|
def __init__(self, **values):
|
||||||
|
signals.pre_init.send(self, values=values)
|
||||||
|
|
||||||
self._data = {}
|
self._data = {}
|
||||||
# Assign default values to instance
|
# Assign default values to instance
|
||||||
for attr_name in self._fields.keys():
|
for attr_name in self._fields.keys():
|
||||||
@ -395,6 +482,8 @@ class BaseDocument(object):
|
|||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
signals.post_init.send(self)
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""Ensure that all fields' values are valid and that required fields
|
"""Ensure that all fields' values are valid and that required fields
|
||||||
are present.
|
are present.
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from mongoengine import signals
|
||||||
from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument,
|
from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument,
|
||||||
ValidationError)
|
ValidationError)
|
||||||
from queryset import OperationError
|
from queryset import OperationError
|
||||||
@ -75,6 +76,8 @@ class Document(BaseDocument):
|
|||||||
For example, ``save(..., w=2, fsync=True)`` will wait until at least two servers
|
For example, ``save(..., w=2, fsync=True)`` will wait until at least two servers
|
||||||
have recorded the write and will force an fsync on each server being written to.
|
have recorded the write and will force an fsync on each server being written to.
|
||||||
"""
|
"""
|
||||||
|
signals.pre_save.send(self)
|
||||||
|
|
||||||
if validate:
|
if validate:
|
||||||
self.validate()
|
self.validate()
|
||||||
|
|
||||||
@ -82,6 +85,7 @@ class Document(BaseDocument):
|
|||||||
write_options = {}
|
write_options = {}
|
||||||
|
|
||||||
doc = self.to_mongo()
|
doc = self.to_mongo()
|
||||||
|
created = '_id' not in doc
|
||||||
try:
|
try:
|
||||||
collection = self.__class__.objects._collection
|
collection = self.__class__.objects._collection
|
||||||
if force_insert:
|
if force_insert:
|
||||||
@ -96,12 +100,16 @@ class Document(BaseDocument):
|
|||||||
id_field = self._meta['id_field']
|
id_field = self._meta['id_field']
|
||||||
self[id_field] = self._fields[id_field].to_python(object_id)
|
self[id_field] = self._fields[id_field].to_python(object_id)
|
||||||
|
|
||||||
|
signals.post_save.send(self, created=created)
|
||||||
|
|
||||||
def delete(self, safe=False):
|
def delete(self, safe=False):
|
||||||
"""Delete the :class:`~mongoengine.Document` from the database. This
|
"""Delete the :class:`~mongoengine.Document` from the database. This
|
||||||
will only take effect if the document has been previously saved.
|
will only take effect if the document has been previously saved.
|
||||||
|
|
||||||
:param safe: check if the operation succeeded before returning
|
:param safe: check if the operation succeeded before returning
|
||||||
"""
|
"""
|
||||||
|
signals.pre_delete.send(self)
|
||||||
|
|
||||||
id_field = self._meta['id_field']
|
id_field = self._meta['id_field']
|
||||||
object_id = self._fields[id_field].to_mongo(self[id_field])
|
object_id = self._fields[id_field].to_mongo(self[id_field])
|
||||||
try:
|
try:
|
||||||
@ -110,6 +118,8 @@ class Document(BaseDocument):
|
|||||||
message = u'Could not delete document (%s)' % err.message
|
message = u'Could not delete document (%s)' % err.message
|
||||||
raise OperationError(message)
|
raise OperationError(message)
|
||||||
|
|
||||||
|
signals.post_delete.send(self)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_delete_rule(cls, document_cls, field_name, rule):
|
def register_delete_rule(cls, document_cls, field_name, rule):
|
||||||
"""This method registers the delete rules to apply when removing this
|
"""This method registers the delete rules to apply when removing this
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from base import BaseField, ObjectIdField, ValidationError, get_document
|
from base import (BaseField, DereferenceBaseField, ObjectIdField,
|
||||||
|
ValidationError, get_document)
|
||||||
from queryset import DO_NOTHING
|
from queryset import DO_NOTHING
|
||||||
from document import Document, EmbeddedDocument
|
from document import Document, EmbeddedDocument
|
||||||
from connection import _get_db
|
from connection import _get_db
|
||||||
@ -12,7 +13,6 @@ import pymongo.binary
|
|||||||
import datetime, time
|
import datetime, time
|
||||||
import decimal
|
import decimal
|
||||||
import gridfs
|
import gridfs
|
||||||
import warnings
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField',
|
__all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField',
|
||||||
@ -153,6 +153,7 @@ class IntField(BaseField):
|
|||||||
def prepare_query_value(self, op, value):
|
def prepare_query_value(self, op, value):
|
||||||
return int(value)
|
return int(value)
|
||||||
|
|
||||||
|
|
||||||
class FloatField(BaseField):
|
class FloatField(BaseField):
|
||||||
"""An floating point number field.
|
"""An floating point number field.
|
||||||
"""
|
"""
|
||||||
@ -178,6 +179,7 @@ class FloatField(BaseField):
|
|||||||
def prepare_query_value(self, op, value):
|
def prepare_query_value(self, op, value):
|
||||||
return float(value)
|
return float(value)
|
||||||
|
|
||||||
|
|
||||||
class DecimalField(BaseField):
|
class DecimalField(BaseField):
|
||||||
"""A fixed-point decimal number field.
|
"""A fixed-point decimal number field.
|
||||||
|
|
||||||
@ -227,6 +229,10 @@ class BooleanField(BaseField):
|
|||||||
|
|
||||||
class DateTimeField(BaseField):
|
class DateTimeField(BaseField):
|
||||||
"""A datetime field.
|
"""A datetime field.
|
||||||
|
|
||||||
|
Note: Microseconds are rounded to the nearest millisecond.
|
||||||
|
Pre UTC microsecond support is effecively broken see
|
||||||
|
`tests.field.test_datetime` for more information.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def validate(self, value):
|
def validate(self, value):
|
||||||
@ -255,7 +261,6 @@ class DateTimeField(BaseField):
|
|||||||
try: # Seconds are optional, so try converting seconds first.
|
try: # Seconds are optional, so try converting seconds first.
|
||||||
return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M:%S')[:6],
|
return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M:%S')[:6],
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
try: # Try without seconds.
|
try: # Try without seconds.
|
||||||
return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M')[:5],
|
return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M')[:5],
|
||||||
@ -267,6 +272,7 @@ class DateTimeField(BaseField):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class EmbeddedDocumentField(BaseField):
|
class EmbeddedDocumentField(BaseField):
|
||||||
"""An embedded document field. Only valid values are subclasses of
|
"""An embedded document field. Only valid values are subclasses of
|
||||||
:class:`~mongoengine.EmbeddedDocument`.
|
:class:`~mongoengine.EmbeddedDocument`.
|
||||||
@ -314,7 +320,7 @@ class EmbeddedDocumentField(BaseField):
|
|||||||
return self.to_mongo(value)
|
return self.to_mongo(value)
|
||||||
|
|
||||||
|
|
||||||
class ListField(BaseField):
|
class ListField(DereferenceBaseField):
|
||||||
"""A list field that wraps a standard field, allowing multiple instances
|
"""A list field that wraps a standard field, allowing multiple instances
|
||||||
of the field to be used as a list in the database.
|
of the field to be used as a list in the database.
|
||||||
"""
|
"""
|
||||||
@ -330,42 +336,6 @@ class ListField(BaseField):
|
|||||||
kwargs.setdefault('default', lambda: [])
|
kwargs.setdefault('default', lambda: [])
|
||||||
super(ListField, self).__init__(**kwargs)
|
super(ListField, self).__init__(**kwargs)
|
||||||
|
|
||||||
def __get__(self, instance, owner):
|
|
||||||
"""Descriptor to automatically dereference references.
|
|
||||||
"""
|
|
||||||
if instance is None:
|
|
||||||
# Document class being used rather than a document object
|
|
||||||
return self
|
|
||||||
|
|
||||||
if isinstance(self.field, ReferenceField):
|
|
||||||
referenced_type = self.field.document_type
|
|
||||||
# Get value from document instance if available
|
|
||||||
value_list = instance._data.get(self.name)
|
|
||||||
if value_list:
|
|
||||||
deref_list = []
|
|
||||||
for value in value_list:
|
|
||||||
# Dereference DBRefs
|
|
||||||
if isinstance(value, (pymongo.dbref.DBRef)):
|
|
||||||
value = _get_db().dereference(value)
|
|
||||||
deref_list.append(referenced_type._from_son(value))
|
|
||||||
else:
|
|
||||||
deref_list.append(value)
|
|
||||||
instance._data[self.name] = deref_list
|
|
||||||
|
|
||||||
if isinstance(self.field, GenericReferenceField):
|
|
||||||
value_list = instance._data.get(self.name)
|
|
||||||
if value_list:
|
|
||||||
deref_list = []
|
|
||||||
for value in value_list:
|
|
||||||
# Dereference DBRefs
|
|
||||||
if isinstance(value, (dict, pymongo.son.SON)):
|
|
||||||
deref_list.append(self.field.dereference(value))
|
|
||||||
else:
|
|
||||||
deref_list.append(value)
|
|
||||||
instance._data[self.name] = deref_list
|
|
||||||
|
|
||||||
return super(ListField, self).__get__(instance, owner)
|
|
||||||
|
|
||||||
def to_python(self, value):
|
def to_python(self, value):
|
||||||
return [self.field.to_python(item) for item in value]
|
return [self.field.to_python(item) for item in value]
|
||||||
|
|
||||||
@ -462,7 +432,7 @@ class DictField(BaseField):
|
|||||||
return super(DictField, self).prepare_query_value(op, value)
|
return super(DictField, self).prepare_query_value(op, value)
|
||||||
|
|
||||||
|
|
||||||
class MapField(BaseField):
|
class MapField(DereferenceBaseField):
|
||||||
"""A field that maps a name to a specified field type. Similar to
|
"""A field that maps a name to a specified field type. Similar to
|
||||||
a DictField, except the 'value' of each item must match the specified
|
a DictField, except the 'value' of each item must match the specified
|
||||||
field type.
|
field type.
|
||||||
@ -494,42 +464,6 @@ class MapField(BaseField):
|
|||||||
except Exception, err:
|
except Exception, err:
|
||||||
raise ValidationError('Invalid MapField item (%s)' % str(item))
|
raise ValidationError('Invalid MapField item (%s)' % str(item))
|
||||||
|
|
||||||
def __get__(self, instance, owner):
|
|
||||||
"""Descriptor to automatically dereference references.
|
|
||||||
"""
|
|
||||||
if instance is None:
|
|
||||||
# Document class being used rather than a document object
|
|
||||||
return self
|
|
||||||
|
|
||||||
if isinstance(self.field, ReferenceField):
|
|
||||||
referenced_type = self.field.document_type
|
|
||||||
# Get value from document instance if available
|
|
||||||
value_dict = instance._data.get(self.name)
|
|
||||||
if value_dict:
|
|
||||||
deref_dict = []
|
|
||||||
for key,value in value_dict.iteritems():
|
|
||||||
# Dereference DBRefs
|
|
||||||
if isinstance(value, (pymongo.dbref.DBRef)):
|
|
||||||
value = _get_db().dereference(value)
|
|
||||||
deref_dict[key] = referenced_type._from_son(value)
|
|
||||||
else:
|
|
||||||
deref_dict[key] = value
|
|
||||||
instance._data[self.name] = deref_dict
|
|
||||||
|
|
||||||
if isinstance(self.field, GenericReferenceField):
|
|
||||||
value_dict = instance._data.get(self.name)
|
|
||||||
if value_dict:
|
|
||||||
deref_dict = []
|
|
||||||
for key,value in value_dict.iteritems():
|
|
||||||
# Dereference DBRefs
|
|
||||||
if isinstance(value, (dict, pymongo.son.SON)):
|
|
||||||
deref_dict[key] = self.field.dereference(value)
|
|
||||||
else:
|
|
||||||
deref_dict[key] = value
|
|
||||||
instance._data[self.name] = deref_dict
|
|
||||||
|
|
||||||
return super(MapField, self).__get__(instance, owner)
|
|
||||||
|
|
||||||
def to_python(self, value):
|
def to_python(self, value):
|
||||||
return dict([(key, self.field.to_python(item)) for key, item in value.iteritems()])
|
return dict([(key, self.field.to_python(item)) for key, item in value.iteritems()])
|
||||||
|
|
||||||
@ -752,11 +686,11 @@ class GridFSProxy(object):
|
|||||||
self.newfile = self.fs.new_file(**kwargs)
|
self.newfile = self.fs.new_file(**kwargs)
|
||||||
self.grid_id = self.newfile._id
|
self.grid_id = self.newfile._id
|
||||||
|
|
||||||
def put(self, file, **kwargs):
|
def put(self, file_obj, **kwargs):
|
||||||
if self.grid_id:
|
if self.grid_id:
|
||||||
raise GridFSError('This document already has a file. Either delete '
|
raise GridFSError('This document already has a file. Either delete '
|
||||||
'it or call replace to overwrite it')
|
'it or call replace to overwrite it')
|
||||||
self.grid_id = self.fs.put(file, **kwargs)
|
self.grid_id = self.fs.put(file_obj, **kwargs)
|
||||||
|
|
||||||
def write(self, string):
|
def write(self, string):
|
||||||
if self.grid_id:
|
if self.grid_id:
|
||||||
@ -785,9 +719,9 @@ class GridFSProxy(object):
|
|||||||
self.grid_id = None
|
self.grid_id = None
|
||||||
self.gridout = None
|
self.gridout = None
|
||||||
|
|
||||||
def replace(self, file, **kwargs):
|
def replace(self, file_obj, **kwargs):
|
||||||
self.delete()
|
self.delete()
|
||||||
self.put(file, **kwargs)
|
self.put(file_obj, **kwargs)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if self.newfile:
|
if self.newfile:
|
||||||
|
@ -336,6 +336,7 @@ class QuerySet(object):
|
|||||||
self._snapshot = False
|
self._snapshot = False
|
||||||
self._timeout = True
|
self._timeout = True
|
||||||
self._class_check = True
|
self._class_check = True
|
||||||
|
self._slave_okay = False
|
||||||
|
|
||||||
# If inheritance is allowed, only return instances and instances of
|
# If inheritance is allowed, only return instances and instances of
|
||||||
# subclasses of the class being used
|
# subclasses of the class being used
|
||||||
@ -352,7 +353,7 @@ class QuerySet(object):
|
|||||||
|
|
||||||
copy_props = ('_initial_query', '_query_obj', '_where_clause',
|
copy_props = ('_initial_query', '_query_obj', '_where_clause',
|
||||||
'_loaded_fields', '_ordering', '_snapshot',
|
'_loaded_fields', '_ordering', '_snapshot',
|
||||||
'_timeout', '_limit', '_skip')
|
'_timeout', '_limit', '_skip', '_slave_okay')
|
||||||
|
|
||||||
for prop in copy_props:
|
for prop in copy_props:
|
||||||
val = getattr(self, prop)
|
val = getattr(self, prop)
|
||||||
@ -376,21 +377,27 @@ class QuerySet(object):
|
|||||||
construct a multi-field index); keys may be prefixed with a **+**
|
construct a multi-field index); keys may be prefixed with a **+**
|
||||||
or a **-** to determine the index ordering
|
or a **-** to determine the index ordering
|
||||||
"""
|
"""
|
||||||
index_list = QuerySet._build_index_spec(self._document, key_or_list)
|
index_spec = QuerySet._build_index_spec(self._document, key_or_list)
|
||||||
self._collection.ensure_index(index_list, drop_dups=drop_dups,
|
self._collection.ensure_index(
|
||||||
background=background)
|
index_spec['fields'],
|
||||||
|
drop_dups=drop_dups,
|
||||||
|
background=background,
|
||||||
|
sparse=index_spec.get('sparse', False),
|
||||||
|
unique=index_spec.get('unique', False))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _build_index_spec(cls, doc_cls, key_or_list):
|
def _build_index_spec(cls, doc_cls, spec):
|
||||||
"""Build a PyMongo index spec from a MongoEngine index spec.
|
"""Build a PyMongo index spec from a MongoEngine index spec.
|
||||||
"""
|
"""
|
||||||
if isinstance(key_or_list, basestring):
|
if isinstance(spec, basestring):
|
||||||
key_or_list = [key_or_list]
|
spec = {'fields': [spec]}
|
||||||
|
if isinstance(spec, (list, tuple)):
|
||||||
|
spec = {'fields': spec}
|
||||||
|
|
||||||
index_list = []
|
index_list = []
|
||||||
use_types = doc_cls._meta.get('allow_inheritance', True)
|
use_types = doc_cls._meta.get('allow_inheritance', True)
|
||||||
for key in key_or_list:
|
for key in spec['fields']:
|
||||||
# Get direction from + or -
|
# Get direction from + or -
|
||||||
direction = pymongo.ASCENDING
|
direction = pymongo.ASCENDING
|
||||||
if key.startswith("-"):
|
if key.startswith("-"):
|
||||||
@ -411,12 +418,20 @@ class QuerySet(object):
|
|||||||
use_types = False
|
use_types = False
|
||||||
|
|
||||||
# If _types is being used, prepend it to every specified index
|
# If _types is being used, prepend it to every specified index
|
||||||
if doc_cls._meta.get('allow_inheritance') and use_types:
|
if (spec.get('types', True) and doc_cls._meta.get('allow_inheritance')
|
||||||
|
and use_types):
|
||||||
index_list.insert(0, ('_types', 1))
|
index_list.insert(0, ('_types', 1))
|
||||||
|
|
||||||
return index_list
|
spec['fields'] = index_list
|
||||||
|
|
||||||
def __call__(self, q_obj=None, class_check=True, **query):
|
if spec.get('sparse', False) and len(spec['fields']) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
'Sparse indexes can only have one field in them. '
|
||||||
|
'See https://jira.mongodb.org/browse/SERVER-2193')
|
||||||
|
|
||||||
|
return spec
|
||||||
|
|
||||||
|
def __call__(self, q_obj=None, class_check=True, slave_okay=False, **query):
|
||||||
"""Filter the selected documents by calling the
|
"""Filter the selected documents by calling the
|
||||||
:class:`~mongoengine.queryset.QuerySet` with a query.
|
:class:`~mongoengine.queryset.QuerySet` with a query.
|
||||||
|
|
||||||
@ -426,6 +441,8 @@ class QuerySet(object):
|
|||||||
objects, only the last one will be used
|
objects, only the last one will be used
|
||||||
:param class_check: If set to False bypass class name check when
|
:param class_check: If set to False bypass class name check when
|
||||||
querying collection
|
querying collection
|
||||||
|
:param slave_okay: if True, allows this query to be run against a
|
||||||
|
replica secondary.
|
||||||
:param query: Django-style query keyword arguments
|
:param query: Django-style query keyword arguments
|
||||||
"""
|
"""
|
||||||
query = Q(**query)
|
query = Q(**query)
|
||||||
@ -465,9 +482,12 @@ class QuerySet(object):
|
|||||||
|
|
||||||
# Ensure document-defined indexes are created
|
# Ensure document-defined indexes are created
|
||||||
if self._document._meta['indexes']:
|
if self._document._meta['indexes']:
|
||||||
for key_or_list in self._document._meta['indexes']:
|
for spec in self._document._meta['indexes']:
|
||||||
self._collection.ensure_index(key_or_list,
|
opts = index_opts.copy()
|
||||||
background=background, **index_opts)
|
opts['unique'] = spec.get('unique', False)
|
||||||
|
opts['sparse'] = spec.get('sparse', False)
|
||||||
|
self._collection.ensure_index(spec['fields'],
|
||||||
|
background=background, **opts)
|
||||||
|
|
||||||
# If _types is being used (for polymorphism), it needs an index
|
# If _types is being used (for polymorphism), it needs an index
|
||||||
if '_types' in self._query:
|
if '_types' in self._query:
|
||||||
@ -484,16 +504,22 @@ class QuerySet(object):
|
|||||||
return self._collection_obj
|
return self._collection_obj
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _cursor(self):
|
def _cursor_args(self):
|
||||||
if self._cursor_obj is None:
|
|
||||||
cursor_args = {
|
cursor_args = {
|
||||||
'snapshot': self._snapshot,
|
'snapshot': self._snapshot,
|
||||||
'timeout': self._timeout,
|
'timeout': self._timeout,
|
||||||
|
'slave_okay': self._slave_okay
|
||||||
}
|
}
|
||||||
if self._loaded_fields:
|
if self._loaded_fields:
|
||||||
cursor_args['fields'] = self._loaded_fields.as_dict()
|
cursor_args['fields'] = self._loaded_fields.as_dict()
|
||||||
|
return cursor_args
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _cursor(self):
|
||||||
|
if self._cursor_obj is None:
|
||||||
|
|
||||||
self._cursor_obj = self._collection.find(self._query,
|
self._cursor_obj = self._collection.find(self._query,
|
||||||
**cursor_args)
|
**self._cursor_args)
|
||||||
# Apply where clauses to cursor
|
# Apply where clauses to cursor
|
||||||
if self._where_clause:
|
if self._where_clause:
|
||||||
self._cursor_obj.where(self._where_clause)
|
self._cursor_obj.where(self._where_clause)
|
||||||
@ -702,6 +728,46 @@ class QuerySet(object):
|
|||||||
result = None
|
result = None
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def insert(self, doc_or_docs, load_bulk=True):
|
||||||
|
"""bulk insert documents
|
||||||
|
|
||||||
|
:param docs_or_doc: a document or list of documents to be inserted
|
||||||
|
:param load_bulk (optional): If True returns the list of document instances
|
||||||
|
|
||||||
|
By default returns document instances, set ``load_bulk`` to False to
|
||||||
|
return just ``ObjectIds``
|
||||||
|
|
||||||
|
.. versionadded:: 0.5
|
||||||
|
"""
|
||||||
|
from document import Document
|
||||||
|
|
||||||
|
docs = doc_or_docs
|
||||||
|
return_one = False
|
||||||
|
if isinstance(docs, Document) or issubclass(docs.__class__, Document):
|
||||||
|
return_one = True
|
||||||
|
docs = [docs]
|
||||||
|
|
||||||
|
raw = []
|
||||||
|
for doc in docs:
|
||||||
|
if not isinstance(doc, self._document):
|
||||||
|
msg = "Some documents inserted aren't instances of %s" % str(self._document)
|
||||||
|
raise OperationError(msg)
|
||||||
|
if doc.pk:
|
||||||
|
msg = "Some documents have ObjectIds use doc.update() instead"
|
||||||
|
raise OperationError(msg)
|
||||||
|
raw.append(doc.to_mongo())
|
||||||
|
|
||||||
|
ids = self._collection.insert(raw)
|
||||||
|
|
||||||
|
if not load_bulk:
|
||||||
|
return return_one and ids[0] or ids
|
||||||
|
|
||||||
|
documents = self.in_bulk(ids)
|
||||||
|
results = []
|
||||||
|
for obj_id in ids:
|
||||||
|
results.append(documents.get(obj_id))
|
||||||
|
return return_one and results[0] or results
|
||||||
|
|
||||||
def with_id(self, object_id):
|
def with_id(self, object_id):
|
||||||
"""Retrieve the object matching the id provided.
|
"""Retrieve the object matching the id provided.
|
||||||
|
|
||||||
@ -710,7 +776,7 @@ class QuerySet(object):
|
|||||||
id_field = self._document._meta['id_field']
|
id_field = self._document._meta['id_field']
|
||||||
object_id = self._document._fields[id_field].to_mongo(object_id)
|
object_id = self._document._fields[id_field].to_mongo(object_id)
|
||||||
|
|
||||||
result = self._collection.find_one({'_id': object_id})
|
result = self._collection.find_one({'_id': object_id}, **self._cursor_args)
|
||||||
if result is not None:
|
if result is not None:
|
||||||
result = self._document._from_son(result)
|
result = self._document._from_son(result)
|
||||||
return result
|
return result
|
||||||
@ -726,7 +792,8 @@ class QuerySet(object):
|
|||||||
"""
|
"""
|
||||||
doc_map = {}
|
doc_map = {}
|
||||||
|
|
||||||
docs = self._collection.find({'_id': {'$in': object_ids}})
|
docs = self._collection.find({'_id': {'$in': object_ids}},
|
||||||
|
**self._cursor_args)
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
doc_map[doc['_id']] = self._document._from_son(doc)
|
doc_map[doc['_id']] = self._document._from_son(doc)
|
||||||
|
|
||||||
@ -1023,6 +1090,7 @@ class QuerySet(object):
|
|||||||
:param enabled: whether or not snapshot mode is enabled
|
:param enabled: whether or not snapshot mode is enabled
|
||||||
"""
|
"""
|
||||||
self._snapshot = enabled
|
self._snapshot = enabled
|
||||||
|
return self
|
||||||
|
|
||||||
def timeout(self, enabled):
|
def timeout(self, enabled):
|
||||||
"""Enable or disable the default mongod timeout when querying.
|
"""Enable or disable the default mongod timeout when querying.
|
||||||
@ -1030,6 +1098,15 @@ class QuerySet(object):
|
|||||||
:param enabled: whether or not the timeout is used
|
:param enabled: whether or not the timeout is used
|
||||||
"""
|
"""
|
||||||
self._timeout = enabled
|
self._timeout = enabled
|
||||||
|
return self
|
||||||
|
|
||||||
|
def slave_okay(self, enabled):
|
||||||
|
"""Enable or disable the slave_okay when querying.
|
||||||
|
|
||||||
|
:param enabled: whether or not the slave_okay is enabled
|
||||||
|
"""
|
||||||
|
self._slave_okay = enabled
|
||||||
|
return self
|
||||||
|
|
||||||
def delete(self, safe=False):
|
def delete(self, safe=False):
|
||||||
"""Delete the documents matched by the query.
|
"""Delete the documents matched by the query.
|
||||||
|
44
mongoengine/signals.py
Normal file
44
mongoengine/signals.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
__all__ = ['pre_init', 'post_init', 'pre_save', 'post_save',
|
||||||
|
'pre_delete', 'post_delete']
|
||||||
|
|
||||||
|
signals_available = False
|
||||||
|
try:
|
||||||
|
from blinker import Namespace
|
||||||
|
signals_available = True
|
||||||
|
except ImportError:
|
||||||
|
class Namespace(object):
|
||||||
|
def signal(self, name, doc=None):
|
||||||
|
return _FakeSignal(name, doc)
|
||||||
|
|
||||||
|
class _FakeSignal(object):
|
||||||
|
"""If blinker is unavailable, create a fake class with the same
|
||||||
|
interface that allows sending of signals but will fail with an
|
||||||
|
error on anything else. Instead of doing anything on send, it
|
||||||
|
will just ignore the arguments and do nothing instead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name, doc=None):
|
||||||
|
self.name = name
|
||||||
|
self.__doc__ = doc
|
||||||
|
|
||||||
|
def _fail(self, *args, **kwargs):
|
||||||
|
raise RuntimeError('signalling support is unavailable '
|
||||||
|
'because the blinker library is '
|
||||||
|
'not installed.')
|
||||||
|
send = lambda *a, **kw: None
|
||||||
|
connect = disconnect = has_receivers_for = receivers_for = \
|
||||||
|
temporarily_connected_to = _fail
|
||||||
|
del _fail
|
||||||
|
|
||||||
|
# the namespace for code signals. If you are not mongoengine code, do
|
||||||
|
# not put signals in here. Create your own namespace instead.
|
||||||
|
_signals = Namespace()
|
||||||
|
|
||||||
|
pre_init = _signals.signal('pre_init')
|
||||||
|
post_init = _signals.signal('post_init')
|
||||||
|
pre_save = _signals.signal('pre_save')
|
||||||
|
post_save = _signals.signal('post_save')
|
||||||
|
pre_delete = _signals.signal('pre_delete')
|
||||||
|
post_delete = _signals.signal('post_delete')
|
59
mongoengine/tests.py
Normal file
59
mongoengine/tests.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
from mongoengine.connection import _get_db
|
||||||
|
|
||||||
|
|
||||||
|
class query_counter(object):
|
||||||
|
""" Query_counter contextmanager to get the number of queries. """
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
""" Construct the query_counter. """
|
||||||
|
self.counter = 0
|
||||||
|
self.db = _get_db()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
""" On every with block we need to drop the profile collection. """
|
||||||
|
self.db.set_profiling_level(0)
|
||||||
|
self.db.system.profile.drop()
|
||||||
|
self.db.set_profiling_level(2)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, t, value, traceback):
|
||||||
|
""" Reset the profiling level. """
|
||||||
|
self.db.set_profiling_level(0)
|
||||||
|
|
||||||
|
def __eq__(self, value):
|
||||||
|
""" == Compare querycounter. """
|
||||||
|
return value == self._get_count()
|
||||||
|
|
||||||
|
def __ne__(self, value):
|
||||||
|
""" != Compare querycounter. """
|
||||||
|
return not self.__eq__(value)
|
||||||
|
|
||||||
|
def __lt__(self, value):
|
||||||
|
""" < Compare querycounter. """
|
||||||
|
return self._get_count() < value
|
||||||
|
|
||||||
|
def __le__(self, value):
|
||||||
|
""" <= Compare querycounter. """
|
||||||
|
return self._get_count() <= value
|
||||||
|
|
||||||
|
def __gt__(self, value):
|
||||||
|
""" > Compare querycounter. """
|
||||||
|
return self._get_count() > value
|
||||||
|
|
||||||
|
def __ge__(self, value):
|
||||||
|
""" >= Compare querycounter. """
|
||||||
|
return self._get_count() >= value
|
||||||
|
|
||||||
|
def __int__(self):
|
||||||
|
""" int representation. """
|
||||||
|
return self._get_count()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
""" repr query_counter as the number of queries. """
|
||||||
|
return u"%s" % self._get_count()
|
||||||
|
|
||||||
|
def _get_count(self):
|
||||||
|
""" Get the number of queries. """
|
||||||
|
count = self.db.system.profile.find().count() - self.counter
|
||||||
|
self.counter += 1
|
||||||
|
return count
|
2
setup.py
2
setup.py
@ -45,6 +45,6 @@ setup(name='mongoengine',
|
|||||||
long_description=LONG_DESCRIPTION,
|
long_description=LONG_DESCRIPTION,
|
||||||
platforms=['any'],
|
platforms=['any'],
|
||||||
classifiers=CLASSIFIERS,
|
classifiers=CLASSIFIERS,
|
||||||
install_requires=['pymongo'],
|
install_requires=['pymongo', 'blinker'],
|
||||||
test_suite='tests',
|
test_suite='tests',
|
||||||
)
|
)
|
||||||
|
288
tests/dereference.py
Normal file
288
tests/dereference.py
Normal file
@ -0,0 +1,288 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
from mongoengine import *
|
||||||
|
from mongoengine.connection import _get_db
|
||||||
|
from mongoengine.tests import query_counter
|
||||||
|
|
||||||
|
|
||||||
|
class FieldTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
connect(db='mongoenginetest')
|
||||||
|
self.db = _get_db()
|
||||||
|
|
||||||
|
def test_list_item_dereference(self):
|
||||||
|
"""Ensure that DBRef items in ListFields are dereferenced.
|
||||||
|
"""
|
||||||
|
class User(Document):
|
||||||
|
name = StringField()
|
||||||
|
|
||||||
|
class Group(Document):
|
||||||
|
members = ListField(ReferenceField(User))
|
||||||
|
|
||||||
|
User.drop_collection()
|
||||||
|
Group.drop_collection()
|
||||||
|
|
||||||
|
for i in xrange(1, 51):
|
||||||
|
user = User(name='user %s' % i)
|
||||||
|
user.save()
|
||||||
|
|
||||||
|
group = Group(members=User.objects)
|
||||||
|
group.save()
|
||||||
|
|
||||||
|
with query_counter() as q:
|
||||||
|
self.assertEqual(q, 0)
|
||||||
|
|
||||||
|
group_obj = Group.objects.first()
|
||||||
|
self.assertEqual(q, 1)
|
||||||
|
|
||||||
|
[m for m in group_obj.members]
|
||||||
|
self.assertEqual(q, 2)
|
||||||
|
|
||||||
|
User.drop_collection()
|
||||||
|
Group.drop_collection()
|
||||||
|
|
||||||
|
def test_recursive_reference(self):
|
||||||
|
"""Ensure that ReferenceFields can reference their own documents.
|
||||||
|
"""
|
||||||
|
class Employee(Document):
|
||||||
|
name = StringField()
|
||||||
|
boss = ReferenceField('self')
|
||||||
|
friends = ListField(ReferenceField('self'))
|
||||||
|
|
||||||
|
bill = Employee(name='Bill Lumbergh')
|
||||||
|
bill.save()
|
||||||
|
|
||||||
|
michael = Employee(name='Michael Bolton')
|
||||||
|
michael.save()
|
||||||
|
|
||||||
|
samir = Employee(name='Samir Nagheenanajar')
|
||||||
|
samir.save()
|
||||||
|
|
||||||
|
friends = [michael, samir]
|
||||||
|
peter = Employee(name='Peter Gibbons', boss=bill, friends=friends)
|
||||||
|
peter.save()
|
||||||
|
|
||||||
|
with query_counter() as q:
|
||||||
|
self.assertEqual(q, 0)
|
||||||
|
|
||||||
|
peter = Employee.objects.with_id(peter.id)
|
||||||
|
self.assertEqual(q, 1)
|
||||||
|
|
||||||
|
peter.boss
|
||||||
|
self.assertEqual(q, 2)
|
||||||
|
|
||||||
|
peter.friends
|
||||||
|
self.assertEqual(q, 3)
|
||||||
|
|
||||||
|
def test_generic_reference(self):
|
||||||
|
|
||||||
|
class UserA(Document):
|
||||||
|
name = StringField()
|
||||||
|
|
||||||
|
class UserB(Document):
|
||||||
|
name = StringField()
|
||||||
|
|
||||||
|
class UserC(Document):
|
||||||
|
name = StringField()
|
||||||
|
|
||||||
|
class Group(Document):
|
||||||
|
members = ListField(GenericReferenceField())
|
||||||
|
|
||||||
|
UserA.drop_collection()
|
||||||
|
UserB.drop_collection()
|
||||||
|
UserC.drop_collection()
|
||||||
|
Group.drop_collection()
|
||||||
|
|
||||||
|
members = []
|
||||||
|
for i in xrange(1, 51):
|
||||||
|
a = UserA(name='User A %s' % i)
|
||||||
|
a.save()
|
||||||
|
|
||||||
|
b = UserB(name='User B %s' % i)
|
||||||
|
b.save()
|
||||||
|
|
||||||
|
c = UserC(name='User C %s' % i)
|
||||||
|
c.save()
|
||||||
|
|
||||||
|
members += [a, b, c]
|
||||||
|
|
||||||
|
group = Group(members=members)
|
||||||
|
group.save()
|
||||||
|
|
||||||
|
with query_counter() as q:
|
||||||
|
self.assertEqual(q, 0)
|
||||||
|
|
||||||
|
group_obj = Group.objects.first()
|
||||||
|
self.assertEqual(q, 1)
|
||||||
|
|
||||||
|
[m for m in group_obj.members]
|
||||||
|
self.assertEqual(q, 4)
|
||||||
|
|
||||||
|
[m for m in group_obj.members]
|
||||||
|
self.assertEqual(q, 4)
|
||||||
|
|
||||||
|
UserA.drop_collection()
|
||||||
|
UserB.drop_collection()
|
||||||
|
UserC.drop_collection()
|
||||||
|
Group.drop_collection()
|
||||||
|
|
||||||
|
def test_map_field_reference(self):
|
||||||
|
|
||||||
|
class User(Document):
|
||||||
|
name = StringField()
|
||||||
|
|
||||||
|
class Group(Document):
|
||||||
|
members = MapField(ReferenceField(User))
|
||||||
|
|
||||||
|
User.drop_collection()
|
||||||
|
Group.drop_collection()
|
||||||
|
|
||||||
|
members = []
|
||||||
|
for i in xrange(1, 51):
|
||||||
|
user = User(name='user %s' % i)
|
||||||
|
user.save()
|
||||||
|
members.append(user)
|
||||||
|
|
||||||
|
group = Group(members=dict([(str(u.id), u) for u in members]))
|
||||||
|
group.save()
|
||||||
|
|
||||||
|
with query_counter() as q:
|
||||||
|
self.assertEqual(q, 0)
|
||||||
|
|
||||||
|
group_obj = Group.objects.first()
|
||||||
|
self.assertEqual(q, 1)
|
||||||
|
|
||||||
|
[m for m in group_obj.members]
|
||||||
|
self.assertEqual(q, 2)
|
||||||
|
|
||||||
|
User.drop_collection()
|
||||||
|
Group.drop_collection()
|
||||||
|
|
||||||
|
def ztest_generic_reference_dict_field(self):
|
||||||
|
|
||||||
|
class UserA(Document):
|
||||||
|
name = StringField()
|
||||||
|
|
||||||
|
class UserB(Document):
|
||||||
|
name = StringField()
|
||||||
|
|
||||||
|
class UserC(Document):
|
||||||
|
name = StringField()
|
||||||
|
|
||||||
|
class Group(Document):
|
||||||
|
members = DictField()
|
||||||
|
|
||||||
|
UserA.drop_collection()
|
||||||
|
UserB.drop_collection()
|
||||||
|
UserC.drop_collection()
|
||||||
|
Group.drop_collection()
|
||||||
|
|
||||||
|
members = []
|
||||||
|
for i in xrange(1, 51):
|
||||||
|
a = UserA(name='User A %s' % i)
|
||||||
|
a.save()
|
||||||
|
|
||||||
|
b = UserB(name='User B %s' % i)
|
||||||
|
b.save()
|
||||||
|
|
||||||
|
c = UserC(name='User C %s' % i)
|
||||||
|
c.save()
|
||||||
|
|
||||||
|
members += [a, b, c]
|
||||||
|
|
||||||
|
group = Group(members=dict([(str(u.id), u) for u in members]))
|
||||||
|
group.save()
|
||||||
|
|
||||||
|
with query_counter() as q:
|
||||||
|
self.assertEqual(q, 0)
|
||||||
|
|
||||||
|
group_obj = Group.objects.first()
|
||||||
|
self.assertEqual(q, 1)
|
||||||
|
|
||||||
|
[m for m in group_obj.members]
|
||||||
|
self.assertEqual(q, 4)
|
||||||
|
|
||||||
|
[m for m in group_obj.members]
|
||||||
|
self.assertEqual(q, 4)
|
||||||
|
|
||||||
|
group.members = {}
|
||||||
|
group.save()
|
||||||
|
|
||||||
|
with query_counter() as q:
|
||||||
|
self.assertEqual(q, 0)
|
||||||
|
|
||||||
|
group_obj = Group.objects.first()
|
||||||
|
self.assertEqual(q, 1)
|
||||||
|
|
||||||
|
[m for m in group_obj.members]
|
||||||
|
self.assertEqual(q, 1)
|
||||||
|
|
||||||
|
UserA.drop_collection()
|
||||||
|
UserB.drop_collection()
|
||||||
|
UserC.drop_collection()
|
||||||
|
Group.drop_collection()
|
||||||
|
|
||||||
|
def test_generic_reference_map_field(self):
|
||||||
|
|
||||||
|
class UserA(Document):
|
||||||
|
name = StringField()
|
||||||
|
|
||||||
|
class UserB(Document):
|
||||||
|
name = StringField()
|
||||||
|
|
||||||
|
class UserC(Document):
|
||||||
|
name = StringField()
|
||||||
|
|
||||||
|
class Group(Document):
|
||||||
|
members = MapField(GenericReferenceField())
|
||||||
|
|
||||||
|
UserA.drop_collection()
|
||||||
|
UserB.drop_collection()
|
||||||
|
UserC.drop_collection()
|
||||||
|
Group.drop_collection()
|
||||||
|
|
||||||
|
members = []
|
||||||
|
for i in xrange(1, 51):
|
||||||
|
a = UserA(name='User A %s' % i)
|
||||||
|
a.save()
|
||||||
|
|
||||||
|
b = UserB(name='User B %s' % i)
|
||||||
|
b.save()
|
||||||
|
|
||||||
|
c = UserC(name='User C %s' % i)
|
||||||
|
c.save()
|
||||||
|
|
||||||
|
members += [a, b, c]
|
||||||
|
|
||||||
|
group = Group(members=dict([(str(u.id), u) for u in members]))
|
||||||
|
group.save()
|
||||||
|
|
||||||
|
with query_counter() as q:
|
||||||
|
self.assertEqual(q, 0)
|
||||||
|
|
||||||
|
group_obj = Group.objects.first()
|
||||||
|
self.assertEqual(q, 1)
|
||||||
|
|
||||||
|
[m for m in group_obj.members]
|
||||||
|
self.assertEqual(q, 4)
|
||||||
|
|
||||||
|
[m for m in group_obj.members]
|
||||||
|
self.assertEqual(q, 4)
|
||||||
|
|
||||||
|
group.members = {}
|
||||||
|
group.save()
|
||||||
|
|
||||||
|
with query_counter() as q:
|
||||||
|
self.assertEqual(q, 0)
|
||||||
|
|
||||||
|
group_obj = Group.objects.first()
|
||||||
|
self.assertEqual(q, 1)
|
||||||
|
|
||||||
|
[m for m in group_obj.members]
|
||||||
|
self.assertEqual(q, 1)
|
||||||
|
|
||||||
|
UserA.drop_collection()
|
||||||
|
UserB.drop_collection()
|
||||||
|
UserC.drop_collection()
|
||||||
|
Group.drop_collection()
|
@ -377,6 +377,40 @@ class DocumentTest(unittest.TestCase):
|
|||||||
|
|
||||||
BlogPost.drop_collection()
|
BlogPost.drop_collection()
|
||||||
|
|
||||||
|
|
||||||
|
def test_dictionary_indexes(self):
|
||||||
|
"""Ensure that indexes are used when meta[indexes] contains dictionaries
|
||||||
|
instead of lists.
|
||||||
|
"""
|
||||||
|
class BlogPost(Document):
|
||||||
|
date = DateTimeField(db_field='addDate', default=datetime.now)
|
||||||
|
category = StringField()
|
||||||
|
tags = ListField(StringField())
|
||||||
|
meta = {
|
||||||
|
'indexes': [
|
||||||
|
{ 'fields': ['-date'], 'unique': True,
|
||||||
|
'sparse': True, 'types': False },
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
BlogPost.drop_collection()
|
||||||
|
|
||||||
|
info = BlogPost.objects._collection.index_information()
|
||||||
|
# _id, '-date'
|
||||||
|
self.assertEqual(len(info), 3)
|
||||||
|
|
||||||
|
# Indexes are lazy so use list() to perform query
|
||||||
|
list(BlogPost.objects)
|
||||||
|
info = BlogPost.objects._collection.index_information()
|
||||||
|
info = [(value['key'],
|
||||||
|
value.get('unique', False),
|
||||||
|
value.get('sparse', False))
|
||||||
|
for key, value in info.iteritems()]
|
||||||
|
self.assertTrue(([('addDate', -1)], True, True) in info)
|
||||||
|
|
||||||
|
BlogPost.drop_collection()
|
||||||
|
|
||||||
|
|
||||||
def test_unique(self):
|
def test_unique(self):
|
||||||
"""Ensure that uniqueness constraints are applied to fields.
|
"""Ensure that uniqueness constraints are applied to fields.
|
||||||
"""
|
"""
|
||||||
|
@ -187,6 +187,66 @@ class FieldTest(unittest.TestCase):
|
|||||||
log.time = '1pm'
|
log.time = '1pm'
|
||||||
self.assertRaises(ValidationError, log.validate)
|
self.assertRaises(ValidationError, log.validate)
|
||||||
|
|
||||||
|
def test_datetime(self):
|
||||||
|
"""Tests showing pymongo datetime fields handling of microseconds.
|
||||||
|
Microseconds are rounded to the nearest millisecond and pre UTC
|
||||||
|
handling is wonky.
|
||||||
|
|
||||||
|
See: http://api.mongodb.org/python/current/api/bson/son.html#dt
|
||||||
|
"""
|
||||||
|
class LogEntry(Document):
|
||||||
|
date = DateTimeField()
|
||||||
|
|
||||||
|
LogEntry.drop_collection()
|
||||||
|
|
||||||
|
# Post UTC - microseconds are rounded (down) nearest millisecond and dropped
|
||||||
|
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999)
|
||||||
|
d2 = datetime.datetime(1970, 01, 01, 00, 00, 01)
|
||||||
|
log = LogEntry()
|
||||||
|
log.date = d1
|
||||||
|
log.save()
|
||||||
|
log.reload()
|
||||||
|
self.assertNotEquals(log.date, d1)
|
||||||
|
self.assertEquals(log.date, d2)
|
||||||
|
|
||||||
|
# Post UTC - microseconds are rounded (down) nearest millisecond
|
||||||
|
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9999)
|
||||||
|
d2 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9000)
|
||||||
|
log.date = d1
|
||||||
|
log.save()
|
||||||
|
log.reload()
|
||||||
|
self.assertNotEquals(log.date, d1)
|
||||||
|
self.assertEquals(log.date, d2)
|
||||||
|
|
||||||
|
# Pre UTC dates microseconds below 1000 are dropped
|
||||||
|
d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999)
|
||||||
|
d2 = datetime.datetime(1969, 12, 31, 23, 59, 59)
|
||||||
|
log.date = d1
|
||||||
|
log.save()
|
||||||
|
log.reload()
|
||||||
|
self.assertNotEquals(log.date, d1)
|
||||||
|
self.assertEquals(log.date, d2)
|
||||||
|
|
||||||
|
# Pre UTC microseconds above 1000 is wonky.
|
||||||
|
# log.date has an invalid microsecond value so I can't construct
|
||||||
|
# a date to compare.
|
||||||
|
#
|
||||||
|
# However, the timedelta is predicable with pre UTC timestamps
|
||||||
|
# It always adds 16 seconds and [777216-776217] microseconds
|
||||||
|
for i in xrange(1001, 3113, 33):
|
||||||
|
d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, i)
|
||||||
|
log.date = d1
|
||||||
|
log.save()
|
||||||
|
log.reload()
|
||||||
|
self.assertNotEquals(log.date, d1)
|
||||||
|
|
||||||
|
delta = log.date - d1
|
||||||
|
self.assertEquals(delta.seconds, 16)
|
||||||
|
microseconds = 777216 - (i % 1000)
|
||||||
|
self.assertEquals(delta.microseconds, microseconds)
|
||||||
|
|
||||||
|
LogEntry.drop_collection()
|
||||||
|
|
||||||
def test_list_validation(self):
|
def test_list_validation(self):
|
||||||
"""Ensure that a list field only accepts lists with valid elements.
|
"""Ensure that a list field only accepts lists with valid elements.
|
||||||
"""
|
"""
|
||||||
|
@ -9,6 +9,7 @@ from mongoengine.queryset import (QuerySet, QuerySetManager,
|
|||||||
MultipleObjectsReturned, DoesNotExist,
|
MultipleObjectsReturned, DoesNotExist,
|
||||||
QueryFieldList)
|
QueryFieldList)
|
||||||
from mongoengine import *
|
from mongoengine import *
|
||||||
|
from mongoengine.tests import query_counter
|
||||||
|
|
||||||
|
|
||||||
class QuerySetTest(unittest.TestCase):
|
class QuerySetTest(unittest.TestCase):
|
||||||
@ -331,6 +332,125 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
person = self.Person.objects.get(age=50)
|
person = self.Person.objects.get(age=50)
|
||||||
self.assertEqual(person.name, "User C")
|
self.assertEqual(person.name, "User C")
|
||||||
|
|
||||||
|
def test_bulk_insert(self):
|
||||||
|
"""Ensure that query by array position works.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Comment(EmbeddedDocument):
|
||||||
|
name = StringField()
|
||||||
|
|
||||||
|
class Post(EmbeddedDocument):
|
||||||
|
comments = ListField(EmbeddedDocumentField(Comment))
|
||||||
|
|
||||||
|
class Blog(Document):
|
||||||
|
title = StringField()
|
||||||
|
tags = ListField(StringField())
|
||||||
|
posts = ListField(EmbeddedDocumentField(Post))
|
||||||
|
|
||||||
|
Blog.drop_collection()
|
||||||
|
|
||||||
|
with query_counter() as q:
|
||||||
|
self.assertEqual(q, 0)
|
||||||
|
|
||||||
|
comment1 = Comment(name='testa')
|
||||||
|
comment2 = Comment(name='testb')
|
||||||
|
post1 = Post(comments=[comment1, comment2])
|
||||||
|
post2 = Post(comments=[comment2, comment2])
|
||||||
|
|
||||||
|
blogs = []
|
||||||
|
for i in xrange(1, 100):
|
||||||
|
blogs.append(Blog(title="post %s" % i, posts=[post1, post2]))
|
||||||
|
|
||||||
|
Blog.objects.insert(blogs, load_bulk=False)
|
||||||
|
self.assertEqual(q, 2) # 1 for the inital connection and 1 for the insert
|
||||||
|
|
||||||
|
Blog.objects.insert(blogs)
|
||||||
|
self.assertEqual(q, 4) # 1 for insert, and 1 for in bulk
|
||||||
|
|
||||||
|
Blog.drop_collection()
|
||||||
|
|
||||||
|
comment1 = Comment(name='testa')
|
||||||
|
comment2 = Comment(name='testb')
|
||||||
|
post1 = Post(comments=[comment1, comment2])
|
||||||
|
post2 = Post(comments=[comment2, comment2])
|
||||||
|
blog1 = Blog(title="code", posts=[post1, post2])
|
||||||
|
blog2 = Blog(title="mongodb", posts=[post2, post1])
|
||||||
|
blog1, blog2 = Blog.objects.insert([blog1, blog2])
|
||||||
|
self.assertEqual(blog1.title, "code")
|
||||||
|
self.assertEqual(blog2.title, "mongodb")
|
||||||
|
|
||||||
|
self.assertEqual(Blog.objects.count(), 2)
|
||||||
|
|
||||||
|
# test handles people trying to upsert
|
||||||
|
def throw_operation_error():
|
||||||
|
blogs = Blog.objects
|
||||||
|
Blog.objects.insert(blogs)
|
||||||
|
|
||||||
|
self.assertRaises(OperationError, throw_operation_error)
|
||||||
|
|
||||||
|
# test handles other classes being inserted
|
||||||
|
def throw_operation_error_wrong_doc():
|
||||||
|
class Author(Document):
|
||||||
|
pass
|
||||||
|
Blog.objects.insert(Author())
|
||||||
|
|
||||||
|
self.assertRaises(OperationError, throw_operation_error_wrong_doc)
|
||||||
|
|
||||||
|
def throw_operation_error_not_a_document():
|
||||||
|
Blog.objects.insert("HELLO WORLD")
|
||||||
|
|
||||||
|
self.assertRaises(OperationError, throw_operation_error_not_a_document)
|
||||||
|
|
||||||
|
Blog.drop_collection()
|
||||||
|
|
||||||
|
blog1 = Blog(title="code", posts=[post1, post2])
|
||||||
|
blog1 = Blog.objects.insert(blog1)
|
||||||
|
self.assertEqual(blog1.title, "code")
|
||||||
|
self.assertEqual(Blog.objects.count(), 1)
|
||||||
|
|
||||||
|
Blog.drop_collection()
|
||||||
|
blog1 = Blog(title="code", posts=[post1, post2])
|
||||||
|
obj_id = Blog.objects.insert(blog1, load_bulk=False)
|
||||||
|
self.assertEquals(obj_id.__class__.__name__, 'ObjectId')
|
||||||
|
|
||||||
|
def test_slave_okay(self):
|
||||||
|
"""Ensures that a query can take slave_okay syntax
|
||||||
|
"""
|
||||||
|
person1 = self.Person(name="User A", age=20)
|
||||||
|
person1.save()
|
||||||
|
person2 = self.Person(name="User B", age=30)
|
||||||
|
person2.save()
|
||||||
|
|
||||||
|
# Retrieve the first person from the database
|
||||||
|
person = self.Person.objects.slave_okay(True).first()
|
||||||
|
self.assertTrue(isinstance(person, self.Person))
|
||||||
|
self.assertEqual(person.name, "User A")
|
||||||
|
self.assertEqual(person.age, 20)
|
||||||
|
|
||||||
|
def test_cursor_args(self):
|
||||||
|
"""Ensures the cursor args can be set as expected
|
||||||
|
"""
|
||||||
|
p = self.Person.objects
|
||||||
|
# Check default
|
||||||
|
self.assertEqual(p._cursor_args,
|
||||||
|
{'snapshot': False, 'slave_okay': False, 'timeout': True})
|
||||||
|
|
||||||
|
p.snapshot(False).slave_okay(False).timeout(False)
|
||||||
|
self.assertEqual(p._cursor_args,
|
||||||
|
{'snapshot': False, 'slave_okay': False, 'timeout': False})
|
||||||
|
|
||||||
|
p.snapshot(True).slave_okay(False).timeout(False)
|
||||||
|
self.assertEqual(p._cursor_args,
|
||||||
|
{'snapshot': True, 'slave_okay': False, 'timeout': False})
|
||||||
|
|
||||||
|
p.snapshot(True).slave_okay(True).timeout(False)
|
||||||
|
self.assertEqual(p._cursor_args,
|
||||||
|
{'snapshot': True, 'slave_okay': True, 'timeout': False})
|
||||||
|
|
||||||
|
p.snapshot(True).slave_okay(True).timeout(True)
|
||||||
|
self.assertEqual(p._cursor_args,
|
||||||
|
{'snapshot': True, 'slave_okay': True, 'timeout': True})
|
||||||
|
|
||||||
def test_repeated_iteration(self):
|
def test_repeated_iteration(self):
|
||||||
"""Ensure that QuerySet rewinds itself one iteration finishes.
|
"""Ensure that QuerySet rewinds itself one iteration finishes.
|
||||||
"""
|
"""
|
||||||
@ -2099,8 +2219,27 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
Number.drop_collection()
|
Number.drop_collection()
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_index(self):
|
||||||
|
"""Ensure that manual creation of indexes works.
|
||||||
|
"""
|
||||||
|
class Comment(Document):
|
||||||
|
message = StringField()
|
||||||
|
|
||||||
|
Comment.objects.ensure_index('message')
|
||||||
|
|
||||||
|
info = Comment.objects._collection.index_information()
|
||||||
|
info = [(value['key'],
|
||||||
|
value.get('unique', False),
|
||||||
|
value.get('sparse', False))
|
||||||
|
for key, value in info.iteritems()]
|
||||||
|
self.assertTrue(([('_types', 1), ('message', 1)], False, False) in info)
|
||||||
|
|
||||||
|
|
||||||
class QTest(unittest.TestCase):
|
class QTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
connect(db='mongoenginetest')
|
||||||
|
|
||||||
def test_empty_q(self):
|
def test_empty_q(self):
|
||||||
"""Ensure that empty Q objects won't hurt.
|
"""Ensure that empty Q objects won't hurt.
|
||||||
"""
|
"""
|
||||||
|
130
tests/signals.py
Normal file
130
tests/signals.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from mongoengine import *
|
||||||
|
from mongoengine import signals
|
||||||
|
|
||||||
|
signal_output = []
|
||||||
|
|
||||||
|
|
||||||
|
class SignalTests(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Testing signals before/after saving and deleting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_signal_output(self, fn, *args, **kwargs):
|
||||||
|
# Flush any existing signal output
|
||||||
|
global signal_output
|
||||||
|
signal_output = []
|
||||||
|
fn(*args, **kwargs)
|
||||||
|
return signal_output
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
connect(db='mongoenginetest')
|
||||||
|
class Author(Document):
|
||||||
|
name = StringField()
|
||||||
|
|
||||||
|
def __unicode__(self):
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def pre_init(cls, instance, **kwargs):
|
||||||
|
signal_output.append('pre_init signal, %s' % cls.__name__)
|
||||||
|
signal_output.append(str(kwargs['values']))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def post_init(cls, instance, **kwargs):
|
||||||
|
signal_output.append('post_init signal, %s' % instance)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def pre_save(cls, instance, **kwargs):
|
||||||
|
signal_output.append('pre_save signal, %s' % instance)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def post_save(cls, instance, **kwargs):
|
||||||
|
signal_output.append('post_save signal, %s' % instance)
|
||||||
|
if 'created' in kwargs:
|
||||||
|
if kwargs['created']:
|
||||||
|
signal_output.append('Is created')
|
||||||
|
else:
|
||||||
|
signal_output.append('Is updated')
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def pre_delete(cls, instance, **kwargs):
|
||||||
|
signal_output.append('pre_delete signal, %s' % instance)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def post_delete(cls, instance, **kwargs):
|
||||||
|
signal_output.append('post_delete signal, %s' % instance)
|
||||||
|
|
||||||
|
self.Author = Author
|
||||||
|
|
||||||
|
# Save up the number of connected signals so that we can check at the end
|
||||||
|
# that all the signals we register get properly unregistered
|
||||||
|
self.pre_signals = (
|
||||||
|
len(signals.pre_init.receivers),
|
||||||
|
len(signals.post_init.receivers),
|
||||||
|
len(signals.pre_save.receivers),
|
||||||
|
len(signals.post_save.receivers),
|
||||||
|
len(signals.pre_delete.receivers),
|
||||||
|
len(signals.post_delete.receivers)
|
||||||
|
)
|
||||||
|
|
||||||
|
signals.pre_init.connect(Author.pre_init)
|
||||||
|
signals.post_init.connect(Author.post_init)
|
||||||
|
signals.pre_save.connect(Author.pre_save)
|
||||||
|
signals.post_save.connect(Author.post_save)
|
||||||
|
signals.pre_delete.connect(Author.pre_delete)
|
||||||
|
signals.post_delete.connect(Author.post_delete)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
signals.pre_init.disconnect(self.Author.pre_init)
|
||||||
|
signals.post_init.disconnect(self.Author.post_init)
|
||||||
|
signals.post_delete.disconnect(self.Author.post_delete)
|
||||||
|
signals.pre_delete.disconnect(self.Author.pre_delete)
|
||||||
|
signals.post_save.disconnect(self.Author.post_save)
|
||||||
|
signals.pre_save.disconnect(self.Author.pre_save)
|
||||||
|
|
||||||
|
# Check that all our signals got disconnected properly.
|
||||||
|
post_signals = (
|
||||||
|
len(signals.pre_init.receivers),
|
||||||
|
len(signals.post_init.receivers),
|
||||||
|
len(signals.pre_save.receivers),
|
||||||
|
len(signals.post_save.receivers),
|
||||||
|
len(signals.pre_delete.receivers),
|
||||||
|
len(signals.post_delete.receivers)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(self.pre_signals, post_signals)
|
||||||
|
|
||||||
|
def test_model_signals(self):
|
||||||
|
""" Model saves should throw some signals. """
|
||||||
|
|
||||||
|
def create_author():
|
||||||
|
a1 = self.Author(name='Bill Shakespeare')
|
||||||
|
|
||||||
|
self.assertEqual(self.get_signal_output(create_author), [
|
||||||
|
"pre_init signal, Author",
|
||||||
|
"{'name': 'Bill Shakespeare'}",
|
||||||
|
"post_init signal, Bill Shakespeare",
|
||||||
|
])
|
||||||
|
|
||||||
|
a1 = self.Author(name='Bill Shakespeare')
|
||||||
|
self.assertEqual(self.get_signal_output(a1.save), [
|
||||||
|
"pre_save signal, Bill Shakespeare",
|
||||||
|
"post_save signal, Bill Shakespeare",
|
||||||
|
"Is created"
|
||||||
|
])
|
||||||
|
|
||||||
|
a1.reload()
|
||||||
|
a1.name='William Shakespeare'
|
||||||
|
self.assertEqual(self.get_signal_output(a1.save), [
|
||||||
|
"pre_save signal, William Shakespeare",
|
||||||
|
"post_save signal, William Shakespeare",
|
||||||
|
"Is updated"
|
||||||
|
])
|
||||||
|
|
||||||
|
self.assertEqual(self.get_signal_output(a1.delete), [
|
||||||
|
'pre_delete signal, William Shakespeare',
|
||||||
|
'post_delete signal, William Shakespeare',
|
||||||
|
])
|
Loading…
x
Reference in New Issue
Block a user