Made lists of recursive reference fields possible

This commit is contained in:
Harry Marr 2010-10-03 01:48:42 +01:00
parent 98bc0a7c10
commit 159923fae2
2 changed files with 33 additions and 2 deletions

View File

@ -20,6 +20,7 @@ __all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField',
RECURSIVE_REFERENCE_CONSTANT = 'self' RECURSIVE_REFERENCE_CONSTANT = 'self'
class StringField(BaseField): class StringField(BaseField):
"""A unicode string field. """A unicode string field.
""" """
@ -105,6 +106,7 @@ class URLField(StringField):
message = 'This URL appears to be a broken link: %s' % e message = 'This URL appears to be a broken link: %s' % e
raise ValidationError(message) raise ValidationError(message)
class EmailField(StringField): class EmailField(StringField):
"""A field that validates input as an E-Mail-Address. """A field that validates input as an E-Mail-Address.
""" """
@ -119,6 +121,7 @@ class EmailField(StringField):
if not EmailField.EMAIL_REGEX.match(value): if not EmailField.EMAIL_REGEX.match(value):
raise ValidationError('Invalid Mail-address: %s' % value) raise ValidationError('Invalid Mail-address: %s' % value)
class IntField(BaseField): class IntField(BaseField):
"""An integer field. """An integer field.
""" """
@ -142,6 +145,7 @@ class IntField(BaseField):
if self.max_value is not None and value > self.max_value: if self.max_value is not None and value > self.max_value:
raise ValidationError('Integer value is too large') raise ValidationError('Integer value is too large')
class FloatField(BaseField): class FloatField(BaseField):
"""An floating point number field. """An floating point number field.
""" """
@ -197,6 +201,7 @@ class DecimalField(BaseField):
if self.max_value is not None and value > self.max_value: if self.max_value is not None and value > self.max_value:
raise ValidationError('Decimal value is too large') raise ValidationError('Decimal value is too large')
class BooleanField(BaseField): class BooleanField(BaseField):
"""A boolean field type. """A boolean field type.
@ -209,6 +214,7 @@ class BooleanField(BaseField):
def validate(self, value): def validate(self, value):
assert isinstance(value, bool) assert isinstance(value, bool)
class DateTimeField(BaseField): class DateTimeField(BaseField):
"""A datetime field. """A datetime field.
""" """
@ -216,6 +222,7 @@ class DateTimeField(BaseField):
def validate(self, value): def validate(self, value):
assert isinstance(value, datetime.datetime) assert isinstance(value, datetime.datetime)
class EmbeddedDocumentField(BaseField): class EmbeddedDocumentField(BaseField):
"""An embedded document field. Only valid values are subclasses of """An embedded document field. Only valid values are subclasses of
:class:`~mongoengine.EmbeddedDocument`. :class:`~mongoengine.EmbeddedDocument`.
@ -331,6 +338,16 @@ class ListField(BaseField):
def lookup_member(self, member_name): def lookup_member(self, member_name):
return self.field.lookup_member(member_name) return self.field.lookup_member(member_name)
def _set_owner_document(self, owner_document):
self.field.owner_document = owner_document
self._owner_document = owner_document
def _get_owner_document(self, owner_document):
self._owner_document = owner_document
owner_document = property(_get_owner_document, _set_owner_document)
class SortedListField(ListField): class SortedListField(ListField):
"""A ListField that sorts the contents of its list before writing to """A ListField that sorts the contents of its list before writing to
the database in order to ensure that a sorted list is always the database in order to ensure that a sorted list is always
@ -346,9 +363,11 @@ class SortedListField(ListField):
def to_mongo(self, value): def to_mongo(self, value):
if self._ordering is not None: if self._ordering is not None:
return sorted([self.field.to_mongo(item) for item in value], key=itemgetter(self._ordering)) return sorted([self.field.to_mongo(item) for item in value],
key=itemgetter(self._ordering))
return sorted([self.field.to_mongo(item) for item in value]) return sorted([self.field.to_mongo(item) for item in value])
class DictField(BaseField): class DictField(BaseField):
"""A dictionary field that wraps a standard Python dictionary. This is """A dictionary field that wraps a standard Python dictionary. This is
similar to an embedded document, but the structure is not defined. similar to an embedded document, but the structure is not defined.
@ -442,6 +461,7 @@ class ReferenceField(BaseField):
def lookup_member(self, member_name): def lookup_member(self, member_name):
return self.document_type._fields.get(member_name) return self.document_type._fields.get(member_name)
class GenericReferenceField(BaseField): class GenericReferenceField(BaseField):
"""A reference to *any* :class:`~mongoengine.document.Document` subclass """A reference to *any* :class:`~mongoengine.document.Document` subclass
that will be automatically dereferenced on access (lazily). that will be automatically dereferenced on access (lazily).
@ -640,6 +660,7 @@ class FileField(BaseField):
assert isinstance(value, GridFSProxy) assert isinstance(value, GridFSProxy)
assert isinstance(value.grid_id, pymongo.objectid.ObjectId) assert isinstance(value.grid_id, pymongo.objectid.ObjectId)
class GeoPointField(BaseField): class GeoPointField(BaseField):
"""A list storing a latitude and longitude. """A list storing a latitude and longitude.
""" """

View File

@ -394,14 +394,24 @@ class FieldTest(unittest.TestCase):
class Employee(Document): class Employee(Document):
name = StringField() name = StringField()
boss = ReferenceField('self') boss = ReferenceField('self')
friends = ListField(ReferenceField('self'))
bill = Employee(name='Bill Lumbergh') bill = Employee(name='Bill Lumbergh')
bill.save() bill.save()
peter = Employee(name='Peter Gibbons', boss=bill)
michael = Employee(name='Michael Bolton')
michael.save()
samir = Employee(name='Samir Nagheenanajar')
samir.save()
friends = [michael, samir]
peter = Employee(name='Peter Gibbons', boss=bill, friends=friends)
peter.save() peter.save()
peter = Employee.objects.with_id(peter.id) peter = Employee.objects.with_id(peter.id)
self.assertEqual(peter.boss, bill) self.assertEqual(peter.boss, bill)
self.assertEqual(peter.friends, friends)
def test_undefined_reference(self): def test_undefined_reference(self):
"""Ensure that ReferenceFields may reference undefined Documents. """Ensure that ReferenceFields may reference undefined Documents.