From 9b30afeca9b5d61ce0bc9759ea93e20cca702ae1 Mon Sep 17 00:00:00 2001 From: DavidBord Date: Sun, 3 Aug 2014 12:54:22 +0300 Subject: [PATCH] fix-#397: Allow specifying the '_cls' as a field for indexes --- AUTHORS | 1 + mongoengine/base/document.py | 5 ++++- mongoengine/base/metaclasses.py | 4 ++++ mongoengine/dereference.py | 4 ++++ mongoengine/queryset/base.py | 7 +++++-- tests/document/class_methods.py | 4 ++-- tests/document/inheritance.py | 4 ++-- tests/document/instance.py | 2 +- tests/fields/fields.py | 23 +++++++++++++++++++++++ 9 files changed, 46 insertions(+), 8 deletions(-) diff --git a/AUTHORS b/AUTHORS index 77115407..8943bc79 100644 --- a/AUTHORS +++ b/AUTHORS @@ -208,3 +208,4 @@ that much better: * Norberto Leite (https://github.com/nleite) * Bob Cribbs (https://github.com/bocribbz) * Jay Shirley (https://github.com/jshirley) + * DavidBord (https://github.com/DavidBord) diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index cfc5de2a..869449f9 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -26,7 +26,7 @@ NON_FIELD_ERRORS = '__all__' class BaseDocument(object): __slots__ = ('_changed_fields', '_initialised', '_created', '_data', - '_dynamic_fields', '_auto_id_field', '_db_field_map', '_cls', '__weakref__') + '_dynamic_fields', '_auto_id_field', '_db_field_map', '__weakref__') _dynamic = False _dynamic_lock = True @@ -78,6 +78,9 @@ class BaseDocument(object): value = getattr(self, key, None) setattr(self, key, value) + if "_cls" not in values: + self._cls = self._class_name + # Set passed values after initialisation if self._dynamic: dynamic_data = {} diff --git a/mongoengine/base/metaclasses.py b/mongoengine/base/metaclasses.py index a4bd0144..e8014553 100644 --- a/mongoengine/base/metaclasses.py +++ b/mongoengine/base/metaclasses.py @@ -47,6 +47,10 @@ class DocumentMetaclass(type): meta.merge(base._meta) attrs['_meta'] = meta + if '_meta' in attrs and attrs['_meta'].get('allow_inheritance', ALLOW_INHERITANCE): + StringField = _import_class('StringField') + attrs['_cls'] = StringField() + # Handle document Fields # Merge all fields from subclasses diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index f9c8ecd6..a22e3473 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -180,7 +180,11 @@ class DeReference(object): return self.object_map.get(items['_ref'].id, items) elif '_cls' in items: doc = get_document(items['_cls'])._from_son(items) + _cls = doc._data.pop('_cls', None) + del items['_cls'] doc._data = self._attach_objects(doc._data, depth, doc, None) + if _cls is not None: + doc._data['_cls'] = _cls return doc if not hasattr(items, 'items'): diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index ce1af2d5..7094dacc 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -1413,8 +1413,11 @@ class BaseQuerySet(object): def _query(self): if self._mongo_query is None: self._mongo_query = self._query_obj.to_query(self._document) - if self._class_check: - self._mongo_query.update(self._initial_query) + if self._class_check and self._initial_query: + if "_cls" in self._mongo_query: + self._mongo_query = {"$and": [self._initial_query, self._mongo_query]} + else: + self._mongo_query.update(self._initial_query) return self._mongo_query @property diff --git a/tests/document/class_methods.py b/tests/document/class_methods.py index 52e3794c..5da474ac 100644 --- a/tests/document/class_methods.py +++ b/tests/document/class_methods.py @@ -36,9 +36,9 @@ class ClassMethodsTest(unittest.TestCase): def test_definition(self): """Ensure that document may be defined using fields. """ - self.assertEqual(['age', 'id', 'name'], + self.assertEqual(['_cls', 'age', 'id', 'name'], sorted(self.Person._fields.keys())) - self.assertEqual(["IntField", "ObjectIdField", "StringField"], + self.assertEqual(["IntField", "ObjectIdField", "StringField", "StringField"], sorted([x.__class__.__name__ for x in self.Person._fields.values()])) diff --git a/tests/document/inheritance.py b/tests/document/inheritance.py index 5a48f75e..566d3699 100644 --- a/tests/document/inheritance.py +++ b/tests/document/inheritance.py @@ -163,7 +163,7 @@ class InheritanceTest(unittest.TestCase): class Employee(Person): salary = IntField() - self.assertEqual(['age', 'id', 'name', 'salary'], + self.assertEqual(['_cls', 'age', 'id', 'name', 'salary'], sorted(Employee._fields.keys())) self.assertEqual(Employee._get_collection_name(), Person._get_collection_name()) @@ -180,7 +180,7 @@ class InheritanceTest(unittest.TestCase): class Employee(Person): salary = IntField() - self.assertEqual(['age', 'id', 'name', 'salary'], + self.assertEqual(['_cls', 'age', 'id', 'name', 'salary'], sorted(Employee._fields.keys())) self.assertEqual(Person(name="Bob", age=35).to_mongo().keys(), ['_cls', 'name', 'age']) diff --git a/tests/document/instance.py b/tests/document/instance.py index c66fe026..f925aeeb 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -462,7 +462,7 @@ class InstanceTest(unittest.TestCase): self.assertEqual(person['name'], 'Another User') # Length = length(assigned fields + id) - self.assertEqual(len(person), 3) + self.assertEqual(len(person), 4) self.assertTrue('age' in person) person.age = None diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 1415fbe8..26a30bfc 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -3034,6 +3034,29 @@ class FieldTest(unittest.TestCase): test.dictionary # Just access to test getter self.assertRaises(ValidationError, test.validate) + def test_cls_field(self): + class Animal(Document): + meta = {'allow_inheritance': True} + + class Fish(Animal): + pass + + class Mammal(Animal): + pass + + class Dog(Mammal): + pass + + class Human(Mammal): + pass + + Animal.objects.delete() + Dog().save() + Fish().save() + Human().save() + self.assertEquals(Animal.objects(_cls__in=["Animal.Mammal.Dog", "Animal.Fish"]).count(), 2) + self.assertEquals(Animal.objects(_cls__in=["Animal.Fish.Guppy"]).count(), 0) + if __name__ == '__main__': unittest.main()