373 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			373 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from collections import defaultdict
 | |
| 
 | |
| import pymongo
 | |
| from bson import SON
 | |
| 
 | |
| from mongoengine.base.fields import UPDATE_OPERATORS
 | |
| from mongoengine.connection import get_connection
 | |
| from mongoengine.common import _import_class
 | |
| from mongoengine.errors import InvalidQueryError, LookUpError
 | |
| 
 | |
| __all__ = ('query', 'update')
 | |
| 
 | |
| 
 | |
| COMPARISON_OPERATORS = ('ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod',
 | |
|                         'all', 'size', 'exists', 'not', 'elemMatch', 'type')
 | |
| GEO_OPERATORS = ('within_distance', 'within_spherical_distance',
 | |
|                  'within_box', 'within_polygon', 'near', 'near_sphere',
 | |
|                  'max_distance', 'geo_within', 'geo_within_box',
 | |
|                  'geo_within_polygon', 'geo_within_center',
 | |
|                  'geo_within_sphere', 'geo_intersects')
 | |
| STRING_OPERATORS = ('contains', 'icontains', 'startswith',
 | |
|                     'istartswith', 'endswith', 'iendswith',
 | |
|                     'exact', 'iexact')
 | |
| CUSTOM_OPERATORS = ('match',)
 | |
| MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS +
 | |
|                    STRING_OPERATORS + CUSTOM_OPERATORS)
 | |
| 
 | |
| 
 | |
| def query(_doc_cls=None, _field_operation=False, **query):
 | |
|     """Transform a query from Django-style format to Mongo format.
 | |
|     """
 | |
|     mongo_query = {}
 | |
|     merge_query = defaultdict(list)
 | |
|     for key, value in sorted(query.items()):
 | |
|         if key == "__raw__":
 | |
|             mongo_query.update(value)
 | |
|             continue
 | |
| 
 | |
|         parts = key.rsplit('__')
 | |
|         indices = [(i, p) for i, p in enumerate(parts) if p.isdigit()]
 | |
|         parts = [part for part in parts if not part.isdigit()]
 | |
|         # Check for an operator and transform to mongo-style if there is
 | |
|         op = None
 | |
|         if len(parts) > 1 and parts[-1] in MATCH_OPERATORS:
 | |
|             op = parts.pop()
 | |
| 
 | |
|         if len(parts) > 1 and not parts[-1]:
 | |
|             parts.pop()
 | |
| 
 | |
|         negate = False
 | |
|         if len(parts) > 1 and parts[-1] == 'not':
 | |
|             parts.pop()
 | |
|             negate = True
 | |
| 
 | |
|         if _doc_cls:
 | |
|             # Switch field names to proper names [set in Field(name='foo')]
 | |
|             try:
 | |
|                 fields = _doc_cls._lookup_field(parts)
 | |
|             except Exception, e:
 | |
|                 raise InvalidQueryError(e)
 | |
|             parts = []
 | |
| 
 | |
|             CachedReferenceField = _import_class('CachedReferenceField')
 | |
| 
 | |
|             cleaned_fields = []
 | |
|             for field in fields:
 | |
|                 append_field = True
 | |
|                 if isinstance(field, basestring):
 | |
|                     parts.append(field)
 | |
|                     append_field = False
 | |
|                 # is last and CachedReferenceField
 | |
|                 elif isinstance(field, CachedReferenceField) and fields[-1] == field:
 | |
|                     parts.append('%s._id' % field.db_field)
 | |
|                 else:
 | |
|                     parts.append(field.db_field)
 | |
| 
 | |
|                 if append_field:
 | |
|                     cleaned_fields.append(field)
 | |
| 
 | |
|             # Convert value to proper value
 | |
|             field = cleaned_fields[-1]
 | |
| 
 | |
|             singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not']
 | |
|             singular_ops += STRING_OPERATORS
 | |
|             if op in singular_ops:
 | |
|                 if isinstance(field, basestring):
 | |
|                     if (op in STRING_OPERATORS and
 | |
|                             isinstance(value, basestring)):
 | |
|                         StringField = _import_class('StringField')
 | |
|                         value = StringField.prepare_query_value(op, value)
 | |
|                     else:
 | |
|                         value = field
 | |
|                 else:
 | |
