mirror of https://github.com/m-labs/artiq.git
transforms/inline: encode exceptions
This commit is contained in:
parent
655835e8de
commit
e821f9eb83
|
@ -44,7 +44,8 @@ class Core:
|
||||||
# transform/simplify AST
|
# transform/simplify AST
|
||||||
_debug_unparse = _make_debug_unparse("fold_constants_2")
|
_debug_unparse = _make_debug_unparse("fold_constants_2")
|
||||||
|
|
||||||
func_def, rpc_map = inline(self, k_function, k_args, k_kwargs)
|
func_def, rpc_map, exception_map = inline(
|
||||||
|
self, k_function, k_args, k_kwargs)
|
||||||
_debug_unparse("inline", func_def)
|
_debug_unparse("inline", func_def)
|
||||||
|
|
||||||
lower_units(func_def, self.runtime_env.ref_period)
|
lower_units(func_def, self.runtime_env.ref_period)
|
||||||
|
|
|
@ -27,13 +27,14 @@ class _ReferenceManager:
|
||||||
# inlined_name -> use_count
|
# inlined_name -> use_count
|
||||||
self.use_count = dict()
|
self.use_count = dict()
|
||||||
self.rpc_map = defaultdict(lambda: len(self.rpc_map))
|
self.rpc_map = defaultdict(lambda: len(self.rpc_map))
|
||||||
|
self.exception_map = defaultdict(lambda: len(self.exception_map))
|
||||||
self.kernel_attr_init = []
|
self.kernel_attr_init = []
|
||||||
|
|
||||||
# reserved names
|
# reserved names
|
||||||
for kg in core_language.kernel_globals:
|
for kg in core_language.kernel_globals:
|
||||||
self.use_count[kg] = 1
|
self.use_count[kg] = 1
|
||||||
for name in ("int", "round", "int64", "round64", "float", "array",
|
for name in ("int", "round", "int64", "round64", "float", "array",
|
||||||
"range", "Fraction", "Quantity",
|
"range", "Fraction", "Quantity", "EncodedException",
|
||||||
"s_unit", "Hz_unit", "microcycle_unit"):
|
"s_unit", "Hz_unit", "microcycle_unit"):
|
||||||
self.use_count[name] = 1
|
self.use_count[name] = 1
|
||||||
|
|
||||||
|
@ -93,7 +94,7 @@ _embeddable_calls = {
|
||||||
core_language.syscall,
|
core_language.syscall,
|
||||||
range, int, float, round,
|
range, int, float, round,
|
||||||
core_language.int64, core_language.round64, core_language.array,
|
core_language.int64, core_language.round64, core_language.array,
|
||||||
Fraction, units.Quantity
|
Fraction, units.Quantity, core_language.EncodedException
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -173,8 +174,8 @@ class _ReferenceReplacer(ast.NodeVisitor):
|
||||||
retval_name = self.rm.new_name(
|
retval_name = self.rm.new_name(
|
||||||
func.k_function_info.k_function.__name__ + "_return")
|
func.k_function_info.k_function.__name__ + "_return")
|
||||||
args = [func.__self__] + new_args
|
args = [func.__self__] + new_args
|
||||||
inlined, _ = inline(self.core, func.k_function_info.k_function,
|
inlined, _, _ = inline(self.core, func.k_function_info.k_function,
|
||||||
args, dict(), self.rm, retval_name)
|
args, dict(), self.rm, retval_name)
|
||||||
self._insertion_point.append(ast.With(
|
self._insertion_point.append(ast.With(
|
||||||
items=[ast.withitem(context_expr=ast.Name(id="sequential",
|
items=[ast.withitem(context_expr=ast.Name(id="sequential",
|
||||||
ctx=ast.Load()),
|
ctx=ast.Load()),
|
||||||
|
@ -214,6 +215,35 @@ class _ReferenceReplacer(ast.NodeVisitor):
|
||||||
self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
def visit_Raise(self, node):
|
||||||
|
if node.cause is not None:
|
||||||
|
raise NotImplementedError("Exception causes are not supported")
|
||||||
|
exception_class = self.rm.get(self.obj, self.func_name, node.exc)
|
||||||
|
if not inspect.isclass(exception_class):
|
||||||
|
raise NotImplementedError("Exception must be a class")
|
||||||
|
exception_id = self.rm.exception_map[exception_class]
|
||||||
|
node.exc = ast.copy_location(
|
||||||
|
ast.Call(func=ast.Name("EncodedException", ast.Load()),
|
||||||
|
args=[value_to_ast(exception_id)],
|
||||||
|
keywords=[], starargs=None, kwargs=None),
|
||||||
|
node.exc)
|
||||||
|
return node
|
||||||
|
|
||||||
|
def visit_ExceptHandler(self, node):
|
||||||
|
if node.name is not None:
|
||||||
|
raise NotImplementedError("'as target' is not supported")
|
||||||
|
exception_class = self.rm.get(self.obj, self.func_name, node.type)
|
||||||
|
if not inspect.isclass(exception_class):
|
||||||
|
raise NotImplementedError("Exception type must be a class")
|
||||||
|
exception_id = self.rm.exception_map[exception_class]
|
||||||
|
node.type = ast.copy_location(
|
||||||
|
ast.Call(func=ast.Name("EncodedException", ast.Load()),
|
||||||
|
args=[value_to_ast(exception_id)],
|
||||||
|
keywords=[], starargs=None, kwargs=None),
|
||||||
|
node.type)
|
||||||
|
self.generic_visit(node)
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
class _ListReadOnlyParams(ast.NodeVisitor):
|
class _ListReadOnlyParams(ast.NodeVisitor):
|
||||||
def visit_FunctionDef(self, node):
|
def visit_FunctionDef(self, node):
|
||||||
|
@ -272,4 +302,7 @@ def inline(core, k_function, k_args, k_kwargs, rm=None, retval_name=None):
|
||||||
|
|
||||||
r_rpc_map = dict((rpc_num, rpc_fun)
|
r_rpc_map = dict((rpc_num, rpc_fun)
|
||||||
for rpc_fun, rpc_num in rm.rpc_map.items())
|
for rpc_fun, rpc_num in rm.rpc_map.items())
|
||||||
return func_def, r_rpc_map
|
r_exception_map = dict((exception_num, exception_class)
|
||||||
|
for exception_class, exception_num
|
||||||
|
in rm.exception_map.items())
|
||||||
|
return func_def, r_rpc_map, r_exception_map
|
||||||
|
|
Loading…
Reference in New Issue