Added a cleaner way to get collection names
Also handles dynamic collection naming - refs #180.
This commit is contained in:
parent
1b0323bc22
commit
f41c5217c6
@ -22,6 +22,7 @@ class ValidationError(Exception):
|
|||||||
|
|
||||||
_document_registry = {}
|
_document_registry = {}
|
||||||
|
|
||||||
|
|
||||||
def get_document(name):
|
def get_document(name):
|
||||||
doc = _document_registry.get(name, None)
|
doc = _document_registry.get(name, None)
|
||||||
if not doc:
|
if not doc:
|
||||||
@ -195,7 +196,7 @@ class ComplexBaseField(BaseField):
|
|||||||
elif isinstance(v, (dict, pymongo.son.SON)):
|
elif isinstance(v, (dict, pymongo.son.SON)):
|
||||||
if '_ref' in v:
|
if '_ref' in v:
|
||||||
# generic reference
|
# generic reference
|
||||||
collection = get_document(v['_cls'])._meta['collection']
|
collection = get_document(v['_cls'])._get_collection_name()
|
||||||
collections.setdefault(collection, []).append((k,v))
|
collections.setdefault(collection, []).append((k,v))
|
||||||
else:
|
else:
|
||||||
# Use BaseDict so can watch any changes
|
# Use BaseDict so can watch any changes
|
||||||
@ -257,7 +258,7 @@ class ComplexBaseField(BaseField):
|
|||||||
if v.pk is None:
|
if v.pk is None:
|
||||||
raise ValidationError('You can only reference documents once '
|
raise ValidationError('You can only reference documents once '
|
||||||
'they have been saved to the database')
|
'they have been saved to the database')
|
||||||
collection = v._meta['collection']
|
collection = v._get_collection_name()
|
||||||
value_dict[k] = pymongo.dbref.DBRef(collection, v.pk)
|
value_dict[k] = pymongo.dbref.DBRef(collection, v.pk)
|
||||||
elif hasattr(v, 'to_python'):
|
elif hasattr(v, 'to_python'):
|
||||||
value_dict[k] = v.to_python()
|
value_dict[k] = v.to_python()
|
||||||
@ -306,7 +307,7 @@ class ComplexBaseField(BaseField):
|
|||||||
from fields import GenericReferenceField
|
from fields import GenericReferenceField
|
||||||
value_dict[k] = GenericReferenceField().to_mongo(v)
|
value_dict[k] = GenericReferenceField().to_mongo(v)
|
||||||
else:
|
else:
|
||||||
collection = v._meta['collection']
|
collection = v._get_collection_name()
|
||||||
value_dict[k] = pymongo.dbref.DBRef(collection, v.pk)
|
value_dict[k] = pymongo.dbref.DBRef(collection, v.pk)
|
||||||
elif hasattr(v, 'to_mongo'):
|
elif hasattr(v, 'to_mongo'):
|
||||||
value_dict[k] = v.to_mongo()
|
value_dict[k] = v.to_mongo()
|
||||||
@ -500,9 +501,14 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
|
|||||||
# Subclassed documents inherit collection from superclass
|
# Subclassed documents inherit collection from superclass
|
||||||
for base in bases:
|
for base in bases:
|
||||||
if hasattr(base, '_meta'):
|
if hasattr(base, '_meta'):
|
||||||
if 'collection' in base._meta:
|
|
||||||
collection = base._meta['collection']
|
|
||||||
|
|
||||||
|
if 'collection' in attrs.get('meta', {}) and not base._meta.get('abstract', False):
|
||||||
|
import warnings
|
||||||
|
msg = "Trying to set a collection on a subclass (%s)" % name
|
||||||
|
warnings.warn(msg, SyntaxWarning)
|
||||||
|
del(attrs['meta']['collection'])
|
||||||
|
if base._get_collection_name():
|
||||||
|
collection = base._get_collection_name()
|
||||||
# Propagate index options.
|
# Propagate index options.
|
||||||
for key in ('index_background', 'index_drop_dups', 'index_opts'):
|
for key in ('index_background', 'index_drop_dups', 'index_opts'):
|
||||||
if key in base._meta:
|
if key in base._meta:
|
||||||
@ -539,6 +545,10 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
|
|||||||
# DocumentMetaclass before instantiating CollectionManager object
|
# DocumentMetaclass before instantiating CollectionManager object
|
||||||
new_class = super_new(cls, name, bases, attrs)
|
new_class = super_new(cls, name, bases, attrs)
|
||||||
|
|
||||||
|
collection = attrs['_meta'].get('collection', None)
|
||||||
|
if callable(collection):
|
||||||
|
new_class._meta['collection'] = collection(new_class)
|
||||||
|
|
||||||
# Provide a default queryset unless one has been manually provided
|
# Provide a default queryset unless one has been manually provided
|
||||||
manager = attrs.get('objects', QuerySetManager())
|
manager = attrs.get('objects', QuerySetManager())
|
||||||
if hasattr(manager, 'queryset_class'):
|
if hasattr(manager, 'queryset_class'):
|
||||||
@ -675,6 +685,12 @@ class BaseDocument(object):
|
|||||||
elif field.required:
|
elif field.required:
|
||||||
raise ValidationError('Field "%s" is required' % field.name)
|
raise ValidationError('Field "%s" is required' % field.name)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_collection_name(cls):
|
||||||
|
"""Returns the collection name for this class.
|
||||||
|
"""
|
||||||
|
return cls._meta.get('collection', None)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_subclasses(cls):
|
def _get_subclasses(cls):
|
||||||
"""Return a dictionary of all subclasses (found recursively).
|
"""Return a dictionary of all subclasses (found recursively).
|
||||||
|
@ -6,7 +6,12 @@ from connection import _get_db
|
|||||||
|
|
||||||
import pymongo
|
import pymongo
|
||||||
|
|
||||||
__all__ = ['Document', 'EmbeddedDocument', 'ValidationError', 'OperationError']
|
__all__ = ['Document', 'EmbeddedDocument', 'ValidationError',
|
||||||
|
'OperationError', 'InvalidCollectionError']
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidCollectionError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class EmbeddedDocument(BaseDocument):
|
class EmbeddedDocument(BaseDocument):
|
||||||
@ -72,6 +77,41 @@ class Document(BaseDocument):
|
|||||||
"""
|
"""
|
||||||
__metaclass__ = TopLevelDocumentMetaclass
|
__metaclass__ = TopLevelDocumentMetaclass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_collection(self):
|
||||||
|
"""Returns the collection for the document."""
|
||||||
|
db = _get_db()
|
||||||
|
collection_name = self._get_collection_name()
|
||||||
|
|
||||||
|
if not hasattr(self, '_collection') or self._collection is None:
|
||||||
|
# Create collection as a capped collection if specified
|
||||||
|
if self._meta['max_size'] or self._meta['max_documents']:
|
||||||
|
# Get max document limit and max byte size from meta
|
||||||
|
max_size = self._meta['max_size'] or 10000000 # 10MB default
|
||||||
|
max_documents = self._meta['max_documents']
|
||||||
|
|
||||||
|
if collection_name in db.collection_names():
|
||||||
|
self._collection = db[collection_name]
|
||||||
|
# The collection already exists, check if its capped
|
||||||
|
# options match the specified capped options
|
||||||
|
options = self._collection.options()
|
||||||
|
if options.get('max') != max_documents or \
|
||||||
|
options.get('size') != max_size:
|
||||||
|
msg = ('Cannot create collection "%s" as a capped '
|
||||||
|
'collection as it already exists') % self._collection
|
||||||
|
raise InvalidCollectionError(msg)
|
||||||
|
else:
|
||||||
|
# Create the collection as a capped collection
|
||||||
|
opts = {'capped': True, 'size': max_size}
|
||||||
|
if max_documents:
|
||||||
|
opts['max'] = max_documents
|
||||||
|
self._collection = db.create_collection(
|
||||||
|
collection_name, **opts
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._collection = db[collection_name]
|
||||||
|
return self._collection
|
||||||
|
|
||||||
def save(self, safe=True, force_insert=False, validate=True, write_options=None):
|
def save(self, safe=True, force_insert=False, validate=True, write_options=None):
|
||||||
"""Save the :class:`~mongoengine.Document` to the database. If the
|
"""Save the :class:`~mongoengine.Document` to the database. If the
|
||||||
document already exists, it will be updated, otherwise it will be
|
document already exists, it will be updated, otherwise it will be
|
||||||
@ -173,7 +213,7 @@ class Document(BaseDocument):
|
|||||||
if not self.pk:
|
if not self.pk:
|
||||||
msg = "Only saved documents can have a valid dbref"
|
msg = "Only saved documents can have a valid dbref"
|
||||||
raise OperationError(msg)
|
raise OperationError(msg)
|
||||||
return pymongo.dbref.DBRef(self.__class__._meta['collection'], self.pk)
|
return pymongo.dbref.DBRef(self.__class__._get_collection_name(), self.pk)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_delete_rule(cls, document_cls, field_name, rule):
|
def register_delete_rule(cls, document_cls, field_name, rule):
|
||||||
@ -188,7 +228,7 @@ class Document(BaseDocument):
|
|||||||
:class:`~mongoengine.Document` type from the database.
|
:class:`~mongoengine.Document` type from the database.
|
||||||
"""
|
"""
|
||||||
db = _get_db()
|
db = _get_db()
|
||||||
db.drop_collection(cls._meta['collection'])
|
db.drop_collection(cls._get_collection_name())
|
||||||
|
|
||||||
|
|
||||||
class MapReduceDocument(object):
|
class MapReduceDocument(object):
|
||||||
|
@ -252,7 +252,7 @@ class DateTimeField(BaseField):
|
|||||||
return datetime.datetime(value.year, value.month, value.day)
|
return datetime.datetime(value.year, value.month, value.day)
|
||||||
|
|
||||||
# Attempt to parse a datetime:
|
# Attempt to parse a datetime:
|
||||||
#value = smart_str(value)
|
# value = smart_str(value)
|
||||||
# split usecs, because they are not recognized by strptime.
|
# split usecs, because they are not recognized by strptime.
|
||||||
if '.' in value:
|
if '.' in value:
|
||||||
try:
|
try:
|
||||||
@ -278,6 +278,7 @@ class DateTimeField(BaseField):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ComplexDateTimeField(StringField):
|
class ComplexDateTimeField(StringField):
|
||||||
"""
|
"""
|
||||||
ComplexDateTimeField handles microseconds exactly instead of rounding
|
ComplexDateTimeField handles microseconds exactly instead of rounding
|
||||||
@ -526,6 +527,7 @@ class MapField(DictField):
|
|||||||
super(MapField, self).__init__(field=field, *args, **kwargs)
|
super(MapField, self).__init__(field=field, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ReferenceField(BaseField):
|
class ReferenceField(BaseField):
|
||||||
"""A reference to a document that will be automatically dereferenced on
|
"""A reference to a document that will be automatically dereferenced on
|
||||||
access (lazily).
|
access (lazily).
|
||||||
@ -595,7 +597,7 @@ class ReferenceField(BaseField):
|
|||||||
id_ = document
|
id_ = document
|
||||||
|
|
||||||
id_ = id_field.to_mongo(id_)
|
id_ = id_field.to_mongo(id_)
|
||||||
collection = self.document_type._meta['collection']
|
collection = self.document_type._get_collection_name()
|
||||||
return pymongo.dbref.DBRef(collection, id_)
|
return pymongo.dbref.DBRef(collection, id_)
|
||||||
|
|
||||||
def prepare_query_value(self, op, value):
|
def prepare_query_value(self, op, value):
|
||||||
@ -664,7 +666,7 @@ class GenericReferenceField(BaseField):
|
|||||||
id_ = document
|
id_ = document
|
||||||
|
|
||||||
id_ = id_field.to_mongo(id_)
|
id_ = id_field.to_mongo(id_)
|
||||||
collection = document._meta['collection']
|
collection = document._get_collection_name()
|
||||||
ref = pymongo.dbref.DBRef(collection, id_)
|
ref = pymongo.dbref.DBRef(collection, id_)
|
||||||
return {'_cls': document._class_name, '_ref': ref}
|
return {'_cls': document._class_name, '_ref': ref}
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ import itertools
|
|||||||
import operator
|
import operator
|
||||||
|
|
||||||
__all__ = ['queryset_manager', 'Q', 'InvalidQueryError',
|
__all__ = ['queryset_manager', 'Q', 'InvalidQueryError',
|
||||||
'InvalidCollectionError', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY']
|
'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY']
|
||||||
|
|
||||||
|
|
||||||
# The maximum number of items to display in a QuerySet.__repr__
|
# The maximum number of items to display in a QuerySet.__repr__
|
||||||
@ -40,10 +40,6 @@ class OperationError(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class InvalidCollectionError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
RE_TYPE = type(re.compile(''))
|
RE_TYPE = type(re.compile(''))
|
||||||
|
|
||||||
|
|
||||||
@ -1360,7 +1356,7 @@ class QuerySet(object):
|
|||||||
|
|
||||||
fields = [QuerySet._translate_field_name(self._document, f)
|
fields = [QuerySet._translate_field_name(self._document, f)
|
||||||
for f in fields]
|
for f in fields]
|
||||||
collection = self._document._meta['collection']
|
collection = self._document._get_collection_name()
|
||||||
|
|
||||||
scope = {
|
scope = {
|
||||||
'collection': collection,
|
'collection': collection,
|
||||||
@ -1550,39 +1546,9 @@ class QuerySetManager(object):
|
|||||||
# Document class being used rather than a document object
|
# Document class being used rather than a document object
|
||||||
return self
|
return self
|
||||||
|
|
||||||
db = _get_db()
|
|
||||||
collection = owner._meta['collection']
|
|
||||||
if (db, collection) not in self._collections:
|
|
||||||
# Create collection as a capped collection if specified
|
|
||||||
if owner._meta['max_size'] or owner._meta['max_documents']:
|
|
||||||
# Get max document limit and max byte size from meta
|
|
||||||
max_size = owner._meta['max_size'] or 10000000 # 10MB default
|
|
||||||
max_documents = owner._meta['max_documents']
|
|
||||||
|
|
||||||
if collection in db.collection_names():
|
|
||||||
self._collections[(db, collection)] = db[collection]
|
|
||||||
# The collection already exists, check if its capped
|
|
||||||
# options match the specified capped options
|
|
||||||
options = self._collections[(db, collection)].options()
|
|
||||||
if options.get('max') != max_documents or \
|
|
||||||
options.get('size') != max_size:
|
|
||||||
msg = ('Cannot create collection "%s" as a capped '
|
|
||||||
'collection as it already exists') % collection
|
|
||||||
raise InvalidCollectionError(msg)
|
|
||||||
else:
|
|
||||||
# Create the collection as a capped collection
|
|
||||||
opts = {'capped': True, 'size': max_size}
|
|
||||||
if max_documents:
|
|
||||||
opts['max'] = max_documents
|
|
||||||
self._collections[(db, collection)] = db.create_collection(
|
|
||||||
collection, **opts
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self._collections[(db, collection)] = db[collection]
|
|
||||||
|
|
||||||
# owner is the document that contains the QuerySetManager
|
# owner is the document that contains the QuerySetManager
|
||||||
queryset_class = owner._meta['queryset_class'] or QuerySet
|
queryset_class = owner._meta['queryset_class'] or QuerySet
|
||||||
queryset = queryset_class(owner, self._collections[(db, collection)])
|
queryset = queryset_class(owner, owner._get_collection())
|
||||||
if self.get_queryset:
|
if self.get_queryset:
|
||||||
if self.get_queryset.func_code.co_argcount == 1:
|
if self.get_queryset.func_code.co_argcount == 1:
|
||||||
queryset = self.get_queryset(queryset)
|
queryset = self.get_queryset(queryset)
|
||||||
|
@ -1,5 +1,10 @@
|
|||||||
|
import pickle
|
||||||
|
import pymongo
|
||||||
import unittest
|
import unittest
|
||||||
|
import warnings
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import pymongo
|
import pymongo
|
||||||
import pickle
|
import pickle
|
||||||
import weakref
|
import weakref
|
||||||
@ -30,7 +35,7 @@ class DocumentTest(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
self.Person(name='Test').save()
|
self.Person(name='Test').save()
|
||||||
|
|
||||||
collection = self.Person._meta['collection']
|
collection = self.Person._get_collection_name()
|
||||||
self.assertTrue(collection in self.db.collection_names())
|
self.assertTrue(collection in self.db.collection_names())
|
||||||
|
|
||||||
self.Person.drop_collection()
|
self.Person.drop_collection()
|
||||||
@ -57,6 +62,23 @@ class DocumentTest(unittest.TestCase):
|
|||||||
# Ensure Document isn't treated like an actual document
|
# Ensure Document isn't treated like an actual document
|
||||||
self.assertFalse(hasattr(Document, '_fields'))
|
self.assertFalse(hasattr(Document, '_fields'))
|
||||||
|
|
||||||
|
def test_dynamic_collection_naming(self):
|
||||||
|
|
||||||
|
def create_collection_name(cls):
|
||||||
|
return "PERSON"
|
||||||
|
|
||||||
|
class DynamicPerson(Document):
|
||||||
|
name = StringField()
|
||||||
|
age = IntField()
|
||||||
|
|
||||||
|
meta = {'collection': create_collection_name}
|
||||||
|
|
||||||
|
collection = DynamicPerson._get_collection_name()
|
||||||
|
self.assertEquals(collection, 'PERSON')
|
||||||
|
|
||||||
|
DynamicPerson(name='Test User', age=30).save()
|
||||||
|
self.assertTrue(collection in self.db.collection_names())
|
||||||
|
|
||||||
def test_get_superclasses(self):
|
def test_get_superclasses(self):
|
||||||
"""Ensure that the correct list of superclasses is assembled.
|
"""Ensure that the correct list of superclasses is assembled.
|
||||||
"""
|
"""
|
||||||
@ -225,8 +247,8 @@ class DocumentTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertTrue('name' in Employee._fields)
|
self.assertTrue('name' in Employee._fields)
|
||||||
self.assertTrue('salary' in Employee._fields)
|
self.assertTrue('salary' in Employee._fields)
|
||||||
self.assertEqual(Employee._meta['collection'],
|
self.assertEqual(Employee._get_collection_name(),
|
||||||
self.Person._meta['collection'])
|
self.Person._get_collection_name())
|
||||||
|
|
||||||
# Ensure that MRO error is not raised
|
# Ensure that MRO error is not raised
|
||||||
class A(Document): pass
|
class A(Document): pass
|
||||||
@ -251,7 +273,7 @@ class DocumentTest(unittest.TestCase):
|
|||||||
# Check that _cls etc aren't present on simple documents
|
# Check that _cls etc aren't present on simple documents
|
||||||
dog = Animal(name='dog')
|
dog = Animal(name='dog')
|
||||||
dog.save()
|
dog.save()
|
||||||
collection = self.db[Animal._meta['collection']]
|
collection = self.db[Animal._get_collection_name()]
|
||||||
obj = collection.find_one()
|
obj = collection.find_one()
|
||||||
self.assertFalse('_cls' in obj)
|
self.assertFalse('_cls' in obj)
|
||||||
self.assertFalse('_types' in obj)
|
self.assertFalse('_types' in obj)
|
||||||
@ -297,7 +319,7 @@ class DocumentTest(unittest.TestCase):
|
|||||||
# Check that _cls etc aren't present on simple documents
|
# Check that _cls etc aren't present on simple documents
|
||||||
dog = Animal(name='dog')
|
dog = Animal(name='dog')
|
||||||
dog.save()
|
dog.save()
|
||||||
collection = self.db[Animal._meta['collection']]
|
collection = self.db[Animal._get_collection_name()]
|
||||||
obj = collection.find_one()
|
obj = collection.find_one()
|
||||||
self.assertFalse('_cls' in obj)
|
self.assertFalse('_cls' in obj)
|
||||||
self.assertFalse('_types' in obj)
|
self.assertFalse('_types' in obj)
|
||||||
@ -318,7 +340,7 @@ class DocumentTest(unittest.TestCase):
|
|||||||
dog = Animal(name='dog')
|
dog = Animal(name='dog')
|
||||||
dog.save()
|
dog.save()
|
||||||
|
|
||||||
collection = self.db[Animal._meta['collection']]
|
collection = self.db[Animal._get_collection_name()]
|
||||||
obj = collection.find_one()
|
obj = collection.find_one()
|
||||||
self.assertTrue('_cls' in obj)
|
self.assertTrue('_cls' in obj)
|
||||||
self.assertTrue('_types' in obj)
|
self.assertTrue('_types' in obj)
|
||||||
@ -381,9 +403,12 @@ class DocumentTest(unittest.TestCase):
|
|||||||
self.assertFalse('collection' in Animal._meta)
|
self.assertFalse('collection' in Animal._meta)
|
||||||
self.assertFalse('collection' in Mammal._meta)
|
self.assertFalse('collection' in Mammal._meta)
|
||||||
|
|
||||||
self.assertEqual(Fish._meta['collection'], 'fish')
|
self.assertEqual(Animal._get_collection_name(), None)
|
||||||
self.assertEqual(Guppy._meta['collection'], 'fish')
|
self.assertEqual(Mammal._get_collection_name(), None)
|
||||||
self.assertEqual(Human._meta['collection'], 'human')
|
|
||||||
|
self.assertEqual(Fish._get_collection_name(), 'fish')
|
||||||
|
self.assertEqual(Guppy._get_collection_name(), 'fish')
|
||||||
|
self.assertEqual(Human._get_collection_name(), 'human')
|
||||||
|
|
||||||
def create_bad_abstract():
|
def create_bad_abstract():
|
||||||
class EvilHuman(Human):
|
class EvilHuman(Human):
|
||||||
@ -434,14 +459,21 @@ class DocumentTest(unittest.TestCase):
|
|||||||
def test_inherited_collections(self):
|
def test_inherited_collections(self):
|
||||||
"""Ensure that subclassed documents don't override parents' collections.
|
"""Ensure that subclassed documents don't override parents' collections.
|
||||||
"""
|
"""
|
||||||
class Drink(Document):
|
with warnings.catch_warnings(record=True) as w:
|
||||||
name = StringField()
|
# Cause all warnings to always be triggered.
|
||||||
|
warnings.simplefilter("always")
|
||||||
|
|
||||||
class AlcoholicDrink(Drink):
|
class Drink(Document):
|
||||||
meta = {'collection': 'booze'}
|
name = StringField()
|
||||||
|
|
||||||
class Drinker(Document):
|
class AlcoholicDrink(Drink):
|
||||||
drink = GenericReferenceField()
|
meta = {'collection': 'booze'}
|
||||||
|
|
||||||
|
class Drinker(Document):
|
||||||
|
drink = GenericReferenceField()
|
||||||
|
|
||||||
|
# Confirm we triggered a SyntaxWarning
|
||||||
|
assert issubclass(w[0].category, SyntaxWarning)
|
||||||
|
|
||||||
Drink.drop_collection()
|
Drink.drop_collection()
|
||||||
AlcoholicDrink.drop_collection()
|
AlcoholicDrink.drop_collection()
|
||||||
@ -455,7 +487,6 @@ class DocumentTest(unittest.TestCase):
|
|||||||
|
|
||||||
beer = AlcoholicDrink(name='Beer')
|
beer = AlcoholicDrink(name='Beer')
|
||||||
beer.save()
|
beer.save()
|
||||||
|
|
||||||
real_person = Drinker(drink=beer)
|
real_person = Drinker(drink=beer)
|
||||||
real_person.save()
|
real_person.save()
|
||||||
|
|
||||||
@ -936,7 +967,7 @@ class DocumentTest(unittest.TestCase):
|
|||||||
person = self.Person(name='Test User', age=30)
|
person = self.Person(name='Test User', age=30)
|
||||||
person.save()
|
person.save()
|
||||||
# Ensure that the object is in the database
|
# Ensure that the object is in the database
|
||||||
collection = self.db[self.Person._meta['collection']]
|
collection = self.db[self.Person._get_collection_name()]
|
||||||
person_obj = collection.find_one({'name': 'Test User'})
|
person_obj = collection.find_one({'name': 'Test User'})
|
||||||
self.assertEqual(person_obj['name'], 'Test User')
|
self.assertEqual(person_obj['name'], 'Test User')
|
||||||
self.assertEqual(person_obj['age'], 30)
|
self.assertEqual(person_obj['age'], 30)
|
||||||
@ -1279,7 +1310,7 @@ class DocumentTest(unittest.TestCase):
|
|||||||
id='497ce96f395f2f052a494fd4')
|
id='497ce96f395f2f052a494fd4')
|
||||||
person.save()
|
person.save()
|
||||||
# Ensure that the object is in the database with the correct _id
|
# Ensure that the object is in the database with the correct _id
|
||||||
collection = self.db[self.Person._meta['collection']]
|
collection = self.db[self.Person._get_collection_name()]
|
||||||
person_obj = collection.find_one({'name': 'Test User'})
|
person_obj = collection.find_one({'name': 'Test User'})
|
||||||
self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4')
|
self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4')
|
||||||
|
|
||||||
@ -1291,7 +1322,7 @@ class DocumentTest(unittest.TestCase):
|
|||||||
pk='497ce96f395f2f052a494fd4')
|
pk='497ce96f395f2f052a494fd4')
|
||||||
person.save()
|
person.save()
|
||||||
# Ensure that the object is in the database with the correct _id
|
# Ensure that the object is in the database with the correct _id
|
||||||
collection = self.db[self.Person._meta['collection']]
|
collection = self.db[self.Person._get_collection_name()]
|
||||||
person_obj = collection.find_one({'name': 'Test User'})
|
person_obj = collection.find_one({'name': 'Test User'})
|
||||||
self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4')
|
self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4')
|
||||||
|
|
||||||
@ -1314,7 +1345,7 @@ class DocumentTest(unittest.TestCase):
|
|||||||
post.comments = comments
|
post.comments = comments
|
||||||
post.save()
|
post.save()
|
||||||
|
|
||||||
collection = self.db[BlogPost._meta['collection']]
|
collection = self.db[BlogPost._get_collection_name()]
|
||||||
post_obj = collection.find_one()
|
post_obj = collection.find_one()
|
||||||
self.assertEqual(post_obj['tags'], tags)
|
self.assertEqual(post_obj['tags'], tags)
|
||||||
for comment_obj, comment in zip(post_obj['comments'], comments):
|
for comment_obj, comment in zip(post_obj['comments'], comments):
|
||||||
@ -1339,7 +1370,7 @@ class DocumentTest(unittest.TestCase):
|
|||||||
employee.save()
|
employee.save()
|
||||||
|
|
||||||
# Ensure that the object is in the database
|
# Ensure that the object is in the database
|
||||||
collection = self.db[self.Person._meta['collection']]
|
collection = self.db[self.Person._get_collection_name()]
|
||||||
employee_obj = collection.find_one({'name': 'Test Employee'})
|
employee_obj = collection.find_one({'name': 'Test Employee'})
|
||||||
self.assertEqual(employee_obj['name'], 'Test Employee')
|
self.assertEqual(employee_obj['name'], 'Test Employee')
|
||||||
self.assertEqual(employee_obj['age'], 50)
|
self.assertEqual(employee_obj['age'], 50)
|
||||||
@ -1370,6 +1401,7 @@ class DocumentTest(unittest.TestCase):
|
|||||||
promoted_employee.reload()
|
promoted_employee.reload()
|
||||||
self.assertEqual(promoted_employee.name, 'Test Employee')
|
self.assertEqual(promoted_employee.name, 'Test Employee')
|
||||||
self.assertEqual(promoted_employee.age, 50)
|
self.assertEqual(promoted_employee.age, 50)
|
||||||
|
|
||||||
# Ensure that the 'details' embedded object saved correctly
|
# Ensure that the 'details' embedded object saved correctly
|
||||||
self.assertEqual(promoted_employee.details.position, 'Senior Developer')
|
self.assertEqual(promoted_employee.details.position, 'Senior Developer')
|
||||||
|
|
||||||
@ -1399,7 +1431,7 @@ class DocumentTest(unittest.TestCase):
|
|||||||
p.save()
|
p.save()
|
||||||
self.assertEquals(p._fields.keys(), ['name', 'id'])
|
self.assertEquals(p._fields.keys(), ['name', 'id'])
|
||||||
|
|
||||||
collection = self.db[Person._meta['collection']]
|
collection = self.db[Person._get_collection_name()]
|
||||||
obj = collection.find_one()
|
obj = collection.find_one()
|
||||||
self.assertEquals(obj['_cls'], 'Person')
|
self.assertEquals(obj['_cls'], 'Person')
|
||||||
self.assertEquals(obj['_types'], ['Person'])
|
self.assertEquals(obj['_types'], ['Person'])
|
||||||
@ -1492,6 +1524,9 @@ class DocumentTest(unittest.TestCase):
|
|||||||
text = StringField()
|
text = StringField()
|
||||||
post = ReferenceField(BlogPost, reverse_delete_rule=CASCADE)
|
post = ReferenceField(BlogPost, reverse_delete_rule=CASCADE)
|
||||||
|
|
||||||
|
self.Person.drop_collection()
|
||||||
|
BlogPost.drop_collection()
|
||||||
|
Comment.drop_collection()
|
||||||
|
|
||||||
author = self.Person(name='Test User')
|
author = self.Person(name='Test User')
|
||||||
author.save()
|
author.save()
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
import pymongo
|
import pymongo
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
@ -27,7 +25,7 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
self.assertTrue(isinstance(self.Person.objects, QuerySet))
|
self.assertTrue(isinstance(self.Person.objects, QuerySet))
|
||||||
self.assertEqual(self.Person.objects._collection.name,
|
self.assertEqual(self.Person.objects._collection.name,
|
||||||
self.Person._meta['collection'])
|
self.Person._get_collection_name())
|
||||||
self.assertTrue(isinstance(self.Person.objects._collection,
|
self.assertTrue(isinstance(self.Person.objects._collection,
|
||||||
pymongo.collection.Collection))
|
pymongo.collection.Collection))
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user