Merge branch 'dev' into pull_124

This commit is contained in:
Ross Lawley 2011-05-25 09:54:56 +01:00
commit 2ce70448b0
10 changed files with 1184 additions and 240 deletions

4
.gitignore vendored
View File

@ -1,3 +1,5 @@
.*
!.gitignore
*.pyc *.pyc
.*.swp .*.swp
*.egg *.egg
@ -9,4 +11,4 @@ mongoengine.egg-info/
env/ env/
.settings .settings
.project .project
.pydevproject .pydevproject

View File

@ -7,22 +7,32 @@ import pymongo
import pymongo.objectid import pymongo.objectid
_document_registry = {} class NotRegistered(Exception):
pass
def get_document(name):
return _document_registry[name]
class ValidationError(Exception): class ValidationError(Exception):
pass pass
_document_registry = {}
def get_document(name):
if name not in _document_registry:
raise NotRegistered("""
`%s` has not been registered in the document registry.
Importing the document class automatically registers it, has it
been imported?
""".strip() % name)
return _document_registry[name]
class BaseField(object): class BaseField(object):
"""A base class for fields in a MongoDB document. Instances of this class """A base class for fields in a MongoDB document. Instances of this class
may be added to subclasses of `Document` to define a document's schema. may be added to subclasses of `Document` to define a document's schema.
""" """
# Fields may have _types inserted into indexes by default # Fields may have _types inserted into indexes by default
_index_with_types = True _index_with_types = True
_geo_index = False _geo_index = False
@ -32,7 +42,7 @@ class BaseField(object):
creation_counter = 0 creation_counter = 0
auto_creation_counter = -1 auto_creation_counter = -1
def __init__(self, db_field=None, name=None, required=False, default=None, def __init__(self, db_field=None, name=None, required=False, default=None,
unique=False, unique_with=None, primary_key=False, unique=False, unique_with=None, primary_key=False,
validation=None, choices=None): validation=None, choices=None):
self.db_field = (db_field or name) if not primary_key else '_id' self.db_field = (db_field or name) if not primary_key else '_id'
@ -57,7 +67,7 @@ class BaseField(object):
BaseField.creation_counter += 1 BaseField.creation_counter += 1
def __get__(self, instance, owner): def __get__(self, instance, owner):
"""Descriptor for retrieving a value from a field in a document. Do """Descriptor for retrieving a value from a field in a document. Do
any necessary conversion between Python and MongoDB types. any necessary conversion between Python and MongoDB types.
""" """
if instance is None: if instance is None:
@ -167,8 +177,8 @@ class DocumentMetaclass(type):
superclasses.update(base._superclasses) superclasses.update(base._superclasses)
if hasattr(base, '_meta'): if hasattr(base, '_meta'):
# Ensure that the Document class may be subclassed - # Ensure that the Document class may be subclassed -
# inheritance may be disabled to remove dependency on # inheritance may be disabled to remove dependency on
# additional fields _cls and _types # additional fields _cls and _types
if base._meta.get('allow_inheritance', True) == False: if base._meta.get('allow_inheritance', True) == False:
raise ValueError('Document %s may not be subclassed' % raise ValueError('Document %s may not be subclassed' %
@ -211,12 +221,12 @@ class DocumentMetaclass(type):
module = attrs.get('__module__') module = attrs.get('__module__')
base_excs = tuple(base.DoesNotExist for base in bases base_excs = tuple(base.DoesNotExist for base in bases
if hasattr(base, 'DoesNotExist')) or (DoesNotExist,) if hasattr(base, 'DoesNotExist')) or (DoesNotExist,)
exc = subclass_exception('DoesNotExist', base_excs, module) exc = subclass_exception('DoesNotExist', base_excs, module)
new_class.add_to_class('DoesNotExist', exc) new_class.add_to_class('DoesNotExist', exc)
base_excs = tuple(base.MultipleObjectsReturned for base in bases base_excs = tuple(base.MultipleObjectsReturned for base in bases
if hasattr(base, 'MultipleObjectsReturned')) if hasattr(base, 'MultipleObjectsReturned'))
base_excs = base_excs or (MultipleObjectsReturned,) base_excs = base_excs or (MultipleObjectsReturned,)
exc = subclass_exception('MultipleObjectsReturned', base_excs, module) exc = subclass_exception('MultipleObjectsReturned', base_excs, module)
@ -238,12 +248,21 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
def __new__(cls, name, bases, attrs): def __new__(cls, name, bases, attrs):
super_new = super(TopLevelDocumentMetaclass, cls).__new__ super_new = super(TopLevelDocumentMetaclass, cls).__new__
# Classes defined in this package are abstract and should not have # Classes defined in this package are abstract and should not have
# their own metadata with DB collection, etc. # their own metadata with DB collection, etc.
# __metaclass__ is only set on the class with the __metaclass__ # __metaclass__ is only set on the class with the __metaclass__
# attribute (i.e. it is not set on subclasses). This differentiates # attribute (i.e. it is not set on subclasses). This differentiates
# 'real' documents from the 'Document' class # 'real' documents from the 'Document' class
if attrs.get('__metaclass__') == TopLevelDocumentMetaclass: #
# Also assume a class is abstract if it has abstract set to True in
# its meta dictionary. This allows custom Document superclasses.
if (attrs.get('__metaclass__') == TopLevelDocumentMetaclass or
('meta' in attrs and attrs['meta'].get('abstract', False))):
# Make sure no base class was non-abstract
non_abstract_bases = [b for b in bases
if hasattr(b,'_meta') and not b._meta.get('abstract', False)]
if non_abstract_bases:
raise ValueError("Abstract document cannot have non-abstract base")
return super_new(cls, name, bases, attrs) return super_new(cls, name, bases, attrs)
collection = name.lower() collection = name.lower()
@ -266,6 +285,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
base_indexes += base._meta.get('indexes', []) base_indexes += base._meta.get('indexes', [])
meta = { meta = {
'abstract': False,
'collection': collection, 'collection': collection,
'max_documents': None, 'max_documents': None,
'max_size': None, 'max_size': None,
@ -289,13 +309,39 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
new_class = super_new(cls, name, bases, attrs) new_class = super_new(cls, name, bases, attrs)
# Provide a default queryset unless one has been manually provided # Provide a default queryset unless one has been manually provided
if not hasattr(new_class, 'objects'): manager = attrs.get('objects', QuerySetManager())
new_class.objects = QuerySetManager() if hasattr(manager, 'queryset_class'):
meta['queryset_class'] = manager.queryset_class
new_class.objects = manager
user_indexes = [QuerySet._build_index_spec(new_class, spec) user_indexes = [QuerySet._build_index_spec(new_class, spec)
for spec in meta['indexes']] + base_indexes for spec in meta['indexes']] + base_indexes
new_class._meta['indexes'] = user_indexes new_class._meta['indexes'] = user_indexes
unique_indexes = cls._unique_with_indexes(new_class)
new_class._meta['unique_indexes'] = unique_indexes
for field_name, field in new_class._fields.items():
# Check for custom primary key
if field.primary_key:
current_pk = new_class._meta['id_field']
if current_pk and current_pk != field_name:
raise ValueError('Cannot override primary key field')
if not current_pk:
new_class._meta['id_field'] = field_name
# Make 'Document.id' an alias to the real primary key field
new_class.id = field
if not new_class._meta['id_field']:
new_class._meta['id_field'] = 'id'
new_class._fields['id'] = ObjectIdField(db_field='_id')
new_class.id = new_class._fields['id']
return new_class
@classmethod
def _unique_with_indexes(cls, new_class, namespace=""):
unique_indexes = [] unique_indexes = []
for field_name, field in new_class._fields.items(): for field_name, field in new_class._fields.items():
# Generate a list of indexes needed by uniqueness constraints # Generate a list of indexes needed by uniqueness constraints
@ -321,28 +367,16 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
unique_fields += unique_with unique_fields += unique_with
# Add the new index to the list # Add the new index to the list
index = [(f, pymongo.ASCENDING) for f in unique_fields] index = [("%s%s" % (namespace, f), pymongo.ASCENDING) for f in unique_fields]
unique_indexes.append(index) unique_indexes.append(index)
# Check for custom primary key # Grab any embedded document field unique indexes
if field.primary_key: if field.__class__.__name__ == "EmbeddedDocumentField":
current_pk = new_class._meta['id_field'] field_namespace = "%s." % field_name
if current_pk and current_pk != field_name: unique_indexes += cls._unique_with_indexes(field.document_type,
raise ValueError('Cannot override primary key field') field_namespace)
if not current_pk: return unique_indexes
new_class._meta['id_field'] = field_name
# Make 'Document.id' an alias to the real primary key field
new_class.id = field
new_class._meta['unique_indexes'] = unique_indexes
if not new_class._meta['id_field']:
new_class._meta['id_field'] = 'id'
new_class._fields['id'] = ObjectIdField(db_field='_id')
new_class.id = new_class._fields['id']
return new_class
class BaseDocument(object): class BaseDocument(object):
@ -366,7 +400,7 @@ class BaseDocument(object):
are present. are present.
""" """
# Get a list of tuples of field names and their current values # Get a list of tuples of field names and their current values
fields = [(field, getattr(self, name)) fields = [(field, getattr(self, name))
for name, field in self._fields.items()] for name, field in self._fields.items()]
# Ensure that each field is matched to a valid value # Ensure that each field is matched to a valid value
@ -461,7 +495,7 @@ class BaseDocument(object):
self._meta.get('allow_inheritance', True) == False): self._meta.get('allow_inheritance', True) == False):
data['_cls'] = self._class_name data['_cls'] = self._class_name
data['_types'] = self._superclasses.keys() + [self._class_name] data['_types'] = self._superclasses.keys() + [self._class_name]
if data.has_key('_id') and not data['_id']: if data.has_key('_id') and data['_id'] is None:
del data['_id'] del data['_id']
return data return data

View File

@ -1,5 +1,6 @@
from pymongo import Connection from pymongo import Connection
import multiprocessing import multiprocessing
import threading
__all__ = ['ConnectionError', 'connect'] __all__ = ['ConnectionError', 'connect']
@ -22,17 +23,22 @@ class ConnectionError(Exception):
def _get_connection(reconnect=False): def _get_connection(reconnect=False):
"""Handles the connection to the database
"""
global _connection global _connection
identity = get_identity() identity = get_identity()
# Connect to the database if not already connected # Connect to the database if not already connected
if _connection.get(identity) is None or reconnect: if _connection.get(identity) is None or reconnect:
try: try:
_connection[identity] = Connection(**_connection_settings) _connection[identity] = Connection(**_connection_settings)
except: except Exception, e:
raise ConnectionError('Cannot connect to the database') raise ConnectionError("Cannot connect to the database:\n%s" % e)
return _connection[identity] return _connection[identity]
def _get_db(reconnect=False): def _get_db(reconnect=False):
"""Handles database connections and authentication based on the current
identity
"""
global _db, _connection global _db, _connection
identity = get_identity() identity = get_identity()
# Connect if not already connected # Connect if not already connected
@ -52,12 +58,17 @@ def _get_db(reconnect=False):
return _db[identity] return _db[identity]
def get_identity(): def get_identity():
"""Creates an identity key based on the current process and thread
identity.
"""
identity = multiprocessing.current_process()._identity identity = multiprocessing.current_process()._identity
identity = 0 if not identity else identity[0] identity = 0 if not identity else identity[0]
identity = (identity, threading.current_thread().ident)
return identity return identity
def connect(db, username=None, password=None, **kwargs): def connect(db, username=None, password=None, **kwargs):
"""Connect to the database specified by the 'db' argument. Connection """Connect to the database specified by the 'db' argument. Connection
settings may be provided here as well if the database is not running on settings may be provided here as well if the database is not running on
the default port on localhost. If authentication is needed, provide the default port on localhost. If authentication is needed, provide
username and password arguments as well. username and password arguments as well.

View File

@ -86,7 +86,7 @@ class User(Document):
else: else:
email = '@'.join([email_name, domain_part.lower()]) email = '@'.join([email_name, domain_part.lower()])
user = User(username=username, email=email, date_joined=now) user = cls(username=username, email=email, date_joined=now)
user.set_password(password) user.set_password(password)
user.save() user.save()
return user return user

View File

@ -40,44 +40,54 @@ class Document(BaseDocument):
presence of `_cls` and `_types`, set :attr:`allow_inheritance` to presence of `_cls` and `_types`, set :attr:`allow_inheritance` to
``False`` in the :attr:`meta` dictionary. ``False`` in the :attr:`meta` dictionary.
A :class:`~mongoengine.Document` may use a **Capped Collection** by A :class:`~mongoengine.Document` may use a **Capped Collection** by
specifying :attr:`max_documents` and :attr:`max_size` in the :attr:`meta` specifying :attr:`max_documents` and :attr:`max_size` in the :attr:`meta`
dictionary. :attr:`max_documents` is the maximum number of documents that dictionary. :attr:`max_documents` is the maximum number of documents that
is allowed to be stored in the collection, and :attr:`max_size` is the is allowed to be stored in the collection, and :attr:`max_size` is the
maximum size of the collection in bytes. If :attr:`max_size` is not maximum size of the collection in bytes. If :attr:`max_size` is not
specified and :attr:`max_documents` is, :attr:`max_size` defaults to specified and :attr:`max_documents` is, :attr:`max_size` defaults to
10000000 bytes (10MB). 10000000 bytes (10MB).
Indexes may be created by specifying :attr:`indexes` in the :attr:`meta` Indexes may be created by specifying :attr:`indexes` in the :attr:`meta`
dictionary. The value should be a list of field names or tuples of field dictionary. The value should be a list of field names or tuples of field
names. Index direction may be specified by prefixing the field names with names. Index direction may be specified by prefixing the field names with
a **+** or **-** sign. a **+** or **-** sign.
""" """
__metaclass__ = TopLevelDocumentMetaclass __metaclass__ = TopLevelDocumentMetaclass
def save(self, safe=True, force_insert=False, validate=True): def save(self, safe=True, force_insert=False, validate=True, write_options=None):
"""Save the :class:`~mongoengine.Document` to the database. If the """Save the :class:`~mongoengine.Document` to the database. If the
document already exists, it will be updated, otherwise it will be document already exists, it will be updated, otherwise it will be
created. created.
If ``safe=True`` and the operation is unsuccessful, an If ``safe=True`` and the operation is unsuccessful, an
:class:`~mongoengine.OperationError` will be raised. :class:`~mongoengine.OperationError` will be raised.
:param safe: check if the operation succeeded before returning :param safe: check if the operation succeeded before returning
:param force_insert: only try to create a new document, don't allow :param force_insert: only try to create a new document, don't allow
updates of existing documents updates of existing documents
:param validate: validates the document; set to ``False`` to skip. :param validate: validates the document; set to ``False`` to skip.
:param write_options: Extra keyword arguments are passed down to
:meth:`~pymongo.collection.Collection.save` OR
:meth:`~pymongo.collection.Collection.insert`
which will be used as options for the resultant ``getLastError`` command.
For example, ``save(..., w=2, fsync=True)`` will wait until at least two servers
have recorded the write and will force an fsync on each server being written to.
""" """
if validate: if validate:
self.validate() self.validate()
if not write_options:
write_options = {}
doc = self.to_mongo() doc = self.to_mongo()
try: try:
collection = self.__class__.objects._collection collection = self.__class__.objects._collection
if force_insert: if force_insert:
object_id = collection.insert(doc, safe=safe) object_id = collection.insert(doc, safe=safe, **write_options)
else: else:
object_id = collection.save(doc, safe=safe) object_id = collection.save(doc, safe=safe, **write_options)
except pymongo.errors.OperationFailure, err: except pymongo.errors.OperationFailure, err:
message = 'Could not save document (%s)' message = 'Could not save document (%s)'
if u'duplicate key' in unicode(err): if u'duplicate key' in unicode(err):
@ -131,9 +141,9 @@ class MapReduceDocument(object):
"""A document returned from a map/reduce query. """A document returned from a map/reduce query.
:param collection: An instance of :class:`~pymongo.Collection` :param collection: An instance of :class:`~pymongo.Collection`
:param key: Document/result key, often an instance of :param key: Document/result key, often an instance of
:class:`~pymongo.objectid.ObjectId`. If supplied as :class:`~pymongo.objectid.ObjectId`. If supplied as
an ``ObjectId`` found in the given ``collection``, an ``ObjectId`` found in the given ``collection``,
the object can be accessed via the ``object`` property. the object can be accessed via the ``object`` property.
:param value: The result(s) for this key. :param value: The result(s) for this key.
@ -148,7 +158,7 @@ class MapReduceDocument(object):
@property @property
def object(self): def object(self):
"""Lazy-load the object referenced by ``self.key``. ``self.key`` """Lazy-load the object referenced by ``self.key``. ``self.key``
should be the ``primary_key``. should be the ``primary_key``.
""" """
id_field = self._document()._meta['id_field'] id_field = self._document()._meta['id_field']

View File

@ -17,7 +17,7 @@ import warnings
__all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField', __all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField',
'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField', 'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField',
'ObjectIdField', 'ReferenceField', 'ValidationError', 'ObjectIdField', 'ReferenceField', 'ValidationError', 'MapField',
'DecimalField', 'URLField', 'GenericReferenceField', 'FileField', 'DecimalField', 'URLField', 'GenericReferenceField', 'FileField',
'BinaryField', 'SortedListField', 'EmailField', 'GeoPointField'] 'BinaryField', 'SortedListField', 'EmailField', 'GeoPointField']
@ -339,7 +339,7 @@ class ListField(BaseField):
if isinstance(self.field, ReferenceField): if isinstance(self.field, ReferenceField):
referenced_type = self.field.document_type referenced_type = self.field.document_type
# Get value from document instance if available # Get value from document instance if available
value_list = instance._data.get(self.name) value_list = instance._data.get(self.name)
if value_list: if value_list:
deref_list = [] deref_list = []
@ -449,7 +449,108 @@ class DictField(BaseField):
'contain "." or "$" characters') 'contain "." or "$" characters')
def lookup_member(self, member_name): def lookup_member(self, member_name):
return self.basecls(db_field=member_name) return DictField(basecls=self.basecls, db_field=member_name)
def prepare_query_value(self, op, value):
match_operators = ['contains', 'icontains', 'startswith',
'istartswith', 'endswith', 'iendswith',
'exact', 'iexact']
if op in match_operators and isinstance(value, basestring):
return StringField().prepare_query_value(op, value)
return super(DictField,self).prepare_query_value(op, value)
class MapField(BaseField):
"""A field that maps a name to a specified field type. Similar to
a DictField, except the 'value' of each item must match the specified
field type.
.. versionadded:: 0.5
"""
def __init__(self, field=None, *args, **kwargs):
if not isinstance(field, BaseField):
raise ValidationError('Argument to MapField constructor must be '
'a valid field')
self.field = field
kwargs.setdefault('default', lambda: {})
super(MapField, self).__init__(*args, **kwargs)
def validate(self, value):
"""Make sure that a list of valid fields is being used.
"""
if not isinstance(value, dict):
raise ValidationError('Only dictionaries may be used in a '
'DictField')
if any(('.' in k or '$' in k) for k in value):
raise ValidationError('Invalid dictionary key name - keys may not '
'contain "." or "$" characters')
try:
[self.field.validate(item) for item in value.values()]
except Exception, err:
raise ValidationError('Invalid MapField item (%s)' % str(item))
def __get__(self, instance, owner):
"""Descriptor to automatically dereference references.
"""
if instance is None:
# Document class being used rather than a document object
return self
if isinstance(self.field, ReferenceField):
referenced_type = self.field.document_type
# Get value from document instance if available
value_dict = instance._data.get(self.name)
if value_dict:
deref_dict = []
for key,value in value_dict.iteritems():
# Dereference DBRefs
if isinstance(value, (pymongo.dbref.DBRef)):
value = _get_db().dereference(value)
deref_dict[key] = referenced_type._from_son(value)
else:
deref_dict[key] = value
instance._data[self.name] = deref_dict
if isinstance(self.field, GenericReferenceField):
value_dict = instance._data.get(self.name)
if value_dict:
deref_dict = []
for key,value in value_dict.iteritems():
# Dereference DBRefs
if isinstance(value, (dict, pymongo.son.SON)):
deref_dict[key] = self.field.dereference(value)
else:
deref_dict[key] = value
instance._data[self.name] = deref_dict
return super(MapField, self).__get__(instance, owner)
def to_python(self, value):
return dict( [(key,self.field.to_python(item)) for key,item in value.iteritems()] )
def to_mongo(self, value):
return dict( [(key,self.field.to_mongo(item)) for key,item in value.iteritems()] )
def prepare_query_value(self, op, value):
return self.field.prepare_query_value(op, value)
def lookup_member(self, member_name):
return self.field.lookup_member(member_name)
def _set_owner_document(self, owner_document):
self.field.owner_document = owner_document
self._owner_document = owner_document
def _get_owner_document(self, owner_document):
self._owner_document = owner_document
owner_document = property(_get_owner_document, _set_owner_document)
class ReferenceField(BaseField): class ReferenceField(BaseField):
"""A reference to a document that will be automatically dereferenced on """A reference to a document that will be automatically dereferenced on
@ -522,6 +623,9 @@ class GenericReferenceField(BaseField):
"""A reference to *any* :class:`~mongoengine.document.Document` subclass """A reference to *any* :class:`~mongoengine.document.Document` subclass
that will be automatically dereferenced on access (lazily). that will be automatically dereferenced on access (lazily).
note: Any documents used as a generic reference must be registered in the
document registry. Importing the model will automatically register it.
.. versionadded:: 0.3 .. versionadded:: 0.3
""" """
@ -601,6 +705,7 @@ class GridFSProxy(object):
self.fs = gridfs.GridFS(_get_db()) # Filesystem instance self.fs = gridfs.GridFS(_get_db()) # Filesystem instance
self.newfile = None # Used for partial writes self.newfile = None # Used for partial writes
self.grid_id = grid_id # Store GridFS id for file self.grid_id = grid_id # Store GridFS id for file
self.gridout = None
def __getattr__(self, name): def __getattr__(self, name):
obj = self.get() obj = self.get()
@ -614,8 +719,12 @@ class GridFSProxy(object):
def get(self, id=None): def get(self, id=None):
if id: if id:
self.grid_id = id self.grid_id = id
if self.grid_id is None:
return None
try: try:
return self.fs.get(id or self.grid_id) if self.gridout is None:
self.gridout = self.fs.get(self.grid_id)
return self.gridout
except: except:
# File has been deleted # File has been deleted
return None return None
@ -643,11 +752,11 @@ class GridFSProxy(object):
if not self.newfile: if not self.newfile:
self.new_file() self.new_file()
self.grid_id = self.newfile._id self.grid_id = self.newfile._id
self.newfile.writelines(lines) self.newfile.writelines(lines)
def read(self): def read(self, size=-1):
try: try:
return self.get().read() return self.get().read(size)
except: except:
return None return None
@ -655,6 +764,7 @@ class GridFSProxy(object):
# Delete file from GridFS, FileField still remains # Delete file from GridFS, FileField still remains
self.fs.delete(self.grid_id) self.fs.delete(self.grid_id)
self.grid_id = None self.grid_id = None
self.gridout = None
def replace(self, file, **kwargs): def replace(self, file, **kwargs):
self.delete() self.delete()

