Merge pull request #2069 from bagerard/some_refactoring

minor refactoring and additional of tests
This commit is contained in:
erdenezul 2019-06-05 22:30:09 +02:00 committed by GitHub
commit 072e86a2f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 94 additions and 32 deletions

View File

@ -37,6 +37,7 @@ from mongoengine.errors import DoesNotExist, InvalidQueryError, ValidationError
from mongoengine.python_support import StringIO from mongoengine.python_support import StringIO
from mongoengine.queryset import DO_NOTHING from mongoengine.queryset import DO_NOTHING
from mongoengine.queryset.base import BaseQuerySet from mongoengine.queryset.base import BaseQuerySet
from mongoengine.queryset.transform import STRING_OPERATORS
try: try:
from PIL import Image, ImageOps from PIL import Image, ImageOps
@ -106,11 +107,11 @@ class StringField(BaseField):
if not isinstance(op, six.string_types): if not isinstance(op, six.string_types):
return value return value
if op.lstrip('i') in ('startswith', 'endswith', 'contains', 'exact'): if op in STRING_OPERATORS:
flags = 0 case_insensitive = op.startswith('i')
if op.startswith('i'): op = op.lstrip('i')
flags = re.IGNORECASE
op = op.lstrip('i') flags = re.IGNORECASE if case_insensitive else 0
regex = r'%s' regex = r'%s'
if op == 'startswith': if op == 'startswith':
@ -497,15 +498,18 @@ class DateTimeField(BaseField):
if not isinstance(value, six.string_types): if not isinstance(value, six.string_types):
return None return None
return self._parse_datetime(value)
def _parse_datetime(self, value):
# Attempt to parse a datetime from a string
value = value.strip() value = value.strip()
if not value: if not value:
return None return None
# Attempt to parse a datetime:
if dateutil: if dateutil:
try: try:
return dateutil.parser.parse(value) return dateutil.parser.parse(value)
except (TypeError, ValueError): except (TypeError, ValueError, OverflowError):
return None return None
# split usecs, because they are not recognized by strptime. # split usecs, because they are not recognized by strptime.

View File

@ -6,6 +6,7 @@ from mongoengine.connection import get_connection
# Constant that can be used to compare the version retrieved with # Constant that can be used to compare the version retrieved with
# get_mongodb_version() # get_mongodb_version()
MONGODB_34 = (3, 4)
MONGODB_36 = (3, 6) MONGODB_36 = (3, 6)

View File

