2
0
mirror of https://github.com/m-labs/artiq.git synced 2024-12-25 11:18:27 +08:00

tools/file_import: use SourceFileLoader

This allows loading modules from files with extensions not in
importlib.machinery.SOURCE_SUFFIXES

Signed-off-by: Etienne Wodey <etienne.wodey@aqt.eu>
This commit is contained in:
Etienne Wodey 2021-12-08 23:41:38 +01:00 committed by Sébastien Bourdeauducq
parent 311a818a49
commit 33a9ca2684
2 changed files with 16 additions and 5 deletions

View File

@ -1,7 +1,8 @@
from contextlib import contextmanager from contextlib import contextmanager
import unittest import importlib
from pathlib import Path from pathlib import Path
import tempfile import tempfile
import unittest
from artiq import tools from artiq import tools
@ -10,13 +11,13 @@ from artiq import tools
# Very simplified version of CPython's # Very simplified version of CPython's
# Lib/test/test_importlib/util.py:create_modules # Lib/test/test_importlib/util.py:create_modules
@contextmanager @contextmanager
def create_modules(*names): def create_modules(*names, extension=".py"):
mapping = {} mapping = {}
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
mapping[".root"] = Path(temp_dir) mapping[".root"] = Path(temp_dir)
for name in names: for name in names:
file_path = Path(temp_dir) / f"{name}.py" file_path = Path(temp_dir) / f"{name}{extension}"
with file_path.open("w") as fp: with file_path.open("w") as fp:
print(f"_MODULE_NAME = {name!r}", file=fp) print(f"_MODULE_NAME = {name!r}", file=fp)
mapping[name] = file_path mapping[name] = file_path
@ -45,6 +46,13 @@ class TestFileImport(unittest.TestCase):
self.assertEqual(mod2._M1_NAME, mod1._MODULE_NAME) self.assertEqual(mod2._M1_NAME, mod1._MODULE_NAME)
def test_can_import_not_in_source_suffixes(self):
for extension in ["", ".something"]:
self.assertNotIn(extension, importlib.machinery.SOURCE_SUFFIXES)
with create_modules(MODNAME, extension=extension) as mods:
mod = tools.file_import(str(mods[MODNAME]))
self.assertEqual(Path(mod.__file__).name, f"{MODNAME}{extension}")
class TestGetExperiment(unittest.TestCase): class TestGetExperiment(unittest.TestCase):
def test_fail_no_experiments(self): def test_fail_no_experiments(self):

View File

@ -1,5 +1,5 @@
import asyncio import asyncio
import importlib.util import importlib
import inspect import inspect
import logging import logging
import os import os
@ -78,7 +78,10 @@ def file_import(filename, prefix="file_import_"):
sys.path.insert(0, path) sys.path.insert(0, path)
try: try:
spec = importlib.util.spec_from_file_location(modname, filename) spec = importlib.util.spec_from_loader(
modname,
importlib.machinery.SourceFileLoader(modname, str(filename)),
)
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) spec.loader.exec_module(module)
finally: finally: