Merge branch 'v0.4' of git://github.com/hmarr/mongoengine into v0.4

This commit is contained in:
Steve Challis 2010-09-29 23:39:09 +01:00
commit 67a9b358a0
8 changed files with 144 additions and 26 deletions

View File

@ -220,6 +220,20 @@ either a single field name, or a list or tuple of field names::
first_name = StringField() first_name = StringField()
last_name = StringField(unique_with='first_name') last_name = StringField(unique_with='first_name')
Skipping Document validation on save
------------------------------------
You can also skip the whole document validation process by setting
``validate=False`` when caling the :meth:`~mongoengine.document.Document.save`
method::
class Recipient(Document):
name = StringField()
email = EmailField()
recipient = Recipient(name='admin', email='root@localhost')
recipient.save() # will raise a ValidationError while
recipient.save(validate=False) # won't
Document collections Document collections
==================== ====================
Document classes that inherit **directly** from :class:`~mongoengine.Document` Document classes that inherit **directly** from :class:`~mongoengine.Document`

View File

@ -255,6 +255,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
'index_background': False, 'index_background': False,
'index_drop_dups': False, 'index_drop_dups': False,
'index_opts': {}, 'index_opts': {},
'queryset_class': QuerySet,
} }
meta.update(base_meta) meta.update(base_meta)

View File

@ -1,5 +1,5 @@
from pymongo import Connection from pymongo import Connection
import multiprocessing
__all__ = ['ConnectionError', 'connect'] __all__ = ['ConnectionError', 'connect']
@ -8,12 +8,12 @@ _connection_settings = {
'host': 'localhost', 'host': 'localhost',
'port': 27017, 'port': 27017,
} }
_connection = None _connection = {}
_db_name = None _db_name = None
_db_username = None _db_username = None
_db_password = None _db_password = None
_db = None _db = {}
class ConnectionError(Exception): class ConnectionError(Exception):
@ -22,32 +22,39 @@ class ConnectionError(Exception):
def _get_connection(): def _get_connection():
global _connection global _connection
identity = get_identity()
# Connect to the database if not already connected # Connect to the database if not already connected
if _connection is None: if _connection.get(identity) is None:
try: try:
_connection = Connection(**_connection_settings) _connection[identity] = Connection(**_connection_settings)
except: except:
raise ConnectionError('Cannot connect to the database') raise ConnectionError('Cannot connect to the database')
return _connection return _connection[identity]
def _get_db(): def _get_db():
global _db, _connection global _db, _connection
identity = get_identity()
# Connect if not already connected # Connect if not already connected
if _connection is None: if _connection.get(identity) is None:
_connection = _get_connection() _connection[identity] = _get_connection()
if _db is None: if _db.get(identity) is None:
# _db_name will be None if the user hasn't called connect() # _db_name will be None if the user hasn't called connect()
if _db_name is None: if _db_name is None:
raise ConnectionError('Not connected to the database') raise ConnectionError('Not connected to the database')
# Get DB from current connection and authenticate if necessary # Get DB from current connection and authenticate if necessary
_db = _connection[_db_name] _db[identity] = _connection[identity][_db_name]
if _db_username and _db_password: if _db_username and _db_password:
_db.authenticate(_db_username, _db_password) _db[identity].authenticate(_db_username, _db_password)
return _db return _db[identity]
def get_identity():
identity = multiprocessing.current_process()._identity
identity = 0 if not identity else identity[0]
return identity
def connect(db, username=None, password=None, **kwargs): def connect(db, username=None, password=None, **kwargs):
"""Connect to the database specified by the 'db' argument. Connection """Connect to the database specified by the 'db' argument. Connection
settings may be provided here as well if the database is not running on settings may be provided here as well if the database is not running on

View File

@ -56,7 +56,7 @@ class Document(BaseDocument):
__metaclass__ = TopLevelDocumentMetaclass __metaclass__ = TopLevelDocumentMetaclass
def save(self, safe=True, force_insert=False): def save(self, safe=True, force_insert=False, validate=True):
"""Save the :class:`~mongoengine.Document` to the database. If the """Save the :class:`~mongoengine.Document` to the database. If the
document already exists, it will be updated, otherwise it will be document already exists, it will be updated, otherwise it will be
created. created.
@ -67,8 +67,10 @@ class Document(BaseDocument):
:param safe: check if the operation succeeded before returning :param safe: check if the operation succeeded before returning
:param force_insert: only try to create a new document, don't allow :param force_insert: only try to create a new document, don't allow
updates of existing documents updates of existing documents
:param validate: validates the document; set to ``False`` for skiping
""" """
self.validate() if validate:
self.validate()
doc = self.to_mongo() doc = self.to_mongo()
try: try:
collection = self.__class__.objects._collection collection = self.__class__.objects._collection

