diff --git a/docs/changelog.rst b/docs/changelog.rst index b2a855d5..8388b05a 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -2,6 +2,11 @@ Changelog ========= +Changes in 0.8 +============== +- Remove _types and just use _cls for inheritance (MongoEngine/mongoengine#148) + + Changes in 0.7.X ================ - Unicode fix for repr (MongoEngine/mongoengine#133) diff --git a/docs/guide/defining-documents.rst b/docs/guide/defining-documents.rst index 3ee77961..cf3b5a6f 100644 --- a/docs/guide/defining-documents.rst +++ b/docs/guide/defining-documents.rst @@ -461,9 +461,10 @@ If a dictionary is passed then the following options are available: :attr:`fields` (Default: None) The fields to index. Specified in the same format as described above. -:attr:`types` (Default: True) - Whether the index should have the :attr:`_types` field added automatically - to the start of the index. +:attr:`cls` (Default: True) + If you have polymorphic models that inherit and have `allow_inheritance` + turned on, you can configure whether the index should have the + :attr:`_cls` field added automatically to the start of the index. :attr:`sparse` (Default: False) Whether the index should be sparse. @@ -590,14 +591,14 @@ convenient and efficient retrieval of related documents:: Working with existing data -------------------------- To enable correct retrieval of documents involved in this kind of heirarchy, -two extra attributes are stored on each document in the database: :attr:`_cls` -and :attr:`_types`. These are hidden from the user through the MongoEngine -interface, but may not be present if you are trying to use MongoEngine with -an existing database. For this reason, you may disable this inheritance -mechansim, removing the dependency of :attr:`_cls` and :attr:`_types`, enabling -you to work with existing databases. To disable inheritance on a document -class, set :attr:`allow_inheritance` to ``False`` in the :attr:`meta` -dictionary:: +an extra attribute is stored on each document in the database: :attr:`_cls`. +These are hidden from the user through the MongoEngine interface, but may not +be present if you are trying to use MongoEngine with an existing database. + +For this reason, you may disable this inheritance mechansim, removing the +dependency of :attr:`_cls`, enabling you to work with existing databases. +To disable inheritance on a document class, set :attr:`allow_inheritance` to +``False`` in the :attr:`meta` dictionary:: # Will work with data in an existing collection named 'cmsPage' class Page(Document): diff --git a/docs/upgrade.rst b/docs/upgrade.rst index 82ac7cac..99e3078c 100644 --- a/docs/upgrade.rst +++ b/docs/upgrade.rst @@ -2,6 +2,45 @@ Upgrading ========= +0.7 to 0.8 +========== + +Inheritance +----------- + +The inheritance model has changed, we no longer need to store an array of +`types` with the model we can just use the classname in `_cls`. This means +that you will have to update your indexes for each of your inherited classes +like so: + + # 1. Declaration of the class + class Animal(Document): + name = StringField() + meta = { + 'allow_inheritance': True, + 'indexes': ['name'] + } + + # 2. Remove _types + collection = Animal._get_collection() + collection.update({}, {"$unset": {"_types": 1}}, multi=True) + + # 3. Confirm extra data is removed + count = collection.find({'_types': {"$exists": True}}).count() + assert count == 0 + + # 4. Remove indexes + info = collection.index_information() + indexes_to_drop = [key for key, value in info.iteritems() + if '_types' in dict(value['key'])] + for index in indexes_to_drop: + collection.drop_index(index) + + # 5. Recreate indexes + Animal.objects._ensure_indexes() + + + 0.6 to 0.7 ========== diff --git a/mongoengine/__init__.py b/mongoengine/__init__.py index 9044e617..d92165c5 100644 --- a/mongoengine/__init__.py +++ b/mongoengine/__init__.py @@ -9,10 +9,10 @@ from queryset import * import signals from signals import * -__all__ = (document.__all__ + fields.__all__ + connection.__all__ + - queryset.__all__ + signals.__all__) +__all__ = (list(document.__all__) + fields.__all__ + connection.__all__ + + list(queryset.__all__) + signals.__all__) -VERSION = (0, 7, 5) +VERSION = (0, 8, 0, '+') def get_version(): diff --git a/mongoengine/base.py b/mongoengine/base.py deleted file mode 100644 index fa12e35d..00000000 --- a/mongoengine/base.py +++ /dev/null @@ -1,1523 +0,0 @@ -import operator -import sys -import warnings -import weakref - -from collections import defaultdict -from functools import partial - -from queryset import QuerySet, QuerySetManager -from queryset import DoesNotExist, MultipleObjectsReturned -from queryset import DO_NOTHING - -from mongoengine import signals -from mongoengine.python_support import (PY3, UNICODE_KWARGS, txt_type, - to_str_keys_recursive) - -import pymongo -from bson import ObjectId -from bson.dbref import DBRef - -ALLOW_INHERITANCE = True - -_document_registry = {} -_class_registry = {} - - -class NotRegistered(Exception): - pass - - -class InvalidDocumentError(Exception): - pass - - -class ValidationError(AssertionError): - """Validation exception. - - May represent an error validating a field or a - document containing fields with validation errors. - - :ivar errors: A dictionary of errors for fields within this - document or list, or None if the error is for an - individual field. - """ - - errors = {} - field_name = None - _message = None - - def __init__(self, message="", **kwargs): - self.errors = kwargs.get('errors', {}) - self.field_name = kwargs.get('field_name') - self.message = message - - def __str__(self): - return txt_type(self.message) - - def __repr__(self): - return '%s(%s,)' % (self.__class__.__name__, self.message) - - def __getattribute__(self, name): - message = super(ValidationError, self).__getattribute__(name) - if name == 'message': - if self.field_name: - message = '%s' % message - if self.errors: - message = '%s(%s)' % (message, self._format_errors()) - return message - - def _get_message(self): - return self._message - - def _set_message(self, message): - self._message = message - - message = property(_get_message, _set_message) - - def to_dict(self): - """Returns a dictionary of all errors within a document - - Keys are field names or list indices and values are the - validation error messages, or a nested dictionary of - errors for an embedded document or list. - """ - - def build_dict(source): - errors_dict = {} - if not source: - return errors_dict - if isinstance(source, dict): - for field_name, error in source.iteritems(): - errors_dict[field_name] = build_dict(error) - elif isinstance(source, ValidationError) and source.errors: - return build_dict(source.errors) - else: - return unicode(source) - return errors_dict - if not self.errors: - return {} - return build_dict(self.errors) - - def _format_errors(self): - """Returns a string listing all errors within a document""" - - def generate_key(value, prefix=''): - if isinstance(value, list): - value = ' '.join([generate_key(k) for k in value]) - if isinstance(value, dict): - value = ' '.join( - [generate_key(v, k) for k, v in value.iteritems()]) - - results = "%s.%s" % (prefix, value) if prefix else value - return results - - error_dict = defaultdict(list) - for k, v in self.to_dict().iteritems(): - error_dict[generate_key(v)].append(k) - return ' '.join(["%s: %s" % (k, v) for k, v in error_dict.iteritems()]) - - -def get_document(name): - doc = _document_registry.get(name, None) - if not doc: - # Possible old style names - end = ".%s" % name - possible_match = [k for k in _document_registry.keys() - if k.endswith(end)] - if len(possible_match) == 1: - doc = _document_registry.get(possible_match.pop(), None) - if not doc: - raise NotRegistered(""" - `%s` has not been registered in the document registry. - Importing the document class automatically registers it, has it - been imported? - """.strip() % name) - return doc - - -class BaseField(object): - """A base class for fields in a MongoDB document. Instances of this class - may be added to subclasses of `Document` to define a document's schema. - - .. versionchanged:: 0.5 - added verbose and help text - """ - - name = None - - # Fields may have _types inserted into indexes by default - _index_with_types = True - _geo_index = False - - # These track each time a Field instance is created. Used to retain order. - # The auto_creation_counter is used for fields that MongoEngine implicitly - # creates, creation_counter is used for all user-specified fields. - creation_counter = 0 - auto_creation_counter = -1 - - def __init__(self, db_field=None, name=None, required=False, default=None, - unique=False, unique_with=None, primary_key=False, - validation=None, choices=None, verbose_name=None, - help_text=None): - self.db_field = (db_field or name) if not primary_key else '_id' - if name: - msg = "Fields' 'name' attribute deprecated in favour of 'db_field'" - warnings.warn(msg, DeprecationWarning) - self.name = None - self.required = required or primary_key - self.default = default - self.unique = bool(unique or unique_with) - self.unique_with = unique_with - self.primary_key = primary_key - self.validation = validation - self.choices = choices - self.verbose_name = verbose_name - self.help_text = help_text - - # Adjust the appropriate creation counter, and save our local copy. - if self.db_field == '_id': - self.creation_counter = BaseField.auto_creation_counter - BaseField.auto_creation_counter -= 1 - else: - self.creation_counter = BaseField.creation_counter - BaseField.creation_counter += 1 - - def __get__(self, instance, owner): - """Descriptor for retrieving a value from a field in a document. Do - any necessary conversion between Python and MongoDB types. - """ - if instance is None: - # Document class being used rather than a document object - return self - - # Get value from document instance if available, if not use default - value = instance._data.get(self.name) - - if value is None: - value = self.default - # Allow callable default values - if callable(value): - value = value() - - return value - - def __set__(self, instance, value): - """Descriptor for assigning a value to a field in a document. - """ - instance._data[self.name] = value - if instance._initialised: - instance._mark_as_changed(self.name) - - def error(self, message="", errors=None, field_name=None): - """Raises a ValidationError. - """ - field_name = field_name if field_name else self.name - raise ValidationError(message, errors=errors, field_name=field_name) - - def to_python(self, value): - """Convert a MongoDB-compatible type to a Python type. - """ - return value - - def to_mongo(self, value): - """Convert a Python type to a MongoDB-compatible type. - """ - return self.to_python(value) - - def prepare_query_value(self, op, value): - """Prepare a value that is being used in a query for PyMongo. - """ - return value - - def validate(self, value): - """Perform validation on a value. - """ - pass - - def _validate(self, value): - Document = _import_class('Document') - EmbeddedDocument = _import_class('EmbeddedDocument') - # check choices - if self.choices: - is_cls = isinstance(value, (Document, EmbeddedDocument)) - value_to_check = value.__class__ if is_cls else value - err_msg = 'an instance' if is_cls else 'one' - if isinstance(self.choices[0], (list, tuple)): - option_keys = [k for k, v in self.choices] - if value_to_check not in option_keys: - msg = ('Value must be %s of %s' % - (err_msg, unicode(option_keys))) - self.error(msg) - elif value_to_check not in self.choices: - msg = ('Value must be %s of %s' % - (err_msg, unicode(self.choices))) - self.error() - - # check validation argument - if self.validation is not None: - if callable(self.validation): - if not self.validation(value): - self.error('Value does not match custom validation method') - else: - raise ValueError('validation argument for "%s" must be a ' - 'callable.' % self.name) - - self.validate(value) - - -class ComplexBaseField(BaseField): - """Handles complex fields, such as lists / dictionaries. - - Allows for nesting of embedded documents inside complex types. - Handles the lazy dereferencing of a queryset by lazily dereferencing all - items in a list / dict rather than one at a time. - - .. versionadded:: 0.5 - """ - - field = None - __dereference = False - - def __get__(self, instance, owner): - """Descriptor to automatically dereference references. - """ - if instance is None: - # Document class being used rather than a document object - return self - - ReferenceField = _import_class('ReferenceField') - GenericReferenceField = _import_class('GenericReferenceField') - dereference = self.field is None or isinstance(self.field, - (GenericReferenceField, ReferenceField)) - if not self._dereference and instance._initialised and dereference: - instance._data[self.name] = self._dereference( - instance._data.get(self.name), max_depth=1, instance=instance, - name=self.name - ) - - value = super(ComplexBaseField, self).__get__(instance, owner) - - # Convert lists / values so we can watch for any changes on them - if (isinstance(value, (list, tuple)) and - not isinstance(value, BaseList)): - value = BaseList(value, instance, self.name) - instance._data[self.name] = value - elif isinstance(value, dict) and not isinstance(value, BaseDict): - value = BaseDict(value, instance, self.name) - instance._data[self.name] = value - - if (instance._initialised and isinstance(value, (BaseList, BaseDict)) - and not value._dereferenced): - value = self._dereference( - value, max_depth=1, instance=instance, name=self.name - ) - value._dereferenced = True - instance._data[self.name] = value - - return value - - def __set__(self, instance, value): - """Descriptor for assigning a value to a field in a document. - """ - instance._data[self.name] = value - instance._mark_as_changed(self.name) - - def to_python(self, value): - """Convert a MongoDB-compatible type to a Python type. - """ - Document = _import_class('Document') - - if isinstance(value, basestring): - return value - - if hasattr(value, 'to_python'): - return value.to_python() - - is_list = False - if not hasattr(value, 'items'): - try: - is_list = True - value = dict([(k, v) for k, v in enumerate(value)]) - except TypeError: # Not iterable return the value - return value - - if self.field: - value_dict = dict([(key, self.field.to_python(item)) - for key, item in value.items()]) - else: - value_dict = {} - for k, v in value.items(): - if isinstance(v, Document): - # We need the id from the saved object to create the DBRef - if v.pk is None: - self.error('You can only reference documents once they' - ' have been saved to the database') - collection = v._get_collection_name() - value_dict[k] = DBRef(collection, v.pk) - elif hasattr(v, 'to_python'): - value_dict[k] = v.to_python() - else: - value_dict[k] = self.to_python(v) - - if is_list: # Convert back to a list - return [v for k, v in sorted(value_dict.items(), - key=operator.itemgetter(0))] - return value_dict - - def to_mongo(self, value): - """Convert a Python type to a MongoDB-compatible type. - """ - Document = _import_class("Document") - - if isinstance(value, basestring): - return value - - if hasattr(value, 'to_mongo'): - return value.to_mongo() - - is_list = False - if not hasattr(value, 'items'): - try: - is_list = True - value = dict([(k, v) for k, v in enumerate(value)]) - except TypeError: # Not iterable return the value - return value - - if self.field: - value_dict = dict([(key, self.field.to_mongo(item)) - for key, item in value.items()]) - else: - value_dict = {} - for k, v in value.items(): - if isinstance(v, Document): - # We need the id from the saved object to create the DBRef - if v.pk is None: - self.error('You can only reference documents once they' - ' have been saved to the database') - - # If its a document that is not inheritable it won't have - # _types / _cls data so make it a generic reference allows - # us to dereference - meta = getattr(v, '_meta', {}) - allow_inheritance = ( - meta.get('allow_inheritance', ALLOW_INHERITANCE) - == False) - if allow_inheritance and not self.field: - GenericReferenceField = _import_class("GenericReferenceField") - value_dict[k] = GenericReferenceField().to_mongo(v) - else: - collection = v._get_collection_name() - value_dict[k] = DBRef(collection, v.pk) - elif hasattr(v, 'to_mongo'): - value_dict[k] = v.to_mongo() - else: - value_dict[k] = self.to_mongo(v) - - if is_list: # Convert back to a list - return [v for k, v in sorted(value_dict.items(), - key=operator.itemgetter(0))] - return value_dict - - def validate(self, value): - """If field is provided ensure the value is valid. - """ - errors = {} - if self.field: - if hasattr(value, 'iteritems') or hasattr(value, 'items'): - sequence = value.iteritems() - else: - sequence = enumerate(value) - for k, v in sequence: - try: - self.field._validate(v) - except ValidationError, error: - errors[k] = error.errors or error - except (ValueError, AssertionError), error: - errors[k] = error - - if errors: - field_class = self.field.__class__.__name__ - self.error('Invalid %s item (%s)' % (field_class, value), - errors=errors) - # Don't allow empty values if required - if self.required and not value: - self.error('Field is required and cannot be empty') - - def prepare_query_value(self, op, value): - return self.to_mongo(value) - - def lookup_member(self, member_name): - if self.field: - return self.field.lookup_member(member_name) - return None - - def _set_owner_document(self, owner_document): - if self.field: - self.field.owner_document = owner_document - self._owner_document = owner_document - - def _get_owner_document(self, owner_document): - self._owner_document = owner_document - - owner_document = property(_get_owner_document, _set_owner_document) - - @property - def _dereference(self,): - if not self.__dereference: - DeReference = _import_class("DeReference") - self.__dereference = DeReference() # Cached - return self.__dereference - - -class ObjectIdField(BaseField): - """An field wrapper around MongoDB's ObjectIds. - """ - - def to_python(self, value): - if not isinstance(value, ObjectId): - value = ObjectId(value) - return value - - def to_mongo(self, value): - if not isinstance(value, ObjectId): - try: - return ObjectId(unicode(value)) - except Exception, e: - # e.message attribute has been deprecated since Python 2.6 - self.error(unicode(e)) - return value - - def prepare_query_value(self, op, value): - return self.to_mongo(value) - - def validate(self, value): - try: - ObjectId(unicode(value)) - except: - self.error('Invalid Object ID') - - -class DocumentMetaclass(type): - """Metaclass for all documents. - """ - - def __new__(cls, name, bases, attrs): - flattened_bases = cls._get_bases(bases) - super_new = super(DocumentMetaclass, cls).__new__ - - # If a base class just call super - metaclass = attrs.get('my_metaclass') - if metaclass and issubclass(metaclass, DocumentMetaclass): - return super_new(cls, name, bases, attrs) - - attrs['_is_document'] = attrs.get('_is_document', False) - - # EmbeddedDocuments could have meta data for inheritance - if 'meta' in attrs: - attrs['_meta'] = attrs.pop('meta') - - # Handle document Fields - - # Merge all fields from subclasses - doc_fields = {} - for base in flattened_bases[::-1]: - if hasattr(base, '_fields'): - doc_fields.update(base._fields) - - # Standard object mixin - merge in any Fields - if not hasattr(base, '_meta'): - base_fields = {} - for attr_name, attr_value in base.__dict__.iteritems(): - if not isinstance(attr_value, BaseField): - continue - attr_value.name = attr_name - if not attr_value.db_field: - attr_value.db_field = attr_name - base_fields[attr_name] = attr_value - doc_fields.update(base_fields) - - # Discover any document fields - field_names = {} - for attr_name, attr_value in attrs.iteritems(): - if not isinstance(attr_value, BaseField): - continue - attr_value.name = attr_name - if not attr_value.db_field: - attr_value.db_field = attr_name - doc_fields[attr_name] = attr_value - - # Count names to ensure no db_field redefinitions - field_names[attr_value.db_field] = field_names.get( - attr_value.db_field, 0) + 1 - - # Ensure no duplicate db_fields - duplicate_db_fields = [k for k, v in field_names.items() if v > 1] - if duplicate_db_fields: - msg = ("Multiple db_fields defined for: %s " % - ", ".join(duplicate_db_fields)) - raise InvalidDocumentError(msg) - - # Set _fields and db_field maps - attrs['_fields'] = doc_fields - attrs['_db_field_map'] = dict([(k, getattr(v, 'db_field', k)) - for k, v in doc_fields.iteritems()]) - attrs['_reverse_db_field_map'] = dict( - (v, k) for k, v in attrs['_db_field_map'].iteritems()) - - # - # Set document hierarchy - # - superclasses = {} - class_name = [name] - for base in flattened_bases: - if (not getattr(base, '_is_base_cls', True) and - not getattr(base, '_meta', {}).get('abstract', True)): - # Collate heirarchy for _cls and _types - class_name.append(base.__name__) - - # Get superclasses from superclass - superclasses[base._class_name] = base - superclasses.update(base._superclasses) - - if hasattr(base, '_meta'): - # Warn if allow_inheritance isn't set and prevent - # inheritance of classes where inheritance is set to False - allow_inheritance = base._meta.get('allow_inheritance', - ALLOW_INHERITANCE) - if (not getattr(base, '_is_base_cls', True) - and allow_inheritance is None): - warnings.warn( - "%s uses inheritance, the default for " - "allow_inheritance is changing to off by default. " - "Please add it to the document meta." % name, - FutureWarning - ) - elif (allow_inheritance == False and - not base._meta.get('abstract')): - raise ValueError('Document %s may not be subclassed' % - base.__name__) - - attrs['_class_name'] = '.'.join(reversed(class_name)) - attrs['_superclasses'] = superclasses - - # Create the new_class - new_class = super_new(cls, name, bases, attrs) - - # Handle delete rules - Document, EmbeddedDocument, DictField = cls._import_classes() - for field in new_class._fields.itervalues(): - f = field - f.owner_document = new_class - delete_rule = getattr(f, 'reverse_delete_rule', DO_NOTHING) - if isinstance(f, ComplexBaseField) and hasattr(f, 'field'): - delete_rule = getattr(f.field, - 'reverse_delete_rule', - DO_NOTHING) - if isinstance(f, DictField) and delete_rule != DO_NOTHING: - msg = ("Reverse delete rules are not supported " - "for %s (field: %s)" % - (field.__class__.__name__, field.name)) - raise InvalidDocumentError(msg) - - f = field.field - - if delete_rule != DO_NOTHING: - if issubclass(new_class, EmbeddedDocument): - msg = ("Reverse delete rules are not supported for " - "EmbeddedDocuments (field: %s)" % field.name) - raise InvalidDocumentError(msg) - f.document_type.register_delete_rule(new_class, - field.name, delete_rule) - - if (field.name and hasattr(Document, field.name) and - EmbeddedDocument not in new_class.mro()): - msg = ("%s is a document method and not a valid " - "field name" % field.name) - raise InvalidDocumentError(msg) - - # Add class to the _document_registry - _document_registry[new_class._class_name] = new_class - - # In Python 2, User-defined methods objects have special read-only - # attributes 'im_func' and 'im_self' which contain the function obj - # and class instance object respectively. With Python 3 these special - # attributes have been replaced by __func__ and __self__. The Blinker - # module continues to use im_func and im_self, so the code below - # copies __func__ into im_func and __self__ into im_self for - # classmethod objects in Document derived classes. - if PY3: - for key, val in new_class.__dict__.items(): - if isinstance(val, classmethod): - f = val.__get__(new_class) - if hasattr(f, '__func__') and not hasattr(f, 'im_func'): - f.__dict__.update({'im_func': getattr(f, '__func__')}) - if hasattr(f, '__self__') and not hasattr(f, 'im_self'): - f.__dict__.update({'im_self': getattr(f, '__self__')}) - - return new_class - - def add_to_class(self, name, value): - setattr(self, name, value) - - @classmethod - def _get_bases(cls, bases): - if isinstance(bases, BasesTuple): - return bases - seen = [] - bases = cls.__get_bases(bases) - unique_bases = (b for b in bases if not (b in seen or seen.append(b))) - return BasesTuple(unique_bases) - - @classmethod - def __get_bases(cls, bases): - for base in bases: - if base is object: - continue - yield base - for child_base in cls.__get_bases(base.__bases__): - yield child_base - - @classmethod - def _import_classes(cls): - Document = _import_class('Document') - EmbeddedDocument = _import_class('EmbeddedDocument') - DictField = _import_class('DictField') - return (Document, EmbeddedDocument, DictField) - - -class TopLevelDocumentMetaclass(DocumentMetaclass): - """Metaclass for top-level documents (i.e. documents that have their own - collection in the database. - """ - - def __new__(cls, name, bases, attrs): - flattened_bases = cls._get_bases(bases) - super_new = super(TopLevelDocumentMetaclass, cls).__new__ - - # Set default _meta data if base class, otherwise get user defined meta - if (attrs.get('my_metaclass') == TopLevelDocumentMetaclass): - # defaults - attrs['_meta'] = { - 'abstract': True, - 'max_documents': None, - 'max_size': None, - 'ordering': [], # default ordering applied at runtime - 'indexes': [], # indexes to be ensured at runtime - 'id_field': None, - 'index_background': False, - 'index_drop_dups': False, - 'index_opts': None, - 'delete_rules': None, - 'allow_inheritance': None, - } - attrs['_is_base_cls'] = True - attrs['_meta'].update(attrs.get('meta', {})) - else: - attrs['_meta'] = attrs.get('meta', {}) - # Explictly set abstract to false unless set - attrs['_meta']['abstract'] = attrs['_meta'].get('abstract', False) - attrs['_is_base_cls'] = False - - # Set flag marking as document class - as opposed to an object mixin - attrs['_is_document'] = True - - # Ensure queryset_class is inherited - if 'objects' in attrs: - manager = attrs['objects'] - if hasattr(manager, 'queryset_class'): - attrs['_meta']['queryset_class'] = manager.queryset_class - - # Clean up top level meta - if 'meta' in attrs: - del(attrs['meta']) - - # Find the parent document class - parent_doc_cls = [b for b in flattened_bases - if b.__class__ == TopLevelDocumentMetaclass] - parent_doc_cls = None if not parent_doc_cls else parent_doc_cls[0] - - # Prevent classes setting collection different to their parents - # If parent wasn't an abstract class - if (parent_doc_cls and 'collection' in attrs.get('_meta', {}) - and not parent_doc_cls._meta.get('abstract', True)): - msg = "Trying to set a collection on a subclass (%s)" % name - warnings.warn(msg, SyntaxWarning) - del(attrs['_meta']['collection']) - - # Ensure abstract documents have abstract bases - if attrs.get('_is_base_cls') or attrs['_meta'].get('abstract'): - if (parent_doc_cls and - not parent_doc_cls._meta.get('abstract', False)): - msg = "Abstract document cannot have non-abstract base" - raise ValueError(msg) - return super_new(cls, name, bases, attrs) - - # Merge base class metas. - # Uses a special MetaDict that handles various merging rules - meta = MetaDict() - for base in flattened_bases[::-1]: - # Add any mixin metadata from plain objects - if hasattr(base, 'meta'): - meta.merge(base.meta) - elif hasattr(base, '_meta'): - meta.merge(base._meta) - - # Set collection in the meta if its callable - if (getattr(base, '_is_document', False) and - not base._meta.get('abstract')): - collection = meta.get('collection', None) - if callable(collection): - meta['collection'] = collection(base) - - meta.merge(attrs.get('_meta', {})) # Top level meta - - # Only simple classes (direct subclasses of Document) - # may set allow_inheritance to False - simple_class = all([b._meta.get('abstract') - for b in flattened_bases if hasattr(b, '_meta')]) - if (not simple_class and meta['allow_inheritance'] == False and - not meta['abstract']): - raise ValueError('Only direct subclasses of Document may set ' - '"allow_inheritance" to False') - - # Set default collection name - if 'collection' not in meta: - meta['collection'] = ''.join('_%s' % c if c.isupper() else c - for c in name).strip('_').lower() - attrs['_meta'] = meta - - # Call super and get the new class - new_class = super_new(cls, name, bases, attrs) - - meta = new_class._meta - - # Set index specifications - meta['index_specs'] = [QuerySet._build_index_spec(new_class, spec) - for spec in meta['indexes']] - unique_indexes = cls._unique_with_indexes(new_class) - new_class._meta['unique_indexes'] = unique_indexes - - # If collection is a callable - call it and set the value - collection = meta.get('collection') - if callable(collection): - new_class._meta['collection'] = collection(new_class) - - # Provide a default queryset unless one has been set - manager = attrs.get('objects', QuerySetManager()) - new_class.objects = manager - - # Validate the fields and set primary key if needed - for field_name, field in new_class._fields.iteritems(): - if field.primary_key: - # Ensure only one primary key is set - current_pk = new_class._meta.get('id_field') - if current_pk and current_pk != field_name: - raise ValueError('Cannot override primary key field') - - # Set primary key - if not current_pk: - new_class._meta['id_field'] = field_name - new_class.id = field - - # Set primary key if not defined by the document - if not new_class._meta.get('id_field'): - new_class._meta['id_field'] = 'id' - new_class._fields['id'] = ObjectIdField(db_field='_id') - new_class.id = new_class._fields['id'] - - # Merge in exceptions with parent hierarchy - exceptions_to_merge = (DoesNotExist, MultipleObjectsReturned) - module = attrs.get('__module__') - for exc in exceptions_to_merge: - name = exc.__name__ - parents = tuple(getattr(base, name) for base in flattened_bases - if hasattr(base, name)) or (exc,) - # Create new exception and set to new_class - exception = type(name, parents, {'__module__': module}) - setattr(new_class, name, exception) - - return new_class - - @classmethod - def _unique_with_indexes(cls, new_class, namespace=""): - """ - Find and set unique indexes - """ - unique_indexes = [] - for field_name, field in new_class._fields.items(): - # Generate a list of indexes needed by uniqueness constraints - if field.unique: - field.required = True - unique_fields = [field.db_field] - - # Add any unique_with fields to the back of the index spec - if field.unique_with: - if isinstance(field.unique_with, basestring): - field.unique_with = [field.unique_with] - - # Convert unique_with field names to real field names - unique_with = [] - for other_name in field.unique_with: - parts = other_name.split('.') - # Lookup real name - parts = QuerySet._lookup_field(new_class, parts) - name_parts = [part.db_field for part in parts] - unique_with.append('.'.join(name_parts)) - # Unique field should be required - parts[-1].required = True - unique_fields += unique_with - - # Add the new index to the list - index = [("%s%s" % (namespace, f), pymongo.ASCENDING) - for f in unique_fields] - unique_indexes.append(index) - - # Grab any embedded document field unique indexes - if (field.__class__.__name__ == "EmbeddedDocumentField" and - field.document_type != new_class): - field_namespace = "%s." % field_name - unique_indexes += cls._unique_with_indexes(field.document_type, - field_namespace) - - return unique_indexes - - -class MetaDict(dict): - """Custom dictionary for meta classes. - Handles the merging of set indexes - """ - _merge_options = ('indexes',) - - def merge(self, new_options): - for k, v in new_options.iteritems(): - if k in self._merge_options: - self[k] = self.get(k, []) + v - else: - self[k] = v - - -class BaseDocument(object): - - _dynamic = False - _created = True - _dynamic_lock = True - _initialised = False - - def __init__(self, **values): - signals.pre_init.send(self.__class__, document=self, values=values) - - self._data = {} - - # Assign default values to instance - for key, field in self._fields.iteritems(): - if self._db_field_map.get(key, key) in values: - continue - value = getattr(self, key, None) - setattr(self, key, value) - - # Set passed values after initialisation - if self._dynamic: - self._dynamic_fields = {} - dynamic_data = {} - for key, value in values.iteritems(): - if key in self._fields or key == '_id': - setattr(self, key, value) - elif self._dynamic: - dynamic_data[key] = value - else: - for key, value in values.iteritems(): - key = self._reverse_db_field_map.get(key, key) - setattr(self, key, value) - - # Set any get_fieldname_display methods - self.__set_field_display() - - if self._dynamic: - self._dynamic_lock = False - for key, value in dynamic_data.iteritems(): - setattr(self, key, value) - - # Flag initialised - self._initialised = True - signals.post_init.send(self.__class__, document=self) - - def __setattr__(self, name, value): - # Handle dynamic data only if an initialised dynamic document - if self._dynamic and not self._dynamic_lock: - - field = None - if not hasattr(self, name) and not name.startswith('_'): - DynamicField = _import_class("DynamicField") - field = DynamicField(db_field=name) - field.name = name - self._dynamic_fields[name] = field - - if not name.startswith('_'): - value = self.__expand_dynamic_values(name, value) - - # Handle marking data as changed - if name in self._dynamic_fields: - self._data[name] = value - if hasattr(self, '_changed_fields'): - self._mark_as_changed(name) - - if (self._is_document and not self._created and - name in self._meta.get('shard_key', tuple()) and - self._data.get(name) != value): - OperationError = _import_class('OperationError') - msg = "Shard Keys are immutable. Tried to update %s" % name - raise OperationError(msg) - - super(BaseDocument, self).__setattr__(name, value) - - def __expand_dynamic_values(self, name, value): - """expand any dynamic values to their correct types / values""" - if not isinstance(value, (dict, list, tuple)): - return value - - is_list = False - if not hasattr(value, 'items'): - is_list = True - value = dict([(k, v) for k, v in enumerate(value)]) - - if not is_list and '_cls' in value: - cls = get_document(value['_cls']) - return cls(**value) - - data = {} - for k, v in value.items(): - key = name if is_list else k - data[k] = self.__expand_dynamic_values(key, v) - - if is_list: # Convert back to a list - data_items = sorted(data.items(), key=operator.itemgetter(0)) - value = [v for k, v in data_items] - else: - value = data - - # Convert lists / values so we can watch for any changes on them - if (isinstance(value, (list, tuple)) and - not isinstance(value, BaseList)): - value = BaseList(value, self, name) - elif isinstance(value, dict) and not isinstance(value, BaseDict): - value = BaseDict(value, self, name) - - return value - - def validate(self): - """Ensure that all fields' values are valid and that required fields - are present. - """ - # Get a list of tuples of field names and their current values - fields = [(field, getattr(self, name)) - for name, field in self._fields.items()] - - # Ensure that each field is matched to a valid value - errors = {} - for field, value in fields: - if value is not None: - try: - field._validate(value) - except ValidationError, error: - errors[field.name] = error.errors or error - except (ValueError, AttributeError, AssertionError), error: - errors[field.name] = error - elif field.required: - errors[field.name] = ValidationError('Field is required', - field_name=field.name) - if errors: - raise ValidationError('ValidationError', errors=errors) - - def to_mongo(self): - """Return data dictionary ready for use with MongoDB. - """ - data = {} - for field_name, field in self._fields.items(): - value = getattr(self, field_name, None) - if value is not None: - data[field.db_field] = field.to_mongo(value) - # Only add _cls and _types if allow_inheritance is not False - if not (hasattr(self, '_meta') and - self._meta.get('allow_inheritance', ALLOW_INHERITANCE) == False): - data['_cls'] = self._class_name - data['_types'] = self._superclasses.keys() + [self._class_name] - if '_id' in data and data['_id'] is None: - del data['_id'] - - if not self._dynamic: - return data - - for name, field in self._dynamic_fields.items(): - data[name] = field.to_mongo(self._data.get(name, None)) - return data - - @classmethod - def _get_collection_name(cls): - """Returns the collection name for this class. - """ - return cls._meta.get('collection', None) - - @classmethod - def _from_son(cls, son): - """Create an instance of a Document (subclass) from a PyMongo SON. - """ - # get the class name from the document, falling back to the given - # class if unavailable - class_name = son.get('_cls', cls._class_name) - data = dict(("%s" % key, value) for key, value in son.items()) - if not UNICODE_KWARGS: - # python 2.6.4 and lower cannot handle unicode keys - # passed to class constructor example: cls(**data) - to_str_keys_recursive(data) - - if '_types' in data: - del data['_types'] - - if '_cls' in data: - del data['_cls'] - - # Return correct subclass for document type - if class_name != cls._class_name: - cls = get_document(class_name) - - changed_fields = [] - errors_dict = {} - - for field_name, field in cls._fields.items(): - if field.db_field in data: - value = data[field.db_field] - try: - data[field_name] = (value if value is None - else field.to_python(value)) - if field_name != field.db_field: - del data[field.db_field] - except (AttributeError, ValueError), e: - errors_dict[field_name] = e - elif field.default: - default = field.default - if callable(default): - default = default() - if isinstance(default, BaseDocument): - changed_fields.append(field_name) - - if errors_dict: - errors = "\n".join(["%s - %s" % (k, v) - for k, v in errors_dict.items()]) - msg = ("Invalid data to create a `%s` instance.\n%s" - % (cls._class_name, errors)) - raise InvalidDocumentError(msg) - - obj = cls(**data) - obj._changed_fields = changed_fields - obj._created = False - return obj - - def _mark_as_changed(self, key): - """Marks a key as explicitly changed by the user - """ - if not key: - return - key = self._db_field_map.get(key, key) - if (hasattr(self, '_changed_fields') and - key not in self._changed_fields): - self._changed_fields.append(key) - - def _get_changed_fields(self, key='', inspected=None): - """Returns a list of all fields that have explicitly been changed. - """ - EmbeddedDocument = _import_class("EmbeddedDocument") - DynamicEmbeddedDocument = _import_class("DynamicEmbeddedDocument") - _changed_fields = [] - _changed_fields += getattr(self, '_changed_fields', []) - - inspected = inspected or set() - if hasattr(self, 'id'): - if self.id in inspected: - return _changed_fields - inspected.add(self.id) - - field_list = self._fields.copy() - if self._dynamic: - field_list.update(self._dynamic_fields) - - for field_name in field_list: - - db_field_name = self._db_field_map.get(field_name, field_name) - key = '%s.' % db_field_name - field = self._data.get(field_name, None) - if hasattr(field, 'id'): - if field.id in inspected: - continue - inspected.add(field.id) - - if (isinstance(field, (EmbeddedDocument, DynamicEmbeddedDocument)) - and db_field_name not in _changed_fields): - # Find all embedded fields that have been changed - changed = field._get_changed_fields(key, inspected) - _changed_fields += ["%s%s" % (key, k) for k in changed if k] - elif (isinstance(field, (list, tuple, dict)) and - db_field_name not in _changed_fields): - # Loop list / dict fields as they contain documents - # Determine the iterator to use - if not hasattr(field, 'items'): - iterator = enumerate(field) - else: - iterator = field.iteritems() - for index, value in iterator: - if not hasattr(value, '_get_changed_fields'): - continue - list_key = "%s%s." % (key, index) - changed = value._get_changed_fields(list_key, inspected) - _changed_fields += ["%s%s" % (list_key, k) - for k in changed if k] - return _changed_fields - - def _delta(self): - """Returns the delta (set, unset) of the changes for a document. - Gets any values that have been explicitly changed. - """ - # Handles cases where not loaded from_son but has _id - doc = self.to_mongo() - set_fields = self._get_changed_fields() - set_data = {} - unset_data = {} - parts = [] - if hasattr(self, '_changed_fields'): - set_data = {} - # Fetch each set item from its path - for path in set_fields: - parts = path.split('.') - d = doc - new_path = [] - for p in parts: - if isinstance(d, DBRef): - break - elif p.isdigit(): - d = d[int(p)] - elif hasattr(d, 'get'): - d = d.get(p) - new_path.append(p) - path = '.'.join(new_path) - set_data[path] = d - else: - set_data = doc - if '_id' in set_data: - del(set_data['_id']) - - # Determine if any changed items were actually unset. - for path, value in set_data.items(): - if value or isinstance(value, bool): - continue - - # If we've set a value that ain't the default value dont unset it. - default = None - if (self._dynamic and len(parts) and - parts[0] in self._dynamic_fields): - del(set_data[path]) - unset_data[path] = 1 - continue - elif path in self._fields: - default = self._fields[path].default - else: # Perform a full lookup for lists / embedded lookups - d = self - parts = path.split('.') - db_field_name = parts.pop() - for p in parts: - if p.isdigit(): - d = d[int(p)] - elif (hasattr(d, '__getattribute__') and - not isinstance(d, dict)): - real_path = d._reverse_db_field_map.get(p, p) - d = getattr(d, real_path) - else: - d = d.get(p) - - if hasattr(d, '_fields'): - field_name = d._reverse_db_field_map.get(db_field_name, - db_field_name) - - if field_name in d._fields: - default = d._fields.get(field_name).default - else: - default = None - - if default is not None: - if callable(default): - default = default() - if default != value: - continue - - del(set_data[path]) - unset_data[path] = 1 - return set_data, unset_data - - @classmethod - def _geo_indices(cls, inspected=None): - inspected = inspected or [] - geo_indices = [] - inspected.append(cls) - - EmbeddedDocumentField = _import_class("EmbeddedDocumentField") - GeoPointField = _import_class("GeoPointField") - - for field in cls._fields.values(): - if not isinstance(field, (EmbeddedDocumentField, GeoPointField)): - continue - if hasattr(field, 'document_type'): - field_cls = field.document_type - if field_cls in inspected: - continue - if hasattr(field_cls, '_geo_indices'): - geo_indices += field_cls._geo_indices(inspected) - elif field._geo_index: - geo_indices.append(field) - return geo_indices - - def __getstate__(self): - removals = ("get_%s_display" % k - for k, v in self._fields.items() if v.choices) - for k in removals: - if hasattr(self, k): - delattr(self, k) - return self.__dict__ - - def __setstate__(self, __dict__): - self.__dict__ = __dict__ - self.__set_field_display() - - def __set_field_display(self): - """Dynamically set the display value for a field with choices""" - for attr_name, field in self._fields.items(): - if field.choices: - setattr(self, - 'get_%s_display' % attr_name, - partial(self.__get_field_display, field=field)) - - def __get_field_display(self, field): - """Returns the display value for a choice field""" - value = getattr(self, field.name) - if field.choices and isinstance(field.choices[0], (list, tuple)): - return dict(field.choices).get(value, value) - return value - - def __iter__(self): - return iter(self._fields) - - def __getitem__(self, name): - """Dictionary-style field access, return a field's value if present. - """ - try: - if name in self._fields: - return getattr(self, name) - except AttributeError: - pass - raise KeyError(name) - - def __setitem__(self, name, value): - """Dictionary-style field access, set a field's value. - """ - # Ensure that the field exists before settings its value - if name not in self._fields: - raise KeyError(name) - return setattr(self, name, value) - - def __contains__(self, name): - try: - val = getattr(self, name) - return val is not None - except AttributeError: - return False - - def __len__(self): - return len(self._data) - - def __repr__(self): - try: - u = self.__str__() - except (UnicodeEncodeError, UnicodeDecodeError): - u = '[Bad Unicode data]' - repr_type = type(u) - return repr_type('<%s: %s>' % (self.__class__.__name__, u)) - - def __str__(self): - if hasattr(self, '__unicode__'): - if PY3: - return self.__unicode__() - else: - return unicode(self).encode('utf-8') - return txt_type('%s object' % self.__class__.__name__) - - def __eq__(self, other): - if isinstance(other, self.__class__) and hasattr(other, 'id'): - if self.id == other.id: - return True - return False - - def __ne__(self, other): - return not self.__eq__(other) - - def __hash__(self): - if self.pk is None: - # For new object - return super(BaseDocument, self).__hash__() - else: - return hash(self.pk) - - -class BasesTuple(tuple): - """Special class to handle introspection of bases tuple in __new__""" - pass - - -class BaseList(list): - """A special list so we can watch any changes - """ - - _dereferenced = False - _instance = None - _name = None - - def __init__(self, list_items, instance, name): - self._instance = weakref.proxy(instance) - self._name = name - return super(BaseList, self).__init__(list_items) - - def __setitem__(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).__setitem__(*args, **kwargs) - - def __delitem__(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).__delitem__(*args, **kwargs) - - def __getstate__(self): - self.observer = None - return self - - def __setstate__(self, state): - self = state - return self - - def append(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).append(*args, **kwargs) - - def extend(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).extend(*args, **kwargs) - - def insert(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).insert(*args, **kwargs) - - def pop(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).pop(*args, **kwargs) - - def remove(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).remove(*args, **kwargs) - - def reverse(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).reverse(*args, **kwargs) - - def sort(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).sort(*args, **kwargs) - - def _mark_as_changed(self): - if hasattr(self._instance, '_mark_as_changed'): - self._instance._mark_as_changed(self._name) - - -class BaseDict(dict): - """A special dict so we can watch any changes - """ - - _dereferenced = False - _instance = None - _name = None - - def __init__(self, dict_items, instance, name): - self._instance = weakref.proxy(instance) - self._name = name - return super(BaseDict, self).__init__(dict_items) - - def __setitem__(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseDict, self).__setitem__(*args, **kwargs) - - def __delete__(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseDict, self).__delete__(*args, **kwargs) - - def __delitem__(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseDict, self).__delitem__(*args, **kwargs) - - def __delattr__(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseDict, self).__delattr__(*args, **kwargs) - - def __getstate__(self): - self.instance = None - self._dereferenced = False - return self - - def __setstate__(self, state): - self = state - return self - - def clear(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseDict, self).clear(*args, **kwargs) - - def pop(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseDict, self).pop(*args, **kwargs) - - def popitem(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseDict, self).popitem(*args, **kwargs) - - def update(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseDict, self).update(*args, **kwargs) - - def _mark_as_changed(self): - if hasattr(self._instance, '_mark_as_changed'): - self._instance._mark_as_changed(self._name) - - -def _import_class(cls_name): - """Cached mechanism for imports""" - if cls_name in _class_registry: - return _class_registry.get(cls_name) - - doc_classes = ['Document', 'DynamicEmbeddedDocument', 'EmbeddedDocument'] - field_classes = ['DictField', 'DynamicField', 'EmbeddedDocumentField', - 'GenericReferenceField', 'GeoPointField', - 'ReferenceField'] - queryset_classes = ['OperationError'] - deref_classes = ['DeReference'] - - if cls_name in doc_classes: - from mongoengine import document as module - import_classes = doc_classes - elif cls_name in field_classes: - from mongoengine import fields as module - import_classes = field_classes - elif cls_name in queryset_classes: - from mongoengine import queryset as module - import_classes = queryset_classes - elif cls_name in deref_classes: - from mongoengine import dereference as module - import_classes = deref_classes - else: - raise ValueError('No import set for: ' % cls_name) - - for cls in import_classes: - _class_registry[cls] = getattr(module, cls) - - return _class_registry.get(cls_name) diff --git a/mongoengine/base/__init__.py b/mongoengine/base/__init__.py new file mode 100644 index 00000000..1d4a6ebe --- /dev/null +++ b/mongoengine/base/__init__.py @@ -0,0 +1,5 @@ +from .common import * +from .datastructures import * +from .document import * +from .fields import * +from .metaclasses import * diff --git a/mongoengine/base/common.py b/mongoengine/base/common.py new file mode 100644 index 00000000..648561be --- /dev/null +++ b/mongoengine/base/common.py @@ -0,0 +1,25 @@ +from mongoengine.errors import NotRegistered + +__all__ = ('ALLOW_INHERITANCE', 'get_document', '_document_registry') + +ALLOW_INHERITANCE = True + +_document_registry = {} + + +def get_document(name): + doc = _document_registry.get(name, None) + if not doc: + # Possible old style names + end = ".%s" % name + possible_match = [k for k in _document_registry.keys() + if k.endswith(end)] + if len(possible_match) == 1: + doc = _document_registry.get(possible_match.pop(), None) + if not doc: + raise NotRegistered(""" + `%s` has not been registered in the document registry. + Importing the document class automatically registers it, has it + been imported? + """.strip() % name) + return doc diff --git a/mongoengine/base/datastructures.py b/mongoengine/base/datastructures.py new file mode 100644 index 00000000..9a7620e6 --- /dev/null +++ b/mongoengine/base/datastructures.py @@ -0,0 +1,124 @@ +import weakref + +__all__ = ("BaseDict", "BaseList") + + +class BaseDict(dict): + """A special dict so we can watch any changes + """ + + _dereferenced = False + _instance = None + _name = None + + def __init__(self, dict_items, instance, name): + self._instance = weakref.proxy(instance) + self._name = name + return super(BaseDict, self).__init__(dict_items) + + def __setitem__(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseDict, self).__setitem__(*args, **kwargs) + + def __delete__(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseDict, self).__delete__(*args, **kwargs) + + def __delitem__(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseDict, self).__delitem__(*args, **kwargs) + + def __delattr__(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseDict, self).__delattr__(*args, **kwargs) + + def __getstate__(self): + self.instance = None + self._dereferenced = False + return self + + def __setstate__(self, state): + self = state + return self + + def clear(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseDict, self).clear(*args, **kwargs) + + def pop(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseDict, self).pop(*args, **kwargs) + + def popitem(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseDict, self).popitem(*args, **kwargs) + + def update(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseDict, self).update(*args, **kwargs) + + def _mark_as_changed(self): + if hasattr(self._instance, '_mark_as_changed'): + self._instance._mark_as_changed(self._name) + + +class BaseList(list): + """A special list so we can watch any changes + """ + + _dereferenced = False + _instance = None + _name = None + + def __init__(self, list_items, instance, name): + self._instance = weakref.proxy(instance) + self._name = name + return super(BaseList, self).__init__(list_items) + + def __setitem__(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseList, self).__setitem__(*args, **kwargs) + + def __delitem__(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseList, self).__delitem__(*args, **kwargs) + + def __getstate__(self): + self.observer = None + return self + + def __setstate__(self, state): + self = state + return self + + def append(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseList, self).append(*args, **kwargs) + + def extend(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseList, self).extend(*args, **kwargs) + + def insert(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseList, self).insert(*args, **kwargs) + + def pop(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseList, self).pop(*args, **kwargs) + + def remove(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseList, self).remove(*args, **kwargs) + + def reverse(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseList, self).reverse(*args, **kwargs) + + def sort(self, *args, **kwargs): + self._mark_as_changed() + return super(BaseList, self).sort(*args, **kwargs) + + def _mark_as_changed(self): + if hasattr(self._instance, '_mark_as_changed'): + self._instance._mark_as_changed(self._name) diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py new file mode 100644 index 00000000..af97e1f2 --- /dev/null +++ b/mongoengine/base/document.py @@ -0,0 +1,644 @@ +import operator +from functools import partial + +import pymongo +from bson.dbref import DBRef + +from mongoengine import signals +from mongoengine.common import _import_class +from mongoengine.errors import (ValidationError, InvalidDocumentError, + LookUpError) +from mongoengine.python_support import (PY3, UNICODE_KWARGS, txt_type, + to_str_keys_recursive) + +from .common import get_document, ALLOW_INHERITANCE +from .datastructures import BaseDict, BaseList +from .fields import ComplexBaseField + +__all__ = ('BaseDocument', ) + + +class BaseDocument(object): + + _dynamic = False + _created = True + _dynamic_lock = True + _initialised = False + + def __init__(self, **values): + signals.pre_init.send(self.__class__, document=self, values=values) + + self._data = {} + + # Assign default values to instance + for key, field in self._fields.iteritems(): + if self._db_field_map.get(key, key) in values: + continue + value = getattr(self, key, None) + setattr(self, key, value) + + # Set passed values after initialisation + if self._dynamic: + self._dynamic_fields = {} + dynamic_data = {} + for key, value in values.iteritems(): + if key in self._fields or key == '_id': + setattr(self, key, value) + elif self._dynamic: + dynamic_data[key] = value + else: + for key, value in values.iteritems(): + key = self._reverse_db_field_map.get(key, key) + setattr(self, key, value) + + # Set any get_fieldname_display methods + self.__set_field_display() + + if self._dynamic: + self._dynamic_lock = False + for key, value in dynamic_data.iteritems(): + setattr(self, key, value) + + # Flag initialised + self._initialised = True + signals.post_init.send(self.__class__, document=self) + + def __setattr__(self, name, value): + # Handle dynamic data only if an initialised dynamic document + if self._dynamic and not self._dynamic_lock: + + field = None + if not hasattr(self, name) and not name.startswith('_'): + DynamicField = _import_class("DynamicField") + field = DynamicField(db_field=name) + field.name = name + self._dynamic_fields[name] = field + + if not name.startswith('_'): + value = self.__expand_dynamic_values(name, value) + + # Handle marking data as changed + if name in self._dynamic_fields: + self._data[name] = value + if hasattr(self, '_changed_fields'): + self._mark_as_changed(name) + + if (self._is_document and not self._created and + name in self._meta.get('shard_key', tuple()) and + self._data.get(name) != value): + OperationError = _import_class('OperationError') + msg = "Shard Keys are immutable. Tried to update %s" % name + raise OperationError(msg) + + super(BaseDocument, self).__setattr__(name, value) + + def __getstate__(self): + removals = ("get_%s_display" % k + for k, v in self._fields.items() if v.choices) + for k in removals: + if hasattr(self, k): + delattr(self, k) + return self.__dict__ + + def __setstate__(self, __dict__): + self.__dict__ = __dict__ + self.__set_field_display() + + def __iter__(self): + return iter(self._fields) + + def __getitem__(self, name): + """Dictionary-style field access, return a field's value if present. + """ + try: + if name in self._fields: + return getattr(self, name) + except AttributeError: + pass + raise KeyError(name) + + def __setitem__(self, name, value): + """Dictionary-style field access, set a field's value. + """ + # Ensure that the field exists before settings its value + if name not in self._fields: + raise KeyError(name) + return setattr(self, name, value) + + def __contains__(self, name): + try: + val = getattr(self, name) + return val is not None + except AttributeError: + return False + + def __len__(self): + return len(self._data) + + def __repr__(self): + try: + u = self.__str__() + except (UnicodeEncodeError, UnicodeDecodeError): + u = '[Bad Unicode data]' + repr_type = type(u) + return repr_type('<%s: %s>' % (self.__class__.__name__, u)) + + def __str__(self): + if hasattr(self, '__unicode__'): + if PY3: + return self.__unicode__() + else: + return unicode(self).encode('utf-8') + return txt_type('%s object' % self.__class__.__name__) + + def __eq__(self, other): + if isinstance(other, self.__class__) and hasattr(other, 'id'): + if self.id == other.id: + return True + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + if self.pk is None: + # For new object + return super(BaseDocument, self).__hash__() + else: + return hash(self.pk) + + def to_mongo(self): + """Return data dictionary ready for use with MongoDB. + """ + data = {} + for field_name, field in self._fields.items(): + value = getattr(self, field_name, None) + if value is not None: + data[field.db_field] = field.to_mongo(value) + # Only add _cls if allow_inheritance is not False + if not (hasattr(self, '_meta') and + self._meta.get('allow_inheritance', ALLOW_INHERITANCE) == False): + data['_cls'] = self._class_name + if '_id' in data and data['_id'] is None: + del data['_id'] + + if not self._dynamic: + return data + + for name, field in self._dynamic_fields.items(): + data[name] = field.to_mongo(self._data.get(name, None)) + return data + + def validate(self): + """Ensure that all fields' values are valid and that required fields + are present. + """ + # Get a list of tuples of field names and their current values + fields = [(field, getattr(self, name)) + for name, field in self._fields.items()] + + # Ensure that each field is matched to a valid value + errors = {} + for field, value in fields: + if value is not None: + try: + field._validate(value) + except ValidationError, error: + errors[field.name] = error.errors or error + except (ValueError, AttributeError, AssertionError), error: + errors[field.name] = error + elif field.required: + errors[field.name] = ValidationError('Field is required', + field_name=field.name) + if errors: + raise ValidationError('ValidationError', errors=errors) + + def __expand_dynamic_values(self, name, value): + """expand any dynamic values to their correct types / values""" + if not isinstance(value, (dict, list, tuple)): + return value + + is_list = False + if not hasattr(value, 'items'): + is_list = True + value = dict([(k, v) for k, v in enumerate(value)]) + + if not is_list and '_cls' in value: + cls = get_document(value['_cls']) + return cls(**value) + + data = {} + for k, v in value.items(): + key = name if is_list else k + data[k] = self.__expand_dynamic_values(key, v) + + if is_list: # Convert back to a list + data_items = sorted(data.items(), key=operator.itemgetter(0)) + value = [v for k, v in data_items] + else: + value = data + + # Convert lists / values so we can watch for any changes on them + if (isinstance(value, (list, tuple)) and + not isinstance(value, BaseList)): + value = BaseList(value, self, name) + elif isinstance(value, dict) and not isinstance(value, BaseDict): + value = BaseDict(value, self, name) + + return value + + def _mark_as_changed(self, key): + """Marks a key as explicitly changed by the user + """ + if not key: + return + key = self._db_field_map.get(key, key) + if (hasattr(self, '_changed_fields') and + key not in self._changed_fields): + self._changed_fields.append(key) + + def _get_changed_fields(self, key='', inspected=None): + """Returns a list of all fields that have explicitly been changed. + """ + EmbeddedDocument = _import_class("EmbeddedDocument") + DynamicEmbeddedDocument = _import_class("DynamicEmbeddedDocument") + _changed_fields = [] + _changed_fields += getattr(self, '_changed_fields', []) + + inspected = inspected or set() + if hasattr(self, 'id'): + if self.id in inspected: + return _changed_fields + inspected.add(self.id) + + field_list = self._fields.copy() + if self._dynamic: + field_list.update(self._dynamic_fields) + + for field_name in field_list: + + db_field_name = self._db_field_map.get(field_name, field_name) + key = '%s.' % db_field_name + field = self._data.get(field_name, None) + if hasattr(field, 'id'): + if field.id in inspected: + continue + inspected.add(field.id) + + if (isinstance(field, (EmbeddedDocument, DynamicEmbeddedDocument)) + and db_field_name not in _changed_fields): + # Find all embedded fields that have been changed + changed = field._get_changed_fields(key, inspected) + _changed_fields += ["%s%s" % (key, k) for k in changed if k] + elif (isinstance(field, (list, tuple, dict)) and + db_field_name not in _changed_fields): + # Loop list / dict fields as they contain documents + # Determine the iterator to use + if not hasattr(field, 'items'): + iterator = enumerate(field) + else: + iterator = field.iteritems() + for index, value in iterator: + if not hasattr(value, '_get_changed_fields'): + continue + list_key = "%s%s." % (key, index) + changed = value._get_changed_fields(list_key, inspected) + _changed_fields += ["%s%s" % (list_key, k) + for k in changed if k] + return _changed_fields + + def _delta(self): + """Returns the delta (set, unset) of the changes for a document. + Gets any values that have been explicitly changed. + """ + # Handles cases where not loaded from_son but has _id + doc = self.to_mongo() + set_fields = self._get_changed_fields() + set_data = {} + unset_data = {} + parts = [] + if hasattr(self, '_changed_fields'): + set_data = {} + # Fetch each set item from its path + for path in set_fields: + parts = path.split('.') + d = doc + new_path = [] + for p in parts: + if isinstance(d, DBRef): + break + elif p.isdigit(): + d = d[int(p)] + elif hasattr(d, 'get'): + d = d.get(p) + new_path.append(p) + path = '.'.join(new_path) + set_data[path] = d + else: + set_data = doc + if '_id' in set_data: + del(set_data['_id']) + + # Determine if any changed items were actually unset. + for path, value in set_data.items(): + if value or isinstance(value, bool): + continue + + # If we've set a value that ain't the default value dont unset it. + default = None + if (self._dynamic and len(parts) and + parts[0] in self._dynamic_fields): + del(set_data[path]) + unset_data[path] = 1 + continue + elif path in self._fields: + default = self._fields[path].default + else: # Perform a full lookup for lists / embedded lookups + d = self + parts = path.split('.') + db_field_name = parts.pop() + for p in parts: + if p.isdigit(): + d = d[int(p)] + elif (hasattr(d, '__getattribute__') and + not isinstance(d, dict)): + real_path = d._reverse_db_field_map.get(p, p) + d = getattr(d, real_path) + else: + d = d.get(p) + + if hasattr(d, '_fields'): + field_name = d._reverse_db_field_map.get(db_field_name, + db_field_name) + + if field_name in d._fields: + default = d._fields.get(field_name).default + else: + default = None + + if default is not None: + if callable(default): + default = default() + if default != value: + continue + + del(set_data[path]) + unset_data[path] = 1 + return set_data, unset_data + + @classmethod + def _get_collection_name(cls): + """Returns the collection name for this class. + """ + return cls._meta.get('collection', None) + + @classmethod + def _from_son(cls, son): + """Create an instance of a Document (subclass) from a PyMongo SON. + """ + # get the class name from the document, falling back to the given + # class if unavailable + class_name = son.get('_cls', cls._class_name) + data = dict(("%s" % key, value) for key, value in son.items()) + if not UNICODE_KWARGS: + # python 2.6.4 and lower cannot handle unicode keys + # passed to class constructor example: cls(**data) + to_str_keys_recursive(data) + + if '_cls' in data: + del data['_cls'] + + # Return correct subclass for document type + if class_name != cls._class_name: + cls = get_document(class_name) + + changed_fields = [] + errors_dict = {} + + for field_name, field in cls._fields.items(): + if field.db_field in data: + value = data[field.db_field] + try: + data[field_name] = (value if value is None + else field.to_python(value)) + if field_name != field.db_field: + del data[field.db_field] + except (AttributeError, ValueError), e: + errors_dict[field_name] = e + elif field.default: + default = field.default + if callable(default): + default = default() + if isinstance(default, BaseDocument): + changed_fields.append(field_name) + + if errors_dict: + errors = "\n".join(["%s - %s" % (k, v) + for k, v in errors_dict.items()]) + msg = ("Invalid data to create a `%s` instance.\n%s" + % (cls._class_name, errors)) + raise InvalidDocumentError(msg) + + obj = cls(**data) + obj._changed_fields = changed_fields + obj._created = False + return obj + + @classmethod + def _build_index_spec(cls, spec): + """Build a PyMongo index spec from a MongoEngine index spec. + """ + if isinstance(spec, basestring): + spec = {'fields': [spec]} + elif isinstance(spec, (list, tuple)): + spec = {'fields': list(spec)} + elif isinstance(spec, dict): + spec = dict(spec) + + index_list = [] + direction = None + + # Check to see if we need to include _cls + allow_inheritance = cls._meta.get('allow_inheritance', + ALLOW_INHERITANCE) != False + include_cls = allow_inheritance and not spec.get('sparse', False) + + for key in spec['fields']: + # If inherited spec continue + if isinstance(key, (list, tuple)): + continue + + # ASCENDING from +, + # DESCENDING from - + # GEO2D from * + direction = pymongo.ASCENDING + if key.startswith("-"): + direction = pymongo.DESCENDING + elif key.startswith("*"): + direction = pymongo.GEO2D + if key.startswith(("+", "-", "*")): + key = key[1:] + + # Use real field name, do it manually because we need field + # objects for the next part (list field checking) + parts = key.split('.') + if parts in (['pk'], ['id'], ['_id']): + key = '_id' + fields = [] + else: + fields = cls._lookup_field(parts) + parts = [field if field == '_id' else field.db_field + for field in fields] + key = '.'.join(parts) + index_list.append((key, direction)) + + # Don't add cls to a geo index + if include_cls and direction is not pymongo.GEO2D: + index_list.insert(0, ('_cls', 1)) + + spec['fields'] = index_list + if spec.get('sparse', False) and len(spec['fields']) > 1: + raise ValueError( + 'Sparse indexes can only have one field in them. ' + 'See https://jira.mongodb.org/browse/SERVER-2193') + + return spec + + @classmethod + def _unique_with_indexes(cls, namespace=""): + """ + Find and set unique indexes + """ + unique_indexes = [] + for field_name, field in cls._fields.items(): + # Generate a list of indexes needed by uniqueness constraints + if field.unique: + field.required = True + unique_fields = [field.db_field] + + # Add any unique_with fields to the back of the index spec + if field.unique_with: + if isinstance(field.unique_with, basestring): + field.unique_with = [field.unique_with] + + # Convert unique_with field names to real field names + unique_with = [] + for other_name in field.unique_with: + parts = other_name.split('.') + # Lookup real name + parts = cls._lookup_field(parts) + name_parts = [part.db_field for part in parts] + unique_with.append('.'.join(name_parts)) + # Unique field should be required + parts[-1].required = True + unique_fields += unique_with + + # Add the new index to the list + index = [("%s%s" % (namespace, f), pymongo.ASCENDING) + for f in unique_fields] + unique_indexes.append(index) + + # Grab any embedded document field unique indexes + if (field.__class__.__name__ == "EmbeddedDocumentField" and + field.document_type != cls): + field_namespace = "%s." % field_name + doc_cls = field.document_type + unique_indexes += doc_cls._unique_with_indexes(field_namespace) + + return unique_indexes + + @classmethod + def _lookup_field(cls, parts): + """Lookup a field based on its attribute and return a list containing + the field's parents and the field. + """ + if not isinstance(parts, (list, tuple)): + parts = [parts] + fields = [] + field = None + + for field_name in parts: + # Handle ListField indexing: + if field_name.isdigit(): + new_field = field.field + fields.append(field_name) + continue + + if field is None: + # Look up first field from the document + if field_name == 'pk': + # Deal with "primary key" alias + field_name = cls._meta['id_field'] + if field_name in cls._fields: + field = cls._fields[field_name] + elif cls._dynamic: + DynamicField = _import_class('DynamicField') + field = DynamicField(db_field=field_name) + else: + raise LookUpError('Cannot resolve field "%s"' + % field_name) + else: + ReferenceField = _import_class('ReferenceField') + GenericReferenceField = _import_class('GenericReferenceField') + if isinstance(field, (ReferenceField, GenericReferenceField)): + raise LookUpError('Cannot perform join in mongoDB: %s' % + '__'.join(parts)) + if hasattr(getattr(field, 'field', None), 'lookup_member'): + new_field = field.field.lookup_member(field_name) + else: + # Look up subfield on the previous field + new_field = field.lookup_member(field_name) + if not new_field and isinstance(field, ComplexBaseField): + fields.append(field_name) + continue + elif not new_field: + raise LookUpError('Cannot resolve field "%s"' + % field_name) + field = new_field # update field to the new field type + fields.append(field) + return fields + + @classmethod + def _translate_field_name(cls, field, sep='.'): + """Translate a field attribute name to a database field name. + """ + parts = field.split(sep) + parts = [f.db_field for f in cls._lookup_field(parts)] + return '.'.join(parts) + + @classmethod + def _geo_indices(cls, inspected=None): + inspected = inspected or [] + geo_indices = [] + inspected.append(cls) + + EmbeddedDocumentField = _import_class("EmbeddedDocumentField") + GeoPointField = _import_class("GeoPointField") + + for field in cls._fields.values(): + if not isinstance(field, (EmbeddedDocumentField, GeoPointField)): + continue + if hasattr(field, 'document_type'): + field_cls = field.document_type + if field_cls in inspected: + continue + if hasattr(field_cls, '_geo_indices'): + geo_indices += field_cls._geo_indices(inspected) + elif field._geo_index: + geo_indices.append(field) + return geo_indices + + def __set_field_display(self): + """Dynamically set the display value for a field with choices""" + for attr_name, field in self._fields.items(): + if field.choices: + setattr(self, + 'get_%s_display' % attr_name, + partial(self.__get_field_display, field=field)) + + def __get_field_display(self, field): + """Returns the display value for a choice field""" + value = getattr(self, field.name) + if field.choices and isinstance(field.choices[0], (list, tuple)): + return dict(field.choices).get(value, value) + return value diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py new file mode 100644 index 00000000..44f5e131 --- /dev/null +++ b/mongoengine/base/fields.py @@ -0,0 +1,371 @@ +import operator +import warnings + +from bson import DBRef, ObjectId + +from mongoengine.common import _import_class +from mongoengine.errors import ValidationError + +from .common import ALLOW_INHERITANCE +from .datastructures import BaseDict, BaseList + +__all__ = ("BaseField", "ComplexBaseField", "ObjectIdField") + + +class BaseField(object): + """A base class for fields in a MongoDB document. Instances of this class + may be added to subclasses of `Document` to define a document's schema. + + .. versionchanged:: 0.5 - added verbose and help text + """ + + name = None + _geo_index = False + + # These track each time a Field instance is created. Used to retain order. + # The auto_creation_counter is used for fields that MongoEngine implicitly + # creates, creation_counter is used for all user-specified fields. + creation_counter = 0 + auto_creation_counter = -1 + + def __init__(self, db_field=None, name=None, required=False, default=None, + unique=False, unique_with=None, primary_key=False, + validation=None, choices=None, verbose_name=None, + help_text=None): + self.db_field = (db_field or name) if not primary_key else '_id' + if name: + msg = "Fields' 'name' attribute deprecated in favour of 'db_field'" + warnings.warn(msg, DeprecationWarning) + self.name = None + self.required = required or primary_key + self.default = default + self.unique = bool(unique or unique_with) + self.unique_with = unique_with + self.primary_key = primary_key + self.validation = validation + self.choices = choices + self.verbose_name = verbose_name + self.help_text = help_text + + # Adjust the appropriate creation counter, and save our local copy. + if self.db_field == '_id': + self.creation_counter = BaseField.auto_creation_counter + BaseField.auto_creation_counter -= 1 + else: + self.creation_counter = BaseField.creation_counter + BaseField.creation_counter += 1 + + def __get__(self, instance, owner): + """Descriptor for retrieving a value from a field in a document. Do + any necessary conversion between Python and MongoDB types. + """ + if instance is None: + # Document class being used rather than a document object + return self + + # Get value from document instance if available, if not use default + value = instance._data.get(self.name) + + if value is None: + value = self.default + # Allow callable default values + if callable(value): + value = value() + + return value + + def __set__(self, instance, value): + """Descriptor for assigning a value to a field in a document. + """ + instance._data[self.name] = value + if instance._initialised: + instance._mark_as_changed(self.name) + + def error(self, message="", errors=None, field_name=None): + """Raises a ValidationError. + """ + field_name = field_name if field_name else self.name + raise ValidationError(message, errors=errors, field_name=field_name) + + def to_python(self, value): + """Convert a MongoDB-compatible type to a Python type. + """ + return value + + def to_mongo(self, value): + """Convert a Python type to a MongoDB-compatible type. + """ + return self.to_python(value) + + def prepare_query_value(self, op, value): + """Prepare a value that is being used in a query for PyMongo. + """ + return value + + def validate(self, value): + """Perform validation on a value. + """ + pass + + def _validate(self, value): + Document = _import_class('Document') + EmbeddedDocument = _import_class('EmbeddedDocument') + # check choices + if self.choices: + is_cls = isinstance(value, (Document, EmbeddedDocument)) + value_to_check = value.__class__ if is_cls else value + err_msg = 'an instance' if is_cls else 'one' + if isinstance(self.choices[0], (list, tuple)): + option_keys = [k for k, v in self.choices] + if value_to_check not in option_keys: + msg = ('Value must be %s of %s' % + (err_msg, unicode(option_keys))) + self.error(msg) + elif value_to_check not in self.choices: + msg = ('Value must be %s of %s' % + (err_msg, unicode(self.choices))) + self.error() + + # check validation argument + if self.validation is not None: + if callable(self.validation): + if not self.validation(value): + self.error('Value does not match custom validation method') + else: + raise ValueError('validation argument for "%s" must be a ' + 'callable.' % self.name) + + self.validate(value) + + +class ComplexBaseField(BaseField): + """Handles complex fields, such as lists / dictionaries. + + Allows for nesting of embedded documents inside complex types. + Handles the lazy dereferencing of a queryset by lazily dereferencing all + items in a list / dict rather than one at a time. + + .. versionadded:: 0.5 + """ + + field = None + __dereference = False + + def __get__(self, instance, owner): + """Descriptor to automatically dereference references. + """ + if instance is None: + # Document class being used rather than a document object + return self + + ReferenceField = _import_class('ReferenceField') + GenericReferenceField = _import_class('GenericReferenceField') + dereference = self.field is None or isinstance(self.field, + (GenericReferenceField, ReferenceField)) + if not self._dereference and instance._initialised and dereference: + instance._data[self.name] = self._dereference( + instance._data.get(self.name), max_depth=1, instance=instance, + name=self.name + ) + + value = super(ComplexBaseField, self).__get__(instance, owner) + + # Convert lists / values so we can watch for any changes on them + if (isinstance(value, (list, tuple)) and + not isinstance(value, BaseList)): + value = BaseList(value, instance, self.name) + instance._data[self.name] = value + elif isinstance(value, dict) and not isinstance(value, BaseDict): + value = BaseDict(value, instance, self.name) + instance._data[self.name] = value + + if (instance._initialised and isinstance(value, (BaseList, BaseDict)) + and not value._dereferenced): + value = self._dereference( + value, max_depth=1, instance=instance, name=self.name + ) + value._dereferenced = True + instance._data[self.name] = value + + return value + + def __set__(self, instance, value): + """Descriptor for assigning a value to a field in a document. + """ + instance._data[self.name] = value + instance._mark_as_changed(self.name) + + def to_python(self, value): + """Convert a MongoDB-compatible type to a Python type. + """ + Document = _import_class('Document') + + if isinstance(value, basestring): + return value + + if hasattr(value, 'to_python'): + return value.to_python() + + is_list = False + if not hasattr(value, 'items'): + try: + is_list = True + value = dict([(k, v) for k, v in enumerate(value)]) + except TypeError: # Not iterable return the value + return value + + if self.field: + value_dict = dict([(key, self.field.to_python(item)) + for key, item in value.items()]) + else: + value_dict = {} + for k, v in value.items(): + if isinstance(v, Document): + # We need the id from the saved object to create the DBRef + if v.pk is None: + self.error('You can only reference documents once they' + ' have been saved to the database') + collection = v._get_collection_name() + value_dict[k] = DBRef(collection, v.pk) + elif hasattr(v, 'to_python'): + value_dict[k] = v.to_python() + else: + value_dict[k] = self.to_python(v) + + if is_list: # Convert back to a list + return [v for k, v in sorted(value_dict.items(), + key=operator.itemgetter(0))] + return value_dict + + def to_mongo(self, value): + """Convert a Python type to a MongoDB-compatible type. + """ + Document = _import_class("Document") + + if isinstance(value, basestring): + return value + + if hasattr(value, 'to_mongo'): + return value.to_mongo() + + is_list = False + if not hasattr(value, 'items'): + try: + is_list = True + value = dict([(k, v) for k, v in enumerate(value)]) + except TypeError: # Not iterable return the value + return value + + if self.field: + value_dict = dict([(key, self.field.to_mongo(item)) + for key, item in value.items()]) + else: + value_dict = {} + for k, v in value.items(): + if isinstance(v, Document): + # We need the id from the saved object to create the DBRef + if v.pk is None: + self.error('You can only reference documents once they' + ' have been saved to the database') + + # If its a document that is not inheritable it won't have + # any _cls data so make it a generic reference allows + # us to dereference + meta = getattr(v, '_meta', {}) + allow_inheritance = ( + meta.get('allow_inheritance', ALLOW_INHERITANCE) + == False) + if allow_inheritance and not self.field: + GenericReferenceField = _import_class( + "GenericReferenceField") + value_dict[k] = GenericReferenceField().to_mongo(v) + else: + collection = v._get_collection_name() + value_dict[k] = DBRef(collection, v.pk) + elif hasattr(v, 'to_mongo'): + value_dict[k] = v.to_mongo() + else: + value_dict[k] = self.to_mongo(v) + + if is_list: # Convert back to a list + return [v for k, v in sorted(value_dict.items(), + key=operator.itemgetter(0))] + return value_dict + + def validate(self, value): + """If field is provided ensure the value is valid. + """ + errors = {} + if self.field: + if hasattr(value, 'iteritems') or hasattr(value, 'items'): + sequence = value.iteritems() + else: + sequence = enumerate(value) + for k, v in sequence: + try: + self.field._validate(v) + except ValidationError, error: + errors[k] = error.errors or error + except (ValueError, AssertionError), error: + errors[k] = error + + if errors: + field_class = self.field.__class__.__name__ + self.error('Invalid %s item (%s)' % (field_class, value), + errors=errors) + # Don't allow empty values if required + if self.required and not value: + self.error('Field is required and cannot be empty') + + def prepare_query_value(self, op, value): + return self.to_mongo(value) + + def lookup_member(self, member_name): + if self.field: + return self.field.lookup_member(member_name) + return None + + def _set_owner_document(self, owner_document): + if self.field: + self.field.owner_document = owner_document + self._owner_document = owner_document + + def _get_owner_document(self, owner_document): + self._owner_document = owner_document + + owner_document = property(_get_owner_document, _set_owner_document) + + @property + def _dereference(self,): + if not self.__dereference: + DeReference = _import_class("DeReference") + self.__dereference = DeReference() # Cached + return self.__dereference + + +class ObjectIdField(BaseField): + """A field wrapper around MongoDB's ObjectIds. + """ + + def to_python(self, value): + if not isinstance(value, ObjectId): + value = ObjectId(value) + return value + + def to_mongo(self, value): + if not isinstance(value, ObjectId): + try: + return ObjectId(unicode(value)) + except Exception, e: + # e.message attribute has been deprecated since Python 2.6 + self.error(unicode(e)) + return value + + def prepare_query_value(self, op, value): + return self.to_mongo(value) + + def validate(self, value): + try: + ObjectId(unicode(value)) + except: + self.error('Invalid Object ID') diff --git a/mongoengine/base/metaclasses.py b/mongoengine/base/metaclasses.py new file mode 100644 index 00000000..f87b03e4 --- /dev/null +++ b/mongoengine/base/metaclasses.py @@ -0,0 +1,388 @@ +import warnings + +import pymongo + +from mongoengine.common import _import_class +from mongoengine.errors import InvalidDocumentError +from mongoengine.python_support import PY3 +from mongoengine.queryset import (DO_NOTHING, DoesNotExist, + MultipleObjectsReturned, + QuerySet, QuerySetManager) + +from .common import _document_registry, ALLOW_INHERITANCE +from .fields import BaseField, ComplexBaseField, ObjectIdField + +__all__ = ('DocumentMetaclass', 'TopLevelDocumentMetaclass') + + +class DocumentMetaclass(type): + """Metaclass for all documents. + """ + + def __new__(cls, name, bases, attrs): + flattened_bases = cls._get_bases(bases) + super_new = super(DocumentMetaclass, cls).__new__ + + # If a base class just call super + metaclass = attrs.get('my_metaclass') + if metaclass and issubclass(metaclass, DocumentMetaclass): + return super_new(cls, name, bases, attrs) + + attrs['_is_document'] = attrs.get('_is_document', False) + + # EmbeddedDocuments could have meta data for inheritance + if 'meta' in attrs: + attrs['_meta'] = attrs.pop('meta') + + # Handle document Fields + + # Merge all fields from subclasses + doc_fields = {} + for base in flattened_bases[::-1]: + if hasattr(base, '_fields'): + doc_fields.update(base._fields) + + # Standard object mixin - merge in any Fields + if not hasattr(base, '_meta'): + base_fields = {} + for attr_name, attr_value in base.__dict__.iteritems(): + if not isinstance(attr_value, BaseField): + continue + attr_value.name = attr_name + if not attr_value.db_field: + attr_value.db_field = attr_name + base_fields[attr_name] = attr_value + doc_fields.update(base_fields) + + # Discover any document fields + field_names = {} + for attr_name, attr_value in attrs.iteritems(): + if not isinstance(attr_value, BaseField): + continue + attr_value.name = attr_name + if not attr_value.db_field: + attr_value.db_field = attr_name + doc_fields[attr_name] = attr_value + + # Count names to ensure no db_field redefinitions + field_names[attr_value.db_field] = field_names.get( + attr_value.db_field, 0) + 1 + + # Ensure no duplicate db_fields + duplicate_db_fields = [k for k, v in field_names.items() if v > 1] + if duplicate_db_fields: + msg = ("Multiple db_fields defined for: %s " % + ", ".join(duplicate_db_fields)) + raise InvalidDocumentError(msg) + + # Set _fields and db_field maps + attrs['_fields'] = doc_fields + attrs['_db_field_map'] = dict([(k, getattr(v, 'db_field', k)) + for k, v in doc_fields.iteritems()]) + attrs['_reverse_db_field_map'] = dict( + (v, k) for k, v in attrs['_db_field_map'].iteritems()) + + # + # Set document hierarchy + # + superclasses = () + class_name = [name] + for base in flattened_bases: + if (not getattr(base, '_is_base_cls', True) and + not getattr(base, '_meta', {}).get('abstract', True)): + # Collate heirarchy for _cls and _subclasses + class_name.append(base.__name__) + + if hasattr(base, '_meta'): + # Warn if allow_inheritance isn't set and prevent + # inheritance of classes where inheritance is set to False + allow_inheritance = base._meta.get('allow_inheritance', + ALLOW_INHERITANCE) + if (not getattr(base, '_is_base_cls', True) + and allow_inheritance is None): + warnings.warn( + "%s uses inheritance, the default for " + "allow_inheritance is changing to off by default. " + "Please add it to the document meta." % name, + FutureWarning + ) + elif (allow_inheritance == False and + not base._meta.get('abstract')): + raise ValueError('Document %s may not be subclassed' % + base.__name__) + + # Get superclasses from last base superclass + document_bases = [b for b in flattened_bases + if hasattr(b, '_class_name')] + if document_bases: + superclasses = document_bases[0]._superclasses + superclasses += (document_bases[0]._class_name, ) + + _cls = '.'.join(reversed(class_name)) + attrs['_class_name'] = _cls + attrs['_superclasses'] = superclasses + attrs['_subclasses'] = (_cls, ) + attrs['_types'] = attrs['_subclasses'] # TODO depreciate _types + + # Create the new_class + new_class = super_new(cls, name, bases, attrs) + + # Set _subclasses + for base in document_bases: + if _cls not in base._subclasses: + base._subclasses += (_cls,) + base._types = base._subclasses # TODO depreciate _types + + # Handle delete rules + Document, EmbeddedDocument, DictField = cls._import_classes() + for field in new_class._fields.itervalues(): + f = field + f.owner_document = new_class + delete_rule = getattr(f, 'reverse_delete_rule', DO_NOTHING) + if isinstance(f, ComplexBaseField) and hasattr(f, 'field'): + delete_rule = getattr(f.field, + 'reverse_delete_rule', + DO_NOTHING) + if isinstance(f, DictField) and delete_rule != DO_NOTHING: + msg = ("Reverse delete rules are not supported " + "for %s (field: %s)" % + (field.__class__.__name__, field.name)) + raise InvalidDocumentError(msg) + + f = field.field + + if delete_rule != DO_NOTHING: + if issubclass(new_class, EmbeddedDocument): + msg = ("Reverse delete rules are not supported for " + "EmbeddedDocuments (field: %s)" % field.name) + raise InvalidDocumentError(msg) + f.document_type.register_delete_rule(new_class, + field.name, delete_rule) + + if (field.name and hasattr(Document, field.name) and + EmbeddedDocument not in new_class.mro()): + msg = ("%s is a document method and not a valid " + "field name" % field.name) + raise InvalidDocumentError(msg) + + # Add class to the _document_registry + _document_registry[new_class._class_name] = new_class + + # In Python 2, User-defined methods objects have special read-only + # attributes 'im_func' and 'im_self' which contain the function obj + # and class instance object respectively. With Python 3 these special + # attributes have been replaced by __func__ and __self__. The Blinker + # module continues to use im_func and im_self, so the code below + # copies __func__ into im_func and __self__ into im_self for + # classmethod objects in Document derived classes. + if PY3: + for key, val in new_class.__dict__.items(): + if isinstance(val, classmethod): + f = val.__get__(new_class) + if hasattr(f, '__func__') and not hasattr(f, 'im_func'): + f.__dict__.update({'im_func': getattr(f, '__func__')}) + if hasattr(f, '__self__') and not hasattr(f, 'im_self'): + f.__dict__.update({'im_self': getattr(f, '__self__')}) + + return new_class + + def add_to_class(self, name, value): + setattr(self, name, value) + + @classmethod + def _get_bases(cls, bases): + if isinstance(bases, BasesTuple): + return bases + seen = [] + bases = cls.__get_bases(bases) + unique_bases = (b for b in bases if not (b in seen or seen.append(b))) + return BasesTuple(unique_bases) + + @classmethod + def __get_bases(cls, bases): + for base in bases: + if base is object: + continue + yield base + for child_base in cls.__get_bases(base.__bases__): + yield child_base + + @classmethod + def _import_classes(cls): + Document = _import_class('Document') + EmbeddedDocument = _import_class('EmbeddedDocument') + DictField = _import_class('DictField') + return (Document, EmbeddedDocument, DictField) + + +class TopLevelDocumentMetaclass(DocumentMetaclass): + """Metaclass for top-level documents (i.e. documents that have their own + collection in the database. + """ + + def __new__(cls, name, bases, attrs): + flattened_bases = cls._get_bases(bases) + super_new = super(TopLevelDocumentMetaclass, cls).__new__ + + # Set default _meta data if base class, otherwise get user defined meta + if (attrs.get('my_metaclass') == TopLevelDocumentMetaclass): + # defaults + attrs['_meta'] = { + 'abstract': True, + 'max_documents': None, + 'max_size': None, + 'ordering': [], # default ordering applied at runtime + 'indexes': [], # indexes to be ensured at runtime + 'id_field': None, + 'index_background': False, + 'index_drop_dups': False, + 'index_opts': None, + 'delete_rules': None, + 'allow_inheritance': None, + } + attrs['_is_base_cls'] = True + attrs['_meta'].update(attrs.get('meta', {})) + else: + attrs['_meta'] = attrs.get('meta', {}) + # Explictly set abstract to false unless set + attrs['_meta']['abstract'] = attrs['_meta'].get('abstract', False) + attrs['_is_base_cls'] = False + + # Set flag marking as document class - as opposed to an object mixin + attrs['_is_document'] = True + + # Ensure queryset_class is inherited + if 'objects' in attrs: + manager = attrs['objects'] + if hasattr(manager, 'queryset_class'): + attrs['_meta']['queryset_class'] = manager.queryset_class + + # Clean up top level meta + if 'meta' in attrs: + del(attrs['meta']) + + # Find the parent document class + parent_doc_cls = [b for b in flattened_bases + if b.__class__ == TopLevelDocumentMetaclass] + parent_doc_cls = None if not parent_doc_cls else parent_doc_cls[0] + + # Prevent classes setting collection different to their parents + # If parent wasn't an abstract class + if (parent_doc_cls and 'collection' in attrs.get('_meta', {}) + and not parent_doc_cls._meta.get('abstract', True)): + msg = "Trying to set a collection on a subclass (%s)" % name + warnings.warn(msg, SyntaxWarning) + del(attrs['_meta']['collection']) + + # Ensure abstract documents have abstract bases + if attrs.get('_is_base_cls') or attrs['_meta'].get('abstract'): + if (parent_doc_cls and + not parent_doc_cls._meta.get('abstract', False)): + msg = "Abstract document cannot have non-abstract base" + raise ValueError(msg) + return super_new(cls, name, bases, attrs) + + # Merge base class metas. + # Uses a special MetaDict that handles various merging rules + meta = MetaDict() + for base in flattened_bases[::-1]: + # Add any mixin metadata from plain objects + if hasattr(base, 'meta'): + meta.merge(base.meta) + elif hasattr(base, '_meta'): + meta.merge(base._meta) + + # Set collection in the meta if its callable + if (getattr(base, '_is_document', False) and + not base._meta.get('abstract')): + collection = meta.get('collection', None) + if callable(collection): + meta['collection'] = collection(base) + + meta.merge(attrs.get('_meta', {})) # Top level meta + + # Only simple classes (direct subclasses of Document) + # may set allow_inheritance to False + simple_class = all([b._meta.get('abstract') + for b in flattened_bases if hasattr(b, '_meta')]) + if (not simple_class and meta['allow_inheritance'] == False and + not meta['abstract']): + raise ValueError('Only direct subclasses of Document may set ' + '"allow_inheritance" to False') + + # Set default collection name + if 'collection' not in meta: + meta['collection'] = ''.join('_%s' % c if c.isupper() else c + for c in name).strip('_').lower() + attrs['_meta'] = meta + + # Call super and get the new class + new_class = super_new(cls, name, bases, attrs) + + meta = new_class._meta + + # Set index specifications + meta['index_specs'] = [new_class._build_index_spec(spec) + for spec in meta['indexes']] + unique_indexes = new_class._unique_with_indexes() + new_class._meta['unique_indexes'] = unique_indexes + + # If collection is a callable - call it and set the value + collection = meta.get('collection') + if callable(collection): + new_class._meta['collection'] = collection(new_class) + + # Provide a default queryset unless one has been set + manager = attrs.get('objects', QuerySetManager()) + new_class.objects = manager + + # Validate the fields and set primary key if needed + for field_name, field in new_class._fields.iteritems(): + if field.primary_key: + # Ensure only one primary key is set + current_pk = new_class._meta.get('id_field') + if current_pk and current_pk != field_name: + raise ValueError('Cannot override primary key field') + + # Set primary key + if not current_pk: + new_class._meta['id_field'] = field_name + new_class.id = field + + # Set primary key if not defined by the document + if not new_class._meta.get('id_field'): + new_class._meta['id_field'] = 'id' + new_class._fields['id'] = ObjectIdField(db_field='_id') + new_class.id = new_class._fields['id'] + + # Merge in exceptions with parent hierarchy + exceptions_to_merge = (DoesNotExist, MultipleObjectsReturned) + module = attrs.get('__module__') + for exc in exceptions_to_merge: + name = exc.__name__ + parents = tuple(getattr(base, name) for base in flattened_bases + if hasattr(base, name)) or (exc,) + # Create new exception and set to new_class + exception = type(name, parents, {'__module__': module}) + setattr(new_class, name, exception) + + return new_class + + +class MetaDict(dict): + """Custom dictionary for meta classes. + Handles the merging of set indexes + """ + _merge_options = ('indexes',) + + def merge(self, new_options): + for k, v in new_options.iteritems(): + if k in self._merge_options: + self[k] = self.get(k, []) + v + else: + self[k] = v + + +class BasesTuple(tuple): + """Special class to handle introspection of bases tuple in __new__""" + pass diff --git a/mongoengine/common.py b/mongoengine/common.py new file mode 100644 index 00000000..c284777e --- /dev/null +++ b/mongoengine/common.py @@ -0,0 +1,35 @@ +_class_registry_cache = {} + + +def _import_class(cls_name): + """Cached mechanism for imports""" + if cls_name in _class_registry_cache: + return _class_registry_cache.get(cls_name) + + doc_classes = ('Document', 'DynamicEmbeddedDocument', 'EmbeddedDocument', + 'MapReduceDocument') + field_classes = ('DictField', 'DynamicField', 'EmbeddedDocumentField', + 'GenericReferenceField', 'GeoPointField', + 'ReferenceField', 'StringField') + queryset_classes = ('OperationError',) + deref_classes = ('DeReference',) + + if cls_name in doc_classes: + from mongoengine import document as module + import_classes = doc_classes + elif cls_name in field_classes: + from mongoengine import fields as module + import_classes = field_classes + elif cls_name in queryset_classes: + from mongoengine import queryset as module + import_classes = queryset_classes + elif cls_name in deref_classes: + from mongoengine import dereference as module + import_classes = deref_classes + else: + raise ValueError('No import set for: ' % cls_name) + + for cls in import_classes: + _class_registry_cache[cls] = getattr(module, cls) + + return _class_registry_cache.get(cls_name) \ No newline at end of file diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index 386dbf4b..59cc0a58 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -164,7 +164,7 @@ class DeReference(object): if isinstance(items, (dict, SON)): if '_ref' in items: return self.object_map.get(items['_ref'].id, items) - elif '_types' in items and '_cls' in items: + elif '_cls' in items: doc = get_document(items['_cls'])._from_son(items) doc._data = self._attach_objects(doc._data, depth, doc, None) return doc diff --git a/mongoengine/django/shortcuts.py b/mongoengine/django/shortcuts.py index 637cee15..9cc8370b 100644 --- a/mongoengine/django/shortcuts.py +++ b/mongoengine/django/shortcuts.py @@ -1,6 +1,6 @@ from mongoengine.queryset import QuerySet from mongoengine.base import BaseDocument -from mongoengine.base import ValidationError +from mongoengine.errors import ValidationError def _get_queryset(cls): """Inspired by django.shortcuts.*""" diff --git a/mongoengine/document.py b/mongoengine/document.py index 7b3afafb..b1ce13ad 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -11,9 +11,9 @@ from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, from queryset import OperationError, NotUniqueError from connection import get_db, DEFAULT_CONNECTION_NAME -__all__ = ['Document', 'EmbeddedDocument', 'DynamicDocument', +__all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument', 'DynamicEmbeddedDocument', 'OperationError', - 'InvalidCollectionError', 'NotUniqueError'] + 'InvalidCollectionError', 'NotUniqueError', 'MapReduceDocument') class InvalidCollectionError(Exception): @@ -28,11 +28,11 @@ class EmbeddedDocument(BaseDocument): A :class:`~mongoengine.EmbeddedDocument` subclass may be itself subclassed, to create a specialised version of the embedded document that will be - stored in the same collection. To facilitate this behaviour, `_cls` and - `_types` fields are added to documents (hidden though the MongoEngine - interface though). To disable this behaviour and remove the dependence on - the presence of `_cls` and `_types`, set :attr:`allow_inheritance` to - ``False`` in the :attr:`meta` dictionary. + stored in the same collection. To facilitate this behaviour a `_cls` + field is added to documents (hidden though the MongoEngine interface). + To disable this behaviour and remove the dependence on the presence of + `_cls` set :attr:`allow_inheritance` to ``False`` in the :attr:`meta` + dictionary. """ # The __metaclass__ attribute is removed by 2to3 when running with Python3 @@ -76,11 +76,11 @@ class Document(BaseDocument): A :class:`~mongoengine.Document` subclass may be itself subclassed, to create a specialised version of the document that will be stored in the - same collection. To facilitate this behaviour, `_cls` and `_types` - fields are added to documents (hidden though the MongoEngine interface - though). To disable this behaviour and remove the dependence on the - presence of `_cls` and `_types`, set :attr:`allow_inheritance` to - ``False`` in the :attr:`meta` dictionary. + same collection. To facilitate this behaviour a `_cls` + field is added to documents (hidden though the MongoEngine interface). + To disable this behaviour and remove the dependence on the presence of + `_cls` set :attr:`allow_inheritance` to ``False`` in the :attr:`meta` + dictionary. A :class:`~mongoengine.Document` may use a **Capped Collection** by specifying :attr:`max_documents` and :attr:`max_size` in the :attr:`meta` @@ -101,10 +101,10 @@ class Document(BaseDocument): production systems where index creation is performed as part of a deployment system. - By default, _types will be added to the start of every index (that + By default, _cls will be added to the start of every index (that doesn't contain a list) if allow_inheritance is True. This can be disabled by either setting types to False on the specific index or - by setting index_types to False on the meta dictionary for the document. + by setting index_cls to False on the meta dictionary for the document. """ # The __metaclass__ attribute is removed by 2to3 when running with Python3 diff --git a/mongoengine/errors.py b/mongoengine/errors.py new file mode 100644 index 00000000..eb72503d --- /dev/null +++ b/mongoengine/errors.py @@ -0,0 +1,124 @@ +from collections import defaultdict + +from .python_support import txt_type + + +__all__ = ('NotRegistered', 'InvalidDocumentError', 'ValidationError') + + +class NotRegistered(Exception): + pass + + +class InvalidDocumentError(Exception): + pass + + +class LookUpError(AttributeError): + pass + + +class DoesNotExist(Exception): + pass + + +class MultipleObjectsReturned(Exception): + pass + + +class InvalidQueryError(Exception): + pass + + +class OperationError(Exception): + pass + + +class NotUniqueError(OperationError): + pass + + +class ValidationError(AssertionError): + """Validation exception. + + May represent an error validating a field or a + document containing fields with validation errors. + + :ivar errors: A dictionary of errors for fields within this + document or list, or None if the error is for an + individual field. + """ + + errors = {} + field_name = None + _message = None + + def __init__(self, message="", **kwargs): + self.errors = kwargs.get('errors', {}) + self.field_name = kwargs.get('field_name') + self.message = message + + def __str__(self): + return txt_type(self.message) + + def __repr__(self): + return '%s(%s,)' % (self.__class__.__name__, self.message) + + def __getattribute__(self, name): + message = super(ValidationError, self).__getattribute__(name) + if name == 'message': + if self.field_name: + message = '%s' % message + if self.errors: + message = '%s(%s)' % (message, self._format_errors()) + return message + + def _get_message(self): + return self._message + + def _set_message(self, message): + self._message = message + + message = property(_get_message, _set_message) + + def to_dict(self): + """Returns a dictionary of all errors within a document + + Keys are field names or list indices and values are the + validation error messages, or a nested dictionary of + errors for an embedded document or list. + """ + + def build_dict(source): + errors_dict = {} + if not source: + return errors_dict + if isinstance(source, dict): + for field_name, error in source.iteritems(): + errors_dict[field_name] = build_dict(error) + elif isinstance(source, ValidationError) and source.errors: + return build_dict(source.errors) + else: + return unicode(source) + return errors_dict + if not self.errors: + return {} + return build_dict(self.errors) + + def _format_errors(self): + """Returns a string listing all errors within a document""" + + def generate_key(value, prefix=''): + if isinstance(value, list): + value = ' '.join([generate_key(k) for k in value]) + if isinstance(value, dict): + value = ' '.join( + [generate_key(v, k) for k, v in value.iteritems()]) + + results = "%s.%s" % (prefix, value) if prefix else value + return results + + error_dict = defaultdict(list) + for k, v in self.to_dict().iteritems(): + error_dict[generate_key(v)].append(k) + return ' '.join(["%s: %s" % (k, v) for k, v in error_dict.iteritems()]) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 01d3fc63..9bcba9f1 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -12,10 +12,11 @@ from operator import itemgetter import gridfs from bson import Binary, DBRef, SON, ObjectId +from mongoengine.errors import ValidationError from mongoengine.python_support import (PY3, bin_type, txt_type, str_types, StringIO) from base import (BaseField, ComplexBaseField, ObjectIdField, - ValidationError, get_document, BaseDocument) + get_document, BaseDocument) from queryset import DO_NOTHING, QuerySet from document import Document, EmbeddedDocument from connection import get_db, DEFAULT_CONNECTION_NAME @@ -568,9 +569,6 @@ class ListField(ComplexBaseField): Required means it cannot be empty - as the default for ListFields is [] """ - # ListFields cannot be indexed with _types - MongoDB doesn't support this - _index_with_types = False - def __init__(self, field=None, **kwargs): self.field = field kwargs.setdefault('default', lambda: []) diff --git a/mongoengine/queryset/__init__.py b/mongoengine/queryset/__init__.py new file mode 100644 index 00000000..f6feeab7 --- /dev/null +++ b/mongoengine/queryset/__init__.py @@ -0,0 +1,11 @@ +from mongoengine.errors import (DoesNotExist, MultipleObjectsReturned, + InvalidQueryError, OperationError, + NotUniqueError) +from .field_list import * +from .manager import * +from .queryset import * +from .transform import * +from .visitor import * + +__all__ = (field_list.__all__ + manager.__all__ + queryset.__all__ + + transform.__all__ + visitor.__all__) diff --git a/mongoengine/queryset/field_list.py b/mongoengine/queryset/field_list.py new file mode 100644 index 00000000..1c825fa9 --- /dev/null +++ b/mongoengine/queryset/field_list.py @@ -0,0 +1,51 @@ + +__all__ = ('QueryFieldList',) + + +class QueryFieldList(object): + """Object that handles combinations of .only() and .exclude() calls""" + ONLY = 1 + EXCLUDE = 0 + + def __init__(self, fields=[], value=ONLY, always_include=[]): + self.value = value + self.fields = set(fields) + self.always_include = set(always_include) + self._id = None + + def __add__(self, f): + if not self.fields: + self.fields = f.fields + self.value = f.value + elif self.value is self.ONLY and f.value is self.ONLY: + self.fields = self.fields.intersection(f.fields) + elif self.value is self.EXCLUDE and f.value is self.EXCLUDE: + self.fields = self.fields.union(f.fields) + elif self.value is self.ONLY and f.value is self.EXCLUDE: + self.fields -= f.fields + elif self.value is self.EXCLUDE and f.value is self.ONLY: + self.value = self.ONLY + self.fields = f.fields - self.fields + + if '_id' in f.fields: + self._id = f.value + + if self.always_include: + if self.value is self.ONLY and self.fields: + self.fields = self.fields.union(self.always_include) + else: + self.fields -= self.always_include + return self + + def __nonzero__(self): + return bool(self.fields) + + def as_dict(self): + field_list = dict((field, self.value) for field in self.fields) + if self._id is not None: + field_list['_id'] = self._id + return field_list + + def reset(self): + self.fields = set([]) + self.value = self.ONLY diff --git a/mongoengine/queryset/manager.py b/mongoengine/queryset/manager.py new file mode 100644 index 00000000..7376e3c6 --- /dev/null +++ b/mongoengine/queryset/manager.py @@ -0,0 +1,61 @@ +from functools import partial +from .queryset import QuerySet + +__all__ = ('queryset_manager', 'QuerySetManager') + + +class QuerySetManager(object): + """ + The default QuerySet Manager. + + Custom QuerySet Manager functions can extend this class and users can + add extra queryset functionality. Any custom manager methods must accept a + :class:`~mongoengine.Document` class as its first argument, and a + :class:`~mongoengine.queryset.QuerySet` as its second argument. + + The method function should return a :class:`~mongoengine.queryset.QuerySet` + , probably the same one that was passed in, but modified in some way. + """ + + get_queryset = None + + def __init__(self, queryset_func=None): + if queryset_func: + self.get_queryset = queryset_func + self._collections = {} + + def __get__(self, instance, owner): + """Descriptor for instantiating a new QuerySet object when + Document.objects is accessed. + """ + if instance is not None: + # Document class being used rather than a document object + return self + + # owner is the document that contains the QuerySetManager + queryset_class = owner._meta.get('queryset_class') or QuerySet + queryset = queryset_class(owner, owner._get_collection()) + if self.get_queryset: + arg_count = self.get_queryset.func_code.co_argcount + if arg_count == 1: + queryset = self.get_queryset(queryset) + elif arg_count == 2: + queryset = self.get_queryset(owner, queryset) + else: + queryset = partial(self.get_queryset, owner, queryset) + return queryset + + +def queryset_manager(func): + """Decorator that allows you to define custom QuerySet managers on + :class:`~mongoengine.Document` classes. The manager must be a function that + accepts a :class:`~mongoengine.Document` class as its first argument, and a + :class:`~mongoengine.queryset.QuerySet` as its second argument. The method + function should return a :class:`~mongoengine.queryset.QuerySet`, probably + the same one that was passed in, but modified in some way. + """ + if func.func_code.co_argcount == 1: + import warnings + msg = 'Methods decorated with queryset_manager should take 2 arguments' + warnings.warn(msg, DeprecationWarning) + return QuerySetManager(func) diff --git a/mongoengine/queryset.py b/mongoengine/queryset/queryset.py similarity index 59% rename from mongoengine/queryset.py rename to mongoengine/queryset/queryset.py index c774322e..51080663 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset/queryset.py @@ -4,20 +4,21 @@ import copy import itertools import operator -from collections import defaultdict -from functools import partial - -from mongoengine.python_support import product, reduce - import pymongo from bson.code import Code from mongoengine import signals +from mongoengine.common import _import_class +from mongoengine.errors import (OperationError, NotUniqueError, + InvalidQueryError) -__all__ = ['queryset_manager', 'Q', 'InvalidQueryError', - 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY', 'PULL'] +from . import transform +from .field_list import QueryFieldList +from .visitor import Q +__all__ = ('QuerySet', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY', 'PULL') + # The maximum number of items to display in a QuerySet.__repr__ REPR_OUTPUT_SIZE = 20 @@ -28,308 +29,9 @@ CASCADE = 2 DENY = 3 PULL = 4 - -class DoesNotExist(Exception): - pass - - -class MultipleObjectsReturned(Exception): - pass - - -class InvalidQueryError(Exception): - pass - - -class OperationError(Exception): - pass - - -class NotUniqueError(OperationError): - pass - - RE_TYPE = type(re.compile('')) -class QNodeVisitor(object): - """Base visitor class for visiting Q-object nodes in a query tree. - """ - - def visit_combination(self, combination): - """Called by QCombination objects. - """ - return combination - - def visit_query(self, query): - """Called by (New)Q objects. - """ - return query - - -class SimplificationVisitor(QNodeVisitor): - """Simplifies query trees by combinging unnecessary 'and' connection nodes - into a single Q-object. - """ - - def visit_combination(self, combination): - if combination.operation == combination.AND: - # The simplification only applies to 'simple' queries - if all(isinstance(node, Q) for node in combination.children): - queries = [node.query for node in combination.children] - return Q(**self._query_conjunction(queries)) - return combination - - def _query_conjunction(self, queries): - """Merges query dicts - effectively &ing them together. - """ - query_ops = set() - combined_query = {} - for query in queries: - ops = set(query.keys()) - # Make sure that the same operation isn't applied more than once - # to a single field - intersection = ops.intersection(query_ops) - if intersection: - msg = 'Duplicate query conditions: ' - raise InvalidQueryError(msg + ', '.join(intersection)) - - query_ops.update(ops) - combined_query.update(copy.deepcopy(query)) - return combined_query - - -class QueryTreeTransformerVisitor(QNodeVisitor): - """Transforms the query tree in to a form that may be used with MongoDB. - """ - - def visit_combination(self, combination): - if combination.operation == combination.AND: - # MongoDB doesn't allow us to have too many $or operations in our - # queries, so the aim is to move the ORs up the tree to one - # 'master' $or. Firstly, we must find all the necessary parts (part - # of an AND combination or just standard Q object), and store them - # separately from the OR parts. - or_groups = [] - and_parts = [] - for node in combination.children: - if isinstance(node, QCombination): - if node.operation == node.OR: - # Any of the children in an $or component may cause - # the query to succeed - or_groups.append(node.children) - elif node.operation == node.AND: - and_parts.append(node) - elif isinstance(node, Q): - and_parts.append(node) - - # Now we combine the parts into a usable query. AND together all of - # the necessary parts. Then for each $or part, create a new query - # that ANDs the necessary part with the $or part. - clauses = [] - for or_group in product(*or_groups): - q_object = reduce(lambda a, b: a & b, and_parts, Q()) - q_object = reduce(lambda a, b: a & b, or_group, q_object) - clauses.append(q_object) - # Finally, $or the generated clauses in to one query. Each of the - # clauses is sufficient for the query to succeed. - return reduce(lambda a, b: a | b, clauses, Q()) - - if combination.operation == combination.OR: - children = [] - # Crush any nested ORs in to this combination as MongoDB doesn't - # support nested $or operations - for node in combination.children: - if (isinstance(node, QCombination) and - node.operation == combination.OR): - children += node.children - else: - children.append(node) - combination.children = children - - return combination - - -class QueryCompilerVisitor(QNodeVisitor): - """Compiles the nodes in a query tree to a PyMongo-compatible query - dictionary. - """ - - def __init__(self, document): - self.document = document - - def visit_combination(self, combination): - if combination.operation == combination.OR: - return {'$or': combination.children} - elif combination.operation == combination.AND: - return self._mongo_query_conjunction(combination.children) - return combination - - def visit_query(self, query): - return QuerySet._transform_query(self.document, **query.query) - - def _mongo_query_conjunction(self, queries): - """Merges Mongo query dicts - effectively &ing them together. - """ - combined_query = {} - for query in queries: - for field, ops in query.items(): - if field not in combined_query: - combined_query[field] = ops - else: - # The field is already present in the query the only way - # we can merge is if both the existing value and the new - # value are operation dicts, reject anything else - if (not isinstance(combined_query[field], dict) or - not isinstance(ops, dict)): - message = 'Conflicting values for ' + field - raise InvalidQueryError(message) - - current_ops = set(combined_query[field].keys()) - new_ops = set(ops.keys()) - # Make sure that the same operation isn't applied more than - # once to a single field - intersection = current_ops.intersection(new_ops) - if intersection: - msg = 'Duplicate query conditions: ' - raise InvalidQueryError(msg + ', '.join(intersection)) - - # Right! We've got two non-overlapping dicts of operations! - combined_query[field].update(copy.deepcopy(ops)) - return combined_query - - -class QNode(object): - """Base class for nodes in query trees. - """ - - AND = 0 - OR = 1 - - def to_query(self, document): - query = self.accept(SimplificationVisitor()) - query = query.accept(QueryTreeTransformerVisitor()) - query = query.accept(QueryCompilerVisitor(document)) - return query - - def accept(self, visitor): - raise NotImplementedError - - def _combine(self, other, operation): - """Combine this node with another node into a QCombination object. - """ - if getattr(other, 'empty', True): - return self - - if self.empty: - return other - - return QCombination(operation, [self, other]) - - @property - def empty(self): - return False - - def __or__(self, other): - return self._combine(other, self.OR) - - def __and__(self, other): - return self._combine(other, self.AND) - - -class QCombination(QNode): - """Represents the combination of several conditions by a given logical - operator. - """ - - def __init__(self, operation, children): - self.operation = operation - self.children = [] - for node in children: - # If the child is a combination of the same type, we can merge its - # children directly into this combinations children - if isinstance(node, QCombination) and node.operation == operation: - self.children += node.children - else: - self.children.append(node) - - def accept(self, visitor): - for i in range(len(self.children)): - if isinstance(self.children[i], QNode): - self.children[i] = self.children[i].accept(visitor) - - return visitor.visit_combination(self) - - @property - def empty(self): - return not bool(self.children) - - -class Q(QNode): - """A simple query object, used in a query tree to build up more complex - query structures. - """ - - def __init__(self, **query): - self.query = query - - def accept(self, visitor): - return visitor.visit_query(self) - - @property - def empty(self): - return not bool(self.query) - - -class QueryFieldList(object): - """Object that handles combinations of .only() and .exclude() calls""" - ONLY = 1 - EXCLUDE = 0 - - def __init__(self, fields=[], value=ONLY, always_include=[]): - self.value = value - self.fields = set(fields) - self.always_include = set(always_include) - self._id = None - - def as_dict(self): - field_list = dict((field, self.value) for field in self.fields) - if self._id is not None: - field_list['_id'] = self._id - return field_list - - def __add__(self, f): - if not self.fields: - self.fields = f.fields - self.value = f.value - elif self.value is self.ONLY and f.value is self.ONLY: - self.fields = self.fields.intersection(f.fields) - elif self.value is self.EXCLUDE and f.value is self.EXCLUDE: - self.fields = self.fields.union(f.fields) - elif self.value is self.ONLY and f.value is self.EXCLUDE: - self.fields -= f.fields - elif self.value is self.EXCLUDE and f.value is self.ONLY: - self.value = self.ONLY - self.fields = f.fields - self.fields - - if '_id' in f.fields: - self._id = f.value - - if self.always_include: - if self.value is self.ONLY and self.fields: - self.fields = self.fields.union(self.always_include) - else: - self.fields -= self.always_include - return self - - def reset(self): - self.fields = set([]) - self.value = self.ONLY - - def __nonzero__(self): - return bool(self.fields) - - class QuerySet(object): """A set of results returned from a query. Wraps a MongoDB cursor, providing :class:`~mongoengine.Document` objects as the results. @@ -357,7 +59,7 @@ class QuerySet(object): # If inheritance is allowed, only return instances and instances of # subclasses of the class being used if document._meta.get('allow_inheritance') != False: - self._initial_query = {'_types': self._document._class_name} + self._initial_query = {"_cls": {"$in": self._document._subclasses}} self._loaded_fields = QueryFieldList(always_include=['_cls']) self._cursor_obj = None self._limit = None @@ -397,7 +99,7 @@ class QuerySet(object): construct a multi-field index); keys may be prefixed with a **+** or a **-** to determine the index ordering """ - index_spec = QuerySet._build_index_spec(self._document, key_or_list) + index_spec = self._document._build_index_spec(key_or_list) index_spec = index_spec.copy() fields = index_spec.pop('fields') index_spec['drop_dups'] = drop_dups @@ -448,26 +150,26 @@ class QuerySet(object): background = self._document._meta.get('index_background', False) drop_dups = self._document._meta.get('index_drop_dups', False) index_opts = self._document._meta.get('index_opts') or {} - index_types = self._document._meta.get('index_types', True) + index_cls = self._document._meta.get('index_cls', True) # determine if an index which we are creating includes - # _type as its first field; if so, we can avoid creating - # an extra index on _type, as mongodb will use the existing - # index to service queries against _type - types_indexed = False + # _cls as its first field; if so, we can avoid creating + # an extra index on _cls, as mongodb will use the existing + # index to service queries against _cls + cls_indexed = False - def includes_types(fields): + def includes_cls(fields): first_field = None if len(fields): if isinstance(fields[0], basestring): first_field = fields[0] elif isinstance(fields[0], (list, tuple)) and len(fields[0]): first_field = fields[0][0] - return first_field == '_types' + return first_field == '_cls' # Ensure indexes created by uniqueness constraints for index in self._document._meta['unique_indexes']: - types_indexed = types_indexed or includes_types(index) + cls_indexed = cls_indexed or includes_cls(index) self._collection.ensure_index(index, unique=True, background=background, drop_dups=drop_dups, **index_opts) @@ -477,16 +179,16 @@ class QuerySet(object): for spec in index_spec: spec = spec.copy() fields = spec.pop('fields') - types_indexed = types_indexed or includes_types(fields) + cls_indexed = cls_indexed or includes_cls(fields) opts = index_opts.copy() opts.update(spec) self._collection.ensure_index(fields, background=background, **opts) - # If _types is being used (for polymorphism), it needs an index, - # only if another index doesn't begin with _types - if index_types and '_types' in self._query and not types_indexed: - self._collection.ensure_index('_types', + # If _cls is being used (for polymorphism), it needs an index, + # only if another index doesn't begin with _cls + if index_cls and '_cls' in self._query and not cls_indexed: + self._collection.ensure_index('_cls', background=background, **index_opts) # Add geo indicies @@ -495,79 +197,14 @@ class QuerySet(object): self._collection.ensure_index(index_spec, background=background, **index_opts) - @classmethod - def _build_index_spec(cls, doc_cls, spec): - """Build a PyMongo index spec from a MongoEngine index spec. - """ - if isinstance(spec, basestring): - spec = {'fields': [spec]} - elif isinstance(spec, (list, tuple)): - spec = {'fields': list(spec)} - elif isinstance(spec, dict): - spec = dict(spec) - - index_list = [] - direction = None - - allow_inheritance = doc_cls._meta.get('allow_inheritance') != False - - # If sparse - dont include types - use_types = allow_inheritance and not spec.get('sparse', False) - - for key in spec['fields']: - # If inherited spec continue - if isinstance(key, (list, tuple)): - continue - - # Get ASCENDING direction from +, DESCENDING from -, and GEO2D from * - direction = pymongo.ASCENDING - if key.startswith("-"): - direction = pymongo.DESCENDING - elif key.startswith("*"): - direction = pymongo.GEO2D - if key.startswith(("+", "-", "*")): - key = key[1:] - - # Use real field name, do it manually because we need field - # objects for the next part (list field checking) - parts = key.split('.') - if parts in (['pk'], ['id'], ['_id']): - key = '_id' - fields = [] - else: - fields = QuerySet._lookup_field(doc_cls, parts) - parts = [field if field == '_id' else field.db_field - for field in fields] - key = '.'.join(parts) - index_list.append((key, direction)) - - # Check if a list field is being used, don't use _types if it is - if use_types and not all(f._index_with_types for f in fields): - use_types = False - - # If _types is being used, prepend it to every specified index - index_types = doc_cls._meta.get('index_types', True) - - if (spec.get('types', index_types) and use_types - and direction is not pymongo.GEO2D): - index_list.insert(0, ('_types', 1)) - - spec['fields'] = index_list - if spec.get('sparse', False) and len(spec['fields']) > 1: - raise ValueError( - 'Sparse indexes can only have one field in them. ' - 'See https://jira.mongodb.org/browse/SERVER-2193') - - return spec - @classmethod def _reset_already_indexed(cls, document=None): - """Helper to reset already indexed, can be useful for testing purposes""" + """Helper to reset already indexed, can be useful for testing purposes + """ if document: cls.__already_indexed.discard(document) cls.__already_indexed.clear() - @property def _collection(self): """Property that returns the collection object. This allows us to @@ -624,195 +261,12 @@ class QuerySet(object): self._cursor_obj.hint(self._hint) return self._cursor_obj - @classmethod - def _lookup_field(cls, document, parts): - """Lookup a field based on its attribute and return a list containing - the field's parents and the field. - """ - if not isinstance(parts, (list, tuple)): - parts = [parts] - fields = [] - field = None - - for field_name in parts: - # Handle ListField indexing: - if field_name.isdigit(): - try: - new_field = field.field - except AttributeError, err: - raise InvalidQueryError( - "Can't use index on unsubscriptable field (%s)" % err) - fields.append(field_name) - continue - - if field is None: - # Look up first field from the document - if field_name == 'pk': - # Deal with "primary key" alias - field_name = document._meta['id_field'] - if field_name in document._fields: - field = document._fields[field_name] - elif document._dynamic: - from fields import DynamicField - field = DynamicField(db_field=field_name) - else: - raise InvalidQueryError('Cannot resolve field "%s"' - % field_name) - else: - from mongoengine.fields import ReferenceField, GenericReferenceField - if isinstance(field, (ReferenceField, GenericReferenceField)): - raise InvalidQueryError('Cannot perform join in mongoDB: %s' % '__'.join(parts)) - if hasattr(getattr(field, 'field', None), 'lookup_member'): - new_field = field.field.lookup_member(field_name) - else: - # Look up subfield on the previous field - new_field = field.lookup_member(field_name) - from base import ComplexBaseField - if not new_field and isinstance(field, ComplexBaseField): - fields.append(field_name) - continue - elif not new_field: - raise InvalidQueryError('Cannot resolve field "%s"' - % field_name) - field = new_field # update field to the new field type - fields.append(field) - return fields - - @classmethod - def _translate_field_name(cls, doc_cls, field, sep='.'): - """Translate a field attribute name to a database field name. - """ - parts = field.split(sep) - parts = [f.db_field for f in QuerySet._lookup_field(doc_cls, parts)] - return '.'.join(parts) - - @classmethod - def _transform_query(cls, _doc_cls=None, _field_operation=False, **query): - """Transform a query from Django-style format to Mongo format. - """ - operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', - 'all', 'size', 'exists', 'not'] - geo_operators = ['within_distance', 'within_spherical_distance', 'within_box', 'within_polygon', 'near', 'near_sphere'] - match_operators = ['contains', 'icontains', 'startswith', - 'istartswith', 'endswith', 'iendswith', - 'exact', 'iexact'] - custom_operators = ['match'] - - mongo_query = {} - merge_query = defaultdict(list) - for key, value in query.items(): - if key == "__raw__": - mongo_query.update(value) - continue - - parts = key.split('__') - indices = [(i, p) for i, p in enumerate(parts) if p.isdigit()] - parts = [part for part in parts if not part.isdigit()] - # Check for an operator and transform to mongo-style if there is - op = None - if parts[-1] in operators + match_operators + geo_operators + custom_operators: - op = parts.pop() - - negate = False - if parts[-1] == 'not': - parts.pop() - negate = True - - if _doc_cls: - # Switch field names to proper names [set in Field(name='foo')] - fields = QuerySet._lookup_field(_doc_cls, parts) - parts = [] - - cleaned_fields = [] - for field in fields: - append_field = True - if isinstance(field, basestring): - parts.append(field) - append_field = False - else: - parts.append(field.db_field) - if append_field: - cleaned_fields.append(field) - - # Convert value to proper value - field = cleaned_fields[-1] - - singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not'] - singular_ops += match_operators - if op in singular_ops: - if isinstance(field, basestring): - if op in match_operators and isinstance(value, basestring): - from mongoengine import StringField - value = StringField.prepare_query_value(op, value) - else: - value = field - else: - value = field.prepare_query_value(op, value) - elif op in ('in', 'nin', 'all', 'near'): - # 'in', 'nin' and 'all' require a list of values - value = [field.prepare_query_value(op, v) for v in value] - - # if op and op not in match_operators: - if op: - if op in geo_operators: - if op == "within_distance": - value = {'$within': {'$center': value}} - elif op == "within_spherical_distance": - value = {'$within': {'$centerSphere': value}} - elif op == "within_polygon": - value = {'$within': {'$polygon': value}} - elif op == "near": - value = {'$near': value} - elif op == "near_sphere": - value = {'$nearSphere': value} - elif op == 'within_box': - value = {'$within': {'$box': value}} - else: - raise NotImplementedError("Geo method '%s' has not " - "been implemented" % op) - elif op in custom_operators: - if op == 'match': - value = {"$elemMatch": value} - else: - NotImplementedError("Custom method '%s' has not " - "been implemented" % op) - elif op not in match_operators: - value = {'$' + op: value} - - if negate: - value = {'$not': value} - - for i, part in indices: - parts.insert(i, part) - key = '.'.join(parts) - if op is None or key not in mongo_query: - mongo_query[key] = value - elif key in mongo_query: - if key in mongo_query and isinstance(mongo_query[key], dict): - mongo_query[key].update(value) - else: - # Store for manually merging later - merge_query[key].append(value) - - # The queryset has been filter in such a way we must manually merge - for k, v in merge_query.items(): - merge_query[k].append(mongo_query[k]) - del mongo_query[k] - if isinstance(v, list): - value = [{k:val} for val in v] - if '$and' in mongo_query.keys(): - mongo_query['$and'].append(value) - else: - mongo_query['$and'] = value - - return mongo_query - def get(self, *q_objs, **query): """Retrieve the the matching object raising :class:`~mongoengine.queryset.MultipleObjectsReturned` or - `DocumentName.MultipleObjectsReturned` exception if multiple results and - :class:`~mongoengine.queryset.DoesNotExist` or `DocumentName.DoesNotExist` - if no results are found. + `DocumentName.MultipleObjectsReturned` exception if multiple results + and :class:`~mongoengine.queryset.DoesNotExist` or + `DocumentName.DoesNotExist` if no results are found. .. versionadded:: 0.3 """ @@ -910,7 +364,7 @@ class QuerySet(object): .. versionadded:: 0.5 """ - from document import Document + Document = _import_class('Document') if not write_options: write_options = {} @@ -1064,7 +518,7 @@ class QuerySet(object): .. versionadded:: 0.3 """ - from document import MapReduceDocument + MapReduceDocument = _import_class('MapReduceDocument') if not hasattr(self._collection, "map_reduce"): raise NotImplementedError("Requires MongoDB >= 1.7.1") @@ -1267,14 +721,16 @@ class QuerySet(object): .. versionadded:: 0.5 """ - self._loaded_fields = QueryFieldList(always_include=self._loaded_fields.always_include) + self._loaded_fields = QueryFieldList( + always_include=self._loaded_fields.always_include) return self def _fields_to_dbfields(self, fields): """Translate fields paths to its db equivalents""" ret = [] for field in fields: - field = ".".join(f.db_field for f in QuerySet._lookup_field(self._document, field.split('.'))) + field = ".".join(f.db_field for f in + self._document._lookup_field(field.split('.'))) ret.append(field) return ret @@ -1288,7 +744,8 @@ class QuerySet(object): """ key_list = [] for key in keys: - if not key: continue + if not key: + continue direction = pymongo.ASCENDING if key[0] == '-': direction = pymongo.DESCENDING @@ -1296,7 +753,7 @@ class QuerySet(object): key = key[1:] key = key.replace('__', '.') try: - key = QuerySet._translate_field_name(self._document, key) + key = self._document._translate_field_name(key) except: pass key_list.append((key, direction)) @@ -1389,107 +846,6 @@ class QuerySet(object): self._collection.remove(self._query, safe=safe) - @classmethod - def _transform_update(cls, _doc_cls=None, **update): - """Transform an update spec from Django-style format to Mongo format. - """ - operators = ['set', 'unset', 'inc', 'dec', 'pop', 'push', 'push_all', - 'pull', 'pull_all', 'add_to_set'] - match_operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', - 'all', 'size', 'exists', 'not'] - - mongo_update = {} - for key, value in update.items(): - if key == "__raw__": - mongo_update.update(value) - continue - parts = key.split('__') - # Check for an operator and transform to mongo-style if there is - op = None - if parts[0] in operators: - op = parts.pop(0) - # Convert Pythonic names to Mongo equivalents - if op in ('push_all', 'pull_all'): - op = op.replace('_all', 'All') - elif op == 'dec': - # Support decrement by flipping a positive value's sign - # and using 'inc' - op = 'inc' - if value > 0: - value = -value - elif op == 'add_to_set': - op = op.replace('_to_set', 'ToSet') - - match = None - if parts[-1] in match_operators: - match = parts.pop() - - if _doc_cls: - # Switch field names to proper names [set in Field(name='foo')] - fields = QuerySet._lookup_field(_doc_cls, parts) - parts = [] - - cleaned_fields = [] - for field in fields: - append_field = True - if isinstance(field, basestring): - # Convert the S operator to $ - if field == 'S': - field = '$' - parts.append(field) - append_field = False - else: - parts.append(field.db_field) - if append_field: - cleaned_fields.append(field) - - # Convert value to proper value - field = cleaned_fields[-1] - - if op in (None, 'set', 'push', 'pull'): - if field.required or value is not None: - value = field.prepare_query_value(op, value) - elif op in ('pushAll', 'pullAll'): - value = [field.prepare_query_value(op, v) for v in value] - elif op == 'addToSet': - if isinstance(value, (list, tuple, set)): - value = [field.prepare_query_value(op, v) for v in value] - elif field.required or value is not None: - value = field.prepare_query_value(op, value) - - if match: - match = '$' + match - value = {match: value} - - key = '.'.join(parts) - - if not op: - raise InvalidQueryError("Updates must supply an operation " - "eg: set__FIELD=value") - - if 'pull' in op and '.' in key: - # Dot operators don't work on pull operations - # it uses nested dict syntax - if op == 'pullAll': - raise InvalidQueryError("pullAll operations only support " - "a single field depth") - - parts.reverse() - for key in parts: - value = {key: value} - elif op == 'addToSet' and isinstance(value, list): - value = {key: {"$each": value}} - else: - value = {key: value} - key = '$' + op - - if key not in mongo_update: - mongo_update[key] = value - elif key in mongo_update and isinstance(mongo_update[key], dict): - mongo_update[key].update(value) - - return mongo_update - def update(self, safe_update=True, upsert=False, multi=True, write_options=None, **update): """Perform an atomic update on the fields matched by the query. When ``safe_update`` is used, the number of affected documents is returned. @@ -1506,14 +862,9 @@ class QuerySet(object): if not write_options: write_options = {} - update = QuerySet._transform_update(self._document, **update) + update = transform.update(self._document, **update) query = self._query - # SERVER-5247 hack - remove_types = "_types" in query and ".$." in unicode(update) - if remove_types: - del query["_types"] - try: ret = self._collection.update(query, update, multi=multi, upsert=upsert, safe=safe_update, @@ -1537,30 +888,8 @@ class QuerySet(object): .. versionadded:: 0.2 """ - if not update: - raise OperationError("No update parameters, would remove data") - - if not write_options: - write_options = {} - update = QuerySet._transform_update(self._document, **update) - query = self._query - - # SERVER-5247 hack - remove_types = "_types" in query and ".$." in unicode(update) - if remove_types: - del query["_types"] - - try: - # Explicitly provide 'multi=False' to newer versions of PyMongo - # as the default may change to 'True' - ret = self._collection.update(query, update, multi=False, - upsert=upsert, safe=safe_update, - **write_options) - - if ret is not None and 'n' in ret: - return ret['n'] - except pymongo.errors.OperationFailure, e: - raise OperationError(u'Update failed [%s]' % unicode(e)) + return self.update(safe_update=True, upsert=False, multi=False, + write_options=None, **update) def __iter__(self): self.rewind() @@ -1611,14 +940,14 @@ class QuerySet(object): def field_sub(match): # Extract just the field name, and look up the field objects field_name = match.group(1).split('.') - fields = QuerySet._lookup_field(self._document, field_name) + fields = self._document._lookup_field(field_name) # Substitute the correct name for the field into the javascript return u'["%s"]' % fields[-1].db_field def field_path_sub(match): # Extract just the field name, and look up the field objects field_name = match.group(1).split('.') - fields = QuerySet._lookup_field(self._document, field_name) + fields = self._document._lookup_field(field_name) # Substitute the correct name for the field into the javascript return ".".join([f.db_field for f in fields]) @@ -1650,8 +979,7 @@ class QuerySet(object): """ code = self._sub_js_fields(code) - fields = [QuerySet._translate_field_name(self._document, f) - for f in fields] + fields = [self._document._translate_field_name(f) for f in fields] collection = self._document._get_collection_name() scope = { @@ -1925,63 +1253,5 @@ class QuerySet(object): @property def _dereference(self): if not self.__dereference: - from dereference import DeReference - self.__dereference = DeReference() # Cached + self.__dereference = _import_class('DeReference')() return self.__dereference - - -class QuerySetManager(object): - """ - The default QuerySet Manager. - - Custom QuerySet Manager functions can extend this class and users can - add extra queryset functionality. Any custom manager methods must accept a - :class:`~mongoengine.Document` class as its first argument, and a - :class:`~mongoengine.queryset.QuerySet` as its second argument. - - The method function should return a :class:`~mongoengine.queryset.QuerySet` - , probably the same one that was passed in, but modified in some way. - """ - - get_queryset = None - - def __init__(self, queryset_func=None): - if queryset_func: - self.get_queryset = queryset_func - self._collections = {} - - def __get__(self, instance, owner): - """Descriptor for instantiating a new QuerySet object when - Document.objects is accessed. - """ - if instance is not None: - # Document class being used rather than a document object - return self - - # owner is the document that contains the QuerySetManager - queryset_class = owner._meta.get('queryset_class') or QuerySet - queryset = queryset_class(owner, owner._get_collection()) - if self.get_queryset: - arg_count = self.get_queryset.func_code.co_argcount - if arg_count == 1: - queryset = self.get_queryset(queryset) - elif arg_count == 2: - queryset = self.get_queryset(owner, queryset) - else: - queryset = partial(self.get_queryset, owner, queryset) - return queryset - - -def queryset_manager(func): - """Decorator that allows you to define custom QuerySet managers on - :class:`~mongoengine.Document` classes. The manager must be a function that - accepts a :class:`~mongoengine.Document` class as its first argument, and a - :class:`~mongoengine.queryset.QuerySet` as its second argument. The method - function should return a :class:`~mongoengine.queryset.QuerySet`, probably - the same one that was passed in, but modified in some way. - """ - if func.func_code.co_argcount == 1: - import warnings - msg = 'Methods decorated with queryset_manager should take 2 arguments' - warnings.warn(msg, DeprecationWarning) - return QuerySetManager(func) diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py new file mode 100644 index 00000000..8ee84eed --- /dev/null +++ b/mongoengine/queryset/transform.py @@ -0,0 +1,237 @@ +from collections import defaultdict + +from mongoengine.common import _import_class +from mongoengine.errors import InvalidQueryError, LookUpError + +__all__ = ('query', 'update') + + +COMPARISON_OPERATORS = ('ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', + 'all', 'size', 'exists', 'not') +GEO_OPERATORS = ('within_distance', 'within_spherical_distance', + 'within_box', 'within_polygon', 'near', 'near_sphere') +STRING_OPERATORS = ('contains', 'icontains', 'startswith', + 'istartswith', 'endswith', 'iendswith', + 'exact', 'iexact') +CUSTOM_OPERATORS = ('match',) +MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS + + STRING_OPERATORS + CUSTOM_OPERATORS) + +UPDATE_OPERATORS = ('set', 'unset', 'inc', 'dec', 'pop', 'push', + 'push_all', 'pull', 'pull_all', 'add_to_set') + + +def query(_doc_cls=None, _field_operation=False, **query): + """Transform a query from Django-style format to Mongo format. + """ + mongo_query = {} + merge_query = defaultdict(list) + for key, value in query.items(): + if key == "__raw__": + mongo_query.update(value) + continue + + parts = key.split('__') + indices = [(i, p) for i, p in enumerate(parts) if p.isdigit()] + parts = [part for part in parts if not part.isdigit()] + # Check for an operator and transform to mongo-style if there is + op = None + if parts[-1] in MATCH_OPERATORS: + op = parts.pop() + + negate = False + if parts[-1] == 'not': + parts.pop() + negate = True + + if _doc_cls: + # Switch field names to proper names [set in Field(name='foo')] + try: + fields = _doc_cls._lookup_field(parts) + except Exception, e: + raise InvalidQueryError(e) + parts = [] + + cleaned_fields = [] + for field in fields: + append_field = True + if isinstance(field, basestring): + parts.append(field) + append_field = False + else: + parts.append(field.db_field) + if append_field: + cleaned_fields.append(field) + + # Convert value to proper value + field = cleaned_fields[-1] + + singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not'] + singular_ops += STRING_OPERATORS + if op in singular_ops: + if isinstance(field, basestring): + if (op in STRING_OPERATORS and + isinstance(value, basestring)): + StringField = _import_class('StringField') + value = StringField.prepare_query_value(op, value) + else: + value = field + else: + value = field.prepare_query_value(op, value) + elif op in ('in', 'nin', 'all', 'near'): + # 'in', 'nin' and 'all' require a list of values + value = [field.prepare_query_value(op, v) for v in value] + + # if op and op not in COMPARISON_OPERATORS: + if op: + if op in GEO_OPERATORS: + if op == "within_distance": + value = {'$within': {'$center': value}} + elif op == "within_spherical_distance": + value = {'$within': {'$centerSphere': value}} + elif op == "within_polygon": + value = {'$within': {'$polygon': value}} + elif op == "near": + value = {'$near': value} + elif op == "near_sphere": + value = {'$nearSphere': value} + elif op == 'within_box': + value = {'$within': {'$box': value}} + else: + raise NotImplementedError("Geo method '%s' has not " + "been implemented" % op) + elif op in CUSTOM_OPERATORS: + if op == 'match': + value = {"$elemMatch": value} + else: + NotImplementedError("Custom method '%s' has not " + "been implemented" % op) + elif op not in STRING_OPERATORS: + value = {'$' + op: value} + + if negate: + value = {'$not': value} + + for i, part in indices: + parts.insert(i, part) + key = '.'.join(parts) + if op is None or key not in mongo_query: + mongo_query[key] = value + elif key in mongo_query: + if key in mongo_query and isinstance(mongo_query[key], dict): + mongo_query[key].update(value) + else: + # Store for manually merging later + merge_query[key].append(value) + + # The queryset has been filter in such a way we must manually merge + for k, v in merge_query.items(): + merge_query[k].append(mongo_query[k]) + del mongo_query[k] + if isinstance(v, list): + value = [{k:val} for val in v] + if '$and' in mongo_query.keys(): + mongo_query['$and'].append(value) + else: + mongo_query['$and'] = value + + return mongo_query + + +def update(_doc_cls=None, **update): + """Transform an update spec from Django-style format to Mongo format. + """ + mongo_update = {} + for key, value in update.items(): + if key == "__raw__": + mongo_update.update(value) + continue + parts = key.split('__') + # Check for an operator and transform to mongo-style if there is + op = None + if parts[0] in UPDATE_OPERATORS: + op = parts.pop(0) + # Convert Pythonic names to Mongo equivalents + if op in ('push_all', 'pull_all'): + op = op.replace('_all', 'All') + elif op == 'dec': + # Support decrement by flipping a positive value's sign + # and using 'inc' + op = 'inc' + if value > 0: + value = -value + elif op == 'add_to_set': + op = op.replace('_to_set', 'ToSet') + + match = None + if parts[-1] in COMPARISON_OPERATORS: + match = parts.pop() + + if _doc_cls: + # Switch field names to proper names [set in Field(name='foo')] + try: + fields = _doc_cls._lookup_field(parts) + except Exception, e: + raise InvalidQueryError(e) + parts = [] + + cleaned_fields = [] + for field in fields: + append_field = True + if isinstance(field, basestring): + # Convert the S operator to $ + if field == 'S': + field = '$' + parts.append(field) + append_field = False + else: + parts.append(field.db_field) + if append_field: + cleaned_fields.append(field) + + # Convert value to proper value + field = cleaned_fields[-1] + + if op in (None, 'set', 'push', 'pull'): + if field.required or value is not None: + value = field.prepare_query_value(op, value) + elif op in ('pushAll', 'pullAll'): + value = [field.prepare_query_value(op, v) for v in value] + elif op == 'addToSet': + if isinstance(value, (list, tuple, set)): + value = [field.prepare_query_value(op, v) for v in value] + elif field.required or value is not None: + value = field.prepare_query_value(op, value) + + if match: + match = '$' + match + value = {match: value} + + key = '.'.join(parts) + + if not op: + raise InvalidQueryError("Updates must supply an operation " + "eg: set__FIELD=value") + + if 'pull' in op and '.' in key: + # Dot operators don't work on pull operations + # it uses nested dict syntax + if op == 'pullAll': + raise InvalidQueryError("pullAll operations only support " + "a single field depth") + + parts.reverse() + for key in parts: + value = {key: value} + elif op == 'addToSet' and isinstance(value, list): + value = {key: {"$each": value}} + else: + value = {key: value} + key = '$' + op + + if key not in mongo_update: + mongo_update[key] = value + elif key in mongo_update and isinstance(mongo_update[key], dict): + mongo_update[key].update(value) + + return mongo_update diff --git a/mongoengine/queryset/visitor.py b/mongoengine/queryset/visitor.py new file mode 100644 index 00000000..94d6a5e1 --- /dev/null +++ b/mongoengine/queryset/visitor.py @@ -0,0 +1,237 @@ +import copy + +from mongoengine.errors import InvalidQueryError +from mongoengine.python_support import product, reduce + +from mongoengine.queryset import transform + +__all__ = ('Q',) + + +class QNodeVisitor(object): + """Base visitor class for visiting Q-object nodes in a query tree. + """ + + def visit_combination(self, combination): + """Called by QCombination objects. + """ + return combination + + def visit_query(self, query): + """Called by (New)Q objects. + """ + return query + + +class SimplificationVisitor(QNodeVisitor): + """Simplifies query trees by combinging unnecessary 'and' connection nodes + into a single Q-object. + """ + + def visit_combination(self, combination): + if combination.operation == combination.AND: + # The simplification only applies to 'simple' queries + if all(isinstance(node, Q) for node in combination.children): + queries = [node.query for node in combination.children] + return Q(**self._query_conjunction(queries)) + return combination + + def _query_conjunction(self, queries): + """Merges query dicts - effectively &ing them together. + """ + query_ops = set() + combined_query = {} + for query in queries: + ops = set(query.keys()) + # Make sure that the same operation isn't applied more than once + # to a single field + intersection = ops.intersection(query_ops) + if intersection: + msg = 'Duplicate query conditions: ' + raise InvalidQueryError(msg + ', '.join(intersection)) + + query_ops.update(ops) + combined_query.update(copy.deepcopy(query)) + return combined_query + + +class QueryTreeTransformerVisitor(QNodeVisitor): + """Transforms the query tree in to a form that may be used with MongoDB. + """ + + def visit_combination(self, combination): + if combination.operation == combination.AND: + # MongoDB doesn't allow us to have too many $or operations in our + # queries, so the aim is to move the ORs up the tree to one + # 'master' $or. Firstly, we must find all the necessary parts (part + # of an AND combination or just standard Q object), and store them + # separately from the OR parts. + or_groups = [] + and_parts = [] + for node in combination.children: + if isinstance(node, QCombination): + if node.operation == node.OR: + # Any of the children in an $or component may cause + # the query to succeed + or_groups.append(node.children) + elif node.operation == node.AND: + and_parts.append(node) + elif isinstance(node, Q): + and_parts.append(node) + + # Now we combine the parts into a usable query. AND together all of + # the necessary parts. Then for each $or part, create a new query + # that ANDs the necessary part with the $or part. + clauses = [] + for or_group in product(*or_groups): + q_object = reduce(lambda a, b: a & b, and_parts, Q()) + q_object = reduce(lambda a, b: a & b, or_group, q_object) + clauses.append(q_object) + # Finally, $or the generated clauses in to one query. Each of the + # clauses is sufficient for the query to succeed. + return reduce(lambda a, b: a | b, clauses, Q()) + + if combination.operation == combination.OR: + children = [] + # Crush any nested ORs in to this combination as MongoDB doesn't + # support nested $or operations + for node in combination.children: + if (isinstance(node, QCombination) and + node.operation == combination.OR): + children += node.children + else: + children.append(node) + combination.children = children + + return combination + + +class QueryCompilerVisitor(QNodeVisitor): + """Compiles the nodes in a query tree to a PyMongo-compatible query + dictionary. + """ + + def __init__(self, document): + self.document = document + + def visit_combination(self, combination): + if combination.operation == combination.OR: + return {'$or': combination.children} + elif combination.operation == combination.AND: + return self._mongo_query_conjunction(combination.children) + return combination + + def visit_query(self, query): + return transform.query(self.document, **query.query) + + def _mongo_query_conjunction(self, queries): + """Merges Mongo query dicts - effectively &ing them together. + """ + combined_query = {} + for query in queries: + for field, ops in query.items(): + if field not in combined_query: + combined_query[field] = ops + else: + # The field is already present in the query the only way + # we can merge is if both the existing value and the new + # value are operation dicts, reject anything else + if (not isinstance(combined_query[field], dict) or + not isinstance(ops, dict)): + message = 'Conflicting values for ' + field + raise InvalidQueryError(message) + + current_ops = set(combined_query[field].keys()) + new_ops = set(ops.keys()) + # Make sure that the same operation isn't applied more than + # once to a single field + intersection = current_ops.intersection(new_ops) + if intersection: + msg = 'Duplicate query conditions: ' + raise InvalidQueryError(msg + ', '.join(intersection)) + + # Right! We've got two non-overlapping dicts of operations! + combined_query[field].update(copy.deepcopy(ops)) + return combined_query + + +class QNode(object): + """Base class for nodes in query trees. + """ + + AND = 0 + OR = 1 + + def to_query(self, document): + query = self.accept(SimplificationVisitor()) + query = query.accept(QueryTreeTransformerVisitor()) + query = query.accept(QueryCompilerVisitor(document)) + return query + + def accept(self, visitor): + raise NotImplementedError + + def _combine(self, other, operation): + """Combine this node with another node into a QCombination object. + """ + if getattr(other, 'empty', True): + return self + + if self.empty: + return other + + return QCombination(operation, [self, other]) + + @property + def empty(self): + return False + + def __or__(self, other): + return self._combine(other, self.OR) + + def __and__(self, other): + return self._combine(other, self.AND) + + +class QCombination(QNode): + """Represents the combination of several conditions by a given logical + operator. + """ + + def __init__(self, operation, children): + self.operation = operation + self.children = [] + for node in children: + # If the child is a combination of the same type, we can merge its + # children directly into this combinations children + if isinstance(node, QCombination) and node.operation == operation: + self.children += node.children + else: + self.children.append(node) + + def accept(self, visitor): + for i in range(len(self.children)): + if isinstance(self.children[i], QNode): + self.children[i] = self.children[i].accept(visitor) + + return visitor.visit_combination(self) + + @property + def empty(self): + return not bool(self.children) + + +class Q(QNode): + """A simple query object, used in a query tree to build up more complex + query structures. + """ + + def __init__(self, **query): + self.query = query + + def accept(self, visitor): + return visitor.visit_query(self) + + @property + def empty(self): + return not bool(self.query) diff --git a/setup.cfg b/setup.cfg index d95a9176..3f3faa8c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,4 +8,4 @@ detailed-errors = 1 #cover-package = mongoengine py3where = build where = tests -#tests = test_bugfix.py \ No newline at end of file +#tests = document/__init__.py \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py index e69de29b..f2a43b05 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,2 @@ +from .all_warnings import AllWarnings +from .document import * \ No newline at end of file diff --git a/tests/test_all_warnings.py b/tests/all_warnings/__init__.py similarity index 91% rename from tests/test_all_warnings.py rename to tests/all_warnings/__init__.py index 9b38fa61..72de8222 100644 --- a/tests/test_all_warnings.py +++ b/tests/all_warnings/__init__.py @@ -1,11 +1,19 @@ +""" +This test has been put into a module. This is because it tests warnings that +only get triggered on first hit. This way we can ensure its imported into the +top level and called first by the test suite. +""" + import unittest import warnings from mongoengine import * -from mongoengine.tests import query_counter -class TestWarnings(unittest.TestCase): +__all__ = ('AllWarnings', ) + + +class AllWarnings(unittest.TestCase): def setUp(self): conn = connect(db='mongoenginetest') diff --git a/tests/document/__init__.py b/tests/document/__init__.py new file mode 100644 index 00000000..1ef25201 --- /dev/null +++ b/tests/document/__init__.py @@ -0,0 +1,11 @@ +# TODO EXPLICT IMPORTS + +from class_methods import * +from delta import * +from dynamic import * +from indexes import * +from inheritance import * +from instance import * + +if __name__ == '__main__': + unittest.main() diff --git a/tests/document/class_methods.py b/tests/document/class_methods.py new file mode 100644 index 00000000..8050998c --- /dev/null +++ b/tests/document/class_methods.py @@ -0,0 +1,183 @@ +# -*- coding: utf-8 -*- +from __future__ import with_statement +import unittest + +from mongoengine import * + +from mongoengine.queryset import NULLIFY +from mongoengine.connection import get_db + +__all__ = ("ClassMethodsTest", ) + + +class ClassMethodsTest(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + self.db = get_db() + + class Person(Document): + name = StringField() + age = IntField() + + non_field = True + + meta = {"allow_inheritance": True} + + self.Person = Person + + def tearDown(self): + for collection in self.db.collection_names(): + if 'system.' in collection: + continue + self.db.drop_collection(collection) + + def test_definition(self): + """Ensure that document may be defined using fields. + """ + self.assertEqual(['age', 'name', 'id'], self.Person._fields.keys()) + self.assertEqual([IntField, StringField, ObjectIdField], + [x.__class__ for x in self.Person._fields.values()]) + + def test_get_db(self): + """Ensure that get_db returns the expected db. + """ + db = self.Person._get_db() + self.assertEqual(self.db, db) + + def test_get_collection_name(self): + """Ensure that get_collection_name returns the expected collection + name. + """ + collection_name = 'person' + self.assertEqual(collection_name, self.Person._get_collection_name()) + + def test_get_collection(self): + """Ensure that get_collection returns the expected collection. + """ + collection_name = 'person' + collection = self.Person._get_collection() + self.assertEqual(self.db[collection_name], collection) + + def test_drop_collection(self): + """Ensure that the collection may be dropped from the database. + """ + collection_name = 'person' + self.Person(name='Test').save() + self.assertTrue(collection_name in self.db.collection_names()) + + self.Person.drop_collection() + self.assertFalse(collection_name in self.db.collection_names()) + + def test_register_delete_rule(self): + """Ensure that register delete rule adds a delete rule to the document + meta. + """ + class Job(Document): + employee = ReferenceField(self.Person) + + self.assertEqual(self.Person._meta.get('delete_rules'), None) + + self.Person.register_delete_rule(Job, 'employee', NULLIFY) + self.assertEqual(self.Person._meta['delete_rules'], + {(Job, 'employee'): NULLIFY}) + + def test_collection_naming(self): + """Ensure that a collection with a specified name may be used. + """ + + class DefaultNamingTest(Document): + pass + self.assertEqual('default_naming_test', + DefaultNamingTest._get_collection_name()) + + class CustomNamingTest(Document): + meta = {'collection': 'pimp_my_collection'} + + self.assertEqual('pimp_my_collection', + CustomNamingTest._get_collection_name()) + + class DynamicNamingTest(Document): + meta = {'collection': lambda c: "DYNAMO"} + self.assertEqual('DYNAMO', DynamicNamingTest._get_collection_name()) + + # Use Abstract class to handle backwards compatibility + class BaseDocument(Document): + meta = { + 'abstract': True, + 'collection': lambda c: c.__name__.lower() + } + + class OldNamingConvention(BaseDocument): + pass + self.assertEqual('oldnamingconvention', + OldNamingConvention._get_collection_name()) + + class InheritedAbstractNamingTest(BaseDocument): + meta = {'collection': 'wibble'} + self.assertEqual('wibble', + InheritedAbstractNamingTest._get_collection_name()) + + # Mixin tests + class BaseMixin(object): + meta = { + 'collection': lambda c: c.__name__.lower() + } + + class OldMixinNamingConvention(Document, BaseMixin): + pass + self.assertEqual('oldmixinnamingconvention', + OldMixinNamingConvention._get_collection_name()) + + class BaseMixin(object): + meta = { + 'collection': lambda c: c.__name__.lower() + } + + class BaseDocument(Document, BaseMixin): + meta = {'allow_inheritance': True} + + class MyDocument(BaseDocument): + pass + + self.assertEqual('basedocument', MyDocument._get_collection_name()) + + def test_custom_collection_name_operations(self): + """Ensure that a collection with a specified name is used as expected. + """ + collection_name = 'personCollTest' + + class Person(Document): + name = StringField() + meta = {'collection': collection_name} + + Person(name="Test User").save() + self.assertTrue(collection_name in self.db.collection_names()) + + user_obj = self.db[collection_name].find_one() + self.assertEqual(user_obj['name'], "Test User") + + user_obj = Person.objects[0] + self.assertEqual(user_obj.name, "Test User") + + Person.drop_collection() + self.assertFalse(collection_name in self.db.collection_names()) + + def test_collection_name_and_primary(self): + """Ensure that a collection with a specified name may be used. + """ + + class Person(Document): + name = StringField(primary_key=True) + meta = {'collection': 'app'} + + Person(name="Test User").save() + + user_obj = Person.objects.first() + self.assertEqual(user_obj.name, "Test User") + + Person.drop_collection() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/document/delta.py b/tests/document/delta.py new file mode 100644 index 00000000..f8a071d6 --- /dev/null +++ b/tests/document/delta.py @@ -0,0 +1,688 @@ +# -*- coding: utf-8 -*- +import unittest + +from mongoengine import * +from mongoengine.connection import get_db + +__all__ = ("DeltaTest",) + + +class DeltaTest(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + self.db = get_db() + + class Person(Document): + name = StringField() + age = IntField() + + non_field = True + + meta = {"allow_inheritance": True} + + self.Person = Person + + def tearDown(self): + for collection in self.db.collection_names(): + if 'system.' in collection: + continue + self.db.drop_collection(collection) + + def test_delta(self): + self.delta(Document) + self.delta(DynamicDocument) + + def delta(self, DocClass): + + class Doc(DocClass): + string_field = StringField() + int_field = IntField() + dict_field = DictField() + list_field = ListField() + + Doc.drop_collection() + doc = Doc() + doc.save() + + doc = Doc.objects.first() + self.assertEqual(doc._get_changed_fields(), []) + self.assertEqual(doc._delta(), ({}, {})) + + doc.string_field = 'hello' + self.assertEqual(doc._get_changed_fields(), ['string_field']) + self.assertEqual(doc._delta(), ({'string_field': 'hello'}, {})) + + doc._changed_fields = [] + doc.int_field = 1 + self.assertEqual(doc._get_changed_fields(), ['int_field']) + self.assertEqual(doc._delta(), ({'int_field': 1}, {})) + + doc._changed_fields = [] + dict_value = {'hello': 'world', 'ping': 'pong'} + doc.dict_field = dict_value + self.assertEqual(doc._get_changed_fields(), ['dict_field']) + self.assertEqual(doc._delta(), ({'dict_field': dict_value}, {})) + + doc._changed_fields = [] + list_value = ['1', 2, {'hello': 'world'}] + doc.list_field = list_value + self.assertEqual(doc._get_changed_fields(), ['list_field']) + self.assertEqual(doc._delta(), ({'list_field': list_value}, {})) + + # Test unsetting + doc._changed_fields = [] + doc.dict_field = {} + self.assertEqual(doc._get_changed_fields(), ['dict_field']) + self.assertEqual(doc._delta(), ({}, {'dict_field': 1})) + + doc._changed_fields = [] + doc.list_field = [] + self.assertEqual(doc._get_changed_fields(), ['list_field']) + self.assertEqual(doc._delta(), ({}, {'list_field': 1})) + + def test_delta_recursive(self): + self.delta_recursive(Document, EmbeddedDocument) + self.delta_recursive(DynamicDocument, EmbeddedDocument) + self.delta_recursive(Document, DynamicEmbeddedDocument) + self.delta_recursive(DynamicDocument, DynamicEmbeddedDocument) + + def delta_recursive(self, DocClass, EmbeddedClass): + + class Embedded(EmbeddedClass): + string_field = StringField() + int_field = IntField() + dict_field = DictField() + list_field = ListField() + + class Doc(DocClass): + string_field = StringField() + int_field = IntField() + dict_field = DictField() + list_field = ListField() + embedded_field = EmbeddedDocumentField(Embedded) + + Doc.drop_collection() + doc = Doc() + doc.save() + + doc = Doc.objects.first() + self.assertEqual(doc._get_changed_fields(), []) + self.assertEqual(doc._delta(), ({}, {})) + + embedded_1 = Embedded() + embedded_1.string_field = 'hello' + embedded_1.int_field = 1 + embedded_1.dict_field = {'hello': 'world'} + embedded_1.list_field = ['1', 2, {'hello': 'world'}] + doc.embedded_field = embedded_1 + + self.assertEqual(doc._get_changed_fields(), ['embedded_field']) + + embedded_delta = { + 'string_field': 'hello', + 'int_field': 1, + 'dict_field': {'hello': 'world'}, + 'list_field': ['1', 2, {'hello': 'world'}] + } + self.assertEqual(doc.embedded_field._delta(), (embedded_delta, {})) + embedded_delta.update({ + '_cls': 'Embedded', + }) + self.assertEqual(doc._delta(), + ({'embedded_field': embedded_delta}, {})) + + doc.save() + doc = doc.reload(10) + + doc.embedded_field.dict_field = {} + self.assertEqual(doc._get_changed_fields(), + ['embedded_field.dict_field']) + self.assertEqual(doc.embedded_field._delta(), ({}, {'dict_field': 1})) + self.assertEqual(doc._delta(), ({}, {'embedded_field.dict_field': 1})) + doc.save() + doc = doc.reload(10) + self.assertEqual(doc.embedded_field.dict_field, {}) + + doc.embedded_field.list_field = [] + self.assertEqual(doc._get_changed_fields(), + ['embedded_field.list_field']) + self.assertEqual(doc.embedded_field._delta(), ({}, {'list_field': 1})) + self.assertEqual(doc._delta(), ({}, {'embedded_field.list_field': 1})) + doc.save() + doc = doc.reload(10) + self.assertEqual(doc.embedded_field.list_field, []) + + embedded_2 = Embedded() + embedded_2.string_field = 'hello' + embedded_2.int_field = 1 + embedded_2.dict_field = {'hello': 'world'} + embedded_2.list_field = ['1', 2, {'hello': 'world'}] + + doc.embedded_field.list_field = ['1', 2, embedded_2] + self.assertEqual(doc._get_changed_fields(), + ['embedded_field.list_field']) + self.assertEqual(doc.embedded_field._delta(), ({ + 'list_field': ['1', 2, { + '_cls': 'Embedded', + 'string_field': 'hello', + 'dict_field': {'hello': 'world'}, + 'int_field': 1, + 'list_field': ['1', 2, {'hello': 'world'}], + }] + }, {})) + + self.assertEqual(doc._delta(), ({ + 'embedded_field.list_field': ['1', 2, { + '_cls': 'Embedded', + 'string_field': 'hello', + 'dict_field': {'hello': 'world'}, + 'int_field': 1, + 'list_field': ['1', 2, {'hello': 'world'}], + }] + }, {})) + doc.save() + doc = doc.reload(10) + + self.assertEqual(doc.embedded_field.list_field[0], '1') + self.assertEqual(doc.embedded_field.list_field[1], 2) + for k in doc.embedded_field.list_field[2]._fields: + self.assertEqual(doc.embedded_field.list_field[2][k], + embedded_2[k]) + + doc.embedded_field.list_field[2].string_field = 'world' + self.assertEqual(doc._get_changed_fields(), + ['embedded_field.list_field.2.string_field']) + self.assertEqual(doc.embedded_field._delta(), + ({'list_field.2.string_field': 'world'}, {})) + self.assertEqual(doc._delta(), + ({'embedded_field.list_field.2.string_field': 'world'}, {})) + doc.save() + doc = doc.reload(10) + self.assertEqual(doc.embedded_field.list_field[2].string_field, + 'world') + + # Test multiple assignments + doc.embedded_field.list_field[2].string_field = 'hello world' + doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] + self.assertEqual(doc._get_changed_fields(), + ['embedded_field.list_field']) + self.assertEqual(doc.embedded_field._delta(), ({ + 'list_field': ['1', 2, { + '_cls': 'Embedded', + 'string_field': 'hello world', + 'int_field': 1, + 'list_field': ['1', 2, {'hello': 'world'}], + 'dict_field': {'hello': 'world'}}]}, {})) + self.assertEqual(doc._delta(), ({ + 'embedded_field.list_field': ['1', 2, { + '_cls': 'Embedded', + 'string_field': 'hello world', + 'int_field': 1, + 'list_field': ['1', 2, {'hello': 'world'}], + 'dict_field': {'hello': 'world'}} + ]}, {})) + doc.save() + doc = doc.reload(10) + self.assertEqual(doc.embedded_field.list_field[2].string_field, + 'hello world') + + # Test list native methods + doc.embedded_field.list_field[2].list_field.pop(0) + self.assertEqual(doc._delta(), + ({'embedded_field.list_field.2.list_field': + [2, {'hello': 'world'}]}, {})) + doc.save() + doc = doc.reload(10) + + doc.embedded_field.list_field[2].list_field.append(1) + self.assertEqual(doc._delta(), + ({'embedded_field.list_field.2.list_field': + [2, {'hello': 'world'}, 1]}, {})) + doc.save() + doc = doc.reload(10) + self.assertEqual(doc.embedded_field.list_field[2].list_field, + [2, {'hello': 'world'}, 1]) + + doc.embedded_field.list_field[2].list_field.sort(key=str) + doc.save() + doc = doc.reload(10) + self.assertEqual(doc.embedded_field.list_field[2].list_field, + [1, 2, {'hello': 'world'}]) + + del(doc.embedded_field.list_field[2].list_field[2]['hello']) + self.assertEqual(doc._delta(), + ({'embedded_field.list_field.2.list_field': [1, 2, {}]}, {})) + doc.save() + doc = doc.reload(10) + + del(doc.embedded_field.list_field[2].list_field) + self.assertEqual(doc._delta(), + ({}, {'embedded_field.list_field.2.list_field': 1})) + + doc.save() + doc = doc.reload(10) + + doc.dict_field['Embedded'] = embedded_1 + doc.save() + doc = doc.reload(10) + + doc.dict_field['Embedded'].string_field = 'Hello World' + self.assertEqual(doc._get_changed_fields(), + ['dict_field.Embedded.string_field']) + self.assertEqual(doc._delta(), + ({'dict_field.Embedded.string_field': 'Hello World'}, {})) + + def test_circular_reference_deltas(self): + self.circular_reference_deltas(Document, Document) + self.circular_reference_deltas(Document, DynamicDocument) + self.circular_reference_deltas(DynamicDocument, Document) + self.circular_reference_deltas(DynamicDocument, DynamicDocument) + + def circular_reference_deltas(self, DocClass1, DocClass2): + + class Person(DocClass1): + name = StringField() + owns = ListField(ReferenceField('Organization')) + + class Organization(DocClass2): + name = StringField() + owner = ReferenceField('Person') + + person = Person(name="owner") + person.save() + organization = Organization(name="company") + organization.save() + + person.owns.append(organization) + organization.owner = person + + person.save() + organization.save() + + p = Person.objects[0].select_related() + o = Organization.objects.first() + self.assertEqual(p.owns[0], o) + self.assertEqual(o.owner, p) + + def test_circular_reference_deltas_2(self): + self.circular_reference_deltas_2(Document, Document) + self.circular_reference_deltas_2(Document, DynamicDocument) + self.circular_reference_deltas_2(DynamicDocument, Document) + self.circular_reference_deltas_2(DynamicDocument, DynamicDocument) + + def circular_reference_deltas_2(self, DocClass1, DocClass2): + + class Person(DocClass1): + name = StringField() + owns = ListField(ReferenceField('Organization')) + employer = ReferenceField('Organization') + + class Organization(DocClass2): + name = StringField() + owner = ReferenceField('Person') + employees = ListField(ReferenceField('Person')) + + Person.drop_collection() + Organization.drop_collection() + + person = Person(name="owner") + person.save() + + employee = Person(name="employee") + employee.save() + + organization = Organization(name="company") + organization.save() + + person.owns.append(organization) + organization.owner = person + + organization.employees.append(employee) + employee.employer = organization + + person.save() + organization.save() + employee.save() + + p = Person.objects.get(name="owner") + e = Person.objects.get(name="employee") + o = Organization.objects.first() + + self.assertEqual(p.owns[0], o) + self.assertEqual(o.owner, p) + self.assertEqual(e.employer, o) + + def test_delta_db_field(self): + self.delta_db_field(Document) + self.delta_db_field(DynamicDocument) + + def delta_db_field(self, DocClass): + + class Doc(DocClass): + string_field = StringField(db_field='db_string_field') + int_field = IntField(db_field='db_int_field') + dict_field = DictField(db_field='db_dict_field') + list_field = ListField(db_field='db_list_field') + + Doc.drop_collection() + doc = Doc() + doc.save() + + doc = Doc.objects.first() + self.assertEqual(doc._get_changed_fields(), []) + self.assertEqual(doc._delta(), ({}, {})) + + doc.string_field = 'hello' + self.assertEqual(doc._get_changed_fields(), ['db_string_field']) + self.assertEqual(doc._delta(), ({'db_string_field': 'hello'}, {})) + + doc._changed_fields = [] + doc.int_field = 1 + self.assertEqual(doc._get_changed_fields(), ['db_int_field']) + self.assertEqual(doc._delta(), ({'db_int_field': 1}, {})) + + doc._changed_fields = [] + dict_value = {'hello': 'world', 'ping': 'pong'} + doc.dict_field = dict_value + self.assertEqual(doc._get_changed_fields(), ['db_dict_field']) + self.assertEqual(doc._delta(), ({'db_dict_field': dict_value}, {})) + + doc._changed_fields = [] + list_value = ['1', 2, {'hello': 'world'}] + doc.list_field = list_value + self.assertEqual(doc._get_changed_fields(), ['db_list_field']) + self.assertEqual(doc._delta(), ({'db_list_field': list_value}, {})) + + # Test unsetting + doc._changed_fields = [] + doc.dict_field = {} + self.assertEqual(doc._get_changed_fields(), ['db_dict_field']) + self.assertEqual(doc._delta(), ({}, {'db_dict_field': 1})) + + doc._changed_fields = [] + doc.list_field = [] + self.assertEqual(doc._get_changed_fields(), ['db_list_field']) + self.assertEqual(doc._delta(), ({}, {'db_list_field': 1})) + + # Test it saves that data + doc = Doc() + doc.save() + + doc.string_field = 'hello' + doc.int_field = 1 + doc.dict_field = {'hello': 'world'} + doc.list_field = ['1', 2, {'hello': 'world'}] + doc.save() + doc = doc.reload(10) + + self.assertEqual(doc.string_field, 'hello') + self.assertEqual(doc.int_field, 1) + self.assertEqual(doc.dict_field, {'hello': 'world'}) + self.assertEqual(doc.list_field, ['1', 2, {'hello': 'world'}]) + + def test_delta_recursive_db_field(self): + self.delta_recursive_db_field(Document, EmbeddedDocument) + self.delta_recursive_db_field(Document, DynamicEmbeddedDocument) + self.delta_recursive_db_field(DynamicDocument, EmbeddedDocument) + self.delta_recursive_db_field(DynamicDocument, DynamicEmbeddedDocument) + + def delta_recursive_db_field(self, DocClass, EmbeddedClass): + + class Embedded(EmbeddedClass): + string_field = StringField(db_field='db_string_field') + int_field = IntField(db_field='db_int_field') + dict_field = DictField(db_field='db_dict_field') + list_field = ListField(db_field='db_list_field') + + class Doc(DocClass): + string_field = StringField(db_field='db_string_field') + int_field = IntField(db_field='db_int_field') + dict_field = DictField(db_field='db_dict_field') + list_field = ListField(db_field='db_list_field') + embedded_field = EmbeddedDocumentField(Embedded, + db_field='db_embedded_field') + + Doc.drop_collection() + doc = Doc() + doc.save() + + doc = Doc.objects.first() + self.assertEqual(doc._get_changed_fields(), []) + self.assertEqual(doc._delta(), ({}, {})) + + embedded_1 = Embedded() + embedded_1.string_field = 'hello' + embedded_1.int_field = 1 + embedded_1.dict_field = {'hello': 'world'} + embedded_1.list_field = ['1', 2, {'hello': 'world'}] + doc.embedded_field = embedded_1 + + self.assertEqual(doc._get_changed_fields(), ['db_embedded_field']) + + embedded_delta = { + 'db_string_field': 'hello', + 'db_int_field': 1, + 'db_dict_field': {'hello': 'world'}, + 'db_list_field': ['1', 2, {'hello': 'world'}] + } + self.assertEqual(doc.embedded_field._delta(), (embedded_delta, {})) + embedded_delta.update({ + '_cls': 'Embedded', + }) + self.assertEqual(doc._delta(), + ({'db_embedded_field': embedded_delta}, {})) + + doc.save() + doc = doc.reload(10) + + doc.embedded_field.dict_field = {} + self.assertEqual(doc._get_changed_fields(), + ['db_embedded_field.db_dict_field']) + self.assertEqual(doc.embedded_field._delta(), + ({}, {'db_dict_field': 1})) + self.assertEqual(doc._delta(), + ({}, {'db_embedded_field.db_dict_field': 1})) + doc.save() + doc = doc.reload(10) + self.assertEqual(doc.embedded_field.dict_field, {}) + + doc.embedded_field.list_field = [] + self.assertEqual(doc._get_changed_fields(), + ['db_embedded_field.db_list_field']) + self.assertEqual(doc.embedded_field._delta(), + ({}, {'db_list_field': 1})) + self.assertEqual(doc._delta(), + ({}, {'db_embedded_field.db_list_field': 1})) + doc.save() + doc = doc.reload(10) + self.assertEqual(doc.embedded_field.list_field, []) + + embedded_2 = Embedded() + embedded_2.string_field = 'hello' + embedded_2.int_field = 1 + embedded_2.dict_field = {'hello': 'world'} + embedded_2.list_field = ['1', 2, {'hello': 'world'}] + + doc.embedded_field.list_field = ['1', 2, embedded_2] + self.assertEqual(doc._get_changed_fields(), + ['db_embedded_field.db_list_field']) + self.assertEqual(doc.embedded_field._delta(), ({ + 'db_list_field': ['1', 2, { + '_cls': 'Embedded', + 'db_string_field': 'hello', + 'db_dict_field': {'hello': 'world'}, + 'db_int_field': 1, + 'db_list_field': ['1', 2, {'hello': 'world'}], + }] + }, {})) + + self.assertEqual(doc._delta(), ({ + 'db_embedded_field.db_list_field': ['1', 2, { + '_cls': 'Embedded', + 'db_string_field': 'hello', + 'db_dict_field': {'hello': 'world'}, + 'db_int_field': 1, + 'db_list_field': ['1', 2, {'hello': 'world'}], + }] + }, {})) + doc.save() + doc = doc.reload(10) + + self.assertEqual(doc.embedded_field.list_field[0], '1') + self.assertEqual(doc.embedded_field.list_field[1], 2) + for k in doc.embedded_field.list_field[2]._fields: + self.assertEqual(doc.embedded_field.list_field[2][k], + embedded_2[k]) + + doc.embedded_field.list_field[2].string_field = 'world' + self.assertEqual(doc._get_changed_fields(), + ['db_embedded_field.db_list_field.2.db_string_field']) + self.assertEqual(doc.embedded_field._delta(), + ({'db_list_field.2.db_string_field': 'world'}, {})) + self.assertEqual(doc._delta(), + ({'db_embedded_field.db_list_field.2.db_string_field': 'world'}, + {})) + doc.save() + doc = doc.reload(10) + self.assertEqual(doc.embedded_field.list_field[2].string_field, + 'world') + + # Test multiple assignments + doc.embedded_field.list_field[2].string_field = 'hello world' + doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] + self.assertEqual(doc._get_changed_fields(), + ['db_embedded_field.db_list_field']) + self.assertEqual(doc.embedded_field._delta(), ({ + 'db_list_field': ['1', 2, { + '_cls': 'Embedded', + 'db_string_field': 'hello world', + 'db_int_field': 1, + 'db_list_field': ['1', 2, {'hello': 'world'}], + 'db_dict_field': {'hello': 'world'}}]}, {})) + self.assertEqual(doc._delta(), ({ + 'db_embedded_field.db_list_field': ['1', 2, { + '_cls': 'Embedded', + 'db_string_field': 'hello world', + 'db_int_field': 1, + 'db_list_field': ['1', 2, {'hello': 'world'}], + 'db_dict_field': {'hello': 'world'}} + ]}, {})) + doc.save() + doc = doc.reload(10) + self.assertEqual(doc.embedded_field.list_field[2].string_field, + 'hello world') + + # Test list native methods + doc.embedded_field.list_field[2].list_field.pop(0) + self.assertEqual(doc._delta(), + ({'db_embedded_field.db_list_field.2.db_list_field': + [2, {'hello': 'world'}]}, {})) + doc.save() + doc = doc.reload(10) + + doc.embedded_field.list_field[2].list_field.append(1) + self.assertEqual(doc._delta(), + ({'db_embedded_field.db_list_field.2.db_list_field': + [2, {'hello': 'world'}, 1]}, {})) + doc.save() + doc = doc.reload(10) + self.assertEqual(doc.embedded_field.list_field[2].list_field, + [2, {'hello': 'world'}, 1]) + + doc.embedded_field.list_field[2].list_field.sort(key=str) + doc.save() + doc = doc.reload(10) + self.assertEqual(doc.embedded_field.list_field[2].list_field, + [1, 2, {'hello': 'world'}]) + + del(doc.embedded_field.list_field[2].list_field[2]['hello']) + self.assertEqual(doc._delta(), + ({'db_embedded_field.db_list_field.2.db_list_field': + [1, 2, {}]}, {})) + doc.save() + doc = doc.reload(10) + + del(doc.embedded_field.list_field[2].list_field) + self.assertEqual(doc._delta(), ({}, + {'db_embedded_field.db_list_field.2.db_list_field': 1})) + + def test_delta_for_dynamic_documents(self): + class Person(DynamicDocument): + name = StringField() + meta = {'allow_inheritance': True} + + Person.drop_collection() + + p = Person(name="James", age=34) + self.assertEqual(p._delta(), ({'age': 34, 'name': 'James', + '_cls': 'Person'}, {})) + + p.doc = 123 + del(p.doc) + self.assertEqual(p._delta(), ({'age': 34, 'name': 'James', + '_cls': 'Person'}, {'doc': 1})) + + p = Person() + p.name = "Dean" + p.age = 22 + p.save() + + p.age = 24 + self.assertEqual(p.age, 24) + self.assertEqual(p._get_changed_fields(), ['age']) + self.assertEqual(p._delta(), ({'age': 24}, {})) + + p = self.Person.objects(age=22).get() + p.age = 24 + self.assertEqual(p.age, 24) + self.assertEqual(p._get_changed_fields(), ['age']) + self.assertEqual(p._delta(), ({'age': 24}, {})) + + p.save() + self.assertEqual(1, self.Person.objects(age=24).count()) + + def test_dynamic_delta(self): + + class Doc(DynamicDocument): + pass + + Doc.drop_collection() + doc = Doc() + doc.save() + + doc = Doc.objects.first() + self.assertEqual(doc._get_changed_fields(), []) + self.assertEqual(doc._delta(), ({}, {})) + + doc.string_field = 'hello' + self.assertEqual(doc._get_changed_fields(), ['string_field']) + self.assertEqual(doc._delta(), ({'string_field': 'hello'}, {})) + + doc._changed_fields = [] + doc.int_field = 1 + self.assertEqual(doc._get_changed_fields(), ['int_field']) + self.assertEqual(doc._delta(), ({'int_field': 1}, {})) + + doc._changed_fields = [] + dict_value = {'hello': 'world', 'ping': 'pong'} + doc.dict_field = dict_value + self.assertEqual(doc._get_changed_fields(), ['dict_field']) + self.assertEqual(doc._delta(), ({'dict_field': dict_value}, {})) + + doc._changed_fields = [] + list_value = ['1', 2, {'hello': 'world'}] + doc.list_field = list_value + self.assertEqual(doc._get_changed_fields(), ['list_field']) + self.assertEqual(doc._delta(), ({'list_field': list_value}, {})) + + # Test unsetting + doc._changed_fields = [] + doc.dict_field = {} + self.assertEqual(doc._get_changed_fields(), ['dict_field']) + self.assertEqual(doc._delta(), ({}, {'dict_field': 1})) + + doc._changed_fields = [] + doc.list_field = [] + self.assertEqual(doc._get_changed_fields(), ['list_field']) + self.assertEqual(doc._delta(), ({}, {'list_field': 1})) diff --git a/tests/document/dynamic.py b/tests/document/dynamic.py new file mode 100644 index 00000000..ef279179 --- /dev/null +++ b/tests/document/dynamic.py @@ -0,0 +1,270 @@ +import unittest + +from mongoengine import * +from mongoengine.connection import get_db + +__all__ = ("DynamicTest", ) + + +class DynamicTest(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + self.db = get_db() + + class Person(DynamicDocument): + name = StringField() + meta = {'allow_inheritance': True} + + Person.drop_collection() + + self.Person = Person + + def test_simple_dynamic_document(self): + """Ensures simple dynamic documents are saved correctly""" + + p = self.Person() + p.name = "James" + p.age = 34 + + self.assertEqual(p.to_mongo(), {"_cls": "Person", "name": "James", + "age": 34}) + + p.save() + + self.assertEqual(self.Person.objects.first().age, 34) + + # Confirm no changes to self.Person + self.assertFalse(hasattr(self.Person, 'age')) + + def test_change_scope_of_variable(self): + """Test changing the scope of a dynamic field has no adverse effects""" + p = self.Person() + p.name = "Dean" + p.misc = 22 + p.save() + + p = self.Person.objects.get() + p.misc = {'hello': 'world'} + p.save() + + p = self.Person.objects.get() + self.assertEqual(p.misc, {'hello': 'world'}) + + def test_delete_dynamic_field(self): + """Test deleting a dynamic field works""" + self.Person.drop_collection() + p = self.Person() + p.name = "Dean" + p.misc = 22 + p.save() + + p = self.Person.objects.get() + p.misc = {'hello': 'world'} + p.save() + + p = self.Person.objects.get() + self.assertEqual(p.misc, {'hello': 'world'}) + collection = self.db[self.Person._get_collection_name()] + obj = collection.find_one() + self.assertEqual(sorted(obj.keys()), ['_cls', '_id', 'misc', 'name']) + + del(p.misc) + p.save() + + p = self.Person.objects.get() + self.assertFalse(hasattr(p, 'misc')) + + obj = collection.find_one() + self.assertEqual(sorted(obj.keys()), ['_cls', '_id', 'name']) + + def test_dynamic_document_queries(self): + """Ensure we can query dynamic fields""" + p = self.Person() + p.name = "Dean" + p.age = 22 + p.save() + + self.assertEqual(1, self.Person.objects(age=22).count()) + p = self.Person.objects(age=22) + p = p.get() + self.assertEqual(22, p.age) + + def test_complex_dynamic_document_queries(self): + class Person(DynamicDocument): + name = StringField() + + Person.drop_collection() + + p = Person(name="test") + p.age = "ten" + p.save() + + p1 = Person(name="test1") + p1.age = "less then ten and a half" + p1.save() + + p2 = Person(name="test2") + p2.age = 10 + p2.save() + + self.assertEqual(Person.objects(age__icontains='ten').count(), 2) + self.assertEqual(Person.objects(age__gte=10).count(), 1) + + def test_complex_data_lookups(self): + """Ensure you can query dynamic document dynamic fields""" + p = self.Person() + p.misc = {'hello': 'world'} + p.save() + + self.assertEqual(1, self.Person.objects(misc__hello='world').count()) + + def test_inheritance(self): + """Ensure that dynamic document plays nice with inheritance""" + class Employee(self.Person): + salary = IntField() + + Employee.drop_collection() + + self.assertTrue('name' in Employee._fields) + self.assertTrue('salary' in Employee._fields) + self.assertEqual(Employee._get_collection_name(), + self.Person._get_collection_name()) + + joe_bloggs = Employee() + joe_bloggs.name = "Joe Bloggs" + joe_bloggs.salary = 10 + joe_bloggs.age = 20 + joe_bloggs.save() + + self.assertEqual(1, self.Person.objects(age=20).count()) + self.assertEqual(1, Employee.objects(age=20).count()) + + joe_bloggs = self.Person.objects.first() + self.assertTrue(isinstance(joe_bloggs, Employee)) + + def test_embedded_dynamic_document(self): + """Test dynamic embedded documents""" + class Embedded(DynamicEmbeddedDocument): + pass + + class Doc(DynamicDocument): + pass + + Doc.drop_collection() + doc = Doc() + + embedded_1 = Embedded() + embedded_1.string_field = 'hello' + embedded_1.int_field = 1 + embedded_1.dict_field = {'hello': 'world'} + embedded_1.list_field = ['1', 2, {'hello': 'world'}] + doc.embedded_field = embedded_1 + + self.assertEqual(doc.to_mongo(), {"_cls": "Doc", + "embedded_field": { + "_cls": "Embedded", + "string_field": "hello", + "int_field": 1, + "dict_field": {"hello": "world"}, + "list_field": ['1', 2, {'hello': 'world'}] + } + }) + doc.save() + + doc = Doc.objects.first() + self.assertEqual(doc.embedded_field.__class__, Embedded) + self.assertEqual(doc.embedded_field.string_field, "hello") + self.assertEqual(doc.embedded_field.int_field, 1) + self.assertEqual(doc.embedded_field.dict_field, {'hello': 'world'}) + self.assertEqual(doc.embedded_field.list_field, + ['1', 2, {'hello': 'world'}]) + + def test_complex_embedded_documents(self): + """Test complex dynamic embedded documents setups""" + class Embedded(DynamicEmbeddedDocument): + pass + + class Doc(DynamicDocument): + pass + + Doc.drop_collection() + doc = Doc() + + embedded_1 = Embedded() + embedded_1.string_field = 'hello' + embedded_1.int_field = 1 + embedded_1.dict_field = {'hello': 'world'} + + embedded_2 = Embedded() + embedded_2.string_field = 'hello' + embedded_2.int_field = 1 + embedded_2.dict_field = {'hello': 'world'} + embedded_2.list_field = ['1', 2, {'hello': 'world'}] + + embedded_1.list_field = ['1', 2, embedded_2] + doc.embedded_field = embedded_1 + + self.assertEqual(doc.to_mongo(), {"_cls": "Doc", + "embedded_field": { + "_cls": "Embedded", + "string_field": "hello", + "int_field": 1, + "dict_field": {"hello": "world"}, + "list_field": ['1', 2, + {"_cls": "Embedded", + "string_field": "hello", + "int_field": 1, + "dict_field": {"hello": "world"}, + "list_field": ['1', 2, {'hello': 'world'}]} + ] + } + }) + doc.save() + doc = Doc.objects.first() + self.assertEqual(doc.embedded_field.__class__, Embedded) + self.assertEqual(doc.embedded_field.string_field, "hello") + self.assertEqual(doc.embedded_field.int_field, 1) + self.assertEqual(doc.embedded_field.dict_field, {'hello': 'world'}) + self.assertEqual(doc.embedded_field.list_field[0], '1') + self.assertEqual(doc.embedded_field.list_field[1], 2) + + embedded_field = doc.embedded_field.list_field[2] + + self.assertEqual(embedded_field.__class__, Embedded) + self.assertEqual(embedded_field.string_field, "hello") + self.assertEqual(embedded_field.int_field, 1) + self.assertEqual(embedded_field.dict_field, {'hello': 'world'}) + self.assertEqual(embedded_field.list_field, ['1', 2, + {'hello': 'world'}]) + + def test_dynamic_and_embedded(self): + """Ensure embedded documents play nicely""" + + class Address(EmbeddedDocument): + city = StringField() + + class Person(DynamicDocument): + name = StringField() + meta = {'allow_inheritance': True} + + Person.drop_collection() + + Person(name="Ross", address=Address(city="London")).save() + + person = Person.objects.first() + person.address.city = "Lundenne" + person.save() + + self.assertEqual(Person.objects.first().address.city, "Lundenne") + + person = Person.objects.first() + person.address = Address(city="Londinium") + person.save() + + self.assertEqual(Person.objects.first().address.city, "Londinium") + + person = Person.objects.first() + person.age = 35 + person.save() + self.assertEqual(Person.objects.first().age, 35) diff --git a/tests/document/indexes.py b/tests/document/indexes.py new file mode 100644 index 00000000..a6b74cd0 --- /dev/null +++ b/tests/document/indexes.py @@ -0,0 +1,637 @@ +# -*- coding: utf-8 -*- +from __future__ import with_statement +import bson +import os +import pickle +import pymongo +import sys +import unittest +import uuid +import warnings + +from nose.plugins.skip import SkipTest +from datetime import datetime + +from tests.fixtures import Base, Mixin, PickleEmbedded, PickleTest + +from mongoengine import * +from mongoengine.errors import (NotRegistered, InvalidDocumentError, + InvalidQueryError) +from mongoengine.queryset import NULLIFY, Q +from mongoengine.connection import get_db, get_connection + +TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png') + +__all__ = ("InstanceTest", ) + + +class InstanceTest(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + self.db = get_db() + + class Person(Document): + name = StringField() + age = IntField() + + non_field = True + + meta = {"allow_inheritance": True} + + self.Person = Person + + def tearDown(self): + for collection in self.db.collection_names(): + if 'system.' in collection: + continue + self.db.drop_collection(collection) + + def test_indexes_document(self, ): + """Ensure that indexes are used when meta[indexes] is specified for + Documents + """ + index_test(Document) + + def test_indexes_dynamic_document(self, ): + """Ensure that indexes are used when meta[indexes] is specified for + Dynamic Documents + """ + index_test(DynamicDocument) + + def index_test(self, InheritFrom): + + class BlogPost(InheritFrom): + date = DateTimeField(db_field='addDate', default=datetime.now) + category = StringField() + tags = ListField(StringField()) + meta = { + 'indexes': [ + '-date', + 'tags', + ('category', '-date') + ], + 'allow_inheritance': True + } + + expected_specs = [{'fields': [('_cls', 1), ('addDate', -1)]}, + {'fields': [('_cls', 1), ('tags', 1)]}, + {'fields': [('_cls', 1), ('category', 1), + ('addDate', -1)]}] + self.assertEqual(expected_specs, BlogPost._meta['index_specs']) + + BlogPost.objects._ensure_indexes() + info = BlogPost.objects._collection.index_information() + # _id, '-date', 'tags', ('cat', 'date') + # NB: there is no index on _cls by itself, since + # the indices on -date and tags will both contain + # _cls as first element in the key + self.assertEqual(len(info), 4) + info = [value['key'] for key, value in info.iteritems()] + for expected in expected_specs: + self.assertTrue(expected['fields'] in info) + + class ExtendedBlogPost(BlogPost): + title = StringField() + meta = {'indexes': ['title']} + + expected_specs.append({'fields': [('_cls', 1), ('title', 1)]}) + self.assertEqual(expected_specs, ExtendedBlogPost._meta['index_specs']) + + BlogPost.drop_collection() + + ExtendedBlogPost.objects._ensure_indexes() + info = ExtendedBlogPost.objects._collection.index_information() + info = [value['key'] for key, value in info.iteritems()] + for expected in expected_specs: + self.assertTrue(expected['fields'] in info) + + def test_inherited_index(self): + """Ensure index specs are inhertited correctly""" + + class A(Document): + title = StringField() + meta = { + 'indexes': [ + { + 'fields': ('title',), + }, + ], + 'allow_inheritance': True, + } + + class B(A): + description = StringField() + + self.assertEqual(A._meta['index_specs'], B._meta['index_specs']) + self.assertEqual([{'fields': [('_cls', 1), ('title', 1)]}], + A._meta['index_specs']) + + def test_build_index_spec_is_not_destructive(self): + + class MyDoc(Document): + keywords = StringField() + + meta = { + 'indexes': ['keywords'], + 'allow_inheritance': False + } + + self.assertEqual(MyDoc._meta['index_specs'], + [{'fields': [('keywords', 1)]}]) + + # Force index creation + MyDoc.objects._ensure_indexes() + + self.assertEqual(MyDoc._meta['index_specs'], + [{'fields': [('keywords', 1)]}]) + + def test_embedded_document_index_meta(self): + """Ensure that embedded document indexes are created explicitly + """ + class Rank(EmbeddedDocument): + title = StringField(required=True) + + class Person(Document): + name = StringField(required=True) + rank = EmbeddedDocumentField(Rank, required=False) + + meta = { + 'indexes': [ + 'rank.title', + ], + 'allow_inheritance': False + } + + self.assertEqual([{'fields': [('rank.title', 1)]}], + Person._meta['index_specs']) + + Person.drop_collection() + + # Indexes are lazy so use list() to perform query + list(Person.objects) + info = Person.objects._collection.index_information() + info = [value['key'] for key, value in info.iteritems()] + self.assertTrue([('rank.title', 1)] in info) + + def test_explicit_geo2d_index(self): + """Ensure that geo2d indexes work when created via meta[indexes] + """ + class Place(Document): + location = DictField() + meta = { + 'allow_inheritance': True, + 'indexes': [ + '*location.point', + ] + } + + self.assertEqual([{'fields': [('location.point', '2d')]}], + Place._meta['index_specs']) + + Place.objects()._ensure_indexes() + info = Place._get_collection().index_information() + info = [value['key'] for key, value in info.iteritems()] + self.assertTrue([('location.point', '2d')] in info) + + def test_dictionary_indexes(self): + """Ensure that indexes are used when meta[indexes] contains + dictionaries instead of lists. + """ + class BlogPost(Document): + date = DateTimeField(db_field='addDate', default=datetime.now) + category = StringField() + tags = ListField(StringField()) + meta = { + 'indexes': [ + {'fields': ['-date'], 'unique': True, + 'sparse': True, 'types': False}, + ], + } + + self.assertEqual([{'fields': [('addDate', -1)], 'unique': True, + 'sparse': True, 'types': False}], + BlogPost._meta['index_specs']) + + BlogPost.drop_collection() + + info = BlogPost.objects._collection.index_information() + # _id, '-date' + self.assertEqual(len(info), 3) + + # Indexes are lazy so use list() to perform query + list(BlogPost.objects) + info = BlogPost.objects._collection.index_information() + info = [(value['key'], + value.get('unique', False), + value.get('sparse', False)) + for key, value in info.iteritems()] + self.assertTrue(([('addDate', -1)], True, True) in info) + + BlogPost.drop_collection() + + def test_abstract_index_inheritance(self): + + class UserBase(Document): + user_guid = StringField(required=True) + meta = { + 'abstract': True, + 'indexes': ['user_guid'], + 'allow_inheritance': True + } + + class Person(UserBase): + name = StringField() + + meta = { + 'indexes': ['name'], + } + + Person(name="test", user_guid='123').save() + + self.assertEqual(1, Person.objects.count()) + info = Person.objects._collection.index_information() + self.assertEqual(info.keys(), ['_cls_1_name_1', '_cls_1_user_guid_1', + '_id_']) + + def test_disable_index_creation(self): + """Tests setting auto_create_index to False on the connection will + disable any index generation. + """ + class User(Document): + meta = { + 'indexes': ['user_guid'], + 'auto_create_index': False + } + user_guid = StringField(required=True) + + + User.drop_collection() + + u = User(user_guid='123') + u.save() + + self.assertEqual(1, User.objects.count()) + info = User.objects._collection.index_information() + self.assertEqual(info.keys(), ['_id_']) + User.drop_collection() + + def test_embedded_document_index(self): + """Tests settings an index on an embedded document + """ + class Date(EmbeddedDocument): + year = IntField(db_field='yr') + + class BlogPost(Document): + title = StringField() + date = EmbeddedDocumentField(Date) + + meta = { + 'indexes': [ + '-date.year' + ], + } + + BlogPost.drop_collection() + + info = BlogPost.objects._collection.index_information() + self.assertEqual(info.keys(), ['_cls_1_date.yr_-1', '_id_']) + BlogPost.drop_collection() + + def test_list_embedded_document_index(self): + """Ensure list embedded documents can be indexed + """ + class Tag(EmbeddedDocument): + name = StringField(db_field='tag') + + class BlogPost(Document): + title = StringField() + tags = ListField(EmbeddedDocumentField(Tag)) + + meta = { + 'indexes': [ + 'tags.name' + ] + } + + BlogPost.drop_collection() + + info = BlogPost.objects._collection.index_information() + # we don't use _cls in with list fields by default + self.assertEqual(info.keys(), ['_id_', '_cls_1_tags.tag_1']) + + post1 = BlogPost(title="Embedded Indexes tests in place", + tags=[Tag(name="about"), Tag(name="time")] + ) + post1.save() + BlogPost.drop_collection() + + def test_recursive_embedded_objects_dont_break_indexes(self): + + class RecursiveObject(EmbeddedDocument): + obj = EmbeddedDocumentField('self') + + class RecursiveDocument(Document): + recursive_obj = EmbeddedDocumentField(RecursiveObject) + meta = {'allow_inheritance': True} + + RecursiveDocument.objects._ensure_indexes() + info = RecursiveDocument._get_collection().index_information() + self.assertEqual(info.keys(), ['_id_', '_cls_1']) + + def test_geo_indexes_recursion(self): + + class Location(Document): + name = StringField() + location = GeoPointField() + + class Parent(Document): + name = StringField() + location = ReferenceField(Location) + + Location.drop_collection() + Parent.drop_collection() + + list(Parent.objects) + + collection = Parent._get_collection() + info = collection.index_information() + + self.assertFalse('location_2d' in info) + + self.assertEqual(len(Parent._geo_indices()), 0) + self.assertEqual(len(Location._geo_indices()), 1) + + def test_covered_index(self): + """Ensure that covered indexes can be used + """ + + class Test(Document): + a = IntField() + + meta = { + 'indexes': ['a'], + 'allow_inheritance': False + } + + Test.drop_collection() + + obj = Test(a=1) + obj.save() + + # Need to be explicit about covered indexes as mongoDB doesn't know if + # the documents returned might have more keys in that here. + query_plan = Test.objects(id=obj.id).exclude('a').explain() + self.assertFalse(query_plan['indexOnly']) + + query_plan = Test.objects(id=obj.id).only('id').explain() + self.assertTrue(query_plan['indexOnly']) + + query_plan = Test.objects(a=1).only('a').exclude('id').explain() + self.assertTrue(query_plan['indexOnly']) + + def test_index_on_id(self): + + class BlogPost(Document): + meta = { + 'indexes': [ + ['categories', 'id'] + ], + 'allow_inheritance': False + } + + title = StringField(required=True) + description = StringField(required=True) + categories = ListField() + + BlogPost.drop_collection() + + indexes = BlogPost.objects._collection.index_information() + self.assertEqual(indexes['categories_1__id_1']['key'], + [('categories', 1), ('_id', 1)]) + + def test_hint(self): + + class BlogPost(Document): + tags = ListField(StringField()) + meta = { + 'indexes': [ + 'tags', + ], + } + + BlogPost.drop_collection() + + for i in xrange(0, 10): + tags = [("tag %i" % n) for n in xrange(0, i % 2)] + BlogPost(tags=tags).save() + + self.assertEqual(BlogPost.objects.count(), 10) + self.assertEqual(BlogPost.objects.hint().count(), 10) + self.assertEqual(BlogPost.objects.hint([('tags', 1)]).count(), 10) + + self.assertEqual(BlogPost.objects.hint([('ZZ', 1)]).count(), 10) + + def invalid_index(): + BlogPost.objects.hint('tags') + self.assertRaises(TypeError, invalid_index) + + def invalid_index_2(): + return BlogPost.objects.hint(('tags', 1)) + self.assertRaises(TypeError, invalid_index_2) + + def test_unique(self): + """Ensure that uniqueness constraints are applied to fields. + """ + class BlogPost(Document): + title = StringField() + slug = StringField(unique=True) + + BlogPost.drop_collection() + + post1 = BlogPost(title='test1', slug='test') + post1.save() + + # Two posts with the same slug is not allowed + post2 = BlogPost(title='test2', slug='test') + self.assertRaises(NotUniqueError, post2.save) + + # Ensure backwards compatibilty for errors + self.assertRaises(OperationError, post2.save) + + def test_unique_with(self): + """Ensure that unique_with constraints are applied to fields. + """ + class Date(EmbeddedDocument): + year = IntField(db_field='yr') + + class BlogPost(Document): + title = StringField() + date = EmbeddedDocumentField(Date) + slug = StringField(unique_with='date.year') + + BlogPost.drop_collection() + + post1 = BlogPost(title='test1', date=Date(year=2009), slug='test') + post1.save() + + # day is different so won't raise exception + post2 = BlogPost(title='test2', date=Date(year=2010), slug='test') + post2.save() + + # Now there will be two docs with the same slug and the same day: fail + post3 = BlogPost(title='test3', date=Date(year=2010), slug='test') + self.assertRaises(OperationError, post3.save) + + BlogPost.drop_collection() + + def test_unique_embedded_document(self): + """Ensure that uniqueness constraints are applied to fields on embedded documents. + """ + class SubDocument(EmbeddedDocument): + year = IntField(db_field='yr') + slug = StringField(unique=True) + + class BlogPost(Document): + title = StringField() + sub = EmbeddedDocumentField(SubDocument) + + BlogPost.drop_collection() + + post1 = BlogPost(title='test1', sub=SubDocument(year=2009, slug="test")) + post1.save() + + # sub.slug is different so won't raise exception + post2 = BlogPost(title='test2', sub=SubDocument(year=2010, slug='another-slug')) + post2.save() + + # Now there will be two docs with the same sub.slug + post3 = BlogPost(title='test3', sub=SubDocument(year=2010, slug='test')) + self.assertRaises(NotUniqueError, post3.save) + + BlogPost.drop_collection() + + def test_unique_with_embedded_document_and_embedded_unique(self): + """Ensure that uniqueness constraints are applied to fields on + embedded documents. And work with unique_with as well. + """ + class SubDocument(EmbeddedDocument): + year = IntField(db_field='yr') + slug = StringField(unique=True) + + class BlogPost(Document): + title = StringField(unique_with='sub.year') + sub = EmbeddedDocumentField(SubDocument) + + BlogPost.drop_collection() + + post1 = BlogPost(title='test1', sub=SubDocument(year=2009, slug="test")) + post1.save() + + # sub.slug is different so won't raise exception + post2 = BlogPost(title='test2', sub=SubDocument(year=2010, slug='another-slug')) + post2.save() + + # Now there will be two docs with the same sub.slug + post3 = BlogPost(title='test3', sub=SubDocument(year=2010, slug='test')) + self.assertRaises(NotUniqueError, post3.save) + + # Now there will be two docs with the same title and year + post3 = BlogPost(title='test1', sub=SubDocument(year=2009, slug='test-1')) + self.assertRaises(NotUniqueError, post3.save) + + BlogPost.drop_collection() + + def test_ttl_indexes(self): + + class Log(Document): + created = DateTimeField(default=datetime.now) + meta = { + 'indexes': [ + {'fields': ['created'], 'expireAfterSeconds': 3600} + ] + } + + Log.drop_collection() + + if pymongo.version_tuple[0] < 2 and pymongo.version_tuple[1] < 3: + raise SkipTest('pymongo needs to be 2.3 or higher for this test') + + connection = get_connection() + version_array = connection.server_info()['versionArray'] + if version_array[0] < 2 and version_array[1] < 2: + raise SkipTest('MongoDB needs to be 2.2 or higher for this test') + + # Indexes are lazy so use list() to perform query + list(Log.objects) + info = Log.objects._collection.index_information() + self.assertEqual(3600, + info['_cls_1_created_1']['expireAfterSeconds']) + + def test_unique_and_indexes(self): + """Ensure that 'unique' constraints aren't overridden by + meta.indexes. + """ + class Customer(Document): + cust_id = IntField(unique=True, required=True) + meta = { + 'indexes': ['cust_id'], + 'allow_inheritance': False, + } + + Customer.drop_collection() + cust = Customer(cust_id=1) + cust.save() + + cust_dupe = Customer(cust_id=1) + try: + cust_dupe.save() + raise AssertionError, "We saved a dupe!" + except NotUniqueError: + pass + Customer.drop_collection() + + def test_unique_and_primary(self): + """If you set a field as primary, then unexpected behaviour can occur. + You won't create a duplicate but you will update an existing document. + """ + + class User(Document): + name = StringField(primary_key=True, unique=True) + password = StringField() + + User.drop_collection() + + user = User(name='huangz', password='secret') + user.save() + + user = User(name='huangz', password='secret2') + user.save() + + self.assertEqual(User.objects.count(), 1) + self.assertEqual(User.objects.get().password, 'secret2') + + User.drop_collection() + + def test_types_index_with_pk(self): + """Ensure you can use `pk` as part of a query""" + + class Comment(EmbeddedDocument): + comment_id = IntField(required=True) + + try: + class BlogPost(Document): + comments = EmbeddedDocumentField(Comment) + meta = {'indexes': [ + {'fields': ['pk', 'comments.comment_id'], + 'unique': True}]} + except UnboundLocalError: + self.fail('Unbound local error at types index + pk definition') + + info = BlogPost.objects._collection.index_information() + info = [value['key'] for key, value in info.iteritems()] + index_item = [('_cls', 1), ('_id', 1), ('comments.comment_id', 1)] + self.assertTrue(index_item in info) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/document/inheritance.py b/tests/document/inheritance.py new file mode 100644 index 00000000..d269ac0e --- /dev/null +++ b/tests/document/inheritance.py @@ -0,0 +1,395 @@ +# -*- coding: utf-8 -*- +import unittest +import warnings + +from datetime import datetime + +from tests.fixtures import Base + +from mongoengine import Document, EmbeddedDocument, connect +from mongoengine.connection import get_db +from mongoengine.fields import (BooleanField, GenericReferenceField, + IntField, StringField) + +__all__ = ('InheritanceTest', ) + + +class InheritanceTest(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + self.db = get_db() + + def tearDown(self): + for collection in self.db.collection_names(): + if 'system.' in collection: + continue + self.db.drop_collection(collection) + + def test_superclasses(self): + """Ensure that the correct list of superclasses is assembled. + """ + class Animal(Document): + meta = {'allow_inheritance': True} + class Fish(Animal): pass + class Guppy(Fish): pass + class Mammal(Animal): pass + class Dog(Mammal): pass + class Human(Mammal): pass + + self.assertEqual(Animal._superclasses, ()) + self.assertEqual(Fish._superclasses, ('Animal',)) + self.assertEqual(Guppy._superclasses, ('Animal', 'Animal.Fish')) + self.assertEqual(Mammal._superclasses, ('Animal',)) + self.assertEqual(Dog._superclasses, ('Animal', 'Animal.Mammal')) + self.assertEqual(Human._superclasses, ('Animal', 'Animal.Mammal')) + + def test_external_superclasses(self): + """Ensure that the correct list of super classes is assembled when + importing part of the model. + """ + class Animal(Base): pass + class Fish(Animal): pass + class Guppy(Fish): pass + class Mammal(Animal): pass + class Dog(Mammal): pass + class Human(Mammal): pass + + self.assertEqual(Animal._superclasses, ('Base', )) + self.assertEqual(Fish._superclasses, ('Base', 'Base.Animal',)) + self.assertEqual(Guppy._superclasses, ('Base', 'Base.Animal', + 'Base.Animal.Fish')) + self.assertEqual(Mammal._superclasses, ('Base', 'Base.Animal',)) + self.assertEqual(Dog._superclasses, ('Base', 'Base.Animal', + 'Base.Animal.Mammal')) + self.assertEqual(Human._superclasses, ('Base', 'Base.Animal', + 'Base.Animal.Mammal')) + + def test_subclasses(self): + """Ensure that the correct list of _subclasses (subclasses) is + assembled. + """ + class Animal(Document): + meta = {'allow_inheritance': True} + class Fish(Animal): pass + class Guppy(Fish): pass + class Mammal(Animal): pass + class Dog(Mammal): pass + class Human(Mammal): pass + + self.assertEqual(Animal._subclasses, ('Animal', + 'Animal.Fish', + 'Animal.Fish.Guppy', + 'Animal.Mammal', + 'Animal.Mammal.Dog', + 'Animal.Mammal.Human')) + self.assertEqual(Fish._subclasses, ('Animal.Fish', + 'Animal.Fish.Guppy',)) + self.assertEqual(Guppy._subclasses, ('Animal.Fish.Guppy',)) + self.assertEqual(Mammal._subclasses, ('Animal.Mammal', + 'Animal.Mammal.Dog', + 'Animal.Mammal.Human')) + self.assertEqual(Human._subclasses, ('Animal.Mammal.Human',)) + + def test_external_subclasses(self): + """Ensure that the correct list of _subclasses (subclasses) is + assembled when importing part of the model. + """ + class Animal(Base): pass + class Fish(Animal): pass + class Guppy(Fish): pass + class Mammal(Animal): pass + class Dog(Mammal): pass + class Human(Mammal): pass + + self.assertEqual(Animal._subclasses, ('Base.Animal', + 'Base.Animal.Fish', + 'Base.Animal.Fish.Guppy', + 'Base.Animal.Mammal', + 'Base.Animal.Mammal.Dog', + 'Base.Animal.Mammal.Human')) + self.assertEqual(Fish._subclasses, ('Base.Animal.Fish', + 'Base.Animal.Fish.Guppy',)) + self.assertEqual(Guppy._subclasses, ('Base.Animal.Fish.Guppy',)) + self.assertEqual(Mammal._subclasses, ('Base.Animal.Mammal', + 'Base.Animal.Mammal.Dog', + 'Base.Animal.Mammal.Human')) + self.assertEqual(Human._subclasses, ('Base.Animal.Mammal.Human',)) + + def test_dynamic_declarations(self): + """Test that declaring an extra class updates meta data""" + + class Animal(Document): + meta = {'allow_inheritance': True} + + self.assertEqual(Animal._superclasses, ()) + self.assertEqual(Animal._subclasses, ('Animal',)) + + # Test dynamically adding a class changes the meta data + class Fish(Animal): + pass + + self.assertEqual(Animal._superclasses, ()) + self.assertEqual(Animal._subclasses, ('Animal', 'Animal.Fish')) + + self.assertEqual(Fish._superclasses, ('Animal', )) + self.assertEqual(Fish._subclasses, ('Animal.Fish',)) + + # Test dynamically adding an inherited class changes the meta data + class Pike(Fish): + pass + + self.assertEqual(Animal._superclasses, ()) + self.assertEqual(Animal._subclasses, ('Animal', 'Animal.Fish', + 'Animal.Fish.Pike')) + + self.assertEqual(Fish._superclasses, ('Animal', )) + self.assertEqual(Fish._subclasses, ('Animal.Fish', 'Animal.Fish.Pike')) + + self.assertEqual(Pike._superclasses, ('Animal', 'Animal.Fish')) + self.assertEqual(Pike._subclasses, ('Animal.Fish.Pike',)) + + def test_inheritance_meta_data(self): + """Ensure that document may inherit fields from a superclass document. + """ + class Person(Document): + name = StringField() + age = IntField() + + meta = {'allow_inheritance': True} + + class Employee(Person): + salary = IntField() + + self.assertEqual(['salary', 'age', 'name', 'id'], + Employee._fields.keys()) + self.assertEqual(Employee._get_collection_name(), + Person._get_collection_name()) + + + def test_polymorphic_queries(self): + """Ensure that the correct subclasses are returned from a query + """ + + class Animal(Document): + meta = {'allow_inheritance': True} + class Fish(Animal): pass + class Mammal(Animal): pass + class Dog(Mammal): pass + class Human(Mammal): pass + + Animal.drop_collection() + + Animal().save() + Fish().save() + Mammal().save() + Dog().save() + Human().save() + + classes = [obj.__class__ for obj in Animal.objects] + self.assertEqual(classes, [Animal, Fish, Mammal, Dog, Human]) + + classes = [obj.__class__ for obj in Mammal.objects] + self.assertEqual(classes, [Mammal, Dog, Human]) + + classes = [obj.__class__ for obj in Human.objects] + self.assertEqual(classes, [Human]) + + + def test_allow_inheritance(self): + """Ensure that inheritance may be disabled on simple classes and that + _cls and _subclasses will not be used. + """ + + class Animal(Document): + name = StringField() + meta = {'allow_inheritance': False} + + def create_dog_class(): + class Dog(Animal): + pass + + self.assertRaises(ValueError, create_dog_class) + + # Check that _cls etc aren't present on simple documents + dog = Animal(name='dog') + dog.save() + + collection = self.db[Animal._get_collection_name()] + obj = collection.find_one() + self.assertFalse('_cls' in obj) + + def test_cant_turn_off_inheritance_on_subclass(self): + """Ensure if inheritance is on in a subclass you cant turn it off + """ + + class Animal(Document): + name = StringField() + meta = {'allow_inheritance': True} + + def create_mammal_class(): + class Mammal(Animal): + meta = {'allow_inheritance': False} + self.assertRaises(ValueError, create_mammal_class) + + def test_allow_inheritance_abstract_document(self): + """Ensure that abstract documents can set inheritance rules and that + _cls will not be used. + """ + class FinalDocument(Document): + meta = {'abstract': True, + 'allow_inheritance': False} + + class Animal(FinalDocument): + name = StringField() + + def create_mammal_class(): + class Mammal(Animal): + pass + self.assertRaises(ValueError, create_mammal_class) + + # Check that _cls isn't present in simple documents + doc = Animal(name='dog') + self.assertFalse('_cls' in doc.to_mongo()) + + def test_allow_inheritance_embedded_document(self): + """Ensure embedded documents respect inheritance + """ + + class Comment(EmbeddedDocument): + content = StringField() + meta = {'allow_inheritance': False} + + def create_special_comment(): + class SpecialComment(Comment): + pass + + self.assertRaises(ValueError, create_special_comment) + + doc = Comment(content='test') + self.assertFalse('_cls' in doc.to_mongo()) + + class Comment(EmbeddedDocument): + content = StringField() + meta = {'allow_inheritance': True} + + doc = Comment(content='test') + self.assertTrue('_cls' in doc.to_mongo()) + + def test_document_inheritance(self): + """Ensure mutliple inheritance of abstract documents + """ + class DateCreatedDocument(Document): + meta = { + 'allow_inheritance': True, + 'abstract': True, + } + + class DateUpdatedDocument(Document): + meta = { + 'allow_inheritance': True, + 'abstract': True, + } + + try: + class MyDocument(DateCreatedDocument, DateUpdatedDocument): + pass + except: + self.assertTrue(False, "Couldn't create MyDocument class") + + def test_abstract_documents(self): + """Ensure that a document superclass can be marked as abstract + thereby not using it as the name for the collection.""" + + defaults = {'index_background': True, + 'index_drop_dups': True, + 'index_opts': {'hello': 'world'}, + 'allow_inheritance': True, + 'queryset_class': 'QuerySet', + 'db_alias': 'myDB', + 'shard_key': ('hello', 'world')} + + meta_settings = {'abstract': True} + meta_settings.update(defaults) + + class Animal(Document): + name = StringField() + meta = meta_settings + + class Fish(Animal): pass + class Guppy(Fish): pass + + class Mammal(Animal): + meta = {'abstract': True} + class Human(Mammal): pass + + for k, v in defaults.iteritems(): + for cls in [Animal, Fish, Guppy]: + self.assertEqual(cls._meta[k], v) + + self.assertFalse('collection' in Animal._meta) + self.assertFalse('collection' in Mammal._meta) + + self.assertEqual(Animal._get_collection_name(), None) + self.assertEqual(Mammal._get_collection_name(), None) + + self.assertEqual(Fish._get_collection_name(), 'fish') + self.assertEqual(Guppy._get_collection_name(), 'fish') + self.assertEqual(Human._get_collection_name(), 'human') + + def create_bad_abstract(): + class EvilHuman(Human): + evil = BooleanField(default=True) + meta = {'abstract': True} + self.assertRaises(ValueError, create_bad_abstract) + + def test_inherited_collections(self): + """Ensure that subclassed documents don't override parents' + collections + """ + + class Drink(Document): + name = StringField() + meta = {'allow_inheritance': True} + + class Drinker(Document): + drink = GenericReferenceField() + + try: + warnings.simplefilter("error") + + class AcloholicDrink(Drink): + meta = {'collection': 'booze'} + + except SyntaxWarning: + warnings.simplefilter("ignore") + + class AlcoholicDrink(Drink): + meta = {'collection': 'booze'} + + else: + raise AssertionError("SyntaxWarning should be triggered") + + warnings.resetwarnings() + + Drink.drop_collection() + AlcoholicDrink.drop_collection() + Drinker.drop_collection() + + red_bull = Drink(name='Red Bull') + red_bull.save() + + programmer = Drinker(drink=red_bull) + programmer.save() + + beer = AlcoholicDrink(name='Beer') + beer.save() + real_person = Drinker(drink=beer) + real_person.save() + + self.assertEqual(Drinker.objects[0].drink.name, red_bull.name) + self.assertEqual(Drinker.objects[1].drink.name, beer.name) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_document.py b/tests/document/instance.py similarity index 50% rename from tests/test_document.py rename to tests/document/instance.py index a09aaeca..95f37d9e 100644 --- a/tests/test_document.py +++ b/tests/document/instance.py @@ -15,14 +15,17 @@ from datetime import datetime from tests.fixtures import Base, Mixin, PickleEmbedded, PickleTest from mongoengine import * -from mongoengine.base import NotRegistered, InvalidDocumentError -from mongoengine.queryset import InvalidQueryError +from mongoengine.errors import (NotRegistered, InvalidDocumentError, + InvalidQueryError) +from mongoengine.queryset import NULLIFY, Q from mongoengine.connection import get_db, get_connection TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png') +__all__ = ("InstanceTest",) -class DocumentTest(unittest.TestCase): + +class InstanceTest(unittest.TestCase): def setUp(self): connect(db='mongoenginetest') @@ -32,59 +35,57 @@ class DocumentTest(unittest.TestCase): name = StringField() age = IntField() - meta = {'allow_inheritance': True} + non_field = True + + meta = {"allow_inheritance": True} self.Person = Person def tearDown(self): - self.Person.drop_collection() + for collection in self.db.collection_names(): + if 'system.' in collection: + continue + self.db.drop_collection(collection) - def test_drop_collection(self): - """Ensure that the collection may be dropped from the database. + def test_capped_collection(self): + """Ensure that capped collections work properly. """ - self.Person(name='Test').save() + class Log(Document): + date = DateTimeField(default=datetime.now) + meta = { + 'max_documents': 10, + 'max_size': 90000, + } - collection = self.Person._get_collection_name() - self.assertTrue(collection in self.db.collection_names()) + Log.drop_collection() - self.Person.drop_collection() - self.assertFalse(collection in self.db.collection_names()) + # Ensure that the collection handles up to its maximum + for _ in range(10): + Log().save() - def test_queryset_resurrects_dropped_collection(self): + self.assertEqual(len(Log.objects), 10) - self.Person.objects().item_frequencies('name') - self.Person.drop_collection() + # Check that extra documents don't increase the size + Log().save() + self.assertEqual(len(Log.objects), 10) - self.assertEqual({}, self.Person.objects().item_frequencies('name')) + options = Log.objects._collection.options() + self.assertEqual(options['capped'], True) + self.assertEqual(options['max'], 10) + self.assertEqual(options['size'], 90000) - class Actor(self.Person): - pass + # Check that the document cannot be redefined with different options + def recreate_log_document(): + class Log(Document): + date = DateTimeField(default=datetime.now) + meta = { + 'max_documents': 11, + } + # Create the collection by accessing Document.objects + Log.objects + self.assertRaises(InvalidCollectionError, recreate_log_document) - # Ensure works correctly with inhertited classes - Actor.objects().item_frequencies('name') - self.Person.drop_collection() - self.assertEqual({}, Actor.objects().item_frequencies('name')) - - def test_definition(self): - """Ensure that document may be defined using fields. - """ - name_field = StringField() - age_field = IntField() - - class Person(Document): - name = name_field - age = age_field - non_field = True - - self.assertEqual(Person._fields['name'], name_field) - self.assertEqual(Person._fields['age'], age_field) - self.assertFalse('non_field' in Person._fields) - self.assertTrue('id' in Person._fields) - # Test iteration over fields - fields = list(Person()) - self.assertTrue('name' in fields and 'age' in fields) - # Ensure Document isn't treated like an actual document - self.assertFalse(hasattr(Document, '_fields')) + Log.drop_collection() def test_repr(self): """Ensure that unicode representation works @@ -95,146 +96,22 @@ class DocumentTest(unittest.TestCase): def __unicode__(self): return self.title - Article.drop_collection() + doc = Article(title=u'привет мир') - Article(title=u'привет мир').save() + self.assertEqual('', repr(doc)) - self.assertEqual('', repr(Article.objects.first())) - self.assertEqual('[]', repr(Article.objects.all())) + def test_queryset_resurrects_dropped_collection(self): + self.Person.drop_collection() - def test_collection_naming(self): - """Ensure that a collection with a specified name may be used. - """ + self.assertEqual([], list(self.Person.objects())) - class DefaultNamingTest(Document): - pass - self.assertEqual('default_naming_test', DefaultNamingTest._get_collection_name()) - - class CustomNamingTest(Document): - meta = {'collection': 'pimp_my_collection'} - - self.assertEqual('pimp_my_collection', CustomNamingTest._get_collection_name()) - - class DynamicNamingTest(Document): - meta = {'collection': lambda c: "DYNAMO"} - self.assertEqual('DYNAMO', DynamicNamingTest._get_collection_name()) - - # Use Abstract class to handle backwards compatibility - class BaseDocument(Document): - meta = { - 'abstract': True, - 'collection': lambda c: c.__name__.lower() - } - - class OldNamingConvention(BaseDocument): - pass - self.assertEqual('oldnamingconvention', OldNamingConvention._get_collection_name()) - - class InheritedAbstractNamingTest(BaseDocument): - meta = {'collection': 'wibble'} - self.assertEqual('wibble', InheritedAbstractNamingTest._get_collection_name()) - - - # Mixin tests - class BaseMixin(object): - meta = { - 'collection': lambda c: c.__name__.lower() - } - - class OldMixinNamingConvention(Document, BaseMixin): - pass - self.assertEqual('oldmixinnamingconvention', OldMixinNamingConvention._get_collection_name()) - - class BaseMixin(object): - meta = { - 'collection': lambda c: c.__name__.lower() - } - - class BaseDocument(Document, BaseMixin): - meta = {'allow_inheritance': True} - - class MyDocument(BaseDocument): + class Actor(self.Person): pass - self.assertEqual('basedocument', MyDocument._get_collection_name()) - - def test_get_superclasses(self): - """Ensure that the correct list of superclasses is assembled. - """ - class Animal(Document): - meta = {'allow_inheritance': True} - class Fish(Animal): pass - class Mammal(Animal): pass - class Human(Mammal): pass - class Dog(Mammal): pass - - mammal_superclasses = {'Animal': Animal} - self.assertEqual(Mammal._superclasses, mammal_superclasses) - - dog_superclasses = { - 'Animal': Animal, - 'Animal.Mammal': Mammal, - } - self.assertEqual(Dog._superclasses, dog_superclasses) - - def test_external_superclasses(self): - """Ensure that the correct list of sub and super classes is assembled. - when importing part of the model - """ - class Animal(Base): pass - class Fish(Animal): pass - class Mammal(Animal): pass - class Human(Mammal): pass - class Dog(Mammal): pass - - mammal_superclasses = {'Base': Base, 'Base.Animal': Animal} - self.assertEqual(Mammal._superclasses, mammal_superclasses) - - dog_superclasses = { - 'Base': Base, - 'Base.Animal': Animal, - 'Base.Animal.Mammal': Mammal, - } - self.assertEqual(Dog._superclasses, dog_superclasses) - - Base.drop_collection() - - h = Human() - h.save() - - self.assertEqual(Human.objects.count(), 1) - self.assertEqual(Mammal.objects.count(), 1) - self.assertEqual(Animal.objects.count(), 1) - self.assertEqual(Base.objects.count(), 1) - Base.drop_collection() - - def test_polymorphic_queries(self): - """Ensure that the correct subclasses are returned from a query""" - class Animal(Document): - meta = {'allow_inheritance': True} - class Fish(Animal): pass - class Mammal(Animal): pass - class Human(Mammal): pass - class Dog(Mammal): pass - - Animal.drop_collection() - - Animal().save() - Fish().save() - Mammal().save() - Human().save() - Dog().save() - - classes = [obj.__class__ for obj in Animal.objects] - self.assertEqual(classes, [Animal, Fish, Mammal, Human, Dog]) - - classes = [obj.__class__ for obj in Mammal.objects] - self.assertEqual(classes, [Mammal, Human, Dog]) - - classes = [obj.__class__ for obj in Human.objects] - self.assertEqual(classes, [Human]) - - Animal.drop_collection() + # Ensure works correctly with inhertited classes + Actor.objects() + self.Person.drop_collection() + self.assertEqual([], list(Actor.objects())) def test_polymorphic_references(self): """Ensure that the correct subclasses are returned from a query when @@ -244,8 +121,8 @@ class DocumentTest(unittest.TestCase): meta = {'allow_inheritance': True} class Fish(Animal): pass class Mammal(Animal): pass - class Human(Mammal): pass class Dog(Mammal): pass + class Human(Mammal): pass class Zoo(Document): animals = ListField(ReferenceField(Animal)) @@ -256,8 +133,8 @@ class DocumentTest(unittest.TestCase): Animal().save() Fish().save() Mammal().save() - Human().save() Dog().save() + Human().save() # Save a reference to each animal zoo = Zoo(animals=Animal.objects) @@ -265,7 +142,7 @@ class DocumentTest(unittest.TestCase): zoo.reload() classes = [a.__class__ for a in Zoo.objects.first().animals] - self.assertEqual(classes, [Animal, Fish, Mammal, Human, Dog]) + self.assertEqual(classes, [Animal, Fish, Mammal, Dog, Human]) Zoo.drop_collection() @@ -278,7 +155,7 @@ class DocumentTest(unittest.TestCase): zoo.reload() classes = [a.__class__ for a in Zoo.objects.first().animals] - self.assertEqual(classes, [Animal, Fish, Mammal, Human, Dog]) + self.assertEqual(classes, [Animal, Fish, Mammal, Dog, Human]) Zoo.drop_collection() Animal.drop_collection() @@ -308,466 +185,8 @@ class DocumentTest(unittest.TestCase): self.assertEqual(list_stats, CompareStats.objects.first().stats) - def test_inheritance(self): - """Ensure that document may inherit fields from a superclass document. - """ - class Employee(self.Person): - salary = IntField() - self.assertTrue('name' in Employee._fields) - self.assertTrue('salary' in Employee._fields) - self.assertEqual(Employee._get_collection_name(), - self.Person._get_collection_name()) - # Ensure that MRO error is not raised - class A(Document): - meta = {'allow_inheritance': True} - class B(A): pass - class C(B): pass - - def test_allow_inheritance(self): - """Ensure that inheritance may be disabled on simple classes and that - _cls and _types will not be used. - """ - - class Animal(Document): - name = StringField() - meta = {'allow_inheritance': False} - - Animal.drop_collection() - def create_dog_class(): - class Dog(Animal): - pass - self.assertRaises(ValueError, create_dog_class) - - # Check that _cls etc aren't present on simple documents - dog = Animal(name='dog') - dog.save() - collection = self.db[Animal._get_collection_name()] - obj = collection.find_one() - self.assertFalse('_cls' in obj) - self.assertFalse('_types' in obj) - - Animal.drop_collection() - - def create_employee_class(): - class Employee(self.Person): - meta = {'allow_inheritance': False} - self.assertRaises(ValueError, create_employee_class) - - def test_allow_inheritance_abstract_document(self): - """Ensure that abstract documents can set inheritance rules and that - _cls and _types will not be used. - """ - class FinalDocument(Document): - meta = {'abstract': True, - 'allow_inheritance': False} - - class Animal(FinalDocument): - name = StringField() - - Animal.drop_collection() - def create_dog_class(): - class Dog(Animal): - pass - self.assertRaises(ValueError, create_dog_class) - - # Check that _cls etc aren't present on simple documents - dog = Animal(name='dog') - dog.save() - collection = self.db[Animal._get_collection_name()] - obj = collection.find_one() - self.assertFalse('_cls' in obj) - self.assertFalse('_types' in obj) - - Animal.drop_collection() - - def test_allow_inheritance_embedded_document(self): - - # Test the same for embedded documents - class Comment(EmbeddedDocument): - content = StringField() - meta = {'allow_inheritance': False} - - def create_special_comment(): - class SpecialComment(Comment): - pass - - self.assertRaises(ValueError, create_special_comment) - - comment = Comment(content='test') - self.assertFalse('_cls' in comment.to_mongo()) - self.assertFalse('_types' in comment.to_mongo()) - - class Comment(EmbeddedDocument): - content = StringField() - meta = {'allow_inheritance': True} - - comment = Comment(content='test') - self.assertTrue('_cls' in comment.to_mongo()) - self.assertTrue('_types' in comment.to_mongo()) - - def test_document_inheritance(self): - """Ensure mutliple inheritance of abstract docs works - """ - class DateCreatedDocument(Document): - meta = { - 'allow_inheritance': True, - 'abstract': True, - } - - class DateUpdatedDocument(Document): - meta = { - 'allow_inheritance': True, - 'abstract': True, - } - - try: - class MyDocument(DateCreatedDocument, DateUpdatedDocument): - pass - except: - self.assertTrue(False, "Couldn't create MyDocument class") - - def test_how_to_turn_off_inheritance(self): - """Demonstrates migrating from allow_inheritance = True to False. - """ - class Animal(Document): - name = StringField() - meta = { - 'indexes': ['name'] - } - - self.assertEqual(Animal._meta['index_specs'], - [{'fields': [('_types', 1), ('name', 1)]}]) - - Animal.drop_collection() - - dog = Animal(name='dog') - dog.save() - - collection = self.db[Animal._get_collection_name()] - obj = collection.find_one() - self.assertTrue('_cls' in obj) - self.assertTrue('_types' in obj) - - info = collection.index_information() - info = [value['key'] for key, value in info.iteritems()] - self.assertEqual([[(u'_id', 1)], [(u'_types', 1), (u'name', 1)]], info) - - # Turn off inheritance - class Animal(Document): - name = StringField() - meta = { - 'allow_inheritance': False, - 'indexes': ['name'] - } - - self.assertEqual(Animal._meta['index_specs'], - [{'fields': [('name', 1)]}]) - collection.update({}, {"$unset": {"_types": 1, "_cls": 1}}, multi=True) - - # Confirm extra data is removed - obj = collection.find_one() - self.assertFalse('_cls' in obj) - self.assertFalse('_types' in obj) - - info = collection.index_information() - info = [value['key'] for key, value in info.iteritems()] - self.assertEqual([[(u'_id', 1)], [(u'_types', 1), (u'name', 1)]], info) - - info = collection.index_information() - indexes_to_drop = [key for key, value in info.iteritems() if '_types' in dict(value['key'])] - for index in indexes_to_drop: - collection.drop_index(index) - - info = collection.index_information() - info = [value['key'] for key, value in info.iteritems()] - self.assertEqual([[(u'_id', 1)]], info) - - # Recreate indexes - dog = Animal.objects.first() - dog.save() - info = collection.index_information() - info = [value['key'] for key, value in info.iteritems()] - self.assertEqual([[(u'_id', 1)], [(u'name', 1),]], info) - - Animal.drop_collection() - - def test_abstract_documents(self): - """Ensure that a document superclass can be marked as abstract - thereby not using it as the name for the collection.""" - - defaults = {'index_background': True, - 'index_drop_dups': True, - 'index_opts': {'hello': 'world'}, - 'allow_inheritance': True, - 'queryset_class': 'QuerySet', - 'db_alias': 'myDB', - 'shard_key': ('hello', 'world')} - - meta_settings = {'abstract': True} - meta_settings.update(defaults) - - class Animal(Document): - name = StringField() - meta = meta_settings - - class Fish(Animal): pass - class Guppy(Fish): pass - - class Mammal(Animal): - meta = {'abstract': True} - class Human(Mammal): pass - - for k, v in defaults.iteritems(): - for cls in [Animal, Fish, Guppy]: - self.assertEqual(cls._meta[k], v) - - self.assertFalse('collection' in Animal._meta) - self.assertFalse('collection' in Mammal._meta) - - self.assertEqual(Animal._get_collection_name(), None) - self.assertEqual(Mammal._get_collection_name(), None) - - self.assertEqual(Fish._get_collection_name(), 'fish') - self.assertEqual(Guppy._get_collection_name(), 'fish') - self.assertEqual(Human._get_collection_name(), 'human') - - def create_bad_abstract(): - class EvilHuman(Human): - evil = BooleanField(default=True) - meta = {'abstract': True} - self.assertRaises(ValueError, create_bad_abstract) - - def test_collection_name(self): - """Ensure that a collection with a specified name may be used. - """ - collection = 'personCollTest' - if collection in self.db.collection_names(): - self.db.drop_collection(collection) - - class Person(Document): - name = StringField() - meta = {'collection': collection} - - user = Person(name="Test User") - user.save() - self.assertTrue(collection in self.db.collection_names()) - - user_obj = self.db[collection].find_one() - self.assertEqual(user_obj['name'], "Test User") - - user_obj = Person.objects[0] - self.assertEqual(user_obj.name, "Test User") - - Person.drop_collection() - self.assertFalse(collection in self.db.collection_names()) - - def test_collection_name_and_primary(self): - """Ensure that a collection with a specified name may be used. - """ - - class Person(Document): - name = StringField(primary_key=True) - meta = {'collection': 'app'} - - user = Person(name="Test User") - user.save() - - user_obj = Person.objects[0] - self.assertEqual(user_obj.name, "Test User") - - Person.drop_collection() - - def test_inherited_collections(self): - """Ensure that subclassed documents don't override parents' collections. - """ - - class Drink(Document): - name = StringField() - meta = {'allow_inheritance': True} - - class Drinker(Document): - drink = GenericReferenceField() - - try: - warnings.simplefilter("error") - - class AcloholicDrink(Drink): - meta = {'collection': 'booze'} - - except SyntaxWarning, w: - warnings.simplefilter("ignore") - - class AlcoholicDrink(Drink): - meta = {'collection': 'booze'} - - else: - raise AssertionError("SyntaxWarning should be triggered") - - warnings.resetwarnings() - - Drink.drop_collection() - AlcoholicDrink.drop_collection() - Drinker.drop_collection() - - red_bull = Drink(name='Red Bull') - red_bull.save() - - programmer = Drinker(drink=red_bull) - programmer.save() - - beer = AlcoholicDrink(name='Beer') - beer.save() - real_person = Drinker(drink=beer) - real_person.save() - - self.assertEqual(Drinker.objects[0].drink.name, red_bull.name) - self.assertEqual(Drinker.objects[1].drink.name, beer.name) - - def test_capped_collection(self): - """Ensure that capped collections work properly. - """ - class Log(Document): - date = DateTimeField(default=datetime.now) - meta = { - 'max_documents': 10, - 'max_size': 90000, - } - - Log.drop_collection() - - # Ensure that the collection handles up to its maximum - for i in range(10): - Log().save() - - self.assertEqual(len(Log.objects), 10) - - # Check that extra documents don't increase the size - Log().save() - self.assertEqual(len(Log.objects), 10) - - options = Log.objects._collection.options() - self.assertEqual(options['capped'], True) - self.assertEqual(options['max'], 10) - self.assertEqual(options['size'], 90000) - - # Check that the document cannot be redefined with different options - def recreate_log_document(): - class Log(Document): - date = DateTimeField(default=datetime.now) - meta = { - 'max_documents': 11, - } - # Create the collection by accessing Document.objects - Log.objects - self.assertRaises(InvalidCollectionError, recreate_log_document) - - Log.drop_collection() - - def test_indexes(self): - """Ensure that indexes are used when meta[indexes] is specified. - """ - class BlogPost(Document): - date = DateTimeField(db_field='addDate', default=datetime.now) - category = StringField() - tags = ListField(StringField()) - meta = { - 'indexes': [ - '-date', - 'tags', - ('category', '-date') - ], - 'allow_inheritance': True - } - - self.assertEqual(BlogPost._meta['index_specs'], - [{'fields': [('_types', 1), ('addDate', -1)]}, - {'fields': [('tags', 1)]}, - {'fields': [('_types', 1), ('category', 1), - ('addDate', -1)]}]) - - BlogPost.drop_collection() - - info = BlogPost.objects._collection.index_information() - # _id, '-date', 'tags', ('cat', 'date') - # NB: there is no index on _types by itself, since - # the indices on -date and tags will both contain - # _types as first element in the key - self.assertEqual(len(info), 4) - - # Indexes are lazy so use list() to perform query - list(BlogPost.objects) - info = BlogPost.objects._collection.index_information() - info = [value['key'] for key, value in info.iteritems()] - self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)] - in info) - self.assertTrue([('_types', 1), ('addDate', -1)] in info) - # tags is a list field so it shouldn't have _types in the index - self.assertTrue([('tags', 1)] in info) - - class ExtendedBlogPost(BlogPost): - title = StringField() - meta = {'indexes': ['title']} - - self.assertEqual(ExtendedBlogPost._meta['index_specs'], - [{'fields': [('_types', 1), ('addDate', -1)]}, - {'fields': [('tags', 1)]}, - {'fields': [('_types', 1), ('category', 1), - ('addDate', -1)]}, - {'fields': [('_types', 1), ('title', 1)]}]) - - BlogPost.drop_collection() - - list(ExtendedBlogPost.objects) - info = ExtendedBlogPost.objects._collection.index_information() - info = [value['key'] for key, value in info.iteritems()] - self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)] - in info) - self.assertTrue([('_types', 1), ('addDate', -1)] in info) - self.assertTrue([('_types', 1), ('title', 1)] in info) - - BlogPost.drop_collection() - - def test_inherited_index(self): - """Ensure index specs are inhertited correctly""" - - class A(Document): - title = StringField() - meta = { - 'indexes': [ - { - 'fields': ('title',), - }, - ], - 'allow_inheritance': True, - } - - class B(A): - description = StringField() - - self.assertEqual(A._meta['index_specs'], B._meta['index_specs']) - self.assertEqual([{'fields': [('_types', 1), ('title', 1)]}], - A._meta['index_specs']) - - def test_build_index_spec_is_not_destructive(self): - - class MyDoc(Document): - keywords = StringField() - - meta = { - 'indexes': ['keywords'], - 'allow_inheritance': False - } - - self.assertEqual(MyDoc._meta['index_specs'], - [{'fields': [('keywords', 1)]}]) - - # Force index creation - MyDoc.objects._ensure_indexes() - - self.assertEqual(MyDoc._meta['index_specs'], - [{'fields': [('keywords', 1)]}]) def test_db_field_load(self): """Ensure we load data correctly @@ -812,477 +231,8 @@ class DocumentTest(unittest.TestCase): self.assertEqual(Person.objects.get(name="Jack").rank, "Corporal") self.assertEqual(Person.objects.get(name="Fred").rank, "Private") - def test_embedded_document_index_meta(self): - """Ensure that embedded document indexes are created explicitly - """ - class Rank(EmbeddedDocument): - title = StringField(required=True) - class Person(Document): - name = StringField(required=True) - rank = EmbeddedDocumentField(Rank, required=False) - meta = { - 'indexes': [ - 'rank.title', - ], - 'allow_inheritance': False - } - - self.assertEqual([{'fields': [('rank.title', 1)]}], - Person._meta['index_specs']) - - Person.drop_collection() - - # Indexes are lazy so use list() to perform query - list(Person.objects) - info = Person.objects._collection.index_information() - info = [value['key'] for key, value in info.iteritems()] - self.assertTrue([('rank.title', 1)] in info) - - def test_explicit_geo2d_index(self): - """Ensure that geo2d indexes work when created via meta[indexes] - """ - class Place(Document): - location = DictField() - meta = { - 'indexes': [ - '*location.point', - ], - } - - self.assertEqual([{'fields': [('location.point', '2d')]}], - Place._meta['index_specs']) - - Place.drop_collection() - - info = Place.objects._collection.index_information() - # Indexes are lazy so use list() to perform query - list(Place.objects) - info = Place.objects._collection.index_information() - info = [value['key'] for key, value in info.iteritems()] - - self.assertTrue([('location.point', '2d')] in info) - - def test_dictionary_indexes(self): - """Ensure that indexes are used when meta[indexes] contains dictionaries - instead of lists. - """ - class BlogPost(Document): - date = DateTimeField(db_field='addDate', default=datetime.now) - category = StringField() - tags = ListField(StringField()) - meta = { - 'indexes': [ - {'fields': ['-date'], 'unique': True, - 'sparse': True, 'types': False }, - ], - } - - self.assertEqual([{'fields': [('addDate', -1)], 'unique': True, - 'sparse': True, 'types': False}], - BlogPost._meta['index_specs']) - - BlogPost.drop_collection() - - info = BlogPost.objects._collection.index_information() - # _id, '-date' - self.assertEqual(len(info), 3) - - # Indexes are lazy so use list() to perform query - list(BlogPost.objects) - info = BlogPost.objects._collection.index_information() - info = [(value['key'], - value.get('unique', False), - value.get('sparse', False)) - for key, value in info.iteritems()] - self.assertTrue(([('addDate', -1)], True, True) in info) - - BlogPost.drop_collection() - - def test_abstract_index_inheritance(self): - - class UserBase(Document): - meta = { - 'abstract': True, - 'indexes': ['user_guid'] - } - - user_guid = StringField(required=True) - - class Person(UserBase): - meta = { - 'indexes': ['name'], - } - - name = StringField() - - Person.drop_collection() - - p = Person(name="test", user_guid='123') - p.save() - - self.assertEqual(1, Person.objects.count()) - info = Person.objects._collection.index_information() - self.assertEqual(info.keys(), ['_types_1_user_guid_1', '_id_', '_types_1_name_1']) - Person.drop_collection() - - def test_disable_index_creation(self): - """Tests setting auto_create_index to False on the connection will - disable any index generation. - """ - class User(Document): - meta = { - 'indexes': ['user_guid'], - 'auto_create_index': False - } - user_guid = StringField(required=True) - - - User.drop_collection() - - u = User(user_guid='123') - u.save() - - self.assertEqual(1, User.objects.count()) - info = User.objects._collection.index_information() - self.assertEqual(info.keys(), ['_id_']) - User.drop_collection() - - def test_embedded_document_index(self): - """Tests settings an index on an embedded document - """ - class Date(EmbeddedDocument): - year = IntField(db_field='yr') - - class BlogPost(Document): - title = StringField() - date = EmbeddedDocumentField(Date) - - meta = { - 'indexes': [ - '-date.year' - ], - } - - BlogPost.drop_collection() - - info = BlogPost.objects._collection.index_information() - self.assertEqual(info.keys(), ['_types_1_date.yr_-1', '_id_']) - BlogPost.drop_collection() - - def test_list_embedded_document_index(self): - """Ensure list embedded documents can be indexed - """ - class Tag(EmbeddedDocument): - name = StringField(db_field='tag') - - class BlogPost(Document): - title = StringField() - tags = ListField(EmbeddedDocumentField(Tag)) - - meta = { - 'indexes': [ - 'tags.name' - ], - } - - BlogPost.drop_collection() - - info = BlogPost.objects._collection.index_information() - # we don't use _types in with list fields by default - self.assertEqual(info.keys(), ['_id_', '_types_1', 'tags.tag_1']) - - post1 = BlogPost(title="Embedded Indexes tests in place", - tags=[Tag(name="about"), Tag(name="time")] - ) - post1.save() - BlogPost.drop_collection() - - def test_recursive_embedded_objects_dont_break_indexes(self): - - class RecursiveObject(EmbeddedDocument): - obj = EmbeddedDocumentField('self') - - class RecursiveDocument(Document): - recursive_obj = EmbeddedDocumentField(RecursiveObject) - - info = RecursiveDocument.objects._collection.index_information() - self.assertEqual(info.keys(), ['_id_', '_types_1']) - - def test_geo_indexes_recursion(self): - - class Location(Document): - name = StringField() - location = GeoPointField() - - class Parent(Document): - name = StringField() - location = ReferenceField(Location) - - Location.drop_collection() - Parent.drop_collection() - - list(Parent.objects) - - collection = Parent._get_collection() - info = collection.index_information() - - self.assertFalse('location_2d' in info) - - self.assertEqual(len(Parent._geo_indices()), 0) - self.assertEqual(len(Location._geo_indices()), 1) - - def test_covered_index(self): - """Ensure that covered indexes can be used - """ - - class Test(Document): - a = IntField() - - meta = { - 'indexes': ['a'], - 'allow_inheritance': False - } - - Test.drop_collection() - - obj = Test(a=1) - obj.save() - - # Need to be explicit about covered indexes as mongoDB doesn't know if - # the documents returned might have more keys in that here. - query_plan = Test.objects(id=obj.id).exclude('a').explain() - self.assertFalse(query_plan['indexOnly']) - - query_plan = Test.objects(id=obj.id).only('id').explain() - self.assertTrue(query_plan['indexOnly']) - - query_plan = Test.objects(a=1).only('a').exclude('id').explain() - self.assertTrue(query_plan['indexOnly']) - - def test_index_on_id(self): - - class BlogPost(Document): - meta = { - 'indexes': [ - ['categories', 'id'] - ], - 'allow_inheritance': False - } - - title = StringField(required=True) - description = StringField(required=True) - categories = ListField() - - BlogPost.drop_collection() - - indexes = BlogPost.objects._collection.index_information() - self.assertEqual(indexes['categories_1__id_1']['key'], - [('categories', 1), ('_id', 1)]) - - def test_hint(self): - - class BlogPost(Document): - tags = ListField(StringField()) - meta = { - 'indexes': [ - 'tags', - ], - } - - BlogPost.drop_collection() - - for i in xrange(0, 10): - tags = [("tag %i" % n) for n in xrange(0, i % 2)] - BlogPost(tags=tags).save() - - self.assertEqual(BlogPost.objects.count(), 10) - self.assertEqual(BlogPost.objects.hint().count(), 10) - self.assertEqual(BlogPost.objects.hint([('tags', 1)]).count(), 10) - - self.assertEqual(BlogPost.objects.hint([('ZZ', 1)]).count(), 10) - - def invalid_index(): - BlogPost.objects.hint('tags') - self.assertRaises(TypeError, invalid_index) - - def invalid_index_2(): - return BlogPost.objects.hint(('tags', 1)) - self.assertRaises(TypeError, invalid_index_2) - - def test_unique(self): - """Ensure that uniqueness constraints are applied to fields. - """ - class BlogPost(Document): - title = StringField() - slug = StringField(unique=True) - - BlogPost.drop_collection() - - post1 = BlogPost(title='test1', slug='test') - post1.save() - - # Two posts with the same slug is not allowed - post2 = BlogPost(title='test2', slug='test') - self.assertRaises(NotUniqueError, post2.save) - - # Ensure backwards compatibilty for errors - self.assertRaises(OperationError, post2.save) - - def test_unique_with(self): - """Ensure that unique_with constraints are applied to fields. - """ - class Date(EmbeddedDocument): - year = IntField(db_field='yr') - - class BlogPost(Document): - title = StringField() - date = EmbeddedDocumentField(Date) - slug = StringField(unique_with='date.year') - - BlogPost.drop_collection() - - post1 = BlogPost(title='test1', date=Date(year=2009), slug='test') - post1.save() - - # day is different so won't raise exception - post2 = BlogPost(title='test2', date=Date(year=2010), slug='test') - post2.save() - - # Now there will be two docs with the same slug and the same day: fail - post3 = BlogPost(title='test3', date=Date(year=2010), slug='test') - self.assertRaises(OperationError, post3.save) - - BlogPost.drop_collection() - - def test_unique_embedded_document(self): - """Ensure that uniqueness constraints are applied to fields on embedded documents. - """ - class SubDocument(EmbeddedDocument): - year = IntField(db_field='yr') - slug = StringField(unique=True) - - class BlogPost(Document): - title = StringField() - sub = EmbeddedDocumentField(SubDocument) - - BlogPost.drop_collection() - - post1 = BlogPost(title='test1', sub=SubDocument(year=2009, slug="test")) - post1.save() - - # sub.slug is different so won't raise exception - post2 = BlogPost(title='test2', sub=SubDocument(year=2010, slug='another-slug')) - post2.save() - - # Now there will be two docs with the same sub.slug - post3 = BlogPost(title='test3', sub=SubDocument(year=2010, slug='test')) - self.assertRaises(NotUniqueError, post3.save) - - BlogPost.drop_collection() - - def test_unique_with_embedded_document_and_embedded_unique(self): - """Ensure that uniqueness constraints are applied to fields on - embedded documents. And work with unique_with as well. - """ - class SubDocument(EmbeddedDocument): - year = IntField(db_field='yr') - slug = StringField(unique=True) - - class BlogPost(Document): - title = StringField(unique_with='sub.year') - sub = EmbeddedDocumentField(SubDocument) - - BlogPost.drop_collection() - - post1 = BlogPost(title='test1', sub=SubDocument(year=2009, slug="test")) - post1.save() - - # sub.slug is different so won't raise exception - post2 = BlogPost(title='test2', sub=SubDocument(year=2010, slug='another-slug')) - post2.save() - - # Now there will be two docs with the same sub.slug - post3 = BlogPost(title='test3', sub=SubDocument(year=2010, slug='test')) - self.assertRaises(NotUniqueError, post3.save) - - # Now there will be two docs with the same title and year - post3 = BlogPost(title='test1', sub=SubDocument(year=2009, slug='test-1')) - self.assertRaises(NotUniqueError, post3.save) - - BlogPost.drop_collection() - - def test_ttl_indexes(self): - - class Log(Document): - created = DateTimeField(default=datetime.now) - meta = { - 'indexes': [ - {'fields': ['created'], 'expireAfterSeconds': 3600} - ] - } - - Log.drop_collection() - - if pymongo.version_tuple[0] < 2 and pymongo.version_tuple[1] < 3: - raise SkipTest('pymongo needs to be 2.3 or higher for this test') - - connection = get_connection() - version_array = connection.server_info()['versionArray'] - if version_array[0] < 2 and version_array[1] < 2: - raise SkipTest('MongoDB needs to be 2.2 or higher for this test') - - # Indexes are lazy so use list() to perform query - list(Log.objects) - info = Log.objects._collection.index_information() - self.assertEqual(3600, - info['_types_1_created_1']['expireAfterSeconds']) - - def test_unique_and_indexes(self): - """Ensure that 'unique' constraints aren't overridden by - meta.indexes. - """ - class Customer(Document): - cust_id = IntField(unique=True, required=True) - meta = { - 'indexes': ['cust_id'], - 'allow_inheritance': False, - } - - Customer.drop_collection() - cust = Customer(cust_id=1) - cust.save() - - cust_dupe = Customer(cust_id=1) - try: - cust_dupe.save() - raise AssertionError, "We saved a dupe!" - except NotUniqueError: - pass - Customer.drop_collection() - - def test_unique_and_primary(self): - """If you set a field as primary, then unexpected behaviour can occur. - You won't create a duplicate but you will update an existing document. - """ - - class User(Document): - name = StringField(primary_key=True, unique=True) - password = StringField() - - User.drop_collection() - - user = User(name='huangz', password='secret') - user.save() - - user = User(name='huangz', password='secret2') - user.save() - - self.assertEqual(User.objects.count(), 1) - self.assertEqual(User.objects.get().password, 'secret2') - - User.drop_collection() def test_custom_id_field(self): """Ensure that documents may be created with custom primary keys. @@ -1876,7 +826,6 @@ class DocumentTest(unittest.TestCase): class Site(Document): page = EmbeddedDocumentField(Page) - Site.drop_collection() site = Site(page=Page(log_message="Warning: Dummy message")) site.save() @@ -1903,7 +852,6 @@ class DocumentTest(unittest.TestCase): class Site(Document): page = EmbeddedDocumentField(Page) - Site.drop_collection() site = Site(page=Page(log_message="Warning: Dummy message")) @@ -1917,519 +865,6 @@ class DocumentTest(unittest.TestCase): site = Site.objects.first() self.assertEqual(site.page.log_message, "Error: Dummy message") - def test_circular_reference_deltas(self): - - class Person(Document): - name = StringField() - owns = ListField(ReferenceField('Organization')) - - class Organization(Document): - name = StringField() - owner = ReferenceField('Person') - - Person.drop_collection() - Organization.drop_collection() - - person = Person(name="owner") - person.save() - organization = Organization(name="company") - organization.save() - - person.owns.append(organization) - organization.owner = person - - person.save() - organization.save() - - p = Person.objects[0].select_related() - o = Organization.objects.first() - self.assertEqual(p.owns[0], o) - self.assertEqual(o.owner, p) - - def test_circular_reference_deltas_2(self): - - class Person(Document): - name = StringField() - owns = ListField( ReferenceField( 'Organization' ) ) - employer = ReferenceField( 'Organization' ) - - class Organization( Document ): - name = StringField() - owner = ReferenceField( 'Person' ) - employees = ListField( ReferenceField( 'Person' ) ) - - Person.drop_collection() - Organization.drop_collection() - - person = Person( name="owner" ) - person.save() - - employee = Person( name="employee" ) - employee.save() - - organization = Organization( name="company" ) - organization.save() - - person.owns.append( organization ) - organization.owner = person - - organization.employees.append( employee ) - employee.employer = organization - - person.save() - organization.save() - employee.save() - - p = Person.objects.get(name="owner") - e = Person.objects.get(name="employee") - o = Organization.objects.first() - - self.assertEqual(p.owns[0], o) - self.assertEqual(o.owner, p) - self.assertEqual(e.employer, o) - - def test_delta(self): - - class Doc(Document): - string_field = StringField() - int_field = IntField() - dict_field = DictField() - list_field = ListField() - - Doc.drop_collection() - doc = Doc() - doc.save() - - doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), []) - self.assertEqual(doc._delta(), ({}, {})) - - doc.string_field = 'hello' - self.assertEqual(doc._get_changed_fields(), ['string_field']) - self.assertEqual(doc._delta(), ({'string_field': 'hello'}, {})) - - doc._changed_fields = [] - doc.int_field = 1 - self.assertEqual(doc._get_changed_fields(), ['int_field']) - self.assertEqual(doc._delta(), ({'int_field': 1}, {})) - - doc._changed_fields = [] - dict_value = {'hello': 'world', 'ping': 'pong'} - doc.dict_field = dict_value - self.assertEqual(doc._get_changed_fields(), ['dict_field']) - self.assertEqual(doc._delta(), ({'dict_field': dict_value}, {})) - - doc._changed_fields = [] - list_value = ['1', 2, {'hello': 'world'}] - doc.list_field = list_value - self.assertEqual(doc._get_changed_fields(), ['list_field']) - self.assertEqual(doc._delta(), ({'list_field': list_value}, {})) - - # Test unsetting - doc._changed_fields = [] - doc.dict_field = {} - self.assertEqual(doc._get_changed_fields(), ['dict_field']) - self.assertEqual(doc._delta(), ({}, {'dict_field': 1})) - - doc._changed_fields = [] - doc.list_field = [] - self.assertEqual(doc._get_changed_fields(), ['list_field']) - self.assertEqual(doc._delta(), ({}, {'list_field': 1})) - - def test_delta_recursive(self): - - class Embedded(EmbeddedDocument): - string_field = StringField() - int_field = IntField() - dict_field = DictField() - list_field = ListField() - - class Doc(Document): - string_field = StringField() - int_field = IntField() - dict_field = DictField() - list_field = ListField() - embedded_field = EmbeddedDocumentField(Embedded) - - Doc.drop_collection() - doc = Doc() - doc.save() - - doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), []) - self.assertEqual(doc._delta(), ({}, {})) - - embedded_1 = Embedded() - embedded_1.string_field = 'hello' - embedded_1.int_field = 1 - embedded_1.dict_field = {'hello': 'world'} - embedded_1.list_field = ['1', 2, {'hello': 'world'}] - doc.embedded_field = embedded_1 - - self.assertEqual(doc._get_changed_fields(), ['embedded_field']) - - embedded_delta = { - 'string_field': 'hello', - 'int_field': 1, - 'dict_field': {'hello': 'world'}, - 'list_field': ['1', 2, {'hello': 'world'}] - } - self.assertEqual(doc.embedded_field._delta(), (embedded_delta, {})) - embedded_delta.update({ - '_types': ['Embedded'], - '_cls': 'Embedded', - }) - self.assertEqual(doc._delta(), ({'embedded_field': embedded_delta}, {})) - - doc.save() - doc = doc.reload(10) - - doc.embedded_field.dict_field = {} - self.assertEqual(doc._get_changed_fields(), ['embedded_field.dict_field']) - self.assertEqual(doc.embedded_field._delta(), ({}, {'dict_field': 1})) - self.assertEqual(doc._delta(), ({}, {'embedded_field.dict_field': 1})) - doc.save() - doc = doc.reload(10) - self.assertEqual(doc.embedded_field.dict_field, {}) - - doc.embedded_field.list_field = [] - self.assertEqual(doc._get_changed_fields(), ['embedded_field.list_field']) - self.assertEqual(doc.embedded_field._delta(), ({}, {'list_field': 1})) - self.assertEqual(doc._delta(), ({}, {'embedded_field.list_field': 1})) - doc.save() - doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field, []) - - embedded_2 = Embedded() - embedded_2.string_field = 'hello' - embedded_2.int_field = 1 - embedded_2.dict_field = {'hello': 'world'} - embedded_2.list_field = ['1', 2, {'hello': 'world'}] - - doc.embedded_field.list_field = ['1', 2, embedded_2] - self.assertEqual(doc._get_changed_fields(), ['embedded_field.list_field']) - self.assertEqual(doc.embedded_field._delta(), ({ - 'list_field': ['1', 2, { - '_cls': 'Embedded', - '_types': ['Embedded'], - 'string_field': 'hello', - 'dict_field': {'hello': 'world'}, - 'int_field': 1, - 'list_field': ['1', 2, {'hello': 'world'}], - }] - }, {})) - - self.assertEqual(doc._delta(), ({ - 'embedded_field.list_field': ['1', 2, { - '_cls': 'Embedded', - '_types': ['Embedded'], - 'string_field': 'hello', - 'dict_field': {'hello': 'world'}, - 'int_field': 1, - 'list_field': ['1', 2, {'hello': 'world'}], - }] - }, {})) - doc.save() - doc = doc.reload(10) - - self.assertEqual(doc.embedded_field.list_field[0], '1') - self.assertEqual(doc.embedded_field.list_field[1], 2) - for k in doc.embedded_field.list_field[2]._fields: - self.assertEqual(doc.embedded_field.list_field[2][k], embedded_2[k]) - - doc.embedded_field.list_field[2].string_field = 'world' - self.assertEqual(doc._get_changed_fields(), ['embedded_field.list_field.2.string_field']) - self.assertEqual(doc.embedded_field._delta(), ({'list_field.2.string_field': 'world'}, {})) - self.assertEqual(doc._delta(), ({'embedded_field.list_field.2.string_field': 'world'}, {})) - doc.save() - doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[2].string_field, 'world') - - # Test multiple assignments - doc.embedded_field.list_field[2].string_field = 'hello world' - doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] - self.assertEqual(doc._get_changed_fields(), ['embedded_field.list_field']) - self.assertEqual(doc.embedded_field._delta(), ({ - 'list_field': ['1', 2, { - '_types': ['Embedded'], - '_cls': 'Embedded', - 'string_field': 'hello world', - 'int_field': 1, - 'list_field': ['1', 2, {'hello': 'world'}], - 'dict_field': {'hello': 'world'}}]}, {})) - self.assertEqual(doc._delta(), ({ - 'embedded_field.list_field': ['1', 2, { - '_types': ['Embedded'], - '_cls': 'Embedded', - 'string_field': 'hello world', - 'int_field': 1, - 'list_field': ['1', 2, {'hello': 'world'}], - 'dict_field': {'hello': 'world'}} - ]}, {})) - doc.save() - doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[2].string_field, 'hello world') - - # Test list native methods - doc.embedded_field.list_field[2].list_field.pop(0) - self.assertEqual(doc._delta(), ({'embedded_field.list_field.2.list_field': [2, {'hello': 'world'}]}, {})) - doc.save() - doc = doc.reload(10) - - doc.embedded_field.list_field[2].list_field.append(1) - self.assertEqual(doc._delta(), ({'embedded_field.list_field.2.list_field': [2, {'hello': 'world'}, 1]}, {})) - doc.save() - doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[2].list_field, [2, {'hello': 'world'}, 1]) - - doc.embedded_field.list_field[2].list_field.sort(key=str) - doc.save() - doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[2].list_field, [1, 2, {'hello': 'world'}]) - - del(doc.embedded_field.list_field[2].list_field[2]['hello']) - self.assertEqual(doc._delta(), ({'embedded_field.list_field.2.list_field': [1, 2, {}]}, {})) - doc.save() - doc = doc.reload(10) - - del(doc.embedded_field.list_field[2].list_field) - self.assertEqual(doc._delta(), ({}, {'embedded_field.list_field.2.list_field': 1})) - - doc.save() - doc = doc.reload(10) - - doc.dict_field['Embedded'] = embedded_1 - doc.save() - doc = doc.reload(10) - - doc.dict_field['Embedded'].string_field = 'Hello World' - self.assertEqual(doc._get_changed_fields(), ['dict_field.Embedded.string_field']) - self.assertEqual(doc._delta(), ({'dict_field.Embedded.string_field': 'Hello World'}, {})) - - - def test_delta_db_field(self): - - class Doc(Document): - string_field = StringField(db_field='db_string_field') - int_field = IntField(db_field='db_int_field') - dict_field = DictField(db_field='db_dict_field') - list_field = ListField(db_field='db_list_field') - - Doc.drop_collection() - doc = Doc() - doc.save() - - doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), []) - self.assertEqual(doc._delta(), ({}, {})) - - doc.string_field = 'hello' - self.assertEqual(doc._get_changed_fields(), ['db_string_field']) - self.assertEqual(doc._delta(), ({'db_string_field': 'hello'}, {})) - - doc._changed_fields = [] - doc.int_field = 1 - self.assertEqual(doc._get_changed_fields(), ['db_int_field']) - self.assertEqual(doc._delta(), ({'db_int_field': 1}, {})) - - doc._changed_fields = [] - dict_value = {'hello': 'world', 'ping': 'pong'} - doc.dict_field = dict_value - self.assertEqual(doc._get_changed_fields(), ['db_dict_field']) - self.assertEqual(doc._delta(), ({'db_dict_field': dict_value}, {})) - - doc._changed_fields = [] - list_value = ['1', 2, {'hello': 'world'}] - doc.list_field = list_value - self.assertEqual(doc._get_changed_fields(), ['db_list_field']) - self.assertEqual(doc._delta(), ({'db_list_field': list_value}, {})) - - # Test unsetting - doc._changed_fields = [] - doc.dict_field = {} - self.assertEqual(doc._get_changed_fields(), ['db_dict_field']) - self.assertEqual(doc._delta(), ({}, {'db_dict_field': 1})) - - doc._changed_fields = [] - doc.list_field = [] - self.assertEqual(doc._get_changed_fields(), ['db_list_field']) - self.assertEqual(doc._delta(), ({}, {'db_list_field': 1})) - - # Test it saves that data - doc = Doc() - doc.save() - - doc.string_field = 'hello' - doc.int_field = 1 - doc.dict_field = {'hello': 'world'} - doc.list_field = ['1', 2, {'hello': 'world'}] - doc.save() - doc = doc.reload(10) - - self.assertEqual(doc.string_field, 'hello') - self.assertEqual(doc.int_field, 1) - self.assertEqual(doc.dict_field, {'hello': 'world'}) - self.assertEqual(doc.list_field, ['1', 2, {'hello': 'world'}]) - - def test_delta_recursive_db_field(self): - - class Embedded(EmbeddedDocument): - string_field = StringField(db_field='db_string_field') - int_field = IntField(db_field='db_int_field') - dict_field = DictField(db_field='db_dict_field') - list_field = ListField(db_field='db_list_field') - - class Doc(Document): - string_field = StringField(db_field='db_string_field') - int_field = IntField(db_field='db_int_field') - dict_field = DictField(db_field='db_dict_field') - list_field = ListField(db_field='db_list_field') - embedded_field = EmbeddedDocumentField(Embedded, db_field='db_embedded_field') - - Doc.drop_collection() - doc = Doc() - doc.save() - - doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), []) - self.assertEqual(doc._delta(), ({}, {})) - - embedded_1 = Embedded() - embedded_1.string_field = 'hello' - embedded_1.int_field = 1 - embedded_1.dict_field = {'hello': 'world'} - embedded_1.list_field = ['1', 2, {'hello': 'world'}] - doc.embedded_field = embedded_1 - - self.assertEqual(doc._get_changed_fields(), ['db_embedded_field']) - - embedded_delta = { - 'db_string_field': 'hello', - 'db_int_field': 1, - 'db_dict_field': {'hello': 'world'}, - 'db_list_field': ['1', 2, {'hello': 'world'}] - } - self.assertEqual(doc.embedded_field._delta(), (embedded_delta, {})) - embedded_delta.update({ - '_types': ['Embedded'], - '_cls': 'Embedded', - }) - self.assertEqual(doc._delta(), ({'db_embedded_field': embedded_delta}, {})) - - doc.save() - doc = doc.reload(10) - - doc.embedded_field.dict_field = {} - self.assertEqual(doc._get_changed_fields(), ['db_embedded_field.db_dict_field']) - self.assertEqual(doc.embedded_field._delta(), ({}, {'db_dict_field': 1})) - self.assertEqual(doc._delta(), ({}, {'db_embedded_field.db_dict_field': 1})) - doc.save() - doc = doc.reload(10) - self.assertEqual(doc.embedded_field.dict_field, {}) - - doc.embedded_field.list_field = [] - self.assertEqual(doc._get_changed_fields(), ['db_embedded_field.db_list_field']) - self.assertEqual(doc.embedded_field._delta(), ({}, {'db_list_field': 1})) - self.assertEqual(doc._delta(), ({}, {'db_embedded_field.db_list_field': 1})) - doc.save() - doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field, []) - - embedded_2 = Embedded() - embedded_2.string_field = 'hello' - embedded_2.int_field = 1 - embedded_2.dict_field = {'hello': 'world'} - embedded_2.list_field = ['1', 2, {'hello': 'world'}] - - doc.embedded_field.list_field = ['1', 2, embedded_2] - self.assertEqual(doc._get_changed_fields(), ['db_embedded_field.db_list_field']) - self.assertEqual(doc.embedded_field._delta(), ({ - 'db_list_field': ['1', 2, { - '_cls': 'Embedded', - '_types': ['Embedded'], - 'db_string_field': 'hello', - 'db_dict_field': {'hello': 'world'}, - 'db_int_field': 1, - 'db_list_field': ['1', 2, {'hello': 'world'}], - }] - }, {})) - - self.assertEqual(doc._delta(), ({ - 'db_embedded_field.db_list_field': ['1', 2, { - '_cls': 'Embedded', - '_types': ['Embedded'], - 'db_string_field': 'hello', - 'db_dict_field': {'hello': 'world'}, - 'db_int_field': 1, - 'db_list_field': ['1', 2, {'hello': 'world'}], - }] - }, {})) - doc.save() - doc = doc.reload(10) - - self.assertEqual(doc.embedded_field.list_field[0], '1') - self.assertEqual(doc.embedded_field.list_field[1], 2) - for k in doc.embedded_field.list_field[2]._fields: - self.assertEqual(doc.embedded_field.list_field[2][k], embedded_2[k]) - - doc.embedded_field.list_field[2].string_field = 'world' - self.assertEqual(doc._get_changed_fields(), ['db_embedded_field.db_list_field.2.db_string_field']) - self.assertEqual(doc.embedded_field._delta(), ({'db_list_field.2.db_string_field': 'world'}, {})) - self.assertEqual(doc._delta(), ({'db_embedded_field.db_list_field.2.db_string_field': 'world'}, {})) - doc.save() - doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[2].string_field, 'world') - - # Test multiple assignments - doc.embedded_field.list_field[2].string_field = 'hello world' - doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] - self.assertEqual(doc._get_changed_fields(), ['db_embedded_field.db_list_field']) - self.assertEqual(doc.embedded_field._delta(), ({ - 'db_list_field': ['1', 2, { - '_types': ['Embedded'], - '_cls': 'Embedded', - 'db_string_field': 'hello world', - 'db_int_field': 1, - 'db_list_field': ['1', 2, {'hello': 'world'}], - 'db_dict_field': {'hello': 'world'}}]}, {})) - self.assertEqual(doc._delta(), ({ - 'db_embedded_field.db_list_field': ['1', 2, { - '_types': ['Embedded'], - '_cls': 'Embedded', - 'db_string_field': 'hello world', - 'db_int_field': 1, - 'db_list_field': ['1', 2, {'hello': 'world'}], - 'db_dict_field': {'hello': 'world'}} - ]}, {})) - doc.save() - doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[2].string_field, 'hello world') - - # Test list native methods - doc.embedded_field.list_field[2].list_field.pop(0) - self.assertEqual(doc._delta(), ({'db_embedded_field.db_list_field.2.db_list_field': [2, {'hello': 'world'}]}, {})) - doc.save() - doc = doc.reload(10) - - doc.embedded_field.list_field[2].list_field.append(1) - self.assertEqual(doc._delta(), ({'db_embedded_field.db_list_field.2.db_list_field': [2, {'hello': 'world'}, 1]}, {})) - doc.save() - doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[2].list_field, [2, {'hello': 'world'}, 1]) - - doc.embedded_field.list_field[2].list_field.sort(key=str) - doc.save() - doc = doc.reload(10) - self.assertEqual(doc.embedded_field.list_field[2].list_field, [1, 2, {'hello': 'world'}]) - - del(doc.embedded_field.list_field[2].list_field[2]['hello']) - self.assertEqual(doc._delta(), ({'db_embedded_field.db_list_field.2.db_list_field': [1, 2, {}]}, {})) - doc.save() - doc = doc.reload(10) - - del(doc.embedded_field.list_field[2].list_field) - self.assertEqual(doc._delta(), ({}, {'db_embedded_field.db_list_field.2.db_list_field': 1})) - def test_save_only_changed_fields(self): """Ensure save only sets / unsets changed fields """ @@ -2437,7 +872,6 @@ class DocumentTest(unittest.TestCase): class User(self.Person): active = BooleanField(default=True) - User.drop_collection() # Create person object and save it to the database @@ -2697,29 +1131,6 @@ class DocumentTest(unittest.TestCase): promoted_employee.reload() self.assertEqual(promoted_employee.details, None) - def test_mixins_dont_add_to_types(self): - - class Mixin(object): - name = StringField() - - class Person(Document, Mixin): - pass - - Person.drop_collection() - - self.assertEqual(Person._fields.keys(), ['name', 'id']) - - Person(name="Rozza").save() - - collection = self.db[Person._get_collection_name()] - obj = collection.find_one() - self.assertEqual(obj['_cls'], 'Person') - self.assertEqual(obj['_types'], ['Person']) - - self.assertEqual(Person.objects.count(), 1) - - Person.drop_collection() - def test_object_mixins(self): class NameMixin(object): @@ -2795,22 +1206,6 @@ class DocumentTest(unittest.TestCase): BlogPost.drop_collection() - def test_cannot_perform_joins_references(self): - - class BlogPost(Document): - author = ReferenceField(self.Person) - author2 = GenericReferenceField() - - def test_reference(): - list(BlogPost.objects(author__name="test")) - - self.assertRaises(InvalidQueryError, test_reference) - - def test_generic_reference(): - list(BlogPost.objects(author2__name="test")) - - self.assertRaises(InvalidQueryError, test_generic_reference) - def test_duplicate_db_fields_raise_invalid_document_error(self): """Ensure a InvalidDocumentError is thrown if duplicate fields declare the same db_field""" @@ -3082,17 +1477,17 @@ class DocumentTest(unittest.TestCase): for u in User.objects.all(): all_user_dic[u] = "OK" - self.assertEqual(all_user_dic.get(u1, False), "OK" ) - self.assertEqual(all_user_dic.get(u2, False), "OK" ) - self.assertEqual(all_user_dic.get(u3, False), "OK" ) - self.assertEqual(all_user_dic.get(u4, False), False ) # New object - self.assertEqual(all_user_dic.get(b1, False), False ) # Other object - self.assertEqual(all_user_dic.get(b2, False), False ) # Other object + self.assertEqual(all_user_dic.get(u1, False), "OK") + self.assertEqual(all_user_dic.get(u2, False), "OK") + self.assertEqual(all_user_dic.get(u3, False), "OK") + self.assertEqual(all_user_dic.get(u4, False), False) # New object + self.assertEqual(all_user_dic.get(b1, False), False) # Other object + self.assertEqual(all_user_dic.get(b2, False), False) # Other object # in Set all_user_set = set(User.objects.all()) - self.assertTrue(u1 in all_user_set ) + self.assertTrue(u1 in all_user_set) def test_picklable(self): @@ -3313,7 +1708,7 @@ class DocumentTest(unittest.TestCase): # Bob Book.objects.create(name="1", author=bob, extra={"a": bob.to_dbref(), "b": [karl.to_dbref(), susan.to_dbref()]}) - Book.objects.create(name="2", author=bob, extra={"a": bob.to_dbref(), "b": karl.to_dbref()} ) + Book.objects.create(name="2", author=bob, extra={"a": bob.to_dbref(), "b": karl.to_dbref()}) Book.objects.create(name="3", author=bob, extra={"a": bob.to_dbref(), "c": [jon.to_dbref(), peter.to_dbref()]}) Book.objects.create(name="4", author=bob) @@ -3325,20 +1720,20 @@ class DocumentTest(unittest.TestCase): Book.objects.create(name="9", author=jon, extra={"a": peter.to_dbref()}) # Checks - self.assertEqual(u",".join([str(b) for b in Book.objects.all()] ) , "1,2,3,4,5,6,7,8,9" ) + self.assertEqual(u",".join([str(b) for b in Book.objects.all()]) , "1,2,3,4,5,6,7,8,9") # bob related books self.assertEqual(u",".join([str(b) for b in Book.objects.filter( - Q(extra__a=bob ) | + Q(extra__a=bob) | Q(author=bob) | Q(extra__b=bob))]) , "1,2,3,4") # Susan & Karl related books self.assertEqual(u",".join([str(b) for b in Book.objects.filter( - Q(extra__a__all=[karl, susan] ) | - Q(author__all=[karl, susan ] ) | - Q(extra__b__all=[karl.to_dbref(), susan.to_dbref()] ) - ) ] ) , "1" ) + Q(extra__a__all=[karl, susan]) | + Q(author__all=[karl, susan ]) | + Q(extra__b__all=[karl.to_dbref(), susan.to_dbref()]) + ) ]) , "1") # $Where self.assertEqual(u",".join([str(b) for b in Book.objects.filter( @@ -3348,7 +1743,7 @@ class DocumentTest(unittest.TestCase): return this.name == '1' || this.name == '2';}""" } - ) ]), "1,2") + ) ]), "1,2") class ValidatorErrorTest(unittest.TestCase): @@ -3504,5 +1899,6 @@ class ValidatorErrorTest(unittest.TestCase): self.assertRaises(OperationError, change_shard_key) + if __name__ == '__main__': unittest.main() diff --git a/tests/mongoengine.png b/tests/document/mongoengine.png similarity index 100% rename from tests/mongoengine.png rename to tests/document/mongoengine.png diff --git a/tests/migration/__init__.py b/tests/migration/__init__.py new file mode 100644 index 00000000..882e7370 --- /dev/null +++ b/tests/migration/__init__.py @@ -0,0 +1,4 @@ +from turn_off_inheritance import * + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/migration/test_convert_to_new_inheritance_model.py b/tests/migration/test_convert_to_new_inheritance_model.py new file mode 100644 index 00000000..0ef37f74 --- /dev/null +++ b/tests/migration/test_convert_to_new_inheritance_model.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +import unittest + +from mongoengine import Document, connect +from mongoengine.connection import get_db +from mongoengine.fields import StringField + +__all__ = ('ConvertToNewInheritanceModel', ) + + +class ConvertToNewInheritanceModel(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + self.db = get_db() + + def tearDown(self): + for collection in self.db.collection_names(): + if 'system.' in collection: + continue + self.db.drop_collection(collection) + + def test_how_to_convert_to_the_new_inheritance_model(self): + """Demonstrates migrating from 0.7 to 0.8 + """ + + # 1. Declaration of the class + class Animal(Document): + name = StringField() + meta = { + 'allow_inheritance': True, + 'indexes': ['name'] + } + + # 2. Remove _types + collection = Animal._get_collection() + collection.update({}, {"$unset": {"_types": 1}}, multi=True) + + # 3. Confirm extra data is removed + count = collection.find({'_types': {"$exists": True}}).count() + assert count == 0 + + # 4. Remove indexes + info = collection.index_information() + indexes_to_drop = [key for key, value in info.iteritems() + if '_types' in dict(value['key'])] + for index in indexes_to_drop: + collection.drop_index(index) + + # 5. Recreate indexes + Animal.objects._ensure_indexes() diff --git a/tests/migration/turn_off_inheritance.py b/tests/migration/turn_off_inheritance.py new file mode 100644 index 00000000..5d0f7d73 --- /dev/null +++ b/tests/migration/turn_off_inheritance.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +import unittest + +from mongoengine import Document, connect +from mongoengine.connection import get_db +from mongoengine.fields import StringField + +__all__ = ('TurnOffInheritanceTest', ) + + +class TurnOffInheritanceTest(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + self.db = get_db() + + def tearDown(self): + for collection in self.db.collection_names(): + if 'system.' in collection: + continue + self.db.drop_collection(collection) + + def test_how_to_turn_off_inheritance(self): + """Demonstrates migrating from allow_inheritance = True to False. + """ + + # 1. Old declaration of the class + + class Animal(Document): + name = StringField() + meta = { + 'allow_inheritance': True, + 'indexes': ['name'] + } + + # 2. Turn off inheritance + class Animal(Document): + name = StringField() + meta = { + 'allow_inheritance': False, + 'indexes': ['name'] + } + + # 3. Remove _types and _cls + collection = Animal._get_collection() + collection.update({}, {"$unset": {"_types": 1, "_cls": 1}}, multi=True) + + # 3. Confirm extra data is removed + count = collection.find({"$or": [{'_types': {"$exists": True}}, + {'_cls': {"$exists": True}}]}).count() + assert count == 0 + + # 4. Remove indexes + info = collection.index_information() + indexes_to_drop = [key for key, value in info.iteritems() + if '_types' in dict(value['key']) + or '_cls' in dict(value['key'])] + for index in indexes_to_drop: + collection.drop_index(index) + + # 5. Recreate indexes + Animal.objects._ensure_indexes() diff --git a/tests/test_dynamic_document.py b/tests/test_dynamic_document.py deleted file mode 100644 index 23762a34..00000000 --- a/tests/test_dynamic_document.py +++ /dev/null @@ -1,533 +0,0 @@ -import unittest - -from mongoengine import * -from mongoengine.connection import get_db - - -class DynamicDocTest(unittest.TestCase): - - def setUp(self): - connect(db='mongoenginetest') - self.db = get_db() - - class Person(DynamicDocument): - name = StringField() - meta = {'allow_inheritance': True} - - Person.drop_collection() - - self.Person = Person - - def test_simple_dynamic_document(self): - """Ensures simple dynamic documents are saved correctly""" - - p = self.Person() - p.name = "James" - p.age = 34 - - self.assertEqual(p.to_mongo(), - {"_types": ["Person"], "_cls": "Person", - "name": "James", "age": 34} - ) - - p.save() - - self.assertEqual(self.Person.objects.first().age, 34) - - # Confirm no changes to self.Person - self.assertFalse(hasattr(self.Person, 'age')) - - def test_dynamic_document_delta(self): - """Ensures simple dynamic documents can delta correctly""" - p = self.Person(name="James", age=34) - self.assertEqual(p._delta(), ({'_types': ['Person'], 'age': 34, 'name': 'James', '_cls': 'Person'}, {})) - - p.doc = 123 - del(p.doc) - self.assertEqual(p._delta(), ({'_types': ['Person'], 'age': 34, 'name': 'James', '_cls': 'Person'}, {'doc': 1})) - - def test_change_scope_of_variable(self): - """Test changing the scope of a dynamic field has no adverse effects""" - p = self.Person() - p.name = "Dean" - p.misc = 22 - p.save() - - p = self.Person.objects.get() - p.misc = {'hello': 'world'} - p.save() - - p = self.Person.objects.get() - self.assertEqual(p.misc, {'hello': 'world'}) - - def test_delete_dynamic_field(self): - """Test deleting a dynamic field works""" - self.Person.drop_collection() - p = self.Person() - p.name = "Dean" - p.misc = 22 - p.save() - - p = self.Person.objects.get() - p.misc = {'hello': 'world'} - p.save() - - p = self.Person.objects.get() - self.assertEqual(p.misc, {'hello': 'world'}) - collection = self.db[self.Person._get_collection_name()] - obj = collection.find_one() - self.assertEqual(sorted(obj.keys()), ['_cls', '_id', '_types', 'misc', 'name']) - - del(p.misc) - p.save() - - p = self.Person.objects.get() - self.assertFalse(hasattr(p, 'misc')) - - obj = collection.find_one() - self.assertEqual(sorted(obj.keys()), ['_cls', '_id', '_types', 'name']) - - def test_dynamic_document_queries(self): - """Ensure we can query dynamic fields""" - p = self.Person() - p.name = "Dean" - p.age = 22 - p.save() - - self.assertEqual(1, self.Person.objects(age=22).count()) - p = self.Person.objects(age=22) - p = p.get() - self.assertEqual(22, p.age) - - def test_complex_dynamic_document_queries(self): - class Person(DynamicDocument): - name = StringField() - - Person.drop_collection() - - p = Person(name="test") - p.age = "ten" - p.save() - - p1 = Person(name="test1") - p1.age = "less then ten and a half" - p1.save() - - p2 = Person(name="test2") - p2.age = 10 - p2.save() - - self.assertEqual(Person.objects(age__icontains='ten').count(), 2) - self.assertEqual(Person.objects(age__gte=10).count(), 1) - - def test_complex_data_lookups(self): - """Ensure you can query dynamic document dynamic fields""" - p = self.Person() - p.misc = {'hello': 'world'} - p.save() - - self.assertEqual(1, self.Person.objects(misc__hello='world').count()) - - def test_inheritance(self): - """Ensure that dynamic document plays nice with inheritance""" - class Employee(self.Person): - salary = IntField() - - Employee.drop_collection() - - self.assertTrue('name' in Employee._fields) - self.assertTrue('salary' in Employee._fields) - self.assertEqual(Employee._get_collection_name(), - self.Person._get_collection_name()) - - joe_bloggs = Employee() - joe_bloggs.name = "Joe Bloggs" - joe_bloggs.salary = 10 - joe_bloggs.age = 20 - joe_bloggs.save() - - self.assertEqual(1, self.Person.objects(age=20).count()) - self.assertEqual(1, Employee.objects(age=20).count()) - - joe_bloggs = self.Person.objects.first() - self.assertTrue(isinstance(joe_bloggs, Employee)) - - def test_embedded_dynamic_document(self): - """Test dynamic embedded documents""" - class Embedded(DynamicEmbeddedDocument): - pass - - class Doc(DynamicDocument): - pass - - Doc.drop_collection() - doc = Doc() - - embedded_1 = Embedded() - embedded_1.string_field = 'hello' - embedded_1.int_field = 1 - embedded_1.dict_field = {'hello': 'world'} - embedded_1.list_field = ['1', 2, {'hello': 'world'}] - doc.embedded_field = embedded_1 - - self.assertEqual(doc.to_mongo(), {"_types": ['Doc'], "_cls": "Doc", - "embedded_field": { - "_types": ['Embedded'], "_cls": "Embedded", - "string_field": "hello", - "int_field": 1, - "dict_field": {"hello": "world"}, - "list_field": ['1', 2, {'hello': 'world'}] - } - }) - doc.save() - - doc = Doc.objects.first() - self.assertEqual(doc.embedded_field.__class__, Embedded) - self.assertEqual(doc.embedded_field.string_field, "hello") - self.assertEqual(doc.embedded_field.int_field, 1) - self.assertEqual(doc.embedded_field.dict_field, {'hello': 'world'}) - self.assertEqual(doc.embedded_field.list_field, ['1', 2, {'hello': 'world'}]) - - def test_complex_embedded_documents(self): - """Test complex dynamic embedded documents setups""" - class Embedded(DynamicEmbeddedDocument): - pass - - class Doc(DynamicDocument): - pass - - Doc.drop_collection() - doc = Doc() - - embedded_1 = Embedded() - embedded_1.string_field = 'hello' - embedded_1.int_field = 1 - embedded_1.dict_field = {'hello': 'world'} - - embedded_2 = Embedded() - embedded_2.string_field = 'hello' - embedded_2.int_field = 1 - embedded_2.dict_field = {'hello': 'world'} - embedded_2.list_field = ['1', 2, {'hello': 'world'}] - - embedded_1.list_field = ['1', 2, embedded_2] - doc.embedded_field = embedded_1 - - self.assertEqual(doc.to_mongo(), {"_types": ['Doc'], "_cls": "Doc", - "embedded_field": { - "_types": ['Embedded'], "_cls": "Embedded", - "string_field": "hello", - "int_field": 1, - "dict_field": {"hello": "world"}, - "list_field": ['1', 2, - {"_types": ['Embedded'], "_cls": "Embedded", - "string_field": "hello", - "int_field": 1, - "dict_field": {"hello": "world"}, - "list_field": ['1', 2, {'hello': 'world'}]} - ] - } - }) - doc.save() - doc = Doc.objects.first() - self.assertEqual(doc.embedded_field.__class__, Embedded) - self.assertEqual(doc.embedded_field.string_field, "hello") - self.assertEqual(doc.embedded_field.int_field, 1) - self.assertEqual(doc.embedded_field.dict_field, {'hello': 'world'}) - self.assertEqual(doc.embedded_field.list_field[0], '1') - self.assertEqual(doc.embedded_field.list_field[1], 2) - - embedded_field = doc.embedded_field.list_field[2] - - self.assertEqual(embedded_field.__class__, Embedded) - self.assertEqual(embedded_field.string_field, "hello") - self.assertEqual(embedded_field.int_field, 1) - self.assertEqual(embedded_field.dict_field, {'hello': 'world'}) - self.assertEqual(embedded_field.list_field, ['1', 2, {'hello': 'world'}]) - - def test_delta_for_dynamic_documents(self): - p = self.Person() - p.name = "Dean" - p.age = 22 - p.save() - - p.age = 24 - self.assertEqual(p.age, 24) - self.assertEqual(p._get_changed_fields(), ['age']) - self.assertEqual(p._delta(), ({'age': 24}, {})) - - p = self.Person.objects(age=22).get() - p.age = 24 - self.assertEqual(p.age, 24) - self.assertEqual(p._get_changed_fields(), ['age']) - self.assertEqual(p._delta(), ({'age': 24}, {})) - - p.save() - self.assertEqual(1, self.Person.objects(age=24).count()) - - def test_delta(self): - - class Doc(DynamicDocument): - pass - - Doc.drop_collection() - doc = Doc() - doc.save() - - doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), []) - self.assertEqual(doc._delta(), ({}, {})) - - doc.string_field = 'hello' - self.assertEqual(doc._get_changed_fields(), ['string_field']) - self.assertEqual(doc._delta(), ({'string_field': 'hello'}, {})) - - doc._changed_fields = [] - doc.int_field = 1 - self.assertEqual(doc._get_changed_fields(), ['int_field']) - self.assertEqual(doc._delta(), ({'int_field': 1}, {})) - - doc._changed_fields = [] - dict_value = {'hello': 'world', 'ping': 'pong'} - doc.dict_field = dict_value - self.assertEqual(doc._get_changed_fields(), ['dict_field']) - self.assertEqual(doc._delta(), ({'dict_field': dict_value}, {})) - - doc._changed_fields = [] - list_value = ['1', 2, {'hello': 'world'}] - doc.list_field = list_value - self.assertEqual(doc._get_changed_fields(), ['list_field']) - self.assertEqual(doc._delta(), ({'list_field': list_value}, {})) - - # Test unsetting - doc._changed_fields = [] - doc.dict_field = {} - self.assertEqual(doc._get_changed_fields(), ['dict_field']) - self.assertEqual(doc._delta(), ({}, {'dict_field': 1})) - - doc._changed_fields = [] - doc.list_field = [] - self.assertEqual(doc._get_changed_fields(), ['list_field']) - self.assertEqual(doc._delta(), ({}, {'list_field': 1})) - - def test_delta_recursive(self): - """Testing deltaing works with dynamic documents""" - class Embedded(DynamicEmbeddedDocument): - pass - - class Doc(DynamicDocument): - pass - - Doc.drop_collection() - doc = Doc() - doc.save() - - doc = Doc.objects.first() - self.assertEqual(doc._get_changed_fields(), []) - self.assertEqual(doc._delta(), ({}, {})) - - embedded_1 = Embedded() - embedded_1.string_field = 'hello' - embedded_1.int_field = 1 - embedded_1.dict_field = {'hello': 'world'} - embedded_1.list_field = ['1', 2, {'hello': 'world'}] - doc.embedded_field = embedded_1 - - self.assertEqual(doc._get_changed_fields(), ['embedded_field']) - - embedded_delta = { - 'string_field': 'hello', - 'int_field': 1, - 'dict_field': {'hello': 'world'}, - 'list_field': ['1', 2, {'hello': 'world'}] - } - self.assertEqual(doc.embedded_field._delta(), (embedded_delta, {})) - embedded_delta.update({ - '_types': ['Embedded'], - '_cls': 'Embedded', - }) - self.assertEqual(doc._delta(), ({'embedded_field': embedded_delta}, {})) - - doc.save() - doc.reload() - - doc.embedded_field.dict_field = {} - self.assertEqual(doc._get_changed_fields(), ['embedded_field.dict_field']) - self.assertEqual(doc.embedded_field._delta(), ({}, {'dict_field': 1})) - - self.assertEqual(doc._delta(), ({}, {'embedded_field.dict_field': 1})) - doc.save() - doc.reload() - - doc.embedded_field.list_field = [] - self.assertEqual(doc._get_changed_fields(), ['embedded_field.list_field']) - self.assertEqual(doc.embedded_field._delta(), ({}, {'list_field': 1})) - self.assertEqual(doc._delta(), ({}, {'embedded_field.list_field': 1})) - doc.save() - doc.reload() - - embedded_2 = Embedded() - embedded_2.string_field = 'hello' - embedded_2.int_field = 1 - embedded_2.dict_field = {'hello': 'world'} - embedded_2.list_field = ['1', 2, {'hello': 'world'}] - - doc.embedded_field.list_field = ['1', 2, embedded_2] - self.assertEqual(doc._get_changed_fields(), ['embedded_field.list_field']) - self.assertEqual(doc.embedded_field._delta(), ({ - 'list_field': ['1', 2, { - '_cls': 'Embedded', - '_types': ['Embedded'], - 'string_field': 'hello', - 'dict_field': {'hello': 'world'}, - 'int_field': 1, - 'list_field': ['1', 2, {'hello': 'world'}], - }] - }, {})) - - self.assertEqual(doc._delta(), ({ - 'embedded_field.list_field': ['1', 2, { - '_cls': 'Embedded', - '_types': ['Embedded'], - 'string_field': 'hello', - 'dict_field': {'hello': 'world'}, - 'int_field': 1, - 'list_field': ['1', 2, {'hello': 'world'}], - }] - }, {})) - doc.save() - doc.reload() - - self.assertEqual(doc.embedded_field.list_field[2]._changed_fields, []) - self.assertEqual(doc.embedded_field.list_field[0], '1') - self.assertEqual(doc.embedded_field.list_field[1], 2) - for k in doc.embedded_field.list_field[2]._fields: - self.assertEqual(doc.embedded_field.list_field[2][k], embedded_2[k]) - - doc.embedded_field.list_field[2].string_field = 'world' - self.assertEqual(doc._get_changed_fields(), ['embedded_field.list_field.2.string_field']) - self.assertEqual(doc.embedded_field._delta(), ({'list_field.2.string_field': 'world'}, {})) - self.assertEqual(doc._delta(), ({'embedded_field.list_field.2.string_field': 'world'}, {})) - doc.save() - doc.reload() - self.assertEqual(doc.embedded_field.list_field[2].string_field, 'world') - - # Test multiple assignments - doc.embedded_field.list_field[2].string_field = 'hello world' - doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] - self.assertEqual(doc._get_changed_fields(), ['embedded_field.list_field']) - self.assertEqual(doc.embedded_field._delta(), ({ - 'list_field': ['1', 2, { - '_types': ['Embedded'], - '_cls': 'Embedded', - 'string_field': 'hello world', - 'int_field': 1, - 'list_field': ['1', 2, {'hello': 'world'}], - 'dict_field': {'hello': 'world'}}]}, {})) - self.assertEqual(doc._delta(), ({ - 'embedded_field.list_field': ['1', 2, { - '_types': ['Embedded'], - '_cls': 'Embedded', - 'string_field': 'hello world', - 'int_field': 1, - 'list_field': ['1', 2, {'hello': 'world'}], - 'dict_field': {'hello': 'world'}} - ]}, {})) - doc.save() - doc.reload() - self.assertEqual(doc.embedded_field.list_field[2].string_field, 'hello world') - - # Test list native methods - doc.embedded_field.list_field[2].list_field.pop(0) - self.assertEqual(doc._delta(), ({'embedded_field.list_field.2.list_field': [2, {'hello': 'world'}]}, {})) - doc.save() - doc.reload() - - doc.embedded_field.list_field[2].list_field.append(1) - self.assertEqual(doc._delta(), ({'embedded_field.list_field.2.list_field': [2, {'hello': 'world'}, 1]}, {})) - doc.save() - doc.reload() - self.assertEqual(doc.embedded_field.list_field[2].list_field, [2, {'hello': 'world'}, 1]) - - doc.embedded_field.list_field[2].list_field.sort(key=str)# use str as a key to allow comparing uncomperable types - doc.save() - doc.reload() - self.assertEqual(doc.embedded_field.list_field[2].list_field, [1, 2, {'hello': 'world'}]) - - del(doc.embedded_field.list_field[2].list_field[2]['hello']) - self.assertEqual(doc._delta(), ({'embedded_field.list_field.2.list_field': [1, 2, {}]}, {})) - doc.save() - doc.reload() - - del(doc.embedded_field.list_field[2].list_field) - self.assertEqual(doc._delta(), ({}, {'embedded_field.list_field.2.list_field': 1})) - - doc.save() - doc.reload() - - doc.dict_field = {'embedded': embedded_1} - doc.save() - doc.reload() - - doc.dict_field['embedded'].string_field = 'Hello World' - self.assertEqual(doc._get_changed_fields(), ['dict_field.embedded.string_field']) - self.assertEqual(doc._delta(), ({'dict_field.embedded.string_field': 'Hello World'}, {})) - - def test_indexes(self): - """Ensure that indexes are used when meta[indexes] is specified. - """ - class BlogPost(DynamicDocument): - meta = { - 'indexes': [ - '-date', - ('category', '-date') - ], - } - - BlogPost.drop_collection() - - info = BlogPost.objects._collection.index_information() - # _id, '-date', ('cat', 'date') - # NB: there is no index on _types by itself, since - # the indices on -date and tags will both contain - # _types as first element in the key - self.assertEqual(len(info), 3) - - # Indexes are lazy so use list() to perform query - list(BlogPost.objects) - info = BlogPost.objects._collection.index_information() - info = [value['key'] for key, value in info.iteritems()] - self.assertTrue([('_types', 1), ('category', 1), ('date', -1)] - in info) - self.assertTrue([('_types', 1), ('date', -1)] in info) - - def test_dynamic_and_embedded(self): - """Ensure embedded documents play nicely""" - - class Address(EmbeddedDocument): - city = StringField() - - class Person(DynamicDocument): - name = StringField() - meta = {'allow_inheritance': True} - - Person.drop_collection() - - Person(name="Ross", address=Address(city="London")).save() - - person = Person.objects.first() - person.address.city = "Lundenne" - person.save() - - self.assertEqual(Person.objects.first().address.city, "Lundenne") - - person = Person.objects.first() - person.address = Address(city="Londinium") - person.save() - - self.assertEqual(Person.objects.first().address.city, "Londinium") - - person = Person.objects.first() - person.age = 35 - person.save() - self.assertEqual(Person.objects.first().age, 35) diff --git a/tests/test_fields.py b/tests/test_fields.py index 98065501..118521fd 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -14,10 +14,11 @@ import gridfs from nose.plugins.skip import SkipTest from mongoengine import * from mongoengine.connection import get_db -from mongoengine.base import _document_registry, NotRegistered +from mongoengine.base import _document_registry +from mongoengine.errors import NotRegistered from mongoengine.python_support import PY3, b, StringIO, bin_type -TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png') +TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'document/mongoengine.png') class FieldTest(unittest.TestCase): diff --git a/tests/test_queryset.py b/tests/test_queryset.py index 690df5eb..cdabadb4 100644 --- a/tests/test_queryset.py +++ b/tests/test_queryset.py @@ -11,9 +11,12 @@ from mongoengine import * from mongoengine.connection import get_connection from mongoengine.python_support import PY3 from mongoengine.tests import query_counter -from mongoengine.queryset import (QuerySet, QuerySetManager, +from mongoengine.queryset import (Q, QuerySet, QuerySetManager, MultipleObjectsReturned, DoesNotExist, - QueryFieldList) + QueryFieldList, queryset_manager) +from mongoengine.queryset import transform +from mongoengine.errors import InvalidQueryError + class QuerySetTest(unittest.TestCase): @@ -40,19 +43,34 @@ class QuerySetTest(unittest.TestCase): def test_transform_query(self): """Ensure that the _transform_query function operates correctly. """ - self.assertEqual(QuerySet._transform_query(name='test', age=30), + self.assertEqual(transform.query(name='test', age=30), {'name': 'test', 'age': 30}) - self.assertEqual(QuerySet._transform_query(age__lt=30), + self.assertEqual(transform.query(age__lt=30), {'age': {'$lt': 30}}) - self.assertEqual(QuerySet._transform_query(age__gt=20, age__lt=50), + self.assertEqual(transform.query(age__gt=20, age__lt=50), {'age': {'$gt': 20, '$lt': 50}}) - self.assertEqual(QuerySet._transform_query(age=20, age__gt=50), + self.assertEqual(transform.query(age=20, age__gt=50), {'age': 20}) - self.assertEqual(QuerySet._transform_query(friend__age__gte=30), + self.assertEqual(transform.query(friend__age__gte=30), {'friend.age': {'$gte': 30}}) - self.assertEqual(QuerySet._transform_query(name__exists=True), + self.assertEqual(transform.query(name__exists=True), {'name': {'$exists': True}}) + def test_cannot_perform_joins_references(self): + + class BlogPost(Document): + author = ReferenceField(self.Person) + author2 = GenericReferenceField() + + def test_reference(): + list(BlogPost.objects(author__name="test")) + + self.assertRaises(InvalidQueryError, test_reference) + + def test_generic_reference(): + list(BlogPost.objects(author2__name="test")) + + def test_find(self): """Ensure that a query returns a valid set of results. """ @@ -921,10 +939,9 @@ class QuerySetTest(unittest.TestCase): # find all published blog posts before 2010-01-07 published_posts = BlogPost.published() published_posts = published_posts.filter( - published_date__lt=datetime(2010, 1, 7, 0, 0 ,0)) + published_date__lt=datetime(2010, 1, 7, 0, 0, 0)) self.assertEqual(published_posts.count(), 2) - blog_posts = BlogPost.objects blog_posts = blog_posts.filter(blog__in=[blog_1, blog_2]) blog_posts = blog_posts.filter(blog=blog_3) @@ -935,7 +952,7 @@ class QuerySetTest(unittest.TestCase): def test_raw_and_merging(self): class Doc(Document): - pass + meta = {'allow_inheritance': False} raw_query = Doc.objects(__raw__={'deleted': False, 'scraped': 'yes', @@ -943,7 +960,7 @@ class QuerySetTest(unittest.TestCase): {'attachments.views.extracted':'no'}] })._query - expected = {'deleted': False, '_types': 'Doc', 'scraped': 'yes', + expected = {'deleted': False, 'scraped': 'yes', '$nor': [{'views.extracted': 'no'}, {'attachments.views.extracted': 'no'}]} self.assertEqual(expected, raw_query) @@ -2598,68 +2615,6 @@ class QuerySetTest(unittest.TestCase): Group.drop_collection() - def test_types_index(self): - """Ensure that and index is used when '_types' is being used in a - query. - """ - class BlogPost(Document): - date = DateTimeField() - meta = {'indexes': ['-date']} - - # Indexes are lazy so use list() to perform query - list(BlogPost.objects) - info = BlogPost.objects._collection.index_information() - info = [value['key'] for key, value in info.iteritems()] - self.assertTrue([('_types', 1)] in info) - self.assertTrue([('_types', 1), ('date', -1)] in info) - - def test_dont_index_types(self): - """Ensure that index_types will, when disabled, prevent _types - being added to all indices. - """ - class BloggPost(Document): - date = DateTimeField() - meta = {'index_types': False, - 'indexes': ['-date']} - - # Indexes are lazy so use list() to perform query - list(BloggPost.objects) - info = BloggPost.objects._collection.index_information() - info = [value['key'] for key, value in info.iteritems()] - self.assertTrue([('_types', 1)] not in info) - self.assertTrue([('date', -1)] in info) - - BloggPost.drop_collection() - - class BloggPost(Document): - title = StringField() - meta = {'allow_inheritance': False} - - # _types is not used on objects where allow_inheritance is False - list(BloggPost.objects) - info = BloggPost.objects._collection.index_information() - self.assertFalse([('_types', 1)] in info.values()) - - BloggPost.drop_collection() - - def test_types_index_with_pk(self): - - class Comment(EmbeddedDocument): - comment_id = IntField(required=True) - - try: - class BlogPost(Document): - comments = EmbeddedDocumentField(Comment) - meta = {'indexes': [{'fields': ['pk', 'comments.comment_id'], - 'unique': True}]} - except UnboundLocalError: - self.fail('Unbound local error at types index + pk definition') - - info = BlogPost.objects._collection.index_information() - info = [value['key'] for key, value in info.iteritems()] - index_item = [(u'_types', 1), (u'_id', 1), (u'comments.comment_id', 1)] - self.assertTrue(index_item in info) - def test_dict_with_custom_baseclass(self): """Ensure DictField working with custom base clases. """ @@ -3116,6 +3071,7 @@ class QuerySetTest(unittest.TestCase): """ class Comment(Document): message = StringField() + meta = {'allow_inheritance': True} Comment.objects.ensure_index('message') @@ -3124,7 +3080,7 @@ class QuerySetTest(unittest.TestCase): value.get('unique', False), value.get('sparse', False)) for key, value in info.iteritems()] - self.assertTrue(([('_types', 1), ('message', 1)], False, False) in info) + self.assertTrue(([('_cls', 1), ('message', 1)], False, False) in info) def test_where(self): """Ensure that where clauses work.