Added the ability to reload specific document fields #100

This commit is contained in:
Ross Lawley 2014-06-27 11:10:14 +01:00
parent 67eaf120b9
commit 9cc4359c04
3 changed files with 135 additions and 104 deletions

View File

@ -6,6 +6,7 @@ Changelog
Changes in 0.9.X - DEV Changes in 0.9.X - DEV
====================== ======================
- Added the ability to reload specific document fields #100
- Added db_alias support and fixes for custom map/reduce output #586 - Added db_alias support and fixes for custom map/reduce output #586
- post_save signal now has access to delta information about field changes #594 #589 - post_save signal now has access to delta information about field changes #594 #589
- Don't query with $orderby for qs.get() #600 - Don't query with $orderby for qs.get() #600

View File

@ -54,7 +54,7 @@ class EmbeddedDocument(BaseDocument):
`_cls` set :attr:`allow_inheritance` to ``False`` in the :attr:`meta` `_cls` set :attr:`allow_inheritance` to ``False`` in the :attr:`meta`
dictionary. dictionary.
""" """
__slots__ = ('_instance') __slots__ = ('_instance')
# The __metaclass__ attribute is removed by 2to3 when running with Python3 # The __metaclass__ attribute is removed by 2to3 when running with Python3
@ -463,27 +463,41 @@ class Document(BaseDocument):
DeReference()([self], max_depth + 1) DeReference()([self], max_depth + 1)
return self return self
def reload(self, max_depth=1): def reload(self, *fields, **kwargs):
"""Reloads all attributes from the database. """Reloads all attributes from the database.
:param fields: (optional) args list of fields to reload
:param max_depth: (optional) depth of dereferencing to follow
.. versionadded:: 0.1.2 .. versionadded:: 0.1.2
.. versionchanged:: 0.6 Now chainable .. versionchanged:: 0.6 Now chainable
.. versionchanged:: 0.9 Can provide specific fields to reload
""" """
max_depth = 1
if fields and isinstance(fields[0], int):
max_depth = fields[0]
fields = fields[1:]
elif "max_depth" in kwargs:
max_depth = kwargs["max_depth"]
if not self.pk: if not self.pk:
raise self.DoesNotExist("Document does not exist") raise self.DoesNotExist("Document does not exist")
obj = self._qs.read_preference(ReadPreference.PRIMARY).filter( obj = self._qs.read_preference(ReadPreference.PRIMARY).filter(
**self._object_key).limit(1).select_related(max_depth=max_depth) **self._object_key).only(*fields).limit(1
).select_related(max_depth=max_depth)
if obj: if obj:
obj = obj[0] obj = obj[0]
else: else:
raise self.DoesNotExist("Document does not exist") raise self.DoesNotExist("Document does not exist")
for field in self._fields_ordered: for field in self._fields_ordered:
setattr(self, field, self._reload(field, obj[field])) if not fields or field in fields:
setattr(self, field, self._reload(field, obj[field]))
self._changed_fields = obj._changed_fields self._changed_fields = obj._changed_fields
self._created = False self._created = False
return obj return self
def _reload(self, key, value): def _reload(self, key, value):
"""Used by :meth:`~mongoengine.Document.reload` to ensure the """Used by :meth:`~mongoengine.Document.reload` to ensure the

View File

