diff --git a/localized_fields/fields/integer_field.py b/localized_fields/fields/integer_field.py index 745c2c6..4b800be 100644 --- a/localized_fields/fields/integer_field.py +++ b/localized_fields/fields/integer_field.py @@ -1,6 +1,7 @@ from typing import Dict, Optional, Union from django.conf import settings +from django.contrib.postgres.fields.hstore import KeyTransform from django.db.utils import IntegrityError from ..forms import LocalizedIntegerFieldForm @@ -8,17 +9,42 @@ from ..value import LocalizedIntegerValue, LocalizedValue from .field import LocalizedField +class LocalizedIntegerFieldKeyTransform(KeyTransform): + """Transform that selects a single key from a hstore value and casts it to + an integer.""" + + def as_sql(self, compiler, connection): + sql, params = super().as_sql(compiler, connection) + return f"{sql}::integer", params + + class LocalizedIntegerField(LocalizedField): """Stores integers as a localized value.""" attr_class = LocalizedIntegerValue + def get_transform(self, name): + """Gets the transformation to apply when selecting this value. + + This is where the SQL expression to grab a single is added and + the cast to integer so that sorting by a hstore value works as + expected. + """ + + def _transform(*args, **kwargs): + return LocalizedIntegerFieldKeyTransform(name, *args, **kwargs) + + return _transform + @classmethod def from_db_value(cls, value, *_) -> Optional[LocalizedIntegerValue]: db_value = super().from_db_value(value) if db_value is None: return db_value + if isinstance(db_value, str): + return int(db_value) + # if we were used in an expression somehow then it might be # that we're returning an individual value or an array, so # we should not convert that into an :see:LocalizedIntegerValue diff --git a/tests/test_integer_field.py b/tests/test_integer_field.py index 61f0170..8084ea4 100644 --- a/tests/test_integer_field.py +++ b/tests/test_integer_field.py @@ -1,3 +1,5 @@ +import django + from django.conf import settings from django.db import connection from django.db.utils import IntegrityError @@ -194,3 +196,33 @@ class LocalizedIntegerFieldTestCase(TestCase): ) obj.refresh_from_db() assert obj.score.get(settings.LANGUAGE_CODE) == 75 + + def test_order_by(self): + """Tests whether ordering by a :see:LocalizedIntegerField key works + expected.""" + + # using key transforms (score__en) in order_by(..) is only + # supported since Django 2.1 + # https://github.com/django/django/commit/2162f0983de0dfe2178531638ce7ea56f54dd4e7#diff-0edd853580d56db07e4020728d59e193 + + if django.VERSION < (2, 1): + return + + model = get_fake_model( + { + "score": LocalizedIntegerField( + default={settings.LANGUAGE_CODE: 1337}, null=True + ) + } + ) + + model.objects.create(score=dict(en=982)) + model.objects.create(score=dict(en=382)) + model.objects.create(score=dict(en=1331)) + + res = list( + model.objects.values_list("score__en", flat=True).order_by( + "-score__en" + ) + ) + assert res == [1331, 982, 382]