parent
3cc2c617fd
commit
7d90aa76ff
@ -1,4 +1,5 @@
|
||||
import weakref
|
||||
from mongoengine.common import _import_class
|
||||
|
||||
__all__ = ("BaseDict", "BaseList")
|
||||
|
||||
@ -16,6 +17,14 @@ class BaseDict(dict):
|
||||
self._name = name
|
||||
return super(BaseDict, self).__init__(dict_items)
|
||||
|
||||
def __getitem__(self, *args, **kwargs):
|
||||
value = super(BaseDict, self).__getitem__(*args, **kwargs)
|
||||
|
||||
EmbeddedDocument = _import_class('EmbeddedDocument')
|
||||
if isinstance(value, EmbeddedDocument) and value._instance is None:
|
||||
value._instance = self._instance
|
||||
return value
|
||||
|
||||
def __setitem__(self, *args, **kwargs):
|
||||
self._mark_as_changed()
|
||||
return super(BaseDict, self).__setitem__(*args, **kwargs)
|
||||
@ -75,6 +84,14 @@ class BaseList(list):
|
||||
self._name = name
|
||||
return super(BaseList, self).__init__(list_items)
|
||||
|
||||
def __getitem__(self, *args, **kwargs):
|
||||
value = super(BaseList, self).__getitem__(*args, **kwargs)
|
||||
|
||||
EmbeddedDocument = _import_class('EmbeddedDocument')
|
||||
if isinstance(value, EmbeddedDocument) and value._instance is None:
|
||||
value._instance = self._instance
|
||||
return value
|
||||
|
||||
def __setitem__(self, *args, **kwargs):
|
||||
self._mark_as_changed()
|
||||
return super(BaseList, self).__setitem__(*args, **kwargs)
|
||||
@ -84,7 +101,8 @@ class BaseList(list):
|
||||
return super(BaseList, self).__delitem__(*args, **kwargs)
|
||||
|
||||
def __getstate__(self):
|
||||
self.observer = None
|
||||
self.instance = None
|
||||
self._dereferenced = False
|
||||
return self
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
@ -1,5 +1,6 @@
|
||||
import operator
|
||||
import warnings
|
||||
import weakref
|
||||
|
||||
from bson import DBRef, ObjectId
|
||||
|
||||
@ -71,6 +72,9 @@ class BaseField(object):
|
||||
if callable(value):
|
||||
value = value()
|
||||
|
||||
EmbeddedDocument = _import_class('EmbeddedDocument')
|
||||
if isinstance(value, EmbeddedDocument) and value._instance is None:
|
||||
value._instance = weakref.proxy(instance)
|
||||
return value
|
||||
|
||||
def __set__(self, instance, value):
|
||||
|
@ -40,6 +40,8 @@ class EmbeddedDocument(BaseDocument):
|
||||
my_metaclass = DocumentMetaclass
|
||||
__metaclass__ = DocumentMetaclass
|
||||
|
||||
_instance = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(EmbeddedDocument, self).__init__(*args, **kwargs)
|
||||
self._changed_fields = []
|
||||
|
@ -625,7 +625,8 @@ class SortedListField(ListField):
|
||||
def to_mongo(self, value):
|
||||
value = super(SortedListField, self).to_mongo(value)
|
||||
if self._ordering is not None:
|
||||
return sorted(value, key=itemgetter(self._ordering), reverse=self._order_reverse)
|
||||
return sorted(value, key=itemgetter(self._ordering),
|
||||
reverse=self._order_reverse)
|
||||
return sorted(value, reverse=self._order_reverse)
|
||||
|
||||
|
||||
@ -655,7 +656,9 @@ class DictField(ComplexBaseField):
|
||||
self.error('Only dictionaries may be used in a DictField')
|
||||
|
||||
if any(k for k in value.keys() if not isinstance(k, basestring)):
|
||||
self.error('Invalid dictionary key - documents must have only string keys')
|
||||
msg = ("Invalid dictionary key - documents must "
|
||||
"have only string keys")
|
||||
self.error(msg)
|
||||
if any(('.' in k or '$' in k) for k in value.keys()):
|
||||
self.error('Invalid dictionary key name - keys may not contain "."'
|
||||
' or "$" characters')
|
||||
|
@ -183,9 +183,6 @@ class InstanceTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(list_stats, CompareStats.objects.first().stats)
|
||||
|
||||
|
||||
|
||||
|
||||
def test_db_field_load(self):
|
||||
"""Ensure we load data correctly
|
||||
"""
|
||||
@ -214,24 +211,24 @@ class InstanceTest(unittest.TestCase):
|
||||
|
||||
class Person(Document):
|
||||
name = StringField(required=True)
|
||||
rank_ = EmbeddedDocumentField(Rank, required=False, db_field='rank')
|
||||
rank_ = EmbeddedDocumentField(Rank,
|
||||
required=False,
|
||||
db_field='rank')
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
return self.rank_.title if self.rank_ is not None else "Private"
|
||||
if self.rank_ is None:
|
||||
return "Private"
|
||||
return self.rank_.title
|
||||
|
||||
Person.drop_collection()
|
||||
|
||||
Person(name="Jack", rank_=Rank(title="Corporal")).save()
|
||||
|
||||
Person(name="Fred").save()
|
||||
|
||||
self.assertEqual(Person.objects.get(name="Jack").rank, "Corporal")
|
||||
self.assertEqual(Person.objects.get(name="Fred").rank, "Private")
|
||||
|
||||
|
||||
|
||||
|
||||
def test_custom_id_field(self):
|
||||
"""Ensure that documents may be created with custom primary keys.
|
||||
"""
|
||||
@ -247,7 +244,7 @@ class InstanceTest(unittest.TestCase):
|
||||
self.assertEqual(User._meta['id_field'], 'username')
|
||||
|
||||
def create_invalid_user():
|
||||
User(name='test').save() # no primary key field
|
||||
User(name='test').save() # no primary key field
|
||||
self.assertRaises(ValidationError, create_invalid_user)
|
||||
|
||||
def define_invalid_user():
|
||||
@ -424,6 +421,36 @@ class InstanceTest(unittest.TestCase):
|
||||
self.assertTrue('content' in Comment._fields)
|
||||
self.assertFalse('id' in Comment._fields)
|
||||
|
||||
def test_embedded_document_instance(self):
|
||||
"""Ensure that embedded documents can reference parent instance
|
||||
"""
|
||||
class Embedded(EmbeddedDocument):
|
||||
string = StringField()
|
||||
|
||||
class Doc(Document):
|
||||
embedded_field = EmbeddedDocumentField(Embedded)
|
||||
|
||||
Doc.drop_collection()
|
||||
Doc(embedded_field=Embedded(string="Hi")).save()
|
||||
|
||||
doc = Doc.objects.get()
|
||||
self.assertEqual(doc, doc.embedded_field._instance)
|
||||
|
||||
def test_embedded_document_complex_instance(self):
|
||||
"""Ensure that embedded documents in complex fields can reference
|
||||
parent instance"""
|
||||
class Embedded(EmbeddedDocument):
|
||||
string = StringField()
|
||||
|
||||
class Doc(Document):
|
||||
embedded_field = ListField(EmbeddedDocumentField(Embedded))
|
||||
|
||||
Doc.drop_collection()
|
||||
Doc(embedded_field=[Embedded(string="Hi")]).save()
|
||||
|
||||
doc = Doc.objects.get()
|
||||
self.assertEqual(doc, doc.embedded_field[0]._instance)
|
||||
|
||||
def test_embedded_document_validation(self):
|
||||
"""Ensure that embedded documents may be validated.
|
||||
"""
|
||||
@ -442,6 +469,7 @@ class InstanceTest(unittest.TestCase):
|
||||
|
||||
comment.date = datetime.now()
|
||||
comment.validate()
|
||||
self.assertEqual(comment._instance, None)
|
||||
|
||||
def test_embedded_db_field_validate(self):
|
||||
|
||||
@ -475,11 +503,13 @@ class InstanceTest(unittest.TestCase):
|
||||
self.assertEqual(person_obj['age'], 30)
|
||||
self.assertEqual(person_obj['_id'], person.id)
|
||||
# Test skipping validation on save
|
||||
|
||||
class Recipient(Document):
|
||||
email = EmailField(required=True)
|
||||
|
||||
recipient = Recipient(email='root@localhost')
|
||||
self.assertRaises(ValidationError, recipient.save)
|
||||
|
||||
try:
|
||||
recipient.save(validate=False)
|
||||
except ValidationError:
|
||||
|
Loading…
x
Reference in New Issue
Block a user