Merge branch 'master' into pr/592

This commit is contained in:
Ross Lawley
2014-06-27 12:36:39 +01:00
34 changed files with 1526 additions and 262 deletions

View File

@@ -207,22 +207,21 @@ class DeltaTest(unittest.TestCase):
doc.embedded_field.list_field[2].string_field = 'hello world'
doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2]
self.assertEqual(doc._get_changed_fields(),
['embedded_field.list_field'])
self.assertEqual(doc.embedded_field._delta(), ({
'list_field': ['1', 2, {
'_cls': 'Embedded',
'string_field': 'hello world',
'int_field': 1,
'list_field': ['1', 2, {'hello': 'world'}],
'dict_field': {'hello': 'world'}}]}, {}))
self.assertEqual(doc._delta(), ({
'embedded_field.list_field': ['1', 2, {
['embedded_field.list_field.2'])
self.assertEqual(doc.embedded_field._delta(), ({'list_field.2': {
'_cls': 'Embedded',
'string_field': 'hello world',
'int_field': 1,
'list_field': ['1', 2, {'hello': 'world'}],
'dict_field': {'hello': 'world'}}
]}, {}))
}, {}))
self.assertEqual(doc._delta(), ({'embedded_field.list_field.2': {
'_cls': 'Embedded',
'string_field': 'hello world',
'int_field': 1,
'list_field': ['1', 2, {'hello': 'world'}],
'dict_field': {'hello': 'world'}}
}, {}))
doc.save()
doc = doc.reload(10)
self.assertEqual(doc.embedded_field.list_field[2].string_field,
@@ -253,7 +252,7 @@ class DeltaTest(unittest.TestCase):
del(doc.embedded_field.list_field[2].list_field[2]['hello'])
self.assertEqual(doc._delta(),
({'embedded_field.list_field.2.list_field': [1, 2, {}]}, {}))
({}, {'embedded_field.list_field.2.list_field.2.hello': 1}))
doc.save()
doc = doc.reload(10)
@@ -548,22 +547,21 @@ class DeltaTest(unittest.TestCase):
doc.embedded_field.list_field[2].string_field = 'hello world'
doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2]
self.assertEqual(doc._get_changed_fields(),
['db_embedded_field.db_list_field'])
self.assertEqual(doc.embedded_field._delta(), ({
'db_list_field': ['1', 2, {
['db_embedded_field.db_list_field.2'])
self.assertEqual(doc.embedded_field._delta(), ({'db_list_field.2': {
'_cls': 'Embedded',
'db_string_field': 'hello world',
'db_int_field': 1,
'db_list_field': ['1', 2, {'hello': 'world'}],
'db_dict_field': {'hello': 'world'}}]}, {}))
'db_dict_field': {'hello': 'world'}}}, {}))
self.assertEqual(doc._delta(), ({
'db_embedded_field.db_list_field': ['1', 2, {
'db_embedded_field.db_list_field.2': {
'_cls': 'Embedded',
'db_string_field': 'hello world',
'db_int_field': 1,
'db_list_field': ['1', 2, {'hello': 'world'}],
'db_dict_field': {'hello': 'world'}}
]}, {}))
}, {}))
doc.save()
doc = doc.reload(10)
self.assertEqual(doc.embedded_field.list_field[2].string_field,
@@ -594,8 +592,7 @@ class DeltaTest(unittest.TestCase):
del(doc.embedded_field.list_field[2].list_field[2]['hello'])
self.assertEqual(doc._delta(),
({'db_embedded_field.db_list_field.2.db_list_field':
[1, 2, {}]}, {}))
({}, {'db_embedded_field.db_list_field.2.db_list_field.2.hello': 1}))
doc.save()
doc = doc.reload(10)
@@ -735,5 +732,47 @@ class DeltaTest(unittest.TestCase):
mydoc._clear_changed_fields()
self.assertEqual([], mydoc._get_changed_fields())
def test_referenced_object_changed_attributes(self):
"""Ensures that when you save a new reference to a field, the referenced object isn't altered"""
class Organization(Document):
name = StringField()
class User(Document):
name = StringField()
org = ReferenceField('Organization', required=True)
Organization.drop_collection()
User.drop_collection()
org1 = Organization(name='Org 1')
org1.save()
org2 = Organization(name='Org 2')
org2.save()
user = User(name='Fred', org=org1)
user.save()
org1.reload()
org2.reload()
user.reload()
self.assertEqual(org1.name, 'Org 1')
self.assertEqual(org2.name, 'Org 2')
self.assertEqual(user.name, 'Fred')
user.name = 'Harold'
user.org = org2
org2.name = 'New Org 2'
self.assertEqual(org2.name, 'New Org 2')
user.save()
org2.save()
self.assertEqual(org2.name, 'New Org 2')
org2.reload()
self.assertEqual(org2.name, 'New Org 2')
if __name__ == '__main__':
unittest.main()

