diff --git a/docs/changelog.rst b/docs/changelog.rst index 8388b05a..1970bf02 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,7 @@ Changelog Changes in 0.8 ============== +- Inheritance is off by default (MongoEngine/mongoengine#122) - Remove _types and just use _cls for inheritance (MongoEngine/mongoengine#148) diff --git a/docs/guide/defining-documents.rst b/docs/guide/defining-documents.rst index cf3b5a6f..ea8e05b2 100644 --- a/docs/guide/defining-documents.rst +++ b/docs/guide/defining-documents.rst @@ -462,9 +462,10 @@ If a dictionary is passed then the following options are available: The fields to index. Specified in the same format as described above. :attr:`cls` (Default: True) - If you have polymorphic models that inherit and have `allow_inheritance` - turned on, you can configure whether the index should have the - :attr:`_cls` field added automatically to the start of the index. + If you have polymorphic models that inherit and have + :attr:`allow_inheritance` turned on, you can configure whether the index + should have the :attr:`_cls` field added automatically to the start of the + index. :attr:`sparse` (Default: False) Whether the index should be sparse. @@ -573,7 +574,9 @@ defined, you may subclass it and add any extra fields or methods you may need. As this is new class is not a direct subclass of :class:`~mongoengine.Document`, it will not be stored in its own collection; it will use the same collection as its superclass uses. This allows for more -convenient and efficient retrieval of related documents:: +convenient and efficient retrieval of related documents - all you need do is +set :attr:`allow_inheritance` to True in the :attr:`meta` data for a +document.:: # Stored in a collection named 'page' class Page(Document): @@ -585,25 +588,20 @@ convenient and efficient retrieval of related documents:: class DatedPage(Page): date = DateTimeField() -.. note:: From 0.7 onwards you must declare `allow_inheritance` in the document meta. +.. note:: From 0.8 onwards you must declare :attr:`allow_inheritance` defaults + to False, meaning you must set it to True to use inheritance. Working with existing data -------------------------- -To enable correct retrieval of documents involved in this kind of heirarchy, -an extra attribute is stored on each document in the database: :attr:`_cls`. -These are hidden from the user through the MongoEngine interface, but may not -be present if you are trying to use MongoEngine with an existing database. - -For this reason, you may disable this inheritance mechansim, removing the -dependency of :attr:`_cls`, enabling you to work with existing databases. -To disable inheritance on a document class, set :attr:`allow_inheritance` to -``False`` in the :attr:`meta` dictionary:: +As MongoEngine no longer defaults to needing :attr:`_cls` you can quickly and +easily get working with existing data. Just define the document to match +the expected schema in your database. If you have wildly varying schemas then +a :class:`~mongoengine.DynamicDocument` might be more appropriate. # Will work with data in an existing collection named 'cmsPage' class Page(Document): title = StringField(max_length=200, required=True) meta = { - 'collection': 'cmsPage', - 'allow_inheritance': False, + 'collection': 'cmsPage' } diff --git a/docs/tutorial.rst b/docs/tutorial.rst index a5284c8f..c2fb5b91 100644 --- a/docs/tutorial.rst +++ b/docs/tutorial.rst @@ -84,12 +84,15 @@ using* the new fields we need to support video posts. This fits with the Object-Oriented principle of *inheritance* nicely. We can think of :class:`Post` as a base class, and :class:`TextPost`, :class:`ImagePost` and :class:`LinkPost` as subclasses of :class:`Post`. In fact, MongoEngine supports -this kind of modelling out of the box:: +this kind of modelling out of the box - all you need do is turn on inheritance +by setting :attr:`allow_inheritance` to True in the :attr:`meta`:: class Post(Document): title = StringField(max_length=120, required=True) author = ReferenceField(User) + meta = {'allow_inheritance': True} + class TextPost(Post): content = StringField() diff --git a/docs/upgrade.rst b/docs/upgrade.rst index 99e3078c..bf0a8421 100644 --- a/docs/upgrade.rst +++ b/docs/upgrade.rst @@ -8,10 +8,13 @@ Upgrading Inheritance ----------- +Data Model +~~~~~~~~~~ + The inheritance model has changed, we no longer need to store an array of -`types` with the model we can just use the classname in `_cls`. This means -that you will have to update your indexes for each of your inherited classes -like so: +:attr:`types` with the model we can just use the classname in :attr:`_cls`. +This means that you will have to update your indexes for each of your +inherited classes like so: # 1. Declaration of the class class Animal(Document): @@ -40,6 +43,19 @@ like so: Animal.objects._ensure_indexes() +Document Definition +~~~~~~~~~~~~~~~~~~~ + +The default for inheritance has changed - its now off by default and +:attr:`_cls` will not be stored automatically with the class. So if you extend +your :class:`~mongoengine.Document` or :class:`~mongoengine.EmbeddedDocuments` +you will need to declare :attr:`allow_inheritance` in the meta data like so: + + class Animal(Document): + name = StringField() + + meta = {'allow_inheritance': True} + 0.6 to 0.7 ========== @@ -123,7 +139,7 @@ Document.objects.with_id - now raises an InvalidQueryError if used with a filter. FutureWarning - A future warning has been added to all inherited classes that -don't define `allow_inheritance` in their meta. +don't define :attr:`allow_inheritance` in their meta. You may need to update pyMongo to 2.0 for use with Sharding. diff --git a/mongoengine/base/common.py b/mongoengine/base/common.py index 648561be..82728d1e 100644 --- a/mongoengine/base/common.py +++ b/mongoengine/base/common.py @@ -2,7 +2,7 @@ from mongoengine.errors import NotRegistered __all__ = ('ALLOW_INHERITANCE', 'get_document', '_document_registry') -ALLOW_INHERITANCE = True +ALLOW_INHERITANCE = False _document_registry = {} diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index af97e1f2..bc509af2 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -50,7 +50,6 @@ class BaseDocument(object): for key, value in values.iteritems(): key = self._reverse_db_field_map.get(key, key) setattr(self, key, value) - # Set any get_fieldname_display methods self.__set_field_display() @@ -83,6 +82,11 @@ class BaseDocument(object): if hasattr(self, '_changed_fields'): self._mark_as_changed(name) + # Check if the user has created a new instance of a class + if (self._is_document and self._initialised + and self._created and name == self._meta['id_field']): + super(BaseDocument, self).__setattr__('_created', False) + if (self._is_document and not self._created and name in self._meta.get('shard_key', tuple()) and self._data.get(name) != value): @@ -171,14 +175,24 @@ class BaseDocument(object): """Return data dictionary ready for use with MongoDB. """ data = {} - for field_name, field in self._fields.items(): - value = getattr(self, field_name, None) + for field_name, field in self._fields.iteritems(): + value = self._data.get(field_name, None) if value is not None: - data[field.db_field] = field.to_mongo(value) - # Only add _cls if allow_inheritance is not False - if not (hasattr(self, '_meta') and - self._meta.get('allow_inheritance', ALLOW_INHERITANCE) == False): + value = field.to_mongo(value) + + # Handle self generating fields + if value is None and field._auto_gen: + value = field.generate() + self._data[field_name] = value + + if value is not None: + data[field.db_field] = value + + # Only add _cls if allow_inheritance is True + if (hasattr(self, '_meta') and + self._meta.get('allow_inheritance', ALLOW_INHERITANCE) == True): data['_cls'] = self._class_name + if '_id' in data and data['_id'] is None: del data['_id'] @@ -194,7 +208,7 @@ class BaseDocument(object): are present. """ # Get a list of tuples of field names and their current values - fields = [(field, getattr(self, name)) + fields = [(field, self._data.get(name)) for name, field in self._fields.items()] # Ensure that each field is matched to a valid value @@ -207,7 +221,7 @@ class BaseDocument(object): errors[field.name] = error.errors or error except (ValueError, AttributeError, AssertionError), error: errors[field.name] = error - elif field.required: + elif field.required and not getattr(field, '_auto_gen', False): errors[field.name] = ValidationError('Field is required', field_name=field.name) if errors: @@ -313,6 +327,7 @@ class BaseDocument(object): """ # Handles cases where not loaded from_son but has _id doc = self.to_mongo() + set_fields = self._get_changed_fields() set_data = {} unset_data = {} @@ -370,7 +385,6 @@ class BaseDocument(object): if hasattr(d, '_fields'): field_name = d._reverse_db_field_map.get(db_field_name, db_field_name) - if field_name in d._fields: default = d._fields.get(field_name).default else: @@ -379,6 +393,7 @@ class BaseDocument(object): if default is not None: if callable(default): default = default() + if default != value: continue @@ -399,15 +414,12 @@ class BaseDocument(object): # get the class name from the document, falling back to the given # class if unavailable class_name = son.get('_cls', cls._class_name) - data = dict(("%s" % key, value) for key, value in son.items()) + data = dict(("%s" % key, value) for key, value in son.iteritems()) if not UNICODE_KWARGS: # python 2.6.4 and lower cannot handle unicode keys # passed to class constructor example: cls(**data) to_str_keys_recursive(data) - if '_cls' in data: - del data['_cls'] - # Return correct subclass for document type if class_name != cls._class_name: cls = get_document(class_name) @@ -415,7 +427,7 @@ class BaseDocument(object): changed_fields = [] errors_dict = {} - for field_name, field in cls._fields.items(): + for field_name, field in cls._fields.iteritems(): if field.db_field in data: value = data[field.db_field] try: diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index 44f5e131..00e040ca 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -21,6 +21,7 @@ class BaseField(object): name = None _geo_index = False + _auto_gen = False # Call `generate` to generate a value # These track each time a Field instance is created. Used to retain order. # The auto_creation_counter is used for fields that MongoEngine implicitly @@ -36,7 +37,6 @@ class BaseField(object): if name: msg = "Fields' 'name' attribute deprecated in favour of 'db_field'" warnings.warn(msg, DeprecationWarning) - self.name = None self.required = required or primary_key self.default = default self.unique = bool(unique or unique_with) @@ -62,7 +62,6 @@ class BaseField(object): if instance is None: # Document class being used rather than a document object return self - # Get value from document instance if available, if not use default value = instance._data.get(self.name) @@ -241,12 +240,21 @@ class ComplexBaseField(BaseField): """Convert a Python type to a MongoDB-compatible type. """ Document = _import_class("Document") + EmbeddedDocument = _import_class("EmbeddedDocument") + GenericReferenceField = _import_class("GenericReferenceField") if isinstance(value, basestring): return value if hasattr(value, 'to_mongo'): - return value.to_mongo() + if isinstance(value, Document): + return GenericReferenceField().to_mongo(value) + cls = value.__class__ + val = value.to_mongo() + # If we its a document thats not inherited add _cls + if (isinstance(value, EmbeddedDocument)): + val['_cls'] = cls.__name__ + return val is_list = False if not hasattr(value, 'items'): @@ -258,10 +266,10 @@ class ComplexBaseField(BaseField): if self.field: value_dict = dict([(key, self.field.to_mongo(item)) - for key, item in value.items()]) + for key, item in value.iteritems()]) else: value_dict = {} - for k, v in value.items(): + for k, v in value.iteritems(): if isinstance(v, Document): # We need the id from the saved object to create the DBRef if v.pk is None: @@ -274,16 +282,19 @@ class ComplexBaseField(BaseField): meta = getattr(v, '_meta', {}) allow_inheritance = ( meta.get('allow_inheritance', ALLOW_INHERITANCE) - == False) - if allow_inheritance and not self.field: - GenericReferenceField = _import_class( - "GenericReferenceField") + == True) + if not allow_inheritance and not self.field: value_dict[k] = GenericReferenceField().to_mongo(v) else: collection = v._get_collection_name() value_dict[k] = DBRef(collection, v.pk) elif hasattr(v, 'to_mongo'): - value_dict[k] = v.to_mongo() + cls = v.__class__ + val = v.to_mongo() + # If we its a document thats not inherited add _cls + if (isinstance(v, (Document, EmbeddedDocument))): + val['_cls'] = cls.__name__ + value_dict[k] = val else: value_dict[k] = self.to_mongo(v) diff --git a/mongoengine/base/metaclasses.py b/mongoengine/base/metaclasses.py index f87b03e4..e68ec13d 100644 --- a/mongoengine/base/metaclasses.py +++ b/mongoengine/base/metaclasses.py @@ -34,6 +34,17 @@ class DocumentMetaclass(type): if 'meta' in attrs: attrs['_meta'] = attrs.pop('meta') + # EmbeddedDocuments should inherit meta data + if '_meta' not in attrs: + meta = MetaDict() + for base in flattened_bases[::-1]: + # Add any mixin metadata from plain objects + if hasattr(base, 'meta'): + meta.merge(base.meta) + elif hasattr(base, '_meta'): + meta.merge(base._meta) + attrs['_meta'] = meta + # Handle document Fields # Merge all fields from subclasses @@ -52,6 +63,7 @@ class DocumentMetaclass(type): if not attr_value.db_field: attr_value.db_field = attr_name base_fields[attr_name] = attr_value + doc_fields.update(base_fields) # Discover any document fields @@ -98,15 +110,7 @@ class DocumentMetaclass(type): # inheritance of classes where inheritance is set to False allow_inheritance = base._meta.get('allow_inheritance', ALLOW_INHERITANCE) - if (not getattr(base, '_is_base_cls', True) - and allow_inheritance is None): - warnings.warn( - "%s uses inheritance, the default for " - "allow_inheritance is changing to off by default. " - "Please add it to the document meta." % name, - FutureWarning - ) - elif (allow_inheritance == False and + if (allow_inheritance != True and not base._meta.get('abstract')): raise ValueError('Document %s may not be subclassed' % base.__name__) @@ -353,6 +357,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): if not new_class._meta.get('id_field'): new_class._meta['id_field'] = 'id' new_class._fields['id'] = ObjectIdField(db_field='_id') + new_class._fields['id'].name = 'id' new_class.id = new_class._fields['id'] # Merge in exceptions with parent hierarchy diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index 59cc0a58..25d46b46 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -121,7 +121,10 @@ class DeReference(object): for key, doc in references.iteritems(): object_map[key] = doc else: # Generic reference: use the refs data to convert to document - if doc_type and not isinstance(doc_type, (ListField, DictField, MapField,) ): + if isinstance(doc_type, (ListField, DictField, MapField,)): + continue + + if doc_type: references = doc_type._get_db()[col].find({'_id': {'$in': refs}}) for ref in references: doc = doc_type._from_son(ref) diff --git a/mongoengine/document.py b/mongoengine/document.py index b1ce13ad..95dd6246 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -117,6 +117,7 @@ class Document(BaseDocument): """ def fget(self): return getattr(self, self._meta['id_field']) + def fset(self, value): return setattr(self, self._meta['id_field'], value) return property(fget, fset) @@ -125,7 +126,7 @@ class Document(BaseDocument): @classmethod def _get_db(cls): """Some Model using other db_alias""" - return get_db(cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME )) + return get_db(cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME)) @classmethod def _get_collection(cls): @@ -212,11 +213,11 @@ class Document(BaseDocument): doc = self.to_mongo() - created = force_insert or '_id' not in doc + find_delta = ('_id' not in doc or self._created or force_insert) try: collection = self.__class__.objects._collection - if created: + if find_delta: if force_insert: object_id = collection.insert(doc, safe=safe, **write_options) @@ -271,7 +272,8 @@ class Document(BaseDocument): self._changed_fields = [] self._created = False - signals.post_save.send(self.__class__, document=self, created=created) + signals.post_save.send(self.__class__, document=self, + created=find_delta) return self def cascade_save(self, warn_cascade=None, *args, **kwargs): @@ -373,6 +375,7 @@ class Document(BaseDocument): for name in self._dynamic_fields.keys(): setattr(self, name, self._reload(name, obj._data[name])) self._changed_fields = obj._changed_fields + self._created = False return obj def _reload(self, key, value): @@ -464,7 +467,13 @@ class DynamicEmbeddedDocument(EmbeddedDocument): """Deletes the attribute by setting to None and allowing _delta to unset it""" field_name = args[0] - setattr(self, field_name, None) + if field_name in self._fields: + default = self._fields[field_name].default + if callable(default): + default = default() + setattr(self, field_name, default) + else: + setattr(self, field_name, None) class MapReduceDocument(object): diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 9bcba9f1..15e1626f 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -16,12 +16,11 @@ from mongoengine.errors import ValidationError from mongoengine.python_support import (PY3, bin_type, txt_type, str_types, StringIO) from base import (BaseField, ComplexBaseField, ObjectIdField, - get_document, BaseDocument) + get_document, BaseDocument, ALLOW_INHERITANCE) from queryset import DO_NOTHING, QuerySet from document import Document, EmbeddedDocument from connection import get_db, DEFAULT_CONNECTION_NAME - try: from PIL import Image, ImageOps except ImportError: @@ -314,16 +313,16 @@ class DateTimeField(BaseField): usecs = 0 kwargs = {'microsecond': usecs} try: # Seconds are optional, so try converting seconds first. - return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M:%S')[:6], - **kwargs) + return datetime.datetime(*time.strptime(value, + '%Y-%m-%d %H:%M:%S')[:6], **kwargs) except ValueError: try: # Try without seconds. - return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M')[:5], - **kwargs) + return datetime.datetime(*time.strptime(value, + '%Y-%m-%d %H:%M')[:5], **kwargs) except ValueError: # Try without hour/minutes/seconds. try: - return datetime.datetime(*time.strptime(value, '%Y-%m-%d')[:3], - **kwargs) + return datetime.datetime(*time.strptime(value, + '%Y-%m-%d')[:3], **kwargs) except ValueError: return None @@ -410,6 +409,7 @@ class ComplexDateTimeField(StringField): return super(ComplexDateTimeField, self).__set__(instance, value) def validate(self, value): + value = self.to_python(value) if not isinstance(value, datetime.datetime): self.error('Only datetime objects may used in a ' 'ComplexDateTimeField') @@ -422,6 +422,7 @@ class ComplexDateTimeField(StringField): return original_value def to_mongo(self, value): + value = self.to_python(value) return self._convert_from_datetime(value) def prepare_query_value(self, op, value): @@ -529,7 +530,12 @@ class DynamicField(BaseField): return value if hasattr(value, 'to_mongo'): - return value.to_mongo() + cls = value.__class__ + val = value.to_mongo() + # If we its a document thats not inherited add _cls + if (isinstance(value, (Document, EmbeddedDocument))): + val['_cls'] = cls.__name__ + return val if not isinstance(value, (dict, list, tuple)): return value @@ -540,13 +546,12 @@ class DynamicField(BaseField): value = dict([(k, v) for k, v in enumerate(value)]) data = {} - for k, v in value.items(): + for k, v in value.iteritems(): data[k] = self.to_mongo(v) + value = data if is_list: # Convert back to a list - value = [v for k, v in sorted(data.items(), key=itemgetter(0))] - else: - value = data + value = [v for k, v in sorted(data.iteritems(), key=itemgetter(0))] return value def lookup_member(self, member_name): @@ -666,7 +671,6 @@ class DictField(ComplexBaseField): if op in match_operators and isinstance(value, basestring): return StringField().prepare_query_value(op, value) - return super(DictField, self).prepare_query_value(op, value) @@ -1323,7 +1327,8 @@ class GeoPointField(BaseField): class SequenceField(IntField): - """Provides a sequental counter (see http://www.mongodb.org/display/DOCS/Object+IDs#ObjectIDs-SequenceNumbers) + """Provides a sequental counter see: + http://www.mongodb.org/display/DOCS/Object+IDs#ObjectIDs-SequenceNumbers .. note:: @@ -1335,17 +1340,21 @@ class SequenceField(IntField): .. versionadded:: 0.5 """ - def __init__(self, collection_name=None, db_alias = None, sequence_name = None, *args, **kwargs): + _auto_gen = True + + def __init__(self, collection_name=None, db_alias=None, + sequence_name=None, *args, **kwargs): self.collection_name = collection_name or 'mongoengine.counters' self.db_alias = db_alias or DEFAULT_CONNECTION_NAME self.sequence_name = sequence_name return super(SequenceField, self).__init__(*args, **kwargs) - def generate_new_value(self): + def generate(self): """ Generate and Increment the counter """ - sequence_name = self.sequence_name or self.owner_document._get_collection_name() + sequence_name = (self.sequence_name or + self.owner_document._get_collection_name()) sequence_id = "%s.%s" % (sequence_name, self.name) collection = get_db(alias=self.db_alias)[self.collection_name] counter = collection.find_and_modify(query={"_id": sequence_id}, @@ -1365,7 +1374,7 @@ class SequenceField(IntField): value = instance._data.get(self.name) if not value and instance._initialised: - value = self.generate_new_value() + value = self.generate() instance._data[self.name] = value instance._mark_as_changed(self.name) @@ -1374,13 +1383,13 @@ class SequenceField(IntField): def __set__(self, instance, value): if value is None and instance._initialised: - value = self.generate_new_value() + value = self.generate() return super(SequenceField, self).__set__(instance, value) def to_python(self, value): if value is None: - value = self.generate_new_value() + value = self.generate() return value diff --git a/mongoengine/queryset/queryset.py b/mongoengine/queryset/queryset.py index 51080663..dd7200b2 100644 --- a/mongoengine/queryset/queryset.py +++ b/mongoengine/queryset/queryset.py @@ -58,7 +58,7 @@ class QuerySet(object): # If inheritance is allowed, only return instances and instances of # subclasses of the class being used - if document._meta.get('allow_inheritance') != False: + if document._meta.get('allow_inheritance') == True: self._initial_query = {"_cls": {"$in": self._document._subclasses}} self._loaded_fields = QueryFieldList(always_include=['_cls']) self._cursor_obj = None diff --git a/tests/all_warnings/__init__.py b/tests/all_warnings/__init__.py index 72de8222..4609c5a5 100644 --- a/tests/all_warnings/__init__.py +++ b/tests/all_warnings/__init__.py @@ -29,22 +29,6 @@ class AllWarnings(unittest.TestCase): # restore default handling of warnings warnings.showwarning = self.showwarning_default - def test_allow_inheritance_future_warning(self): - """Add FutureWarning for future allow_inhertiance default change. - """ - - class SimpleBase(Document): - a = IntField() - - class InheritedClass(SimpleBase): - b = IntField() - - InheritedClass() - self.assertEqual(len(self.warning_list), 1) - warning = self.warning_list[0] - self.assertEqual(FutureWarning, warning["category"]) - self.assertTrue("InheritedClass" in str(warning["message"])) - def test_dbref_reference_field_future_warning(self): class Person(Document): @@ -93,7 +77,7 @@ class AllWarnings(unittest.TestCase): def test_document_collection_syntax_warning(self): class NonAbstractBase(Document): - pass + meta = {'allow_inheritance': True} class InheritedDocumentFailTest(NonAbstractBase): meta = {'collection': 'fail'} diff --git a/tests/document/delta.py b/tests/document/delta.py index f8a071d6..c6191d9b 100644 --- a/tests/document/delta.py +++ b/tests/document/delta.py @@ -1,4 +1,6 @@ # -*- coding: utf-8 -*- +import sys +sys.path[0:0] = [""] import unittest from mongoengine import * @@ -126,9 +128,6 @@ class DeltaTest(unittest.TestCase): 'list_field': ['1', 2, {'hello': 'world'}] } self.assertEqual(doc.embedded_field._delta(), (embedded_delta, {})) - embedded_delta.update({ - '_cls': 'Embedded', - }) self.assertEqual(doc._delta(), ({'embedded_field': embedded_delta}, {})) @@ -162,6 +161,7 @@ class DeltaTest(unittest.TestCase): doc.embedded_field.list_field = ['1', 2, embedded_2] self.assertEqual(doc._get_changed_fields(), ['embedded_field.list_field']) + self.assertEqual(doc.embedded_field._delta(), ({ 'list_field': ['1', 2, { '_cls': 'Embedded', @@ -175,10 +175,10 @@ class DeltaTest(unittest.TestCase): self.assertEqual(doc._delta(), ({ 'embedded_field.list_field': ['1', 2, { '_cls': 'Embedded', - 'string_field': 'hello', - 'dict_field': {'hello': 'world'}, - 'int_field': 1, - 'list_field': ['1', 2, {'hello': 'world'}], + 'string_field': 'hello', + 'dict_field': {'hello': 'world'}, + 'int_field': 1, + 'list_field': ['1', 2, {'hello': 'world'}], }] }, {})) doc.save() @@ -467,9 +467,6 @@ class DeltaTest(unittest.TestCase): 'db_list_field': ['1', 2, {'hello': 'world'}] } self.assertEqual(doc.embedded_field._delta(), (embedded_delta, {})) - embedded_delta.update({ - '_cls': 'Embedded', - }) self.assertEqual(doc._delta(), ({'db_embedded_field': embedded_delta}, {})) @@ -520,10 +517,10 @@ class DeltaTest(unittest.TestCase): self.assertEqual(doc._delta(), ({ 'db_embedded_field.db_list_field': ['1', 2, { '_cls': 'Embedded', - 'db_string_field': 'hello', - 'db_dict_field': {'hello': 'world'}, - 'db_int_field': 1, - 'db_list_field': ['1', 2, {'hello': 'world'}], + 'db_string_field': 'hello', + 'db_dict_field': {'hello': 'world'}, + 'db_int_field': 1, + 'db_list_field': ['1', 2, {'hello': 'world'}], }] }, {})) doc.save() @@ -686,3 +683,7 @@ class DeltaTest(unittest.TestCase): doc.list_field = [] self.assertEqual(doc._get_changed_fields(), ['list_field']) self.assertEqual(doc._delta(), ({}, {'list_field': 1})) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/document/dynamic.py b/tests/document/dynamic.py index ef279179..d879b54c 100644 --- a/tests/document/dynamic.py +++ b/tests/document/dynamic.py @@ -1,4 +1,7 @@ import unittest +import sys + +sys.path[0:0] = [""] from mongoengine import * from mongoengine.connection import get_db @@ -161,7 +164,7 @@ class DynamicTest(unittest.TestCase): embedded_1.list_field = ['1', 2, {'hello': 'world'}] doc.embedded_field = embedded_1 - self.assertEqual(doc.to_mongo(), {"_cls": "Doc", + self.assertEqual(doc.to_mongo(), { "embedded_field": { "_cls": "Embedded", "string_field": "hello", @@ -205,7 +208,7 @@ class DynamicTest(unittest.TestCase): embedded_1.list_field = ['1', 2, embedded_2] doc.embedded_field = embedded_1 - self.assertEqual(doc.to_mongo(), {"_cls": "Doc", + self.assertEqual(doc.to_mongo(), { "embedded_field": { "_cls": "Embedded", "string_field": "hello", @@ -246,7 +249,6 @@ class DynamicTest(unittest.TestCase): class Person(DynamicDocument): name = StringField() - meta = {'allow_inheritance': True} Person.drop_collection() @@ -268,3 +270,7 @@ class DynamicTest(unittest.TestCase): person.age = 35 person.save() self.assertEqual(Person.objects.first().age, 35) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/document/inheritance.py b/tests/document/inheritance.py index d269ac0e..08e29048 100644 --- a/tests/document/inheritance.py +++ b/tests/document/inheritance.py @@ -203,7 +203,6 @@ class InheritanceTest(unittest.TestCase): class Animal(Document): name = StringField() - meta = {'allow_inheritance': False} def create_dog_class(): class Dog(Animal): @@ -258,7 +257,6 @@ class InheritanceTest(unittest.TestCase): class Comment(EmbeddedDocument): content = StringField() - meta = {'allow_inheritance': False} def create_special_comment(): class SpecialComment(Comment): diff --git a/tests/document/instance.py b/tests/document/instance.py index 95f37d9e..fcc43bad 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -1,24 +1,22 @@ # -*- coding: utf-8 -*- from __future__ import with_statement +import sys +sys.path[0:0] = [""] + import bson import os import pickle -import pymongo -import sys import unittest import uuid -import warnings -from nose.plugins.skip import SkipTest from datetime import datetime - -from tests.fixtures import Base, Mixin, PickleEmbedded, PickleTest +from tests.fixtures import PickleEmbedded, PickleTest from mongoengine import * from mongoengine.errors import (NotRegistered, InvalidDocumentError, InvalidQueryError) from mongoengine.queryset import NULLIFY, Q -from mongoengine.connection import get_db, get_connection +from mongoengine.connection import get_db TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png') @@ -461,7 +459,7 @@ class InstanceTest(unittest.TestCase): doc.validate() keys = doc._data.keys() self.assertEqual(2, len(keys)) - self.assertTrue(None in keys) + self.assertTrue('id' in keys) self.assertTrue('e' in keys) def test_save(self): @@ -656,8 +654,8 @@ class InstanceTest(unittest.TestCase): self.assertEqual(p1.name, p.parent.name) def test_update(self): - """Ensure that an existing document is updated instead of be overwritten. - """ + """Ensure that an existing document is updated instead of be + overwritten.""" # Create person object and save it to the database person = self.Person(name='Test User', age=30) person.save() @@ -753,30 +751,33 @@ class InstanceTest(unittest.TestCase): float_field = FloatField(default=1.1) boolean_field = BooleanField(default=True) datetime_field = DateTimeField(default=datetime.now) - embedded_document_field = EmbeddedDocumentField(EmbeddedDoc, default=lambda: EmbeddedDoc()) + embedded_document_field = EmbeddedDocumentField(EmbeddedDoc, + default=lambda: EmbeddedDoc()) list_field = ListField(default=lambda: [1, 2, 3]) dict_field = DictField(default=lambda: {"hello": "world"}) objectid_field = ObjectIdField(default=bson.ObjectId) - reference_field = ReferenceField(Simple, default=lambda: Simple().save()) + reference_field = ReferenceField(Simple, default=lambda: + Simple().save()) map_field = MapField(IntField(), default=lambda: {"simple": 1}) decimal_field = DecimalField(default=1.0) complex_datetime_field = ComplexDateTimeField(default=datetime.now) url_field = URLField(default="http://mongoengine.org") dynamic_field = DynamicField(default=1) - generic_reference_field = GenericReferenceField(default=lambda: Simple().save()) - sorted_list_field = SortedListField(IntField(), default=lambda: [1, 2, 3]) + generic_reference_field = GenericReferenceField( + default=lambda: Simple().save()) + sorted_list_field = SortedListField(IntField(), + default=lambda: [1, 2, 3]) email_field = EmailField(default="ross@example.com") geo_point_field = GeoPointField(default=lambda: [1, 2]) sequence_field = SequenceField() uuid_field = UUIDField(default=uuid.uuid4) - generic_embedded_document_field = GenericEmbeddedDocumentField(default=lambda: EmbeddedDoc()) - + generic_embedded_document_field = GenericEmbeddedDocumentField( + default=lambda: EmbeddedDoc()) Simple.drop_collection() Doc.drop_collection() Doc().save() - my_doc = Doc.objects.only("string_field").first() my_doc.string_field = "string" my_doc.save() @@ -1707,9 +1708,12 @@ class InstanceTest(unittest.TestCase): peter = User.objects.create(name="Peter") # Bob - Book.objects.create(name="1", author=bob, extra={"a": bob.to_dbref(), "b": [karl.to_dbref(), susan.to_dbref()]}) - Book.objects.create(name="2", author=bob, extra={"a": bob.to_dbref(), "b": karl.to_dbref()}) - Book.objects.create(name="3", author=bob, extra={"a": bob.to_dbref(), "c": [jon.to_dbref(), peter.to_dbref()]}) + Book.objects.create(name="1", author=bob, extra={ + "a": bob.to_dbref(), "b": [karl.to_dbref(), susan.to_dbref()]}) + Book.objects.create(name="2", author=bob, extra={ + "a": bob.to_dbref(), "b": karl.to_dbref()}) + Book.objects.create(name="3", author=bob, extra={ + "a": bob.to_dbref(), "c": [jon.to_dbref(), peter.to_dbref()]}) Book.objects.create(name="4", author=bob) # Jon @@ -1717,23 +1721,26 @@ class InstanceTest(unittest.TestCase): Book.objects.create(name="6", author=peter) Book.objects.create(name="7", author=jon) Book.objects.create(name="8", author=jon) - Book.objects.create(name="9", author=jon, extra={"a": peter.to_dbref()}) + Book.objects.create(name="9", author=jon, + extra={"a": peter.to_dbref()}) # Checks - self.assertEqual(u",".join([str(b) for b in Book.objects.all()]) , "1,2,3,4,5,6,7,8,9") + self.assertEqual(",".join([str(b) for b in Book.objects.all()]), + "1,2,3,4,5,6,7,8,9") # bob related books - self.assertEqual(u",".join([str(b) for b in Book.objects.filter( + self.assertEqual(",".join([str(b) for b in Book.objects.filter( Q(extra__a=bob) | Q(author=bob) | - Q(extra__b=bob))]) , + Q(extra__b=bob))]), "1,2,3,4") # Susan & Karl related books - self.assertEqual(u",".join([str(b) for b in Book.objects.filter( + self.assertEqual(",".join([str(b) for b in Book.objects.filter( Q(extra__a__all=[karl, susan]) | - Q(author__all=[karl, susan ]) | - Q(extra__b__all=[karl.to_dbref(), susan.to_dbref()]) - ) ]) , "1") + Q(author__all=[karl, susan]) | + Q(extra__b__all=[ + karl.to_dbref(), susan.to_dbref()])) + ]), "1") # $Where self.assertEqual(u",".join([str(b) for b in Book.objects.filter( @@ -1743,7 +1750,7 @@ class InstanceTest(unittest.TestCase): return this.name == '1' || this.name == '2';}""" } - ) ]), "1,2") + )]), "1,2") class ValidatorErrorTest(unittest.TestCase): diff --git a/tests/test_dereference.py b/tests/test_dereference.py index 7b149dbd..c9631ebb 100644 --- a/tests/test_dereference.py +++ b/tests/test_dereference.py @@ -331,14 +331,10 @@ class FieldTest(unittest.TestCase): return "" % self.name Person.drop_collection() - paul = Person(name="Paul") - paul.save() - maria = Person(name="Maria") - maria.save() - julia = Person(name='Julia') - julia.save() - anna = Person(name='Anna') - anna.save() + paul = Person(name="Paul").save() + maria = Person(name="Maria").save() + julia = Person(name='Julia').save() + anna = Person(name='Anna').save() paul.other.friends = [maria, julia, anna] paul.other.name = "Paul's friends" diff --git a/tests/test_fields.py b/tests/test_fields.py index 118521fd..1c13a58c 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -727,7 +727,7 @@ class FieldTest(unittest.TestCase): """Ensure that the list fields can handle the complex types.""" class SettingBase(EmbeddedDocument): - pass + meta = {'allow_inheritance': True} class StringSetting(SettingBase): value = StringField() @@ -743,8 +743,9 @@ class FieldTest(unittest.TestCase): e.mapping.append(StringSetting(value='foo')) e.mapping.append(IntegerSetting(value=42)) e.mapping.append({'number': 1, 'string': 'Hi!', 'float': 1.001, - 'complex': IntegerSetting(value=42), 'list': - [IntegerSetting(value=42), StringSetting(value='foo')]}) + 'complex': IntegerSetting(value=42), + 'list': [IntegerSetting(value=42), + StringSetting(value='foo')]}) e.save() e2 = Simple.objects.get(id=e.id) @@ -844,7 +845,7 @@ class FieldTest(unittest.TestCase): """Ensure that the dict field can handle the complex types.""" class SettingBase(EmbeddedDocument): - pass + meta = {'allow_inheritance': True} class StringSetting(SettingBase): value = StringField() @@ -859,9 +860,11 @@ class FieldTest(unittest.TestCase): e = Simple() e.mapping['somestring'] = StringSetting(value='foo') e.mapping['someint'] = IntegerSetting(value=42) - e.mapping['nested_dict'] = {'number': 1, 'string': 'Hi!', 'float': 1.001, - 'complex': IntegerSetting(value=42), 'list': - [IntegerSetting(value=42), StringSetting(value='foo')]} + e.mapping['nested_dict'] = {'number': 1, 'string': 'Hi!', + 'float': 1.001, + 'complex': IntegerSetting(value=42), + 'list': [IntegerSetting(value=42), + StringSetting(value='foo')]} e.save() e2 = Simple.objects.get(id=e.id) @@ -915,7 +918,7 @@ class FieldTest(unittest.TestCase): """Ensure that the MapField can handle complex declared types.""" class SettingBase(EmbeddedDocument): - pass + meta = {"allow_inheritance": True} class StringSetting(SettingBase): value = StringField() @@ -951,7 +954,8 @@ class FieldTest(unittest.TestCase): number = IntField(default=0, db_field='i') class Test(Document): - my_map = MapField(field=EmbeddedDocumentField(Embedded), db_field='x') + my_map = MapField(field=EmbeddedDocumentField(Embedded), + db_field='x') Test.drop_collection() @@ -1038,6 +1042,8 @@ class FieldTest(unittest.TestCase): class User(EmbeddedDocument): name = StringField() + meta = {'allow_inheritance': True} + class PowerUser(User): power = IntField() @@ -1046,8 +1052,10 @@ class FieldTest(unittest.TestCase): author = EmbeddedDocumentField(User) post = BlogPost(content='What I did today...') - post.author = User(name='Test User') post.author = PowerUser(name='Test User', power=47) + post.save() + + self.assertEqual(47, BlogPost.objects.first().author.power) def test_reference_validation(self): """Ensure that invalid docment objects cannot be assigned to reference @@ -2117,12 +2125,12 @@ class FieldTest(unittest.TestCase): def test_sequence_fields_reload(self): class Animal(Document): counter = SequenceField() - type = StringField() + name = StringField() self.db['mongoengine.counters'].drop() Animal.drop_collection() - a = Animal(type="Boi") + a = Animal(name="Boi") a.save() self.assertEqual(a.counter, 1) diff --git a/tests/test_queryset.py b/tests/test_queryset.py index cdabadb4..e9e78b4f 100644 --- a/tests/test_queryset.py +++ b/tests/test_queryset.py @@ -647,7 +647,8 @@ class QuerySetTest(unittest.TestCase): self.assertRaises(NotUniqueError, throw_operation_error_not_unique) self.assertEqual(Blog.objects.count(), 2) - Blog.objects.insert([blog2, blog3], write_options={'continue_on_error': True}) + Blog.objects.insert([blog2, blog3], write_options={ + 'continue_on_error': True}) self.assertEqual(Blog.objects.count(), 3) def test_get_changed_fields_query_count(self): @@ -673,7 +674,7 @@ class QuerySetTest(unittest.TestCase): r2 = Project(name="r2").save() r3 = Project(name="r3").save() p1 = Person(name="p1", projects=[r1, r2]).save() - p2 = Person(name="p2", projects=[r2]).save() + p2 = Person(name="p2", projects=[r2, r3]).save() o1 = Organization(name="o1", employees=[p1]).save() with query_counter() as q: @@ -688,24 +689,24 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(q, 0) fresh_o1 = Organization.objects.get(id=o1.id) - fresh_o1.save() + fresh_o1.save() # No changes, does nothing - self.assertEqual(q, 2) + self.assertEqual(q, 1) with query_counter() as q: self.assertEqual(q, 0) fresh_o1 = Organization.objects.get(id=o1.id) - fresh_o1.save(cascade=False) + fresh_o1.save(cascade=False) # No changes, does nothing - self.assertEqual(q, 2) + self.assertEqual(q, 1) with query_counter() as q: self.assertEqual(q, 0) fresh_o1 = Organization.objects.get(id=o1.id) - fresh_o1.employees.append(p2) - fresh_o1.save(cascade=False) + fresh_o1.employees.append(p2) # Dereferences + fresh_o1.save(cascade=False) # Saves self.assertEqual(q, 3)