diff --git a/.gitignore b/.gitignore index 950cbbedb..8153f9c71 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ __pycache__ *.bin *.elf *.fbi +*.pyc doc/manual/_build /build /dist @@ -15,3 +16,10 @@ artiq/test/h5types.h5 examples/master/results examples/master/dataset_db.pyon examples/sim/dataset_db.pyon +Output/ +/lit-test/libartiq_support/libartiq_support.so + +# for developer convenience +/test*.py +/device_db.pyon +/dataset_db.pyon diff --git a/.gitmodules b/.gitmodules index dcb309bfe..d6a4341d5 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,4 +1,4 @@ [submodule "artiq/runtime/lwip"] path = artiq/runtime/lwip url = git://git.savannah.nongnu.org/lwip.git - ignore = untracked + ignore = untracked diff --git a/.travis.yml b/.travis.yml index 881694a54..b62e44c07 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,9 @@ language: python python: - '3.5' +branches: + only: + - master sudo: false env: global: diff --git a/artiq/compiler/__init__.py b/artiq/compiler/__init__.py new file mode 100644 index 000000000..0fabfd4ad --- /dev/null +++ b/artiq/compiler/__init__.py @@ -0,0 +1,2 @@ +from .module import Module, Source +from .embedding import Stitcher diff --git a/artiq/compiler/algorithms/__init__.py b/artiq/compiler/algorithms/__init__.py new file mode 100644 index 000000000..47fcd2dbf --- /dev/null +++ b/artiq/compiler/algorithms/__init__.py @@ -0,0 +1 @@ +from .inline import inline diff --git a/artiq/compiler/algorithms/inline.py b/artiq/compiler/algorithms/inline.py new file mode 100644 index 000000000..b0b691919 --- /dev/null +++ b/artiq/compiler/algorithms/inline.py @@ -0,0 +1,80 @@ +""" +:func:`inline` inlines a call instruction in ARTIQ IR. +The call instruction must have a statically known callee, +it must be second to last in the basic block, and the basic +block must have exactly one successor. +""" + +from .. import types, builtins, iodelay, ir + +def inline(call_insn): + assert isinstance(call_insn, ir.Call) + assert call_insn.static_target_function is not None + assert len(call_insn.basic_block.successors()) == 1 + assert call_insn.basic_block.index(call_insn) == \ + len(call_insn.basic_block.instructions) - 2 + + value_map = {} + source_function = call_insn.static_target_function + target_function = call_insn.basic_block.function + target_predecessor = call_insn.basic_block + target_successor = call_insn.basic_block.successors()[0] + + if builtins.is_none(source_function.type.ret): + target_return_phi = None + else: + target_return_phi = target_successor.prepend(ir.Phi(source_function.type.ret)) + + closure = target_predecessor.insert(ir.GetAttr(call_insn.target_function(), '__closure__'), + before=call_insn) + for actual_arg, formal_arg in zip([closure] + call_insn.arguments(), + source_function.arguments): + value_map[formal_arg] = actual_arg + + for source_block in source_function.basic_blocks: + target_block = ir.BasicBlock([], "i." + source_block.name) + target_function.add(target_block) + value_map[source_block] = target_block + + def mapper(value): + if isinstance(value, ir.Constant): + return value + else: + return value_map[value] + + for source_insn in source_function.instructions(): + target_block = value_map[source_insn.basic_block] + if isinstance(source_insn, ir.Return): + if target_return_phi is not None: + target_return_phi.add_incoming(mapper(source_insn.value()), target_block) + target_insn = ir.Branch(target_successor) + elif isinstance(source_insn, ir.Phi): + target_insn = ir.Phi() + elif isinstance(source_insn, ir.Delay): + substs = source_insn.substs() + mapped_substs = {var: value_map[substs[var]] for var in substs} + const_substs = {var: iodelay.Const(mapped_substs[var].value) + for var in mapped_substs + if isinstance(mapped_substs[var], ir.Constant)} + other_substs = {var: mapped_substs[var] + for var in mapped_substs + if not isinstance(mapped_substs[var], ir.Constant)} + target_insn = ir.Delay(source_insn.expr.fold(const_substs), other_substs, + value_map[source_insn.decomposition()], + value_map[source_insn.target()]) + else: + target_insn = source_insn.copy(mapper) + target_insn.name = "i." + source_insn.name + value_map[source_insn] = target_insn + target_block.append(target_insn) + + for source_insn in source_function.instructions(): + if isinstance(source_insn, ir.Phi): + target_insn = value_map[source_insn] + for block, value in source_insn.incoming(): + target_insn.add_incoming(value_map[value], value_map[block]) + + target_predecessor.terminator().replace_with(ir.Branch(value_map[source_function.entry()])) + if target_return_phi is not None: + call_insn.replace_all_uses_with(target_return_phi) + call_insn.erase() diff --git a/artiq/compiler/analyses/__init__.py b/artiq/compiler/analyses/__init__.py new file mode 100644 index 000000000..708b325ef --- /dev/null +++ b/artiq/compiler/analyses/__init__.py @@ -0,0 +1,2 @@ +from .domination import DominatorTree +from .devirtualization import Devirtualization diff --git a/artiq/compiler/analyses/devirtualization.py b/artiq/compiler/analyses/devirtualization.py new file mode 100644 index 000000000..14f93c882 --- /dev/null +++ b/artiq/compiler/analyses/devirtualization.py @@ -0,0 +1,119 @@ +""" +:class:`Devirtualizer` performs method resolution at +compile time. + +Devirtualization is implemented using a lattice +with three states: unknown → assigned once → diverges. +The lattice is computed individually for every +variable in scope as well as every +(instance type, field name) pair. +""" + +from pythonparser import algorithm +from .. import asttyped, ir, types + +def _advance(target_map, key, value): + if key not in target_map: + target_map[key] = value # unknown → assigned once + else: + target_map[key] = None # assigned once → diverges + +class FunctionResolver(algorithm.Visitor): + def __init__(self, variable_map): + self.variable_map = variable_map + + self.scope_map = dict() + self.queue = [] + + self.in_assign = False + self.current_scopes = [] + + def finalize(self): + for thunk in self.queue: + thunk() + + def visit_scope(self, node): + self.current_scopes.append(node) + self.generic_visit(node) + self.current_scopes.pop() + + def visit_in_assign(self, node): + self.in_assign = True + self.visit(node) + self.in_assign = False + + def visit_Assign(self, node): + self.visit(node.value) + self.visit_in_assign(node.targets) + + def visit_For(self, node): + self.visit(node.iter) + self.visit_in_assign(node.target) + self.visit(node.body) + self.visit(node.orelse) + + def visit_withitem(self, node): + self.visit(node.context_expr) + self.visit_in_assign(node.optional_vars) + + def visit_comprehension(self, node): + self.visit(node.iter) + self.visit_in_assign(node.target) + self.visit(node.ifs) + + def visit_ModuleT(self, node): + self.visit_scope(node) + + def visit_FunctionDefT(self, node): + _advance(self.scope_map, (self.current_scopes[-1], node.name), node) + self.visit_scope(node) + + def visit_NameT(self, node): + if self.in_assign: + # Just give up if we assign anything at all to a variable, and + # assume it diverges. + _advance(self.scope_map, (self.current_scopes[-1], node.id), None) + else: + # Look up the final value in scope_map and copy it into variable_map. + keys = [(scope, node.id) for scope in reversed(self.current_scopes)] + def thunk(): + for key in keys: + if key in self.scope_map: + self.variable_map[node] = self.scope_map[key] + return + self.queue.append(thunk) + +class MethodResolver(algorithm.Visitor): + def __init__(self, variable_map, method_map): + self.variable_map = variable_map + self.method_map = method_map + + # embedding.Stitcher.finalize generates initialization statements + # of form "constructor.meth = meth_body". + def visit_Assign(self, node): + if node.value not in self.variable_map: + return + + value = self.variable_map[node.value] + for target in node.targets: + if isinstance(target, asttyped.AttributeT): + if types.is_constructor(target.value.type): + instance_type = target.value.type.instance + elif types.is_instance(target.value.type): + instance_type = target.value.type + else: + continue + _advance(self.method_map, (instance_type, target.attr), value) + +class Devirtualization: + def __init__(self): + self.variable_map = dict() + self.method_map = dict() + + def visit(self, node): + function_resolver = FunctionResolver(self.variable_map) + function_resolver.visit(node) + function_resolver.finalize() + + method_resolver = MethodResolver(self.variable_map, self.method_map) + method_resolver.visit(node) diff --git a/artiq/compiler/analyses/domination.py b/artiq/compiler/analyses/domination.py new file mode 100644 index 000000000..53aa3a2f1 --- /dev/null +++ b/artiq/compiler/analyses/domination.py @@ -0,0 +1,137 @@ +""" +:class:`DominatorTree` computes the dominance relation over +control flow graphs. + +See http://www.cs.rice.edu/~keith/EMBED/dom.pdf. +""" + +class GenericDominatorTree: + def __init__(self): + self._assign_names() + self._compute() + + def _traverse_in_postorder(self): + raise NotImplementedError + + def _prev_block_names(self, block): + raise NotImplementedError + + def _assign_names(self): + postorder = self._traverse_in_postorder() + + self._start_name = len(postorder) - 1 + self._block_of_name = postorder + self._name_of_block = {} + for block_name, block in enumerate(postorder): + self._name_of_block[block] = block_name + + def _intersect(self, block_name_1, block_name_2): + finger_1, finger_2 = block_name_1, block_name_2 + while finger_1 != finger_2: + while finger_1 < finger_2: + finger_1 = self._doms[finger_1] + while finger_2 < finger_1: + finger_2 = self._doms[finger_2] + return finger_1 + + def _compute(self): + self._doms = {} + + # Start block dominates itself. + self._doms[self._start_name] = self._start_name + + # We don't yet know what blocks dominate all other blocks. + for block_name in range(self._start_name): + self._doms[block_name] = None + + changed = True + while changed: + changed = False + + # For all blocks except start block, in reverse postorder... + for block_name in reversed(range(self._start_name)): + # Select a new immediate dominator from the blocks we have + # already processed, and remember all others. + # We've already processed at least one previous block because + # of the graph traverse order. + new_idom, prev_block_names = None, [] + for prev_block_name in self._prev_block_names(block_name): + if new_idom is None and self._doms[prev_block_name] is not None: + new_idom = prev_block_name + else: + prev_block_names.append(prev_block_name) + + # Find a common previous block + for prev_block_name in prev_block_names: + if self._doms[prev_block_name] is not None: + new_idom = self._intersect(prev_block_name, new_idom) + + if self._doms[block_name] != new_idom: + self._doms[block_name] = new_idom + changed = True + + def immediate_dominator(self, block): + return self._block_of_name[self._doms[self._name_of_block[block]]] + + def dominators(self, block): + yield block + + block_name = self._name_of_block[block] + while block_name != self._doms[block_name]: + block_name = self._doms[block_name] + yield self._block_of_name[block_name] + +class DominatorTree(GenericDominatorTree): + def __init__(self, function): + self.function = function + super().__init__() + + def _traverse_in_postorder(self): + postorder = [] + + visited = set() + def visit(block): + visited.add(block) + for next_block in block.successors(): + if next_block not in visited: + visit(next_block) + postorder.append(block) + + visit(self.function.entry()) + + return postorder + + def _prev_block_names(self, block_name): + for block in self._block_of_name[block_name].predecessors(): + yield self._name_of_block[block] + +class PostDominatorTree(GenericDominatorTree): + def __init__(self, function): + self.function = function + super().__init__() + + def _traverse_in_postorder(self): + postorder = [] + + visited = set() + def visit(block): + visited.add(block) + for next_block in block.predecessors(): + if next_block not in visited: + visit(next_block) + postorder.append(block) + + for block in self.function.basic_blocks: + if not any(block.successors()): + visit(block) + + postorder.append(None) # virtual exit block + return postorder + + def _prev_block_names(self, block_name): + succ_blocks = self._block_of_name[block_name].successors() + if len(succ_blocks) > 0: + for block in succ_blocks: + yield self._name_of_block[block] + else: + yield self._start_name diff --git a/artiq/compiler/asttyped.py b/artiq/compiler/asttyped.py new file mode 100644 index 000000000..6d7908d8d --- /dev/null +++ b/artiq/compiler/asttyped.py @@ -0,0 +1,99 @@ +""" +The typedtree module exports the PythonParser AST enriched with +typing information. +""" + +from pythonparser import ast + +class commontyped(ast.commonloc): + """A mixin for typed AST nodes.""" + + _types = ("type",) + + def _reprfields(self): + return self._fields + self._locs + self._types + +class scoped(object): + """ + :ivar typing_env: (dict with string keys and :class:`.types.Type` values) + map of variable names to variable types + :ivar globals_in_scope: (set of string keys) + set of variables resolved as globals + """ + +# Typed versions of untyped nodes +class argT(ast.arg, commontyped): + pass + +class ClassDefT(ast.ClassDef): + _types = ("constructor_type",) +class FunctionDefT(ast.FunctionDef, scoped): + _types = ("signature_type",) +class ModuleT(ast.Module, scoped): + pass + +class ExceptHandlerT(ast.ExceptHandler): + _fields = ("filter", "name", "body") # rename ast.ExceptHandler.type to filter + _types = ("name_type",) + +class SliceT(ast.Slice, commontyped): + pass + +class AttributeT(ast.Attribute, commontyped): + pass +class BinOpT(ast.BinOp, commontyped): + pass +class BoolOpT(ast.BoolOp, commontyped): + pass +class CallT(ast.Call, commontyped): + """ + :ivar iodelay: (:class:`iodelay.Expr`) + """ +class CompareT(ast.Compare, commontyped): + pass +class DictT(ast.Dict, commontyped): + pass +class DictCompT(ast.DictComp, commontyped, scoped): + pass +class EllipsisT(ast.Ellipsis, commontyped): + pass +class GeneratorExpT(ast.GeneratorExp, commontyped, scoped): + pass +class IfExpT(ast.IfExp, commontyped): + pass +class LambdaT(ast.Lambda, commontyped, scoped): + pass +class ListT(ast.List, commontyped): + pass +class ListCompT(ast.ListComp, commontyped, scoped): + pass +class NameT(ast.Name, commontyped): + pass +class NameConstantT(ast.NameConstant, commontyped): + pass +class NumT(ast.Num, commontyped): + pass +class SetT(ast.Set, commontyped): + pass +class SetCompT(ast.SetComp, commontyped, scoped): + pass +class StrT(ast.Str, commontyped): + pass +class StarredT(ast.Starred, commontyped): + pass +class SubscriptT(ast.Subscript, commontyped): + pass +class TupleT(ast.Tuple, commontyped): + pass +class UnaryOpT(ast.UnaryOp, commontyped): + pass +class YieldT(ast.Yield, commontyped): + pass +class YieldFromT(ast.YieldFrom, commontyped): + pass + +# Novel typed nodes +class CoerceT(ast.expr, commontyped): + _fields = ('value',) # other_value deliberately not in _fields +class QuoteT(ast.expr, commontyped): + _fields = ('value',) diff --git a/artiq/compiler/builtins.py b/artiq/compiler/builtins.py new file mode 100644 index 000000000..99065d39b --- /dev/null +++ b/artiq/compiler/builtins.py @@ -0,0 +1,245 @@ +""" +The :mod:`builtins` module contains the builtin Python +and ARTIQ types, such as int or float. +""" + +from collections import OrderedDict +from . import types + +# Types + +class TNone(types.TMono): + def __init__(self): + super().__init__("NoneType") + +class TBool(types.TMono): + def __init__(self): + super().__init__("bool") + + @staticmethod + def zero(): + return False + + @staticmethod + def one(): + return True + +class TInt(types.TMono): + def __init__(self, width=None): + if width is None: + width = types.TVar() + super().__init__("int", {"width": width}) + + @staticmethod + def zero(): + return 0 + + @staticmethod + def one(): + return 1 + +def TInt32(): + return TInt(types.TValue(32)) + +def TInt64(): + return TInt(types.TValue(64)) + +class TFloat(types.TMono): + def __init__(self): + super().__init__("float") + + @staticmethod + def zero(): + return 0.0 + + @staticmethod + def one(): + return 1.0 + +class TStr(types.TMono): + def __init__(self): + super().__init__("str") + +class TList(types.TMono): + def __init__(self, elt=None): + if elt is None: + elt = types.TVar() + super().__init__("list", {"elt": elt}) + +class TRange(types.TMono): + def __init__(self, elt=None): + if elt is None: + elt = types.TVar() + super().__init__("range", {"elt": elt}) + self.attributes = OrderedDict([ + ("start", elt), + ("stop", elt), + ("step", elt), + ]) + +class TException(types.TMono): + # All exceptions share the same internal layout: + # * Pointer to the unique global with the name of the exception (str) + # (which also serves as the EHABI type_info). + # * File, line and column where it was raised (str, int, int). + # * Message, which can contain substitutions {0}, {1} and {2} (str). + # * Three 64-bit integers, parameterizing the message (int(width=64)). + + + # Keep this in sync with the function ARTIQIRGenerator.alloc_exn. + attributes = OrderedDict([ + ("__name__", TStr()), + ("__file__", TStr()), + ("__line__", TInt(types.TValue(32))), + ("__col__", TInt(types.TValue(32))), + ("__func__", TStr()), + ("__message__", TStr()), + ("__param0__", TInt(types.TValue(64))), + ("__param1__", TInt(types.TValue(64))), + ("__param2__", TInt(types.TValue(64))), + ]) + + def __init__(self, name="Exception"): + super().__init__(name) + +def fn_bool(): + return types.TConstructor(TBool()) + +def fn_int(): + return types.TConstructor(TInt()) + +def fn_float(): + return types.TConstructor(TFloat()) + +def fn_str(): + return types.TConstructor(TStr()) + +def fn_list(): + return types.TConstructor(TList()) + +def fn_Exception(): + return types.TExceptionConstructor(TException("Exception")) + +def fn_IndexError(): + return types.TExceptionConstructor(TException("IndexError")) + +def fn_ValueError(): + return types.TExceptionConstructor(TException("ValueError")) + +def fn_ZeroDivisionError(): + return types.TExceptionConstructor(TException("ZeroDivisionError")) + +def fn_range(): + return types.TBuiltinFunction("range") + +def fn_len(): + return types.TBuiltinFunction("len") + +def fn_round(): + return types.TBuiltinFunction("round") + +def fn_print(): + return types.TBuiltinFunction("print") + +def fn_kernel(): + return types.TBuiltinFunction("kernel") + +def fn_parallel(): + return types.TBuiltinFunction("parallel") + +def fn_sequential(): + return types.TBuiltinFunction("sequential") + +def fn_now(): + return types.TBuiltinFunction("now") + +def fn_delay(): + return types.TBuiltinFunction("delay") + +def fn_at(): + return types.TBuiltinFunction("at") + +def fn_now_mu(): + return types.TBuiltinFunction("now_mu") + +def fn_delay_mu(): + return types.TBuiltinFunction("delay_mu") + +def fn_at_mu(): + return types.TBuiltinFunction("at_mu") + +def fn_mu_to_seconds(): + return types.TBuiltinFunction("mu_to_seconds") + +def fn_seconds_to_mu(): + return types.TBuiltinFunction("seconds_to_mu") + +# Accessors + +def is_none(typ): + return types.is_mono(typ, "NoneType") + +def is_bool(typ): + return types.is_mono(typ, "bool") + +def is_int(typ, width=None): + if width is not None: + return types.is_mono(typ, "int", width=width) + else: + return types.is_mono(typ, "int") + +def get_int_width(typ): + if is_int(typ): + return types.get_value(typ.find()["width"]) + +def is_float(typ): + return types.is_mono(typ, "float") + +def is_str(typ): + return types.is_mono(typ, "str") + +def is_numeric(typ): + typ = typ.find() + return isinstance(typ, types.TMono) and \ + typ.name in ('int', 'float') + +def is_list(typ, elt=None): + if elt is not None: + return types.is_mono(typ, "list", elt=elt) + else: + return types.is_mono(typ, "list") + +def is_range(typ, elt=None): + if elt is not None: + return types.is_mono(typ, "range", {"elt": elt}) + else: + return types.is_mono(typ, "range") + +def is_exception(typ, name=None): + if name is None: + return isinstance(typ.find(), TException) + else: + return isinstance(typ.find(), TException) and \ + typ.name == name + +def is_iterable(typ): + typ = typ.find() + return isinstance(typ, types.TMono) and \ + typ.name in ('list', 'range') + +def get_iterable_elt(typ): + if is_iterable(typ): + return typ.find()["elt"].find() + +def is_collection(typ): + typ = typ.find() + return isinstance(typ, types.TTuple) or \ + types.is_mono(typ, "list") + +def is_allocated(typ): + return not (is_none(typ) or is_bool(typ) or is_int(typ) or + is_float(typ) or is_range(typ) or + types._is_pointer(typ) or types.is_function(typ) or + types.is_c_function(typ) or types.is_rpc_function(typ) or + types.is_method(typ) or types.is_tuple(typ) or + types.is_value(typ)) diff --git a/artiq/compiler/embedding.py b/artiq/compiler/embedding.py new file mode 100644 index 000000000..637153dc0 --- /dev/null +++ b/artiq/compiler/embedding.py @@ -0,0 +1,676 @@ +""" +The :class:`Stitcher` class allows to transparently combine compiled +Python code and Python code executed on the host system: it resolves +the references to the host objects and translates the functions +annotated as ``@kernel`` when they are referenced. +""" + +import sys, os, re, linecache, inspect, textwrap +from collections import OrderedDict, defaultdict + +from pythonparser import ast, algorithm, source, diagnostic, parse_buffer +from pythonparser import lexer as source_lexer, parser as source_parser + +from Levenshtein import jaro_winkler + +from ..language import core as language_core +from . import types, builtins, asttyped, prelude +from .transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer + + +class ObjectMap: + def __init__(self): + self.current_key = 0 + self.forward_map = {} + self.reverse_map = {} + + def store(self, obj_ref): + obj_id = id(obj_ref) + if obj_id in self.reverse_map: + return self.reverse_map[obj_id] + + self.current_key += 1 + self.forward_map[self.current_key] = obj_ref + self.reverse_map[obj_id] = self.current_key + return self.current_key + + def retrieve(self, obj_key): + return self.forward_map[obj_key] + + def has_rpc(self): + return any(filter(lambda x: inspect.isfunction(x) or inspect.ismethod(x), + self.forward_map.values())) + +class ASTSynthesizer: + def __init__(self, type_map, value_map, quote_function=None, expanded_from=None): + self.source = "" + self.source_buffer = source.Buffer(self.source, "") + self.type_map, self.value_map = type_map, value_map + self.quote_function = quote_function + self.expanded_from = expanded_from + + def finalize(self): + self.source_buffer.source = self.source + return self.source_buffer + + def _add(self, fragment): + range_from = len(self.source) + self.source += fragment + range_to = len(self.source) + return source.Range(self.source_buffer, range_from, range_to, + expanded_from=self.expanded_from) + + def quote(self, value): + """Construct an AST fragment equal to `value`.""" + if value is None: + typ = builtins.TNone() + return asttyped.NameConstantT(value=value, type=typ, + loc=self._add(repr(value))) + elif value is True or value is False: + typ = builtins.TBool() + return asttyped.NameConstantT(value=value, type=typ, + loc=self._add(repr(value))) + elif isinstance(value, (int, float)): + if isinstance(value, int): + typ = builtins.TInt() + elif isinstance(value, float): + typ = builtins.TFloat() + return asttyped.NumT(n=value, ctx=None, type=typ, + loc=self._add(repr(value))) + elif isinstance(value, language_core.int): + typ = builtins.TInt(width=types.TValue(value.width)) + return asttyped.NumT(n=int(value), ctx=None, type=typ, + loc=self._add(repr(value))) + elif isinstance(value, str): + return asttyped.StrT(s=value, ctx=None, type=builtins.TStr(), + loc=self._add(repr(value))) + elif isinstance(value, list): + begin_loc = self._add("[") + elts = [] + for index, elt in enumerate(value): + elts.append(self.quote(elt)) + if index < len(value) - 1: + self._add(", ") + end_loc = self._add("]") + return asttyped.ListT(elts=elts, ctx=None, type=builtins.TList(), + begin_loc=begin_loc, end_loc=end_loc, + loc=begin_loc.join(end_loc)) + elif inspect.isfunction(value) or inspect.ismethod(value): + quote_loc = self._add('`') + repr_loc = self._add(repr(value)) + unquote_loc = self._add('`') + loc = quote_loc.join(unquote_loc) + + function_name, function_type = self.quote_function(value, self.expanded_from) + return asttyped.NameT(id=function_name, ctx=None, type=function_type, loc=loc) + else: + quote_loc = self._add('`') + repr_loc = self._add(repr(value)) + unquote_loc = self._add('`') + loc = quote_loc.join(unquote_loc) + + if isinstance(value, type): + typ = value + else: + typ = type(value) + + if typ in self.type_map: + instance_type, constructor_type = self.type_map[typ] + else: + instance_type = types.TInstance("{}.{}".format(typ.__module__, typ.__qualname__), + OrderedDict()) + instance_type.attributes['__objectid__'] = builtins.TInt(types.TValue(32)) + + constructor_type = types.TConstructor(instance_type) + constructor_type.attributes['__objectid__'] = builtins.TInt(types.TValue(32)) + instance_type.constructor = constructor_type + + self.type_map[typ] = instance_type, constructor_type + + if isinstance(value, type): + self.value_map[constructor_type].append((value, loc)) + return asttyped.QuoteT(value=value, type=constructor_type, + loc=loc) + else: + self.value_map[instance_type].append((value, loc)) + return asttyped.QuoteT(value=value, type=instance_type, + loc=loc) + + def call(self, function_node, args, kwargs): + """ + Construct an AST fragment calling a function specified by + an AST node `function_node`, with given arguments. + """ + arg_nodes = [] + kwarg_nodes = [] + kwarg_locs = [] + + name_loc = self._add(function_node.name) + begin_loc = self._add("(") + for index, arg in enumerate(args): + arg_nodes.append(self.quote(arg)) + if index < len(args) - 1: + self._add(", ") + if any(args) and any(kwargs): + self._add(", ") + for index, kw in enumerate(kwargs): + arg_loc = self._add(kw) + equals_loc = self._add("=") + kwarg_locs.append((arg_loc, equals_loc)) + kwarg_nodes.append(self.quote(kwargs[kw])) + if index < len(kwargs) - 1: + self._add(", ") + end_loc = self._add(")") + + return asttyped.CallT( + func=asttyped.NameT(id=function_node.name, ctx=None, + type=function_node.signature_type, + loc=name_loc), + args=arg_nodes, + keywords=[ast.keyword(arg=kw, value=value, + arg_loc=arg_loc, equals_loc=equals_loc, + loc=arg_loc.join(value.loc)) + for kw, value, (arg_loc, equals_loc) + in zip(kwargs, kwarg_nodes, kwarg_locs)], + starargs=None, kwargs=None, + type=types.TVar(), iodelay=None, + begin_loc=begin_loc, end_loc=end_loc, star_loc=None, dstar_loc=None, + loc=name_loc.join(end_loc)) + + def assign_local(self, var_name, value): + name_loc = self._add(var_name) + _ = self._add(" ") + equals_loc = self._add("=") + _ = self._add(" ") + value_node = self.quote(value) + + var_node = asttyped.NameT(id=var_name, ctx=None, type=value_node.type, + loc=name_loc) + + return ast.Assign(targets=[var_node], value=value_node, + op_locs=[equals_loc], loc=name_loc.join(value_node.loc)) + + def assign_attribute(self, obj, attr_name, value): + obj_node = self.quote(obj) + dot_loc = self._add(".") + name_loc = self._add(attr_name) + _ = self._add(" ") + equals_loc = self._add("=") + _ = self._add(" ") + value_node = self.quote(value) + + attr_node = asttyped.AttributeT(value=obj_node, attr=attr_name, ctx=None, + type=value_node.type, + dot_loc=dot_loc, attr_loc=name_loc, + loc=obj_node.loc.join(name_loc)) + + return ast.Assign(targets=[attr_node], value=value_node, + op_locs=[equals_loc], loc=name_loc.join(value_node.loc)) + +class StitchingASTTypedRewriter(ASTTypedRewriter): + def __init__(self, engine, prelude, globals, host_environment, quote): + super().__init__(engine, prelude) + self.globals = globals + self.env_stack.append(self.globals) + + self.host_environment = host_environment + self.quote = quote + + def visit_Name(self, node): + typ = super()._try_find_name(node.id) + if typ is not None: + # Value from device environment. + return asttyped.NameT(type=typ, id=node.id, ctx=node.ctx, + loc=node.loc) + else: + # Try to find this value in the host environment and quote it. + if node.id in self.host_environment: + return self.quote(self.host_environment[node.id], node.loc) + else: + suggestion = self._most_similar_ident(node.id) + if suggestion is not None: + diag = diagnostic.Diagnostic("fatal", + "name '{name}' is not bound to anything; did you mean '{suggestion}'?", + {"name": node.id, "suggestion": suggestion}, + node.loc) + self.engine.process(diag) + else: + diag = diagnostic.Diagnostic("fatal", + "name '{name}' is not bound to anything", {"name": node.id}, + node.loc) + self.engine.process(diag) + + def _most_similar_ident(self, id): + names = set() + names.update(self.host_environment.keys()) + for typing_env in reversed(self.env_stack): + names.update(typing_env.keys()) + + sorted_names = sorted(names, key=lambda other: jaro_winkler(id, other), reverse=True) + if len(sorted_names) > 0: + if jaro_winkler(id, sorted_names[0]) > 0.0: + return sorted_names[0] + +class StitchingInferencer(Inferencer): + def __init__(self, engine, value_map, quote): + super().__init__(engine) + self.value_map = value_map + self.quote = quote + + def visit_AttributeT(self, node): + self.generic_visit(node) + object_type = node.value.type.find() + + # The inferencer can only observe types, not values; however, + # when we work with host objects, we have to get the values + # somewhere, since host interpreter does not have types. + # Since we have categorized every host object we quoted according to + # its type, we now interrogate every host object we have to ensure + # that we can successfully serialize the value of the attribute we + # are now adding at the code generation stage. + # + # FIXME: We perform exhaustive checks of every known host object every + # time an attribute access is visited, which is potentially quadratic. + # This is done because it is simpler than performing the checks only when: + # * a previously unknown attribute is encountered, + # * a previously unknown host object is encountered; + # which would be the optimal solution. + for object_value, object_loc in self.value_map[object_type]: + if not hasattr(object_value, node.attr): + note = diagnostic.Diagnostic("note", + "attribute accessed here", {}, + node.loc) + diag = diagnostic.Diagnostic("error", + "host object does not have an attribute '{attr}'", + {"attr": node.attr}, + object_loc, notes=[note]) + self.engine.process(diag) + return + + # Figure out what ARTIQ type does the value of the attribute have. + # We do this by quoting it, as if to serialize. This has some + # overhead (i.e. synthesizing a source buffer), but has the advantage + # of having the host-to-ARTIQ mapping code in only one place and + # also immediately getting proper diagnostics on type errors. + attr_value = getattr(object_value, node.attr) + if (inspect.ismethod(attr_value) and hasattr(attr_value.__func__, 'artiq_embedded') + and types.is_instance(object_type)): + # In cases like: + # class c: + # @kernel + # def f(self): pass + # we want f to be defined on the class, not on the instance. + attributes = object_type.constructor.attributes + attr_value = attr_value.__func__ + else: + attributes = object_type.attributes + + ast = self.quote(attr_value, object_loc.expanded_from) + + def proxy_diagnostic(diag): + note = diagnostic.Diagnostic("note", + "while inferring a type for an attribute '{attr}' of a host object", + {"attr": node.attr}, + node.loc) + diag.notes.append(note) + + self.engine.process(diag) + + proxy_engine = diagnostic.Engine() + proxy_engine.process = proxy_diagnostic + Inferencer(engine=proxy_engine).visit(ast) + IntMonomorphizer(engine=proxy_engine).visit(ast) + + if node.attr not in attributes: + # We just figured out what the type should be. Add it. + attributes[node.attr] = ast.type + elif attributes[node.attr] != ast.type: + # Does this conflict with an earlier guess? + printer = types.TypePrinter() + diag = diagnostic.Diagnostic("error", + "host object has an attribute '{attr}' of type {typea}, which is" + " different from previously inferred type {typeb} for the same attribute", + {"typea": printer.name(ast.type), + "typeb": printer.name(attributes[node.attr]), + "attr": node.attr}, + object_loc) + self.engine.process(diag) + + super().visit_AttributeT(node) + +class TypedtreeHasher(algorithm.Visitor): + def generic_visit(self, node): + def freeze(obj): + if isinstance(obj, ast.AST): + return self.visit(obj) + elif isinstance(obj, types.Type): + return hash(obj.find()) + else: + # We don't care; only types change during inference. + pass + + fields = node._fields + if hasattr(node, '_types'): + fields = fields + node._types + return hash(tuple(freeze(getattr(node, field_name)) for field_name in fields)) + +class Stitcher: + def __init__(self, engine=None): + if engine is None: + self.engine = diagnostic.Engine(all_errors_are_fatal=True) + else: + self.engine = engine + + self.name = "" + self.typedtree = [] + self.inject_at = 0 + self.prelude = prelude.globals() + self.globals = {} + + self.functions = {} + + self.object_map = ObjectMap() + self.type_map = {} + self.value_map = defaultdict(lambda: []) + + def stitch_call(self, function, args, kwargs): + function_node = self._quote_embedded_function(function) + self.typedtree.append(function_node) + + # We synthesize source code for the initial call so that + # diagnostics would have something meaningful to display to the user. + synthesizer = self._synthesizer() + call_node = synthesizer.call(function_node, args, kwargs) + synthesizer.finalize() + self.typedtree.append(call_node) + + def finalize(self): + inferencer = StitchingInferencer(engine=self.engine, + value_map=self.value_map, + quote=self._quote) + hasher = TypedtreeHasher() + + # Iterate inference to fixed point. + old_typedtree_hash = None + while True: + inferencer.visit(self.typedtree) + typedtree_hash = hasher.visit(self.typedtree) + + if old_typedtree_hash == typedtree_hash: + break + old_typedtree_hash = typedtree_hash + + # For every host class we embed, add an appropriate constructor + # as a global. This is necessary for method lookup, which uses + # the getconstructor instruction. + for instance_type, constructor_type in list(self.type_map.values()): + # Do we have any direct reference to a constructor? + if len(self.value_map[constructor_type]) > 0: + # Yes, use it. + constructor, _constructor_loc = self.value_map[constructor_type][0] + else: + # No, extract one from a reference to an instance. + instance, _instance_loc = self.value_map[instance_type][0] + constructor = type(instance) + + self.globals[constructor_type.name] = constructor_type + + synthesizer = self._synthesizer() + ast = synthesizer.assign_local(constructor_type.name, constructor) + synthesizer.finalize() + self._inject(ast) + + for attr in constructor_type.attributes: + if types.is_function(constructor_type.attributes[attr]): + synthesizer = self._synthesizer() + ast = synthesizer.assign_attribute(constructor, attr, + getattr(constructor, attr)) + synthesizer.finalize() + self._inject(ast) + + # After we have found all functions, synthesize a module to hold them. + source_buffer = source.Buffer("", "") + self.typedtree = asttyped.ModuleT( + typing_env=self.globals, globals_in_scope=set(), + body=self.typedtree, loc=source.Range(source_buffer, 0, 0)) + + def _inject(self, node): + self.typedtree.insert(self.inject_at, node) + self.inject_at += 1 + + def _synthesizer(self, expanded_from=None): + return ASTSynthesizer(expanded_from=expanded_from, + type_map=self.type_map, + value_map=self.value_map, + quote_function=self._quote_function) + + def _quote_embedded_function(self, function): + if not hasattr(function, "artiq_embedded"): + raise ValueError("{} is not an embedded function".format(repr(function))) + + # Extract function source. + embedded_function = function.artiq_embedded.function + source_code = inspect.getsource(embedded_function) + filename = embedded_function.__code__.co_filename + module_name = embedded_function.__globals__['__name__'] + first_line = embedded_function.__code__.co_firstlineno + + # Extract function environment. + host_environment = dict() + host_environment.update(embedded_function.__globals__) + cells = embedded_function.__closure__ + cell_names = embedded_function.__code__.co_freevars + host_environment.update({var: cells[index] for index, var in enumerate(cell_names)}) + + # Find out how indented we are. + initial_whitespace = re.search(r"^\s*", source_code).group(0) + initial_indent = len(initial_whitespace.expandtabs()) + + # Parse. + source_buffer = source.Buffer(source_code, filename, first_line) + lexer = source_lexer.Lexer(source_buffer, version=sys.version_info[0:2], + diagnostic_engine=self.engine) + lexer.indent = [(initial_indent, + source.Range(source_buffer, 0, len(initial_whitespace)), + initial_whitespace)] + parser = source_parser.Parser(lexer, version=sys.version_info[0:2], + diagnostic_engine=self.engine) + function_node = parser.file_input().body[0] + + # Mangle the name, since we put everything into a single module. + function_node.name = "{}.{}".format(module_name, function.__qualname__) + + # Normally, LocalExtractor would populate the typing environment + # of the module with the function name. However, since we run + # ASTTypedRewriter on the function node directly, we need to do it + # explicitly. + self.globals[function_node.name] = types.TVar() + + # Memoize the function before typing it to handle recursive + # invocations. + self.functions[function] = function_node.name + + # Rewrite into typed form. + asttyped_rewriter = StitchingASTTypedRewriter( + engine=self.engine, prelude=self.prelude, + globals=self.globals, host_environment=host_environment, + quote=self._quote) + return asttyped_rewriter.visit(function_node) + + def _function_loc(self, function): + filename = function.__code__.co_filename + line = function.__code__.co_firstlineno + name = function.__code__.co_name + + source_line = linecache.getline(filename, line) + while source_line.lstrip().startswith("@"): + line += 1 + source_line = linecache.getline(filename, line) + + if "" in function.__qualname__: + column = 0 # can't get column of lambda + else: + column = re.search("def", source_line).start(0) + source_buffer = source.Buffer(source_line, filename, line) + return source.Range(source_buffer, column, column) + + def _call_site_note(self, call_loc, is_syscall): + if call_loc: + if is_syscall: + return [diagnostic.Diagnostic("note", + "in system call here", {}, + call_loc)] + else: + return [diagnostic.Diagnostic("note", + "in function called remotely here", {}, + call_loc)] + else: + return [] + + def _extract_annot(self, function, annot, kind, call_loc, is_syscall): + if not isinstance(annot, types.Type): + diag = diagnostic.Diagnostic("error", + "type annotation for {kind}, '{annot}', is not an ARTIQ type", + {"kind": kind, "annot": repr(annot)}, + self._function_loc(function), + notes=self._call_site_note(call_loc, is_syscall)) + self.engine.process(diag) + + return types.TVar() + else: + return annot + + def _type_of_param(self, function, loc, param, is_syscall): + if param.annotation is not inspect.Parameter.empty: + # Type specified explicitly. + return self._extract_annot(function, param.annotation, + "argument '{}'".format(param.name), loc, + is_syscall) + elif is_syscall: + # Syscalls must be entirely annotated. + diag = diagnostic.Diagnostic("error", + "system call argument '{argument}' must have a type annotation", + {"argument": param.name}, + self._function_loc(function), + notes=self._call_site_note(loc, is_syscall)) + self.engine.process(diag) + elif param.default is not inspect.Parameter.empty: + # Try and infer the type from the default value. + # This is tricky, because the default value might not have + # a well-defined type in APython. + # In this case, we bail out, but mention why we do it. + ast = self._quote(param.default, None) + + def proxy_diagnostic(diag): + note = diagnostic.Diagnostic("note", + "expanded from here while trying to infer a type for an" + " unannotated optional argument '{argument}' from its default value", + {"argument": param.name}, + self._function_loc(function)) + diag.notes.append(note) + + note = self._call_site_note(loc, is_syscall) + if note: + diag.notes += note + + self.engine.process(diag) + + proxy_engine = diagnostic.Engine() + proxy_engine.process = proxy_diagnostic + Inferencer(engine=proxy_engine).visit(ast) + IntMonomorphizer(engine=proxy_engine).visit(ast) + + return ast.type + else: + # Let the rest of the program decide. + return types.TVar() + + def _quote_foreign_function(self, function, loc, syscall): + signature = inspect.signature(function) + + arg_types = OrderedDict() + optarg_types = OrderedDict() + for param in signature.parameters.values(): + if param.kind not in (inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD): + # We pretend we don't see *args, kwpostargs=..., **kwargs. + # Since every method can be still invoked without any arguments + # going into *args and the slots after it, this is always safe, + # if sometimes constraining. + # + # Accepting POSITIONAL_ONLY is OK, because the compiler + # desugars the keyword arguments into positional ones internally. + continue + + if param.default is inspect.Parameter.empty: + arg_types[param.name] = self._type_of_param(function, loc, param, + is_syscall=syscall is not None) + elif syscall is None: + optarg_types[param.name] = self._type_of_param(function, loc, param, + is_syscall=False) + else: + diag = diagnostic.Diagnostic("error", + "system call argument '{argument}' must not have a default value", + {"argument": param.name}, + self._function_loc(function), + notes=self._call_site_note(loc, is_syscall=True)) + self.engine.process(diag) + + if signature.return_annotation is not inspect.Signature.empty: + ret_type = self._extract_annot(function, signature.return_annotation, + "return type", loc, is_syscall=syscall is not None) + elif syscall is None: + ret_type = builtins.TNone() + else: # syscall is not None + diag = diagnostic.Diagnostic("error", + "system call must have a return type annotation", {}, + self._function_loc(function), + notes=self._call_site_note(loc, is_syscall=True)) + self.engine.process(diag) + ret_type = types.TVar() + + if syscall is None: + function_type = types.TRPCFunction(arg_types, optarg_types, ret_type, + service=self.object_map.store(function)) + function_name = "rpc${}".format(function_type.service) + else: + function_type = types.TCFunction(arg_types, ret_type, + name=syscall) + function_name = "ffi${}".format(function_type.name) + + self.globals[function_name] = function_type + self.functions[function] = function_name + + return function_name, function_type + + def _quote_function(self, function, loc): + if function in self.functions: + function_name = self.functions[function] + return function_name, self.globals[function_name] + + if hasattr(function, "artiq_embedded"): + if function.artiq_embedded.function is not None: + # Insert the typed AST for the new function and restart inference. + # It doesn't really matter where we insert as long as it is before + # the final call. + function_node = self._quote_embedded_function(function) + self._inject(function_node) + return function_node.name, self.globals[function_node.name] + elif function.artiq_embedded.syscall is not None: + # Insert a storage-less global whose type instructs the compiler + # to perform a system call instead of a regular call. + return self._quote_foreign_function(function, loc, + syscall=function.artiq_embedded.syscall) + else: + assert False + else: + # Insert a storage-less global whose type instructs the compiler + # to perform an RPC instead of a regular call. + return self._quote_foreign_function(function, loc, + syscall=None) + + def _quote(self, value, loc): + synthesizer = self._synthesizer(loc) + node = synthesizer.quote(value) + synthesizer.finalize() + return node diff --git a/artiq/compiler/iodelay.py b/artiq/compiler/iodelay.py new file mode 100644 index 000000000..6cab8588c --- /dev/null +++ b/artiq/compiler/iodelay.py @@ -0,0 +1,249 @@ +""" +The :mod:`iodelay` module contains the classes describing +the statically inferred RTIO delay arising from executing +a function. +""" + +from functools import reduce + +class Expr: + def __add__(lhs, rhs): + assert isinstance(rhs, Expr) + return Add(lhs, rhs) + __iadd__ = __add__ + + def __sub__(lhs, rhs): + assert isinstance(rhs, Expr) + return Sub(lhs, rhs) + __isub__ = __sub__ + + def __mul__(lhs, rhs): + assert isinstance(rhs, Expr) + return Mul(lhs, rhs) + __imul__ = __mul__ + + def __truediv__(lhs, rhs): + assert isinstance(rhs, Expr) + return TrueDiv(lhs, rhs) + __itruediv__ = __truediv__ + + def __floordiv__(lhs, rhs): + assert isinstance(rhs, Expr) + return FloorDiv(lhs, rhs) + __ifloordiv__ = __floordiv__ + + def __ne__(lhs, rhs): + return not (lhs == rhs) + + def free_vars(self): + return set() + + def fold(self, vars=None): + return self + +class Const(Expr): + _priority = 1 + + def __init__(self, value): + assert isinstance(value, (int, float)) + self.value = value + + def __str__(self): + return str(self.value) + + def __eq__(lhs, rhs): + return rhs.__class__ == lhs.__class__ and lhs.value == rhs.value + + def eval(self, env): + return self.value + +class Var(Expr): + _priority = 1 + + def __init__(self, name): + assert isinstance(name, str) + self.name = name + + def __str__(self): + return self.name + + def __eq__(lhs, rhs): + return rhs.__class__ == lhs.__class__ and lhs.name == rhs.name + + def free_vars(self): + return {self.name} + + def fold(self, vars=None): + if vars is not None and self.name in vars: + return vars[self.name] + else: + return self + +class Conv(Expr): + _priority = 1 + + def __init__(self, operand, ref_period): + assert isinstance(operand, Expr) + assert isinstance(ref_period, float) + self.operand, self.ref_period = operand, ref_period + + def __eq__(lhs, rhs): + return rhs.__class__ == lhs.__class__ and \ + lhs.ref_period == rhs.ref_period and \ + lhs.operand == rhs.operand + + def free_vars(self): + return self.operand.free_vars() + +class MUToS(Conv): + def __str__(self): + return "mu->s({})".format(self.operand) + + def eval(self, env): + return self.operand.eval(env) * self.ref_period + + def fold(self, vars=None): + operand = self.operand.fold(vars) + if isinstance(operand, Const): + return Const(operand.value * self.ref_period) + else: + return MUToS(operand, ref_period=self.ref_period) + +class SToMU(Conv): + def __str__(self): + return "s->mu({})".format(self.operand) + + def eval(self, env): + return int(self.operand.eval(env) / self.ref_period) + + def fold(self, vars=None): + operand = self.operand.fold(vars) + if isinstance(operand, Const): + return Const(int(operand.value / self.ref_period)) + else: + return SToMU(operand, ref_period=self.ref_period) + +class BinOp(Expr): + def __init__(self, lhs, rhs): + self.lhs, self.rhs = lhs, rhs + + def __str__(self): + lhs = "({})".format(self.lhs) if self.lhs._priority > self._priority else str(self.lhs) + rhs = "({})".format(self.rhs) if self.rhs._priority > self._priority else str(self.rhs) + return "{} {} {}".format(lhs, self._symbol, rhs) + + def __eq__(lhs, rhs): + return rhs.__class__ == lhs.__class__ and lhs.lhs == rhs.lhs and lhs.rhs == rhs.rhs + + def eval(self, env): + return self.__class__._op(self.lhs.eval(env), self.rhs.eval(env)) + + def free_vars(self): + return self.lhs.free_vars() | self.rhs.free_vars() + + def _fold_binop(self, lhs, rhs): + if isinstance(lhs, Const) and lhs.__class__ == rhs.__class__: + return Const(self.__class__._op(lhs.value, rhs.value)) + elif isinstance(lhs, (MUToS, SToMU)) and lhs.__class__ == rhs.__class__: + return lhs.__class__(self.__class__(lhs.operand, rhs.operand), + ref_period=lhs.ref_period).fold() + else: + return self.__class__(lhs, rhs) + + def fold(self, vars=None): + return self._fold_binop(self.lhs.fold(vars), self.rhs.fold(vars)) + +class BinOpFixpoint(BinOp): + def _fold_binop(self, lhs, rhs): + if isinstance(lhs, Const) and lhs.value == self._fixpoint: + return rhs + elif isinstance(rhs, Const) and rhs.value == self._fixpoint: + return lhs + else: + return super()._fold_binop(lhs, rhs) + +class Add(BinOpFixpoint): + _priority = 2 + _symbol = "+" + _op = lambda a, b: a + b + _fixpoint = 0 + +class Mul(BinOpFixpoint): + _priority = 1 + _symbol = "*" + _op = lambda a, b: a * b + _fixpoint = 1 + +class Sub(BinOp): + _priority = 2 + _symbol = "-" + _op = lambda a, b: a - b + + def _fold_binop(self, lhs, rhs): + if isinstance(rhs, Const) and rhs.value == 0: + return lhs + else: + return super()._fold_binop(lhs, rhs) + +class Div(BinOp): + def _fold_binop(self, lhs, rhs): + if isinstance(rhs, Const) and rhs.value == 1: + return lhs + else: + return super()._fold_binop(lhs, rhs) + +class TrueDiv(Div): + _priority = 1 + _symbol = "/" + _op = lambda a, b: a / b if b != 0 else 0 + +class FloorDiv(Div): + _priority = 1 + _symbol = "//" + _op = lambda a, b: a // b if b != 0 else 0 + +class Max(Expr): + _priority = 1 + + def __init__(self, operands): + assert isinstance(operands, list) + assert all([isinstance(operand, Expr) for operand in operands]) + assert operands != [] + self.operands = operands + + def __str__(self): + return "max({})".format(", ".join([str(operand) for operand in self.operands])) + + def __eq__(lhs, rhs): + return rhs.__class__ == lhs.__class__ and lhs.operands == rhs.operands + + def free_vars(self): + return reduce(lambda a, b: a | b, [operand.free_vars() for operand in self.operands]) + + def eval(self, env): + return max([operand.eval() for operand in self.operands]) + + def fold(self, vars=None): + consts, exprs = [], [] + for operand in self.operands: + operand = operand.fold(vars) + if isinstance(operand, Const): + consts.append(operand.value) + elif operand not in exprs: + exprs.append(operand) + if len(consts) > 0: + exprs.append(Const(max(consts))) + if len(exprs) == 1: + return exprs[0] + else: + return Max(exprs) + +def is_const(expr, value=None): + expr = expr.fold() + if value is None: + return isinstance(expr, Const) + else: + return isinstance(expr, Const) and expr.value == value + +def is_zero(expr): + return is_const(expr, 0) diff --git a/artiq/compiler/ir.py b/artiq/compiler/ir.py new file mode 100644 index 000000000..54c6f0500 --- /dev/null +++ b/artiq/compiler/ir.py @@ -0,0 +1,1374 @@ +""" +The :mod:`ir` module contains the intermediate representation +of the ARTIQ compiler. +""" + +from collections import OrderedDict +from pythonparser import ast +from . import types, builtins + +# Generic SSA IR classes + +def escape_name(name): + if all([str.isalnum(x) or x == "." for x in name]): + return name + else: + return "\"{}\"".format(name.replace("\"", "\\\"")) + +class TBasicBlock(types.TMono): + def __init__(self): + super().__init__("label") + +def is_basic_block(typ): + return isinstance(typ, TBasicBlock) + +class TOption(types.TMono): + def __init__(self, inner): + super().__init__("option", {"inner": inner}) + +def is_option(typ): + return isinstance(typ, TOption) + +class TExceptionTypeInfo(types.TMono): + def __init__(self): + super().__init__("exntypeinfo") + +def is_exn_typeinfo(typ): + return isinstance(typ, TExceptionTypeInfo) + +class Value: + """ + An SSA value that keeps track of its uses. + + :ivar type: (:class:`.types.Type`) type of this value + :ivar uses: (list of :class:`Value`) values that use this value + """ + + def __init__(self, typ): + self.uses, self.type = set(), typ.find() + + def replace_all_uses_with(self, value): + for user in set(self.uses): + user.replace_uses_of(self, value) + + def __str__(self): + return self.as_entity(type_printer=types.TypePrinter()) + +class Constant(Value): + """ + A constant value. + + :ivar value: (Python object) value + """ + + def __init__(self, value, typ): + super().__init__(typ) + self.value = value + + def as_operand(self, type_printer): + return self.as_entity(type_printer) + + def as_entity(self, type_printer): + return "{} {}".format(type_printer.name(self.type), + repr(self.value)) + + def __hash__(self): + return hash(self.value) + + def __eq__(self, other): + return isinstance(other, Constant) and \ + other.type == self.type and other.value == self.value + + def __ne__(self, other): + return not (self == other) + +class NamedValue(Value): + """ + An SSA value that has a name. + + :ivar name: (string) name of this value + :ivar function: (:class:`Function`) function containing this value + """ + + def __init__(self, typ, name): + super().__init__(typ) + self.name, self.function = name, None + + def set_name(self, new_name): + if self.function is not None: + self.function._remove_name(self.name) + self.name = self.function._add_name(new_name) + else: + self.name = new_name + + def _set_function(self, new_function): + if self.function != new_function: + if self.function is not None: + self.function._remove_name(self.name) + self.function = new_function + if self.function is not None: + self.name = self.function._add_name(self.name) + + def _detach(self): + self.function = None + + def as_operand(self, type_printer): + return "{} %{}".format(type_printer.name(self.type), + escape_name(self.name)) + +class User(NamedValue): + """ + An SSA value that has operands. + + :ivar operands: (list of :class:`Value`) operands of this value + """ + + def __init__(self, operands, typ, name): + super().__init__(typ, name) + self.operands = [] + self.set_operands(operands) + + def set_operands(self, new_operands): + for operand in set(self.operands): + operand.uses.remove(self) + self.operands = new_operands + for operand in set(self.operands): + operand.uses.add(self) + + def drop_references(self): + self.set_operands([]) + + def replace_uses_of(self, value, replacement): + assert value in self.operands + + for index, operand in enumerate(self.operands): + if operand == value: + self.operands[index] = replacement + + value.uses.remove(self) + replacement.uses.add(self) + +class Instruction(User): + """ + An SSA instruction. + + :ivar loc: (:class:`pythonparser.source.Range` or None) + source location + """ + + def __init__(self, operands, typ, name=""): + assert isinstance(operands, list) + assert isinstance(typ, types.Type) + super().__init__(operands, typ, name) + self.basic_block = None + self.loc = None + + def copy(self, mapper): + self_copy = self.__class__.__new__(self.__class__) + Instruction.__init__(self_copy, list(map(mapper, self.operands)), + self.type, self.name) + return self_copy + + def set_basic_block(self, new_basic_block): + self.basic_block = new_basic_block + if self.basic_block is not None: + self._set_function(self.basic_block.function) + else: + self._set_function(None) + + def opcode(self): + """String representation of the opcode.""" + return "???" + + def _detach(self): + self.set_basic_block(None) + + def remove_from_parent(self): + if self.basic_block is not None: + self.basic_block.remove(self) + + def erase(self): + self.remove_from_parent() + self.drop_references() + # Check this after drop_references in case this + # is a self-referencing phi. + assert not any(self.uses) + + def replace_with(self, value): + self.replace_all_uses_with(value) + if isinstance(value, Instruction): + self.basic_block.replace(self, value) + self.drop_references() + else: + self.erase() + + def _operands_as_string(self, type_printer): + return ", ".join([operand.as_operand(type_printer) for operand in self.operands]) + + def as_entity(self, type_printer): + if builtins.is_none(self.type) and len(self.uses) == 0: + prefix = "" + else: + prefix = "%{} = {} ".format(escape_name(self.name), + type_printer.name(self.type)) + + if any(self.operands): + return "{}{} {}".format(prefix, self.opcode(), + self._operands_as_string(type_printer)) + else: + return "{}{}".format(prefix, self.opcode()) + +class Phi(Instruction): + """ + An SSA instruction that joins data flow. + + Use :meth:`incoming` and :meth:`add_incoming` instead of + directly reading :attr:`operands` or calling :meth:`set_operands`. + """ + + def __init__(self, typ, name=""): + super().__init__([], typ, name) + + def opcode(self): + return "phi" + + def incoming(self): + operand_iter = iter(self.operands) + while True: + yield next(operand_iter), next(operand_iter) + + def incoming_blocks(self): + return (block for (block, value) in self.incoming()) + + def incoming_values(self): + return (value for (block, value) in self.incoming()) + + def incoming_value_for_block(self, target_block): + for (block, value) in self.incoming(): + if block == target_block: + return value + assert False + + def add_incoming(self, value, block): + assert value.type == self.type + self.operands.append(value) + value.uses.add(self) + self.operands.append(block) + block.uses.add(self) + + def remove_incoming_value(self, value): + index = self.operands.index(value) + self.operands[index].uses.remove(self) + self.operands[index + 1].uses.remove(self) + del self.operands[index:index + 2] + + def remove_incoming_block(self, block): + index = self.operands.index(block) + self.operands[index - 1].uses.remove(self) + self.operands[index].uses.remove(self) + del self.operands[index - 1:index + 1] + + def as_entity(self, type_printer): + if builtins.is_none(self.type): + prefix = "" + else: + prefix = "%{} = {} ".format(escape_name(self.name), + type_printer.name(self.type)) + + if any(self.operands): + operand_list = ["%{} => {}".format(escape_name(block.name), + value.as_operand(type_printer)) + for value, block in self.incoming()] + return "{}{} [{}]".format(prefix, self.opcode(), ", ".join(operand_list)) + else: + return "{}{} [???]".format(prefix, self.opcode()) + +class Terminator(Instruction): + """ + An SSA instruction that performs control flow. + """ + + def successors(self): + return [operand for operand in self.operands if isinstance(operand, BasicBlock)] + +class BasicBlock(NamedValue): + """ + A block of instructions with no control flow inside it. + + :ivar instructions: (list of :class:`Instruction`) + """ + + def __init__(self, instructions, name=""): + super().__init__(TBasicBlock(), name) + self.instructions = [] + self.set_instructions(instructions) + + def set_instructions(self, new_insns): + for insn in self.instructions: + insn.detach() + self.instructions = new_insns + for insn in self.instructions: + insn.set_basic_block(self) + + def remove_from_parent(self): + if self.function is not None: + self.function.remove(self) + + def erase(self): + # self.instructions is updated while iterating + for insn in reversed(self.instructions): + insn.erase() + self.remove_from_parent() + # Check this after erasing instructions in case the block + # loops into itself. + assert not any(self.uses) + + def prepend(self, insn): + assert isinstance(insn, Instruction) + insn.set_basic_block(self) + self.instructions.insert(0, insn) + return insn + + def append(self, insn): + assert isinstance(insn, Instruction) + insn.set_basic_block(self) + self.instructions.append(insn) + return insn + + def index(self, insn): + return self.instructions.index(insn) + + def insert(self, insn, before): + assert isinstance(insn, Instruction) + insn.set_basic_block(self) + self.instructions.insert(self.index(before), insn) + return insn + + def remove(self, insn): + assert insn in self.instructions + insn._detach() + self.instructions.remove(insn) + return insn + + def replace(self, insn, replacement): + self.insert(replacement, before=insn) + self.remove(insn) + + def is_terminated(self): + return any(self.instructions) and isinstance(self.instructions[-1], Terminator) + + def terminator(self): + assert self.is_terminated() + return self.instructions[-1] + + def successors(self): + return self.terminator().successors() + + def predecessors(self): + return [use.basic_block for use in self.uses if isinstance(use, Terminator)] + + def as_entity(self, type_printer): + # Header + lines = ["{}:".format(escape_name(self.name))] + if self.function is not None: + lines[0] += " ; predecessors: {}".format( + ", ".join([escape_name(pred.name) for pred in self.predecessors()])) + + # Annotated instructions + loc = None + for insn in self.instructions: + if loc != insn.loc: + loc = insn.loc + + if loc is None: + lines.append("; ") + else: + source_lines = loc.source_lines() + beg_col, end_col = loc.column(), loc.end().column() + source_lines[-1] = \ + source_lines[-1][:end_col] + "\x1b[0m" + source_lines[-1][end_col:] + source_lines[0] = \ + source_lines[0][:beg_col] + "\x1b[1;32m" + source_lines[0][beg_col:] + + line_desc = "{}:{}".format(loc.source_buffer.name, loc.line()) + lines += ["; {} {}".format(line_desc, line.rstrip("\n")) + for line in source_lines] + lines.append(" " + insn.as_entity(type_printer)) + + return "\n".join(lines) + + def __repr__(self): + return "".format(repr(self.name)) + +class Argument(NamedValue): + """ + A function argument. + """ + + def as_entity(self, type_printer): + return self.as_operand(type_printer) + +class Function: + """ + A function containing SSA IR. + + :ivar loc: (:class:`pythonparser.source.Range` or None) + source location of function definition + :ivar is_internal: + (bool) if True, the function should not be accessible from outside + the module it is contained in + """ + + def __init__(self, typ, name, arguments, loc=None): + self.type, self.name, self.loc = typ, name, loc + self.names, self.arguments, self.basic_blocks = set(), [], [] + self.next_name = 1 + self.set_arguments(arguments) + self.is_internal = False + + def _remove_name(self, name): + self.names.remove(name) + + def _add_name(self, base_name): + if base_name == "": + name = "v.{}".format(self.next_name) + self.next_name += 1 + elif base_name in self.names: + name = "{}.{}".format(base_name, self.next_name) + self.next_name += 1 + else: + name = base_name + + self.names.add(name) + return name + + def set_arguments(self, new_arguments): + for argument in self.arguments: + argument._set_function(None) + self.arguments = new_arguments + for argument in self.arguments: + argument._set_function(self) + + def add(self, basic_block): + basic_block._set_function(self) + self.basic_blocks.append(basic_block) + + def remove(self, basic_block): + basic_block._detach() + self.basic_blocks.remove(basic_block) + + def entry(self): + assert any(self.basic_blocks) + return self.basic_blocks[0] + + def exits(self): + return [block for block in self.basic_blocks if not any(block.successors())] + + def instructions(self): + for basic_block in self.basic_blocks: + yield from iter(basic_block.instructions) + + def as_entity(self, type_printer): + postorder = [] + visited = set() + def visit(block): + visited.add(block) + for next_block in block.successors(): + if next_block not in visited: + visit(next_block) + postorder.append(block) + + visit(self.entry()) + + lines = [] + lines.append("{} {}({}) {{ ; type: {}".format( + type_printer.name(self.type.ret), self.name, + ", ".join([arg.as_operand(type_printer) for arg in self.arguments]), + type_printer.name(self.type))) + + postorder_blocks = list(reversed(postorder)) + orphan_blocks = [block for block in self.basic_blocks if block not in postorder] + for block in postorder_blocks + orphan_blocks: + lines.append(block.as_entity(type_printer)) + + lines.append("}") + return "\n".join(lines) + + def __str__(self): + return self.as_entity(types.TypePrinter()) + +# Python-specific SSA IR classes + +class TEnvironment(types.TMono): + def __init__(self, vars, outer=None): + if outer is not None: + assert isinstance(outer, TEnvironment) + env = OrderedDict({"$outer": outer}) + env.update(vars) + else: + env = OrderedDict(vars) + + super().__init__("environment", env) + + def type_of(self, name): + if name in self.params: + return self.params[name].find() + elif "$outer" in self.params: + return self.params["$outer"].type_of(name) + else: + assert False + + def outermost(self): + if "$outer" in self.params: + return self.params["$outer"].outermost() + else: + return self + + """ + Add a new binding, ensuring hygiene. + + :returns: (string) mangled name + """ + def add(self, base_name, typ): + name, counter = base_name, 1 + while name in self.params or name == "": + if base_name == "": + name = str(counter) + else: + name = "{}.{}".format(name, counter) + counter += 1 + + self.params[name] = typ.find() + return name + +def is_environment(typ): + return isinstance(typ, TEnvironment) + +class EnvironmentArgument(Argument): + """ + A function argument specifying an outer environment. + """ + + def as_operand(self, type_printer): + return "environment(...) %{}".format(escape_name(self.name)) + +class Alloc(Instruction): + """ + An instruction that allocates an object specified by + the type of the intsruction. + """ + + def __init__(self, operands, typ, name=""): + for operand in operands: assert isinstance(operand, Value) + super().__init__(operands, typ, name) + + def opcode(self): + return "alloc" + + def as_operand(self, type_printer): + if is_environment(self.type): + # Only show full environment in the instruction itself + return "%{}".format(escape_name(self.name)) + else: + return super().as_operand(type_printer) + +class GetLocal(Instruction): + """ + An intruction that loads a local variable from an environment, + possibly going through multiple levels of indirection. + + :ivar var_name: (string) variable name + """ + + """ + :param env: (:class:`Value`) local environment + :param var_name: (string) local variable name + """ + def __init__(self, env, var_name, name=""): + assert isinstance(env, Value) + assert isinstance(env.type, TEnvironment) + assert isinstance(var_name, str) + super().__init__([env], env.type.type_of(var_name), name) + self.var_name = var_name + + def copy(self, mapper): + self_copy = super().copy(mapper) + self_copy.var_name = self.var_name + return self_copy + + def opcode(self): + return "getlocal({})".format(repr(self.var_name)) + + def environment(self): + return self.operands[0] + +class SetLocal(Instruction): + """ + An intruction that stores a local variable into an environment, + possibly going through multiple levels of indirection. + + :ivar var_name: (string) variable name + """ + + """ + :param env: (:class:`Value`) local environment + :param var_name: (string) local variable name + :param value: (:class:`Value`) value to assign + """ + def __init__(self, env, var_name, value, name=""): + assert isinstance(env, Value) + assert isinstance(env.type, TEnvironment) + assert isinstance(var_name, str) + assert env.type.type_of(var_name) == value.type + assert isinstance(value, Value) + super().__init__([env, value], builtins.TNone(), name) + self.var_name = var_name + + def copy(self, mapper): + self_copy = super().copy(mapper) + self_copy.var_name = self.var_name + return self_copy + + def opcode(self): + return "setlocal({})".format(repr(self.var_name)) + + def environment(self): + return self.operands[0] + + def value(self): + return self.operands[1] + +class GetConstructor(Instruction): + """ + An intruction that loads a local variable with the given type + from an environment, possibly going through multiple levels of indirection. + + :ivar var_name: (string) variable name + """ + + """ + :param env: (:class:`Value`) local environment + :param var_name: (string) local variable name + :param var_type: (:class:`types.Type`) local variable type + """ + def __init__(self, env, var_name, var_type, name=""): + assert isinstance(env, Value) + assert isinstance(env.type, TEnvironment) + assert isinstance(var_name, str) + assert isinstance(var_type, types.Type) + super().__init__([env], var_type, name) + self.var_name = var_name + + def copy(self, mapper): + self_copy = super().copy(mapper) + self_copy.var_name = self.var_name + return self_copy + + def opcode(self): + return "getconstructor({})".format(repr(self.var_name)) + + def environment(self): + return self.operands[0] + +class GetAttr(Instruction): + """ + An intruction that loads an attribute from an object, + or extracts a tuple element. + + :ivar attr: (string) variable name + """ + + """ + :param obj: (:class:`Value`) object or tuple + :param attr: (string or integer) attribute or index + """ + def __init__(self, obj, attr, name=""): + assert isinstance(obj, Value) + assert isinstance(attr, (str, int)) + if isinstance(attr, int): + assert isinstance(obj.type, types.TTuple) + typ = obj.type.elts[attr] + else: + typ = obj.type.attributes[attr] + super().__init__([obj], typ, name) + self.attr = attr + + def copy(self, mapper): + self_copy = super().copy(mapper) + self_copy.attr = self.attr + return self_copy + + def opcode(self): + return "getattr({})".format(repr(self.attr)) + + def object(self): + return self.operands[0] + +class SetAttr(Instruction): + """ + An intruction that stores an attribute to an object. + + :ivar attr: (string) variable name + """ + + """ + :param obj: (:class:`Value`) object or tuple + :param attr: (string or integer) attribute + :param value: (:class:`Value`) value to store + """ + def __init__(self, obj, attr, value, name=""): + assert isinstance(obj, Value) + assert isinstance(attr, (str, int)) + assert isinstance(value, Value) + if isinstance(attr, int): + assert value.type == obj.type.elts[attr].find() + else: + assert value.type == obj.type.attributes[attr].find() + super().__init__([obj, value], builtins.TNone(), name) + self.attr = attr + + def copy(self, mapper): + self_copy = super().copy(mapper) + self_copy.attr = self.attr + return self_copy + + def opcode(self): + return "setattr({})".format(repr(self.attr)) + + def object(self): + return self.operands[0] + + def value(self): + return self.operands[1] + +class GetElem(Instruction): + """ + An intruction that loads an element from a list. + """ + + """ + :param lst: (:class:`Value`) list + :param index: (:class:`Value`) index + """ + def __init__(self, lst, index, name=""): + assert isinstance(lst, Value) + assert isinstance(index, Value) + super().__init__([lst, index], builtins.get_iterable_elt(lst.type), name) + + def opcode(self): + return "getelem" + + def list(self): + return self.operands[0] + + def index(self): + return self.operands[1] + +class SetElem(Instruction): + """ + An intruction that stores an element into a list. + """ + + """ + :param lst: (:class:`Value`) list + :param index: (:class:`Value`) index + :param value: (:class:`Value`) value to store + """ + def __init__(self, lst, index, value, name=""): + assert isinstance(lst, Value) + assert isinstance(index, Value) + assert isinstance(value, Value) + assert builtins.get_iterable_elt(lst.type) == value.type.find() + super().__init__([lst, index, value], builtins.TNone(), name) + + def opcode(self): + return "setelem" + + def list(self): + return self.operands[0] + + def index(self): + return self.operands[1] + + def value(self): + return self.operands[2] + +class Coerce(Instruction): + """ + A coercion operation for numbers. + """ + + def __init__(self, value, typ, name=""): + assert isinstance(value, Value) + assert isinstance(typ, types.Type) + super().__init__([value], typ, name) + + def opcode(self): + return "coerce" + + def value(self): + return self.operands[0] + +class Arith(Instruction): + """ + An arithmetic operation on numbers. + + :ivar op: (:class:`pythonparser.ast.operator`) operation + """ + + """ + :param op: (:class:`pythonparser.ast.operator`) operation + :param lhs: (:class:`Value`) left-hand operand + :param rhs: (:class:`Value`) right-hand operand + """ + def __init__(self, op, lhs, rhs, name=""): + assert isinstance(op, ast.operator) + assert isinstance(lhs, Value) + assert isinstance(rhs, Value) + assert lhs.type == rhs.type + super().__init__([lhs, rhs], lhs.type, name) + self.op = op + + def copy(self, mapper): + self_copy = super().copy(mapper) + self_copy.op = self.op + return self_copy + + def opcode(self): + return "arith({})".format(type(self.op).__name__) + + def lhs(self): + return self.operands[0] + + def rhs(self): + return self.operands[1] + +class Compare(Instruction): + """ + A comparison operation on numbers. + + :ivar op: (:class:`pythonparser.ast.cmpop`) operation + """ + + """ + :param op: (:class:`pythonparser.ast.cmpop`) operation + :param lhs: (:class:`Value`) left-hand operand + :param rhs: (:class:`Value`) right-hand operand + """ + def __init__(self, op, lhs, rhs, name=""): + assert isinstance(op, ast.cmpop) + assert isinstance(lhs, Value) + assert isinstance(rhs, Value) + assert lhs.type == rhs.type + super().__init__([lhs, rhs], builtins.TBool(), name) + self.op = op + + def copy(self, mapper): + self_copy = super().copy(mapper) + self_copy.op = self.op + return self_copy + + def opcode(self): + return "compare({})".format(type(self.op).__name__) + + def lhs(self): + return self.operands[0] + + def rhs(self): + return self.operands[1] + +class Builtin(Instruction): + """ + A builtin operation. Similar to a function call that + never raises. + + :ivar op: (string) operation name + """ + + """ + :param op: (string) operation name + """ + def __init__(self, op, operands, typ, name=""): + assert isinstance(op, str) + for operand in operands: assert isinstance(operand, Value) + super().__init__(operands, typ, name) + self.op = op + + def copy(self, mapper): + self_copy = super().copy(mapper) + self_copy.op = self.op + return self_copy + + def opcode(self): + return "builtin({})".format(self.op) + +class Closure(Instruction): + """ + A closure creation operation. + + :ivar target_function: (:class:`Function`) function to invoke + """ + + """ + :param func: (:class:`Function`) function + :param env: (:class:`Value`) outer environment + """ + def __init__(self, func, env, name=""): + assert isinstance(func, Function) + assert isinstance(env, Value) + assert is_environment(env.type) + super().__init__([env], func.type, name) + self.target_function = func + + def copy(self, mapper): + self_copy = super().copy(mapper) + self_copy.target_function = self.target_function + return self_copy + + def opcode(self): + return "closure({})".format(self.target_function.name) + + def environment(self): + return self.operands[0] + +class Call(Instruction): + """ + A function call operation. + + :ivar static_target_function: (:class:`Function` or None) + statically resolved callee + """ + + """ + :param func: (:class:`Value`) function to call + :param args: (list of :class:`Value`) function arguments + """ + def __init__(self, func, args, name=""): + assert isinstance(func, Value) + for arg in args: assert isinstance(arg, Value) + super().__init__([func] + args, func.type.ret, name) + self.static_target_function = None + + def copy(self, mapper): + self_copy = super().copy(mapper) + self_copy.static_target_function = self.static_target_function + return self_copy + + def opcode(self): + return "call" + + def target_function(self): + return self.operands[0] + + def arguments(self): + return self.operands[1:] + + def as_entity(self, type_printer): + result = super().as_entity(type_printer) + if self.static_target_function is not None: + result += " ; calls {}".format(self.static_target_function.name) + return result + +class Select(Instruction): + """ + A conditional select instruction. + """ + + """ + :param cond: (:class:`Value`) select condition + :param if_true: (:class:`Value`) value of select if condition is truthful + :param if_false: (:class:`Value`) value of select if condition is falseful + """ + def __init__(self, cond, if_true, if_false, name=""): + assert isinstance(cond, Value) + assert builtins.is_bool(cond.type) + assert isinstance(if_true, Value) + assert isinstance(if_false, Value) + assert if_true.type == if_false.type + super().__init__([cond, if_true, if_false], if_true.type, name) + + def opcode(self): + return "select" + + def condition(self): + return self.operands[0] + + def if_true(self): + return self.operands[1] + + def if_false(self): + return self.operands[2] + +class Quote(Instruction): + """ + A quote operation. Returns a host interpreter value as a constant. + + :ivar value: (string) operation name + """ + + """ + :param value: (string) operation name + """ + def __init__(self, value, typ, name=""): + super().__init__([], typ, name) + self.value = value + + def copy(self, mapper): + self_copy = super().copy(mapper) + self_copy.value = self.value + return self_copy + + def opcode(self): + return "quote({})".format(repr(self.value)) + +class Branch(Terminator): + """ + An unconditional branch instruction. + """ + + """ + :param target: (:class:`BasicBlock`) branch target + """ + def __init__(self, target, name=""): + assert isinstance(target, BasicBlock) + super().__init__([target], builtins.TNone(), name) + + def opcode(self): + return "branch" + + def target(self): + return self.operands[0] + + def set_target(self, new_target): + self.operands[0].uses.remove(self) + self.operands[0] = new_target + self.operands[0].uses.add(self) + +class BranchIf(Terminator): + """ + A conditional branch instruction. + """ + + """ + :param cond: (:class:`Value`) branch condition + :param if_true: (:class:`BasicBlock`) branch target if condition is truthful + :param if_false: (:class:`BasicBlock`) branch target if condition is falseful + """ + def __init__(self, cond, if_true, if_false, name=""): + assert isinstance(cond, Value) + assert builtins.is_bool(cond.type) + assert isinstance(if_true, BasicBlock) + assert isinstance(if_false, BasicBlock) + assert if_true != if_false # use Branch instead + super().__init__([cond, if_true, if_false], builtins.TNone(), name) + + def opcode(self): + return "branchif" + + def condition(self): + return self.operands[0] + + def if_true(self): + return self.operands[1] + + def if_false(self): + return self.operands[2] + +class IndirectBranch(Terminator): + """ + An indirect branch instruction. + """ + + """ + :param target: (:class:`Value`) branch target + :param destinations: (list of :class:`BasicBlock`) all possible values of `target` + """ + def __init__(self, target, destinations, name=""): + assert isinstance(target, Value) + assert all([isinstance(dest, BasicBlock) for dest in destinations]) + super().__init__([target] + destinations, builtins.TNone(), name) + + def opcode(self): + return "indirectbranch" + + def target(self): + return self.operands[0] + + def destinations(self): + return self.operands[1:] + + def add_destination(self, destination): + destination.uses.add(self) + self.operands.append(destination) + + def _operands_as_string(self, type_printer): + return "{}, [{}]".format(self.operands[0].as_operand(type_printer), + ", ".join([dest.as_operand(type_printer) + for dest in self.operands[1:]])) + +class Return(Terminator): + """ + A return instruction. + """ + + """ + :param value: (:class:`Value`) return value + """ + def __init__(self, value, name=""): + assert isinstance(value, Value) + super().__init__([value], builtins.TNone(), name) + + def opcode(self): + return "return" + + def value(self): + return self.operands[0] + +class Unreachable(Terminator): + """ + An instruction used to mark unreachable branches. + """ + + """ + :param target: (:class:`BasicBlock`) branch target + """ + def __init__(self, name=""): + super().__init__([], builtins.TNone(), name) + + def opcode(self): + return "unreachable" + +class Raise(Terminator): + """ + A raise instruction. + """ + + """ + :param value: (:class:`Value`) exception value + :param exn: (:class:`BasicBlock` or None) exceptional target + """ + def __init__(self, value=None, exn=None, name=""): + assert isinstance(value, Value) + operands = [value] + if exn is not None: + assert isinstance(exn, BasicBlock) + operands.append(exn) + super().__init__(operands, builtins.TNone(), name) + + def opcode(self): + return "raise" + + def value(self): + return self.operands[0] + + def exception_target(self): + if len(self.operands) > 1: + return self.operands[1] + +class Reraise(Terminator): + """ + A reraise instruction. + """ + + """ + :param exn: (:class:`BasicBlock` or None) exceptional target + """ + def __init__(self, exn=None, name=""): + operands = [] + if exn is not None: + assert isinstance(exn, BasicBlock) + operands.append(exn) + super().__init__(operands, builtins.TNone(), name) + + def opcode(self): + return "reraise" + + def exception_target(self): + if len(self.operands) > 0: + return self.operands[0] + +class Invoke(Terminator): + """ + A function call operation that supports exception handling. + + :ivar static_target_function: (:class:`Function` or None) + statically resolved callee + """ + + """ + :param func: (:class:`Value`) function to call + :param args: (list of :class:`Value`) function arguments + :param normal: (:class:`BasicBlock`) normal target + :param exn: (:class:`BasicBlock`) exceptional target + """ + def __init__(self, func, args, normal, exn, name=""): + assert isinstance(func, Value) + for arg in args: assert isinstance(arg, Value) + assert isinstance(normal, BasicBlock) + assert isinstance(exn, BasicBlock) + super().__init__([func] + args + [normal, exn], func.type.ret, name) + self.static_target_function = None + + def copy(self, mapper): + self_copy = super().copy(mapper) + self_copy.static_target_function = self.static_target_function + return self_copy + + def opcode(self): + return "invoke" + + def target_function(self): + return self.operands[0] + + def arguments(self): + return self.operands[1:-2] + + def normal_target(self): + return self.operands[-2] + + def exception_target(self): + return self.operands[-1] + + def _operands_as_string(self, type_printer): + result = ", ".join([operand.as_operand(type_printer) for operand in self.operands[:-2]]) + result += " to {} unwind {}".format(self.operands[-2].as_operand(type_printer), + self.operands[-1].as_operand(type_printer)) + return result + + def as_entity(self, type_printer): + result = super().as_entity(type_printer) + if self.static_target_function is not None: + result += " ; calls {}".format(self.static_target_function.name) + return result + +class LandingPad(Terminator): + """ + An instruction that gives an incoming exception a name and + dispatches it according to its type. + + Once dispatched, the exception should be cast to its proper + type by calling the "exncast" builtin on the landing pad value. + + :ivar types: (a list of :class:`builtins.TException`) + exception types corresponding to the basic block operands + """ + + def __init__(self, cleanup, name=""): + super().__init__([cleanup], builtins.TException(), name) + self.types = [] + + def copy(self, mapper): + self_copy = super().copy(mapper) + self_copy.types = list(self.types) + return self_copy + + def opcode(self): + return "landingpad" + + def cleanup(self): + return self.operands[0] + + def clauses(self): + return zip(self.operands[1:], self.types) + + def add_clause(self, target, typ): + assert isinstance(target, BasicBlock) + assert typ is None or builtins.is_exception(typ) + self.operands.append(target) + self.types.append(typ.find() if typ is not None else None) + target.uses.add(self) + + def _operands_as_string(self, type_printer): + table = [] + for target, typ in self.clauses(): + if typ is None: + table.append("... => {}".format(target.as_operand(type_printer))) + else: + table.append("{} => {}".format(type_printer.name(typ), + target.as_operand(type_printer))) + return "cleanup {}, [{}]".format(self.cleanup().as_operand(type_printer), + ", ".join(table)) + +class Delay(Terminator): + """ + A delay operation. Ties an :class:`iodelay.Expr` to SSA values so that + inlining could lead to the expression folding to a constant. + + :ivar expr: (:class:`iodelay.Expr`) expression + :ivar var_names: (list of string) + iodelay variable names corresponding to operands + """ + + """ + :param expr: (:class:`iodelay.Expr`) expression + :param substs: (dict of str to :class:`Value`) + SSA values corresponding to iodelay variable names + :param call: (:class:`Call` or ``Constant(None, builtins.TNone())``) + the call instruction that caused this delay, if any + :param target: (:class:`BasicBlock`) branch target + """ + def __init__(self, expr, substs, decomposition, target, name=""): + for var_name in substs: assert isinstance(var_name, str) + assert isinstance(decomposition, Call) or \ + isinstance(decomposition, Builtin) and decomposition.op in ("delay", "delay_mu") + assert isinstance(target, BasicBlock) + super().__init__([decomposition, target, *substs.values()], builtins.TNone(), name) + self.expr = expr + self.var_names = list(substs.keys()) + + def copy(self, mapper): + self_copy = super().copy(mapper) + self_copy.expr = self.expr + self_copy.var_names = list(self.var_names) + return self_copy + + def decomposition(self): + return self.operands[0] + + def set_decomposition(self, new_decomposition): + self.operands[0].uses.remove(self) + self.operands[0] = new_decomposition + self.operands[0].uses.add(self) + + def target(self): + return self.operands[1] + + def set_target(self, new_target): + self.operands[1].uses.remove(self) + self.operands[1] = new_target + self.operands[1].uses.add(self) + + def substs(self): + return {key: value for key, value in zip(self.var_names, self.operands[2:])} + + def _operands_as_string(self, type_printer): + substs = self.substs() + substs_as_strings = [] + for var_name in substs: + substs_as_strings.append("{} = {}".format(var_name, substs[var_name])) + result = "[{}]".format(", ".join(substs_as_strings)) + result += ", decomp {}, to {}".format(self.decomposition().as_operand(type_printer), + self.target().as_operand(type_printer)) + return result + + def opcode(self): + return "delay({})".format(self.expr) + +class Parallel(Terminator): + """ + An instruction that schedules several threads of execution + in parallel. + """ + + def __init__(self, destinations, name=""): + super().__init__(destinations, builtins.TNone(), name) + + def opcode(self): + return "parallel" + + def destinations(self): + return self.operands + + def add_destination(self, destination): + destination.uses.add(self) + self.operands.append(destination) diff --git a/artiq/compiler/module.py b/artiq/compiler/module.py new file mode 100644 index 000000000..03ddaae91 --- /dev/null +++ b/artiq/compiler/module.py @@ -0,0 +1,94 @@ +""" +The :class:`Module` class encapsulates a single Python module, +which corresponds to a single ARTIQ translation unit (one LLVM +bitcode file and one object file, unless LTO is used). +A :class:`Module` can be created from a typed AST. + +The :class:`Source` class parses a single source file or +string and infers types for it using a trivial :module:`prelude`. +""" + +import os +from pythonparser import source, diagnostic, parse_buffer +from . import prelude, types, transforms, analyses, validators + +class Source: + def __init__(self, source_buffer, engine=None): + if engine is None: + self.engine = diagnostic.Engine(all_errors_are_fatal=True) + else: + self.engine = engine + + self.object_map = None + + self.name, _ = os.path.splitext(os.path.basename(source_buffer.name)) + + asttyped_rewriter = transforms.ASTTypedRewriter(engine=engine, + prelude=prelude.globals()) + inferencer = transforms.Inferencer(engine=engine) + + self.parsetree, self.comments = parse_buffer(source_buffer, engine=engine) + self.typedtree = asttyped_rewriter.visit(self.parsetree) + self.globals = asttyped_rewriter.globals + inferencer.visit(self.typedtree) + + @classmethod + def from_string(cls, source_string, name="input.py", first_line=1, engine=None): + return cls(source.Buffer(source_string + "\n", name, first_line), engine=engine) + + @classmethod + def from_filename(cls, filename, engine=None): + with open(filename) as f: + return cls(source.Buffer(f.read(), filename, 1), engine=engine) + +class Module: + def __init__(self, src, ref_period=1e-6): + self.engine = src.engine + self.object_map = src.object_map + + int_monomorphizer = transforms.IntMonomorphizer(engine=self.engine) + inferencer = transforms.Inferencer(engine=self.engine) + monomorphism_validator = validators.MonomorphismValidator(engine=self.engine) + escape_validator = validators.EscapeValidator(engine=self.engine) + iodelay_estimator = transforms.IODelayEstimator(engine=self.engine, + ref_period=ref_period) + artiq_ir_generator = transforms.ARTIQIRGenerator(engine=self.engine, + module_name=src.name, + ref_period=ref_period) + dead_code_eliminator = transforms.DeadCodeEliminator(engine=self.engine) + local_access_validator = validators.LocalAccessValidator(engine=self.engine) + devirtualization = analyses.Devirtualization() + interleaver = transforms.Interleaver(engine=self.engine) + + self.name = src.name + self.globals = src.globals + int_monomorphizer.visit(src.typedtree) + inferencer.visit(src.typedtree) + monomorphism_validator.visit(src.typedtree) + escape_validator.visit(src.typedtree) + iodelay_estimator.visit_fixpoint(src.typedtree) + devirtualization.visit(src.typedtree) + self.artiq_ir = artiq_ir_generator.visit(src.typedtree) + artiq_ir_generator.annotate_calls(devirtualization) + dead_code_eliminator.process(self.artiq_ir) + local_access_validator.process(self.artiq_ir) + interleaver.process(self.artiq_ir) + + def build_llvm_ir(self, target): + """Compile the module to LLVM IR for the specified target.""" + llvm_ir_generator = transforms.LLVMIRGenerator(engine=self.engine, + module_name=self.name, target=target, + object_map=self.object_map) + return llvm_ir_generator.process(self.artiq_ir) + + def entry_point(self): + """Return the name of the function that is the entry point of this module.""" + if self.name != "": + return self.name + ".__modinit__" + else: + return "__modinit__" + + def __repr__(self): + printer = types.TypePrinter() + globals = ["%s: %s" % (var, printer.name(self.globals[var])) for var in self.globals] + return "" % (repr(self.name), ",\n ".join(globals)) diff --git a/artiq/compiler/prelude.py b/artiq/compiler/prelude.py new file mode 100644 index 000000000..26870cb52 --- /dev/null +++ b/artiq/compiler/prelude.py @@ -0,0 +1,44 @@ +""" +The :mod:`prelude` module contains the initial global environment +in which ARTIQ kernels are evaluated. +""" + +from . import builtins + +def globals(): + return { + # Value constructors + "bool": builtins.fn_bool(), + "int": builtins.fn_int(), + "float": builtins.fn_float(), + "list": builtins.fn_list(), + "range": builtins.fn_range(), + + # Exception constructors + "Exception": builtins.fn_Exception(), + "IndexError": builtins.fn_IndexError(), + "ValueError": builtins.fn_ValueError(), + "ZeroDivisionError": builtins.fn_ZeroDivisionError(), + + # Built-in Python functions + "len": builtins.fn_len(), + "round": builtins.fn_round(), + "print": builtins.fn_print(), + + # ARTIQ decorators + "kernel": builtins.fn_kernel(), + + # ARTIQ context managers + "parallel": builtins.fn_parallel(), + "sequential": builtins.fn_sequential(), + + # ARTIQ time management functions + "now": builtins.fn_now(), + "delay": builtins.fn_delay(), + "at": builtins.fn_at(), + "now_mu": builtins.fn_now_mu(), + "delay_mu": builtins.fn_delay_mu(), + "at_mu": builtins.fn_at_mu(), + "mu_to_seconds": builtins.fn_mu_to_seconds(), + "seconds_to_mu": builtins.fn_seconds_to_mu(), + } diff --git a/artiq/compiler/targets.py b/artiq/compiler/targets.py new file mode 100644 index 000000000..8889ee67c --- /dev/null +++ b/artiq/compiler/targets.py @@ -0,0 +1,169 @@ +import os, sys, tempfile, subprocess +from artiq.compiler import types +from llvmlite_artiq import ir as ll, binding as llvm + +llvm.initialize() +llvm.initialize_all_targets() +llvm.initialize_all_asmprinters() + +class RunTool: + def __init__(self, pattern, **tempdata): + self.files = [] + self.pattern = pattern + self.tempdata = tempdata + + def maketemp(self, data): + f = tempfile.NamedTemporaryFile() + f.write(data) + f.flush() + self.files.append(f) + return f + + def __enter__(self): + tempfiles = {} + tempnames = {} + for key in self.tempdata: + tempfiles[key] = self.maketemp(self.tempdata[key]) + tempnames[key] = tempfiles[key].name + + cmdline = [] + for argument in self.pattern: + cmdline.append(argument.format(**tempnames)) + + process = subprocess.Popen(cmdline, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = process.communicate() + if process.returncode != 0: + raise Exception("{} invocation failed: {}". + format(cmdline[0], stderr.decode('utf-8'))) + + tempfiles["__stdout__"] = stdout.decode('utf-8') + return tempfiles + + def __exit__(self, exc_typ, exc_value, exc_trace): + for f in self.files: + f.close() + +class Target: + """ + A description of the target environment where the binaries + generated by the ARTIQ compiler will be deployed. + + :var triple: (string) + LLVM target triple, e.g. ``"or1k"`` + :var data_layout: (string) + LLVM target data layout, e.g. ``"E-m:e-p:32:32-i64:32-f64:32-v64:32-v128:32-a:0:32-n32"`` + :var features: (list of string) + LLVM target CPU features, e.g. ``["mul", "div", "ffl1"]`` + :var print_function: (string) + Name of a formatted print functions (with the signature of ``printf``) + provided by the target, e.g. ``"printf"``. + """ + triple = "unknown" + data_layout = "" + features = [] + print_function = "printf" + + + def __init__(self): + self.llcontext = ll.Context() + + def compile(self, module): + """Compile the module to a relocatable object for this target.""" + + if os.getenv("ARTIQ_DUMP_SIG"): + print("====== MODULE_SIGNATURE DUMP ======", file=sys.stderr) + print(module, file=sys.stderr) + + if os.getenv("ARTIQ_DUMP_IR"): + print("====== ARTIQ IR DUMP ======", file=sys.stderr) + type_printer = types.TypePrinter() + for function in module.artiq_ir: + print(function.as_entity(type_printer), file=sys.stderr) + + llmod = module.build_llvm_ir(self) + llparsedmod = llvm.parse_assembly(str(llmod)) + llparsedmod.verify() + + if os.getenv("ARTIQ_DUMP_LLVM"): + print("====== LLVM IR DUMP ======", file=sys.stderr) + print(str(llparsedmod), file=sys.stderr) + + llpassmgrbuilder = llvm.create_pass_manager_builder() + llpassmgrbuilder.opt_level = 2 # -O2 + llpassmgrbuilder.size_level = 1 # -Os + + llpassmgr = llvm.create_module_pass_manager() + llpassmgrbuilder.populate(llpassmgr) + llpassmgr.run(llparsedmod) + + if os.getenv("ARTIQ_DUMP_LLVM"): + print("====== LLVM IR DUMP (OPTIMIZED) ======", file=sys.stderr) + print(str(llparsedmod), file=sys.stderr) + + lltarget = llvm.Target.from_triple(self.triple) + llmachine = lltarget.create_target_machine( + features=",".join(["+{}".format(f) for f in self.features]), + reloc="pic", codemodel="default") + + if os.getenv("ARTIQ_DUMP_ASSEMBLY"): + print("====== ASSEMBLY DUMP ======", file=sys.stderr) + print(llmachine.emit_assembly(llparsedmod), file=sys.stderr) + + return llmachine.emit_object(llparsedmod) + + def link(self, objects, init_fn): + """Link the relocatable objects into a shared library for this target.""" + with RunTool([self.triple + "-ld", "-shared", "--eh-frame-hdr", "-init", init_fn] + + ["{{obj{}}}".format(index) for index in range(len(objects))] + + ["-o", "{output}"], + output=b"", + **{"obj{}".format(index): obj for index, obj in enumerate(objects)}) \ + as results: + library = results["output"].read() + + if os.getenv("ARTIQ_DUMP_ELF"): + shlib_temp = tempfile.NamedTemporaryFile(suffix=".so", delete=False) + shlib_temp.write(library) + shlib_temp.close() + print("====== SHARED LIBRARY DUMP ======", file=sys.stderr) + print("Shared library dumped as {}".format(shlib_temp.name), file=sys.stderr) + + return library + + def compile_and_link(self, modules): + return self.link([self.compile(module) for module in modules], + init_fn=modules[0].entry_point()) + + def strip(self, library): + with RunTool([self.triple + "-strip", "--strip-debug", "{library}", "-o", "{output}"], + library=library, output=b"") \ + as results: + return results["output"].read() + + def symbolize(self, library, addresses): + # Addresses point one instruction past the jump; offset them back by 1. + offset_addresses = [hex(addr - 1) for addr in addresses] + with RunTool([self.triple + "-addr2line", "--functions", "--inlines", + "--exe={library}"] + offset_addresses, + library=library) \ + as results: + lines = results["__stdout__"].rstrip().split("\n") + backtrace = [] + for function_name, location, address in zip(lines[::2], lines[1::2], addresses): + filename, line = location.rsplit(":", 1) + if filename == "??": + continue + # can't get column out of addr2line D: + backtrace.append((filename, int(line), -1, function_name, address)) + return backtrace + +class NativeTarget(Target): + def __init__(self): + super().__init__() + self.triple = llvm.get_default_triple() + +class OR1KTarget(Target): + triple = "or1k-linux" + data_layout = "E-m:e-p:32:32-i64:32-f64:32-v64:32-v128:32-a:0:32-n32" + features = ["mul", "div", "ffl1", "cmov", "addc"] + print_function = "lognonl" diff --git a/artiq/compiler/testbench/__init__.py b/artiq/compiler/testbench/__init__.py new file mode 100644 index 000000000..68ea51d7b --- /dev/null +++ b/artiq/compiler/testbench/__init__.py @@ -0,0 +1,21 @@ +import time, cProfile as profile, pstats + +def benchmark(f, name): + profiler = profile.Profile() + profiler.enable() + + start = time.perf_counter() + end = 0 + runs = 0 + while end - start < 5 or runs < 10: + f() + runs += 1 + end = time.perf_counter() + + profiler.create_stats() + + print("{} {} runs: {:.2f}s, {:.2f}ms/run".format( + runs, name, end - start, (end - start) / runs * 1000)) + + stats = pstats.Stats(profiler) + stats.strip_dirs().sort_stats('time').print_stats(10) diff --git a/artiq/compiler/testbench/embedding.py b/artiq/compiler/testbench/embedding.py new file mode 100644 index 000000000..abe5f2fd3 --- /dev/null +++ b/artiq/compiler/testbench/embedding.py @@ -0,0 +1,35 @@ +import sys, os + +from artiq.master.databases import DeviceDB +from artiq.master.worker_db import DeviceManager + +from artiq.coredevice.core import Core, CompileError + +def main(): + if len(sys.argv) > 1 and sys.argv[1] == "+compile": + del sys.argv[1] + compile_only = True + else: + compile_only = False + + ddb_path = os.path.join(os.path.dirname(sys.argv[1]), "device_db.pyon") + dmgr = DeviceManager(DeviceDB(ddb_path)) + + with open(sys.argv[1]) as f: + testcase_code = compile(f.read(), f.name, "exec") + testcase_vars = {'__name__': 'testbench', 'dmgr': dmgr} + exec(testcase_code, testcase_vars) + + try: + core = dmgr.get("core") + if compile_only: + core.compile(testcase_vars["entrypoint"], (), {}) + else: + core.run(testcase_vars["entrypoint"], (), {}) + print(core.comm.get_log()) + core.comm.clear_log() + except CompileError as error: + print("\n".join(error.__cause__.diagnostic.render(only_line=True))) + +if __name__ == "__main__": + main() diff --git a/artiq/compiler/testbench/inferencer.py b/artiq/compiler/testbench/inferencer.py new file mode 100644 index 000000000..174baf8f4 --- /dev/null +++ b/artiq/compiler/testbench/inferencer.py @@ -0,0 +1,83 @@ +import sys, fileinput, os +from pythonparser import source, diagnostic, algorithm, parse_buffer +from .. import prelude, types +from ..transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer + +class Printer(algorithm.Visitor): + """ + :class:`Printer` prints ``:`` and the node type after every typed node, + and ``->`` and the node type before the colon in a function definition. + + In almost all cases (except function definition) this does not result + in valid Python syntax. + + :ivar rewriter: (:class:`pythonparser.source.Rewriter`) rewriter instance + """ + + def __init__(self, buf): + self.rewriter = source.Rewriter(buf) + self.type_printer = types.TypePrinter() + + def rewrite(self): + return self.rewriter.rewrite() + + def visit_FunctionDefT(self, node): + super().generic_visit(node) + + self.rewriter.insert_before(node.colon_loc, + "->{}".format(self.type_printer.name(node.return_type))) + + def visit_ExceptHandlerT(self, node): + super().generic_visit(node) + + if node.name_loc: + self.rewriter.insert_after(node.name_loc, + ":{}".format(self.type_printer.name(node.name_type))) + + def generic_visit(self, node): + super().generic_visit(node) + + if hasattr(node, "type"): + self.rewriter.insert_after(node.loc, + ":{}".format(self.type_printer.name(node.type))) + +def main(): + if len(sys.argv) > 1 and sys.argv[1] == "+mono": + del sys.argv[1] + monomorphize = True + else: + monomorphize = False + + if len(sys.argv) > 1 and sys.argv[1] == "+diag": + del sys.argv[1] + def process_diagnostic(diag): + print("\n".join(diag.render(only_line=True))) + if diag.level == "fatal": + exit() + else: + def process_diagnostic(diag): + print("\n".join(diag.render())) + if diag.level in ("fatal", "error"): + exit(1) + + engine = diagnostic.Engine() + engine.process = process_diagnostic + + buf = source.Buffer("".join(fileinput.input()).expandtabs(), + os.path.basename(fileinput.filename())) + parsed, comments = parse_buffer(buf, engine=engine) + typed = ASTTypedRewriter(engine=engine, prelude=prelude.globals()).visit(parsed) + Inferencer(engine=engine).visit(typed) + if monomorphize: + IntMonomorphizer(engine=engine).visit(typed) + Inferencer(engine=engine).visit(typed) + + printer = Printer(buf) + printer.visit(typed) + for comment in comments: + if comment.text.find("CHECK") >= 0: + printer.rewriter.remove(comment.loc) + print(printer.rewrite().source) + +if __name__ == "__main__": + main() diff --git a/artiq/compiler/testbench/irgen.py b/artiq/compiler/testbench/irgen.py new file mode 100644 index 000000000..4add27c20 --- /dev/null +++ b/artiq/compiler/testbench/irgen.py @@ -0,0 +1,19 @@ +import sys, fileinput +from pythonparser import diagnostic +from .. import Module, Source + +def main(): + def process_diagnostic(diag): + print("\n".join(diag.render())) + if diag.level in ("fatal", "error"): + exit(1) + + engine = diagnostic.Engine() + engine.process = process_diagnostic + + mod = Module(Source.from_string("".join(fileinput.input()).expandtabs(), engine=engine)) + for fn in mod.artiq_ir: + print(fn) + +if __name__ == "__main__": + main() diff --git a/artiq/compiler/testbench/jit.py b/artiq/compiler/testbench/jit.py new file mode 100644 index 000000000..c1c90dd56 --- /dev/null +++ b/artiq/compiler/testbench/jit.py @@ -0,0 +1,35 @@ +import os, sys, fileinput, ctypes +from pythonparser import diagnostic +from llvmlite_artiq import binding as llvm +from .. import Module, Source +from ..targets import NativeTarget + +def main(): + libartiq_support = os.getenv('LIBARTIQ_SUPPORT') + if libartiq_support is not None: + llvm.load_library_permanently(libartiq_support) + + def process_diagnostic(diag): + print("\n".join(diag.render())) + if diag.level in ("fatal", "error"): + exit(1) + + engine = diagnostic.Engine() + engine.process = process_diagnostic + + source = "".join(fileinput.input()) + source = source.replace("#ARTIQ#", "") + mod = Module(Source.from_string(source.expandtabs(), engine=engine)) + + target = NativeTarget() + llmod = mod.build_llvm_ir(target) + llparsedmod = llvm.parse_assembly(str(llmod)) + llparsedmod.verify() + + llmachine = llvm.Target.from_triple(target.triple).create_target_machine() + lljit = llvm.create_mcjit_compiler(llparsedmod, llmachine) + llmain = lljit.get_pointer_to_global(llparsedmod.get_function(llmod.name + ".__modinit__")) + ctypes.CFUNCTYPE(None)(llmain)() + +if __name__ == "__main__": + main() diff --git a/artiq/compiler/testbench/llvmgen.py b/artiq/compiler/testbench/llvmgen.py new file mode 100644 index 000000000..33500ec8b --- /dev/null +++ b/artiq/compiler/testbench/llvmgen.py @@ -0,0 +1,30 @@ +import sys, fileinput +from pythonparser import diagnostic +from llvmlite_artiq import ir as ll +from .. import Module, Source +from ..targets import NativeTarget + +def main(): + def process_diagnostic(diag): + print("\n".join(diag.render())) + if diag.level in ("fatal", "error"): + exit(1) + + engine = diagnostic.Engine() + engine.process = process_diagnostic + + mod = Module(Source.from_string("".join(fileinput.input()).expandtabs(), engine=engine)) + + target = NativeTarget() + llmod = mod.build_llvm_ir(target=target) + + # Add main so that the result can be executed with lli + llmain = ll.Function(llmod, ll.FunctionType(ll.VoidType(), []), "main") + llbuilder = ll.IRBuilder(llmain.append_basic_block("entry")) + llbuilder.call(llmod.get_global(llmod.name + ".__modinit__"), []) + llbuilder.ret_void() + + print(llmod) + +if __name__ == "__main__": + main() diff --git a/artiq/compiler/testbench/perf.py b/artiq/compiler/testbench/perf.py new file mode 100644 index 000000000..c01e02376 --- /dev/null +++ b/artiq/compiler/testbench/perf.py @@ -0,0 +1,37 @@ +import sys, os +from pythonparser import diagnostic +from .. import Module, Source +from ..targets import OR1KTarget +from . import benchmark + +def main(): + if not len(sys.argv) == 2: + print("Expected exactly one module filename", file=sys.stderr) + exit(1) + + def process_diagnostic(diag): + print("\n".join(diag.render()), file=sys.stderr) + if diag.level in ("fatal", "error"): + exit(1) + + engine = diagnostic.Engine() + engine.process = process_diagnostic + + # Make sure everything's valid + filename = sys.argv[1] + with open(filename) as f: + code = f.read() + source = Source.from_string(code, filename, engine=engine) + module = Module(source) + + benchmark(lambda: Source.from_string(code, filename), + "ARTIQ parsing and inference") + + benchmark(lambda: Module(source), + "ARTIQ transforms and validators") + + benchmark(lambda: OR1KTarget().compile_and_link([module]), + "LLVM optimization and linking") + +if __name__ == "__main__": + main() diff --git a/artiq/compiler/testbench/perf_embedding.py b/artiq/compiler/testbench/perf_embedding.py new file mode 100644 index 000000000..258978eac --- /dev/null +++ b/artiq/compiler/testbench/perf_embedding.py @@ -0,0 +1,52 @@ +import sys, os +from pythonparser import diagnostic +from ...protocols.file_db import FlatFileDB +from ...master.worker_db import DeviceManager +from .. import Module +from ..embedding import Stitcher +from ..targets import OR1KTarget +from . import benchmark + +def main(): + if not len(sys.argv) == 2: + print("Expected exactly one module filename", file=sys.stderr) + exit(1) + + def process_diagnostic(diag): + print("\n".join(diag.render()), file=sys.stderr) + if diag.level in ("fatal", "error"): + exit(1) + + engine = diagnostic.Engine() + engine.process = process_diagnostic + + with open(sys.argv[1]) as f: + testcase_code = compile(f.read(), f.name, "exec") + testcase_vars = {'__name__': 'testbench'} + exec(testcase_code, testcase_vars) + + ddb_path = os.path.join(os.path.dirname(sys.argv[1]), "ddb.pyon") + dmgr = DeviceManager(FlatFileDB(ddb_path)) + + def embed(): + experiment = testcase_vars["Benchmark"](dmgr) + + stitcher = Stitcher() + stitcher.stitch_call(experiment.run, (experiment,), {}) + stitcher.finalize() + return stitcher + + stitcher = embed() + module = Module(stitcher) + + benchmark(lambda: embed(), + "ARTIQ embedding") + + benchmark(lambda: Module(stitcher), + "ARTIQ transforms and validators") + + benchmark(lambda: OR1KTarget().compile_and_link([module]), + "LLVM optimization and linking") + +if __name__ == "__main__": + main() diff --git a/artiq/compiler/testbench/shlib.py b/artiq/compiler/testbench/shlib.py new file mode 100644 index 000000000..97c19f11b --- /dev/null +++ b/artiq/compiler/testbench/shlib.py @@ -0,0 +1,30 @@ +import sys, os +from pythonparser import diagnostic +from .. import Module, Source +from ..targets import OR1KTarget + +def main(): + if not len(sys.argv) > 1: + print("Expected at least one module filename", file=sys.stderr) + exit(1) + + def process_diagnostic(diag): + print("\n".join(diag.render()), file=sys.stderr) + if diag.level in ("fatal", "error"): + exit(1) + + engine = diagnostic.Engine() + engine.process = process_diagnostic + + modules = [] + for filename in sys.argv[1:]: + modules.append(Module(Source.from_filename(filename, engine=engine))) + + llobj = OR1KTarget().compile_and_link(modules) + + basename, ext = os.path.splitext(sys.argv[-1]) + with open(basename + ".so", "wb") as f: + f.write(llobj) + +if __name__ == "__main__": + main() diff --git a/artiq/compiler/testbench/signature.py b/artiq/compiler/testbench/signature.py new file mode 100644 index 000000000..98d4687fb --- /dev/null +++ b/artiq/compiler/testbench/signature.py @@ -0,0 +1,43 @@ +import sys, fileinput +from pythonparser import diagnostic +from .. import types, iodelay, Module, Source + +def main(): + if len(sys.argv) > 1 and sys.argv[1] == "+diag": + del sys.argv[1] + diag = True + def process_diagnostic(diag): + print("\n".join(diag.render(only_line=True))) + if diag.level == "fatal": + exit() + else: + diag = False + def process_diagnostic(diag): + print("\n".join(diag.render(colored=True))) + if diag.level in ("fatal", "error"): + exit(1) + + if len(sys.argv) > 1 and sys.argv[1] == "+delay": + del sys.argv[1] + force_delays = True + else: + force_delays = False + + engine = diagnostic.Engine() + engine.process = process_diagnostic + + try: + mod = Module(Source.from_string("".join(fileinput.input()).expandtabs(), engine=engine)) + + if force_delays: + for var in mod.globals: + typ = mod.globals[var].find() + if types.is_function(typ) and types.is_indeterminate_delay(typ.delay): + process_diagnostic(typ.delay.find().cause) + + print(repr(mod)) + except: + if not diag: raise + +if __name__ == "__main__": + main() diff --git a/artiq/compiler/transforms/__init__.py b/artiq/compiler/transforms/__init__.py new file mode 100644 index 000000000..665fd3ea1 --- /dev/null +++ b/artiq/compiler/transforms/__init__.py @@ -0,0 +1,8 @@ +from .asttyped_rewriter import ASTTypedRewriter +from .inferencer import Inferencer +from .int_monomorphizer import IntMonomorphizer +from .iodelay_estimator import IODelayEstimator +from .artiq_ir_generator import ARTIQIRGenerator +from .dead_code_eliminator import DeadCodeEliminator +from .llvm_ir_generator import LLVMIRGenerator +from .interleaver import Interleaver diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py new file mode 100644 index 000000000..e29949299 --- /dev/null +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -0,0 +1,1816 @@ +""" +:class:`ARTIQIRGenerator` transforms typed AST into ARTIQ intermediate +representation. ARTIQ IR is designed to be low-level enough that +its operations are elementary--contain no internal branching-- +but without too much detail, such as exposing the reference/value +semantics explicitly. +""" + +from collections import OrderedDict, defaultdict +from pythonparser import algorithm, diagnostic, ast +from .. import types, builtins, asttyped, ir, iodelay + +def _readable_name(insn): + if isinstance(insn, ir.Constant): + return str(insn.value) + else: + return insn.name + +def _extract_loc(node): + if "keyword_loc" in node._locs: + return node.keyword_loc + else: + return node.loc + +# We put some effort in keeping generated IR readable, +# i.e. with a more or less linear correspondence to the source. +# This is why basic blocks sometimes seem to be produced in an odd order. +class ARTIQIRGenerator(algorithm.Visitor): + """ + :class:`ARTIQIRGenerator` contains a lot of internal state, + which is effectively maintained in a stack--with push/pop + pairs around any state updates. It is comprised of following: + + :ivar current_loc: (:class:`pythonparser.source.Range`) + source range of the node being currently visited + :ivar current_function: (:class:`ir.Function` or None) + module, def or lambda currently being translated + :ivar current_globals: (set of string) + set of variables that will be resolved in global scope + :ivar current_block: (:class:`ir.BasicBlock`) + basic block to which any new instruction will be appended + :ivar current_env: (:class:`ir.Alloc` of type :class:`ir.TEnvironment`) + the chained function environment, containing variables that + can become upvalues + :ivar current_private_env: (:class:`ir.Alloc` of type :class:`ir.TEnvironment`) + the private function environment, containing internal state + :ivar current_args: (dict of string to :class:`ir.Argument`) + the map of Python names of formal arguments to + the current function to their SSA names + :ivar current_assign: (:class:`ir.Value` or None) + the right-hand side of current assignment statement, or + a component of a composite right-hand side when visiting + a composite left-hand side, such as, in ``x, y = z``, + the 2nd tuple element when visting ``y`` + :ivar current_assert_env: (:class:`ir.Alloc` of type :class:`ir.TEnvironment`) + the environment where the individual components of current assert + statement are stored until display + :ivar current_assert_subexprs: (list of (:class:`ast.AST`, string)) + the mapping from components of current assert statement to the names + their values have in :ivar:`current_assert_env` + :ivar break_target: (:class:`ir.BasicBlock` or None) + the basic block to which ``break`` will transfer control + :ivar continue_target: (:class:`ir.BasicBlock` or None) + the basic block to which ``continue`` will transfer control + :ivar return_target: (:class:`ir.BasicBlock` or None) + the basic block to which ``return`` will transfer control + :ivar unwind_target: (:class:`ir.BasicBlock` or None) + the basic block to which unwinding will transfer control + + There is, additionally, some global state that is used to translate + the results of analyses on AST level to IR level: + + :ivar function_map: (map of :class:`ast.FunctionDefT` to :class:`ir.Function`) + the map from function definition nodes to IR functions + :ivar variable_map: (map of :class:`ast.NameT` to :class:`ir.GetLocal`) + the map from variable name nodes to instructions retrieving + the variable values + :ivar method_map: (map of :class:`ast.AttributeT` to :class:`ir.GetAttribute`) + the map from method resolution nodes to instructions retrieving + the called function inside a translated :class:`ast.CallT` node + """ + + _size_type = builtins.TInt(types.TValue(32)) + + def __init__(self, module_name, engine, ref_period): + self.engine = engine + self.functions = [] + self.name = [module_name] if module_name != "" else [] + self.ref_period = ir.Constant(ref_period, builtins.TFloat()) + self.current_loc = None + self.current_function = None + self.current_class = None + self.current_globals = set() + self.current_block = None + self.current_env = None + self.current_private_env = None + self.current_args = None + self.current_assign = None + self.current_assert_env = None + self.current_assert_subexprs = None + self.break_target = None + self.continue_target = None + self.return_target = None + self.unwind_target = None + self.function_map = dict() + self.variable_map = dict() + self.method_map = defaultdict(lambda: []) + + def annotate_calls(self, devirtualization): + for var_node in devirtualization.variable_map: + callee_node = devirtualization.variable_map[var_node] + if callee_node is None: + continue + callee = self.function_map[callee_node] + + call_target = self.variable_map[var_node] + for use in call_target.uses: + if isinstance(use, (ir.Call, ir.Invoke)) and \ + use.target_function() == call_target: + use.static_target_function = callee + + for type_and_method in devirtualization.method_map: + callee_node = devirtualization.method_map[type_and_method] + if callee_node is None: + continue + callee = self.function_map[callee_node] + + for call in self.method_map[type_and_method]: + assert isinstance(call, (ir.Call, ir.Invoke)) + call.static_target_function = callee + + def add_block(self, name=""): + block = ir.BasicBlock([], name) + self.current_function.add(block) + return block + + def append(self, insn, block=None, loc=None): + if loc is None: + loc = self.current_loc + if block is None: + block = self.current_block + + if insn.loc is None: + insn.loc = loc + return block.append(insn) + + def terminate(self, insn): + if not self.current_block.is_terminated(): + self.append(insn) + else: + insn.drop_references() + + # Visitors + + def visit(self, obj): + if isinstance(obj, list): + for elt in obj: + self.visit(elt) + if self.current_block.is_terminated(): + break + elif isinstance(obj, ast.AST): + try: + old_loc, self.current_loc = self.current_loc, _extract_loc(obj) + return self._visit_one(obj) + finally: + self.current_loc = old_loc + + # Module visitor + + def visit_ModuleT(self, node): + # Treat start of module as synthesized + self.current_loc = None + + try: + typ = types.TFunction(OrderedDict(), OrderedDict(), builtins.TNone()) + func = ir.Function(typ, ".".join(self.name + ['__modinit__']), [], + loc=node.loc.begin()) + self.functions.append(func) + old_func, self.current_function = self.current_function, func + + entry = self.add_block("entry") + old_block, self.current_block = self.current_block, entry + + env = self.append(ir.Alloc([], ir.TEnvironment(node.typing_env), name="env")) + old_env, self.current_env = self.current_env, env + + priv_env = self.append(ir.Alloc([], ir.TEnvironment({ "$return": typ.ret }), + name="privenv")) + old_priv_env, self.current_private_env = self.current_private_env, priv_env + + self.generic_visit(node) + self.terminate(ir.Return(ir.Constant(None, builtins.TNone()))) + + return self.functions + finally: + self.current_function = old_func + self.current_block = old_block + self.current_env = old_env + self.current_private_env = old_priv_env + + # Statement visitors + + def visit_ClassDefT(self, node): + klass = self.append(ir.Alloc([], node.constructor_type, + name="class.{}".format(node.name))) + self._set_local(node.name, klass) + + try: + old_class, self.current_class = self.current_class, klass + self.visit(node.body) + finally: + self.current_class = old_class + + def visit_function(self, node, is_lambda, is_internal): + if is_lambda: + name = "lambda@{}:{}".format(node.loc.line(), node.loc.column()) + typ = node.type.find() + else: + name = node.name + typ = node.signature_type.find() + + try: + defaults = [] + for arg_name, default_node in zip(typ.optargs, node.args.defaults): + default = self.visit(default_node) + env_default_name = \ + self.current_env.type.add("default$" + arg_name, default.type) + self.append(ir.SetLocal(self.current_env, env_default_name, default)) + defaults.append(env_default_name) + + old_name, self.name = self.name, self.name + [name] + + env_arg = ir.EnvironmentArgument(self.current_env.type, "outerenv") + + old_args, self.current_args = self.current_args, {} + + args = [] + for arg_name in typ.args: + arg = ir.Argument(typ.args[arg_name], "arg." + arg_name) + self.current_args[arg_name] = arg + args.append(arg) + + optargs = [] + for arg_name in typ.optargs: + arg = ir.Argument(ir.TOption(typ.optargs[arg_name]), "arg." + arg_name) + self.current_args[arg_name] = arg + optargs.append(arg) + + func = ir.Function(typ, ".".join(self.name), [env_arg] + args + optargs, + loc=node.lambda_loc if is_lambda else node.keyword_loc) + func.is_internal = is_internal + self.functions.append(func) + old_func, self.current_function = self.current_function, func + + if not is_lambda: + self.function_map[node] = func + + entry = self.add_block() + old_block, self.current_block = self.current_block, entry + + old_globals, self.current_globals = self.current_globals, node.globals_in_scope + + env_without_globals = \ + {var: node.typing_env[var] + for var in node.typing_env + if var not in node.globals_in_scope} + env_type = ir.TEnvironment(env_without_globals, self.current_env.type) + env = self.append(ir.Alloc([], env_type, name="env")) + old_env, self.current_env = self.current_env, env + + if not is_lambda: + priv_env = self.append(ir.Alloc([], ir.TEnvironment({ "$return": typ.ret }), + name="privenv")) + old_priv_env, self.current_private_env = self.current_private_env, priv_env + + self.append(ir.SetLocal(env, "$outer", env_arg)) + for index, arg_name in enumerate(typ.args): + self.append(ir.SetLocal(env, arg_name, args[index])) + for index, (arg_name, env_default_name) in enumerate(zip(typ.optargs, defaults)): + default = self.append(ir.GetLocal(self.current_env, env_default_name)) + value = self.append(ir.Builtin("unwrap_or", [optargs[index], default], + typ.optargs[arg_name])) + self.append(ir.SetLocal(env, arg_name, value)) + + result = self.visit(node.body) + + if is_lambda: + self.terminate(ir.Return(result)) + elif builtins.is_none(typ.ret): + if not self.current_block.is_terminated(): + self.current_block.append(ir.Return(ir.Constant(None, builtins.TNone()))) + else: + if not self.current_block.is_terminated(): + self.current_block.append(ir.Unreachable()) + finally: + self.name = old_name + self.current_args = old_args + self.current_function = old_func + self.current_block = old_block + self.current_globals = old_globals + self.current_env = old_env + if not is_lambda: + self.current_private_env = old_priv_env + + return self.append(ir.Closure(func, self.current_env)) + + def visit_FunctionDefT(self, node, in_class=None): + func = self.visit_function(node, is_lambda=False, is_internal=len(self.name) > 2) + if in_class is None: + self._set_local(node.name, func) + else: + self.append(ir.SetAttr(in_class, node.name, func)) + + def visit_Return(self, node): + if node.value is None: + return_value = ir.Constant(None, builtins.TNone()) + else: + return_value = self.visit(node.value) + + if self.return_target is None: + self.append(ir.Return(return_value)) + else: + self.append(ir.SetLocal(self.current_private_env, "$return", return_value)) + self.append(ir.Branch(self.return_target)) + + def visit_Expr(self, node): + # Ignore the value, do it for side effects. + result = self.visit(node.value) + + # See comment in visit_Pass. + if isinstance(result, ir.Constant): + self.visit_Pass(node) + + def visit_Pass(self, node): + # Insert a dummy instruction so that analyses which extract + # locations from CFG have something to use. + self.append(ir.Builtin("nop", [], builtins.TNone())) + + def visit_Assign(self, node): + try: + self.current_assign = self.visit(node.value) + assert self.current_assign is not None + for target in node.targets: + self.visit(target) + finally: + self.current_assign = None + + def visit_AugAssign(self, node): + lhs = self.visit(node.target) + rhs = self.visit(node.value) + value = self.append(ir.Arith(node.op, lhs, rhs)) + try: + self.current_assign = value + self.visit(node.target) + finally: + self.current_assign = None + + def coerce_to_bool(self, insn, block=None): + if builtins.is_bool(insn.type): + return insn + elif builtins.is_int(insn.type): + return self.append(ir.Compare(ast.NotEq(loc=None), insn, ir.Constant(0, insn.type)), + block=block) + elif builtins.is_float(insn.type): + return self.append(ir.Compare(ast.NotEq(loc=None), insn, ir.Constant(0, insn.type)), + block=block) + elif builtins.is_iterable(insn.type): + length = self.iterable_len(insn) + return self.append(ir.Compare(ast.NotEq(loc=None), length, ir.Constant(0, length.type)), + block=block) + else: + note = diagnostic.Diagnostic("note", + "this expression has type {type}", + {"type": types.TypePrinter().name(insn.type)}, + insn.loc) + diag = diagnostic.Diagnostic("warning", + "this expression, which is always truthful, is coerced to bool", {}, + insn.loc, notes=[note]) + self.engine.process(diag) + return ir.Constant(True, builtins.TBool()) + + def visit_If(self, node): + cond = self.visit(node.test) + cond = self.coerce_to_bool(cond) + head = self.current_block + + if_true = self.add_block() + self.current_block = if_true + self.visit(node.body) + post_if_true = self.current_block + + if any(node.orelse): + if_false = self.add_block() + self.current_block = if_false + self.visit(node.orelse) + post_if_false = self.current_block + + tail = self.add_block() + self.current_block = tail + if not post_if_true.is_terminated(): + post_if_true.append(ir.Branch(tail)) + + if any(node.orelse): + if not post_if_false.is_terminated(): + post_if_false.append(ir.Branch(tail)) + self.append(ir.BranchIf(cond, if_true, if_false), block=head) + else: + self.append(ir.BranchIf(cond, if_true, tail), block=head) + + def visit_While(self, node): + try: + head = self.add_block("while.head") + self.append(ir.Branch(head)) + self.current_block = head + old_continue, self.continue_target = self.continue_target, head + cond = self.visit(node.test) + + break_block = self.add_block("while.break") + old_break, self.break_target = self.break_target, break_block + + body = self.add_block("while.body") + self.current_block = body + self.visit(node.body) + post_body = self.current_block + + if any(node.orelse): + else_tail = self.add_block("while.else") + self.current_block = else_tail + self.visit(node.orelse) + post_else_tail = self.current_block + + tail = self.add_block("while.tail") + self.current_block = tail + + if any(node.orelse): + if not post_else_tail.is_terminated(): + post_else_tail.append(ir.Branch(tail)) + else: + else_tail = tail + + head.append(ir.BranchIf(cond, body, else_tail)) + if not post_body.is_terminated(): + post_body.append(ir.Branch(head)) + break_block.append(ir.Branch(tail)) + finally: + self.break_target = old_break + self.continue_target = old_continue + + def iterable_len(self, value, typ=_size_type): + if builtins.is_list(value.type): + return self.append(ir.Builtin("len", [value], typ, + name="{}.len".format(value.name))) + elif builtins.is_range(value.type): + start = self.append(ir.GetAttr(value, "start")) + stop = self.append(ir.GetAttr(value, "stop")) + step = self.append(ir.GetAttr(value, "step")) + spread = self.append(ir.Arith(ast.Sub(loc=None), stop, start)) + return self.append(ir.Arith(ast.FloorDiv(loc=None), spread, step, + name="{}.len".format(value.name))) + else: + assert False + + def iterable_get(self, value, index): + # Assuming the value is within bounds. + if builtins.is_list(value.type): + return self.append(ir.GetElem(value, index)) + elif builtins.is_range(value.type): + start = self.append(ir.GetAttr(value, "start")) + step = self.append(ir.GetAttr(value, "step")) + offset = self.append(ir.Arith(ast.Mult(loc=None), step, index)) + return self.append(ir.Arith(ast.Add(loc=None), start, offset)) + else: + assert False + + def visit_For(self, node): + try: + iterable = self.visit(node.iter) + length = self.iterable_len(iterable) + prehead = self.current_block + + head = self.add_block("for.head") + self.append(ir.Branch(head)) + self.current_block = head + phi = self.append(ir.Phi(length.type)) + phi.add_incoming(ir.Constant(0, phi.type), prehead) + cond = self.append(ir.Compare(ast.Lt(loc=None), phi, length)) + + break_block = self.add_block("for.break") + old_break, self.break_target = self.break_target, break_block + + continue_block = self.add_block("for.continue") + old_continue, self.continue_target = self.continue_target, continue_block + self.current_block = continue_block + + updated_index = self.append(ir.Arith(ast.Add(loc=None), phi, ir.Constant(1, phi.type))) + phi.add_incoming(updated_index, continue_block) + self.append(ir.Branch(head)) + + body = self.add_block("for.body") + self.current_block = body + elt = self.iterable_get(iterable, phi) + try: + self.current_assign = elt + self.visit(node.target) + finally: + self.current_assign = None + self.visit(node.body) + post_body = self.current_block + + if any(node.orelse): + else_tail = self.add_block("for.else") + self.current_block = else_tail + self.visit(node.orelse) + post_else_tail = self.current_block + + tail = self.add_block("for.tail") + self.current_block = tail + + if any(node.orelse): + if not post_else_tail.is_terminated(): + post_else_tail.append(ir.Branch(tail)) + else: + else_tail = tail + + head.append(ir.BranchIf(cond, body, else_tail)) + if not post_body.is_terminated(): + post_body.append(ir.Branch(continue_block)) + break_block.append(ir.Branch(tail)) + finally: + self.break_target = old_break + self.continue_target = old_continue + + def visit_Break(self, node): + self.append(ir.Branch(self.break_target)) + + def visit_Continue(self, node): + self.append(ir.Branch(self.continue_target)) + + def raise_exn(self, exn, loc=None): + if exn is not None: + if loc is None: + loc = self.current_loc + + loc_file = ir.Constant(loc.source_buffer.name, builtins.TStr()) + loc_line = ir.Constant(loc.line(), builtins.TInt(types.TValue(32))) + loc_column = ir.Constant(loc.column(), builtins.TInt(types.TValue(32))) + loc_function = ir.Constant(".".join(self.name), builtins.TStr()) + + self.append(ir.SetAttr(exn, "__file__", loc_file)) + self.append(ir.SetAttr(exn, "__line__", loc_line)) + self.append(ir.SetAttr(exn, "__col__", loc_column)) + self.append(ir.SetAttr(exn, "__func__", loc_function)) + + if self.unwind_target is not None: + self.append(ir.Raise(exn, self.unwind_target)) + else: + self.append(ir.Raise(exn)) + else: + if self.unwind_target is not None: + self.append(ir.Reraise(self.unwind_target)) + else: + self.append(ir.Reraise()) + + def visit_Raise(self, node): + self.raise_exn(self.visit(node.exc)) + + def visit_Try(self, node): + dispatcher = self.add_block("try.dispatch") + + if any(node.finalbody): + # k for continuation + final_state = self.append(ir.Alloc([], ir.TEnvironment({ "$k": ir.TBasicBlock() }))) + final_targets = [] + + if self.break_target is not None: + break_proxy = self.add_block("try.break") + old_break, self.break_target = self.break_target, break_proxy + break_proxy.append(ir.SetLocal(final_state, "$k", old_break)) + final_targets.append(old_break) + if self.continue_target is not None: + continue_proxy = self.add_block("try.continue") + old_continue, self.continue_target = self.continue_target, continue_proxy + continue_proxy.append(ir.SetLocal(final_state, "$k", old_continue)) + final_targets.append(old_continue) + + return_proxy = self.add_block("try.return") + old_return, self.return_target = self.return_target, return_proxy + if old_return is not None: + return_proxy.append(ir.SetLocal(final_state, "$k", old_return)) + final_targets.append(old_return) + else: + return_action = self.add_block("try.doreturn") + value = return_action.append(ir.GetLocal(self.current_private_env, "$return")) + return_action.append(ir.Return(value)) + return_proxy.append(ir.SetLocal(final_state, "$k", return_action)) + final_targets.append(return_action) + + body = self.add_block("try.body") + self.append(ir.Branch(body)) + self.current_block = body + + try: + old_unwind, self.unwind_target = self.unwind_target, dispatcher + self.visit(node.body) + finally: + self.unwind_target = old_unwind + + self.visit(node.orelse) + body = self.current_block + + if any(node.finalbody): + if self.break_target: + self.break_target = old_break + if self.continue_target: + self.continue_target = old_continue + self.return_target = old_return + + cleanup = self.add_block('handler.cleanup') + landingpad = dispatcher.append(ir.LandingPad(cleanup)) + + handlers = [] + for handler_node in node.handlers: + exn_type = handler_node.name_type.find() + if handler_node.filter is not None and \ + not builtins.is_exception(exn_type, 'Exception'): + handler = self.add_block("handler." + exn_type.name) + landingpad.add_clause(handler, exn_type) + else: + handler = self.add_block("handler.catchall") + landingpad.add_clause(handler, None) + + self.current_block = handler + if handler_node.name is not None: + exn = self.append(ir.Builtin("exncast", [landingpad], handler_node.name_type)) + self._set_local(handler_node.name, exn) + self.visit(handler_node.body) + post_handler = self.current_block + + handlers.append((handler, post_handler)) + + if any(node.finalbody): + finalizer = self.add_block("finally") + self.current_block = finalizer + + self.visit(node.finalbody) + post_finalizer = self.current_block + + reraise = self.add_block('try.reraise') + reraise.append(ir.Reraise(self.unwind_target)) + + self.current_block = tail = self.add_block("try.tail") + if any(node.finalbody): + final_targets.append(tail) + final_targets.append(reraise) + + if self.break_target: + break_proxy.append(ir.Branch(finalizer)) + if self.continue_target: + continue_proxy.append(ir.Branch(finalizer)) + return_proxy.append(ir.Branch(finalizer)) + + if not body.is_terminated(): + body.append(ir.SetLocal(final_state, "$k", tail)) + body.append(ir.Branch(finalizer)) + + cleanup.append(ir.SetLocal(final_state, "$k", reraise)) + cleanup.append(ir.Branch(finalizer)) + + for handler, post_handler in handlers: + if not post_handler.is_terminated(): + post_handler.append(ir.SetLocal(final_state, "$k", tail)) + post_handler.append(ir.Branch(finalizer)) + + if not post_finalizer.is_terminated(): + dest = post_finalizer.append(ir.GetLocal(final_state, "$k")) + post_finalizer.append(ir.IndirectBranch(dest, final_targets)) + else: + if not body.is_terminated(): + body.append(ir.Branch(tail)) + + cleanup.append(ir.Reraise(self.unwind_target)) + + for handler, post_handler in handlers: + if not post_handler.is_terminated(): + post_handler.append(ir.Branch(tail)) + + def visit_With(self, node): + if len(node.items) != 1: + diag = diagnostic.Diagnostic("fatal", + "only one expression per 'with' statement is supported", + {"type": types.TypePrinter().name(typ)}, + node.context_expr.loc) + self.engine.process(diag) + + context_expr_node = node.items[0].context_expr + optional_vars_node = node.items[0].optional_vars + + if types.is_builtin(context_expr_node.type, "sequential"): + self.visit(node.body) + elif types.is_builtin(context_expr_node.type, "parallel"): + parallel = self.append(ir.Parallel([])) + + heads, tails = [], [] + for stmt in node.body: + self.current_block = self.add_block() + heads.append(self.current_block) + self.visit(stmt) + tails.append(self.current_block) + + for head in heads: + parallel.add_destination(head) + + self.current_block = self.add_block() + for tail in tails: + if not tail.is_terminated(): + tail.append(ir.Branch(self.current_block)) + + # Expression visitors + # These visitors return a node in addition to mutating + # the IR. + + def visit_LambdaT(self, node): + return self.visit_function(node, is_lambda=True, is_internal=True) + + def visit_IfExpT(self, node): + cond = self.visit(node.test) + head = self.current_block + + if_true = self.add_block() + self.current_block = if_true + true_result = self.visit(node.body) + post_if_true = self.current_block + + if_false = self.add_block() + self.current_block = if_false + false_result = self.visit(node.orelse) + post_if_false = self.current_block + + tail = self.add_block() + self.current_block = tail + + if not post_if_true.is_terminated(): + post_if_true.append(ir.Branch(tail)) + if not post_if_false.is_terminated(): + post_if_false.append(ir.Branch(tail)) + head.append(ir.BranchIf(cond, if_true, if_false)) + + phi = self.append(ir.Phi(node.type)) + phi.add_incoming(true_result, post_if_true) + phi.add_incoming(false_result, post_if_false) + return phi + + def visit_NumT(self, node): + return ir.Constant(node.n, node.type) + + def visit_StrT(self, node): + return ir.Constant(node.s, node.type) + + def visit_NameConstantT(self, node): + return ir.Constant(node.value, node.type) + + def _env_for(self, name): + if name in self.current_globals: + return self.append(ir.Builtin("globalenv", [self.current_env], + self.current_env.type.outermost())) + else: + return self.current_env + + def _get_local(self, name): + if self.current_class is not None and \ + name in self.current_class.type.attributes: + return self.append(ir.GetAttr(self.current_class, name, + name="local." + name)) + + return self.append(ir.GetLocal(self._env_for(name), name, + name="local." + name)) + + def _set_local(self, name, value): + if self.current_class is not None and \ + name in self.current_class.type.attributes: + return self.append(ir.SetAttr(self.current_class, name, value)) + + self.append(ir.SetLocal(self._env_for(name), name, value)) + + def visit_NameT(self, node): + if self.current_assign is None: + insn = self._get_local(node.id) + self.variable_map[node] = insn + return insn + else: + return self._set_local(node.id, self.current_assign) + + def visit_AttributeT(self, node): + try: + old_assign, self.current_assign = self.current_assign, None + obj = self.visit(node.value) + finally: + self.current_assign = old_assign + + if node.attr not in obj.type.find().attributes: + # A class attribute. Get the constructor (class object) and + # extract the attribute from it. + constr_type = obj.type.constructor + constr = self.append(ir.GetConstructor(self._env_for(constr_type.name), + constr_type.name, constr_type, + name="constructor." + constr_type.name)) + + if types.is_function(constr.type.attributes[node.attr]): + # A method. Construct a method object instead. + func = self.append(ir.GetAttr(constr, node.attr)) + return self.append(ir.Alloc([func, obj], node.type)) + else: + obj = constr + + if self.current_assign is None: + return self.append(ir.GetAttr(obj, node.attr, + name="{}.{}".format(_readable_name(obj), node.attr))) + else: + self.append(ir.SetAttr(obj, node.attr, self.current_assign)) + + def _map_index(self, length, index, one_past_the_end=False, loc=None): + lt_0 = self.append(ir.Compare(ast.Lt(loc=None), + index, ir.Constant(0, index.type))) + from_end = self.append(ir.Arith(ast.Add(loc=None), length, index)) + mapped_index = self.append(ir.Select(lt_0, from_end, index)) + mapped_ge_0 = self.append(ir.Compare(ast.GtE(loc=None), + mapped_index, ir.Constant(0, mapped_index.type))) + end_cmpop = ast.LtE(loc=None) if one_past_the_end else ast.Lt(loc=None) + mapped_lt_len = self.append(ir.Compare(end_cmpop, mapped_index, length)) + in_bounds = self.append(ir.Select(mapped_ge_0, mapped_lt_len, + ir.Constant(False, builtins.TBool()))) + head = self.current_block + + self.current_block = out_of_bounds_block = self.add_block() + exn = self.alloc_exn(builtins.TException("IndexError"), + ir.Constant("index {0} out of bounds 0:{1}", builtins.TStr()), + index, length) + self.raise_exn(exn, loc=loc) + + self.current_block = in_bounds_block = self.add_block() + head.append(ir.BranchIf(in_bounds, in_bounds_block, out_of_bounds_block)) + + return mapped_index + + def _make_check(self, cond, exn_gen, loc=None): + # cond: bool Value, condition + # exn_gen: lambda()->exn Value, exception if condition not true + cond_block = self.current_block + + self.current_block = body_block = self.add_block() + self.raise_exn(exn_gen(), loc=loc) + + self.current_block = tail_block = self.add_block() + cond_block.append(ir.BranchIf(cond, tail_block, body_block)) + + def _make_loop(self, init, cond_gen, body_gen): + # init: 'iter Value, initial loop variable value + # cond_gen: lambda('iter Value)->bool Value, loop condition + # body_gen: lambda('iter Value)->'iter Value, loop body, + # returns next loop variable value + init_block = self.current_block + + self.current_block = head_block = self.add_block() + init_block.append(ir.Branch(head_block)) + phi = self.append(ir.Phi(init.type)) + phi.add_incoming(init, init_block) + cond = cond_gen(phi) + + self.current_block = body_block = self.add_block() + body = body_gen(phi) + self.append(ir.Branch(head_block)) + phi.add_incoming(body, self.current_block) + + self.current_block = tail_block = self.add_block() + head_block.append(ir.BranchIf(cond, body_block, tail_block)) + + return head_block, body_block, tail_block + + def visit_SubscriptT(self, node): + try: + old_assign, self.current_assign = self.current_assign, None + value = self.visit(node.value) + finally: + self.current_assign = old_assign + + if isinstance(node.slice, ast.Index): + try: + old_assign, self.current_assign = self.current_assign, None + index = self.visit(node.slice.value) + finally: + self.current_assign = old_assign + + length = self.iterable_len(value, index.type) + mapped_index = self._map_index(length, index, + loc=node.begin_loc) + if self.current_assign is None: + result = self.iterable_get(value, mapped_index) + result.set_name("{}.at.{}".format(value.name, _readable_name(index))) + return result + else: + self.append(ir.SetElem(value, mapped_index, self.current_assign, + name="{}.at.{}".format(value.name, _readable_name(index)))) + else: # Slice + length = self.iterable_len(value, node.slice.type) + + if node.slice.lower is not None: + try: + old_assign, self.current_assign = self.current_assign, None + start_index = self.visit(node.slice.lower) + finally: + self.current_assign = old_assign + else: + start_index = ir.Constant(0, node.slice.type) + mapped_start_index = self._map_index(length, start_index, + loc=node.begin_loc) + + if node.slice.upper is not None: + try: + old_assign, self.current_assign = self.current_assign, None + stop_index = self.visit(node.slice.upper) + finally: + self.current_assign = old_assign + else: + stop_index = length + mapped_stop_index = self._map_index(length, stop_index, one_past_the_end=True, + loc=node.begin_loc) + + if node.slice.step is not None: + try: + old_assign, self.current_assign = self.current_assign, None + step = self.visit(node.slice.step) + finally: + self.current_assign = old_assign + + self._make_check( + self.append(ir.Compare(ast.NotEq(loc=None), step, ir.Constant(0, step.type))), + lambda: self.alloc_exn(builtins.TException("ValueError"), + ir.Constant("step cannot be zero", builtins.TStr())), + loc=node.slice.step.loc) + else: + step = ir.Constant(1, node.slice.type) + counting_up = self.append(ir.Compare(ast.Gt(loc=None), step, + ir.Constant(0, step.type))) + + unstepped_size = self.append(ir.Arith(ast.Sub(loc=None), + mapped_stop_index, mapped_start_index)) + slice_size_a = self.append(ir.Arith(ast.FloorDiv(loc=None), unstepped_size, step)) + slice_size_b = self.append(ir.Arith(ast.Mod(loc=None), unstepped_size, step)) + rem_not_empty = self.append(ir.Compare(ast.NotEq(loc=None), slice_size_b, + ir.Constant(0, slice_size_b.type))) + slice_size_c = self.append(ir.Arith(ast.Add(loc=None), slice_size_a, + ir.Constant(1, slice_size_a.type))) + slice_size = self.append(ir.Select(rem_not_empty, + slice_size_c, slice_size_a, + name="slice.size")) + self._make_check( + self.append(ir.Compare(ast.LtE(loc=None), slice_size, length)), + lambda: self.alloc_exn(builtins.TException("ValueError"), + ir.Constant("slice size {0} is larger than iterable length {1}", + builtins.TStr()), + slice_size, length), + loc=node.slice.loc) + + if self.current_assign is None: + is_neg_size = self.append(ir.Compare(ast.Lt(loc=None), + slice_size, ir.Constant(0, slice_size.type))) + abs_slice_size = self.append(ir.Select(is_neg_size, + ir.Constant(0, slice_size.type), slice_size)) + other_value = self.append(ir.Alloc([abs_slice_size], value.type, + name="slice.result")) + else: + other_value = self.current_assign + + prehead = self.current_block + + head = self.current_block = self.add_block() + prehead.append(ir.Branch(head)) + + index = self.append(ir.Phi(node.slice.type, + name="slice.index")) + index.add_incoming(mapped_start_index, prehead) + other_index = self.append(ir.Phi(node.slice.type, + name="slice.resindex")) + other_index.add_incoming(ir.Constant(0, node.slice.type), prehead) + + # Still within bounds? + bounded_up = self.append(ir.Compare(ast.Lt(loc=None), index, mapped_stop_index)) + bounded_down = self.append(ir.Compare(ast.Gt(loc=None), index, mapped_stop_index)) + within_bounds = self.append(ir.Select(counting_up, bounded_up, bounded_down)) + + body = self.current_block = self.add_block() + + if self.current_assign is None: + elem = self.iterable_get(value, index) + self.append(ir.SetElem(other_value, other_index, elem)) + else: + elem = self.append(ir.GetElem(self.current_assign, other_index)) + self.append(ir.SetElem(value, index, elem)) + + next_index = self.append(ir.Arith(ast.Add(loc=None), index, step)) + index.add_incoming(next_index, body) + next_other_index = self.append(ir.Arith(ast.Add(loc=None), other_index, + ir.Constant(1, node.slice.type))) + other_index.add_incoming(next_other_index, body) + self.append(ir.Branch(head)) + + tail = self.current_block = self.add_block() + head.append(ir.BranchIf(within_bounds, body, tail)) + + if self.current_assign is None: + return other_value + + def visit_TupleT(self, node): + if self.current_assign is None: + return self.append(ir.Alloc([self.visit(elt) for elt in node.elts], node.type)) + else: + try: + old_assign = self.current_assign + for index, elt_node in enumerate(node.elts): + self.current_assign = \ + self.append(ir.GetAttr(old_assign, index, + name="{}.e{}".format(old_assign.name, index)), + loc=elt_node.loc) + self.visit(elt_node) + finally: + self.current_assign = old_assign + + def visit_ListT(self, node): + if self.current_assign is None: + elts = [self.visit(elt_node) for elt_node in node.elts] + lst = self.append(ir.Alloc([ir.Constant(len(node.elts), self._size_type)], + node.type)) + for index, elt_node in enumerate(elts): + self.append(ir.SetElem(lst, ir.Constant(index, self._size_type), elt_node)) + return lst + else: + length = self.iterable_len(self.current_assign) + self._make_check( + self.append(ir.Compare(ast.Eq(loc=None), length, + ir.Constant(len(node.elts), self._size_type))), + lambda: self.alloc_exn(builtins.TException("ValueError"), + ir.Constant("list must be {0} elements long to decompose", builtins.TStr()), + length)) + + for index, elt_node in enumerate(node.elts): + elt = self.append(ir.GetElem(self.current_assign, + ir.Constant(index, self._size_type))) + try: + old_assign, self.current_assign = self.current_assign, elt + self.visit(elt_node) + finally: + self.current_assign = old_assign + + def visit_ListCompT(self, node): + assert len(node.generators) == 1 + comprehension = node.generators[0] + assert comprehension.ifs == [] + + iterable = self.visit(comprehension.iter) + length = self.iterable_len(iterable) + result = self.append(ir.Alloc([length], node.type)) + + try: + env_type = ir.TEnvironment(node.typing_env, self.current_env.type) + env = self.append(ir.Alloc([], env_type, name="env.gen")) + old_env, self.current_env = self.current_env, env + + self.append(ir.SetLocal(env, "$outer", old_env)) + + def body_gen(index): + elt = self.iterable_get(iterable, index) + try: + old_assign, self.current_assign = self.current_assign, elt + self.visit(comprehension.target) + finally: + self.current_assign = old_assign + + mapped_elt = self.visit(node.elt) + self.append(ir.SetElem(result, index, mapped_elt)) + return self.append(ir.Arith(ast.Add(loc=None), index, + ir.Constant(1, length.type))) + self._make_loop(ir.Constant(0, length.type), + lambda index: self.append(ir.Compare(ast.Lt(loc=None), index, length)), + body_gen) + + return result + finally: + self.current_env = old_env + + def visit_BoolOpT(self, node): + blocks = [] + for value_node in node.values: + value_head = self.current_block + value = self.visit(value_node) + self.instrument_assert(value_node, value) + value_tail = self.current_block + + blocks.append((value, value_head, value_tail)) + self.current_block = self.add_block() + + tail = self.current_block + phi = self.append(ir.Phi(node.type)) + for ((value, value_head, value_tail), (next_value_head, next_value_tail)) in \ + zip(blocks, [(h,t) for (v,h,t) in blocks[1:]] + [(tail, tail)]): + phi.add_incoming(value, value_tail) + if next_value_head != tail: + cond = self.coerce_to_bool(value, block=value_tail) + if isinstance(node.op, ast.And): + value_tail.append(ir.BranchIf(cond, next_value_head, tail)) + else: + value_tail.append(ir.BranchIf(cond, tail, next_value_head)) + else: + value_tail.append(ir.Branch(tail)) + return phi + + def visit_UnaryOpT(self, node): + if isinstance(node.op, ast.Not): + cond = self.coerce_to_bool(self.visit(node.operand)) + return self.append(ir.Select(cond, + ir.Constant(False, builtins.TBool()), + ir.Constant(True, builtins.TBool()))) + elif isinstance(node.op, ast.USub): + operand = self.visit(node.operand) + return self.append(ir.Arith(ast.Sub(loc=None), + ir.Constant(0, operand.type), operand)) + elif isinstance(node.op, ast.UAdd): + # No-op. + return self.visit(node.operand) + else: + assert False + + def visit_CoerceT(self, node): + value = self.visit(node.value) + if node.type.find() == value.type: + return value + else: + return self.append(ir.Coerce(value, node.type, + name="{}.{}".format(_readable_name(value), + node.type.name))) + + def visit_BinOpT(self, node): + if builtins.is_numeric(node.type): + lhs = self.visit(node.left) + rhs = self.visit(node.right) + if isinstance(node.op, (ast.LShift, ast.RShift)): + # Check for negative shift amount. + self._make_check( + self.append(ir.Compare(ast.GtE(loc=None), rhs, ir.Constant(0, rhs.type))), + lambda: self.alloc_exn(builtins.TException("ValueError"), + ir.Constant("shift amount must be nonnegative", builtins.TStr())), + loc=node.right.loc) + elif isinstance(node.op, (ast.Div, ast.FloorDiv)): + self._make_check( + self.append(ir.Compare(ast.NotEq(loc=None), rhs, ir.Constant(0, rhs.type))), + lambda: self.alloc_exn(builtins.TException("ZeroDivisionError"), + ir.Constant("cannot divide by zero", builtins.TStr())), + loc=node.right.loc) + + return self.append(ir.Arith(node.op, lhs, rhs)) + elif isinstance(node.op, ast.Add): # list + list, tuple + tuple + lhs, rhs = self.visit(node.left), self.visit(node.right) + if types.is_tuple(node.left.type) and types.is_tuple(node.right.type): + elts = [] + for index, elt in enumerate(node.left.type.elts): + elts.append(self.append(ir.GetAttr(lhs, index))) + for index, elt in enumerate(node.right.type.elts): + elts.append(self.append(ir.GetAttr(rhs, index))) + return self.append(ir.Alloc(elts, node.type)) + elif builtins.is_list(node.left.type) and builtins.is_list(node.right.type): + lhs_length = self.iterable_len(lhs) + rhs_length = self.iterable_len(rhs) + + result_length = self.append(ir.Arith(ast.Add(loc=None), lhs_length, rhs_length)) + result = self.append(ir.Alloc([result_length], node.type)) + + # Copy lhs + def body_gen(index): + elt = self.append(ir.GetElem(lhs, index)) + self.append(ir.SetElem(result, index, elt)) + return self.append(ir.Arith(ast.Add(loc=None), index, + ir.Constant(1, self._size_type))) + self._make_loop(ir.Constant(0, self._size_type), + lambda index: self.append(ir.Compare(ast.Lt(loc=None), index, lhs_length)), + body_gen) + + # Copy rhs + def body_gen(index): + elt = self.append(ir.GetElem(rhs, index)) + result_index = self.append(ir.Arith(ast.Add(loc=None), index, lhs_length)) + self.append(ir.SetElem(result, result_index, elt)) + return self.append(ir.Arith(ast.Add(loc=None), index, + ir.Constant(1, self._size_type))) + self._make_loop(ir.Constant(0, self._size_type), + lambda index: self.append(ir.Compare(ast.Lt(loc=None), index, rhs_length)), + body_gen) + + return result + else: + assert False + elif isinstance(node.op, ast.Mult): # list * int, int * list + lhs, rhs = self.visit(node.left), self.visit(node.right) + if builtins.is_list(lhs.type) and builtins.is_int(rhs.type): + lst, num = lhs, rhs + elif builtins.is_int(lhs.type) and builtins.is_list(rhs.type): + lst, num = rhs, lhs + else: + assert False + + lst_length = self.iterable_len(lst) + + result_length = self.append(ir.Arith(ast.Mult(loc=None), lst_length, num)) + result = self.append(ir.Alloc([result_length], node.type)) + + # num times... + def body_gen(num_index): + # ... copy the list + def body_gen(lst_index): + elt = self.append(ir.GetElem(lst, lst_index)) + base_index = self.append(ir.Arith(ast.Mult(loc=None), + num_index, lst_length)) + result_index = self.append(ir.Arith(ast.Add(loc=None), + base_index, lst_index)) + self.append(ir.SetElem(result, base_index, elt)) + return self.append(ir.Arith(ast.Add(loc=None), lst_index, + ir.Constant(1, self._size_type))) + self._make_loop(ir.Constant(0, self._size_type), + lambda index: self.append(ir.Compare(ast.Lt(loc=None), index, lst_length)), + body_gen) + + return self.append(ir.Arith(ast.Add(loc=None), num_index, + ir.Constant(1, self._size_type))) + self._make_loop(ir.Constant(0, self._size_type), + lambda index: self.append(ir.Compare(ast.Lt(loc=None), index, num)), + body_gen) + + return result + else: + assert False + + def polymorphic_compare_pair_order(self, op, lhs, rhs): + if builtins.is_none(lhs.type) and builtins.is_none(rhs.type): + return self.append(ir.Compare(op, lhs, rhs)) + elif builtins.is_numeric(lhs.type) and builtins.is_numeric(rhs.type): + return self.append(ir.Compare(op, lhs, rhs)) + elif builtins.is_bool(lhs.type) and builtins.is_bool(rhs.type): + return self.append(ir.Compare(op, lhs, rhs)) + elif types.is_tuple(lhs.type) and types.is_tuple(rhs.type): + result = None + for index in range(len(lhs.type.elts)): + lhs_elt = self.append(ir.GetAttr(lhs, index)) + rhs_elt = self.append(ir.GetAttr(rhs, index)) + elt_result = self.append(ir.Compare(op, lhs_elt, rhs_elt)) + if result is None: + result = elt_result + else: + result = self.append(ir.Select(result, elt_result, + ir.Constant(False, builtins.TBool()))) + return result + elif builtins.is_list(lhs.type) and builtins.is_list(rhs.type): + head = self.current_block + lhs_length = self.iterable_len(lhs) + rhs_length = self.iterable_len(rhs) + compare_length = self.append(ir.Compare(op, lhs_length, rhs_length)) + eq_length = self.append(ir.Compare(ast.Eq(loc=None), lhs_length, rhs_length)) + + # If the length is the same, compare element-by-element + # and break when the comparison result is false + loop_head = self.add_block() + self.current_block = loop_head + index_phi = self.append(ir.Phi(self._size_type)) + index_phi.add_incoming(ir.Constant(0, self._size_type), head) + loop_cond = self.append(ir.Compare(ast.Lt(loc=None), index_phi, lhs_length)) + + loop_body = self.add_block() + self.current_block = loop_body + lhs_elt = self.append(ir.GetElem(lhs, index_phi)) + rhs_elt = self.append(ir.GetElem(rhs, index_phi)) + body_result = self.polymorphic_compare_pair(op, lhs_elt, rhs_elt) + + loop_body2 = self.add_block() + self.current_block = loop_body2 + index_next = self.append(ir.Arith(ast.Add(loc=None), index_phi, + ir.Constant(1, self._size_type))) + self.append(ir.Branch(loop_head)) + index_phi.add_incoming(index_next, loop_body2) + + tail = self.add_block() + self.current_block = tail + phi = self.append(ir.Phi(builtins.TBool())) + head.append(ir.BranchIf(eq_length, loop_head, tail)) + phi.add_incoming(compare_length, head) + loop_head.append(ir.BranchIf(loop_cond, loop_body, tail)) + phi.add_incoming(ir.Constant(True, builtins.TBool()), loop_head) + loop_body.append(ir.BranchIf(body_result, loop_body2, tail)) + phi.add_incoming(body_result, loop_body) + + if isinstance(op, ast.NotEq): + result = self.append(ir.Select(phi, + ir.Constant(False, builtins.TBool()), ir.Constant(True, builtins.TBool()))) + else: + result = phi + + return result + else: + assert False + + def polymorphic_compare_pair_inclusion(self, op, needle, haystack): + if builtins.is_range(haystack.type): + # Optimized range `in` operator + start = self.append(ir.GetAttr(haystack, "start")) + stop = self.append(ir.GetAttr(haystack, "stop")) + step = self.append(ir.GetAttr(haystack, "step")) + after_start = self.append(ir.Compare(ast.GtE(loc=None), needle, start)) + after_stop = self.append(ir.Compare(ast.Lt(loc=None), needle, stop)) + from_start = self.append(ir.Arith(ast.Sub(loc=None), needle, start)) + mod_step = self.append(ir.Arith(ast.Mod(loc=None), from_start, step)) + on_step = self.append(ir.Compare(ast.Eq(loc=None), mod_step, + ir.Constant(0, mod_step.type))) + result = self.append(ir.Select(after_start, after_stop, + ir.Constant(False, builtins.TBool()))) + result = self.append(ir.Select(result, on_step, + ir.Constant(False, builtins.TBool()))) + elif builtins.is_iterable(haystack.type): + length = self.iterable_len(haystack) + + cmp_result = loop_body2 = None + def body_gen(index): + nonlocal cmp_result, loop_body2 + + elt = self.iterable_get(haystack, index) + cmp_result = self.polymorphic_compare_pair(ast.Eq(loc=None), needle, elt) + + loop_body2 = self.add_block() + self.current_block = loop_body2 + return self.append(ir.Arith(ast.Add(loc=None), index, + ir.Constant(1, length.type))) + loop_head, loop_body, loop_tail = \ + self._make_loop(ir.Constant(0, length.type), + lambda index: self.append(ir.Compare(ast.Lt(loc=None), index, length)), + body_gen) + + loop_body.append(ir.BranchIf(cmp_result, loop_tail, loop_body2)) + phi = loop_tail.prepend(ir.Phi(builtins.TBool())) + phi.add_incoming(ir.Constant(False, builtins.TBool()), loop_head) + phi.add_incoming(ir.Constant(True, builtins.TBool()), loop_body) + + result = phi + else: + assert False + + if isinstance(op, ast.NotIn): + result = self.append(ir.Select(result, + ir.Constant(False, builtins.TBool()), + ir.Constant(True, builtins.TBool()))) + + return result + + def polymorphic_compare_pair(self, op, lhs, rhs): + if isinstance(op, (ast.Is, ast.IsNot)): + # The backend will handle equality of aggregates. + return self.append(ir.Compare(op, lhs, rhs)) + elif isinstance(op, (ast.In, ast.NotIn)): + return self.polymorphic_compare_pair_inclusion(op, lhs, rhs) + else: # Eq, NotEq, Lt, LtE, Gt, GtE + return self.polymorphic_compare_pair_order(op, lhs, rhs) + + def visit_CompareT(self, node): + # Essentially a sequence of `and`s performed over results + # of comparisons. + blocks = [] + lhs = self.visit(node.left) + self.instrument_assert(node.left, lhs) + for op, rhs_node in zip(node.ops, node.comparators): + result_head = self.current_block + rhs = self.visit(rhs_node) + self.instrument_assert(rhs_node, rhs) + result = self.polymorphic_compare_pair(op, lhs, rhs) + result_tail = self.current_block + + blocks.append((result, result_head, result_tail)) + self.current_block = self.add_block() + lhs = rhs + + tail = self.current_block + phi = self.append(ir.Phi(node.type)) + for ((result, result_head, result_tail), (next_result_head, next_result_tail)) in \ + zip(blocks, [(h,t) for (v,h,t) in blocks[1:]] + [(tail, tail)]): + phi.add_incoming(result, result_tail) + if next_result_head != tail: + result_tail.append(ir.BranchIf(result, next_result_head, tail)) + else: + result_tail.append(ir.Branch(tail)) + return phi + + # Keep this function with builtins.TException.attributes. + def alloc_exn(self, typ, message=None, param0=None, param1=None, param2=None): + attributes = [ + ir.Constant(typ.find().name, ir.TExceptionTypeInfo()), # typeinfo + ir.Constant("", builtins.TStr()), # file + ir.Constant(0, builtins.TInt(types.TValue(32))), # line + ir.Constant(0, builtins.TInt(types.TValue(32))), # column + ir.Constant("", builtins.TStr()), # function + ] + + if message is None: + attributes.append(ir.Constant(typ.find().name, builtins.TStr())) + else: + attributes.append(message) # message + + param_type = builtins.TInt(types.TValue(64)) + for param in [param0, param1, param2]: + if param is None: + attributes.append(ir.Constant(0, builtins.TInt(types.TValue(64)))) + else: + if param.type != param_type: + param = self.append(ir.Coerce(param, param_type)) + attributes.append(param) # paramN, N=0:2 + + return self.append(ir.Alloc(attributes, typ)) + + def visit_builtin_call(self, node): + # A builtin by any other name... Ignore node.func, just use the type. + typ = node.func.type + if types.is_builtin(typ, "bool"): + if len(node.args) == 0 and len(node.keywords) == 0: + return ir.Constant(False, builtins.TBool()) + elif len(node.args) == 1 and len(node.keywords) == 0: + arg = self.visit(node.args[0]) + return self.coerce_to_bool(arg) + else: + assert False + elif types.is_builtin(typ, "int"): + if len(node.args) == 0 and len(node.keywords) == 0: + return ir.Constant(0, node.type) + elif len(node.args) == 1 and \ + (len(node.keywords) == 0 or \ + len(node.keywords) == 1 and node.keywords[0].arg == 'width'): + # The width argument is purely type-level + arg = self.visit(node.args[0]) + return self.append(ir.Coerce(arg, node.type)) + else: + assert False + elif types.is_builtin(typ, "float"): + if len(node.args) == 0 and len(node.keywords) == 0: + return ir.Constant(0.0, builtins.TFloat()) + elif len(node.args) == 1 and len(node.keywords) == 0: + arg = self.visit(node.args[0]) + return self.append(ir.Coerce(arg, node.type)) + else: + assert False + elif types.is_builtin(typ, "list"): + if len(node.args) == 0 and len(node.keywords) == 0: + length = ir.Constant(0, builtins.TInt(types.TValue(32))) + return self.append(ir.Alloc([length], node.type)) + elif len(node.args) == 1 and len(node.keywords) == 0: + arg = self.visit(node.args[0]) + length = self.iterable_len(arg) + result = self.append(ir.Alloc([length], node.type)) + + def body_gen(index): + elt = self.iterable_get(arg, index) + self.append(ir.SetElem(result, index, elt)) + return self.append(ir.Arith(ast.Add(loc=None), index, + ir.Constant(1, length.type))) + self._make_loop(ir.Constant(0, length.type), + lambda index: self.append(ir.Compare(ast.Lt(loc=None), index, length)), + body_gen) + + return result + else: + assert False + elif types.is_builtin(typ, "range"): + elt_typ = builtins.get_iterable_elt(node.type) + if len(node.args) == 1 and len(node.keywords) == 0: + max_arg = self.visit(node.args[0]) + return self.append(ir.Alloc([ + ir.Constant(elt_typ.zero(), elt_typ), + max_arg, + ir.Constant(elt_typ.one(), elt_typ), + ], node.type)) + elif len(node.args) == 2 and len(node.keywords) == 0: + min_arg = self.visit(node.args[0]) + max_arg = self.visit(node.args[1]) + return self.append(ir.Alloc([ + min_arg, + max_arg, + ir.Constant(elt_typ.one(), elt_typ), + ], node.type)) + elif len(node.args) == 3 and len(node.keywords) == 0: + min_arg = self.visit(node.args[0]) + max_arg = self.visit(node.args[1]) + step_arg = self.visit(node.args[2]) + return self.append(ir.Alloc([ + min_arg, + max_arg, + step_arg, + ], node.type)) + else: + assert False + elif types.is_builtin(typ, "len"): + if len(node.args) == 1 and len(node.keywords) == 0: + arg = self.visit(node.args[0]) + return self.iterable_len(arg) + else: + assert False + elif types.is_builtin(typ, "round"): + if len(node.args) == 1 and len(node.keywords) == 0: + arg = self.visit(node.args[0]) + return self.append(ir.Builtin("round", [arg], node.type)) + else: + assert False + elif types.is_builtin(typ, "print"): + self.polymorphic_print([self.visit(arg) for arg in node.args], + separator=" ", suffix="\n") + return ir.Constant(None, builtins.TNone()) + elif types.is_builtin(typ, "now"): + if len(node.args) == 0 and len(node.keywords) == 0: + now_mu = self.append(ir.Builtin("now_mu", [], builtins.TInt(types.TValue(64)))) + now_mu_float = self.append(ir.Coerce(now_mu, builtins.TFloat())) + return self.append(ir.Arith(ast.Mult(loc=None), now_mu_float, self.ref_period)) + else: + assert False + elif types.is_builtin(typ, "delay") or types.is_builtin(typ, "at"): + if len(node.args) == 1 and len(node.keywords) == 0: + arg = self.visit(node.args[0]) + arg_mu_float = self.append(ir.Arith(ast.Div(loc=None), arg, self.ref_period)) + arg_mu = self.append(ir.Coerce(arg_mu_float, builtins.TInt(types.TValue(64)))) + return self.append(ir.Builtin(typ.name + "_mu", [arg_mu], builtins.TNone())) + else: + assert False + elif types.is_builtin(typ, "now_mu") or types.is_builtin(typ, "delay_mu") \ + or types.is_builtin(typ, "at_mu"): + return self.append(ir.Builtin(typ.name, + [self.visit(arg) for arg in node.args], node.type)) + elif types.is_builtin(typ, "mu_to_seconds"): + if len(node.args) == 1 and len(node.keywords) == 0: + arg = self.visit(node.args[0]) + arg_float = self.append(ir.Coerce(arg, builtins.TFloat())) + return self.append(ir.Arith(ast.Mult(loc=None), arg_float, self.ref_period)) + else: + assert False + elif types.is_builtin(typ, "seconds_to_mu"): + if len(node.args) == 1 and len(node.keywords) == 0: + arg = self.visit(node.args[0]) + arg_mu = self.append(ir.Arith(ast.Div(loc=None), arg, self.ref_period)) + return self.append(ir.Coerce(arg_mu, builtins.TInt(types.TValue(64)))) + else: + assert False + elif types.is_exn_constructor(typ): + return self.alloc_exn(node.type, *[self.visit(arg_node) for arg_node in node.args]) + elif types.is_constructor(typ): + return self.append(ir.Alloc([], typ.instance)) + else: + assert False + + def visit_CallT(self, node): + typ = node.func.type.find() + + if node.iodelay is not None and not iodelay.is_const(node.iodelay, 0): + before_delay = self.current_block + during_delay = self.add_block() + before_delay.append(ir.Branch(during_delay)) + self.current_block = during_delay + + if types.is_builtin(typ): + insn = self.visit_builtin_call(node) + else: + if types.is_function(typ): + func = self.visit(node.func) + self_arg = None + fn_typ = typ + offset = 0 + elif types.is_method(typ): + method = self.visit(node.func) + func = self.append(ir.GetAttr(method, "__func__")) + self_arg = self.append(ir.GetAttr(method, "__self__")) + fn_typ = types.get_method_function(typ) + offset = 1 + else: + assert False + + args = [None] * (len(fn_typ.args) + len(fn_typ.optargs)) + + for index, arg_node in enumerate(node.args): + arg = self.visit(arg_node) + if index < len(fn_typ.args): + args[index + offset] = arg + else: + args[index + offset] = self.append(ir.Alloc([arg], ir.TOption(arg.type))) + + for keyword in node.keywords: + arg = self.visit(keyword.value) + if keyword.arg in fn_typ.args: + for index, arg_name in enumerate(fn_typ.args): + if keyword.arg == arg_name: + assert args[index] is None + args[index] = arg + break + elif keyword.arg in fn_typ.optargs: + for index, optarg_name in enumerate(fn_typ.optargs): + if keyword.arg == optarg_name: + assert args[len(fn_typ.args) + index] is None + args[len(fn_typ.args) + index] = \ + self.append(ir.Alloc([arg], ir.TOption(arg.type))) + break + + for index, optarg_name in enumerate(fn_typ.optargs): + if args[len(fn_typ.args) + index] is None: + args[len(fn_typ.args) + index] = \ + self.append(ir.Alloc([], ir.TOption(fn_typ.optargs[optarg_name]))) + + if self_arg is not None: + assert args[0] is None + args[0] = self_arg + + assert None not in args + + if self.unwind_target is None: + insn = self.append(ir.Call(func, args)) + else: + after_invoke = self.add_block() + insn = self.append(ir.Invoke(func, args, after_invoke, self.unwind_target)) + self.current_block = after_invoke + + method_key = None + if isinstance(node.func, asttyped.AttributeT): + attr_node = node.func + self.method_map[(attr_node.value.type.find(), attr_node.attr)].append(insn) + + if node.iodelay is not None and not iodelay.is_const(node.iodelay, 0): + after_delay = self.add_block() + substs = {var_name: self.current_args[var_name] + for var_name in node.iodelay.free_vars()} + self.append(ir.Delay(node.iodelay, substs, insn, after_delay)) + self.current_block = after_delay + + return insn + + def visit_QuoteT(self, node): + return self.append(ir.Quote(node.value, node.type)) + + def instrument_assert(self, node, value): + if self.current_assert_env is not None: + if isinstance(value, ir.Constant): + return # don't display the values of constants + + if any([algorithm.compare(node, subexpr) + for (subexpr, name) in self.current_assert_subexprs]): + return # don't display the same subexpression twice + + name = self.current_assert_env.type.add("$subexpr", ir.TOption(node.type)) + value_opt = self.append(ir.Alloc([value], ir.TOption(node.type)), + loc=node.loc) + self.append(ir.SetLocal(self.current_assert_env, name, value_opt), + loc=node.loc) + self.current_assert_subexprs.append((node, name)) + + def visit_Assert(self, node): + try: + assert_env = self.current_assert_env = \ + self.append(ir.Alloc([], ir.TEnvironment({}), name="assertenv")) + assert_subexprs = self.current_assert_subexprs = [] + init = self.current_block + + prehead = self.current_block = self.add_block() + cond = self.visit(node.test) + head = self.current_block + finally: + self.current_assert_env = None + self.current_assert_subexprs = None + + for subexpr_node, subexpr_name in assert_subexprs: + empty = init.append(ir.Alloc([], ir.TOption(subexpr_node.type))) + init.append(ir.SetLocal(assert_env, subexpr_name, empty)) + init.append(ir.Branch(prehead)) + + if_failed = self.current_block = self.add_block() + + if node.msg: + explanation = node.msg.s + else: + explanation = node.loc.source() + self.append(ir.Builtin("printf", [ + ir.Constant("assertion failed at %s: %s\n", builtins.TStr()), + ir.Constant(str(node.loc.begin()), builtins.TStr()), + ir.Constant(str(explanation), builtins.TStr()), + ], builtins.TNone())) + + for subexpr_node, subexpr_name in assert_subexprs: + subexpr_head = self.current_block + subexpr_value_opt = self.append(ir.GetLocal(assert_env, subexpr_name)) + subexpr_cond = self.append(ir.Builtin("is_some", [subexpr_value_opt], + builtins.TBool())) + + subexpr_body = self.current_block = self.add_block() + self.append(ir.Builtin("printf", [ + ir.Constant(" (%s) = ", builtins.TStr()), + ir.Constant(subexpr_node.loc.source(), builtins.TStr()) + ], builtins.TNone())) + subexpr_value = self.append(ir.Builtin("unwrap", [subexpr_value_opt], + subexpr_node.type)) + self.polymorphic_print([subexpr_value], separator="", suffix="\n") + subexpr_postbody = self.current_block + + subexpr_tail = self.current_block = self.add_block() + self.append(ir.Branch(subexpr_tail), block=subexpr_postbody) + self.append(ir.BranchIf(subexpr_cond, subexpr_body, subexpr_tail), block=subexpr_head) + + self.append(ir.Builtin("abort", [], builtins.TNone())) + self.append(ir.Unreachable()) + + tail = self.current_block = self.add_block() + self.append(ir.BranchIf(cond, tail, if_failed), block=head) + + def polymorphic_print(self, values, separator, suffix="", as_repr=False): + format_string = "" + args = [] + def flush(): + nonlocal format_string, args + if format_string != "": + format_arg = [ir.Constant(format_string, builtins.TStr())] + self.append(ir.Builtin("printf", format_arg + args, builtins.TNone())) + format_string = "" + args = [] + + for value in values: + if format_string != "": + format_string += separator + + if types.is_tuple(value.type): + format_string += "("; flush() + self.polymorphic_print([self.append(ir.GetAttr(value, index)) + for index in range(len(value.type.elts))], + separator=", ", as_repr=True) + if len(value.type.elts) == 1: + format_string += ",)" + else: + format_string += ")" + elif types.is_function(value.type): + format_string += "" + args.append(self.append(ir.GetAttr(value, '__code__'))) + args.append(self.append(ir.GetAttr(value, '__closure__'))) + elif builtins.is_none(value.type): + format_string += "None" + elif builtins.is_bool(value.type): + format_string += "%s" + args.append(self.append(ir.Select(value, + ir.Constant("True", builtins.TStr()), + ir.Constant("False", builtins.TStr())))) + elif builtins.is_int(value.type): + width = builtins.get_int_width(value.type) + if width <= 32: + format_string += "%d" + elif width <= 64: + format_string += "%lld" + else: + assert False + args.append(value) + elif builtins.is_float(value.type): + format_string += "%g" + args.append(value) + elif builtins.is_str(value.type): + if as_repr: + format_string += "\"%s\"" + else: + format_string += "%s" + args.append(value) + elif builtins.is_list(value.type): + format_string += "["; flush() + + length = self.iterable_len(value) + last = self.append(ir.Arith(ast.Sub(loc=None), length, ir.Constant(1, length.type))) + def body_gen(index): + elt = self.iterable_get(value, index) + self.polymorphic_print([elt], separator="", as_repr=True) + is_last = self.append(ir.Compare(ast.Lt(loc=None), index, last)) + head = self.current_block + + if_last = self.current_block = self.add_block() + self.append(ir.Builtin("printf", + [ir.Constant(", ", builtins.TStr())], builtins.TNone())) + + tail = self.current_block = self.add_block() + if_last.append(ir.Branch(tail)) + head.append(ir.BranchIf(is_last, if_last, tail)) + + return self.append(ir.Arith(ast.Add(loc=None), index, + ir.Constant(1, length.type))) + self._make_loop(ir.Constant(0, length.type), + lambda index: self.append(ir.Compare(ast.Lt(loc=None), index, length)), + body_gen) + + format_string += "]" + elif builtins.is_range(value.type): + format_string += "range("; flush() + + start = self.append(ir.GetAttr(value, "start")) + stop = self.append(ir.GetAttr(value, "stop")) + step = self.append(ir.GetAttr(value, "step")) + self.polymorphic_print([start, stop, step], separator=", ") + + format_string += ")" + elif builtins.is_exception(value.type): + name = self.append(ir.GetAttr(value, "__name__")) + message = self.append(ir.GetAttr(value, "__message__")) + param1 = self.append(ir.GetAttr(value, "__param0__")) + param2 = self.append(ir.GetAttr(value, "__param1__")) + param3 = self.append(ir.GetAttr(value, "__param2__")) + + format_string += "%s(%s, %lld, %lld, %lld)" + args += [name, message, param1, param2, param3] + else: + assert False + + format_string += suffix + flush() diff --git a/artiq/compiler/transforms/asttyped_rewriter.py b/artiq/compiler/transforms/asttyped_rewriter.py new file mode 100644 index 000000000..e8f811423 --- /dev/null +++ b/artiq/compiler/transforms/asttyped_rewriter.py @@ -0,0 +1,496 @@ +""" +:class:`ASTTypedRewriter` rewrites a parsetree (:mod:`pythonparser.ast`) +to a typedtree (:mod:`..asttyped`). +""" + +from collections import OrderedDict +from pythonparser import ast, algorithm, diagnostic +from .. import asttyped, types, builtins + +# This visitor will be called for every node with a scope, +# i.e.: class, function, comprehension, lambda +class LocalExtractor(algorithm.Visitor): + def __init__(self, env_stack, engine): + super().__init__() + self.env_stack = env_stack + self.engine = engine + + self.in_root = False + self.in_assign = False + self.typing_env = OrderedDict() + + # which names are global have to be recorded in the current scope + self.global_ = set() + + # which names are nonlocal only affects whether the current scope + # gets a new binding or not, so we throw this away + self.nonlocal_ = set() + + # parameters can't be declared as global or nonlocal + self.params = set() + + def visit_in_assign(self, node, in_assign): + try: + old_in_assign, self.in_assign = self.in_assign, in_assign + return self.visit(node) + finally: + self.in_assign = old_in_assign + + def visit_Assign(self, node): + self.visit(node.value) + self.visit_in_assign(node.targets, in_assign=True) + + def visit_For(self, node): + self.visit(node.iter) + self.visit_in_assign(node.target, in_assign=True) + self.visit(node.body) + self.visit(node.orelse) + + def visit_withitem(self, node): + self.visit(node.context_expr) + self.visit_in_assign(node.optional_vars, in_assign=True) + + def visit_comprehension(self, node): + self.visit(node.iter) + self.visit_in_assign(node.target, in_assign=True) + self.visit(node.ifs) + + def visit_generator(self, node): + if self.in_root: + return + self.in_root = True + self.visit(list(reversed(node.generators))) + self.visit(node.elt) + + visit_ListComp = visit_generator + visit_SetComp = visit_generator + visit_GeneratorExp = visit_generator + + def visit_DictComp(self, node): + if self.in_root: + return + self.in_root = True + self.visit(list(reversed(node.generators))) + self.visit(node.key) + self.visit(node.value) + + def visit_root(self, node): + if self.in_root: + return + self.in_root = True + self.generic_visit(node) + + visit_Module = visit_root # don't look at inner scopes + visit_ClassDef = visit_root + visit_Lambda = visit_root + + def visit_FunctionDef(self, node): + if self.in_root: + self._assignable(node.name) + self.visit_root(node) + + def _assignable(self, name): + assert name is not None + if name not in self.typing_env and name not in self.nonlocal_: + self.typing_env[name] = types.TVar() + + def visit_arg(self, node): + if node.arg in self.params: + diag = diagnostic.Diagnostic("error", + "duplicate parameter '{name}'", {"name": node.arg}, + node.loc) + self.engine.process(diag) + return + self._assignable(node.arg) + self.params.add(node.arg) + + def visit_Name(self, node): + if self.in_assign: + # code like: + # x = 1 + # def f(): + # x = 1 + # creates a new binding for x in f's scope + self._assignable(node.id) + + def visit_Attribute(self, node): + self.visit_in_assign(node.value, in_assign=False) + + def visit_Subscript(self, node): + self.visit_in_assign(node.value, in_assign=False) + self.visit_in_assign(node.slice, in_assign=False) + + def _check_not_in(self, name, names, curkind, newkind, loc): + if name in names: + diag = diagnostic.Diagnostic("error", + "name '{name}' cannot be {curkind} and {newkind} simultaneously", + {"name": name, "curkind": curkind, "newkind": newkind}, loc) + self.engine.process(diag) + return True + return False + + def visit_Global(self, node): + for name, loc in zip(node.names, node.name_locs): + if self._check_not_in(name, self.nonlocal_, "nonlocal", "global", loc) or \ + self._check_not_in(name, self.params, "a parameter", "global", loc): + continue + + self.global_.add(name) + if len(self.env_stack) == 1: + self._assignable(name) # already in global scope + else: + if name not in self.env_stack[1]: + self.env_stack[1][name] = types.TVar() + self.typing_env[name] = self.env_stack[1][name] + + def visit_Nonlocal(self, node): + for name, loc in zip(node.names, node.name_locs): + if self._check_not_in(name, self.global_, "global", "nonlocal", loc) or \ + self._check_not_in(name, self.params, "a parameter", "nonlocal", loc): + continue + + # nonlocal does not search prelude and global scopes + found = False + for outer_env in reversed(self.env_stack[2:]): + if name in outer_env: + found = True + break + if not found: + diag = diagnostic.Diagnostic("error", + "cannot declare name '{name}' as nonlocal: it is not bound in any outer scope", + {"name": name}, + loc, [node.keyword_loc]) + self.engine.process(diag) + continue + + self.nonlocal_.add(name) + + def visit_ExceptHandler(self, node): + self.visit(node.type) + if node.name is not None: + self._assignable(node.name) + for stmt in node.body: + self.visit(stmt) + + +class ASTTypedRewriter(algorithm.Transformer): + """ + :class:`ASTTypedRewriter` converts an untyped AST to a typed AST + where all type fields of non-literals are filled with fresh type variables, + and type fields of literals are filled with corresponding types. + + :class:`ASTTypedRewriter` also discovers the scope of variable bindings + via :class:`LocalExtractor`. + """ + + def __init__(self, engine, prelude): + self.engine = engine + self.globals = None + self.env_stack = [prelude] + self.in_class = None + + def _try_find_name(self, name): + for typing_env in reversed(self.env_stack): + if name in typing_env: + return typing_env[name] + + def _find_name(self, name, loc): + if self.in_class is not None: + typ = self.in_class.constructor_type.attributes.get(name) + if typ is not None: + return typ + + typ = self._try_find_name(name) + if typ is not None: + return typ + + diag = diagnostic.Diagnostic("fatal", + "undefined variable '{name}'", {"name":name}, loc) + self.engine.process(diag) + + # Visitors that replace node with a typed node + # + def visit_Module(self, node): + extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine) + extractor.visit(node) + + node = asttyped.ModuleT( + typing_env=extractor.typing_env, globals_in_scope=extractor.global_, + body=node.body, loc=node.loc) + self.globals = node.typing_env + + try: + self.env_stack.append(node.typing_env) + return self.generic_visit(node) + finally: + self.env_stack.pop() + + def visit_FunctionDef(self, node): + extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine) + extractor.visit(node) + + node = asttyped.FunctionDefT( + typing_env=extractor.typing_env, globals_in_scope=extractor.global_, + signature_type=self._find_name(node.name, node.name_loc), return_type=types.TVar(), + name=node.name, args=node.args, returns=node.returns, + body=node.body, decorator_list=node.decorator_list, + keyword_loc=node.keyword_loc, name_loc=node.name_loc, + arrow_loc=node.arrow_loc, colon_loc=node.colon_loc, at_locs=node.at_locs, + loc=node.loc) + + try: + self.env_stack.append(node.typing_env) + return self.generic_visit(node) + finally: + self.env_stack.pop() + + def visit_ClassDef(self, node): + if any(node.bases) or any(node.keywords) or \ + node.starargs is not None or node.kwargs is not None: + diag = diagnostic.Diagnostic("error", + "inheritance is not supported", {}, + node.lparen_loc.join(node.rparen_loc)) + self.engine.process(diag) + + for child in node.body: + if isinstance(child, (ast.Assign, ast.FunctionDef, ast.Pass)): + continue + + diag = diagnostic.Diagnostic("fatal", + "class body must contain only assignments and function definitions", {}, + child.loc) + self.engine.process(diag) + + if node.name in self.env_stack[-1]: + diag = diagnostic.Diagnostic("fatal", + "variable '{name}' is already defined", {"name":node.name}, node.name_loc) + self.engine.process(diag) + + extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine) + extractor.visit(node) + + # Now we create two types. + # The first type is the type of instances created by the constructor. + # Its attributes are those of the class environment, but wrapped + # appropriately so that they are linked to the class from which they + # originate. + instance_type = types.TInstance(node.name, OrderedDict()) + + # The second type is the type of the constructor itself (in other words, + # the class object): it is simply a singleton type that has the class + # environment as attributes. + constructor_type = types.TConstructor(instance_type) + constructor_type.attributes = extractor.typing_env + instance_type.constructor = constructor_type + + self.env_stack[-1][node.name] = constructor_type + + node = asttyped.ClassDefT( + constructor_type=constructor_type, + name=node.name, + bases=self.visit(node.bases), keywords=self.visit(node.keywords), + starargs=self.visit(node.starargs), kwargs=self.visit(node.kwargs), + body=node.body, + decorator_list=self.visit(node.decorator_list), + keyword_loc=node.keyword_loc, name_loc=node.name_loc, + lparen_loc=node.lparen_loc, star_loc=node.star_loc, + dstar_loc=node.dstar_loc, rparen_loc=node.rparen_loc, + colon_loc=node.colon_loc, at_locs=node.at_locs, + loc=node.loc) + + try: + old_in_class, self.in_class = self.in_class, node + return self.generic_visit(node) + finally: + self.in_class = old_in_class + + def visit_arg(self, node): + return asttyped.argT(type=self._find_name(node.arg, node.loc), + arg=node.arg, annotation=self.visit(node.annotation), + arg_loc=node.arg_loc, colon_loc=node.colon_loc, loc=node.loc) + + def visit_Num(self, node): + if isinstance(node.n, int): + typ = builtins.TInt() + elif isinstance(node.n, float): + typ = builtins.TFloat() + else: + diag = diagnostic.Diagnostic("fatal", + "numeric type {type} is not supported", {"type": node.n.__class__.__name__}, + node.loc) + self.engine.process(diag) + return asttyped.NumT(type=typ, + n=node.n, loc=node.loc) + + def visit_Str(self, node): + return asttyped.StrT(type=builtins.TStr(), + s=node.s, + begin_loc=node.begin_loc, end_loc=node.end_loc, loc=node.loc) + + def visit_Name(self, node): + return asttyped.NameT(type=self._find_name(node.id, node.loc), + id=node.id, ctx=node.ctx, loc=node.loc) + + def visit_NameConstant(self, node): + if node.value is True or node.value is False: + typ = builtins.TBool() + elif node.value is None: + typ = builtins.TNone() + return asttyped.NameConstantT(type=typ, value=node.value, loc=node.loc) + + def visit_Tuple(self, node): + node = self.generic_visit(node) + return asttyped.TupleT(type=types.TTuple([x.type for x in node.elts]), + elts=node.elts, ctx=node.ctx, loc=node.loc) + + def visit_List(self, node): + node = self.generic_visit(node) + node = asttyped.ListT(type=builtins.TList(), + elts=node.elts, ctx=node.ctx, + begin_loc=node.begin_loc, end_loc=node.end_loc, loc=node.loc) + return self.visit(node) + + def visit_Attribute(self, node): + node = self.generic_visit(node) + node = asttyped.AttributeT(type=types.TVar(), + value=node.value, attr=node.attr, ctx=node.ctx, + dot_loc=node.dot_loc, attr_loc=node.attr_loc, loc=node.loc) + return self.visit(node) + + def visit_Slice(self, node): + node = self.generic_visit(node) + node = asttyped.SliceT(type=types.TVar(), + lower=node.lower, upper=node.upper, step=node.step, + bound_colon_loc=node.bound_colon_loc, + step_colon_loc=node.step_colon_loc, + loc=node.loc) + return self.visit(node) + + def visit_Subscript(self, node): + node = self.generic_visit(node) + node = asttyped.SubscriptT(type=types.TVar(), + value=node.value, slice=node.slice, ctx=node.ctx, + begin_loc=node.begin_loc, end_loc=node.end_loc, loc=node.loc) + return self.visit(node) + + def visit_BoolOp(self, node): + node = self.generic_visit(node) + node = asttyped.BoolOpT(type=types.TVar(), + op=node.op, values=node.values, + op_locs=node.op_locs, loc=node.loc) + return self.visit(node) + + def visit_UnaryOp(self, node): + node = self.generic_visit(node) + node = asttyped.UnaryOpT(type=types.TVar(), + op=node.op, operand=node.operand, + loc=node.loc) + return self.visit(node) + + def visit_BinOp(self, node): + node = self.generic_visit(node) + node = asttyped.BinOpT(type=types.TVar(), + left=node.left, op=node.op, right=node.right, + loc=node.loc) + return self.visit(node) + + def visit_Compare(self, node): + node = self.generic_visit(node) + node = asttyped.CompareT(type=types.TVar(), + left=node.left, ops=node.ops, comparators=node.comparators, + loc=node.loc) + return self.visit(node) + + def visit_IfExp(self, node): + node = self.generic_visit(node) + node = asttyped.IfExpT(type=types.TVar(), + test=node.test, body=node.body, orelse=node.orelse, + if_loc=node.if_loc, else_loc=node.else_loc, loc=node.loc) + return self.visit(node) + + def visit_ListComp(self, node): + extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine) + extractor.visit(node) + + node = asttyped.ListCompT( + typing_env=extractor.typing_env, globals_in_scope=extractor.global_, + type=types.TVar(), + elt=node.elt, generators=node.generators, + begin_loc=node.begin_loc, end_loc=node.end_loc, loc=node.loc) + + try: + self.env_stack.append(node.typing_env) + return self.generic_visit(node) + finally: + self.env_stack.pop() + + def visit_Call(self, node): + node = self.generic_visit(node) + node = asttyped.CallT(type=types.TVar(), iodelay=None, + func=node.func, args=node.args, keywords=node.keywords, + starargs=node.starargs, kwargs=node.kwargs, + star_loc=node.star_loc, dstar_loc=node.dstar_loc, + begin_loc=node.begin_loc, end_loc=node.end_loc, loc=node.loc) + return node + + def visit_Lambda(self, node): + extractor = LocalExtractor(env_stack=self.env_stack, engine=self.engine) + extractor.visit(node) + + node = asttyped.LambdaT( + typing_env=extractor.typing_env, globals_in_scope=extractor.global_, + type=types.TVar(), + args=node.args, body=node.body, + lambda_loc=node.lambda_loc, colon_loc=node.colon_loc, loc=node.loc) + + try: + self.env_stack.append(node.typing_env) + return self.generic_visit(node) + finally: + self.env_stack.pop() + + def visit_ExceptHandler(self, node): + node = self.generic_visit(node) + if node.name is not None: + name_type = self._find_name(node.name, node.name_loc) + else: + name_type = types.TVar() + node = asttyped.ExceptHandlerT( + name_type=name_type, + filter=node.type, name=node.name, body=node.body, + except_loc=node.except_loc, as_loc=node.as_loc, name_loc=node.name_loc, + colon_loc=node.colon_loc, loc=node.loc) + return node + + def visit_Raise(self, node): + node = self.generic_visit(node) + if node.cause: + diag = diagnostic.Diagnostic("error", + "'raise from' syntax is not supported", {}, + node.from_loc) + self.engine.process(diag) + return node + + # Unsupported visitors + # + def visit_unsupported(self, node): + diag = diagnostic.Diagnostic("fatal", + "this syntax is not supported", {}, + node.loc) + self.engine.process(diag) + + # expr + visit_Dict = visit_unsupported + visit_DictComp = visit_unsupported + visit_Ellipsis = visit_unsupported + visit_GeneratorExp = visit_unsupported + visit_Set = visit_unsupported + visit_SetComp = visit_unsupported + visit_Starred = visit_unsupported + visit_Yield = visit_unsupported + visit_YieldFrom = visit_unsupported + + # stmt + visit_Delete = visit_unsupported + visit_Import = visit_unsupported + visit_ImportFrom = visit_unsupported diff --git a/artiq/compiler/transforms/dead_code_eliminator.py b/artiq/compiler/transforms/dead_code_eliminator.py new file mode 100644 index 000000000..dfe5ccf6c --- /dev/null +++ b/artiq/compiler/transforms/dead_code_eliminator.py @@ -0,0 +1,43 @@ +""" +:class:`DeadCodeEliminator` is a very simple dead code elimination +transform: it only removes basic blocks with no predecessors. +""" + +from .. import ir + +class DeadCodeEliminator: + def __init__(self, engine): + self.engine = engine + + def process(self, functions): + for func in functions: + self.process_function(func) + + def process_function(self, func): + for block in list(func.basic_blocks): + if not any(block.predecessors()) and block != func.entry(): + for use in set(block.uses): + if isinstance(use, ir.SetLocal): + use.erase() + self.remove_block(block) + + def remove_block(self, block): + # block.uses are updated while iterating + for use in set(block.uses): + if isinstance(use, ir.Phi): + use.remove_incoming_block(block) + if not any(use.operands): + self.remove_instruction(use) + else: + assert False + + block.erase() + + def remove_instruction(self, insn): + for use in set(insn.uses): + if isinstance(use, ir.Phi): + use.remove_incoming_value(insn) + if not any(use.operands): + self.remove_instruction(use) + + insn.erase() diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py new file mode 100644 index 000000000..f823014f8 --- /dev/null +++ b/artiq/compiler/transforms/inferencer.py @@ -0,0 +1,1100 @@ +""" +:class:`Inferencer` performs unification-based inference on a typedtree. +""" + +from collections import OrderedDict +from pythonparser import algorithm, diagnostic, ast +from .. import asttyped, types, builtins + +class Inferencer(algorithm.Visitor): + """ + :class:`Inferencer` infers types by recursively applying the unification + algorithm. It does not treat inability to infer a concrete type as an error; + the result can still contain type variables. + + :class:`Inferencer` is idempotent, but does not guarantee that it will + perform all possible inference in a single pass. + """ + + def __init__(self, engine): + self.engine = engine + self.function = None # currently visited function, for Return inference + self.in_loop = False + self.has_return = False + + def _unify(self, typea, typeb, loca, locb, makenotes=None, when=""): + try: + typea.unify(typeb) + except types.UnificationError as e: + printer = types.TypePrinter() + + if makenotes: + notes = makenotes(printer, typea, typeb, loca, locb) + else: + notes = [ + diagnostic.Diagnostic("note", + "expression of type {typea}", + {"typea": printer.name(typea)}, + loca) + ] + if locb: + notes.append( + diagnostic.Diagnostic("note", + "expression of type {typeb}", + {"typeb": printer.name(typeb)}, + locb)) + + highlights = [locb] if locb else [] + if e.typea.find() == typea.find() and e.typeb.find() == typeb.find() or \ + e.typeb.find() == typea.find() and e.typea.find() == typeb.find(): + diag = diagnostic.Diagnostic("error", + "cannot unify {typea} with {typeb}{when}", + {"typea": printer.name(typea), "typeb": printer.name(typeb), + "when": when}, + loca, highlights, notes) + else: # give more detail + diag = diagnostic.Diagnostic("error", + "cannot unify {typea} with {typeb}{when}: {fraga} is incompatible with {fragb}", + {"typea": printer.name(typea), "typeb": printer.name(typeb), + "fraga": printer.name(e.typea), "fragb": printer.name(e.typeb), + "when": when}, + loca, highlights, notes) + self.engine.process(diag) + + # makenotes for the case where types of multiple elements are unified + # with the type of parent expression + def _makenotes_elts(self, elts, kind): + def makenotes(printer, typea, typeb, loca, locb): + return [ + diagnostic.Diagnostic("note", + "{kind} of type {typea}", + {"kind": kind, "typea": printer.name(elts[0].type)}, + elts[0].loc), + diagnostic.Diagnostic("note", + "{kind} of type {typeb}", + {"kind": kind, "typeb": printer.name(typeb)}, + locb) + ] + return makenotes + + def visit_ListT(self, node): + self.generic_visit(node) + elt_type_loc = node.loc + for elt in node.elts: + self._unify(node.type["elt"], elt.type, + elt_type_loc, elt.loc, + self._makenotes_elts(node.elts, "a list element")) + elt_type_loc = elt.loc + + def visit_AttributeT(self, node): + self.generic_visit(node) + object_type = node.value.type.find() + if not types.is_var(object_type): + if node.attr in object_type.attributes: + def makenotes(printer, typea, typeb, loca, locb): + return [ + diagnostic.Diagnostic("note", + "expression of type {typea}", + {"typea": printer.name(typea)}, + loca), + diagnostic.Diagnostic("note", + "expression of type {typeb}", + {"typeb": printer.name(object_type)}, + node.value.loc) + ] + + # Assumes no free type variables in .attributes. + self._unify(node.type, object_type.attributes[node.attr], + node.loc, None, + makenotes=makenotes, when=" for attribute '{}'".format(node.attr)) + elif types.is_instance(object_type) and \ + node.attr in object_type.constructor.attributes: + # Assumes no free type variables in .attributes. + attr_type = object_type.constructor.attributes[node.attr].find() + if types.is_function(attr_type): + # Convert to a method. + if len(attr_type.args) < 1: + diag = diagnostic.Diagnostic("error", + "function '{attr}{type}' of class '{class}' cannot accept a self argument", + {"attr": node.attr, "type": types.TypePrinter().name(attr_type), + "class": object_type.name}, + node.loc) + self.engine.process(diag) + return + else: + def makenotes(printer, typea, typeb, loca, locb): + return [ + diagnostic.Diagnostic("note", + "expression of type {typea}", + {"typea": printer.name(typea)}, + loca), + diagnostic.Diagnostic("note", + "reference to a class function of type {typeb}", + {"typeb": printer.name(attr_type)}, + locb) + ] + + self._unify(object_type, list(attr_type.args.values())[0], + node.value.loc, node.loc, + makenotes=makenotes, + when=" while inferring the type for self argument") + + attr_type = types.TMethod(object_type, attr_type) + + if not types.is_var(attr_type): + self._unify(node.type, attr_type, + node.loc, None) + else: + if node.attr_loc.source_buffer == node.value.loc.source_buffer: + highlights, notes = [node.value.loc], [] + else: + # This happens when the object being accessed is embedded + # from the host program. + note = diagnostic.Diagnostic("note", + "object being accessed", {}, + node.value.loc) + highlights, notes = [], [note] + + diag = diagnostic.Diagnostic("error", + "type {type} does not have an attribute '{attr}'", + {"type": types.TypePrinter().name(object_type), "attr": node.attr}, + node.attr_loc, highlights, notes) + self.engine.process(diag) + + def _unify_iterable(self, element, collection): + if builtins.is_iterable(collection.type): + rhs_type = collection.type.find() + rhs_wrapped_lhs_type = types.TMono(rhs_type.name, {"elt": element.type}) + self._unify(rhs_wrapped_lhs_type, rhs_type, + element.loc, collection.loc) + elif not types.is_var(collection.type): + diag = diagnostic.Diagnostic("error", + "type {type} is not iterable", + {"type": types.TypePrinter().name(collection.type)}, + collection.loc, []) + self.engine.process(diag) + + def visit_Index(self, node): + self.generic_visit(node) + value = node.value + if types.is_tuple(value.type): + diag = diagnostic.Diagnostic("error", + "multi-dimensional slices are not supported", {}, + node.loc, []) + self.engine.process(diag) + else: + self._unify(value.type, builtins.TInt(), + value.loc, None) + + def visit_SliceT(self, node): + self._unify(node.type, builtins.TInt(), + node.loc, None) + for operand in (node.lower, node.upper, node.step): + if operand is not None: + self._unify(operand.type, node.type, + operand.loc, None) + + def visit_ExtSlice(self, node): + diag = diagnostic.Diagnostic("error", + "multi-dimensional slices are not supported", {}, + node.loc, []) + self.engine.process(diag) + + def visit_SubscriptT(self, node): + self.generic_visit(node) + if isinstance(node.slice, ast.Index): + self._unify_iterable(element=node, collection=node.value) + elif isinstance(node.slice, ast.Slice): + self._unify(node.type, node.value.type, + node.loc, node.value.loc) + else: # ExtSlice + pass # error emitted above + + def visit_IfExpT(self, node): + self.generic_visit(node) + self._unify(node.body.type, node.orelse.type, + node.body.loc, node.orelse.loc) + self._unify(node.type, node.body.type, + node.loc, None) + + def visit_BoolOpT(self, node): + self.generic_visit(node) + for value in node.values: + self._unify(node.type, value.type, + node.loc, value.loc, self._makenotes_elts(node.values, "an operand")) + + def visit_UnaryOpT(self, node): + self.generic_visit(node) + operand_type = node.operand.type.find() + if isinstance(node.op, ast.Not): + self._unify(node.type, builtins.TBool(), + node.loc, None) + elif isinstance(node.op, ast.Invert): + if builtins.is_int(operand_type): + self._unify(node.type, operand_type, + node.loc, None) + elif not types.is_var(operand_type): + diag = diagnostic.Diagnostic("error", + "expected '~' operand to be of integer type, not {type}", + {"type": types.TypePrinter().name(operand_type)}, + node.operand.loc) + self.engine.process(diag) + else: # UAdd, USub + if builtins.is_numeric(operand_type): + self._unify(node.type, operand_type, + node.loc, None) + elif not types.is_var(operand_type): + diag = diagnostic.Diagnostic("error", + "expected unary '{op}' operand to be of numeric type, not {type}", + {"op": node.op.loc.source(), + "type": types.TypePrinter().name(operand_type)}, + node.operand.loc) + self.engine.process(diag) + + def visit_CoerceT(self, node): + self.generic_visit(node) + if builtins.is_numeric(node.type) and builtins.is_numeric(node.value.type): + pass + else: + printer = types.TypePrinter() + note = diagnostic.Diagnostic("note", + "expression that required coercion to {typeb}", + {"typeb": printer.name(node.type)}, + node.other_value.loc) + diag = diagnostic.Diagnostic("error", + "cannot coerce {typea} to {typeb}", + {"typea": printer.name(node.value.type), "typeb": printer.name(node.type)}, + node.loc, notes=[note]) + self.engine.process(diag) + + def _coerce_one(self, typ, coerced_node, other_node): + if coerced_node.type.find() == typ.find(): + return coerced_node + elif isinstance(coerced_node, asttyped.CoerceT): + node = coerced_node + node.type, node.other_value = typ, other_node + else: + node = asttyped.CoerceT(type=typ, value=coerced_node, other_value=other_node, + loc=coerced_node.loc) + self.visit(node) + return node + + def _coerce_numeric(self, nodes, map_return=lambda typ: typ): + # See https://docs.python.org/3/library/stdtypes.html#numeric-types-int-float-complex. + node_types = [] + for node in nodes: + if isinstance(node, asttyped.CoerceT): + node_types.append(node.value.type) + else: + node_types.append(node.type) + if any(map(types.is_var, node_types)): # not enough info yet + return + elif not all(map(builtins.is_numeric, node_types)): + err_node = next(filter(lambda node: not builtins.is_numeric(node.type), nodes)) + diag = diagnostic.Diagnostic("error", + "cannot coerce {type} to a numeric type", + {"type": types.TypePrinter().name(err_node.type)}, + err_node.loc, []) + self.engine.process(diag) + return + elif any(map(builtins.is_float, node_types)): + typ = builtins.TFloat() + elif any(map(builtins.is_int, node_types)): + widths = list(map(builtins.get_int_width, node_types)) + if all(widths): + typ = builtins.TInt(types.TValue(max(widths))) + else: + typ = builtins.TInt() + else: + assert False + + return map_return(typ) + + def _order_by_pred(self, pred, left, right): + if pred(left.type): + return left, right + elif pred(right.type): + return right, left + else: + assert False + + def _coerce_binop(self, op, left, right): + if isinstance(op, (ast.BitAnd, ast.BitOr, ast.BitXor, + ast.LShift, ast.RShift)): + # bitwise operators require integers + for operand in (left, right): + if not types.is_var(operand.type) and not builtins.is_int(operand.type): + diag = diagnostic.Diagnostic("error", + "expected '{op}' operand to be of integer type, not {type}", + {"op": op.loc.source(), + "type": types.TypePrinter().name(operand.type)}, + op.loc, [operand.loc]) + self.engine.process(diag) + return + + return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ)) + elif isinstance(op, ast.Add): + # add works on numbers and also collections + if builtins.is_collection(left.type) or builtins.is_collection(right.type): + collection, other = \ + self._order_by_pred(builtins.is_collection, left, right) + if types.is_tuple(collection.type): + pred, kind = types.is_tuple, "tuple" + elif builtins.is_list(collection.type): + pred, kind = builtins.is_list, "list" + else: + assert False + if not pred(other.type): + printer = types.TypePrinter() + note1 = diagnostic.Diagnostic("note", + "{kind} of type {typea}", + {"typea": printer.name(collection.type), "kind": kind}, + collection.loc) + note2 = diagnostic.Diagnostic("note", + "{typeb}, which cannot be added to a {kind}", + {"typeb": printer.name(other.type), "kind": kind}, + other.loc) + diag = diagnostic.Diagnostic("error", + "expected every '+' operand to be a {kind} in this context", + {"kind": kind}, + op.loc, [other.loc, collection.loc], + [note1, note2]) + self.engine.process(diag) + return + + if types.is_tuple(collection.type): + return types.TTuple(left.type.find().elts + + right.type.find().elts), left.type, right.type + elif builtins.is_list(collection.type): + self._unify(left.type, right.type, + left.loc, right.loc) + return left.type, left.type, right.type + else: + return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ)) + elif isinstance(op, ast.Mult): + # mult works on numbers and also number & collection + if types.is_tuple(left.type) or types.is_tuple(right.type): + tuple_, other = self._order_by_pred(types.is_tuple, left, right) + diag = diagnostic.Diagnostic("error", + "passing tuples to '*' is not supported", {}, + op.loc, [tuple_.loc]) + self.engine.process(diag) + return + elif builtins.is_list(left.type) or builtins.is_list(right.type): + list_, other = self._order_by_pred(builtins.is_list, left, right) + if not builtins.is_int(other.type): + printer = types.TypePrinter() + note1 = diagnostic.Diagnostic("note", + "list operand of type {typea}", + {"typea": printer.name(list_.type)}, + list_.loc) + note2 = diagnostic.Diagnostic("note", + "operand of type {typeb}, which is not a valid repetition amount", + {"typeb": printer.name(other.type)}, + other.loc) + diag = diagnostic.Diagnostic("error", + "expected '*' operands to be a list and an integer in this context", {}, + op.loc, [list_.loc, other.loc], + [note1, note2]) + self.engine.process(diag) + return + + return list_.type, left.type, right.type + else: + return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ)) + elif isinstance(op, (ast.FloorDiv, ast.Mod, ast.Pow, ast.Sub)): + # numeric operators work on any kind of number + return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ)) + elif isinstance(op, ast.Div): + # division always returns a float + return self._coerce_numeric((left, right), + lambda typ: (builtins.TFloat(), builtins.TFloat(), builtins.TFloat())) + else: # MatMult + diag = diagnostic.Diagnostic("error", + "operator '{op}' is not supported", {"op": op.loc.source()}, + op.loc) + self.engine.process(diag) + return + + def visit_BinOpT(self, node): + self.generic_visit(node) + coerced = self._coerce_binop(node.op, node.left, node.right) + if coerced: + return_type, left_type, right_type = coerced + node.left = self._coerce_one(left_type, node.left, other_node=node.right) + node.right = self._coerce_one(right_type, node.right, other_node=node.left) + self._unify(node.type, return_type, + node.loc, None) + + def visit_CompareT(self, node): + self.generic_visit(node) + pairs = zip([node.left] + node.comparators, node.comparators) + if all(map(lambda op: isinstance(op, (ast.Is, ast.IsNot)), node.ops)): + for left, right in pairs: + self._unify(left.type, right.type, + left.loc, right.loc) + elif all(map(lambda op: isinstance(op, (ast.In, ast.NotIn)), node.ops)): + for left, right in pairs: + self._unify_iterable(element=left, collection=right) + else: # Eq, NotEq, Lt, LtE, Gt, GtE + operands = [node.left] + node.comparators + operand_types = [operand.type for operand in operands] + if any(map(builtins.is_collection, operand_types)): + for left, right in pairs: + self._unify(left.type, right.type, + left.loc, right.loc) + elif any(map(builtins.is_numeric, operand_types)): + typ = self._coerce_numeric(operands) + if typ: + try: + other_node = next(filter(lambda operand: operand.type.find() == typ.find(), + operands)) + except StopIteration: + # can't find an argument with an exact type, meaning + # the return value is more generic than any of the inputs, meaning + # the type is known (typ is not None), but its width is not + def wide_enough(opreand): + return types.is_mono(opreand.type) and \ + opreand.type.find().name == typ.find().name + other_node = next(filter(wide_enough, operands)) + node.left, *node.comparators = \ + [self._coerce_one(typ, operand, other_node) for operand in operands] + else: + pass # No coercion required. + self._unify(node.type, builtins.TBool(), + node.loc, None) + + def visit_ListCompT(self, node): + if len(node.generators) > 1: + diag = diagnostic.Diagnostic("error", + "multiple for clauses in comprehensions are not supported", {}, + node.generators[1].for_loc) + self.engine.process(diag) + + self.generic_visit(node) + self._unify(node.type, builtins.TList(node.elt.type), + node.loc, None) + + def visit_comprehension(self, node): + if any(node.ifs): + diag = diagnostic.Diagnostic("error", + "if clauses in comprehensions are not supported", {}, + node.if_locs[0]) + self.engine.process(diag) + + self.generic_visit(node) + self._unify_iterable(element=node.target, collection=node.iter) + + def visit_builtin_call(self, node): + typ = node.func.type.find() + + def valid_form(signature): + return diagnostic.Diagnostic("note", + "{func} can be invoked as: {signature}", + {"func": typ.name, "signature": signature}, + node.func.loc) + + def diagnose(valid_forms): + printer = types.TypePrinter() + args = [printer.name(arg.type) for arg in node.args] + args += ["%s=%s" % (kw.arg, printer.name(kw.value.type)) for kw in node.keywords] + + diag = diagnostic.Diagnostic("error", + "{func} cannot be invoked with the arguments ({args})", + {"func": typ.name, "args": ", ".join(args)}, + node.func.loc, notes=valid_forms) + self.engine.process(diag) + + def simple_form(info, arg_types=[], return_type=builtins.TNone()): + self._unify(node.type, return_type, + node.loc, None) + + if len(node.args) == len(arg_types) and len(node.keywords) == 0: + for index, arg_type in enumerate(arg_types): + self._unify(node.args[index].type, arg_type, + node.args[index].loc, None) + else: + diagnose([ valid_form(info) ]) + + if types.is_exn_constructor(typ): + valid_forms = lambda: [ + valid_form("{exn}() -> {exn}".format(exn=typ.name)), + valid_form("{exn}(message:str) -> {exn}".format(exn=typ.name)), + valid_form("{exn}(message:str, param1:int(width=64)) -> {exn}".format(exn=typ.name)), + valid_form("{exn}(message:str, param1:int(width=64), " + "param2:int(width=64)) -> {exn}".format(exn=typ.name)), + valid_form("{exn}(message:str, param1:int(width=64), " + "param2:int(width=64), param3:int(width=64)) " + "-> {exn}".format(exn=typ.name)), + ] + + if len(node.args) == 0 and len(node.keywords) == 0: + pass # Default message, zeroes as parameters + elif len(node.args) >= 1 and len(node.args) <= 4 and len(node.keywords) == 0: + message, *params = node.args + + self._unify(message.type, builtins.TStr(), + message.loc, None) + for param in params: + self._unify(param.type, builtins.TInt(types.TValue(64)), + param.loc, None) + else: + diagnose(valid_forms()) + + self._unify(node.type, typ.instance, + node.loc, None) + elif types.is_builtin(typ, "bool"): + valid_forms = lambda: [ + valid_form("bool() -> bool"), + valid_form("bool(x:'a) -> bool") + ] + + if len(node.args) == 0 and len(node.keywords) == 0: + pass # False + elif len(node.args) == 1 and len(node.keywords) == 0: + arg, = node.args + pass # anything goes + else: + diagnose(valid_forms()) + + self._unify(node.type, builtins.TBool(), + node.loc, None) + elif types.is_builtin(typ, "int"): + valid_forms = lambda: [ + valid_form("int() -> int(width='a)"), + valid_form("int(x:'a) -> int(width='b) where 'a is numeric"), + valid_form("int(x:'a, width='b:) -> int(width='b) where 'a is numeric") + ] + + self._unify(node.type, builtins.TInt(), + node.loc, None) + + if len(node.args) == 0 and len(node.keywords) == 0: + pass # 0 + elif len(node.args) == 1 and len(node.keywords) == 0 and \ + types.is_var(node.args[0].type): + pass # undetermined yet + elif len(node.args) == 1 and len(node.keywords) == 0 and \ + builtins.is_numeric(node.args[0].type): + self._unify(node.type, builtins.TInt(), + node.loc, None) + elif len(node.args) == 1 and len(node.keywords) == 1 and \ + builtins.is_numeric(node.args[0].type) and \ + node.keywords[0].arg == 'width': + width = node.keywords[0].value + if not (isinstance(width, asttyped.NumT) and isinstance(width.n, int)): + diag = diagnostic.Diagnostic("error", + "the width argument of int() must be an integer literal", {}, + node.keywords[0].loc) + self.engine.process(diag) + return + + self._unify(node.type, builtins.TInt(types.TValue(width.n)), + node.loc, None) + else: + diagnose(valid_forms()) + elif types.is_builtin(typ, "float"): + valid_forms = lambda: [ + valid_form("float() -> float"), + valid_form("float(x:'a) -> float where 'a is numeric") + ] + + self._unify(node.type, builtins.TFloat(), + node.loc, None) + + if len(node.args) == 0 and len(node.keywords) == 0: + pass # 0.0 + elif len(node.args) == 1 and len(node.keywords) == 0 and \ + types.is_var(node.args[0].type): + pass # undetermined yet + elif len(node.args) == 1 and len(node.keywords) == 0 and \ + builtins.is_numeric(node.args[0].type): + pass + else: + diagnose(valid_forms()) + elif types.is_builtin(typ, "list"): + valid_forms = lambda: [ + valid_form("list() -> list(elt='a)"), + valid_form("list(x:'a) -> list(elt='b) where 'a is iterable") + ] + + self._unify(node.type, builtins.TList(), + node.loc, None) + + if len(node.args) == 0 and len(node.keywords) == 0: + pass # [] + elif len(node.args) == 1 and len(node.keywords) == 0: + arg, = node.args + + if builtins.is_iterable(arg.type): + def makenotes(printer, typea, typeb, loca, locb): + return [ + diagnostic.Diagnostic("note", + "iterator returning elements of type {typea}", + {"typea": printer.name(typea)}, + loca), + diagnostic.Diagnostic("note", + "iterator returning elements of type {typeb}", + {"typeb": printer.name(typeb)}, + locb) + ] + self._unify(node.type.find().params["elt"], + arg.type.find().params["elt"], + node.loc, arg.loc, makenotes=makenotes) + elif types.is_var(arg.type): + pass # undetermined yet + else: + note = diagnostic.Diagnostic("note", + "this expression has type {type}", + {"type": types.TypePrinter().name(arg.type)}, + arg.loc) + diag = diagnostic.Diagnostic("error", + "the argument of list() must be of an iterable type", {}, + node.func.loc, notes=[note]) + self.engine.process(diag) + else: + diagnose(valid_forms()) + elif types.is_builtin(typ, "range"): + valid_forms = lambda: [ + valid_form("range(max:int(width='a)) -> range(elt=int(width='a))"), + valid_form("range(min:int(width='a), max:int(width='a)) " + "-> range(elt=int(width='a))"), + valid_form("range(min:int(width='a), max:int(width='a), " + "step:int(width='a)) -> range(elt=int(width='a))"), + ] + + range_elt = builtins.TInt(types.TVar()) + self._unify(node.type, builtins.TRange(range_elt), + node.loc, None) + + if len(node.args) in (1, 2, 3) and len(node.keywords) == 0: + for arg in node.args: + self._unify(arg.type, range_elt, + arg.loc, None) + else: + diagnose(valid_forms()) + elif types.is_builtin(typ, "len"): + valid_forms = lambda: [ + valid_form("len(x:'a) -> int(width='b) where 'a is iterable"), + ] + + if len(node.args) == 1 and len(node.keywords) == 0: + arg, = node.args + + if builtins.is_range(arg.type): + self._unify(node.type, builtins.get_iterable_elt(arg.type), + node.loc, None) + elif builtins.is_list(arg.type): + # TODO: should be ssize_t-sized + self._unify(node.type, builtins.TInt(types.TValue(32)), + node.loc, None) + elif types.is_var(arg.type): + pass # undetermined yet + else: + note = diagnostic.Diagnostic("note", + "this expression has type {type}", + {"type": types.TypePrinter().name(arg.type)}, + arg.loc) + diag = diagnostic.Diagnostic("error", + "the argument of len() must be of an iterable type", {}, + node.func.loc, notes=[note]) + self.engine.process(diag) + else: + diagnose(valid_forms()) + elif types.is_builtin(typ, "round"): + valid_forms = lambda: [ + valid_form("round(x:float) -> int(width='a)"), + ] + + self._unify(node.type, builtins.TInt(), + node.loc, None) + + if len(node.args) == 1 and len(node.keywords) == 0: + arg, = node.args + + self._unify(arg.type, builtins.TFloat(), + arg.loc, None) + else: + diagnose(valid_forms()) + elif types.is_builtin(typ, "print"): + valid_forms = lambda: [ + valid_form("print(args...) -> None"), + ] + + self._unify(node.type, builtins.TNone(), + node.loc, None) + + if len(node.keywords) == 0: + # We can print any arguments. + pass + else: + diagnose(valid_forms()) + elif types.is_builtin(typ, "now"): + simple_form("now() -> float", + [], builtins.TFloat()) + elif types.is_builtin(typ, "delay"): + simple_form("delay(time:float) -> None", + [builtins.TFloat()]) + elif types.is_builtin(typ, "at"): + simple_form("at(time:float) -> None", + [builtins.TFloat()]) + elif types.is_builtin(typ, "now_mu"): + simple_form("now_mu() -> int(width=64)", + [], builtins.TInt(types.TValue(64))) + elif types.is_builtin(typ, "delay_mu"): + simple_form("delay_mu(time_mu:int(width=64)) -> None", + [builtins.TInt(types.TValue(64))]) + elif types.is_builtin(typ, "at_mu"): + simple_form("at_mu(time_mu:int(width=64)) -> None", + [builtins.TInt(types.TValue(64))]) + elif types.is_builtin(typ, "mu_to_seconds"): + simple_form("mu_to_seconds(time_mu:int(width=64)) -> float", + [builtins.TInt(types.TValue(64))], builtins.TFloat()) + elif types.is_builtin(typ, "seconds_to_mu"): + simple_form("seconds_to_mu(time:float) -> int(width=64)", + [builtins.TFloat()], builtins.TInt(types.TValue(64))) + elif types.is_constructor(typ): + # An user-defined class. + self._unify(node.type, typ.find().instance, + node.loc, None) + else: + assert False + + def visit_CallT(self, node): + self.generic_visit(node) + + for (sigil_loc, vararg) in ((node.star_loc, node.starargs), + (node.dstar_loc, node.kwargs)): + if vararg: + diag = diagnostic.Diagnostic("error", + "variadic arguments are not supported", {}, + sigil_loc, [vararg.loc]) + self.engine.process(diag) + return + + typ = node.func.type.find() + + if types.is_var(typ): + return # not enough info yet + elif types.is_builtin(typ): + return self.visit_builtin_call(node) + elif not (types.is_function(typ) or types.is_method(typ)): + diag = diagnostic.Diagnostic("error", + "cannot call this expression of type {type}", + {"type": types.TypePrinter().name(typ)}, + node.func.loc, []) + self.engine.process(diag) + return + + if types.is_function(typ): + typ_arity = typ.arity() + typ_args = typ.args + typ_optargs = typ.optargs + typ_ret = typ.ret + else: + typ = types.get_method_function(typ) + typ_arity = typ.arity() - 1 + typ_args = OrderedDict(list(typ.args.items())[1:]) + typ_optargs = typ.optargs + typ_ret = typ.ret + + passed_args = dict() + + if len(node.args) > typ_arity: + note = diagnostic.Diagnostic("note", + "extraneous argument(s)", {}, + node.args[typ_arity].loc.join(node.args[-1].loc)) + diag = diagnostic.Diagnostic("error", + "this function of type {type} accepts at most {num} arguments", + {"type": types.TypePrinter().name(node.func.type), + "num": typ_arity}, + node.func.loc, [], [note]) + self.engine.process(diag) + return + + for actualarg, (formalname, formaltyp) in \ + zip(node.args, list(typ_args.items()) + list(typ_optargs.items())): + self._unify(actualarg.type, formaltyp, + actualarg.loc, None) + passed_args[formalname] = actualarg.loc + + for keyword in node.keywords: + if keyword.arg in passed_args: + diag = diagnostic.Diagnostic("error", + "the argument '{name}' has been passed earlier as positional", + {"name": keyword.arg}, + keyword.arg_loc, [passed_args[keyword.arg]]) + self.engine.process(diag) + return + + if keyword.arg in typ_args: + self._unify(keyword.value.type, typ_args[keyword.arg], + keyword.value.loc, None) + elif keyword.arg in typ_optargs: + self._unify(keyword.value.type, typ_optargs[keyword.arg], + keyword.value.loc, None) + passed_args[keyword.arg] = keyword.arg_loc + + for formalname in typ_args: + if formalname not in passed_args: + note = diagnostic.Diagnostic("note", + "the called function is of type {type}", + {"type": types.TypePrinter().name(node.func.type)}, + node.func.loc) + diag = diagnostic.Diagnostic("error", + "mandatory argument '{name}' is not passed", + {"name": formalname}, + node.begin_loc.join(node.end_loc), [], [note]) + self.engine.process(diag) + return + + self._unify(node.type, typ_ret, + node.loc, None) + + def visit_LambdaT(self, node): + self.generic_visit(node) + signature_type = self._type_from_arguments(node.args, node.body.type) + if signature_type: + self._unify(node.type, signature_type, + node.loc, None) + + def visit_Assign(self, node): + self.generic_visit(node) + for target in node.targets: + self._unify(target.type, node.value.type, + target.loc, node.value.loc) + + def visit_AugAssign(self, node): + self.generic_visit(node) + coerced = self._coerce_binop(node.op, node.target, node.value) + if coerced: + return_type, target_type, value_type = coerced + + try: + node.target.type.unify(target_type) + except types.UnificationError as e: + printer = types.TypePrinter() + note = diagnostic.Diagnostic("note", + "expression of type {typec}", + {"typec": printer.name(node.value.type)}, + node.value.loc) + diag = diagnostic.Diagnostic("error", + "expression of type {typea} has to be coerced to {typeb}, " + "which makes assignment invalid", + {"typea": printer.name(node.target.type), + "typeb": printer.name(target_type)}, + node.op.loc, [node.target.loc], [note]) + self.engine.process(diag) + return + + try: + node.target.type.unify(return_type) + except types.UnificationError as e: + printer = types.TypePrinter() + note = diagnostic.Diagnostic("note", + "expression of type {typec}", + {"typec": printer.name(node.value.type)}, + node.value.loc) + diag = diagnostic.Diagnostic("error", + "the result of this operation has type {typeb}, " + "which makes assignment to a slot of type {typea} invalid", + {"typea": printer.name(node.target.type), + "typeb": printer.name(return_type)}, + node.op.loc, [node.target.loc], [note]) + self.engine.process(diag) + return + + node.value = self._coerce_one(value_type, node.value, other_node=node.target) + + def visit_For(self, node): + old_in_loop, self.in_loop = self.in_loop, True + self.generic_visit(node) + self.in_loop = old_in_loop + self._unify_iterable(node.target, node.iter) + + def visit_While(self, node): + old_in_loop, self.in_loop = self.in_loop, True + self.generic_visit(node) + self.in_loop = old_in_loop + + def visit_Break(self, node): + if not self.in_loop: + diag = diagnostic.Diagnostic("error", + "break statement outside of a loop", {}, + node.keyword_loc) + self.engine.process(diag) + + def visit_Continue(self, node): + if not self.in_loop: + diag = diagnostic.Diagnostic("error", + "continue statement outside of a loop", {}, + node.keyword_loc) + self.engine.process(diag) + + def visit_withitem(self, node): + self.generic_visit(node) + + typ = node.context_expr.type + if not (types.is_builtin(typ, "parallel") or types.is_builtin(typ, "sequential")): + diag = diagnostic.Diagnostic("error", + "value of type {type} cannot act as a context manager", + {"type": types.TypePrinter().name(typ)}, + node.context_expr.loc) + self.engine.process(diag) + + if node.optional_vars is not None: + self._unify(node.optional_vars.type, node.context_expr.type, + node.optional_vars.loc, node.context_expr.loc) + + def visit_ExceptHandlerT(self, node): + self.generic_visit(node) + + if node.filter is not None: + if not types.is_exn_constructor(node.filter.type): + diag = diagnostic.Diagnostic("error", + "this expression must refer to an exception constructor", + {"type": types.TypePrinter().name(node.filter.type)}, + node.filter.loc) + self.engine.process(diag) + else: + def makenotes(printer, typea, typeb, loca, locb): + return [ + diagnostic.Diagnostic("note", + "expression of type {typea}", + {"typea": printer.name(typea)}, + loca), + diagnostic.Diagnostic("note", + "constructor of an exception of type {typeb}", + {"typeb": printer.name(typeb)}, + locb) + ] + self._unify(node.name_type, builtins.TException(node.filter.type.name), + node.name_loc, node.filter.loc, makenotes) + + def _type_from_arguments(self, node, ret): + self.generic_visit(node) + + for (sigil_loc, vararg) in ((node.star_loc, node.vararg), + (node.dstar_loc, node.kwarg)): + if vararg: + diag = diagnostic.Diagnostic("error", + "variadic arguments are not supported", {}, + sigil_loc, [vararg.loc]) + self.engine.process(diag) + return + + def extract_args(arg_nodes): + args = [(arg_node.arg, arg_node.type) for arg_node in arg_nodes] + return OrderedDict(args) + + return types.TFunction(extract_args(node.args[:len(node.args) - len(node.defaults)]), + extract_args(node.args[len(node.args) - len(node.defaults):]), + ret) + + def visit_arguments(self, node): + self.generic_visit(node) + for arg, default in zip(node.args[len(node.args) - len(node.defaults):], node.defaults): + self._unify(arg.type, default.type, + arg.loc, default.loc) + + def visit_FunctionDefT(self, node): + for index, decorator in enumerate(node.decorator_list): + if types.is_builtin(decorator.type, "kernel"): + continue + + diag = diagnostic.Diagnostic("error", + "decorators are not supported", {}, + node.at_locs[index], [decorator.loc]) + self.engine.process(diag) + + try: + old_function, self.function = self.function, node + old_in_loop, self.in_loop = self.in_loop, False + old_has_return, self.has_return = self.has_return, False + + self.generic_visit(node) + + # Lack of return statements is not the only case where the return + # type cannot be inferred. The other one is infinite (possibly mutual) + # recursion. Since Python functions don't have to return a value, + # we ignore that one. + if not self.has_return: + def makenotes(printer, typea, typeb, loca, locb): + return [ + diagnostic.Diagnostic("note", + "function with return type {typea}", + {"typea": printer.name(typea)}, + node.name_loc), + ] + self._unify(node.return_type, builtins.TNone(), + node.name_loc, None, makenotes) + finally: + self.function = old_function + self.in_loop = old_in_loop + self.has_return = old_has_return + + signature_type = self._type_from_arguments(node.args, node.return_type) + if signature_type: + self._unify(node.signature_type, signature_type, + node.name_loc, None) + + def visit_ClassDefT(self, node): + if any(node.decorator_list): + diag = diagnostic.Diagnostic("error", + "decorators are not supported", {}, + node.at_locs[0], [node.decorator_list[0].loc]) + self.engine.process(diag) + + self.generic_visit(node) + + def visit_Return(self, node): + if not self.function: + diag = diagnostic.Diagnostic("error", + "return statement outside of a function", {}, + node.keyword_loc) + self.engine.process(diag) + return + + self.has_return = True + + self.generic_visit(node) + def makenotes(printer, typea, typeb, loca, locb): + return [ + diagnostic.Diagnostic("note", + "function with return type {typea}", + {"typea": printer.name(typea)}, + self.function.name_loc), + diagnostic.Diagnostic("note", + "a statement returning {typeb}", + {"typeb": printer.name(typeb)}, + node.loc) + ] + if node.value is None: + self._unify(self.function.return_type, builtins.TNone(), + self.function.name_loc, node.loc, makenotes) + else: + self._unify(self.function.return_type, node.value.type, + self.function.name_loc, node.value.loc, makenotes) + + def visit_Raise(self, node): + self.generic_visit(node) + + if node.exc is not None: + exc_type = node.exc.type + if not types.is_var(exc_type) and not builtins.is_exception(exc_type): + diag = diagnostic.Diagnostic("error", + "cannot raise a value of type {type}, which is not an exception", + {"type": types.TypePrinter().name(exc_type)}, + node.exc.loc) + self.engine.process(diag) + + def visit_Assert(self, node): + self.generic_visit(node) + self._unify(node.test.type, builtins.TBool(), + node.test.loc, None) + if node.msg is not None: + if not isinstance(node.msg, asttyped.StrT): + diag = diagnostic.Diagnostic("error", + "assertion message must be a string literal", {}, + node.msg.loc) + self.engine.process(diag) diff --git a/artiq/compiler/transforms/int_monomorphizer.py b/artiq/compiler/transforms/int_monomorphizer.py new file mode 100644 index 000000000..3aeed9140 --- /dev/null +++ b/artiq/compiler/transforms/int_monomorphizer.py @@ -0,0 +1,28 @@ +""" +:class:`IntMonomorphizer` collapses the integer literals of undetermined +width to 32 bits, assuming they fit into 32 bits, or 64 bits if they +do not. +""" + +from pythonparser import algorithm, diagnostic +from .. import types, builtins + +class IntMonomorphizer(algorithm.Visitor): + def __init__(self, engine): + self.engine = engine + + def visit_NumT(self, node): + if builtins.is_int(node.type): + if types.is_var(node.type["width"]): + if -2**31 < node.n < 2**31-1: + width = 32 + elif -2**63 < node.n < 2**63-1: + width = 64 + else: + diag = diagnostic.Diagnostic("error", + "integer literal out of range for a signed 64-bit value", {}, + node.loc) + self.engine.process(diag) + return + + node.type["width"].unify(types.TValue(width)) diff --git a/artiq/compiler/transforms/interleaver.py b/artiq/compiler/transforms/interleaver.py new file mode 100644 index 000000000..4aa3c327b --- /dev/null +++ b/artiq/compiler/transforms/interleaver.py @@ -0,0 +1,161 @@ +""" +:class:`Interleaver` reorders requests to the RTIO core so that +the timestamp would always monotonically nondecrease. +""" + +from pythonparser import diagnostic + +from .. import types, builtins, ir, iodelay +from ..analyses import domination +from ..algorithms import inline + +def delay_free_subgraph(root, limit): + visited = set() + queue = root.successors() + while len(queue) > 0: + block = queue.pop() + visited.add(block) + + if block is limit: + continue + + if isinstance(block.terminator(), ir.Delay): + return False + + for successor in block.successors(): + if successor not in visited: + queue.append(successor) + + return True + +def iodelay_of_block(block): + terminator = block.terminator() + if isinstance(terminator, ir.Delay): + # We should be able to fold everything without free variables. + folded_expr = terminator.expr.fold() + assert iodelay.is_const(folded_expr) + return folded_expr.value + else: + return 0 + +def is_pure_delay(insn): + return isinstance(insn, ir.Builtin) and insn.op in ("delay", "delay_mu") + +def is_impure_delay_block(block): + terminator = block.terminator() + return isinstance(terminator, ir.Delay) and \ + not is_pure_delay(terminator.decomposition()) + +class Interleaver: + def __init__(self, engine): + self.engine = engine + + def process(self, functions): + for func in functions: + self.process_function(func) + + def process_function(self, func): + for insn in func.instructions(): + if isinstance(insn, ir.Delay): + if any(insn.expr.free_vars()): + # If a function has free variables in delay expressions, + # that means its IO delay depends on arguments. + # Do not change such functions in any way so that it will + # be successfully inlined and then removed by DCE. + return + + postdom_tree = None + for insn in func.instructions(): + if not isinstance(insn, ir.Parallel): + continue + + # Lazily compute dominators. + if postdom_tree is None: + postdom_tree = domination.PostDominatorTree(func) + + interleave_until = postdom_tree.immediate_dominator(insn.basic_block) + assert (interleave_until is not None) # no nonlocal flow in `with parallel` + + target_block = insn.basic_block + target_time = 0 + source_blocks = insn.basic_block.successors() + source_times = [0 for _ in source_blocks] + + while len(source_blocks) > 0: + def time_after_block(pair): + index, block = pair + return source_times[index] + iodelay_of_block(block) + + # Always prefer impure blocks (with calls) to pure blocks, because + # impure blocks may expand with smaller delays appearing, and in + # case of a tie, if a pure block is preferred, this would violate + # the timeline monotonicity. + available_source_blocks = list(filter(is_impure_delay_block, source_blocks)) + if not any(available_source_blocks): + available_source_blocks = source_blocks + + index, source_block = min(enumerate(available_source_blocks), key=time_after_block) + source_block_delay = iodelay_of_block(source_block) + + new_target_time = source_times[index] + source_block_delay + target_time_delta = new_target_time - target_time + assert target_time_delta >= 0 + + target_terminator = target_block.terminator() + if isinstance(target_terminator, ir.Parallel): + target_terminator.replace_with(ir.Branch(source_block)) + else: + assert isinstance(target_terminator, (ir.Delay, ir.Branch)) + target_terminator.set_target(source_block) + + source_terminator = source_block.terminator() + + if not isinstance(source_terminator, ir.Delay): + source_terminator.replace_with(ir.Branch(source_terminator.target())) + else: + old_decomp = source_terminator.decomposition() + if is_pure_delay(old_decomp): + if target_time_delta > 0: + new_decomp_expr = ir.Constant(int(target_time_delta), builtins.TInt64()) + new_decomp = ir.Builtin("delay_mu", [new_decomp_expr], builtins.TNone()) + new_decomp.loc = old_decomp.loc + + source_terminator.basic_block.insert(new_decomp, before=source_terminator) + source_terminator.expr = iodelay.Const(target_time_delta) + source_terminator.set_decomposition(new_decomp) + else: + source_terminator.replace_with(ir.Branch(source_terminator.target())) + old_decomp.erase() + else: # It's a call. + need_to_inline = len(source_blocks) > 1 + if need_to_inline: + if old_decomp.static_target_function is None: + diag = diagnostic.Diagnostic("fatal", + "it is not possible to interleave this function call within " + "a 'with parallel:' statement because the compiler could not " + "prove that the same function would always be called", {}, + old_decomp.loc) + self.engine.process(diag) + + inline(old_decomp) + postdom_tree = domination.PostDominatorTree(func) + continue + elif target_time_delta > 0: + source_terminator.expr = iodelay.Const(target_time_delta) + else: + source_terminator.replace_with(ir.Branch(source_terminator.target())) + + target_block = source_block + target_time = new_target_time + + new_source_block = postdom_tree.immediate_dominator(source_block) + assert (new_source_block is not None) + assert delay_free_subgraph(source_block, new_source_block) + + if new_source_block == interleave_until: + # We're finished with this branch. + del source_blocks[index] + del source_times[index] + else: + source_blocks[index] = new_source_block + source_times[index] = new_target_time diff --git a/artiq/compiler/transforms/iodelay_estimator.py b/artiq/compiler/transforms/iodelay_estimator.py new file mode 100644 index 000000000..4170094c0 --- /dev/null +++ b/artiq/compiler/transforms/iodelay_estimator.py @@ -0,0 +1,286 @@ +""" +:class:`IODelayEstimator` calculates the amount of time +elapsed from the point of view of the RTIO core for +every function. +""" + +from pythonparser import ast, algorithm, diagnostic +from .. import types, iodelay, builtins, asttyped + +class _UnknownDelay(Exception): + pass + +class _IndeterminateDelay(Exception): + def __init__(self, cause): + self.cause = cause + +class IODelayEstimator(algorithm.Visitor): + def __init__(self, engine, ref_period): + self.engine = engine + self.ref_period = ref_period + self.changed = False + self.current_delay = iodelay.Const(0) + self.current_args = None + self.current_goto = None + self.current_return = None + + def evaluate(self, node, abort): + if isinstance(node, asttyped.NumT): + return iodelay.Const(node.n) + elif isinstance(node, asttyped.CoerceT): + return self.evaluate(node.value, abort) + elif isinstance(node, asttyped.NameT): + if self.current_args is None: + note = diagnostic.Diagnostic("note", + "this variable is not an argument", {}, + node.loc) + abort([note]) + elif node.id in [arg.arg for arg in self.current_args.args]: + return iodelay.Var(node.id) + else: + notes = [ + diagnostic.Diagnostic("note", + "this variable is not an argument of the innermost function", {}, + node.loc), + diagnostic.Diagnostic("note", + "only these arguments are in scope of analysis", {}, + self.current_args.loc) + ] + abort(notes) + elif isinstance(node, asttyped.BinOpT): + lhs = self.evaluate(node.left, abort) + rhs = self.evaluate(node.right, abort) + if isinstance(node.op, ast.Add): + return lhs + rhs + elif isinstance(node.op, ast.Sub): + return lhs - rhs + elif isinstance(node.op, ast.Mult): + return lhs * rhs + elif isinstance(node.op, ast.Div): + return lhs / rhs + elif isinstance(node.op, ast.FloorDiv): + return lhs // rhs + else: + note = diagnostic.Diagnostic("note", + "this operator is not supported", {}, + node.op.loc) + abort([note]) + else: + note = diagnostic.Diagnostic("note", + "this expression is not supported", {}, + node.loc) + abort([note]) + + def abort(self, message, loc, notes=[]): + diag = diagnostic.Diagnostic("error", message, {}, loc, notes=notes) + raise _IndeterminateDelay(diag) + + def visit_fixpoint(self, node): + while True: + self.changed = False + self.visit(node) + if not self.changed: + return + + def visit_ModuleT(self, node): + try: + for stmt in node.body: + try: + self.visit(stmt) + except _UnknownDelay: + pass # more luck next time? + except _IndeterminateDelay: + pass # we don't care; module-level code is never interleaved + + def visit_function(self, args, body, typ, loc): + old_args, self.current_args = self.current_args, args + old_return, self.current_return = self.current_return, None + old_delay, self.current_delay = self.current_delay, iodelay.Const(0) + try: + self.visit(body) + if not iodelay.is_zero(self.current_delay) and self.current_return is not None: + self.abort("only return statement at the end of the function " + "can be interleaved", self.current_return.loc) + + delay = types.TFixedDelay(self.current_delay.fold()) + except _IndeterminateDelay as error: + delay = types.TIndeterminateDelay(error.cause) + self.current_delay = old_delay + self.current_return = old_return + self.current_args = old_args + + if types.is_indeterminate_delay(delay) and types.is_indeterminate_delay(typ.delay): + # Both delays indeterminate; no point in unifying since that will + # replace the lazy and more specific error with an eager and more generic + # error (unification error of delay(?) with delay(?), which is useless). + return + + try: + old_delay = typ.delay.find() + typ.delay.unify(delay) + if typ.delay.find() != old_delay: + self.changed = True + except types.UnificationError as e: + printer = types.TypePrinter() + diag = diagnostic.Diagnostic("fatal", + "delay {delaya} was inferred for this function, but its delay is already " + "constrained externally to {delayb}", + {"delaya": printer.name(delay), "delayb": printer.name(typ.delay)}, + loc) + self.engine.process(diag) + + def visit_FunctionDefT(self, node): + self.visit(node.args.defaults) + self.visit(node.args.kw_defaults) + + # We can only handle return in tail position. + if isinstance(node.body[-1], ast.Return): + body = node.body[:-1] + else: + body = node.body + self.visit_function(node.args, body, node.signature_type.find(), node.loc) + + def visit_LambdaT(self, node): + self.visit_function(node.args, node.body, node.type.find(), node.loc) + + def get_iterable_length(self, node): + def abort(notes): + self.abort("for statement cannot be interleaved because " + "trip count is indeterminate", + node.loc, notes) + + def evaluate(node): + return self.evaluate(node, abort) + + if isinstance(node, asttyped.CallT) and types.is_builtin(node.func.type, "range"): + range_min, range_max, range_step = iodelay.Const(0), None, iodelay.Const(1) + if len(node.args) == 3: + range_min, range_max, range_step = map(evaluate, node.args) + elif len(node.args) == 2: + range_min, range_max = map(evaluate, node.args) + elif len(node.args) == 1: + range_max, = map(evaluate, node.args) + return (range_max - range_min) // range_step + else: + note = diagnostic.Diagnostic("note", + "this value is not a constant range literal", {}, + node.loc) + abort([note]) + + def visit_For(self, node): + self.visit(node.iter) + + old_goto, self.current_goto = self.current_goto, None + old_delay, self.current_delay = self.current_delay, iodelay.Const(0) + self.visit(node.body) + if iodelay.is_zero(self.current_delay): + self.current_delay = old_delay + else: + if self.current_goto is not None: + self.abort("loop trip count is indeterminate because of control flow", + self.current_goto.loc) + + trip_count = self.get_iterable_length(node.iter) + self.current_delay = old_delay + self.current_delay * trip_count + self.current_goto = old_goto + + self.visit(node.orelse) + + def visit_goto(self, node): + self.current_goto = node + + visit_Break = visit_goto + visit_Continue = visit_goto + + def visit_control_flow(self, kind, node): + old_delay, self.current_delay = self.current_delay, iodelay.Const(0) + self.generic_visit(node) + if not iodelay.is_zero(self.current_delay): + self.abort("{} cannot be interleaved".format(kind), node.loc) + self.current_delay = old_delay + + visit_While = lambda self, node: self.visit_control_flow("while statement", node) + visit_If = lambda self, node: self.visit_control_flow("if statement", node) + visit_IfExpT = lambda self, node: self.visit_control_flow("if expression", node) + visit_Try = lambda self, node: self.visit_control_flow("try statement", node) + + def visit_Return(self, node): + self.current_return = node + + def visit_With(self, node): + self.visit(node.items) + + context_expr = node.items[0].context_expr + if len(node.items) == 1 and types.is_builtin(context_expr.type, "parallel"): + try: + delays = [] + for stmt in node.body: + old_delay, self.current_delay = self.current_delay, iodelay.Const(0) + self.visit(stmt) + delays.append(self.current_delay) + self.current_delay = old_delay + + if any(delays): + self.current_delay += iodelay.Max(delays) + except _IndeterminateDelay as error: + # Interleave failures inside `with` statements are hard failures, + # since there's no chance that the code will never actually execute + # inside a `with` statement after all. + self.engine.process(error.cause) + + elif len(node.items) == 1 and types.is_builtin(context_expr.type, "sequential"): + self.visit(node.body) + else: + self.abort("with statement cannot be interleaved", node.loc) + + def visit_CallT(self, node): + typ = node.func.type.find() + def abort(notes): + self.abort("this call cannot be interleaved because " + "an argument cannot be statically evaluated", + node.loc, notes) + + if types.is_builtin(typ, "delay"): + value = self.evaluate(node.args[0], abort=abort) + call_delay = iodelay.SToMU(value, ref_period=self.ref_period) + elif types.is_builtin(typ, "delay_mu"): + value = self.evaluate(node.args[0], abort=abort) + call_delay = value + elif not types.is_builtin(typ): + if types.is_function(typ): + offset = 0 + elif types.is_method(typ): + offset = 1 + typ = types.get_method_function(typ) + else: + assert False + + delay = typ.find().delay.find() + if types.is_var(delay): + raise _UnknownDelay() + elif delay.is_indeterminate(): + note = diagnostic.Diagnostic("note", + "function called here", {}, + node.loc) + cause = delay.cause + cause = diagnostic.Diagnostic(cause.level, cause.reason, cause.arguments, + cause.location, cause.highlights, + cause.notes + [note]) + raise _IndeterminateDelay(cause) + elif delay.is_fixed(): + args = {} + for kw_node in node.keywords: + args[kw_node.arg] = kw_node.value + for arg_name, arg_node in zip(list(typ.args)[offset:], node.args): + args[arg_name] = arg_node + + free_vars = delay.duration.free_vars() + call_delay = delay.duration.fold( + { arg: self.evaluate(args[arg], abort=abort) for arg in free_vars }) + else: + assert False + else: + call_delay = iodelay.Const(0) + + self.current_delay += call_delay + node.iodelay = call_delay diff --git a/artiq/compiler/transforms/llvm_ir_generator.py b/artiq/compiler/transforms/llvm_ir_generator.py new file mode 100644 index 000000000..ac54ecf9f --- /dev/null +++ b/artiq/compiler/transforms/llvm_ir_generator.py @@ -0,0 +1,1151 @@ +""" +:class:`LLVMIRGenerator` transforms ARTIQ intermediate representation +into LLVM intermediate representation. +""" + +import os +from pythonparser import ast, diagnostic +from llvmlite_artiq import ir as ll +from ...language import core as language_core +from .. import types, builtins, ir + + +llvoid = ll.VoidType() +lli1 = ll.IntType(1) +lli8 = ll.IntType(8) +lli32 = ll.IntType(32) +lli64 = ll.IntType(64) +lldouble = ll.DoubleType() +llptr = ll.IntType(8).as_pointer() +llmetadata = ll.MetaData() + + +DW_LANG_Python = 0x0014 +DW_TAG_compile_unit = 17 +DW_TAG_subroutine_type = 21 +DW_TAG_file_type = 41 +DW_TAG_subprogram = 46 + +def memoize(generator): + def memoized(self, *args): + result = self.cache.get((generator,) + args, None) + if result is None: + return generator(self, *args) + else: + return result + return memoized + +class DebugInfoEmitter: + def __init__(self, llmodule): + self.llmodule = llmodule + self.cache = {} + self.subprograms = [] + + def emit(self, operands): + def map_operand(operand): + if operand is None: + return ll.Constant(llmetadata, None) + elif isinstance(operand, str): + return ll.MetaDataString(self.llmodule, operand) + elif isinstance(operand, bool): + return ll.Constant(lli1, operand) + elif isinstance(operand, int): + return ll.Constant(lli32, operand) + elif isinstance(operand, (list, tuple)): + return self.emit(operand) + elif isinstance(operand, ll.Value): + return operand + else: + print(operand) + assert False + return self.llmodule.add_metadata(list(map(map_operand, operands))) + + @memoize + def emit_filename(self, source_buffer): + source_dir, source_file = os.path.split(source_buffer.name) + return self.emit([source_file, source_dir]) + + @memoize + def emit_compile_unit(self, source_buffer, llsubprograms): + return self.emit([ + DW_TAG_compile_unit, + self.emit_filename(source_buffer), # filename + DW_LANG_Python, # source language + "ARTIQ", # producer + False, # optimized? + "", # linker flags + 0, # runtime version + [], # enum types + [], # retained types + llsubprograms, # subprograms + [], # global variables + [], # imported entities + "", # split debug filename + 2, # kind (full=1, lines only=2) + ]) + + @memoize + def emit_file(self, source_buffer): + return self.emit([ + DW_TAG_file_type, + self.emit_filename(source_buffer), # filename + ]) + + @memoize + def emit_subroutine_type(self, typ): + return self.emit([ + DW_TAG_subroutine_type, + None, # filename + None, # context descriptor + "", # name + 0, # line number + 0, # (i64) size in bits + 0, # (i64) alignment in bits + 0, # (i64) offset in bits + 0, # flags + None, # derived from + [None], # members + 0, # runtime languages + None, # base type with vtable pointer + None, # template parameters + None # unique identifier + ]) + + @memoize + def emit_subprogram(self, func, llfunc): + source_buffer = func.loc.source_buffer + display_name = "{}{}".format(func.name, types.TypePrinter().name(func.type)) + subprogram = self.emit([ + DW_TAG_subprogram, + self.emit_filename(source_buffer), # filename + self.emit_file(source_buffer), # context descriptor + func.name, # name + display_name, # display name + llfunc.name, # linkage name + func.loc.line(), # line number where defined + self.emit_subroutine_type(func.type), # type descriptor + func.is_internal, # local to compile unit? + True, # global is defined in the compile unit? + 0, # virtuality + 0, # index into a virtual function + None, # base type with vtable pointer + 0, # flags + False, # optimized? + llfunc, # LLVM function + None, # template parameters + None, # function declaration descriptor + [], # function variables + func.loc.line(), # line number where scope begins + ]) + self.subprograms.append(subprogram) + return subprogram + + @memoize + def emit_loc(self, loc, scope, inlined_scope=None): + return self.emit([ + loc.line(), # line + loc.column(), # column + scope, # scope + inlined_scope, # inlined scope + ]) + + def finalize(self, source_buffer): + llident = self.llmodule.add_named_metadata('llvm.ident') + llident.add(self.emit(["ARTIQ"])) + + llflags = self.llmodule.add_named_metadata('llvm.module.flags') + llflags.add(self.emit([2, "Debug Info Version", 1])) + + llcompile_units = self.llmodule.add_named_metadata('llvm.dbg.cu') + llcompile_units.add(self.emit_compile_unit(source_buffer, tuple(self.subprograms))) + + +class LLVMIRGenerator: + def __init__(self, engine, module_name, target, object_map): + self.engine = engine + self.target = target + self.object_map = object_map + self.llcontext = target.llcontext + self.llmodule = ll.Module(context=self.llcontext, name=module_name) + self.llmodule.triple = target.triple + self.llmodule.data_layout = target.data_layout + self.llfunction = None + self.llmap = {} + self.llobject_map = {} + self.phis = [] + self.debug_info_emitter = DebugInfoEmitter(self.llmodule) + + def needs_sret(self, lltyp, may_be_large=True): + if isinstance(lltyp, ll.VoidType): + return False + elif isinstance(lltyp, ll.IntType) and lltyp.width <= 32: + return False + elif isinstance(lltyp, ll.PointerType): + return False + elif may_be_large and isinstance(lltyp, ll.DoubleType): + return False + elif may_be_large and isinstance(lltyp, ll.LiteralStructType) \ + and len(lltyp.elements) <= 2: + return not any([self.needs_sret(elt, may_be_large=False) for elt in lltyp.elements]) + else: + return True + + def llty_of_type(self, typ, bare=False, for_return=False): + typ = typ.find() + if types.is_tuple(typ): + return ll.LiteralStructType([self.llty_of_type(eltty) for eltty in typ.elts]) + elif types.is_rpc_function(typ) or types.is_c_function(typ): + if for_return: + return llvoid + else: + return ll.LiteralStructType([]) + elif types._is_pointer(typ): + return llptr + elif types.is_function(typ): + sretarg = [] + llretty = self.llty_of_type(typ.ret, for_return=True) + if self.needs_sret(llretty): + sretarg = [llretty.as_pointer()] + llretty = llvoid + + envarg = llptr + llty = ll.FunctionType(args=sretarg + [envarg] + + [self.llty_of_type(typ.args[arg]) + for arg in typ.args] + + [self.llty_of_type(ir.TOption(typ.optargs[arg])) + for arg in typ.optargs], + return_type=llretty) + + # TODO: actually mark the first argument as sret (also noalias nocapture). + # llvmlite currently does not have support for this; + # https://github.com/numba/llvmlite/issues/91. + if sretarg: + llty.__has_sret = True + else: + llty.__has_sret = False + + if bare: + return llty + else: + return ll.LiteralStructType([envarg, llty.as_pointer()]) + elif types.is_method(typ): + llfuncty = self.llty_of_type(types.get_method_function(typ)) + llselfty = self.llty_of_type(types.get_method_self(typ)) + return ll.LiteralStructType([llfuncty, llselfty]) + elif builtins.is_none(typ): + if for_return: + return llvoid + else: + return ll.LiteralStructType([]) + elif builtins.is_bool(typ): + return lli1 + elif builtins.is_int(typ): + return ll.IntType(builtins.get_int_width(typ)) + elif builtins.is_float(typ): + return lldouble + elif builtins.is_str(typ) or ir.is_exn_typeinfo(typ): + return llptr + elif builtins.is_list(typ): + lleltty = self.llty_of_type(builtins.get_iterable_elt(typ)) + return ll.LiteralStructType([lli32, lleltty.as_pointer()]) + elif builtins.is_range(typ): + lleltty = self.llty_of_type(builtins.get_iterable_elt(typ)) + return ll.LiteralStructType([lleltty, lleltty, lleltty]) + elif ir.is_basic_block(typ): + return llptr + elif ir.is_option(typ): + return ll.LiteralStructType([lli1, self.llty_of_type(typ.params["inner"])]) + elif ir.is_environment(typ): + llty = ll.LiteralStructType([self.llty_of_type(typ.params[name]) + for name in typ.params]) + if bare: + return llty + else: + return llty.as_pointer() + else: # Catch-all for exceptions and custom classes + if builtins.is_exception(typ): + name = "class.Exception" # they all share layout + elif types.is_constructor(typ): + name = "class.{}".format(typ.name) + else: + name = "instance.{}".format(typ.name) + + llty = self.llcontext.get_identified_type(name) + if llty.elements is None: + # First setting elements to [] will allow us to handle + # self-referential types. + llty.elements = [] + llty.elements = [self.llty_of_type(attrtyp) + for attrtyp in typ.attributes.values()] + + if bare or not builtins.is_allocated(typ): + return llty + else: + return llty.as_pointer() + + def llstr_of_str(self, value, name=None, + linkage="private", unnamed_addr=True): + if isinstance(value, str): + assert "\0" not in value + as_bytes = (value + "\0").encode("utf-8") + else: + as_bytes = value + + if name is None: + name = self.llmodule.get_unique_name("str") + + llstr = self.llmodule.get_global(name) + if llstr is None: + llstrty = ll.ArrayType(lli8, len(as_bytes)) + llstr = ll.GlobalVariable(self.llmodule, llstrty, name) + llstr.global_constant = True + llstr.initializer = ll.Constant(llstrty, bytearray(as_bytes)) + llstr.linkage = linkage + llstr.unnamed_addr = unnamed_addr + return llstr.bitcast(llptr) + + def llconst_of_const(self, const): + llty = self.llty_of_type(const.type) + if const.value is None: + return ll.Constant(llty, []) + elif const.value is True: + return ll.Constant(llty, True) + elif const.value is False: + return ll.Constant(llty, False) + elif isinstance(const.value, (int, float)): + return ll.Constant(llty, const.value) + elif isinstance(const.value, (str, bytes)): + if ir.is_exn_typeinfo(const.type): + # Exception typeinfo; should be merged with identical others + name = "__artiq_exn_" + const.value + linkage = "linkonce" + unnamed_addr = False + else: + # Just a string + name = None + linkage = "private" + unnamed_addr = True + + return self.llstr_of_str(const.value, name=name, + linkage=linkage, unnamed_addr=unnamed_addr) + else: + assert False + + def llbuiltin(self, name): + llglobal = self.llmodule.get_global(name) + if llglobal is not None: + return llglobal + + if name in "llvm.donothing": + llty = ll.FunctionType(llvoid, []) + elif name in "llvm.trap": + llty = ll.FunctionType(llvoid, []) + elif name == "llvm.floor.f64": + llty = ll.FunctionType(lldouble, [lldouble]) + elif name == "llvm.round.f64": + llty = ll.FunctionType(lldouble, [lldouble]) + elif name == "llvm.pow.f64": + llty = ll.FunctionType(lldouble, [lldouble, lldouble]) + elif name == "llvm.powi.f64": + llty = ll.FunctionType(lldouble, [lldouble, lli32]) + elif name == "llvm.copysign.f64": + llty = ll.FunctionType(lldouble, [lldouble, lldouble]) + elif name == "llvm.stacksave": + llty = ll.FunctionType(llptr, []) + elif name == "llvm.stackrestore": + llty = ll.FunctionType(llvoid, [llptr]) + elif name == self.target.print_function: + llty = ll.FunctionType(llvoid, [llptr], var_arg=True) + elif name == "__artiq_personality": + llty = ll.FunctionType(lli32, [], var_arg=True) + elif name == "__artiq_raise": + llty = ll.FunctionType(llvoid, [self.llty_of_type(builtins.TException())]) + elif name == "__artiq_reraise": + llty = ll.FunctionType(llvoid, []) + elif name == "send_rpc": + llty = ll.FunctionType(llvoid, [lli32, llptr], + var_arg=True) + elif name == "recv_rpc": + llty = ll.FunctionType(lli32, [llptr]) + elif name == "now": + llty = lli64 + else: + assert False + + if isinstance(llty, ll.FunctionType): + llglobal = ll.Function(self.llmodule, llty, name) + if name in ("__artiq_raise", "__artiq_reraise", "llvm.trap"): + llglobal.attributes.add("noreturn") + else: + llglobal = ll.GlobalVariable(self.llmodule, llty, name) + + return llglobal + + def map(self, value): + if isinstance(value, (ir.Argument, ir.Instruction, ir.BasicBlock)): + return self.llmap[value] + elif isinstance(value, ir.Constant): + return self.llconst_of_const(value) + elif isinstance(value, ir.Function): + llfun = self.llmodule.get_global(value.name) + if llfun is None: + llfun = ll.Function(self.llmodule, self.llty_of_type(value.type, bare=True), + value.name) + return llfun + else: + assert False + + def process(self, functions): + for func in functions: + self.process_function(func) + + if any(functions): + self.debug_info_emitter.finalize(functions[0].loc.source_buffer) + + return self.llmodule + + def process_function(self, func): + try: + self.llfunction = self.llmodule.get_global(func.name) + + if self.llfunction is None: + llfunty = self.llty_of_type(func.type, bare=True) + self.llfunction = ll.Function(self.llmodule, llfunty, func.name) + + if func.is_internal: + self.llfunction.linkage = 'internal' + + self.llfunction.attributes.add('uwtable') + + self.llbuilder = ll.IRBuilder() + llblock_map = {} + + disubprogram = self.debug_info_emitter.emit_subprogram(func, self.llfunction) + + # First, map arguments. + if self.llfunction.type.pointee.__has_sret: + llactualargs = self.llfunction.args[1:] + else: + llactualargs = self.llfunction.args + + for arg, llarg in zip(func.arguments, llactualargs): + self.llmap[arg] = llarg + + # Second, create all basic blocks. + for block in func.basic_blocks: + llblock = self.llfunction.append_basic_block(block.name) + self.llmap[block] = llblock + + # Third, translate all instructions. + for block in func.basic_blocks: + self.llbuilder.position_at_end(self.llmap[block]) + for insn in block.instructions: + if insn.loc is not None: + self.llbuilder.debug_metadata = \ + self.debug_info_emitter.emit_loc(insn.loc, disubprogram) + + llinsn = getattr(self, "process_" + type(insn).__name__)(insn) + assert llinsn is not None + self.llmap[insn] = llinsn + + # There is no 1:1 correspondence between ARTIQ and LLVM + # basic blocks, because sometimes we expand a single ARTIQ + # instruction so that the result spans several LLVM basic + # blocks. This only really matters for phis, which will + # use a different map. + llblock_map[block] = self.llbuilder.basic_block + + # Fourth, add incoming values to phis. + for phi, llphi in self.phis: + for value, block in phi.incoming(): + llphi.add_incoming(self.map(value), llblock_map[block]) + finally: + self.llfunction = None + self.llmap = {} + self.phis = [] + + def process_Phi(self, insn): + llinsn = self.llbuilder.phi(self.llty_of_type(insn.type), name=insn.name) + self.phis.append((insn, llinsn)) + return llinsn + + def llindex(self, index): + return ll.Constant(lli32, index) + + def process_Alloc(self, insn): + if ir.is_environment(insn.type): + return self.llbuilder.alloca(self.llty_of_type(insn.type, bare=True), + name=insn.name) + elif ir.is_option(insn.type): + if len(insn.operands) == 0: # empty + llvalue = ll.Constant(self.llty_of_type(insn.type), ll.Undefined) + return self.llbuilder.insert_value(llvalue, ll.Constant(lli1, False), 0, + name=insn.name) + elif len(insn.operands) == 1: # full + llvalue = ll.Constant(self.llty_of_type(insn.type), ll.Undefined) + llvalue = self.llbuilder.insert_value(llvalue, ll.Constant(lli1, True), 0) + return self.llbuilder.insert_value(llvalue, self.map(insn.operands[0]), 1, + name=insn.name) + else: + assert False + elif builtins.is_list(insn.type): + llsize = self.map(insn.operands[0]) + llvalue = ll.Constant(self.llty_of_type(insn.type), ll.Undefined) + llvalue = self.llbuilder.insert_value(llvalue, llsize, 0) + llalloc = self.llbuilder.alloca(self.llty_of_type(builtins.get_iterable_elt(insn.type)), + size=llsize) + llvalue = self.llbuilder.insert_value(llvalue, llalloc, 1, name=insn.name) + return llvalue + elif not builtins.is_allocated(insn.type): + llvalue = ll.Constant(self.llty_of_type(insn.type), ll.Undefined) + for index, elt in enumerate(insn.operands): + llvalue = self.llbuilder.insert_value(llvalue, self.map(elt), index) + llvalue.name = insn.name + return llvalue + else: # catchall for exceptions and custom (allocated) classes + llalloc = self.llbuilder.alloca(self.llty_of_type(insn.type, bare=True)) + for index, operand in enumerate(insn.operands): + lloperand = self.map(operand) + llfieldptr = self.llbuilder.gep(llalloc, [self.llindex(0), self.llindex(index)]) + self.llbuilder.store(lloperand, llfieldptr) + return llalloc + + def llptr_to_var(self, llenv, env_ty, var_name, var_type=None): + if var_name in env_ty.params and (var_type is None or + env_ty.params[var_name] == var_type): + var_index = list(env_ty.params.keys()).index(var_name) + return self.llbuilder.gep(llenv, [self.llindex(0), self.llindex(var_index)]) + else: + outer_index = list(env_ty.params.keys()).index("$outer") + llptr = self.llbuilder.gep(llenv, [self.llindex(0), self.llindex(outer_index)]) + llouterenv = self.llbuilder.load(llptr) + return self.llptr_to_var(llouterenv, env_ty.params["$outer"], var_name) + + def process_GetLocal(self, insn): + env = insn.environment() + llptr = self.llptr_to_var(self.map(env), env.type, insn.var_name) + return self.llbuilder.load(llptr) + + def process_GetConstructor(self, insn): + env = insn.environment() + llptr = self.llptr_to_var(self.map(env), env.type, insn.var_name, insn.type) + return self.llbuilder.load(llptr) + + def process_SetLocal(self, insn): + env = insn.environment() + llptr = self.llptr_to_var(self.map(env), env.type, insn.var_name) + llvalue = self.map(insn.value()) + if isinstance(llvalue, ll.Block): + llvalue = ll.BlockAddress(self.llfunction, llvalue) + if llptr.type.pointee != llvalue.type: + # The environment argument is an i8*, so that all closures can + # unify with each other regardless of environment type or size. + # We fixup the type on assignment into the "$outer" slot. + assert insn.var_name == '$outer' + llvalue = self.llbuilder.bitcast(llvalue, llptr.type.pointee) + return self.llbuilder.store(llvalue, llptr) + + def attr_index(self, insn): + return list(insn.object().type.attributes.keys()).index(insn.attr) + + def process_GetAttr(self, insn): + if types.is_tuple(insn.object().type): + return self.llbuilder.extract_value(self.map(insn.object()), insn.attr, + name=insn.name) + elif not builtins.is_allocated(insn.object().type): + return self.llbuilder.extract_value(self.map(insn.object()), self.attr_index(insn), + name=insn.name) + else: + llptr = self.llbuilder.gep(self.map(insn.object()), + [self.llindex(0), self.llindex(self.attr_index(insn))], + name=insn.name) + return self.llbuilder.load(llptr) + + def process_SetAttr(self, insn): + assert builtins.is_allocated(insn.object().type) + llptr = self.llbuilder.gep(self.map(insn.object()), + [self.llindex(0), self.llindex(self.attr_index(insn))], + name=insn.name) + return self.llbuilder.store(self.map(insn.value()), llptr) + + def process_GetElem(self, insn): + llelts = self.llbuilder.extract_value(self.map(insn.list()), 1) + llelt = self.llbuilder.gep(llelts, [self.map(insn.index())], + inbounds=True) + return self.llbuilder.load(llelt) + + def process_SetElem(self, insn): + llelts = self.llbuilder.extract_value(self.map(insn.list()), 1) + llelt = self.llbuilder.gep(llelts, [self.map(insn.index())], + inbounds=True) + return self.llbuilder.store(self.map(insn.value()), llelt) + + def process_Coerce(self, insn): + typ, value_typ = insn.type, insn.value().type + if builtins.is_int(typ) and builtins.is_float(value_typ): + return self.llbuilder.fptosi(self.map(insn.value()), self.llty_of_type(typ), + name=insn.name) + elif builtins.is_float(typ) and builtins.is_int(value_typ): + return self.llbuilder.sitofp(self.map(insn.value()), self.llty_of_type(typ), + name=insn.name) + elif builtins.is_int(typ) and builtins.is_int(value_typ): + if builtins.get_int_width(typ) > builtins.get_int_width(value_typ): + return self.llbuilder.sext(self.map(insn.value()), self.llty_of_type(typ), + name=insn.name) + else: # builtins.get_int_width(typ) <= builtins.get_int_width(value_typ): + return self.llbuilder.trunc(self.map(insn.value()), self.llty_of_type(typ), + name=insn.name) + else: + assert False + + def process_Arith(self, insn): + if isinstance(insn.op, ast.Add): + if builtins.is_float(insn.type): + return self.llbuilder.fadd(self.map(insn.lhs()), self.map(insn.rhs()), + name=insn.name) + else: + return self.llbuilder.add(self.map(insn.lhs()), self.map(insn.rhs()), + name=insn.name) + elif isinstance(insn.op, ast.Sub): + if builtins.is_float(insn.type): + return self.llbuilder.fsub(self.map(insn.lhs()), self.map(insn.rhs()), + name=insn.name) + else: + return self.llbuilder.sub(self.map(insn.lhs()), self.map(insn.rhs()), + name=insn.name) + elif isinstance(insn.op, ast.Mult): + if builtins.is_float(insn.type): + return self.llbuilder.fmul(self.map(insn.lhs()), self.map(insn.rhs()), + name=insn.name) + else: + return self.llbuilder.mul(self.map(insn.lhs()), self.map(insn.rhs()), + name=insn.name) + elif isinstance(insn.op, ast.Div): + if builtins.is_float(insn.lhs().type): + return self.llbuilder.fdiv(self.map(insn.lhs()), self.map(insn.rhs()), + name=insn.name) + else: + lllhs = self.llbuilder.sitofp(self.map(insn.lhs()), self.llty_of_type(insn.type)) + llrhs = self.llbuilder.sitofp(self.map(insn.rhs()), self.llty_of_type(insn.type)) + return self.llbuilder.fdiv(lllhs, llrhs, + name=insn.name) + elif isinstance(insn.op, ast.FloorDiv): + if builtins.is_float(insn.type): + llvalue = self.llbuilder.fdiv(self.map(insn.lhs()), self.map(insn.rhs())) + return self.llbuilder.call(self.llbuiltin("llvm.floor.f64"), [llvalue], + name=insn.name) + else: + return self.llbuilder.sdiv(self.map(insn.lhs()), self.map(insn.rhs()), + name=insn.name) + elif isinstance(insn.op, ast.Mod): + # Python only has the modulo operator, LLVM only has the remainder + if builtins.is_float(insn.type): + llvalue = self.llbuilder.frem(self.map(insn.lhs()), self.map(insn.rhs())) + return self.llbuilder.call(self.llbuiltin("llvm.copysign.f64"), + [llvalue, self.map(insn.rhs())], + name=insn.name) + else: + lllhs, llrhs = map(self.map, (insn.lhs(), insn.rhs())) + llxorsign = self.llbuilder.and_(self.llbuilder.xor(lllhs, llrhs), + ll.Constant(lllhs.type, 1 << lllhs.type.width - 1)) + llnegate = self.llbuilder.icmp_unsigned('!=', + llxorsign, ll.Constant(llxorsign.type, 0)) + llvalue = self.llbuilder.srem(lllhs, llrhs) + llnegvalue = self.llbuilder.sub(ll.Constant(llvalue.type, 0), llvalue) + return self.llbuilder.select(llnegate, llnegvalue, llvalue) + elif isinstance(insn.op, ast.Pow): + if builtins.is_float(insn.type): + return self.llbuilder.call(self.llbuiltin("llvm.pow.f64"), + [self.map(insn.lhs()), self.map(insn.rhs())], + name=insn.name) + else: + lllhs = self.llbuilder.sitofp(self.map(insn.lhs()), lldouble) + llrhs = self.llbuilder.trunc(self.map(insn.rhs()), lli32) + llvalue = self.llbuilder.call(self.llbuiltin("llvm.powi.f64"), [lllhs, llrhs]) + return self.llbuilder.fptosi(llvalue, self.llty_of_type(insn.type), + name=insn.name) + elif isinstance(insn.op, ast.LShift): + lllhs, llrhs = map(self.map, (insn.lhs(), insn.rhs())) + llrhs_max = ll.Constant(llrhs.type, builtins.get_int_width(insn.lhs().type)) + llrhs_overflow = self.llbuilder.icmp_signed('>=', llrhs, llrhs_max) + llvalue_zero = ll.Constant(lllhs.type, 0) + llvalue = self.llbuilder.shl(lllhs, llrhs) + return self.llbuilder.select(llrhs_overflow, llvalue_zero, llvalue, + name=insn.name) + elif isinstance(insn.op, ast.RShift): + lllhs, llrhs = map(self.map, (insn.lhs(), insn.rhs())) + llrhs_max = ll.Constant(llrhs.type, builtins.get_int_width(insn.lhs().type) - 1) + llrhs_overflow = self.llbuilder.icmp_signed('>', llrhs, llrhs_max) + llvalue = self.llbuilder.ashr(lllhs, llrhs) + llvalue_max = self.llbuilder.ashr(lllhs, llrhs_max) # preserve sign bit + return self.llbuilder.select(llrhs_overflow, llvalue_max, llvalue, + name=insn.name) + elif isinstance(insn.op, ast.BitAnd): + return self.llbuilder.and_(self.map(insn.lhs()), self.map(insn.rhs()), + name=insn.name) + elif isinstance(insn.op, ast.BitOr): + return self.llbuilder.or_(self.map(insn.lhs()), self.map(insn.rhs()), + name=insn.name) + elif isinstance(insn.op, ast.BitXor): + return self.llbuilder.xor(self.map(insn.lhs()), self.map(insn.rhs()), + name=insn.name) + else: + assert False + + def process_Compare(self, insn): + if isinstance(insn.op, (ast.Eq, ast.Is)): + op = '==' + elif isinstance(insn.op, (ast.NotEq, ast.IsNot)): + op = '!=' + elif isinstance(insn.op, ast.Gt): + op = '>' + elif isinstance(insn.op, ast.GtE): + op = '>=' + elif isinstance(insn.op, ast.Lt): + op = '<' + elif isinstance(insn.op, ast.LtE): + op = '<=' + else: + assert False + + lllhs, llrhs = map(self.map, (insn.lhs(), insn.rhs())) + assert lllhs.type == llrhs.type + + if isinstance(lllhs.type, ll.IntType): + return self.llbuilder.icmp_signed(op, lllhs, llrhs, + name=insn.name) + elif isinstance(lllhs.type, ll.PointerType): + return self.llbuilder.icmp_unsigned(op, lllhs, llrhs, + name=insn.name) + elif isinstance(lllhs.type, ll.DoubleType): + return self.llbuilder.fcmp_ordered(op, lllhs, llrhs, + name=insn.name) + elif isinstance(lllhs.type, ll.LiteralStructType): + # Compare aggregates (such as lists or ranges) element-by-element. + llvalue = ll.Constant(lli1, True) + for index in range(len(lllhs.type.elements)): + lllhselt = self.llbuilder.extract_value(lllhs, index) + llrhselt = self.llbuilder.extract_value(llrhs, index) + llresult = self.llbuilder.icmp_unsigned('==', lllhselt, llrhselt) + llvalue = self.llbuilder.select(llresult, llvalue, + ll.Constant(lli1, False)) + return self.llbuilder.icmp_unsigned(op, llvalue, ll.Constant(lli1, True), + name=insn.name) + else: + print(lllhs, llrhs) + assert False + + def process_Builtin(self, insn): + if insn.op == "nop": + return self.llbuilder.call(self.llbuiltin("llvm.donothing"), []) + if insn.op == "abort": + return self.llbuilder.call(self.llbuiltin("llvm.trap"), []) + elif insn.op == "is_some": + lloptarg = self.map(insn.operands[0]) + return self.llbuilder.extract_value(lloptarg, 0, + name=insn.name) + elif insn.op == "unwrap": + lloptarg = self.map(insn.operands[0]) + return self.llbuilder.extract_value(lloptarg, 1, + name=insn.name) + elif insn.op == "unwrap_or": + lloptarg, lldefault = map(self.map, insn.operands) + llhas_arg = self.llbuilder.extract_value(lloptarg, 0) + llarg = self.llbuilder.extract_value(lloptarg, 1) + return self.llbuilder.select(llhas_arg, llarg, lldefault, + name=insn.name) + elif insn.op == "round": + llarg = self.map(insn.operands[0]) + llvalue = self.llbuilder.call(self.llbuiltin("llvm.round.f64"), [llarg]) + return self.llbuilder.fptosi(llvalue, self.llty_of_type(insn.type), + name=insn.name) + elif insn.op == "globalenv": + def get_outer(llenv, env_ty): + if "$outer" in env_ty.params: + outer_index = list(env_ty.params.keys()).index("$outer") + llptr = self.llbuilder.gep(llenv, [self.llindex(0), self.llindex(outer_index)]) + llouterenv = self.llbuilder.load(llptr) + return self.llptr_to_var(llouterenv, env_ty.params["$outer"], var_name) + else: + return llenv + + env, = insn.operands + return get_outer(self.map(env), env.type) + elif insn.op == "len": + lst, = insn.operands + return self.llbuilder.extract_value(self.map(lst), 0) + elif insn.op == "printf": + # We only get integers, floats, pointers and strings here. + llargs = map(self.map, insn.operands) + return self.llbuilder.call(self.llbuiltin(self.target.print_function), llargs, + name=insn.name) + elif insn.op == "exncast": + # This is an identity cast at LLVM IR level. + return self.map(insn.operands[0]) + elif insn.op == "now_mu": + return self.llbuilder.load(self.llbuiltin("now"), name=insn.name) + elif insn.op == "at_mu": + time, = insn.operands + return self.llbuilder.store(self.map(time), self.llbuiltin("now")) + elif insn.op == "delay_mu": + interval, = insn.operands + llnowptr = self.llbuiltin("now") + llnow = self.llbuilder.load(llnowptr) + lladjusted = self.llbuilder.add(llnow, self.map(interval)) + return self.llbuilder.store(lladjusted, llnowptr) + else: + assert False + + def process_Closure(self, insn): + llvalue = ll.Constant(self.llty_of_type(insn.target_function.type), ll.Undefined) + llenv = self.llbuilder.bitcast(self.map(insn.environment()), llptr) + llvalue = self.llbuilder.insert_value(llvalue, llenv, 0) + llvalue = self.llbuilder.insert_value(llvalue, self.map(insn.target_function), 1, + name=insn.name) + return llvalue + + def _prepare_closure_call(self, insn): + llclosure = self.map(insn.target_function()) + llargs = [self.map(arg) for arg in insn.arguments()] + llenv = self.llbuilder.extract_value(llclosure, 0) + llfun = self.llbuilder.extract_value(llclosure, 1) + return llfun, [llenv] + list(llargs) + + def _prepare_ffi_call(self, insn): + llargs = [self.map(arg) for arg in insn.arguments()] + llfunname = insn.target_function().type.name + llfun = self.llmodule.get_global(llfunname) + if llfun is None: + llfunty = ll.FunctionType(self.llty_of_type(insn.type, for_return=True), + [llarg.type for llarg in llargs]) + llfun = ll.Function(self.llmodule, llfunty, + insn.target_function().type.name) + return llfun, list(llargs) + + # See session.c:{send,receive}_rpc_value and comm_generic.py:_{send,receive}_rpc_value. + def _rpc_tag(self, typ, error_handler): + if types.is_tuple(typ): + assert len(typ.elts) < 256 + return b"t" + bytes([len(typ.elts)]) + \ + b"".join([self._rpc_tag(elt_type, error_handler) + for elt_type in typ.elts]) + elif builtins.is_none(typ): + return b"n" + elif builtins.is_bool(typ): + return b"b" + elif builtins.is_int(typ, types.TValue(32)): + return b"i" + elif builtins.is_int(typ, types.TValue(64)): + return b"I" + elif builtins.is_float(typ): + return b"f" + elif builtins.is_str(typ): + return b"s" + elif builtins.is_list(typ): + return b"l" + self._rpc_tag(builtins.get_iterable_elt(typ), + error_handler) + elif builtins.is_range(typ): + return b"r" + self._rpc_tag(builtins.get_iterable_elt(typ), + error_handler) + elif ir.is_option(typ): + return b"o" + self._rpc_tag(typ.params["inner"], + error_handler) + elif '__objectid__' in typ.attributes: + return b"O" + else: + error_handler(typ) + + def _build_rpc(self, fun_loc, fun_type, args, llnormalblock, llunwindblock): + llservice = ll.Constant(lli32, fun_type.service) + + tag = b"" + + for arg in args: + def arg_error_handler(typ): + printer = types.TypePrinter() + note = diagnostic.Diagnostic("note", + "value of type {type}", + {"type": printer.name(typ)}, + arg.loc) + diag = diagnostic.Diagnostic("error", + "type {type} is not supported in remote procedure calls", + {"type": printer.name(arg.type)}, + arg.loc) + self.engine.process(diag) + tag += self._rpc_tag(arg.type, arg_error_handler) + tag += b":" + + def ret_error_handler(typ): + printer = types.TypePrinter() + note = diagnostic.Diagnostic("note", + "value of type {type}", + {"type": printer.name(typ)}, + fun_loc) + diag = diagnostic.Diagnostic("error", + "return type {type} is not supported in remote procedure calls", + {"type": printer.name(fun_type.ret)}, + fun_loc) + self.engine.process(diag) + tag += self._rpc_tag(fun_type.ret, ret_error_handler) + tag += b"\x00" + + lltag = self.llstr_of_str(tag) + + llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), []) + + llargs = [] + for arg in args: + llarg = self.map(arg) + llargslot = self.llbuilder.alloca(llarg.type) + self.llbuilder.store(llarg, llargslot) + llargs.append(llargslot) + + self.llbuilder.call(self.llbuiltin("send_rpc"), + [llservice, lltag] + llargs) + + # Don't waste stack space on saved arguments. + self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr]) + + # T result = { + # void *ptr = NULL; + # loop: int size = rpc_recv("tag", ptr); + # if(size) { ptr = alloca(size); goto loop; } + # else *(T*)ptr + # } + llprehead = self.llbuilder.basic_block + llhead = self.llbuilder.append_basic_block(name=llprehead.name + ".rpc.head") + if llunwindblock: + llheadu = self.llbuilder.append_basic_block(name=llprehead.name + ".rpc.head.unwind") + llalloc = self.llbuilder.append_basic_block(name=llprehead.name + ".rpc.alloc") + lltail = self.llbuilder.append_basic_block(name=llprehead.name + ".rpc.tail") + + llretty = self.llty_of_type(fun_type.ret) + llslot = self.llbuilder.alloca(llretty) + llslotgen = self.llbuilder.bitcast(llslot, llptr) + self.llbuilder.branch(llhead) + + self.llbuilder.position_at_end(llhead) + llphi = self.llbuilder.phi(llslotgen.type) + llphi.add_incoming(llslotgen, llprehead) + if llunwindblock: + llsize = self.llbuilder.invoke(self.llbuiltin("recv_rpc"), [llphi], + llheadu, llunwindblock) + self.llbuilder.position_at_end(llheadu) + else: + llsize = self.llbuilder.call(self.llbuiltin("recv_rpc"), [llphi]) + lldone = self.llbuilder.icmp_unsigned('==', llsize, ll.Constant(llsize.type, 0)) + self.llbuilder.cbranch(lldone, lltail, llalloc) + + self.llbuilder.position_at_end(llalloc) + llalloca = self.llbuilder.alloca(lli8, llsize) + llphi.add_incoming(llalloca, llalloc) + self.llbuilder.branch(llhead) + + self.llbuilder.position_at_end(lltail) + llret = self.llbuilder.load(llslot) + if not builtins.is_allocated(fun_type.ret): + # We didn't allocate anything except the slot for the value itself. + # Don't waste stack space. + self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr]) + if llnormalblock: + self.llbuilder.branch(llnormalblock) + return llret + + def process_Call(self, insn): + if types.is_rpc_function(insn.target_function().type): + return self._build_rpc(insn.target_function().loc, + insn.target_function().type, + insn.arguments(), + llnormalblock=None, llunwindblock=None) + elif types.is_c_function(insn.target_function().type): + llfun, llargs = self._prepare_ffi_call(insn) + return self.llbuilder.call(llfun, llargs, + name=insn.name) + else: + llfun, llargs = self._prepare_closure_call(insn) + + if llfun.type.pointee.__has_sret: + llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), []) + + llresultslot = self.llbuilder.alloca(llfun.type.pointee.args[0].pointee) + self.llbuilder.call(llfun, [llresultslot] + llargs) + llresult = self.llbuilder.load(llresultslot) + + self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr]) + + return llresult + else: + return self.llbuilder.call(llfun, llargs, + name=insn.name) + + def process_Invoke(self, insn): + llnormalblock = self.map(insn.normal_target()) + llunwindblock = self.map(insn.exception_target()) + if types.is_rpc_function(insn.target_function().type): + return self._build_rpc(insn.target_function().loc, + insn.target_function().type, + insn.arguments(), + llnormalblock, llunwindblock) + elif types.is_c_function(insn.target_function().type): + llfun, llargs = self._prepare_ffi_call(insn) + return self.llbuilder.invoke(llfun, llargs, llnormalblock, llunwindblock, + name=insn.name) + else: + llfun, llargs = self._prepare_closure_call(insn) + return self.llbuilder.invoke(llfun, llargs, llnormalblock, llunwindblock, + name=insn.name) + + def _quote(self, value, typ, path): + value_id = id(value) + if value_id in self.llobject_map: + return self.llobject_map[value_id] + + global_name = "" + llty = self.llty_of_type(typ) + + if types.is_constructor(typ) or types.is_instance(typ): + llfields = [] + for attr in typ.attributes: + if attr == "__objectid__": + objectid = self.object_map.store(value) + llfields.append(ll.Constant(lli32, objectid)) + global_name = "object.{}".format(objectid) + else: + llfields.append(self._quote(getattr(value, attr), typ.attributes[attr], + lambda: path() + [attr])) + llconst = ll.Constant(llty.pointee, llfields) + + llglobal = ll.GlobalVariable(self.llmodule, llconst.type, global_name) + llglobal.initializer = llconst + llglobal.linkage = "private" + self.llobject_map[value_id] = llglobal + return llglobal + elif builtins.is_none(typ): + assert value is None + return ll.Constant.literal_struct([]) + elif builtins.is_bool(typ): + assert value in (True, False) + return ll.Constant(llty, value) + elif builtins.is_int(typ): + assert isinstance(value, (int, language_core.int)) + return ll.Constant(llty, int(value)) + elif builtins.is_float(typ): + assert isinstance(value, float) + return ll.Constant(llty, value) + elif builtins.is_str(typ): + assert isinstance(value, (str, bytes)) + return self.llstr_of_str(value) + elif builtins.is_list(typ): + assert isinstance(value, list) + elt_type = builtins.get_iterable_elt(typ) + llelts = [self._quote(value[i], elt_type, lambda: path() + [str(i)]) + for i in range(len(value))] + lleltsary = ll.Constant(ll.ArrayType(llelts[0].type, len(llelts)), llelts) + + llglobal = ll.GlobalVariable(self.llmodule, lleltsary.type, "quoted.list") + llglobal.initializer = lleltsary + llglobal.linkage = "private" + + lleltsptr = llglobal.bitcast(lleltsary.type.element.as_pointer()) + llconst = ll.Constant(llty, [ll.Constant(lli32, len(llelts)), lleltsptr]) + return llconst + elif types.is_function(typ): + # RPC and C functions have no runtime representation; ARTIQ + # functions are initialized explicitly. + return ll.Constant(llty, ll.Undefined) + else: + print(typ) + assert False + + def process_Quote(self, insn): + assert self.object_map is not None + return self._quote(insn.value, insn.type, lambda: [repr(insn.value)]) + + def process_Select(self, insn): + return self.llbuilder.select(self.map(insn.condition()), + self.map(insn.if_true()), self.map(insn.if_false())) + + def process_Branch(self, insn): + return self.llbuilder.branch(self.map(insn.target())) + + process_Delay = process_Branch + + def process_BranchIf(self, insn): + return self.llbuilder.cbranch(self.map(insn.condition()), + self.map(insn.if_true()), self.map(insn.if_false())) + + def process_IndirectBranch(self, insn): + llinsn = self.llbuilder.branch_indirect(self.map(insn.target())) + for dest in insn.destinations(): + llinsn.add_destination(self.map(dest)) + return llinsn + + def process_Return(self, insn): + if builtins.is_none(insn.value().type): + return self.llbuilder.ret_void() + else: + if self.llfunction.type.pointee.__has_sret: + self.llbuilder.store(self.map(insn.value()), self.llfunction.args[0]) + return self.llbuilder.ret_void() + else: + return self.llbuilder.ret(self.map(insn.value())) + + def process_Unreachable(self, insn): + return self.llbuilder.unreachable() + + def process_Raise(self, insn): + llexn = self.map(insn.value()) + if insn.exception_target() is not None: + llnormalblock = self.llfunction.append_basic_block("unreachable") + llnormalblock.terminator = ll.Unreachable(llnormalblock) + llnormalblock.instructions.append(llnormalblock.terminator) + + llunwindblock = self.map(insn.exception_target()) + llinsn = self.llbuilder.invoke(self.llbuiltin("__artiq_raise"), [llexn], + llnormalblock, llunwindblock, + name=insn.name) + else: + llinsn = self.llbuilder.call(self.llbuiltin("__artiq_raise"), [llexn], + name=insn.name) + self.llbuilder.unreachable() + llinsn.attributes.add('noreturn') + return llinsn + + def process_Reraise(self, insn): + llinsn = self.llbuilder.call(self.llbuiltin("__artiq_reraise"), [], + name=insn.name) + llinsn.attributes.add('noreturn') + self.llbuilder.unreachable() + return llinsn + + def process_LandingPad(self, insn): + # Layout on return from landing pad: {%_Unwind_Exception*, %Exception*} + lllandingpadty = ll.LiteralStructType([llptr, llptr]) + lllandingpad = self.llbuilder.landingpad(lllandingpadty, + self.llbuiltin("__artiq_personality"), + cleanup=True) + llrawexn = self.llbuilder.extract_value(lllandingpad, 1) + llexn = self.llbuilder.bitcast(llrawexn, self.llty_of_type(insn.type)) + llexnnameptr = self.llbuilder.gep(llexn, [self.llindex(0), self.llindex(0)]) + llexnname = self.llbuilder.load(llexnnameptr) + + for target, typ in insn.clauses(): + if typ is None: + llclauseexnname = ll.Constant( + self.llty_of_type(ir.TExceptionTypeInfo()), None) + else: + llclauseexnname = self.llconst_of_const( + ir.Constant(typ.name, ir.TExceptionTypeInfo())) + lllandingpad.add_clause(ll.CatchClause(llclauseexnname)) + + if typ is None: + self.llbuilder.branch(self.map(target)) + else: + llmatchingclause = self.llbuilder.icmp_unsigned('==', llexnname, llclauseexnname) + with self.llbuilder.if_then(llmatchingclause): + self.llbuilder.branch(self.map(target)) + + if self.llbuilder.basic_block.terminator is None: + self.llbuilder.branch(self.map(insn.cleanup())) + + return llexn diff --git a/artiq/compiler/types.py b/artiq/compiler/types.py new file mode 100644 index 000000000..9cac1594a --- /dev/null +++ b/artiq/compiler/types.py @@ -0,0 +1,657 @@ +""" +The :mod:`types` module contains the classes describing the types +in :mod:`asttyped`. +""" + +import string +from collections import OrderedDict +from pythonparser import diagnostic +from . import iodelay + + +class UnificationError(Exception): + def __init__(self, typea, typeb): + self.typea, self.typeb = typea, typeb + + +def genalnum(): + ident = ["a"] + while True: + yield "".join(ident) + pos = len(ident) - 1 + while pos >= 0: + cur_n = string.ascii_lowercase.index(ident[pos]) + if cur_n < 25: + ident[pos] = string.ascii_lowercase[cur_n + 1] + break + else: + ident[pos] = "a" + pos -= 1 + if pos < 0: + ident = ["a"] + ident + +def _freeze(dict_): + return tuple((key, dict_[key]) for key in dict_) + +def _map_find(elts): + if isinstance(elts, list): + return [x.find() for x in elts] + elif isinstance(elts, dict): + return {k: elts[k].find() for k in elts} + else: + assert False + + +class Type(object): + def __str__(self): + return TypePrinter().name(self) + +class TVar(Type): + """ + A type variable. + + In effect, the classic union-find data structure is intrusively + folded into this class. + """ + + def __init__(self): + self.parent = self + + def find(self): + if self.parent is self: + return self + else: + root = self.parent.find() + self.parent = root # path compression + return root + + def unify(self, other): + other = other.find() + + if self.parent is self: + self.parent = other + else: + self.find().unify(other) + + def fold(self, accum, fn): + if self.parent is self: + return fn(accum, self) + else: + return self.find().fold(accum, fn) + + def __repr__(self): + if self.parent is self: + return "" % id(self) + else: + return repr(self.find()) + + # __eq__ and __hash__ are not overridden and default to + # comparison by identity. Use .find() explicitly before + # any lookups or comparisons. + +class TMono(Type): + """ + A monomorphic type, possibly parametric. + + :class:`TMono` is supposed to be subclassed by builtin types, + unlike all other :class:`Type` descendants. Similarly, + instances of :class:`TMono` should never be allocated directly, + as that will break the type-sniffing code in :mod:`builtins`. + """ + + attributes = OrderedDict() + + def __init__(self, name, params={}): + assert isinstance(params, (dict, OrderedDict)) + self.name, self.params = name, OrderedDict(sorted(params.items())) + + def find(self): + return self + + def unify(self, other): + if isinstance(other, TMono) and self.name == other.name: + assert self.params.keys() == other.params.keys() + for param in self.params: + self.params[param].unify(other.params[param]) + elif isinstance(other, TVar): + other.unify(self) + else: + raise UnificationError(self, other) + + def fold(self, accum, fn): + for param in self.params: + accum = self.params[param].fold(accum, fn) + return fn(accum, self) + + def __repr__(self): + return "artiq.compiler.types.TMono(%s, %s)" % (repr(self.name), repr(self.params)) + + def __getitem__(self, param): + return self.params[param] + + def __eq__(self, other): + return isinstance(other, TMono) and \ + self.name == other.name and \ + _map_find(self.params) == _map_find(other.params) + + def __ne__(self, other): + return not (self == other) + + def __hash__(self): + return hash((self.name, _freeze(self.params))) + +class TTuple(Type): + """ + A tuple type. + + :ivar elts: (list of :class:`Type`) elements + """ + + attributes = OrderedDict() + + def __init__(self, elts=[]): + self.elts = elts + + def find(self): + return self + + def unify(self, other): + if isinstance(other, TTuple) and len(self.elts) == len(other.elts): + for selfelt, otherelt in zip(self.elts, other.elts): + selfelt.unify(otherelt) + elif isinstance(other, TVar): + other.unify(self) + else: + raise UnificationError(self, other) + + def fold(self, accum, fn): + for elt in self.elts: + accum = elt.fold(accum, fn) + return fn(accum, self) + + def __repr__(self): + return "artiq.compiler.types.TTuple(%s)" % repr(self.elts) + + def __eq__(self, other): + return isinstance(other, TTuple) and \ + _map_find(self.elts) == _map_find(other.elts) + + def __ne__(self, other): + return not (self == other) + +class _TPointer(TMono): + def __init__(self): + super().__init__("pointer") + +class TFunction(Type): + """ + A function type. + + :ivar args: (:class:`collections.OrderedDict` of string to :class:`Type`) + mandatory arguments + :ivar optargs: (:class:`collections.OrderedDict` of string to :class:`Type`) + optional arguments + :ivar ret: (:class:`Type`) + return type + :ivar delay: (:class:`Type`) + RTIO delay + """ + + attributes = OrderedDict([ + ('__code__', _TPointer()), + ('__closure__', _TPointer()), + ]) + + def __init__(self, args, optargs, ret): + assert isinstance(args, OrderedDict) + assert isinstance(optargs, OrderedDict) + assert isinstance(ret, Type) + self.args, self.optargs, self.ret = args, optargs, ret + self.delay = TVar() + + def arity(self): + return len(self.args) + len(self.optargs) + + def find(self): + return self + + def unify(self, other): + if isinstance(other, TFunction) and \ + self.args.keys() == other.args.keys() and \ + self.optargs.keys() == other.optargs.keys(): + for selfarg, otherarg in zip(list(self.args.values()) + list(self.optargs.values()), + list(other.args.values()) + list(other.optargs.values())): + selfarg.unify(otherarg) + self.ret.unify(other.ret) + self.delay.unify(other.delay) + elif isinstance(other, TVar): + other.unify(self) + else: + raise UnificationError(self, other) + + def fold(self, accum, fn): + for arg in self.args: + accum = self.args[arg].fold(accum, fn) + for optarg in self.optargs: + accum = self.optargs[optarg].fold(accum, fn) + accum = self.ret.fold(accum, fn) + return fn(accum, self) + + def __repr__(self): + return "artiq.compiler.types.TFunction({}, {}, {})".format( + repr(self.args), repr(self.optargs), repr(self.ret)) + + def __eq__(self, other): + return isinstance(other, TFunction) and \ + _map_find(self.args) == _map_find(other.args) and \ + _map_find(self.optargs) == _map_find(other.optargs) + + def __ne__(self, other): + return not (self == other) + + def __hash__(self): + return hash((_freeze(self.args), _freeze(self.optargs), self.ret)) + +class TRPCFunction(TFunction): + """ + A function type of a remote function. + + :ivar service: (int) RPC service number + """ + + attributes = OrderedDict() + + def __init__(self, args, optargs, ret, service): + super().__init__(args, optargs, ret) + self.service = service + self.delay = TFixedDelay(iodelay.Const(0)) + + def unify(self, other): + if isinstance(other, TRPCFunction) and \ + self.service == other.service: + super().unify(other) + elif isinstance(other, TVar): + other.unify(self) + else: + raise UnificationError(self, other) + +class TCFunction(TFunction): + """ + A function type of a runtime-provided C function. + + :ivar name: (str) C function name + """ + + attributes = OrderedDict() + + def __init__(self, args, ret, name): + super().__init__(args, OrderedDict(), ret) + self.name = name + self.delay = TFixedDelay(iodelay.Const(0)) + + def unify(self, other): + if isinstance(other, TCFunction) and \ + self.name == other.name: + super().unify(other) + elif isinstance(other, TVar): + other.unify(self) + else: + raise UnificationError(self, other) + +class TBuiltin(Type): + """ + An instance of builtin type. Every instance of a builtin + type is treated specially according to its name. + """ + + def __init__(self, name): + assert isinstance(name, str) + self.name = name + self.attributes = OrderedDict() + + def find(self): + return self + + def unify(self, other): + if self != other: + raise UnificationError(self, other) + + def fold(self, accum, fn): + return fn(accum, self) + + def __repr__(self): + return "artiq.compiler.types.{}({})".format(type(self).__name__, repr(self.name)) + + def __eq__(self, other): + return isinstance(other, TBuiltin) and \ + self.name == other.name + + def __ne__(self, other): + return not (self == other) + + def __hash__(self): + return hash(self.name) + +class TBuiltinFunction(TBuiltin): + """ + A type of a builtin function. + """ + +class TConstructor(TBuiltin): + """ + A type of a constructor of a class, e.g. ``list``. + Note that this is not the same as the type of an instance of + the class, which is ``TMono("list", ...)`` (or a descendant). + + :ivar instance: (:class:`Type`) + the type of the instance created by this constructor + """ + + def __init__(self, instance): + assert isinstance(instance, TMono) + super().__init__(instance.name) + self.instance = instance + +class TExceptionConstructor(TConstructor): + """ + A type of a constructor of an exception, e.g. ``Exception``. + Note that this is not the same as the type of an instance of + the class, which is ``TMono("Exception", ...)``. + """ + +class TInstance(TMono): + """ + A type of an instance of a user-defined class. + + :ivar constructor: (:class:`TConstructor`) + the type of the constructor with which this instance + was created + """ + + def __init__(self, name, attributes): + assert isinstance(attributes, OrderedDict) + super().__init__(name) + self.attributes = attributes + + def __repr__(self): + return "artiq.compiler.types.TInstance({}, {})".format( + repr(self.name), repr(self.attributes)) + +class TMethod(TMono): + """ + A type of a method. + """ + + def __init__(self, self_type, function_type): + super().__init__("method", {"self": self_type, "fn": function_type}) + self.attributes = OrderedDict([ + ("__func__", function_type), + ("__self__", self_type), + ]) + +class TValue(Type): + """ + A type-level value (such as the integer denoting width of + a generic integer type. + """ + + def __init__(self, value): + self.value = value + + def find(self): + return self + + def unify(self, other): + if isinstance(other, TVar): + other.unify(self) + elif self != other: + raise UnificationError(self, other) + + def fold(self, accum, fn): + return fn(accum, self) + + def __repr__(self): + return "artiq.compiler.types.TValue(%s)" % repr(self.value) + + def __eq__(self, other): + return isinstance(other, TValue) and \ + self.value == other.value + + def __ne__(self, other): + return not (self == other) + +class TDelay(Type): + """ + The type-level representation of IO delay. + """ + + def __init__(self, duration, cause): + assert duration is None or isinstance(duration, iodelay.Expr) + assert cause is None or isinstance(cause, diagnostic.Diagnostic) + assert (not (duration and cause)) and (duration or cause) + self.duration, self.cause = duration, cause + + def is_fixed(self): + return self.duration is not None + + def is_indeterminate(self): + return self.cause is not None + + def find(self): + return self + + def unify(self, other): + other = other.find() + + if self.is_fixed() and other.is_fixed() and \ + self.duration.fold() == other.duration.fold(): + pass + elif isinstance(other, TVar): + other.unify(self) + else: + raise UnificationError(self, other) + + def fold(self, accum, fn): + # delay types do not participate in folding + pass + + def __eq__(self, other): + return isinstance(other, TDelay) and \ + (self.duration == other.duration and \ + self.cause == other.cause) + + def __ne__(self, other): + return not (self == other) + + def __repr__(self): + if self.duration is None: + return "<{}.TIndeterminateDelay>".format(__name__) + elif self.cause is None: + return "{}.TFixedDelay({})".format(__name__, self.duration) + else: + assert False + +def TIndeterminateDelay(cause): + return TDelay(None, cause) + +def TFixedDelay(duration): + return TDelay(duration, None) + + +def is_var(typ): + return isinstance(typ.find(), TVar) + +def is_mono(typ, name=None, **params): + typ = typ.find() + params_match = True + for param in params: + if param not in typ.params: + return False + params_match = params_match and \ + typ.params[param].find() == params[param].find() + return isinstance(typ, TMono) and \ + (name is None or (typ.name == name and params_match)) + +def is_polymorphic(typ): + return typ.fold(False, lambda accum, typ: accum or is_var(typ)) + +def is_tuple(typ, elts=None): + typ = typ.find() + if elts: + return isinstance(typ, TTuple) and \ + elts == typ.elts + else: + return isinstance(typ, TTuple) + +def _is_pointer(typ): + return isinstance(typ.find(), _TPointer) + +def is_function(typ): + return isinstance(typ.find(), TFunction) + +def is_rpc_function(typ): + return isinstance(typ.find(), TRPCFunction) + +def is_c_function(typ): + return isinstance(typ.find(), TCFunction) + +def is_builtin(typ, name=None): + typ = typ.find() + if name is None: + return isinstance(typ, TBuiltin) + else: + return isinstance(typ, TBuiltin) and \ + typ.name == name + +def is_constructor(typ, name=None): + typ = typ.find() + if name is not None: + return isinstance(typ, TConstructor) and \ + typ.name == name + else: + return isinstance(typ, TConstructor) + +def is_exn_constructor(typ, name=None): + typ = typ.find() + if name is not None: + return isinstance(typ, TExceptionConstructor) and \ + typ.name == name + else: + return isinstance(typ, TExceptionConstructor) + +def is_instance(typ, name=None): + typ = typ.find() + if name is not None: + return isinstance(typ, TInstance) and \ + typ.name == name + else: + return isinstance(typ, TInstance) + +def is_method(typ): + return isinstance(typ.find(), TMethod) + +def get_method_self(typ): + if is_method(typ): + return typ.find().params["self"] + +def get_method_function(typ): + if is_method(typ): + return typ.find().params["fn"] + +def is_value(typ): + return isinstance(typ.find(), TValue) + +def get_value(typ): + typ = typ.find() + if isinstance(typ, TVar): + return None + elif isinstance(typ, TValue): + return typ.value + else: + assert False + +def is_delay(typ): + return isinstance(typ.find(), TDelay) + +def is_fixed_delay(typ): + return is_delay(typ) and typ.find().is_fixed() + +def is_indeterminate_delay(typ): + return is_delay(typ) and typ.find().is_indeterminate() + + +class TypePrinter(object): + """ + A class that prints types using Python-like syntax and gives + type variables sequential alphabetic names. + """ + + def __init__(self): + self.gen = genalnum() + self.map = {} + self.recurse_guard = set() + + def name(self, typ): + typ = typ.find() + if isinstance(typ, TVar): + if typ not in self.map: + self.map[typ] = "'%s" % next(self.gen) + return self.map[typ] + elif isinstance(typ, TInstance): + if typ in self.recurse_guard: + return "".format(typ.name) + else: + self.recurse_guard.add(typ) + attrs = ", ".join(["{}: {}".format(attr, self.name(typ.attributes[attr])) + for attr in typ.attributes]) + return "".format(typ.name, attrs) + elif isinstance(typ, TMono): + if typ.params == {}: + return typ.name + else: + return "%s(%s)" % (typ.name, ", ".join( + ["%s=%s" % (k, self.name(typ.params[k])) for k in typ.params])) + elif isinstance(typ, TTuple): + if len(typ.elts) == 1: + return "(%s,)" % self.name(typ.elts[0]) + else: + return "(%s)" % ", ".join(list(map(self.name, typ.elts))) + elif isinstance(typ, (TFunction, TRPCFunction, TCFunction)): + args = [] + args += [ "%s:%s" % (arg, self.name(typ.args[arg])) for arg in typ.args] + args += ["?%s:%s" % (arg, self.name(typ.optargs[arg])) for arg in typ.optargs] + signature = "(%s)->%s" % (", ".join(args), self.name(typ.ret)) + + delay = typ.delay.find() + if isinstance(delay, TVar): + signature += " delay({})".format(self.name(delay)) + elif not (delay.is_fixed() and iodelay.is_zero(delay.duration)): + signature += " " + self.name(delay) + + if isinstance(typ, TRPCFunction): + return "rpc({}) {}".format(typ.service, signature) + if isinstance(typ, TCFunction): + return "ffi({}) {}".format(repr(typ.name), signature) + elif isinstance(typ, TFunction): + return signature + elif isinstance(typ, TBuiltinFunction): + return "".format(typ.name) + elif isinstance(typ, (TConstructor, TExceptionConstructor)): + if typ in self.recurse_guard: + return "".format(typ.name) + else: + self.recurse_guard.add(typ) + attrs = ", ".join(["{}: {}".format(attr, self.name(typ.attributes[attr])) + for attr in typ.attributes]) + return "".format(typ.name, attrs) + elif isinstance(typ, TValue): + return repr(typ.value) + elif isinstance(typ, TDelay): + if typ.is_fixed(): + return "delay({} mu)".format(typ.duration) + elif typ.is_indeterminate(): + return "delay(?)" + else: + assert False + else: + assert False diff --git a/artiq/compiler/validators/__init__.py b/artiq/compiler/validators/__init__.py new file mode 100644 index 000000000..7f0719ea9 --- /dev/null +++ b/artiq/compiler/validators/__init__.py @@ -0,0 +1,3 @@ +from .monomorphism import MonomorphismValidator +from .escape import EscapeValidator +from .local_access import LocalAccessValidator diff --git a/artiq/compiler/validators/escape.py b/artiq/compiler/validators/escape.py new file mode 100644 index 000000000..386cd388b --- /dev/null +++ b/artiq/compiler/validators/escape.py @@ -0,0 +1,311 @@ +""" +:class:`EscapeValidator` verifies that no mutable data escapes +the region of its allocation. +""" + +import functools +from pythonparser import algorithm, diagnostic +from .. import asttyped, types, builtins + +def has_region(typ): + return typ.fold(False, lambda accum, typ: accum or builtins.is_allocated(typ)) + +class Region: + """ + A last-in-first-out allocation region. Tied to lexical scoping + and is internally represented simply by a source range. + + :ivar range: (:class:`pythonparser.source.Range` or None) + """ + + def __init__(self, source_range=None): + self.range = source_range + + def present(self): + return bool(self.range) + + def includes(self, other): + assert self.range + assert self.range.source_buffer == other.range.source_buffer + + return self.range.begin_pos <= other.range.begin_pos and \ + self.range.end_pos >= other.range.end_pos + + def intersects(self, other): + assert self.range.source_buffer == other.range.source_buffer + assert self.range + + return (self.range.begin_pos <= other.range.begin_pos <= self.range.end_pos and \ + other.range.end_pos > self.range.end_pos) or \ + (other.range.begin_pos <= self.range.begin_pos <= other.range.end_pos and \ + self.range.end_pos > other.range.end_pos) + + def contract(self, other): + if not self.range: + self.range = other.range + + def outlives(lhs, rhs): + if lhs is None: # lhs lives forever + return True + elif rhs is None: # rhs lives forever, lhs does not + return False + else: + assert not lhs.intersects(rhs) + return lhs.includes(rhs) + + def __repr__(self): + return "Region({})".format(repr(self.range)) + +class RegionOf(algorithm.Visitor): + """ + Visit an expression and return the list of regions that must + be alive for the expression to execute. + """ + + def __init__(self, env_stack, youngest_region): + self.env_stack, self.youngest_region = env_stack, youngest_region + + # Liveness determined by assignments + def visit_NameT(self, node): + # First, look at stack regions + for region in reversed(self.env_stack[1:]): + if node.id in region: + return region[node.id] + + # Then, look at the global region of this module + if node.id in self.env_stack[0]: + return None + + assert False + + # Value lives as long as the current scope, if it's mutable, + # or else forever + def visit_sometimes_allocating(self, node): + if has_region(node.type): + return self.youngest_region + else: + return None + + visit_BinOpT = visit_sometimes_allocating + visit_CallT = visit_sometimes_allocating + + # Value lives as long as the object/container, if it's mutable, + # or else forever + def visit_accessor(self, node): + if has_region(node.type): + return self.visit(node.value) + else: + return None + + visit_AttributeT = visit_accessor + visit_SubscriptT = visit_accessor + + # Value lives as long as the shortest living operand + def visit_selecting(self, nodes): + regions = [self.visit(node) for node in nodes] + regions = list(filter(lambda x: x, regions)) + if any(regions): + regions.sort(key=functools.cmp_to_key(Region.outlives), reverse=True) + return regions[0] + else: + return None + + def visit_BoolOpT(self, node): + return self.visit_selecting(node.values) + + def visit_IfExpT(self, node): + return self.visit_selecting([node.body, node.orelse]) + + def visit_TupleT(self, node): + return self.visit_selecting(node.elts) + + # Value lives as long as the current scope + def visit_allocating(self, node): + return self.youngest_region + + visit_DictT = visit_allocating + visit_DictCompT = visit_allocating + visit_GeneratorExpT = visit_allocating + visit_LambdaT = visit_allocating + visit_ListT = visit_allocating + visit_ListCompT = visit_allocating + visit_SetT = visit_allocating + visit_SetCompT = visit_allocating + + # Value lives forever + def visit_immutable(self, node): + assert not has_region(node.type) + return None + + visit_NameConstantT = visit_immutable + visit_NumT = visit_immutable + visit_EllipsisT = visit_immutable + visit_UnaryOpT = visit_immutable + visit_CompareT = visit_immutable + + # Value is mutable, but still lives forever + def visit_StrT(self, node): + return None + + # Not implemented + def visit_unimplemented(self, node): + assert False + + visit_StarredT = visit_unimplemented + visit_YieldT = visit_unimplemented + visit_YieldFromT = visit_unimplemented + + +class AssignedNamesOf(algorithm.Visitor): + """ + Visit an expression and return the list of names that appear + on the lhs of assignment, directly or through an accessor. + """ + + def visit_NameT(self, node): + return [node] + + def visit_accessor(self, node): + return self.visit(node.value) + + visit_AttributeT = visit_accessor + visit_SubscriptT = visit_accessor + + def visit_sequence(self, node): + return reduce(list.__add__, map(self.visit, node.elts)) + + visit_TupleT = visit_sequence + visit_ListT = visit_sequence + + def visit_StarredT(self, node): + assert False + + +class EscapeValidator(algorithm.Visitor): + def __init__(self, engine): + self.engine = engine + self.youngest_region = None + self.env_stack = [] + self.youngest_env = None + + def _region_of(self, expr): + return RegionOf(self.env_stack, self.youngest_region).visit(expr) + + def _names_of(self, expr): + return AssignedNamesOf().visit(expr) + + def _diagnostics_for(self, region, loc, descr="the value of the expression"): + if region: + return [ + diagnostic.Diagnostic("note", + "{descr} is alive from this point...", {"descr": descr}, + region.range.begin()), + diagnostic.Diagnostic("note", + "... to this point", {}, + region.range.end()) + ] + else: + return [ + diagnostic.Diagnostic("note", + "{descr} is alive forever", {"descr": descr}, + loc) + ] + + def visit_in_region(self, node, region, typing_env): + try: + old_youngest_region = self.youngest_region + self.youngest_region = region + + old_youngest_env = self.youngest_env + self.youngest_env = {} + + for name in typing_env: + if has_region(typing_env[name]): + self.youngest_env[name] = Region(None) # not yet known + else: + self.youngest_env[name] = None # lives forever + self.env_stack.append(self.youngest_env) + + self.generic_visit(node) + finally: + self.env_stack.pop() + self.youngest_env = old_youngest_env + self.youngest_region = old_youngest_region + + def visit_ModuleT(self, node): + self.visit_in_region(node, None, node.typing_env) + + def visit_FunctionDefT(self, node): + self.youngest_env[node.name] = self.youngest_region + self.visit_in_region(node, Region(node.loc), node.typing_env) + + def visit_ClassDefT(self, node): + self.youngest_env[node.name] = self.youngest_region + self.visit_in_region(node, Region(node.loc), node.constructor_type.attributes) + + # Only three ways for a pointer to escape: + # * Assigning or op-assigning it (we ensure an outlives relationship) + # * Returning it (we only allow returning values that live forever) + # * Raising it (we forbid allocating exceptions that refer to mutable data)¹ + # + # Literals doesn't count: a constructed object is always + # outlived by all its constituents. + # Closures don't count: see above. + # Calling functions doesn't count: arguments never outlive + # the function body. + # + # ¹Strings are currently never allocated with a limited lifetime, + # and exceptions can only refer to strings, so we don't actually check + # this property. But we will need to, if string operations are ever added. + + def visit_assignment(self, target, value, is_aug_assign=False): + target_region = self._region_of(target) + value_region = self._region_of(value) if not is_aug_assign else self.youngest_region + + # If this is a variable, we might need to contract the live range. + if value_region is not None: + for name in self._names_of(target): + region = self._region_of(name) + if region is not None: + region.contract(value_region) + + # The assigned value should outlive the assignee + if not Region.outlives(value_region, target_region): + if is_aug_assign: + target_desc = "the assignment target, allocated here," + else: + target_desc = "the assignment target" + note = diagnostic.Diagnostic("note", + "this expression has type {type}", + {"type": types.TypePrinter().name(value.type)}, + value.loc) + diag = diagnostic.Diagnostic("error", + "the assigned value does not outlive the assignment target", {}, + value.loc, [target.loc], + notes=self._diagnostics_for(target_region, target.loc, + target_desc) + + self._diagnostics_for(value_region, value.loc, + "the assigned value")) + self.engine.process(diag) + + def visit_Assign(self, node): + for target in node.targets: + self.visit_assignment(target, node.value) + + def visit_AugAssign(self, node): + if builtins.is_allocated(node.target.type): + # If the target is mutable, op-assignment will allocate + # in the youngest region. + self.visit_assignment(node.target, node.value, is_aug_assign=True) + + def visit_Return(self, node): + region = self._region_of(node.value) + if region: + note = diagnostic.Diagnostic("note", + "this expression has type {type}", + {"type": types.TypePrinter().name(node.value.type)}, + node.value.loc) + diag = diagnostic.Diagnostic("error", + "cannot return a mutable value that does not live forever", {}, + node.value.loc, notes=self._diagnostics_for(region, node.value.loc) + [note]) + self.engine.process(diag) diff --git a/artiq/compiler/validators/local_access.py b/artiq/compiler/validators/local_access.py new file mode 100644 index 000000000..fed87092b --- /dev/null +++ b/artiq/compiler/validators/local_access.py @@ -0,0 +1,175 @@ +""" +:class:`LocalAccessValidator` verifies that local variables +are not accessed before being used. +""" + +from functools import reduce +from pythonparser import diagnostic +from .. import ir, analyses + +def is_special_variable(name): + return "$" in name + +class LocalAccessValidator: + def __init__(self, engine): + self.engine = engine + + def process(self, functions): + for func in functions: + self.process_function(func) + + def process_function(self, func): + # Find all environments and closures allocated in this func. + environments, closures = [], [] + for insn in func.instructions(): + if isinstance(insn, ir.Alloc) and ir.is_environment(insn.type): + environments.append(insn) + elif isinstance(insn, ir.Closure): + closures.append(insn) + + # Compute initial state of interesting environments. + # Environments consisting only of internal variables (containing a ".") + # are ignored. + initial_state = {} + for env in environments: + env_state = {var: False for var in env.type.params if "." not in var} + if any(env_state): + initial_state[env] = env_state + + # Traverse the acyclic graph made of basic blocks and forward edges only, + # while updating the environment state. + domtree = analyses.DominatorTree(func) + state = {} + def traverse(block): + # Have we computed the state of this block already? + if block in state: + return state[block] + + # No! Which forward edges lead to this block? + # If we dominate a predecessor, it's a back edge instead. + forward_edge_preds = [pred for pred in block.predecessors() + if block not in domtree.dominators(pred)] + + # Figure out what the state is before the leader + # instruction of this block. + pred_states = [traverse(pred) for pred in forward_edge_preds] + block_state = {} + if len(pred_states) > 1: + for env in initial_state: + # The variable has to be initialized in all predecessors + # in order to be initialized in this block. + def merge_state(a, b): + return {var: a[var] and b[var] for var in a} + block_state[env] = reduce(merge_state, + [state[env] for state in pred_states]) + elif len(pred_states) == 1: + # The state is the same as at the terminator of predecessor. + # We'll mutate it, so copy. + pred_state = pred_states[0] + for env in initial_state: + env_state = pred_state[env] + block_state[env] = {var: env_state[var] for var in env_state} + else: + # This is the entry block. + for env in initial_state: + env_state = initial_state[env] + block_state[env] = {var: env_state[var] for var in env_state} + + # Update the state based on block contents, while validating + # that no access to uninitialized variables will be done. + for insn in block.instructions: + def pred_at_fault(env, var_name): + # Find out where the uninitialized state comes from. + for pred, pred_state in zip(forward_edge_preds, pred_states): + if not pred_state[env][var_name]: + return pred + + # It's the entry block and it was never initialized. + return None + + set_local_in_this_frame = False + if (isinstance(insn, (ir.SetLocal, ir.GetLocal)) and + not is_special_variable(insn.var_name)): + env, var_name = insn.environment(), insn.var_name + + # Make sure that the variable is defined in the scope of this function. + if env in block_state and var_name in block_state[env]: + if isinstance(insn, ir.SetLocal): + # We've just initialized it. + block_state[env][var_name] = True + set_local_in_this_frame = True + else: # isinstance(insn, ir.GetLocal) + if not block_state[env][var_name]: + # Oops, accessing it uninitialized. + self._uninitialized_access(insn, var_name, + pred_at_fault(env, var_name)) + + closures_to_check = [] + + if (isinstance(insn, (ir.SetLocal, ir.SetAttr, ir.SetElem)) and + not set_local_in_this_frame): + # Closures may escape via these mechanisms and be invoked elsewhere. + if isinstance(insn.value(), ir.Closure): + closures_to_check.append(insn.value()) + + if isinstance(insn, (ir.Call, ir.Invoke)): + # We can't always trace the flow of closures from point of + # definition to point of call; however, we know that, by transitiveness + # of this analysis, only closures defined in this function can contain + # uninitialized variables. + # + # Thus, enumerate the closures, and check all of them during any operation + # that may eventually result in the closure being called. + closures_to_check = closures + + for closure in closures_to_check: + env = closure.environment() + # Make sure this environment has any interesting variables. + if env in block_state: + for var_name in block_state[env]: + if not block_state[env][var_name] and not is_special_variable(var_name): + # A closure would capture this variable while it is not always + # initialized. Note that this check is transitive. + self._uninitialized_access(closure, var_name, + pred_at_fault(env, var_name)) + + # Save the state. + state[block] = block_state + + return block_state + + for block in func.basic_blocks: + traverse(block) + + def _uninitialized_access(self, insn, var_name, pred_at_fault): + if pred_at_fault is not None: + uninitialized_loc = None + for pred_insn in reversed(pred_at_fault.instructions): + if pred_insn.loc is not None: + uninitialized_loc = pred_insn.loc.begin() + break + assert uninitialized_loc is not None + + note = diagnostic.Diagnostic("note", + "variable is not initialized when control flows from this point", {}, + uninitialized_loc) + else: + note = None + + if note is not None: + notes = [note] + else: + notes = [] + + if isinstance(insn, ir.Closure): + diag = diagnostic.Diagnostic("error", + "variable '{name}' can be captured in a closure uninitialized here", + {"name": var_name}, + insn.loc, notes=notes) + else: + diag = diagnostic.Diagnostic("error", + "variable '{name}' is not always initialized here", + {"name": var_name}, + insn.loc, notes=notes) + + self.engine.process(diag) diff --git a/artiq/compiler/validators/monomorphism.py b/artiq/compiler/validators/monomorphism.py new file mode 100644 index 000000000..e4dd1d853 --- /dev/null +++ b/artiq/compiler/validators/monomorphism.py @@ -0,0 +1,39 @@ +""" +:class:`MonomorphismValidator` verifies that all type variables have been +elided, which is necessary for code generation. +""" + +from pythonparser import algorithm, diagnostic +from .. import asttyped, types, builtins + +class MonomorphismValidator(algorithm.Visitor): + def __init__(self, engine): + self.engine = engine + + def visit_FunctionDefT(self, node): + super().generic_visit(node) + + return_type = node.signature_type.find().ret + if types.is_polymorphic(return_type): + note = diagnostic.Diagnostic("note", + "the function has return type {type}", + {"type": types.TypePrinter().name(return_type)}, + node.name_loc) + diag = diagnostic.Diagnostic("error", + "the return type of this function cannot be fully inferred", {}, + node.name_loc, notes=[note]) + self.engine.process(diag) + + def generic_visit(self, node): + super().generic_visit(node) + + if isinstance(node, asttyped.commontyped): + if types.is_polymorphic(node.type): + note = diagnostic.Diagnostic("note", + "the expression has type {type}", + {"type": types.TypePrinter().name(node.type)}, + node.loc) + diag = diagnostic.Diagnostic("error", + "the type of this expression cannot be fully inferred", {}, + node.loc, notes=[note]) + self.engine.process(diag) diff --git a/artiq/coredevice/comm_dummy.py b/artiq/coredevice/comm_dummy.py index 5b0c35c46..82a1a9575 100644 --- a/artiq/coredevice/comm_dummy.py +++ b/artiq/coredevice/comm_dummy.py @@ -3,28 +3,16 @@ from operator import itemgetter class Comm: def __init__(self, dmgr): - pass + super().__init__() def switch_clock(self, external): pass - def load(self, kcode): - print("================") - print(" LLVM IR") - print("================") - print(kcode) + def load(self, kernel_library): + pass - def run(self, kname): - print("RUN: "+kname) + def run(self): + pass - def serve(self, rpc_map, exception_map): - print("================") - print(" RPC map") - print("================") - for k, v in sorted(rpc_map.items(), key=itemgetter(0)): - print(str(k)+" -> "+str(v)) - print("================") - print(" Exception map") - print("================") - for k, v in sorted(exception_map.items(), key=itemgetter(0)): - print(str(k)+" -> "+str(v)) + def serve(self, object_map, symbolizer): + pass diff --git a/artiq/coredevice/comm_generic.py b/artiq/coredevice/comm_generic.py index 88fee184b..a42eeb501 100644 --- a/artiq/coredevice/comm_generic.py +++ b/artiq/coredevice/comm_generic.py @@ -1,11 +1,10 @@ import struct import logging +import traceback from enum import Enum from fractions import Fraction -from artiq.coredevice import runtime_exceptions from artiq.language import core as core_language -from artiq.coredevice.rpc_wrapper import RPCWrapper logger = logging.getLogger(__name__) @@ -13,22 +12,26 @@ logger = logging.getLogger(__name__) class _H2DMsgType(Enum): LOG_REQUEST = 1 - IDENT_REQUEST = 2 - SWITCH_CLOCK = 3 + LOG_CLEAR = 2 - LOAD_OBJECT = 4 - RUN_KERNEL = 5 + IDENT_REQUEST = 3 + SWITCH_CLOCK = 4 - RPC_REPLY = 6 + LOAD_LIBRARY = 5 + RUN_KERNEL = 6 - FLASH_READ_REQUEST = 7 - FLASH_WRITE_REQUEST = 8 - FLASH_ERASE_REQUEST = 9 - FLASH_REMOVE_REQUEST = 10 + RPC_REPLY = 7 + RPC_EXCEPTION = 8 + + FLASH_READ_REQUEST = 9 + FLASH_WRITE_REQUEST = 10 + FLASH_ERASE_REQUEST = 11 + FLASH_REMOVE_REQUEST = 12 class _D2HMsgType(Enum): LOG_REPLY = 1 + IDENT_REPLY = 2 CLOCK_SWITCH_COMPLETED = 3 CLOCK_SWITCH_FAILED = 4 @@ -50,9 +53,16 @@ class _D2HMsgType(Enum): class UnsupportedDevice(Exception): pass +class RPCReturnValueError(ValueError): + pass + class CommGeneric: - # methods for derived classes to implement + def __init__(self): + self._read_type = self._write_type = None + self._read_length = 0 + self._write_buffer = [] + def open(self): """Opens the communication channel. Must do nothing if already opened.""" @@ -72,175 +82,412 @@ class CommGeneric: """Writes exactly length bytes to the communication channel. The channel is assumed to be opened.""" raise NotImplementedError + + # + # Reader interface # def _read_header(self): self.open() + if self._read_length > 0: + raise IOError("Read underrun ({} bytes remaining)". + format(self._read_length)) + + # Wait for a synchronization sequence, 5a 5a 5a 5a. sync_count = 0 while sync_count < 4: - (c, ) = struct.unpack("B", self.read(1)) - if c == 0x5a: + (sync_byte, ) = struct.unpack("B", self.read(1)) + if sync_byte == 0x5a: sync_count += 1 else: sync_count = 0 - length = struct.unpack(">l", self.read(4))[0] - if not length: # inband connection close - raise OSError("Connection closed") - tyv = struct.unpack("B", self.read(1))[0] - ty = _D2HMsgType(tyv) - logger.debug("receiving message: type=%r length=%d", ty, length) - return length, ty - def _write_header(self, length, ty): + # Read message header. + (self._read_length, ) = struct.unpack(">l", self.read(4)) + if not self._read_length: # inband connection close + raise OSError("Connection closed") + + (raw_type, ) = struct.unpack("B", self.read(1)) + self._read_type = _D2HMsgType(raw_type) + + if self._read_length < 9: + raise IOError("Read overrun in message header ({} remaining)". + format(self._read_length)) + self._read_length -= 9 + + logger.debug("receiving message: type=%r length=%d", + self._read_type, self._read_length) + + def _read_expect(self, ty): + if self._read_type != ty: + raise IOError("Incorrect reply from device: {} (expected {})". + format(self._read_type, ty)) + + def _read_empty(self, ty): + self._read_header() + self._read_expect(ty) + + def _read_chunk(self, length): + if self._read_length < length: + raise IOError("Read overrun while trying to read {} bytes ({} remaining)" + " in packet {}". + format(length, self._read_length, self._read_type)) + + self._read_length -= length + return self.read(length) + + def _read_int8(self): + (value, ) = struct.unpack("B", self._read_chunk(1)) + return value + + def _read_int32(self): + (value, ) = struct.unpack(">l", self._read_chunk(4)) + return value + + def _read_int64(self): + (value, ) = struct.unpack(">q", self._read_chunk(8)) + return value + + def _read_float64(self): + (value, ) = struct.unpack(">d", self._read_chunk(8)) + return value + + def _read_bytes(self): + return self._read_chunk(self._read_int32()) + + def _read_string(self): + return self._read_bytes()[:-1].decode('utf-8') + + # + # Writer interface + # + + def _write_header(self, ty): self.open() - logger.debug("sending message: type=%r length=%d", ty, length) - self.write(struct.pack(">ll", 0x5a5a5a5a, length)) - if ty is not None: - self.write(struct.pack("B", ty.value)) + + logger.debug("preparing to send message: type=%r", ty) + self._write_type = ty + self._write_buffer = [] + + def _write_flush(self): + # Calculate message size. + length = sum([len(chunk) for chunk in self._write_buffer]) + logger.debug("sending message: type=%r length=%d", self._write_type, length) + + # Write synchronization sequence, header and body. + self.write(struct.pack(">llB", 0x5a5a5a5a, + 9 + length, self._write_type.value)) + for chunk in self._write_buffer: + self.write(chunk) + + def _write_empty(self, ty): + self._write_header(ty) + self._write_flush() + + def _write_chunk(self, chunk): + self._write_buffer.append(chunk) + + def _write_int8(self, value): + self._write_buffer.append(struct.pack("B", value)) + + def _write_int32(self, value): + self._write_buffer.append(struct.pack(">l", value)) + + def _write_int64(self, value): + self._write_buffer.append(struct.pack(">q", value)) + + def _write_float64(self, value): + self._write_buffer.append(struct.pack(">d", value)) + + def _write_bytes(self, value): + self._write_int32(len(value)) + self._write_buffer.append(value) + + def _write_string(self, value): + self._write_bytes(value.encode("utf-8") + b"\0") + + # + # Exported APIs + # def reset_session(self): - self._write_header(0, None) + self.write(struct.pack(">ll", 0x5a5a5a5a, 0)) def check_ident(self): - self._write_header(9, _H2DMsgType.IDENT_REQUEST) - _, ty = self._read_header() - if ty != _D2HMsgType.IDENT_REPLY: - raise IOError("Incorrect reply from device: {}".format(ty)) - (reply, ) = struct.unpack("B", self.read(1)) - runtime_id = chr(reply) - for i in range(3): - (reply, ) = struct.unpack("B", self.read(1)) - runtime_id += chr(reply) - if runtime_id != "AROR": + self._write_empty(_H2DMsgType.IDENT_REQUEST) + + self._read_header() + self._read_expect(_D2HMsgType.IDENT_REPLY) + runtime_id = self._read_chunk(4) + if runtime_id != b"AROR": raise UnsupportedDevice("Unsupported runtime ID: {}" .format(runtime_id)) def switch_clock(self, external): - self._write_header(10, _H2DMsgType.SWITCH_CLOCK) - self.write(struct.pack("B", int(external))) - _, ty = self._read_header() - if ty != _D2HMsgType.CLOCK_SWITCH_COMPLETED: - raise IOError("Incorrect reply from device: {}".format(ty)) + self._write_header(_H2DMsgType.SWITCH_CLOCK) + self._write_int8(external) + self._write_flush() - def load(self, kcode): - self._write_header(len(kcode) + 9, _H2DMsgType.LOAD_OBJECT) - self.write(kcode) - _, ty = self._read_header() - if ty != _D2HMsgType.LOAD_COMPLETED: - raise IOError("Incorrect reply from device: "+str(ty)) - - def run(self, kname): - self._write_header(len(kname) + 9, _H2DMsgType.RUN_KERNEL) - self.write(bytes(kname, "ascii")) - logger.debug("running kernel: %s", kname) - - def flash_storage_read(self, key): - self._write_header(9+len(key), _H2DMsgType.FLASH_READ_REQUEST) - self.write(key) - length, ty = self._read_header() - if ty != _D2HMsgType.FLASH_READ_REPLY: - raise IOError("Incorrect reply from device: {}".format(ty)) - value = self.read(length - 9) - return value - - def flash_storage_write(self, key, value): - self._write_header(9+len(key)+1+len(value), - _H2DMsgType.FLASH_WRITE_REQUEST) - self.write(key) - self.write(b"\x00") - self.write(value) - _, ty = self._read_header() - if ty != _D2HMsgType.FLASH_OK_REPLY: - if ty == _D2HMsgType.FLASH_ERROR_REPLY: - raise IOError("Flash storage is full") - else: - raise IOError("Incorrect reply from device: {}".format(ty)) - - def flash_storage_erase(self): - self._write_header(9, _H2DMsgType.FLASH_ERASE_REQUEST) - _, ty = self._read_header() - if ty != _D2HMsgType.FLASH_OK_REPLY: - raise IOError("Incorrect reply from device: {}".format(ty)) - - def flash_storage_remove(self, key): - self._write_header(9+len(key), _H2DMsgType.FLASH_REMOVE_REQUEST) - self.write(key) - _, ty = self._read_header() - if ty != _D2HMsgType.FLASH_OK_REPLY: - raise IOError("Incorrect reply from device: {}".format(ty)) - - def _receive_rpc_value(self, type_tag): - if type_tag == "n": - return None - if type_tag == "b": - return bool(struct.unpack("B", self.read(1))[0]) - if type_tag == "i": - return struct.unpack(">l", self.read(4))[0] - if type_tag == "I": - return struct.unpack(">q", self.read(8))[0] - if type_tag == "f": - return struct.unpack(">d", self.read(8))[0] - if type_tag == "F": - n, d = struct.unpack(">qq", self.read(16)) - return Fraction(n, d) - - def _receive_rpc_values(self): - r = [] - while True: - type_tag = chr(struct.unpack("B", self.read(1))[0]) - if type_tag == "\x00": - return r - elif type_tag == "l": - elt_type_tag = chr(struct.unpack("B", self.read(1))[0]) - length = struct.unpack(">l", self.read(4))[0] - r.append([self._receive_rpc_value(elt_type_tag) - for i in range(length)]) - else: - r.append(self._receive_rpc_value(type_tag)) - - def _serve_rpc(self, rpc_wrapper, rpc_map, user_exception_map): - rpc_num = struct.unpack(">l", self.read(4))[0] - args = self._receive_rpc_values() - logger.debug("rpc service: %d %r", rpc_num, args) - eid, r = rpc_wrapper.run_rpc( - user_exception_map, rpc_map[rpc_num], args) - self._write_header(9+2*4, _H2DMsgType.RPC_REPLY) - self.write(struct.pack(">ll", eid, r)) - logger.debug("rpc service: %d %r == %r (eid %d)", rpc_num, args, - r, eid) - - def _serve_exception(self, rpc_wrapper, user_exception_map): - eid, p0, p1, p2 = struct.unpack(">lqqq", self.read(4+3*8)) - rpc_wrapper.filter_rpc_exception(eid) - if eid < core_language.first_user_eid: - exception = runtime_exceptions.exception_map[eid] - raise exception(self.core, p0, p1, p2) - else: - exception = user_exception_map[eid] - raise exception - - def serve(self, rpc_map, user_exception_map): - rpc_wrapper = RPCWrapper() - while True: - _, ty = self._read_header() - if ty == _D2HMsgType.RPC_REQUEST: - self._serve_rpc(rpc_wrapper, rpc_map, user_exception_map) - elif ty == _D2HMsgType.KERNEL_EXCEPTION: - self._serve_exception(rpc_wrapper, user_exception_map) - elif ty == _D2HMsgType.KERNEL_FINISHED: - return - else: - raise IOError("Incorrect request from device: "+str(ty)) + self._read_empty(_D2HMsgType.CLOCK_SWITCH_COMPLETED) def get_log(self): - self._write_header(9, _H2DMsgType.LOG_REQUEST) - length, ty = self._read_header() - if ty != _D2HMsgType.LOG_REPLY: - raise IOError("Incorrect request from device: "+str(ty)) - r = "" - for i in range(length - 9): - c = struct.unpack("B", self.read(1))[0] - if c: - r += chr(c) - return r + self._write_empty(_H2DMsgType.LOG_REQUEST) + + self._read_header() + self._read_expect(_D2HMsgType.LOG_REPLY) + return self._read_chunk(self._read_length).decode('utf-8') + + def clear_log(self): + self._write_empty(_H2DMsgType.LOG_CLEAR) + + self._read_empty(_D2HMsgType.LOG_REPLY) + + def flash_storage_read(self, key): + self._write_header(_H2DMsgType.FLASH_READ_REQUEST) + self._write_string(key) + self._write_flush() + + self._read_header() + self._read_expect(_D2HMsgType.FLASH_READ_REPLY) + return self._read_chunk(self._read_length) + + def flash_storage_write(self, key, value): + self._write_header(_H2DMsgType.FLASH_WRITE_REQUEST) + self._write_string(key) + self._write_bytes(value) + self._write_flush() + + self._read_header() + if self._read_type == _D2HMsgType.FLASH_ERROR_REPLY: + raise IOError("Flash storage is full") + else: + self._read_expect(_D2HMsgType.FLASH_OK_REPLY) + + def flash_storage_erase(self): + self._write_empty(_H2DMsgType.FLASH_ERASE_REQUEST) + + self._read_empty(_D2HMsgType.FLASH_OK_REPLY) + + def flash_storage_remove(self, key): + self._write_header(_H2DMsgType.FLASH_REMOVE_REQUEST) + self._write_string(key) + self._write_flush() + + self._read_empty(_D2HMsgType.FLASH_OK_REPLY) + + def load(self, kernel_library): + self._write_header(_H2DMsgType.LOAD_LIBRARY) + self._write_chunk(kernel_library) + self._write_flush() + + self._read_empty(_D2HMsgType.LOAD_COMPLETED) + + def run(self): + self._write_empty(_H2DMsgType.RUN_KERNEL) + logger.debug("running kernel") + + _rpc_sentinel = object() + + # See session.c:{send,receive}_rpc_value and llvm_ir_generator.py:_rpc_tag. + def _receive_rpc_value(self, object_map): + tag = chr(self._read_int8()) + if tag == "\x00": + return self._rpc_sentinel + elif tag == "t": + length = self._read_int8() + return tuple(self._receive_rpc_value(object_map) for _ in range(length)) + elif tag == "n": + return None + elif tag == "b": + return bool(self._read_int8()) + elif tag == "i": + return self._read_int32() + elif tag == "I": + return self._read_int64() + elif tag == "f": + return self._read_float64() + elif tag == "F": + numerator = self._read_int64() + denominator = self._read_int64() + return Fraction(numerator, denominator) + elif tag == "s": + return self._read_string() + elif tag == "l": + length = self._read_int32() + return [self._receive_rpc_value(object_map) for _ in range(length)] + elif tag == "r": + start = self._receive_rpc_value(object_map) + stop = self._receive_rpc_value(object_map) + step = self._receive_rpc_value(object_map) + return range(start, stop, step) + elif tag == "o": + present = self._read_int8() + if present: + return self._receive_rpc_value(object_map) + elif tag == "O": + return object_map.retrieve(self._read_int32()) + else: + raise IOError("Unknown RPC value tag: {}".format(repr(tag))) + + def _receive_rpc_args(self, object_map): + args = [] + while True: + value = self._receive_rpc_value(object_map) + if value is self._rpc_sentinel: + return args + args.append(value) + + def _skip_rpc_value(self, tags): + tag = tags.pop(0) + if tag == "t": + length = tags.pop(0) + for _ in range(length): + self._skip_rpc_value(tags) + elif tag == "l": + self._skip_rpc_value(tags) + elif tag == "r": + self._skip_rpc_value(tags) + else: + pass + + def _send_rpc_value(self, tags, value, root, function): + def check(cond, expected): + if not cond: + raise RPCReturnValueError( + "type mismatch: cannot serialize {value} as {type}" + " ({function} has returned {root})".format( + value=repr(value), type=expected(), + function=function, root=root)) + + tag = chr(tags.pop(0)) + if tag == "t": + length = tags.pop(0) + check(isinstance(value, tuple) and length == len(value), + lambda: "tuple of {}".format(length)) + for elt in value: + self._send_rpc_value(tags, elt, root, function) + elif tag == "n": + check(value is None, + lambda: "None") + elif tag == "b": + check(isinstance(value, bool), + lambda: "bool") + self._write_int8(value) + elif tag == "i": + check(isinstance(value, int) and (-2**31 < value < 2**31-1), + lambda: "32-bit int") + self._write_int32(value) + elif tag == "I": + check(isinstance(value, int) and (-2**63 < value < 2**63-1), + lambda: "64-bit int") + self._write_int64(value) + elif tag == "f": + check(isinstance(value, float), + lambda: "float") + self._write_float64(value) + elif tag == "F": + check(isinstance(value, Fraction) and + (-2**63 < value.numerator < 2**63-1) and + (-2**63 < value.denominator < 2**63-1), + lambda: "64-bit Fraction") + self._write_int64(value.numerator) + self._write_int64(value.denominator) + elif tag == "s": + check(isinstance(value, str) and "\x00" not in value, + lambda: "str") + self._write_string(value) + elif tag == "l": + check(isinstance(value, list), + lambda: "list") + self._write_int32(len(value)) + for elt in value: + tags_copy = bytearray(tags) + self._send_rpc_value(tags_copy, elt, root, function) + self._skip_rpc_value(tags) + elif tag == "r": + check(isinstance(value, range), + lambda: "range") + tags_copy = bytearray(tags) + self._send_rpc_value(tags_copy, value.start, root, function) + tags_copy = bytearray(tags) + self._send_rpc_value(tags_copy, value.stop, root, function) + tags_copy = bytearray(tags) + self._send_rpc_value(tags_copy, value.step, root, function) + tags = tags_copy + else: + raise IOError("Unknown RPC value tag: {}".format(repr(tag))) + + def _serve_rpc(self, object_map): + service = self._read_int32() + args = self._receive_rpc_args(object_map) + return_tags = self._read_bytes() + logger.debug("rpc service: %d %r -> %s", service, args, return_tags) + + try: + result = object_map.retrieve(service)(*args) + logger.debug("rpc service: %d %r == %r", service, args, result) + + self._write_header(_H2DMsgType.RPC_REPLY) + self._write_bytes(return_tags) + self._send_rpc_value(bytearray(return_tags), result, result, + object_map.retrieve(service)) + self._write_flush() + except core_language.ARTIQException as exn: + logger.debug("rpc service: %d %r ! %r", service, args, exn) + + self._write_header(_H2DMsgType.RPC_EXCEPTION) + self._write_string(exn.name) + self._write_string(exn.message) + for index in range(3): + self._write_int64(exn.param[index]) + + self._write_string(exn.filename) + self._write_int32(exn.line) + self._write_int32(exn.column) + self._write_string(exn.function) + + self._write_flush() + except Exception as exn: + logger.debug("rpc service: %d %r ! %r", service, args, exn) + + self._write_header(_H2DMsgType.RPC_EXCEPTION) + self._write_string(type(exn).__name__) + self._write_string(str(exn)) + for index in range(3): + self._write_int64(0) + + (_, (filename, line, function, _), ) = traceback.extract_tb(exn.__traceback__, 2) + self._write_string(filename) + self._write_int32(line) + self._write_int32(-1) # column not known + self._write_string(function) + + self._write_flush() + + def _serve_exception(self, symbolizer): + name = self._read_string() + message = self._read_string() + params = [self._read_int64() for _ in range(3)] + + filename = self._read_string() + line = self._read_int32() + column = self._read_int32() + function = self._read_string() + + backtrace = [self._read_int32() for _ in range(self._read_int32())] + + traceback = list(reversed(symbolizer(backtrace))) + \ + [(filename, line, column, function, None)] + raise core_language.ARTIQException(name, message, params, traceback) + + def serve(self, object_map, symbolizer): + while True: + self._read_header() + if self._read_type == _D2HMsgType.RPC_REQUEST: + self._serve_rpc(object_map) + elif self._read_type == _D2HMsgType.KERNEL_EXCEPTION: + self._serve_exception(symbolizer) + else: + self._read_expect(_D2HMsgType.KERNEL_FINISHED) + return diff --git a/artiq/coredevice/comm_serial.py b/artiq/coredevice/comm_serial.py index 70218da14..91eda3567 100644 --- a/artiq/coredevice/comm_serial.py +++ b/artiq/coredevice/comm_serial.py @@ -10,6 +10,7 @@ logger = logging.getLogger(__name__) class Comm(CommGeneric): def __init__(self, dmgr, serial_dev, baud_rate=115200): + super().__init__() self.serial_dev = serial_dev self.baud_rate = baud_rate @@ -27,10 +28,10 @@ class Comm(CommGeneric): del self.port def read(self, length): - r = bytes() - while len(r) < length: - r += self.port.read(length - len(r)) - return r + result = bytes() + while len(result) < length: + result += self.port.read(length - len(result)) + return result def write(self, data): remaining = len(data) diff --git a/artiq/coredevice/comm_tcp.py b/artiq/coredevice/comm_tcp.py index 8c3334c8f..cd8d97e9a 100644 --- a/artiq/coredevice/comm_tcp.py +++ b/artiq/coredevice/comm_tcp.py @@ -26,6 +26,7 @@ def set_keepalive(sock, after_idle, interval, max_fails): class Comm(CommGeneric): def __init__(self, dmgr, host, port=1381): + super().__init__() self.host = host self.port = port diff --git a/artiq/coredevice/core.py b/artiq/coredevice/core.py index e9f6b57e3..0ce6e9dc4 100644 --- a/artiq/coredevice/core.py +++ b/artiq/coredevice/core.py @@ -1,49 +1,37 @@ import os +from pythonparser import diagnostic + from artiq.language.core import * -from artiq.language.units import ns +from artiq.language.types import * +from artiq.language.units import * -from artiq.transforms.inline import inline -from artiq.transforms.quantize_time import quantize_time -from artiq.transforms.remove_inter_assigns import remove_inter_assigns -from artiq.transforms.fold_constants import fold_constants -from artiq.transforms.remove_dead_code import remove_dead_code -from artiq.transforms.unroll_loops import unroll_loops -from artiq.transforms.interleave import interleave -from artiq.transforms.lower_time import lower_time -from artiq.transforms.unparse import unparse +from artiq.compiler import Stitcher, Module +from artiq.compiler.targets import OR1KTarget -from artiq.coredevice.runtime import Runtime - -from artiq.py2llvm import get_runtime_binary +# Import for side effects (creating the exception classes). +from artiq.coredevice import exceptions -def _announce_unparse(label, node): - print("*** Unparsing: "+label) - print(unparse(node)) +class CompileError(Exception): + def __init__(self, diagnostic): + self.diagnostic = diagnostic + + def render_string(self, colored=False): + def shorten_path(path): + return path.replace(os.path.normpath(os.path.join(__file__, "..", "..")), "") + lines = [shorten_path(path) for path in self.diagnostic.render(colored=colored)] + return "\n".join(lines) + + def __str__(self): + # Prepend a newline so that the message shows up on after + # exception class name printed by Python. + return "\n" + self.render_string(colored=True) -def _make_debug_unparse(final): - try: - env = os.environ["ARTIQ_UNPARSE"] - except KeyError: - env = "" - selected_labels = set(env.split()) - if "all" in selected_labels: - return _announce_unparse - else: - if "final" in selected_labels: - selected_labels.add(final) - - def _filtered_unparse(label, node): - if label in selected_labels: - _announce_unparse(label, node) - return _filtered_unparse - - -def _no_debug_unparse(label, node): - pass - +@syscall +def rtio_get_counter() -> TInt64: + raise NotImplementedError("syscall not simulated") class Core: """Core device driver. @@ -66,79 +54,46 @@ class Core: self.first_run = True self.core = self self.comm.core = self - self.runtime = Runtime() - def transform_stack(self, func_def, rpc_map, exception_map, - debug_unparse=_no_debug_unparse): - remove_inter_assigns(func_def) - debug_unparse("remove_inter_assigns_1", func_def) + def compile(self, function, args, kwargs, with_attr_writeback=True): + try: + engine = diagnostic.Engine(all_errors_are_fatal=True) - quantize_time(func_def, self.ref_period) - debug_unparse("quantize_time", func_def) + stitcher = Stitcher(engine=engine) + stitcher.stitch_call(function, args, kwargs) + stitcher.finalize() - fold_constants(func_def) - debug_unparse("fold_constants_1", func_def) + module = Module(stitcher, ref_period=self.ref_period) + target = OR1KTarget() - unroll_loops(func_def, 500) - debug_unparse("unroll_loops", func_def) + library = target.compile_and_link([module]) + stripped_library = target.strip(library) - interleave(func_def) - debug_unparse("interleave", func_def) + return stitcher.object_map, stripped_library, \ + lambda addresses: target.symbolize(library, addresses) + except diagnostic.Error as error: + raise CompileError(error.diagnostic) from error - lower_time(func_def) - debug_unparse("lower_time", func_def) + def run(self, function, args, kwargs): + object_map, kernel_library, symbolizer = self.compile(function, args, kwargs) - remove_inter_assigns(func_def) - debug_unparse("remove_inter_assigns_2", func_def) - - fold_constants(func_def) - debug_unparse("fold_constants_2", func_def) - - remove_dead_code(func_def) - debug_unparse("remove_dead_code_1", func_def) - - remove_inter_assigns(func_def) - debug_unparse("remove_inter_assigns_3", func_def) - - fold_constants(func_def) - debug_unparse("fold_constants_3", func_def) - - remove_dead_code(func_def) - debug_unparse("remove_dead_code_2", func_def) - - def compile(self, k_function, k_args, k_kwargs, with_attr_writeback=True): - debug_unparse = _make_debug_unparse("remove_dead_code_2") - - func_def, rpc_map, exception_map = inline( - self, k_function, k_args, k_kwargs, with_attr_writeback) - debug_unparse("inline", func_def) - self.transform_stack(func_def, rpc_map, exception_map, debug_unparse) - - binary = get_runtime_binary(self.runtime, func_def) - - return binary, rpc_map, exception_map - - def run(self, k_function, k_args, k_kwargs): if self.first_run: self.comm.check_ident() self.comm.switch_clock(self.external_clock) + self.first_run = False - binary, rpc_map, exception_map = self.compile( - k_function, k_args, k_kwargs) - self.comm.load(binary) - self.comm.run(k_function.__name__) - self.comm.serve(rpc_map, exception_map) - self.first_run = False + self.comm.load(kernel_library) + self.comm.run() + self.comm.serve(object_map, symbolizer) @kernel def get_rtio_counter_mu(self): - """Return the current value of the hardware RTIO counter.""" - return syscall("rtio_get_counter") + return rtio_get_counter() @kernel def break_realtime(self): """Set the timeline to the current value of the hardware RTIO counter plus a margin of 125000 machine units.""" - min_now = syscall("rtio_get_counter") + 125000 + min_now = rtio_get_counter() + 125000 if now_mu() < min_now: at_mu(min_now) diff --git a/artiq/coredevice/dds.py b/artiq/coredevice/dds.py index 4eb0872fc..8350ee3d2 100644 --- a/artiq/coredevice/dds.py +++ b/artiq/coredevice/dds.py @@ -1,4 +1,5 @@ from artiq.language.core import * +from artiq.language.types import * from artiq.language.units import * @@ -9,6 +10,24 @@ PHASE_MODE_ABSOLUTE = 1 PHASE_MODE_TRACKING = 2 +@syscall +def dds_init(time_mu: TInt64, channel: TInt32) -> TNone: + raise NotImplementedError("syscall not simulated") + +@syscall +def dds_batch_enter(time_mu: TInt64) -> TNone: + raise NotImplementedError("syscall not simulated") + +@syscall +def dds_batch_exit() -> TNone: + raise NotImplementedError("syscall not simulated") + +@syscall +def dds_set(time_mu: TInt64, channel: TInt32, ftw: TInt32, + pow: TInt32, phase_mode: TInt32) -> TNone: + raise NotImplementedError("syscall not simulated") + + class _BatchContextManager: def __init__(self, dds_bus): self.dds_bus = dds_bus @@ -37,13 +56,13 @@ class DDSBus: The time of execution of the DDS commands is the time of entering the batch (as closely as hardware permits).""" - syscall("dds_batch_enter", now_mu()) + dds_batch_enter(now_mu()) @kernel def batch_exit(self): """Ends a DDS command batch. All buffered DDS commands are issued on the bus.""" - syscall("dds_batch_exit") + dds_batch_exit() class _DDSGeneric: @@ -105,7 +124,7 @@ class _DDSGeneric: """Resets and initializes the DDS channel. The runtime does this for all channels upon core device startup.""" - syscall("dds_init", now_mu(), self.channel) + dds_init(now_mu(), self.channel) @kernel def set_phase_mode(self, phase_mode): @@ -144,8 +163,7 @@ class _DDSGeneric: """ if phase_mode == _PHASE_MODE_DEFAULT: phase_mode = self.phase_mode - syscall("dds_set", now_mu(), self.channel, frequency, - phase, phase_mode, amplitude) + dds_set(now_mu(), self.channel, frequency, phase, phase_mode, amplitude) @kernel def set(self, frequency, phase=0.0, phase_mode=_PHASE_MODE_DEFAULT, diff --git a/artiq/coredevice/exceptions.py b/artiq/coredevice/exceptions.py new file mode 100644 index 000000000..d0f54f2fa --- /dev/null +++ b/artiq/coredevice/exceptions.py @@ -0,0 +1,41 @@ +from artiq.language.core import ARTIQException + +class ZeroDivisionError(ARTIQException): + """Python's :class:`ZeroDivisionError`, mirrored in ARTIQ.""" + +class ValueError(ARTIQException): + """Python's :class:`ValueError`, mirrored in ARTIQ.""" + +class IndexError(ARTIQException): + """Python's :class:`IndexError`, mirrored in ARTIQ.""" + +class InternalError(ARTIQException): + """Raised when the runtime encounters an internal error condition.""" + +class RTIOUnderflow(ARTIQException): + """Raised when the CPU fails to submit a RTIO event early enough + (with respect to the event's timestamp). + + The offending event is discarded and the RTIO core keeps operating. + """ + +class RTIOSequenceError(ARTIQException): + """Raised when an event is submitted on a given channel with a timestamp + not larger than the previous one. + + The offending event is discarded and the RTIO core keeps operating. + """ + +class RTIOOverflow(ARTIQException): + """Raised when at least one event could not be registered into the RTIO + input FIFO because it was full (CPU not reading fast enough). + + This does not interrupt operations further than cancelling the current + read attempt and discarding some events. Reading can be reattempted after + the exception is caught, and events will be partially retrieved. + """ + +class DDSBatchError(ARTIQException): + """Raised when attempting to start a DDS batch while already in a batch, + or when too many commands are batched. + """ diff --git a/artiq/coredevice/rpc_wrapper.py b/artiq/coredevice/rpc_wrapper.py deleted file mode 100644 index eeae17286..000000000 --- a/artiq/coredevice/rpc_wrapper.py +++ /dev/null @@ -1,40 +0,0 @@ -from artiq.coredevice.runtime_exceptions import exception_map, _RPCException - - -def _lookup_exception(d, e): - for eid, exception in d.items(): - if isinstance(e, exception): - return eid - return 0 - - -class RPCWrapper: - def __init__(self): - self.last_exception = None - - def run_rpc(self, user_exception_map, fn, args): - eid = 0 - r = None - - try: - r = fn(*args) - except Exception as e: - eid = _lookup_exception(user_exception_map, e) - if not eid: - eid = _lookup_exception(exception_map, e) - if eid: - self.last_exception = None - else: - self.last_exception = e - eid = _RPCException.eid - - if r is None: - r = 0 - else: - r = int(r) - - return eid, r - - def filter_rpc_exception(self, eid): - if eid == _RPCException.eid: - raise self.last_exception diff --git a/artiq/coredevice/runtime.py b/artiq/coredevice/runtime.py index 5fcaf56d1..a9d74b63d 100644 --- a/artiq/coredevice/runtime.py +++ b/artiq/coredevice/runtime.py @@ -1,212 +1,13 @@ import os -import llvmlite_artiq.ir as ll -import llvmlite_artiq.binding as llvm +class SourceLoader: + def __init__(self, runtime_root): + self.runtime_root = runtime_root -from artiq.py2llvm import base_types, fractions, lists -from artiq.language import units + def get_source(self, filename): + print(os.path.join(self.runtime_root, filename)) + with open(os.path.join(self.runtime_root, filename)) as f: + return f.read() - -llvm.initialize() -llvm.initialize_all_targets() -llvm.initialize_all_asmprinters() - -_syscalls = { - "now_init": "n:I", - "now_save": "I:n", - "watchdog_set": "i:i", - "watchdog_clear": "i:n", - "rtio_get_counter": "n:I", - "ttl_set_o": "Iib:n", - "ttl_set_oe": "Iib:n", - "ttl_set_sensitivity": "Iii:n", - "ttl_get": "iI:I", - "ttl_clock_set": "Iii:n", - "dds_init": "Ii:n", - "dds_batch_enter": "I:n", - "dds_batch_exit": "n:n", - "dds_set": "Iiiiii:n", -} - - -def _chr_to_type(c): - if c == "n": - return ll.VoidType() - if c == "b": - return ll.IntType(1) - if c == "i": - return ll.IntType(32) - if c == "I": - return ll.IntType(64) - raise ValueError - - -def _str_to_functype(s): - assert(s[-2] == ":") - type_ret = _chr_to_type(s[-1]) - type_args = [_chr_to_type(c) for c in s[:-2] if c != "n"] - return ll.FunctionType(type_ret, type_args) - - -def _chr_to_value(c): - if c == "n": - return base_types.VNone() - if c == "b": - return base_types.VBool() - if c == "i": - return base_types.VInt() - if c == "I": - return base_types.VInt(64) - raise ValueError - - -def _value_to_str(v): - if isinstance(v, base_types.VNone): - return "n" - if isinstance(v, base_types.VBool): - return "b" - if isinstance(v, base_types.VInt): - if v.nbits == 32: - return "i" - if v.nbits == 64: - return "I" - raise ValueError - if isinstance(v, base_types.VFloat): - return "f" - if isinstance(v, fractions.VFraction): - return "F" - if isinstance(v, lists.VList): - return "l" + _value_to_str(v.el_type) - raise ValueError - - -class LinkInterface: - def init_module(self, module): - self.module = module - llvm_module = self.module.llvm_module - - # RPC - func_type = ll.FunctionType(ll.IntType(32), [ll.IntType(32)], - var_arg=1) - self.rpc = ll.Function(llvm_module, func_type, "__syscall_rpc") - - # syscalls - self.syscalls = dict() - for func_name, func_type_str in _syscalls.items(): - func_type = _str_to_functype(func_type_str) - self.syscalls[func_name] = ll.Function( - llvm_module, func_type, "__syscall_" + func_name) - - # exception handling - func_type = ll.FunctionType(ll.IntType(32), - [ll.PointerType(ll.IntType(8))]) - self.eh_setjmp = ll.Function(llvm_module, func_type, - "__eh_setjmp") - self.eh_setjmp.attributes.add("nounwind") - self.eh_setjmp.attributes.add("returns_twice") - - func_type = ll.FunctionType(ll.PointerType(ll.IntType(8)), []) - self.eh_push = ll.Function(llvm_module, func_type, "__eh_push") - - func_type = ll.FunctionType(ll.VoidType(), [ll.IntType(32)]) - self.eh_pop = ll.Function(llvm_module, func_type, "__eh_pop") - - func_type = ll.FunctionType(ll.IntType(32), []) - self.eh_getid = ll.Function(llvm_module, func_type, "__eh_getid") - - func_type = ll.FunctionType(ll.VoidType(), [ll.IntType(32)]) - self.eh_raise = ll.Function(llvm_module, func_type, "__eh_raise") - self.eh_raise.attributes.add("noreturn") - - def _build_rpc(self, args, builder): - r = base_types.VInt() - if builder is not None: - new_args = [] - new_args.append(args[0].auto_load(builder)) # RPC number - for arg in args[1:]: - # type tag - arg_type_str = _value_to_str(arg) - arg_type_int = 0 - for c in reversed(arg_type_str): - arg_type_int <<= 8 - arg_type_int |= ord(c) - new_args.append(ll.Constant(ll.IntType(32), arg_type_int)) - - # pointer to value - if not isinstance(arg, base_types.VNone): - if isinstance(arg.llvm_value.type, ll.PointerType): - new_args.append(arg.llvm_value) - else: - arg_ptr = arg.new() - arg_ptr.alloca(builder) - arg_ptr.auto_store(builder, arg.llvm_value) - new_args.append(arg_ptr.llvm_value) - # end marker - new_args.append(ll.Constant(ll.IntType(32), 0)) - r.auto_store(builder, builder.call(self.rpc, new_args)) - return r - - def _build_regular_syscall(self, syscall_name, args, builder): - r = _chr_to_value(_syscalls[syscall_name][-1]) - if builder is not None: - args = [arg.auto_load(builder) for arg in args] - r.auto_store(builder, builder.call(self.syscalls[syscall_name], - args)) - return r - - def build_syscall(self, syscall_name, args, builder): - if syscall_name == "rpc": - return self._build_rpc(args, builder) - else: - return self._build_regular_syscall(syscall_name, args, builder) - - def build_catch(self, builder): - jmpbuf = builder.call(self.eh_push, []) - exception_occured = builder.call(self.eh_setjmp, [jmpbuf]) - return builder.icmp_signed("!=", - exception_occured, - ll.Constant(ll.IntType(32), 0)) - - def build_pop(self, builder, levels): - builder.call(self.eh_pop, [ll.Constant(ll.IntType(32), levels)]) - - def build_getid(self, builder): - return builder.call(self.eh_getid, []) - - def build_raise(self, builder, eid): - builder.call(self.eh_raise, [eid]) - - -def _debug_dump_obj(obj): - try: - env = os.environ["ARTIQ_DUMP_OBJECT"] - except KeyError: - return - - for i in range(1000): - filename = "{}_{:03d}.elf".format(env, i) - try: - f = open(filename, "xb") - except FileExistsError: - pass - else: - f.write(obj) - f.close() - return - raise IOError - - -class Runtime(LinkInterface): - def __init__(self): - self.cpu_type = "or1k" - # allow 1ms for all initial DDS programming - self.warmup_time = 1*units.ms - - def emit_object(self): - tm = llvm.Target.from_triple(self.cpu_type).create_target_machine() - obj = tm.emit_object(self.module.llvm_module_ref) - _debug_dump_obj(obj) - return obj - - def __repr__(self): - return "".format(self.cpu_type) +artiq_root = os.path.join(os.path.dirname(__file__), "..", "..") +source_loader = SourceLoader(os.path.join(artiq_root, "soc", "runtime")) diff --git a/artiq/coredevice/runtime_exceptions.py b/artiq/coredevice/runtime_exceptions.py deleted file mode 100644 index e97152e71..000000000 --- a/artiq/coredevice/runtime_exceptions.py +++ /dev/null @@ -1,85 +0,0 @@ -import inspect - -from artiq.language.core import RuntimeException - - -# Must be kept in sync with soc/runtime/exceptions.h - -class InternalError(RuntimeException): - """Raised when the runtime encounters an internal error condition.""" - eid = 1 - - -class _RPCException(RuntimeException): - eid = 2 - - -class RTIOUnderflow(RuntimeException): - """Raised when the CPU fails to submit a RTIO event early enough - (with respect to the event's timestamp). - - The offending event is discarded and the RTIO core keeps operating. - """ - eid = 3 - - def __str__(self): - return "at {} on channel {}, violation {}".format( - self.p0*self.core.ref_period, - self.p1, - (self.p2 - self.p0)*self.core.ref_period) - - -class RTIOSequenceError(RuntimeException): - """Raised when an event is submitted on a given channel with a timestamp - not larger than the previous one. - - The offending event is discarded and the RTIO core keeps operating. - """ - eid = 4 - - def __str__(self): - return "at {} on channel {}".format(self.p0*self.core.ref_period, - self.p1) - -class RTIOCollisionError(RuntimeException): - """Raised when an event is submitted on a given channel with the same - coarse timestamp as the previous one but with a different fine timestamp. - - Coarse timestamps correspond to the RTIO system clock (typically around - 125MHz) whereas fine timestamps correspond to the RTIO SERDES clock - (typically around 1GHz). - - The offending event is discarded and the RTIO core keeps operating. - """ - eid = 5 - - def __str__(self): - return "at {} on channel {}".format(self.p0*self.core.ref_period, - self.p1) - - -class RTIOOverflow(RuntimeException): - """Raised when at least one event could not be registered into the RTIO - input FIFO because it was full (CPU not reading fast enough). - - This does not interrupt operations further than cancelling the current - read attempt and discarding some events. Reading can be reattempted after - the exception is caught, and events will be partially retrieved. - """ - eid = 6 - - def __str__(self): - return "on channel {}".format(self.p0) - - -class DDSBatchError(RuntimeException): - """Raised when attempting to start a DDS batch while already in a batch, - or when too many commands are batched. - """ - eid = 7 - - -exception_map = {e.eid: e for e in globals().values() - if inspect.isclass(e) - and issubclass(e, RuntimeException) - and hasattr(e, "eid")} diff --git a/artiq/coredevice/ttl.py b/artiq/coredevice/ttl.py index 6285678af..be410d919 100644 --- a/artiq/coredevice/ttl.py +++ b/artiq/coredevice/ttl.py @@ -1,4 +1,26 @@ from artiq.language.core import * +from artiq.language.types import * + + +@syscall +def ttl_set_o(time_mu: TInt64, channel: TInt32, enabled: TBool) -> TNone: + raise NotImplementedError("syscall not simulated") + +@syscall +def ttl_set_oe(time_mu: TInt64, channel: TInt32, enabled: TBool) -> TNone: + raise NotImplementedError("syscall not simulated") + +@syscall +def ttl_set_sensitivity(time_mu: TInt64, channel: TInt32, sensitivity: TInt32) -> TNone: + raise NotImplementedError("syscall not simulated") + +@syscall +def ttl_get(channel: TInt32, time_limit_mu: TInt64) -> TInt64: + raise NotImplementedError("syscall not simulated") + +@syscall +def ttl_clock_set(time_mu: TInt64, channel: TInt32, ftw: TInt32) -> TNone: + raise NotImplementedError("syscall not simulated") class TTLOut: @@ -13,18 +35,18 @@ class TTLOut: self.channel = channel # in RTIO cycles - self.o_previous_timestamp = int64(0) + self.o_previous_timestamp = int(0, width=64) @kernel def set_o(self, o): - syscall("ttl_set_o", now_mu(), self.channel, o) + ttl_set_o(now_mu(), self.channel, o) self.o_previous_timestamp = now_mu() @kernel def sync(self): """Busy-wait until all programmed level switches have been effected.""" - while syscall("rtio_get_counter") < self.o_previous_timestamp: + while self.core.get_rtio_counter_mu() < self.o_previous_timestamp: pass @kernel @@ -76,12 +98,12 @@ class TTLInOut: self.channel = channel # in RTIO cycles - self.o_previous_timestamp = int64(0) - self.i_previous_timestamp = int64(0) + self.o_previous_timestamp = int(0, width=64) + self.i_previous_timestamp = int(0, width=64) @kernel def set_oe(self, oe): - syscall("ttl_set_oe", now_mu(), self.channel, oe) + ttl_set_oe(now_mu(), self.channel, oe) @kernel def output(self): @@ -95,14 +117,14 @@ class TTLInOut: @kernel def set_o(self, o): - syscall("ttl_set_o", now_mu(), self.channel, o) + ttl_set_o(now_mu(), self.channel, o) self.o_previous_timestamp = now_mu() @kernel def sync(self): """Busy-wait until all programmed level switches have been effected.""" - while syscall("rtio_get_counter") < self.o_previous_timestamp: + while self.core.get_rtio_counter_mu() < self.o_previous_timestamp: pass @kernel @@ -137,7 +159,7 @@ class TTLInOut: @kernel def _set_sensitivity(self, value): - syscall("ttl_set_sensitivity", now_mu(), self.channel, value) + ttl_set_sensitivity(now_mu(), self.channel, value) self.i_previous_timestamp = now_mu() @kernel @@ -193,8 +215,7 @@ class TTLInOut: """Poll the RTIO input during all the previously programmed gate openings, and returns the number of registered events.""" count = 0 - while syscall("ttl_get", self.channel, - self.i_previous_timestamp) >= 0: + while ttl_get(self.channel, self.i_previous_timestamp) >= 0: count += 1 return count @@ -205,7 +226,7 @@ class TTLInOut: If the gate is permanently closed, returns a negative value. """ - return syscall("ttl_get", self.channel, self.i_previous_timestamp) + return ttl_get(self.channel, self.i_previous_timestamp) class TTLClockGen: @@ -221,7 +242,7 @@ class TTLClockGen: self.channel = channel # in RTIO cycles - self.previous_timestamp = int64(0) + self.previous_timestamp = int(0, width=64) self.acc_width = 24 @portable @@ -256,7 +277,7 @@ class TTLClockGen: that are not powers of two cause jitter of one RTIO clock cycle at the output. """ - syscall("ttl_clock_set", now_mu(), self.channel, frequency) + ttl_clock_set(now_mu(), self.channel, frequency) self.previous_timestamp = now_mu() @kernel @@ -273,5 +294,5 @@ class TTLClockGen: def sync(self): """Busy-wait until all programmed frequency switches and stops have been effected.""" - while syscall("rtio_get_counter") < self.o_previous_timestamp: + while self.core.get_rtio_counter_mu() < self.o_previous_timestamp: pass diff --git a/artiq/frontend/artiq_compile.py b/artiq/frontend/artiq_compile.py index 78ebff837..50c7d82af 100755 --- a/artiq/frontend/artiq_compile.py +++ b/artiq/frontend/artiq_compile.py @@ -1,10 +1,10 @@ #!/usr/bin/env python3.5 -import logging -import argparse +import sys, logging, argparse from artiq.master.databases import DeviceDB, DatasetDB from artiq.master.worker_db import DeviceManager, DatasetManager +from artiq.coredevice.core import CompileError from artiq.tools import * @@ -40,34 +40,35 @@ def main(): dataset_mgr = DatasetManager(DatasetDB(args.dataset_db)) try: - module = file_import(args.file) + module = file_import(args.file, prefix="artiq_run_") exp = get_experiment(module, args.experiment) arguments = parse_arguments(args.arguments) exp_inst = exp(device_mgr, dataset_mgr, **arguments) - if (not hasattr(exp.run, "k_function_info") - or not exp.run.k_function_info): + if not hasattr(exp.run, "artiq_embedded"): raise ValueError("Experiment entry point must be a kernel") - core_name = exp.run.k_function_info.core_name + core_name = exp.run.artiq_embedded.core_name core = getattr(exp_inst, core_name) - binary, rpc_map, _ = core.compile(exp.run.k_function_info.k_function, - [exp_inst], {}, - with_attr_writeback=False) + object_map, kernel_library, symbolizer = \ + core.compile(exp.run, [exp_inst], {}, + with_attr_writeback=False) + except CompileError as error: + print(error.render_string(colored=True), file=sys.stderr) + return finally: device_mgr.close_devices() - if rpc_map: + if object_map.has_rpc(): raise ValueError("Experiment must not use RPC") output = args.output if output is None: - output = args.file - if output.endswith(".py"): - output = output[:-3] - output += ".elf" + basename, ext = os.path.splitext(args.file) + output = "{}.elf".format(basename) + with open(output, "wb") as f: - f.write(binary) + f.write(kernel_library) if __name__ == "__main__": main() diff --git a/artiq/frontend/artiq_coretool.py b/artiq/frontend/artiq_coretool.py index 53ee9a13b..735eea86c 100755 --- a/artiq/frontend/artiq_coretool.py +++ b/artiq/frontend/artiq_coretool.py @@ -26,7 +26,7 @@ def get_argparser(): # Configuration Read command p_read = subparsers.add_parser("cfg-read", help="read key from core device config") - p_read.add_argument("key", type=to_bytes, + p_read.add_argument("key", type=str, help="key to be read from core device config") # Configuration Write command @@ -34,11 +34,11 @@ def get_argparser(): help="write key-value records to core " "device config") p_write.add_argument("-s", "--string", nargs=2, action="append", - default=[], metavar=("KEY", "STRING"), type=to_bytes, + default=[], metavar=("KEY", "STRING"), type=str, help="key-value records to be written to core device " "config") p_write.add_argument("-f", "--file", nargs=2, action="append", - type=to_bytes, default=[], + type=str, default=[], metavar=("KEY", "FILENAME"), help="key and file whose content to be written to " "core device config") @@ -47,7 +47,7 @@ def get_argparser(): p_delete = subparsers.add_parser("cfg-delete", help="delete key from core device config") p_delete.add_argument("key", nargs=argparse.REMAINDER, - default=[], type=to_bytes, + default=[], type=str, help="key to be deleted from core device config") # Configuration Erase command @@ -61,9 +61,10 @@ def main(): device_mgr = DeviceManager(DeviceDB(args.device_db)) try: comm = device_mgr.get("comm") + comm.check_ident() if args.action == "log": - print(comm.get_log()) + print(comm.get_log(), end='') elif args.action == "cfg-read": value = comm.flash_storage_read(args.key) if not value: @@ -72,7 +73,7 @@ def main(): print(value) elif args.action == "cfg-write": for key, value in args.string: - comm.flash_storage_write(key, value) + comm.flash_storage_write(key, value.encode("utf-8")) for key, filename in args.file: with open(filename, "rb") as fi: comm.flash_storage_write(key, fi.read()) diff --git a/artiq/frontend/artiq_run.py b/artiq/frontend/artiq_run.py index fa65a6a0a..b20e7de94 100755 --- a/artiq/frontend/artiq_run.py +++ b/artiq/frontend/artiq_run.py @@ -12,9 +12,11 @@ import h5py from artiq.language.environment import EnvExperiment from artiq.master.databases import DeviceDB, DatasetDB from artiq.master.worker_db import DeviceManager, DatasetManager +from artiq.coredevice.core import CompileError +from artiq.compiler.embedding import ObjectMap +from artiq.compiler.targets import OR1KTarget from artiq.tools import * - logger = logging.getLogger(__name__) @@ -25,9 +27,13 @@ class ELFRunner(EnvExperiment): def run(self): with open(self.file, "rb") as f: - self.core.comm.load(f.read()) - self.core.comm.run("run") - self.core.comm.serve(dict(), dict()) + kernel_library = f.read() + + target = OR1KTarget() + self.core.comm.load(kernel_library) + self.core.comm.run() + self.core.comm.serve(ObjectMap(), + lambda addresses: target.symbolize(kernel_library, addresses)) class DummyScheduler: @@ -92,7 +98,7 @@ def _build_experiment(device_mgr, dataset_mgr, args): "for ELF kernels") return ELFRunner(device_mgr, dataset_mgr, file=args.file) else: - module = file_import(args.file) + module = file_import(args.file, prefix="artiq_run_") file = args.file else: module = sys.modules["__main__"] @@ -122,6 +128,9 @@ def run(with_file=False): exp_inst.prepare() exp_inst.run() exp_inst.analyze() + except CompileError as error: + print(error.render_string(colored=True), file=sys.stderr) + return finally: device_mgr.close_devices() diff --git a/artiq/gateware/targets/pipistrello.py b/artiq/gateware/targets/pipistrello.py index aa6957d61..07c682990 100755 --- a/artiq/gateware/targets/pipistrello.py +++ b/artiq/gateware/targets/pipistrello.py @@ -131,12 +131,7 @@ trce -v 12 -fastpaths -tsi {build_name}.tsi -o {build_name}.twr {build_name}.ncd """ platform.add_extension(nist_qc1.papilio_adapter_io) - self.submodules.leds = gpio.GPIOOut(Cat( - platform.request("user_led", 0), - platform.request("user_led", 1), - platform.request("user_led", 2), - platform.request("user_led", 3), - )) + self.submodules.leds = gpio.GPIOOut(platform.request("user_led", 4)) self.comb += [ platform.request("ttl_l_tx_en").eq(1), @@ -173,9 +168,10 @@ trce -v 12 -fastpaths -tsi {build_name}.tsi -o {build_name}.twr {build_name}.ncd self.submodules += phy rtio_channels.append(rtio.Channel.from_phy(phy, ofifo_depth=4)) - phy = ttl_simple.Output(platform.request("user_led", 4)) - self.submodules += phy - rtio_channels.append(rtio.Channel.from_phy(phy, ofifo_depth=4)) + for led_number in range(4): + phy = ttl_simple.Output(platform.request("user_led", led_number)) + self.submodules += phy + rtio_channels.append(rtio.Channel.from_phy(phy, ofifo_depth=4)) self.add_constant("RTIO_REGULAR_TTL_COUNT", len(rtio_channels)) diff --git a/artiq/language/__init__.py b/artiq/language/__init__.py index 39edc6164..763babfbd 100644 --- a/artiq/language/__init__.py +++ b/artiq/language/__init__.py @@ -1,7 +1,8 @@ # Copyright (C) 2014, 2015 Robert Jordens -from artiq.language import core, environment, units, scan +from artiq.language import core, types, environment, units, scan from artiq.language.core import * +from artiq.language.types import * from artiq.language.environment import * from artiq.language.units import * from artiq.language.scan import * @@ -9,6 +10,7 @@ from artiq.language.scan import * __all__ = [] __all__.extend(core.__all__) +__all__.extend(types.__all__) __all__.extend(environment.__all__) __all__.extend(units.__all__) __all__.extend(scan.__all__) diff --git a/artiq/language/core.py b/artiq/language/core.py index da362a0e5..b588d7887 100644 --- a/artiq/language/core.py +++ b/artiq/language/core.py @@ -2,92 +2,168 @@ Core ARTIQ extensions to the Python language. """ +import os, linecache, re from collections import namedtuple from functools import wraps +# for runtime files in backtraces +from artiq.coredevice.runtime import source_loader -__all__ = ["int64", "round64", "TerminationRequested", - "kernel", "portable", - "set_time_manager", "set_syscall_manager", "set_watchdog_factory", - "RuntimeException", "EncodedException"] + +__all__ = ["host_int", "int", + "kernel", "portable", "syscall", + "set_time_manager", "set_watchdog_factory", + "ARTIQException", + "TerminationRequested"] # global namespace for kernels -kernel_globals = ("sequential", "parallel", +kernel_globals = ( + "sequential", "parallel", "delay_mu", "now_mu", "at_mu", "delay", "seconds_to_mu", "mu_to_seconds", - "syscall", "watchdog") + "watchdog" +) __all__.extend(kernel_globals) +host_int = int -class int64(int): - """64-bit integers for static compilation. +class int: + """ + Arbitrary-precision integers for static compilation. - When this class is used instead of Python's ``int``, the static compiler - stores the corresponding variable on 64 bits instead of 32. + The static compiler does not use unlimited-precision integers, + like Python normally does, because of their unbounded memory requirements. + Instead, it allows to choose a bit width (usually 32 or 64) at compile-time, + and all computations follow wrap-around semantics on overflow. - When used in the interpreter, it behaves as ``int`` and the results of - integer operations involving it are also ``int64`` (which matches the - size promotion rules of the static compiler). This way, it is possible to - specify 64-bit size annotations on constants that are passed to the - kernels. + This class implements the same semantics on the host. - Example: + For example: - >>> a = int64(1) - >>> b = int64(3) + 2 - >>> isinstance(a, int64) + >>> a = int(1, width=64) + >>> b = int(3, width=64) + 2 + >>> isinstance(a, int) True - >>> isinstance(b, int64) + >>> isinstance(b, int) True >>> a + b - 6 + int(6, width=64) + >>> int(10, width=32) + 0x7fffffff + int(9, width=32) + >>> int(0x80000000) + int(-2147483648, width=32) """ - pass -def _make_int64_op_method(int_method): - def method(self, *args): - r = int_method(self, *args) - if isinstance(r, int): - r = int64(r) - return r - return method + __slots__ = ['_value', '_width'] -for _op_name in ("neg", "pos", "abs", "invert", "round", - "add", "radd", "sub", "rsub", "mul", "rmul", "pow", "rpow", - "lshift", "rlshift", "rshift", "rrshift", - "and", "rand", "xor", "rxor", "or", "ror", - "floordiv", "rfloordiv", "mod", "rmod"): - _method_name = "__" + _op_name + "__" - _orig_method = getattr(int, _method_name) - setattr(int64, _method_name, _make_int64_op_method(_orig_method)) + def __new__(cls, value, width=32): + if isinstance(value, int): + return value + else: + sign_bit = 2 ** (width - 1) + value = host_int(value) + if value & sign_bit: + value = -1 & ~sign_bit + (value & (sign_bit - 1)) + 1 + else: + value &= sign_bit - 1 -for _op_name in ("add", "sub", "mul", "floordiv", "mod", - "pow", "lshift", "rshift", "lshift", - "and", "xor", "or"): - _op_method = getattr(int, "__" + _op_name + "__") - setattr(int64, "__i" + _op_name + "__", _make_int64_op_method(_op_method)) + self = super().__new__(cls) + self._value = value + self._width = width + return self + + @property + def width(self): + return self._width + + def __int__(self): + return self._value + + def __float__(self): + return float(self._value) + + def __str__(self): + return str(self._value) + + def __repr__(self): + return "int({}, width={})".format(self._value, self._width) + + def _unaryop(lower_fn): + def operator(self): + return int(lower_fn(self._value), self._width) + return operator + + __neg__ = _unaryop(host_int.__neg__) + __pos__ = _unaryop(host_int.__pos__) + __abs__ = _unaryop(host_int.__abs__) + __invert__ = _unaryop(host_int.__invert__) + __round__ = _unaryop(host_int.__round__) + + def _binaryop(lower_fn, rlower_fn=None): + def operator(self, other): + if isinstance(other, host_int): + return int(lower_fn(self._value, other), self._width) + elif isinstance(other, int): + width = self._width if self._width > other._width else other._width + return int(lower_fn(self._value, other._value), width) + elif rlower_fn: + return getattr(other, rlower_fn)(self._value) + else: + return NotImplemented + return operator + + __add__ = __iadd__ = _binaryop(host_int.__add__, "__radd__") + __sub__ = __isub__ = _binaryop(host_int.__sub__, "__rsub__") + __mul__ = __imul__ = _binaryop(host_int.__mul__, "__rmul__") + __truediv__ = __itruediv__ = _binaryop(host_int.__truediv__, "__rtruediv__") + __floordiv__ = __ifloordiv__ = _binaryop(host_int.__floordiv__, "__rfloordiv__") + __mod__ = __imod__ = _binaryop(host_int.__mod__, "__rmod__") + __pow__ = __ipow__ = _binaryop(host_int.__pow__, "__rpow__") + + __radd__ = _binaryop(host_int.__radd__, "__add__") + __rsub__ = _binaryop(host_int.__rsub__, "__sub__") + __rmul__ = _binaryop(host_int.__rmul__, "__mul__") + __rfloordiv__ = _binaryop(host_int.__rfloordiv__, "__floordiv__") + __rtruediv__ = _binaryop(host_int.__rtruediv__, "__truediv__") + __rmod__ = _binaryop(host_int.__rmod__, "__mod__") + __rpow__ = _binaryop(host_int.__rpow__, "__pow__") + + __lshift__ = __ilshift__ = _binaryop(host_int.__lshift__) + __rshift__ = __irshift__ = _binaryop(host_int.__rshift__) + __and__ = __iand__ = _binaryop(host_int.__and__) + __or__ = __ior__ = _binaryop(host_int.__or__) + __xor__ = __ixor__ = _binaryop(host_int.__xor__) + + __rlshift__ = _binaryop(host_int.__rlshift__) + __rrshift__ = _binaryop(host_int.__rrshift__) + __rand__ = _binaryop(host_int.__rand__) + __ror__ = _binaryop(host_int.__ror__) + __rxor__ = _binaryop(host_int.__rxor__) + + def _compareop(lower_fn, rlower_fn): + def operator(self, other): + if isinstance(other, host_int): + return lower_fn(self._value, other) + elif isinstance(other, int): + return lower_fn(self._value, other._value) + else: + return getattr(other, rlower_fn)(self._value) + return operator + + __eq__ = _compareop(host_int.__eq__, "__ne__") + __ne__ = _compareop(host_int.__ne__, "__eq__") + __gt__ = _compareop(host_int.__gt__, "__le__") + __ge__ = _compareop(host_int.__ge__, "__lt__") + __lt__ = _compareop(host_int.__lt__, "__ge__") + __le__ = _compareop(host_int.__le__, "__gt__") -def round64(x): - """Rounds to a 64-bit integer. - - This function is equivalent to ``int64(round(x))`` but, when targeting - static compilation, prevents overflow when the rounded value is too large - to fit in a 32-bit integer. - """ - return int64(round(x)) - - -class TerminationRequested(Exception): - """Raised by ``pause`` when the user has requested termination.""" - pass - - -_KernelFunctionInfo = namedtuple("_KernelFunctionInfo", "core_name k_function") - +_ARTIQEmbeddedInfo = namedtuple("_ARTIQEmbeddedInfo", + "core_name function syscall") def kernel(arg): - """This decorator marks an object's method for execution on the core + """ + This decorator marks an object's method for execution on the core device. When a decorated method is called from the Python interpreter, the ``core`` @@ -106,26 +182,20 @@ def kernel(arg): specifies the name of the attribute to use as core device driver. """ if isinstance(arg, str): - def real_decorator(k_function): - @wraps(k_function) - def run_on_core(exp, *k_args, **k_kwargs): - return getattr(exp, arg).run(k_function, - ((exp,) + k_args), k_kwargs) - run_on_core.k_function_info = _KernelFunctionInfo( - core_name=arg, k_function=k_function) + def inner_decorator(function): + @wraps(function) + def run_on_core(self, *k_args, **k_kwargs): + return getattr(self, arg).run(run_on_core, ((self,) + k_args), k_kwargs) + run_on_core.artiq_embedded = _ARTIQEmbeddedInfo( + core_name=arg, function=function, syscall=None) return run_on_core - return real_decorator + return inner_decorator else: - @wraps(arg) - def run_on_core(exp, *k_args, **k_kwargs): - return exp.core.run(arg, ((exp,) + k_args), k_kwargs) - run_on_core.k_function_info = _KernelFunctionInfo( - core_name="core", k_function=arg) - return run_on_core + return kernel("core")(arg) - -def portable(f): - """This decorator marks a function for execution on the same device as its +def portable(function): + """ + This decorator marks a function for execution on the same device as its caller. In other words, a decorated function called from the interpreter on the @@ -133,8 +203,30 @@ def portable(f): core device). A decorated function called from a kernel will be executed on the core device (no RPC). """ - f.k_function_info = _KernelFunctionInfo(core_name="", k_function=f) - return f + function.artiq_embedded = \ + _ARTIQEmbeddedInfo(core_name=None, function=function, syscall=None) + return function + +def syscall(arg): + """ + This decorator marks a function as a system call. When executed on a core + device, a C function with the provided name (or the same name as + the Python function, if not provided) will be called. When executed on + host, the Python function will be called as usual. + + Every argument and the return value must be annotated with ARTIQ types. + + Only drivers should normally define syscalls. + """ + if isinstance(arg, str): + def inner_decorator(function): + function.artiq_embedded = \ + _ARTIQEmbeddedInfo(core_name=None, function=None, + syscall=function.__name__) + return function + return inner_decorator + else: + return syscall(arg.__name__)(arg) class _DummyTimeManager: @@ -163,22 +255,6 @@ def set_time_manager(time_manager): _time_manager = time_manager -class _DummySyscallManager: - def do(self, *args): - raise NotImplementedError( - "Attempted to interpret kernel without a syscall manager") - -_syscall_manager = _DummySyscallManager() - - -def set_syscall_manager(syscall_manager): - """Set the system call manager used for simulating the core device's - runtime in the Python interpreter. - """ - global _syscall_manager - _syscall_manager = syscall_manager - - class _Sequential: """In a sequential block, statements are executed one after another, with the time increasing as one moves down the statement list.""" @@ -251,17 +327,6 @@ def mu_to_seconds(mu, core=None): return mu*core.ref_period -def syscall(*args): - """Invokes a service of the runtime. - - Kernels use this function to interface to the outside world: program RTIO - events, make RPCs, etc. - - Only drivers should normally use ``syscall``. - """ - return _syscall_manager.do(*args) - - class _DummyWatchdog: def __init__(self, timeout): pass @@ -286,32 +351,70 @@ def watchdog(timeout): return _watchdog_factory(timeout) -_encoded_exceptions = dict() +class TerminationRequested(Exception): + """Raised by ``pause`` when the user has requested termination.""" + pass -def EncodedException(eid): - """Represents exceptions on the core device, which are identified - by a single number.""" - try: - return _encoded_exceptions[eid] - except KeyError: - class EncodedException(Exception): - def __init__(self): - Exception.__init__(self, eid) - _encoded_exceptions[eid] = EncodedException - return EncodedException +class ARTIQException(Exception): + """Base class for exceptions raised or passed through the core device.""" + # Try and create an instance of the specific class, if one exists. + def __new__(cls, name, message, params, traceback): + def find_subclass(cls): + if cls.__name__ == name: + return cls + else: + for subclass in cls.__subclasses__(): + cls = find_subclass(subclass) + if cls is not None: + return cls -class RuntimeException(Exception): - """Base class for all exceptions used by the device runtime. - Those exceptions are defined in ``artiq.coredevice.runtime_exceptions``. - """ - def __init__(self, core, p0, p1, p2): - Exception.__init__(self) - self.core = core - self.p0 = p0 - self.p1 = p1 - self.p2 = p2 + more_specific_cls = find_subclass(cls) + if more_specific_cls is None: + more_specific_cls = cls + exn = Exception.__new__(more_specific_cls) + exn.__init__(name, message, params, traceback) + return exn -first_user_eid = 1024 + def __init__(self, name, message, params, traceback): + Exception.__init__(self, name, message, *params) + self.name, self.message, self.params = name, message, params + self.traceback = list(traceback) + + def __str__(self): + lines = [] + + if type(self).__name__ == self.name: + lines.append(self.message.format(*self.params)) + else: + lines.append("({}) {}".format(self.name, self.message.format(*self.params))) + + lines.append("Core Device Traceback (most recent call last):") + for (filename, line, column, function, address) in self.traceback: + stub_globals = {"__name__": filename, "__loader__": source_loader} + source_line = linecache.getline(filename, line, stub_globals) + indentation = re.search(r"^\s*", source_line).end() + + if address is None: + formatted_address = "" + else: + formatted_address = " (RA=0x{:x})".format(address) + + filename = filename.replace(os.path.normpath(os.path.join(os.path.dirname(__file__), + "..")), "") + if column == -1: + lines.append(" File \"{file}\", line {line}, in {function}{address}". + format(file=filename, line=line, function=function, + address=formatted_address)) + lines.append(" {}".format(source_line.strip() if source_line else "")) + else: + lines.append(" File \"{file}\", line {line}, column {column}," + " in {function}{address}". + format(file=filename, line=line, column=column + 1, + function=function, address=formatted_address)) + lines.append(" {}".format(source_line.strip() if source_line else "")) + lines.append(" {}^".format(" " * (column - indentation))) + + return "\n".join(lines) diff --git a/artiq/language/types.py b/artiq/language/types.py new file mode 100644 index 000000000..66a8b1b89 --- /dev/null +++ b/artiq/language/types.py @@ -0,0 +1,19 @@ +""" +Values representing ARTIQ types, to be used in function type +annotations. +""" + +from artiq.compiler import types, builtins + +__all__ = ["TNone", "TBool", "TInt32", "TInt64", "TFloat", + "TStr", "TList", "TRange32", "TRange64"] + +TNone = builtins.TNone() +TBool = builtins.TBool() +TInt32 = builtins.TInt(types.TValue(32)) +TInt64 = builtins.TInt(types.TValue(64)) +TFloat = builtins.TFloat() +TStr = builtins.TStr() +TList = builtins.TList +TRange32 = builtins.TRange(builtins.TInt(types.TValue(32))) +TRange64 = builtins.TRange(builtins.TInt(types.TValue(64))) diff --git a/artiq/py2llvm/__init__.py b/artiq/py2llvm/__init__.py deleted file mode 100644 index ebb8a93af..000000000 --- a/artiq/py2llvm/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from artiq.py2llvm.module import Module - -def get_runtime_binary(runtime, func_def): - module = Module(runtime) - module.compile_function(func_def, dict()) - return module.emit_object() diff --git a/artiq/py2llvm/ast_body.py b/artiq/py2llvm/ast_body.py deleted file mode 100644 index 17b08c861..000000000 --- a/artiq/py2llvm/ast_body.py +++ /dev/null @@ -1,539 +0,0 @@ -import ast - -import llvmlite_artiq.ir as ll - -from artiq.py2llvm import values, base_types, fractions, lists, iterators -from artiq.py2llvm.tools import is_terminated - - -_ast_unops = { - ast.Invert: "o_inv", - ast.Not: "o_not", - ast.UAdd: "o_pos", - ast.USub: "o_neg" -} - -_ast_binops = { - ast.Add: values.operators.add, - ast.Sub: values.operators.sub, - ast.Mult: values.operators.mul, - ast.Div: values.operators.truediv, - ast.FloorDiv: values.operators.floordiv, - ast.Mod: values.operators.mod, - ast.Pow: values.operators.pow, - ast.LShift: values.operators.lshift, - ast.RShift: values.operators.rshift, - ast.BitOr: values.operators.or_, - ast.BitXor: values.operators.xor, - ast.BitAnd: values.operators.and_ -} - -_ast_cmps = { - ast.Eq: values.operators.eq, - ast.NotEq: values.operators.ne, - ast.Lt: values.operators.lt, - ast.LtE: values.operators.le, - ast.Gt: values.operators.gt, - ast.GtE: values.operators.ge -} - - -class Visitor: - def __init__(self, runtime, ns, builder=None): - self.runtime = runtime - self.ns = ns - self.builder = builder - self._break_stack = [] - self._continue_stack = [] - self._active_exception_stack = [] - self._exception_level_stack = [0] - - # builder can be None for visit_expression - def visit_expression(self, node): - method = "_visit_expr_" + node.__class__.__name__ - try: - visitor = getattr(self, method) - except AttributeError: - raise NotImplementedError("Unsupported node '{}' in expression" - .format(node.__class__.__name__)) - return visitor(node) - - def _visit_expr_Name(self, node): - try: - r = self.ns[node.id] - except KeyError: - raise NameError("Name '{}' is not defined".format(node.id)) - return r - - def _visit_expr_NameConstant(self, node): - v = node.value - if v is None: - r = base_types.VNone() - elif isinstance(v, bool): - r = base_types.VBool() - else: - raise NotImplementedError - if self.builder is not None: - r.set_const_value(self.builder, v) - return r - - def _visit_expr_Num(self, node): - n = node.n - if isinstance(n, int): - if abs(n) < 2**31: - r = base_types.VInt() - else: - r = base_types.VInt(64) - elif isinstance(n, float): - r = base_types.VFloat() - else: - raise NotImplementedError - if self.builder is not None: - r.set_const_value(self.builder, n) - return r - - def _visit_expr_UnaryOp(self, node): - value = self.visit_expression(node.operand) - return getattr(value, _ast_unops[type(node.op)])(self.builder) - - def _visit_expr_BinOp(self, node): - return _ast_binops[type(node.op)](self.visit_expression(node.left), - self.visit_expression(node.right), - self.builder) - - def _visit_expr_BoolOp(self, node): - if self.builder is not None: - initial_block = self.builder.basic_block - function = initial_block.function - merge_block = function.append_basic_block("b_merge") - - test_blocks = [] - test_values = [] - for i, value in enumerate(node.values): - if self.builder is not None: - test_block = function.append_basic_block("b_{}_test".format(i)) - test_blocks.append(test_block) - self.builder.position_at_end(test_block) - test_values.append(self.visit_expression(value)) - - result = test_values[0].new() - for value in test_values[1:]: - result.merge(value) - - if self.builder is not None: - self.builder.position_at_end(initial_block) - result.alloca(self.builder, "b_result") - self.builder.branch(test_blocks[0]) - - next_test_blocks = test_blocks[1:] - next_test_blocks.append(None) - for block, next_block, value in zip(test_blocks, - next_test_blocks, - test_values): - self.builder.position_at_end(block) - bval = value.o_bool(self.builder) - result.auto_store(self.builder, - value.auto_load(self.builder)) - if next_block is None: - self.builder.branch(merge_block) - else: - if isinstance(node.op, ast.Or): - self.builder.cbranch(bval.auto_load(self.builder), - merge_block, - next_block) - elif isinstance(node.op, ast.And): - self.builder.cbranch(bval.auto_load(self.builder), - next_block, - merge_block) - else: - raise NotImplementedError - self.builder.position_at_end(merge_block) - - return result - - def _visit_expr_Compare(self, node): - comparisons = [] - old_comparator = self.visit_expression(node.left) - for op, comparator_a in zip(node.ops, node.comparators): - comparator = self.visit_expression(comparator_a) - comparison = _ast_cmps[type(op)](old_comparator, comparator, - self.builder) - comparisons.append(comparison) - old_comparator = comparator - r = comparisons[0] - for comparison in comparisons[1:]: - r = values.operators.and_(r, comparison) - return r - - def _visit_expr_Call(self, node): - fn = node.func.id - if fn in {"bool", "int", "int64", "round", "round64", "float", "len"}: - value = self.visit_expression(node.args[0]) - return getattr(value, "o_" + fn)(self.builder) - elif fn == "Fraction": - r = fractions.VFraction() - if self.builder is not None: - numerator = self.visit_expression(node.args[0]) - denominator = self.visit_expression(node.args[1]) - r.set_value_nd(self.builder, numerator, denominator) - return r - elif fn == "range": - return iterators.IRange( - self.builder, - [self.visit_expression(arg) for arg in node.args]) - elif fn == "syscall": - return self.runtime.build_syscall( - node.args[0].s, - [self.visit_expression(expr) for expr in node.args[1:]], - self.builder) - else: - raise NameError("Function '{}' is not defined".format(fn)) - - def _visit_expr_Attribute(self, node): - value = self.visit_expression(node.value) - return value.o_getattr(node.attr, self.builder) - - def _visit_expr_List(self, node): - elts = [self.visit_expression(elt) for elt in node.elts] - if elts: - el_type = elts[0].new() - for elt in elts[1:]: - el_type.merge(elt) - else: - el_type = base_types.VNone() - count = len(elts) - r = lists.VList(el_type, count) - r.elts = elts - return r - - def _visit_expr_ListComp(self, node): - if len(node.generators) != 1: - raise NotImplementedError - generator = node.generators[0] - if not isinstance(generator, ast.comprehension): - raise NotImplementedError - if not isinstance(generator.target, ast.Name): - raise NotImplementedError - target = generator.target.id - if not isinstance(generator.iter, ast.Call): - raise NotImplementedError - if not isinstance(generator.iter.func, ast.Name): - raise NotImplementedError - if generator.iter.func.id != "range": - raise NotImplementedError - if len(generator.iter.args) != 1: - raise NotImplementedError - if not isinstance(generator.iter.args[0], ast.Num): - raise NotImplementedError - count = generator.iter.args[0].n - - # Prevent incorrect use of the generator target, if it is defined in - # the local function namespace. - if target in self.ns: - old_target_val = self.ns[target] - del self.ns[target] - else: - old_target_val = None - elt = self.visit_expression(node.elt) - if old_target_val is not None: - self.ns[target] = old_target_val - - el_type = elt.new() - r = lists.VList(el_type, count) - r.elt = elt - return r - - def _visit_expr_Subscript(self, node): - value = self.visit_expression(node.value) - if isinstance(node.slice, ast.Index): - index = self.visit_expression(node.slice.value) - else: - raise NotImplementedError - return value.o_subscript(index, self.builder) - - def visit_statements(self, stmts): - for node in stmts: - node_type = node.__class__.__name__ - method = "_visit_stmt_" + node_type - try: - visitor = getattr(self, method) - except AttributeError: - raise NotImplementedError("Unsupported node '{}' in statement" - .format(node_type)) - visitor(node) - if node_type in ("Return", "Break", "Continue"): - break - - def _bb_terminated(self): - return is_terminated(self.builder.basic_block) - - def _visit_stmt_Assign(self, node): - val = self.visit_expression(node.value) - if isinstance(node.value, ast.List): - if len(node.targets) > 1: - raise NotImplementedError - target = self.visit_expression(node.targets[0]) - target.set_count(self.builder, val.alloc_count) - for i, elt in enumerate(val.elts): - idx = base_types.VInt() - idx.set_const_value(self.builder, i) - target.o_subscript(idx, self.builder).set_value(self.builder, - elt) - elif isinstance(node.value, ast.ListComp): - if len(node.targets) > 1: - raise NotImplementedError - target = self.visit_expression(node.targets[0]) - target.set_count(self.builder, val.alloc_count) - - i = base_types.VInt() - i.alloca(self.builder) - i.auto_store(self.builder, ll.Constant(ll.IntType(32), 0)) - - function = self.builder.basic_block.function - copy_block = function.append_basic_block("ai_copy") - end_block = function.append_basic_block("ai_end") - self.builder.branch(copy_block) - - self.builder.position_at_end(copy_block) - target.o_subscript(i, self.builder).set_value(self.builder, - val.elt) - i.auto_store(self.builder, self.builder.add( - i.auto_load(self.builder), - ll.Constant(ll.IntType(32), 1))) - cont = self.builder.icmp_signed( - "<", i.auto_load(self.builder), - ll.Constant(ll.IntType(32), val.alloc_count)) - self.builder.cbranch(cont, copy_block, end_block) - - self.builder.position_at_end(end_block) - else: - for target in node.targets: - target = self.visit_expression(target) - target.set_value(self.builder, val) - - def _visit_stmt_AugAssign(self, node): - target = self.visit_expression(node.target) - right = self.visit_expression(node.value) - val = _ast_binops[type(node.op)](target, right, self.builder) - target.set_value(self.builder, val) - - def _visit_stmt_Expr(self, node): - self.visit_expression(node.value) - - def _visit_stmt_If(self, node): - function = self.builder.basic_block.function - then_block = function.append_basic_block("i_then") - else_block = function.append_basic_block("i_else") - merge_block = function.append_basic_block("i_merge") - - condition = self.visit_expression(node.test).o_bool(self.builder) - self.builder.cbranch(condition.auto_load(self.builder), - then_block, else_block) - - self.builder.position_at_end(then_block) - self.visit_statements(node.body) - if not self._bb_terminated(): - self.builder.branch(merge_block) - - self.builder.position_at_end(else_block) - self.visit_statements(node.orelse) - if not self._bb_terminated(): - self.builder.branch(merge_block) - - self.builder.position_at_end(merge_block) - - def _enter_loop_body(self, break_block, continue_block): - self._break_stack.append(break_block) - self._continue_stack.append(continue_block) - self._exception_level_stack.append(0) - - def _leave_loop_body(self): - self._exception_level_stack.pop() - self._continue_stack.pop() - self._break_stack.pop() - - def _visit_stmt_While(self, node): - function = self.builder.basic_block.function - - body_block = function.append_basic_block("w_body") - else_block = function.append_basic_block("w_else") - condition = self.visit_expression(node.test).o_bool(self.builder) - self.builder.cbranch( - condition.auto_load(self.builder), body_block, else_block) - - continue_block = function.append_basic_block("w_continue") - merge_block = function.append_basic_block("w_merge") - self.builder.position_at_end(body_block) - self._enter_loop_body(merge_block, continue_block) - self.visit_statements(node.body) - self._leave_loop_body() - if not self._bb_terminated(): - self.builder.branch(continue_block) - - self.builder.position_at_end(continue_block) - condition = self.visit_expression(node.test).o_bool(self.builder) - self.builder.cbranch( - condition.auto_load(self.builder), body_block, merge_block) - - self.builder.position_at_end(else_block) - self.visit_statements(node.orelse) - if not self._bb_terminated(): - self.builder.branch(merge_block) - - self.builder.position_at_end(merge_block) - - def _visit_stmt_For(self, node): - function = self.builder.basic_block.function - - it = self.visit_expression(node.iter) - target = self.visit_expression(node.target) - itval = it.get_value_ptr() - - body_block = function.append_basic_block("f_body") - else_block = function.append_basic_block("f_else") - cont = it.o_next(self.builder) - self.builder.cbranch( - cont.auto_load(self.builder), body_block, else_block) - - continue_block = function.append_basic_block("f_continue") - merge_block = function.append_basic_block("f_merge") - self.builder.position_at_end(body_block) - target.set_value(self.builder, itval) - self._enter_loop_body(merge_block, continue_block) - self.visit_statements(node.body) - self._leave_loop_body() - if not self._bb_terminated(): - self.builder.branch(continue_block) - - self.builder.position_at_end(continue_block) - cont = it.o_next(self.builder) - self.builder.cbranch( - cont.auto_load(self.builder), body_block, merge_block) - - self.builder.position_at_end(else_block) - self.visit_statements(node.orelse) - if not self._bb_terminated(): - self.builder.branch(merge_block) - - self.builder.position_at_end(merge_block) - - def _break_loop_body(self, target_block): - exception_levels = self._exception_level_stack[-1] - if exception_levels: - self.runtime.build_pop(self.builder, exception_levels) - self.builder.branch(target_block) - - def _visit_stmt_Break(self, node): - self._break_loop_body(self._break_stack[-1]) - - def _visit_stmt_Continue(self, node): - self._break_loop_body(self._continue_stack[-1]) - - def _visit_stmt_Return(self, node): - if node.value is None: - val = base_types.VNone() - else: - val = self.visit_expression(node.value) - exception_levels = sum(self._exception_level_stack) - if exception_levels: - self.runtime.build_pop(self.builder, exception_levels) - if isinstance(val, base_types.VNone): - self.builder.ret_void() - else: - self.builder.ret(val.auto_load(self.builder)) - - def _visit_stmt_Pass(self, node): - pass - - def _visit_stmt_Raise(self, node): - if self._active_exception_stack: - finally_block, propagate, propagate_eid = ( - self._active_exception_stack[-1]) - self.builder.store(ll.Constant(ll.IntType(1), 1), propagate) - if node.exc is not None: - eid = ll.Constant(ll.IntType(32), node.exc.args[0].n) - self.builder.store(eid, propagate_eid) - self.builder.branch(finally_block) - else: - eid = ll.Constant(ll.IntType(32), node.exc.args[0].n) - self.runtime.build_raise(self.builder, eid) - - def _handle_exception(self, function, finally_block, - propagate, propagate_eid, handlers): - eid = self.runtime.build_getid(self.builder) - self._active_exception_stack.append( - (finally_block, propagate, propagate_eid)) - self.builder.store(ll.Constant(ll.IntType(1), 1), propagate) - self.builder.store(eid, propagate_eid) - - for handler in handlers: - handled_exc_block = function.append_basic_block("try_exc_h") - cont_exc_block = function.append_basic_block("try_exc_c") - if handler.type is None: - self.builder.branch(handled_exc_block) - else: - if isinstance(handler.type, ast.Tuple): - match = self.builder.icmp_signed( - "==", eid, - ll.Constant(ll.IntType(32), - handler.type.elts[0].args[0].n)) - for elt in handler.type.elts[1:]: - match = self.builder.or_( - match, - self.builder.icmp_signed( - "==", eid, - ll.Constant(ll.IntType(32), elt.args[0].n))) - else: - match = self.builder.icmp_signed( - "==", eid, - ll.Constant(ll.IntType(32), handler.type.args[0].n)) - self.builder.cbranch(match, handled_exc_block, cont_exc_block) - self.builder.position_at_end(handled_exc_block) - self.builder.store(ll.Constant(ll.IntType(1), 0), propagate) - self.visit_statements(handler.body) - if not self._bb_terminated(): - self.builder.branch(finally_block) - self.builder.position_at_end(cont_exc_block) - self.builder.branch(finally_block) - - self._active_exception_stack.pop() - - def _visit_stmt_Try(self, node): - function = self.builder.basic_block.function - noexc_block = function.append_basic_block("try_noexc") - exc_block = function.append_basic_block("try_exc") - finally_block = function.append_basic_block("try_finally") - - propagate = self.builder.alloca(ll.IntType(1), - name="propagate") - self.builder.store(ll.Constant(ll.IntType(1), 0), propagate) - propagate_eid = self.builder.alloca(ll.IntType(32), - name="propagate_eid") - exception_occured = self.runtime.build_catch(self.builder) - self.builder.cbranch(exception_occured, exc_block, noexc_block) - - self.builder.position_at_end(noexc_block) - self._exception_level_stack[-1] += 1 - self.visit_statements(node.body) - self._exception_level_stack[-1] -= 1 - if not self._bb_terminated(): - self.runtime.build_pop(self.builder, 1) - self.visit_statements(node.orelse) - if not self._bb_terminated(): - self.builder.branch(finally_block) - self.builder.position_at_end(exc_block) - self._handle_exception(function, finally_block, - propagate, propagate_eid, node.handlers) - - propagate_block = function.append_basic_block("try_propagate") - merge_block = function.append_basic_block("try_merge") - self.builder.position_at_end(finally_block) - self.visit_statements(node.finalbody) - if not self._bb_terminated(): - self.builder.cbranch( - self.builder.load(propagate), - propagate_block, merge_block) - self.builder.position_at_end(propagate_block) - self.runtime.build_raise(self.builder, self.builder.load(propagate_eid)) - self.builder.branch(merge_block) - self.builder.position_at_end(merge_block) diff --git a/artiq/py2llvm/base_types.py b/artiq/py2llvm/base_types.py deleted file mode 100644 index a5690c396..000000000 --- a/artiq/py2llvm/base_types.py +++ /dev/null @@ -1,321 +0,0 @@ -import llvmlite_artiq.ir as ll - -from artiq.py2llvm.values import VGeneric - - -class VNone(VGeneric): - def get_llvm_type(self): - return ll.VoidType() - - def alloca(self, builder, name): - pass - - def set_const_value(self, builder, v): - assert v is None - - def set_value(self, builder, other): - if not isinstance(other, VNone): - raise TypeError - - def o_bool(self, builder): - r = VBool() - if builder is not None: - r.set_const_value(builder, False) - return r - - def o_not(self, builder): - r = VBool() - if builder is not None: - r.set_const_value(builder, True) - return r - - -class VInt(VGeneric): - def __init__(self, nbits=32): - VGeneric.__init__(self) - self.nbits = nbits - - def get_llvm_type(self): - return ll.IntType(self.nbits) - - def __repr__(self): - return "".format(self.nbits) - - def same_type(self, other): - return isinstance(other, VInt) and other.nbits == self.nbits - - def merge(self, other): - if isinstance(other, VInt) and not isinstance(other, VBool): - if other.nbits > self.nbits: - self.nbits = other.nbits - else: - raise TypeError("Incompatible types: {} and {}" - .format(repr(self), repr(other))) - - def set_value(self, builder, n): - self.auto_store( - builder, n.o_intx(self.nbits, builder).auto_load(builder)) - - def set_const_value(self, builder, n): - self.auto_store(builder, ll.Constant(self.get_llvm_type(), n)) - - def o_bool(self, builder, inv=False): - r = VBool() - if builder is not None: - r.auto_store( - builder, builder.icmp_signed( - "==" if inv else "!=", - self.auto_load(builder), - ll.Constant(self.get_llvm_type(), 0))) - return r - - def o_float(self, builder): - r = VFloat() - if builder is not None: - if isinstance(self, VBool): - cf = builder.uitofp - else: - cf = builder.sitofp - r.auto_store(builder, cf(self.auto_load(builder), - r.get_llvm_type())) - return r - - def o_not(self, builder): - return self.o_bool(builder, inv=True) - - def o_neg(self, builder): - r = VInt(self.nbits) - if builder is not None: - r.auto_store( - builder, builder.mul( - self.auto_load(builder), - ll.Constant(self.get_llvm_type(), -1))) - return r - - def o_intx(self, target_bits, builder): - r = VInt(target_bits) - if builder is not None: - if self.nbits == target_bits: - r.auto_store( - builder, self.auto_load(builder)) - if self.nbits > target_bits: - r.auto_store( - builder, builder.trunc(self.auto_load(builder), - r.get_llvm_type())) - if self.nbits < target_bits: - if isinstance(self, VBool): - ef = builder.zext - else: - ef = builder.sext - r.auto_store( - builder, ef(self.auto_load(builder), - r.get_llvm_type())) - return r - o_roundx = o_intx - - def o_truediv(self, other, builder): - if isinstance(other, VInt): - left = self.o_float(builder) - right = other.o_float(builder) - return left.o_truediv(right, builder) - else: - return NotImplemented - -def _make_vint_binop_method(builder_name, bool_op): - def binop_method(self, other, builder): - if isinstance(other, VInt): - target_bits = max(self.nbits, other.nbits) - if not bool_op and target_bits == 1: - target_bits = 32 - if bool_op and target_bits == 1: - r = VBool() - else: - r = VInt(target_bits) - if builder is not None: - left = self.o_intx(target_bits, builder) - right = other.o_intx(target_bits, builder) - bf = getattr(builder, builder_name) - r.auto_store( - builder, bf(left.auto_load(builder), - right.auto_load(builder))) - return r - else: - return NotImplemented - return binop_method - -for _method_name, _builder_name, _bool_op in (("o_add", "add", False), - ("o_sub", "sub", False), - ("o_mul", "mul", False), - ("o_floordiv", "sdiv", False), - ("o_mod", "srem", False), - ("o_and", "and_", True), - ("o_xor", "xor", True), - ("o_or", "or_", True)): - setattr(VInt, _method_name, _make_vint_binop_method(_builder_name, _bool_op)) - - -def _make_vint_cmp_method(icmp_val): - def cmp_method(self, other, builder): - if isinstance(other, VInt): - r = VBool() - if builder is not None: - target_bits = max(self.nbits, other.nbits) - left = self.o_intx(target_bits, builder) - right = other.o_intx(target_bits, builder) - r.auto_store( - builder, - builder.icmp_signed( - icmp_val, left.auto_load(builder), - right.auto_load(builder))) - return r - else: - return NotImplemented - return cmp_method - -for _method_name, _icmp_val in (("o_eq", "=="), - ("o_ne", "!="), - ("o_lt", "<"), - ("o_le", "<="), - ("o_gt", ">"), - ("o_ge", ">=")): - setattr(VInt, _method_name, _make_vint_cmp_method(_icmp_val)) - - -class VBool(VInt): - def __init__(self): - VInt.__init__(self, 1) - - __repr__ = VGeneric.__repr__ - same_type = VGeneric.same_type - merge = VGeneric.merge - - def set_const_value(self, builder, b): - VInt.set_const_value(self, builder, int(b)) - - -class VFloat(VGeneric): - def get_llvm_type(self): - return ll.DoubleType() - - def set_value(self, builder, v): - if not isinstance(v, VFloat): - raise TypeError - self.auto_store(builder, v.auto_load(builder)) - - def set_const_value(self, builder, n): - self.auto_store(builder, ll.Constant(self.get_llvm_type(), n)) - - def o_float(self, builder): - r = VFloat() - if builder is not None: - r.auto_store(builder, self.auto_load(builder)) - return r - - def o_bool(self, builder, inv=False): - r = VBool() - if builder is not None: - r.auto_store( - builder, builder.fcmp_ordered( - "==" if inv else "!=", - self.auto_load(builder), - ll.Constant(self.get_llvm_type(), 0.0))) - return r - - def o_not(self, builder): - return self.o_bool(builder, True) - - def o_neg(self, builder): - r = VFloat() - if builder is not None: - r.auto_store( - builder, builder.fmul( - self.auto_load(builder), - ll.Constant(self.get_llvm_type(), -1.0))) - return r - - def o_intx(self, target_bits, builder): - r = VInt(target_bits) - if builder is not None: - r.auto_store(builder, builder.fptosi(self.auto_load(builder), - r.get_llvm_type())) - return r - - def o_roundx(self, target_bits, builder): - r = VInt(target_bits) - if builder is not None: - function = builder.basic_block.function - neg_block = function.append_basic_block("fr_neg") - merge_block = function.append_basic_block("fr_merge") - - half = VFloat() - half.alloca(builder, "half") - half.set_const_value(builder, 0.5) - - condition = builder.fcmp_ordered( - "<", - self.auto_load(builder), - ll.Constant(self.get_llvm_type(), 0.0)) - builder.cbranch(condition, neg_block, merge_block) - - builder.position_at_end(neg_block) - half.set_const_value(builder, -0.5) - builder.branch(merge_block) - - builder.position_at_end(merge_block) - s = builder.fadd(self.auto_load(builder), half.auto_load(builder)) - r.auto_store(builder, builder.fptosi(s, r.get_llvm_type())) - return r - - def o_floordiv(self, other, builder): - return self.o_truediv(other, builder).o_int64(builder).o_float(builder) - -def _make_vfloat_binop_method(builder_name, reverse): - def binop_method(self, other, builder): - if not hasattr(other, "o_float"): - return NotImplemented - r = VFloat() - if builder is not None: - left = self.o_float(builder) - right = other.o_float(builder) - if reverse: - left, right = right, left - bf = getattr(builder, builder_name) - r.auto_store( - builder, bf(left.auto_load(builder), - right.auto_load(builder))) - return r - return binop_method - -for _method_name, _builder_name in (("add", "fadd"), - ("sub", "fsub"), - ("mul", "fmul"), - ("truediv", "fdiv")): - setattr(VFloat, "o_" + _method_name, - _make_vfloat_binop_method(_builder_name, False)) - setattr(VFloat, "or_" + _method_name, - _make_vfloat_binop_method(_builder_name, True)) - - -def _make_vfloat_cmp_method(fcmp_val): - def cmp_method(self, other, builder): - if not hasattr(other, "o_float"): - return NotImplemented - r = VBool() - if builder is not None: - left = self.o_float(builder) - right = other.o_float(builder) - r.auto_store( - builder, - builder.fcmp_ordered( - fcmp_val, left.auto_load(builder), - right.auto_load(builder))) - return r - return cmp_method - -for _method_name, _fcmp_val in (("o_eq", "=="), - ("o_ne", "!="), - ("o_lt", "<"), - ("o_le", "<="), - ("o_gt", ">"), - ("o_ge", ">=")): - setattr(VFloat, _method_name, _make_vfloat_cmp_method(_fcmp_val)) diff --git a/artiq/py2llvm/infer_types.py b/artiq/py2llvm/infer_types.py deleted file mode 100644 index d02aeef36..000000000 --- a/artiq/py2llvm/infer_types.py +++ /dev/null @@ -1,74 +0,0 @@ -import ast -from copy import deepcopy - -from artiq.py2llvm.ast_body import Visitor -from artiq.py2llvm import base_types - - -class _TypeScanner(ast.NodeVisitor): - def __init__(self, env, ns): - self.exprv = Visitor(env, ns) - - def _update_target(self, target, val): - ns = self.exprv.ns - if isinstance(target, ast.Name): - if target.id in ns: - ns[target.id].merge(val) - else: - ns[target.id] = deepcopy(val) - elif isinstance(target, ast.Subscript): - target = target.value - levels = 0 - while isinstance(target, ast.Subscript): - target = target.value - levels += 1 - if isinstance(target, ast.Name): - target_value = ns[target.id] - for i in range(levels): - target_value = target_value.o_subscript(None, None) - target_value.merge_subscript(val) - else: - raise NotImplementedError - else: - raise NotImplementedError - - def visit_Assign(self, node): - val = self.exprv.visit_expression(node.value) - for target in node.targets: - self._update_target(target, val) - - def visit_AugAssign(self, node): - val = self.exprv.visit_expression(ast.BinOp( - op=node.op, left=node.target, right=node.value)) - self._update_target(node.target, val) - - def visit_For(self, node): - it = self.exprv.visit_expression(node.iter) - self._update_target(node.target, it.get_value_ptr()) - self.generic_visit(node) - - def visit_Return(self, node): - if node.value is None: - val = base_types.VNone() - else: - val = self.exprv.visit_expression(node.value) - ns = self.exprv.ns - if "return" in ns: - ns["return"].merge(val) - else: - ns["return"] = deepcopy(val) - - -def infer_function_types(env, node, param_types): - ns = deepcopy(param_types) - ts = _TypeScanner(env, ns) - ts.visit(node) - while True: - prev_ns = deepcopy(ns) - ts = _TypeScanner(env, ns) - ts.visit(node) - if all(v.same_type(prev_ns[k]) for k, v in ns.items()): - # no more promotions - completed - if "return" not in ns: - ns["return"] = base_types.VNone() - return ns diff --git a/artiq/py2llvm/iterators.py b/artiq/py2llvm/iterators.py deleted file mode 100644 index 0e1526319..000000000 --- a/artiq/py2llvm/iterators.py +++ /dev/null @@ -1,51 +0,0 @@ -from artiq.py2llvm.values import operators -from artiq.py2llvm.base_types import VInt - -class IRange: - def __init__(self, builder, args): - minimum, step = None, None - if len(args) == 1: - maximum = args[0] - elif len(args) == 2: - minimum, maximum = args - else: - minimum, maximum, step = args - if minimum is None: - minimum = VInt() - if builder is not None: - minimum.set_const_value(builder, 0) - if step is None: - step = VInt() - if builder is not None: - step.set_const_value(builder, 1) - - self._counter = minimum.new() - self._counter.merge(maximum) - self._counter.merge(step) - self._minimum = self._counter.new() - self._maximum = self._counter.new() - self._step = self._counter.new() - - if builder is not None: - self._minimum.alloca(builder, "irange_min") - self._maximum.alloca(builder, "irange_max") - self._step.alloca(builder, "irange_step") - self._counter.alloca(builder, "irange_count") - - self._minimum.set_value(builder, minimum) - self._maximum.set_value(builder, maximum) - self._step.set_value(builder, step) - - counter_init = operators.sub(self._minimum, self._step, builder) - self._counter.set_value(builder, counter_init) - - # must be a pointer value that can be dereferenced anytime - # to get the current value of the iterator - def get_value_ptr(self): - return self._counter - - def o_next(self, builder): - self._counter.set_value( - builder, - operators.add(self._counter, self._step, builder)) - return operators.lt(self._counter, self._maximum, builder) diff --git a/artiq/py2llvm/lists.py b/artiq/py2llvm/lists.py deleted file mode 100644 index d486e7ddd..000000000 --- a/artiq/py2llvm/lists.py +++ /dev/null @@ -1,72 +0,0 @@ -import llvmlite_artiq.ir as ll - -from artiq.py2llvm.values import VGeneric -from artiq.py2llvm.base_types import VInt, VNone - - -class VList(VGeneric): - def __init__(self, el_type, alloc_count): - VGeneric.__init__(self) - self.el_type = el_type - self.alloc_count = alloc_count - - def get_llvm_type(self): - count = 0 if self.alloc_count is None else self.alloc_count - if isinstance(self.el_type, VNone): - return ll.LiteralStructType([ll.IntType(32)]) - else: - return ll.LiteralStructType([ - ll.IntType(32), ll.ArrayType(self.el_type.get_llvm_type(), - count)]) - - def __repr__(self): - return "".format( - repr(self.el_type), - "?" if self.alloc_count is None else self.alloc_count) - - def same_type(self, other): - return (isinstance(other, VList) - and self.el_type.same_type(other.el_type)) - - def merge(self, other): - if isinstance(other, VList): - if self.alloc_count: - if other.alloc_count: - self.el_type.merge(other.el_type) - if self.alloc_count < other.alloc_count: - self.alloc_count = other.alloc_count - else: - self.el_type = other.el_type.new() - self.alloc_count = other.alloc_count - else: - raise TypeError("Incompatible types: {} and {}" - .format(repr(self), repr(other))) - - def merge_subscript(self, other): - self.el_type.merge(other) - - def set_count(self, builder, count): - count_ptr = builder.gep(self.llvm_value, [ - ll.Constant(ll.IntType(32), 0), - ll.Constant(ll.IntType(32), 0)]) - builder.store(ll.Constant(ll.IntType(32), count), count_ptr) - - def o_len(self, builder): - r = VInt() - if builder is not None: - count_ptr = builder.gep(self.llvm_value, [ - ll.Constant(ll.IntType(32), 0), - ll.Constant(ll.IntType(32), 0)]) - r.auto_store(builder, builder.load(count_ptr)) - return r - - def o_subscript(self, index, builder): - r = self.el_type.new() - if builder is not None and not isinstance(r, VNone): - index = index.o_int(builder).auto_load(builder) - ssa_r = builder.gep(self.llvm_value, [ - ll.Constant(ll.IntType(32), 0), - ll.Constant(ll.IntType(32), 1), - index]) - r.auto_store(builder, ssa_r) - return r diff --git a/artiq/py2llvm/module.py b/artiq/py2llvm/module.py deleted file mode 100644 index b842833e9..000000000 --- a/artiq/py2llvm/module.py +++ /dev/null @@ -1,62 +0,0 @@ -import llvmlite_artiq.ir as ll -import llvmlite_artiq.binding as llvm - -from artiq.py2llvm import infer_types, ast_body, base_types, fractions, tools - - -class Module: - def __init__(self, runtime=None): - self.llvm_module = ll.Module("main") - self.runtime = runtime - - if self.runtime is not None: - self.runtime.init_module(self) - fractions.init_module(self) - - def finalize(self): - self.llvm_module_ref = llvm.parse_assembly(str(self.llvm_module)) - pmb = llvm.create_pass_manager_builder() - pmb.opt_level = 2 - pm = llvm.create_module_pass_manager() - pmb.populate(pm) - pm.run(self.llvm_module_ref) - - def get_ee(self): - self.finalize() - tm = llvm.Target.from_default_triple().create_target_machine() - ee = llvm.create_mcjit_compiler(self.llvm_module_ref, tm) - ee.finalize_object() - return ee - - def emit_object(self): - self.finalize() - return self.runtime.emit_object() - - def compile_function(self, func_def, param_types): - ns = infer_types.infer_function_types(self.runtime, func_def, param_types) - retval = ns["return"] - - function_type = ll.FunctionType(retval.get_llvm_type(), - [ns[arg.arg].get_llvm_type() for arg in func_def.args.args]) - function = ll.Function(self.llvm_module, function_type, func_def.name) - bb = function.append_basic_block("entry") - builder = ll.IRBuilder() - builder.position_at_end(bb) - - for arg_ast, arg_llvm in zip(func_def.args.args, function.args): - arg_llvm.name = arg_ast.arg - for k, v in ns.items(): - v.alloca(builder, k) - for arg_ast, arg_llvm in zip(func_def.args.args, function.args): - ns[arg_ast.arg].auto_store(builder, arg_llvm) - - visitor = ast_body.Visitor(self.runtime, ns, builder) - visitor.visit_statements(func_def.body) - - if not tools.is_terminated(builder.basic_block): - if isinstance(retval, base_types.VNone): - builder.ret_void() - else: - builder.ret(retval.auto_load(builder)) - - return function, retval diff --git a/artiq/py2llvm/tools.py b/artiq/py2llvm/tools.py deleted file mode 100644 index 361b82a6f..000000000 --- a/artiq/py2llvm/tools.py +++ /dev/null @@ -1,5 +0,0 @@ -import llvmlite_artiq.ir as ll - -def is_terminated(basic_block): - return (basic_block.instructions - and isinstance(basic_block.instructions[-1], ll.Terminator)) diff --git a/artiq/py2llvm/values.py b/artiq/py2llvm/values.py deleted file mode 100644 index 254d17541..000000000 --- a/artiq/py2llvm/values.py +++ /dev/null @@ -1,94 +0,0 @@ -from types import SimpleNamespace -from copy import copy - -import llvmlite_artiq.ir as ll - - -class VGeneric: - def __init__(self): - self.llvm_value = None - - def new(self): - r = copy(self) - r.llvm_value = None - return r - - def __repr__(self): - return "<" + self.__class__.__name__ + ">" - - def same_type(self, other): - return isinstance(other, self.__class__) - - def merge(self, other): - if not self.same_type(other): - raise TypeError("Incompatible types: {} and {}" - .format(repr(self), repr(other))) - - def auto_load(self, builder): - if isinstance(self.llvm_value.type, ll.PointerType): - return builder.load(self.llvm_value) - else: - return self.llvm_value - - def auto_store(self, builder, llvm_value): - if self.llvm_value is None: - self.llvm_value = llvm_value - elif isinstance(self.llvm_value.type, ll.PointerType): - builder.store(llvm_value, self.llvm_value) - else: - raise RuntimeError( - "Attempted to set LLVM SSA value multiple times") - - def alloca(self, builder, name=""): - if self.llvm_value is not None: - raise RuntimeError("Attempted to alloca existing LLVM value "+name) - self.llvm_value = builder.alloca(self.get_llvm_type(), name=name) - - def o_int(self, builder): - return self.o_intx(32, builder) - - def o_int64(self, builder): - return self.o_intx(64, builder) - - def o_round(self, builder): - return self.o_roundx(32, builder) - - def o_round64(self, builder): - return self.o_roundx(64, builder) - - -def _make_binary_operator(op_name): - def op(l, r, builder): - try: - opf = getattr(l, "o_" + op_name) - except AttributeError: - result = NotImplemented - else: - result = opf(r, builder) - if result is NotImplemented: - try: - ropf = getattr(r, "or_" + op_name) - except AttributeError: - result = NotImplemented - else: - result = ropf(l, builder) - if result is NotImplemented: - raise TypeError( - "Unsupported operand types for {}: {} and {}" - .format(op_name, type(l).__name__, type(r).__name__)) - return result - return op - - -def _make_operators(): - d = dict() - for op_name in ("add", "sub", "mul", - "truediv", "floordiv", "mod", - "pow", "lshift", "rshift", "xor", - "eq", "ne", "lt", "le", "gt", "ge"): - d[op_name] = _make_binary_operator(op_name) - d["and_"] = _make_binary_operator("and") - d["or_"] = _make_binary_operator("or") - return SimpleNamespace(**d) - -operators = _make_operators() diff --git a/artiq/py2llvm/fractions.py b/artiq/py2llvm_old/fractions.py similarity index 99% rename from artiq/py2llvm/fractions.py rename to artiq/py2llvm_old/fractions.py index d00ff74de..a2895107b 100644 --- a/artiq/py2llvm/fractions.py +++ b/artiq/py2llvm_old/fractions.py @@ -1,5 +1,5 @@ import inspect -import ast +from pythonparser import parse, ast import llvmlite_artiq.ir as ll @@ -18,7 +18,7 @@ def _gcd(a, b): def init_module(module): - func_def = ast.parse(inspect.getsource(_gcd)).body[0] + func_def = parse(inspect.getsource(_gcd)).body[0] function, _ = module.compile_function(func_def, {"a": VInt(64), "b": VInt(64)}) function.linkage = "internal" diff --git a/artiq/py2llvm_old/test/py2llvm.py b/artiq/py2llvm_old/test/py2llvm.py new file mode 100644 index 000000000..c6d9f0135 --- /dev/null +++ b/artiq/py2llvm_old/test/py2llvm.py @@ -0,0 +1,169 @@ +import unittest +from pythonparser import parse, ast +import inspect +from fractions import Fraction +from ctypes import CFUNCTYPE, c_int, c_int32, c_int64, c_double +import struct + +import llvmlite_or1k.binding as llvm + +from artiq.language.core import int64 +from artiq.py2llvm.infer_types import infer_function_types +from artiq.py2llvm import base_types, lists +from artiq.py2llvm.module import Module + +def simplify_encode(a, b): + f = Fraction(a, b) + return f.numerator*1000 + f.denominator + + +def frac_arith_encode(op, a, b, c, d): + if op == 0: + f = Fraction(a, b) - Fraction(c, d) + elif op == 1: + f = Fraction(a, b) + Fraction(c, d) + elif op == 2: + f = Fraction(a, b) * Fraction(c, d) + else: + f = Fraction(a, b) / Fraction(c, d) + return f.numerator*1000 + f.denominator + + +def frac_arith_encode_int(op, a, b, x): + if op == 0: + f = Fraction(a, b) - x + elif op == 1: + f = Fraction(a, b) + x + elif op == 2: + f = Fraction(a, b) * x + else: + f = Fraction(a, b) / x + return f.numerator*1000 + f.denominator + + +def frac_arith_encode_int_rev(op, a, b, x): + if op == 0: + f = x - Fraction(a, b) + elif op == 1: + f = x + Fraction(a, b) + elif op == 2: + f = x * Fraction(a, b) + else: + f = x / Fraction(a, b) + return f.numerator*1000 + f.denominator + + +def frac_arith_float(op, a, b, x): + if op == 0: + return Fraction(a, b) - x + elif op == 1: + return Fraction(a, b) + x + elif op == 2: + return Fraction(a, b) * x + else: + return Fraction(a, b) / x + + +def frac_arith_float_rev(op, a, b, x): + if op == 0: + return x - Fraction(a, b) + elif op == 1: + return x + Fraction(a, b) + elif op == 2: + return x * Fraction(a, b) + else: + return x / Fraction(a, b) + + +class CodeGenCase(unittest.TestCase): + def test_frac_simplify(self): + simplify_encode_c = CompiledFunction( + simplify_encode, {"a": base_types.VInt(), "b": base_types.VInt()}) + for a in _test_range(): + for b in _test_range(): + self.assertEqual( + simplify_encode_c(a, b), simplify_encode(a, b)) + + def _test_frac_arith(self, op): + frac_arith_encode_c = CompiledFunction( + frac_arith_encode, { + "op": base_types.VInt(), + "a": base_types.VInt(), "b": base_types.VInt(), + "c": base_types.VInt(), "d": base_types.VInt()}) + for a in _test_range(): + for b in _test_range(): + for c in _test_range(): + for d in _test_range(): + self.assertEqual( + frac_arith_encode_c(op, a, b, c, d), + frac_arith_encode(op, a, b, c, d)) + + def test_frac_add(self): + self._test_frac_arith(0) + + def test_frac_sub(self): + self._test_frac_arith(1) + + def test_frac_mul(self): + self._test_frac_arith(2) + + def test_frac_div(self): + self._test_frac_arith(3) + + def _test_frac_arith_int(self, op, rev): + f = frac_arith_encode_int_rev if rev else frac_arith_encode_int + f_c = CompiledFunction(f, { + "op": base_types.VInt(), + "a": base_types.VInt(), "b": base_types.VInt(), + "x": base_types.VInt()}) + for a in _test_range(): + for b in _test_range(): + for x in _test_range(): + self.assertEqual( + f_c(op, a, b, x), + f(op, a, b, x)) + + def test_frac_add_int(self): + self._test_frac_arith_int(0, False) + self._test_frac_arith_int(0, True) + + def test_frac_sub_int(self): + self._test_frac_arith_int(1, False) + self._test_frac_arith_int(1, True) + + def test_frac_mul_int(self): + self._test_frac_arith_int(2, False) + self._test_frac_arith_int(2, True) + + def test_frac_div_int(self): + self._test_frac_arith_int(3, False) + self._test_frac_arith_int(3, True) + + def _test_frac_arith_float(self, op, rev): + f = frac_arith_float_rev if rev else frac_arith_float + f_c = CompiledFunction(f, { + "op": base_types.VInt(), + "a": base_types.VInt(), "b": base_types.VInt(), + "x": base_types.VFloat()}) + for a in _test_range(): + for b in _test_range(): + for x in _test_range(): + self.assertAlmostEqual( + f_c(op, a, b, x/2), + f(op, a, b, x/2)) + + def test_frac_add_float(self): + self._test_frac_arith_float(0, False) + self._test_frac_arith_float(0, True) + + def test_frac_sub_float(self): + self._test_frac_arith_float(1, False) + self._test_frac_arith_float(1, True) + + def test_frac_mul_float(self): + self._test_frac_arith_float(2, False) + self._test_frac_arith_float(2, True) + + def test_frac_div_float(self): + self._test_frac_arith_float(3, False) + self._test_frac_arith_float(3, True) diff --git a/artiq/transforms/inline.py b/artiq/py2llvm_old/transforms/inline.py similarity index 100% rename from artiq/transforms/inline.py rename to artiq/py2llvm_old/transforms/inline.py diff --git a/artiq/transforms/interleave.py b/artiq/py2llvm_old/transforms/interleave.py similarity index 100% rename from artiq/transforms/interleave.py rename to artiq/py2llvm_old/transforms/interleave.py diff --git a/artiq/py2llvm_old/transforms/quantize_time.py b/artiq/py2llvm_old/transforms/quantize_time.py new file mode 100644 index 000000000..42e04f564 --- /dev/null +++ b/artiq/py2llvm_old/transforms/quantize_time.py @@ -0,0 +1,43 @@ + def visit_With(self, node): + self.generic_visit(node) + if (isinstance(node.items[0].context_expr, ast.Call) + and node.items[0].context_expr.func.id == "watchdog"): + + idname = "__watchdog_id_" + str(self.watchdog_id_counter) + self.watchdog_id_counter += 1 + + time = ast.BinOp(left=node.items[0].context_expr.args[0], + op=ast.Mult(), + right=ast.Num(1000)) + time_int = ast.Call( + func=ast.Name("round", ast.Load()), + args=[time], + keywords=[], starargs=None, kwargs=None) + syscall_set = ast.Call( + func=ast.Name("syscall", ast.Load()), + args=[ast.Str("watchdog_set"), time_int], + keywords=[], starargs=None, kwargs=None) + stmt_set = ast.copy_location( + ast.Assign(targets=[ast.Name(idname, ast.Store())], + value=syscall_set), + node) + + syscall_clear = ast.Call( + func=ast.Name("syscall", ast.Load()), + args=[ast.Str("watchdog_clear"), + ast.Name(idname, ast.Load())], + keywords=[], starargs=None, kwargs=None) + stmt_clear = ast.copy_location(ast.Expr(syscall_clear), node) + + node.items[0] = ast.withitem( + context_expr=ast.Name(id="sequential", + ctx=ast.Load()), + optional_vars=None) + node.body = [ + stmt_set, + ast.Try(body=node.body, + handlers=[], + orelse=[], + finalbody=[stmt_clear]) + ] + return node diff --git a/artiq/transforms/unroll_loops.py b/artiq/py2llvm_old/transforms/unroll_loops.py similarity index 100% rename from artiq/transforms/unroll_loops.py rename to artiq/py2llvm_old/transforms/unroll_loops.py diff --git a/artiq/runtime/Makefile b/artiq/runtime/Makefile index b572d5f12..45b6f6ea3 100644 --- a/artiq/runtime/Makefile +++ b/artiq/runtime/Makefile @@ -3,10 +3,18 @@ include $(MISOC_DIRECTORY)/software/common.mak PYTHON ?= python3.5 -OBJECTS := isr.o flash_storage.o clock.o rtiocrg.o elf_loader.o services.o session.o log.o test_mode.o kloader.o bridge_ctl.o mailbox.o ksupport_data.o net_server.o moninj.o main.o -OBJECTS_KSUPPORT := ksupport.o exception_jmp.o exceptions.o mailbox.o bridge.o rtio.o ttl.o dds.o +OBJECTS := isr.o clock.o rtiocrg.o flash_storage.o mailbox.o \ + session.o log.o moninj.o net_server.o bridge_ctl.o \ + ksupport_data.o kloader.o test_mode.o main.o +OBJECTS_KSUPPORT := ksupport.o artiq_personality.o mailbox.o \ + bridge.o rtio.o ttl.o dds.o -CFLAGS += -I$(LIBLWIP_DIRECTORY)/../lwip/src/include -I$(LIBLWIP_DIRECTORY) -I. +CFLAGS += -I$(MISOC_DIRECTORY)/software/include/dyld \ + -I$(LIBDYLD_DIRECTORY)/include \ + -I$(LIBUNWIND_DIRECTORY) \ + -I$(LIBUNWIND_DIRECTORY)/../unwinder/include \ + -I$(LIBLWIP_DIRECTORY)/../lwip/src/include \ + -I$(LIBLWIP_DIRECTORY) all: runtime.bin runtime.fbi @@ -19,7 +27,7 @@ all: runtime.bin runtime.fbi runtime.elf: $(OBJECTS) $(LD) $(LDFLAGS) \ - -T $(RUNTIME_DIRECTORY)/linker.ld \ + -T $(RUNTIME_DIRECTORY)/runtime.ld \ -N -o $@ \ ../libbase/crt0-$(CPU).o \ $(OBJECTS) \ @@ -31,28 +39,21 @@ runtime.elf: $(OBJECTS) ksupport.elf: $(OBJECTS_KSUPPORT) $(LD) $(LDFLAGS) \ + --eh-frame-hdr \ -T $(RUNTIME_DIRECTORY)/ksupport.ld \ -N -o $@ \ ../libbase/crt0-$(CPU).o \ $^ \ + -L../libbase \ -L../libcompiler_rt \ - -lcompiler_rt + -L../libunwind \ + -L../libdyld \ + -lbase -lcompiler_rt -lunwind -ldyld @chmod -x $@ -ksupport_data.o: ksupport.bin +ksupport_data.o: ksupport.elf $(LD) -r -b binary -o $@ $< -service_table.h: ksupport.elf $(RUNTIME_DIRECTORY)/gen_service_table.py - @echo " GEN " $@ && $(PYTHON) $(RUNTIME_DIRECTORY)/gen_service_table.py ksupport.elf > $@ - -$(RUNTIME_DIRECTORY)/services.c: service_table.h - -services.o: $(RUNTIME_DIRECTORY)/services.c service_table.h - $(compile) - -main.o: $(RUNTIME_DIRECTORY)/main.c - $(compile) - %.o: $(RUNTIME_DIRECTORY)/%.c $(compile) @@ -62,6 +63,6 @@ main.o: $(RUNTIME_DIRECTORY)/main.c clean: $(RM) $(OBJECTS) $(OBJECTS_KSUPPORT) $(RM) runtime.elf runtime.bin runtime.fbi .*~ *~ - $(RM) service_table.h ksupport.elf ksupport.bin + $(RM) ksupport.elf ksupport.bin -.PHONY: all clean main.o +.PHONY: all clean diff --git a/artiq/runtime/artiq_personality.c b/artiq/runtime/artiq_personality.c new file mode 100644 index 000000000..d5ba26b24 --- /dev/null +++ b/artiq/runtime/artiq_personality.c @@ -0,0 +1,464 @@ +#include +#include +#include +#include +#include +#include "artiq_personality.h" + +/* Logging */ + +#ifndef NDEBUG +#define EH_LOG0(fmt) fprintf(stderr, "%s: " fmt "\n", __func__) +#define EH_LOG(fmt, ...) fprintf(stderr, "%s: " fmt "\n", __func__, __VA_ARGS__) +#else +#define EH_LOG0(fmt) +#define EH_LOG(fmt, ...) +#endif + +#define EH_FAIL(err) \ + do { \ + fprintf(stderr, "%s fatal: %s\n", __func__, err); \ + abort(); \ + } while(0) + +#define EH_ASSERT(expr) \ + if(!(expr)) EH_FAIL(#expr) + +/* DWARF format handling */ + +enum { + DW_EH_PE_absptr = 0x00, + DW_EH_PE_uleb128 = 0x01, + DW_EH_PE_udata2 = 0x02, + DW_EH_PE_udata4 = 0x03, + DW_EH_PE_udata8 = 0x04, + DW_EH_PE_sleb128 = 0x09, + DW_EH_PE_sdata2 = 0x0A, + DW_EH_PE_sdata4 = 0x0B, + DW_EH_PE_sdata8 = 0x0C, + DW_EH_PE_pcrel = 0x10, + DW_EH_PE_textrel = 0x20, + DW_EH_PE_datarel = 0x30, + DW_EH_PE_funcrel = 0x40, + DW_EH_PE_aligned = 0x50, + DW_EH_PE_indirect = 0x80, + DW_EH_PE_omit = 0xFF +}; + +// Read a uleb128 encoded value and advance pointer +// See Variable Length Data in: http://dwarfstd.org/Dwarf3.pdf +static uintptr_t readULEB128(const uint8_t **data) { + uintptr_t result = 0; + uintptr_t shift = 0; + unsigned char byte; + const uint8_t *p = *data; + + do { + byte = *p++; + result |= (byte & 0x7f) << shift; + shift += 7; + } + while (byte & 0x80); + + *data = p; + + return result; +} + +// Read a sleb128 encoded value and advance pointer +// See Variable Length Data in: http://dwarfstd.org/Dwarf3.pdf +static uintptr_t readSLEB128(const uint8_t **data) { + uintptr_t result = 0; + uintptr_t shift = 0; + unsigned char byte; + const uint8_t *p = *data; + + do { + byte = *p++; + result |= (byte & 0x7f) << shift; + shift += 7; + } + while (byte & 0x80); + + *data = p; + + if ((byte & 0x40) && (shift < (sizeof(result) << 3))) { + result |= (~0 << shift); + } + + return result; +} + +static unsigned getEncodingSize(uint8_t Encoding) { + if (Encoding == DW_EH_PE_omit) + return 0; + + switch (Encoding & 0x0F) { + case DW_EH_PE_absptr: + return sizeof(uintptr_t); + case DW_EH_PE_udata2: + return sizeof(uint16_t); + case DW_EH_PE_udata4: + return sizeof(uint32_t); + case DW_EH_PE_udata8: + return sizeof(uint64_t); + case DW_EH_PE_sdata2: + return sizeof(int16_t); + case DW_EH_PE_sdata4: + return sizeof(int32_t); + case DW_EH_PE_sdata8: + return sizeof(int64_t); + default: + // not supported + abort(); + } +} + +// Read a pointer encoded value and advance pointer +// See Variable Length Data in: http://dwarfstd.org/Dwarf3.pdf +static uintptr_t readEncodedPointer(const uint8_t **data, uint8_t encoding) { + uintptr_t result = 0; + const uint8_t *p = *data; + + if (encoding == DW_EH_PE_omit) + return(result); + + // first get value + switch (encoding & 0x0F) { + case DW_EH_PE_absptr: + memcpy(&result, p, sizeof(uintptr_t)); + p += sizeof(uintptr_t); + break; + case DW_EH_PE_uleb128: + result = readULEB128(&p); + break; + // Note: This case has not been tested + case DW_EH_PE_sleb128: + result = readSLEB128(&p); + break; + case DW_EH_PE_udata2: + { + uint16_t valu16; + memcpy(&valu16, p, sizeof(uint16_t)); + result = valu16; + } + p += sizeof(uint16_t); + break; + case DW_EH_PE_udata4: + { + uint32_t valu32; + memcpy(&valu32, p, sizeof(uint32_t)); + result = valu32; + } + p += sizeof(uint32_t); + break; + case DW_EH_PE_udata8: + { + uint64_t valu64; + memcpy(&valu64, p, sizeof(uint64_t)); + result = valu64; + } + p += sizeof(uint64_t); + break; + case DW_EH_PE_sdata2: + { + int16_t val16; + memcpy(&val16, p, sizeof(int16_t)); + result = val16; + } + p += sizeof(int16_t); + break; + case DW_EH_PE_sdata4: + { + int32_t val32; + memcpy(&val32, p, sizeof(int32_t)); + result = val32; + } + p += sizeof(int32_t); + break; + case DW_EH_PE_sdata8: + { + int64_t val64; + memcpy(&val64, p, sizeof(int64_t)); + result = val64; + } + p += sizeof(int64_t); + break; + default: + // not supported + abort(); + break; + } + + // then add relative offset + switch (encoding & 0x70) { + case DW_EH_PE_absptr: + // do nothing + break; + case DW_EH_PE_pcrel: + result += (uintptr_t)(*data); + break; + case DW_EH_PE_textrel: + case DW_EH_PE_datarel: + case DW_EH_PE_funcrel: + case DW_EH_PE_aligned: + default: + // not supported + abort(); + break; + } + + // then apply indirection + if (encoding & DW_EH_PE_indirect) { + result = *((uintptr_t*)result); + } + + *data = p; + + return result; +} + + +/* Raising */ + +#define ARTIQ_EXCEPTION_CLASS 0x4152545141525451LL // 'ARTQARTQ' + +static void __artiq_cleanup(_Unwind_Reason_Code reason, struct _Unwind_Exception *exc); +static _Unwind_Reason_Code __artiq_uncaught_exception( + int version, _Unwind_Action actions, uint64_t exceptionClass, + struct _Unwind_Exception *exceptionObject, struct _Unwind_Context *context, + void *stop_parameter); + +struct artiq_raised_exception { + struct _Unwind_Exception unwind; + struct artiq_exception artiq; + int handled; + struct artiq_backtrace_item backtrace[1024]; + size_t backtrace_size; +}; + +static struct artiq_raised_exception inflight; + +void __artiq_raise(struct artiq_exception *artiq_exn) { + EH_LOG("===> raise (name=%s, msg=%s, params=[%lld,%lld,%lld])", + artiq_exn->name, artiq_exn->message, + (long long int)artiq_exn->param[0], + (long long int)artiq_exn->param[1], + (long long int)artiq_exn->param[2]); + + memmove(&inflight.artiq, artiq_exn, sizeof(struct artiq_exception)); + inflight.unwind.exception_class = ARTIQ_EXCEPTION_CLASS; + inflight.unwind.exception_cleanup = &__artiq_cleanup; + inflight.handled = 0; + inflight.backtrace_size = 0; + + _Unwind_Reason_Code result = _Unwind_RaiseException(&inflight.unwind); + EH_ASSERT((result == _URC_END_OF_STACK) && + "Unexpected error during unwinding"); + + // If we're here, there are no handlers, only cleanups. + // Force unwinding anyway; we shall stop at nothing except the end of stack. + result = _Unwind_ForcedUnwind(&inflight.unwind, &__artiq_uncaught_exception, + NULL); + EH_FAIL("_Unwind_ForcedUnwind should not return"); +} + +void __artiq_reraise() { + if(inflight.handled) { + EH_LOG0("===> reraise"); + __artiq_raise(&inflight.artiq); + } else { + EH_LOG0("===> resume"); + EH_ASSERT((inflight.artiq.typeinfo != 0) && + "Need an exception to reraise"); + _Unwind_Resume(&inflight.unwind); + abort(); + } +} + +/* Unwinding */ + +// The code below does not refer to the `inflight` global. + +static void __artiq_cleanup(_Unwind_Reason_Code reason, struct _Unwind_Exception *exc) { + EH_LOG0("===> cleanup"); + struct artiq_raised_exception *inflight = (struct artiq_raised_exception*) exc; + // The in-flight exception is statically allocated, so we don't need to free it. + // But, we clear it to mark it as processed. + memset(&inflight->artiq, 0, sizeof(struct artiq_exception)); +} + +static _Unwind_Reason_Code __artiq_uncaught_exception( + int version, _Unwind_Action actions, uint64_t exceptionClass, + struct _Unwind_Exception *exceptionObject, struct _Unwind_Context *context, + void *stop_parameter) { + struct artiq_raised_exception *inflight = + (struct artiq_raised_exception*)exceptionObject; + EH_ASSERT(inflight->backtrace_size < + sizeof(inflight->backtrace) / sizeof(inflight->backtrace[0]) && + "Out of space for backtrace"); + + uintptr_t pc = _Unwind_GetIP(context); + uintptr_t funcStart = _Unwind_GetRegionStart(context); + uintptr_t pcOffset = pc - funcStart; + EH_LOG("===> uncaught (pc=%p+%p)", (void*)funcStart, (void*)pcOffset); + + inflight->backtrace[inflight->backtrace_size].function = funcStart; + inflight->backtrace[inflight->backtrace_size].offset = pcOffset; + ++inflight->backtrace_size; + + if(actions & _UA_END_OF_STACK) { + EH_LOG0("end of stack"); + __artiq_terminate(&inflight->artiq, inflight->backtrace, inflight->backtrace_size); + } else { + EH_LOG0("continue"); + return _URC_NO_REASON; + } +} + +_Unwind_Reason_Code __artiq_personality( + int version, _Unwind_Action actions, uint64_t exceptionClass, + struct _Unwind_Exception *exceptionObject, struct _Unwind_Context *context); +_Unwind_Reason_Code __artiq_personality( + int version, _Unwind_Action actions, uint64_t exceptionClass, + struct _Unwind_Exception *exceptionObject, struct _Unwind_Context *context) { + EH_LOG("===> entry (actions =%s%s%s%s; class=%08lx; object=%p, context=%p)", + (actions & _UA_SEARCH_PHASE ? " search" : ""), + (actions & _UA_CLEANUP_PHASE ? " cleanup" : ""), + (actions & _UA_HANDLER_FRAME ? " handler" : ""), + (actions & _UA_FORCE_UNWIND ? " force-unwind" : ""), + exceptionClass, exceptionObject, context); + EH_ASSERT((exceptionClass == ARTIQ_EXCEPTION_CLASS) && + "Foreign exceptions are not supported"); + + struct artiq_raised_exception *inflight = + (struct artiq_raised_exception*)exceptionObject; + EH_LOG("=> exception name=%s", + inflight->artiq.name); + + // Get a pointer to LSDA. If there's no LSDA, this function doesn't + // actually handle any exceptions. + const uint8_t *lsda = (const uint8_t*) _Unwind_GetLanguageSpecificData(context); + if(lsda == NULL) + return _URC_CONTINUE_UNWIND; + + EH_LOG("lsda=%p", lsda); + + // Get the current instruction pointer and offset it before next + // instruction in the current frame which threw the exception. + uintptr_t pc = _Unwind_GetIP(context) - 1; + + // Get beginning of the current frame's code. + uintptr_t funcStart = _Unwind_GetRegionStart(context); + uintptr_t pcOffset = pc - funcStart; + + EH_LOG("=> pc=%p (%p+%p)", (void*)pc, (void*)funcStart, (void*)pcOffset); + + // Parse LSDA header. + uint8_t lpStartEncoding = *lsda++; + if (lpStartEncoding != DW_EH_PE_omit) { + readEncodedPointer(&lsda, lpStartEncoding); + } + + uint8_t ttypeEncoding = *lsda++; + const uint8_t *classInfo = NULL; + if (ttypeEncoding != DW_EH_PE_omit) { + // Calculate type info locations in emitted dwarf code which + // were flagged by type info arguments to llvm.eh.selector + // intrinsic + uintptr_t classInfoOffset = readULEB128(&lsda); + classInfo = lsda + classInfoOffset; + EH_LOG("classInfo=%p", classInfo); + } + + // Walk call-site table looking for range that includes current PC. + uint8_t callSiteEncoding = *lsda++; + uint32_t callSiteTableLength = readULEB128(&lsda); + const uint8_t *callSiteTableStart = lsda; + const uint8_t *callSiteTableEnd = callSiteTableStart + callSiteTableLength; + const uint8_t *actionTableStart = callSiteTableEnd; + const uint8_t *callSitePtr = callSiteTableStart; + + while(callSitePtr < callSiteTableEnd) { + uintptr_t start = readEncodedPointer(&callSitePtr, + callSiteEncoding); + uintptr_t length = readEncodedPointer(&callSitePtr, + callSiteEncoding); + uintptr_t landingPad = readEncodedPointer(&callSitePtr, + callSiteEncoding); + uintptr_t actionValue = readULEB128(&callSitePtr); + + EH_LOG("call site (start=+%p, len=%d, landingPad=+%p, actionValue=%d)", + (void*)start, (int)length, (void*)landingPad, (int)actionValue); + + if(landingPad == 0) { + EH_LOG0("no landing pad, skipping"); + continue; + } + + if((start <= pcOffset) && (pcOffset < (start + length))) { + EH_LOG0("=> call site matches pc"); + + int exceptionMatched = 0; + if(actionValue) { + const uint8_t *actionEntry = actionTableStart + (actionValue - 1); + EH_LOG("actionEntry=%p", actionEntry); + + for(;;) { + // Each emitted DWARF action corresponds to a 2 tuple of + // type info address offset, and action offset to the next + // emitted action. + intptr_t typeInfoOffset = readSLEB128(&actionEntry); + const uint8_t *tempActionEntry = actionEntry; + intptr_t actionOffset = readSLEB128(&tempActionEntry); + EH_LOG("typeInfoOffset=%p actionOffset=%p", + (void*)typeInfoOffset, (void*)actionOffset); + EH_ASSERT((typeInfoOffset >= 0) && "Filter clauses are not supported"); + + unsigned encodingSize = getEncodingSize(ttypeEncoding); + const uint8_t *typeInfoPtrPtr = classInfo - typeInfoOffset * encodingSize; + uintptr_t typeInfoPtr = readEncodedPointer(&typeInfoPtrPtr, ttypeEncoding); + EH_LOG("encodingSize=%u typeInfoPtrPtr=%p typeInfoPtr=%p", + encodingSize, typeInfoPtrPtr, (void*)typeInfoPtr); + EH_LOG("typeInfo=%s", (char*)typeInfoPtr); + + if(typeInfoPtr == 0 || inflight->artiq.typeinfo == typeInfoPtr) { + EH_LOG0("matching action found"); + exceptionMatched = 1; + break; + } + + if (!actionOffset) + break; + + actionEntry += actionOffset; + } + } + + if(!(actions & _UA_SEARCH_PHASE)) { + EH_LOG0("=> jumping to landing pad"); + + if(actions & _UA_HANDLER_FRAME) + inflight->handled = 1; + + _Unwind_SetGR(context, __builtin_eh_return_data_regno(0), + (uintptr_t)exceptionObject); + _Unwind_SetGR(context, __builtin_eh_return_data_regno(1), + (uintptr_t)&inflight->artiq); + _Unwind_SetIP(context, funcStart + landingPad); + + return _URC_INSTALL_CONTEXT; + } else if(exceptionMatched) { + EH_LOG0("=> handler found"); + + return _URC_HANDLER_FOUND; + } else { + EH_LOG0("=> handler not found"); + + return _URC_CONTINUE_UNWIND; + } + } + } + + return _URC_CONTINUE_UNWIND; +} diff --git a/artiq/runtime/artiq_personality.h b/artiq/runtime/artiq_personality.h new file mode 100644 index 000000000..9e7ddce3e --- /dev/null +++ b/artiq/runtime/artiq_personality.h @@ -0,0 +1,59 @@ +#ifndef __ARTIQ_PERSONALITY_H +#define __ARTIQ_PERSONALITY_H + +#include +#include + +struct artiq_exception { + union { + uintptr_t typeinfo; + const char *name; + }; + const char *file; + int32_t line; + int32_t column; + const char *function; + const char *message; + int64_t param[3]; +}; + +struct artiq_backtrace_item { + intptr_t function; + intptr_t offset; +}; + +#ifdef __cplusplus +extern "C" { +#endif + +/* Provided by the runtime */ +void __artiq_raise(struct artiq_exception *artiq_exn) + __attribute__((noreturn)); +void __artiq_reraise(void) + __attribute__((noreturn)); + +#define artiq_raise_from_c(exnname, exnmsg, exnparam0, exnparam1, exnparam2) \ + do { \ + struct artiq_exception exn = { \ + .name = exnname, \ + .message = exnmsg, \ + .param = { exnparam0, exnparam1, exnparam2 }, \ + .file = __FILE__, \ + .line = __LINE__, \ + .column = -1, \ + .function = __func__, \ + }; \ + __artiq_raise(&exn); \ + } while(0) + +/* Called by the runtime */ +void __artiq_terminate(struct artiq_exception *artiq_exn, + struct artiq_backtrace_item *backtrace, + size_t backtrace_size) + __attribute__((noreturn)); + +#ifdef __cplusplus +} +#endif + +#endif /* __ARTIQ_PERSONALITY_H */ diff --git a/artiq/runtime/bridge.c b/artiq/runtime/bridge.c index 8ae83c1cc..d810ae5e8 100644 --- a/artiq/runtime/bridge.c +++ b/artiq/runtime/bridge.c @@ -38,7 +38,7 @@ static void send_ready(void) struct msg_base msg; msg.type = MESSAGE_TYPE_BRG_READY; - mailbox_send_and_wait(&msg); + mailbox_send_and_wait(&msg); } void bridge_main(void) diff --git a/artiq/runtime/dds.c b/artiq/runtime/dds.c index 7d120a0cd..3c3f435e5 100644 --- a/artiq/runtime/dds.c +++ b/artiq/runtime/dds.c @@ -1,7 +1,7 @@ #include #include -#include "exceptions.h" +#include "artiq_personality.h" #include "rtio.h" #include "log.h" #include "dds.h" @@ -177,7 +177,7 @@ static struct dds_set_params batch[DDS_MAX_BATCH]; void dds_batch_enter(long long int timestamp) { if(batch_mode) - exception_raise(EID_DDS_BATCH_ERROR); + artiq_raise_from_c("DDSBatchError", "DDS batch error", 0, 0, 0); batch_mode = 1; batch_count = 0; batch_ref_time = timestamp; @@ -189,7 +189,7 @@ void dds_batch_exit(void) int i; if(!batch_mode) - exception_raise(EID_DDS_BATCH_ERROR); + artiq_raise_from_c("DDSBatchError", "DDS batch error", 0, 0, 0); rtio_chan_sel_write(RTIO_DDS_CHANNEL); /* + FUD time */ now = batch_ref_time - batch_count*(DURATION_PROGRAM + DURATION_WRITE); @@ -207,7 +207,7 @@ void dds_set(long long int timestamp, int channel, { if(batch_mode) { if(batch_count >= DDS_MAX_BATCH) - exception_raise(EID_DDS_BATCH_ERROR); + artiq_raise_from_c("DDSBatchError", "DDS batch error", 0, 0, 0); /* timestamp parameter ignored (determined by batch) */ batch[batch_count].channel = channel; batch[batch_count].ftw = ftw; diff --git a/artiq/runtime/elf_loader.c b/artiq/runtime/elf_loader.c deleted file mode 100644 index 8604381e8..000000000 --- a/artiq/runtime/elf_loader.c +++ /dev/null @@ -1,240 +0,0 @@ -#include - -#include "log.h" -#include "elf_loader.h" - -#define EI_NIDENT 16 - -struct elf32_ehdr { - unsigned char ident[EI_NIDENT]; /* ident bytes */ - unsigned short type; /* file type */ - unsigned short machine; /* target machine */ - unsigned int version; /* file version */ - unsigned int entry; /* start address */ - unsigned int phoff; /* phdr file offset */ - unsigned int shoff; /* shdr file offset */ - unsigned int flags; /* file flags */ - unsigned short ehsize; /* sizeof ehdr */ - unsigned short phentsize; /* sizeof phdr */ - unsigned short phnum; /* number phdrs */ - unsigned short shentsize; /* sizeof shdr */ - unsigned short shnum; /* number shdrs */ - unsigned short shstrndx; /* shdr string index */ -} __attribute__((packed)); - -static const unsigned char elf_magic_header[] = { - 0x7f, 0x45, 0x4c, 0x46, /* 0x7f, 'E', 'L', 'F' */ - 0x01, /* Only 32-bit objects. */ - 0x02, /* Only big-endian. */ - 0x01, /* Only ELF version 1. */ -}; - -#define ET_NONE 0 /* Unknown type. */ -#define ET_REL 1 /* Relocatable. */ -#define ET_EXEC 2 /* Executable. */ -#define ET_DYN 3 /* Shared object. */ -#define ET_CORE 4 /* Core file. */ - -#define EM_OR1K 0x005c - -struct elf32_shdr { - unsigned int name; /* section name */ - unsigned int type; /* SHT_... */ - unsigned int flags; /* SHF_... */ - unsigned int addr; /* virtual address */ - unsigned int offset; /* file offset */ - unsigned int size; /* section size */ - unsigned int link; /* misc info */ - unsigned int info; /* misc info */ - unsigned int addralign; /* memory alignment */ - unsigned int entsize; /* entry size if table */ -} __attribute__((packed)); - -struct elf32_name { - char name[12]; -} __attribute__((packed)); - -struct elf32_rela { - unsigned int offset; /* Location to be relocated. */ - unsigned int info; /* Relocation type and symbol index. */ - int addend; /* Addend. */ -} __attribute__((packed)); - -#define ELF32_R_SYM(info) ((info) >> 8) -#define ELF32_R_TYPE(info) ((unsigned char)(info)) - -#define R_OR1K_INSN_REL_26 6 - -struct elf32_sym { - unsigned int name; /* String table index of name. */ - unsigned int value; /* Symbol value. */ - unsigned int size; /* Size of associated object. */ - unsigned char info; /* Type and binding information. */ - unsigned char other; /* Reserved (not used). */ - unsigned short shndx; /* Section index of symbol. */ -} __attribute__((packed)); - -#define STT_NOTYPE 0 -#define STT_OBJECT 1 -#define STT_FUNC 2 -#define STT_SECTION 3 -#define STT_FILE 4 - -#define ELF32_ST_TYPE(info) ((info) & 0x0f) - - -#define SANITIZE_OFFSET_SIZE(offset, size) \ - if(offset > 0x10000000) { \ - log("Incorrect offset in ELF data"); \ - return 0; \ - } \ - if((offset + size) > elf_length) { \ - log("Attempted to access past the end of ELF data"); \ - return 0; \ - } - -#define GET_POINTER_SAFE(target, target_type, offset) \ - SANITIZE_OFFSET_SIZE(offset, sizeof(target_type)); \ - target = (target_type *)((char *)elf_data + offset) - -void *find_symbol(const struct symbol *symbols, const char *name) -{ - int i; - - i = 0; - while((symbols[i].name != NULL) && (strcmp(symbols[i].name, name) != 0)) - i++; - return symbols[i].target; -} - -static int fixup(void *dest, int dest_length, struct elf32_rela *rela, void *target) -{ - int type, offset; - unsigned int *_dest = dest; - unsigned int *_target = target; - - type = ELF32_R_TYPE(rela->info); - offset = rela->offset/4; - if(type == R_OR1K_INSN_REL_26) { - int val; - - val = _target - (_dest + offset); - _dest[offset] = (_dest[offset] & 0xfc000000) | (val & 0x03ffffff); - } else - log("Unsupported relocation type: %d", type); - return 1; -} - -int load_elf(symbol_resolver resolver, symbol_callback callback, void *elf_data, int elf_length, void *dest, int dest_length) -{ - struct elf32_ehdr *ehdr; - struct elf32_shdr *strtable; - unsigned int shdrptr; - int i; - - unsigned int textoff, textsize; - unsigned int textrelaoff, textrelasize; - unsigned int symtaboff, symtabsize; - unsigned int strtaboff, strtabsize; - - - /* validate ELF */ - GET_POINTER_SAFE(ehdr, struct elf32_ehdr, 0); - if(memcmp(ehdr->ident, elf_magic_header, sizeof(elf_magic_header)) != 0) { - log("Incorrect ELF header"); - return 0; - } - if(ehdr->type != ET_REL) { - log("ELF is not relocatable"); - return 0; - } - if(ehdr->machine != EM_OR1K) { - log("ELF is for a different machine"); - return 0; - } - - /* extract section info */ - GET_POINTER_SAFE(strtable, struct elf32_shdr, ehdr->shoff + ehdr->shentsize*ehdr->shstrndx); - textoff = textsize = 0; - textrelaoff = textrelasize = 0; - symtaboff = symtabsize = 0; - strtaboff = strtabsize = 0; - shdrptr = ehdr->shoff; - for(i=0;ishnum;i++) { - struct elf32_shdr *shdr; - struct elf32_name *name; - - GET_POINTER_SAFE(shdr, struct elf32_shdr, shdrptr); - GET_POINTER_SAFE(name, struct elf32_name, strtable->offset + shdr->name); - - if(strncmp(name->name, ".text", 5) == 0) { - textoff = shdr->offset; - textsize = shdr->size; - } else if(strncmp(name->name, ".rela.text", 10) == 0) { - textrelaoff = shdr->offset; - textrelasize = shdr->size; - } else if(strncmp(name->name, ".symtab", 7) == 0) { - symtaboff = shdr->offset; - symtabsize = shdr->size; - } else if(strncmp(name->name, ".strtab", 7) == 0) { - strtaboff = shdr->offset; - strtabsize = shdr->size; - } - - shdrptr += ehdr->shentsize; - } - SANITIZE_OFFSET_SIZE(textoff, textsize); - SANITIZE_OFFSET_SIZE(textrelaoff, textrelasize); - SANITIZE_OFFSET_SIZE(symtaboff, symtabsize); - SANITIZE_OFFSET_SIZE(strtaboff, strtabsize); - - /* load .text section */ - if(textsize > dest_length) { - log(".text section is too large"); - return 0; - } - memcpy(dest, (char *)elf_data + textoff, textsize); - - /* process .text relocations */ - for(i=0;iinfo)); - if(sym->name != 0) { - char *name; - void *target; - - name = (char *)elf_data + strtaboff + sym->name; - target = resolver(name); - if(target == NULL) { - log("Undefined symbol: %s", name); - return 0; - } - if(!fixup(dest, dest_length, rela, target)) - return 0; - } else { - log("Unsupported relocation"); - return 0; - } - } - - /* list provided functions via callback */ - for(i=0;iinfo) == STT_FUNC) && (sym->name != 0)) { - char *name; - void *target; - - name = (char *)elf_data + strtaboff + sym->name; - target = (char *)dest + sym->value; - if(!callback(name, target)) - return 0; - } - } - - return 1; -} diff --git a/artiq/runtime/elf_loader.h b/artiq/runtime/elf_loader.h deleted file mode 100644 index a116e0851..000000000 --- a/artiq/runtime/elf_loader.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef __ELF_LOADER_H -#define __ELF_LOADER_H - -struct symbol { - char *name; - void *target; -}; - -typedef void * (*symbol_resolver)(const char *); -typedef int (*symbol_callback)(const char *, void *); - -void *find_symbol(const struct symbol *symbols, const char *name); -/* elf_data must be aligned on a 32-bit boundary */ -int load_elf(symbol_resolver resolver, symbol_callback callback, void *elf_data, int elf_length, void *dest, int dest_length); - -#endif /* __ELF_LOADER_H */ diff --git a/artiq/runtime/exception_jmp.S b/artiq/runtime/exception_jmp.S deleted file mode 100644 index 014422960..000000000 --- a/artiq/runtime/exception_jmp.S +++ /dev/null @@ -1,37 +0,0 @@ -.global exception_setjmp -.type exception_setjmp, @function -exception_setjmp: - l.sw 0(r3), r1 - l.sw 4(r3), r2 - l.sw 8(r3), r9 - l.sw 12(r3), r10 - l.sw 16(r3), r14 - l.sw 20(r3), r16 - l.sw 24(r3), r18 - l.sw 28(r3), r20 - l.sw 32(r3), r22 - l.sw 36(r3), r24 - l.sw 40(r3), r26 - l.sw 44(r3), r28 - l.sw 48(r3), r30 - l.jr r9 - l.ori r11, r0, 0 - -.global exception_longjmp -.type exception_longjmp, @function -exception_longjmp: - l.lwz r1, 0(r3) - l.lwz r2, 4(r3) - l.lwz r9, 8(r3) - l.lwz r10, 12(r3) - l.lwz r14, 16(r3) - l.lwz r16, 20(r3) - l.lwz r18, 24(r3) - l.lwz r20, 28(r3) - l.lwz r22, 32(r3) - l.lwz r24, 36(r3) - l.lwz r26, 40(r3) - l.lwz r28, 44(r3) - l.lwz r30, 48(r3) - l.jr r9 - l.ori r11, r0, 1 diff --git a/artiq/runtime/exceptions.c b/artiq/runtime/exceptions.c deleted file mode 100644 index 5c82f5c43..000000000 --- a/artiq/runtime/exceptions.c +++ /dev/null @@ -1,58 +0,0 @@ -#include - -#include "log.h" -#include "exceptions.h" - -#define MAX_EXCEPTION_CONTEXTS 64 - -struct exception_context { - void *jb[13]; -}; - -static struct exception_context exception_contexts[MAX_EXCEPTION_CONTEXTS]; -static int ec_top; -static int stored_id; -static long long int stored_params[3]; - -void *exception_push(void) -{ - if(ec_top >= MAX_EXCEPTION_CONTEXTS) - exception_raise(EID_INTERNAL_ERROR); - return exception_contexts[ec_top++].jb; -} - -void exception_pop(int levels) -{ - ec_top -= levels; -} - -int exception_getid(long long int *eparams) -{ - int i; - - if(eparams) - for(i=0;i<3;i++) - eparams[i] = stored_params[i]; - return stored_id; -} - -void exception_raise(int id) -{ - exception_raise_params(id, 0, 0, 0); -} - -void exception_raise_params(int id, - long long int p0, long long int p1, - long long int p2) -{ - if(ec_top > 0) { - stored_id = id; - stored_params[0] = p0; - stored_params[1] = p1; - stored_params[2] = p2; - exception_longjmp(exception_contexts[--ec_top].jb); - } else { - log("ERROR: uncaught exception, ID=%d\n", id); - while(1); - } -} diff --git a/artiq/runtime/exceptions.h b/artiq/runtime/exceptions.h deleted file mode 100644 index ae97b9d95..000000000 --- a/artiq/runtime/exceptions.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef __EXCEPTIONS_H -#define __EXCEPTIONS_H - -enum { - EID_NONE = 0, - EID_INTERNAL_ERROR = 1, - EID_RPC_EXCEPTION = 2, - EID_RTIO_UNDERFLOW = 3, - EID_RTIO_SEQUENCE_ERROR = 4, - EID_RTIO_COLLISION_ERROR = 5, - EID_RTIO_OVERFLOW = 6, - EID_DDS_BATCH_ERROR = 7 -}; - -int exception_setjmp(void *jb) __attribute__((returns_twice)); -void exception_longjmp(void *jb) __attribute__((noreturn)); - -void *exception_push(void); -void exception_pop(int levels); -int exception_getid(long long int *eparams); -void exception_raise(int id) __attribute__((noreturn)); -void exception_raise_params(int id, - long long int p0, long long int p1, - long long int p2) __attribute__((noreturn)); - -#endif /* __EXCEPTIONS_H */ diff --git a/artiq/runtime/flash_storage.c b/artiq/runtime/flash_storage.c index 34bc3a508..a160409a8 100644 --- a/artiq/runtime/flash_storage.c +++ b/artiq/runtime/flash_storage.c @@ -115,7 +115,8 @@ static int is_empty(struct record *record) return record->value_len == 0; } -static int key_exists(char *buff, char *key, char *end, char accept_empty, struct record *found_record) +static int key_exists(char *buff, const char *key, char *end, char accept_empty, + struct record *found_record) { struct iter_state is; struct record iter_record; @@ -170,7 +171,7 @@ static char check_for_empty_records(char *buff) return 0; } -static unsigned int try_to_flush_duplicates(char *new_key, unsigned int buf_len) +static unsigned int try_to_flush_duplicates(const char *new_key, unsigned int buf_len) { unsigned int key_size, new_record_size, ret = 0, can_rollback = 0; struct record record, previous_record; @@ -210,7 +211,8 @@ static unsigned int try_to_flush_duplicates(char *new_key, unsigned int buf_len) return ret; } -static void write_at_offset(char *key, void *buffer, int buf_len, unsigned int sector_offset) +static void write_at_offset(const char *key, const void *buffer, + int buf_len, unsigned int sector_offset) { int key_len = strlen(key) + 1; unsigned int record_size = key_len + buf_len + sizeof(record_size); @@ -223,7 +225,7 @@ static void write_at_offset(char *key, void *buffer, int buf_len, unsigned int s } -int fs_write(char *key, void *buffer, unsigned int buf_len) +int fs_write(const char *key, const void *buffer, unsigned int buf_len) { struct record record; unsigned int key_size = strlen(key) + 1; @@ -269,7 +271,7 @@ void fs_erase(void) flush_cpu_dcache(); } -unsigned int fs_read(char *key, void *buffer, unsigned int buf_len, unsigned int *remain) +unsigned int fs_read(const char *key, void *buffer, unsigned int buf_len, unsigned int *remain) { unsigned int read_length = 0; struct iter_state is; @@ -295,7 +297,7 @@ unsigned int fs_read(char *key, void *buffer, unsigned int buf_len, unsigned int return read_length; } -void fs_remove(char *key) +void fs_remove(const char *key) { fs_write(key, NULL, 0); } diff --git a/artiq/runtime/flash_storage.h b/artiq/runtime/flash_storage.h index 9994fef37..e983de778 100644 --- a/artiq/runtime/flash_storage.h +++ b/artiq/runtime/flash_storage.h @@ -5,9 +5,9 @@ #ifndef __FLASH_STORAGE_H #define __FLASH_STORAGE_H -void fs_remove(char *key); +void fs_remove(const char *key); void fs_erase(void); -int fs_write(char *key, void *buffer, unsigned int buflen); -unsigned int fs_read(char *key, void *buffer, unsigned int buflen, unsigned int *remain); +int fs_write(const char *key, const void *buffer, unsigned int buflen); +unsigned int fs_read(const char *key, void *buffer, unsigned int buflen, unsigned int *remain); #endif /* __FLASH_STORAGE_H */ diff --git a/artiq/runtime/gen_service_table.py b/artiq/runtime/gen_service_table.py deleted file mode 100755 index 8038d87b6..000000000 --- a/artiq/runtime/gen_service_table.py +++ /dev/null @@ -1,65 +0,0 @@ -#!/usr/bin/env python3.5 - -import sys - -from elftools.elf.elffile import ELFFile - - -services = [ - ("syscalls", [ - ("now_init", "now_init"), - ("now_save", "now_save"), - - ("watchdog_set", "watchdog_set"), - ("watchdog_clear", "watchdog_clear"), - - ("rpc", "rpc"), - - ("rtio_get_counter", "rtio_get_counter"), - - ("ttl_set_o", "ttl_set_o"), - ("ttl_set_oe", "ttl_set_oe"), - ("ttl_set_sensitivity", "ttl_set_sensitivity"), - ("ttl_get", "ttl_get"), - ("ttl_clock_set", "ttl_clock_set"), - - ("dds_init", "dds_init"), - ("dds_batch_enter", "dds_batch_enter"), - ("dds_batch_exit", "dds_batch_exit"), - ("dds_set", "dds_set"), - ]), - - ("eh", [ - ("setjmp", "exception_setjmp"), - ("push", "exception_push"), - ("pop", "exception_pop"), - ("getid", "exception_getid"), - ("raise", "exception_raise"), - ]) -] - - -def print_service_table(ksupport_elf_filename): - with open(ksupport_elf_filename, "rb") as f: - elf = ELFFile(f) - symtab = elf.get_section_by_name(b".symtab") - symbols = {symbol.name: symbol.entry.st_value - for symbol in symtab.iter_symbols()} - for name, contents in services: - print("static const struct symbol {}[] = {{".format(name)) - for name, value in contents: - print(" {{\"{}\", (void *)0x{:08x}}}," - .format(name, symbols[bytes(value, "ascii")])) - print(" {NULL, NULL}") - print("};") - - -def main(): - if len(sys.argv) == 2: - print_service_table(sys.argv[1]) - else: - print("Incorrect number of command line arguments") - sys.exit(1) - -if __name__ == "__main__": - main() diff --git a/artiq/runtime/kloader.c b/artiq/runtime/kloader.c index 4fff4790a..f56bf020d 100644 --- a/artiq/runtime/kloader.c +++ b/artiq/runtime/kloader.c @@ -1,126 +1,109 @@ #include #include +#include + +#include "kloader.h" #include "log.h" #include "clock.h" #include "flash_storage.h" #include "mailbox.h" #include "messages.h" -#include "elf_loader.h" -#include "services.h" -#include "kloader.h" -static struct symbol symtab[128]; -static int _symtab_count; -static char _symtab_strings[128*16]; -static char *_symtab_strptr; - -static void symtab_init(void) +static void start_kernel_cpu(struct msg_load_request *msg) { - memset(symtab, 0, sizeof(symtab)); - _symtab_count = 0; - _symtab_strptr = _symtab_strings; -} + // Stop kernel CPU before messing with its code. + kernel_cpu_reset_write(1); -static int symtab_add(const char *name, void *target) -{ - if(_symtab_count >= sizeof(symtab)/sizeof(symtab[0])) { - log("Too many provided symbols in object"); - symtab_init(); - return 0; - } - symtab[_symtab_count].name = _symtab_strptr; - symtab[_symtab_count].target = target; - _symtab_count++; + // Load kernel support code. + extern void _binary_ksupport_elf_start, _binary_ksupport_elf_end; + memcpy((void *)(KERNELCPU_EXEC_ADDRESS - KSUPPORT_HEADER_SIZE), + &_binary_ksupport_elf_start, + &_binary_ksupport_elf_end - &_binary_ksupport_elf_start); - while(1) { - if(_symtab_strptr >= &_symtab_strings[sizeof(_symtab_strings)]) { - log("Provided symbol string table overflow"); - symtab_init(); - return 0; - } - *_symtab_strptr = *name; - _symtab_strptr++; - if(*name == 0) - break; - name++; - } - - return 1; -} - -int kloader_load(void *buffer, int length) -{ - if(!kernel_cpu_reset_read()) { - log("BUG: attempted to load while kernel CPU running"); - return 0; - } - symtab_init(); - return load_elf( - resolve_service_symbol, symtab_add, - buffer, length, (void *)KERNELCPU_PAYLOAD_ADDRESS, 4*1024*1024); -} - -kernel_function kloader_find(const char *name) -{ - return find_symbol(symtab, name); -} - -extern char _binary_ksupport_bin_start; -extern char _binary_ksupport_bin_end; - -static void start_kernel_cpu(void *addr) -{ - memcpy((void *)KERNELCPU_EXEC_ADDRESS, &_binary_ksupport_bin_start, - &_binary_ksupport_bin_end - &_binary_ksupport_bin_start); - mailbox_acknowledge(); - mailbox_send(addr); + // Start kernel CPU. + mailbox_send(msg); kernel_cpu_reset_write(0); } -void kloader_start_bridge(void) +void kloader_start_bridge() { start_kernel_cpu(NULL); } -void kloader_start_user_kernel(kernel_function k) +static int load_or_start_kernel(const void *library, int run_kernel) +{ + static struct dyld_info library_info; + struct msg_load_request request = { + .library = library, + .library_info = &library_info, + .run_kernel = run_kernel, + }; + start_kernel_cpu(&request); + + struct msg_load_reply *reply = mailbox_wait_and_receive(); + mailbox_acknowledge(); + + if(reply->type != MESSAGE_TYPE_LOAD_REPLY) { + log("BUG: unexpected reply to load/run request"); + return 0; + } + + if(reply->error != NULL) { + log("cannot load kernel: %s", reply->error); + return 0; + } + + return 1; +} + +int kloader_load_library(const void *library) { if(!kernel_cpu_reset_read()) { - log("BUG: attempted to start kernel CPU while already running (user kernel)"); - return; + log("BUG: attempted to load kernel library while kernel CPU is running"); + return 0; } - start_kernel_cpu((void *)k); + + return load_or_start_kernel(library, 0); +} + +void kloader_filter_backtrace(struct artiq_backtrace_item *backtrace, + size_t *backtrace_size) { + struct artiq_backtrace_item *cursor = backtrace; + + // Remove all backtrace items belonging to ksupport and subtract + // shared object base from the addresses. + for(int i = 0; i < *backtrace_size; i++) { + if(backtrace[i].function > KERNELCPU_PAYLOAD_ADDRESS) { + backtrace[i].function -= KERNELCPU_PAYLOAD_ADDRESS; + *cursor++ = backtrace[i]; + } + } + + *backtrace_size = cursor - backtrace; +} + +void kloader_start_kernel() +{ + load_or_start_kernel(NULL, 1); } static int kloader_start_flash_kernel(char *key) { - char buffer[32*1024]; - unsigned int len, remain; - kernel_function k; - - if(!kernel_cpu_reset_read()) { - log("BUG: attempted to start kernel CPU while already running (%s)", key); - return 0; - } #if (defined CSR_SPIFLASH_BASE && defined SPIFLASH_PAGE_SIZE) - len = fs_read(key, buffer, sizeof(buffer), &remain); - if(len <= 0) + char buffer[32*1024]; + unsigned int length, remain; + + length = fs_read(key, buffer, sizeof(buffer), &remain); + if(length <= 0) return 0; + if(remain) { - log("ERROR: %s too long", key); + log("ERROR: kernel %s is too large", key); return 0; } - if(!kloader_load(buffer, len)) { - log("ERROR: failed to load ELF binary (%s)", key); - return 0; - } - k = kloader_find("run"); - if(!k) { - log("ERROR: failed to find entry point for ELF kernel (%s)", key); - return 0; - } - start_kernel_cpu((void *)k); - return 1; + + return load_or_start_kernel(buffer, 1); #else return 0; #endif @@ -145,7 +128,7 @@ void kloader_stop(void) int kloader_validate_kpointer(void *p) { unsigned int v = (unsigned int)p; - if((v < 0x40400000) || (v > (0x4fffffff - 1024*1024))) { + if((v < KERNELCPU_EXEC_ADDRESS) || (v > KERNELCPU_LAST_ADDRESS)) { log("Received invalid pointer from kernel CPU: 0x%08x", v); return 0; } @@ -195,7 +178,11 @@ void kloader_service_essential_kmsg(void) case MESSAGE_TYPE_LOG: { struct msg_log *msg = (struct msg_log *)umsg; - log_va(msg->fmt, msg->args); + if(msg->no_newline) { + lognonl_va(msg->fmt, msg->args); + } else { + log_va(msg->fmt, msg->args); + } mailbox_acknowledge(); break; } diff --git a/artiq/runtime/kloader.h b/artiq/runtime/kloader.h index 8f6091d28..d2527eb71 100644 --- a/artiq/runtime/kloader.h +++ b/artiq/runtime/kloader.h @@ -1,20 +1,23 @@ #ifndef __KLOADER_H #define __KLOADER_H -#define KERNELCPU_EXEC_ADDRESS 0x40400000 -#define KERNELCPU_PAYLOAD_ADDRESS 0x40408000 +#include "artiq_personality.h" + +#define KERNELCPU_EXEC_ADDRESS 0x40400000 +#define KERNELCPU_PAYLOAD_ADDRESS 0x40420000 +#define KERNELCPU_LAST_ADDRESS (0x4fffffff - 1024*1024) +#define KSUPPORT_HEADER_SIZE 0x80 extern long long int now; -typedef void (*kernel_function)(void); - -int kloader_load(void *buffer, int length); -kernel_function kloader_find(const char *name); +int kloader_load_library(const void *code); +void kloader_filter_backtrace(struct artiq_backtrace_item *backtrace, + size_t *backtrace_size); void kloader_start_bridge(void); int kloader_start_startup_kernel(void); int kloader_start_idle_kernel(void); -void kloader_start_user_kernel(kernel_function k); +void kloader_start_kernel(void); void kloader_stop(void); int kloader_validate_kpointer(void *p); diff --git a/artiq/runtime/ksupport.c b/artiq/runtime/ksupport.c index 2331f5195..9f04b2bf9 100644 --- a/artiq/runtime/ksupport.c +++ b/artiq/runtime/ksupport.c @@ -1,67 +1,253 @@ #include +#include +#include -#include "exceptions.h" -#include "bridge.h" +#include +#include +#include +#include + +#include "ksupport.h" +#include "kloader.h" #include "mailbox.h" #include "messages.h" -#include "rtio.h" +#include "bridge.h" +#include "artiq_personality.h" +#include "ttl.h" #include "dds.h" +#include "rtio.h" -/* for the prototypes for watchdog_set() and watchdog_clear() */ -#include "clock.h" -/* for the prototype for rpc() */ -#include "session.h" -/* for the prototype for log() */ -#include "log.h" +void ksupport_abort(void); -void exception_handler(unsigned long vect, unsigned long *sp); -void exception_handler(unsigned long vect, unsigned long *sp) +int64_t now; + +/* compiler-rt symbols */ +extern void __divsi3, __modsi3, __ledf2, __gedf2, __unorddf2, __eqdf2, __ltdf2, + __nedf2, __gtdf2, __negsf2, __negdf2, __addsf3, __subsf3, __mulsf3, + __divsf3, __lshrdi3, __muldi3, __divdi3, __ashldi3, __ashrdi3, + __udivmoddi4, __floatsisf, __floatunsisf, __fixsfsi, __fixunssfsi, + __adddf3, __subdf3, __muldf3, __divdf3, __floatsidf, __floatunsidf, + __floatdidf, __fixdfsi, __fixdfdi, __fixunsdfsi, __clzsi2, __ctzsi2, + __udivdi3, __umoddi3, __moddi3; + +/* artiq_personality symbols */ +extern void __artiq_personality; + +struct symbol { + const char *name; + void *addr; +}; + +static const struct symbol runtime_exports[] = { + /* compiler-rt */ + {"__divsi3", &__divsi3}, + {"__modsi3", &__modsi3}, + {"__ledf2", &__ledf2}, + {"__gedf2", &__gedf2}, + {"__unorddf2", &__unorddf2}, + {"__eqdf2", &__eqdf2}, + {"__ltdf2", &__ltdf2}, + {"__nedf2", &__nedf2}, + {"__gtdf2", &__gtdf2}, + {"__negsf2", &__negsf2}, + {"__negdf2", &__negdf2}, + {"__addsf3", &__addsf3}, + {"__subsf3", &__subsf3}, + {"__mulsf3", &__mulsf3}, + {"__divsf3", &__divsf3}, + {"__lshrdi3", &__lshrdi3}, + {"__muldi3", &__muldi3}, + {"__divdi3", &__divdi3}, + {"__ashldi3", &__ashldi3}, + {"__ashrdi3", &__ashrdi3}, + {"__udivmoddi4", &__udivmoddi4}, + {"__floatsisf", &__floatsisf}, + {"__floatunsisf", &__floatunsisf}, + {"__fixsfsi", &__fixsfsi}, + {"__fixunssfsi", &__fixunssfsi}, + {"__adddf3", &__adddf3}, + {"__subdf3", &__subdf3}, + {"__muldf3", &__muldf3}, + {"__divdf3", &__divdf3}, + {"__floatsidf", &__floatsidf}, + {"__floatunsidf", &__floatunsidf}, + {"__floatdidf", &__floatdidf}, + {"__fixdfsi", &__fixdfsi}, + {"__fixdfdi", &__fixdfdi}, + {"__fixunsdfsi", &__fixunsdfsi}, + {"__clzsi2", &__clzsi2}, + {"__ctzsi2", &__ctzsi2}, + {"__udivdi3", &__udivdi3}, + {"__umoddi3", &__umoddi3}, + {"__moddi3", &__moddi3}, + + /* exceptions */ + {"_Unwind_Resume", &_Unwind_Resume}, + {"__artiq_personality", &__artiq_personality}, + {"__artiq_raise", &__artiq_raise}, + {"__artiq_reraise", &__artiq_reraise}, + {"abort", &ksupport_abort}, + + /* proxified syscalls */ + {"now", &now}, + + {"watchdog_set", &watchdog_set}, + {"watchdog_clear", &watchdog_clear}, + + {"log", &log}, + {"lognonl", &lognonl}, + {"send_rpc", &send_rpc}, + {"recv_rpc", &recv_rpc}, + + /* direct syscalls */ + {"rtio_get_counter", &rtio_get_counter}, + + {"ttl_set_o", &ttl_set_o}, + {"ttl_set_oe", &ttl_set_oe}, + {"ttl_set_sensitivity", &ttl_set_sensitivity}, + {"ttl_get", &ttl_get}, + {"ttl_clock_set", &ttl_clock_set}, + + {"dds_init", &dds_init}, + {"dds_batch_enter", &dds_batch_enter}, + {"dds_batch_exit", &dds_batch_exit}, + {"dds_set", &dds_set}, + + /* end */ + {NULL, NULL} +}; + +/* called by libunwind */ +int fprintf(FILE *stream, const char *fmt, ...) { - struct msg_exception msg; + struct msg_log request; - msg.type = MESSAGE_TYPE_EXCEPTION; - msg.eid = EID_INTERNAL_ERROR; - msg.eparams[0] = 256; - msg.eparams[1] = 256; - msg.eparams[2] = 256; - mailbox_send_and_wait(&msg); - while(1); + request.type = MESSAGE_TYPE_LOG; + request.fmt = fmt; + request.no_newline = 1; + va_start(request.args, fmt); + mailbox_send_and_wait(&request); + va_end(request.args); + + return 0; } -typedef void (*kernel_function)(void); +/* called by libunwind */ +int dladdr (const void *address, Dl_info *info) { + /* we don't try to resolve names */ + return 0; +} + +/* called by libunwind */ +int dl_iterate_phdr (int (*callback) (struct dl_phdr_info *, size_t, void *), void *data) { + Elf32_Ehdr *ehdr; + struct dl_phdr_info phdr_info; + int retval; + + ehdr = (Elf32_Ehdr *)(KERNELCPU_EXEC_ADDRESS - KSUPPORT_HEADER_SIZE); + phdr_info = (struct dl_phdr_info){ + .dlpi_addr = 0, /* absolutely linked */ + .dlpi_name = "", + .dlpi_phdr = (Elf32_Phdr*) ((intptr_t)ehdr + ehdr->e_phoff), + .dlpi_phnum = ehdr->e_phnum, + }; + retval = callback(&phdr_info, sizeof(phdr_info), data); + if(retval) + return retval; + + ehdr = (Elf32_Ehdr *)KERNELCPU_PAYLOAD_ADDRESS; + phdr_info = (struct dl_phdr_info){ + .dlpi_addr = KERNELCPU_PAYLOAD_ADDRESS, + .dlpi_name = "", + .dlpi_phdr = (Elf32_Phdr*) ((intptr_t)ehdr + ehdr->e_phoff), + .dlpi_phnum = ehdr->e_phnum, + }; + retval = callback(&phdr_info, sizeof(phdr_info), data); + return retval; +} + +static Elf32_Addr resolve_runtime_export(const char *name) { + const struct symbol *sym = runtime_exports; + while(sym->name) { + if(!strcmp(sym->name, name)) + return (Elf32_Addr)sym->addr; + ++sym; + } + return 0; +} + +void exception_handler(unsigned long vect, unsigned long *regs, + unsigned long pc, unsigned long ea); +void exception_handler(unsigned long vect, unsigned long *regs, + unsigned long pc, unsigned long ea) +{ + artiq_raise_from_c("InternalError", + "Hardware exception {0} at PC {1}, EA {2}", + vect, pc, ea); +} int main(void); int main(void) { - kernel_function k; - void *jb; + struct msg_load_request *request = mailbox_receive(); + struct msg_load_reply load_reply = { + .type = MESSAGE_TYPE_LOAD_REPLY, + .error = NULL + }; - k = mailbox_receive(); - - if(k == NULL) + if(request == NULL) { bridge_main(); - else { - jb = exception_push(); - if(exception_setjmp(jb)) { - struct msg_exception msg; + while(1); + } - msg.type = MESSAGE_TYPE_EXCEPTION; - msg.eid = exception_getid(msg.eparams); - mailbox_send_and_wait(&msg); - } else { - struct msg_base msg; - - k(); - exception_pop(1); - - msg.type = MESSAGE_TYPE_FINISHED; - mailbox_send_and_wait(&msg); + if(request->library != NULL) { + if(!dyld_load(request->library, KERNELCPU_PAYLOAD_ADDRESS, + resolve_runtime_export, request->library_info, + &load_reply.error)) { + mailbox_send(&load_reply); + while(1); } } + + if(request->run_kernel) { + void (*kernel_init)() = request->library_info->init; + + mailbox_send_and_wait(&load_reply); + + now = now_init(); + kernel_init(); + now_save(now); + + struct msg_base finished_reply; + finished_reply.type = MESSAGE_TYPE_FINISHED; + mailbox_send_and_wait(&finished_reply); + } else { + mailbox_send(&load_reply); + } + while(1); } -long long int now_init(void); +/* called from __artiq_personality */ +void __artiq_terminate(struct artiq_exception *artiq_exn, + struct artiq_backtrace_item *backtrace, + size_t backtrace_size) { + struct msg_exception msg; + + msg.type = MESSAGE_TYPE_EXCEPTION; + msg.exception = artiq_exn; + msg.backtrace = backtrace; + msg.backtrace_size = backtrace_size; + mailbox_send(&msg); + + while(1); +} + +void ksupport_abort() { + artiq_raise_from_c("InternalError", "abort() called; check device log for details", + 0, 0, 0); +} + long long int now_init(void) { struct msg_base request; @@ -72,8 +258,11 @@ long long int now_init(void) mailbox_send_and_wait(&request); reply = mailbox_wait_and_receive(); - if(reply->type != MESSAGE_TYPE_NOW_INIT_REPLY) - exception_raise_params(EID_INTERNAL_ERROR, 1, 0, 0); + if(reply->type != MESSAGE_TYPE_NOW_INIT_REPLY) { + log("Malformed MESSAGE_TYPE_NOW_INIT_REQUEST reply type %d", + reply->type); + while(1); + } now = reply->now; mailbox_acknowledge(); @@ -85,7 +274,6 @@ long long int now_init(void) return now; } -void now_save(long long int now); void now_save(long long int now) { struct msg_now_save request; @@ -106,8 +294,11 @@ int watchdog_set(int ms) mailbox_send_and_wait(&request); reply = mailbox_wait_and_receive(); - if(reply->type != MESSAGE_TYPE_WATCHDOG_SET_REPLY) - exception_raise_params(EID_INTERNAL_ERROR, 2, 0, 0); + if(reply->type != MESSAGE_TYPE_WATCHDOG_SET_REPLY) { + log("Malformed MESSAGE_TYPE_WATCHDOG_SET_REQUEST reply type %d", + reply->type); + while(1); + } id = reply->id; mailbox_acknowledge(); @@ -123,28 +314,56 @@ void watchdog_clear(int id) mailbox_send_and_wait(&request); } -int rpc(int rpc_num, ...) +void send_rpc(int service, const char *tag, ...) { - struct msg_rpc_request request; - struct msg_rpc_reply *reply; - int eid, retval; + struct msg_rpc_send request; - request.type = MESSAGE_TYPE_RPC_REQUEST; - request.rpc_num = rpc_num; - va_start(request.args, rpc_num); + request.type = MESSAGE_TYPE_RPC_SEND; + request.service = service; + request.tag = tag; + va_start(request.args, tag); mailbox_send_and_wait(&request); va_end(request.args); +} + +int recv_rpc(void *slot) { + struct msg_rpc_recv_request request; + struct msg_rpc_recv_reply *reply; + + request.type = MESSAGE_TYPE_RPC_RECV_REQUEST; + request.slot = slot; + mailbox_send_and_wait(&request); reply = mailbox_wait_and_receive(); - if(reply->type != MESSAGE_TYPE_RPC_REPLY) - exception_raise_params(EID_INTERNAL_ERROR, 3, 0, 0); - eid = reply->eid; - retval = reply->retval; - mailbox_acknowledge(); + if(reply->type != MESSAGE_TYPE_RPC_RECV_REPLY) { + log("Malformed MESSAGE_TYPE_RPC_RECV_REQUEST reply type %d", + reply->type); + while(1); + } - if(eid != EID_NONE) - exception_raise(eid); - return retval; + if(reply->exception) { + struct artiq_exception exception; + memcpy(&exception, reply->exception, + sizeof(struct artiq_exception)); + mailbox_acknowledge(); + __artiq_raise(&exception); + } else { + int alloc_size = reply->alloc_size; + mailbox_acknowledge(); + return alloc_size; + } +} + +void lognonl(const char *fmt, ...) +{ + struct msg_log request; + + request.type = MESSAGE_TYPE_LOG; + request.fmt = fmt; + request.no_newline = 1; + va_start(request.args, fmt); + mailbox_send_and_wait(&request); + va_end(request.args); } void log(const char *fmt, ...) @@ -153,6 +372,7 @@ void log(const char *fmt, ...) request.type = MESSAGE_TYPE_LOG; request.fmt = fmt; + request.no_newline = 0; va_start(request.args, fmt); mailbox_send_and_wait(&request); va_end(request.args); diff --git a/artiq/runtime/ksupport.h b/artiq/runtime/ksupport.h new file mode 100644 index 000000000..88dc7e2a0 --- /dev/null +++ b/artiq/runtime/ksupport.h @@ -0,0 +1,13 @@ +#ifndef __KSTARTUP_H +#define __KSTARTUP_H + +long long int now_init(void); +void now_save(long long int now); +int watchdog_set(int ms); +void watchdog_clear(int id); +void send_rpc(int service, const char *tag, ...); +int recv_rpc(void *slot); +void lognonl(const char *fmt, ...); +void log(const char *fmt, ...); + +#endif /* __KSTARTUP_H */ diff --git a/artiq/runtime/ksupport.ld b/artiq/runtime/ksupport.ld index 3cc585399..9f9ca4bb9 100644 --- a/artiq/runtime/ksupport.ld +++ b/artiq/runtime/ksupport.ld @@ -4,10 +4,10 @@ ENTRY(_start) INCLUDE generated/regions.ld /* First 4M of main memory are reserved for runtime code/data - * then comes kernel memory. First 32K of kernel memory are for support code. + * then comes kernel memory. First 128K of kernel memory are for support code. */ MEMORY { - ksupport : ORIGIN = 0x40400000, LENGTH = 0x8000 + ksupport (RWX) : ORIGIN = 0x40400000, LENGTH = 0x20000 } /* On AMP systems, kernel stack is at the end of main RAM, @@ -15,6 +15,13 @@ MEMORY { */ PROVIDE(_fstack = 0x40000000 + LENGTH(main_ram) - 1024*1024 - 4); +/* Force ld to make the ELF header as loadable. */ +PHDRS +{ + text PT_LOAD FILEHDR PHDRS; + eh_frame PT_GNU_EH_FRAME; +} + SECTIONS { .text : @@ -22,7 +29,7 @@ SECTIONS _ftext = .; *(.text .stub .text.* .gnu.linkonce.t.*) _etext = .; - } > ksupport + } :text .rodata : { @@ -33,6 +40,16 @@ SECTIONS _erodata = .; } > ksupport + .eh_frame : + { + *(.eh_frame) + } :text + + .eh_frame_hdr : + { + *(.eh_frame_hdr) + } :text :eh_frame + .data : { . = ALIGN(4); @@ -41,7 +58,7 @@ SECTIONS *(.data1) *(.sdata .sdata.* .gnu.linkonce.s.*) _edata = .; - } > ksupport + } .bss : { @@ -57,5 +74,5 @@ SECTIONS _ebss = .; . = ALIGN(8); _heapstart = .; - } > ksupport + } } diff --git a/artiq/runtime/log.c b/artiq/runtime/log.c index 4f1750f2f..6ac28fc1e 100644 --- a/artiq/runtime/log.c +++ b/artiq/runtime/log.c @@ -1,5 +1,7 @@ #include #include +#include +#include #include @@ -8,7 +10,7 @@ static int buffer_index; static char buffer[LOG_BUFFER_SIZE]; -void log_va(const char *fmt, va_list args) +void lognonl_va(const char *fmt, va_list args) { char outbuf[256]; int i, len; @@ -18,16 +20,29 @@ void log_va(const char *fmt, va_list args) buffer[buffer_index] = outbuf[i]; buffer_index = (buffer_index + 1) % LOG_BUFFER_SIZE; } - buffer[buffer_index] = '\n'; - buffer_index = (buffer_index + 1) % LOG_BUFFER_SIZE; #ifdef CSR_ETHMAC_BASE /* Since main comms are over ethernet, the serial port * is free for us to use. */ - puts(outbuf); + putsnonl(outbuf); #endif } +void lognonl(const char *fmt, ...) +{ + va_list args; + + va_start(args, fmt); + lognonl_va(fmt, args); + va_end(args); +} + +void log_va(const char *fmt, va_list args) +{ + lognonl_va(fmt, args); + lognonl("\n"); +} + void log(const char *fmt, ...) { va_list args; @@ -41,9 +56,14 @@ void log_get(char *outbuf) { int i, j; - j = buffer_index + 1; - for(i=0;i 0) { - for(i=0;i +#include enum { + MESSAGE_TYPE_LOAD_REPLY, MESSAGE_TYPE_NOW_INIT_REQUEST, MESSAGE_TYPE_NOW_INIT_REPLY, MESSAGE_TYPE_NOW_SAVE, @@ -12,8 +14,9 @@ enum { MESSAGE_TYPE_WATCHDOG_SET_REQUEST, MESSAGE_TYPE_WATCHDOG_SET_REPLY, MESSAGE_TYPE_WATCHDOG_CLEAR, - MESSAGE_TYPE_RPC_REQUEST, - MESSAGE_TYPE_RPC_REPLY, + MESSAGE_TYPE_RPC_SEND, + MESSAGE_TYPE_RPC_RECV_REQUEST, + MESSAGE_TYPE_RPC_RECV_REPLY, MESSAGE_TYPE_LOG, MESSAGE_TYPE_BRG_READY, @@ -33,6 +36,17 @@ struct msg_base { /* kernel messages */ +struct msg_load_request { + const void *library; + struct dyld_info *library_info; + int run_kernel; +}; + +struct msg_load_reply { + int type; + const char *error; +}; + struct msg_now_init_reply { int type; long long int now; @@ -45,8 +59,9 @@ struct msg_now_save { struct msg_exception { int type; - int eid; - long long int eparams[3]; + struct artiq_exception *exception; + struct artiq_backtrace_item *backtrace; + size_t backtrace_size; }; struct msg_watchdog_set_request { @@ -64,21 +79,28 @@ struct msg_watchdog_clear { int id; }; -struct msg_rpc_request { +struct msg_rpc_send { int type; - int rpc_num; + int service; + const char *tag; va_list args; }; -struct msg_rpc_reply { +struct msg_rpc_recv_request { int type; - int eid; - int retval; + void *slot; +}; + +struct msg_rpc_recv_reply { + int type; + int alloc_size; + struct artiq_exception *exception; }; struct msg_log { int type; const char *fmt; + int no_newline; va_list args; }; diff --git a/artiq/runtime/net_server.c b/artiq/runtime/net_server.c index 296fded5f..cfdcbc201 100644 --- a/artiq/runtime/net_server.c +++ b/artiq/runtime/net_server.c @@ -91,7 +91,7 @@ static err_t net_server_recv(void *arg, struct tcp_pcb *pcb, struct pbuf *p, err static err_t net_server_sent(void *arg, struct tcp_pcb *pcb, u16_t len) { - session_ack_mem(len); + session_ack_sent(len); return ERR_OK; } @@ -208,7 +208,7 @@ void net_server_service(void) if(len > sndbuf) len = sndbuf; tcp_write(active_pcb, data, len, 0); - session_ack_data(len); + session_ack_consumed(len); } if(len < 0) net_server_close(active_cs, active_pcb); diff --git a/artiq/runtime/rtio.c b/artiq/runtime/rtio.c index 004c71a86..0bf0fde5b 100644 --- a/artiq/runtime/rtio.c +++ b/artiq/runtime/rtio.c @@ -1,6 +1,5 @@ #include -#include "exceptions.h" #include "rtio.h" void rtio_init(void) @@ -22,17 +21,20 @@ void rtio_process_exceptional_status(int status, long long int timestamp, int ch while(rtio_o_status_read() & RTIO_O_STATUS_FULL); if(status & RTIO_O_STATUS_UNDERFLOW) { rtio_o_underflow_reset_write(1); - exception_raise_params(EID_RTIO_UNDERFLOW, + artiq_raise_from_c("RTIOUnderflow", + "RTIO underflow at {0} mu, channel {1}, counter {2}", timestamp, channel, rtio_get_counter()); } if(status & RTIO_O_STATUS_SEQUENCE_ERROR) { rtio_o_sequence_error_reset_write(1); - exception_raise_params(EID_RTIO_SEQUENCE_ERROR, + artiq_raise_from_c("RTIOSequenceError", + "RTIO sequence error at {0} mu, channel {1}", timestamp, channel, 0); } if(status & RTIO_O_STATUS_COLLISION_ERROR) { rtio_o_collision_error_reset_write(1); - exception_raise_params(EID_RTIO_COLLISION_ERROR, + artiq_raise_from_c("RTIOCollisionError", + "RTIO collision error at {0} mu, channel {1}", timestamp, channel, 0); } } diff --git a/artiq/runtime/rtio.h b/artiq/runtime/rtio.h index 566f18ead..702c381c4 100644 --- a/artiq/runtime/rtio.h +++ b/artiq/runtime/rtio.h @@ -2,6 +2,7 @@ #define __RTIO_H #include +#include "artiq_personality.h" #define RTIO_O_STATUS_FULL 1 #define RTIO_O_STATUS_UNDERFLOW 2 diff --git a/artiq/runtime/linker.ld b/artiq/runtime/runtime.ld similarity index 71% rename from artiq/runtime/linker.ld rename to artiq/runtime/runtime.ld index 4a4217f1e..dacfe535d 100644 --- a/artiq/runtime/linker.ld +++ b/artiq/runtime/runtime.ld @@ -10,6 +10,13 @@ MEMORY { runtime : ORIGIN = 0x40000000, LENGTH = 0x400000 /* 4M */ } +/* First 4M of main memory are reserved for runtime code/data + * then comes kernel memory. First 32K of kernel memory are for support code. + */ +MEMORY { + kernel : ORIGIN = 0x40400000, LENGTH = 0x8000 +} + /* Kernel memory space start right after the runtime, * and ends before the runtime stack. * Runtime stack is always at the end of main_ram. @@ -17,6 +24,11 @@ MEMORY { */ PROVIDE(_fstack = 0x40000000 + LENGTH(main_ram) - 4); +/* On AMP systems, kernel stack is at the end of main RAM, + * before the runtime stack. Leave 1M for runtime stack. + */ +PROVIDE(_kernel_fstack = 0x40000000 + LENGTH(main_ram) - 1024*1024 - 4); + SECTIONS { .text : @@ -58,6 +70,12 @@ SECTIONS . = ALIGN(4); _ebss = .; . = ALIGN(8); - _heapstart = .; } > runtime + + /DISCARD/ : + { + *(.eh_frame) + } + + _heapstart = .; } diff --git a/artiq/runtime/services.c b/artiq/runtime/services.c deleted file mode 100644 index 39db8537b..000000000 --- a/artiq/runtime/services.c +++ /dev/null @@ -1,78 +0,0 @@ -#include - -#include "elf_loader.h" -#include "session.h" -#include "clock.h" -#include "ttl.h" -#include "dds.h" -#include "exceptions.h" -#include "services.h" - -#include - -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wimplicit-int" -extern __divsi3, __modsi3, __ledf2, __gedf2, __unorddf2, __eqdf2, __ltdf2, - __nedf2, __gtdf2, __negsf2, __negdf2, __addsf3, __subsf3, __mulsf3, - __divsf3, __lshrdi3, __muldi3, __divdi3, __ashldi3, __ashrdi3, - __udivmoddi4, __floatsisf, __floatunsisf, __fixsfsi, __fixunssfsi, - __adddf3, __subdf3, __muldf3, __divdf3, __floatsidf, __floatunsidf, - __floatdidf, __fixdfsi, __fixdfdi, __fixunsdfsi, __clzsi2, __ctzsi2, - __udivdi3, __umoddi3, __moddi3; -#pragma GCC diagnostic pop - -static const struct symbol compiler_rt[] = { - {"divsi3", &__divsi3}, - {"modsi3", &__modsi3}, - {"ledf2", &__ledf2}, - {"gedf2", &__gedf2}, - {"unorddf2", &__unorddf2}, - {"eqdf2", &__eqdf2}, - {"ltdf2", &__ltdf2}, - {"nedf2", &__nedf2}, - {"gtdf2", &__gtdf2}, - {"negsf2", &__negsf2}, - {"negdf2", &__negdf2}, - {"addsf3", &__addsf3}, - {"subsf3", &__subsf3}, - {"mulsf3", &__mulsf3}, - {"divsf3", &__divsf3}, - {"lshrdi3", &__lshrdi3}, - {"muldi3", &__muldi3}, - {"divdi3", &__divdi3}, - {"ashldi3", &__ashldi3}, - {"ashrdi3", &__ashrdi3}, - {"udivmoddi4", &__udivmoddi4}, - {"floatsisf", &__floatsisf}, - {"floatunsisf", &__floatunsisf}, - {"fixsfsi", &__fixsfsi}, - {"fixunssfsi", &__fixunssfsi}, - {"adddf3", &__adddf3}, - {"subdf3", &__subdf3}, - {"muldf3", &__muldf3}, - {"divdf3", &__divdf3}, - {"floatsidf", &__floatsidf}, - {"floatunsidf", &__floatunsidf}, - {"floatdidf", &__floatdidf}, - {"fixdfsi", &__fixdfsi}, - {"fixdfdi", &__fixdfdi}, - {"fixunsdfsi", &__fixunsdfsi}, - {"clzsi2", &__clzsi2}, - {"ctzsi2", &__ctzsi2}, - {"udivdi3", &__udivdi3}, - {"umoddi3", &__umoddi3}, - {"moddi3", &__moddi3}, - {NULL, NULL} -}; - -void *resolve_service_symbol(const char *name) -{ - if(strncmp(name, "__", 2) != 0) - return NULL; - name += 2; - if(strncmp(name, "syscall_", 8) == 0) - return find_symbol(syscalls, name + 8); - if(strncmp(name, "eh_", 3) == 0) - return find_symbol(eh, name + 3); - return find_symbol(compiler_rt, name); -} diff --git a/artiq/runtime/services.h b/artiq/runtime/services.h deleted file mode 100644 index 9c9dcf630..000000000 --- a/artiq/runtime/services.h +++ /dev/null @@ -1,6 +0,0 @@ -#ifndef __SERVICES_H -#define __SERVICES_H - -void *resolve_service_symbol(const char *name); - -#endif /* __SERVICES_H */ diff --git a/artiq/runtime/session.c b/artiq/runtime/session.c index 5fa39f8cd..e074a6f63 100644 --- a/artiq/runtime/session.c +++ b/artiq/runtime/session.c @@ -10,7 +10,7 @@ #include "clock.h" #include "log.h" #include "kloader.h" -#include "exceptions.h" +#include "artiq_personality.h" #include "flash_storage.h" #include "rtiocrg.h" #include "session.h" @@ -18,43 +18,269 @@ #define BUFFER_IN_SIZE (1024*1024) #define BUFFER_OUT_SIZE (1024*1024) -static int buffer_in_index; -/* The 9th byte (right after the header) of buffer_in must be aligned - * to a 32-bit boundary for elf_loader to work. - */ +static int process_input(); +static int out_packet_available(); + +// ============================= Reader interface ============================= + +// Align the 9th byte (right after the header) of buffer_in so that +// the payload can be deserialized directly from the buffer using word reads. static struct { char padding[3]; - char data[BUFFER_IN_SIZE]; -} __attribute__((packed)) _buffer_in __attribute__((aligned(4))); -#define buffer_in _buffer_in.data -static int buffer_out_index_data; -static int buffer_out_index_mem; -static char buffer_out[BUFFER_OUT_SIZE]; + union { + char data[BUFFER_IN_SIZE]; + struct { + int32_t sync; + int32_t length; + int8_t type; + } __attribute__((packed)) header; + }; +} __attribute__((packed, aligned(4))) buffer_in; -static int get_in_packet_len(void) +static int buffer_in_write_cursor, buffer_in_read_cursor; + +static void in_packet_reset() { - int r; - - memcpy(&r, &buffer_in[4], 4); - return r; + buffer_in_write_cursor = 0; + buffer_in_read_cursor = 0; } -static int get_out_packet_len(void) +static int in_packet_fill(uint8_t *data, int length) { - int r; + int consumed = 0; + while(consumed < length) { + /* Make sure the output buffer is available for any reply + * we might need to send. */ + if(!out_packet_available()) + break; - memcpy(&r, &buffer_out[4], 4); - return r; + if(buffer_in_write_cursor < 4) { + /* Haven't received the synchronization sequence yet. */ + buffer_in.data[buffer_in_write_cursor++] = data[consumed]; + + /* Framing error? */ + if(data[consumed++] != 0x5a) { + buffer_in_write_cursor = 0; + continue; + } + } else if(buffer_in_write_cursor < 8) { + /* Haven't received the packet length yet. */ + buffer_in.data[buffer_in_write_cursor++] = data[consumed++]; + } else if(buffer_in.header.length == 0) { + /* Zero-length packet means session reset. */ + return -2; + } else if(buffer_in.header.length > BUFFER_IN_SIZE) { + /* Packet wouldn't fit in the buffer. */ + return -1; + } else if(buffer_in.header.length > buffer_in_write_cursor) { + /* Receiving payload. */ + int remaining = buffer_in.header.length - buffer_in_write_cursor; + int amount = length - consumed > remaining ? remaining : length - consumed; + memcpy(&buffer_in.data[buffer_in_write_cursor], &data[consumed], + amount); + buffer_in_write_cursor += amount; + consumed += amount; + } + + if(buffer_in.header.length == buffer_in_write_cursor) { + /* We have a complete packet. */ + + buffer_in_read_cursor = sizeof(buffer_in.header); + if(!process_input()) + return -1; + + if(buffer_in_read_cursor < buffer_in_write_cursor) { + log("session.c: read underrun (%d bytes remaining)", + buffer_in_write_cursor - buffer_in_read_cursor); + } + + in_packet_reset(); + } + } + + return consumed; } -static void submit_output(int len) +static void in_packet_chunk(void *ptr, int length) { - memset(&buffer_out[0], 0x5a, 4); - memcpy(&buffer_out[4], &len, 4); - buffer_out_index_data = 0; - buffer_out_index_mem = 0; + if(buffer_in_read_cursor + length > buffer_in_write_cursor) { + log("session.c: read overrun while trying to read %d bytes" + " (%d remaining)", + length, buffer_in_write_cursor - buffer_in_read_cursor); + } + + if(ptr != NULL) + memcpy(ptr, &buffer_in.data[buffer_in_read_cursor], length); + buffer_in_read_cursor += length; } +static int8_t in_packet_int8() +{ + int8_t result; + in_packet_chunk(&result, sizeof(result)); + return result; +} + +static int32_t in_packet_int32() +{ + int32_t result; + in_packet_chunk(&result, sizeof(result)); + return result; +} + +static int64_t in_packet_int64() +{ + int64_t result; + in_packet_chunk(&result, sizeof(result)); + return result; +} + +static const void *in_packet_bytes(int *length) +{ + *length = in_packet_int32(); + const void *ptr = &buffer_in.data[buffer_in_read_cursor]; + in_packet_chunk(NULL, *length); + return ptr; +} + +static const char *in_packet_string() +{ + int length; + const char *string = in_packet_bytes(&length); + if(string[length - 1] != 0) { + log("session.c: string is not zero-terminated"); + return ""; + } + return string; +} + +// ============================= Writer interface ============================= + +static union { + char data[BUFFER_OUT_SIZE]; + struct { + int32_t sync; + int32_t length; + int8_t type; + } __attribute__((packed)) header; +} buffer_out; + +static int buffer_out_read_cursor, buffer_out_sent_cursor, buffer_out_write_cursor; + +static void out_packet_reset() +{ + buffer_out_read_cursor = 0; + buffer_out_write_cursor = 0; + buffer_out_sent_cursor = 0; +} + +static int out_packet_available() +{ + return buffer_out_write_cursor == 0; +} + +static void out_packet_extract(void **data, int *length) +{ + if(buffer_out_write_cursor > 0 && + buffer_out.header.length > 0) { + *data = &buffer_out.data[buffer_out_read_cursor]; + *length = buffer_out_write_cursor - buffer_out_read_cursor; + } else { + *length = 0; + } +} + +static void out_packet_advance_consumed(int length) +{ + if(buffer_out_read_cursor + length > buffer_out_write_cursor) { + log("session.c: write underrun (consume) while trying to" + " acknowledge %d bytes (%d remaining)", + length, buffer_out_write_cursor - buffer_out_read_cursor); + return; + } + + buffer_out_read_cursor += length; +} + +static void out_packet_advance_sent(int length) { + if(buffer_out_sent_cursor + length > buffer_out_write_cursor) { + log("session.c: write underrun (send) while trying to" + " acknowledge %d bytes (%d remaining)", + length, buffer_out_write_cursor - buffer_out_sent_cursor); + return; + } + + buffer_out_sent_cursor += length; + if(buffer_out_sent_cursor == buffer_out_write_cursor) + out_packet_reset(); +} + +static int out_packet_chunk(const void *ptr, int length) +{ + if(buffer_out_write_cursor + length > BUFFER_OUT_SIZE) { + log("session.c: write overrun while trying to write %d bytes" + " (%d remaining)", + length, BUFFER_OUT_SIZE - buffer_out_write_cursor); + return 0; + } + + memcpy(&buffer_out.data[buffer_out_write_cursor], ptr, length); + buffer_out_write_cursor += length; + return 1; +} + +static void out_packet_start(int type) +{ + buffer_out.header.sync = 0x5a5a5a5a; + buffer_out.header.type = type; + buffer_out.header.length = 0; + buffer_out_write_cursor = sizeof(buffer_out.header); +} + +static void out_packet_finish() +{ + buffer_out.header.length = buffer_out_write_cursor; +} + +static void out_packet_empty(int type) +{ + out_packet_start(type); + out_packet_finish(); +} + +static int out_packet_int8(int8_t value) +{ + return out_packet_chunk(&value, sizeof(value)); +} + +static int out_packet_int32(int32_t value) +{ + return out_packet_chunk(&value, sizeof(value)); +} + +static int out_packet_int64(int64_t value) +{ + return out_packet_chunk(&value, sizeof(value)); +} + +static int out_packet_float64(double value) +{ + return out_packet_chunk(&value, sizeof(value)); +} + +static int out_packet_bytes(const void *ptr, int length) +{ + return out_packet_int32(length) && + out_packet_chunk(ptr, length); +} + +static int out_packet_string(const char *string) +{ + return out_packet_bytes(string, strlen(string) + 1); +} + +// =============================== API handling =============================== + static int user_kernel_state; enum { @@ -105,11 +331,12 @@ void session_startup_kernel(void) void session_start(void) { - buffer_in_index = 0; - memset(&buffer_out[4], 0, 4); + in_packet_reset(); + out_packet_reset(); + kloader_stop(); - user_kernel_state = USER_KERNEL_NONE; now = -1; + user_kernel_state = USER_KERNEL_NONE; } void session_end(void) @@ -123,13 +350,16 @@ void session_end(void) /* host to device */ enum { REMOTEMSG_TYPE_LOG_REQUEST = 1, + REMOTEMSG_TYPE_LOG_CLEAR, + REMOTEMSG_TYPE_IDENT_REQUEST, REMOTEMSG_TYPE_SWITCH_CLOCK, - - REMOTEMSG_TYPE_LOAD_OBJECT, + + REMOTEMSG_TYPE_LOAD_LIBRARY, REMOTEMSG_TYPE_RUN_KERNEL, REMOTEMSG_TYPE_RPC_REPLY, + REMOTEMSG_TYPE_RPC_EXCEPTION, REMOTEMSG_TYPE_FLASH_READ_REQUEST, REMOTEMSG_TYPE_FLASH_WRITE_REQUEST, @@ -140,6 +370,7 @@ enum { /* device to host */ enum { REMOTEMSG_TYPE_LOG_REPLY = 1, + REMOTEMSG_TYPE_IDENT_REPLY, REMOTEMSG_TYPE_CLOCK_SWITCH_COMPLETED, REMOTEMSG_TYPE_CLOCK_SWITCH_FAILED, @@ -158,341 +389,508 @@ enum { REMOTEMSG_TYPE_FLASH_ERROR_REPLY }; -static int check_flash_storage_key_len(char *key, unsigned int key_len) -{ - if(key_len == get_in_packet_len() - 8) { - log("Invalid key: not a null-terminated string"); - buffer_out[8] = REMOTEMSG_TYPE_FLASH_ERROR_REPLY; - submit_output(9); - return 0; - } - return 1; -} +static int receive_rpc_value(const char **tag, void **slot); static int process_input(void) { - switch(buffer_in[8]) { + switch(buffer_in.header.type) { + case REMOTEMSG_TYPE_IDENT_REQUEST: + out_packet_start(REMOTEMSG_TYPE_IDENT_REPLY); + out_packet_chunk("AROR", 4); + out_packet_finish(); + break; + + case REMOTEMSG_TYPE_SWITCH_CLOCK: { + int clk = in_packet_int8(); + + if(user_kernel_state >= USER_KERNEL_RUNNING) { + log("Attempted to switch RTIO clock while kernel running"); + out_packet_empty(REMOTEMSG_TYPE_CLOCK_SWITCH_FAILED); + break; + } + + if(rtiocrg_switch_clock(clk)) + out_packet_empty(REMOTEMSG_TYPE_CLOCK_SWITCH_COMPLETED); + else + out_packet_empty(REMOTEMSG_TYPE_CLOCK_SWITCH_FAILED); + break; + } + case REMOTEMSG_TYPE_LOG_REQUEST: #if (LOG_BUFFER_SIZE + 9) > BUFFER_OUT_SIZE #error Output buffer cannot hold the log buffer #endif - buffer_out[8] = REMOTEMSG_TYPE_LOG_REPLY; - log_get(&buffer_out[9]); - submit_output(9 + LOG_BUFFER_SIZE); + out_packet_start(REMOTEMSG_TYPE_LOG_REPLY); + log_get(&buffer_out.data[buffer_out_write_cursor]); + buffer_out_write_cursor += LOG_BUFFER_SIZE; + out_packet_finish(); break; - case REMOTEMSG_TYPE_IDENT_REQUEST: - buffer_out[8] = REMOTEMSG_TYPE_IDENT_REPLY; - buffer_out[9] = 'A'; - buffer_out[10] = 'R'; - buffer_out[11] = 'O'; - buffer_out[12] = 'R'; - submit_output(13); + + case REMOTEMSG_TYPE_LOG_CLEAR: + log_clear(); + out_packet_empty(REMOTEMSG_TYPE_LOG_REPLY); break; - case REMOTEMSG_TYPE_SWITCH_CLOCK: + + case REMOTEMSG_TYPE_FLASH_READ_REQUEST: { +#if SPIFLASH_SECTOR_SIZE - 4 > BUFFER_OUT_SIZE - 9 +#error Output buffer cannot hold the flash storage data +#endif + const char *key = in_packet_string(); + int value_length; + + out_packet_start(REMOTEMSG_TYPE_FLASH_READ_REPLY); + value_length = fs_read(key, &buffer_out.data[buffer_out_write_cursor], + sizeof(buffer_out.data) - buffer_out_write_cursor, NULL); + buffer_out_write_cursor += value_length; + out_packet_finish(); + break; + } + + case REMOTEMSG_TYPE_FLASH_WRITE_REQUEST: { +#if SPIFLASH_SECTOR_SIZE - 4 > BUFFER_IN_SIZE - 9 +#error Input buffer cannot hold the flash storage data +#endif + const char *key, *value; + int value_length; + key = in_packet_string(); + value = in_packet_bytes(&value_length); + + if(fs_write(key, value, value_length)) + out_packet_empty(REMOTEMSG_TYPE_FLASH_OK_REPLY); + else + out_packet_empty(REMOTEMSG_TYPE_FLASH_ERROR_REPLY); + break; + } + + case REMOTEMSG_TYPE_FLASH_ERASE_REQUEST: + fs_erase(); + out_packet_empty(REMOTEMSG_TYPE_FLASH_OK_REPLY); + break; + + case REMOTEMSG_TYPE_FLASH_REMOVE_REQUEST: { + const char *key = in_packet_string(); + + fs_remove(key); + out_packet_empty(REMOTEMSG_TYPE_FLASH_OK_REPLY); + break; + } + + case REMOTEMSG_TYPE_LOAD_LIBRARY: { + const void *kernel = &buffer_in.data[buffer_in_read_cursor]; + buffer_in_read_cursor = buffer_in_write_cursor; + if(user_kernel_state >= USER_KERNEL_RUNNING) { - log("Attempted to switch RTIO clock while kernel running"); - buffer_out[8] = REMOTEMSG_TYPE_CLOCK_SWITCH_FAILED; - submit_output(9); + log("Attempted to load new kernel library while already running"); + out_packet_empty(REMOTEMSG_TYPE_LOAD_FAILED); break; } - if(rtiocrg_switch_clock(buffer_in[9])) - buffer_out[8] = REMOTEMSG_TYPE_CLOCK_SWITCH_COMPLETED; - else - buffer_out[8] = REMOTEMSG_TYPE_CLOCK_SWITCH_FAILED; - submit_output(9); - break; - case REMOTEMSG_TYPE_LOAD_OBJECT: - if(user_kernel_state >= USER_KERNEL_RUNNING) { - log("Attempted to load new kernel while already running"); - buffer_out[8] = REMOTEMSG_TYPE_LOAD_FAILED; - submit_output(9); - break; - } - if(kloader_load(&buffer_in[9], get_in_packet_len() - 8)) { - buffer_out[8] = REMOTEMSG_TYPE_LOAD_COMPLETED; - user_kernel_state = USER_KERNEL_LOADED; - } else - buffer_out[8] = REMOTEMSG_TYPE_LOAD_FAILED; - submit_output(9); - break; - case REMOTEMSG_TYPE_RUN_KERNEL: { - kernel_function k; + if(kloader_load_library(kernel)) { + out_packet_empty(REMOTEMSG_TYPE_LOAD_COMPLETED); + user_kernel_state = USER_KERNEL_LOADED; + } else { + out_packet_empty(REMOTEMSG_TYPE_LOAD_FAILED); + } + break; + } + + case REMOTEMSG_TYPE_RUN_KERNEL: if(user_kernel_state != USER_KERNEL_LOADED) { log("Attempted to run kernel while not in the LOADED state"); - buffer_out[8] = REMOTEMSG_TYPE_KERNEL_STARTUP_FAILED; - submit_output(9); - break; - } - - if((buffer_in_index + 1) > BUFFER_OUT_SIZE) { - log("Kernel name too long"); - buffer_out[8] = REMOTEMSG_TYPE_KERNEL_STARTUP_FAILED; - submit_output(9); - break; - } - buffer_in[buffer_in_index] = 0; - - k = kloader_find((char *)&buffer_in[9]); - if(k == NULL) { - log("Failed to find kernel entry point '%s' in object", &buffer_in[9]); - buffer_out[8] = REMOTEMSG_TYPE_KERNEL_STARTUP_FAILED; - submit_output(9); + out_packet_empty(REMOTEMSG_TYPE_KERNEL_STARTUP_FAILED); break; } watchdog_init(); - kloader_start_user_kernel(k); + kloader_start_kernel(); + user_kernel_state = USER_KERNEL_RUNNING; break; - } + case REMOTEMSG_TYPE_RPC_REPLY: { - struct msg_rpc_reply reply; + struct msg_rpc_recv_request *request; + struct msg_rpc_recv_reply reply; if(user_kernel_state != USER_KERNEL_WAIT_RPC) { log("Unsolicited RPC reply"); - return 0; + return 0; // restart session } - reply.type = MESSAGE_TYPE_RPC_REPLY; - memcpy(&reply.eid, &buffer_in[9], 4); - memcpy(&reply.retval, &buffer_in[13], 4); + request = mailbox_wait_and_receive(); + if(request->type != MESSAGE_TYPE_RPC_RECV_REQUEST) { + log("Expected MESSAGE_TYPE_RPC_RECV_REQUEST, got %d", + request->type); + return 0; // restart session + } + + const char *tag = in_packet_string(); + void *slot = request->slot; + if(!receive_rpc_value(&tag, &slot)) { + log("Failed to receive RPC reply"); + return 0; // restart session + } + + reply.type = MESSAGE_TYPE_RPC_RECV_REPLY; + reply.alloc_size = 0; + reply.exception = NULL; mailbox_send_and_wait(&reply); + user_kernel_state = USER_KERNEL_RUNNING; break; } - case REMOTEMSG_TYPE_FLASH_READ_REQUEST: { -#if SPIFLASH_SECTOR_SIZE - 4 > BUFFER_OUT_SIZE - 9 -#error Output buffer cannot hold the flash storage data -#elif SPIFLASH_SECTOR_SIZE - 4 > BUFFER_IN_SIZE - 9 -#error Input buffer cannot hold the flash storage data -#endif - unsigned int ret, in_packet_len; - char *key; - in_packet_len = get_in_packet_len(); - key = &buffer_in[9]; - buffer_in[in_packet_len] = '\0'; + case REMOTEMSG_TYPE_RPC_EXCEPTION: { + struct msg_rpc_recv_request *request; + struct msg_rpc_recv_reply reply; - buffer_out[8] = REMOTEMSG_TYPE_FLASH_READ_REPLY; - ret = fs_read(key, &buffer_out[9], sizeof(buffer_out) - 9, NULL); - submit_output(9 + ret); + struct artiq_exception exception; + exception.name = in_packet_string(); + exception.message = in_packet_string(); + exception.param[0] = in_packet_int64(); + exception.param[1] = in_packet_int64(); + exception.param[2] = in_packet_int64(); + exception.file = in_packet_string(); + exception.line = in_packet_int32(); + exception.column = in_packet_int32(); + exception.function = in_packet_string(); + + if(user_kernel_state != USER_KERNEL_WAIT_RPC) { + log("Unsolicited RPC exception reply"); + return 0; // restart session + } + + request = mailbox_wait_and_receive(); + if(request->type != MESSAGE_TYPE_RPC_RECV_REQUEST) { + log("Expected MESSAGE_TYPE_RPC_RECV_REQUEST, got %d", + request->type); + return 0; // restart session + } + + reply.type = MESSAGE_TYPE_RPC_RECV_REPLY; + reply.alloc_size = 0; + reply.exception = &exception; + mailbox_send_and_wait(&reply); + + user_kernel_state = USER_KERNEL_RUNNING; break; } - case REMOTEMSG_TYPE_FLASH_WRITE_REQUEST: { - char *key, *value; - unsigned int key_len, value_len, in_packet_len; - int ret; - in_packet_len = get_in_packet_len(); - key = &buffer_in[9]; - key_len = strnlen(key, in_packet_len - 9) + 1; - if(!check_flash_storage_key_len(key, key_len)) - break; - - value_len = in_packet_len - key_len - 9; - value = key + key_len; - ret = fs_write(key, value, value_len); - - if(ret) - buffer_out[8] = REMOTEMSG_TYPE_FLASH_OK_REPLY; - else - buffer_out[8] = REMOTEMSG_TYPE_FLASH_ERROR_REPLY; - submit_output(9); - break; - } - case REMOTEMSG_TYPE_FLASH_ERASE_REQUEST: { - fs_erase(); - buffer_out[8] = REMOTEMSG_TYPE_FLASH_OK_REPLY; - submit_output(9); - break; - } - case REMOTEMSG_TYPE_FLASH_REMOVE_REQUEST: { - char *key; - unsigned int in_packet_len; - - in_packet_len = get_in_packet_len(); - key = &buffer_in[9]; - buffer_in[in_packet_len] = '\0'; - - fs_remove(key); - buffer_out[8] = REMOTEMSG_TYPE_FLASH_OK_REPLY; - submit_output(9); - break; - } default: + log("Received invalid packet type %d from host", + buffer_in.header.type); return 0; } + return 1; } -/* Returns -1 in case of irrecoverable error - * (the session must be dropped and session_end called) - */ -int session_input(void *data, int len) -{ - unsigned char *_data = data; - int consumed; - - consumed = 0; - while(len > 0) { - /* Make sure the output buffer is available for any reply - * we might need to send. */ - if(get_out_packet_len() != 0) - return consumed; - - if(buffer_in_index < 4) { - /* synchronizing */ - if(_data[consumed] == 0x5a) - buffer_in[buffer_in_index++] = 0x5a; - else - buffer_in_index = 0; - consumed++; len--; - } else if(buffer_in_index < 8) { - /* receiving length */ - buffer_in[buffer_in_index++] = _data[consumed]; - consumed++; len--; - if((buffer_in_index == 8) && (get_in_packet_len() == 0)) - /* zero-length packet = session reset */ - return -2; - } else { - /* receiving payload */ - int packet_len; - int count; - - packet_len = get_in_packet_len(); - if(packet_len > BUFFER_IN_SIZE) - return -1; - count = packet_len - buffer_in_index; - if(count > len) - count = len; - memcpy(&buffer_in[buffer_in_index], &_data[consumed], count); - buffer_in_index += count; - - if(buffer_in_index == packet_len) { - if(!process_input()) - return -1; - buffer_in_index = 0; - } - - consumed += count; len -= count; +// See comm_generic.py:_{send,receive}_rpc_value and llvm_ir_generator.py:_rpc_tag. +static void skip_rpc_value(const char **tag) { + switch(*(*tag)++) { + case 't': { + int size = *(*tag)++; + for(int i = 0; i < size; i++) + skip_rpc_value(tag); + break; } + + case 'l': + skip_rpc_value(tag); + break; + + case 'r': + skip_rpc_value(tag); + break; } - return consumed; } -static int add_base_rpc_value(char base_type, void *value, char *buffer_out, int available_space) +static int sizeof_rpc_value(const char **tag) { - switch(base_type) { - case 'n': + switch(*(*tag)++) { + case 't': { // tuple + int size = *(*tag)++; + + int32_t length = 0; + for(int i = 0; i < size; i++) + length += sizeof_rpc_value(tag); + return length; + } + + case 'n': // None return 0; - case 'b': - if(available_space < 1) - return -1; - if(*(char *)value) - buffer_out[0] = 1; - else - buffer_out[0] = 0; - return 1; - case 'i': - if(available_space < 4) - return -1; - memcpy(buffer_out, value, 4); - return 4; - case 'I': - case 'f': - if(available_space < 8) - return -1; - memcpy(buffer_out, value, 8); - return 8; - case 'F': - if(available_space < 16) - return -1; - memcpy(buffer_out, value, 16); - return 16; + + case 'b': // bool + return sizeof(int8_t); + + case 'i': // int(width=32) + return sizeof(int32_t); + + case 'I': // int(width=64) + return sizeof(int64_t); + + case 'f': // float + return sizeof(double); + + case 'F': // Fraction + return sizeof(struct { int64_t numerator, denominator; }); + + case 's': // string + return sizeof(char *); + + case 'l': // list(elt='a) + skip_rpc_value(tag); + return sizeof(struct { int32_t length; struct {} *elements; }); + + case 'r': // range(elt='a) + return sizeof_rpc_value(tag) * 3; + default: - return -1; - } -} - -static int add_rpc_value(int bi, int type_tag, void *value) -{ - char base_type; - int obi, r; - - obi = bi; - base_type = type_tag; - - if((bi + 1) > BUFFER_OUT_SIZE) - return -1; - buffer_out[bi++] = base_type; - - if(base_type == 'l') { - char elt_type; - int len; - int i, p; - - elt_type = type_tag >> 8; - if((bi + 1) > BUFFER_OUT_SIZE) - return -1; - buffer_out[bi++] = elt_type; - - len = *(int *)value; - if((bi + 4) > BUFFER_OUT_SIZE) - return -1; - memcpy(&buffer_out[bi], &len, 4); - bi += 4; - - p = 4; - for(i=0;i BUFFER_OUT_SIZE) - return 0; - buffer_out[bi++] = 0; +} - submit_output(bi); +static void *alloc_rpc_value(int size) +{ + struct msg_rpc_recv_request *request; + struct msg_rpc_recv_reply reply; + + reply.type = MESSAGE_TYPE_RPC_RECV_REPLY; + reply.alloc_size = size; + reply.exception = NULL; + mailbox_send_and_wait(&reply); + + request = mailbox_wait_and_receive(); + if(request->type != MESSAGE_TYPE_RPC_RECV_REQUEST) { + log("Expected MESSAGE_TYPE_RPC_RECV_REQUEST, got %d", + request->type); + return NULL; + } + return request->slot; +} + +static int receive_rpc_value(const char **tag, void **slot) +{ + switch(*(*tag)++) { + case 't': { // tuple + int size = *(*tag)++; + + for(int i = 0; i < size; i++) { + if(!receive_rpc_value(tag, slot)) + return 0; + } + break; + } + + case 'n': // None + break; + + case 'b': { // bool + *((*(int8_t**)slot)++) = in_packet_int8(); + break; + } + + case 'i': { // int(width=32) + *((*(int32_t**)slot)++) = in_packet_int32(); + break; + } + + case 'I': { // int(width=64) + *((*(int64_t**)slot)++) = in_packet_int64(); + break; + } + + case 'f': { // float + *((*(int64_t**)slot)++) = in_packet_int64(); + break; + } + + case 'F': { // Fraction + struct { int64_t numerator, denominator; } *fraction = *slot; + fraction->numerator = in_packet_int64(); + fraction->denominator = in_packet_int64(); + *slot = (void*)((intptr_t)(*slot) + sizeof(*fraction)); + break; + } + + case 's': { // string + const char *in_string = in_packet_string(); + char *out_string = alloc_rpc_value(strlen(in_string) + 1); + memcpy(out_string, in_string, strlen(in_string) + 1); + *((*(char***)slot)++) = out_string; + break; + } + + case 'l': { // list(elt='a) + struct { int32_t length; struct {} *elements; } *list = *slot; + list->length = in_packet_int32(); + + const char *tag_copy = *tag; + list->elements = alloc_rpc_value(sizeof_rpc_value(&tag_copy) * list->length); + + void *element = list->elements; + for(int i = 0; i < list->length; i++) { + const char *tag_copy = *tag; + if(!receive_rpc_value(&tag_copy, &element)) + return 0; + } + skip_rpc_value(tag); + break; + } + + case 'r': { // range(elt='a) + const char *tag_copy; + tag_copy = *tag; + if(!receive_rpc_value(&tag_copy, slot)) // min + return 0; + tag_copy = *tag; + if(!receive_rpc_value(&tag_copy, slot)) // max + return 0; + tag_copy = *tag; + if(!receive_rpc_value(&tag_copy, slot)) // step + return 0; + *tag = tag_copy; + break; + } + + default: + log("receive_rpc_value: unknown tag %02x", *((*tag) - 1)); + return 0; + } + + return 1; +} + +static int send_rpc_value(const char **tag, void **value) +{ + if(!out_packet_int8(**tag)) + return 0; + + switch(*(*tag)++) { + case 't': { // tuple + int size = *(*tag)++; + if(!out_packet_int8(size)) + return 0; + + for(int i = 0; i < size; i++) { + if(!send_rpc_value(tag, value)) + return 0; + } + break; + } + + case 'n': // None + break; + + case 'b': { // bool + return out_packet_int8(*((*(int8_t**)value)++)); + } + + case 'i': { // int(width=32) + return out_packet_int32(*((*(int32_t**)value)++)); + } + + case 'I': { // int(width=64) + return out_packet_int64(*((*(int64_t**)value)++)); + } + + case 'f': { // float + return out_packet_float64(*((*(double**)value)++)); + } + + case 'F': { // Fraction + struct { int64_t numerator, denominator; } *fraction = *value; + if(!out_packet_int64(fraction->numerator)) + return 0; + if(!out_packet_int64(fraction->denominator)) + return 0; + *value = (void*)((intptr_t)(*value) + sizeof(*fraction)); + break; + } + + case 's': { // string + return out_packet_string(*((*(const char***)value)++)); + } + + case 'l': { // list(elt='a) + struct { uint32_t length; struct {} *elements; } *list = *value; + void *element = list->elements; + + if(!out_packet_int32(list->length)) + return 0; + + for(int i = 0; i < list->length; i++) { + const char *tag_copy = *tag; + if(!send_rpc_value(&tag_copy, &element)) + return 0; + } + skip_rpc_value(tag); + + *value = (void*)((intptr_t)(*value) + sizeof(*list)); + break; + } + + case 'r': { // range(elt='a) + const char *tag_copy; + tag_copy = *tag; + if(!send_rpc_value(&tag_copy, value)) // min + return 0; + tag_copy = *tag; + if(!send_rpc_value(&tag_copy, value)) // max + return 0; + tag_copy = *tag; + if(!send_rpc_value(&tag_copy, value)) // step + return 0; + *tag = tag_copy; + break; + } + + case 'o': { // option(inner='a) + struct { int8_t present; struct {} contents; } *option = *value; + void *contents = &option->contents; + + if(!out_packet_int8(option->present)) + return 0; + + // option never appears in composite types, so we don't have + // to accurately advance *value. + if(option->present) { + return send_rpc_value(tag, &contents); + } else { + skip_rpc_value(tag); + break; + } + } + + case 'O': { // host object + struct { uint32_t id; } **object = *value; + return out_packet_int32((*object)->id); + } + + default: + log("send_rpc_value: unknown tag %02x", *((*tag) - 1)); + return 0; + } + + return 1; +} + +static int send_rpc_request(int service, const char *tag, va_list args) +{ + out_packet_start(REMOTEMSG_TYPE_RPC_REQUEST); + out_packet_int32(service); + + while(*tag != ':') { + void *value = va_arg(args, void*); + if(!kloader_validate_kpointer(value)) + return 0; + if(!send_rpc_value(&tag, &value)) + return 0; + } + out_packet_int8(0); + + out_packet_string(tag + 1); // return tags + out_packet_finish(); return 1; } @@ -503,6 +901,12 @@ static int process_kmsg(struct msg_base *umsg) return 0; if(kloader_is_essential_kmsg(umsg->type)) return 1; /* handled elsewhere */ + if(user_kernel_state == USER_KERNEL_WAIT_RPC && + umsg->type == MESSAGE_TYPE_RPC_RECV_REQUEST) { + // Handled and acknowledged when we receive + // REMOTEMSG_TYPE_RPC_{EXCEPTION,REPLY}. + return 1; + } if(user_kernel_state != USER_KERNEL_RUNNING) { log("Received unexpected message from kernel CPU while not in running state"); return 0; @@ -510,96 +914,120 @@ static int process_kmsg(struct msg_base *umsg) switch(umsg->type) { case MESSAGE_TYPE_FINISHED: - buffer_out[8] = REMOTEMSG_TYPE_KERNEL_FINISHED; - submit_output(9); + out_packet_empty(REMOTEMSG_TYPE_KERNEL_FINISHED); kloader_stop(); user_kernel_state = USER_KERNEL_LOADED; mailbox_acknowledge(); break; + case MESSAGE_TYPE_EXCEPTION: { struct msg_exception *msg = (struct msg_exception *)umsg; - buffer_out[8] = REMOTEMSG_TYPE_KERNEL_EXCEPTION; - memcpy(&buffer_out[9], &msg->eid, 4); - memcpy(&buffer_out[13], msg->eparams, 3*8); - submit_output(9+4+3*8); + out_packet_start(REMOTEMSG_TYPE_KERNEL_EXCEPTION); + + out_packet_string(msg->exception->name); + out_packet_string(msg->exception->message); + out_packet_int64(msg->exception->param[0]); + out_packet_int64(msg->exception->param[1]); + out_packet_int64(msg->exception->param[2]); + + out_packet_string(msg->exception->file); + out_packet_int32(msg->exception->line); + out_packet_int32(msg->exception->column); + out_packet_string(msg->exception->function); + + kloader_filter_backtrace(msg->backtrace, + &msg->backtrace_size); + + out_packet_int32(msg->backtrace_size); + for(int i = 0; i < msg->backtrace_size; i++) { + struct artiq_backtrace_item *item = &msg->backtrace[i]; + out_packet_int32(item->function + item->offset); + } + + out_packet_finish(); kloader_stop(); user_kernel_state = USER_KERNEL_LOADED; mailbox_acknowledge(); break; } - case MESSAGE_TYPE_RPC_REQUEST: { - struct msg_rpc_request *msg = (struct msg_rpc_request *)umsg; - if(!send_rpc_request(msg->rpc_num, msg->args)) - return 0; + case MESSAGE_TYPE_RPC_SEND: { + struct msg_rpc_send *msg = (struct msg_rpc_send *)umsg; + + if(!send_rpc_request(msg->service, msg->tag, msg->args)) { + log("Failed to send RPC request (service %d, tag %s)", + msg->service, msg->tag); + return 0; // restart session + } + user_kernel_state = USER_KERNEL_WAIT_RPC; mailbox_acknowledge(); break; } + default: { - log("ERROR: received invalid message type from kernel CPU"); + log("Received invalid message type %d from kernel CPU", + umsg->type); return 0; } } + return 1; } -/* len is set to -1 in case of irrecoverable error +/* Returns amount of bytes consumed on success. + * Returns -1 in case of irrecoverable error + * (the session must be dropped and session_end called). + * Returns -2 if the host has requested session reset. + */ +int session_input(void *data, int length) +{ + return in_packet_fill((uint8_t*)data, length); +} + +/* *length is set to -1 in case of irrecoverable error * (the session must be dropped and session_end called) */ -void session_poll(void **data, int *len) +void session_poll(void **data, int *length) { - int l; - if(user_kernel_state == USER_KERNEL_RUNNING) { if(watchdog_expired()) { log("Watchdog expired"); - *len = -1; + *length = -1; return; } if(!rtiocrg_check()) { log("RTIO clock failure"); - *len = -1; + *length = -1; return; } } - l = get_out_packet_len(); - - /* If the output buffer is available, + /* If the output buffer is available, * check if the kernel CPU has something to transmit. */ - if(l == 0) { - struct msg_base *umsg; - - umsg = mailbox_receive(); + if(out_packet_available()) { + struct msg_base *umsg = mailbox_receive(); if(umsg) { if(!process_kmsg(umsg)) { - *len = -1; + *length = -1; return; } } - l = get_out_packet_len(); } - if(l > 0) { - *len = l - buffer_out_index_data; - *data = &buffer_out[buffer_out_index_data]; - } else - *len = 0; + out_packet_extract(data, length); } -void session_ack_data(int len) +void session_ack_consumed(int length) { - buffer_out_index_data += len; + out_packet_advance_consumed(length); } -void session_ack_mem(int len) +void session_ack_sent(int length) { - buffer_out_index_mem += len; - if(buffer_out_index_mem >= get_out_packet_len()) - memset(&buffer_out[4], 0, 4); + out_packet_advance_sent(length); } diff --git a/artiq/runtime/session.h b/artiq/runtime/session.h index 5728103ae..4184b1758 100644 --- a/artiq/runtime/session.h +++ b/artiq/runtime/session.h @@ -5,11 +5,9 @@ void session_startup_kernel(void); void session_start(void); void session_end(void); -int session_input(void *data, int len); -void session_poll(void **data, int *len); -void session_ack_data(int len); -void session_ack_mem(int len); - -int rpc(int rpc_num, ...); +int session_input(void *data, int length); +void session_poll(void **data, int *length); +void session_ack_consumed(int length); +void session_ack_sent(int length); #endif /* __SESSION_H */ diff --git a/artiq/runtime/ttl.c b/artiq/runtime/ttl.c index 387b977b1..577ab1eeb 100644 --- a/artiq/runtime/ttl.c +++ b/artiq/runtime/ttl.c @@ -1,6 +1,6 @@ #include -#include "exceptions.h" +#include "artiq_personality.h" #include "rtio.h" #include "ttl.h" @@ -40,7 +40,8 @@ long long int ttl_get(int channel, long long int time_limit) while((status = rtio_i_status_read())) { if(rtio_i_status_read() & RTIO_I_STATUS_OVERFLOW) { rtio_i_overflow_reset_write(1); - exception_raise_params(EID_RTIO_OVERFLOW, + artiq_raise_from_c("RTIOOverflow", + "RTIO overflow at channel {0}", channel, 0, 0); } if(rtio_get_counter() >= time_limit) { diff --git a/artiq/test/compiler/domination.py b/artiq/test/compiler/domination.py new file mode 100644 index 000000000..0c4334617 --- /dev/null +++ b/artiq/test/compiler/domination.py @@ -0,0 +1,166 @@ +import unittest +from artiq.compiler.analyses.domination import DominatorTree, PostDominatorTree + +class MockBasicBlock: + def __init__(self, name): + self.name = name + self._successors = [] + self._predecessors = [] + + def successors(self): + return self._successors + + def predecessors(self): + return self._predecessors + + def set_successors(self, successors): + self._successors = list(successors) + for block in self._successors: + block._predecessors.append(self) + +class MockFunction: + def __init__(self, entry, basic_blocks): + self._entry = entry + self.basic_blocks = basic_blocks + + def entry(self): + return self._entry + +def makefn(entry_name, graph): + blocks = {} + for block_name in graph: + blocks[block_name] = MockBasicBlock(block_name) + for block_name in graph: + successors = list(map(lambda name: blocks[name], graph[block_name])) + blocks[block_name].set_successors(successors) + return MockFunction(blocks[entry_name], blocks.values()) + +def dom(function, domtree): + dom = {} + for block in function.basic_blocks: + dom[block.name] = [dom_block.name for dom_block in domtree.dominators(block)] + return dom + +def idom(function, domtree): + idom = {} + for block in function.basic_blocks: + idom_block = domtree.immediate_dominator(block) + idom[block.name] = idom_block.name if idom_block else None + return idom + +class TestDominatorTree(unittest.TestCase): + def test_linear(self): + func = makefn('A', { + 'A': ['B'], + 'B': ['C'], + 'C': [] + }) + domtree = DominatorTree(func) + self.assertEqual({ + 'C': 'B', 'B': 'A', 'A': 'A' + }, idom(func, domtree)) + self.assertEqual({ + 'C': ['C', 'B', 'A'], 'B': ['B', 'A'], 'A': ['A'] + }, dom(func, domtree)) + + def test_diamond(self): + func = makefn('A', { + 'A': ['C', 'B'], + 'B': ['D'], + 'C': ['D'], + 'D': [] + }) + domtree = DominatorTree(func) + self.assertEqual({ + 'D': 'A', 'C': 'A', 'B': 'A', 'A': 'A' + }, idom(func, domtree)) + + def test_combined(self): + func = makefn('A', { + 'A': ['B', 'D'], + 'B': ['C'], + 'C': ['E'], + 'D': ['E'], + 'E': [] + }) + domtree = DominatorTree(func) + self.assertEqual({ + 'A': 'A', 'B': 'A', 'C': 'B', 'D': 'A', 'E': 'A' + }, idom(func, domtree)) + + def test_figure_2(self): + func = makefn(5, { + 5: [3, 4], + 4: [1], + 3: [2], + 2: [1], + 1: [2] + }) + domtree = DominatorTree(func) + self.assertEqual({ + 1: 5, 2: 5, 3: 5, 4: 5, 5: 5 + }, idom(func, domtree)) + + def test_figure_4(self): + func = makefn(6, { + 6: [4, 5], + 5: [1], + 4: [3, 2], + 3: [2], + 2: [1, 3], + 1: [2] + }) + domtree = DominatorTree(func) + self.assertEqual({ + 1: 6, 2: 6, 3: 6, 4: 6, 5: 6, 6: 6 + }, idom(func, domtree)) + +class TestPostDominatorTree(unittest.TestCase): + def test_linear(self): + func = makefn('A', { + 'A': ['B'], + 'B': ['C'], + 'C': [] + }) + domtree = PostDominatorTree(func) + self.assertEqual({ + 'A': 'B', 'B': 'C', 'C': None + }, idom(func, domtree)) + + def test_diamond(self): + func = makefn('A', { + 'A': ['B', 'D'], + 'B': ['C'], + 'C': ['E'], + 'D': ['E'], + 'E': [] + }) + domtree = PostDominatorTree(func) + self.assertEqual({ + 'E': None, 'D': 'E', 'C': 'E', 'B': 'C', 'A': 'E' + }, idom(func, domtree)) + + def test_multi_exit(self): + func = makefn('A', { + 'A': ['B', 'C'], + 'B': [], + 'C': [] + }) + domtree = PostDominatorTree(func) + self.assertEqual({ + 'A': None, 'B': None, 'C': None + }, idom(func, domtree)) + + def test_multi_exit_diamond(self): + func = makefn('A', { + 'A': ['B', 'C'], + 'B': ['D'], + 'C': ['D'], + 'D': ['E', 'F'], + 'E': [], + 'F': [] + }) + domtree = PostDominatorTree(func) + self.assertEqual({ + 'A': 'D', 'B': 'D', 'C': 'D', 'D': None, 'E': None, 'F': None + }, idom(func, domtree)) diff --git a/artiq/test/coredevice/embedding.py b/artiq/test/coredevice/embedding.py new file mode 100644 index 000000000..111f40485 --- /dev/null +++ b/artiq/test/coredevice/embedding.py @@ -0,0 +1,41 @@ +from artiq.language import * +from artiq.test.hardware_testbench import ExperimentCase + +class Roundtrip(EnvExperiment): + def build(self): + self.attr_device("core") + + @kernel + def roundtrip(self, obj, fn): + fn(obj) + +class RoundtripTest(ExperimentCase): + def assertRoundtrip(self, obj): + exp = self.create(Roundtrip) + def callback(objcopy): + self.assertEqual(obj, objcopy) + exp.roundtrip(obj, callback) + + def test_None(self): + self.assertRoundtrip(None) + + def test_bool(self): + self.assertRoundtrip(True) + self.assertRoundtrip(False) + + def test_int(self): + self.assertRoundtrip(42) + self.assertRoundtrip(int(42, width=64)) + + def test_float(self): + self.assertRoundtrip(42.0) + + def test_str(self): + self.assertRoundtrip("foo") + + def test_list(self): + self.assertRoundtrip([10]) + + def test_object(self): + obj = object() + self.assertRoundtrip(obj) diff --git a/artiq/test/coredevice_vs_host.py b/artiq/test/coredevice/portability.py similarity index 100% rename from artiq/test/coredevice_vs_host.py rename to artiq/test/coredevice/portability.py diff --git a/artiq/test/coredevice.py b/artiq/test/coredevice/rtio.py similarity index 91% rename from artiq/test/coredevice.py rename to artiq/test/coredevice/rtio.py index aec923bd3..ad487aade 100644 --- a/artiq/test/coredevice.py +++ b/artiq/test/coredevice/rtio.py @@ -5,8 +5,7 @@ from math import sqrt from artiq.language import * from artiq.test.hardware_testbench import ExperimentCase -from artiq.coredevice.runtime_exceptions import RTIOUnderflow -from artiq.coredevice import runtime_exceptions +from artiq.coredevice.exceptions import RTIOUnderflow, RTIOSequenceError class RTT(EnvExperiment): @@ -29,7 +28,7 @@ class RTT(EnvExperiment): delay(1*us) t0 = now_mu() self.ttl_inout.pulse(1*us) - self.set_rtt(mu_to_seconds(self.ttl_inout.timestamp_mu() - t0)) + self.set_result("rtt", mu_to_seconds(self.ttl_inout.timestamp_mu() - t0)) class Loopback(EnvExperiment): @@ -51,7 +50,7 @@ class Loopback(EnvExperiment): delay(1*us) t0 = now_mu() self.loop_out.pulse(1*us) - self.set_rtt(mu_to_seconds(self.loop_in.timestamp_mu() - t0)) + self.set_result("rtt", mu_to_seconds(self.loop_in.timestamp_mu() - t0)) class ClockGeneratorLoopback(EnvExperiment): @@ -73,7 +72,7 @@ class ClockGeneratorLoopback(EnvExperiment): with sequential: delay(200*ns) self.loop_clock_out.set(1*MHz) - self.set_count(self.loop_clock_in.count()) + self.set_result("count", self.loop_clock_in.count()) class PulseRate(EnvExperiment): @@ -96,7 +95,7 @@ class PulseRate(EnvExperiment): dt += 1 self.core.break_realtime() else: - self.set_pulse_rate(mu_to_seconds(2*dt)) + self.set_result("pulse_rate", mu_to_seconds(2*dt)) break @@ -130,7 +129,7 @@ class LoopbackCount(EnvExperiment): for i in range(self.npulses): delay(25*ns) self.ttl_inout.pulse(25*ns) - self.set_count(self.ttl_inout.count()) + self.set_result("count", self.ttl_inout.count()) class Underflow(EnvExperiment): @@ -180,7 +179,7 @@ class TimeKeepsRunning(EnvExperiment): @kernel def run(self): - self.set_time_at_start(now_mu()) + self.set_result("time_at_start", now_mu()) class Handover(EnvExperiment): @@ -188,8 +187,8 @@ class Handover(EnvExperiment): self.setattr_device("core") @kernel - def get_now(self): - self.time_at_start = now_mu() + def get_now(self, var): + self.set_result(var, now_mu()) def run(self): self.get_now() @@ -232,11 +231,11 @@ class CoredeviceTest(ExperimentCase): self.assertEqual(count, npulses) def test_underflow(self): - with self.assertRaises(runtime_exceptions.RTIOUnderflow): + with self.assertRaises(RTIOUnderflow): self.execute(Underflow) def test_sequence_error(self): - with self.assertRaises(runtime_exceptions.RTIOSequenceError): + with self.assertRaises(RTIOSequenceError): self.execute(SequenceError) def test_collision_error(self): @@ -256,7 +255,7 @@ class CoredeviceTest(ExperimentCase): dead_time = mu_to_seconds(t2 - t1, self.device_mgr.get("core")) print(dead_time) self.assertGreater(dead_time, 1*ms) - self.assertLess(dead_time, 300*ms) + self.assertLess(dead_time, 500*ms) def test_handover(self): self.execute(Handover) @@ -269,19 +268,19 @@ class RPCTiming(EnvExperiment): self.setattr_device("core") self.setattr_argument("repeats", FreeValue(100)) - def nop(self, x): + def nop(self): pass @kernel def bench(self): - self.ts = [0. for _ in range(self.repeats)] for i in range(self.repeats): t1 = self.core.get_rtio_counter_mu() - self.nop(1) + self.nop() t2 = self.core.get_rtio_counter_mu() self.ts[i] = mu_to_seconds(t2 - t1) def run(self): + self.ts = [0. for _ in range(self.repeats)] self.bench() mean = sum(self.ts)/self.repeats self.set_dataset("rpc_time_stddev", sqrt( diff --git a/artiq/test/hardware_testbench.py b/artiq/test/hardware_testbench.py index a9a9503af..65d35d2e8 100644 --- a/artiq/test/hardware_testbench.py +++ b/artiq/test/hardware_testbench.py @@ -7,9 +7,10 @@ import unittest import logging from artiq.language import * -from artiq.protocols import pyon from artiq.master.databases import DeviceDB, DatasetDB from artiq.master.worker_db import DeviceManager, DatasetManager +from artiq.coredevice.core import CompileError +from artiq.protocols import pyon from artiq.frontend.artiq_run import DummyScheduler @@ -41,7 +42,16 @@ class ExperimentCase(unittest.TestCase): virtual_devices={"scheduler": DummyScheduler()}) self.dataset_mgr = DatasetManager(self.dataset_db) - def execute(self, cls, **kwargs): + def create(self, cls, **kwargs): + try: + exp = cls(self.device_mgr, self.dataset_mgr, **kwargs) + exp.prepare() + return exp + except KeyError as e: + # skip if ddb does not match requirements + raise unittest.SkipTest(*e.args) + + def execute(self, cls, *args, **kwargs): expid = { "file": sys.modules[cls.__module__].__file__, "class_name": cls.__name__, @@ -49,14 +59,12 @@ class ExperimentCase(unittest.TestCase): } self.device_mgr.virtual_devices["scheduler"].expid = expid try: - try: - exp = cls(self.device_mgr, self.dataset_mgr, **kwargs) - except KeyError as e: - # skip if ddb does not match requirements - raise unittest.SkipTest(*e.args) - exp.prepare() + exp = self.create(cls, **kwargs) exp.run() exp.analyze() return exp + except CompileError as error: + # Reduce amount of text on terminal. + raise error from None finally: self.device_mgr.close_devices() diff --git a/artiq/test/language.py b/artiq/test/language.py new file mode 100644 index 000000000..103df4569 --- /dev/null +++ b/artiq/test/language.py @@ -0,0 +1,44 @@ +import unittest + +from artiq.language.core import * + + +class LanguageCoreTest(unittest.TestCase): + def test_unary(self): + self.assertEqual(int(10), +int(10)) + self.assertEqual(int(-10), -int(10)) + self.assertEqual(int(~10), ~int(10)) + self.assertEqual(int(10), round(int(10))) + + def test_arith(self): + self.assertEqual(int(9), int(4) + int(5)) + self.assertEqual(int(9), int(4) + 5) + self.assertEqual(int(9), 5 + int(4)) + + self.assertEqual(9.0, int(4) + 5.0) + self.assertEqual(9.0, 5.0 + int(4)) + + a = int(5) + a += int(2) + a += 2 + self.assertEqual(int(9), a) + + def test_compare(self): + self.assertTrue(int(9) > int(8)) + self.assertTrue(int(9) > 8) + self.assertTrue(int(9) > 8.0) + self.assertTrue(9 > int(8)) + self.assertTrue(9.0 > int(8)) + + def test_bitwise(self): + self.assertEqual(int(0x100), int(0x10) << int(4)) + self.assertEqual(int(0x100), int(0x10) << 4) + self.assertEqual(int(0x100), 0x10 << int(4)) + + def test_wraparound(self): + self.assertEqual(int(0xffffffff), int(-1)) + self.assertTrue(int(0x7fffffff) > int(1)) + self.assertTrue(int(0x80000000) < int(-1)) + + self.assertEqual(int(9), int(10) + int(0xffffffff)) + self.assertEqual(-1.0, float(int(0xfffffffe) + int(1))) diff --git a/artiq/test/py2llvm.py b/artiq/test/py2llvm.py deleted file mode 100644 index a8eea8bb9..000000000 --- a/artiq/test/py2llvm.py +++ /dev/null @@ -1,374 +0,0 @@ -import unittest -import ast -import inspect -from fractions import Fraction -from ctypes import CFUNCTYPE, c_int, c_int32, c_int64, c_double -import struct -import os - -import llvmlite_artiq.binding as llvm - -from artiq.language.core import int64 -from artiq.py2llvm.infer_types import infer_function_types -from artiq.py2llvm import base_types, lists -from artiq.py2llvm.module import Module - - -llvm.initialize() -llvm.initialize_native_target() -llvm.initialize_native_asmprinter() -if struct.calcsize("P") < 8 and os.name != "nt": - from ctypes import _dlopen, RTLD_GLOBAL - _dlopen("libgcc_s.so", RTLD_GLOBAL) - - -def _base_types(choice): - a = 2 # promoted later to int64 - b = a + 1 # initially int32, becomes int64 after a is promoted - c = b//2 # initially int32, becomes int64 after b is promoted - d = 4 and 5 # stays int32 - x = int64(7) - a += x # promotes a to int64 - foo = True | True or False - bar = None - myf = 4.5 - myf2 = myf + x - - if choice and foo and not bar: - return d - elif myf2: - return x + c - else: - return int64(8) - - -def _build_function_types(f): - return infer_function_types( - None, ast.parse(inspect.getsource(f)), - dict()) - - -class FunctionBaseTypesCase(unittest.TestCase): - def setUp(self): - self.ns = _build_function_types(_base_types) - - def test_simple_types(self): - self.assertIsInstance(self.ns["foo"], base_types.VBool) - self.assertIsInstance(self.ns["bar"], base_types.VNone) - self.assertIsInstance(self.ns["d"], base_types.VInt) - self.assertEqual(self.ns["d"].nbits, 32) - self.assertIsInstance(self.ns["x"], base_types.VInt) - self.assertEqual(self.ns["x"].nbits, 64) - self.assertIsInstance(self.ns["myf"], base_types.VFloat) - self.assertIsInstance(self.ns["myf2"], base_types.VFloat) - - def test_promotion(self): - for v in "abc": - self.assertIsInstance(self.ns[v], base_types.VInt) - self.assertEqual(self.ns[v].nbits, 64) - - def test_return(self): - self.assertIsInstance(self.ns["return"], base_types.VInt) - self.assertEqual(self.ns["return"].nbits, 64) - - -def test_list_types(): - a = [0, 0, 0, 0, 0] - for i in range(2): - a[i] = int64(8) - return a - - -class FunctionListTypesCase(unittest.TestCase): - def setUp(self): - self.ns = _build_function_types(test_list_types) - - def test_list_types(self): - self.assertIsInstance(self.ns["a"], lists.VList) - self.assertIsInstance(self.ns["a"].el_type, base_types.VInt) - self.assertEqual(self.ns["a"].el_type.nbits, 64) - self.assertEqual(self.ns["a"].alloc_count, 5) - self.assertIsInstance(self.ns["i"], base_types.VInt) - self.assertEqual(self.ns["i"].nbits, 32) - - -def _value_to_ctype(v): - if isinstance(v, base_types.VBool): - return c_int - elif isinstance(v, base_types.VInt): - if v.nbits == 32: - return c_int32 - elif v.nbits == 64: - return c_int64 - else: - raise NotImplementedError(str(v)) - elif isinstance(v, base_types.VFloat): - return c_double - else: - raise NotImplementedError(str(v)) - - -class CompiledFunction: - def __init__(self, function, param_types): - module = Module() - - func_def = ast.parse(inspect.getsource(function)).body[0] - function, retval = module.compile_function(func_def, param_types) - argvals = [param_types[arg.arg] for arg in func_def.args.args] - - ee = module.get_ee() - cfptr = ee.get_pointer_to_global( - module.llvm_module_ref.get_function(function.name)) - retval_ctype = _value_to_ctype(retval) - argval_ctypes = [_value_to_ctype(argval) for argval in argvals] - self.cfunc = CFUNCTYPE(retval_ctype, *argval_ctypes)(cfptr) - - # HACK: prevent garbage collection of self.cfunc internals - self.ee = ee - - def __call__(self, *args): - return self.cfunc(*args) - - -def arith(op, a, b): - if op == 0: - return a + b - elif op == 1: - return a - b - elif op == 2: - return a * b - else: - return a / b - - -def is_prime(x): - d = 2 - while d*d <= x: - if not x % d: - return False - d += 1 - return True - - -def simplify_encode(a, b): - f = Fraction(a, b) - return f.numerator*1000 + f.denominator - - -def frac_arith_encode(op, a, b, c, d): - if op == 0: - f = Fraction(a, b) - Fraction(c, d) - elif op == 1: - f = Fraction(a, b) + Fraction(c, d) - elif op == 2: - f = Fraction(a, b) * Fraction(c, d) - else: - f = Fraction(a, b) / Fraction(c, d) - return f.numerator*1000 + f.denominator - - -def frac_arith_encode_int(op, a, b, x): - if op == 0: - f = Fraction(a, b) - x - elif op == 1: - f = Fraction(a, b) + x - elif op == 2: - f = Fraction(a, b) * x - else: - f = Fraction(a, b) / x - return f.numerator*1000 + f.denominator - - -def frac_arith_encode_int_rev(op, a, b, x): - if op == 0: - f = x - Fraction(a, b) - elif op == 1: - f = x + Fraction(a, b) - elif op == 2: - f = x * Fraction(a, b) - else: - f = x / Fraction(a, b) - return f.numerator*1000 + f.denominator - - -def frac_arith_float(op, a, b, x): - if op == 0: - return Fraction(a, b) - x - elif op == 1: - return Fraction(a, b) + x - elif op == 2: - return Fraction(a, b) * x - else: - return Fraction(a, b) / x - - -def frac_arith_float_rev(op, a, b, x): - if op == 0: - return x - Fraction(a, b) - elif op == 1: - return x + Fraction(a, b) - elif op == 2: - return x * Fraction(a, b) - else: - return x / Fraction(a, b) - - -def list_test(): - x = 80 - a = [3 for x in range(7)] - b = [1, 2, 4, 5, 4, 0, 5] - a[3] = x - a[0] += 6 - a[1] = b[1] + b[2] - - acc = 0 - for i in range(7): - if i and a[i]: - acc += 1 - acc += a[i] - return acc - - -def corner_cases(): - two = True + True - (not True) - three = two + True//True - False*True - two_float = three - True/True - one_float = two_float - (1.0 == bool(0.1)) - zero = int(one_float) + round(-0.6) - eleven_float = zero + 5.5//0.5 - ten_float = eleven_float + round(Fraction(2, -3)) - return ten_float - - -def _test_range(): - for i in range(5, 10): - yield i - yield -i - - -class CodeGenCase(unittest.TestCase): - def _test_float_arith(self, op): - arith_c = CompiledFunction(arith, { - "op": base_types.VInt(), - "a": base_types.VFloat(), "b": base_types.VFloat()}) - for a in _test_range(): - for b in _test_range(): - self.assertEqual(arith_c(op, a/2, b/2), arith(op, a/2, b/2)) - - def test_float_add(self): - self._test_float_arith(0) - - def test_float_sub(self): - self._test_float_arith(1) - - def test_float_mul(self): - self._test_float_arith(2) - - def test_float_div(self): - self._test_float_arith(3) - - @unittest.skipIf(os.name == "nt", "This test is known to fail on Windows") - def test_is_prime(self): - is_prime_c = CompiledFunction(is_prime, {"x": base_types.VInt()}) - for i in range(200): - self.assertEqual(is_prime_c(i), is_prime(i)) - - def test_frac_simplify(self): - simplify_encode_c = CompiledFunction( - simplify_encode, {"a": base_types.VInt(), "b": base_types.VInt()}) - for a in _test_range(): - for b in _test_range(): - self.assertEqual( - simplify_encode_c(a, b), simplify_encode(a, b)) - - def _test_frac_arith(self, op): - frac_arith_encode_c = CompiledFunction( - frac_arith_encode, { - "op": base_types.VInt(), - "a": base_types.VInt(), "b": base_types.VInt(), - "c": base_types.VInt(), "d": base_types.VInt()}) - for a in _test_range(): - for b in _test_range(): - for c in _test_range(): - for d in _test_range(): - self.assertEqual( - frac_arith_encode_c(op, a, b, c, d), - frac_arith_encode(op, a, b, c, d)) - - def test_frac_add(self): - self._test_frac_arith(0) - - def test_frac_sub(self): - self._test_frac_arith(1) - - def test_frac_mul(self): - self._test_frac_arith(2) - - def test_frac_div(self): - self._test_frac_arith(3) - - def _test_frac_arith_int(self, op, rev): - f = frac_arith_encode_int_rev if rev else frac_arith_encode_int - f_c = CompiledFunction(f, { - "op": base_types.VInt(), - "a": base_types.VInt(), "b": base_types.VInt(), - "x": base_types.VInt()}) - for a in _test_range(): - for b in _test_range(): - for x in _test_range(): - self.assertEqual( - f_c(op, a, b, x), - f(op, a, b, x)) - - def test_frac_add_int(self): - self._test_frac_arith_int(0, False) - self._test_frac_arith_int(0, True) - - def test_frac_sub_int(self): - self._test_frac_arith_int(1, False) - self._test_frac_arith_int(1, True) - - def test_frac_mul_int(self): - self._test_frac_arith_int(2, False) - self._test_frac_arith_int(2, True) - - def test_frac_div_int(self): - self._test_frac_arith_int(3, False) - self._test_frac_arith_int(3, True) - - def _test_frac_arith_float(self, op, rev): - f = frac_arith_float_rev if rev else frac_arith_float - f_c = CompiledFunction(f, { - "op": base_types.VInt(), - "a": base_types.VInt(), "b": base_types.VInt(), - "x": base_types.VFloat()}) - for a in _test_range(): - for b in _test_range(): - for x in _test_range(): - self.assertAlmostEqual( - f_c(op, a, b, x/2), - f(op, a, b, x/2)) - - def test_frac_add_float(self): - self._test_frac_arith_float(0, False) - self._test_frac_arith_float(0, True) - - def test_frac_sub_float(self): - self._test_frac_arith_float(1, False) - self._test_frac_arith_float(1, True) - - def test_frac_mul_float(self): - self._test_frac_arith_float(2, False) - self._test_frac_arith_float(2, True) - - def test_frac_div_float(self): - self._test_frac_arith_float(3, False) - self._test_frac_arith_float(3, True) - - def test_list(self): - list_test_c = CompiledFunction(list_test, dict()) - self.assertEqual(list_test_c(), list_test()) - - def test_corner_cases(self): - corner_cases_c = CompiledFunction(corner_cases, dict()) - self.assertEqual(corner_cases_c(), corner_cases()) diff --git a/artiq/test/transforms.py b/artiq/test/transforms.py deleted file mode 100644 index dffee41a2..000000000 --- a/artiq/test/transforms.py +++ /dev/null @@ -1,44 +0,0 @@ -import unittest -import ast - -from artiq import ns -from artiq.coredevice import comm_dummy, core -from artiq.transforms.unparse import unparse - - -optimize_in = """ - -def run(): - dds_sysclk = Fraction(1000000000, 1) - n = seconds_to_mu((1.2345 * Fraction(1, 1000000000))) - with sequential: - frequency = 345 * Fraction(1000000, 1) - frequency_to_ftw_return = int((((2 ** 32) * frequency) / dds_sysclk)) - ftw = frequency_to_ftw_return - with sequential: - ftw2 = ftw - ftw_to_frequency_return = ((ftw2 * dds_sysclk) / (2 ** 32)) - f = ftw_to_frequency_return - phi = ((1000 * mu_to_seconds(n)) * f) - do_something(int(phi)) -""" - -optimize_out = """ - -def run(): - now = syscall('now_init') - try: - do_something(344) - finally: - syscall('now_save', now) -""" - - -class OptimizeCase(unittest.TestCase): - def test_optimize(self): - dmgr = dict() - dmgr["comm"] = comm_dummy.Comm(dmgr) - coredev = core.Core(dmgr, ref_period=1*ns) - func_def = ast.parse(optimize_in).body[0] - coredev.transform_stack(func_def, dict(), dict()) - self.assertEqual(unparse(func_def), optimize_out) diff --git a/artiq/tools.py b/artiq/tools.py index dae3c2424..b50cc6fd3 100644 --- a/artiq/tools.py +++ b/artiq/tools.py @@ -71,7 +71,7 @@ def short_format(v): return r -def file_import(filename): +def file_import(filename, prefix="file_import_"): linecache.checkcache(filename) modname = filename @@ -81,7 +81,7 @@ def file_import(filename): i = modname.find(".") if i > 0: modname = modname[:i] - modname = "file_import_" + modname + modname = prefix + modname path = os.path.dirname(os.path.realpath(filename)) sys.path.insert(0, path) diff --git a/artiq/transforms/__init__.py b/artiq/transforms/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/artiq/transforms/fold_constants.py b/artiq/transforms/fold_constants.py deleted file mode 100644 index 402fc243b..000000000 --- a/artiq/transforms/fold_constants.py +++ /dev/null @@ -1,156 +0,0 @@ -import ast -import operator -from fractions import Fraction - -from artiq.transforms.tools import * -from artiq.language.core import int64, round64 - - -_ast_unops = { - ast.Invert: operator.inv, - ast.Not: operator.not_, - ast.UAdd: operator.pos, - ast.USub: operator.neg -} - - -_ast_binops = { - ast.Add: operator.add, - ast.Sub: operator.sub, - ast.Mult: operator.mul, - ast.Div: operator.truediv, - ast.FloorDiv: operator.floordiv, - ast.Mod: operator.mod, - ast.Pow: operator.pow, - ast.LShift: operator.lshift, - ast.RShift: operator.rshift, - ast.BitOr: operator.or_, - ast.BitXor: operator.xor, - ast.BitAnd: operator.and_ -} - -_ast_cmpops = { - ast.Eq: operator.eq, - ast.NotEq: operator.ne, - ast.Lt: operator.lt, - ast.LtE: operator.le, - ast.Gt: operator.gt, - ast.GtE: operator.ge -} - -_ast_boolops = { - ast.Or: lambda x, y: x or y, - ast.And: lambda x, y: x and y -} - - -class _ConstantFolder(ast.NodeTransformer): - def visit_UnaryOp(self, node): - self.generic_visit(node) - try: - operand = eval_constant(node.operand) - except NotConstant: - return node - try: - op = _ast_unops[type(node.op)] - except KeyError: - return node - try: - result = value_to_ast(op(operand)) - except: - return node - return ast.copy_location(result, node) - - def visit_BinOp(self, node): - self.generic_visit(node) - try: - left, right = eval_constant(node.left), eval_constant(node.right) - except NotConstant: - return node - try: - op = _ast_binops[type(node.op)] - except KeyError: - return node - try: - result = value_to_ast(op(left, right)) - except: - return node - return ast.copy_location(result, node) - - def visit_Compare(self, node): - self.generic_visit(node) - try: - operands = [eval_constant(node.left)] - except NotConstant: - operands = [node.left] - ops = [] - for op, right_ast in zip(node.ops, node.comparators): - try: - right = eval_constant(right_ast) - except NotConstant: - right = right_ast - if (not isinstance(operands[-1], ast.AST) - and not isinstance(right, ast.AST)): - left = operands.pop() - operands.append(_ast_cmpops[type(op)](left, right)) - else: - ops.append(op) - operands.append(right_ast) - operands = [operand if isinstance(operand, ast.AST) - else ast.copy_location(value_to_ast(operand), node) - for operand in operands] - if len(operands) == 1: - return operands[0] - else: - node.left = operands[0] - node.right = operands[1:] - node.ops = ops - return node - - def visit_BoolOp(self, node): - self.generic_visit(node) - new_values = [] - for value in node.values: - try: - value_c = eval_constant(value) - except NotConstant: - new_values.append(value) - else: - if new_values and not isinstance(new_values[-1], ast.AST): - op = _ast_boolops[type(node.op)] - new_values[-1] = op(new_values[-1], value_c) - else: - new_values.append(value_c) - new_values = [v if isinstance(v, ast.AST) else value_to_ast(v) - for v in new_values] - if len(new_values) > 1: - node.values = new_values - return node - else: - return new_values[0] - - def visit_Call(self, node): - self.generic_visit(node) - fn = node.func.id - constant_ops = { - "int": int, - "int64": int64, - "round": round, - "round64": round64, - "Fraction": Fraction - } - if fn in constant_ops: - args = [] - for arg in node.args: - try: - args.append(eval_constant(arg)) - except NotConstant: - return node - result = value_to_ast(constant_ops[fn](*args)) - return ast.copy_location(result, node) - else: - return node - - -def fold_constants(node): - _ConstantFolder().visit(node) diff --git a/artiq/transforms/remove_dead_code.py b/artiq/transforms/remove_dead_code.py deleted file mode 100644 index 9a58c851d..000000000 --- a/artiq/transforms/remove_dead_code.py +++ /dev/null @@ -1,59 +0,0 @@ -import ast - -from artiq.transforms.tools import is_ref_transparent - - -class _SourceLister(ast.NodeVisitor): - def __init__(self): - self.sources = set() - - def visit_Name(self, node): - if isinstance(node.ctx, ast.Load): - self.sources.add(node.id) - - -class _DeadCodeRemover(ast.NodeTransformer): - def __init__(self, kept_targets): - self.kept_targets = kept_targets - - def visit_Assign(self, node): - new_targets = [] - for target in node.targets: - if (not isinstance(target, ast.Name) - or target.id in self.kept_targets): - new_targets.append(target) - if not new_targets and is_ref_transparent(node.value)[0]: - return None - else: - return node - - def visit_AugAssign(self, node): - if (isinstance(node.target, ast.Name) - and node.target.id not in self.kept_targets - and is_ref_transparent(node.value)[0]): - return None - else: - return node - - def visit_If(self, node): - self.generic_visit(node) - if isinstance(node.test, ast.NameConstant): - if node.test.value: - return node.body - else: - return node.orelse - else: - return node - - def visit_While(self, node): - self.generic_visit(node) - if isinstance(node.test, ast.NameConstant) and not node.test.value: - return node.orelse - else: - return node - - -def remove_dead_code(func_def): - sl = _SourceLister() - sl.visit(func_def) - _DeadCodeRemover(sl.sources).visit(func_def) diff --git a/artiq/transforms/remove_inter_assigns.py b/artiq/transforms/remove_inter_assigns.py deleted file mode 100644 index 56d877215..000000000 --- a/artiq/transforms/remove_inter_assigns.py +++ /dev/null @@ -1,149 +0,0 @@ -import ast -from copy import copy, deepcopy -from collections import defaultdict - -from artiq.transforms.tools import is_ref_transparent, count_all_nodes - - -class _TargetLister(ast.NodeVisitor): - def __init__(self): - self.targets = set() - - def visit_Name(self, node): - if isinstance(node.ctx, ast.Store): - self.targets.add(node.id) - - -class _InterAssignRemover(ast.NodeTransformer): - def __init__(self): - self.replacements = dict() - self.modified_names = set() - # name -> set of names that depend on it - # i.e. when x is modified, dependencies[x] is the set of names that - # cannot be replaced anymore - self.dependencies = defaultdict(set) - - def invalidate(self, name): - try: - del self.replacements[name] - except KeyError: - pass - for d in self.dependencies[name]: - self.invalidate(d) - del self.dependencies[name] - - def visit_Name(self, node): - if isinstance(node.ctx, ast.Load): - try: - return deepcopy(self.replacements[node.id]) - except KeyError: - return node - else: - self.modified_names.add(node.id) - self.invalidate(node.id) - return node - - def visit_Assign(self, node): - node.value = self.visit(node.value) - node.targets = [self.visit(target) for target in node.targets] - rt, depends_on = is_ref_transparent(node.value) - if rt and count_all_nodes(node.value) < 100: - for target in node.targets: - if isinstance(target, ast.Name): - if target.id not in depends_on: - self.replacements[target.id] = node.value - for d in depends_on: - self.dependencies[d].add(target.id) - return node - - def visit_AugAssign(self, node): - left = deepcopy(node.target) - left.ctx = ast.Load() - newnode = ast.copy_location( - ast.Assign( - targets=[node.target], - value=ast.BinOp(left=left, op=node.op, right=node.value) - ), - node - ) - return self.visit_Assign(newnode) - - def modified_names_push(self): - prev_modified_names = self.modified_names - self.modified_names = set() - return prev_modified_names - - def modified_names_pop(self, prev_modified_names): - for name in self.modified_names: - self.invalidate(name) - self.modified_names |= prev_modified_names - - def visit_Try(self, node): - prev_modified_names = self.modified_names_push() - node.body = [self.visit(stmt) for stmt in node.body] - self.modified_names_pop(prev_modified_names) - - prev_modified_names = self.modified_names_push() - prev_replacements = self.replacements - for handler in node.handlers: - self.replacements = copy(prev_replacements) - handler.body = [self.visit(stmt) for stmt in handler.body] - self.replacements = copy(prev_replacements) - node.orelse = [self.visit(stmt) for stmt in node.orelse] - self.modified_names_pop(prev_modified_names) - - prev_modified_names = self.modified_names_push() - node.finalbody = [self.visit(stmt) for stmt in node.finalbody] - self.modified_names_pop(prev_modified_names) - return node - - def visit_If(self, node): - node.test = self.visit(node.test) - - prev_modified_names = self.modified_names_push() - - prev_replacements = self.replacements - self.replacements = copy(prev_replacements) - node.body = [self.visit(n) for n in node.body] - self.replacements = copy(prev_replacements) - node.orelse = [self.visit(n) for n in node.orelse] - self.replacements = prev_replacements - - self.modified_names_pop(prev_modified_names) - - return node - - def visit_loop(self, node): - prev_modified_names = self.modified_names_push() - prev_replacements = self.replacements - - self.replacements = copy(prev_replacements) - tl = _TargetLister() - for n in node.body: - tl.visit(n) - for name in tl.targets: - self.invalidate(name) - node.body = [self.visit(n) for n in node.body] - - self.replacements = copy(prev_replacements) - node.orelse = [self.visit(n) for n in node.orelse] - - self.replacements = prev_replacements - self.modified_names_pop(prev_modified_names) - - def visit_For(self, node): - prev_modified_names = self.modified_names_push() - node.target = self.visit(node.target) - self.modified_names_pop(prev_modified_names) - node.iter = self.visit(node.iter) - self.visit_loop(node) - return node - - def visit_While(self, node): - self.visit_loop(node) - node.test = self.visit(node.test) - return node - - -def remove_inter_assigns(func_def): - _InterAssignRemover().visit(func_def) diff --git a/doc/manual/core_device.rst b/doc/manual/core_device.rst index d01a992a6..942a4b844 100644 --- a/doc/manual/core_device.rst +++ b/doc/manual/core_device.rst @@ -54,21 +54,27 @@ The low-cost Pipistrello FPGA board can be used as a lower-cost but slower alter When plugged to an adapter, the NIST QC1 hardware can be used. The TTL lines are mapped to RTIO channels as follows: -+--------------+----------+------------+ -| RTIO channel | TTL line | Capability | -+==============+==========+============+ -| 0 | PMT0 | Input | -+--------------+----------+------------+ -| 1 | PMT1 | Input | -+--------------+----------+------------+ -| 2-16 | TTL0-14 | Output | -+--------------+----------+------------+ -| 17 | EXT_LED | Output | -+--------------+----------+------------+ -| 18 | USER_LED | Output | -+--------------+----------+------------+ -| 19 | TTL15 | Clock | -+--------------+----------+------------+ ++--------------+------------+------------+ +| RTIO channel | TTL line | Capability | ++==============+============+============+ +| 0 | PMT0 | Input | ++--------------+------------+------------+ +| 1 | PMT1 | Input | ++--------------+------------+------------+ +| 2-16 | TTL0-14 | Output | ++--------------+------------+------------+ +| 17 | EXT_LED | Output | ++--------------+------------+------------+ +| 18 | USER_LED_1 | Output | ++--------------+------------+------------+ +| 19 | USER_LED_2 | Output | ++--------------+------------+------------+ +| 20 | USER_LED_3 | Output | ++--------------+------------+------------+ +| 21 | USER_LED_4 | Output | ++--------------+------------+------------+ +| 22 | TTL15 | Clock | ++--------------+------------+------------+ The input only limitation on channels 0 and 1 comes from the QC-DAQ adapter. When the adapter is not used (and physically unplugged from the Pipistrello board), the corresponding pins on the Pipistrello can be used as outputs. Do not configure these channels as outputs when the adapter is plugged, as this would cause electrical contention. diff --git a/lit-test/harness.py b/lit-test/harness.py new file mode 100644 index 000000000..955394c3b --- /dev/null +++ b/lit-test/harness.py @@ -0,0 +1,31 @@ +""" +The purpose of this harness is to emulate the behavior of +the python executable, but add the ARTIQ root to sys.path +beforehand. + +This is necessary because eggs override the PYTHONPATH environment +variable, but not current directory; therefore `python -m artiq...` +ran from the ARTIQ root would work, but there is no simple way to +emulate the same behavior when invoked under lit. +""" + +import sys, os, argparse, importlib + +parser = argparse.ArgumentParser(description=__doc__) +parser.add_argument('-m', metavar='mod', type=str, + help='run library module as a script') +parser.add_argument('args', type=str, nargs='+', + help='arguments passed to program in sys.argv[1:]') +args = parser.parse_args(sys.argv[1:]) + +artiq_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +sys.path.insert(1, artiq_path) + +if args.m: + sys.argv[1:] = args.args + importlib.import_module(args.m).main() +else: + sys.argv[1:] = args.args[1:] + with open(args.args[0]) as f: + code = compile(f.read(), args.args[0], 'exec') + exec(code) diff --git a/lit-test/libartiq_support/Makefile b/lit-test/libartiq_support/Makefile new file mode 100644 index 000000000..0406dee63 --- /dev/null +++ b/lit-test/libartiq_support/Makefile @@ -0,0 +1,4 @@ +CC ?= clang + +libartiq_support.so: ../../artiq/runtime/artiq_personality.c artiq_terminate.c artiq_time.c + $(CC) -std=c99 -Wall -Werror -I. -I../../artiq/runtime -g -fPIC -shared -o $@ $^ diff --git a/lit-test/libartiq_support/__cxxabi_config.h b/lit-test/libartiq_support/__cxxabi_config.h new file mode 100644 index 000000000..42cd6fe5c --- /dev/null +++ b/lit-test/libartiq_support/__cxxabi_config.h @@ -0,0 +1 @@ +#define LIBCXXABI_ARM_EHABI 0 diff --git a/lit-test/libartiq_support/artiq_terminate.c b/lit-test/libartiq_support/artiq_terminate.c new file mode 100644 index 000000000..5b1315131 --- /dev/null +++ b/lit-test/libartiq_support/artiq_terminate.c @@ -0,0 +1,29 @@ +#include +#include +#include +#include +#include + +#define __USE_GNU +#include + +void __artiq_terminate(struct artiq_exception *exn, + struct artiq_backtrace_item *backtrace, + size_t backtrace_size) { + printf("Uncaught %s: %s (%"PRIi64", %"PRIi64", %"PRIi64")\n" + "at %s:%"PRIi32":%"PRIi32"\n", + exn->name, exn->message, + exn->param[0], exn->param[1], exn->param[1], + exn->file, exn->line, exn->column + 1); + + for(size_t i = 0; i < backtrace_size; i++) { + Dl_info info; + if(dladdr((void*) backtrace[i].function, &info) && info.dli_sname) { + printf("at %s+%p\n", info.dli_sname, (void*)backtrace[i].offset); + } else { + printf("at %p+%p\n", (void*)backtrace[i].function, (void*)backtrace[i].offset); + } + } + + exit(1); +} diff --git a/lit-test/libartiq_support/artiq_time.c b/lit-test/libartiq_support/artiq_time.c new file mode 100644 index 000000000..1afeadbc0 --- /dev/null +++ b/lit-test/libartiq_support/artiq_time.c @@ -0,0 +1,3 @@ +#include + +int64_t now = 0; diff --git a/lit-test/libartiq_support/libartiq_personality.so b/lit-test/libartiq_support/libartiq_personality.so new file mode 100755 index 000000000..80ff44cb8 Binary files /dev/null and b/lit-test/libartiq_support/libartiq_personality.so differ diff --git a/lit-test/libartiq_support/unwind.h b/lit-test/libartiq_support/unwind.h new file mode 100644 index 000000000..86001bbb5 --- /dev/null +++ b/lit-test/libartiq_support/unwind.h @@ -0,0 +1,329 @@ +//===------------------------------- unwind.h -----------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is dual licensed under the MIT and the University of Illinois Open +// Source Licenses. See LICENSE.TXT for details. +// +// +// C++ ABI Level 1 ABI documented at: +// http://mentorembedded.github.io/cxx-abi/abi-eh.html +// +//===----------------------------------------------------------------------===// + +#ifndef __UNWIND_H__ +#define __UNWIND_H__ + +#include +#include + +#if defined(__APPLE__) +#define LIBUNWIND_UNAVAIL __attribute__ (( unavailable )) +#else +#define LIBUNWIND_UNAVAIL +#endif + +#include <__cxxabi_config.h> + +typedef enum { + _URC_NO_REASON = 0, + _URC_OK = 0, + _URC_FOREIGN_EXCEPTION_CAUGHT = 1, + _URC_FATAL_PHASE2_ERROR = 2, + _URC_FATAL_PHASE1_ERROR = 3, + _URC_NORMAL_STOP = 4, + _URC_END_OF_STACK = 5, + _URC_HANDLER_FOUND = 6, + _URC_INSTALL_CONTEXT = 7, + _URC_CONTINUE_UNWIND = 8, +#if LIBCXXABI_ARM_EHABI + _URC_FAILURE = 9 +#endif +} _Unwind_Reason_Code; + +typedef enum { + _UA_SEARCH_PHASE = 1, + _UA_CLEANUP_PHASE = 2, + _UA_HANDLER_FRAME = 4, + _UA_FORCE_UNWIND = 8, + _UA_END_OF_STACK = 16 // gcc extension to C++ ABI +} _Unwind_Action; + +typedef struct _Unwind_Context _Unwind_Context; // opaque + +#if LIBCXXABI_ARM_EHABI +typedef uint32_t _Unwind_State; + +static const _Unwind_State _US_VIRTUAL_UNWIND_FRAME = 0; +static const _Unwind_State _US_UNWIND_FRAME_STARTING = 1; +static const _Unwind_State _US_UNWIND_FRAME_RESUME = 2; +/* Undocumented flag for force unwinding. */ +static const _Unwind_State _US_FORCE_UNWIND = 8; + +typedef uint32_t _Unwind_EHT_Header; + +struct _Unwind_Control_Block; +typedef struct _Unwind_Control_Block _Unwind_Control_Block; +typedef struct _Unwind_Control_Block _Unwind_Exception; /* Alias */ + +struct _Unwind_Control_Block { + uint64_t exception_class; + void (*exception_cleanup)(_Unwind_Reason_Code, _Unwind_Control_Block*); + + /* Unwinder cache, private fields for the unwinder's use */ + struct { + uint32_t reserved1; /* init reserved1 to 0, then don't touch */ + uint32_t reserved2; + uint32_t reserved3; + uint32_t reserved4; + uint32_t reserved5; + } unwinder_cache; + + /* Propagation barrier cache (valid after phase 1): */ + struct { + uint32_t sp; + uint32_t bitpattern[5]; + } barrier_cache; + + /* Cleanup cache (preserved over cleanup): */ + struct { + uint32_t bitpattern[4]; + } cleanup_cache; + + /* Pr cache (for pr's benefit): */ + struct { + uint32_t fnstart; /* function start address */ + _Unwind_EHT_Header* ehtp; /* pointer to EHT entry header word */ + uint32_t additional; + uint32_t reserved1; + } pr_cache; + + long long int :0; /* Enforce the 8-byte alignment */ +}; + +typedef _Unwind_Reason_Code (*_Unwind_Stop_Fn) + (_Unwind_State state, + _Unwind_Exception* exceptionObject, + struct _Unwind_Context* context); + +typedef _Unwind_Reason_Code (*__personality_routine) + (_Unwind_State state, + _Unwind_Exception* exceptionObject, + struct _Unwind_Context* context); +#else +struct _Unwind_Context; // opaque +struct _Unwind_Exception; // forward declaration +typedef struct _Unwind_Exception _Unwind_Exception; + +struct _Unwind_Exception { + uint64_t exception_class; + void (*exception_cleanup)(_Unwind_Reason_Code reason, + _Unwind_Exception *exc); + uintptr_t private_1; // non-zero means forced unwind + uintptr_t private_2; // holds sp that phase1 found for phase2 to use +#ifndef __LP64__ + // The gcc implementation of _Unwind_Exception used attribute mode on the + // above fields which had the side effect of causing this whole struct to + // round up to 32 bytes in size. To be more explicit, we add pad fields + // added for binary compatibility. + uint32_t reserved[3]; +#endif +}; + +typedef _Unwind_Reason_Code (*_Unwind_Stop_Fn) + (int version, + _Unwind_Action actions, + uint64_t exceptionClass, + _Unwind_Exception* exceptionObject, + struct _Unwind_Context* context, + void* stop_parameter ); + +typedef _Unwind_Reason_Code (*__personality_routine) + (int version, + _Unwind_Action actions, + uint64_t exceptionClass, + _Unwind_Exception* exceptionObject, + struct _Unwind_Context* context); +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +// +// The following are the base functions documented by the C++ ABI +// +#ifdef __USING_SJLJ_EXCEPTIONS__ +extern _Unwind_Reason_Code + _Unwind_SjLj_RaiseException(_Unwind_Exception *exception_object); +extern void _Unwind_SjLj_Resume(_Unwind_Exception *exception_object); +#else +extern _Unwind_Reason_Code + _Unwind_RaiseException(_Unwind_Exception *exception_object); +extern void _Unwind_Resume(_Unwind_Exception *exception_object); +#endif +extern void _Unwind_DeleteException(_Unwind_Exception *exception_object); + +#if LIBCXXABI_ARM_EHABI +typedef enum { + _UVRSC_CORE = 0, /* integer register */ + _UVRSC_VFP = 1, /* vfp */ + _UVRSC_WMMXD = 3, /* Intel WMMX data register */ + _UVRSC_WMMXC = 4 /* Intel WMMX control register */ +} _Unwind_VRS_RegClass; + +typedef enum { + _UVRSD_UINT32 = 0, + _UVRSD_VFPX = 1, + _UVRSD_UINT64 = 3, + _UVRSD_FLOAT = 4, + _UVRSD_DOUBLE = 5 +} _Unwind_VRS_DataRepresentation; + +typedef enum { + _UVRSR_OK = 0, + _UVRSR_NOT_IMPLEMENTED = 1, + _UVRSR_FAILED = 2 +} _Unwind_VRS_Result; + +extern void _Unwind_Complete(_Unwind_Exception* exception_object); + +extern _Unwind_VRS_Result +_Unwind_VRS_Get(_Unwind_Context *context, _Unwind_VRS_RegClass regclass, + uint32_t regno, _Unwind_VRS_DataRepresentation representation, + void *valuep); + +extern _Unwind_VRS_Result +_Unwind_VRS_Set(_Unwind_Context *context, _Unwind_VRS_RegClass regclass, + uint32_t regno, _Unwind_VRS_DataRepresentation representation, + void *valuep); + +extern _Unwind_VRS_Result +_Unwind_VRS_Pop(_Unwind_Context *context, _Unwind_VRS_RegClass regclass, + uint32_t discriminator, + _Unwind_VRS_DataRepresentation representation); +#endif + +extern uintptr_t _Unwind_GetGR(struct _Unwind_Context *context, int index); +extern void _Unwind_SetGR(struct _Unwind_Context *context, int index, + uintptr_t new_value); +extern uintptr_t _Unwind_GetIP(struct _Unwind_Context *context); +extern void _Unwind_SetIP(struct _Unwind_Context *, uintptr_t new_value); + +extern uintptr_t _Unwind_GetRegionStart(struct _Unwind_Context *context); +extern uintptr_t + _Unwind_GetLanguageSpecificData(struct _Unwind_Context *context); +#ifdef __USING_SJLJ_EXCEPTIONS__ +extern _Unwind_Reason_Code + _Unwind_SjLj_ForcedUnwind(_Unwind_Exception *exception_object, + _Unwind_Stop_Fn stop, void *stop_parameter); +#else +extern _Unwind_Reason_Code + _Unwind_ForcedUnwind(_Unwind_Exception *exception_object, + _Unwind_Stop_Fn stop, void *stop_parameter); +#endif + +#ifdef __USING_SJLJ_EXCEPTIONS__ +typedef struct _Unwind_FunctionContext *_Unwind_FunctionContext_t; +extern void _Unwind_SjLj_Register(_Unwind_FunctionContext_t fc); +extern void _Unwind_SjLj_Unregister(_Unwind_FunctionContext_t fc); +#endif + +// +// The following are semi-suppoted extensions to the C++ ABI +// + +// +// called by __cxa_rethrow(). +// +#ifdef __USING_SJLJ_EXCEPTIONS__ +extern _Unwind_Reason_Code + _Unwind_SjLj_Resume_or_Rethrow(_Unwind_Exception *exception_object); +#else +extern _Unwind_Reason_Code + _Unwind_Resume_or_Rethrow(_Unwind_Exception *exception_object); +#endif + +// _Unwind_Backtrace() is a gcc extension that walks the stack and calls the +// _Unwind_Trace_Fn once per frame until it reaches the bottom of the stack +// or the _Unwind_Trace_Fn function returns something other than _URC_NO_REASON. +typedef _Unwind_Reason_Code (*_Unwind_Trace_Fn)(struct _Unwind_Context *, + void *); +extern _Unwind_Reason_Code _Unwind_Backtrace(_Unwind_Trace_Fn, void *); + +// _Unwind_GetCFA is a gcc extension that can be called from within a +// personality handler to get the CFA (stack pointer before call) of +// current frame. +extern uintptr_t _Unwind_GetCFA(struct _Unwind_Context *); + + +// _Unwind_GetIPInfo is a gcc extension that can be called from within a +// personality handler. Similar to _Unwind_GetIP() but also returns in +// *ipBefore a non-zero value if the instruction pointer is at or before the +// instruction causing the unwind. Normally, in a function call, the IP returned +// is the return address which is after the call instruction and may be past the +// end of the function containing the call instruction. +extern uintptr_t _Unwind_GetIPInfo(struct _Unwind_Context *context, + int *ipBefore); + + +// __register_frame() is used with dynamically generated code to register the +// FDE for a generated (JIT) code. The FDE must use pc-rel addressing to point +// to its function and optional LSDA. +// __register_frame() has existed in all versions of Mac OS X, but in 10.4 and +// 10.5 it was buggy and did not actually register the FDE with the unwinder. +// In 10.6 and later it does register properly. +extern void __register_frame(const void *fde); +extern void __deregister_frame(const void *fde); + +// _Unwind_Find_FDE() will locate the FDE if the pc is in some function that has +// an associated FDE. Note, Mac OS X 10.6 and later, introduces "compact unwind +// info" which the runtime uses in preference to dwarf unwind info. This +// function will only work if the target function has an FDE but no compact +// unwind info. +struct dwarf_eh_bases { + uintptr_t tbase; + uintptr_t dbase; + uintptr_t func; +}; +extern const void *_Unwind_Find_FDE(const void *pc, struct dwarf_eh_bases *); + + +// This function attempts to find the start (address of first instruction) of +// a function given an address inside the function. It only works if the +// function has an FDE (dwarf unwind info). +// This function is unimplemented on Mac OS X 10.6 and later. Instead, use +// _Unwind_Find_FDE() and look at the dwarf_eh_bases.func result. +extern void *_Unwind_FindEnclosingFunction(void *pc); + +// Mac OS X does not support text-rel and data-rel addressing so these functions +// are unimplemented +extern uintptr_t _Unwind_GetDataRelBase(struct _Unwind_Context *context) + LIBUNWIND_UNAVAIL; +extern uintptr_t _Unwind_GetTextRelBase(struct _Unwind_Context *context) + LIBUNWIND_UNAVAIL; + +// Mac OS X 10.4 and 10.5 had implementations of these functions in +// libgcc_s.dylib, but they never worked. +/// These functions are no longer available on Mac OS X. +extern void __register_frame_info_bases(const void *fde, void *ob, void *tb, + void *db) LIBUNWIND_UNAVAIL; +extern void __register_frame_info(const void *fde, void *ob) + LIBUNWIND_UNAVAIL; +extern void __register_frame_info_table_bases(const void *fde, void *ob, + void *tb, void *db) + LIBUNWIND_UNAVAIL; +extern void __register_frame_info_table(const void *fde, void *ob) + LIBUNWIND_UNAVAIL; +extern void __register_frame_table(const void *fde) + LIBUNWIND_UNAVAIL; +extern void *__deregister_frame_info(const void *fde) + LIBUNWIND_UNAVAIL; +extern void *__deregister_frame_info_bases(const void *fde) + LIBUNWIND_UNAVAIL; + +#ifdef __cplusplus +} +#endif + +#endif // __UNWIND_H__ diff --git a/lit-test/not.py b/lit-test/not.py new file mode 100644 index 000000000..8c0421623 --- /dev/null +++ b/lit-test/not.py @@ -0,0 +1,2 @@ +import sys, subprocess +exit(not subprocess.call(sys.argv[1:])) diff --git a/lit-test/test/codegen/warning_useless_bool.py b/lit-test/test/codegen/warning_useless_bool.py new file mode 100644 index 000000000..d81fa2941 --- /dev/null +++ b/lit-test/test/codegen/warning_useless_bool.py @@ -0,0 +1,5 @@ +# RUN: %python -m artiq.compiler.testbench.signature +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +# CHECK-L: ${LINE:+1}: warning: this expression, which is always truthful, is coerced to bool +bool(IndexError()) diff --git a/lit-test/test/devirtualization/device_db.pyon b/lit-test/test/devirtualization/device_db.pyon new file mode 100644 index 000000000..7c1bb62ef --- /dev/null +++ b/lit-test/test/devirtualization/device_db.pyon @@ -0,0 +1,8 @@ +{ + "comm": { + "type": "local", + "module": "artiq.coredevice.comm_dummy", + "class": "Comm", + "arguments": {} + } +} diff --git a/lit-test/test/devirtualization/function.py b/lit-test/test/devirtualization/function.py new file mode 100644 index 000000000..5c24b492b --- /dev/null +++ b/lit-test/test/devirtualization/function.py @@ -0,0 +1,22 @@ +# RUN: env ARTIQ_DUMP_IR=1 %python -m artiq.compiler.testbench.embedding +compile %s 2>%t +# RUN: OutputCheck %s --file-to-check=%t + +from artiq.language.core import * +from artiq.language.types import * + +# CHECK-L: call ()->NoneType %local.testbench.entrypoint ; calls testbench.entrypoint + +@kernel +def baz(): + pass + +class foo: + @kernel + def bar(self): + # CHECK-L: call ()->NoneType %local.testbench.baz ; calls testbench.baz + baz() +x = foo() + +@kernel +def entrypoint(): + x.bar() diff --git a/lit-test/test/devirtualization/method.py b/lit-test/test/devirtualization/method.py new file mode 100644 index 000000000..2105487e3 --- /dev/null +++ b/lit-test/test/devirtualization/method.py @@ -0,0 +1,16 @@ +# RUN: env ARTIQ_DUMP_IR=1 %python -m artiq.compiler.testbench.embedding +compile %s 2>%t +# RUN: OutputCheck %s --file-to-check=%t + +from artiq.language.core import * +from artiq.language.types import * + +class foo: + @kernel + def bar(self): + pass +x = foo() + +@kernel +def entrypoint(): + # CHECK-L: ; calls testbench.foo.bar + x.bar() diff --git a/lit-test/test/embedding/device_db.pyon b/lit-test/test/embedding/device_db.pyon new file mode 100644 index 000000000..7c1bb62ef --- /dev/null +++ b/lit-test/test/embedding/device_db.pyon @@ -0,0 +1,8 @@ +{ + "comm": { + "type": "local", + "module": "artiq.coredevice.comm_dummy", + "class": "Comm", + "arguments": {} + } +} diff --git a/lit-test/test/embedding/error_attr_absent.py b/lit-test/test/embedding/error_attr_absent.py new file mode 100644 index 000000000..296ed860a --- /dev/null +++ b/lit-test/test/embedding/error_attr_absent.py @@ -0,0 +1,16 @@ +# RUN: %python -m artiq.compiler.testbench.embedding %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +from artiq.language.core import * +from artiq.language.types import * + +class c: + pass + +@kernel +def entrypoint(): + # CHECK-L: :1: error: host object does not have an attribute 'x' + # CHECK-L: ${LINE:+1}: note: expanded from here + a = c + # CHECK-L: ${LINE:+1}: note: attribute accessed here + a.x diff --git a/lit-test/test/embedding/error_attr_conflict.py b/lit-test/test/embedding/error_attr_conflict.py new file mode 100644 index 000000000..a5c960d5c --- /dev/null +++ b/lit-test/test/embedding/error_attr_conflict.py @@ -0,0 +1,21 @@ +# RUN: %python -m artiq.compiler.testbench.embedding %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +from artiq.language.core import * +from artiq.language.types import * + +class c: + pass + +i1 = c() +i1.x = 1 + +i2 = c() +i2.x = 1.0 + +@kernel +def entrypoint(): + # CHECK-L: :1: error: host object has an attribute 'x' of type float, which is different from previously inferred type int(width=32) for the same attribute + i1.x + # CHECK-L: ${LINE:+1}: note: expanded from here + i2.x diff --git a/lit-test/test/embedding/error_attr_unify.py b/lit-test/test/embedding/error_attr_unify.py new file mode 100644 index 000000000..e891ec73b --- /dev/null +++ b/lit-test/test/embedding/error_attr_unify.py @@ -0,0 +1,17 @@ +# RUN: %python -m artiq.compiler.testbench.embedding %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +from artiq.language.core import * +from artiq.language.types import * + +class c: + x = [1, "x"] + +@kernel +def entrypoint(): + # CHECK-L: :1: error: cannot unify int(width='a) with str + # CHECK-NEXT-L: [1, 'x'] + # CHECK-L: ${LINE:+1}: note: expanded from here + a = c + # CHECK-L: ${LINE:+1}: note: while inferring a type for an attribute 'x' of a host object + a.x diff --git a/lit-test/test/embedding/error_rpc_annot_return.py b/lit-test/test/embedding/error_rpc_annot_return.py new file mode 100644 index 000000000..063ac6c28 --- /dev/null +++ b/lit-test/test/embedding/error_rpc_annot_return.py @@ -0,0 +1,14 @@ +# RUN: %python -m artiq.compiler.testbench.embedding %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +from artiq.language.core import * +from artiq.language.types import * + +# CHECK-L: ${LINE:+1}: error: type annotation for return type, '1', is not an ARTIQ type +def foo() -> 1: + pass + +@kernel +def entrypoint(): + # CHECK-L: ${LINE:+1}: note: in function called remotely here + foo() diff --git a/lit-test/test/embedding/error_rpc_default_unify.py b/lit-test/test/embedding/error_rpc_default_unify.py new file mode 100644 index 000000000..9cb483628 --- /dev/null +++ b/lit-test/test/embedding/error_rpc_default_unify.py @@ -0,0 +1,15 @@ +# RUN: %python -m artiq.compiler.testbench.embedding %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +from artiq.language.core import * +from artiq.language.types import * + +# CHECK-L: :1: error: cannot unify int(width='a) with str +# CHECK-L: ${LINE:+1}: note: expanded from here while trying to infer a type for an unannotated optional argument 'x' from its default value +def foo(x=[1,"x"]): + pass + +@kernel +def entrypoint(): + # CHECK-L: ${LINE:+1}: note: in function called remotely here + foo() diff --git a/lit-test/test/embedding/error_syscall_annot.py b/lit-test/test/embedding/error_syscall_annot.py new file mode 100644 index 000000000..cbb5aab22 --- /dev/null +++ b/lit-test/test/embedding/error_syscall_annot.py @@ -0,0 +1,15 @@ +# RUN: %python -m artiq.compiler.testbench.embedding %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +from artiq.language.core import * +from artiq.language.types import * + +# CHECK-L: ${LINE:+2}: error: type annotation for argument 'x', '1', is not an ARTIQ type +@syscall +def foo(x: 1) -> TNone: + pass + +@kernel +def entrypoint(): + # CHECK-L: ${LINE:+1}: note: in system call here + foo() diff --git a/lit-test/test/embedding/error_syscall_annot_return.py b/lit-test/test/embedding/error_syscall_annot_return.py new file mode 100644 index 000000000..20f1d4ac5 --- /dev/null +++ b/lit-test/test/embedding/error_syscall_annot_return.py @@ -0,0 +1,15 @@ +# RUN: %python -m artiq.compiler.testbench.embedding %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +from artiq.language.core import * +from artiq.language.types import * + +# CHECK-L: ${LINE:+2}: error: type annotation for return type, '1', is not an ARTIQ type +@syscall +def foo() -> 1: + pass + +@kernel +def entrypoint(): + # CHECK-L: ${LINE:+1}: note: in system call here + foo() diff --git a/lit-test/test/embedding/error_syscall_arg.py b/lit-test/test/embedding/error_syscall_arg.py new file mode 100644 index 000000000..b944cc0e9 --- /dev/null +++ b/lit-test/test/embedding/error_syscall_arg.py @@ -0,0 +1,15 @@ +# RUN: %python -m artiq.compiler.testbench.embedding %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +from artiq.language.core import * +from artiq.language.types import * + +# CHECK-L: ${LINE:+2}: error: system call argument 'x' must have a type annotation +@syscall +def foo(x) -> TNone: + pass + +@kernel +def entrypoint(): + # CHECK-L: ${LINE:+1}: note: in system call here + foo() diff --git a/lit-test/test/embedding/error_syscall_default_arg.py b/lit-test/test/embedding/error_syscall_default_arg.py new file mode 100644 index 000000000..df025ff19 --- /dev/null +++ b/lit-test/test/embedding/error_syscall_default_arg.py @@ -0,0 +1,14 @@ +# RUN: %python -m artiq.compiler.testbench.embedding %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +from artiq.language.core import * +from artiq.language.types import * + +# CHECK-L: ${LINE:+2}: error: system call argument 'x' must not have a default value +@syscall +def foo(x=1) -> TNone: + pass + +@kernel +def entrypoint(): + foo() diff --git a/lit-test/test/embedding/error_syscall_return.py b/lit-test/test/embedding/error_syscall_return.py new file mode 100644 index 000000000..c82db063d --- /dev/null +++ b/lit-test/test/embedding/error_syscall_return.py @@ -0,0 +1,14 @@ +# RUN: %python -m artiq.compiler.testbench.embedding %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +from artiq.language.core import * +from artiq.language.types import * + +# CHECK-L: ${LINE:+2}: error: system call must have a return type annotation +@syscall +def foo(): + pass + +@kernel +def entrypoint(): + foo() diff --git a/lit-test/test/exceptions/catch.py b/lit-test/test/exceptions/catch.py new file mode 100644 index 000000000..d6c2866c1 --- /dev/null +++ b/lit-test/test/exceptions/catch.py @@ -0,0 +1,9 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s >%t +# RUN: OutputCheck %s --file-to-check=%t +# REQUIRES: exceptions + +try: + 1/0 +except ZeroDivisionError: + # CHECK-L: OK + print("OK") diff --git a/lit-test/test/exceptions/catch_all.py b/lit-test/test/exceptions/catch_all.py new file mode 100644 index 000000000..1417f5f31 --- /dev/null +++ b/lit-test/test/exceptions/catch_all.py @@ -0,0 +1,14 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s >%t +# RUN: OutputCheck %s --file-to-check=%t +# REQUIRES: exceptions + +def catch(f): + try: + f() + except Exception as e: + print(e) + +# CHECK-L: ZeroDivisionError +catch(lambda: 1/0) +# CHECK-L: IndexError +catch(lambda: [1.0][10]) diff --git a/lit-test/test/exceptions/catch_multi.py b/lit-test/test/exceptions/catch_multi.py new file mode 100644 index 000000000..472086660 --- /dev/null +++ b/lit-test/test/exceptions/catch_multi.py @@ -0,0 +1,16 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s >%t +# RUN: OutputCheck %s --file-to-check=%t +# REQUIRES: exceptions + +def catch(f): + try: + f() + except ZeroDivisionError as zde: + print(zde) + except IndexError as ie: + print(ie) + +# CHECK-L: ZeroDivisionError +catch(lambda: 1/0) +# CHECK-L: IndexError +catch(lambda: [1.0][10]) diff --git a/lit-test/test/exceptions/catch_outer.py b/lit-test/test/exceptions/catch_outer.py new file mode 100644 index 000000000..de7253eaf --- /dev/null +++ b/lit-test/test/exceptions/catch_outer.py @@ -0,0 +1,16 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s >%t +# RUN: OutputCheck %s --file-to-check=%t +# REQUIRES: exceptions + +def f(): + try: + 1/0 + except ValueError: + # CHECK-NOT-L: FAIL + print("FAIL") + +try: + f() +except ZeroDivisionError: + # CHECK-L: OK + print("OK") diff --git a/lit-test/test/exceptions/finally.py b/lit-test/test/exceptions/finally.py new file mode 100644 index 000000000..17304fc15 --- /dev/null +++ b/lit-test/test/exceptions/finally.py @@ -0,0 +1,21 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s >%t +# RUN: OutputCheck %s --file-to-check=%t +# REQUIRES: exceptions + +def f(): + try: + 1/0 + finally: + print("f-fin") + print("f-out") + +def g(): + try: + f() + except: + print("g-except") + +# CHECK-L: f-fin +# CHECK-NOT-L: f-out +# CHECK-L: g-except +g() diff --git a/lit-test/test/exceptions/finally_catch.py b/lit-test/test/exceptions/finally_catch.py new file mode 100644 index 000000000..23bc39730 --- /dev/null +++ b/lit-test/test/exceptions/finally_catch.py @@ -0,0 +1,17 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s >%t +# RUN: OutputCheck %s --file-to-check=%t +# REQUIRES: exceptions + +def f(): + try: + 1/0 + except: + print("f-except") + finally: + print("f-fin") + print("f-out") + +# CHECK-L: f-except +# CHECK-L: f-fin +# CHECK-L: f-out +f() diff --git a/lit-test/test/exceptions/finally_raise.py b/lit-test/test/exceptions/finally_raise.py new file mode 100644 index 000000000..02c41ea7e --- /dev/null +++ b/lit-test/test/exceptions/finally_raise.py @@ -0,0 +1,23 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s >%t +# RUN: OutputCheck %s --file-to-check=%t +# REQUIRES: exceptions + +def f(): + try: + 1/0 + finally: + print("f-fin") + raise ValueError() + +def g(): + try: + f() + except ZeroDivisionError: + print("g-except-zde") + except ValueError: + print("g-except-ve") + +# CHECK-L: f-fin +# CHECK-L: g-except-ve +# CHECK-NOT-L: g-except-zde +g() diff --git a/lit-test/test/exceptions/finally_squash.py b/lit-test/test/exceptions/finally_squash.py new file mode 100644 index 000000000..8c7b58fc3 --- /dev/null +++ b/lit-test/test/exceptions/finally_squash.py @@ -0,0 +1,21 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s >%t +# RUN: OutputCheck %s --file-to-check=%t +# REQUIRES: exceptions + +def f(): + try: + 1/0 + finally: + print("f-fin") + return + +def g(): + try: + f() + except: + print("g-except") + +# CHECK-L: f-fin +# CHECK-NOT-L: f-out +# CHECK-NOT-L: g-except +g() diff --git a/lit-test/test/exceptions/finally_uncaught.py b/lit-test/test/exceptions/finally_uncaught.py new file mode 100644 index 000000000..1eb211663 --- /dev/null +++ b/lit-test/test/exceptions/finally_uncaught.py @@ -0,0 +1,12 @@ +# RUN: %not %python -m artiq.compiler.testbench.jit %s >%t +# RUN: OutputCheck %s --file-to-check=%t +# REQUIRES: exceptions + +def f(): + try: + 1/0 + finally: + print("f-fin") + +# CHECK-L: f-fin +f() diff --git a/lit-test/test/exceptions/reraise.py b/lit-test/test/exceptions/reraise.py new file mode 100644 index 000000000..2a02b523d --- /dev/null +++ b/lit-test/test/exceptions/reraise.py @@ -0,0 +1,16 @@ +# RUN: %not %python -m artiq.compiler.testbench.jit %s >%t +# RUN: OutputCheck %s --file-to-check=%t +# REQUIRES: exceptions + +def f(): + # CHECK-L: Uncaught ZeroDivisionError + # CHECK-L: at input.py:${LINE:+1}: + 1/0 + +def g(): + try: + f() + except: + raise + +g() diff --git a/lit-test/test/exceptions/reraise_update.py b/lit-test/test/exceptions/reraise_update.py new file mode 100644 index 000000000..891e7685d --- /dev/null +++ b/lit-test/test/exceptions/reraise_update.py @@ -0,0 +1,16 @@ +# RUN: %not %python -m artiq.compiler.testbench.jit %s >%t +# RUN: OutputCheck %s --file-to-check=%t +# REQUIRES: exceptions + +def f(): + 1/0 + +def g(): + try: + f() + except Exception as e: + # CHECK-L: Uncaught ZeroDivisionError + # CHECK-L: at input.py:${LINE:+1}: + raise e + +g() diff --git a/lit-test/test/exceptions/uncaught.py b/lit-test/test/exceptions/uncaught.py new file mode 100644 index 000000000..6044f8bf2 --- /dev/null +++ b/lit-test/test/exceptions/uncaught.py @@ -0,0 +1,7 @@ +# RUN: %not %python -m artiq.compiler.testbench.jit %s >%t +# RUN: OutputCheck %s --file-to-check=%t +# REQUIRES: exceptions + +# CHECK-L: Uncaught ZeroDivisionError: cannot divide by zero (0, 0, 0) +# CHECK-L: at input.py:${LINE:+1}: +1/0 diff --git a/lit-test/test/inferencer/builtin_calls.py b/lit-test/test/inferencer/builtin_calls.py new file mode 100644 index 000000000..61a97fb4a --- /dev/null +++ b/lit-test/test/inferencer/builtin_calls.py @@ -0,0 +1,32 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +# CHECK-L: bool:():bool +bool() + +# CHECK-L: bool:([]:list(elt='a)):bool +bool([]) + +# CHECK-L: int:():int(width='b) +int() + +# CHECK-L: int:(1.0:float):int(width='c) +int(1.0) + +# CHECK-L: int:(1.0:float, width=64:int(width='d)):int(width=64) +int(1.0, width=64) + +# CHECK-L: float:():float +float() + +# CHECK-L: float:(1:int(width='e)):float +float(1) + +# CHECK-L: list:():list(elt='f) +list() + +# CHECK-L: len:([]:list(elt='g)):int(width=32) +len([]) + +# CHECK-L: round:(1.0:float):int(width='h) +round(1.0) diff --git a/lit-test/test/inferencer/class.py b/lit-test/test/inferencer/class.py new file mode 100644 index 000000000..1249474a0 --- /dev/null +++ b/lit-test/test/inferencer/class.py @@ -0,0 +1,19 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +class c: + a = 1 + def f(): + pass + def m(self): + pass + +# CHECK-L: c:)->NoneType delay('c)}> +c +# CHECK-L: .a:int(width='a) +c.a +# CHECK-L: .f:()->NoneType delay('b) +c.f + +# CHECK-L: .m:method(fn=(self:)->NoneType delay('c), self=) +c().m() diff --git a/lit-test/test/inferencer/coerce.py b/lit-test/test/inferencer/coerce.py new file mode 100644 index 000000000..04fe42bbc --- /dev/null +++ b/lit-test/test/inferencer/coerce.py @@ -0,0 +1,41 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +1 | 2 +# CHECK-L: 1:int(width='a):int(width='b) | 2:int(width='c):int(width='b):int(width='b) + +1 + 2 +# CHECK-L: 1:int(width='d):int(width='e) + 2:int(width='f):int(width='e):int(width='e) + +(1,) + (2.0,) +# CHECK-L: (1:int(width='g),):(int(width='g),) + (2.0:float,):(float,):(int(width='g), float) + +[1] + [2] +# CHECK-L: [1:int(width='h)]:list(elt=int(width='h)) + [2:int(width='h)]:list(elt=int(width='h)):list(elt=int(width='h)) + +1 * 2 +# CHECK-L: 1:int(width='i):int(width='j) * 2:int(width='k):int(width='j):int(width='j) + +[1] * 2 +# CHECK-L: [1:int(width='l)]:list(elt=int(width='l)) * 2:int(width='m):list(elt=int(width='l)) + +1 // 2 +# CHECK-L: 1:int(width='n):int(width='o) // 2:int(width='p):int(width='o):int(width='o) + +1 + 1.0 +# CHECK-L: 1:int(width='q):float + 1.0:float:float + +a = []; a += [1] +# CHECK-L: a:list(elt=int(width='r)) = []:list(elt=int(width='r)); a:list(elt=int(width='r)) += [1:int(width='r)]:list(elt=int(width='r)) + +[] is [1] +# CHECK-L: []:list(elt=int(width='s)) is [1:int(width='s)]:list(elt=int(width='s)):bool + +1 in [1] +# CHECK-L: 1:int(width='t) in [1:int(width='t)]:list(elt=int(width='t)):bool + +[] < [1] +# CHECK-L: []:list(elt=int(width='u)) < [1:int(width='u)]:list(elt=int(width='u)):bool + +1.0 < 1 +# CHECK-L: 1.0:float < 1:int(width='v):float:bool diff --git a/lit-test/test/inferencer/error_assert.py b/lit-test/test/inferencer/error_assert.py new file mode 100644 index 000000000..1e7c10284 --- /dev/null +++ b/lit-test/test/inferencer/error_assert.py @@ -0,0 +1,6 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +x = "A" +# CHECK-L: ${LINE:+1}: error: assertion message must be a string literal +assert True, x diff --git a/lit-test/test/inferencer/error_builtin_calls.py b/lit-test/test/inferencer/error_builtin_calls.py new file mode 100644 index 000000000..aae74c5cd --- /dev/null +++ b/lit-test/test/inferencer/error_builtin_calls.py @@ -0,0 +1,12 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +a = 1 +# CHECK-L: ${LINE:+1}: error: the width argument of int() must be an integer literal +int(1.0, width=a) + +# CHECK-L: ${LINE:+1}: error: the argument of len() must be of an iterable type +len(1) + +# CHECK-L: ${LINE:+1}: error: the argument of list() must be of an iterable type +list(1) diff --git a/lit-test/test/inferencer/error_call.py b/lit-test/test/inferencer/error_call.py new file mode 100644 index 000000000..739ffa249 --- /dev/null +++ b/lit-test/test/inferencer/error_call.py @@ -0,0 +1,20 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +# CHECK-L: ${LINE:+1}: error: cannot call this expression of type int +(1)() + +def f(x, y, z=1): + pass + +# CHECK-L: ${LINE:+1}: error: variadic arguments are not supported +f(*[]) + +# CHECK-L: ${LINE:+1}: error: variadic arguments are not supported +f(**[]) + +# CHECK-L: ${LINE:+1}: error: the argument 'x' has been passed earlier as positional +f(1, x=1) + +# CHECK-L: ${LINE:+1}: error: mandatory argument 'x' is not passed +f() diff --git a/lit-test/test/inferencer/error_class.py b/lit-test/test/inferencer/error_class.py new file mode 100644 index 000000000..0cc075b05 --- /dev/null +++ b/lit-test/test/inferencer/error_class.py @@ -0,0 +1,10 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +# CHECK-L: ${LINE:+1}: error: inheritance is not supported +class a(1): + pass + +class b: + # CHECK-L: ${LINE:+1}: fatal: class body must contain only assignments and function definitions + x += 1 diff --git a/lit-test/test/inferencer/error_class_redefine.py b/lit-test/test/inferencer/error_class_redefine.py new file mode 100644 index 000000000..d5556fd98 --- /dev/null +++ b/lit-test/test/inferencer/error_class_redefine.py @@ -0,0 +1,8 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +class c: + pass +# CHECK-L: ${LINE:+1}: fatal: variable 'c' is already defined +class c: + pass diff --git a/lit-test/test/inferencer/error_coerce.py b/lit-test/test/inferencer/error_coerce.py new file mode 100644 index 000000000..bf4a5dd36 --- /dev/null +++ b/lit-test/test/inferencer/error_coerce.py @@ -0,0 +1,37 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +# CHECK-L: ${LINE:+1}: error: expected '<<' operand to be of integer type, not float +1 << 2.0 + +# CHECK-L: ${LINE:+3}: error: expected every '+' operand to be a list in this context +# CHECK-L: ${LINE:+2}: note: list of type list(elt=int(width='a)) +# CHECK-L: ${LINE:+1}: note: int(width='b), which cannot be added to a list +[1] + 2 + +# CHECK-L: ${LINE:+1}: error: cannot unify list(elt=int(width='a)) with list(elt=float): int(width='a) is incompatible with float +[1] + [2.0] + +# CHECK-L: ${LINE:+3}: error: expected every '+' operand to be a tuple in this context +# CHECK-L: ${LINE:+2}: note: tuple of type (int(width='a),) +# CHECK-L: ${LINE:+1}: note: int(width='b), which cannot be added to a tuple +(1,) + 2 + +# CHECK-L: ${LINE:+1}: error: passing tuples to '*' is not supported +(1,) * 2 + +# CHECK-L: ${LINE:+3}: error: expected '*' operands to be a list and an integer in this context +# CHECK-L: ${LINE:+2}: note: list operand of type list(elt=int(width='a)) +# CHECK-L: ${LINE:+1}: note: operand of type list(elt='b), which is not a valid repetition amount +[1] * [] + +# CHECK-L: ${LINE:+1}: error: cannot coerce list(elt='a) to a numeric type +[] - 1.0 + +# CHECK-L: ${LINE:+2}: error: expression of type int(width='a) has to be coerced to float, which makes assignment invalid +# CHECK-L: ${LINE:+1}: note: expression of type float +a = 1; a += 1.0 + +# CHECK-L: ${LINE:+2}: error: the result of this operation has type (int(width='a), float), which makes assignment to a slot of type (int(width='a),) invalid +# CHECK-L: ${LINE:+1}: note: expression of type (float,) +b = (1,); b += (1.0,) diff --git a/lit-test/test/inferencer/error_comprehension.py b/lit-test/test/inferencer/error_comprehension.py new file mode 100644 index 000000000..d586dd657 --- /dev/null +++ b/lit-test/test/inferencer/error_comprehension.py @@ -0,0 +1,8 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +# CHECK-L: ${LINE:+1}: error: if clauses in comprehensions are not supported +[x for x in [] if x] + +# CHECK-L: ${LINE:+1}: error: multiple for clauses in comprehensions are not supported +[(x, y) for x in [] for y in []] diff --git a/lit-test/test/inferencer/error_control_flow.py b/lit-test/test/inferencer/error_control_flow.py new file mode 100644 index 000000000..65e300511 --- /dev/null +++ b/lit-test/test/inferencer/error_control_flow.py @@ -0,0 +1,19 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +# CHECK-L: ${LINE:+1}: error: return statement outside of a function +return + +# CHECK-L: ${LINE:+1}: error: break statement outside of a loop +break + +# CHECK-L: ${LINE:+1}: error: continue statement outside of a loop +continue + +while True: + def f(): + # CHECK-L: ${LINE:+1}: error: break statement outside of a loop + break + + # CHECK-L: ${LINE:+1}: error: continue statement outside of a loop + continue diff --git a/lit-test/test/inferencer/error_exception.py b/lit-test/test/inferencer/error_exception.py new file mode 100644 index 000000000..4e4b340b5 --- /dev/null +++ b/lit-test/test/inferencer/error_exception.py @@ -0,0 +1,14 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +try: + pass +# CHECK-L: ${LINE:+1}: error: this expression must refer to an exception constructor +except 1: + pass + +try: + pass +# CHECK-L: ${LINE:+1}: error: cannot unify int(width='a) with Exception +except Exception as e: + e = 1 diff --git a/lit-test/test/inferencer/error_iterable.py b/lit-test/test/inferencer/error_iterable.py new file mode 100644 index 000000000..68ed1b002 --- /dev/null +++ b/lit-test/test/inferencer/error_iterable.py @@ -0,0 +1,5 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +# CHECK-L: ${LINE:+1}: error: type int(width='a) is not iterable +for x in 1: pass diff --git a/lit-test/test/inferencer/error_local_unbound.py b/lit-test/test/inferencer/error_local_unbound.py new file mode 100644 index 000000000..7327a8548 --- /dev/null +++ b/lit-test/test/inferencer/error_local_unbound.py @@ -0,0 +1,5 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +# CHECK-L: ${LINE:+1}: fatal: undefined variable 'x' +x diff --git a/lit-test/test/inferencer/error_locals.py b/lit-test/test/inferencer/error_locals.py new file mode 100644 index 000000000..3a5185ca8 --- /dev/null +++ b/lit-test/test/inferencer/error_locals.py @@ -0,0 +1,35 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +x = 1 +def a(): + # CHECK-L: ${LINE:+1}: error: cannot declare name 'x' as nonlocal: it is not bound in any outer scope + nonlocal x + +def f(): + y = 1 + def b(): + nonlocal y + # CHECK-L: ${LINE:+1}: error: name 'y' cannot be nonlocal and global simultaneously + global y + + def c(): + global y + # CHECK-L: ${LINE:+1}: error: name 'y' cannot be global and nonlocal simultaneously + nonlocal y + + def d(y): + # CHECK-L: ${LINE:+1}: error: name 'y' cannot be a parameter and global simultaneously + global y + + def e(y): + # CHECK-L: ${LINE:+1}: error: name 'y' cannot be a parameter and nonlocal simultaneously + nonlocal y + +# CHECK-L: ${LINE:+1}: error: duplicate parameter 'x' +def f(x, x): + pass + +# CHECK-L: ${LINE:+1}: error: variadic arguments are not supported +def g(*x): + pass diff --git a/lit-test/test/inferencer/error_method.py b/lit-test/test/inferencer/error_method.py new file mode 100644 index 000000000..cb4f77075 --- /dev/null +++ b/lit-test/test/inferencer/error_method.py @@ -0,0 +1,16 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +class c: + def f(): + pass + + def g(self): + pass + +# CHECK-L: ${LINE:+1}: error: function 'f()->NoneType delay('a)' of class 'c' cannot accept a self argument +c().f() + +c.g(1) +# CHECK-L: ${LINE:+1}: error: cannot unify with int(width='a) while inferring the type for self argument +c().g() diff --git a/lit-test/test/inferencer/error_return.py b/lit-test/test/inferencer/error_return.py new file mode 100644 index 000000000..ba4ddc2b3 --- /dev/null +++ b/lit-test/test/inferencer/error_return.py @@ -0,0 +1,16 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +# CHECK-L: ${LINE:+2}: error: cannot unify int(width='a) with NoneType +# CHECK-L: ${LINE:+1}: note: function with return type int(width='a) +def a(): + return 1 + # CHECK-L: ${LINE:+1}: note: a statement returning NoneType + return + +# CHECK-L: ${LINE:+2}: error: cannot unify int(width='a) with list(elt='b) +# CHECK-L: ${LINE:+1}: note: function with return type int(width='a) +def b(): + return 1 + # CHECK-L: ${LINE:+1}: note: a statement returning list(elt='b) + return [] diff --git a/lit-test/test/inferencer/error_subscript.py b/lit-test/test/inferencer/error_subscript.py new file mode 100644 index 000000000..0aadb3289 --- /dev/null +++ b/lit-test/test/inferencer/error_subscript.py @@ -0,0 +1,10 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +x = [] + +# CHECK-L: ${LINE:+1}: error: multi-dimensional slices are not supported +x[1,2] + +# CHECK-L: ${LINE:+1}: error: multi-dimensional slices are not supported +x[1:2,3:4] diff --git a/lit-test/test/inferencer/error_unify.py b/lit-test/test/inferencer/error_unify.py new file mode 100644 index 000000000..e81537d29 --- /dev/null +++ b/lit-test/test/inferencer/error_unify.py @@ -0,0 +1,27 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +a = 1 +b = [] + +# CHECK-L: ${LINE:+1}: error: cannot unify int(width='a) with list(elt='b) +a = b + +# CHECK-L: ${LINE:+1}: error: cannot unify int(width='a) with list(elt='b) +[1, []] +# CHECK-L: note: a list element of type int(width='a) +# CHECK-L: note: a list element of type list(elt='b) + +# CHECK-L: ${LINE:+1}: error: cannot unify int(width='a) with bool +1 and False +# CHECK-L: note: an operand of type int(width='a) +# CHECK-L: note: an operand of type bool + +# CHECK-L: ${LINE:+1}: error: expected unary '+' operand to be of numeric type, not list(elt='a) ++[] + +# CHECK-L: ${LINE:+1}: error: expected '~' operand to be of integer type, not float +~1.0 + +# CHECK-L: ${LINE:+1}: error: type int(width='a) does not have an attribute 'x' +(1).x diff --git a/lit-test/test/inferencer/exception.py b/lit-test/test/inferencer/exception.py new file mode 100644 index 000000000..e0e0f9645 --- /dev/null +++ b/lit-test/test/inferencer/exception.py @@ -0,0 +1,13 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +# CHECK-L: Exception: +Exception + +try: + pass +except Exception: + pass +except Exception as e: + # CHECK-L: e:Exception + e diff --git a/lit-test/test/inferencer/gcd.py b/lit-test/test/inferencer/gcd.py new file mode 100644 index 000000000..e2d4b4779 --- /dev/null +++ b/lit-test/test/inferencer/gcd.py @@ -0,0 +1,13 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer %s >%t + +def _gcd(a, b): + if a < 0: + a = -a + while a: + c = a + a = b % a + b = c + return b + +# CHECK-L: _gcd:(a:int(width='a), b:int(width='a))->int(width='a)(10:int(width='a), 25:int(width='a)):int(width='a) +_gcd(10, 25) diff --git a/lit-test/test/inferencer/prelude.py b/lit-test/test/inferencer/prelude.py new file mode 100644 index 000000000..643895e4a --- /dev/null +++ b/lit-test/test/inferencer/prelude.py @@ -0,0 +1,10 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +# CHECK-L: x: +x = len + +def f(): + global len + # CHECK-L: len:int(width='a) = + len = 1 diff --git a/lit-test/test/inferencer/scoping.py b/lit-test/test/inferencer/scoping.py new file mode 100644 index 000000000..604b2425d --- /dev/null +++ b/lit-test/test/inferencer/scoping.py @@ -0,0 +1,9 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +# CHECK-L: []:list(elt=int(width='a)) +x = [] + +def f(): + global x + x[0] = 1 diff --git a/lit-test/test/inferencer/unify.py b/lit-test/test/inferencer/unify.py new file mode 100644 index 000000000..59d23e937 --- /dev/null +++ b/lit-test/test/inferencer/unify.py @@ -0,0 +1,73 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +a = 1 +# CHECK-L: a:int(width='a) + +b = a +# CHECK-L: b:int(width='a) + +c = True +# CHECK-L: c:bool + +d = False +# CHECK-L: d:bool + +e = None +# CHECK-L: e:NoneType + +f = 1.0 +# CHECK-L: f:float + +g = [] +# CHECK-L: g:list(elt='b) + +h = [1] +# CHECK-L: h:list(elt=int(width='c)) + +i = [] +i[0] = 1 +# CHECK-L: i:list(elt=int(width='d)) + +j = [] +j += [1.0] +# CHECK-L: j:list(elt=float) + +1 if a else 2 +# CHECK-L: 1:int(width='f) if a:int(width='a) else 2:int(width='f):int(width='f) + +True and False +# CHECK-L: True:bool and False:bool:bool + +1 and 0 +# CHECK-L: 1:int(width='g) and 0:int(width='g):int(width='g) + +~1 +# CHECK-L: 1:int(width='h):int(width='h) + +not 1 +# CHECK-L: 1:int(width='i):bool + +[x for x in [1]] +# CHECK-L: [x:int(width='j) for x:int(width='j) in [1:int(width='j)]:list(elt=int(width='j))]:list(elt=int(width='j)) + +lambda x, y=1: x +# CHECK-L: lambda x:'k, y:int(width='l)=1:int(width='l): x:'k:(x:'k, ?y:int(width='l))->'k + +k = "x" +# CHECK-L: k:str + +IndexError() +# CHECK-L: IndexError:():IndexError + +IndexError("x") +# CHECK-L: IndexError:("x":str):IndexError + +IndexError("x", 1) +# CHECK-L: IndexError:("x":str, 1:int(width=64)):IndexError + +IndexError("x", 1, 1) +# CHECK-L: IndexError:("x":str, 1:int(width=64), 1:int(width=64)):IndexError + +IndexError("x", 1, 1, 1) +# CHECK-L: IndexError:("x":str, 1:int(width=64), 1:int(width=64), 1:int(width=64)):IndexError diff --git a/lit-test/test/inferencer/with.py b/lit-test/test/inferencer/with.py new file mode 100644 index 000000000..b11f3da13 --- /dev/null +++ b/lit-test/test/inferencer/with.py @@ -0,0 +1,5 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +# CHECK-L: as x: +with parallel as x: pass diff --git a/lit-test/test/integration/arithmetics.py b/lit-test/test/integration/arithmetics.py new file mode 100644 index 000000000..1963121e6 --- /dev/null +++ b/lit-test/test/integration/arithmetics.py @@ -0,0 +1,40 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s +# RUN: %python %s +# REQUIRES: exceptions + +assert -(-1) == 1 +assert -(-1.0) == 1.0 +assert +1 == 1 +assert +1.0 == 1.0 +assert 1 + 1 == 2 +assert 1.0 + 1.0 == 2.0 +assert 1 - 1 == 0 +assert 1.0 - 1.0 == 0.0 +assert 2 * 2 == 4 +assert 2.0 * 2.0 == 4.0 +assert 3 / 2 == 1.5 +assert 3.0 / 2.0 == 1.5 +assert 3 // 2 == 1 +assert 3.0 // 2.0 == 1.0 +assert 3 % 2 == 1 +assert -3 % 2 == 1 +assert 3 % -2 == -1 +assert -3 % -2 == -1 +assert 3.0 % 2.0 == 1.0 +assert -3.0 % 2.0 == 1.0 +assert 3.0 % -2.0 == -1.0 +assert -3.0 % -2.0 == -1.0 +assert 3 ** 2 == 9 +assert 3.0 ** 2.0 == 9.0 +assert 9.0 ** 0.5 == 3.0 +assert 1 << 1 == 2 +assert 2 >> 1 == 1 +assert -2 >> 1 == -1 +#ARTIQ#assert 1 << 32 == 0 +assert -1 >> 32 == -1 +assert 0x18 & 0x0f == 0x08 +assert 0x18 | 0x0f == 0x1f +assert 0x18 ^ 0x0f == 0x17 + +assert [1] + [2] == [1, 2] +assert [1] * 3 == [1, 1, 1] diff --git a/lit-test/test/integration/attribute.py b/lit-test/test/integration/attribute.py new file mode 100644 index 000000000..301243b54 --- /dev/null +++ b/lit-test/test/integration/attribute.py @@ -0,0 +1,7 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s +# RUN: %python %s + +r = range(10) +assert r.start == 0 +assert r.stop == 10 +assert r.step == 1 diff --git a/lit-test/test/integration/bool.py b/lit-test/test/integration/bool.py new file mode 100644 index 000000000..1a68ebd1c --- /dev/null +++ b/lit-test/test/integration/bool.py @@ -0,0 +1,21 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s +# RUN: %python %s + +assert (not 0) == True +assert (not 1) == False + +assert (0 and 0) is 0 +assert (1 and 0) is 0 +assert (0 and 1) is 0 +assert (1 and 2) is 2 + +assert (0 or 0) is 0 +assert (1 or 0) is 1 +assert (0 or 1) is 1 +assert (1 or 2) is 1 + +assert bool(False) is False and bool(False) is False +assert bool(0) is False and bool(1) is True +assert bool(0.0) is False and bool(1.0) is True +x = []; assert bool(x) is False; x = [1]; assert bool(x) is True +assert bool(range(0)) is False and bool(range(1)) is True diff --git a/lit-test/test/integration/builtin.py b/lit-test/test/integration/builtin.py new file mode 100644 index 000000000..dbb7e471c --- /dev/null +++ b/lit-test/test/integration/builtin.py @@ -0,0 +1,26 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s +# RUN: %python %s +# REQUIRES: exceptions + +assert bool() is False +# bool(x) is tested in bool.py + +assert int() is 0 +assert int(1.0) is 1 +#ARTIQ#assert int(1, width=64) << 40 is 1099511627776 + +#ARTIQ#assert float() is 0.0 +#ARTIQ#assert float(1) is 1.0 + +x = list() +if False: x = [1] +assert x == [] + +#ARTIQ#assert range(10) is range(0, 10, 1) +#ARTIQ#assert range(1, 10) is range(1, 10, 1) + +assert len([1, 2, 3]) is 3 +assert len(range(10)) is 10 +assert len(range(0, 10, 2)) is 5 + +#ARTIQ#assert round(1.4) is 1 and round(1.6) is 2 diff --git a/lit-test/test/integration/class.py b/lit-test/test/integration/class.py new file mode 100644 index 000000000..205210ea0 --- /dev/null +++ b/lit-test/test/integration/class.py @@ -0,0 +1,16 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s +# RUN: %python %s + +class c: + a = 1 + def f(): + return 2 + def g(self): + return self.a + 5 + def h(self, x): + return self.a + x + +assert c.a == 1 +assert c.f() == 2 +assert c().g() == 6 +assert c().h(9) == 10 diff --git a/lit-test/test/integration/compare.py b/lit-test/test/integration/compare.py new file mode 100644 index 000000000..48a33cc09 --- /dev/null +++ b/lit-test/test/integration/compare.py @@ -0,0 +1,19 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s +# RUN: %python %s + +assert 1 < 2 and not (2 < 1) +assert 2 > 1 and not (1 > 2) +assert 1 == 1 and not (1 == 2) +assert 1 != 2 and not (1 != 1) +assert 1 <= 1 and 1 <= 2 and not (2 <= 1) +assert 1 >= 1 and 2 >= 1 and not (1 >= 2) +assert 1 is 1 and not (1 is 2) +assert 1 is not 2 and not (1 is not 1) + +x, y = [1], [1] +assert x is x and x is not y +#ARTIQ#assert range(10) is range(10) and range(10) is not range(11) + +lst = [1, 2, 3] +assert 1 in lst and 0 not in lst +assert 1 in range(10) and 11 not in range(10) and -1 not in range(10) diff --git a/lit-test/test/integration/finally.py b/lit-test/test/integration/finally.py new file mode 100644 index 000000000..e629602ad --- /dev/null +++ b/lit-test/test/integration/finally.py @@ -0,0 +1,78 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s >%t.1 +# RUN: OutputCheck %s --file-to-check=%t.1 +# RUN: %python %s >%t.2 +# RUN: OutputCheck %s --file-to-check=%t.2 +# REQUIRES: exceptions + +def f(): + while True: + try: + print("f-try") + break + finally: + print("f-finally") + print("f-out") + +def g(): + x = True + while x: + try: + print("g-try") + x = False + continue + finally: + print("g-finally") + print("g-out") + +def h(): + try: + print("h-try") + return 10 + finally: + print("h-finally") + print("h-out") + return 20 + +def i(): + try: + print("i-try") + return 10 + finally: + print("i-finally") + return 30 + print("i-out") + return 20 + +def j(): + try: + print("j-try") + finally: + print("j-finally") + print("j-out") + +# CHECK-L: f-try +# CHECK-L: f-finally +# CHECK-L: f-out +f() + +# CHECK-L: g-try +# CHECK-L: g-finally +# CHECK-L: g-out +g() + +# CHECK-L: h-try +# CHECK-L: h-finally +# CHECK-NOT-L: h-out +# CHECK-L: h 10 +print("h", h()) + +# CHECK-L: i-try +# CHECK-L: i-finally +# CHECK-NOT-L: i-out +# CHECK-L: i 30 +print("i", i()) + +# CHECK-L: j-try +# CHECK-L: j-finally +# CHECK-L: j-out +print("j", j()) diff --git a/lit-test/test/integration/for.py b/lit-test/test/integration/for.py new file mode 100644 index 000000000..9c305f5da --- /dev/null +++ b/lit-test/test/integration/for.py @@ -0,0 +1,29 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s +# RUN: %python %s + +count = 0 +for x in range(10): + count += 1 +assert count == 10 + +for x in range(10): + assert True +else: + assert True + +for x in range(0): + assert False +else: + assert True + +for x in range(10): + continue + assert False +else: + assert True + +for x in range(10): + break + assert False +else: + assert False diff --git a/lit-test/test/integration/function.py b/lit-test/test/integration/function.py new file mode 100644 index 000000000..bbaca2083 --- /dev/null +++ b/lit-test/test/integration/function.py @@ -0,0 +1,11 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s +# RUN: %python %s + +def fib(x): + if x == 1: + return x + else: + return x * fib(x - 1) +assert fib(5) == 120 + +# argument combinations handled in lambda.py diff --git a/lit-test/test/integration/if.py b/lit-test/test/integration/if.py new file mode 100644 index 000000000..fab6c3df0 --- /dev/null +++ b/lit-test/test/integration/if.py @@ -0,0 +1,21 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s +# RUN: %python %s + +if True: + assert True + +if False: + assert False + +if True: + assert True +else: + assert False + +if False: + assert False +else: + assert True + +assert (0 if True else 1) == 0 +assert (0 if False else 1) == 1 diff --git a/lit-test/test/integration/instance.py b/lit-test/test/integration/instance.py new file mode 100644 index 000000000..bf255d88f --- /dev/null +++ b/lit-test/test/integration/instance.py @@ -0,0 +1,14 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s +# RUN: %python %s + +class c: + a = 1 + +i = c() + +assert i.a == 1 + +def f(): + c = None + assert i.a == 1 +f() diff --git a/lit-test/test/integration/lambda.py b/lit-test/test/integration/lambda.py new file mode 100644 index 000000000..a1f08763a --- /dev/null +++ b/lit-test/test/integration/lambda.py @@ -0,0 +1,10 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s +# RUN: %python %s + +assert (lambda: 1)() == 1 +assert (lambda x: x)(1) == 1 +assert (lambda x, y: x + y)(1, 2) == 3 +assert (lambda x, y=1: x + y)(1) == 2 +assert (lambda x, y=1: x + y)(1, 2) == 3 +assert (lambda x, y=1: x + y)(x=3) == 4 +assert (lambda x, y=1: x + y)(y=2, x=3) == 5 diff --git a/lit-test/test/integration/list.py b/lit-test/test/integration/list.py new file mode 100644 index 000000000..06f08e426 --- /dev/null +++ b/lit-test/test/integration/list.py @@ -0,0 +1,9 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s +# RUN: %python %s +# REQUIRES: exceptions + +[x, y] = [1, 2] +assert (x, y) == (1, 2) + +lst = [1, 2, 3] +assert [x*x for x in lst] == [1, 4, 9] diff --git a/lit-test/test/integration/locals.py b/lit-test/test/integration/locals.py new file mode 100644 index 000000000..6ad9b0763 --- /dev/null +++ b/lit-test/test/integration/locals.py @@ -0,0 +1,7 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s +# RUN: %python %s + +x = 1 +assert x == 1 +x += 1 +assert x == 2 diff --git a/lit-test/test/integration/print.py b/lit-test/test/integration/print.py new file mode 100644 index 000000000..06653d43b --- /dev/null +++ b/lit-test/test/integration/print.py @@ -0,0 +1,32 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +# CHECK-L: None +print(None) + +# CHECK-L: True False +print(True, False) + +# CHECK-L: 1 -1 +print(1, -1) + +# CHECK-L: 10000000000 +print(10000000000) + +# CHECK-L: 1.5 +print(1.5) + +# CHECK-L: (True, 1) +print((True, 1)) + +# CHECK-L: (True,) +print((True,)) + +# CHECK-L: [1, 2, 3] +print([1, 2, 3]) + +# CHECK-L: [[1, 2], [3]] +print([[1, 2], [3]]) + +# CHECK-L: range(0, 10, 1) +print(range(10)) diff --git a/lit-test/test/integration/subscript.py b/lit-test/test/integration/subscript.py new file mode 100644 index 000000000..eaa7f8455 --- /dev/null +++ b/lit-test/test/integration/subscript.py @@ -0,0 +1,19 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s +# RUN: %python %s +# REQUIRES: exceptions + +lst = list(range(10)) +assert lst[0] == 0 +assert lst[1] == 1 +assert lst[-1] == 9 +assert lst[0:1] == [0] +assert lst[0:2] == [0, 1] +assert lst[0:10] == lst +assert lst[1:-1] == lst[1:9] +assert lst[0:1:2] == [0] +assert lst[0:2:2] == [0] +assert lst[0:3:2] == [0, 2] + +lst = [0, 0, 0, 0, 0] +lst[0:5:2] = [1, 2, 3] +assert lst == [1, 0, 2, 0, 3] diff --git a/lit-test/test/integration/tuple.py b/lit-test/test/integration/tuple.py new file mode 100644 index 000000000..5d6c153dd --- /dev/null +++ b/lit-test/test/integration/tuple.py @@ -0,0 +1,7 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s +# RUN: %python %s + +x, y = 2, 1 +x, y = y, x +assert x == 1 and y == 2 +assert (1, 2) + (3.0,) == (1, 2, 3.0) diff --git a/lit-test/test/integration/while.py b/lit-test/test/integration/while.py new file mode 100644 index 000000000..7cd2ac626 --- /dev/null +++ b/lit-test/test/integration/while.py @@ -0,0 +1,31 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s +# RUN: %python %s + +cond, count = True, 0 +while cond: + count += 1 + cond = False +assert count == 1 + +while False: + pass +else: + assert True + +cond = True +while cond: + cond = False +else: + assert True + +while True: + break + assert False +else: + assert False + +cond = True +while cond: + cond = False + continue + assert False diff --git a/lit-test/test/interleaving/error_inlining.py b/lit-test/test/interleaving/error_inlining.py new file mode 100644 index 000000000..c34959a51 --- /dev/null +++ b/lit-test/test/interleaving/error_inlining.py @@ -0,0 +1,16 @@ +# RUN: %python -m artiq.compiler.testbench.signature +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +def f(): + delay_mu(2) + +def g(): + delay_mu(2) + +x = f if True else g + +def h(): + with parallel: + f() + # CHECK-L: ${LINE:+1}: fatal: it is not possible to interleave this function call within a 'with parallel:' statement because the compiler could not prove that the same function would always be called + x() diff --git a/lit-test/test/interleaving/indirect.py b/lit-test/test/interleaving/indirect.py new file mode 100644 index 000000000..19f59c69f --- /dev/null +++ b/lit-test/test/interleaving/indirect.py @@ -0,0 +1,28 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +def f(): + delay_mu(2) + +def g(): + with parallel: + with sequential: + print("A", now_mu()) + f() + # + print("B", now_mu()) + with sequential: + print("C", now_mu()) + f() + # + print("D", now_mu()) + f() + # + print("E", now_mu()) + +# CHECK-L: A 0 +# CHECK-L: C 0 +# CHECK-L: B 2 +# CHECK-L: D 2 +# CHECK-L: E 4 +g() diff --git a/lit-test/test/interleaving/indirect_arg.py b/lit-test/test/interleaving/indirect_arg.py new file mode 100644 index 000000000..e9af00034 --- /dev/null +++ b/lit-test/test/interleaving/indirect_arg.py @@ -0,0 +1,28 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +def f(n): + delay_mu(n) + +def g(): + with parallel: + with sequential: + print("A", now_mu()) + f(2) + # + print("B", now_mu()) + with sequential: + print("C", now_mu()) + f(2) + # + print("D", now_mu()) + f(2) + # + print("E", now_mu()) + +# CHECK-L: A 0 +# CHECK-L: C 0 +# CHECK-L: B 2 +# CHECK-L: D 2 +# CHECK-L: E 4 +g() diff --git a/lit-test/test/interleaving/nonoverlapping.py b/lit-test/test/interleaving/nonoverlapping.py new file mode 100644 index 000000000..60b3f739e --- /dev/null +++ b/lit-test/test/interleaving/nonoverlapping.py @@ -0,0 +1,25 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +def g(): + with parallel: + with sequential: + print("A", now_mu()) + delay_mu(2) + # + print("B", now_mu()) + with sequential: + print("C", now_mu()) + delay_mu(2) + # + print("D", now_mu()) + delay_mu(2) + # + print("E", now_mu()) + +# CHECK-L: A 0 +# CHECK-L: C 0 +# CHECK-L: B 2 +# CHECK-L: D 2 +# CHECK-L: E 4 +g() diff --git a/lit-test/test/interleaving/overlapping.py b/lit-test/test/interleaving/overlapping.py new file mode 100644 index 000000000..630fdb638 --- /dev/null +++ b/lit-test/test/interleaving/overlapping.py @@ -0,0 +1,25 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +def g(): + with parallel: + with sequential: + print("A", now_mu()) + delay_mu(3) + # + print("B", now_mu()) + with sequential: + print("C", now_mu()) + delay_mu(2) + # + print("D", now_mu()) + delay_mu(2) + # + print("E", now_mu()) + +# CHECK-L: A 0 +# CHECK-L: C 0 +# CHECK-L: D 2 +# CHECK-L: B 3 +# CHECK-L: E 4 +g() diff --git a/lit-test/test/interleaving/pure_impure_tie.py b/lit-test/test/interleaving/pure_impure_tie.py new file mode 100644 index 000000000..f2eab0502 --- /dev/null +++ b/lit-test/test/interleaving/pure_impure_tie.py @@ -0,0 +1,14 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +def f(): + delay_mu(2) + +def g(): + with parallel: + f() + delay_mu(2) + print(now_mu()) + +# CHECK-L: 2 +g() diff --git a/lit-test/test/iodelay/argument.py b/lit-test/test/iodelay/argument.py new file mode 100644 index 000000000..f026c78f1 --- /dev/null +++ b/lit-test/test/iodelay/argument.py @@ -0,0 +1,7 @@ +# RUN: %python -m artiq.compiler.testbench.signature %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +# CHECK-L: f: (a:float, b:int(width=64))->NoneType delay(s->mu(a) + b mu) +def f(a, b): + delay(a) + delay_mu(b) diff --git a/lit-test/test/iodelay/arith.py b/lit-test/test/iodelay/arith.py new file mode 100644 index 000000000..0eba79715 --- /dev/null +++ b/lit-test/test/iodelay/arith.py @@ -0,0 +1,8 @@ +# RUN: %python -m artiq.compiler.testbench.signature %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +# CHECK-L: f: (a:int(width=32), b:int(width=32), c:int(width=32), d:int(width=32), e:int(width=32))->NoneType delay(s->mu(a * b // c + d - 10 / e) mu) +def f(a, b, c, d, e): + delay(a * b // c + d - 10 / e) + +f(1,2,3,4,5) diff --git a/lit-test/test/iodelay/call.py b/lit-test/test/iodelay/call.py new file mode 100644 index 000000000..e08446da1 --- /dev/null +++ b/lit-test/test/iodelay/call.py @@ -0,0 +1,10 @@ +# RUN: %python -m artiq.compiler.testbench.signature %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +def f(): + delay_mu(1) + +# CHECK-L: g: ()->NoneType delay(2 mu) +def g(): + f() + f() diff --git a/lit-test/test/iodelay/call_subst.py b/lit-test/test/iodelay/call_subst.py new file mode 100644 index 000000000..fddc007c6 --- /dev/null +++ b/lit-test/test/iodelay/call_subst.py @@ -0,0 +1,13 @@ +# RUN: %python -m artiq.compiler.testbench.signature %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +def pulse(len): + # "on" + delay_mu(len) + # "off" + delay_mu(len) + +# CHECK-L: f: ()->NoneType delay(600 mu) +def f(): + pulse(100) + pulse(200) diff --git a/lit-test/test/iodelay/class.py b/lit-test/test/iodelay/class.py new file mode 100644 index 000000000..6bcb9339b --- /dev/null +++ b/lit-test/test/iodelay/class.py @@ -0,0 +1,12 @@ +# RUN: %python -m artiq.compiler.testbench.signature %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +# CHECK-L: g: (i:)->NoneType delay(1000000 mu) +def g(i): + i.f(1.0) + +class c: + def f(self, x): + delay(x) + +g(c()) diff --git a/lit-test/test/iodelay/error_arith.py b/lit-test/test/iodelay/error_arith.py new file mode 100644 index 000000000..54e30aa24 --- /dev/null +++ b/lit-test/test/iodelay/error_arith.py @@ -0,0 +1,21 @@ +# RUN: %python -m artiq.compiler.testbench.signature +diag +delay %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +def f(a): + b = 1.0 + # CHECK-L: ${LINE:+3}: error: this call cannot be interleaved + # CHECK-L: ${LINE:+2}: note: this variable is not an argument of the innermost function + # CHECK-L: ${LINE:-4}: note: only these arguments are in scope of analysis + delay(b) + +def g(): + # CHECK-L: ${LINE:+2}: error: this call cannot be interleaved + # CHECK-L: ${LINE:+1}: note: this operator is not supported + delay(2.0**2) + +def h(): + # CHECK-L: ${LINE:+2}: error: this call cannot be interleaved + # CHECK-L: ${LINE:+1}: note: this expression is not supported + delay_mu(1 if False else 2) + +f(1) diff --git a/lit-test/test/iodelay/error_bad_parallel.py b/lit-test/test/iodelay/error_bad_parallel.py new file mode 100644 index 000000000..8a9a462b2 --- /dev/null +++ b/lit-test/test/iodelay/error_bad_parallel.py @@ -0,0 +1,8 @@ +# RUN: %python -m artiq.compiler.testbench.signature +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +def f(): + with parallel: + # CHECK-L: ${LINE:+1}: error: while statement cannot be interleaved + while True: + delay_mu(1) diff --git a/lit-test/test/iodelay/error_builtinfn.py b/lit-test/test/iodelay/error_builtinfn.py new file mode 100644 index 000000000..ad44c972c --- /dev/null +++ b/lit-test/test/iodelay/error_builtinfn.py @@ -0,0 +1,14 @@ +# RUN: %python -m artiq.compiler.testbench.signature +diag +delay %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +def f(): + x = 1 + # CHECK-L: ${LINE:+1}: error: this call cannot be interleaved because an argument cannot be statically evaluated + delay_mu(x) + +def g(): + x = 1.0 + # CHECK-L: ${LINE:+1}: error: this call cannot be interleaved + delay(x) + + diff --git a/lit-test/test/iodelay/error_call_nested.py b/lit-test/test/iodelay/error_call_nested.py new file mode 100644 index 000000000..b283c0917 --- /dev/null +++ b/lit-test/test/iodelay/error_call_nested.py @@ -0,0 +1,11 @@ +# RUN: %python -m artiq.compiler.testbench.signature +diag +delay %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +def f(): + # CHECK-L: ${LINE:+1}: error: this call cannot be interleaved + delay(1.0**2) + +def g(): + # CHECK-L: ${LINE:+1}: note: function called here + f() + f() diff --git a/lit-test/test/iodelay/error_call_subst.py b/lit-test/test/iodelay/error_call_subst.py new file mode 100644 index 000000000..62c6bb29a --- /dev/null +++ b/lit-test/test/iodelay/error_call_subst.py @@ -0,0 +1,13 @@ +# RUN: %python -m artiq.compiler.testbench.signature +diag +delay %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +def pulse(len): + # "on" + delay_mu(len) + # "off" + delay_mu(len) + +def f(): + a = 100 + # CHECK-L: ${LINE:+1}: error: this call cannot be interleaved + pulse(a) diff --git a/lit-test/test/iodelay/error_control_flow.py b/lit-test/test/iodelay/error_control_flow.py new file mode 100644 index 000000000..c179c95b9 --- /dev/null +++ b/lit-test/test/iodelay/error_control_flow.py @@ -0,0 +1,23 @@ +# RUN: %python -m artiq.compiler.testbench.signature +diag +delay %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +def f(): + # CHECK-L: ${LINE:+1}: error: while statement cannot be interleaved + while True: + delay_mu(1) + +def g(): + # CHECK-L: ${LINE:+1}: error: if statement cannot be interleaved + if True: + delay_mu(1) + +def h(): + # CHECK-L: ${LINE:+1}: error: if expression cannot be interleaved + delay_mu(1) if True else delay_mu(2) + +def i(): + # CHECK-L: ${LINE:+1}: error: try statement cannot be interleaved + try: + delay_mu(1) + finally: + pass diff --git a/lit-test/test/iodelay/error_for.py b/lit-test/test/iodelay/error_for.py new file mode 100644 index 000000000..5c5b10ffc --- /dev/null +++ b/lit-test/test/iodelay/error_for.py @@ -0,0 +1,9 @@ +# RUN: %python -m artiq.compiler.testbench.signature +diag +delay %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +def f(): + r = range(10) + # CHECK-L: ${LINE:+2}: error: for statement cannot be interleaved because trip count is indeterminate + # CHECK-L: ${LINE:+1}: note: this value is not a constant range literal + for _ in r: + delay_mu(1) diff --git a/lit-test/test/iodelay/error_goto.py b/lit-test/test/iodelay/error_goto.py new file mode 100644 index 000000000..02686cd86 --- /dev/null +++ b/lit-test/test/iodelay/error_goto.py @@ -0,0 +1,14 @@ +# RUN: %python -m artiq.compiler.testbench.signature +diag +delay %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +def f(): + for _ in range(10): + delay_mu(10) + # CHECK-L: ${LINE:+1}: error: loop trip count is indeterminate because of control flow + break + +def g(): + for _ in range(10): + delay_mu(10) + # CHECK-L: ${LINE:+1}: error: loop trip count is indeterminate because of control flow + continue diff --git a/lit-test/test/iodelay/error_return.py b/lit-test/test/iodelay/error_return.py new file mode 100644 index 000000000..c047de2b0 --- /dev/null +++ b/lit-test/test/iodelay/error_return.py @@ -0,0 +1,8 @@ +# RUN: %python -m artiq.compiler.testbench.signature +diag +delay %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +def f(): + if True: + # CHECK-L: ${LINE:+1}: error: only return statement at the end of the function can be interleaved + return 1 + delay_mu(1) diff --git a/lit-test/test/iodelay/error_unify.py b/lit-test/test/iodelay/error_unify.py new file mode 100644 index 000000000..723b5e393 --- /dev/null +++ b/lit-test/test/iodelay/error_unify.py @@ -0,0 +1,11 @@ +# RUN: %python -m artiq.compiler.testbench.signature +diag +delay %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +def f(): + delay_mu(10) + +# CHECK-L: ${LINE:+1}: fatal: delay delay(20 mu) was inferred for this function, but its delay is already constrained externally to delay(10 mu) +def g(): + delay_mu(20) + +x = f if True else g diff --git a/lit-test/test/iodelay/goto.py b/lit-test/test/iodelay/goto.py new file mode 100644 index 000000000..d80de43f9 --- /dev/null +++ b/lit-test/test/iodelay/goto.py @@ -0,0 +1,14 @@ +# RUN: %python -m artiq.compiler.testbench.signature %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +# CHECK-L: f: ()->NoneType delay(10 mu) +def f(): + delay_mu(10) + for _ in range(10): + break + +# CHECK-L: g: ()->NoneType delay(10 mu) +def g(): + delay_mu(10) + for _ in range(10): + continue diff --git a/lit-test/test/iodelay/linear.py b/lit-test/test/iodelay/linear.py new file mode 100644 index 000000000..cb2bed204 --- /dev/null +++ b/lit-test/test/iodelay/linear.py @@ -0,0 +1,12 @@ +# RUN: %python -m artiq.compiler.testbench.signature %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +# CHECK-L: f: ()->NoneType delay(1001000 mu) +def f(): + delay(1.0) + delay_mu(1000) + +# CHECK-L: g: ()->NoneType delay(3 mu) +def g(): + delay_mu(1) + delay_mu(2) diff --git a/lit-test/test/iodelay/loop.py b/lit-test/test/iodelay/loop.py new file mode 100644 index 000000000..37bcb9e07 --- /dev/null +++ b/lit-test/test/iodelay/loop.py @@ -0,0 +1,13 @@ +# RUN: %python -m artiq.compiler.testbench.signature %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +# CHECK-L: f: ()->NoneType delay(30 mu) +def f(): + for _ in range(10): + delay_mu(3) + +# CHECK-L: g: ()->NoneType delay(60 mu) +def g(): + for _ in range(10): + for _ in range(2): + delay_mu(3) diff --git a/lit-test/test/iodelay/order_invariance.py b/lit-test/test/iodelay/order_invariance.py new file mode 100644 index 000000000..e4deb96c5 --- /dev/null +++ b/lit-test/test/iodelay/order_invariance.py @@ -0,0 +1,10 @@ +# RUN: %python -m artiq.compiler.testbench.signature %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +# CHECK-L: g: ()->NoneType delay(2 mu) +def g(): + f() + f() + +def f(): + delay_mu(1) diff --git a/lit-test/test/iodelay/parallel.py b/lit-test/test/iodelay/parallel.py new file mode 100644 index 000000000..0976f1780 --- /dev/null +++ b/lit-test/test/iodelay/parallel.py @@ -0,0 +1,15 @@ +# RUN: %python -m artiq.compiler.testbench.signature %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +# CHECK-L: f: (a:int(width=64), b:int(width=64))->NoneType delay(max(a, b) mu) +def f(a, b): + with parallel: + delay_mu(a) + delay_mu(b) + +# CHECK-L: g: (a:int(width=64))->NoneType delay(max(a, 200) mu) +def g(a): + with parallel: + delay_mu(100) + delay_mu(200) + delay_mu(a) diff --git a/lit-test/test/iodelay/range.py b/lit-test/test/iodelay/range.py new file mode 100644 index 000000000..9c4a73a79 --- /dev/null +++ b/lit-test/test/iodelay/range.py @@ -0,0 +1,21 @@ +# RUN: %python -m artiq.compiler.testbench.signature %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +# CHECK-L: f: (a:int(width=32))->NoneType delay(3 * a mu) +def f(a): + for _ in range(a): + delay_mu(3) + +# CHECK-L: g: (a:int(width=32), b:int(width=32))->NoneType delay(3 * (b - a) mu) +def g(a, b): + for _ in range(a, b): + delay_mu(3) + +# CHECK-L: h: (a:int(width=32), b:int(width=32), c:int(width=32))->NoneType delay(3 * (b - a) // c mu) +def h(a, b, c): + for _ in range(a, b, c): + delay_mu(3) + +f(1) +g(1,2) +h(1,2,3) diff --git a/lit-test/test/iodelay/return.py b/lit-test/test/iodelay/return.py new file mode 100644 index 000000000..d07267a44 --- /dev/null +++ b/lit-test/test/iodelay/return.py @@ -0,0 +1,17 @@ +# RUN: %python -m artiq.compiler.testbench.signature %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +# CHECK-L: f: ()->int(width=32) delay(30 mu) +def f(): + for _ in range(10): + delay_mu(3) + return 10 + +# CHECK-L: g: (x:float)->int(width=32) +# CHECK-NOT-L: delay +def g(x): + if x > 1.0: + return 1 + return 0 + +g(1.0) diff --git a/lit-test/test/iodelay/sequential.py b/lit-test/test/iodelay/sequential.py new file mode 100644 index 000000000..54a776b68 --- /dev/null +++ b/lit-test/test/iodelay/sequential.py @@ -0,0 +1,8 @@ +# RUN: %python -m artiq.compiler.testbench.signature %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +# CHECK-L: f: (a:int(width=64), b:int(width=64))->NoneType delay(a + b mu) +def f(a, b): + with sequential: + delay_mu(a) + delay_mu(b) diff --git a/lit-test/test/lit.cfg b/lit-test/test/lit.cfg new file mode 100644 index 000000000..08f024e55 --- /dev/null +++ b/lit-test/test/lit.cfg @@ -0,0 +1,27 @@ +import os, subprocess +import lit.util +import lit.formats + +root = os.path.join(os.path.dirname(__file__), '..') + +config.name = 'ARTIQ' +config.test_format = lit.formats.ShTest() +config.suffixes = ['.py'] + +python_executable = 'python3.5' + +harness = '{} {}'.format(python_executable, os.path.join(root, 'harness.py')) +config.substitutions.append( ('%python', harness) ) + +not_ = '{} {}'.format(python_executable, os.path.join(root, 'not.py')) +config.substitutions.append( ('%not', not_) ) + +if os.name == 'posix': + support_build = os.path.join(root, 'libartiq_support') + if subprocess.call(['make', '-sC', support_build]) != 0: + lit_config.fatal("Unable to build JIT support library") + + support_lib = os.path.join(support_build, 'libartiq_support.so') + config.environment['LIBARTIQ_SUPPORT'] = support_lib + + config.available_features.add('exceptions') diff --git a/lit-test/test/local_access/invalid_closure.py b/lit-test/test/local_access/invalid_closure.py new file mode 100644 index 000000000..75948fb57 --- /dev/null +++ b/lit-test/test/local_access/invalid_closure.py @@ -0,0 +1,15 @@ +# RUN: %python -m artiq.compiler.testbench.signature +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +if False: + t = 1 + +# CHECK-L: ${LINE:+1}: error: variable 't' can be captured in a closure uninitialized +l = lambda: t + +# CHECK-L: ${LINE:+1}: error: variable 't' can be captured in a closure uninitialized +def f(): + return t + +l() +f() diff --git a/lit-test/test/local_access/invalid_flow.py b/lit-test/test/local_access/invalid_flow.py new file mode 100644 index 000000000..7b99958b4 --- /dev/null +++ b/lit-test/test/local_access/invalid_flow.py @@ -0,0 +1,20 @@ +# RUN: %python -m artiq.compiler.testbench.signature +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +x = 1 +if x > 10: + y = 1 +# CHECK-L: ${LINE:+1}: error: variable 'y' is not always initialized +x + y + +for z in [1]: + pass +# CHECK-L: ${LINE:+1}: error: variable 'z' is not always initialized +-z + +if True: + pass +else: + t = 1 +# CHECK-L: ${LINE:+1}: error: variable 't' is not always initialized +-t diff --git a/lit-test/test/local_access/valid.py b/lit-test/test/local_access/valid.py new file mode 100644 index 000000000..3c5fd0208 --- /dev/null +++ b/lit-test/test/local_access/valid.py @@ -0,0 +1,7 @@ +# RUN: %python -m artiq.compiler.testbench.signature %s >%t + +if False: + x = 1 +else: + x = 2 +-x diff --git a/lit-test/test/monomorphism/error_notmono.py b/lit-test/test/monomorphism/error_notmono.py new file mode 100644 index 000000000..9c9b02452 --- /dev/null +++ b/lit-test/test/monomorphism/error_notmono.py @@ -0,0 +1,9 @@ +# RUN: %python -m artiq.compiler.testbench.signature +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +# CHECK-L: ${LINE:+1}: error: the type of this expression cannot be fully inferred +x = int(1) + +# CHECK-L: ${LINE:+1}: error: the return type of this function cannot be fully inferred +def fn(): + return int(1) diff --git a/lit-test/test/monomorphism/integers.py b/lit-test/test/monomorphism/integers.py new file mode 100644 index 000000000..20850cb47 --- /dev/null +++ b/lit-test/test/monomorphism/integers.py @@ -0,0 +1,5 @@ +# RUN: %python -m artiq.compiler.testbench.signature %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +x = 1 +# CHECK-L: x: int(width=32) diff --git a/lit-test/test/time/advance.py b/lit-test/test/time/advance.py new file mode 100644 index 000000000..0faaaa597 --- /dev/null +++ b/lit-test/test/time/advance.py @@ -0,0 +1,9 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s + +assert now() == 0.0 +delay(100.0) +assert now() == 100.0 +at(12345.0) +assert now() == 12345.0 + +assert now_mu() == 12345000000 diff --git a/lit-test/test/time/advance_mu.py b/lit-test/test/time/advance_mu.py new file mode 100644 index 000000000..0ff97e1d1 --- /dev/null +++ b/lit-test/test/time/advance_mu.py @@ -0,0 +1,7 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s + +assert now_mu() == 0 +delay_mu(100) +assert now_mu() == 100 +at_mu(12345) +assert now_mu() == 12345 diff --git a/lit-test/test/time/conversion.py b/lit-test/test/time/conversion.py new file mode 100644 index 000000000..ccabb41cb --- /dev/null +++ b/lit-test/test/time/conversion.py @@ -0,0 +1,4 @@ +# RUN: %python -m artiq.compiler.testbench.jit %s + +assert seconds_to_mu(2.0) == 2000000 +assert mu_to_seconds(1500000) == 1.5 diff --git a/setup.py b/setup.py index 055cdfcde..7750a0f38 100755 --- a/setup.py +++ b/setup.py @@ -28,7 +28,9 @@ class PushDocCommand(Command): requirements = [ "sphinx", "sphinx-argparse", "pyserial", "numpy", "scipy", "python-dateutil", "prettytable", "h5py", "pydaqmx", "pyelftools", - "quamash", "pyqtgraph", "llvmlite_artiq", "pygit2", "aiohttp" + "quamash", "pyqtgraph", "pygit2", "aiohttp", + "llvmlite_artiq", "pythonparser", "python-Levenshtein", + "lit", "OutputCheck", ] scripts = [