fixed merge conflict in queryset.py, used hmarr's code

This commit is contained in:
blackbrrr
2010-01-07 20:23:11 -06:00
18 changed files with 365 additions and 49 deletions

View File

@@ -12,7 +12,7 @@ __all__ = (document.__all__ + fields.__all__ + connection.__all__ +
__author__ = 'Harry Marr'
VERSION = (0, 1, 1)
VERSION = (0, 1, 3)
def get_version():
version = '%s.%s' % (VERSION[0], VERSION[1])

View File

@@ -49,6 +49,11 @@ class BaseField(object):
"""
return self.to_python(value)
def prepare_query_value(self, value):
"""Prepare a value that is being used in a query for PyMongo.
"""
return value
def validate(self, value):
"""Perform validation on a value.
"""
@@ -67,6 +72,9 @@ class ObjectIdField(BaseField):
return pymongo.objectid.ObjectId(value)
return value
def prepare_query_value(self, value):
return self.to_mongo(value)
def validate(self, value):
try:
pymongo.objectid.ObjectId(str(value))
@@ -199,17 +207,17 @@ class BaseDocument(object):
return all_subclasses
def __iter__(self):
# Use _data rather than _fields as iterator only looks at names so
# values don't need to be converted to Python types
return iter(self._data)
return iter(self._fields)
def __getitem__(self, name):
"""Dictionary-style field access, return a field's value if present.
"""
try:
return getattr(self, name)
if name in self._fields:
return getattr(self, name)
except AttributeError:
raise KeyError(name)
pass
raise KeyError(name)
def __setitem__(self, name, value):
"""Dictionary-style field access, set a field's value.

View File

@@ -10,6 +10,10 @@ _connection_settings = {
'pool_size': 1,
}
_connection = None
_db_name = None
_db_username = None
_db_password = None
_db = None
@@ -19,14 +23,30 @@ class ConnectionError(Exception):
def _get_connection():
global _connection
# Connect to the database if not already connected
if _connection is None:
_connection = Connection(**_connection_settings)
try:
_connection = Connection(**_connection_settings)
except:
raise ConnectionError('Cannot connect to the database')
return _connection
def _get_db():
global _db
global _db, _connection
# Connect if not already connected
if _connection is None:
_connection = _get_connection()
if _db is None:
raise ConnectionError('Not connected to database')
# _db_name will be None if the user hasn't called connect()
if _db_name is None:
raise ConnectionError('Not connected to the database')
# Get DB from current connection and authenticate if necessary
_db = _connection[_db_name]
if _db_username and _db_password:
_db.authenticate(_db_username, _db_password)
return _db
def connect(db, username=None, password=None, **kwargs):
@@ -35,12 +55,8 @@ def connect(db, username=None, password=None, **kwargs):
the default port on localhost. If authentication is needed, provide
username and password arguments as well.
"""
global _db
global _connection_settings, _db_name, _db_username, _db_password
_connection_settings.update(kwargs)
connection = _get_connection()
# Get DB from connection and auth if necessary
_db = connection[db]
if username is not None and password is not None:
_db.authenticate(username, password)
_db_name = db
_db_username = username
_db_password = password

View File

View File

@@ -0,0 +1,99 @@
from mongoengine import *
from django.utils.hashcompat import md5_constructor, sha_constructor
from django.utils.encoding import smart_str
from django.contrib.auth.models import AnonymousUser
import datetime
REDIRECT_FIELD_NAME = 'next'
def get_hexdigest(algorithm, salt, raw_password):
raw_password, salt = smart_str(raw_password), smart_str(salt)
if algorithm == 'md5':
return md5_constructor(salt + raw_password).hexdigest()
elif algorithm == 'sha1':
return sha_constructor(salt + raw_password).hexdigest()
raise ValueError('Got unknown password algorithm type in password')
class User(Document):
"""A User document that aims to mirror most of the API specified by Django
at http://docs.djangoproject.com/en/dev/topics/auth/#users
"""
username = StringField(max_length=30, required=True)
first_name = StringField(max_length=30)
last_name = StringField(max_length=30)
email = StringField()
password = StringField(max_length=128)
is_staff = BooleanField(default=False)
is_active = BooleanField(default=True)
is_superuser = BooleanField(default=False)
last_login = DateTimeField(default=datetime.datetime.now)
def get_full_name(self):
"""Returns the users first and last names, separated by a space.
"""
full_name = u'%s %s' % (self.first_name or '', self.last_name or '')
return full_name.strip()
def is_anonymous(self):
return False
def is_authenticated(self):
return True
def set_password(self, raw_password):
"""Sets the user's password - always use this rather than directly
assigning to :attr:`~mongoengine.django.auth.User.password` as the
password is hashed before storage.
"""
from random import random
algo = 'sha1'
salt = get_hexdigest(algo, str(random()), str(random()))[:5]
hash = get_hexdigest(algo, salt, raw_password)
self.password = '%s$%s$%s' % (algo, salt, hash)
def check_password(self, raw_password):
"""Checks the user's password against a provided password - always use
this rather than directly comparing to
:attr:`~mongoengine.django.auth.User.password` as the password is
hashed before storage.
"""
algo, salt, hash = self.password.split('$')
return hash == get_hexdigest(algo, salt, raw_password)
@classmethod
def create_user(cls, username, password, email=None):
"""Create (and save) a new user with the given username, password and
email address.
"""
user = User(username=username, email=email)
user.set_password(password)
user.save()
return user
class MongoEngineBackend(object):
"""Authenticate using MongoEngine and mongoengine.django.auth.User.
"""
def authenticate(self, username=None, password=None):
user = User.objects(username=username).first()
if user:
if password and user.check_password(password):
return user
return None
def get_user(self, user_id):
return User.objects.with_id(user_id)
def get_user(userid):
"""Returns a User object from an id (User.id). Django's equivalent takes
request, but taking an id instead leaves it up to the developer to store
the id in any way they want (session, signed cookie, etc.)
"""
if not userid:
return AnonymousUser()
return MongoEngineBackend().get_user(userid) or AnonymousUser()

View File

@@ -44,6 +44,11 @@ class Document(BaseDocument):
maximum size of the collection in bytes. If :attr:`max_size` is not
specified and :attr:`max_documents` is, :attr:`max_size` defaults to
10000000 bytes (10MB).
Indexes may be created by specifying :attr:`indexes` in the :attr:`meta`
dictionary. The value should be a list of field names or tuples of field
names. Index direction may be specified by prefixing the field names with
a **+** or **-** sign.
"""
__metaclass__ = TopLevelDocumentMetaclass
@@ -67,8 +72,7 @@ class Document(BaseDocument):
def reload(self):
"""Reloads all attributes from the database.
"""
object_id = self._fields['id'].to_mongo(self.id)
obj = self.__class__.objects(id=object_id).first()
obj = self.__class__.objects(id=self.id).first()
for field in self._fields:
setattr(self, field, getattr(obj, field))

View File

@@ -7,9 +7,9 @@ import pymongo
import datetime
__all__ = ['StringField', 'IntField', 'FloatField', 'DateTimeField',
'EmbeddedDocumentField', 'ListField', 'ObjectIdField',
'ReferenceField', 'ValidationError']
__all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField',
'DateTimeField', 'EmbeddedDocumentField', 'ListField',
'ObjectIdField', 'ReferenceField', 'ValidationError']
class StringField(BaseField):
@@ -25,7 +25,7 @@ class StringField(BaseField):
return unicode(value)
def validate(self, value):
assert(isinstance(value, (str, unicode)))
assert isinstance(value, (str, unicode))
if self.max_length is not None and len(value) > self.max_length:
raise ValidationError('String value is too long')
@@ -50,7 +50,7 @@ class IntField(BaseField):
return int(value)
def validate(self, value):
assert(isinstance(value, (int, long)))
assert isinstance(value, (int, long))
if self.min_value is not None and value < self.min_value:
raise ValidationError('Integer value is too small')
@@ -71,7 +71,7 @@ class FloatField(BaseField):
return float(value)
def validate(self, value):
assert(isinstance(value, float))
assert isinstance(value, float)
if self.min_value is not None and value < self.min_value:
raise ValidationError('Float value is too small')
@@ -80,12 +80,23 @@ class FloatField(BaseField):
raise ValidationError('Float value is too large')
class BooleanField(BaseField):
"""A boolean field type.
"""
def to_python(self, value):
return bool(value)
def validate(self, value):
assert isinstance(value, bool)
class DateTimeField(BaseField):
"""A datetime field.
"""
def validate(self, value):
assert(isinstance(value, datetime.datetime))
assert isinstance(value, datetime.datetime)
class EmbeddedDocumentField(BaseField):
@@ -188,21 +199,27 @@ class ReferenceField(BaseField):
def to_mongo(self, document):
if isinstance(document, (str, unicode, pymongo.objectid.ObjectId)):
# document may already be an object id
id_ = document
else:
# We need the id from the saved object to create the DBRef
id_ = document.id
if id_ is None:
raise ValidationError('You can only reference documents once '
'they have been saved to the database')
# id may be a string rather than an ObjectID object
if not isinstance(id_, pymongo.objectid.ObjectId):
id_ = pymongo.objectid.ObjectId(id_)
collection = self.document_type._meta['collection']
return pymongo.dbref.DBRef(collection, id_)
def prepare_query_value(self, value):
return self.to_mongo(value)
def validate(self, value):
assert(isinstance(value, (self.document_type, pymongo.dbref.DBRef)))
assert isinstance(value, (self.document_type, pymongo.dbref.DBRef))
def lookup_member(self, member_name):
return self.document_type._fields.get(member_name)

