Merge pull request #1 from MongoEngine/master

pull new changes from original
This commit is contained in:
iici-gli 2016-09-04 23:43:04 -04:00 committed by GitHub
commit e5b6a12977
37 changed files with 1228 additions and 488 deletions

View File

@ -5,6 +5,7 @@ python:
- '3.2'
- '3.3'
- '3.4'
- '3.5'
- pypy
- pypy3
env:
@ -24,7 +25,8 @@ install:
- sudo apt-get install python-dev python3-dev libopenjpeg-dev zlib1g-dev libjpeg-turbo8-dev
libtiff4-dev libjpeg8-dev libfreetype6-dev liblcms2-dev libwebp-dev tcl8.5-dev tk8.5-dev
python-tk
- travis_retry pip install tox>=1.9 coveralls
# virtualenv>=14.0.0 has dropped Python 3.2 support
- travis_retry pip install "virtualenv<14.0.0" "tox>=1.9" coveralls
- travis_retry tox -e $(echo py$TRAVIS_PYTHON_VERSION-mg$PYMONGO | tr -d . | sed -e 's/pypypy/pypy/') -- -e test
script:
- tox -e $(echo py$TRAVIS_PYTHON_VERSION-mg$PYMONGO | tr -d . | sed -e 's/pypypy/pypy/') -- --with-coverage

11
AUTHORS
View File

