Merge pull request #719 from DavidBord/fix-397

fix-#397: Allow specifying the '_cls' as a field for indexes
This commit is contained in:
Yohan Graterol 2014-08-26 10:48:10 -05:00
commit 5b9c70ae22
9 changed files with 46 additions and 8 deletions

View File

@ -208,3 +208,4 @@ that much better:
* Norberto Leite (https://github.com/nleite)
* Bob Cribbs (https://github.com/bocribbz)
* Jay Shirley (https://github.com/jshirley)
* DavidBord (https://github.com/DavidBord)

View File

@ -26,7 +26,7 @@ NON_FIELD_ERRORS = '__all__'
class BaseDocument(object):
__slots__ = ('_changed_fields', '_initialised', '_created', '_data',
'_dynamic_fields', '_auto_id_field', '_db_field_map', '_cls', '__weakref__')
'_dynamic_fields', '_auto_id_field', '_db_field_map', '__weakref__')
_dynamic = False
_dynamic_lock = True
@ -78,6 +78,9 @@ class BaseDocument(object):
value = getattr(self, key, None)
setattr(self, key, value)
if "_cls" not in values:
self._cls = self._class_name
# Set passed values after initialisation
if self._dynamic:
dynamic_data = {}

View File

@ -47,6 +47,10 @@ class DocumentMetaclass(type):
meta.merge(base._meta)
attrs['_meta'] = meta
if '_meta' in attrs and attrs['_meta'].get('allow_inheritance', ALLOW_INHERITANCE):
StringField = _import_class('StringField')
attrs['_cls'] = StringField()
# Handle document Fields
# Merge all fields from subclasses

View File

@ -180,7 +180,11 @@ class DeReference(object):
return self.object_map.get(items['_ref'].id, items)
elif '_cls' in items:
doc = get_document(items['_cls'])._from_son(items)
_cls = doc._data.pop('_cls', None)
del items['_cls']
doc._data = self._attach_objects(doc._data, depth, doc, None)
if _cls is not None:
doc._data['_cls'] = _cls
return doc
if not hasattr(items, 'items'):

View File

@ -1413,7 +1413,10 @@ class BaseQuerySet(object):
def _query(self):
if self._mongo_query is None:
self._mongo_query = self._query_obj.to_query(self._document)
if self._class_check:
if self._class_check and self._initial_query:
if "_cls" in self._mongo_query:
self._mongo_query = {"$and": [self._initial_query, self._mongo_query]}
else:
self._mongo_query.update(self._initial_query)
return self._mongo_query

View File

@ -36,9 +36,9 @@ class ClassMethodsTest(unittest.TestCase):
def test_definition(self):
"""Ensure that document may be defined using fields.
"""
self.assertEqual(['age', 'id', 'name'],
self.assertEqual(['_cls', 'age', 'id', 'name'],
sorted(self.Person._fields.keys()))
self.assertEqual(["IntField", "ObjectIdField", "StringField"],
self.assertEqual(["IntField", "ObjectIdField", "StringField", "StringField"],
sorted([x.__class__.__name__ for x in
self.Person._fields.values()]))

View File

@ -163,7 +163,7 @@ class InheritanceTest(unittest.TestCase):
class Employee(Person):
salary = IntField()
self.assertEqual(['age', 'id', 'name', 'salary'],
self.assertEqual(['_cls', 'age', 'id', 'name', 'salary'],
sorted(Employee._fields.keys()))
self.assertEqual(Employee._get_collection_name(),
Person._get_collection_name())
@ -180,7 +180,7 @@ class InheritanceTest(unittest.TestCase):
class Employee(Person):
salary = IntField()
self.assertEqual(['age', 'id', 'name', 'salary'],
self.assertEqual(['_cls', 'age', 'id', 'name', 'salary'],
sorted(Employee._fields.keys()))
self.assertEqual(Person(name="Bob", age=35).to_mongo().keys(),
['_cls', 'name', 'age'])

View File

@ -462,7 +462,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(person['name'], 'Another User')
# Length = length(assigned fields + id)
self.assertEqual(len(person), 3)
self.assertEqual(len(person), 4)
self.assertTrue('age' in person)
person.age = None

View File

@ -3034,6 +3034,29 @@ class FieldTest(unittest.TestCase):
test.dictionary # Just access to test getter
self.assertRaises(ValidationError, test.validate)
def test_cls_field(self):
class Animal(Document):
meta = {'allow_inheritance': True}
class Fish(Animal):
pass
class Mammal(Animal):
pass
class Dog(Mammal):
pass
class Human(Mammal):
pass
Animal.objects.delete()
Dog().save()
Fish().save()
Human().save()
self.assertEquals(Animal.objects(_cls__in=["Animal.Mammal.Dog", "Animal.Fish"]).count(), 2)
self.assertEquals(Animal.objects(_cls__in=["Animal.Fish.Guppy"]).count(), 0)
if __name__ == '__main__':
unittest.main()