Merge branch 'master' of https://github.com/MongoEngine/mongoengine
This commit is contained in:
@@ -1,12 +1,11 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import with_statement
|
||||
import sys
|
||||
sys.path[0:0] = [""]
|
||||
import unittest
|
||||
|
||||
from mongoengine import *
|
||||
|
||||
from mongoengine.queryset import NULLIFY
|
||||
from mongoengine.queryset import NULLIFY, PULL
|
||||
from mongoengine.connection import get_db
|
||||
|
||||
__all__ = ("ClassMethodsTest", )
|
||||
@@ -86,6 +85,25 @@ class ClassMethodsTest(unittest.TestCase):
|
||||
self.assertEqual(self.Person._meta['delete_rules'],
|
||||
{(Job, 'employee'): NULLIFY})
|
||||
|
||||
def test_register_delete_rule_inherited(self):
|
||||
|
||||
class Vaccine(Document):
|
||||
name = StringField(required=True)
|
||||
|
||||
meta = {"indexes": ["name"]}
|
||||
|
||||
class Animal(Document):
|
||||
family = StringField(required=True)
|
||||
vaccine_made = ListField(ReferenceField("Vaccine", reverse_delete_rule=PULL))
|
||||
|
||||
meta = {"allow_inheritance": True, "indexes": ["family"]}
|
||||
|
||||
class Cat(Animal):
|
||||
name = StringField(required=True)
|
||||
|
||||
self.assertEqual(Vaccine._meta['delete_rules'][(Animal, 'vaccine_made')], PULL)
|
||||
self.assertEqual(Vaccine._meta['delete_rules'][(Cat, 'vaccine_made')], PULL)
|
||||
|
||||
def test_collection_naming(self):
|
||||
"""Ensure that a collection with a specified name may be used.
|
||||
"""
|
||||
|
||||
@@ -31,8 +31,9 @@ class DynamicTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(p.to_mongo(), {"_cls": "Person", "name": "James",
|
||||
"age": 34})
|
||||
|
||||
self.assertEqual(p.to_mongo().keys(), ["_cls", "name", "age"])
|
||||
p.save()
|
||||
self.assertEqual(p.to_mongo().keys(), ["_id", "_cls", "name", "age"])
|
||||
|
||||
self.assertEqual(self.Person.objects.first().age, 34)
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import with_statement
|
||||
import unittest
|
||||
import sys
|
||||
sys.path[0:0] = [""]
|
||||
@@ -217,7 +216,7 @@ class IndexesTest(unittest.TestCase):
|
||||
}
|
||||
|
||||
self.assertEqual([{'fields': [('location.point', '2d')]}],
|
||||
Place._meta['index_specs'])
|
||||
Place._meta['index_specs'])
|
||||
|
||||
Place.ensure_indexes()
|
||||
info = Place._get_collection().index_information()
|
||||
@@ -231,8 +230,7 @@ class IndexesTest(unittest.TestCase):
|
||||
location = DictField()
|
||||
|
||||
class Place(Document):
|
||||
current = DictField(
|
||||
field=EmbeddedDocumentField('EmbeddedLocation'))
|
||||
current = DictField(field=EmbeddedDocumentField('EmbeddedLocation'))
|
||||
meta = {
|
||||
'allow_inheritance': True,
|
||||
'indexes': [
|
||||
@@ -241,7 +239,7 @@ class IndexesTest(unittest.TestCase):
|
||||
}
|
||||
|
||||
self.assertEqual([{'fields': [('current.location.point', '2d')]}],
|
||||
Place._meta['index_specs'])
|
||||
Place._meta['index_specs'])
|
||||
|
||||
Place.ensure_indexes()
|
||||
info = Place._get_collection().index_information()
|
||||
@@ -264,7 +262,7 @@ class IndexesTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual([{'fields': [('addDate', -1)], 'unique': True,
|
||||
'sparse': True}],
|
||||
BlogPost._meta['index_specs'])
|
||||
BlogPost._meta['index_specs'])
|
||||
|
||||
BlogPost.drop_collection()
|
||||
|
||||
@@ -382,8 +380,7 @@ class IndexesTest(unittest.TestCase):
|
||||
self.assertEqual(sorted(info.keys()), ['_id_', 'tags.tag_1'])
|
||||
|
||||
post1 = BlogPost(title="Embedded Indexes tests in place",
|
||||
tags=[Tag(name="about"), Tag(name="time")]
|
||||
)
|
||||
tags=[Tag(name="about"), Tag(name="time")])
|
||||
post1.save()
|
||||
BlogPost.drop_collection()
|
||||
|
||||
@@ -400,29 +397,6 @@ class IndexesTest(unittest.TestCase):
|
||||
info = RecursiveDocument._get_collection().index_information()
|
||||
self.assertEqual(sorted(info.keys()), ['_cls_1', '_id_'])
|
||||
|
||||
def test_geo_indexes_recursion(self):
|
||||
|
||||
class Location(Document):
|
||||
name = StringField()
|
||||
location = GeoPointField()
|
||||
|
||||
class Parent(Document):
|
||||
name = StringField()
|
||||
location = ReferenceField(Location, dbref=False)
|
||||
|
||||
Location.drop_collection()
|
||||
Parent.drop_collection()
|
||||
|
||||
list(Parent.objects)
|
||||
|
||||
collection = Parent._get_collection()
|
||||
info = collection.index_information()
|
||||
|
||||
self.assertFalse('location_2d' in info)
|
||||
|
||||
self.assertEqual(len(Parent._geo_indices()), 0)
|
||||
self.assertEqual(len(Location._geo_indices()), 1)
|
||||
|
||||
def test_covered_index(self):
|
||||
"""Ensure that covered indexes can be used
|
||||
"""
|
||||
@@ -433,7 +407,7 @@ class IndexesTest(unittest.TestCase):
|
||||
meta = {
|
||||
'indexes': ['a'],
|
||||
'allow_inheritance': False
|
||||
}
|
||||
}
|
||||
|
||||
Test.drop_collection()
|
||||
|
||||
@@ -633,7 +607,7 @@ class IndexesTest(unittest.TestCase):
|
||||
list(Log.objects)
|
||||
info = Log.objects._collection.index_information()
|
||||
self.assertEqual(3600,
|
||||
info['created_1']['expireAfterSeconds'])
|
||||
info['created_1']['expireAfterSeconds'])
|
||||
|
||||
def test_unique_and_indexes(self):
|
||||
"""Ensure that 'unique' constraints aren't overridden by
|
||||
|
||||
@@ -143,7 +143,7 @@ class InheritanceTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(Animal._superclasses, ())
|
||||
self.assertEqual(Animal._subclasses, ('Animal', 'Animal.Fish',
|
||||
'Animal.Fish.Pike'))
|
||||
'Animal.Fish.Pike'))
|
||||
|
||||
self.assertEqual(Fish._superclasses, ('Animal', ))
|
||||
self.assertEqual(Fish._subclasses, ('Animal.Fish', 'Animal.Fish.Pike'))
|
||||
@@ -168,6 +168,26 @@ class InheritanceTest(unittest.TestCase):
|
||||
self.assertEqual(Employee._get_collection_name(),
|
||||
Person._get_collection_name())
|
||||
|
||||
def test_inheritance_to_mongo_keys(self):
|
||||
"""Ensure that document may inherit fields from a superclass document.
|
||||
"""
|
||||
class Person(Document):
|
||||
name = StringField()
|
||||
age = IntField()
|
||||
|
||||
meta = {'allow_inheritance': True}
|
||||
|
||||
class Employee(Person):
|
||||
salary = IntField()
|
||||
|
||||
self.assertEqual(['age', 'id', 'name', 'salary'],
|
||||
sorted(Employee._fields.keys()))
|
||||
self.assertEqual(Person(name="Bob", age=35).to_mongo().keys(),
|
||||
['_cls', 'name', 'age'])
|
||||
self.assertEqual(Employee(name="Bob", age=35, salary=0).to_mongo().keys(),
|
||||
['_cls', 'name', 'age', 'salary'])
|
||||
self.assertEqual(Employee._get_collection_name(),
|
||||
Person._get_collection_name())
|
||||
|
||||
def test_polymorphic_queries(self):
|
||||
"""Ensure that the correct subclasses are returned from a query
|
||||
@@ -197,7 +217,6 @@ class InheritanceTest(unittest.TestCase):
|
||||
classes = [obj.__class__ for obj in Human.objects]
|
||||
self.assertEqual(classes, [Human])
|
||||
|
||||
|
||||
def test_allow_inheritance(self):
|
||||
"""Ensure that inheritance may be disabled on simple classes and that
|
||||
_cls and _subclasses will not be used.
|
||||
@@ -213,8 +232,8 @@ class InheritanceTest(unittest.TestCase):
|
||||
self.assertRaises(ValueError, create_dog_class)
|
||||
|
||||
# Check that _cls etc aren't present on simple documents
|
||||
dog = Animal(name='dog')
|
||||
dog.save()
|
||||
dog = Animal(name='dog').save()
|
||||
self.assertEqual(dog.to_mongo().keys(), ['_id', 'name'])
|
||||
|
||||
collection = self.db[Animal._get_collection_name()]
|
||||
obj = collection.find_one()
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import with_statement
|
||||
import sys
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
@@ -320,8 +319,8 @@ class InstanceTest(unittest.TestCase):
|
||||
|
||||
Location.drop_collection()
|
||||
|
||||
self.assertEquals(Area, get_document("Area"))
|
||||
self.assertEquals(Area, get_document("Location.Area"))
|
||||
self.assertEqual(Area, get_document("Area"))
|
||||
self.assertEqual(Area, get_document("Location.Area"))
|
||||
|
||||
def test_creation(self):
|
||||
"""Ensure that document may be created using keyword arguments.
|
||||
@@ -428,6 +427,21 @@ class InstanceTest(unittest.TestCase):
|
||||
self.assertFalse('age' in person)
|
||||
self.assertFalse('nationality' in person)
|
||||
|
||||
def test_embedded_document_to_mongo(self):
|
||||
class Person(EmbeddedDocument):
|
||||
name = StringField()
|
||||
age = IntField()
|
||||
|
||||
meta = {"allow_inheritance": True}
|
||||
|
||||
class Employee(Person):
|
||||
salary = IntField()
|
||||
|
||||
self.assertEqual(Person(name="Bob", age=35).to_mongo().keys(),
|
||||
['_cls', 'name', 'age'])
|
||||
self.assertEqual(Employee(name="Bob", age=35, salary=0).to_mongo().keys(),
|
||||
['_cls', 'name', 'age', 'salary'])
|
||||
|
||||
def test_embedded_document(self):
|
||||
"""Ensure that embedded documents are set up correctly.
|
||||
"""
|
||||
@@ -494,12 +508,12 @@ class InstanceTest(unittest.TestCase):
|
||||
t = TestDocument(status="published")
|
||||
t.save(clean=False)
|
||||
|
||||
self.assertEquals(t.pub_date, None)
|
||||
self.assertEqual(t.pub_date, None)
|
||||
|
||||
t = TestDocument(status="published")
|
||||
t.save(clean=True)
|
||||
|
||||
self.assertEquals(type(t.pub_date), datetime)
|
||||
self.assertEqual(type(t.pub_date), datetime)
|
||||
|
||||
def test_document_embedded_clean(self):
|
||||
class TestEmbeddedDocument(EmbeddedDocument):
|
||||
@@ -531,7 +545,7 @@ class InstanceTest(unittest.TestCase):
|
||||
self.assertEqual(e.to_dict(), {'doc': {'__all__': expect_msg}})
|
||||
|
||||
t = TestDocument(doc=TestEmbeddedDocument(x=10, y=25)).save()
|
||||
self.assertEquals(t.doc.z, 35)
|
||||
self.assertEqual(t.doc.z, 35)
|
||||
|
||||
# Asserts not raises
|
||||
t = TestDocument(doc=TestEmbeddedDocument(x=15, y=35, z=5))
|
||||
@@ -838,6 +852,14 @@ class InstanceTest(unittest.TestCase):
|
||||
self.assertEqual(person.name, None)
|
||||
self.assertEqual(person.age, None)
|
||||
|
||||
def test_inserts_if_you_set_the_pk(self):
|
||||
p1 = self.Person(name='p1', id=bson.ObjectId()).save()
|
||||
p2 = self.Person(name='p2')
|
||||
p2.id = bson.ObjectId()
|
||||
p2.save()
|
||||
|
||||
self.assertEqual(2, self.Person.objects.count())
|
||||
|
||||
def test_can_save_if_not_included(self):
|
||||
|
||||
class EmbeddedDoc(EmbeddedDocument):
|
||||
@@ -1881,11 +1903,11 @@ class InstanceTest(unittest.TestCase):
|
||||
|
||||
A.objects.all()
|
||||
|
||||
self.assertEquals('testdb-2', B._meta.get('db_alias'))
|
||||
self.assertEquals('mongoenginetest',
|
||||
A._get_collection().database.name)
|
||||
self.assertEquals('mongoenginetest2',
|
||||
B._get_collection().database.name)
|
||||
self.assertEqual('testdb-2', B._meta.get('db_alias'))
|
||||
self.assertEqual('mongoenginetest',
|
||||
A._get_collection().database.name)
|
||||
self.assertEqual('mongoenginetest2',
|
||||
B._get_collection().database.name)
|
||||
|
||||
def test_db_alias_propagates(self):
|
||||
"""db_alias propagates?
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
from fields import *
|
||||
from file_tests import *
|
||||
from file_tests import *
|
||||
from geo import *
|
||||
@@ -1,5 +1,4 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import with_statement
|
||||
import sys
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
@@ -409,6 +408,27 @@ class FieldTest(unittest.TestCase):
|
||||
log.time = '1pm'
|
||||
self.assertRaises(ValidationError, log.validate)
|
||||
|
||||
def test_datetime_tz_aware_mark_as_changed(self):
|
||||
from mongoengine import connection
|
||||
|
||||
# Reset the connections
|
||||
connection._connection_settings = {}
|
||||
connection._connections = {}
|
||||
connection._dbs = {}
|
||||
|
||||
connect(db='mongoenginetest', tz_aware=True)
|
||||
|
||||
class LogEntry(Document):
|
||||
time = DateTimeField()
|
||||
|
||||
LogEntry.drop_collection()
|
||||
|
||||
LogEntry(time=datetime.datetime(2013, 1, 1, 0, 0, 0)).save()
|
||||
|
||||
log = LogEntry.objects.first()
|
||||
log.time = datetime.datetime(2013, 1, 1, 0, 0, 0)
|
||||
self.assertEqual(['time'], log._changed_fields)
|
||||
|
||||
def test_datetime(self):
|
||||
"""Tests showing pymongo datetime fields handling of microseconds.
|
||||
Microseconds are rounded to the nearest millisecond and pre UTC
|
||||
@@ -1841,45 +1861,6 @@ class FieldTest(unittest.TestCase):
|
||||
|
||||
Shirt.drop_collection()
|
||||
|
||||
def test_geo_indexes(self):
|
||||
"""Ensure that indexes are created automatically for GeoPointFields.
|
||||
"""
|
||||
class Event(Document):
|
||||
title = StringField()
|
||||
location = GeoPointField()
|
||||
|
||||
Event.drop_collection()
|
||||
event = Event(title="Coltrane Motion @ Double Door",
|
||||
location=[41.909889, -87.677137])
|
||||
event.save()
|
||||
|
||||
info = Event.objects._collection.index_information()
|
||||
self.assertTrue(u'location_2d' in info)
|
||||
self.assertTrue(info[u'location_2d']['key'] == [(u'location', u'2d')])
|
||||
|
||||
Event.drop_collection()
|
||||
|
||||
def test_geo_embedded_indexes(self):
|
||||
"""Ensure that indexes are created automatically for GeoPointFields on
|
||||
embedded documents.
|
||||
"""
|
||||
class Venue(EmbeddedDocument):
|
||||
location = GeoPointField()
|
||||
name = StringField()
|
||||
|
||||
class Event(Document):
|
||||
title = StringField()
|
||||
venue = EmbeddedDocumentField(Venue)
|
||||
|
||||
Event.drop_collection()
|
||||
venue = Venue(name="Double Door", location=[41.909889, -87.677137])
|
||||
event = Event(title="Coltrane Motion", venue=venue)
|
||||
event.save()
|
||||
|
||||
info = Event.objects._collection.index_information()
|
||||
self.assertTrue(u'location_2d' in info)
|
||||
self.assertTrue(info[u'location_2d']['key'] == [(u'location', u'2d')])
|
||||
|
||||
def test_ensure_unique_default_instances(self):
|
||||
"""Ensure that every field has it's own unique default instance."""
|
||||
class D(Document):
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import with_statement
|
||||
import sys
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
|
||||
274
tests/fields/geo.py
Normal file
274
tests/fields/geo.py
Normal file
@@ -0,0 +1,274 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import sys
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
import unittest
|
||||
|
||||
from mongoengine import *
|
||||
from mongoengine.connection import get_db
|
||||
|
||||
__all__ = ("GeoFieldTest", )
|
||||
|
||||
|
||||
class GeoFieldTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
connect(db='mongoenginetest')
|
||||
self.db = get_db()
|
||||
|
||||
def _test_for_expected_error(self, Cls, loc, expected):
|
||||
try:
|
||||
Cls(loc=loc).validate()
|
||||
self.fail()
|
||||
except ValidationError, e:
|
||||
self.assertEqual(expected, e.to_dict()['loc'])
|
||||
|
||||
def test_geopoint_validation(self):
|
||||
class Location(Document):
|
||||
loc = GeoPointField()
|
||||
|
||||
invalid_coords = [{"x": 1, "y": 2}, 5, "a"]
|
||||
expected = 'GeoPointField can only accept tuples or lists of (x, y)'
|
||||
|
||||
for coord in invalid_coords:
|
||||
self._test_for_expected_error(Location, coord, expected)
|
||||
|
||||
invalid_coords = [[], [1], [1, 2, 3]]
|
||||
for coord in invalid_coords:
|
||||
expected = "Value (%s) must be a two-dimensional point" % repr(coord)
|
||||
self._test_for_expected_error(Location, coord, expected)
|
||||
|
||||
invalid_coords = [[{}, {}], ("a", "b")]
|
||||
for coord in invalid_coords:
|
||||
expected = "Both values (%s) in point must be float or int" % repr(coord)
|
||||
self._test_for_expected_error(Location, coord, expected)
|
||||
|
||||
def test_point_validation(self):
|
||||
class Location(Document):
|
||||
loc = PointField()
|
||||
|
||||
invalid_coords = {"x": 1, "y": 2}
|
||||
expected = 'PointField can only accept a valid GeoJson dictionary or lists of (x, y)'
|
||||
self._test_for_expected_error(Location, invalid_coords, expected)
|
||||
|
||||
invalid_coords = {"type": "MadeUp", "coordinates": []}
|
||||
expected = 'PointField type must be "Point"'
|
||||
self._test_for_expected_error(Location, invalid_coords, expected)
|
||||
|
||||
invalid_coords = {"type": "Point", "coordinates": [1, 2, 3]}
|
||||
expected = "Value ([1, 2, 3]) must be a two-dimensional point"
|
||||
self._test_for_expected_error(Location, invalid_coords, expected)
|
||||
|
||||
invalid_coords = [5, "a"]
|
||||
expected = "PointField can only accept lists of [x, y]"
|
||||
for coord in invalid_coords:
|
||||
self._test_for_expected_error(Location, coord, expected)
|
||||
|
||||
invalid_coords = [[], [1], [1, 2, 3]]
|
||||
for coord in invalid_coords:
|
||||
expected = "Value (%s) must be a two-dimensional point" % repr(coord)
|
||||
self._test_for_expected_error(Location, coord, expected)
|
||||
|
||||
invalid_coords = [[{}, {}], ("a", "b")]
|
||||
for coord in invalid_coords:
|
||||
expected = "Both values (%s) in point must be float or int" % repr(coord)
|
||||
self._test_for_expected_error(Location, coord, expected)
|
||||
|
||||
Location(loc=[1, 2]).validate()
|
||||
|
||||
def test_linestring_validation(self):
|
||||
class Location(Document):
|
||||
loc = LineStringField()
|
||||
|
||||
invalid_coords = {"x": 1, "y": 2}
|
||||
expected = 'LineStringField can only accept a valid GeoJson dictionary or lists of (x, y)'
|
||||
self._test_for_expected_error(Location, invalid_coords, expected)
|
||||
|
||||
invalid_coords = {"type": "MadeUp", "coordinates": [[]]}
|
||||
expected = 'LineStringField type must be "LineString"'
|
||||
self._test_for_expected_error(Location, invalid_coords, expected)
|
||||
|
||||
invalid_coords = {"type": "LineString", "coordinates": [[1, 2, 3]]}
|
||||
expected = "Invalid LineString:\nValue ([1, 2, 3]) must be a two-dimensional point"
|
||||
self._test_for_expected_error(Location, invalid_coords, expected)
|
||||
|
||||
invalid_coords = [5, "a"]
|
||||
expected = "Invalid LineString must contain at least one valid point"
|
||||
self._test_for_expected_error(Location, invalid_coords, expected)
|
||||
|
||||
invalid_coords = [[1]]
|
||||
expected = "Invalid LineString:\nValue (%s) must be a two-dimensional point" % repr(invalid_coords[0])
|
||||
self._test_for_expected_error(Location, invalid_coords, expected)
|
||||
|
||||
invalid_coords = [[1, 2, 3]]
|
||||
expected = "Invalid LineString:\nValue (%s) must be a two-dimensional point" % repr(invalid_coords[0])
|
||||
self._test_for_expected_error(Location, invalid_coords, expected)
|
||||
|
||||
invalid_coords = [[[{}, {}]], [("a", "b")]]
|
||||
for coord in invalid_coords:
|
||||
expected = "Invalid LineString:\nBoth values (%s) in point must be float or int" % repr(coord[0])
|
||||
self._test_for_expected_error(Location, coord, expected)
|
||||
|
||||
Location(loc=[[1, 2], [3, 4], [5, 6], [1,2]]).validate()
|
||||
|
||||
def test_polygon_validation(self):
|
||||
class Location(Document):
|
||||
loc = PolygonField()
|
||||
|
||||
invalid_coords = {"x": 1, "y": 2}
|
||||
expected = 'PolygonField can only accept a valid GeoJson dictionary or lists of (x, y)'
|
||||
self._test_for_expected_error(Location, invalid_coords, expected)
|
||||
|
||||
invalid_coords = {"type": "MadeUp", "coordinates": [[]]}
|
||||
expected = 'PolygonField type must be "Polygon"'
|
||||
self._test_for_expected_error(Location, invalid_coords, expected)
|
||||
|
||||
invalid_coords = {"type": "Polygon", "coordinates": [[[1, 2, 3]]]}
|
||||
expected = "Invalid Polygon:\nValue ([1, 2, 3]) must be a two-dimensional point"
|
||||
self._test_for_expected_error(Location, invalid_coords, expected)
|
||||
|
||||
invalid_coords = [[[5, "a"]]]
|
||||
expected = "Invalid Polygon:\nBoth values ([5, 'a']) in point must be float or int"
|
||||
self._test_for_expected_error(Location, invalid_coords, expected)
|
||||
|
||||
invalid_coords = [[[]]]
|
||||
expected = "Invalid Polygon must contain at least one valid linestring"
|
||||
self._test_for_expected_error(Location, invalid_coords, expected)
|
||||
|
||||
invalid_coords = [[[1, 2, 3]]]
|
||||
expected = "Invalid Polygon:\nValue ([1, 2, 3]) must be a two-dimensional point"
|
||||
self._test_for_expected_error(Location, invalid_coords, expected)
|
||||
|
||||
invalid_coords = [[[{}, {}]], [("a", "b")]]
|
||||
expected = "Invalid Polygon:\nBoth values ([{}, {}]) in point must be float or int, Both values (('a', 'b')) in point must be float or int"
|
||||
self._test_for_expected_error(Location, invalid_coords, expected)
|
||||
|
||||
invalid_coords = [[[1, 2], [3, 4]]]
|
||||
expected = "Invalid Polygon:\nLineStrings must start and end at the same point"
|
||||
self._test_for_expected_error(Location, invalid_coords, expected)
|
||||
|
||||
Location(loc=[[[1, 2], [3, 4], [5, 6], [1, 2]]]).validate()
|
||||
|
||||
def test_indexes_geopoint(self):
|
||||
"""Ensure that indexes are created automatically for GeoPointFields.
|
||||
"""
|
||||
class Event(Document):
|
||||
title = StringField()
|
||||
location = GeoPointField()
|
||||
|
||||
geo_indicies = Event._geo_indices()
|
||||
self.assertEqual(geo_indicies, [{'fields': [('location', '2d')]}])
|
||||
|
||||
def test_geopoint_embedded_indexes(self):
|
||||
"""Ensure that indexes are created automatically for GeoPointFields on
|
||||
embedded documents.
|
||||
"""
|
||||
class Venue(EmbeddedDocument):
|
||||
location = GeoPointField()
|
||||
name = StringField()
|
||||
|
||||
class Event(Document):
|
||||
title = StringField()
|
||||
venue = EmbeddedDocumentField(Venue)
|
||||
|
||||
geo_indicies = Event._geo_indices()
|
||||
self.assertEqual(geo_indicies, [{'fields': [('venue.location', '2d')]}])
|
||||
|
||||
def test_indexes_2dsphere(self):
|
||||
"""Ensure that indexes are created automatically for GeoPointFields.
|
||||
"""
|
||||
class Event(Document):
|
||||
title = StringField()
|
||||
point = PointField()
|
||||
line = LineStringField()
|
||||
polygon = PolygonField()
|
||||
|
||||
geo_indicies = Event._geo_indices()
|
||||
self.assertTrue({'fields': [('line', '2dsphere')]} in geo_indicies)
|
||||
self.assertTrue({'fields': [('polygon', '2dsphere')]} in geo_indicies)
|
||||
self.assertTrue({'fields': [('point', '2dsphere')]} in geo_indicies)
|
||||
|
||||
def test_indexes_2dsphere_embedded(self):
|
||||
"""Ensure that indexes are created automatically for GeoPointFields.
|
||||
"""
|
||||
class Venue(EmbeddedDocument):
|
||||
name = StringField()
|
||||
point = PointField()
|
||||
line = LineStringField()
|
||||
polygon = PolygonField()
|
||||
|
||||
class Event(Document):
|
||||
title = StringField()
|
||||
venue = EmbeddedDocumentField(Venue)
|
||||
|
||||
geo_indicies = Event._geo_indices()
|
||||
self.assertTrue({'fields': [('venue.line', '2dsphere')]} in geo_indicies)
|
||||
self.assertTrue({'fields': [('venue.polygon', '2dsphere')]} in geo_indicies)
|
||||
self.assertTrue({'fields': [('venue.point', '2dsphere')]} in geo_indicies)
|
||||
|
||||
def test_geo_indexes_recursion(self):
|
||||
|
||||
class Location(Document):
|
||||
name = StringField()
|
||||
location = GeoPointField()
|
||||
|
||||
class Parent(Document):
|
||||
name = StringField()
|
||||
location = ReferenceField(Location)
|
||||
|
||||
Location.drop_collection()
|
||||
Parent.drop_collection()
|
||||
|
||||
list(Parent.objects)
|
||||
|
||||
collection = Parent._get_collection()
|
||||
info = collection.index_information()
|
||||
|
||||
self.assertFalse('location_2d' in info)
|
||||
|
||||
self.assertEqual(len(Parent._geo_indices()), 0)
|
||||
self.assertEqual(len(Location._geo_indices()), 1)
|
||||
|
||||
def test_geo_indexes_auto_index(self):
|
||||
|
||||
# Test just listing the fields
|
||||
class Log(Document):
|
||||
location = PointField(auto_index=False)
|
||||
datetime = DateTimeField()
|
||||
|
||||
meta = {
|
||||
'indexes': [[("location", "2dsphere"), ("datetime", 1)]]
|
||||
}
|
||||
|
||||
self.assertEqual([], Log._geo_indices())
|
||||
|
||||
Log.drop_collection()
|
||||
Log.ensure_indexes()
|
||||
|
||||
info = Log._get_collection().index_information()
|
||||
self.assertEqual(info["location_2dsphere_datetime_1"]["key"],
|
||||
[('location', '2dsphere'), ('datetime', 1)])
|
||||
|
||||
# Test listing explicitly
|
||||
class Log(Document):
|
||||
location = PointField(auto_index=False)
|
||||
datetime = DateTimeField()
|
||||
|
||||
meta = {
|
||||
'indexes': [
|
||||
{'fields': [("location", "2dsphere"), ("datetime", 1)]}
|
||||
]
|
||||
}
|
||||
|
||||
self.assertEqual([], Log._geo_indices())
|
||||
|
||||
Log.drop_collection()
|
||||
Log.ensure_indexes()
|
||||
|
||||
info = Log._get_collection().index_information()
|
||||
self.assertEqual(info["location_2dsphere_datetime_1"]["key"],
|
||||
[('location', '2dsphere'), ('datetime', 1)])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
from transform import *
|
||||
from field_list import *
|
||||
from queryset import *
|
||||
from visitor import *
|
||||
from visitor import *
|
||||
from geo import *
|
||||
|
||||
418
tests/queryset/geo.py
Normal file
418
tests/queryset/geo.py
Normal file
@@ -0,0 +1,418 @@
|
||||
import sys
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
import unittest
|
||||
from datetime import datetime, timedelta
|
||||
from mongoengine import *
|
||||
|
||||
__all__ = ("GeoQueriesTest",)
|
||||
|
||||
|
||||
class GeoQueriesTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
connect(db='mongoenginetest')
|
||||
|
||||
def test_geospatial_operators(self):
|
||||
"""Ensure that geospatial queries are working.
|
||||
"""
|
||||
class Event(Document):
|
||||
title = StringField()
|
||||
date = DateTimeField()
|
||||
location = GeoPointField()
|
||||
|
||||
def __unicode__(self):
|
||||
return self.title
|
||||
|
||||
Event.drop_collection()
|
||||
|
||||
event1 = Event(title="Coltrane Motion @ Double Door",
|
||||
date=datetime.now() - timedelta(days=1),
|
||||
location=[-87.677137, 41.909889]).save()
|
||||
event2 = Event(title="Coltrane Motion @ Bottom of the Hill",
|
||||
date=datetime.now() - timedelta(days=10),
|
||||
location=[-122.4194155, 37.7749295]).save()
|
||||
event3 = Event(title="Coltrane Motion @ Empty Bottle",
|
||||
date=datetime.now(),
|
||||
location=[-87.686638, 41.900474]).save()
|
||||
|
||||
# find all events "near" pitchfork office, chicago.
|
||||
# note that "near" will show the san francisco event, too,
|
||||
# although it sorts to last.
|
||||
events = Event.objects(location__near=[-87.67892, 41.9120459])
|
||||
self.assertEqual(events.count(), 3)
|
||||
self.assertEqual(list(events), [event1, event3, event2])
|
||||
|
||||
# find events within 5 degrees of pitchfork office, chicago
|
||||
point_and_distance = [[-87.67892, 41.9120459], 5]
|
||||
events = Event.objects(location__within_distance=point_and_distance)
|
||||
self.assertEqual(events.count(), 2)
|
||||
events = list(events)
|
||||
self.assertTrue(event2 not in events)
|
||||
self.assertTrue(event1 in events)
|
||||
self.assertTrue(event3 in events)
|
||||
|
||||
# ensure ordering is respected by "near"
|
||||
events = Event.objects(location__near=[-87.67892, 41.9120459])
|
||||
events = events.order_by("-date")
|
||||
self.assertEqual(events.count(), 3)
|
||||
self.assertEqual(list(events), [event3, event1, event2])
|
||||
|
||||
# find events within 10 degrees of san francisco
|
||||
point = [-122.415579, 37.7566023]
|
||||
events = Event.objects(location__near=point, location__max_distance=10)
|
||||
self.assertEqual(events.count(), 1)
|
||||
self.assertEqual(events[0], event2)
|
||||
|
||||
# find events within 10 degrees of san francisco
|
||||
point_and_distance = [[-122.415579, 37.7566023], 10]
|
||||
events = Event.objects(location__within_distance=point_and_distance)
|
||||
self.assertEqual(events.count(), 1)
|
||||
self.assertEqual(events[0], event2)
|
||||
|
||||
# find events within 1 degree of greenpoint, broolyn, nyc, ny
|
||||
point_and_distance = [[-73.9509714, 40.7237134], 1]
|
||||
events = Event.objects(location__within_distance=point_and_distance)
|
||||
self.assertEqual(events.count(), 0)
|
||||
|
||||
# ensure ordering is respected by "within_distance"
|
||||
point_and_distance = [[-87.67892, 41.9120459], 10]
|
||||
events = Event.objects(location__within_distance=point_and_distance)
|
||||
events = events.order_by("-date")
|
||||
self.assertEqual(events.count(), 2)
|
||||
self.assertEqual(events[0], event3)
|
||||
|
||||
# check that within_box works
|
||||
box = [(-125.0, 35.0), (-100.0, 40.0)]
|
||||
events = Event.objects(location__within_box=box)
|
||||
self.assertEqual(events.count(), 1)
|
||||
self.assertEqual(events[0].id, event2.id)
|
||||
|
||||
polygon = [
|
||||
(-87.694445, 41.912114),
|
||||
(-87.69084, 41.919395),
|
||||
(-87.681742, 41.927186),
|
||||
(-87.654276, 41.911731),
|
||||
(-87.656164, 41.898061),
|
||||
]
|
||||
events = Event.objects(location__within_polygon=polygon)
|
||||
self.assertEqual(events.count(), 1)
|
||||
self.assertEqual(events[0].id, event1.id)
|
||||
|
||||
polygon2 = [
|
||||
(-1.742249, 54.033586),
|
||||
(-1.225891, 52.792797),
|
||||
(-4.40094, 53.389881)
|
||||
]
|
||||
events = Event.objects(location__within_polygon=polygon2)
|
||||
self.assertEqual(events.count(), 0)
|
||||
|
||||
def test_geo_spatial_embedded(self):
|
||||
|
||||
class Venue(EmbeddedDocument):
|
||||
location = GeoPointField()
|
||||
name = StringField()
|
||||
|
||||
class Event(Document):
|
||||
title = StringField()
|
||||
venue = EmbeddedDocumentField(Venue)
|
||||
|
||||
Event.drop_collection()
|
||||
|
||||
venue1 = Venue(name="The Rock", location=[-87.677137, 41.909889])
|
||||
venue2 = Venue(name="The Bridge", location=[-122.4194155, 37.7749295])
|
||||
|
||||
event1 = Event(title="Coltrane Motion @ Double Door",
|
||||
venue=venue1).save()
|
||||
event2 = Event(title="Coltrane Motion @ Bottom of the Hill",
|
||||
venue=venue2).save()
|
||||
event3 = Event(title="Coltrane Motion @ Empty Bottle",
|
||||
venue=venue1).save()
|
||||
|
||||
# find all events "near" pitchfork office, chicago.
|
||||
# note that "near" will show the san francisco event, too,
|
||||
# although it sorts to last.
|
||||
events = Event.objects(venue__location__near=[-87.67892, 41.9120459])
|
||||
self.assertEqual(events.count(), 3)
|
||||
self.assertEqual(list(events), [event1, event3, event2])
|
||||
|
||||
def test_spherical_geospatial_operators(self):
|
||||
"""Ensure that spherical geospatial queries are working
|
||||
"""
|
||||
class Point(Document):
|
||||
location = GeoPointField()
|
||||
|
||||
Point.drop_collection()
|
||||
|
||||
# These points are one degree apart, which (according to Google Maps)
|
||||
# is about 110 km apart at this place on the Earth.
|
||||
north_point = Point(location=[-122, 38]).save() # Near Concord, CA
|
||||
south_point = Point(location=[-122, 37]).save() # Near Santa Cruz, CA
|
||||
|
||||
earth_radius = 6378.009 # in km (needs to be a float for dividing by)
|
||||
|
||||
# Finds both points because they are within 60 km of the reference
|
||||
# point equidistant between them.
|
||||
points = Point.objects(location__near_sphere=[-122, 37.5])
|
||||
self.assertEqual(points.count(), 2)
|
||||
|
||||
# Same behavior for _within_spherical_distance
|
||||
points = Point.objects(
|
||||
location__within_spherical_distance=[[-122, 37.5], 60/earth_radius]
|
||||
)
|
||||
self.assertEqual(points.count(), 2)
|
||||
|
||||
points = Point.objects(location__near_sphere=[-122, 37.5],
|
||||
location__max_distance=60 / earth_radius)
|
||||
self.assertEqual(points.count(), 2)
|
||||
|
||||
# Finds both points, but orders the north point first because it's
|
||||
# closer to the reference point to the north.
|
||||
points = Point.objects(location__near_sphere=[-122, 38.5])
|
||||
self.assertEqual(points.count(), 2)
|
||||
self.assertEqual(points[0].id, north_point.id)
|
||||
self.assertEqual(points[1].id, south_point.id)
|
||||
|
||||
# Finds both points, but orders the south point first because it's
|
||||
# closer to the reference point to the south.
|
||||
points = Point.objects(location__near_sphere=[-122, 36.5])
|
||||
self.assertEqual(points.count(), 2)
|
||||
self.assertEqual(points[0].id, south_point.id)
|
||||
self.assertEqual(points[1].id, north_point.id)
|
||||
|
||||
# Finds only one point because only the first point is within 60km of
|
||||
# the reference point to the south.
|
||||
points = Point.objects(
|
||||
location__within_spherical_distance=[[-122, 36.5], 60/earth_radius])
|
||||
self.assertEqual(points.count(), 1)
|
||||
self.assertEqual(points[0].id, south_point.id)
|
||||
|
||||
def test_2dsphere_point(self):
|
||||
|
||||
class Event(Document):
|
||||
title = StringField()
|
||||
date = DateTimeField()
|
||||
location = PointField()
|
||||
|
||||
def __unicode__(self):
|
||||
return self.title
|
||||
|
||||
Event.drop_collection()
|
||||
|
||||
event1 = Event(title="Coltrane Motion @ Double Door",
|
||||
date=datetime.now() - timedelta(days=1),
|
||||
location=[-87.677137, 41.909889])
|
||||
event1.save()
|
||||
event2 = Event(title="Coltrane Motion @ Bottom of the Hill",
|
||||
date=datetime.now() - timedelta(days=10),
|
||||
location=[-122.4194155, 37.7749295]).save()
|
||||
event3 = Event(title="Coltrane Motion @ Empty Bottle",
|
||||
date=datetime.now(),
|
||||
location=[-87.686638, 41.900474]).save()
|
||||
|
||||
# find all events "near" pitchfork office, chicago.
|
||||
# note that "near" will show the san francisco event, too,
|
||||
# although it sorts to last.
|
||||
events = Event.objects(location__near=[-87.67892, 41.9120459])
|
||||
self.assertEqual(events.count(), 3)
|
||||
self.assertEqual(list(events), [event1, event3, event2])
|
||||
|
||||
# find events within 5 degrees of pitchfork office, chicago
|
||||
point_and_distance = [[-87.67892, 41.9120459], 2]
|
||||
events = Event.objects(location__geo_within_center=point_and_distance)
|
||||
self.assertEqual(events.count(), 2)
|
||||
events = list(events)
|
||||
self.assertTrue(event2 not in events)
|
||||
self.assertTrue(event1 in events)
|
||||
self.assertTrue(event3 in events)
|
||||
|
||||
# ensure ordering is respected by "near"
|
||||
events = Event.objects(location__near=[-87.67892, 41.9120459])
|
||||
events = events.order_by("-date")
|
||||
self.assertEqual(events.count(), 3)
|
||||
self.assertEqual(list(events), [event3, event1, event2])
|
||||
|
||||
# find events within 10km of san francisco
|
||||
point = [-122.415579, 37.7566023]
|
||||
events = Event.objects(location__near=point, location__max_distance=10000)
|
||||
self.assertEqual(events.count(), 1)
|
||||
self.assertEqual(events[0], event2)
|
||||
|
||||
# find events within 1km of greenpoint, broolyn, nyc, ny
|
||||
events = Event.objects(location__near=[-73.9509714, 40.7237134], location__max_distance=1000)
|
||||
self.assertEqual(events.count(), 0)
|
||||
|
||||
# ensure ordering is respected by "near"
|
||||
events = Event.objects(location__near=[-87.67892, 41.9120459],
|
||||
location__max_distance=10000).order_by("-date")
|
||||
self.assertEqual(events.count(), 2)
|
||||
self.assertEqual(events[0], event3)
|
||||
|
||||
# check that within_box works
|
||||
box = [(-125.0, 35.0), (-100.0, 40.0)]
|
||||
events = Event.objects(location__geo_within_box=box)
|
||||
self.assertEqual(events.count(), 1)
|
||||
self.assertEqual(events[0].id, event2.id)
|
||||
|
||||
polygon = [
|
||||
(-87.694445, 41.912114),
|
||||
(-87.69084, 41.919395),
|
||||
(-87.681742, 41.927186),
|
||||
(-87.654276, 41.911731),
|
||||
(-87.656164, 41.898061),
|
||||
]
|
||||
events = Event.objects(location__geo_within_polygon=polygon)
|
||||
self.assertEqual(events.count(), 1)
|
||||
self.assertEqual(events[0].id, event1.id)
|
||||
|
||||
polygon2 = [
|
||||
(-1.742249, 54.033586),
|
||||
(-1.225891, 52.792797),
|
||||
(-4.40094, 53.389881)
|
||||
]
|
||||
events = Event.objects(location__geo_within_polygon=polygon2)
|
||||
self.assertEqual(events.count(), 0)
|
||||
|
||||
def test_2dsphere_point_embedded(self):
|
||||
|
||||
class Venue(EmbeddedDocument):
|
||||
location = GeoPointField()
|
||||
name = StringField()
|
||||
|
||||
class Event(Document):
|
||||
title = StringField()
|
||||
venue = EmbeddedDocumentField(Venue)
|
||||
|
||||
Event.drop_collection()
|
||||
|
||||
venue1 = Venue(name="The Rock", location=[-87.677137, 41.909889])
|
||||
venue2 = Venue(name="The Bridge", location=[-122.4194155, 37.7749295])
|
||||
|
||||
event1 = Event(title="Coltrane Motion @ Double Door",
|
||||
venue=venue1).save()
|
||||
event2 = Event(title="Coltrane Motion @ Bottom of the Hill",
|
||||
venue=venue2).save()
|
||||
event3 = Event(title="Coltrane Motion @ Empty Bottle",
|
||||
venue=venue1).save()
|
||||
|
||||
# find all events "near" pitchfork office, chicago.
|
||||
# note that "near" will show the san francisco event, too,
|
||||
# although it sorts to last.
|
||||
events = Event.objects(venue__location__near=[-87.67892, 41.9120459])
|
||||
self.assertEqual(events.count(), 3)
|
||||
self.assertEqual(list(events), [event1, event3, event2])
|
||||
|
||||
def test_linestring(self):
|
||||
|
||||
class Road(Document):
|
||||
name = StringField()
|
||||
line = LineStringField()
|
||||
|
||||
Road.drop_collection()
|
||||
|
||||
Road(name="66", line=[[40, 5], [41, 6]]).save()
|
||||
|
||||
# near
|
||||
point = {"type": "Point", "coordinates": [40, 5]}
|
||||
roads = Road.objects.filter(line__near=point["coordinates"]).count()
|
||||
self.assertEqual(1, roads)
|
||||
|
||||
roads = Road.objects.filter(line__near=point).count()
|
||||
self.assertEqual(1, roads)
|
||||
|
||||
roads = Road.objects.filter(line__near={"$geometry": point}).count()
|
||||
self.assertEqual(1, roads)
|
||||
|
||||
# Within
|
||||
polygon = {"type": "Polygon",
|
||||
"coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]]}
|
||||
roads = Road.objects.filter(line__geo_within=polygon["coordinates"]).count()
|
||||
self.assertEqual(1, roads)
|
||||
|
||||
roads = Road.objects.filter(line__geo_within=polygon).count()
|
||||
self.assertEqual(1, roads)
|
||||
|
||||
roads = Road.objects.filter(line__geo_within={"$geometry": polygon}).count()
|
||||
self.assertEqual(1, roads)
|
||||
|
||||
# Intersects
|
||||
line = {"type": "LineString",
|
||||
"coordinates": [[40, 5], [40, 6]]}
|
||||
roads = Road.objects.filter(line__geo_intersects=line["coordinates"]).count()
|
||||
self.assertEqual(1, roads)
|
||||
|
||||
roads = Road.objects.filter(line__geo_intersects=line).count()
|
||||
self.assertEqual(1, roads)
|
||||
|
||||
roads = Road.objects.filter(line__geo_intersects={"$geometry": line}).count()
|
||||
self.assertEqual(1, roads)
|
||||
|
||||
polygon = {"type": "Polygon",
|
||||
"coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]]}
|
||||
roads = Road.objects.filter(line__geo_intersects=polygon["coordinates"]).count()
|
||||
self.assertEqual(1, roads)
|
||||
|
||||
roads = Road.objects.filter(line__geo_intersects=polygon).count()
|
||||
self.assertEqual(1, roads)
|
||||
|
||||
roads = Road.objects.filter(line__geo_intersects={"$geometry": polygon}).count()
|
||||
self.assertEqual(1, roads)
|
||||
|
||||
def test_polygon(self):
|
||||
|
||||
class Road(Document):
|
||||
name = StringField()
|
||||
poly = PolygonField()
|
||||
|
||||
Road.drop_collection()
|
||||
|
||||
Road(name="66", poly=[[[40, 5], [40, 6], [41, 6], [40, 5]]]).save()
|
||||
|
||||
# near
|
||||
point = {"type": "Point", "coordinates": [40, 5]}
|
||||
roads = Road.objects.filter(poly__near=point["coordinates"]).count()
|
||||
self.assertEqual(1, roads)
|
||||
|
||||
roads = Road.objects.filter(poly__near=point).count()
|
||||
self.assertEqual(1, roads)
|
||||
|
||||
roads = Road.objects.filter(poly__near={"$geometry": point}).count()
|
||||
self.assertEqual(1, roads)
|
||||
|
||||
# Within
|
||||
polygon = {"type": "Polygon",
|
||||
"coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]]}
|
||||
roads = Road.objects.filter(poly__geo_within=polygon["coordinates"]).count()
|
||||
self.assertEqual(1, roads)
|
||||
|
||||
roads = Road.objects.filter(poly__geo_within=polygon).count()
|
||||
self.assertEqual(1, roads)
|
||||
|
||||
roads = Road.objects.filter(poly__geo_within={"$geometry": polygon}).count()
|
||||
self.assertEqual(1, roads)
|
||||
|
||||
# Intersects
|
||||
line = {"type": "LineString",
|
||||
"coordinates": [[40, 5], [41, 6]]}
|
||||
roads = Road.objects.filter(poly__geo_intersects=line["coordinates"]).count()
|
||||
self.assertEqual(1, roads)
|
||||
|
||||
roads = Road.objects.filter(poly__geo_intersects=line).count()
|
||||
self.assertEqual(1, roads)
|
||||
|
||||
roads = Road.objects.filter(poly__geo_intersects={"$geometry": line}).count()
|
||||
self.assertEqual(1, roads)
|
||||
|
||||
polygon = {"type": "Polygon",
|
||||
"coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]]}
|
||||
roads = Road.objects.filter(poly__geo_intersects=polygon["coordinates"]).count()
|
||||
self.assertEqual(1, roads)
|
||||
|
||||
roads = Road.objects.filter(poly__geo_intersects=polygon).count()
|
||||
self.assertEqual(1, roads)
|
||||
|
||||
roads = Road.objects.filter(poly__geo_intersects={"$geometry": polygon}).count()
|
||||
self.assertEqual(1, roads)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,4 +1,3 @@
|
||||
from __future__ import with_statement
|
||||
import sys
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
@@ -116,6 +115,15 @@ class QuerySetTest(unittest.TestCase):
|
||||
self.assertEqual(len(people), 1)
|
||||
self.assertEqual(people[0].name, 'User B')
|
||||
|
||||
# Test slice limit and skip cursor reset
|
||||
qs = self.Person.objects[1:2]
|
||||
# fetch then delete the cursor
|
||||
qs._cursor
|
||||
qs._cursor_obj = None
|
||||
people = list(qs)
|
||||
self.assertEqual(len(people), 1)
|
||||
self.assertEqual(people[0].name, 'User B')
|
||||
|
||||
people = list(self.Person.objects[1:1])
|
||||
self.assertEqual(len(people), 0)
|
||||
|
||||
@@ -274,7 +282,7 @@ class QuerySetTest(unittest.TestCase):
|
||||
a_objects = A.objects(s='test1')
|
||||
query = B.objects(ref__in=a_objects)
|
||||
query = query.filter(boolfield=True)
|
||||
self.assertEquals(query.count(), 1)
|
||||
self.assertEqual(query.count(), 1)
|
||||
|
||||
def test_update_write_concern(self):
|
||||
"""Test that passing write_concern works"""
|
||||
@@ -287,15 +295,19 @@ class QuerySetTest(unittest.TestCase):
|
||||
name='Test User', write_concern=write_concern)
|
||||
author.save(write_concern=write_concern)
|
||||
|
||||
self.Person.objects.update(set__name='Ross',
|
||||
write_concern=write_concern)
|
||||
result = self.Person.objects.update(
|
||||
set__name='Ross', write_concern={"w": 1})
|
||||
self.assertEqual(result, 1)
|
||||
result = self.Person.objects.update(
|
||||
set__name='Ross', write_concern={"w": 0})
|
||||
self.assertEqual(result, None)
|
||||
|
||||
author = self.Person.objects.first()
|
||||
self.assertEqual(author.name, 'Ross')
|
||||
|
||||
self.Person.objects.update_one(set__name='Test User', write_concern=write_concern)
|
||||
author = self.Person.objects.first()
|
||||
self.assertEqual(author.name, 'Test User')
|
||||
result = self.Person.objects.update_one(
|
||||
set__name='Test User', write_concern={"w": 1})
|
||||
self.assertEqual(result, 1)
|
||||
result = self.Person.objects.update_one(
|
||||
set__name='Test User', write_concern={"w": 0})
|
||||
self.assertEqual(result, None)
|
||||
|
||||
def test_update_update_has_a_value(self):
|
||||
"""Test to ensure that update is passed a value to update to"""
|
||||
@@ -524,6 +536,24 @@ class QuerySetTest(unittest.TestCase):
|
||||
self.assertEqual(club.members['John']['gender'], "F")
|
||||
self.assertEqual(club.members['John']['age'], 14)
|
||||
|
||||
def test_upsert(self):
|
||||
self.Person.drop_collection()
|
||||
|
||||
self.Person.objects(pk=ObjectId(), name="Bob", age=30).update(upsert=True)
|
||||
|
||||
bob = self.Person.objects.first()
|
||||
self.assertEqual("Bob", bob.name)
|
||||
self.assertEqual(30, bob.age)
|
||||
|
||||
def test_set_on_insert(self):
|
||||
self.Person.drop_collection()
|
||||
|
||||
self.Person.objects(pk=ObjectId()).update(set__name='Bob', set_on_insert__age=30, upsert=True)
|
||||
|
||||
bob = self.Person.objects.first()
|
||||
self.assertEqual("Bob", bob.name)
|
||||
self.assertEqual(30, bob.age)
|
||||
|
||||
def test_get_or_create(self):
|
||||
"""Ensure that ``get_or_create`` returns one result or creates a new
|
||||
document.
|
||||
@@ -763,7 +793,7 @@ class QuerySetTest(unittest.TestCase):
|
||||
|
||||
p = p.snapshot(True).slave_okay(True).timeout(True)
|
||||
self.assertEqual(p._cursor_args,
|
||||
{'snapshot': True, 'slave_okay': True, 'timeout': True})
|
||||
{'snapshot': True, 'slave_okay': True, 'timeout': True})
|
||||
|
||||
def test_repeated_iteration(self):
|
||||
"""Ensure that QuerySet rewinds itself one iteration finishes.
|
||||
@@ -805,6 +835,7 @@ class QuerySetTest(unittest.TestCase):
|
||||
self.assertTrue("Doc: 0" in docs_string)
|
||||
|
||||
self.assertEqual(docs.count(), 1000)
|
||||
self.assertTrue('(remaining elements truncated)' in "%s" % docs)
|
||||
|
||||
# Limit and skip
|
||||
docs = docs[1:4]
|
||||
@@ -1233,7 +1264,7 @@ class QuerySetTest(unittest.TestCase):
|
||||
class BlogPost(Document):
|
||||
content = StringField()
|
||||
authors = ListField(ReferenceField(self.Person,
|
||||
reverse_delete_rule=PULL))
|
||||
reverse_delete_rule=PULL))
|
||||
|
||||
BlogPost.drop_collection()
|
||||
self.Person.drop_collection()
|
||||
@@ -1291,6 +1322,49 @@ class QuerySetTest(unittest.TestCase):
|
||||
self.Person.objects()[:1].delete()
|
||||
self.assertEqual(1, BlogPost.objects.count())
|
||||
|
||||
|
||||
def test_reference_field_find(self):
|
||||
"""Ensure cascading deletion of referring documents from the database.
|
||||
"""
|
||||
class BlogPost(Document):
|
||||
content = StringField()
|
||||
author = ReferenceField(self.Person)
|
||||
|
||||
BlogPost.drop_collection()
|
||||
self.Person.drop_collection()
|
||||
|
||||
me = self.Person(name='Test User').save()
|
||||
BlogPost(content="test 123", author=me).save()
|
||||
|
||||
self.assertEqual(1, BlogPost.objects(author=me).count())
|
||||
self.assertEqual(1, BlogPost.objects(author=me.pk).count())
|
||||
self.assertEqual(1, BlogPost.objects(author="%s" % me.pk).count())
|
||||
|
||||
self.assertEqual(1, BlogPost.objects(author__in=[me]).count())
|
||||
self.assertEqual(1, BlogPost.objects(author__in=[me.pk]).count())
|
||||
self.assertEqual(1, BlogPost.objects(author__in=["%s" % me.pk]).count())
|
||||
|
||||
def test_reference_field_find_dbref(self):
|
||||
"""Ensure cascading deletion of referring documents from the database.
|
||||
"""
|
||||
class BlogPost(Document):
|
||||
content = StringField()
|
||||
author = ReferenceField(self.Person, dbref=True)
|
||||
|
||||
BlogPost.drop_collection()
|
||||
self.Person.drop_collection()
|
||||
|
||||
me = self.Person(name='Test User').save()
|
||||
BlogPost(content="test 123", author=me).save()
|
||||
|
||||
self.assertEqual(1, BlogPost.objects(author=me).count())
|
||||
self.assertEqual(1, BlogPost.objects(author=me.pk).count())
|
||||
self.assertEqual(1, BlogPost.objects(author="%s" % me.pk).count())
|
||||
|
||||
self.assertEqual(1, BlogPost.objects(author__in=[me]).count())
|
||||
self.assertEqual(1, BlogPost.objects(author__in=[me.pk]).count())
|
||||
self.assertEqual(1, BlogPost.objects(author__in=["%s" % me.pk]).count())
|
||||
|
||||
def test_update(self):
|
||||
"""Ensure that atomic updates work properly.
|
||||
"""
|
||||
@@ -2380,167 +2454,6 @@ class QuerySetTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
self.Person.drop_collection()
|
||||
|
||||
def test_geospatial_operators(self):
|
||||
"""Ensure that geospatial queries are working.
|
||||
"""
|
||||
class Event(Document):
|
||||
title = StringField()
|
||||
date = DateTimeField()
|
||||
location = GeoPointField()
|
||||
|
||||
def __unicode__(self):
|
||||
return self.title
|
||||
|
||||
Event.drop_collection()
|
||||
|
||||
event1 = Event(title="Coltrane Motion @ Double Door",
|
||||
date=datetime.now() - timedelta(days=1),
|
||||
location=[41.909889, -87.677137])
|
||||
event2 = Event(title="Coltrane Motion @ Bottom of the Hill",
|
||||
date=datetime.now() - timedelta(days=10),
|
||||
location=[37.7749295, -122.4194155])
|
||||
event3 = Event(title="Coltrane Motion @ Empty Bottle",
|
||||
date=datetime.now(),
|
||||
location=[41.900474, -87.686638])
|
||||
|
||||
event1.save()
|
||||
event2.save()
|
||||
event3.save()
|
||||
|
||||
# find all events "near" pitchfork office, chicago.
|
||||
# note that "near" will show the san francisco event, too,
|
||||
# although it sorts to last.
|
||||
events = Event.objects(location__near=[41.9120459, -87.67892])
|
||||
self.assertEqual(events.count(), 3)
|
||||
self.assertEqual(list(events), [event1, event3, event2])
|
||||
|
||||
# find events within 5 degrees of pitchfork office, chicago
|
||||
point_and_distance = [[41.9120459, -87.67892], 5]
|
||||
events = Event.objects(location__within_distance=point_and_distance)
|
||||
self.assertEqual(events.count(), 2)
|
||||
events = list(events)
|
||||
self.assertTrue(event2 not in events)
|
||||
self.assertTrue(event1 in events)
|
||||
self.assertTrue(event3 in events)
|
||||
|
||||
# ensure ordering is respected by "near"
|
||||
events = Event.objects(location__near=[41.9120459, -87.67892])
|
||||
events = events.order_by("-date")
|
||||
self.assertEqual(events.count(), 3)
|
||||
self.assertEqual(list(events), [event3, event1, event2])
|
||||
|
||||
# find events within 10 degrees of san francisco
|
||||
point = [37.7566023, -122.415579]
|
||||
events = Event.objects(location__near=point, location__max_distance=10)
|
||||
self.assertEqual(events.count(), 1)
|
||||
self.assertEqual(events[0], event2)
|
||||
|
||||
# find events within 10 degrees of san francisco
|
||||
point_and_distance = [[37.7566023, -122.415579], 10]
|
||||
events = Event.objects(location__within_distance=point_and_distance)
|
||||
self.assertEqual(events.count(), 1)
|
||||
self.assertEqual(events[0], event2)
|
||||
|
||||
# find events within 1 degree of greenpoint, broolyn, nyc, ny
|
||||
point_and_distance = [[40.7237134, -73.9509714], 1]
|
||||
events = Event.objects(location__within_distance=point_and_distance)
|
||||
self.assertEqual(events.count(), 0)
|
||||
|
||||
# ensure ordering is respected by "within_distance"
|
||||
point_and_distance = [[41.9120459, -87.67892], 10]
|
||||
events = Event.objects(location__within_distance=point_and_distance)
|
||||
events = events.order_by("-date")
|
||||
self.assertEqual(events.count(), 2)
|
||||
self.assertEqual(events[0], event3)
|
||||
|
||||
# check that within_box works
|
||||
box = [(35.0, -125.0), (40.0, -100.0)]
|
||||
events = Event.objects(location__within_box=box)
|
||||
self.assertEqual(events.count(), 1)
|
||||
self.assertEqual(events[0].id, event2.id)
|
||||
|
||||
# check that polygon works for users who have a server >= 1.9
|
||||
server_version = tuple(
|
||||
get_connection().server_info()['version'].split('.')
|
||||
)
|
||||
required_version = tuple("1.9.0".split("."))
|
||||
if server_version >= required_version:
|
||||
polygon = [
|
||||
(41.912114,-87.694445),
|
||||
(41.919395,-87.69084),
|
||||
(41.927186,-87.681742),
|
||||
(41.911731,-87.654276),
|
||||
(41.898061,-87.656164),
|
||||
]
|
||||
events = Event.objects(location__within_polygon=polygon)
|
||||
self.assertEqual(events.count(), 1)
|
||||
self.assertEqual(events[0].id, event1.id)
|
||||
|
||||
polygon2 = [
|
||||
(54.033586,-1.742249),
|
||||
(52.792797,-1.225891),
|
||||
(53.389881,-4.40094)
|
||||
]
|
||||
events = Event.objects(location__within_polygon=polygon2)
|
||||
self.assertEqual(events.count(), 0)
|
||||
|
||||
Event.drop_collection()
|
||||
|
||||
def test_spherical_geospatial_operators(self):
|
||||
"""Ensure that spherical geospatial queries are working
|
||||
"""
|
||||
class Point(Document):
|
||||
location = GeoPointField()
|
||||
|
||||
Point.drop_collection()
|
||||
|
||||
# These points are one degree apart, which (according to Google Maps)
|
||||
# is about 110 km apart at this place on the Earth.
|
||||
north_point = Point(location=[-122, 38]) # Near Concord, CA
|
||||
south_point = Point(location=[-122, 37]) # Near Santa Cruz, CA
|
||||
north_point.save()
|
||||
south_point.save()
|
||||
|
||||
earth_radius = 6378.009; # in km (needs to be a float for dividing by)
|
||||
|
||||
# Finds both points because they are within 60 km of the reference
|
||||
# point equidistant between them.
|
||||
points = Point.objects(location__near_sphere=[-122, 37.5])
|
||||
self.assertEqual(points.count(), 2)
|
||||
|
||||
# Same behavior for _within_spherical_distance
|
||||
points = Point.objects(
|
||||
location__within_spherical_distance=[[-122, 37.5], 60/earth_radius]
|
||||
);
|
||||
self.assertEqual(points.count(), 2)
|
||||
|
||||
points = Point.objects(location__near_sphere=[-122, 37.5],
|
||||
location__max_distance=60 / earth_radius);
|
||||
self.assertEqual(points.count(), 2)
|
||||
|
||||
# Finds both points, but orders the north point first because it's
|
||||
# closer to the reference point to the north.
|
||||
points = Point.objects(location__near_sphere=[-122, 38.5])
|
||||
self.assertEqual(points.count(), 2)
|
||||
self.assertEqual(points[0].id, north_point.id)
|
||||
self.assertEqual(points[1].id, south_point.id)
|
||||
|
||||
# Finds both points, but orders the south point first because it's
|
||||
# closer to the reference point to the south.
|
||||
points = Point.objects(location__near_sphere=[-122, 36.5])
|
||||
self.assertEqual(points.count(), 2)
|
||||
self.assertEqual(points[0].id, south_point.id)
|
||||
self.assertEqual(points[1].id, north_point.id)
|
||||
|
||||
# Finds only one point because only the first point is within 60km of
|
||||
# the reference point to the south.
|
||||
points = Point.objects(
|
||||
location__within_spherical_distance=[[-122, 36.5], 60/earth_radius])
|
||||
self.assertEqual(points.count(), 1)
|
||||
self.assertEqual(points[0].id, south_point.id)
|
||||
|
||||
Point.drop_collection()
|
||||
|
||||
def test_custom_querysets(self):
|
||||
"""Ensure that custom QuerySet classes may be used.
|
||||
"""
|
||||
@@ -3276,6 +3189,28 @@ class QuerySetTest(unittest.TestCase):
|
||||
self.assertEqual(results[1]['name'], 'Barack Obama')
|
||||
self.assertEqual(results[1]['price'], Decimal('2.22'))
|
||||
|
||||
def test_as_pymongo_json_limit_fields(self):
|
||||
|
||||
class User(Document):
|
||||
email = EmailField(unique=True, required=True)
|
||||
password_hash = StringField(db_field='password_hash', required=True)
|
||||
password_salt = StringField(db_field='password_salt', required=True)
|
||||
|
||||
User.drop_collection()
|
||||
User(email="ross@example.com", password_salt="SomeSalt", password_hash="SomeHash").save()
|
||||
|
||||
serialized_user = User.objects.exclude('password_salt', 'password_hash').as_pymongo()[0]
|
||||
self.assertEqual(set(['_id', 'email']), set(serialized_user.keys()))
|
||||
|
||||
serialized_user = User.objects.exclude('id', 'password_salt', 'password_hash').to_json()
|
||||
self.assertEqual('[{"email": "ross@example.com"}]', serialized_user)
|
||||
|
||||
serialized_user = User.objects.exclude('password_salt').only('email').as_pymongo()[0]
|
||||
self.assertEqual(set(['email']), set(serialized_user.keys()))
|
||||
|
||||
serialized_user = User.objects.exclude('password_salt').only('email').to_json()
|
||||
self.assertEqual('[{"email": "ross@example.com"}]', serialized_user)
|
||||
|
||||
def test_no_dereference(self):
|
||||
|
||||
class Organization(Document):
|
||||
@@ -3297,6 +3232,51 @@ class QuerySetTest(unittest.TestCase):
|
||||
Organization))
|
||||
self.assertTrue(isinstance(qs.first().organization, Organization))
|
||||
|
||||
def test_cached_queryset(self):
|
||||
class Person(Document):
|
||||
name = StringField()
|
||||
|
||||
Person.drop_collection()
|
||||
for i in xrange(100):
|
||||
Person(name="No: %s" % i).save()
|
||||
|
||||
with query_counter() as q:
|
||||
self.assertEqual(q, 0)
|
||||
people = Person.objects
|
||||
|
||||
[x for x in people]
|
||||
self.assertEqual(100, len(people._result_cache))
|
||||
self.assertEqual(None, people._len)
|
||||
self.assertEqual(q, 1)
|
||||
|
||||
list(people)
|
||||
self.assertEqual(100, people._len) # Caused by list calling len
|
||||
self.assertEqual(q, 1)
|
||||
|
||||
people.count() # count is cached
|
||||
self.assertEqual(q, 1)
|
||||
|
||||
def test_cache_not_cloned(self):
|
||||
|
||||
class User(Document):
|
||||
name = StringField()
|
||||
|
||||
def __unicode__(self):
|
||||
return self.name
|
||||
|
||||
User.drop_collection()
|
||||
|
||||
User(name="Alice").save()
|
||||
User(name="Bob").save()
|
||||
|
||||
users = User.objects.all().order_by('name')
|
||||
self.assertEqual("%s" % users, "[<User: Alice>, <User: Bob>]")
|
||||
self.assertEqual(2, len(users._result_cache))
|
||||
|
||||
users = users.filter(name="Bob")
|
||||
self.assertEqual("%s" % users, "[<User: Bob>]")
|
||||
self.assertEqual(1, len(users._result_cache))
|
||||
|
||||
def test_nested_queryset_iterator(self):
|
||||
# Try iterating the same queryset twice, nested.
|
||||
names = ['Alice', 'Bob', 'Chuck', 'David', 'Eric', 'Francis', 'George']
|
||||
@@ -3313,30 +3293,73 @@ class QuerySetTest(unittest.TestCase):
|
||||
User(name=name).save()
|
||||
|
||||
users = User.objects.all().order_by('name')
|
||||
|
||||
outer_count = 0
|
||||
inner_count = 0
|
||||
inner_total_count = 0
|
||||
|
||||
self.assertEqual(users.count(), 7)
|
||||
with query_counter() as q:
|
||||
self.assertEqual(q, 0)
|
||||
|
||||
for i, outer_user in enumerate(users):
|
||||
self.assertEqual(outer_user.name, names[i])
|
||||
outer_count += 1
|
||||
inner_count = 0
|
||||
|
||||
# Calling len might disrupt the inner loop if there are bugs
|
||||
self.assertEqual(users.count(), 7)
|
||||
|
||||
for j, inner_user in enumerate(users):
|
||||
self.assertEqual(inner_user.name, names[j])
|
||||
inner_count += 1
|
||||
inner_total_count += 1
|
||||
for i, outer_user in enumerate(users):
|
||||
self.assertEqual(outer_user.name, names[i])
|
||||
outer_count += 1
|
||||
inner_count = 0
|
||||
|
||||
self.assertEqual(inner_count, 7) # inner loop should always be executed seven times
|
||||
# Calling len might disrupt the inner loop if there are bugs
|
||||
self.assertEqual(users.count(), 7)
|
||||
|
||||
self.assertEqual(outer_count, 7) # outer loop should be executed seven times total
|
||||
self.assertEqual(inner_total_count, 7 * 7) # inner loop should be executed fourtynine times total
|
||||
for j, inner_user in enumerate(users):
|
||||
self.assertEqual(inner_user.name, names[j])
|
||||
inner_count += 1
|
||||
inner_total_count += 1
|
||||
|
||||
self.assertEqual(inner_count, 7) # inner loop should always be executed seven times
|
||||
|
||||
self.assertEqual(outer_count, 7) # outer loop should be executed seven times total
|
||||
self.assertEqual(inner_total_count, 7 * 7) # inner loop should be executed fourtynine times total
|
||||
|
||||
self.assertEqual(q, 2)
|
||||
|
||||
def test_no_sub_classes(self):
|
||||
class A(Document):
|
||||
x = IntField()
|
||||
y = IntField()
|
||||
|
||||
meta = {'allow_inheritance': True}
|
||||
|
||||
class B(A):
|
||||
z = IntField()
|
||||
|
||||
class C(B):
|
||||
zz = IntField()
|
||||
|
||||
A.drop_collection()
|
||||
|
||||
A(x=10, y=20).save()
|
||||
A(x=15, y=30).save()
|
||||
B(x=20, y=40).save()
|
||||
B(x=30, y=50).save()
|
||||
C(x=40, y=60).save()
|
||||
|
||||
self.assertEqual(A.objects.no_sub_classes().count(), 2)
|
||||
self.assertEqual(A.objects.count(), 5)
|
||||
|
||||
self.assertEqual(B.objects.no_sub_classes().count(), 2)
|
||||
self.assertEqual(B.objects.count(), 3)
|
||||
|
||||
self.assertEqual(C.objects.no_sub_classes().count(), 1)
|
||||
self.assertEqual(C.objects.count(), 1)
|
||||
|
||||
for obj in A.objects.no_sub_classes():
|
||||
self.assertEqual(obj.__class__, A)
|
||||
|
||||
for obj in B.objects.no_sub_classes():
|
||||
self.assertEqual(obj.__class__, B)
|
||||
|
||||
for obj in C.objects.no_sub_classes():
|
||||
self.assertEqual(obj.__class__, C)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from __future__ import with_statement
|
||||
import sys
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from __future__ import with_statement
|
||||
import sys
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from __future__ import with_statement
|
||||
import sys
|
||||
sys.path[0:0] = [""]
|
||||
import unittest
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from __future__ import with_statement
|
||||
import sys
|
||||
sys.path[0:0] = [""]
|
||||
import unittest
|
||||
@@ -6,7 +5,8 @@ import unittest
|
||||
from mongoengine import *
|
||||
from mongoengine.connection import get_db
|
||||
from mongoengine.context_managers import (switch_db, switch_collection,
|
||||
no_dereference, query_counter)
|
||||
no_sub_classes, no_dereference,
|
||||
query_counter)
|
||||
|
||||
|
||||
class ContextManagersTest(unittest.TestCase):
|
||||
@@ -139,6 +139,54 @@ class ContextManagersTest(unittest.TestCase):
|
||||
self.assertTrue(isinstance(group.ref, User))
|
||||
self.assertTrue(isinstance(group.generic, User))
|
||||
|
||||
def test_no_sub_classes(self):
|
||||
class A(Document):
|
||||
x = IntField()
|
||||
y = IntField()
|
||||
|
||||
meta = {'allow_inheritance': True}
|
||||
|
||||
class B(A):
|
||||
z = IntField()
|
||||
|
||||
class C(B):
|
||||
zz = IntField()
|
||||
|
||||
A.drop_collection()
|
||||
|
||||
A(x=10, y=20).save()
|
||||
A(x=15, y=30).save()
|
||||
B(x=20, y=40).save()
|
||||
B(x=30, y=50).save()
|
||||
C(x=40, y=60).save()
|
||||
|
||||
self.assertEqual(A.objects.count(), 5)
|
||||
self.assertEqual(B.objects.count(), 3)
|
||||
self.assertEqual(C.objects.count(), 1)
|
||||
|
||||
with no_sub_classes(A) as A:
|
||||
self.assertEqual(A.objects.count(), 2)
|
||||
|
||||
for obj in A.objects:
|
||||
self.assertEqual(obj.__class__, A)
|
||||
|
||||
with no_sub_classes(B) as B:
|
||||
self.assertEqual(B.objects.count(), 2)
|
||||
|
||||
for obj in B.objects:
|
||||
self.assertEqual(obj.__class__, B)
|
||||
|
||||
with no_sub_classes(C) as C:
|
||||
self.assertEqual(C.objects.count(), 1)
|
||||
|
||||
for obj in C.objects:
|
||||
self.assertEqual(obj.__class__, C)
|
||||
|
||||
# Confirm context manager exit correctly
|
||||
self.assertEqual(A.objects.count(), 5)
|
||||
self.assertEqual(B.objects.count(), 3)
|
||||
self.assertEqual(C.objects.count(), 1)
|
||||
|
||||
def test_query_counter(self):
|
||||
connect('mongoenginetest')
|
||||
db = get_db()
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import with_statement
|
||||
import sys
|
||||
sys.path[0:0] = [""]
|
||||
import unittest
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from __future__ import with_statement
|
||||
import sys
|
||||
sys.path[0:0] = [""]
|
||||
import unittest
|
||||
@@ -151,22 +150,74 @@ class QuerySetTest(unittest.TestCase):
|
||||
# Try iterating the same queryset twice, nested, in a Django template.
|
||||
names = ['A', 'B', 'C', 'D']
|
||||
|
||||
class User(Document):
|
||||
class CustomUser(Document):
|
||||
name = StringField()
|
||||
|
||||
def __unicode__(self):
|
||||
return self.name
|
||||
|
||||
User.drop_collection()
|
||||
CustomUser.drop_collection()
|
||||
|
||||
for name in names:
|
||||
User(name=name).save()
|
||||
CustomUser(name=name).save()
|
||||
|
||||
users = User.objects.all().order_by('name')
|
||||
users = CustomUser.objects.all().order_by('name')
|
||||
template = Template("{% for user in users %}{{ user.name }}{% ifequal forloop.counter 2 %} {% for inner_user in users %}{{ inner_user.name }}{% endfor %} {% endifequal %}{% endfor %}")
|
||||
rendered = template.render(Context({'users': users}))
|
||||
self.assertEqual(rendered, 'AB ABCD CD')
|
||||
|
||||
def test_filter(self):
|
||||
"""Ensure that a queryset and filters work as expected
|
||||
"""
|
||||
|
||||
class Note(Document):
|
||||
text = StringField()
|
||||
|
||||
for i in xrange(1, 101):
|
||||
Note(name="Note: %s" % i).save()
|
||||
|
||||
# Check the count
|
||||
self.assertEqual(Note.objects.count(), 100)
|
||||
|
||||
# Get the first 10 and confirm
|
||||
notes = Note.objects[:10]
|
||||
self.assertEqual(notes.count(), 10)
|
||||
|
||||
# Test djangos template filters
|
||||
# self.assertEqual(length(notes), 10)
|
||||
t = Template("{{ notes.count }}")
|
||||
c = Context({"notes": notes})
|
||||
self.assertEqual(t.render(c), "10")
|
||||
|
||||
# Test with skip
|
||||
notes = Note.objects.skip(90)
|
||||
self.assertEqual(notes.count(), 10)
|
||||
|
||||
# Test djangos template filters
|
||||
self.assertEqual(notes.count(), 10)
|
||||
t = Template("{{ notes.count }}")
|
||||
c = Context({"notes": notes})
|
||||
self.assertEqual(t.render(c), "10")
|
||||
|
||||
# Test with limit
|
||||
notes = Note.objects.skip(90)
|
||||
self.assertEqual(notes.count(), 10)
|
||||
|
||||
# Test djangos template filters
|
||||
self.assertEqual(notes.count(), 10)
|
||||
t = Template("{{ notes.count }}")
|
||||
c = Context({"notes": notes})
|
||||
self.assertEqual(t.render(c), "10")
|
||||
|
||||
# Test with skip and limit
|
||||
notes = Note.objects.skip(10).limit(10)
|
||||
|
||||
# Test djangos template filters
|
||||
self.assertEqual(notes.count(), 10)
|
||||
t = Template("{{ notes.count }}")
|
||||
c = Context({"notes": notes})
|
||||
self.assertEqual(t.render(c), "10")
|
||||
|
||||
|
||||
class MongoDBSessionTest(SessionTestsMixin, unittest.TestCase):
|
||||
backend = SessionStore
|
||||
|
||||
47
tests/test_jinja.py
Normal file
47
tests/test_jinja.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import sys
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
import unittest
|
||||
|
||||
from mongoengine import *
|
||||
|
||||
import jinja2
|
||||
|
||||
|
||||
class TemplateFilterTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
connect(db='mongoenginetest')
|
||||
|
||||
def test_jinja2(self):
|
||||
env = jinja2.Environment()
|
||||
|
||||
class TestData(Document):
|
||||
title = StringField()
|
||||
description = StringField()
|
||||
|
||||
TestData.drop_collection()
|
||||
|
||||
examples = [('A', '1'),
|
||||
('B', '2'),
|
||||
('C', '3')]
|
||||
|
||||
for title, description in examples:
|
||||
TestData(title=title, description=description).save()
|
||||
|
||||
tmpl = """
|
||||
{%- for record in content -%}
|
||||
{%- if loop.first -%}{ {%- endif -%}
|
||||
"{{ record.title }}": "{{ record.description }}"
|
||||
{%- if loop.last -%} }{%- else -%},{% endif -%}
|
||||
{%- endfor -%}
|
||||
"""
|
||||
ctx = {'content': TestData.objects}
|
||||
template = env.from_string(tmpl)
|
||||
rendered = template.render(**ctx)
|
||||
|
||||
self.assertEqual('{"A": "1","B": "2","C": "3"}', rendered)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user