fixed merge conflict in queryset.py, used hmarr's code

This commit is contained in:
blackbrrr 2010-01-07 20:23:11 -06:00
commit 2a7fc03e79
18 changed files with 365 additions and 49 deletions

View File

@ -38,6 +38,8 @@ Fields
.. autoclass:: mongoengine.FloatField
.. autoclass:: mongoengine.BooleanField
.. autoclass:: mongoengine.DateTimeField
.. autoclass:: mongoengine.EmbeddedDocumentField

View File

@ -2,6 +2,23 @@
Changelog
=========
Changes is v0.1.3
=================
- Added Django authentication backend
- Added Document.meta support for indexes, which are ensured just before
querying takes place
- A few minor bugfixes
Changes in v0.1.2
=================
- Query values may be processed before before being used in queries
- Made connections lazy
- Fixed bug in Document dictionary-style access
- Added BooleanField
- Added Document.reload method
Changes in v0.1.1
=================
- Documents may now use capped collections

View File

@ -25,7 +25,7 @@ sys.path.append(os.path.abspath('..'))
extensions = ['sphinx.ext.autodoc']
# Add any paths that contain templates here, relative to this directory.
templates_path = ['.templates']
templates_path = ['_templates']
# The suffix of source filenames.
source_suffix = '.rst'

29
docs/django.rst Normal file
View File

@ -0,0 +1,29 @@
=============================
Using MongoEngine with Django
=============================
Connecting
==========
In your **settings.py** file, ignore the standard database settings (unless you
also plan to use the ORM in your project), and instead call
:func:`~mongoengine.connect` somewhere in the settings module.
Authentication
==============
MongoEngine includes a Django authentication backend, which uses MongoDB. The
:class:`~mongoengine.django.auth.User` model is a MongoEngine
:class:`~mongoengine.Document`, but implements most of the methods and
attributes that the standard Django :class:`User` model does - so the two are
moderately compatible. Using this backend will allow you to store users in
MongoDB but still use many of the Django authentication infrastucture (such as
the :func:`login_required` decorator and the :func:`authenticate` function). To
enable the MongoEngine auth backend, add the following to you **settings.py**
file::
AUTHENTICATION_BACKENDS = (
'mongoengine.django.auth.MongoEngineBackend',
)
The :mod:`~mongoengine.django.auth` module also contains a
:func:`~mongoengine.django.auth.get_user` helper function, that takes a user's
:attr:`id` and returns a :class:`~mongoengine.django.auth.User` object.

View File

@ -16,6 +16,7 @@ The source is available on `GitHub <http://github.com/hmarr/mongoengine>`_.
tutorial
userguide
apireference
django
changelog
Indices and tables

View File

@ -2,8 +2,6 @@
User Guide
==========
.. _guide-connecting:
Installing
==========
MongoEngine is available on PyPI, so to use it you can use
@ -20,6 +18,8 @@ Alternatively, if you don't have setuptools installed, `download it from PyPi
# python setup.py install
.. _guide-connecting:
Connecting to MongoDB
=====================
To connect to a running instance of :program:`mongod`, use the
@ -168,6 +168,22 @@ The following example shows a :class:`Log` document that will be limited to
ip_address = StringField()
meta = {'max_documents': 1000, 'max_size': 2000000}
Indexes
-------
You can specify indexes on collections to make querying faster. This is done
by creating a list of index specifications called :attr:`indexes` in the
:attr:`~Document.meta` dictionary, where an index specification may either be
a single field name, or a tuple containing multiple field names. A direction
may be specified on fields by prefixing the field name with a **+** or a **-**
sign. Note that direction only matters on multi-field indexes. ::
class Page(Document):
title = StringField()
rating = StringField()
meta = {
'indexes': ['title', ('title', '-rating')]
}
Document inheritance
--------------------
To create a specialised type of a :class:`~mongoengine.Document` you have

View File

