fixed merge conflict in queryset.py, used hmarr's code

This commit is contained in:
blackbrrr 2010-01-07 20:23:11 -06:00
commit 2a7fc03e79
18 changed files with 365 additions and 49 deletions

View File

@ -38,6 +38,8 @@ Fields
.. autoclass:: mongoengine.FloatField .. autoclass:: mongoengine.FloatField
.. autoclass:: mongoengine.BooleanField
.. autoclass:: mongoengine.DateTimeField .. autoclass:: mongoengine.DateTimeField
.. autoclass:: mongoengine.EmbeddedDocumentField .. autoclass:: mongoengine.EmbeddedDocumentField

View File

@ -2,6 +2,23 @@
Changelog Changelog
========= =========
Changes is v0.1.3
=================
- Added Django authentication backend
- Added Document.meta support for indexes, which are ensured just before
querying takes place
- A few minor bugfixes
Changes in v0.1.2
=================
- Query values may be processed before before being used in queries
- Made connections lazy
- Fixed bug in Document dictionary-style access
- Added BooleanField
- Added Document.reload method
Changes in v0.1.1 Changes in v0.1.1
================= =================
- Documents may now use capped collections - Documents may now use capped collections

View File

@ -25,7 +25,7 @@ sys.path.append(os.path.abspath('..'))
extensions = ['sphinx.ext.autodoc'] extensions = ['sphinx.ext.autodoc']
# Add any paths that contain templates here, relative to this directory. # Add any paths that contain templates here, relative to this directory.
templates_path = ['.templates'] templates_path = ['_templates']
# The suffix of source filenames. # The suffix of source filenames.
source_suffix = '.rst' source_suffix = '.rst'

29
docs/django.rst Normal file
View File

@ -0,0 +1,29 @@
=============================
Using MongoEngine with Django
=============================
Connecting
==========
In your **settings.py** file, ignore the standard database settings (unless you
also plan to use the ORM in your project), and instead call
:func:`~mongoengine.connect` somewhere in the settings module.
Authentication
==============
MongoEngine includes a Django authentication backend, which uses MongoDB. The
:class:`~mongoengine.django.auth.User` model is a MongoEngine
:class:`~mongoengine.Document`, but implements most of the methods and
attributes that the standard Django :class:`User` model does - so the two are
moderately compatible. Using this backend will allow you to store users in
MongoDB but still use many of the Django authentication infrastucture (such as
the :func:`login_required` decorator and the :func:`authenticate` function). To
enable the MongoEngine auth backend, add the following to you **settings.py**
file::
AUTHENTICATION_BACKENDS = (
'mongoengine.django.auth.MongoEngineBackend',
)
The :mod:`~mongoengine.django.auth` module also contains a
:func:`~mongoengine.django.auth.get_user` helper function, that takes a user's
:attr:`id` and returns a :class:`~mongoengine.django.auth.User` object.

View File

@ -16,6 +16,7 @@ The source is available on `GitHub <http://github.com/hmarr/mongoengine>`_.
tutorial tutorial
userguide userguide
apireference apireference
django
changelog changelog
Indices and tables Indices and tables

View File

@ -2,8 +2,6 @@
User Guide User Guide
========== ==========
.. _guide-connecting:
Installing Installing
========== ==========
MongoEngine is available on PyPI, so to use it you can use MongoEngine is available on PyPI, so to use it you can use
@ -20,6 +18,8 @@ Alternatively, if you don't have setuptools installed, `download it from PyPi
# python setup.py install # python setup.py install
.. _guide-connecting:
Connecting to MongoDB Connecting to MongoDB
===================== =====================
To connect to a running instance of :program:`mongod`, use the To connect to a running instance of :program:`mongod`, use the
@ -168,6 +168,22 @@ The following example shows a :class:`Log` document that will be limited to
ip_address = StringField() ip_address = StringField()
meta = {'max_documents': 1000, 'max_size': 2000000} meta = {'max_documents': 1000, 'max_size': 2000000}
Indexes
-------
You can specify indexes on collections to make querying faster. This is done
by creating a list of index specifications called :attr:`indexes` in the
:attr:`~Document.meta` dictionary, where an index specification may either be
a single field name, or a tuple containing multiple field names. A direction
may be specified on fields by prefixing the field name with a **+** or a **-**
sign. Note that direction only matters on multi-field indexes. ::
class Page(Document):
title = StringField()
rating = StringField()
meta = {
'indexes': ['title', ('title', '-rating')]
}
Document inheritance Document inheritance
-------------------- --------------------
To create a specialised type of a :class:`~mongoengine.Document` you have To create a specialised type of a :class:`~mongoengine.Document` you have