|                     value = field.prepare_query_value(op, value)
 | |
| 
 | |
|                     if isinstance(field, CachedReferenceField) and value:
 | |
|                         value = value['_id']
 | |
| 
 | |
|             elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict):
 | |
|                 # 'in', 'nin' and 'all' require a list of values
 | |
|                 value = [field.prepare_query_value(op, v) for v in value]
 | |
| 
 | |
|         # if op and op not in COMPARISON_OPERATORS:
 | |
|         if op:
 | |
|             if op in GEO_OPERATORS:
 | |
|                 value = _geo_operator(field, op, value)
 | |
|             elif op in CUSTOM_OPERATORS:
 | |
|                 if op in ('elem_match', 'match'):
 | |
|                     value = field.prepare_query_value(op, value)
 | |
|                     value = {"$elemMatch": value}
 | |
|                 else:
 | |
|                     NotImplementedError("Custom method '%s' has not "
 | |
|                                         "been implemented" % op)
 | |
|             elif op not in STRING_OPERATORS:
 | |
|                 value = {'$' + op: value}
 | |
| 
 | |
|         if negate:
 | |
|             value = {'$not': value}
 | |
| 
 | |
|         for i, part in indices:
 | |
|             parts.insert(i, part)
 | |
|         key = '.'.join(parts)
 | |
|         if op is None or key not in mongo_query:
 | |
|             mongo_query[key] = value
 | |
|         elif key in mongo_query:
 | |
|             if key in mongo_query and isinstance(mongo_query[key], dict):
 | |
|                 mongo_query[key].update(value)
 | |
|                 # $maxDistance needs to come last - convert to SON
 | |
|                 value_dict = mongo_query[key]
 | |
|                 if ('$maxDistance' in value_dict and '$near' in value_dict):
 | |
|                     value_son = SON()
 | |
|                     if isinstance(value_dict['$near'], dict):
 | |
|                         for k, v in value_dict.iteritems():
 | |
|                             if k == '$maxDistance':
 | |
|                                 continue
 | |
|                             value_son[k] = v
 | |
|                         if (get_connection().max_wire_version <= 1):
 | |
|                             value_son['$maxDistance'] = value_dict[
 | |
|                                 '$maxDistance']
 | |
|                         else:
 | |
|                             value_son['$near'] = SON(value_son['$near'])
 | |
|                             value_son['$near'][
 | |
|                                 '$maxDistance'] = value_dict['$maxDistance']
 | |
|                     else:
 | |
|                         for k, v in value_dict.iteritems():
 | |
|                             if k == '$maxDistance':
 | |
|                                 continue
 | |
|                             value_son[k] = v
 | |
|                         value_son['$maxDistance'] = value_dict['$maxDistance']
 | |
| 
 | |
|                     mongo_query[key] = value_son
 | |
|             else:
 | |
|                 # Store for manually merging later
 | |
|                 merge_query[key].append(value)
 | |
| 
 | |
|     # The queryset has been filter in such a way we must manually merge
 | |
|     for k, v in merge_query.items():
 | |
|         merge_query[k].append(mongo_query[k])
 | |
|         del mongo_query[k]
 | |
|         if isinstance(v, list):
 | |
|             value = [{k: val} for val in v]
 | |
|             if '$and' in mongo_query.keys():
 | |
|                 mongo_query['$and'].extend(value)
 | |
|             else:
 | |
|                 mongo_query['$and'] = value
 | |
| 
 | |
|     return mongo_query
 | |
| 
 | |
| 
 | |
| def update(_doc_cls=None, **update):
 | |
|     """Transform an update spec from Django-style format to Mongo format.
 | |
|     """
 | |
|     mongo_update = {}
 | |
|     for key, value in update.items():
 | |
|         if key == "__raw__":
 | |
|             mongo_update.update(value)
 | |
|             continue
 | |
|         parts = key.split('__')
 | |
|         # if there is no operator, default to "set"
 | |
|         if len(parts) < 3 and parts[0] not in UPDATE_OPERATORS:
 | |
|             parts.insert(0, 'set')
 | |
|         # Check for an operator and transform to mongo-style if there is
 | |
|         op = None
 | |
|         if parts[0] in UPDATE_OPERATORS:
 | |
