diff --git a/.gitignore b/.gitignore index d67429a2..8951a0ce 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +.* +!.gitignore *.pyc .*.swp *.egg @@ -9,4 +11,4 @@ mongoengine.egg-info/ env/ .settings .project -.pydevproject \ No newline at end of file +.pydevproject diff --git a/mongoengine/base.py b/mongoengine/base.py index b7f7ce87..ffceb794 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -7,22 +7,32 @@ import pymongo import pymongo.objectid -_document_registry = {} - -def get_document(name): - return _document_registry[name] +class NotRegistered(Exception): + pass class ValidationError(Exception): pass +_document_registry = {} + +def get_document(name): + if name not in _document_registry: + raise NotRegistered(""" + `%s` has not been registered in the document registry. + Importing the document class automatically registers it, has it + been imported? + """.strip() % name) + return _document_registry[name] + + class BaseField(object): """A base class for fields in a MongoDB document. Instances of this class may be added to subclasses of `Document` to define a document's schema. """ - # Fields may have _types inserted into indexes by default + # Fields may have _types inserted into indexes by default _index_with_types = True _geo_index = False @@ -32,7 +42,7 @@ class BaseField(object): creation_counter = 0 auto_creation_counter = -1 - def __init__(self, db_field=None, name=None, required=False, default=None, + def __init__(self, db_field=None, name=None, required=False, default=None, unique=False, unique_with=None, primary_key=False, validation=None, choices=None): self.db_field = (db_field or name) if not primary_key else '_id' @@ -57,7 +67,7 @@ class BaseField(object): BaseField.creation_counter += 1 def __get__(self, instance, owner): - """Descriptor for retrieving a value from a field in a document. Do + """Descriptor for retrieving a value from a field in a document. Do any necessary conversion between Python and MongoDB types. """ if instance is None: @@ -167,8 +177,8 @@ class DocumentMetaclass(type): superclasses.update(base._superclasses) if hasattr(base, '_meta'): - # Ensure that the Document class may be subclassed - - # inheritance may be disabled to remove dependency on + # Ensure that the Document class may be subclassed - + # inheritance may be disabled to remove dependency on # additional fields _cls and _types if base._meta.get('allow_inheritance', True) == False: raise ValueError('Document %s may not be subclassed' % @@ -211,12 +221,12 @@ class DocumentMetaclass(type): module = attrs.get('__module__') - base_excs = tuple(base.DoesNotExist for base in bases + base_excs = tuple(base.DoesNotExist for base in bases if hasattr(base, 'DoesNotExist')) or (DoesNotExist,) exc = subclass_exception('DoesNotExist', base_excs, module) new_class.add_to_class('DoesNotExist', exc) - base_excs = tuple(base.MultipleObjectsReturned for base in bases + base_excs = tuple(base.MultipleObjectsReturned for base in bases if hasattr(base, 'MultipleObjectsReturned')) base_excs = base_excs or (MultipleObjectsReturned,) exc = subclass_exception('MultipleObjectsReturned', base_excs, module) @@ -238,12 +248,21 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): def __new__(cls, name, bases, attrs): super_new = super(TopLevelDocumentMetaclass, cls).__new__ - # Classes defined in this package are abstract and should not have + # Classes defined in this package are abstract and should not have # their own metadata with DB collection, etc. - # __metaclass__ is only set on the class with the __metaclass__ + # __metaclass__ is only set on the class with the __metaclass__ # attribute (i.e. it is not set on subclasses). This differentiates # 'real' documents from the 'Document' class - if attrs.get('__metaclass__') == TopLevelDocumentMetaclass: + # + # Also assume a class is abstract if it has abstract set to True in + # its meta dictionary. This allows custom Document superclasses. + if (attrs.get('__metaclass__') == TopLevelDocumentMetaclass or + ('meta' in attrs and attrs['meta'].get('abstract', False))): + # Make sure no base class was non-abstract + non_abstract_bases = [b for b in bases + if hasattr(b,'_meta') and not b._meta.get('abstract', False)] + if non_abstract_bases: + raise ValueError("Abstract document cannot have non-abstract base") return super_new(cls, name, bases, attrs) collection = name.lower() @@ -266,6 +285,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): base_indexes += base._meta.get('indexes', []) meta = { + 'abstract': False, 'collection': collection, 'max_documents': None, 'max_size': None, @@ -289,13 +309,39 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): new_class = super_new(cls, name, bases, attrs) # Provide a default queryset unless one has been manually provided - if not hasattr(new_class, 'objects'): - new_class.objects = QuerySetManager() + manager = attrs.get('objects', QuerySetManager()) + if hasattr(manager, 'queryset_class'): + meta['queryset_class'] = manager.queryset_class + new_class.objects = manager user_indexes = [QuerySet._build_index_spec(new_class, spec) for spec in meta['indexes']] + base_indexes new_class._meta['indexes'] = user_indexes + unique_indexes = cls._unique_with_indexes(new_class) + new_class._meta['unique_indexes'] = unique_indexes + + for field_name, field in new_class._fields.items(): + # Check for custom primary key + if field.primary_key: + current_pk = new_class._meta['id_field'] + if current_pk and current_pk != field_name: + raise ValueError('Cannot override primary key field') + + if not current_pk: + new_class._meta['id_field'] = field_name + # Make 'Document.id' an alias to the real primary key field + new_class.id = field + + if not new_class._meta['id_field']: + new_class._meta['id_field'] = 'id' + new_class._fields['id'] = ObjectIdField(db_field='_id') + new_class.id = new_class._fields['id'] + + return new_class + + @classmethod + def _unique_with_indexes(cls, new_class, namespace=""): unique_indexes = [] for field_name, field in new_class._fields.items(): # Generate a list of indexes needed by uniqueness constraints @@ -321,28 +367,16 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): unique_fields += unique_with # Add the new index to the list - index = [(f, pymongo.ASCENDING) for f in unique_fields] + index = [("%s%s" % (namespace, f), pymongo.ASCENDING) for f in unique_fields] unique_indexes.append(index) - # Check for custom primary key - if field.primary_key: - current_pk = new_class._meta['id_field'] - if current_pk and current_pk != field_name: - raise ValueError('Cannot override primary key field') + # Grab any embedded document field unique indexes + if field.__class__.__name__ == "EmbeddedDocumentField": + field_namespace = "%s." % field_name + unique_indexes += cls._unique_with_indexes(field.document_type, + field_namespace) - if not current_pk: - new_class._meta['id_field'] = field_name - # Make 'Document.id' an alias to the real primary key field - new_class.id = field - - new_class._meta['unique_indexes'] = unique_indexes - - if not new_class._meta['id_field']: - new_class._meta['id_field'] = 'id' - new_class._fields['id'] = ObjectIdField(db_field='_id') - new_class.id = new_class._fields['id'] - - return new_class + return unique_indexes class BaseDocument(object): @@ -366,7 +400,7 @@ class BaseDocument(object): are present. """ # Get a list of tuples of field names and their current values - fields = [(field, getattr(self, name)) + fields = [(field, getattr(self, name)) for name, field in self._fields.items()] # Ensure that each field is matched to a valid value @@ -461,7 +495,7 @@ class BaseDocument(object): self._meta.get('allow_inheritance', True) == False): data['_cls'] = self._class_name data['_types'] = self._superclasses.keys() + [self._class_name] - if data.has_key('_id') and not data['_id']: + if data.has_key('_id') and data['_id'] is None: del data['_id'] return data diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 814fde13..7b5cd210 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -1,5 +1,6 @@ from pymongo import Connection import multiprocessing +import threading __all__ = ['ConnectionError', 'connect'] @@ -22,17 +23,22 @@ class ConnectionError(Exception): def _get_connection(reconnect=False): + """Handles the connection to the database + """ global _connection identity = get_identity() # Connect to the database if not already connected if _connection.get(identity) is None or reconnect: try: _connection[identity] = Connection(**_connection_settings) - except: - raise ConnectionError('Cannot connect to the database') + except Exception, e: + raise ConnectionError("Cannot connect to the database:\n%s" % e) return _connection[identity] def _get_db(reconnect=False): + """Handles database connections and authentication based on the current + identity + """ global _db, _connection identity = get_identity() # Connect if not already connected @@ -52,12 +58,17 @@ def _get_db(reconnect=False): return _db[identity] def get_identity(): + """Creates an identity key based on the current process and thread + identity. + """ identity = multiprocessing.current_process()._identity identity = 0 if not identity else identity[0] + + identity = (identity, threading.current_thread().ident) return identity - + def connect(db, username=None, password=None, **kwargs): - """Connect to the database specified by the 'db' argument. Connection + """Connect to the database specified by the 'db' argument. Connection settings may be provided here as well if the database is not running on the default port on localhost. If authentication is needed, provide username and password arguments as well. diff --git a/mongoengine/django/auth.py b/mongoengine/django/auth.py index 595852ef..41d307cc 100644 --- a/mongoengine/django/auth.py +++ b/mongoengine/django/auth.py @@ -86,7 +86,7 @@ class User(Document): else: email = '@'.join([email_name, domain_part.lower()]) - user = User(username=username, email=email, date_joined=now) + user = cls(username=username, email=email, date_joined=now) user.set_password(password) user.save() return user diff --git a/mongoengine/document.py b/mongoengine/document.py index 196662c3..771b9229 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -40,44 +40,54 @@ class Document(BaseDocument): presence of `_cls` and `_types`, set :attr:`allow_inheritance` to ``False`` in the :attr:`meta` dictionary. - A :class:`~mongoengine.Document` may use a **Capped Collection** by + A :class:`~mongoengine.Document` may use a **Capped Collection** by specifying :attr:`max_documents` and :attr:`max_size` in the :attr:`meta` dictionary. :attr:`max_documents` is the maximum number of documents that - is allowed to be stored in the collection, and :attr:`max_size` is the - maximum size of the collection in bytes. If :attr:`max_size` is not - specified and :attr:`max_documents` is, :attr:`max_size` defaults to + is allowed to be stored in the collection, and :attr:`max_size` is the + maximum size of the collection in bytes. If :attr:`max_size` is not + specified and :attr:`max_documents` is, :attr:`max_size` defaults to 10000000 bytes (10MB). Indexes may be created by specifying :attr:`indexes` in the :attr:`meta` - dictionary. The value should be a list of field names or tuples of field + dictionary. The value should be a list of field names or tuples of field names. Index direction may be specified by prefixing the field names with a **+** or **-** sign. """ __metaclass__ = TopLevelDocumentMetaclass - def save(self, safe=True, force_insert=False, validate=True): + def save(self, safe=True, force_insert=False, validate=True, write_options=None): """Save the :class:`~mongoengine.Document` to the database. If the document already exists, it will be updated, otherwise it will be created. - If ``safe=True`` and the operation is unsuccessful, an + If ``safe=True`` and the operation is unsuccessful, an :class:`~mongoengine.OperationError` will be raised. :param safe: check if the operation succeeded before returning - :param force_insert: only try to create a new document, don't allow + :param force_insert: only try to create a new document, don't allow updates of existing documents :param validate: validates the document; set to ``False`` to skip. + :param write_options: Extra keyword arguments are passed down to + :meth:`~pymongo.collection.Collection.save` OR + :meth:`~pymongo.collection.Collection.insert` + which will be used as options for the resultant ``getLastError`` command. + For example, ``save(..., w=2, fsync=True)`` will wait until at least two servers + have recorded the write and will force an fsync on each server being written to. """ if validate: self.validate() + + if not write_options: + write_options = {} + doc = self.to_mongo() try: collection = self.__class__.objects._collection if force_insert: - object_id = collection.insert(doc, safe=safe) + object_id = collection.insert(doc, safe=safe, **write_options) else: - object_id = collection.save(doc, safe=safe) + object_id = collection.save(doc, safe=safe, **write_options) except pymongo.errors.OperationFailure, err: message = 'Could not save document (%s)' if u'duplicate key' in unicode(err): @@ -131,9 +141,9 @@ class MapReduceDocument(object): """A document returned from a map/reduce query. :param collection: An instance of :class:`~pymongo.Collection` - :param key: Document/result key, often an instance of - :class:`~pymongo.objectid.ObjectId`. If supplied as - an ``ObjectId`` found in the given ``collection``, + :param key: Document/result key, often an instance of + :class:`~pymongo.objectid.ObjectId`. If supplied as + an ``ObjectId`` found in the given ``collection``, the object can be accessed via the ``object`` property. :param value: The result(s) for this key. @@ -148,7 +158,7 @@ class MapReduceDocument(object): @property def object(self): - """Lazy-load the object referenced by ``self.key``. ``self.key`` + """Lazy-load the object referenced by ``self.key``. ``self.key`` should be the ``primary_key``. """ id_field = self._document()._meta['id_field'] diff --git a/mongoengine/fields.py b/mongoengine/fields.py index c06fdd4d..11366dd0 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -17,7 +17,7 @@ import warnings __all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField', 'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField', - 'ObjectIdField', 'ReferenceField', 'ValidationError', + 'ObjectIdField', 'ReferenceField', 'ValidationError', 'MapField', 'DecimalField', 'URLField', 'GenericReferenceField', 'FileField', 'BinaryField', 'SortedListField', 'EmailField', 'GeoPointField'] @@ -339,7 +339,7 @@ class ListField(BaseField): if isinstance(self.field, ReferenceField): referenced_type = self.field.document_type - # Get value from document instance if available + # Get value from document instance if available value_list = instance._data.get(self.name) if value_list: deref_list = [] @@ -449,7 +449,108 @@ class DictField(BaseField): 'contain "." or "$" characters') def lookup_member(self, member_name): - return self.basecls(db_field=member_name) + return DictField(basecls=self.basecls, db_field=member_name) + + def prepare_query_value(self, op, value): + match_operators = ['contains', 'icontains', 'startswith', + 'istartswith', 'endswith', 'iendswith', + 'exact', 'iexact'] + + if op in match_operators and isinstance(value, basestring): + return StringField().prepare_query_value(op, value) + + return super(DictField,self).prepare_query_value(op, value) + + +class MapField(BaseField): + """A field that maps a name to a specified field type. Similar to + a DictField, except the 'value' of each item must match the specified + field type. + + .. versionadded:: 0.5 + """ + + def __init__(self, field=None, *args, **kwargs): + if not isinstance(field, BaseField): + raise ValidationError('Argument to MapField constructor must be ' + 'a valid field') + self.field = field + kwargs.setdefault('default', lambda: {}) + super(MapField, self).__init__(*args, **kwargs) + + def validate(self, value): + """Make sure that a list of valid fields is being used. + """ + if not isinstance(value, dict): + raise ValidationError('Only dictionaries may be used in a ' + 'DictField') + + if any(('.' in k or '$' in k) for k in value): + raise ValidationError('Invalid dictionary key name - keys may not ' + 'contain "." or "$" characters') + + try: + [self.field.validate(item) for item in value.values()] + except Exception, err: + raise ValidationError('Invalid MapField item (%s)' % str(item)) + + def __get__(self, instance, owner): + """Descriptor to automatically dereference references. + """ + if instance is None: + # Document class being used rather than a document object + return self + + if isinstance(self.field, ReferenceField): + referenced_type = self.field.document_type + # Get value from document instance if available + value_dict = instance._data.get(self.name) + if value_dict: + deref_dict = [] + for key,value in value_dict.iteritems(): + # Dereference DBRefs + if isinstance(value, (pymongo.dbref.DBRef)): + value = _get_db().dereference(value) + deref_dict[key] = referenced_type._from_son(value) + else: + deref_dict[key] = value + instance._data[self.name] = deref_dict + + if isinstance(self.field, GenericReferenceField): + value_dict = instance._data.get(self.name) + if value_dict: + deref_dict = [] + for key,value in value_dict.iteritems(): + # Dereference DBRefs + if isinstance(value, (dict, pymongo.son.SON)): + deref_dict[key] = self.field.dereference(value) + else: + deref_dict[key] = value + instance._data[self.name] = deref_dict + + return super(MapField, self).__get__(instance, owner) + + def to_python(self, value): + return dict( [(key,self.field.to_python(item)) for key,item in value.iteritems()] ) + + def to_mongo(self, value): + return dict( [(key,self.field.to_mongo(item)) for key,item in value.iteritems()] ) + + def prepare_query_value(self, op, value): + return self.field.prepare_query_value(op, value) + + def lookup_member(self, member_name): + return self.field.lookup_member(member_name) + + def _set_owner_document(self, owner_document): + self.field.owner_document = owner_document + self._owner_document = owner_document + + def _get_owner_document(self, owner_document): + self._owner_document = owner_document + + owner_document = property(_get_owner_document, _set_owner_document) + class ReferenceField(BaseField): """A reference to a document that will be automatically dereferenced on @@ -522,6 +623,9 @@ class GenericReferenceField(BaseField): """A reference to *any* :class:`~mongoengine.document.Document` subclass that will be automatically dereferenced on access (lazily). + note: Any documents used as a generic reference must be registered in the + document registry. Importing the model will automatically register it. + .. versionadded:: 0.3 """ @@ -601,6 +705,7 @@ class GridFSProxy(object): self.fs = gridfs.GridFS(_get_db()) # Filesystem instance self.newfile = None # Used for partial writes self.grid_id = grid_id # Store GridFS id for file + self.gridout = None def __getattr__(self, name): obj = self.get() @@ -614,8 +719,12 @@ class GridFSProxy(object): def get(self, id=None): if id: self.grid_id = id + if self.grid_id is None: + return None try: - return self.fs.get(id or self.grid_id) + if self.gridout is None: + self.gridout = self.fs.get(self.grid_id) + return self.gridout except: # File has been deleted return None @@ -643,11 +752,11 @@ class GridFSProxy(object): if not self.newfile: self.new_file() self.grid_id = self.newfile._id - self.newfile.writelines(lines) + self.newfile.writelines(lines) - def read(self): + def read(self, size=-1): try: - return self.get().read() + return self.get().read(size) except: return None @@ -655,6 +764,7 @@ class GridFSProxy(object): # Delete file from GridFS, FileField still remains self.fs.delete(self.grid_id) self.grid_id = None + self.gridout = None def replace(self, file, **kwargs): self.delete() diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 17ebc2e9..54d4845e 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -8,6 +8,7 @@ import pymongo.objectid import re import copy import itertools +import operator __all__ = ['queryset_manager', 'Q', 'InvalidQueryError', 'InvalidCollectionError', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY'] @@ -280,30 +281,30 @@ class QueryFieldList(object): ONLY = True EXCLUDE = False - def __init__(self, fields=[], direction=ONLY, always_include=[]): - self.direction = direction + def __init__(self, fields=[], value=ONLY, always_include=[]): + self.value = value self.fields = set(fields) self.always_include = set(always_include) def as_dict(self): - return dict((field, self.direction) for field in self.fields) + return dict((field, self.value) for field in self.fields) def __add__(self, f): if not self.fields: self.fields = f.fields - self.direction = f.direction - elif self.direction is self.ONLY and f.direction is self.ONLY: + self.value = f.value + elif self.value is self.ONLY and f.value is self.ONLY: self.fields = self.fields.intersection(f.fields) - elif self.direction is self.EXCLUDE and f.direction is self.EXCLUDE: + elif self.value is self.EXCLUDE and f.value is self.EXCLUDE: self.fields = self.fields.union(f.fields) - elif self.direction is self.ONLY and f.direction is self.EXCLUDE: + elif self.value is self.ONLY and f.value is self.EXCLUDE: self.fields -= f.fields - elif self.direction is self.EXCLUDE and f.direction is self.ONLY: - self.direction = self.ONLY + elif self.value is self.EXCLUDE and f.value is self.ONLY: + self.value = self.ONLY self.fields = f.fields - self.fields if self.always_include: - if self.direction is self.ONLY and self.fields: + if self.value is self.ONLY and self.fields: self.fields = self.fields.union(self.always_include) else: self.fields -= self.always_include @@ -311,7 +312,7 @@ class QueryFieldList(object): def reset(self): self.fields = set([]) - self.direction = self.ONLY + self.value = self.ONLY def __nonzero__(self): return bool(self.fields) @@ -334,6 +335,7 @@ class QuerySet(object): self._ordering = [] self._snapshot = False self._timeout = True + self._class_check = True # If inheritance is allowed, only return instances and instances of # subclasses of the class being used @@ -344,11 +346,26 @@ class QuerySet(object): self._limit = None self._skip = None + def clone(self): + """Creates a copy of the current :class:`~mongoengine.queryset.QuerySet`""" + c = self.__class__(self._document, self._collection_obj) + + copy_props = ('_initial_query', '_query_obj', '_where_clause', + '_loaded_fields', '_ordering', '_snapshot', + '_timeout', '_limit', '_skip') + + for prop in copy_props: + val = getattr(self, prop) + setattr(c, prop, copy.deepcopy(val)) + + return c + @property def _query(self): if self._mongo_query is None: self._mongo_query = self._query_obj.to_query(self._document) - self._mongo_query.update(self._initial_query) + if self._class_check: + self._mongo_query.update(self._initial_query) return self._mongo_query def ensure_index(self, key_or_list, drop_dups=False, background=False, @@ -399,7 +416,7 @@ class QuerySet(object): return index_list - def __call__(self, q_obj=None, **query): + def __call__(self, q_obj=None, class_check=True, **query): """Filter the selected documents by calling the :class:`~mongoengine.queryset.QuerySet` with a query. @@ -407,16 +424,17 @@ class QuerySet(object): the query; the :class:`~mongoengine.queryset.QuerySet` is filtered multiple times with different :class:`~mongoengine.queryset.Q` objects, only the last one will be used + :param class_check: If set to False bypass class name check when + querying collection :param query: Django-style query keyword arguments """ - #if q_obj: - #self._where_clause = q_obj.as_js(self._document) query = Q(**query) if q_obj: query &= q_obj self._query_obj &= query self._mongo_query = None self._cursor_obj = None + self._class_check = class_check return self def filter(self, *q_objs, **query): @@ -440,17 +458,17 @@ class QuerySet(object): drop_dups = self._document._meta.get('index_drop_dups', False) index_opts = self._document._meta.get('index_options', {}) + # Ensure indexes created by uniqueness constraints + for index in self._document._meta['unique_indexes']: + self._collection.ensure_index(index, unique=True, + background=background, drop_dups=drop_dups, **index_opts) + # Ensure document-defined indexes are created if self._document._meta['indexes']: for key_or_list in self._document._meta['indexes']: self._collection.ensure_index(key_or_list, background=background, **index_opts) - # Ensure indexes created by uniqueness constraints - for index in self._document._meta['unique_indexes']: - self._collection.ensure_index(index, unique=True, - background=background, drop_dups=drop_dups, **index_opts) - # If _types is being used (for polymorphism), it needs an index if '_types' in self._query: self._collection.ensure_index('_types', @@ -474,7 +492,7 @@ class QuerySet(object): } if self._loaded_fields: cursor_args['fields'] = self._loaded_fields.as_dict() - self._cursor_obj = self._collection.find(self._query, + self._cursor_obj = self._collection.find(self._query, **cursor_args) # Apply where clauses to cursor if self._where_clause: @@ -504,6 +522,15 @@ class QuerySet(object): fields = [] field = None for field_name in parts: + # Handle ListField indexing: + if field_name.isdigit(): + try: + field = field.field + except AttributeError, err: + raise InvalidQueryError( + "Can't use index on unsubscriptable field (%s)" % err) + fields.append(field_name) + continue if field is None: # Look up first field from the document if field_name == 'pk': @@ -528,14 +555,14 @@ class QuerySet(object): return '.'.join(parts) @classmethod - def _transform_query(cls, _doc_cls=None, **query): + def _transform_query(cls, _doc_cls=None, _field_operation=False, **query): """Transform a query from Django-style format to Mongo format. """ operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', 'all', 'size', 'exists', 'not'] - geo_operators = ['within_distance', 'within_box', 'near'] - match_operators = ['contains', 'icontains', 'startswith', - 'istartswith', 'endswith', 'iendswith', + geo_operators = ['within_distance', 'within_spherical_distance', 'within_box', 'near', 'near_sphere'] + match_operators = ['contains', 'icontains', 'startswith', + 'istartswith', 'endswith', 'iendswith', 'exact', 'iexact'] mongo_query = {} @@ -577,8 +604,12 @@ class QuerySet(object): if op in geo_operators: if op == "within_distance": value = {'$within': {'$center': value}} + elif op == "within_spherical_distance": + value = {'$within': {'$centerSphere': value}} elif op == "near": value = {'$near': value} + elif op == "near_sphere": + value = {'$nearSphere': value} elif op == 'within_box': value = {'$within': {'$box': value}} else: @@ -620,9 +651,9 @@ class QuerySet(object): raise self._document.DoesNotExist("%s matching query does not exist." % self._document._class_name) - def get_or_create(self, *q_objs, **query): - """Retrieve unique object or create, if it doesn't exist. Returns a tuple of - ``(object, created)``, where ``object`` is the retrieved or created object + def get_or_create(self, write_options=None, *q_objs, **query): + """Retrieve unique object or create, if it doesn't exist. Returns a tuple of + ``(object, created)``, where ``object`` is the retrieved or created object and ``created`` is a boolean specifying whether a new object was created. Raises :class:`~mongoengine.queryset.MultipleObjectsReturned` or `DocumentName.MultipleObjectsReturned` if multiple results are found. @@ -630,6 +661,10 @@ class QuerySet(object): dictionary of default values for the new document may be provided as a keyword argument called :attr:`defaults`. + :param write_options: optional extra keyword arguments used if we + have to create a new document. + Passes any write_options onto :meth:`~mongoengine.document.Document.save` + .. versionadded:: 0.3 """ defaults = query.get('defaults', {}) @@ -641,7 +676,7 @@ class QuerySet(object): if count == 0: query.update(defaults) doc = self._document(**query) - doc.save() + doc.save(write_options=write_options) return doc, True elif count == 1: return self.first(), False @@ -725,7 +760,7 @@ class QuerySet(object): def __len__(self): return self.count() - def map_reduce(self, map_f, reduce_f, finalize_f=None, limit=None, + def map_reduce(self, map_f, reduce_f, output, finalize_f=None, limit=None, scope=None, keep_temp=False): """Perform a map/reduce query using the current query spec and ordering. While ``map_reduce`` respects ``QuerySet`` chaining, @@ -739,26 +774,26 @@ class QuerySet(object): :param map_f: map function, as :class:`~pymongo.code.Code` or string :param reduce_f: reduce function, as :class:`~pymongo.code.Code` or string + :param output: output collection name :param finalize_f: finalize function, an optional function that performs any post-reduction processing. :param scope: values to insert into map/reduce global scope. Optional. :param limit: number of objects from current query to provide to map/reduce method - :param keep_temp: keep temporary table (boolean, default ``True``) Returns an iterator yielding :class:`~mongoengine.document.MapReduceDocument`. - .. note:: Map/Reduce requires server version **>= 1.1.1**. The PyMongo + .. note:: Map/Reduce changed in server version **>= 1.7.4**. The PyMongo :meth:`~pymongo.collection.Collection.map_reduce` helper requires - PyMongo version **>= 1.2**. + PyMongo version **>= 1.11**. .. versionadded:: 0.3 """ from document import MapReduceDocument if not hasattr(self._collection, "map_reduce"): - raise NotImplementedError("Requires MongoDB >= 1.1.1") + raise NotImplementedError("Requires MongoDB >= 1.7.1") map_f_scope = {} if isinstance(map_f, pymongo.code.Code): @@ -789,8 +824,7 @@ class QuerySet(object): if limit: mr_args['limit'] = limit - - results = self._collection.map_reduce(map_f, reduce_f, **mr_args) + results = self._collection.map_reduce(map_f, reduce_f, output, **mr_args) results = results.find() if self._ordering: @@ -835,7 +869,7 @@ class QuerySet(object): self._skip, self._limit = key.start, key.stop except IndexError, err: # PyMongo raises an error if key.start == key.stop, catch it, - # bin it, kill it. + # bin it, kill it. start = key.start or 0 if start >= 0 and key.stop >= 0 and key.step is None: if start == key.stop: @@ -868,10 +902,8 @@ class QuerySet(object): .. versionadded:: 0.3 """ - fields = self._fields_to_dbfields(fields) - self._loaded_fields += QueryFieldList(fields, direction=QueryFieldList.ONLY) - return self - + fields = dict([(f, QueryFieldList.ONLY) for f in fields]) + return self.fields(**fields) def exclude(self, *fields): """Opposite to .only(), exclude some document's fields. :: @@ -880,8 +912,44 @@ class QuerySet(object): :param fields: fields to exclude """ - fields = self._fields_to_dbfields(fields) - self._loaded_fields += QueryFieldList(fields, direction=QueryFieldList.EXCLUDE) + fields = dict([(f, QueryFieldList.EXCLUDE) for f in fields]) + return self.fields(**fields) + + def fields(self, **kwargs): + """Manipulate how you load this document's fields. Used by `.only()` + and `.exclude()` to manipulate which fields to retrieve. Fields also + allows for a greater level of control for example: + + Retrieving a Subrange of Array Elements + --------------------------------------- + + You can use the $slice operator to retrieve a subrange of elements in + an array :: + + post = BlogPost.objects(...).fields(slice__comments=5) // first 5 comments + + :param kwargs: A dictionary identifying what to include + + .. versionadded:: 0.5 + """ + + # Check for an operator and transform to mongo-style if there is + operators = ["slice"] + cleaned_fields = [] + for key, value in kwargs.items(): + parts = key.split('__') + op = None + if parts[0] in operators: + op = parts.pop(0) + value = {'$' + op: value} + key = '.'.join(parts) + cleaned_fields.append((key, value)) + + fields = sorted(cleaned_fields, key=operator.itemgetter(1)) + for value, group in itertools.groupby(fields, lambda x: x[1]): + fields = [field for field, value in group] + fields = self._fields_to_dbfields(fields) + self._loaded_fields += QueryFieldList(fields, value=value) return self def all_fields(self): @@ -917,6 +985,10 @@ class QuerySet(object): if key[0] in ('-', '+'): key = key[1:] key = key.replace('__', '.') + try: + key = QuerySet._translate_field_name(self._document, key) + except: + pass key_list.append((key, direction)) self._ordering = key_list @@ -1007,10 +1079,17 @@ class QuerySet(object): if _doc_cls: # Switch field names to proper names [set in Field(name='foo')] fields = QuerySet._lookup_field(_doc_cls, parts) - parts = [field.db_field for field in fields] + parts = [] + + for field in fields: + if isinstance(field, str): + parts.append(field) + else: + parts.append(field.db_field) # Convert value to proper value field = fields[-1] + if op in (None, 'set', 'push', 'pull', 'addToSet'): value = field.prepare_query_value(op, value) elif op in ('pushAll', 'pullAll'): @@ -1029,22 +1108,27 @@ class QuerySet(object): return mongo_update - def update(self, safe_update=True, upsert=False, **update): - """Perform an atomic update on the fields matched by the query. When + def update(self, safe_update=True, upsert=False, write_options=None, **update): + """Perform an atomic update on the fields matched by the query. When ``safe_update`` is used, the number of affected documents is returned. - :param safe: check if the operation succeeded before returning - :param update: Django-style update keyword arguments + :param safe_update: check if the operation succeeded before returning + :param upsert: Any existing document with that "_id" is overwritten. + :param write_options: extra keyword arguments for :meth:`~pymongo.collection.Collection.update` .. versionadded:: 0.2 """ if pymongo.version < '1.1.1': raise OperationError('update() method requires PyMongo 1.1.1+') + if not write_options: + write_options = {} + update = QuerySet._transform_update(self._document, **update) try: ret = self._collection.update(self._query, update, multi=True, - upsert=upsert, safe=safe_update) + upsert=upsert, safe=safe_update, + **write_options) if ret is not None and 'n' in ret: return ret['n'] except pymongo.errors.OperationFailure, err: @@ -1053,22 +1137,27 @@ class QuerySet(object): raise OperationError(message) raise OperationError(u'Update failed (%s)' % unicode(err)) - def update_one(self, safe_update=True, upsert=False, **update): - """Perform an atomic update on first field matched by the query. When + def update_one(self, safe_update=True, upsert=False, write_options=None, **update): + """Perform an atomic update on first field matched by the query. When ``safe_update`` is used, the number of affected documents is returned. - :param safe: check if the operation succeeded before returning + :param safe_update: check if the operation succeeded before returning + :param upsert: Any existing document with that "_id" is overwritten. + :param write_options: extra keyword arguments for :meth:`~pymongo.collection.Collection.update` :param update: Django-style update keyword arguments .. versionadded:: 0.2 """ + if not write_options: + write_options = {} update = QuerySet._transform_update(self._document, **update) try: # Explicitly provide 'multi=False' to newer versions of PyMongo # as the default may change to 'True' if pymongo.version >= '1.1.1': ret = self._collection.update(self._query, update, multi=False, - upsert=upsert, safe=safe_update) + upsert=upsert, safe=safe_update, + **write_options) else: # Older versions of PyMongo don't support 'multi' ret = self._collection.update(self._query, update, @@ -1082,8 +1171,8 @@ class QuerySet(object): return self def _sub_js_fields(self, code): - """When fields are specified with [~fieldname] syntax, where - *fieldname* is the Python name of a field, *fieldname* will be + """When fields are specified with [~fieldname] syntax, where + *fieldname* is the Python name of a field, *fieldname* will be substituted for the MongoDB name of the field (specified using the :attr:`name` keyword argument in a field's constructor). """ @@ -1106,9 +1195,9 @@ class QuerySet(object): options specified as keyword arguments. As fields in MongoEngine may use different names in the database (set - using the :attr:`db_field` keyword argument to a :class:`Field` + using the :attr:`db_field` keyword argument to a :class:`Field` constructor), a mechanism exists for replacing MongoEngine field names - with the database field names in Javascript code. When accessing a + with the database field names in Javascript code. When accessing a field, use square-bracket notation, and prefix the MongoEngine field name with a tilde (~). @@ -1241,8 +1330,11 @@ class QuerySet(object): class QuerySetManager(object): - def __init__(self, manager_func=None): - self._manager_func = manager_func + get_queryset = None + + def __init__(self, queryset_func=None): + if queryset_func: + self.get_queryset = queryset_func self._collections = {} def __get__(self, instance, owner): @@ -1259,7 +1351,7 @@ class QuerySetManager(object): # Create collection as a capped collection if specified if owner._meta['max_size'] or owner._meta['max_documents']: # Get max document limit and max byte size from meta - max_size = owner._meta['max_size'] or 10000000 # 10MB default + max_size = owner._meta['max_size'] or 10000000 # 10MB default max_documents = owner._meta['max_documents'] if collection in db.collection_names(): @@ -1286,11 +1378,11 @@ class QuerySetManager(object): # owner is the document that contains the QuerySetManager queryset_class = owner._meta['queryset_class'] or QuerySet queryset = queryset_class(owner, self._collections[(db, collection)]) - if self._manager_func: - if self._manager_func.func_code.co_argcount == 1: - queryset = self._manager_func(queryset) + if self.get_queryset: + if self.get_queryset.func_code.co_argcount == 1: + queryset = self.get_queryset(queryset) else: - queryset = self._manager_func(owner, queryset) + queryset = self.get_queryset(owner, queryset) return queryset diff --git a/tests/document.py b/tests/document.py index 0da7b93e..fe67312e 100644 --- a/tests/document.py +++ b/tests/document.py @@ -1,13 +1,25 @@ import unittest from datetime import datetime import pymongo +import pickle from mongoengine import * +from mongoengine.base import BaseField from mongoengine.connection import _get_db +class PickleEmbedded(EmbeddedDocument): + date = DateTimeField(default=datetime.now) + +class PickleTest(Document): + number = IntField() + string = StringField() + embedded = EmbeddedDocumentField(PickleEmbedded) + lists = ListField(StringField()) + + class DocumentTest(unittest.TestCase): - + def setUp(self): connect(db='mongoenginetest') self.db = _get_db() @@ -17,6 +29,9 @@ class DocumentTest(unittest.TestCase): age = IntField() self.Person = Person + def tearDown(self): + self.Person.drop_collection() + def test_drop_collection(self): """Ensure that the collection may be dropped from the database. """ @@ -38,7 +53,7 @@ class DocumentTest(unittest.TestCase): name = name_field age = age_field non_field = True - + self.assertEqual(Person._fields['name'], name_field) self.assertEqual(Person._fields['age'], age_field) self.assertFalse('non_field' in Person._fields) @@ -60,7 +75,7 @@ class DocumentTest(unittest.TestCase): mammal_superclasses = {'Animal': Animal} self.assertEqual(Mammal._superclasses, mammal_superclasses) - + dog_superclasses = { 'Animal': Animal, 'Animal.Mammal': Mammal, @@ -68,7 +83,7 @@ class DocumentTest(unittest.TestCase): self.assertEqual(Dog._superclasses, dog_superclasses) def test_get_subclasses(self): - """Ensure that the correct list of subclasses is retrieved by the + """Ensure that the correct list of subclasses is retrieved by the _get_subclasses method. """ class Animal(Document): pass @@ -78,15 +93,15 @@ class DocumentTest(unittest.TestCase): class Dog(Mammal): pass mammal_subclasses = { - 'Animal.Mammal.Dog': Dog, + 'Animal.Mammal.Dog': Dog, 'Animal.Mammal.Human': Human } self.assertEqual(Mammal._get_subclasses(), mammal_subclasses) - + animal_subclasses = { 'Animal.Fish': Fish, 'Animal.Mammal': Mammal, - 'Animal.Mammal.Dog': Dog, + 'Animal.Mammal.Dog': Dog, 'Animal.Mammal.Human': Human } self.assertEqual(Animal._get_subclasses(), animal_subclasses) @@ -124,7 +139,7 @@ class DocumentTest(unittest.TestCase): self.assertTrue('name' in Employee._fields) self.assertTrue('salary' in Employee._fields) - self.assertEqual(Employee._meta['collection'], + self.assertEqual(Employee._meta['collection'], self.Person._meta['collection']) # Ensure that MRO error is not raised @@ -146,7 +161,7 @@ class DocumentTest(unittest.TestCase): class Dog(Animal): pass self.assertRaises(ValueError, create_dog_class) - + # Check that _cls etc aren't present on simple documents dog = Animal(name='dog') dog.save() @@ -161,7 +176,7 @@ class DocumentTest(unittest.TestCase): class Employee(self.Person): meta = {'allow_inheritance': False} self.assertRaises(ValueError, create_employee_class) - + # Test the same for embedded documents class Comment(EmbeddedDocument): content = StringField() @@ -176,6 +191,34 @@ class DocumentTest(unittest.TestCase): self.assertFalse('_cls' in comment.to_mongo()) self.assertFalse('_types' in comment.to_mongo()) + def test_abstract_documents(self): + """Ensure that a document superclass can be marked as abstract + thereby not using it as the name for the collection.""" + + class Animal(Document): + name = StringField() + meta = {'abstract': True} + + class Fish(Animal): pass + class Guppy(Fish): pass + + class Mammal(Animal): + meta = {'abstract': True} + class Human(Mammal): pass + + self.assertFalse('collection' in Animal._meta) + self.assertFalse('collection' in Mammal._meta) + + self.assertEqual(Fish._meta['collection'], 'fish') + self.assertEqual(Guppy._meta['collection'], 'fish') + self.assertEqual(Human._meta['collection'], 'human') + + def create_bad_abstract(): + class EvilHuman(Human): + evil = BooleanField(default=True) + meta = {'abstract': True} + self.assertRaises(ValueError, create_bad_abstract) + def test_collection_name(self): """Ensure that a collection with a specified name may be used. """ @@ -186,7 +229,7 @@ class DocumentTest(unittest.TestCase): class Person(Document): name = StringField() meta = {'collection': collection} - + user = Person(name="Test User") user.save() self.assertTrue(collection in self.db.collection_names()) @@ -200,6 +243,22 @@ class DocumentTest(unittest.TestCase): Person.drop_collection() self.assertFalse(collection in self.db.collection_names()) + def test_collection_name_and_primary(self): + """Ensure that a collection with a specified name may be used. + """ + + class Person(Document): + name = StringField(primary_key=True) + meta = {'collection': 'app'} + + user = Person(name="Test User") + user.save() + + user_obj = Person.objects[0] + self.assertEqual(user_obj.name, "Test User") + + Person.drop_collection() + def test_inherited_collections(self): """Ensure that subclassed documents don't override parents' collections. """ @@ -280,7 +339,7 @@ class DocumentTest(unittest.TestCase): tags = ListField(StringField()) meta = { 'indexes': [ - '-date', + '-date', 'tags', ('category', '-date') ], @@ -296,12 +355,12 @@ class DocumentTest(unittest.TestCase): list(BlogPost.objects) info = BlogPost.objects._collection.index_information() info = [value['key'] for key, value in info.iteritems()] - self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)] + self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)] in info) self.assertTrue([('_types', 1), ('addDate', -1)] in info) # tags is a list field so it shouldn't have _types in the index self.assertTrue([('tags', 1)] in info) - + class ExtendedBlogPost(BlogPost): title = StringField() meta = {'indexes': ['title']} @@ -311,7 +370,7 @@ class DocumentTest(unittest.TestCase): list(ExtendedBlogPost.objects) info = ExtendedBlogPost.objects._collection.index_information() info = [value['key'] for key, value in info.iteritems()] - self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)] + self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)] in info) self.assertTrue([('_types', 1), ('addDate', -1)] in info) self.assertTrue([('_types', 1), ('title', 1)] in info) @@ -334,6 +393,10 @@ class DocumentTest(unittest.TestCase): post2 = BlogPost(title='test2', slug='test') self.assertRaises(OperationError, post2.save) + + def test_unique_with(self): + """Ensure that unique_with constraints are applied to fields. + """ class Date(EmbeddedDocument): year = IntField(db_field='yr') @@ -357,6 +420,108 @@ class DocumentTest(unittest.TestCase): BlogPost.drop_collection() + def test_unique_embedded_document(self): + """Ensure that uniqueness constraints are applied to fields on embedded documents. + """ + class SubDocument(EmbeddedDocument): + year = IntField(db_field='yr') + slug = StringField(unique=True) + + class BlogPost(Document): + title = StringField() + sub = EmbeddedDocumentField(SubDocument) + + BlogPost.drop_collection() + + post1 = BlogPost(title='test1', sub=SubDocument(year=2009, slug="test")) + post1.save() + + # sub.slug is different so won't raise exception + post2 = BlogPost(title='test2', sub=SubDocument(year=2010, slug='another-slug')) + post2.save() + + # Now there will be two docs with the same sub.slug + post3 = BlogPost(title='test3', sub=SubDocument(year=2010, slug='test')) + self.assertRaises(OperationError, post3.save) + + BlogPost.drop_collection() + + def test_unique_with_embedded_document_and_embedded_unique(self): + """Ensure that uniqueness constraints are applied to fields on + embedded documents. And work with unique_with as well. + """ + class SubDocument(EmbeddedDocument): + year = IntField(db_field='yr') + slug = StringField(unique=True) + + class BlogPost(Document): + title = StringField(unique_with='sub.year') + sub = EmbeddedDocumentField(SubDocument) + + BlogPost.drop_collection() + + post1 = BlogPost(title='test1', sub=SubDocument(year=2009, slug="test")) + post1.save() + + # sub.slug is different so won't raise exception + post2 = BlogPost(title='test2', sub=SubDocument(year=2010, slug='another-slug')) + post2.save() + + # Now there will be two docs with the same sub.slug + post3 = BlogPost(title='test3', sub=SubDocument(year=2010, slug='test')) + self.assertRaises(OperationError, post3.save) + + # Now there will be two docs with the same title and year + post3 = BlogPost(title='test1', sub=SubDocument(year=2009, slug='test-1')) + self.assertRaises(OperationError, post3.save) + + BlogPost.drop_collection() + + def test_unique_and_indexes(self): + """Ensure that 'unique' constraints aren't overridden by + meta.indexes. + """ + class Customer(Document): + cust_id = IntField(unique=True, required=True) + meta = { + 'indexes': ['cust_id'], + 'allow_inheritance': False, + } + + Customer.drop_collection() + cust = Customer(cust_id=1) + cust.save() + + cust_dupe = Customer(cust_id=1) + try: + cust_dupe.save() + raise AssertionError, "We saved a dupe!" + except OperationError: + pass + Customer.drop_collection() + + def test_unique_and_primary(self): + """If you set a field as primary, then unexpected behaviour can occur. + You won't create a duplicate but you will update an existing document. + """ + + class User(Document): + name = StringField(primary_key=True, unique=True) + password = StringField() + + User.drop_collection() + + user = User(name='huangz', password='secret') + user.save() + + user = User(name='huangz', password='secret2') + user.save() + + self.assertEqual(User.objects.count(), 1) + self.assertEqual(User.objects.get().password, 'secret2') + + User.drop_collection() + def test_custom_id_field(self): """Ensure that documents may be created with custom primary keys. """ @@ -380,7 +545,7 @@ class DocumentTest(unittest.TestCase): class EmailUser(User): email = StringField() - + user = User(username='test', name='test user') user.save() @@ -391,20 +556,20 @@ class DocumentTest(unittest.TestCase): user_son = User.objects._collection.find_one() self.assertEqual(user_son['_id'], 'test') self.assertTrue('username' not in user_son['_id']) - + User.drop_collection() - + user = User(pk='mongo', name='mongo user') user.save() - + user_obj = User.objects.first() self.assertEqual(user_obj.id, 'mongo') self.assertEqual(user_obj.pk, 'mongo') - + user_son = User.objects._collection.find_one() self.assertEqual(user_son['_id'], 'mongo') self.assertTrue('username' not in user_son['_id']) - + User.drop_collection() def test_creation(self): @@ -457,18 +622,18 @@ class DocumentTest(unittest.TestCase): """ class Comment(EmbeddedDocument): content = StringField() - + self.assertTrue('content' in Comment._fields) self.assertFalse('id' in Comment._fields) self.assertFalse('collection' in Comment._meta) - + def test_embedded_document_validation(self): """Ensure that embedded documents may be validated. """ class Comment(EmbeddedDocument): date = DateTimeField() content = StringField(required=True) - + comment = Comment() self.assertRaises(ValidationError, comment.validate) @@ -496,7 +661,7 @@ class DocumentTest(unittest.TestCase): # Test skipping validation on save class Recipient(Document): email = EmailField(required=True) - + recipient = Recipient(email='root@localhost') self.assertRaises(ValidationError, recipient.save) try: @@ -517,19 +682,19 @@ class DocumentTest(unittest.TestCase): """Ensure that a document may be saved with a custom _id. """ # Create person object and save it to the database - person = self.Person(name='Test User', age=30, + person = self.Person(name='Test User', age=30, id='497ce96f395f2f052a494fd4') person.save() # Ensure that the object is in the database with the correct _id collection = self.db[self.Person._meta['collection']] person_obj = collection.find_one({'name': 'Test User'}) self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4') - + def test_save_custom_pk(self): """Ensure that a document may be saved with a custom _id using pk alias. """ # Create person object and save it to the database - person = self.Person(name='Test User', age=30, + person = self.Person(name='Test User', age=30, pk='497ce96f395f2f052a494fd4') person.save() # Ensure that the object is in the database with the correct _id @@ -565,7 +730,7 @@ class DocumentTest(unittest.TestCase): BlogPost.drop_collection() def test_save_embedded_document(self): - """Ensure that a document with an embedded document field may be + """Ensure that a document with an embedded document field may be saved in the database. """ class EmployeeDetails(EmbeddedDocument): @@ -588,10 +753,38 @@ class DocumentTest(unittest.TestCase): # Ensure that the 'details' embedded object saved correctly self.assertEqual(employee_obj['details']['position'], 'Developer') + def test_updating_an_embedded_document(self): + """Ensure that a document with an embedded document field may be + saved in the database. + """ + class EmployeeDetails(EmbeddedDocument): + position = StringField() + + class Employee(self.Person): + salary = IntField() + details = EmbeddedDocumentField(EmployeeDetails) + + # Create employee object and save it to the database + employee = Employee(name='Test Employee', age=50, salary=20000) + employee.details = EmployeeDetails(position='Developer') + employee.save() + + # Test updating an embedded document + promoted_employee = Employee.objects.get(name='Test Employee') + promoted_employee.details.position = 'Senior Developer' + promoted_employee.save() + + collection = self.db[self.Person._meta['collection']] + employee_obj = collection.find_one({'name': 'Test Employee'}) + self.assertEqual(employee_obj['name'], 'Test Employee') + self.assertEqual(employee_obj['age'], 50) + # Ensure that the 'details' embedded object saved correctly + self.assertEqual(employee_obj['details']['position'], 'Senior Developer') + def test_save_reference(self): """Ensure that a document reference field may be saved in the database. """ - + class BlogPost(Document): meta = {'collection': 'blogpost_1'} content = StringField() @@ -610,7 +803,7 @@ class DocumentTest(unittest.TestCase): post_obj = BlogPost.objects.first() # Test laziness - self.assertTrue(isinstance(post_obj._data['author'], + self.assertTrue(isinstance(post_obj._data['author'], pymongo.dbref.DBRef)) self.assertTrue(isinstance(post_obj.author, self.Person)) self.assertEqual(post_obj.author.name, 'Test User') @@ -725,9 +918,25 @@ class DocumentTest(unittest.TestCase): self.Person.drop_collection() BlogPost.drop_collection() + def subclasses_and_unique_keys_works(self): - def tearDown(self): - self.Person.drop_collection() + class A(Document): + pass + + class B(A): + foo = BooleanField(unique=True) + + A.drop_collection() + B.drop_collection() + + A().save() + A().save() + B(foo=True).save() + + self.assertEquals(A.objects.count(), 2) + self.assertEquals(B.objects.count(), 1) + A.drop_collection() + B.drop_collection() def test_document_hash(self): """Test document in list, dict, set @@ -737,7 +946,7 @@ class DocumentTest(unittest.TestCase): class BlogPost(Document): pass - + # Clear old datas User.drop_collection() BlogPost.drop_collection() @@ -774,9 +983,46 @@ class DocumentTest(unittest.TestCase): # in Set all_user_set = set(User.objects.all()) - + self.assertTrue(u1 in all_user_set ) - + + def test_picklable(self): + + pickle_doc = PickleTest(number=1, string="OH HAI", lists=['1', '2']) + pickle_doc.embedded = PickleEmbedded() + pickle_doc.save() + + pickled_doc = pickle.dumps(pickle_doc) + resurrected = pickle.loads(pickled_doc) + + self.assertEquals(resurrected, pickle_doc) + + resurrected.string = "Working" + resurrected.save() + + pickle_doc.reload() + self.assertEquals(resurrected, pickle_doc) + + def test_write_options(self): + """Test that passing write_options works""" + + self.Person.drop_collection() + + write_options = {"fsync": True} + + author, created = self.Person.objects.get_or_create( + name='Test User', write_options=write_options) + author.save(write_options=write_options) + + self.Person.objects.update(set__name='Ross', write_options=write_options) + + author = self.Person.objects.first() + self.assertEquals(author.name, 'Ross') + + self.Person.objects.update_one(set__name='Test User', write_options=write_options) + author = self.Person.objects.first() + self.assertEquals(author.name, 'Test User') + if __name__ == '__main__': unittest.main() diff --git a/tests/fields.py b/tests/fields.py index f24b5eda..00b1c886 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -7,6 +7,7 @@ import gridfs from mongoengine import * from mongoengine.connection import _get_db +from mongoengine.base import _document_registry, NotRegistered class FieldTest(unittest.TestCase): @@ -45,7 +46,7 @@ class FieldTest(unittest.TestCase): """ class Person(Document): name = StringField() - + person = Person(name='Test User') self.assertEqual(person.id, None) @@ -95,7 +96,7 @@ class FieldTest(unittest.TestCase): link.url = 'http://www.google.com:8080' link.validate() - + def test_int_validation(self): """Ensure that invalid values cannot be assigned to int fields. """ @@ -129,12 +130,12 @@ class FieldTest(unittest.TestCase): self.assertRaises(ValidationError, person.validate) person.height = 4.0 self.assertRaises(ValidationError, person.validate) - + def test_decimal_validation(self): """Ensure that invalid values cannot be assigned to decimal fields. """ class Person(Document): - height = DecimalField(min_value=Decimal('0.1'), + height = DecimalField(min_value=Decimal('0.1'), max_value=Decimal('3.5')) Person.drop_collection() @@ -249,7 +250,7 @@ class FieldTest(unittest.TestCase): post.save() post.reload() self.assertEqual(post.tags, ['fun', 'leisure']) - + comment1 = Comment(content='Good for you', order=1) comment2 = Comment(content='Yay.', order=0) comments = [comment1, comment2] @@ -261,12 +262,14 @@ class FieldTest(unittest.TestCase): BlogPost.drop_collection() - def test_dict_validation(self): + def test_dict_field(self): """Ensure that dict types work as expected. """ class BlogPost(Document): info = DictField() + BlogPost.drop_collection() + post = BlogPost() post.info = 'my post' self.assertRaises(ValidationError, post.validate) @@ -281,7 +284,24 @@ class FieldTest(unittest.TestCase): self.assertRaises(ValidationError, post.validate) post.info = {'title': 'test'} - post.validate() + post.save() + + post = BlogPost() + post.info = {'details': {'test': 'test'}} + post.save() + + post = BlogPost() + post.info = {'details': {'test': 3}} + post.save() + + self.assertEquals(BlogPost.objects.count(), 3) + self.assertEquals(BlogPost.objects.filter(info__title__exact='test').count(), 1) + self.assertEquals(BlogPost.objects.filter(info__details__test__exact='test').count(), 1) + + # Confirm handles non strings or non existing keys + self.assertEquals(BlogPost.objects.filter(info__details__test__exact=5).count(), 0) + self.assertEquals(BlogPost.objects.filter(info__made_up__test__exact='test').count(), 0) + BlogPost.drop_collection() def test_embedded_document_validation(self): """Ensure that invalid embedded documents cannot be assigned to @@ -315,7 +335,7 @@ class FieldTest(unittest.TestCase): person.validate() def test_embedded_document_inheritance(self): - """Ensure that subclasses of embedded documents may be provided to + """Ensure that subclasses of embedded documents may be provided to EmbeddedDocumentFields of the superclass' type. """ class User(EmbeddedDocument): @@ -327,7 +347,7 @@ class FieldTest(unittest.TestCase): class BlogPost(Document): content = StringField() author = EmbeddedDocumentField(User) - + post = BlogPost(content='What I did today...') post.author = User(name='Test User') post.author = PowerUser(name='Test User', power=47) @@ -370,7 +390,7 @@ class FieldTest(unittest.TestCase): User.drop_collection() BlogPost.drop_collection() - + def test_list_item_dereference(self): """Ensure that DBRef items in ListFields are dereferenced. """ @@ -434,7 +454,7 @@ class FieldTest(unittest.TestCase): class TreeNode(EmbeddedDocument): name = StringField() children = ListField(EmbeddedDocumentField('self')) - + tree = Tree(name="Tree") first_child = TreeNode(name="Child 1") @@ -442,7 +462,7 @@ class FieldTest(unittest.TestCase): second_child = TreeNode(name="Child 2") first_child.children.append(second_child) - + third_child = TreeNode(name="Child 3") first_child.children.append(third_child) @@ -506,20 +526,20 @@ class FieldTest(unittest.TestCase): Member.drop_collection() BlogPost.drop_collection() - + def test_generic_reference(self): """Ensure that a GenericReferenceField properly dereferences items. """ class Link(Document): title = StringField() meta = {'allow_inheritance': False} - + class Post(Document): title = StringField() - + class Bookmark(Document): bookmark_object = GenericReferenceField() - + Link.drop_collection() Post.drop_collection() Bookmark.drop_collection() @@ -574,16 +594,49 @@ class FieldTest(unittest.TestCase): user = User(bookmarks=[post_1, link_1]) user.save() - + user = User.objects(bookmarks__all=[post_1, link_1]).first() - + self.assertEqual(user.bookmarks[0], post_1) self.assertEqual(user.bookmarks[1], link_1) - + Link.drop_collection() Post.drop_collection() User.drop_collection() + def test_generic_reference_document_not_registered(self): + """Ensure dereferencing out of the document registry throws a + `NotRegistered` error. + """ + class Link(Document): + title = StringField() + + class User(Document): + bookmarks = ListField(GenericReferenceField()) + + Link.drop_collection() + User.drop_collection() + + link_1 = Link(title="Pitchfork") + link_1.save() + + user = User(bookmarks=[link_1]) + user.save() + + # Mimic User and Link definitions being in a different file + # and the Link model not being imported in the User file. + del(_document_registry["Link"]) + + user = User.objects.first() + try: + user.bookmarks + raise AssertionError, "Link was removed from the registry" + except NotRegistered: + pass + + Link.drop_collection() + User.drop_collection() + def test_binary_fields(self): """Ensure that binary fields can be stored and retrieved. """ @@ -701,6 +754,12 @@ class FieldTest(unittest.TestCase): self.assertTrue(streamfile == result) self.assertEquals(result.file.read(), text + more_text) self.assertEquals(result.file.content_type, content_type) + result.file.seek(0) + self.assertEquals(result.file.tell(), 0) + self.assertEquals(result.file.read(len(text)), text) + self.assertEquals(result.file.tell(), len(text)) + self.assertEquals(result.file.read(len(more_text)), more_text) + self.assertEquals(result.file.tell(), len(text + more_text)) result.file.delete() # Ensure deleted file returns None @@ -721,7 +780,7 @@ class FieldTest(unittest.TestCase): result = SetFile.objects.first() self.assertTrue(setfile == result) self.assertEquals(result.file.read(), more_text) - result.file.delete() + result.file.delete() PutFile.drop_collection() StreamFile.drop_collection() @@ -785,5 +844,66 @@ class FieldTest(unittest.TestCase): self.assertEqual(d2.data, {}) self.assertEqual(d2.data2, {}) + def test_mapfield(self): + """Ensure that the MapField handles the declared type.""" + + class Simple(Document): + mapping = MapField(IntField()) + + Simple.drop_collection() + + e = Simple() + e.mapping['someint'] = 1 + e.save() + + def create_invalid_mapping(): + e.mapping['somestring'] = "abc" + e.save() + + self.assertRaises(ValidationError, create_invalid_mapping) + + def create_invalid_class(): + class NoDeclaredType(Document): + mapping = MapField() + + self.assertRaises(ValidationError, create_invalid_class) + + Simple.drop_collection() + + def test_complex_mapfield(self): + """Ensure that the MapField can handle complex declared types.""" + + class SettingBase(EmbeddedDocument): + pass + + class StringSetting(SettingBase): + value = StringField() + + class IntegerSetting(SettingBase): + value = IntField() + + class Extensible(Document): + mapping = MapField(EmbeddedDocumentField(SettingBase)) + + Extensible.drop_collection() + + e = Extensible() + e.mapping['somestring'] = StringSetting(value='foo') + e.mapping['someint'] = IntegerSetting(value=42) + e.save() + + e2 = Extensible.objects.get(id=e.id) + self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting)) + self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting)) + + def create_invalid_mapping(): + e.mapping['someint'] = 123 + e.save() + + self.assertRaises(ValidationError, create_invalid_mapping) + + Extensible.drop_collection() + + if __name__ == '__main__': unittest.main() diff --git a/tests/queryset.py b/tests/queryset.py index d503cf3f..5cf08957 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -5,8 +5,9 @@ import unittest import pymongo from datetime import datetime, timedelta -from mongoengine.queryset import (QuerySet, MultipleObjectsReturned, - DoesNotExist, QueryFieldList) +from mongoengine.queryset import (QuerySet, QuerySetManager, + MultipleObjectsReturned, DoesNotExist, + QueryFieldList) from mongoengine import * @@ -105,6 +106,10 @@ class QuerySetTest(unittest.TestCase): people = list(self.Person.objects[1:1]) self.assertEqual(len(people), 0) + # Test slice out of range + people = list(self.Person.objects[80000:80001]) + self.assertEqual(len(people), 0) + def test_find_one(self): """Ensure that a query using find_one returns a valid result. """ @@ -162,7 +167,7 @@ class QuerySetTest(unittest.TestCase): person = self.Person.objects.get(age__lt=30) self.assertEqual(person.name, "User A") - + def test_find_array_position(self): """Ensure that query by array position works. """ @@ -177,7 +182,7 @@ class QuerySetTest(unittest.TestCase): posts = ListField(EmbeddedDocumentField(Post)) Blog.drop_collection() - + Blog.objects.create(tags=['a', 'b']) self.assertEqual(len(Blog.objects(tags__0='a')), 1) self.assertEqual(len(Blog.objects(tags__0='b')), 0) @@ -207,6 +212,55 @@ class QuerySetTest(unittest.TestCase): Blog.drop_collection() + def test_update_array_position(self): + """Ensure that updating by array position works. + + Check update() and update_one() can take syntax like: + set__posts__1__comments__1__name="testc" + Check that it only works for ListFields. + """ + class Comment(EmbeddedDocument): + name = StringField() + + class Post(EmbeddedDocument): + comments = ListField(EmbeddedDocumentField(Comment)) + + class Blog(Document): + tags = ListField(StringField()) + posts = ListField(EmbeddedDocumentField(Post)) + + Blog.drop_collection() + + comment1 = Comment(name='testa') + comment2 = Comment(name='testb') + post1 = Post(comments=[comment1, comment2]) + post2 = Post(comments=[comment2, comment2]) + blog1 = Blog.objects.create(posts=[post1, post2]) + blog2 = Blog.objects.create(posts=[post2, post1]) + + # Update all of the first comments of second posts of all blogs + blog = Blog.objects().update(set__posts__1__comments__0__name="testc") + testc_blogs = Blog.objects(posts__1__comments__0__name="testc") + self.assertEqual(len(testc_blogs), 2) + + Blog.drop_collection() + + blog1 = Blog.objects.create(posts=[post1, post2]) + blog2 = Blog.objects.create(posts=[post2, post1]) + + # Update only the first blog returned by the query + blog = Blog.objects().update_one( + set__posts__1__comments__1__name="testc") + testc_blogs = Blog.objects(posts__1__comments__1__name="testc") + self.assertEqual(len(testc_blogs), 1) + + # Check that using this indexing syntax on a non-list fails + def non_list_indexing(): + Blog.objects().update(set__posts__1__comments__0__name__1="asdf") + self.assertRaises(InvalidQueryError, non_list_indexing) + + Blog.drop_collection() + def test_get_or_create(self): """Ensure that ``get_or_create`` returns one result or creates a new document. @@ -226,16 +280,16 @@ class QuerySetTest(unittest.TestCase): person, created = self.Person.objects.get_or_create(age=30) self.assertEqual(person.name, "User B") self.assertEqual(created, False) - + person, created = self.Person.objects.get_or_create(age__lt=30) self.assertEqual(person.name, "User A") self.assertEqual(created, False) - + # Try retrieving when no objects exists - new doc should be created kwargs = dict(age=50, defaults={'name': 'User C'}) person, created = self.Person.objects.get_or_create(**kwargs) self.assertEqual(created, True) - + person = self.Person.objects.get(age=50) self.assertEqual(person.name, "User C") @@ -328,7 +382,7 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(obj, person) obj = self.Person.objects(Q(name__iexact='gUIDO VAN rOSSU')).first() self.assertEqual(obj, None) - + # Test unsafe expressions person = self.Person(name='Guido van Rossum [.\'Geek\']') person.save() @@ -593,6 +647,81 @@ class QuerySetTest(unittest.TestCase): Email.drop_collection() + def test_slicing_fields(self): + """Ensure that query slicing an array works. + """ + class Numbers(Document): + n = ListField(IntField()) + + Numbers.drop_collection() + + numbers = Numbers(n=[0,1,2,3,4,5,-5,-4,-3,-2,-1]) + numbers.save() + + # first three + numbers = Numbers.objects.fields(slice__n=3).get() + self.assertEquals(numbers.n, [0, 1, 2]) + + # last three + numbers = Numbers.objects.fields(slice__n=-3).get() + self.assertEquals(numbers.n, [-3, -2, -1]) + + # skip 2, limit 3 + numbers = Numbers.objects.fields(slice__n=[2, 3]).get() + self.assertEquals(numbers.n, [2, 3, 4]) + + # skip to fifth from last, limit 4 + numbers = Numbers.objects.fields(slice__n=[-5, 4]).get() + self.assertEquals(numbers.n, [-5, -4, -3, -2]) + + # skip to fifth from last, limit 10 + numbers = Numbers.objects.fields(slice__n=[-5, 10]).get() + self.assertEquals(numbers.n, [-5, -4, -3, -2, -1]) + + # skip to fifth from last, limit 10 dict method + numbers = Numbers.objects.fields(n={"$slice": [-5, 10]}).get() + self.assertEquals(numbers.n, [-5, -4, -3, -2, -1]) + + def test_slicing_nested_fields(self): + """Ensure that query slicing an embedded array works. + """ + + class EmbeddedNumber(EmbeddedDocument): + n = ListField(IntField()) + + class Numbers(Document): + embedded = EmbeddedDocumentField(EmbeddedNumber) + + Numbers.drop_collection() + + numbers = Numbers() + numbers.embedded = EmbeddedNumber(n=[0,1,2,3,4,5,-5,-4,-3,-2,-1]) + numbers.save() + + # first three + numbers = Numbers.objects.fields(slice__embedded__n=3).get() + self.assertEquals(numbers.embedded.n, [0, 1, 2]) + + # last three + numbers = Numbers.objects.fields(slice__embedded__n=-3).get() + self.assertEquals(numbers.embedded.n, [-3, -2, -1]) + + # skip 2, limit 3 + numbers = Numbers.objects.fields(slice__embedded__n=[2, 3]).get() + self.assertEquals(numbers.embedded.n, [2, 3, 4]) + + # skip to fifth from last, limit 4 + numbers = Numbers.objects.fields(slice__embedded__n=[-5, 4]).get() + self.assertEquals(numbers.embedded.n, [-5, -4, -3, -2]) + + # skip to fifth from last, limit 10 + numbers = Numbers.objects.fields(slice__embedded__n=[-5, 10]).get() + self.assertEquals(numbers.embedded.n, [-5, -4, -3, -2, -1]) + + # skip to fifth from last, limit 10 dict method + numbers = Numbers.objects.fields(embedded__n={"$slice": [-5, 10]}).get() + self.assertEquals(numbers.embedded.n, [-5, -4, -3, -2, -1]) + def test_find_embedded(self): """Ensure that an embedded document is properly returned from a query. """ @@ -674,7 +803,7 @@ class QuerySetTest(unittest.TestCase): posts = [post.id for post in q] published_posts = (post1, post2, post3, post5, post6) self.assertTrue(all(obj.id in posts for obj in published_posts)) - + # Check Q object combination date = datetime(2010, 1, 10) @@ -714,7 +843,7 @@ class QuerySetTest(unittest.TestCase): obj = self.Person.objects(Q(name__not=re.compile('^bob'))).first() self.assertEqual(obj, person) - + obj = self.Person.objects(Q(name__not=re.compile('^Gui'))).first() self.assertEqual(obj, None) @@ -786,7 +915,7 @@ class QuerySetTest(unittest.TestCase): class BlogPost(Document): name = StringField(db_field='doc-name') - comments = ListField(EmbeddedDocumentField(Comment), + comments = ListField(EmbeddedDocumentField(Comment), db_field='cmnts') BlogPost.drop_collection() @@ -958,7 +1087,7 @@ class QuerySetTest(unittest.TestCase): BlogPost.objects.update_one(unset__hits=1) post.reload() self.assertEqual(post.hits, None) - + BlogPost.drop_collection() def test_update_pull(self): @@ -1027,7 +1156,7 @@ class QuerySetTest(unittest.TestCase): """ # run a map/reduce operation spanning all posts - results = BlogPost.objects.map_reduce(map_f, reduce_f) + results = BlogPost.objects.map_reduce(map_f, reduce_f, "myresults") results = list(results) self.assertEqual(len(results), 4) @@ -1038,7 +1167,7 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(film.value, 3) BlogPost.drop_collection() - + def test_map_reduce_with_custom_object_ids(self): """Ensure that QuerySet.map_reduce works properly with custom primary keys. @@ -1047,24 +1176,24 @@ class QuerySetTest(unittest.TestCase): class BlogPost(Document): title = StringField(primary_key=True) tags = ListField(StringField()) - + post1 = BlogPost(title="Post #1", tags=["mongodb", "mongoengine"]) post2 = BlogPost(title="Post #2", tags=["django", "mongodb"]) post3 = BlogPost(title="Post #3", tags=["hitchcock films"]) - + post1.save() post2.save() post3.save() - + self.assertEqual(BlogPost._fields['title'].db_field, '_id') self.assertEqual(BlogPost._meta['id_field'], 'title') - + map_f = """ function() { emit(this._id, 1); } """ - + # reduce to a list of tag ids and counts reduce_f = """ function(key, values) { @@ -1075,10 +1204,10 @@ class QuerySetTest(unittest.TestCase): return total; } """ - - results = BlogPost.objects.map_reduce(map_f, reduce_f) + + results = BlogPost.objects.map_reduce(map_f, reduce_f, "myresults") results = list(results) - + self.assertEqual(results[0].object, post1) self.assertEqual(results[1].object, post2) self.assertEqual(results[2].object, post3) @@ -1168,7 +1297,7 @@ class QuerySetTest(unittest.TestCase): finalize_f = """ function(key, value) { - // f(sec_since_epoch,y,z) = + // f(sec_since_epoch,y,z) = // log10(z) + ((y*sec_since_epoch) / 45000) z_10 = Math.log(value.z) / Math.log(10); weight = z_10 + ((value.y * value.t_s) / 45000); @@ -1187,6 +1316,7 @@ class QuerySetTest(unittest.TestCase): results = Link.objects.order_by("-value") results = results.map_reduce(map_f, reduce_f, + "myresults", finalize_f=finalize_f, scope=scope) results = list(results) @@ -1289,6 +1419,7 @@ class QuerySetTest(unittest.TestCase): class BlogPost(Document): tags = ListField(StringField()) deleted = BooleanField(default=False) + date = DateTimeField(default=datetime.now) @queryset_manager def objects(doc_cls, queryset): @@ -1296,7 +1427,7 @@ class QuerySetTest(unittest.TestCase): @queryset_manager def music_posts(doc_cls, queryset): - return queryset(tags='music', deleted=False) + return queryset(tags='music', deleted=False).order_by('-date') BlogPost.drop_collection() @@ -1312,7 +1443,7 @@ class QuerySetTest(unittest.TestCase): self.assertEqual([p.id for p in BlogPost.objects], [post1.id, post2.id, post3.id]) self.assertEqual([p.id for p in BlogPost.music_posts], - [post1.id, post2.id]) + [post2.id, post1.id]) BlogPost.drop_collection() @@ -1452,10 +1583,12 @@ class QuerySetTest(unittest.TestCase): class Test(Document): testdict = DictField() + Test.drop_collection() + t = Test(testdict={'f': 'Value'}) t.save() - self.assertEqual(len(Test.objects(testdict__f__startswith='Val')), 0) + self.assertEqual(len(Test.objects(testdict__f__startswith='Val')), 1) self.assertEqual(len(Test.objects(testdict__f='Value')), 1) Test.drop_collection() @@ -1514,12 +1647,12 @@ class QuerySetTest(unittest.TestCase): title = StringField() date = DateTimeField() location = GeoPointField() - + def __unicode__(self): return self.title - + Event.drop_collection() - + event1 = Event(title="Coltrane Motion @ Double Door", date=datetime.now() - timedelta(days=1), location=[41.909889, -87.677137]) @@ -1529,7 +1662,7 @@ class QuerySetTest(unittest.TestCase): event3 = Event(title="Coltrane Motion @ Empty Bottle", date=datetime.now(), location=[41.900474, -87.686638]) - + event1.save() event2.save() event3.save() @@ -1541,7 +1674,7 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(events.count(), 3) self.assertEqual(list(events), [event1, event3, event2]) - # find events within 5 miles of pitchfork office, chicago + # find events within 5 degrees of pitchfork office, chicago point_and_distance = [[41.9120459, -87.67892], 5] events = Event.objects(location__within_distance=point_and_distance) self.assertEqual(events.count(), 2) @@ -1549,24 +1682,24 @@ class QuerySetTest(unittest.TestCase): self.assertTrue(event2 not in events) self.assertTrue(event1 in events) self.assertTrue(event3 in events) - + # ensure ordering is respected by "near" events = Event.objects(location__near=[41.9120459, -87.67892]) events = events.order_by("-date") self.assertEqual(events.count(), 3) self.assertEqual(list(events), [event3, event1, event2]) - - # find events around san francisco + + # find events within 10 degrees of san francisco point_and_distance = [[37.7566023, -122.415579], 10] events = Event.objects(location__within_distance=point_and_distance) self.assertEqual(events.count(), 1) self.assertEqual(events[0], event2) - - # find events within 1 mile of greenpoint, broolyn, nyc, ny + + # find events within 1 degree of greenpoint, broolyn, nyc, ny point_and_distance = [[40.7237134, -73.9509714], 1] events = Event.objects(location__within_distance=point_and_distance) self.assertEqual(events.count(), 0) - + # ensure ordering is respected by "within_distance" point_and_distance = [[41.9120459, -87.67892], 10] events = Event.objects(location__within_distance=point_and_distance) @@ -1579,9 +1712,61 @@ class QuerySetTest(unittest.TestCase): events = Event.objects(location__within_box=box) self.assertEqual(events.count(), 1) self.assertEqual(events[0].id, event2.id) - + Event.drop_collection() + def test_spherical_geospatial_operators(self): + """Ensure that spherical geospatial queries are working + """ + class Point(Document): + location = GeoPointField() + + Point.drop_collection() + + # These points are one degree apart, which (according to Google Maps) + # is about 110 km apart at this place on the Earth. + north_point = Point(location=[-122, 38]) # Near Concord, CA + south_point = Point(location=[-122, 37]) # Near Santa Cruz, CA + north_point.save() + south_point.save() + + earth_radius = 6378.009; # in km (needs to be a float for dividing by) + + # Finds both points because they are within 60 km of the reference + # point equidistant between them. + points = Point.objects(location__near_sphere=[-122, 37.5]) + self.assertEqual(points.count(), 2) + + # Same behavior for _within_spherical_distance + points = Point.objects( + location__within_spherical_distance=[[-122, 37.5], 60/earth_radius] + ); + self.assertEqual(points.count(), 2) + + # Finds both points, but orders the north point first because it's + # closer to the reference point to the north. + points = Point.objects(location__near_sphere=[-122, 38.5]) + self.assertEqual(points.count(), 2) + self.assertEqual(points[0].id, north_point.id) + self.assertEqual(points[1].id, south_point.id) + + # Finds both points, but orders the south point first because it's + # closer to the reference point to the south. + points = Point.objects(location__near_sphere=[-122, 36.5]) + self.assertEqual(points.count(), 2) + self.assertEqual(points[0].id, south_point.id) + self.assertEqual(points[1].id, north_point.id) + + # Finds only one point because only the first point is within 60km of + # the reference point to the south. + points = Point.objects( + location__within_spherical_distance=[[-122, 36.5], 60/earth_radius] + ); + self.assertEqual(points.count(), 1) + self.assertEqual(points[0].id, south_point.id) + + Point.drop_collection() + def test_custom_querysets(self): """Ensure that custom QuerySet classes may be used. """ @@ -1602,6 +1787,53 @@ class QuerySetTest(unittest.TestCase): Post.drop_collection() + def test_custom_querysets_set_manager_directly(self): + """Ensure that custom QuerySet classes may be used. + """ + + class CustomQuerySet(QuerySet): + def not_empty(self): + return len(self) > 0 + + class CustomQuerySetManager(QuerySetManager): + queryset_class = CustomQuerySet + + class Post(Document): + objects = CustomQuerySetManager() + + Post.drop_collection() + + self.assertTrue(isinstance(Post.objects, CustomQuerySet)) + self.assertFalse(Post.objects.not_empty()) + + Post().save() + self.assertTrue(Post.objects.not_empty()) + + Post.drop_collection() + + def test_custom_querysets_managers_directly(self): + """Ensure that custom QuerySet classes may be used. + """ + + class CustomQuerySetManager(QuerySetManager): + + @staticmethod + def get_queryset(doc_cls, queryset): + return queryset(is_published=True) + + class Post(Document): + is_published = BooleanField(default=False) + published = CustomQuerySetManager() + + Post.drop_collection() + + Post().save() + Post(is_published=True).save() + self.assertEquals(Post.objects.count(), 2) + self.assertEquals(Post.published.count(), 1) + + Post.drop_collection() + def test_call_after_limits_set(self): """Ensure that re-filtering after slicing works """ @@ -1637,6 +1869,35 @@ class QuerySetTest(unittest.TestCase): Number.drop_collection() + def test_clone(self): + """Ensure that cloning clones complex querysets + """ + class Number(Document): + n = IntField() + + Number.drop_collection() + + for i in xrange(1, 101): + t = Number(n=i) + t.save() + + test = Number.objects + test2 = test.clone() + self.assertFalse(test == test2) + self.assertEqual(test.count(), test2.count()) + + test = test.filter(n__gt=11) + test2 = test.clone() + self.assertFalse(test == test2) + self.assertEqual(test.count(), test2.count()) + + test = test.limit(10) + test2 = test.clone() + self.assertFalse(test == test2) + self.assertEqual(test.count(), test2.count()) + + Number.drop_collection() + def test_unset_reference(self): class Comment(Document): text = StringField() @@ -1658,6 +1919,39 @@ class QuerySetTest(unittest.TestCase): Comment.drop_collection() Post.drop_collection() + def test_order_works_with_custom_db_field_names(self): + class Number(Document): + n = IntField(db_field='number') + + Number.drop_collection() + + n2 = Number.objects.create(n=2) + n1 = Number.objects.create(n=1) + + self.assertEqual(list(Number.objects), [n2,n1]) + self.assertEqual(list(Number.objects.order_by('n')), [n1,n2]) + + Number.drop_collection() + + def test_order_works_with_primary(self): + """Ensure that order_by and primary work. + """ + class Number(Document): + n = IntField(primary_key=True) + + Number.drop_collection() + + Number(n=1).save() + Number(n=2).save() + Number(n=3).save() + + numbers = [n.n for n in Number.objects.order_by('-n')] + self.assertEquals([3, 2, 1], numbers) + + numbers = [n.n for n in Number.objects.order_by('+n')] + self.assertEquals([1, 2, 3], numbers) + Number.drop_collection() + class QTest(unittest.TestCase): @@ -1679,7 +1973,7 @@ class QTest(unittest.TestCase): query = {'age': {'$gte': 18}, 'name': 'test'} self.assertEqual((q1 & q2 & q3 & q4 & q5).to_query(Person), query) - + def test_q_with_dbref(self): """Ensure Q objects handle DBRefs correctly""" connect(db='mongoenginetest') @@ -1721,7 +2015,7 @@ class QTest(unittest.TestCase): query = Q(x__lt=100) & Q(y__ne='NotMyString') query &= Q(y__in=['a', 'b', 'c']) & Q(x__gt=-100) mongo_query = { - 'x': {'$lt': 100, '$gt': -100}, + 'x': {'$lt': 100, '$gt': -100}, 'y': {'$ne': 'NotMyString', '$in': ['a', 'b', 'c']}, } self.assertEqual(query.to_query(TestDoc), mongo_query) @@ -1795,6 +2089,30 @@ class QTest(unittest.TestCase): for condition in conditions: self.assertTrue(condition in query['$or']) + + def test_q_clone(self): + + class TestDoc(Document): + x = IntField() + + TestDoc.drop_collection() + for i in xrange(1, 101): + t = TestDoc(x=i) + t.save() + + # Check normal cases work without an error + test = TestDoc.objects(Q(x__lt=7) & Q(x__gt=3)) + + self.assertEqual(test.count(), 3) + + test2 = test.clone() + self.assertEqual(test2.count(), 3) + self.assertFalse(test2 == test) + + test2.filter(x=6) + self.assertEqual(test2.count(), 1) + self.assertEqual(test.count(), 3) + class QueryFieldListTest(unittest.TestCase): def test_empty(self): q = QueryFieldList() @@ -1805,51 +2123,52 @@ class QueryFieldListTest(unittest.TestCase): def test_include_include(self): q = QueryFieldList() - q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.ONLY) + q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.ONLY) self.assertEqual(q.as_dict(), {'a': True, 'b': True}) - q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) + q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) self.assertEqual(q.as_dict(), {'b': True}) def test_include_exclude(self): q = QueryFieldList() - q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.ONLY) + q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.ONLY) self.assertEqual(q.as_dict(), {'a': True, 'b': True}) - q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.EXCLUDE) + q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.EXCLUDE) self.assertEqual(q.as_dict(), {'a': True}) def test_exclude_exclude(self): q = QueryFieldList() - q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.EXCLUDE) + q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.EXCLUDE) self.assertEqual(q.as_dict(), {'a': False, 'b': False}) - q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.EXCLUDE) + q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.EXCLUDE) self.assertEqual(q.as_dict(), {'a': False, 'b': False, 'c': False}) def test_exclude_include(self): q = QueryFieldList() - q += QueryFieldList(fields=['a', 'b'], direction=QueryFieldList.EXCLUDE) + q += QueryFieldList(fields=['a', 'b'], value=QueryFieldList.EXCLUDE) self.assertEqual(q.as_dict(), {'a': False, 'b': False}) - q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) + q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) self.assertEqual(q.as_dict(), {'c': True}) def test_always_include(self): q = QueryFieldList(always_include=['x', 'y']) - q += QueryFieldList(fields=['a', 'b', 'x'], direction=QueryFieldList.EXCLUDE) - q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) + q += QueryFieldList(fields=['a', 'b', 'x'], value=QueryFieldList.EXCLUDE) + q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True}) - def test_reset(self): q = QueryFieldList(always_include=['x', 'y']) - q += QueryFieldList(fields=['a', 'b', 'x'], direction=QueryFieldList.EXCLUDE) - q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) + q += QueryFieldList(fields=['a', 'b', 'x'], value=QueryFieldList.EXCLUDE) + q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'c': True}) q.reset() self.assertFalse(q) - q += QueryFieldList(fields=['b', 'c'], direction=QueryFieldList.ONLY) + q += QueryFieldList(fields=['b', 'c'], value=QueryFieldList.ONLY) self.assertEqual(q.as_dict(), {'x': True, 'y': True, 'b': True, 'c': True}) - - + def test_using_a_slice(self): + q = QueryFieldList() + q += QueryFieldList(fields=['a'], value={"$slice": 5}) + self.assertEqual(q.as_dict(), {'a': {"$slice": 5}}) if __name__ == '__main__':