forked from M-Labs/artiq
Fix handling of default values for RPC arguments (fixes #190).
This commit is contained in:
parent
f4b19fee5c
commit
52102a1a79
|
@ -294,6 +294,7 @@ class CommGeneric:
|
|||
logger.debug("running kernel")
|
||||
|
||||
_rpc_sentinel = object()
|
||||
_rpc_undefined = object()
|
||||
|
||||
# See session.c:{send,receive}_rpc_value and llvm_ir_generator.py:_rpc_tag.
|
||||
def _receive_rpc_value(self, object_map):
|
||||
|
@ -331,18 +332,23 @@ class CommGeneric:
|
|||
present = self._read_int8()
|
||||
if present:
|
||||
return self._receive_rpc_value(object_map)
|
||||
else:
|
||||
return self._rpc_undefined
|
||||
elif tag == "O":
|
||||
return object_map.retrieve(self._read_int32())
|
||||
else:
|
||||
raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
|
||||
|
||||
def _receive_rpc_args(self, object_map):
|
||||
def _receive_rpc_args(self, object_map, defaults):
|
||||
args = []
|
||||
while True:
|
||||
value = self._receive_rpc_value(object_map)
|
||||
if value is self._rpc_sentinel:
|
||||
return args
|
||||
args.append(value)
|
||||
elif value is self._rpc_undefined:
|
||||
args.append(defaults[len(args)])
|
||||
else:
|
||||
args.append(value)
|
||||
|
||||
def _skip_rpc_value(self, tags):
|
||||
tag = tags.pop(0)
|
||||
|
@ -425,22 +431,22 @@ class CommGeneric:
|
|||
raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
|
||||
|
||||
def _serve_rpc(self, object_map):
|
||||
service = self._read_int32()
|
||||
args = self._receive_rpc_args(object_map)
|
||||
service_id = self._read_int32()
|
||||
service = object_map.retrieve(service_id)
|
||||
arguments = self._receive_rpc_args(object_map, service.__defaults__)
|
||||
return_tags = self._read_bytes()
|
||||
logger.debug("rpc service: %d %r -> %s", service, args, return_tags)
|
||||
logger.debug("rpc service: [%d]%r %r -> %s", service_id, service, arguments, return_tags)
|
||||
|
||||
try:
|
||||
result = object_map.retrieve(service)(*args)
|
||||
logger.debug("rpc service: %d %r == %r", service, args, result)
|
||||
result = service(*arguments)
|
||||
logger.debug("rpc service: %d %r == %r", service_id, arguments, result)
|
||||
|
||||
self._write_header(_H2DMsgType.RPC_REPLY)
|
||||
self._write_bytes(return_tags)
|
||||
self._send_rpc_value(bytearray(return_tags), result, result,
|
||||
object_map.retrieve(service))
|
||||
self._send_rpc_value(bytearray(return_tags), result, result, service)
|
||||
self._write_flush()
|
||||
except core_language.ARTIQException as exn:
|
||||
logger.debug("rpc service: %d %r ! %r", service, args, exn)
|
||||
logger.debug("rpc service: %d %r ! %r", service_id, arguments, exn)
|
||||
|
||||
self._write_header(_H2DMsgType.RPC_EXCEPTION)
|
||||
self._write_string(exn.name)
|
||||
|
@ -455,7 +461,7 @@ class CommGeneric:
|
|||
|
||||
self._write_flush()
|
||||
except Exception as exn:
|
||||
logger.debug("rpc service: %d %r ! %r", service, args, exn)
|
||||
logger.debug("rpc service: %d %r ! %r", service_id, arguments, exn)
|
||||
|
||||
self._write_header(_H2DMsgType.RPC_EXCEPTION)
|
||||
self._write_string(type(exn).__name__)
|
||||
|
|
|
@ -10,7 +10,6 @@ class Roundtrip(EnvExperiment):
|
|||
def roundtrip(self, obj, fn):
|
||||
fn(obj)
|
||||
|
||||
|
||||
class RoundtripTest(ExperimentCase):
|
||||
def assertRoundtrip(self, obj):
|
||||
exp = self.create(Roundtrip)
|
||||
|
@ -41,3 +40,22 @@ class RoundtripTest(ExperimentCase):
|
|||
def test_object(self):
|
||||
obj = object()
|
||||
self.assertRoundtrip(obj)
|
||||
|
||||
|
||||
class DefaultArg(EnvExperiment):
|
||||
def build(self):
|
||||
self.setattr_device("core")
|
||||
|
||||
def test(self, foo=42) -> TInt32:
|
||||
return foo
|
||||
|
||||
@kernel
|
||||
def run(self, callback):
|
||||
callback(self.test())
|
||||
|
||||
class DefaultArgTest(ExperimentCase):
|
||||
def test_default_arg(self):
|
||||
exp = self.create(DefaultArg)
|
||||
def callback(value):
|
||||
self.assertEqual(value, 42)
|
||||
exp.run(callback)
|
||||
|
|
Loading…
Reference in New Issue