Merge pull request #1 from MongoEngine/master
pull new changes from original
This commit is contained in:
@@ -5,6 +5,7 @@ import sys
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
import pymongo
|
||||
from random import randint
|
||||
|
||||
from nose.plugins.skip import SkipTest
|
||||
from datetime import datetime
|
||||
@@ -16,9 +17,11 @@ __all__ = ("IndexesTest", )
|
||||
|
||||
|
||||
class IndexesTest(unittest.TestCase):
|
||||
_MAX_RAND = 10 ** 10
|
||||
|
||||
def setUp(self):
|
||||
self.connection = connect(db='mongoenginetest')
|
||||
self.db_name = 'mongoenginetest_IndexesTest_' + str(randint(0, self._MAX_RAND))
|
||||
self.connection = connect(db=self.db_name)
|
||||
self.db = get_db()
|
||||
|
||||
class Person(Document):
|
||||
@@ -32,10 +35,7 @@ class IndexesTest(unittest.TestCase):
|
||||
self.Person = Person
|
||||
|
||||
def tearDown(self):
|
||||
for collection in self.db.collection_names():
|
||||
if 'system.' in collection:
|
||||
continue
|
||||
self.db.drop_collection(collection)
|
||||
self.connection.drop_database(self.db)
|
||||
|
||||
def test_indexes_document(self):
|
||||
"""Ensure that indexes are used when meta[indexes] is specified for
|
||||
@@ -822,33 +822,29 @@ class IndexesTest(unittest.TestCase):
|
||||
name = StringField(required=True)
|
||||
term = StringField(required=True)
|
||||
|
||||
class Report(Document):
|
||||
class ReportEmbedded(Document):
|
||||
key = EmbeddedDocumentField(CompoundKey, primary_key=True)
|
||||
text = StringField()
|
||||
|
||||
Report.drop_collection()
|
||||
|
||||
my_key = CompoundKey(name="n", term="ok")
|
||||
report = Report(text="OK", key=my_key).save()
|
||||
report = ReportEmbedded(text="OK", key=my_key).save()
|
||||
|
||||
self.assertEqual({'text': 'OK', '_id': {'term': 'ok', 'name': 'n'}},
|
||||
report.to_mongo())
|
||||
self.assertEqual(report, Report.objects.get(pk=my_key))
|
||||
self.assertEqual(report, ReportEmbedded.objects.get(pk=my_key))
|
||||
|
||||
def test_compound_key_dictfield(self):
|
||||
|
||||
class Report(Document):
|
||||
class ReportDictField(Document):
|
||||
key = DictField(primary_key=True)
|
||||
text = StringField()
|
||||
|
||||
Report.drop_collection()
|
||||
|
||||
my_key = {"name": "n", "term": "ok"}
|
||||
report = Report(text="OK", key=my_key).save()
|
||||
report = ReportDictField(text="OK", key=my_key).save()
|
||||
|
||||
self.assertEqual({'text': 'OK', '_id': {'term': 'ok', 'name': 'n'}},
|
||||
report.to_mongo())
|
||||
self.assertEqual(report, Report.objects.get(pk=my_key))
|
||||
self.assertEqual(report, ReportDictField.objects.get(pk=my_key))
|
||||
|
||||
def test_string_indexes(self):
|
||||
|
||||
@@ -863,6 +859,20 @@ class IndexesTest(unittest.TestCase):
|
||||
self.assertTrue([('provider_ids.foo', 1)] in info)
|
||||
self.assertTrue([('provider_ids.bar', 1)] in info)
|
||||
|
||||
def test_sparse_compound_indexes(self):
|
||||
|
||||
class MyDoc(Document):
|
||||
provider_ids = DictField()
|
||||
meta = {
|
||||
"indexes": [{'fields': ("provider_ids.foo", "provider_ids.bar"),
|
||||
'sparse': True}],
|
||||
}
|
||||
|
||||
info = MyDoc.objects._collection.index_information()
|
||||
self.assertEqual([('provider_ids.foo', 1), ('provider_ids.bar', 1)],
|
||||
info['provider_ids.foo_1_provider_ids.bar_1']['key'])
|
||||
self.assertTrue(info['provider_ids.foo_1_provider_ids.bar_1']['sparse'])
|
||||
|
||||
def test_text_indexes(self):
|
||||
|
||||
class Book(Document):
|
||||
@@ -895,26 +905,38 @@ class IndexesTest(unittest.TestCase):
|
||||
|
||||
Issue #812
|
||||
"""
|
||||
# Use a new connection and database since dropping the database could
|
||||
# cause concurrent tests to fail.
|
||||
connection = connect(db='tempdatabase',
|
||||
alias='test_indexes_after_database_drop')
|
||||
|
||||
class BlogPost(Document):
|
||||
title = StringField()
|
||||
slug = StringField(unique=True)
|
||||
|
||||
BlogPost.drop_collection()
|
||||
meta = {'db_alias': 'test_indexes_after_database_drop'}
|
||||
|
||||
# Create Post #1
|
||||
post1 = BlogPost(title='test1', slug='test')
|
||||
post1.save()
|
||||
try:
|
||||
BlogPost.drop_collection()
|
||||
|
||||
# Drop the Database
|
||||
self.connection.drop_database(BlogPost._get_db().name)
|
||||
# Create Post #1
|
||||
post1 = BlogPost(title='test1', slug='test')
|
||||
post1.save()
|
||||
|
||||
# Re-create Post #1
|
||||
post1 = BlogPost(title='test1', slug='test')
|
||||
post1.save()
|
||||
# Drop the Database
|
||||
connection.drop_database('tempdatabase')
|
||||
|
||||
# Re-create Post #1
|
||||
post1 = BlogPost(title='test1', slug='test')
|
||||
post1.save()
|
||||
|
||||
# Create Post #2
|
||||
post2 = BlogPost(title='test2', slug='test')
|
||||
self.assertRaises(NotUniqueError, post2.save)
|
||||
finally:
|
||||
# Drop the temporary database at the end
|
||||
connection.drop_database('tempdatabase')
|
||||
|
||||
# Create Post #2
|
||||
post2 = BlogPost(title='test2', slug='test')
|
||||
self.assertRaises(NotUniqueError, post2.save)
|
||||
|
||||
def test_index_dont_send_cls_option(self):
|
||||
"""
|
||||
|
||||
@@ -411,7 +411,7 @@ class InheritanceTest(unittest.TestCase):
|
||||
try:
|
||||
class MyDocument(DateCreatedDocument, DateUpdatedDocument):
|
||||
pass
|
||||
except:
|
||||
except Exception:
|
||||
self.assertTrue(False, "Couldn't create MyDocument class")
|
||||
|
||||
def test_abstract_documents(self):
|
||||
|
||||
@@ -7,12 +7,13 @@ import os
|
||||
import pickle
|
||||
import unittest
|
||||
import uuid
|
||||
import weakref
|
||||
|
||||
from datetime import datetime
|
||||
from bson import DBRef, ObjectId
|
||||
from tests import fixtures
|
||||
from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest,
|
||||
PickleDyanmicEmbedded, PickleDynamicTest)
|
||||
PickleDynamicEmbedded, PickleDynamicTest)
|
||||
|
||||
from mongoengine import *
|
||||
from mongoengine.errors import (NotRegistered, InvalidDocumentError,
|
||||
@@ -30,6 +31,8 @@ TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__),
|
||||
__all__ = ("InstanceTest",)
|
||||
|
||||
|
||||
|
||||
|
||||
class InstanceTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@@ -63,6 +66,14 @@ class InstanceTest(unittest.TestCase):
|
||||
list(self.Person._get_collection().find().sort("id")),
|
||||
sorted(docs, key=lambda doc: doc["_id"]))
|
||||
|
||||
def assertHasInstance(self, field, instance):
|
||||
self.assertTrue(hasattr(field, "_instance"))
|
||||
self.assertTrue(field._instance is not None)
|
||||
if isinstance(field._instance, weakref.ProxyType):
|
||||
self.assertTrue(field._instance.__eq__(instance))
|
||||
else:
|
||||
self.assertEqual(field._instance, instance)
|
||||
|
||||
def test_capped_collection(self):
|
||||
"""Ensure that capped collections work properly.
|
||||
"""
|
||||
@@ -473,6 +484,20 @@ class InstanceTest(unittest.TestCase):
|
||||
doc.reload()
|
||||
Animal.drop_collection()
|
||||
|
||||
def test_reload_sharded_nested(self):
|
||||
class SuperPhylum(EmbeddedDocument):
|
||||
name = StringField()
|
||||
|
||||
class Animal(Document):
|
||||
superphylum = EmbeddedDocumentField(SuperPhylum)
|
||||
meta = {'shard_key': ('superphylum.name',)}
|
||||
|
||||
Animal.drop_collection()
|
||||
doc = Animal(superphylum=SuperPhylum(name='Deuterostomia'))
|
||||
doc.save()
|
||||
doc.reload()
|
||||
Animal.drop_collection()
|
||||
|
||||
def test_reload_referencing(self):
|
||||
"""Ensures reloading updates weakrefs correctly
|
||||
"""
|
||||
@@ -546,6 +571,28 @@ class InstanceTest(unittest.TestCase):
|
||||
except Exception:
|
||||
self.assertFalse("Threw wrong exception")
|
||||
|
||||
def test_reload_of_non_strict_with_special_field_name(self):
|
||||
"""Ensures reloading works for documents with meta strict == False
|
||||
"""
|
||||
class Post(Document):
|
||||
meta = {
|
||||
'strict': False
|
||||
}
|
||||
title = StringField()
|
||||
items = ListField()
|
||||
|
||||
Post.drop_collection()
|
||||
|
||||
Post._get_collection().insert({
|
||||
"title": "Items eclipse",
|
||||
"items": ["more lorem", "even more ipsum"]
|
||||
})
|
||||
|
||||
post = Post.objects.first()
|
||||
post.reload()
|
||||
self.assertEqual(post.title, "Items eclipse")
|
||||
self.assertEqual(post.items, ["more lorem", "even more ipsum"])
|
||||
|
||||
def test_dictionary_access(self):
|
||||
"""Ensure that dictionary-style field access works properly.
|
||||
"""
|
||||
@@ -608,10 +655,12 @@ class InstanceTest(unittest.TestCase):
|
||||
embedded_field = EmbeddedDocumentField(Embedded)
|
||||
|
||||
Doc.drop_collection()
|
||||
Doc(embedded_field=Embedded(string="Hi")).save()
|
||||
doc = Doc(embedded_field=Embedded(string="Hi"))
|
||||
self.assertHasInstance(doc.embedded_field, doc)
|
||||
|
||||
doc.save()
|
||||
doc = Doc.objects.get()
|
||||
self.assertEqual(doc, doc.embedded_field._instance)
|
||||
self.assertHasInstance(doc.embedded_field, doc)
|
||||
|
||||
def test_embedded_document_complex_instance(self):
|
||||
"""Ensure that embedded documents in complex fields can reference
|
||||
@@ -623,10 +672,25 @@ class InstanceTest(unittest.TestCase):
|
||||
embedded_field = ListField(EmbeddedDocumentField(Embedded))
|
||||
|
||||
Doc.drop_collection()
|
||||
Doc(embedded_field=[Embedded(string="Hi")]).save()
|
||||
doc = Doc(embedded_field=[Embedded(string="Hi")])
|
||||
self.assertHasInstance(doc.embedded_field[0], doc)
|
||||
|
||||
doc.save()
|
||||
doc = Doc.objects.get()
|
||||
self.assertEqual(doc, doc.embedded_field[0]._instance)
|
||||
self.assertHasInstance(doc.embedded_field[0], doc)
|
||||
|
||||
def test_embedded_document_complex_instance_no_use_db_field(self):
|
||||
"""Ensure that use_db_field is propagated to list of Emb Docs
|
||||
"""
|
||||
class Embedded(EmbeddedDocument):
|
||||
string = StringField(db_field='s')
|
||||
|
||||
class Doc(Document):
|
||||
embedded_field = ListField(EmbeddedDocumentField(Embedded))
|
||||
|
||||
d = Doc(embedded_field=[Embedded(string="Hi")]).to_mongo(
|
||||
use_db_field=False).to_dict()
|
||||
self.assertEqual(d['embedded_field'], [{'string': 'Hi'}])
|
||||
|
||||
def test_instance_is_set_on_setattr(self):
|
||||
|
||||
@@ -639,11 +703,28 @@ class InstanceTest(unittest.TestCase):
|
||||
Account.drop_collection()
|
||||
acc = Account()
|
||||
acc.email = Email(email='test@example.com')
|
||||
self.assertTrue(hasattr(acc._data["email"], "_instance"))
|
||||
self.assertHasInstance(acc._data["email"], acc)
|
||||
acc.save()
|
||||
|
||||
acc1 = Account.objects.first()
|
||||
self.assertTrue(hasattr(acc1._data["email"], "_instance"))
|
||||
self.assertHasInstance(acc1._data["email"], acc1)
|
||||
|
||||
def test_instance_is_set_on_setattr_on_embedded_document_list(self):
|
||||
|
||||
class Email(EmbeddedDocument):
|
||||
email = EmailField()
|
||||
|
||||
class Account(Document):
|
||||
emails = EmbeddedDocumentListField(Email)
|
||||
|
||||
Account.drop_collection()
|
||||
acc = Account()
|
||||
acc.emails = [Email(email='test@example.com')]
|
||||
self.assertHasInstance(acc._data["emails"][0], acc)
|
||||
acc.save()
|
||||
|
||||
acc1 = Account.objects.first()
|
||||
self.assertHasInstance(acc1._data["emails"][0], acc1)
|
||||
|
||||
def test_document_clean(self):
|
||||
class TestDocument(Document):
|
||||
@@ -1825,6 +1906,62 @@ class InstanceTest(unittest.TestCase):
|
||||
author.delete()
|
||||
self.assertEqual(BlogPost.objects.count(), 0)
|
||||
|
||||
def test_reverse_delete_rule_with_custom_id_field(self):
|
||||
"""Ensure that a referenced document with custom primary key
|
||||
is also deleted upon deletion.
|
||||
"""
|
||||
class User(Document):
|
||||
name = StringField(primary_key=True)
|
||||
|
||||
class Book(Document):
|
||||
author = ReferenceField(User, reverse_delete_rule=CASCADE)
|
||||
reviewer = ReferenceField(User, reverse_delete_rule=NULLIFY)
|
||||
|
||||
User.drop_collection()
|
||||
Book.drop_collection()
|
||||
|
||||
user = User(name='Mike').save()
|
||||
reviewer = User(name='John').save()
|
||||
book = Book(author=user, reviewer=reviewer).save()
|
||||
|
||||
reviewer.delete()
|
||||
self.assertEqual(Book.objects.count(), 1)
|
||||
self.assertEqual(Book.objects.get().reviewer, None)
|
||||
|
||||
user.delete()
|
||||
self.assertEqual(Book.objects.count(), 0)
|
||||
|
||||
def test_reverse_delete_rule_with_shared_id_among_collections(self):
|
||||
"""Ensure that cascade delete rule doesn't mix id among collections.
|
||||
"""
|
||||
class User(Document):
|
||||
id = IntField(primary_key=True)
|
||||
|
||||
class Book(Document):
|
||||
id = IntField(primary_key=True)
|
||||
author = ReferenceField(User, reverse_delete_rule=CASCADE)
|
||||
|
||||
User.drop_collection()
|
||||
Book.drop_collection()
|
||||
|
||||
user_1 = User(id=1).save()
|
||||
user_2 = User(id=2).save()
|
||||
book_1 = Book(id=1, author=user_2).save()
|
||||
book_2 = Book(id=2, author=user_1).save()
|
||||
|
||||
user_2.delete()
|
||||
# Deleting user_2 should also delete book_1 but not book_2
|
||||
self.assertEqual(Book.objects.count(), 1)
|
||||
self.assertEqual(Book.objects.get(), book_2)
|
||||
|
||||
user_3 = User(id=3).save()
|
||||
book_3 = Book(id=3, author=user_3).save()
|
||||
|
||||
user_3.delete()
|
||||
# Deleting user_3 should also delete book_3
|
||||
self.assertEqual(Book.objects.count(), 1)
|
||||
self.assertEqual(Book.objects.get(), book_2)
|
||||
|
||||
def test_reverse_delete_rule_with_document_inheritance(self):
|
||||
"""Ensure that a referenced document is also deleted upon deletion
|
||||
of a child document.
|
||||
@@ -2180,7 +2317,7 @@ class InstanceTest(unittest.TestCase):
|
||||
|
||||
pickle_doc = PickleDynamicTest(
|
||||
name="test", number=1, string="One", lists=['1', '2'])
|
||||
pickle_doc.embedded = PickleDyanmicEmbedded(foo="Bar")
|
||||
pickle_doc.embedded = PickleDynamicEmbedded(foo="Bar")
|
||||
pickled_doc = pickle.dumps(pickle_doc) # make sure pickling works even before the doc is saved
|
||||
|
||||
pickle_doc.save()
|
||||
@@ -2683,6 +2820,32 @@ class InstanceTest(unittest.TestCase):
|
||||
|
||||
self.assertRaises(OperationError, change_shard_key)
|
||||
|
||||
def test_shard_key_in_embedded_document(self):
|
||||
class Foo(EmbeddedDocument):
|
||||
foo = StringField()
|
||||
|
||||
class Bar(Document):
|
||||
meta = {
|
||||
'shard_key': ('foo.foo',)
|
||||
}
|
||||
foo = EmbeddedDocumentField(Foo)
|
||||
bar = StringField()
|
||||
|
||||
foo_doc = Foo(foo='hello')
|
||||
bar_doc = Bar(foo=foo_doc, bar='world')
|
||||
bar_doc.save()
|
||||
|
||||
self.assertTrue(bar_doc.id is not None)
|
||||
|
||||
bar_doc.bar = 'baz'
|
||||
bar_doc.save()
|
||||
|
||||
def change_shard_key():
|
||||
bar_doc.foo.foo = 'something'
|
||||
bar_doc.save()
|
||||
|
||||
self.assertRaises(OperationError, change_shard_key)
|
||||
|
||||
def test_shard_key_primary(self):
|
||||
class LogEntry(Document):
|
||||
machine = StringField(primary_key=True)
|
||||
@@ -2765,6 +2928,20 @@ class InstanceTest(unittest.TestCase):
|
||||
self.assertEqual(person.name, "Test User")
|
||||
self.assertEqual(person.age, 42)
|
||||
|
||||
def test_positional_creation_embedded(self):
|
||||
"""Ensure that embedded document may be created using positional arguments.
|
||||
"""
|
||||
job = self.Job("Test Job", 4)
|
||||
self.assertEqual(job.name, "Test Job")
|
||||
self.assertEqual(job.years, 4)
|
||||
|
||||
def test_mixed_creation_embedded(self):
|
||||
"""Ensure that embedded document may be created using mixed arguments.
|
||||
"""
|
||||
job = self.Job("Test Job", years=4)
|
||||
self.assertEqual(job.name, "Test Job")
|
||||
self.assertEqual(job.years, 4)
|
||||
|
||||
def test_mixed_creation_dynamic(self):
|
||||
"""Ensure that document may be created using mixed arguments.
|
||||
"""
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import sys
|
||||
|
||||
import six
|
||||
from nose.plugins.skip import SkipTest
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
@@ -10,6 +12,7 @@ import uuid
|
||||
import math
|
||||
import itertools
|
||||
import re
|
||||
import six
|
||||
|
||||
try:
|
||||
import dateutil
|
||||
@@ -19,6 +22,10 @@ except ImportError:
|
||||
from decimal import Decimal
|
||||
|
||||
from bson import Binary, DBRef, ObjectId
|
||||
try:
|
||||
from bson.int64 import Int64
|
||||
except ImportError:
|
||||
Int64 = long
|
||||
|
||||
from mongoengine import *
|
||||
from mongoengine.connection import get_db
|
||||
@@ -399,20 +406,37 @@ class FieldTest(unittest.TestCase):
|
||||
class Person(Document):
|
||||
height = FloatField(min_value=0.1, max_value=3.5)
|
||||
|
||||
class BigPerson(Document):
|
||||
height = FloatField()
|
||||
|
||||
person = Person()
|
||||
person.height = 1.89
|
||||
person.validate()
|
||||
|
||||
person.height = '2.0'
|
||||
self.assertRaises(ValidationError, person.validate)
|
||||
|
||||
person.height = 0.01
|
||||
self.assertRaises(ValidationError, person.validate)
|
||||
|
||||
person.height = 4.0
|
||||
self.assertRaises(ValidationError, person.validate)
|
||||
|
||||
person_2 = Person(height='something invalid')
|
||||
self.assertRaises(ValidationError, person_2.validate)
|
||||
|
||||
big_person = BigPerson()
|
||||
|
||||
for value, value_type in enumerate(six.integer_types):
|
||||
big_person.height = value_type(value)
|
||||
big_person.validate()
|
||||
|
||||
big_person.height = 2 ** 500
|
||||
big_person.validate()
|
||||
|
||||
big_person.height = 2 ** 100000 # Too big for a float value
|
||||
self.assertRaises(ValidationError, big_person.validate)
|
||||
|
||||
def test_decimal_validation(self):
|
||||
"""Ensure that invalid values cannot be assigned to decimal fields.
|
||||
"""
|
||||
@@ -1184,6 +1208,19 @@ class FieldTest(unittest.TestCase):
|
||||
simple = simple.reload()
|
||||
self.assertEqual(simple.widgets, [4])
|
||||
|
||||
def test_list_field_with_negative_indices(self):
|
||||
|
||||
class Simple(Document):
|
||||
widgets = ListField()
|
||||
|
||||
simple = Simple(widgets=[1, 2, 3, 4]).save()
|
||||
simple.widgets[-1] = 5
|
||||
self.assertEqual(['widgets.3'], simple._changed_fields)
|
||||
simple.save()
|
||||
|
||||
simple = simple.reload()
|
||||
self.assertEqual(simple.widgets, [1, 2, 3, 5])
|
||||
|
||||
def test_list_field_complex(self):
|
||||
"""Ensure that the list fields can handle the complex types."""
|
||||
|
||||
@@ -1563,6 +1600,29 @@ class FieldTest(unittest.TestCase):
|
||||
actions__friends__operation='drink',
|
||||
actions__friends__object='beer').count())
|
||||
|
||||
def test_map_field_unicode(self):
|
||||
|
||||
class Info(EmbeddedDocument):
|
||||
description = StringField()
|
||||
value_list = ListField(field=StringField())
|
||||
|
||||
class BlogPost(Document):
|
||||
info_dict = MapField(field=EmbeddedDocumentField(Info))
|
||||
|
||||
BlogPost.drop_collection()
|
||||
|
||||
tree = BlogPost(info_dict={
|
||||
u"éééé": {
|
||||
'description': u"VALUE: éééé"
|
||||
}
|
||||
})
|
||||
|
||||
tree.save()
|
||||
|
||||
self.assertEqual(BlogPost.objects.get(id=tree.id).info_dict[u"éééé"].description, u"VALUE: éééé")
|
||||
|
||||
BlogPost.drop_collection()
|
||||
|
||||
def test_embedded_db_field(self):
|
||||
|
||||
class Embedded(EmbeddedDocument):
|
||||
@@ -1599,6 +1659,8 @@ class FieldTest(unittest.TestCase):
|
||||
name = StringField()
|
||||
preferences = EmbeddedDocumentField(PersonPreferences)
|
||||
|
||||
Person.drop_collection()
|
||||
|
||||
person = Person(name='Test User')
|
||||
person.preferences = 'My Preferences'
|
||||
self.assertRaises(ValidationError, person.validate)
|
||||
@@ -1631,12 +1693,39 @@ class FieldTest(unittest.TestCase):
|
||||
content = StringField()
|
||||
author = EmbeddedDocumentField(User)
|
||||
|
||||
BlogPost.drop_collection()
|
||||
|
||||
post = BlogPost(content='What I did today...')
|
||||
post.author = PowerUser(name='Test User', power=47)
|
||||
post.save()
|
||||
|
||||
self.assertEqual(47, BlogPost.objects.first().author.power)
|
||||
|
||||
def test_embedded_document_inheritance_with_list(self):
|
||||
"""Ensure that nested list of subclassed embedded documents is
|
||||
handled correctly.
|
||||
"""
|
||||
|
||||
class Group(EmbeddedDocument):
|
||||
name = StringField()
|
||||
content = ListField(StringField())
|
||||
|
||||
class Basedoc(Document):
|
||||
groups = ListField(EmbeddedDocumentField(Group))
|
||||
meta = {'abstract': True}
|
||||
|
||||
class User(Basedoc):
|
||||
doctype = StringField(require=True, default='userdata')
|
||||
|
||||
User.drop_collection()
|
||||
|
||||
content = ['la', 'le', 'lu']
|
||||
group = Group(name='foo', content=content)
|
||||
foobar = User(groups=[group])
|
||||
foobar.save()
|
||||
|
||||
self.assertEqual(content, User.objects.first().groups[0].content)
|
||||
|
||||
def test_reference_validation(self):
|
||||
"""Ensure that invalid docment objects cannot be assigned to reference
|
||||
fields.
|
||||
@@ -2329,6 +2418,91 @@ class FieldTest(unittest.TestCase):
|
||||
Member.drop_collection()
|
||||
BlogPost.drop_collection()
|
||||
|
||||
def test_drop_abstract_document(self):
|
||||
"""Ensure that an abstract document cannot be dropped given it
|
||||
has no underlying collection.
|
||||
"""
|
||||
class AbstractDoc(Document):
|
||||
name = StringField()
|
||||
meta = {"abstract": True}
|
||||
|
||||
self.assertRaises(OperationError, AbstractDoc.drop_collection)
|
||||
|
||||
def test_reference_class_with_abstract_parent(self):
|
||||
"""Ensure that a class with an abstract parent can be referenced.
|
||||
"""
|
||||
class Sibling(Document):
|
||||
name = StringField()
|
||||
meta = {"abstract": True}
|
||||
|
||||
class Sister(Sibling):
|
||||
pass
|
||||
|
||||
class Brother(Sibling):
|
||||
sibling = ReferenceField(Sibling)
|
||||
|
||||
Sister.drop_collection()
|
||||
Brother.drop_collection()
|
||||
|
||||
sister = Sister(name="Alice")
|
||||
sister.save()
|
||||
brother = Brother(name="Bob", sibling=sister)
|
||||
brother.save()
|
||||
|
||||
self.assertEquals(Brother.objects[0].sibling.name, sister.name)
|
||||
|
||||
Sister.drop_collection()
|
||||
Brother.drop_collection()
|
||||
|
||||
def test_reference_abstract_class(self):
|
||||
"""Ensure that an abstract class instance cannot be used in the
|
||||
reference of that abstract class.
|
||||
"""
|
||||
class Sibling(Document):
|
||||
name = StringField()
|
||||
meta = {"abstract": True}
|
||||
|
||||
class Sister(Sibling):
|
||||
pass
|
||||
|
||||
class Brother(Sibling):
|
||||
sibling = ReferenceField(Sibling)
|
||||
|
||||
Sister.drop_collection()
|
||||
Brother.drop_collection()
|
||||
|
||||
sister = Sibling(name="Alice")
|
||||
brother = Brother(name="Bob", sibling=sister)
|
||||
self.assertRaises(ValidationError, brother.save)
|
||||
|
||||
Sister.drop_collection()
|
||||
Brother.drop_collection()
|
||||
|
||||
def test_abstract_reference_base_type(self):
|
||||
"""Ensure that an an abstract reference fails validation when given a
|
||||
Document that does not inherit from the abstract type.
|
||||
"""
|
||||
class Sibling(Document):
|
||||
name = StringField()
|
||||
meta = {"abstract": True}
|
||||
|
||||
class Brother(Sibling):
|
||||
sibling = ReferenceField(Sibling)
|
||||
|
||||
class Mother(Document):
|
||||
name = StringField()
|
||||
|
||||
Brother.drop_collection()
|
||||
Mother.drop_collection()
|
||||
|
||||
mother = Mother(name="Carol")
|
||||
mother.save()
|
||||
brother = Brother(name="Bob", sibling=mother)
|
||||
self.assertRaises(ValidationError, brother.save)
|
||||
|
||||
Brother.drop_collection()
|
||||
Mother.drop_collection()
|
||||
|
||||
def test_generic_reference(self):
|
||||
"""Ensure that a GenericReferenceField properly dereferences items.
|
||||
"""
|
||||
@@ -3353,7 +3527,7 @@ class FieldTest(unittest.TestCase):
|
||||
def __init__(self, **kwargs):
|
||||
super(EnumField, self).__init__(**kwargs)
|
||||
|
||||
def to_mongo(self, value):
|
||||
def to_mongo(self, value, **kwargs):
|
||||
return value
|
||||
|
||||
def to_python(self, value):
|
||||
@@ -3520,6 +3694,19 @@ class FieldTest(unittest.TestCase):
|
||||
|
||||
self.assertRaises(FieldDoesNotExist, test)
|
||||
|
||||
def test_long_field_is_considered_as_int64(self):
|
||||
"""
|
||||
Tests that long fields are stored as long in mongo, even if long value
|
||||
is small enough to be an int.
|
||||
"""
|
||||
class TestLongFieldConsideredAsInt64(Document):
|
||||
some_long = LongField()
|
||||
|
||||
doc = TestLongFieldConsideredAsInt64(some_long=42).save()
|
||||
db = get_db()
|
||||
self.assertTrue(isinstance(db.test_long_field_considered_as_int64.find()[0]['some_long'], Int64))
|
||||
self.assertTrue(isinstance(doc.some_long, six.integer_types))
|
||||
|
||||
|
||||
class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
|
||||
|
||||
@@ -3907,6 +4094,17 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
|
||||
# modified
|
||||
self.assertEqual(number, 2)
|
||||
|
||||
def test_unicode(self):
|
||||
"""
|
||||
Tests that unicode strings handled correctly
|
||||
"""
|
||||
post = self.BlogPost(comments=[
|
||||
self.Comments(author='user1', message=u'сообщение'),
|
||||
self.Comments(author='user2', message=u'хабарлама')
|
||||
]).save()
|
||||
self.assertEqual(post.comments.get(message=u'сообщение').author,
|
||||
'user1')
|
||||
|
||||
def test_save(self):
|
||||
"""
|
||||
Tests the save method of a List of Embedded Documents.
|
||||
|
||||
@@ -26,7 +26,7 @@ class NewDocumentPickleTest(Document):
|
||||
new_field = StringField()
|
||||
|
||||
|
||||
class PickleDyanmicEmbedded(DynamicEmbeddedDocument):
|
||||
class PickleDynamicEmbedded(DynamicEmbeddedDocument):
|
||||
date = DateTimeField(default=datetime.now)
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import unittest
|
||||
|
||||
from convert_to_new_inheritance_model import *
|
||||
from decimalfield_as_float import *
|
||||
from refrencefield_dbref_to_object_id import *
|
||||
from referencefield_dbref_to_object_id import *
|
||||
from turn_off_inheritance import *
|
||||
from uuidfield_to_binary import *
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -680,12 +680,21 @@ class QuerySetTest(unittest.TestCase):
|
||||
def test_upsert_one(self):
|
||||
self.Person.drop_collection()
|
||||
|
||||
self.Person.objects(name="Bob", age=30).update_one(upsert=True)
|
||||
bob = self.Person.objects(name="Bob", age=30).upsert_one()
|
||||
|
||||
bob = self.Person.objects.first()
|
||||
self.assertEqual("Bob", bob.name)
|
||||
self.assertEqual(30, bob.age)
|
||||
|
||||
bob.name = "Bobby"
|
||||
bob.save()
|
||||
|
||||
bobby = self.Person.objects(name="Bobby", age=30).upsert_one()
|
||||
|
||||
self.assertEqual("Bobby", bobby.name)
|
||||
self.assertEqual(30, bobby.age)
|
||||
self.assertEqual(bob.id, bobby.id)
|
||||
|
||||
|
||||
def test_set_on_insert(self):
|
||||
self.Person.drop_collection()
|
||||
|
||||
@@ -2757,25 +2766,15 @@ class QuerySetTest(unittest.TestCase):
|
||||
|
||||
avg = float(sum(ages)) / (len(ages) + 1) # take into account the 0
|
||||
self.assertAlmostEqual(int(self.Person.objects.average('age')), avg)
|
||||
self.assertAlmostEqual(
|
||||
int(self.Person.objects.aggregate_average('age')), avg
|
||||
)
|
||||
|
||||
self.Person(name='ageless person').save()
|
||||
self.assertEqual(int(self.Person.objects.average('age')), avg)
|
||||
self.assertEqual(
|
||||
int(self.Person.objects.aggregate_average('age')), avg
|
||||
)
|
||||
|
||||
# dot notation
|
||||
self.Person(
|
||||
name='person meta', person_meta=self.PersonMeta(weight=0)).save()
|
||||
self.assertAlmostEqual(
|
||||
int(self.Person.objects.average('person_meta.weight')), 0)
|
||||
self.assertAlmostEqual(
|
||||
int(self.Person.objects.aggregate_average('person_meta.weight')),
|
||||
0
|
||||
)
|
||||
|
||||
for i, weight in enumerate(ages):
|
||||
self.Person(
|
||||
@@ -2784,19 +2783,11 @@ class QuerySetTest(unittest.TestCase):
|
||||
self.assertAlmostEqual(
|
||||
int(self.Person.objects.average('person_meta.weight')), avg
|
||||
)
|
||||
self.assertAlmostEqual(
|
||||
int(self.Person.objects.aggregate_average('person_meta.weight')),
|
||||
avg
|
||||
)
|
||||
|
||||
self.Person(name='test meta none').save()
|
||||
self.assertEqual(
|
||||
int(self.Person.objects.average('person_meta.weight')), avg
|
||||
)
|
||||
self.assertEqual(
|
||||
int(self.Person.objects.aggregate_average('person_meta.weight')),
|
||||
avg
|
||||
)
|
||||
|
||||
# test summing over a filtered queryset
|
||||
over_50 = [a for a in ages if a >= 50]
|
||||
@@ -2805,10 +2796,6 @@ class QuerySetTest(unittest.TestCase):
|
||||
self.Person.objects.filter(age__gte=50).average('age'),
|
||||
avg
|
||||
)
|
||||
self.assertEqual(
|
||||
self.Person.objects.filter(age__gte=50).aggregate_average('age'),
|
||||
avg
|
||||
)
|
||||
|
||||
def test_sum(self):
|
||||
"""Ensure that field can be summed over correctly.
|
||||
@@ -2818,15 +2805,9 @@ class QuerySetTest(unittest.TestCase):
|
||||
self.Person(name='test%s' % i, age=age).save()
|
||||
|
||||
self.assertEqual(self.Person.objects.sum('age'), sum(ages))
|
||||
self.assertEqual(
|
||||
self.Person.objects.aggregate_sum('age'), sum(ages)
|
||||
)
|
||||
|
||||
self.Person(name='ageless person').save()
|
||||
self.assertEqual(self.Person.objects.sum('age'), sum(ages))
|
||||
self.assertEqual(
|
||||
self.Person.objects.aggregate_sum('age'), sum(ages)
|
||||
)
|
||||
|
||||
for i, age in enumerate(ages):
|
||||
self.Person(name='test meta%s' %
|
||||
@@ -2835,26 +2816,15 @@ class QuerySetTest(unittest.TestCase):
|
||||
self.assertEqual(
|
||||
self.Person.objects.sum('person_meta.weight'), sum(ages)
|
||||
)
|
||||
self.assertEqual(
|
||||
self.Person.objects.aggregate_sum('person_meta.weight'),
|
||||
sum(ages)
|
||||
)
|
||||
|
||||
self.Person(name='weightless person').save()
|
||||
self.assertEqual(self.Person.objects.sum('age'), sum(ages))
|
||||
self.assertEqual(
|
||||
self.Person.objects.aggregate_sum('age'), sum(ages)
|
||||
)
|
||||
|
||||
# test summing over a filtered queryset
|
||||
self.assertEqual(
|
||||
self.Person.objects.filter(age__gte=50).sum('age'),
|
||||
sum([a for a in ages if a >= 50])
|
||||
)
|
||||
self.assertEqual(
|
||||
self.Person.objects.filter(age__gte=50).aggregate_sum('age'),
|
||||
sum([a for a in ages if a >= 50])
|
||||
)
|
||||
|
||||
def test_embedded_average(self):
|
||||
class Pay(EmbeddedDocument):
|
||||
@@ -2867,21 +2837,12 @@ class QuerySetTest(unittest.TestCase):
|
||||
|
||||
Doc.drop_collection()
|
||||
|
||||
Doc(name=u"Wilson Junior",
|
||||
pay=Pay(value=150)).save()
|
||||
Doc(name='Wilson Junior', pay=Pay(value=150)).save()
|
||||
Doc(name='Isabella Luanna', pay=Pay(value=530)).save()
|
||||
Doc(name='Tayza mariana', pay=Pay(value=165)).save()
|
||||
Doc(name='Eliana Costa', pay=Pay(value=115)).save()
|
||||
|
||||
Doc(name=u"Isabella Luanna",
|
||||
pay=Pay(value=530)).save()
|
||||
|
||||
Doc(name=u"Tayza mariana",
|
||||
pay=Pay(value=165)).save()
|
||||
|
||||
Doc(name=u"Eliana Costa",
|
||||
pay=Pay(value=115)).save()
|
||||
|
||||
self.assertEqual(
|
||||
Doc.objects.average('pay.value'),
|
||||
240)
|
||||
self.assertEqual(Doc.objects.average('pay.value'), 240)
|
||||
|
||||
def test_embedded_array_average(self):
|
||||
class Pay(EmbeddedDocument):
|
||||
@@ -2889,26 +2850,16 @@ class QuerySetTest(unittest.TestCase):
|
||||
|
||||
class Doc(Document):
|
||||
name = StringField()
|
||||
pay = EmbeddedDocumentField(
|
||||
Pay)
|
||||
pay = EmbeddedDocumentField(Pay)
|
||||
|
||||
Doc.drop_collection()
|
||||
|
||||
Doc(name=u"Wilson Junior",
|
||||
pay=Pay(values=[150, 100])).save()
|
||||
Doc(name='Wilson Junior', pay=Pay(values=[150, 100])).save()
|
||||
Doc(name='Isabella Luanna', pay=Pay(values=[530, 100])).save()
|
||||
Doc(name='Tayza mariana', pay=Pay(values=[165, 100])).save()
|
||||
Doc(name='Eliana Costa', pay=Pay(values=[115, 100])).save()
|
||||
|
||||
Doc(name=u"Isabella Luanna",
|
||||
pay=Pay(values=[530, 100])).save()
|
||||
|
||||
Doc(name=u"Tayza mariana",
|
||||
pay=Pay(values=[165, 100])).save()
|
||||
|
||||
Doc(name=u"Eliana Costa",
|
||||
pay=Pay(values=[115, 100])).save()
|
||||
|
||||
self.assertEqual(
|
||||
Doc.objects.average('pay.values'),
|
||||
170)
|
||||
self.assertEqual(Doc.objects.average('pay.values'), 170)
|
||||
|
||||
def test_array_average(self):
|
||||
class Doc(Document):
|
||||
@@ -2921,9 +2872,7 @@ class QuerySetTest(unittest.TestCase):
|
||||
Doc(values=[165, 100]).save()
|
||||
Doc(values=[115, 100]).save()
|
||||
|
||||
self.assertEqual(
|
||||
Doc.objects.average('values'),
|
||||
170)
|
||||
self.assertEqual(Doc.objects.average('values'), 170)
|
||||
|
||||
def test_embedded_sum(self):
|
||||
class Pay(EmbeddedDocument):
|
||||
@@ -2931,26 +2880,16 @@ class QuerySetTest(unittest.TestCase):
|
||||
|
||||
class Doc(Document):
|
||||
name = StringField()
|
||||
pay = EmbeddedDocumentField(
|
||||
Pay)
|
||||
pay = EmbeddedDocumentField(Pay)
|
||||
|
||||
Doc.drop_collection()
|
||||
|
||||
Doc(name=u"Wilson Junior",
|
||||
pay=Pay(value=150)).save()
|
||||
Doc(name='Wilson Junior', pay=Pay(value=150)).save()
|
||||
Doc(name='Isabella Luanna', pay=Pay(value=530)).save()
|
||||
Doc(name='Tayza mariana', pay=Pay(value=165)).save()
|
||||
Doc(name='Eliana Costa', pay=Pay(value=115)).save()
|
||||
|
||||
Doc(name=u"Isabella Luanna",
|
||||
pay=Pay(value=530)).save()
|
||||
|
||||
Doc(name=u"Tayza mariana",
|
||||
pay=Pay(value=165)).save()
|
||||
|
||||
Doc(name=u"Eliana Costa",
|
||||
pay=Pay(value=115)).save()
|
||||
|
||||
self.assertEqual(
|
||||
Doc.objects.sum('pay.value'),
|
||||
960)
|
||||
self.assertEqual(Doc.objects.sum('pay.value'), 960)
|
||||
|
||||
def test_embedded_array_sum(self):
|
||||
class Pay(EmbeddedDocument):
|
||||
@@ -2958,26 +2897,16 @@ class QuerySetTest(unittest.TestCase):
|
||||
|
||||
class Doc(Document):
|
||||
name = StringField()
|
||||
pay = EmbeddedDocumentField(
|
||||
Pay)
|
||||
pay = EmbeddedDocumentField(Pay)
|
||||
|
||||
Doc.drop_collection()
|
||||
|
||||
Doc(name=u"Wilson Junior",
|
||||
pay=Pay(values=[150, 100])).save()
|
||||
Doc(name='Wilson Junior', pay=Pay(values=[150, 100])).save()
|
||||
Doc(name='Isabella Luanna', pay=Pay(values=[530, 100])).save()
|
||||
Doc(name='Tayza mariana', pay=Pay(values=[165, 100])).save()
|
||||
Doc(name='Eliana Costa', pay=Pay(values=[115, 100])).save()
|
||||
|
||||
Doc(name=u"Isabella Luanna",
|
||||
pay=Pay(values=[530, 100])).save()
|
||||
|
||||
Doc(name=u"Tayza mariana",
|
||||
pay=Pay(values=[165, 100])).save()
|
||||
|
||||
Doc(name=u"Eliana Costa",
|
||||
pay=Pay(values=[115, 100])).save()
|
||||
|
||||
self.assertEqual(
|
||||
Doc.objects.sum('pay.values'),
|
||||
1360)
|
||||
self.assertEqual(Doc.objects.sum('pay.values'), 1360)
|
||||
|
||||
def test_array_sum(self):
|
||||
class Doc(Document):
|
||||
@@ -2990,9 +2919,7 @@ class QuerySetTest(unittest.TestCase):
|
||||
Doc(values=[165, 100]).save()
|
||||
Doc(values=[115, 100]).save()
|
||||
|
||||
self.assertEqual(
|
||||
Doc.objects.sum('values'),
|
||||
1360)
|
||||
self.assertEqual(Doc.objects.sum('values'), 1360)
|
||||
|
||||
def test_distinct(self):
|
||||
"""Ensure that the QuerySet.distinct method works.
|
||||
@@ -3604,6 +3531,15 @@ class QuerySetTest(unittest.TestCase):
|
||||
self.assertEqual(MyDoc.objects.count(), 10)
|
||||
self.assertEqual(MyDoc.objects.none().count(), 0)
|
||||
|
||||
def test_count_list_embedded(self):
|
||||
class B(EmbeddedDocument):
|
||||
c = StringField()
|
||||
|
||||
class A(Document):
|
||||
b = ListField(EmbeddedDocumentField(B))
|
||||
|
||||
self.assertEqual(A.objects(b=[{'c': 'c'}]).count(), 0)
|
||||
|
||||
def test_call_after_limits_set(self):
|
||||
"""Ensure that re-filtering after slicing works
|
||||
"""
|
||||
@@ -4105,6 +4041,10 @@ class QuerySetTest(unittest.TestCase):
|
||||
Foo(shape="circle", color="purple", thick=False)])
|
||||
b2.save()
|
||||
|
||||
b3 = Bar(foo=[Foo(shape="square", thick=True),
|
||||
Foo(shape="circle", color="purple", thick=False)])
|
||||
b3.save()
|
||||
|
||||
ak = list(
|
||||
Bar.objects(foo__match={'shape': "square", "color": "purple"}))
|
||||
self.assertEqual([b1], ak)
|
||||
@@ -4116,6 +4056,22 @@ class QuerySetTest(unittest.TestCase):
|
||||
ak = list(Bar.objects(foo__match=Foo(shape="square", color="purple")))
|
||||
self.assertEqual([b1], ak)
|
||||
|
||||
ak = list(
|
||||
Bar.objects(foo__elemMatch={'shape': "square", "color__exists": True}))
|
||||
self.assertEqual([b1, b2], ak)
|
||||
|
||||
ak = list(
|
||||
Bar.objects(foo__match={'shape': "square", "color__exists": True}))
|
||||
self.assertEqual([b1, b2], ak)
|
||||
|
||||
ak = list(
|
||||
Bar.objects(foo__elemMatch={'shape': "square", "color__exists": False}))
|
||||
self.assertEqual([b3], ak)
|
||||
|
||||
ak = list(
|
||||
Bar.objects(foo__match={'shape': "square", "color__exists": False}))
|
||||
self.assertEqual([b3], ak)
|
||||
|
||||
def test_upsert_includes_cls(self):
|
||||
"""Upserts should include _cls information for inheritable classes
|
||||
"""
|
||||
@@ -4156,7 +4112,11 @@ class QuerySetTest(unittest.TestCase):
|
||||
|
||||
def test_read_preference(self):
|
||||
class Bar(Document):
|
||||
pass
|
||||
txt = StringField()
|
||||
|
||||
meta = {
|
||||
'indexes': [ 'txt' ]
|
||||
}
|
||||
|
||||
Bar.drop_collection()
|
||||
bars = list(Bar.objects(read_preference=ReadPreference.PRIMARY))
|
||||
@@ -4168,9 +4128,51 @@ class QuerySetTest(unittest.TestCase):
|
||||
error_class = TypeError
|
||||
self.assertRaises(error_class, Bar.objects, read_preference='Primary')
|
||||
|
||||
# read_preference as a kwarg
|
||||
bars = Bar.objects(read_preference=ReadPreference.SECONDARY_PREFERRED)
|
||||
self.assertEqual(
|
||||
bars._read_preference, ReadPreference.SECONDARY_PREFERRED)
|
||||
self.assertEqual(bars._cursor._Cursor__read_preference,
|
||||
ReadPreference.SECONDARY_PREFERRED)
|
||||
|
||||
# read_preference as a query set method
|
||||
bars = Bar.objects.read_preference(ReadPreference.SECONDARY_PREFERRED)
|
||||
self.assertEqual(
|
||||
bars._read_preference, ReadPreference.SECONDARY_PREFERRED)
|
||||
self.assertEqual(bars._cursor._Cursor__read_preference,
|
||||
ReadPreference.SECONDARY_PREFERRED)
|
||||
|
||||
# read_preference after skip
|
||||
bars = Bar.objects.skip(1) \
|
||||
.read_preference(ReadPreference.SECONDARY_PREFERRED)
|
||||
self.assertEqual(
|
||||
bars._read_preference, ReadPreference.SECONDARY_PREFERRED)
|
||||
self.assertEqual(bars._cursor._Cursor__read_preference,
|
||||
ReadPreference.SECONDARY_PREFERRED)
|
||||
|
||||
# read_preference after limit
|
||||
bars = Bar.objects.limit(1) \
|
||||
.read_preference(ReadPreference.SECONDARY_PREFERRED)
|
||||
self.assertEqual(
|
||||
bars._read_preference, ReadPreference.SECONDARY_PREFERRED)
|
||||
self.assertEqual(bars._cursor._Cursor__read_preference,
|
||||
ReadPreference.SECONDARY_PREFERRED)
|
||||
|
||||
# read_preference after order_by
|
||||
bars = Bar.objects.order_by('txt') \
|
||||
.read_preference(ReadPreference.SECONDARY_PREFERRED)
|
||||
self.assertEqual(
|
||||
bars._read_preference, ReadPreference.SECONDARY_PREFERRED)
|
||||
self.assertEqual(bars._cursor._Cursor__read_preference,
|
||||
ReadPreference.SECONDARY_PREFERRED)
|
||||
|
||||
# read_preference after hint
|
||||
bars = Bar.objects.hint([('txt', 1)]) \
|
||||
.read_preference(ReadPreference.SECONDARY_PREFERRED)
|
||||
self.assertEqual(
|
||||
bars._read_preference, ReadPreference.SECONDARY_PREFERRED)
|
||||
self.assertEqual(bars._cursor._Cursor__read_preference,
|
||||
ReadPreference.SECONDARY_PREFERRED)
|
||||
|
||||
def test_json_simple(self):
|
||||
|
||||
@@ -4824,5 +4826,6 @@ class QuerySetTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(1, Doc.objects(item__type__="axe").count())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -224,6 +224,10 @@ class TransformTest(unittest.TestCase):
|
||||
self.assertEqual(1, Doc.objects(item__type__="axe").count())
|
||||
self.assertEqual(1, Doc.objects(item__name__="Heroic axe").count())
|
||||
|
||||
Doc.objects(id=doc.id).update(set__item__type__='sword')
|
||||
self.assertEqual(1, Doc.objects(item__type__="sword").count())
|
||||
self.assertEqual(0, Doc.objects(item__type__="axe").count())
|
||||
|
||||
def test_understandable_error_raised(self):
|
||||
class Event(Document):
|
||||
title = StringField()
|
||||
|
||||
@@ -8,6 +8,7 @@ try:
|
||||
import unittest2 as unittest
|
||||
except ImportError:
|
||||
import unittest
|
||||
from nose.plugins.skip import SkipTest
|
||||
|
||||
import pymongo
|
||||
from bson.tz_util import utc
|
||||
@@ -51,6 +52,42 @@ class ConnectionTest(unittest.TestCase):
|
||||
conn = get_connection('testdb')
|
||||
self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient))
|
||||
|
||||
def test_connect_in_mocking(self):
|
||||
"""Ensure that the connect() method works properly in mocking.
|
||||
"""
|
||||
try:
|
||||
import mongomock
|
||||
except ImportError:
|
||||
raise SkipTest('you need mongomock installed to run this testcase')
|
||||
|
||||
connect('mongoenginetest', host='mongomock://localhost')
|
||||
conn = get_connection()
|
||||
self.assertTrue(isinstance(conn, mongomock.MongoClient))
|
||||
|
||||
connect('mongoenginetest2', host='mongomock://localhost', alias='testdb2')
|
||||
conn = get_connection('testdb2')
|
||||
self.assertTrue(isinstance(conn, mongomock.MongoClient))
|
||||
|
||||
connect('mongoenginetest3', host='mongodb://localhost', is_mock=True, alias='testdb3')
|
||||
conn = get_connection('testdb3')
|
||||
self.assertTrue(isinstance(conn, mongomock.MongoClient))
|
||||
|
||||
connect('mongoenginetest4', is_mock=True, alias='testdb4')
|
||||
conn = get_connection('testdb4')
|
||||
self.assertTrue(isinstance(conn, mongomock.MongoClient))
|
||||
|
||||
connect(host='mongodb://localhost:27017/mongoenginetest5', is_mock=True, alias='testdb5')
|
||||
conn = get_connection('testdb5')
|
||||
self.assertTrue(isinstance(conn, mongomock.MongoClient))
|
||||
|
||||
connect(host='mongomock://localhost:27017/mongoenginetest6', alias='testdb6')
|
||||
conn = get_connection('testdb6')
|
||||
self.assertTrue(isinstance(conn, mongomock.MongoClient))
|
||||
|
||||
connect(host='mongomock://localhost:27017/mongoenginetest7', is_mock=True, alias='testdb7')
|
||||
conn = get_connection('testdb7')
|
||||
self.assertTrue(isinstance(conn, mongomock.MongoClient))
|
||||
|
||||
def test_disconnect(self):
|
||||
"""Ensure that the disconnect() method works properly
|
||||
"""
|
||||
@@ -151,7 +188,7 @@ class ConnectionTest(unittest.TestCase):
|
||||
self.assertRaises(ConnectionError, get_db, 'test1')
|
||||
|
||||
# Authentication succeeds with "authSource"
|
||||
test_conn2 = connect(
|
||||
connect(
|
||||
'mongoenginetest', alias='test2',
|
||||
host=('mongodb://username2:password@localhost/'
|
||||
'mongoenginetest?authSource=admin')
|
||||
|
||||
@@ -12,9 +12,13 @@ from mongoengine.context_managers import query_counter
|
||||
|
||||
class FieldTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
connect(db='mongoenginetest')
|
||||
self.db = get_db()
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.db = connect(db='mongoenginetest')
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.db.drop_database('mongoenginetest')
|
||||
|
||||
def test_list_item_dereference(self):
|
||||
"""Ensure that DBRef items in ListFields are dereferenced.
|
||||
@@ -304,6 +308,7 @@ class FieldTest(unittest.TestCase):
|
||||
|
||||
User.drop_collection()
|
||||
Post.drop_collection()
|
||||
SimpleList.drop_collection()
|
||||
|
||||
u1 = User.objects.create(name='u1')
|
||||
u2 = User.objects.create(name='u2')
|
||||
|
||||
@@ -25,6 +25,8 @@ class SignalTests(unittest.TestCase):
|
||||
connect(db='mongoenginetest')
|
||||
|
||||
class Author(Document):
|
||||
# Make the id deterministic for easier testing
|
||||
id = SequenceField(primary_key=True)
|
||||
name = StringField()
|
||||
|
||||
def __unicode__(self):
|
||||
@@ -33,7 +35,7 @@ class SignalTests(unittest.TestCase):
|
||||
@classmethod
|
||||
def pre_init(cls, sender, document, *args, **kwargs):
|
||||
signal_output.append('pre_init signal, %s' % cls.__name__)
|
||||
signal_output.append(str(kwargs['values']))
|
||||
signal_output.append(kwargs['values'])
|
||||
|
||||
@classmethod
|
||||
def post_init(cls, sender, document, **kwargs):
|
||||
@@ -43,48 +45,55 @@ class SignalTests(unittest.TestCase):
|
||||
@classmethod
|
||||
def pre_save(cls, sender, document, **kwargs):
|
||||
signal_output.append('pre_save signal, %s' % document)
|
||||
signal_output.append(kwargs)
|
||||
|
||||
@classmethod
|
||||
def pre_save_post_validation(cls, sender, document, **kwargs):
|
||||
signal_output.append('pre_save_post_validation signal, %s' % document)
|
||||
if 'created' in kwargs:
|
||||
if kwargs['created']:
|
||||
signal_output.append('Is created')
|
||||
else:
|
||||
signal_output.append('Is updated')
|
||||
if kwargs.pop('created', False):
|
||||
signal_output.append('Is created')
|
||||
else:
|
||||
signal_output.append('Is updated')
|
||||
signal_output.append(kwargs)
|
||||
|
||||
@classmethod
|
||||
def post_save(cls, sender, document, **kwargs):
|
||||
dirty_keys = document._delta()[0].keys() + document._delta()[1].keys()
|
||||
signal_output.append('post_save signal, %s' % document)
|
||||
signal_output.append('post_save dirty keys, %s' % dirty_keys)
|
||||
if 'created' in kwargs:
|
||||
if kwargs['created']:
|
||||
signal_output.append('Is created')
|
||||
else:
|
||||
signal_output.append('Is updated')
|
||||
if kwargs.pop('created', False):
|
||||
signal_output.append('Is created')
|
||||
else:
|
||||
signal_output.append('Is updated')
|
||||
signal_output.append(kwargs)
|
||||
|
||||
@classmethod
|
||||
def pre_delete(cls, sender, document, **kwargs):
|
||||
signal_output.append('pre_delete signal, %s' % document)
|
||||
signal_output.append(kwargs)
|
||||
|
||||
@classmethod
|
||||
def post_delete(cls, sender, document, **kwargs):
|
||||
signal_output.append('post_delete signal, %s' % document)
|
||||
signal_output.append(kwargs)
|
||||
|
||||
@classmethod
|
||||
def pre_bulk_insert(cls, sender, documents, **kwargs):
|
||||
signal_output.append('pre_bulk_insert signal, %s' % documents)
|
||||
signal_output.append(kwargs)
|
||||
|
||||
@classmethod
|
||||
def post_bulk_insert(cls, sender, documents, **kwargs):
|
||||
signal_output.append('post_bulk_insert signal, %s' % documents)
|
||||
if kwargs.get('loaded', False):
|
||||
if kwargs.pop('loaded', False):
|
||||
signal_output.append('Is loaded')
|
||||
else:
|
||||
signal_output.append('Not loaded')
|
||||
signal_output.append(kwargs)
|
||||
|
||||
self.Author = Author
|
||||
Author.drop_collection()
|
||||
Author.id.set_next_value(0)
|
||||
|
||||
class Another(Document):
|
||||
|
||||
@@ -96,10 +105,12 @@ class SignalTests(unittest.TestCase):
|
||||
@classmethod
|
||||
def pre_delete(cls, sender, document, **kwargs):
|
||||
signal_output.append('pre_delete signal, %s' % document)
|
||||
signal_output.append(kwargs)
|
||||
|
||||
@classmethod
|
||||
def post_delete(cls, sender, document, **kwargs):
|
||||
signal_output.append('post_delete signal, %s' % document)
|
||||
signal_output.append(kwargs)
|
||||
|
||||
self.Another = Another
|
||||
Another.drop_collection()
|
||||
@@ -118,6 +129,41 @@ class SignalTests(unittest.TestCase):
|
||||
self.ExplicitId = ExplicitId
|
||||
ExplicitId.drop_collection()
|
||||
|
||||
class Post(Document):
|
||||
title = StringField()
|
||||
content = StringField()
|
||||
active = BooleanField(default=False)
|
||||
|
||||
def __unicode__(self):
|
||||
return self.title
|
||||
|
||||
@classmethod
|
||||
def pre_bulk_insert(cls, sender, documents, **kwargs):
|
||||
signal_output.append('pre_bulk_insert signal, %s' %
|
||||
[(doc, {'active': documents[n].active})
|
||||
for n, doc in enumerate(documents)])
|
||||
|
||||
# make changes here, this is just an example -
|
||||
# it could be anything that needs pre-validation or looks-ups before bulk bulk inserting
|
||||
for document in documents:
|
||||
if not document.active:
|
||||
document.active = True
|
||||
signal_output.append(kwargs)
|
||||
|
||||
@classmethod
|
||||
def post_bulk_insert(cls, sender, documents, **kwargs):
|
||||
signal_output.append('post_bulk_insert signal, %s' %
|
||||
[(doc, {'active': documents[n].active})
|
||||
for n, doc in enumerate(documents)])
|
||||
if kwargs.pop('loaded', False):
|
||||
signal_output.append('Is loaded')
|
||||
else:
|
||||
signal_output.append('Not loaded')
|
||||
signal_output.append(kwargs)
|
||||
|
||||
self.Post = Post
|
||||
Post.drop_collection()
|
||||
|
||||
# Save up the number of connected signals so that we can check at the
|
||||
# end that all the signals we register get properly unregistered
|
||||
self.pre_signals = (
|
||||
@@ -147,6 +193,9 @@ class SignalTests(unittest.TestCase):
|
||||
|
||||
signals.post_save.connect(ExplicitId.post_save, sender=ExplicitId)
|
||||
|
||||
signals.pre_bulk_insert.connect(Post.pre_bulk_insert, sender=Post)
|
||||
signals.post_bulk_insert.connect(Post.post_bulk_insert, sender=Post)
|
||||
|
||||
def tearDown(self):
|
||||
signals.pre_init.disconnect(self.Author.pre_init)
|
||||
signals.post_init.disconnect(self.Author.post_init)
|
||||
@@ -163,6 +212,9 @@ class SignalTests(unittest.TestCase):
|
||||
|
||||
signals.post_save.disconnect(self.ExplicitId.post_save)
|
||||
|
||||
signals.pre_bulk_insert.disconnect(self.Post.pre_bulk_insert)
|
||||
signals.post_bulk_insert.disconnect(self.Post.post_bulk_insert)
|
||||
|
||||
# Check that all our signals got disconnected properly.
|
||||
post_signals = (
|
||||
len(signals.pre_init.receivers),
|
||||
@@ -199,66 +251,121 @@ class SignalTests(unittest.TestCase):
|
||||
a.save()
|
||||
self.get_signal_output(lambda: None) # eliminate signal output
|
||||
a1 = self.Author.objects(name='Bill Shakespeare')[0]
|
||||
|
||||
|
||||
self.assertEqual(self.get_signal_output(create_author), [
|
||||
"pre_init signal, Author",
|
||||
"{'name': 'Bill Shakespeare'}",
|
||||
{'name': 'Bill Shakespeare'},
|
||||
"post_init signal, Bill Shakespeare, document._created = True",
|
||||
])
|
||||
|
||||
a1 = self.Author(name='Bill Shakespeare')
|
||||
self.assertEqual(self.get_signal_output(a1.save), [
|
||||
"pre_save signal, Bill Shakespeare",
|
||||
{},
|
||||
"pre_save_post_validation signal, Bill Shakespeare",
|
||||
"Is created",
|
||||
{},
|
||||
"post_save signal, Bill Shakespeare",
|
||||
"post_save dirty keys, ['name']",
|
||||
"Is created"
|
||||
"Is created",
|
||||
{}
|
||||
])
|
||||
|
||||
a1.reload()
|
||||
a1.name = 'William Shakespeare'
|
||||
self.assertEqual(self.get_signal_output(a1.save), [
|
||||
"pre_save signal, William Shakespeare",
|
||||
{},
|
||||
"pre_save_post_validation signal, William Shakespeare",
|
||||
"Is updated",
|
||||
{},
|
||||
"post_save signal, William Shakespeare",
|
||||
"post_save dirty keys, ['name']",
|
||||
"Is updated"
|
||||
"Is updated",
|
||||
{}
|
||||
])
|
||||
|
||||
self.assertEqual(self.get_signal_output(a1.delete), [
|
||||
'pre_delete signal, William Shakespeare',
|
||||
{},
|
||||
'post_delete signal, William Shakespeare',
|
||||
{}
|
||||
])
|
||||
|
||||
signal_output = self.get_signal_output(load_existing_author)
|
||||
# test signal_output lines separately, because of random ObjectID after object load
|
||||
self.assertEqual(signal_output[0],
|
||||
self.assertEqual(self.get_signal_output(load_existing_author), [
|
||||
"pre_init signal, Author",
|
||||
)
|
||||
self.assertEqual(signal_output[2],
|
||||
"post_init signal, Bill Shakespeare, document._created = False",
|
||||
)
|
||||
{'id': 2, 'name': 'Bill Shakespeare'},
|
||||
"post_init signal, Bill Shakespeare, document._created = False"
|
||||
])
|
||||
|
||||
|
||||
signal_output = self.get_signal_output(bulk_create_author_with_load)
|
||||
|
||||
# The output of this signal is not entirely deterministic. The reloaded
|
||||
# object will have an object ID. Hence, we only check part of the output
|
||||
self.assertEqual(signal_output[3], "pre_bulk_insert signal, [<Author: Bill Shakespeare>]"
|
||||
)
|
||||
self.assertEqual(signal_output[-2:],
|
||||
["post_bulk_insert signal, [<Author: Bill Shakespeare>]",
|
||||
"Is loaded",])
|
||||
self.assertEqual(self.get_signal_output(bulk_create_author_with_load), [
|
||||
'pre_init signal, Author',
|
||||
{'name': 'Bill Shakespeare'},
|
||||
'post_init signal, Bill Shakespeare, document._created = True',
|
||||
'pre_bulk_insert signal, [<Author: Bill Shakespeare>]',
|
||||
{},
|
||||
'pre_init signal, Author',
|
||||
{'id': 3, 'name': 'Bill Shakespeare'},
|
||||
'post_init signal, Bill Shakespeare, document._created = False',
|
||||
'post_bulk_insert signal, [<Author: Bill Shakespeare>]',
|
||||
'Is loaded',
|
||||
{}
|
||||
])
|
||||
|
||||
self.assertEqual(self.get_signal_output(bulk_create_author_without_load), [
|
||||
"pre_init signal, Author",
|
||||
"{'name': 'Bill Shakespeare'}",
|
||||
{'name': 'Bill Shakespeare'},
|
||||
"post_init signal, Bill Shakespeare, document._created = True",
|
||||
"pre_bulk_insert signal, [<Author: Bill Shakespeare>]",
|
||||
{},
|
||||
"post_bulk_insert signal, [<Author: Bill Shakespeare>]",
|
||||
"Not loaded",
|
||||
{}
|
||||
])
|
||||
|
||||
def test_signal_kwargs(self):
|
||||
""" Make sure signal_kwargs is passed to signals calls. """
|
||||
|
||||
def live_and_let_die():
|
||||
a = self.Author(name='Bill Shakespeare')
|
||||
a.save(signal_kwargs={'live': True, 'die': False})
|
||||
a.delete(signal_kwargs={'live': False, 'die': True})
|
||||
|
||||
self.assertEqual(self.get_signal_output(live_and_let_die), [
|
||||
"pre_init signal, Author",
|
||||
{'name': 'Bill Shakespeare'},
|
||||
"post_init signal, Bill Shakespeare, document._created = True",
|
||||
"pre_save signal, Bill Shakespeare",
|
||||
{'die': False, 'live': True},
|
||||
"pre_save_post_validation signal, Bill Shakespeare",
|
||||
"Is created",
|
||||
{'die': False, 'live': True},
|
||||
"post_save signal, Bill Shakespeare",
|
||||
"post_save dirty keys, ['name']",
|
||||
"Is created",
|
||||
{'die': False, 'live': True},
|
||||
'pre_delete signal, Bill Shakespeare',
|
||||
{'die': True, 'live': False},
|
||||
'post_delete signal, Bill Shakespeare',
|
||||
{'die': True, 'live': False}
|
||||
])
|
||||
|
||||
def bulk_create_author():
|
||||
a1 = self.Author(name='Bill Shakespeare')
|
||||
self.Author.objects.insert([a1], signal_kwargs={'key': True})
|
||||
|
||||
self.assertEqual(self.get_signal_output(bulk_create_author), [
|
||||
'pre_init signal, Author',
|
||||
{'name': 'Bill Shakespeare'},
|
||||
'post_init signal, Bill Shakespeare, document._created = True',
|
||||
'pre_bulk_insert signal, [<Author: Bill Shakespeare>]',
|
||||
{'key': True},
|
||||
'pre_init signal, Author',
|
||||
{'id': 2, 'name': 'Bill Shakespeare'},
|
||||
'post_init signal, Bill Shakespeare, document._created = False',
|
||||
'post_bulk_insert signal, [<Author: Bill Shakespeare>]',
|
||||
'Is loaded',
|
||||
{'key': True}
|
||||
])
|
||||
|
||||
def test_queryset_delete_signals(self):
|
||||
@@ -267,7 +374,9 @@ class SignalTests(unittest.TestCase):
|
||||
self.Another(name='Bill Shakespeare').save()
|
||||
self.assertEqual(self.get_signal_output(self.Another.objects.delete), [
|
||||
'pre_delete signal, Bill Shakespeare',
|
||||
{},
|
||||
'post_delete signal, Bill Shakespeare',
|
||||
{}
|
||||
])
|
||||
|
||||
def test_signals_with_explicit_doc_ids(self):
|
||||
@@ -306,6 +415,23 @@ class SignalTests(unittest.TestCase):
|
||||
ei.switch_db("testdb-1", keep_created=False)
|
||||
self.assertEqual(self.get_signal_output(ei.save), ['Is created'])
|
||||
|
||||
def test_signals_bulk_insert(self):
|
||||
def bulk_set_active_post():
|
||||
posts = [
|
||||
self.Post(title='Post 1'),
|
||||
self.Post(title='Post 2'),
|
||||
self.Post(title='Post 3')
|
||||
]
|
||||
self.Post.objects.insert(posts)
|
||||
|
||||
results = self.get_signal_output(bulk_set_active_post)
|
||||
self.assertEqual(results, [
|
||||
"pre_bulk_insert signal, [(<Post: Post 1>, {'active': False}), (<Post: Post 2>, {'active': False}), (<Post: Post 3>, {'active': False})]",
|
||||
{},
|
||||
"post_bulk_insert signal, [(<Post: Post 1>, {'active': True}), (<Post: Post 2>, {'active': True}), (<Post: Post 3>, {'active': True})]",
|
||||
'Is loaded',
|
||||
{}
|
||||
])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user