diff --git a/artiq/test/test_tools.py b/artiq/test/test_tools.py index 0c24dd29e..8d104281d 100644 --- a/artiq/test/test_tools.py +++ b/artiq/test/test_tools.py @@ -1,7 +1,8 @@ from contextlib import contextmanager -import unittest +import importlib from pathlib import Path import tempfile +import unittest from artiq import tools @@ -10,13 +11,13 @@ from artiq import tools # Very simplified version of CPython's # Lib/test/test_importlib/util.py:create_modules @contextmanager -def create_modules(*names): +def create_modules(*names, extension=".py"): mapping = {} with tempfile.TemporaryDirectory() as temp_dir: mapping[".root"] = Path(temp_dir) for name in names: - file_path = Path(temp_dir) / f"{name}.py" + file_path = Path(temp_dir) / f"{name}{extension}" with file_path.open("w") as fp: print(f"_MODULE_NAME = {name!r}", file=fp) mapping[name] = file_path @@ -45,6 +46,13 @@ class TestFileImport(unittest.TestCase): self.assertEqual(mod2._M1_NAME, mod1._MODULE_NAME) + def test_can_import_not_in_source_suffixes(self): + for extension in ["", ".something"]: + self.assertNotIn(extension, importlib.machinery.SOURCE_SUFFIXES) + with create_modules(MODNAME, extension=extension) as mods: + mod = tools.file_import(str(mods[MODNAME])) + self.assertEqual(Path(mod.__file__).name, f"{MODNAME}{extension}") + class TestGetExperiment(unittest.TestCase): def test_fail_no_experiments(self): diff --git a/artiq/tools.py b/artiq/tools.py index 167f8cf74..d98059356 100644 --- a/artiq/tools.py +++ b/artiq/tools.py @@ -1,5 +1,5 @@ import asyncio -import importlib.util +import importlib import inspect import logging import os @@ -78,7 +78,10 @@ def file_import(filename, prefix="file_import_"): sys.path.insert(0, path) try: - spec = importlib.util.spec_from_file_location(modname, filename) + spec = importlib.util.spec_from_loader( + modname, + importlib.machinery.SourceFileLoader(modname, str(filename)), + ) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) finally: