From 6c4aee147933dab020cced417bd556411becb3da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wilson=20J=C3=BAnior?= Date: Thu, 17 Jul 2014 13:42:34 -0300 Subject: [PATCH] added CachedReferenceField restriction to use in EmbeddedDocument --- mongoengine/base/metaclasses.py | 40 ++++++++++++++++++++------------- mongoengine/common.py | 1 + tests/fields/fields.py | 12 ++++++++++ 3 files changed, 38 insertions(+), 15 deletions(-) diff --git a/mongoengine/base/metaclasses.py b/mongoengine/base/metaclasses.py index 4b2e8b9b..887c9abc 100644 --- a/mongoengine/base/metaclasses.py +++ b/mongoengine/base/metaclasses.py @@ -16,6 +16,7 @@ __all__ = ('DocumentMetaclass', 'TopLevelDocumentMetaclass') class DocumentMetaclass(type): + """Metaclass for all documents. """ @@ -90,7 +91,7 @@ class DocumentMetaclass(type): # Set _fields and db_field maps attrs['_fields'] = doc_fields attrs['_db_field_map'] = dict([(k, getattr(v, 'db_field', k)) - for k, v in doc_fields.iteritems()]) + for k, v in doc_fields.iteritems()]) attrs['_reverse_db_field_map'] = dict( (v, k) for k, v in attrs['_db_field_map'].iteritems()) @@ -105,7 +106,7 @@ class DocumentMetaclass(type): class_name = [name] for base in flattened_bases: if (not getattr(base, '_is_base_cls', True) and - not getattr(base, '_meta', {}).get('abstract', True)): + not getattr(base, '_meta', {}).get('abstract', True)): # Collate heirarchy for _cls and _subclasses class_name.append(base.__name__) @@ -115,7 +116,7 @@ class DocumentMetaclass(type): allow_inheritance = base._meta.get('allow_inheritance', ALLOW_INHERITANCE) if (allow_inheritance is not True and - not base._meta.get('abstract')): + not base._meta.get('abstract')): raise ValueError('Document %s may not be subclassed' % base.__name__) @@ -141,7 +142,8 @@ class DocumentMetaclass(type): base._subclasses += (_cls,) base._types = base._subclasses # TODO depreciate _types - Document, EmbeddedDocument, DictField = cls._import_classes() + (Document, EmbeddedDocument, DictField, + CachedReferenceField) = cls._import_classes() if issubclass(new_class, Document): new_class._collection = None @@ -170,6 +172,10 @@ class DocumentMetaclass(type): f = field f.owner_document = new_class delete_rule = getattr(f, 'reverse_delete_rule', DO_NOTHING) + if isinstance(f, CachedReferenceField) and issubclass( + new_class, EmbeddedDocument): + raise InvalidDocumentError( + "CachedReferenceFields is not allowed in EmbeddedDocuments") if isinstance(f, ComplexBaseField) and hasattr(f, 'field'): delete_rule = getattr(f.field, 'reverse_delete_rule', @@ -191,7 +197,7 @@ class DocumentMetaclass(type): field.name, delete_rule) if (field.name and hasattr(Document, field.name) and - EmbeddedDocument not in new_class.mro()): + EmbeddedDocument not in new_class.mro()): msg = ("%s is a document method and not a valid " "field name" % field.name) raise InvalidDocumentError(msg) @@ -224,10 +230,12 @@ class DocumentMetaclass(type): Document = _import_class('Document') EmbeddedDocument = _import_class('EmbeddedDocument') DictField = _import_class('DictField') - return (Document, EmbeddedDocument, DictField) + CachedReferenceField = _import_class('CachedReferenceField') + return (Document, EmbeddedDocument, DictField, CachedReferenceField) class TopLevelDocumentMetaclass(DocumentMetaclass): + """Metaclass for top-level documents (i.e. documents that have their own collection in the database. """ @@ -275,21 +283,21 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): # Find the parent document class parent_doc_cls = [b for b in flattened_bases - if b.__class__ == TopLevelDocumentMetaclass] + if b.__class__ == TopLevelDocumentMetaclass] parent_doc_cls = None if not parent_doc_cls else parent_doc_cls[0] # Prevent classes setting collection different to their parents # If parent wasn't an abstract class if (parent_doc_cls and 'collection' in attrs.get('_meta', {}) - and not parent_doc_cls._meta.get('abstract', True)): - msg = "Trying to set a collection on a subclass (%s)" % name - warnings.warn(msg, SyntaxWarning) - del(attrs['_meta']['collection']) + and not parent_doc_cls._meta.get('abstract', True)): + msg = "Trying to set a collection on a subclass (%s)" % name + warnings.warn(msg, SyntaxWarning) + del(attrs['_meta']['collection']) # Ensure abstract documents have abstract bases if attrs.get('_is_base_cls') or attrs['_meta'].get('abstract'): if (parent_doc_cls and - not parent_doc_cls._meta.get('abstract', False)): + not parent_doc_cls._meta.get('abstract', False)): msg = "Abstract document cannot have non-abstract base" raise ValueError(msg) return super_new(cls, name, bases, attrs) @@ -306,7 +314,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): # Set collection in the meta if its callable if (getattr(base, '_is_document', False) and - not base._meta.get('abstract')): + not base._meta.get('abstract')): collection = meta.get('collection', None) if callable(collection): meta['collection'] = collection(base) @@ -318,7 +326,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): simple_class = all([b._meta.get('abstract') for b in flattened_bases if hasattr(b, '_meta')]) if (not simple_class and meta['allow_inheritance'] is False and - not meta['abstract']): + not meta['abstract']): raise ValueError('Only direct subclasses of Document may set ' '"allow_inheritance" to False') @@ -378,7 +386,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): for exc in exceptions_to_merge: name = exc.__name__ parents = tuple(getattr(base, name) for base in flattened_bases - if hasattr(base, name)) or (exc,) + if hasattr(base, name)) or (exc,) # Create new exception and set to new_class exception = type(name, parents, {'__module__': module}) setattr(new_class, name, exception) @@ -387,6 +395,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): class MetaDict(dict): + """Custom dictionary for meta classes. Handles the merging of set indexes """ @@ -401,5 +410,6 @@ class MetaDict(dict): class BasesTuple(tuple): + """Special class to handle introspection of bases tuple in __new__""" pass diff --git a/mongoengine/common.py b/mongoengine/common.py index daa194b9..7c0c18d2 100644 --- a/mongoengine/common.py +++ b/mongoengine/common.py @@ -25,6 +25,7 @@ def _import_class(cls_name): 'GenericEmbeddedDocumentField', 'GeoPointField', 'PointField', 'LineStringField', 'ListField', 'PolygonField', 'ReferenceField', 'StringField', + 'CachedReferenceField', 'ComplexBaseField', 'GeoJsonBaseField') queryset_classes = ('OperationError',) deref_classes = ('DeReference',) diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 5da06981..c82c936b 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -1539,6 +1539,18 @@ class FieldTest(unittest.TestCase): self.assertEqual(ocorrence.person, "teste") self.assertTrue(isinstance(ocorrence.animal, Animal)) + def test_cached_reference_fields_on_embedded_documents(self): + def build(): + class Test(Document): + name = StringField() + + type('WrongEmbeddedDocument', ( + EmbeddedDocument,), { + 'test': CachedReferenceField(Test) + }) + + self.assertRaises(InvalidDocumentError, build) + def test_cached_reference_embedded_fields(self): class Owner(EmbeddedDocument): TPS = (