Updates to imports for future pymongo 2.2

This commit is contained in:
Ross Lawley 2012-02-17 11:18:25 +00:00
parent a15352a4f8
commit a59b518cf2
9 changed files with 81 additions and 69 deletions

View File

@ -6,9 +6,11 @@ from mongoengine import signals
import sys import sys
import pymongo import pymongo
import pymongo.objectid from bson import ObjectId
import operator import operator
from functools import partial from functools import partial
from bson.dbref import DBRef
class NotRegistered(Exception): class NotRegistered(Exception):
@ -295,7 +297,7 @@ class ComplexBaseField(BaseField):
self.error('You can only reference documents once they' self.error('You can only reference documents once they'
' have been saved to the database') ' have been saved to the database')
collection = v._get_collection_name() collection = v._get_collection_name()
value_dict[k] = pymongo.dbref.DBRef(collection, v.pk) value_dict[k] = DBRef(collection, v.pk)
elif hasattr(v, 'to_python'): elif hasattr(v, 'to_python'):
value_dict[k] = v.to_python() value_dict[k] = v.to_python()
else: else:
@ -344,7 +346,7 @@ class ComplexBaseField(BaseField):
value_dict[k] = GenericReferenceField().to_mongo(v) value_dict[k] = GenericReferenceField().to_mongo(v)
else: else:
collection = v._get_collection_name() collection = v._get_collection_name()
value_dict[k] = pymongo.dbref.DBRef(collection, v.pk) value_dict[k] = DBRef(collection, v.pk)
elif hasattr(v, 'to_mongo'): elif hasattr(v, 'to_mongo'):
value_dict[k] = v.to_mongo() value_dict[k] = v.to_mongo()
else: else:
@ -447,9 +449,9 @@ class ObjectIdField(BaseField):
return value return value
def to_mongo(self, value): def to_mongo(self, value):
if not isinstance(value, pymongo.objectid.ObjectId): if not isinstance(value, ObjectId):
try: try:
return pymongo.objectid.ObjectId(unicode(value)) return ObjectId(unicode(value))
except Exception, e: except Exception, e:
# e.message attribute has been deprecated since Python 2.6 # e.message attribute has been deprecated since Python 2.6
self.error(unicode(e)) self.error(unicode(e))
@ -460,7 +462,7 @@ class ObjectIdField(BaseField):
def validate(self, value): def validate(self, value):
try: try:
pymongo.objectid.ObjectId(unicode(value)) ObjectId(unicode(value))
except: except:
self.error('Invalid Object ID') self.error('Invalid Object ID')

View File

