From 08ba51f714731a6fc27340cae8cdb47bf1d60302 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Mon, 20 Jun 2011 15:41:23 +0100 Subject: [PATCH] Updated geo_index checking to be recursive Fixes #127 - Embedded Documents can declare geo indexes and have them created automatically --- docs/changelog.rst | 1 + mongoengine/base.py | 194 +++++++++++++++++++++------------------- mongoengine/queryset.py | 11 ++- tests/fields.py | 21 +++++ 4 files changed, 129 insertions(+), 98 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index cfae79e0..0737171c 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,7 @@ Changelog Changes in dev ============== +- Updated geo index checking to be recursive and check in embedded documents - Updated default collection naming convention - Added Document Mixin support - Fixed queryet __repr__ mid iteration diff --git a/mongoengine/base.py b/mongoengine/base.py index 94f00cbf..12c760aa 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -644,28 +644,6 @@ class BaseDocument(object): signals.post_init.send(self.__class__, document=self) - def __getstate__(self): - self_dict = self.__dict__ - removals = ["get_%s_display" % k for k,v in self._fields.items() if v.choices] - for k in removals: - if hasattr(self, k): - delattr(self, k) - return self.__dict__ - - def __setstate__(self, __dict__): - self.__dict__ = __dict__ - self.__set_field_display() - - def __set_field_display(self): - for attr_name, field in self._fields.items(): - if field.choices: # dynamically adds a way to get the display value for a field with choices - setattr(self, 'get_%s_display' % attr_name, partial(self.__get_field_display, field=field)) - - def __get_field_display(self, field): - """Returns the display value for a choice field""" - value = getattr(self, field.name) - return dict(field.choices).get(value, value) - def validate(self): """Ensure that all fields' values are valid and that required fields are present. @@ -685,6 +663,33 @@ class BaseDocument(object): elif field.required: raise ValidationError('Field "%s" is required' % field.name) + @apply + def pk(): + """Primary key alias + """ + def fget(self): + return getattr(self, self._meta['id_field']) + def fset(self, value): + return setattr(self, self._meta['id_field'], value) + return property(fget, fset) + + def to_mongo(self): + """Return data dictionary ready for use with MongoDB. + """ + data = {} + for field_name, field in self._fields.items(): + value = getattr(self, field_name, None) + if value is not None: + data[field.db_field] = field.to_mongo(value) + # Only add _cls and _types if allow_inheritance is not False + if not (hasattr(self, '_meta') and + self._meta.get('allow_inheritance', True) == False): + data['_cls'] = self._class_name + data['_types'] = self._superclasses.keys() + [self._class_name] + if '_id' in data and data['_id'] is None: + del data['_id'] + return data + @classmethod def _get_collection_name(cls): """Returns the collection name for this class. @@ -706,76 +711,6 @@ class BaseDocument(object): all_subclasses.update(subclass._get_subclasses()) return all_subclasses - @apply - def pk(): - """Primary key alias - """ - def fget(self): - return getattr(self, self._meta['id_field']) - def fset(self, value): - return setattr(self, self._meta['id_field'], value) - return property(fget, fset) - - def __iter__(self): - return iter(self._fields) - - def __getitem__(self, name): - """Dictionary-style field access, return a field's value if present. - """ - try: - if name in self._fields: - return getattr(self, name) - except AttributeError: - pass - raise KeyError(name) - - def __setitem__(self, name, value): - """Dictionary-style field access, set a field's value. - """ - # Ensure that the field exists before settings its value - if name not in self._fields: - raise KeyError(name) - return setattr(self, name, value) - - def __contains__(self, name): - try: - val = getattr(self, name) - return val is not None - except AttributeError: - return False - - def __len__(self): - return len(self._data) - - def __repr__(self): - try: - u = unicode(self) - except (UnicodeEncodeError, UnicodeDecodeError): - u = '[Bad Unicode data]' - return u'<%s: %s>' % (self.__class__.__name__, u) - - def __str__(self): - if hasattr(self, '__unicode__'): - return unicode(self).encode('utf-8') - return '%s object' % self.__class__.__name__ - - def to_mongo(self): - """Return data dictionary ready for use with MongoDB. - """ - data = {} - for field_name, field in self._fields.items(): - value = getattr(self, field_name, None) - if value is not None: - data[field.db_field] = field.to_mongo(value) - # Only add _cls and _types if allow_inheritance is not False - if not (hasattr(self, '_meta') and - self._meta.get('allow_inheritance', True) == False): - data['_cls'] = self._class_name - data['_types'] = self._superclasses.keys() + [self._class_name] - if '_id' in data and data['_id'] is None: - del data['_id'] - return data - @classmethod def _from_son(cls, son): """Create an instance of a Document (subclass) from a PyMongo SON. @@ -874,6 +809,81 @@ class BaseDocument(object): unset_data[k] = 1 return set_data, unset_data + @classmethod + def _geo_indices(cls): + geo_indices = [] + for field in cls._fields.values(): + if hasattr(field, 'document_type'): + geo_indices += field.document_type._geo_indices() + elif field._geo_index: + geo_indices.append(field) + return geo_indices + + def __getstate__(self): + self_dict = self.__dict__ + removals = ["get_%s_display" % k for k,v in self._fields.items() if v.choices] + for k in removals: + if hasattr(self, k): + delattr(self, k) + return self.__dict__ + + def __setstate__(self, __dict__): + self.__dict__ = __dict__ + self.__set_field_display() + + def __set_field_display(self): + for attr_name, field in self._fields.items(): + if field.choices: # dynamically adds a way to get the display value for a field with choices + setattr(self, 'get_%s_display' % attr_name, partial(self.__get_field_display, field=field)) + + def __get_field_display(self, field): + """Returns the display value for a choice field""" + value = getattr(self, field.name) + return dict(field.choices).get(value, value) + + def __iter__(self): + return iter(self._fields) + + def __getitem__(self, name): + """Dictionary-style field access, return a field's value if present. + """ + try: + if name in self._fields: + return getattr(self, name) + except AttributeError: + pass + raise KeyError(name) + + def __setitem__(self, name, value): + """Dictionary-style field access, set a field's value. + """ + # Ensure that the field exists before settings its value + if name not in self._fields: + raise KeyError(name) + return setattr(self, name, value) + + def __contains__(self, name): + try: + val = getattr(self, name) + return val is not None + except AttributeError: + return False + + def __len__(self): + return len(self._data) + + def __repr__(self): + try: + u = unicode(self) + except (UnicodeEncodeError, UnicodeDecodeError): + u = '[Bad Unicode data]' + return u'<%s: %s>' % (self.__class__.__name__, u) + + def __str__(self): + if hasattr(self, '__unicode__'): + return unicode(self).encode('utf-8') + return '%s object' % self.__class__.__name__ + def __eq__(self, other): if isinstance(other, self.__class__) and hasattr(other, 'id'): if self.id == other.id: diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 2a5d3edb..e2947a00 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -494,12 +494,11 @@ class QuerySet(object): self._collection.ensure_index('_types', background=background, **index_opts) - # Ensure all needed field indexes are created - for field in self._document._fields.values(): - if field.__class__._geo_index: - index_spec = [(field.db_field, pymongo.GEO2D)] - self._collection.ensure_index(index_spec, - background=background, **index_opts) + # Add geo indicies + for field in self._document._geo_indices(): + index_spec = [(field.db_field, pymongo.GEO2D)] + self._collection.ensure_index(index_spec, + background=background, **index_opts) return self._collection_obj diff --git a/tests/fields.py b/tests/fields.py index 22049309..fe53d9e7 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -1312,6 +1312,27 @@ class FieldTest(unittest.TestCase): Event.drop_collection() + def test_geo_embedded_indexes(self): + """Ensure that indexes are created automatically for GeoPointFields on + embedded documents. + """ + class Venue(EmbeddedDocument): + location = GeoPointField() + name = StringField() + + class Event(Document): + title = StringField() + venue = EmbeddedDocumentField(Venue) + + Event.drop_collection() + venue = Venue(name="Double Door", location=[41.909889, -87.677137]) + event = Event(title="Coltrane Motion", venue=venue) + event.save() + + info = Event.objects._collection.index_information() + self.assertTrue(u'location_2d' in info) + self.assertTrue(info[u'location_2d']['key'] == [(u'location', u'2d')]) + def test_ensure_unique_default_instances(self): """Ensure that every field has it's own unique default instance.""" class D(Document):