diff --git a/aerich/utils.py b/aerich/utils.py index 14664a0..6e4fb88 100644 --- a/aerich/utils.py +++ b/aerich/utils.py @@ -4,7 +4,7 @@ import re import sys from pathlib import Path from types import ModuleType -from typing import Dict, Optional +from typing import Dict, Optional, Union from click import BadOptionUsage, ClickException, Context from tortoise import BaseDBAsyncClient, Tortoise @@ -95,7 +95,7 @@ def is_default_function(string: str) -> Optional[re.Match]: return re.match(r"^$", str(string or "")) -def import_py_file(file: Path) -> ModuleType: +def import_py_file(file: Union[str, Path]) -> ModuleType: module_name, file_ext = os.path.splitext(os.path.split(file)[-1]) spec = importlib.util.spec_from_file_location(module_name, file) module = importlib.util.module_from_spec(spec) # type:ignore[arg-type] diff --git a/tests/test_utils.py b/tests/test_utils.py index 654bc0d..9d640d2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,6 @@ from aerich.utils import import_py_file -def test_import_py_file(): +def test_import_py_file() -> None: m = import_py_file("aerich/utils.py") assert getattr(m, "import_py_file")