diff --git a/mongomap/base.py b/mongomap/base.py index a906fbc6..c2bfa0fe 100644 --- a/mongomap/base.py +++ b/mongomap/base.py @@ -1,5 +1,7 @@ from collection import CollectionManager +import pymongo + class ValidationError(Exception): pass @@ -65,6 +67,25 @@ class BaseField(object): pass +class ObjectIdField(BaseField): + """An field wrapper around MongoDB's ObjectIds. + """ + + def _to_python(self, value): + return str(value) + + def _to_mongo(self, value): + if not isinstance(value, pymongo.objectid.ObjectId): + return pymongo.objectid.ObjectId(value) + return value + + def _validate(self, value): + try: + pymongo.objectid.ObjectId(str(value)) + except: + raise ValidationError('Invalid Object ID') + + class DocumentMetaclass(type): """Metaclass for all documents. """ @@ -76,7 +97,6 @@ class DocumentMetaclass(type): return super_new(cls, name, bases, attrs) doc_fields = {} - # Include all fields present in superclasses for base in bases: if hasattr(base, '_fields'): @@ -85,7 +105,6 @@ class DocumentMetaclass(type): # Add the document's fields to the _fields attribute for attr_name, attr_value in attrs.items(): if issubclass(attr_value.__class__, BaseField): - #print attr_value.name if not attr_value.name: attr_value.name = attr_name doc_fields[attr_name] = attr_value @@ -114,22 +133,15 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): for base in bases: if hasattr(base, '_meta') and 'collection' in base._meta: collection = base._meta['collection'] - - # Get primary key field - object_id_field = None - for attr_name, attr_value in attrs.items(): - if issubclass(attr_value.__class__, BaseField): - if hasattr(attr_value, 'object_id') and attr_value.object_id: - object_id_field = attr_name - attr_value.required = True meta = { 'collection': collection, - 'object_id_field': object_id_field, } meta.update(attrs.get('meta', {})) attrs['_meta'] = meta + attrs['_id'] = ObjectIdField() + # Set up collection manager, needs the class to have fields so use # DocumentMetaclass before instantiating CollectionManager object new_class = super_new(cls, name, bases, attrs) diff --git a/mongomap/collection.py b/mongomap/collection.py index 5160fd8a..5016a1a5 100644 --- a/mongomap/collection.py +++ b/mongomap/collection.py @@ -10,7 +10,6 @@ class CollectionManager(object): self._collection_name = document._meta['collection'] # This will create the collection if it doesn't exist self._collection = db[self._collection_name] - self._id_field = document._meta['object_id_field'] def _save_document(self, document): """Save the provided document to the collection. diff --git a/mongomap/fields.py b/mongomap/fields.py index 61bfb357..3003e3c9 100644 --- a/mongomap/fields.py +++ b/mongomap/fields.py @@ -1,20 +1,20 @@ -from base import BaseField, ValidationError +from base import BaseField, ObjectIdField, ValidationError from document import EmbeddedDocument import re +import pymongo __all__ = ['StringField', 'IntField', 'EmbeddedDocumentField', - 'ValidationError'] + 'ObjectIdField', 'ValidationError'] class StringField(BaseField): """A unicode string field. """ - def __init__(self, regex=None, max_length=None, object_id=False, **kwargs): + def __init__(self, regex=None, max_length=None, **kwargs): self.regex = re.compile(regex) if regex else None - self.object_id = object_id self.max_length = max_length super(StringField, self).__init__(**kwargs) diff --git a/tests/document.py b/tests/document.py index 57ed4dbb..7a6e1c4f 100644 --- a/tests/document.py +++ b/tests/document.py @@ -31,6 +31,7 @@ class DocumentTest(unittest.TestCase): self.assertEqual(Person._fields['name'], name_field) self.assertEqual(Person._fields['age'], age_field) self.assertFalse('non_field' in Person._fields) + self.assertTrue('_id' in Person._fields) # Test iteration over fields fields = list(Person()) self.assertTrue('name' in fields and 'age' in fields) @@ -67,7 +68,8 @@ class DocumentTest(unittest.TestCase): person['name'] = 'Another User' self.assertEquals(person['name'], 'Another User') - self.assertEquals(len(person), 2) + # Length = length(assigned fields + _id) + self.assertEquals(len(person), 3) self.assertTrue('age' in person) person.age = None @@ -81,6 +83,7 @@ class DocumentTest(unittest.TestCase): content = StringField() self.assertTrue('content' in Comment._fields) + self.assertFalse('_id' in Comment._fields) self.assertFalse(hasattr(Comment, '_meta')) def test_save(self): @@ -95,6 +98,18 @@ class DocumentTest(unittest.TestCase): self.assertEqual(person_obj['name'], 'Test User') self.assertEqual(person_obj['age'], 30) + def test_save_custom_id(self): + """Ensure that a document may be saved with a custom _id. + """ + # Create person object and save it to the database + person = self.Person(name='Test User', age=30, + _id='497ce96f395f2f052a494fd4') + person.save() + # Ensure that the object is in the database with the correct _id + collection = self.db[self.Person._meta['collection']] + person_obj = collection.find_one({'name': 'Test User'}) + self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4') + def test_save_embedded_document(self): """Ensure that a document with an embedded document field may be saved in the database. diff --git a/tests/fields.py b/tests/fields.py index f657debb..c4af2be0 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -24,7 +24,7 @@ class FieldTest(unittest.TestCase): """Ensure that required field constraints are enforced. """ class Person(Document): - name = StringField(object_id=True) + name = StringField(required=True) age = IntField(required=True) userid = StringField() @@ -36,6 +36,18 @@ class FieldTest(unittest.TestCase): self.assertRaises(ValidationError, person.__setattr__, 'age', None) person.userid = None + def test_object_id_validation(self): + """Ensure that invalid values cannot be assigned to string fields. + """ + class Person(Document): + name = StringField() + + person = Person(name='Test User') + self.assertRaises(AttributeError, getattr, person, '_id') + self.assertRaises(ValidationError, person.__setattr__, '_id', 47) + self.assertRaises(ValidationError, person.__setattr__, '_id', 'abc') + person._id = '497ce96f395f2f052a494fd4' + def test_string_validation(self): """Ensure that invalid values cannot be assigned to string fields. """