db_alias using in model, queryset, reference fields, derefrence.

This commit is contained in:
Ross Lawley
2011-12-07 01:14:02 -08:00
parent cf4a45da11
commit 8d2bc444bb
5 changed files with 143 additions and 15 deletions

View File

@@ -103,13 +103,19 @@ class DeReference(object):
for key, doc in references.iteritems():
object_map[key] = doc
else: # Generic reference: use the refs data to convert to document
references = get_db()[col].find({'_id': {'$in': refs}})
for ref in references:
if '_cls' in ref:
doc = get_document(ref['_cls'])._from_son(ref)
else:
if doc_type:
references = doc_type._get_db()[col].find({'_id': {'$in': refs}})
for ref in references:
doc = doc_type._from_son(ref)
object_map[doc.id] = doc
object_map[doc.id] = doc
else:
references = get_db()[col].find({'_id': {'$in': refs}})
for ref in references:
if '_cls' in ref:
doc = get_document(ref["_cls"])._from_son(ref)
else:
doc = doc_type._from_son(ref)
object_map[doc.id] = doc
return object_map
def _attach_objects(self, items, depth=0, instance=None, name=None):

View File

@@ -2,7 +2,7 @@ from mongoengine import signals
from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument,
BaseDict, BaseList, DataObserver)
from queryset import OperationError
from connection import get_db
from connection import get_db, DEFAULT_CONNECTION_NAME
import pymongo
@@ -86,11 +86,16 @@ class Document(BaseDocument):
return setattr(self, self._meta['id_field'], value)
return property(fget, fset)
@classmethod
def _get_db(self):
"""Some Model using other db_alias"""
return get_db(self._meta.get("db_alias", DEFAULT_CONNECTION_NAME ))
@classmethod
def _get_collection(self):
"""Returns the collection for the document."""
if not hasattr(self, '_collection') or self._collection is None:
db = get_db()
db = self._get_db()
collection_name = self._get_collection_name()
# Create collection as a capped collection if specified
if self._meta['max_size'] or self._meta['max_documents']:
@@ -318,7 +323,7 @@ class Document(BaseDocument):
:class:`~mongoengine.Document` type from the database.
"""
from mongoengine.queryset import QuerySet
db = get_db()
db = cls._get_db()
db.drop_collection(cls._get_collection_name())
QuerySet._reset_already_indexed(cls)

View File

@@ -645,7 +645,7 @@ class ReferenceField(BaseField):
value = instance._data.get(self.name)
# Dereference DBRefs
if isinstance(value, (pymongo.dbref.DBRef)):
value = get_db().dereference(value)
value = self.document_type._get_db().dereference(value)
if value is not None:
instance._data[self.name] = self.document_type._from_son(value)
@@ -718,7 +718,7 @@ class GenericReferenceField(BaseField):
def dereference(self, value):
doc_cls = get_document(value['_cls'])
reference = value['_ref']
doc = get_db().dereference(reference)
doc = doc_cls._get_db().dereference(reference)
if doc is not None:
doc = doc_cls._from_son(doc)
return doc
@@ -1162,8 +1162,9 @@ class SequenceField(IntField):
.. versionadded:: 0.5
"""
def __init__(self, collection_name=None, *args, **kwargs):
def __init__(self, collection_name=None, db_alias = None, *args, **kwargs):
self.collection_name = collection_name or 'mongoengine.counters'
self.db_alias = db_alias or DEFAULT_CONNECTION_NAME
return super(SequenceField, self).__init__(*args, **kwargs)
def generate_new_value(self):
@@ -1172,7 +1173,7 @@ class SequenceField(IntField):
"""
sequence_id = "{0}.{1}".format(self.owner_document._get_collection_name(),
self.name)
collection = get_db()[self.collection_name]
collection = get_db(alias = self.db_alias )[self.collection_name]
counter = collection.find_and_modify(query={"_id": sequence_id},
update={"$inc": {"next": 1}},
new=True,

View File

@@ -481,7 +481,7 @@ class QuerySet(object):
if self._document not in QuerySet.__already_indexed:
# Ensure collection exists
db = get_db()
db = self._document._get_db()
if self._collection_obj.name not in db.collection_names():
self._document._collection = None
self._collection_obj = self._document._get_collection()
@@ -1452,7 +1452,7 @@ class QuerySet(object):
scope['query'] = query
code = pymongo.code.Code(code, scope=scope)
db = get_db()
db = self._document._get_db()
return db.eval(code, *fields)
def where(self, where_clause):