Merge branch 'v0.4'

This commit is contained in:
Harry Marr 2010-10-18 13:56:07 +01:00
commit 69989365c7
23 changed files with 1749 additions and 346 deletions

4
.gitignore vendored
View File

@ -1,7 +1,9 @@
*.pyc
.*.swp
*.egg
docs/.build
docs/_build
build/
dist/
mongoengine.egg-info/
mongoengine.egg-info/
env/

View File

@ -2,3 +2,4 @@ Harry Marr <harry@hmarr.com>
Matt Dennewitz <mattdennewitz@gmail.com>
Deepak Thukral <iapain@yahoo.com>
Florian Schlachter <flori@n-schlachter.de>
Steve Challis <steve@stevechallis.com>

View File

@ -64,3 +64,7 @@ Fields
.. autoclass:: mongoengine.ReferenceField
.. autoclass:: mongoengine.GenericReferenceField
.. autoclass:: mongoengine.FileField
.. autoclass:: mongoengine.GeoPointField

View File

@ -2,6 +2,32 @@
Changelog
=========
Changes in v0.4
===============
- Added ``GridFSStorage`` Django storage backend
- Added ``FileField`` for GridFS support
- New Q-object implementation, which is no longer based on Javascript
- Added ``SortedListField``
- Added ``EmailField``
- Added ``GeoPointField``
- Added ``exact`` and ``iexact`` match operators to ``QuerySet``
- Added ``get_document_or_404`` and ``get_list_or_404`` Django shortcuts
- Added new query operators for Geo queries
- Added ``not`` query operator
- Added new update operators: ``pop`` and ``add_to_set``
- Added ``__raw__`` query parameter
- Added support for custom querysets
- Fixed document inheritance primary key issue
- Added support for querying by array element position
- Base class can now be defined for ``DictField``
- Fixed MRO error that occured on document inheritance
- Added ``QuerySet.distinct``, ``QuerySet.create``, ``QuerySet.snapshot``,
``QuerySet.timeout`` and ``QuerySet.all``
- Subsequent calls to ``connect()`` now work
- Introduced ``min_length`` for ``StringField``
- Fixed multi-process connection issue
- Other minor fixes
Changes in v0.3
===============
- Added MapReduce support

View File

