Add _instance to Embedded Documents

Fixes MongoEngine/mongoengine#139
This commit is contained in:
Ross Lawley 2012-11-06 16:04:23 +00:00
parent 3cc2c617fd
commit 7d90aa76ff
5 changed files with 70 additions and 13 deletions

View File

@ -1,4 +1,5 @@
import weakref
from mongoengine.common import _import_class
__all__ = ("BaseDict", "BaseList")
@ -16,6 +17,14 @@ class BaseDict(dict):
self._name = name
return super(BaseDict, self).__init__(dict_items)
def __getitem__(self, *args, **kwargs):
value = super(BaseDict, self).__getitem__(*args, **kwargs)
EmbeddedDocument = _import_class('EmbeddedDocument')
if isinstance(value, EmbeddedDocument) and value._instance is None:
value._instance = self._instance
return value
def __setitem__(self, *args, **kwargs):
self._mark_as_changed()
return super(BaseDict, self).__setitem__(*args, **kwargs)
@ -75,6 +84,14 @@ class BaseList(list):
self._name = name
return super(BaseList, self).__init__(list_items)
def __getitem__(self, *args, **kwargs):
value = super(BaseList, self).__getitem__(*args, **kwargs)
EmbeddedDocument = _import_class('EmbeddedDocument')
if isinstance(value, EmbeddedDocument) and value._instance is None:
value._instance = self._instance
return value
def __setitem__(self, *args, **kwargs):
self._mark_as_changed()
return super(BaseList, self).__setitem__(*args, **kwargs)
@ -84,7 +101,8 @@ class BaseList(list):
return super(BaseList, self).__delitem__(*args, **kwargs)
def __getstate__(self):
self.observer = None
self.instance = None
self._dereferenced = False
return self
def __setstate__(self, state):

View File

@ -1,5 +1,6 @@
import operator
import warnings
import weakref
from bson import DBRef, ObjectId
@ -71,6 +72,9 @@ class BaseField(object):
if callable(value):
value = value()
EmbeddedDocument = _import_class('EmbeddedDocument')
if isinstance(value, EmbeddedDocument) and value._instance is None:
value._instance = weakref.proxy(instance)
return value
def __set__(self, instance, value):

View File

@ -40,6 +40,8 @@ class EmbeddedDocument(BaseDocument):
my_metaclass = DocumentMetaclass
__metaclass__ = DocumentMetaclass
_instance = None
def __init__(self, *args, **kwargs):
super(EmbeddedDocument, self).__init__(*args, **kwargs)
self._changed_fields = []

View File

@ -625,7 +625,8 @@ class SortedListField(ListField):
def to_mongo(self, value):
value = super(SortedListField, self).to_mongo(value)
if self._ordering is not None:
return sorted(value, key=itemgetter(self._ordering), reverse=self._order_reverse)
return sorted(value, key=itemgetter(self._ordering),
reverse=self._order_reverse)
return sorted(value, reverse=self._order_reverse)
@ -655,7 +656,9 @@ class DictField(ComplexBaseField):
self.error('Only dictionaries may be used in a DictField')
if any(k for k in value.keys() if not isinstance(k, basestring)):
self.error('Invalid dictionary key - documents must have only string keys')
msg = ("Invalid dictionary key - documents must "
"have only string keys")
self.error(msg)
if any(('.' in k or '$' in k) for k in value.keys()):
self.error('Invalid dictionary key name - keys may not contain "."'
' or "$" characters')

View File

@ -183,9 +183,6 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(list_stats, CompareStats.objects.first().stats)
def test_db_field_load(self):
"""Ensure we load data correctly
"""
@ -214,24 +211,24 @@ class InstanceTest(unittest.TestCase):
class Person(Document):
name = StringField(required=True)
rank_ = EmbeddedDocumentField(Rank, required=False, db_field='rank')
rank_ = EmbeddedDocumentField(Rank,
required=False,
db_field='rank')
@property
def rank(self):
return self.rank_.title if self.rank_ is not None else "Private"
if self.rank_ is None:
return "Private"
return self.rank_.title
Person.drop_collection()
Person(name="Jack", rank_=Rank(title="Corporal")).save()
Person(name="Fred").save()
self.assertEqual(Person.objects.get(name="Jack").rank, "Corporal")
self.assertEqual(Person.objects.get(name="Fred").rank, "Private")
def test_custom_id_field(self):
"""Ensure that documents may be created with custom primary keys.
"""
@ -247,7 +244,7 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(User._meta['id_field'], 'username')
def create_invalid_user():
User(name='test').save() # no primary key field
User(name='test').save() # no primary key field
self.assertRaises(ValidationError, create_invalid_user)
def define_invalid_user():
@ -424,6 +421,36 @@ class InstanceTest(unittest.TestCase):
self.assertTrue('content' in Comment._fields)
self.assertFalse('id' in Comment._fields)
def test_embedded_document_instance(self):
"""Ensure that embedded documents can reference parent instance
"""
class Embedded(EmbeddedDocument):
string = StringField()
class Doc(Document):
embedded_field = EmbeddedDocumentField(Embedded)
Doc.drop_collection()
Doc(embedded_field=Embedded(string="Hi")).save()
doc = Doc.objects.get()
self.assertEqual(doc, doc.embedded_field._instance)
def test_embedded_document_complex_instance(self):
"""Ensure that embedded documents in complex fields can reference
parent instance"""
class Embedded(EmbeddedDocument):
string = StringField()
class Doc(Document):
embedded_field = ListField(EmbeddedDocumentField(Embedded))
Doc.drop_collection()
Doc(embedded_field=[Embedded(string="Hi")]).save()
doc = Doc.objects.get()
self.assertEqual(doc, doc.embedded_field[0]._instance)
def test_embedded_document_validation(self):
"""Ensure that embedded documents may be validated.
"""
@ -442,6 +469,7 @@ class InstanceTest(unittest.TestCase):
comment.date = datetime.now()
comment.validate()
self.assertEqual(comment._instance, None)
def test_embedded_db_field_validate(self):
@ -475,11 +503,13 @@ class InstanceTest(unittest.TestCase):
self.assertEqual(person_obj['age'], 30)
self.assertEqual(person_obj['_id'], person.id)
# Test skipping validation on save
class Recipient(Document):
email = EmailField(required=True)
recipient = Recipient(email='root@localhost')
self.assertRaises(ValidationError, recipient.save)
try:
recipient.save(validate=False)
except ValidationError: