Improve the health of this package (#1428)

This commit is contained in:
Stefan Wójcik
2016-12-11 18:49:21 -05:00
committed by GitHub
parent 3135b456be
commit 835d3c3d18
60 changed files with 1564 additions and 1893 deletions

View File

@@ -2,4 +2,3 @@ from all_warnings import AllWarnings
from document import *
from queryset import *
from fields import *
from migration import *

View File

@@ -3,8 +3,6 @@ This test has been put into a module. This is because it tests warnings that
only get triggered on first hit. This way we can ensure its imported into the
top level and called first by the test suite.
"""
import sys
sys.path[0:0] = [""]
import unittest
import warnings

View File

@@ -1,5 +1,3 @@
import sys
sys.path[0:0] = [""]
import unittest
from class_methods import *

View File

@@ -1,6 +1,4 @@
# -*- coding: utf-8 -*-
import sys
sys.path[0:0] = [""]
import unittest
from mongoengine import *

View File

@@ -1,6 +1,4 @@
# -*- coding: utf-8 -*-
import sys
sys.path[0:0] = [""]
import unittest
from bson import SON

View File

@@ -1,6 +1,4 @@
import unittest
import sys
sys.path[0:0] = [""]
from mongoengine import *
from mongoengine.connection import get_db
@@ -143,11 +141,9 @@ class DynamicTest(unittest.TestCase):
def test_three_level_complex_data_lookups(self):
"""Ensure you can query three level document dynamic fields"""
p = self.Person()
p.misc = {'hello': {'hello2': 'world'}}
p.save()
# from pprint import pprint as pp; import pdb; pdb.set_trace();
print self.Person.objects(misc__hello__hello2='world')
p = self.Person.objects.create(
misc={'hello': {'hello2': 'world'}}
)
self.assertEqual(1, self.Person.objects(misc__hello__hello2='world').count())
def test_complex_embedded_document_validation(self):

View File

@@ -556,8 +556,8 @@ class IndexesTest(unittest.TestCase):
BlogPost.drop_collection()
for i in xrange(0, 10):
tags = [("tag %i" % n) for n in xrange(0, i % 2)]
for i in range(0, 10):
tags = [("tag %i" % n) for n in range(0, i % 2)]
BlogPost(tags=tags).save()
self.assertEqual(BlogPost.objects.count(), 10)

View File

@@ -1,6 +1,4 @@
# -*- coding: utf-8 -*-
import sys
sys.path[0:0] = [""]
import unittest
import warnings
@@ -253,19 +251,17 @@ class InheritanceTest(unittest.TestCase):
self.assertEqual(classes, [Human])
def test_allow_inheritance(self):
"""Ensure that inheritance may be disabled on simple classes and that
_cls and _subclasses will not be used.
"""Ensure that inheritance is disabled by default on simple
classes and that _cls will not be used.
"""
class Animal(Document):
name = StringField()
def create_dog_class():
# can't inherit because Animal didn't explicitly allow inheritance
with self.assertRaises(ValueError):
class Dog(Animal):
pass
self.assertRaises(ValueError, create_dog_class)
# Check that _cls etc aren't present on simple documents
dog = Animal(name='dog').save()
self.assertEqual(dog.to_mongo().keys(), ['_id', 'name'])
@@ -275,17 +271,15 @@ class InheritanceTest(unittest.TestCase):
self.assertFalse('_cls' in obj)
def test_cant_turn_off_inheritance_on_subclass(self):
"""Ensure if inheritance is on in a subclass you cant turn it off
"""Ensure if inheritance is on in a subclass you cant turn it off.
"""
class Animal(Document):
name = StringField()
meta = {'allow_inheritance': True}
def create_mammal_class():
with self.assertRaises(ValueError):
class Mammal(Animal):
meta = {'allow_inheritance': False}
self.assertRaises(ValueError, create_mammal_class)
def test_allow_inheritance_abstract_document(self):
"""Ensure that abstract documents can set inheritance rules and that
@@ -298,10 +292,9 @@ class InheritanceTest(unittest.TestCase):
class Animal(FinalDocument):
name = StringField()
def create_mammal_class():
with self.assertRaises(ValueError):
class Mammal(Animal):
pass
self.assertRaises(ValueError, create_mammal_class)
# Check that _cls isn't present in simple documents
doc = Animal(name='dog')
@@ -360,29 +353,26 @@ class InheritanceTest(unittest.TestCase):
self.assertEqual(berlin.pk, berlin.auto_id_0)
def test_abstract_document_creation_does_not_fail(self):
class City(Document):
continent = StringField()
meta = {'abstract': True,
'allow_inheritance': False}
bkk = City(continent='asia')
self.assertEqual(None, bkk.pk)
# TODO: expected error? Shouldn't we create a new error type?
self.assertRaises(KeyError, lambda: setattr(bkk, 'pk', 1))
with self.assertRaises(KeyError):
setattr(bkk, 'pk', 1)
def test_allow_inheritance_embedded_document(self):
"""Ensure embedded documents respect inheritance
"""
"""Ensure embedded documents respect inheritance."""
class Comment(EmbeddedDocument):
content = StringField()
def create_special_comment():
with self.assertRaises(ValueError):
class SpecialComment(Comment):
pass
self.assertRaises(ValueError, create_special_comment)
doc = Comment(content='test')
self.assertFalse('_cls' in doc.to_mongo())
@@ -454,11 +444,11 @@ class InheritanceTest(unittest.TestCase):
self.assertEqual(Guppy._get_collection_name(), 'fish')
self.assertEqual(Human._get_collection_name(), 'human')
def create_bad_abstract():
# ensure that a subclass of a non-abstract class can't be abstract
with self.assertRaises(ValueError):
class EvilHuman(Human):
evil = BooleanField(default=True)
meta = {'abstract': True}
self.assertRaises(ValueError, create_bad_abstract)
def test_abstract_embedded_documents(self):
# 789: EmbeddedDocument shouldn't inherit abstract

View File

@@ -1,7 +1,4 @@
# -*- coding: utf-8 -*-
import sys
sys.path[0:0] = [""]
import bson
import os
import pickle
@@ -16,12 +13,12 @@ from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest,
PickleDynamicEmbedded, PickleDynamicTest)
from mongoengine import *
from mongoengine.base import get_document, _document_registry
from mongoengine.connection import get_db
from mongoengine.errors import (NotRegistered, InvalidDocumentError,
InvalidQueryError, NotUniqueError,
FieldDoesNotExist, SaveConditionError)
from mongoengine.queryset import NULLIFY, Q
from mongoengine.connection import get_db
from mongoengine.base import get_document
from mongoengine.context_managers import switch_db, query_counter
from mongoengine import signals
@@ -102,21 +99,18 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(options['size'], 4096)
# Check that the document cannot be redefined with different options
def recreate_log_document():
class Log(Document):
date = DateTimeField(default=datetime.now)
meta = {
'max_documents': 11,
}
# Create the collection by accessing Document.objects
Log.objects
self.assertRaises(InvalidCollectionError, recreate_log_document)
class Log(Document):
date = DateTimeField(default=datetime.now)
meta = {
'max_documents': 11,
}
Log.drop_collection()
# Accessing Document.objects creates the collection
with self.assertRaises(InvalidCollectionError):
Log.objects
def test_capped_collection_default(self):
"""Ensure that capped collections defaults work properly.
"""
"""Ensure that capped collections defaults work properly."""
class Log(Document):
date = DateTimeField(default=datetime.now)
meta = {
@@ -134,16 +128,14 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(options['size'], 10 * 2**20)
# Check that the document with default value can be recreated
def recreate_log_document():
class Log(Document):
date = DateTimeField(default=datetime.now)
meta = {
'max_documents': 10,
}
# Create the collection by accessing Document.objects
Log.objects
recreate_log_document()
Log.drop_collection()
class Log(Document):
date = DateTimeField(default=datetime.now)
meta = {
'max_documents': 10,
}
# Create the collection by accessing Document.objects
Log.objects
def test_capped_collection_no_max_size_problems(self):
"""Ensure that capped collections with odd max_size work properly.
@@ -166,16 +158,14 @@ class InstanceTest(unittest.TestCase):
self.assertTrue(options['size'] >= 10000)
# Check that the document with odd max_size value can be recreated
def recreate_log_document():
class Log(Document):
date = DateTimeField(default=datetime.now)
meta = {
'max_size': 10000,
}
# Create the collection by accessing Document.objects
Log.objects
recreate_log_document()
Log.drop_collection()
class Log(Document):
date = DateTimeField(default=datetime.now)
meta = {
'max_size': 10000,
}
# Create the collection by accessing Document.objects
Log.objects
def test_repr(self):
"""Ensure that unicode representation works
@@ -286,7 +276,7 @@ class InstanceTest(unittest.TestCase):
list_stats = []
for i in xrange(10):
for i in range(10):
s = Stats()
s.save()
list_stats.append(s)
@@ -356,14 +346,14 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(User._fields['username'].db_field, '_id')
self.assertEqual(User._meta['id_field'], 'username')
def create_invalid_user():
User(name='test').save() # no primary key field
self.assertRaises(ValidationError, create_invalid_user)
# test no primary key field
self.assertRaises(ValidationError, User(name='test').save)
def define_invalid_user():
# define a subclass with a different primary key field than the
# parent
with self.assertRaises(ValueError):
class EmailUser(User):
email = StringField(primary_key=True)
self.assertRaises(ValueError, define_invalid_user)
class EmailUser(User):
email = StringField()
@@ -411,12 +401,10 @@ class InstanceTest(unittest.TestCase):
# Mimic Place and NicePlace definitions being in a different file
# and the NicePlace model not being imported in at query time.
from mongoengine.base import _document_registry
del(_document_registry['Place.NicePlace'])
def query_without_importing_nice_place():
print Place.objects.all()
self.assertRaises(NotRegistered, query_without_importing_nice_place)
with self.assertRaises(NotRegistered):
list(Place.objects.all())
def test_document_registry_regressions(self):
@@ -745,7 +733,7 @@ class InstanceTest(unittest.TestCase):
try:
t.save()
except ValidationError, e:
except ValidationError as e:
expect_msg = "Draft entries may not have a publication date."
self.assertTrue(expect_msg in e.message)
self.assertEqual(e.to_dict(), {'__all__': expect_msg})
@@ -784,7 +772,7 @@ class InstanceTest(unittest.TestCase):
t = TestDocument(doc=TestEmbeddedDocument(x=10, y=25, z=15))
try:
t.save()
except ValidationError, e:
except ValidationError as e:
expect_msg = "Value of z != x + y"
self.assertTrue(expect_msg in e.message)
self.assertEqual(e.to_dict(), {'doc': {'__all__': expect_msg}})
@@ -798,8 +786,10 @@ class InstanceTest(unittest.TestCase):
def test_modify_empty(self):
doc = self.Person(name="bob", age=10).save()
self.assertRaises(
InvalidDocumentError, lambda: self.Person().modify(set__age=10))
with self.assertRaises(InvalidDocumentError):
self.Person().modify(set__age=10)
self.assertDbEqual([dict(doc.to_mongo())])
def test_modify_invalid_query(self):
@@ -807,9 +797,8 @@ class InstanceTest(unittest.TestCase):
doc2 = self.Person(name="jim", age=20).save()
docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())]
self.assertRaises(
InvalidQueryError,
lambda: doc1.modify(dict(id=doc2.id), set__value=20))
with self.assertRaises(InvalidQueryError):
doc1.modify({'id': doc2.id}, set__value=20)
self.assertDbEqual(docs)
@@ -818,7 +807,7 @@ class InstanceTest(unittest.TestCase):
doc2 = self.Person(name="jim", age=20).save()
docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())]
assert not doc1.modify(dict(name=doc2.name), set__age=100)
assert not doc1.modify({'name': doc2.name}, set__age=100)
self.assertDbEqual(docs)
@@ -827,7 +816,7 @@ class InstanceTest(unittest.TestCase):
doc2 = self.Person(id=ObjectId(), name="jim", age=20)
docs = [dict(doc1.to_mongo())]
assert not doc2.modify(dict(name=doc2.name), set__age=100)
assert not doc2.modify({'name': doc2.name}, set__age=100)
self.assertDbEqual(docs)
@@ -1293,12 +1282,11 @@ class InstanceTest(unittest.TestCase):
def test_document_update(self):
def update_not_saved_raises():
# try updating a non-saved document
with self.assertRaises(OperationError):
person = self.Person(name='dcrosta')
person.update(set__name='Dan Crosta')
self.assertRaises(OperationError, update_not_saved_raises)
author = self.Person(name='dcrosta')
author.save()
@@ -1308,19 +1296,17 @@ class InstanceTest(unittest.TestCase):
p1 = self.Person.objects.first()
self.assertEqual(p1.name, author.name)
def update_no_value_raises():
# try sending an empty update
with self.assertRaises(OperationError):
person = self.Person.objects.first()
person.update()
self.assertRaises(OperationError, update_no_value_raises)
def update_no_op_should_default_to_set():
person = self.Person.objects.first()
person.update(name="Dan")
person.reload()
return person.name
self.assertEqual("Dan", update_no_op_should_default_to_set())
# update that doesn't explicitly specify an operator should default
# to 'set__'
person = self.Person.objects.first()
person.update(name="Dan")
person.reload()
self.assertEqual("Dan", person.name)
def test_update_unique_field(self):
class Doc(Document):
@@ -1329,8 +1315,8 @@ class InstanceTest(unittest.TestCase):
doc1 = Doc(name="first").save()
doc2 = Doc(name="second").save()
self.assertRaises(NotUniqueError, lambda:
doc2.update(set__name=doc1.name))
with self.assertRaises(NotUniqueError):
doc2.update(set__name=doc1.name)
def test_embedded_update(self):
"""
@@ -1848,15 +1834,13 @@ class InstanceTest(unittest.TestCase):
def test_duplicate_db_fields_raise_invalid_document_error(self):
"""Ensure a InvalidDocumentError is thrown if duplicate fields
declare the same db_field"""
def throw_invalid_document_error():
declare the same db_field.
"""
with self.assertRaises(InvalidDocumentError):
class Foo(Document):
name = StringField()
name2 = StringField(db_field='name')
self.assertRaises(InvalidDocumentError, throw_invalid_document_error)
def test_invalid_son(self):
"""Raise an error if loading invalid data"""
class Occurrence(EmbeddedDocument):
@@ -1868,11 +1852,13 @@ class InstanceTest(unittest.TestCase):
forms = ListField(StringField(), default=list)
occurs = ListField(EmbeddedDocumentField(Occurrence), default=list)
def raise_invalid_document():
Word._from_son({'stem': [1, 2, 3], 'forms': 1, 'count': 'one',
'occurs': {"hello": None}})
self.assertRaises(InvalidDocumentError, raise_invalid_document)
with self.assertRaises(InvalidDocumentError):
Word._from_son({
'stem': [1, 2, 3],
'forms': 1,
'count': 'one',
'occurs': {"hello": None}
})
def test_reverse_delete_rule_cascade_and_nullify(self):
"""Ensure that a referenced document is also deleted upon deletion.
@@ -2103,8 +2089,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(Bar.objects.get().foo, None)
def test_invalid_reverse_delete_rule_raise_errors(self):
def throw_invalid_document_error():
with self.assertRaises(InvalidDocumentError):
class Blog(Document):
content = StringField()
authors = MapField(ReferenceField(
@@ -2114,21 +2099,15 @@ class InstanceTest(unittest.TestCase):
self.Person,
reverse_delete_rule=NULLIFY))
self.assertRaises(InvalidDocumentError, throw_invalid_document_error)
def throw_invalid_document_error_embedded():
with self.assertRaises(InvalidDocumentError):
class Parents(EmbeddedDocument):
father = ReferenceField('Person', reverse_delete_rule=DENY)
mother = ReferenceField('Person', reverse_delete_rule=DENY)
self.assertRaises(
InvalidDocumentError, throw_invalid_document_error_embedded)
def test_reverse_delete_rule_cascade_recurs(self):
"""Ensure that a chain of documents is also deleted upon cascaded
deletion.
"""
class BlogPost(Document):
content = StringField()
author = ReferenceField(self.Person, reverse_delete_rule=CASCADE)
@@ -2344,15 +2323,14 @@ class InstanceTest(unittest.TestCase):
pickle_doc.save()
pickle_doc.delete()
def test_throw_invalid_document_error(self):
# test handles people trying to upsert
def throw_invalid_document_error():
def test_override_method_with_field(self):
"""Test creating a field with a field name that would override
the "validate" method.
"""
with self.assertRaises(InvalidDocumentError):
class Blog(Document):
validate = DictField()
self.assertRaises(InvalidDocumentError, throw_invalid_document_error)
def test_mutating_documents(self):
class B(EmbeddedDocument):
@@ -2815,11 +2793,10 @@ class InstanceTest(unittest.TestCase):
log.log = "Saving"
log.save()
def change_shard_key():
# try to change the shard key
with self.assertRaises(OperationError):
log.machine = "127.0.0.1"
self.assertRaises(OperationError, change_shard_key)
def test_shard_key_in_embedded_document(self):
class Foo(EmbeddedDocument):
foo = StringField()
@@ -2840,12 +2817,11 @@ class InstanceTest(unittest.TestCase):
bar_doc.bar = 'baz'
bar_doc.save()
def change_shard_key():
# try to change the shard key
with self.assertRaises(OperationError):
bar_doc.foo.foo = 'something'
bar_doc.save()
self.assertRaises(OperationError, change_shard_key)
def test_shard_key_primary(self):
class LogEntry(Document):
machine = StringField(primary_key=True)
@@ -2866,11 +2842,10 @@ class InstanceTest(unittest.TestCase):
log.log = "Saving"
log.save()
def change_shard_key():
# try to change the shard key
with self.assertRaises(OperationError):
log.machine = "127.0.0.1"
self.assertRaises(OperationError, change_shard_key)
def test_kwargs_simple(self):
class Embedded(EmbeddedDocument):
@@ -2955,11 +2930,9 @@ class InstanceTest(unittest.TestCase):
def test_bad_mixed_creation(self):
"""Ensure that document gives correct error when duplicating arguments
"""
def construct_bad_instance():
with self.assertRaises(TypeError):
return self.Person("Test User", 42, name="Bad User")
self.assertRaises(TypeError, construct_bad_instance)
def test_data_contains_id_field(self):
"""Ensure that asking for _data returns 'id'
"""
@@ -3118,17 +3091,17 @@ class InstanceTest(unittest.TestCase):
p4 = Person.objects()[0]
p4.save()
self.assertEquals(p4.height, 189)
# However the default will not be fixed in DB
self.assertEquals(Person.objects(height=189).count(), 0)
# alter DB for the new default
coll = Person._get_collection()
for person in Person.objects.as_pymongo():
if 'height' not in person:
person['height'] = 189
coll.save(person)
self.assertEquals(Person.objects(height=189).count(), 1)
def test_from_son(self):

View File

@@ -1,6 +1,3 @@
import sys
sys.path[0:0] = [""]
import unittest
import uuid

View File

@@ -1,7 +1,4 @@
# -*- coding: utf-8 -*-
import sys
sys.path[0:0] = [""]
import unittest
from datetime import datetime
@@ -60,7 +57,7 @@ class ValidatorErrorTest(unittest.TestCase):
try:
User().validate()
except ValidationError, e:
except ValidationError as e:
self.assertTrue("User:None" in e.message)
self.assertEqual(e.to_dict(), {
'username': 'Field is required',
@@ -70,7 +67,7 @@ class ValidatorErrorTest(unittest.TestCase):
user.name = None
try:
user.save()
except ValidationError, e:
except ValidationError as e:
self.assertTrue("User:RossC0" in e.message)
self.assertEqual(e.to_dict(), {
'name': 'Field is required'})
@@ -118,7 +115,7 @@ class ValidatorErrorTest(unittest.TestCase):
try:
Doc(id="bad").validate()
except ValidationError, e:
except ValidationError as e:
self.assertTrue("SubDoc:None" in e.message)
self.assertEqual(e.to_dict(), {
"e": {'val': 'OK could not be converted to int'}})
@@ -136,7 +133,7 @@ class ValidatorErrorTest(unittest.TestCase):
doc.e.val = "OK"
try:
doc.save()
except ValidationError, e:
except ValidationError as e:
self.assertTrue("Doc:test" in e.message)
self.assertEqual(e.to_dict(), {
"e": {'val': 'OK could not be converted to int'}})
@@ -156,14 +153,14 @@ class ValidatorErrorTest(unittest.TestCase):
s = SubDoc()
self.assertRaises(ValidationError, lambda: s.validate())
self.assertRaises(ValidationError, s.validate)
d1.e = s
d2.e = s
del d1
self.assertRaises(ValidationError, lambda: d2.validate())
self.assertRaises(ValidationError, d2.validate)
def test_parent_reference_in_child_document(self):
"""