@ -19,7 +19,7 @@ MongoDB but still use many of the Django authentication infrastucture (such as
the :func:`login_required` decorator and the :func:`authenticate` function). To
enable the MongoEngine auth backend, add the following to you **settings.py**
file::
AUTHENTICATION_BACKENDS = (
'mongoengine.django.auth.MongoEngineBackend',
)
@ -44,3 +44,44 @@ into you settings module::
SESSION_ENGINE = 'mongoengine.django.sessions'
.. versionadded:: 0.2.1
Storage
=======
With MongoEngine's support for GridFS via the :class:`~mongoengine.FileField`,
it is useful to have a Django file storage backend that wraps this. The new
storage module is called :class:`~mongoengine.django.GridFSStorage`. Using it
is very similar to using the default FileSystemStorage.::
fs = mongoengine.django.GridFSStorage()
filename = fs.save('hello.txt', 'Hello, World!')
All of the `Django Storage API methods
<http://docs.djangoproject.com/en/dev/ref/files/storage/>`_ have been
implemented except :func:`path`. If the filename provided already exists, an
underscore and a number (before # the file extension, if one exists) will be
appended to the filename until the generated filename doesn't exist. The
:func:`save` method will return the new filename.::
>>> fs.exists('hello.txt')
True
>>> fs.open('hello.txt').read()
'Hello, World!'
>>> fs.size('hello.txt')
13
>>> fs.url('hello.txt')
'http://your_media_url/hello.txt'
>>> fs.open('hello.txt').name
'hello.txt'
>>> fs.listdir()
([], [u'hello.txt'])
All files will be saved and retrieved in GridFS via the :class::`FileDocument`
document, allowing easy access to the files without the GridFSStorage
backend.::
>>> from mongoengine.django.storage import FileDocument
>>> FileDocument.objects()
[<FileDocument: FileDocument object>]
.. versionadded:: 0.4

View File

@ -46,6 +46,12 @@ are as follows:
* :class:`~mongoengine.EmbeddedDocumentField`
* :class:`~mongoengine.ReferenceField`
* :class:`~mongoengine.GenericReferenceField`
* :class:`~mongoengine.BooleanField`
* :class:`~mongoengine.FileField`
* :class:`~mongoengine.EmailField`
* :class:`~mongoengine.SortedListField`
* :class:`~mongoengine.BinaryField`
* :class:`~mongoengine.GeoPointField`
Field arguments
---------------
@ -66,6 +72,25 @@ arguments can be set on all fields:
:attr:`default` (Default: None)
A value to use when no value is set for this field.
The definion of default parameters follow `the general rules on Python
<http://docs.python.org/reference/compound_stmts.html#function-definitions>`__,
which means that some care should be taken when dealing with default mutable objects
(like in :class:`~mongoengine.ListField` or :class:`~mongoengine.DictField`)::
class ExampleFirst(Document):
# Default an empty list
values = ListField(IntField(), default=list)
class ExampleSecond(Document):
# Default a set of values
values = ListField(IntField(), default=lambda: [1,2,3])
class ExampleDangerous(Document):
# This can make an .append call to add values to the default (and all the following objects),
# instead to just an object
values = ListField(IntField(), default=[1,2,3])
:attr:`unique` (Default: False)
When True, no documents in the collection will have the same value for this
field.
@ -214,6 +239,20 @@ either a single field name, or a list or tuple of field names::
first_name = StringField()
last_name = StringField(unique_with='first_name')
Skipping Document validation on save
------------------------------------
You can also skip the whole document validation process by setting
``validate=False`` when caling the :meth:`~mongoengine.document.Document.save`
method::
class Recipient(Document):
name = StringField()
email = EmailField()
recipient = Recipient(name='admin', email='root@localhost')
recipient.save() # will raise a ValidationError while
recipient.save(validate=False) # won't
Document collections
====================
Document classes that inherit **directly** from :class:`~mongoengine.Document`
@ -259,6 +298,10 @@ or a **-** sign. Note that direction only matters on multi-field indexes. ::
meta = {
'indexes': ['title', ('title', '-rating')]
}
.. note::
Geospatial indexes will be automatically created for all
:class:`~mongoengine.GeoPointField`\ s
Ordering
========

View File

@ -59,6 +59,13 @@ you may still use :attr:`id` to access the primary key if you want::
>>> bob.id == bob.email == 'bob@example.com'
True
You can also access the document's "primary key" using the :attr:`pk` field; in
is an alias to :attr:`id`::
>>> page = Page(title="Another Test Page")
>>> page.save()
>>> page.id == page.pk
.. note::
If you define your own primary key field, the field implicitly becomes
required, so a :class:`ValidationError` will be thrown if you don't provide

83
docs/guide/gridfs.rst Normal file
View File

@ -0,0 +1,83 @@
======
GridFS
======
.. versionadded:: 0.4
Writing
-------
GridFS support comes in the form of the :class:`~mongoengine.FileField` field
object. This field acts as a file-like object and provides a couple of
different ways of inserting and retrieving data. Arbitrary metadata such as
content type can also be stored alongside the files. In the following example,
a document is created to store details about animals, including a photo::
class Animal(Document):
genus = StringField()
family = StringField()
photo = FileField()
marmot = Animal('Marmota', 'Sciuridae')
marmot_photo = open('marmot.jpg', 'r') # Retrieve a photo from disk
marmot.photo = marmot_photo # Store photo in the document
marmot.photo.content_type = 'image/jpeg' # Store metadata
marmot.save()
Another way of writing to a :class:`~mongoengine.FileField` is to use the
:func:`put` method. This allows for metadata to be stored in the same call as
the file::
marmot.photo.put(marmot_photo, content_type='image/jpeg')
marmot.save()
Retrieval
---------
So using the :class:`~mongoengine.FileField` is just like using any other
field. The file can also be retrieved just as easily::
marmot = Animal.objects(genus='Marmota').first()
photo = marmot.photo.read()
content_type = marmot.photo.content_type
Streaming
---------
Streaming data into a :class:`~mongoengine.FileField` is achieved in a
slightly different manner. First, a new file must be created by calling the
:func:`new_file` method. Data can then be written using :func:`write`::
marmot.photo.new_file()
marmot.photo.write('some_image_data')
marmot.photo.write('some_more_image_data')
marmot.photo.close()
marmot.photo.save()
Deletion
--------
Deleting stored files is achieved with the :func:`delete` method::
marmot.photo.delete()
.. note::
The FileField in a Document actually only stores the ID of a file in a
separate GridFS collection. This means that deleting a document
with a defined FileField does not actually delete the file. You must be
careful to delete any files in a Document as above before deleting the
Document itself.
Replacing files
---------------
Files can be replaced with the :func:`replace` method. This works just like
the :func:`put` method so even metadata can (and should) be replaced::
another_marmot = open('another_marmot.png', 'r')
marmot.photo.replace(another_marmot, content_type='image/png')

View File

@ -10,3 +10,4 @@ User Guide
defining-documents
document-instances
querying
gridfs

View File

@ -34,7 +34,7 @@ arguments. The keys in the keyword arguments correspond to fields on the
Fields on embedded documents may also be referred to using field lookup syntax
by using a double-underscore in place of the dot in object attribute access
syntax::
# This will return a QuerySet that will only iterate over pages that have
# been written by a user whose 'country' field is set to 'uk'
uk_pages = Page.objects(author__country='uk')
@ -53,11 +53,21 @@ lists that contain that item will be matched::
# 'tags' list
Page.objects(tags='coding')
Raw queries
-----------
It is possible to provide a raw PyMongo query as a query parameter, which will
be integrated directly into the query. This is done using the ``__raw__``
keyword argument::
Page.objects(__raw__={'tags': 'coding'})
.. versionadded:: 0.4
Query operators
===============
Operators other than equality may also be used in queries; just attach the
operator name to a key with a double-underscore::
# Only find users whose age is 18 or less
young_users = Users.objects(age__lte=18)
@ -68,10 +78,12 @@ Available operators are as follows:
* ``lte`` -- less than or equal to
* ``gt`` -- greater than
* ``gte`` -- greater than or equal to
* ``not`` -- negate a standard check, may be used before other operators (e.g.
``Q(age__not__mod=5)``)
* ``in`` -- value is in list (a list of values should be provided)
* ``nin`` -- value is not in list (a list of values should be provided)
* ``mod`` -- ``value % x == y``, where ``x`` and ``y`` are two provided values
* ``all`` -- every item in array is in list of values provided
* ``all`` -- every item in list of values provided is in array
* ``size`` -- the size of the array is
* ``exists`` -- value for field exists
@ -89,6 +101,27 @@ expressions:
.. versionadded:: 0.3
There are a few special operators for performing geographical queries, that
may used with :class:`~mongoengine.GeoPointField`\ s:
* ``within_distance`` -- provide a list containing a point and a maximum
distance (e.g. [(41.342, -87.653), 5])
* ``within_box`` -- filter documents to those within a given bounding box (e.g.
[(35.0, -125.0), (40.0, -100.0)])
* ``near`` -- order the documents by how close they are to a given point
.. versionadded:: 0.4
Querying by position
====================
It is possible to query by position in a list by using a numerical value as a
query operator. So if you wanted to find all pages whose first tag was ``db``,
you could use the following query::
BlogPost.objects(tags__0='db')
.. versionadded:: 0.4
Limiting and skipping results
=============================
Just as with traditional ORMs, you may limit the number of results returned, or
@ -111,7 +144,7 @@ You may also index the query to retrieve a single result. If an item at that
index does not exists, an :class:`IndexError` will be raised. A shortcut for
retrieving the first result and returning :attr:`None` if no result exists is
provided (:meth:`~mongoengine.queryset.QuerySet.first`)::
>>> # Make sure there are no users
>>> User.drop_collection()
>>> User.objects[0]
@ -174,13 +207,29 @@ custom manager methods as you like::
@queryset_manager
def live_posts(doc_cls, queryset):
return queryset(published=True).filter(published=True)
return queryset.filter(published=True)
BlogPost(title='test1', published=False).save()
BlogPost(title='test2', published=True).save()
assert len(BlogPost.objects) == 2
assert len(BlogPost.live_posts) == 1
Custom QuerySets
================
Should you want to add custom methods for interacting with or filtering
documents, extending the :class:`~mongoengine.queryset.QuerySet` class may be
the way to go. To use a custom :class:`~mongoengine.queryset.QuerySet` class on
a document, set ``queryset_class`` to the custom class in a
:class:`~mongoengine.Document`\ s ``meta`` dictionary::
class AwesomerQuerySet(QuerySet):
pass
class Page(Document):
meta = {'queryset_class': AwesomerQuerySet}
.. versionadded:: 0.4
Aggregation
===========
MongoDB provides some aggregation methods out of the box, but there are not as
@ -399,14 +448,17 @@ that you may use with these methods:
* ``unset`` -- delete a particular value (since MongoDB v1.3+)
* ``inc`` -- increment a value by a given amount
* ``dec`` -- decrement a value by a given amount
* ``pop`` -- remove the last item from a list
* ``push`` -- append a value to a list
* ``push_all`` -- append several values to a list
* ``pop`` -- remove the first or last element of a list
* ``pull`` -- remove a value from a list
* ``pull_all`` -- remove several values from a list
* ``add_to_set`` -- add value to a list only if its not in the list already
The syntax for atomic updates is similar to the querying syntax, but the
modifier comes before the field, not after it::
>>> post = BlogPost(title='Test', page_views=0, tags=['database'])
>>> post.save()
>>> BlogPost.objects(id=post.id).update_one(inc__page_views=1)

View File

@ -7,7 +7,7 @@ MongoDB. To install it, simply run
.. code-block:: console
# easy_install -U mongoengine
# pip install -U mongoengine
The source is available on `GitHub <http://github.com/hmarr/mongoengine>`_.

View File

@ -12,7 +12,7 @@ __all__ = (document.__all__ + fields.__all__ + connection.__all__ +
__author__ = 'Harry Marr'
VERSION = (0, 3, 0)
VERSION = (0, 4, 0)
def get_version():
version = '%s.%s' % (VERSION[0], VERSION[1])

View File

@ -23,10 +23,11 @@ class BaseField(object):
# Fields may have _types inserted into indexes by default
_index_with_types = True
_geo_index = False
def __init__(self, db_field=None, name=None, required=False, default=None,
unique=False, unique_with=None, primary_key=False, validation=None,
choices=None):
unique=False, unique_with=None, primary_key=False,
validation=None, choices=None):
self.db_field = (db_field or name) if not primary_key else '_id'
if name:
import warnings
@ -87,22 +88,24 @@ class BaseField(object):
# check choices
if self.choices is not None:
if value not in self.choices:
raise ValidationError("Value must be one of %s."%unicode(self.choices))
raise ValidationError("Value must be one of %s."
% unicode(self.choices))
# check validation argument
if self.validation is not None:
if callable(self.validation):
if not self.validation(value):
raise ValidationError('Value does not match custom validation method.')
raise ValidationError('Value does not match custom' \
'validation method.')
else:
raise ValueError('validation argument must be a callable.')
self.validate(value)
class ObjectIdField(BaseField):
"""An field wrapper around MongoDB's ObjectIds.
"""
def to_python(self, value):
return value
# return unicode(value)
@ -148,7 +151,7 @@ class DocumentMetaclass(type):
# Get superclasses from superclass
superclasses[base._class_name] = base
superclasses.update(base._superclasses)
if hasattr(base, '_meta'):
# Ensure that the Document class may be subclassed -
# inheritance may be disabled to remove dependency on
@ -189,20 +192,23 @@ class DocumentMetaclass(type):
field.owner_document = new_class
module = attrs.get('__module__')
base_excs = tuple(base.DoesNotExist for base in bases
if hasattr(base, 'DoesNotExist')) or (DoesNotExist,)
exc = subclass_exception('DoesNotExist', base_excs, module)
new_class.add_to_class('DoesNotExist', exc)
base_excs = tuple(base.MultipleObjectsReturned for base in bases
if hasattr(base, 'MultipleObjectsReturned'))
base_excs = base_excs or (MultipleObjectsReturned,)
exc = subclass_exception('MultipleObjectsReturned', base_excs, module)
new_class.add_to_class('MultipleObjectsReturned', exc)
global _document_registry
_document_registry[name] = new_class
return new_class
def add_to_class(self, name, value):
setattr(self, name, value)
@ -213,8 +219,6 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
"""
def __new__(cls, name, bases, attrs):
global _document_registry
super_new = super(TopLevelDocumentMetaclass, cls).__new__
# Classes defined in this package are abstract and should not have
# their own metadata with DB collection, etc.
@ -225,15 +229,21 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
return super_new(cls, name, bases, attrs)
collection = name.lower()
id_field = None
base_indexes = []
base_meta = {}
# Subclassed documents inherit collection from superclass
for base in bases:
if hasattr(base, '_meta') and 'collection' in base._meta:
collection = base._meta['collection']
# Propagate index options.
for key in ('index_background', 'index_drop_dups', 'index_opts'):
if key in base._meta:
base_meta[key] = base._meta[key]
id_field = id_field or base._meta.get('id_field')
base_indexes += base._meta.get('indexes', [])
@ -244,7 +254,12 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
'ordering': [], # default ordering applied at runtime
'indexes': [], # indexes to be ensured at runtime
'id_field': id_field,
'index_background': False,
'index_drop_dups': False,
'index_opts': {},
'queryset_class': QuerySet,
}
meta.update(base_meta)
# Apply document-defined meta options
meta.update(attrs.get('meta', {}))
@ -253,18 +268,21 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
# Set up collection manager, needs the class to have fields so use
# DocumentMetaclass before instantiating CollectionManager object
new_class = super_new(cls, name, bases, attrs)
new_class.objects = QuerySetManager()
# Provide a default queryset unless one has been manually provided
if not hasattr(new_class, 'objects'):
new_class.objects = QuerySetManager()
user_indexes = [QuerySet._build_index_spec(new_class, spec)
for spec in meta['indexes']] + base_indexes
new_class._meta['indexes'] = user_indexes
unique_indexes = []
for field_name, field in new_class._fields.items():
# Generate a list of indexes needed by uniqueness constraints
if field.unique:
field.required = True
unique_fields = [field_name]
unique_fields = [field.db_field]
# Add any unique_with fields to the back of the index spec
if field.unique_with:
@ -305,8 +323,6 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
new_class._fields['id'] = ObjectIdField(db_field='_id')
new_class.id = new_class._fields['id']
_document_registry[name] = new_class
return new_class
@ -314,14 +330,17 @@ class BaseDocument(object):
def __init__(self, **values):
self._data = {}
# Assign default values to instance
for attr_name in self._fields.keys():
# Use default value if present
value = getattr(self, attr_name, None)
setattr(self, attr_name, value)
# Assign initial values to instance
for attr_name, attr_value in self._fields.items():
if attr_name in values:
for attr_name in values.keys():
try:
setattr(self, attr_name, values.pop(attr_name))
else:
# Use default value if present
value = getattr(self, attr_name, None)
setattr(self, attr_name, value)
except AttributeError:
pass
def validate(self):
"""Ensure that all fields' values are valid and that required fields
@ -337,8 +356,8 @@ class BaseDocument(object):
try:
field._validate(value)
except (ValueError, AttributeError, AssertionError), e:
raise ValidationError('Invalid value for field of type "' +
field.__class__.__name__ + '"')
raise ValidationError('Invalid value for field of type "%s": %s'
% (field.__class__.__name__, value))
elif field.required:
raise ValidationError('Field "%s" is required' % field.name)
@ -357,6 +376,16 @@ class BaseDocument(object):
all_subclasses.update(subclass._get_subclasses())
return all_subclasses
@apply
def pk():
"""Primary key alias
"""
def fget(self):
return getattr(self, self._meta['id_field'])
def fset(self, value):
return setattr(self, self._meta['id_field'], value)
return property(fget, fset)
def __iter__(self):
return iter(self._fields)
@ -413,8 +442,10 @@ class BaseDocument(object):
self._meta.get('allow_inheritance', True) == False):
data['_cls'] = self._class_name
data['_types'] = self._superclasses.keys() + [self._class_name]
if data.has_key('_id') and not data['_id']:
del data['_id']
return data
@classmethod
def _from_son(cls, son):
"""Create an instance of a Document (subclass) from a PyMongo SON.
@ -444,12 +475,14 @@ class BaseDocument(object):
for field_name, field in cls._fields.items():
if field.db_field in data:
data[field_name] = field.to_python(data[field.db_field])
value = data[field.db_field]
data[field_name] = (value if value is None
else field.to_python(value))
obj = cls(**data)
obj._present_fields = present_fields
return obj
def __eq__(self, other):
if isinstance(other, self.__class__) and hasattr(other, 'id'):
if self.id == other.id:

View File

@ -1,62 +1,71 @@
from pymongo import Connection
import multiprocessing
__all__ = ['ConnectionError', 'connect']
_connection_settings = {
_connection_defaults = {
'host': 'localhost',
'port': 27017,
}
_connection = None
_connection = {}
_connection_settings = _connection_defaults.copy()
_db_name = None
_db_username = None
_db_password = None
_db = None
_db = {}
class ConnectionError(Exception):
pass
def _get_connection():
def _get_connection(reconnect=False):
global _connection
identity = get_identity()
# Connect to the database if not already connected
if _connection is None:
if _connection.get(identity) is None or reconnect:
try:
_connection = Connection(**_connection_settings)
_connection[identity] = Connection(**_connection_settings)
except:
raise ConnectionError('Cannot connect to the database')
return _connection
return _connection[identity]
def _get_db():
def _get_db(reconnect=False):
global _db, _connection
identity = get_identity()
# Connect if not already connected
if _connection is None:
_connection = _get_connection()
if _connection.get(identity) is None or reconnect:
_connection[identity] = _get_connection(reconnect=reconnect)
if _db is None:
if _db.get(identity) is None or reconnect:
# _db_name will be None if the user hasn't called connect()
if _db_name is None:
raise ConnectionError('Not connected to the database')
# Get DB from current connection and authenticate if necessary
_db = _connection[_db_name]
_db[identity] = _connection[identity][_db_name]
if _db_username and _db_password:
_db.authenticate(_db_username, _db_password)
_db[identity].authenticate(_db_username, _db_password)
return _db
return _db[identity]
def get_identity():
identity = multiprocessing.current_process()._identity
identity = 0 if not identity else identity[0]
return identity
def connect(db, username=None, password=None, **kwargs):
"""Connect to the database specified by the 'db' argument. Connection
settings may be provided here as well if the database is not running on
the default port on localhost. If authentication is needed, provide
username and password arguments as well.
"""
global _connection_settings, _db_name, _db_username, _db_password
_connection_settings.update(kwargs)
global _connection_settings, _db_name, _db_username, _db_password, _db
_connection_settings = dict(_connection_defaults, **kwargs)
_db_name = db
_db_username = username
_db_password = password
return _get_db()
return _get_db(reconnect=True)

View File

@ -32,6 +32,9 @@ class User(Document):
last_login = DateTimeField(default=datetime.datetime.now)
date_joined = DateTimeField(default=datetime.datetime.now)
def __unicode__(self):
return self.username
def get_full_name(self):
"""Returns the users first and last names, separated by a space.
"""
@ -72,10 +75,9 @@ class User(Document):
email address.
"""
now = datetime.datetime.now()
# Normalize the address by lowercasing the domain part of the email
# address.
# Not sure why we'r allowing null email when its not allowed in django
if email is not None:
try:
email_name, domain_part = email.strip().split('@', 1)
@ -83,12 +85,12 @@ class User(Document):
pass
else:
email = '@'.join([email_name, domain_part.lower()])
user = User(username=username, email=email, date_joined=now)
user.set_password(password)
user.save()
return user
def get_and_delete_messages(self):
return []

View File

