forked from M-Labs/artiq
1
0
Fork 0

master/databases: use tools.file_import to load the device_db

Signed-off-by: Etienne Wodey <wodey@iqo.uni-hannover.de>
This commit is contained in:
Etienne Wodey 2021-06-16 16:27:02 +02:00 committed by Sébastien Bourdeauducq
parent 5c23e6edb6
commit b8ab5f2607
2 changed files with 40 additions and 13 deletions

View File

@ -1,5 +1,6 @@
import asyncio import asyncio
import tokenize
from artiq.tools import file_import
from sipyco.sync_struct import Notifier, process_mod, update_from_dict from sipyco.sync_struct import Notifier, process_mod, update_from_dict
from sipyco import pyon from sipyco import pyon
@ -7,10 +8,12 @@ from sipyco.asyncio_tools import TaskObject
def device_db_from_file(filename): def device_db_from_file(filename):
glbs = dict() mod = file_import(filename)
with tokenize.open(filename) as f:
exec(f.read(), glbs) # use __dict__ instead of direct attribute access
return glbs["device_db"] # for backwards compatibility of the exception interface
# (raise KeyError and not AttributeError if device_db is missing)
return mod.__dict__["device_db"]
class DeviceDB: class DeviceDB:
@ -19,8 +22,7 @@ class DeviceDB:
self.data = Notifier(device_db_from_file(self.backing_file)) self.data = Notifier(device_db_from_file(self.backing_file))
def scan(self): def scan(self):
update_from_dict(self.data, update_from_dict(self.data, device_db_from_file(self.backing_file))
device_db_from_file(self.backing_file))
def get_device_db(self): def get_device_db(self):
return self.data.raw_view return self.data.raw_view

View File

@ -2,6 +2,7 @@
import unittest import unittest
import tempfile import tempfile
from pathlib import Path
from artiq.master.databases import DeviceDB from artiq.master.databases import DeviceDB
from artiq.tools import file_import from artiq.tools import file_import
@ -16,13 +17,13 @@ device_db = {
"arguments": {"host": "::1", "ref_period": 1e-09}, "arguments": {"host": "::1", "ref_period": 1e-09},
}, },
"core-alias": "core", "core_alias": "core",
"unresolved-alias": "dummy", "unresolved_alias": "dummy",
} }
""" """
class TestInvalidDeviceDB(unittest.TestCase): class TestDeviceDBImport(unittest.TestCase):
def test_no_device_db_in_file(self): def test_no_device_db_in_file(self):
with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as f: with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as f:
print("", file=f, flush=True) print("", file=f, flush=True)
@ -30,6 +31,30 @@ class TestInvalidDeviceDB(unittest.TestCase):
with self.assertRaisesRegex(KeyError, "device_db"): with self.assertRaisesRegex(KeyError, "device_db"):
DeviceDB(f.name) DeviceDB(f.name)
def test_import_same_level(self):
with tempfile.TemporaryDirectory() as tmpdir:
# make sure both files land in the same directory
args = dict(mode="w+", suffix=".py", dir=tmpdir)
with tempfile.NamedTemporaryFile(
**args
) as fileA, tempfile.NamedTemporaryFile(**args) as fileB:
print(DUMMY_DDB_FILE, file=fileA, flush=True)
print(
f"""
from {Path(fileA.name).stem} import device_db
device_db["new_core_alias"] = "core"
""",
file=fileB,
flush=True,
)
ddb = DeviceDB(fileB.name)
self.assertEqual(
ddb.get("new_core_alias", resolve_alias=True),
DeviceDB(fileA.name).get("core"),
)
class TestDeviceDB(unittest.TestCase): class TestDeviceDB(unittest.TestCase):
def setUp(self): def setUp(self):
@ -44,15 +69,15 @@ class TestDeviceDB(unittest.TestCase):
def test_get_alias(self): def test_get_alias(self):
with self.assertRaises(TypeError): # str indexing on str with self.assertRaises(TypeError): # str indexing on str
self.ddb.get("core-alias")["class"] self.ddb.get("core_alias")["class"]
self.assertEqual( self.assertEqual(
self.ddb.get("core-alias", resolve_alias=True), self.ddb.get("core") self.ddb.get("core_alias", resolve_alias=True), self.ddb.get("core")
) )
def test_get_unresolved_alias(self): def test_get_unresolved_alias(self):
with self.assertRaisesRegex(KeyError, "dummy"): with self.assertRaisesRegex(KeyError, "dummy"):
self.ddb.get("unresolved-alias", resolve_alias=True) self.ddb.get("unresolved_alias", resolve_alias=True)
def test_update(self): def test_update(self):
with self.assertRaises(KeyError): with self.assertRaises(KeyError):