View File

@@ -1,11 +1,7 @@
# -*- coding: utf-8 -*-
import sys
import six
from nose.plugins.skip import SkipTest
sys.path[0:0] = [""]
import datetime
import unittest
import uuid
@@ -29,10 +25,9 @@ except ImportError:
from mongoengine import *
from mongoengine.connection import get_db
from mongoengine.base import _document_registry
from mongoengine.base.datastructures import BaseDict, EmbeddedDocumentList
from mongoengine.base import (BaseDict, BaseField, EmbeddedDocumentList,
_document_registry)
from mongoengine.errors import NotRegistered, DoesNotExist
from mongoengine.python_support import PY3, b, bin_type
__all__ = ("FieldTest", "EmbeddedDocumentListFieldTestCase")
@@ -653,8 +648,8 @@ class FieldTest(unittest.TestCase):
# Post UTC - microseconds are rounded (down) nearest millisecond and
# dropped
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999)
d2 = datetime.datetime(1970, 01, 01, 00, 00, 01)
d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 999)
d2 = datetime.datetime(1970, 1, 1, 0, 0, 1)
log = LogEntry()
log.date = d1
log.save()
@@ -663,15 +658,15 @@ class FieldTest(unittest.TestCase):
self.assertEqual(log.date, d2)
# Post UTC - microseconds are rounded (down) nearest millisecond
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9999)
d2 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9000)
d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9999)
d2 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9000)
log.date = d1
log.save()
log.reload()
self.assertNotEqual(log.date, d1)
self.assertEqual(log.date, d2)
if not PY3:
if not six.PY3:
# Pre UTC dates microseconds below 1000 are dropped
# This does not seem to be true in PY3
d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999)
@@ -691,7 +686,7 @@ class FieldTest(unittest.TestCase):
LogEntry.drop_collection()
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01)
d1 = datetime.datetime(1970, 1, 1, 0, 0, 1)
log = LogEntry()
log.date = d1
log.validate()
@@ -708,8 +703,8 @@ class FieldTest(unittest.TestCase):
LogEntry.drop_collection()
# create 60 log entries
for i in xrange(1950, 2010):
d = datetime.datetime(i, 01, 01, 00, 00, 01)
for i in range(1950, 2010):
d = datetime.datetime(i, 1, 1, 0, 0, 1)
LogEntry(date=d).save()
self.assertEqual(LogEntry.objects.count(), 60)
@@ -756,7 +751,7 @@ class FieldTest(unittest.TestCase):
# Post UTC - microseconds are rounded (down) nearest millisecond and
# dropped - with default datetimefields
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999)
d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 999)
log = LogEntry()
log.date = d1
log.save()
@@ -765,7 +760,7 @@ class FieldTest(unittest.TestCase):
# Post UTC - microseconds are rounded (down) nearest millisecond - with
# default datetimefields
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9999)
d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9999)
log.date = d1
log.save()
log.reload()
@@ -782,7 +777,7 @@ class FieldTest(unittest.TestCase):
# Pre UTC microseconds above 1000 is wonky - with default datetimefields
# log.date has an invalid microsecond value so I can't construct
# a date to compare.
for i in xrange(1001, 3113, 33):
for i in range(1001, 3113, 33):
d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, i)
log.date = d1
log.save()
@@ -792,7 +787,7 @@ class FieldTest(unittest.TestCase):
self.assertEqual(log, log1)
# Test string padding
microsecond = map(int, [math.pow(10, x) for x in xrange(6)])
microsecond = map(int, [math.pow(10, x) for x in range(6)])
mm = dd = hh = ii = ss = [1, 10]
for values in itertools.product([2014], mm, dd, hh, ii, ss, microsecond):
@@ -814,7 +809,7 @@ class FieldTest(unittest.TestCase):
LogEntry.drop_collection()
d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999)
d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 999)
log = LogEntry()
log.date = d1
log.save()
@@ -825,8 +820,8 @@ class FieldTest(unittest.TestCase):
LogEntry.drop_collection()
# create 60 log entries
for i in xrange(1950, 2010):
d = datetime.datetime(i, 01, 01, 00, 00, 01, 999)
for i in range(1950, 2010):
d = datetime.datetime(i, 1, 1, 0, 0, 1, 999)
LogEntry(date=d).save()
self.assertEqual(LogEntry.objects.count(), 60)
@@ -1134,12 +1129,11 @@ class FieldTest(unittest.TestCase):
e.mapping = [1]
e.save()
def create_invalid_mapping():
# try creating an invalid mapping
with self.assertRaises(ValidationError):
e.mapping = ["abc"]
e.save()
self.assertRaises(ValidationError, create_invalid_mapping)
Simple.drop_collection()
def test_list_field_rejects_strings(self):
@@ -1406,12 +1400,11 @@ class FieldTest(unittest.TestCase):
e.mapping['someint'] = 1
e.save()
def create_invalid_mapping():
# try creating an invalid mapping
with self.assertRaises(ValidationError):
e.mapping['somestring'] = "abc"
e.save()
self.assertRaises(ValidationError, create_invalid_mapping)
Simple.drop_collection()
def test_dictfield_complex(self):
@@ -1484,11 +1477,10 @@ class FieldTest(unittest.TestCase):
self.assertEqual(BaseDict, type(e.mapping))
self.assertEqual({"ints": [3, 4]}, e.mapping)
def create_invalid_mapping():
# try creating an invalid mapping
with self.assertRaises(ValueError):
e.update(set__mapping={"somestrings": ["foo", "bar", ]})
self.assertRaises(ValueError, create_invalid_mapping)
Simple.drop_collection()
def test_mapfield(self):
@@ -1503,18 +1495,14 @@ class FieldTest(unittest.TestCase):
e.mapping['someint'] = 1
e.save()
def create_invalid_mapping():
with self.assertRaises(ValidationError):
e.mapping['somestring'] = "abc"
e.save()
self.assertRaises(ValidationError, create_invalid_mapping)
def create_invalid_class():
with self.assertRaises(ValidationError):
class NoDeclaredType(Document):
mapping = MapField()
self.assertRaises(ValidationError, create_invalid_class)
Simple.drop_collection()
def test_complex_mapfield(self):
@@ -1543,14 +1531,10 @@ class FieldTest(unittest.TestCase):
self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting))
self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting))
def create_invalid_mapping():
with self.assertRaises(ValidationError):
e.mapping['someint'] = 123
e.save()
self.assertRaises(ValidationError, create_invalid_mapping)
Extensible.drop_collection()
def test_embedded_mapfield_db_field(self):
class Embedded(EmbeddedDocument):
@@ -1760,8 +1744,8 @@ class FieldTest(unittest.TestCase):
# Reference is no longer valid
foo.delete()
bar = Bar.objects.get()
self.assertRaises(DoesNotExist, lambda: getattr(bar, 'ref'))
self.assertRaises(DoesNotExist, lambda: getattr(bar, 'generic_ref'))
self.assertRaises(DoesNotExist, getattr, bar, 'ref')
self.assertRaises(DoesNotExist, getattr, bar, 'generic_ref')
# When auto_dereference is disabled, there is no trouble returning DBRef
bar = Bar.objects.get()
@@ -2036,7 +2020,7 @@ class FieldTest(unittest.TestCase):
})
def test_cached_reference_fields_on_embedded_documents(self):
def build():
with self.assertRaises(InvalidDocumentError):
class Test(Document):
name = StringField()
@@ -2045,8 +2029,6 @@ class FieldTest(unittest.TestCase):
'test': CachedReferenceField(Test)
})
self.assertRaises(InvalidDocumentError, build)
def test_cached_reference_auto_sync(self):
class Person(Document):
TYPES = (
@@ -2863,7 +2845,7 @@ class FieldTest(unittest.TestCase):
content_type = StringField()
blob = BinaryField()
BLOB = b('\xe6\x00\xc4\xff\x07')
BLOB = six.b('\xe6\x00\xc4\xff\x07')
MIME_TYPE = 'application/octet-stream'
Attachment.drop_collection()
@@ -2873,7 +2855,7 @@ class FieldTest(unittest.TestCase):
attachment_1 = Attachment.objects().first()
self.assertEqual(MIME_TYPE, attachment_1.content_type)
self.assertEqual(BLOB, bin_type(attachment_1.blob))
self.assertEqual(BLOB, six.binary_type(attachment_1.blob))
Attachment.drop_collection()
@@ -2900,13 +2882,13 @@ class FieldTest(unittest.TestCase):
attachment_required = AttachmentRequired()
self.assertRaises(ValidationError, attachment_required.validate)
attachment_required.blob = Binary(b('\xe6\x00\xc4\xff\x07'))
attachment_required.blob = Binary(six.b('\xe6\x00\xc4\xff\x07'))
attachment_required.validate()
attachment_size_limit = AttachmentSizeLimit(
blob=b('\xe6\x00\xc4\xff\x07'))
blob=six.b('\xe6\x00\xc4\xff\x07'))
self.assertRaises(ValidationError, attachment_size_limit.validate)
attachment_size_limit.blob = b('\xe6\x00\xc4\xff')
attachment_size_limit.blob = six.b('\xe6\x00\xc4\xff')
attachment_size_limit.validate()
Attachment.drop_collection()
@@ -3152,7 +3134,7 @@ class FieldTest(unittest.TestCase):
try:
shirt.validate()
except ValidationError, error:
except ValidationError as error:
# get the validation rules
error_dict = error.to_dict()
self.assertEqual(error_dict['size'], SIZE_MESSAGE)
@@ -3181,7 +3163,7 @@ class FieldTest(unittest.TestCase):
self.db['mongoengine.counters'].drop()
Person.drop_collection()
for x in xrange(10):
for x in range(10):
Person(name="Person %s" % x).save()
c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'})
@@ -3205,7 +3187,7 @@ class FieldTest(unittest.TestCase):
self.db['mongoengine.counters'].drop()
Person.drop_collection()
for x in xrange(10):
for x in range(10):
Person(name="Person %s" % x).save()
self.assertEqual(Person.id.get_next_value(), 11)
@@ -3220,7 +3202,7 @@ class FieldTest(unittest.TestCase):
self.db['mongoengine.counters'].drop()
Person.drop_collection()
for x in xrange(10):
for x in range(10):
Person(name="Person %s" % x).save()
self.assertEqual(Person.id.get_next_value(), '11')
@@ -3236,7 +3218,7 @@ class FieldTest(unittest.TestCase):
self.db['mongoengine.counters'].drop()
Person.drop_collection()
for x in xrange(10):
for x in range(10):
Person(name="Person %s" % x).save()
c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'})
@@ -3261,7 +3243,7 @@ class FieldTest(unittest.TestCase):
self.db['mongoengine.counters'].drop()
Person.drop_collection()
for x in xrange(10):
for x in range(10):
Person(name="Person %s" % x).save()
c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'})
@@ -3323,7 +3305,7 @@ class FieldTest(unittest.TestCase):
Animal.drop_collection()
Person.drop_collection()
for x in xrange(10):
for x in range(10):
Animal(name="Animal %s" % x).save()
Person(name="Person %s" % x).save()
@@ -3353,7 +3335,7 @@ class FieldTest(unittest.TestCase):
self.db['mongoengine.counters'].drop()
Person.drop_collection()
for x in xrange(10):
for x in range(10):
p = Person(name="Person %s" % x)
p.save()
@@ -3540,7 +3522,7 @@ class FieldTest(unittest.TestCase):
self.assertRaises(ValidationError, post.validate)
try:
post.validate()
except ValidationError, error:
except ValidationError as error:
# ValidationError.errors property
self.assertTrue(hasattr(error, 'errors'))
self.assertTrue(isinstance(error.errors, dict))
@@ -3601,8 +3583,6 @@ class FieldTest(unittest.TestCase):
Ensure that tuples remain tuples when they are
inside a ComplexBaseField
"""
from mongoengine.base import BaseField
class EnumField(BaseField):
def __init__(self, **kwargs):
@@ -3836,9 +3816,7 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
filtered = self.post1.comments.filter()
# Ensure nothing was changed
# < 2.6 Incompatible >
# self.assertListEqual(filtered, self.post1.comments)
self.assertEqual(filtered, self.post1.comments)
self.assertListEqual(filtered, self.post1.comments)
def test_single_keyword_filter(self):
"""
@@ -3889,10 +3867,8 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
Tests the filter method of a List of Embedded Documents
when the keyword is not a known keyword.
"""
# < 2.6 Incompatible >
# with self.assertRaises(AttributeError):
# self.post2.comments.filter(year=2)
self.assertRaises(AttributeError, self.post2.comments.filter, year=2)
with self.assertRaises(AttributeError):
self.post2.comments.filter(year=2)
def test_no_keyword_exclude(self):
"""
@@ -3902,9 +3878,7 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
filtered = self.post1.comments.exclude()
# Ensure everything was removed
# < 2.6 Incompatible >
# self.assertListEqual(filtered, [])
self.assertEqual(filtered, [])
self.assertListEqual(filtered, [])
def test_single_keyword_exclude(self):
"""
@@ -3950,10 +3924,8 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
Tests the exclude method of a List of Embedded Documents
when the keyword is not a known keyword.
"""
# < 2.6 Incompatible >
# with self.assertRaises(AttributeError):
# self.post2.comments.exclude(year=2)
self.assertRaises(AttributeError, self.post2.comments.exclude, year=2)
with self.assertRaises(AttributeError):
self.post2.comments.exclude(year=2)
def test_chained_filter_exclude(self):
"""
@@ -3991,10 +3963,7 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
single keyword.
"""
comment = self.post1.comments.get(author='user1')
# < 2.6 Incompatible >
# self.assertIsInstance(comment, self.Comments)
self.assertTrue(isinstance(comment, self.Comments))
self.assertIsInstance(comment, self.Comments)
self.assertEqual(comment.author, 'user1')
def test_multi_keyword_get(self):
@@ -4003,10 +3972,7 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
multiple keywords.
"""
comment = self.post2.comments.get(author='user2', message='message2')
# < 2.6 Incompatible >
# self.assertIsInstance(comment, self.Comments)
self.assertTrue(isinstance(comment, self.Comments))
self.assertIsInstance(comment, self.Comments)
self.assertEqual(comment.author, 'user2')
self.assertEqual(comment.message, 'message2')
@@ -4015,44 +3981,32 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
Tests the get method of a List of Embedded Documents without
a keyword to return multiple documents.
"""
# < 2.6 Incompatible >
# with self.assertRaises(MultipleObjectsReturned):
# self.post1.comments.get()
self.assertRaises(MultipleObjectsReturned, self.post1.comments.get)
with self.assertRaises(MultipleObjectsReturned):
self.post1.comments.get()
def test_keyword_multiple_return_get(self):
"""
Tests the get method of a List of Embedded Documents with a keyword
to return multiple documents.
"""
# < 2.6 Incompatible >
# with self.assertRaises(MultipleObjectsReturned):
# self.post2.comments.get(author='user2')
self.assertRaises(
MultipleObjectsReturned, self.post2.comments.get, author='user2'
)
with self.assertRaises(MultipleObjectsReturned):
self.post2.comments.get(author='user2')
def test_unknown_keyword_get(self):
"""
Tests the get method of a List of Embedded Documents with an
unknown keyword.
"""
# < 2.6 Incompatible >
# with self.assertRaises(AttributeError):
# self.post2.comments.get(year=2020)
self.assertRaises(AttributeError, self.post2.comments.get, year=2020)
with self.assertRaises(AttributeError):
self.post2.comments.get(year=2020)
def test_no_result_get(self):
"""
Tests the get method of a List of Embedded Documents where get
returns no results.
"""
# < 2.6 Incompatible >
# with self.assertRaises(DoesNotExist):
# self.post1.comments.get(author='user3')
self.assertRaises(
DoesNotExist, self.post1.comments.get, author='user3'
)
with self.assertRaises(DoesNotExist):
self.post1.comments.get(author='user3')
def test_first(self):
"""
@@ -4062,9 +4016,7 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
comment = self.post1.comments.first()
# Ensure a Comment object was returned.
# < 2.6 Incompatible >
# self.assertIsInstance(comment, self.Comments)
self.assertTrue(isinstance(comment, self.Comments))
self.assertIsInstance(comment, self.Comments)
self.assertEqual(comment, self.post1.comments[0])
def test_create(self):
@@ -4077,22 +4029,14 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
self.post1.save()
# Ensure the returned value is the comment object.
# < 2.6 Incompatible >
# self.assertIsInstance(comment, self.Comments)
self.assertTrue(isinstance(comment, self.Comments))
self.assertIsInstance(comment, self.Comments)
self.assertEqual(comment.author, 'user4')
self.assertEqual(comment.message, 'message1')
# Ensure the new comment was actually saved to the database.
# < 2.6 Incompatible >
# self.assertIn(
# comment,
# self.BlogPost.objects(comments__author='user4')[0].comments
# )
self.assertTrue(
comment in self.BlogPost.objects(
comments__author='user4'
)[0].comments
self.assertIn(
comment,
self.BlogPost.objects(comments__author='user4')[0].comments
)
def test_filtered_create(self):
@@ -4107,22 +4051,14 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
self.post1.save()
# Ensure the returned value is the comment object.
# < 2.6 Incompatible >
# self.assertIsInstance(comment, self.Comments)
self.assertTrue(isinstance(comment, self.Comments))
self.assertIsInstance(comment, self.Comments)
self.assertEqual(comment.author, 'user4')
self.assertEqual(comment.message, 'message1')
# Ensure the new comment was actually saved to the database.
# < 2.6 Incompatible >
# self.assertIn(
# comment,
# self.BlogPost.objects(comments__author='user4')[0].comments
# )
self.assertTrue(
comment in self.BlogPost.objects(
comments__author='user4'
)[0].comments
self.assertIn(
comment,
self.BlogPost.objects(comments__author='user4')[0].comments
)
def test_no_keyword_update(self):
@@ -4135,22 +4071,14 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
self.post1.save()
# Ensure that nothing was altered.
# < 2.6 Incompatible >
# self.assertIn(
# original[0],
# self.BlogPost.objects(id=self.post1.id)[0].comments
# )
self.assertTrue(
original[0] in self.BlogPost.objects(id=self.post1.id)[0].comments
self.assertIn(
original[0],
self.BlogPost.objects(id=self.post1.id)[0].comments
)
# < 2.6 Incompatible >
# self.assertIn(
# original[1],
# self.BlogPost.objects(id=self.post1.id)[0].comments
# )
self.assertTrue(
original[1] in self.BlogPost.objects(id=self.post1.id)[0].comments
self.assertIn(
original[1],
self.BlogPost.objects(id=self.post1.id)[0].comments
)
# Ensure the method returned 0 as the number of entries
@@ -4196,13 +4124,9 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
comments.save()
# Ensure that the new comment has been added to the database.
# < 2.6 Incompatible >
# self.assertIn(
# new_comment,
# self.BlogPost.objects(id=self.post1.id)[0].comments
# )
self.assertTrue(
new_comment in self.BlogPost.objects(id=self.post1.id)[0].comments
self.assertIn(
new_comment,
self.BlogPost.objects(id=self.post1.id)[0].comments
)
def test_delete(self):
@@ -4214,23 +4138,15 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
# Ensure that all the comments under post1 were deleted in the
# database.
# < 2.6 Incompatible >
# self.assertListEqual(
# self.BlogPost.objects(id=self.post1.id)[0].comments, []
# )
self.assertEqual(
self.assertListEqual(
self.BlogPost.objects(id=self.post1.id)[0].comments, []
)
# Ensure that post1 comments were deleted from the list.
# < 2.6 Incompatible >
# self.assertListEqual(self.post1.comments, [])
self.assertEqual(self.post1.comments, [])
self.assertListEqual(self.post1.comments, [])
# Ensure that comments still returned a EmbeddedDocumentList object.
# < 2.6 Incompatible >
# self.assertIsInstance(self.post1.comments, EmbeddedDocumentList)
self.assertTrue(isinstance(self.post1.comments, EmbeddedDocumentList))
self.assertIsInstance(self.post1.comments, EmbeddedDocumentList)
# Ensure that the delete method returned 2 as the number of entries
# deleted from the database
@@ -4270,21 +4186,15 @@ class EmbeddedDocumentListFieldTestCase(unittest.TestCase):
self.post1.save()
# Ensure that only the user2 comment was deleted.
# < 2.6 Incompatible >
# self.assertNotIn(
# comment, self.BlogPost.objects(id=self.post1.id)[0].comments
# )
self.assertTrue(
comment not in self.BlogPost.objects(id=self.post1.id)[0].comments
self.assertNotIn(
comment, self.BlogPost.objects(id=self.post1.id)[0].comments
)
self.assertEqual(
len(self.BlogPost.objects(id=self.post1.id)[0].comments), 1
)
# Ensure that the user2 comment no longer exists in the list.
# < 2.6 Incompatible >
# self.assertNotIn(comment, self.post1.comments)
self.assertTrue(comment not in self.post1.comments)
self.assertNotIn(comment, self.post1.comments)
self.assertEqual(len(self.post1.comments), 1)
# Ensure that the delete method returned 1 as the number of entries

View File

@@ -1,18 +1,16 @@
# -*- coding: utf-8 -*-
import sys
sys.path[0:0] = [""]
import copy
import os
import unittest
import tempfile
import gridfs
import six
from nose.plugins.skip import SkipTest
from mongoengine import *
from mongoengine.connection import get_db
from mongoengine.python_support import b, StringIO
from mongoengine.python_support import StringIO
try:
from PIL import Image
@@ -49,7 +47,7 @@ class FileTest(unittest.TestCase):
PutFile.drop_collection()
text = b('Hello, World!')
text = six.b('Hello, World!')
content_type = 'text/plain'
putfile = PutFile()
@@ -88,8 +86,8 @@ class FileTest(unittest.TestCase):
StreamFile.drop_collection()
text = b('Hello, World!')
more_text = b('Foo Bar')
text = six.b('Hello, World!')
more_text = six.b('Foo Bar')
content_type = 'text/plain'
streamfile = StreamFile()
@@ -123,8 +121,8 @@ class FileTest(unittest.TestCase):
StreamFile.drop_collection()
text = b('Hello, World!')
more_text = b('Foo Bar')
text = six.b('Hello, World!')
more_text = six.b('Foo Bar')
content_type = 'text/plain'
streamfile = StreamFile()
@@ -155,8 +153,8 @@ class FileTest(unittest.TestCase):
class SetFile(Document):
the_file = FileField()
text = b('Hello, World!')
more_text = b('Foo Bar')
text = six.b('Hello, World!')
more_text = six.b('Foo Bar')
SetFile.drop_collection()
@@ -185,7 +183,7 @@ class FileTest(unittest.TestCase):
GridDocument.drop_collection()
with tempfile.TemporaryFile() as f:
f.write(b("Hello World!"))
f.write(six.b("Hello World!"))
f.flush()
# Test without default
@@ -202,7 +200,7 @@ class FileTest(unittest.TestCase):
self.assertEqual(doc_b.the_file.grid_id, doc_c.the_file.grid_id)
# Test with default
doc_d = GridDocument(the_file=b(''))
doc_d = GridDocument(the_file=six.b(''))
doc_d.save()
doc_e = GridDocument.objects.with_id(doc_d.id)
@@ -228,7 +226,7 @@ class FileTest(unittest.TestCase):
# First instance
test_file = TestFile()
test_file.name = "Hello, World!"
test_file.the_file.put(b('Hello, World!'))
test_file.the_file.put(six.b('Hello, World!'))
test_file.save()
# Second instance
@@ -282,7 +280,7 @@ class FileTest(unittest.TestCase):
test_file = TestFile()
self.assertFalse(bool(test_file.the_file))
test_file.the_file.put(b('Hello, World!'), content_type='text/plain')
test_file.the_file.put(six.b('Hello, World!'), content_type='text/plain')
test_file.save()
self.assertTrue(bool(test_file.the_file))
@@ -297,66 +295,66 @@ class FileTest(unittest.TestCase):
test_file = TestFile()
self.assertFalse(test_file.the_file in [{"test": 1}])
def test_file_disk_space(self):
""" Test disk space usage when we delete/replace a file """
def test_file_disk_space(self):
""" Test disk space usage when we delete/replace a file """
class TestFile(Document):
the_file = FileField()
text = b('Hello, World!')
text = six.b('Hello, World!')
content_type = 'text/plain'
testfile = TestFile()
testfile.the_file.put(text, content_type=content_type, filename="hello")
testfile.save()
# Now check fs.files and fs.chunks
# Now check fs.files and fs.chunks
db = TestFile._get_db()
files = db.fs.files.find()
chunks = db.fs.chunks.find()
self.assertEquals(len(list(files)), 1)
self.assertEquals(len(list(chunks)), 1)
# Deleting the docoument should delete the files
# Deleting the docoument should delete the files
testfile.delete()
files = db.fs.files.find()
chunks = db.fs.chunks.find()
self.assertEquals(len(list(files)), 0)
self.assertEquals(len(list(chunks)), 0)
# Test case where we don't store a file in the first place
# Test case where we don't store a file in the first place
testfile = TestFile()
testfile.save()
files = db.fs.files.find()
chunks = db.fs.chunks.find()
self.assertEquals(len(list(files)), 0)
self.assertEquals(len(list(chunks)), 0)
testfile.delete()
files = db.fs.files.find()
chunks = db.fs.chunks.find()
self.assertEquals(len(list(files)), 0)
self.assertEquals(len(list(chunks)), 0)
# Test case where we overwrite the file
# Test case where we overwrite the file
testfile = TestFile()
testfile.the_file.put(text, content_type=content_type, filename="hello")
testfile.save()
text = b('Bonjour, World!')
text = six.b('Bonjour, World!')
testfile.the_file.replace(text, content_type=content_type, filename="hello")
testfile.save()
files = db.fs.files.find()
chunks = db.fs.chunks.find()
self.assertEquals(len(list(files)), 1)
self.assertEquals(len(list(chunks)), 1)
testfile.delete()
files = db.fs.files.find()
chunks = db.fs.chunks.find()
self.assertEquals(len(list(files)), 0)
@@ -372,14 +370,14 @@ class FileTest(unittest.TestCase):
TestImage.drop_collection()
with tempfile.TemporaryFile() as f:
f.write(b("Hello World!"))
f.write(six.b("Hello World!"))
f.flush()
t = TestImage()
try:
t.image.put(f)
self.fail("Should have raised an invalidation error")
except ValidationError, e:
except ValidationError as e:
self.assertEqual("%s" % e, "Invalid image: cannot identify image file %s" % f)
t = TestImage()
@@ -496,7 +494,7 @@ class FileTest(unittest.TestCase):
# First instance
test_file = TestFile()
test_file.name = "Hello, World!"
test_file.the_file.put(b('Hello, World!'),
test_file.the_file.put(six.b('Hello, World!'),
name="hello.txt")
test_file.save()
@@ -504,16 +502,15 @@ class FileTest(unittest.TestCase):
self.assertEqual(data.get('name'), 'hello.txt')
test_file = TestFile.objects.first()
self.assertEqual(test_file.the_file.read(),
b('Hello, World!'))
self.assertEqual(test_file.the_file.read(), six.b('Hello, World!'))
test_file = TestFile.objects.first()
test_file.the_file = b('HELLO, WORLD!')
test_file.the_file = six.b('HELLO, WORLD!')
test_file.save()
test_file = TestFile.objects.first()
self.assertEqual(test_file.the_file.read(),
b('HELLO, WORLD!'))
six.b('HELLO, WORLD!'))
def test_copyable(self):
class PutFile(Document):
@@ -521,7 +518,7 @@ class FileTest(unittest.TestCase):
PutFile.drop_collection()
text = b('Hello, World!')
text = six.b('Hello, World!')
content_type = 'text/plain'
putfile = PutFile()

View File

@@ -1,7 +1,4 @@
# -*- coding: utf-8 -*-
import sys
sys.path[0:0] = [""]
import unittest
from mongoengine import *

View File

@@ -1,11 +0,0 @@
import unittest
from convert_to_new_inheritance_model import *
from decimalfield_as_float import *
from referencefield_dbref_to_object_id import *
from turn_off_inheritance import *
from uuidfield_to_binary import *
if __name__ == '__main__':
unittest.main()

View File

@@ -1,51 +0,0 @@
# -*- coding: utf-8 -*-
import unittest
from mongoengine import Document, connect
from mongoengine.connection import get_db
from mongoengine.fields import StringField
__all__ = ('ConvertToNewInheritanceModel', )
class ConvertToNewInheritanceModel(unittest.TestCase):
def setUp(self):
connect(db='mongoenginetest')
self.db = get_db()
def tearDown(self):
for collection in self.db.collection_names():
if 'system.' in collection:
continue
self.db.drop_collection(collection)
def test_how_to_convert_to_the_new_inheritance_model(self):
"""Demonstrates migrating from 0.7 to 0.8
"""
# 1. Declaration of the class
class Animal(Document):
name = StringField()
meta = {
'allow_inheritance': True,
'indexes': ['name']
}
# 2. Remove _types
collection = Animal._get_collection()
collection.update({}, {"$unset": {"_types": 1}}, multi=True)
# 3. Confirm extra data is removed
count = collection.find({'_types': {"$exists": True}}).count()
self.assertEqual(0, count)
# 4. Remove indexes
info = collection.index_information()
indexes_to_drop = [key for key, value in info.iteritems()
if '_types' in dict(value['key'])]
for index in indexes_to_drop:
collection.drop_index(index)
# 5. Recreate indexes
Animal.ensure_indexes()

View File

@@ -1,50 +0,0 @@
# -*- coding: utf-8 -*-
import unittest
import decimal
from decimal import Decimal
from mongoengine import Document, connect
from mongoengine.connection import get_db
from mongoengine.fields import StringField, DecimalField, ListField
__all__ = ('ConvertDecimalField', )
class ConvertDecimalField(unittest.TestCase):
def setUp(self):
connect(db='mongoenginetest')
self.db = get_db()
def test_how_to_convert_decimal_fields(self):
"""Demonstrates migrating from 0.7 to 0.8
"""
# 1. Old definition - using dbrefs
class Person(Document):
name = StringField()
money = DecimalField(force_string=True)
monies = ListField(DecimalField(force_string=True))
Person.drop_collection()
Person(name="Wilson Jr", money=Decimal("2.50"),
monies=[Decimal("2.10"), Decimal("5.00")]).save()
# 2. Start the migration by changing the schema
# Change DecimalField - add precision and rounding settings
class Person(Document):
name = StringField()
money = DecimalField(precision=2, rounding=decimal.ROUND_HALF_UP)
monies = ListField(DecimalField(precision=2,
rounding=decimal.ROUND_HALF_UP))
# 3. Loop all the objects and mark parent as changed
for p in Person.objects:
p._mark_as_changed('money')
p._mark_as_changed('monies')
p.save()
# 4. Confirmation of the fix!
wilson = Person.objects(name="Wilson Jr").as_pymongo()[0]
self.assertTrue(isinstance(wilson['money'], float))
self.assertTrue(all([isinstance(m, float) for m in wilson['monies']]))

View File

@@ -1,52 +0,0 @@
# -*- coding: utf-8 -*-
import unittest
from mongoengine import Document, connect
from mongoengine.connection import get_db
from mongoengine.fields import StringField, ReferenceField, ListField
__all__ = ('ConvertToObjectIdsModel', )
class ConvertToObjectIdsModel(unittest.TestCase):
def setUp(self):
connect(db='mongoenginetest')
self.db = get_db()
def test_how_to_convert_to_object_id_reference_fields(self):
"""Demonstrates migrating from 0.7 to 0.8
"""
# 1. Old definition - using dbrefs
class Person(Document):
name = StringField()
parent = ReferenceField('self', dbref=True)
friends = ListField(ReferenceField('self', dbref=True))
Person.drop_collection()
p1 = Person(name="Wilson", parent=None).save()
f1 = Person(name="John", parent=None).save()
f2 = Person(name="Paul", parent=None).save()
f3 = Person(name="George", parent=None).save()
f4 = Person(name="Ringo", parent=None).save()
Person(name="Wilson Jr", parent=p1, friends=[f1, f2, f3, f4]).save()
# 2. Start the migration by changing the schema
# Change ReferenceField as now dbref defaults to False
class Person(Document):
name = StringField()
parent = ReferenceField('self')
friends = ListField(ReferenceField('self'))
# 3. Loop all the objects and mark parent as changed
for p in Person.objects:
p._mark_as_changed('parent')
p._mark_as_changed('friends')
p.save()
# 4. Confirmation of the fix!
wilson = Person.objects(name="Wilson Jr").as_pymongo()[0]
self.assertEqual(p1.id, wilson['parent'])
self.assertEqual([f1.id, f2.id, f3.id, f4.id], wilson['friends'])

View File

@@ -1,62 +0,0 @@
# -*- coding: utf-8 -*-
import unittest
from mongoengine import Document, connect
from mongoengine.connection import get_db
from mongoengine.fields import StringField
__all__ = ('TurnOffInheritanceTest', )
class TurnOffInheritanceTest(unittest.TestCase):
def setUp(self):
connect(db='mongoenginetest')
self.db = get_db()
def tearDown(self):
for collection in self.db.collection_names():
if 'system.' in collection:
continue
self.db.drop_collection(collection)
def test_how_to_turn_off_inheritance(self):
"""Demonstrates migrating from allow_inheritance = True to False.
"""
# 1. Old declaration of the class
class Animal(Document):
name = StringField()
meta = {
'allow_inheritance': True,
'indexes': ['name']
}
# 2. Turn off inheritance
class Animal(Document):
name = StringField()
meta = {
'allow_inheritance': False,
'indexes': ['name']
}
# 3. Remove _types and _cls
collection = Animal._get_collection()
collection.update({}, {"$unset": {"_types": 1, "_cls": 1}}, multi=True)
# 3. Confirm extra data is removed
count = collection.find({"$or": [{'_types': {"$exists": True}},
{'_cls': {"$exists": True}}]}).count()
assert count == 0
# 4. Remove indexes
info = collection.index_information()
indexes_to_drop = [key for key, value in info.iteritems()
if '_types' in dict(value['key'])
or '_cls' in dict(value['key'])]
for index in indexes_to_drop:
collection.drop_index(index)
# 5. Recreate indexes
Animal.ensure_indexes()

View File

@@ -1,48 +0,0 @@
# -*- coding: utf-8 -*-
import unittest
import uuid
from mongoengine import Document, connect
from mongoengine.connection import get_db
from mongoengine.fields import StringField, UUIDField, ListField
__all__ = ('ConvertToBinaryUUID', )
class ConvertToBinaryUUID(unittest.TestCase):
def setUp(self):
connect(db='mongoenginetest')
self.db = get_db()
def test_how_to_convert_to_binary_uuid_fields(self):
"""Demonstrates migrating from 0.7 to 0.8
"""
# 1. Old definition - using dbrefs
class Person(Document):
name = StringField()
uuid = UUIDField(binary=False)
uuids = ListField(UUIDField(binary=False))
Person.drop_collection()
Person(name="Wilson Jr", uuid=uuid.uuid4(),
uuids=[uuid.uuid4(), uuid.uuid4()]).save()
# 2. Start the migration by changing the schema
# Change UUIDFIeld as now binary defaults to True
class Person(Document):
name = StringField()
uuid = UUIDField()
uuids = ListField(UUIDField())
# 3. Loop all the objects and mark parent as changed
for p in Person.objects:
p._mark_as_changed('uuid')
p._mark_as_changed('uuids')
p.save()
# 4. Confirmation of the fix!
wilson = Person.objects(name="Wilson Jr").as_pymongo()[0]
self.assertTrue(isinstance(wilson['uuid'], uuid.UUID))
self.assertTrue(all([isinstance(u, uuid.UUID) for u in wilson['uuids']]))

View File

@@ -1,6 +1,3 @@
import sys
sys.path[0:0] = [""]
import unittest
from mongoengine import *
@@ -95,7 +92,7 @@ class OnlyExcludeAllTest(unittest.TestCase):
exclude = ['d', 'e']
only = ['b', 'c']
qs = MyDoc.objects.fields(**dict(((i, 1) for i in include)))
qs = MyDoc.objects.fields(**{i: 1 for i in include})
self.assertEqual(qs._loaded_fields.as_dict(),
{'a': 1, 'b': 1, 'c': 1, 'd': 1, 'e': 1})
qs = qs.only(*only)
@@ -103,14 +100,14 @@ class OnlyExcludeAllTest(unittest.TestCase):
qs = qs.exclude(*exclude)
self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1})
qs = MyDoc.objects.fields(**dict(((i, 1) for i in include)))
qs = MyDoc.objects.fields(**{i: 1 for i in include})
qs = qs.exclude(*exclude)
self.assertEqual(qs._loaded_fields.as_dict(), {'a': 1, 'b': 1, 'c': 1})
qs = qs.only(*only)
self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1})
qs = MyDoc.objects.exclude(*exclude)
qs = qs.fields(**dict(((i, 1) for i in include)))
qs = qs.fields(**{i: 1 for i in include})
self.assertEqual(qs._loaded_fields.as_dict(), {'a': 1, 'b': 1, 'c': 1})
qs = qs.only(*only)
self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1})
@@ -129,7 +126,7 @@ class OnlyExcludeAllTest(unittest.TestCase):
exclude = ['d', 'e']
only = ['b', 'c']
qs = MyDoc.objects.fields(**dict(((i, 1) for i in include)))
qs = MyDoc.objects.fields(**{i: 1 for i in include})
qs = qs.exclude(*exclude)
qs = qs.only(*only)
qs = qs.fields(slice__b=5)

View File

@@ -1,9 +1,5 @@
import sys
sys.path[0:0] = [""]
import unittest
from datetime import datetime, timedelta
import unittest
from pymongo.errors import OperationFailure
from mongoengine import *

View File

@@ -1,6 +1,3 @@
import sys
sys.path[0:0] = [""]
import unittest
from mongoengine import connect, Document, IntField
@@ -99,4 +96,4 @@ class FindAndModifyTest(unittest.TestCase):
if __name__ == '__main__':
unittest.main()
unittest.main()

View File

@@ -9,13 +9,13 @@ from nose.plugins.skip import SkipTest
import pymongo
from pymongo.errors import ConfigurationError
from pymongo.read_preferences import ReadPreference
import six
from mongoengine import *
from mongoengine.connection import get_connection, get_db
from mongoengine.context_managers import query_counter, switch_db
from mongoengine.errors import InvalidQueryError
from mongoengine.python_support import IS_PYMONGO_3, PY3
from mongoengine.python_support import IS_PYMONGO_3
from mongoengine.queryset import (DoesNotExist, MultipleObjectsReturned,
QuerySet, QuerySetManager, queryset_manager)
@@ -25,7 +25,10 @@ __all__ = ("QuerySetTest",)
class db_ops_tracker(query_counter):
def get_ops(self):
ignore_query = {"ns": {"$ne": "%s.system.indexes" % self.db.name}}
ignore_query = {
'ns': {'$ne': '%s.system.indexes' % self.db.name},
'command.count': {'$ne': 'system.profile'}
}
return list(self.db.system.profile.find(ignore_query))
@@ -94,12 +97,12 @@ class QuerySetTest(unittest.TestCase):
author = ReferenceField(self.Person)
author2 = GenericReferenceField()
def test_reference():
# test addressing a field from a reference
with self.assertRaises(InvalidQueryError):
list(BlogPost.objects(author__name="test"))
self.assertRaises(InvalidQueryError, test_reference)
def test_generic_reference():
# should fail for a generic reference as well
with self.assertRaises(InvalidQueryError):
list(BlogPost.objects(author2__name="test"))
def test_find(self):
@@ -174,7 +177,7 @@ class QuerySetTest(unittest.TestCase):
# Test larger slice __repr__
self.Person.objects.delete()
for i in xrange(55):
for i in range(55):
self.Person(name='A%s' % i, age=i).save()
self.assertEqual(self.Person.objects.count(), 55)
@@ -218,14 +221,15 @@ class QuerySetTest(unittest.TestCase):
person = self.Person.objects[1]
self.assertEqual(person.name, "User B")
self.assertRaises(IndexError, self.Person.objects.__getitem__, 2)
with self.assertRaises(IndexError):
self.Person.objects[2]
# Find a document using just the object id
person = self.Person.objects.with_id(person1.id)
self.assertEqual(person.name, "User A")
self.assertRaises(
InvalidQueryError, self.Person.objects(name="User A").with_id, person1.id)
with self.assertRaises(InvalidQueryError):
self.Person.objects(name="User A").with_id(person1.id)
def test_find_only_one(self):
"""Ensure that a query using ``get`` returns at most one result.
@@ -363,7 +367,8 @@ class QuerySetTest(unittest.TestCase):
# test invalid batch size
qs = A.objects.batch_size(-1)
self.assertRaises(ValueError, lambda: list(qs))
with self.assertRaises(ValueError):
list(qs)
def test_update_write_concern(self):
"""Test that passing write_concern works"""
@@ -392,18 +397,14 @@ class QuerySetTest(unittest.TestCase):
"""Test to ensure that update is passed a value to update to"""
self.Person.drop_collection()
author = self.Person(name='Test User')
author.save()
author = self.Person.objects.create(name='Test User')
def update_raises():
with self.assertRaises(OperationError):
self.Person.objects(pk=author.pk).update({})
def update_one_raises():
with self.assertRaises(OperationError):
self.Person.objects(pk=author.pk).update_one({})
self.assertRaises(OperationError, update_raises)
self.assertRaises(OperationError, update_one_raises)
def test_update_array_position(self):
"""Ensure that updating by array position works.
@@ -431,8 +432,8 @@ class QuerySetTest(unittest.TestCase):
Blog.objects.create(posts=[post2, post1])
# Update all of the first comments of second posts of all blogs
Blog.objects().update(set__posts__1__comments__0__name="testc")
testc_blogs = Blog.objects(posts__1__comments__0__name="testc")
Blog.objects().update(set__posts__1__comments__0__name='testc')
testc_blogs = Blog.objects(posts__1__comments__0__name='testc')
self.assertEqual(testc_blogs.count(), 2)
Blog.drop_collection()
@@ -441,14 +442,13 @@ class QuerySetTest(unittest.TestCase):
# Update only the first blog returned by the query
Blog.objects().update_one(
set__posts__1__comments__1__name="testc")
testc_blogs = Blog.objects(posts__1__comments__1__name="testc")
set__posts__1__comments__1__name='testc')
testc_blogs = Blog.objects(posts__1__comments__1__name='testc')
self.assertEqual(testc_blogs.count(), 1)
# Check that using this indexing syntax on a non-list fails
def non_list_indexing():
Blog.objects().update(set__posts__1__comments__0__name__1="asdf")
self.assertRaises(InvalidQueryError, non_list_indexing)
with self.assertRaises(InvalidQueryError):
Blog.objects().update(set__posts__1__comments__0__name__1='asdf')
Blog.drop_collection()
@@ -516,15 +516,12 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(simple.x, [1, 2, None, 4, 3, 2, 3, 4])
# Nested updates arent supported yet..
def update_nested():
with self.assertRaises(OperationError):
Simple.drop_collection()
Simple(x=[{'test': [1, 2, 3, 4]}]).save()
Simple.objects(x__test=2).update(set__x__S__test__S=3)
self.assertEqual(simple.x, [1, 2, 3, 4])
self.assertRaises(OperationError, update_nested)
Simple.drop_collection()
def test_update_using_positional_operator_embedded_document(self):
"""Ensure that the embedded documents can be updated using the positional
operator."""
@@ -617,11 +614,11 @@ class QuerySetTest(unittest.TestCase):
members = DictField()
club = Club()
club.members['John'] = dict(gender="M", age=13)
club.members['John'] = {'gender': 'M', 'age': 13}
club.save()
Club.objects().update(
set__members={"John": dict(gender="F", age=14)})
set__members={"John": {'gender': 'F', 'age': 14}})
club = Club.objects().first()
self.assertEqual(club.members['John']['gender'], "F")
@@ -802,7 +799,7 @@ class QuerySetTest(unittest.TestCase):
post2 = Post(comments=[comment2, comment2])
blogs = []
for i in xrange(1, 100):
for i in range(1, 100):
blogs.append(Blog(title="post %s" % i, posts=[post1, post2]))
Blog.objects.insert(blogs, load_bulk=False)
@@ -839,30 +836,31 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(Blog.objects.count(), 2)
# test handles people trying to upsert
def throw_operation_error():
# test inserting an existing document (shouldn't be allowed)
with self.assertRaises(OperationError):
blog = Blog.objects.first()
Blog.objects.insert(blog)
# test inserting a query set
with self.assertRaises(OperationError):
blogs = Blog.objects
Blog.objects.insert(blogs)
self.assertRaises(OperationError, throw_operation_error)
# Test can insert new doc
# insert a new doc
new_post = Blog(title="code123", id=ObjectId())
Blog.objects.insert(new_post)
# test handles other classes being inserted
def throw_operation_error_wrong_doc():
class Author(Document):
pass
class Author(Document):
pass
# try inserting a different document class
with self.assertRaises(OperationError):
Blog.objects.insert(Author())
self.assertRaises(OperationError, throw_operation_error_wrong_doc)
def throw_operation_error_not_a_document():
# try inserting a non-document
with self.assertRaises(OperationError):
Blog.objects.insert("HELLO WORLD")
self.assertRaises(OperationError, throw_operation_error_not_a_document)
Blog.drop_collection()
blog1 = Blog(title="code", posts=[post1, post2])
@@ -882,14 +880,13 @@ class QuerySetTest(unittest.TestCase):
blog3 = Blog(title="baz", posts=[post1, post2])
Blog.objects.insert([blog1, blog2])
def throw_operation_error_not_unique():
with self.assertRaises(NotUniqueError):
Blog.objects.insert([blog2, blog3])
self.assertRaises(NotUniqueError, throw_operation_error_not_unique)
self.assertEqual(Blog.objects.count(), 2)
Blog.objects.insert([blog2, blog3], write_concern={"w": 0,
'continue_on_error': True})
Blog.objects.insert([blog2, blog3],
write_concern={"w": 0, 'continue_on_error': True})
self.assertEqual(Blog.objects.count(), 3)
def test_get_changed_fields_query_count(self):
@@ -1022,7 +1019,7 @@ class QuerySetTest(unittest.TestCase):
Doc.drop_collection()
for i in xrange(1000):
for i in range(1000):
Doc(number=i).save()
docs = Doc.objects.order_by('number')
@@ -1176,7 +1173,7 @@ class QuerySetTest(unittest.TestCase):
qs = list(qs)
expected = list(expected)
self.assertEqual(len(qs), len(expected))
for i in xrange(len(qs)):
for i in range(len(qs)):
self.assertEqual(qs[i], expected[i])
def test_ordering(self):
@@ -1216,7 +1213,8 @@ class QuerySetTest(unittest.TestCase):
self.assertSequence(qs, expected)
def test_clear_ordering(self):
""" Ensure that the default ordering can be cleared by calling order_by().
"""Ensure that the default ordering can be cleared by calling
order_by() w/o any arguments.
"""
class BlogPost(Document):
title = StringField()
@@ -1232,12 +1230,13 @@ class QuerySetTest(unittest.TestCase):
BlogPost.objects.filter(title='whatever').first()
self.assertEqual(len(q.get_ops()), 1)
self.assertEqual(
q.get_ops()[0]['query']['$orderby'], {u'published_date': -1})
q.get_ops()[0]['query']['$orderby'],
{'published_date': -1}
)
with db_ops_tracker() as q:
BlogPost.objects.filter(title='whatever').order_by().first()
self.assertEqual(len(q.get_ops()), 1)
print q.get_ops()[0]['query']
self.assertFalse('$orderby' in q.get_ops()[0]['query'])
def test_no_ordering_for_get(self):
@@ -1710,7 +1709,7 @@ class QuerySetTest(unittest.TestCase):
Log.drop_collection()
for i in xrange(10):
for i in range(10):
Log().save()
Log.objects()[3:5].delete()
@@ -1910,12 +1909,10 @@ class QuerySetTest(unittest.TestCase):
Site.objects(id=s.id).update_one(pull__collaborators__user='Esteban')
self.assertEqual(Site.objects.first().collaborators, [])
def pull_all():
with self.assertRaises(InvalidQueryError):
Site.objects(id=s.id).update_one(
pull_all__collaborators__user=['Ross'])
self.assertRaises(InvalidQueryError, pull_all)
def test_pull_from_nested_embedded(self):
class User(EmbeddedDocument):
@@ -1946,12 +1943,10 @@ class QuerySetTest(unittest.TestCase):
pull__collaborators__unhelpful={'name': 'Frank'})
self.assertEqual(Site.objects.first().collaborators['unhelpful'], [])
def pull_all():
with self.assertRaises(InvalidQueryError):
Site.objects(id=s.id).update_one(
pull_all__collaborators__helpful__name=['Ross'])
self.assertRaises(InvalidQueryError, pull_all)
def test_pull_from_nested_mapfield(self):
class Collaborator(EmbeddedDocument):
@@ -1980,12 +1975,10 @@ class QuerySetTest(unittest.TestCase):
pull__collaborators__unhelpful={'user': 'Frank'})
self.assertEqual(Site.objects.first().collaborators['unhelpful'], [])
def pull_all():
with self.assertRaises(InvalidQueryError):
Site.objects(id=s.id).update_one(
pull_all__collaborators__helpful__user=['Ross'])
self.assertRaises(InvalidQueryError, pull_all)
def test_update_one_pop_generic_reference(self):
class BlogTag(Document):
@@ -2610,7 +2603,7 @@ class QuerySetTest(unittest.TestCase):
BlogPost(hits=2, tags=['music', 'actors']).save()
def test_assertions(f):
f = dict((key, int(val)) for key, val in f.items())
f = {key: int(val) for key, val in f.items()}
self.assertEqual(
set(['music', 'film', 'actors', 'watch']), set(f.keys()))
self.assertEqual(f['music'], 3)
@@ -2625,7 +2618,7 @@ class QuerySetTest(unittest.TestCase):
# Ensure query is taken into account
def test_assertions(f):
f = dict((key, int(val)) for key, val in f.items())
f = {key: int(val) for key, val in f.items()}
self.assertEqual(set(['music', 'actors', 'watch']), set(f.keys()))
self.assertEqual(f['music'], 2)
self.assertEqual(f['actors'], 1)
@@ -2689,7 +2682,7 @@ class QuerySetTest(unittest.TestCase):
doc.save()
def test_assertions(f):
f = dict((key, int(val)) for key, val in f.items())
f = {key: int(val) for key, val in f.items()}
self.assertEqual(
set(['62-3331-1656', '62-3332-1656']), set(f.keys()))
self.assertEqual(f['62-3331-1656'], 2)
@@ -2703,7 +2696,7 @@ class QuerySetTest(unittest.TestCase):
# Ensure query is taken into account
def test_assertions(f):
f = dict((key, int(val)) for key, val in f.items())
f = {key: int(val) for key, val in f.items()}
self.assertEqual(set(['62-3331-1656']), set(f.keys()))
self.assertEqual(f['62-3331-1656'], 2)
@@ -2810,10 +2803,10 @@ class QuerySetTest(unittest.TestCase):
Test.drop_collection()
for i in xrange(50):
for i in range(50):
Test(val=1).save()
for i in xrange(20):
for i in range(20):
Test(val=2).save()
freqs = Test.objects.item_frequencies(
@@ -3603,7 +3596,7 @@ class QuerySetTest(unittest.TestCase):
Post.drop_collection()
for i in xrange(10):
for i in range(10):
Post(title="Post %s" % i).save()
self.assertEqual(5, Post.objects.limit(5).skip(5).count(with_limit_and_skip=True))
@@ -3618,7 +3611,7 @@ class QuerySetTest(unittest.TestCase):
pass
MyDoc.drop_collection()
for i in xrange(0, 10):
for i in range(0, 10):
MyDoc().save()
self.assertEqual(MyDoc.objects.count(), 10)
@@ -3674,7 +3667,7 @@ class QuerySetTest(unittest.TestCase):
Number.drop_collection()
for i in xrange(1, 101):
for i in range(1, 101):
t = Number(n=i)
t.save()
@@ -3821,11 +3814,9 @@ class QuerySetTest(unittest.TestCase):
self.assertTrue(a in results)
self.assertTrue(c in results)
def invalid_where():
with self.assertRaises(TypeError):
list(IntPair.objects.where(fielda__gte=3))
self.assertRaises(TypeError, invalid_where)
def test_scalar(self):
class Organization(Document):
@@ -4081,7 +4072,7 @@ class QuerySetTest(unittest.TestCase):
# Test larger slice __repr__
self.Person.objects.delete()
for i in xrange(55):
for i in range(55):
self.Person(name='A%s' % i, age=i).save()
self.assertEqual(self.Person.objects.scalar('name').count(), 55)
@@ -4089,7 +4080,7 @@ class QuerySetTest(unittest.TestCase):
"A0", "%s" % self.Person.objects.order_by('name').scalar('name').first())
self.assertEqual(
"A0", "%s" % self.Person.objects.scalar('name').order_by('name')[0])
if PY3:
if six.PY3:
self.assertEqual("['A1', 'A2']", "%s" % self.Person.objects.order_by(
'age').scalar('name')[1:3])
self.assertEqual("['A51', 'A52']", "%s" % self.Person.objects.order_by(
@@ -4107,7 +4098,7 @@ class QuerySetTest(unittest.TestCase):
pks = self.Person.objects.order_by('age').scalar('pk')[1:3]
names = self.Person.objects.scalar('name').in_bulk(list(pks)).values()
if PY3:
if six.PY3:
expected = "['A1', 'A2']"
else:
expected = "[u'A1', u'A2']"
@@ -4463,7 +4454,7 @@ class QuerySetTest(unittest.TestCase):
name = StringField()
Person.drop_collection()
for i in xrange(100):
for i in range(100):
Person(name="No: %s" % i).save()
with query_counter() as q:
@@ -4494,7 +4485,7 @@ class QuerySetTest(unittest.TestCase):
name = StringField()
Person.drop_collection()
for i in xrange(100):
for i in range(100):
Person(name="No: %s" % i).save()
with query_counter() as q:
@@ -4538,7 +4529,7 @@ class QuerySetTest(unittest.TestCase):
fields = DictField()
Noddy.drop_collection()
for i in xrange(100):
for i in range(100):
noddy = Noddy()
for j in range(20):
noddy.fields["key" + str(j)] = "value " + str(j)
@@ -4550,7 +4541,9 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(counter, 100)
self.assertEqual(len(list(docs)), 100)
self.assertRaises(TypeError, lambda: len(docs))
with self.assertRaises(TypeError):
len(docs)
with query_counter() as q:
self.assertEqual(q, 0)
@@ -4739,7 +4732,7 @@ class QuerySetTest(unittest.TestCase):
name = StringField()
Person.drop_collection()
for i in xrange(100):
for i in range(100):
Person(name="No: %s" % i).save()
with query_counter() as q:
@@ -4863,10 +4856,10 @@ class QuerySetTest(unittest.TestCase):
])
def test_delete_count(self):
[self.Person(name="User {0}".format(i), age=i * 10).save() for i in xrange(1, 4)]
[self.Person(name="User {0}".format(i), age=i * 10).save() for i in range(1, 4)]
self.assertEqual(self.Person.objects().delete(), 3) # test ordinary QuerySey delete count
[self.Person(name="User {0}".format(i), age=i * 10).save() for i in xrange(1, 4)]
[self.Person(name="User {0}".format(i), age=i * 10).save() for i in range(1, 4)]
self.assertEqual(self.Person.objects().skip(1).delete(), 2) # test Document delete with existing documents
@@ -4875,12 +4868,14 @@ class QuerySetTest(unittest.TestCase):
def test_max_time_ms(self):
# 778: max_time_ms can get only int or None as input
self.assertRaises(TypeError, self.Person.objects(name="name").max_time_ms, "not a number")
self.assertRaises(TypeError,
self.Person.objects(name="name").max_time_ms,
'not a number')
def test_subclass_field_query(self):
class Animal(Document):
is_mamal = BooleanField()
meta = dict(allow_inheritance=True)
meta = {'allow_inheritance': True}
class Cat(Animal):
whiskers_length = FloatField()
@@ -4925,7 +4920,7 @@ class QuerySetTest(unittest.TestCase):
class Data(Document):
pass
for i in xrange(300):
for i in range(300):
Data().save()
records = Data.objects.limit(250)
@@ -4957,7 +4952,7 @@ class QuerySetTest(unittest.TestCase):
class Data(Document):
pass
for i in xrange(300):
for i in range(300):
Data().save()
qs = Data.objects.limit(250)

