diff --git a/mongoengine/base.py b/mongoengine/base.py index c8c162b4..c1306ff5 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -22,6 +22,7 @@ class BaseField(object): # Fields may have _types inserted into indexes by default _index_with_types = True + _geo_index = False def __init__(self, db_field=None, name=None, required=False, default=None, unique=False, unique_with=None, primary_key=False, validation=None, diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 127f029f..f9fa4dee 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -13,7 +13,7 @@ __all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField', 'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField', 'ObjectIdField', 'ReferenceField', 'ValidationError', 'DecimalField', 'URLField', 'GenericReferenceField', - 'BinaryField', 'SortedListField', 'EmailField', 'GeoLocationField'] + 'BinaryField', 'SortedListField', 'EmailField', 'GeoPointField'] RECURSIVE_REFERENCE_CONSTANT = 'self' @@ -369,23 +369,6 @@ class DictField(BaseField): def lookup_member(self, member_name): return self.basecls(db_field=member_name) -class GeoLocationField(DictField): - """Supports geobased fields""" - - def validate(self, value): - """Make sure that a geo-value is of type (x, y) - """ - if not isinstance(value, tuple) and not isinstance(value, list): - raise ValidationError('GeoLocationField can only hold tuples or lists of (x, y)') - - if len(value) <> 2: - raise ValidationError('GeoLocationField must have exactly two elements (x, y)') - - def to_mongo(self, value): - return {'x': value[0], 'y': value[1]} - - def to_python(self, value): - return value.keys() class ReferenceField(BaseField): """A reference to a document that will be automatically dereferenced on @@ -500,6 +483,7 @@ class GenericReferenceField(BaseField): def prepare_query_value(self, op, value): return self.to_mongo(value)['_ref'] + class BinaryField(BaseField): """A binary data field. """ @@ -520,3 +504,24 @@ class BinaryField(BaseField): if self.max_bytes is not None and len(value) > self.max_bytes: raise ValidationError('Binary value is too long') + + +class GeoPointField(BaseField): + """A list storing a latitude and longitude. + """ + + _geo_index = True + + def validate(self, value): + """Make sure that a geo-value is of type (x, y) + """ + if not isinstance(value, (list, tuple)): + raise ValidationError('GeoPointField can only accept tuples or ' + 'lists of (x, y)') + + if not len(value) == 2: + raise ValidationError('Value must be a two-dimensional point.') + if (not isinstance(value[0], (float, int)) and + not isinstance(value[1], (float, int))): + raise ValidationError('Both values in point must be float or int.') + diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 069ab113..5a837c4f 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -241,21 +241,22 @@ class QuerySet(object): # Ensure document-defined indexes are created if self._document._meta['indexes']: for key_or_list in self._document._meta['indexes']: - #self.ensure_index(key_or_list) self._collection.ensure_index(key_or_list) # Ensure indexes created by uniqueness constraints for index in self._document._meta['unique_indexes']: self._collection.ensure_index(index, unique=True) - + # If _types is being used (for polymorphism), it needs an index if '_types' in self._query: self._collection.ensure_index('_types') # Ensure all needed field indexes are created - for field_name, field_instance in self._document._fields.iteritems(): - if field_instance.__class__.__name__ == 'GeoLocationField': - self._collection.ensure_index([(field_name, pymongo.GEO2D),]) + for field in self._document._fields.values(): + if field.__class__._geo_index: + index_spec = [(field.db_field, pymongo.GEO2D)] + self._collection.ensure_index(index_spec) + return self._collection_obj @property @@ -311,9 +312,10 @@ class QuerySet(object): """Transform a query from Django-style format to Mongo format. """ operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', - 'all', 'size', 'exists', 'near'] - match_operators = ['contains', 'icontains', 'startswith', - 'istartswith', 'endswith', 'iendswith', + 'all', 'size', 'exists'] + geo_operators = ['within_distance', 'within_box', 'near'] + match_operators = ['contains', 'icontains', 'startswith', + 'istartswith', 'endswith', 'iendswith', 'exact', 'iexact'] mongo_query = {} @@ -321,7 +323,7 @@ class QuerySet(object): parts = key.split('__') # Check for an operator and transform to mongo-style if there is op = None - if parts[-1] in operators + match_operators: + if parts[-1] in operators + match_operators + geo_operators: op = parts.pop() if _doc_cls: @@ -335,15 +337,25 @@ class QuerySet(object): singular_ops += match_operators if op in singular_ops: value = field.prepare_query_value(op, value) - elif op in ('in', 'nin', 'all'): + elif op in ('in', 'nin', 'all', 'near'): # 'in', 'nin' and 'all' require a list of values value = [field.prepare_query_value(op, v) for v in value] if field.__class__.__name__ == 'GenericReferenceField': parts.append('_ref') - if op and op not in match_operators: - value = {'$' + op: value} + # if op and op not in match_operators: + if op: + if op in geo_operators: + if op == "within_distance": + value = {'$within': {'$center': value}} + elif op == "near": + value = {'$near': value} + else: + raise NotImplementedError("Geo method '%s' has not " + "been implemented" % op) + elif op not in match_operators: + value = {'$' + op: value} key = '.'.join(parts) if op is None or key not in mongo_query: diff --git a/tests/queryset.py b/tests/queryset.py index 51f92993..a7719b88 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -1164,6 +1164,77 @@ class QuerySetTest(unittest.TestCase): def tearDown(self): self.Person.drop_collection() + def test_geospatial_operators(self): + """Ensure that geospatial queries are working. + """ + class Event(Document): + title = StringField() + date = DateTimeField() + location = GeoPointField() + + def __unicode__(self): + return self.title + + meta = {'geo_indexes': ["location"]} + + Event.drop_collection() + + event1 = Event(title="Coltrane Motion @ Double Door", + date=datetime.now() - timedelta(days=1), + location=[41.909889, -87.677137]) + event2 = Event(title="Coltrane Motion @ Bottom of the Hill", + date=datetime.now() - timedelta(days=10), + location=[37.7749295, -122.4194155]) + event3 = Event(title="Coltrane Motion @ Empty Bottle", + date=datetime.now(), + location=[41.900474, -87.686638]) + + event1.save() + event2.save() + event3.save() + + # find all events "near" pitchfork office, chicago. + # note that "near" will show the san francisco event, too, + # although it sorts to last. + events = Event.objects(location__near=[41.9120459, -87.67892]) + self.assertEqual(events.count(), 3) + self.assertEqual(list(events), [event1, event3, event2]) + + # find events within 5 miles of pitchfork office, chicago + point_and_distance = [[41.9120459, -87.67892], 5] + events = Event.objects(location__within_distance=point_and_distance) + self.assertEqual(events.count(), 2) + events = list(events) + self.assertTrue(event2 not in events) + self.assertTrue(event1 in events) + self.assertTrue(event3 in events) + + # ensure ordering is respected by "near" + events = Event.objects(location__near=[41.9120459, -87.67892]) + events = events.order_by("-date") + self.assertEqual(events.count(), 3) + self.assertEqual(list(events), [event3, event1, event2]) + + # find events around san francisco + point_and_distance = [[37.7566023, -122.415579], 10] + events = Event.objects(location__within_distance=point_and_distance) + self.assertEqual(events.count(), 1) + self.assertEqual(events[0], event2) + + # find events within 1 mile of greenpoint, broolyn, nyc, ny + point_and_distance = [[40.7237134, -73.9509714], 1] + events = Event.objects(location__within_distance=point_and_distance) + self.assertEqual(events.count(), 0) + + # ensure ordering is respected by "within_distance" + point_and_distance = [[41.9120459, -87.67892], 10] + events = Event.objects(location__within_distance=point_and_distance) + events = events.order_by("-date") + self.assertEqual(events.count(), 2) + self.assertEqual(events[0], event3) + + Event.drop_collection() + class QTest(unittest.TestCase):