Inheritance is off by default (MongoEngine/mongoengine#122)

This commit is contained in:
Ross Lawley 2012-10-17 11:36:18 +00:00
parent 6f29d12386
commit 3d5b6ae332
20 changed files with 245 additions and 177 deletions

View File

@ -4,6 +4,7 @@ Changelog
Changes in 0.8
==============
- Inheritance is off by default (MongoEngine/mongoengine#122)
- Remove _types and just use _cls for inheritance (MongoEngine/mongoengine#148)

View File

@ -462,9 +462,10 @@ If a dictionary is passed then the following options are available:
The fields to index. Specified in the same format as described above.
:attr:`cls` (Default: True)
If you have polymorphic models that inherit and have `allow_inheritance`
turned on, you can configure whether the index should have the
:attr:`_cls` field added automatically to the start of the index.
If you have polymorphic models that inherit and have
:attr:`allow_inheritance` turned on, you can configure whether the index
should have the :attr:`_cls` field added automatically to the start of the
index.
:attr:`sparse` (Default: False)
Whether the index should be sparse.
@ -573,7 +574,9 @@ defined, you may subclass it and add any extra fields or methods you may need.
As this is new class is not a direct subclass of
:class:`~mongoengine.Document`, it will not be stored in its own collection; it
will use the same collection as its superclass uses. This allows for more
convenient and efficient retrieval of related documents::
convenient and efficient retrieval of related documents - all you need do is
set :attr:`allow_inheritance` to True in the :attr:`meta` data for a
document.::
# Stored in a collection named 'page'
class Page(Document):
@ -585,25 +588,20 @@ convenient and efficient retrieval of related documents::
class DatedPage(Page):
date = DateTimeField()
.. note:: From 0.7 onwards you must declare `allow_inheritance` in the document meta.
.. note:: From 0.8 onwards you must declare :attr:`allow_inheritance` defaults
to False, meaning you must set it to True to use inheritance.
Working with existing data
--------------------------
To enable correct retrieval of documents involved in this kind of heirarchy,
an extra attribute is stored on each document in the database: :attr:`_cls`.
These are hidden from the user through the MongoEngine interface, but may not
be present if you are trying to use MongoEngine with an existing database.
For this reason, you may disable this inheritance mechansim, removing the
dependency of :attr:`_cls`, enabling you to work with existing databases.
To disable inheritance on a document class, set :attr:`allow_inheritance` to
``False`` in the :attr:`meta` dictionary::
As MongoEngine no longer defaults to needing :attr:`_cls` you can quickly and
easily get working with existing data. Just define the document to match
the expected schema in your database. If you have wildly varying schemas then
a :class:`~mongoengine.DynamicDocument` might be more appropriate.
# Will work with data in an existing collection named 'cmsPage'
class Page(Document):
title = StringField(max_length=200, required=True)
meta = {
'collection': 'cmsPage',
'allow_inheritance': False,
'collection': 'cmsPage'
}

View File

@ -84,12 +84,15 @@ using* the new fields we need to support video posts. This fits with the
Object-Oriented principle of *inheritance* nicely. We can think of
:class:`Post` as a base class, and :class:`TextPost`, :class:`ImagePost` and
:class:`LinkPost` as subclasses of :class:`Post`. In fact, MongoEngine supports
this kind of modelling out of the box::
this kind of modelling out of the box - all you need do is turn on inheritance
by setting :attr:`allow_inheritance` to True in the :attr:`meta`::
class Post(Document):
title = StringField(max_length=120, required=True)
author = ReferenceField(User)
meta = {'allow_inheritance': True}
class TextPost(Post):
content = StringField()

View File

@ -8,10 +8,13 @@ Upgrading
Inheritance
-----------
Data Model
~~~~~~~~~~
The inheritance model has changed, we no longer need to store an array of
`types` with the model we can just use the classname in `_cls`. This means
that you will have to update your indexes for each of your inherited classes
like so:
:attr:`types` with the model we can just use the classname in :attr:`_cls`.
This means that you will have to update your indexes for each of your
inherited classes like so:
# 1. Declaration of the class
class Animal(Document):
@ -40,6 +43,19 @@ like so:
Animal.objects._ensure_indexes()
Document Definition
~~~~~~~~~~~~~~~~~~~
The default for inheritance has changed - its now off by default and
:attr:`_cls` will not be stored automatically with the class. So if you extend
your :class:`~mongoengine.Document` or :class:`~mongoengine.EmbeddedDocuments`
you will need to declare :attr:`allow_inheritance` in the meta data like so:
class Animal(Document):
name = StringField()
meta = {'allow_inheritance': True}
0.6 to 0.7
==========
@ -123,7 +139,7 @@ Document.objects.with_id - now raises an InvalidQueryError if used with a
filter.
FutureWarning - A future warning has been added to all inherited classes that
don't define `allow_inheritance` in their meta.
don't define :attr:`allow_inheritance` in their meta.
You may need to update pyMongo to 2.0 for use with Sharding.

View File

@ -2,7 +2,7 @@ from mongoengine.errors import NotRegistered
__all__ = ('ALLOW_INHERITANCE', 'get_document', '_document_registry')
ALLOW_INHERITANCE = True
ALLOW_INHERITANCE = False
_document_registry = {}

View File

@ -50,7 +50,6 @@ class BaseDocument(object):
for key, value in values.iteritems():
key = self._reverse_db_field_map.get(key, key)
setattr(self, key, value)
# Set any get_fieldname_display methods
self.__set_field_display()
@ -83,6 +82,11 @@ class BaseDocument(object):
if hasattr(self, '_changed_fields'):
self._mark_as_changed(name)
# Check if the user has created a new instance of a class
if (self._is_document and self._initialised
and self._created and name == self._meta['id_field']):
super(BaseDocument, self).__setattr__('_created', False)
if (self._is_document and not self._created and
name in self._meta.get('shard_key', tuple()) and
self._data.get(name) != value):
@ -171,14 +175,24 @@ class BaseDocument(object):
"""Return data dictionary ready for use with MongoDB.
"""
data = {}
for field_name, field in self._fields.items():
value = getattr(self, field_name, None)
for field_name, field in self._fields.iteritems():
value = self._data.get(field_name, None)
if value is not None:
data[field.db_field] = field.to_mongo(value)
# Only add _cls if allow_inheritance is not False
if not (hasattr(self, '_meta') and
self._meta.get('allow_inheritance', ALLOW_INHERITANCE) == False):
value = field.to_mongo(value)
# Handle self generating fields
if value is None and field._auto_gen:
value = field.generate()
self._data[field_name] = value
if value is not None:
data[field.db_field] = value
# Only add _cls if allow_inheritance is True
if (hasattr(self, '_meta') and
self._meta.get('allow_inheritance', ALLOW_INHERITANCE) == True):
data['_cls'] = self._class_name
if '_id' in data and data['_id'] is None:
del data['_id']
@ -194,7 +208,7 @@ class BaseDocument(object):
are present.
"""
# Get a list of tuples of field names and their current values
fields = [(field, getattr(self, name))
fields = [(field, self._data.get(name))
for name, field in self._fields.items()]
# Ensure that each field is matched to a valid value
@ -207,7 +221,7 @@ class BaseDocument(object):
errors[field.name] = error.errors or error
except (ValueError, AttributeError, AssertionError), error:
errors[field.name] = error
elif field.required:
elif field.required and not getattr(field, '_auto_gen', False):
errors[field.name] = ValidationError('Field is required',
field_name=field.name)
if errors:
@ -313,6 +327,7 @@ class BaseDocument(object):
"""
# Handles cases where not loaded from_son but has _id
doc = self.to_mongo()
set_fields = self._get_changed_fields()
set_data = {}
unset_data = {}
@ -370,7 +385,6 @@ class BaseDocument(object):
if hasattr(d, '_fields'):
field_name = d._reverse_db_field_map.get(db_field_name,
db_field_name)
if field_name in d._fields:
default = d._fields.get(field_name).default
else:
@ -379,6 +393,7 @@ class BaseDocument(object):
if default is not None:
if callable(default):
default = default()
if default != value:
continue
@ -399,15 +414,12 @@ class BaseDocument(object):
# get the class name from the document, falling back to the given
# class if unavailable
class_name = son.get('_cls', cls._class_name)
data = dict(("%s" % key, value) for key, value in son.items())
data = dict(("%s" % key, value) for key, value in son.iteritems())
if not UNICODE_KWARGS:
# python 2.6.4 and lower cannot handle unicode keys
# passed to class constructor example: cls(**data)
to_str_keys_recursive(data)
if '_cls' in data:
del data['_cls']
# Return correct subclass for document type
if class_name != cls._class_name:
cls = get_document(class_name)
@ -415,7 +427,7 @@ class BaseDocument(object):
changed_fields = []
errors_dict = {}
for field_name, field in cls._fields.items():
for field_name, field in cls._fields.iteritems():
if field.db_field in data:
value = data[field.db_field]
try:

View File

@ -21,6 +21,7 @@ class BaseField(object):
name = None
_geo_index = False
_auto_gen = False # Call `generate` to generate a value
# These track each time a Field instance is created. Used to retain order.
# The auto_creation_counter is used for fields that MongoEngine implicitly
@ -36,7 +37,6 @@ class BaseField(object):
if name:
msg = "Fields' 'name' attribute deprecated in favour of 'db_field'"
warnings.warn(msg, DeprecationWarning)
self.name = None
self.required = required or primary_key
self.default = default
self.unique = bool(unique or unique_with)
@ -62,7 +62,6 @@ class BaseField(object):
if instance is None:
# Document class being used rather than a document object
return self
# Get value from document instance if available, if not use default
value = instance._data.get(self.name)
@ -241,12 +240,21 @@ class ComplexBaseField(BaseField):
"""Convert a Python type to a MongoDB-compatible type.
"""
Document = _import_class("Document")
EmbeddedDocument = _import_class("EmbeddedDocument")
GenericReferenceField = _import_class("GenericReferenceField")
if isinstance(value, basestring):
return value
if hasattr(value, 'to_mongo'):
return value.to_mongo()
if isinstance(value, Document):
return GenericReferenceField().to_mongo(value)
cls = value.__class__
val = value.to_mongo()
# If we its a document thats not inherited add _cls
if (isinstance(value, EmbeddedDocument)):
val['_cls'] = cls.__name__
return val
is_list = False
if not hasattr(value, 'items'):
@ -258,10 +266,10 @@ class ComplexBaseField(BaseField):
if self.field:
value_dict = dict([(key, self.field.to_mongo(item))
for key, item in value.items()])
for key, item in value.iteritems()])
else:
value_dict = {}
for k, v in value.items():
for k, v in value.iteritems():
if isinstance(v, Document):
# We need the id from the saved object to create the DBRef
if v.pk is None:
@ -274,16 +282,19 @@ class ComplexBaseField(BaseField):
meta = getattr(v, '_meta', {})
allow_inheritance = (
meta.get('allow_inheritance', ALLOW_INHERITANCE)
== False)
if allow_inheritance and not self.field:
GenericReferenceField = _import_class(
"GenericReferenceField")
== True)
if not allow_inheritance and not self.field:
value_dict[k] = GenericReferenceField().to_mongo(v)
else:
collection = v._get_collection_name()
value_dict[k] = DBRef(collection, v.pk)
elif hasattr(v, 'to_mongo'):
value_dict[k] = v.to_mongo()
cls = v.__class__
val = v.to_mongo()
# If we its a document thats not inherited add _cls
if (isinstance(v, (Document, EmbeddedDocument))):
val['_cls'] = cls.__name__
value_dict[k] = val
else:
value_dict[k] = self.to_mongo(v)

