From a59b518cf2ca53da5699ce34a95e15cc07561a6d Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Fri, 17 Feb 2012 11:18:25 +0000 Subject: [PATCH] Updates to imports for future pymongo 2.2 --- mongoengine/base.py | 14 ++++++------ mongoengine/connection.py | 21 ++++++++++-------- mongoengine/dereference.py | 16 +++++++------- mongoengine/document.py | 15 +++++++------ mongoengine/fields.py | 22 +++++++++---------- mongoengine/queryset.py | 44 ++++++++++++++++++-------------------- tests/connection.py | 11 +++++++++- tests/document.py | 3 ++- tests/queryset.py | 4 ++-- 9 files changed, 81 insertions(+), 69 deletions(-) diff --git a/mongoengine/base.py b/mongoengine/base.py index 3190de6f..be60db91 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -6,9 +6,11 @@ from mongoengine import signals import sys import pymongo -import pymongo.objectid +from bson import ObjectId import operator + from functools import partial +from bson.dbref import DBRef class NotRegistered(Exception): @@ -295,7 +297,7 @@ class ComplexBaseField(BaseField): self.error('You can only reference documents once they' ' have been saved to the database') collection = v._get_collection_name() - value_dict[k] = pymongo.dbref.DBRef(collection, v.pk) + value_dict[k] = DBRef(collection, v.pk) elif hasattr(v, 'to_python'): value_dict[k] = v.to_python() else: @@ -344,7 +346,7 @@ class ComplexBaseField(BaseField): value_dict[k] = GenericReferenceField().to_mongo(v) else: collection = v._get_collection_name() - value_dict[k] = pymongo.dbref.DBRef(collection, v.pk) + value_dict[k] = DBRef(collection, v.pk) elif hasattr(v, 'to_mongo'): value_dict[k] = v.to_mongo() else: @@ -447,9 +449,9 @@ class ObjectIdField(BaseField): return value def to_mongo(self, value): - if not isinstance(value, pymongo.objectid.ObjectId): + if not isinstance(value, ObjectId): try: - return pymongo.objectid.ObjectId(unicode(value)) + return ObjectId(unicode(value)) except Exception, e: # e.message attribute has been deprecated since Python 2.6 self.error(unicode(e)) @@ -460,7 +462,7 @@ class ObjectIdField(BaseField): def validate(self, value): try: - pymongo.objectid.ObjectId(unicode(value)) + ObjectId(unicode(value)) except: self.error('Invalid Object ID') diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 822c604b..b6b716f8 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -46,7 +46,12 @@ def register_connection(alias, name, host='localhost', port=27017, raise ConnectionError("If using URI style connection include "\ "database name in string") uri_dict['name'] = uri_dict.get('database') - _connection_settings[alias] = uri_dict + _connection_settings[alias] = { + 'host': host, + 'name': uri_dict.get('database'), + 'username': uri_dict.get('username'), + 'password': uri_dict.get('password') + } return _connection_settings[alias] = { @@ -89,11 +94,11 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): conn_settings = _connection_settings[alias].copy() if hasattr(pymongo, 'version_tuple'): # Support for 2.1+ - conn_settings.pop('name') - conn_settings.pop('slaves') - conn_settings.pop('is_slave') - conn_settings.pop('username') - conn_settings.pop('password') + conn_settings.pop('name', None) + conn_settings.pop('slaves', None) + conn_settings.pop('is_slave', None) + conn_settings.pop('username', None) + conn_settings.pop('password', None) else: # Get all the slave connections if 'slaves' in conn_settings: @@ -106,8 +111,7 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): try: _connections[alias] = Connection(**conn_settings) except Exception, e: - raise e - raise ConnectionError('Cannot connect to database %s' % alias) + raise ConnectionError("Cannot connect to database %s :\n%s" % (alias, e)) return _connections[alias] @@ -120,7 +124,6 @@ def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False): conn = get_connection(alias) conn_settings = _connection_settings[alias] _dbs[alias] = conn[conn_settings['name']] - # Authenticate if necessary if conn_settings['username'] and conn_settings['password']: _dbs[alias].authenticate(conn_settings['username'], diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index 40ceb49c..a5ad9166 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -1,4 +1,4 @@ -import pymongo +from bson import DBRef, SON from base import (BaseDict, BaseList, TopLevelDocumentMetaclass, get_document) from fields import ReferenceField @@ -68,9 +68,9 @@ class DeReference(object): if hasattr(item, '_fields'): for field_name, field in item._fields.iteritems(): v = item._data.get(field_name, None) - if isinstance(v, (pymongo.dbref.DBRef)): + if isinstance(v, (DBRef)): reference_map.setdefault(field.document_type, []).append(v.id) - elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v: + elif isinstance(v, (dict, SON)) and '_ref' in v: reference_map.setdefault(get_document(v['_cls']), []).append(v['_ref'].id) elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: field_cls = getattr(getattr(field, 'field', None), 'document_type', None) @@ -79,9 +79,9 @@ class DeReference(object): if isinstance(field_cls, (Document, TopLevelDocumentMetaclass)): key = field_cls reference_map.setdefault(key, []).extend(refs) - elif isinstance(item, (pymongo.dbref.DBRef)): + elif isinstance(item, (DBRef)): reference_map.setdefault(item.collection, []).append(item.id) - elif isinstance(item, (dict, pymongo.son.SON)) and '_ref' in item: + elif isinstance(item, (dict, SON)) and '_ref' in item: reference_map.setdefault(get_document(item['_cls']), []).append(item['_ref'].id) elif isinstance(item, (dict, list, tuple)) and depth - 1 <= self.max_depth: references = self._find_references(item, depth - 1) @@ -138,7 +138,7 @@ class DeReference(object): else: return BaseList(items, instance, name) - if isinstance(items, (dict, pymongo.son.SON)): + if isinstance(items, (dict, SON)): if '_ref' in items: return self.object_map.get(items['_ref'].id, items) elif '_types' in items and '_cls' in items: @@ -167,9 +167,9 @@ class DeReference(object): elif hasattr(v, '_fields'): for field_name, field in v._fields.iteritems(): v = data[k]._data.get(field_name, None) - if isinstance(v, (pymongo.dbref.DBRef)): + if isinstance(v, (DBRef)): data[k]._data[field_name] = self.object_map.get(v.id, v) - elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v: + elif isinstance(v, (dict, SON)) and '_ref' in v: data[k]._data[field_name] = self.object_map.get(v['_ref'].id, v) elif isinstance(v, dict) and depth <= self.max_depth: data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name) diff --git a/mongoengine/document.py b/mongoengine/document.py index 686945c2..6d9d3148 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -1,11 +1,12 @@ +import pymongo +from bson.dbref import DBRef + from mongoengine import signals from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, BaseDict, BaseList) from queryset import OperationError from connection import get_db, DEFAULT_CONNECTION_NAME -import pymongo - __all__ = ['Document', 'EmbeddedDocument', 'DynamicDocument', 'DynamicEmbeddedDocument', 'OperationError', 'InvalidCollectionError'] @@ -151,7 +152,7 @@ class Document(BaseDocument): .. versionchanged:: 0.5 In existing documents it only saves changed fields using set / unset - Saves are cascaded and any :class:`~pymongo.dbref.DBRef` objects + Saves are cascaded and any :class:`~bson.dbref.DBRef` objects that have changes are saved as well. .. versionchanged:: 0.6 Cascade saves are optional = defaults to True, if you want fine grain @@ -271,7 +272,7 @@ class Document(BaseDocument): signals.post_delete.send(self.__class__, document=self) def select_related(self, max_depth=1): - """Handles dereferencing of :class:`~pymongo.dbref.DBRef` objects to + """Handles dereferencing of :class:`~bson.dbref.DBRef` objects to a maximum depth in order to cut down the number queries to mongodb. .. versionadded:: 0.5 @@ -313,12 +314,12 @@ class Document(BaseDocument): return value def to_dbref(self): - """Returns an instance of :class:`~pymongo.dbref.DBRef` useful in + """Returns an instance of :class:`~bson.dbref.DBRef` useful in `__raw__` queries.""" if not self.pk: msg = "Only saved documents can have a valid dbref" raise OperationError(msg) - return pymongo.dbref.DBRef(self.__class__._get_collection_name(), self.pk) + return DBRef(self.__class__._get_collection_name(), self.pk) @classmethod def register_delete_rule(cls, document_cls, field_name, rule): @@ -385,7 +386,7 @@ class MapReduceDocument(object): :param collection: An instance of :class:`~pymongo.Collection` :param key: Document/result key, often an instance of - :class:`~pymongo.objectid.ObjectId`. If supplied as + :class:`~bson.objectid.ObjectId`. If supplied as an ``ObjectId`` found in the given ``collection``, the object can be accessed via the ``object`` property. :param value: The result(s) for this key. diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 2e4411a2..33b0ed5b 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -2,13 +2,11 @@ import datetime import time import decimal import gridfs -import pymongo -import pymongo.binary -import pymongo.dbref -import pymongo.son import re import uuid +from bson import Binary, DBRef, SON, ObjectId + from base import (BaseField, ComplexBaseField, ObjectIdField, ValidationError, get_document) from queryset import DO_NOTHING, QuerySet @@ -644,7 +642,7 @@ class ReferenceField(BaseField): # Get value from document instance if available value = instance._data.get(self.name) # Dereference DBRefs - if isinstance(value, (pymongo.dbref.DBRef)): + if isinstance(value, (DBRef)): value = self.document_type._get_db().dereference(value) if value is not None: instance._data[self.name] = self.document_type._from_son(value) @@ -666,7 +664,7 @@ class ReferenceField(BaseField): id_ = id_field.to_mongo(id_) collection = self.document_type._get_collection_name() - return pymongo.dbref.DBRef(collection, id_) + return DBRef(collection, id_) def prepare_query_value(self, op, value): if value is None: @@ -675,7 +673,7 @@ class ReferenceField(BaseField): return self.to_mongo(value) def validate(self, value): - if not isinstance(value, (self.document_type, pymongo.dbref.DBRef)): + if not isinstance(value, (self.document_type, DBRef)): self.error('A ReferenceField only accepts DBRef') if isinstance(value, Document) and value.id is None: @@ -701,13 +699,13 @@ class GenericReferenceField(BaseField): return self value = instance._data.get(self.name) - if isinstance(value, (dict, pymongo.son.SON)): + if isinstance(value, (dict, SON)): instance._data[self.name] = self.dereference(value) return super(GenericReferenceField, self).__get__(instance, owner) def validate(self, value): - if not isinstance(value, (Document, pymongo.dbref.DBRef)): + if not isinstance(value, (Document, DBRef)): self.error('GenericReferences can only contain documents') # We need the id from the saved object to create the DBRef @@ -741,7 +739,7 @@ class GenericReferenceField(BaseField): id_ = id_field.to_mongo(id_) collection = document._get_collection_name() - ref = pymongo.dbref.DBRef(collection, id_) + ref = DBRef(collection, id_) return {'_cls': document._class_name, '_ref': ref} def prepare_query_value(self, op, value): @@ -760,7 +758,7 @@ class BinaryField(BaseField): super(BinaryField, self).__init__(**kwargs) def to_mongo(self, value): - return pymongo.binary.Binary(value) + return Binary(value) def to_python(self, value): # Returns str not unicode as this is binary data @@ -964,7 +962,7 @@ class FileField(BaseField): if value.grid_id is not None: if not isinstance(value, self.proxy_class): self.error('FileField only accepts GridFSProxy values') - if not isinstance(value.grid_id, pymongo.objectid.ObjectId): + if not isinstance(value.grid_id, ObjectId): self.error('Invalid GridFSProxy value') diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 0c0b11f4..9032fc7f 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -1,16 +1,14 @@ -from connection import get_db -from mongoengine import signals - import pprint -import pymongo -import pymongo.code -import pymongo.dbref -import pymongo.objectid import re import copy import itertools import operator +import pymongo +from bson.code import Code + +from mongoengine import signals + __all__ = ['queryset_manager', 'Q', 'InvalidQueryError', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY'] @@ -935,9 +933,9 @@ class QuerySet(object): and :meth:`~mongoengine.tests.QuerySetTest.test_map_advanced` tests in ``tests.queryset.QuerySetTest`` for usage examples. - :param map_f: map function, as :class:`~pymongo.code.Code` or string + :param map_f: map function, as :class:`~bson.code.Code` or string :param reduce_f: reduce function, as - :class:`~pymongo.code.Code` or string + :class:`~bson.code.Code` or string :param output: output collection name, if set to 'inline' will try to use :class:`~pymongo.collection.Collection.inline_map_reduce` :param finalize_f: finalize function, an optional function that @@ -967,27 +965,27 @@ class QuerySet(object): raise NotImplementedError("Requires MongoDB >= 1.7.1") map_f_scope = {} - if isinstance(map_f, pymongo.code.Code): + if isinstance(map_f, Code): map_f_scope = map_f.scope map_f = unicode(map_f) - map_f = pymongo.code.Code(self._sub_js_fields(map_f), map_f_scope) + map_f = Code(self._sub_js_fields(map_f), map_f_scope) reduce_f_scope = {} - if isinstance(reduce_f, pymongo.code.Code): + if isinstance(reduce_f, Code): reduce_f_scope = reduce_f.scope reduce_f = unicode(reduce_f) reduce_f_code = self._sub_js_fields(reduce_f) - reduce_f = pymongo.code.Code(reduce_f_code, reduce_f_scope) + reduce_f = Code(reduce_f_code, reduce_f_scope) mr_args = {'query': self._query} if finalize_f: finalize_f_scope = {} - if isinstance(finalize_f, pymongo.code.Code): + if isinstance(finalize_f, Code): finalize_f_scope = finalize_f.scope finalize_f = unicode(finalize_f) finalize_f_code = self._sub_js_fields(finalize_f) - finalize_f = pymongo.code.Code(finalize_f_code, finalize_f_scope) + finalize_f = Code(finalize_f_code, finalize_f_scope) mr_args['finalize'] = finalize_f if scope: @@ -1499,7 +1497,7 @@ class QuerySet(object): query['$where'] = self._where_clause scope['query'] = query - code = pymongo.code.Code(code, scope=scope) + code = Code(code, scope=scope) db = self._document._get_db() return db.eval(code, *fields) @@ -1528,13 +1526,13 @@ class QuerySet(object): .. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work with sharding. """ - map_func = pymongo.code.Code(""" + map_func = Code(""" function() { emit(1, this[field] || 0); } """, scope={'field': field}) - reduce_func = pymongo.code.Code(""" + reduce_func = Code(""" function(key, values) { var sum = 0; for (var i in values) { @@ -1558,14 +1556,14 @@ class QuerySet(object): .. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work with sharding. """ - map_func = pymongo.code.Code(""" + map_func = Code(""" function() { if (this.hasOwnProperty(field)) emit(1, {t: this[field] || 0, c: 1}); } """, scope={'field': field}) - reduce_func = pymongo.code.Code(""" + reduce_func = Code(""" function(key, values) { var out = {t: 0, c: 0}; for (var i in values) { @@ -1577,7 +1575,7 @@ class QuerySet(object): } """) - finalize_func = pymongo.code.Code(""" + finalize_func = Code(""" function(key, value) { return value.t / value.c; } @@ -1719,7 +1717,7 @@ class QuerySet(object): def __repr__(self): limit = REPR_OUTPUT_SIZE + 1 - start = ( 0 if self._skip is None else self._skip ) + start = (0 if self._skip is None else self._skip) if self._limit is None: stop = start + limit if self._limit is not None: @@ -1736,7 +1734,7 @@ class QuerySet(object): return repr(data) def select_related(self, max_depth=1): - """Handles dereferencing of :class:`~pymongo.dbref.DBRef` objects to + """Handles dereferencing of :class:`~bson.dbref.DBRef` objects to a maximum depth in order to cut down the number queries to mongodb. .. versionadded:: 0.5 diff --git a/tests/connection.py b/tests/connection.py index 7ff0998e..ce20a54a 100644 --- a/tests/connection.py +++ b/tests/connection.py @@ -4,7 +4,7 @@ import pymongo import mongoengine.connection from mongoengine import * -from mongoengine.connection import get_db, get_connection +from mongoengine.connection import get_db, get_connection, ConnectionError class ConnectionTest(unittest.TestCase): @@ -33,6 +33,15 @@ class ConnectionTest(unittest.TestCase): def test_connect_uri(self): """Ensure that the connect() method works properly with uri's """ + c = connect(db='mongoenginetest', alias='admin') + c.admin.system.users.remove({}) + c.mongoenginetest.system.users.remove({}) + + c.admin.add_user("admin", "password") + c.admin.authenticate("admin", "password") + c.mongoenginetest.add_user("username", "password") + + self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost/mongoenginetest') connect("testdb_uri", host='mongodb://username:password@localhost/mongoenginetest') diff --git a/tests/document.py b/tests/document.py index dabc27e1..b06731dd 100644 --- a/tests/document.py +++ b/tests/document.py @@ -1,5 +1,6 @@ import pickle import pymongo +import bson import unittest import warnings @@ -2222,7 +2223,7 @@ class DocumentTest(unittest.TestCase): # Test laziness self.assertTrue(isinstance(post_obj._data['author'], - pymongo.dbref.DBRef)) + bson.DBRef)) self.assertTrue(isinstance(post_obj.author, self.Person)) self.assertEqual(post_obj.author.name, 'Test User') diff --git a/tests/queryset.py b/tests/queryset.py index 555e7d00..06b65bf4 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import unittest import pymongo +from bson import ObjectId from datetime import datetime, timedelta from mongoengine.queryset import (QuerySet, QuerySetManager, @@ -59,8 +60,7 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(len(people), 2) results = list(people) self.assertTrue(isinstance(results[0], self.Person)) - self.assertTrue(isinstance(results[0].id, (pymongo.objectid.ObjectId, - str, unicode))) + self.assertTrue(isinstance(results[0].id, (ObjectId, str, unicode))) self.assertEqual(results[0].name, "User A") self.assertEqual(results[0].age, 20) self.assertEqual(results[1].name, "User B")