From 52102a1a79a2afc5f71e5e5ba8998d7cb4d52828 Mon Sep 17 00:00:00 2001 From: whitequark Date: Fri, 18 Dec 2015 18:03:07 +0800 Subject: [PATCH] Fix handling of default values for RPC arguments (fixes #190). --- artiq/coredevice/comm_generic.py | 28 +++++++++++++++++----------- artiq/test/coredevice/embedding.py | 20 +++++++++++++++++++- 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/artiq/coredevice/comm_generic.py b/artiq/coredevice/comm_generic.py index c5ecd4309..1d3d69f2a 100644 --- a/artiq/coredevice/comm_generic.py +++ b/artiq/coredevice/comm_generic.py @@ -294,6 +294,7 @@ class CommGeneric: logger.debug("running kernel") _rpc_sentinel = object() + _rpc_undefined = object() # See session.c:{send,receive}_rpc_value and llvm_ir_generator.py:_rpc_tag. def _receive_rpc_value(self, object_map): @@ -331,18 +332,23 @@ class CommGeneric: present = self._read_int8() if present: return self._receive_rpc_value(object_map) + else: + return self._rpc_undefined elif tag == "O": return object_map.retrieve(self._read_int32()) else: raise IOError("Unknown RPC value tag: {}".format(repr(tag))) - def _receive_rpc_args(self, object_map): + def _receive_rpc_args(self, object_map, defaults): args = [] while True: value = self._receive_rpc_value(object_map) if value is self._rpc_sentinel: return args - args.append(value) + elif value is self._rpc_undefined: + args.append(defaults[len(args)]) + else: + args.append(value) def _skip_rpc_value(self, tags): tag = tags.pop(0) @@ -425,22 +431,22 @@ class CommGeneric: raise IOError("Unknown RPC value tag: {}".format(repr(tag))) def _serve_rpc(self, object_map): - service = self._read_int32() - args = self._receive_rpc_args(object_map) + service_id = self._read_int32() + service = object_map.retrieve(service_id) + arguments = self._receive_rpc_args(object_map, service.__defaults__) return_tags = self._read_bytes() - logger.debug("rpc service: %d %r -> %s", service, args, return_tags) + logger.debug("rpc service: [%d]%r %r -> %s", service_id, service, arguments, return_tags) try: - result = object_map.retrieve(service)(*args) - logger.debug("rpc service: %d %r == %r", service, args, result) + result = service(*arguments) + logger.debug("rpc service: %d %r == %r", service_id, arguments, result) self._write_header(_H2DMsgType.RPC_REPLY) self._write_bytes(return_tags) - self._send_rpc_value(bytearray(return_tags), result, result, - object_map.retrieve(service)) + self._send_rpc_value(bytearray(return_tags), result, result, service) self._write_flush() except core_language.ARTIQException as exn: - logger.debug("rpc service: %d %r ! %r", service, args, exn) + logger.debug("rpc service: %d %r ! %r", service_id, arguments, exn) self._write_header(_H2DMsgType.RPC_EXCEPTION) self._write_string(exn.name) @@ -455,7 +461,7 @@ class CommGeneric: self._write_flush() except Exception as exn: - logger.debug("rpc service: %d %r ! %r", service, args, exn) + logger.debug("rpc service: %d %r ! %r", service_id, arguments, exn) self._write_header(_H2DMsgType.RPC_EXCEPTION) self._write_string(type(exn).__name__) diff --git a/artiq/test/coredevice/embedding.py b/artiq/test/coredevice/embedding.py index b80abe0c2..a5a7095ad 100644 --- a/artiq/test/coredevice/embedding.py +++ b/artiq/test/coredevice/embedding.py @@ -10,7 +10,6 @@ class Roundtrip(EnvExperiment): def roundtrip(self, obj, fn): fn(obj) - class RoundtripTest(ExperimentCase): def assertRoundtrip(self, obj): exp = self.create(Roundtrip) @@ -41,3 +40,22 @@ class RoundtripTest(ExperimentCase): def test_object(self): obj = object() self.assertRoundtrip(obj) + + +class DefaultArg(EnvExperiment): + def build(self): + self.setattr_device("core") + + def test(self, foo=42) -> TInt32: + return foo + + @kernel + def run(self, callback): + callback(self.test()) + +class DefaultArgTest(ExperimentCase): + def test_default_arg(self): + exp = self.create(DefaultArg) + def callback(value): + self.assertEqual(value, 42) + exp.run(callback)