View File

@ -34,6 +34,17 @@ class DocumentMetaclass(type):
if 'meta' in attrs:
attrs['_meta'] = attrs.pop('meta')
# EmbeddedDocuments should inherit meta data
if '_meta' not in attrs:
meta = MetaDict()
for base in flattened_bases[::-1]:
# Add any mixin metadata from plain objects
if hasattr(base, 'meta'):
meta.merge(base.meta)
elif hasattr(base, '_meta'):
meta.merge(base._meta)
attrs['_meta'] = meta
# Handle document Fields
# Merge all fields from subclasses
@ -52,6 +63,7 @@ class DocumentMetaclass(type):
if not attr_value.db_field:
attr_value.db_field = attr_name
base_fields[attr_name] = attr_value
doc_fields.update(base_fields)
# Discover any document fields
@ -98,15 +110,7 @@ class DocumentMetaclass(type):
# inheritance of classes where inheritance is set to False
allow_inheritance = base._meta.get('allow_inheritance',
ALLOW_INHERITANCE)
if (not getattr(base, '_is_base_cls', True)
and allow_inheritance is None):
warnings.warn(
"%s uses inheritance, the default for "
"allow_inheritance is changing to off by default. "
"Please add it to the document meta." % name,
FutureWarning
)
elif (allow_inheritance == False and
if (allow_inheritance != True and
not base._meta.get('abstract')):
raise ValueError('Document %s may not be subclassed' %
base.__name__)
@ -353,6 +357,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
if not new_class._meta.get('id_field'):
new_class._meta['id_field'] = 'id'
new_class._fields['id'] = ObjectIdField(db_field='_id')
new_class._fields['id'].name = 'id'
new_class.id = new_class._fields['id']
# Merge in exceptions with parent hierarchy

View File

@ -121,7 +121,10 @@ class DeReference(object):
for key, doc in references.iteritems():
object_map[key] = doc
else: # Generic reference: use the refs data to convert to document
if doc_type and not isinstance(doc_type, (ListField, DictField, MapField,) ):
if isinstance(doc_type, (ListField, DictField, MapField,)):
continue
if doc_type:
references = doc_type._get_db()[col].find({'_id': {'$in': refs}})
for ref in references:
doc = doc_type._from_son(ref)

View File

@ -117,6 +117,7 @@ class Document(BaseDocument):
"""
def fget(self):
return getattr(self, self._meta['id_field'])
def fset(self, value):
return setattr(self, self._meta['id_field'], value)
return property(fget, fset)
@ -125,7 +126,7 @@ class Document(BaseDocument):
@classmethod
def _get_db(cls):
"""Some Model using other db_alias"""
return get_db(cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME ))
return get_db(cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME))
@classmethod
def _get_collection(cls):
@ -212,11 +213,11 @@ class Document(BaseDocument):
doc = self.to_mongo()
created = force_insert or '_id' not in doc
find_delta = ('_id' not in doc or self._created or force_insert)
try:
collection = self.__class__.objects._collection
if created:
if find_delta:
if force_insert:
object_id = collection.insert(doc, safe=safe,
**write_options)
@ -271,7 +272,8 @@ class Document(BaseDocument):
self._changed_fields = []
self._created = False
signals.post_save.send(self.__class__, document=self, created=created)
signals.post_save.send(self.__class__, document=self,
created=find_delta)
return self
def cascade_save(self, warn_cascade=None, *args, **kwargs):
@ -373,6 +375,7 @@ class Document(BaseDocument):
for name in self._dynamic_fields.keys():
setattr(self, name, self._reload(name, obj._data[name]))
self._changed_fields = obj._changed_fields
self._created = False
return obj
def _reload(self, key, value):
@ -464,7 +467,13 @@ class DynamicEmbeddedDocument(EmbeddedDocument):
"""Deletes the attribute by setting to None and allowing _delta to unset
it"""
field_name = args[0]
setattr(self, field_name, None)
if field_name in self._fields:
default = self._fields[field_name].default
if callable(default):
default = default()
setattr(self, field_name, default)
else:
setattr(self, field_name, None)
class MapReduceDocument(object):

View File

@ -16,12 +16,11 @@ from mongoengine.errors import ValidationError
from mongoengine.python_support import (PY3, bin_type, txt_type,
str_types, StringIO)
from base import (BaseField, ComplexBaseField, ObjectIdField,
get_document, BaseDocument)
get_document, BaseDocument, ALLOW_INHERITANCE)
from queryset import DO_NOTHING, QuerySet
from document import Document, EmbeddedDocument
from connection import get_db, DEFAULT_CONNECTION_NAME
try:
from PIL import Image, ImageOps
except ImportError:
@ -314,16 +313,16 @@ class DateTimeField(BaseField):
usecs = 0
kwargs = {'microsecond': usecs}
try: # Seconds are optional, so try converting seconds first.
return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M:%S')[:6],
**kwargs)
return datetime.datetime(*time.strptime(value,
'%Y-%m-%d %H:%M:%S')[:6], **kwargs)
except ValueError:
try: # Try without seconds.
return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M')[:5],
**kwargs)
return datetime.datetime(*time.strptime(value,
'%Y-%m-%d %H:%M')[:5], **kwargs)
except ValueError: # Try without hour/minutes/seconds.
try:
return datetime.datetime(*time.strptime(value, '%Y-%m-%d')[:3],
**kwargs)
return datetime.datetime(*time.strptime(value,
'%Y-%m-%d')[:3], **kwargs)
except ValueError:
return None
@ -410,6 +409,7 @@ class ComplexDateTimeField(StringField):
return super(ComplexDateTimeField, self).__set__(instance, value)
def validate(self, value):
value = self.to_python(value)
if not isinstance(value, datetime.datetime):
self.error('Only datetime objects may used in a '
'ComplexDateTimeField')
@ -422,6 +422,7 @@ class ComplexDateTimeField(StringField):
return original_value
def to_mongo(self, value):
value = self.to_python(value)
return self._convert_from_datetime(value)
def prepare_query_value(self, op, value):
@ -529,7 +530,12 @@ class DynamicField(BaseField):
return value
if hasattr(value, 'to_mongo'):
return value.to_mongo()
cls = value.__class__
val = value.to_mongo()
# If we its a document thats not inherited add _cls
if (isinstance(value, (Document, EmbeddedDocument))):
val['_cls'] = cls.__name__
return val
if not isinstance(value, (dict, list, tuple)):
return value
@ -540,13 +546,12 @@ class DynamicField(BaseField):
value = dict([(k, v) for k, v in enumerate(value)])
data = {}
for k, v in value.items():
for k, v in value.iteritems():
data[k] = self.to_mongo(v)
value = data
if is_list: # Convert back to a list
value = [v for k, v in sorted(data.items(), key=itemgetter(0))]
else:
value = data
value = [v for k, v in sorted(data.iteritems(), key=itemgetter(0))]
return value
def lookup_member(self, member_name):
@ -666,7 +671,6 @@ class DictField(ComplexBaseField):
if op in match_operators and isinstance(value, basestring):
return StringField().prepare_query_value(op, value)
return super(DictField, self).prepare_query_value(op, value)
@ -1323,7 +1327,8 @@ class GeoPointField(BaseField):
class SequenceField(IntField):
"""Provides a sequental counter (see http://www.mongodb.org/display/DOCS/Object+IDs#ObjectIDs-SequenceNumbers)
"""Provides a sequental counter see:
http://www.mongodb.org/display/DOCS/Object+IDs#ObjectIDs-SequenceNumbers
.. note::
@ -1335,17 +1340,21 @@ class SequenceField(IntField):
.. versionadded:: 0.5
"""
def __init__(self, collection_name=None, db_alias = None, sequence_name = None, *args, **kwargs):
_auto_gen = True
def __init__(self, collection_name=None, db_alias=None,
sequence_name=None, *args, **kwargs):
self.collection_name = collection_name or 'mongoengine.counters'
self.db_alias = db_alias or DEFAULT_CONNECTION_NAME
self.sequence_name = sequence_name
return super(SequenceField, self).__init__(*args, **kwargs)
def generate_new_value(self):
def generate(self):
"""
Generate and Increment the counter
"""
sequence_name = self.sequence_name or self.owner_document._get_collection_name()
sequence_name = (self.sequence_name or
self.owner_document._get_collection_name())
sequence_id = "%s.%s" % (sequence_name, self.name)
collection = get_db(alias=self.db_alias)[self.collection_name]
counter = collection.find_and_modify(query={"_id": sequence_id},
@ -1365,7 +1374,7 @@ class SequenceField(IntField):
value = instance._data.get(self.name)
if not value and instance._initialised:
value = self.generate_new_value()
value = self.generate()
instance._data[self.name] = value
instance._mark_as_changed(self.name)
@ -1374,13 +1383,13 @@ class SequenceField(IntField):
def __set__(self, instance, value):
if value is None and instance._initialised:
value = self.generate_new_value()
value = self.generate()
return super(SequenceField, self).__set__(instance, value)
def to_python(self, value):
if value is None:
value = self.generate_new_value()
value = self.generate()
return value

View File

@ -58,7 +58,7 @@ class QuerySet(object):
# If inheritance is allowed, only return instances and instances of
# subclasses of the class being used
if document._meta.get('allow_inheritance') != False:
if document._meta.get('allow_inheritance') == True:
self._initial_query = {"_cls": {"$in": self._document._subclasses}}
self._loaded_fields = QueryFieldList(always_include=['_cls'])
self._cursor_obj = None

View File

@ -29,22 +29,6 @@ class AllWarnings(unittest.TestCase):
# restore default handling of warnings
warnings.showwarning = self.showwarning_default
def test_allow_inheritance_future_warning(self):
"""Add FutureWarning for future allow_inhertiance default change.
"""
class SimpleBase(Document):
a = IntField()
class InheritedClass(SimpleBase):
b = IntField()
InheritedClass()
self.assertEqual(len(self.warning_list), 1)
warning = self.warning_list[0]
self.assertEqual(FutureWarning, warning["category"])
self.assertTrue("InheritedClass" in str(warning["message"]))
def test_dbref_reference_field_future_warning(self):
class Person(Document):
@ -93,7 +77,7 @@ class AllWarnings(unittest.TestCase):
def test_document_collection_syntax_warning(self):
class NonAbstractBase(Document):
pass
meta = {'allow_inheritance': True}
class InheritedDocumentFailTest(NonAbstractBase):
meta = {'collection': 'fail'}

View File

@ -1,4 +1,6 @@
# -*- coding: utf-8 -*-
import sys
sys.path[0:0] = [""]
import unittest
from mongoengine import *
@ -126,9 +128,6 @@ class DeltaTest(unittest.TestCase):
'list_field': ['1', 2, {'hello': 'world'}]
}
self.assertEqual(doc.embedded_field._delta(), (embedded_delta, {}))
embedded_delta.update({
'_cls': 'Embedded',
})
self.assertEqual(doc._delta(),
({'embedded_field': embedded_delta}, {}))
@ -162,6 +161,7 @@ class DeltaTest(unittest.TestCase):
doc.embedded_field.list_field = ['1', 2, embedded_2]
self.assertEqual(doc._get_changed_fields(),
['embedded_field.list_field'])
self.assertEqual(doc.embedded_field._delta(), ({
'list_field': ['1', 2, {
'_cls': 'Embedded',
@ -175,10 +175,10 @@ class DeltaTest(unittest.TestCase):
self.assertEqual(doc._delta(), ({
'embedded_field.list_field': ['1', 2, {
'_cls': 'Embedded',
'string_field': 'hello',
'dict_field': {'hello': 'world'},
'int_field': 1,
'list_field': ['1', 2, {'hello': 'world'}],
'string_field': 'hello',
'dict_field': {'hello': 'world'},
'int_field': 1,
'list_field': ['1', 2, {'hello': 'world'}],
}]
}, {}))
doc.save()
@ -467,9 +467,6 @@ class DeltaTest(unittest.TestCase):
'db_list_field': ['1', 2, {'hello': 'world'}]
}
self.assertEqual(doc.embedded_field._delta(), (embedded_delta, {}))
embedded_delta.update({
'_cls': 'Embedded',
})
self.assertEqual(doc._delta(),
({'db_embedded_field': embedded_delta}, {}))
@ -520,10 +517,10 @@ class DeltaTest(unittest.TestCase):
self.assertEqual(doc._delta(), ({
'db_embedded_field.db_list_field': ['1', 2, {
'_cls': 'Embedded',
'db_string_field': 'hello',
'db_dict_field': {'hello': 'world'},
'db_int_field': 1,
'db_list_field': ['1', 2, {'hello': 'world'}],
'db_string_field': 'hello',
'db_dict_field': {'hello': 'world'},
'db_int_field': 1,
'db_list_field': ['1', 2, {'hello': 'world'}],
}]
}, {}))
doc.save()
@ -686,3 +683,7 @@ class DeltaTest(unittest.TestCase):
doc.list_field = []
self.assertEqual(doc._get_changed_fields(), ['list_field'])
self.assertEqual(doc._delta(), ({}, {'list_field': 1}))
if __name__ == '__main__':
unittest.main()

View File

@ -1,4 +1,7 @@
import unittest
import sys
sys.path[0:0] = [""]
from mongoengine import *
from mongoengine.connection import get_db
@ -161,7 +164,7 @@ class DynamicTest(unittest.TestCase):
embedded_1.list_field = ['1', 2, {'hello': 'world'}]
doc.embedded_field = embedded_1
self.assertEqual(doc.to_mongo(), {"_cls": "Doc",
self.assertEqual(doc.to_mongo(), {
"embedded_field": {
"_cls": "Embedded",
"string_field": "hello",
@ -205,7 +208,7 @@ class DynamicTest(unittest.TestCase):
embedded_1.list_field = ['1', 2, embedded_2]
doc.embedded_field = embedded_1
self.assertEqual(doc.to_mongo(), {"_cls": "Doc",
self.assertEqual(doc.to_mongo(), {
"embedded_field": {
"_cls": "Embedded",
"string_field": "hello",
@ -246,7 +249,6 @@ class DynamicTest(unittest.TestCase):
class Person(DynamicDocument):
name = StringField()
meta = {'allow_inheritance': True}
Person.drop_collection()
@ -268,3 +270,7 @@ class DynamicTest(unittest.TestCase):
person.age = 35
person.save()
self.assertEqual(Person.objects.first().age, 35)
if __name__ == '__main__':
unittest.main()

View File

@ -203,7 +203,6 @@ class InheritanceTest(unittest.TestCase):
class Animal(Document):
name = StringField()
meta = {'allow_inheritance': False}
def create_dog_class():
class Dog(Animal):
@ -258,7 +257,6 @@ class InheritanceTest(unittest.TestCase):
class Comment(EmbeddedDocument):
content = StringField()
meta = {'allow_inheritance': False}
def create_special_comment():
class SpecialComment(Comment):

View File

@ -1,24 +1,22 @@
# -*- coding: utf-8 -*-
from __future__ import with_statement
import sys
sys.path[0:0] = [""]
import bson
import os
import pickle
import pymongo
import sys
import unittest
import uuid
import warnings
from nose.plugins.skip import SkipTest
from datetime import datetime
from tests.fixtures import Base, Mixin, PickleEmbedded, PickleTest
from tests.fixtures import PickleEmbedded, PickleTest
from mongoengine import *
from mongoengine.errors import (NotRegistered, InvalidDocumentError,
InvalidQueryError)
from mongoengine.queryset import NULLIFY, Q
from mongoengine.connection import get_db, get_connection
from mongoengine.connection import get_db
TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png')
@ -461,7 +459,7 @@ class InstanceTest(unittest.TestCase):
doc.validate()
keys = doc._data.keys()
self.assertEqual(2, len(keys))
self.assertTrue(None in keys)
self.assertTrue('id' in keys)
self.assertTrue('e' in keys)
def test_save(self):
@ -656,8 +654,8 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(p1.name, p.parent.name)
def test_update(self):
"""Ensure that an existing document is updated instead of be overwritten.
"""
"""Ensure that an existing document is updated instead of be
overwritten."""
# Create person object and save it to the database
person = self.Person(name='Test User', age=30)
person.save()
@ -753,30 +751,33 @@ class InstanceTest(unittest.TestCase):
float_field = FloatField(default=1.1)
boolean_field = BooleanField(default=True)
datetime_field = DateTimeField(default=datetime.now)
embedded_document_field = EmbeddedDocumentField(EmbeddedDoc, default=lambda: EmbeddedDoc())
embedded_document_field = EmbeddedDocumentField(EmbeddedDoc,
default=lambda: EmbeddedDoc())
list_field = ListField(default=lambda: [1, 2, 3])
dict_field = DictField(default=lambda: {"hello": "world"})
objectid_field = ObjectIdField(default=bson.ObjectId)
reference_field = ReferenceField(Simple, default=lambda: Simple().save())
reference_field = ReferenceField(Simple, default=lambda:
Simple().save())
map_field = MapField(IntField(), default=lambda: {"simple": 1})
decimal_field = DecimalField(default=1.0)
complex_datetime_field = ComplexDateTimeField(default=datetime.now)
url_field = URLField(default="http://mongoengine.org")
dynamic_field = DynamicField(default=1)
generic_reference_field = GenericReferenceField(default=lambda: Simple().save())
sorted_list_field = SortedListField(IntField(), default=lambda: [1, 2, 3])
generic_reference_field = GenericReferenceField(
default=lambda: Simple().save())
sorted_list_field = SortedListField(IntField(),
default=lambda: [1, 2, 3])
email_field = EmailField(default="ross@example.com")
geo_point_field = GeoPointField(default=lambda: [1, 2])
sequence_field = SequenceField()
uuid_field = UUIDField(default=uuid.uuid4)
generic_embedded_document_field = GenericEmbeddedDocumentField(default=lambda: EmbeddedDoc())
generic_embedded_document_field = GenericEmbeddedDocumentField(
default=lambda: EmbeddedDoc())
Simple.drop_collection()
Doc.drop_collection()
Doc().save()
my_doc = Doc.objects.only("string_field").first()
my_doc.string_field = "string"
my_doc.save()
@ -1707,9 +1708,12 @@ class InstanceTest(unittest.TestCase):
peter = User.objects.create(name="Peter")
# Bob
Book.objects.create(name="1", author=bob, extra={"a": bob.to_dbref(), "b": [karl.to_dbref(), susan.to_dbref()]})
Book.objects.create(name="2", author=bob, extra={"a": bob.to_dbref(), "b": karl.to_dbref()})
Book.objects.create(name="3", author=bob, extra={"a": bob.to_dbref(), "c": [jon.to_dbref(), peter.to_dbref()]})
Book.objects.create(name="1", author=bob, extra={
"a": bob.to_dbref(), "b": [karl.to_dbref(), susan.to_dbref()]})
Book.objects.create(name="2", author=bob, extra={
"a": bob.to_dbref(), "b": karl.to_dbref()})
Book.objects.create(name="3", author=bob, extra={
"a": bob.to_dbref(), "c": [jon.to_dbref(), peter.to_dbref()]})
Book.objects.create(name="4", author=bob)
# Jon
@ -1717,23 +1721,26 @@ class InstanceTest(unittest.TestCase):
Book.objects.create(name="6", author=peter)
Book.objects.create(name="7", author=jon)
Book.objects.create(name="8", author=jon)
Book.objects.create(name="9", author=jon, extra={"a": peter.to_dbref()})
Book.objects.create(name="9", author=jon,
extra={"a": peter.to_dbref()})
# Checks
self.assertEqual(u",".join([str(b) for b in Book.objects.all()]) , "1,2,3,4,5,6,7,8,9")
self.assertEqual(",".join([str(b) for b in Book.objects.all()]),
"1,2,3,4,5,6,7,8,9")
# bob related books
self.assertEqual(u",".join([str(b) for b in Book.objects.filter(
self.assertEqual(",".join([str(b) for b in Book.objects.filter(
Q(extra__a=bob) |
Q(author=bob) |
Q(extra__b=bob))]) ,
Q(extra__b=bob))]),
"1,2,3,4")
# Susan & Karl related books
self.assertEqual(u",".join([str(b) for b in Book.objects.filter(
self.assertEqual(",".join([str(b) for b in Book.objects.filter(
Q(extra__a__all=[karl, susan]) |
Q(author__all=[karl, susan ]) |
Q(extra__b__all=[karl.to_dbref(), susan.to_dbref()])
) ]) , "1")
Q(author__all=[karl, susan]) |
Q(extra__b__all=[
karl.to_dbref(), susan.to_dbref()]))
]), "1")
# $Where
self.assertEqual(u",".join([str(b) for b in Book.objects.filter(
@ -1743,7 +1750,7 @@ class InstanceTest(unittest.TestCase):
return this.name == '1' ||
this.name == '2';}"""
}
) ]), "1,2")
)]), "1,2")
class ValidatorErrorTest(unittest.TestCase):

View File

@ -331,14 +331,10 @@ class FieldTest(unittest.TestCase):
return "<Person: %s>" % self.name
Person.drop_collection()
paul = Person(name="Paul")
paul.save()
maria = Person(name="Maria")
maria.save()
julia = Person(name='Julia')
julia.save()
anna = Person(name='Anna')
anna.save()
paul = Person(name="Paul").save()
maria = Person(name="Maria").save()
julia = Person(name='Julia').save()
anna = Person(name='Anna').save()
paul.other.friends = [maria, julia, anna]
paul.other.name = "Paul's friends"

View File

@ -727,7 +727,7 @@ class FieldTest(unittest.TestCase):
"""Ensure that the list fields can handle the complex types."""
class SettingBase(EmbeddedDocument):
pass
meta = {'allow_inheritance': True}
class StringSetting(SettingBase):
value = StringField()
@ -743,8 +743,9 @@ class FieldTest(unittest.TestCase):
e.mapping.append(StringSetting(value='foo'))
e.mapping.append(IntegerSetting(value=42))
e.mapping.append({'number': 1, 'string': 'Hi!', 'float': 1.001,
'complex': IntegerSetting(value=42), 'list':
[IntegerSetting(value=42), StringSetting(value='foo')]})
'complex': IntegerSetting(value=42),
'list': [IntegerSetting(value=42),
StringSetting(value='foo')]})
e.save()
e2 = Simple.objects.get(id=e.id)
@ -844,7 +845,7 @@ class FieldTest(unittest.TestCase):
"""Ensure that the dict field can handle the complex types."""
class SettingBase(EmbeddedDocument):
pass
meta = {'allow_inheritance': True}
class StringSetting(SettingBase):
value = StringField()
@ -859,9 +860,11 @@ class FieldTest(unittest.TestCase):
e = Simple()
e.mapping['somestring'] = StringSetting(value='foo')
e.mapping['someint'] = IntegerSetting(value=42)
e.mapping['nested_dict'] = {'number': 1, 'string': 'Hi!', 'float': 1.001,
'complex': IntegerSetting(value=42), 'list':
[IntegerSetting(value=42), StringSetting(value='foo')]}
e.mapping['nested_dict'] = {'number': 1, 'string': 'Hi!',
'float': 1.001,
'complex': IntegerSetting(value=42),
'list': [IntegerSetting(value=42),
StringSetting(value='foo')]}
e.save()
e2 = Simple.objects.get(id=e.id)
@ -915,7 +918,7 @@ class FieldTest(unittest.TestCase):
"""Ensure that the MapField can handle complex declared types."""
class SettingBase(EmbeddedDocument):
pass
meta = {"allow_inheritance": True}
class StringSetting(SettingBase):
value = StringField()
@ -951,7 +954,8 @@ class FieldTest(unittest.TestCase):
number = IntField(default=0, db_field='i')
class Test(Document):
my_map = MapField(field=EmbeddedDocumentField(Embedded), db_field='x')
my_map = MapField(field=EmbeddedDocumentField(Embedded),
db_field='x')
Test.drop_collection()
@ -1038,6 +1042,8 @@ class FieldTest(unittest.TestCase):
class User(EmbeddedDocument):
name = StringField()
meta = {'allow_inheritance': True}
class PowerUser(User):
power = IntField()
@ -1046,8 +1052,10 @@ class FieldTest(unittest.TestCase):
author = EmbeddedDocumentField(User)
post = BlogPost(content='What I did today...')
post.author = User(name='Test User')
post.author = PowerUser(name='Test User', power=47)
post.save()
self.assertEqual(47, BlogPost.objects.first().author.power)
def test_reference_validation(self):
"""Ensure that invalid docment objects cannot be assigned to reference
@ -2117,12 +2125,12 @@ class FieldTest(unittest.TestCase):
def test_sequence_fields_reload(self):
class Animal(Document):
counter = SequenceField()
type = StringField()
name = StringField()
self.db['mongoengine.counters'].drop()
Animal.drop_collection()
a = Animal(type="Boi")
a = Animal(name="Boi")
a.save()
self.assertEqual(a.counter, 1)

View File

@ -647,7 +647,8 @@ class QuerySetTest(unittest.TestCase):
self.assertRaises(NotUniqueError, throw_operation_error_not_unique)
self.assertEqual(Blog.objects.count(), 2)
Blog.objects.insert([blog2, blog3], write_options={'continue_on_error': True})
Blog.objects.insert([blog2, blog3], write_options={
'continue_on_error': True})
self.assertEqual(Blog.objects.count(), 3)
def test_get_changed_fields_query_count(self):
@ -673,7 +674,7 @@ class QuerySetTest(unittest.TestCase):
r2 = Project(name="r2").save()
r3 = Project(name="r3").save()
p1 = Person(name="p1", projects=[r1, r2]).save()
p2 = Person(name="p2", projects=[r2]).save()
p2 = Person(name="p2", projects=[r2, r3]).save()
o1 = Organization(name="o1", employees=[p1]).save()
with query_counter() as q:
@ -688,24 +689,24 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(q, 0)
fresh_o1 = Organization.objects.get(id=o1.id)
fresh_o1.save()
fresh_o1.save() # No changes, does nothing
self.assertEqual(q, 2)
self.assertEqual(q, 1)
with query_counter() as q:
self.assertEqual(q, 0)
fresh_o1 = Organization.objects.get(id=o1.id)
fresh_o1.save(cascade=False)
fresh_o1.save(cascade=False) # No changes, does nothing
self.assertEqual(q, 2)
self.assertEqual(q, 1)
with query_counter() as q:
self.assertEqual(q, 0)
fresh_o1 = Organization.objects.get(id=o1.id)
fresh_o1.employees.append(p2)
fresh_o1.save(cascade=False)
fresh_o1.employees.append(p2) # Dereferences
fresh_o1.save(cascade=False) # Saves
self.assertEqual(q, 3)