Merge remote branch 'upstream/dev' into dev

This commit is contained in:
Colin Howe 2011-06-10 16:50:44 +01:00
commit c1fadcac85
12 changed files with 1080 additions and 217 deletions

View File

@ -53,6 +53,8 @@ Fields
.. autoclass:: mongoengine.DateTimeField
.. autoclass:: mongoengine.ComplexDateTimeField
.. autoclass:: mongoengine.EmbeddedDocumentField
.. autoclass:: mongoengine.DictField

View File

@ -5,11 +5,15 @@ Changelog
Changes in dev
==============
- Added slave_okay kwarg to queryset
- Fixed saving so sets updated values rather than overwrites
- Added ComplexDateTimeField - Handles datetimes correctly with microseconds
- Added ComplexBaseField - for improved flexibility and performance
- Added get_FIELD_display() method for easy choice field displaying
- Added queryset.slave_okay(enabled) method
- Updated queryset.timeout(enabled) and queryset.snapshot(enabled) to be chainable
- Added insert method for bulk inserts
- Added blinker signal support
- Added query_counter context manager for tests
- Added DereferenceBaseField - for improved performance in field dereferencing
- Added optional map_reduce method item_frequencies
- Added inline_map_reduce option to map_reduce
- Updated connection exception so it provides more info on the cause.

View File

@ -8,6 +8,7 @@ import sys
import pymongo
import pymongo.objectid
from operator import itemgetter
from functools import partial
class NotRegistered(Exception):
@ -61,6 +62,7 @@ class BaseField(object):
self.primary_key = primary_key
self.validation = validation
self.choices = choices
# Adjust the appropriate creation counter, and save our local copy.
if self.db_field == '_id':
self.creation_counter = BaseField.auto_creation_counter
@ -90,6 +92,9 @@ class BaseField(object):
"""Descriptor for assigning a value to a field in a document.
"""
instance._data[self.name] = value
# If the field set is in the _present_fields list add it so we can track
if hasattr(instance, '_present_fields') and self.name not in instance._present_fields:
instance._present_fields.append(self.name)
def to_python(self, value):
"""Convert a MongoDB-compatible type to a Python type.
@ -130,15 +135,19 @@ class BaseField(object):
self.validate(value)
class DereferenceBaseField(BaseField):
"""Handles the lazy dereferencing of a queryset. Will dereference all
class ComplexBaseField(BaseField):
"""Handles complex fields, such as lists / dictionaries.
Allows for nesting of embedded documents inside complex types.
Handles the lazy dereferencing of a queryset by lazily dereferencing all
items in a list / dict rather than one at a time.
"""
field = None
def __get__(self, instance, owner):
"""Descriptor to automatically dereference references.
"""
from fields import ReferenceField, GenericReferenceField
from connection import _get_db
if instance is None:
@ -147,68 +156,175 @@ class DereferenceBaseField(BaseField):
# Get value from document instance if available
value_list = instance._data.get(self.name)
if not value_list:
return super(DereferenceBaseField, self).__get__(instance, owner)
if not value_list or isinstance(value_list, basestring):
return super(ComplexBaseField, self).__get__(instance, owner)
is_list = False
if not hasattr(value_list, 'items'):
is_list = True
value_list = dict([(k,v) for k,v in enumerate(value_list)])
if isinstance(self.field, ReferenceField) and value_list:
db = _get_db()
dbref = {}
collections = {}
for k,v in value_list.items():
if isinstance(v, dict) and '_cls' in v and '_ref' not in v:
value_list[k] = get_document(v['_cls'].split('.')[-1])._from_son(v)
for k, v in value_list.items():
dbref[k] = v
# Save any DBRefs
# Handle all dereferencing
db = _get_db()
dbref = {}
collections = {}
for k, v in value_list.items():
dbref[k] = v
# Save any DBRefs
if isinstance(v, (pymongo.dbref.DBRef)):
# direct reference (DBRef)
collections.setdefault(v.collection, []).append((k, v))
elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v:
# generic reference
collection = get_document(v['_cls'])._meta['collection']
collections.setdefault(collection, []).append((k, v))
# For each collection get the references
for collection, dbrefs in collections.items():
id_map = {}
for k, v in dbrefs:
if isinstance(v, (pymongo.dbref.DBRef)):
collections.setdefault(v.collection, []).append((k, v))
# direct reference (DBRef), has no _cls information
id_map[v.id] = (k, None)
elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v:
# generic reference - includes _cls information
id_map[v['_ref'].id] = (k, get_document(v['_cls']))
# For each collection get the references
for collection, dbrefs in collections.items():
id_map = dict([(v.id, k) for k, v in dbrefs])
references = db[collection].find({'_id': {'$in': id_map.keys()}})
for ref in references:
key = id_map[ref['_id']]
dbref[key] = get_document(ref['_cls'])._from_son(ref)
references = db[collection].find({'_id': {'$in': id_map.keys()}})
for ref in references:
key, doc_cls = id_map[ref['_id']]
if not doc_cls: # If no doc_cls get it from the referenced doc
doc_cls = get_document(ref['_cls'])
dbref[key] = doc_cls._from_son(ref)
if is_list:
dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))]
instance._data[self.name] = dbref
if is_list:
dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))]
instance._data[self.name] = dbref
return super(ComplexBaseField, self).__get__(instance, owner)
# Get value from document instance if available
if isinstance(self.field, GenericReferenceField) and value_list:
db = _get_db()
value_list = [(k,v) for k,v in value_list.items()]
dbref = {}
classes = {}
def to_python(self, value):
"""Convert a MongoDB-compatible type to a Python type.
"""
from mongoengine import Document
for k, v in value_list:
dbref[k] = v
# Save any DBRefs
if isinstance(v, (dict, pymongo.son.SON)):
classes.setdefault(v['_cls'], []).append((k, v))
if isinstance(value, basestring):
return value
# For each collection get the references
for doc_cls, dbrefs in classes.items():
id_map = dict([(v['_ref'].id, k) for k, v in dbrefs])
doc_cls = get_document(doc_cls)
collection = doc_cls._meta['collection']
references = db[collection].find({'_id': {'$in': id_map.keys()}})
if hasattr(value, 'to_python'):
return value.to_python()
for ref in references:
key = id_map[ref['_id']]
dbref[key] = doc_cls._from_son(ref)
is_list = False
if not hasattr(value, 'items'):
try:
is_list = True
value = dict([(k,v) for k,v in enumerate(value)])
except TypeError: # Not iterable return the value
return value
if is_list:
dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))]
if self.field:
value_dict = dict([(key, self.field.to_python(item)) for key, item in value.items()])
else:
value_dict = {}
for k,v in value.items():
if isinstance(v, Document):
# We need the id from the saved object to create the DBRef
if v.pk is None:
raise ValidationError('You can only reference documents once '
'they have been saved to the database')
collection = v._meta['collection']
value_dict[k] = pymongo.dbref.DBRef(collection, v.pk)
elif hasattr(v, 'to_python'):
value_dict[k] = v.to_python()
else:
value_dict[k] = self.to_python(v)
instance._data[self.name] = dbref
if is_list: # Convert back to a list
return [v for k,v in sorted(value_dict.items(), key=itemgetter(0))]
return value_dict
return super(DereferenceBaseField, self).__get__(instance, owner)
def to_mongo(self, value):
"""Convert a Python type to a MongoDB-compatible type.
"""
from mongoengine import Document
if isinstance(value, basestring):
return value
if hasattr(value, 'to_mongo'):
return value.to_mongo()
is_list = False
if not hasattr(value, 'items'):
try:
is_list = True
value = dict([(k,v) for k,v in enumerate(value)])
except TypeError: # Not iterable return the value
return value
if self.field:
value_dict = dict([(key, self.field.to_mongo(item)) for key, item in value.items()])
else:
value_dict = {}
for k,v in value.items():
if isinstance(v, Document):
# We need the id from the saved object to create the DBRef
if v.pk is None:
raise ValidationError('You can only reference documents once '
'they have been saved to the database')
# If its a document that is not inheritable it won't have
# _types / _cls data so make it a generic reference allows
# us to dereference
meta = getattr(v, 'meta', getattr(v, '_meta', {}))
if meta and not meta['allow_inheritance'] and not self.field:
from fields import GenericReferenceField
value_dict[k] = GenericReferenceField().to_mongo(v)
else:
collection = v._meta['collection']
value_dict[k] = pymongo.dbref.DBRef(collection, v.pk)
elif hasattr(v, 'to_mongo'):
value_dict[k] = v.to_mongo()
else:
value_dict[k] = self.to_mongo(v)
if is_list: # Convert back to a list
return [v for k,v in sorted(value_dict.items(), key=itemgetter(0))]
return value_dict
def validate(self, value):
"""If field provided ensure the value is valid.
"""
if self.field:
try:
if hasattr(value, 'iteritems'):
[self.field.validate(v) for k,v in value.iteritems()]
else:
[self.field.validate(v) for v in value]
except Exception, err:
raise ValidationError('Invalid %s item (%s)' % (
self.field.__class__.__name__, str(v)))
def prepare_query_value(self, op, value):
return self.to_mongo(value)
def lookup_member(self, member_name):
if self.field:
return self.field.lookup_member(member_name)
return None
def _set_owner_document(self, owner_document):
if self.field:
self.field.owner_document = owner_document
self._owner_document = owner_document
def _get_owner_document(self, owner_document):
self._owner_document = owner_document
owner_document = property(_get_owner_document, _set_owner_document)
class ObjectIdField(BaseField):
@ -217,7 +333,6 @@ class ObjectIdField(BaseField):
def to_python(self, value):
return value
# return unicode(value)
def to_mongo(self, value):
if not isinstance(value, pymongo.objectid.ObjectId):
@ -261,7 +376,7 @@ class DocumentMetaclass(type):
superclasses[base._class_name] = base
superclasses.update(base._superclasses)
if hasattr(base, '_meta'):
if hasattr(base, '_meta') and not base._meta.get('abstract'):
# Ensure that the Document class may be subclassed -
# inheritance may be disabled to remove dependency on
# additional fields _cls and _types
@ -278,7 +393,7 @@ class DocumentMetaclass(type):
# Only simple classes - direct subclasses of Document - may set
# allow_inheritance to False
if not simple_class and not meta['allow_inheritance']:
if not simple_class and not meta['allow_inheritance'] and not meta['abstract']:
raise ValueError('Only direct subclasses of Document may set '
'"allow_inheritance" to False')
attrs['_meta'] = meta
@ -358,8 +473,9 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
# Subclassed documents inherit collection from superclass
for base in bases:
if hasattr(base, '_meta') and 'collection' in base._meta:
collection = base._meta['collection']
if hasattr(base, '_meta'):
if 'collection' in base._meta:
collection = base._meta['collection']
# Propagate index options.
for key in ('index_background', 'index_drop_dups', 'index_opts'):
@ -368,6 +484,9 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
id_field = id_field or base._meta.get('id_field')
base_indexes += base._meta.get('indexes', [])
# Propagate 'allow_inheritance'
if 'allow_inheritance' in base._meta:
base_meta['allow_inheritance'] = base._meta['allow_inheritance']
meta = {
'abstract': False,
@ -382,6 +501,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
'index_opts': {},
'queryset_class': QuerySet,
'delete_rules': {},
'allow_inheritance': True
}
meta.update(base_meta)
@ -471,19 +591,28 @@ class BaseDocument(object):
self._data = {}
# Assign default values to instance
for attr_name in self._fields.keys():
# Use default value if present
for attr_name, field in self._fields.items():
if field.choices: # dynamically adds a way to get the display value for a field with choices
setattr(self, 'get_%s_display' % attr_name, partial(self._get_FIELD_display, field=field))
value = getattr(self, attr_name, None)
setattr(self, attr_name, value)
# Assign initial values to instance
for attr_name in values.keys():
try:
setattr(self, attr_name, values.pop(attr_name))
value = values.pop(attr_name)
setattr(self, attr_name, value)
except AttributeError:
pass
signals.post_init.send(self)
def _get_FIELD_display(self, field):
"""Returns the display value for a choice field"""
value = getattr(self, field.name)
return dict(field.choices).get(value, value)
def validate(self):
"""Ensure that all fields' values are valid and that required fields
are present.
@ -614,7 +743,6 @@ class BaseDocument(object):
cls = subclasses[class_name]
present_fields = data.keys()
for field_name, field in cls._fields.items():
if field.db_field in data:
value = data[field.db_field]

View File

@ -53,6 +53,11 @@ class Document(BaseDocument):
dictionary. The value should be a list of field names or tuples of field
names. Index direction may be specified by prefixing the field names with
a **+** or **-** sign.
By default, _types will be added to the start of every index (that
doesn't contain a list) if allow_inheritence is True. This can be
disabled by either setting types to False on the specific index or
by setting index_types to False on the meta dictionary for the document.
"""
__metaclass__ = TopLevelDocumentMetaclass
@ -90,6 +95,16 @@ class Document(BaseDocument):
collection = self.__class__.objects._collection
if force_insert:
object_id = collection.insert(doc, safe=safe, **write_options)
elif '_id' in doc:
# Perform a set rather than a save - this will only save set fields
object_id = doc.pop('_id')
collection.update({'_id': object_id}, {"$set": doc}, upsert=True, safe=safe, **write_options)
# Find and unset any fields explicitly set to None
if hasattr(self, '_present_fields'):
removals = dict([(k, 1) for k in self._present_fields if k not in doc and k != '_id'])
if removals:
collection.update({'_id': object_id}, {"$unset": removals}, upsert=True, safe=safe, **write_options)
else:
object_id = collection.save(doc, safe=safe, **write_options)
except pymongo.errors.OperationFailure, err:

View File

@ -1,4 +1,4 @@
from base import (BaseField, DereferenceBaseField, ObjectIdField,
from base import (BaseField, ComplexBaseField, ObjectIdField,
ValidationError, get_document)
from queryset import DO_NOTHING
from document import Document, EmbeddedDocument
@ -18,8 +18,9 @@ import gridfs
__all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField',
'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField',
'ObjectIdField', 'ReferenceField', 'ValidationError', 'MapField',
'DecimalField', 'URLField', 'GenericReferenceField', 'FileField',
'BinaryField', 'SortedListField', 'EmailField', 'GeoPointField']
'DecimalField', 'ComplexDateTimeField', 'URLField',
'GenericReferenceField', 'FileField', 'BinaryField',
'SortedListField', 'EmailField', 'GeoPointField']
RECURSIVE_REFERENCE_CONSTANT = 'self'
@ -273,6 +274,98 @@ class DateTimeField(BaseField):
return None
class ComplexDateTimeField(StringField):
"""
ComplexDateTimeField handles microseconds exactly instead of rounding
like DateTimeField does.
Derives from a StringField so you can do `gte` and `lte` filtering by
using lexicographical comparison when filtering / sorting strings.
The stored string has the following format:
YYYY,MM,DD,HH,MM,SS,NNNNNN
Where NNNNNN is the number of microseconds of the represented `datetime`.
The `,` as the separator can be easily modified by passing the `separator`
keyword when initializing the field.
"""
def __init__(self, separator=',', **kwargs):
self.names = ['year', 'month', 'day', 'hour', 'minute', 'second',
'microsecond']
self.separtor = separator
super(ComplexDateTimeField, self).__init__(**kwargs)
def _leading_zero(self, number):
"""
Converts the given number to a string.
If it has only one digit, a leading zero so as it has always at least
two digits.
"""
if int(number) < 10:
return "0%s" % number
else:
return str(number)
def _convert_from_datetime(self, val):
"""
Convert a `datetime` object to a string representation (which will be
stored in MongoDB). This is the reverse function of
`_convert_from_string`.
>>> a = datetime(2011, 6, 8, 20, 26, 24, 192284)
>>> RealDateTimeField()._convert_from_datetime(a)
'2011,06,08,20,26,24,192284'
"""
data = []
for name in self.names:
data.append(self._leading_zero(getattr(val, name)))
return ','.join(data)
def _convert_from_string(self, data):
"""
Convert a string representation to a `datetime` object (the object you
will manipulate). This is the reverse function of
`_convert_from_datetime`.
>>> a = '2011,06,08,20,26,24,192284'
>>> ComplexDateTimeField()._convert_from_string(a)
datetime.datetime(2011, 6, 8, 20, 26, 24, 192284)
"""
data = data.split(',')
data = map(int, data)
values = {}
for i in range(7):
values[self.names[i]] = data[i]
return datetime.datetime(**values)
def __get__(self, instance, owner):
data = super(ComplexDateTimeField, self).__get__(instance, owner)
if data == None:
return datetime.datetime.now()
return self._convert_from_string(data)
def __set__(self, obj, val):
data = self._convert_from_datetime(val)
return super(ComplexDateTimeField, self).__set__(obj, data)
def validate(self, value):
if not isinstance(value, datetime.datetime):
raise ValidationError('Only datetime objects may used in a \
ComplexDateTimeField')
def to_python(self, value):
return self._convert_from_string(value)
def to_mongo(self, value):
return self._convert_from_datetime(value)
def prepare_query_value(self, op, value):
return self._convert_from_datetime(value)
class EmbeddedDocumentField(BaseField):
"""An embedded document field. Only valid values are subclasses of
:class:`~mongoengine.EmbeddedDocument`.
@ -301,6 +394,8 @@ class EmbeddedDocumentField(BaseField):
return value
def to_mongo(self, value):
if isinstance(value, basestring):
return value
return self.document_type.to_mongo(value)
def validate(self, value):
@ -320,7 +415,7 @@ class EmbeddedDocumentField(BaseField):
return self.to_mongo(value)
class ListField(DereferenceBaseField):
class ListField(ComplexBaseField):
"""A list field that wraps a standard field, allowing multiple instances
of the field to be used as a list in the database.
"""
@ -328,48 +423,25 @@ class ListField(DereferenceBaseField):
# ListFields cannot be indexed with _types - MongoDB doesn't support this
_index_with_types = False
def __init__(self, field, **kwargs):
if not isinstance(field, BaseField):
raise ValidationError('Argument to ListField constructor must be '
'a valid field')
def __init__(self, field=None, **kwargs):
self.field = field
kwargs.setdefault('default', lambda: [])
super(ListField, self).__init__(**kwargs)
def to_python(self, value):
return [self.field.to_python(item) for item in value]
def to_mongo(self, value):
return [self.field.to_mongo(item) for item in value]
def validate(self, value):
"""Make sure that a list of valid fields is being used.
"""
if not isinstance(value, (list, tuple)):
raise ValidationError('Only lists and tuples may be used in a '
'list field')
try:
[self.field.validate(item) for item in value]
except Exception, err:
raise ValidationError('Invalid ListField item (%s)' % str(item))
super(ListField, self).validate(value)
def prepare_query_value(self, op, value):
if op in ('set', 'unset'):
return [self.field.prepare_query_value(op, v) for v in value]
return self.field.prepare_query_value(op, value)
def lookup_member(self, member_name):
return self.field.lookup_member(member_name)
def _set_owner_document(self, owner_document):
self.field.owner_document = owner_document
self._owner_document = owner_document
def _get_owner_document(self, owner_document):
self._owner_document = owner_document
owner_document = property(_get_owner_document, _set_owner_document)
if self.field:
if op in ('set', 'unset') and not isinstance(value, basestring):
return [self.field.prepare_query_value(op, v) for v in value]
return self.field.prepare_query_value(op, value)
return super(ListField, self).prepare_query_value(op, value)
class SortedListField(ListField):
@ -388,20 +460,21 @@ class SortedListField(ListField):
super(SortedListField, self).__init__(field, **kwargs)
def to_mongo(self, value):
value = super(SortedListField, self).to_mongo(value)
if self._ordering is not None:
return sorted([self.field.to_mongo(item) for item in value],
key=itemgetter(self._ordering))
return sorted([self.field.to_mongo(item) for item in value])
return sorted(value, key=itemgetter(self._ordering))
return sorted(value)
class DictField(BaseField):
class DictField(ComplexBaseField):
"""A dictionary field that wraps a standard Python dictionary. This is
similar to an embedded document, but the structure is not defined.
.. versionadded:: 0.3
"""
def __init__(self, basecls=None, *args, **kwargs):
def __init__(self, basecls=None, field=None, *args, **kwargs):
self.field = field
self.basecls = basecls or BaseField
assert issubclass(self.basecls, BaseField)
kwargs.setdefault('default', lambda: {})
@ -417,6 +490,7 @@ class DictField(BaseField):
if any(('.' in k or '$' in k) for k in value):
raise ValidationError('Invalid dictionary key name - keys may not '
'contain "." or "$" characters')
super(DictField, self).validate(value)
def lookup_member(self, member_name):
return DictField(basecls=self.basecls, db_field=member_name)
@ -432,7 +506,7 @@ class DictField(BaseField):
return super(DictField, self).prepare_query_value(op, value)
class MapField(DereferenceBaseField):
class MapField(DictField):
"""A field that maps a name to a specified field type. Similar to
a DictField, except the 'value' of each item must match the specified
field type.
@ -444,50 +518,7 @@ class MapField(DereferenceBaseField):
if not isinstance(field, BaseField):
raise ValidationError('Argument to MapField constructor must be '
'a valid field')
self.field = field
kwargs.setdefault('default', lambda: {})
super(MapField, self).__init__(*args, **kwargs)
def validate(self, value):
"""Make sure that a list of valid fields is being used.
"""
if not isinstance(value, dict):
raise ValidationError('Only dictionaries may be used in a '
'DictField')
if any(('.' in k or '$' in k) for k in value):
raise ValidationError('Invalid dictionary key name - keys may not '
'contain "." or "$" characters')
try:
[self.field.validate(item) for item in value.values()]
except Exception, err:
raise ValidationError('Invalid MapField item (%s)' % str(item))
def to_python(self, value):
return dict([(key, self.field.to_python(item)) for key, item in value.iteritems()])
def to_mongo(self, value):
return dict([(key, self.field.to_mongo(item)) for key, item in value.iteritems()])
def prepare_query_value(self, op, value):
if op not in ('set', 'unset'):
return self.field.prepare_query_value(op, value)
for key in value:
value[key] = self.field.prepare_query_value(op, value[key])
return value
def lookup_member(self, member_name):
return self.field.lookup_member(member_name)
def _set_owner_document(self, owner_document):
self.field.owner_document = owner_document
self._owner_document = owner_document
def _get_owner_document(self, owner_document):
self._owner_document = owner_document
owner_document = property(_get_owner_document, _set_owner_document)
super(MapField, self).__init__(field=field, *args, **kwargs)
class ReferenceField(BaseField):

