forked from M-Labs/artiq
tools/file_import: use SourceFileLoader
This allows loading modules from files with extensions not in importlib.machinery.SOURCE_SUFFIXES Signed-off-by: Etienne Wodey <etienne.wodey@aqt.eu>
This commit is contained in:
parent
311a818a49
commit
33a9ca2684
|
@ -1,7 +1,8 @@
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
import unittest
|
import importlib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
from artiq import tools
|
from artiq import tools
|
||||||
|
|
||||||
|
@ -10,13 +11,13 @@ from artiq import tools
|
||||||
# Very simplified version of CPython's
|
# Very simplified version of CPython's
|
||||||
# Lib/test/test_importlib/util.py:create_modules
|
# Lib/test/test_importlib/util.py:create_modules
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def create_modules(*names):
|
def create_modules(*names, extension=".py"):
|
||||||
mapping = {}
|
mapping = {}
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
mapping[".root"] = Path(temp_dir)
|
mapping[".root"] = Path(temp_dir)
|
||||||
|
|
||||||
for name in names:
|
for name in names:
|
||||||
file_path = Path(temp_dir) / f"{name}.py"
|
file_path = Path(temp_dir) / f"{name}{extension}"
|
||||||
with file_path.open("w") as fp:
|
with file_path.open("w") as fp:
|
||||||
print(f"_MODULE_NAME = {name!r}", file=fp)
|
print(f"_MODULE_NAME = {name!r}", file=fp)
|
||||||
mapping[name] = file_path
|
mapping[name] = file_path
|
||||||
|
@ -45,6 +46,13 @@ class TestFileImport(unittest.TestCase):
|
||||||
|
|
||||||
self.assertEqual(mod2._M1_NAME, mod1._MODULE_NAME)
|
self.assertEqual(mod2._M1_NAME, mod1._MODULE_NAME)
|
||||||
|
|
||||||
|
def test_can_import_not_in_source_suffixes(self):
|
||||||
|
for extension in ["", ".something"]:
|
||||||
|
self.assertNotIn(extension, importlib.machinery.SOURCE_SUFFIXES)
|
||||||
|
with create_modules(MODNAME, extension=extension) as mods:
|
||||||
|
mod = tools.file_import(str(mods[MODNAME]))
|
||||||
|
self.assertEqual(Path(mod.__file__).name, f"{MODNAME}{extension}")
|
||||||
|
|
||||||
|
|
||||||
class TestGetExperiment(unittest.TestCase):
|
class TestGetExperiment(unittest.TestCase):
|
||||||
def test_fail_no_experiments(self):
|
def test_fail_no_experiments(self):
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import importlib.util
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
@ -78,7 +78,10 @@ def file_import(filename, prefix="file_import_"):
|
||||||
sys.path.insert(0, path)
|
sys.path.insert(0, path)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
spec = importlib.util.spec_from_file_location(modname, filename)
|
spec = importlib.util.spec_from_loader(
|
||||||
|
modname,
|
||||||
|
importlib.machinery.SourceFileLoader(modname, str(filename)),
|
||||||
|
)
|
||||||
module = importlib.util.module_from_spec(spec)
|
module = importlib.util.module_from_spec(spec)
|
||||||
spec.loader.exec_module(module)
|
spec.loader.exec_module(module)
|
||||||
finally:
|
finally:
|
||||||
|
|
Loading…
Reference in New Issue