parent
3cc2c617fd
commit
7d90aa76ff
@ -1,4 +1,5 @@
|
|||||||
import weakref
|
import weakref
|
||||||
|
from mongoengine.common import _import_class
|
||||||
|
|
||||||
__all__ = ("BaseDict", "BaseList")
|
__all__ = ("BaseDict", "BaseList")
|
||||||
|
|
||||||
@ -16,6 +17,14 @@ class BaseDict(dict):
|
|||||||
self._name = name
|
self._name = name
|
||||||
return super(BaseDict, self).__init__(dict_items)
|
return super(BaseDict, self).__init__(dict_items)
|
||||||
|
|
||||||
|
def __getitem__(self, *args, **kwargs):
|
||||||
|
value = super(BaseDict, self).__getitem__(*args, **kwargs)
|
||||||
|
|
||||||
|
EmbeddedDocument = _import_class('EmbeddedDocument')
|
||||||
|
if isinstance(value, EmbeddedDocument) and value._instance is None:
|
||||||
|
value._instance = self._instance
|
||||||
|
return value
|
||||||
|
|
||||||
def __setitem__(self, *args, **kwargs):
|
def __setitem__(self, *args, **kwargs):
|
||||||
self._mark_as_changed()
|
self._mark_as_changed()
|
||||||
return super(BaseDict, self).__setitem__(*args, **kwargs)
|
return super(BaseDict, self).__setitem__(*args, **kwargs)
|
||||||
@ -75,6 +84,14 @@ class BaseList(list):
|
|||||||
self._name = name
|
self._name = name
|
||||||
return super(BaseList, self).__init__(list_items)
|
return super(BaseList, self).__init__(list_items)
|
||||||
|
|
||||||
|
def __getitem__(self, *args, **kwargs):
|
||||||
|
value = super(BaseList, self).__getitem__(*args, **kwargs)
|
||||||
|
|
||||||
|
EmbeddedDocument = _import_class('EmbeddedDocument')
|
||||||
|
if isinstance(value, EmbeddedDocument) and value._instance is None:
|
||||||
|
value._instance = self._instance
|
||||||
|
return value
|
||||||
|
|
||||||
def __setitem__(self, *args, **kwargs):
|
def __setitem__(self, *args, **kwargs):
|
||||||
self._mark_as_changed()
|
self._mark_as_changed()
|
||||||
return super(BaseList, self).__setitem__(*args, **kwargs)
|
return super(BaseList, self).__setitem__(*args, **kwargs)
|
||||||
@ -84,7 +101,8 @@ class BaseList(list):
|
|||||||
return super(BaseList, self).__delitem__(*args, **kwargs)
|
return super(BaseList, self).__delitem__(*args, **kwargs)
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
self.observer = None
|
self.instance = None
|
||||||
|
self._dereferenced = False
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import operator
|
import operator
|
||||||
import warnings
|
import warnings
|
||||||
|
import weakref
|
||||||
|
|
||||||
from bson import DBRef, ObjectId
|
from bson import DBRef, ObjectId
|
||||||
|
|
||||||
@ -71,6 +72,9 @@ class BaseField(object):
|
|||||||
if callable(value):
|
if callable(value):
|
||||||
value = value()
|
value = value()
|
||||||
|
|
||||||
|
EmbeddedDocument = _import_class('EmbeddedDocument')
|
||||||
|
if isinstance(value, EmbeddedDocument) and value._instance is None:
|
||||||
|
value._instance = weakref.proxy(instance)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def __set__(self, instance, value):
|
def __set__(self, instance, value):
|
||||||
|
@ -40,6 +40,8 @@ class EmbeddedDocument(BaseDocument):
|
|||||||
my_metaclass = DocumentMetaclass
|
my_metaclass = DocumentMetaclass
|
||||||
__metaclass__ = DocumentMetaclass
|
__metaclass__ = DocumentMetaclass
|
||||||
|
|
||||||
|
_instance = None
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(EmbeddedDocument, self).__init__(*args, **kwargs)
|
super(EmbeddedDocument, self).__init__(*args, **kwargs)
|
||||||
self._changed_fields = []
|
self._changed_fields = []
|
||||||
|
@ -625,7 +625,8 @@ class SortedListField(ListField):
|
|||||||
def to_mongo(self, value):
|
def to_mongo(self, value):
|
||||||
value = super(SortedListField, self).to_mongo(value)
|
value = super(SortedListField, self).to_mongo(value)
|
||||||
if self._ordering is not None:
|
if self._ordering is not None:
|
||||||
return sorted(value, key=itemgetter(self._ordering), reverse=self._order_reverse)
|
return sorted(value, key=itemgetter(self._ordering),
|
||||||
|
reverse=self._order_reverse)
|
||||||
return sorted(value, reverse=self._order_reverse)
|
return sorted(value, reverse=self._order_reverse)
|
||||||
|
|
||||||
|
|
||||||
@ -655,7 +656,9 @@ class DictField(ComplexBaseField):
|
|||||||
self.error('Only dictionaries may be used in a DictField')
|
self.error('Only dictionaries may be used in a DictField')
|
||||||
|
|
||||||
if any(k for k in value.keys() if not isinstance(k, basestring)):
|
if any(k for k in value.keys() if not isinstance(k, basestring)):
|
||||||
self.error('Invalid dictionary key - documents must have only string keys')
|
msg = ("Invalid dictionary key - documents must "
|
||||||
|
"have only string keys")
|
||||||
|
self.error(msg)
|
||||||
if any(('.' in k or '$' in k) for k in value.keys()):
|
if any(('.' in k or '$' in k) for k in value.keys()):
|
||||||
self.error('Invalid dictionary key name - keys may not contain "."'
|
self.error('Invalid dictionary key name - keys may not contain "."'
|
||||||
' or "$" characters')
|
' or "$" characters')
|
||||||
|
@ -183,9 +183,6 @@ class InstanceTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(list_stats, CompareStats.objects.first().stats)
|
self.assertEqual(list_stats, CompareStats.objects.first().stats)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_db_field_load(self):
|
def test_db_field_load(self):
|
||||||
"""Ensure we load data correctly
|
"""Ensure we load data correctly
|
||||||
"""
|
"""
|
||||||
@ -214,24 +211,24 @@ class InstanceTest(unittest.TestCase):
|
|||||||
|
|
||||||
class Person(Document):
|
class Person(Document):
|
||||||
name = StringField(required=True)
|
name = StringField(required=True)
|
||||||
rank_ = EmbeddedDocumentField(Rank, required=False, db_field='rank')
|
rank_ = EmbeddedDocumentField(Rank,
|
||||||
|
required=False,
|
||||||
|
db_field='rank')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def rank(self):
|
def rank(self):
|
||||||
return self.rank_.title if self.rank_ is not None else "Private"
|
if self.rank_ is None:
|
||||||
|
return "Private"
|
||||||
|
return self.rank_.title
|
||||||
|
|
||||||
Person.drop_collection()
|
Person.drop_collection()
|
||||||
|
|
||||||
Person(name="Jack", rank_=Rank(title="Corporal")).save()
|
Person(name="Jack", rank_=Rank(title="Corporal")).save()
|
||||||
|
|
||||||
Person(name="Fred").save()
|
Person(name="Fred").save()
|
||||||
|
|
||||||
self.assertEqual(Person.objects.get(name="Jack").rank, "Corporal")
|
self.assertEqual(Person.objects.get(name="Jack").rank, "Corporal")
|
||||||
self.assertEqual(Person.objects.get(name="Fred").rank, "Private")
|
self.assertEqual(Person.objects.get(name="Fred").rank, "Private")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_custom_id_field(self):
|
def test_custom_id_field(self):
|
||||||
"""Ensure that documents may be created with custom primary keys.
|
"""Ensure that documents may be created with custom primary keys.
|
||||||
"""
|
"""
|
||||||
@ -247,7 +244,7 @@ class InstanceTest(unittest.TestCase):
|
|||||||
self.assertEqual(User._meta['id_field'], 'username')
|
self.assertEqual(User._meta['id_field'], 'username')
|
||||||
|
|
||||||
def create_invalid_user():
|
def create_invalid_user():
|
||||||
User(name='test').save() # no primary key field
|
User(name='test').save() # no primary key field
|
||||||
self.assertRaises(ValidationError, create_invalid_user)
|
self.assertRaises(ValidationError, create_invalid_user)
|
||||||
|
|
||||||
def define_invalid_user():
|
def define_invalid_user():
|
||||||
@ -424,6 +421,36 @@ class InstanceTest(unittest.TestCase):
|
|||||||
self.assertTrue('content' in Comment._fields)
|
self.assertTrue('content' in Comment._fields)
|
||||||
self.assertFalse('id' in Comment._fields)
|
self.assertFalse('id' in Comment._fields)
|
||||||
|
|
||||||
|
def test_embedded_document_instance(self):
|
||||||
|
"""Ensure that embedded documents can reference parent instance
|
||||||
|
"""
|
||||||
|
class Embedded(EmbeddedDocument):
|
||||||
|
string = StringField()
|
||||||
|
|
||||||
|
class Doc(Document):
|
||||||
|
embedded_field = EmbeddedDocumentField(Embedded)
|
||||||
|
|
||||||
|
Doc.drop_collection()
|
||||||
|
Doc(embedded_field=Embedded(string="Hi")).save()
|
||||||
|
|
||||||
|
doc = Doc.objects.get()
|
||||||
|
self.assertEqual(doc, doc.embedded_field._instance)
|
||||||
|
|
||||||
|
def test_embedded_document_complex_instance(self):
|
||||||
|
"""Ensure that embedded documents in complex fields can reference
|
||||||
|
parent instance"""
|
||||||
|
class Embedded(EmbeddedDocument):
|
||||||
|
string = StringField()
|
||||||
|
|
||||||
|
class Doc(Document):
|
||||||
|
embedded_field = ListField(EmbeddedDocumentField(Embedded))
|
||||||
|
|
||||||
|
Doc.drop_collection()
|
||||||
|
Doc(embedded_field=[Embedded(string="Hi")]).save()
|
||||||
|
|
||||||
|
doc = Doc.objects.get()
|
||||||
|
self.assertEqual(doc, doc.embedded_field[0]._instance)
|
||||||
|
|
||||||
def test_embedded_document_validation(self):
|
def test_embedded_document_validation(self):
|
||||||
"""Ensure that embedded documents may be validated.
|
"""Ensure that embedded documents may be validated.
|
||||||
"""
|
"""
|
||||||
@ -442,6 +469,7 @@ class InstanceTest(unittest.TestCase):
|
|||||||
|
|
||||||
comment.date = datetime.now()
|
comment.date = datetime.now()
|
||||||
comment.validate()
|
comment.validate()
|
||||||
|
self.assertEqual(comment._instance, None)
|
||||||
|
|
||||||
def test_embedded_db_field_validate(self):
|
def test_embedded_db_field_validate(self):
|
||||||
|
|
||||||
@ -475,11 +503,13 @@ class InstanceTest(unittest.TestCase):
|
|||||||
self.assertEqual(person_obj['age'], 30)
|
self.assertEqual(person_obj['age'], 30)
|
||||||
self.assertEqual(person_obj['_id'], person.id)
|
self.assertEqual(person_obj['_id'], person.id)
|
||||||
# Test skipping validation on save
|
# Test skipping validation on save
|
||||||
|
|
||||||
class Recipient(Document):
|
class Recipient(Document):
|
||||||
email = EmailField(required=True)
|
email = EmailField(required=True)
|
||||||
|
|
||||||
recipient = Recipient(email='root@localhost')
|
recipient = Recipient(email='root@localhost')
|
||||||
self.assertRaises(ValidationError, recipient.save)
|
self.assertRaises(ValidationError, recipient.save)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
recipient.save(validate=False)
|
recipient.save(validate=False)
|
||||||
except ValidationError:
|
except ValidationError:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user