Merge branch 'master' into 0.8M

Conflicts:
	AUTHORS
	docs/django.rst
	mongoengine/base.py
	mongoengine/queryset.py
	tests/fields/fields.py
	tests/queryset/queryset.py
	tests/test_dereference.py
	tests/test_document.py
This commit is contained in:
Ross Lawley 2013-04-17 11:57:53 +00:00
commit 51e50bf0a9
17 changed files with 474 additions and 101 deletions

11
AUTHORS
View File

@ -129,6 +129,16 @@ that much better:
* Peter Teichman * Peter Teichman
* Jakub Kot * Jakub Kot
* Jorge Bastida * Jorge Bastida
* Aleksandr Sorokoumov
* Yohan Graterol
* bool-dev
* Russ Weeks
* Paul Swartz
* Sundar Raman
* Benoit Louy
* lraucy
* hellysmile
* Jaepil Jeong
* Stefan Wójcik * Stefan Wójcik
* Pete Campton * Pete Campton
* Martyn Smith * Martyn Smith
@ -145,4 +155,3 @@ that much better:
* Jared Forsyth * Jared Forsyth
* Kenneth Falck * Kenneth Falck
* Lukasz Balcerzak * Lukasz Balcerzak
* Aleksandr Sorokoumov

View File

@ -51,6 +51,19 @@ Changes in 0.8.X
Changes in 0.7.10 Changes in 0.7.10
================= =================
- Allow construction using positional parameters (#268)
- Updated EmailField length to support long domains (#243)
- Added 64-bit integer support (#251)
- Added Django sessions TTL support (#224)
- Fixed issue with numerical keys in MapField(EmbeddedDocumentField()) (#240)
- Fixed clearing _changed_fields for complex nested embedded documents (#237, #239, #242)
- Added "_id" to _data dictionary (#255)
- Only mark a field as changed if the value has changed (#258)
- Explicitly check for Document instances when dereferencing (#261)
- Fixed order_by chaining issue (#265)
- Added dereference support for tuples (#250)
- Resolve field name to db field name when using distinct(#260, #264, #269)
- Added kwargs to doc.save to help interop with django (#223, #270)
- Fixed cloning querysets in PY3 - Fixed cloning querysets in PY3
- Int fields no longer unset in save when changed to 0 (#272) - Int fields no longer unset in save when changed to 0 (#272)
- Fixed ReferenceField query chaining bug fixed (#254) - Fixed ReferenceField query chaining bug fixed (#254)

View File

@ -10,9 +10,15 @@ In your **settings.py** file, ignore the standard database settings (unless you
also plan to use the ORM in your project), and instead call also plan to use the ORM in your project), and instead call
:func:`~mongoengine.connect` somewhere in the settings module. :func:`~mongoengine.connect` somewhere in the settings module.
.. note:: If getting an ``ImproperlyConfigured: settings.DATABASES is .. note ::
improperly configured`` error you may need to remove If you are not using another Database backend you may need to add a dummy
``django.contrib.sites`` from ``INSTALLED_APPS`` in settings.py. database backend to ``settings.py`` eg::
DATABASES = {
'default': {
'ENGINE': 'django.db.backends.dummy'
}
}
Authentication Authentication
============== ==============
@ -49,6 +55,9 @@ into you settings module::
SESSION_ENGINE = 'mongoengine.django.sessions' SESSION_ENGINE = 'mongoengine.django.sessions'
Django provides session cookie, which expires after ```SESSION_COOKIE_AGE``` seconds, but doesnt delete cookie at sessions backend, so ``'mongoengine.django.sessions'`` supports `mongodb TTL
<http://docs.mongodb.org/manual/tutorial/expire-data/>`_.
.. versionadded:: 0.2.1 .. versionadded:: 0.2.1
Storage Storage

View File

@ -30,13 +30,22 @@ class BaseDocument(object):
_dynamic_lock = True _dynamic_lock = True
_initialised = False _initialised = False
def __init__(self, __auto_convert=True, **values): def __init__(self, __auto_convert=True, *args, **values):
""" """
Initialise a document or embedded document Initialise a document or embedded document
:param __auto_convert: Try and will cast python objects to Object types :param __auto_convert: Try and will cast python objects to Object types
:param values: A dictionary of values for the document :param values: A dictionary of values for the document
""" """
if args:
# Combine positional arguments with named arguments.
# We only want named arguments.
field = iter(self._fields_ordered)
for value in args:
name = next(field)
if name in values:
raise TypeError("Multiple values for keyword argument '" + name + "'")
values[name] = value
signals.pre_init.send(self.__class__, document=self, values=values) signals.pre_init.send(self.__class__, document=self, values=values)
@ -143,7 +152,10 @@ class BaseDocument(object):
self.__set_field_display() self.__set_field_display()
def __iter__(self): def __iter__(self):
return iter(self._fields) if 'id' in self._fields and 'id' not in self._fields_ordered:
return iter(('id', ) + self._fields_ordered)
return iter(self._fields_ordered)
def __getitem__(self, name): def __getitem__(self, name):
"""Dictionary-style field access, return a field's value if present. """Dictionary-style field access, return a field's value if present.
@ -347,6 +359,22 @@ class BaseDocument(object):
key not in self._changed_fields): key not in self._changed_fields):
self._changed_fields.append(key) self._changed_fields.append(key)
def _clear_changed_fields(self):
self._changed_fields = []
EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
for field_name, field in self._fields.iteritems():
if (isinstance(field, ComplexBaseField) and
isinstance(field.field, EmbeddedDocumentField)):
field_value = getattr(self, field_name, None)
if field_value:
for idx in (field_value if isinstance(field_value, dict)
else xrange(len(field_value))):
field_value[idx]._clear_changed_fields()
elif isinstance(field, EmbeddedDocumentField):
field_value = getattr(self, field_name, None)
if field_value:
field_value._clear_changed_fields()
def _get_changed_fields(self, key='', inspected=None): def _get_changed_fields(self, key='', inspected=None):
"""Returns a list of all fields that have explicitly been changed. """Returns a list of all fields that have explicitly been changed.
""" """
@ -418,7 +446,7 @@ class BaseDocument(object):
for p in parts: for p in parts:
if isinstance(d, DBRef): if isinstance(d, DBRef):
break break
elif p.isdigit(): elif isinstance(d, list) and p.isdigit():
d = d[int(p)] d = d[int(p)]
elif hasattr(d, 'get'): elif hasattr(d, 'get'):
d = d.get(p) d = d.get(p)
@ -449,7 +477,7 @@ class BaseDocument(object):
parts = path.split('.') parts = path.split('.')
db_field_name = parts.pop() db_field_name = parts.pop()
for p in parts: for p in parts:
if p.isdigit(): if isinstance(d, list) and p.isdigit():
d = d[int(p)] d = d[int(p)]
elif (hasattr(d, '__getattribute__') and elif (hasattr(d, '__getattribute__') and
not isinstance(d, dict)): not isinstance(d, dict)):

View File

@ -81,8 +81,12 @@ class BaseField(object):
def __set__(self, instance, value): def __set__(self, instance, value):
"""Descriptor for assigning a value to a field in a document. """Descriptor for assigning a value to a field in a document.
""" """
changed = False
if (self.name not in instance._data or
instance._data[self.name] != value):
changed = True
instance._data[self.name] = value instance._data[self.name] = value
if instance._initialised: if changed and instance._initialised:
instance._mark_as_changed(self.name) instance._mark_as_changed(self.name)
def error(self, message="", errors=None, field_name=None): def error(self, message="", errors=None, field_name=None):

View File

@ -91,6 +91,9 @@ class DocumentMetaclass(type):
attrs['_fields'] = doc_fields attrs['_fields'] = doc_fields
attrs['_db_field_map'] = dict([(k, getattr(v, 'db_field', k)) attrs['_db_field_map'] = dict([(k, getattr(v, 'db_field', k))
for k, v in doc_fields.iteritems()]) for k, v in doc_fields.iteritems()])
attrs['_fields_ordered'] = tuple(i[1] for i in sorted(
(v.creation_counter, v.name)
for v in doc_fields.itervalues()))
attrs['_reverse_db_field_map'] = dict( attrs['_reverse_db_field_map'] = dict(
(v, k) for k, v in attrs['_db_field_map'].iteritems()) (v, k) for k, v in attrs['_db_field_map'].iteritems())
@ -110,7 +113,7 @@ class DocumentMetaclass(type):
# inheritance of classes where inheritance is set to False # inheritance of classes where inheritance is set to False
allow_inheritance = base._meta.get('allow_inheritance', allow_inheritance = base._meta.get('allow_inheritance',
ALLOW_INHERITANCE) ALLOW_INHERITANCE)
if (allow_inheritance != True and if (allow_inheritance is not True and
not base._meta.get('abstract')): not base._meta.get('abstract')):
raise ValueError('Document %s may not be subclassed' % raise ValueError('Document %s may not be subclassed' %
base.__name__) base.__name__)

View File

@ -33,7 +33,7 @@ class DeReference(object):
self.max_depth = max_depth self.max_depth = max_depth
doc_type = None doc_type = None
if instance and instance._fields: if instance and isinstance(instance, (Document, TopLevelDocumentMetaclass)):
doc_type = instance._fields.get(name) doc_type = instance._fields.get(name)
if hasattr(doc_type, 'field'): if hasattr(doc_type, 'field'):
doc_type = doc_type.field doc_type = doc_type.field
@ -84,7 +84,7 @@ class DeReference(object):
# Recursively find dbreferences # Recursively find dbreferences
depth += 1 depth += 1
for k, item in iterator: for k, item in iterator:
if hasattr(item, '_fields'): if isinstance(item, Document):
for field_name, field in item._fields.iteritems(): for field_name, field in item._fields.iteritems():
v = item._data.get(field_name, None) v = item._data.get(field_name, None)
if isinstance(v, (DBRef)): if isinstance(v, (DBRef)):
@ -174,6 +174,7 @@ class DeReference(object):
if not hasattr(items, 'items'): if not hasattr(items, 'items'):
is_list = True is_list = True
as_tuple = isinstance(items, tuple)
iterator = enumerate(items) iterator = enumerate(items)
data = [] data = []
else: else:
@ -190,7 +191,7 @@ class DeReference(object):
if k in self.object_map and not is_list: if k in self.object_map and not is_list:
data[k] = self.object_map[k] data[k] = self.object_map[k]
elif hasattr(v, '_fields'): elif isinstance(v, Document):
for field_name, field in v._fields.iteritems(): for field_name, field in v._fields.iteritems():
v = data[k]._data.get(field_name, None) v = data[k]._data.get(field_name, None)
if isinstance(v, (DBRef)): if isinstance(v, (DBRef)):
@ -208,7 +209,7 @@ class DeReference(object):
if instance and name: if instance and name:
if is_list: if is_list:
return BaseList(data, instance, name) return tuple(data) if as_tuple else BaseList(data, instance, name)
return BaseDict(data, instance, name) return BaseDict(data, instance, name)
depth += 1 depth += 1
return data return data

View File

@ -32,9 +32,17 @@ class MongoSession(Document):
else fields.DictField() else fields.DictField()
expire_date = fields.DateTimeField() expire_date = fields.DateTimeField()
meta = {'collection': MONGOENGINE_SESSION_COLLECTION, meta = {
'collection': MONGOENGINE_SESSION_COLLECTION,
'db_alias': MONGOENGINE_SESSION_DB_ALIAS, 'db_alias': MONGOENGINE_SESSION_DB_ALIAS,
'allow_inheritance': False} 'allow_inheritance': False,
'indexes': [
{
'fields': ['expire_date'],
'expireAfterSeconds': settings.SESSION_COOKIE_AGE
}
]
}
def get_decoded(self): def get_decoded(self):
return SessionStore().decode(self.session_data) return SessionStore().decode(self.session_data)

View File

@ -160,7 +160,7 @@ class Document(BaseDocument):
def save(self, safe=True, force_insert=False, validate=True, clean=True, def save(self, safe=True, force_insert=False, validate=True, clean=True,
write_options=None, cascade=None, cascade_kwargs=None, write_options=None, cascade=None, cascade_kwargs=None,
_refs=None): _refs=None, **kwargs):
"""Save the :class:`~mongoengine.Document` to the database. If the """Save the :class:`~mongoengine.Document` to the database. If the
document already exists, it will be updated, otherwise it will be document already exists, it will be updated, otherwise it will be
created. created.
@ -278,7 +278,7 @@ class Document(BaseDocument):
if id_field not in self._meta.get('shard_key', []): if id_field not in self._meta.get('shard_key', []):
self[id_field] = self._fields[id_field].to_python(object_id) self[id_field] = self._fields[id_field].to_python(object_id)
self._changed_fields = [] self._clear_changed_fields()
self._created = False self._created = False
signals.post_save.send(self.__class__, document=self, created=created) signals.post_save.send(self.__class__, document=self, created=created)
return self return self

View File

@ -27,7 +27,7 @@ except ImportError:
Image = None Image = None
ImageOps = None ImageOps = None
__all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField', __all__ = ['StringField', 'IntField', 'LongField', 'FloatField', 'BooleanField',
'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField', 'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField',
'ObjectIdField', 'ReferenceField', 'ValidationError', 'MapField', 'ObjectIdField', 'ReferenceField', 'ValidationError', 'MapField',
'DecimalField', 'ComplexDateTimeField', 'URLField', 'DynamicField', 'DecimalField', 'ComplexDateTimeField', 'URLField', 'DynamicField',
@ -143,7 +143,7 @@ class EmailField(StringField):
EMAIL_REGEX = re.compile( EMAIL_REGEX = re.compile(
r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*" # dot-atom r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*" # dot-atom
r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string
r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,253}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain
) )
def validate(self, value): def validate(self, value):
@ -153,7 +153,7 @@ class EmailField(StringField):
class IntField(BaseField): class IntField(BaseField):
"""An integer field. """An 32-bit integer field.
""" """
def __init__(self, min_value=None, max_value=None, **kwargs): def __init__(self, min_value=None, max_value=None, **kwargs):
@ -186,6 +186,40 @@ class IntField(BaseField):
return int(value) return int(value)
class LongField(BaseField):
"""An 64-bit integer field.
"""
def __init__(self, min_value=None, max_value=None, **kwargs):
self.min_value, self.max_value = min_value, max_value
super(LongField, self).__init__(**kwargs)
def to_python(self, value):
try:
value = long(value)
except ValueError:
pass
return value
def validate(self, value):
try:
value = long(value)
except:
self.error('%s could not be converted to long' % value)
if self.min_value is not None and value < self.min_value:
self.error('Long value is too small')
if self.max_value is not None and value > self.max_value:
self.error('Long value is too large')
def prepare_query_value(self, op, value):
if value is None:
return value
return long(value)
class FloatField(BaseField): class FloatField(BaseField):
"""An floating point number field. """An floating point number field.
""" """

View File

@ -609,8 +609,11 @@ class QuerySet(object):
.. versionchanged:: 0.6 - Improved db_field refrence handling .. versionchanged:: 0.6 - Improved db_field refrence handling
""" """
queryset = self.clone() queryset = self.clone()
return queryset._dereference(queryset._cursor.distinct(field), 1, try:
name=field, instance=queryset._document) field = self._fields_to_dbfields([field]).pop()
finally:
return self._dereference(queryset._cursor.distinct(field), 1,
name=field, instance=self._document)
def only(self, *fields): def only(self, *fields):
"""Load only a subset of this document's fields. :: """Load only a subset of this document's fields. ::
@ -696,7 +699,7 @@ class QuerySet(object):
prefixed with **+** or **-** to determine the ordering direction prefixed with **+** or **-** to determine the ordering direction
""" """
queryset = self.clone() queryset = self.clone()
queryset._ordering = self._get_order_by(keys) queryset._ordering = queryset._get_order_by(keys)
return queryset return queryset
def explain(self, format=False): def explain(self, format=False):

View File

@ -32,7 +32,7 @@ class SimplificationVisitor(QNodeVisitor):
if combination.operation == combination.AND: if combination.operation == combination.AND:
# The simplification only applies to 'simple' queries # The simplification only applies to 'simple' queries
if all(isinstance(node, Q) for node in combination.children): if all(isinstance(node, Q) for node in combination.children):
queries = [node.query for node in combination.children] queries = [n.query for n in combination.children]
return Q(**self._query_conjunction(queries)) return Q(**self._query_conjunction(queries))
return combination return combination

View File

@ -827,8 +827,8 @@ class InstanceTest(unittest.TestCase):
float_field = FloatField(default=1.1) float_field = FloatField(default=1.1)
boolean_field = BooleanField(default=True) boolean_field = BooleanField(default=True)
datetime_field = DateTimeField(default=datetime.now) datetime_field = DateTimeField(default=datetime.now)
embedded_document_field = EmbeddedDocumentField(EmbeddedDoc, embedded_document_field = EmbeddedDocumentField(
default=lambda: EmbeddedDoc()) EmbeddedDoc, default=lambda: EmbeddedDoc())
list_field = ListField(default=lambda: [1, 2, 3]) list_field = ListField(default=lambda: [1, 2, 3])
dict_field = DictField(default=lambda: {"hello": "world"}) dict_field = DictField(default=lambda: {"hello": "world"})
objectid_field = ObjectIdField(default=bson.ObjectId) objectid_field = ObjectIdField(default=bson.ObjectId)
@ -1127,20 +1127,20 @@ class InstanceTest(unittest.TestCase):
u3 = User(username="hmarr") u3 = User(username="hmarr")
u3.save() u3.save()
p1 = Page(comments = [Comment(user=u1, comment="Its very good"), p1 = Page(comments=[Comment(user=u1, comment="Its very good"),
Comment(user=u2, comment="Hello world"), Comment(user=u2, comment="Hello world"),
Comment(user=u3, comment="Ping Pong"), Comment(user=u3, comment="Ping Pong"),
Comment(user=u1, comment="I like a beer")]) Comment(user=u1, comment="I like a beer")])
p1.save() p1.save()
p2 = Page(comments = [Comment(user=u1, comment="Its very good"), p2 = Page(comments=[Comment(user=u1, comment="Its very good"),
Comment(user=u2, comment="Hello world")]) Comment(user=u2, comment="Hello world")])
p2.save() p2.save()
p3 = Page(comments = [Comment(user=u3, comment="Its very good")]) p3 = Page(comments=[Comment(user=u3, comment="Its very good")])
p3.save() p3.save()
p4 = Page(comments = [Comment(user=u2, comment="Heavy Metal song")]) p4 = Page(comments=[Comment(user=u2, comment="Heavy Metal song")])
p4.save() p4.save()
self.assertEqual([p1, p2], list(Page.objects.filter(comments__user=u1))) self.assertEqual([p1, p2], list(Page.objects.filter(comments__user=u1)))
@ -1183,7 +1183,6 @@ class InstanceTest(unittest.TestCase):
class Site(Document): class Site(Document):
page = EmbeddedDocumentField(Page) page = EmbeddedDocumentField(Page)
Site.drop_collection() Site.drop_collection()
site = Site(page=Page(log_message="Warning: Dummy message")) site = Site(page=Page(log_message="Warning: Dummy message"))
site.save() site.save()
@ -1328,7 +1327,8 @@ class InstanceTest(unittest.TestCase):
occurs = ListField(EmbeddedDocumentField(Occurrence), default=list) occurs = ListField(EmbeddedDocumentField(Occurrence), default=list)
def raise_invalid_document(): def raise_invalid_document():
Word._from_son({'stem': [1,2,3], 'forms': 1, 'count': 'one', 'occurs': {"hello": None}}) Word._from_son({'stem': [1, 2, 3], 'forms': 1, 'count': 'one',
'occurs': {"hello": None}})
self.assertRaises(InvalidDocumentError, raise_invalid_document) self.assertRaises(InvalidDocumentError, raise_invalid_document)
@ -1350,7 +1350,7 @@ class InstanceTest(unittest.TestCase):
reviewer = self.Person(name='Re Viewer') reviewer = self.Person(name='Re Viewer')
reviewer.save() reviewer.save()
post = BlogPost(content = 'Watched some TV') post = BlogPost(content='Watched some TV')
post.author = author post.author = author
post.reviewer = reviewer post.reviewer = reviewer
post.save() post.save()
@ -1432,7 +1432,6 @@ class InstanceTest(unittest.TestCase):
author.delete() author.delete()
self.assertEqual(len(BlogPost.objects), 0) self.assertEqual(len(BlogPost.objects), 0)
def test_reverse_delete_rule_cascade_triggers_pre_delete_signal(self): def test_reverse_delete_rule_cascade_triggers_pre_delete_signal(self):
''' ensure the pre_delete signal is triggered upon a cascading deletion ''' ensure the pre_delete signal is triggered upon a cascading deletion
setup a blog post with content, an author and editor setup a blog post with content, an author and editor
@ -1730,7 +1729,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(Doc.objects(archived=False).count(), 1) self.assertEqual(Doc.objects(archived=False).count(), 1)
def test_can_save_false_values_dynamic(self): def test_can_save_false_values_dynamic(self):
"""Ensures you can save False values on dynamic docs""" """Ensures you can save False values on dynamic docs"""
class Doc(DynamicDocument): class Doc(DynamicDocument):
@ -1943,8 +1941,8 @@ class InstanceTest(unittest.TestCase):
function(){ function(){
return this.name == '1' || return this.name == '1' ||
this.name == '2';}""" this.name == '2';}"""
} })]),
)]), "1,2") "1,2")
def test_switch_db_instance(self): def test_switch_db_instance(self):
register_connection('testdb-1', 'mongoenginetest2') register_connection('testdb-1', 'mongoenginetest2')
@ -2020,7 +2018,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual("Bar", user._data["foo"]) self.assertEqual("Bar", user._data["foo"])
self.assertEqual([1, 2, 3], user._data["data"]) self.assertEqual([1, 2, 3], user._data["data"])
def test_spaces_in_keys(self): def test_spaces_in_keys(self):
class Embedded(DynamicEmbeddedDocument): class Embedded(DynamicEmbeddedDocument):
@ -2118,5 +2115,82 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(classic_doc, dict_doc) self.assertEqual(classic_doc, dict_doc)
self.assertEqual(classic_doc._data, dict_doc._data) self.assertEqual(classic_doc._data, dict_doc._data)
def test_positional_creation(self):
"""Ensure that document may be created using positional arguments.
"""
person = self.Person("Test User", 42)
self.assertEqual(person.name, "Test User")
self.assertEqual(person.age, 42)
def test_mixed_creation(self):
"""Ensure that document may be created using mixed arguments.
"""
person = self.Person("Test User", age=42)
self.assertEqual(person.name, "Test User")
self.assertEqual(person.age, 42)
def test_bad_mixed_creation(self):
"""Ensure that document gives correct error when duplicating arguments
"""
def construct_bad_instance():
return self.Person("Test User", 42, name="Bad User")
self.assertRaises(TypeError, construct_bad_instance)
def test_data_contains_id_field(self):
"""Ensure that asking for _data returns 'id'
"""
class Person(Document):
name = StringField()
Person.drop_collection()
Person(name="Harry Potter").save()
person = Person.objects.first()
self.assertTrue('id' in person._data.keys())
self.assertEqual(person._data.get('id'), person.id)
def test_complex_nesting_document_and_embedded_document(self):
class Macro(EmbeddedDocument):
value = DynamicField(default="UNDEFINED")
class Parameter(EmbeddedDocument):
macros = MapField(EmbeddedDocumentField(Macro))
def expand(self):
self.macros["test"] = Macro()
class Node(Document):
parameters = MapField(EmbeddedDocumentField(Parameter))
def expand(self):
self.flattened_parameter = {}
for parameter_name, parameter in self.parameters.iteritems():
parameter.expand()
class System(Document):
name = StringField(required=True)
nodes = MapField(ReferenceField(Node, dbref=False))
def save(self, *args, **kwargs):
for node_name, node in self.nodes.iteritems():
node.expand()
node.save(*args, **kwargs)
super(System, self).save(*args, **kwargs)
System.drop_collection()
Node.drop_collection()
system = System(name="system")
system.nodes["node"] = Node()
system.save()
system.nodes["node"].parameters["param"] = Parameter()
system.save()
system = System.objects.first()
self.assertEqual("UNDEFINED", system.nodes["node"].parameters["param"].macros["test"].value)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -130,8 +130,8 @@ class ValidatorErrorTest(unittest.TestCase):
doc = Doc.objects.first() doc = Doc.objects.first()
keys = doc._data.keys() keys = doc._data.keys()
self.assertEqual(2, len(keys)) self.assertEqual(2, len(keys))
self.assertTrue('id' in keys)
self.assertTrue('e' in keys) self.assertTrue('e' in keys)
self.assertTrue('id' in keys)
doc.e.val = "OK" doc.e.val = "OK"
try: try:

View File

@ -145,6 +145,17 @@ class FieldTest(unittest.TestCase):
self.assertEqual(1, TestDocument.objects(int_fld__ne=None).count()) self.assertEqual(1, TestDocument.objects(int_fld__ne=None).count())
self.assertEqual(1, TestDocument.objects(float_fld__ne=None).count()) self.assertEqual(1, TestDocument.objects(float_fld__ne=None).count())
def test_long_ne_operator(self):
class TestDocument(Document):
long_fld = LongField()
TestDocument.drop_collection()
TestDocument(long_fld=None).save()
TestDocument(long_fld=1).save()
self.assertEqual(1, TestDocument.objects(long_fld__ne=None).count())
def test_object_id_validation(self): def test_object_id_validation(self):
"""Ensure that invalid values cannot be assigned to string fields. """Ensure that invalid values cannot be assigned to string fields.
""" """
@ -218,6 +229,23 @@ class FieldTest(unittest.TestCase):
person.age = 'ten' person.age = 'ten'
self.assertRaises(ValidationError, person.validate) self.assertRaises(ValidationError, person.validate)
def test_long_validation(self):
"""Ensure that invalid values cannot be assigned to long fields.
"""
class TestDocument(Document):
value = LongField(min_value=0, max_value=110)
doc = TestDocument()
doc.value = 50
doc.validate()
doc.value = -1
self.assertRaises(ValidationError, doc.validate)
doc.age = 120
self.assertRaises(ValidationError, doc.validate)
doc.age = 'ten'
self.assertRaises(ValidationError, doc.validate)
def test_float_validation(self): def test_float_validation(self):
"""Ensure that invalid values cannot be assigned to float fields. """Ensure that invalid values cannot be assigned to float fields.
""" """
@ -971,6 +999,24 @@ class FieldTest(unittest.TestCase):
doc = self.db.test.find_one() doc = self.db.test.find_one()
self.assertEqual(doc['x']['DICTIONARY_KEY']['i'], 2) self.assertEqual(doc['x']['DICTIONARY_KEY']['i'], 2)
def test_mapfield_numerical_index(self):
"""Ensure that MapField accept numeric strings as indexes."""
class Embedded(EmbeddedDocument):
name = StringField()
class Test(Document):
my_map = MapField(EmbeddedDocumentField(Embedded))
Test.drop_collection()
test = Test()
test.my_map['1'] = Embedded(name='test')
test.save()
test.my_map['1'].name = 'test updated'
test.save()
Test.drop_collection()
def test_map_field_lookup(self): def test_map_field_lookup(self):
"""Ensure MapField lookups succeed on Fields without a lookup method""" """Ensure MapField lookups succeed on Fields without a lookup method"""
@ -2404,6 +2450,21 @@ class FieldTest(unittest.TestCase):
post.comments[1].content = 'here we go' post.comments[1].content = 'here we go'
post.validate() post.validate()
def test_email_field(self):
class User(Document):
email = EmailField()
user = User(email="ross@example.com")
self.assertTrue(user.validate() is None)
user = User(email=("Kofq@rhom0e4klgauOhpbpNdogawnyIKvQS0wk2mjqrgGQ5S"
"ucictfqpdkK9iS1zeFw8sg7s7cwAF7suIfUfeyueLpfosjn3"
"aJIazqqWkm7.net"))
self.assertTrue(user.validate() is None)
user = User(email='me@localhost')
self.assertRaises(ValidationError, user.validate)
def test_email_field_honors_regex(self): def test_email_field_honors_regex(self):
class User(Document): class User(Document):
email = EmailField(regex=r'\w+@example.com') email = EmailField(regex=r'\w+@example.com')

View File

@ -931,6 +931,11 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection() BlogPost.drop_collection()
Blog.drop_collection() Blog.drop_collection()
def assertSequence(self, qs, expected):
self.assertEqual(len(qs), len(expected))
for i in range(len(qs)):
self.assertEqual(qs[i], expected[i])
def test_ordering(self): def test_ordering(self):
"""Ensure default ordering is applied and can be overridden. """Ensure default ordering is applied and can be overridden.
""" """
@ -957,14 +962,13 @@ class QuerySetTest(unittest.TestCase):
# get the "first" BlogPost using default ordering # get the "first" BlogPost using default ordering
# from BlogPost.meta.ordering # from BlogPost.meta.ordering
latest_post = BlogPost.objects.first() expected = [blog_post_3, blog_post_2, blog_post_1]
self.assertEqual(latest_post.title, "Blog Post #3") self.assertSequence(BlogPost.objects.all(), expected)
# override default ordering, order BlogPosts by "published_date" # override default ordering, order BlogPosts by "published_date"
first_post = BlogPost.objects.order_by("+published_date").first() qs = BlogPost.objects.order_by("+published_date")
self.assertEqual(first_post.title, "Blog Post #1") expected = [blog_post_1, blog_post_2, blog_post_3]
self.assertSequence(qs, expected)
BlogPost.drop_collection()
def test_find_embedded(self): def test_find_embedded(self):
"""Ensure that an embedded document is properly returned from a query. """Ensure that an embedded document is properly returned from a query.
@ -1505,8 +1509,8 @@ class QuerySetTest(unittest.TestCase):
def test_order_by(self): def test_order_by(self):
"""Ensure that QuerySets may be ordered. """Ensure that QuerySets may be ordered.
""" """
self.Person(name="User A", age=20).save()
self.Person(name="User B", age=40).save() self.Person(name="User B", age=40).save()
self.Person(name="User A", age=20).save()
self.Person(name="User C", age=30).save() self.Person(name="User C", age=30).save()
names = [p.name for p in self.Person.objects.order_by('-age')] names = [p.name for p in self.Person.objects.order_by('-age')]
@ -1521,11 +1525,67 @@ class QuerySetTest(unittest.TestCase):
ages = [p.age for p in self.Person.objects.order_by('-name')] ages = [p.age for p in self.Person.objects.order_by('-name')]
self.assertEqual(ages, [30, 40, 20]) self.assertEqual(ages, [30, 40, 20])
def test_order_by_optional(self):
class BlogPost(Document):
title = StringField()
published_date = DateTimeField(required=False)
BlogPost.drop_collection()
blog_post_3 = BlogPost(title="Blog Post #3",
published_date=datetime(2010, 1, 6, 0, 0 ,0))
blog_post_2 = BlogPost(title="Blog Post #2",
published_date=datetime(2010, 1, 5, 0, 0 ,0))
blog_post_4 = BlogPost(title="Blog Post #4",
published_date=datetime(2010, 1, 7, 0, 0 ,0))
blog_post_1 = BlogPost(title="Blog Post #1", published_date=None)
blog_post_3.save()
blog_post_1.save()
blog_post_4.save()
blog_post_2.save()
expected = [blog_post_1, blog_post_2, blog_post_3, blog_post_4]
self.assertSequence(BlogPost.objects.order_by('published_date'),
expected)
self.assertSequence(BlogPost.objects.order_by('+published_date'),
expected)
expected.reverse()
self.assertSequence(BlogPost.objects.order_by('-published_date'),
expected)
def test_order_by_list(self):
class BlogPost(Document):
title = StringField()
published_date = DateTimeField(required=False)
BlogPost.drop_collection()
blog_post_1 = BlogPost(title="A",
published_date=datetime(2010, 1, 6, 0, 0 ,0))
blog_post_2 = BlogPost(title="B",
published_date=datetime(2010, 1, 6, 0, 0 ,0))
blog_post_3 = BlogPost(title="C",
published_date=datetime(2010, 1, 7, 0, 0 ,0))
blog_post_2.save()
blog_post_3.save()
blog_post_1.save()
qs = BlogPost.objects.order_by('published_date', 'title')
expected = [blog_post_1, blog_post_2, blog_post_3]
self.assertSequence(qs, expected)
qs = BlogPost.objects.order_by('-published_date', '-title')
expected.reverse()
self.assertSequence(qs, expected)
def test_order_by_chaining(self): def test_order_by_chaining(self):
"""Ensure that an order_by query chains properly and allows .only() """Ensure that an order_by query chains properly and allows .only()
""" """
self.Person(name="User A", age=20).save()
self.Person(name="User B", age=40).save() self.Person(name="User B", age=40).save()
self.Person(name="User A", age=20).save()
self.Person(name="User C", age=30).save() self.Person(name="User C", age=30).save()
only_age = self.Person.objects.order_by('-age').only('age') only_age = self.Person.objects.order_by('-age').only('age')
@ -1537,6 +1597,21 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(names, [None, None, None]) self.assertEqual(names, [None, None, None])
self.assertEqual(ages, [40, 30, 20]) self.assertEqual(ages, [40, 30, 20])
qs = self.Person.objects.all().order_by('-age')
qs = qs.limit(10)
ages = [p.age for p in qs]
self.assertEqual(ages, [40, 30, 20])
qs = self.Person.objects.all().limit(10)
qs = qs.order_by('-age')
ages = [p.age for p in qs]
self.assertEqual(ages, [40, 30, 20])
qs = self.Person.objects.all().skip(0)
qs = qs.order_by('-age')
ages = [p.age for p in qs]
self.assertEqual(ages, [40, 30, 20])
def test_confirm_order_by_reference_wont_work(self): def test_confirm_order_by_reference_wont_work(self):
"""Ordering by reference is not possible. Use map / reduce.. or """Ordering by reference is not possible. Use map / reduce.. or
denormalise""" denormalise"""
@ -2065,6 +2140,25 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(Foo.objects.distinct("bar"), [bar]) self.assertEqual(Foo.objects.distinct("bar"), [bar])
def test_distinct_handles_db_field(self):
"""Ensure that distinct resolves field name to db_field as expected.
"""
class Product(Document):
product_id = IntField(db_field='pid')
Product.drop_collection()
Product(product_id=1).save()
Product(product_id=2).save()
Product(product_id=1).save()
self.assertEqual(set(Product.objects.distinct('product_id')),
set([1, 2]))
self.assertEqual(set(Product.objects.distinct('pid')),
set([1, 2]))
Product.drop_collection()
def test_custom_manager(self): def test_custom_manager(self):
"""Ensure that custom QuerySetManager instances work as expected. """Ensure that custom QuerySetManager instances work as expected.
""" """

View File

@ -7,7 +7,7 @@ from bson import DBRef, ObjectId
from mongoengine import * from mongoengine import *
from mongoengine.connection import get_db from mongoengine.connection import get_db
from mongoengine.context_managers import query_counter, no_dereference from mongoengine.context_managers import query_counter
class FieldTest(unittest.TestCase): class FieldTest(unittest.TestCase):
@ -212,8 +212,9 @@ class FieldTest(unittest.TestCase):
# Migrate the data # Migrate the data
for g in Group.objects(): for g in Group.objects():
g.author = g.author # Explicitly mark as changed so resets
g.members = g.members g._mark_as_changed('author')
g._mark_as_changed('members')
g.save() g.save()
group = Group.objects.first() group = Group.objects.first()
@ -1120,6 +1121,37 @@ class FieldTest(unittest.TestCase):
self.assertEqual(q, 2) self.assertEqual(q, 2)
def test_tuples_as_tuples(self):
"""
Ensure that tuples remain tuples when they are
inside a ComplexBaseField
"""
from mongoengine.base import BaseField
class EnumField(BaseField):
def __init__(self, **kwargs):
super(EnumField, self).__init__(**kwargs)
def to_mongo(self, value):
return value
def to_python(self, value):
return tuple(value)
class TestDoc(Document):
items = ListField(EnumField())
TestDoc.drop_collection()
tuples = [(100, 'Testing')]
doc = TestDoc()
doc.items = tuples
doc.save()
x = TestDoc.objects().get()
self.assertTrue(x is not None)
self.assertTrue(len(x.items) == 1)
self.assertTrue(tuple(x.items[0]) in tuples)
self.assertTrue(x.items[0] in tuples)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()