diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index a8d739d9..7422bb8f 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -1146,9 +1146,8 @@ class QuerySetManager(object): return self db = _get_db() - if db not in self._collections: - collection = owner._meta['collection'] - + collection = owner._meta['collection'] + if (db, collection) not in self._collections: # Create collection as a capped collection if specified if owner._meta['max_size'] or owner._meta['max_documents']: # Get max document limit and max byte size from meta @@ -1156,10 +1155,10 @@ class QuerySetManager(object): max_documents = owner._meta['max_documents'] if collection in db.collection_names(): - self._collections[db] = db[collection] + self._collections[(db, collection)] = db[collection] # The collection already exists, check if its capped # options match the specified capped options - options = self._collections[db].options() + options = self._collections[(db, collection)].options() if options.get('max') != max_documents or \ options.get('size') != max_size: msg = ('Cannot create collection "%s" as a capped ' @@ -1170,15 +1169,15 @@ class QuerySetManager(object): opts = {'capped': True, 'size': max_size} if max_documents: opts['max'] = max_documents - self._collections[db] = db.create_collection( + self._collections[(db, collection)] = db.create_collection( collection, **opts ) else: - self._collections[db] = db[collection] + self._collections[(db, collection)] = db[collection] # owner is the document that contains the QuerySetManager queryset_class = owner._meta['queryset_class'] or QuerySet - queryset = queryset_class(owner, self._collections[db]) + queryset = queryset_class(owner, self._collections[(db, collection)]) if self._manager_func: if self._manager_func.func_code.co_argcount == 1: queryset = self._manager_func(queryset) diff --git a/tests/document.py b/tests/document.py index c2e0d6f9..aa901813 100644 --- a/tests/document.py +++ b/tests/document.py @@ -200,6 +200,37 @@ class DocumentTest(unittest.TestCase): Person.drop_collection() self.assertFalse(collection in self.db.collection_names()) + def test_inherited_collections(self): + """Ensure that subclassed documents don't override parents' collections. + """ + class Drink(Document): + name = StringField() + + class AlcoholicDrink(Drink): + meta = {'collection': 'booze'} + + class Drinker(Document): + drink = GenericReferenceField() + + Drink.drop_collection() + AlcoholicDrink.drop_collection() + Drinker.drop_collection() + + red_bull = Drink(name='Red Bull') + red_bull.save() + + programmer = Drinker(drink=red_bull) + programmer.save() + + beer = AlcoholicDrink(name='Beer') + beer.save() + + real_person = Drinker(drink=beer) + real_person.save() + + self.assertEqual(Drinker.objects[0].drink.name, red_bull.name) + self.assertEqual(Drinker.objects[1].drink.name, beer.name) + def test_capped_collection(self): """Ensure that capped collections work properly. """