db_alias using in model, queryset, reference fields, derefrence.

This commit is contained in:
Ross Lawley 2011-12-07 01:14:02 -08:00
parent cf4a45da11
commit 8d2bc444bb
5 changed files with 143 additions and 15 deletions

View File

@ -103,13 +103,19 @@ class DeReference(object):
for key, doc in references.iteritems(): for key, doc in references.iteritems():
object_map[key] = doc object_map[key] = doc
else: # Generic reference: use the refs data to convert to document else: # Generic reference: use the refs data to convert to document
references = get_db()[col].find({'_id': {'$in': refs}}) if doc_type:
for ref in references: references = doc_type._get_db()[col].find({'_id': {'$in': refs}})
if '_cls' in ref: for ref in references:
doc = get_document(ref['_cls'])._from_son(ref)
else:
doc = doc_type._from_son(ref) doc = doc_type._from_son(ref)
object_map[doc.id] = doc object_map[doc.id] = doc
else:
references = get_db()[col].find({'_id': {'$in': refs}})
for ref in references:
if '_cls' in ref:
doc = get_document(ref["_cls"])._from_son(ref)
else:
doc = doc_type._from_son(ref)
object_map[doc.id] = doc
return object_map return object_map
def _attach_objects(self, items, depth=0, instance=None, name=None): def _attach_objects(self, items, depth=0, instance=None, name=None):

View File

@ -2,7 +2,7 @@ from mongoengine import signals
from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument,
BaseDict, BaseList, DataObserver) BaseDict, BaseList, DataObserver)
from queryset import OperationError from queryset import OperationError
from connection import get_db from connection import get_db, DEFAULT_CONNECTION_NAME
import pymongo import pymongo
@ -86,11 +86,16 @@ class Document(BaseDocument):
return setattr(self, self._meta['id_field'], value) return setattr(self, self._meta['id_field'], value)
return property(fget, fset) return property(fget, fset)
@classmethod
def _get_db(self):
"""Some Model using other db_alias"""
return get_db(self._meta.get("db_alias", DEFAULT_CONNECTION_NAME ))
@classmethod @classmethod
def _get_collection(self): def _get_collection(self):
"""Returns the collection for the document.""" """Returns the collection for the document."""
if not hasattr(self, '_collection') or self._collection is None: if not hasattr(self, '_collection') or self._collection is None:
db = get_db() db = self._get_db()
collection_name = self._get_collection_name() collection_name = self._get_collection_name()
# Create collection as a capped collection if specified # Create collection as a capped collection if specified
if self._meta['max_size'] or self._meta['max_documents']: if self._meta['max_size'] or self._meta['max_documents']:
@ -318,7 +323,7 @@ class Document(BaseDocument):
:class:`~mongoengine.Document` type from the database. :class:`~mongoengine.Document` type from the database.
""" """
from mongoengine.queryset import QuerySet from mongoengine.queryset import QuerySet
db = get_db() db = cls._get_db()
db.drop_collection(cls._get_collection_name()) db.drop_collection(cls._get_collection_name())
QuerySet._reset_already_indexed(cls) QuerySet._reset_already_indexed(cls)

View File