View File

@ -12,7 +12,7 @@ __all__ = (document.__all__ + fields.__all__ + connection.__all__ +
__author__ = 'Harry Marr' __author__ = 'Harry Marr'
VERSION = (0, 1, 1) VERSION = (0, 1, 3)
def get_version(): def get_version():
version = '%s.%s' % (VERSION[0], VERSION[1]) version = '%s.%s' % (VERSION[0], VERSION[1])

View File

@ -49,6 +49,11 @@ class BaseField(object):
""" """
return self.to_python(value) return self.to_python(value)
def prepare_query_value(self, value):
"""Prepare a value that is being used in a query for PyMongo.
"""
return value
def validate(self, value): def validate(self, value):
"""Perform validation on a value. """Perform validation on a value.
""" """
@ -67,6 +72,9 @@ class ObjectIdField(BaseField):
return pymongo.objectid.ObjectId(value) return pymongo.objectid.ObjectId(value)
return value return value
def prepare_query_value(self, value):
return self.to_mongo(value)
def validate(self, value): def validate(self, value):
try: try:
pymongo.objectid.ObjectId(str(value)) pymongo.objectid.ObjectId(str(value))
@ -199,16 +207,16 @@ class BaseDocument(object):
return all_subclasses return all_subclasses
def __iter__(self): def __iter__(self):
# Use _data rather than _fields as iterator only looks at names so return iter(self._fields)
# values don't need to be converted to Python types
return iter(self._data)
def __getitem__(self, name): def __getitem__(self, name):
"""Dictionary-style field access, return a field's value if present. """Dictionary-style field access, return a field's value if present.
""" """
try: try:
if name in self._fields:
return getattr(self, name) return getattr(self, name)
except AttributeError: except AttributeError:
pass
raise KeyError(name) raise KeyError(name)
def __setitem__(self, name, value): def __setitem__(self, name, value):

View File

@ -10,6 +10,10 @@ _connection_settings = {
'pool_size': 1, 'pool_size': 1,
} }
_connection = None _connection = None
_db_name = None
_db_username = None
_db_password = None
_db = None _db = None
@ -19,14 +23,30 @@ class ConnectionError(Exception):
def _get_connection(): def _get_connection():
global _connection global _connection
# Connect to the database if not already connected
if _connection is None: if _connection is None:
try:
_connection = Connection(**_connection_settings) _connection = Connection(**_connection_settings)
except:
raise ConnectionError('Cannot connect to the database')
return _connection return _connection
def _get_db(): def _get_db():
global _db global _db, _connection
# Connect if not already connected
if _connection is None:
_connection = _get_connection()
if _db is None: if _db is None:
raise ConnectionError('Not connected to database') # _db_name will be None if the user hasn't called connect()
if _db_name is None:
raise ConnectionError('Not connected to the database')
# Get DB from current connection and authenticate if necessary
_db = _connection[_db_name]
if _db_username and _db_password:
_db.authenticate(_db_username, _db_password)
return _db return _db
def connect(db, username=None, password=None, **kwargs): def connect(db, username=None, password=None, **kwargs):
@ -35,12 +55,8 @@ def connect(db, username=None, password=None, **kwargs):
the default port on localhost. If authentication is needed, provide the default port on localhost. If authentication is needed, provide
username and password arguments as well. username and password arguments as well.
""" """
global _db global _connection_settings, _db_name, _db_username, _db_password
_connection_settings.update(kwargs) _connection_settings.update(kwargs)
connection = _get_connection() _db_name = db
# Get DB from connection and auth if necessary _db_username = username
_db = connection[db] _db_password = password
if username is not None and password is not None:
_db.authenticate(username, password)

View File

View File