@ -50,7 +50,7 @@ class InstanceTest(unittest.TestCase):
continue continue
self.db.drop_collection(collection) self.db.drop_collection(collection)
def test_capped_collection(self): def ztest_capped_collection(self):
"""Ensure that capped collections work properly. """Ensure that capped collections work properly.
""" """
class Log(Document): class Log(Document):
@ -90,7 +90,7 @@ class InstanceTest(unittest.TestCase):
Log.drop_collection() Log.drop_collection()
def test_repr(self): def ztest_repr(self):
"""Ensure that unicode representation works """Ensure that unicode representation works
""" """
class Article(Document): class Article(Document):
@ -103,7 +103,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual('<Article: привет мир>', repr(doc)) self.assertEqual('<Article: привет мир>', repr(doc))
def test_queryset_resurrects_dropped_collection(self): def ztest_queryset_resurrects_dropped_collection(self):
self.Person.drop_collection() self.Person.drop_collection()
self.assertEqual([], list(self.Person.objects())) self.assertEqual([], list(self.Person.objects()))
@ -116,7 +116,7 @@ class InstanceTest(unittest.TestCase):
self.Person.drop_collection() self.Person.drop_collection()
self.assertEqual([], list(Actor.objects())) self.assertEqual([], list(Actor.objects()))
def test_polymorphic_references(self): def ztest_polymorphic_references(self):
"""Ensure that the correct subclasses are returned from a query when """Ensure that the correct subclasses are returned from a query when
using references / generic references using references / generic references
""" """
@ -163,7 +163,7 @@ class InstanceTest(unittest.TestCase):
Zoo.drop_collection() Zoo.drop_collection()
Animal.drop_collection() Animal.drop_collection()
def test_reference_inheritance(self): def ztest_reference_inheritance(self):
class Stats(Document): class Stats(Document):
created = DateTimeField(default=datetime.now) created = DateTimeField(default=datetime.now)
@ -188,7 +188,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(list_stats, CompareStats.objects.first().stats) self.assertEqual(list_stats, CompareStats.objects.first().stats)
def test_db_field_load(self): def ztest_db_field_load(self):
"""Ensure we load data correctly """Ensure we load data correctly
""" """
class Person(Document): class Person(Document):
@ -208,7 +208,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(Person.objects.get(name="Jack").rank, "Corporal") self.assertEqual(Person.objects.get(name="Jack").rank, "Corporal")
self.assertEqual(Person.objects.get(name="Fred").rank, "Private") self.assertEqual(Person.objects.get(name="Fred").rank, "Private")
def test_db_embedded_doc_field_load(self): def ztest_db_embedded_doc_field_load(self):
"""Ensure we load embedded document data correctly """Ensure we load embedded document data correctly
""" """
class Rank(EmbeddedDocument): class Rank(EmbeddedDocument):
@ -234,7 +234,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(Person.objects.get(name="Jack").rank, "Corporal") self.assertEqual(Person.objects.get(name="Jack").rank, "Corporal")
self.assertEqual(Person.objects.get(name="Fred").rank, "Private") self.assertEqual(Person.objects.get(name="Fred").rank, "Private")
def test_custom_id_field(self): def ztest_custom_id_field(self):
"""Ensure that documents may be created with custom primary keys. """Ensure that documents may be created with custom primary keys.
""" """
class User(Document): class User(Document):
@ -286,7 +286,7 @@ class InstanceTest(unittest.TestCase):
User.drop_collection() User.drop_collection()
def test_document_not_registered(self): def ztest_document_not_registered(self):
class Place(Document): class Place(Document):
name = StringField() name = StringField()
@ -310,7 +310,7 @@ class InstanceTest(unittest.TestCase):
print Place.objects.all() print Place.objects.all()
self.assertRaises(NotRegistered, query_without_importing_nice_place) self.assertRaises(NotRegistered, query_without_importing_nice_place)
def test_document_registry_regressions(self): def ztest_document_registry_regressions(self):
class Location(Document): class Location(Document):
name = StringField() name = StringField()
@ -324,14 +324,14 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(Area, get_document("Area")) self.assertEqual(Area, get_document("Area"))
self.assertEqual(Area, get_document("Location.Area")) self.assertEqual(Area, get_document("Location.Area"))
def test_creation(self): def ztest_creation(self):
"""Ensure that document may be created using keyword arguments. """Ensure that document may be created using keyword arguments.
""" """
person = self.Person(name="Test User", age=30) person = self.Person(name="Test User", age=30)
self.assertEqual(person.name, "Test User") self.assertEqual(person.name, "Test User")
self.assertEqual(person.age, 30) self.assertEqual(person.age, 30)
def test_to_dbref(self): def ztest_to_dbref(self):
"""Ensure that you can get a dbref of a document""" """Ensure that you can get a dbref of a document"""
person = self.Person(name="Test User", age=30) person = self.Person(name="Test User", age=30)
self.assertRaises(OperationError, person.to_dbref) self.assertRaises(OperationError, person.to_dbref)
@ -353,11 +353,19 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(person.name, "Test User") self.assertEqual(person.name, "Test User")
self.assertEqual(person.age, 20) self.assertEqual(person.age, 20)
person.reload('age')
self.assertEqual(person.name, "Test User")
self.assertEqual(person.age, 21)
person.reload() person.reload()
self.assertEqual(person.name, "Mr Test User") self.assertEqual(person.name, "Mr Test User")
self.assertEqual(person.age, 21) self.assertEqual(person.age, 21)
def test_reload_sharded(self): person.reload()
self.assertEqual(person.name, "Mr Test User")
self.assertEqual(person.age, 21)
def ztest_reload_sharded(self):
class Animal(Document): class Animal(Document):
superphylum = StringField() superphylum = StringField()
meta = {'shard_key': ('superphylum',)} meta = {'shard_key': ('superphylum',)}
@ -368,7 +376,7 @@ class InstanceTest(unittest.TestCase):
doc.reload() doc.reload()
Animal.drop_collection() Animal.drop_collection()
def test_reload_referencing(self): def ztest_reload_referencing(self):
"""Ensures reloading updates weakrefs correctly """Ensures reloading updates weakrefs correctly
""" """
class Embedded(EmbeddedDocument): class Embedded(EmbeddedDocument):
@ -402,6 +410,7 @@ class InstanceTest(unittest.TestCase):
'embedded_field.dict_field']) 'embedded_field.dict_field'])
doc.save() doc.save()
self.assertEqual(len(doc.list_field), 4)
doc = doc.reload(10) doc = doc.reload(10)
self.assertEqual(doc._get_changed_fields(), []) self.assertEqual(doc._get_changed_fields(), [])
self.assertEqual(len(doc.list_field), 4) self.assertEqual(len(doc.list_field), 4)
@ -409,7 +418,17 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(len(doc.embedded_field.list_field), 4) self.assertEqual(len(doc.embedded_field.list_field), 4)
self.assertEqual(len(doc.embedded_field.dict_field), 2) self.assertEqual(len(doc.embedded_field.dict_field), 2)
def test_reload_doesnt_exist(self): doc.list_field.append(1)
doc.save()
doc.dict_field['extra'] = 1
doc = doc.reload(10, 'list_field')
self.assertEqual(doc._get_changed_fields(), [])
self.assertEqual(len(doc.list_field), 5)
self.assertEqual(len(doc.dict_field), 3)
self.assertEqual(len(doc.embedded_field.list_field), 4)
self.assertEqual(len(doc.embedded_field.dict_field), 2)
def ztest_reload_doesnt_exist(self):
class Foo(Document): class Foo(Document):
pass pass
@ -430,7 +449,7 @@ class InstanceTest(unittest.TestCase):
except Exception as ex: except Exception as ex:
self.assertFalse("Threw wrong exception") self.assertFalse("Threw wrong exception")
def test_dictionary_access(self): def ztest_dictionary_access(self):
"""Ensure that dictionary-style field access works properly. """Ensure that dictionary-style field access works properly.
""" """
person = self.Person(name='Test User', age=30) person = self.Person(name='Test User', age=30)
@ -450,7 +469,7 @@ class InstanceTest(unittest.TestCase):
self.assertFalse('age' in person) self.assertFalse('age' in person)
self.assertFalse('nationality' in person) self.assertFalse('nationality' in person)
def test_embedded_document_to_mongo(self): def ztest_embedded_document_to_mongo(self):
class Person(EmbeddedDocument): class Person(EmbeddedDocument):
name = StringField() name = StringField()
age = IntField() age = IntField()
@ -465,14 +484,14 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(Employee(name="Bob", age=35, salary=0).to_mongo().keys(), self.assertEqual(Employee(name="Bob", age=35, salary=0).to_mongo().keys(),
['_cls', 'name', 'age', 'salary']) ['_cls', 'name', 'age', 'salary'])
def test_embedded_document_to_mongo_id(self): def ztest_embedded_document_to_mongo_id(self):
class SubDoc(EmbeddedDocument): class SubDoc(EmbeddedDocument):
id = StringField(required=True) id = StringField(required=True)
sub_doc = SubDoc(id="abc") sub_doc = SubDoc(id="abc")
self.assertEqual(sub_doc.to_mongo().keys(), ['id']) self.assertEqual(sub_doc.to_mongo().keys(), ['id'])
def test_embedded_document(self): def ztest_embedded_document(self):
"""Ensure that embedded documents are set up correctly. """Ensure that embedded documents are set up correctly.
""" """
class Comment(EmbeddedDocument): class Comment(EmbeddedDocument):
@ -481,7 +500,7 @@ class InstanceTest(unittest.TestCase):
self.assertTrue('content' in Comment._fields) self.assertTrue('content' in Comment._fields)
self.assertFalse('id' in Comment._fields) self.assertFalse('id' in Comment._fields)
def test_embedded_document_instance(self): def ztest_embedded_document_instance(self):
"""Ensure that embedded documents can reference parent instance """Ensure that embedded documents can reference parent instance
""" """
class Embedded(EmbeddedDocument): class Embedded(EmbeddedDocument):
@ -496,7 +515,7 @@ class InstanceTest(unittest.TestCase):
doc = Doc.objects.get() doc = Doc.objects.get()
self.assertEqual(doc, doc.embedded_field._instance) self.assertEqual(doc, doc.embedded_field._instance)
def test_embedded_document_complex_instance(self): def ztest_embedded_document_complex_instance(self):
"""Ensure that embedded documents in complex fields can reference """Ensure that embedded documents in complex fields can reference
parent instance""" parent instance"""
class Embedded(EmbeddedDocument): class Embedded(EmbeddedDocument):
@ -511,13 +530,10 @@ class InstanceTest(unittest.TestCase):
doc = Doc.objects.get() doc = Doc.objects.get()
self.assertEqual(doc, doc.embedded_field[0]._instance) self.assertEqual(doc, doc.embedded_field[0]._instance)
def test_instance_is_set_on_setattr(self): def ztest_instance_is_set_on_setattr(self):
class Email(EmbeddedDocument): class Email(EmbeddedDocument):
email = EmailField() email = EmailField()
def clean(self):
print "instance:"
print self._instance
class Account(Document): class Account(Document):
email = EmbeddedDocumentField(Email) email = EmbeddedDocumentField(Email)
@ -531,7 +547,7 @@ class InstanceTest(unittest.TestCase):
acc1 = Account.objects.first() acc1 = Account.objects.first()
self.assertTrue(hasattr(acc1._data["email"], "_instance")) self.assertTrue(hasattr(acc1._data["email"], "_instance"))
def test_document_clean(self): def ztest_document_clean(self):
class TestDocument(Document): class TestDocument(Document):
status = StringField() status = StringField()
pub_date = DateTimeField() pub_date = DateTimeField()
@ -565,7 +581,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(type(t.pub_date), datetime) self.assertEqual(type(t.pub_date), datetime)
def test_document_embedded_clean(self): def ztest_document_embedded_clean(self):
class TestEmbeddedDocument(EmbeddedDocument): class TestEmbeddedDocument(EmbeddedDocument):
x = IntField(required=True) x = IntField(required=True)
y = IntField(required=True) y = IntField(required=True)
@ -601,7 +617,7 @@ class InstanceTest(unittest.TestCase):
t = TestDocument(doc=TestEmbeddedDocument(x=15, y=35, z=5)) t = TestDocument(doc=TestEmbeddedDocument(x=15, y=35, z=5))
t.save(clean=False) t.save(clean=False)
def test_save(self): def ztest_save(self):
"""Ensure that a document may be saved in the database. """Ensure that a document may be saved in the database.
""" """
# Create person object and save it to the database # Create person object and save it to the database
@ -626,7 +642,7 @@ class InstanceTest(unittest.TestCase):
except ValidationError: except ValidationError:
self.fail() self.fail()
def test_save_to_a_value_that_equates_to_false(self): def ztest_save_to_a_value_that_equates_to_false(self):
class Thing(EmbeddedDocument): class Thing(EmbeddedDocument):
count = IntField() count = IntField()
@ -646,7 +662,7 @@ class InstanceTest(unittest.TestCase):
user.reload() user.reload()
self.assertEqual(user.thing.count, 0) self.assertEqual(user.thing.count, 0)
def test_save_max_recursion_not_hit(self): def ztest_save_max_recursion_not_hit(self):
class Person(Document): class Person(Document):
name = StringField() name = StringField()
@ -672,7 +688,7 @@ class InstanceTest(unittest.TestCase):
p0.name = 'wpjunior' p0.name = 'wpjunior'
p0.save() p0.save()
def test_save_max_recursion_not_hit_with_file_field(self): def ztest_save_max_recursion_not_hit_with_file_field(self):
class Foo(Document): class Foo(Document):
name = StringField() name = StringField()
@ -696,7 +712,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(b.picture, b.bar.picture, b.bar.bar.picture) self.assertEqual(b.picture, b.bar.picture, b.bar.bar.picture)
def test_save_cascades(self): def ztest_save_cascades(self):
class Person(Document): class Person(Document):
name = StringField() name = StringField()
@ -719,7 +735,7 @@ class InstanceTest(unittest.TestCase):
p1.reload() p1.reload()
self.assertEqual(p1.name, p.parent.name) self.assertEqual(p1.name, p.parent.name)
def test_save_cascade_kwargs(self): def ztest_save_cascade_kwargs(self):
class Person(Document): class Person(Document):
name = StringField() name = StringField()
@ -740,7 +756,7 @@ class InstanceTest(unittest.TestCase):
p2.reload() p2.reload()
self.assertEqual(p1.name, p2.parent.name) self.assertEqual(p1.name, p2.parent.name)
def test_save_cascade_meta_false(self): def ztest_save_cascade_meta_false(self):
class Person(Document): class Person(Document):
name = StringField() name = StringField()
@ -769,7 +785,7 @@ class InstanceTest(unittest.TestCase):
p1.reload() p1.reload()
self.assertEqual(p1.name, p.parent.name) self.assertEqual(p1.name, p.parent.name)
def test_save_cascade_meta_true(self): def ztest_save_cascade_meta_true(self):
class Person(Document): class Person(Document):
name = StringField() name = StringField()
@ -794,7 +810,7 @@ class InstanceTest(unittest.TestCase):
p1.reload() p1.reload()
self.assertNotEqual(p1.name, p.parent.name) self.assertNotEqual(p1.name, p.parent.name)
def test_save_cascades_generically(self): def ztest_save_cascades_generically(self):
class Person(Document): class Person(Document):
name = StringField() name = StringField()
@ -820,7 +836,7 @@ class InstanceTest(unittest.TestCase):
p1.reload() p1.reload()
self.assertEqual(p1.name, p.parent.name) self.assertEqual(p1.name, p.parent.name)
def test_save_atomicity_condition(self): def ztest_save_atomicity_condition(self):
class Widget(Document): class Widget(Document):
toggle = BooleanField(default=False) toggle = BooleanField(default=False)
@ -835,7 +851,7 @@ class InstanceTest(unittest.TestCase):
return uuid.UUID(int=i) return uuid.UUID(int=i)
Widget.drop_collection() Widget.drop_collection()
w1 = Widget(toggle=False, save_id=UUID(1)) w1 = Widget(toggle=False, save_id=UUID(1))
# ignore save_condition on new record creation # ignore save_condition on new record creation
@ -893,8 +909,8 @@ class InstanceTest(unittest.TestCase):
w1.reload() w1.reload()
self.assertTrue(w1.toggle) self.assertTrue(w1.toggle)
self.assertEqual(w1.count, 3) self.assertEqual(w1.count, 3)
def test_update(self): def ztest_update(self):
"""Ensure that an existing document is updated instead of be """Ensure that an existing document is updated instead of be
overwritten.""" overwritten."""
# Create person object and save it to the database # Create person object and save it to the database
@ -978,7 +994,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(person.name, None) self.assertEqual(person.name, None)
self.assertEqual(person.age, None) self.assertEqual(person.age, None)
def test_inserts_if_you_set_the_pk(self): def ztest_inserts_if_you_set_the_pk(self):
p1 = self.Person(name='p1', id=bson.ObjectId()).save() p1 = self.Person(name='p1', id=bson.ObjectId()).save()
p2 = self.Person(name='p2') p2 = self.Person(name='p2')
p2.id = bson.ObjectId() p2.id = bson.ObjectId()
@ -986,7 +1002,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(2, self.Person.objects.count()) self.assertEqual(2, self.Person.objects.count())
def test_can_save_if_not_included(self): def ztest_can_save_if_not_included(self):
class EmbeddedDoc(EmbeddedDocument): class EmbeddedDoc(EmbeddedDocument):
pass pass
@ -1035,7 +1051,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(my_doc.string_field, "string") self.assertEqual(my_doc.string_field, "string")
self.assertEqual(my_doc.int_field, 1) self.assertEqual(my_doc.int_field, 1)
def test_document_update(self): def ztest_document_update(self):
def update_not_saved_raises(): def update_not_saved_raises():
person = self.Person(name='dcrosta') person = self.Person(name='dcrosta')
@ -1064,7 +1080,7 @@ class InstanceTest(unittest.TestCase):
self.assertRaises(InvalidQueryError, update_no_op_raises) self.assertRaises(InvalidQueryError, update_no_op_raises)
def test_update_unique_field(self): def ztest_update_unique_field(self):
class Doc(Document): class Doc(Document):
name = StringField(unique=True) name = StringField(unique=True)
@ -1074,7 +1090,7 @@ class InstanceTest(unittest.TestCase):
self.assertRaises(NotUniqueError, lambda: self.assertRaises(NotUniqueError, lambda:
doc2.update(set__name=doc1.name)) doc2.update(set__name=doc1.name))
def test_embedded_update(self): def ztest_embedded_update(self):
""" """
Test update on `EmbeddedDocumentField` fields Test update on `EmbeddedDocumentField` fields
""" """
@ -1098,7 +1114,7 @@ class InstanceTest(unittest.TestCase):
site = Site.objects.first() site = Site.objects.first()
self.assertEqual(site.page.log_message, "Error: Dummy message") self.assertEqual(site.page.log_message, "Error: Dummy message")
def test_embedded_update_db_field(self): def ztest_embedded_update_db_field(self):
""" """
Test update on `EmbeddedDocumentField` fields when db_field is other Test update on `EmbeddedDocumentField` fields when db_field is other
than default. than default.
@ -1125,7 +1141,7 @@ class InstanceTest(unittest.TestCase):
site = Site.objects.first() site = Site.objects.first()
self.assertEqual(site.page.log_message, "Error: Dummy message") self.assertEqual(site.page.log_message, "Error: Dummy message")
def test_save_only_changed_fields(self): def ztest_save_only_changed_fields(self):
"""Ensure save only sets / unsets changed fields """Ensure save only sets / unsets changed fields
""" """
@ -1154,7 +1170,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(person.age, 21) self.assertEqual(person.age, 21)
self.assertEqual(person.active, False) self.assertEqual(person.active, False)
def test_query_count_when_saving(self): def ztest_query_count_when_saving(self):
"""Ensure references don't cause extra fetches when saving""" """Ensure references don't cause extra fetches when saving"""
class Organization(Document): class Organization(Document):
name = StringField() name = StringField()
@ -1247,7 +1263,7 @@ class InstanceTest(unittest.TestCase):
sub.save() sub.save()
self.assertEqual(q, 3) self.assertEqual(q, 3)
def test_set_unset_one_operation(self): def ztest_set_unset_one_operation(self):
"""Ensure that $set and $unset actions are performed in the same """Ensure that $set and $unset actions are performed in the same
operation. operation.
""" """
@ -1269,7 +1285,7 @@ class InstanceTest(unittest.TestCase):
foo.save() foo.save()
self.assertEqual(1, q) self.assertEqual(1, q)
def test_save_only_changed_fields_recursive(self): def ztest_save_only_changed_fields_recursive(self):
"""Ensure save only sets / unsets changed fields """Ensure save only sets / unsets changed fields
""" """
@ -1311,7 +1327,7 @@ class InstanceTest(unittest.TestCase):
person = self.Person.objects.get() person = self.Person.objects.get()
self.assertFalse(person.comments_dict['first_post'].published) self.assertFalse(person.comments_dict['first_post'].published)
def test_delete(self): def ztest_delete(self):
"""Ensure that document may be deleted using the delete method. """Ensure that document may be deleted using the delete method.
""" """
person = self.Person(name="Test User", age=30) person = self.Person(name="Test User", age=30)
@ -1320,7 +1336,7 @@ class InstanceTest(unittest.TestCase):
person.delete() person.delete()
self.assertEqual(self.Person.objects.count(), 0) self.assertEqual(self.Person.objects.count(), 0)
def test_save_custom_id(self): def ztest_save_custom_id(self):
"""Ensure that a document may be saved with a custom _id. """Ensure that a document may be saved with a custom _id.
""" """
# Create person object and save it to the database # Create person object and save it to the database
@ -1332,7 +1348,7 @@ class InstanceTest(unittest.TestCase):
person_obj = collection.find_one({'name': 'Test User'}) person_obj = collection.find_one({'name': 'Test User'})
self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4') self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4')
def test_save_custom_pk(self): def ztest_save_custom_pk(self):
"""Ensure that a document may be saved with a custom _id using pk alias. """Ensure that a document may be saved with a custom _id using pk alias.
""" """
# Create person object and save it to the database # Create person object and save it to the database
@ -1344,7 +1360,7 @@ class InstanceTest(unittest.TestCase):
person_obj = collection.find_one({'name': 'Test User'}) person_obj = collection.find_one({'name': 'Test User'})
self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4') self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4')
def test_save_list(self): def ztest_save_list(self):
"""Ensure that a list field may be properly saved. """Ensure that a list field may be properly saved.
""" """
class Comment(EmbeddedDocument): class Comment(EmbeddedDocument):
@ -1371,7 +1387,7 @@ class InstanceTest(unittest.TestCase):
BlogPost.drop_collection() BlogPost.drop_collection()
def test_list_search_by_embedded(self): def ztest_list_search_by_embedded(self):
class User(Document): class User(Document):
username = StringField(required=True) username = StringField(required=True)
@ -1423,7 +1439,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual([p1, p2, p4], list(Page.objects.filter(comments__user=u2))) self.assertEqual([p1, p2, p4], list(Page.objects.filter(comments__user=u2)))
self.assertEqual([p1, p3], list(Page.objects.filter(comments__user=u3))) self.assertEqual([p1, p3], list(Page.objects.filter(comments__user=u3)))
def test_save_embedded_document(self): def ztest_save_embedded_document(self):
"""Ensure that a document with an embedded document field may be """Ensure that a document with an embedded document field may be
saved in the database. saved in the database.
""" """
@ -1447,7 +1463,7 @@ class InstanceTest(unittest.TestCase):
# Ensure that the 'details' embedded object saved correctly # Ensure that the 'details' embedded object saved correctly
self.assertEqual(employee_obj['details']['position'], 'Developer') self.assertEqual(employee_obj['details']['position'], 'Developer')
def test_embedded_update_after_save(self): def ztest_embedded_update_after_save(self):
""" """
Test update of `EmbeddedDocumentField` attached to a newly saved Test update of `EmbeddedDocumentField` attached to a newly saved
document. document.
@ -1470,7 +1486,7 @@ class InstanceTest(unittest.TestCase):
site = Site.objects.first() site = Site.objects.first()
self.assertEqual(site.page.log_message, "Error: Dummy message") self.assertEqual(site.page.log_message, "Error: Dummy message")
def test_updating_an_embedded_document(self): def ztest_updating_an_embedded_document(self):
"""Ensure that a document with an embedded document field may be """Ensure that a document with an embedded document field may be
saved in the database. saved in the database.
""" """
@ -1505,7 +1521,7 @@ class InstanceTest(unittest.TestCase):
promoted_employee.reload() promoted_employee.reload()
self.assertEqual(promoted_employee.details, None) self.assertEqual(promoted_employee.details, None)
def test_object_mixins(self): def ztest_object_mixins(self):
class NameMixin(object): class NameMixin(object):
name = StringField() name = StringField()
@ -1520,7 +1536,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(['id', 'name', 'widgets'], sorted(Bar._fields.keys())) self.assertEqual(['id', 'name', 'widgets'], sorted(Bar._fields.keys()))
def test_mixin_inheritance(self): def ztest_mixin_inheritance(self):
class BaseMixIn(object): class BaseMixIn(object):
count = IntField() count = IntField()
data = StringField() data = StringField()
@ -1544,7 +1560,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(t.data, "test") self.assertEqual(t.data, "test")
self.assertEqual(t.count, 12) self.assertEqual(t.count, 12)
def test_save_reference(self): def ztest_save_reference(self):
"""Ensure that a document reference field may be saved in the database. """Ensure that a document reference field may be saved in the database.
""" """
@ -1580,7 +1596,7 @@ class InstanceTest(unittest.TestCase):
BlogPost.drop_collection() BlogPost.drop_collection()
def test_duplicate_db_fields_raise_invalid_document_error(self): def ztest_duplicate_db_fields_raise_invalid_document_error(self):
"""Ensure a InvalidDocumentError is thrown if duplicate fields """Ensure a InvalidDocumentError is thrown if duplicate fields
declare the same db_field""" declare the same db_field"""
@ -1591,7 +1607,7 @@ class InstanceTest(unittest.TestCase):
self.assertRaises(InvalidDocumentError, throw_invalid_document_error) self.assertRaises(InvalidDocumentError, throw_invalid_document_error)
def test_invalid_son(self): def ztest_invalid_son(self):
"""Raise an error if loading invalid data""" """Raise an error if loading invalid data"""
class Occurrence(EmbeddedDocument): class Occurrence(EmbeddedDocument):
number = IntField() number = IntField()
@ -1608,7 +1624,7 @@ class InstanceTest(unittest.TestCase):
self.assertRaises(InvalidDocumentError, raise_invalid_document) self.assertRaises(InvalidDocumentError, raise_invalid_document)
def test_reverse_delete_rule_cascade_and_nullify(self): def ztest_reverse_delete_rule_cascade_and_nullify(self):
"""Ensure that a referenced document is also deleted upon deletion. """Ensure that a referenced document is also deleted upon deletion.
""" """
@ -1639,7 +1655,7 @@ class InstanceTest(unittest.TestCase):
author.delete() author.delete()
self.assertEqual(BlogPost.objects.count(), 0) self.assertEqual(BlogPost.objects.count(), 0)
def test_reverse_delete_rule_with_document_inheritance(self): def ztest_reverse_delete_rule_with_document_inheritance(self):
"""Ensure that a referenced document is also deleted upon deletion """Ensure that a referenced document is also deleted upon deletion
of a child document. of a child document.
""" """
@ -1674,7 +1690,7 @@ class InstanceTest(unittest.TestCase):
author.delete() author.delete()
self.assertEqual(BlogPost.objects.count(), 0) self.assertEqual(BlogPost.objects.count(), 0)
def test_reverse_delete_rule_cascade_and_nullify_complex_field(self): def ztest_reverse_delete_rule_cascade_and_nullify_complex_field(self):
"""Ensure that a referenced document is also deleted upon deletion for """Ensure that a referenced document is also deleted upon deletion for
complex fields. complex fields.
""" """
@ -1708,7 +1724,7 @@ class InstanceTest(unittest.TestCase):
author.delete() author.delete()
self.assertEqual(BlogPost.objects.count(), 0) self.assertEqual(BlogPost.objects.count(), 0)
def test_reverse_delete_rule_cascade_triggers_pre_delete_signal(self): def ztest_reverse_delete_rule_cascade_triggers_pre_delete_signal(self):
''' ensure the pre_delete signal is triggered upon a cascading deletion ''' ensure the pre_delete signal is triggered upon a cascading deletion
setup a blog post with content, an author and editor setup a blog post with content, an author and editor
delete the author which triggers deletion of blogpost via cascade delete the author which triggers deletion of blogpost via cascade
@ -1744,7 +1760,7 @@ class InstanceTest(unittest.TestCase):
editor = Editor.objects(name='Max P.').get() editor = Editor.objects(name='Max P.').get()
self.assertEqual(editor.review_queue, 0) self.assertEqual(editor.review_queue, 0)
def test_two_way_reverse_delete_rule(self): def ztest_two_way_reverse_delete_rule(self):
"""Ensure that Bi-Directional relationships work with """Ensure that Bi-Directional relationships work with
reverse_delete_rule reverse_delete_rule
""" """
@ -1777,7 +1793,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(Bar.objects.count(), 1) # No effect on the BlogPost self.assertEqual(Bar.objects.count(), 1) # No effect on the BlogPost
self.assertEqual(Bar.objects.get().foo, None) self.assertEqual(Bar.objects.get().foo, None)
def test_invalid_reverse_delete_rules_raise_errors(self): def ztest_invalid_reverse_delete_rules_raise_errors(self):
def throw_invalid_document_error(): def throw_invalid_document_error():
class Blog(Document): class Blog(Document):
@ -1794,7 +1810,7 @@ class InstanceTest(unittest.TestCase):
self.assertRaises(InvalidDocumentError, throw_invalid_document_error_embedded) self.assertRaises(InvalidDocumentError, throw_invalid_document_error_embedded)
def test_reverse_delete_rule_cascade_recurs(self): def ztest_reverse_delete_rule_cascade_recurs(self):
"""Ensure that a chain of documents is also deleted upon cascaded """Ensure that a chain of documents is also deleted upon cascaded
deletion. deletion.
""" """
@ -1831,7 +1847,7 @@ class InstanceTest(unittest.TestCase):
BlogPost.drop_collection() BlogPost.drop_collection()
Comment.drop_collection() Comment.drop_collection()
def test_reverse_delete_rule_deny(self): def ztest_reverse_delete_rule_deny(self):
"""Ensure that a document cannot be referenced if there are still """Ensure that a document cannot be referenced if there are still
documents referring to it. documents referring to it.
""" """
@ -1886,7 +1902,7 @@ class InstanceTest(unittest.TestCase):
A.drop_collection() A.drop_collection()
B.drop_collection() B.drop_collection()
def test_document_hash(self): def ztest_document_hash(self):
"""Test document in list, dict, set """Test document in list, dict, set
""" """
class User(Document): class User(Document):
@ -1934,7 +1950,7 @@ class InstanceTest(unittest.TestCase):
self.assertTrue(u1 in all_user_set) self.assertTrue(u1 in all_user_set)
def test_picklable(self): def ztest_picklable(self):
pickle_doc = PickleTest(number=1, string="One", lists=['1', '2']) pickle_doc = PickleTest(number=1, string="One", lists=['1', '2'])
pickle_doc.embedded = PickleEmbedded() pickle_doc.embedded = PickleEmbedded()
@ -1960,7 +1976,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(pickle_doc.string, "Two") self.assertEqual(pickle_doc.string, "Two")
self.assertEqual(pickle_doc.lists, ["1", "2", "3"]) self.assertEqual(pickle_doc.lists, ["1", "2", "3"])
def test_dynamic_document_pickle(self): def ztest_dynamic_document_pickle(self):
pickle_doc = PickleDynamicTest(name="test", number=1, string="One", lists=['1', '2']) pickle_doc = PickleDynamicTest(name="test", number=1, string="One", lists=['1', '2'])
pickle_doc.embedded = PickleDyanmicEmbedded(foo="Bar") pickle_doc.embedded = PickleDyanmicEmbedded(foo="Bar")
@ -1983,13 +1999,13 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(resurrected.embedded._dynamic_fields.keys(), self.assertEqual(resurrected.embedded._dynamic_fields.keys(),
pickle_doc.embedded._dynamic_fields.keys()) pickle_doc.embedded._dynamic_fields.keys())
def test_picklable_on_signals(self): def ztest_picklable_on_signals(self):
pickle_doc = PickleSignalsTest(number=1, string="One", lists=['1', '2']) pickle_doc = PickleSignalsTest(number=1, string="One", lists=['1', '2'])
pickle_doc.embedded = PickleEmbedded() pickle_doc.embedded = PickleEmbedded()
pickle_doc.save() pickle_doc.save()
pickle_doc.delete() pickle_doc.delete()
def test_throw_invalid_document_error(self): def ztest_throw_invalid_document_error(self):
# test handles people trying to upsert # test handles people trying to upsert
def throw_invalid_document_error(): def throw_invalid_document_error():
@ -1998,7 +2014,7 @@ class InstanceTest(unittest.TestCase):
self.assertRaises(InvalidDocumentError, throw_invalid_document_error) self.assertRaises(InvalidDocumentError, throw_invalid_document_error)
def test_mutating_documents(self): def ztest_mutating_documents(self):
class B(EmbeddedDocument): class B(EmbeddedDocument):
field1 = StringField(default='field1') field1 = StringField(default='field1')
@ -2029,7 +2045,7 @@ class InstanceTest(unittest.TestCase):
a.reload() a.reload()
self.assertEqual(a.b.field2.c_field, 'new value') self.assertEqual(a.b.field2.c_field, 'new value')
def test_can_save_false_values(self): def ztest_can_save_false_values(self):
"""Ensures you can save False values on save""" """Ensures you can save False values on save"""
class Doc(Document): class Doc(Document):
foo = StringField() foo = StringField()
@ -2043,7 +2059,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(Doc.objects(archived=False).count(), 1) self.assertEqual(Doc.objects(archived=False).count(), 1)
def test_can_save_false_values_dynamic(self): def ztest_can_save_false_values_dynamic(self):
"""Ensures you can save False values on dynamic docs""" """Ensures you can save False values on dynamic docs"""
class Doc(DynamicDocument): class Doc(DynamicDocument):
foo = StringField() foo = StringField()
@ -2056,7 +2072,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(Doc.objects(archived=False).count(), 1) self.assertEqual(Doc.objects(archived=False).count(), 1)
def test_do_not_save_unchanged_references(self): def ztest_do_not_save_unchanged_references(self):
"""Ensures cascading saves dont auto update""" """Ensures cascading saves dont auto update"""
class Job(Document): class Job(Document):
name = StringField() name = StringField()
@ -2087,7 +2103,7 @@ class InstanceTest(unittest.TestCase):
finally: finally:
Collection.update = orig_update Collection.update = orig_update
def test_db_alias_tests(self): def ztest_db_alias_tests(self):
""" DB Alias tests """ """ DB Alias tests """
# mongoenginetest - Is default connection alias from setUp() # mongoenginetest - Is default connection alias from setUp()
# Register Aliases # Register Aliases
@ -2143,7 +2159,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(Book._get_collection(), get_db("testdb-2")[Book._get_collection_name()]) self.assertEqual(Book._get_collection(), get_db("testdb-2")[Book._get_collection_name()])
self.assertEqual(AuthorBooks._get_collection(), get_db("testdb-3")[AuthorBooks._get_collection_name()]) self.assertEqual(AuthorBooks._get_collection(), get_db("testdb-3")[AuthorBooks._get_collection_name()])
def test_db_alias_overrides(self): def ztest_db_alias_overrides(self):
"""db_alias can be overriden """db_alias can be overriden
""" """
# Register a connection with db_alias testdb-2 # Register a connection with db_alias testdb-2
@ -2168,7 +2184,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual('mongoenginetest2', self.assertEqual('mongoenginetest2',
B._get_collection().database.name) B._get_collection().database.name)
def test_db_alias_propagates(self): def ztest_db_alias_propagates(self):
"""db_alias propagates? """db_alias propagates?
""" """
register_connection('testdb-1', 'mongoenginetest2') register_connection('testdb-1', 'mongoenginetest2')
@ -2182,7 +2198,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual('testdb-1', B._meta.get('db_alias')) self.assertEqual('testdb-1', B._meta.get('db_alias'))
def test_db_ref_usage(self): def ztest_db_ref_usage(self):
""" DB Ref usage in dict_fields""" """ DB Ref usage in dict_fields"""
class User(Document): class User(Document):
@ -2260,7 +2276,7 @@ class InstanceTest(unittest.TestCase):
})]), })]),
"1,2") "1,2")
def test_switch_db_instance(self): def ztest_switch_db_instance(self):
register_connection('testdb-1', 'mongoenginetest2') register_connection('testdb-1', 'mongoenginetest2')
class Group(Document): class Group(Document):
@ -2310,7 +2326,7 @@ class InstanceTest(unittest.TestCase):
group = Group.objects.first() group = Group.objects.first()
self.assertEqual("hello - default", group.name) self.assertEqual("hello - default", group.name)
def test_no_overwritting_no_data_loss(self): def ztest_no_overwritting_no_data_loss(self):
class User(Document): class User(Document):
username = StringField(primary_key=True) username = StringField(primary_key=True)
@ -2334,7 +2350,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual("Bar", user._data["foo"]) self.assertEqual("Bar", user._data["foo"])
self.assertEqual([1, 2, 3], user._data["data"]) self.assertEqual([1, 2, 3], user._data["data"])
def test_spaces_in_keys(self): def ztest_spaces_in_keys(self):
class Embedded(DynamicEmbeddedDocument): class Embedded(DynamicEmbeddedDocument):
pass pass
@ -2350,7 +2366,7 @@ class InstanceTest(unittest.TestCase):
one = Doc.objects.filter(**{'hello world': 1}).count() one = Doc.objects.filter(**{'hello world': 1}).count()
self.assertEqual(1, one) self.assertEqual(1, one)
def test_shard_key(self): def ztest_shard_key(self):
class LogEntry(Document): class LogEntry(Document):
machine = StringField() machine = StringField()
log = StringField() log = StringField()
@ -2375,7 +2391,7 @@ class InstanceTest(unittest.TestCase):
self.assertRaises(OperationError, change_shard_key) self.assertRaises(OperationError, change_shard_key)
def test_shard_key_primary(self): def ztest_shard_key_primary(self):
class LogEntry(Document): class LogEntry(Document):
machine = StringField(primary_key=True) machine = StringField(primary_key=True)
log = StringField() log = StringField()
@ -2400,7 +2416,7 @@ class InstanceTest(unittest.TestCase):
self.assertRaises(OperationError, change_shard_key) self.assertRaises(OperationError, change_shard_key)
def test_kwargs_simple(self): def ztest_kwargs_simple(self):
class Embedded(EmbeddedDocument): class Embedded(EmbeddedDocument):
name = StringField() name = StringField()
@ -2416,7 +2432,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(classic_doc, dict_doc) self.assertEqual(classic_doc, dict_doc)
self.assertEqual(classic_doc._data, dict_doc._data) self.assertEqual(classic_doc._data, dict_doc._data)
def test_kwargs_complex(self): def ztest_kwargs_complex(self):
class Embedded(EmbeddedDocument): class Embedded(EmbeddedDocument):
name = StringField() name = StringField()
@ -2435,21 +2451,21 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(classic_doc, dict_doc) self.assertEqual(classic_doc, dict_doc)
self.assertEqual(classic_doc._data, dict_doc._data) self.assertEqual(classic_doc._data, dict_doc._data)
def test_positional_creation(self): def ztest_positional_creation(self):
"""Ensure that document may be created using positional arguments. """Ensure that document may be created using positional arguments.
""" """
person = self.Person("Test User", 42) person = self.Person("Test User", 42)
self.assertEqual(person.name, "Test User") self.assertEqual(person.name, "Test User")
self.assertEqual(person.age, 42) self.assertEqual(person.age, 42)
def test_mixed_creation(self): def ztest_mixed_creation(self):
"""Ensure that document may be created using mixed arguments. """Ensure that document may be created using mixed arguments.
""" """
person = self.Person("Test User", age=42) person = self.Person("Test User", age=42)
self.assertEqual(person.name, "Test User") self.assertEqual(person.name, "Test User")
self.assertEqual(person.age, 42) self.assertEqual(person.age, 42)
def test_mixed_creation_dynamic(self): def ztest_mixed_creation_dynamic(self):
"""Ensure that document may be created using mixed arguments. """Ensure that document may be created using mixed arguments.
""" """
class Person(DynamicDocument): class Person(DynamicDocument):
@ -2459,7 +2475,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(person.name, "Test User") self.assertEqual(person.name, "Test User")
self.assertEqual(person.age, 42) self.assertEqual(person.age, 42)
def test_bad_mixed_creation(self): def ztest_bad_mixed_creation(self):
"""Ensure that document gives correct error when duplicating arguments """Ensure that document gives correct error when duplicating arguments
""" """
def construct_bad_instance(): def construct_bad_instance():
@ -2467,7 +2483,7 @@ class InstanceTest(unittest.TestCase):
self.assertRaises(TypeError, construct_bad_instance) self.assertRaises(TypeError, construct_bad_instance)
def test_data_contains_id_field(self): def ztest_data_contains_id_field(self):
"""Ensure that asking for _data returns 'id' """Ensure that asking for _data returns 'id'
""" """
class Person(Document): class Person(Document):
@ -2480,7 +2496,7 @@ class InstanceTest(unittest.TestCase):
self.assertTrue('id' in person._data.keys()) self.assertTrue('id' in person._data.keys())
self.assertEqual(person._data.get('id'), person.id) self.assertEqual(person._data.get('id'), person.id)
def test_complex_nesting_document_and_embedded_document(self): def ztest_complex_nesting_document_and_embedded_document(self):
class Macro(EmbeddedDocument): class Macro(EmbeddedDocument):
value = DynamicField(default="UNDEFINED") value = DynamicField(default="UNDEFINED")
@ -2521,7 +2537,7 @@ class InstanceTest(unittest.TestCase):
system = NodesSystem.objects.first() system = NodesSystem.objects.first()
self.assertEqual("UNDEFINED", system.nodes["node"].parameters["param"].macros["test"].value) self.assertEqual("UNDEFINED", system.nodes["node"].parameters["param"].macros["test"].value)
def test_embedded_document_equality(self): def ztest_embedded_document_equality(self):
class Test(Document): class Test(Document):
field = StringField(required=True) field = StringField(required=True)