@ -12,7 +12,7 @@ __all__ = (document.__all__ + fields.__all__ + connection.__all__ +
__author__ = 'Harry Marr'
VERSION = (0, 1, 1)
VERSION = (0, 1, 3)
def get_version():
version = '%s.%s' % (VERSION[0], VERSION[1])

View File

@ -49,6 +49,11 @@ class BaseField(object):
"""
return self.to_python(value)
def prepare_query_value(self, value):
"""Prepare a value that is being used in a query for PyMongo.
"""
return value
def validate(self, value):
"""Perform validation on a value.
"""
@ -67,6 +72,9 @@ class ObjectIdField(BaseField):
return pymongo.objectid.ObjectId(value)
return value
def prepare_query_value(self, value):
return self.to_mongo(value)
def validate(self, value):
try:
pymongo.objectid.ObjectId(str(value))
@ -199,17 +207,17 @@ class BaseDocument(object):
return all_subclasses
def __iter__(self):
# Use _data rather than _fields as iterator only looks at names so
# values don't need to be converted to Python types
return iter(self._data)
return iter(self._fields)
def __getitem__(self, name):
"""Dictionary-style field access, return a field's value if present.
"""
try:
return getattr(self, name)
if name in self._fields:
return getattr(self, name)
except AttributeError:
raise KeyError(name)
pass
raise KeyError(name)
def __setitem__(self, name, value):
"""Dictionary-style field access, set a field's value.

View File

@ -10,6 +10,10 @@ _connection_settings = {
'pool_size': 1,
}
_connection = None
_db_name = None
_db_username = None
_db_password = None
_db = None
@ -19,14 +23,30 @@ class ConnectionError(Exception):
def _get_connection():
global _connection
# Connect to the database if not already connected
if _connection is None:
_connection = Connection(**_connection_settings)
try:
_connection = Connection(**_connection_settings)
except:
raise ConnectionError('Cannot connect to the database')
return _connection
def _get_db():
global _db
global _db, _connection
# Connect if not already connected
if _connection is None:
_connection = _get_connection()
if _db is None:
raise ConnectionError('Not connected to database')
# _db_name will be None if the user hasn't called connect()
if _db_name is None:
raise ConnectionError('Not connected to the database')
# Get DB from current connection and authenticate if necessary
_db = _connection[_db_name]
if _db_username and _db_password:
_db.authenticate(_db_username, _db_password)
return _db
def connect(db, username=None, password=None, **kwargs):
@ -35,12 +55,8 @@ def connect(db, username=None, password=None, **kwargs):
the default port on localhost. If authentication is needed, provide
username and password arguments as well.
"""
global _db
global _connection_settings, _db_name, _db_username, _db_password
_connection_settings.update(kwargs)
connection = _get_connection()
# Get DB from connection and auth if necessary
_db = connection[db]
if username is not None and password is not None:
_db.authenticate(username, password)
_db_name = db
_db_username = username
_db_password = password

View File

View File

@ -0,0 +1,99 @@
from mongoengine import *
from django.utils.hashcompat import md5_constructor, sha_constructor
from django.utils.encoding import smart_str
from django.contrib.auth.models import AnonymousUser
import datetime
REDIRECT_FIELD_NAME = 'next'
def get_hexdigest(algorithm, salt, raw_password):
raw_password, salt = smart_str(raw_password), smart_str(salt)
if algorithm == 'md5':
return md5_constructor(salt + raw_password).hexdigest()
elif algorithm == 'sha1':
return sha_constructor(salt + raw_password).hexdigest()
raise ValueError('Got unknown password algorithm type in password')
class User(Document):
"""A User document that aims to mirror most of the API specified by Django
at http://docs.djangoproject.com/en/dev/topics/auth/#users
"""
username = StringField(max_length=30, required=True)
first_name = StringField(max_length=30)
last_name = StringField(max_length=30)
email = StringField()
password = StringField(max_length=128)
is_staff = BooleanField(default=False)
is_active = BooleanField(default=True)
is_superuser = BooleanField(default=False)
last_login = DateTimeField(default=datetime.datetime.now)
def get_full_name(self):
"""Returns the users first and last names, separated by a space.
"""
full_name = u'%s %s' % (self.first_name or '', self.last_name or '')
return full_name.strip()
def is_anonymous(self):
return False
def is_authenticated(self):
return True
def set_password(self, raw_password):
"""Sets the user's password - always use this rather than directly
assigning to :attr:`~mongoengine.django.auth.User.password` as the
password is hashed before storage.
"""
from random import random
algo = 'sha1'
salt = get_hexdigest(algo, str(random()), str(random()))[:5]
hash = get_hexdigest(algo, salt, raw_password)
self.password = '%s$%s$%s' % (algo, salt, hash)
def check_password(self, raw_password):
"""Checks the user's password against a provided password - always use
this rather than directly comparing to
:attr:`~mongoengine.django.auth.User.password` as the password is
hashed before storage.
"""
algo, salt, hash = self.password.split('$')
return hash == get_hexdigest(algo, salt, raw_password)
@classmethod
def create_user(cls, username, password, email=None):
"""Create (and save) a new user with the given username, password and
email address.
"""
user = User(username=username, email=email)
user.set_password(password)
user.save()
return user
class MongoEngineBackend(object):
"""Authenticate using MongoEngine and mongoengine.django.auth.User.
"""
def authenticate(self, username=None, password=None):
user = User.objects(username=username).first()
if user:
if password and user.check_password(password):
return user
return None
def get_user(self, user_id):
return User.objects.with_id(user_id)
def get_user(userid):
"""Returns a User object from an id (User.id). Django's equivalent takes
request, but taking an id instead leaves it up to the developer to store
the id in any way they want (session, signed cookie, etc.)
"""
if not userid:
return AnonymousUser()
return MongoEngineBackend().get_user(userid) or AnonymousUser()

View File

@ -44,6 +44,11 @@ class Document(BaseDocument):
maximum size of the collection in bytes. If :attr:`max_size` is not
specified and :attr:`max_documents` is, :attr:`max_size` defaults to
10000000 bytes (10MB).
Indexes may be created by specifying :attr:`indexes` in the :attr:`meta`
dictionary. The value should be a list of field names or tuples of field
names. Index direction may be specified by prefixing the field names with
a **+** or **-** sign.
"""
__metaclass__ = TopLevelDocumentMetaclass
@ -67,8 +72,7 @@ class Document(BaseDocument):
def reload(self):
"""Reloads all attributes from the database.
"""
object_id = self._fields['id'].to_mongo(self.id)
obj = self.__class__.objects(id=object_id).first()
obj = self.__class__.objects(id=self.id).first()
for field in self._fields:
setattr(self, field, getattr(obj, field))

View File

@ -7,9 +7,9 @@ import pymongo
import datetime
__all__ = ['StringField', 'IntField', 'FloatField', 'DateTimeField',
'EmbeddedDocumentField', 'ListField', 'ObjectIdField',
'ReferenceField', 'ValidationError']
__all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField',
'DateTimeField', 'EmbeddedDocumentField', 'ListField',
'ObjectIdField', 'ReferenceField', 'ValidationError']
class StringField(BaseField):
@ -25,7 +25,7 @@ class StringField(BaseField):
return unicode(value)
def validate(self, value):
assert(isinstance(value, (str, unicode)))
assert isinstance(value, (str, unicode))
if self.max_length is not None and len(value) > self.max_length:
raise ValidationError('String value is too long')
@ -50,7 +50,7 @@ class IntField(BaseField):
return int(value)
def validate(self, value):
assert(isinstance(value, (int, long)))
assert isinstance(value, (int, long))
if self.min_value is not None and value < self.min_value:
raise ValidationError('Integer value is too small')
@ -71,7 +71,7 @@ class FloatField(BaseField):
return float(value)
def validate(self, value):
assert(isinstance(value, float))
assert isinstance(value, float)
if self.min_value is not None and value < self.min_value:
raise ValidationError('Float value is too small')
@ -80,12 +80,23 @@ class FloatField(BaseField):
raise ValidationError('Float value is too large')
class BooleanField(BaseField):
"""A boolean field type.
"""
def to_python(self, value):
return bool(value)
def validate(self, value):
assert isinstance(value, bool)
class DateTimeField(BaseField):
"""A datetime field.
"""
def validate(self, value):
assert(isinstance(value, datetime.datetime))
assert isinstance(value, datetime.datetime)
class EmbeddedDocumentField(BaseField):
@ -188,21 +199,27 @@ class ReferenceField(BaseField):
def to_mongo(self, document):
if isinstance(document, (str, unicode, pymongo.objectid.ObjectId)):
# document may already be an object id
id_ = document
else:
# We need the id from the saved object to create the DBRef
id_ = document.id
if id_ is None:
raise ValidationError('You can only reference documents once '
'they have been saved to the database')
# id may be a string rather than an ObjectID object
if not isinstance(id_, pymongo.objectid.ObjectId):
id_ = pymongo.objectid.ObjectId(id_)
collection = self.document_type._meta['collection']
return pymongo.dbref.DBRef(collection, id_)
def prepare_query_value(self, value):
return self.to_mongo(value)
def validate(self, value):
assert(isinstance(value, (self.document_type, pymongo.dbref.DBRef)))
assert isinstance(value, (self.document_type, pymongo.dbref.DBRef))
def lookup_member(self, member_name):
return self.document_type._fields.get(member_name)