View File

@@ -238,7 +238,8 @@ class TransformTest(unittest.TestCase):
box = [(35.0, -125.0), (40.0, -100.0)]
# I *meant* to execute location__within_box=box
events = Event.objects(location__within=box)
self.assertRaises(InvalidQueryError, lambda: events.count())
with self.assertRaises(InvalidQueryError):
events.count()
if __name__ == '__main__':

View File

@@ -185,7 +185,7 @@ class QTest(unittest.TestCase):
x = IntField()
TestDoc.drop_collection()
for i in xrange(1, 101):
for i in range(1, 101):
t = TestDoc(x=i)
t.save()
@@ -268,14 +268,13 @@ class QTest(unittest.TestCase):
self.assertEqual(self.Person.objects(Q(age__in=[20, 30])).count(), 3)
# Test invalid query objs
def wrong_query_objs():
with self.assertRaises(InvalidQueryError):
self.Person.objects('user1')
def wrong_query_objs_filter():
self.Person.objects('user1')
# filter should fail, too
with self.assertRaises(InvalidQueryError):
self.Person.objects.filter('user1')
self.assertRaises(InvalidQueryError, wrong_query_objs)
self.assertRaises(InvalidQueryError, wrong_query_objs_filter)
def test_q_regex(self):
"""Ensure that Q objects can be queried using regexes.

View File

@@ -1,9 +1,6 @@
import sys
import datetime
from pymongo.errors import OperationFailure
sys.path[0:0] = [""]
try:
import unittest2 as unittest
except ImportError:
@@ -19,7 +16,8 @@ from mongoengine import (
)
from mongoengine.python_support import IS_PYMONGO_3
import mongoengine.connection
from mongoengine.connection import get_db, get_connection, ConnectionError
from mongoengine.connection import (MongoEngineConnectionError, get_db,
get_connection)
def get_tz_awareness(connection):
@@ -159,7 +157,10 @@ class ConnectionTest(unittest.TestCase):
c.mongoenginetest.add_user("username", "password")
if not IS_PYMONGO_3:
self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost')
self.assertRaises(
MongoEngineConnectionError, connect, 'testdb_uri_bad',
host='mongodb://test:password@localhost'
)
connect("testdb_uri", host='mongodb://username:password@localhost/mongoenginetest')
@@ -229,10 +230,11 @@ class ConnectionTest(unittest.TestCase):
self.assertRaises(OperationFailure, test_conn.server_info)
else:
self.assertRaises(
ConnectionError, connect, 'mongoenginetest', alias='test1',
MongoEngineConnectionError, connect, 'mongoenginetest',
alias='test1',
host='mongodb://username2:password@localhost/mongoenginetest'
)
self.assertRaises(ConnectionError, get_db, 'test1')
self.assertRaises(MongoEngineConnectionError, get_db, 'test1')
# Authentication succeeds with "authSource"
connect(
@@ -253,7 +255,7 @@ class ConnectionTest(unittest.TestCase):
"""
register_connection('testdb', 'mongoenginetest2')
self.assertRaises(ConnectionError, get_connection)
self.assertRaises(MongoEngineConnectionError, get_connection)
conn = get_connection('testdb')
self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient))