View File

@ -418,8 +418,9 @@ class QuerySet(object):
use_types = False
# If _types is being used, prepend it to every specified index
if (spec.get('types', True) and doc_cls._meta.get('allow_inheritance')
and use_types):
index_types = doc_cls._meta.get('index_types', True)
allow_inheritance = doc_cls._meta.get('allow_inheritance')
if spec.get('types', index_types) and allow_inheritance and use_types:
index_list.insert(0, ('_types', 1))
spec['fields'] = index_list
@ -474,6 +475,7 @@ class QuerySet(object):
background = self._document._meta.get('index_background', False)
drop_dups = self._document._meta.get('index_drop_dups', False)
index_opts = self._document._meta.get('index_options', {})
index_types = self._document._meta.get('index_types', True)
# Ensure indexes created by uniqueness constraints
for index in self._document._meta['unique_indexes']:
@ -490,7 +492,7 @@ class QuerySet(object):
background=background, **opts)
# If _types is being used (for polymorphism), it needs an index
if '_types' in self._query:
if index_types and '_types' in self._query:
self._collection.ensure_index('_types',
background=background, **index_opts)
@ -547,11 +549,12 @@ class QuerySet(object):
parts = [parts]
fields = []
field = None
for field_name in parts:
# Handle ListField indexing:
if field_name.isdigit():
try:
field = field.field
new_field = field.field
except AttributeError, err:
raise InvalidQueryError(
"Can't use index on unsubscriptable field (%s)" % err)
@ -565,11 +568,17 @@ class QuerySet(object):
field = document._fields[field_name]
else:
# Look up subfield on the previous field
field = field.lookup_member(field_name)
if field is None:
new_field = field.lookup_member(field_name)
from base import ComplexBaseField
if not new_field and isinstance(field, ComplexBaseField):
fields.append(field_name)
continue
elif not new_field:
raise InvalidQueryError('Cannot resolve field "%s"'
% field_name)
% field_name)
field = new_field # update field to the new field type
fields.append(field)
return fields
@classmethod
@ -613,14 +622,33 @@ class QuerySet(object):
if _doc_cls:
# Switch field names to proper names [set in Field(name='foo')]
fields = QuerySet._lookup_field(_doc_cls, parts)
parts = [field.db_field for field in fields]
parts = []
cleaned_fields = []
append_field = True
for field in fields:
if isinstance(field, str):
parts.append(field)
append_field = False
else:
parts.append(field.db_field)
if append_field:
cleaned_fields.append(field)
# Convert value to proper value
field = fields[-1]
field = cleaned_fields[-1]
singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not']
singular_ops += match_operators
if op in singular_ops:
value = field.prepare_query_value(op, value)
if isinstance(field, basestring):
if op in match_operators and isinstance(value, basestring):
from mongoengine import StringField
value = StringField().prepare_query_value(op, value)
else:
value = field
else:
value = field.prepare_query_value(op, value)
elif op in ('in', 'nin', 'all', 'near'):
# 'in', 'nin' and 'all' require a list of values
value = [field.prepare_query_value(op, v) for v in value]
@ -1168,14 +1196,19 @@ class QuerySet(object):
fields = QuerySet._lookup_field(_doc_cls, parts)
parts = []
cleaned_fields = []
append_field = True
for field in fields:
if isinstance(field, str):
parts.append(field)
append_field = False
else:
parts.append(field.db_field)
if append_field:
cleaned_fields.append(field)
# Convert value to proper value
field = fields[-1]
field = cleaned_fields[-1]
if op in (None, 'set', 'push', 'pull', 'addToSet'):
value = field.prepare_query_value(op, value)

