Merge branch 'dev' into pull_124

This commit is contained in:
Ross Lawley 2011-05-25 09:54:56 +01:00
commit 2ce70448b0
10 changed files with 1184 additions and 240 deletions

2
.gitignore vendored
View File

@ -1,3 +1,5 @@
.*
!.gitignore
*.pyc *.pyc
.*.swp .*.swp
*.egg *.egg

View File

@ -7,16 +7,26 @@ import pymongo
import pymongo.objectid import pymongo.objectid
_document_registry = {} class NotRegistered(Exception):
pass
def get_document(name):
return _document_registry[name]
class ValidationError(Exception): class ValidationError(Exception):
pass pass
_document_registry = {}
def get_document(name):
if name not in _document_registry:
raise NotRegistered("""
`%s` has not been registered in the document registry.
Importing the document class automatically registers it, has it
been imported?
""".strip() % name)
return _document_registry[name]
class BaseField(object): class BaseField(object):
"""A base class for fields in a MongoDB document. Instances of this class """A base class for fields in a MongoDB document. Instances of this class
may be added to subclasses of `Document` to define a document's schema. may be added to subclasses of `Document` to define a document's schema.
@ -243,7 +253,16 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
# __metaclass__ is only set on the class with the __metaclass__ # __metaclass__ is only set on the class with the __metaclass__
# attribute (i.e. it is not set on subclasses). This differentiates # attribute (i.e. it is not set on subclasses). This differentiates
# 'real' documents from the 'Document' class # 'real' documents from the 'Document' class
if attrs.get('__metaclass__') == TopLevelDocumentMetaclass: #
# Also assume a class is abstract if it has abstract set to True in
# its meta dictionary. This allows custom Document superclasses.
if (attrs.get('__metaclass__') == TopLevelDocumentMetaclass or
('meta' in attrs and attrs['meta'].get('abstract', False))):
# Make sure no base class was non-abstract
non_abstract_bases = [b for b in bases
if hasattr(b,'_meta') and not b._meta.get('abstract', False)]
if non_abstract_bases:
raise ValueError("Abstract document cannot have non-abstract base")
return super_new(cls, name, bases, attrs) return super_new(cls, name, bases, attrs)
collection = name.lower() collection = name.lower()
@ -266,6 +285,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
base_indexes += base._meta.get('indexes', []) base_indexes += base._meta.get('indexes', [])
meta = { meta = {
'abstract': False,
'collection': collection, 'collection': collection,
'max_documents': None, 'max_documents': None,
'max_size': None, 'max_size': None,
@ -289,13 +309,39 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
new_class = super_new(cls, name, bases, attrs) new_class = super_new(cls, name, bases, attrs)
# Provide a default queryset unless one has been manually provided # Provide a default queryset unless one has been manually provided
if not hasattr(new_class, 'objects'): manager = attrs.get('objects', QuerySetManager())
new_class.objects = QuerySetManager() if hasattr(manager, 'queryset_class'):
meta['queryset_class'] = manager.queryset_class
new_class.objects = manager
user_indexes = [QuerySet._build_index_spec(new_class, spec) user_indexes = [QuerySet._build_index_spec(new_class, spec)
for spec in meta['indexes']] + base_indexes for spec in meta['indexes']] + base_indexes
new_class._meta['indexes'] = user_indexes new_class._meta['indexes'] = user_indexes
unique_indexes = cls._unique_with_indexes(new_class)
new_class._meta['unique_indexes'] = unique_indexes
for field_name, field in new_class._fields.items():
# Check for custom primary key
if field.primary_key:
current_pk = new_class._meta['id_field']
if current_pk and current_pk != field_name:
raise ValueError('Cannot override primary key field')
if not current_pk:
new_class._meta['id_field'] = field_name
# Make 'Document.id' an alias to the real primary key field
new_class.id = field
if not new_class._meta['id_field']:
new_class._meta['id_field'] = 'id'
new_class._fields['id'] = ObjectIdField(db_field='_id')
new_class.id = new_class._fields['id']
return new_class
@classmethod
def _unique_with_indexes(cls, new_class, namespace=""):
unique_indexes = [] unique_indexes = []
for field_name, field in new_class._fields.items(): for field_name, field in new_class._fields.items():
# Generate a list of indexes needed by uniqueness constraints # Generate a list of indexes needed by uniqueness constraints
@ -321,28 +367,16 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
unique_fields += unique_with unique_fields += unique_with
# Add the new index to the list # Add the new index to the list
index = [(f, pymongo.ASCENDING) for f in unique_fields] index = [("%s%s" % (namespace, f), pymongo.ASCENDING) for f in unique_fields]
unique_indexes.append(index) unique_indexes.append(index)
# Check for custom primary key # Grab any embedded document field unique indexes
if field.primary_key: if field.__class__.__name__ == "EmbeddedDocumentField":
current_pk = new_class._meta['id_field'] field_namespace = "%s." % field_name
if current_pk and current_pk != field_name: unique_indexes += cls._unique_with_indexes(field.document_type,
raise ValueError('Cannot override primary key field') field_namespace)
if not current_pk: return unique_indexes
new_class._meta['id_field'] = field_name
# Make 'Document.id' an alias to the real primary key field
new_class.id = field
new_class._meta['unique_indexes'] = unique_indexes
if not new_class._meta['id_field']:
new_class._meta['id_field'] = 'id'
new_class._fields['id'] = ObjectIdField(db_field='_id')
new_class.id = new_class._fields['id']
return new_class
class BaseDocument(object): class BaseDocument(object):
@ -461,7 +495,7 @@ class BaseDocument(object):
self._meta.get('allow_inheritance', True) == False): self._meta.get('allow_inheritance', True) == False):
data['_cls'] = self._class_name data['_cls'] = self._class_name
data['_types'] = self._superclasses.keys() + [self._class_name] data['_types'] = self._superclasses.keys() + [self._class_name]
if data.has_key('_id') and not data['_id']: if data.has_key('_id') and data['_id'] is None:
del data['_id'] del data['_id']
return data return data

View File

@ -1,5 +1,6 @@
from pymongo import Connection from pymongo import Connection
import multiprocessing import multiprocessing
import threading
__all__ = ['ConnectionError', 'connect'] __all__ = ['ConnectionError', 'connect']
@ -22,17 +23,22 @@ class ConnectionError(Exception):
def _get_connection(reconnect=False): def _get_connection(reconnect=False):
"""Handles the connection to the database
"""
global _connection global _connection
identity = get_identity() identity = get_identity()
# Connect to the database if not already connected # Connect to the database if not already connected
if _connection.get(identity) is None or reconnect: if _connection.get(identity) is None or reconnect:
try: try:
_connection[identity] = Connection(**_connection_settings) _connection[identity] = Connection(**_connection_settings)
except: except Exception, e:
raise ConnectionError('Cannot connect to the database') raise ConnectionError("Cannot connect to the database:\n%s" % e)
return _connection[identity] return _connection[identity]
def _get_db(reconnect=False): def _get_db(reconnect=False):
"""Handles database connections and authentication based on the current
identity
"""
global _db, _connection global _db, _connection
identity = get_identity() identity = get_identity()
# Connect if not already connected # Connect if not already connected
@ -52,8 +58,13 @@ def _get_db(reconnect=False):
return _db[identity] return _db[identity]
def get_identity(): def get_identity():
"""Creates an identity key based on the current process and thread
identity.
"""
identity = multiprocessing.current_process()._identity identity = multiprocessing.current_process()._identity
identity = 0 if not identity else identity[0] identity = 0 if not identity else identity[0]
identity = (identity, threading.current_thread().ident)
return identity return identity
def connect(db, username=None, password=None, **kwargs): def connect(db, username=None, password=None, **kwargs):

View File

@ -86,7 +86,7 @@ class User(Document):
else: else:
email = '@'.join([email_name, domain_part.lower()]) email = '@'.join([email_name, domain_part.lower()])
user = User(username=username, email=email, date_joined=now) user = cls(username=username, email=email, date_joined=now)
user.set_password(password) user.set_password(password)
user.save() user.save()
return user return user

View File

@ -56,7 +56,7 @@ class Document(BaseDocument):
__metaclass__ = TopLevelDocumentMetaclass __metaclass__ = TopLevelDocumentMetaclass
def save(self, safe=True, force_insert=False, validate=True): def save(self, safe=True, force_insert=False, validate=True, write_options=None):
"""Save the :class:`~mongoengine.Document` to the database. If the """Save the :class:`~mongoengine.Document` to the database. If the
document already exists, it will be updated, otherwise it will be document already exists, it will be updated, otherwise it will be
created. created.
@ -68,16 +68,26 @@ class Document(BaseDocument):
:param force_insert: only try to create a new document, don't allow :param force_insert: only try to create a new document, don't allow
updates of existing documents updates of existing documents
:param validate: validates the document; set to ``False`` to skip. :param validate: validates the document; set to ``False`` to skip.
:param write_options: Extra keyword arguments are passed down to
:meth:`~pymongo.collection.Collection.save` OR
:meth:`~pymongo.collection.Collection.insert`
which will be used as options for the resultant ``getLastError`` command.
For example, ``save(..., w=2, fsync=True)`` will wait until at least two servers
have recorded the write and will force an fsync on each server being written to.
""" """
if validate: if validate:
self.validate() self.validate()
if not write_options:
write_options = {}
doc = self.to_mongo() doc = self.to_mongo()
try: try:
collection = self.__class__.objects._collection collection = self.__class__.objects._collection
if force_insert: if force_insert:
object_id = collection.insert(doc, safe=safe) object_id = collection.insert(doc, safe=safe, **write_options)
else: else:
object_id = collection.save(doc, safe=safe) object_id = collection.save(doc, safe=safe, **write_options)
except pymongo.errors.OperationFailure, err: except pymongo.errors.OperationFailure, err:
message = 'Could not save document (%s)' message = 'Could not save document (%s)'
if u'duplicate key' in unicode(err): if u'duplicate key' in unicode(err):

View File

@ -17,7 +17,7 @@ import warnings
__all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField', __all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField',
'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField', 'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField',
'ObjectIdField', 'ReferenceField', 'ValidationError', 'ObjectIdField', 'ReferenceField', 'ValidationError', 'MapField',
'DecimalField', 'URLField', 'GenericReferenceField', 'FileField', 'DecimalField', 'URLField', 'GenericReferenceField', 'FileField',
'BinaryField', 'SortedListField', 'EmailField', 'GeoPointField'] 'BinaryField', 'SortedListField', 'EmailField', 'GeoPointField']
@ -449,7 +449,108 @@ class DictField(BaseField):
'contain "." or "$" characters') 'contain "." or "$" characters')
def lookup_member(self, member_name): def lookup_member(self, member_name):
return self.basecls(db_field=member_name) return DictField(basecls=self.basecls, db_field=member_name)
def prepare_query_value(self, op, value):
match_operators = ['contains', 'icontains', 'startswith',
'istartswith', 'endswith', 'iendswith',
'exact', 'iexact']
if op in match_operators and isinstance(value, basestring):
return StringField().prepare_query_value(op, value)
return super(DictField,self).prepare_query_value(op, value)
class MapField(BaseField):
"""A field that maps a name to a specified field type. Similar to
a DictField, except the 'value' of each item must match the specified
field type.
.. versionadded:: 0.5
"""
def __init__(self, field=None, *args, **kwargs):
if not isinstance(field, BaseField):
raise ValidationError('Argument to MapField constructor must be '
'a valid field')
self.field = field
kwargs.setdefault('default', lambda: {})
super(MapField, self).__init__(*args, **kwargs)
def validate(self, value):
"""Make sure that a list of valid fields is being used.
"""
if not isinstance(value, dict):
raise ValidationError('Only dictionaries may be used in a '
'DictField')
if any(('.' in k or '$' in k) for k in value):
raise ValidationError('Invalid dictionary key name - keys may not '
'contain "." or "$" characters')
try:
[self.field.validate(item) for item in value.values()]
except Exception, err:
raise ValidationError('Invalid MapField item (%s)' % str(item))
def __get__(self, instance, owner):
"""Descriptor to automatically dereference references.
"""
if instance is None:
# Document class being used rather than a document object
return self
if isinstance(self.field, ReferenceField):
referenced_type = self.field.document_type
# Get value from document instance if available
value_dict = instance._data.get(self.name)
if value_dict:
deref_dict = []
for key,value in value_dict.iteritems():
# Dereference DBRefs
if isinstance(value, (pymongo.dbref.DBRef)):
value = _get_db().dereference(value)
deref_dict[key] = referenced_type._from_son(value)
else:
deref_dict[key] = value
instance._data[self.name] = deref_dict
if isinstance(self.field, GenericReferenceField):
value_dict = instance._data.get(self.name)
if value_dict:
deref_dict = []
for key,value in value_dict.iteritems():
# Dereference DBRefs
if isinstance(value, (dict, pymongo.son.SON)):
deref_dict[key] = self.field.dereference(value)
else:
deref_dict[key] = value
instance._data[self.name] = deref_dict
return super(MapField, self).__get__(instance, owner)
def to_python(self, value):
return dict( [(key,self.field.to_python(item)) for key,item in value.iteritems()] )
def to_mongo(self, value):
return dict( [(key,self.field.to_mongo(item)) for key,item in value.iteritems()] )
def prepare_query_value(self, op, value):
return self.field.prepare_query_value(op, value)
def lookup_member(self, member_name):
return self.field.lookup_member(member_name)
def _set_owner_document(self, owner_document):
self.field.owner_document = owner_document
self._owner_document = owner_document
def _get_owner_document(self, owner_document):
self._owner_document = owner_document
owner_document = property(_get_owner_document, _set_owner_document)
class ReferenceField(BaseField): class ReferenceField(BaseField):
"""A reference to a document that will be automatically dereferenced on """A reference to a document that will be automatically dereferenced on
@ -522,6 +623,9 @@ class GenericReferenceField(BaseField):
"""A reference to *any* :class:`~mongoengine.document.Document` subclass """A reference to *any* :class:`~mongoengine.document.Document` subclass
that will be automatically dereferenced on access (lazily). that will be automatically dereferenced on access (lazily).
note: Any documents used as a generic reference must be registered in the
document registry. Importing the model will automatically register it.
.. versionadded:: 0.3 .. versionadded:: 0.3
""" """
@ -601,6 +705,7 @@ class GridFSProxy(object):
self.fs = gridfs.GridFS(_get_db()) # Filesystem instance self.fs = gridfs.GridFS(_get_db()) # Filesystem instance
self.newfile = None # Used for partial writes self.newfile = None # Used for partial writes
self.grid_id = grid_id # Store GridFS id for file self.grid_id = grid_id # Store GridFS id for file
self.gridout = None
def __getattr__(self, name): def __getattr__(self, name):
obj = self.get() obj = self.get()
@ -614,8 +719,12 @@ class GridFSProxy(object):
def get(self, id=None): def get(self, id=None):
if id: if id:
self.grid_id = id self.grid_id = id
if self.grid_id is None:
return None
try: try:
return self.fs.get(id or self.grid_id) if self.gridout is None:
self.gridout = self.fs.get(self.grid_id)
return self.gridout
except: except:
# File has been deleted # File has been deleted
return None return None
@ -645,9 +754,9 @@ class GridFSProxy(object):
self.grid_id = self.newfile._id self.grid_id = self.newfile._id
self.newfile.writelines(lines) self.newfile.writelines(lines)
def read(self): def read(self, size=-1):
try: try:
return self.get().read() return self.get().read(size)
except: except:
return None return None
@ -655,6 +764,7 @@ class GridFSProxy(object):
# Delete file from GridFS, FileField still remains # Delete file from GridFS, FileField still remains
self.fs.delete(self.grid_id) self.fs.delete(self.grid_id)
self.grid_id = None self.grid_id = None
self.gridout = None
def replace(self, file, **kwargs): def replace(self, file, **kwargs):
self.delete() self.delete()

View File