|             op = parts.pop(0)
 | |
|             # Convert Pythonic names to Mongo equivalents
 | |
|             if op in ('push_all', 'pull_all'):
 | |
|                 op = op.replace('_all', 'All')
 | |
|             elif op == 'dec':
 | |
|                 # Support decrement by flipping a positive value's sign
 | |
|                 # and using 'inc'
 | |
|                 op = 'inc'
 | |
|                 if value > 0:
 | |
|                     value = -value
 | |
|             elif op == 'add_to_set':
 | |
|                 op = 'addToSet'
 | |
|             elif op == 'set_on_insert':
 | |
|                 op = "setOnInsert"
 | |
| 
 | |
|         match = None
 | |
|         if parts[-1] in COMPARISON_OPERATORS:
 | |
|             match = parts.pop()
 | |
| 
 | |
|         if _doc_cls:
 | |
|             # Switch field names to proper names [set in Field(name='foo')]
 | |
|             try:
 | |
|                 fields = _doc_cls._lookup_field(parts)
 | |
|             except Exception, e:
 | |
|                 raise InvalidQueryError(e)
 | |
|             parts = []
 | |
| 
 | |
|             cleaned_fields = []
 | |
|             appended_sub_field = False
 | |
|             for field in fields:
 | |
|                 append_field = True
 | |
|                 if isinstance(field, basestring):
 | |
|                     # Convert the S operator to $
 | |
|                     if field == 'S':
 | |
|                         field = '$'
 | |
|                     parts.append(field)
 | |
|                     append_field = False
 | |
|                 else:
 | |
|                     parts.append(field.db_field)
 | |
|                 if append_field:
 | |
|                     appended_sub_field = False
 | |
|                     cleaned_fields.append(field)
 | |
|                     if hasattr(field, 'field'):
 | |
|                         cleaned_fields.append(field.field)
 | |
|                         appended_sub_field = True
 | |
| 
 | |
|             # Convert value to proper value
 | |
|             if appended_sub_field:
 | |
|                 field = cleaned_fields[-2]
 | |
|             else:
 | |
|                 field = cleaned_fields[-1]
 | |
| 
 | |
|             GeoJsonBaseField = _import_class("GeoJsonBaseField")
 | |
|             if isinstance(field, GeoJsonBaseField):
 | |
|                 value = field.to_mongo(value)
 | |
| 
 | |
|             if op in (None, 'set', 'push', 'pull'):
 | |
|                 if field.required or value is not None:
 | |
|                     value = field.prepare_query_value(op, value)
 | |
|             elif op in ('pushAll', 'pullAll'):
 | |
|                 value = [field.prepare_query_value(op, v) for v in value]
 | |
|             elif op in ('addToSet', 'setOnInsert'):
 | |
|                 if isinstance(value, (list, tuple, set)):
 | |
|                     value = [field.prepare_query_value(op, v) for v in value]
 | |
|                 elif field.required or value is not None:
 | |
|                     value = field.prepare_query_value(op, value)
 | |
|             elif op == "unset":
 | |
|                 value = 1
 | |
| 
 | |
|         if match:
 | |
|             match = '$' + match
 | |
|             value = {match: value}
 | |
| 
 | |
|         key = '.'.join(parts)
 | |
| 
 | |
|         if not op:
 | |
|             raise InvalidQueryError("Updates must supply an operation "
 | |
|                                     "eg: set__FIELD=value")
 | |
| 
 | |
|         if 'pull' in op and '.' in key:
 | |
|             # Dot operators don't work on pull operations
 | |
|             # unless they point to a list field
 | |
|             # Otherwise it uses nested dict syntax
 | |
|             if op == 'pullAll':
 | |
|                 raise InvalidQueryError("pullAll operations only support "
 | |
|                                         "a single field depth")
 | |
| 
 | |
|             # Look for the last list field and use dot notation until there
 | |
|             field_classes = [c.__class__ for c in cleaned_fields]
 | |
|             field_classes.reverse()
 | |
|             ListField = _import_class('ListField')
 | |
|             if ListField in field_classes:
 | |
|                 # Join all fields via dot notation to the last ListField
 | |
|                 # Then process as normal
 | |
|                 last_listField = len(
 | |
|                     cleaned_fields) - field_classes.index(ListField)
 | |