@ -0,0 +1,99 @@
from mongoengine import *
from django.utils.hashcompat import md5_constructor, sha_constructor
from django.utils.encoding import smart_str
from django.contrib.auth.models import AnonymousUser
import datetime
REDIRECT_FIELD_NAME = 'next'
def get_hexdigest(algorithm, salt, raw_password):
raw_password, salt = smart_str(raw_password), smart_str(salt)
if algorithm == 'md5':
return md5_constructor(salt + raw_password).hexdigest()
elif algorithm == 'sha1':
return sha_constructor(salt + raw_password).hexdigest()
raise ValueError('Got unknown password algorithm type in password')
class User(Document):
"""A User document that aims to mirror most of the API specified by Django
at http://docs.djangoproject.com/en/dev/topics/auth/#users
"""
username = StringField(max_length=30, required=True)
first_name = StringField(max_length=30)
last_name = StringField(max_length=30)
email = StringField()
password = StringField(max_length=128)
is_staff = BooleanField(default=False)
is_active = BooleanField(default=True)
is_superuser = BooleanField(default=False)
last_login = DateTimeField(default=datetime.datetime.now)
def get_full_name(self):
"""Returns the users first and last names, separated by a space.
"""
full_name = u'%s %s' % (self.first_name or '', self.last_name or '')
return full_name.strip()
def is_anonymous(self):
return False
def is_authenticated(self):
return True
def set_password(self, raw_password):
"""Sets the user's password - always use this rather than directly
assigning to :attr:`~mongoengine.django.auth.User.password` as the
password is hashed before storage.
"""
from random import random
algo = 'sha1'
salt = get_hexdigest(algo, str(random()), str(random()))[:5]
hash = get_hexdigest(algo, salt, raw_password)
self.password = '%s$%s$%s' % (algo, salt, hash)
def check_password(self, raw_password):
"""Checks the user's password against a provided password - always use
this rather than directly comparing to
:attr:`~mongoengine.django.auth.User.password` as the password is
hashed before storage.
"""
algo, salt, hash = self.password.split('$')
return hash == get_hexdigest(algo, salt, raw_password)
@classmethod
def create_user(cls, username, password, email=None):
"""Create (and save) a new user with the given username, password and
email address.
"""
user = User(username=username, email=email)
user.set_password(password)
user.save()
return user
class MongoEngineBackend(object):
"""Authenticate using MongoEngine and mongoengine.django.auth.User.
"""
def authenticate(self, username=None, password=None):
user = User.objects(username=username).first()
if user:
if password and user.check_password(password):
return user
return None
def get_user(self, user_id):
return User.objects.with_id(user_id)
def get_user(userid):
"""Returns a User object from an id (User.id). Django's equivalent takes
request, but taking an id instead leaves it up to the developer to store
the id in any way they want (session, signed cookie, etc.)
"""
if not userid:
return AnonymousUser()
return MongoEngineBackend().get_user(userid) or AnonymousUser()

View File

@ -44,6 +44,11 @@ class Document(BaseDocument):
maximum size of the collection in bytes. If :attr:`max_size` is not maximum size of the collection in bytes. If :attr:`max_size` is not
specified and :attr:`max_documents` is, :attr:`max_size` defaults to specified and :attr:`max_documents` is, :attr:`max_size` defaults to
10000000 bytes (10MB). 10000000 bytes (10MB).
Indexes may be created by specifying :attr:`indexes` in the :attr:`meta`
dictionary. The value should be a list of field names or tuples of field
names. Index direction may be specified by prefixing the field names with
a **+** or **-** sign.
""" """
__metaclass__ = TopLevelDocumentMetaclass __metaclass__ = TopLevelDocumentMetaclass
@ -67,8 +72,7 @@ class Document(BaseDocument):
def reload(self): def reload(self):
"""Reloads all attributes from the database. """Reloads all attributes from the database.
""" """
object_id = self._fields['id'].to_mongo(self.id) obj = self.__class__.objects(id=self.id).first()
obj = self.__class__.objects(id=object_id).first()
for field in self._fields: for field in self._fields:
setattr(self, field, getattr(obj, field)) setattr(self, field, getattr(obj, field))

View File

