Updated docs and added a NotRegistered exception
For handling GenericReferences that reference documents that haven't been imported. Closes #170
This commit is contained in:
parent
40b69baa29
commit
9260ff9e83
@ -7,22 +7,32 @@ import pymongo
|
|||||||
import pymongo.objectid
|
import pymongo.objectid
|
||||||
|
|
||||||
|
|
||||||
_document_registry = {}
|
class NotRegistered(Exception):
|
||||||
|
pass
|
||||||
def get_document(name):
|
|
||||||
return _document_registry[name]
|
|
||||||
|
|
||||||
|
|
||||||
class ValidationError(Exception):
|
class ValidationError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
_document_registry = {}
|
||||||
|
|
||||||
|
def get_document(name):
|
||||||
|
if name not in _document_registry:
|
||||||
|
raise NotRegistered("""
|
||||||
|
`%s` has not been registered in the document registry.
|
||||||
|
Importing the document class automatically registers it, has it
|
||||||
|
been imported?
|
||||||
|
""".strip() % name)
|
||||||
|
return _document_registry[name]
|
||||||
|
|
||||||
|
|
||||||
class BaseField(object):
|
class BaseField(object):
|
||||||
"""A base class for fields in a MongoDB document. Instances of this class
|
"""A base class for fields in a MongoDB document. Instances of this class
|
||||||
may be added to subclasses of `Document` to define a document's schema.
|
may be added to subclasses of `Document` to define a document's schema.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Fields may have _types inserted into indexes by default
|
# Fields may have _types inserted into indexes by default
|
||||||
_index_with_types = True
|
_index_with_types = True
|
||||||
_geo_index = False
|
_geo_index = False
|
||||||
|
|
||||||
@ -32,7 +42,7 @@ class BaseField(object):
|
|||||||
creation_counter = 0
|
creation_counter = 0
|
||||||
auto_creation_counter = -1
|
auto_creation_counter = -1
|
||||||
|
|
||||||
def __init__(self, db_field=None, name=None, required=False, default=None,
|
def __init__(self, db_field=None, name=None, required=False, default=None,
|
||||||
unique=False, unique_with=None, primary_key=False,
|
unique=False, unique_with=None, primary_key=False,
|
||||||
validation=None, choices=None):
|
validation=None, choices=None):
|
||||||
self.db_field = (db_field or name) if not primary_key else '_id'
|
self.db_field = (db_field or name) if not primary_key else '_id'
|
||||||
@ -57,7 +67,7 @@ class BaseField(object):
|
|||||||
BaseField.creation_counter += 1
|
BaseField.creation_counter += 1
|
||||||
|
|
||||||
def __get__(self, instance, owner):
|
def __get__(self, instance, owner):
|
||||||
"""Descriptor for retrieving a value from a field in a document. Do
|
"""Descriptor for retrieving a value from a field in a document. Do
|
||||||
any necessary conversion between Python and MongoDB types.
|
any necessary conversion between Python and MongoDB types.
|
||||||
"""
|
"""
|
||||||
if instance is None:
|
if instance is None:
|
||||||
@ -167,8 +177,8 @@ class DocumentMetaclass(type):
|
|||||||
superclasses.update(base._superclasses)
|
superclasses.update(base._superclasses)
|
||||||
|
|
||||||
if hasattr(base, '_meta'):
|
if hasattr(base, '_meta'):
|
||||||
# Ensure that the Document class may be subclassed -
|
# Ensure that the Document class may be subclassed -
|
||||||
# inheritance may be disabled to remove dependency on
|
# inheritance may be disabled to remove dependency on
|
||||||
# additional fields _cls and _types
|
# additional fields _cls and _types
|
||||||
if base._meta.get('allow_inheritance', True) == False:
|
if base._meta.get('allow_inheritance', True) == False:
|
||||||
raise ValueError('Document %s may not be subclassed' %
|
raise ValueError('Document %s may not be subclassed' %
|
||||||
@ -211,12 +221,12 @@ class DocumentMetaclass(type):
|
|||||||
|
|
||||||
module = attrs.get('__module__')
|
module = attrs.get('__module__')
|
||||||
|
|
||||||
base_excs = tuple(base.DoesNotExist for base in bases
|
base_excs = tuple(base.DoesNotExist for base in bases
|
||||||
if hasattr(base, 'DoesNotExist')) or (DoesNotExist,)
|
if hasattr(base, 'DoesNotExist')) or (DoesNotExist,)
|
||||||
exc = subclass_exception('DoesNotExist', base_excs, module)
|
exc = subclass_exception('DoesNotExist', base_excs, module)
|
||||||
new_class.add_to_class('DoesNotExist', exc)
|
new_class.add_to_class('DoesNotExist', exc)
|
||||||
|
|
||||||
base_excs = tuple(base.MultipleObjectsReturned for base in bases
|
base_excs = tuple(base.MultipleObjectsReturned for base in bases
|
||||||
if hasattr(base, 'MultipleObjectsReturned'))
|
if hasattr(base, 'MultipleObjectsReturned'))
|
||||||
base_excs = base_excs or (MultipleObjectsReturned,)
|
base_excs = base_excs or (MultipleObjectsReturned,)
|
||||||
exc = subclass_exception('MultipleObjectsReturned', base_excs, module)
|
exc = subclass_exception('MultipleObjectsReturned', base_excs, module)
|
||||||
@ -238,9 +248,9 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
|
|||||||
|
|
||||||
def __new__(cls, name, bases, attrs):
|
def __new__(cls, name, bases, attrs):
|
||||||
super_new = super(TopLevelDocumentMetaclass, cls).__new__
|
super_new = super(TopLevelDocumentMetaclass, cls).__new__
|
||||||
# Classes defined in this package are abstract and should not have
|
# Classes defined in this package are abstract and should not have
|
||||||
# their own metadata with DB collection, etc.
|
# their own metadata with DB collection, etc.
|
||||||
# __metaclass__ is only set on the class with the __metaclass__
|
# __metaclass__ is only set on the class with the __metaclass__
|
||||||
# attribute (i.e. it is not set on subclasses). This differentiates
|
# attribute (i.e. it is not set on subclasses). This differentiates
|
||||||
# 'real' documents from the 'Document' class
|
# 'real' documents from the 'Document' class
|
||||||
if attrs.get('__metaclass__') == TopLevelDocumentMetaclass:
|
if attrs.get('__metaclass__') == TopLevelDocumentMetaclass:
|
||||||
@ -366,7 +376,7 @@ class BaseDocument(object):
|
|||||||
are present.
|
are present.
|
||||||
"""
|
"""
|
||||||
# Get a list of tuples of field names and their current values
|
# Get a list of tuples of field names and their current values
|
||||||
fields = [(field, getattr(self, name))
|
fields = [(field, getattr(self, name))
|
||||||
for name, field in self._fields.items()]
|
for name, field in self._fields.items()]
|
||||||
|
|
||||||
# Ensure that each field is matched to a valid value
|
# Ensure that each field is matched to a valid value
|
||||||
|
@ -339,7 +339,7 @@ class ListField(BaseField):
|
|||||||
|
|
||||||
if isinstance(self.field, ReferenceField):
|
if isinstance(self.field, ReferenceField):
|
||||||
referenced_type = self.field.document_type
|
referenced_type = self.field.document_type
|
||||||
# Get value from document instance if available
|
# Get value from document instance if available
|
||||||
value_list = instance._data.get(self.name)
|
value_list = instance._data.get(self.name)
|
||||||
if value_list:
|
if value_list:
|
||||||
deref_list = []
|
deref_list = []
|
||||||
@ -522,6 +522,9 @@ class GenericReferenceField(BaseField):
|
|||||||
"""A reference to *any* :class:`~mongoengine.document.Document` subclass
|
"""A reference to *any* :class:`~mongoengine.document.Document` subclass
|
||||||
that will be automatically dereferenced on access (lazily).
|
that will be automatically dereferenced on access (lazily).
|
||||||
|
|
||||||
|
note: Any documents used as a generic reference must be registered in the
|
||||||
|
document registry. Importing the model will automatically register it.
|
||||||
|
|
||||||
.. versionadded:: 0.3
|
.. versionadded:: 0.3
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -648,7 +651,7 @@ class GridFSProxy(object):
|
|||||||
if not self.newfile:
|
if not self.newfile:
|
||||||
self.new_file()
|
self.new_file()
|
||||||
self.grid_id = self.newfile._id
|
self.grid_id = self.newfile._id
|
||||||
self.newfile.writelines(lines)
|
self.newfile.writelines(lines)
|
||||||
|
|
||||||
def read(self, size=-1):
|
def read(self, size=-1):
|
||||||
try:
|
try:
|
||||||
|
@ -7,6 +7,7 @@ import gridfs
|
|||||||
|
|
||||||
from mongoengine import *
|
from mongoengine import *
|
||||||
from mongoengine.connection import _get_db
|
from mongoengine.connection import _get_db
|
||||||
|
from mongoengine.base import _document_registry, NotRegistered
|
||||||
|
|
||||||
|
|
||||||
class FieldTest(unittest.TestCase):
|
class FieldTest(unittest.TestCase):
|
||||||
@ -45,7 +46,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
class Person(Document):
|
class Person(Document):
|
||||||
name = StringField()
|
name = StringField()
|
||||||
|
|
||||||
person = Person(name='Test User')
|
person = Person(name='Test User')
|
||||||
self.assertEqual(person.id, None)
|
self.assertEqual(person.id, None)
|
||||||
|
|
||||||
@ -95,7 +96,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
|
|
||||||
link.url = 'http://www.google.com:8080'
|
link.url = 'http://www.google.com:8080'
|
||||||
link.validate()
|
link.validate()
|
||||||
|
|
||||||
def test_int_validation(self):
|
def test_int_validation(self):
|
||||||
"""Ensure that invalid values cannot be assigned to int fields.
|
"""Ensure that invalid values cannot be assigned to int fields.
|
||||||
"""
|
"""
|
||||||
@ -129,12 +130,12 @@ class FieldTest(unittest.TestCase):
|
|||||||
self.assertRaises(ValidationError, person.validate)
|
self.assertRaises(ValidationError, person.validate)
|
||||||
person.height = 4.0
|
person.height = 4.0
|
||||||
self.assertRaises(ValidationError, person.validate)
|
self.assertRaises(ValidationError, person.validate)
|
||||||
|
|
||||||
def test_decimal_validation(self):
|
def test_decimal_validation(self):
|
||||||
"""Ensure that invalid values cannot be assigned to decimal fields.
|
"""Ensure that invalid values cannot be assigned to decimal fields.
|
||||||
"""
|
"""
|
||||||
class Person(Document):
|
class Person(Document):
|
||||||
height = DecimalField(min_value=Decimal('0.1'),
|
height = DecimalField(min_value=Decimal('0.1'),
|
||||||
max_value=Decimal('3.5'))
|
max_value=Decimal('3.5'))
|
||||||
|
|
||||||
Person.drop_collection()
|
Person.drop_collection()
|
||||||
@ -249,7 +250,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
post.save()
|
post.save()
|
||||||
post.reload()
|
post.reload()
|
||||||
self.assertEqual(post.tags, ['fun', 'leisure'])
|
self.assertEqual(post.tags, ['fun', 'leisure'])
|
||||||
|
|
||||||
comment1 = Comment(content='Good for you', order=1)
|
comment1 = Comment(content='Good for you', order=1)
|
||||||
comment2 = Comment(content='Yay.', order=0)
|
comment2 = Comment(content='Yay.', order=0)
|
||||||
comments = [comment1, comment2]
|
comments = [comment1, comment2]
|
||||||
@ -315,7 +316,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
person.validate()
|
person.validate()
|
||||||
|
|
||||||
def test_embedded_document_inheritance(self):
|
def test_embedded_document_inheritance(self):
|
||||||
"""Ensure that subclasses of embedded documents may be provided to
|
"""Ensure that subclasses of embedded documents may be provided to
|
||||||
EmbeddedDocumentFields of the superclass' type.
|
EmbeddedDocumentFields of the superclass' type.
|
||||||
"""
|
"""
|
||||||
class User(EmbeddedDocument):
|
class User(EmbeddedDocument):
|
||||||
@ -327,7 +328,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
class BlogPost(Document):
|
class BlogPost(Document):
|
||||||
content = StringField()
|
content = StringField()
|
||||||
author = EmbeddedDocumentField(User)
|
author = EmbeddedDocumentField(User)
|
||||||
|
|
||||||
post = BlogPost(content='What I did today...')
|
post = BlogPost(content='What I did today...')
|
||||||
post.author = User(name='Test User')
|
post.author = User(name='Test User')
|
||||||
post.author = PowerUser(name='Test User', power=47)
|
post.author = PowerUser(name='Test User', power=47)
|
||||||
@ -370,7 +371,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
|
|
||||||
User.drop_collection()
|
User.drop_collection()
|
||||||
BlogPost.drop_collection()
|
BlogPost.drop_collection()
|
||||||
|
|
||||||
def test_list_item_dereference(self):
|
def test_list_item_dereference(self):
|
||||||
"""Ensure that DBRef items in ListFields are dereferenced.
|
"""Ensure that DBRef items in ListFields are dereferenced.
|
||||||
"""
|
"""
|
||||||
@ -434,7 +435,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
class TreeNode(EmbeddedDocument):
|
class TreeNode(EmbeddedDocument):
|
||||||
name = StringField()
|
name = StringField()
|
||||||
children = ListField(EmbeddedDocumentField('self'))
|
children = ListField(EmbeddedDocumentField('self'))
|
||||||
|
|
||||||
tree = Tree(name="Tree")
|
tree = Tree(name="Tree")
|
||||||
|
|
||||||
first_child = TreeNode(name="Child 1")
|
first_child = TreeNode(name="Child 1")
|
||||||
@ -442,7 +443,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
|
|
||||||
second_child = TreeNode(name="Child 2")
|
second_child = TreeNode(name="Child 2")
|
||||||
first_child.children.append(second_child)
|
first_child.children.append(second_child)
|
||||||
|
|
||||||
third_child = TreeNode(name="Child 3")
|
third_child = TreeNode(name="Child 3")
|
||||||
first_child.children.append(third_child)
|
first_child.children.append(third_child)
|
||||||
|
|
||||||
@ -506,20 +507,20 @@ class FieldTest(unittest.TestCase):
|
|||||||
|
|
||||||
Member.drop_collection()
|
Member.drop_collection()
|
||||||
BlogPost.drop_collection()
|
BlogPost.drop_collection()
|
||||||
|
|
||||||
def test_generic_reference(self):
|
def test_generic_reference(self):
|
||||||
"""Ensure that a GenericReferenceField properly dereferences items.
|
"""Ensure that a GenericReferenceField properly dereferences items.
|
||||||
"""
|
"""
|
||||||
class Link(Document):
|
class Link(Document):
|
||||||
title = StringField()
|
title = StringField()
|
||||||
meta = {'allow_inheritance': False}
|
meta = {'allow_inheritance': False}
|
||||||
|
|
||||||
class Post(Document):
|
class Post(Document):
|
||||||
title = StringField()
|
title = StringField()
|
||||||
|
|
||||||
class Bookmark(Document):
|
class Bookmark(Document):
|
||||||
bookmark_object = GenericReferenceField()
|
bookmark_object = GenericReferenceField()
|
||||||
|
|
||||||
Link.drop_collection()
|
Link.drop_collection()
|
||||||
Post.drop_collection()
|
Post.drop_collection()
|
||||||
Bookmark.drop_collection()
|
Bookmark.drop_collection()
|
||||||
@ -574,16 +575,49 @@ class FieldTest(unittest.TestCase):
|
|||||||
|
|
||||||
user = User(bookmarks=[post_1, link_1])
|
user = User(bookmarks=[post_1, link_1])
|
||||||
user.save()
|
user.save()
|
||||||
|
|
||||||
user = User.objects(bookmarks__all=[post_1, link_1]).first()
|
user = User.objects(bookmarks__all=[post_1, link_1]).first()
|
||||||
|
|
||||||
self.assertEqual(user.bookmarks[0], post_1)
|
self.assertEqual(user.bookmarks[0], post_1)
|
||||||
self.assertEqual(user.bookmarks[1], link_1)
|
self.assertEqual(user.bookmarks[1], link_1)
|
||||||
|
|
||||||
Link.drop_collection()
|
Link.drop_collection()
|
||||||
Post.drop_collection()
|
Post.drop_collection()
|
||||||
User.drop_collection()
|
User.drop_collection()
|
||||||
|
|
||||||
|
def test_generic_reference_document_not_registered(self):
|
||||||
|
"""Ensure dereferencing out of the document registry throws a
|
||||||
|
`NotRegistered` error.
|
||||||
|
"""
|
||||||
|
class Link(Document):
|
||||||
|
title = StringField()
|
||||||
|
|
||||||
|
class User(Document):
|
||||||
|
bookmarks = ListField(GenericReferenceField())
|
||||||
|
|
||||||
|
Link.drop_collection()
|
||||||
|
User.drop_collection()
|
||||||
|
|
||||||
|
link_1 = Link(title="Pitchfork")
|
||||||
|
link_1.save()
|
||||||
|
|
||||||
|
user = User(bookmarks=[link_1])
|
||||||
|
user.save()
|
||||||
|
|
||||||
|
# Mimic User and Link definitions being in a different file
|
||||||
|
# and the Link model not being imported in the User file.
|
||||||
|
del(_document_registry["Link"])
|
||||||
|
|
||||||
|
user = User.objects.first()
|
||||||
|
try:
|
||||||
|
user.bookmarks
|
||||||
|
raise AssertionError, "Link was removed from the registry"
|
||||||
|
except NotRegistered:
|
||||||
|
pass
|
||||||
|
|
||||||
|
Link.drop_collection()
|
||||||
|
User.drop_collection()
|
||||||
|
|
||||||
def test_binary_fields(self):
|
def test_binary_fields(self):
|
||||||
"""Ensure that binary fields can be stored and retrieved.
|
"""Ensure that binary fields can be stored and retrieved.
|
||||||
"""
|
"""
|
||||||
@ -727,7 +761,7 @@ class FieldTest(unittest.TestCase):
|
|||||||
result = SetFile.objects.first()
|
result = SetFile.objects.first()
|
||||||
self.assertTrue(setfile == result)
|
self.assertTrue(setfile == result)
|
||||||
self.assertEquals(result.file.read(), more_text)
|
self.assertEquals(result.file.read(), more_text)
|
||||||
result.file.delete()
|
result.file.delete()
|
||||||
|
|
||||||
PutFile.drop_collection()
|
PutFile.drop_collection()
|
||||||
StreamFile.drop_collection()
|
StreamFile.drop_collection()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user