Updated docs and added a NotRegistered exception

For handling GenericReferences that reference documents that haven't
been imported.

Closes #170
This commit is contained in:
Ross Lawley 2011-05-20 10:22:22 +01:00
parent 40b69baa29
commit 9260ff9e83
3 changed files with 81 additions and 34 deletions

View File

@ -7,22 +7,32 @@ import pymongo
import pymongo.objectid import pymongo.objectid
_document_registry = {} class NotRegistered(Exception):
pass
def get_document(name):
return _document_registry[name]
class ValidationError(Exception): class ValidationError(Exception):
pass pass
_document_registry = {}
def get_document(name):
if name not in _document_registry:
raise NotRegistered("""
`%s` has not been registered in the document registry.
Importing the document class automatically registers it, has it
been imported?
""".strip() % name)
return _document_registry[name]
class BaseField(object): class BaseField(object):
"""A base class for fields in a MongoDB document. Instances of this class """A base class for fields in a MongoDB document. Instances of this class
may be added to subclasses of `Document` to define a document's schema. may be added to subclasses of `Document` to define a document's schema.
""" """
# Fields may have _types inserted into indexes by default # Fields may have _types inserted into indexes by default
_index_with_types = True _index_with_types = True
_geo_index = False _geo_index = False
@ -32,7 +42,7 @@ class BaseField(object):
creation_counter = 0 creation_counter = 0
auto_creation_counter = -1 auto_creation_counter = -1
def __init__(self, db_field=None, name=None, required=False, default=None, def __init__(self, db_field=None, name=None, required=False, default=None,
unique=False, unique_with=None, primary_key=False, unique=False, unique_with=None, primary_key=False,
validation=None, choices=None): validation=None, choices=None):
self.db_field = (db_field or name) if not primary_key else '_id' self.db_field = (db_field or name) if not primary_key else '_id'
@ -57,7 +67,7 @@ class BaseField(object):
BaseField.creation_counter += 1 BaseField.creation_counter += 1
def __get__(self, instance, owner): def __get__(self, instance, owner):
"""Descriptor for retrieving a value from a field in a document. Do """Descriptor for retrieving a value from a field in a document. Do
any necessary conversion between Python and MongoDB types. any necessary conversion between Python and MongoDB types.
""" """
if instance is None: if instance is None:
@ -167,8 +177,8 @@ class DocumentMetaclass(type):
superclasses.update(base._superclasses) superclasses.update(base._superclasses)
if hasattr(base, '_meta'): if hasattr(base, '_meta'):
# Ensure that the Document class may be subclassed - # Ensure that the Document class may be subclassed -
# inheritance may be disabled to remove dependency on # inheritance may be disabled to remove dependency on
# additional fields _cls and _types # additional fields _cls and _types
if base._meta.get('allow_inheritance', True) == False: if base._meta.get('allow_inheritance', True) == False:
raise ValueError('Document %s may not be subclassed' % raise ValueError('Document %s may not be subclassed' %
@ -211,12 +221,12 @@ class DocumentMetaclass(type):
module = attrs.get('__module__') module = attrs.get('__module__')
base_excs = tuple(base.DoesNotExist for base in bases base_excs = tuple(base.DoesNotExist for base in bases
if hasattr(base, 'DoesNotExist')) or (DoesNotExist,) if hasattr(base, 'DoesNotExist')) or (DoesNotExist,)
exc = subclass_exception('DoesNotExist', base_excs, module) exc = subclass_exception('DoesNotExist', base_excs, module)
new_class.add_to_class('DoesNotExist', exc) new_class.add_to_class('DoesNotExist', exc)
base_excs = tuple(base.MultipleObjectsReturned for base in bases base_excs = tuple(base.MultipleObjectsReturned for base in bases
if hasattr(base, 'MultipleObjectsReturned')) if hasattr(base, 'MultipleObjectsReturned'))
base_excs = base_excs or (MultipleObjectsReturned,) base_excs = base_excs or (MultipleObjectsReturned,)
exc = subclass_exception('MultipleObjectsReturned', base_excs, module) exc = subclass_exception('MultipleObjectsReturned', base_excs, module)
@ -238,9 +248,9 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
def __new__(cls, name, bases, attrs): def __new__(cls, name, bases, attrs):
super_new = super(TopLevelDocumentMetaclass, cls).__new__ super_new = super(TopLevelDocumentMetaclass, cls).__new__
# Classes defined in this package are abstract and should not have # Classes defined in this package are abstract and should not have
# their own metadata with DB collection, etc. # their own metadata with DB collection, etc.
# __metaclass__ is only set on the class with the __metaclass__ # __metaclass__ is only set on the class with the __metaclass__
# attribute (i.e. it is not set on subclasses). This differentiates # attribute (i.e. it is not set on subclasses). This differentiates
# 'real' documents from the 'Document' class # 'real' documents from the 'Document' class
if attrs.get('__metaclass__') == TopLevelDocumentMetaclass: if attrs.get('__metaclass__') == TopLevelDocumentMetaclass:
@ -366,7 +376,7 @@ class BaseDocument(object):
are present. are present.
""" """
# Get a list of tuples of field names and their current values # Get a list of tuples of field names and their current values
fields = [(field, getattr(self, name)) fields = [(field, getattr(self, name))
for name, field in self._fields.items()] for name, field in self._fields.items()]
# Ensure that each field is matched to a valid value # Ensure that each field is matched to a valid value

View File

@ -339,7 +339,7 @@ class ListField(BaseField):
if isinstance(self.field, ReferenceField): if isinstance(self.field, ReferenceField):
referenced_type = self.field.document_type referenced_type = self.field.document_type
# Get value from document instance if available # Get value from document instance if available
value_list = instance._data.get(self.name) value_list = instance._data.get(self.name)
if value_list: if value_list:
deref_list = [] deref_list = []
@ -522,6 +522,9 @@ class GenericReferenceField(BaseField):
"""A reference to *any* :class:`~mongoengine.document.Document` subclass """A reference to *any* :class:`~mongoengine.document.Document` subclass
that will be automatically dereferenced on access (lazily). that will be automatically dereferenced on access (lazily).
note: Any documents used as a generic reference must be registered in the
document registry. Importing the model will automatically register it.
.. versionadded:: 0.3 .. versionadded:: 0.3
""" """
@ -648,7 +651,7 @@ class GridFSProxy(object):
if not self.newfile: if not self.newfile:
self.new_file() self.new_file()
self.grid_id = self.newfile._id self.grid_id = self.newfile._id
self.newfile.writelines(lines) self.newfile.writelines(lines)
def read(self, size=-1): def read(self, size=-1):
try: try:

View File

@ -7,6 +7,7 @@ import gridfs
from mongoengine import * from mongoengine import *
from mongoengine.connection import _get_db from mongoengine.connection import _get_db
from mongoengine.base import _document_registry, NotRegistered
class FieldTest(unittest.TestCase): class FieldTest(unittest.TestCase):
@ -45,7 +46,7 @@ class FieldTest(unittest.TestCase):
""" """
class Person(Document): class Person(Document):
name = StringField() name = StringField()
person = Person(name='Test User') person = Person(name='Test User')
self.assertEqual(person.id, None) self.assertEqual(person.id, None)
@ -95,7 +96,7 @@ class FieldTest(unittest.TestCase):
link.url = 'http://www.google.com:8080' link.url = 'http://www.google.com:8080'
link.validate() link.validate()
def test_int_validation(self): def test_int_validation(self):
"""Ensure that invalid values cannot be assigned to int fields. """Ensure that invalid values cannot be assigned to int fields.
""" """
@ -129,12 +130,12 @@ class FieldTest(unittest.TestCase):
self.assertRaises(ValidationError, person.validate) self.assertRaises(ValidationError, person.validate)
person.height = 4.0 person.height = 4.0
self.assertRaises(ValidationError, person.validate) self.assertRaises(ValidationError, person.validate)
def test_decimal_validation(self): def test_decimal_validation(self):
"""Ensure that invalid values cannot be assigned to decimal fields. """Ensure that invalid values cannot be assigned to decimal fields.
""" """
class Person(Document): class Person(Document):
height = DecimalField(min_value=Decimal('0.1'), height = DecimalField(min_value=Decimal('0.1'),
max_value=Decimal('3.5')) max_value=Decimal('3.5'))
Person.drop_collection() Person.drop_collection()
@ -249,7 +250,7 @@ class FieldTest(unittest.TestCase):
post.save() post.save()
post.reload() post.reload()
self.assertEqual(post.tags, ['fun', 'leisure']) self.assertEqual(post.tags, ['fun', 'leisure'])
comment1 = Comment(content='Good for you', order=1) comment1 = Comment(content='Good for you', order=1)
comment2 = Comment(content='Yay.', order=0) comment2 = Comment(content='Yay.', order=0)
comments = [comment1, comment2] comments = [comment1, comment2]
@ -315,7 +316,7 @@ class FieldTest(unittest.TestCase):
person.validate() person.validate()
def test_embedded_document_inheritance(self): def test_embedded_document_inheritance(self):
"""Ensure that subclasses of embedded documents may be provided to """Ensure that subclasses of embedded documents may be provided to
EmbeddedDocumentFields of the superclass' type. EmbeddedDocumentFields of the superclass' type.
""" """
class User(EmbeddedDocument): class User(EmbeddedDocument):
@ -327,7 +328,7 @@ class FieldTest(unittest.TestCase):
class BlogPost(Document): class BlogPost(Document):
content = StringField() content = StringField()
author = EmbeddedDocumentField(User) author = EmbeddedDocumentField(User)
post = BlogPost(content='What I did today...') post = BlogPost(content='What I did today...')
post.author = User(name='Test User') post.author = User(name='Test User')
post.author = PowerUser(name='Test User', power=47) post.author = PowerUser(name='Test User', power=47)
@ -370,7 +371,7 @@ class FieldTest(unittest.TestCase):
User.drop_collection() User.drop_collection()
BlogPost.drop_collection() BlogPost.drop_collection()
def test_list_item_dereference(self): def test_list_item_dereference(self):
"""Ensure that DBRef items in ListFields are dereferenced. """Ensure that DBRef items in ListFields are dereferenced.
""" """
@ -434,7 +435,7 @@ class FieldTest(unittest.TestCase):
class TreeNode(EmbeddedDocument): class TreeNode(EmbeddedDocument):
name = StringField() name = StringField()
children = ListField(EmbeddedDocumentField('self')) children = ListField(EmbeddedDocumentField('self'))
tree = Tree(name="Tree") tree = Tree(name="Tree")
first_child = TreeNode(name="Child 1") first_child = TreeNode(name="Child 1")
@ -442,7 +443,7 @@ class FieldTest(unittest.TestCase):
second_child = TreeNode(name="Child 2") second_child = TreeNode(name="Child 2")
first_child.children.append(second_child) first_child.children.append(second_child)
third_child = TreeNode(name="Child 3") third_child = TreeNode(name="Child 3")
first_child.children.append(third_child) first_child.children.append(third_child)
@ -506,20 +507,20 @@ class FieldTest(unittest.TestCase):
Member.drop_collection() Member.drop_collection()
BlogPost.drop_collection() BlogPost.drop_collection()
def test_generic_reference(self): def test_generic_reference(self):
"""Ensure that a GenericReferenceField properly dereferences items. """Ensure that a GenericReferenceField properly dereferences items.
""" """
class Link(Document): class Link(Document):
title = StringField() title = StringField()
meta = {'allow_inheritance': False} meta = {'allow_inheritance': False}
class Post(Document): class Post(Document):
title = StringField() title = StringField()
class Bookmark(Document): class Bookmark(Document):
bookmark_object = GenericReferenceField() bookmark_object = GenericReferenceField()
Link.drop_collection() Link.drop_collection()
Post.drop_collection() Post.drop_collection()
Bookmark.drop_collection() Bookmark.drop_collection()
@ -574,16 +575,49 @@ class FieldTest(unittest.TestCase):
user = User(bookmarks=[post_1, link_1]) user = User(bookmarks=[post_1, link_1])
user.save() user.save()
user = User.objects(bookmarks__all=[post_1, link_1]).first() user = User.objects(bookmarks__all=[post_1, link_1]).first()
self.assertEqual(user.bookmarks[0], post_1) self.assertEqual(user.bookmarks[0], post_1)
self.assertEqual(user.bookmarks[1], link_1) self.assertEqual(user.bookmarks[1], link_1)
Link.drop_collection() Link.drop_collection()
Post.drop_collection() Post.drop_collection()
User.drop_collection() User.drop_collection()
def test_generic_reference_document_not_registered(self):
"""Ensure dereferencing out of the document registry throws a
`NotRegistered` error.
"""
class Link(Document):
title = StringField()
class User(Document):
bookmarks = ListField(GenericReferenceField())
Link.drop_collection()
User.drop_collection()
link_1 = Link(title="Pitchfork")
link_1.save()
user = User(bookmarks=[link_1])
user.save()
# Mimic User and Link definitions being in a different file
# and the Link model not being imported in the User file.
del(_document_registry["Link"])
user = User.objects.first()
try:
user.bookmarks
raise AssertionError, "Link was removed from the registry"
except NotRegistered:
pass
Link.drop_collection()
User.drop_collection()
def test_binary_fields(self): def test_binary_fields(self):
"""Ensure that binary fields can be stored and retrieved. """Ensure that binary fields can be stored and retrieved.
""" """
@ -727,7 +761,7 @@ class FieldTest(unittest.TestCase):
result = SetFile.objects.first() result = SetFile.objects.first()
self.assertTrue(setfile == result) self.assertTrue(setfile == result)
self.assertEquals(result.file.read(), more_text) self.assertEquals(result.file.read(), more_text)
result.file.delete() result.file.delete()
PutFile.drop_collection() PutFile.drop_collection()
StreamFile.drop_collection() StreamFile.drop_collection()