Added a cleaner way to get collection names
Also handles dynamic collection naming - refs #180.
This commit is contained in:
@@ -1,5 +1,10 @@
|
||||
import pickle
|
||||
import pymongo
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import pymongo
|
||||
import pickle
|
||||
import weakref
|
||||
@@ -30,7 +35,7 @@ class DocumentTest(unittest.TestCase):
|
||||
"""
|
||||
self.Person(name='Test').save()
|
||||
|
||||
collection = self.Person._meta['collection']
|
||||
collection = self.Person._get_collection_name()
|
||||
self.assertTrue(collection in self.db.collection_names())
|
||||
|
||||
self.Person.drop_collection()
|
||||
@@ -57,6 +62,23 @@ class DocumentTest(unittest.TestCase):
|
||||
# Ensure Document isn't treated like an actual document
|
||||
self.assertFalse(hasattr(Document, '_fields'))
|
||||
|
||||
def test_dynamic_collection_naming(self):
|
||||
|
||||
def create_collection_name(cls):
|
||||
return "PERSON"
|
||||
|
||||
class DynamicPerson(Document):
|
||||
name = StringField()
|
||||
age = IntField()
|
||||
|
||||
meta = {'collection': create_collection_name}
|
||||
|
||||
collection = DynamicPerson._get_collection_name()
|
||||
self.assertEquals(collection, 'PERSON')
|
||||
|
||||
DynamicPerson(name='Test User', age=30).save()
|
||||
self.assertTrue(collection in self.db.collection_names())
|
||||
|
||||
def test_get_superclasses(self):
|
||||
"""Ensure that the correct list of superclasses is assembled.
|
||||
"""
|
||||
@@ -225,8 +247,8 @@ class DocumentTest(unittest.TestCase):
|
||||
|
||||
self.assertTrue('name' in Employee._fields)
|
||||
self.assertTrue('salary' in Employee._fields)
|
||||
self.assertEqual(Employee._meta['collection'],
|
||||
self.Person._meta['collection'])
|
||||
self.assertEqual(Employee._get_collection_name(),
|
||||
self.Person._get_collection_name())
|
||||
|
||||
# Ensure that MRO error is not raised
|
||||
class A(Document): pass
|
||||
@@ -251,7 +273,7 @@ class DocumentTest(unittest.TestCase):
|
||||
# Check that _cls etc aren't present on simple documents
|
||||
dog = Animal(name='dog')
|
||||
dog.save()
|
||||
collection = self.db[Animal._meta['collection']]
|
||||
collection = self.db[Animal._get_collection_name()]
|
||||
obj = collection.find_one()
|
||||
self.assertFalse('_cls' in obj)
|
||||
self.assertFalse('_types' in obj)
|
||||
@@ -297,7 +319,7 @@ class DocumentTest(unittest.TestCase):
|
||||
# Check that _cls etc aren't present on simple documents
|
||||
dog = Animal(name='dog')
|
||||
dog.save()
|
||||
collection = self.db[Animal._meta['collection']]
|
||||
collection = self.db[Animal._get_collection_name()]
|
||||
obj = collection.find_one()
|
||||
self.assertFalse('_cls' in obj)
|
||||
self.assertFalse('_types' in obj)
|
||||
@@ -318,7 +340,7 @@ class DocumentTest(unittest.TestCase):
|
||||
dog = Animal(name='dog')
|
||||
dog.save()
|
||||
|
||||
collection = self.db[Animal._meta['collection']]
|
||||
collection = self.db[Animal._get_collection_name()]
|
||||
obj = collection.find_one()
|
||||
self.assertTrue('_cls' in obj)
|
||||
self.assertTrue('_types' in obj)
|
||||
@@ -381,9 +403,12 @@ class DocumentTest(unittest.TestCase):
|
||||
self.assertFalse('collection' in Animal._meta)
|
||||
self.assertFalse('collection' in Mammal._meta)
|
||||
|
||||
self.assertEqual(Fish._meta['collection'], 'fish')
|
||||
self.assertEqual(Guppy._meta['collection'], 'fish')
|
||||
self.assertEqual(Human._meta['collection'], 'human')
|
||||
self.assertEqual(Animal._get_collection_name(), None)
|
||||
self.assertEqual(Mammal._get_collection_name(), None)
|
||||
|
||||
self.assertEqual(Fish._get_collection_name(), 'fish')
|
||||
self.assertEqual(Guppy._get_collection_name(), 'fish')
|
||||
self.assertEqual(Human._get_collection_name(), 'human')
|
||||
|
||||
def create_bad_abstract():
|
||||
class EvilHuman(Human):
|
||||
@@ -434,14 +459,21 @@ class DocumentTest(unittest.TestCase):
|
||||
def test_inherited_collections(self):
|
||||
"""Ensure that subclassed documents don't override parents' collections.
|
||||
"""
|
||||
class Drink(Document):
|
||||
name = StringField()
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
# Cause all warnings to always be triggered.
|
||||
warnings.simplefilter("always")
|
||||
|
||||
class AlcoholicDrink(Drink):
|
||||
meta = {'collection': 'booze'}
|
||||
class Drink(Document):
|
||||
name = StringField()
|
||||
|
||||
class Drinker(Document):
|
||||
drink = GenericReferenceField()
|
||||
class AlcoholicDrink(Drink):
|
||||
meta = {'collection': 'booze'}
|
||||
|
||||
class Drinker(Document):
|
||||
drink = GenericReferenceField()
|
||||
|
||||
# Confirm we triggered a SyntaxWarning
|
||||
assert issubclass(w[0].category, SyntaxWarning)
|
||||
|
||||
Drink.drop_collection()
|
||||
AlcoholicDrink.drop_collection()
|
||||
@@ -455,7 +487,6 @@ class DocumentTest(unittest.TestCase):
|
||||
|
||||
beer = AlcoholicDrink(name='Beer')
|
||||
beer.save()
|
||||
|
||||
real_person = Drinker(drink=beer)
|
||||
real_person.save()
|
||||
|
||||
@@ -936,7 +967,7 @@ class DocumentTest(unittest.TestCase):
|
||||
person = self.Person(name='Test User', age=30)
|
||||
person.save()
|
||||
# Ensure that the object is in the database
|
||||
collection = self.db[self.Person._meta['collection']]
|
||||
collection = self.db[self.Person._get_collection_name()]
|
||||
person_obj = collection.find_one({'name': 'Test User'})
|
||||
self.assertEqual(person_obj['name'], 'Test User')
|
||||
self.assertEqual(person_obj['age'], 30)
|
||||
@@ -1279,7 +1310,7 @@ class DocumentTest(unittest.TestCase):
|
||||
id='497ce96f395f2f052a494fd4')
|
||||
person.save()
|
||||
# Ensure that the object is in the database with the correct _id
|
||||
collection = self.db[self.Person._meta['collection']]
|
||||
collection = self.db[self.Person._get_collection_name()]
|
||||
person_obj = collection.find_one({'name': 'Test User'})
|
||||
self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4')
|
||||
|
||||
@@ -1291,7 +1322,7 @@ class DocumentTest(unittest.TestCase):
|
||||
pk='497ce96f395f2f052a494fd4')
|
||||
person.save()
|
||||
# Ensure that the object is in the database with the correct _id
|
||||
collection = self.db[self.Person._meta['collection']]
|
||||
collection = self.db[self.Person._get_collection_name()]
|
||||
person_obj = collection.find_one({'name': 'Test User'})
|
||||
self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4')
|
||||
|
||||
@@ -1314,7 +1345,7 @@ class DocumentTest(unittest.TestCase):
|
||||
post.comments = comments
|
||||
post.save()
|
||||
|
||||
collection = self.db[BlogPost._meta['collection']]
|
||||
collection = self.db[BlogPost._get_collection_name()]
|
||||
post_obj = collection.find_one()
|
||||
self.assertEqual(post_obj['tags'], tags)
|
||||
for comment_obj, comment in zip(post_obj['comments'], comments):
|
||||
@@ -1339,7 +1370,7 @@ class DocumentTest(unittest.TestCase):
|
||||
employee.save()
|
||||
|
||||
# Ensure that the object is in the database
|
||||
collection = self.db[self.Person._meta['collection']]
|
||||
collection = self.db[self.Person._get_collection_name()]
|
||||
employee_obj = collection.find_one({'name': 'Test Employee'})
|
||||
self.assertEqual(employee_obj['name'], 'Test Employee')
|
||||
self.assertEqual(employee_obj['age'], 50)
|
||||
@@ -1370,6 +1401,7 @@ class DocumentTest(unittest.TestCase):
|
||||
promoted_employee.reload()
|
||||
self.assertEqual(promoted_employee.name, 'Test Employee')
|
||||
self.assertEqual(promoted_employee.age, 50)
|
||||
|
||||
# Ensure that the 'details' embedded object saved correctly
|
||||
self.assertEqual(promoted_employee.details.position, 'Senior Developer')
|
||||
|
||||
@@ -1399,7 +1431,7 @@ class DocumentTest(unittest.TestCase):
|
||||
p.save()
|
||||
self.assertEquals(p._fields.keys(), ['name', 'id'])
|
||||
|
||||
collection = self.db[Person._meta['collection']]
|
||||
collection = self.db[Person._get_collection_name()]
|
||||
obj = collection.find_one()
|
||||
self.assertEquals(obj['_cls'], 'Person')
|
||||
self.assertEquals(obj['_types'], ['Person'])
|
||||
@@ -1492,6 +1524,9 @@ class DocumentTest(unittest.TestCase):
|
||||
text = StringField()
|
||||
post = ReferenceField(BlogPost, reverse_delete_rule=CASCADE)
|
||||
|
||||
self.Person.drop_collection()
|
||||
BlogPost.drop_collection()
|
||||
Comment.drop_collection()
|
||||
|
||||
author = self.Person(name='Test User')
|
||||
author.save()
|
||||
|
||||
Reference in New Issue
Block a user