@ -645,7 +645,7 @@ class ReferenceField(BaseField):
value = instance._data.get(self.name) value = instance._data.get(self.name)
# Dereference DBRefs # Dereference DBRefs
if isinstance(value, (pymongo.dbref.DBRef)): if isinstance(value, (pymongo.dbref.DBRef)):
value = get_db().dereference(value) value = self.document_type._get_db().dereference(value)
if value is not None: if value is not None:
instance._data[self.name] = self.document_type._from_son(value) instance._data[self.name] = self.document_type._from_son(value)
@ -718,7 +718,7 @@ class GenericReferenceField(BaseField):
def dereference(self, value): def dereference(self, value):
doc_cls = get_document(value['_cls']) doc_cls = get_document(value['_cls'])
reference = value['_ref'] reference = value['_ref']
doc = get_db().dereference(reference) doc = doc_cls._get_db().dereference(reference)
if doc is not None: if doc is not None:
doc = doc_cls._from_son(doc) doc = doc_cls._from_son(doc)
return doc return doc
@ -1162,8 +1162,9 @@ class SequenceField(IntField):
.. versionadded:: 0.5 .. versionadded:: 0.5
""" """
def __init__(self, collection_name=None, *args, **kwargs): def __init__(self, collection_name=None, db_alias = None, *args, **kwargs):
self.collection_name = collection_name or 'mongoengine.counters' self.collection_name = collection_name or 'mongoengine.counters'
self.db_alias = db_alias or DEFAULT_CONNECTION_NAME
return super(SequenceField, self).__init__(*args, **kwargs) return super(SequenceField, self).__init__(*args, **kwargs)
def generate_new_value(self): def generate_new_value(self):
@ -1172,7 +1173,7 @@ class SequenceField(IntField):
""" """
sequence_id = "{0}.{1}".format(self.owner_document._get_collection_name(), sequence_id = "{0}.{1}".format(self.owner_document._get_collection_name(),
self.name) self.name)
collection = get_db()[self.collection_name] collection = get_db(alias = self.db_alias )[self.collection_name]
counter = collection.find_and_modify(query={"_id": sequence_id}, counter = collection.find_and_modify(query={"_id": sequence_id},
update={"$inc": {"next": 1}}, update={"$inc": {"next": 1}},
new=True, new=True,

View File

@ -481,7 +481,7 @@ class QuerySet(object):
if self._document not in QuerySet.__already_indexed: if self._document not in QuerySet.__already_indexed:
# Ensure collection exists # Ensure collection exists
db = get_db() db = self._document._get_db()
if self._collection_obj.name not in db.collection_names(): if self._collection_obj.name not in db.collection_names():
self._document._collection = None self._document._collection = None
self._collection_obj = self._document._get_collection() self._collection_obj = self._document._get_collection()
@ -1452,7 +1452,7 @@ class QuerySet(object):
scope['query'] = query scope['query'] = query
code = pymongo.code.Code(code, scope=scope) code = pymongo.code.Code(code, scope=scope)
db = get_db() db = self._document._get_db()
return db.eval(code, *fields) return db.eval(code, *fields)
def where(self, where_clause): def where(self, where_clause):

View File

@ -2528,6 +2528,122 @@ class DocumentTest(unittest.TestCase):
finally: finally:
Collection.update = orig_update Collection.update = orig_update
def test_db_alias_tests(self):
""" DB Alias tests """
# mongoenginetest - Is default connection alias from setUp()
# Register Aliases
register_connection('testdb-1', "mongoenginetest2" )
register_connection('testdb-2', "mongoenginetest3" )
register_connection('testdb-3', 'mongoenginetest4')
class User(Document):
name = StringField()
class Book(Document):
name = StringField()
meta = {"db_alias" : "testdb-1" }
# Drops
User.drop_collection()
Book.drop_collection()
# Create
bob = User.objects.create(name = "Bob")
hp = Book.objects.create(name = "Harry Potter")
# Selects
self.assertEqual( User.objects.first(), bob)
self.assertEqual( Book.objects.first(), hp)
# DeRefecence
class AuthorBooks(Document):
author = ReferenceField(User)
book = ReferenceField(Book)
meta = {"db_alias" : "testdb-2" }
# Drops
AuthorBooks.drop_collection()
ab = AuthorBooks.objects.create( author = bob, book = hp)
# select
self.assertEqual( AuthorBooks.objects.first(), ab)
self.assertEqual( AuthorBooks.objects.first().book, hp)
self.assertEqual( AuthorBooks.objects.first().author, bob)
def test_db_ref_usage(self):
""" DB Ref usage in __raw__ queries """
class User(Document):
name = StringField()
class Book(Document):
name = StringField()
author = ReferenceField(User)
extra = DictField()
meta = {
'ordering': ['+name']
}
def __unicode__(self):
return self.name
def __str__(self):
return self.name
# Drops
User.drop_collection()
Book.drop_collection()
# Authors
bob = User.objects.create(name = "Bob")
jon = User.objects.create(name = "Jon")
# Redactors
karl = User.objects.create( name = "Karl")
susan = User.objects.create( name = "Susan")
peter = User.objects.create( name = "Peter")
# Bob
Book.objects.create( name = "1", author = bob, extra = {"a": bob.to_dbref(), "b" : [karl.to_dbref(), susan.to_dbref(),] } )
Book.objects.create( name = "2", author = bob, extra = {"a": bob.to_dbref(), "b" : karl.to_dbref()} )
Book.objects.create( name = "3", author = bob, extra = {"a": bob.to_dbref(), "c" : [jon.to_dbref(), peter.to_dbref() ] })
Book.objects.create( name = "4", author = bob,)
# Jon
Book.objects.create( name = "5", author = jon,)
Book.objects.create( name = "6", author = peter,)
Book.objects.create( name = "7", author = jon,)
Book.objects.create( name = "8", author = jon,)
Book.objects.create( name = "9", author = jon, extra = {"a": peter.to_dbref() })
# Checks
self.assertEqual(u",".join([str(b) for b in Book.objects.all()] ) , "1,2,3,4,5,6,7,8,9" )
# bob related books
self.assertEqual(u",".join([str(b) for b in Book.objects.filter(
Q(extra__a = bob ) |
Q(author = bob) |
Q(extra__b = bob ) )] ) ,
"1,2,3,4" )
# Susan & Karl related books
self.assertEqual(u",".join([str(b) for b in Book.objects.filter(
Q(extra__a__all = [karl, susan] ) |
Q(author__all = [karl, susan ] ) |
Q(extra__b__all = [karl.to_dbref(), susan.to_dbref()] )
) ] ) , "1" )
# $Where
self.assertEqual(u",".join([str(b) for b in Book.objects.filter(
__raw__ = {
"$where" : """function(){ return this.name == '1' || this.name == '2'; } """
}
) ] ) , "1,2" )
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()