Merge remote-tracking branch 'MongoEngine/master'

This commit is contained in:
Breeze.kay 2015-06-19 11:14:51 +08:00
commit 0bbbbdde80
19 changed files with 765 additions and 184 deletions

View File

@ -223,3 +223,4 @@ that much better:
* Kiryl Yermakou (https://github.com/rma4ok)
* Matthieu Rigal (https://github.com/MRigal)
* Charanpal Dhanjal (https://github.com/charanpald)
* Emmanuel Leblond (https://github.com/touilleMan)

View File

@ -19,6 +19,13 @@ Changes in 0.9.X - DEV
- Added __ support to escape field name in fields lookup keywords that match operators names #949
- Support for PyMongo 3+ #946
- Fix for issue where FileField deletion did not free space in GridFS.
- No_dereference() not respected on embedded docs containing reference. #517
- Document save raise an exception if save_condition fails #1005
- Fixes some internal _id handling issue. #961
- Updated URL and Email Field regex validators, added schemes argument to URLField validation. #652
- Removed get_or_create() deprecated since 0.8.0. #300
- Capped collection multiple of 256. #1011
- Added `BaseQuerySet.aggregate_sum` and `BaseQuerySet.aggregate_average` methods.
Changes in 0.9.0
================

View File

@ -315,12 +315,12 @@ reference with a delete rule specification. A delete rule is specified by
supplying the :attr:`reverse_delete_rule` attributes on the
:class:`ReferenceField` definition, like this::
class Employee(Document):
class ProfilePage(Document):
...
profile_page = ReferenceField('ProfilePage', reverse_delete_rule=mongoengine.NULLIFY)
employee = ReferenceField('Employee', reverse_delete_rule=mongoengine.CASCADE)
The declaration in this example means that when an :class:`Employee` object is
removed, the :class:`ProfilePage` that belongs to that employee is removed as
removed, the :class:`ProfilePage` that references that employee is removed as
well. If a whole batch of employees is removed, all profile pages that are
linked are removed as well.
@ -447,8 +447,10 @@ A :class:`~mongoengine.Document` may use a **Capped Collection** by specifying
:attr:`max_documents` and :attr:`max_size` in the :attr:`meta` dictionary.
:attr:`max_documents` is the maximum number of documents that is allowed to be
stored in the collection, and :attr:`max_size` is the maximum size of the
collection in bytes. If :attr:`max_size` is not specified and
:attr:`max_documents` is, :attr:`max_size` defaults to 10000000 bytes (10MB).
collection in bytes. :attr:`max_size` is rounded up to the next multiple of 256
by MongoDB internally and mongoengine before. Use also a multiple of 256 to
avoid confusions. If :attr:`max_size` is not specified and
:attr:`max_documents` is, :attr:`max_size` defaults to 10485760 bytes (10MB).
The following example shows a :class:`Log` document that will be limited to
1000 entries and 2MB of disk space::
@ -465,19 +467,26 @@ You can specify indexes on collections to make querying faster. This is done
by creating a list of index specifications called :attr:`indexes` in the
:attr:`~mongoengine.Document.meta` dictionary, where an index specification may
either be a single field name, a tuple containing multiple field names, or a
dictionary containing a full index definition. A direction may be specified on
fields by prefixing the field name with a **+** (for ascending) or a **-** sign
(for descending). Note that direction only matters on multi-field indexes.
Text indexes may be specified by prefixing the field name with a **$**. ::
dictionary containing a full index definition.
A direction may be specified on fields by prefixing the field name with a
**+** (for ascending) or a **-** sign (for descending). Note that direction
only matters on multi-field indexes. Text indexes may be specified by prefixing
the field name with a **$**. Hashed indexes may be specified by prefixing
the field name with a **#**::
class Page(Document):
category = IntField()
title = StringField()
rating = StringField()
created = DateTimeField()
meta = {
'indexes': [
'title',
'$title', # text index
'#title', # hashed index
('title', '-rating'),
('category', '_cls'),
{
'fields': ['created'],
'expireAfterSeconds': 3600
@ -532,11 +541,14 @@ There are a few top level defaults for all indexes that can be set::
:attr:`index_background` (Optional)
Set the default value for if an index should be indexed in the background
:attr:`index_cls` (Optional)
A way to turn off a specific index for _cls.
:attr:`index_drop_dups` (Optional)
Set the default value for if an index should drop duplicates
:attr:`index_cls` (Optional)
A way to turn off a specific index for _cls.
.. note:: Since MongoDB 3.0 drop_dups is not supported anymore. Raises a Warning
and has no effect
Compound Indexes and Indexing sub documents

View File

@ -263,21 +263,11 @@ no document matches the query, and
if more than one document matched the query. These exceptions are merged into
your document definitions eg: `MyDoc.DoesNotExist`
A variation of this method exists,
:meth:`~mongoengine.queryset.QuerySet.get_or_create`, that will create a new
document with the query arguments if no documents match the query. An
additional keyword argument, :attr:`defaults` may be provided, which will be
used as default values for the new document, in the case that it should need
to be created::
>>> a, created = User.objects.get_or_create(name='User A', defaults={'age': 30})
>>> b, created = User.objects.get_or_create(name='User A', defaults={'age': 40})
>>> a.name == b.name and a.age == b.age
True
.. warning::
:meth:`~mongoengine.queryset.QuerySet.get_or_create` method is deprecated
since :mod:`mongoengine` 0.8.
A variation of this method, get_or_create() existed, but it was unsafe. It
could not be made safe, because there are no transactions in mongoDB. Other
approaches should be investigated, to ensure you don't accidentally duplicate
data when using something similar to this method. Therefore it was deprecated
in 0.8 and removed in 0.10.
Default Document queries
========================

View File

@ -184,7 +184,7 @@ class BaseDocument(object):
self__initialised = False
# Check if the user has created a new instance of a class
if (self._is_document and self__initialised
and self__created and name == self._meta['id_field']):
and self__created and name == self._meta.get('id_field')):
super(BaseDocument, self).__setattr__('_created', False)
super(BaseDocument, self).__setattr__(name, value)
@ -672,7 +672,7 @@ class BaseDocument(object):
@classmethod
def _get_collection_name(cls):
"""Returns the collection name for this class.
"""Returns the collection name for this class. None for abstract class
"""
return cls._meta.get('collection', None)
@ -782,7 +782,7 @@ class BaseDocument(object):
allow_inheritance = cls._meta.get('allow_inheritance',
ALLOW_INHERITANCE)
include_cls = (allow_inheritance and not spec.get('sparse', False) and
spec.get('cls', True))
spec.get('cls', True) and '_cls' not in spec['fields'])
# 733: don't include cls if index_cls is False unless there is an explicit cls with the index
include_cls = include_cls and (spec.get('cls', False) or cls._meta.get('index_cls', True))
@ -795,16 +795,25 @@ class BaseDocument(object):
# ASCENDING from +
# DESCENDING from -
# GEO2D from *
# TEXT from $
# HASHED from #
# GEOSPHERE from (
# GEOHAYSTACK from )
# GEO2D from *
direction = pymongo.ASCENDING
if key.startswith("-"):
direction = pymongo.DESCENDING
elif key.startswith("*"):
direction = pymongo.GEO2D
elif key.startswith("$"):
direction = pymongo.TEXT
if key.startswith(("+", "-", "*", "$")):
elif key.startswith("#"):
direction = pymongo.HASHED
elif key.startswith("("):
direction = pymongo.GEOSPHERE
elif key.startswith(")"):
direction = pymongo.GEOHAYSTACK
elif key.startswith("*"):
direction = pymongo.GEO2D
if key.startswith(("+", "-", "*", "$", "#", "(", ")")):
key = key[1:]
# Use real field name, do it manually because we need field
@ -827,7 +836,8 @@ class BaseDocument(object):
index_list.append((key, direction))
# Don't add cls to a geo index
if include_cls and direction is not pymongo.GEO2D:
if include_cls and direction not in (
pymongo.GEO2D, pymongo.GEOHAYSTACK, pymongo.GEOSPHERE):
index_list.insert(0, ('_cls', 1))
if index_list:
@ -973,8 +983,13 @@ class BaseDocument(object):
if hasattr(getattr(field, 'field', None), 'lookup_member'):
new_field = field.field.lookup_member(field_name)
else:
# Look up subfield on the previous field
# Look up subfield on the previous field or raise
try:
new_field = field.lookup_member(field_name)
except AttributeError:
raise LookUpError('Cannot resolve subfield or operator {} '
'on the field {}'.format(
field_name, field.name))
if not new_field and isinstance(field, ComplexBaseField):
if hasattr(field.field, 'document_type') and cls._dynamic \
and field.field.document_type._dynamic:

View File

@ -290,6 +290,7 @@ class ComplexBaseField(BaseField):
return value
if self.field:
self.field._auto_dereference = self._auto_dereference
value_dict = dict([(key, self.field.to_python(item))
for key, item in value.items()])
else:
@ -424,8 +425,11 @@ class ObjectIdField(BaseField):
"""
def to_python(self, value):
try:
if not isinstance(value, ObjectId):
value = ObjectId(value)
except:
pass
return value
def to_mongo(self, value):

View File

@ -385,15 +385,17 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
new_class._auto_id_field = getattr(parent_doc_cls,
'_auto_id_field', False)
if not new_class._meta.get('id_field'):
# After 0.10, find not existing names, instead of overwriting
id_name, id_db_name = cls.get_auto_id_names(new_class)
new_class._auto_id_field = True
new_class._meta['id_field'] = 'id'
new_class._fields['id'] = ObjectIdField(db_field='_id')
new_class._fields['id'].name = 'id'
new_class.id = new_class._fields['id']
new_class._meta['id_field'] = id_name
new_class._fields[id_name] = ObjectIdField(db_field=id_db_name)
new_class._fields[id_name].name = id_name
new_class.id = new_class._fields[id_name]
new_class._db_field_map[id_name] = id_db_name
new_class._reverse_db_field_map[id_db_name] = id_name
# Prepend id field to _fields_ordered
if 'id' in new_class._fields and 'id' not in new_class._fields_ordered:
new_class._fields_ordered = ('id', ) + new_class._fields_ordered
new_class._fields_ordered = (id_name, ) + new_class._fields_ordered
# Merge in exceptions with parent hierarchy
exceptions_to_merge = (DoesNotExist, MultipleObjectsReturned)
@ -408,6 +410,19 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
return new_class
def get_auto_id_names(self):
id_name, id_db_name = ('id', '_id')
if id_name not in self._fields and \
id_db_name not in (v.db_field for v in self._fields.values()):
return id_name, id_db_name
id_basename, id_db_basename, i = 'auto_id', '_auto_id', 0
while id_name in self._fields or \
id_db_name in (v.db_field for v in self._fields.values()):
id_name = '{0}_{1}'.format(id_basename, i)
id_db_name = '{0}_{1}'.format(id_db_basename, i)
i += 1
return id_name, id_db_name
class MetaDict(dict):

View File

@ -128,21 +128,25 @@ class DeReference(object):
"""
object_map = {}
for collection, dbrefs in self.reference_map.iteritems():
refs = [dbref for dbref in dbrefs
if unicode(dbref).encode('utf-8') not in object_map]
if hasattr(collection, 'objects'): # We have a document class for the refs
col_name = collection._get_collection_name()
refs = [dbref for dbref in dbrefs
if (col_name, dbref) not in object_map]
references = collection.objects.in_bulk(refs)
for key, doc in references.iteritems():
object_map[key] = doc
object_map[(col_name, key)] = doc
else: # Generic reference: use the refs data to convert to document
if isinstance(doc_type, (ListField, DictField, MapField,)):
continue
refs = [dbref for dbref in dbrefs
if (collection, dbref) not in object_map]
if doc_type:
references = doc_type._get_db()[collection].find({'_id': {'$in': refs}})
for ref in references:
doc = doc_type._from_son(ref)
object_map[doc.id] = doc
object_map[(collection, doc.id)] = doc
else:
references = get_db()[collection].find({'_id': {'$in': refs}})
for ref in references:
@ -154,7 +158,7 @@ class DeReference(object):
for x in collection.split('_')))._from_son(ref)
else:
doc = doc_type._from_son(ref)
object_map[doc.id] = doc
object_map[(collection, doc.id)] = doc
return object_map
def _attach_objects(self, items, depth=0, instance=None, name=None):
@ -180,7 +184,8 @@ class DeReference(object):
if isinstance(items, (dict, SON)):
if '_ref' in items:
return self.object_map.get(items['_ref'].id, items)
return self.object_map.get(
(items['_ref'].collection, items['_ref'].id), items)
elif '_cls' in items:
doc = get_document(items['_cls'])._from_son(items)
_cls = doc._data.pop('_cls', None)
@ -216,9 +221,11 @@ class DeReference(object):
for field_name, field in v._fields.iteritems():
v = data[k]._data.get(field_name, None)
if isinstance(v, (DBRef)):
data[k]._data[field_name] = self.object_map.get(v.id, v)
data[k]._data[field_name] = self.object_map.get(
(v.collection, v.id), v)
elif isinstance(v, (dict, SON)) and '_ref' in v:
data[k]._data[field_name] = self.object_map.get(v['_ref'].id, v)
data[k]._data[field_name] = self.object_map.get(
(v['_ref'].collection , v['_ref'].id), v)
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
item_name = "{0}.{1}.{2}".format(name, k, field_name)
data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=item_name)
@ -226,7 +233,7 @@ class DeReference(object):
item_name = '%s.%s' % (name, k) if name else name
data[k] = self._attach_objects(v, depth - 1, instance=instance, name=item_name)
elif hasattr(v, 'id'):
data[k] = self.object_map.get(v.id, v)
data[k] = self.object_map.get((v.collection, v.id), v)
if instance and name:
if is_list:

View File

@ -1,4 +1,4 @@
import warnings
import pymongo
import re
@ -17,6 +17,7 @@ from mongoengine.base import (
get_document
)
from mongoengine.errors import InvalidQueryError, InvalidDocumentError
from mongoengine.python_support import IS_PYMONGO_3
from mongoengine.queryset import (OperationError, NotUniqueError,
QuerySet, transform)
from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME
@ -113,9 +114,11 @@ class Document(BaseDocument):
specifying :attr:`max_documents` and :attr:`max_size` in the :attr:`meta`
dictionary. :attr:`max_documents` is the maximum number of documents that
is allowed to be stored in the collection, and :attr:`max_size` is the
maximum size of the collection in bytes. If :attr:`max_size` is not
maximum size of the collection in bytes. :attr:`max_size` is rounded up
to the next multiple of 256 by MongoDB internally and mongoengine before.
Use also a multiple of 256 to avoid confusions. If :attr:`max_size` is not
specified and :attr:`max_documents` is, :attr:`max_size` defaults to
10000000 bytes (10MB).
10485760 bytes (10MB).
Indexes may be created by specifying :attr:`indexes` in the :attr:`meta`
dictionary. The value should be a list of field names or tuples of field
@ -136,7 +139,7 @@ class Document(BaseDocument):
By default, any extra attribute existing in stored data but not declared
in your model will raise a :class:`~mongoengine.FieldDoesNotExist` error.
This can be disabled by setting :attr:`strict` to ``False``
in the :attr:`meta` dictionnary.
in the :attr:`meta` dictionary.
"""
# The __metaclass__ attribute is removed by 2to3 when running with Python3
@ -144,13 +147,15 @@ class Document(BaseDocument):
my_metaclass = TopLevelDocumentMetaclass
__metaclass__ = TopLevelDocumentMetaclass
__slots__ = ('__objects')
__slots__ = ('__objects',)
def pk():
"""Primary key alias
"""
def fget(self):
if 'id_field' not in self._meta:
return None
return getattr(self, self._meta['id_field'])
def fset(self, value):
@ -171,10 +176,13 @@ class Document(BaseDocument):
db = cls._get_db()
collection_name = cls._get_collection_name()
# Create collection as a capped collection if specified
if cls._meta['max_size'] or cls._meta['max_documents']:
if cls._meta.get('max_size') or cls._meta.get('max_documents'):
# Get max document limit and max byte size from meta
max_size = cls._meta['max_size'] or 10000000 # 10MB default
max_documents = cls._meta['max_documents']
max_size = cls._meta.get('max_size') or 10 * 2 ** 20 # 10MB default
max_documents = cls._meta.get('max_documents')
# Round up to next 256 bytes as MongoDB would do it to avoid exception
if max_size % 256:
max_size = (max_size // 256 + 1) * 256
if collection_name in db.collection_names():
cls._collection = db[collection_name]
@ -265,7 +273,8 @@ class Document(BaseDocument):
to cascading saves. Implies ``cascade=True``.
:param _refs: A list of processed references used in cascading saves
:param save_condition: only perform save if matching record in db
satisfies condition(s) (e.g., version number)
satisfies condition(s) (e.g. version number).
Raises :class:`OperationError` if the conditions are not satisfied
.. versionchanged:: 0.5
In existing documents it only saves changed fields using
@ -283,6 +292,8 @@ class Document(BaseDocument):
.. versionchanged:: 0.8.5
Optional save_condition that only overwrites existing documents
if the condition is satisfied in the current db record.
.. versionchanged:: 0.10
:class:`OperationError` exception raised if save_condition fails.
"""
signals.pre_save.send(self.__class__, document=self)
@ -347,6 +358,9 @@ class Document(BaseDocument):
upsert = save_condition is None
last_error = collection.update(select_dict, update_query,
upsert=upsert, **write_concern)
if not upsert and last_error['nModified'] == 0:
raise OperationError('Race condition preventing'
' document update detected')
created = is_new_object(last_error)
if cascade is None:
@ -636,23 +650,51 @@ class Document(BaseDocument):
db.drop_collection(cls._get_collection_name())
@classmethod
def ensure_index(cls, key_or_list, drop_dups=False, background=False,
**kwargs):
"""Ensure that the given indexes are in place.
def create_index(cls, keys, background=False, **kwargs):
"""Creates the given indexes if required.
:param key_or_list: a single index key or a list of index keys (to
:param keys: a single index key or a list of index keys (to
construct a multi-field index); keys may be prefixed with a **+**
or a **-** to determine the index ordering
:param background: Allows index creation in the background
"""
index_spec = cls._build_index_spec(key_or_list)
index_spec = cls._build_index_spec(keys)
index_spec = index_spec.copy()
fields = index_spec.pop('fields')
drop_dups = kwargs.get('drop_dups', False)
if IS_PYMONGO_3 and drop_dups:
msg = "drop_dups is deprecated and is removed when using PyMongo 3+."
warnings.warn(msg, DeprecationWarning)
elif not IS_PYMONGO_3:
index_spec['drop_dups'] = drop_dups
index_spec['background'] = background
index_spec.update(kwargs)
if IS_PYMONGO_3:
return cls._get_collection().create_index(fields, **index_spec)
else:
return cls._get_collection().ensure_index(fields, **index_spec)
@classmethod
def ensure_index(cls, key_or_list, drop_dups=False, background=False,
**kwargs):
"""Ensure that the given indexes are in place. Deprecated in favour
of create_index.
:param key_or_list: a single index key or a list of index keys (to
construct a multi-field index); keys may be prefixed with a **+**
or a **-** to determine the index ordering
:param background: Allows index creation in the background
:param drop_dups: Was removed/ignored with MongoDB >2.7.5. The value
will be removed if PyMongo3+ is used
"""
if IS_PYMONGO_3 and drop_dups:
msg = "drop_dups is deprecated and is removed when using PyMongo 3+."
warnings.warn(msg, DeprecationWarning)
elif not IS_PYMONGO_3:
kwargs.update({'drop_dups': drop_dups})
return cls.create_index(key_or_list, background=background, **kwargs)
@classmethod
def ensure_indexes(cls):
"""Checks the document meta data and ensures all the indexes exist.
@ -666,6 +708,9 @@ class Document(BaseDocument):
drop_dups = cls._meta.get('index_drop_dups', False)
index_opts = cls._meta.get('index_opts') or {}
index_cls = cls._meta.get('index_cls', True)
if IS_PYMONGO_3 and drop_dups:
msg = "drop_dups is deprecated and is removed when using PyMongo 3+."
warnings.warn(msg, DeprecationWarning)
collection = cls._get_collection()
# 746: when connection is via mongos, the read preference is not necessarily an indication that
@ -694,6 +739,9 @@ class Document(BaseDocument):
if 'cls' in opts:
del opts['cls']
if IS_PYMONGO_3:
collection.create_index(fields, background=background, **opts)
else:
collection.ensure_index(fields, background=background,
drop_dups=drop_dups, **opts)
@ -707,6 +755,10 @@ class Document(BaseDocument):
if 'cls' in index_opts:
del index_opts['cls']
if IS_PYMONGO_3:
collection.create_index('_cls', background=background,
**index_opts)
else:
collection.ensure_index('_cls', background=background,
**index_opts)

View File

@ -119,22 +119,31 @@ class URLField(StringField):
"""
_URL_REGEX = re.compile(
r'^(?:http|ftp)s?://' # http:// or https://
# domain...
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|'
r'^(?:[a-z0-9\.\-]*)://' # scheme is validated separately
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}(?<!-)\.?)|' # domain...
r'localhost|' # localhost...
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|' # ...or ipv4
r'\[?[A-F0-9]*:[A-F0-9:]+\]?)' # ...or ipv6
r'(?::\d+)?' # optional port
r'(?:/?|[/?]\S+)$', re.IGNORECASE)
_URL_SCHEMES = ['http', 'https', 'ftp', 'ftps']
def __init__(self, verify_exists=False, url_regex=None, **kwargs):
def __init__(self, verify_exists=False, url_regex=None, schemes=None, **kwargs):
self.verify_exists = verify_exists
self.url_regex = url_regex or self._URL_REGEX
self.schemes = schemes or self._URL_SCHEMES
super(URLField, self).__init__(**kwargs)
def validate(self, value):
# Check first if the scheme is valid
scheme = value.split('://')[0].lower()
if scheme not in self.schemes:
self.error('Invalid scheme {} in URL: {}'.format(scheme, value))
return
# Then check full URL
if not self.url_regex.match(value):
self.error('Invalid URL: %s' % value)
self.error('Invalid URL: {}'.format(value))
return
if self.verify_exists:
@ -162,7 +171,7 @@ class EmailField(StringField):
# quoted-string
r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"'
# domain (max length of an ICAAN TLD is 22 characters)
r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,253}[A-Z0-9])?\.)+[A-Z]{2,22}$', re.IGNORECASE
r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}|[A-Z0-9-]{2,}(?<!-))$', re.IGNORECASE
)
def validate(self, value):
@ -546,7 +555,7 @@ class EmbeddedDocumentField(BaseField):
def to_python(self, value):
if not isinstance(value, self.document_type):
return self.document_type._from_son(value)
return self.document_type._from_son(value, _auto_dereference=self._auto_dereference)
return value
def to_mongo(self, value, use_db_field=True, fields=[]):
@ -1004,6 +1013,7 @@ class CachedReferenceField(BaseField):
"""
A referencefield with cache fields to purpose pseudo-joins
.. versionadded:: 0.9
"""
@ -1674,12 +1684,21 @@ class SequenceField(BaseField):
cluster of machines, it is easier to create an object ID than have
global, uniformly increasing sequence numbers.
:param collection_name: Name of the counter collection (default 'mongoengine.counters')
:param sequence_name: Name of the sequence in the collection (default 'ClassName.counter')
:param value_decorator: Any callable to use as a counter (default int)
Use any callable as `value_decorator` to transform calculated counter into
any value suitable for your needs, e.g. string or hexadecimal
representation of the default integer counter value.
.. versionadded:: 0.5
.. note::
In case the counter is defined in the abstract document, it will be
common to all inherited documents and the default sequence name will
be the class name of the abstract document.
.. versionadded:: 0.5
.. versionchanged:: 0.8 added `value_decorator`
"""
@ -1694,7 +1713,7 @@ class SequenceField(BaseField):
self.sequence_name = sequence_name
self.value_decorator = (callable(value_decorator) and
value_decorator or self.VALUE_DECORATOR)
return super(SequenceField, self).__init__(*args, **kwargs)
super(SequenceField, self).__init__(*args, **kwargs)
def generate(self):
"""
@ -1740,7 +1759,7 @@ class SequenceField(BaseField):
if self.sequence_name:
return self.sequence_name
owner = self.owner_document
if issubclass(owner, Document):
if issubclass(owner, Document) and not owner._meta.get('abstract'):
return owner._get_collection_name()
else:
return ''.join('_%s' % c if c.isupper() else c

View File

@ -258,54 +258,6 @@ class BaseQuerySet(object):
"""
return self._document(**kwargs).save()
def get_or_create(self, write_concern=None, auto_save=True,
*q_objs, **query):
"""Retrieve unique object or create, if it doesn't exist. Returns a
tuple of ``(object, created)``, where ``object`` is the retrieved or
created object and ``created`` is a boolean specifying whether a new
object was created. Raises
:class:`~mongoengine.queryset.MultipleObjectsReturned` or
`DocumentName.MultipleObjectsReturned` if multiple results are found.
A new document will be created if the document doesn't exists; a
dictionary of default values for the new document may be provided as a
keyword argument called :attr:`defaults`.
.. note:: This requires two separate operations and therefore a
race condition exists. Because there are no transactions in
mongoDB other approaches should be investigated, to ensure you
don't accidentally duplicate data when using this method. This is
now scheduled to be removed before 1.0
:param write_concern: optional extra keyword arguments used if we
have to create a new document.
Passes any write_concern onto :meth:`~mongoengine.Document.save`
:param auto_save: if the object is to be saved automatically if
not found.
.. deprecated:: 0.8
.. versionchanged:: 0.6 - added `auto_save`
.. versionadded:: 0.3
"""
msg = ("get_or_create is scheduled to be deprecated. The approach is "
"flawed without transactions. Upserts should be preferred.")
warnings.warn(msg, DeprecationWarning)
defaults = query.get('defaults', {})
if 'defaults' in query:
del query['defaults']
try:
doc = self.get(*q_objs, **query)
return doc, False
except self._document.DoesNotExist:
query.update(defaults)
doc = self._document(**query)
if auto_save:
doc.save(write_concern=write_concern)
return doc, True
def first(self):
"""Retrieve the first object matching the query.
"""
@ -1296,6 +1248,27 @@ class BaseQuerySet(object):
else:
return 0
def aggregate_sum(self, field):
"""Sum over the values of the specified field.
:param field: the field to sum over; use dot-notation to refer to
embedded document fields
This method is more performant than the regular `sum`, because it uses
the aggregation framework instead of map-reduce.
"""
result = self._document._get_collection().aggregate([
{ '$match': self._query },
{ '$group': { '_id': 'sum', 'total': { '$sum': '$' + field } } }
])
if IS_PYMONGO_3:
result = list(result)
else:
result = result.get('result')
if result:
return result[0]['total']
return 0
def average(self, field):
"""Average over the values of the specified field.
@ -1351,6 +1324,27 @@ class BaseQuerySet(object):
else:
return 0
def aggregate_average(self, field):
"""Average over the values of the specified field.
:param field: the field to average over; use dot-notation to refer to
embedded document fields
This method is more performant than the regular `average`, because it
uses the aggregation framework instead of map-reduce.
"""
result = self._document._get_collection().aggregate([
{ '$match': self._query },
{ '$group': { '_id': 'avg', 'total': { '$avg': '$' + field } } }
])
if IS_PYMONGO_3:
result = list(result)
else:
result = result.get('result')
if result:
return result[0]['total']
return 0
def item_frequencies(self, field, normalize=False, map_reduce=True):
"""Returns a dictionary of all items present in a field across
the whole queried set of documents, and their corresponding frequency.

View File

@ -275,6 +275,60 @@ class IndexesTest(unittest.TestCase):
info = [value['key'] for key, value in info.iteritems()]
self.assertTrue([('current.location.point', '2d')] in info)
def test_explicit_geosphere_index(self):
"""Ensure that geosphere indexes work when created via meta[indexes]
"""
class Place(Document):
location = DictField()
meta = {
'allow_inheritance': True,
'indexes': [
'(location.point',
]
}
self.assertEqual([{'fields': [('location.point', '2dsphere')]}],
Place._meta['index_specs'])
Place.ensure_indexes()
info = Place._get_collection().index_information()
info = [value['key'] for key, value in info.iteritems()]
self.assertTrue([('location.point', '2dsphere')] in info)
def test_explicit_geohaystack_index(self):
"""Ensure that geohaystack indexes work when created via meta[indexes]
"""
raise SkipTest('GeoHaystack index creation is not supported for now'
'from meta, as it requires a bucketSize parameter.')
class Place(Document):
location = DictField()
name = StringField()
meta = {
'indexes': [
(')location.point', 'name')
]
}
self.assertEqual([{'fields': [('location.point', 'geoHaystack'), ('name', 1)]}],
Place._meta['index_specs'])
Place.ensure_indexes()
info = Place._get_collection().index_information()
info = [value['key'] for key, value in info.iteritems()]
self.assertTrue([('location.point', 'geoHaystack')] in info)
def test_create_geohaystack_index(self):
"""Ensure that geohaystack indexes can be created
"""
class Place(Document):
location = DictField()
name = StringField()
Place.create_index({'fields': (')location.point', 'name')}, bucketSize=10)
info = Place._get_collection().index_information()
info = [value['key'] for key, value in info.iteritems()]
self.assertTrue([('location.point', 'geoHaystack'), ('name', 1)] in info)
def test_dictionary_indexes(self):
"""Ensure that indexes are used when meta[indexes] contains
dictionaries instead of lists.
@ -822,6 +876,18 @@ class IndexesTest(unittest.TestCase):
key = indexes["title_text"]["key"]
self.assertTrue(('_fts', 'text') in key)
def test_hashed_indexes(self):
class Book(Document):
ref_id = StringField()
meta = {
"indexes": ["#ref_id"],
}
indexes = Book.objects._collection.index_information()
self.assertTrue("ref_id_hashed" in indexes)
self.assertTrue(('ref_id', 'hashed') in indexes["ref_id_hashed"]["key"])
def test_indexes_after_database_drop(self):
"""
Test to ensure that indexes are re-created on a collection even
@ -909,6 +975,30 @@ class IndexesTest(unittest.TestCase):
}
})
def test_compound_index_underscore_cls_not_overwritten(self):
"""
Test that the compound index doesn't get another _cls when it is specified
"""
class TestDoc(Document):
shard_1 = StringField()
txt_1 = StringField()
meta = {
'collection': 'test',
'allow_inheritance': True,
'sparse': True,
'shard_key': 'shard_1',
'indexes': [
('shard_1', '_cls', 'txt_1'),
]
}
TestDoc.drop_collection()
TestDoc.ensure_indexes()
index_info = TestDoc._get_collection().index_information()
self.assertTrue('shard_1_1__cls_1_txt_1_1' in index_info)
if __name__ == '__main__':
unittest.main()

View File

@ -307,6 +307,69 @@ class InheritanceTest(unittest.TestCase):
doc = Animal(name='dog')
self.assertFalse('_cls' in doc.to_mongo())
def test_abstract_handle_ids_in_metaclass_properly(self):
class City(Document):
continent = StringField()
meta = {'abstract': True,
'allow_inheritance': False}
class EuropeanCity(City):
name = StringField()
berlin = EuropeanCity(name='Berlin', continent='Europe')
self.assertEqual(len(berlin._db_field_map), len(berlin._fields_ordered))
self.assertEqual(len(berlin._reverse_db_field_map), len(berlin._fields_ordered))
self.assertEqual(len(berlin._fields_ordered), 3)
self.assertEqual(berlin._fields_ordered[0], 'id')
def test_auto_id_not_set_if_specific_in_parent_class(self):
class City(Document):
continent = StringField()
city_id = IntField(primary_key=True)
meta = {'abstract': True,
'allow_inheritance': False}
class EuropeanCity(City):
name = StringField()
berlin = EuropeanCity(name='Berlin', continent='Europe')
self.assertEqual(len(berlin._db_field_map), len(berlin._fields_ordered))
self.assertEqual(len(berlin._reverse_db_field_map), len(berlin._fields_ordered))
self.assertEqual(len(berlin._fields_ordered), 3)
self.assertEqual(berlin._fields_ordered[0], 'city_id')
def test_auto_id_vs_non_pk_id_field(self):
class City(Document):
continent = StringField()
id = IntField()
meta = {'abstract': True,
'allow_inheritance': False}
class EuropeanCity(City):
name = StringField()
berlin = EuropeanCity(name='Berlin', continent='Europe')
self.assertEqual(len(berlin._db_field_map), len(berlin._fields_ordered))
self.assertEqual(len(berlin._reverse_db_field_map), len(berlin._fields_ordered))
self.assertEqual(len(berlin._fields_ordered), 4)
self.assertEqual(berlin._fields_ordered[0], 'auto_id_0')
berlin.save()
self.assertEqual(berlin.pk, berlin.auto_id_0)
def test_abstract_document_creation_does_not_fail(self):
class City(Document):
continent = StringField()
meta = {'abstract': True,
'allow_inheritance': False}
bkk = City(continent='asia')
self.assertEqual(None, bkk.pk)
# TODO: expected error? Shouldn't we create a new error type?
self.assertRaises(KeyError, lambda: setattr(bkk, 'pk', 1))
def test_allow_inheritance_embedded_document(self):
"""Ensure embedded documents respect inheritance
"""

View File

@ -88,7 +88,7 @@ class InstanceTest(unittest.TestCase):
options = Log.objects._collection.options()
self.assertEqual(options['capped'], True)
self.assertEqual(options['max'], 10)
self.assertTrue(options['size'] >= 4096)
self.assertEqual(options['size'], 4096)
# Check that the document cannot be redefined with different options
def recreate_log_document():
@ -103,6 +103,69 @@ class InstanceTest(unittest.TestCase):
Log.drop_collection()
def test_capped_collection_default(self):
"""Ensure that capped collections defaults work properly.
"""
class Log(Document):
date = DateTimeField(default=datetime.now)
meta = {
'max_documents': 10,
}
Log.drop_collection()
# Create a doc to create the collection
Log().save()
options = Log.objects._collection.options()
self.assertEqual(options['capped'], True)
self.assertEqual(options['max'], 10)
self.assertEqual(options['size'], 10 * 2**20)
# Check that the document with default value can be recreated
def recreate_log_document():
class Log(Document):
date = DateTimeField(default=datetime.now)
meta = {
'max_documents': 10,
}
# Create the collection by accessing Document.objects
Log.objects
recreate_log_document()
Log.drop_collection()
def test_capped_collection_no_max_size_problems(self):
"""Ensure that capped collections with odd max_size work properly.
MongoDB rounds up max_size to next multiple of 256, recreating a doc
with the same spec failed in mongoengine <0.10
"""
class Log(Document):
date = DateTimeField(default=datetime.now)
meta = {
'max_size': 10000,
}
Log.drop_collection()
# Create a doc to create the collection
Log().save()
options = Log.objects._collection.options()
self.assertEqual(options['capped'], True)
self.assertTrue(options['size'] >= 10000)
# Check that the document with odd max_size value can be recreated
def recreate_log_document():
class Log(Document):
date = DateTimeField(default=datetime.now)
meta = {
'max_size': 10000,
}
# Create the collection by accessing Document.objects
Log.objects
recreate_log_document()
Log.drop_collection()
def test_repr(self):
"""Ensure that unicode representation works
"""
@ -954,11 +1017,12 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(w1.save_id, UUID(1))
self.assertEqual(w1.count, 0)
# mismatch in save_condition prevents save
# mismatch in save_condition prevents save and raise exception
flip(w1)
self.assertTrue(w1.toggle)
self.assertEqual(w1.count, 1)
w1.save(save_condition={'save_id': UUID(42)})
self.assertRaises(OperationError,
w1.save, save_condition={'save_id': UUID(42)})
w1.reload()
self.assertFalse(w1.toggle)
self.assertEqual(w1.count, 0)
@ -986,7 +1050,8 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(w1.count, 2)
flip(w2)
flip(w2)
w2.save(save_condition={'save_id': old_id})
self.assertRaises(OperationError,
w2.save, save_condition={'save_id': old_id})
w2.reload()
self.assertFalse(w2.toggle)
self.assertEqual(w2.count, 2)
@ -998,7 +1063,8 @@ class InstanceTest(unittest.TestCase):
self.assertTrue(w1.toggle)
self.assertEqual(w1.count, 3)
flip(w1)
w1.save(save_condition={'count__gte': w1.count})
self.assertRaises(OperationError,
w1.save, save_condition={'count__gte': w1.count})
w1.reload()
self.assertTrue(w1.toggle)
self.assertEqual(w1.count, 3)

View File

@ -39,6 +39,7 @@ class FieldTest(unittest.TestCase):
def tearDown(self):
self.db.drop_collection('fs.files')
self.db.drop_collection('fs.chunks')
self.db.drop_collection('mongoengine.counters')
def test_default_values_nothing_set(self):
"""Ensure that default field values are used when creating a document.
@ -341,6 +342,23 @@ class FieldTest(unittest.TestCase):
link.url = 'http://www.google.com:8080'
link.validate()
def test_url_scheme_validation(self):
"""Ensure that URLFields validate urls with specific schemes properly.
"""
class Link(Document):
url = URLField()
class SchemeLink(Document):
url = URLField(schemes=['ws', 'irc'])
link = Link()
link.url = 'ws://google.com'
self.assertRaises(ValidationError, link.validate)
scheme_link = SchemeLink()
scheme_link.url = 'ws://google.com'
scheme_link.validate()
def test_int_validation(self):
"""Ensure that invalid values cannot be assigned to int fields.
"""
@ -2134,9 +2152,7 @@ class FieldTest(unittest.TestCase):
obj = Product.objects(company=None).first()
self.assertEqual(obj, me)
obj, created = Product.objects.get_or_create(company=None)
self.assertEqual(created, False)
obj = Product.objects.get(company=None)
self.assertEqual(obj, me)
def test_reference_query_conversion(self):
@ -2954,6 +2970,57 @@ class FieldTest(unittest.TestCase):
self.assertEqual(1, post.comments[0].id)
self.assertEqual(2, post.comments[1].id)
def test_inherited_sequencefield(self):
class Base(Document):
name = StringField()
counter = SequenceField()
meta = {'abstract': True}
class Foo(Base):
pass
class Bar(Base):
pass
bar = Bar(name='Bar')
bar.save()
foo = Foo(name='Foo')
foo.save()
self.assertTrue('base.counter' in
self.db['mongoengine.counters'].find().distinct('_id'))
self.assertFalse(('foo.counter' or 'bar.counter') in
self.db['mongoengine.counters'].find().distinct('_id'))
self.assertNotEqual(foo.counter, bar.counter)
self.assertEqual(foo._fields['counter'].owner_document, Base)
self.assertEqual(bar._fields['counter'].owner_document, Base)
def test_no_inherited_sequencefield(self):
class Base(Document):
name = StringField()
meta = {'abstract': True}
class Foo(Base):
counter = SequenceField()
class Bar(Base):
counter = SequenceField()
bar = Bar(name='Bar')
bar.save()
foo = Foo(name='Foo')
foo.save()
self.assertFalse('base.counter' in
self.db['mongoengine.counters'].find().distinct('_id'))
self.assertTrue(('foo.counter' and 'bar.counter') in
self.db['mongoengine.counters'].find().distinct('_id'))
self.assertEqual(foo.counter, bar.counter)
self.assertEqual(foo._fields['counter'].owner_document, Foo)
self.assertEqual(bar._fields['counter'].owner_document, Bar)
def test_generic_embedded_document(self):
class Car(EmbeddedDocument):
name = StringField()
@ -3088,7 +3155,6 @@ class FieldTest(unittest.TestCase):
self.assertTrue(user.validate() is None)
user = User(email=("Kofq@rhom0e4klgauOhpbpNdogawnyIKvQS0wk2mjqrgGQ5S"
"ucictfqpdkK9iS1zeFw8sg7s7cwAF7suIfUfeyueLpfosjn3"
"aJIazqqWkm7.net"))
self.assertTrue(user.validate() is None)

View File

@ -13,7 +13,7 @@ import pymongo
from pymongo.errors import ConfigurationError
from pymongo.read_preferences import ReadPreference
from bson import ObjectId
from bson import ObjectId, DBRef
from mongoengine import *
from mongoengine.connection import get_connection, get_db
@ -340,8 +340,7 @@ class QuerySetTest(unittest.TestCase):
write_concern = {"fsync": True}
author, created = self.Person.objects.get_or_create(
name='Test User', write_concern=write_concern)
author = self.Person.objects.create(name='Test User')
author.save(write_concern=write_concern)
result = self.Person.objects.update(
@ -630,6 +629,40 @@ class QuerySetTest(unittest.TestCase):
self.assertRaises(ValidationError, Doc.objects().update, dt_f="datetime", upsert=True)
self.assertRaises(ValidationError, Doc.objects().update, ed_f__str_f=1, upsert=True)
def test_update_related_models( self ):
class TestPerson( Document ):
name = StringField()
class TestOrganization( Document ):
name = StringField()
owner = ReferenceField( TestPerson )
TestPerson.drop_collection()
TestOrganization.drop_collection()
p = TestPerson( name='p1' )
p.save()
o = TestOrganization( name='o1' )
o.save()
o.owner = p
p.name = 'p2'
self.assertEqual( o._get_changed_fields(), [ 'owner' ] )
self.assertEqual( p._get_changed_fields(), [ 'name' ] )
o.save()
self.assertEqual( o._get_changed_fields(), [] )
self.assertEqual( p._get_changed_fields(), [ 'name' ] ) # Fails; it's empty
# This will do NOTHING at all, even though we changed the name
p.save()
p.reload()
self.assertEqual( p.name, 'p2' ) # Fails; it's still `p1`
def test_upsert(self):
self.Person.drop_collection()
@ -659,37 +692,42 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual("Bob", bob.name)
self.assertEqual(30, bob.age)
def test_get_or_create(self):
"""Ensure that ``get_or_create`` returns one result or creates a new
document.
"""
person1 = self.Person(name="User A", age=20)
person1.save()
person2 = self.Person(name="User B", age=30)
person2.save()
def test_save_and_only_on_fields_with_default(self):
class Embed(EmbeddedDocument):
field = IntField()
# Retrieve the first person from the database
self.assertRaises(MultipleObjectsReturned,
self.Person.objects.get_or_create)
self.assertRaises(self.Person.MultipleObjectsReturned,
self.Person.objects.get_or_create)
class B(Document):
meta = {'collection': 'b'}
# Use a query to filter the people found to just person2
person, created = self.Person.objects.get_or_create(age=30)
self.assertEqual(person.name, "User B")
self.assertEqual(created, False)
field = IntField(default=1)
embed = EmbeddedDocumentField(Embed, default=Embed)
embed_no_default = EmbeddedDocumentField(Embed)
person, created = self.Person.objects.get_or_create(age__lt=30)
self.assertEqual(person.name, "User A")
self.assertEqual(created, False)
# Creating {field : 2, embed : {field: 2}, embed_no_default: {field: 2}}
val = 2
embed = Embed()
embed.field = val
record = B()
record.field = val
record.embed = embed
record.embed_no_default = embed
record.save()
# Try retrieving when no objects exists - new doc should be created
kwargs = dict(age=50, defaults={'name': 'User C'})
person, created = self.Person.objects.get_or_create(**kwargs)
self.assertEqual(created, True)
# Checking it was saved correctly
record.reload()
self.assertEqual(record.field, 2)
self.assertEqual(record.embed_no_default.field, 2)
self.assertEqual(record.embed.field, 2)
person = self.Person.objects.get(age=50)
self.assertEqual(person.name, "User C")
# Request only the _id field and save
clone = B.objects().only('id').first()
clone.save()
# Reload the record and see that the embed data is not lost
record.reload()
self.assertEqual(record.field, 2)
self.assertEqual(record.embed_no_default.field, 2)
self.assertEqual(record.embed.field, 2)
def test_bulk_insert(self):
"""Ensure that bulk insert works
@ -2668,26 +2706,58 @@ class QuerySetTest(unittest.TestCase):
avg = float(sum(ages)) / (len(ages) + 1) # take into account the 0
self.assertAlmostEqual(int(self.Person.objects.average('age')), avg)
self.assertAlmostEqual(
int(self.Person.objects.aggregate_average('age')), avg
)
self.Person(name='ageless person').save()
self.assertEqual(int(self.Person.objects.average('age')), avg)
self.assertEqual(
int(self.Person.objects.aggregate_average('age')), avg
)
# dot notation
self.Person(
name='person meta', person_meta=self.PersonMeta(weight=0)).save()
self.assertAlmostEqual(
int(self.Person.objects.average('person_meta.weight')), 0)
self.assertAlmostEqual(
int(self.Person.objects.aggregate_average('person_meta.weight')),
0
)
for i, weight in enumerate(ages):
self.Person(
name='test meta%i', person_meta=self.PersonMeta(weight=weight)).save()
self.assertAlmostEqual(
int(self.Person.objects.average('person_meta.weight')), avg)
int(self.Person.objects.average('person_meta.weight')), avg
)
self.assertAlmostEqual(
int(self.Person.objects.aggregate_average('person_meta.weight')),
avg
)
self.Person(name='test meta none').save()
self.assertEqual(
int(self.Person.objects.average('person_meta.weight')), avg)
int(self.Person.objects.average('person_meta.weight')), avg
)
self.assertEqual(
int(self.Person.objects.aggregate_average('person_meta.weight')),
avg
)
# test summing over a filtered queryset
over_50 = [a for a in ages if a >= 50]
avg = float(sum(over_50)) / len(over_50)
self.assertEqual(
self.Person.objects.filter(age__gte=50).average('age'),
avg
)
self.assertEqual(
self.Person.objects.filter(age__gte=50).aggregate_average('age'),
avg
)
def test_sum(self):
"""Ensure that field can be summed over correctly.
@ -2696,20 +2766,44 @@ class QuerySetTest(unittest.TestCase):
for i, age in enumerate(ages):
self.Person(name='test%s' % i, age=age).save()
self.assertEqual(int(self.Person.objects.sum('age')), sum(ages))
self.assertEqual(self.Person.objects.sum('age'), sum(ages))
self.assertEqual(
self.Person.objects.aggregate_sum('age'), sum(ages)
)
self.Person(name='ageless person').save()
self.assertEqual(int(self.Person.objects.sum('age')), sum(ages))
self.assertEqual(self.Person.objects.sum('age'), sum(ages))
self.assertEqual(
self.Person.objects.aggregate_sum('age'), sum(ages)
)
for i, age in enumerate(ages):
self.Person(name='test meta%s' %
i, person_meta=self.PersonMeta(weight=age)).save()
self.assertEqual(
int(self.Person.objects.sum('person_meta.weight')), sum(ages))
self.Person.objects.sum('person_meta.weight'), sum(ages)
)
self.assertEqual(
self.Person.objects.aggregate_sum('person_meta.weight'),
sum(ages)
)
self.Person(name='weightless person').save()
self.assertEqual(int(self.Person.objects.sum('age')), sum(ages))
self.assertEqual(self.Person.objects.sum('age'), sum(ages))
self.assertEqual(
self.Person.objects.aggregate_sum('age'), sum(ages)
)
# test summing over a filtered queryset
self.assertEqual(
self.Person.objects.filter(age__gte=50).sum('age'),
sum([a for a in ages if a >= 50])
)
self.assertEqual(
self.Person.objects.filter(age__gte=50).aggregate_sum('age'),
sum([a for a in ages if a >= 50])
)
def test_embedded_average(self):
class Pay(EmbeddedDocument):
@ -3655,11 +3749,9 @@ class QuerySetTest(unittest.TestCase):
def test_scalar(self):
class Organization(Document):
id = ObjectIdField('_id')
name = StringField()
class User(Document):
id = ObjectIdField('_id')
name = StringField()
organization = ObjectIdField()
@ -4185,6 +4277,41 @@ class QuerySetTest(unittest.TestCase):
Organization))
self.assertTrue(isinstance(qs.first().organization, Organization))
def test_no_dereference_embedded_doc(self):
class User(Document):
name = StringField()
class Member(EmbeddedDocument):
name = StringField()
user = ReferenceField(User)
class Organization(Document):
name = StringField()
members = ListField(EmbeddedDocumentField(Member))
ceo = ReferenceField(User)
member = EmbeddedDocumentField(Member)
admin = ListField(ReferenceField(User))
Organization.drop_collection()
User.drop_collection()
user = User(name="Flash")
user.save()
member = Member(name="Flash", user=user)
company = Organization(name="Mongo Inc", ceo=user, member=member)
company.admin.append(user)
company.members.append(member)
company.save()
result = Organization.objects().no_dereference().first()
self.assertTrue(isinstance(result.admin[0], (DBRef, ObjectId)))
self.assertTrue(isinstance(result.member.user, (DBRef, ObjectId)))
self.assertTrue(isinstance(result.members[0].user, (DBRef, ObjectId)))
def test_cached_queryset(self):
class Person(Document):
name = StringField()
@ -4622,6 +4749,13 @@ class QuerySetTest(unittest.TestCase):
self.assertEquals(Animal.objects(folded_ears=True).count(), 1)
self.assertEquals(Animal.objects(whiskers_length=5.1).count(), 1)
def test_loop_via_invalid_id_does_not_crash(self):
class Person(Document):
name = StringField()
Person.objects.delete()
Person._get_collection().update({"name": "a"}, {"$set": {"_id": ""}}, upsert=True)
for p in Person.objects():
self.assertEqual(p.name, 'a')
if __name__ == '__main__':
unittest.main()

