From c99f5c4ec16cde8ef2ce654971c9ec8acb7904b8 Mon Sep 17 00:00:00 2001 From: Harry Marr Date: Wed, 18 Nov 2009 19:02:57 +0000 Subject: [PATCH] Added CollectionManager, made connection module All connection-related functions are now in connection.py. Created a ConnectionManager class for interacting with a collection in the database. Top-level document classes have an instance of a ConnectionManager (Document.collection). Defined a 'save' method on top-level document's that uses the collection manager's '_save_document' method to save the document to the database. Added tests for CollectionManagers -- all unit tests now require a valid connection to the database, which is set up in the tests' setUp method. --- mongomap/__init__.py | 34 +++--------------------------- mongomap/base.py | 18 +++++++++++++--- mongomap/collection.py | 18 ++++++++++++++++ mongomap/connection.py | 48 ++++++++++++++++++++++++++++++++++++++++++ mongomap/document.py | 4 ++-- mongomap/fields.py | 6 ++++++ tests/collection.py | 33 +++++++++++++++++++++++++++++ tests/document.py | 45 +++++++++++++++++++++++++++++++++++++++ tests/fields.py | 3 +++ 9 files changed, 173 insertions(+), 36 deletions(-) create mode 100644 mongomap/collection.py create mode 100644 mongomap/connection.py create mode 100644 tests/collection.py diff --git a/mongomap/__init__.py b/mongomap/__init__.py index eb72c30a..b91ca4b5 100644 --- a/mongomap/__init__.py +++ b/mongomap/__init__.py @@ -2,39 +2,11 @@ import document from document import * import fields from fields import * +import connection +from connection import * -from pymongo import Connection - -__all__ = document.__all__ + fields.__all__ + ['connect'] +__all__ = document.__all__ + fields.__all__ + connection.__all__ __author__ = 'Harry Marr' __version__ = '0.1' -_connection_settings = { - 'host': 'localhost', - 'port': 27017, - 'pool_size': 1, -} -_connection = None -_db = None - -def _get_connection(): - if _connection is None: - _connection = Connection(**_connection_settings) - return _connection - -def connect(db=None, username=None, password=None, **kwargs): - """Connect to the database specified by the 'db' argument. Connection - settings may be provided here as well if the database is not running on - the default port on localhost. If authentication is needed, provide - username and password arguments as well. - """ - if db is None: - raise TypeError('"db" argument must be provided to connect()') - - _connection_settings.update(kwargs) - connection = _get_connection() - # Get DB from connection and auth if necessary - _db = connection[db] - if username is not None and password is not None: - _db.authenticate(username, password) diff --git a/mongomap/base.py b/mongomap/base.py index ddae4615..a906fbc6 100644 --- a/mongomap/base.py +++ b/mongomap/base.py @@ -1,3 +1,4 @@ +from collection import CollectionManager class ValidationError(Exception): @@ -128,7 +129,13 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): } meta.update(attrs.get('meta', {})) attrs['_meta'] = meta - return super_new(cls, name, bases, attrs) + + # Set up collection manager, needs the class to have fields so use + # DocumentMetaclass before instantiating CollectionManager object + new_class = super_new(cls, name, bases, attrs) + setattr(new_class, 'collection', CollectionManager(new_class)) + + return new_class class BaseDocument(object): @@ -141,7 +148,7 @@ class BaseDocument(object): setattr(self, attr_name, values.pop(attr_name)) else: if attr_value.required: - raise ValidationError('Field "%s" is required' % self.name) + raise ValidationError('Field "%s" is required' % attr_name) # Use default value setattr(self, attr_name, getattr(self, attr_name, None)) @@ -179,4 +186,9 @@ class BaseDocument(object): def _to_mongo(self): """Return data dictionary ready for use with MongoDB. """ - return dict((k, v) for k, v in self._data.items() if v is not None) + data = {} + for field_name, field in self._fields.items(): + value = getattr(self, field_name, None) + if value is not None: + data[field_name] = field._to_mongo(value) + return data diff --git a/mongomap/collection.py b/mongomap/collection.py new file mode 100644 index 00000000..5160fd8a --- /dev/null +++ b/mongomap/collection.py @@ -0,0 +1,18 @@ +from connection import _get_db + +class CollectionManager(object): + + def __init__(self, document): + """Set up the collection manager for a specific document. + """ + db = _get_db() + self._document = document + self._collection_name = document._meta['collection'] + # This will create the collection if it doesn't exist + self._collection = db[self._collection_name] + self._id_field = document._meta['object_id_field'] + + def _save_document(self, document): + """Save the provided document to the collection. + """ + _id = self._collection.save(document) diff --git a/mongomap/connection.py b/mongomap/connection.py new file mode 100644 index 00000000..de66c476 --- /dev/null +++ b/mongomap/connection.py @@ -0,0 +1,48 @@ +from pymongo import Connection + + +__all__ = ['ConnectionError', 'connect'] + + +_connection_settings = { + 'host': 'localhost', + 'port': 27017, + 'pool_size': 1, +} +_connection = None +_db = None + + +class ConnectionError(Exception): + pass + + +def _get_connection(): + global _connection + if _connection is None: + _connection = Connection(**_connection_settings) + return _connection + +def _get_db(): + global _db + if _db is None: + raise ConnectionError('Not connected to database') + return _db + +def connect(db=None, username=None, password=None, **kwargs): + """Connect to the database specified by the 'db' argument. Connection + settings may be provided here as well if the database is not running on + the default port on localhost. If authentication is needed, provide + username and password arguments as well. + """ + global _db + if db is None: + raise TypeError('"db" argument must be provided to connect()') + + _connection_settings.update(kwargs) + connection = _get_connection() + # Get DB from connection and auth if necessary + _db = connection[db] + if username is not None and password is not None: + _db.authenticate(username, password) + diff --git a/mongomap/document.py b/mongomap/document.py index b1b3ffe3..7b7f0289 100644 --- a/mongomap/document.py +++ b/mongomap/document.py @@ -1,7 +1,5 @@ from base import DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument -#import pymongo - __all__ = ['Document', 'EmbeddedDocument'] @@ -15,3 +13,5 @@ class Document(BaseDocument): __metaclass__ = TopLevelDocumentMetaclass + def save(self): + self.collection._save_document(self._to_mongo()) diff --git a/mongomap/fields.py b/mongomap/fields.py index 6235cb6d..61bfb357 100644 --- a/mongomap/fields.py +++ b/mongomap/fields.py @@ -61,7 +61,13 @@ class EmbeddedDocumentField(BaseField): def _to_python(self, value): return value + def _to_mongo(self, value): + return self.document._to_mongo(value) + def _validate(self, value): + """Make sure that the document instance is an instance of the + EmbeddedDocument subclass provided when the document was defined. + """ if not isinstance(value, self.document): raise ValidationError('Invalid embedded document instance ' 'provided to an EmbeddedDocumentField') diff --git a/tests/collection.py b/tests/collection.py new file mode 100644 index 00000000..4f2bc71d --- /dev/null +++ b/tests/collection.py @@ -0,0 +1,33 @@ +import unittest +import pymongo + +from mongomap.collection import CollectionManager +from mongomap import * + + +class CollectionManagerTest(unittest.TestCase): + + def setUp(self): + connect(db='mongotest') + + class Person(Document): + name = StringField() + age = IntField() + self.Person = Person + + def test_initialisation(self): + """Ensure that CollectionManager is correctly initialised. + """ + class Person(Document): + name = StringField() + age = IntField() + + self.assertTrue(isinstance(Person.collection, CollectionManager)) + self.assertEqual(Person.collection._collection_name, + Person._meta['collection']) + self.assertTrue(isinstance(Person.collection._collection, + pymongo.collection.Collection)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/document.py b/tests/document.py index 177560f1..57ed4dbb 100644 --- a/tests/document.py +++ b/tests/document.py @@ -1,16 +1,22 @@ import unittest from mongomap import * +from mongomap.connection import _get_db class DocumentTest(unittest.TestCase): def setUp(self): + connect(db='mongomaptest') + class Person(Document): name = StringField() age = IntField() self.Person = Person + self.db = _get_db() + self.db.drop_collection(self.Person._meta['collection']) + def test_definition(self): """Ensure that document may be defined using fields. """ @@ -77,6 +83,45 @@ class DocumentTest(unittest.TestCase): self.assertTrue('content' in Comment._fields) self.assertFalse(hasattr(Comment, '_meta')) + def test_save(self): + """Ensure that a document may be saved in the database. + """ + # Create person object and save it to the database + person = self.Person(name='Test User', age=30) + person.save() + # Ensure that the object is in the database + collection = self.db[self.Person._meta['collection']] + person_obj = collection.find_one({'name': 'Test User'}) + self.assertEqual(person_obj['name'], 'Test User') + self.assertEqual(person_obj['age'], 30) + + def test_save_embedded_document(self): + """Ensure that a document with an embedded document field may be + saved in the database. + """ + class EmployeeDetails(EmbeddedDocument): + position = StringField() + + class Employee(self.Person): + salary = IntField() + details = EmbeddedDocumentField(EmployeeDetails) + + # Create employee object and save it to the database + employee = Employee(name='Test Employee', age=50, salary=20000) + employee.details = EmployeeDetails(position='Developer') + employee.save() + + # Ensure that the object is in the database + collection = self.db[self.Person._meta['collection']] + employee_obj = collection.find_one({'name': 'Test Employee'}) + self.assertEqual(employee_obj['name'], 'Test Employee') + self.assertEqual(employee_obj['age'], 50) + # Ensure that the 'details' embedded object saved correctly + self.assertEqual(employee_obj['details']['position'], 'Developer') + + def tearDown(self): + self.db.drop_collection(self.Person._meta['collection']) + if __name__ == '__main__': unittest.main() diff --git a/tests/fields.py b/tests/fields.py index 58342453..0bfb6380 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -5,6 +5,9 @@ from mongomap import * class FieldTest(unittest.TestCase): + def setUp(self): + connect(db='mongomaptest') + def test_default_values(self): """Ensure that default field values are used when creating a document. """