@ -7,9 +7,9 @@ import pymongo
import datetime import datetime
__all__ = ['StringField', 'IntField', 'FloatField', 'DateTimeField', __all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField',
'EmbeddedDocumentField', 'ListField', 'ObjectIdField', 'DateTimeField', 'EmbeddedDocumentField', 'ListField',
'ReferenceField', 'ValidationError'] 'ObjectIdField', 'ReferenceField', 'ValidationError']
class StringField(BaseField): class StringField(BaseField):
@ -25,7 +25,7 @@ class StringField(BaseField):
return unicode(value) return unicode(value)
def validate(self, value): def validate(self, value):
assert(isinstance(value, (str, unicode))) assert isinstance(value, (str, unicode))
if self.max_length is not None and len(value) > self.max_length: if self.max_length is not None and len(value) > self.max_length:
raise ValidationError('String value is too long') raise ValidationError('String value is too long')
@ -50,7 +50,7 @@ class IntField(BaseField):
return int(value) return int(value)
def validate(self, value): def validate(self, value):
assert(isinstance(value, (int, long))) assert isinstance(value, (int, long))
if self.min_value is not None and value < self.min_value: if self.min_value is not None and value < self.min_value:
raise ValidationError('Integer value is too small') raise ValidationError('Integer value is too small')
@ -71,7 +71,7 @@ class FloatField(BaseField):
return float(value) return float(value)
def validate(self, value): def validate(self, value):
assert(isinstance(value, float)) assert isinstance(value, float)
if self.min_value is not None and value < self.min_value: if self.min_value is not None and value < self.min_value:
raise ValidationError('Float value is too small') raise ValidationError('Float value is too small')
@ -80,12 +80,23 @@ class FloatField(BaseField):
raise ValidationError('Float value is too large') raise ValidationError('Float value is too large')
class BooleanField(BaseField):
"""A boolean field type.
"""
def to_python(self, value):
return bool(value)
def validate(self, value):
assert isinstance(value, bool)
class DateTimeField(BaseField): class DateTimeField(BaseField):
"""A datetime field. """A datetime field.
""" """
def validate(self, value): def validate(self, value):
assert(isinstance(value, datetime.datetime)) assert isinstance(value, datetime.datetime)
class EmbeddedDocumentField(BaseField): class EmbeddedDocumentField(BaseField):
@ -188,21 +199,27 @@ class ReferenceField(BaseField):
def to_mongo(self, document): def to_mongo(self, document):
if isinstance(document, (str, unicode, pymongo.objectid.ObjectId)): if isinstance(document, (str, unicode, pymongo.objectid.ObjectId)):
# document may already be an object id
id_ = document id_ = document
else: else:
# We need the id from the saved object to create the DBRef
id_ = document.id id_ = document.id
if id_ is None: if id_ is None:
raise ValidationError('You can only reference documents once ' raise ValidationError('You can only reference documents once '
'they have been saved to the database') 'they have been saved to the database')
# id may be a string rather than an ObjectID object
if not isinstance(id_, pymongo.objectid.ObjectId): if not isinstance(id_, pymongo.objectid.ObjectId):
id_ = pymongo.objectid.ObjectId(id_) id_ = pymongo.objectid.ObjectId(id_)
collection = self.document_type._meta['collection'] collection = self.document_type._meta['collection']
return pymongo.dbref.DBRef(collection, id_) return pymongo.dbref.DBRef(collection, id_)
def prepare_query_value(self, value):
return self.to_mongo(value)
def validate(self, value): def validate(self, value):
assert(isinstance(value, (self.document_type, pymongo.dbref.DBRef))) assert isinstance(value, (self.document_type, pymongo.dbref.DBRef))
def lookup_member(self, member_name): def lookup_member(self, member_name):
return self.document_type._fields.get(member_name) return self.document_type._fields.get(member_name)

View File

