From 33a9ca26848f63db70490289087d1a46e63f54f8 Mon Sep 17 00:00:00 2001 From: Etienne Wodey Date: Wed, 8 Dec 2021 23:41:38 +0100 Subject: [PATCH] tools/file_import: use SourceFileLoader This allows loading modules from files with extensions not in importlib.machinery.SOURCE_SUFFIXES Signed-off-by: Etienne Wodey --- artiq/test/test_tools.py | 14 +++++++++++--- artiq/tools.py | 7 +++++-- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/artiq/test/test_tools.py b/artiq/test/test_tools.py index 0c24dd29e..8d104281d 100644 --- a/artiq/test/test_tools.py +++ b/artiq/test/test_tools.py @@ -1,7 +1,8 @@ from contextlib import contextmanager -import unittest +import importlib from pathlib import Path import tempfile +import unittest from artiq import tools @@ -10,13 +11,13 @@ from artiq import tools # Very simplified version of CPython's # Lib/test/test_importlib/util.py:create_modules @contextmanager -def create_modules(*names): +def create_modules(*names, extension=".py"): mapping = {} with tempfile.TemporaryDirectory() as temp_dir: mapping[".root"] = Path(temp_dir) for name in names: - file_path = Path(temp_dir) / f"{name}.py" + file_path = Path(temp_dir) / f"{name}{extension}" with file_path.open("w") as fp: print(f"_MODULE_NAME = {name!r}", file=fp) mapping[name] = file_path @@ -45,6 +46,13 @@ class TestFileImport(unittest.TestCase): self.assertEqual(mod2._M1_NAME, mod1._MODULE_NAME) + def test_can_import_not_in_source_suffixes(self): + for extension in ["", ".something"]: + self.assertNotIn(extension, importlib.machinery.SOURCE_SUFFIXES) + with create_modules(MODNAME, extension=extension) as mods: + mod = tools.file_import(str(mods[MODNAME])) + self.assertEqual(Path(mod.__file__).name, f"{MODNAME}{extension}") + class TestGetExperiment(unittest.TestCase): def test_fail_no_experiments(self): diff --git a/artiq/tools.py b/artiq/tools.py index 167f8cf74..d98059356 100644 --- a/artiq/tools.py +++ b/artiq/tools.py @@ -1,5 +1,5 @@ import asyncio -import importlib.util +import importlib import inspect import logging import os @@ -78,7 +78,10 @@ def file_import(filename, prefix="file_import_"): sys.path.insert(0, path) try: - spec = importlib.util.spec_from_file_location(modname, filename) + spec = importlib.util.spec_from_loader( + modname, + importlib.machinery.SourceFileLoader(modname, str(filename)), + ) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) finally: