Merge remote branch 'upstream/dev' into dev

This commit is contained in:
Colin Howe 2011-06-10 16:50:44 +01:00
commit c1fadcac85
12 changed files with 1080 additions and 217 deletions

View File

@ -53,6 +53,8 @@ Fields
.. autoclass:: mongoengine.DateTimeField .. autoclass:: mongoengine.DateTimeField
.. autoclass:: mongoengine.ComplexDateTimeField
.. autoclass:: mongoengine.EmbeddedDocumentField .. autoclass:: mongoengine.EmbeddedDocumentField
.. autoclass:: mongoengine.DictField .. autoclass:: mongoengine.DictField

View File

@ -5,11 +5,15 @@ Changelog
Changes in dev Changes in dev
============== ==============
- Added slave_okay kwarg to queryset - Fixed saving so sets updated values rather than overwrites
- Added ComplexDateTimeField - Handles datetimes correctly with microseconds
- Added ComplexBaseField - for improved flexibility and performance
- Added get_FIELD_display() method for easy choice field displaying
- Added queryset.slave_okay(enabled) method
- Updated queryset.timeout(enabled) and queryset.snapshot(enabled) to be chainable
- Added insert method for bulk inserts - Added insert method for bulk inserts
- Added blinker signal support - Added blinker signal support
- Added query_counter context manager for tests - Added query_counter context manager for tests
- Added DereferenceBaseField - for improved performance in field dereferencing
- Added optional map_reduce method item_frequencies - Added optional map_reduce method item_frequencies
- Added inline_map_reduce option to map_reduce - Added inline_map_reduce option to map_reduce
- Updated connection exception so it provides more info on the cause. - Updated connection exception so it provides more info on the cause.

View File

@ -8,6 +8,7 @@ import sys
import pymongo import pymongo
import pymongo.objectid import pymongo.objectid
from operator import itemgetter from operator import itemgetter
from functools import partial
class NotRegistered(Exception): class NotRegistered(Exception):
@ -61,6 +62,7 @@ class BaseField(object):
self.primary_key = primary_key self.primary_key = primary_key
self.validation = validation self.validation = validation
self.choices = choices self.choices = choices
# Adjust the appropriate creation counter, and save our local copy. # Adjust the appropriate creation counter, and save our local copy.
if self.db_field == '_id': if self.db_field == '_id':
self.creation_counter = BaseField.auto_creation_counter self.creation_counter = BaseField.auto_creation_counter
@ -90,6 +92,9 @@ class BaseField(object):
"""Descriptor for assigning a value to a field in a document. """Descriptor for assigning a value to a field in a document.
""" """
instance._data[self.name] = value instance._data[self.name] = value
# If the field set is in the _present_fields list add it so we can track
if hasattr(instance, '_present_fields') and self.name not in instance._present_fields:
instance._present_fields.append(self.name)
def to_python(self, value): def to_python(self, value):
"""Convert a MongoDB-compatible type to a Python type. """Convert a MongoDB-compatible type to a Python type.
@ -130,15 +135,19 @@ class BaseField(object):
self.validate(value) self.validate(value)
class DereferenceBaseField(BaseField): class ComplexBaseField(BaseField):
"""Handles the lazy dereferencing of a queryset. Will dereference all """Handles complex fields, such as lists / dictionaries.
Allows for nesting of embedded documents inside complex types.
Handles the lazy dereferencing of a queryset by lazily dereferencing all
items in a list / dict rather than one at a time. items in a list / dict rather than one at a time.
""" """
field = None
def __get__(self, instance, owner): def __get__(self, instance, owner):
"""Descriptor to automatically dereference references. """Descriptor to automatically dereference references.
""" """
from fields import ReferenceField, GenericReferenceField
from connection import _get_db from connection import _get_db
if instance is None: if instance is None:
@ -147,68 +156,175 @@ class DereferenceBaseField(BaseField):
# Get value from document instance if available # Get value from document instance if available
value_list = instance._data.get(self.name) value_list = instance._data.get(self.name)
if not value_list: if not value_list or isinstance(value_list, basestring):
return super(DereferenceBaseField, self).__get__(instance, owner) return super(ComplexBaseField, self).__get__(instance, owner)
is_list = False is_list = False
if not hasattr(value_list, 'items'): if not hasattr(value_list, 'items'):
is_list = True is_list = True
value_list = dict([(k,v) for k,v in enumerate(value_list)]) value_list = dict([(k,v) for k,v in enumerate(value_list)])
if isinstance(self.field, ReferenceField) and value_list: for k,v in value_list.items():
db = _get_db() if isinstance(v, dict) and '_cls' in v and '_ref' not in v:
dbref = {} value_list[k] = get_document(v['_cls'].split('.')[-1])._from_son(v)
collections = {}
for k, v in value_list.items(): # Handle all dereferencing
dbref[k] = v db = _get_db()
# Save any DBRefs dbref = {}
collections = {}
for k, v in value_list.items():
dbref[k] = v
# Save any DBRefs
if isinstance(v, (pymongo.dbref.DBRef)):
# direct reference (DBRef)
collections.setdefault(v.collection, []).append((k, v))
elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v:
# generic reference
collection = get_document(v['_cls'])._meta['collection']
collections.setdefault(collection, []).append((k, v))
# For each collection get the references
for collection, dbrefs in collections.items():
id_map = {}
for k, v in dbrefs:
if isinstance(v, (pymongo.dbref.DBRef)): if isinstance(v, (pymongo.dbref.DBRef)):
collections.setdefault(v.collection, []).append((k, v)) # direct reference (DBRef), has no _cls information
id_map[v.id] = (k, None)
elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v:
# generic reference - includes _cls information
id_map[v['_ref'].id] = (k, get_document(v['_cls']))
# For each collection get the references references = db[collection].find({'_id': {'$in': id_map.keys()}})
for collection, dbrefs in collections.items(): for ref in references:
id_map = dict([(v.id, k) for k, v in dbrefs]) key, doc_cls = id_map[ref['_id']]
references = db[collection].find({'_id': {'$in': id_map.keys()}}) if not doc_cls: # If no doc_cls get it from the referenced doc
for ref in references: doc_cls = get_document(ref['_cls'])
key = id_map[ref['_id']] dbref[key] = doc_cls._from_son(ref)
dbref[key] = get_document(ref['_cls'])._from_son(ref)
if is_list: if is_list:
dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))] dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))]
instance._data[self.name] = dbref instance._data[self.name] = dbref
return super(ComplexBaseField, self).__get__(instance, owner)
# Get value from document instance if available def to_python(self, value):
if isinstance(self.field, GenericReferenceField) and value_list: """Convert a MongoDB-compatible type to a Python type.
db = _get_db() """
value_list = [(k,v) for k,v in value_list.items()] from mongoengine import Document
dbref = {}
classes = {}
for k, v in value_list: if isinstance(value, basestring):
dbref[k] = v return value
# Save any DBRefs
if isinstance(v, (dict, pymongo.son.SON)):
classes.setdefault(v['_cls'], []).append((k, v))
# For each collection get the references if hasattr(value, 'to_python'):
for doc_cls, dbrefs in classes.items(): return value.to_python()
id_map = dict([(v['_ref'].id, k) for k, v in dbrefs])
doc_cls = get_document(doc_cls)
collection = doc_cls._meta['collection']
references = db[collection].find({'_id': {'$in': id_map.keys()}})
for ref in references: is_list = False
key = id_map[ref['_id']] if not hasattr(value, 'items'):
dbref[key] = doc_cls._from_son(ref) try:
is_list = True
value = dict([(k,v) for k,v in enumerate(value)])
except TypeError: # Not iterable return the value
return value
if is_list: if self.field:
dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))] value_dict = dict([(key, self.field.to_python(item)) for key, item in value.items()])
else:
value_dict = {}
for k,v in value.items():
if isinstance(v, Document):
# We need the id from the saved object to create the DBRef
if v.pk is None:
raise ValidationError('You can only reference documents once '
'they have been saved to the database')
collection = v._meta['collection']
value_dict[k] = pymongo.dbref.DBRef(collection, v.pk)
elif hasattr(v, 'to_python'):
value_dict[k] = v.to_python()
else:
value_dict[k] = self.to_python(v)
instance._data[self.name] = dbref if is_list: # Convert back to a list
return [v for k,v in sorted(value_dict.items(), key=itemgetter(0))]
return value_dict
return super(DereferenceBaseField, self).__get__(instance, owner) def to_mongo(self, value):
"""Convert a Python type to a MongoDB-compatible type.
"""
from mongoengine import Document
if isinstance(value, basestring):
return value
if hasattr(value, 'to_mongo'):
return value.to_mongo()
is_list = False
if not hasattr(value, 'items'):
try:
is_list = True
value = dict([(k,v) for k,v in enumerate(value)])
except TypeError: # Not iterable return the value
return value
if self.field:
value_dict = dict([(key, self.field.to_mongo(item)) for key, item in value.items()])
else:
value_dict = {}
for k,v in value.items():
if isinstance(v, Document):
# We need the id from the saved object to create the DBRef
if v.pk is None:
raise ValidationError('You can only reference documents once '
'they have been saved to the database')
# If its a document that is not inheritable it won't have
# _types / _cls data so make it a generic reference allows
# us to dereference
meta = getattr(v, 'meta', getattr(v, '_meta', {}))
if meta and not meta['allow_inheritance'] and not self.field:
from fields import GenericReferenceField
value_dict[k] = GenericReferenceField().to_mongo(v)
else:
collection = v._meta['collection']
value_dict[k] = pymongo.dbref.DBRef(collection, v.pk)
elif hasattr(v, 'to_mongo'):
value_dict[k] = v.to_mongo()
else:
value_dict[k] = self.to_mongo(v)
if is_list: # Convert back to a list
return [v for k,v in sorted(value_dict.items(), key=itemgetter(0))]
return value_dict
def validate(self, value):
"""If field provided ensure the value is valid.
"""
if self.field:
try:
if hasattr(value, 'iteritems'):
[self.field.validate(v) for k,v in value.iteritems()]
else:
[self.field.validate(v) for v in value]
except Exception, err:
raise ValidationError('Invalid %s item (%s)' % (
self.field.__class__.__name__, str(v)))
def prepare_query_value(self, op, value):
return self.to_mongo(value)
def lookup_member(self, member_name):
if self.field:
return self.field.lookup_member(member_name)
return None
def _set_owner_document(self, owner_document):
if self.field:
self.field.owner_document = owner_document
self._owner_document = owner_document
def _get_owner_document(self, owner_document):
self._owner_document = owner_document
owner_document = property(_get_owner_document, _set_owner_document)
class ObjectIdField(BaseField): class ObjectIdField(BaseField):
@ -217,7 +333,6 @@ class ObjectIdField(BaseField):
def to_python(self, value): def to_python(self, value):
return value return value
# return unicode(value)
def to_mongo(self, value): def to_mongo(self, value):
if not isinstance(value, pymongo.objectid.ObjectId): if not isinstance(value, pymongo.objectid.ObjectId):
@ -261,7 +376,7 @@ class DocumentMetaclass(type):
superclasses[base._class_name] = base superclasses[base._class_name] = base
superclasses.update(base._superclasses) superclasses.update(base._superclasses)
if hasattr(base, '_meta'): if hasattr(base, '_meta') and not base._meta.get('abstract'):
# Ensure that the Document class may be subclassed - # Ensure that the Document class may be subclassed -
# inheritance may be disabled to remove dependency on # inheritance may be disabled to remove dependency on
# additional fields _cls and _types # additional fields _cls and _types
@ -278,7 +393,7 @@ class DocumentMetaclass(type):
# Only simple classes - direct subclasses of Document - may set # Only simple classes - direct subclasses of Document - may set
# allow_inheritance to False # allow_inheritance to False
if not simple_class and not meta['allow_inheritance']: if not simple_class and not meta['allow_inheritance'] and not meta['abstract']:
raise ValueError('Only direct subclasses of Document may set ' raise ValueError('Only direct subclasses of Document may set '
'"allow_inheritance" to False') '"allow_inheritance" to False')
attrs['_meta'] = meta attrs['_meta'] = meta
@ -358,8 +473,9 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
# Subclassed documents inherit collection from superclass # Subclassed documents inherit collection from superclass
for base in bases: for base in bases:
if hasattr(base, '_meta') and 'collection' in base._meta: if hasattr(base, '_meta'):
collection = base._meta['collection'] if 'collection' in base._meta:
collection = base._meta['collection']
# Propagate index options. # Propagate index options.
for key in ('index_background', 'index_drop_dups', 'index_opts'): for key in ('index_background', 'index_drop_dups', 'index_opts'):
@ -368,6 +484,9 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
id_field = id_field or base._meta.get('id_field') id_field = id_field or base._meta.get('id_field')
base_indexes += base._meta.get('indexes', []) base_indexes += base._meta.get('indexes', [])
# Propagate 'allow_inheritance'
if 'allow_inheritance' in base._meta:
base_meta['allow_inheritance'] = base._meta['allow_inheritance']
meta = { meta = {
'abstract': False, 'abstract': False,
@ -382,6 +501,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
'index_opts': {}, 'index_opts': {},
'queryset_class': QuerySet, 'queryset_class': QuerySet,
'delete_rules': {}, 'delete_rules': {},
'allow_inheritance': True
} }
meta.update(base_meta) meta.update(base_meta)
@ -471,19 +591,28 @@ class BaseDocument(object):
self._data = {} self._data = {}
# Assign default values to instance # Assign default values to instance
for attr_name in self._fields.keys(): for attr_name, field in self._fields.items():
# Use default value if present if field.choices: # dynamically adds a way to get the display value for a field with choices
setattr(self, 'get_%s_display' % attr_name, partial(self._get_FIELD_display, field=field))
value = getattr(self, attr_name, None) value = getattr(self, attr_name, None)
setattr(self, attr_name, value) setattr(self, attr_name, value)
# Assign initial values to instance # Assign initial values to instance
for attr_name in values.keys(): for attr_name in values.keys():
try: try:
setattr(self, attr_name, values.pop(attr_name)) value = values.pop(attr_name)
setattr(self, attr_name, value)
except AttributeError: except AttributeError:
pass pass
signals.post_init.send(self) signals.post_init.send(self)
def _get_FIELD_display(self, field):
"""Returns the display value for a choice field"""
value = getattr(self, field.name)
return dict(field.choices).get(value, value)
def validate(self): def validate(self):
"""Ensure that all fields' values are valid and that required fields """Ensure that all fields' values are valid and that required fields
are present. are present.
@ -614,7 +743,6 @@ class BaseDocument(object):
cls = subclasses[class_name] cls = subclasses[class_name]
present_fields = data.keys() present_fields = data.keys()
for field_name, field in cls._fields.items(): for field_name, field in cls._fields.items():
if field.db_field in data: if field.db_field in data:
value = data[field.db_field] value = data[field.db_field]

View File

@ -53,6 +53,11 @@ class Document(BaseDocument):
dictionary. The value should be a list of field names or tuples of field dictionary. The value should be a list of field names or tuples of field
names. Index direction may be specified by prefixing the field names with names. Index direction may be specified by prefixing the field names with
a **+** or **-** sign. a **+** or **-** sign.
By default, _types will be added to the start of every index (that
doesn't contain a list) if allow_inheritence is True. This can be
disabled by either setting types to False on the specific index or
by setting index_types to False on the meta dictionary for the document.
""" """
__metaclass__ = TopLevelDocumentMetaclass __metaclass__ = TopLevelDocumentMetaclass
@ -90,6 +95,16 @@ class Document(BaseDocument):
collection = self.__class__.objects._collection collection = self.__class__.objects._collection
if force_insert: if force_insert:
object_id = collection.insert(doc, safe=safe, **write_options) object_id = collection.insert(doc, safe=safe, **write_options)
elif '_id' in doc:
# Perform a set rather than a save - this will only save set fields
object_id = doc.pop('_id')
collection.update({'_id': object_id}, {"$set": doc}, upsert=True, safe=safe, **write_options)
# Find and unset any fields explicitly set to None
if hasattr(self, '_present_fields'):
removals = dict([(k, 1) for k in self._present_fields if k not in doc and k != '_id'])
if removals:
collection.update({'_id': object_id}, {"$unset": removals}, upsert=True, safe=safe, **write_options)
else: else:
object_id = collection.save(doc, safe=safe, **write_options) object_id = collection.save(doc, safe=safe, **write_options)
except pymongo.errors.OperationFailure, err: except pymongo.errors.OperationFailure, err:

View File

@ -1,4 +1,4 @@
from base import (BaseField, DereferenceBaseField, ObjectIdField, from base import (BaseField, ComplexBaseField, ObjectIdField,
ValidationError, get_document) ValidationError, get_document)
from queryset import DO_NOTHING from queryset import DO_NOTHING
from document import Document, EmbeddedDocument from document import Document, EmbeddedDocument
@ -18,8 +18,9 @@ import gridfs
__all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField', __all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField',
'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField', 'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField',
'ObjectIdField', 'ReferenceField', 'ValidationError', 'MapField', 'ObjectIdField', 'ReferenceField', 'ValidationError', 'MapField',
'DecimalField', 'URLField', 'GenericReferenceField', 'FileField', 'DecimalField', 'ComplexDateTimeField', 'URLField',
'BinaryField', 'SortedListField', 'EmailField', 'GeoPointField'] 'GenericReferenceField', 'FileField', 'BinaryField',
'SortedListField', 'EmailField', 'GeoPointField']
RECURSIVE_REFERENCE_CONSTANT = 'self' RECURSIVE_REFERENCE_CONSTANT = 'self'
@ -273,6 +274,98 @@ class DateTimeField(BaseField):
return None return None
class ComplexDateTimeField(StringField):
"""
ComplexDateTimeField handles microseconds exactly instead of rounding
like DateTimeField does.
Derives from a StringField so you can do `gte` and `lte` filtering by
using lexicographical comparison when filtering / sorting strings.
The stored string has the following format:
YYYY,MM,DD,HH,MM,SS,NNNNNN
Where NNNNNN is the number of microseconds of the represented `datetime`.
The `,` as the separator can be easily modified by passing the `separator`
keyword when initializing the field.
"""
def __init__(self, separator=',', **kwargs):
self.names = ['year', 'month', 'day', 'hour', 'minute', 'second',
'microsecond']
self.separtor = separator
super(ComplexDateTimeField, self).__init__(**kwargs)
def _leading_zero(self, number):
"""
Converts the given number to a string.
If it has only one digit, a leading zero so as it has always at least
two digits.
"""
if int(number) < 10:
return "0%s" % number
else:
return str(number)
def _convert_from_datetime(self, val):
"""
Convert a `datetime` object to a string representation (which will be
stored in MongoDB). This is the reverse function of
`_convert_from_string`.
>>> a = datetime(2011, 6, 8, 20, 26, 24, 192284)
>>> RealDateTimeField()._convert_from_datetime(a)
'2011,06,08,20,26,24,192284'
"""
data = []
for name in self.names:
data.append(self._leading_zero(getattr(val, name)))
return ','.join(data)
def _convert_from_string(self, data):
"""
Convert a string representation to a `datetime` object (the object you
will manipulate). This is the reverse function of
`_convert_from_datetime`.
>>> a = '2011,06,08,20,26,24,192284'
>>> ComplexDateTimeField()._convert_from_string(a)
datetime.datetime(2011, 6, 8, 20, 26, 24, 192284)
"""
data = data.split(',')
data = map(int, data)
values = {}
for i in range(7):
values[self.names[i]] = data[i]
return datetime.datetime(**values)
def __get__(self, instance, owner):
data = super(ComplexDateTimeField, self).__get__(instance, owner)
if data == None:
return datetime.datetime.now()
return self._convert_from_string(data)
def __set__(self, obj, val):
data = self._convert_from_datetime(val)
return super(ComplexDateTimeField, self).__set__(obj, data)
def validate(self, value):
if not isinstance(value, datetime.datetime):
raise ValidationError('Only datetime objects may used in a \
ComplexDateTimeField')
def to_python(self, value):
return self._convert_from_string(value)
def to_mongo(self, value):
return self._convert_from_datetime(value)
def prepare_query_value(self, op, value):
return self._convert_from_datetime(value)
class EmbeddedDocumentField(BaseField): class EmbeddedDocumentField(BaseField):
"""An embedded document field. Only valid values are subclasses of """An embedded document field. Only valid values are subclasses of
:class:`~mongoengine.EmbeddedDocument`. :class:`~mongoengine.EmbeddedDocument`.
@ -301,6 +394,8 @@ class EmbeddedDocumentField(BaseField):
return value return value
def to_mongo(self, value): def to_mongo(self, value):
if isinstance(value, basestring):
return value
return self.document_type.to_mongo(value) return self.document_type.to_mongo(value)
def validate(self, value): def validate(self, value):
@ -320,7 +415,7 @@ class EmbeddedDocumentField(BaseField):
return self.to_mongo(value) return self.to_mongo(value)
class ListField(DereferenceBaseField): class ListField(ComplexBaseField):
"""A list field that wraps a standard field, allowing multiple instances """A list field that wraps a standard field, allowing multiple instances
of the field to be used as a list in the database. of the field to be used as a list in the database.
""" """
@ -328,48 +423,25 @@ class ListField(DereferenceBaseField):
# ListFields cannot be indexed with _types - MongoDB doesn't support this # ListFields cannot be indexed with _types - MongoDB doesn't support this
_index_with_types = False _index_with_types = False
def __init__(self, field, **kwargs): def __init__(self, field=None, **kwargs):
if not isinstance(field, BaseField):
raise ValidationError('Argument to ListField constructor must be '
'a valid field')
self.field = field self.field = field
kwargs.setdefault('default', lambda: []) kwargs.setdefault('default', lambda: [])
super(ListField, self).__init__(**kwargs) super(ListField, self).__init__(**kwargs)
def to_python(self, value):
return [self.field.to_python(item) for item in value]
def to_mongo(self, value):
return [self.field.to_mongo(item) for item in value]
def validate(self, value): def validate(self, value):
"""Make sure that a list of valid fields is being used. """Make sure that a list of valid fields is being used.
""" """
if not isinstance(value, (list, tuple)): if not isinstance(value, (list, tuple)):
raise ValidationError('Only lists and tuples may be used in a ' raise ValidationError('Only lists and tuples may be used in a '
'list field') 'list field')
super(ListField, self).validate(value)
try:
[self.field.validate(item) for item in value]
except Exception, err:
raise ValidationError('Invalid ListField item (%s)' % str(item))
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
if op in ('set', 'unset'): if self.field:
return [self.field.prepare_query_value(op, v) for v in value] if op in ('set', 'unset') and not isinstance(value, basestring):
return self.field.prepare_query_value(op, value) return [self.field.prepare_query_value(op, v) for v in value]
return self.field.prepare_query_value(op, value)
def lookup_member(self, member_name): return super(ListField, self).prepare_query_value(op, value)
return self.field.lookup_member(member_name)
def _set_owner_document(self, owner_document):
self.field.owner_document = owner_document
self._owner_document = owner_document
def _get_owner_document(self, owner_document):
self._owner_document = owner_document
owner_document = property(_get_owner_document, _set_owner_document)
class SortedListField(ListField): class SortedListField(ListField):
@ -388,20 +460,21 @@ class SortedListField(ListField):
super(SortedListField, self).__init__(field, **kwargs) super(SortedListField, self).__init__(field, **kwargs)
def to_mongo(self, value): def to_mongo(self, value):
value = super(SortedListField, self).to_mongo(value)
if self._ordering is not None: if self._ordering is not None:
return sorted([self.field.to_mongo(item) for item in value], return sorted(value, key=itemgetter(self._ordering))
key=itemgetter(self._ordering)) return sorted(value)
return sorted([self.field.to_mongo(item) for item in value])
class DictField(BaseField): class DictField(ComplexBaseField):
"""A dictionary field that wraps a standard Python dictionary. This is """A dictionary field that wraps a standard Python dictionary. This is
similar to an embedded document, but the structure is not defined. similar to an embedded document, but the structure is not defined.
.. versionadded:: 0.3 .. versionadded:: 0.3
""" """
def __init__(self, basecls=None, *args, **kwargs): def __init__(self, basecls=None, field=None, *args, **kwargs):
self.field = field
self.basecls = basecls or BaseField self.basecls = basecls or BaseField
assert issubclass(self.basecls, BaseField) assert issubclass(self.basecls, BaseField)
kwargs.setdefault('default', lambda: {}) kwargs.setdefault('default', lambda: {})
@ -417,6 +490,7 @@ class DictField(BaseField):
if any(('.' in k or '$' in k) for k in value): if any(('.' in k or '$' in k) for k in value):
raise ValidationError('Invalid dictionary key name - keys may not ' raise ValidationError('Invalid dictionary key name - keys may not '
'contain "." or "$" characters') 'contain "." or "$" characters')
super(DictField, self).validate(value)
def lookup_member(self, member_name): def lookup_member(self, member_name):
return DictField(basecls=self.basecls, db_field=member_name) return DictField(basecls=self.basecls, db_field=member_name)
@ -432,7 +506,7 @@ class DictField(BaseField):
return super(DictField, self).prepare_query_value(op, value) return super(DictField, self).prepare_query_value(op, value)
class MapField(DereferenceBaseField): class MapField(DictField):
"""A field that maps a name to a specified field type. Similar to """A field that maps a name to a specified field type. Similar to
a DictField, except the 'value' of each item must match the specified a DictField, except the 'value' of each item must match the specified
field type. field type.
@ -444,50 +518,7 @@ class MapField(DereferenceBaseField):
if not isinstance(field, BaseField): if not isinstance(field, BaseField):
raise ValidationError('Argument to MapField constructor must be ' raise ValidationError('Argument to MapField constructor must be '
'a valid field') 'a valid field')
self.field = field super(MapField, self).__init__(field=field, *args, **kwargs)
kwargs.setdefault('default', lambda: {})
super(MapField, self).__init__(*args, **kwargs)
def validate(self, value):
"""Make sure that a list of valid fields is being used.
"""
if not isinstance(value, dict):
raise ValidationError('Only dictionaries may be used in a '
'DictField')
if any(('.' in k or '$' in k) for k in value):
raise ValidationError('Invalid dictionary key name - keys may not '
'contain "." or "$" characters')
try:
[self.field.validate(item) for item in value.values()]
except Exception, err:
raise ValidationError('Invalid MapField item (%s)' % str(item))
def to_python(self, value):
return dict([(key, self.field.to_python(item)) for key, item in value.iteritems()])
def to_mongo(self, value):
return dict([(key, self.field.to_mongo(item)) for key, item in value.iteritems()])
def prepare_query_value(self, op, value):
if op not in ('set', 'unset'):
return self.field.prepare_query_value(op, value)
for key in value:
value[key] = self.field.prepare_query_value(op, value[key])
return value
def lookup_member(self, member_name):
return self.field.lookup_member(member_name)
def _set_owner_document(self, owner_document):
self.field.owner_document = owner_document
self._owner_document = owner_document
def _get_owner_document(self, owner_document):
self._owner_document = owner_document
owner_document = property(_get_owner_document, _set_owner_document)
class ReferenceField(BaseField): class ReferenceField(BaseField):

View File

@ -418,8 +418,9 @@ class QuerySet(object):
use_types = False use_types = False
# If _types is being used, prepend it to every specified index # If _types is being used, prepend it to every specified index
if (spec.get('types', True) and doc_cls._meta.get('allow_inheritance') index_types = doc_cls._meta.get('index_types', True)
and use_types): allow_inheritance = doc_cls._meta.get('allow_inheritance')
if spec.get('types', index_types) and allow_inheritance and use_types:
index_list.insert(0, ('_types', 1)) index_list.insert(0, ('_types', 1))
spec['fields'] = index_list spec['fields'] = index_list
@ -474,6 +475,7 @@ class QuerySet(object):
background = self._document._meta.get('index_background', False) background = self._document._meta.get('index_background', False)
drop_dups = self._document._meta.get('index_drop_dups', False) drop_dups = self._document._meta.get('index_drop_dups', False)
index_opts = self._document._meta.get('index_options', {}) index_opts = self._document._meta.get('index_options', {})
index_types = self._document._meta.get('index_types', True)
# Ensure indexes created by uniqueness constraints # Ensure indexes created by uniqueness constraints
for index in self._document._meta['unique_indexes']: for index in self._document._meta['unique_indexes']:
@ -490,7 +492,7 @@ class QuerySet(object):
background=background, **opts) background=background, **opts)
# If _types is being used (for polymorphism), it needs an index # If _types is being used (for polymorphism), it needs an index
if '_types' in self._query: if index_types and '_types' in self._query:
self._collection.ensure_index('_types', self._collection.ensure_index('_types',
background=background, **index_opts) background=background, **index_opts)
@ -547,11 +549,12 @@ class QuerySet(object):
parts = [parts] parts = [parts]
fields = [] fields = []
field = None field = None
for field_name in parts: for field_name in parts:
# Handle ListField indexing: # Handle ListField indexing:
if field_name.isdigit(): if field_name.isdigit():
try: try:
field = field.field new_field = field.field
except AttributeError, err: except AttributeError, err:
raise InvalidQueryError( raise InvalidQueryError(
"Can't use index on unsubscriptable field (%s)" % err) "Can't use index on unsubscriptable field (%s)" % err)
@ -565,11 +568,17 @@ class QuerySet(object):
field = document._fields[field_name] field = document._fields[field_name]
else: else:
# Look up subfield on the previous field # Look up subfield on the previous field
field = field.lookup_member(field_name) new_field = field.lookup_member(field_name)
if field is None: from base import ComplexBaseField
if not new_field and isinstance(field, ComplexBaseField):
fields.append(field_name)
continue
elif not new_field:
raise InvalidQueryError('Cannot resolve field "%s"' raise InvalidQueryError('Cannot resolve field "%s"'
% field_name) % field_name)
field = new_field # update field to the new field type
fields.append(field) fields.append(field)
return fields return fields
@classmethod @classmethod
@ -613,14 +622,33 @@ class QuerySet(object):
if _doc_cls: if _doc_cls:
# Switch field names to proper names [set in Field(name='foo')] # Switch field names to proper names [set in Field(name='foo')]
fields = QuerySet._lookup_field(_doc_cls, parts) fields = QuerySet._lookup_field(_doc_cls, parts)
parts = [field.db_field for field in fields] parts = []
cleaned_fields = []
append_field = True
for field in fields:
if isinstance(field, str):
parts.append(field)
append_field = False
else:
parts.append(field.db_field)
if append_field:
cleaned_fields.append(field)
# Convert value to proper value # Convert value to proper value
field = fields[-1] field = cleaned_fields[-1]
singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not'] singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not']
singular_ops += match_operators singular_ops += match_operators
if op in singular_ops: if op in singular_ops:
value = field.prepare_query_value(op, value) if isinstance(field, basestring):
if op in match_operators and isinstance(value, basestring):
from mongoengine import StringField
value = StringField().prepare_query_value(op, value)
else:
value = field
else:
value = field.prepare_query_value(op, value)
elif op in ('in', 'nin', 'all', 'near'): elif op in ('in', 'nin', 'all', 'near'):
# 'in', 'nin' and 'all' require a list of values # 'in', 'nin' and 'all' require a list of values
value = [field.prepare_query_value(op, v) for v in value] value = [field.prepare_query_value(op, v) for v in value]
@ -1168,14 +1196,19 @@ class QuerySet(object):
fields = QuerySet._lookup_field(_doc_cls, parts) fields = QuerySet._lookup_field(_doc_cls, parts)
parts = [] parts = []
cleaned_fields = []
append_field = True
for field in fields: for field in fields:
if isinstance(field, str): if isinstance(field, str):
parts.append(field) parts.append(field)
append_field = False
else: else:
parts.append(field.db_field) parts.append(field.db_field)
if append_field:
cleaned_fields.append(field)
# Convert value to proper value # Convert value to proper value
field = fields[-1] field = cleaned_fields[-1]
if op in (None, 'set', 'push', 'pull', 'addToSet'): if op in (None, 'set', 'push', 'pull', 'addToSet'):
value = field.prepare_query_value(op, value) value = field.prepare_query_value(op, value)

View File

@ -45,6 +45,6 @@ setup(name='mongoengine',
long_description=LONG_DESCRIPTION, long_description=LONG_DESCRIPTION,
platforms=['any'], platforms=['any'],
classifiers=CLASSIFIERS, classifiers=CLASSIFIERS,
install_requires=['pymongo', 'blinker'], install_requires=['pymongo', 'blinker', 'django==1.3'],
test_suite='tests', test_suite='tests',
) )

View File

@ -122,6 +122,64 @@ class FieldTest(unittest.TestCase):
[m for m in group_obj.members] [m for m in group_obj.members]
self.assertEqual(q, 4) self.assertEqual(q, 4)
for m in group_obj.members:
self.assertTrue('User' in m.__class__.__name__)
UserA.drop_collection()
UserB.drop_collection()
UserC.drop_collection()
Group.drop_collection()
def test_list_field_complex(self):
class UserA(Document):
name = StringField()
class UserB(Document):
name = StringField()
class UserC(Document):
name = StringField()
class Group(Document):
members = ListField()
UserA.drop_collection()
UserB.drop_collection()
UserC.drop_collection()
Group.drop_collection()
members = []
for i in xrange(1, 51):
a = UserA(name='User A %s' % i)
a.save()
b = UserB(name='User B %s' % i)
b.save()
c = UserC(name='User C %s' % i)
c.save()
members += [a, b, c]
group = Group(members=members)
group.save()
with query_counter() as q:
self.assertEqual(q, 0)
group_obj = Group.objects.first()
self.assertEqual(q, 1)
[m for m in group_obj.members]
self.assertEqual(q, 4)
[m for m in group_obj.members]
self.assertEqual(q, 4)
for m in group_obj.members:
self.assertTrue('User' in m.__class__.__name__)
UserA.drop_collection() UserA.drop_collection()
UserB.drop_collection() UserB.drop_collection()
UserC.drop_collection() UserC.drop_collection()
@ -156,10 +214,13 @@ class FieldTest(unittest.TestCase):
[m for m in group_obj.members] [m for m in group_obj.members]
self.assertEqual(q, 2) self.assertEqual(q, 2)
for k, m in group_obj.members.iteritems():
self.assertTrue(isinstance(m, User))
User.drop_collection() User.drop_collection()
Group.drop_collection() Group.drop_collection()
def ztest_generic_reference_dict_field(self): def test_dict_field(self):
class UserA(Document): class UserA(Document):
name = StringField() name = StringField()
@ -206,6 +267,9 @@ class FieldTest(unittest.TestCase):
[m for m in group_obj.members] [m for m in group_obj.members]
self.assertEqual(q, 4) self.assertEqual(q, 4)
for k, m in group_obj.members.iteritems():
self.assertTrue('User' in m.__class__.__name__)
group.members = {} group.members = {}
group.save() group.save()
@ -218,11 +282,54 @@ class FieldTest(unittest.TestCase):
[m for m in group_obj.members] [m for m in group_obj.members]
self.assertEqual(q, 1) self.assertEqual(q, 1)
for k, m in group_obj.members.iteritems():
self.assertTrue('User' in m.__class__.__name__)
UserA.drop_collection() UserA.drop_collection()
UserB.drop_collection() UserB.drop_collection()
UserC.drop_collection() UserC.drop_collection()
Group.drop_collection() Group.drop_collection()
def test_dict_field_no_field_inheritance(self):
class UserA(Document):
name = StringField()
meta = {'allow_inheritance': False}
class Group(Document):
members = DictField()
UserA.drop_collection()
Group.drop_collection()
members = []
for i in xrange(1, 51):
a = UserA(name='User A %s' % i)
a.save()
members += [a]
group = Group(members=dict([(str(u.id), u) for u in members]))
group.save()
with query_counter() as q:
self.assertEqual(q, 0)
group_obj = Group.objects.first()
self.assertEqual(q, 1)
[m for m in group_obj.members]
self.assertEqual(q, 2)
[m for m in group_obj.members]
self.assertEqual(q, 2)
for k, m in group_obj.members.iteritems():
self.assertTrue(isinstance(m, UserA))
UserA.drop_collection()
Group.drop_collection()
def test_generic_reference_map_field(self): def test_generic_reference_map_field(self):
class UserA(Document): class UserA(Document):
@ -270,6 +377,9 @@ class FieldTest(unittest.TestCase):
[m for m in group_obj.members] [m for m in group_obj.members]
self.assertEqual(q, 4) self.assertEqual(q, 4)
for k, m in group_obj.members.iteritems():
self.assertTrue('User' in m.__class__.__name__)
group.members = {} group.members = {}
group.save() group.save()

57
tests/django_tests.py Normal file
View File

@ -0,0 +1,57 @@
# -*- coding: utf-8 -*-
import unittest
from mongoengine import *
from django.template import Context, Template
from django.conf import settings
settings.configure()
class QuerySetTest(unittest.TestCase):
def setUp(self):
connect(db='mongoenginetest')
class Person(Document):
name = StringField()
age = IntField()
self.Person = Person
def test_order_by_in_django_template(self):
"""Ensure that QuerySets are properly ordered in Django template.
"""
self.Person.drop_collection()
self.Person(name="A", age=20).save()
self.Person(name="D", age=10).save()
self.Person(name="B", age=40).save()
self.Person(name="C", age=30).save()
t = Template("{% for o in ol %}{{ o.name }}-{{ o.age }}:{% endfor %}")
d = {"ol": self.Person.objects.order_by('-name')}
self.assertEqual(t.render(Context(d)), u'D-10:C-30:B-40:A-20:')
d = {"ol": self.Person.objects.order_by('+name')}
self.assertEqual(t.render(Context(d)), u'A-20:B-40:C-30:D-10:')
d = {"ol": self.Person.objects.order_by('-age')}
self.assertEqual(t.render(Context(d)), u'B-40:C-30:A-20:D-10:')
d = {"ol": self.Person.objects.order_by('+age')}
self.assertEqual(t.render(Context(d)), u'D-10:A-20:C-30:B-40:')
self.Person.drop_collection()
def test_q_object_filter_in_template(self):
self.Person.drop_collection()
self.Person(name="A", age=20).save()
self.Person(name="D", age=10).save()
self.Person(name="B", age=40).save()
self.Person(name="C", age=30).save()
t = Template("{% for o in ol %}{{ o.name }}-{{ o.age }}:{% endfor %}")
d = {"ol": self.Person.objects.filter(Q(age=10) | Q(name="C"))}
self.assertEqual(t.render(Context(d)), u'D-10:C-30:')

View File

@ -151,12 +151,12 @@ class DocumentTest(unittest.TestCase):
"""Ensure that inheritance may be disabled on simple classes and that """Ensure that inheritance may be disabled on simple classes and that
_cls and _types will not be used. _cls and _types will not be used.
""" """
class Animal(Document): class Animal(Document):
meta = {'allow_inheritance': False}
name = StringField() name = StringField()
meta = {'allow_inheritance': False}
Animal.drop_collection() Animal.drop_collection()
def create_dog_class(): def create_dog_class():
class Dog(Animal): class Dog(Animal):
pass pass
@ -191,6 +191,92 @@ class DocumentTest(unittest.TestCase):
self.assertFalse('_cls' in comment.to_mongo()) self.assertFalse('_cls' in comment.to_mongo())
self.assertFalse('_types' in comment.to_mongo()) self.assertFalse('_types' in comment.to_mongo())
def test_allow_inheritance_abstract_document(self):
"""Ensure that abstract documents can set inheritance rules and that
_cls and _types will not be used.
"""
class FinalDocument(Document):
meta = {'abstract': True,
'allow_inheritance': False}
class Animal(FinalDocument):
name = StringField()
Animal.drop_collection()
def create_dog_class():
class Dog(Animal):
pass
self.assertRaises(ValueError, create_dog_class)
# Check that _cls etc aren't present on simple documents
dog = Animal(name='dog')
dog.save()
collection = self.db[Animal._meta['collection']]
obj = collection.find_one()
self.assertFalse('_cls' in obj)
self.assertFalse('_types' in obj)
Animal.drop_collection()
def test_how_to_turn_off_inheritance(self):
"""Demonstrates migrating from allow_inheritance = True to False.
"""
class Animal(Document):
name = StringField()
meta = {
'indexes': ['name']
}
Animal.drop_collection()
dog = Animal(name='dog')
dog.save()
collection = self.db[Animal._meta['collection']]
obj = collection.find_one()
self.assertTrue('_cls' in obj)
self.assertTrue('_types' in obj)
info = collection.index_information()
info = [value['key'] for key, value in info.iteritems()]
self.assertEquals([[(u'_id', 1)], [(u'_types', 1)], [(u'_types', 1), (u'name', 1)]], info)
# Turn off inheritance
class Animal(Document):
name = StringField()
meta = {
'allow_inheritance': False,
'indexes': ['name']
}
collection.update({}, {"$unset": {"_types": 1, "_cls": 1}}, False, True)
# Confirm extra data is removed
obj = collection.find_one()
self.assertFalse('_cls' in obj)
self.assertFalse('_types' in obj)
info = collection.index_information()
info = [value['key'] for key, value in info.iteritems()]
self.assertEquals([[(u'_id', 1)], [(u'_types', 1)], [(u'_types', 1), (u'name', 1)]], info)
info = collection.index_information()
indexes_to_drop = [key for key, value in info.iteritems() if '_types' in dict(value['key'])]
for index in indexes_to_drop:
collection.drop_index(index)
info = collection.index_information()
info = [value['key'] for key, value in info.iteritems()]
self.assertEquals([[(u'_id', 1)]], info)
# Recreate indexes
dog = Animal.objects.first()
dog.save()
info = collection.index_information()
info = [value['key'] for key, value in info.iteritems()]
self.assertEquals([[(u'_id', 1)], [(u'name', 1),]], info)
Animal.drop_collection()
def test_abstract_documents(self): def test_abstract_documents(self):
"""Ensure that a document superclass can be marked as abstract """Ensure that a document superclass can be marked as abstract
thereby not using it as the name for the collection.""" thereby not using it as the name for the collection."""
@ -703,6 +789,90 @@ class DocumentTest(unittest.TestCase):
except ValidationError: except ValidationError:
self.fail() self.fail()
def test_update(self):
"""Ensure that an existing document is updated instead of be overwritten.
"""
# Create person object and save it to the database
person = self.Person(name='Test User', age=30)
person.save()
# Create same person object, with same id, without age
same_person = self.Person(name='Test')
same_person.id = person.id
same_person.save()
# Confirm only one object
self.assertEquals(self.Person.objects.count(), 1)
# reload
person.reload()
same_person.reload()
# Confirm the same
self.assertEqual(person, same_person)
self.assertEqual(person.name, same_person.name)
self.assertEqual(person.age, same_person.age)
# Confirm the saved values
self.assertEqual(person.name, 'Test')
self.assertEqual(person.age, 30)
# Test only / exclude only updates included fields
person = self.Person.objects.only('name').get()
person.name = 'User'
person.save()
person.reload()
self.assertEqual(person.name, 'User')
self.assertEqual(person.age, 30)
# test exclude only updates set fields
person = self.Person.objects.exclude('name').get()
person.age = 21
person.save()
person.reload()
self.assertEqual(person.name, 'User')
self.assertEqual(person.age, 21)
# Test only / exclude can set non excluded / included fields
person = self.Person.objects.only('name').get()
person.name = 'Test'
person.age = 30
person.save()
person.reload()
self.assertEqual(person.name, 'Test')
self.assertEqual(person.age, 30)
# test exclude only updates set fields
person = self.Person.objects.exclude('name').get()
person.name = 'User'
person.age = 21
person.save()
person.reload()
self.assertEqual(person.name, 'User')
self.assertEqual(person.age, 21)
# Confirm does remove unrequired fields
person = self.Person.objects.exclude('name').get()
person.age = None
person.save()
person.reload()
self.assertEqual(person.name, 'User')
self.assertEqual(person.age, None)
person = self.Person.objects.get()
person.name = None
person.age = None
person.save()
person.reload()
self.assertEqual(person.name, None)
self.assertEqual(person.age, None)
def test_delete(self): def test_delete(self):
"""Ensure that document may be deleted using the delete method. """Ensure that document may be deleted using the delete method.
""" """