View File

@@ -292,6 +292,59 @@ class DynamicTest(unittest.TestCase):
person.save()
self.assertEqual(Person.objects.first().age, 35)
def test_dynamic_embedded_works_with_only(self):
"""Ensure custom fieldnames on a dynamic embedded document are found by qs.only()"""
class Address(DynamicEmbeddedDocument):
city = StringField()
class Person(DynamicDocument):
address = EmbeddedDocumentField(Address)
Person.drop_collection()
Person(name="Eric", address=Address(city="San Francisco", street_number="1337")).save()
self.assertEqual(Person.objects.first().address.street_number, '1337')
self.assertEqual(Person.objects.only('address__street_number').first().address.street_number, '1337')
def test_dynamic_and_embedded_dict_access(self):
"""Ensure embedded dynamic documents work with dict[] style access"""
class Address(EmbeddedDocument):
city = StringField()
class Person(DynamicDocument):
name = StringField()
Person.drop_collection()
Person(name="Ross", address=Address(city="London")).save()
person = Person.objects.first()
person.attrval = "This works"
person["phone"] = "555-1212" # but this should too
# Same thing two levels deep
person["address"]["city"] = "Lundenne"
person.save()
self.assertEqual(Person.objects.first().address.city, "Lundenne")
self.assertEqual(Person.objects.first().phone, "555-1212")
person = Person.objects.first()
person.address = Address(city="Londinium")
person.save()
self.assertEqual(Person.objects.first().address.city, "Londinium")
person = Person.objects.first()
person["age"] = 35
person.save()
self.assertEqual(Person.objects.first().age, 35)
if __name__ == '__main__':
unittest.main()

View File