View File

@@ -1,5 +1,3 @@
import sys
sys.path[0:0] = [""]
import unittest
from mongoengine import *
@@ -79,7 +77,7 @@ class ContextManagersTest(unittest.TestCase):
User.drop_collection()
Group.drop_collection()
for i in xrange(1, 51):
for i in range(1, 51):
User(name='user %s' % i).save()
user = User.objects.first()
@@ -117,7 +115,7 @@ class ContextManagersTest(unittest.TestCase):
User.drop_collection()
Group.drop_collection()
for i in xrange(1, 51):
for i in range(1, 51):
User(name='user %s' % i).save()
user = User.objects.first()
@@ -195,7 +193,7 @@ class ContextManagersTest(unittest.TestCase):
with query_counter() as q:
self.assertEqual(0, q)
for i in xrange(1, 51):
for i in range(1, 51):
db.test.find({}).count()
self.assertEqual(50, q)

View File

@@ -23,7 +23,8 @@ class TestStrictDict(unittest.TestCase):
self.assertEqual(repr(d), '{"a": \'"\', "b": "\'", "c": \'\'}')
def test_init_fails_on_nonexisting_attrs(self):
self.assertRaises(AttributeError, lambda: self.dtype(a=1, b=2, d=3))
with self.assertRaises(AttributeError):
self.dtype(a=1, b=2, d=3)
def test_eq(self):
d = self.dtype(a=1, b=1, c=1)
@@ -46,14 +47,12 @@ class TestStrictDict(unittest.TestCase):
d = self.dtype()
d.a = 1
self.assertEqual(d.a, 1)
self.assertRaises(AttributeError, lambda: d.b)
self.assertRaises(AttributeError, getattr, d, 'b')
def test_setattr_raises_on_nonexisting_attr(self):
d = self.dtype()
def _f():
with self.assertRaises(AttributeError):
d.x = 1
self.assertRaises(AttributeError, _f)
def test_setattr_getattr_special(self):
d = self.strict_dict_class(["items"])