@ -8,6 +8,7 @@ import pymongo.objectid
import re import re
import copy import copy
import itertools import itertools
import operator
__all__ = ['queryset_manager', 'Q', 'InvalidQueryError', __all__ = ['queryset_manager', 'Q', 'InvalidQueryError',
'InvalidCollectionError', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY'] 'InvalidCollectionError', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY']
@ -280,30 +281,30 @@ class QueryFieldList(object):
ONLY = True ONLY = True
EXCLUDE = False EXCLUDE = False
def __init__(self, fields=[], direction=ONLY, always_include=[]): def __init__(self, fields=[], value=ONLY, always_include=[]):
self.direction = direction self.value = value
self.fields = set(fields) self.fields = set(fields)
self.always_include = set(always_include) self.always_include = set(always_include)
def as_dict(self): def as_dict(self):
return dict((field, self.direction) for field in self.fields) return dict((field, self.value) for field in self.fields)
def __add__(self, f): def __add__(self, f):
if not self.fields: if not self.fields:
self.fields = f.fields self.fields = f.fields
self.direction = f.direction self.value = f.value
elif self.direction is self.ONLY and f.direction is self.ONLY: elif self.value is self.ONLY and f.value is self.ONLY:
self.fields = self.fields.intersection(f.fields) self.fields = self.fields.intersection(f.fields)
elif self.direction is self.EXCLUDE and f.direction is self.EXCLUDE: elif self.value is self.EXCLUDE and f.value is self.EXCLUDE:
self.fields = self.fields.union(f.fields) self.fields = self.fields.union(f.fields)
elif self.direction is self.ONLY and f.direction is self.EXCLUDE: elif self.value is self.ONLY and f.value is self.EXCLUDE:
self.fields -= f.fields self.fields -= f.fields
elif self.direction is self.EXCLUDE and f.direction is self.ONLY: elif self.value is self.EXCLUDE and f.value is self.ONLY:
self.direction = self.ONLY self.value = self.ONLY
self.fields = f.fields - self.fields self.fields = f.fields - self.fields
if self.always_include: if self.always_include:
if self.direction is self.ONLY and self.fields: if self.value is self.ONLY and self.fields:
self.fields = self.fields.union(self.always_include) self.fields = self.fields.union(self.always_include)
else: else:
self.fields -= self.always_include self.fields -= self.always_include
@ -311,7 +312,7 @@ class QueryFieldList(object):
def reset(self): def reset(self):
self.fields = set([]) self.fields = set([])
self.direction = self.ONLY self.value = self.ONLY
def __nonzero__(self): def __nonzero__(self):
return bool(self.fields) return bool(self.fields)
@ -334,6 +335,7 @@ class QuerySet(object):
self._ordering = [] self._ordering = []
self._snapshot = False self._snapshot = False
self._timeout = True self._timeout = True
self._class_check = True
# If inheritance is allowed, only return instances and instances of # If inheritance is allowed, only return instances and instances of
# subclasses of the class being used # subclasses of the class being used
@ -344,11 +346,26 @@ class QuerySet(object):
self._limit = None self._limit = None
self._skip = None self._skip = None
def clone(self):
"""Creates a copy of the current :class:`~mongoengine.queryset.QuerySet`"""
c = self.__class__(self._document, self._collection_obj)
copy_props = ('_initial_query', '_query_obj', '_where_clause',
'_loaded_fields', '_ordering', '_snapshot',
'_timeout', '_limit', '_skip')
for prop in copy_props:
val = getattr(self, prop)
setattr(c, prop, copy.deepcopy(val))
return c
@property @property
def _query(self): def _query(self):
if self._mongo_query is None: if self._mongo_query is None:
self._mongo_query = self._query_obj.to_query(self._document) self._mongo_query = self._query_obj.to_query(self._document)
self._mongo_query.update(self._initial_query) if self._class_check:
self._mongo_query.update(self._initial_query)
return self._mongo_query return self._mongo_query
def ensure_index(self, key_or_list, drop_dups=False, background=False, def ensure_index(self, key_or_list, drop_dups=False, background=False,
@ -399,7 +416,7 @@ class QuerySet(object):
return index_list return index_list
def __call__(self, q_obj=None, **query): def __call__(self, q_obj=None, class_check=True, **query):
"""Filter the selected documents by calling the """Filter the selected documents by calling the
:class:`~mongoengine.queryset.QuerySet` with a query. :class:`~mongoengine.queryset.QuerySet` with a query.
@ -407,16 +424,17 @@ class QuerySet(object):
the query; the :class:`~mongoengine.queryset.QuerySet` is filtered the query; the :class:`~mongoengine.queryset.QuerySet` is filtered
multiple times with different :class:`~mongoengine.queryset.Q` multiple times with different :class:`~mongoengine.queryset.Q`
objects, only the last one will be used objects, only the last one will be used
:param class_check: If set to False bypass class name check when
querying collection
:param query: Django-style query keyword arguments :param query: Django-style query keyword arguments
""" """
#if q_obj:
#self._where_clause = q_obj.as_js(self._document)
query = Q(**query) query = Q(**query)
if q_obj: if q_obj:
query &= q_obj query &= q_obj
self._query_obj &= query self._query_obj &= query
self._mongo_query = None self._mongo_query = None
self._cursor_obj = None self._cursor_obj = None
self._class_check = class_check
return self return self
def filter(self, *q_objs, **query): def filter(self, *q_objs, **query):
@ -440,17 +458,17 @@ class QuerySet(object):
drop_dups = self._document._meta.get('index_drop_dups', False) drop_dups = self._document._meta.get('index_drop_dups', False)
index_opts = self._document._meta.get('index_options', {}) index_opts = self._document._meta.get('index_options', {})
# Ensure indexes created by uniqueness constraints
for index in self._document._meta['unique_indexes']:
self._collection.ensure_index(index, unique=True,
background=background, drop_dups=drop_dups, **index_opts)
# Ensure document-defined indexes are created # Ensure document-defined indexes are created
if self._document._meta['indexes']: if self._document._meta['indexes']:
for key_or_list in self._document._meta['indexes']: for key_or_list in self._document._meta['indexes']:
self._collection.ensure_index(key_or_list, self._collection.ensure_index(key_or_list,
background=background, **index_opts) background=background, **index_opts)
# Ensure indexes created by uniqueness constraints
for index in self._document._meta['unique_indexes']:
self._collection.ensure_index(index, unique=True,
background=background, drop_dups=drop_dups, **index_opts)
# If _types is being used (for polymorphism), it needs an index # If _types is being used (for polymorphism), it needs an index
if '_types' in self._query: if '_types' in self._query:
self._collection.ensure_index('_types', self._collection.ensure_index('_types',
@ -504,6 +522,15 @@ class QuerySet(object):
fields = [] fields = []
field = None field = None
for field_name in parts: for field_name in parts:
# Handle ListField indexing:
if field_name.isdigit():
try:
field = field.field
except AttributeError, err:
raise InvalidQueryError(
"Can't use index on unsubscriptable field (%s)" % err)
fields.append(field_name)
continue
if field is None: if field is None:
# Look up first field from the document # Look up first field from the document
if field_name == 'pk': if field_name == 'pk':
@ -528,12 +555,12 @@ class QuerySet(object):
return '.'.join(parts) return '.'.join(parts)
@classmethod @classmethod
def _transform_query(cls, _doc_cls=None, **query): def _transform_query(cls, _doc_cls=None, _field_operation=False, **query):
"""Transform a query from Django-style format to Mongo format. """Transform a query from Django-style format to Mongo format.
""" """
operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod',
'all', 'size', 'exists', 'not'] 'all', 'size', 'exists', 'not']
geo_operators = ['within_distance', 'within_box', 'near'] geo_operators = ['within_distance', 'within_spherical_distance', 'within_box', 'near', 'near_sphere']
match_operators = ['contains', 'icontains', 'startswith', match_operators = ['contains', 'icontains', 'startswith',
'istartswith', 'endswith', 'iendswith', 'istartswith', 'endswith', 'iendswith',
'exact', 'iexact'] 'exact', 'iexact']
@ -577,8 +604,12 @@ class QuerySet(object):
if op in geo_operators: if op in geo_operators:
if op == "within_distance": if op == "within_distance":
value = {'$within': {'$center': value}} value = {'$within': {'$center': value}}
elif op == "within_spherical_distance":
value = {'$within': {'$centerSphere': value}}
elif op == "near": elif op == "near":
value = {'$near': value} value = {'$near': value}
elif op == "near_sphere":
value = {'$nearSphere': value}
elif op == 'within_box': elif op == 'within_box':
value = {'$within': {'$box': value}} value = {'$within': {'$box': value}}
else: else:
@ -620,7 +651,7 @@ class QuerySet(object):
raise self._document.DoesNotExist("%s matching query does not exist." raise self._document.DoesNotExist("%s matching query does not exist."
% self._document._class_name) % self._document._class_name)
def get_or_create(self, *q_objs, **query): def get_or_create(self, write_options=None, *q_objs, **query):
"""Retrieve unique object or create, if it doesn't exist. Returns a tuple of """Retrieve unique object or create, if it doesn't exist. Returns a tuple of
``(object, created)``, where ``object`` is the retrieved or created object ``(object, created)``, where ``object`` is the retrieved or created object
and ``created`` is a boolean specifying whether a new object was created. Raises and ``created`` is a boolean specifying whether a new object was created. Raises
@ -630,6 +661,10 @@ class QuerySet(object):
dictionary of default values for the new document may be provided as a dictionary of default values for the new document may be provided as a
keyword argument called :attr:`defaults`. keyword argument called :attr:`defaults`.
:param write_options: optional extra keyword arguments used if we
have to create a new document.
Passes any write_options onto :meth:`~mongoengine.document.Document.save`
.. versionadded:: 0.3 .. versionadded:: 0.3
""" """
defaults = query.get('defaults', {}) defaults = query.get('defaults', {})
@ -641,7 +676,7 @@ class QuerySet(object):
if count == 0: if count == 0:
query.update(defaults) query.update(defaults)
doc = self._document(**query) doc = self._document(**query)
doc.save() doc.save(write_options=write_options)
return doc, True return doc, True
elif count == 1: elif count == 1:
return self.first(), False return self.first(), False
@ -725,7 +760,7 @@ class QuerySet(object):
def __len__(self): def __len__(self):
return self.count() return self.count()
def map_reduce(self, map_f, reduce_f, finalize_f=None, limit=None, def map_reduce(self, map_f, reduce_f, output, finalize_f=None, limit=None,
scope=None, keep_temp=False): scope=None, keep_temp=False):
"""Perform a map/reduce query using the current query spec """Perform a map/reduce query using the current query spec
and ordering. While ``map_reduce`` respects ``QuerySet`` chaining, and ordering. While ``map_reduce`` respects ``QuerySet`` chaining,
@ -739,26 +774,26 @@ class QuerySet(object):
:param map_f: map function, as :class:`~pymongo.code.Code` or string :param map_f: map function, as :class:`~pymongo.code.Code` or string
:param reduce_f: reduce function, as :param reduce_f: reduce function, as
:class:`~pymongo.code.Code` or string :class:`~pymongo.code.Code` or string
:param output: output collection name
:param finalize_f: finalize function, an optional function that :param finalize_f: finalize function, an optional function that
performs any post-reduction processing. performs any post-reduction processing.
:param scope: values to insert into map/reduce global scope. Optional. :param scope: values to insert into map/reduce global scope. Optional.
:param limit: number of objects from current query to provide :param limit: number of objects from current query to provide
to map/reduce method to map/reduce method
:param keep_temp: keep temporary table (boolean, default ``True``)
Returns an iterator yielding Returns an iterator yielding
:class:`~mongoengine.document.MapReduceDocument`. :class:`~mongoengine.document.MapReduceDocument`.
.. note:: Map/Reduce requires server version **>= 1.1.1**. The PyMongo .. note:: Map/Reduce changed in server version **>= 1.7.4**. The PyMongo
:meth:`~pymongo.collection.Collection.map_reduce` helper requires :meth:`~pymongo.collection.Collection.map_reduce` helper requires
PyMongo version **>= 1.2**. PyMongo version **>= 1.11**.
.. versionadded:: 0.3 .. versionadded:: 0.3
""" """
from document import MapReduceDocument from document import MapReduceDocument
if not hasattr(self._collection, "map_reduce"): if not hasattr(self._collection, "map_reduce"):
raise NotImplementedError("Requires MongoDB >= 1.1.1") raise NotImplementedError("Requires MongoDB >= 1.7.1")
map_f_scope = {} map_f_scope = {}
if isinstance(map_f, pymongo.code.Code): if isinstance(map_f, pymongo.code.Code):
@ -789,8 +824,7 @@ class QuerySet(object):
if limit: if limit:
mr_args['limit'] = limit mr_args['limit'] = limit
results = self._collection.map_reduce(map_f, reduce_f, output, **mr_args)
results = self._collection.map_reduce(map_f, reduce_f, **mr_args)
results = results.find() results = results.find()
if self._ordering: if self._ordering:
@ -868,10 +902,8 @@ class QuerySet(object):
.. versionadded:: 0.3 .. versionadded:: 0.3
""" """
fields = self._fields_to_dbfields(fields) fields = dict([(f, QueryFieldList.ONLY) for f in fields])
self._loaded_fields += QueryFieldList(fields, direction=QueryFieldList.ONLY) return self.fields(**fields)
return self
def exclude(self, *fields): def exclude(self, *fields):
"""Opposite to .only(), exclude some document's fields. :: """Opposite to .only(), exclude some document's fields. ::
@ -880,8 +912,44 @@ class QuerySet(object):
:param fields: fields to exclude :param fields: fields to exclude
""" """
fields = self._fields_to_dbfields(fields) fields = dict([(f, QueryFieldList.EXCLUDE) for f in fields])
self._loaded_fields += QueryFieldList(fields, direction=QueryFieldList.EXCLUDE) return self.fields(**fields)
def fields(self, **kwargs):
"""Manipulate how you load this document's fields. Used by `.only()`
and `.exclude()` to manipulate which fields to retrieve. Fields also
allows for a greater level of control for example:
Retrieving a Subrange of Array Elements
---------------------------------------
You can use the $slice operator to retrieve a subrange of elements in
an array ::
post = BlogPost.objects(...).fields(slice__comments=5) // first 5 comments
:param kwargs: A dictionary identifying what to include
.. versionadded:: 0.5
"""
# Check for an operator and transform to mongo-style if there is
operators = ["slice"]
cleaned_fields = []
for key, value in kwargs.items():
parts = key.split('__')
op = None
if parts[0] in operators:
op = parts.pop(0)
value = {'$' + op: value}
key = '.'.join(parts)
cleaned_fields.append((key, value))
fields = sorted(cleaned_fields, key=operator.itemgetter(1))
for value, group in itertools.groupby(fields, lambda x: x[1]):
fields = [field for field, value in group]
fields = self._fields_to_dbfields(fields)
self._loaded_fields += QueryFieldList(fields, value=value)
return self return self
def all_fields(self): def all_fields(self):
@ -917,6 +985,10 @@ class QuerySet(object):
if key[0] in ('-', '+'): if key[0] in ('-', '+'):
key = key[1:] key = key[1:]
key = key.replace('__', '.') key = key.replace('__', '.')
try:
key = QuerySet._translate_field_name(self._document, key)
except:
pass
key_list.append((key, direction)) key_list.append((key, direction))
self._ordering = key_list self._ordering = key_list
@ -1007,10 +1079,17 @@ class QuerySet(object):
if _doc_cls: if _doc_cls:
# Switch field names to proper names [set in Field(name='foo')] # Switch field names to proper names [set in Field(name='foo')]
fields = QuerySet._lookup_field(_doc_cls, parts) fields = QuerySet._lookup_field(_doc_cls, parts)
parts = [field.db_field for field in fields] parts = []
for field in fields:
if isinstance(field, str):
parts.append(field)
else:
parts.append(field.db_field)
# Convert value to proper value # Convert value to proper value
field = fields[-1] field = fields[-1]
if op in (None, 'set', 'push', 'pull', 'addToSet'): if op in (None, 'set', 'push', 'pull', 'addToSet'):
value = field.prepare_query_value(op, value) value = field.prepare_query_value(op, value)
elif op in ('pushAll', 'pullAll'): elif op in ('pushAll', 'pullAll'):
@ -1029,22 +1108,27 @@ class QuerySet(object):
return mongo_update return mongo_update
def update(self, safe_update=True, upsert=False, **update): def update(self, safe_update=True, upsert=False, write_options=None, **update):
"""Perform an atomic update on the fields matched by the query. When """Perform an atomic update on the fields matched by the query. When
``safe_update`` is used, the number of affected documents is returned. ``safe_update`` is used, the number of affected documents is returned.
:param safe: check if the operation succeeded before returning :param safe_update: check if the operation succeeded before returning
:param update: Django-style update keyword arguments :param upsert: Any existing document with that "_id" is overwritten.
:param write_options: extra keyword arguments for :meth:`~pymongo.collection.Collection.update`
.. versionadded:: 0.2 .. versionadded:: 0.2
""" """
if pymongo.version < '1.1.1': if pymongo.version < '1.1.1':
raise OperationError('update() method requires PyMongo 1.1.1+') raise OperationError('update() method requires PyMongo 1.1.1+')
if not write_options:
write_options = {}
update = QuerySet._transform_update(self._document, **update) update = QuerySet._transform_update(self._document, **update)
try: try:
ret = self._collection.update(self._query, update, multi=True, ret = self._collection.update(self._query, update, multi=True,
upsert=upsert, safe=safe_update) upsert=upsert, safe=safe_update,
**write_options)
if ret is not None and 'n' in ret: if ret is not None and 'n' in ret:
return ret['n'] return ret['n']
except pymongo.errors.OperationFailure, err: except pymongo.errors.OperationFailure, err:
@ -1053,22 +1137,27 @@ class QuerySet(object):
raise OperationError(message) raise OperationError(message)
raise OperationError(u'Update failed (%s)' % unicode(err)) raise OperationError(u'Update failed (%s)' % unicode(err))
def update_one(self, safe_update=True, upsert=False, **update): def update_one(self, safe_update=True, upsert=False, write_options=None, **update):
"""Perform an atomic update on first field matched by the query. When """Perform an atomic update on first field matched by the query. When
``safe_update`` is used, the number of affected documents is returned. ``safe_update`` is used, the number of affected documents is returned.
:param safe: check if the operation succeeded before returning :param safe_update: check if the operation succeeded before returning
:param upsert: Any existing document with that "_id" is overwritten.
:param write_options: extra keyword arguments for :meth:`~pymongo.collection.Collection.update`
:param update: Django-style update keyword arguments :param update: Django-style update keyword arguments
.. versionadded:: 0.2 .. versionadded:: 0.2
""" """
if not write_options:
write_options = {}
update = QuerySet._transform_update(self._document, **update) update = QuerySet._transform_update(self._document, **update)
try: try:
# Explicitly provide 'multi=False' to newer versions of PyMongo # Explicitly provide 'multi=False' to newer versions of PyMongo
# as the default may change to 'True' # as the default may change to 'True'
if pymongo.version >= '1.1.1': if pymongo.version >= '1.1.1':
ret = self._collection.update(self._query, update, multi=False, ret = self._collection.update(self._query, update, multi=False,
upsert=upsert, safe=safe_update) upsert=upsert, safe=safe_update,
**write_options)
else: else:
# Older versions of PyMongo don't support 'multi' # Older versions of PyMongo don't support 'multi'
ret = self._collection.update(self._query, update, ret = self._collection.update(self._query, update,
@ -1241,8 +1330,11 @@ class QuerySet(object):
class QuerySetManager(object): class QuerySetManager(object):
def __init__(self, manager_func=None): get_queryset = None
self._manager_func = manager_func
def __init__(self, queryset_func=None):
if queryset_func:
self.get_queryset = queryset_func
self._collections = {} self._collections = {}
def __get__(self, instance, owner): def __get__(self, instance, owner):
@ -1259,7 +1351,7 @@ class QuerySetManager(object):
# Create collection as a capped collection if specified # Create collection as a capped collection if specified
if owner._meta['max_size'] or owner._meta['max_documents']: if owner._meta['max_size'] or owner._meta['max_documents']:
# Get max document limit and max byte size from meta # Get max document limit and max byte size from meta
max_size = owner._meta['max_size'] or 10000000 # 10MB default max_size = owner._meta['max_size'] or 10000000 # 10MB default
max_documents = owner._meta['max_documents'] max_documents = owner._meta['max_documents']
if collection in db.collection_names(): if collection in db.collection_names():
@ -1286,11 +1378,11 @@ class QuerySetManager(object):
# owner is the document that contains the QuerySetManager # owner is the document that contains the QuerySetManager
queryset_class = owner._meta['queryset_class'] or QuerySet queryset_class = owner._meta['queryset_class'] or QuerySet
queryset = queryset_class(owner, self._collections[(db, collection)]) queryset = queryset_class(owner, self._collections[(db, collection)])
if self._manager_func: if self.get_queryset:
if self._manager_func.func_code.co_argcount == 1: if self.get_queryset.func_code.co_argcount == 1:
queryset = self._manager_func(queryset) queryset = self.get_queryset(queryset)
else: else:
queryset = self._manager_func(owner, queryset) queryset = self.get_queryset(owner, queryset)
return queryset return queryset

View File

@ -1,11 +1,23 @@
import unittest import unittest
from datetime import datetime from datetime import datetime
import pymongo import pymongo
import pickle
from mongoengine import * from mongoengine import *
from mongoengine.base import BaseField
from mongoengine.connection import _get_db from mongoengine.connection import _get_db
class PickleEmbedded(EmbeddedDocument):
date = DateTimeField(default=datetime.now)
class PickleTest(Document):
number = IntField()
string = StringField()
embedded = EmbeddedDocumentField(PickleEmbedded)
lists = ListField(StringField())
class DocumentTest(unittest.TestCase): class DocumentTest(unittest.TestCase):
def setUp(self): def setUp(self):
@ -17,6 +29,9 @@ class DocumentTest(unittest.TestCase):
age = IntField() age = IntField()
self.Person = Person self.Person = Person
def tearDown(self):
self.Person.drop_collection()
def test_drop_collection(self): def test_drop_collection(self):
"""Ensure that the collection may be dropped from the database. """Ensure that the collection may be dropped from the database.
""" """
@ -176,6 +191,34 @@ class DocumentTest(unittest.TestCase):
self.assertFalse('_cls' in comment.to_mongo()) self.assertFalse('_cls' in comment.to_mongo())
self.assertFalse('_types' in comment.to_mongo()) self.assertFalse('_types' in comment.to_mongo())
def test_abstract_documents(self):
"""Ensure that a document superclass can be marked as abstract
thereby not using it as the name for the collection."""
class Animal(Document):
name = StringField()
meta = {'abstract': True}
class Fish(Animal): pass
class Guppy(Fish): pass
class Mammal(Animal):
meta = {'abstract': True}
class Human(Mammal): pass
self.assertFalse('collection' in Animal._meta)
self.assertFalse('collection' in Mammal._meta)
self.assertEqual(Fish._meta['collection'], 'fish')
self.assertEqual(Guppy._meta['collection'], 'fish')
self.assertEqual(Human._meta['collection'], 'human')
def create_bad_abstract():
class EvilHuman(Human):
evil = BooleanField(default=True)
meta = {'abstract': True}
self.assertRaises(ValueError, create_bad_abstract)
def test_collection_name(self): def test_collection_name(self):
"""Ensure that a collection with a specified name may be used. """Ensure that a collection with a specified name may be used.
""" """
@ -200,6 +243,22 @@ class DocumentTest(unittest.TestCase):
Person.drop_collection() Person.drop_collection()
self.assertFalse(collection in self.db.collection_names()) self.assertFalse(collection in self.db.collection_names())
def test_collection_name_and_primary(self):
"""Ensure that a collection with a specified name may be used.
"""
class Person(Document):
name = StringField(primary_key=True)
meta = {'collection': 'app'}
user = Person(name="Test User")
user.save()
user_obj = Person.objects[0]
self.assertEqual(user_obj.name, "Test User")
Person.drop_collection()
def test_inherited_collections(self): def test_inherited_collections(self):
"""Ensure that subclassed documents don't override parents' collections. """Ensure that subclassed documents don't override parents' collections.
""" """
@ -334,6 +393,10 @@ class DocumentTest(unittest.TestCase):
post2 = BlogPost(title='test2', slug='test') post2 = BlogPost(title='test2', slug='test')
self.assertRaises(OperationError, post2.save) self.assertRaises(OperationError, post2.save)
def test_unique_with(self):
"""Ensure that unique_with constraints are applied to fields.
"""
class Date(EmbeddedDocument): class Date(EmbeddedDocument):
year = IntField(db_field='yr') year = IntField(db_field='yr')
@ -357,6 +420,108 @@ class DocumentTest(unittest.TestCase):
BlogPost.drop_collection() BlogPost.drop_collection()
def test_unique_embedded_document(self):
"""Ensure that uniqueness constraints are applied to fields on embedded documents.
"""
class SubDocument(EmbeddedDocument):
year = IntField(db_field='yr')
slug = StringField(unique=True)
class BlogPost(Document):
title = StringField()
sub = EmbeddedDocumentField(SubDocument)
BlogPost.drop_collection()
post1 = BlogPost(title='test1', sub=SubDocument(year=2009, slug="test"))
post1.save()
# sub.slug is different so won't raise exception
post2 = BlogPost(title='test2', sub=SubDocument(year=2010, slug='another-slug'))
post2.save()
# Now there will be two docs with the same sub.slug
post3 = BlogPost(title='test3', sub=SubDocument(year=2010, slug='test'))
self.assertRaises(OperationError, post3.save)
BlogPost.drop_collection()
def test_unique_with_embedded_document_and_embedded_unique(self):
"""Ensure that uniqueness constraints are applied to fields on
embedded documents. And work with unique_with as well.
"""
class SubDocument(EmbeddedDocument):
year = IntField(db_field='yr')
slug = StringField(unique=True)
class BlogPost(Document):
title = StringField(unique_with='sub.year')
sub = EmbeddedDocumentField(SubDocument)
BlogPost.drop_collection()
post1 = BlogPost(title='test1', sub=SubDocument(year=2009, slug="test"))
post1.save()
# sub.slug is different so won't raise exception
post2 = BlogPost(title='test2', sub=SubDocument(year=2010, slug='another-slug'))
post2.save()
# Now there will be two docs with the same sub.slug
post3 = BlogPost(title='test3', sub=SubDocument(year=2010, slug='test'))
self.assertRaises(OperationError, post3.save)
# Now there will be two docs with the same title and year
post3 = BlogPost(title='test1', sub=SubDocument(year=2009, slug='test-1'))
self.assertRaises(OperationError, post3.save)
BlogPost.drop_collection()
def test_unique_and_indexes(self):
"""Ensure that 'unique' constraints aren't overridden by
meta.indexes.
"""
class Customer(Document):
cust_id = IntField(unique=True, required=True)
meta = {
'indexes': ['cust_id'],
'allow_inheritance': False,
}
Customer.drop_collection()
cust = Customer(cust_id=1)
cust.save()
cust_dupe = Customer(cust_id=1)
try:
cust_dupe.save()
raise AssertionError, "We saved a dupe!"
except OperationError:
pass
Customer.drop_collection()
def test_unique_and_primary(self):
"""If you set a field as primary, then unexpected behaviour can occur.
You won't create a duplicate but you will update an existing document.
"""
class User(Document):
name = StringField(primary_key=True, unique=True)
password = StringField()
User.drop_collection()
user = User(name='huangz', password='secret')
user.save()
user = User(name='huangz', password='secret2')
user.save()
self.assertEqual(User.objects.count(), 1)
self.assertEqual(User.objects.get().password, 'secret2')
User.drop_collection()
def test_custom_id_field(self): def test_custom_id_field(self):
"""Ensure that documents may be created with custom primary keys. """Ensure that documents may be created with custom primary keys.
""" """
@ -588,6 +753,34 @@ class DocumentTest(unittest.TestCase):
# Ensure that the 'details' embedded object saved correctly # Ensure that the 'details' embedded object saved correctly
self.assertEqual(employee_obj['details']['position'], 'Developer') self.assertEqual(employee_obj['details']['position'], 'Developer')
def test_updating_an_embedded_document(self):
"""Ensure that a document with an embedded document field may be
saved in the database.
"""
class EmployeeDetails(EmbeddedDocument):
position = StringField()
class Employee(self.Person):
salary = IntField()
details = EmbeddedDocumentField(EmployeeDetails)
# Create employee object and save it to the database
employee = Employee(name='Test Employee', age=50, salary=20000)
employee.details = EmployeeDetails(position='Developer')
employee.save()
# Test updating an embedded document
promoted_employee = Employee.objects.get(name='Test Employee')
promoted_employee.details.position = 'Senior Developer'
promoted_employee.save()
collection = self.db[self.Person._meta['collection']]
employee_obj = collection.find_one({'name': 'Test Employee'})
self.assertEqual(employee_obj['name'], 'Test Employee')
self.assertEqual(employee_obj['age'], 50)
# Ensure that the 'details' embedded object saved correctly
self.assertEqual(employee_obj['details']['position'], 'Senior Developer')
def test_save_reference(self): def test_save_reference(self):
"""Ensure that a document reference field may be saved in the database. """Ensure that a document reference field may be saved in the database.
""" """
@ -725,9 +918,25 @@ class DocumentTest(unittest.TestCase):
self.Person.drop_collection() self.Person.drop_collection()
BlogPost.drop_collection() BlogPost.drop_collection()
def subclasses_and_unique_keys_works(self):
def tearDown(self): class A(Document):
self.Person.drop_collection() pass
class B(A):
foo = BooleanField(unique=True)
A.drop_collection()
B.drop_collection()
A().save()
A().save()
B(foo=True).save()
self.assertEquals(A.objects.count(), 2)
self.assertEquals(B.objects.count(), 1)
A.drop_collection()
B.drop_collection()
def test_document_hash(self): def test_document_hash(self):
"""Test document in list, dict, set """Test document in list, dict, set
@ -777,6 +986,43 @@ class DocumentTest(unittest.TestCase):
self.assertTrue(u1 in all_user_set ) self.assertTrue(u1 in all_user_set )
def test_picklable(self):
pickle_doc = PickleTest(number=1, string="OH HAI", lists=['1', '2'])
pickle_doc.embedded = PickleEmbedded()
pickle_doc.save()
pickled_doc = pickle.dumps(pickle_doc)
resurrected = pickle.loads(pickled_doc)
self.assertEquals(resurrected, pickle_doc)
resurrected.string = "Working"
resurrected.save()
pickle_doc.reload()
self.assertEquals(resurrected, pickle_doc)
def test_write_options(self):
"""Test that passing write_options works"""
self.Person.drop_collection()
write_options = {"fsync": True}
author, created = self.Person.objects.get_or_create(
name='Test User', write_options=write_options)
author.save(write_options=write_options)
self.Person.objects.update(set__name='Ross', write_options=write_options)
author = self.Person.objects.first()
self.assertEquals(author.name, 'Ross')
self.Person.objects.update_one(set__name='Test User', write_options=write_options)
author = self.Person.objects.first()
self.assertEquals(author.name, 'Test User')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -7,6 +7,7 @@ import gridfs
from mongoengine import * from mongoengine import *
from mongoengine.connection import _get_db from mongoengine.connection import _get_db
from mongoengine.base import _document_registry, NotRegistered
class FieldTest(unittest.TestCase): class FieldTest(unittest.TestCase):
@ -261,12 +262,14 @@ class FieldTest(unittest.TestCase):
BlogPost.drop_collection() BlogPost.drop_collection()
def test_dict_validation(self): def test_dict_field(self):
"""Ensure that dict types work as expected. """Ensure that dict types work as expected.
""" """
class BlogPost(Document): class BlogPost(Document):
info = DictField() info = DictField()
BlogPost.drop_collection()
post = BlogPost() post = BlogPost()
post.info = 'my post' post.info = 'my post'
self.assertRaises(ValidationError, post.validate) self.assertRaises(ValidationError, post.validate)
@ -281,7 +284,24 @@ class FieldTest(unittest.TestCase):
self.assertRaises(ValidationError, post.validate) self.assertRaises(ValidationError, post.validate)
post.info = {'title': 'test'} post.info = {'title': 'test'}
post.validate() post.save()
post = BlogPost()
post.info = {'details': {'test': 'test'}}
post.save()
post = BlogPost()
post.info = {'details': {'test': 3}}
post.save()
self.assertEquals(BlogPost.objects.count(), 3)
self.assertEquals(BlogPost.objects.filter(info__title__exact='test').count(), 1)
self.assertEquals(BlogPost.objects.filter(info__details__test__exact='test').count(), 1)
# Confirm handles non strings or non existing keys
self.assertEquals(BlogPost.objects.filter(info__details__test__exact=5).count(), 0)
self.assertEquals(BlogPost.objects.filter(info__made_up__test__exact='test').count(), 0)
BlogPost.drop_collection()
def test_embedded_document_validation(self): def test_embedded_document_validation(self):
"""Ensure that invalid embedded documents cannot be assigned to """Ensure that invalid embedded documents cannot be assigned to
@ -584,6 +604,39 @@ class FieldTest(unittest.TestCase):
Post.drop_collection() Post.drop_collection()
User.drop_collection() User.drop_collection()
def test_generic_reference_document_not_registered(self):
"""Ensure dereferencing out of the document registry throws a
`NotRegistered` error.
"""
class Link(Document):
title = StringField()
class User(Document):
bookmarks = ListField(GenericReferenceField())
Link.drop_collection()
User.drop_collection()
link_1 = Link(title="Pitchfork")
link_1.save()
user = User(bookmarks=[link_1])
user.save()
# Mimic User and Link definitions being in a different file
# and the Link model not being imported in the User file.
del(_document_registry["Link"])
user = User.objects.first()
try:
user.bookmarks
raise AssertionError, "Link was removed from the registry"
except NotRegistered:
pass
Link.drop_collection()
User.drop_collection()
def test_binary_fields(self): def test_binary_fields(self):
"""Ensure that binary fields can be stored and retrieved. """Ensure that binary fields can be stored and retrieved.
""" """
@ -701,6 +754,12 @@ class FieldTest(unittest.TestCase):
self.assertTrue(streamfile == result) self.assertTrue(streamfile == result)
self.assertEquals(result.file.read(), text + more_text) self.assertEquals(result.file.read(), text + more_text)
self.assertEquals(result.file.content_type, content_type) self.assertEquals(result.file.content_type, content_type)
result.file.seek(0)
self.assertEquals(result.file.tell(), 0)
self.assertEquals(result.file.read(len(text)), text)
self.assertEquals(result.file.tell(), len(text))
self.assertEquals(result.file.read(len(more_text)), more_text)
self.assertEquals(result.file.tell(), len(text + more_text))
result.file.delete() result.file.delete()
# Ensure deleted file returns None # Ensure deleted file returns None
@ -785,5 +844,66 @@ class FieldTest(unittest.TestCase):
self.assertEqual(d2.data, {}) self.assertEqual(d2.data, {})
self.assertEqual(d2.data2, {}) self.assertEqual(d2.data2, {})
def test_mapfield(self):
"""Ensure that the MapField handles the declared type."""
class Simple(Document):
mapping = MapField(IntField())
Simple.drop_collection()
e = Simple()
e.mapping['someint'] = 1
e.save()
def create_invalid_mapping():
e.mapping['somestring'] = "abc"
e.save()
self.assertRaises(ValidationError, create_invalid_mapping)
def create_invalid_class():
class NoDeclaredType(Document):
mapping = MapField()
self.assertRaises(ValidationError, create_invalid_class)
Simple.drop_collection()
def test_complex_mapfield(self):
"""Ensure that the MapField can handle complex declared types."""
class SettingBase(EmbeddedDocument):
pass
class StringSetting(SettingBase):
value = StringField()
class IntegerSetting(SettingBase):
value = IntField()
class Extensible(Document):
mapping = MapField(EmbeddedDocumentField(SettingBase))
Extensible.drop_collection()
e = Extensible()
e.mapping['somestring'] = StringSetting(value='foo')
e.mapping['someint'] = IntegerSetting(value=42)
e.save()
e2 = Extensible.objects.get(id=e.id)
self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting))
self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting))
def create_invalid_mapping():
e.mapping['someint'] = 123
e.save()
self.assertRaises(ValidationError, create_invalid_mapping)
Extensible.drop_collection()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -5,8 +5,9 @@ import unittest
import pymongo import pymongo
from datetime import datetime, timedelta from datetime import datetime, timedelta
from mongoengine.queryset import (QuerySet, MultipleObjectsReturned, from mongoengine.queryset import (QuerySet, QuerySetManager,
DoesNotExist, QueryFieldList) MultipleObjectsReturned, DoesNotExist,
QueryFieldList)
from mongoengine import * from mongoengine import *
@ -105,6 +106,10 @@ class QuerySetTest(unittest.TestCase):
people = list(self.Person.objects[1:1]) people = list(self.Person.objects[1:1])
self.assertEqual(len(people), 0) self.assertEqual(len(people), 0)
# Test slice out of range
people = list(self.Person.objects[80000:80001])
self.assertEqual(len(people), 0)
def test_find_one(self): def test_find_one(self):
"""Ensure that a query using find_one returns a valid result. """Ensure that a query using find_one returns a valid result.
""" """
@ -207,6 +212,55 @@ class QuerySetTest(unittest.TestCase):
Blog.drop_collection() Blog.drop_collection()
def test_update_array_position(self):
"""Ensure that updating by array position works.
Check update() and update_one() can take syntax like:
set__posts__1__comments__1__name="testc"
Check that it only works for ListFields.
"""
class Comment(EmbeddedDocument):
name = StringField()
class Post(EmbeddedDocument):
comments = ListField(EmbeddedDocumentField(Comment))
class Blog(Document):
tags = ListField(StringField())
posts = ListField(EmbeddedDocumentField(Post))
Blog.drop_collection()
comment1 = Comment(name='testa')
comment2 = Comment(name='testb')
post1 = Post(comments=[comment1, comment2])
post2 = Post(comments=[comment2, comment2])
blog1 = Blog.objects.create(posts=[post1, post2])
blog2 = Blog.objects.create(posts=[post2, post1])
# Update all of the first comments of second posts of all blogs
blog = Blog.objects().update(set__posts__1__comments__0__name="testc")
testc_blogs = Blog.objects(posts__1__comments__0__name="testc")
self.assertEqual(len(testc_blogs), 2)
Blog.drop_collection()
blog1 = Blog.objects.create(posts=[post1, post2])
blog2 = Blog.objects.create(posts=[post2, post1])
# Update only the first blog returned by the query
blog = Blog.objects().update_one(
set__posts__1__comments__1__name="testc")
testc_blogs = Blog.objects(posts__1__comments__1__name="testc")
self.assertEqual(len(testc_blogs), 1)
# Check that using this indexing syntax on a non-list fails
def non_list_indexing():
Blog.objects().update(set__posts__1__comments__0__name__1="asdf")
self.assertRaises(InvalidQueryError, non_list_indexing)
Blog.drop_collection()
def test_get_or_create(self): def test_get_or_create(self):
"""Ensure that ``get_or_create`` returns one result or creates a new """Ensure that ``get_or_create`` returns one result or creates a new
document. document.
@ -593,6 +647,81 @@ class QuerySetTest(unittest.TestCase):
Email.drop_collection() Email.drop_collection()
def test_slicing_fields(self):
"""Ensure that query slicing an array works.
"""
class Numbers(Document):
n = ListField(IntField())
Numbers.drop_collection()
numbers = Numbers(n=[0,1,2,3,4,5,-5,-4,-3,-2,-1])
numbers.save()
# first three
numbers = Numbers.objects.fields(slice__n=3).get()
self.assertEquals(numbers.n, [0, 1, 2])
# last three
numbers = Numbers.objects.fields(slice__n=-3).get()
self.assertEquals(numbers.n, [-3, -2, -1])
# skip 2, limit 3
numbers = Numbers.objects.fields(slice__n=[2, 3]).get()
self.assertEquals(numbers.n, [2, 3, 4])
# skip to fifth from last, limit 4
numbers = Numbers.objects.fields(slice__n=[-5, 4]).get()
self.assertEquals(numbers.n, [-5, -4, -3, -2])
# skip to fifth from last, limit 10
numbers = Numbers.objects.fields(slice__n=[-5, 10]).get()
self.assertEquals(numbers.n, [-5, -4, -3, -2, -1])
# skip to fifth from last, limit 10 dict method
numbers = Numbers.objects.fields(n={"$slice": [-5, 10]}).get()
self.assertEquals(numbers.n, [-5, -4, -3, -2, -1])
def test_slicing_nested_fields(self):
"""Ensure that query slicing an embedded array works.
"""
class EmbeddedNumber(EmbeddedDocument):
n = ListField(IntField())
class Numbers(Document):
embedded = EmbeddedDocumentField(EmbeddedNumber)
Numbers.drop_collection()
numbers = Numbers()
numbers.embedded = EmbeddedNumber(n=[0,1,2,3,4,5,-5,-4,-3,-2,-1])
numbers.save()
# first three
numbers = Numbers.objects.fields(slice__embedded__n=3).get()
self.assertEquals(numbers.embedded.n, [0, 1, 2])
# last three
numbers = Numbers.objects.fields(slice__embedded__n=-3).get()
self.assertEquals(numbers.embedded.n, [-3, -2, -1])
# skip 2, limit 3
numbers = Numbers.objects.fields(slice__embedded__n=[2, 3]).get()
self.assertEquals(numbers.embedded.n, [2, 3, 4])
# skip to fifth from last, limit 4
numbers = Numbers.objects.fields(slice__embedded__n=[-5, 4]).get()
self.assertEquals(numbers.embedded.n, [-5, -4, -3, -2])
# skip to fifth from last, limit 10
numbers = Numbers.objects.fields(slice__embedded__n=[-5, 10]).get()
self.assertEquals(numbers.embedded.n, [-5, -4, -3, -2, -1])
# skip to fifth from last, limit 10 dict method
numbers = Numbers.objects.fields(embedded__n={"$slice": [-5, 10]}).get()
self.assertEquals(numbers.embedded.n, [-5, -4, -3, -2, -1])
def test_find_embedded(self): def test_find_embedded(self):
"""Ensure that an embedded document is properly returned from a query. """Ensure that an embedded document is properly returned from a query.
""" """
@ -1027,7 +1156,7 @@ class QuerySetTest(unittest.TestCase):
""" """
# run a map/reduce operation spanning all posts # run a map/reduce operation spanning all posts
results = BlogPost.objects.map_reduce(map_f, reduce_f) results = BlogPost.objects.map_reduce(map_f, reduce_f, "myresults")
results = list(results) results = list(results)
self.assertEqual(len(results), 4) self.assertEqual(len(results), 4)
@ -1076,7 +1205,7 @@ class QuerySetTest(unittest.TestCase):
} }
""" """
results = BlogPost.objects.map_reduce(map_f, reduce_f) results = BlogPost.objects.map_reduce(map_f, reduce_f, "myresults")
results = list(results) results = list(results)
self.assertEqual(results[0].object, post1) self.assertEqual(results[0].object, post1)
@ -1187,6 +1316,7 @@ class QuerySetTest(unittest.TestCase):
results = Link.objects.order_by("-value") results = Link.objects.order_by("-value")
results = results.map_reduce(map_f, results = results.map_reduce(map_f,
reduce_f, reduce_f,
"myresults",
finalize_f=finalize_f, finalize_f=finalize_f,
scope=scope) scope=scope)
results = list(results) results = list(results)
@ -1289,6 +1419,7 @@ class QuerySetTest(unittest.TestCase):
class BlogPost(Document): class BlogPost(Document):
tags = ListField(StringField()) tags = ListField(StringField())
deleted = BooleanField(default=False) deleted = BooleanField(default=False)
date = DateTimeField(default=datetime.now)
@queryset_manager @queryset_manager
def objects(doc_cls, queryset): def objects(doc_cls, queryset):
@ -1296,7 +1427,7 @@ class QuerySetTest(unittest.TestCase):
@queryset_manager @queryset_manager
def music_posts(doc_cls, queryset): def music_posts(doc_cls, queryset):
return queryset(tags='music', deleted=False) return queryset(tags='music', deleted=False).order_by('-date')
BlogPost.drop_collection() BlogPost.drop_collection()
@ -1312,7 +1443,7 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual([p.id for p in BlogPost.objects], self.assertEqual([p.id for p in BlogPost.objects],
[post1.id, post2.id, post3.id]) [post1.id, post2.id, post3.id])
self.assertEqual([p.id for p in BlogPost.music_posts], self.assertEqual([p.id for p in BlogPost.music_posts],
[post1.id, post2.id]) [post2.id, post1.id])
BlogPost.drop_collection() BlogPost.drop_collection()
@ -1452,10 +1583,12 @@ class QuerySetTest(unittest.TestCase):
class Test(Document): class Test(Document):
testdict = DictField() testdict = DictField()
Test.drop_collection()
t = Test(testdict={'f': 'Value'}) t = Test(testdict={'f': 'Value'})
t.save() t.save()
self.assertEqual(len(Test.objects(testdict__f__startswith='Val')), 0) self.assertEqual(len(Test.objects(testdict__f__startswith='Val')), 1)
self.assertEqual(len(Test.objects(testdict__f='Value')), 1) self.assertEqual(len(Test.objects(testdict__f='Value')), 1)
Test.drop_collection() Test.drop_collection()
@ -1541,7 +1674,7 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(events.count(), 3) self.assertEqual(events.count(), 3)
self.assertEqual(list(events), [event1, event3, event2]) self.assertEqual(list(events), [event1, event3, event2])
# find events within 5 miles of pitchfork office, chicago # find events within 5 degrees of pitchfork office, chicago
point_and_distance = [[41.9120459, -87.67892], 5] point_and_distance = [[41.9120459, -87.67892], 5]
events = Event.objects(location__within_distance=point_and_distance) events = Event.objects(location__within_distance=point_and_distance)
self.assertEqual(events.count(), 2) self.assertEqual(events.count(), 2)
@ -1556,13 +1689,13 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(events.count(), 3) self.assertEqual(events.count(), 3)
self.assertEqual(list(events), [event3, event1, event2]) self.assertEqual(list(events), [event3, event1, event2])
# find events around san francisco # find events within 10 degrees of san francisco
point_and_distance = [[37.7566023, -122.415579], 10] point_and_distance = [[37.7566023, -122.415579], 10]
events = Event.objects(location__within_distance=point_and_distance) events = Event.objects(location__within_distance=point_and_distance)
self.assertEqual(events.count(), 1) self.assertEqual(events.count(), 1)
self.assertEqual(events[0], event2) self.assertEqual(events[0], event2)
# find events within 1 mile of greenpoint, broolyn, nyc, ny # find events within 1 degree of greenpoint, broolyn, nyc, ny
point_and_distance = [[40.7237134, -73.9509714], 1] point_and_distance = [[40.7237134, -73.9509714], 1]
events = Event.objects(location__within_distance=point_and_distance) events = Event.objects(location__within_distance=point_and_distance)
self.assertEqual(events.count(), 0) self.assertEqual(events.count(), 0)
@ -1582,6 +1715,58 @@ class QuerySetTest(unittest.TestCase):
Event.drop_collection() Event.drop_collection()
def test_spherical_geospatial_operators(self):
"""Ensure that spherical geospatial queries are working
"""
class Point(Document):
location = GeoPointField()
Point.drop_collection()
# These points are one degree apart, which (according to Google Maps)
# is about 110 km apart at this place on the Earth.
north_point = Point(location=[-122, 38]) # Near Concord, CA
south_point = Point(location=[-122, 37]) # Near Santa Cruz, CA
north_point.save()
south_point.save()
earth_radius = 6378.009; # in km (needs to be a float for dividing by)
# Finds both points because they are within 60 km of the reference
# point equidistant between them.
points = Point.objects(location__near_sphere=[-122, 37.5])
self.assertEqual(points.count(), 2)
# Same behavior for _within_spherical_distance
points = Point.objects(
location__within_spherical_distance=[[-122, 37.5], 60/earth_radius]
);
self.assertEqual(points.count(), 2)
# Finds both points, but orders the north point first because it's
# closer to the reference point to the north.
points = Point.objects(location__near_sphere=[-122, 38.5])
self.assertEqual(points.count(), 2)
self.assertEqual(points[0].id, north_point.id)
self.assertEqual(points[1].id, south_point.id)
# Finds both points, but orders the south point first because it's
# closer to the reference point to the south.
points = Point.objects(location__near_sphere=[-122, 36.5])
self.assertEqual(points.count(), 2)
self.assertEqual(points[0].id, south_point.id)
self.assertEqual(points[1].id, north_point.id)
# Finds only one point because only the first point is within 60km of
# the reference point to the south.
points = Point.objects(
location__within_spherical_distance=[[-122, 36.5], 60/earth_radius]
);
self.assertEqual(points.count(), 1)
self.assertEqual(points[0].id, south_point.id)
Point.drop_collection()
def test_custom_querysets(self): def test_custom_querysets(self):
"""Ensure that custom QuerySet classes may be used. """Ensure that custom QuerySet classes may be used.
""" """
@ -1602,6 +1787,53 @@ class QuerySetTest(unittest.TestCase):
Post.drop_collection() Post.drop_collection()
def test_custom_querysets_set_manager_directly(self):
"""Ensure that custom QuerySet classes may be used.
"""
class CustomQuerySet(QuerySet):
def not_empty(self):
return len(self) > 0
class CustomQuerySetManager(QuerySetManager):
queryset_class = CustomQuerySet
class Post(Document):
objects = CustomQuerySetManager()
Post.drop_collection()
self.assertTrue(isinstance(Post.objects, CustomQuerySet))
self.assertFalse(Post.objects.not_empty())
Post().save()
self.assertTrue(Post.objects.not_empty())
Post.drop_collection()
def test_custom_querysets_managers_directly(self):
"""Ensure that custom QuerySet classes may be used.
"""
class CustomQuerySetManager(QuerySetManager):
@staticmethod
def get_queryset(doc_cls, queryset):
return queryset(is_published=True)
class Post(Document):
is_published = BooleanField(default=False)
published = CustomQuerySetManager()
Post.drop_collection()
Post().save()
Post(is_published=True).save()
self.assertEquals(Post.objects.count(), 2)
self.assertEquals(Post.published.count(), 1)
Post.drop_collection()
def test_call_after_limits_set(self): def test_call_after_limits_set(self):
"""Ensure that re-filtering after slicing works """Ensure that re-filtering after slicing works
""" """
@ -1637,6 +1869,35 @@ class QuerySetTest(unittest.TestCase):
Number.drop_collection() Number.drop_collection()
def test_clone(self):
"""Ensure that cloning clones complex querysets
"""
class Number(Document):
n = IntField()
Number.drop_collection()
for i in xrange(1, 101):
t = Number(n=i)
t.save()
test = Number.objects
test2 = test.clone()
self.assertFalse(test == test2)
self.assertEqual(test.count(), test2.count())
test = test.filter(n__gt=11)
test2 = test.clone()
self.assertFalse(test == test2)
self.assertEqual(test.count(), test2.count())
test = test.limit(10)
test2 = test.clone()
self.assertFalse(test == test2)
self.assertEqual(test.count(), test2.count())
Number.drop_collection()
def test_unset_reference(self): def test_unset_reference(self):
class Comment(Document): class Comment(Document):
text = StringField() text = StringField()
@ -1658,6 +1919,39 @@ class QuerySetTest(unittest.TestCase):
Comment.drop_collection() Comment.drop_collection()
Post.drop_collection() Post.drop_collection()
def test_order_works_with_custom_db_field_names(self):
class Number(Document):
n = IntField(db_field='number')
Number.drop_collection()
n2 = Number.objects.create(n=2)
n1 = Number.objects.create(n=1)
self.assertEqual(list(Number.objects), [n2,n1])
self.assertEqual(list(Number.objects.order_by('n')), [n1,n2])
Number.drop_collection()
def test_order_works_with_primary(self):
"""Ensure that order_by and primary work.
"""
class Number(Document):
n = IntField(primary_key=True)
Number.drop_collection()
Number(n=1).save()
Number(n=2).save()
Number(n=3).save()
numbers = [n.n for n in Number.objects.order_by('-n')]
self.assertEquals([3, 2, 1], numbers)
numbers = [n.n for n in Number.objects.order_by('+n')]
self.assertEquals([1, 2, 3], numbers)
Number.drop_collection()
class QTest(unittest.TestCase): class QTest(unittest.TestCase):
@ -1795,6 +2089,30 @@ class QTest(unittest.TestCase):
for condition in conditions: for condition in conditions:
self.assertTrue(condition in query['$or']) self.assertTrue(condition in query['$or'])
def test_q_clone(self):
class TestDoc(Document):
x = IntField()
TestDoc.drop_collection()
for i in xrange(1, 101):
t = TestDoc(x=i)
t.save()
# Check normal cases work without an error
test = TestDoc.objects(Q(x__lt=7) & Q(x__gt=3))
self.assertEqual(test.count(), 3)
test2 = test.clone()
self.assertEqual(test2.count(), 3)
self.assertFalse(test2 == test)
test2.filter(x=6)
self.assertEqual(test2.count(), 1)
self.assertEqual(test.count(), 3)
class QueryFieldListTest(unittest.TestCase): class QueryFieldListTest(unittest.TestCase):
def test_empty(self): def test_empty(self):
q = QueryFieldList() q = QueryFieldList()
@ -1805,51 +2123,52 @@ class QueryFieldListTest(unittest.TestCase):
def test_include_include(self): def test_include_include(self):
q = QueryFieldList() q = QueryFieldList()
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'a': True, 'b': True}) self.assertEqual(q.as_dict(), {'a': True, 'b': True})
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'b': True}) self.assertEqual(q.as_dict(), {'b': True})
def test_include_exclude(self): def test_include_exclude(self):
q = QueryFieldList() q = QueryFieldList()
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'a': True, 'b': True}) self.assertEqual(q.as_dict(), {'a': True, 'b': True})
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.EXCLUDE)
self.assertEqual(q.as_dict(), {'a': True}) self.assertEqual(q.as_dict(), {'a': True})
def test_exclude_exclude(self): def test_exclude_exclude(self):
q = QueryFieldList() q = QueryFieldList()
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.EXCLUDE)
self.assertEqual(q.as_dict(), {'a': False, 'b': False}) self.assertEqual(q.as_dict(), {'a': False, 'b': False})
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.EXCLUDE)
self.assertEqual(q.as_dict(), {'a': False, 'b': False, 'c': False}) self.assertEqual(q.as_dict(), {'a': False, 'b': False, 'c': False})
def test_exclude_include(self): def test_exclude_include(self):
q = QueryFieldList() q = QueryFieldList()
q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.EXCLUDE)
self.assertEqual(q.as_dict(), {'a': False, 'b': False}) self.assertEqual(q.as_dict(), {'a': False, 'b': False})
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'c': True}) self.assertEqual(q.as_dict(), {'c': True})
def test_always_include(self): def test_always_include(self):
q = QueryFieldList(always_include=['x', 'y']) q = QueryFieldList(always_include=['x', 'y'])
q += QueryFieldList(fields=['a', 'b', 'x'], direction=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=['a', 'b', 'x'], value=QueryFieldList.EXCLUDE)
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True}) self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True})
def test_reset(self): def test_reset(self):
q = QueryFieldList(always_include=['x', 'y']) q = QueryFieldList(always_include=['x', 'y'])
q += QueryFieldList(fields=['a', 'b', 'x'], direction=QueryFieldList.EXCLUDE) q += QueryFieldList(fields=['a', 'b', 'x'], value=QueryFieldList.EXCLUDE)
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True}) self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True})
q.reset() q.reset()
self.assertFalse(q) self.assertFalse(q)
q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY)
self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'b': True, 'c': True}) self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'b': True, 'c': True})
def test_using_a_slice(self):
q = QueryFieldList()
q += QueryFieldList(fields=['a'], value={"$slice": 5})
self.assertEqual(q.as_dict(), {'a': {"$slice": 5}})
if __name__ == '__main__': if __name__ == '__main__':