pk as a property with a setter + get rid of basestring

This commit is contained in:
Stefan Wojcik 2016-12-06 23:02:08 -05:00
parent 557f9bd971
commit 7dd4639037
10 changed files with 73 additions and 66 deletions

View File

@ -30,8 +30,6 @@ def get_version():
"""Return the VERSION as a string, e.g. for VERSION == (0, 10, 7), """Return the VERSION as a string, e.g. for VERSION == (0, 10, 7),
return '0.10.7'. return '0.10.7'.
""" """
if isinstance(VERSION[-1], basestring):
return '.'.join(map(str, VERSION[:-1])) + VERSION[-1]
return '.'.join(map(str, VERSION)) return '.'.join(map(str, VERSION))

View File

@ -762,7 +762,7 @@ class BaseDocument(object):
def _build_index_spec(cls, spec): def _build_index_spec(cls, spec):
"""Build a PyMongo index spec from a MongoEngine index spec. """Build a PyMongo index spec from a MongoEngine index spec.
""" """
if isinstance(spec, basestring): if isinstance(spec, six.string_types):
spec = {'fields': [spec]} spec = {'fields': [spec]}
elif isinstance(spec, (list, tuple)): elif isinstance(spec, (list, tuple)):
spec = {'fields': list(spec)} spec = {'fields': list(spec)}
@ -856,7 +856,7 @@ class BaseDocument(object):
# Add any unique_with fields to the back of the index spec # Add any unique_with fields to the back of the index spec
if field.unique_with: if field.unique_with:
if isinstance(field.unique_with, basestring): if isinstance(field.unique_with, six.string_types):
field.unique_with = [field.unique_with] field.unique_with = [field.unique_with]
# Convert unique_with field names to real field names # Convert unique_with field names to real field names

View File

@ -299,7 +299,7 @@ class ComplexBaseField(BaseField):
def to_python(self, value): def to_python(self, value):
"""Convert a MongoDB-compatible type to a Python type. """Convert a MongoDB-compatible type to a Python type.
""" """
if isinstance(value, basestring): if isinstance(value, six.string_types):
return value return value
if hasattr(value, 'to_python'): if hasattr(value, 'to_python'):
@ -345,7 +345,7 @@ class ComplexBaseField(BaseField):
EmbeddedDocument = _import_class("EmbeddedDocument") EmbeddedDocument = _import_class("EmbeddedDocument")
GenericReferenceField = _import_class("GenericReferenceField") GenericReferenceField = _import_class("GenericReferenceField")
if isinstance(value, basestring): if isinstance(value, six.string_types):
return value return value
if hasattr(value, 'to_mongo'): if hasattr(value, 'to_mongo'):

View File

@ -160,7 +160,7 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
# Discard port since it can't be used on MongoReplicaSetClient # Discard port since it can't be used on MongoReplicaSetClient
conn_settings.pop('port', None) conn_settings.pop('port', None)
# Discard replicaSet if not base string # Discard replicaSet if not base string
if not isinstance(conn_settings['replicaSet'], basestring): if not isinstance(conn_settings['replicaSet'], six.string_types):
conn_settings.pop('replicaSet', None) conn_settings.pop('replicaSet', None)
if not IS_PYMONGO_3: if not IS_PYMONGO_3:
connection_class = MongoReplicaSetClient connection_class = MongoReplicaSetClient

View File

@ -25,7 +25,7 @@ class DeReference(object):
:class:`~mongoengine.base.ComplexBaseField` :class:`~mongoengine.base.ComplexBaseField`
:param get: A boolean determining if being called by __get__ :param get: A boolean determining if being called by __get__
""" """
if items is None or isinstance(items, basestring): if items is None or isinstance(items, six.string_types):
return items return items
# cheapest way to convert a queryset to a list # cheapest way to convert a queryset to a list

View File