View File

@ -224,6 +224,15 @@ class TransformTest(unittest.TestCase):
self.assertEqual(1, Doc.objects(item__type__="axe").count())
self.assertEqual(1, Doc.objects(item__name__="Heroic axe").count())
def test_understandable_error_raised(self):
class Event(Document):
title = StringField()
location = GeoPointField()
box = [(35.0, -125.0), (40.0, -100.0)]
# I *meant* to execute location__within_box=box
events = Event.objects(location__within=box)
self.assertRaises(InvalidQueryError, lambda: events.count())
if __name__ == '__main__':
unittest.main()

View File

@ -1026,6 +1026,43 @@ class FieldTest(unittest.TestCase):
self.assertEqual(type(foo.bar), Bar)
self.assertEqual(type(foo.baz), Baz)
def test_document_reload_reference_integrity(self):
"""
Ensure reloading a document with multiple similar id
in different collections doesn't mix them.
"""
class Topic(Document):
id = IntField(primary_key=True)
class User(Document):
id = IntField(primary_key=True)
name = StringField()
class Message(Document):
id = IntField(primary_key=True)
topic = ReferenceField(Topic)
author = ReferenceField(User)
Topic.drop_collection()
User.drop_collection()
Message.drop_collection()
# All objects share the same id, but each in a different collection
topic = Topic(id=1).save()
user = User(id=1, name='user-name').save()
Message(id=1, topic=topic, author=user).save()
concurrent_change_user = User.objects.get(id=1)
concurrent_change_user.name = 'new-name'
concurrent_change_user.save()
self.assertNotEqual(user.name, 'new-name')
msg = Message.objects.get(id=1)
msg.reload()
self.assertEqual(msg.topic, topic)
self.assertEqual(msg.author, user)
self.assertEqual(msg.author.name, 'new-name')
def test_list_lookup_not_checked_in_map(self):
"""Ensure we dereference list data correctly
"""