added GenericEmbeddedDocumentField

This commit is contained in:
Wilson Júnior 2011-07-27 08:45:15 -03:00
parent e3cbeb9df0
commit 6471c6e133
2 changed files with 52 additions and 1 deletions

View File

@ -21,7 +21,7 @@ __all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField',
'DecimalField', 'ComplexDateTimeField', 'URLField',
'GenericReferenceField', 'FileField', 'BinaryField',
'SortedListField', 'EmailField', 'GeoPointField',
'SequenceField']
'SequenceField', 'GenericEmbeddedDocumentField']
RECURSIVE_REFERENCE_CONSTANT = 'self'
@ -420,6 +420,32 @@ class EmbeddedDocumentField(BaseField):
def prepare_query_value(self, op, value):
return self.to_mongo(value)
class GenericEmbeddedDocumentField(BaseField):
def prepare_query_value(self, op, value):
return self.to_mongo(value)
def to_python(self, value):
if isinstance(value, dict):
doc_cls = get_document(value['_cls'])
value = doc_cls._from_son(value)
return value
def validate(self, value):
if not isinstance(value, EmbeddedDocument):
raise ValidationError('Invalid embedded document instance '
'provided to an GenericEmbeddedDocumentField')
value.validate()
def to_mongo(self, document):
if document is None:
return None
data = document.to_mongo()
if not '_cls' in data:
data['_cls'] = document._class_name
return data
class ListField(ComplexBaseField):
"""A list field that wraps a standard field, allowing multiple instances

View File

@ -1488,5 +1488,30 @@ class FieldTest(unittest.TestCase):
self.assertEqual(c['next'], 10)
def test_generic_embedded_document(self):
class Car(EmbeddedDocument):
name = StringField()
class Dish(EmbeddedDocument):
food = StringField(required=True)
number = IntField()
class Person(Document):
name = StringField()
like = GenericEmbeddedDocumentField()
person = Person(name='Test User')
person.like = Car(name='Fiat')
person.save()
person = Person.objects.first()
self.assertTrue(isinstance(person.like, Car))
person.like = Dish(food="arroz", number=15)
person.save()
person = Person.objects.first()
self.assertTrue(isinstance(person.like, Dish))
if __name__ == '__main__':
unittest.main()