@@ -15,7 +15,7 @@ from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest,
from mongoengine import *
from mongoengine.errors import (NotRegistered, InvalidDocumentError,
InvalidQueryError)
InvalidQueryError, NotUniqueError)
from mongoengine.queryset import NULLIFY, Q
from mongoengine.connection import get_db
from mongoengine.base import get_document
@@ -57,7 +57,7 @@ class InstanceTest(unittest.TestCase):
date = DateTimeField(default=datetime.now)
meta = {
'max_documents': 10,
'max_size': 90000,
'max_size': 4096,
}
Log.drop_collection()
@@ -75,7 +75,7 @@ class InstanceTest(unittest.TestCase):
options = Log.objects._collection.options()
self.assertEqual(options['capped'], True)
self.assertEqual(options['max'], 10)
self.assertEqual(options['size'], 90000)
self.assertTrue(options['size'] >= 4096)
# Check that the document cannot be redefined with different options
def recreate_log_document():
@@ -353,6 +353,14 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(person.name, "Test User")
self.assertEqual(person.age, 20)
person.reload('age')
self.assertEqual(person.name, "Test User")
self.assertEqual(person.age, 21)
person.reload()
self.assertEqual(person.name, "Mr Test User")
self.assertEqual(person.age, 21)
person.reload()
self.assertEqual(person.name, "Mr Test User")
self.assertEqual(person.age, 21)
@@ -398,10 +406,11 @@ class InstanceTest(unittest.TestCase):
doc.embedded_field.dict_field['woot'] = "woot"
self.assertEqual(doc._get_changed_fields(), [
'list_field', 'dict_field', 'embedded_field.list_field',
'embedded_field.dict_field'])
'list_field', 'dict_field.woot', 'embedded_field.list_field',
'embedded_field.dict_field.woot'])
doc.save()
self.assertEqual(len(doc.list_field), 4)
doc = doc.reload(10)
self.assertEqual(doc._get_changed_fields(), [])
self.assertEqual(len(doc.list_field), 4)
@@ -409,6 +418,16 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(len(doc.embedded_field.list_field), 4)
self.assertEqual(len(doc.embedded_field.dict_field), 2)
doc.list_field.append(1)
doc.save()
doc.dict_field['extra'] = 1
doc = doc.reload(10, 'list_field')
self.assertEqual(doc._get_changed_fields(), [])
self.assertEqual(len(doc.list_field), 5)
self.assertEqual(len(doc.dict_field), 3)
self.assertEqual(len(doc.embedded_field.list_field), 4)
self.assertEqual(len(doc.embedded_field.dict_field), 2)
def test_reload_doesnt_exist(self):
class Foo(Document):
pass
@@ -515,9 +534,6 @@ class InstanceTest(unittest.TestCase):
class Email(EmbeddedDocument):
email = EmailField()
def clean(self):
print "instance:"
print self._instance
class Account(Document):
email = EmbeddedDocumentField(Email)
@@ -820,6 +836,80 @@ class InstanceTest(unittest.TestCase):
p1.reload()
self.assertEqual(p1.name, p.parent.name)
def test_save_atomicity_condition(self):
class Widget(Document):
toggle = BooleanField(default=False)
count = IntField(default=0)
save_id = UUIDField()
def flip(widget):
widget.toggle = not widget.toggle
widget.count += 1
def UUID(i):
return uuid.UUID(int=i)
Widget.drop_collection()
w1 = Widget(toggle=False, save_id=UUID(1))
# ignore save_condition on new record creation
w1.save(save_condition={'save_id':UUID(42)})
w1.reload()
self.assertFalse(w1.toggle)
self.assertEqual(w1.save_id, UUID(1))
self.assertEqual(w1.count, 0)
# mismatch in save_condition prevents save
flip(w1)
self.assertTrue(w1.toggle)
self.assertEqual(w1.count, 1)
w1.save(save_condition={'save_id':UUID(42)})
w1.reload()
self.assertFalse(w1.toggle)
self.assertEqual(w1.count, 0)
# matched save_condition allows save
flip(w1)
self.assertTrue(w1.toggle)
self.assertEqual(w1.count, 1)
w1.save(save_condition={'save_id':UUID(1)})
w1.reload()
self.assertTrue(w1.toggle)
self.assertEqual(w1.count, 1)
# save_condition can be used to ensure atomic read & updates
# i.e., prevent interleaved reads and writes from separate contexts
w2 = Widget.objects.get()
self.assertEqual(w1, w2)
old_id = w1.save_id
flip(w1)
w1.save_id = UUID(2)
w1.save(save_condition={'save_id':old_id})
w1.reload()
self.assertFalse(w1.toggle)
self.assertEqual(w1.count, 2)
flip(w2)
flip(w2)
w2.save(save_condition={'save_id':old_id})
w2.reload()
self.assertFalse(w2.toggle)
self.assertEqual(w2.count, 2)
# save_condition uses mongoengine-style operator syntax
flip(w1)
w1.save(save_condition={'count__lt':w1.count})
w1.reload()
self.assertTrue(w1.toggle)
self.assertEqual(w1.count, 3)
flip(w1)
w1.save(save_condition={'count__gte':w1.count})
w1.reload()
self.assertTrue(w1.toggle)
self.assertEqual(w1.count, 3)
def test_update(self):
"""Ensure that an existing document is updated instead of be
overwritten."""
@@ -990,6 +1080,16 @@ class InstanceTest(unittest.TestCase):
self.assertRaises(InvalidQueryError, update_no_op_raises)
def test_update_unique_field(self):
class Doc(Document):
name = StringField(unique=True)
doc1 = Doc(name="first").save()
doc2 = Doc(name="second").save()
self.assertRaises(NotUniqueError, lambda:
doc2.update(set__name=doc1.name))
def test_embedded_update(self):
"""
Test update on `EmbeddedDocumentField` fields
@@ -2281,6 +2381,8 @@ class InstanceTest(unittest.TestCase):
log.machine = "Localhost"
log.save()
self.assertTrue(log.id is not None)
log.log = "Saving"
log.save()
@@ -2304,6 +2406,8 @@ class InstanceTest(unittest.TestCase):
log.machine = "Localhost"
log.save()
self.assertTrue(log.id is not None)
log.log = "Saving"
log.save()
@@ -2411,7 +2515,7 @@ class InstanceTest(unittest.TestCase):
for parameter_name, parameter in self.parameters.iteritems():
parameter.expand()
class System(Document):
class NodesSystem(Document):
name = StringField(required=True)
nodes = MapField(ReferenceField(Node, dbref=False))
@@ -2419,18 +2523,18 @@ class InstanceTest(unittest.TestCase):
for node_name, node in self.nodes.iteritems():
node.expand()
node.save(*args, **kwargs)
super(System, self).save(*args, **kwargs)
super(NodesSystem, self).save(*args, **kwargs)
System.drop_collection()
NodesSystem.drop_collection()
Node.drop_collection()
system = System(name="system")
system = NodesSystem(name="system")
system.nodes["node"] = Node()
system.save()
system.nodes["node"].parameters["param"] = Parameter()
system.save()
system = System.objects.first()
system = NodesSystem.objects.first()
self.assertEqual("UNDEFINED", system.nodes["node"].parameters["param"].macros["test"].value)
def test_embedded_document_equality(self):
@@ -2452,5 +2556,65 @@ class InstanceTest(unittest.TestCase):
f1.ref # Dereferences lazily
self.assertEqual(f1, f2)
def test_dbref_equality(self):
class Test2(Document):
name = StringField()
class Test3(Document):
name = StringField()
class Test(Document):
name = StringField()
test2 = ReferenceField('Test2')
test3 = ReferenceField('Test3')
Test.drop_collection()
Test2.drop_collection()
Test3.drop_collection()
t2 = Test2(name='a')
t2.save()
t3 = Test3(name='x')
t3.id = t2.id
t3.save()
t = Test(name='b', test2=t2, test3=t3)
f = Test._from_son(t.to_mongo())
dbref2 = f._data['test2']
obj2 = f.test2
self.assertTrue(isinstance(dbref2, DBRef))
self.assertTrue(isinstance(obj2, Test2))
self.assertTrue(obj2.id == dbref2.id)
self.assertTrue(obj2 == dbref2)
self.assertTrue(dbref2 == obj2)
dbref3 = f._data['test3']
obj3 = f.test3
self.assertTrue(isinstance(dbref3, DBRef))
self.assertTrue(isinstance(obj3, Test3))
self.assertTrue(obj3.id == dbref3.id)
self.assertTrue(obj3 == dbref3)
self.assertTrue(dbref3 == obj3)
self.assertTrue(obj2.id == obj3.id)
self.assertTrue(dbref2.id == dbref3.id)
self.assertFalse(dbref2 == dbref3)
self.assertFalse(dbref3 == dbref2)
self.assertTrue(dbref2 != dbref3)
self.assertTrue(dbref3 != dbref2)
self.assertFalse(obj2 == dbref3)
self.assertFalse(dbref3 == obj2)
self.assertTrue(obj2 != dbref3)
self.assertTrue(dbref3 != obj2)
self.assertFalse(obj3 == dbref2)
self.assertFalse(dbref2 == obj3)
self.assertTrue(obj3 != dbref2)
self.assertTrue(dbref2 != obj3)
if __name__ == '__main__':
unittest.main()

View File

@@ -279,7 +279,7 @@ class FileTest(unittest.TestCase):
t.image.put(f)
self.fail("Should have raised an invalidation error")
except ValidationError, e:
self.assertEqual("%s" % e, "Invalid image: cannot identify image file")
self.assertEqual("%s" % e, "Invalid image: cannot identify image file %s" % f)
t = TestImage()
t.image.put(open(TEST_IMAGE_PATH, 'rb'))

View File

@@ -3,3 +3,4 @@ from field_list import *
from queryset import *
from visitor import *
from geo import *
from modify import *

View File

@@ -5,6 +5,8 @@ import unittest
from datetime import datetime, timedelta
from mongoengine import *
from nose.plugins.skip import SkipTest
__all__ = ("GeoQueriesTest",)
@@ -139,6 +141,7 @@ class GeoQueriesTest(unittest.TestCase):
def test_spherical_geospatial_operators(self):
"""Ensure that spherical geospatial queries are working
"""
raise SkipTest("https://jira.mongodb.org/browse/SERVER-14039")
class Point(Document):
location = GeoPointField()

102
tests/queryset/modify.py Normal file
View File

@@ -0,0 +1,102 @@
import sys
sys.path[0:0] = [""]
import unittest
from mongoengine import connect, Document, IntField
__all__ = ("FindAndModifyTest",)
class Doc(Document):
id = IntField(primary_key=True)
value = IntField()
class FindAndModifyTest(unittest.TestCase):
def setUp(self):
connect(db="mongoenginetest")
Doc.drop_collection()
def assertDbEqual(self, docs):
self.assertEqual(list(Doc._collection.find().sort("id")), docs)
def test_modify(self):
Doc(id=0, value=0).save()
doc = Doc(id=1, value=1).save()
old_doc = Doc.objects(id=1).modify(set__value=-1)
self.assertEqual(old_doc.to_json(), doc.to_json())
self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}])
def test_modify_with_new(self):
Doc(id=0, value=0).save()
doc = Doc(id=1, value=1).save()
new_doc = Doc.objects(id=1).modify(set__value=-1, new=True)
doc.value = -1
self.assertEqual(new_doc.to_json(), doc.to_json())
self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}])
def test_modify_not_existing(self):
Doc(id=0, value=0).save()
self.assertEqual(Doc.objects(id=1).modify(set__value=-1), None)
self.assertDbEqual([{"_id": 0, "value": 0}])
def test_modify_with_upsert(self):
Doc(id=0, value=0).save()
old_doc = Doc.objects(id=1).modify(set__value=1, upsert=True)
self.assertEqual(old_doc, None)
self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": 1}])
def test_modify_with_upsert_existing(self):
Doc(id=0, value=0).save()
doc = Doc(id=1, value=1).save()
old_doc = Doc.objects(id=1).modify(set__value=-1, upsert=True)
self.assertEqual(old_doc.to_json(), doc.to_json())
self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}])
def test_modify_with_upsert_with_new(self):
Doc(id=0, value=0).save()
new_doc = Doc.objects(id=1).modify(upsert=True, new=True, set__value=1)
self.assertEqual(new_doc.to_mongo(), {"_id": 1, "value": 1})
self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": 1}])
def test_modify_with_remove(self):
Doc(id=0, value=0).save()
doc = Doc(id=1, value=1).save()
old_doc = Doc.objects(id=1).modify(remove=True)
self.assertEqual(old_doc.to_json(), doc.to_json())
self.assertDbEqual([{"_id": 0, "value": 0}])
def test_find_and_modify_with_remove_not_existing(self):
Doc(id=0, value=0).save()
self.assertEqual(Doc.objects(id=1).modify(remove=True), None)
self.assertDbEqual([{"_id": 0, "value": 0}])
def test_modify_with_order_by(self):
Doc(id=0, value=3).save()
Doc(id=1, value=2).save()
Doc(id=2, value=1).save()
doc = Doc(id=3, value=0).save()
old_doc = Doc.objects().order_by("-id").modify(set__value=-1)
self.assertEqual(old_doc.to_json(), doc.to_json())
self.assertDbEqual([
{"_id": 0, "value": 3}, {"_id": 1, "value": 2},
{"_id": 2, "value": 1}, {"_id": 3, "value": -1}])
def test_modify_with_fields(self):
Doc(id=0, value=0).save()
Doc(id=1, value=1).save()
old_doc = Doc.objects(id=1).only("id").modify(set__value=-1)
self.assertEqual(old_doc.to_mongo(), {"_id": 1})
self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}])
if __name__ == '__main__':
unittest.main()

View File

@@ -14,9 +14,9 @@ from pymongo.read_preferences import ReadPreference
from bson import ObjectId
from mongoengine import *
from mongoengine.connection import get_connection
from mongoengine.connection import get_connection, get_db
from mongoengine.python_support import PY3
from mongoengine.context_managers import query_counter
from mongoengine.context_managers import query_counter, switch_db
from mongoengine.queryset import (QuerySet, QuerySetManager,
MultipleObjectsReturned, DoesNotExist,
queryset_manager)
@@ -25,10 +25,17 @@ from mongoengine.errors import InvalidQueryError
__all__ = ("QuerySetTest",)
class db_ops_tracker(query_counter):
def get_ops(self):
ignore_query = {"ns": {"$ne": "%s.system.indexes" % self.db.name}}
return list(self.db.system.profile.find(ignore_query))
class QuerySetTest(unittest.TestCase):
def setUp(self):
connect(db='mongoenginetest')
connect(db='mongoenginetest2', alias='test2')
class PersonMeta(EmbeddedDocument):
weight = IntField()
@@ -650,7 +657,10 @@ class QuerySetTest(unittest.TestCase):
blogs.append(Blog(title="post %s" % i, posts=[post1, post2]))
Blog.objects.insert(blogs, load_bulk=False)
self.assertEqual(q, 1) # 1 for the insert
if (get_connection().max_wire_version <= 1):
self.assertEqual(q, 1)
else:
self.assertEqual(q, 99) # profiling logs each doc now in the bulk op
Blog.drop_collection()
Blog.ensure_indexes()
@@ -659,7 +669,10 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(q, 0)
Blog.objects.insert(blogs)
self.assertEqual(q, 2) # 1 for insert, and 1 for in bulk fetch
if (get_connection().max_wire_version <= 1):
self.assertEqual(q, 2) # 1 for insert, and 1 for in bulk fetch
else:
self.assertEqual(q, 100) # 99 for insert, and 1 for in bulk fetch
Blog.drop_collection()
@@ -1040,6 +1053,54 @@ class QuerySetTest(unittest.TestCase):
expected = [blog_post_1, blog_post_2, blog_post_3]
self.assertSequence(qs, expected)
def test_clear_ordering(self):
""" Ensure that the default ordering can be cleared by calling order_by().
"""
class BlogPost(Document):
title = StringField()
published_date = DateTimeField()
meta = {
'ordering': ['-published_date']
}
BlogPost.drop_collection()
with db_ops_tracker() as q:
BlogPost.objects.filter(title='whatever').first()
self.assertEqual(len(q.get_ops()), 1)
self.assertEqual(q.get_ops()[0]['query']['$orderby'], {u'published_date': -1})
with db_ops_tracker() as q:
BlogPost.objects.filter(title='whatever').order_by().first()
self.assertEqual(len(q.get_ops()), 1)
print q.get_ops()[0]['query']
self.assertFalse('$orderby' in q.get_ops()[0]['query'])
def test_no_ordering_for_get(self):
""" Ensure that Doc.objects.get doesn't use any ordering.
"""
class BlogPost(Document):
title = StringField()
published_date = DateTimeField()
meta = {
'ordering': ['-published_date']
}
BlogPost.objects.create(title='whatever', published_date=datetime.utcnow())
with db_ops_tracker() as q:
BlogPost.objects.get(title='whatever')
self.assertEqual(len(q.get_ops()), 1)
self.assertFalse('$orderby' in q.get_ops()[0]['query'])
# Ordering should be ignored for .get even if we set it explicitly
with db_ops_tracker() as q:
BlogPost.objects.order_by('-title').get(title='whatever')
self.assertEqual(len(q.get_ops()), 1)
self.assertFalse('$orderby' in q.get_ops()[0]['query'])
def test_find_embedded(self):
"""Ensure that an embedded document is properly returned from a query.
"""
@@ -1925,6 +1986,140 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection()
def test_map_reduce_custom_output(self):
"""
Test map/reduce custom output
"""
register_connection('test2', 'mongoenginetest2')
class Family(Document):
id = IntField(
primary_key=True)
log = StringField()
class Person(Document):
id = IntField(
primary_key=True)
name = StringField()
age = IntField()
family = ReferenceField(Family)
Family.drop_collection()
Person.drop_collection()
# creating first family
f1 = Family(id=1, log="Trav 02 de Julho")
f1.save()
# persons of first family
Person(id=1, family=f1, name=u"Wilson Jr", age=21).save()
Person(id=2, family=f1, name=u"Wilson Father", age=45).save()
Person(id=3, family=f1, name=u"Eliana Costa", age=40).save()
Person(id=4, family=f1, name=u"Tayza Mariana", age=17).save()
# creating second family
f2 = Family(id=2, log="Av prof frasc brunno")
f2.save()
#persons of second family
Person(id=5, family=f2, name="Isabella Luanna", age=16).save()
Person(id=6, family=f2, name="Sandra Mara", age=36).save()
Person(id=7, family=f2, name="Igor Gabriel", age=10).save()
# creating third family
f3 = Family(id=3, log="Av brazil")
f3.save()
#persons of thrird family
Person(id=8, family=f3, name="Arthur WA", age=30).save()
Person(id=9, family=f3, name="Paula Leonel", age=25).save()
# executing join map/reduce
map_person = """
function () {
emit(this.family, {
totalAge: this.age,
persons: [{
name: this.name,
age: this.age
}]});
}
"""
map_family = """
function () {
emit(this._id, {
totalAge: 0,
persons: []
});
}
"""
reduce_f = """
function (key, values) {
var family = {persons: [], totalAge: 0};
values.forEach(function(value) {
if (value.persons) {
value.persons.forEach(function (person) {
family.persons.push(person);
family.totalAge += person.age;
});
}
});
return family;
}
"""
cursor = Family.objects.map_reduce(
map_f=map_family,
reduce_f=reduce_f,
output={'replace': 'family_map', 'db_alias': 'test2'})
# start a map/reduce
cursor.next()
results = Person.objects.map_reduce(
map_f=map_person,
reduce_f=reduce_f,
output={'reduce': 'family_map', 'db_alias': 'test2'})
results = list(results)
collection = get_db('test2').family_map
self.assertEqual(
collection.find_one({'_id': 1}), {
'_id': 1,
'value': {
'persons': [
{'age': 21, 'name': u'Wilson Jr'},
{'age': 45, 'name': u'Wilson Father'},
{'age': 40, 'name': u'Eliana Costa'},
{'age': 17, 'name': u'Tayza Mariana'}],
'totalAge': 123}
})
self.assertEqual(
collection.find_one({'_id': 2}), {
'_id': 2,
'value': {
'persons': [
{'age': 16, 'name': u'Isabella Luanna'},
{'age': 36, 'name': u'Sandra Mara'},
{'age': 10, 'name': u'Igor Gabriel'}],
'totalAge': 62}
})
self.assertEqual(
collection.find_one({'_id': 3}), {
'_id': 3,
'value': {
'persons': [
{'age': 30, 'name': u'Arthur WA'},
{'age': 25, 'name': u'Paula Leonel'}],
'totalAge': 55}
})
def test_map_reduce_finalize(self):
"""Ensure that map, reduce, and finalize run and introduce "scope"
by simulating "hotness" ranking with Reddit algorithm.
@@ -2540,6 +2735,27 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(authors, [mark_twain, john_tolkien])
def test_distinct_ListField_ReferenceField(self):
class Foo(Document):
bar_lst = ListField(ReferenceField('Bar'))
class Bar(Document):
text = StringField()
Bar.drop_collection()
Foo.drop_collection()
bar_1 = Bar(text="hi")
bar_1.save()
bar_2 = Bar(text="bye")
bar_2.save()
foo = Foo(bar=bar_1, bar_lst=[bar_1, bar_2])
foo.save()
self.assertEqual(Foo.objects.distinct("bar_lst"), [bar_1, bar_2])
def test_custom_manager(self):
"""Ensure that custom QuerySetManager instances work as expected.
"""
@@ -2957,6 +3173,23 @@ class QuerySetTest(unittest.TestCase):
Number.drop_collection()
def test_using(self):
"""Ensure that switching databases for a queryset is possible
"""
class Number2(Document):
n = IntField()
Number2.drop_collection()
with switch_db(Number2, 'test2') as Number2:
Number2.drop_collection()
for i in range(1, 10):
t = Number2(n=i)
t.switch_db('test2')
t.save()
self.assertEqual(len(Number2.objects.using('test2')), 9)
def test_unset_reference(self):
class Comment(Document):
text = StringField()
@@ -3586,7 +3819,13 @@ class QuerySetTest(unittest.TestCase):
[x for x in people]
self.assertEqual(100, len(people._result_cache))
self.assertEqual(None, people._len)
import platform
if platform.python_implementation() != "PyPy":
# PyPy evaluates __len__ when iterating with list comprehensions while CPython does not.
# This may be a bug in PyPy (PyPy/#1802) but it does not affect the behavior of MongoEngine.
self.assertEqual(None, people._len)
self.assertEqual(q, 1)
list(people)
@@ -3814,6 +4053,111 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(Example.objects(size=instance_size).count(), 1)
self.assertEqual(Example.objects(size__in=[instance_size]).count(), 1)
def test_cursor_in_an_if_stmt(self):
class Test(Document):
test_field = StringField()
Test.drop_collection()
queryset = Test.objects
if queryset:
raise AssertionError('Empty cursor returns True')
test = Test()
test.test_field = 'test'
test.save()
queryset = Test.objects
if not test:
raise AssertionError('Cursor has data and returned False')
queryset.next()
if not queryset:
raise AssertionError('Cursor has data and it must returns True,'
' even in the last item.')
def test_bool_performance(self):
class Person(Document):
name = StringField()
Person.drop_collection()
for i in xrange(100):
Person(name="No: %s" % i).save()
with query_counter() as q:
if Person.objects:
pass
self.assertEqual(q, 1)
op = q.db.system.profile.find({"ns":
{"$ne": "%s.system.indexes" % q.db.name}})[0]
self.assertEqual(op['nreturned'], 1)
def test_bool_with_ordering(self):
class Person(Document):
name = StringField()
Person.drop_collection()
Person(name="Test").save()
qs = Person.objects.order_by('name')
with query_counter() as q:
if qs:
pass
op = q.db.system.profile.find({"ns":
{"$ne": "%s.system.indexes" % q.db.name}})[0]
self.assertFalse('$orderby' in op['query'],
'BaseQuerySet cannot use orderby in if stmt')
with query_counter() as p:
for x in qs:
pass
op = p.db.system.profile.find({"ns":
{"$ne": "%s.system.indexes" % q.db.name}})[0]
self.assertTrue('$orderby' in op['query'],
'BaseQuerySet cannot remove orderby in for loop')
def test_bool_with_ordering_from_meta_dict(self):
class Person(Document):
name = StringField()
meta = {
'ordering': ['name']
}
Person.drop_collection()
Person(name="B").save()
Person(name="C").save()
Person(name="A").save()
with query_counter() as q:
if Person.objects:
pass
op = q.db.system.profile.find({"ns":
{"$ne": "%s.system.indexes" % q.db.name}})[0]
self.assertFalse('$orderby' in op['query'],
'BaseQuerySet must remove orderby from meta in boolen test')
self.assertEqual(Person.objects.first().name, 'A')
self.assertTrue(Person.objects._has_data(),
'Cursor has data and returned False')
if __name__ == '__main__':
unittest.main()