@ -12,6 +12,7 @@ from bson import DBRef, ObjectId
from pymongo.errors import DuplicateKeyError from pymongo.errors import DuplicateKeyError
from six import iteritems from six import iteritems
from mongoengine.mongodb_support import get_mongodb_version, MONGODB_36, MONGODB_34
from mongoengine.pymongo_support import list_collection_names from mongoengine.pymongo_support import list_collection_names
from tests import fixtures from tests import fixtures
from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest, from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest,
@ -466,7 +467,16 @@ class InstanceTest(MongoDBTestCase):
Animal.drop_collection() Animal.drop_collection()
doc = Animal(superphylum='Deuterostomia') doc = Animal(superphylum='Deuterostomia')
doc.save() doc.save()
doc.reload()
mongo_db = get_mongodb_version()
CMD_QUERY_KEY = 'command' if mongo_db >= MONGODB_36 else 'query'
with query_counter() as q:
doc.reload()
query_op = q.db.system.profile.find({'ns': 'mongoenginetest.animal'})[0]
self.assertEqual(set(query_op[CMD_QUERY_KEY]['filter'].keys()), set(['_id', 'superphylum']))
Animal.drop_collection()
def test_reload_sharded_nested(self): def test_reload_sharded_nested(self):
class SuperPhylum(EmbeddedDocument): class SuperPhylum(EmbeddedDocument):
@ -480,6 +490,34 @@ class InstanceTest(MongoDBTestCase):
doc = Animal(superphylum=SuperPhylum(name='Deuterostomia')) doc = Animal(superphylum=SuperPhylum(name='Deuterostomia'))
doc.save() doc.save()
doc.reload() doc.reload()
Animal.drop_collection()
def test_update_shard_key_routing(self):
"""Ensures updating a doc with a specified shard_key includes it in
the query.
"""
class Animal(Document):
is_mammal = BooleanField()
name = StringField()
meta = {'shard_key': ('is_mammal', 'id')}
Animal.drop_collection()
doc = Animal(is_mammal=True, name='Dog')
doc.save()
mongo_db = get_mongodb_version()
with query_counter() as q:
doc.name = 'Cat'
doc.save()
query_op = q.db.system.profile.find({'ns': 'mongoenginetest.animal'})[0]
self.assertEqual(query_op['op'], 'update')
if mongo_db == MONGODB_34:
self.assertEqual(set(query_op['query'].keys()), set(['_id', 'is_mammal']))
else:
self.assertEqual(set(query_op['command']['q'].keys()), set(['_id', 'is_mammal']))
Animal.drop_collection()
def test_reload_with_changed_fields(self): def test_reload_with_changed_fields(self):
"""Ensures reloading will not affect changed fields""" """Ensures reloading will not affect changed fields"""

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import datetime import datetime as dt
import six import six
try: try:
@ -41,13 +41,13 @@ class TestDateTimeField(MongoDBTestCase):
a document. a document.
""" """
class Person(Document): class Person(Document):
created = DateTimeField(default=datetime.datetime.utcnow) created = DateTimeField(default=dt.datetime.utcnow)
utcnow = datetime.datetime.utcnow() utcnow = dt.datetime.utcnow()
person = Person() person = Person()
person.validate() person.validate()
person_created_t0 = person.created person_created_t0 = person.created
self.assertLess(person.created - utcnow, datetime.timedelta(seconds=1)) self.assertLess(person.created - utcnow, dt.timedelta(seconds=1))
self.assertEqual(person_created_t0, person.created) # make sure it does not change self.assertEqual(person_created_t0, person.created) # make sure it does not change
self.assertEqual(person._data['created'], person.created) self.assertEqual(person._data['created'], person.created)
@ -65,15 +65,15 @@ class TestDateTimeField(MongoDBTestCase):
# Test can save dates # Test can save dates
log = LogEntry() log = LogEntry()
log.date = datetime.date.today() log.date = dt.date.today()
log.save() log.save()
log.reload() log.reload()
self.assertEqual(log.date.date(), datetime.date.today()) self.assertEqual(log.date.date(), dt.date.today())
# Post UTC - microseconds are rounded (down) nearest millisecond and # Post UTC - microseconds are rounded (down) nearest millisecond and
# dropped # dropped
d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 999) d1 = dt.datetime(1970, 1, 1, 0, 0, 1, 999)
d2 = datetime.datetime(1970, 1, 1, 0, 0, 1) d2 = dt.datetime(1970, 1, 1, 0, 0, 1)
log = LogEntry() log = LogEntry()
log.date = d1 log.date = d1
log.save() log.save()
@ -82,8 +82,8 @@ class TestDateTimeField(MongoDBTestCase):
self.assertEqual(log.date, d2) self.assertEqual(log.date, d2)
# Post UTC - microseconds are rounded (down) nearest millisecond # Post UTC - microseconds are rounded (down) nearest millisecond
d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9999) d1 = dt.datetime(1970, 1, 1, 0, 0, 1, 9999)
d2 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9000) d2 = dt.datetime(1970, 1, 1, 0, 0, 1, 9000)
log.date = d1 log.date = d1
log.save() log.save()
log.reload() log.reload()
@ -93,8 +93,8 @@ class TestDateTimeField(MongoDBTestCase):
if not six.PY3: if not six.PY3:
# Pre UTC dates microseconds below 1000 are dropped # Pre UTC dates microseconds below 1000 are dropped
# This does not seem to be true in PY3 # This does not seem to be true in PY3
d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999) d1 = dt.datetime(1969, 12, 31, 23, 59, 59, 999)
d2 = datetime.datetime(1969, 12, 31, 23, 59, 59) d2 = dt.datetime(1969, 12, 31, 23, 59, 59)
log.date = d1 log.date = d1
log.save() log.save()
log.reload() log.reload()
@ -108,7 +108,7 @@ class TestDateTimeField(MongoDBTestCase):
LogEntry.drop_collection() LogEntry.drop_collection()
d1 = datetime.datetime(1970, 1, 1, 0, 0, 1) d1 = dt.datetime(1970, 1, 1, 0, 0, 1)
log = LogEntry() log = LogEntry()
log.date = d1 log.date = d1
log.validate() log.validate()
@ -124,7 +124,7 @@ class TestDateTimeField(MongoDBTestCase):
# create additional 19 log entries for a total of 20 # create additional 19 log entries for a total of 20
for i in range(1971, 1990): for i in range(1971, 1990):
d = datetime.datetime(i, 1, 1, 0, 0, 1) d = dt.datetime(i, 1, 1, 0, 0, 1)
LogEntry(date=d).save() LogEntry(date=d).save()
self.assertEqual(LogEntry.objects.count(), 20) self.assertEqual(LogEntry.objects.count(), 20)
@ -143,15 +143,15 @@ class TestDateTimeField(MongoDBTestCase):
i += 1 i += 1
# Test searching # Test searching
logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980, 1, 1)) logs = LogEntry.objects.filter(date__gte=dt.datetime(1980, 1, 1))
self.assertEqual(logs.count(), 10) self.assertEqual(logs.count(), 10)
logs = LogEntry.objects.filter(date__lte=datetime.datetime(1980, 1, 1)) logs = LogEntry.objects.filter(date__lte=dt.datetime(1980, 1, 1))
self.assertEqual(logs.count(), 10) self.assertEqual(logs.count(), 10)
logs = LogEntry.objects.filter( logs = LogEntry.objects.filter(
date__lte=datetime.datetime(1980, 1, 1), date__lte=dt.datetime(1980, 1, 1),
date__gte=datetime.datetime(1975, 1, 1), date__gte=dt.datetime(1975, 1, 1),
) )
self.assertEqual(logs.count(), 5) self.assertEqual(logs.count(), 5)
@ -163,20 +163,20 @@ class TestDateTimeField(MongoDBTestCase):
time = DateTimeField() time = DateTimeField()
log = LogEntry() log = LogEntry()
log.time = datetime.datetime.now() log.time = dt.datetime.now()
log.validate() log.validate()
log.time = datetime.date.today() log.time = dt.date.today()
log.validate() log.validate()
log.time = datetime.datetime.now().isoformat(' ') log.time = dt.datetime.now().isoformat(' ')
log.validate() log.validate()
log.time = '2019-05-16 21:42:57.897847' log.time = '2019-05-16 21:42:57.897847'
log.validate() log.validate()
if dateutil: if dateutil:
log.time = datetime.datetime.now().isoformat('T') log.time = dt.datetime.now().isoformat('T')
log.validate() log.validate()
log.time = -1 log.time = -1
@ -190,6 +190,25 @@ class TestDateTimeField(MongoDBTestCase):
log.time = '2019-05-16 21:42:57.123.456' log.time = '2019-05-16 21:42:57.123.456'
self.assertRaises(ValidationError, log.validate) self.assertRaises(ValidationError, log.validate)
def test_parse_datetime_as_str(self):
class DTDoc(Document):
date = DateTimeField()
date_str = '2019-03-02 22:26:01'
# make sure that passing a parsable datetime works
dtd = DTDoc()
dtd.date = date_str
self.assertIsInstance(dtd.date, six.string_types)
dtd.save()
dtd.reload()
self.assertIsInstance(dtd.date, dt.datetime)
self.assertEqual(str(dtd.date), date_str)
dtd.date = 'January 1st, 9999999999'
self.assertRaises(ValidationError, dtd.validate)
class TestDateTimeTzAware(MongoDBTestCase): class TestDateTimeTzAware(MongoDBTestCase):
def test_datetime_tz_aware_mark_as_changed(self): def test_datetime_tz_aware_mark_as_changed(self):
@ -205,8 +224,8 @@ class TestDateTimeTzAware(MongoDBTestCase):
LogEntry.drop_collection() LogEntry.drop_collection()
LogEntry(time=datetime.datetime(2013, 1, 1, 0, 0, 0)).save() LogEntry(time=dt.datetime(2013, 1, 1, 0, 0, 0)).save()
log = LogEntry.objects.first() log = LogEntry.objects.first()
log.time = datetime.datetime(2013, 1, 1, 0, 0, 0) log.time = dt.datetime(2013, 1, 1, 0, 0, 0)
self.assertEqual(['time'], log._changed_fields) self.assertEqual(['time'], log._changed_fields)