View File

@ -247,6 +247,107 @@ class FieldTest(unittest.TestCase):
LogEntry.drop_collection() LogEntry.drop_collection()
def test_complexdatetime_storage(self):
"""Tests for complex datetime fields - which can handle microseconds
without rounding.
"""
class LogEntry(Document):
date = ComplexDateTimeField()
LogEntry.drop_collection()
# Post UTC - microseconds are rounded (down) nearest millisecond and dropped - with default datetimefields
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999)
log = LogEntry()
log.date = d1
log.save()
log.reload()
self.assertEquals(log.date, d1)
# Post UTC - microseconds are rounded (down) nearest millisecond - with default datetimefields
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9999)
log.date = d1
log.save()
log.reload()
self.assertEquals(log.date, d1)
# Pre UTC dates microseconds below 1000 are dropped - with default datetimefields
d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999)
log.date = d1
log.save()
log.reload()
self.assertEquals(log.date, d1)
# Pre UTC microseconds above 1000 is wonky - with default datetimefields
# log.date has an invalid microsecond value so I can't construct
# a date to compare.
for i in xrange(1001, 3113, 33):
d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, i)
log.date = d1
log.save()
log.reload()
self.assertEquals(log.date, d1)
log1 = LogEntry.objects.get(date=d1)
self.assertEqual(log, log1)
LogEntry.drop_collection()
def test_complexdatetime_usage(self):
"""Tests for complex datetime fields - which can handle microseconds
without rounding.
"""
class LogEntry(Document):
date = ComplexDateTimeField()
LogEntry.drop_collection()
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999)
log = LogEntry()
log.date = d1
log.save()
log1 = LogEntry.objects.get(date=d1)
self.assertEquals(log, log1)
LogEntry.drop_collection()
# create 60 log entries
for i in xrange(1950, 2010):
d = datetime.datetime(i, 01, 01, 00, 00, 01, 999)
LogEntry(date=d).save()
self.assertEqual(LogEntry.objects.count(), 60)
# Test ordering
logs = LogEntry.objects.order_by("date")
count = logs.count()
i = 0
while i == count-1:
self.assertTrue(logs[i].date <= logs[i+1].date)
i +=1
logs = LogEntry.objects.order_by("-date")
count = logs.count()
i = 0
while i == count-1:
self.assertTrue(logs[i].date >= logs[i+1].date)
i +=1
# Test searching
logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980,1,1))
self.assertEqual(logs.count(), 30)
logs = LogEntry.objects.filter(date__lte=datetime.datetime(1980,1,1))
self.assertEqual(logs.count(), 30)
logs = LogEntry.objects.filter(
date__lte=datetime.datetime(2011,1,1),
date__gte=datetime.datetime(2000,1,1),
)
self.assertEqual(logs.count(), 10)
LogEntry.drop_collection()
def test_list_validation(self): def test_list_validation(self):
"""Ensure that a list field only accepts lists with valid elements. """Ensure that a list field only accepts lists with valid elements.
""" """
@ -322,6 +423,108 @@ class FieldTest(unittest.TestCase):
BlogPost.drop_collection() BlogPost.drop_collection()
def test_list_field(self):
"""Ensure that list types work as expected.
"""
class BlogPost(Document):
info = ListField()
BlogPost.drop_collection()
post = BlogPost()
post.info = 'my post'
self.assertRaises(ValidationError, post.validate)
post.info = {'title': 'test'}
self.assertRaises(ValidationError, post.validate)
post.info = ['test']
post.save()
post = BlogPost()
post.info = [{'test': 'test'}]
post.save()
post = BlogPost()
post.info = [{'test': 3}]
post.save()
self.assertEquals(BlogPost.objects.count(), 3)
self.assertEquals(BlogPost.objects.filter(info__exact='test').count(), 1)
self.assertEquals(BlogPost.objects.filter(info__0__test='test').count(), 1)
# Confirm handles non strings or non existing keys
self.assertEquals(BlogPost.objects.filter(info__0__test__exact='5').count(), 0)
self.assertEquals(BlogPost.objects.filter(info__100__test__exact='test').count(), 0)
BlogPost.drop_collection()
def test_list_field_strict(self):
"""Ensure that list field handles validation if provided a strict field type."""
class Simple(Document):
mapping = ListField(field=IntField())
Simple.drop_collection()
e = Simple()
e.mapping = [1]
e.save()
def create_invalid_mapping():
e.mapping = ["abc"]
e.save()
self.assertRaises(ValidationError, create_invalid_mapping)
Simple.drop_collection()
def test_list_field_complex(self):
"""Ensure that the list fields can handle the complex types."""
class SettingBase(EmbeddedDocument):
pass
class StringSetting(SettingBase):
value = StringField()
class IntegerSetting(SettingBase):
value = IntField()
class Simple(Document):
mapping = ListField()
Simple.drop_collection()
e = Simple()
e.mapping.append(StringSetting(value='foo'))
e.mapping.append(IntegerSetting(value=42))
e.mapping.append({'number': 1, 'string': 'Hi!', 'float': 1.001,
'complex': IntegerSetting(value=42), 'list':
[IntegerSetting(value=42), StringSetting(value='foo')]})
e.save()
e2 = Simple.objects.get(id=e.id)
self.assertTrue(isinstance(e2.mapping[0], StringSetting))
self.assertTrue(isinstance(e2.mapping[1], IntegerSetting))
# Test querying
self.assertEquals(Simple.objects.filter(mapping__1__value=42).count(), 1)
self.assertEquals(Simple.objects.filter(mapping__2__number=1).count(), 1)
self.assertEquals(Simple.objects.filter(mapping__2__complex__value=42).count(), 1)
self.assertEquals(Simple.objects.filter(mapping__2__list__0__value=42).count(), 1)
self.assertEquals(Simple.objects.filter(mapping__2__list__1__value='foo').count(), 1)
# Confirm can update
Simple.objects().update(set__mapping__1=IntegerSetting(value=10))
self.assertEquals(Simple.objects.filter(mapping__1__value=10).count(), 1)
Simple.objects().update(
set__mapping__2__list__1=StringSetting(value='Boo'))
self.assertEquals(Simple.objects.filter(mapping__2__list__1__value='foo').count(), 0)
self.assertEquals(Simple.objects.filter(mapping__2__list__1__value='Boo').count(), 1)
Simple.drop_collection()
def test_dict_field(self): def test_dict_field(self):
"""Ensure that dict types work as expected. """Ensure that dict types work as expected.
""" """
@ -363,6 +566,131 @@ class FieldTest(unittest.TestCase):
self.assertEquals(BlogPost.objects.filter(info__made_up__test__exact='test').count(), 0) self.assertEquals(BlogPost.objects.filter(info__made_up__test__exact='test').count(), 0)
BlogPost.drop_collection() BlogPost.drop_collection()
def test_dictfield_strict(self):
"""Ensure that dict field handles validation if provided a strict field type."""
class Simple(Document):
mapping = DictField(field=IntField())
Simple.drop_collection()
e = Simple()
e.mapping['someint'] = 1
e.save()
def create_invalid_mapping():
e.mapping['somestring'] = "abc"
e.save()
self.assertRaises(ValidationError, create_invalid_mapping)
Simple.drop_collection()
def test_dictfield_complex(self):
"""Ensure that the dict field can handle the complex types."""
class SettingBase(EmbeddedDocument):
pass
class StringSetting(SettingBase):
value = StringField()
class IntegerSetting(SettingBase):
value = IntField()
class Simple(Document):
mapping = DictField()
Simple.drop_collection()
e = Simple()
e.mapping['somestring'] = StringSetting(value='foo')
e.mapping['someint'] = IntegerSetting(value=42)
e.mapping['nested_dict'] = {'number': 1, 'string': 'Hi!', 'float': 1.001,
'complex': IntegerSetting(value=42), 'list':
[IntegerSetting(value=42), StringSetting(value='foo')]}
e.save()
e2 = Simple.objects.get(id=e.id)
self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting))
self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting))
# Test querying
self.assertEquals(Simple.objects.filter(mapping__someint__value=42).count(), 1)
self.assertEquals(Simple.objects.filter(mapping__nested_dict__number=1).count(), 1)
self.assertEquals(Simple.objects.filter(mapping__nested_dict__complex__value=42).count(), 1)
self.assertEquals(Simple.objects.filter(mapping__nested_dict__list__0__value=42).count(), 1)
self.assertEquals(Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 1)
# Confirm can update
Simple.objects().update(
set__mapping={"someint": IntegerSetting(value=10)})
Simple.objects().update(
set__mapping__nested_dict__list__1=StringSetting(value='Boo'))
self.assertEquals(Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 0)
self.assertEquals(Simple.objects.filter(mapping__nested_dict__list__1__value='Boo').count(), 1)
Simple.drop_collection()
def test_mapfield(self):
"""Ensure that the MapField handles the declared type."""
class Simple(Document):
mapping = MapField(IntField())
Simple.drop_collection()
e = Simple()
e.mapping['someint'] = 1
e.save()
def create_invalid_mapping():
e.mapping['somestring'] = "abc"
e.save()
self.assertRaises(ValidationError, create_invalid_mapping)
def create_invalid_class():
class NoDeclaredType(Document):
mapping = MapField()
self.assertRaises(ValidationError, create_invalid_class)
Simple.drop_collection()
def test_complex_mapfield(self):
"""Ensure that the MapField can handle complex declared types."""
class SettingBase(EmbeddedDocument):
pass
class StringSetting(SettingBase):
value = StringField()
class IntegerSetting(SettingBase):
value = IntField()
class Extensible(Document):
mapping = MapField(EmbeddedDocumentField(SettingBase))
Extensible.drop_collection()
e = Extensible()
e.mapping['somestring'] = StringSetting(value='foo')
e.mapping['someint'] = IntegerSetting(value=42)
e.save()
e2 = Extensible.objects.get(id=e.id)
self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting))
self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting))
def create_invalid_mapping():
e.mapping['someint'] = 123
e.save()
self.assertRaises(ValidationError, create_invalid_mapping)
Extensible.drop_collection()
def test_embedded_document_validation(self): def test_embedded_document_validation(self):
"""Ensure that invalid embedded documents cannot be assigned to """Ensure that invalid embedded documents cannot be assigned to
embedded document fields. embedded document fields.
@ -773,6 +1101,35 @@ class FieldTest(unittest.TestCase):
Shirt.drop_collection() Shirt.drop_collection()
def test_choices_get_field_display(self):
"""Test dynamic helper for returning the display value of a choices field.
"""
class Shirt(Document):
size = StringField(max_length=3, choices=(('S', 'Small'), ('M', 'Medium'), ('L', 'Large'),
('XL', 'Extra Large'), ('XXL', 'Extra Extra Large')))
style = StringField(max_length=3, choices=(('S', 'Small'), ('B', 'Baggy'), ('W', 'wide')), default='S')
Shirt.drop_collection()
shirt = Shirt()
self.assertEqual(shirt.get_size_display(), None)
self.assertEqual(shirt.get_style_display(), 'Small')
shirt.size = "XXL"
shirt.style = "B"
self.assertEqual(shirt.get_size_display(), 'Extra Extra Large')
self.assertEqual(shirt.get_style_display(), 'Baggy')
# Set as Z - an invalid choice
shirt.size = "Z"
shirt.style = "Z"
self.assertEqual(shirt.get_size_display(), 'Z')
self.assertEqual(shirt.get_style_display(), 'Z')
self.assertRaises(ValidationError, shirt.validate)
Shirt.drop_collection()
def test_file_fields(self): def test_file_fields(self):
"""Ensure that file fields can be written to and their data retrieved """Ensure that file fields can be written to and their data retrieved
""" """
@ -904,66 +1261,6 @@ class FieldTest(unittest.TestCase):
self.assertEqual(d2.data, {}) self.assertEqual(d2.data, {})
self.assertEqual(d2.data2, {}) self.assertEqual(d2.data2, {})
def test_mapfield(self):
"""Ensure that the MapField handles the declared type."""
class Simple(Document):
mapping = MapField(IntField())
Simple.drop_collection()
e = Simple()
e.mapping['someint'] = 1
e.save()
def create_invalid_mapping():
e.mapping['somestring'] = "abc"
e.save()
self.assertRaises(ValidationError, create_invalid_mapping)
def create_invalid_class():
class NoDeclaredType(Document):
mapping = MapField()
self.assertRaises(ValidationError, create_invalid_class)
Simple.drop_collection()
def test_complex_mapfield(self):
"""Ensure that the MapField can handle complex declared types."""
class SettingBase(EmbeddedDocument):
pass
class StringSetting(SettingBase):
value = StringField()
class IntegerSetting(SettingBase):
value = IntField()
class Extensible(Document):
mapping = MapField(EmbeddedDocumentField(SettingBase))
Extensible.drop_collection()
e = Extensible()
e.mapping['somestring'] = StringSetting(value='foo')
e.mapping['someint'] = IntegerSetting(value=42)
e.save()
e2 = Extensible.objects.get(id=e.id)
self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting))
self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting))
def create_invalid_mapping():
e.mapping['someint'] = 123
e.save()
self.assertRaises(ValidationError, create_invalid_mapping)
Extensible.drop_collection()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -1830,6 +1830,22 @@ class QuerySetTest(unittest.TestCase):
self.assertTrue([('_types', 1)] in info) self.assertTrue([('_types', 1)] in info)
self.assertTrue([('_types', 1), ('date', -1)] in info) self.assertTrue([('_types', 1), ('date', -1)] in info)
def test_dont_index_types(self):
"""Ensure that index_types will, when disabled, prevent _types
being added to all indices.
"""
class BlogPost(Document):
date = DateTimeField()
meta = {'index_types': False,
'indexes': ['-date']}
# Indexes are lazy so use list() to perform query
list(BlogPost.objects)
info = BlogPost.objects._collection.index_information()
info = [value['key'] for key, value in info.iteritems()]
self.assertTrue([('_types', 1)] not in info)
self.assertTrue([('date', -1)] in info)
BlogPost.drop_collection() BlogPost.drop_collection()
class BlogPost(Document): class BlogPost(Document):