View File

@ -26,23 +26,28 @@ class QuerySet(object):
self._query = {'_types': self._document._class_name}
self._cursor_obj = None
def ensure_index(self, key_or_list, direction=None):
def ensure_index(self, key_or_list):
"""Ensure that the given indexes are in place.
"""
if isinstance(key_or_list, basestring):
# single-field indexes needn't specify a direction
if key_or_list.startswith("-") or key_or_list.startswith("+"):
if key_or_list.startswith(("-", "+")):
key_or_list = key_or_list[1:]
self._collection.ensure_index(key_or_list)
# Use real field name
key = QuerySet._translate_field_name(self._document, key_or_list)
self._collection.ensure_index(key)
elif isinstance(key_or_list, (list, tuple)):
index_list = []
for key in key_or_list:
# Get direction from + or -
direction = pymongo.ASCENDING
if key.startswith("-"):
index_list.append((key[1:], pymongo.DESCENDING))
else:
if key.startswith("+"):
direction = pymongo.DESCENDING
if key.startswith(("+", "-")):
key = key[1:]
index_list.append((key, pymongo.ASCENDING))
# Use real field name
key = QuerySet._translate_field_name(self._document, key)
index_list.append((key, direction))
self._collection.ensure_index(index_list)
return self
@ -68,12 +73,13 @@ class QuerySet(object):
return self._cursor_obj
@classmethod
def _translate_field_name(cls, document, parts):
"""Translate a field attribute name to a database field name.
def _lookup_field(cls, document, parts):
"""Lookup a field based on its attribute and return a list containing
the field's parents and the field.
"""
if not isinstance(parts, (list, tuple)):
parts = [parts]
field_names = []
fields = []
field = None
for field_name in parts:
if field is None:
@ -85,9 +91,17 @@ class QuerySet(object):
if field is None:
raise InvalidQueryError('Cannot resolve field "%s"'
% field_name)
field_names.append(field.name)
return field_names
fields.append(field)
return fields
@classmethod
def _translate_field_name(cls, doc_cls, field, sep='.'):
"""Translate a field attribute name to a database field name.
"""
parts = field.split(sep)
parts = [f.name for f in QuerySet._lookup_field(doc_cls, parts)]
return '.'.join(parts)
@classmethod
def _transform_query(cls, _doc_cls=None, **query):
"""Transform a query from Django-style format to Mongo format.
@ -102,11 +116,22 @@ class QuerySet(object):
op = None
if parts[-1] in operators:
op = parts.pop()
value = {'$' + op: value}
# Switch field names to proper names [set in Field(name='foo')]
if _doc_cls:
parts = QuerySet._translate_field_name(_doc_cls, parts)
# Switch field names to proper names [set in Field(name='foo')]
fields = QuerySet._lookup_field(_doc_cls, parts)
parts = [field.name for field in fields]
# Convert value to proper value
field = fields[-1]
if op in (None, 'neq', 'gt', 'gte', 'lt', 'lte'):
value = field.prepare_query_value(value)
elif op in ('in', 'nin', 'all'):
# 'in', 'nin' and 'all' require a list of values
value = [field.prepare_query_value(v) for v in value]
if op:
value = {'$' + op: value}
key = '.'.join(parts)
if op is None or key not in mongo_query:
@ -129,7 +154,7 @@ class QuerySet(object):
"""Retrieve the object matching the id provided.
"""
if not isinstance(object_id, pymongo.objectid.ObjectId):
object_id = pymongo.objectid.ObjectId(object_id)
object_id = pymongo.objectid.ObjectId(str(object_id))
result = self._collection.find_one(object_id)
if result is not None:

View File

@ -1,6 +1,5 @@
from setuptools import setup, find_packages
VERSION = '0.1.1'
import os
DESCRIPTION = "A Python Document-Object Mapper for working with MongoDB"
@ -10,6 +9,20 @@ try:
except:
pass
def get_version(version_tuple):
version = '%s.%s' % (version_tuple[0], version_tuple[1])
if version_tuple[2]:
version = '%s.%s' % (version, version_tuple[2])
return version
# Dirty hack to get version number from monogengine/__init__.py - we can't
# import it as it depends on PyMongo and PyMongo isn't installed until this
# file is read
init = os.path.join(os.path.dirname(__file__), 'mongoengine', '__init__.py')
version_line = filter(lambda l: l.startswith('VERSION'), open(init))[0]
VERSION = get_version(eval(version_line.split('=')[-1]))
print VERSION
CLASSIFIERS = [
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',

View File

@ -221,6 +221,32 @@ class DocumentTest(unittest.TestCase):
Log.drop_collection()
def test_indexes(self):
"""Ensure that indexes are used when meta[indexes] is specified.
"""
class BlogPost(Document):
date = DateTimeField(name='addDate', default=datetime.datetime.now)
category = StringField()
meta = {
'indexes': [
'-date',
('category', '-date')
],
}
BlogPost.drop_collection()
info = BlogPost.objects._collection.index_information()
self.assertEqual(len(info), 0)
BlogPost.objects()
info = BlogPost.objects._collection.index_information()
self.assertTrue([('category', 1), ('addDate', -1)] in info.values())
# Even though descending order was specified, single-key indexes use 1
self.assertTrue([('addDate', 1)] in info.values())
BlogPost.drop_collection()
def test_creation(self):
"""Ensure that document may be created using keyword arguments.
"""
@ -321,6 +347,8 @@ class DocumentTest(unittest.TestCase):
comments = ListField(EmbeddedDocumentField(Comment))
tags = ListField(StringField())
BlogPost.drop_collection()
post = BlogPost(content='Went for a walk today...')
post.tags = tags = ['fun', 'leisure']
comments = [Comment(content='Good for you'), Comment(content='Yay.')]

View File

@ -113,6 +113,21 @@ class FieldTest(unittest.TestCase):
person.height = 4.0
self.assertRaises(ValidationError, person.validate)
def test_boolean_validation(self):
"""Ensure that invalid values cannot be assigned to boolean fields.
"""
class Person(Document):
admin = BooleanField()
person = Person()
person.admin = True
person.validate()
person.admin = 2
self.assertRaises(ValidationError, person.validate)
person.admin = 'Yes'
self.assertRaises(ValidationError, person.validate)
def test_datetime_validation(self):
"""Ensure that invalid values cannot be assigned to datetime fields.
"""

View File

@ -300,6 +300,32 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection()
def test_query_value_conversion(self):
"""Ensure that query values are properly converted when necessary.
"""
class BlogPost(Document):
author = ReferenceField(self.Person)
BlogPost.drop_collection()
person = self.Person(name='test', age=30)
person.save()
post = BlogPost(author=person)
post.save()
# Test that query may be performed by providing a document as a value
# while using a ReferenceField's name - the document should be
# converted to an DBRef, which is legal, unlike a Document object
post_obj = BlogPost.objects(author=person).first()
self.assertEqual(post.id, post_obj.id)
# Test that lists of values work when using the 'in', 'nin' and 'all'
post_obj = BlogPost.objects(author__in=[person]).first()
self.assertEqual(post.id, post_obj.id)
BlogPost.drop_collection()
def tearDown(self):
self.Person.drop_collection()