Inheritance is off by default (MongoEngine/mongoengine#122)

This commit is contained in:
Ross Lawley 2012-10-17 11:36:18 +00:00
parent 6f29d12386
commit 3d5b6ae332
20 changed files with 245 additions and 177 deletions

View File

@ -4,6 +4,7 @@ Changelog
Changes in 0.8 Changes in 0.8
============== ==============
- Inheritance is off by default (MongoEngine/mongoengine#122)
- Remove _types and just use _cls for inheritance (MongoEngine/mongoengine#148) - Remove _types and just use _cls for inheritance (MongoEngine/mongoengine#148)

View File

@ -462,9 +462,10 @@ If a dictionary is passed then the following options are available:
The fields to index. Specified in the same format as described above. The fields to index. Specified in the same format as described above.
:attr:`cls` (Default: True) :attr:`cls` (Default: True)
If you have polymorphic models that inherit and have `allow_inheritance` If you have polymorphic models that inherit and have
turned on, you can configure whether the index should have the :attr:`allow_inheritance` turned on, you can configure whether the index
:attr:`_cls` field added automatically to the start of the index. should have the :attr:`_cls` field added automatically to the start of the
index.
:attr:`sparse` (Default: False) :attr:`sparse` (Default: False)
Whether the index should be sparse. Whether the index should be sparse.
@ -573,7 +574,9 @@ defined, you may subclass it and add any extra fields or methods you may need.
As this is new class is not a direct subclass of As this is new class is not a direct subclass of
:class:`~mongoengine.Document`, it will not be stored in its own collection; it :class:`~mongoengine.Document`, it will not be stored in its own collection; it
will use the same collection as its superclass uses. This allows for more will use the same collection as its superclass uses. This allows for more
convenient and efficient retrieval of related documents:: convenient and efficient retrieval of related documents - all you need do is
set :attr:`allow_inheritance` to True in the :attr:`meta` data for a
document.::
# Stored in a collection named 'page' # Stored in a collection named 'page'
class Page(Document): class Page(Document):
@ -585,25 +588,20 @@ convenient and efficient retrieval of related documents::
class DatedPage(Page): class DatedPage(Page):
date = DateTimeField() date = DateTimeField()
.. note:: From 0.7 onwards you must declare `allow_inheritance` in the document meta. .. note:: From 0.8 onwards you must declare :attr:`allow_inheritance` defaults
to False, meaning you must set it to True to use inheritance.
Working with existing data Working with existing data
-------------------------- --------------------------
To enable correct retrieval of documents involved in this kind of heirarchy, As MongoEngine no longer defaults to needing :attr:`_cls` you can quickly and
an extra attribute is stored on each document in the database: :attr:`_cls`. easily get working with existing data. Just define the document to match
These are hidden from the user through the MongoEngine interface, but may not the expected schema in your database. If you have wildly varying schemas then
be present if you are trying to use MongoEngine with an existing database. a :class:`~mongoengine.DynamicDocument` might be more appropriate.
For this reason, you may disable this inheritance mechansim, removing the
dependency of :attr:`_cls`, enabling you to work with existing databases.
To disable inheritance on a document class, set :attr:`allow_inheritance` to
``False`` in the :attr:`meta` dictionary::
# Will work with data in an existing collection named 'cmsPage' # Will work with data in an existing collection named 'cmsPage'
class Page(Document): class Page(Document):
title = StringField(max_length=200, required=True) title = StringField(max_length=200, required=True)
meta = { meta = {
'collection': 'cmsPage', 'collection': 'cmsPage'
'allow_inheritance': False,
} }

View File

@ -84,12 +84,15 @@ using* the new fields we need to support video posts. This fits with the
Object-Oriented principle of *inheritance* nicely. We can think of Object-Oriented principle of *inheritance* nicely. We can think of
:class:`Post` as a base class, and :class:`TextPost`, :class:`ImagePost` and :class:`Post` as a base class, and :class:`TextPost`, :class:`ImagePost` and
:class:`LinkPost` as subclasses of :class:`Post`. In fact, MongoEngine supports :class:`LinkPost` as subclasses of :class:`Post`. In fact, MongoEngine supports
this kind of modelling out of the box:: this kind of modelling out of the box - all you need do is turn on inheritance
by setting :attr:`allow_inheritance` to True in the :attr:`meta`::
class Post(Document): class Post(Document):
title = StringField(max_length=120, required=True) title = StringField(max_length=120, required=True)
author = ReferenceField(User) author = ReferenceField(User)
meta = {'allow_inheritance': True}
class TextPost(Post): class TextPost(Post):
content = StringField() content = StringField()

View File

@ -8,10 +8,13 @@ Upgrading
Inheritance Inheritance
----------- -----------
Data Model
~~~~~~~~~~
The inheritance model has changed, we no longer need to store an array of The inheritance model has changed, we no longer need to store an array of
`types` with the model we can just use the classname in `_cls`. This means :attr:`types` with the model we can just use the classname in :attr:`_cls`.
that you will have to update your indexes for each of your inherited classes This means that you will have to update your indexes for each of your
like so: inherited classes like so:
# 1. Declaration of the class # 1. Declaration of the class
class Animal(Document): class Animal(Document):
@ -40,6 +43,19 @@ like so:
Animal.objects._ensure_indexes() Animal.objects._ensure_indexes()
Document Definition
~~~~~~~~~~~~~~~~~~~
The default for inheritance has changed - its now off by default and
:attr:`_cls` will not be stored automatically with the class. So if you extend
your :class:`~mongoengine.Document` or :class:`~mongoengine.EmbeddedDocuments`
you will need to declare :attr:`allow_inheritance` in the meta data like so:
class Animal(Document):
name = StringField()
meta = {'allow_inheritance': True}
0.6 to 0.7 0.6 to 0.7
========== ==========
@ -123,7 +139,7 @@ Document.objects.with_id - now raises an InvalidQueryError if used with a
filter. filter.
FutureWarning - A future warning has been added to all inherited classes that FutureWarning - A future warning has been added to all inherited classes that
don't define `allow_inheritance` in their meta. don't define :attr:`allow_inheritance` in their meta.
You may need to update pyMongo to 2.0 for use with Sharding. You may need to update pyMongo to 2.0 for use with Sharding.

View File

@ -2,7 +2,7 @@ from mongoengine.errors import NotRegistered
__all__ = ('ALLOW_INHERITANCE', 'get_document', '_document_registry') __all__ = ('ALLOW_INHERITANCE', 'get_document', '_document_registry')
ALLOW_INHERITANCE = True ALLOW_INHERITANCE = False
_document_registry = {} _document_registry = {}

View File

@ -50,7 +50,6 @@ class BaseDocument(object):
for key, value in values.iteritems(): for key, value in values.iteritems():
key = self._reverse_db_field_map.get(key, key) key = self._reverse_db_field_map.get(key, key)
setattr(self, key, value) setattr(self, key, value)
# Set any get_fieldname_display methods # Set any get_fieldname_display methods
self.__set_field_display() self.__set_field_display()
@ -83,6 +82,11 @@ class BaseDocument(object):
if hasattr(self, '_changed_fields'): if hasattr(self, '_changed_fields'):
self._mark_as_changed(name) self._mark_as_changed(name)
# Check if the user has created a new instance of a class
if (self._is_document and self._initialised
and self._created and name == self._meta['id_field']):
super(BaseDocument, self).__setattr__('_created', False)
if (self._is_document and not self._created and if (self._is_document and not self._created and
name in self._meta.get('shard_key', tuple()) and name in self._meta.get('shard_key', tuple()) and
self._data.get(name) != value): self._data.get(name) != value):
@ -171,14 +175,24 @@ class BaseDocument(object):
"""Return data dictionary ready for use with MongoDB. """Return data dictionary ready for use with MongoDB.
""" """
data = {} data = {}
for field_name, field in self._fields.items(): for field_name, field in self._fields.iteritems():
value = getattr(self, field_name, None) value = self._data.get(field_name, None)
if value is not None: if value is not None:
data[field.db_field] = field.to_mongo(value) value = field.to_mongo(value)
# Only add _cls if allow_inheritance is not False
if not (hasattr(self, '_meta') and # Handle self generating fields
self._meta.get('allow_inheritance', ALLOW_INHERITANCE) == False): if value is None and field._auto_gen:
value = field.generate()
self._data[field_name] = value
if value is not None:
data[field.db_field] = value
# Only add _cls if allow_inheritance is True
if (hasattr(self, '_meta') and
self._meta.get('allow_inheritance', ALLOW_INHERITANCE) == True):
data['_cls'] = self._class_name data['_cls'] = self._class_name
if '_id' in data and data['_id'] is None: if '_id' in data and data['_id'] is None:
del data['_id'] del data['_id']
@ -194,7 +208,7 @@ class BaseDocument(object):
are present. are present.
""" """
# Get a list of tuples of field names and their current values # Get a list of tuples of field names and their current values
fields = [(field, getattr(self, name)) fields = [(field, self._data.get(name))
for name, field in self._fields.items()] for name, field in self._fields.items()]
# Ensure that each field is matched to a valid value # Ensure that each field is matched to a valid value
@ -207,7 +221,7 @@ class BaseDocument(object):
errors[field.name] = error.errors or error errors[field.name] = error.errors or error
except (ValueError, AttributeError, AssertionError), error: except (ValueError, AttributeError, AssertionError), error:
errors[field.name] = error errors[field.name] = error
elif field.required: elif field.required and not getattr(field, '_auto_gen', False):
errors[field.name] = ValidationError('Field is required', errors[field.name] = ValidationError('Field is required',
field_name=field.name) field_name=field.name)
if errors: if errors:
@ -313,6 +327,7 @@ class BaseDocument(object):
""" """
# Handles cases where not loaded from_son but has _id # Handles cases where not loaded from_son but has _id
doc = self.to_mongo() doc = self.to_mongo()
set_fields = self._get_changed_fields() set_fields = self._get_changed_fields()
set_data = {} set_data = {}
unset_data = {} unset_data = {}
@ -370,7 +385,6 @@ class BaseDocument(object):
if hasattr(d, '_fields'): if hasattr(d, '_fields'):
field_name = d._reverse_db_field_map.get(db_field_name, field_name = d._reverse_db_field_map.get(db_field_name,
db_field_name) db_field_name)
if field_name in d._fields: if field_name in d._fields:
default = d._fields.get(field_name).default default = d._fields.get(field_name).default
else: else:
@ -379,6 +393,7 @@ class BaseDocument(object):
if default is not None: if default is not None:
if callable(default): if callable(default):
default = default() default = default()
if default != value: if default != value:
continue continue
@ -399,15 +414,12 @@ class BaseDocument(object):
# get the class name from the document, falling back to the given # get the class name from the document, falling back to the given
# class if unavailable # class if unavailable
class_name = son.get('_cls', cls._class_name) class_name = son.get('_cls', cls._class_name)
data = dict(("%s" % key, value) for key, value in son.items()) data = dict(("%s" % key, value) for key, value in son.iteritems())
if not UNICODE_KWARGS: if not UNICODE_KWARGS:
# python 2.6.4 and lower cannot handle unicode keys # python 2.6.4 and lower cannot handle unicode keys
# passed to class constructor example: cls(**data) # passed to class constructor example: cls(**data)
to_str_keys_recursive(data) to_str_keys_recursive(data)
if '_cls' in data:
del data['_cls']
# Return correct subclass for document type # Return correct subclass for document type
if class_name != cls._class_name: if class_name != cls._class_name:
cls = get_document(class_name) cls = get_document(class_name)
@ -415,7 +427,7 @@ class BaseDocument(object):
changed_fields = [] changed_fields = []
errors_dict = {} errors_dict = {}
for field_name, field in cls._fields.items(): for field_name, field in cls._fields.iteritems():
if field.db_field in data: if field.db_field in data:
value = data[field.db_field] value = data[field.db_field]
try: try:

View File

@ -21,6 +21,7 @@ class BaseField(object):
name = None name = None
_geo_index = False _geo_index = False
_auto_gen = False # Call `generate` to generate a value
# These track each time a Field instance is created. Used to retain order. # These track each time a Field instance is created. Used to retain order.
# The auto_creation_counter is used for fields that MongoEngine implicitly # The auto_creation_counter is used for fields that MongoEngine implicitly
@ -36,7 +37,6 @@ class BaseField(object):
if name: if name:
msg = "Fields' 'name' attribute deprecated in favour of 'db_field'" msg = "Fields' 'name' attribute deprecated in favour of 'db_field'"
warnings.warn(msg, DeprecationWarning) warnings.warn(msg, DeprecationWarning)
self.name = None
self.required = required or primary_key self.required = required or primary_key
self.default = default self.default = default
self.unique = bool(unique or unique_with) self.unique = bool(unique or unique_with)
@ -62,7 +62,6 @@ class BaseField(object):
if instance is None: if instance is None:
# Document class being used rather than a document object # Document class being used rather than a document object
return self return self
# Get value from document instance if available, if not use default # Get value from document instance if available, if not use default
value = instance._data.get(self.name) value = instance._data.get(self.name)
@ -241,12 +240,21 @@ class ComplexBaseField(BaseField):
"""Convert a Python type to a MongoDB-compatible type. """Convert a Python type to a MongoDB-compatible type.
""" """
Document = _import_class("Document") Document = _import_class("Document")
EmbeddedDocument = _import_class("EmbeddedDocument")
GenericReferenceField = _import_class("GenericReferenceField")
if isinstance(value, basestring): if isinstance(value, basestring):
return value return value
if hasattr(value, 'to_mongo'): if hasattr(value, 'to_mongo'):
return value.to_mongo() if isinstance(value, Document):
return GenericReferenceField().to_mongo(value)
cls = value.__class__
val = value.to_mongo()
# If we its a document thats not inherited add _cls
if (isinstance(value, EmbeddedDocument)):
val['_cls'] = cls.__name__
return val
is_list = False is_list = False
if not hasattr(value, 'items'): if not hasattr(value, 'items'):
@ -258,10 +266,10 @@ class ComplexBaseField(BaseField):
if self.field: if self.field:
value_dict = dict([(key, self.field.to_mongo(item)) value_dict = dict([(key, self.field.to_mongo(item))
for key, item in value.items()]) for key, item in value.iteritems()])
else: else:
value_dict = {} value_dict = {}
for k, v in value.items(): for k, v in value.iteritems():
if isinstance(v, Document): if isinstance(v, Document):
# We need the id from the saved object to create the DBRef # We need the id from the saved object to create the DBRef
if v.pk is None: if v.pk is None:
@ -274,16 +282,19 @@ class ComplexBaseField(BaseField):
meta = getattr(v, '_meta', {}) meta = getattr(v, '_meta', {})
allow_inheritance = ( allow_inheritance = (
meta.get('allow_inheritance', ALLOW_INHERITANCE) meta.get('allow_inheritance', ALLOW_INHERITANCE)
== False) == True)
if allow_inheritance and not self.field: if not allow_inheritance and not self.field:
GenericReferenceField = _import_class(
"GenericReferenceField")
value_dict[k] = GenericReferenceField().to_mongo(v) value_dict[k] = GenericReferenceField().to_mongo(v)
else: else:
collection = v._get_collection_name() collection = v._get_collection_name()
value_dict[k] = DBRef(collection, v.pk) value_dict[k] = DBRef(collection, v.pk)
elif hasattr(v, 'to_mongo'): elif hasattr(v, 'to_mongo'):
value_dict[k] = v.to_mongo() cls = v.__class__
val = v.to_mongo()
# If we its a document thats not inherited add _cls
if (isinstance(v, (Document, EmbeddedDocument))):
val['_cls'] = cls.__name__
value_dict[k] = val
else: else:
value_dict[k] = self.to_mongo(v) value_dict[k] = self.to_mongo(v)

View File

@ -34,6 +34,17 @@ class DocumentMetaclass(type):
if 'meta' in attrs: if 'meta' in attrs:
attrs['_meta'] = attrs.pop('meta') attrs['_meta'] = attrs.pop('meta')
# EmbeddedDocuments should inherit meta data
if '_meta' not in attrs:
meta = MetaDict()
for base in flattened_bases[::-1]:
# Add any mixin metadata from plain objects
if hasattr(base, 'meta'):
meta.merge(base.meta)
elif hasattr(base, '_meta'):
meta.merge(base._meta)
attrs['_meta'] = meta
# Handle document Fields # Handle document Fields
# Merge all fields from subclasses # Merge all fields from subclasses
@ -52,6 +63,7 @@ class DocumentMetaclass(type):
if not attr_value.db_field: if not attr_value.db_field:
attr_value.db_field = attr_name attr_value.db_field = attr_name
base_fields[attr_name] = attr_value base_fields[attr_name] = attr_value
doc_fields.update(base_fields) doc_fields.update(base_fields)
# Discover any document fields # Discover any document fields
@ -98,15 +110,7 @@ class DocumentMetaclass(type):
# inheritance of classes where inheritance is set to False # inheritance of classes where inheritance is set to False
allow_inheritance = base._meta.get('allow_inheritance', allow_inheritance = base._meta.get('allow_inheritance',
ALLOW_INHERITANCE) ALLOW_INHERITANCE)
if (not getattr(base, '_is_base_cls', True) if (allow_inheritance != True and
and allow_inheritance is None):
warnings.warn(
"%s uses inheritance, the default for "
"allow_inheritance is changing to off by default. "
"Please add it to the document meta." % name,
FutureWarning
)
elif (allow_inheritance == False and
not base._meta.get('abstract')): not base._meta.get('abstract')):
raise ValueError('Document %s may not be subclassed' % raise ValueError('Document %s may not be subclassed' %
base.__name__) base.__name__)
@ -353,6 +357,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
if not new_class._meta.get('id_field'): if not new_class._meta.get('id_field'):
new_class._meta['id_field'] = 'id' new_class._meta['id_field'] = 'id'
new_class._fields['id'] = ObjectIdField(db_field='_id') new_class._fields['id'] = ObjectIdField(db_field='_id')
new_class._fields['id'].name = 'id'
new_class.id = new_class._fields['id'] new_class.id = new_class._fields['id']
# Merge in exceptions with parent hierarchy # Merge in exceptions with parent hierarchy

View File

@ -121,7 +121,10 @@ class DeReference(object):
for key, doc in references.iteritems(): for key, doc in references.iteritems():
object_map[key] = doc object_map[key] = doc
else: # Generic reference: use the refs data to convert to document else: # Generic reference: use the refs data to convert to document
if doc_type and not isinstance(doc_type, (ListField, DictField, MapField,) ): if isinstance(doc_type, (ListField, DictField, MapField,)):
continue
if doc_type:
references = doc_type._get_db()[col].find({'_id': {'$in': refs}}) references = doc_type._get_db()[col].find({'_id': {'$in': refs}})
for ref in references: for ref in references:
doc = doc_type._from_son(ref) doc = doc_type._from_son(ref)

View File

@ -117,6 +117,7 @@ class Document(BaseDocument):
""" """
def fget(self): def fget(self):
return getattr(self, self._meta['id_field']) return getattr(self, self._meta['id_field'])
def fset(self, value): def fset(self, value):
return setattr(self, self._meta['id_field'], value) return setattr(self, self._meta['id_field'], value)
return property(fget, fset) return property(fget, fset)
@ -125,7 +126,7 @@ class Document(BaseDocument):
@classmethod @classmethod
def _get_db(cls): def _get_db(cls):
"""Some Model using other db_alias""" """Some Model using other db_alias"""
return get_db(cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME )) return get_db(cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME))
@classmethod @classmethod
def _get_collection(cls): def _get_collection(cls):
@ -212,11 +213,11 @@ class Document(BaseDocument):
doc = self.to_mongo() doc = self.to_mongo()
created = force_insert or '_id' not in doc find_delta = ('_id' not in doc or self._created or force_insert)
try: try:
collection = self.__class__.objects._collection collection = self.__class__.objects._collection
if created: if find_delta:
if force_insert: if force_insert:
object_id = collection.insert(doc, safe=safe, object_id = collection.insert(doc, safe=safe,
**write_options) **write_options)
@ -271,7 +272,8 @@ class Document(BaseDocument):
self._changed_fields = [] self._changed_fields = []
self._created = False self._created = False
signals.post_save.send(self.__class__, document=self, created=created) signals.post_save.send(self.__class__, document=self,
created=find_delta)
return self return self
def cascade_save(self, warn_cascade=None, *args, **kwargs): def cascade_save(self, warn_cascade=None, *args, **kwargs):
@ -373,6 +375,7 @@ class Document(BaseDocument):
for name in self._dynamic_fields.keys(): for name in self._dynamic_fields.keys():
setattr(self, name, self._reload(name, obj._data[name])) setattr(self, name, self._reload(name, obj._data[name]))
self._changed_fields = obj._changed_fields self._changed_fields = obj._changed_fields
self._created = False
return obj return obj
def _reload(self, key, value): def _reload(self, key, value):
@ -464,7 +467,13 @@ class DynamicEmbeddedDocument(EmbeddedDocument):
"""Deletes the attribute by setting to None and allowing _delta to unset """Deletes the attribute by setting to None and allowing _delta to unset
it""" it"""
field_name = args[0] field_name = args[0]
setattr(self, field_name, None) if field_name in self._fields:
default = self._fields[field_name].default
if callable(default):
default = default()
setattr(self, field_name, default)
else:
setattr(self, field_name, None)
class MapReduceDocument(object): class MapReduceDocument(object):

View File

@ -16,12 +16,11 @@ from mongoengine.errors import ValidationError
from mongoengine.python_support import (PY3, bin_type, txt_type, from mongoengine.python_support import (PY3, bin_type, txt_type,
str_types, StringIO) str_types, StringIO)
from base import (BaseField, ComplexBaseField, ObjectIdField, from base import (BaseField, ComplexBaseField, ObjectIdField,
get_document, BaseDocument) get_document, BaseDocument, ALLOW_INHERITANCE)
from queryset import DO_NOTHING, QuerySet from queryset import DO_NOTHING, QuerySet
from document import Document, EmbeddedDocument from document import Document, EmbeddedDocument
from connection import get_db, DEFAULT_CONNECTION_NAME from connection import get_db, DEFAULT_CONNECTION_NAME
try: try:
from PIL import Image, ImageOps from PIL import Image, ImageOps
except ImportError: except ImportError:
@ -314,16 +313,16 @@ class DateTimeField(BaseField):
usecs = 0 usecs = 0
kwargs = {'microsecond': usecs} kwargs = {'microsecond': usecs}
try: # Seconds are optional, so try converting seconds first. try: # Seconds are optional, so try converting seconds first.
return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M:%S')[:6], return datetime.datetime(*time.strptime(value,
**kwargs) '%Y-%m-%d %H:%M:%S')[:6], **kwargs)
except ValueError: except ValueError:
try: # Try without seconds. try: # Try without seconds.
return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M')[:5], return datetime.datetime(*time.strptime(value,
**kwargs) '%Y-%m-%d %H:%M')[:5], **kwargs)
except ValueError: # Try without hour/minutes/seconds. except ValueError: # Try without hour/minutes/seconds.
try: try:
return datetime.datetime(*time.strptime(value, '%Y-%m-%d')[:3], return datetime.datetime(*time.strptime(value,
**kwargs) '%Y-%m-%d')[:3], **kwargs)
except ValueError: except ValueError:
return None return None
@ -410,6 +409,7 @@ class ComplexDateTimeField(StringField):
return super(ComplexDateTimeField, self).__set__(instance, value) return super(ComplexDateTimeField, self).__set__(instance, value)
def validate(self, value): def validate(self, value):
value = self.to_python(value)
if not isinstance(value, datetime.datetime): if not isinstance(value, datetime.datetime):
self.error('Only datetime objects may used in a ' self.error('Only datetime objects may used in a '
'ComplexDateTimeField') 'ComplexDateTimeField')
@ -422,6 +422,7 @@ class ComplexDateTimeField(StringField):
return original_value return original_value
def to_mongo(self, value): def to_mongo(self, value):
value = self.to_python(value)
return self._convert_from_datetime(value) return self._convert_from_datetime(value)
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
@ -529,7 +530,12 @@ class DynamicField(BaseField):
return value return value
if hasattr(value, 'to_mongo'): if hasattr(value, 'to_mongo'):
return value.to_mongo() cls = value.__class__
val = value.to_mongo()
# If we its a document thats not inherited add _cls
if (isinstance(value, (Document, EmbeddedDocument))):
val['_cls'] = cls.__name__
return val
if not isinstance(value, (dict, list, tuple)): if not isinstance(value, (dict, list, tuple)):
return value return value
@ -540,13 +546,12 @@ class DynamicField(BaseField):
value = dict([(k, v) for k, v in enumerate(value)]) value = dict([(k, v) for k, v in enumerate(value)])
data = {} data = {}
for k, v in value.items(): for k, v in value.iteritems():
data[k] = self.to_mongo(v) data[k] = self.to_mongo(v)
value = data
if is_list: # Convert back to a list if is_list: # Convert back to a list
value = [v for k, v in sorted(data.items(), key=itemgetter(0))] value = [v for k, v in sorted(data.iteritems(), key=itemgetter(0))]
else:
value = data
return value return value
def lookup_member(self, member_name): def lookup_member(self, member_name):
@ -666,7 +671,6 @@ class DictField(ComplexBaseField):
if op in match_operators and isinstance(value, basestring): if op in match_operators and isinstance(value, basestring):
return StringField().prepare_query_value(op, value) return StringField().prepare_query_value(op, value)
return super(DictField, self).prepare_query_value(op, value) return super(DictField, self).prepare_query_value(op, value)
@ -1323,7 +1327,8 @@ class GeoPointField(BaseField):
class SequenceField(IntField): class SequenceField(IntField):
"""Provides a sequental counter (see http://www.mongodb.org/display/DOCS/Object+IDs#ObjectIDs-SequenceNumbers) """Provides a sequental counter see:
http://www.mongodb.org/display/DOCS/Object+IDs#ObjectIDs-SequenceNumbers
.. note:: .. note::
@ -1335,17 +1340,21 @@ class SequenceField(IntField):
.. versionadded:: 0.5 .. versionadded:: 0.5
""" """
def __init__(self, collection_name=None, db_alias = None, sequence_name = None, *args, **kwargs): _auto_gen = True
def __init__(self, collection_name=None, db_alias=None,
sequence_name=None, *args, **kwargs):
self.collection_name = collection_name or 'mongoengine.counters' self.collection_name = collection_name or 'mongoengine.counters'
self.db_alias = db_alias or DEFAULT_CONNECTION_NAME self.db_alias = db_alias or DEFAULT_CONNECTION_NAME
self.sequence_name = sequence_name self.sequence_name = sequence_name
return super(SequenceField, self).__init__(*args, **kwargs) return super(SequenceField, self).__init__(*args, **kwargs)
def generate_new_value(self): def generate(self):
""" """
Generate and Increment the counter Generate and Increment the counter
""" """
sequence_name = self.sequence_name or self.owner_document._get_collection_name() sequence_name = (self.sequence_name or
self.owner_document._get_collection_name())
sequence_id = "%s.%s" % (sequence_name, self.name) sequence_id = "%s.%s" % (sequence_name, self.name)
collection = get_db(alias=self.db_alias)[self.collection_name] collection = get_db(alias=self.db_alias)[self.collection_name]
counter = collection.find_and_modify(query={"_id": sequence_id}, counter = collection.find_and_modify(query={"_id": sequence_id},
@ -1365,7 +1374,7 @@ class SequenceField(IntField):
value = instance._data.get(self.name) value = instance._data.get(self.name)
if not value and instance._initialised: if not value and instance._initialised:
value = self.generate_new_value() value = self.generate()
instance._data[self.name] = value instance._data[self.name] = value
instance._mark_as_changed(self.name) instance._mark_as_changed(self.name)
@ -1374,13 +1383,13 @@ class SequenceField(IntField):
def __set__(self, instance, value): def __set__(self, instance, value):
if value is None and instance._initialised: if value is None and instance._initialised:
value = self.generate_new_value() value = self.generate()
return super(SequenceField, self).__set__(instance, value) return super(SequenceField, self).__set__(instance, value)
def to_python(self, value): def to_python(self, value):
if value is None: if value is None:
value = self.generate_new_value() value = self.generate()
return value return value

View File

@ -58,7 +58,7 @@ class QuerySet(object):
# If inheritance is allowed, only return instances and instances of # If inheritance is allowed, only return instances and instances of
# subclasses of the class being used # subclasses of the class being used
if document._meta.get('allow_inheritance') != False: if document._meta.get('allow_inheritance') == True:
self._initial_query = {"_cls": {"$in": self._document._subclasses}} self._initial_query = {"_cls": {"$in": self._document._subclasses}}
self._loaded_fields = QueryFieldList(always_include=['_cls']) self._loaded_fields = QueryFieldList(always_include=['_cls'])
self._cursor_obj = None self._cursor_obj = None

View File

@ -29,22 +29,6 @@ class AllWarnings(unittest.TestCase):
# restore default handling of warnings # restore default handling of warnings
warnings.showwarning = self.showwarning_default warnings.showwarning = self.showwarning_default
def test_allow_inheritance_future_warning(self):
"""Add FutureWarning for future allow_inhertiance default change.
"""
class SimpleBase(Document):
a = IntField()
class InheritedClass(SimpleBase):
b = IntField()
InheritedClass()
self.assertEqual(len(self.warning_list), 1)
warning = self.warning_list[0]
self.assertEqual(FutureWarning, warning["category"])
self.assertTrue("InheritedClass" in str(warning["message"]))
def test_dbref_reference_field_future_warning(self): def test_dbref_reference_field_future_warning(self):
class Person(Document): class Person(Document):
@ -93,7 +77,7 @@ class AllWarnings(unittest.TestCase):
def test_document_collection_syntax_warning(self): def test_document_collection_syntax_warning(self):
class NonAbstractBase(Document): class NonAbstractBase(Document):
pass meta = {'allow_inheritance': True}
class InheritedDocumentFailTest(NonAbstractBase): class InheritedDocumentFailTest(NonAbstractBase):
meta = {'collection': 'fail'} meta = {'collection': 'fail'}

View File

@ -1,4 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import sys
sys.path[0:0] = [""]
import unittest import unittest
from mongoengine import * from mongoengine import *
@ -126,9 +128,6 @@ class DeltaTest(unittest.TestCase):
'list_field': ['1', 2, {'hello': 'world'}] 'list_field': ['1', 2, {'hello': 'world'}]
} }
self.assertEqual(doc.embedded_field._delta(), (embedded_delta, {})) self.assertEqual(doc.embedded_field._delta(), (embedded_delta, {}))
embedded_delta.update({
'_cls': 'Embedded',
})
self.assertEqual(doc._delta(), self.assertEqual(doc._delta(),
({'embedded_field': embedded_delta}, {})) ({'embedded_field': embedded_delta}, {}))
@ -162,6 +161,7 @@ class DeltaTest(unittest.TestCase):
doc.embedded_field.list_field = ['1', 2, embedded_2] doc.embedded_field.list_field = ['1', 2, embedded_2]
self.assertEqual(doc._get_changed_fields(), self.assertEqual(doc._get_changed_fields(),
['embedded_field.list_field']) ['embedded_field.list_field'])
self.assertEqual(doc.embedded_field._delta(), ({ self.assertEqual(doc.embedded_field._delta(), ({
'list_field': ['1', 2, { 'list_field': ['1', 2, {
'_cls': 'Embedded', '_cls': 'Embedded',
@ -175,10 +175,10 @@ class DeltaTest(unittest.TestCase):
self.assertEqual(doc._delta(), ({ self.assertEqual(doc._delta(), ({
'embedded_field.list_field': ['1', 2, { 'embedded_field.list_field': ['1', 2, {
'_cls': 'Embedded', '_cls': 'Embedded',
'string_field': 'hello', 'string_field': 'hello',
'dict_field': {'hello': 'world'}, 'dict_field': {'hello': 'world'},
'int_field': 1, 'int_field': 1,
'list_field': ['1', 2, {'hello': 'world'}], 'list_field': ['1', 2, {'hello': 'world'}],
}] }]
}, {})) }, {}))
doc.save() doc.save()
@ -467,9 +467,6 @@ class DeltaTest(unittest.TestCase):
'db_list_field': ['1', 2, {'hello': 'world'}] 'db_list_field': ['1', 2, {'hello': 'world'}]
} }
self.assertEqual(doc.embedded_field._delta(), (embedded_delta, {})) self.assertEqual(doc.embedded_field._delta(), (embedded_delta, {}))
embedded_delta.update({
'_cls': 'Embedded',
})
self.assertEqual(doc._delta(), self.assertEqual(doc._delta(),
({'db_embedded_field': embedded_delta}, {})) ({'db_embedded_field': embedded_delta}, {}))
@ -520,10 +517,10 @@ class DeltaTest(unittest.TestCase):
self.assertEqual(doc._delta(), ({ self.assertEqual(doc._delta(), ({
'db_embedded_field.db_list_field': ['1', 2, { 'db_embedded_field.db_list_field': ['1', 2, {
'_cls': 'Embedded', '_cls': 'Embedded',
'db_string_field': 'hello', 'db_string_field': 'hello',
'db_dict_field': {'hello': 'world'}, 'db_dict_field': {'hello': 'world'},
'db_int_field': 1, 'db_int_field': 1,
'db_list_field': ['1', 2, {'hello': 'world'}], 'db_list_field': ['1', 2, {'hello': 'world'}],
}] }]
}, {})) }, {}))
doc.save() doc.save()
@ -686,3 +683,7 @@ class DeltaTest(unittest.TestCase):
doc.list_field = [] doc.list_field = []
self.assertEqual(doc._get_changed_fields(), ['list_field']) self.assertEqual(doc._get_changed_fields(), ['list_field'])
self.assertEqual(doc._delta(), ({}, {'list_field': 1})) self.assertEqual(doc._delta(), ({}, {'list_field': 1}))
if __name__ == '__main__':
unittest.main()

View File

@ -1,4 +1,7 @@
import unittest import unittest
import sys
sys.path[0:0] = [""]
from mongoengine import * from mongoengine import *
from mongoengine.connection import get_db from mongoengine.connection import get_db
@ -161,7 +164,7 @@ class DynamicTest(unittest.TestCase):
embedded_1.list_field = ['1', 2, {'hello': 'world'}] embedded_1.list_field = ['1', 2, {'hello': 'world'}]
doc.embedded_field = embedded_1 doc.embedded_field = embedded_1
self.assertEqual(doc.to_mongo(), {"_cls": "Doc", self.assertEqual(doc.to_mongo(), {
"embedded_field": { "embedded_field": {
"_cls": "Embedded", "_cls": "Embedded",
"string_field": "hello", "string_field": "hello",
@ -205,7 +208,7 @@ class DynamicTest(unittest.TestCase):
embedded_1.list_field = ['1', 2, embedded_2] embedded_1.list_field = ['1', 2, embedded_2]
doc.embedded_field = embedded_1 doc.embedded_field = embedded_1
self.assertEqual(doc.to_mongo(), {"_cls": "Doc", self.assertEqual(doc.to_mongo(), {
"embedded_field": { "embedded_field": {
"_cls": "Embedded", "_cls": "Embedded",
"string_field": "hello", "string_field": "hello",
@ -246,7 +249,6 @@ class DynamicTest(unittest.TestCase):
class Person(DynamicDocument): class Person(DynamicDocument):
name = StringField() name = StringField()
meta = {'allow_inheritance': True}
Person.drop_collection() Person.drop_collection()
@ -268,3 +270,7 @@ class DynamicTest(unittest.TestCase):
person.age = 35 person.age = 35
person.save() person.save()
self.assertEqual(Person.objects.first().age, 35) self.assertEqual(Person.objects.first().age, 35)
if __name__ == '__main__':
unittest.main()

View File

@ -203,7 +203,6 @@ class InheritanceTest(unittest.TestCase):
class Animal(Document): class Animal(Document):
name = StringField() name = StringField()
meta = {'allow_inheritance': False}
def create_dog_class(): def create_dog_class():
class Dog(Animal): class Dog(Animal):
@ -258,7 +257,6 @@ class InheritanceTest(unittest.TestCase):
class Comment(EmbeddedDocument): class Comment(EmbeddedDocument):
content = StringField() content = StringField()
meta = {'allow_inheritance': False}
def create_special_comment(): def create_special_comment():
class SpecialComment(Comment): class SpecialComment(Comment):

View File

@ -1,24 +1,22 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import with_statement from __future__ import with_statement
import sys
sys.path[0:0] = [""]
import bson import bson
import os import os
import pickle import pickle
import pymongo
import sys
import unittest import unittest
import uuid import uuid
import warnings
from nose.plugins.skip import SkipTest
from datetime import datetime from datetime import datetime
from tests.fixtures import PickleEmbedded, PickleTest
from tests.fixtures import Base, Mixin, PickleEmbedded, PickleTest
from mongoengine import * from mongoengine import *
from mongoengine.errors import (NotRegistered, InvalidDocumentError, from mongoengine.errors import (NotRegistered, InvalidDocumentError,
InvalidQueryError) InvalidQueryError)
from mongoengine.queryset import NULLIFY, Q from mongoengine.queryset import NULLIFY, Q
from mongoengine.connection import get_db, get_connection from mongoengine.connection import get_db
TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png') TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png')
@ -461,7 +459,7 @@ class InstanceTest(unittest.TestCase):
doc.validate() doc.validate()
keys = doc._data.keys() keys = doc._data.keys()
self.assertEqual(2, len(keys)) self.assertEqual(2, len(keys))
self.assertTrue(None in keys) self.assertTrue('id' in keys)
self.assertTrue('e' in keys) self.assertTrue('e' in keys)
def test_save(self): def test_save(self):
@ -656,8 +654,8 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(p1.name, p.parent.name) self.assertEqual(p1.name, p.parent.name)
def test_update(self): def test_update(self):
"""Ensure that an existing document is updated instead of be overwritten. """Ensure that an existing document is updated instead of be
""" overwritten."""
# Create person object and save it to the database # Create person object and save it to the database
person = self.Person(name='Test User', age=30) person = self.Person(name='Test User', age=30)
person.save() person.save()
@ -753,30 +751,33 @@ class InstanceTest(unittest.TestCase):
float_field = FloatField(default=1.1) float_field = FloatField(default=1.1)
boolean_field = BooleanField(default=True) boolean_field = BooleanField(default=True)
datetime_field = DateTimeField(default=datetime.now) datetime_field = DateTimeField(default=datetime.now)
embedded_document_field = EmbeddedDocumentField(EmbeddedDoc, default=lambda: EmbeddedDoc()) embedded_document_field = EmbeddedDocumentField(EmbeddedDoc,
default=lambda: EmbeddedDoc())
list_field = ListField(default=lambda: [1, 2, 3]) list_field = ListField(default=lambda: [1, 2, 3])
dict_field = DictField(default=lambda: {"hello": "world"}) dict_field = DictField(default=lambda: {"hello": "world"})
objectid_field = ObjectIdField(default=bson.ObjectId) objectid_field = ObjectIdField(default=bson.ObjectId)
reference_field = ReferenceField(Simple, default=lambda: Simple().save()) reference_field = ReferenceField(Simple, default=lambda:
Simple().save())
map_field = MapField(IntField(), default=lambda: {"simple": 1}) map_field = MapField(IntField(), default=lambda: {"simple": 1})
decimal_field = DecimalField(default=1.0) decimal_field = DecimalField(default=1.0)
complex_datetime_field = ComplexDateTimeField(default=datetime.now) complex_datetime_field = ComplexDateTimeField(default=datetime.now)
url_field = URLField(default="http://mongoengine.org") url_field = URLField(default="http://mongoengine.org")
dynamic_field = DynamicField(default=1) dynamic_field = DynamicField(default=1)
generic_reference_field = GenericReferenceField(default=lambda: Simple().save()) generic_reference_field = GenericReferenceField(
sorted_list_field = SortedListField(IntField(), default=lambda: [1, 2, 3]) default=lambda: Simple().save())
sorted_list_field = SortedListField(IntField(),
default=lambda: [1, 2, 3])
email_field = EmailField(default="ross@example.com") email_field = EmailField(default="ross@example.com")
geo_point_field = GeoPointField(default=lambda: [1, 2]) geo_point_field = GeoPointField(default=lambda: [1, 2])
sequence_field = SequenceField() sequence_field = SequenceField()
uuid_field = UUIDField(default=uuid.uuid4) uuid_field = UUIDField(default=uuid.uuid4)
generic_embedded_document_field = GenericEmbeddedDocumentField(default=lambda: EmbeddedDoc()) generic_embedded_document_field = GenericEmbeddedDocumentField(
default=lambda: EmbeddedDoc())
Simple.drop_collection() Simple.drop_collection()
Doc.drop_collection() Doc.drop_collection()
Doc().save() Doc().save()
my_doc = Doc.objects.only("string_field").first() my_doc = Doc.objects.only("string_field").first()
my_doc.string_field = "string" my_doc.string_field = "string"
my_doc.save() my_doc.save()
@ -1707,9 +1708,12 @@ class InstanceTest(unittest.TestCase):
peter = User.objects.create(name="Peter") peter = User.objects.create(name="Peter")
# Bob # Bob
Book.objects.create(name="1", author=bob, extra={"a": bob.to_dbref(), "b": [karl.to_dbref(), susan.to_dbref()]}) Book.objects.create(name="1", author=bob, extra={
Book.objects.create(name="2", author=bob, extra={"a": bob.to_dbref(), "b": karl.to_dbref()}) "a": bob.to_dbref(), "b": [karl.to_dbref(), susan.to_dbref()]})
Book.objects.create(name="3", author=bob, extra={"a": bob.to_dbref(), "c": [jon.to_dbref(), peter.to_dbref()]}) Book.objects.create(name="2", author=bob, extra={
"a": bob.to_dbref(), "b": karl.to_dbref()})
Book.objects.create(name="3", author=bob, extra={
"a": bob.to_dbref(), "c": [jon.to_dbref(), peter.to_dbref()]})
Book.objects.create(name="4", author=bob) Book.objects.create(name="4", author=bob)
# Jon # Jon
@ -1717,23 +1721,26 @@ class InstanceTest(unittest.TestCase):
Book.objects.create(name="6", author=peter) Book.objects.create(name="6", author=peter)
Book.objects.create(name="7", author=jon) Book.objects.create(name="7", author=jon)
Book.objects.create(name="8", author=jon) Book.objects.create(name="8", author=jon)
Book.objects.create(name="9", author=jon, extra={"a": peter.to_dbref()}) Book.objects.create(name="9", author=jon,
extra={"a": peter.to_dbref()})
# Checks # Checks
self.assertEqual(u",".join([str(b) for b in Book.objects.all()]) , "1,2,3,4,5,6,7,8,9") self.assertEqual(",".join([str(b) for b in Book.objects.all()]),
"1,2,3,4,5,6,7,8,9")
# bob related books # bob related books
self.assertEqual(u",".join([str(b) for b in Book.objects.filter( self.assertEqual(",".join([str(b) for b in Book.objects.filter(
Q(extra__a=bob) | Q(extra__a=bob) |
Q(author=bob) | Q(author=bob) |
Q(extra__b=bob))]) , Q(extra__b=bob))]),
"1,2,3,4") "1,2,3,4")
# Susan & Karl related books # Susan & Karl related books
self.assertEqual(u",".join([str(b) for b in Book.objects.filter( self.assertEqual(",".join([str(b) for b in Book.objects.filter(
Q(extra__a__all=[karl, susan]) | Q(extra__a__all=[karl, susan]) |
Q(author__all=[karl, susan ]) | Q(author__all=[karl, susan]) |
Q(extra__b__all=[karl.to_dbref(), susan.to_dbref()]) Q(extra__b__all=[
) ]) , "1") karl.to_dbref(), susan.to_dbref()]))
]), "1")
# $Where # $Where
self.assertEqual(u",".join([str(b) for b in Book.objects.filter( self.assertEqual(u",".join([str(b) for b in Book.objects.filter(
@ -1743,7 +1750,7 @@ class InstanceTest(unittest.TestCase):
return this.name == '1' || return this.name == '1' ||
this.name == '2';}""" this.name == '2';}"""
} }
) ]), "1,2") )]), "1,2")
class ValidatorErrorTest(unittest.TestCase): class ValidatorErrorTest(unittest.TestCase):

View File

@ -331,14 +331,10 @@ class FieldTest(unittest.TestCase):
return "<Person: %s>" % self.name return "<Person: %s>" % self.name
Person.drop_collection() Person.drop_collection()
paul = Person(name="Paul") paul = Person(name="Paul").save()
paul.save() maria = Person(name="Maria").save()
maria = Person(name="Maria") julia = Person(name='Julia').save()
maria.save() anna = Person(name='Anna').save()
julia = Person(name='Julia')
julia.save()
anna = Person(name='Anna')
anna.save()
paul.other.friends = [maria, julia, anna] paul.other.friends = [maria, julia, anna]
paul.other.name = "Paul's friends" paul.other.name = "Paul's friends"

View File

@ -727,7 +727,7 @@ class FieldTest(unittest.TestCase):
"""Ensure that the list fields can handle the complex types.""" """Ensure that the list fields can handle the complex types."""
class SettingBase(EmbeddedDocument): class SettingBase(EmbeddedDocument):
pass meta = {'allow_inheritance': True}
class StringSetting(SettingBase): class StringSetting(SettingBase):
value = StringField() value = StringField()
@ -743,8 +743,9 @@ class FieldTest(unittest.TestCase):
e.mapping.append(StringSetting(value='foo')) e.mapping.append(StringSetting(value='foo'))
e.mapping.append(IntegerSetting(value=42)) e.mapping.append(IntegerSetting(value=42))
e.mapping.append({'number': 1, 'string': 'Hi!', 'float': 1.001, e.mapping.append({'number': 1, 'string': 'Hi!', 'float': 1.001,
'complex': IntegerSetting(value=42), 'list': 'complex': IntegerSetting(value=42),
[IntegerSetting(value=42), StringSetting(value='foo')]}) 'list': [IntegerSetting(value=42),
StringSetting(value='foo')]})
e.save() e.save()
e2 = Simple.objects.get(id=e.id) e2 = Simple.objects.get(id=e.id)
@ -844,7 +845,7 @@ class FieldTest(unittest.TestCase):
"""Ensure that the dict field can handle the complex types.""" """Ensure that the dict field can handle the complex types."""
class SettingBase(EmbeddedDocument): class SettingBase(EmbeddedDocument):
pass meta = {'allow_inheritance': True}
class StringSetting(SettingBase): class StringSetting(SettingBase):
value = StringField() value = StringField()
@ -859,9 +860,11 @@ class FieldTest(unittest.TestCase):
e = Simple() e = Simple()
e.mapping['somestring'] = StringSetting(value='foo') e.mapping['somestring'] = StringSetting(value='foo')
e.mapping['someint'] = IntegerSetting(value=42) e.mapping['someint'] = IntegerSetting(value=42)
e.mapping['nested_dict'] = {'number': 1, 'string': 'Hi!', 'float': 1.001, e.mapping['nested_dict'] = {'number': 1, 'string': 'Hi!',
'complex': IntegerSetting(value=42), 'list': 'float': 1.001,
[IntegerSetting(value=42), StringSetting(value='foo')]} 'complex': IntegerSetting(value=42),
'list': [IntegerSetting(value=42),
StringSetting(value='foo')]}
e.save() e.save()
e2 = Simple.objects.get(id=e.id) e2 = Simple.objects.get(id=e.id)
@ -915,7 +918,7 @@ class FieldTest(unittest.TestCase):
"""Ensure that the MapField can handle complex declared types.""" """Ensure that the MapField can handle complex declared types."""
class SettingBase(EmbeddedDocument): class SettingBase(EmbeddedDocument):
pass meta = {"allow_inheritance": True}
class StringSetting(SettingBase): class StringSetting(SettingBase):
value = StringField() value = StringField()
@ -951,7 +954,8 @@ class FieldTest(unittest.TestCase):
number = IntField(default=0, db_field='i') number = IntField(default=0, db_field='i')
class Test(Document): class Test(Document):
my_map = MapField(field=EmbeddedDocumentField(Embedded), db_field='x') my_map = MapField(field=EmbeddedDocumentField(Embedded),
db_field='x')
Test.drop_collection() Test.drop_collection()
@ -1038,6 +1042,8 @@ class FieldTest(unittest.TestCase):
class User(EmbeddedDocument): class User(EmbeddedDocument):
name = StringField() name = StringField()
meta = {'allow_inheritance': True}
class PowerUser(User): class PowerUser(User):
power = IntField() power = IntField()
@ -1046,8 +1052,10 @@ class FieldTest(unittest.TestCase):
author = EmbeddedDocumentField(User) author = EmbeddedDocumentField(User)
post = BlogPost(content='What I did today...') post = BlogPost(content='What I did today...')
post.author = User(name='Test User')
post.author = PowerUser(name='Test User', power=47) post.author = PowerUser(name='Test User', power=47)
post.save()
self.assertEqual(47, BlogPost.objects.first().author.power)
def test_reference_validation(self): def test_reference_validation(self):
"""Ensure that invalid docment objects cannot be assigned to reference """Ensure that invalid docment objects cannot be assigned to reference
@ -2117,12 +2125,12 @@ class FieldTest(unittest.TestCase):
def test_sequence_fields_reload(self): def test_sequence_fields_reload(self):
class Animal(Document): class Animal(Document):
counter = SequenceField() counter = SequenceField()
type = StringField() name = StringField()
self.db['mongoengine.counters'].drop() self.db['mongoengine.counters'].drop()
Animal.drop_collection() Animal.drop_collection()
a = Animal(type="Boi") a = Animal(name="Boi")
a.save() a.save()
self.assertEqual(a.counter, 1) self.assertEqual(a.counter, 1)

View File

@ -647,7 +647,8 @@ class QuerySetTest(unittest.TestCase):
self.assertRaises(NotUniqueError, throw_operation_error_not_unique) self.assertRaises(NotUniqueError, throw_operation_error_not_unique)
self.assertEqual(Blog.objects.count(), 2) self.assertEqual(Blog.objects.count(), 2)
Blog.objects.insert([blog2, blog3], write_options={'continue_on_error': True}) Blog.objects.insert([blog2, blog3], write_options={
'continue_on_error': True})
self.assertEqual(Blog.objects.count(), 3) self.assertEqual(Blog.objects.count(), 3)
def test_get_changed_fields_query_count(self): def test_get_changed_fields_query_count(self):
@ -673,7 +674,7 @@ class QuerySetTest(unittest.TestCase):
r2 = Project(name="r2").save() r2 = Project(name="r2").save()
r3 = Project(name="r3").save() r3 = Project(name="r3").save()
p1 = Person(name="p1", projects=[r1, r2]).save() p1 = Person(name="p1", projects=[r1, r2]).save()
p2 = Person(name="p2", projects=[r2]).save() p2 = Person(name="p2", projects=[r2, r3]).save()
o1 = Organization(name="o1", employees=[p1]).save() o1 = Organization(name="o1", employees=[p1]).save()
with query_counter() as q: with query_counter() as q:
@ -688,24 +689,24 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(q, 0) self.assertEqual(q, 0)
fresh_o1 = Organization.objects.get(id=o1.id) fresh_o1 = Organization.objects.get(id=o1.id)
fresh_o1.save() fresh_o1.save() # No changes, does nothing
self.assertEqual(q, 2) self.assertEqual(q, 1)
with query_counter() as q: with query_counter() as q:
self.assertEqual(q, 0) self.assertEqual(q, 0)
fresh_o1 = Organization.objects.get(id=o1.id) fresh_o1 = Organization.objects.get(id=o1.id)
fresh_o1.save(cascade=False) fresh_o1.save(cascade=False) # No changes, does nothing
self.assertEqual(q, 2) self.assertEqual(q, 1)
with query_counter() as q: with query_counter() as q:
self.assertEqual(q, 0) self.assertEqual(q, 0)
fresh_o1 = Organization.objects.get(id=o1.id) fresh_o1 = Organization.objects.get(id=o1.id)
fresh_o1.employees.append(p2) fresh_o1.employees.append(p2) # Dereferences
fresh_o1.save(cascade=False) fresh_o1.save(cascade=False) # Saves
self.assertEqual(q, 3) self.assertEqual(q, 3)