Allowed to switch databases for a specific query.
This commit is contained in:
parent
f099dc6a37
commit
5ae588833b
@ -1,6 +1,5 @@
|
|||||||
from mongoengine.common import _import_class
|
from mongoengine.common import _import_class
|
||||||
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
|
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
|
||||||
from mongoengine.queryset import QuerySet
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ("switch_db", "switch_collection", "no_dereference",
|
__all__ = ("switch_db", "switch_collection", "no_dereference",
|
||||||
@ -162,12 +161,6 @@ class no_sub_classes(object):
|
|||||||
return self.cls
|
return self.cls
|
||||||
|
|
||||||
|
|
||||||
class QuerySetNoDeRef(QuerySet):
|
|
||||||
"""Special no_dereference QuerySet"""
|
|
||||||
def __dereference(items, max_depth=1, instance=None, name=None):
|
|
||||||
return items
|
|
||||||
|
|
||||||
|
|
||||||
class query_counter(object):
|
class query_counter(object):
|
||||||
""" Query_counter context manager to get the number of queries. """
|
""" Query_counter context manager to get the number of queries. """
|
||||||
|
|
||||||
|
@ -13,11 +13,11 @@ import pymongo
|
|||||||
from pymongo.common import validate_read_preference
|
from pymongo.common import validate_read_preference
|
||||||
|
|
||||||
from mongoengine import signals
|
from mongoengine import signals
|
||||||
|
from mongoengine.context_managers import switch_db
|
||||||
from mongoengine.common import _import_class
|
from mongoengine.common import _import_class
|
||||||
from mongoengine.base.common import get_document
|
from mongoengine.base.common import get_document
|
||||||
from mongoengine.errors import (OperationError, NotUniqueError,
|
from mongoengine.errors import (OperationError, NotUniqueError,
|
||||||
InvalidQueryError, LookUpError)
|
InvalidQueryError, LookUpError)
|
||||||
|
|
||||||
from mongoengine.queryset import transform
|
from mongoengine.queryset import transform
|
||||||
from mongoengine.queryset.field_list import QueryFieldList
|
from mongoengine.queryset.field_list import QueryFieldList
|
||||||
from mongoengine.queryset.visitor import Q, QNode
|
from mongoengine.queryset.visitor import Q, QNode
|
||||||
@ -389,7 +389,7 @@ class BaseQuerySet(object):
|
|||||||
ref_q = document_cls.objects(**{field_name + '__in': self})
|
ref_q = document_cls.objects(**{field_name + '__in': self})
|
||||||
ref_q_count = ref_q.count()
|
ref_q_count = ref_q.count()
|
||||||
if (doc != document_cls and ref_q_count > 0
|
if (doc != document_cls and ref_q_count > 0
|
||||||
or (doc == document_cls and ref_q_count > 0)):
|
or (doc == document_cls and ref_q_count > 0)):
|
||||||
ref_q.delete(write_concern=write_concern)
|
ref_q.delete(write_concern=write_concern)
|
||||||
elif rule == NULLIFY:
|
elif rule == NULLIFY:
|
||||||
document_cls.objects(**{field_name + '__in': self}).update(
|
document_cls.objects(**{field_name + '__in': self}).update(
|
||||||
@ -522,6 +522,19 @@ class BaseQuerySet(object):
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def using(self, alias):
|
||||||
|
"""This method is for controlling which database the QuerySet will be evaluated against if you are using more than one database.
|
||||||
|
|
||||||
|
:param alias: The database alias
|
||||||
|
|
||||||
|
.. versionadded:: 0.8
|
||||||
|
"""
|
||||||
|
|
||||||
|
with switch_db(self._document, alias) as cls:
|
||||||
|
collection = cls._get_collection()
|
||||||
|
|
||||||
|
return self.clone_into(self.__class__(self._document, collection))
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
"""Creates a copy of the current
|
"""Creates a copy of the current
|
||||||
:class:`~mongoengine.queryset.QuerySet`
|
:class:`~mongoengine.queryset.QuerySet`
|
||||||
@ -926,7 +939,7 @@ class BaseQuerySet(object):
|
|||||||
mr_args['out'] = output
|
mr_args['out'] = output
|
||||||
|
|
||||||
results = getattr(queryset._collection, map_reduce_function)(
|
results = getattr(queryset._collection, map_reduce_function)(
|
||||||
map_f, reduce_f, **mr_args)
|
map_f, reduce_f, **mr_args)
|
||||||
|
|
||||||
if map_reduce_function == 'map_reduce':
|
if map_reduce_function == 'map_reduce':
|
||||||
results = results.find()
|
results = results.find()
|
||||||
@ -1362,7 +1375,7 @@ class BaseQuerySet(object):
|
|||||||
for subdoc in subclasses:
|
for subdoc in subclasses:
|
||||||
try:
|
try:
|
||||||
subfield = ".".join(f.db_field for f in
|
subfield = ".".join(f.db_field for f in
|
||||||
subdoc._lookup_field(field.split('.')))
|
subdoc._lookup_field(field.split('.')))
|
||||||
ret.append(subfield)
|
ret.append(subfield)
|
||||||
found = True
|
found = True
|
||||||
break
|
break
|
||||||
@ -1450,6 +1463,7 @@ class BaseQuerySet(object):
|
|||||||
# type of this field and use the corresponding
|
# type of this field and use the corresponding
|
||||||
# .to_python(...)
|
# .to_python(...)
|
||||||
from mongoengine.fields import EmbeddedDocumentField
|
from mongoengine.fields import EmbeddedDocumentField
|
||||||
|
|
||||||
obj = self._document
|
obj = self._document
|
||||||
for chunk in path.split('.'):
|
for chunk in path.split('.'):
|
||||||
obj = getattr(obj, chunk, None)
|
obj = getattr(obj, chunk, None)
|
||||||
@ -1460,6 +1474,7 @@ class BaseQuerySet(object):
|
|||||||
if obj and data is not None:
|
if obj and data is not None:
|
||||||
data = obj.to_python(data)
|
data = obj.to_python(data)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
return clean(row)
|
return clean(row)
|
||||||
|
|
||||||
def _sub_js_fields(self, code):
|
def _sub_js_fields(self, code):
|
||||||
@ -1468,6 +1483,7 @@ class BaseQuerySet(object):
|
|||||||
substituted for the MongoDB name of the field (specified using the
|
substituted for the MongoDB name of the field (specified using the
|
||||||
:attr:`name` keyword argument in a field's constructor).
|
:attr:`name` keyword argument in a field's constructor).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def field_sub(match):
|
def field_sub(match):
|
||||||
# Extract just the field name, and look up the field objects
|
# Extract just the field name, and look up the field objects
|
||||||
field_name = match.group(1).split('.')
|
field_name = match.group(1).split('.')
|
||||||
|
@ -155,3 +155,10 @@ class QuerySetNoCache(BaseQuerySet):
|
|||||||
queryset = self.clone()
|
queryset = self.clone()
|
||||||
queryset.rewind()
|
queryset.rewind()
|
||||||
return queryset
|
return queryset
|
||||||
|
|
||||||
|
|
||||||
|
class QuerySetNoDeRef(QuerySet):
|
||||||
|
"""Special no_dereference QuerySet"""
|
||||||
|
|
||||||
|
def __dereference(items, max_depth=1, instance=None, name=None):
|
||||||
|
return items
|
@ -29,6 +29,7 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
connect(db='mongoenginetest')
|
connect(db='mongoenginetest')
|
||||||
|
connect(db='mongoenginetest2', alias='test2')
|
||||||
|
|
||||||
class PersonMeta(EmbeddedDocument):
|
class PersonMeta(EmbeddedDocument):
|
||||||
weight = IntField()
|
weight = IntField()
|
||||||
@ -2957,6 +2958,21 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
|
|
||||||
Number.drop_collection()
|
Number.drop_collection()
|
||||||
|
|
||||||
|
def test_using(self):
|
||||||
|
"""Ensure that switching databases for a queryset is possible
|
||||||
|
"""
|
||||||
|
class Number2(Document):
|
||||||
|
n = IntField()
|
||||||
|
|
||||||
|
Number2.drop_collection()
|
||||||
|
|
||||||
|
for i in xrange(1, 10):
|
||||||
|
t = Number2(n=i)
|
||||||
|
t.switch_db('test2')
|
||||||
|
t.save()
|
||||||
|
|
||||||
|
self.assertEqual(len(Number2.objects.using('test2')), 9)
|
||||||
|
|
||||||
def test_unset_reference(self):
|
def test_unset_reference(self):
|
||||||
class Comment(Document):
|
class Comment(Document):
|
||||||
text = StringField()
|
text = StringField()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user