added CachedReferenceField restriction to use in EmbeddedDocument

This commit is contained in:
Wilson Júnior 2014-07-17 13:42:34 -03:00
parent 73549a9044
commit 6c4aee1479
3 changed files with 38 additions and 15 deletions

View File

@ -16,6 +16,7 @@ __all__ = ('DocumentMetaclass', 'TopLevelDocumentMetaclass')
class DocumentMetaclass(type): class DocumentMetaclass(type):
"""Metaclass for all documents. """Metaclass for all documents.
""" """
@ -90,7 +91,7 @@ class DocumentMetaclass(type):
# Set _fields and db_field maps # Set _fields and db_field maps
attrs['_fields'] = doc_fields attrs['_fields'] = doc_fields
attrs['_db_field_map'] = dict([(k, getattr(v, 'db_field', k)) attrs['_db_field_map'] = dict([(k, getattr(v, 'db_field', k))
for k, v in doc_fields.iteritems()]) for k, v in doc_fields.iteritems()])
attrs['_reverse_db_field_map'] = dict( attrs['_reverse_db_field_map'] = dict(
(v, k) for k, v in attrs['_db_field_map'].iteritems()) (v, k) for k, v in attrs['_db_field_map'].iteritems())
@ -105,7 +106,7 @@ class DocumentMetaclass(type):
class_name = [name] class_name = [name]
for base in flattened_bases: for base in flattened_bases:
if (not getattr(base, '_is_base_cls', True) and if (not getattr(base, '_is_base_cls', True) and
not getattr(base, '_meta', {}).get('abstract', True)): not getattr(base, '_meta', {}).get('abstract', True)):
# Collate heirarchy for _cls and _subclasses # Collate heirarchy for _cls and _subclasses
class_name.append(base.__name__) class_name.append(base.__name__)
@ -115,7 +116,7 @@ class DocumentMetaclass(type):
allow_inheritance = base._meta.get('allow_inheritance', allow_inheritance = base._meta.get('allow_inheritance',
ALLOW_INHERITANCE) ALLOW_INHERITANCE)
if (allow_inheritance is not True and if (allow_inheritance is not True and
not base._meta.get('abstract')): not base._meta.get('abstract')):
raise ValueError('Document %s may not be subclassed' % raise ValueError('Document %s may not be subclassed' %
base.__name__) base.__name__)
@ -141,7 +142,8 @@ class DocumentMetaclass(type):
base._subclasses += (_cls,) base._subclasses += (_cls,)
base._types = base._subclasses # TODO depreciate _types base._types = base._subclasses # TODO depreciate _types
Document, EmbeddedDocument, DictField = cls._import_classes() (Document, EmbeddedDocument, DictField,
CachedReferenceField) = cls._import_classes()
if issubclass(new_class, Document): if issubclass(new_class, Document):
new_class._collection = None new_class._collection = None
@ -170,6 +172,10 @@ class DocumentMetaclass(type):
f = field f = field
f.owner_document = new_class f.owner_document = new_class
delete_rule = getattr(f, 'reverse_delete_rule', DO_NOTHING) delete_rule = getattr(f, 'reverse_delete_rule', DO_NOTHING)
if isinstance(f, CachedReferenceField) and issubclass(
new_class, EmbeddedDocument):
raise InvalidDocumentError(
"CachedReferenceFields is not allowed in EmbeddedDocuments")
if isinstance(f, ComplexBaseField) and hasattr(f, 'field'): if isinstance(f, ComplexBaseField) and hasattr(f, 'field'):
delete_rule = getattr(f.field, delete_rule = getattr(f.field,
'reverse_delete_rule', 'reverse_delete_rule',
@ -191,7 +197,7 @@ class DocumentMetaclass(type):
field.name, delete_rule) field.name, delete_rule)
if (field.name and hasattr(Document, field.name) and if (field.name and hasattr(Document, field.name) and
EmbeddedDocument not in new_class.mro()): EmbeddedDocument not in new_class.mro()):
msg = ("%s is a document method and not a valid " msg = ("%s is a document method and not a valid "
"field name" % field.name) "field name" % field.name)
raise InvalidDocumentError(msg) raise InvalidDocumentError(msg)
@ -224,10 +230,12 @@ class DocumentMetaclass(type):
Document = _import_class('Document') Document = _import_class('Document')
EmbeddedDocument = _import_class('EmbeddedDocument') EmbeddedDocument = _import_class('EmbeddedDocument')
DictField = _import_class('DictField') DictField = _import_class('DictField')
return (Document, EmbeddedDocument, DictField) CachedReferenceField = _import_class('CachedReferenceField')
return (Document, EmbeddedDocument, DictField, CachedReferenceField)
class TopLevelDocumentMetaclass(DocumentMetaclass): class TopLevelDocumentMetaclass(DocumentMetaclass):
"""Metaclass for top-level documents (i.e. documents that have their own """Metaclass for top-level documents (i.e. documents that have their own
collection in the database. collection in the database.
""" """
@ -275,21 +283,21 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
# Find the parent document class # Find the parent document class
parent_doc_cls = [b for b in flattened_bases parent_doc_cls = [b for b in flattened_bases
if b.__class__ == TopLevelDocumentMetaclass] if b.__class__ == TopLevelDocumentMetaclass]
parent_doc_cls = None if not parent_doc_cls else parent_doc_cls[0] parent_doc_cls = None if not parent_doc_cls else parent_doc_cls[0]
# Prevent classes setting collection different to their parents # Prevent classes setting collection different to their parents
# If parent wasn't an abstract class # If parent wasn't an abstract class
if (parent_doc_cls and 'collection' in attrs.get('_meta', {}) if (parent_doc_cls and 'collection' in attrs.get('_meta', {})
and not parent_doc_cls._meta.get('abstract', True)): and not parent_doc_cls._meta.get('abstract', True)):
msg = "Trying to set a collection on a subclass (%s)" % name msg = "Trying to set a collection on a subclass (%s)" % name
warnings.warn(msg, SyntaxWarning) warnings.warn(msg, SyntaxWarning)
del(attrs['_meta']['collection']) del(attrs['_meta']['collection'])
# Ensure abstract documents have abstract bases # Ensure abstract documents have abstract bases
if attrs.get('_is_base_cls') or attrs['_meta'].get('abstract'): if attrs.get('_is_base_cls') or attrs['_meta'].get('abstract'):
if (parent_doc_cls and if (parent_doc_cls and
not parent_doc_cls._meta.get('abstract', False)): not parent_doc_cls._meta.get('abstract', False)):
msg = "Abstract document cannot have non-abstract base" msg = "Abstract document cannot have non-abstract base"
raise ValueError(msg) raise ValueError(msg)
return super_new(cls, name, bases, attrs) return super_new(cls, name, bases, attrs)
@ -306,7 +314,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
# Set collection in the meta if its callable # Set collection in the meta if its callable
if (getattr(base, '_is_document', False) and if (getattr(base, '_is_document', False) and
not base._meta.get('abstract')): not base._meta.get('abstract')):
collection = meta.get('collection', None) collection = meta.get('collection', None)
if callable(collection): if callable(collection):
meta['collection'] = collection(base) meta['collection'] = collection(base)
@ -318,7 +326,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
simple_class = all([b._meta.get('abstract') simple_class = all([b._meta.get('abstract')
for b in flattened_bases if hasattr(b, '_meta')]) for b in flattened_bases if hasattr(b, '_meta')])
if (not simple_class and meta['allow_inheritance'] is False and if (not simple_class and meta['allow_inheritance'] is False and
not meta['abstract']): not meta['abstract']):
raise ValueError('Only direct subclasses of Document may set ' raise ValueError('Only direct subclasses of Document may set '
'"allow_inheritance" to False') '"allow_inheritance" to False')
@ -378,7 +386,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
for exc in exceptions_to_merge: for exc in exceptions_to_merge:
name = exc.__name__ name = exc.__name__
parents = tuple(getattr(base, name) for base in flattened_bases parents = tuple(getattr(base, name) for base in flattened_bases
if hasattr(base, name)) or (exc,) if hasattr(base, name)) or (exc,)
# Create new exception and set to new_class # Create new exception and set to new_class
exception = type(name, parents, {'__module__': module}) exception = type(name, parents, {'__module__': module})
setattr(new_class, name, exception) setattr(new_class, name, exception)
@ -387,6 +395,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
class MetaDict(dict): class MetaDict(dict):
"""Custom dictionary for meta classes. """Custom dictionary for meta classes.
Handles the merging of set indexes Handles the merging of set indexes
""" """
@ -401,5 +410,6 @@ class MetaDict(dict):
class BasesTuple(tuple): class BasesTuple(tuple):
"""Special class to handle introspection of bases tuple in __new__""" """Special class to handle introspection of bases tuple in __new__"""
pass pass

View File

@ -25,6 +25,7 @@ def _import_class(cls_name):
'GenericEmbeddedDocumentField', 'GeoPointField', 'GenericEmbeddedDocumentField', 'GeoPointField',
'PointField', 'LineStringField', 'ListField', 'PointField', 'LineStringField', 'ListField',
'PolygonField', 'ReferenceField', 'StringField', 'PolygonField', 'ReferenceField', 'StringField',
'CachedReferenceField',
'ComplexBaseField', 'GeoJsonBaseField') 'ComplexBaseField', 'GeoJsonBaseField')
queryset_classes = ('OperationError',) queryset_classes = ('OperationError',)
deref_classes = ('DeReference',) deref_classes = ('DeReference',)

View File

@ -1539,6 +1539,18 @@ class FieldTest(unittest.TestCase):
self.assertEqual(ocorrence.person, "teste") self.assertEqual(ocorrence.person, "teste")
self.assertTrue(isinstance(ocorrence.animal, Animal)) self.assertTrue(isinstance(ocorrence.animal, Animal))
def test_cached_reference_fields_on_embedded_documents(self):
def build():
class Test(Document):
name = StringField()
type('WrongEmbeddedDocument', (
EmbeddedDocument,), {
'test': CachedReferenceField(Test)
})
self.assertRaises(InvalidDocumentError, build)
def test_cached_reference_embedded_fields(self): def test_cached_reference_embedded_fields(self):
class Owner(EmbeddedDocument): class Owner(EmbeddedDocument):
TPS = ( TPS = (