Updated for pymongo

This commit is contained in:
Ross Lawley 2012-07-20 16:17:35 +01:00
parent 69989365c7
commit 3da37fbf6e
7 changed files with 158 additions and 138 deletions

View File

@ -12,7 +12,7 @@ __all__ = (document.__all__ + fields.__all__ + connection.__all__ +
__author__ = 'Harry Marr' __author__ = 'Harry Marr'
VERSION = (0, 4, 0) VERSION = (0, 4, 1)
def get_version(): def get_version():
version = '%s.%s' % (VERSION[0], VERSION[1]) version = '%s.%s' % (VERSION[0], VERSION[1])

View File

@ -2,8 +2,8 @@ from queryset import QuerySet, QuerySetManager
from queryset import DoesNotExist, MultipleObjectsReturned from queryset import DoesNotExist, MultipleObjectsReturned
import sys import sys
import bson
import pymongo import pymongo
import pymongo.objectid
_document_registry = {} _document_registry = {}
@ -111,9 +111,9 @@ class ObjectIdField(BaseField):
# return unicode(value) # return unicode(value)
def to_mongo(self, value): def to_mongo(self, value):
if not isinstance(value, pymongo.objectid.ObjectId): if not isinstance(value, bson.objectid.ObjectId):
try: try:
return pymongo.objectid.ObjectId(unicode(value)) return bson.objectid.ObjectId(unicode(value))
except Exception, e: except Exception, e:
#e.message attribute has been deprecated since Python 2.6 #e.message attribute has been deprecated since Python 2.6
raise ValidationError(unicode(e)) raise ValidationError(unicode(e))
@ -124,7 +124,7 @@ class ObjectIdField(BaseField):
def validate(self, value): def validate(self, value):
try: try:
pymongo.objectid.ObjectId(unicode(value)) bson.objectid.ObjectId(unicode(value))
except: except:
raise ValidationError('Invalid Object ID') raise ValidationError('Invalid Object ID')

View File

@ -124,7 +124,7 @@ class MapReduceDocument(object):
:param collection: An instance of :class:`~pymongo.Collection` :param collection: An instance of :class:`~pymongo.Collection`
:param key: Document/result key, often an instance of :param key: Document/result key, often an instance of
:class:`~pymongo.objectid.ObjectId`. If supplied as :class:`~bson.objectid.ObjectId`. If supplied as
an ``ObjectId`` found in the given ``collection``, an ``ObjectId`` found in the given ``collection``,
the object can be accessed via the ``object`` property. the object can be accessed via the ``object`` property.
:param value: The result(s) for this key. :param value: The result(s) for this key.

View File

@ -5,9 +5,9 @@ from operator import itemgetter
import re import re
import pymongo import pymongo
import pymongo.dbref import bson.dbref
import pymongo.son import bson.son
import pymongo.binary import bson.binary
import datetime import datetime
import decimal import decimal
import gridfs import gridfs
@ -306,7 +306,7 @@ class ListField(BaseField):
deref_list = [] deref_list = []
for value in value_list: for value in value_list:
# Dereference DBRefs # Dereference DBRefs
if isinstance(value, (pymongo.dbref.DBRef)): if isinstance(value, (bson.dbref.DBRef)):
value = _get_db().dereference(value) value = _get_db().dereference(value)
deref_list.append(referenced_type._from_son(value)) deref_list.append(referenced_type._from_son(value))
else: else:
@ -319,7 +319,7 @@ class ListField(BaseField):
deref_list = [] deref_list = []
for value in value_list: for value in value_list:
# Dereference DBRefs # Dereference DBRefs
if isinstance(value, (dict, pymongo.son.SON)): if isinstance(value, (dict, bson.son.SON)):
deref_list.append(self.field.dereference(value)) deref_list.append(self.field.dereference(value))
else: else:
deref_list.append(value) deref_list.append(value)
@ -444,7 +444,7 @@ class ReferenceField(BaseField):
# Get value from document instance if available # Get value from document instance if available
value = instance._data.get(self.name) value = instance._data.get(self.name)
# Dereference DBRefs # Dereference DBRefs
if isinstance(value, (pymongo.dbref.DBRef)): if isinstance(value, (bson.dbref.DBRef)):
value = _get_db().dereference(value) value = _get_db().dereference(value)
if value is not None: if value is not None:
instance._data[self.name] = self.document_type._from_son(value) instance._data[self.name] = self.document_type._from_son(value)
@ -466,13 +466,13 @@ class ReferenceField(BaseField):
id_ = id_field.to_mongo(id_) id_ = id_field.to_mongo(id_)
collection = self.document_type._meta['collection'] collection = self.document_type._meta['collection']
return pymongo.dbref.DBRef(collection, id_) return bson.dbref.DBRef(collection, id_)
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
return self.to_mongo(value) return self.to_mongo(value)
def validate(self, value): def validate(self, value):
assert isinstance(value, (self.document_type, pymongo.dbref.DBRef)) assert isinstance(value, (self.document_type, bson.dbref.DBRef))
def lookup_member(self, member_name): def lookup_member(self, member_name):
return self.document_type._fields.get(member_name) return self.document_type._fields.get(member_name)
@ -490,7 +490,7 @@ class GenericReferenceField(BaseField):
return self return self
value = instance._data.get(self.name) value = instance._data.get(self.name)
if isinstance(value, (dict, pymongo.son.SON)): if isinstance(value, (dict, bson.son.SON)):
instance._data[self.name] = self.dereference(value) instance._data[self.name] = self.dereference(value)
return super(GenericReferenceField, self).__get__(instance, owner) return super(GenericReferenceField, self).__get__(instance, owner)
@ -518,7 +518,7 @@ class GenericReferenceField(BaseField):
id_ = id_field.to_mongo(id_) id_ = id_field.to_mongo(id_)
collection = document._meta['collection'] collection = document._meta['collection']
ref = pymongo.dbref.DBRef(collection, id_) ref = bson.dbref.DBRef(collection, id_)
return {'_cls': document.__class__.__name__, '_ref': ref} return {'_cls': document.__class__.__name__, '_ref': ref}
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
@ -534,7 +534,7 @@ class BinaryField(BaseField):
super(BinaryField, self).__init__(**kwargs) super(BinaryField, self).__init__(**kwargs)
def to_mongo(self, value): def to_mongo(self, value):
return pymongo.binary.Binary(value) return bson.binary.Binary(value)
def to_python(self, value): def to_python(self, value):
# Returns str not unicode as this is binary data # Returns str not unicode as this is binary data
@ -680,7 +680,7 @@ class FileField(BaseField):
def validate(self, value): def validate(self, value):
if value.grid_id is not None: if value.grid_id is not None:
assert isinstance(value, GridFSProxy) assert isinstance(value, GridFSProxy)
assert isinstance(value.grid_id, pymongo.objectid.ObjectId) assert isinstance(value.grid_id, bson.objectid.ObjectId)
class GeoPointField(BaseField): class GeoPointField(BaseField):

View File

@ -2,9 +2,9 @@ from connection import _get_db
import pprint import pprint
import pymongo import pymongo
import pymongo.code import bson.code
import pymongo.dbref import bson.dbref
import pymongo.objectid import bson.objectid
import re import re
import copy import copy
import itertools import itertools
@ -667,8 +667,8 @@ class QuerySet(object):
def __len__(self): def __len__(self):
return self.count() return self.count()
def map_reduce(self, map_f, reduce_f, finalize_f=None, limit=None, def map_reduce(self, map_f, reduce_f, output, finalize_f=None, limit=None,
scope=None, keep_temp=False): scope=None):
"""Perform a map/reduce query using the current query spec """Perform a map/reduce query using the current query spec
and ordering. While ``map_reduce`` respects ``QuerySet`` chaining, and ordering. While ``map_reduce`` respects ``QuerySet`` chaining,
it must be the last call made, as it does not return a maleable it must be the last call made, as it does not return a maleable
@ -678,52 +678,61 @@ class QuerySet(object):
and :meth:`~mongoengine.tests.QuerySetTest.test_map_advanced` and :meth:`~mongoengine.tests.QuerySetTest.test_map_advanced`
tests in ``tests.queryset.QuerySetTest`` for usage examples. tests in ``tests.queryset.QuerySetTest`` for usage examples.
:param map_f: map function, as :class:`~pymongo.code.Code` or string :param map_f: map function, as :class:`~bson.code.Code` or string
:param reduce_f: reduce function, as :param reduce_f: reduce function, as
:class:`~pymongo.code.Code` or string :class:`~bson.code.Code` or string
:param output: output collection name, if set to 'inline' will try to
use :class:`~pymongo.collection.Collection.inline_map_reduce`
This can also be a dictionary containing output options
see: http://docs.mongodb.org/manual/reference/commands/#mapReduce
:param finalize_f: finalize function, an optional function that :param finalize_f: finalize function, an optional function that
performs any post-reduction processing. performs any post-reduction processing.
:param scope: values to insert into map/reduce global scope. Optional. :param scope: values to insert into map/reduce global scope. Optional.
:param limit: number of objects from current query to provide :param limit: number of objects from current query to provide
to map/reduce method to map/reduce method
:param keep_temp: keep temporary table (boolean, default ``True``)
Returns an iterator yielding Returns an iterator yielding
:class:`~mongoengine.document.MapReduceDocument`. :class:`~mongoengine.document.MapReduceDocument`.
.. note:: Map/Reduce requires server version **>= 1.1.1**. The PyMongo .. note::
Map/Reduce changed in server version **>= 1.7.4**. The PyMongo
:meth:`~pymongo.collection.Collection.map_reduce` helper requires :meth:`~pymongo.collection.Collection.map_reduce` helper requires
PyMongo version **>= 1.2**. PyMongo version **>= 1.11**.
.. versionchanged:: 0.5
- removed ``keep_temp`` keyword argument, which was only relevant
for MongoDB server versions older than 1.7.4
.. versionadded:: 0.3 .. versionadded:: 0.3
""" """
from document import MapReduceDocument from document import MapReduceDocument
if not hasattr(self._collection, "map_reduce"): if not hasattr(self._collection, "map_reduce"):
raise NotImplementedError("Requires MongoDB >= 1.1.1") raise NotImplementedError("Requires MongoDB >= 1.7.1")
map_f_scope = {} map_f_scope = {}
if isinstance(map_f, pymongo.code.Code): if isinstance(map_f, bson.code.Code):
map_f_scope = map_f.scope map_f_scope = map_f.scope
map_f = unicode(map_f) map_f = unicode(map_f)
map_f = pymongo.code.Code(self._sub_js_fields(map_f), map_f_scope) map_f = bson.code.Code(self._sub_js_fields(map_f), map_f_scope)
reduce_f_scope = {} reduce_f_scope = {}
if isinstance(reduce_f, pymongo.code.Code): if isinstance(reduce_f, bson.code.Code):
reduce_f_scope = reduce_f.scope reduce_f_scope = reduce_f.scope
reduce_f = unicode(reduce_f) reduce_f = unicode(reduce_f)
reduce_f_code = self._sub_js_fields(reduce_f) reduce_f_code = self._sub_js_fields(reduce_f)
reduce_f = pymongo.code.Code(reduce_f_code, reduce_f_scope) reduce_f = bson.code.Code(reduce_f_code, reduce_f_scope)
mr_args = {'query': self._query, 'keeptemp': keep_temp} mr_args = {'query': self._query}
if finalize_f: if finalize_f:
finalize_f_scope = {} finalize_f_scope = {}
if isinstance(finalize_f, pymongo.code.Code): if isinstance(finalize_f, bson.code.Code):
finalize_f_scope = finalize_f.scope finalize_f_scope = finalize_f.scope
finalize_f = unicode(finalize_f) finalize_f = unicode(finalize_f)
finalize_f_code = self._sub_js_fields(finalize_f) finalize_f_code = self._sub_js_fields(finalize_f)
finalize_f = pymongo.code.Code(finalize_f_code, finalize_f_scope) finalize_f = bson.code.Code(finalize_f_code, finalize_f_scope)
mr_args['finalize'] = finalize_f mr_args['finalize'] = finalize_f
if scope: if scope:
@ -732,7 +741,15 @@ class QuerySet(object):
if limit: if limit:
mr_args['limit'] = limit mr_args['limit'] = limit
results = self._collection.map_reduce(map_f, reduce_f, **mr_args) if output == 'inline' and not self._ordering:
map_reduce_function = 'inline_map_reduce'
else:
map_reduce_function = 'map_reduce'
mr_args['out'] = output
results = getattr(self._collection, map_reduce_function)(map_f, reduce_f, **mr_args)
if map_reduce_function == 'map_reduce':
results = results.find() results = results.find()
if self._ordering: if self._ordering:
@ -1037,7 +1054,7 @@ class QuerySet(object):
query['$where'] = self._where_clause query['$where'] = self._where_clause
scope['query'] = query scope['query'] = query
code = pymongo.code.Code(code, scope=scope) code = bson.code.Code(code, scope=scope)
db = _get_db() db = _get_db()
return db.eval(code, *fields) return db.eval(code, *fields)

View File

@ -1,5 +1,6 @@
import unittest import unittest
from datetime import datetime from datetime import datetime
import bson
import pymongo import pymongo
from mongoengine import * from mongoengine import *
@ -611,7 +612,7 @@ class DocumentTest(unittest.TestCase):
# Test laziness # Test laziness
self.assertTrue(isinstance(post_obj._data['author'], self.assertTrue(isinstance(post_obj._data['author'],
pymongo.dbref.DBRef)) bson.dbref.DBRef))
self.assertTrue(isinstance(post_obj.author, self.Person)) self.assertTrue(isinstance(post_obj.author, self.Person))
self.assertEqual(post_obj.author.name, 'Test User') self.assertEqual(post_obj.author.name, 'Test User')

View File

@ -3,6 +3,7 @@
import unittest import unittest
import pymongo import pymongo
import bson
from datetime import datetime, timedelta from datetime import datetime, timedelta
from mongoengine.queryset import (QuerySet, MultipleObjectsReturned, from mongoengine.queryset import (QuerySet, MultipleObjectsReturned,
@ -58,7 +59,7 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(len(people), 2) self.assertEqual(len(people), 2)
results = list(people) results = list(people)
self.assertTrue(isinstance(results[0], self.Person)) self.assertTrue(isinstance(results[0], self.Person))
self.assertTrue(isinstance(results[0].id, (pymongo.objectid.ObjectId, self.assertTrue(isinstance(results[0].id, (bson.objectid.ObjectId,
str, unicode))) str, unicode)))
self.assertEqual(results[0].name, "User A") self.assertEqual(results[0].name, "User A")
self.assertEqual(results[0].age, 20) self.assertEqual(results[0].age, 20)
@ -802,7 +803,7 @@ class QuerySetTest(unittest.TestCase):
""" """
# run a map/reduce operation spanning all posts # run a map/reduce operation spanning all posts
results = BlogPost.objects.map_reduce(map_f, reduce_f) results = BlogPost.objects.map_reduce(map_f, reduce_f, "myresults")
results = list(results) results = list(results)
self.assertEqual(len(results), 4) self.assertEqual(len(results), 4)
@ -851,7 +852,7 @@ class QuerySetTest(unittest.TestCase):
} }
""" """
results = BlogPost.objects.map_reduce(map_f, reduce_f) results = BlogPost.objects.map_reduce(map_f, reduce_f, "myresults")
results = list(results) results = list(results)
self.assertEqual(results[0].object, post1) self.assertEqual(results[0].object, post1)
@ -962,6 +963,7 @@ class QuerySetTest(unittest.TestCase):
results = Link.objects.order_by("-value") results = Link.objects.order_by("-value")
results = results.map_reduce(map_f, results = results.map_reduce(map_f,
reduce_f, reduce_f,
"myresults",
finalize_f=finalize_f, finalize_f=finalize_f,
scope=scope) scope=scope)
results = list(results) results = list(results)