From 4209d61b1368717047927cee40a9d64768def93a Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Wed, 10 Jul 2013 12:49:19 +0000 Subject: [PATCH] Document.select_related() now respects `db_alias` (#377) --- docs/changelog.rst | 1 + mongoengine/document.py | 4 ++-- tests/fields/fields.py | 31 +++++++++++++++++++++++++++ tests/test_dereference.py | 45 +++++++++++++++++---------------------- 4 files changed, 54 insertions(+), 27 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 42fd9bb1..926c6cb1 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,7 @@ Changelog Changes in 0.8.3 ================ +- Document.select_related() now respects `db_alias` (#377) - Reload uses shard_key if applicable (#384) - Dynamic fields are ordered based on creation and stored in _fields_ordered (#396) - Fixed pickling dynamic documents `_dynamic_fields` (#387) diff --git a/mongoengine/document.py b/mongoengine/document.py index c9901a2a..e331aa1f 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -440,8 +440,8 @@ class Document(BaseDocument): .. versionadded:: 0.5 """ - import dereference - self._data = dereference.DeReference()(self._data, max_depth) + DeReference = _import_class('DeReference') + DeReference()([self], max_depth + 1) return self def reload(self, max_depth=1): diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 8f024991..b3d8d52e 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -2474,6 +2474,37 @@ class FieldTest(unittest.TestCase): user = User(email='me@example.com') self.assertTrue(user.validate() is None) + def test_tuples_as_tuples(self): + """ + Ensure that tuples remain tuples when they are + inside a ComplexBaseField + """ + from mongoengine.base import BaseField + + class EnumField(BaseField): + + def __init__(self, **kwargs): + super(EnumField, self).__init__(**kwargs) + + def to_mongo(self, value): + return value + + def to_python(self, value): + return tuple(value) + + class TestDoc(Document): + items = ListField(EnumField()) + + TestDoc.drop_collection() + tuples = [(100, 'Testing')] + doc = TestDoc() + doc.items = tuples + doc.save() + x = TestDoc.objects().get() + self.assertTrue(x is not None) + self.assertTrue(len(x.items) == 1) + self.assertTrue(tuple(x.items[0]) in tuples) + self.assertTrue(x.items[0] in tuples) if __name__ == '__main__': unittest.main() diff --git a/tests/test_dereference.py b/tests/test_dereference.py index e146963f..db9868a0 100644 --- a/tests/test_dereference.py +++ b/tests/test_dereference.py @@ -1121,37 +1121,32 @@ class FieldTest(unittest.TestCase): self.assertEqual(q, 2) - def test_tuples_as_tuples(self): - """ - Ensure that tuples remain tuples when they are - inside a ComplexBaseField - """ - from mongoengine.base import BaseField + def test_objectid_reference_across_databases(self): + # mongoenginetest - Is default connection alias from setUp() + # Register Aliases + register_connection('testdb-1', 'mongoenginetest2') - class EnumField(BaseField): + class User(Document): + name = StringField() + meta = {"db_alias": "testdb-1"} - def __init__(self, **kwargs): - super(EnumField, self).__init__(**kwargs) + class Book(Document): + name = StringField() + author = ReferenceField(User) - def to_mongo(self, value): - return value + # Drops + User.drop_collection() + Book.drop_collection() - def to_python(self, value): - return tuple(value) + user = User(name="Ross").save() + Book(name="MongoEngine for pros", author=user).save() - class TestDoc(Document): - items = ListField(EnumField()) + # Can't use query_counter across databases - so test the _data object + book = Book.objects.first() + self.assertFalse(isinstance(book._data['author'], User)) - TestDoc.drop_collection() - tuples = [(100, 'Testing')] - doc = TestDoc() - doc.items = tuples - doc.save() - x = TestDoc.objects().get() - self.assertTrue(x is not None) - self.assertTrue(len(x.items) == 1) - self.assertTrue(tuple(x.items[0]) in tuples) - self.assertTrue(x.items[0] in tuples) + book.select_related() + self.assertTrue(isinstance(book._data['author'], User)) def test_non_ascii_pk(self): """