View File

@ -45,6 +45,6 @@ setup(name='mongoengine',
long_description=LONG_DESCRIPTION,
platforms=['any'],
classifiers=CLASSIFIERS,
install_requires=['pymongo', 'blinker'],
install_requires=['pymongo', 'blinker', 'django==1.3'],
test_suite='tests',
)

View File

@ -122,6 +122,64 @@ class FieldTest(unittest.TestCase):
[m for m in group_obj.members]
self.assertEqual(q, 4)
for m in group_obj.members:
self.assertTrue('User' in m.__class__.__name__)
UserA.drop_collection()
UserB.drop_collection()
UserC.drop_collection()
Group.drop_collection()
def test_list_field_complex(self):
class UserA(Document):
name = StringField()
class UserB(Document):
name = StringField()
class UserC(Document):
name = StringField()
class Group(Document):
members = ListField()
UserA.drop_collection()
UserB.drop_collection()
UserC.drop_collection()
Group.drop_collection()
members = []
for i in xrange(1, 51):
a = UserA(name='User A %s' % i)
a.save()
b = UserB(name='User B %s' % i)
b.save()
c = UserC(name='User C %s' % i)
c.save()
members += [a, b, c]
group = Group(members=members)
group.save()
with query_counter() as q:
self.assertEqual(q, 0)
group_obj = Group.objects.first()
self.assertEqual(q, 1)
[m for m in group_obj.members]
self.assertEqual(q, 4)
[m for m in group_obj.members]
self.assertEqual(q, 4)
for m in group_obj.members:
self.assertTrue('User' in m.__class__.__name__)
UserA.drop_collection()
UserB.drop_collection()
UserC.drop_collection()
@ -156,10 +214,13 @@ class FieldTest(unittest.TestCase):
[m for m in group_obj.members]
self.assertEqual(q, 2)
for k, m in group_obj.members.iteritems():
self.assertTrue(isinstance(m, User))
User.drop_collection()
Group.drop_collection()
def ztest_generic_reference_dict_field(self):
def test_dict_field(self):
class UserA(Document):
name = StringField()
@ -206,6 +267,9 @@ class FieldTest(unittest.TestCase):
[m for m in group_obj.members]
self.assertEqual(q, 4)
for k, m in group_obj.members.iteritems():
self.assertTrue('User' in m.__class__.__name__)
group.members = {}
group.save()
@ -218,11 +282,54 @@ class FieldTest(unittest.TestCase):
[m for m in group_obj.members]
self.assertEqual(q, 1)
for k, m in group_obj.members.iteritems():
self.assertTrue('User' in m.__class__.__name__)
UserA.drop_collection()
UserB.drop_collection()
UserC.drop_collection()
Group.drop_collection()
def test_dict_field_no_field_inheritance(self):
class UserA(Document):
name = StringField()
meta = {'allow_inheritance': False}
class Group(Document):
members = DictField()
UserA.drop_collection()
Group.drop_collection()
members = []
for i in xrange(1, 51):
a = UserA(name='User A %s' % i)
a.save()
members += [a]
group = Group(members=dict([(str(u.id), u) for u in members]))
group.save()
with query_counter() as q:
self.assertEqual(q, 0)
group_obj = Group.objects.first()
self.assertEqual(q, 1)
[m for m in group_obj.members]
self.assertEqual(q, 2)
[m for m in group_obj.members]
self.assertEqual(q, 2)
for k, m in group_obj.members.iteritems():
self.assertTrue(isinstance(m, UserA))
UserA.drop_collection()
Group.drop_collection()
def test_generic_reference_map_field(self):
class UserA(Document):
@ -270,6 +377,9 @@ class FieldTest(unittest.TestCase):
[m for m in group_obj.members]
self.assertEqual(q, 4)
for k, m in group_obj.members.iteritems():
self.assertTrue('User' in m.__class__.__name__)
group.members = {}
group.save()

57
tests/django_tests.py Normal file
View File

@ -0,0 +1,57 @@
# -*- coding: utf-8 -*-
import unittest
from mongoengine import *
from django.template import Context, Template
from django.conf import settings
settings.configure()
class QuerySetTest(unittest.TestCase):
def setUp(self):
connect(db='mongoenginetest')
class Person(Document):
name = StringField()
age = IntField()
self.Person = Person
def test_order_by_in_django_template(self):
"""Ensure that QuerySets are properly ordered in Django template.
"""
self.Person.drop_collection()
self.Person(name="A", age=20).save()
self.Person(name="D", age=10).save()
self.Person(name="B", age=40).save()
self.Person(name="C", age=30).save()
t = Template("{% for o in ol %}{{ o.name }}-{{ o.age }}:{% endfor %}")
d = {"ol": self.Person.objects.order_by('-name')}
self.assertEqual(t.render(Context(d)), u'D-10:C-30:B-40:A-20:')
d = {"ol": self.Person.objects.order_by('+name')}
self.assertEqual(t.render(Context(d)), u'A-20:B-40:C-30:D-10:')
d = {"ol": self.Person.objects.order_by('-age')}
self.assertEqual(t.render(Context(d)), u'B-40:C-30:A-20:D-10:')
d = {"ol": self.Person.objects.order_by('+age')}
self.assertEqual(t.render(Context(d)), u'D-10:A-20:C-30:B-40:')
self.Person.drop_collection()
def test_q_object_filter_in_template(self):
self.Person.drop_collection()
self.Person(name="A", age=20).save()
self.Person(name="D", age=10).save()
self.Person(name="B", age=40).save()
self.Person(name="C", age=30).save()
t = Template("{% for o in ol %}{{ o.name }}-{{ o.age }}:{% endfor %}")
d = {"ol": self.Person.objects.filter(Q(age=10) | Q(name="C"))}
self.assertEqual(t.render(Context(d)), u'D-10:C-30:')