View File

@@ -1,6 +1,4 @@
# -*- coding: utf-8 -*-
import sys
sys.path[0:0] = [""]
import unittest
from bson import DBRef, ObjectId
@@ -32,7 +30,7 @@ class FieldTest(unittest.TestCase):
User.drop_collection()
Group.drop_collection()
for i in xrange(1, 51):
for i in range(1, 51):
user = User(name='user %s' % i)
user.save()
@@ -90,7 +88,7 @@ class FieldTest(unittest.TestCase):
User.drop_collection()
Group.drop_collection()
for i in xrange(1, 51):
for i in range(1, 51):
user = User(name='user %s' % i)
user.save()
@@ -162,7 +160,7 @@ class FieldTest(unittest.TestCase):
User.drop_collection()
Group.drop_collection()
for i in xrange(1, 26):
for i in range(1, 26):
user = User(name='user %s' % i)
user.save()
@@ -440,7 +438,7 @@ class FieldTest(unittest.TestCase):
Group.drop_collection()
members = []
for i in xrange(1, 51):
for i in range(1, 51):
a = UserA(name='User A %s' % i)
a.save()
@@ -531,7 +529,7 @@ class FieldTest(unittest.TestCase):
Group.drop_collection()
members = []
for i in xrange(1, 51):
for i in range(1, 51):
a = UserA(name='User A %s' % i)
a.save()
@@ -614,15 +612,15 @@ class FieldTest(unittest.TestCase):
Group.drop_collection()
members = []
for i in xrange(1, 51):
for i in range(1, 51):
user = User(name='user %s' % i)
user.save()
members.append(user)
group = Group(members=dict([(str(u.id), u) for u in members]))
group = Group(members={str(u.id): u for u in members})
group.save()
group = Group(members=dict([(str(u.id), u) for u in members]))
group = Group(members={str(u.id): u for u in members})
group.save()
with query_counter() as q:
@@ -687,7 +685,7 @@ class FieldTest(unittest.TestCase):
Group.drop_collection()
members = []
for i in xrange(1, 51):
for i in range(1, 51):
a = UserA(name='User A %s' % i)
a.save()
@@ -699,9 +697,9 @@ class FieldTest(unittest.TestCase):
members += [a, b, c]
group = Group(members=dict([(str(u.id), u) for u in members]))
group = Group(members={str(u.id): u for u in members})
group.save()
group = Group(members=dict([(str(u.id), u) for u in members]))
group = Group(members={str(u.id): u for u in members})
group.save()
with query_counter() as q:
@@ -783,16 +781,16 @@ class FieldTest(unittest.TestCase):
Group.drop_collection()
members = []
for i in xrange(1, 51):
for i in range(1, 51):
a = UserA(name='User A %s' % i)
a.save()
members += [a]
group = Group(members=dict([(str(u.id), u) for u in members]))
group = Group(members={str(u.id): u for u in members})
group.save()
group = Group(members=dict([(str(u.id), u) for u in members]))
group = Group(members={str(u.id): u for u in members})
group.save()
with query_counter() as q:
@@ -866,7 +864,7 @@ class FieldTest(unittest.TestCase):
Group.drop_collection()
members = []
for i in xrange(1, 51):
for i in range(1, 51):
a = UserA(name='User A %s' % i)
a.save()
@@ -878,9 +876,9 @@ class FieldTest(unittest.TestCase):
members += [a, b, c]
group = Group(members=dict([(str(u.id), u) for u in members]))
group = Group(members={str(u.id): u for u in members})
group.save()
group = Group(members=dict([(str(u.id), u) for u in members]))
group = Group(members={str(u.id): u for u in members})
group.save()
with query_counter() as q:
@@ -1103,7 +1101,7 @@ class FieldTest(unittest.TestCase):
User.drop_collection()
Group.drop_collection()
for i in xrange(1, 51):
for i in range(1, 51):
User(name='user %s' % i).save()
Group(name="Test", members=User.objects).save()
@@ -1132,7 +1130,7 @@ class FieldTest(unittest.TestCase):
User.drop_collection()
Group.drop_collection()
for i in xrange(1, 51):
for i in range(1, 51):
User(name='user %s' % i).save()
Group(name="Test", members=User.objects).save()
@@ -1169,7 +1167,7 @@ class FieldTest(unittest.TestCase):
Group.drop_collection()
members = []
for i in xrange(1, 51):
for i in range(1, 51):
a = UserA(name='User A %s' % i).save()
b = UserB(name='User B %s' % i).save()
c = UserC(name='User C %s' % i).save()

View File

@@ -1,6 +1,3 @@
import sys
sys.path[0:0] = [""]
import unittest
from pymongo import ReadPreference
@@ -18,7 +15,7 @@ else:
import mongoengine
from mongoengine import *
from mongoengine.connection import ConnectionError
from mongoengine.connection import MongoEngineConnectionError
class ConnectionTest(unittest.TestCase):
@@ -41,7 +38,7 @@ class ConnectionTest(unittest.TestCase):
conn = connect(db='mongoenginetest',
host="mongodb://localhost/mongoenginetest?replicaSet=rs",
read_preference=READ_PREF)
except ConnectionError, e:
except MongoEngineConnectionError as e:
return
if not isinstance(conn, CONN_CLASS):

View File

@@ -1,6 +1,4 @@
# -*- coding: utf-8 -*-
import sys
sys.path[0:0] = [""]
import unittest
from mongoengine import *