@ -0,0 +1,112 @@
import os
import itertools
import urlparse
from mongoengine import *
from django.conf import settings
from django.core.files.storage import Storage
from django.core.exceptions import ImproperlyConfigured
class FileDocument(Document):
"""A document used to store a single file in GridFS.
"""
file = FileField()
class GridFSStorage(Storage):
"""A custom storage backend to store files in GridFS
"""
def __init__(self, base_url=None):
if base_url is None:
base_url = settings.MEDIA_URL
self.base_url = base_url
self.document = FileDocument
self.field = 'file'
def delete(self, name):
"""Deletes the specified file from the storage system.
"""
if self.exists(name):
doc = self.document.objects.first()
field = getattr(doc, self.field)
self._get_doc_with_name(name).delete() # Delete the FileField
field.delete() # Delete the FileDocument
def exists(self, name):
"""Returns True if a file referened by the given name already exists in the
storage system, or False if the name is available for a new file.
"""
doc = self._get_doc_with_name(name)
if doc:
field = getattr(doc, self.field)
return bool(field.name)
else:
return False
def listdir(self, path=None):
"""Lists the contents of the specified path, returning a 2-tuple of lists;
the first item being directories, the second item being files.
"""
def name(doc):
return getattr(doc, self.field).name
docs = self.document.objects
return [], [name(d) for d in docs if name(d)]
def size(self, name):
"""Returns the total size, in bytes, of the file specified by name.
"""
doc = self._get_doc_with_name(name)
if doc:
return getattr(doc, self.field).length
else:
raise ValueError("No such file or directory: '%s'" % name)
def url(self, name):
"""Returns an absolute URL where the file's contents can be accessed
directly by a web browser.
"""
if self.base_url is None:
raise ValueError("This file is not accessible via a URL.")
return urlparse.urljoin(self.base_url, name).replace('\\', '/')
def _get_doc_with_name(self, name):
"""Find the documents in the store with the given name
"""
docs = self.document.objects
doc = [d for d in docs if getattr(d, self.field).name == name]
if doc:
return doc[0]
else:
return None
def _open(self, name, mode='rb'):
doc = self._get_doc_with_name(name)
if doc:
return getattr(doc, self.field)
else:
raise ValueError("No file found with the name '%s'." % name)
def get_available_name(self, name):
"""Returns a filename that's free on the target storage system, and
available for new content to be written to.
"""
file_root, file_ext = os.path.splitext(name)
# If the filename already exists, add an underscore and a number (before
# the file extension, if one exists) to the filename until the generated
# filename doesn't exist.
count = itertools.count(1)
while self.exists(name):
# file_ext includes the dot.
name = os.path.join("%s_%s%s" % (file_root, count.next(), file_ext))
return name
def _save(self, name, content):
doc = self.document()
getattr(doc, self.field).put(content, filename=name)
doc.save()
return name

View File

@ -0,0 +1,21 @@
#coding: utf-8
from django.test import TestCase
from django.conf import settings
from mongoengine import connect
class MongoTestCase(TestCase):
"""
TestCase class that clear the collection between the tests
"""
db_name = 'test_%s' % settings.MONGO_DATABASE_NAME
def __init__(self, methodName='runtest'):
self.db = connect(self.db_name)
super(MongoTestCase, self).__init__(methodName)
def _post_teardown(self):
super(MongoTestCase, self)._post_teardown()
for collection in self.db.collection_names():
if collection == 'system.indexes':
continue
self.db.drop_collection(collection)

View File