View File

@ -151,12 +151,12 @@ class DocumentTest(unittest.TestCase):
"""Ensure that inheritance may be disabled on simple classes and that
_cls and _types will not be used.
"""
class Animal(Document):
meta = {'allow_inheritance': False}
name = StringField()
meta = {'allow_inheritance': False}
Animal.drop_collection()
def create_dog_class():
class Dog(Animal):
pass
@ -191,6 +191,92 @@ class DocumentTest(unittest.TestCase):
self.assertFalse('_cls' in comment.to_mongo())
self.assertFalse('_types' in comment.to_mongo())
def test_allow_inheritance_abstract_document(self):
"""Ensure that abstract documents can set inheritance rules and that
_cls and _types will not be used.
"""
class FinalDocument(Document):
meta = {'abstract': True,
'allow_inheritance': False}
class Animal(FinalDocument):
name = StringField()
Animal.drop_collection()
def create_dog_class():
class Dog(Animal):
pass
self.assertRaises(ValueError, create_dog_class)
# Check that _cls etc aren't present on simple documents
dog = Animal(name='dog')
dog.save()
collection = self.db[Animal._meta['collection']]
obj = collection.find_one()
self.assertFalse('_cls' in obj)
self.assertFalse('_types' in obj)
Animal.drop_collection()
def test_how_to_turn_off_inheritance(self):
"""Demonstrates migrating from allow_inheritance = True to False.
"""
class Animal(Document):
name = StringField()
meta = {
'indexes': ['name']
}
Animal.drop_collection()
dog = Animal(name='dog')
dog.save()
collection = self.db[Animal._meta['collection']]
obj = collection.find_one()
self.assertTrue('_cls' in obj)
self.assertTrue('_types' in obj)
info = collection.index_information()
info = [value['key'] for key, value in info.iteritems()]
self.assertEquals([[(u'_id', 1)], [(u'_types', 1)], [(u'_types', 1), (u'name', 1)]], info)
# Turn off inheritance
class Animal(Document):
name = StringField()
meta = {
'allow_inheritance': False,
'indexes': ['name']
}
collection.update({}, {"$unset": {"_types": 1, "_cls": 1}}, False, True)
# Confirm extra data is removed
obj = collection.find_one()
self.assertFalse('_cls' in obj)
self.assertFalse('_types' in obj)
info = collection.index_information()
info = [value['key'] for key, value in info.iteritems()]
self.assertEquals([[(u'_id', 1)], [(u'_types', 1)], [(u'_types', 1), (u'name', 1)]], info)
info = collection.index_information()
indexes_to_drop = [key for key, value in info.iteritems() if '_types' in dict(value['key'])]
for index in indexes_to_drop:
collection.drop_index(index)
info = collection.index_information()
info = [value['key'] for key, value in info.iteritems()]
self.assertEquals([[(u'_id', 1)]], info)
# Recreate indexes
dog = Animal.objects.first()
dog.save()
info = collection.index_information()
info = [value['key'] for key, value in info.iteritems()]
self.assertEquals([[(u'_id', 1)], [(u'name', 1),]], info)
Animal.drop_collection()
def test_abstract_documents(self):
"""Ensure that a document superclass can be marked as abstract
thereby not using it as the name for the collection."""
@ -703,6 +789,90 @@ class DocumentTest(unittest.TestCase):
except ValidationError:
self.fail()
def test_update(self):
"""Ensure that an existing document is updated instead of be overwritten.
"""
# Create person object and save it to the database
person = self.Person(name='Test User', age=30)
person.save()
# Create same person object, with same id, without age
same_person = self.Person(name='Test')
same_person.id = person.id
same_person.save()
# Confirm only one object
self.assertEquals(self.Person.objects.count(), 1)
# reload
person.reload()
same_person.reload()
# Confirm the same
self.assertEqual(person, same_person)
self.assertEqual(person.name, same_person.name)
self.assertEqual(person.age, same_person.age)
# Confirm the saved values
self.assertEqual(person.name, 'Test')
self.assertEqual(person.age, 30)
# Test only / exclude only updates included fields
person = self.Person.objects.only('name').get()
person.name = 'User'
person.save()
person.reload()
self.assertEqual(person.name, 'User')
self.assertEqual(person.age, 30)
# test exclude only updates set fields
person = self.Person.objects.exclude('name').get()
person.age = 21
person.save()
person.reload()
self.assertEqual(person.name, 'User')
self.assertEqual(person.age, 21)
# Test only / exclude can set non excluded / included fields
person = self.Person.objects.only('name').get()
person.name = 'Test'
person.age = 30
person.save()
person.reload()
self.assertEqual(person.name, 'Test')
self.assertEqual(person.age, 30)
# test exclude only updates set fields
person = self.Person.objects.exclude('name').get()
person.name = 'User'
person.age = 21
person.save()
person.reload()
self.assertEqual(person.name, 'User')
self.assertEqual(person.age, 21)
# Confirm does remove unrequired fields
person = self.Person.objects.exclude('name').get()
person.age = None
person.save()
person.reload()
self.assertEqual(person.name, 'User')
self.assertEqual(person.age, None)
person = self.Person.objects.get()
person.name = None
person.age = None
person.save()
person.reload()
self.assertEqual(person.name, None)
self.assertEqual(person.age, None)
def test_delete(self):
"""Ensure that document may be deleted using the delete method.
"""

View File

