diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index 0d8d7f15..3dac92b5 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -103,13 +103,19 @@ class DeReference(object): for key, doc in references.iteritems(): object_map[key] = doc else: # Generic reference: use the refs data to convert to document - references = get_db()[col].find({'_id': {'$in': refs}}) - for ref in references: - if '_cls' in ref: - doc = get_document(ref['_cls'])._from_son(ref) - else: + if doc_type: + references = doc_type._get_db()[col].find({'_id': {'$in': refs}}) + for ref in references: doc = doc_type._from_son(ref) - object_map[doc.id] = doc + object_map[doc.id] = doc + else: + references = get_db()[col].find({'_id': {'$in': refs}}) + for ref in references: + if '_cls' in ref: + doc = get_document(ref["_cls"])._from_son(ref) + else: + doc = doc_type._from_son(ref) + object_map[doc.id] = doc return object_map def _attach_objects(self, items, depth=0, instance=None, name=None): diff --git a/mongoengine/document.py b/mongoengine/document.py index 3f4c9d77..e56e4abe 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -2,7 +2,7 @@ from mongoengine import signals from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, BaseDict, BaseList, DataObserver) from queryset import OperationError -from connection import get_db +from connection import get_db, DEFAULT_CONNECTION_NAME import pymongo @@ -86,11 +86,16 @@ class Document(BaseDocument): return setattr(self, self._meta['id_field'], value) return property(fget, fset) + @classmethod + def _get_db(self): + """Some Model using other db_alias""" + return get_db(self._meta.get("db_alias", DEFAULT_CONNECTION_NAME )) + @classmethod def _get_collection(self): """Returns the collection for the document.""" if not hasattr(self, '_collection') or self._collection is None: - db = get_db() + db = self._get_db() collection_name = self._get_collection_name() # Create collection as a capped collection if specified if self._meta['max_size'] or self._meta['max_documents']: @@ -318,7 +323,7 @@ class Document(BaseDocument): :class:`~mongoengine.Document` type from the database. """ from mongoengine.queryset import QuerySet - db = get_db() + db = cls._get_db() db.drop_collection(cls._get_collection_name()) QuerySet._reset_already_indexed(cls) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 914cb925..2f5911f9 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -645,7 +645,7 @@ class ReferenceField(BaseField): value = instance._data.get(self.name) # Dereference DBRefs if isinstance(value, (pymongo.dbref.DBRef)): - value = get_db().dereference(value) + value = self.document_type._get_db().dereference(value) if value is not None: instance._data[self.name] = self.document_type._from_son(value) @@ -718,7 +718,7 @@ class GenericReferenceField(BaseField): def dereference(self, value): doc_cls = get_document(value['_cls']) reference = value['_ref'] - doc = get_db().dereference(reference) + doc = doc_cls._get_db().dereference(reference) if doc is not None: doc = doc_cls._from_son(doc) return doc @@ -1162,8 +1162,9 @@ class SequenceField(IntField): .. versionadded:: 0.5 """ - def __init__(self, collection_name=None, *args, **kwargs): + def __init__(self, collection_name=None, db_alias = None, *args, **kwargs): self.collection_name = collection_name or 'mongoengine.counters' + self.db_alias = db_alias or DEFAULT_CONNECTION_NAME return super(SequenceField, self).__init__(*args, **kwargs) def generate_new_value(self): @@ -1172,7 +1173,7 @@ class SequenceField(IntField): """ sequence_id = "{0}.{1}".format(self.owner_document._get_collection_name(), self.name) - collection = get_db()[self.collection_name] + collection = get_db(alias = self.db_alias )[self.collection_name] counter = collection.find_and_modify(query={"_id": sequence_id}, update={"$inc": {"next": 1}}, new=True, diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index ce9b175c..0c39253b 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -481,7 +481,7 @@ class QuerySet(object): if self._document not in QuerySet.__already_indexed: # Ensure collection exists - db = get_db() + db = self._document._get_db() if self._collection_obj.name not in db.collection_names(): self._document._collection = None self._collection_obj = self._document._get_collection() @@ -1452,7 +1452,7 @@ class QuerySet(object): scope['query'] = query code = pymongo.code.Code(code, scope=scope) - db = get_db() + db = self._document._get_db() return db.eval(code, *fields) def where(self, where_clause): diff --git a/tests/document.py b/tests/document.py index 5f3bc635..0d9f8d7d 100644 --- a/tests/document.py +++ b/tests/document.py @@ -2528,6 +2528,122 @@ class DocumentTest(unittest.TestCase): finally: Collection.update = orig_update + def test_db_alias_tests(self): + """ DB Alias tests """ + # mongoenginetest - Is default connection alias from setUp() + # Register Aliases + register_connection('testdb-1', "mongoenginetest2" ) + register_connection('testdb-2', "mongoenginetest3" ) + register_connection('testdb-3', 'mongoenginetest4') + + class User(Document): + name = StringField() + class Book(Document): + name = StringField() + meta = {"db_alias" : "testdb-1" } + + # Drops + User.drop_collection() + Book.drop_collection() + + # Create + bob = User.objects.create(name = "Bob") + hp = Book.objects.create(name = "Harry Potter") + + # Selects + self.assertEqual( User.objects.first(), bob) + self.assertEqual( Book.objects.first(), hp) + + # DeRefecence + class AuthorBooks(Document): + author = ReferenceField(User) + book = ReferenceField(Book) + meta = {"db_alias" : "testdb-2" } + # Drops + AuthorBooks.drop_collection() + + ab = AuthorBooks.objects.create( author = bob, book = hp) + # select + self.assertEqual( AuthorBooks.objects.first(), ab) + self.assertEqual( AuthorBooks.objects.first().book, hp) + self.assertEqual( AuthorBooks.objects.first().author, bob) + + + + + + + + + def test_db_ref_usage(self): + """ DB Ref usage in __raw__ queries """ + class User(Document): + name = StringField() + + class Book(Document): + name = StringField() + author = ReferenceField(User) + extra = DictField() + meta = { + 'ordering': ['+name'] + } + + def __unicode__(self): + return self.name + def __str__(self): + return self.name + + + # Drops + User.drop_collection() + Book.drop_collection() + + # Authors + bob = User.objects.create(name = "Bob") + jon = User.objects.create(name = "Jon") + + # Redactors + karl = User.objects.create( name = "Karl") + susan = User.objects.create( name = "Susan") + peter = User.objects.create( name = "Peter") + + # Bob + Book.objects.create( name = "1", author = bob, extra = {"a": bob.to_dbref(), "b" : [karl.to_dbref(), susan.to_dbref(),] } ) + Book.objects.create( name = "2", author = bob, extra = {"a": bob.to_dbref(), "b" : karl.to_dbref()} ) + Book.objects.create( name = "3", author = bob, extra = {"a": bob.to_dbref(), "c" : [jon.to_dbref(), peter.to_dbref() ] }) + Book.objects.create( name = "4", author = bob,) + + # Jon + Book.objects.create( name = "5", author = jon,) + Book.objects.create( name = "6", author = peter,) + Book.objects.create( name = "7", author = jon,) + Book.objects.create( name = "8", author = jon,) + Book.objects.create( name = "9", author = jon, extra = {"a": peter.to_dbref() }) + + # Checks + self.assertEqual(u",".join([str(b) for b in Book.objects.all()] ) , "1,2,3,4,5,6,7,8,9" ) + # bob related books + self.assertEqual(u",".join([str(b) for b in Book.objects.filter( + Q(extra__a = bob ) | + Q(author = bob) | + Q(extra__b = bob ) )] ) , + "1,2,3,4" ) + # Susan & Karl related books + self.assertEqual(u",".join([str(b) for b in Book.objects.filter( + Q(extra__a__all = [karl, susan] ) | + Q(author__all = [karl, susan ] ) | + Q(extra__b__all = [karl.to_dbref(), susan.to_dbref()] ) + ) ] ) , "1" ) + + # $Where + self.assertEqual(u",".join([str(b) for b in Book.objects.filter( + __raw__ = { + "$where" : """function(){ return this.name == '1' || this.name == '2'; } """ + } + ) ] ) , "1,2" ) + + + if __name__ == '__main__': unittest.main()