@ -46,7 +46,12 @@ def register_connection(alias, name, host='localhost', port=27017,
raise ConnectionError("If using URI style connection include "\ raise ConnectionError("If using URI style connection include "\
"database name in string") "database name in string")
uri_dict['name'] = uri_dict.get('database') uri_dict['name'] = uri_dict.get('database')
_connection_settings[alias] = uri_dict _connection_settings[alias] = {
'host': host,
'name': uri_dict.get('database'),
'username': uri_dict.get('username'),
'password': uri_dict.get('password')
}
return return
_connection_settings[alias] = { _connection_settings[alias] = {
@ -89,11 +94,11 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
conn_settings = _connection_settings[alias].copy() conn_settings = _connection_settings[alias].copy()
if hasattr(pymongo, 'version_tuple'): # Support for 2.1+ if hasattr(pymongo, 'version_tuple'): # Support for 2.1+
conn_settings.pop('name') conn_settings.pop('name', None)
conn_settings.pop('slaves') conn_settings.pop('slaves', None)
conn_settings.pop('is_slave') conn_settings.pop('is_slave', None)
conn_settings.pop('username') conn_settings.pop('username', None)
conn_settings.pop('password') conn_settings.pop('password', None)
else: else:
# Get all the slave connections # Get all the slave connections
if 'slaves' in conn_settings: if 'slaves' in conn_settings:
@ -106,8 +111,7 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
try: try:
_connections[alias] = Connection(**conn_settings) _connections[alias] = Connection(**conn_settings)
except Exception, e: except Exception, e:
raise e raise ConnectionError("Cannot connect to database %s :\n%s" % (alias, e))
raise ConnectionError('Cannot connect to database %s' % alias)
return _connections[alias] return _connections[alias]
@ -120,7 +124,6 @@ def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
conn = get_connection(alias) conn = get_connection(alias)
conn_settings = _connection_settings[alias] conn_settings = _connection_settings[alias]
_dbs[alias] = conn[conn_settings['name']] _dbs[alias] = conn[conn_settings['name']]
# Authenticate if necessary # Authenticate if necessary
if conn_settings['username'] and conn_settings['password']: if conn_settings['username'] and conn_settings['password']:
_dbs[alias].authenticate(conn_settings['username'], _dbs[alias].authenticate(conn_settings['username'],

View File

@ -1,4 +1,4 @@
import pymongo from bson import DBRef, SON
from base import (BaseDict, BaseList, TopLevelDocumentMetaclass, get_document) from base import (BaseDict, BaseList, TopLevelDocumentMetaclass, get_document)
from fields import ReferenceField from fields import ReferenceField
@ -68,9 +68,9 @@ class DeReference(object):
if hasattr(item, '_fields'): if hasattr(item, '_fields'):
for field_name, field in item._fields.iteritems(): for field_name, field in item._fields.iteritems():
v = item._data.get(field_name, None) v = item._data.get(field_name, None)
if isinstance(v, (pymongo.dbref.DBRef)): if isinstance(v, (DBRef)):
reference_map.setdefault(field.document_type, []).append(v.id) reference_map.setdefault(field.document_type, []).append(v.id)
elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v: elif isinstance(v, (dict, SON)) and '_ref' in v:
reference_map.setdefault(get_document(v['_cls']), []).append(v['_ref'].id) reference_map.setdefault(get_document(v['_cls']), []).append(v['_ref'].id)
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
field_cls = getattr(getattr(field, 'field', None), 'document_type', None) field_cls = getattr(getattr(field, 'field', None), 'document_type', None)
@ -79,9 +79,9 @@ class DeReference(object):
if isinstance(field_cls, (Document, TopLevelDocumentMetaclass)): if isinstance(field_cls, (Document, TopLevelDocumentMetaclass)):
key = field_cls key = field_cls
reference_map.setdefault(key, []).extend(refs) reference_map.setdefault(key, []).extend(refs)
elif isinstance(item, (pymongo.dbref.DBRef)): elif isinstance(item, (DBRef)):
reference_map.setdefault(item.collection, []).append(item.id) reference_map.setdefault(item.collection, []).append(item.id)
elif isinstance(item, (dict, pymongo.son.SON)) and '_ref' in item: elif isinstance(item, (dict, SON)) and '_ref' in item:
reference_map.setdefault(get_document(item['_cls']), []).append(item['_ref'].id) reference_map.setdefault(get_document(item['_cls']), []).append(item['_ref'].id)
elif isinstance(item, (dict, list, tuple)) and depth - 1 <= self.max_depth: elif isinstance(item, (dict, list, tuple)) and depth - 1 <= self.max_depth:
references = self._find_references(item, depth - 1) references = self._find_references(item, depth - 1)
@ -138,7 +138,7 @@ class DeReference(object):
else: else:
return BaseList(items, instance, name) return BaseList(items, instance, name)
if isinstance(items, (dict, pymongo.son.SON)): if isinstance(items, (dict, SON)):
if '_ref' in items: if '_ref' in items:
return self.object_map.get(items['_ref'].id, items) return self.object_map.get(items['_ref'].id, items)
elif '_types' in items and '_cls' in items: elif '_types' in items and '_cls' in items:
@ -167,9 +167,9 @@ class DeReference(object):
elif hasattr(v, '_fields'): elif hasattr(v, '_fields'):
for field_name, field in v._fields.iteritems(): for field_name, field in v._fields.iteritems():
v = data[k]._data.get(field_name, None) v = data[k]._data.get(field_name, None)
if isinstance(v, (pymongo.dbref.DBRef)): if isinstance(v, (DBRef)):
data[k]._data[field_name] = self.object_map.get(v.id, v) data[k]._data[field_name] = self.object_map.get(v.id, v)
elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v: elif isinstance(v, (dict, SON)) and '_ref' in v:
data[k]._data[field_name] = self.object_map.get(v['_ref'].id, v) data[k]._data[field_name] = self.object_map.get(v['_ref'].id, v)
elif isinstance(v, dict) and depth <= self.max_depth: elif isinstance(v, dict) and depth <= self.max_depth:
data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name) data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name)

View File

@ -1,11 +1,12 @@
import pymongo
from bson.dbref import DBRef
from mongoengine import signals from mongoengine import signals
from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument,
BaseDict, BaseList) BaseDict, BaseList)
from queryset import OperationError from queryset import OperationError
from connection import get_db, DEFAULT_CONNECTION_NAME from connection import get_db, DEFAULT_CONNECTION_NAME
import pymongo
__all__ = ['Document', 'EmbeddedDocument', 'DynamicDocument', __all__ = ['Document', 'EmbeddedDocument', 'DynamicDocument',
'DynamicEmbeddedDocument', 'OperationError', 'InvalidCollectionError'] 'DynamicEmbeddedDocument', 'OperationError', 'InvalidCollectionError']
@ -151,7 +152,7 @@ class Document(BaseDocument):
.. versionchanged:: 0.5 .. versionchanged:: 0.5
In existing documents it only saves changed fields using set / unset In existing documents it only saves changed fields using set / unset
Saves are cascaded and any :class:`~pymongo.dbref.DBRef` objects Saves are cascaded and any :class:`~bson.dbref.DBRef` objects
that have changes are saved as well. that have changes are saved as well.
.. versionchanged:: 0.6 .. versionchanged:: 0.6
Cascade saves are optional = defaults to True, if you want fine grain Cascade saves are optional = defaults to True, if you want fine grain
@ -271,7 +272,7 @@ class Document(BaseDocument):
signals.post_delete.send(self.__class__, document=self) signals.post_delete.send(self.__class__, document=self)
def select_related(self, max_depth=1): def select_related(self, max_depth=1):
"""Handles dereferencing of :class:`~pymongo.dbref.DBRef` objects to """Handles dereferencing of :class:`~bson.dbref.DBRef` objects to
a maximum depth in order to cut down the number queries to mongodb. a maximum depth in order to cut down the number queries to mongodb.
.. versionadded:: 0.5 .. versionadded:: 0.5
@ -313,12 +314,12 @@ class Document(BaseDocument):
return value return value
def to_dbref(self): def to_dbref(self):
"""Returns an instance of :class:`~pymongo.dbref.DBRef` useful in """Returns an instance of :class:`~bson.dbref.DBRef` useful in
`__raw__` queries.""" `__raw__` queries."""
if not self.pk: if not self.pk:
msg = "Only saved documents can have a valid dbref" msg = "Only saved documents can have a valid dbref"
raise OperationError(msg) raise OperationError(msg)
return pymongo.dbref.DBRef(self.__class__._get_collection_name(), self.pk) return DBRef(self.__class__._get_collection_name(), self.pk)
@classmethod @classmethod
def register_delete_rule(cls, document_cls, field_name, rule): def register_delete_rule(cls, document_cls, field_name, rule):
@ -385,7 +386,7 @@ class MapReduceDocument(object):
:param collection: An instance of :class:`~pymongo.Collection` :param collection: An instance of :class:`~pymongo.Collection`
:param key: Document/result key, often an instance of :param key: Document/result key, often an instance of
:class:`~pymongo.objectid.ObjectId`. If supplied as :class:`~bson.objectid.ObjectId`. If supplied as
an ``ObjectId`` found in the given ``collection``, an ``ObjectId`` found in the given ``collection``,
the object can be accessed via the ``object`` property. the object can be accessed via the ``object`` property.
:param value: The result(s) for this key. :param value: The result(s) for this key.

View File

@ -2,13 +2,11 @@ import datetime
import time import time
import decimal import decimal
import gridfs import gridfs
import pymongo
import pymongo.binary
import pymongo.dbref
import pymongo.son
import re import re
import uuid import uuid
from bson import Binary, DBRef, SON, ObjectId
from base import (BaseField, ComplexBaseField, ObjectIdField, from base import (BaseField, ComplexBaseField, ObjectIdField,
ValidationError, get_document) ValidationError, get_document)
from queryset import DO_NOTHING, QuerySet from queryset import DO_NOTHING, QuerySet
@ -644,7 +642,7 @@ class ReferenceField(BaseField):
# Get value from document instance if available # Get value from document instance if available
value = instance._data.get(self.name) value = instance._data.get(self.name)
# Dereference DBRefs # Dereference DBRefs
if isinstance(value, (pymongo.dbref.DBRef)): if isinstance(value, (DBRef)):
value = self.document_type._get_db().dereference(value) value = self.document_type._get_db().dereference(value)
if value is not None: if value is not None:
instance._data[self.name] = self.document_type._from_son(value) instance._data[self.name] = self.document_type._from_son(value)
@ -666,7 +664,7 @@ class ReferenceField(BaseField):
id_ = id_field.to_mongo(id_) id_ = id_field.to_mongo(id_)
collection = self.document_type._get_collection_name() collection = self.document_type._get_collection_name()
return pymongo.dbref.DBRef(collection, id_) return DBRef(collection, id_)
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
if value is None: if value is None:
@ -675,7 +673,7 @@ class ReferenceField(BaseField):
return self.to_mongo(value) return self.to_mongo(value)
def validate(self, value): def validate(self, value):
if not isinstance(value, (self.document_type, pymongo.dbref.DBRef)): if not isinstance(value, (self.document_type, DBRef)):
self.error('A ReferenceField only accepts DBRef') self.error('A ReferenceField only accepts DBRef')
if isinstance(value, Document) and value.id is None: if isinstance(value, Document) and value.id is None:
@ -701,13 +699,13 @@ class GenericReferenceField(BaseField):
return self return self
value = instance._data.get(self.name) value = instance._data.get(self.name)
if isinstance(value, (dict, pymongo.son.SON)): if isinstance(value, (dict, SON)):
instance._data[self.name] = self.dereference(value) instance._data[self.name] = self.dereference(value)
return super(GenericReferenceField, self).__get__(instance, owner) return super(GenericReferenceField, self).__get__(instance, owner)
def validate(self, value): def validate(self, value):
if not isinstance(value, (Document, pymongo.dbref.DBRef)): if not isinstance(value, (Document, DBRef)):
self.error('GenericReferences can only contain documents') self.error('GenericReferences can only contain documents')
# We need the id from the saved object to create the DBRef # We need the id from the saved object to create the DBRef
@ -741,7 +739,7 @@ class GenericReferenceField(BaseField):
id_ = id_field.to_mongo(id_) id_ = id_field.to_mongo(id_)
collection = document._get_collection_name() collection = document._get_collection_name()
ref = pymongo.dbref.DBRef(collection, id_) ref = DBRef(collection, id_)
return {'_cls': document._class_name, '_ref': ref} return {'_cls': document._class_name, '_ref': ref}
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
@ -760,7 +758,7 @@ class BinaryField(BaseField):
super(BinaryField, self).__init__(**kwargs) super(BinaryField, self).__init__(**kwargs)
def to_mongo(self, value): def to_mongo(self, value):
return pymongo.binary.Binary(value) return Binary(value)
def to_python(self, value): def to_python(self, value):
# Returns str not unicode as this is binary data # Returns str not unicode as this is binary data
@ -964,7 +962,7 @@ class FileField(BaseField):
if value.grid_id is not None: if value.grid_id is not None:
if not isinstance(value, self.proxy_class): if not isinstance(value, self.proxy_class):
self.error('FileField only accepts GridFSProxy values') self.error('FileField only accepts GridFSProxy values')
if not isinstance(value.grid_id, pymongo.objectid.ObjectId): if not isinstance(value.grid_id, ObjectId):
self.error('Invalid GridFSProxy value') self.error('Invalid GridFSProxy value')

View File

@ -1,16 +1,14 @@
from connection import get_db
from mongoengine import signals
import pprint import pprint
import pymongo
import pymongo.code
import pymongo.dbref
import pymongo.objectid
import re import re
import copy import copy
import itertools import itertools
import operator import operator
import pymongo
from bson.code import Code
from mongoengine import signals
__all__ = ['queryset_manager', 'Q', 'InvalidQueryError', __all__ = ['queryset_manager', 'Q', 'InvalidQueryError',
'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY'] 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY']
@ -935,9 +933,9 @@ class QuerySet(object):
and :meth:`~mongoengine.tests.QuerySetTest.test_map_advanced` and :meth:`~mongoengine.tests.QuerySetTest.test_map_advanced`
tests in ``tests.queryset.QuerySetTest`` for usage examples. tests in ``tests.queryset.QuerySetTest`` for usage examples.
:param map_f: map function, as :class:`~pymongo.code.Code` or string :param map_f: map function, as :class:`~bson.code.Code` or string
:param reduce_f: reduce function, as :param reduce_f: reduce function, as
:class:`~pymongo.code.Code` or string :class:`~bson.code.Code` or string
:param output: output collection name, if set to 'inline' will try to :param output: output collection name, if set to 'inline' will try to
use :class:`~pymongo.collection.Collection.inline_map_reduce` use :class:`~pymongo.collection.Collection.inline_map_reduce`
:param finalize_f: finalize function, an optional function that :param finalize_f: finalize function, an optional function that
@ -967,27 +965,27 @@ class QuerySet(object):
raise NotImplementedError("Requires MongoDB >= 1.7.1") raise NotImplementedError("Requires MongoDB >= 1.7.1")
map_f_scope = {} map_f_scope = {}
if isinstance(map_f, pymongo.code.Code): if isinstance(map_f, Code):
map_f_scope = map_f.scope map_f_scope = map_f.scope
map_f = unicode(map_f) map_f = unicode(map_f)
map_f = pymongo.code.Code(self._sub_js_fields(map_f), map_f_scope) map_f = Code(self._sub_js_fields(map_f), map_f_scope)
reduce_f_scope = {} reduce_f_scope = {}
if isinstance(reduce_f, pymongo.code.Code): if isinstance(reduce_f, Code):
reduce_f_scope = reduce_f.scope reduce_f_scope = reduce_f.scope
reduce_f = unicode(reduce_f) reduce_f = unicode(reduce_f)
reduce_f_code = self._sub_js_fields(reduce_f) reduce_f_code = self._sub_js_fields(reduce_f)
reduce_f = pymongo.code.Code(reduce_f_code, reduce_f_scope) reduce_f = Code(reduce_f_code, reduce_f_scope)
mr_args = {'query': self._query} mr_args = {'query': self._query}
if finalize_f: if finalize_f:
finalize_f_scope = {} finalize_f_scope = {}
if isinstance(finalize_f, pymongo.code.Code): if isinstance(finalize_f, Code):
finalize_f_scope = finalize_f.scope finalize_f_scope = finalize_f.scope
finalize_f = unicode(finalize_f) finalize_f = unicode(finalize_f)
finalize_f_code = self._sub_js_fields(finalize_f) finalize_f_code = self._sub_js_fields(finalize_f)
finalize_f = pymongo.code.Code(finalize_f_code, finalize_f_scope) finalize_f = Code(finalize_f_code, finalize_f_scope)
mr_args['finalize'] = finalize_f mr_args['finalize'] = finalize_f
if scope: if scope:
@ -1499,7 +1497,7 @@ class QuerySet(object):
query['$where'] = self._where_clause query['$where'] = self._where_clause
scope['query'] = query scope['query'] = query
code = pymongo.code.Code(code, scope=scope) code = Code(code, scope=scope)
db = self._document._get_db() db = self._document._get_db()
return db.eval(code, *fields) return db.eval(code, *fields)
@ -1528,13 +1526,13 @@ class QuerySet(object):
.. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work .. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work
with sharding. with sharding.
""" """
map_func = pymongo.code.Code(""" map_func = Code("""
function() { function() {
emit(1, this[field] || 0); emit(1, this[field] || 0);
} }
""", scope={'field': field}) """, scope={'field': field})
reduce_func = pymongo.code.Code(""" reduce_func = Code("""
function(key, values) { function(key, values) {
var sum = 0; var sum = 0;
for (var i in values) { for (var i in values) {
@ -1558,14 +1556,14 @@ class QuerySet(object):
.. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work .. versionchanged:: 0.5 - updated to map_reduce as db.eval doesnt work
with sharding. with sharding.
""" """
map_func = pymongo.code.Code(""" map_func = Code("""
function() { function() {
if (this.hasOwnProperty(field)) if (this.hasOwnProperty(field))
emit(1, {t: this[field] || 0, c: 1}); emit(1, {t: this[field] || 0, c: 1});
} }
""", scope={'field': field}) """, scope={'field': field})
reduce_func = pymongo.code.Code(""" reduce_func = Code("""
function(key, values) { function(key, values) {
var out = {t: 0, c: 0}; var out = {t: 0, c: 0};
for (var i in values) { for (var i in values) {
@ -1577,7 +1575,7 @@ class QuerySet(object):
} }
""") """)
finalize_func = pymongo.code.Code(""" finalize_func = Code("""
function(key, value) { function(key, value) {
return value.t / value.c; return value.t / value.c;
} }
@ -1719,7 +1717,7 @@ class QuerySet(object):
def __repr__(self): def __repr__(self):
limit = REPR_OUTPUT_SIZE + 1 limit = REPR_OUTPUT_SIZE + 1
start = ( 0 if self._skip is None else self._skip ) start = (0 if self._skip is None else self._skip)
if self._limit is None: if self._limit is None:
stop = start + limit stop = start + limit
if self._limit is not None: if self._limit is not None:
@ -1736,7 +1734,7 @@ class QuerySet(object):
return repr(data) return repr(data)
def select_related(self, max_depth=1): def select_related(self, max_depth=1):
"""Handles dereferencing of :class:`~pymongo.dbref.DBRef` objects to """Handles dereferencing of :class:`~bson.dbref.DBRef` objects to
a maximum depth in order to cut down the number queries to mongodb. a maximum depth in order to cut down the number queries to mongodb.
.. versionadded:: 0.5 .. versionadded:: 0.5

View File

@ -4,7 +4,7 @@ import pymongo
import mongoengine.connection import mongoengine.connection
from mongoengine import * from mongoengine import *
from mongoengine.connection import get_db, get_connection from mongoengine.connection import get_db, get_connection, ConnectionError
class ConnectionTest(unittest.TestCase): class ConnectionTest(unittest.TestCase):
@ -33,6 +33,15 @@ class ConnectionTest(unittest.TestCase):
def test_connect_uri(self): def test_connect_uri(self):
"""Ensure that the connect() method works properly with uri's """Ensure that the connect() method works properly with uri's
""" """
c = connect(db='mongoenginetest', alias='admin')
c.admin.system.users.remove({})
c.mongoenginetest.system.users.remove({})
c.admin.add_user("admin", "password")
c.admin.authenticate("admin", "password")
c.mongoenginetest.add_user("username", "password")
self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost/mongoenginetest')
connect("testdb_uri", host='mongodb://username:password@localhost/mongoenginetest') connect("testdb_uri", host='mongodb://username:password@localhost/mongoenginetest')

View File

@ -1,5 +1,6 @@
import pickle import pickle
import pymongo import pymongo
import bson
import unittest import unittest
import warnings import warnings
@ -2222,7 +2223,7 @@ class DocumentTest(unittest.TestCase):
# Test laziness # Test laziness
self.assertTrue(isinstance(post_obj._data['author'], self.assertTrue(isinstance(post_obj._data['author'],
pymongo.dbref.DBRef)) bson.DBRef))
self.assertTrue(isinstance(post_obj.author, self.Person)) self.assertTrue(isinstance(post_obj.author, self.Person))
self.assertEqual(post_obj.author.name, 'Test User') self.assertEqual(post_obj.author.name, 'Test User')

View File

@ -1,6 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import unittest import unittest
import pymongo import pymongo
from bson import ObjectId
from datetime import datetime, timedelta from datetime import datetime, timedelta
from mongoengine.queryset import (QuerySet, QuerySetManager, from mongoengine.queryset import (QuerySet, QuerySetManager,
@ -59,8 +60,7 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(len(people), 2) self.assertEqual(len(people), 2)
results = list(people) results = list(people)
self.assertTrue(isinstance(results[0], self.Person)) self.assertTrue(isinstance(results[0], self.Person))
self.assertTrue(isinstance(results[0].id, (pymongo.objectid.ObjectId, self.assertTrue(isinstance(results[0].id, (ObjectId, str, unicode)))
str, unicode)))
self.assertEqual(results[0].name, "User A") self.assertEqual(results[0].name, "User A")
self.assertEqual(results[0].age, 20) self.assertEqual(results[0].age, 20)
self.assertEqual(results[1].name, "User B") self.assertEqual(results[1].name, "User B")