@ -33,7 +33,7 @@ def includes_cls(fields):
first_field = None first_field = None
if len(fields): if len(fields):
if isinstance(fields[0], basestring): if isinstance(fields[0], six.string_types):
first_field = fields[0] first_field = fields[0]
elif isinstance(fields[0], (list, tuple)) and len(fields[0]): elif isinstance(fields[0], (list, tuple)) and len(fields[0]):
first_field = fields[0][0] first_field = fields[0][0]
@ -146,21 +146,17 @@ class Document(BaseDocument):
__slots__ = ('__objects',) __slots__ = ('__objects',)
def pk(): @property
"""Primary key alias def pk(self):
""" """Get the primary key."""
if 'id_field' not in self._meta:
return None
return getattr(self, self._meta['id_field'])
def fget(self): @pk.setter
if 'id_field' not in self._meta: def pk(self, value):
return None """Set the primary key."""
return getattr(self, self._meta['id_field']) return setattr(self, self._meta['id_field'], value)
def fset(self, value):
return setattr(self, self._meta['id_field'], value)
return property(fget, fset)
pk = pk()
@classmethod @classmethod
def _get_db(cls): def _get_db(cls):
@ -208,7 +204,7 @@ class Document(BaseDocument):
cls.ensure_indexes() cls.ensure_indexes()
return cls._collection return cls._collection
def modify(self, query={}, **update): def modify(self, query=None, **update):
"""Perform an atomic update of the document in the database and reload """Perform an atomic update of the document in the database and reload
the document object using updated version. the document object using updated version.
@ -222,6 +218,8 @@ class Document(BaseDocument):
database matches the query database matches the query
:param update: Django-style update keyword arguments :param update: Django-style update keyword arguments
""" """
if query is None:
query = {}
if self.pk is None: if self.pk is None:
raise InvalidDocumentError("The document does not have a primary key.") raise InvalidDocumentError("The document does not have a primary key.")
@ -412,9 +410,10 @@ class Document(BaseDocument):
self._created = False self._created = False
return self return self
def cascade_save(self, *args, **kwargs): def cascade_save(self, **kwargs):
"""Recursively saves any references / """Recursively save any references and generic references on the
generic references on the document""" document.
"""
_refs = kwargs.get('_refs', []) or [] _refs = kwargs.get('_refs', []) or []
ReferenceField = _import_class('ReferenceField') ReferenceField = _import_class('ReferenceField')
@ -441,16 +440,17 @@ class Document(BaseDocument):
@property @property
def _qs(self): def _qs(self):
""" """Return the queryset to use for updating / reloading / deletions."""
Returns the queryset to use for updating / reloading / deletions
"""
if not hasattr(self, '__objects'): if not hasattr(self, '__objects'):
self.__objects = QuerySet(self, self._get_collection()) self.__objects = QuerySet(self, self._get_collection())
return self.__objects return self.__objects
@property @property
def _object_key(self): def _object_key(self):
"""Dict to identify object in collection """Get the query dict that can be used to fetch this object from
the database. Most of the time it's a simple PK lookup, but in
case of a sharded collection with a compound shard key, it can
contain a more complex query.
""" """
select_dict = {'pk': self.pk} select_dict = {'pk': self.pk}
shard_key = self.__class__._meta.get('shard_key', tuple()) shard_key = self.__class__._meta.get('shard_key', tuple())

View File

