Merge branch 'master' into 0.8M

Conflicts:
	AUTHORS
	docs/django.rst
	mongoengine/base.py
	mongoengine/queryset.py
	tests/fields/fields.py
	tests/queryset/queryset.py
	tests/test_dereference.py
	tests/test_document.py
This commit is contained in:
Ross Lawley 2013-04-17 11:57:53 +00:00
commit 51e50bf0a9
17 changed files with 474 additions and 101 deletions

11
AUTHORS
View File

@ -129,6 +129,16 @@ that much better:
* Peter Teichman * Peter Teichman
* Jakub Kot * Jakub Kot
* Jorge Bastida * Jorge Bastida
* Aleksandr Sorokoumov
* Yohan Graterol
* bool-dev
* Russ Weeks
* Paul Swartz
* Sundar Raman
* Benoit Louy
* lraucy
* hellysmile
* Jaepil Jeong
* Stefan Wójcik * Stefan Wójcik
* Pete Campton * Pete Campton
* Martyn Smith * Martyn Smith
@ -145,4 +155,3 @@ that much better:
* Jared Forsyth * Jared Forsyth
* Kenneth Falck * Kenneth Falck
* Lukasz Balcerzak * Lukasz Balcerzak
* Aleksandr Sorokoumov

View File

@ -51,6 +51,19 @@ Changes in 0.8.X
Changes in 0.7.10 Changes in 0.7.10
================= =================
- Allow construction using positional parameters (#268)
- Updated EmailField length to support long domains (#243)
- Added 64-bit integer support (#251)
- Added Django sessions TTL support (#224)
- Fixed issue with numerical keys in MapField(EmbeddedDocumentField()) (#240)
- Fixed clearing _changed_fields for complex nested embedded documents (#237, #239, #242)
- Added "_id" to _data dictionary (#255)
- Only mark a field as changed if the value has changed (#258)
- Explicitly check for Document instances when dereferencing (#261)
- Fixed order_by chaining issue (#265)
- Added dereference support for tuples (#250)
- Resolve field name to db field name when using distinct(#260, #264, #269)
- Added kwargs to doc.save to help interop with django (#223, #270)
- Fixed cloning querysets in PY3 - Fixed cloning querysets in PY3
- Int fields no longer unset in save when changed to 0 (#272) - Int fields no longer unset in save when changed to 0 (#272)
- Fixed ReferenceField query chaining bug fixed (#254) - Fixed ReferenceField query chaining bug fixed (#254)

View File

@ -10,9 +10,15 @@ In your **settings.py** file, ignore the standard database settings (unless you
also plan to use the ORM in your project), and instead call also plan to use the ORM in your project), and instead call
:func:`~mongoengine.connect` somewhere in the settings module. :func:`~mongoengine.connect` somewhere in the settings module.
.. note:: If getting an ``ImproperlyConfigured: settings.DATABASES is .. note ::
improperly configured`` error you may need to remove If you are not using another Database backend you may need to add a dummy
``django.contrib.sites`` from ``INSTALLED_APPS`` in settings.py. database backend to ``settings.py`` eg::
DATABASES = {
'default': {
'ENGINE': 'django.db.backends.dummy'
}
}
Authentication Authentication
============== ==============
@ -49,6 +55,9 @@ into you settings module::
SESSION_ENGINE = 'mongoengine.django.sessions' SESSION_ENGINE = 'mongoengine.django.sessions'
Django provides session cookie, which expires after ```SESSION_COOKIE_AGE``` seconds, but doesnt delete cookie at sessions backend, so ``'mongoengine.django.sessions'`` supports `mongodb TTL
<http://docs.mongodb.org/manual/tutorial/expire-data/>`_.
.. versionadded:: 0.2.1 .. versionadded:: 0.2.1
Storage Storage

View File

@ -30,13 +30,22 @@ class BaseDocument(object):
_dynamic_lock = True _dynamic_lock = True
_initialised = False _initialised = False
def __init__(self, __auto_convert=True, **values): def __init__(self, __auto_convert=True, *args, **values):
""" """
Initialise a document or embedded document Initialise a document or embedded document
:param __auto_convert: Try and will cast python objects to Object types :param __auto_convert: Try and will cast python objects to Object types
:param values: A dictionary of values for the document :param values: A dictionary of values for the document
""" """
if args:
# Combine positional arguments with named arguments.
# We only want named arguments.
field = iter(self._fields_ordered)
for value in args:
name = next(field)
if name in values:
raise TypeError("Multiple values for keyword argument '" + name + "'")
values[name] = value
signals.pre_init.send(self.__class__, document=self, values=values) signals.pre_init.send(self.__class__, document=self, values=values)
@ -117,15 +126,15 @@ class BaseDocument(object):
self._mark_as_changed(name) self._mark_as_changed(name)
if (self._is_document and not self._created and if (self._is_document and not self._created and
name in self._meta.get('shard_key', tuple()) and name in self._meta.get('shard_key', tuple()) and
self._data.get(name) != value): self._data.get(name) != value):
OperationError = _import_class('OperationError') OperationError = _import_class('OperationError')
msg = "Shard Keys are immutable. Tried to update %s" % name msg = "Shard Keys are immutable. Tried to update %s" % name
raise OperationError(msg) raise OperationError(msg)
# Check if the user has created a new instance of a class # Check if the user has created a new instance of a class
if (self._is_document and self._initialised if (self._is_document and self._initialised
and self._created and name == self._meta['id_field']): and self._created and name == self._meta['id_field']):
super(BaseDocument, self).__setattr__('_created', False) super(BaseDocument, self).__setattr__('_created', False)
super(BaseDocument, self).__setattr__(name, value) super(BaseDocument, self).__setattr__(name, value)
@ -143,7 +152,10 @@ class BaseDocument(object):
self.__set_field_display() self.__set_field_display()
def __iter__(self): def __iter__(self):
return iter(self._fields) if 'id' in self._fields and 'id' not in self._fields_ordered:
return iter(('id', ) + self._fields_ordered)
return iter(self._fields_ordered)
def __getitem__(self, name): def __getitem__(self, name):
"""Dictionary-style field access, return a field's value if present. """Dictionary-style field access, return a field's value if present.
@ -264,7 +276,7 @@ class BaseDocument(object):
for name, field in self._fields.items()] for name, field in self._fields.items()]
if self._dynamic: if self._dynamic:
fields += [(field, self._data.get(name)) fields += [(field, self._data.get(name))
for name, field in self._dynamic_fields.items()] for name, field in self._dynamic_fields.items()]
EmbeddedDocumentField = _import_class("EmbeddedDocumentField") EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
GenericEmbeddedDocumentField = _import_class("GenericEmbeddedDocumentField") GenericEmbeddedDocumentField = _import_class("GenericEmbeddedDocumentField")
@ -273,7 +285,7 @@ class BaseDocument(object):
if value is not None: if value is not None:
try: try:
if isinstance(field, (EmbeddedDocumentField, if isinstance(field, (EmbeddedDocumentField,
GenericEmbeddedDocumentField)): GenericEmbeddedDocumentField)):
field._validate(value, clean=clean) field._validate(value, clean=clean)
else: else:
field._validate(value) field._validate(value)
@ -330,7 +342,7 @@ class BaseDocument(object):
# Convert lists / values so we can watch for any changes on them # Convert lists / values so we can watch for any changes on them
if (isinstance(value, (list, tuple)) and if (isinstance(value, (list, tuple)) and
not isinstance(value, BaseList)): not isinstance(value, BaseList)):
value = BaseList(value, self, name) value = BaseList(value, self, name)
elif isinstance(value, dict) and not isinstance(value, BaseDict): elif isinstance(value, dict) and not isinstance(value, BaseDict):
value = BaseDict(value, self, name) value = BaseDict(value, self, name)
@ -344,9 +356,25 @@ class BaseDocument(object):
return return
key = self._db_field_map.get(key, key) key = self._db_field_map.get(key, key)
if (hasattr(self, '_changed_fields') and if (hasattr(self, '_changed_fields') and
key not in self._changed_fields): key not in self._changed_fields):
self._changed_fields.append(key) self._changed_fields.append(key)
def _clear_changed_fields(self):
self._changed_fields = []
EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
for field_name, field in self._fields.iteritems():
if (isinstance(field, ComplexBaseField) and
isinstance(field.field, EmbeddedDocumentField)):
field_value = getattr(self, field_name, None)
if field_value:
for idx in (field_value if isinstance(field_value, dict)
else xrange(len(field_value))):
field_value[idx]._clear_changed_fields()
elif isinstance(field, EmbeddedDocumentField):
field_value = getattr(self, field_name, None)
if field_value:
field_value._clear_changed_fields()
def _get_changed_fields(self, key='', inspected=None): def _get_changed_fields(self, key='', inspected=None):
"""Returns a list of all fields that have explicitly been changed. """Returns a list of all fields that have explicitly been changed.
""" """
@ -418,7 +446,7 @@ class BaseDocument(object):
for p in parts: for p in parts:
if isinstance(d, DBRef): if isinstance(d, DBRef):
break break
elif p.isdigit(): elif isinstance(d, list) and p.isdigit():
d = d[int(p)] d = d[int(p)]
elif hasattr(d, 'get'): elif hasattr(d, 'get'):
d = d.get(p) d = d.get(p)
@ -449,7 +477,7 @@ class BaseDocument(object):
parts = path.split('.') parts = path.split('.')
db_field_name = parts.pop() db_field_name = parts.pop()
for p in parts: for p in parts:
if p.isdigit(): if isinstance(d, list) and p.isdigit():
d = d[int(p)] d = d[int(p)]
elif (hasattr(d, '__getattribute__') and elif (hasattr(d, '__getattribute__') and
not isinstance(d, dict)): not isinstance(d, dict)):
@ -514,7 +542,7 @@ class BaseDocument(object):
value = data[field.db_field] value = data[field.db_field]
try: try:
data[field_name] = (value if value is None data[field_name] = (value if value is None
else field.to_python(value)) else field.to_python(value))
if field_name != field.db_field: if field_name != field.db_field:
del data[field.db_field] del data[field.db_field]
except (AttributeError, ValueError), e: except (AttributeError, ValueError), e:
@ -548,14 +576,14 @@ class BaseDocument(object):
geo_indices = cls._geo_indices() geo_indices = cls._geo_indices()
unique_indices = cls._unique_with_indexes() unique_indices = cls._unique_with_indexes()
index_specs = [cls._build_index_spec(spec) index_specs = [cls._build_index_spec(spec)
for spec in meta_indexes] for spec in meta_indexes]
def merge_index_specs(index_specs, indices): def merge_index_specs(index_specs, indices):
if not indices: if not indices:
return index_specs return index_specs
spec_fields = [v['fields'] spec_fields = [v['fields']
for k, v in enumerate(index_specs)] for k, v in enumerate(index_specs)]
# Merge unqiue_indexes with existing specs # Merge unqiue_indexes with existing specs
for k, v in enumerate(indices): for k, v in enumerate(indices):
if v['fields'] in spec_fields: if v['fields'] in spec_fields:
@ -727,7 +755,7 @@ class BaseDocument(object):
field = DynamicField(db_field=field_name) field = DynamicField(db_field=field_name)
else: else:
raise LookUpError('Cannot resolve field "%s"' raise LookUpError('Cannot resolve field "%s"'
% field_name) % field_name)
else: else:
ReferenceField = _import_class('ReferenceField') ReferenceField = _import_class('ReferenceField')
GenericReferenceField = _import_class('GenericReferenceField') GenericReferenceField = _import_class('GenericReferenceField')
@ -744,7 +772,7 @@ class BaseDocument(object):
continue continue
elif not new_field: elif not new_field:
raise LookUpError('Cannot resolve field "%s"' raise LookUpError('Cannot resolve field "%s"'
% field_name) % field_name)
field = new_field # update field to the new field type field = new_field # update field to the new field type
fields.append(field) fields.append(field)
return fields return fields

View File

@ -81,8 +81,12 @@ class BaseField(object):
def __set__(self, instance, value): def __set__(self, instance, value):
"""Descriptor for assigning a value to a field in a document. """Descriptor for assigning a value to a field in a document.
""" """
instance._data[self.name] = value changed = False
if instance._initialised: if (self.name not in instance._data or
instance._data[self.name] != value):
changed = True
instance._data[self.name] = value
if changed and instance._initialised:
instance._mark_as_changed(self.name) instance._mark_as_changed(self.name)
def error(self, message="", errors=None, field_name=None): def error(self, message="", errors=None, field_name=None):

View File

@ -78,7 +78,7 @@ class DocumentMetaclass(type):
# Count names to ensure no db_field redefinitions # Count names to ensure no db_field redefinitions
field_names[attr_value.db_field] = field_names.get( field_names[attr_value.db_field] = field_names.get(
attr_value.db_field, 0) + 1 attr_value.db_field, 0) + 1
# Ensure no duplicate db_fields # Ensure no duplicate db_fields
duplicate_db_fields = [k for k, v in field_names.items() if v > 1] duplicate_db_fields = [k for k, v in field_names.items() if v > 1]
@ -90,9 +90,12 @@ class DocumentMetaclass(type):
# Set _fields and db_field maps # Set _fields and db_field maps
attrs['_fields'] = doc_fields attrs['_fields'] = doc_fields
attrs['_db_field_map'] = dict([(k, getattr(v, 'db_field', k)) attrs['_db_field_map'] = dict([(k, getattr(v, 'db_field', k))
for k, v in doc_fields.iteritems()]) for k, v in doc_fields.iteritems()])
attrs['_fields_ordered'] = tuple(i[1] for i in sorted(
(v.creation_counter, v.name)
for v in doc_fields.itervalues()))
attrs['_reverse_db_field_map'] = dict( attrs['_reverse_db_field_map'] = dict(
(v, k) for k, v in attrs['_db_field_map'].iteritems()) (v, k) for k, v in attrs['_db_field_map'].iteritems())
# #
# Set document hierarchy # Set document hierarchy
@ -101,7 +104,7 @@ class DocumentMetaclass(type):
class_name = [name] class_name = [name]
for base in flattened_bases: for base in flattened_bases:
if (not getattr(base, '_is_base_cls', True) and if (not getattr(base, '_is_base_cls', True) and
not getattr(base, '_meta', {}).get('abstract', True)): not getattr(base, '_meta', {}).get('abstract', True)):
# Collate heirarchy for _cls and _subclasses # Collate heirarchy for _cls and _subclasses
class_name.append(base.__name__) class_name.append(base.__name__)
@ -109,11 +112,11 @@ class DocumentMetaclass(type):
# Warn if allow_inheritance isn't set and prevent # Warn if allow_inheritance isn't set and prevent
# inheritance of classes where inheritance is set to False # inheritance of classes where inheritance is set to False
allow_inheritance = base._meta.get('allow_inheritance', allow_inheritance = base._meta.get('allow_inheritance',
ALLOW_INHERITANCE) ALLOW_INHERITANCE)
if (allow_inheritance != True and if (allow_inheritance is not True and
not base._meta.get('abstract')): not base._meta.get('abstract')):
raise ValueError('Document %s may not be subclassed' % raise ValueError('Document %s may not be subclassed' %
base.__name__) base.__name__)
# Get superclasses from last base superclass # Get superclasses from last base superclass
document_bases = [b for b in flattened_bases document_bases = [b for b in flattened_bases

View File

@ -33,7 +33,7 @@ class DeReference(object):
self.max_depth = max_depth self.max_depth = max_depth
doc_type = None doc_type = None
if instance and instance._fields: if instance and isinstance(instance, (Document, TopLevelDocumentMetaclass)):
doc_type = instance._fields.get(name) doc_type = instance._fields.get(name)
if hasattr(doc_type, 'field'): if hasattr(doc_type, 'field'):
doc_type = doc_type.field doc_type = doc_type.field
@ -84,7 +84,7 @@ class DeReference(object):
# Recursively find dbreferences # Recursively find dbreferences
depth += 1 depth += 1
for k, item in iterator: for k, item in iterator:
if hasattr(item, '_fields'): if isinstance(item, Document):
for field_name, field in item._fields.iteritems(): for field_name, field in item._fields.iteritems():
v = item._data.get(field_name, None) v = item._data.get(field_name, None)
if isinstance(v, (DBRef)): if isinstance(v, (DBRef)):
@ -174,6 +174,7 @@ class DeReference(object):
if not hasattr(items, 'items'): if not hasattr(items, 'items'):
is_list = True is_list = True
as_tuple = isinstance(items, tuple)
iterator = enumerate(items) iterator = enumerate(items)
data = [] data = []
else: else:
@ -190,7 +191,7 @@ class DeReference(object):
if k in self.object_map and not is_list: if k in self.object_map and not is_list:
data[k] = self.object_map[k] data[k] = self.object_map[k]
elif hasattr(v, '_fields'): elif isinstance(v, Document):
for field_name, field in v._fields.iteritems(): for field_name, field in v._fields.iteritems():
v = data[k]._data.get(field_name, None) v = data[k]._data.get(field_name, None)
if isinstance(v, (DBRef)): if isinstance(v, (DBRef)):
@ -208,7 +209,7 @@ class DeReference(object):
if instance and name: if instance and name:
if is_list: if is_list:
return BaseList(data, instance, name) return tuple(data) if as_tuple else BaseList(data, instance, name)
return BaseDict(data, instance, name) return BaseDict(data, instance, name)
depth += 1 depth += 1
return data return data

View File

@ -32,9 +32,17 @@ class MongoSession(Document):
else fields.DictField() else fields.DictField()
expire_date = fields.DateTimeField() expire_date = fields.DateTimeField()
meta = {'collection': MONGOENGINE_SESSION_COLLECTION, meta = {
'db_alias': MONGOENGINE_SESSION_DB_ALIAS, 'collection': MONGOENGINE_SESSION_COLLECTION,
'allow_inheritance': False} 'db_alias': MONGOENGINE_SESSION_DB_ALIAS,
'allow_inheritance': False,
'indexes': [
{
'fields': ['expire_date'],
'expireAfterSeconds': settings.SESSION_COOKIE_AGE
}
]
}
def get_decoded(self): def get_decoded(self):
return SessionStore().decode(self.session_data) return SessionStore().decode(self.session_data)

View File

@ -160,7 +160,7 @@ class Document(BaseDocument):
def save(self, safe=True, force_insert=False, validate=True, clean=True, def save(self, safe=True, force_insert=False, validate=True, clean=True,
write_options=None, cascade=None, cascade_kwargs=None, write_options=None, cascade=None, cascade_kwargs=None,
_refs=None): _refs=None, **kwargs):
"""Save the :class:`~mongoengine.Document` to the database. If the """Save the :class:`~mongoengine.Document` to the database. If the
document already exists, it will be updated, otherwise it will be document already exists, it will be updated, otherwise it will be
created. created.
@ -278,7 +278,7 @@ class Document(BaseDocument):
if id_field not in self._meta.get('shard_key', []): if id_field not in self._meta.get('shard_key', []):
self[id_field] = self._fields[id_field].to_python(object_id) self[id_field] = self._fields[id_field].to_python(object_id)
self._changed_fields = [] self._clear_changed_fields()
self._created = False self._created = False
signals.post_save.send(self.__class__, document=self, created=created) signals.post_save.send(self.__class__, document=self, created=created)
return self return self

View File

@ -27,7 +27,7 @@ except ImportError:
Image = None Image = None
ImageOps = None ImageOps = None
__all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField', __all__ = ['StringField', 'IntField', 'LongField', 'FloatField', 'BooleanField',
'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField', 'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField',
'ObjectIdField', 'ReferenceField', 'ValidationError', 'MapField', 'ObjectIdField', 'ReferenceField', 'ValidationError', 'MapField',
'DecimalField', 'ComplexDateTimeField', 'URLField', 'DynamicField', 'DecimalField', 'ComplexDateTimeField', 'URLField', 'DynamicField',
@ -143,7 +143,7 @@ class EmailField(StringField):
EMAIL_REGEX = re.compile( EMAIL_REGEX = re.compile(
r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*" # dot-atom r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*" # dot-atom
r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string
r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,253}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain
) )
def validate(self, value): def validate(self, value):
@ -153,7 +153,7 @@ class EmailField(StringField):
class IntField(BaseField): class IntField(BaseField):
"""An integer field. """An 32-bit integer field.
""" """
def __init__(self, min_value=None, max_value=None, **kwargs): def __init__(self, min_value=None, max_value=None, **kwargs):
@ -186,6 +186,40 @@ class IntField(BaseField):
return int(value) return int(value)
class LongField(BaseField):
"""An 64-bit integer field.
"""
def __init__(self, min_value=None, max_value=None, **kwargs):
self.min_value, self.max_value = min_value, max_value
super(LongField, self).__init__(**kwargs)
def to_python(self, value):
try:
value = long(value)
except ValueError:
pass
return value
def validate(self, value):
try:
value = long(value)
except:
self.error('%s could not be converted to long' % value)
if self.min_value is not None and value < self.min_value:
self.error('Long value is too small')
if self.max_value is not None and value > self.max_value:
self.error('Long value is too large')
def prepare_query_value(self, op, value):
if value is None:
return value
return long(value)
class FloatField(BaseField): class FloatField(BaseField):
"""An floating point number field. """An floating point number field.
""" """

View File

@ -609,8 +609,11 @@ class QuerySet(object):
.. versionchanged:: 0.6 - Improved db_field refrence handling .. versionchanged:: 0.6 - Improved db_field refrence handling
""" """
queryset = self.clone() queryset = self.clone()
return queryset._dereference(queryset._cursor.distinct(field), 1, try:
name=field, instance=queryset._document) field = self._fields_to_dbfields([field]).pop()
finally:
return self._dereference(queryset._cursor.distinct(field), 1,
name=field, instance=self._document)
def only(self, *fields): def only(self, *fields):
"""Load only a subset of this document's fields. :: """Load only a subset of this document's fields. ::
@ -696,7 +699,7 @@ class QuerySet(object):
prefixed with **+** or **-** to determine the ordering direction prefixed with **+** or **-** to determine the ordering direction
""" """
queryset = self.clone() queryset = self.clone()
queryset._ordering = self._get_order_by(keys) queryset._ordering = queryset._get_order_by(keys)
return queryset return queryset
def explain(self, format=False): def explain(self, format=False):

View File

@ -32,7 +32,7 @@ class SimplificationVisitor(QNodeVisitor):
if combination.operation == combination.AND: if combination.operation == combination.AND:
# The simplification only applies to 'simple' queries # The simplification only applies to 'simple' queries
if all(isinstance(node, Q) for node in combination.children): if all(isinstance(node, Q) for node in combination.children):
queries = [node.query for node in combination.children] queries = [n.query for n in combination.children]
return Q(**self._query_conjunction(queries)) return Q(**self._query_conjunction(queries))
return combination return combination

View File

@ -827,20 +827,20 @@ class InstanceTest(unittest.TestCase):
float_field = FloatField(default=1.1) float_field = FloatField(default=1.1)
boolean_field = BooleanField(default=True) boolean_field = BooleanField(default=True)
datetime_field = DateTimeField(default=datetime.now) datetime_field = DateTimeField(default=datetime.now)
embedded_document_field = EmbeddedDocumentField(EmbeddedDoc, embedded_document_field = EmbeddedDocumentField(
default=lambda: EmbeddedDoc()) EmbeddedDoc, default=lambda: EmbeddedDoc())
list_field = ListField(default=lambda: [1, 2, 3]) list_field = ListField(default=lambda: [1, 2, 3])
dict_field = DictField(default=lambda: {"hello": "world"}) dict_field = DictField(default=lambda: {"hello": "world"})
objectid_field = ObjectIdField(default=bson.ObjectId) objectid_field = ObjectIdField(default=bson.ObjectId)
reference_field = ReferenceField(Simple, default=lambda: reference_field = ReferenceField(Simple, default=lambda:
Simple().save()) Simple().save())
map_field = MapField(IntField(), default=lambda: {"simple": 1}) map_field = MapField(IntField(), default=lambda: {"simple": 1})
decimal_field = DecimalField(default=1.0) decimal_field = DecimalField(default=1.0)
complex_datetime_field = ComplexDateTimeField(default=datetime.now) complex_datetime_field = ComplexDateTimeField(default=datetime.now)
url_field = URLField(default="http://mongoengine.org") url_field = URLField(default="http://mongoengine.org")
dynamic_field = DynamicField(default=1) dynamic_field = DynamicField(default=1)
generic_reference_field = GenericReferenceField( generic_reference_field = GenericReferenceField(
default=lambda: Simple().save()) default=lambda: Simple().save())
sorted_list_field = SortedListField(IntField(), sorted_list_field = SortedListField(IntField(),
default=lambda: [1, 2, 3]) default=lambda: [1, 2, 3])
email_field = EmailField(default="ross@example.com") email_field = EmailField(default="ross@example.com")
@ -848,7 +848,7 @@ class InstanceTest(unittest.TestCase):
sequence_field = SequenceField() sequence_field = SequenceField()
uuid_field = UUIDField(default=uuid.uuid4) uuid_field = UUIDField(default=uuid.uuid4)
generic_embedded_document_field = GenericEmbeddedDocumentField( generic_embedded_document_field = GenericEmbeddedDocumentField(
default=lambda: EmbeddedDoc()) default=lambda: EmbeddedDoc())
Simple.drop_collection() Simple.drop_collection()
Doc.drop_collection() Doc.drop_collection()
@ -1127,20 +1127,20 @@ class InstanceTest(unittest.TestCase):
u3 = User(username="hmarr") u3 = User(username="hmarr")
u3.save() u3.save()
p1 = Page(comments = [Comment(user=u1, comment="Its very good"), p1 = Page(comments=[Comment(user=u1, comment="Its very good"),
Comment(user=u2, comment="Hello world"), Comment(user=u2, comment="Hello world"),
Comment(user=u3, comment="Ping Pong"), Comment(user=u3, comment="Ping Pong"),
Comment(user=u1, comment="I like a beer")]) Comment(user=u1, comment="I like a beer")])
p1.save() p1.save()
p2 = Page(comments = [Comment(user=u1, comment="Its very good"), p2 = Page(comments=[Comment(user=u1, comment="Its very good"),
Comment(user=u2, comment="Hello world")]) Comment(user=u2, comment="Hello world")])
p2.save() p2.save()
p3 = Page(comments = [Comment(user=u3, comment="Its very good")]) p3 = Page(comments=[Comment(user=u3, comment="Its very good")])
p3.save() p3.save()
p4 = Page(comments = [Comment(user=u2, comment="Heavy Metal song")]) p4 = Page(comments=[Comment(user=u2, comment="Heavy Metal song")])
p4.save() p4.save()
self.assertEqual([p1, p2], list(Page.objects.filter(comments__user=u1))) self.assertEqual([p1, p2], list(Page.objects.filter(comments__user=u1)))
@ -1183,7 +1183,6 @@ class InstanceTest(unittest.TestCase):
class Site(Document): class Site(Document):
page = EmbeddedDocumentField(Page) page = EmbeddedDocumentField(Page)
Site.drop_collection() Site.drop_collection()
site = Site(page=Page(log_message="Warning: Dummy message")) site = Site(page=Page(log_message="Warning: Dummy message"))
site.save() site.save()
@ -1328,7 +1327,8 @@ class InstanceTest(unittest.TestCase):
occurs = ListField(EmbeddedDocumentField(Occurrence), default=list) occurs = ListField(EmbeddedDocumentField(Occurrence), default=list)
def raise_invalid_document(): def raise_invalid_document():
Word._from_son({'stem': [1,2,3], 'forms': 1, 'count': 'one', 'occurs': {"hello": None}}) Word._from_son({'stem': [1, 2, 3], 'forms': 1, 'count': 'one',
'occurs': {"hello": None}})
self.assertRaises(InvalidDocumentError, raise_invalid_document) self.assertRaises(InvalidDocumentError, raise_invalid_document)
@ -1350,7 +1350,7 @@ class InstanceTest(unittest.TestCase):
reviewer = self.Person(name='Re Viewer') reviewer = self.Person(name='Re Viewer')
reviewer.save() reviewer.save()
post = BlogPost(content = 'Watched some TV') post = BlogPost(content='Watched some TV')
post.author = author post.author = author
post.reviewer = reviewer post.reviewer = reviewer
post.save() post.save()
@ -1432,7 +1432,6 @@ class InstanceTest(unittest.TestCase):
author.delete() author.delete()
self.assertEqual(len(BlogPost.objects), 0) self.assertEqual(len(BlogPost.objects), 0)
def test_reverse_delete_rule_cascade_triggers_pre_delete_signal(self): def test_reverse_delete_rule_cascade_triggers_pre_delete_signal(self):
''' ensure the pre_delete signal is triggered upon a cascading deletion ''' ensure the pre_delete signal is triggered upon a cascading deletion
setup a blog post with content, an author and editor setup a blog post with content, an author and editor
@ -1627,7 +1626,7 @@ class InstanceTest(unittest.TestCase):
u1 = User.objects.create() u1 = User.objects.create()
u2 = User.objects.create() u2 = User.objects.create()
u3 = User.objects.create() u3 = User.objects.create()
u4 = User() # New object u4 = User() # New object
b1 = BlogPost.objects.create() b1 = BlogPost.objects.create()
b2 = BlogPost.objects.create() b2 = BlogPost.objects.create()
@ -1638,9 +1637,9 @@ class InstanceTest(unittest.TestCase):
self.assertTrue(u1 in all_user_list) self.assertTrue(u1 in all_user_list)
self.assertTrue(u2 in all_user_list) self.assertTrue(u2 in all_user_list)
self.assertTrue(u3 in all_user_list) self.assertTrue(u3 in all_user_list)
self.assertFalse(u4 in all_user_list) # New object self.assertFalse(u4 in all_user_list) # New object
self.assertFalse(b1 in all_user_list) # Other object self.assertFalse(b1 in all_user_list) # Other object
self.assertFalse(b2 in all_user_list) # Other object self.assertFalse(b2 in all_user_list) # Other object
# in Dict # in Dict
all_user_dic = {} all_user_dic = {}
@ -1650,9 +1649,9 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(all_user_dic.get(u1, False), "OK") self.assertEqual(all_user_dic.get(u1, False), "OK")
self.assertEqual(all_user_dic.get(u2, False), "OK") self.assertEqual(all_user_dic.get(u2, False), "OK")
self.assertEqual(all_user_dic.get(u3, False), "OK") self.assertEqual(all_user_dic.get(u3, False), "OK")
self.assertEqual(all_user_dic.get(u4, False), False) # New object self.assertEqual(all_user_dic.get(u4, False), False) # New object
self.assertEqual(all_user_dic.get(b1, False), False) # Other object self.assertEqual(all_user_dic.get(b1, False), False) # Other object
self.assertEqual(all_user_dic.get(b2, False), False) # Other object self.assertEqual(all_user_dic.get(b2, False), False) # Other object
# in Set # in Set
all_user_set = set(User.objects.all()) all_user_set = set(User.objects.all())
@ -1730,7 +1729,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(Doc.objects(archived=False).count(), 1) self.assertEqual(Doc.objects(archived=False).count(), 1)
def test_can_save_false_values_dynamic(self): def test_can_save_false_values_dynamic(self):
"""Ensures you can save False values on dynamic docs""" """Ensures you can save False values on dynamic docs"""
class Doc(DynamicDocument): class Doc(DynamicDocument):
@ -1852,9 +1850,9 @@ class InstanceTest(unittest.TestCase):
self.assertEquals('testdb-2', B._meta.get('db_alias')) self.assertEquals('testdb-2', B._meta.get('db_alias'))
self.assertEquals('mongoenginetest', self.assertEquals('mongoenginetest',
A._get_collection().database.name) A._get_collection().database.name)
self.assertEquals('mongoenginetest2', self.assertEquals('mongoenginetest2',
B._get_collection().database.name) B._get_collection().database.name)
def test_db_alias_propagates(self): def test_db_alias_propagates(self):
"""db_alias propagates? """db_alias propagates?
@ -1920,21 +1918,21 @@ class InstanceTest(unittest.TestCase):
# Checks # Checks
self.assertEqual(",".join([str(b) for b in Book.objects.all()]), self.assertEqual(",".join([str(b) for b in Book.objects.all()]),
"1,2,3,4,5,6,7,8,9") "1,2,3,4,5,6,7,8,9")
# bob related books # bob related books
self.assertEqual(",".join([str(b) for b in Book.objects.filter( self.assertEqual(",".join([str(b) for b in Book.objects.filter(
Q(extra__a=bob) | Q(extra__a=bob) |
Q(author=bob) | Q(author=bob) |
Q(extra__b=bob))]), Q(extra__b=bob))]),
"1,2,3,4") "1,2,3,4")
# Susan & Karl related books # Susan & Karl related books
self.assertEqual(",".join([str(b) for b in Book.objects.filter( self.assertEqual(",".join([str(b) for b in Book.objects.filter(
Q(extra__a__all=[karl, susan]) | Q(extra__a__all=[karl, susan]) |
Q(author__all=[karl, susan]) | Q(author__all=[karl, susan]) |
Q(extra__b__all=[ Q(extra__b__all=[
karl.to_dbref(), susan.to_dbref()])) karl.to_dbref(), susan.to_dbref()]))
]), "1") ]), "1")
# $Where # $Where
self.assertEqual(u",".join([str(b) for b in Book.objects.filter( self.assertEqual(u",".join([str(b) for b in Book.objects.filter(
@ -1943,8 +1941,8 @@ class InstanceTest(unittest.TestCase):
function(){ function(){
return this.name == '1' || return this.name == '1' ||
this.name == '2';}""" this.name == '2';}"""
} })]),
)]), "1,2") "1,2")
def test_switch_db_instance(self): def test_switch_db_instance(self):
register_connection('testdb-1', 'mongoenginetest2') register_connection('testdb-1', 'mongoenginetest2')
@ -2020,7 +2018,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual("Bar", user._data["foo"]) self.assertEqual("Bar", user._data["foo"])
self.assertEqual([1, 2, 3], user._data["data"]) self.assertEqual([1, 2, 3], user._data["data"])
def test_spaces_in_keys(self): def test_spaces_in_keys(self):
class Embedded(DynamicEmbeddedDocument): class Embedded(DynamicEmbeddedDocument):
@ -2109,8 +2106,8 @@ class InstanceTest(unittest.TestCase):
docs = ListField(EmbeddedDocumentField(Embedded)) docs = ListField(EmbeddedDocumentField(Embedded))
classic_doc = Doc(doc_name="my doc", docs=[ classic_doc = Doc(doc_name="my doc", docs=[
Embedded(name="embedded doc1"), Embedded(name="embedded doc1"),
Embedded(name="embedded doc2")]) Embedded(name="embedded doc2")])
dict_doc = Doc(**{"doc_name": "my doc", dict_doc = Doc(**{"doc_name": "my doc",
"docs": [{"name": "embedded doc1"}, "docs": [{"name": "embedded doc1"},
{"name": "embedded doc2"}]}) {"name": "embedded doc2"}]})
@ -2118,5 +2115,82 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(classic_doc, dict_doc) self.assertEqual(classic_doc, dict_doc)
self.assertEqual(classic_doc._data, dict_doc._data) self.assertEqual(classic_doc._data, dict_doc._data)
def test_positional_creation(self):
"""Ensure that document may be created using positional arguments.
"""
person = self.Person("Test User", 42)
self.assertEqual(person.name, "Test User")
self.assertEqual(person.age, 42)
def test_mixed_creation(self):
"""Ensure that document may be created using mixed arguments.
"""
person = self.Person("Test User", age=42)
self.assertEqual(person.name, "Test User")
self.assertEqual(person.age, 42)
def test_bad_mixed_creation(self):
"""Ensure that document gives correct error when duplicating arguments
"""
def construct_bad_instance():
return self.Person("Test User", 42, name="Bad User")
self.assertRaises(TypeError, construct_bad_instance)
def test_data_contains_id_field(self):
"""Ensure that asking for _data returns 'id'
"""
class Person(Document):
name = StringField()
Person.drop_collection()
Person(name="Harry Potter").save()
person = Person.objects.first()
self.assertTrue('id' in person._data.keys())
self.assertEqual(person._data.get('id'), person.id)
def test_complex_nesting_document_and_embedded_document(self):
class Macro(EmbeddedDocument):
value = DynamicField(default="UNDEFINED")
class Parameter(EmbeddedDocument):
macros = MapField(EmbeddedDocumentField(Macro))
def expand(self):
self.macros["test"] = Macro()
class Node(Document):
parameters = MapField(EmbeddedDocumentField(Parameter))
def expand(self):
self.flattened_parameter = {}
for parameter_name, parameter in self.parameters.iteritems():
parameter.expand()
class System(Document):
name = StringField(required=True)
nodes = MapField(ReferenceField(Node, dbref=False))
def save(self, *args, **kwargs):
for node_name, node in self.nodes.iteritems():
node.expand()
node.save(*args, **kwargs)
super(System, self).save(*args, **kwargs)
System.drop_collection()
Node.drop_collection()
system = System(name="system")
system.nodes["node"] = Node()
system.save()
system.nodes["node"].parameters["param"] = Parameter()
system.save()
system = System.objects.first()
self.assertEqual("UNDEFINED", system.nodes["node"].parameters["param"].macros["test"].value)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -130,8 +130,8 @@ class ValidatorErrorTest(unittest.TestCase):
doc = Doc.objects.first() doc = Doc.objects.first()
keys = doc._data.keys() keys = doc._data.keys()
self.assertEqual(2, len(keys)) self.assertEqual(2, len(keys))
self.assertTrue('id' in keys)
self.assertTrue('e' in keys) self.assertTrue('e' in keys)
self.assertTrue('id' in keys)
doc.e.val = "OK" doc.e.val = "OK"
try: try:

View File

@ -145,6 +145,17 @@ class FieldTest(unittest.TestCase):
self.assertEqual(1, TestDocument.objects(int_fld__ne=None).count()) self.assertEqual(1, TestDocument.objects(int_fld__ne=None).count())
self.assertEqual(1, TestDocument.objects(float_fld__ne=None).count()) self.assertEqual(1, TestDocument.objects(float_fld__ne=None).count())
def test_long_ne_operator(self):
class TestDocument(Document):
long_fld = LongField()
TestDocument.drop_collection()
TestDocument(long_fld=None).save()
TestDocument(long_fld=1).save()
self.assertEqual(1, TestDocument.objects(long_fld__ne=None).count())
def test_object_id_validation(self): def test_object_id_validation(self):
"""Ensure that invalid values cannot be assigned to string fields. """Ensure that invalid values cannot be assigned to string fields.
""" """
@ -218,6 +229,23 @@ class FieldTest(unittest.TestCase):
person.age = 'ten' person.age = 'ten'
self.assertRaises(ValidationError, person.validate) self.assertRaises(ValidationError, person.validate)
def test_long_validation(self):
"""Ensure that invalid values cannot be assigned to long fields.
"""
class TestDocument(Document):
value = LongField(min_value=0, max_value=110)
doc = TestDocument()
doc.value = 50
doc.validate()
doc.value = -1
self.assertRaises(ValidationError, doc.validate)
doc.age = 120
self.assertRaises(ValidationError, doc.validate)
doc.age = 'ten'
self.assertRaises(ValidationError, doc.validate)
def test_float_validation(self): def test_float_validation(self):
"""Ensure that invalid values cannot be assigned to float fields. """Ensure that invalid values cannot be assigned to float fields.
""" """
@ -971,6 +999,24 @@ class FieldTest(unittest.TestCase):
doc = self.db.test.find_one() doc = self.db.test.find_one()
self.assertEqual(doc['x']['DICTIONARY_KEY']['i'], 2) self.assertEqual(doc['x']['DICTIONARY_KEY']['i'], 2)
def test_mapfield_numerical_index(self):
"""Ensure that MapField accept numeric strings as indexes."""
class Embedded(EmbeddedDocument):
name = StringField()
class Test(Document):
my_map = MapField(EmbeddedDocumentField(Embedded))
Test.drop_collection()
test = Test()
test.my_map['1'] = Embedded(name='test')
test.save()
test.my_map['1'].name = 'test updated'
test.save()
Test.drop_collection()
def test_map_field_lookup(self): def test_map_field_lookup(self):
"""Ensure MapField lookups succeed on Fields without a lookup method""" """Ensure MapField lookups succeed on Fields without a lookup method"""
@ -2399,11 +2445,26 @@ class FieldTest(unittest.TestCase):
self.assertTrue(1 in error_dict['comments']) self.assertTrue(1 in error_dict['comments'])
self.assertTrue('content' in error_dict['comments'][1]) self.assertTrue('content' in error_dict['comments'][1])
self.assertEqual(error_dict['comments'][1]['content'], self.assertEqual(error_dict['comments'][1]['content'],
u'Field is required') u'Field is required')
post.comments[1].content = 'here we go' post.comments[1].content = 'here we go'
post.validate() post.validate()
def test_email_field(self):
class User(Document):
email = EmailField()
user = User(email="ross@example.com")
self.assertTrue(user.validate() is None)
user = User(email=("Kofq@rhom0e4klgauOhpbpNdogawnyIKvQS0wk2mjqrgGQ5S"
"ucictfqpdkK9iS1zeFw8sg7s7cwAF7suIfUfeyueLpfosjn3"
"aJIazqqWkm7.net"))
self.assertTrue(user.validate() is None)
user = User(email='me@localhost')
self.assertRaises(ValidationError, user.validate)
def test_email_field_honors_regex(self): def test_email_field_honors_regex(self):
class User(Document): class User(Document):
email = EmailField(regex=r'\w+@example.com') email = EmailField(regex=r'\w+@example.com')

View File

@ -931,6 +931,11 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection() BlogPost.drop_collection()
Blog.drop_collection() Blog.drop_collection()
def assertSequence(self, qs, expected):
self.assertEqual(len(qs), len(expected))
for i in range(len(qs)):
self.assertEqual(qs[i], expected[i])
def test_ordering(self): def test_ordering(self):
"""Ensure default ordering is applied and can be overridden. """Ensure default ordering is applied and can be overridden.
""" """
@ -957,14 +962,13 @@ class QuerySetTest(unittest.TestCase):
# get the "first" BlogPost using default ordering # get the "first" BlogPost using default ordering
# from BlogPost.meta.ordering # from BlogPost.meta.ordering
latest_post = BlogPost.objects.first() expected = [blog_post_3, blog_post_2, blog_post_1]
self.assertEqual(latest_post.title, "Blog Post #3") self.assertSequence(BlogPost.objects.all(), expected)
# override default ordering, order BlogPosts by "published_date" # override default ordering, order BlogPosts by "published_date"
first_post = BlogPost.objects.order_by("+published_date").first() qs = BlogPost.objects.order_by("+published_date")
self.assertEqual(first_post.title, "Blog Post #1") expected = [blog_post_1, blog_post_2, blog_post_3]
self.assertSequence(qs, expected)
BlogPost.drop_collection()
def test_find_embedded(self): def test_find_embedded(self):
"""Ensure that an embedded document is properly returned from a query. """Ensure that an embedded document is properly returned from a query.
@ -1505,8 +1509,8 @@ class QuerySetTest(unittest.TestCase):
def test_order_by(self): def test_order_by(self):
"""Ensure that QuerySets may be ordered. """Ensure that QuerySets may be ordered.
""" """
self.Person(name="User A", age=20).save()
self.Person(name="User B", age=40).save() self.Person(name="User B", age=40).save()
self.Person(name="User A", age=20).save()
self.Person(name="User C", age=30).save() self.Person(name="User C", age=30).save()
names = [p.name for p in self.Person.objects.order_by('-age')] names = [p.name for p in self.Person.objects.order_by('-age')]
@ -1521,11 +1525,67 @@ class QuerySetTest(unittest.TestCase):
ages = [p.age for p in self.Person.objects.order_by('-name')] ages = [p.age for p in self.Person.objects.order_by('-name')]
self.assertEqual(ages, [30, 40, 20]) self.assertEqual(ages, [30, 40, 20])
def test_order_by_optional(self):
class BlogPost(Document):
title = StringField()
published_date = DateTimeField(required=False)
BlogPost.drop_collection()
blog_post_3 = BlogPost(title="Blog Post #3",
published_date=datetime(2010, 1, 6, 0, 0 ,0))
blog_post_2 = BlogPost(title="Blog Post #2",
published_date=datetime(2010, 1, 5, 0, 0 ,0))
blog_post_4 = BlogPost(title="Blog Post #4",
published_date=datetime(2010, 1, 7, 0, 0 ,0))
blog_post_1 = BlogPost(title="Blog Post #1", published_date=None)
blog_post_3.save()
blog_post_1.save()
blog_post_4.save()
blog_post_2.save()
expected = [blog_post_1, blog_post_2, blog_post_3, blog_post_4]
self.assertSequence(BlogPost.objects.order_by('published_date'),
expected)
self.assertSequence(BlogPost.objects.order_by('+published_date'),
expected)
expected.reverse()
self.assertSequence(BlogPost.objects.order_by('-published_date'),
expected)
def test_order_by_list(self):
class BlogPost(Document):
title = StringField()
published_date = DateTimeField(required=False)
BlogPost.drop_collection()
blog_post_1 = BlogPost(title="A",
published_date=datetime(2010, 1, 6, 0, 0 ,0))
blog_post_2 = BlogPost(title="B",
published_date=datetime(2010, 1, 6, 0, 0 ,0))
blog_post_3 = BlogPost(title="C",
published_date=datetime(2010, 1, 7, 0, 0 ,0))
blog_post_2.save()
blog_post_3.save()
blog_post_1.save()
qs = BlogPost.objects.order_by('published_date', 'title')
expected = [blog_post_1, blog_post_2, blog_post_3]
self.assertSequence(qs, expected)
qs = BlogPost.objects.order_by('-published_date', '-title')
expected.reverse()
self.assertSequence(qs, expected)
def test_order_by_chaining(self): def test_order_by_chaining(self):
"""Ensure that an order_by query chains properly and allows .only() """Ensure that an order_by query chains properly and allows .only()
""" """
self.Person(name="User A", age=20).save()
self.Person(name="User B", age=40).save() self.Person(name="User B", age=40).save()
self.Person(name="User A", age=20).save()
self.Person(name="User C", age=30).save() self.Person(name="User C", age=30).save()
only_age = self.Person.objects.order_by('-age').only('age') only_age = self.Person.objects.order_by('-age').only('age')
@ -1537,6 +1597,21 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(names, [None, None, None]) self.assertEqual(names, [None, None, None])
self.assertEqual(ages, [40, 30, 20]) self.assertEqual(ages, [40, 30, 20])
qs = self.Person.objects.all().order_by('-age')
qs = qs.limit(10)
ages = [p.age for p in qs]
self.assertEqual(ages, [40, 30, 20])
qs = self.Person.objects.all().limit(10)
qs = qs.order_by('-age')
ages = [p.age for p in qs]
self.assertEqual(ages, [40, 30, 20])
qs = self.Person.objects.all().skip(0)
qs = qs.order_by('-age')
ages = [p.age for p in qs]
self.assertEqual(ages, [40, 30, 20])
def test_confirm_order_by_reference_wont_work(self): def test_confirm_order_by_reference_wont_work(self):
"""Ordering by reference is not possible. Use map / reduce.. or """Ordering by reference is not possible. Use map / reduce.. or
denormalise""" denormalise"""
@ -2065,6 +2140,25 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(Foo.objects.distinct("bar"), [bar]) self.assertEqual(Foo.objects.distinct("bar"), [bar])
def test_distinct_handles_db_field(self):
"""Ensure that distinct resolves field name to db_field as expected.
"""
class Product(Document):
product_id = IntField(db_field='pid')
Product.drop_collection()
Product(product_id=1).save()
Product(product_id=2).save()
Product(product_id=1).save()
self.assertEqual(set(Product.objects.distinct('product_id')),
set([1, 2]))
self.assertEqual(set(Product.objects.distinct('pid')),
set([1, 2]))
Product.drop_collection()
def test_custom_manager(self): def test_custom_manager(self):
"""Ensure that custom QuerySetManager instances work as expected. """Ensure that custom QuerySetManager instances work as expected.
""" """

View File

@ -7,7 +7,7 @@ from bson import DBRef, ObjectId
from mongoengine import * from mongoengine import *
from mongoengine.connection import get_db from mongoengine.connection import get_db
from mongoengine.context_managers import query_counter, no_dereference from mongoengine.context_managers import query_counter
class FieldTest(unittest.TestCase): class FieldTest(unittest.TestCase):
@ -212,8 +212,9 @@ class FieldTest(unittest.TestCase):
# Migrate the data # Migrate the data
for g in Group.objects(): for g in Group.objects():
g.author = g.author # Explicitly mark as changed so resets
g.members = g.members g._mark_as_changed('author')
g._mark_as_changed('members')
g.save() g.save()
group = Group.objects.first() group = Group.objects.first()
@ -1120,6 +1121,37 @@ class FieldTest(unittest.TestCase):
self.assertEqual(q, 2) self.assertEqual(q, 2)
def test_tuples_as_tuples(self):
"""
Ensure that tuples remain tuples when they are
inside a ComplexBaseField
"""
from mongoengine.base import BaseField
class EnumField(BaseField):
def __init__(self, **kwargs):
super(EnumField, self).__init__(**kwargs)
def to_mongo(self, value):
return value
def to_python(self, value):
return tuple(value)
class TestDoc(Document):
items = ListField(EnumField())
TestDoc.drop_collection()
tuples = [(100, 'Testing')]
doc = TestDoc()
doc.items = tuples
doc.save()
x = TestDoc.objects().get()
self.assertTrue(x is not None)
self.assertTrue(len(x.items) == 1)
self.assertTrue(tuple(x.items[0]) in tuples)
self.assertTrue(x.items[0] in tuples)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()