View File

@@ -26,23 +26,28 @@ class QuerySet(object):
self._query = {'_types': self._document._class_name}
self._cursor_obj = None
def ensure_index(self, key_or_list, direction=None):
def ensure_index(self, key_or_list):
"""Ensure that the given indexes are in place.
"""
if isinstance(key_or_list, basestring):
# single-field indexes needn't specify a direction
if key_or_list.startswith("-") or key_or_list.startswith("+"):
if key_or_list.startswith(("-", "+")):
key_or_list = key_or_list[1:]
self._collection.ensure_index(key_or_list)
# Use real field name
key = QuerySet._translate_field_name(self._document, key_or_list)
self._collection.ensure_index(key)
elif isinstance(key_or_list, (list, tuple)):
index_list = []
for key in key_or_list:
# Get direction from + or -
direction = pymongo.ASCENDING
if key.startswith("-"):
index_list.append((key[1:], pymongo.DESCENDING))
else:
if key.startswith("+"):
direction = pymongo.DESCENDING
if key.startswith(("+", "-")):
key = key[1:]
index_list.append((key, pymongo.ASCENDING))
# Use real field name
key = QuerySet._translate_field_name(self._document, key)
index_list.append((key, direction))
self._collection.ensure_index(index_list)
return self
@@ -68,12 +73,13 @@ class QuerySet(object):
return self._cursor_obj
@classmethod
def _translate_field_name(cls, document, parts):
"""Translate a field attribute name to a database field name.
def _lookup_field(cls, document, parts):
"""Lookup a field based on its attribute and return a list containing
the field's parents and the field.
"""
if not isinstance(parts, (list, tuple)):
parts = [parts]
field_names = []
fields = []
field = None
for field_name in parts:
if field is None:
@@ -85,9 +91,17 @@ class QuerySet(object):
if field is None:
raise InvalidQueryError('Cannot resolve field "%s"'
% field_name)
field_names.append(field.name)
return field_names
fields.append(field)
return fields
@classmethod
def _translate_field_name(cls, doc_cls, field, sep='.'):
"""Translate a field attribute name to a database field name.
"""
parts = field.split(sep)
parts = [f.name for f in QuerySet._lookup_field(doc_cls, parts)]
return '.'.join(parts)
@classmethod
def _transform_query(cls, _doc_cls=None, **query):
"""Transform a query from Django-style format to Mongo format.
@@ -102,11 +116,22 @@ class QuerySet(object):
op = None
if parts[-1] in operators:
op = parts.pop()
value = {'$' + op: value}
# Switch field names to proper names [set in Field(name='foo')]
if _doc_cls:
parts = QuerySet._translate_field_name(_doc_cls, parts)
# Switch field names to proper names [set in Field(name='foo')]
fields = QuerySet._lookup_field(_doc_cls, parts)
parts = [field.name for field in fields]
# Convert value to proper value
field = fields[-1]
if op in (None, 'neq', 'gt', 'gte', 'lt', 'lte'):
value = field.prepare_query_value(value)
elif op in ('in', 'nin', 'all'):
# 'in', 'nin' and 'all' require a list of values
value = [field.prepare_query_value(v) for v in value]
if op:
value = {'$' + op: value}
key = '.'.join(parts)
if op is None or key not in mongo_query:
@@ -129,7 +154,7 @@ class QuerySet(object):
"""Retrieve the object matching the id provided.
"""
if not isinstance(object_id, pymongo.objectid.ObjectId):
object_id = pymongo.objectid.ObjectId(object_id)
object_id = pymongo.objectid.ObjectId(str(object_id))
result = self._collection.find_one(object_id)
if result is not None: