From 7c62fdc0b82f13bae0796b0d749ecb87002240a7 Mon Sep 17 00:00:00 2001 From: Colin Howe Date: Wed, 8 Jun 2011 12:20:58 +0100 Subject: [PATCH 01/13] Allow for types to never be auto-prepended to indices --- mongoengine/queryset.py | 9 ++++++--- tests/queryset.py | 16 ++++++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 17a1b0da..303afb6a 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -410,8 +410,10 @@ class QuerySet(object): if use_types and not all(f._index_with_types for f in fields): use_types = False - # If _types is being used, prepend it to every specified index - if doc_cls._meta.get('allow_inheritance') and use_types: + # If _types is being used, create an index for it + index_types = doc_cls._meta.get('index_types', True) + allow_inheritance = doc_cls._meta.get('allow_inheritance') + if index_types and allow_inheritance and use_types: index_list.insert(0, ('_types', 1)) return index_list @@ -457,6 +459,7 @@ class QuerySet(object): background = self._document._meta.get('index_background', False) drop_dups = self._document._meta.get('index_drop_dups', False) index_opts = self._document._meta.get('index_options', {}) + index_types = self._document._meta.get('index_types', True) # Ensure indexes created by uniqueness constraints for index in self._document._meta['unique_indexes']: @@ -470,7 +473,7 @@ class QuerySet(object): background=background, **index_opts) # If _types is being used (for polymorphism), it needs an index - if '_types' in self._query: + if index_types and '_types' in self._query: self._collection.ensure_index('_types', background=background, **index_opts) diff --git a/tests/queryset.py b/tests/queryset.py index 1f03fbd9..1e5e7a5a 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -1710,6 +1710,22 @@ class QuerySetTest(unittest.TestCase): self.assertTrue([('_types', 1)] in info) self.assertTrue([('_types', 1), ('date', -1)] in info) + def test_dont_index_types(self): + """Ensure that index_types will, when disabled, prevent _types + being added to all indices. + """ + class BlogPost(Document): + date = DateTimeField() + meta = {'index_types': False, + 'indexes': ['-date']} + + # Indexes are lazy so use list() to perform query + list(BlogPost.objects) + info = BlogPost.objects._collection.index_information() + info = [value['key'] for key, value in info.iteritems()] + self.assertTrue([('_types', 1)] not in info) + self.assertTrue([('date', -1)] in info) + BlogPost.drop_collection() class BlogPost(Document): From aa32d4301479a7cd45071ca3e5607ebe319f225e Mon Sep 17 00:00:00 2001 From: Colin Howe Date: Wed, 8 Jun 2011 12:36:32 +0100 Subject: [PATCH 02/13] Pydoc update --- mongoengine/document.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mongoengine/document.py b/mongoengine/document.py index b563f427..cae8343d 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -53,6 +53,11 @@ class Document(BaseDocument): dictionary. The value should be a list of field names or tuples of field names. Index direction may be specified by prefixing the field names with a **+** or **-** sign. + + By default, _types will be added to the start of every index (that + doesn't contain a list) if allow_inheritence is True. This can be + disabled by either setting types to False on the specific index or + by setting index_types to False on the meta dictionary for the document. """ __metaclass__ = TopLevelDocumentMetaclass From 6dc2672dbab4d0914e838c4df867daa911a33dcf Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Wed, 8 Jun 2011 13:03:42 +0100 Subject: [PATCH 03/13] Updated changelog --- docs/changelog.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index ed877ebb..0a2a273f 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,7 +5,8 @@ Changelog Changes in dev ============== -- Added slave_okay kwarg to queryset +- Added queryset.slave_okay(enabled) method +- Updated queryset.timeout(enabled) and queryset.snapshot(enabled) to be chainable - Added insert method for bulk inserts - Added blinker signal support - Added query_counter context manager for tests From d32dd9ff62c0984af5062a4b52f974bb009b22a3 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Wed, 8 Jun 2011 13:07:08 +0100 Subject: [PATCH 04/13] Added _get_FIELD_display() for handy choice field display lookups closes #188 --- docs/changelog.rst | 1 + mongoengine/base.py | 12 +++++++++++- tests/fields.py | 29 +++++++++++++++++++++++++++++ 3 files changed, 41 insertions(+), 1 deletion(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 0a2a273f..c76b1154 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,7 @@ Changelog Changes in dev ============== +- Added get_FIELD_display() method for easy choice field displaying. - Added queryset.slave_okay(enabled) method - Updated queryset.timeout(enabled) and queryset.snapshot(enabled) to be chainable - Added insert method for bulk inserts diff --git a/mongoengine/base.py b/mongoengine/base.py index 76bb1ab7..3875fea5 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -8,6 +8,7 @@ import sys import pymongo import pymongo.objectid from operator import itemgetter +from functools import partial class NotRegistered(Exception): @@ -61,6 +62,7 @@ class BaseField(object): self.primary_key = primary_key self.validation = validation self.choices = choices + # Adjust the appropriate creation counter, and save our local copy. if self.db_field == '_id': self.creation_counter = BaseField.auto_creation_counter @@ -471,7 +473,10 @@ class BaseDocument(object): self._data = {} # Assign default values to instance - for attr_name in self._fields.keys(): + for attr_name, field in self._fields.items(): + if field.choices: # dynamically adds a way to get the display value for a field with choices + setattr(self, 'get_%s_display' % attr_name, partial(self._get_FIELD_display, field=field)) + # Use default value if present value = getattr(self, attr_name, None) setattr(self, attr_name, value) @@ -484,6 +489,11 @@ class BaseDocument(object): signals.post_init.send(self) + def _get_FIELD_display(self, field): + """Returns the display value for a choice field""" + value = getattr(self, field.name) + return dict(field.choices).get(value, value) + def validate(self): """Ensure that all fields' values are valid and that required fields are present. diff --git a/tests/fields.py b/tests/fields.py index 320e33db..d8970043 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -773,6 +773,35 @@ class FieldTest(unittest.TestCase): Shirt.drop_collection() + def test_choices_get_field_display(self): + """Test dynamic helper for returning the display value of a choices field. + """ + class Shirt(Document): + size = StringField(max_length=3, choices=(('S', 'Small'), ('M', 'Medium'), ('L', 'Large'), + ('XL', 'Extra Large'), ('XXL', 'Extra Extra Large'))) + style = StringField(max_length=3, choices=(('S', 'Small'), ('B', 'Baggy'), ('W', 'wide')), default='S') + + Shirt.drop_collection() + + shirt = Shirt() + + self.assertEqual(shirt.get_size_display(), None) + self.assertEqual(shirt.get_style_display(), 'Small') + + shirt.size = "XXL" + shirt.style = "B" + self.assertEqual(shirt.get_size_display(), 'Extra Extra Large') + self.assertEqual(shirt.get_style_display(), 'Baggy') + + # Set as Z - an invalid choice + shirt.size = "Z" + shirt.style = "Z" + self.assertEqual(shirt.get_size_display(), 'Z') + self.assertEqual(shirt.get_style_display(), 'Z') + self.assertRaises(ValidationError, shirt.validate) + + Shirt.drop_collection() + def test_file_fields(self): """Ensure that file fields can be written to and their data retrieved """ From 602d7dad0020937364f7076a1930d46209d6009d Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Wed, 8 Jun 2011 17:10:26 +0100 Subject: [PATCH 05/13] Improvements to Abstract Base Classes Added test example highlighting what to do to migrate a class from complex (allows inheritance) to simple. --- mongoengine/base.py | 13 +++++-- tests/document.py | 90 ++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 97 insertions(+), 6 deletions(-) diff --git a/mongoengine/base.py b/mongoengine/base.py index 3875fea5..8a0a1f23 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -263,7 +263,7 @@ class DocumentMetaclass(type): superclasses[base._class_name] = base superclasses.update(base._superclasses) - if hasattr(base, '_meta'): + if hasattr(base, '_meta') and not base._meta.get('abstract'): # Ensure that the Document class may be subclassed - # inheritance may be disabled to remove dependency on # additional fields _cls and _types @@ -280,7 +280,7 @@ class DocumentMetaclass(type): # Only simple classes - direct subclasses of Document - may set # allow_inheritance to False - if not simple_class and not meta['allow_inheritance']: + if not simple_class and not meta['allow_inheritance'] and not meta['abstract']: raise ValueError('Only direct subclasses of Document may set ' '"allow_inheritance" to False') attrs['_meta'] = meta @@ -360,8 +360,9 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): # Subclassed documents inherit collection from superclass for base in bases: - if hasattr(base, '_meta') and 'collection' in base._meta: - collection = base._meta['collection'] + if hasattr(base, '_meta'): + if 'collection' in base._meta: + collection = base._meta['collection'] # Propagate index options. for key in ('index_background', 'index_drop_dups', 'index_opts'): @@ -370,6 +371,9 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): id_field = id_field or base._meta.get('id_field') base_indexes += base._meta.get('indexes', []) + # Propagate 'allow_inheritance' + if 'allow_inheritance' in base._meta: + base_meta['allow_inheritance'] = base._meta['allow_inheritance'] meta = { 'abstract': False, @@ -384,6 +388,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): 'index_opts': {}, 'queryset_class': QuerySet, 'delete_rules': {}, + 'allow_inheritance': True } meta.update(base_meta) diff --git a/tests/document.py b/tests/document.py index a8120469..14541469 100644 --- a/tests/document.py +++ b/tests/document.py @@ -151,12 +151,12 @@ class DocumentTest(unittest.TestCase): """Ensure that inheritance may be disabled on simple classes and that _cls and _types will not be used. """ + class Animal(Document): - meta = {'allow_inheritance': False} name = StringField() + meta = {'allow_inheritance': False} Animal.drop_collection() - def create_dog_class(): class Dog(Animal): pass @@ -191,6 +191,92 @@ class DocumentTest(unittest.TestCase): self.assertFalse('_cls' in comment.to_mongo()) self.assertFalse('_types' in comment.to_mongo()) + def test_allow_inheritance_abstract_document(self): + """Ensure that abstract documents can set inheritance rules and that + _cls and _types will not be used. + """ + class FinalDocument(Document): + meta = {'abstract': True, + 'allow_inheritance': False} + + class Animal(FinalDocument): + name = StringField() + + Animal.drop_collection() + def create_dog_class(): + class Dog(Animal): + pass + self.assertRaises(ValueError, create_dog_class) + + # Check that _cls etc aren't present on simple documents + dog = Animal(name='dog') + dog.save() + collection = self.db[Animal._meta['collection']] + obj = collection.find_one() + self.assertFalse('_cls' in obj) + self.assertFalse('_types' in obj) + + Animal.drop_collection() + + def test_how_to_turn_off_inheritance(self): + """Demonstrates migrating from allow_inheritance = True to False. + """ + class Animal(Document): + name = StringField() + meta = { + 'indexes': ['name'] + } + + Animal.drop_collection() + + dog = Animal(name='dog') + dog.save() + + collection = self.db[Animal._meta['collection']] + obj = collection.find_one() + self.assertTrue('_cls' in obj) + self.assertTrue('_types' in obj) + + info = collection.index_information() + info = [value['key'] for key, value in info.iteritems()] + self.assertEquals([[(u'_id', 1)], [(u'_types', 1)], [(u'_types', 1), (u'name', 1)]], info) + + # Turn off inheritance + class Animal(Document): + name = StringField() + meta = { + 'allow_inheritance': False, + 'indexes': ['name'] + } + collection.update({}, {"$unset": {"_types": 1, "_cls": 1}}, False, True) + + # Confirm extra data is removed + obj = collection.find_one() + self.assertFalse('_cls' in obj) + self.assertFalse('_types' in obj) + + info = collection.index_information() + info = [value['key'] for key, value in info.iteritems()] + self.assertEquals([[(u'_id', 1)], [(u'_types', 1)], [(u'_types', 1), (u'name', 1)]], info) + + info = collection.index_information() + indexes_to_drop = [key for key, value in info.iteritems() if '_types' in dict(value['key'])] + for index in indexes_to_drop: + collection.drop_index(index) + + info = collection.index_information() + info = [value['key'] for key, value in info.iteritems()] + self.assertEquals([[(u'_id', 1)]], info) + + # Recreate indexes + dog = Animal.objects.first() + dog.save() + info = collection.index_information() + info = [value['key'] for key, value in info.iteritems()] + self.assertEquals([[(u'_id', 1)], [(u'name', 1),]], info) + + Animal.drop_collection() + def test_abstract_documents(self): """Ensure that a document superclass can be marked as abstract thereby not using it as the name for the collection.""" From 4b9bacf7316275d2c0c1efa7b5850b98374679cc Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Mon, 6 Jun 2011 17:21:54 +0100 Subject: [PATCH 06/13] Added ComplexBaseField * Handles the efficient lazy dereferencing of DBrefs. * Handles complex nested values in ListFields and DictFields * Allows for both strictly declared ListFields and DictFields where the embedded value must be of a field type or no restrictions where the values can be a mix of field types / values. * Handles DBrefences of documents where allow_inheritance = False. --- mongoengine/base.py | 206 +++++++++++++++++++++------- mongoengine/fields.py | 102 +++----------- mongoengine/queryset.py | 47 +++++-- tests/dereference.py | 112 +++++++++++++++- tests/fields.py | 287 +++++++++++++++++++++++++++++++--------- 5 files changed, 555 insertions(+), 199 deletions(-) diff --git a/mongoengine/base.py b/mongoengine/base.py index 8a0a1f23..a22795c7 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -132,15 +132,19 @@ class BaseField(object): self.validate(value) -class DereferenceBaseField(BaseField): - """Handles the lazy dereferencing of a queryset. Will dereference all +class ComplexBaseField(BaseField): + """Handles complex fields, such as lists / dictionaries. + + Allows for nesting of embedded documents inside complex types. + Handles the lazy dereferencing of a queryset by lazily dereferencing all items in a list / dict rather than one at a time. """ + field = None + def __get__(self, instance, owner): """Descriptor to automatically dereference references. """ - from fields import ReferenceField, GenericReferenceField from connection import _get_db if instance is None: @@ -149,68 +153,175 @@ class DereferenceBaseField(BaseField): # Get value from document instance if available value_list = instance._data.get(self.name) - if not value_list: - return super(DereferenceBaseField, self).__get__(instance, owner) + if not value_list or isinstance(value_list, basestring): + return super(ComplexBaseField, self).__get__(instance, owner) is_list = False if not hasattr(value_list, 'items'): is_list = True value_list = dict([(k,v) for k,v in enumerate(value_list)]) - if isinstance(self.field, ReferenceField) and value_list: - db = _get_db() - dbref = {} - collections = {} + for k,v in value_list.items(): + if isinstance(v, dict) and '_cls' in v and '_ref' not in v: + value_list[k] = get_document(v['_cls'].split('.')[-1])._from_son(v) - for k, v in value_list.items(): - dbref[k] = v - # Save any DBRefs + # Handle all dereferencing + db = _get_db() + dbref = {} + collections = {} + for k, v in value_list.items(): + dbref[k] = v + # Save any DBRefs + if isinstance(v, (pymongo.dbref.DBRef)): + # direct reference (DBRef) + collections.setdefault(v.collection, []).append((k, v)) + elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v: + # generic reference + collection = get_document(v['_cls'])._meta['collection'] + collections.setdefault(collection, []).append((k, v)) + + # For each collection get the references + for collection, dbrefs in collections.items(): + id_map = {} + for k, v in dbrefs: if isinstance(v, (pymongo.dbref.DBRef)): - collections.setdefault(v.collection, []).append((k, v)) + # direct reference (DBRef), has no _cls information + id_map[v.id] = (k, None) + elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v: + # generic reference - includes _cls information + id_map[v['_ref'].id] = (k, get_document(v['_cls'])) - # For each collection get the references - for collection, dbrefs in collections.items(): - id_map = dict([(v.id, k) for k, v in dbrefs]) - references = db[collection].find({'_id': {'$in': id_map.keys()}}) - for ref in references: - key = id_map[ref['_id']] - dbref[key] = get_document(ref['_cls'])._from_son(ref) + references = db[collection].find({'_id': {'$in': id_map.keys()}}) + for ref in references: + key, doc_cls = id_map[ref['_id']] + if not doc_cls: # If no doc_cls get it from the referenced doc + doc_cls = get_document(ref['_cls']) + dbref[key] = doc_cls._from_son(ref) - if is_list: - dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))] - instance._data[self.name] = dbref + if is_list: + dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))] + instance._data[self.name] = dbref + return super(ComplexBaseField, self).__get__(instance, owner) - # Get value from document instance if available - if isinstance(self.field, GenericReferenceField) and value_list: - db = _get_db() - value_list = [(k,v) for k,v in value_list.items()] - dbref = {} - classes = {} + def to_python(self, value): + """Convert a MongoDB-compatible type to a Python type. + """ + from mongoengine import Document - for k, v in value_list: - dbref[k] = v - # Save any DBRefs - if isinstance(v, (dict, pymongo.son.SON)): - classes.setdefault(v['_cls'], []).append((k, v)) + if isinstance(value, basestring): + return value - # For each collection get the references - for doc_cls, dbrefs in classes.items(): - id_map = dict([(v['_ref'].id, k) for k, v in dbrefs]) - doc_cls = get_document(doc_cls) - collection = doc_cls._meta['collection'] - references = db[collection].find({'_id': {'$in': id_map.keys()}}) + if hasattr(value, 'to_python'): + return value.to_python() - for ref in references: - key = id_map[ref['_id']] - dbref[key] = doc_cls._from_son(ref) + is_list = False + if not hasattr(value, 'items'): + try: + is_list = True + value = dict([(k,v) for k,v in enumerate(value)]) + except TypeError: # Not iterable return the value + return value - if is_list: - dbref = [v for k,v in sorted(dbref.items(), key=itemgetter(0))] + if self.field: + value_dict = dict([(key, self.field.to_python(item)) for key, item in value.items()]) + else: + value_dict = {} + for k,v in value.items(): + if isinstance(v, Document): + # We need the id from the saved object to create the DBRef + if v.pk is None: + raise ValidationError('You can only reference documents once ' + 'they have been saved to the database') + collection = v._meta['collection'] + value_dict[k] = pymongo.dbref.DBRef(collection, v.pk) + elif hasattr(v, 'to_python'): + value_dict[k] = v.to_python() + else: + value_dict[k] = self.to_python(v) - instance._data[self.name] = dbref + if is_list: # Convert back to a list + return [v for k,v in sorted(value_dict.items(), key=itemgetter(0))] + return value_dict - return super(DereferenceBaseField, self).__get__(instance, owner) + def to_mongo(self, value): + """Convert a Python type to a MongoDB-compatible type. + """ + from mongoengine import Document + if isinstance(value, basestring): + return value + + if hasattr(value, 'to_mongo'): + return value.to_mongo() + + is_list = False + if not hasattr(value, 'items'): + try: + is_list = True + value = dict([(k,v) for k,v in enumerate(value)]) + except TypeError: # Not iterable return the value + return value + + if self.field: + value_dict = dict([(key, self.field.to_mongo(item)) for key, item in value.items()]) + else: + value_dict = {} + for k,v in value.items(): + if isinstance(v, Document): + # We need the id from the saved object to create the DBRef + if v.pk is None: + raise ValidationError('You can only reference documents once ' + 'they have been saved to the database') + + # If its a document that is not inheritable it won't have + # _types / _cls data so make it a generic reference allows + # us to dereference + meta = getattr(v, 'meta', getattr(v, '_meta', {})) + if meta and not meta['allow_inheritance'] and not self.field: + from fields import GenericReferenceField + value_dict[k] = GenericReferenceField().to_mongo(v) + else: + collection = v._meta['collection'] + value_dict[k] = pymongo.dbref.DBRef(collection, v.pk) + elif hasattr(v, 'to_mongo'): + value_dict[k] = v.to_mongo() + else: + value_dict[k] = self.to_mongo(v) + + if is_list: # Convert back to a list + return [v for k,v in sorted(value_dict.items(), key=itemgetter(0))] + return value_dict + + def validate(self, value): + """If field provided ensure the value is valid. + """ + if self.field: + try: + if hasattr(value, 'iteritems'): + [self.field.validate(v) for k,v in value.iteritems()] + else: + [self.field.validate(v) for v in value] + except Exception, err: + raise ValidationError('Invalid %s item (%s)' % ( + self.field.__class__.__name__, str(v))) + + def prepare_query_value(self, op, value): + return self.to_mongo(value) + + def lookup_member(self, member_name): + if self.field: + return self.field.lookup_member(member_name) + return None + + def _set_owner_document(self, owner_document): + if self.field: + self.field.owner_document = owner_document + self._owner_document = owner_document + + def _get_owner_document(self, owner_document): + self._owner_document = owner_document + + owner_document = property(_get_owner_document, _set_owner_document) class ObjectIdField(BaseField): @@ -219,7 +330,6 @@ class ObjectIdField(BaseField): def to_python(self, value): return value - # return unicode(value) def to_mongo(self, value): if not isinstance(value, pymongo.objectid.ObjectId): diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 1995d345..f9b2580b 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -1,4 +1,4 @@ -from base import (BaseField, DereferenceBaseField, ObjectIdField, +from base import (BaseField, ComplexBaseField, ObjectIdField, ValidationError, get_document) from queryset import DO_NOTHING from document import Document, EmbeddedDocument @@ -301,6 +301,8 @@ class EmbeddedDocumentField(BaseField): return value def to_mongo(self, value): + if isinstance(value, basestring): + return value return self.document_type.to_mongo(value) def validate(self, value): @@ -320,7 +322,7 @@ class EmbeddedDocumentField(BaseField): return self.to_mongo(value) -class ListField(DereferenceBaseField): +class ListField(ComplexBaseField): """A list field that wraps a standard field, allowing multiple instances of the field to be used as a list in the database. """ @@ -328,48 +330,25 @@ class ListField(DereferenceBaseField): # ListFields cannot be indexed with _types - MongoDB doesn't support this _index_with_types = False - def __init__(self, field, **kwargs): - if not isinstance(field, BaseField): - raise ValidationError('Argument to ListField constructor must be ' - 'a valid field') + def __init__(self, field=None, **kwargs): self.field = field kwargs.setdefault('default', lambda: []) super(ListField, self).__init__(**kwargs) - def to_python(self, value): - return [self.field.to_python(item) for item in value] - - def to_mongo(self, value): - return [self.field.to_mongo(item) for item in value] - def validate(self, value): """Make sure that a list of valid fields is being used. """ if not isinstance(value, (list, tuple)): raise ValidationError('Only lists and tuples may be used in a ' 'list field') - - try: - [self.field.validate(item) for item in value] - except Exception, err: - raise ValidationError('Invalid ListField item (%s)' % str(item)) + super(ListField, self).validate(value) def prepare_query_value(self, op, value): - if op in ('set', 'unset'): - return [self.field.prepare_query_value(op, v) for v in value] - return self.field.prepare_query_value(op, value) - - def lookup_member(self, member_name): - return self.field.lookup_member(member_name) - - def _set_owner_document(self, owner_document): - self.field.owner_document = owner_document - self._owner_document = owner_document - - def _get_owner_document(self, owner_document): - self._owner_document = owner_document - - owner_document = property(_get_owner_document, _set_owner_document) + if self.field: + if op in ('set', 'unset') and not isinstance(value, basestring): + return [self.field.prepare_query_value(op, v) for v in value] + return self.field.prepare_query_value(op, value) + return super(ListField, self).prepare_query_value(op, value) class SortedListField(ListField): @@ -388,20 +367,21 @@ class SortedListField(ListField): super(SortedListField, self).__init__(field, **kwargs) def to_mongo(self, value): + value = super(SortedListField, self).to_mongo(value) if self._ordering is not None: - return sorted([self.field.to_mongo(item) for item in value], - key=itemgetter(self._ordering)) - return sorted([self.field.to_mongo(item) for item in value]) + return sorted(value, key=itemgetter(self._ordering)) + return sorted(value) -class DictField(BaseField): +class DictField(ComplexBaseField): """A dictionary field that wraps a standard Python dictionary. This is similar to an embedded document, but the structure is not defined. .. versionadded:: 0.3 """ - def __init__(self, basecls=None, *args, **kwargs): + def __init__(self, basecls=None, field=None, *args, **kwargs): + self.field = field self.basecls = basecls or BaseField assert issubclass(self.basecls, BaseField) kwargs.setdefault('default', lambda: {}) @@ -417,6 +397,7 @@ class DictField(BaseField): if any(('.' in k or '$' in k) for k in value): raise ValidationError('Invalid dictionary key name - keys may not ' 'contain "." or "$" characters') + super(DictField, self).validate(value) def lookup_member(self, member_name): return DictField(basecls=self.basecls, db_field=member_name) @@ -432,7 +413,7 @@ class DictField(BaseField): return super(DictField, self).prepare_query_value(op, value) -class MapField(DereferenceBaseField): +class MapField(DictField): """A field that maps a name to a specified field type. Similar to a DictField, except the 'value' of each item must match the specified field type. @@ -444,50 +425,7 @@ class MapField(DereferenceBaseField): if not isinstance(field, BaseField): raise ValidationError('Argument to MapField constructor must be ' 'a valid field') - self.field = field - kwargs.setdefault('default', lambda: {}) - super(MapField, self).__init__(*args, **kwargs) - - def validate(self, value): - """Make sure that a list of valid fields is being used. - """ - if not isinstance(value, dict): - raise ValidationError('Only dictionaries may be used in a ' - 'DictField') - - if any(('.' in k or '$' in k) for k in value): - raise ValidationError('Invalid dictionary key name - keys may not ' - 'contain "." or "$" characters') - - try: - [self.field.validate(item) for item in value.values()] - except Exception, err: - raise ValidationError('Invalid MapField item (%s)' % str(item)) - - def to_python(self, value): - return dict([(key, self.field.to_python(item)) for key, item in value.iteritems()]) - - def to_mongo(self, value): - return dict([(key, self.field.to_mongo(item)) for key, item in value.iteritems()]) - - def prepare_query_value(self, op, value): - if op not in ('set', 'unset'): - return self.field.prepare_query_value(op, value) - for key in value: - value[key] = self.field.prepare_query_value(op, value[key]) - return value - - def lookup_member(self, member_name): - return self.field.lookup_member(member_name) - - def _set_owner_document(self, owner_document): - self.field.owner_document = owner_document - self._owner_document = owner_document - - def _get_owner_document(self, owner_document): - self._owner_document = owner_document - - owner_document = property(_get_owner_document, _set_owner_document) + super(MapField, self).__init__(field=field, *args, **kwargs) class ReferenceField(BaseField): diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 1dfe55af..666567e2 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -549,11 +549,12 @@ class QuerySet(object): parts = [parts] fields = [] field = None + for field_name in parts: # Handle ListField indexing: if field_name.isdigit(): try: - field = field.field + new_field = field.field except AttributeError, err: raise InvalidQueryError( "Can't use index on unsubscriptable field (%s)" % err) @@ -567,11 +568,17 @@ class QuerySet(object): field = document._fields[field_name] else: # Look up subfield on the previous field - field = field.lookup_member(field_name) - if field is None: + new_field = field.lookup_member(field_name) + from base import ComplexBaseField + if not new_field and isinstance(field, ComplexBaseField): + fields.append(field_name) + continue + elif not new_field: raise InvalidQueryError('Cannot resolve field "%s"' - % field_name) + % field_name) + field = new_field # update field to the new field type fields.append(field) + return fields @classmethod @@ -615,14 +622,33 @@ class QuerySet(object): if _doc_cls: # Switch field names to proper names [set in Field(name='foo')] fields = QuerySet._lookup_field(_doc_cls, parts) - parts = [field.db_field for field in fields] + parts = [] + + cleaned_fields = [] + append_field = True + for field in fields: + if isinstance(field, str): + parts.append(field) + append_field = False + else: + parts.append(field.db_field) + if append_field: + cleaned_fields.append(field) # Convert value to proper value - field = fields[-1] + field = cleaned_fields[-1] + singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not'] singular_ops += match_operators if op in singular_ops: - value = field.prepare_query_value(op, value) + if isinstance(field, basestring): + if op in match_operators and isinstance(value, basestring): + from mongoengine import StringField + value = StringField().prepare_query_value(op, value) + else: + value = field + else: + value = field.prepare_query_value(op, value) elif op in ('in', 'nin', 'all', 'near'): # 'in', 'nin' and 'all' require a list of values value = [field.prepare_query_value(op, v) for v in value] @@ -1170,14 +1196,19 @@ class QuerySet(object): fields = QuerySet._lookup_field(_doc_cls, parts) parts = [] + cleaned_fields = [] + append_field = True for field in fields: if isinstance(field, str): parts.append(field) + append_field = False else: parts.append(field.db_field) + if append_field: + cleaned_fields.append(field) # Convert value to proper value - field = fields[-1] + field = cleaned_fields[-1] if op in (None, 'set', 'push', 'pull', 'addToSet'): value = field.prepare_query_value(op, value) diff --git a/tests/dereference.py b/tests/dereference.py index b6cee89e..68792721 100644 --- a/tests/dereference.py +++ b/tests/dereference.py @@ -122,6 +122,64 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 4) + for m in group_obj.members: + self.assertTrue('User' in m.__class__.__name__) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + def test_list_field_complex(self): + + class UserA(Document): + name = StringField() + + class UserB(Document): + name = StringField() + + class UserC(Document): + name = StringField() + + class Group(Document): + members = ListField() + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + members = [] + for i in xrange(1, 51): + a = UserA(name='User A %s' % i) + a.save() + + b = UserB(name='User B %s' % i) + b.save() + + c = UserC(name='User C %s' % i) + c.save() + + members += [a, b, c] + + group = Group(members=members) + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + + for m in group_obj.members: + self.assertTrue('User' in m.__class__.__name__) + UserA.drop_collection() UserB.drop_collection() UserC.drop_collection() @@ -156,10 +214,13 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 2) + for k, m in group_obj.members.iteritems(): + self.assertTrue(isinstance(m, User)) + User.drop_collection() Group.drop_collection() - def ztest_generic_reference_dict_field(self): + def test_dict_field(self): class UserA(Document): name = StringField() @@ -206,6 +267,9 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 4) + for k, m in group_obj.members.iteritems(): + self.assertTrue('User' in m.__class__.__name__) + group.members = {} group.save() @@ -218,11 +282,54 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 1) + for k, m in group_obj.members.iteritems(): + self.assertTrue('User' in m.__class__.__name__) + UserA.drop_collection() UserB.drop_collection() UserC.drop_collection() Group.drop_collection() + def test_dict_field_no_field_inheritance(self): + + class UserA(Document): + name = StringField() + meta = {'allow_inheritance': False} + + class Group(Document): + members = DictField() + + UserA.drop_collection() + Group.drop_collection() + + members = [] + for i in xrange(1, 51): + a = UserA(name='User A %s' % i) + a.save() + + members += [a] + + group = Group(members=dict([(str(u.id), u) for u in members])) + group.save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 2) + + [m for m in group_obj.members] + self.assertEqual(q, 2) + + for k, m in group_obj.members.iteritems(): + self.assertTrue(isinstance(m, UserA)) + + UserA.drop_collection() + Group.drop_collection() + def test_generic_reference_map_field(self): class UserA(Document): @@ -270,6 +377,9 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 4) + for k, m in group_obj.members.iteritems(): + self.assertTrue('User' in m.__class__.__name__) + group.members = {} group.save() diff --git a/tests/fields.py b/tests/fields.py index d8970043..4d51ed51 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -322,6 +322,108 @@ class FieldTest(unittest.TestCase): BlogPost.drop_collection() + def test_list_field(self): + """Ensure that list types work as expected. + """ + class BlogPost(Document): + info = ListField() + + BlogPost.drop_collection() + + post = BlogPost() + post.info = 'my post' + self.assertRaises(ValidationError, post.validate) + + post.info = {'title': 'test'} + self.assertRaises(ValidationError, post.validate) + + post.info = ['test'] + post.save() + + post = BlogPost() + post.info = [{'test': 'test'}] + post.save() + + post = BlogPost() + post.info = [{'test': 3}] + post.save() + + + self.assertEquals(BlogPost.objects.count(), 3) + self.assertEquals(BlogPost.objects.filter(info__exact='test').count(), 1) + self.assertEquals(BlogPost.objects.filter(info__0__test='test').count(), 1) + + # Confirm handles non strings or non existing keys + self.assertEquals(BlogPost.objects.filter(info__0__test__exact='5').count(), 0) + self.assertEquals(BlogPost.objects.filter(info__100__test__exact='test').count(), 0) + BlogPost.drop_collection() + + def test_list_field_Strict(self): + """Ensure that list field handles validation if provided a strict field type.""" + + class Simple(Document): + mapping = ListField(field=IntField()) + + Simple.drop_collection() + + e = Simple() + e.mapping = [1] + e.save() + + def create_invalid_mapping(): + e.mapping = ["abc"] + e.save() + + self.assertRaises(ValidationError, create_invalid_mapping) + + Simple.drop_collection() + + def test_list_field_complex(self): + """Ensure that the list fields can handle the complex types.""" + + class SettingBase(EmbeddedDocument): + pass + + class StringSetting(SettingBase): + value = StringField() + + class IntegerSetting(SettingBase): + value = IntField() + + class Simple(Document): + mapping = ListField() + + Simple.drop_collection() + e = Simple() + e.mapping.append(StringSetting(value='foo')) + e.mapping.append(IntegerSetting(value=42)) + e.mapping.append({'number': 1, 'string': 'Hi!', 'float': 1.001, + 'complex': IntegerSetting(value=42), 'list': + [IntegerSetting(value=42), StringSetting(value='foo')]}) + e.save() + + e2 = Simple.objects.get(id=e.id) + self.assertTrue(isinstance(e2.mapping[0], StringSetting)) + self.assertTrue(isinstance(e2.mapping[1], IntegerSetting)) + + # Test querying + self.assertEquals(Simple.objects.filter(mapping__1__value=42).count(), 1) + self.assertEquals(Simple.objects.filter(mapping__2__number=1).count(), 1) + self.assertEquals(Simple.objects.filter(mapping__2__complex__value=42).count(), 1) + self.assertEquals(Simple.objects.filter(mapping__2__list__0__value=42).count(), 1) + self.assertEquals(Simple.objects.filter(mapping__2__list__1__value='foo').count(), 1) + + # Confirm can update + Simple.objects().update(set__mapping__1=IntegerSetting(value=10)) + self.assertEquals(Simple.objects.filter(mapping__1__value=10).count(), 1) + + Simple.objects().update( + set__mapping__2__list__1=StringSetting(value='Boo')) + self.assertEquals(Simple.objects.filter(mapping__2__list__1__value='foo').count(), 0) + self.assertEquals(Simple.objects.filter(mapping__2__list__1__value='Boo').count(), 1) + + Simple.drop_collection() + def test_dict_field(self): """Ensure that dict types work as expected. """ @@ -363,6 +465,131 @@ class FieldTest(unittest.TestCase): self.assertEquals(BlogPost.objects.filter(info__made_up__test__exact='test').count(), 0) BlogPost.drop_collection() + def test_dictfield_Strict(self): + """Ensure that dict field handles validation if provided a strict field type.""" + + class Simple(Document): + mapping = DictField(field=IntField()) + + Simple.drop_collection() + + e = Simple() + e.mapping['someint'] = 1 + e.save() + + def create_invalid_mapping(): + e.mapping['somestring'] = "abc" + e.save() + + self.assertRaises(ValidationError, create_invalid_mapping) + + Simple.drop_collection() + + def test_dictfield_complex(self): + """Ensure that the dict field can handle the complex types.""" + + class SettingBase(EmbeddedDocument): + pass + + class StringSetting(SettingBase): + value = StringField() + + class IntegerSetting(SettingBase): + value = IntField() + + class Simple(Document): + mapping = DictField() + + Simple.drop_collection() + e = Simple() + e.mapping['somestring'] = StringSetting(value='foo') + e.mapping['someint'] = IntegerSetting(value=42) + e.mapping['nested_dict'] = {'number': 1, 'string': 'Hi!', 'float': 1.001, + 'complex': IntegerSetting(value=42), 'list': + [IntegerSetting(value=42), StringSetting(value='foo')]} + e.save() + + e2 = Simple.objects.get(id=e.id) + self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting)) + self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting)) + + # Test querying + self.assertEquals(Simple.objects.filter(mapping__someint__value=42).count(), 1) + self.assertEquals(Simple.objects.filter(mapping__nested_dict__number=1).count(), 1) + self.assertEquals(Simple.objects.filter(mapping__nested_dict__complex__value=42).count(), 1) + self.assertEquals(Simple.objects.filter(mapping__nested_dict__list__0__value=42).count(), 1) + self.assertEquals(Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 1) + + # Confirm can update + Simple.objects().update( + set__mapping={"someint": IntegerSetting(value=10)}) + Simple.objects().update( + set__mapping__nested_dict__list__1=StringSetting(value='Boo')) + self.assertEquals(Simple.objects.filter(mapping__nested_dict__list__1__value='foo').count(), 0) + self.assertEquals(Simple.objects.filter(mapping__nested_dict__list__1__value='Boo').count(), 1) + + Simple.drop_collection() + + def test_mapfield(self): + """Ensure that the MapField handles the declared type.""" + + class Simple(Document): + mapping = MapField(IntField()) + + Simple.drop_collection() + + e = Simple() + e.mapping['someint'] = 1 + e.save() + + def create_invalid_mapping(): + e.mapping['somestring'] = "abc" + e.save() + + self.assertRaises(ValidationError, create_invalid_mapping) + + def create_invalid_class(): + class NoDeclaredType(Document): + mapping = MapField() + + self.assertRaises(ValidationError, create_invalid_class) + + Simple.drop_collection() + + def test_complex_mapfield(self): + """Ensure that the MapField can handle complex declared types.""" + + class SettingBase(EmbeddedDocument): + pass + + class StringSetting(SettingBase): + value = StringField() + + class IntegerSetting(SettingBase): + value = IntField() + + class Extensible(Document): + mapping = MapField(EmbeddedDocumentField(SettingBase)) + + Extensible.drop_collection() + + e = Extensible() + e.mapping['somestring'] = StringSetting(value='foo') + e.mapping['someint'] = IntegerSetting(value=42) + e.save() + + e2 = Extensible.objects.get(id=e.id) + self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting)) + self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting)) + + def create_invalid_mapping(): + e.mapping['someint'] = 123 + e.save() + + self.assertRaises(ValidationError, create_invalid_mapping) + + Extensible.drop_collection() + def test_embedded_document_validation(self): """Ensure that invalid embedded documents cannot be assigned to embedded document fields. @@ -933,66 +1160,6 @@ class FieldTest(unittest.TestCase): self.assertEqual(d2.data, {}) self.assertEqual(d2.data2, {}) - def test_mapfield(self): - """Ensure that the MapField handles the declared type.""" - - class Simple(Document): - mapping = MapField(IntField()) - - Simple.drop_collection() - - e = Simple() - e.mapping['someint'] = 1 - e.save() - - def create_invalid_mapping(): - e.mapping['somestring'] = "abc" - e.save() - - self.assertRaises(ValidationError, create_invalid_mapping) - - def create_invalid_class(): - class NoDeclaredType(Document): - mapping = MapField() - - self.assertRaises(ValidationError, create_invalid_class) - - Simple.drop_collection() - - def test_complex_mapfield(self): - """Ensure that the MapField can handle complex declared types.""" - - class SettingBase(EmbeddedDocument): - pass - - class StringSetting(SettingBase): - value = StringField() - - class IntegerSetting(SettingBase): - value = IntField() - - class Extensible(Document): - mapping = MapField(EmbeddedDocumentField(SettingBase)) - - Extensible.drop_collection() - - e = Extensible() - e.mapping['somestring'] = StringSetting(value='foo') - e.mapping['someint'] = IntegerSetting(value=42) - e.save() - - e2 = Extensible.objects.get(id=e.id) - self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting)) - self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting)) - - def create_invalid_mapping(): - e.mapping['someint'] = 123 - e.save() - - self.assertRaises(ValidationError, create_invalid_mapping) - - Extensible.drop_collection() - if __name__ == '__main__': unittest.main() From b9255f73c381c820d13ec30fba499a3fe6868a3e Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Thu, 9 Jun 2011 11:28:57 +0100 Subject: [PATCH 07/13] Updated docs --- docs/changelog.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index c76b1154..f4be4ca6 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,13 +5,13 @@ Changelog Changes in dev ============== +- Added ComplexBaseField - for improved flexibility and performance. - Added get_FIELD_display() method for easy choice field displaying. - Added queryset.slave_okay(enabled) method - Updated queryset.timeout(enabled) and queryset.snapshot(enabled) to be chainable - Added insert method for bulk inserts - Added blinker signal support - Added query_counter context manager for tests -- Added DereferenceBaseField - for improved performance in field dereferencing - Added optional map_reduce method item_frequencies - Added inline_map_reduce option to map_reduce - Updated connection exception so it provides more info on the cause. From a66417e9d098b03b5dfaf04ab23fa8d185dd38e2 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Thu, 9 Jun 2011 11:31:47 +0100 Subject: [PATCH 08/13] pep8 update --- tests/fields.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/fields.py b/tests/fields.py index 4d51ed51..1b199982 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -358,7 +358,7 @@ class FieldTest(unittest.TestCase): self.assertEquals(BlogPost.objects.filter(info__100__test__exact='test').count(), 0) BlogPost.drop_collection() - def test_list_field_Strict(self): + def test_list_field_strict(self): """Ensure that list field handles validation if provided a strict field type.""" class Simple(Document): @@ -465,7 +465,7 @@ class FieldTest(unittest.TestCase): self.assertEquals(BlogPost.objects.filter(info__made_up__test__exact='test').count(), 0) BlogPost.drop_collection() - def test_dictfield_Strict(self): + def test_dictfield_strict(self): """Ensure that dict field handles validation if provided a strict field type.""" class Simple(Document): From 199b4eb860a93c581c1ddfc915f7094fc28de678 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Thu, 9 Jun 2011 12:08:37 +0100 Subject: [PATCH 09/13] Added django_tests and regression test for order_by Refs #190 --- setup.py | 2 +- tests/django_tests.py | 44 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) create mode 100644 tests/django_tests.py diff --git a/setup.py b/setup.py index d3be64b3..1f65ae5d 100644 --- a/setup.py +++ b/setup.py @@ -45,6 +45,6 @@ setup(name='mongoengine', long_description=LONG_DESCRIPTION, platforms=['any'], classifiers=CLASSIFIERS, - install_requires=['pymongo', 'blinker'], + install_requires=['pymongo', 'blinker', 'django>=1.3'], test_suite='tests', ) diff --git a/tests/django_tests.py b/tests/django_tests.py new file mode 100644 index 00000000..e5e26022 --- /dev/null +++ b/tests/django_tests.py @@ -0,0 +1,44 @@ + +# -*- coding: utf-8 -*- + +import unittest + +from mongoengine import * + + +class QuerySetTest(unittest.TestCase): + + def setUp(self): + connect(db='mongoenginetest') + + class Person(Document): + name = StringField() + age = IntField() + self.Person = Person + + def test_order_by_in_django_template(self): + """Ensure that QuerySets are properly ordered in Django template. + """ + self.Person.drop_collection() + + self.Person(name="A", age=20).save() + self.Person(name="D", age=10).save() + self.Person(name="B", age=40).save() + self.Person(name="C", age=30).save() + + from django.conf import settings + settings.configure() + from django.template import Context, Template + + t = Template("{% for o in ol %}{{ o.name }}-{{ o.age }}:{% endfor %}") + + d = {"ol": self.Person.objects.order_by('-name')} + self.assertEqual(t.render(Context(d)), u'D-10:C-30:B-40:A-20:') + d = {"ol": self.Person.objects.order_by('+name')} + self.assertEqual(t.render(Context(d)), u'A-20:B-40:C-30:D-10:') + d = {"ol": self.Person.objects.order_by('-age')} + self.assertEqual(t.render(Context(d)), u'B-40:C-30:A-20:D-10:') + d = {"ol": self.Person.objects.order_by('+age')} + self.assertEqual(t.render(Context(d)), u'D-10:A-20:C-30:B-40:') + + self.Person.drop_collection() \ No newline at end of file From 417bb1b35d21c4bf02cb0acfd95f5b1ff6c49d70 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Thu, 9 Jun 2011 12:15:36 +0100 Subject: [PATCH 10/13] Added regression test for #185 --- tests/django_tests.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/tests/django_tests.py b/tests/django_tests.py index e5e26022..6be1ea25 100644 --- a/tests/django_tests.py +++ b/tests/django_tests.py @@ -5,6 +5,9 @@ import unittest from mongoengine import * +from django.template import Context, Template +from django.conf import settings +settings.configure() class QuerySetTest(unittest.TestCase): @@ -26,10 +29,6 @@ class QuerySetTest(unittest.TestCase): self.Person(name="B", age=40).save() self.Person(name="C", age=30).save() - from django.conf import settings - settings.configure() - from django.template import Context, Template - t = Template("{% for o in ol %}{{ o.name }}-{{ o.age }}:{% endfor %}") d = {"ol": self.Person.objects.order_by('-name')} @@ -41,4 +40,18 @@ class QuerySetTest(unittest.TestCase): d = {"ol": self.Person.objects.order_by('+age')} self.assertEqual(t.render(Context(d)), u'D-10:A-20:C-30:B-40:') - self.Person.drop_collection() \ No newline at end of file + self.Person.drop_collection() + + def test_q_object_filter_in_template(self): + + self.Person.drop_collection() + + self.Person(name="A", age=20).save() + self.Person(name="D", age=10).save() + self.Person(name="B", age=40).save() + self.Person(name="C", age=30).save() + + t = Template("{% for o in ol %}{{ o.name }}-{{ o.age }}:{% endfor %}") + + d = {"ol": self.Person.objects.filter(Q(age=10) | Q(name="C"))} + self.assertEqual(t.render(Context(d)), u'D-10:C-30:') \ No newline at end of file From b2848b85194dee2429d35036c96a6c800cef42bf Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Thu, 9 Jun 2011 14:20:21 +0100 Subject: [PATCH 11/13] Added ComplexDateTimeField Thanks to @pelletier for the code. Refs #187 --- docs/apireference.rst | 2 + mongoengine/fields.py | 97 +++++++++++++++++++++++++++++++++++++++- tests/fields.py | 101 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 198 insertions(+), 2 deletions(-) diff --git a/docs/apireference.rst b/docs/apireference.rst index a3d287ab..2442803d 100644 --- a/docs/apireference.rst +++ b/docs/apireference.rst @@ -53,6 +53,8 @@ Fields .. autoclass:: mongoengine.DateTimeField +.. autoclass:: mongoengine.ComplexDateTimeField + .. autoclass:: mongoengine.EmbeddedDocumentField .. autoclass:: mongoengine.DictField diff --git a/mongoengine/fields.py b/mongoengine/fields.py index f9b2580b..5d5304ae 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -18,8 +18,9 @@ import gridfs __all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField', 'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField', 'ObjectIdField', 'ReferenceField', 'ValidationError', 'MapField', - 'DecimalField', 'URLField', 'GenericReferenceField', 'FileField', - 'BinaryField', 'SortedListField', 'EmailField', 'GeoPointField'] + 'DecimalField', 'ComplexDateTimeField', 'URLField', + 'GenericReferenceField', 'FileField', 'BinaryField', + 'SortedListField', 'EmailField', 'GeoPointField'] RECURSIVE_REFERENCE_CONSTANT = 'self' @@ -273,6 +274,98 @@ class DateTimeField(BaseField): return None +class ComplexDateTimeField(StringField): + """ + ComplexDateTimeField handles microseconds exactly instead of rounding + like DateTimeField does. + + Derives from a StringField so you can do `gte` and `lte` filtering by + using lexicographical comparison when filtering / sorting strings. + + The stored string has the following format: + + YYYY,MM,DD,HH,MM,SS,NNNNNN + + Where NNNNNN is the number of microseconds of the represented `datetime`. + The `,` as the separator can be easily modified by passing the `separator` + keyword when initializing the field. + """ + + def __init__(self, separator=',', **kwargs): + self.names = ['year', 'month', 'day', 'hour', 'minute', 'second', + 'microsecond'] + self.separtor = separator + super(ComplexDateTimeField, self).__init__(**kwargs) + + def _leading_zero(self, number): + """ + Converts the given number to a string. + + If it has only one digit, a leading zero so as it has always at least + two digits. + """ + if int(number) < 10: + return "0%s" % number + else: + return str(number) + + def _convert_from_datetime(self, val): + """ + Convert a `datetime` object to a string representation (which will be + stored in MongoDB). This is the reverse function of + `_convert_from_string`. + + >>> a = datetime(2011, 6, 8, 20, 26, 24, 192284) + >>> RealDateTimeField()._convert_from_datetime(a) + '2011,06,08,20,26,24,192284' + """ + data = [] + for name in self.names: + data.append(self._leading_zero(getattr(val, name))) + return ','.join(data) + + def _convert_from_string(self, data): + """ + Convert a string representation to a `datetime` object (the object you + will manipulate). This is the reverse function of + `_convert_from_datetime`. + + >>> a = '2011,06,08,20,26,24,192284' + >>> ComplexDateTimeField()._convert_from_string(a) + datetime.datetime(2011, 6, 8, 20, 26, 24, 192284) + """ + data = data.split(',') + data = map(int, data) + values = {} + for i in range(7): + values[self.names[i]] = data[i] + return datetime.datetime(**values) + + def __get__(self, instance, owner): + data = super(ComplexDateTimeField, self).__get__(instance, owner) + if data == None: + return datetime.datetime.now() + return self._convert_from_string(data) + + def __set__(self, obj, val): + data = self._convert_from_datetime(val) + return super(ComplexDateTimeField, self).__set__(obj, data) + + def validate(self, value): + if not isinstance(value, datetime.datetime): + raise ValidationError('Only datetime objects may used in a \ + ComplexDateTimeField') + + def to_python(self, value): + return self._convert_from_string(value) + + def to_mongo(self, value): + return self._convert_from_datetime(value) + + def prepare_query_value(self, op, value): + return self._convert_from_datetime(value) + + class EmbeddedDocumentField(BaseField): """An embedded document field. Only valid values are subclasses of :class:`~mongoengine.EmbeddedDocument`. diff --git a/tests/fields.py b/tests/fields.py index 1b199982..531167c8 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -247,6 +247,107 @@ class FieldTest(unittest.TestCase): LogEntry.drop_collection() + def test_complexdatetime_storage(self): + """Tests for complex datetime fields - which can handle microseconds + without rounding. + """ + class LogEntry(Document): + date = ComplexDateTimeField() + + LogEntry.drop_collection() + + # Post UTC - microseconds are rounded (down) nearest millisecond and dropped - with default datetimefields + d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999) + log = LogEntry() + log.date = d1 + log.save() + log.reload() + self.assertEquals(log.date, d1) + + # Post UTC - microseconds are rounded (down) nearest millisecond - with default datetimefields + d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9999) + log.date = d1 + log.save() + log.reload() + self.assertEquals(log.date, d1) + + # Pre UTC dates microseconds below 1000 are dropped - with default datetimefields + d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999) + log.date = d1 + log.save() + log.reload() + self.assertEquals(log.date, d1) + + # Pre UTC microseconds above 1000 is wonky - with default datetimefields + # log.date has an invalid microsecond value so I can't construct + # a date to compare. + for i in xrange(1001, 3113, 33): + d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, i) + log.date = d1 + log.save() + log.reload() + self.assertEquals(log.date, d1) + log1 = LogEntry.objects.get(date=d1) + self.assertEqual(log, log1) + + LogEntry.drop_collection() + + def test_complexdatetime_usage(self): + """Tests for complex datetime fields - which can handle microseconds + without rounding. + """ + class LogEntry(Document): + date = ComplexDateTimeField() + + LogEntry.drop_collection() + + d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999) + log = LogEntry() + log.date = d1 + log.save() + + log1 = LogEntry.objects.get(date=d1) + self.assertEquals(log, log1) + + LogEntry.drop_collection() + + # create 60 log entries + for i in xrange(1950, 2010): + d = datetime.datetime(i, 01, 01, 00, 00, 01, 999) + LogEntry(date=d).save() + + self.assertEqual(LogEntry.objects.count(), 60) + + # Test ordering + logs = LogEntry.objects.order_by("date") + count = logs.count() + i = 0 + while i == count-1: + self.assertTrue(logs[i].date <= logs[i+1].date) + i +=1 + + logs = LogEntry.objects.order_by("-date") + count = logs.count() + i = 0 + while i == count-1: + self.assertTrue(logs[i].date >= logs[i+1].date) + i +=1 + + # Test searching + logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980,1,1)) + self.assertEqual(logs.count(), 30) + + logs = LogEntry.objects.filter(date__lte=datetime.datetime(1980,1,1)) + self.assertEqual(logs.count(), 30) + + logs = LogEntry.objects.filter( + date__lte=datetime.datetime(2011,1,1), + date__gte=datetime.datetime(2000,1,1), + ) + self.assertEqual(logs.count(), 10) + + LogEntry.drop_collection() + def test_list_validation(self): """Ensure that a list field only accepts lists with valid elements. """ From fb09fde2097bd557a9173c749f45a8688cf62050 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Thu, 9 Jun 2011 14:26:52 +0100 Subject: [PATCH 12/13] Updated changelog --- docs/changelog.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index f4be4ca6..0bbb5b82 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,7 @@ Changelog Changes in dev ============== +- Added ComplexDateTimeField - Handles datetimes correctly with microseconds - Added ComplexBaseField - for improved flexibility and performance. - Added get_FIELD_display() method for easy choice field displaying. - Added queryset.slave_okay(enabled) method From fd7f882011ce548efd7ae5fcb0f59fd38d38e98b Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Thu, 9 Jun 2011 16:09:06 +0100 Subject: [PATCH 13/13] Save no longer tramples over documents now sets or unsets explicit fields. Fixes #146, refs #18 Thanks @zhangcheng for the initial code --- docs/changelog.rst | 5 ++- mongoengine/base.py | 9 +++-- mongoengine/document.py | 10 +++++ setup.py | 2 +- tests/document.py | 84 +++++++++++++++++++++++++++++++++++++++++ 5 files changed, 104 insertions(+), 6 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 0bbb5b82..ecd7ef57 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,9 +5,10 @@ Changelog Changes in dev ============== +- Fixed saving so sets updated values rather than overwrites - Added ComplexDateTimeField - Handles datetimes correctly with microseconds -- Added ComplexBaseField - for improved flexibility and performance. -- Added get_FIELD_display() method for easy choice field displaying. +- Added ComplexBaseField - for improved flexibility and performance +- Added get_FIELD_display() method for easy choice field displaying - Added queryset.slave_okay(enabled) method - Updated queryset.timeout(enabled) and queryset.snapshot(enabled) to be chainable - Added insert method for bulk inserts diff --git a/mongoengine/base.py b/mongoengine/base.py index a22795c7..aed17bc3 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -92,6 +92,9 @@ class BaseField(object): """Descriptor for assigning a value to a field in a document. """ instance._data[self.name] = value + # If the field set is in the _present_fields list add it so we can track + if hasattr(instance, '_present_fields') and self.name not in instance._present_fields: + instance._present_fields.append(self.name) def to_python(self, value): """Convert a MongoDB-compatible type to a Python type. @@ -592,13 +595,14 @@ class BaseDocument(object): if field.choices: # dynamically adds a way to get the display value for a field with choices setattr(self, 'get_%s_display' % attr_name, partial(self._get_FIELD_display, field=field)) - # Use default value if present value = getattr(self, attr_name, None) setattr(self, attr_name, value) + # Assign initial values to instance for attr_name in values.keys(): try: - setattr(self, attr_name, values.pop(attr_name)) + value = values.pop(attr_name) + setattr(self, attr_name, value) except AttributeError: pass @@ -739,7 +743,6 @@ class BaseDocument(object): cls = subclasses[class_name] present_fields = data.keys() - for field_name, field in cls._fields.items(): if field.db_field in data: value = data[field.db_field] diff --git a/mongoengine/document.py b/mongoengine/document.py index cae8343d..e25bea06 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -95,6 +95,16 @@ class Document(BaseDocument): collection = self.__class__.objects._collection if force_insert: object_id = collection.insert(doc, safe=safe, **write_options) + elif '_id' in doc: + # Perform a set rather than a save - this will only save set fields + object_id = doc.pop('_id') + collection.update({'_id': object_id}, {"$set": doc}, upsert=True, safe=safe, **write_options) + + # Find and unset any fields explicitly set to None + if hasattr(self, '_present_fields'): + removals = dict([(k, 1) for k in self._present_fields if k not in doc and k != '_id']) + if removals: + collection.update({'_id': object_id}, {"$unset": removals}, upsert=True, safe=safe, **write_options) else: object_id = collection.save(doc, safe=safe, **write_options) except pymongo.errors.OperationFailure, err: diff --git a/setup.py b/setup.py index 1f65ae5d..37ec4375 100644 --- a/setup.py +++ b/setup.py @@ -45,6 +45,6 @@ setup(name='mongoengine', long_description=LONG_DESCRIPTION, platforms=['any'], classifiers=CLASSIFIERS, - install_requires=['pymongo', 'blinker', 'django>=1.3'], + install_requires=['pymongo', 'blinker', 'django==1.3'], test_suite='tests', ) diff --git a/tests/document.py b/tests/document.py index 14541469..f0af8f2d 100644 --- a/tests/document.py +++ b/tests/document.py @@ -789,6 +789,90 @@ class DocumentTest(unittest.TestCase): except ValidationError: self.fail() + def test_update(self): + """Ensure that an existing document is updated instead of be overwritten. + """ + # Create person object and save it to the database + person = self.Person(name='Test User', age=30) + person.save() + + # Create same person object, with same id, without age + same_person = self.Person(name='Test') + same_person.id = person.id + same_person.save() + + # Confirm only one object + self.assertEquals(self.Person.objects.count(), 1) + + # reload + person.reload() + same_person.reload() + + # Confirm the same + self.assertEqual(person, same_person) + self.assertEqual(person.name, same_person.name) + self.assertEqual(person.age, same_person.age) + + # Confirm the saved values + self.assertEqual(person.name, 'Test') + self.assertEqual(person.age, 30) + + # Test only / exclude only updates included fields + person = self.Person.objects.only('name').get() + person.name = 'User' + person.save() + + person.reload() + self.assertEqual(person.name, 'User') + self.assertEqual(person.age, 30) + + # test exclude only updates set fields + person = self.Person.objects.exclude('name').get() + person.age = 21 + person.save() + + person.reload() + self.assertEqual(person.name, 'User') + self.assertEqual(person.age, 21) + + # Test only / exclude can set non excluded / included fields + person = self.Person.objects.only('name').get() + person.name = 'Test' + person.age = 30 + person.save() + + person.reload() + self.assertEqual(person.name, 'Test') + self.assertEqual(person.age, 30) + + # test exclude only updates set fields + person = self.Person.objects.exclude('name').get() + person.name = 'User' + person.age = 21 + person.save() + + person.reload() + self.assertEqual(person.name, 'User') + self.assertEqual(person.age, 21) + + # Confirm does remove unrequired fields + person = self.Person.objects.exclude('name').get() + person.age = None + person.save() + + person.reload() + self.assertEqual(person.name, 'User') + self.assertEqual(person.age, None) + + person = self.Person.objects.get() + person.name = None + person.age = None + person.save() + + person.reload() + self.assertEqual(person.name, None) + self.assertEqual(person.age, None) + def test_delete(self): """Ensure that document may be deleted using the delete method. """