View File

@ -512,6 +512,10 @@ class BinaryField(BaseField):
raise ValidationError('Binary value is too long') raise ValidationError('Binary value is too long')
class GridFSError(Exception):
pass
class GridFSProxy(object): class GridFSProxy(object):
"""Proxy object to handle writing and reading of files to and from GridFS """Proxy object to handle writing and reading of files to and from GridFS
@ -527,6 +531,7 @@ class GridFSProxy(object):
obj = self.get() obj = self.get()
if name in dir(obj): if name in dir(obj):
return getattr(obj, name) return getattr(obj, name)
raise AttributeError
def __get__(self, instance, value): def __get__(self, instance, value):
return self return self
@ -545,12 +550,18 @@ class GridFSProxy(object):
self.grid_id = self.newfile._id self.grid_id = self.newfile._id
def put(self, file, **kwargs): def put(self, file, **kwargs):
if self.grid_id:
raise GridFSError('This document alreay has a file. Either delete '
'it or call replace to overwrite it')
self.grid_id = self.fs.put(file, **kwargs) self.grid_id = self.fs.put(file, **kwargs)
def write(self, string): def write(self, string):
if not self.newfile: if self.grid_id:
if not self.newfile:
raise GridFSError('This document alreay has a file. Either '
'delete it or call replace to overwrite it')
else:
self.new_file() self.new_file()
self.grid_id = self.newfile._id
self.newfile.write(string) self.newfile.write(string)
def writelines(self, lines): def writelines(self, lines):

View File

@ -344,6 +344,8 @@ class QuerySet(object):
mongo_query = {} mongo_query = {}
for key, value in query.items(): for key, value in query.items():
parts = key.split('__') parts = key.split('__')
indices = [(i, p) for i, p in enumerate(parts) if p.isdigit()]
parts = [part for part in parts if not part.isdigit()]
# Check for an operator and transform to mongo-style if there is # Check for an operator and transform to mongo-style if there is
op = None op = None
if parts[-1] in operators + match_operators + geo_operators: if parts[-1] in operators + match_operators + geo_operators:
@ -381,7 +383,9 @@ class QuerySet(object):
"been implemented" % op) "been implemented" % op)
elif op not in match_operators: elif op not in match_operators:
value = {'$' + op: value} value = {'$' + op: value}
for i, part in indices:
parts.insert(i, part)
key = '.'.join(parts) key = '.'.join(parts)
if op is None or key not in mongo_query: if op is None or key not in mongo_query:
mongo_query[key] = value mongo_query[key] = value
@ -762,7 +766,8 @@ class QuerySet(object):
return mongo_update return mongo_update
def update(self, safe_update=True, upsert=False, **update): def update(self, safe_update=True, upsert=False, **update):
"""Perform an atomic update on the fields matched by the query. """Perform an atomic update on the fields matched by the query. When
``safe_update`` is used, the number of affected documents is returned.
:param safe: check if the operation succeeded before returning :param safe: check if the operation succeeded before returning
:param update: Django-style update keyword arguments :param update: Django-style update keyword arguments
@ -774,8 +779,10 @@ class QuerySet(object):
update = QuerySet._transform_update(self._document, **update) update = QuerySet._transform_update(self._document, **update)
try: try:
self._collection.update(self._query, update, safe=safe_update, ret = self._collection.update(self._query, update, multi=True,
upsert=upsert, multi=True) upsert=upsert, safe=safe_update)
if ret is not None and 'n' in ret:
return ret['n']
except pymongo.errors.OperationFailure, err: except pymongo.errors.OperationFailure, err:
if unicode(err) == u'multi not coded yet': if unicode(err) == u'multi not coded yet':
message = u'update() method requires MongoDB 1.1.3+' message = u'update() method requires MongoDB 1.1.3+'
@ -783,7 +790,8 @@ class QuerySet(object):
raise OperationError(u'Update failed (%s)' % unicode(err)) raise OperationError(u'Update failed (%s)' % unicode(err))
def update_one(self, safe_update=True, upsert=False, **update): def update_one(self, safe_update=True, upsert=False, **update):
"""Perform an atomic update on first field matched by the query. """Perform an atomic update on first field matched by the query. When
``safe_update`` is used, the number of affected documents is returned.
:param safe: check if the operation succeeded before returning :param safe: check if the operation succeeded before returning
:param update: Django-style update keyword arguments :param update: Django-style update keyword arguments
@ -795,11 +803,14 @@ class QuerySet(object):
# Explicitly provide 'multi=False' to newer versions of PyMongo # Explicitly provide 'multi=False' to newer versions of PyMongo
# as the default may change to 'True' # as the default may change to 'True'
if pymongo.version >= '1.1.1': if pymongo.version >= '1.1.1':
self._collection.update(self._query, update, safe=safe_update, ret = self._collection.update(self._query, update, multi=False,
upsert=upsert, multi=False) upsert=upsert, safe=safe_update)
else: else:
# Older versions of PyMongo don't support 'multi' # Older versions of PyMongo don't support 'multi'
self._collection.update(self._query, update, safe=safe_update) ret = self._collection.update(self._query, update,
safe=safe_update)
if ret is not None and 'n' in ret:
return ret['n']
except pymongo.errors.OperationFailure, e: except pymongo.errors.OperationFailure, e:
raise OperationError(u'Update failed [%s]' % unicode(e)) raise OperationError(u'Update failed [%s]' % unicode(e))
@ -988,7 +999,8 @@ class QuerySetManager(object):
self._collection = db[collection] self._collection = db[collection]
# owner is the document that contains the QuerySetManager # owner is the document that contains the QuerySetManager
queryset = QuerySet(owner, self._collection) queryset_class = owner._meta['queryset_class'] or QuerySet
queryset = queryset_class(owner, self._collection)
if self._manager_func: if self._manager_func:
if self._manager_func.func_code.co_argcount == 1: if self._manager_func.func_code.co_argcount == 1:
queryset = self._manager_func(queryset) queryset = self._manager_func(queryset)

View File

@ -448,6 +448,16 @@ class DocumentTest(unittest.TestCase):
self.assertEqual(person_obj['name'], 'Test User') self.assertEqual(person_obj['name'], 'Test User')
self.assertEqual(person_obj['age'], 30) self.assertEqual(person_obj['age'], 30)
self.assertEqual(person_obj['_id'], person.id) self.assertEqual(person_obj['_id'], person.id)
# Test skipping validation on save
class Recipient(Document):
email = EmailField(required=True)
recipient = Recipient(email='root@localhost')
self.assertRaises(ValidationError, recipient.save)
try:
recipient.save(validate=False)
except ValidationError:
fail()
def test_delete(self): def test_delete(self):
"""Ensure that document may be deleted using the delete method. """Ensure that document may be deleted using the delete method.

View File

@ -165,8 +165,49 @@ class QuerySetTest(unittest.TestCase):
person = self.Person.objects.get(age__lt=30) person = self.Person.objects.get(age__lt=30)
self.assertEqual(person.name, "User A") self.assertEqual(person.name, "User A")
def test_find_array_position(self):
"""Ensure that query by array position works.
"""
class Comment(EmbeddedDocument):
name = StringField()
class Post(EmbeddedDocument):
comments = ListField(EmbeddedDocumentField(Comment))
class Blog(Document):
tags = ListField(StringField())
posts = ListField(EmbeddedDocumentField(Post))
Blog.drop_collection()
Blog.objects.create(tags=['a', 'b'])
self.assertEqual(len(Blog.objects(tags__0='a')), 1)
self.assertEqual(len(Blog.objects(tags__0='b')), 0)
self.assertEqual(len(Blog.objects(tags__1='a')), 0)
self.assertEqual(len(Blog.objects(tags__1='b')), 1)
Blog.drop_collection()
comment1 = Comment(name='testa')
comment2 = Comment(name='testb')
post1 = Post(comments=[comment1, comment2])
post2 = Post(comments=[comment2, comment2])
blog1 = Blog.objects.create(posts=[post1, post2])
blog2 = Blog.objects.create(posts=[post2, post1])
blog = Blog.objects(posts__0__comments__0__name='testa').get()
self.assertEqual(blog, blog1)
query = Blog.objects(posts__1__comments__1__name='testb')
self.assertEqual(len(query), 2)
query = Blog.objects(posts__1__comments__1__name='testa')
self.assertEqual(len(query), 0)
query = Blog.objects(posts__0__comments__1__name='testa')
self.assertEqual(len(query), 0)
Blog.drop_collection()
def test_get_or_create(self): def test_get_or_create(self):
"""Ensure that ``get_or_create`` returns one result or creates a new """Ensure that ``get_or_create`` returns one result or creates a new
@ -1266,6 +1307,26 @@ class QuerySetTest(unittest.TestCase):
Event.drop_collection() Event.drop_collection()
def test_custom_querysets(self):
"""Ensure that custom QuerySet classes may be used.
"""
class CustomQuerySet(QuerySet):
def not_empty(self):
return len(self) > 0
class Post(Document):
meta = {'queryset_class': CustomQuerySet}
Post.drop_collection()
self.assertTrue(isinstance(Post.objects, CustomQuerySet))
self.assertFalse(Post.objects.not_empty())
Post().save()
self.assertTrue(Post.objects.not_empty())
Post.drop_collection()
class QTest(unittest.TestCase): class QTest(unittest.TestCase):