View File

@@ -1,6 +1,11 @@
import sys
sys.path[0:0] = [""]
import unittest
try:
import unittest2 as unittest
except ImportError:
import unittest
import datetime
import pymongo
@@ -34,6 +39,17 @@ class ConnectionTest(unittest.TestCase):
conn = get_connection('testdb')
self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient))
def test_sharing_connections(self):
"""Ensure that connections are shared when the connection settings are exactly the same
"""
connect('mongoenginetest', alias='testdb1')
expected_connection = get_connection('testdb1')
connect('mongoenginetest', alias='testdb2')
actual_connection = get_connection('testdb2')
self.assertEqual(expected_connection, actual_connection)
def test_connect_uri(self):
"""Ensure that the connect() method works properly with uri's
"""

View File

@@ -0,0 +1,107 @@
import unittest
from mongoengine.base.datastructures import StrictDict, SemiStrictDict
class TestStrictDict(unittest.TestCase):
def strict_dict_class(self, *args, **kwargs):
return StrictDict.create(*args, **kwargs)
def setUp(self):
self.dtype = self.strict_dict_class(("a", "b", "c"))
def test_init(self):
d = self.dtype(a=1, b=1, c=1)
self.assertEqual((d.a, d.b, d.c), (1, 1, 1))
def test_init_fails_on_nonexisting_attrs(self):
self.assertRaises(AttributeError, lambda: self.dtype(a=1, b=2, d=3))
def test_eq(self):
d = self.dtype(a=1, b=1, c=1)
dd = self.dtype(a=1, b=1, c=1)
e = self.dtype(a=1, b=1, c=3)
f = self.dtype(a=1, b=1)
g = self.strict_dict_class(("a", "b", "c", "d"))(a=1, b=1, c=1, d=1)
h = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=1)
i = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=2)
self.assertEqual(d, dd)
self.assertNotEqual(d, e)
self.assertNotEqual(d, f)
self.assertNotEqual(d, g)
self.assertNotEqual(f, d)
self.assertEqual(d, h)
self.assertNotEqual(d, i)
def test_setattr_getattr(self):
d = self.dtype()
d.a = 1
self.assertEqual(d.a, 1)
self.assertRaises(AttributeError, lambda: d.b)
def test_setattr_raises_on_nonexisting_attr(self):
d = self.dtype()
def _f():
d.x=1
self.assertRaises(AttributeError, _f)
def test_setattr_getattr_special(self):
d = self.strict_dict_class(["items"])
d.items = 1
self.assertEqual(d.items, 1)
def test_get(self):
d = self.dtype(a=1)
self.assertEqual(d.get('a'), 1)
self.assertEqual(d.get('b', 'bla'), 'bla')
def test_items(self):
d = self.dtype(a=1)
self.assertEqual(d.items(), [('a', 1)])
d = self.dtype(a=1, b=2)
self.assertEqual(d.items(), [('a', 1), ('b', 2)])
def test_mappings_protocol(self):
d = self.dtype(a=1, b=2)
assert dict(d) == {'a': 1, 'b': 2}
assert dict(**d) == {'a': 1, 'b': 2}
class TestSemiSrictDict(TestStrictDict):
def strict_dict_class(self, *args, **kwargs):
return SemiStrictDict.create(*args, **kwargs)
def test_init_fails_on_nonexisting_attrs(self):
# disable irrelevant test
pass
def test_setattr_raises_on_nonexisting_attr(self):
# disable irrelevant test
pass
def test_setattr_getattr_nonexisting_attr_succeeds(self):
d = self.dtype()
d.x = 1
self.assertEqual(d.x, 1)
def test_init_succeeds_with_nonexisting_attrs(self):
d = self.dtype(a=1, b=1, c=1, x=2)
self.assertEqual((d.a, d.b, d.c, d.x), (1, 1, 1, 2))
def test_iter_with_nonexisting_attrs(self):
d = self.dtype(a=1, b=1, c=1, x=2)
self.assertEqual(list(d), ['a', 'b', 'c', 'x'])
def test_iteritems_with_nonexisting_attrs(self):
d = self.dtype(a=1, b=1, c=1, x=2)
self.assertEqual(list(d.iteritems()), [('a', 1), ('b', 1), ('c', 1), ('x', 2)])
def tets_cmp_with_strict_dicts(self):
d = self.dtype(a=1, b=1, c=1)
dd = StrictDict.create(("a", "b", "c"))(a=1, b=1, c=1)
self.assertEqual(d, dd)
def test_cmp_with_strict_dict_with_nonexisting_attrs(self):
d = self.dtype(a=1, b=1, c=1, x=2)
dd = StrictDict.create(("a", "b", "c", "x"))(a=1, b=1, c=1, x=2)
self.assertEqual(d, dd)
if __name__ == '__main__':
unittest.main()

View File

@@ -291,6 +291,30 @@ class FieldTest(unittest.TestCase):
self.assertEqual(employee.friends, friends)
self.assertEqual(q, 2)
def test_list_of_lists_of_references(self):
class User(Document):
name = StringField()
class Post(Document):
user_lists = ListField(ListField(ReferenceField(User)))
class SimpleList(Document):
users = ListField(ReferenceField(User))
User.drop_collection()
Post.drop_collection()
u1 = User.objects.create(name='u1')
u2 = User.objects.create(name='u2')
u3 = User.objects.create(name='u3')
SimpleList.objects.create(users=[u1, u2, u3])
self.assertEqual(SimpleList.objects.all()[0].users, [u1, u2, u3])
Post.objects.create(user_lists=[[u1, u2], [u3]])
self.assertEqual(Post.objects.all()[0].user_lists, [[u1, u2], [u3]])
def test_circular_reference(self):
"""Ensure you can handle circular references
"""

View File

@@ -54,7 +54,9 @@ class SignalTests(unittest.TestCase):
@classmethod
def post_save(cls, sender, document, **kwargs):
dirty_keys = document._delta()[0].keys() + document._delta()[1].keys()
signal_output.append('post_save signal, %s' % document)
signal_output.append('post_save dirty keys, %s' % dirty_keys)
if 'created' in kwargs:
if kwargs['created']:
signal_output.append('Is created')
@@ -203,6 +205,7 @@ class SignalTests(unittest.TestCase):
"pre_save_post_validation signal, Bill Shakespeare",
"Is created",
"post_save signal, Bill Shakespeare",
"post_save dirty keys, ['name']",
"Is created"
])
@@ -213,6 +216,7 @@ class SignalTests(unittest.TestCase):
"pre_save_post_validation signal, William Shakespeare",
"Is updated",
"post_save signal, William Shakespeare",
"post_save dirty keys, ['name']",
"Is updated"
])