@ -247,6 +247,107 @@ class FieldTest(unittest.TestCase):
LogEntry.drop_collection()
def test_complexdatetime_storage(self):
"""Tests for complex datetime fields - which can handle microseconds
without rounding.
"""
class LogEntry(Document):
date = ComplexDateTimeField()
LogEntry.drop_collection()
# Post UTC - microseconds are rounded (down) nearest millisecond and dropped - with default datetimefields
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999)
log = LogEntry()
log.date = d1
log.save()
log.reload()
self.assertEquals(log.date, d1)
# Post UTC - microseconds are rounded (down) nearest millisecond - with default datetimefields
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9999)
log.date = d1
log.save()
log.reload()
self.assertEquals(log.date, d1)
# Pre UTC dates microseconds below 1000 are dropped - with default datetimefields
d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999)
log.date = d1
log.save()
log.reload()
self.assertEquals(log.date, d1)
# Pre UTC microseconds above 1000 is wonky - with default datetimefields
# log.date has an invalid microsecond value so I can't construct
# a date to compare.
for i in xrange(1001, 3113, 33):
d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, i)
log.date = d1
log.save()
log.reload()
self.assertEquals(log.date, d1)
log1 = LogEntry.objects.get(date=d1)
self.assertEqual(log, log1)
LogEntry.drop_collection()
def test_complexdatetime_usage(self):
"""Tests for complex datetime fields - which can handle microseconds
without rounding.
"""
class LogEntry(Document):
date = ComplexDateTimeField()
LogEntry.drop_collection()
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999)
log = LogEntry()
log.date = d1
log.save()
log1 = LogEntry.objects.get(date=d1)
self.assertEquals(log, log1)
LogEntry.drop_collection()
# create 60 log entries
for i in xrange(1950, 2010):
d = datetime.datetime(i, 01, 01, 00, 00, 01, 999)
LogEntry(date=d).save()
self.assertEqual(LogEntry.objects.count(), 60)
# Test ordering
logs = LogEntry.objects.order_by("date")
count = logs.count()
i = 0
while i == count-1:
self.assertTrue(logs[i].date <= logs[i+1].date)
i +=1
logs = LogEntry.objects.order_by("-date")
count = logs.count()
i = 0
while i == count-1:
self.assertTrue(logs[i].date >= logs[i+1].date)
i +=1
# Test searching
logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980,1,1))
self.assertEqual(logs.count(), 30)
logs = LogEntry.objects.filter(date__lte=datetime.datetime(1980,1,1))
self.assertEqual(logs.count(), 30)
logs = LogEntry.objects.filter(
date__lte=datetime.datetime(2011,1,1),
date__gte=datetime.datetime(2000,1,1),
)
self.assertEqual(logs.count(), 10)
LogEntry.drop_collection()
def test_list_validation(self):
"""Ensure that a list field only accepts lists with valid elements.
"""
@ -322,6 +423,108 @@ class FieldTest(unittest.TestCase):
BlogPost.drop_collection()
def test_list_field(self):
"""Ensure that list types work as expected.
"""
class BlogPost(Document):
info = ListField()
BlogPost.drop_collection()
post = BlogPost()
post.info = 'my post'
self.assertRaises(ValidationError, post.validate)
post.info = {'title': 'test'}
self.assertRaises(ValidationError, post.validate)
post.info = ['test']
post.save()
post = BlogPost()
post.info = [{'test': 'test'}]
post.save()
post = BlogPost()
post.info = [{'test': 3}]
post.save()
self.assertEquals(BlogPost.objects.count(), 3)
self.assertEquals(BlogPost.objects.filter(info__exact='test').count(), 1)
self.assertEquals(BlogPost.objects.filter(info__0__test='test').count(), 1)
# Confirm handles non strings or non existing keys
self.assertEquals(BlogPost.objects.filter(info__0__test__exact='5').count(), 0)
self.assertEquals(BlogPost.objects.filter(info__100__test__exact='test').count(), 0)
BlogPost.drop_collection()
def test_list_field_strict(self):
"""Ensure that list field handles validation if provided a strict field type."""
class Simple(Document):
mapping = ListField(field=IntField())
Simple.drop_collection()
e = Simple()
e.mapping = [1]
e.save()
def create_invalid_mapping():
e.mapping = ["abc"]
e.save()
self.assertRaises(ValidationError, create_invalid_mapping)
Simple.drop_collection()
def test_list_field_complex(self):
"""Ensure that the list fields can handle the complex types."""
class SettingBase(EmbeddedDocument):
pass
class StringSetting(SettingBase):
value = StringField()
class IntegerSetting(SettingBase):
value = IntField()
class Simple(Document):
mapping = ListField()
Simple.drop_collection()
e = Simple()
e.mapping.append(StringSetting(value='foo'))
e.mapping.append(IntegerSetting(value=42))
e.mapping.append({'number': 1, 'string': 'Hi!', 'float': 1.001,
'complex': IntegerSetting(value=42), 'list':
[IntegerSetting(value=42), StringSetting(value='foo')]})
e.save()
e2 = Simple.objects.get(id=e.id)
self.assertTrue(isinstance(e2.mapping[0], StringSetting))
self.assertTrue(isinstance(e2.mapping[1], IntegerSetting))
# Test querying
self.assertEquals(Simple.objects.filter(mapping__1__value=42).count(), 1)
self.assertEquals(Simple.objects.filter(mapping__2__number=1).count(), 1)
self.assertEquals(Simple.objects.filter(mapping__2__complex__value=42).count(), 1)
self.assertEquals(Simple.objects.filter(mapping__2__list__0__value=42).count(), 1)
self.assertEquals(Simple.objects.filter(mapping__2__list__1__value='foo').count(), 1)
# Confirm can update
Simple.objects().update(set__mapping__1=IntegerSetting(value=10))
self.assertEquals(Simple.objects.filter(mapping__1__value=10).count(), 1)
Simple.objects().update(
set__mapping__2__list__1=StringSetting(value='Boo'))
self.assertEquals(Simple.objects.filter(mapping__2__list__1__value='foo').count(), 0)
self.assertEquals(Simple.objects.filter(mapping__2__list__1__value='Boo').count(), 1)
Simple.drop_collection()
def test_dict_field(self):
"""Ensure that dict types work as expected.
"""
@ -363,6 +566,131 @@ class FieldTest(unittest.TestCase):
self.assertEquals(BlogPost.objects.filter(info__made_up__test__exact='test').count(), 0)
BlogPost.drop_collection()
def test_dictfield_strict(self):
"""Ensure that dict field handles validation if provided a strict field type."""
class Simple(Document):
mapping = DictField(field=IntField())
Simple.drop_collection()
e = Simple()
e.mapping['someint'] = 1
e.save()
def create_invalid_mapping():
e.mapping['somestring'] = "abc"
e.save()
self.assertRaises(ValidationError, create_invalid_mapping)
Simple.drop_collection()
def test_dictfield_complex(self):
"""Ensure that the dict field can handle the complex types."""
class SettingBase(EmbeddedDocument):
pass
class StringSetting(SettingBase):
value = StringField()
class IntegerSetting(SettingBase):
value = IntField()
class Simple(Document):
mapping = DictField()
Simple.drop_collection()
e = Simple()
e.mapping['somestring'] = StringSetting(value='foo')
e.mapping['someint'] = IntegerSetting(value=42)
e.mapping['nested_dict'] = {'number': 1, 'string': 'Hi!', 'float': 1.001,
'complex': IntegerSetting(value=42), 'list':
[IntegerSetting(value=42), StringSetting(value='foo')]}
e.save()
e2 = Simple.objects.get(id=e.id)
self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting))
self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting))
# Test querying
self.assertEquals(Simple.objects.filter(mapping__someint__value=42).count(), 1)
self.assertEquals(Simple.objects.filter(mapping__nested_dict__number=1).count(), 1)
self.assertEquals(Simple.objects.filter(mapping__nested_dict__complex__value=42).count(), 1)
self.assertEquals(Simple.objects.filter(mapping__nested_dict__list__0__value=42).count(), 1)
self.assertEquals(Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 1)
# Confirm can update
Simple.objects().update(
set__mapping={"someint": IntegerSetting(value=10)})
Simple.objects().update(
set__mapping__nested_dict__list__1=StringSetting(value='Boo'))
self.assertEquals(Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 0)
self.assertEquals(Simple.objects.filter(mapping__nested_dict__list__1__value='Boo').count(), 1)
Simple.drop_collection()
def test_mapfield(self):
"""Ensure that the MapField handles the declared type."""
class Simple(Document):
mapping = MapField(IntField())
Simple.drop_collection()
e = Simple()
e.mapping['someint'] = 1
e.save()
def create_invalid_mapping():
e.mapping['somestring'] = "abc"
e.save()
self.assertRaises(ValidationError, create_invalid_mapping)
def create_invalid_class():
class NoDeclaredType(Document):
mapping = MapField()
self.assertRaises(ValidationError, create_invalid_class)
Simple.drop_collection()
def test_complex_mapfield(self):
"""Ensure that the MapField can handle complex declared types."""
class SettingBase(EmbeddedDocument):
pass
class StringSetting(SettingBase):
value = StringField()
class IntegerSetting(SettingBase):
value = IntField()
class Extensible(Document):
mapping = MapField(EmbeddedDocumentField(SettingBase))
Extensible.drop_collection()
e = Extensible()
e.mapping['somestring'] = StringSetting(value='foo')
e.mapping['someint'] = IntegerSetting(value=42)
e.save()
e2 = Extensible.objects.get(id=e.id)
self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting))
self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting))
def create_invalid_mapping():
e.mapping['someint'] = 123
e.save()
self.assertRaises(ValidationError, create_invalid_mapping)
Extensible.drop_collection()
def test_embedded_document_validation(self):
"""Ensure that invalid embedded documents cannot be assigned to
embedded document fields.
@ -773,6 +1101,35 @@ class FieldTest(unittest.TestCase):
Shirt.drop_collection()
def test_choices_get_field_display(self):
"""Test dynamic helper for returning the display value of a choices field.
"""
class Shirt(Document):
size = StringField(max_length=3, choices=(('S', 'Small'), ('M', 'Medium'), ('L', 'Large'),
('XL', 'Extra Large'), ('XXL', 'Extra Extra Large')))
style = StringField(max_length=3, choices=(('S', 'Small'), ('B', 'Baggy'), ('W', 'wide')), default='S')
Shirt.drop_collection()
shirt = Shirt()
self.assertEqual(shirt.get_size_display(), None)
self.assertEqual(shirt.get_style_display(), 'Small')
shirt.size = "XXL"
shirt.style = "B"
self.assertEqual(shirt.get_size_display(), 'Extra Extra Large')
self.assertEqual(shirt.get_style_display(), 'Baggy')
# Set as Z - an invalid choice
shirt.size = "Z"
shirt.style = "Z"
self.assertEqual(shirt.get_size_display(), 'Z')
self.assertEqual(shirt.get_style_display(), 'Z')
self.assertRaises(ValidationError, shirt.validate)
Shirt.drop_collection()
def test_file_fields(self):
"""Ensure that file fields can be written to and their data retrieved
"""
@ -904,66 +1261,6 @@ class FieldTest(unittest.TestCase):
self.assertEqual(d2.data, {})
self.assertEqual(d2.data2, {})
def test_mapfield(self):
"""Ensure that the MapField handles the declared type."""
class Simple(Document):
mapping = MapField(IntField())
Simple.drop_collection()
e = Simple()
e.mapping['someint'] = 1
e.save()
def create_invalid_mapping():
e.mapping['somestring'] = "abc"
e.save()
self.assertRaises(ValidationError, create_invalid_mapping)
def create_invalid_class():
class NoDeclaredType(Document):
mapping = MapField()
self.assertRaises(ValidationError, create_invalid_class)
Simple.drop_collection()
def test_complex_mapfield(self):
"""Ensure that the MapField can handle complex declared types."""
class SettingBase(EmbeddedDocument):
pass
class StringSetting(SettingBase):
value = StringField()
class IntegerSetting(SettingBase):
value = IntField()
class Extensible(Document):
mapping = MapField(EmbeddedDocumentField(SettingBase))
Extensible.drop_collection()
e = Extensible()
e.mapping['somestring'] = StringSetting(value='foo')
e.mapping['someint'] = IntegerSetting(value=42)
e.save()
e2 = Extensible.objects.get(id=e.id)
self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting))
self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting))
def create_invalid_mapping():
e.mapping['someint'] = 123
e.save()
self.assertRaises(ValidationError, create_invalid_mapping)
Extensible.drop_collection()
if __name__ == '__main__':
unittest.main()

View File

@ -1830,6 +1830,22 @@ class QuerySetTest(unittest.TestCase):
self.assertTrue([('_types', 1)] in info)
self.assertTrue([('_types', 1), ('date', -1)] in info)
def test_dont_index_types(self):
"""Ensure that index_types will, when disabled, prevent _types
being added to all indices.
"""
class BlogPost(Document):
date = DateTimeField()
meta = {'index_types': False,
'indexes': ['-date']}
# Indexes are lazy so use list() to perform query
list(BlogPost.objects)
info = BlogPost.objects._collection.index_information()
info = [value['key'] for key, value in info.iteritems()]
self.assertTrue([('_types', 1)] not in info)
self.assertTrue([('date', -1)] in info)
BlogPost.drop_collection()
class BlogPost(Document):