@ -78,7 +78,7 @@ class StringField(BaseField):
return value return value
def validate(self, value): def validate(self, value):
if not isinstance(value, basestring): if not isinstance(value, six.string_types):
self.error('StringField only accepts string values') self.error('StringField only accepts string values')
if self.max_length is not None and len(value) > self.max_length: if self.max_length is not None and len(value) > self.max_length:
@ -94,7 +94,7 @@ class StringField(BaseField):
return None return None
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
if not isinstance(op, basestring): if not isinstance(op, six.string_types):
return value return value
if op.lstrip('i') in ('startswith', 'endswith', 'contains', 'exact'): if op.lstrip('i') in ('startswith', 'endswith', 'contains', 'exact'):
@ -349,7 +349,7 @@ class DecimalField(BaseField):
def validate(self, value): def validate(self, value):
if not isinstance(value, decimal.Decimal): if not isinstance(value, decimal.Decimal):
if not isinstance(value, basestring): if not isinstance(value, six.string_types):
value = six.text_type(value) value = six.text_type(value)
try: try:
value = decimal.Decimal(value) value = decimal.Decimal(value)
@ -413,7 +413,7 @@ class DateTimeField(BaseField):
if callable(value): if callable(value):
return value() return value()
if not isinstance(value, basestring): if not isinstance(value, six.string_types):
return None return None
# Attempt to parse a datetime: # Attempt to parse a datetime:
@ -540,16 +540,19 @@ class EmbeddedDocumentField(BaseField):
""" """
def __init__(self, document_type, **kwargs): def __init__(self, document_type, **kwargs):
if not isinstance(document_type, basestring): if (
if not issubclass(document_type, EmbeddedDocument): not isinstance(document_type, six.string_types) and
self.error('Invalid embedded document class provided to an ' not issubclass(document_type, EmbeddedDocument)
'EmbeddedDocumentField') ):
self.error('Invalid embedded document class provided to an '
'EmbeddedDocumentField')
self.document_type_obj = document_type self.document_type_obj = document_type
super(EmbeddedDocumentField, self).__init__(**kwargs) super(EmbeddedDocumentField, self).__init__(**kwargs)
@property @property
def document_type(self): def document_type(self):
if isinstance(self.document_type_obj, basestring): if isinstance(self.document_type_obj, six.string_types):
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT: if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
self.document_type_obj = self.owner_document self.document_type_obj = self.owner_document
else: else:
@ -634,7 +637,7 @@ class DynamicField(BaseField):
"""Convert a Python type to a MongoDB compatible type. """Convert a Python type to a MongoDB compatible type.
""" """
if isinstance(value, basestring): if isinstance(value, six.string_types):
return value return value
if hasattr(value, 'to_mongo'): if hasattr(value, 'to_mongo'):
@ -677,7 +680,7 @@ class DynamicField(BaseField):
return member_name return member_name
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
if isinstance(value, basestring): if isinstance(value, six.string_types):
return StringField().prepare_query_value(op, value) return StringField().prepare_query_value(op, value)
return super(DynamicField, self).prepare_query_value(op, self.to_mongo(value)) return super(DynamicField, self).prepare_query_value(op, self.to_mongo(value))
@ -705,14 +708,14 @@ class ListField(ComplexBaseField):
"""Make sure that a list of valid fields is being used. """Make sure that a list of valid fields is being used.
""" """
if (not isinstance(value, (list, tuple, QuerySet)) or if (not isinstance(value, (list, tuple, QuerySet)) or
isinstance(value, basestring)): isinstance(value, six.string_types)):
self.error('Only lists and tuples may be used in a list field') self.error('Only lists and tuples may be used in a list field')
super(ListField, self).validate(value) super(ListField, self).validate(value)
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
if self.field: if self.field:
if op in ('set', 'unset', None) and ( if op in ('set', 'unset', None) and (
not isinstance(value, basestring) and not isinstance(value, six.string_types) and
not isinstance(value, BaseDocument) and not isinstance(value, BaseDocument) and
hasattr(value, '__iter__')): hasattr(value, '__iter__')):
return [self.field.prepare_query_value(op, v) for v in value] return [self.field.prepare_query_value(op, v) for v in value]
@ -782,7 +785,7 @@ def key_not_string(d):
not a string. not a string.
""" """
for k, v in d.items(): for k, v in d.items():
if not isinstance(k, basestring) or (isinstance(v, dict) and key_not_string(v)): if not isinstance(k, six.string_types) or (isinstance(v, dict) and key_not_string(v)):
return True return True
@ -838,7 +841,7 @@ class DictField(ComplexBaseField):
'istartswith', 'endswith', 'iendswith', 'istartswith', 'endswith', 'iendswith',
'exact', 'iexact'] 'exact', 'iexact']
if op in match_operators and isinstance(value, basestring): if op in match_operators and isinstance(value, six.string_types):
return StringField().prepare_query_value(op, value) return StringField().prepare_query_value(op, value)
if hasattr(self.field, 'field'): if hasattr(self.field, 'field'):
@ -914,10 +917,12 @@ class ReferenceField(BaseField):
A reference to an abstract document type is always stored as a A reference to an abstract document type is always stored as a
:class:`~pymongo.dbref.DBRef`, regardless of the value of `dbref`. :class:`~pymongo.dbref.DBRef`, regardless of the value of `dbref`.
""" """
if not isinstance(document_type, basestring): if (
if not issubclass(document_type, (Document, basestring)): not isinstance(document_type, six.string_types) and
self.error('Argument to ReferenceField constructor must be a ' not issubclass(document_type, Document)
'document class or a string') ):
self.error('Argument to ReferenceField constructor must be a '
'document class or a string')
self.dbref = dbref self.dbref = dbref
self.document_type_obj = document_type self.document_type_obj = document_type
@ -926,7 +931,7 @@ class ReferenceField(BaseField):
@property @property
def document_type(self): def document_type(self):
if isinstance(self.document_type_obj, basestring): if isinstance(self.document_type_obj, six.string_types):
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT: if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
self.document_type_obj = self.owner_document self.document_type_obj = self.owner_document
else: else:
@ -1039,8 +1044,10 @@ class CachedReferenceField(BaseField):
:param fields: A list of fields to be cached in document :param fields: A list of fields to be cached in document
:param auto_sync: if True documents are auto updated. :param auto_sync: if True documents are auto updated.
""" """
if not isinstance(document_type, basestring) and \ if (
not issubclass(document_type, (Document, basestring)): not isinstance(document_type, six.string_types) and
not issubclass(document_type, Document)
):
self.error('Argument to CachedReferenceField constructor must be a' self.error('Argument to CachedReferenceField constructor must be a'
' document class or a string') ' document class or a string')
@ -1080,7 +1087,7 @@ class CachedReferenceField(BaseField):
@property @property
def document_type(self): def document_type(self):
if isinstance(self.document_type_obj, basestring): if isinstance(self.document_type_obj, six.string_types):
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT: if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
self.document_type_obj = self.owner_document self.document_type_obj = self.owner_document
else: else:
@ -1194,13 +1201,13 @@ class GenericReferenceField(BaseField):
# Keep the choices as a list of allowed Document class names # Keep the choices as a list of allowed Document class names
if choices: if choices:
for choice in choices: for choice in choices:
if isinstance(choice, basestring): if isinstance(choice, six.string_types):
self.choices.append(choice) self.choices.append(choice)
elif isinstance(choice, type) and issubclass(choice, Document): elif isinstance(choice, type) and issubclass(choice, Document):
self.choices.append(choice._class_name) self.choices.append(choice._class_name)
else: else:
self.error('Invalid choices provided: must be a list of' self.error('Invalid choices provided: must be a list of'
'Document subclasses and/or basestrings') 'Document subclasses and/or six.string_typess')
def _validate_choices(self, value): def _validate_choices(self, value):
if isinstance(value, dict): if isinstance(value, dict):
@ -1866,7 +1873,7 @@ class UUIDField(BaseField):
if not self._binary: if not self._binary:
original_value = value original_value = value
try: try:
if not isinstance(value, basestring): if not isinstance(value, six.string_types):
value = six.text_type(value) value = six.text_type(value)
return uuid.UUID(value) return uuid.UUID(value)
except Exception: except Exception:
@ -1876,7 +1883,7 @@ class UUIDField(BaseField):
def to_mongo(self, value): def to_mongo(self, value):
if not self._binary: if not self._binary:
return six.text_type(value) return six.text_type(value)
elif isinstance(value, basestring): elif isinstance(value, six.string_types):
return uuid.UUID(value) return uuid.UUID(value)
return value return value
@ -1887,7 +1894,7 @@ class UUIDField(BaseField):
def validate(self, value): def validate(self, value):
if not isinstance(value, uuid.UUID): if not isinstance(value, uuid.UUID):
if not isinstance(value, basestring): if not isinstance(value, six.string_types):
value = str(value) value = str(value)
try: try:
uuid.UUID(value) uuid.UUID(value)

