forked from M-Labs/artiq
Fix handling of default values for RPC arguments (fixes #190).
This commit is contained in:
parent
f4b19fee5c
commit
52102a1a79
|
@ -294,6 +294,7 @@ class CommGeneric:
|
||||||
logger.debug("running kernel")
|
logger.debug("running kernel")
|
||||||
|
|
||||||
_rpc_sentinel = object()
|
_rpc_sentinel = object()
|
||||||
|
_rpc_undefined = object()
|
||||||
|
|
||||||
# See session.c:{send,receive}_rpc_value and llvm_ir_generator.py:_rpc_tag.
|
# See session.c:{send,receive}_rpc_value and llvm_ir_generator.py:_rpc_tag.
|
||||||
def _receive_rpc_value(self, object_map):
|
def _receive_rpc_value(self, object_map):
|
||||||
|
@ -331,18 +332,23 @@ class CommGeneric:
|
||||||
present = self._read_int8()
|
present = self._read_int8()
|
||||||
if present:
|
if present:
|
||||||
return self._receive_rpc_value(object_map)
|
return self._receive_rpc_value(object_map)
|
||||||
|
else:
|
||||||
|
return self._rpc_undefined
|
||||||
elif tag == "O":
|
elif tag == "O":
|
||||||
return object_map.retrieve(self._read_int32())
|
return object_map.retrieve(self._read_int32())
|
||||||
else:
|
else:
|
||||||
raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
|
raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
|
||||||
|
|
||||||
def _receive_rpc_args(self, object_map):
|
def _receive_rpc_args(self, object_map, defaults):
|
||||||
args = []
|
args = []
|
||||||
while True:
|
while True:
|
||||||
value = self._receive_rpc_value(object_map)
|
value = self._receive_rpc_value(object_map)
|
||||||
if value is self._rpc_sentinel:
|
if value is self._rpc_sentinel:
|
||||||
return args
|
return args
|
||||||
args.append(value)
|
elif value is self._rpc_undefined:
|
||||||
|
args.append(defaults[len(args)])
|
||||||
|
else:
|
||||||
|
args.append(value)
|
||||||
|
|
||||||
def _skip_rpc_value(self, tags):
|
def _skip_rpc_value(self, tags):
|
||||||
tag = tags.pop(0)
|
tag = tags.pop(0)
|
||||||
|
@ -425,22 +431,22 @@ class CommGeneric:
|
||||||
raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
|
raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
|
||||||
|
|
||||||
def _serve_rpc(self, object_map):
|
def _serve_rpc(self, object_map):
|
||||||
service = self._read_int32()
|
service_id = self._read_int32()
|
||||||
args = self._receive_rpc_args(object_map)
|
service = object_map.retrieve(service_id)
|
||||||
|
arguments = self._receive_rpc_args(object_map, service.__defaults__)
|
||||||
return_tags = self._read_bytes()
|
return_tags = self._read_bytes()
|
||||||
logger.debug("rpc service: %d %r -> %s", service, args, return_tags)
|
logger.debug("rpc service: [%d]%r %r -> %s", service_id, service, arguments, return_tags)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = object_map.retrieve(service)(*args)
|
result = service(*arguments)
|
||||||
logger.debug("rpc service: %d %r == %r", service, args, result)
|
logger.debug("rpc service: %d %r == %r", service_id, arguments, result)
|
||||||
|
|
||||||
self._write_header(_H2DMsgType.RPC_REPLY)
|
self._write_header(_H2DMsgType.RPC_REPLY)
|
||||||
self._write_bytes(return_tags)
|
self._write_bytes(return_tags)
|
||||||
self._send_rpc_value(bytearray(return_tags), result, result,
|
self._send_rpc_value(bytearray(return_tags), result, result, service)
|
||||||
object_map.retrieve(service))
|
|
||||||
self._write_flush()
|
self._write_flush()
|
||||||
except core_language.ARTIQException as exn:
|
except core_language.ARTIQException as exn:
|
||||||
logger.debug("rpc service: %d %r ! %r", service, args, exn)
|
logger.debug("rpc service: %d %r ! %r", service_id, arguments, exn)
|
||||||
|
|
||||||
self._write_header(_H2DMsgType.RPC_EXCEPTION)
|
self._write_header(_H2DMsgType.RPC_EXCEPTION)
|
||||||
self._write_string(exn.name)
|
self._write_string(exn.name)
|
||||||
|
@ -455,7 +461,7 @@ class CommGeneric:
|
||||||
|
|
||||||
self._write_flush()
|
self._write_flush()
|
||||||
except Exception as exn:
|
except Exception as exn:
|
||||||
logger.debug("rpc service: %d %r ! %r", service, args, exn)
|
logger.debug("rpc service: %d %r ! %r", service_id, arguments, exn)
|
||||||
|
|
||||||
self._write_header(_H2DMsgType.RPC_EXCEPTION)
|
self._write_header(_H2DMsgType.RPC_EXCEPTION)
|
||||||
self._write_string(type(exn).__name__)
|
self._write_string(type(exn).__name__)
|
||||||
|
|
|
@ -10,7 +10,6 @@ class Roundtrip(EnvExperiment):
|
||||||
def roundtrip(self, obj, fn):
|
def roundtrip(self, obj, fn):
|
||||||
fn(obj)
|
fn(obj)
|
||||||
|
|
||||||
|
|
||||||
class RoundtripTest(ExperimentCase):
|
class RoundtripTest(ExperimentCase):
|
||||||
def assertRoundtrip(self, obj):
|
def assertRoundtrip(self, obj):
|
||||||
exp = self.create(Roundtrip)
|
exp = self.create(Roundtrip)
|
||||||
|
@ -41,3 +40,22 @@ class RoundtripTest(ExperimentCase):
|
||||||
def test_object(self):
|
def test_object(self):
|
||||||
obj = object()
|
obj = object()
|
||||||
self.assertRoundtrip(obj)
|
self.assertRoundtrip(obj)
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultArg(EnvExperiment):
|
||||||
|
def build(self):
|
||||||
|
self.setattr_device("core")
|
||||||
|
|
||||||
|
def test(self, foo=42) -> TInt32:
|
||||||
|
return foo
|
||||||
|
|
||||||
|
@kernel
|
||||||
|
def run(self, callback):
|
||||||
|
callback(self.test())
|
||||||
|
|
||||||
|
class DefaultArgTest(ExperimentCase):
|
||||||
|
def test_default_arg(self):
|
||||||
|
exp = self.create(DefaultArg)
|
||||||
|
def callback(value):
|
||||||
|
self.assertEqual(value, 42)
|
||||||
|
exp.run(callback)
|
||||||
|
|
Loading…
Reference in New Issue