diff --git a/artiq/master/databases.py b/artiq/master/databases.py index 977cfae44..14cfae4cd 100644 --- a/artiq/master/databases.py +++ b/artiq/master/databases.py @@ -1,5 +1,6 @@ import asyncio -import tokenize + +from artiq.tools import file_import from sipyco.sync_struct import Notifier, process_mod, update_from_dict from sipyco import pyon @@ -7,10 +8,12 @@ from sipyco.asyncio_tools import TaskObject def device_db_from_file(filename): - glbs = dict() - with tokenize.open(filename) as f: - exec(f.read(), glbs) - return glbs["device_db"] + mod = file_import(filename) + + # use __dict__ instead of direct attribute access + # for backwards compatibility of the exception interface + # (raise KeyError and not AttributeError if device_db is missing) + return mod.__dict__["device_db"] class DeviceDB: @@ -19,8 +22,7 @@ class DeviceDB: self.data = Notifier(device_db_from_file(self.backing_file)) def scan(self): - update_from_dict(self.data, - device_db_from_file(self.backing_file)) + update_from_dict(self.data, device_db_from_file(self.backing_file)) def get_device_db(self): return self.data.raw_view diff --git a/artiq/test/test_device_db.py b/artiq/test/test_device_db.py index 5cccaf379..a05093aee 100644 --- a/artiq/test/test_device_db.py +++ b/artiq/test/test_device_db.py @@ -2,6 +2,7 @@ import unittest import tempfile +from pathlib import Path from artiq.master.databases import DeviceDB from artiq.tools import file_import @@ -16,13 +17,13 @@ device_db = { "arguments": {"host": "::1", "ref_period": 1e-09}, }, - "core-alias": "core", - "unresolved-alias": "dummy", + "core_alias": "core", + "unresolved_alias": "dummy", } """ -class TestInvalidDeviceDB(unittest.TestCase): +class TestDeviceDBImport(unittest.TestCase): def test_no_device_db_in_file(self): with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as f: print("", file=f, flush=True) @@ -30,6 +31,30 @@ class TestInvalidDeviceDB(unittest.TestCase): with self.assertRaisesRegex(KeyError, "device_db"): DeviceDB(f.name) + def test_import_same_level(self): + with tempfile.TemporaryDirectory() as tmpdir: + # make sure both files land in the same directory + args = dict(mode="w+", suffix=".py", dir=tmpdir) + with tempfile.NamedTemporaryFile( + **args + ) as fileA, tempfile.NamedTemporaryFile(**args) as fileB: + print(DUMMY_DDB_FILE, file=fileA, flush=True) + print( + f""" +from {Path(fileA.name).stem} import device_db + +device_db["new_core_alias"] = "core" +""", + file=fileB, + flush=True, + ) + + ddb = DeviceDB(fileB.name) + self.assertEqual( + ddb.get("new_core_alias", resolve_alias=True), + DeviceDB(fileA.name).get("core"), + ) + class TestDeviceDB(unittest.TestCase): def setUp(self): @@ -44,15 +69,15 @@ class TestDeviceDB(unittest.TestCase): def test_get_alias(self): with self.assertRaises(TypeError): # str indexing on str - self.ddb.get("core-alias")["class"] + self.ddb.get("core_alias")["class"] self.assertEqual( - self.ddb.get("core-alias", resolve_alias=True), self.ddb.get("core") + self.ddb.get("core_alias", resolve_alias=True), self.ddb.get("core") ) def test_get_unresolved_alias(self): with self.assertRaisesRegex(KeyError, "dummy"): - self.ddb.get("unresolved-alias", resolve_alias=True) + self.ddb.get("unresolved_alias", resolve_alias=True) def test_update(self): with self.assertRaises(KeyError):