Fix handling of default values for RPC arguments (fixes #190).

This commit is contained in:
whitequark 2015-12-18 18:03:07 +08:00
parent f4b19fee5c
commit 52102a1a79
2 changed files with 36 additions and 12 deletions

View File

@ -294,6 +294,7 @@ class CommGeneric:
logger.debug("running kernel") logger.debug("running kernel")
_rpc_sentinel = object() _rpc_sentinel = object()
_rpc_undefined = object()
# See session.c:{send,receive}_rpc_value and llvm_ir_generator.py:_rpc_tag. # See session.c:{send,receive}_rpc_value and llvm_ir_generator.py:_rpc_tag.
def _receive_rpc_value(self, object_map): def _receive_rpc_value(self, object_map):
@ -331,18 +332,23 @@ class CommGeneric:
present = self._read_int8() present = self._read_int8()
if present: if present:
return self._receive_rpc_value(object_map) return self._receive_rpc_value(object_map)
else:
return self._rpc_undefined
elif tag == "O": elif tag == "O":
return object_map.retrieve(self._read_int32()) return object_map.retrieve(self._read_int32())
else: else:
raise IOError("Unknown RPC value tag: {}".format(repr(tag))) raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
def _receive_rpc_args(self, object_map): def _receive_rpc_args(self, object_map, defaults):
args = [] args = []
while True: while True:
value = self._receive_rpc_value(object_map) value = self._receive_rpc_value(object_map)
if value is self._rpc_sentinel: if value is self._rpc_sentinel:
return args return args
args.append(value) elif value is self._rpc_undefined:
args.append(defaults[len(args)])
else:
args.append(value)
def _skip_rpc_value(self, tags): def _skip_rpc_value(self, tags):
tag = tags.pop(0) tag = tags.pop(0)
@ -425,22 +431,22 @@ class CommGeneric:
raise IOError("Unknown RPC value tag: {}".format(repr(tag))) raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
def _serve_rpc(self, object_map): def _serve_rpc(self, object_map):
service = self._read_int32() service_id = self._read_int32()
args = self._receive_rpc_args(object_map) service = object_map.retrieve(service_id)
arguments = self._receive_rpc_args(object_map, service.__defaults__)
return_tags = self._read_bytes() return_tags = self._read_bytes()
logger.debug("rpc service: %d %r -> %s", service, args, return_tags) logger.debug("rpc service: [%d]%r %r -> %s", service_id, service, arguments, return_tags)
try: try:
result = object_map.retrieve(service)(*args) result = service(*arguments)
logger.debug("rpc service: %d %r == %r", service, args, result) logger.debug("rpc service: %d %r == %r", service_id, arguments, result)
self._write_header(_H2DMsgType.RPC_REPLY) self._write_header(_H2DMsgType.RPC_REPLY)
self._write_bytes(return_tags) self._write_bytes(return_tags)
self._send_rpc_value(bytearray(return_tags), result, result, self._send_rpc_value(bytearray(return_tags), result, result, service)
object_map.retrieve(service))
self._write_flush() self._write_flush()
except core_language.ARTIQException as exn: except core_language.ARTIQException as exn:
logger.debug("rpc service: %d %r ! %r", service, args, exn) logger.debug("rpc service: %d %r ! %r", service_id, arguments, exn)
self._write_header(_H2DMsgType.RPC_EXCEPTION) self._write_header(_H2DMsgType.RPC_EXCEPTION)
self._write_string(exn.name) self._write_string(exn.name)
@ -455,7 +461,7 @@ class CommGeneric:
self._write_flush() self._write_flush()
except Exception as exn: except Exception as exn:
logger.debug("rpc service: %d %r ! %r", service, args, exn) logger.debug("rpc service: %d %r ! %r", service_id, arguments, exn)
self._write_header(_H2DMsgType.RPC_EXCEPTION) self._write_header(_H2DMsgType.RPC_EXCEPTION)
self._write_string(type(exn).__name__) self._write_string(type(exn).__name__)

View File

@ -10,7 +10,6 @@ class Roundtrip(EnvExperiment):
def roundtrip(self, obj, fn): def roundtrip(self, obj, fn):
fn(obj) fn(obj)
class RoundtripTest(ExperimentCase): class RoundtripTest(ExperimentCase):
def assertRoundtrip(self, obj): def assertRoundtrip(self, obj):
exp = self.create(Roundtrip) exp = self.create(Roundtrip)
@ -41,3 +40,22 @@ class RoundtripTest(ExperimentCase):
def test_object(self): def test_object(self):
obj = object() obj = object()
self.assertRoundtrip(obj) self.assertRoundtrip(obj)
class DefaultArg(EnvExperiment):
def build(self):
self.setattr_device("core")
def test(self, foo=42) -> TInt32:
return foo
@kernel
def run(self, callback):
callback(self.test())
class DefaultArgTest(ExperimentCase):
def test_default_arg(self):
exp = self.create(DefaultArg)
def callback(value):
self.assertEqual(value, 42)
exp.run(callback)