@ -230,3 +230,14 @@ that much better:
* Amit Lichtenberg (https://github.com/amitlicht)
* Gang Li (https://github.com/iici-gli)
* Lars Butler (https://github.com/larsbutler)
* George Macon (https://github.com/gmacon)
* Ashley Whetter (https://github.com/AWhetter)
* Paul-Armand Verhaegen (https://github.com/paularmand)
* Steven Rossiter (https://github.com/BeardedSteve)
* Luo Peng (https://github.com/RussellLuo)
* Bryan Bennett (https://github.com/bbenne10)
* Gilb's Gilb's (https://github.com/gilbsgilbs)
* Joshua Nedrud (https://github.com/Neurostack)
* Shu Shen (https://github.com/shushen)
* xiaost7 (https://github.com/xiaost7)
* Victor Varvaryuk

View File

@ -19,10 +19,10 @@ MongoEngine
About
=====
MongoEngine is a Python Object-Document Mapper for working with MongoDB.
Documentation available at http://mongoengine-odm.rtfd.org - there is currently
a `tutorial <http://readthedocs.org/docs/mongoengine-odm/en/latest/tutorial.html>`_, a `user guide
<https://mongoengine-odm.readthedocs.org/en/latest/guide/index.html>`_ and an `API reference
<http://readthedocs.org/docs/mongoengine-odm/en/latest/apireference.html>`_.
Documentation available at https://mongoengine-odm.readthedocs.io - there is currently
a `tutorial <https://mongoengine-odm.readthedocs.io/tutorial.html>`_, a `user guide
<https://mongoengine-odm.readthedocs.io/guide/index.html>`_ and an `API reference
<https://mongoengine-odm.readthedocs.io/apireference.html>`_.
Installation
============
@ -48,7 +48,9 @@ Optional Dependencies
Examples
========
Some simple examples of what MongoEngine code looks like::
Some simple examples of what MongoEngine code looks like:
.. code :: python
class BlogPost(Document):
title = StringField(required=True, max_length=200)
@ -97,7 +99,7 @@ Some simple examples of what MongoEngine code looks like::
Tests
=====
To run the test suite, ensure you are running a local instance of MongoDB on
the standard port, and run: ``python setup.py nosetests``.
the standard port and have installed ``nose`` and ``rednose``, and run: ``python setup.py nosetests``.
To run the test suite on every supported Python version and every supported PyMongo version,
you can use ``tox``.

View File

@ -2,8 +2,50 @@
Changelog
=========
Changes in 0.10.1 - DEV
Changes in 0.10.7 - DEV
=======================
- Fixed the bug where dynamic doc has index inside a dict field #1278
- Fixed not being able to specify `use_db_field=False` on `ListField(EmbeddedDocumentField)` instances
- Fixed cascade delete mixing among collections #1224
- Add `signal_kwargs` argument to `Document.save`, `Document.delete` and `BaseQuerySet.insert` to be passed to signals calls #1206
- Raise `OperationError` when trying to do a `drop_collection` on document with no collection set.
- count on ListField of EmbeddedDocumentField fails. #1187
- Fixed long fields stored as int32 in Python 3. #1253
- MapField now handles unicodes keys correctly. #1267
- ListField now handles negative indicies correctly. #1270
- Fixed AttributeError when initializing EmbeddedDocument with positional args. #681
- Fixed no_cursor_timeout error with pymongo 3.0+ #1304
- Replaced map-reduce based QuerySet.sum/average with aggregation-based implementations #1336
- Fixed support for `__` to escape field names that match operators names in `update` #1351
Changes in 0.10.6
=================
- Add support for mocking MongoEngine based on mongomock. #1151
- Fixed not being able to run tests on Windows. #1153
- Allow creation of sparse compound indexes. #1114
- count on ListField of EmbeddedDocumentField fails. #1187
Changes in 0.10.5
=================
- Fix for reloading of strict with special fields. #1156
Changes in 0.10.4
=================
- SaveConditionError is now importable from the top level package. #1165
- upsert_one method added. #1157
Changes in 0.10.3
=================
- Fix `read_preference` (it had chaining issues with PyMongo 2.x and it didn't work at all with PyMongo 3.x) #1042
Changes in 0.10.2
=================
- Allow shard key to point to a field in an embedded document. #551
- Allow arbirary metadata in fields. #1129
- ReferenceFields now support abstract document types. #837
Changes in 0.10.1
=================
- Fix infinite recursion with CASCADE delete rules under specific conditions. #1046
- Fix CachedReferenceField bug when loading cached docs as DBRef but failing to save them. #1047
- Fix ignored chained options #842
@ -13,6 +55,8 @@ Changes in 0.10.1 - DEV
- Fix ListField minus index assignment does not work. #1119
- Remove code that marks field as changed when the field has default but not existed in database #1126
- Remove test dependencies (nose and rednose) from install dependencies list. #1079
- Recursively build query when using elemMatch operator. #1130
- Fix instance back references for lists of embedded documents. #1131
Changes in 0.10.0
=================

View File

@ -17,6 +17,10 @@ class Post(Document):
tags = ListField(StringField(max_length=30))
comments = ListField(EmbeddedDocumentField(Comment))
# bugfix
meta = {'allow_inheritance': True}
class TextPost(Post):
content = StringField()
@ -45,7 +49,8 @@ print 'ALL POSTS'
print
for post in Post.objects:
print post.title
print '=' * post.title.count()
#print '=' * post.title.count()
print "=" * 20
if isinstance(post, TextPost):
print post.content

View File

@ -29,7 +29,7 @@ documents are serialized based on their field order.
Dynamic document schemas
========================
One of the benefits of MongoDb is dynamic schemas for a collection, whilst data
One of the benefits of MongoDB is dynamic schemas for a collection, whilst data
should be planned and organised (after all explicit is better than implicit!)
there are scenarios where having dynamic / expando style documents is desirable.
@ -75,6 +75,7 @@ are as follows:
* :class:`~mongoengine.fields.DynamicField`
* :class:`~mongoengine.fields.EmailField`
* :class:`~mongoengine.fields.EmbeddedDocumentField`
* :class:`~mongoengine.fields.EmbeddedDocumentListField`
* :class:`~mongoengine.fields.FileField`
* :class:`~mongoengine.fields.FloatField`
* :class:`~mongoengine.fields.GenericEmbeddedDocumentField`
@ -172,11 +173,11 @@ arguments can be set on all fields:
class Shirt(Document):
size = StringField(max_length=3, choices=SIZE)
:attr:`help_text` (Default: None)
Optional help text to output with the field -- used by form libraries
:attr:`verbose_name` (Default: None)
Optional human-readable name for the field -- used by form libraries
:attr:`**kwargs` (Optional)
You can supply additional metadata as arbitrary additional keyword
arguments. You can not override existing attributes, however. Common
choices include `help_text` and `verbose_name`, commonly used by form and
widget libraries.
List fields

View File

@ -13,3 +13,4 @@ User Guide
gridfs
signals
text-indexes
mongomock

21
docs/guide/mongomock.rst Normal file
View File

@ -0,0 +1,21 @@
==============================
Use mongomock for testing
==============================
`mongomock <https://github.com/vmalloc/mongomock/>`_ is a package to do just
what the name implies, mocking a mongo database.
To use with mongoengine, simply specify mongomock when connecting with
mongoengine:
.. code-block:: python
connect('mongoenginetest', host='mongomock://localhost')
conn = get_connection()
or with an alias:
.. code-block:: python
connect('mongoenginetest', host='mongomock://localhost', alias='testdb')
conn = get_connection('testdb')

View File

@ -237,7 +237,7 @@ is preferred for achieving this::
# All except for the first 5 people
users = User.objects[5:]
# 5 users, starting from the 10th user found
# 5 users, starting from the 11th user found
users = User.objects[10:15]
You may also index the query to retrieve a single result. If an item at that

View File

@ -14,7 +14,7 @@ import errors
__all__ = (list(document.__all__) + fields.__all__ + connection.__all__ +
list(queryset.__all__) + signals.__all__ + list(errors.__all__))
VERSION = (0, 10, 0)
VERSION = (0, 10, 6)
def get_version():

View File

@ -199,7 +199,8 @@ class BaseList(list):
def _mark_as_changed(self, key=None):
if hasattr(self._instance, '_mark_as_changed'):
if key:
self._instance._mark_as_changed('%s.%s' % (self._name, key))
self._instance._mark_as_changed('%s.%s' % (self._name,
key % len(self)))
else:
self._instance._mark_as_changed(self._name)
@ -210,7 +211,7 @@ class EmbeddedDocumentList(BaseList):
def __match_all(cls, i, kwargs):
items = kwargs.items()
return all([
getattr(i, k) == v or str(getattr(i, k)) == v for k, v in items
getattr(i, k) == v or unicode(getattr(i, k)) == v for k, v in items
])
@classmethod

View File

@ -51,7 +51,7 @@ class BaseDocument(object):
# We only want named arguments.
field = iter(self._fields_ordered)
# If its an automatic id field then skip to the first defined field
if self._auto_id_field:
if getattr(self, '_auto_id_field', False):
next(field)
for value in args:
name = next(field)
@ -325,20 +325,17 @@ class BaseDocument(object):
if value is not None:
if isinstance(field, EmbeddedDocumentField):
if fields:
key = '%s.' % field_name
embedded_fields = [
i.replace(key, '') for i in fields
if i.startswith(key)]
if fields:
key = '%s.' % field_name
embedded_fields = [
i.replace(key, '') for i in fields
if i.startswith(key)]
else:
embedded_fields = []
value = field.to_mongo(value, use_db_field=use_db_field,
fields=embedded_fields)
else:
value = field.to_mongo(value)
embedded_fields = []
value = field.to_mongo(value, use_db_field=use_db_field,
fields=embedded_fields)
# Handle self generating fields
if value is None and field._auto_gen:
@ -835,10 +832,6 @@ class BaseDocument(object):
if index_list:
spec['fields'] = index_list
if spec.get('sparse', False) and len(spec['fields']) > 1:
raise ValueError(
'Sparse indexes can only have one field in them. '
'See https://jira.mongodb.org/browse/SERVER-2193')
return spec
@ -974,7 +967,7 @@ class BaseDocument(object):
if hasattr(getattr(field, 'field', None), 'lookup_member'):
new_field = field.field.lookup_member(field_name)
elif cls._dynamic and (isinstance(field, DynamicField) or
getattr(getattr(field, 'document_type'), '_dynamic')):
getattr(getattr(field, 'document_type', None), '_dynamic', None)):
new_field = DynamicField(db_field=field_name)
else:
# Look up subfield on the previous field or raise

View File

@ -41,8 +41,8 @@ class BaseField(object):
def __init__(self, db_field=None, name=None, required=False, default=None,
unique=False, unique_with=None, primary_key=False,
validation=None, choices=None, verbose_name=None,
help_text=None, null=False, sparse=False, custom_data=None):
validation=None, choices=None, null=False, sparse=False,
**kwargs):
"""
:param db_field: The database field to store this field in
(defaults to the name of the field)
@ -60,16 +60,15 @@ class BaseField(object):
field. Generally this is deprecated in favour of the
`FIELD.validate` method
:param choices: (optional) The valid choices
:param verbose_name: (optional) The verbose name for the field.
Designed to be human readable and is often used when generating
model forms from the document model.
:param help_text: (optional) The help text for this field and is often
used when generating model forms from the document model.
:param null: (optional) Is the field value can be null. If no and there is a default value
then the default value is set
:param sparse: (optional) `sparse=True` combined with `unique=True` and `required=False`
means that uniqueness won't be enforced for `None` values
:param custom_data: (optional) Custom metadata for this field.
:param **kwargs: (optional) Arbitrary indirection-free metadata for
this field can be supplied as additional keyword arguments and
accessed as attributes of the field. Must not conflict with any
existing attributes. Common metadata includes `verbose_name` and
`help_text`.
"""
self.db_field = (db_field or name) if not primary_key else '_id'
@ -83,12 +82,19 @@ class BaseField(object):
self.primary_key = primary_key
self.validation = validation
self.choices = choices
self.verbose_name = verbose_name
self.help_text = help_text
self.null = null
self.sparse = sparse
self._owner_document = None
self.custom_data = custom_data
# Detect and report conflicts between metadata and base properties.
conflicts = set(dir(self)) & set(kwargs)
if conflicts:
raise TypeError("%s already has attribute(s): %s" % (
self.__class__.__name__, ', '.join(conflicts) ))
# Assign metadata to the instance
# This efficient method is available because no __slots__ are defined.
self.__dict__.update(kwargs)
# Adjust the appropriate creation counter, and save our local copy.
if self.db_field == '_id':
@ -127,7 +133,7 @@ class BaseField(object):
if (self.name not in instance._data or
instance._data[self.name] != value):
instance._mark_as_changed(self.name)
except:
except Exception:
# Values cant be compared eg: naive and tz datetimes
# So mark it as changed
instance._mark_as_changed(self.name)
@ -135,6 +141,10 @@ class BaseField(object):
EmbeddedDocument = _import_class('EmbeddedDocument')
if isinstance(value, EmbeddedDocument):
value._instance = weakref.proxy(instance)
elif isinstance(value, (list, tuple)):
for v in value:
if isinstance(v, EmbeddedDocument):
v._instance = weakref.proxy(instance)
instance._data[self.name] = value
def error(self, message="", errors=None, field_name=None):
@ -148,7 +158,7 @@ class BaseField(object):
"""
return value
def to_mongo(self, value):
def to_mongo(self, value, **kwargs):
"""Convert a Python type to a MongoDB-compatible type.
"""
return self.to_python(value)
@ -275,8 +285,6 @@ class ComplexBaseField(BaseField):
def to_python(self, value):
"""Convert a MongoDB-compatible type to a Python type.
"""
Document = _import_class('Document')
if isinstance(value, basestring):
return value
@ -296,6 +304,7 @@ class ComplexBaseField(BaseField):
value_dict = dict([(key, self.field.to_python(item))
for key, item in value.items()])
else:
Document = _import_class('Document')
value_dict = {}
for k, v in value.items():
if isinstance(v, Document):
@ -315,7 +324,7 @@ class ComplexBaseField(BaseField):
key=operator.itemgetter(0))]
return value_dict
def to_mongo(self, value):
def to_mongo(self, value, **kwargs):
"""Convert a Python type to a MongoDB-compatible type.
"""
Document = _import_class("Document")
@ -327,9 +336,10 @@ class ComplexBaseField(BaseField):
if hasattr(value, 'to_mongo'):
if isinstance(value, Document):
return GenericReferenceField().to_mongo(value)
return GenericReferenceField().to_mongo(
value, **kwargs)
cls = value.__class__
val = value.to_mongo()
val = value.to_mongo(**kwargs)
# If it's a document that is not inherited add _cls
if isinstance(value, EmbeddedDocument):
val['_cls'] = cls.__name__
@ -344,7 +354,7 @@ class ComplexBaseField(BaseField):
return value
if self.field:
value_dict = dict([(key, self.field.to_mongo(item))
value_dict = dict([(key, self.field.to_mongo(item, **kwargs))
for key, item in value.iteritems()])
else:
value_dict = {}
@ -363,19 +373,20 @@ class ComplexBaseField(BaseField):
meta.get('allow_inheritance', ALLOW_INHERITANCE)
is True)
if not allow_inheritance and not self.field:
value_dict[k] = GenericReferenceField().to_mongo(v)
value_dict[k] = GenericReferenceField().to_mongo(
v, **kwargs)
else:
collection = v._get_collection_name()
value_dict[k] = DBRef(collection, v.pk)
elif hasattr(v, 'to_mongo'):
cls = v.__class__
val = v.to_mongo()
val = v.to_mongo(**kwargs)
# If it's a document that is not inherited add _cls
if isinstance(v, (Document, EmbeddedDocument)):
val['_cls'] = cls.__name__
value_dict[k] = val
else:
value_dict[k] = self.to_mongo(v)
value_dict[k] = self.to_mongo(v, **kwargs)
if is_list: # Convert back to a list
return [v for _, v in sorted(value_dict.items(),
@ -429,11 +440,11 @@ class ObjectIdField(BaseField):
try:
if not isinstance(value, ObjectId):
value = ObjectId(value)
except:
except Exception:
pass
return value
def to_mongo(self, value):
def to_mongo(self, value, **kwargs):
if not isinstance(value, ObjectId):
try:
return ObjectId(unicode(value))
@ -448,7 +459,7 @@ class ObjectIdField(BaseField):
def validate(self, value):
try:
ObjectId(unicode(value))
except:
except Exception:
self.error('Invalid Object ID')
@ -500,7 +511,7 @@ class GeoJsonBaseField(BaseField):
# Quick and dirty validator
try:
value[0][0][0]
except:
except (TypeError, IndexError):
return "Invalid Polygon must contain at least one valid linestring"
errors = []
@ -524,7 +535,7 @@ class GeoJsonBaseField(BaseField):
# Quick and dirty validator
try:
value[0][0]
except:
except (TypeError, IndexError):
return "Invalid LineString must contain at least one valid point"
errors = []
@ -555,7 +566,7 @@ class GeoJsonBaseField(BaseField):
# Quick and dirty validator
try:
value[0][0]
except:
except (TypeError, IndexError):
return "Invalid MultiPoint must contain at least one valid point"
errors = []
@ -574,7 +585,7 @@ class GeoJsonBaseField(BaseField):
# Quick and dirty validator
try:
value[0][0][0]
except:
except (TypeError, IndexError):
return "Invalid MultiLineString must contain at least one valid linestring"
errors = []
@ -596,7 +607,7 @@ class GeoJsonBaseField(BaseField):
# Quick and dirty validator
try:
value[0][0][0][0]
except:
except (TypeError, IndexError):
return "Invalid MultiPolygon must contain at least one valid Polygon"
errors = []
@ -608,7 +619,7 @@ class GeoJsonBaseField(BaseField):
if errors:
return "Invalid MultiPolygon:\n%s" % ", ".join(errors)
def to_mongo(self, value):
def to_mongo(self, value, **kwargs):
if isinstance(value, dict):
return value
return SON([("type", self._type), ("coordinates", value)])

View File

@ -38,8 +38,11 @@ def register_connection(alias, name=None, host=None, port=None,
:param username: username to authenticate with
:param password: password to authenticate with
:param authentication_source: database to authenticate against
:param is_mock: explicitly use mongomock for this connection
(can also be done by using `mongomock://` as db host prefix)
:param kwargs: allow ad-hoc parameters to be passed into the pymongo driver
.. versionchanged:: 0.10.6 - added mongomock support
"""
global _connection_settings
@ -54,8 +57,13 @@ def register_connection(alias, name=None, host=None, port=None,
}
# Handle uri style connections
if "://" in conn_settings['host']:
uri_dict = uri_parser.parse_uri(conn_settings['host'])
conn_host = conn_settings['host']
if conn_host.startswith('mongomock://'):
conn_settings['is_mock'] = True
# `mongomock://` is not a valid url prefix and must be replaced by `mongodb://`
conn_settings['host'] = conn_host.replace('mongomock://', 'mongodb://', 1)
elif '://' in conn_host:
uri_dict = uri_parser.parse_uri(conn_host)
conn_settings.update({
'name': uri_dict.get('database') or name,
'username': uri_dict.get('username'),
@ -106,7 +114,19 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
conn_settings.pop('password', None)
conn_settings.pop('authentication_source', None)
connection_class = MongoClient
is_mock = conn_settings.pop('is_mock', None)
if is_mock:
# Use MongoClient from mongomock
try:
import mongomock
except ImportError:
raise RuntimeError('You need mongomock installed '
'to mock MongoEngine.')
connection_class = mongomock.MongoClient
else:
# Use MongoClient from pymongo
connection_class = MongoClient
if 'replicaSet' in conn_settings:
# Discard port since it can't be used on MongoReplicaSetClient
conn_settings.pop('port', None)
@ -126,6 +146,7 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
connection_settings.pop('name', None)
connection_settings.pop('username', None)
connection_settings.pop('password', None)
connection_settings.pop('authentication_source', None)
if conn_settings == connection_settings and _connections.get(db_alias, None):
connection = _connections[db_alias]
break

View File

@ -1,5 +1,7 @@
from bson import DBRef, SON
from mongoengine.python_support import txt_type
from base import (
BaseDict, BaseList, EmbeddedDocumentList,
TopLevelDocumentMetaclass, get_document
@ -226,7 +228,7 @@ class DeReference(object):
data[k]._data[field_name] = self.object_map.get(
(v['_ref'].collection, v['_ref'].id), v)
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
item_name = "{0}.{1}.{2}".format(name, k, field_name)
item_name = txt_type("{0}.{1}.{2}").format(name, k, field_name)
data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=item_name)
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
item_name = '%s.%s' % (name, k) if name else name

View File

@ -217,7 +217,7 @@ class Document(BaseDocument):
Returns True if the document has been updated or False if the document
in the database doesn't match the query.
.. note:: All unsaved changes that has been made to the document are
.. note:: All unsaved changes that have been made to the document are
rejected if the method returns True.
:param query: the update will be performed only if the document in the
@ -250,7 +250,7 @@ class Document(BaseDocument):
def save(self, force_insert=False, validate=True, clean=True,
write_concern=None, cascade=None, cascade_kwargs=None,
_refs=None, save_condition=None, **kwargs):
_refs=None, save_condition=None, signal_kwargs=None, **kwargs):
"""Save the :class:`~mongoengine.Document` to the database. If the
document already exists, it will be updated, otherwise it will be
created.
@ -276,6 +276,8 @@ class Document(BaseDocument):
:param save_condition: only perform save if matching record in db
satisfies condition(s) (e.g. version number).
Raises :class:`OperationError` if the conditions are not satisfied
:parm signal_kwargs: (optional) kwargs dictionary to be passed to
the signal calls.
.. versionchanged:: 0.5
In existing documents it only saves changed fields using
@ -297,8 +299,11 @@ class Document(BaseDocument):
:class:`OperationError` exception raised if save_condition fails.
.. versionchanged:: 0.10.1
:class: save_condition failure now raises a `SaveConditionError`
.. versionchanged:: 0.10.7
Add signal_kwargs argument
"""
signals.pre_save.send(self.__class__, document=self)
signal_kwargs = signal_kwargs or {}
signals.pre_save.send(self.__class__, document=self, **signal_kwargs)
if validate:
self.validate(clean=clean)
@ -311,7 +316,7 @@ class Document(BaseDocument):
created = ('_id' not in doc or self._created or force_insert)
signals.pre_save_post_validation.send(self.__class__, document=self,
created=created)
created=created, **signal_kwargs)
try:
collection = self._get_collection()
@ -341,8 +346,12 @@ class Document(BaseDocument):
select_dict['_id'] = object_id
shard_key = self.__class__._meta.get('shard_key', tuple())
for k in shard_key:
actual_key = self._db_field_map.get(k, k)
select_dict[actual_key] = doc[actual_key]
path = self._lookup_field(k.split('.'))
actual_key = [p.db_field for p in path]
val = doc
for ak in actual_key:
val = val[ak]
select_dict['.'.join(actual_key)] = val
def is_new_object(last_error):
if last_error is not None:
@ -396,14 +405,15 @@ class Document(BaseDocument):
if created or id_field not in self._meta.get('shard_key', []):
self[id_field] = self._fields[id_field].to_python(object_id)
signals.post_save.send(self.__class__, document=self, created=created)
signals.post_save.send(self.__class__, document=self,
created=created, **signal_kwargs)
self._clear_changed_fields()
self._created = False
return self
def cascade_save(self, *args, **kwargs):
"""Recursively saves any references /
generic references on an objects"""
generic references on the document"""
_refs = kwargs.get('_refs', []) or []
ReferenceField = _import_class('ReferenceField')
@ -444,7 +454,12 @@ class Document(BaseDocument):
select_dict = {'pk': self.pk}
shard_key = self.__class__._meta.get('shard_key', tuple())
for k in shard_key:
select_dict[k] = getattr(self, k)
path = self._lookup_field(k.split('.'))
actual_key = [p.db_field for p in path]
val = self
for ak in actual_key:
val = getattr(val, ak)
select_dict['__'.join(actual_key)] = val
return select_dict
def update(self, **kwargs):
@ -467,18 +482,24 @@ class Document(BaseDocument):
# Need to add shard key to query, or you get an error
return self._qs.filter(**self._object_key).update_one(**kwargs)
def delete(self, **write_concern):
def delete(self, signal_kwargs=None, **write_concern):
"""Delete the :class:`~mongoengine.Document` from the database. This
will only take effect if the document has been previously saved.
:parm signal_kwargs: (optional) kwargs dictionary to be passed to
the signal calls.
:param write_concern: Extra keyword arguments are passed down which
will be used as options for the resultant
``getLastError`` command. For example,
``save(..., write_concern={w: 2, fsync: True}, ...)`` will
wait until at least two servers have recorded the write and
will force an fsync on the primary server.
.. versionchanged:: 0.10.7
Add signal_kwargs argument
"""
signals.pre_delete.send(self.__class__, document=self)
signal_kwargs = signal_kwargs or {}
signals.pre_delete.send(self.__class__, document=self, **signal_kwargs)
# Delete FileFields separately
FileField = _import_class('FileField')
@ -492,7 +513,7 @@ class Document(BaseDocument):
except pymongo.errors.OperationFailure, err:
message = u'Could not delete document (%s)' % err.message
raise OperationError(message)
signals.post_delete.send(self.__class__, document=self)
signals.post_delete.send(self.__class__, document=self, **signal_kwargs)
def switch_db(self, db_alias, keep_created=True):
"""
@ -595,11 +616,16 @@ class Document(BaseDocument):
if not fields or field in fields:
try:
setattr(self, field, self._reload(field, obj[field]))
except KeyError:
# If field is removed from the database while the object
# is in memory, a reload would cause a KeyError
# i.e. obj.update(unset__field=1) followed by obj.reload()
delattr(self, field)
except (KeyError, AttributeError):
try:
# If field is a special field, e.g. items is stored as _reserved_items,
# an KeyError is thrown. So try to retrieve the field from _data
setattr(self, field, self._reload(field, obj._data.get(field)))
except KeyError:
# If field is removed from the database while the object
# is in memory, a reload would cause a KeyError
# i.e. obj.update(unset__field=1) followed by obj.reload()
delattr(self, field)
self._changed_fields = obj._changed_fields
self._created = False
@ -653,10 +679,20 @@ class Document(BaseDocument):
def drop_collection(cls):
"""Drops the entire collection associated with this
:class:`~mongoengine.Document` type from the database.
Raises :class:`OperationError` if the document has no collection set
(i.g. if it is `abstract`)
.. versionchanged:: 0.10.7
:class:`OperationError` exception raised if no collection available
"""
col_name = cls._get_collection_name()
if not col_name:
raise OperationError('Document %s has no collection defined '
'(is it abstract ?)' % cls)
cls._collection = None
db = cls._get_db()
db.drop_collection(cls._get_collection_name())
db.drop_collection(col_name)
@classmethod
def create_index(cls, keys, background=False, **kwargs):
@ -945,7 +981,7 @@ class MapReduceDocument(object):
if not isinstance(self.key, id_field_type):
try:
self.key = id_field_type(self.key)
except:
except Exception:
raise Exception("Could not cast key as %s" %
id_field_type.__name__)

View File

@ -6,7 +6,7 @@ from mongoengine.python_support import txt_type
__all__ = ('NotRegistered', 'InvalidDocumentError', 'LookUpError',
'DoesNotExist', 'MultipleObjectsReturned', 'InvalidQueryError',
'OperationError', 'NotUniqueError', 'FieldDoesNotExist',
'ValidationError')
'ValidationError', 'SaveConditionError')
class NotRegistered(Exception):

View File

@ -8,6 +8,8 @@ import uuid
import warnings
from operator import itemgetter
import six
try:
import dateutil
except ImportError:
@ -18,6 +20,10 @@ else:
import pymongo
import gridfs
from bson import Binary, DBRef, SON, ObjectId
try:
from bson.int64 import Int64
except ImportError:
Int64 = long
from mongoengine.errors import ValidationError
from mongoengine.python_support import (PY3, bin_type, txt_type,
@ -65,7 +71,7 @@ class StringField(BaseField):
return value
try:
value = value.decode('utf-8')
except:
except Exception:
pass
return value
@ -194,7 +200,7 @@ class IntField(BaseField):
def validate(self, value):
try:
value = int(value)
except:
except Exception:
self.error('%s could not be converted to int' % value)
if self.min_value is not None and value < self.min_value:
@ -225,10 +231,13 @@ class LongField(BaseField):
pass
return value
def to_mongo(self, value, **kwargs):
return Int64(value)
def validate(self, value):
try:
value = long(value)
except:
except Exception:
self.error('%s could not be converted to long' % value)
if self.min_value is not None and value < self.min_value:
@ -260,10 +269,14 @@ class FloatField(BaseField):
return value
def validate(self, value):
if isinstance(value, int):
value = float(value)
if isinstance(value, six.integer_types):
try:
value = float(value)
except OverflowError:
self.error('The value is too large to be converted to float')
if not isinstance(value, float):
self.error('FloatField only accepts float values')
self.error('FloatField only accepts float and integer values')
if self.min_value is not None and value < self.min_value:
self.error('Float value is too small')
@ -325,7 +338,7 @@ class DecimalField(BaseField):
return value
return value.quantize(decimal.Decimal(".%s" % ("0" * self.precision)), rounding=self.rounding)
def to_mongo(self, value, use_db_field=True):
def to_mongo(self, value, **kwargs):
if value is None:
return value
if self.force_string:
@ -388,7 +401,7 @@ class DateTimeField(BaseField):
if not isinstance(new_value, (datetime.datetime, datetime.date)):
self.error(u'cannot parse date "%s"' % value)
def to_mongo(self, value):
def to_mongo(self, value, **kwargs):
if value is None:
return value
if isinstance(value, datetime.datetime):
@ -508,10 +521,10 @@ class ComplexDateTimeField(StringField):
original_value = value
try:
return self._convert_from_string(value)
except:
except Exception:
return original_value
def to_mongo(self, value):
def to_mongo(self, value, **kwargs):
value = self.to_python(value)
return self._convert_from_datetime(value)
@ -546,11 +559,10 @@ class EmbeddedDocumentField(BaseField):
return self.document_type._from_son(value, _auto_dereference=self._auto_dereference)
return value
def to_mongo(self, value, use_db_field=True, fields=[]):
def to_mongo(self, value, **kwargs):
if not isinstance(value, self.document_type):
return value
return self.document_type.to_mongo(value, use_db_field,
fields=fields)
return self.document_type.to_mongo(value, **kwargs)
def validate(self, value, clean=True):
"""Make sure that the document instance is an instance of the
@ -600,11 +612,11 @@ class GenericEmbeddedDocumentField(BaseField):
value.validate(clean=clean)
def to_mongo(self, document, use_db_field=True):
def to_mongo(self, document, **kwargs):
if document is None:
return None
data = document.to_mongo(use_db_field)
data = document.to_mongo(**kwargs)
if '_cls' not in data:
data['_cls'] = document._class_name
return data
@ -616,7 +628,7 @@ class DynamicField(BaseField):
Used by :class:`~mongoengine.DynamicDocument` to handle dynamic data"""
def to_mongo(self, value):
def to_mongo(self, value, **kwargs):
"""Convert a Python type to a MongoDB compatible type.
"""
@ -625,7 +637,7 @@ class DynamicField(BaseField):
if hasattr(value, 'to_mongo'):
cls = value.__class__
val = value.to_mongo()
val = value.to_mongo(**kwargs)
# If we its a document thats not inherited add _cls
if isinstance(value, Document):
val = {"_ref": value.to_dbref(), "_cls": cls.__name__}
@ -643,7 +655,7 @@ class DynamicField(BaseField):
data = {}
for k, v in value.iteritems():
data[k] = self.to_mongo(v)
data[k] = self.to_mongo(v, **kwargs)
value = data
if is_list: # Convert back to a list
@ -697,7 +709,7 @@ class ListField(ComplexBaseField):
def prepare_query_value(self, op, value):
if self.field:
if op in ('set', 'unset') and (
if op in ('set', 'unset', None) and (
not isinstance(value, basestring) and
not isinstance(value, BaseDocument) and
hasattr(value, '__iter__')):
@ -755,8 +767,8 @@ class SortedListField(ListField):
self._order_reverse = kwargs.pop('reverse')
super(SortedListField, self).__init__(field, **kwargs)
def to_mongo(self, value):
value = super(SortedListField, self).to_mongo(value)
def to_mongo(self, value, **kwargs):
value = super(SortedListField, self).to_mongo(value, **kwargs)
if self._ordering is not None:
return sorted(value, key=itemgetter(self._ordering),
reverse=self._order_reverse)
@ -863,12 +875,11 @@ class ReferenceField(BaseField):
The options are:
* DO_NOTHING - don't do anything (default).
* NULLIFY - Updates the reference to null.
* CASCADE - Deletes the documents associated with the reference.
* DENY - Prevent the deletion of the reference object.
* PULL - Pull the reference from a :class:`~mongoengine.fields.ListField`
of references
* DO_NOTHING (0) - don't do anything (default).
* NULLIFY (1) - Updates the reference to null.
* CASCADE (2) - Deletes the documents associated with the reference.
* DENY (3) - Prevent the deletion of the reference object.
* PULL (4) - Pull the reference from a :class:`~mongoengine.fields.ListField` of references
Alternative syntax for registering delete rules (useful when implementing
bi-directional delete rules)
@ -879,7 +890,7 @@ class ReferenceField(BaseField):
content = StringField()
foo = ReferenceField('Foo')
Bar.register_delete_rule(Foo, 'bar', NULLIFY)
Foo.register_delete_rule(Bar, 'foo', NULLIFY)
.. note ::
`reverse_delete_rule` does not trigger pre / post delete signals to be
@ -896,6 +907,10 @@ class ReferenceField(BaseField):
or as the :class:`~pymongo.objectid.ObjectId`.id .
:param reverse_delete_rule: Determines what to do when the referring
object is deleted
.. note ::
A reference to an abstract document type is always stored as a
:class:`~pymongo.dbref.DBRef`, regardless of the value of `dbref`.
"""
if not isinstance(document_type, basestring):
if not issubclass(document_type, (Document, basestring)):
@ -928,33 +943,46 @@ class ReferenceField(BaseField):
self._auto_dereference = instance._fields[self.name]._auto_dereference
# Dereference DBRefs
if self._auto_dereference and isinstance(value, DBRef):
value = self.document_type._get_db().dereference(value)
if hasattr(value, 'cls'):
# Dereference using the class type specified in the reference
cls = get_document(value.cls)
else:
cls = self.document_type
value = cls._get_db().dereference(value)
if value is not None:
instance._data[self.name] = self.document_type._from_son(value)
instance._data[self.name] = cls._from_son(value)
return super(ReferenceField, self).__get__(instance, owner)
def to_mongo(self, document):
def to_mongo(self, document, **kwargs):
if isinstance(document, DBRef):
if not self.dbref:
return document.id
return document
id_field_name = self.document_type._meta['id_field']
id_field = self.document_type._fields[id_field_name]
if isinstance(document, Document):
# We need the id from the saved object to create the DBRef
id_ = document.pk
if id_ is None:
self.error('You can only reference documents once they have'
' been saved to the database')
# Use the attributes from the document instance, so that they
# override the attributes of this field's document type
cls = document
else:
id_ = document
cls = self.document_type
id_ = id_field.to_mongo(id_)
if self.dbref:
collection = self.document_type._get_collection_name()
id_field_name = cls._meta['id_field']
id_field = cls._fields[id_field_name]
id_ = id_field.to_mongo(id_, **kwargs)
if self.document_type._meta.get('abstract'):
collection = cls._get_collection_name()
return DBRef(collection, id_, cls=cls._class_name)
elif self.dbref:
collection = cls._get_collection_name()
return DBRef(collection, id_)
return id_
@ -983,6 +1011,14 @@ class ReferenceField(BaseField):
self.error('You can only reference documents once they have been '
'saved to the database')
if self.document_type._meta.get('abstract') and \
not isinstance(value, self.document_type):
self.error('%s is not an instance of abstract reference'
' type %s' % (value._class_name,
self.document_type._class_name)
)
def lookup_member(self, member_name):
return self.document_type._fields.get(member_name)
@ -990,7 +1026,7 @@ class ReferenceField(BaseField):
class CachedReferenceField(BaseField):
"""
A referencefield with cache fields to purpose pseudo-joins
.. versionadded:: 0.9
"""
@ -1064,7 +1100,7 @@ class CachedReferenceField(BaseField):
return super(CachedReferenceField, self).__get__(instance, owner)
def to_mongo(self, document):
def to_mongo(self, document, **kwargs):
id_field_name = self.document_type._meta['id_field']
id_field = self.document_type._fields[id_field_name]
@ -1079,10 +1115,11 @@ class CachedReferenceField(BaseField):
# TODO: should raise here or will fail next statement
value = SON((
("_id", id_field.to_mongo(id_)),
("_id", id_field.to_mongo(id_, **kwargs)),
))
value.update(dict(document.to_mongo(fields=self.fields)))
kwargs['fields'] = self.fields
value.update(dict(document.to_mongo(**kwargs)))
return value
def prepare_query_value(self, op, value):
@ -1198,7 +1235,7 @@ class GenericReferenceField(BaseField):
doc = doc_cls._from_son(doc)
return doc
def to_mongo(self, document, use_db_field=True):
def to_mongo(self, document, **kwargs):
if document is None:
return None
@ -1217,7 +1254,7 @@ class GenericReferenceField(BaseField):
else:
id_ = document
id_ = id_field.to_mongo(id_)
id_ = id_field.to_mongo(id_, **kwargs)
collection = document._get_collection_name()
ref = DBRef(collection, id_)
return SON((
@ -1246,7 +1283,7 @@ class BinaryField(BaseField):
value = bin_type(value)
return super(BinaryField, self).__set__(instance, value)
def to_mongo(self, value):
def to_mongo(self, value, **kwargs):
return Binary(value)
def validate(self, value):
@ -1346,7 +1383,7 @@ class GridFSProxy(object):
if self.gridout is None:
self.gridout = self.fs.get(self.grid_id)
return self.gridout
except:
except Exception:
# File has been deleted
return None
@ -1384,7 +1421,7 @@ class GridFSProxy(object):
else:
try:
return gridout.read(size)
except:
except Exception:
return ""
def delete(self):
@ -1449,7 +1486,7 @@ class FileField(BaseField):
if grid_file:
try:
grid_file.delete()
except:
except Exception:
pass
# Create a new proxy object as we don't already have one
@ -1471,7 +1508,7 @@ class FileField(BaseField):
db_alias=db_alias,
collection_name=collection_name)
def to_mongo(self, value):
def to_mongo(self, value, **kwargs):
# Store the GridFS file id in MongoDB
if isinstance(value, self.proxy_class) and value.grid_id is not None:
return value.grid_id
@ -1683,17 +1720,17 @@ class SequenceField(BaseField):
:param collection_name: Name of the counter collection (default 'mongoengine.counters')
:param sequence_name: Name of the sequence in the collection (default 'ClassName.counter')
:param value_decorator: Any callable to use as a counter (default int)
Use any callable as `value_decorator` to transform calculated counter into
any value suitable for your needs, e.g. string or hexadecimal
representation of the default integer counter value.
.. note::
In case the counter is defined in the abstract document, it will be
common to all inherited documents and the default sequence name will
In case the counter is defined in the abstract document, it will be
common to all inherited documents and the default sequence name will
be the class name of the abstract document.
.. versionadded:: 0.5
.. versionchanged:: 0.8 added `value_decorator`
"""
@ -1817,11 +1854,11 @@ class UUIDField(BaseField):
if not isinstance(value, basestring):
value = unicode(value)
return uuid.UUID(value)
except:
except Exception:
return original_value
return value
def to_mongo(self, value):
def to_mongo(self, value, **kwargs):
if not self._binary:
return unicode(value)
elif isinstance(value, basestring):

View File

@ -266,7 +266,8 @@ class BaseQuerySet(object):
result = None
return result
def insert(self, doc_or_docs, load_bulk=True, write_concern=None):
def insert(self, doc_or_docs, load_bulk=True,
write_concern=None, signal_kwargs=None):
"""bulk insert documents
:param doc_or_docs: a document or list of documents to be inserted
@ -279,11 +280,15 @@ class BaseQuerySet(object):
``insert(..., {w: 2, fsync: True})`` will wait until at least
two servers have recorded the write and will force an fsync on
each server being written to.
:parm signal_kwargs: (optional) kwargs dictionary to be passed to
the signal calls.
By default returns document instances, set ``load_bulk`` to False to
return just ``ObjectIds``
.. versionadded:: 0.5
.. versionchanged:: 0.10.7
Add signal_kwargs argument
"""
Document = _import_class('Document')
@ -296,7 +301,6 @@ class BaseQuerySet(object):
return_one = True
docs = [docs]
raw = []
for doc in docs:
if not isinstance(doc, self._document):
msg = ("Some documents inserted aren't instances of %s"
@ -305,9 +309,12 @@ class BaseQuerySet(object):
if doc.pk and not doc._created:
msg = "Some documents have ObjectIds use doc.update() instead"
raise OperationError(msg)
raw.append(doc.to_mongo())
signals.pre_bulk_insert.send(self._document, documents=docs)
signal_kwargs = signal_kwargs or {}
signals.pre_bulk_insert.send(self._document,
documents=docs, **signal_kwargs)
raw = [doc.to_mongo() for doc in docs]
try:
ids = self._collection.insert(raw, **write_concern)
except pymongo.errors.DuplicateKeyError, err:
@ -324,7 +331,7 @@ class BaseQuerySet(object):
if not load_bulk:
signals.post_bulk_insert.send(
self._document, documents=docs, loaded=False)
self._document, documents=docs, loaded=False, **signal_kwargs)
return return_one and ids[0] or ids
documents = self.in_bulk(ids)
@ -332,7 +339,7 @@ class BaseQuerySet(object):
for obj_id in ids:
results.append(documents.get(obj_id))
signals.post_bulk_insert.send(
self._document, documents=results, loaded=True)
self._document, documents=results, loaded=True, **signal_kwargs)
return return_one and results[0] or results
def count(self, with_limit_and_skip=False):
@ -403,8 +410,10 @@ class BaseQuerySet(object):
rule = doc._meta['delete_rules'][rule_entry]
if rule == CASCADE:
cascade_refs = set() if cascade_refs is None else cascade_refs
for ref in queryset:
cascade_refs.add(ref.id)
# Handle recursive reference
if doc._collection == document_cls._collection:
for ref in queryset:
cascade_refs.add(ref.id)
ref_q = document_cls.objects(**{field_name + '__in': self, 'id__nin': cascade_refs})
ref_q_count = ref_q.count()
if ref_q_count > 0:
@ -425,7 +434,7 @@ class BaseQuerySet(object):
full_result=False, **update):
"""Perform an atomic update on the fields matched by the query.
:param upsert: Any existing document with that "_id" is overwritten.
:param upsert: insert if document doesn't exist (default ``False``)
:param multi: Update multiple documents.
:param write_concern: Extra keyword arguments are passed down which
will be used as options for the resultant
@ -471,10 +480,36 @@ class BaseQuerySet(object):
raise OperationError(message)
raise OperationError(u'Update failed (%s)' % unicode(err))
def update_one(self, upsert=False, write_concern=None, **update):
"""Perform an atomic update on first field matched by the query.
def upsert_one(self, write_concern=None, **update):
"""Overwrite or add the first document matched by the query.
:param upsert: Any existing document with that "_id" is overwritten.
:param write_concern: Extra keyword arguments are passed down which
will be used as options for the resultant
``getLastError`` command. For example,
``save(..., write_concern={w: 2, fsync: True}, ...)`` will
wait until at least two servers have recorded the write and
will force an fsync on the primary server.
:param update: Django-style update keyword arguments
:returns the new or overwritten document
.. versionadded:: 0.10.2
"""
atomic_update = self.update(multi=False, upsert=True, write_concern=write_concern,
full_result=True, **update)
if atomic_update['updatedExisting']:
document = self.get()
else:
document = self._document.objects.with_id(atomic_update['upserted'])
return document
def update_one(self, upsert=False, write_concern=None, **update):
"""Perform an atomic update on the fields of the first document
matched by the query.
:param upsert: insert if document doesn't exist (default ``False``)
:param write_concern: Extra keyword arguments are passed down which
will be used as options for the resultant
``getLastError`` command. For example,
@ -929,6 +964,7 @@ class BaseQuerySet(object):
validate_read_preference('read_preference', read_preference)
queryset = self.clone()
queryset._read_preference = read_preference
queryset._cursor_obj = None # we need to re-create the cursor object whenever we apply read_preference
return queryset
def scalar(self, *fields):
@ -1201,66 +1237,28 @@ class BaseQuerySet(object):
def sum(self, field):
"""Sum over the values of the specified field.
:param field: the field to sum over; use dot-notation to refer to
:param field: the field to sum over; use dot notation to refer to
embedded document fields
.. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work
with sharding.
"""
map_func = """
function() {
var path = '{{~%(field)s}}'.split('.'),
field = this;
for (p in path) {
if (typeof field != 'undefined')
field = field[path[p]];
else
break;
}
if (field && field.constructor == Array) {
field.forEach(function(item) {
emit(1, item||0);
});
} else if (typeof field != 'undefined') {
emit(1, field||0);
}
}
""" % dict(field=field)
reduce_func = Code("""
function(key, values) {
var sum = 0;
for (var i in values) {
sum += values[i];
}
return sum;
}
""")
for result in self.map_reduce(map_func, reduce_func, output='inline'):
return result.value
else:
return 0
def aggregate_sum(self, field):
"""Sum over the values of the specified field.
:param field: the field to sum over; use dot-notation to refer to
embedded document fields
This method is more performant than the regular `sum`, because it uses
the aggregation framework instead of map-reduce.
"""
result = self._document._get_collection().aggregate([
pipeline = [
{'$match': self._query},
{'$group': {'_id': 'sum', 'total': {'$sum': '$' + field}}}
])
]
# if we're performing a sum over a list field, we sum up all the
# elements in the list, hence we need to $unwind the arrays first
ListField = _import_class('ListField')
field_parts = field.split('.')
field_instances = self._document._lookup_field(field_parts)
if isinstance(field_instances[-1], ListField):
pipeline.insert(1, {'$unwind': '$' + field})
result = self._document._get_collection().aggregate(pipeline)
if IS_PYMONGO_3:
result = list(result)
result = tuple(result)
else:
result = result.get('result')
if result:
return result[0]['total']
return 0
@ -1268,73 +1266,26 @@ class BaseQuerySet(object):
def average(self, field):
"""Average over the values of the specified field.
:param field: the field to average over; use dot-notation to refer to
:param field: the field to average over; use dot notation to refer to
embedded document fields
.. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work
with sharding.
"""
map_func = """
function() {
var path = '{{~%(field)s}}'.split('.'),
field = this;
for (p in path) {
if (typeof field != 'undefined')
field = field[path[p]];
else
break;
}
if (field && field.constructor == Array) {
field.forEach(function(item) {
emit(1, {t: item||0, c: 1});
});
} else if (typeof field != 'undefined') {
emit(1, {t: field||0, c: 1});
}
}
""" % dict(field=field)
reduce_func = Code("""
function(key, values) {
var out = {t: 0, c: 0};
for (var i in values) {
var value = values[i];
out.t += value.t;
out.c += value.c;
}
return out;
}
""")
finalize_func = Code("""
function(key, value) {
return value.t / value.c;
}
""")
for result in self.map_reduce(map_func, reduce_func,
finalize_f=finalize_func, output='inline'):
return result.value
else:
return 0
def aggregate_average(self, field):
"""Average over the values of the specified field.
:param field: the field to average over; use dot-notation to refer to
embedded document fields
This method is more performant than the regular `average`, because it
uses the aggregation framework instead of map-reduce.
"""
result = self._document._get_collection().aggregate([
pipeline = [
{'$match': self._query},
{'$group': {'_id': 'avg', 'total': {'$avg': '$' + field}}}
])
]
# if we're performing an average over a list field, we average out
# all the elements in the list, hence we need to $unwind the arrays
# first
ListField = _import_class('ListField')
field_parts = field.split('.')
field_instances = self._document._lookup_field(field_parts)
if isinstance(field_instances[-1], ListField):
pipeline.insert(1, {'$unwind': '$' + field})
result = self._document._get_collection().aggregate(pipeline)
if IS_PYMONGO_3:
result = list(result)
result = tuple(result)
else:
result = result.get('result')
if result:
@ -1351,7 +1302,7 @@ class BaseQuerySet(object):
Can only do direct simple mappings and cannot map across
:class:`~mongoengine.fields.ReferenceField` or
:class:`~mongoengine.fields.GenericReferenceField` for more complex
counting a manual map reduce call would is required.
counting a manual map reduce call is required.
If the field is a :class:`~mongoengine.fields.ListField`, the items within
each list will be counted individually.
@ -1425,7 +1376,7 @@ class BaseQuerySet(object):
msg = "The snapshot option is not anymore available with PyMongo 3+"
warnings.warn(msg, DeprecationWarning)
cursor_args = {
'no_cursor_timeout': self._timeout
'no_cursor_timeout': not self._timeout
}
if self._loaded_fields:
cursor_args[fields_name] = self._loaded_fields.as_dict()
@ -1442,8 +1393,16 @@ class BaseQuerySet(object):
def _cursor(self):
if self._cursor_obj is None:
self._cursor_obj = self._collection.find(self._query,
**self._cursor_args)
# In PyMongo 3+, we define the read preference on a collection
# level, not a cursor level. Thus, we need to get a cloned
# collection object using `with_options` first.
if IS_PYMONGO_3 and self._read_preference is not None:
self._cursor_obj = self._collection\
.with_options(read_preference=self._read_preference)\
.find(self._query, **self._cursor_args)
else:
self._cursor_obj = self._collection.find(self._query,
**self._cursor_args)
# Apply where clauses to cursor
if self._where_clause:
where_clause = self._sub_js_fields(self._where_clause)
@ -1660,7 +1619,7 @@ class BaseQuerySet(object):
key = key.replace('__', '.')
try:
key = self._document._translate_field_name(key)
except:
except Exception:
pass
key_list.append((key, direction))

View File

@ -29,7 +29,7 @@ class QuerySetManager(object):
Document.objects is accessed.
"""
if instance is not None:
# Document class being used rather than a document object
# Document object being used rather than a document class
return self
# owner is the document that contains the QuerySetManager

View File

@ -38,7 +38,7 @@ class QuerySet(BaseQuerySet):
def __len__(self):
"""Since __len__ is called quite frequently (for example, as part of
list(qs) we populate the result cache and cache the length.
list(qs)), we populate the result cache and cache the length.
"""
if self._len is not None:
return self._len

View File

@ -26,12 +26,12 @@ MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS +
STRING_OPERATORS + CUSTOM_OPERATORS)
def query(_doc_cls=None, **query):
def query(_doc_cls=None, **kwargs):
"""Transform a query from Django-style format to Mongo format.
"""
mongo_query = {}
merge_query = defaultdict(list)
for key, value in sorted(query.items()):
for key, value in sorted(kwargs.items()):
if key == "__raw__":
mongo_query.update(value)
continue
@ -44,7 +44,7 @@ def query(_doc_cls=None, **query):
if len(parts) > 1 and parts[-1] in MATCH_OPERATORS:
op = parts.pop()
# Allw to escape operator-like field name by __
# Allow to escape operator-like field name by __
if len(parts) > 1 and parts[-1] == "":
parts.pop()
@ -105,13 +105,18 @@ def query(_doc_cls=None, **query):
if op:
if op in GEO_OPERATORS:
value = _geo_operator(field, op, value)
elif op in CUSTOM_OPERATORS:
if op in ('elem_match', 'match'):
value = field.prepare_query_value(op, value)
value = {"$elemMatch": value}
elif op in ('match', 'elemMatch'):
ListField = _import_class('ListField')
EmbeddedDocumentField = _import_class('EmbeddedDocumentField')
if (isinstance(value, dict) and isinstance(field, ListField) and
isinstance(field.field, EmbeddedDocumentField)):
value = query(field.field.document_type, **value)
else:
NotImplementedError("Custom method '%s' has not "
"been implemented" % op)
value = field.prepare_query_value(op, value)
value = {"$elemMatch": value}
elif op in CUSTOM_OPERATORS:
NotImplementedError("Custom method '%s' has not "
"been implemented" % op)
elif op not in STRING_OPERATORS:
value = {'$' + op: value}
@ -207,6 +212,10 @@ def update(_doc_cls=None, **update):
if parts[-1] in COMPARISON_OPERATORS:
match = parts.pop()
# Allow to escape operator-like field name by __
if len(parts) > 1 and parts[-1] == "":
parts.pop()
if _doc_cls:
# Switch field names to proper names [set in Field(name='foo')]
try:
@ -359,20 +368,24 @@ def _infer_geometry(value):
"type and coordinates keys")
elif isinstance(value, (list, set)):
# TODO: shouldn't we test value[0][0][0][0] to see if it is MultiPolygon?
# TODO: should both TypeError and IndexError be alike interpreted?
try:
value[0][0][0]
return {"$geometry": {"type": "Polygon", "coordinates": value}}
except:
except (TypeError, IndexError):
pass
try:
value[0][0]
return {"$geometry": {"type": "LineString", "coordinates": value}}
except:
except (TypeError, IndexError):
pass
try:
value[0]
return {"$geometry": {"type": "Point", "coordinates": value}}
except:
except (TypeError, IndexError):
pass
raise InvalidQueryError("Invalid $geometry data. Can be either a dictionary "

View File

@ -1,2 +1,3 @@
pymongo>=2.7.1
nose
pymongo>=2.7.1
six==1.10.0

View File

@ -10,11 +10,12 @@ except ImportError:
DESCRIPTION = 'MongoEngine is a Python Object-Document ' + \
'Mapper for working with MongoDB.'
LONG_DESCRIPTION = None
try:
LONG_DESCRIPTION = open('README.rst').read()
except:
pass
with open('README.rst') as fin:
LONG_DESCRIPTION = fin.read()
except Exception:
LONG_DESCRIPTION = None
def get_version(version_tuple):
@ -77,7 +78,7 @@ setup(name='mongoengine',
long_description=LONG_DESCRIPTION,
platforms=['any'],
classifiers=CLASSIFIERS,
install_requires=['pymongo>=2.7.1'],
install_requires=['pymongo>=2.7.1', 'six'],
test_suite='nose.collector',
**extra_opts
)

View File

@ -5,6 +5,7 @@ import sys
sys.path[0:0] = [""]
import pymongo
from random import randint
from nose.plugins.skip import SkipTest
from datetime import datetime
@ -16,9 +17,11 @@ __all__ = ("IndexesTest", )
class IndexesTest(unittest.TestCase):
_MAX_RAND = 10 ** 10
def setUp(self):
self.connection = connect(db='mongoenginetest')
self.db_name = 'mongoenginetest_IndexesTest_' + str(randint(0, self._MAX_RAND))
self.connection = connect(db=self.db_name)
self.db = get_db()
class Person(Document):
@ -32,10 +35,7 @@ class IndexesTest(unittest.TestCase):
self.Person = Person
def tearDown(self):
for collection in self.db.collection_names():
if 'system.' in collection:
continue
self.db.drop_collection(collection)
self.connection.drop_database(self.db)
def test_indexes_document(self):
"""Ensure that indexes are used when meta[indexes] is specified for
@ -822,33 +822,29 @@ class IndexesTest(unittest.TestCase):
name = StringField(required=True)
term = StringField(required=True)
class Report(Document):
class ReportEmbedded(Document):
key = EmbeddedDocumentField(CompoundKey, primary_key=True)
text = StringField()
Report.drop_collection()
my_key = CompoundKey(name="n", term="ok")
report = Report(text="OK", key=my_key).save()
report = ReportEmbedded(text="OK", key=my_key).save()
self.assertEqual({'text': 'OK', '_id': {'term': 'ok', 'name': 'n'}},
report.to_mongo())
self.assertEqual(report, Report.objects.get(pk=my_key))
self.assertEqual(report, ReportEmbedded.objects.get(pk=my_key))
def test_compound_key_dictfield(self):
class Report(Document):
class ReportDictField(Document):
key = DictField(primary_key=True)
text = StringField()
Report.drop_collection()
my_key = {"name": "n", "term": "ok"}
report = Report(text="OK", key=my_key).save()
report = ReportDictField(text="OK", key=my_key).save()
self.assertEqual({'text': 'OK', '_id': {'term': 'ok', 'name': 'n'}},
report.to_mongo())
self.assertEqual(report, Report.objects.get(pk=my_key))
self.assertEqual(report, ReportDictField.objects.get(pk=my_key))
def test_string_indexes(self):
@ -863,6 +859,20 @@ class IndexesTest(unittest.TestCase):
self.assertTrue([('provider_ids.foo', 1)] in info)
self.assertTrue([('provider_ids.bar', 1)] in info)
def test_sparse_compound_indexes(self):
class MyDoc(Document):
provider_ids = DictField()
meta = {
"indexes": [{'fields': ("provider_ids.foo", "provider_ids.bar"),
'sparse': True}],
}
info = MyDoc.objects._collection.index_information()
self.assertEqual([('provider_ids.foo', 1), ('provider_ids.bar', 1)],
info['provider_ids.foo_1_provider_ids.bar_1']['key'])
self.assertTrue(info['provider_ids.foo_1_provider_ids.bar_1']['sparse'])
def test_text_indexes(self):
class Book(Document):
@ -895,26 +905,38 @@ class IndexesTest(unittest.TestCase):
Issue #812
"""
# Use a new connection and database since dropping the database could
# cause concurrent tests to fail.
connection = connect(db='tempdatabase',
alias='test_indexes_after_database_drop')
class BlogPost(Document):
title = StringField()
slug = StringField(unique=True)
BlogPost.drop_collection()
meta = {'db_alias': 'test_indexes_after_database_drop'}
# Create Post #1
post1 = BlogPost(title='test1', slug='test')
post1.save()
try:
BlogPost.drop_collection()
# Drop the Database
self.connection.drop_database(BlogPost._get_db().name)
# Create Post #1
post1 = BlogPost(title='test1', slug='test')
post1.save()
# Re-create Post #1
post1 = BlogPost(title='test1', slug='test')
post1.save()
# Drop the Database
connection.drop_database('tempdatabase')
# Re-create Post #1
post1 = BlogPost(title='test1', slug='test')
post1.save()
# Create Post #2
post2 = BlogPost(title='test2', slug='test')
self.assertRaises(NotUniqueError, post2.save)
finally:
# Drop the temporary database at the end
connection.drop_database('tempdatabase')
# Create Post #2
post2 = BlogPost(title='test2', slug='test')
self.assertRaises(NotUniqueError, post2.save)
def test_index_dont_send_cls_option(self):
"""

View File

@ -411,7 +411,7 @@ class InheritanceTest(unittest.TestCase):
try:
class MyDocument(DateCreatedDocument, DateUpdatedDocument):
pass
except:
except Exception:
self.assertTrue(False, "Couldn't create MyDocument class")
def test_abstract_documents(self):

View File

@ -7,12 +7,13 @@ import os
import pickle
import unittest
import uuid
import weakref
from datetime import datetime
from bson import DBRef, ObjectId
from tests import fixtures
from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest,
PickleDyanmicEmbedded, PickleDynamicTest)
PickleDynamicEmbedded, PickleDynamicTest)
from mongoengine import *
from mongoengine.errors import (NotRegistered, InvalidDocumentError,
@ -30,6 +31,8 @@ TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__),
__all__ = ("InstanceTest",)
class InstanceTest(unittest.TestCase):
def setUp(self):
@ -63,6 +66,14 @@ class InstanceTest(unittest.TestCase):
list(self.Person._get_collection().find().sort("id")),
sorted(docs, key=lambda doc: doc["_id"]))
def assertHasInstance(self, field, instance):
self.assertTrue(hasattr(field, "_instance"))
self.assertTrue(field._instance is not None)
if isinstance(field._instance, weakref.ProxyType):
self.assertTrue(field._instance.__eq__(instance))
else:
self.assertEqual(field._instance, instance)
def test_capped_collection(self):
"""Ensure that capped collections work properly.
"""
@ -473,6 +484,20 @@ class InstanceTest(unittest.TestCase):
doc.reload()
Animal.drop_collection()
def test_reload_sharded_nested(self):
class SuperPhylum(EmbeddedDocument):
name = StringField()
class Animal(Document):
superphylum = EmbeddedDocumentField(SuperPhylum)
meta = {'shard_key': ('superphylum.name',)}
Animal.drop_collection()
doc = Animal(superphylum=SuperPhylum(name='Deuterostomia'))
doc.save()
doc.reload()
Animal.drop_collection()
def test_reload_referencing(self):
"""Ensures reloading updates weakrefs correctly
"""
@ -546,6 +571,28 @@ class InstanceTest(unittest.TestCase):
except Exception:
self.assertFalse("Threw wrong exception")
def test_reload_of_non_strict_with_special_field_name(self):
"""Ensures reloading works for documents with meta strict == False
"""
class Post(Document):
meta = {
'strict': False
}
title = StringField()
items = ListField()
Post.drop_collection()
Post._get_collection().insert({
"title": "Items eclipse",
"items": ["more lorem", "even more ipsum"]
})
post = Post.objects.first()
post.reload()
self.assertEqual(post.title, "Items eclipse")
self.assertEqual(post.items, ["more lorem", "even more ipsum"])
def test_dictionary_access(self):
"""Ensure that dictionary-style field access works properly.
"""
@ -608,10 +655,12 @@ class InstanceTest(unittest.TestCase):
embedded_field = EmbeddedDocumentField(Embedded)
Doc.drop_collection()
Doc(embedded_field=Embedded(string="Hi")).save()
doc = Doc(embedded_field=Embedded(string="Hi"))
self.assertHasInstance(doc.embedded_field, doc)
doc.save()
doc = Doc.objects.get()
self.assertEqual(doc, doc.embedded_field._instance)
self.assertHasInstance(doc.embedded_field, doc)
def test_embedded_document_complex_instance(self):
"""Ensure that embedded documents in complex fields can reference
@ -623,10 +672,25 @@ class InstanceTest(unittest.TestCase):
embedded_field = ListField(EmbeddedDocumentField(Embedded))
Doc.drop_collection()
Doc(embedded_field=[Embedded(string="Hi")]).save()
doc = Doc(embedded_field=[Embedded(string="Hi")])
self.assertHasInstance(doc.embedded_field[0], doc)
doc.save()
doc = Doc.objects.get()
self.assertEqual(doc, doc.embedded_field[0]._instance)
self.assertHasInstance(doc.embedded_field[0], doc)
def test_embedded_document_complex_instance_no_use_db_field(self):
"""Ensure that use_db_field is propagated to list of Emb Docs
"""
class Embedded(EmbeddedDocument):
string = StringField(db_field='s')
class Doc(Document):
embedded_field = ListField(EmbeddedDocumentField(Embedded))
d = Doc(embedded_field=[Embedded(string="Hi")]).to_mongo(
use_db_field=False).to_dict()
self.assertEqual(d['embedded_field'], [{'string': 'Hi'}])
def test_instance_is_set_on_setattr(self):
@ -639,11 +703,28 @@ class InstanceTest(unittest.TestCase):
Account.drop_collection()
acc = Account()
acc.email = Email(email='test@example.com')
self.assertTrue(hasattr(acc._data["email"], "_instance"))
self.assertHasInstance(acc._data["email"], acc)
acc.save()
acc1 = Account.objects.first()
self.assertTrue(hasattr(acc1._data["email"], "_instance"))
self.assertHasInstance(acc1._data["email"], acc1)
def test_instance_is_set_on_setattr_on_embedded_document_list(self):
class Email(EmbeddedDocument):
email = EmailField()
class Account(Document):
emails = EmbeddedDocumentListField(Email)
Account.drop_collection()
acc = Account()
acc.emails = [Email(email='test@example.com')]
self.assertHasInstance(acc._data["emails"][0], acc)
acc.save()
acc1 = Account.objects.first()
self.assertHasInstance(acc1._data["emails"][0], acc1)
def test_document_clean(self):
class TestDocument(Document):
@ -1825,6 +1906,62 @@ class InstanceTest(unittest.TestCase):
author.delete()
self.assertEqual(BlogPost.objects.count(), 0)
def test_reverse_delete_rule_with_custom_id_field(self):
"""Ensure that a referenced document with custom primary key
is also deleted upon deletion.
"""
class User(Document):
name = StringField(primary_key=True)
class Book(Document):
author = ReferenceField(User, reverse_delete_rule=CASCADE)
reviewer = ReferenceField(User, reverse_delete_rule=NULLIFY)
User.drop_collection()
Book.drop_collection()
user = User(name='Mike').save()
reviewer = User(name='John').save()
book = Book(author=user, reviewer=reviewer).save()
reviewer.delete()
self.assertEqual(Book.objects.count(), 1)
self.assertEqual(Book.objects.get().reviewer, None)
user.delete()
self.assertEqual(Book.objects.count(), 0)
def test_reverse_delete_rule_with_shared_id_among_collections(self):
"""Ensure that cascade delete rule doesn't mix id among collections.
"""
class User(Document):
id = IntField(primary_key=True)
class Book(Document):
id = IntField(primary_key=True)
author = ReferenceField(User, reverse_delete_rule=CASCADE)
User.drop_collection()
Book.drop_collection()
user_1 = User(id=1).save()
user_2 = User(id=2).save()
book_1 = Book(id=1, author=user_2).save()
book_2 = Book(id=2, author=user_1).save()
user_2.delete()
# Deleting user_2 should also delete book_1 but not book_2
self.assertEqual(Book.objects.count(), 1)
self.assertEqual(Book.objects.get(), book_2)
user_3 = User(id=3).save()
book_3 = Book(id=3, author=user_3).save()
user_3.delete()
# Deleting user_3 should also delete book_3
self.assertEqual(Book.objects.count(), 1)
self.assertEqual(Book.objects.get(), book_2)
def test_reverse_delete_rule_with_document_inheritance(self):
"""Ensure that a referenced document is also deleted upon deletion
of a child document.
@ -2180,7 +2317,7 @@ class InstanceTest(unittest.TestCase):
pickle_doc = PickleDynamicTest(
name="test", number=1, string="One", lists=['1', '2'])
pickle_doc.embedded = PickleDyanmicEmbedded(foo="Bar")
pickle_doc.embedded = PickleDynamicEmbedded(foo="Bar")
pickled_doc = pickle.dumps(pickle_doc) # make sure pickling works even before the doc is saved
pickle_doc.save()
@ -2683,6 +2820,32 @@ class InstanceTest(unittest.TestCase):
self.assertRaises(OperationError, change_shard_key)
def test_shard_key_in_embedded_document(self):
class Foo(EmbeddedDocument):
foo = StringField()
class Bar(Document):
meta = {
'shard_key': ('foo.foo',)
}
foo = EmbeddedDocumentField(Foo)
bar = StringField()
foo_doc = Foo(foo='hello')
bar_doc = Bar(foo=foo_doc, bar='world')
bar_doc.save()
self.assertTrue(bar_doc.id is not None)
bar_doc.bar = 'baz'
bar_doc.save()
def change_shard_key():
bar_doc.foo.foo = 'something'
bar_doc.save()
self.assertRaises(OperationError, change_shard_key)
def test_shard_key_primary(self):
class LogEntry(Document):
machine = StringField(primary_key=True)
@ -2765,6 +2928,20 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(person.name, "Test User")
self.assertEqual(person.age, 42)
def test_positional_creation_embedded(self):
"""Ensure that embedded document may be created using positional arguments.
"""
job = self.Job("Test Job", 4)
self.assertEqual(job.name, "Test Job")
self.assertEqual(job.years, 4)
def test_mixed_creation_embedded(self):
"""Ensure that embedded document may be created using mixed arguments.
"""
job = self.Job("Test Job", years=4)
self.assertEqual(job.name, "Test Job")
self.assertEqual(job.years, 4)
def test_mixed_creation_dynamic(self):
"""Ensure that document may be created using mixed arguments.
"""

View File

@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
import sys
import six
from nose.plugins.skip import SkipTest
sys.path[0:0] = [""]
@ -10,6 +12,7 @@ import uuid
import math
import itertools
import re
import six
try:
import dateutil
@ -19,6 +22,10 @@ except ImportError:
from decimal import Decimal
from bson import Binary, DBRef, ObjectId
try:
from bson.int64 import Int64
except ImportError:
Int64 = long
from mongoengine import *
from mongoengine.connection import get_db
@ -399,20 +406,37 @@ class FieldTest(unittest.TestCase):
class Person(Document):
height = FloatField(min_value=0.1, max_value=3.5)
class BigPerson(Document):
height = FloatField()
person = Person()
person.height = 1.89
person.validate()
person.height = '2.0'
self.assertRaises(ValidationError, person.validate)
person.height = 0.01
self.assertRaises(ValidationError, person.validate)
person.height = 4.0
self.assertRaises(ValidationError, person.validate)
person_2 = Person(height='something invalid')
self.assertRaises(ValidationError, person_2.validate)
big_person = BigPerson()
for value, value_type in enumerate(six.integer_types):
big_person.height = value_type(value)
big_person.validate()
big_person.height = 2 ** 500
big_person.validate()
big_person.height = 2 ** 100000 # Too big for a float value
self.assertRaises(ValidationError, big_person.validate)
def test_decimal_validation(self):
"""Ensure that invalid values cannot be assigned to decimal fields.
"""
@ -1184,6 +1208,19 @@ class FieldTest(unittest.TestCase):
simple = simple.reload()
self.assertEqual(simple.widgets, [4])
def test_list_field_with_negative_indices(self):
class Simple(Document):
widgets = ListField()
simple = Simple(widgets=[1, 2, 3, 4]).save()
simple.widgets[-1] = 5
self.assertEqual(['widgets.3'], simple._changed_fields)
simple.save()
simple = simple.reload()
self.assertEqual(simple.widgets, [1, 2, 3, 5])
def test_list_field_complex(self):
"""Ensure that the list fields can handle the complex types."""
@ -1563,6 +1600,29 @@ class FieldTest(unittest.TestCase):
actions__friends__operation='drink',
actions__friends__object='beer').count())
def test_map_field_unicode(self):
class Info(EmbeddedDocument):
description = StringField()
value_list = ListField(field=StringField())
class BlogPost(Document):
info_dict = MapField(field=EmbeddedDocumentField(Info))
BlogPost.drop_collection()
tree = BlogPost(info_dict={
u"éééé": {
'description': u"VALUE: éééé"
}
})
tree.save()
self.assertEqual(BlogPost.objects.get(id=tree.id).info_dict[u"éééé"].description, u"VALUE: éééé")
BlogPost.drop_collection()
def test_embedded_db_field(self):
class Embedded(EmbeddedDocument):
@ -1599,6 +1659,8 @@ class FieldTest(unittest.TestCase):
name = StringField()
preferences = EmbeddedDocumentField(PersonPreferences)
Person.drop_collection()
person = Person(name='Test User')
person.preferences = 'My Preferences'
self.assertRaises(ValidationError, person.validate)
@ -1631,12 +1693,39 @@ class FieldTest(unittest.TestCase):
content = StringField()
author = EmbeddedDocumentField(User)
BlogPost.drop_collection()
post = BlogPost(content='What I did today...')
post.author = PowerUser(name='Test User', power=47)
post.save()
self.assertEqual(47, BlogPost.objects.first().author.power)
def test_embedded_document_inheritance_with_list(self):
"""Ensure that nested list of subclassed embedded documents is
handled correctly.
"""
class Group(EmbeddedDocument):
name = StringField()
content = ListField(StringField())
class Basedoc(Document):
groups = ListField(EmbeddedDocumentField(Group))
meta = {'abstract': True}
class User(Basedoc):
doctype = StringField(require=True, default='userdata')
User.drop_collection()
content = ['la', 'le', 'lu']
group = Group(name='foo', content=content)
foobar = User(groups=[group])
foobar.save()
self.assertEqual(content, User.objects.first().groups[0].content)
def test_reference_validation(self):
"""Ensure that invalid docment objects cannot be assigned to reference
fields.
@ -2329,6 +2418,91 @@ class FieldTest(unittest.TestCase):
Member.drop_collection()
BlogPost.drop_collection()
def test_drop_abstract_document(self):
"""Ensure that an abstract document cannot be dropped given it
has no underlying collection.
"""
class AbstractDoc(Document):
name = StringField()
meta = {"abstract": True}
self.assertRaises(OperationError, AbstractDoc.drop_collection)
def test_reference_class_with_abstract_parent(self):
"""Ensure that a class with an abstract parent can be referenced.
"""
class Sibling(Document):
name = StringField()
meta = {"abstract": True}
class Sister(Sibling):
pass
class Brother(Sibling):
sibling = ReferenceField(Sibling)
Sister.drop_collection()
Brother.drop_collection()
sister = Sister(name="Alice")
sister.save()
brother = Brother(name="Bob", sibling=sister)
brother.save()
self.assertEquals(Brother.objects[0].sibling.name, sister.name)
Sister.drop_collection()
Brother.drop_collection()
def test_reference_abstract_class(self):
"""Ensure that an abstract class instance cannot be used in the
reference of that abstract class.
"""
class Sibling(Document):
name = StringField()
meta = {"abstract": True}
class Sister(Sibling):
pass
class Brother(Sibling):
sibling = ReferenceField(Sibling)
Sister.drop_collection()
Brother.drop_collection()
sister = Sibling(name="Alice")
brother = Brother(name="Bob", sibling=sister)
self.assertRaises(ValidationError, brother.save)
Sister.drop_collection()
Brother.drop_collection()
def test_abstract_reference_base_type(self):
"""Ensure that an an abstract reference fails validation when given a
Document that does not inherit from the abstract type.
"""
class Sibling(Document):
name = StringField()
meta = {"abstract": True}
class Brother(Sibling):
sibling = ReferenceField(Sibling)
class Mother(Document):
name = StringField()
Brother.drop_collection()
Mother.drop_collection()
mother = Mother(name="Carol")
mother.save()
brother = Brother(name="Bob", sibling=mother)
self.assertRaises(ValidationError, brother.save)
Brother.drop_collection()
Mother.drop_collection()
def test_generic_reference(self):
"""Ensure that a GenericReferenceField properly dereferences items.
"""
@ -3353,7 +3527,7 @@ class FieldTest(unittest.TestCase):
def __init__(self, **kwargs):
super(EnumField, self).__init__(**kwargs)
def to_mongo(self, value):
def to_mongo(self, value, **kwargs):
return value
def to_python(self, value):
@ -3520,6 +3694,19 @@ class FieldTest(unittest.TestCase):
self.assertRaises(FieldDoesNotExist, test)
def test_long_field_is_considered_as_int64(self):
"""
Tests that long fields are stored as long in mongo, even if long value
is small enough to be an int.
"""
class TestLongFieldConsideredAsInt64(Document):
some_long = LongField()
doc = TestLongFieldConsideredAsInt64(some_long=42).save()
db = get_db()
self.assertTrue(isinstance(db.test_long_field_considered_as_int64.find()[0]['some_long'], Int64))
self.assertTrue(isinstance(doc.some_long, six.integer_types))
class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
@ -3907,6 +4094,17 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
# modified
self.assertEqual(number, 2)
def test_unicode(self):
"""
Tests that unicode strings handled correctly
"""
post = self.BlogPost(comments=[
self.Comments(author='user1', message=u'сообщение'),
self.Comments(author='user2', message=u'хабарлама')
]).save()
self.assertEqual(post.comments.get(message=u'сообщение').author,
'user1')
def test_save(self):
"""
Tests the save method of a List of Embedded Documents.

View File

@ -26,7 +26,7 @@ class NewDocumentPickleTest(Document):
new_field = StringField()
class PickleDyanmicEmbedded(DynamicEmbeddedDocument):
class PickleDynamicEmbedded(DynamicEmbeddedDocument):
date = DateTimeField(default=datetime.now)

View File

@ -1,8 +1,11 @@
import unittest
from convert_to_new_inheritance_model import *
from decimalfield_as_float import *
from refrencefield_dbref_to_object_id import *
from referencefield_dbref_to_object_id import *
from turn_off_inheritance import *
from uuidfield_to_binary import *
if __name__ == '__main__':
unittest.main()

View File

@ -680,12 +680,21 @@ class QuerySetTest(unittest.TestCase):
def test_upsert_one(self):
self.Person.drop_collection()
self.Person.objects(name="Bob", age=30).update_one(upsert=True)
bob = self.Person.objects(name="Bob", age=30).upsert_one()
bob = self.Person.objects.first()
self.assertEqual("Bob", bob.name)
self.assertEqual(30, bob.age)
bob.name = "Bobby"
bob.save()
bobby = self.Person.objects(name="Bobby", age=30).upsert_one()
self.assertEqual("Bobby", bobby.name)
self.assertEqual(30, bobby.age)
self.assertEqual(bob.id, bobby.id)
def test_set_on_insert(self):
self.Person.drop_collection()
@ -2757,25 +2766,15 @@ class QuerySetTest(unittest.TestCase):
avg = float(sum(ages)) / (len(ages) + 1) # take into account the 0
self.assertAlmostEqual(int(self.Person.objects.average('age')), avg)
self.assertAlmostEqual(
int(self.Person.objects.aggregate_average('age')), avg
)
self.Person(name='ageless person').save()
self.assertEqual(int(self.Person.objects.average('age')), avg)
self.assertEqual(
int(self.Person.objects.aggregate_average('age')), avg
)
# dot notation
self.Person(
name='person meta', person_meta=self.PersonMeta(weight=0)).save()
self.assertAlmostEqual(
int(self.Person.objects.average('person_meta.weight')), 0)
self.assertAlmostEqual(
int(self.Person.objects.aggregate_average('person_meta.weight')),
0
)
for i, weight in enumerate(ages):
self.Person(
@ -2784,19 +2783,11 @@ class QuerySetTest(unittest.TestCase):
self.assertAlmostEqual(
int(self.Person.objects.average('person_meta.weight')), avg
)
self.assertAlmostEqual(
int(self.Person.objects.aggregate_average('person_meta.weight')),
avg
)
self.Person(name='test meta none').save()
self.assertEqual(
int(self.Person.objects.average('person_meta.weight')), avg
)
self.assertEqual(
int(self.Person.objects.aggregate_average('person_meta.weight')),
avg
)
# test summing over a filtered queryset
over_50 = [a for a in ages if a >= 50]
@ -2805,10 +2796,6 @@ class QuerySetTest(unittest.TestCase):
self.Person.objects.filter(age__gte=50).average('age'),
avg
)
self.assertEqual(
self.Person.objects.filter(age__gte=50).aggregate_average('age'),
avg
)
def test_sum(self):
"""Ensure that field can be summed over correctly.
@ -2818,15 +2805,9 @@ class QuerySetTest(unittest.TestCase):
self.Person(name='test%s' % i, age=age).save()
self.assertEqual(self.Person.objects.sum('age'), sum(ages))
self.assertEqual(
self.Person.objects.aggregate_sum('age'), sum(ages)
)
self.Person(name='ageless person').save()
self.assertEqual(self.Person.objects.sum('age'), sum(ages))
self.assertEqual(
self.Person.objects.aggregate_sum('age'), sum(ages)
)
for i, age in enumerate(ages):
self.Person(name='test meta%s' %
@ -2835,26 +2816,15 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(
self.Person.objects.sum('person_meta.weight'), sum(ages)
)
self.assertEqual(
self.Person.objects.aggregate_sum('person_meta.weight'),
sum(ages)
)
self.Person(name='weightless person').save()
self.assertEqual(self.Person.objects.sum('age'), sum(ages))
self.assertEqual(
self.Person.objects.aggregate_sum('age'), sum(ages)
)
# test summing over a filtered queryset
self.assertEqual(
self.Person.objects.filter(age__gte=50).sum('age'),
sum([a for a in ages if a >= 50])
)
self.assertEqual(
self.Person.objects.filter(age__gte=50).aggregate_sum('age'),
sum([a for a in ages if a >= 50])
)
def test_embedded_average(self):
class Pay(EmbeddedDocument):
@ -2867,21 +2837,12 @@ class QuerySetTest(unittest.TestCase):
Doc.drop_collection()
Doc(name=u"Wilson Junior",
pay=Pay(value=150)).save()
Doc(name='Wilson Junior', pay=Pay(value=150)).save()
Doc(name='Isabella Luanna', pay=Pay(value=530)).save()
Doc(name='Tayza mariana', pay=Pay(value=165)).save()
Doc(name='Eliana Costa', pay=Pay(value=115)).save()
Doc(name=u"Isabella Luanna",
pay=Pay(value=530)).save()
Doc(name=u"Tayza mariana",
pay=Pay(value=165)).save()
Doc(name=u"Eliana Costa",
pay=Pay(value=115)).save()
self.assertEqual(
Doc.objects.average('pay.value'),
240)
self.assertEqual(Doc.objects.average('pay.value'), 240)
def test_embedded_array_average(self):
class Pay(EmbeddedDocument):
@ -2889,26 +2850,16 @@ class QuerySetTest(unittest.TestCase):
class Doc(Document):
name = StringField()
pay = EmbeddedDocumentField(
Pay)
pay = EmbeddedDocumentField(Pay)
Doc.drop_collection()
Doc(name=u"Wilson Junior",
pay=Pay(values=[150, 100])).save()
Doc(name='Wilson Junior', pay=Pay(values=[150, 100])).save()
Doc(name='Isabella Luanna', pay=Pay(values=[530, 100])).save()
Doc(name='Tayza mariana', pay=Pay(values=[165, 100])).save()
Doc(name='Eliana Costa', pay=Pay(values=[115, 100])).save()
Doc(name=u"Isabella Luanna",
pay=Pay(values=[530, 100])).save()
Doc(name=u"Tayza mariana",
pay=Pay(values=[165, 100])).save()
Doc(name=u"Eliana Costa",
pay=Pay(values=[115, 100])).save()
self.assertEqual(
Doc.objects.average('pay.values'),
170)
self.assertEqual(Doc.objects.average('pay.values'), 170)
def test_array_average(self):
class Doc(Document):
@ -2921,9 +2872,7 @@ class QuerySetTest(unittest.TestCase):
Doc(values=[165, 100]).save()
Doc(values=[115, 100]).save()
self.assertEqual(
Doc.objects.average('values'),
170)
self.assertEqual(Doc.objects.average('values'), 170)
def test_embedded_sum(self):
class Pay(EmbeddedDocument):
@ -2931,26 +2880,16 @@ class QuerySetTest(unittest.TestCase):
class Doc(Document):
name = StringField()
pay = EmbeddedDocumentField(
Pay)
pay = EmbeddedDocumentField(Pay)
Doc.drop_collection()
Doc(name=u"Wilson Junior",
pay=Pay(value=150)).save()
Doc(name='Wilson Junior', pay=Pay(value=150)).save()
Doc(name='Isabella Luanna', pay=Pay(value=530)).save()
Doc(name='Tayza mariana', pay=Pay(value=165)).save()
Doc(name='Eliana Costa', pay=Pay(value=115)).save()
Doc(name=u"Isabella Luanna",
pay=Pay(value=530)).save()
Doc(name=u"Tayza mariana",
pay=Pay(value=165)).save()
Doc(name=u"Eliana Costa",
pay=Pay(value=115)).save()
self.assertEqual(
Doc.objects.sum('pay.value'),
960)
self.assertEqual(Doc.objects.sum('pay.value'), 960)
def test_embedded_array_sum(self):
class Pay(EmbeddedDocument):
@ -2958,26 +2897,16 @@ class QuerySetTest(unittest.TestCase):
class Doc(Document):
name = StringField()
pay = EmbeddedDocumentField(
Pay)
pay = EmbeddedDocumentField(Pay)
Doc.drop_collection()
Doc(name=u"Wilson Junior",
pay=Pay(values=[150, 100])).save()
Doc(name='Wilson Junior', pay=Pay(values=[150, 100])).save()
Doc(name='Isabella Luanna', pay=Pay(values=[530, 100])).save()
Doc(name='Tayza mariana', pay=Pay(values=[165, 100])).save()
Doc(name='Eliana Costa', pay=Pay(values=[115, 100])).save()
Doc(name=u"Isabella Luanna",
pay=Pay(values=[530, 100])).save()
Doc(name=u"Tayza mariana",
pay=Pay(values=[165, 100])).save()
Doc(name=u"Eliana Costa",
pay=Pay(values=[115, 100])).save()
self.assertEqual(
Doc.objects.sum('pay.values'),
1360)
self.assertEqual(Doc.objects.sum('pay.values'), 1360)
def test_array_sum(self):
class Doc(Document):
@ -2990,9 +2919,7 @@ class QuerySetTest(unittest.TestCase):
Doc(values=[165, 100]).save()
Doc(values=[115, 100]).save()
self.assertEqual(
Doc.objects.sum('values'),
1360)
self.assertEqual(Doc.objects.sum('values'), 1360)
def test_distinct(self):
"""Ensure that the QuerySet.distinct method works.
@ -3604,6 +3531,15 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(MyDoc.objects.count(), 10)
self.assertEqual(MyDoc.objects.none().count(), 0)
def test_count_list_embedded(self):
class B(EmbeddedDocument):
c = StringField()
class A(Document):
b = ListField(EmbeddedDocumentField(B))
self.assertEqual(A.objects(b=[{'c': 'c'}]).count(), 0)
def test_call_after_limits_set(self):
"""Ensure that re-filtering after slicing works
"""
@ -4105,6 +4041,10 @@ class QuerySetTest(unittest.TestCase):
Foo(shape="circle", color="purple", thick=False)])
b2.save()
b3 = Bar(foo=[Foo(shape="square", thick=True),
Foo(shape="circle", color="purple", thick=False)])
b3.save()
ak = list(
Bar.objects(foo__match={'shape': "square", "color": "purple"}))
self.assertEqual([b1], ak)
@ -4116,6 +4056,22 @@ class QuerySetTest(unittest.TestCase):
ak = list(Bar.objects(foo__match=Foo(shape="square", color="purple")))
self.assertEqual([b1], ak)
ak = list(
Bar.objects(foo__elemMatch={'shape': "square", "color__exists": True}))
self.assertEqual([b1, b2], ak)
ak = list(
Bar.objects(foo__match={'shape': "square", "color__exists": True}))
self.assertEqual([b1, b2], ak)
ak = list(
Bar.objects(foo__elemMatch={'shape': "square", "color__exists": False}))
self.assertEqual([b3], ak)
ak = list(
Bar.objects(foo__match={'shape': "square", "color__exists": False}))
self.assertEqual([b3], ak)
def test_upsert_includes_cls(self):
"""Upserts should include _cls information for inheritable classes
"""
@ -4156,7 +4112,11 @@ class QuerySetTest(unittest.TestCase):
def test_read_preference(self):
class Bar(Document):
pass
txt = StringField()
meta = {
'indexes': [ 'txt' ]
}
Bar.drop_collection()
bars = list(Bar.objects(read_preference=ReadPreference.PRIMARY))
@ -4168,9 +4128,51 @@ class QuerySetTest(unittest.TestCase):
error_class = TypeError
self.assertRaises(error_class, Bar.objects, read_preference='Primary')
# read_preference as a kwarg
bars = Bar.objects(read_preference=ReadPreference.SECONDARY_PREFERRED)
self.assertEqual(
bars._read_preference, ReadPreference.SECONDARY_PREFERRED)
self.assertEqual(bars._cursor._Cursor__read_preference,
ReadPreference.SECONDARY_PREFERRED)
# read_preference as a query set method
bars = Bar.objects.read_preference(ReadPreference.SECONDARY_PREFERRED)
self.assertEqual(
bars._read_preference, ReadPreference.SECONDARY_PREFERRED)
self.assertEqual(bars._cursor._Cursor__read_preference,
ReadPreference.SECONDARY_PREFERRED)
# read_preference after skip
bars = Bar.objects.skip(1) \
.read_preference(ReadPreference.SECONDARY_PREFERRED)
self.assertEqual(
bars._read_preference, ReadPreference.SECONDARY_PREFERRED)
self.assertEqual(bars._cursor._Cursor__read_preference,
ReadPreference.SECONDARY_PREFERRED)
# read_preference after limit
bars = Bar.objects.limit(1) \
.read_preference(ReadPreference.SECONDARY_PREFERRED)
self.assertEqual(
bars._read_preference, ReadPreference.SECONDARY_PREFERRED)
self.assertEqual(bars._cursor._Cursor__read_preference,
ReadPreference.SECONDARY_PREFERRED)
# read_preference after order_by
bars = Bar.objects.order_by('txt') \
.read_preference(ReadPreference.SECONDARY_PREFERRED)
self.assertEqual(
bars._read_preference, ReadPreference.SECONDARY_PREFERRED)
self.assertEqual(bars._cursor._Cursor__read_preference,
ReadPreference.SECONDARY_PREFERRED)
# read_preference after hint
bars = Bar.objects.hint([('txt', 1)]) \
.read_preference(ReadPreference.SECONDARY_PREFERRED)
self.assertEqual(
bars._read_preference, ReadPreference.SECONDARY_PREFERRED)
self.assertEqual(bars._cursor._Cursor__read_preference,
ReadPreference.SECONDARY_PREFERRED)
def test_json_simple(self):
@ -4824,5 +4826,6 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(1, Doc.objects(item__type__="axe").count())
if __name__ == '__main__':
unittest.main()

View File

@ -224,6 +224,10 @@ class TransformTest(unittest.TestCase):
self.assertEqual(1, Doc.objects(item__type__="axe").count())
self.assertEqual(1, Doc.objects(item__name__="Heroic axe").count())
Doc.objects(id=doc.id).update(set__item__type__='sword')
self.assertEqual(1, Doc.objects(item__type__="sword").count())
self.assertEqual(0, Doc.objects(item__type__="axe").count())
def test_understandable_error_raised(self):
class Event(Document):
title = StringField()

View File

@ -8,6 +8,7 @@ try:
import unittest2 as unittest
except ImportError:
import unittest
from nose.plugins.skip import SkipTest
import pymongo
from bson.tz_util import utc
@ -51,6 +52,42 @@ class ConnectionTest(unittest.TestCase):
conn = get_connection('testdb')
self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient))
def test_connect_in_mocking(self):
"""Ensure that the connect() method works properly in mocking.
"""
try:
import mongomock
except ImportError:
raise SkipTest('you need mongomock installed to run this testcase')
connect('mongoenginetest', host='mongomock://localhost')
conn = get_connection()
self.assertTrue(isinstance(conn, mongomock.MongoClient))
connect('mongoenginetest2', host='mongomock://localhost', alias='testdb2')
conn = get_connection('testdb2')
self.assertTrue(isinstance(conn, mongomock.MongoClient))
connect('mongoenginetest3', host='mongodb://localhost', is_mock=True, alias='testdb3')
conn = get_connection('testdb3')
self.assertTrue(isinstance(conn, mongomock.MongoClient))
connect('mongoenginetest4', is_mock=True, alias='testdb4')
conn = get_connection('testdb4')
self.assertTrue(isinstance(conn, mongomock.MongoClient))
connect(host='mongodb://localhost:27017/mongoenginetest5', is_mock=True, alias='testdb5')
conn = get_connection('testdb5')
self.assertTrue(isinstance(conn, mongomock.MongoClient))
connect(host='mongomock://localhost:27017/mongoenginetest6', alias='testdb6')
conn = get_connection('testdb6')
self.assertTrue(isinstance(conn, mongomock.MongoClient))
connect(host='mongomock://localhost:27017/mongoenginetest7', is_mock=True, alias='testdb7')
conn = get_connection('testdb7')
self.assertTrue(isinstance(conn, mongomock.MongoClient))
def test_disconnect(self):
"""Ensure that the disconnect() method works properly
"""
@ -151,7 +188,7 @@ class ConnectionTest(unittest.TestCase):
self.assertRaises(ConnectionError, get_db, 'test1')
# Authentication succeeds with "authSource"
test_conn2 = connect(
connect(
'mongoenginetest', alias='test2',
host=('mongodb://username2:password@localhost/'
'mongoenginetest?authSource=admin')

View File

@ -12,9 +12,13 @@ from mongoengine.context_managers import query_counter
class FieldTest(unittest.TestCase):
def setUp(self):
connect(db='mongoenginetest')
self.db = get_db()
@classmethod
def setUpClass(cls):
cls.db = connect(db='mongoenginetest')
@classmethod
def tearDownClass(cls):
cls.db.drop_database('mongoenginetest')
def test_list_item_dereference(self):
"""Ensure that DBRef items in ListFields are dereferenced.
@ -304,6 +308,7 @@ class FieldTest(unittest.TestCase):
User.drop_collection()
Post.drop_collection()
SimpleList.drop_collection()
u1 = User.objects.create(name='u1')
u2 = User.objects.create(name='u2')

View File

@ -25,6 +25,8 @@ class SignalTests(unittest.TestCase):
connect(db='mongoenginetest')
class Author(Document):
# Make the id deterministic for easier testing
id = SequenceField(primary_key=True)
name = StringField()
def __unicode__(self):
@ -33,7 +35,7 @@ class SignalTests(unittest.TestCase):
@classmethod
def pre_init(cls, sender, document, *args, **kwargs):
signal_output.append('pre_init signal, %s' % cls.__name__)
signal_output.append(str(kwargs['values']))
signal_output.append(kwargs['values'])
@classmethod
def post_init(cls, sender, document, **kwargs):
@ -43,48 +45,55 @@ class SignalTests(unittest.TestCase):
@classmethod
def pre_save(cls, sender, document, **kwargs):
signal_output.append('pre_save signal, %s' % document)
signal_output.append(kwargs)
@classmethod
def pre_save_post_validation(cls, sender, document, **kwargs):
signal_output.append('pre_save_post_validation signal, %s' % document)
if 'created' in kwargs:
if kwargs['created']:
signal_output.append('Is created')
else:
signal_output.append('Is updated')
if kwargs.pop('created', False):
signal_output.append('Is created')
else:
signal_output.append('Is updated')
signal_output.append(kwargs)
@classmethod
def post_save(cls, sender, document, **kwargs):
dirty_keys = document._delta()[0].keys() + document._delta()[1].keys()
signal_output.append('post_save signal, %s' % document)
signal_output.append('post_save dirty keys, %s' % dirty_keys)
if 'created' in kwargs:
if kwargs['created']:
signal_output.append('Is created')
else:
signal_output.append('Is updated')
if kwargs.pop('created', False):
signal_output.append('Is created')
else:
signal_output.append('Is updated')
signal_output.append(kwargs)
@classmethod
def pre_delete(cls, sender, document, **kwargs):
signal_output.append('pre_delete signal, %s' % document)
signal_output.append(kwargs)
@classmethod
def post_delete(cls, sender, document, **kwargs):
signal_output.append('post_delete signal, %s' % document)
signal_output.append(kwargs)
@classmethod
def pre_bulk_insert(cls, sender, documents, **kwargs):
signal_output.append('pre_bulk_insert signal, %s' % documents)
signal_output.append(kwargs)
@classmethod
def post_bulk_insert(cls, sender, documents, **kwargs):
signal_output.append('post_bulk_insert signal, %s' % documents)
if kwargs.get('loaded', False):
if kwargs.pop('loaded', False):
signal_output.append('Is loaded')
else:
signal_output.append('Not loaded')
signal_output.append(kwargs)
self.Author = Author
Author.drop_collection()
Author.id.set_next_value(0)
class Another(Document):
@ -96,10 +105,12 @@ class SignalTests(unittest.TestCase):
@classmethod
def pre_delete(cls, sender, document, **kwargs):
signal_output.append('pre_delete signal, %s' % document)
signal_output.append(kwargs)
@classmethod
def post_delete(cls, sender, document, **kwargs):
signal_output.append('post_delete signal, %s' % document)
signal_output.append(kwargs)
self.Another = Another
Another.drop_collection()
@ -118,6 +129,41 @@ class SignalTests(unittest.TestCase):
self.ExplicitId = ExplicitId
ExplicitId.drop_collection()
class Post(Document):
title = StringField()
content = StringField()
active = BooleanField(default=False)
def __unicode__(self):
return self.title
@classmethod
def pre_bulk_insert(cls, sender, documents, **kwargs):
signal_output.append('pre_bulk_insert signal, %s' %
[(doc, {'active': documents[n].active})
for n, doc in enumerate(documents)])
# make changes here, this is just an example -
# it could be anything that needs pre-validation or looks-ups before bulk bulk inserting
for document in documents:
if not document.active:
document.active = True
signal_output.append(kwargs)
@classmethod
def post_bulk_insert(cls, sender, documents, **kwargs):
signal_output.append('post_bulk_insert signal, %s' %
[(doc, {'active': documents[n].active})
for n, doc in enumerate(documents)])
if kwargs.pop('loaded', False):
signal_output.append('Is loaded')
else:
signal_output.append('Not loaded')
signal_output.append(kwargs)
self.Post = Post
Post.drop_collection()
# Save up the number of connected signals so that we can check at the
# end that all the signals we register get properly unregistered
self.pre_signals = (
@ -147,6 +193,9 @@ class SignalTests(unittest.TestCase):
signals.post_save.connect(ExplicitId.post_save, sender=ExplicitId)
signals.pre_bulk_insert.connect(Post.pre_bulk_insert, sender=Post)
signals.post_bulk_insert.connect(Post.post_bulk_insert, sender=Post)
def tearDown(self):
signals.pre_init.disconnect(self.Author.pre_init)
signals.post_init.disconnect(self.Author.post_init)
@ -163,6 +212,9 @@ class SignalTests(unittest.TestCase):
signals.post_save.disconnect(self.ExplicitId.post_save)
signals.pre_bulk_insert.disconnect(self.Post.pre_bulk_insert)
signals.post_bulk_insert.disconnect(self.Post.post_bulk_insert)
# Check that all our signals got disconnected properly.
post_signals = (
len(signals.pre_init.receivers),
@ -199,66 +251,121 @@ class SignalTests(unittest.TestCase):
a.save()
self.get_signal_output(lambda: None) # eliminate signal output
a1 = self.Author.objects(name='Bill Shakespeare')[0]
self.assertEqual(self.get_signal_output(create_author), [
"pre_init signal, Author",
"{'name': 'Bill Shakespeare'}",
{'name': 'Bill Shakespeare'},
"post_init signal, Bill Shakespeare, document._created = True",
])
a1 = self.Author(name='Bill Shakespeare')
self.assertEqual(self.get_signal_output(a1.save), [
"pre_save signal, Bill Shakespeare",
{},
"pre_save_post_validation signal, Bill Shakespeare",
"Is created",
{},
"post_save signal, Bill Shakespeare",
"post_save dirty keys, ['name']",
"Is created"
"Is created",
{}
])
a1.reload()
a1.name = 'William Shakespeare'
self.assertEqual(self.get_signal_output(a1.save), [
"pre_save signal, William Shakespeare",
{},
"pre_save_post_validation signal, William Shakespeare",
"Is updated",
{},
"post_save signal, William Shakespeare",
"post_save dirty keys, ['name']",
"Is updated"
"Is updated",
{}
])
self.assertEqual(self.get_signal_output(a1.delete), [
'pre_delete signal, William Shakespeare',
{},
'post_delete signal, William Shakespeare',
{}
])
signal_output = self.get_signal_output(load_existing_author)
# test signal_output lines separately, because of random ObjectID after object load
self.assertEqual(signal_output[0],
self.assertEqual(self.get_signal_output(load_existing_author), [
"pre_init signal, Author",
)
self.assertEqual(signal_output[2],
"post_init signal, Bill Shakespeare, document._created = False",
)
{'id': 2, 'name': 'Bill Shakespeare'},
"post_init signal, Bill Shakespeare, document._created = False"
])
signal_output = self.get_signal_output(bulk_create_author_with_load)
# The output of this signal is not entirely deterministic. The reloaded
# object will have an object ID. Hence, we only check part of the output
self.assertEqual(signal_output[3], "pre_bulk_insert signal, [<Author: Bill Shakespeare>]"
)
self.assertEqual(signal_output[-2:],
["post_bulk_insert signal, [<Author: Bill Shakespeare>]",
"Is loaded",])
self.assertEqual(self.get_signal_output(bulk_create_author_with_load), [
'pre_init signal, Author',
{'name': 'Bill Shakespeare'},
'post_init signal, Bill Shakespeare, document._created = True',
'pre_bulk_insert signal, [<Author: Bill Shakespeare>]',
{},
'pre_init signal, Author',
{'id': 3, 'name': 'Bill Shakespeare'},
'post_init signal, Bill Shakespeare, document._created = False',
'post_bulk_insert signal, [<Author: Bill Shakespeare>]',
'Is loaded',
{}
])
self.assertEqual(self.get_signal_output(bulk_create_author_without_load), [
"pre_init signal, Author",
"{'name': 'Bill Shakespeare'}",
{'name': 'Bill Shakespeare'},
"post_init signal, Bill Shakespeare, document._created = True",
"pre_bulk_insert signal, [<Author: Bill Shakespeare>]",
{},
"post_bulk_insert signal, [<Author: Bill Shakespeare>]",
"Not loaded",
{}
])
def test_signal_kwargs(self):
""" Make sure signal_kwargs is passed to signals calls. """
def live_and_let_die():
a = self.Author(name='Bill Shakespeare')
a.save(signal_kwargs={'live': True, 'die': False})
a.delete(signal_kwargs={'live': False, 'die': True})
self.assertEqual(self.get_signal_output(live_and_let_die), [
"pre_init signal, Author",
{'name': 'Bill Shakespeare'},
"post_init signal, Bill Shakespeare, document._created = True",
"pre_save signal, Bill Shakespeare",
{'die': False, 'live': True},
"pre_save_post_validation signal, Bill Shakespeare",
"Is created",
{'die': False, 'live': True},
"post_save signal, Bill Shakespeare",
"post_save dirty keys, ['name']",
"Is created",
{'die': False, 'live': True},
'pre_delete signal, Bill Shakespeare',
{'die': True, 'live': False},
'post_delete signal, Bill Shakespeare',
{'die': True, 'live': False}
])
def bulk_create_author():
a1 = self.Author(name='Bill Shakespeare')
self.Author.objects.insert([a1], signal_kwargs={'key': True})
self.assertEqual(self.get_signal_output(bulk_create_author), [
'pre_init signal, Author',
{'name': 'Bill Shakespeare'},
'post_init signal, Bill Shakespeare, document._created = True',
'pre_bulk_insert signal, [<Author: Bill Shakespeare>]',
{'key': True},
'pre_init signal, Author',
{'id': 2, 'name': 'Bill Shakespeare'},
'post_init signal, Bill Shakespeare, document._created = False',
'post_bulk_insert signal, [<Author: Bill Shakespeare>]',
'Is loaded',
{'key': True}
])
def test_queryset_delete_signals(self):
@ -267,7 +374,9 @@ class SignalTests(unittest.TestCase):
self.Another(name='Bill Shakespeare').save()
self.assertEqual(self.get_signal_output(self.Another.objects.delete), [
'pre_delete signal, Bill Shakespeare',
{},
'post_delete signal, Bill Shakespeare',
{}
])
def test_signals_with_explicit_doc_ids(self):
@ -306,6 +415,23 @@ class SignalTests(unittest.TestCase):
ei.switch_db("testdb-1", keep_created=False)
self.assertEqual(self.get_signal_output(ei.save), ['Is created'])
def test_signals_bulk_insert(self):
def bulk_set_active_post():
posts = [
self.Post(title='Post 1'),
self.Post(title='Post 2'),
self.Post(title='Post 3')
]
self.Post.objects.insert(posts)
results = self.get_signal_output(bulk_set_active_post)
self.assertEqual(results, [
"pre_bulk_insert signal, [(<Post: Post 1>, {'active': False}), (<Post: Post 2>, {'active': False}), (<Post: Post 3>, {'active': False})]",
{},
"post_bulk_insert signal, [(<Post: Post 1>, {'active': True}), (<Post: Post 2>, {'active': True}), (<Post: Post 3>, {'active': True})]",
'Is loaded',
{}
])
if __name__ == '__main__':
unittest.main()

View File

@ -1,5 +1,5 @@
[tox]
envlist = {py26,py27,py32,py33,py34,pypy,pypy3}-{mg27,mg28}
envlist = {py26,py27,py32,py33,py34,py35,pypy,pypy3}-{mg27,mg28}
#envlist = {py26,py27,py32,py33,py34,pypy,pypy3}-{mg27,mg28,mg30,mgdev}
[testenv]
@ -12,3 +12,6 @@ deps =
mg28: PyMongo>=2.8,<3.0
mg30: PyMongo>=3.0
mgdev: https://github.com/mongodb/mongo-python-driver/tarball/master
setenv =
PYTHON_EGG_CACHE = {envdir}/python-eggs
passenv = windir