@ -26,23 +26,28 @@ class QuerySet(object):
self._query = {'_types': self._document._class_name} self._query = {'_types': self._document._class_name}
self._cursor_obj = None self._cursor_obj = None
def ensure_index(self, key_or_list, direction=None): def ensure_index(self, key_or_list):
"""Ensure that the given indexes are in place. """Ensure that the given indexes are in place.
""" """
if isinstance(key_or_list, basestring): if isinstance(key_or_list, basestring):
# single-field indexes needn't specify a direction # single-field indexes needn't specify a direction
if key_or_list.startswith("-") or key_or_list.startswith("+"): if key_or_list.startswith(("-", "+")):
key_or_list = key_or_list[1:] key_or_list = key_or_list[1:]
self._collection.ensure_index(key_or_list) # Use real field name
key = QuerySet._translate_field_name(self._document, key_or_list)
self._collection.ensure_index(key)
elif isinstance(key_or_list, (list, tuple)): elif isinstance(key_or_list, (list, tuple)):
index_list = [] index_list = []
for key in key_or_list: for key in key_or_list:
# Get direction from + or -
direction = pymongo.ASCENDING
if key.startswith("-"): if key.startswith("-"):
index_list.append((key[1:], pymongo.DESCENDING)) direction = pymongo.DESCENDING
else: if key.startswith(("+", "-")):
if key.startswith("+"):
key = key[1:] key = key[1:]
index_list.append((key, pymongo.ASCENDING)) # Use real field name
key = QuerySet._translate_field_name(self._document, key)
index_list.append((key, direction))
self._collection.ensure_index(index_list) self._collection.ensure_index(index_list)
return self return self
@ -68,12 +73,13 @@ class QuerySet(object):
return self._cursor_obj return self._cursor_obj
@classmethod @classmethod
def _translate_field_name(cls, document, parts): def _lookup_field(cls, document, parts):
"""Translate a field attribute name to a database field name. """Lookup a field based on its attribute and return a list containing
the field's parents and the field.
""" """
if not isinstance(parts, (list, tuple)): if not isinstance(parts, (list, tuple)):
parts = [parts] parts = [parts]
field_names = [] fields = []
field = None field = None
for field_name in parts: for field_name in parts:
if field is None: if field is None:
@ -85,8 +91,16 @@ class QuerySet(object):
if field is None: if field is None:
raise InvalidQueryError('Cannot resolve field "%s"' raise InvalidQueryError('Cannot resolve field "%s"'
% field_name) % field_name)
field_names.append(field.name) fields.append(field)
return field_names return fields
@classmethod
def _translate_field_name(cls, doc_cls, field, sep='.'):
"""Translate a field attribute name to a database field name.
"""
parts = field.split(sep)
parts = [f.name for f in QuerySet._lookup_field(doc_cls, parts)]
return '.'.join(parts)
@classmethod @classmethod
def _transform_query(cls, _doc_cls=None, **query): def _transform_query(cls, _doc_cls=None, **query):
@ -102,11 +116,22 @@ class QuerySet(object):
op = None op = None
if parts[-1] in operators: if parts[-1] in operators:
op = parts.pop() op = parts.pop()
value = {'$' + op: value}
# Switch field names to proper names [set in Field(name='foo')]
if _doc_cls: if _doc_cls:
parts = QuerySet._translate_field_name(_doc_cls, parts) # Switch field names to proper names [set in Field(name='foo')]
fields = QuerySet._lookup_field(_doc_cls, parts)
parts = [field.name for field in fields]
# Convert value to proper value
field = fields[-1]
if op in (None, 'neq', 'gt', 'gte', 'lt', 'lte'):
value = field.prepare_query_value(value)
elif op in ('in', 'nin', 'all'):
# 'in', 'nin' and 'all' require a list of values
value = [field.prepare_query_value(v) for v in value]
if op:
value = {'$' + op: value}
key = '.'.join(parts) key = '.'.join(parts)
if op is None or key not in mongo_query: if op is None or key not in mongo_query:
@ -129,7 +154,7 @@ class QuerySet(object):
"""Retrieve the object matching the id provided. """Retrieve the object matching the id provided.
""" """
if not isinstance(object_id, pymongo.objectid.ObjectId): if not isinstance(object_id, pymongo.objectid.ObjectId):
object_id = pymongo.objectid.ObjectId(object_id) object_id = pymongo.objectid.ObjectId(str(object_id))
result = self._collection.find_one(object_id) result = self._collection.find_one(object_id)
if result is not None: if result is not None:

View File

