From 1e851adf4f6803b5792ec8a08c4c47b56970b655 Mon Sep 17 00:00:00 2001 From: whitequark Date: Tue, 21 Jul 2015 22:32:10 +0300 Subject: [PATCH] Add a polymorphic print function. --- artiq/compiler/builtins.py | 3 + artiq/compiler/prelude.py | 1 + .../compiler/transforms/artiq_ir_generator.py | 146 ++++++++++++++---- artiq/compiler/transforms/inferencer.py | 13 ++ .../compiler/transforms/llvm_ir_generator.py | 55 ++++--- 5 files changed, 170 insertions(+), 48 deletions(-) diff --git a/artiq/compiler/builtins.py b/artiq/compiler/builtins.py index 4b914063f..dde80af12 100644 --- a/artiq/compiler/builtins.py +++ b/artiq/compiler/builtins.py @@ -116,6 +116,9 @@ def fn_len(): def fn_round(): return types.TBuiltinFunction("round") +def fn_print(): + return types.TBuiltinFunction("print") + def fn_syscall(): return types.TBuiltinFunction("syscall") diff --git a/artiq/compiler/prelude.py b/artiq/compiler/prelude.py index 98432b753..bed44af6c 100644 --- a/artiq/compiler/prelude.py +++ b/artiq/compiler/prelude.py @@ -17,5 +17,6 @@ def globals(): "ValueError": builtins.fn_ValueError(), "len": builtins.fn_len(), "round": builtins.fn_round(), + "print": builtins.fn_print(), "syscall": builtins.fn_syscall(), } diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index c18b987a2..715845b45 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -338,7 +338,7 @@ class ARTIQIRGenerator(algorithm.Visitor): self.break_target = old_break self.continue_target = old_continue - def _iterable_len(self, value, typ=builtins.TInt(types.TValue(32))): + def iterable_len(self, value, typ=builtins.TInt(types.TValue(32))): if builtins.is_list(value.type): return self.append(ir.Builtin("len", [value], typ)) elif builtins.is_range(value.type): @@ -350,7 +350,7 @@ class ARTIQIRGenerator(algorithm.Visitor): else: assert False - def _iterable_get(self, value, index): + def iterable_get(self, value, index): # Assuming the value is within bounds. if builtins.is_list(value.type): return self.append(ir.GetElem(value, index)) @@ -365,7 +365,7 @@ class ARTIQIRGenerator(algorithm.Visitor): def visit_For(self, node): try: iterable = self.visit(node.iter) - length = self._iterable_len(iterable) + length = self.iterable_len(iterable) prehead = self.current_block head = self.add_block("for.head") @@ -388,7 +388,7 @@ class ARTIQIRGenerator(algorithm.Visitor): body = self.add_block("for.body") self.current_block = body - elt = self._iterable_get(iterable, phi) + elt = self.iterable_get(iterable, phi) try: self.current_assign = elt self.visit(node.target) @@ -669,17 +669,17 @@ class ARTIQIRGenerator(algorithm.Visitor): if isinstance(node.slice, ast.Index): index = self.visit(node.slice.value) - length = self._iterable_len(value, index.type) + length = self.iterable_len(value, index.type) mapped_index = self._map_index(length, index) if self.current_assign is None: - result = self._iterable_get(value, mapped_index) + result = self.iterable_get(value, mapped_index) result.set_name("{}.at.{}".format(value.name, _readable_name(index))) return result else: self.append(ir.SetElem(value, mapped_index, self.current_assign, name="{}.at.{}".format(value.name, _readable_name(index)))) else: # Slice - length = self._iterable_len(value, node.slice.type) + length = self.iterable_len(value, node.slice.type) if node.slice.lower is not None: min_index = self.visit(node.slice.lower) @@ -715,7 +715,7 @@ class ARTIQIRGenerator(algorithm.Visitor): index = self.append(ir.Arith(ast.Add(loc=None), min_index, offset)) if self.current_assign is None: - elem = self._iterable_get(value, index) + elem = self.iterable_get(value, index) self.append(ir.SetElem(other_value, other_index, elem)) else: elem = self.append(ir.GetElem(self.current_assign, other_index)) @@ -771,7 +771,7 @@ class ARTIQIRGenerator(algorithm.Visitor): assert comprehension.ifs == [] iterable = self.visit(comprehension.iter) - length = self._iterable_len(iterable) + length = self.iterable_len(iterable) result = self.append(ir.Alloc([length], node.type)) try: @@ -782,7 +782,7 @@ class ARTIQIRGenerator(algorithm.Visitor): self.append(ir.SetLocal(env, ".outer", old_env)) def body_gen(index): - elt = self._iterable_get(iterable, index) + elt = self.iterable_get(iterable, index) try: old_assign, self.current_assign = self.current_assign, elt print(comprehension.target, self.current_assign) @@ -926,7 +926,7 @@ class ARTIQIRGenerator(algorithm.Visitor): else: assert False - def _compare_pair_order(self, op, lhs, rhs): + def polymorphic_compare_pair_order(self, op, lhs, rhs): if builtins.is_numeric(lhs.type) and builtins.is_numeric(rhs.type): return self.append(ir.Compare(op, lhs, rhs)) elif types.is_tuple(lhs.type) and types.is_tuple(rhs.type): @@ -960,7 +960,7 @@ class ARTIQIRGenerator(algorithm.Visitor): self.current_block = loop_body lhs_elt = self.append(ir.GetElem(lhs, index_phi)) rhs_elt = self.append(ir.GetElem(rhs, index_phi)) - body_result = self._compare_pair(op, lhs_elt, rhs_elt) + body_result = self.polymorphic_compare_pair(op, lhs_elt, rhs_elt) loop_body2 = self.add_block() self.current_block = loop_body2 @@ -989,7 +989,7 @@ class ARTIQIRGenerator(algorithm.Visitor): else: assert False - def _compare_pair_inclusion(self, op, needle, haystack): + def polymorphic_compare_pair_inclusion(self, op, needle, haystack): if builtins.is_range(haystack.type): # Optimized range `in` operator start = self.append(ir.GetAttr(haystack, "start")) @@ -1005,15 +1005,15 @@ class ARTIQIRGenerator(algorithm.Visitor): ir.Constant(False, builtins.TBool()))) result = self.append(ir.Select(result, on_step, ir.Constant(False, builtins.TBool()))) - elif builtins.is_iterable(haystack.type): - length = self._iterable_len(haystack) + elif builtins.isiterable(haystack.type): + length = self.iterable_len(haystack) cmp_result = loop_body2 = None def body_gen(index): nonlocal cmp_result, loop_body2 - elt = self._iterable_get(haystack, index) - cmp_result = self._compare_pair(ast.Eq(loc=None), needle, elt) + elt = self.iterable_get(haystack, index) + cmp_result = self.polymorphic_compare_pair(ast.Eq(loc=None), needle, elt) loop_body2 = self.add_block() self.current_block = loop_body2 @@ -1040,7 +1040,7 @@ class ARTIQIRGenerator(algorithm.Visitor): return result - def _compare_pair_identity(self, op, lhs, rhs): + def polymorphic_compare_pair_identity(self, op, lhs, rhs): if builtins.is_allocated(lhs) and builtins.is_allocated(rhs): # These are actually pointers, compare directly. return self.append(ir.Compare(op, lhs, rhs)) @@ -1053,15 +1053,15 @@ class ARTIQIRGenerator(algorithm.Visitor): op = ast.NotEq(loc=None) else: assert False - return self._compare_pair_order(op, lhs, rhs) + return self.polymorphic_compare_pair_order(op, lhs, rhs) - def _compare_pair(self, op, lhs, rhs): + def polymorphic_compare_pair(self, op, lhs, rhs): if isinstance(op, (ast.Is, ast.IsNot)): - return self._compare_pair_identity(op, lhs, rhs) + return self.polymorphic_compare_pair_identity(op, lhs, rhs) elif isinstance(op, (ast.In, ast.NotIn)): - return self._compare_pair_inclusion(op, lhs, rhs) + return self.polymorphic_compare_pair_inclusion(op, lhs, rhs) else: # Eq, NotEq, Lt, LtE, Gt, GtE - return self._compare_pair_order(op, lhs, rhs) + return self.polymorphic_compare_pair_order(op, lhs, rhs) def visit_CompareT(self, node): # Essentially a sequence of `and`s performed over results @@ -1070,7 +1070,7 @@ class ARTIQIRGenerator(algorithm.Visitor): lhs = self.visit(node.left) for op, rhs_node in zip(node.ops, node.comparators): rhs = self.visit(rhs_node) - result = self._compare_pair(op, lhs, rhs) + result = self.polymorphic_compare_pair(op, lhs, rhs) blocks.append((result, self.current_block)) self.current_block = self.add_block() lhs = rhs @@ -1120,11 +1120,11 @@ class ARTIQIRGenerator(algorithm.Visitor): return self.append(ir.Alloc(node.type, length)) elif len(node.args) == 1 and len(node.keywords) == 0: arg = self.visit(node.args[0]) - length = self._iterable_len(arg) + length = self.iterable_len(arg) result = self.append(ir.Alloc([length], node.type)) def body_gen(index): - elt = self._iterable_get(arg, index) + elt = self.iterable_get(arg, index) self.append(ir.SetElem(result, index, elt)) return self.append(ir.Arith(ast.Add(loc=None), index, ir.Constant(1, length.type))) @@ -1136,7 +1136,7 @@ class ARTIQIRGenerator(algorithm.Visitor): else: assert False elif types.is_builtin(typ, "range"): - elt_typ = builtins.get_iterable_elt(node.type) + elt_typ = builtins.getiterable_elt(node.type) if len(node.args) == 1 and len(node.keywords) == 0: max_arg = self.visit(node.args[0]) return self.append(ir.Alloc([ @@ -1166,7 +1166,7 @@ class ARTIQIRGenerator(algorithm.Visitor): elif types.is_builtin(typ, "len"): if len(node.args) == 1 and len(node.keywords) == 0: arg = self.visit(node.args[0]) - return self._iterable_len(arg) + return self.iterable_len(arg) else: assert False elif types.is_builtin(typ, "round"): @@ -1175,6 +1175,10 @@ class ARTIQIRGenerator(algorithm.Visitor): return self.append(ir.Builtin("round", [arg])) else: assert False + elif types.is_builtin(typ, "print"): + self.polymorphic_print([self.visit(arg) for arg in node.args], + separator=" ", suffix="\n") + return ir.Constant(None, builtins.TNone()) elif types.is_exn_constructor(typ): return self.append(ir.Alloc([self.visit(arg) for args in node.args], node.type)) else: @@ -1206,3 +1210,89 @@ class ARTIQIRGenerator(algorithm.Visitor): invoke = self.append(ir.Invoke(func, args, after_invoke, self.unwind_target)) self.current_block = after_invoke return invoke + + def polymorphic_print(self, values, separator, suffix=""): + format_string = "" + args = [] + def flush(): + nonlocal format_string, args + if format_string != "": + format_arg = [ir.Constant(format_string, builtins.TStr())] + self.append(ir.Builtin("printf", format_arg + args, builtins.TNone())) + format_string = "" + args = [] + + for value in values: + if format_string != "": + format_string += separator + + if types.is_tuple(value.type): + format_string += "("; flush() + self.polymorphic_print([self.append(ir.GetAttr(value, index)) + for index in range(len(value.type.elts))], + separator=", ") + format_string += ")" + elif types.is_function(value.type): + format_string += "" + # We're relying on the internal layout of the closure here. + args.append(self.append(ir.GetAttr(value, 0))) + args.append(self.append(ir.GetAttr(value, 1))) + elif builtins.is_none(value.type): + format_string += "None" + elif builtins.is_bool(value.type): + format_string += "%s" + args.append(self.append(ir.Select(value, + ir.Constant("True", builtins.TStr()), + ir.Constant("False", builtins.TStr())))) + elif builtins.is_int(value.type): + format_string += "%d" + args.append(value) + elif builtins.is_float(value.type): + format_string += "%g" + args.append(value) + elif builtins.is_str(value.type): + format_string += "%s" + args.append(value) + elif builtins.is_list(value.type): + format_string += "["; flush() + + length = self.iterable_len(value) + last = self.append(ir.Arith(ast.Sub(loc=None), length, ir.Constant(1, length.type))) + def body_gen(index): + elt = self.iterable_get(value, index) + self.polymorphic_print([elt], separator="") + is_last = self.append(ir.Compare(ast.Lt(loc=None), index, last)) + head = self.current_block + + if_last = self.current_block = self.add_block() + self.append(ir.Builtin("printf", + [ir.Constant(", ", builtins.TStr())], builtins.TNone())) + + tail = self.current_block = self.add_block() + if_last.append(ir.Branch(tail)) + head.append(ir.BranchIf(is_last, if_last, tail)) + + return self.append(ir.Arith(ast.Add(loc=None), index, + ir.Constant(1, length.type))) + self._make_loop(ir.Constant(0, length.type), + lambda index: self.append(ir.Compare(ast.Lt(loc=None), index, length)), + body_gen) + + format_string += "]" + elif builtins.is_range(value.type): + format_string += "range("; flush() + + start = self.append(ir.GetAttr(value, "start")) + stop = self.append(ir.GetAttr(value, "stop")) + step = self.append(ir.GetAttr(value, "step")) + self.polymorphic_print([start, stop, step], separator=", ") + + format_string += ")" + elif builtins.is_exception(value.type): + # TODO: print exceptions + assert False + else: + assert False + + format_string += suffix + flush() diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py index 3dd95aee0..a8b87491a 100644 --- a/artiq/compiler/transforms/inferencer.py +++ b/artiq/compiler/transforms/inferencer.py @@ -637,6 +637,19 @@ class Inferencer(algorithm.Visitor): arg.loc, None) else: diagnose(valid_forms()) + elif types.is_builtin(typ, "print"): + valid_forms = lambda: [ + valid_form("print(args...) -> None"), + ] + + self._unify(node.type, builtins.TNone(), + node.loc, None) + + if len(node.keywords) == 0: + # We can print any arguments. + pass + else: + diagnose(valid_forms()) # TODO: add when it is clear what interface syscall() has # elif types.is_builtin(typ, "syscall"): # valid_Forms = lambda: [ diff --git a/artiq/compiler/transforms/llvm_ir_generator.py b/artiq/compiler/transforms/llvm_ir_generator.py index 8405c147c..80908e10b 100644 --- a/artiq/compiler/transforms/llvm_ir_generator.py +++ b/artiq/compiler/transforms/llvm_ir_generator.py @@ -90,6 +90,25 @@ class LLVMIRGenerator: else: assert False + def llbuiltin(self, name): + llfun = self.llmodule.get_global(name) + if llfun is not None: + return llfun + + if name in ("llvm.abort", "llvm.donothing"): + llty = ll.FunctionType(ll.VoidType(), []) + elif name == "llvm.round.f64": + llty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType()]) + elif name == "llvm.pow.f64": + llty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType(), ll.DoubleType()]) + elif name == "llvm.powi.f64": + llty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType(), ll.IntType(32)]) + elif name == "printf": + llty = ll.FunctionType(ll.VoidType(), [ll.IntType(8).as_pointer()], var_arg=True) + else: + assert False + return ll.Function(self.llmodule, llty, name) + def map(self, value): if isinstance(value, (ir.Argument, ir.Instruction, ir.BasicBlock)): return self.llmap[value] @@ -214,7 +233,7 @@ class LLVMIRGenerator: def process_GetAttr(self, insn): if types.is_tuple(insn.object().type): - return self.llbuilder.extract_value(self.map(insn.object()), self.attr_index(insn), + return self.llbuilder.extract_value(self.map(insn.object()), insn.attr, name=insn.name) elif not builtins.is_allocated(insn.object().type): return self.llbuilder.extract_value(self.map(insn.object()), self.attr_index(insn), @@ -296,9 +315,7 @@ class LLVMIRGenerator: elif isinstance(insn.op, ast.FloorDiv): if builtins.is_float(insn.type): llvalue = self.llbuilder.fdiv(self.map(insn.lhs()), self.map(insn.rhs())) - llfnty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType()]) - llfn = ll.Function(self.llmodule, llfnty, "llvm.round.f64") - return self.llbuilder.call(llfn, [llvalue], + return self.llbuilder.call(self.llbuiltin("llvm.round.f64"), [llvalue], name=insn.name) else: return self.llbuilder.sdiv(self.map(insn.lhs()), self.map(insn.rhs()), @@ -312,15 +329,13 @@ class LLVMIRGenerator: name=insn.name) elif isinstance(insn.op, ast.Pow): if builtins.is_float(insn.type): - llfnty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType(), ll.DoubleType()]) - llfn = ll.Function(self.llmodule, llfnty, "llvm.pow.f64") - return self.llbuilder.call(llfn, [self.map(insn.lhs()), self.map(insn.rhs())], + return self.llbuilder.call(self.llbuiltin("llvm.pow.f64"), + [self.map(insn.lhs()), self.map(insn.rhs())], name=insn.name) else: llrhs = self.llbuilder.trunc(self.map(insn.rhs()), ll.IntType(32)) - llfnty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType(), ll.IntType(32)]) - llfn = ll.Function(self.llmodule, llfnty, "llvm.powi.f64") - llvalue = self.llbuilder.call(llfn, [self.map(insn.lhs()), llrhs]) + llvalue = self.llbuilder.call(self.llbuiltin("llvm.powi.f64"), + [self.map(insn.lhs()), llrhs]) return self.llbuilder.fptosi(llvalue, self.llty_of_type(insn.type), name=insn.name) elif isinstance(insn.op, ast.LShift): @@ -366,8 +381,7 @@ class LLVMIRGenerator: def process_Builtin(self, insn): if insn.op == "nop": - fn = ll.Function(self.llmodule, ll.FunctionType(ll.VoidType(), []), "llvm.donothing") - return self.llbuilder.call(fn, []) + return self.llbuilder.call(self.llbuiltin("llvm.donothing"), []) elif insn.op == "unwrap": optarg, default = map(self.map, insn.operands) has_arg = self.llbuilder.extract_value(optarg, 0) @@ -375,9 +389,7 @@ class LLVMIRGenerator: return self.llbuilder.select(has_arg, arg, default, name=insn.name) elif insn.op == "round": - llfnty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType()]) - llfn = ll.Function(self.llmodule, llfnty, "llvm.round.f64") - return self.llbuilder.call(llfn, [llvalue], + return self.llbuilder.call(self.llbuiltin("llvm.round.f64"), [llvalue], name=insn.name) elif insn.op == "globalenv": def get_outer(llenv, env_ty): @@ -394,6 +406,11 @@ class LLVMIRGenerator: elif insn.op == "len": lst, = insn.operands return self.llbuilder.extract_value(self.map(lst), 0) + elif insn.op == "printf": + # We only get integers, floats, pointers and strings here. + llargs = map(self.map, insn.operands) + return self.llbuilder.call(self.llbuiltin("printf"), llargs, + name=insn.name) # elif insn.op == "exncast": else: assert False @@ -414,8 +431,8 @@ class LLVMIRGenerator: name=insn.name) def process_Select(self, insn): - return self.llbuilder.select(self.map(insn.cond()), - self.map(insn.lhs()), self.map(insn.rhs())) + return self.llbuilder.select(self.map(insn.condition()), + self.map(insn.if_true()), self.map(insn.if_false())) def process_Branch(self, insn): return self.llbuilder.branch(self.map(insn.target())) @@ -438,9 +455,7 @@ class LLVMIRGenerator: def process_Raise(self, insn): # TODO: hack before EH is working - llfnty = ll.FunctionType(ll.VoidType(), []) - llfn = ll.Function(self.llmodule, llfnty, "llvm.abort") - llinsn = self.llbuilder.call(llfn, [], + llinsn = self.llbuilder.call(self.llbuiltin("llvm.abort"), [], name=insn.name) self.llbuilder.unreachable() return llinsn