View File

@ -1188,7 +1188,7 @@ class BaseQuerySet(object):
else: else:
map_reduce_function = 'map_reduce' map_reduce_function = 'map_reduce'
if isinstance(output, basestring): if isinstance(output, six.string_types):
mr_args['out'] = output mr_args['out'] = output
elif isinstance(output, dict): elif isinstance(output, dict):

View File

@ -2,6 +2,7 @@ from collections import defaultdict
from bson import SON from bson import SON
import pymongo import pymongo
import six
from mongoengine.base.fields import UPDATE_OPERATORS from mongoengine.base.fields import UPDATE_OPERATORS
from mongoengine.common import _import_class from mongoengine.common import _import_class
@ -66,7 +67,7 @@ def query(_doc_cls=None, **kwargs):
cleaned_fields = [] cleaned_fields = []
for field in fields: for field in fields:
append_field = True append_field = True
if isinstance(field, basestring): if isinstance(field, six.string_types):
parts.append(field) parts.append(field)
append_field = False append_field = False
# is last and CachedReferenceField # is last and CachedReferenceField
@ -84,9 +85,9 @@ def query(_doc_cls=None, **kwargs):
singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not'] singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not']
singular_ops += STRING_OPERATORS singular_ops += STRING_OPERATORS
if op in singular_ops: if op in singular_ops:
if isinstance(field, basestring): if isinstance(field, six.string_types):
if (op in STRING_OPERATORS and if (op in STRING_OPERATORS and
isinstance(value, basestring)): isinstance(value, six.string_types)):
StringField = _import_class('StringField') StringField = _import_class('StringField')
value = StringField.prepare_query_value(op, value) value = StringField.prepare_query_value(op, value)
else: else:
@ -231,7 +232,7 @@ def update(_doc_cls=None, **update):
appended_sub_field = False appended_sub_field = False
for field in fields: for field in fields:
append_field = True append_field = True
if isinstance(field, basestring): if isinstance(field, six.string_types):
# Convert the S operator to $ # Convert the S operator to $
if field == 'S': if field == 'S':
field = '$' field = '$'

View File

@ -21,8 +21,9 @@ except Exception:
def get_version(version_tuple): def get_version(version_tuple):
if not isinstance(version_tuple[-1], int): """Return the version tuple as a string, e.g. for (0, 10, 7),
return '.'.join(map(str, version_tuple[:-1])) + version_tuple[-1] return '0.10.7'.
"""
return '.'.join(map(str, version_tuple)) return '.'.join(map(str, version_tuple))