@ -1,6 +1,5 @@
from setuptools import setup, find_packages from setuptools import setup, find_packages
import os
VERSION = '0.1.1'
DESCRIPTION = "A Python Document-Object Mapper for working with MongoDB" DESCRIPTION = "A Python Document-Object Mapper for working with MongoDB"
@ -10,6 +9,20 @@ try:
except: except:
pass pass
def get_version(version_tuple):
version = '%s.%s' % (version_tuple[0], version_tuple[1])
if version_tuple[2]:
version = '%s.%s' % (version, version_tuple[2])
return version
# Dirty hack to get version number from monogengine/__init__.py - we can't
# import it as it depends on PyMongo and PyMongo isn't installed until this
# file is read
init = os.path.join(os.path.dirname(__file__), 'mongoengine', '__init__.py')
version_line = filter(lambda l: l.startswith('VERSION'), open(init))[0]
VERSION = get_version(eval(version_line.split('=')[-1]))
print VERSION
CLASSIFIERS = [ CLASSIFIERS = [
'Development Status :: 4 - Beta', 'Development Status :: 4 - Beta',
'Intended Audience :: Developers', 'Intended Audience :: Developers',

View File

@ -221,6 +221,32 @@ class DocumentTest(unittest.TestCase):
Log.drop_collection() Log.drop_collection()
def test_indexes(self):
"""Ensure that indexes are used when meta[indexes] is specified.
"""
class BlogPost(Document):
date = DateTimeField(name='addDate', default=datetime.datetime.now)
category = StringField()
meta = {
'indexes': [
'-date',
('category', '-date')
],
}
BlogPost.drop_collection()
info = BlogPost.objects._collection.index_information()
self.assertEqual(len(info), 0)
BlogPost.objects()
info = BlogPost.objects._collection.index_information()
self.assertTrue([('category', 1), ('addDate', -1)] in info.values())
# Even though descending order was specified, single-key indexes use 1
self.assertTrue([('addDate', 1)] in info.values())
BlogPost.drop_collection()
def test_creation(self): def test_creation(self):
"""Ensure that document may be created using keyword arguments. """Ensure that document may be created using keyword arguments.
""" """
@ -321,6 +347,8 @@ class DocumentTest(unittest.TestCase):
comments = ListField(EmbeddedDocumentField(Comment)) comments = ListField(EmbeddedDocumentField(Comment))
tags = ListField(StringField()) tags = ListField(StringField())
BlogPost.drop_collection()
post = BlogPost(content='Went for a walk today...') post = BlogPost(content='Went for a walk today...')
post.tags = tags = ['fun', 'leisure'] post.tags = tags = ['fun', 'leisure']
comments = [Comment(content='Good for you'), Comment(content='Yay.')] comments = [Comment(content='Good for you'), Comment(content='Yay.')]

View File

@ -113,6 +113,21 @@ class FieldTest(unittest.TestCase):
person.height = 4.0 person.height = 4.0
self.assertRaises(ValidationError, person.validate) self.assertRaises(ValidationError, person.validate)
def test_boolean_validation(self):
"""Ensure that invalid values cannot be assigned to boolean fields.
"""
class Person(Document):
admin = BooleanField()
person = Person()
person.admin = True
person.validate()
person.admin = 2
self.assertRaises(ValidationError, person.validate)
person.admin = 'Yes'
self.assertRaises(ValidationError, person.validate)
def test_datetime_validation(self): def test_datetime_validation(self):
"""Ensure that invalid values cannot be assigned to datetime fields. """Ensure that invalid values cannot be assigned to datetime fields.
""" """

View File

@ -300,6 +300,32 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection() BlogPost.drop_collection()
def test_query_value_conversion(self):
"""Ensure that query values are properly converted when necessary.
"""
class BlogPost(Document):
author = ReferenceField(self.Person)
BlogPost.drop_collection()
person = self.Person(name='test', age=30)
person.save()
post = BlogPost(author=person)
post.save()
# Test that query may be performed by providing a document as a value
# while using a ReferenceField's name - the document should be
# converted to an DBRef, which is legal, unlike a Document object
post_obj = BlogPost.objects(author=person).first()
self.assertEqual(post.id, post_obj.id)
# Test that lists of values work when using the 'in', 'nin' and 'all'
post_obj = BlogPost.objects(author__in=[person]).first()
self.assertEqual(post.id, post_obj.id)
BlogPost.drop_collection()
def tearDown(self): def tearDown(self):
self.Person.drop_collection() self.Person.drop_collection()