View File

@ -8,6 +8,7 @@ import pymongo.objectid
import re import re
import copy import copy
import itertools import itertools
import operator
__all__ = ['queryset_manager', 'Q', 'InvalidQueryError', __all__ = ['queryset_manager', 'Q', 'InvalidQueryError',
'InvalidCollectionError', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY'] 'InvalidCollectionError', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY']
@ -280,30 +281,30 @@ class QueryFieldList(object):
ONLY = True ONLY = True
EXCLUDE = False EXCLUDE = False
def __init__(self, fields=[], direction=ONLY, always_include=[]): def __init__(self, fields=[], value=ONLY, always_include=[]):
self.direction = direction self.value = value
self.fields = set(fields) self.fields = set(fields)
self.always_include = set(always_include) self.always_include = set(always_include)
def as_dict(self): def as_dict(self):
return dict((field, self.direction) for field in self.fields) return dict((field, self.value) for field in self.fields)
def __add__(self, f): def __add__(self, f):
if not self.fields: if not self.fields:
self.fields = f.fields self.fields = f.fields
self.direction = f.direction self.value = f.value
elif self.direction is self.ONLY and f.direction is self.ONLY: elif self.value is self.ONLY and f.value is self.ONLY:
self.fields = self.fields.intersection(f.fields) self.fields = self.fields.intersection(f.fields)
elif self.direction is self.EXCLUDE and f.direction is self.EXCLUDE: elif self.value is self.EXCLUDE and f.value is self.EXCLUDE:
self.fields = self.fields.union(f.fields) self.fields = self.fields.union(f.fields)
elif self.direction is self.ONLY and f.direction is self.EXCLUDE: elif self.value is self.ONLY and f.value is self.EXCLUDE:
self.fields -= f.fields self.fields -= f.fields
elif self.direction is self.EXCLUDE and f.direction is self.ONLY: elif self.value is self.EXCLUDE and f.value is self.ONLY:
self.direction = self.ONLY self.value = self.ONLY
self.fields = f.fields - self.fields self.fields = f.fields - self.fields
if self.always_include: if self.always_include:
if self.direction is self.ONLY and self.fields: if self.value is self.ONLY and self.fields:
self.fields = self.fields.union(self.always_include) self.fields = self.fields.union(self.always_include)
else: else:
self.fields -= self.always_include self.fields -= self.always_include
@ -311,7 +312,7 @@ class QueryFieldList(object):
def reset(self): def reset(self):
self.fields = set([]) self.fields = set([])
self.direction = self.ONLY self.value = self.ONLY
def __nonzero__(self): def __nonzero__(self):
return bool(self.fields) return bool(self.fields)
@ -334,6 +335,7 @@ class QuerySet(object):
self._ordering = [] self._ordering = []
self._snapshot = False self._snapshot = False
self._timeout = True self._timeout = True
self._class_check = True
# If inheritance is allowed, only return instances and instances of # If inheritance is allowed, only return instances and instances of
# subclasses of the class being used # subclasses of the class being used
@ -344,11 +346,26 @@ class QuerySet(object):
self._limit = None self._limit = None
self._skip = None self._skip = None
def clone(self):
"""Creates a copy of the current :class:`~mongoengine.queryset.QuerySet`"""
c = self.__class__(self._document, self._collection_obj)
copy_props = ('_initial_query', '_query_obj', '_where_clause',
'_loaded_fields', '_ordering', '_snapshot',
'_timeout', '_limit', '_skip')
for prop in copy_props:
val = getattr(self, prop)
setattr(c, prop, copy.deepcopy(val))
return c
@property @property
def _query(self): def _query(self):
if self._mongo_query is None: if self._mongo_query is None:
self._mongo_query = self._query_obj.to_query(self._document) self._mongo_query = self._query_obj.to_query(self._document)
self._mongo_query.update(self._initial_query) if self._class_check:
self._mongo_query.update(self._initial_query)
return self._mongo_query return self._mongo_query
def ensure_index(self, key_or_list, drop_dups=False, background=False, def ensure_index(self, key_or_list, drop_dups=False, background=False,
@ -399,7 +416,7 @@ class QuerySet(object):
return index_list return index_list
def __call__(self, q_obj=None, **query): def __call__(self, q_obj=None, class_check=True, **query):
"""Filter the selected documents by calling the """Filter the selected documents by calling the
:class:`~mongoengine.queryset.QuerySet` with a query. :class:`~mongoengine.queryset.QuerySet` with a query.
@ -407,16 +424,17 @@ class QuerySet(object):
the query; the :class:`~mongoengine.queryset.QuerySet` is filtered the query; the :class:`~mongoengine.queryset.QuerySet` is filtered
multiple times with different :class:`~mongoengine.queryset.Q` multiple times with different :class:`~mongoengine.queryset.Q`
objects, only the last one will be used objects, only the last one will be used
:param class_check: If set to False bypass class name check when
querying collection
:param query: Django-style query keyword arguments :param query: Django-style query keyword arguments
""" """
#if q_obj:
#self._where_clause = q_obj.as_js(self._document)
query = Q(**query) query = Q(**query)
if q_obj: if q_obj:
query &= q_obj query &= q_obj
self._query_obj &= query self._query_obj &= query
self._mongo_query = None self._mongo_query = None
self._cursor_obj = None self._cursor_obj = None
self._class_check = class_check
return self return self
def filter(self, *q_objs, **query): def filter(self, *q_objs, **query):
@ -440,17 +458,17 @@ class QuerySet(object):
drop_dups = self._document._meta.get('index_drop_dups', False) drop_dups = self._document._meta.get('index_drop_dups', False)
index_opts = self._document._meta.get('index_options', {}) index_opts = self._document._meta.get('index_options', {})
# Ensure indexes created by uniqueness constraints
for index in self._document._meta['unique_indexes']:
self._collection.ensure_index(index, unique=True,
background=background, drop_dups=drop_dups, **index_opts)
# Ensure document-defined indexes are created # Ensure document-defined indexes are created
if self._document._meta['indexes']: if self._document._meta['indexes']:
for key_or_list in self._document._meta['indexes']: for key_or_list in self._document._meta['indexes']:
self._collection.ensure_index(key_or_list, self._collection.ensure_index(key_or_list,
background=background, **index_opts) background=background, **index_opts)
# Ensure indexes created by uniqueness constraints
for index in self._document._meta['unique_indexes']:
self._collection.ensure_index(index, unique=True,
background=background, drop_dups=drop_dups, **index_opts)
# If _types is being used (for polymorphism), it needs an index # If _types is being used (for polymorphism), it needs an index
if '_types' in self._query: if '_types' in self._query:
self._collection.ensure_index('_types', self._collection.ensure_index('_types',
@ -474,7 +492,7 @@ class QuerySet(object):
} }
if self._loaded_fields: if self._loaded_fields:
cursor_args['fields'] = self._loaded_fields.as_dict() cursor_args['fields'] = self._loaded_fields.as_dict()
self._cursor_obj = self._collection.find(self._query, self._cursor_obj = self._collection.find(self._query,
**cursor_args) **cursor_args)
# Apply where clauses to cursor # Apply where clauses to cursor
if self._where_clause: if self._where_clause:
@ -504,6 +522,15 @@ class QuerySet(object):
fields = [] fields = []
field = None field = None
for field_name in parts: for field_name in parts:
# Handle ListField indexing:
if field_name.isdigit():
try:
field = field.field
except AttributeError, err:
raise InvalidQueryError(
"Can't use index on unsubscriptable field (%s)" % err)
fields.append(field_name)
continue
if field is None: if field is None:
# Look up first field from the document # Look up first field from the document
if field_name == 'pk': if field_name == 'pk':
@ -528,14 +555,14 @@ class QuerySet(object):
return '.'.join(parts) return '.'.join(parts)
@classmethod @classmethod
def _transform_query(cls, _doc_cls=None, **query): def _transform_query(cls, _doc_cls=None, _field_operation=False, **query):
"""Transform a query from Django-style format to Mongo format. """Transform a query from Django-style format to Mongo format.
""" """
operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod',
'all', 'size', 'exists', 'not'] 'all', 'size', 'exists', 'not']
geo_operators = ['within_distance', 'within_box', 'near'] geo_operators = ['within_distance', 'within_spherical_distance', 'within_box', 'near', 'near_sphere']
match_operators = ['contains', 'icontains', 'startswith', match_operators = ['contains', 'icontains', 'startswith',
'istartswith', 'endswith', 'iendswith', 'istartswith', 'endswith', 'iendswith',
'exact', 'iexact'] 'exact', 'iexact']
mongo_query = {} mongo_query = {}
@ -577,8 +604,12 @@ class QuerySet(object):
if op in geo_operators: if op in geo_operators:
if op == "within_distance": if op == "within_distance":
value = {'$within': {'$center': value}} value = {'$within': {'$center': value}}
elif op == "within_spherical_distance":
value = {'$within': {'$centerSphere': value}}
elif op == "near": elif op == "near":
value = {'$near': value} value = {'$near': value}
elif op == "near_sphere":
value = {'$nearSphere': value}
elif op == 'within_box': elif op == 'within_box':
value = {'$within': {'$box': value}} value = {'$within': {'$box': value}}
else: else:
@ -620,9 +651,9 @@ class QuerySet(object):
raise self._document.DoesNotExist("%s matching query does not exist." raise self._document.DoesNotExist("%s matching query does not exist."
% self._document._class_name) % self._document._class_name)
def get_or_create(self, *q_objs, **query): def get_or_create(self, write_options=None, *q_objs, **query):
"""Retrieve unique object or create, if it doesn't exist. Returns a tuple of """Retrieve unique object or create, if it doesn't exist. Returns a tuple of
``(object, created)``, where ``object`` is the retrieved or created object ``(object, created)``, where ``object`` is the retrieved or created object
and ``created`` is a boolean specifying whether a new object was created. Raises and ``created`` is a boolean specifying whether a new object was created. Raises
:class:`~mongoengine.queryset.MultipleObjectsReturned` or :class:`~mongoengine.queryset.MultipleObjectsReturned` or
`DocumentName.MultipleObjectsReturned` if multiple results are found. `DocumentName.MultipleObjectsReturned` if multiple results are found.
@ -630,6 +661,10 @@ class QuerySet(object):
dictionary of default values for the new document may be provided as a dictionary of default values for the new document may be provided as a
keyword argument called :attr:`defaults`. keyword argument called :attr:`defaults`.
:param write_options: optional extra keyword arguments used if we
have to create a new document.
Passes any write_options onto :meth:`~mongoengine.document.Document.save`
.. versionadded:: 0.3 .. versionadded:: 0.3
""" """
defaults = query.get('defaults', {}) defaults = query.get('defaults', {})
@ -641,7 +676,7 @@ class QuerySet(object):
if count == 0: if count == 0:
query.update(defaults) query.update(defaults)
doc = self._document(**query) doc = self._document(**query)
doc.save() doc.save(write_options=write_options)
return doc, True return doc, True
elif count == 1: elif count == 1:
return self.first(), False return self.first(), False
@ -725,7 +760,7 @@ class QuerySet(object):
def __len__(self): def __len__(self):
return self.count() return self.count()
def map_reduce(self, map_f, reduce_f, finalize_f=None, limit=None, def map_reduce(self, map_f, reduce_f, output, finalize_f=None, limit=None,
scope=None, keep_temp=False): scope=None, keep_temp=False):
"""Perform a map/reduce query using the current query spec """Perform a map/reduce query using the current query spec
and ordering. While ``map_reduce`` respects ``QuerySet`` chaining, and ordering. While ``map_reduce`` respects ``QuerySet`` chaining,
@ -739,26 +774,26 @@ class QuerySet(object):
:param map_f: map function, as :class:`~pymongo.code.Code` or string :param map_f: map function, as :class:`~pymongo.code.Code` or string
:param reduce_f: reduce function, as :param reduce_f: reduce function, as
:class:`~pymongo.code.Code` or string :class:`~pymongo.code.Code` or string
:param output: output collection name
:param finalize_f: finalize function, an optional function that :param finalize_f: finalize function, an optional function that
performs any post-reduction processing. performs any post-reduction processing.
:param scope: values to insert into map/reduce global scope. Optional. :param scope: values to insert into map/reduce global scope. Optional.
:param limit: number of objects from current query to provide :param limit: number of objects from current query to provide
to map/reduce method to map/reduce method
:param keep_temp: keep temporary table (boolean, default ``True``)
Returns an iterator yielding Returns an iterator yielding
:class:`~mongoengine.document.MapReduceDocument`. :class:`~mongoengine.document.MapReduceDocument`.
.. note:: Map/Reduce requires server version **>= 1.1.1**. The PyMongo .. note:: Map/Reduce changed in server version **>= 1.7.4**. The PyMongo
:meth:`~pymongo.collection.Collection.map_reduce` helper requires :meth:`~pymongo.collection.Collection.map_reduce` helper requires
PyMongo version **>= 1.2**. PyMongo version **>= 1.11**.
.. versionadded:: 0.3 .. versionadded:: 0.3
""" """
from document import MapReduceDocument from document import MapReduceDocument
if not hasattr(self._collection, "map_reduce"): if not hasattr(self._collection, "map_reduce"):
raise NotImplementedError("Requires MongoDB >= 1.1.1") raise NotImplementedError("Requires MongoDB >= 1.7.1")
map_f_scope = {} map_f_scope = {}
if isinstance(map_f, pymongo.code.Code): if isinstance(map_f, pymongo.code.Code):
@ -789,8 +824,7 @@ class QuerySet(object):
if limit: if limit:
mr_args['limit'] = limit mr_args['limit'] = limit
results = self._collection.map_reduce(map_f, reduce_f, output, **mr_args)
results = self._collection.map_reduce(map_f, reduce_f, **mr_args)
results = results.find() results = results.find()
if self._ordering: if self._ordering:
@ -835,7 +869,7 @@ class QuerySet(object):
self._skip, self._limit = key.start, key.stop self._skip, self._limit = key.start, key.stop
except IndexError, err: except IndexError, err:
# PyMongo raises an error if key.start == key.stop, catch it, # PyMongo raises an error if key.start == key.stop, catch it,
# bin it, kill it. # bin it, kill it.
start = key.start or 0 start = key.start or 0
if start >= 0 and key.stop >= 0 and key.step is None: if start >= 0 and key.stop >= 0 and key.step is None:
if start == key.stop: if start == key.stop:
@ -868,10 +902,8 @@ class QuerySet(object):
.. versionadded:: 0.3 .. versionadded:: 0.3
""" """
fields = self._fields_to_dbfields(fields) fields = dict([(f, QueryFieldList.ONLY) for f in fields])
self._loaded_fields += QueryFieldList(fields, direction=QueryFieldList.ONLY) return self.fields(**fields)
return self
def exclude(self, *fields): def exclude(self, *fields):
"""Opposite to .only(), exclude some document's fields. :: """Opposite to .only(), exclude some document's fields. ::
@ -880,8 +912,44 @@ class QuerySet(object):
:param fields: fields to exclude :param fields: fields to exclude
""" """
fields = self._fields_to_dbfields(fields) fields = dict([(f, QueryFieldList.EXCLUDE) for f in fields])
self._loaded_fields += QueryFieldList(fields, direction=QueryFieldList.EXCLUDE) return self.fields(**fields)
def fields(self, **kwargs):
"""Manipulate how you load this document's fields. Used by `.only()`
and `.exclude()` to manipulate which fields to retrieve. Fields also
allows for a greater level of control for example:
Retrieving a Subrange of Array Elements
---------------------------------------
You can use the $slice operator to retrieve a subrange of elements in
an array ::
post = BlogPost.objects(...).fields(slice__comments=5) // first 5 comments
:param kwargs: A dictionary identifying what to include
.. versionadded:: 0.5
"""
# Check for an operator and transform to mongo-style if there is
operators = ["slice"]
cleaned_fields = []
for key, value in kwargs.items():
parts = key.split('__')
op = None
if parts[0] in operators:
op = parts.pop(0)
value = {'$' + op: value}
key = '.'.join(parts)
cleaned_fields.append((key, value))
fields = sorted(cleaned_fields, key=operator.itemgetter(1))
for value, group in itertools.groupby(fields, lambda x: x[1]):
fields = [field for field, value in group]
fields = self._fields_to_dbfields(fields)
self._loaded_fields += QueryFieldList(fields, value=value)
return self return self
def all_fields(self): def all_fields(self):
@ -917,6 +985,10 @@ class QuerySet(object):
if key[0] in ('-', '+'): if key[0] in ('-', '+'):
key = key[1:] key = key[1:]
key = key.replace('__', '.') key = key.replace('__', '.')
try:
key = QuerySet._translate_field_name(self._document, key)
except:
pass
key_list.append((key, direction)) key_list.append((key, direction))
self._ordering = key_list self._ordering = key_list
@ -1007,10 +1079,17 @@ class QuerySet(object):
if _doc_cls: if _doc_cls:
# Switch field names to proper names [set in Field(name='foo')] # Switch field names to proper names [set in Field(name='foo')]
fields = QuerySet._lookup_field(_doc_cls, parts) fields = QuerySet._lookup_field(_doc_cls, parts)
parts = [field.db_field for field in fields] parts = []
for field in fields:
if isinstance(field, str):
parts.append(field)
else:
parts.append(field.db_field)
# Convert value to proper value # Convert value to proper value
field = fields[-1] field = fields[-1]
if op in (None, 'set', 'push', 'pull', 'addToSet'): if op in (None, 'set', 'push', 'pull', 'addToSet'):
value = field.prepare_query_value(op, value) value = field.prepare_query_value(op, value)
elif op in ('pushAll', 'pullAll'): elif op in ('pushAll', 'pullAll'):
@ -1029,22 +1108,27 @@ class QuerySet(object):
return mongo_update return mongo_update
def update(self, safe_update=True, upsert=False, **update): def update(self, safe_update=True, upsert=False, write_options=None, **update):
"""Perform an atomic update on the fields matched by the query. When """Perform an atomic update on the fields matched by the query. When
``safe_update`` is used, the number of affected documents is returned. ``safe_update`` is used, the number of affected documents is returned.
:param safe: check if the operation succeeded before returning :param safe_update: check if the operation succeeded before returning
:param update: Django-style update keyword arguments :param upsert: Any existing document with that "_id" is overwritten.
:param write_options: extra keyword arguments for :meth:`~pymongo.collection.Collection.update`
.. versionadded:: 0.2 .. versionadded:: 0.2
""" """
if pymongo.version < '1.1.1': if pymongo.version < '1.1.1':
raise OperationError('update() method requires PyMongo 1.1.1+') raise OperationError('update() method requires PyMongo 1.1.1+')
if not write_options:
write_options = {}
update = QuerySet._transform_update(self._document, **update) update = QuerySet._transform_update(self._document, **update)
try: try:
ret = self._collection.update(self._query, update, multi=True, ret = self._collection.update(self._query, update, multi=True,
upsert=upsert, safe=safe_update) upsert=upsert, safe=safe_update,
**write_options)
if ret is not None and 'n' in ret: if ret is not None and 'n' in ret:
return ret['n'] return ret['n']
except pymongo.errors.OperationFailure, err: except pymongo.errors.OperationFailure, err:
@ -1053,22 +1137,27 @@ class QuerySet(object):
raise OperationError(message) raise OperationError(message)
raise OperationError(u'Update failed (%s)' % unicode(err)) raise OperationError(u'Update failed (%s)' % unicode(err))
def update_one(self, safe_update=True, upsert=False, **update): def update_one(self, safe_update=True, upsert=False, write_options=None, **update):
"""Perform an atomic update on first field matched by the query. When """Perform an atomic update on first field matched by the query. When
``safe_update`` is used, the number of affected documents is returned. ``safe_update`` is used, the number of affected documents is returned.
:param safe: check if the operation succeeded before returning :param safe_update: check if the operation succeeded before returning
:param upsert: Any existing document with that "_id" is overwritten.
:param write_options: extra keyword arguments for :meth:`~pymongo.collection.Collection.update`
:param update: Django-style update keyword arguments :param update: Django-style update keyword arguments
.. versionadded:: 0.2 .. versionadded:: 0.2
""" """
if not write_options:
write_options = {}
update = QuerySet._transform_update(self._document, **update) update = QuerySet._transform_update(self._document, **update)
try: try:
# Explicitly provide 'multi=False' to newer versions of PyMongo # Explicitly provide 'multi=False' to newer versions of PyMongo
# as the default may change to 'True' # as the default may change to 'True'
if pymongo.version >= '1.1.1': if pymongo.version >= '1.1.1':
ret = self._collection.update(self._query, update, multi=False, ret = self._collection.update(self._query, update, multi=False,
upsert=upsert, safe=safe_update) upsert=upsert, safe=safe_update,
**write_options)
else: else:
# Older versions of PyMongo don't support 'multi' # Older versions of PyMongo don't support 'multi'
ret = self._collection.update(self._query, update, ret = self._collection.update(self._query, update,
@ -1082,8 +1171,8 @@ class QuerySet(object):
return self return self
def _sub_js_fields(self, code): def _sub_js_fields(self, code):
"""When fields are specified with [~fieldname] syntax, where """When fields are specified with [~fieldname] syntax, where
*fieldname* is the Python name of a field, *fieldname* will be *fieldname* is the Python name of a field, *fieldname* will be
substituted for the MongoDB name of the field (specified using the substituted for the MongoDB name of the field (specified using the
:attr:`name` keyword argument in a field's constructor). :attr:`name` keyword argument in a field's constructor).
""" """
@ -1106,9 +1195,9 @@ class QuerySet(object):
options specified as keyword arguments. options specified as keyword arguments.
As fields in MongoEngine may use different names in the database (set As fields in MongoEngine may use different names in the database (set
using the :attr:`db_field` keyword argument to a :class:`Field` using the :attr:`db_field` keyword argument to a :class:`Field`
constructor), a mechanism exists for replacing MongoEngine field names constructor), a mechanism exists for replacing MongoEngine field names
with the database field names in Javascript code. When accessing a with the database field names in Javascript code. When accessing a
field, use square-bracket notation, and prefix the MongoEngine field field, use square-bracket notation, and prefix the MongoEngine field
name with a tilde (~). name with a tilde (~).
@ -1241,8 +1330,11 @@ class QuerySet(object):
class QuerySetManager(object): class QuerySetManager(object):
def __init__(self, manager_func=None): get_queryset = None
self._manager_func = manager_func
def __init__(self, queryset_func=None):
if queryset_func:
self.get_queryset = queryset_func
self._collections = {} self._collections = {}
def __get__(self, instance, owner): def __get__(self, instance, owner):
@ -1259,7 +1351,7 @@ class QuerySetManager(object):
# Create collection as a capped collection if specified # Create collection as a capped collection if specified
if owner._meta['max_size'] or owner._meta['max_documents']: if owner._meta['max_size'] or owner._meta['max_documents']:
# Get max document limit and max byte size from meta # Get max document limit and max byte size from meta
max_size = owner._meta['max_size'] or 10000000 # 10MB default max_size = owner._meta['max_size'] or 10000000 # 10MB default
max_documents = owner._meta['max_documents'] max_documents = owner._meta['max_documents']
if collection in db.collection_names(): if collection in db.collection_names():
@ -1286,11 +1378,11 @@ class QuerySetManager(object):
# owner is the document that contains the QuerySetManager # owner is the document that contains the QuerySetManager
queryset_class = owner._meta['queryset_class'] or QuerySet queryset_class = owner._meta['queryset_class'] or QuerySet
queryset = queryset_class(owner, self._collections[(db, collection)]) queryset = queryset_class(owner, self._collections[(db, collection)])
if self._manager_func: if self.get_queryset:
if self._manager_func.func_code.co_argcount == 1: if self.get_queryset.func_code.co_argcount == 1:
queryset = self._manager_func(queryset) queryset = self.get_queryset(queryset)
else: else:
queryset = self._manager_func(owner, queryset) queryset = self.get_queryset(owner, queryset)
return queryset return queryset

View File

@ -1,13 +1,25 @@
import unittest import unittest
from datetime import datetime from datetime import datetime
import pymongo import pymongo
import pickle
from mongoengine import * from mongoengine import *
from mongoengine.base import BaseField
from mongoengine.connection import _get_db from mongoengine.connection import _get_db
class PickleEmbedded(EmbeddedDocument):
date = DateTimeField(default=datetime.now)
class PickleTest(Document):
number = IntField()
string = StringField()
embedded = EmbeddedDocumentField(PickleEmbedded)
lists = ListField(StringField())
class DocumentTest(unittest.TestCase): class DocumentTest(unittest.TestCase):
def setUp(self): def setUp(self):
connect(db='mongoenginetest') connect(db='mongoenginetest')
self.db = _get_db() self.db = _get_db()
@ -17,6 +29,9 @@ class DocumentTest(unittest.TestCase):
age = IntField() age = IntField()
self.Person = Person self.Person = Person
def tearDown(self):
self.Person.drop_collection()
def test_drop_collection(self): def test_drop_collection(self):
"""Ensure that the collection may be dropped from the database. """Ensure that the collection may be dropped from the database.
""" """
@ -38,7 +53,7 @@ class DocumentTest(unittest.TestCase):
name = name_field name = name_field
age = age_field age = age_field
non_field = True non_field = True
self.assertEqual(Person._fields['name'], name_field) self.assertEqual(Person._fields['name'], name_field)
self.assertEqual(Person._fields['age'], age_field) self.assertEqual(Person._fields['age'], age_field)
self.assertFalse('non_field' in Person._fields) self.assertFalse('non_field' in Person._fields)
@ -60,7 +75,7 @@ class DocumentTest(unittest.TestCase):
mammal_superclasses = {'Animal': Animal} mammal_superclasses = {'Animal': Animal}
self.assertEqual(Mammal._superclasses, mammal_superclasses) self.assertEqual(Mammal._superclasses, mammal_superclasses)
dog_superclasses = { dog_superclasses = {
'Animal': Animal, 'Animal': Animal,
'Animal.Mammal': Mammal, 'Animal.Mammal': Mammal,
@ -68,7 +83,7 @@ class DocumentTest(unittest.TestCase):
self.assertEqual(Dog._superclasses, dog_superclasses) self.assertEqual(Dog._superclasses, dog_superclasses)
def test_get_subclasses(self): def test_get_subclasses(self):
"""Ensure that the correct list of subclasses is retrieved by the """Ensure that the correct list of subclasses is retrieved by the
_get_subclasses method. _get_subclasses method.
""" """
class Animal(Document): pass class Animal(Document): pass
@ -78,15 +93,15 @@ class DocumentTest(unittest.TestCase):
class Dog(Mammal): pass class Dog(Mammal): pass
mammal_subclasses = { mammal_subclasses = {
'Animal.Mammal.Dog': Dog, 'Animal.Mammal.Dog': Dog,
'Animal.Mammal.Human': Human 'Animal.Mammal.Human': Human
} }
self.assertEqual(Mammal._get_subclasses(), mammal_subclasses) self.assertEqual(Mammal._get_subclasses(), mammal_subclasses)
animal_subclasses = { animal_subclasses = {
'Animal.Fish': Fish, 'Animal.Fish': Fish,
'Animal.Mammal': Mammal, 'Animal.Mammal': Mammal,
'Animal.Mammal.Dog': Dog, 'Animal.Mammal.Dog': Dog,
'Animal.Mammal.Human': Human 'Animal.Mammal.Human': Human
} }
self.assertEqual(Animal._get_subclasses(), animal_subclasses) self.assertEqual(Animal._get_subclasses(), animal_subclasses)
@ -124,7 +139,7 @@ class DocumentTest(unittest.TestCase):
self.assertTrue('name' in Employee._fields) self.assertTrue('name' in Employee._fields)
self.assertTrue('salary' in Employee._fields) self.assertTrue('salary' in Employee._fields)
self.assertEqual(Employee._meta['collection'], self.assertEqual(Employee._meta['collection'],
self.Person._meta['collection']) self.Person._meta['collection'])
# Ensure that MRO error is not raised # Ensure that MRO error is not raised
@ -146,7 +161,7 @@ class DocumentTest(unittest.TestCase):
class Dog(Animal): class Dog(Animal):
pass pass
self.assertRaises(ValueError, create_dog_class) self.assertRaises(ValueError, create_dog_class)
# Check that _cls etc aren't present on simple documents # Check that _cls etc aren't present on simple documents
dog = Animal(name='dog') dog = Animal(name='dog')
dog.save() dog.save()
@ -161,7 +176,7 @@ class DocumentTest(unittest.TestCase):
class Employee(self.Person): class Employee(self.Person):
meta = {'allow_inheritance': False} meta = {'allow_inheritance': False}
self.assertRaises(ValueError, create_employee_class) self.assertRaises(ValueError, create_employee_class)
# Test the same for embedded documents # Test the same for embedded documents
class Comment(EmbeddedDocument): class Comment(EmbeddedDocument):
content = StringField() content = StringField()
@ -176,6 +191,34 @@ class DocumentTest(unittest.TestCase):
self.assertFalse('_cls' in comment.to_mongo()) self.assertFalse('_cls' in comment.to_mongo())
self.assertFalse('_types' in comment.to_mongo()) self.assertFalse('_types' in comment.to_mongo())
def test_abstract_documents(self):
"""Ensure that a document superclass can be marked as abstract
thereby not using it as the name for the collection."""
class Animal(Document):
name = StringField()
meta = {'abstract': True}
class Fish(Animal): pass
class Guppy(Fish): pass
class Mammal(Animal):
meta = {'abstract': True}
class Human(Mammal): pass
self.assertFalse('collection' in Animal._meta)
self.assertFalse('collection' in Mammal._meta)
self.assertEqual(Fish._meta['collection'], 'fish')
self.assertEqual(Guppy._meta['collection'], 'fish')
self.assertEqual(Human._meta['collection'], 'human')
def create_bad_abstract():
class EvilHuman(Human):
evil = BooleanField(default=True)
meta = {'abstract': True}
self.assertRaises(ValueError, create_bad_abstract)
def test_collection_name(self): def test_collection_name(self):
"""Ensure that a collection with a specified name may be used. """Ensure that a collection with a specified name may be used.
""" """
@ -186,7 +229,7 @@ class DocumentTest(unittest.TestCase):
class Person(Document): class Person(Document):
name = StringField() name = StringField()
meta = {'collection': collection} meta = {'collection': collection}
user = Person(name="Test User") user = Person(name="Test User")
user.save() user.save()
self.assertTrue(collection in self.db.collection_names()) self.assertTrue(collection in self.db.collection_names())
@ -200,6 +243,22 @@ class DocumentTest(unittest.TestCase):
Person.drop_collection() Person.drop_collection()
self.assertFalse(collection in self.db.collection_names()) self.assertFalse(collection in self.db.collection_names())
def test_collection_name_and_primary(self):
"""Ensure that a collection with a specified name may be used.
"""
class Person(Document):
name = StringField(primary_key=True)
meta = {'collection': 'app'}
user = Person(name="Test User")
user.save()
user_obj = Person.objects[0]
self.assertEqual(user_obj.name, "Test User")
Person.drop_collection()
def test_inherited_collections(self): def test_inherited_collections(self):
"""Ensure that subclassed documents don't override parents' collections. """Ensure that subclassed documents don't override parents' collections.
""" """
@ -280,7 +339,7 @@ class DocumentTest(unittest.TestCase):
tags = ListField(StringField()) tags = ListField(StringField())
meta = { meta = {
'indexes': [ 'indexes': [
'-date', '-date',
'tags', 'tags',
('category', '-date') ('category', '-date')
], ],
@ -296,12 +355,12 @@ class DocumentTest(unittest.TestCase):
list(BlogPost.objects) list(BlogPost.objects)
info = BlogPost.objects._collection.index_information() info = BlogPost.objects._collection.index_information()
info = [value['key'] for key, value in info.iteritems()] info = [value['key'] for key, value in info.iteritems()]
self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)] self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)]
in info) in info)
self.assertTrue([('_types', 1), ('addDate', -1)] in info) self.assertTrue([('_types', 1), ('addDate', -1)] in info)
# tags is a list field so it shouldn't have _types in the index # tags is a list field so it shouldn't have _types in the index
self.assertTrue([('tags', 1)] in info) self.assertTrue([('tags', 1)] in info)
class ExtendedBlogPost(BlogPost): class ExtendedBlogPost(BlogPost):
title = StringField() title = StringField()
meta = {'indexes': ['title']} meta = {'indexes': ['title']}
@ -311,7 +370,7 @@ class DocumentTest(unittest.TestCase):
list(ExtendedBlogPost.objects) list(ExtendedBlogPost.objects)
info = ExtendedBlogPost.objects._collection.index_information() info = ExtendedBlogPost.objects._collection.index_information()
info = [value['key'] for key, value in info.iteritems()] info = [value['key'] for key, value in info.iteritems()]
self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)] self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)]
in info) in info)
self.assertTrue([('_types', 1), ('addDate', -1)] in info) self.assertTrue([('_types', 1), ('addDate', -1)] in info)
self.assertTrue([('_types', 1), ('title', 1)] in info) self.assertTrue([('_types', 1), ('title', 1)] in info)
@ -334,6 +393,10 @@ class DocumentTest(unittest.TestCase):
post2 = BlogPost(title='test2', slug='test') post2 = BlogPost(title='test2', slug='test')
self.assertRaises(OperationError, post2.save) self.assertRaises(OperationError, post2.save)
def test_unique_with(self):
"""Ensure that unique_with constraints are applied to fields.
"""
class Date(EmbeddedDocument): class Date(EmbeddedDocument):
year = IntField(db_field='yr') year = IntField(db_field='yr')
@ -357,6 +420,108 @@ class DocumentTest(unittest.TestCase):
BlogPost.drop_collection() BlogPost.drop_collection()
def test_unique_embedded_document(self):
"""Ensure that uniqueness constraints are applied to fields on embedded documents.
"""
class SubDocument(EmbeddedDocument):
year = IntField(db_field='yr')
slug = StringField(unique=True)
class BlogPost(Document):
title = StringField()
sub = EmbeddedDocumentField(SubDocument)
BlogPost.drop_collection()
post1 = BlogPost(title='test1', sub=SubDocument(year=2009, slug="test"))
post1.save()
# sub.slug is different so won't raise exception
post2 = BlogPost(title='test2', sub=SubDocument(year=2010, slug='another-slug'))
post2.save()
# Now there will be two docs with the same sub.slug
post3 = BlogPost(title='test3', sub=SubDocument(year=2010, slug='test'))
self.assertRaises(OperationError, post3.save)
BlogPost.drop_collection()
def test_unique_with_embedded_document_and_embedded_unique(self):
"""Ensure that uniqueness constraints are applied to fields on
embedded documents. And work with unique_with as well.
"""
class SubDocument(EmbeddedDocument):
year = IntField(db_field='yr')
slug = StringField(unique=True)
class BlogPost(Document):
title = StringField(unique_with='sub.year')
sub = EmbeddedDocumentField(SubDocument)
BlogPost.drop_collection()
post1 = BlogPost(title='test1', sub=SubDocument(year=2009, slug="test"))
post1.save()
# sub.slug is different so won't raise exception
post2 = BlogPost(title='test2', sub=SubDocument(year=2010, slug='another-slug'))
post2.save()
# Now there will be two docs with the same sub.slug
post3 = BlogPost(title='test3', sub=SubDocument(year=2010, slug='test'))
self.assertRaises(OperationError, post3.save)
# Now there will be two docs with the same title and year
post3 = BlogPost(title='test1', sub=SubDocument(year=2009, slug='test-1'))
self.assertRaises(OperationError, post3.save)
BlogPost.drop_collection()
def test_unique_and_indexes(self):
"""Ensure that 'unique' constraints aren't overridden by
meta.indexes.
"""
class Customer(Document):
cust_id = IntField(unique=True, required=True)
meta = {
'indexes': ['cust_id'],
'allow_inheritance': False,
}
Customer.drop_collection()
cust = Customer(cust_id=1)
cust.save()
cust_dupe = Customer(cust_id=1)
try:
cust_dupe.save()
raise AssertionError, "We saved a dupe!"
except OperationError:
pass
Customer.drop_collection()
def test_unique_and_primary(self):
"""If you set a field as primary, then unexpected behaviour can occur.
You won't create a duplicate but you will update an existing document.
"""
class User(Document):
name = StringField(primary_key=True, unique=True)
password = StringField()
User.drop_collection()
user = User(name='huangz', password='secret')
user.save()
user = User(name='huangz', password='secret2')
user.save()
self.assertEqual(User.objects.count(), 1)
self.assertEqual(User.objects.get().password, 'secret2')
User.drop_collection()
def test_custom_id_field(self): def test_custom_id_field(self):
"""Ensure that documents may be created with custom primary keys. """Ensure that documents may be created with custom primary keys.
""" """
@ -380,7 +545,7 @@ class DocumentTest(unittest.TestCase):
class EmailUser(User): class EmailUser(User):
email = StringField() email = StringField()
user = User(username='test', name='test user') user = User(username='test', name='test user')
user.save() user.save()
@ -391,20 +556,20 @@ class DocumentTest(unittest.TestCase):
user_son = User.objects._collection.find_one() user_son = User.objects._collection.find_one()
self.assertEqual(user_son['_id'], 'test') self.assertEqual(user_son['_id'], 'test')
self.assertTrue('username' not in user_son['_id']) self.assertTrue('username' not in user_son['_id'])
User.drop_collection() User.drop_collection()
user = User(pk='mongo', name='mongo user') user = User(pk='mongo', name='mongo user')
user.save() user.save()
user_obj = User.objects.first() user_obj = User.objects.first()
self.assertEqual(user_obj.id, 'mongo') self.assertEqual(user_obj.id, 'mongo')
self.assertEqual(user_obj.pk, 'mongo') self.assertEqual(user_obj.pk, 'mongo')
user_son = User.objects._collection.find_one() user_son = User.objects._collection.find_one()
self.assertEqual(user_son['_id'], 'mongo') self.assertEqual(user_son['_id'], 'mongo')
self.assertTrue('username' not in user_son['_id']) self.assertTrue('username' not in user_son['_id'])
User.drop_collection() User.drop_collection()
def test_creation(self): def test_creation(self):
@ -457,18 +622,18 @@ class DocumentTest(unittest.TestCase):
""" """
class Comment(EmbeddedDocument): class Comment(EmbeddedDocument):
content = StringField() content = StringField()
self.assertTrue('content' in Comment._fields) self.assertTrue('content' in Comment._fields)
self.assertFalse('id' in Comment._fields) self.assertFalse('id' in Comment._fields)
self.assertFalse('collection' in Comment._meta) self.assertFalse('collection' in Comment._meta)
def test_embedded_document_validation(self): def test_embedded_document_validation(self):
"""Ensure that embedded documents may be validated. """Ensure that embedded documents may be validated.
""" """
class Comment(EmbeddedDocument): class Comment(EmbeddedDocument):
date = DateTimeField() date = DateTimeField()
content = StringField(required=True) content = StringField(required=True)
comment = Comment() comment = Comment()
self.assertRaises(ValidationError, comment.validate) self.assertRaises(ValidationError, comment.validate)
@ -496,7 +661,7 @@ class DocumentTest(unittest.TestCase):
# Test skipping validation on save # Test skipping validation on save
class Recipient(Document): class Recipient(Document):
email = EmailField(required=True) email = EmailField(required=True)
recipient = Recipient(email='root@localhost') recipient = Recipient(email='root@localhost')
self.assertRaises(ValidationError, recipient.save) self.assertRaises(ValidationError, recipient.save)
try: try:
@ -517,19 +682,19 @@ class DocumentTest(unittest.TestCase):
"""Ensure that a document may be saved with a custom _id. """Ensure that a document may be saved with a custom _id.
""" """
# Create person object and save it to the database # Create person object and save it to the database
person = self.Person(name='Test User', age=30, person = self.Person(name='Test User', age=30,
id='497ce96f395f2f052a494fd4') id='497ce96f395f2f052a494fd4')
person.save() person.save()
# Ensure that the object is in the database with the correct _id # Ensure that the object is in the database with the correct _id
collection = self.db[self.Person._meta['collection']] collection = self.db[self.Person._meta['collection']]
person_obj = collection.find_one({'name': 'Test User'}) person_obj = collection.find_one({'name': 'Test User'})
self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4') self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4')
def test_save_custom_pk(self): def test_save_custom_pk(self):
"""Ensure that a document may be saved with a custom _id using pk alias. """Ensure that a document may be saved with a custom _id using pk alias.
""" """
# Create person object and save it to the database # Create person object and save it to the database
person = self.Person(name='Test User', age=30, person = self.Person(name='Test User', age=30,
pk='497ce96f395f2f052a494fd4') pk='497ce96f395f2f052a494fd4')
person.save() person.save()
# Ensure that the object is in the database with the correct _id # Ensure that the object is in the database with the correct _id
@ -565,7 +730,7 @@ class DocumentTest(unittest.TestCase):
BlogPost.drop_collection() BlogPost.drop_collection()
def test_save_embedded_document(self): def test_save_embedded_document(self):
"""Ensure that a document with an embedded document field may be """Ensure that a document with an embedded document field may be
saved in the database. saved in the database.
""" """
class EmployeeDetails(EmbeddedDocument): class EmployeeDetails(EmbeddedDocument):
@ -588,10 +753,38 @@ class DocumentTest(unittest.TestCase):
# Ensure that the 'details' embedded object saved correctly # Ensure that the 'details' embedded object saved correctly
self.assertEqual(employee_obj['details']['position'], 'Developer') self.assertEqual(employee_obj['details']['position'], 'Developer')
def test_updating_an_embedded_document(self):
"""Ensure that a document with an embedded document field may be
saved in the database.
"""
class EmployeeDetails(EmbeddedDocument):
position = StringField()
class Employee(self.Person):
salary = IntField()
details = EmbeddedDocumentField(EmployeeDetails)
# Create employee object and save it to the database
employee = Employee(name='Test Employee', age=50, salary=20000)
employee.details = EmployeeDetails(position='Developer')
employee.save()
# Test updating an embedded document
promoted_employee = Employee.objects.get(name='Test Employee')
promoted_employee.details.position = 'Senior Developer'
promoted_employee.save()
collection = self.db[self.Person._meta['collection']]
employee_obj = collection.find_one({'name': 'Test Employee'})
self.assertEqual(employee_obj['name'], 'Test Employee')
self.assertEqual(employee_obj['age'], 50)
# Ensure that the 'details' embedded object saved correctly
self.assertEqual(employee_obj['details']['position'], 'Senior Developer')
def test_save_reference(self): def test_save_reference(self):
"""Ensure that a document reference field may be saved in the database. """Ensure that a document reference field may be saved in the database.
""" """
class BlogPost(Document): class BlogPost(Document):
meta = {'collection': 'blogpost_1'} meta = {'collection': 'blogpost_1'}
content = StringField() content = StringField()
@ -610,7 +803,7 @@ class DocumentTest(unittest.TestCase):
post_obj = BlogPost.objects.first() post_obj = BlogPost.objects.first()
# Test laziness # Test laziness
self.assertTrue(isinstance(post_obj._data['author'], self.assertTrue(isinstance(post_obj._data['author'],
pymongo.dbref.DBRef)) pymongo.dbref.DBRef))
self.assertTrue(isinstance(post_obj.author, self.Person)) self.assertTrue(isinstance(post_obj.author, self.Person))
self.assertEqual(post_obj.author.name, 'Test User') self.assertEqual(post_obj.author.name, 'Test User')
@ -725,9 +918,25 @@ class DocumentTest(unittest.TestCase):
self.Person.drop_collection() self.Person.drop_collection()
BlogPost.drop_collection() BlogPost.drop_collection()
def subclasses_and_unique_keys_works(self):
def tearDown(self): class A(Document):
self.Person.drop_collection() pass
class B(A):
foo = BooleanField(unique=True)
A.drop_collection()
B.drop_collection()
A().save()
A().save()
B(foo=True).save()
self.assertEquals(A.objects.count(), 2)
self.assertEquals(B.objects.count(), 1)
A.drop_collection()
B.drop_collection()
def test_document_hash(self): def test_document_hash(self):
"""Test document in list, dict, set """Test document in list, dict, set
@ -737,7 +946,7 @@ class DocumentTest(unittest.TestCase):
class BlogPost(Document): class BlogPost(Document):
pass pass
# Clear old datas # Clear old datas
User.drop_collection() User.drop_collection()
BlogPost.drop_collection() BlogPost.drop_collection()
@ -774,9 +983,46 @@ class DocumentTest(unittest.TestCase):
# in Set # in Set
all_user_set = set(User.objects.all()) all_user_set = set(User.objects.all())
self.assertTrue(u1 in all_user_set ) self.assertTrue(u1 in all_user_set )
def test_picklable(self):
pickle_doc = PickleTest(number=1, string="OH HAI", lists=['1', '2'])
pickle_doc.embedded = PickleEmbedded()
pickle_doc.save()
pickled_doc = pickle.dumps(pickle_doc)
resurrected = pickle.loads(pickled_doc)
self.assertEquals(resurrected, pickle_doc)
resurrected.string = "Working"
resurrected.save()
pickle_doc.reload()
self.assertEquals(resurrected, pickle_doc)
def test_write_options(self):
"""Test that passing write_options works"""
self.Person.drop_collection()
write_options = {"fsync": True}
author, created = self.Person.objects.get_or_create(
name='Test User', write_options=write_options)
author.save(write_options=write_options)
self.Person.objects.update(set__name='Ross', write_options=write_options)
author = self.Person.objects.first()
self.assertEquals(author.name, 'Ross')
self.Person.objects.update_one(set__name='Test User', write_options=write_options)
author = self.Person.objects.first()
self.assertEquals(author.name, 'Test User')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -7,6 +7,7 @@ import gridfs
from mongoengine import * from mongoengine import *
from mongoengine.connection import _get_db from mongoengine.connection import _get_db
from mongoengine.base import _document_registry, NotRegistered
class FieldTest(unittest.TestCase): class FieldTest(unittest.TestCase):
@ -45,7 +46,7 @@ class FieldTest(unittest.TestCase):
""" """
class Person(Document): class Person(Document):
name = StringField() name = StringField()
person = Person(name='Test User') person = Person(name='Test User')
self.assertEqual(person.id, None) self.assertEqual(person.id, None)
@ -95,7 +96,7 @@ class FieldTest(unittest.TestCase):
link.url = 'http://www.google.com:8080' link.url = 'http://www.google.com:8080'
link.validate() link.validate()
def test_int_validation(self): def test_int_validation(self):
"""Ensure that invalid values cannot be assigned to int fields. """Ensure that invalid values cannot be assigned to int fields.
""" """
@ -129,12 +130,12 @@ class FieldTest(unittest.TestCase):
self.assertRaises(ValidationError, person.validate) self.assertRaises(ValidationError, person.validate)
person.height = 4.0 person.height = 4.0
self.assertRaises(ValidationError, person.validate) self.assertRaises(ValidationError, person.validate)
def test_decimal_validation(self): def test_decimal_validation(self):
"""Ensure that invalid values cannot be assigned to decimal fields. """Ensure that invalid values cannot be assigned to decimal fields.
""" """
class Person(Document): class Person(Document):
height = DecimalField(min_value=Decimal('0.1'), height = DecimalField(min_value=Decimal('0.1'),
max_value=Decimal('3.5')) max_value=Decimal('3.5'))
Person.drop_collection() Person.drop_collection()
@ -249,7 +250,7 @@ class FieldTest(unittest.TestCase):
post.save() post.save()
post.reload() post.reload()
self.assertEqual(post.tags, ['fun', 'leisure']) self.assertEqual(post.tags, ['fun', 'leisure'])
comment1 = Comment(content='Good for you', order=1) comment1 = Comment(content='Good for you', order=1)
comment2 = Comment(content='Yay.', order=0) comment2 = Comment(content='Yay.', order=0)
comments = [comment1, comment2] comments = [comment1, comment2]
@ -261,12 +262,14 @@ class FieldTest(unittest.TestCase):
BlogPost.drop_collection() BlogPost.drop_collection()
def test_dict_validation(self): def test_dict_field(self):
"""Ensure that dict types work as expected. """Ensure that dict types work as expected.
""" """
class BlogPost(Document): class BlogPost(Document):
info = DictField() info = DictField()
BlogPost.drop_collection()
post = BlogPost() post = BlogPost()
post.info = 'my post' post.info = 'my post'
self.assertRaises(ValidationError, post.validate) self.assertRaises(ValidationError, post.validate)
@ -281,7 +284,24 @@ class FieldTest(unittest.TestCase):
self.assertRaises(ValidationError, post.validate) self.assertRaises(ValidationError, post.validate)
post.info = {'title': 'test'} post.info = {'title': 'test'}
post.validate() post.save()
post = BlogPost()
post.info = {'details': {'test': 'test'}}
post.save()
post = BlogPost()
post.info = {'details': {'test': 3}}
post.save()
self.assertEquals(BlogPost.objects.count(), 3)
self.assertEquals(BlogPost.objects.filter(info__title__exact='test').count(), 1)
self.assertEquals(BlogPost.objects.filter(info__details__test__exact='test').count(), 1)
# Confirm handles non strings or non existing keys
self.assertEquals(BlogPost.objects.filter(info__details__test__exact=5).count(), 0)
self.assertEquals(BlogPost.objects.filter(info__made_up__test__exact='test').count(), 0)
BlogPost.drop_collection()
def test_embedded_document_validation(self): def test_embedded_document_validation(self):
"""Ensure that invalid embedded documents cannot be assigned to """Ensure that invalid embedded documents cannot be assigned to
@ -315,7 +335,7 @@ class FieldTest(unittest.TestCase):
person.validate() person.validate()
def test_embedded_document_inheritance(self): def test_embedded_document_inheritance(self):
"""Ensure that subclasses of embedded documents may be provided to """Ensure that subclasses of embedded documents may be provided to
EmbeddedDocumentFields of the superclass' type. EmbeddedDocumentFields of the superclass' type.
""" """
class User(EmbeddedDocument): class User(EmbeddedDocument):
@ -327,7 +347,7 @@ class FieldTest(unittest.TestCase):
class BlogPost(Document): class BlogPost(Document):
content = StringField() content = StringField()
author = EmbeddedDocumentField(User) author = EmbeddedDocumentField(User)
post = BlogPost(content='What I did today...') post = BlogPost(content='What I did today...')
post.author = User(name='Test User') post.author = User(name='Test User')
post.author = PowerUser(name='Test User', power=47) post.author = PowerUser(name='Test User', power=47)
@ -370,7 +390,7 @@ class FieldTest(unittest.TestCase):
User.drop_collection() User.drop_collection()
BlogPost.drop_collection() BlogPost.drop_collection()
def test_list_item_dereference(self): def test_list_item_dereference(self):
"""Ensure that DBRef items in ListFields are dereferenced. """Ensure that DBRef items in ListFields are dereferenced.
""" """
@ -434,7 +454,7 @@ class FieldTest(unittest.TestCase):
class TreeNode(EmbeddedDocument): class TreeNode(EmbeddedDocument):
name = StringField() name = StringField()
children = ListField(EmbeddedDocumentField('self')) children = ListField(EmbeddedDocumentField('self'))
tree = Tree(name="Tree") tree = Tree(name="Tree")
first_child = TreeNode(name="Child 1") first_child = TreeNode(name="Child 1")
@ -442,7 +462,7 @@ class FieldTest(unittest.TestCase):
second_child = TreeNode(name="Child 2") second_child = TreeNode(name="Child 2")
first_child.children.append(second_child) first_child.children.append(second_child)
third_child = TreeNode(name="Child 3") third_child = TreeNode(name="Child 3")
first_child.children.append(third_child) first_child.children.append(third_child)
@ -506,20 +526,20 @@ class FieldTest(unittest.TestCase):
Member.drop_collection() Member.drop_collection()
BlogPost.drop_collection() BlogPost.drop_collection()
def test_generic_reference(self): def test_generic_reference(self):
"""Ensure that a GenericReferenceField properly dereferences items. """Ensure that a GenericReferenceField properly dereferences items.
""" """
class Link(Document): class Link(Document):
title = StringField() title = StringField()
meta = {'allow_inheritance': False} meta = {'allow_inheritance': False}
class Post(Document): class Post(Document):
title = StringField() title = StringField()
class Bookmark(Document): class Bookmark(Document):
bookmark_object = GenericReferenceField() bookmark_object = GenericReferenceField()
Link.drop_collection() Link.drop_collection()
Post.drop_collection() Post.drop_collection()
Bookmark.drop_collection() Bookmark.drop_collection()
@ -574,16 +594,49 @@ class FieldTest(unittest.TestCase):
user = User(bookmarks=[post_1, link_1]) user = User(bookmarks=[post_1, link_1])
user.save() user.save()
user = User.objects(bookmarks__all=[post_1, link_1]).first() user = User.objects(bookmarks__all=[post_1, link_1]).first()
self.assertEqual(user.bookmarks[0], post_1) self.assertEqual(user.bookmarks[0], post_1)
self.assertEqual(user.bookmarks[1], link_1) self.assertEqual(user.bookmarks[1], link_1)
Link.drop_collection() Link.drop_collection()
Post.drop_collection() Post.drop_collection()
User.drop_collection() User.drop_collection()
def test_generic_reference_document_not_registered(self):
"""Ensure dereferencing out of the document registry throws a
`NotRegistered` error.
"""
class Link(Document):
title = StringField()
class User(Document):
bookmarks = ListField(GenericReferenceField())
Link.drop_collection()
User.drop_collection()
link_1 = Link(title="Pitchfork")
link_1.save()
user = User(bookmarks=[link_1])
user.save()
# Mimic User and Link definitions being in a different file
# and the Link model not being imported in the User file.
del(_document_registry["Link"])
user = User.objects.first()
try:
user.bookmarks
raise AssertionError, "Link was removed from the registry"
except NotRegistered:
pass
Link.drop_collection()
User.drop_collection()
def test_binary_fields(self): def test_binary_fields(self):
"""Ensure that binary fields can be stored and retrieved. """Ensure that binary fields can be stored and retrieved.
""" """
@ -701,6 +754,12 @@ class FieldTest(unittest.TestCase):
self.assertTrue(streamfile == result) self.assertTrue(streamfile == result)
self.assertEquals(result.file.read(), text + more_text) self.assertEquals(result.file.read(), text + more_text)
self.assertEquals(result.file.content_type, content_type) self.assertEquals(result.file.content_type, content_type)
result.file.seek(0)
self.assertEquals(result.file.tell(), 0)
self.assertEquals(result.file.read(len(text)), text)
self.assertEquals(result.file.tell(), len(text))
self.assertEquals(result.file.read(len(more_text)), more_text)
self.assertEquals(result.file.tell(), len(text + more_text))
result.file.delete() result.file.delete()
# Ensure deleted file returns None # Ensure deleted file returns None
@ -721,7 +780,7 @@ class FieldTest(unittest.TestCase):
result = SetFile.objects.first() result = SetFile.objects.first()
self.assertTrue(setfile == result) self.assertTrue(setfile == result)
self.assertEquals(result.file.read(), more_text) self.assertEquals(result.file.read(), more_text)
result.file.delete() result.file.delete()
PutFile.drop_collection() PutFile.drop_collection()
StreamFile.drop_collection() StreamFile.drop_collection()
@ -785,5 +844,66 @@ class FieldTest(unittest.TestCase):
self.assertEqual(d2.data, {}) self.assertEqual(d2.data, {})
self.assertEqual(d2.data2, {}) self.assertEqual(d2.data2, {})
def test_mapfield(self):
"""Ensure that the MapField handles the declared type."""
class Simple(Document):
mapping = MapField(IntField())
Simple.drop_collection()
e = Simple()
e.mapping['someint'] = 1
e.save()
def create_invalid_mapping():
e.mapping['somestring'] = "abc"
e.save()
self.assertRaises(ValidationError, create_invalid_mapping)
def create_invalid_class():
class NoDeclaredType(Document):
mapping = MapField()
self.assertRaises(ValidationError, create_invalid_class)
Simple.drop_collection()
def test_complex_mapfield(self):
"""Ensure that the MapField can handle complex declared types."""
class SettingBase(EmbeddedDocument):
pass
class StringSetting(SettingBase):
value = StringField()
class IntegerSetting(SettingBase):
value = IntField()
class Extensible(Document):
mapping = MapField(EmbeddedDocumentField(SettingBase))
Extensible.drop_collection()
e = Extensible()
e.mapping['somestring'] = StringSetting(value='foo')
e.mapping['someint'] = IntegerSetting(value=42)
e.save()
e2 = Extensible.objects.get(id=e.id)
self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting))
self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting))
def create_invalid_mapping():
e.mapping['someint'] = 123
e.save()
self.assertRaises(ValidationError, create_invalid_mapping)
Extensible.drop_collection()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -5,8 +5,9 @@ import unittest
import pymongo import pymongo
from datetime import datetime, timedelta from datetime import datetime, timedelta
from mongoengine.queryset import (QuerySet, MultipleObjectsReturned, from mongoengine.queryset import (QuerySet, QuerySetManager,
DoesNotExist, QueryFieldList) MultipleObjectsReturned, DoesNotExist,
QueryFieldList)
from mongoengine import * from mongoengine import *
@ -105,6 +106,10 @@ class QuerySetTest(unittest.TestCase):
people = list(self.Person.objects[1:1]) people = list(self.Person.objects[1:1])
self.assertEqual(len(people), 0) self.assertEqual(len(people), 0)
# Test slice out of range
people = list(self.Person.objects[80000:80001])
self.assertEqual(len(people), 0)
def test_find_one(self): def test_find_one(self):
"""Ensure that a query using find_one returns a valid result. """Ensure that a query using find_one returns a valid result.
""" """
@ -162,7 +167,7 @@ class QuerySetTest(unittest.TestCase):
person = self.Person.objects.get(age__lt=30) person = self.Person.objects.get(age__lt=30)
self.assertEqual(person.name, "User A") self.assertEqual(person.name, "User A")
def test_find_array_position(self): def test_find_array_position(self):
"""Ensure that query by array position works. """Ensure that query by array position works.
""" """
@ -177,7 +182,7 @@ class QuerySetTest(unittest.TestCase):
posts = ListField(EmbeddedDocumentField(Post)) posts = ListField(EmbeddedDocumentField(Post))
Blog.drop_collection() Blog.drop_collection()
Blog.objects.create(tags=['a', 'b']) Blog.objects.create(tags=['a', 'b'])
self.assertEqual(len(Blog.objects(tags__0='a')), 1) self.assertEqual(len(Blog.objects(tags__0='a')), 1)
self.assertEqual(len(Blog.objects(tags__0='b')), 0) self.assertEqual(len(Blog.objects(tags__0='b')), 0)
@ -207,6 +212,55 @@ class QuerySetTest(unittest.TestCase):
Blog.drop_collection() Blog.drop_collection()
def test_update_array_position(self):
"""Ensure that updating by array position works.
Check update() and update_one() can take syntax like:
set__posts__1__comments__1__name="testc"
Check that it only works for ListFields.
"""
class Comment(EmbeddedDocument):
name = StringField()
class Post(EmbeddedDocument):
comments = ListField(EmbeddedDocumentField(Comment))
class Blog(Document):
tags = ListField(StringField())
posts = ListField(EmbeddedDocumentField(Post))
Blog.drop_collection()
comment1 = Comment(name='testa')
comment2 = Comment(name='testb')
post1 = Post(comments=[comment1, comment2])
post2 = Post(comments=[comment2, comment2])
blog1 = Blog.objects.create(posts=[post1, post2])
blog2 = Blog.objects.create(posts=[post2, post1])
# Update all of the first comments of second posts of all blogs
blog = Blog.objects().update(set__posts__1__comments__0__name="testc")
testc_blogs = Blog.objects(posts__1__comments__0__name="testc")
self.assertEqual(len(testc_blogs), 2)
Blog.drop_collection()
blog1 = Blog.objects.create(posts=[post1, post2])
blog2 = Blog.objects.create(posts=[post2, post1])
# Update only the first blog returned by the query
blog = Blog.objects().update_one(
set__posts__1__comments__1__name="testc")
testc_blogs = Blog.objects(posts__1__comments__1__name="testc")
self.assertEqual(len(testc_blogs), 1)
# Check that using this indexing syntax on a non-list fails
def non_list_indexing():
Blog.objects().update(set__posts__1__comments__0__name__1="asdf")
self.assertRaises(InvalidQueryError, non_list_indexing)
Blog.drop_collection()
def test_get_or_create(self): def test_get_or_create(self):
"""Ensure that ``get_or_create`` returns one result or creates a new """Ensure that ``get_or_create`` returns one result or creates a new
document. document.
@ -226,16 +280,16 @@ class QuerySetTest(unittest.TestCase):
person, created = self.Person.objects.get_or_create(age=30) person, created = self.Person.objects.get_or_create(age=30)
self.assertEqual(person.name, "User B") self.assertEqual(person.name, "User B")
self.assertEqual(created, False) self.assertEqual(created, False)
person, created = self.Person.objects.get_or_create(age__lt=30) person, created = self.Person.objects.get_or_create(age__lt=30)
self.assertEqual(person.name, "User A") self.assertEqual(person.name, "User A")
self.assertEqual(created, False) self.assertEqual(created, False)
# Try retrieving when no objects exists - new doc should be created # Try retrieving when no objects exists - new doc should be created
kwargs = dict(age=50, defaults={'name': 'User C'}) kwargs = dict(age=50, defaults={'name': 'User C'})
person, created = self.Person.objects.get_or_create(**kwargs) person, created = self.Person.objects.get_or_create(**kwargs)
self.assertEqual(created, True) self.assertEqual(created, True)
person = self.Person.objects.get(age=50) person = self.Person.objects.get(age=50)
self.assertEqual(person.name, "User C") self.assertEqual(person.name, "User C")
@ -328,7 +382,7 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(obj, person) self.assertEqual(obj, person)
obj = self.Person.objects(Q(name__iexact='gUIDO VAN rOSSU')).first() obj = self.Person.objects(Q(name__iexact='gUIDO VAN rOSSU')).first()
self.assertEqual(obj, None) self.assertEqual(obj, None)
# Test unsafe expressions # Test unsafe expressions
person = self.Person(name='Guido van Rossum [.\'Geek\']') person = self.Person(name='Guido van Rossum [.\'Geek\']')
person.save() person.save()
@ -593,6 +647,81 @@ class QuerySetTest(unittest.TestCase):
Email.drop_collection() Email.drop_collection()
def test_slicing_fields(self):
"""Ensure that query slicing an array works.
"""
class Numbers(Document):
n = ListField(IntField())
Numbers.drop_collection()
numbers = Numbers(n=[0,1,2,3,4,5,-5,-4,-3,-2,-1])
numbers.save()
# first three
numbers = Numbers.objects.fields(slice__n=3).get()
self.assertEquals(numbers.n, [0, 1, 2])
# last three
numbers = Numbers.objects.fields(slice__n=-3).get()
self.assertEquals(numbers.n, [-3, -2, -1])
# skip 2, limit 3
numbers = Numbers.objects.fields(slice__n=[2, 3]).get()
self.assertEquals(numbers.n, [2, 3, 4])
# skip to fifth from last, limit 4
numbers = Numbers.objects.fields(slice__n=[-5, 4]).get()
self.assertEquals(numbers.n, [-5, -4, -3, -2])
# skip to fifth from last, limit 10
numbers = Numbers.objects.fields(slice__n=[-5, 10]).get()
self.assertEquals(numbers.n, [-5, -4, -3, -2, -1])
# skip to fifth from last, limit 10 dict method
numbers = Numbers.objects.fields(n={"$slice": [-5, 10]}).get()
self.assertEquals(numbers.n, [-5, -4, -3, -2, -1])
def test_slicing_nested_fields(self):
"""Ensure that query slicing an embedded array works.
"""
class EmbeddedNumber(EmbeddedDocument):
n = ListField(IntField())
class Numbers(Document):
embedded = EmbeddedDocumentField(EmbeddedNumber)
Numbers.drop_collection()
numbers = Numbers()
numbers.embedded = EmbeddedNumber(n=[0,1,2,3,4,5,-5,-4,-3,-2,-1])
numbers.save()
# first three
numbers = Numbers.objects.fields(slice__embedded__n=3).get()
self.assertEquals(numbers.embedded.n, [0, 1, 2])
# last three
numbers = Numbers.objects.fields(slice__embedded__n=-3).get()
self.assertEquals(numbers.embedded.n, [-3, -2, -1])
# skip 2, limit 3
numbers = Numbers.objects.fields(slice__embedded__n=[2, 3]).get()
self.assertEquals(numbers.embedded.n, [2, 3, 4])
# skip to fifth from last, limit 4
numbers = Numbers.objects.fields(slice__embedded__n=[-5, 4]).get()
self.assertEquals(numbers.embedded.n, [-5, -4, -3, -2])
# skip to fifth from last, limit 10
numbers = Numbers.objects.fields(slice__embedded__n=[-5, 10]).get()
self.assertEquals(numbers.embedded.n, [-5, -4, -3, -2, -1])
# skip to fifth from last, limit 10 dict method
numbers = Numbers.objects.fields(embedded__n={"$slice": [-5, 10]}).get()
self.assertEquals(numbers.embedded.n, [-5, -4, -3, -2, -1])
def test_find_embedded(self): def test_find_embedded(self):
"""Ensure that an embedded document is properly returned from a query. """Ensure that an embedded document is properly returned from a query.
""" """
@ -674,7 +803,7 @@ class QuerySetTest(unittest.TestCase):
posts = [post.id for post in q] posts = [post.id for post in q]
published_posts = (post1, post2, post3, post5, post6) published_posts = (post1, post2, post3, post5, post6)
self.assertTrue(all(obj.id in posts for obj in published_posts)) self.assertTrue(all(obj.id in posts for obj in published_posts))
# Check Q object combination # Check Q object combination
date = datetime(2010, 1, 10) date = datetime(2010, 1, 10)
@ -714,7 +843,7 @@ class QuerySetTest(unittest.TestCase):
obj = self.Person.objects(Q(name__not=re.compile('^bob'))).first() obj = self.Person.objects(Q(name__not=re.compile('^bob'))).first()
self.assertEqual(obj, person) self.assertEqual(obj, person)
obj = self.Person.objects(Q(name__not=re.compile('^Gui'))).first() obj = self.Person.objects(Q(name__not=re.compile('^Gui'))).first()
self.assertEqual(obj, None) self.assertEqual(obj, None)
@ -786,7 +915,7 @@ class QuerySetTest(unittest.TestCase):
class BlogPost(Document): class BlogPost(Document):
name = StringField(db_field='doc-name') name = StringField(db_field='doc-name')
comments = ListField(EmbeddedDocumentField(Comment), comments = ListField(EmbeddedDocumentField(Comment),
db_field='cmnts') db_field='cmnts')
BlogPost.drop_collection() BlogPost.drop_collection()
@ -958,7 +1087,7 @@ class QuerySetTest(unittest.TestCase):
BlogPost.objects.update_one(unset__hits=1) BlogPost.objects.update_one(unset__hits=1)
post.reload() post.reload()
self.assertEqual(post.hits, None) self.assertEqual(post.hits, None)
BlogPost.drop_collection() BlogPost.drop_collection()
def test_update_pull(self): def test_update_pull(self):
@ -1027,7 +1156,7 @@ class QuerySetTest(unittest.TestCase):
""" """
# run a map/reduce operation spanning all posts # run a map/reduce operation spanning all posts
results = BlogPost.objects.map_reduce(map_f, reduce_f) results = BlogPost.objects.map_reduce(map_f, reduce_f, "myresults")
results = list(results) results = list(results)
self.assertEqual(len(results), 4) self.assertEqual(len(results), 4)
@ -1038,7 +1167,7 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(film.value, 3) self.assertEqual(film.value, 3)
BlogPost.drop_collection() BlogPost.drop_collection()
def test_map_reduce_with_custom_object_ids(self): def test_map_reduce_with_custom_object_ids(self):
"""Ensure that QuerySet.map_reduce works properly with custom """Ensure that QuerySet.map_reduce works properly with custom
primary keys. primary keys.
@ -1047,24 +1176,24 @@ class QuerySetTest(unittest.TestCase):
class BlogPost(Document): class BlogPost(Document):
title = StringField(primary_key=True) title = StringField(primary_key=True)
tags = ListField(StringField()) tags = ListField(StringField())
post1 = BlogPost(title="Post #1", tags=["mongodb", "mongoengine"]) post1 = BlogPost(title="Post #1", tags=["mongodb", "mongoengine"])
post2 = BlogPost(title="Post #2", tags=["django", "mongodb"]) post2 = BlogPost(title="Post #2", tags=["django", "mongodb"])
post3 = BlogPost(title="Post #3", tags=["hitchcock films"]) post3 = BlogPost(title="Post #3", tags=["hitchcock films"])
post1.save() post1.save()
post2.save() post2.save()
post3.save() post3.save()
self.assertEqual(BlogPost._fields['title'].db_field, '_id') self.assertEqual(BlogPost._fields['title'].db_field, '_id')
self.assertEqual(BlogPost._meta['id_field'], 'title') self.assertEqual(BlogPost._meta['id_field'], 'title')
map_f = """ map_f = """
function() { function() {
emit(this._id, 1); emit(this._id, 1);
} }
""" """
# reduce to a list of tag ids and counts # reduce to a list of tag ids and counts
reduce_f = """ reduce_f = """
function(key, values) { function(key, values) {
@ -1075,10 +1204,10 @@ class QuerySetTest(unittest.TestCase):
return total; return total;
} }
""" """
results = BlogPost.objects.map_reduce(map_f, reduce_f) results = BlogPost.objects.map_reduce(map_f, reduce_f, "myresults")
results = list(results) results = list(results)
self.assertEqual(results[0].object, post1) self.assertEqual(results[0].object, post1)
self.assertEqual(results[1].object, post2) self.assertEqual(results[1].object, post2)
self.assertEqual(results[2].object, post3) self.assertEqual(results[2].object, post3)
@ -1168,7 +1297,7 @@ class QuerySetTest(unittest.TestCase):
finalize_f = """ finalize_f = """
function(key, value) { function(key, value) {
// f(sec_since_epoch,y,z) = // f(sec_since_epoch,y,z) =
// log10(z) + ((y*sec_since_epoch) / 45000) // log10(z) + ((y*sec_since_epoch) / 45000)
z_10 = Math.log(value.z) / Math.log(10); z_10 = Math.log(value.z) / Math.log(10);
weight = z_10 + ((value.y * value.t_s) / 45000); weight = z_10 + ((value.y * value.t_s) / 45000);
@ -1187,6 +1316,7 @@ class QuerySetTest(unittest.TestCase):
results = Link.objects.order_by("-value") results = Link.objects.order_by("-value")
results = results.map_reduce(map_f, results = results.map_reduce(map_f,
reduce_f, reduce_f,
"myresults",
finalize_f=finalize_f, finalize_f=finalize_f,
scope=scope) scope=scope)
results = list(results) results = list(results)
@ -1289,6 +1419,7 @@ class QuerySetTest(unittest.TestCase):
class BlogPost(Document): class BlogPost(Document):
tags = ListField(StringField()) tags = ListField(StringField())
deleted = BooleanField(default=False) deleted = BooleanField(default=False)
date = DateTimeField(default=datetime.now)
@queryset_manager @queryset_manager
def objects(doc_cls, queryset): def objects(doc_cls, queryset):
@ -1296,7 +1427,7 @@ class QuerySetTest(unittest.TestCase):
@queryset_manager @queryset_manager
def music_posts(doc_cls, queryset): def music_posts(doc_cls, queryset):
return queryset(tags='music', deleted=False) return queryset(tags='music', deleted=False).order_by('-date')
BlogPost.drop_collection() BlogPost.drop_collection()
@ -1312,7 +1443,7 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual([p.id for p in BlogPost.objects], self.assertEqual([p.id for p in BlogPost.objects],
[post1.id, post2.id, post3.id]) [post1.id, post2.id, post3.id])
self.assertEqual([p.id for p in BlogPost.music_posts], self.assertEqual([p.id for p in BlogPost.music_posts],
[post1.id, post2.id]) [post2.id, post1.id])
BlogPost.drop_collection() BlogPost.drop_collection()
@ -1452,10 +1583,12 @@ class QuerySetTest(unittest.TestCase):
class Test(Document): class Test(Document):
testdict = DictField() testdict = DictField()
Test.drop_collection()
t = Test(testdict={'f': 'Value'}) t = Test(testdict={'f': 'Value'})
t.save() t.save()
self.assertEqual(len(Test.objects(testdict__f__startswith='Val')), 0) self.assertEqual(len(Test.objects(testdict__f__startswith='Val')), 1)
self.assertEqual(len(Test.objects(testdict__f='Value')), 1) self.assertEqual(len(Test.objects(testdict__f='Value')), 1)
Test.drop_collection() Test.drop_collection()
@ -1514,12 +1647,12 @@ class QuerySetTest(unittest.TestCase):
title = StringField() title = StringField()
date = DateTimeField() date = DateTimeField()
location = GeoPointField() location = GeoPointField()
def __unicode__(self): def __unicode__(self):
return self.title return self.title
Event.drop_collection() Event.drop_collection()
event1 = Event(title="Coltrane Motion @ Double Door", event1 = Event(title="Coltrane Motion @ Double Door",
date=datetime.now() - timedelta(days=1), date=datetime.now() - timedelta(days=1),
location=[41.909889, -87.677137]) location=[41.909889, -87.677137])
@ -1529,7 +1662,7 @@ class QuerySetTest(unittest.TestCase):
event3 = Event(title="Coltrane Motion @ Empty Bottle", event3 = Event(title="Coltrane Motion @ Empty Bottle",
date=datetime.now(), date=datetime.now(),
location=[41.900474, -87.686638]) location=[41.900474, -87.686638])
event1.save() event1.save()
event2.save() event2.save()
event3.save() event3.save()
@ -1541,7 +1674,7 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(events.count(), 3) self.assertEqual(events.count(), 3)
self.assertEqual(list(events), [event1, event3, event2]) self.assertEqual(list(events), [event1, event3, event2])
# find events within 5 miles of pitchfork office, chicago # find events within 5 degrees of pitchfork office, chicago
point_and_distance = [[41.9120459, -87.67892], 5] point_and_distance = [[41.9120459, -87.67892], 5]
events = Event.objects(location__within_distance=point_and_distance) events = Event.objects(location__within_distance=point_and_distance)
self.assertEqual(events.count(), 2) self.assertEqual(events.count(), 2)
@ -1549,24 +1682,24 @@ class QuerySetTest(unittest.TestCase):
self.assertTrue(event2 not in events) self.assertTrue(event2 not in events)
self.assertTrue(event1 in events) self.assertTrue(event1 in events)
self.assertTrue(event3 in events) self.assertTrue(event3 in events)
# ensure ordering is respected by "near" # ensure ordering is respected by "near"
events = Event.objects(location__near=[41.9120459, -87.67892]) events = Event.objects(location__near=[41.9120459, -87.67892])
events = events.order_by("-date") events = events.order_by("-date")
self.assertEqual(events.count(), 3) self.assertEqual(events.count(), 3)
self.assertEqual(list(events), [event3, event1, event2]) self.assertEqual(list(events), [event3, event1, event2])
# find events around san francisco # find events within 10 degrees of san francisco
point_and_distance = [[37.7566023, -122.415579], 10] point_and_distance = [[37.7566023, -122.415579], 10]
events = Event.objects(location__within_distance=point_and_distance) events = Event.objects(location__within_distance=point_and_distance)
self.assertEqual(events.count(), 1) self.assertEqual(events.count(), 1)
self.assertEqual(events[0], event2) self.assertEqual(events[0], event2)
# find events within 1 mile of greenpoint, broolyn, nyc, ny # find events within 1 degree of greenpoint, broolyn, nyc, ny
point_and_distance = [[40.7237134, -73.9509714], 1] point_and_distance = [[40.7237134, -73.9509714], 1]
events = Event.objects(location__within_distance=point_and_distance) events = Event.objects(location__within_distance=point_and_distance)
self.assertEqual(events.count(), 0) self.assertEqual(events.count(), 0)
# ensure ordering is respected by "within_distance" # ensure ordering is respected by "within_distance"
point_and_distance = [[41.9120459, -87.67892], 10] point_and_distance = [[41.9120459, -87.67892], 10]
events = Event.objects(location__within_distance=point_and_distance) events = Event.objects(location__within_distance=point_and_distance)
@ -1579,9 +1712,61 @@ class QuerySetTest(unittest.TestCase):
events = Event.objects(location__within_box=box) events = Event.objects(location__within_box=box)
self.assertEqual(events.count(), 1) self.assertEqual(events.count(), 1)
self.assertEqual(events[0].id, event2.id) self.assertEqual(events[0].id, event2.id)
Event.drop_collection() Event.drop_collection()
def test_spherical_geospatial_operators(self):
"""Ensure that spherical geospatial queries are working
"""
class Point(Document):
location = GeoPointField()
Point.drop_collection()
# These points are one degree apart, which (according to Google Maps)
# is about 110 km apart at this place on the Earth.
north_point = Point(location=[-122, 38]) # Near Concord, CA
south_point = Point(location=[-122, 37]) # Near Santa Cruz, CA
north_point.save()
south_point.save()
earth_radius = 6378.009; # in km (needs to be a float for dividing by)
# Finds both points because they are within 60 km of the reference
# point equidistant between them.
points = Point.objects(location__near_sphere=[-122, 37.5])
self.assertEqual(points.count(), 2)
# Same behavior for _within_spherical_distance
points = Point.objects(
location__within_spherical_distance=[[-122, 37.5], 60/earth_radius]
);
self.assertEqual(points.count(), 2)
# Finds both points, but orders the north point first because it's
# closer to the reference point to the north.
points = Point.objects(location__near_sphere=[-122, 38.5])
self.assertEqual(points.count(), 2)
self.assertEqual(points[0].id, north_point.id)
self.assertEqual(points[1].id, south_point.id)
# Finds both points, but orders the south point first because it's
# closer to the reference point to the south.
points = Point.objects(location__near_sphere=[-122, 36.5])
self.assertEqual(points.count(), 2)
self.assertEqual(points[0].id, south_point.id)
self.assertEqual(points[1].id, north_point.id)
# Finds only one point because only the first point is within 60km of
# the reference point to the south.
points = Point.objects(
location__within_spherical_distance=[[-122, 36.5], 60/earth_radius]
);
self.assertEqual(points.count(), 1)
self.assertEqual(points[0].id, south_point.id)
Point.drop_collection()
def test_custom_querysets(self): def test_custom_querysets(self):
"""Ensure that custom QuerySet classes may be used. """Ensure that custom QuerySet classes may be used.
""" """
@ -1602,6 +1787,53 @@ class QuerySetTest(unittest.TestCase):
Post.drop_collection() Post.drop_collection()
def test_custom_querysets_set_manager_directly(self):
"""Ensure that custom QuerySet classes may be used.
"""
class CustomQuerySet(QuerySet):
def not_empty(self):
return len(self) > 0
class CustomQuerySetManager(QuerySetManager):
queryset_class = CustomQuerySet
class Post(Document):
objects = CustomQuerySetManager()
Post.drop_collection()
self.assertTrue(isinstance(Post.objects, CustomQuerySet))
self.assertFalse(Post.objects.not_empty())
Post().save()
self.assertTrue(Post.objects.not_empty())
Post.drop_collection()
def test_custom_querysets_managers_directly(self):
"""Ensure that custom QuerySet classes may be used.
"""
class CustomQuerySetManager(QuerySetManager):
@staticmethod
def get_queryset(doc_cls, queryset):
return queryset(is_published=True)
class Post(Document):
is_published = BooleanField(default=False)
published = CustomQuerySetManager()
Post.drop_collection()
Post().save()
Post(is_published=True).save()
self.assertEquals(Post.objects.count(), 2)
self.assertEquals(Post.published.count(), 1)
Post.drop_collection()
def test_call_after_limits_set(self): def test_call_after_limits_set(self):
"""Ensure that re-filtering after slicing works """Ensure that re-filtering after slicing works
""" """
@ -1637,6 +1869,35 @@ class QuerySetTest(unittest.TestCase):
Number.drop_collection() Number.drop_collection()
def test_clone(self):
"""Ensure that cloning clones complex querysets
"""
class Number(Document):
n = IntField()
Number.drop_collection()
for i in xrange(1, 101):
t = Number(n=i)
t.save()
test = Number.objects
test2 = test.clone()
self.assertFalse(test == test2)
self.assertEqual(test.count(), test2.count())
test = test.filter(n__gt=11)
test2 = test.clone()
self.assertFalse(test == test2)
self.assertEqual(test.count(), test2.count())
test = test.limit(10)
test2 = test.clone()
self.assertFalse(test == test2)
self.assertEqual(test.count(), test2.count())
Number.drop_collection()
def test_unset_reference(self): def test_unset_reference(self):
class Comment(Document): class Comment(Document):
text = StringField() text = StringField()
@ -1658,6 +1919,39 @@ class QuerySetTest(unittest.TestCase):
Comment.drop_collection() Comment.drop_collection()
Post.drop_collection() Post.drop_collection()
def test_order_works_with_custom_db_field_names(self):
class Number(Document):
n = IntField(db_field='number')
Number.drop_collection()
n2 = Number.objects.create(n=2)
n1 = Number.objects.create(n=1)
self.assertEqual(list(Number.objects), [n2,n1])
self.assertEqual(list(Number.objects.order_by('n')), [n1,n2])
Number.drop_collection()
def test_order_works_with_primary(self):
"""Ensure that order_by and primary work.
"""
class Number(Document):
n = IntField(primary_key=True)
Number.drop_collection()
Number(n=1).save()
Number(n=2).save()
Number(n=3).save()
numbers = [n.n for n in Number.objects.order_by('-n')]
self.assertEquals([3, 2, 1], numbers)
numbers = [n.n for n in Number.objects.order_by('+n')]
self.assertEquals([1, 2, 3], numbers)
Number.drop_collection()
class QTest(unittest.TestCase): class QTest(unittest.TestCase):
@ -1679,7 +1973,7 @@ class QTest(unittest.TestCase):
query = {'age': {'$gte': 18}, 'name': 'test'} query = {'age': {'$gte': 18}, 'name': 'test'}
self.assertEqual((q1 & q2 & q3 & q4 & q5).to_query(Person), query) self.assertEqual((q1 & q2 & q3 & q4 & q5).to_query(Person), query)
def test_q_with_dbref(self): def test_q_with_dbref(self):
"""Ensure Q objects handle DBRefs correctly""" """Ensure Q objects handle DBRefs correctly"""
connect(db='mongoenginetest') connect(db='mongoenginetest')
@ -1721,7 +2015,7 @@ class QTest(unittest.TestCase):
query = Q(x__lt=100) & Q(y__ne='NotMyString') query = Q(x__lt=100) & Q(y__ne='NotMyString')
query &= Q(y__in=['a', 'b', 'c']) & Q(x__gt=-100) query &= Q(y__in=['a', 'b', 'c']) & Q(x__gt=-100)
mongo_query = { mongo_query = {
'x': {'$lt': 100, '$gt': -100}, 'x': {'$lt': 100, '$gt': -100},
'y': {'$ne': 'NotMyString', '$in': ['a', 'b', 'c']}, 'y': {'$ne': 'NotMyString', '$in': ['a', 'b', 'c']},
} }
self.assertEqual(query.to_query(TestDoc), mongo_query) self.assertEqual(query.to_query(TestDoc), mongo_query)
@ -1795,6 +2089,30 @@ class QTest(unittest.TestCase):
for condition in conditions: for condition in conditions:
self.assertTrue(condition in query['$or']) self.assertTrue(condition in query['$or'])
def test_q_clone(self):
class TestDoc(Document):
x = IntField()
TestDoc.drop_collection()
for i in xrange(1, 101):
t = TestDoc(x=i)
t.save()
# Check normal cases work without an error
test = TestDoc.objects(Q(x__lt=7) & Q(x__gt=3))
self.assertEqual(test.count(), 3)
test2 = test.clone()
self.assertEqual(test2.count(), 3)
self.assertFalse(test2 == test)
test2.filter(x=6)
self.assertEqual(test2.count(), 1)
self.assertEqual(test.count(), 3)
class QueryFieldListTest(unittest.TestCase): class QueryFieldListTest(unittest.TestCase):
def test_empty(self): def test_empty(self):
q = QueryFieldList() q = QueryFieldList()
@ -1805,51 +2123,52 @@ class QueryFieldListTest(unittest.TestCase):
def test_include_include(self): def test_include_include(self):
q = QueryFieldList() q = QueryFieldList()
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'a': True, 'b': True}) self.assertEqual(q.as_dict(), {'a': True, 'b': True})
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'b': True}) self.assertEqual(q.as_dict(), {'b': True})
def test_include_exclude(self): def test_include_exclude(self):
q = QueryFieldList() q = QueryFieldList()
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'a': True, 'b': True}) self.assertEqual(q.as_dict(), {'a': True, 'b': True})
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.EXCLUDE)
self.assertEqual(q.as_dict(), {'a': True}) self.assertEqual(q.as_dict(), {'a': True})
def test_exclude_exclude(self): def test_exclude_exclude(self):
q = QueryFieldList() q = QueryFieldList()
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.EXCLUDE)
self.assertEqual(q.as_dict(), {'a': False, 'b': False}) self.assertEqual(q.as_dict(), {'a': False, 'b': False})
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.EXCLUDE)
self.assertEqual(q.as_dict(), {'a': False, 'b': False, 'c': False}) self.assertEqual(q.as_dict(), {'a': False, 'b': False, 'c': False})
def test_exclude_include(self): def test_exclude_include(self):
q = QueryFieldList() q = QueryFieldList()
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.EXCLUDE)
self.assertEqual(q.as_dict(), {'a': False, 'b': False}) self.assertEqual(q.as_dict(), {'a': False, 'b': False})
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'c': True}) self.assertEqual(q.as_dict(), {'c': True})
def test_always_include(self): def test_always_include(self):
q = QueryFieldList(always_include=['x', 'y']) q = QueryFieldList(always_include=['x', 'y'])
q += QueryFieldList(fields=['a', 'b', 'x'], direction=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=['a', 'b', 'x'], value=QueryFieldList.EXCLUDE)
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True}) self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True})
def test_reset(self): def test_reset(self):
q = QueryFieldList(always_include=['x', 'y']) q = QueryFieldList(always_include=['x', 'y'])
q += QueryFieldList(fields=['a', 'b', 'x'], direction=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=['a', 'b', 'x'], value=QueryFieldList.EXCLUDE)
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True}) self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True})
q.reset() q.reset()
self.assertFalse(q) self.assertFalse(q)
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'b': True, 'c': True}) self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'b': True, 'c': True})
def test_using_a_slice(self):
q = QueryFieldList()
q += QueryFieldList(fields=['a'], value={"$slice": 5})
self.assertEqual(q.as_dict(), {'a': {"$slice": 5}})
if __name__ == '__main__': if __name__ == '__main__':