# -*- coding: utf-8 -*- from __future__ import with_statement import sys sys.path[0:0] = [""] import bson import os import pickle import unittest import uuid from datetime import datetime from tests.fixtures import PickleEmbedded, PickleTest from mongoengine import * from mongoengine.errors import (NotRegistered, InvalidDocumentError, InvalidQueryError) from mongoengine.queryset import NULLIFY, Q from mongoengine.connection import get_db TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png') __all__ = ("InstanceTest",) class InstanceTest(unittest.TestCase): def setUp(self): connect(db='mongoenginetest') self.db = get_db() class Person(Document): name = StringField() age = IntField() non_field = True meta = {"allow_inheritance": True} self.Person = Person def tearDown(self): for collection in self.db.collection_names(): if 'system.' in collection: continue self.db.drop_collection(collection) def test_capped_collection(self): """Ensure that capped collections work properly. """ class Log(Document): date = DateTimeField(default=datetime.now) meta = { 'max_documents': 10, 'max_size': 90000, } Log.drop_collection() # Ensure that the collection handles up to its maximum for _ in range(10): Log().save() self.assertEqual(len(Log.objects), 10) # Check that extra documents don't increase the size Log().save() self.assertEqual(len(Log.objects), 10) options = Log.objects._collection.options() self.assertEqual(options['capped'], True) self.assertEqual(options['max'], 10) self.assertEqual(options['size'], 90000) # Check that the document cannot be redefined with different options def recreate_log_document(): class Log(Document): date = DateTimeField(default=datetime.now) meta = { 'max_documents': 11, } # Create the collection by accessing Document.objects Log.objects self.assertRaises(InvalidCollectionError, recreate_log_document) Log.drop_collection() def test_repr(self): """Ensure that unicode representation works """ class Article(Document): title = StringField() def __unicode__(self): return self.title doc = Article(title=u'привет мир') self.assertEqual('', repr(doc)) def test_queryset_resurrects_dropped_collection(self): self.Person.drop_collection() self.assertEqual([], list(self.Person.objects())) class Actor(self.Person): pass # Ensure works correctly with inhertited classes Actor.objects() self.Person.drop_collection() self.assertEqual([], list(Actor.objects())) def test_polymorphic_references(self): """Ensure that the correct subclasses are returned from a query when using references / generic references """ class Animal(Document): meta = {'allow_inheritance': True} class Fish(Animal): pass class Mammal(Animal): pass class Dog(Mammal): pass class Human(Mammal): pass class Zoo(Document): animals = ListField(ReferenceField(Animal)) Zoo.drop_collection() Animal.drop_collection() Animal().save() Fish().save() Mammal().save() Dog().save() Human().save() # Save a reference to each animal zoo = Zoo(animals=Animal.objects) zoo.save() zoo.reload() classes = [a.__class__ for a in Zoo.objects.first().animals] self.assertEqual(classes, [Animal, Fish, Mammal, Dog, Human]) Zoo.drop_collection() class Zoo(Document): animals = ListField(GenericReferenceField(Animal)) # Save a reference to each animal zoo = Zoo(animals=Animal.objects) zoo.save() zoo.reload() classes = [a.__class__ for a in Zoo.objects.first().animals] self.assertEqual(classes, [Animal, Fish, Mammal, Dog, Human]) Zoo.drop_collection() Animal.drop_collection() def test_reference_inheritance(self): class Stats(Document): created = DateTimeField(default=datetime.now) meta = {'allow_inheritance': False} class CompareStats(Document): generated = DateTimeField(default=datetime.now) stats = ListField(ReferenceField(Stats)) Stats.drop_collection() CompareStats.drop_collection() list_stats = [] for i in xrange(10): s = Stats() s.save() list_stats.append(s) cmp_stats = CompareStats(stats=list_stats) cmp_stats.save() self.assertEqual(list_stats, CompareStats.objects.first().stats) def test_db_field_load(self): """Ensure we load data correctly """ class Person(Document): name = StringField(required=True) _rank = StringField(required=False, db_field="rank") @property def rank(self): return self._rank or "Private" Person.drop_collection() Person(name="Jack", _rank="Corporal").save() Person(name="Fred").save() self.assertEqual(Person.objects.get(name="Jack").rank, "Corporal") self.assertEqual(Person.objects.get(name="Fred").rank, "Private") def test_db_embedded_doc_field_load(self): """Ensure we load embedded document data correctly """ class Rank(EmbeddedDocument): title = StringField(required=True) class Person(Document): name = StringField(required=True) rank_ = EmbeddedDocumentField(Rank, required=False, db_field='rank') @property def rank(self): return self.rank_.title if self.rank_ is not None else "Private" Person.drop_collection() Person(name="Jack", rank_=Rank(title="Corporal")).save() Person(name="Fred").save() self.assertEqual(Person.objects.get(name="Jack").rank, "Corporal") self.assertEqual(Person.objects.get(name="Fred").rank, "Private") def test_custom_id_field(self): """Ensure that documents may be created with custom primary keys. """ class User(Document): username = StringField(primary_key=True) name = StringField() meta = {'allow_inheritance': True} User.drop_collection() self.assertEqual(User._fields['username'].db_field, '_id') self.assertEqual(User._meta['id_field'], 'username') def create_invalid_user(): User(name='test').save() # no primary key field self.assertRaises(ValidationError, create_invalid_user) def define_invalid_user(): class EmailUser(User): email = StringField(primary_key=True) self.assertRaises(ValueError, define_invalid_user) class EmailUser(User): email = StringField() user = User(username='test', name='test user') user.save() user_obj = User.objects.first() self.assertEqual(user_obj.id, 'test') self.assertEqual(user_obj.pk, 'test') user_son = User.objects._collection.find_one() self.assertEqual(user_son['_id'], 'test') self.assertTrue('username' not in user_son['_id']) User.drop_collection() user = User(pk='mongo', name='mongo user') user.save() user_obj = User.objects.first() self.assertEqual(user_obj.id, 'mongo') self.assertEqual(user_obj.pk, 'mongo') user_son = User.objects._collection.find_one() self.assertEqual(user_son['_id'], 'mongo') self.assertTrue('username' not in user_son['_id']) User.drop_collection() def test_document_not_registered(self): class Place(Document): name = StringField() meta = {'allow_inheritance': True} class NicePlace(Place): pass Place.drop_collection() Place(name="London").save() NicePlace(name="Buckingham Palace").save() # Mimic Place and NicePlace definitions being in a different file # and the NicePlace model not being imported in at query time. from mongoengine.base import _document_registry del(_document_registry['Place.NicePlace']) def query_without_importing_nice_place(): print Place.objects.all() self.assertRaises(NotRegistered, query_without_importing_nice_place) def test_creation(self): """Ensure that document may be created using keyword arguments. """ person = self.Person(name="Test User", age=30) self.assertEqual(person.name, "Test User") self.assertEqual(person.age, 30) def test_to_dbref(self): """Ensure that you can get a dbref of a document""" person = self.Person(name="Test User", age=30) self.assertRaises(OperationError, person.to_dbref) person.save() person.to_dbref() def test_reload(self): """Ensure that attributes may be reloaded. """ person = self.Person(name="Test User", age=20) person.save() person_obj = self.Person.objects.first() person_obj.name = "Mr Test User" person_obj.age = 21 person_obj.save() self.assertEqual(person.name, "Test User") self.assertEqual(person.age, 20) person.reload() self.assertEqual(person.name, "Mr Test User") self.assertEqual(person.age, 21) def test_reload_sharded(self): class Animal(Document): superphylum = StringField() meta = {'shard_key': ('superphylum',)} Animal.drop_collection() doc = Animal(superphylum = 'Deuterostomia') doc.save() doc.reload() Animal.drop_collection() def test_reload_referencing(self): """Ensures reloading updates weakrefs correctly """ class Embedded(EmbeddedDocument): dict_field = DictField() list_field = ListField() class Doc(Document): dict_field = DictField() list_field = ListField() embedded_field = EmbeddedDocumentField(Embedded) Doc.drop_collection() doc = Doc() doc.dict_field = {'hello': 'world'} doc.list_field = ['1', 2, {'hello': 'world'}] embedded_1 = Embedded() embedded_1.dict_field = {'hello': 'world'} embedded_1.list_field = ['1', 2, {'hello': 'world'}] doc.embedded_field = embedded_1 doc.save() doc = doc.reload(10) doc.list_field.append(1) doc.dict_field['woot'] = "woot" doc.embedded_field.list_field.append(1) doc.embedded_field.dict_field['woot'] = "woot" self.assertEqual(doc._get_changed_fields(), [ 'list_field', 'dict_field', 'embedded_field.list_field', 'embedded_field.dict_field']) doc.save() doc = doc.reload(10) self.assertEqual(doc._get_changed_fields(), []) self.assertEqual(len(doc.list_field), 4) self.assertEqual(len(doc.dict_field), 2) self.assertEqual(len(doc.embedded_field.list_field), 4) self.assertEqual(len(doc.embedded_field.dict_field), 2) def test_dictionary_access(self): """Ensure that dictionary-style field access works properly. """ person = self.Person(name='Test User', age=30) self.assertEqual(person['name'], 'Test User') self.assertRaises(KeyError, person.__getitem__, 'salary') self.assertRaises(KeyError, person.__setitem__, 'salary', 50) person['name'] = 'Another User' self.assertEqual(person['name'], 'Another User') # Length = length(assigned fields + id) self.assertEqual(len(person), 3) self.assertTrue('age' in person) person.age = None self.assertFalse('age' in person) self.assertFalse('nationality' in person) def test_embedded_document(self): """Ensure that embedded documents are set up correctly. """ class Comment(EmbeddedDocument): content = StringField() self.assertTrue('content' in Comment._fields) self.assertFalse('id' in Comment._fields) def test_embedded_document_validation(self): """Ensure that embedded documents may be validated. """ class Comment(EmbeddedDocument): date = DateTimeField() content = StringField(required=True) comment = Comment() self.assertRaises(ValidationError, comment.validate) comment.content = 'test' comment.validate() comment.date = 4 self.assertRaises(ValidationError, comment.validate) comment.date = datetime.now() comment.validate() def test_embedded_db_field_validate(self): class SubDoc(EmbeddedDocument): val = IntField() class Doc(Document): e = EmbeddedDocumentField(SubDoc, db_field='eb') Doc.drop_collection() Doc(e=SubDoc(val=15)).save() doc = Doc.objects.first() doc.validate() keys = doc._data.keys() self.assertEqual(2, len(keys)) self.assertTrue('id' in keys) self.assertTrue('e' in keys) def test_save(self): """Ensure that a document may be saved in the database. """ # Create person object and save it to the database person = self.Person(name='Test User', age=30) person.save() # Ensure that the object is in the database collection = self.db[self.Person._get_collection_name()] person_obj = collection.find_one({'name': 'Test User'}) self.assertEqual(person_obj['name'], 'Test User') self.assertEqual(person_obj['age'], 30) self.assertEqual(person_obj['_id'], person.id) # Test skipping validation on save class Recipient(Document): email = EmailField(required=True) recipient = Recipient(email='root@localhost') self.assertRaises(ValidationError, recipient.save) try: recipient.save(validate=False) except ValidationError: self.fail() def test_save_to_a_value_that_equates_to_false(self): class Thing(EmbeddedDocument): count = IntField() class User(Document): thing = EmbeddedDocumentField(Thing) User.drop_collection() user = User(thing=Thing(count=1)) user.save() user.reload() user.thing.count = 0 user.save() user.reload() self.assertEqual(user.thing.count, 0) def test_save_max_recursion_not_hit(self): class Person(Document): name = StringField() parent = ReferenceField('self') friend = ReferenceField('self') Person.drop_collection() p1 = Person(name="Wilson Snr") p1.parent = None p1.save() p2 = Person(name="Wilson Jr") p2.parent = p1 p2.save() p1.friend = p2 p1.save() # Confirm can save and it resets the changed fields without hitting # max recursion error p0 = Person.objects.first() p0.name = 'wpjunior' p0.save() def test_save_max_recursion_not_hit_with_file_field(self): class Foo(Document): name = StringField() picture = FileField() bar = ReferenceField('self') Foo.drop_collection() a = Foo(name='hello') a.save() a.bar = a with open(TEST_IMAGE_PATH, 'rb') as test_image: a.picture = test_image a.save() # Confirm can save and it resets the changed fields without hitting # max recursion error b = Foo.objects.with_id(a.id) b.name='world' b.save() self.assertEqual(b.picture, b.bar.picture, b.bar.bar.picture) def test_save_cascades(self): class Person(Document): name = StringField() parent = ReferenceField('self') Person.drop_collection() p1 = Person(name="Wilson Snr") p1.parent = None p1.save() p2 = Person(name="Wilson Jr") p2.parent = p1 p2.save() p = Person.objects(name="Wilson Jr").get() p.parent.name = "Daddy Wilson" p.save() p1.reload() self.assertEqual(p1.name, p.parent.name) def test_save_cascade_kwargs(self): class Person(Document): name = StringField() parent = ReferenceField('self') Person.drop_collection() p1 = Person(name="Wilson Snr") p1.parent = None p1.save() p2 = Person(name="Wilson Jr") p2.parent = p1 p2.save(force_insert=True, cascade_kwargs={"force_insert": False}) p = Person.objects(name="Wilson Jr").get() p.parent.name = "Daddy Wilson" p.save() p1.reload() self.assertEqual(p1.name, p.parent.name) def test_save_cascade_meta(self): class Person(Document): name = StringField() parent = ReferenceField('self') meta = {'cascade': False} Person.drop_collection() p1 = Person(name="Wilson Snr") p1.parent = None p1.save() p2 = Person(name="Wilson Jr") p2.parent = p1 p2.save() p = Person.objects(name="Wilson Jr").get() p.parent.name = "Daddy Wilson" p.save() p1.reload() self.assertNotEqual(p1.name, p.parent.name) p.save(cascade=True) p1.reload() self.assertEqual(p1.name, p.parent.name) def test_save_cascades_generically(self): class Person(Document): name = StringField() parent = GenericReferenceField() Person.drop_collection() p1 = Person(name="Wilson Snr") p1.save() p2 = Person(name="Wilson Jr") p2.parent = p1 p2.save() p = Person.objects(name="Wilson Jr").get() p.parent.name = "Daddy Wilson" p.save() p1.reload() self.assertEqual(p1.name, p.parent.name) def test_update(self): """Ensure that an existing document is updated instead of be overwritten.""" # Create person object and save it to the database person = self.Person(name='Test User', age=30) person.save() # Create same person object, with same id, without age same_person = self.Person(name='Test') same_person.id = person.id same_person.save() # Confirm only one object self.assertEqual(self.Person.objects.count(), 1) # reload person.reload() same_person.reload() # Confirm the same self.assertEqual(person, same_person) self.assertEqual(person.name, same_person.name) self.assertEqual(person.age, same_person.age) # Confirm the saved values self.assertEqual(person.name, 'Test') self.assertEqual(person.age, 30) # Test only / exclude only updates included fields person = self.Person.objects.only('name').get() person.name = 'User' person.save() person.reload() self.assertEqual(person.name, 'User') self.assertEqual(person.age, 30) # test exclude only updates set fields person = self.Person.objects.exclude('name').get() person.age = 21 person.save() person.reload() self.assertEqual(person.name, 'User') self.assertEqual(person.age, 21) # Test only / exclude can set non excluded / included fields person = self.Person.objects.only('name').get() person.name = 'Test' person.age = 30 person.save() person.reload() self.assertEqual(person.name, 'Test') self.assertEqual(person.age, 30) # test exclude only updates set fields person = self.Person.objects.exclude('name').get() person.name = 'User' person.age = 21 person.save() person.reload() self.assertEqual(person.name, 'User') self.assertEqual(person.age, 21) # Confirm does remove unrequired fields person = self.Person.objects.exclude('name').get() person.age = None person.save() person.reload() self.assertEqual(person.name, 'User') self.assertEqual(person.age, None) person = self.Person.objects.get() person.name = None person.age = None person.save() person.reload() self.assertEqual(person.name, None) self.assertEqual(person.age, None) def test_can_save_if_not_included(self): class EmbeddedDoc(EmbeddedDocument): pass class Simple(Document): pass class Doc(Document): string_field = StringField(default='1') int_field = IntField(default=1) float_field = FloatField(default=1.1) boolean_field = BooleanField(default=True) datetime_field = DateTimeField(default=datetime.now) embedded_document_field = EmbeddedDocumentField(EmbeddedDoc, default=lambda: EmbeddedDoc()) list_field = ListField(default=lambda: [1, 2, 3]) dict_field = DictField(default=lambda: {"hello": "world"}) objectid_field = ObjectIdField(default=bson.ObjectId) reference_field = ReferenceField(Simple, default=lambda: Simple().save()) map_field = MapField(IntField(), default=lambda: {"simple": 1}) decimal_field = DecimalField(default=1.0) complex_datetime_field = ComplexDateTimeField(default=datetime.now) url_field = URLField(default="http://mongoengine.org") dynamic_field = DynamicField(default=1) generic_reference_field = GenericReferenceField( default=lambda: Simple().save()) sorted_list_field = SortedListField(IntField(), default=lambda: [1, 2, 3]) email_field = EmailField(default="ross@example.com") geo_point_field = GeoPointField(default=lambda: [1, 2]) sequence_field = SequenceField() uuid_field = UUIDField(default=uuid.uuid4) generic_embedded_document_field = GenericEmbeddedDocumentField( default=lambda: EmbeddedDoc()) Simple.drop_collection() Doc.drop_collection() Doc().save() my_doc = Doc.objects.only("string_field").first() my_doc.string_field = "string" my_doc.save() my_doc = Doc.objects.get(string_field="string") self.assertEqual(my_doc.string_field, "string") self.assertEqual(my_doc.int_field, 1) def test_document_update(self): def update_not_saved_raises(): person = self.Person(name='dcrosta') person.update(set__name='Dan Crosta') self.assertRaises(OperationError, update_not_saved_raises) author = self.Person(name='dcrosta') author.save() author.update(set__name='Dan Crosta') author.reload() p1 = self.Person.objects.first() self.assertEqual(p1.name, author.name) def update_no_value_raises(): person = self.Person.objects.first() person.update() self.assertRaises(OperationError, update_no_value_raises) def update_no_op_raises(): person = self.Person.objects.first() person.update(name="Dan") self.assertRaises(InvalidQueryError, update_no_op_raises) def test_embedded_update(self): """ Test update on `EmbeddedDocumentField` fields """ class Page(EmbeddedDocument): log_message = StringField(verbose_name="Log message", required=True) class Site(Document): page = EmbeddedDocumentField(Page) Site.drop_collection() site = Site(page=Page(log_message="Warning: Dummy message")) site.save() # Update site = Site.objects.first() site.page.log_message = "Error: Dummy message" site.save() site = Site.objects.first() self.assertEqual(site.page.log_message, "Error: Dummy message") def test_embedded_update_db_field(self): """ Test update on `EmbeddedDocumentField` fields when db_field is other than default. """ class Page(EmbeddedDocument): log_message = StringField(verbose_name="Log message", db_field="page_log_message", required=True) class Site(Document): page = EmbeddedDocumentField(Page) Site.drop_collection() site = Site(page=Page(log_message="Warning: Dummy message")) site.save() # Update site = Site.objects.first() site.page.log_message = "Error: Dummy message" site.save() site = Site.objects.first() self.assertEqual(site.page.log_message, "Error: Dummy message") def test_save_only_changed_fields(self): """Ensure save only sets / unsets changed fields """ class User(self.Person): active = BooleanField(default=True) User.drop_collection() # Create person object and save it to the database user = User(name='Test User', age=30, active=True) user.save() user.reload() # Simulated Race condition same_person = self.Person.objects.get() same_person.active = False user.age = 21 user.save() same_person.name = 'User' same_person.save() person = self.Person.objects.get() self.assertEqual(person.name, 'User') self.assertEqual(person.age, 21) self.assertEqual(person.active, False) def test_save_only_changed_fields_recursive(self): """Ensure save only sets / unsets changed fields """ class Comment(EmbeddedDocument): published = BooleanField(default=True) class User(self.Person): comments_dict = DictField() comments = ListField(EmbeddedDocumentField(Comment)) active = BooleanField(default=True) User.drop_collection() # Create person object and save it to the database person = User(name='Test User', age=30, active=True) person.comments.append(Comment()) person.save() person.reload() person = self.Person.objects.get() self.assertTrue(person.comments[0].published) person.comments[0].published = False person.save() person = self.Person.objects.get() self.assertFalse(person.comments[0].published) # Simple dict w person.comments_dict['first_post'] = Comment() person.save() person = self.Person.objects.get() self.assertTrue(person.comments_dict['first_post'].published) person.comments_dict['first_post'].published = False person.save() person = self.Person.objects.get() self.assertFalse(person.comments_dict['first_post'].published) def test_delete(self): """Ensure that document may be deleted using the delete method. """ person = self.Person(name="Test User", age=30) person.save() self.assertEqual(len(self.Person.objects), 1) person.delete() self.assertEqual(len(self.Person.objects), 0) def test_save_custom_id(self): """Ensure that a document may be saved with a custom _id. """ # Create person object and save it to the database person = self.Person(name='Test User', age=30, id='497ce96f395f2f052a494fd4') person.save() # Ensure that the object is in the database with the correct _id collection = self.db[self.Person._get_collection_name()] person_obj = collection.find_one({'name': 'Test User'}) self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4') def test_save_custom_pk(self): """Ensure that a document may be saved with a custom _id using pk alias. """ # Create person object and save it to the database person = self.Person(name='Test User', age=30, pk='497ce96f395f2f052a494fd4') person.save() # Ensure that the object is in the database with the correct _id collection = self.db[self.Person._get_collection_name()] person_obj = collection.find_one({'name': 'Test User'}) self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4') def test_save_list(self): """Ensure that a list field may be properly saved. """ class Comment(EmbeddedDocument): content = StringField() class BlogPost(Document): content = StringField() comments = ListField(EmbeddedDocumentField(Comment)) tags = ListField(StringField()) BlogPost.drop_collection() post = BlogPost(content='Went for a walk today...') post.tags = tags = ['fun', 'leisure'] comments = [Comment(content='Good for you'), Comment(content='Yay.')] post.comments = comments post.save() collection = self.db[BlogPost._get_collection_name()] post_obj = collection.find_one() self.assertEqual(post_obj['tags'], tags) for comment_obj, comment in zip(post_obj['comments'], comments): self.assertEqual(comment_obj['content'], comment['content']) BlogPost.drop_collection() def test_list_search_by_embedded(self): class User(Document): username = StringField(required=True) meta = {'allow_inheritance': False} class Comment(EmbeddedDocument): comment = StringField() user = ReferenceField(User, required=True) meta = {'allow_inheritance': False} class Page(Document): comments = ListField(EmbeddedDocumentField(Comment)) meta = {'allow_inheritance': False, 'indexes': [ {'fields': ['comments.user']} ]} User.drop_collection() Page.drop_collection() u1 = User(username="wilson") u1.save() u2 = User(username="rozza") u2.save() u3 = User(username="hmarr") u3.save() p1 = Page(comments = [Comment(user=u1, comment="Its very good"), Comment(user=u2, comment="Hello world"), Comment(user=u3, comment="Ping Pong"), Comment(user=u1, comment="I like a beer")]) p1.save() p2 = Page(comments = [Comment(user=u1, comment="Its very good"), Comment(user=u2, comment="Hello world")]) p2.save() p3 = Page(comments = [Comment(user=u3, comment="Its very good")]) p3.save() p4 = Page(comments = [Comment(user=u2, comment="Heavy Metal song")]) p4.save() self.assertEqual([p1, p2], list(Page.objects.filter(comments__user=u1))) self.assertEqual([p1, p2, p4], list(Page.objects.filter(comments__user=u2))) self.assertEqual([p1, p3], list(Page.objects.filter(comments__user=u3))) def test_save_embedded_document(self): """Ensure that a document with an embedded document field may be saved in the database. """ class EmployeeDetails(EmbeddedDocument): position = StringField() class Employee(self.Person): salary = IntField() details = EmbeddedDocumentField(EmployeeDetails) # Create employee object and save it to the database employee = Employee(name='Test Employee', age=50, salary=20000) employee.details = EmployeeDetails(position='Developer') employee.save() # Ensure that the object is in the database collection = self.db[self.Person._get_collection_name()] employee_obj = collection.find_one({'name': 'Test Employee'}) self.assertEqual(employee_obj['name'], 'Test Employee') self.assertEqual(employee_obj['age'], 50) # Ensure that the 'details' embedded object saved correctly self.assertEqual(employee_obj['details']['position'], 'Developer') def test_embedded_update_after_save(self): """ Test update of `EmbeddedDocumentField` attached to a newly saved document. """ class Page(EmbeddedDocument): log_message = StringField(verbose_name="Log message", required=True) class Site(Document): page = EmbeddedDocumentField(Page) Site.drop_collection() site = Site(page=Page(log_message="Warning: Dummy message")) site.save() # Update site.page.log_message = "Error: Dummy message" site.save() site = Site.objects.first() self.assertEqual(site.page.log_message, "Error: Dummy message") def test_updating_an_embedded_document(self): """Ensure that a document with an embedded document field may be saved in the database. """ class EmployeeDetails(EmbeddedDocument): position = StringField() class Employee(self.Person): salary = IntField() details = EmbeddedDocumentField(EmployeeDetails) # Create employee object and save it to the database employee = Employee(name='Test Employee', age=50, salary=20000) employee.details = EmployeeDetails(position='Developer') employee.save() # Test updating an embedded document promoted_employee = Employee.objects.get(name='Test Employee') promoted_employee.details.position = 'Senior Developer' promoted_employee.save() promoted_employee.reload() self.assertEqual(promoted_employee.name, 'Test Employee') self.assertEqual(promoted_employee.age, 50) # Ensure that the 'details' embedded object saved correctly self.assertEqual(promoted_employee.details.position, 'Senior Developer') # Test removal promoted_employee.details = None promoted_employee.save() promoted_employee.reload() self.assertEqual(promoted_employee.details, None) def test_object_mixins(self): class NameMixin(object): name = StringField() class Foo(EmbeddedDocument, NameMixin): quantity = IntField() self.assertEqual(['name', 'quantity'], sorted(Foo._fields.keys())) class Bar(Document, NameMixin): widgets = StringField() self.assertEqual(['id', 'name', 'widgets'], sorted(Bar._fields.keys())) def test_mixin_inheritance(self): class BaseMixIn(object): count = IntField() data = StringField() class DoubleMixIn(BaseMixIn): comment = StringField() class TestDoc(Document, DoubleMixIn): age = IntField() TestDoc.drop_collection() t = TestDoc(count=12, data="test", comment="great!", age=19) t.save() t = TestDoc.objects.first() self.assertEqual(t.age, 19) self.assertEqual(t.comment, "great!") self.assertEqual(t.data, "test") self.assertEqual(t.count, 12) def test_save_reference(self): """Ensure that a document reference field may be saved in the database. """ class BlogPost(Document): meta = {'collection': 'blogpost_1'} content = StringField() author = ReferenceField(self.Person) BlogPost.drop_collection() author = self.Person(name='Test User') author.save() post = BlogPost(content='Watched some TV today... how exciting.') # Should only reference author when saving post.author = author post.save() post_obj = BlogPost.objects.first() # Test laziness self.assertTrue(isinstance(post_obj._data['author'], bson.DBRef)) self.assertTrue(isinstance(post_obj.author, self.Person)) self.assertEqual(post_obj.author.name, 'Test User') # Ensure that the dereferenced object may be changed and saved post_obj.author.age = 25 post_obj.author.save() author = list(self.Person.objects(name='Test User'))[-1] self.assertEqual(author.age, 25) BlogPost.drop_collection() def test_duplicate_db_fields_raise_invalid_document_error(self): """Ensure a InvalidDocumentError is thrown if duplicate fields declare the same db_field""" def throw_invalid_document_error(): class Foo(Document): name = StringField() name2 = StringField(db_field='name') self.assertRaises(InvalidDocumentError, throw_invalid_document_error) def test_invalid_son(self): """Raise an error if loading invalid data""" class Occurrence(EmbeddedDocument): number = IntField() class Word(Document): stem = StringField() count = IntField(default=1) forms = ListField(StringField(), default=list) occurs = ListField(EmbeddedDocumentField(Occurrence), default=list) def raise_invalid_document(): Word._from_son({'stem': [1,2,3], 'forms': 1, 'count': 'one', 'occurs': {"hello": None}}) self.assertRaises(InvalidDocumentError, raise_invalid_document) def test_reverse_delete_rule_cascade_and_nullify(self): """Ensure that a referenced document is also deleted upon deletion. """ class BlogPost(Document): content = StringField() author = ReferenceField(self.Person, reverse_delete_rule=CASCADE) reviewer = ReferenceField(self.Person, reverse_delete_rule=NULLIFY) self.Person.drop_collection() BlogPost.drop_collection() author = self.Person(name='Test User') author.save() reviewer = self.Person(name='Re Viewer') reviewer.save() post = BlogPost(content = 'Watched some TV') post.author = author post.reviewer = reviewer post.save() reviewer.delete() self.assertEqual(len(BlogPost.objects), 1) # No effect on the BlogPost self.assertEqual(BlogPost.objects.get().reviewer, None) # Delete the Person, which should lead to deletion of the BlogPost, too author.delete() self.assertEqual(len(BlogPost.objects), 0) def test_reverse_delete_rule_cascade_and_nullify_complex_field(self): """Ensure that a referenced document is also deleted upon deletion for complex fields. """ class BlogPost(Document): content = StringField() authors = ListField(ReferenceField(self.Person, reverse_delete_rule=CASCADE)) reviewers = ListField(ReferenceField(self.Person, reverse_delete_rule=NULLIFY)) self.Person.drop_collection() BlogPost.drop_collection() author = self.Person(name='Test User') author.save() reviewer = self.Person(name='Re Viewer') reviewer.save() post = BlogPost(content='Watched some TV') post.authors = [author] post.reviewers = [reviewer] post.save() # Deleting the reviewer should have no effect on the BlogPost reviewer.delete() self.assertEqual(len(BlogPost.objects), 1) self.assertEqual(BlogPost.objects.get().reviewers, []) # Delete the Person, which should lead to deletion of the BlogPost, too author.delete() self.assertEqual(len(BlogPost.objects), 0) def test_two_way_reverse_delete_rule(self): """Ensure that Bi-Directional relationships work with reverse_delete_rule """ class Bar(Document): content = StringField() foo = ReferenceField('Foo') class Foo(Document): content = StringField() bar = ReferenceField(Bar) Bar.register_delete_rule(Foo, 'bar', NULLIFY) Foo.register_delete_rule(Bar, 'foo', NULLIFY) Bar.drop_collection() Foo.drop_collection() b = Bar(content="Hello") b.save() f = Foo(content="world", bar=b) f.save() b.foo = f b.save() f.delete() self.assertEqual(len(Bar.objects), 1) # No effect on the BlogPost self.assertEqual(Bar.objects.get().foo, None) def test_invalid_reverse_delete_rules_raise_errors(self): def throw_invalid_document_error(): class Blog(Document): content = StringField() authors = MapField(ReferenceField(self.Person, reverse_delete_rule=CASCADE)) reviewers = DictField(field=ReferenceField(self.Person, reverse_delete_rule=NULLIFY)) self.assertRaises(InvalidDocumentError, throw_invalid_document_error) def throw_invalid_document_error_embedded(): class Parents(EmbeddedDocument): father = ReferenceField('Person', reverse_delete_rule=DENY) mother = ReferenceField('Person', reverse_delete_rule=DENY) self.assertRaises(InvalidDocumentError, throw_invalid_document_error_embedded) def test_reverse_delete_rule_cascade_recurs(self): """Ensure that a chain of documents is also deleted upon cascaded deletion. """ class BlogPost(Document): content = StringField() author = ReferenceField(self.Person, reverse_delete_rule=CASCADE) class Comment(Document): text = StringField() post = ReferenceField(BlogPost, reverse_delete_rule=CASCADE) self.Person.drop_collection() BlogPost.drop_collection() Comment.drop_collection() author = self.Person(name='Test User') author.save() post = BlogPost(content = 'Watched some TV') post.author = author post.save() comment = Comment(text = 'Kudos.') comment.post = post comment.save() # Delete the Person, which should lead to deletion of the BlogPost, and, # recursively to the Comment, too author.delete() self.assertEqual(len(Comment.objects), 0) self.Person.drop_collection() BlogPost.drop_collection() Comment.drop_collection() def test_reverse_delete_rule_deny(self): """Ensure that a document cannot be referenced if there are still documents referring to it. """ class BlogPost(Document): content = StringField() author = ReferenceField(self.Person, reverse_delete_rule=DENY) self.Person.drop_collection() BlogPost.drop_collection() author = self.Person(name='Test User') author.save() post = BlogPost(content = 'Watched some TV') post.author = author post.save() # Delete the Person should be denied self.assertRaises(OperationError, author.delete) # Should raise denied error self.assertEqual(len(BlogPost.objects), 1) # No objects may have been deleted self.assertEqual(len(self.Person.objects), 1) # Other users, that don't have BlogPosts must be removable, like normal author = self.Person(name='Another User') author.save() self.assertEqual(len(self.Person.objects), 2) author.delete() self.assertEqual(len(self.Person.objects), 1) self.Person.drop_collection() BlogPost.drop_collection() def subclasses_and_unique_keys_works(self): class A(Document): pass class B(A): foo = BooleanField(unique=True) A.drop_collection() B.drop_collection() A().save() A().save() B(foo=True).save() self.assertEqual(A.objects.count(), 2) self.assertEqual(B.objects.count(), 1) A.drop_collection() B.drop_collection() def test_document_hash(self): """Test document in list, dict, set """ class User(Document): pass class BlogPost(Document): pass # Clear old datas User.drop_collection() BlogPost.drop_collection() u1 = User.objects.create() u2 = User.objects.create() u3 = User.objects.create() u4 = User() # New object b1 = BlogPost.objects.create() b2 = BlogPost.objects.create() # in List all_user_list = list(User.objects.all()) self.assertTrue(u1 in all_user_list) self.assertTrue(u2 in all_user_list) self.assertTrue(u3 in all_user_list) self.assertFalse(u4 in all_user_list) # New object self.assertFalse(b1 in all_user_list) # Other object self.assertFalse(b2 in all_user_list) # Other object # in Dict all_user_dic = {} for u in User.objects.all(): all_user_dic[u] = "OK" self.assertEqual(all_user_dic.get(u1, False), "OK") self.assertEqual(all_user_dic.get(u2, False), "OK") self.assertEqual(all_user_dic.get(u3, False), "OK") self.assertEqual(all_user_dic.get(u4, False), False) # New object self.assertEqual(all_user_dic.get(b1, False), False) # Other object self.assertEqual(all_user_dic.get(b2, False), False) # Other object # in Set all_user_set = set(User.objects.all()) self.assertTrue(u1 in all_user_set) def test_picklable(self): pickle_doc = PickleTest(number=1, string="One", lists=['1', '2']) pickle_doc.embedded = PickleEmbedded() pickle_doc.save() pickled_doc = pickle.dumps(pickle_doc) resurrected = pickle.loads(pickled_doc) self.assertEqual(resurrected, pickle_doc) resurrected.string = "Two" resurrected.save() pickle_doc = pickle_doc.reload() self.assertEqual(resurrected, pickle_doc) def test_throw_invalid_document_error(self): # test handles people trying to upsert def throw_invalid_document_error(): class Blog(Document): validate = DictField() self.assertRaises(InvalidDocumentError, throw_invalid_document_error) def test_mutating_documents(self): class B(EmbeddedDocument): field1 = StringField(default='field1') class A(Document): b = EmbeddedDocumentField(B, default=lambda: B()) A.drop_collection() a = A() a.save() a.reload() self.assertEqual(a.b.field1, 'field1') class C(EmbeddedDocument): c_field = StringField(default='cfield') class B(EmbeddedDocument): field1 = StringField(default='field1') field2 = EmbeddedDocumentField(C, default=lambda: C()) class A(Document): b = EmbeddedDocumentField(B, default=lambda: B()) a = A.objects()[0] a.b.field2.c_field = 'new value' a.save() a.reload() self.assertEqual(a.b.field2.c_field, 'new value') def test_can_save_false_values(self): """Ensures you can save False values on save""" class Doc(Document): foo = StringField() archived = BooleanField(default=False, required=True) Doc.drop_collection() d = Doc() d.save() d.archived = False d.save() self.assertEqual(Doc.objects(archived=False).count(), 1) def test_can_save_false_values_dynamic(self): """Ensures you can save False values on dynamic docs""" class Doc(DynamicDocument): foo = StringField() Doc.drop_collection() d = Doc() d.save() d.archived = False d.save() self.assertEqual(Doc.objects(archived=False).count(), 1) def test_do_not_save_unchanged_references(self): """Ensures cascading saves dont auto update""" class Job(Document): name = StringField() class Person(Document): name = StringField() age = IntField() job = ReferenceField(Job) Job.drop_collection() Person.drop_collection() job = Job(name="Job 1") # job should not have any changed fields after the save job.save() person = Person(name="name", age=10, job=job) from pymongo.collection import Collection orig_update = Collection.update try: def fake_update(*args, **kwargs): self.fail("Unexpected update for %s" % args[0].name) return orig_update(*args, **kwargs) Collection.update = fake_update person.save() finally: Collection.update = orig_update def test_db_alias_tests(self): """ DB Alias tests """ # mongoenginetest - Is default connection alias from setUp() # Register Aliases register_connection('testdb-1', 'mongoenginetest2') register_connection('testdb-2', 'mongoenginetest3') register_connection('testdb-3', 'mongoenginetest4') class User(Document): name = StringField() meta = {"db_alias": "testdb-1"} class Book(Document): name = StringField() meta = {"db_alias": "testdb-2"} # Drops User.drop_collection() Book.drop_collection() # Create bob = User.objects.create(name="Bob") hp = Book.objects.create(name="Harry Potter") # Selects self.assertEqual(User.objects.first(), bob) self.assertEqual(Book.objects.first(), hp) # DeReference class AuthorBooks(Document): author = ReferenceField(User) book = ReferenceField(Book) meta = {"db_alias": "testdb-3"} # Drops AuthorBooks.drop_collection() ab = AuthorBooks.objects.create(author=bob, book=hp) # select self.assertEqual(AuthorBooks.objects.first(), ab) self.assertEqual(AuthorBooks.objects.first().book, hp) self.assertEqual(AuthorBooks.objects.first().author, bob) self.assertEqual(AuthorBooks.objects.filter(author=bob).first(), ab) self.assertEqual(AuthorBooks.objects.filter(book=hp).first(), ab) # DB Alias self.assertEqual(User._get_db(), get_db("testdb-1")) self.assertEqual(Book._get_db(), get_db("testdb-2")) self.assertEqual(AuthorBooks._get_db(), get_db("testdb-3")) # Collections self.assertEqual(User._get_collection(), get_db("testdb-1")[User._get_collection_name()]) self.assertEqual(Book._get_collection(), get_db("testdb-2")[Book._get_collection_name()]) self.assertEqual(AuthorBooks._get_collection(), get_db("testdb-3")[AuthorBooks._get_collection_name()]) def test_db_alias_propagates(self): """db_alias propagates? """ class A(Document): name = StringField() meta = {"db_alias": "testdb-1", "allow_inheritance": True} class B(A): pass self.assertEqual('testdb-1', B._meta.get('db_alias')) def test_db_ref_usage(self): """ DB Ref usage in dict_fields""" class User(Document): name = StringField() class Book(Document): name = StringField() author = ReferenceField(User) extra = DictField() meta = { 'ordering': ['+name'] } def __unicode__(self): return self.name def __str__(self): return self.name # Drops User.drop_collection() Book.drop_collection() # Authors bob = User.objects.create(name="Bob") jon = User.objects.create(name="Jon") # Redactors karl = User.objects.create(name="Karl") susan = User.objects.create(name="Susan") peter = User.objects.create(name="Peter") # Bob Book.objects.create(name="1", author=bob, extra={ "a": bob.to_dbref(), "b": [karl.to_dbref(), susan.to_dbref()]}) Book.objects.create(name="2", author=bob, extra={ "a": bob.to_dbref(), "b": karl.to_dbref()}) Book.objects.create(name="3", author=bob, extra={ "a": bob.to_dbref(), "c": [jon.to_dbref(), peter.to_dbref()]}) Book.objects.create(name="4", author=bob) # Jon Book.objects.create(name="5", author=jon) Book.objects.create(name="6", author=peter) Book.objects.create(name="7", author=jon) Book.objects.create(name="8", author=jon) Book.objects.create(name="9", author=jon, extra={"a": peter.to_dbref()}) # Checks self.assertEqual(",".join([str(b) for b in Book.objects.all()]), "1,2,3,4,5,6,7,8,9") # bob related books self.assertEqual(",".join([str(b) for b in Book.objects.filter( Q(extra__a=bob) | Q(author=bob) | Q(extra__b=bob))]), "1,2,3,4") # Susan & Karl related books self.assertEqual(",".join([str(b) for b in Book.objects.filter( Q(extra__a__all=[karl, susan]) | Q(author__all=[karl, susan]) | Q(extra__b__all=[ karl.to_dbref(), susan.to_dbref()])) ]), "1") # $Where self.assertEqual(u",".join([str(b) for b in Book.objects.filter( __raw__={ "$where": """ function(){ return this.name == '1' || this.name == '2';}""" } )]), "1,2") class ValidatorErrorTest(unittest.TestCase): def test_to_dict(self): """Ensure a ValidationError handles error to_dict correctly. """ error = ValidationError('root') self.assertEqual(error.to_dict(), {}) # 1st level error schema error.errors = {'1st': ValidationError('bad 1st'), } self.assertTrue('1st' in error.to_dict()) self.assertEqual(error.to_dict()['1st'], 'bad 1st') # 2nd level error schema error.errors = {'1st': ValidationError('bad 1st', errors={ '2nd': ValidationError('bad 2nd'), })} self.assertTrue('1st' in error.to_dict()) self.assertTrue(isinstance(error.to_dict()['1st'], dict)) self.assertTrue('2nd' in error.to_dict()['1st']) self.assertEqual(error.to_dict()['1st']['2nd'], 'bad 2nd') # moar levels error.errors = {'1st': ValidationError('bad 1st', errors={ '2nd': ValidationError('bad 2nd', errors={ '3rd': ValidationError('bad 3rd', errors={ '4th': ValidationError('Inception'), }), }), })} self.assertTrue('1st' in error.to_dict()) self.assertTrue('2nd' in error.to_dict()['1st']) self.assertTrue('3rd' in error.to_dict()['1st']['2nd']) self.assertTrue('4th' in error.to_dict()['1st']['2nd']['3rd']) self.assertEqual(error.to_dict()['1st']['2nd']['3rd']['4th'], 'Inception') self.assertEqual(error.message, "root(2nd.3rd.4th.Inception: ['1st'])") def test_model_validation(self): class User(Document): username = StringField(primary_key=True) name = StringField(required=True) try: User().validate() except ValidationError, e: expected_error_message = """ValidationError(Field is required: ['username', 'name'])""" self.assertEqual(e.message, expected_error_message) self.assertEqual(e.to_dict(), { 'username': 'Field is required', 'name': 'Field is required'}) def test_spaces_in_keys(self): class Embedded(DynamicEmbeddedDocument): pass class Doc(DynamicDocument): pass Doc.drop_collection() doc = Doc() setattr(doc, 'hello world', 1) doc.save() one = Doc.objects.filter(**{'hello world': 1}).count() self.assertEqual(1, one) def test_fields_rewrite(self): class BasePerson(Document): name = StringField() age = IntField() meta = {'abstract': True} class Person(BasePerson): name = StringField(required=True) p = Person(age=15) self.assertRaises(ValidationError, p.validate) def test_cascaded_save_wrong_reference(self): class ADocument(Document): val = IntField() class BDocument(Document): a = ReferenceField(ADocument) ADocument.drop_collection() BDocument.drop_collection() a = ADocument() a.val = 15 a.save() b = BDocument() b.a = a b.save() a.delete() b = BDocument.objects.first() b.save(cascade=True) def test_shard_key(self): class LogEntry(Document): machine = StringField() log = StringField() meta = { 'shard_key': ('machine',) } LogEntry.drop_collection() log = LogEntry() log.machine = "Localhost" log.save() log.log = "Saving" log.save() def change_shard_key(): log.machine = "127.0.0.1" self.assertRaises(OperationError, change_shard_key) def test_shard_key_primary(self): class LogEntry(Document): machine = StringField(primary_key=True) log = StringField() meta = { 'shard_key': ('machine',) } LogEntry.drop_collection() log = LogEntry() log.machine = "Localhost" log.save() log.log = "Saving" log.save() def change_shard_key(): log.machine = "127.0.0.1" self.assertRaises(OperationError, change_shard_key) if __name__ == '__main__': unittest.main()