Fix handling of default values for RPC arguments (fixes #190).

This commit is contained in:
whitequark 2015-12-18 18:03:07 +08:00
parent f4b19fee5c
commit 52102a1a79
2 changed files with 36 additions and 12 deletions

View File

@ -294,6 +294,7 @@ class CommGeneric:
logger.debug("running kernel")
_rpc_sentinel = object()
_rpc_undefined = object()
# See session.c:{send,receive}_rpc_value and llvm_ir_generator.py:_rpc_tag.
def _receive_rpc_value(self, object_map):
@ -331,17 +332,22 @@ class CommGeneric:
present = self._read_int8()
if present:
return self._receive_rpc_value(object_map)
else:
return self._rpc_undefined
elif tag == "O":
return object_map.retrieve(self._read_int32())
else:
raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
def _receive_rpc_args(self, object_map):
def _receive_rpc_args(self, object_map, defaults):
args = []
while True:
value = self._receive_rpc_value(object_map)
if value is self._rpc_sentinel:
return args
elif value is self._rpc_undefined:
args.append(defaults[len(args)])
else:
args.append(value)
def _skip_rpc_value(self, tags):
@ -425,22 +431,22 @@ class CommGeneric:
raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
def _serve_rpc(self, object_map):
service = self._read_int32()
args = self._receive_rpc_args(object_map)
service_id = self._read_int32()
service = object_map.retrieve(service_id)
arguments = self._receive_rpc_args(object_map, service.__defaults__)
return_tags = self._read_bytes()
logger.debug("rpc service: %d %r -> %s", service, args, return_tags)
logger.debug("rpc service: [%d]%r %r -> %s", service_id, service, arguments, return_tags)
try:
result = object_map.retrieve(service)(*args)
logger.debug("rpc service: %d %r == %r", service, args, result)
result = service(*arguments)
logger.debug("rpc service: %d %r == %r", service_id, arguments, result)
self._write_header(_H2DMsgType.RPC_REPLY)
self._write_bytes(return_tags)
self._send_rpc_value(bytearray(return_tags), result, result,
object_map.retrieve(service))
self._send_rpc_value(bytearray(return_tags), result, result, service)
self._write_flush()
except core_language.ARTIQException as exn:
logger.debug("rpc service: %d %r ! %r", service, args, exn)
logger.debug("rpc service: %d %r ! %r", service_id, arguments, exn)
self._write_header(_H2DMsgType.RPC_EXCEPTION)
self._write_string(exn.name)
@ -455,7 +461,7 @@ class CommGeneric:
self._write_flush()
except Exception as exn:
logger.debug("rpc service: %d %r ! %r", service, args, exn)
logger.debug("rpc service: %d %r ! %r", service_id, arguments, exn)
self._write_header(_H2DMsgType.RPC_EXCEPTION)
self._write_string(type(exn).__name__)

View File

@ -10,7 +10,6 @@ class Roundtrip(EnvExperiment):
def roundtrip(self, obj, fn):
fn(obj)
class RoundtripTest(ExperimentCase):
def assertRoundtrip(self, obj):
exp = self.create(Roundtrip)
@ -41,3 +40,22 @@ class RoundtripTest(ExperimentCase):
def test_object(self):
obj = object()
self.assertRoundtrip(obj)
class DefaultArg(EnvExperiment):
def build(self):
self.setattr_device("core")
def test(self, foo=42) -> TInt32:
return foo
@kernel
def run(self, callback):
callback(self.test())
class DefaultArgTest(ExperimentCase):
def test_default_arg(self):
exp = self.create(DefaultArg)
def callback(value):
self.assertEqual(value, 42)
exp.run(callback)