Merge remote branch 'upstream/dev' into dev
This commit is contained in:
		
							
								
								
									
										1
									
								
								AUTHORS
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								AUTHORS
									
									
									
									
									
								
							| @@ -3,3 +3,4 @@ Matt Dennewitz <mattdennewitz@gmail.com> | |||||||
| Deepak Thukral <iapain@yahoo.com> | Deepak Thukral <iapain@yahoo.com> | ||||||
| Florian Schlachter <flori@n-schlachter.de> | Florian Schlachter <flori@n-schlachter.de> | ||||||
| Steve Challis <steve@stevechallis.com> | Steve Challis <steve@stevechallis.com> | ||||||
|  | Ross Lawley <ross.lawley@gmail.com> | ||||||
|   | |||||||
| @@ -5,6 +5,11 @@ Changelog | |||||||
| Changes in dev | Changes in dev | ||||||
| ============== | ============== | ||||||
|  |  | ||||||
|  | - Added slave_okay kwarg to queryset | ||||||
|  | - Added insert method for bulk inserts | ||||||
|  | - Added blinker signal support | ||||||
|  | - Added query_counter context manager for tests | ||||||
|  | - Added DereferenceBaseField - for improved performance in field dereferencing | ||||||
| - Added optional map_reduce method item_frequencies | - Added optional map_reduce method item_frequencies | ||||||
| - Added inline_map_reduce option to map_reduce | - Added inline_map_reduce option to map_reduce | ||||||
| - Updated connection exception so it provides more info on the cause. | - Updated connection exception so it provides more info on the cause. | ||||||
|   | |||||||
| @@ -49,10 +49,11 @@ Storage | |||||||
| ======= | ======= | ||||||
| With MongoEngine's support for GridFS via the :class:`~mongoengine.FileField`, | With MongoEngine's support for GridFS via the :class:`~mongoengine.FileField`, | ||||||
| it is useful to have a Django file storage backend that wraps this. The new | it is useful to have a Django file storage backend that wraps this. The new | ||||||
| storage module is called :class:`~mongoengine.django.GridFSStorage`. Using it | storage module is called :class:`~mongoengine.django.storage.GridFSStorage`.  | ||||||
| is very similar to using the default FileSystemStorage.:: | Using it is very similar to using the default FileSystemStorage.:: | ||||||
|      |      | ||||||
|     fs = mongoengine.django.GridFSStorage() |     from mongoengine.django.storage import GridFSStorage | ||||||
|  |     fs = GridFSStorage() | ||||||
|  |  | ||||||
|     filename = fs.save('hello.txt', 'Hello, World!') |     filename = fs.save('hello.txt', 'Hello, World!') | ||||||
|  |  | ||||||
|   | |||||||
| @@ -341,9 +341,10 @@ Indexes | |||||||
| You can specify indexes on collections to make querying faster. This is done | You can specify indexes on collections to make querying faster. This is done | ||||||
| by creating a list of index specifications called :attr:`indexes` in the | by creating a list of index specifications called :attr:`indexes` in the | ||||||
| :attr:`~mongoengine.Document.meta` dictionary, where an index specification may | :attr:`~mongoengine.Document.meta` dictionary, where an index specification may | ||||||
| either be a single field name, or a tuple containing multiple field names. A | either be a single field name, a tuple containing multiple field names, or a | ||||||
| direction may be specified on fields by prefixing the field name with a **+** | dictionary containing a full index definition. A direction may be specified on | ||||||
| or a **-** sign. Note that direction only matters on multi-field indexes. :: | fields by prefixing the field name with a **+** or a **-** sign. Note that | ||||||
|  | direction only matters on multi-field indexes. :: | ||||||
|  |  | ||||||
|     class Page(Document): |     class Page(Document): | ||||||
|         title = StringField() |         title = StringField() | ||||||
| @@ -352,6 +353,21 @@ or a **-** sign. Note that direction only matters on multi-field indexes. :: | |||||||
|             'indexes': ['title', ('title', '-rating')] |             'indexes': ['title', ('title', '-rating')] | ||||||
|         } |         } | ||||||
|  |  | ||||||
|  | If a dictionary is passed then the following options are available: | ||||||
|  |  | ||||||
|  | :attr:`fields` (Default: None) | ||||||
|  |     The fields to index. Specified in the same format as described above. | ||||||
|  |  | ||||||
|  | :attr:`types` (Default: True) | ||||||
|  |     Whether the index should have the :attr:`_types` field added automatically | ||||||
|  |     to the start of the index. | ||||||
|  |  | ||||||
|  | :attr:`sparse` (Default: False) | ||||||
|  |     Whether the index should be sparse. | ||||||
|  |  | ||||||
|  | :attr:`unique` (Default: False) | ||||||
|  |     Whether the index should be sparse. | ||||||
|  |  | ||||||
| .. note:: | .. note:: | ||||||
|    Geospatial indexes will be automatically created for all  |    Geospatial indexes will be automatically created for all  | ||||||
|    :class:`~mongoengine.GeoPointField`\ s |    :class:`~mongoengine.GeoPointField`\ s | ||||||
|   | |||||||
| @@ -11,3 +11,4 @@ User Guide | |||||||
|    document-instances |    document-instances | ||||||
|    querying |    querying | ||||||
|    gridfs |    gridfs | ||||||
|  |    signals | ||||||
|   | |||||||
							
								
								
									
										49
									
								
								docs/guide/signals.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								docs/guide/signals.rst
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,49 @@ | |||||||
|  | .. _signals: | ||||||
|  |  | ||||||
|  | Signals | ||||||
|  | ======= | ||||||
|  |  | ||||||
|  | .. versionadded:: 0.5 | ||||||
|  |  | ||||||
|  | Signal support is provided by the excellent `blinker`_ library and | ||||||
|  | will gracefully fall back if it is not available. | ||||||
|  |  | ||||||
|  |  | ||||||
|  | The following document signals exist in MongoEngine and are pretty self explaintary: | ||||||
|  |  | ||||||
|  |   * `mongoengine.signals.pre_init` | ||||||
|  |   * `mongoengine.signals.post_init` | ||||||
|  |   * `mongoengine.signals.pre_save` | ||||||
|  |   * `mongoengine.signals.post_save` | ||||||
|  |   * `mongoengine.signals.pre_delete` | ||||||
|  |   * `mongoengine.signals.post_delete` | ||||||
|  |  | ||||||
|  | Example usage:: | ||||||
|  |  | ||||||
|  |     from mongoengine import * | ||||||
|  |     from mongoengine import signals | ||||||
|  |  | ||||||
|  |     class Author(Document): | ||||||
|  |         name = StringField() | ||||||
|  |  | ||||||
|  |         def __unicode__(self): | ||||||
|  |             return self.name | ||||||
|  |  | ||||||
|  |         @classmethod | ||||||
|  |         def pre_save(cls, instance, **kwargs): | ||||||
|  |             logging.debug("Pre Save: %s" % instance.name) | ||||||
|  |  | ||||||
|  |         @classmethod | ||||||
|  |         def post_save(cls, instance, **kwargs): | ||||||
|  |             logging.debug("Post Save: %s" % instance.name) | ||||||
|  |             if 'created' in kwargs: | ||||||
|  |                 if kwargs['created']: | ||||||
|  |                     logging.debug("Created") | ||||||
|  |                 else: | ||||||
|  |                     logging.debug("Updated") | ||||||
|  |          | ||||||
|  |         signals.pre_save.connect(Author.pre_save) | ||||||
|  |         signals.post_save.connect(Author.post_save) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | .. _blinker: http://pypi.python.org/pypi/blinker | ||||||
| @@ -6,9 +6,11 @@ import connection | |||||||
| from connection import * | from connection import * | ||||||
| import queryset | import queryset | ||||||
| from queryset import * | from queryset import * | ||||||
|  | import signals | ||||||
|  | from signals import * | ||||||
|  |  | ||||||
| __all__ = (document.__all__ + fields.__all__ + connection.__all__ + | __all__ = (document.__all__ + fields.__all__ + connection.__all__ + | ||||||
|            queryset.__all__) |            queryset.__all__ + signals.__all__) | ||||||
|  |  | ||||||
| __author__ = 'Harry Marr' | __author__ = 'Harry Marr' | ||||||
|  |  | ||||||
|   | |||||||
| @@ -2,9 +2,12 @@ from queryset import QuerySet, QuerySetManager | |||||||
| from queryset import DoesNotExist, MultipleObjectsReturned | from queryset import DoesNotExist, MultipleObjectsReturned | ||||||
| from queryset import DO_NOTHING | from queryset import DO_NOTHING | ||||||
|  |  | ||||||
|  | from mongoengine import signals | ||||||
|  |  | ||||||
| import sys | import sys | ||||||
| import pymongo | import pymongo | ||||||
| import pymongo.objectid | import pymongo.objectid | ||||||
|  | from operator import itemgetter | ||||||
|  |  | ||||||
|  |  | ||||||
| class NotRegistered(Exception): | class NotRegistered(Exception): | ||||||
| @@ -126,6 +129,88 @@ class BaseField(object): | |||||||
|  |  | ||||||
|         self.validate(value) |         self.validate(value) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class DereferenceBaseField(BaseField): | ||||||
|  |     """Handles the lazy dereferencing of a queryset.  Will dereference all | ||||||
|  |     items in a list / dict rather than one at a time. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __get__(self, instance, owner): | ||||||
|  |         """Descriptor to automatically dereference references. | ||||||
|  |         """ | ||||||
|  |         from fields import ReferenceField, GenericReferenceField | ||||||
|  |         from connection import _get_db | ||||||
|  |  | ||||||
|  |         if instance is None: | ||||||
|  |             # Document class being used rather than a document object | ||||||
|  |             return self | ||||||
|  |  | ||||||
|  |         # Get value from document instance if available | ||||||
|  |         value_list = instance._data.get(self.name) | ||||||
|  |         if not value_list: | ||||||
|  |             return super(DereferenceBaseField, self).__get__(instance, owner) | ||||||
|  |  | ||||||
|  |         is_list = False | ||||||
|  |         if not hasattr(value_list, 'items'): | ||||||
|  |             is_list = True | ||||||
|  |             value_list = dict([(k,v) for k,v in enumerate(value_list)]) | ||||||
|  |  | ||||||
|  |         if isinstance(self.field, ReferenceField) and value_list: | ||||||
|  |             db = _get_db() | ||||||
|  |             dbref = {} | ||||||
|  |             collections = {} | ||||||
|  |  | ||||||
|  |             for k, v in value_list.items(): | ||||||
|  |                 dbref[k] = v | ||||||
|  |                 # Save any DBRefs | ||||||
|  |                 if isinstance(v, (pymongo.dbref.DBRef)): | ||||||
|  |                     collections.setdefault(v.collection, []).append((k, v)) | ||||||
|  |  | ||||||
|  |             # For each collection get the references | ||||||
|  |             for collection, dbrefs in collections.items(): | ||||||
|  |                 id_map = dict([(v.id, k) for k, v in dbrefs]) | ||||||
|  |                 references = db[collection].find({'_id': {'$in': id_map.keys()}}) | ||||||
|  |                 for ref in references: | ||||||
|  |                     key = id_map[ref['_id']] | ||||||
|  |                     dbref[key] = get_document(ref['_cls'])._from_son(ref) | ||||||
|  |  | ||||||
|  |             if is_list: | ||||||
|  |                 dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))] | ||||||
|  |             instance._data[self.name] = dbref | ||||||
|  |  | ||||||
|  |         # Get value from document instance if available | ||||||
|  |         if isinstance(self.field, GenericReferenceField) and value_list: | ||||||
|  |             db = _get_db() | ||||||
|  |             value_list = [(k,v) for k,v in value_list.items()] | ||||||
|  |             dbref = {} | ||||||
|  |             classes = {} | ||||||
|  |  | ||||||
|  |             for k, v in value_list: | ||||||
|  |                 dbref[k] = v | ||||||
|  |                 # Save any DBRefs | ||||||
|  |                 if isinstance(v, (dict, pymongo.son.SON)): | ||||||
|  |                     classes.setdefault(v['_cls'], []).append((k, v)) | ||||||
|  |  | ||||||
|  |             # For each collection get the references | ||||||
|  |             for doc_cls, dbrefs in classes.items(): | ||||||
|  |                 id_map = dict([(v['_ref'].id, k) for k, v in dbrefs]) | ||||||
|  |                 doc_cls = get_document(doc_cls) | ||||||
|  |                 collection = doc_cls._meta['collection'] | ||||||
|  |                 references = db[collection].find({'_id': {'$in': id_map.keys()}}) | ||||||
|  |  | ||||||
|  |                 for ref in references: | ||||||
|  |                     key = id_map[ref['_id']] | ||||||
|  |                     dbref[key] = doc_cls._from_son(ref) | ||||||
|  |  | ||||||
|  |             if is_list: | ||||||
|  |                 dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))] | ||||||
|  |  | ||||||
|  |             instance._data[self.name] = dbref | ||||||
|  |  | ||||||
|  |         return super(DereferenceBaseField, self).__get__(instance, owner) | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ObjectIdField(BaseField): | class ObjectIdField(BaseField): | ||||||
|     """An field wrapper around MongoDB's ObjectIds. |     """An field wrapper around MongoDB's ObjectIds. | ||||||
|     """ |     """ | ||||||
| @@ -382,6 +467,8 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): | |||||||
| class BaseDocument(object): | class BaseDocument(object): | ||||||
|  |  | ||||||
|     def __init__(self, **values): |     def __init__(self, **values): | ||||||
|  |         signals.pre_init.send(self, values=values) | ||||||
|  |  | ||||||
|         self._data = {} |         self._data = {} | ||||||
|         # Assign default values to instance |         # Assign default values to instance | ||||||
|         for attr_name in self._fields.keys(): |         for attr_name in self._fields.keys(): | ||||||
| @@ -395,6 +482,8 @@ class BaseDocument(object): | |||||||
|             except AttributeError: |             except AttributeError: | ||||||
|                 pass |                 pass | ||||||
|  |  | ||||||
|  |         signals.post_init.send(self) | ||||||
|  |  | ||||||
|     def validate(self): |     def validate(self): | ||||||
|         """Ensure that all fields' values are valid and that required fields |         """Ensure that all fields' values are valid and that required fields | ||||||
|         are present. |         are present. | ||||||
|   | |||||||
| @@ -1,3 +1,4 @@ | |||||||
|  | from mongoengine import signals | ||||||
| from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, | from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, | ||||||
|                   ValidationError) |                   ValidationError) | ||||||
| from queryset import OperationError | from queryset import OperationError | ||||||
| @@ -75,6 +76,8 @@ class Document(BaseDocument): | |||||||
|                 For example, ``save(..., w=2, fsync=True)`` will wait until at least two servers |                 For example, ``save(..., w=2, fsync=True)`` will wait until at least two servers | ||||||
|                 have recorded the write and will force an fsync on each server being written to. |                 have recorded the write and will force an fsync on each server being written to. | ||||||
|         """ |         """ | ||||||
|  |         signals.pre_save.send(self) | ||||||
|  |  | ||||||
|         if validate: |         if validate: | ||||||
|             self.validate() |             self.validate() | ||||||
|  |  | ||||||
| @@ -82,6 +85,7 @@ class Document(BaseDocument): | |||||||
|             write_options = {} |             write_options = {} | ||||||
|  |  | ||||||
|         doc = self.to_mongo() |         doc = self.to_mongo() | ||||||
|  |         created = '_id' not in doc | ||||||
|         try: |         try: | ||||||
|             collection = self.__class__.objects._collection |             collection = self.__class__.objects._collection | ||||||
|             if force_insert: |             if force_insert: | ||||||
| @@ -96,12 +100,16 @@ class Document(BaseDocument): | |||||||
|         id_field = self._meta['id_field'] |         id_field = self._meta['id_field'] | ||||||
|         self[id_field] = self._fields[id_field].to_python(object_id) |         self[id_field] = self._fields[id_field].to_python(object_id) | ||||||
|  |  | ||||||
|  |         signals.post_save.send(self, created=created) | ||||||
|  |  | ||||||
|     def delete(self, safe=False): |     def delete(self, safe=False): | ||||||
|         """Delete the :class:`~mongoengine.Document` from the database. This |         """Delete the :class:`~mongoengine.Document` from the database. This | ||||||
|         will only take effect if the document has been previously saved. |         will only take effect if the document has been previously saved. | ||||||
|  |  | ||||||
|         :param safe: check if the operation succeeded before returning |         :param safe: check if the operation succeeded before returning | ||||||
|         """ |         """ | ||||||
|  |         signals.pre_delete.send(self) | ||||||
|  |  | ||||||
|         id_field = self._meta['id_field'] |         id_field = self._meta['id_field'] | ||||||
|         object_id = self._fields[id_field].to_mongo(self[id_field]) |         object_id = self._fields[id_field].to_mongo(self[id_field]) | ||||||
|         try: |         try: | ||||||
| @@ -110,6 +118,8 @@ class Document(BaseDocument): | |||||||
|             message = u'Could not delete document (%s)' % err.message |             message = u'Could not delete document (%s)' % err.message | ||||||
|             raise OperationError(message) |             raise OperationError(message) | ||||||
|  |  | ||||||
|  |         signals.post_delete.send(self) | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def register_delete_rule(cls, document_cls, field_name, rule): |     def register_delete_rule(cls, document_cls, field_name, rule): | ||||||
|         """This method registers the delete rules to apply when removing this |         """This method registers the delete rules to apply when removing this | ||||||
|   | |||||||
| @@ -1,4 +1,5 @@ | |||||||
| from base import BaseField, ObjectIdField, ValidationError, get_document | from base import (BaseField, DereferenceBaseField, ObjectIdField, | ||||||
|  |                   ValidationError, get_document) | ||||||
| from queryset import DO_NOTHING | from queryset import DO_NOTHING | ||||||
| from document import Document, EmbeddedDocument | from document import Document, EmbeddedDocument | ||||||
| from connection import _get_db | from connection import _get_db | ||||||
| @@ -12,7 +13,6 @@ import pymongo.binary | |||||||
| import datetime, time | import datetime, time | ||||||
| import decimal | import decimal | ||||||
| import gridfs | import gridfs | ||||||
| import warnings |  | ||||||
|  |  | ||||||
|  |  | ||||||
| __all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField', | __all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField', | ||||||
| @@ -153,6 +153,7 @@ class IntField(BaseField): | |||||||
|     def prepare_query_value(self, op, value): |     def prepare_query_value(self, op, value): | ||||||
|         return int(value) |         return int(value) | ||||||
|  |  | ||||||
|  |  | ||||||
| class FloatField(BaseField): | class FloatField(BaseField): | ||||||
|     """An floating point number field. |     """An floating point number field. | ||||||
|     """ |     """ | ||||||
| @@ -178,6 +179,7 @@ class FloatField(BaseField): | |||||||
|     def prepare_query_value(self, op, value): |     def prepare_query_value(self, op, value): | ||||||
|         return float(value) |         return float(value) | ||||||
|  |  | ||||||
|  |  | ||||||
| class DecimalField(BaseField): | class DecimalField(BaseField): | ||||||
|     """A fixed-point decimal number field. |     """A fixed-point decimal number field. | ||||||
|  |  | ||||||
| @@ -227,6 +229,10 @@ class BooleanField(BaseField): | |||||||
|  |  | ||||||
| class DateTimeField(BaseField): | class DateTimeField(BaseField): | ||||||
|     """A datetime field. |     """A datetime field. | ||||||
|  |  | ||||||
|  |     Note: Microseconds are rounded to the nearest millisecond. | ||||||
|  |       Pre UTC microsecond support is effecively broken see | ||||||
|  |       `tests.field.test_datetime` for more information. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def validate(self, value): |     def validate(self, value): | ||||||
| @@ -255,7 +261,6 @@ class DateTimeField(BaseField): | |||||||
|         try:  # Seconds are optional, so try converting seconds first. |         try:  # Seconds are optional, so try converting seconds first. | ||||||
|             return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M:%S')[:6], |             return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M:%S')[:6], | ||||||
|                                      **kwargs) |                                      **kwargs) | ||||||
|  |  | ||||||
|         except ValueError: |         except ValueError: | ||||||
|             try:  # Try without seconds. |             try:  # Try without seconds. | ||||||
|                 return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M')[:5], |                 return datetime.datetime(*time.strptime(value, '%Y-%m-%d %H:%M')[:5], | ||||||
| @@ -267,6 +272,7 @@ class DateTimeField(BaseField): | |||||||
|                 except ValueError: |                 except ValueError: | ||||||
|                     return None |                     return None | ||||||
|  |  | ||||||
|  |  | ||||||
| class EmbeddedDocumentField(BaseField): | class EmbeddedDocumentField(BaseField): | ||||||
|     """An embedded document field. Only valid values are subclasses of |     """An embedded document field. Only valid values are subclasses of | ||||||
|     :class:`~mongoengine.EmbeddedDocument`. |     :class:`~mongoengine.EmbeddedDocument`. | ||||||
| @@ -314,7 +320,7 @@ class EmbeddedDocumentField(BaseField): | |||||||
|         return self.to_mongo(value) |         return self.to_mongo(value) | ||||||
|  |  | ||||||
|  |  | ||||||
| class ListField(BaseField): | class ListField(DereferenceBaseField): | ||||||
|     """A list field that wraps a standard field, allowing multiple instances |     """A list field that wraps a standard field, allowing multiple instances | ||||||
|     of the field to be used as a list in the database. |     of the field to be used as a list in the database. | ||||||
|     """ |     """ | ||||||
| @@ -330,42 +336,6 @@ class ListField(BaseField): | |||||||
|         kwargs.setdefault('default', lambda: []) |         kwargs.setdefault('default', lambda: []) | ||||||
|         super(ListField, self).__init__(**kwargs) |         super(ListField, self).__init__(**kwargs) | ||||||
|  |  | ||||||
|     def __get__(self, instance, owner): |  | ||||||
|         """Descriptor to automatically dereference references. |  | ||||||
|         """ |  | ||||||
|         if instance is None: |  | ||||||
|             # Document class being used rather than a document object |  | ||||||
|             return self |  | ||||||
|  |  | ||||||
|         if isinstance(self.field, ReferenceField): |  | ||||||
|             referenced_type = self.field.document_type |  | ||||||
|             # Get value from document instance if available |  | ||||||
|             value_list = instance._data.get(self.name) |  | ||||||
|             if value_list: |  | ||||||
|                 deref_list = [] |  | ||||||
|                 for value in value_list: |  | ||||||
|                     # Dereference DBRefs |  | ||||||
|                     if isinstance(value, (pymongo.dbref.DBRef)): |  | ||||||
|                         value = _get_db().dereference(value) |  | ||||||
|                         deref_list.append(referenced_type._from_son(value)) |  | ||||||
|                     else: |  | ||||||
|                         deref_list.append(value) |  | ||||||
|                 instance._data[self.name] = deref_list |  | ||||||
|  |  | ||||||
|         if isinstance(self.field, GenericReferenceField): |  | ||||||
|             value_list = instance._data.get(self.name) |  | ||||||
|             if value_list: |  | ||||||
|                 deref_list = [] |  | ||||||
|                 for value in value_list: |  | ||||||
|                     # Dereference DBRefs |  | ||||||
|                     if isinstance(value, (dict, pymongo.son.SON)): |  | ||||||
|                         deref_list.append(self.field.dereference(value)) |  | ||||||
|                     else: |  | ||||||
|                         deref_list.append(value) |  | ||||||
|                 instance._data[self.name] = deref_list |  | ||||||
|  |  | ||||||
|         return super(ListField, self).__get__(instance, owner) |  | ||||||
|  |  | ||||||
|     def to_python(self, value): |     def to_python(self, value): | ||||||
|         return [self.field.to_python(item) for item in value] |         return [self.field.to_python(item) for item in value] | ||||||
|  |  | ||||||
| @@ -462,7 +432,7 @@ class DictField(BaseField): | |||||||
|         return super(DictField, self).prepare_query_value(op, value) |         return super(DictField, self).prepare_query_value(op, value) | ||||||
|  |  | ||||||
|  |  | ||||||
| class MapField(BaseField): | class MapField(DereferenceBaseField): | ||||||
|     """A field that maps a name to a specified field type. Similar to |     """A field that maps a name to a specified field type. Similar to | ||||||
|     a DictField, except the 'value' of each item must match the specified |     a DictField, except the 'value' of each item must match the specified | ||||||
|     field type. |     field type. | ||||||
| @@ -494,42 +464,6 @@ class MapField(BaseField): | |||||||
|         except Exception, err: |         except Exception, err: | ||||||
|             raise ValidationError('Invalid MapField item (%s)' % str(item)) |             raise ValidationError('Invalid MapField item (%s)' % str(item)) | ||||||
|  |  | ||||||
|     def __get__(self, instance, owner): |  | ||||||
|         """Descriptor to automatically dereference references. |  | ||||||
|         """ |  | ||||||
|         if instance is None: |  | ||||||
|             # Document class being used rather than a document object |  | ||||||
|             return self |  | ||||||
|  |  | ||||||
|         if isinstance(self.field, ReferenceField): |  | ||||||
|             referenced_type = self.field.document_type |  | ||||||
|             # Get value from document instance if available |  | ||||||
|             value_dict = instance._data.get(self.name) |  | ||||||
|             if value_dict: |  | ||||||
|                 deref_dict = [] |  | ||||||
|                 for key,value in value_dict.iteritems(): |  | ||||||
|                     # Dereference DBRefs |  | ||||||
|                     if isinstance(value, (pymongo.dbref.DBRef)): |  | ||||||
|                         value = _get_db().dereference(value) |  | ||||||
|                         deref_dict[key] = referenced_type._from_son(value) |  | ||||||
|                     else: |  | ||||||
|                         deref_dict[key] = value |  | ||||||
|                 instance._data[self.name] = deref_dict |  | ||||||
|  |  | ||||||
|         if isinstance(self.field, GenericReferenceField): |  | ||||||
|             value_dict = instance._data.get(self.name) |  | ||||||
|             if value_dict: |  | ||||||
|                 deref_dict = [] |  | ||||||
|                 for key,value in value_dict.iteritems(): |  | ||||||
|                     # Dereference DBRefs |  | ||||||
|                     if isinstance(value, (dict, pymongo.son.SON)): |  | ||||||
|                         deref_dict[key] = self.field.dereference(value) |  | ||||||
|                     else: |  | ||||||
|                         deref_dict[key] = value |  | ||||||
|                 instance._data[self.name] = deref_dict |  | ||||||
|  |  | ||||||
|         return super(MapField, self).__get__(instance, owner) |  | ||||||
|  |  | ||||||
|     def to_python(self, value): |     def to_python(self, value): | ||||||
|         return dict([(key, self.field.to_python(item)) for key, item in value.iteritems()]) |         return dict([(key, self.field.to_python(item)) for key, item in value.iteritems()]) | ||||||
|  |  | ||||||
| @@ -752,11 +686,11 @@ class GridFSProxy(object): | |||||||
|         self.newfile = self.fs.new_file(**kwargs) |         self.newfile = self.fs.new_file(**kwargs) | ||||||
|         self.grid_id = self.newfile._id |         self.grid_id = self.newfile._id | ||||||
|  |  | ||||||
|     def put(self, file, **kwargs): |     def put(self, file_obj, **kwargs): | ||||||
|         if self.grid_id: |         if self.grid_id: | ||||||
|             raise GridFSError('This document already has a file. Either delete ' |             raise GridFSError('This document already has a file. Either delete ' | ||||||
|                               'it or call replace to overwrite it') |                               'it or call replace to overwrite it') | ||||||
|         self.grid_id = self.fs.put(file, **kwargs) |         self.grid_id = self.fs.put(file_obj, **kwargs) | ||||||
|  |  | ||||||
|     def write(self, string): |     def write(self, string): | ||||||
|         if self.grid_id: |         if self.grid_id: | ||||||
| @@ -785,9 +719,9 @@ class GridFSProxy(object): | |||||||
|         self.grid_id = None |         self.grid_id = None | ||||||
|         self.gridout = None |         self.gridout = None | ||||||
|  |  | ||||||
|     def replace(self, file, **kwargs): |     def replace(self, file_obj, **kwargs): | ||||||
|         self.delete() |         self.delete() | ||||||
|         self.put(file, **kwargs) |         self.put(file_obj, **kwargs) | ||||||
|  |  | ||||||
|     def close(self): |     def close(self): | ||||||
|         if self.newfile: |         if self.newfile: | ||||||
|   | |||||||
| @@ -336,6 +336,7 @@ class QuerySet(object): | |||||||
|         self._snapshot = False |         self._snapshot = False | ||||||
|         self._timeout = True |         self._timeout = True | ||||||
|         self._class_check = True |         self._class_check = True | ||||||
|  |         self._slave_okay = False | ||||||
|  |  | ||||||
|         # If inheritance is allowed, only return instances and instances of |         # If inheritance is allowed, only return instances and instances of | ||||||
|         # subclasses of the class being used |         # subclasses of the class being used | ||||||
| @@ -352,7 +353,7 @@ class QuerySet(object): | |||||||
|  |  | ||||||
|         copy_props = ('_initial_query', '_query_obj', '_where_clause', |         copy_props = ('_initial_query', '_query_obj', '_where_clause', | ||||||
|                     '_loaded_fields', '_ordering', '_snapshot', |                     '_loaded_fields', '_ordering', '_snapshot', | ||||||
|                     '_timeout', '_limit', '_skip') |                     '_timeout', '_limit', '_skip', '_slave_okay') | ||||||
|  |  | ||||||
|         for prop in copy_props: |         for prop in copy_props: | ||||||
|             val = getattr(self, prop) |             val = getattr(self, prop) | ||||||
| @@ -376,21 +377,27 @@ class QuerySet(object): | |||||||
|             construct a multi-field index); keys may be prefixed with a **+** |             construct a multi-field index); keys may be prefixed with a **+** | ||||||
|             or a **-** to determine the index ordering |             or a **-** to determine the index ordering | ||||||
|         """ |         """ | ||||||
|         index_list = QuerySet._build_index_spec(self._document, key_or_list) |         index_spec = QuerySet._build_index_spec(self._document, key_or_list) | ||||||
|         self._collection.ensure_index(index_list, drop_dups=drop_dups, |         self._collection.ensure_index( | ||||||
|             background=background) |             index_spec['fields'], | ||||||
|  |             drop_dups=drop_dups, | ||||||
|  |             background=background, | ||||||
|  |             sparse=index_spec.get('sparse', False), | ||||||
|  |             unique=index_spec.get('unique', False)) | ||||||
|         return self |         return self | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def _build_index_spec(cls, doc_cls, key_or_list): |     def _build_index_spec(cls, doc_cls, spec): | ||||||
|         """Build a PyMongo index spec from a MongoEngine index spec. |         """Build a PyMongo index spec from a MongoEngine index spec. | ||||||
|         """ |         """ | ||||||
|         if isinstance(key_or_list, basestring): |         if isinstance(spec, basestring): | ||||||
|             key_or_list = [key_or_list] |             spec = {'fields': [spec]} | ||||||
|  |         if isinstance(spec, (list, tuple)): | ||||||
|  |             spec = {'fields': spec} | ||||||
|  |  | ||||||
|         index_list = [] |         index_list = [] | ||||||
|         use_types = doc_cls._meta.get('allow_inheritance', True) |         use_types = doc_cls._meta.get('allow_inheritance', True) | ||||||
|         for key in key_or_list: |         for key in spec['fields']: | ||||||
|             # Get direction from + or - |             # Get direction from + or - | ||||||
|             direction = pymongo.ASCENDING |             direction = pymongo.ASCENDING | ||||||
|             if key.startswith("-"): |             if key.startswith("-"): | ||||||
| @@ -411,12 +418,20 @@ class QuerySet(object): | |||||||
|                 use_types = False |                 use_types = False | ||||||
|  |  | ||||||
|         # If _types is being used, prepend it to every specified index |         # If _types is being used, prepend it to every specified index | ||||||
|         if doc_cls._meta.get('allow_inheritance') and use_types: |         if (spec.get('types', True) and doc_cls._meta.get('allow_inheritance') | ||||||
|  |                 and use_types): | ||||||
|             index_list.insert(0, ('_types', 1)) |             index_list.insert(0, ('_types', 1)) | ||||||
|  |  | ||||||
|         return index_list |         spec['fields'] = index_list | ||||||
|  |  | ||||||
|     def __call__(self, q_obj=None, class_check=True, **query): |         if spec.get('sparse', False) and len(spec['fields']) > 1: | ||||||
|  |             raise ValueError( | ||||||
|  |                 'Sparse indexes can only have one field in them. ' | ||||||
|  |                 'See https://jira.mongodb.org/browse/SERVER-2193') | ||||||
|  |  | ||||||
|  |         return spec | ||||||
|  |  | ||||||
|  |     def __call__(self, q_obj=None, class_check=True, slave_okay=False, **query): | ||||||
|         """Filter the selected documents by calling the |         """Filter the selected documents by calling the | ||||||
|         :class:`~mongoengine.queryset.QuerySet` with a query. |         :class:`~mongoengine.queryset.QuerySet` with a query. | ||||||
|  |  | ||||||
| @@ -426,6 +441,8 @@ class QuerySet(object): | |||||||
|             objects, only the last one will be used |             objects, only the last one will be used | ||||||
|         :param class_check: If set to False bypass class name check when |         :param class_check: If set to False bypass class name check when | ||||||
|             querying collection |             querying collection | ||||||
|  |         :param slave_okay: if True, allows this query to be run against a | ||||||
|  |             replica secondary. | ||||||
|         :param query: Django-style query keyword arguments |         :param query: Django-style query keyword arguments | ||||||
|         """ |         """ | ||||||
|         query = Q(**query) |         query = Q(**query) | ||||||
| @@ -465,9 +482,12 @@ class QuerySet(object): | |||||||
|  |  | ||||||
|             # Ensure document-defined indexes are created |             # Ensure document-defined indexes are created | ||||||
|             if self._document._meta['indexes']: |             if self._document._meta['indexes']: | ||||||
|                 for key_or_list in self._document._meta['indexes']: |                 for spec in self._document._meta['indexes']: | ||||||
|                     self._collection.ensure_index(key_or_list, |                     opts = index_opts.copy() | ||||||
|                         background=background, **index_opts) |                     opts['unique'] = spec.get('unique', False) | ||||||
|  |                     opts['sparse'] = spec.get('sparse', False) | ||||||
|  |                     self._collection.ensure_index(spec['fields'], | ||||||
|  |                         background=background, **opts) | ||||||
|  |  | ||||||
|             # If _types is being used (for polymorphism), it needs an index |             # If _types is being used (for polymorphism), it needs an index | ||||||
|             if '_types' in self._query: |             if '_types' in self._query: | ||||||
| @@ -484,16 +504,22 @@ class QuerySet(object): | |||||||
|         return self._collection_obj |         return self._collection_obj | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def _cursor(self): |     def _cursor_args(self): | ||||||
|         if self._cursor_obj is None: |  | ||||||
|         cursor_args = { |         cursor_args = { | ||||||
|             'snapshot': self._snapshot, |             'snapshot': self._snapshot, | ||||||
|             'timeout': self._timeout, |             'timeout': self._timeout, | ||||||
|  |             'slave_okay': self._slave_okay | ||||||
|         } |         } | ||||||
|         if self._loaded_fields: |         if self._loaded_fields: | ||||||
|             cursor_args['fields'] = self._loaded_fields.as_dict() |             cursor_args['fields'] = self._loaded_fields.as_dict() | ||||||
|  |         return cursor_args | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def _cursor(self): | ||||||
|  |         if self._cursor_obj is None: | ||||||
|  |  | ||||||
|             self._cursor_obj = self._collection.find(self._query, |             self._cursor_obj = self._collection.find(self._query, | ||||||
|                                                      **cursor_args) |                                                      **self._cursor_args) | ||||||
|             # Apply where clauses to cursor |             # Apply where clauses to cursor | ||||||
|             if self._where_clause: |             if self._where_clause: | ||||||
|                 self._cursor_obj.where(self._where_clause) |                 self._cursor_obj.where(self._where_clause) | ||||||
| @@ -702,6 +728,46 @@ class QuerySet(object): | |||||||
|             result = None |             result = None | ||||||
|         return result |         return result | ||||||
|  |  | ||||||
|  |     def insert(self, doc_or_docs, load_bulk=True): | ||||||
|  |         """bulk insert documents | ||||||
|  |  | ||||||
|  |         :param docs_or_doc: a document or list of documents to be inserted | ||||||
|  |         :param load_bulk (optional): If True returns the list of document instances | ||||||
|  |  | ||||||
|  |         By default returns document instances, set ``load_bulk`` to False to | ||||||
|  |         return just ``ObjectIds`` | ||||||
|  |  | ||||||
|  |         .. versionadded:: 0.5 | ||||||
|  |         """ | ||||||
|  |         from document import Document | ||||||
|  |  | ||||||
|  |         docs = doc_or_docs | ||||||
|  |         return_one = False | ||||||
|  |         if isinstance(docs, Document) or issubclass(docs.__class__, Document): | ||||||
|  |             return_one = True | ||||||
|  |             docs = [docs] | ||||||
|  |  | ||||||
|  |         raw = [] | ||||||
|  |         for doc in docs: | ||||||
|  |             if not isinstance(doc, self._document): | ||||||
|  |                 msg = "Some documents inserted aren't instances of %s" % str(self._document) | ||||||
|  |                 raise OperationError(msg) | ||||||
|  |             if doc.pk: | ||||||
|  |                 msg = "Some documents have ObjectIds use doc.update() instead" | ||||||
|  |                 raise OperationError(msg) | ||||||
|  |             raw.append(doc.to_mongo()) | ||||||
|  |  | ||||||
|  |         ids = self._collection.insert(raw) | ||||||
|  |  | ||||||
|  |         if not load_bulk: | ||||||
|  |             return return_one and ids[0] or ids | ||||||
|  |  | ||||||
|  |         documents = self.in_bulk(ids) | ||||||
|  |         results = [] | ||||||
|  |         for obj_id in ids: | ||||||
|  |             results.append(documents.get(obj_id)) | ||||||
|  |         return return_one and results[0] or results | ||||||
|  |  | ||||||
|     def with_id(self, object_id): |     def with_id(self, object_id): | ||||||
|         """Retrieve the object matching the id provided. |         """Retrieve the object matching the id provided. | ||||||
|  |  | ||||||
| @@ -710,7 +776,7 @@ class QuerySet(object): | |||||||
|         id_field = self._document._meta['id_field'] |         id_field = self._document._meta['id_field'] | ||||||
|         object_id = self._document._fields[id_field].to_mongo(object_id) |         object_id = self._document._fields[id_field].to_mongo(object_id) | ||||||
|  |  | ||||||
|         result = self._collection.find_one({'_id': object_id}) |         result = self._collection.find_one({'_id': object_id}, **self._cursor_args) | ||||||
|         if result is not None: |         if result is not None: | ||||||
|             result = self._document._from_son(result) |             result = self._document._from_son(result) | ||||||
|         return result |         return result | ||||||
| @@ -726,7 +792,8 @@ class QuerySet(object): | |||||||
|         """ |         """ | ||||||
|         doc_map = {} |         doc_map = {} | ||||||
|  |  | ||||||
|         docs = self._collection.find({'_id': {'$in': object_ids}}) |         docs = self._collection.find({'_id': {'$in': object_ids}}, | ||||||
|  |                                      **self._cursor_args) | ||||||
|         for doc in docs: |         for doc in docs: | ||||||
|             doc_map[doc['_id']] = self._document._from_son(doc) |             doc_map[doc['_id']] = self._document._from_son(doc) | ||||||
|  |  | ||||||
| @@ -1023,6 +1090,7 @@ class QuerySet(object): | |||||||
|         :param enabled: whether or not snapshot mode is enabled |         :param enabled: whether or not snapshot mode is enabled | ||||||
|         """ |         """ | ||||||
|         self._snapshot = enabled |         self._snapshot = enabled | ||||||
|  |         return self | ||||||
|  |  | ||||||
|     def timeout(self, enabled): |     def timeout(self, enabled): | ||||||
|         """Enable or disable the default mongod timeout when querying. |         """Enable or disable the default mongod timeout when querying. | ||||||
| @@ -1030,6 +1098,15 @@ class QuerySet(object): | |||||||
|         :param enabled: whether or not the timeout is used |         :param enabled: whether or not the timeout is used | ||||||
|         """ |         """ | ||||||
|         self._timeout = enabled |         self._timeout = enabled | ||||||
|  |         return self | ||||||
|  |  | ||||||
|  |     def slave_okay(self, enabled): | ||||||
|  |         """Enable or disable the slave_okay when querying. | ||||||
|  |  | ||||||
|  |         :param enabled: whether or not the slave_okay is enabled | ||||||
|  |         """ | ||||||
|  |         self._slave_okay = enabled | ||||||
|  |         return self | ||||||
|  |  | ||||||
|     def delete(self, safe=False): |     def delete(self, safe=False): | ||||||
|         """Delete the documents matched by the query. |         """Delete the documents matched by the query. | ||||||
|   | |||||||
							
								
								
									
										44
									
								
								mongoengine/signals.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								mongoengine/signals.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,44 @@ | |||||||
|  | # -*- coding: utf-8 -*- | ||||||
|  |  | ||||||
|  | __all__ = ['pre_init', 'post_init', 'pre_save', 'post_save', | ||||||
|  |            'pre_delete', 'post_delete'] | ||||||
|  |  | ||||||
|  | signals_available = False | ||||||
|  | try: | ||||||
|  |     from blinker import Namespace | ||||||
|  |     signals_available = True | ||||||
|  | except ImportError: | ||||||
|  |     class Namespace(object): | ||||||
|  |         def signal(self, name, doc=None): | ||||||
|  |             return _FakeSignal(name, doc) | ||||||
|  |  | ||||||
|  |     class _FakeSignal(object): | ||||||
|  |         """If blinker is unavailable, create a fake class with the same | ||||||
|  |         interface that allows sending of signals but will fail with an | ||||||
|  |         error on anything else.  Instead of doing anything on send, it | ||||||
|  |         will just ignore the arguments and do nothing instead. | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         def __init__(self, name, doc=None): | ||||||
|  |             self.name = name | ||||||
|  |             self.__doc__ = doc | ||||||
|  |  | ||||||
|  |         def _fail(self, *args, **kwargs): | ||||||
|  |             raise RuntimeError('signalling support is unavailable ' | ||||||
|  |                                'because the blinker library is ' | ||||||
|  |                                'not installed.') | ||||||
|  |         send = lambda *a, **kw: None | ||||||
|  |         connect = disconnect = has_receivers_for = receivers_for = \ | ||||||
|  |             temporarily_connected_to = _fail | ||||||
|  |         del _fail | ||||||
|  |  | ||||||
|  | # the namespace for code signals.  If you are not mongoengine code, do | ||||||
|  | # not put signals in here.  Create your own namespace instead. | ||||||
|  | _signals = Namespace() | ||||||
|  |  | ||||||
|  | pre_init = _signals.signal('pre_init') | ||||||
|  | post_init = _signals.signal('post_init') | ||||||
|  | pre_save = _signals.signal('pre_save') | ||||||
|  | post_save = _signals.signal('post_save') | ||||||
|  | pre_delete = _signals.signal('pre_delete') | ||||||
|  | post_delete = _signals.signal('post_delete') | ||||||
							
								
								
									
										59
									
								
								mongoengine/tests.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								mongoengine/tests.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,59 @@ | |||||||
|  | from mongoengine.connection import _get_db | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class query_counter(object): | ||||||
|  |     """ Query_counter contextmanager to get the number of queries. """ | ||||||
|  |  | ||||||
|  |     def __init__(self): | ||||||
|  |         """ Construct the query_counter. """ | ||||||
|  |         self.counter = 0 | ||||||
|  |         self.db = _get_db() | ||||||
|  |  | ||||||
|  |     def __enter__(self): | ||||||
|  |         """ On every with block we need to drop the profile collection. """ | ||||||
|  |         self.db.set_profiling_level(0) | ||||||
|  |         self.db.system.profile.drop() | ||||||
|  |         self.db.set_profiling_level(2) | ||||||
|  |         return self | ||||||
|  |  | ||||||
|  |     def __exit__(self, t, value, traceback): | ||||||
|  |         """ Reset the profiling level. """ | ||||||
|  |         self.db.set_profiling_level(0) | ||||||
|  |  | ||||||
|  |     def __eq__(self, value): | ||||||
|  |         """ == Compare querycounter. """ | ||||||
|  |         return value == self._get_count() | ||||||
|  |  | ||||||
|  |     def __ne__(self, value): | ||||||
|  |         """ != Compare querycounter. """ | ||||||
|  |         return not self.__eq__(value) | ||||||
|  |  | ||||||
|  |     def __lt__(self, value): | ||||||
|  |         """ < Compare querycounter. """ | ||||||
|  |         return self._get_count() < value | ||||||
|  |  | ||||||
|  |     def __le__(self, value): | ||||||
|  |         """ <= Compare querycounter. """ | ||||||
|  |         return self._get_count() <= value | ||||||
|  |  | ||||||
|  |     def __gt__(self, value): | ||||||
|  |         """ > Compare querycounter. """ | ||||||
|  |         return self._get_count() > value | ||||||
|  |  | ||||||
|  |     def __ge__(self, value): | ||||||
|  |         """ >= Compare querycounter. """ | ||||||
|  |         return self._get_count() >= value | ||||||
|  |  | ||||||
|  |     def __int__(self): | ||||||
|  |         """ int representation. """ | ||||||
|  |         return self._get_count() | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         """ repr query_counter as the number of queries. """ | ||||||
|  |         return u"%s" % self._get_count() | ||||||
|  |  | ||||||
|  |     def _get_count(self): | ||||||
|  |         """ Get the number of queries. """ | ||||||
|  |         count = self.db.system.profile.find().count() - self.counter | ||||||
|  |         self.counter += 1 | ||||||
|  |         return count | ||||||
							
								
								
									
										2
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								setup.py
									
									
									
									
									
								
							| @@ -45,6 +45,6 @@ setup(name='mongoengine', | |||||||
|       long_description=LONG_DESCRIPTION, |       long_description=LONG_DESCRIPTION, | ||||||
|       platforms=['any'], |       platforms=['any'], | ||||||
|       classifiers=CLASSIFIERS, |       classifiers=CLASSIFIERS, | ||||||
|       install_requires=['pymongo'], |       install_requires=['pymongo', 'blinker'], | ||||||
|       test_suite='tests', |       test_suite='tests', | ||||||
| ) | ) | ||||||
|   | |||||||
							
								
								
									
										288
									
								
								tests/dereference.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										288
									
								
								tests/dereference.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,288 @@ | |||||||
|  | import unittest | ||||||
|  |  | ||||||
|  | from mongoengine import * | ||||||
|  | from mongoengine.connection import _get_db | ||||||
|  | from mongoengine.tests import query_counter | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class FieldTest(unittest.TestCase): | ||||||
|  |  | ||||||
|  |     def setUp(self): | ||||||
|  |         connect(db='mongoenginetest') | ||||||
|  |         self.db = _get_db() | ||||||
|  |  | ||||||
|  |     def test_list_item_dereference(self): | ||||||
|  |         """Ensure that DBRef items in ListFields are dereferenced. | ||||||
|  |         """ | ||||||
|  |         class User(Document): | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|  |         class Group(Document): | ||||||
|  |             members = ListField(ReferenceField(User)) | ||||||
|  |  | ||||||
|  |         User.drop_collection() | ||||||
|  |         Group.drop_collection() | ||||||
|  |  | ||||||
|  |         for i in xrange(1, 51): | ||||||
|  |             user = User(name='user %s' % i) | ||||||
|  |             user.save() | ||||||
|  |  | ||||||
|  |         group = Group(members=User.objects) | ||||||
|  |         group.save() | ||||||
|  |  | ||||||
|  |         with query_counter() as q: | ||||||
|  |             self.assertEqual(q, 0) | ||||||
|  |  | ||||||
|  |             group_obj = Group.objects.first() | ||||||
|  |             self.assertEqual(q, 1) | ||||||
|  |  | ||||||
|  |             [m for m in group_obj.members] | ||||||
|  |             self.assertEqual(q, 2) | ||||||
|  |  | ||||||
|  |         User.drop_collection() | ||||||
|  |         Group.drop_collection() | ||||||
|  |  | ||||||
|  |     def test_recursive_reference(self): | ||||||
|  |         """Ensure that ReferenceFields can reference their own documents. | ||||||
|  |         """ | ||||||
|  |         class Employee(Document): | ||||||
|  |             name = StringField() | ||||||
|  |             boss = ReferenceField('self') | ||||||
|  |             friends = ListField(ReferenceField('self')) | ||||||
|  |  | ||||||
|  |         bill = Employee(name='Bill Lumbergh') | ||||||
|  |         bill.save() | ||||||
|  |  | ||||||
|  |         michael = Employee(name='Michael Bolton') | ||||||
|  |         michael.save() | ||||||
|  |  | ||||||
|  |         samir = Employee(name='Samir Nagheenanajar') | ||||||
|  |         samir.save() | ||||||
|  |  | ||||||
|  |         friends = [michael, samir] | ||||||
|  |         peter = Employee(name='Peter Gibbons', boss=bill, friends=friends) | ||||||
|  |         peter.save() | ||||||
|  |  | ||||||
|  |         with query_counter() as q: | ||||||
|  |             self.assertEqual(q, 0) | ||||||
|  |  | ||||||
|  |             peter = Employee.objects.with_id(peter.id) | ||||||
|  |             self.assertEqual(q, 1) | ||||||
|  |  | ||||||
|  |             peter.boss | ||||||
|  |             self.assertEqual(q, 2) | ||||||
|  |  | ||||||
|  |             peter.friends | ||||||
|  |             self.assertEqual(q, 3) | ||||||
|  |  | ||||||
|  |     def test_generic_reference(self): | ||||||
|  |  | ||||||
|  |         class UserA(Document): | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|  |         class UserB(Document): | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|  |         class UserC(Document): | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|  |         class Group(Document): | ||||||
|  |             members = ListField(GenericReferenceField()) | ||||||
|  |  | ||||||
|  |         UserA.drop_collection() | ||||||
|  |         UserB.drop_collection() | ||||||
|  |         UserC.drop_collection() | ||||||
|  |         Group.drop_collection() | ||||||
|  |  | ||||||
|  |         members = [] | ||||||
|  |         for i in xrange(1, 51): | ||||||
|  |             a = UserA(name='User A %s' % i) | ||||||
|  |             a.save() | ||||||
|  |  | ||||||
|  |             b = UserB(name='User B %s' % i) | ||||||
|  |             b.save() | ||||||
|  |  | ||||||
|  |             c = UserC(name='User C %s' % i) | ||||||
|  |             c.save() | ||||||
|  |  | ||||||
|  |             members += [a, b, c] | ||||||
|  |  | ||||||
|  |         group = Group(members=members) | ||||||
|  |         group.save() | ||||||
|  |  | ||||||
|  |         with query_counter() as q: | ||||||
|  |             self.assertEqual(q, 0) | ||||||
|  |  | ||||||
|  |             group_obj = Group.objects.first() | ||||||
|  |             self.assertEqual(q, 1) | ||||||
|  |  | ||||||
|  |             [m for m in group_obj.members] | ||||||
|  |             self.assertEqual(q, 4) | ||||||
|  |  | ||||||
|  |             [m for m in group_obj.members] | ||||||
|  |             self.assertEqual(q, 4) | ||||||
|  |  | ||||||
|  |         UserA.drop_collection() | ||||||
|  |         UserB.drop_collection() | ||||||
|  |         UserC.drop_collection() | ||||||
|  |         Group.drop_collection() | ||||||
|  |  | ||||||
|  |     def test_map_field_reference(self): | ||||||
|  |  | ||||||
|  |         class User(Document): | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|  |         class Group(Document): | ||||||
|  |             members = MapField(ReferenceField(User)) | ||||||
|  |  | ||||||
|  |         User.drop_collection() | ||||||
|  |         Group.drop_collection() | ||||||
|  |  | ||||||
|  |         members = [] | ||||||
|  |         for i in xrange(1, 51): | ||||||
|  |             user = User(name='user %s' % i) | ||||||
|  |             user.save() | ||||||
|  |             members.append(user) | ||||||
|  |  | ||||||
|  |         group = Group(members=dict([(str(u.id), u) for u in members])) | ||||||
|  |         group.save() | ||||||
|  |  | ||||||
|  |         with query_counter() as q: | ||||||
|  |             self.assertEqual(q, 0) | ||||||
|  |  | ||||||
|  |             group_obj = Group.objects.first() | ||||||
|  |             self.assertEqual(q, 1) | ||||||
|  |  | ||||||
|  |             [m for m in group_obj.members] | ||||||
|  |             self.assertEqual(q, 2) | ||||||
|  |  | ||||||
|  |         User.drop_collection() | ||||||
|  |         Group.drop_collection() | ||||||
|  |  | ||||||
|  |     def ztest_generic_reference_dict_field(self): | ||||||
|  |  | ||||||
|  |         class UserA(Document): | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|  |         class UserB(Document): | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|  |         class UserC(Document): | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|  |         class Group(Document): | ||||||
|  |             members = DictField() | ||||||
|  |  | ||||||
|  |         UserA.drop_collection() | ||||||
|  |         UserB.drop_collection() | ||||||
|  |         UserC.drop_collection() | ||||||
|  |         Group.drop_collection() | ||||||
|  |  | ||||||
|  |         members = [] | ||||||
|  |         for i in xrange(1, 51): | ||||||
|  |             a = UserA(name='User A %s' % i) | ||||||
|  |             a.save() | ||||||
|  |  | ||||||
|  |             b = UserB(name='User B %s' % i) | ||||||
|  |             b.save() | ||||||
|  |  | ||||||
|  |             c = UserC(name='User C %s' % i) | ||||||
|  |             c.save() | ||||||
|  |  | ||||||
|  |             members += [a, b, c] | ||||||
|  |  | ||||||
|  |         group = Group(members=dict([(str(u.id), u) for u in members])) | ||||||
|  |         group.save() | ||||||
|  |  | ||||||
|  |         with query_counter() as q: | ||||||
|  |             self.assertEqual(q, 0) | ||||||
|  |  | ||||||
|  |             group_obj = Group.objects.first() | ||||||
|  |             self.assertEqual(q, 1) | ||||||
|  |  | ||||||
|  |             [m for m in group_obj.members] | ||||||
|  |             self.assertEqual(q, 4) | ||||||
|  |  | ||||||
|  |             [m for m in group_obj.members] | ||||||
|  |             self.assertEqual(q, 4) | ||||||
|  |  | ||||||
|  |         group.members = {} | ||||||
|  |         group.save() | ||||||
|  |  | ||||||
|  |         with query_counter() as q: | ||||||
|  |             self.assertEqual(q, 0) | ||||||
|  |  | ||||||
|  |             group_obj = Group.objects.first() | ||||||
|  |             self.assertEqual(q, 1) | ||||||
|  |  | ||||||
|  |             [m for m in group_obj.members] | ||||||
|  |             self.assertEqual(q, 1) | ||||||
|  |  | ||||||
|  |         UserA.drop_collection() | ||||||
|  |         UserB.drop_collection() | ||||||
|  |         UserC.drop_collection() | ||||||
|  |         Group.drop_collection() | ||||||
|  |  | ||||||
|  |     def test_generic_reference_map_field(self): | ||||||
|  |  | ||||||
|  |         class UserA(Document): | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|  |         class UserB(Document): | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|  |         class UserC(Document): | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|  |         class Group(Document): | ||||||
|  |             members = MapField(GenericReferenceField()) | ||||||
|  |  | ||||||
|  |         UserA.drop_collection() | ||||||
|  |         UserB.drop_collection() | ||||||
|  |         UserC.drop_collection() | ||||||
|  |         Group.drop_collection() | ||||||
|  |  | ||||||
|  |         members = [] | ||||||
|  |         for i in xrange(1, 51): | ||||||
|  |             a = UserA(name='User A %s' % i) | ||||||
|  |             a.save() | ||||||
|  |  | ||||||
|  |             b = UserB(name='User B %s' % i) | ||||||
|  |             b.save() | ||||||
|  |  | ||||||
|  |             c = UserC(name='User C %s' % i) | ||||||
|  |             c.save() | ||||||
|  |  | ||||||
|  |             members += [a, b, c] | ||||||
|  |  | ||||||
|  |         group = Group(members=dict([(str(u.id), u) for u in members])) | ||||||
|  |         group.save() | ||||||
|  |  | ||||||
|  |         with query_counter() as q: | ||||||
|  |             self.assertEqual(q, 0) | ||||||
|  |  | ||||||
|  |             group_obj = Group.objects.first() | ||||||
|  |             self.assertEqual(q, 1) | ||||||
|  |  | ||||||
|  |             [m for m in group_obj.members] | ||||||
|  |             self.assertEqual(q, 4) | ||||||
|  |  | ||||||
|  |             [m for m in group_obj.members] | ||||||
|  |             self.assertEqual(q, 4) | ||||||
|  |  | ||||||
|  |         group.members = {} | ||||||
|  |         group.save() | ||||||
|  |  | ||||||
|  |         with query_counter() as q: | ||||||
|  |             self.assertEqual(q, 0) | ||||||
|  |  | ||||||
|  |             group_obj = Group.objects.first() | ||||||
|  |             self.assertEqual(q, 1) | ||||||
|  |  | ||||||
|  |             [m for m in group_obj.members] | ||||||
|  |             self.assertEqual(q, 1) | ||||||
|  |  | ||||||
|  |         UserA.drop_collection() | ||||||
|  |         UserB.drop_collection() | ||||||
|  |         UserC.drop_collection() | ||||||
|  |         Group.drop_collection() | ||||||
| @@ -377,6 +377,40 @@ class DocumentTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         BlogPost.drop_collection() |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     def test_dictionary_indexes(self): | ||||||
|  |         """Ensure that indexes are used when meta[indexes] contains dictionaries | ||||||
|  |         instead of lists. | ||||||
|  |         """ | ||||||
|  |         class BlogPost(Document): | ||||||
|  |             date = DateTimeField(db_field='addDate', default=datetime.now) | ||||||
|  |             category = StringField() | ||||||
|  |             tags = ListField(StringField()) | ||||||
|  |             meta = { | ||||||
|  |                 'indexes': [ | ||||||
|  |                     { 'fields': ['-date'], 'unique': True, | ||||||
|  |                       'sparse': True, 'types': False }, | ||||||
|  |                 ], | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|  |         info = BlogPost.objects._collection.index_information() | ||||||
|  |         # _id, '-date' | ||||||
|  |         self.assertEqual(len(info), 3) | ||||||
|  |  | ||||||
|  |         # Indexes are lazy so use list() to perform query | ||||||
|  |         list(BlogPost.objects) | ||||||
|  |         info = BlogPost.objects._collection.index_information() | ||||||
|  |         info = [(value['key'], | ||||||
|  |                  value.get('unique', False), | ||||||
|  |                  value.get('sparse', False)) | ||||||
|  |                 for key, value in info.iteritems()] | ||||||
|  |         self.assertTrue(([('addDate', -1)], True, True) in info) | ||||||
|  |  | ||||||
|  |         BlogPost.drop_collection() | ||||||
|  |  | ||||||
|  |  | ||||||
|     def test_unique(self): |     def test_unique(self): | ||||||
|         """Ensure that uniqueness constraints are applied to fields. |         """Ensure that uniqueness constraints are applied to fields. | ||||||
|         """ |         """ | ||||||
|   | |||||||
| @@ -187,6 +187,66 @@ class FieldTest(unittest.TestCase): | |||||||
|         log.time = '1pm' |         log.time = '1pm' | ||||||
|         self.assertRaises(ValidationError, log.validate) |         self.assertRaises(ValidationError, log.validate) | ||||||
|  |  | ||||||
|  |     def test_datetime(self): | ||||||
|  |         """Tests showing pymongo datetime fields handling of microseconds. | ||||||
|  |         Microseconds are rounded to the nearest millisecond and pre UTC | ||||||
|  |         handling is wonky. | ||||||
|  |  | ||||||
|  |         See: http://api.mongodb.org/python/current/api/bson/son.html#dt | ||||||
|  |         """ | ||||||
|  |         class LogEntry(Document): | ||||||
|  |             date = DateTimeField() | ||||||
|  |  | ||||||
|  |         LogEntry.drop_collection() | ||||||
|  |  | ||||||
|  |         # Post UTC - microseconds are rounded (down) nearest millisecond and dropped | ||||||
|  |         d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999) | ||||||
|  |         d2 = datetime.datetime(1970, 01, 01, 00, 00, 01) | ||||||
|  |         log = LogEntry() | ||||||
|  |         log.date = d1 | ||||||
|  |         log.save() | ||||||
|  |         log.reload() | ||||||
|  |         self.assertNotEquals(log.date, d1) | ||||||
|  |         self.assertEquals(log.date, d2) | ||||||
|  |  | ||||||
|  |         # Post UTC - microseconds are rounded (down) nearest millisecond | ||||||
|  |         d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9999) | ||||||
|  |         d2 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9000) | ||||||
|  |         log.date = d1 | ||||||
|  |         log.save() | ||||||
|  |         log.reload() | ||||||
|  |         self.assertNotEquals(log.date, d1) | ||||||
|  |         self.assertEquals(log.date, d2) | ||||||
|  |  | ||||||
|  |         # Pre UTC dates microseconds below 1000 are dropped | ||||||
|  |         d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999) | ||||||
|  |         d2 = datetime.datetime(1969, 12, 31, 23, 59, 59) | ||||||
|  |         log.date = d1 | ||||||
|  |         log.save() | ||||||
|  |         log.reload() | ||||||
|  |         self.assertNotEquals(log.date, d1) | ||||||
|  |         self.assertEquals(log.date, d2) | ||||||
|  |  | ||||||
|  |         # Pre UTC microseconds above 1000 is wonky. | ||||||
|  |         # log.date has an invalid microsecond value so I can't construct | ||||||
|  |         # a date to compare. | ||||||
|  |         # | ||||||
|  |         # However, the timedelta is predicable with pre UTC timestamps | ||||||
|  |         # It always adds 16 seconds and [777216-776217] microseconds | ||||||
|  |         for i in xrange(1001, 3113, 33): | ||||||
|  |             d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, i) | ||||||
|  |             log.date = d1 | ||||||
|  |             log.save() | ||||||
|  |             log.reload() | ||||||
|  |             self.assertNotEquals(log.date, d1) | ||||||
|  |  | ||||||
|  |             delta = log.date - d1 | ||||||
|  |             self.assertEquals(delta.seconds, 16) | ||||||
|  |             microseconds = 777216 - (i % 1000) | ||||||
|  |             self.assertEquals(delta.microseconds, microseconds) | ||||||
|  |  | ||||||
|  |         LogEntry.drop_collection() | ||||||
|  |  | ||||||
|     def test_list_validation(self): |     def test_list_validation(self): | ||||||
|         """Ensure that a list field only accepts lists with valid elements. |         """Ensure that a list field only accepts lists with valid elements. | ||||||
|         """ |         """ | ||||||
|   | |||||||
| @@ -9,6 +9,7 @@ from mongoengine.queryset import (QuerySet, QuerySetManager, | |||||||
|                                   MultipleObjectsReturned, DoesNotExist, |                                   MultipleObjectsReturned, DoesNotExist, | ||||||
|                                   QueryFieldList) |                                   QueryFieldList) | ||||||
| from mongoengine import * | from mongoengine import * | ||||||
|  | from mongoengine.tests import query_counter | ||||||
|  |  | ||||||
|  |  | ||||||
| class QuerySetTest(unittest.TestCase): | class QuerySetTest(unittest.TestCase): | ||||||
| @@ -331,6 +332,125 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         person = self.Person.objects.get(age=50) |         person = self.Person.objects.get(age=50) | ||||||
|         self.assertEqual(person.name, "User C") |         self.assertEqual(person.name, "User C") | ||||||
|  |  | ||||||
|  |     def test_bulk_insert(self): | ||||||
|  |         """Ensure that query by array position works. | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         class Comment(EmbeddedDocument): | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|  |         class Post(EmbeddedDocument): | ||||||
|  |             comments = ListField(EmbeddedDocumentField(Comment)) | ||||||
|  |  | ||||||
|  |         class Blog(Document): | ||||||
|  |             title = StringField() | ||||||
|  |             tags = ListField(StringField()) | ||||||
|  |             posts = ListField(EmbeddedDocumentField(Post)) | ||||||
|  |  | ||||||
|  |         Blog.drop_collection() | ||||||
|  |  | ||||||
|  |         with query_counter() as q: | ||||||
|  |             self.assertEqual(q, 0) | ||||||
|  |  | ||||||
|  |             comment1 = Comment(name='testa') | ||||||
|  |             comment2 = Comment(name='testb') | ||||||
|  |             post1 = Post(comments=[comment1, comment2]) | ||||||
|  |             post2 = Post(comments=[comment2, comment2]) | ||||||
|  |  | ||||||
|  |             blogs = [] | ||||||
|  |             for i in xrange(1, 100): | ||||||
|  |                 blogs.append(Blog(title="post %s" % i, posts=[post1, post2])) | ||||||
|  |  | ||||||
|  |             Blog.objects.insert(blogs, load_bulk=False) | ||||||
|  |             self.assertEqual(q, 2) # 1 for the inital connection and 1 for the insert | ||||||
|  |  | ||||||
|  |             Blog.objects.insert(blogs) | ||||||
|  |             self.assertEqual(q, 4) # 1 for insert, and 1 for in bulk | ||||||
|  |  | ||||||
|  |         Blog.drop_collection() | ||||||
|  |  | ||||||
|  |         comment1 = Comment(name='testa') | ||||||
|  |         comment2 = Comment(name='testb') | ||||||
|  |         post1 = Post(comments=[comment1, comment2]) | ||||||
|  |         post2 = Post(comments=[comment2, comment2]) | ||||||
|  |         blog1 = Blog(title="code", posts=[post1, post2]) | ||||||
|  |         blog2 = Blog(title="mongodb", posts=[post2, post1]) | ||||||
|  |         blog1, blog2 = Blog.objects.insert([blog1, blog2]) | ||||||
|  |         self.assertEqual(blog1.title, "code") | ||||||
|  |         self.assertEqual(blog2.title, "mongodb") | ||||||
|  |  | ||||||
|  |         self.assertEqual(Blog.objects.count(), 2) | ||||||
|  |  | ||||||
|  |         # test handles people trying to upsert | ||||||
|  |         def throw_operation_error(): | ||||||
|  |             blogs = Blog.objects | ||||||
|  |             Blog.objects.insert(blogs) | ||||||
|  |  | ||||||
|  |         self.assertRaises(OperationError, throw_operation_error) | ||||||
|  |  | ||||||
|  |         # test handles other classes being inserted | ||||||
|  |         def throw_operation_error_wrong_doc(): | ||||||
|  |             class Author(Document): | ||||||
|  |                 pass | ||||||
|  |             Blog.objects.insert(Author()) | ||||||
|  |  | ||||||
|  |         self.assertRaises(OperationError, throw_operation_error_wrong_doc) | ||||||
|  |  | ||||||
|  |         def throw_operation_error_not_a_document(): | ||||||
|  |             Blog.objects.insert("HELLO WORLD") | ||||||
|  |  | ||||||
|  |         self.assertRaises(OperationError, throw_operation_error_not_a_document) | ||||||
|  |  | ||||||
|  |         Blog.drop_collection() | ||||||
|  |  | ||||||
|  |         blog1 = Blog(title="code", posts=[post1, post2]) | ||||||
|  |         blog1 = Blog.objects.insert(blog1) | ||||||
|  |         self.assertEqual(blog1.title, "code") | ||||||
|  |         self.assertEqual(Blog.objects.count(), 1) | ||||||
|  |  | ||||||
|  |         Blog.drop_collection() | ||||||
|  |         blog1 = Blog(title="code", posts=[post1, post2]) | ||||||
|  |         obj_id = Blog.objects.insert(blog1, load_bulk=False) | ||||||
|  |         self.assertEquals(obj_id.__class__.__name__, 'ObjectId') | ||||||
|  |  | ||||||
|  |     def test_slave_okay(self): | ||||||
|  |         """Ensures that a query can take slave_okay syntax | ||||||
|  |         """ | ||||||
|  |         person1 = self.Person(name="User A", age=20) | ||||||
|  |         person1.save() | ||||||
|  |         person2 = self.Person(name="User B", age=30) | ||||||
|  |         person2.save() | ||||||
|  |  | ||||||
|  |         # Retrieve the first person from the database | ||||||
|  |         person = self.Person.objects.slave_okay(True).first() | ||||||
|  |         self.assertTrue(isinstance(person, self.Person)) | ||||||
|  |         self.assertEqual(person.name, "User A") | ||||||
|  |         self.assertEqual(person.age, 20) | ||||||
|  |  | ||||||
|  |     def test_cursor_args(self): | ||||||
|  |         """Ensures the cursor args can be set as expected | ||||||
|  |         """ | ||||||
|  |         p = self.Person.objects | ||||||
|  |         # Check default | ||||||
|  |         self.assertEqual(p._cursor_args, | ||||||
|  |                 {'snapshot': False, 'slave_okay': False, 'timeout': True}) | ||||||
|  |  | ||||||
|  |         p.snapshot(False).slave_okay(False).timeout(False) | ||||||
|  |         self.assertEqual(p._cursor_args, | ||||||
|  |                 {'snapshot': False, 'slave_okay': False, 'timeout': False}) | ||||||
|  |  | ||||||
|  |         p.snapshot(True).slave_okay(False).timeout(False) | ||||||
|  |         self.assertEqual(p._cursor_args, | ||||||
|  |                 {'snapshot': True, 'slave_okay': False, 'timeout': False}) | ||||||
|  |  | ||||||
|  |         p.snapshot(True).slave_okay(True).timeout(False) | ||||||
|  |         self.assertEqual(p._cursor_args, | ||||||
|  |                 {'snapshot': True, 'slave_okay': True, 'timeout': False}) | ||||||
|  |  | ||||||
|  |         p.snapshot(True).slave_okay(True).timeout(True) | ||||||
|  |         self.assertEqual(p._cursor_args, | ||||||
|  |                 {'snapshot': True, 'slave_okay': True, 'timeout': True}) | ||||||
|  |  | ||||||
|     def test_repeated_iteration(self): |     def test_repeated_iteration(self): | ||||||
|         """Ensure that QuerySet rewinds itself one iteration finishes. |         """Ensure that QuerySet rewinds itself one iteration finishes. | ||||||
|         """ |         """ | ||||||
| @@ -2099,8 +2219,27 @@ class QuerySetTest(unittest.TestCase): | |||||||
|         Number.drop_collection() |         Number.drop_collection() | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     def test_ensure_index(self): | ||||||
|  |         """Ensure that manual creation of indexes works. | ||||||
|  |         """ | ||||||
|  |         class Comment(Document): | ||||||
|  |             message = StringField() | ||||||
|  |  | ||||||
|  |         Comment.objects.ensure_index('message') | ||||||
|  |  | ||||||
|  |         info = Comment.objects._collection.index_information() | ||||||
|  |         info = [(value['key'], | ||||||
|  |                  value.get('unique', False), | ||||||
|  |                  value.get('sparse', False)) | ||||||
|  |                 for key, value in info.iteritems()] | ||||||
|  |         self.assertTrue(([('_types', 1), ('message', 1)], False, False) in info) | ||||||
|  |  | ||||||
|  |  | ||||||
| class QTest(unittest.TestCase): | class QTest(unittest.TestCase): | ||||||
|  |  | ||||||
|  |     def setUp(self): | ||||||
|  |         connect(db='mongoenginetest') | ||||||
|  |  | ||||||
|     def test_empty_q(self): |     def test_empty_q(self): | ||||||
|         """Ensure that empty Q objects won't hurt. |         """Ensure that empty Q objects won't hurt. | ||||||
|         """ |         """ | ||||||
|   | |||||||
							
								
								
									
										130
									
								
								tests/signals.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										130
									
								
								tests/signals.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,130 @@ | |||||||
|  | # -*- coding: utf-8 -*- | ||||||
|  | import unittest | ||||||
|  |  | ||||||
|  | from mongoengine import * | ||||||
|  | from mongoengine import signals | ||||||
|  |  | ||||||
|  | signal_output = [] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SignalTests(unittest.TestCase): | ||||||
|  |     """ | ||||||
|  |     Testing signals before/after saving and deleting. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def get_signal_output(self, fn, *args, **kwargs): | ||||||
|  |         # Flush any existing signal output | ||||||
|  |         global signal_output | ||||||
|  |         signal_output = [] | ||||||
|  |         fn(*args, **kwargs) | ||||||
|  |         return signal_output | ||||||
|  |  | ||||||
|  |     def setUp(self): | ||||||
|  |         connect(db='mongoenginetest') | ||||||
|  |         class Author(Document): | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|  |             def __unicode__(self): | ||||||
|  |                 return self.name | ||||||
|  |  | ||||||
|  |             @classmethod | ||||||
|  |             def pre_init(cls, instance, **kwargs): | ||||||
|  |                 signal_output.append('pre_init signal, %s' % cls.__name__) | ||||||
|  |                 signal_output.append(str(kwargs['values'])) | ||||||
|  |  | ||||||
|  |             @classmethod | ||||||
|  |             def post_init(cls, instance, **kwargs): | ||||||
|  |                 signal_output.append('post_init signal, %s' % instance) | ||||||
|  |  | ||||||
|  |             @classmethod | ||||||
|  |             def pre_save(cls, instance, **kwargs): | ||||||
|  |                 signal_output.append('pre_save signal, %s' % instance) | ||||||
|  |  | ||||||
|  |             @classmethod | ||||||
|  |             def post_save(cls, instance, **kwargs): | ||||||
|  |                 signal_output.append('post_save signal, %s' % instance) | ||||||
|  |                 if 'created' in kwargs: | ||||||
|  |                     if kwargs['created']: | ||||||
|  |                         signal_output.append('Is created') | ||||||
|  |                     else: | ||||||
|  |                         signal_output.append('Is updated') | ||||||
|  |  | ||||||
|  |             @classmethod | ||||||
|  |             def pre_delete(cls, instance, **kwargs): | ||||||
|  |                 signal_output.append('pre_delete signal, %s' % instance) | ||||||
|  |  | ||||||
|  |             @classmethod | ||||||
|  |             def post_delete(cls, instance, **kwargs): | ||||||
|  |                 signal_output.append('post_delete signal, %s' % instance) | ||||||
|  |  | ||||||
|  |         self.Author = Author | ||||||
|  |  | ||||||
|  |         # Save up the number of connected signals so that we can check at the end | ||||||
|  |         # that all the signals we register get properly unregistered | ||||||
|  |         self.pre_signals = ( | ||||||
|  |             len(signals.pre_init.receivers), | ||||||
|  |             len(signals.post_init.receivers), | ||||||
|  |             len(signals.pre_save.receivers), | ||||||
|  |             len(signals.post_save.receivers), | ||||||
|  |             len(signals.pre_delete.receivers), | ||||||
|  |             len(signals.post_delete.receivers) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         signals.pre_init.connect(Author.pre_init) | ||||||
|  |         signals.post_init.connect(Author.post_init) | ||||||
|  |         signals.pre_save.connect(Author.pre_save) | ||||||
|  |         signals.post_save.connect(Author.post_save) | ||||||
|  |         signals.pre_delete.connect(Author.pre_delete) | ||||||
|  |         signals.post_delete.connect(Author.post_delete) | ||||||
|  |  | ||||||
|  |     def tearDown(self): | ||||||
|  |         signals.pre_init.disconnect(self.Author.pre_init) | ||||||
|  |         signals.post_init.disconnect(self.Author.post_init) | ||||||
|  |         signals.post_delete.disconnect(self.Author.post_delete) | ||||||
|  |         signals.pre_delete.disconnect(self.Author.pre_delete) | ||||||
|  |         signals.post_save.disconnect(self.Author.post_save) | ||||||
|  |         signals.pre_save.disconnect(self.Author.pre_save) | ||||||
|  |  | ||||||
|  |         # Check that all our signals got disconnected properly. | ||||||
|  |         post_signals = ( | ||||||
|  |             len(signals.pre_init.receivers), | ||||||
|  |             len(signals.post_init.receivers), | ||||||
|  |             len(signals.pre_save.receivers), | ||||||
|  |             len(signals.post_save.receivers), | ||||||
|  |             len(signals.pre_delete.receivers), | ||||||
|  |             len(signals.post_delete.receivers) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         self.assertEqual(self.pre_signals, post_signals) | ||||||
|  |  | ||||||
|  |     def test_model_signals(self): | ||||||
|  |         """ Model saves should throw some signals. """ | ||||||
|  |  | ||||||
|  |         def create_author(): | ||||||
|  |             a1 = self.Author(name='Bill Shakespeare') | ||||||
|  |  | ||||||
|  |         self.assertEqual(self.get_signal_output(create_author), [ | ||||||
|  |             "pre_init signal, Author", | ||||||
|  |             "{'name': 'Bill Shakespeare'}", | ||||||
|  |             "post_init signal, Bill Shakespeare", | ||||||
|  |         ]) | ||||||
|  |  | ||||||
|  |         a1 = self.Author(name='Bill Shakespeare') | ||||||
|  |         self.assertEqual(self.get_signal_output(a1.save), [ | ||||||
|  |             "pre_save signal, Bill Shakespeare", | ||||||
|  |             "post_save signal, Bill Shakespeare", | ||||||
|  |             "Is created" | ||||||
|  |         ]) | ||||||
|  |  | ||||||
|  |         a1.reload() | ||||||
|  |         a1.name='William Shakespeare' | ||||||
|  |         self.assertEqual(self.get_signal_output(a1.save), [ | ||||||
|  |             "pre_save signal, William Shakespeare", | ||||||
|  |             "post_save signal, William Shakespeare", | ||||||
|  |             "Is updated" | ||||||
|  |         ]) | ||||||
|  |  | ||||||
|  |         self.assertEqual(self.get_signal_output(a1.delete), [ | ||||||
|  |             'pre_delete signal, William Shakespeare', | ||||||
|  |             'post_delete signal, William Shakespeare', | ||||||
|  |         ]) | ||||||
		Reference in New Issue
	
	Block a user