@ -15,7 +15,7 @@ class EmbeddedDocument(BaseDocument):
fields on :class:`~mongoengine.Document`\ s through the
:class:`~mongoengine.EmbeddedDocumentField` field type.
"""
__metaclass__ = DocumentMetaclass
@ -56,7 +56,7 @@ class Document(BaseDocument):
__metaclass__ = TopLevelDocumentMetaclass
def save(self, safe=True, force_insert=False):
def save(self, safe=True, force_insert=False, validate=True):
"""Save the :class:`~mongoengine.Document` to the database. If the
document already exists, it will be updated, otherwise it will be
created.
@ -67,8 +67,10 @@ class Document(BaseDocument):
:param safe: check if the operation succeeded before returning
:param force_insert: only try to create a new document, don't allow
updates of existing documents
:param validate: validates the document; set to ``False`` for skiping
"""
self.validate()
if validate:
self.validate()
doc = self.to_mongo()
try:
collection = self.__class__.objects._collection
@ -119,23 +121,23 @@ class Document(BaseDocument):
class MapReduceDocument(object):
"""A document returned from a map/reduce query.
:param collection: An instance of :class:`~pymongo.Collection`
:param key: Document/result key, often an instance of
:class:`~pymongo.objectid.ObjectId`. If supplied as
an ``ObjectId`` found in the given ``collection``,
the object can be accessed via the ``object`` property.
:param value: The result(s) for this key.
.. versionadded:: 0.3
"""
def __init__(self, document, collection, key, value):
self._document = document
self._collection = collection
self.key = key
self.value = value
@property
def object(self):
"""Lazy-load the object referenced by ``self.key``. ``self.key``
@ -143,7 +145,7 @@ class MapReduceDocument(object):
"""
id_field = self._document()._meta['id_field']
id_field_type = type(id_field)
if not isinstance(self.key, id_field_type):
try:
self.key = id_field_type(self.key)

View File

@ -10,13 +10,16 @@ import pymongo.son
import pymongo.binary
import datetime
import decimal
import gridfs
import warnings
import types
__all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField',
'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField',
'ObjectIdField', 'ReferenceField', 'ValidationError',
'DecimalField', 'URLField', 'GenericReferenceField',
'BinaryField', 'SortedListField', 'EmailField', 'GeoLocationField']
'DecimalField', 'URLField', 'GenericReferenceField', 'FileField',
'BinaryField', 'SortedListField', 'EmailField', 'GeoPointField']
RECURSIVE_REFERENCE_CONSTANT = 'self'
@ -39,7 +42,7 @@ class StringField(BaseField):
if self.max_length is not None and len(value) > self.max_length:
raise ValidationError('String value is too long')
if self.min_length is not None and len(value) < self.min_length:
raise ValidationError('String value is too short')
@ -67,6 +70,9 @@ class StringField(BaseField):
regex = r'%s$'
elif op == 'exact':
regex = r'^%s$'
# escape unsafe characters which could lead to a re.error
value = re.escape(value)
value = re.compile(regex % value, flags)
return value
@ -103,8 +109,11 @@ class URLField(StringField):
message = 'This URL appears to be a broken link: %s' % e
raise ValidationError(message)
class EmailField(StringField):
"""A field that validates input as an E-Mail-Address.
.. versionadded:: 0.4
"""
EMAIL_REGEX = re.compile(
@ -112,11 +121,12 @@ class EmailField(StringField):
r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-011\013\014\016-\177])*"' # quoted-string
r')@(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?$', re.IGNORECASE # domain
)
def validate(self, value):
if not EmailField.EMAIL_REGEX.match(value):
raise ValidationError('Invalid Mail-address: %s' % value)
class IntField(BaseField):
"""An integer field.
"""
@ -140,6 +150,7 @@ class IntField(BaseField):
if self.max_value is not None and value > self.max_value:
raise ValidationError('Integer value is too large')
class FloatField(BaseField):
"""An floating point number field.
"""
@ -176,7 +187,7 @@ class DecimalField(BaseField):
if not isinstance(value, basestring):
value = unicode(value)
return decimal.Decimal(value)
def to_mongo(self, value):
return unicode(value)
@ -195,6 +206,7 @@ class DecimalField(BaseField):
if self.max_value is not None and value > self.max_value:
raise ValidationError('Decimal value is too large')
class BooleanField(BaseField):
"""A boolean field type.
@ -207,6 +219,7 @@ class BooleanField(BaseField):
def validate(self, value):
assert isinstance(value, bool)
class DateTimeField(BaseField):
"""A datetime field.
"""
@ -214,38 +227,49 @@ class DateTimeField(BaseField):
def validate(self, value):
assert isinstance(value, datetime.datetime)
class EmbeddedDocumentField(BaseField):
"""An embedded document field. Only valid values are subclasses of
:class:`~mongoengine.EmbeddedDocument`.
"""
def __init__(self, document, **kwargs):
if not issubclass(document, EmbeddedDocument):
raise ValidationError('Invalid embedded document class provided '
'to an EmbeddedDocumentField')
self.document = document
def __init__(self, document_type, **kwargs):
if not isinstance(document_type, basestring):
if not issubclass(document_type, EmbeddedDocument):
raise ValidationError('Invalid embedded document class '
'provided to an EmbeddedDocumentField')
self.document_type_obj = document_type
super(EmbeddedDocumentField, self).__init__(**kwargs)
@property
def document_type(self):
if isinstance(self.document_type_obj, basestring):
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
self.document_type_obj = self.owner_document
else:
self.document_type_obj = get_document(self.document_type_obj)
return self.document_type_obj
def to_python(self, value):
if not isinstance(value, self.document):
return self.document._from_son(value)
if not isinstance(value, self.document_type):
return self.document_type._from_son(value)
return value
def to_mongo(self, value):
return self.document.to_mongo(value)
return self.document_type.to_mongo(value)
def validate(self, value):
"""Make sure that the document instance is an instance of the
EmbeddedDocument subclass provided when the document was defined.
"""
# Using isinstance also works for subclasses of self.document
if not isinstance(value, self.document):
if not isinstance(value, self.document_type):
raise ValidationError('Invalid embedded document instance '
'provided to an EmbeddedDocumentField')
self.document.validate(value)
self.document_type.validate(value)
def lookup_member(self, member_name):
return self.document._fields.get(member_name)
return self.document_type._fields.get(member_name)
def prepare_query_value(self, op, value):
return self.to_mongo(value)
@ -264,6 +288,7 @@ class ListField(BaseField):
raise ValidationError('Argument to ListField constructor must be '
'a valid field')
self.field = field
kwargs.setdefault('default', lambda: [])
super(ListField, self).__init__(**kwargs)
def __get__(self, instance, owner):
@ -318,20 +343,32 @@ class ListField(BaseField):
try:
[self.field.validate(item) for item in value]
except Exception, err:
raise ValidationError('Invalid ListField item (%s)' % str(err))
raise ValidationError('Invalid ListField item (%s)' % str(item))
def prepare_query_value(self, op, value):
if op in ('set', 'unset'):
return [self.field.to_mongo(v) for v in value]
return self.field.to_mongo(value)
return [self.field.prepare_query_value(op, v) for v in value]
return self.field.prepare_query_value(op, value)
def lookup_member(self, member_name):
return self.field.lookup_member(member_name)
def _set_owner_document(self, owner_document):
self.field.owner_document = owner_document
self._owner_document = owner_document
def _get_owner_document(self, owner_document):
self._owner_document = owner_document
owner_document = property(_get_owner_document, _set_owner_document)
class SortedListField(ListField):
"""A ListField that sorts the contents of its list before writing to
the database in order to ensure that a sorted list is always
retrieved.
.. versionadded:: 0.4
"""
_ordering = None
@ -343,9 +380,11 @@ class SortedListField(ListField):
def to_mongo(self, value):
if self._ordering is not None:
return sorted([self.field.to_mongo(item) for item in value], key=itemgetter(self._ordering))
return sorted([self.field.to_mongo(item) for item in value],
key=itemgetter(self._ordering))
return sorted([self.field.to_mongo(item) for item in value])
class DictField(BaseField):
"""A dictionary field that wraps a standard Python dictionary. This is
similar to an embedded document, but the structure is not defined.
@ -356,6 +395,7 @@ class DictField(BaseField):
def __init__(self, basecls=None, *args, **kwargs):
self.basecls = basecls or BaseField
assert issubclass(self.basecls, BaseField)
kwargs.setdefault('default', lambda: {})
super(DictField, self).__init__(*args, **kwargs)
def validate(self, value):
@ -372,24 +412,6 @@ class DictField(BaseField):
def lookup_member(self, member_name):
return self.basecls(db_field=member_name)
class GeoLocationField(DictField):
"""Supports geobased fields"""
def validate(self, value):
"""Make sure that a geo-value is of type (x, y)
"""
if not isinstance(value, tuple) and not isinstance(value, list):
raise ValidationError('GeoLocationField can only hold tuples or lists of (x, y)')
if len(value) <> 2:
raise ValidationError('GeoLocationField must have exactly two elements (x, y)')
def to_mongo(self, value):
return {'x': value[0], 'y': value[1]}
def to_python(self, value):
return value.keys()
class ReferenceField(BaseField):
"""A reference to a document that will be automatically dereferenced on
access (lazily).
@ -401,7 +423,6 @@ class ReferenceField(BaseField):
raise ValidationError('Argument to ReferenceField constructor '
'must be a document class or a string')
self.document_type_obj = document_type
self.document_obj = None
super(ReferenceField, self).__init__(**kwargs)
@property
@ -501,7 +522,8 @@ class GenericReferenceField(BaseField):
return {'_cls': document.__class__.__name__, '_ref': ref}
def prepare_query_value(self, op, value):
return self.to_mongo(value)['_ref']
return self.to_mongo(value)
class BinaryField(BaseField):
"""A binary data field.
@ -523,3 +545,161 @@ class BinaryField(BaseField):
if self.max_bytes is not None and len(value) > self.max_bytes:
raise ValidationError('Binary value is too long')
class GridFSError(Exception):
pass
class GridFSProxy(object):
"""Proxy object to handle writing and reading of files to and from GridFS
.. versionadded:: 0.4
"""
def __init__(self, grid_id=None):
self.fs = gridfs.GridFS(_get_db()) # Filesystem instance
self.newfile = None # Used for partial writes
self.grid_id = grid_id # Store GridFS id for file
def __getattr__(self, name):
obj = self.get()
if name in dir(obj):
return getattr(obj, name)
raise AttributeError
def __get__(self, instance, value):
return self
def get(self, id=None):
if id:
self.grid_id = id
try:
return self.fs.get(id or self.grid_id)
except:
# File has been deleted
return None
def new_file(self, **kwargs):
self.newfile = self.fs.new_file(**kwargs)
self.grid_id = self.newfile._id
def put(self, file, **kwargs):
if self.grid_id:
raise GridFSError('This document already has a file. Either delete '
'it or call replace to overwrite it')
self.grid_id = self.fs.put(file, **kwargs)
def write(self, string):
if self.grid_id:
if not self.newfile:
raise GridFSError('This document already has a file. Either '
'delete it or call replace to overwrite it')
else:
self.new_file()
self.newfile.write(string)
def writelines(self, lines):
if not self.newfile:
self.new_file()
self.grid_id = self.newfile._id
self.newfile.writelines(lines)
def read(self):
try:
return self.get().read()
except:
return None
def delete(self):
# Delete file from GridFS, FileField still remains
self.fs.delete(self.grid_id)
self.grid_id = None
def replace(self, file, **kwargs):
self.delete()
self.put(file, **kwargs)
def close(self):
if self.newfile:
self.newfile.close()
else:
msg = "The close() method is only necessary after calling write()"
warnings.warn(msg)
class FileField(BaseField):
"""A GridFS storage field.
.. versionadded:: 0.4
"""
def __init__(self, **kwargs):
super(FileField, self).__init__(**kwargs)
def __get__(self, instance, owner):
if instance is None:
return self
# Check if a file already exists for this model
grid_file = instance._data.get(self.name)
self.grid_file = grid_file
if self.grid_file:
return self.grid_file
return GridFSProxy()
def __set__(self, instance, value):
if isinstance(value, file) or isinstance(value, str):
# using "FileField() = file/string" notation
grid_file = instance._data.get(self.name)
# If a file already exists, delete it
if grid_file:
try:
grid_file.delete()
except:
pass
# Create a new file with the new data
grid_file.put(value)
else:
# Create a new proxy object as we don't already have one
instance._data[self.name] = GridFSProxy()
instance._data[self.name].put(value)
else:
instance._data[self.name] = value
def to_mongo(self, value):
# Store the GridFS file id in MongoDB
if isinstance(value, GridFSProxy) and value.grid_id is not None:
return value.grid_id
return None
def to_python(self, value):
if value is not None:
return GridFSProxy(value)
def validate(self, value):
if value.grid_id is not None:
assert isinstance(value, GridFSProxy)
assert isinstance(value.grid_id, pymongo.objectid.ObjectId)
class GeoPointField(BaseField):
"""A list storing a latitude and longitude.
.. versionadded:: 0.4
"""
_geo_index = True
def validate(self, value):
"""Make sure that a geo-value is of type (x, y)
"""
if not isinstance(value, (list, tuple)):
raise ValidationError('GeoPointField can only accept tuples or '
'lists of (x, y)')
if not len(value) == 2:
raise ValidationError('Value must be a two-dimensional point.')
if (not isinstance(value[0], (float, int)) and
not isinstance(value[1], (float, int))):
raise ValidationError('Both values in point must be float or int.')

View File

@ -1,11 +1,13 @@
from connection import _get_db
import pprint
import pymongo
import pymongo.code
import pymongo.dbref
import pymongo.objectid
import re
import copy
import itertools
__all__ = ['queryset_manager', 'Q', 'InvalidQueryError',
'InvalidCollectionError']
@ -17,6 +19,7 @@ REPR_OUTPUT_SIZE = 20
class DoesNotExist(Exception):
pass
class MultipleObjectsReturned(Exception):
pass
@ -28,50 +31,192 @@ class InvalidQueryError(Exception):
class OperationError(Exception):
pass
class InvalidCollectionError(Exception):
pass
RE_TYPE = type(re.compile(''))
class Q(object):
class QNodeVisitor(object):
"""Base visitor class for visiting Q-object nodes in a query tree.
"""
OR = '||'
AND = '&&'
OPERATORS = {
'eq': ('((this.%(field)s instanceof Array) && '
' this.%(field)s.indexOf(%(value)s) != -1) ||'
' this.%(field)s == %(value)s'),
'ne': 'this.%(field)s != %(value)s',
'gt': 'this.%(field)s > %(value)s',
'gte': 'this.%(field)s >= %(value)s',
'lt': 'this.%(field)s < %(value)s',
'lte': 'this.%(field)s <= %(value)s',
'lte': 'this.%(field)s <= %(value)s',
'in': '%(value)s.indexOf(this.%(field)s) != -1',
'nin': '%(value)s.indexOf(this.%(field)s) == -1',
'mod': '%(field)s %% %(value)s',
'all': ('%(value)s.every(function(a){'
'return this.%(field)s.indexOf(a) != -1 })'),
'size': 'this.%(field)s.length == %(value)s',
'exists': 'this.%(field)s != null',
'regex_eq': '%(value)s.test(this.%(field)s)',
'regex_ne': '!%(value)s.test(this.%(field)s)',
}
def visit_combination(self, combination):
"""Called by QCombination objects.
"""
return combination
def __init__(self, **query):
self.query = [query]
def visit_query(self, query):
"""Called by (New)Q objects.
"""
return query
def _combine(self, other, op):
obj = Q()
if not other.query[0]:
class SimplificationVisitor(QNodeVisitor):
"""Simplifies query trees by combinging unnecessary 'and' connection nodes
into a single Q-object.
"""
def visit_combination(self, combination):
if combination.operation == combination.AND:
# The simplification only applies to 'simple' queries
if all(isinstance(node, Q) for node in combination.children):
queries = [node.query for node in combination.children]
return Q(**self._query_conjunction(queries))
return combination
def _query_conjunction(self, queries):
"""Merges query dicts - effectively &ing them together.
"""
query_ops = set()
combined_query = {}
for query in queries:
ops = set(query.keys())
# Make sure that the same operation isn't applied more than once
# to a single field
intersection = ops.intersection(query_ops)
if intersection:
msg = 'Duplicate query contitions: '
raise InvalidQueryError(msg + ', '.join(intersection))
query_ops.update(ops)
combined_query.update(copy.deepcopy(query))
return combined_query
class QueryTreeTransformerVisitor(QNodeVisitor):
"""Transforms the query tree in to a form that may be used with MongoDB.
"""
def visit_combination(self, combination):
if combination.operation == combination.AND:
# MongoDB doesn't allow us to have too many $or operations in our
# queries, so the aim is to move the ORs up the tree to one
# 'master' $or. Firstly, we must find all the necessary parts (part
# of an AND combination or just standard Q object), and store them
# separately from the OR parts.
or_groups = []
and_parts = []
for node in combination.children:
if isinstance(node, QCombination):
if node.operation == node.OR:
# Any of the children in an $or component may cause
# the query to succeed
or_groups.append(node.children)
elif node.operation == node.AND:
and_parts.append(node)
elif isinstance(node, Q):
and_parts.append(node)
# Now we combine the parts into a usable query. AND together all of
# the necessary parts. Then for each $or part, create a new query
# that ANDs the necessary part with the $or part.
clauses = []
for or_group in itertools.product(*or_groups):
q_object = reduce(lambda a, b: a & b, and_parts, Q())
q_object = reduce(lambda a, b: a & b, or_group, q_object)
clauses.append(q_object)
# Finally, $or the generated clauses in to one query. Each of the
# clauses is sufficient for the query to succeed.
return reduce(lambda a, b: a | b, clauses, Q())
if combination.operation == combination.OR:
children = []
# Crush any nested ORs in to this combination as MongoDB doesn't
# support nested $or operations
for node in combination.children:
if (isinstance(node, QCombination) and
node.operation == combination.OR):
children += node.children
else:
children.append(node)
combination.children = children
return combination
class QueryCompilerVisitor(QNodeVisitor):
"""Compiles the nodes in a query tree to a PyMongo-compatible query
dictionary.
"""
def __init__(self, document):
self.document = document
def visit_combination(self, combination):
if combination.operation == combination.OR:
return {'$or': combination.children}
elif combination.operation == combination.AND:
return self._mongo_query_conjunction(combination.children)
return combination
def visit_query(self, query):
return QuerySet._transform_query(self.document, **query.query)
def _mongo_query_conjunction(self, queries):
"""Merges Mongo query dicts - effectively &ing them together.
"""
combined_query = {}
for query in queries:
for field, ops in query.items():
if field not in combined_query:
combined_query[field] = ops
else:
# The field is already present in the query the only way
# we can merge is if both the existing value and the new
# value are operation dicts, reject anything else
if (not isinstance(combined_query[field], dict) or
not isinstance(ops, dict)):
message = 'Conflicting values for ' + field
raise InvalidQueryError(message)
current_ops = set(combined_query[field].keys())
new_ops = set(ops.keys())
# Make sure that the same operation isn't applied more than
# once to a single field
intersection = current_ops.intersection(new_ops)
if intersection:
msg = 'Duplicate query contitions: '
raise InvalidQueryError(msg + ', '.join(intersection))
# Right! We've got two non-overlapping dicts of operations!
combined_query[field].update(copy.deepcopy(ops))
return combined_query
class QNode(object):
"""Base class for nodes in query trees.
"""
AND = 0
OR = 1
def to_query(self, document):
query = self.accept(SimplificationVisitor())
query = query.accept(QueryTreeTransformerVisitor())
query = query.accept(QueryCompilerVisitor(document))
return query
def accept(self, visitor):
raise NotImplementedError
def _combine(self, other, operation):
"""Combine this node with another node into a QCombination object.
"""
if other.empty:
return self
if self.query[0]:
obj.query = (['('] + copy.deepcopy(self.query) + [op] +
copy.deepcopy(other.query) + [')'])
else:
obj.query = copy.deepcopy(other.query)
return obj
if self.empty:
return other
return QCombination(operation, [self, other])
@property
def empty(self):
return False
def __or__(self, other):
return self._combine(other, self.OR)
@ -79,70 +224,49 @@ class Q(object):
def __and__(self, other):
return self._combine(other, self.AND)
def as_js(self, document):
js = []
js_scope = {}
for i, item in enumerate(self.query):
if isinstance(item, dict):
item_query = QuerySet._transform_query(document, **item)
# item_query will values will either be a value or a dict
js.append(self._item_query_as_js(item_query, js_scope, i))
class QCombination(QNode):
"""Represents the combination of several conditions by a given logical
operator.
"""
def __init__(self, operation, children):
self.operation = operation
self.children = []
for node in children:
# If the child is a combination of the same type, we can merge its
# children directly into this combinations children
if isinstance(node, QCombination) and node.operation == operation:
self.children += node.children
else:
js.append(item)
return pymongo.code.Code(' '.join(js), js_scope)
self.children.append(node)
def _item_query_as_js(self, item_query, js_scope, item_num):
# item_query will be in one of the following forms
# {'age': 25, 'name': 'Test'}
# {'age': {'$lt': 25}, 'name': {'$in': ['Test', 'Example']}
# {'age': {'$lt': 25, '$gt': 18}}
js = []
for i, (key, value) in enumerate(item_query.items()):
op = 'eq'
# Construct a variable name for the value in the JS
value_name = 'i%sf%s' % (item_num, i)
if isinstance(value, dict):
# Multiple operators for this field
for j, (op, value) in enumerate(value.items()):
# Create a custom variable name for this operator
op_value_name = '%so%s' % (value_name, j)
# Construct the JS that uses this op
value, operation_js = self._build_op_js(op, key, value,
op_value_name)
# Update the js scope with the value for this op
js_scope[op_value_name] = value
js.append(operation_js)
else:
# Construct the JS for this field
value, field_js = self._build_op_js(op, key, value, value_name)
js_scope[value_name] = value
js.append(field_js)
print ' && '.join(js)
return ' && '.join(js)
def accept(self, visitor):
for i in range(len(self.children)):
self.children[i] = self.children[i].accept(visitor)
def _build_op_js(self, op, key, value, value_name):
"""Substitute the values in to the correct chunk of Javascript.
"""
print op, key, value, value_name
if isinstance(value, RE_TYPE):
# Regexes are handled specially
if op.strip('$') == 'ne':
op_js = Q.OPERATORS['regex_ne']
else:
op_js = Q.OPERATORS['regex_eq']
else:
op_js = Q.OPERATORS[op.strip('$')]
return visitor.visit_combination(self)
# Comparing two ObjectIds in Javascript doesn't work..
if isinstance(value, pymongo.objectid.ObjectId):
value = unicode(value)
@property
def empty(self):
return not bool(self.children)
class Q(QNode):
"""A simple query object, used in a query tree to build up more complex
query structures.
"""
def __init__(self, **query):
self.query = query
def accept(self, visitor):
return visitor.visit_query(self)
@property
def empty(self):
return not bool(self.query)
# Perform the substitution
operation_js = op_js % {
'field': key,
'value': value_name
}
return value, operation_js
class QuerySet(object):
"""A set of results returned from a query. Wraps a MongoDB cursor,
@ -153,20 +277,32 @@ class QuerySet(object):
self._document = document
self._collection_obj = collection
self._accessed_collection = False
self._query = {}
self._mongo_query = None
self._query_obj = Q()
self._initial_query = {}
self._where_clause = None
self._loaded_fields = []
self._ordering = []
self._snapshot = False
self._timeout = True
# If inheritance is allowed, only return instances and instances of
# subclasses of the class being used
if document._meta.get('allow_inheritance'):
self._query = {'_types': self._document._class_name}
self._initial_query = {'_types': self._document._class_name}
self._cursor_obj = None
self._limit = None
self._skip = None
def ensure_index(self, key_or_list):
@property
def _query(self):
if self._mongo_query is None:
self._mongo_query = self._query_obj.to_query(self._document)
self._mongo_query.update(self._initial_query)
return self._mongo_query
def ensure_index(self, key_or_list, drop_dups=False, background=False,
**kwargs):
"""Ensure that the given indexes are in place.
:param key_or_list: a single index key or a list of index keys (to
@ -174,7 +310,8 @@ class QuerySet(object):
or a **-** to determine the index ordering
"""
index_list = QuerySet._build_index_spec(self._document, key_or_list)
self._collection.ensure_index(index_list)
self._collection.ensure_index(index_list, drop_dups=drop_dups,
background=background)
return self
@classmethod
@ -222,10 +359,14 @@ class QuerySet(object):
objects, only the last one will be used
:param query: Django-style query keyword arguments
"""
#if q_obj:
#self._where_clause = q_obj.as_js(self._document)
query = Q(**query)
if q_obj:
self._where_clause = q_obj.as_js(self._document)
query = QuerySet._transform_query(_doc_cls=self._document, **query)
self._query.update(query)
query &= q_obj
self._query_obj &= query
self._mongo_query = None
self._cursor_obj = None
return self
def filter(self, *q_objs, **query):
@ -233,6 +374,10 @@ class QuerySet(object):
"""
return self.__call__(*q_objs, **query)
def all(self):
"""Returns all documents."""
return self.__call__()
@property
def _collection(self):
"""Property that returns the collection object. This allows us to
@ -240,33 +385,45 @@ class QuerySet(object):
"""
if not self._accessed_collection:
self._accessed_collection = True
background = self._document._meta.get('index_background', False)
drop_dups = self._document._meta.get('index_drop_dups', False)
index_opts = self._document._meta.get('index_options', {})
# Ensure document-defined indexes are created
if self._document._meta['indexes']:
for key_or_list in self._document._meta['indexes']:
#self.ensure_index(key_or_list)
self._collection.ensure_index(key_or_list)
self._collection.ensure_index(key_or_list,
background=background, **index_opts)
# Ensure indexes created by uniqueness constraints
for index in self._document._meta['unique_indexes']:
self._collection.ensure_index(index, unique=True)
self._collection.ensure_index(index, unique=True,
background=background, drop_dups=drop_dups, **index_opts)
# If _types is being used (for polymorphism), it needs an index
if '_types' in self._query:
self._collection.ensure_index('_types')
self._collection.ensure_index('_types',
background=background, **index_opts)
# Ensure all needed field indexes are created
for field_name, field_instance in self._document._fields.iteritems():
if field_instance.__class__.__name__ == 'GeoLocationField':
self._collection.ensure_index([(field_name, pymongo.GEO2D),])
for field in self._document._fields.values():
if field.__class__._geo_index:
index_spec = [(field.db_field, pymongo.GEO2D)]
self._collection.ensure_index(index_spec,
background=background, **index_opts)
return self._collection_obj
@property
def _cursor(self):
if self._cursor_obj is None:
cursor_args = {}
cursor_args = {
'snapshot': self._snapshot,
'timeout': self._timeout,
}
if self._loaded_fields:
cursor_args = {'fields': self._loaded_fields}
cursor_args['fields'] = self._loaded_fields
self._cursor_obj = self._collection.find(self._query,
**cursor_args)
# Apply where clauses to cursor
@ -291,6 +448,9 @@ class QuerySet(object):
for field_name in parts:
if field is None:
# Look up first field from the document
if field_name == 'pk':
# Deal with "primary key" alias
field_name = document._meta['id_field']
field = document._fields[field_name]
else:
# Look up subfield on the previous field
@ -314,19 +474,31 @@ class QuerySet(object):
"""Transform a query from Django-style format to Mongo format.
"""
operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod',
'all', 'size', 'exists', 'near']
match_operators = ['contains', 'icontains', 'startswith',
'istartswith', 'endswith', 'iendswith',
'all', 'size', 'exists', 'not']
geo_operators = ['within_distance', 'within_box', 'near']
match_operators = ['contains', 'icontains', 'startswith',
'istartswith', 'endswith', 'iendswith',
'exact', 'iexact']
mongo_query = {}
for key, value in query.items():
if key == "__raw__":
mongo_query.update(value)
continue
parts = key.split('__')
indices = [(i, p) for i, p in enumerate(parts) if p.isdigit()]
parts = [part for part in parts if not part.isdigit()]
# Check for an operator and transform to mongo-style if there is
op = None
if parts[-1] in operators + match_operators:
if parts[-1] in operators + match_operators + geo_operators:
op = parts.pop()
negate = False
if parts[-1] == 'not':
parts.pop()
negate = True
if _doc_cls:
# Switch field names to proper names [set in Field(name='foo')]
fields = QuerySet._lookup_field(_doc_cls, parts)
@ -334,20 +506,34 @@ class QuerySet(object):
# Convert value to proper value
field = fields[-1]
singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte']
singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not']
singular_ops += match_operators
if op in singular_ops:
value = field.prepare_query_value(op, value)
elif op in ('in', 'nin', 'all'):
elif op in ('in', 'nin', 'all', 'near'):
# 'in', 'nin' and 'all' require a list of values
value = [field.prepare_query_value(op, v) for v in value]
if field.__class__.__name__ == 'GenericReferenceField':
parts.append('_ref')
# if op and op not in match_operators:
if op:
if op in geo_operators:
if op == "within_distance":
value = {'$within': {'$center': value}}
elif op == "near":
value = {'$near': value}
elif op == 'within_box':
value = {'$within': {'$box': value}}
else:
raise NotImplementedError("Geo method '%s' has not "
"been implemented" % op)
elif op not in match_operators:
value = {'$' + op: value}
if op and op not in match_operators:
value = {'$' + op: value}
if negate:
value = {'$not': value}
for i, part in indices:
parts.insert(i, part)
key = '.'.join(parts)
if op is None or key not in mongo_query:
mongo_query[key] = value
@ -405,6 +591,15 @@ class QuerySet(object):
message = u'%d items returned, instead of 1' % count
raise self._document.MultipleObjectsReturned(message)
def create(self, **kwargs):
"""Create new object. Returns the saved object instance.
.. versionadded:: 0.4
"""
doc = self._document(**kwargs)
doc.save()
return doc
def first(self):
"""Retrieve the first object matching the query.
"""
@ -429,7 +624,7 @@ class QuerySet(object):
def in_bulk(self, object_ids):
"""Retrieve a set of documents by their ids.
:param object_ids: a list or tuple of ``ObjectId``\ s
:rtype: dict of ObjectIds as keys and collection-specific
Document subclasses as values.
@ -441,7 +636,7 @@ class QuerySet(object):
docs = self._collection.find({'_id': {'$in': object_ids}})
for doc in docs:
doc_map[doc['_id']] = self._document._from_son(doc)
return doc_map
def next(self):
@ -595,12 +790,22 @@ class QuerySet(object):
# Integer index provided
elif isinstance(key, int):
return self._document._from_son(self._cursor[key])
raise AttributeError
def distinct(self, field):
"""Return a list of distinct values for a given field.
:param field: the field to select distinct values from
.. versionadded:: 0.4
"""
return self._cursor.distinct(field)
def only(self, *fields):
"""Load only a subset of this document's fields. ::
post = BlogPost.objects(...).only("title")
:param fields: fields to include
.. versionadded:: 0.3
@ -629,11 +834,13 @@ class QuerySet(object):
"""
key_list = []
for key in keys:
if not key: continue
direction = pymongo.ASCENDING
if key[0] == '-':
direction = pymongo.DESCENDING
if key[0] in ('-', '+'):
key = key[1:]
key = key.replace('__', '.')
key_list.append((key, direction))
self._ordering = key_list
@ -649,10 +856,23 @@ class QuerySet(object):
plan = self._cursor.explain()
if format:
import pprint
plan = pprint.pformat(plan)
return plan
def snapshot(self, enabled):
"""Enable or disable snapshot mode when querying.
:param enabled: whether or not snapshot mode is enabled
"""
self._snapshot = enabled
def timeout(self, enabled):
"""Enable or disable the default mongod timeout when querying.
:param enabled: whether or not the timeout is used
"""
self._timeout = enabled
def delete(self, safe=False):
"""Delete the documents matched by the query.
@ -664,8 +884,8 @@ class QuerySet(object):
def _transform_update(cls, _doc_cls=None, **update):
"""Transform an update spec from Django-style format to Mongo format.
"""
operators = ['set', 'unset', 'inc', 'dec', 'push', 'push_all', 'pull',
'pull_all']
operators = ['set', 'unset', 'inc', 'dec', 'pop', 'push', 'push_all',
'pull', 'pull_all', 'add_to_set']
mongo_update = {}
for key, value in update.items():
@ -683,6 +903,8 @@ class QuerySet(object):
op = 'inc'
if value > 0:
value = -value
elif op == 'add_to_set':
op = op.replace('_to_set', 'ToSet')
if _doc_cls:
# Switch field names to proper names [set in Field(name='foo')]
@ -691,7 +913,8 @@ class QuerySet(object):
# Convert value to proper value
field = fields[-1]
if op in (None, 'set', 'unset', 'push', 'pull'):
if op in (None, 'set', 'unset', 'pop', 'push', 'pull',
'addToSet'):
value = field.prepare_query_value(op, value)
elif op in ('pushAll', 'pullAll'):
value = [field.prepare_query_value(op, v) for v in value]
@ -710,7 +933,8 @@ class QuerySet(object):
return mongo_update
def update(self, safe_update=True, upsert=False, **update):
"""Perform an atomic update on the fields matched by the query.
"""Perform an atomic update on the fields matched by the query. When
``safe_update`` is used, the number of affected documents is returned.
:param safe: check if the operation succeeded before returning
:param update: Django-style update keyword arguments
@ -722,8 +946,10 @@ class QuerySet(object):
update = QuerySet._transform_update(self._document, **update)
try:
self._collection.update(self._query, update, safe=safe_update,
upsert=upsert, multi=True)
ret = self._collection.update(self._query, update, multi=True,
upsert=upsert, safe=safe_update)
if ret is not None and 'n' in ret:
return ret['n']
except pymongo.errors.OperationFailure, err:
if unicode(err) == u'multi not coded yet':
message = u'update() method requires MongoDB 1.1.3+'
@ -731,7 +957,8 @@ class QuerySet(object):
raise OperationError(u'Update failed (%s)' % unicode(err))
def update_one(self, safe_update=True, upsert=False, **update):
"""Perform an atomic update on first field matched by the query.
"""Perform an atomic update on first field matched by the query. When
``safe_update`` is used, the number of affected documents is returned.
:param safe: check if the operation succeeded before returning
:param update: Django-style update keyword arguments
@ -743,11 +970,14 @@ class QuerySet(object):
# Explicitly provide 'multi=False' to newer versions of PyMongo
# as the default may change to 'True'
if pymongo.version >= '1.1.1':
self._collection.update(self._query, update, safe=safe_update,
upsert=upsert, multi=False)
ret = self._collection.update(self._query, update, multi=False,
upsert=upsert, safe=safe_update)
else:
# Older versions of PyMongo don't support 'multi'
self._collection.update(self._query, update, safe=safe_update)
ret = self._collection.update(self._query, update,
safe=safe_update)
if ret is not None and 'n' in ret:
return ret['n']
except pymongo.errors.OperationFailure, e:
raise OperationError(u'Update failed [%s]' % unicode(e))
@ -840,7 +1070,7 @@ class QuerySet(object):
var total = 0.0;
var num = 0;
db[collection].find(query).forEach(function(doc) {
if (doc[averageField]) {
if (doc[averageField] !== undefined) {
total += doc[averageField];
num += 1;
}
@ -850,20 +1080,27 @@ class QuerySet(object):
"""
return self.exec_js(average_func, field)
def item_frequencies(self, list_field, normalize=False):
"""Returns a dictionary of all items present in a list field across
def item_frequencies(self, field, normalize=False):
"""Returns a dictionary of all items present in a field across
the whole queried set of documents, and their corresponding frequency.
This is useful for generating tag clouds, or searching documents.
:param list_field: the list field to use
If the field is a :class:`~mongoengine.ListField`, the items within
each list will be counted individually.
:param field: the field to use
:param normalize: normalize the results so they add to 1.0
"""
freq_func = """
function(listField) {
function(field) {
if (options.normalize) {
var total = 0.0;
db[collection].find(query).forEach(function(doc) {
total += doc[listField].length;
if (doc[field].constructor == Array) {
total += doc[field].length;
} else {
total++;
}
});
}
@ -873,14 +1110,19 @@ class QuerySet(object):
inc /= total;
}
db[collection].find(query).forEach(function(doc) {
doc[listField].forEach(function(item) {
if (doc[field].constructor == Array) {
doc[field].forEach(function(item) {
frequencies[item] = inc + (frequencies[item] || 0);
});
} else {
var item = doc[field];
frequencies[item] = inc + (frequencies[item] || 0);
});
}
});
return frequencies;
}
"""
return self.exec_js(freq_func, list_field, normalize=normalize)
return self.exec_js(freq_func, field, normalize=normalize)
def __repr__(self):
limit = REPR_OUTPUT_SIZE + 1
@ -896,7 +1138,7 @@ class QuerySetManager(object):
def __init__(self, manager_func=None):
self._manager_func = manager_func
self._collection = None
self._collections = {}
def __get__(self, instance, owner):
"""Descriptor for instantiating a new QuerySet object when
@ -906,10 +1148,9 @@ class QuerySetManager(object):
# Document class being used rather than a document object
return self
if self._collection is None:
db = _get_db()
collection = owner._meta['collection']
db = _get_db()
collection = owner._meta['collection']
if (db, collection) not in self._collections:
# Create collection as a capped collection if specified
if owner._meta['max_size'] or owner._meta['max_documents']:
# Get max document limit and max byte size from meta
@ -917,10 +1158,10 @@ class QuerySetManager(object):
max_documents = owner._meta['max_documents']
if collection in db.collection_names():
self._collection = db[collection]
self._collections[(db, collection)] = db[collection]
# The collection already exists, check if its capped
# options match the specified capped options
options = self._collection.options()
options = self._collections[(db, collection)].options()
if options.get('max') != max_documents or \
options.get('size') != max_size:
msg = ('Cannot create collection "%s" as a capped '
@ -931,12 +1172,15 @@ class QuerySetManager(object):
opts = {'capped': True, 'size': max_size}
if max_documents:
opts['max'] = max_documents
self._collection = db.create_collection(collection, **opts)
self._collections[(db, collection)] = db.create_collection(
collection, **opts
)
else:
self._collection = db[collection]
self._collections[(db, collection)] = db[collection]
# owner is the document that contains the QuerySetManager
queryset = QuerySet(owner, self._collection)
queryset_class = owner._meta['queryset_class'] or QuerySet
queryset = queryset_class(owner, self._collections[(db, collection)])
if self._manager_func:
if self._manager_func.func_code.co_argcount == 1:
queryset = self._manager_func(queryset)

View File

@ -200,6 +200,37 @@ class DocumentTest(unittest.TestCase):
Person.drop_collection()
self.assertFalse(collection in self.db.collection_names())
def test_inherited_collections(self):
"""Ensure that subclassed documents don't override parents' collections.
"""
class Drink(Document):
name = StringField()
class AlcoholicDrink(Drink):
meta = {'collection': 'booze'}
class Drinker(Document):
drink = GenericReferenceField()
Drink.drop_collection()
AlcoholicDrink.drop_collection()
Drinker.drop_collection()
red_bull = Drink(name='Red Bull')
red_bull.save()
programmer = Drinker(drink=red_bull)
programmer.save()
beer = AlcoholicDrink(name='Beer')
beer.save()
real_person = Drinker(drink=beer)
real_person.save()
self.assertEqual(Drinker.objects[0].drink.name, red_bull.name)
self.assertEqual(Drinker.objects[1].drink.name, beer.name)
def test_capped_collection(self):
"""Ensure that capped collections work properly.
"""
@ -264,11 +295,12 @@ class DocumentTest(unittest.TestCase):
# Indexes are lazy so use list() to perform query
list(BlogPost.objects)
info = BlogPost.objects._collection.index_information()
info = [value['key'] for key, value in info.iteritems()]
self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)]
in info.values())
self.assertTrue([('_types', 1), ('addDate', -1)] in info.values())
in info)
self.assertTrue([('_types', 1), ('addDate', -1)] in info)
# tags is a list field so it shouldn't have _types in the index
self.assertTrue([('tags', 1)] in info.values())
self.assertTrue([('tags', 1)] in info)
class ExtendedBlogPost(BlogPost):
title = StringField()
@ -278,10 +310,11 @@ class DocumentTest(unittest.TestCase):
list(ExtendedBlogPost.objects)
info = ExtendedBlogPost.objects._collection.index_information()
info = [value['key'] for key, value in info.iteritems()]
self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)]
in info.values())
self.assertTrue([('_types', 1), ('addDate', -1)] in info.values())
self.assertTrue([('_types', 1), ('title', 1)] in info.values())
in info)
self.assertTrue([('_types', 1), ('addDate', -1)] in info)
self.assertTrue([('_types', 1), ('title', 1)] in info)
BlogPost.drop_collection()
@ -353,12 +386,26 @@ class DocumentTest(unittest.TestCase):
user_obj = User.objects.first()
self.assertEqual(user_obj.id, 'test')
self.assertEqual(user_obj.pk, 'test')
user_son = User.objects._collection.find_one()
self.assertEqual(user_son['_id'], 'test')
self.assertTrue('username' not in user_son['_id'])
User.drop_collection()
user = User(pk='mongo', name='mongo user')
user.save()
user_obj = User.objects.first()
self.assertEqual(user_obj.id, 'mongo')
self.assertEqual(user_obj.pk, 'mongo')
user_son = User.objects._collection.find_one()
self.assertEqual(user_son['_id'], 'mongo')
self.assertTrue('username' not in user_son['_id'])
User.drop_collection()
def test_creation(self):
"""Ensure that document may be created using keyword arguments.
@ -446,6 +493,16 @@ class DocumentTest(unittest.TestCase):
self.assertEqual(person_obj['name'], 'Test User')
self.assertEqual(person_obj['age'], 30)
self.assertEqual(person_obj['_id'], person.id)
# Test skipping validation on save
class Recipient(Document):
email = EmailField(required=True)
recipient = Recipient(email='root@localhost')
self.assertRaises(ValidationError, recipient.save)
try:
recipient.save(validate=False)
except ValidationError:
fail()
def test_delete(self):
"""Ensure that document may be deleted using the delete method.
@ -467,6 +524,18 @@ class DocumentTest(unittest.TestCase):
collection = self.db[self.Person._meta['collection']]
person_obj = collection.find_one({'name': 'Test User'})
self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4')
def test_save_custom_pk(self):
"""Ensure that a document may be saved with a custom _id using pk alias.
"""
# Create person object and save it to the database
person = self.Person(name='Test User', age=30,
pk='497ce96f395f2f052a494fd4')
person.save()
# Ensure that the object is in the database with the correct _id
collection = self.db[self.Person._meta['collection']]
person_obj = collection.find_one({'name': 'Test User'})
self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4')
def test_save_list(self):
"""Ensure that a list field may be properly saved.

View File

@ -3,6 +3,7 @@ import datetime
from decimal import Decimal
import pymongo
import gridfs
from mongoengine import *
from mongoengine.connection import _get_db
@ -188,6 +189,9 @@ class FieldTest(unittest.TestCase):
def test_list_validation(self):
"""Ensure that a list field only accepts lists with valid elements.
"""
class User(Document):
pass
class Comment(EmbeddedDocument):
content = StringField()
@ -195,6 +199,7 @@ class FieldTest(unittest.TestCase):
content = StringField()
comments = ListField(EmbeddedDocumentField(Comment))
tags = ListField(StringField())
authors = ListField(ReferenceField(User))
post = BlogPost(content='Went for a walk today...')
post.validate()
@ -209,15 +214,21 @@ class FieldTest(unittest.TestCase):
post.tags = ('fun', 'leisure')
post.validate()
comments = [Comment(content='Good for you'), Comment(content='Yay.')]
post.comments = comments
post.validate()
post.comments = ['a']
self.assertRaises(ValidationError, post.validate)
post.comments = 'yay'
self.assertRaises(ValidationError, post.validate)
comments = [Comment(content='Good for you'), Comment(content='Yay.')]
post.comments = comments
post.validate()
post.authors = [Comment()]
self.assertRaises(ValidationError, post.validate)
post.authors = [User()]
post.validate()
def test_sorted_list_sorting(self):
"""Ensure that a sorted list field properly sorts values.
"""
@ -227,7 +238,8 @@ class FieldTest(unittest.TestCase):
class BlogPost(Document):
content = StringField()
comments = SortedListField(EmbeddedDocumentField(Comment), ordering='order')
comments = SortedListField(EmbeddedDocumentField(Comment),
ordering='order')
tags = SortedListField(StringField())
post = BlogPost(content='Went for a walk today...')
@ -393,14 +405,54 @@ class FieldTest(unittest.TestCase):
class Employee(Document):
name = StringField()
boss = ReferenceField('self')
friends = ListField(ReferenceField('self'))
bill = Employee(name='Bill Lumbergh')
bill.save()
peter = Employee(name='Peter Gibbons', boss=bill)
michael = Employee(name='Michael Bolton')
michael.save()
samir = Employee(name='Samir Nagheenanajar')
samir.save()
friends = [michael, samir]
peter = Employee(name='Peter Gibbons', boss=bill, friends=friends)
peter.save()
peter = Employee.objects.with_id(peter.id)
self.assertEqual(peter.boss, bill)
self.assertEqual(peter.friends, friends)
def test_recursive_embedding(self):
"""Ensure that EmbeddedDocumentFields can contain their own documents.
"""
class Tree(Document):
name = StringField()
children = ListField(EmbeddedDocumentField('TreeNode'))
class TreeNode(EmbeddedDocument):
name = StringField()
children = ListField(EmbeddedDocumentField('self'))
tree = Tree(name="Tree")
first_child = TreeNode(name="Child 1")
tree.children.append(first_child)
second_child = TreeNode(name="Child 2")
first_child.children.append(second_child)
third_child = TreeNode(name="Child 3")
first_child.children.append(third_child)
tree.save()
tree_obj = Tree.objects.first()
self.assertEqual(len(tree.children), 1)
self.assertEqual(tree.children[0].name, first_child.name)
self.assertEqual(tree.children[0].children[0].name, second_child.name)
self.assertEqual(tree.children[0].children[1].name, third_child.name)
def test_undefined_reference(self):
"""Ensure that ReferenceFields may reference undefined Documents.
@ -607,7 +659,130 @@ class FieldTest(unittest.TestCase):
Shirt.drop_collection()
def test_file_fields(self):
"""Ensure that file fields can be written to and their data retrieved
"""
class PutFile(Document):
file = FileField()
class StreamFile(Document):
file = FileField()
class SetFile(Document):
file = FileField()
text = 'Hello, World!'
more_text = 'Foo Bar'
content_type = 'text/plain'
PutFile.drop_collection()
StreamFile.drop_collection()
SetFile.drop_collection()
putfile = PutFile()
putfile.file.put(text, content_type=content_type)
putfile.save()
putfile.validate()
result = PutFile.objects.first()
self.assertTrue(putfile == result)
self.assertEquals(result.file.read(), text)
self.assertEquals(result.file.content_type, content_type)
result.file.delete() # Remove file from GridFS
streamfile = StreamFile()
streamfile.file.new_file(content_type=content_type)
streamfile.file.write(text)
streamfile.file.write(more_text)
streamfile.file.close()
streamfile.save()
streamfile.validate()
result = StreamFile.objects.first()
self.assertTrue(streamfile == result)
self.assertEquals(result.file.read(), text + more_text)
self.assertEquals(result.file.content_type, content_type)
result.file.delete()
# Ensure deleted file returns None
self.assertTrue(result.file.read() == None)
setfile = SetFile()
setfile.file = text
setfile.save()
setfile.validate()
result = SetFile.objects.first()
self.assertTrue(setfile == result)
self.assertEquals(result.file.read(), text)
# Try replacing file with new one
result.file.replace(more_text)
result.save()
result.validate()
result = SetFile.objects.first()
self.assertTrue(setfile == result)
self.assertEquals(result.file.read(), more_text)
result.file.delete()
PutFile.drop_collection()
StreamFile.drop_collection()
SetFile.drop_collection()
# Make sure FileField is optional and not required
class DemoFile(Document):
file = FileField()
d = DemoFile.objects.create()
def test_file_uniqueness(self):
"""Ensure that each instance of a FileField is unique
"""
class TestFile(Document):
name = StringField()
file = FileField()
# First instance
testfile = TestFile()
testfile.name = "Hello, World!"
testfile.file.put('Hello, World!')
testfile.save()
# Second instance
testfiledupe = TestFile()
data = testfiledupe.file.read() # Should be None
self.assertTrue(testfile.name != testfiledupe.name)
self.assertTrue(testfile.file.read() != data)
TestFile.drop_collection()
def test_geo_indexes(self):
"""Ensure that indexes are created automatically for GeoPointFields.
"""
class Event(Document):
title = StringField()
location = GeoPointField()
Event.drop_collection()
event = Event(title="Coltrane Motion @ Double Door",
location=[41.909889, -87.677137])
event.save()
info = Event.objects._collection.index_information()
self.assertTrue(u'location_2d' in info)
self.assertTrue(info[u'location_2d']['key'] == [(u'location', u'2d')])
Event.drop_collection()
def test_ensure_unique_default_instances(self):
"""Ensure that every field has it's own unique default instance."""
class D(Document):
data = DictField()
data2 = DictField(default=lambda: {})
d1 = D()
d1.data['foo'] = 'bar'
d1.data2['foo'] = 'bar'
d2 = D()
self.assertEqual(d2.data, {})
self.assertEqual(d2.data2, {})
if __name__ == '__main__':
unittest.main()

View File

@ -53,9 +53,6 @@ class QuerySetTest(unittest.TestCase):
person2 = self.Person(name="User B", age=30)
person2.save()
q1 = Q(name='test')
q2 = Q(age__gte=18)
# Find all people in the collection
people = self.Person.objects
self.assertEqual(len(people), 2)
@ -156,7 +153,8 @@ class QuerySetTest(unittest.TestCase):
# Retrieve the first person from the database
self.assertRaises(MultipleObjectsReturned, self.Person.objects.get)
self.assertRaises(self.Person.MultipleObjectsReturned, self.Person.objects.get)
self.assertRaises(self.Person.MultipleObjectsReturned,
self.Person.objects.get)
# Use a query to filter the people found to just person2
person = self.Person.objects.get(age=30)
@ -165,8 +163,49 @@ class QuerySetTest(unittest.TestCase):
person = self.Person.objects.get(age__lt=30)
self.assertEqual(person.name, "User A")
def test_find_array_position(self):
"""Ensure that query by array position works.
"""
class Comment(EmbeddedDocument):
name = StringField()
class Post(EmbeddedDocument):
comments = ListField(EmbeddedDocumentField(Comment))
class Blog(Document):
tags = ListField(StringField())
posts = ListField(EmbeddedDocumentField(Post))
Blog.drop_collection()
Blog.objects.create(tags=['a', 'b'])
self.assertEqual(len(Blog.objects(tags__0='a')), 1)
self.assertEqual(len(Blog.objects(tags__0='b')), 0)
self.assertEqual(len(Blog.objects(tags__1='a')), 0)
self.assertEqual(len(Blog.objects(tags__1='b')), 1)
Blog.drop_collection()
comment1 = Comment(name='testa')
comment2 = Comment(name='testb')
post1 = Post(comments=[comment1, comment2])
post2 = Post(comments=[comment2, comment2])
blog1 = Blog.objects.create(posts=[post1, post2])
blog2 = Blog.objects.create(posts=[post2, post1])
blog = Blog.objects(posts__0__comments__0__name='testa').get()
self.assertEqual(blog, blog1)
query = Blog.objects(posts__1__comments__1__name='testb')
self.assertEqual(len(query), 2)
query = Blog.objects(posts__1__comments__1__name='testa')
self.assertEqual(len(query), 0)
query = Blog.objects(posts__0__comments__1__name='testa')
self.assertEqual(len(query), 0)
Blog.drop_collection()
def test_get_or_create(self):
"""Ensure that ``get_or_create`` returns one result or creates a new
@ -193,7 +232,8 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(created, False)
# Try retrieving when no objects exists - new doc should be created
person, created = self.Person.objects.get_or_create(age=50, defaults={'name': 'User C'})
kwargs = dict(age=50, defaults={'name': 'User C'})
person, created = self.Person.objects.get_or_create(**kwargs)
self.assertEqual(created, True)
person = self.Person.objects.get(age=50)
@ -288,6 +328,25 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(obj, person)
obj = self.Person.objects(Q(name__iexact='gUIDO VAN rOSSU')).first()
self.assertEqual(obj, None)
# Test unsafe expressions
person = self.Person(name='Guido van Rossum [.\'Geek\']')
person.save()
obj = self.Person.objects(Q(name__icontains='[.\'Geek')).first()
self.assertEqual(obj, person)
def test_not(self):
"""Ensure that the __not operator works as expected.
"""
alice = self.Person(name='Alice', age=25)
alice.save()
obj = self.Person.objects(name__iexact='alice').first()
self.assertEqual(obj, alice)
obj = self.Person.objects(name__not__iexact='alice').first()
self.assertEqual(obj, None)
def test_filter_chaining(self):
"""Ensure filters can be chained together.
@ -498,9 +557,10 @@ class QuerySetTest(unittest.TestCase):
obj = self.Person.objects(Q(name=re.compile('^gui', re.I))).first()
self.assertEqual(obj, person)
obj = self.Person.objects(Q(name__ne=re.compile('^bob'))).first()
obj = self.Person.objects(Q(name__not=re.compile('^bob'))).first()
self.assertEqual(obj, person)
obj = self.Person.objects(Q(name__ne=re.compile('^Gui'))).first()
obj = self.Person.objects(Q(name__not=re.compile('^Gui'))).first()
self.assertEqual(obj, None)
def test_q_lists(self):
@ -664,28 +724,32 @@ class QuerySetTest(unittest.TestCase):
post.reload()
self.assertTrue('db' in post.tags and 'nosql' in post.tags)
tags = post.tags[:-1]
BlogPost.objects.update(pop__tags=1)
post.reload()
self.assertEqual(post.tags, tags)
BlogPost.objects.update_one(add_to_set__tags='unique')
BlogPost.objects.update_one(add_to_set__tags='unique')
post.reload()
self.assertEqual(post.tags.count('unique'), 1)
BlogPost.drop_collection()
def test_update_pull(self):
"""Ensure that the 'pull' update operation works correctly.
"""
class Comment(EmbeddedDocument):
content = StringField()
class BlogPost(Document):
slug = StringField()
comments = ListField(EmbeddedDocumentField(Comment))
tags = ListField(StringField())
comment1 = Comment(content="test1")
comment2 = Comment(content="test2")
post = BlogPost(slug="test", comments=[comment1, comment2])
post = BlogPost(slug="test", tags=['code', 'mongodb', 'code'])
post.save()
self.assertTrue(comment2 in post.comments)
BlogPost.objects(slug="test").update(pull__comments__content="test2")
BlogPost.objects(slug="test").update(pull__tags="code")
post.reload()
self.assertTrue(comment2 not in post.comments)
self.assertTrue('code' not in post.tags)
self.assertEqual(len(post.tags), 1)
def test_order_by(self):
"""Ensure that QuerySets may be ordered.
@ -921,7 +985,7 @@ class QuerySetTest(unittest.TestCase):
BlogPost(hits=1, tags=['music', 'film', 'actors']).save()
BlogPost(hits=2, tags=['music']).save()
BlogPost(hits=3, tags=['music', 'actors']).save()
BlogPost(hits=2, tags=['music', 'actors']).save()
f = BlogPost.objects.item_frequencies('tags')
f = dict((key, int(val)) for key, val in f.items())
@ -943,16 +1007,26 @@ class QuerySetTest(unittest.TestCase):
self.assertAlmostEqual(f['actors'], 2.0/6.0)
self.assertAlmostEqual(f['film'], 1.0/6.0)
# Check item_frequencies works for non-list fields
f = BlogPost.objects.item_frequencies('hits')
f = dict((key, int(val)) for key, val in f.items())
self.assertEqual(set(['1', '2']), set(f.keys()))
self.assertEqual(f['1'], 1)
self.assertEqual(f['2'], 2)
BlogPost.drop_collection()
def test_average(self):
"""Ensure that field can be averaged correctly.
"""
self.Person(name='person', age=0).save()
self.assertEqual(int(self.Person.objects.average('age')), 0)
ages = [23, 54, 12, 94, 27]
for i, age in enumerate(ages):
self.Person(name='test%s' % i, age=age).save()
avg = float(sum(ages)) / len(ages)
avg = float(sum(ages)) / (len(ages) + 1) # take into account the 0
self.assertAlmostEqual(int(self.Person.objects.average('age')), avg)
self.Person(name='ageless person').save()
@ -970,15 +1044,34 @@ class QuerySetTest(unittest.TestCase):
self.Person(name='ageless person').save()
self.assertEqual(int(self.Person.objects.sum('age')), sum(ages))
def test_distinct(self):
"""Ensure that the QuerySet.distinct method works.
"""
self.Person(name='Mr Orange', age=20).save()
self.Person(name='Mr White', age=20).save()
self.Person(name='Mr Orange', age=30).save()
self.Person(name='Mr Pink', age=30).save()
self.assertEqual(set(self.Person.objects.distinct('name')),
set(['Mr Orange', 'Mr White', 'Mr Pink']))
self.assertEqual(set(self.Person.objects.distinct('age')),
set([20, 30]))
self.assertEqual(set(self.Person.objects(age=30).distinct('name')),
set(['Mr Orange', 'Mr Pink']))
def test_custom_manager(self):
"""Ensure that custom QuerySetManager instances work as expected.
"""
class BlogPost(Document):
tags = ListField(StringField())
deleted = BooleanField(default=False)
@queryset_manager
def objects(doc_cls, queryset):
return queryset(deleted=False)
@queryset_manager
def music_posts(doc_cls, queryset):
return queryset(tags='music')
return queryset(tags='music', deleted=False)
BlogPost.drop_collection()
@ -988,6 +1081,8 @@ class QuerySetTest(unittest.TestCase):
post2.save()
post3 = BlogPost(tags=['film', 'actors'])
post3.save()
post4 = BlogPost(tags=['film', 'actors'], deleted=True)
post4.save()
self.assertEqual([p.id for p in BlogPost.objects],
[post1.id, post2.id, post3.id])
@ -1011,7 +1106,8 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection()
data = {'title': 'Post 1', 'comments': [Comment(content='test')]}
BlogPost(**data).save()
post = BlogPost(**data)
post.save()
self.assertTrue('postTitle' in
BlogPost.objects(title=data['title'])._query)
@ -1019,12 +1115,33 @@ class QuerySetTest(unittest.TestCase):
BlogPost.objects(title=data['title'])._query)
self.assertEqual(len(BlogPost.objects(title=data['title'])), 1)
self.assertTrue('_id' in BlogPost.objects(pk=post.id)._query)
self.assertEqual(len(BlogPost.objects(pk=post.id)), 1)
self.assertTrue('postComments.commentContent' in
BlogPost.objects(comments__content='test')._query)
self.assertEqual(len(BlogPost.objects(comments__content='test')), 1)
BlogPost.drop_collection()
def test_query_pk_field_name(self):
"""Ensure that the correct "primary key" field name is used when querying
"""
class BlogPost(Document):
title = StringField(primary_key=True, db_field='postTitle')
BlogPost.drop_collection()
data = { 'title':'Post 1' }
post = BlogPost(**data)
post.save()
self.assertTrue('_id' in BlogPost.objects(pk=data['title'])._query)
self.assertTrue('_id' in BlogPost.objects(title=data['title'])._query)
self.assertEqual(len(BlogPost.objects(pk=data['title'])), 1)
BlogPost.drop_collection()
def test_query_value_conversion(self):
"""Ensure that query values are properly converted when necessary.
"""
@ -1087,8 +1204,9 @@ class QuerySetTest(unittest.TestCase):
# Indexes are lazy so use list() to perform query
list(BlogPost.objects)
info = BlogPost.objects._collection.index_information()
self.assertTrue([('_types', 1)] in info.values())
self.assertTrue([('_types', 1), ('date', -1)] in info.values())
info = [value['key'] for key, value in info.iteritems()]
self.assertTrue([('_types', 1)] in info)
self.assertTrue([('_types', 1), ('date', -1)] in info)
BlogPost.drop_collection()
@ -1164,46 +1282,104 @@ class QuerySetTest(unittest.TestCase):
def tearDown(self):
self.Person.drop_collection()
def test_geospatial_operators(self):
"""Ensure that geospatial queries are working.
"""
class Event(Document):
title = StringField()
date = DateTimeField()
location = GeoPointField()
def __unicode__(self):
return self.title
Event.drop_collection()
event1 = Event(title="Coltrane Motion @ Double Door",
date=datetime.now() - timedelta(days=1),
location=[41.909889, -87.677137])
event2 = Event(title="Coltrane Motion @ Bottom of the Hill",
date=datetime.now() - timedelta(days=10),
location=[37.7749295, -122.4194155])
event3 = Event(title="Coltrane Motion @ Empty Bottle",
date=datetime.now(),
location=[41.900474, -87.686638])
event1.save()
event2.save()
event3.save()
# find all events "near" pitchfork office, chicago.
# note that "near" will show the san francisco event, too,
# although it sorts to last.
events = Event.objects(location__near=[41.9120459, -87.67892])
self.assertEqual(events.count(), 3)
self.assertEqual(list(events), [event1, event3, event2])
# find events within 5 miles of pitchfork office, chicago
point_and_distance = [[41.9120459, -87.67892], 5]
events = Event.objects(location__within_distance=point_and_distance)
self.assertEqual(events.count(), 2)
events = list(events)
self.assertTrue(event2 not in events)
self.assertTrue(event1 in events)
self.assertTrue(event3 in events)
# ensure ordering is respected by "near"
events = Event.objects(location__near=[41.9120459, -87.67892])
events = events.order_by("-date")
self.assertEqual(events.count(), 3)
self.assertEqual(list(events), [event3, event1, event2])
# find events around san francisco
point_and_distance = [[37.7566023, -122.415579], 10]
events = Event.objects(location__within_distance=point_and_distance)
self.assertEqual(events.count(), 1)
self.assertEqual(events[0], event2)
# find events within 1 mile of greenpoint, broolyn, nyc, ny
point_and_distance = [[40.7237134, -73.9509714], 1]
events = Event.objects(location__within_distance=point_and_distance)
self.assertEqual(events.count(), 0)
# ensure ordering is respected by "within_distance"
point_and_distance = [[41.9120459, -87.67892], 10]
events = Event.objects(location__within_distance=point_and_distance)
events = events.order_by("-date")
self.assertEqual(events.count(), 2)
self.assertEqual(events[0], event3)
# check that within_box works
box = [(35.0, -125.0), (40.0, -100.0)]
events = Event.objects(location__within_box=box)
self.assertEqual(events.count(), 1)
self.assertEqual(events[0].id, event2.id)
Event.drop_collection()
def test_custom_querysets(self):
"""Ensure that custom QuerySet classes may be used.
"""
class CustomQuerySet(QuerySet):
def not_empty(self):
return len(self) > 0
class Post(Document):
meta = {'queryset_class': CustomQuerySet}
Post.drop_collection()
self.assertTrue(isinstance(Post.objects, CustomQuerySet))
self.assertFalse(Post.objects.not_empty())
Post().save()
self.assertTrue(Post.objects.not_empty())
Post.drop_collection()
class QTest(unittest.TestCase):
def test_or_and(self):
"""Ensure that Q objects may be combined correctly.
"""
q1 = Q(name='test')
q2 = Q(age__gte=18)
query = ['(', {'name': 'test'}, '||', {'age__gte': 18}, ')']
self.assertEqual((q1 | q2).query, query)
query = ['(', {'name': 'test'}, '&&', {'age__gte': 18}, ')']
self.assertEqual((q1 & q2).query, query)
query = ['(', '(', {'name': 'test'}, '&&', {'age__gte': 18}, ')', '||',
{'name': 'example'}, ')']
self.assertEqual((q1 & q2 | Q(name='example')).query, query)
def test_item_query_as_js(self):
"""Ensure that the _item_query_as_js utilitiy method works properly.
"""
q = Q()
examples = [
({'name': 'test'}, ('((this.name instanceof Array) && '
'this.name.indexOf(i0f0) != -1) || this.name == i0f0'),
{'i0f0': 'test'}),
({'age': {'$gt': 18}}, 'this.age > i0f0o0', {'i0f0o0': 18}),
({'name': 'test', 'age': {'$gt': 18, '$lte': 65}},
('this.age <= i0f0o0 && this.age > i0f0o1 && '
'((this.name instanceof Array) && '
'this.name.indexOf(i0f1) != -1) || this.name == i0f1'),
{'i0f0o0': 65, 'i0f0o1': 18, 'i0f1': 'test'}),
]
for item, js, scope in examples:
test_scope = {}
self.assertEqual(q._item_query_as_js(item, test_scope, 0), js)
self.assertEqual(scope, test_scope)
def test_empty_q(self):
"""Ensure that empty Q objects won't hurt.
"""
@ -1213,11 +1389,131 @@ class QTest(unittest.TestCase):
q4 = Q(name='test')
q5 = Q()
query = ['(', {'age__gte': 18}, '||', {'name': 'test'}, ')']
self.assertEqual((q1 | q2 | q3 | q4 | q5).query, query)
class Person(Document):
name = StringField()
age = IntField()
query = {'$or': [{'age': {'$gte': 18}}, {'name': 'test'}]}
self.assertEqual((q1 | q2 | q3 | q4 | q5).to_query(Person), query)
query = {'age': {'$gte': 18}, 'name': 'test'}
self.assertEqual((q1 & q2 & q3 & q4 & q5).to_query(Person), query)
def test_q_with_dbref(self):
"""Ensure Q objects handle DBRefs correctly"""
connect(db='mongoenginetest')
class User(Document):
pass
class Post(Document):
created_user = ReferenceField(User)
user = User.objects.create()
Post.objects.create(created_user=user)
self.assertEqual(Post.objects.filter(created_user=user).count(), 1)
self.assertEqual(Post.objects.filter(Q(created_user=user)).count(), 1)
def test_and_combination(self):
"""Ensure that Q-objects correctly AND together.
"""
class TestDoc(Document):
x = IntField()
y = StringField()
# Check than an error is raised when conflicting queries are anded
def invalid_combination():
query = Q(x__lt=7) & Q(x__lt=3)
query.to_query(TestDoc)
self.assertRaises(InvalidQueryError, invalid_combination)
# Check normal cases work without an error
query = Q(x__lt=7) & Q(x__gt=3)
q1 = Q(x__lt=7)
q2 = Q(x__gt=3)
query = (q1 & q2).to_query(TestDoc)
self.assertEqual(query, {'x': {'$lt': 7, '$gt': 3}})
# More complex nested example
query = Q(x__lt=100) & Q(y__ne='NotMyString')
query &= Q(y__in=['a', 'b', 'c']) & Q(x__gt=-100)
mongo_query = {
'x': {'$lt': 100, '$gt': -100},
'y': {'$ne': 'NotMyString', '$in': ['a', 'b', 'c']},
}
self.assertEqual(query.to_query(TestDoc), mongo_query)
def test_or_combination(self):
"""Ensure that Q-objects correctly OR together.
"""
class TestDoc(Document):
x = IntField()
q1 = Q(x__lt=3)
q2 = Q(x__gt=7)
query = (q1 | q2).to_query(TestDoc)
self.assertEqual(query, {
'$or': [
{'x': {'$lt': 3}},
{'x': {'$gt': 7}},
]
})
def test_and_or_combination(self):
"""Ensure that Q-objects handle ANDing ORed components.
"""
class TestDoc(Document):
x = IntField()
y = BooleanField()
query = (Q(x__gt=0) | Q(x__exists=False))
query &= Q(x__lt=100)
self.assertEqual(query.to_query(TestDoc), {
'$or': [
{'x': {'$lt': 100, '$gt': 0}},
{'x': {'$lt': 100, '$exists': False}},
]
})
q1 = (Q(x__gt=0) | Q(x__exists=False))
q2 = (Q(x__lt=100) | Q(y=True))
query = (q1 & q2).to_query(TestDoc)
self.assertEqual(['$or'], query.keys())
conditions = [
{'x': {'$lt': 100, '$gt': 0}},
{'x': {'$lt': 100, '$exists': False}},
{'x': {'$gt': 0}, 'y': True},
{'x': {'$exists': False}, 'y': True},
]
self.assertEqual(len(conditions), len(query['$or']))
for condition in conditions:
self.assertTrue(condition in query['$or'])
def test_or_and_or_combination(self):
"""Ensure that Q-objects handle ORing ANDed ORed components. :)
"""
class TestDoc(Document):
x = IntField()
y = BooleanField()
q1 = (Q(x__gt=0) & (Q(y=True) | Q(y__exists=False)))
q2 = (Q(x__lt=100) & (Q(y=False) | Q(y__exists=False)))
query = (q1 | q2).to_query(TestDoc)
self.assertEqual(['$or'], query.keys())
conditions = [
{'x': {'$gt': 0}, 'y': True},
{'x': {'$gt': 0}, 'y': {'$exists': False}},
{'x': {'$lt': 100}, 'y':False},
{'x': {'$lt': 100}, 'y': {'$exists': False}},
]
self.assertEqual(len(conditions), len(query['$or']))
for condition in conditions:
self.assertTrue(condition in query['$or'])
query = ['(', {'age__gte': 18}, '&&', {'name': 'test'}, ')']
self.assertEqual((q1 & q2 & q3 & q4 & q5).query, query)
if __name__ == '__main__':
unittest.main()