|                 key = ".".join(parts[:last_listField])
 | |
|                 parts = parts[last_listField:]
 | |
|                 parts.insert(0, key)
 | |
| 
 | |
|             parts.reverse()
 | |
|             for key in parts:
 | |
|                 value = {key: value}
 | |
|         elif op == 'addToSet' and isinstance(value, list):
 | |
|             value = {key: {"$each": value}}
 | |
|         else:
 | |
|             value = {key: value}
 | |
|         key = '$' + op
 | |
| 
 | |
|         if key not in mongo_update:
 | |
|             mongo_update[key] = value
 | |
|         elif key in mongo_update and isinstance(mongo_update[key], dict):
 | |
|             mongo_update[key].update(value)
 | |
| 
 | |
|     return mongo_update
 | |
| 
 | |
| 
 | |
| def _geo_operator(field, op, value):
 | |
|     """Helper to return the query for a given geo query"""
 | |
|     if field._geo_index == pymongo.GEO2D:
 | |
|         if op == "within_distance":
 | |
|             value = {'$within': {'$center': value}}
 | |
|         elif op == "within_spherical_distance":
 | |
|             value = {'$within': {'$centerSphere': value}}
 | |
|         elif op == "within_polygon":
 | |
|             value = {'$within': {'$polygon': value}}
 | |
|         elif op == "near":
 | |
|             value = {'$near': value}
 | |
|         elif op == "near_sphere":
 | |
|             value = {'$nearSphere': value}
 | |
|         elif op == 'within_box':
 | |
|             value = {'$within': {'$box': value}}
 | |
|         elif op == "max_distance":
 | |
|             value = {'$maxDistance': value}
 | |
|         else:
 | |
|             raise NotImplementedError("Geo method '%s' has not "
 | |
|                                       "been implemented for a GeoPointField" % op)
 | |
|     else:
 | |
|         if op == "geo_within":
 | |
|             value = {"$geoWithin": _infer_geometry(value)}
 | |
|         elif op == "geo_within_box":
 | |
|             value = {"$geoWithin": {"$box": value}}
 | |
|         elif op == "geo_within_polygon":
 | |
|             value = {"$geoWithin": {"$polygon": value}}
 | |
|         elif op == "geo_within_center":
 | |
|             value = {"$geoWithin": {"$center": value}}
 | |
|         elif op == "geo_within_sphere":
 | |
|             value = {"$geoWithin": {"$centerSphere": value}}
 | |
|         elif op == "geo_intersects":
 | |
|             value = {"$geoIntersects": _infer_geometry(value)}
 | |
|         elif op == "near":
 | |
|             value = {'$near': _infer_geometry(value)}
 | |
|         elif op == "max_distance":
 | |
|             value = {'$maxDistance': value}
 | |
|         else:
 | |
|             raise NotImplementedError("Geo method '%s' has not "
 | |
|                                       "been implemented for a %s " % (op, field._name))
 | |
|     return value
 | |
| 
 | |
| 
 | |
| def _infer_geometry(value):
 | |
|     """Helper method that tries to infer the $geometry shape for a given value"""
 | |
|     if isinstance(value, dict):
 | |
|         if "$geometry" in value:
 | |
|             return value
 | |
|         elif 'coordinates' in value and 'type' in value:
 | |
|             return {"$geometry": value}
 | |
|         raise InvalidQueryError("Invalid $geometry dictionary should have "
 | |
|                                 "type and coordinates keys")
 | |
|     elif isinstance(value, (list, set)):
 | |
|         try:
 | |
|             value[0][0][0]
 | |
|             return {"$geometry": {"type": "Polygon", "coordinates": value}}
 | |
|         except:
 | |
|             pass
 | |
|         try:
 | |
|             value[0][0]
 | |
|             return {"$geometry": {"type": "LineString", "coordinates": value}}
 | |
|         except:
 | |
|             pass
 | |
|         try:
 | |
|             value[0]
 | |
|             return {"$geometry": {"type": "Point", "coordinates": value}}
 | |
|         except:
 | |
|             pass
 | |
| 
 | |
|     raise InvalidQueryError("Invalid $geometry data. Can be either a dictionary "
 | |
|                             "or (nested) lists of coordinate(s)")
 |