From f41c5217c6f0155d4aa7909fb52a556a79c67aba Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Mon, 20 Jun 2011 11:48:12 +0100 Subject: [PATCH] Added a cleaner way to get collection names Also handles dynamic collection naming - refs #180. --- mongoengine/base.py | 26 +++++++++++--- mongoengine/document.py | 46 ++++++++++++++++++++++-- mongoengine/fields.py | 8 +++-- mongoengine/queryset.py | 40 ++------------------- tests/document.py | 79 +++++++++++++++++++++++++++++------------ tests/queryset.py | 4 +-- 6 files changed, 130 insertions(+), 73 deletions(-) diff --git a/mongoengine/base.py b/mongoengine/base.py index f8d415b0..e59119eb 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -22,6 +22,7 @@ class ValidationError(Exception): _document_registry = {} + def get_document(name): doc = _document_registry.get(name, None) if not doc: @@ -195,7 +196,7 @@ class ComplexBaseField(BaseField): elif isinstance(v, (dict, pymongo.son.SON)): if '_ref' in v: # generic reference - collection = get_document(v['_cls'])._meta['collection'] + collection = get_document(v['_cls'])._get_collection_name() collections.setdefault(collection, []).append((k,v)) else: # Use BaseDict so can watch any changes @@ -257,7 +258,7 @@ class ComplexBaseField(BaseField): if v.pk is None: raise ValidationError('You can only reference documents once ' 'they have been saved to the database') - collection = v._meta['collection'] + collection = v._get_collection_name() value_dict[k] = pymongo.dbref.DBRef(collection, v.pk) elif hasattr(v, 'to_python'): value_dict[k] = v.to_python() @@ -306,7 +307,7 @@ class ComplexBaseField(BaseField): from fields import GenericReferenceField value_dict[k] = GenericReferenceField().to_mongo(v) else: - collection = v._meta['collection'] + collection = v._get_collection_name() value_dict[k] = pymongo.dbref.DBRef(collection, v.pk) elif hasattr(v, 'to_mongo'): value_dict[k] = v.to_mongo() @@ -500,9 +501,14 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): # Subclassed documents inherit collection from superclass for base in bases: if hasattr(base, '_meta'): - if 'collection' in base._meta: - collection = base._meta['collection'] + if 'collection' in attrs.get('meta', {}) and not base._meta.get('abstract', False): + import warnings + msg = "Trying to set a collection on a subclass (%s)" % name + warnings.warn(msg, SyntaxWarning) + del(attrs['meta']['collection']) + if base._get_collection_name(): + collection = base._get_collection_name() # Propagate index options. for key in ('index_background', 'index_drop_dups', 'index_opts'): if key in base._meta: @@ -539,6 +545,10 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): # DocumentMetaclass before instantiating CollectionManager object new_class = super_new(cls, name, bases, attrs) + collection = attrs['_meta'].get('collection', None) + if callable(collection): + new_class._meta['collection'] = collection(new_class) + # Provide a default queryset unless one has been manually provided manager = attrs.get('objects', QuerySetManager()) if hasattr(manager, 'queryset_class'): @@ -675,6 +685,12 @@ class BaseDocument(object): elif field.required: raise ValidationError('Field "%s" is required' % field.name) + @classmethod + def _get_collection_name(cls): + """Returns the collection name for this class. + """ + return cls._meta.get('collection', None) + @classmethod def _get_subclasses(cls): """Return a dictionary of all subclasses (found recursively). diff --git a/mongoengine/document.py b/mongoengine/document.py index 0b408cc2..36bf4017 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -6,7 +6,12 @@ from connection import _get_db import pymongo -__all__ = ['Document', 'EmbeddedDocument', 'ValidationError', 'OperationError'] +__all__ = ['Document', 'EmbeddedDocument', 'ValidationError', + 'OperationError', 'InvalidCollectionError'] + + +class InvalidCollectionError(Exception): + pass class EmbeddedDocument(BaseDocument): @@ -72,6 +77,41 @@ class Document(BaseDocument): """ __metaclass__ = TopLevelDocumentMetaclass + @classmethod + def _get_collection(self): + """Returns the collection for the document.""" + db = _get_db() + collection_name = self._get_collection_name() + + if not hasattr(self, '_collection') or self._collection is None: + # Create collection as a capped collection if specified + if self._meta['max_size'] or self._meta['max_documents']: + # Get max document limit and max byte size from meta + max_size = self._meta['max_size'] or 10000000 # 10MB default + max_documents = self._meta['max_documents'] + + if collection_name in db.collection_names(): + self._collection = db[collection_name] + # The collection already exists, check if its capped + # options match the specified capped options + options = self._collection.options() + if options.get('max') != max_documents or \ + options.get('size') != max_size: + msg = ('Cannot create collection "%s" as a capped ' + 'collection as it already exists') % self._collection + raise InvalidCollectionError(msg) + else: + # Create the collection as a capped collection + opts = {'capped': True, 'size': max_size} + if max_documents: + opts['max'] = max_documents + self._collection = db.create_collection( + collection_name, **opts + ) + else: + self._collection = db[collection_name] + return self._collection + def save(self, safe=True, force_insert=False, validate=True, write_options=None): """Save the :class:`~mongoengine.Document` to the database. If the document already exists, it will be updated, otherwise it will be @@ -173,7 +213,7 @@ class Document(BaseDocument): if not self.pk: msg = "Only saved documents can have a valid dbref" raise OperationError(msg) - return pymongo.dbref.DBRef(self.__class__._meta['collection'], self.pk) + return pymongo.dbref.DBRef(self.__class__._get_collection_name(), self.pk) @classmethod def register_delete_rule(cls, document_cls, field_name, rule): @@ -188,7 +228,7 @@ class Document(BaseDocument): :class:`~mongoengine.Document` type from the database. """ db = _get_db() - db.drop_collection(cls._meta['collection']) + db.drop_collection(cls._get_collection_name()) class MapReduceDocument(object): diff --git a/mongoengine/fields.py b/mongoengine/fields.py index e1b43664..50a30a13 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -252,7 +252,7 @@ class DateTimeField(BaseField): return datetime.datetime(value.year, value.month, value.day) # Attempt to parse a datetime: - #value = smart_str(value) + # value = smart_str(value) # split usecs, because they are not recognized by strptime. if '.' in value: try: @@ -278,6 +278,7 @@ class DateTimeField(BaseField): return None + class ComplexDateTimeField(StringField): """ ComplexDateTimeField handles microseconds exactly instead of rounding @@ -526,6 +527,7 @@ class MapField(DictField): super(MapField, self).__init__(field=field, *args, **kwargs) + class ReferenceField(BaseField): """A reference to a document that will be automatically dereferenced on access (lazily). @@ -595,7 +597,7 @@ class ReferenceField(BaseField): id_ = document id_ = id_field.to_mongo(id_) - collection = self.document_type._meta['collection'] + collection = self.document_type._get_collection_name() return pymongo.dbref.DBRef(collection, id_) def prepare_query_value(self, op, value): @@ -664,7 +666,7 @@ class GenericReferenceField(BaseField): id_ = document id_ = id_field.to_mongo(id_) - collection = document._meta['collection'] + collection = document._get_collection_name() ref = pymongo.dbref.DBRef(collection, id_) return {'_cls': document._class_name, '_ref': ref} diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 79d24bba..2a5d3edb 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -11,7 +11,7 @@ import itertools import operator __all__ = ['queryset_manager', 'Q', 'InvalidQueryError', - 'InvalidCollectionError', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY'] + 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY'] # The maximum number of items to display in a QuerySet.__repr__ @@ -40,10 +40,6 @@ class OperationError(Exception): pass -class InvalidCollectionError(Exception): - pass - - RE_TYPE = type(re.compile('')) @@ -1360,7 +1356,7 @@ class QuerySet(object): fields = [QuerySet._translate_field_name(self._document, f) for f in fields] - collection = self._document._meta['collection'] + collection = self._document._get_collection_name() scope = { 'collection': collection, @@ -1550,39 +1546,9 @@ class QuerySetManager(object): # Document class being used rather than a document object return self - db = _get_db() - collection = owner._meta['collection'] - if (db, collection) not in self._collections: - # Create collection as a capped collection if specified - if owner._meta['max_size'] or owner._meta['max_documents']: - # Get max document limit and max byte size from meta - max_size = owner._meta['max_size'] or 10000000 # 10MB default - max_documents = owner._meta['max_documents'] - - if collection in db.collection_names(): - self._collections[(db, collection)] = db[collection] - # The collection already exists, check if its capped - # options match the specified capped options - options = self._collections[(db, collection)].options() - if options.get('max') != max_documents or \ - options.get('size') != max_size: - msg = ('Cannot create collection "%s" as a capped ' - 'collection as it already exists') % collection - raise InvalidCollectionError(msg) - else: - # Create the collection as a capped collection - opts = {'capped': True, 'size': max_size} - if max_documents: - opts['max'] = max_documents - self._collections[(db, collection)] = db.create_collection( - collection, **opts - ) - else: - self._collections[(db, collection)] = db[collection] - # owner is the document that contains the QuerySetManager queryset_class = owner._meta['queryset_class'] or QuerySet - queryset = queryset_class(owner, self._collections[(db, collection)]) + queryset = queryset_class(owner, owner._get_collection()) if self.get_queryset: if self.get_queryset.func_code.co_argcount == 1: queryset = self.get_queryset(queryset) diff --git a/tests/document.py b/tests/document.py index c5aa6e89..c10c903f 100644 --- a/tests/document.py +++ b/tests/document.py @@ -1,5 +1,10 @@ +import pickle +import pymongo import unittest +import warnings + from datetime import datetime + import pymongo import pickle import weakref @@ -30,7 +35,7 @@ class DocumentTest(unittest.TestCase): """ self.Person(name='Test').save() - collection = self.Person._meta['collection'] + collection = self.Person._get_collection_name() self.assertTrue(collection in self.db.collection_names()) self.Person.drop_collection() @@ -57,6 +62,23 @@ class DocumentTest(unittest.TestCase): # Ensure Document isn't treated like an actual document self.assertFalse(hasattr(Document, '_fields')) + def test_dynamic_collection_naming(self): + + def create_collection_name(cls): + return "PERSON" + + class DynamicPerson(Document): + name = StringField() + age = IntField() + + meta = {'collection': create_collection_name} + + collection = DynamicPerson._get_collection_name() + self.assertEquals(collection, 'PERSON') + + DynamicPerson(name='Test User', age=30).save() + self.assertTrue(collection in self.db.collection_names()) + def test_get_superclasses(self): """Ensure that the correct list of superclasses is assembled. """ @@ -225,8 +247,8 @@ class DocumentTest(unittest.TestCase): self.assertTrue('name' in Employee._fields) self.assertTrue('salary' in Employee._fields) - self.assertEqual(Employee._meta['collection'], - self.Person._meta['collection']) + self.assertEqual(Employee._get_collection_name(), + self.Person._get_collection_name()) # Ensure that MRO error is not raised class A(Document): pass @@ -251,7 +273,7 @@ class DocumentTest(unittest.TestCase): # Check that _cls etc aren't present on simple documents dog = Animal(name='dog') dog.save() - collection = self.db[Animal._meta['collection']] + collection = self.db[Animal._get_collection_name()] obj = collection.find_one() self.assertFalse('_cls' in obj) self.assertFalse('_types' in obj) @@ -297,7 +319,7 @@ class DocumentTest(unittest.TestCase): # Check that _cls etc aren't present on simple documents dog = Animal(name='dog') dog.save() - collection = self.db[Animal._meta['collection']] + collection = self.db[Animal._get_collection_name()] obj = collection.find_one() self.assertFalse('_cls' in obj) self.assertFalse('_types' in obj) @@ -318,7 +340,7 @@ class DocumentTest(unittest.TestCase): dog = Animal(name='dog') dog.save() - collection = self.db[Animal._meta['collection']] + collection = self.db[Animal._get_collection_name()] obj = collection.find_one() self.assertTrue('_cls' in obj) self.assertTrue('_types' in obj) @@ -381,9 +403,12 @@ class DocumentTest(unittest.TestCase): self.assertFalse('collection' in Animal._meta) self.assertFalse('collection' in Mammal._meta) - self.assertEqual(Fish._meta['collection'], 'fish') - self.assertEqual(Guppy._meta['collection'], 'fish') - self.assertEqual(Human._meta['collection'], 'human') + self.assertEqual(Animal._get_collection_name(), None) + self.assertEqual(Mammal._get_collection_name(), None) + + self.assertEqual(Fish._get_collection_name(), 'fish') + self.assertEqual(Guppy._get_collection_name(), 'fish') + self.assertEqual(Human._get_collection_name(), 'human') def create_bad_abstract(): class EvilHuman(Human): @@ -434,14 +459,21 @@ class DocumentTest(unittest.TestCase): def test_inherited_collections(self): """Ensure that subclassed documents don't override parents' collections. """ - class Drink(Document): - name = StringField() + with warnings.catch_warnings(record=True) as w: + # Cause all warnings to always be triggered. + warnings.simplefilter("always") - class AlcoholicDrink(Drink): - meta = {'collection': 'booze'} + class Drink(Document): + name = StringField() - class Drinker(Document): - drink = GenericReferenceField() + class AlcoholicDrink(Drink): + meta = {'collection': 'booze'} + + class Drinker(Document): + drink = GenericReferenceField() + + # Confirm we triggered a SyntaxWarning + assert issubclass(w[0].category, SyntaxWarning) Drink.drop_collection() AlcoholicDrink.drop_collection() @@ -455,7 +487,6 @@ class DocumentTest(unittest.TestCase): beer = AlcoholicDrink(name='Beer') beer.save() - real_person = Drinker(drink=beer) real_person.save() @@ -936,7 +967,7 @@ class DocumentTest(unittest.TestCase): person = self.Person(name='Test User', age=30) person.save() # Ensure that the object is in the database - collection = self.db[self.Person._meta['collection']] + collection = self.db[self.Person._get_collection_name()] person_obj = collection.find_one({'name': 'Test User'}) self.assertEqual(person_obj['name'], 'Test User') self.assertEqual(person_obj['age'], 30) @@ -1279,7 +1310,7 @@ class DocumentTest(unittest.TestCase): id='497ce96f395f2f052a494fd4') person.save() # Ensure that the object is in the database with the correct _id - collection = self.db[self.Person._meta['collection']] + collection = self.db[self.Person._get_collection_name()] person_obj = collection.find_one({'name': 'Test User'}) self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4') @@ -1291,7 +1322,7 @@ class DocumentTest(unittest.TestCase): pk='497ce96f395f2f052a494fd4') person.save() # Ensure that the object is in the database with the correct _id - collection = self.db[self.Person._meta['collection']] + collection = self.db[self.Person._get_collection_name()] person_obj = collection.find_one({'name': 'Test User'}) self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4') @@ -1314,7 +1345,7 @@ class DocumentTest(unittest.TestCase): post.comments = comments post.save() - collection = self.db[BlogPost._meta['collection']] + collection = self.db[BlogPost._get_collection_name()] post_obj = collection.find_one() self.assertEqual(post_obj['tags'], tags) for comment_obj, comment in zip(post_obj['comments'], comments): @@ -1339,7 +1370,7 @@ class DocumentTest(unittest.TestCase): employee.save() # Ensure that the object is in the database - collection = self.db[self.Person._meta['collection']] + collection = self.db[self.Person._get_collection_name()] employee_obj = collection.find_one({'name': 'Test Employee'}) self.assertEqual(employee_obj['name'], 'Test Employee') self.assertEqual(employee_obj['age'], 50) @@ -1370,6 +1401,7 @@ class DocumentTest(unittest.TestCase): promoted_employee.reload() self.assertEqual(promoted_employee.name, 'Test Employee') self.assertEqual(promoted_employee.age, 50) + # Ensure that the 'details' embedded object saved correctly self.assertEqual(promoted_employee.details.position, 'Senior Developer') @@ -1399,7 +1431,7 @@ class DocumentTest(unittest.TestCase): p.save() self.assertEquals(p._fields.keys(), ['name', 'id']) - collection = self.db[Person._meta['collection']] + collection = self.db[Person._get_collection_name()] obj = collection.find_one() self.assertEquals(obj['_cls'], 'Person') self.assertEquals(obj['_types'], ['Person']) @@ -1492,6 +1524,9 @@ class DocumentTest(unittest.TestCase): text = StringField() post = ReferenceField(BlogPost, reverse_delete_rule=CASCADE) + self.Person.drop_collection() + BlogPost.drop_collection() + Comment.drop_collection() author = self.Person(name='Test User') author.save() diff --git a/tests/queryset.py b/tests/queryset.py index 6f0098d5..c5f177c2 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -1,6 +1,4 @@ # -*- coding: utf-8 -*- - - import unittest import pymongo from datetime import datetime, timedelta @@ -27,7 +25,7 @@ class QuerySetTest(unittest.TestCase): """ self.assertTrue(isinstance(self.Person.objects, QuerySet)) self.assertEqual(self.Person.objects._collection.name, - self.Person._meta['collection']) + self.Person._get_collection_name()) self.assertTrue(isinstance(self.Person.objects._collection, pymongo.collection.Collection))