forked from M-Labs/artiq
Add a polymorphic print function.
This commit is contained in:
parent
0e7294db8d
commit
1e851adf4f
@ -116,6 +116,9 @@ def fn_len():
|
||||
def fn_round():
|
||||
return types.TBuiltinFunction("round")
|
||||
|
||||
def fn_print():
|
||||
return types.TBuiltinFunction("print")
|
||||
|
||||
def fn_syscall():
|
||||
return types.TBuiltinFunction("syscall")
|
||||
|
||||
|
@ -17,5 +17,6 @@ def globals():
|
||||
"ValueError": builtins.fn_ValueError(),
|
||||
"len": builtins.fn_len(),
|
||||
"round": builtins.fn_round(),
|
||||
"print": builtins.fn_print(),
|
||||
"syscall": builtins.fn_syscall(),
|
||||
}
|
||||
|
@ -338,7 +338,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
self.break_target = old_break
|
||||
self.continue_target = old_continue
|
||||
|
||||
def _iterable_len(self, value, typ=builtins.TInt(types.TValue(32))):
|
||||
def iterable_len(self, value, typ=builtins.TInt(types.TValue(32))):
|
||||
if builtins.is_list(value.type):
|
||||
return self.append(ir.Builtin("len", [value], typ))
|
||||
elif builtins.is_range(value.type):
|
||||
@ -350,7 +350,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
else:
|
||||
assert False
|
||||
|
||||
def _iterable_get(self, value, index):
|
||||
def iterable_get(self, value, index):
|
||||
# Assuming the value is within bounds.
|
||||
if builtins.is_list(value.type):
|
||||
return self.append(ir.GetElem(value, index))
|
||||
@ -365,7 +365,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
def visit_For(self, node):
|
||||
try:
|
||||
iterable = self.visit(node.iter)
|
||||
length = self._iterable_len(iterable)
|
||||
length = self.iterable_len(iterable)
|
||||
prehead = self.current_block
|
||||
|
||||
head = self.add_block("for.head")
|
||||
@ -388,7 +388,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
|
||||
body = self.add_block("for.body")
|
||||
self.current_block = body
|
||||
elt = self._iterable_get(iterable, phi)
|
||||
elt = self.iterable_get(iterable, phi)
|
||||
try:
|
||||
self.current_assign = elt
|
||||
self.visit(node.target)
|
||||
@ -669,17 +669,17 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
|
||||
if isinstance(node.slice, ast.Index):
|
||||
index = self.visit(node.slice.value)
|
||||
length = self._iterable_len(value, index.type)
|
||||
length = self.iterable_len(value, index.type)
|
||||
mapped_index = self._map_index(length, index)
|
||||
if self.current_assign is None:
|
||||
result = self._iterable_get(value, mapped_index)
|
||||
result = self.iterable_get(value, mapped_index)
|
||||
result.set_name("{}.at.{}".format(value.name, _readable_name(index)))
|
||||
return result
|
||||
else:
|
||||
self.append(ir.SetElem(value, mapped_index, self.current_assign,
|
||||
name="{}.at.{}".format(value.name, _readable_name(index))))
|
||||
else: # Slice
|
||||
length = self._iterable_len(value, node.slice.type)
|
||||
length = self.iterable_len(value, node.slice.type)
|
||||
|
||||
if node.slice.lower is not None:
|
||||
min_index = self.visit(node.slice.lower)
|
||||
@ -715,7 +715,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
index = self.append(ir.Arith(ast.Add(loc=None), min_index, offset))
|
||||
|
||||
if self.current_assign is None:
|
||||
elem = self._iterable_get(value, index)
|
||||
elem = self.iterable_get(value, index)
|
||||
self.append(ir.SetElem(other_value, other_index, elem))
|
||||
else:
|
||||
elem = self.append(ir.GetElem(self.current_assign, other_index))
|
||||
@ -771,7 +771,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
assert comprehension.ifs == []
|
||||
|
||||
iterable = self.visit(comprehension.iter)
|
||||
length = self._iterable_len(iterable)
|
||||
length = self.iterable_len(iterable)
|
||||
result = self.append(ir.Alloc([length], node.type))
|
||||
|
||||
try:
|
||||
@ -782,7 +782,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
self.append(ir.SetLocal(env, ".outer", old_env))
|
||||
|
||||
def body_gen(index):
|
||||
elt = self._iterable_get(iterable, index)
|
||||
elt = self.iterable_get(iterable, index)
|
||||
try:
|
||||
old_assign, self.current_assign = self.current_assign, elt
|
||||
print(comprehension.target, self.current_assign)
|
||||
@ -926,7 +926,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
else:
|
||||
assert False
|
||||
|
||||
def _compare_pair_order(self, op, lhs, rhs):
|
||||
def polymorphic_compare_pair_order(self, op, lhs, rhs):
|
||||
if builtins.is_numeric(lhs.type) and builtins.is_numeric(rhs.type):
|
||||
return self.append(ir.Compare(op, lhs, rhs))
|
||||
elif types.is_tuple(lhs.type) and types.is_tuple(rhs.type):
|
||||
@ -960,7 +960,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
self.current_block = loop_body
|
||||
lhs_elt = self.append(ir.GetElem(lhs, index_phi))
|
||||
rhs_elt = self.append(ir.GetElem(rhs, index_phi))
|
||||
body_result = self._compare_pair(op, lhs_elt, rhs_elt)
|
||||
body_result = self.polymorphic_compare_pair(op, lhs_elt, rhs_elt)
|
||||
|
||||
loop_body2 = self.add_block()
|
||||
self.current_block = loop_body2
|
||||
@ -989,7 +989,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
else:
|
||||
assert False
|
||||
|
||||
def _compare_pair_inclusion(self, op, needle, haystack):
|
||||
def polymorphic_compare_pair_inclusion(self, op, needle, haystack):
|
||||
if builtins.is_range(haystack.type):
|
||||
# Optimized range `in` operator
|
||||
start = self.append(ir.GetAttr(haystack, "start"))
|
||||
@ -1005,15 +1005,15 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
ir.Constant(False, builtins.TBool())))
|
||||
result = self.append(ir.Select(result, on_step,
|
||||
ir.Constant(False, builtins.TBool())))
|
||||
elif builtins.is_iterable(haystack.type):
|
||||
length = self._iterable_len(haystack)
|
||||
elif builtins.isiterable(haystack.type):
|
||||
length = self.iterable_len(haystack)
|
||||
|
||||
cmp_result = loop_body2 = None
|
||||
def body_gen(index):
|
||||
nonlocal cmp_result, loop_body2
|
||||
|
||||
elt = self._iterable_get(haystack, index)
|
||||
cmp_result = self._compare_pair(ast.Eq(loc=None), needle, elt)
|
||||
elt = self.iterable_get(haystack, index)
|
||||
cmp_result = self.polymorphic_compare_pair(ast.Eq(loc=None), needle, elt)
|
||||
|
||||
loop_body2 = self.add_block()
|
||||
self.current_block = loop_body2
|
||||
@ -1040,7 +1040,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
|
||||
return result
|
||||
|
||||
def _compare_pair_identity(self, op, lhs, rhs):
|
||||
def polymorphic_compare_pair_identity(self, op, lhs, rhs):
|
||||
if builtins.is_allocated(lhs) and builtins.is_allocated(rhs):
|
||||
# These are actually pointers, compare directly.
|
||||
return self.append(ir.Compare(op, lhs, rhs))
|
||||
@ -1053,15 +1053,15 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
op = ast.NotEq(loc=None)
|
||||
else:
|
||||
assert False
|
||||
return self._compare_pair_order(op, lhs, rhs)
|
||||
return self.polymorphic_compare_pair_order(op, lhs, rhs)
|
||||
|
||||
def _compare_pair(self, op, lhs, rhs):
|
||||
def polymorphic_compare_pair(self, op, lhs, rhs):
|
||||
if isinstance(op, (ast.Is, ast.IsNot)):
|
||||
return self._compare_pair_identity(op, lhs, rhs)
|
||||
return self.polymorphic_compare_pair_identity(op, lhs, rhs)
|
||||
elif isinstance(op, (ast.In, ast.NotIn)):
|
||||
return self._compare_pair_inclusion(op, lhs, rhs)
|
||||
return self.polymorphic_compare_pair_inclusion(op, lhs, rhs)
|
||||
else: # Eq, NotEq, Lt, LtE, Gt, GtE
|
||||
return self._compare_pair_order(op, lhs, rhs)
|
||||
return self.polymorphic_compare_pair_order(op, lhs, rhs)
|
||||
|
||||
def visit_CompareT(self, node):
|
||||
# Essentially a sequence of `and`s performed over results
|
||||
@ -1070,7 +1070,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
lhs = self.visit(node.left)
|
||||
for op, rhs_node in zip(node.ops, node.comparators):
|
||||
rhs = self.visit(rhs_node)
|
||||
result = self._compare_pair(op, lhs, rhs)
|
||||
result = self.polymorphic_compare_pair(op, lhs, rhs)
|
||||
blocks.append((result, self.current_block))
|
||||
self.current_block = self.add_block()
|
||||
lhs = rhs
|
||||
@ -1120,11 +1120,11 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
return self.append(ir.Alloc(node.type, length))
|
||||
elif len(node.args) == 1 and len(node.keywords) == 0:
|
||||
arg = self.visit(node.args[0])
|
||||
length = self._iterable_len(arg)
|
||||
length = self.iterable_len(arg)
|
||||
result = self.append(ir.Alloc([length], node.type))
|
||||
|
||||
def body_gen(index):
|
||||
elt = self._iterable_get(arg, index)
|
||||
elt = self.iterable_get(arg, index)
|
||||
self.append(ir.SetElem(result, index, elt))
|
||||
return self.append(ir.Arith(ast.Add(loc=None), index,
|
||||
ir.Constant(1, length.type)))
|
||||
@ -1136,7 +1136,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
else:
|
||||
assert False
|
||||
elif types.is_builtin(typ, "range"):
|
||||
elt_typ = builtins.get_iterable_elt(node.type)
|
||||
elt_typ = builtins.getiterable_elt(node.type)
|
||||
if len(node.args) == 1 and len(node.keywords) == 0:
|
||||
max_arg = self.visit(node.args[0])
|
||||
return self.append(ir.Alloc([
|
||||
@ -1166,7 +1166,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
elif types.is_builtin(typ, "len"):
|
||||
if len(node.args) == 1 and len(node.keywords) == 0:
|
||||
arg = self.visit(node.args[0])
|
||||
return self._iterable_len(arg)
|
||||
return self.iterable_len(arg)
|
||||
else:
|
||||
assert False
|
||||
elif types.is_builtin(typ, "round"):
|
||||
@ -1175,6 +1175,10 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
return self.append(ir.Builtin("round", [arg]))
|
||||
else:
|
||||
assert False
|
||||
elif types.is_builtin(typ, "print"):
|
||||
self.polymorphic_print([self.visit(arg) for arg in node.args],
|
||||
separator=" ", suffix="\n")
|
||||
return ir.Constant(None, builtins.TNone())
|
||||
elif types.is_exn_constructor(typ):
|
||||
return self.append(ir.Alloc([self.visit(arg) for args in node.args], node.type))
|
||||
else:
|
||||
@ -1206,3 +1210,89 @@ class ARTIQIRGenerator(algorithm.Visitor):
|
||||
invoke = self.append(ir.Invoke(func, args, after_invoke, self.unwind_target))
|
||||
self.current_block = after_invoke
|
||||
return invoke
|
||||
|
||||
def polymorphic_print(self, values, separator, suffix=""):
|
||||
format_string = ""
|
||||
args = []
|
||||
def flush():
|
||||
nonlocal format_string, args
|
||||
if format_string != "":
|
||||
format_arg = [ir.Constant(format_string, builtins.TStr())]
|
||||
self.append(ir.Builtin("printf", format_arg + args, builtins.TNone()))
|
||||
format_string = ""
|
||||
args = []
|
||||
|
||||
for value in values:
|
||||
if format_string != "":
|
||||
format_string += separator
|
||||
|
||||
if types.is_tuple(value.type):
|
||||
format_string += "("; flush()
|
||||
self.polymorphic_print([self.append(ir.GetAttr(value, index))
|
||||
for index in range(len(value.type.elts))],
|
||||
separator=", ")
|
||||
format_string += ")"
|
||||
elif types.is_function(value.type):
|
||||
format_string += "<closure %p(%p)>"
|
||||
# We're relying on the internal layout of the closure here.
|
||||
args.append(self.append(ir.GetAttr(value, 0)))
|
||||
args.append(self.append(ir.GetAttr(value, 1)))
|
||||
elif builtins.is_none(value.type):
|
||||
format_string += "None"
|
||||
elif builtins.is_bool(value.type):
|
||||
format_string += "%s"
|
||||
args.append(self.append(ir.Select(value,
|
||||
ir.Constant("True", builtins.TStr()),
|
||||
ir.Constant("False", builtins.TStr()))))
|
||||
elif builtins.is_int(value.type):
|
||||
format_string += "%d"
|
||||
args.append(value)
|
||||
elif builtins.is_float(value.type):
|
||||
format_string += "%g"
|
||||
args.append(value)
|
||||
elif builtins.is_str(value.type):
|
||||
format_string += "%s"
|
||||
args.append(value)
|
||||
elif builtins.is_list(value.type):
|
||||
format_string += "["; flush()
|
||||
|
||||
length = self.iterable_len(value)
|
||||
last = self.append(ir.Arith(ast.Sub(loc=None), length, ir.Constant(1, length.type)))
|
||||
def body_gen(index):
|
||||
elt = self.iterable_get(value, index)
|
||||
self.polymorphic_print([elt], separator="")
|
||||
is_last = self.append(ir.Compare(ast.Lt(loc=None), index, last))
|
||||
head = self.current_block
|
||||
|
||||
if_last = self.current_block = self.add_block()
|
||||
self.append(ir.Builtin("printf",
|
||||
[ir.Constant(", ", builtins.TStr())], builtins.TNone()))
|
||||
|
||||
tail = self.current_block = self.add_block()
|
||||
if_last.append(ir.Branch(tail))
|
||||
head.append(ir.BranchIf(is_last, if_last, tail))
|
||||
|
||||
return self.append(ir.Arith(ast.Add(loc=None), index,
|
||||
ir.Constant(1, length.type)))
|
||||
self._make_loop(ir.Constant(0, length.type),
|
||||
lambda index: self.append(ir.Compare(ast.Lt(loc=None), index, length)),
|
||||
body_gen)
|
||||
|
||||
format_string += "]"
|
||||
elif builtins.is_range(value.type):
|
||||
format_string += "range("; flush()
|
||||
|
||||
start = self.append(ir.GetAttr(value, "start"))
|
||||
stop = self.append(ir.GetAttr(value, "stop"))
|
||||
step = self.append(ir.GetAttr(value, "step"))
|
||||
self.polymorphic_print([start, stop, step], separator=", ")
|
||||
|
||||
format_string += ")"
|
||||
elif builtins.is_exception(value.type):
|
||||
# TODO: print exceptions
|
||||
assert False
|
||||
else:
|
||||
assert False
|
||||
|
||||
format_string += suffix
|
||||
flush()
|
||||
|
@ -637,6 +637,19 @@ class Inferencer(algorithm.Visitor):
|
||||
arg.loc, None)
|
||||
else:
|
||||
diagnose(valid_forms())
|
||||
elif types.is_builtin(typ, "print"):
|
||||
valid_forms = lambda: [
|
||||
valid_form("print(args...) -> None"),
|
||||
]
|
||||
|
||||
self._unify(node.type, builtins.TNone(),
|
||||
node.loc, None)
|
||||
|
||||
if len(node.keywords) == 0:
|
||||
# We can print any arguments.
|
||||
pass
|
||||
else:
|
||||
diagnose(valid_forms())
|
||||
# TODO: add when it is clear what interface syscall() has
|
||||
# elif types.is_builtin(typ, "syscall"):
|
||||
# valid_Forms = lambda: [
|
||||
|
@ -90,6 +90,25 @@ class LLVMIRGenerator:
|
||||
else:
|
||||
assert False
|
||||
|
||||
def llbuiltin(self, name):
|
||||
llfun = self.llmodule.get_global(name)
|
||||
if llfun is not None:
|
||||
return llfun
|
||||
|
||||
if name in ("llvm.abort", "llvm.donothing"):
|
||||
llty = ll.FunctionType(ll.VoidType(), [])
|
||||
elif name == "llvm.round.f64":
|
||||
llty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType()])
|
||||
elif name == "llvm.pow.f64":
|
||||
llty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType(), ll.DoubleType()])
|
||||
elif name == "llvm.powi.f64":
|
||||
llty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType(), ll.IntType(32)])
|
||||
elif name == "printf":
|
||||
llty = ll.FunctionType(ll.VoidType(), [ll.IntType(8).as_pointer()], var_arg=True)
|
||||
else:
|
||||
assert False
|
||||
return ll.Function(self.llmodule, llty, name)
|
||||
|
||||
def map(self, value):
|
||||
if isinstance(value, (ir.Argument, ir.Instruction, ir.BasicBlock)):
|
||||
return self.llmap[value]
|
||||
@ -214,7 +233,7 @@ class LLVMIRGenerator:
|
||||
|
||||
def process_GetAttr(self, insn):
|
||||
if types.is_tuple(insn.object().type):
|
||||
return self.llbuilder.extract_value(self.map(insn.object()), self.attr_index(insn),
|
||||
return self.llbuilder.extract_value(self.map(insn.object()), insn.attr,
|
||||
name=insn.name)
|
||||
elif not builtins.is_allocated(insn.object().type):
|
||||
return self.llbuilder.extract_value(self.map(insn.object()), self.attr_index(insn),
|
||||
@ -296,9 +315,7 @@ class LLVMIRGenerator:
|
||||
elif isinstance(insn.op, ast.FloorDiv):
|
||||
if builtins.is_float(insn.type):
|
||||
llvalue = self.llbuilder.fdiv(self.map(insn.lhs()), self.map(insn.rhs()))
|
||||
llfnty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType()])
|
||||
llfn = ll.Function(self.llmodule, llfnty, "llvm.round.f64")
|
||||
return self.llbuilder.call(llfn, [llvalue],
|
||||
return self.llbuilder.call(self.llbuiltin("llvm.round.f64"), [llvalue],
|
||||
name=insn.name)
|
||||
else:
|
||||
return self.llbuilder.sdiv(self.map(insn.lhs()), self.map(insn.rhs()),
|
||||
@ -312,15 +329,13 @@ class LLVMIRGenerator:
|
||||
name=insn.name)
|
||||
elif isinstance(insn.op, ast.Pow):
|
||||
if builtins.is_float(insn.type):
|
||||
llfnty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType(), ll.DoubleType()])
|
||||
llfn = ll.Function(self.llmodule, llfnty, "llvm.pow.f64")
|
||||
return self.llbuilder.call(llfn, [self.map(insn.lhs()), self.map(insn.rhs())],
|
||||
return self.llbuilder.call(self.llbuiltin("llvm.pow.f64"),
|
||||
[self.map(insn.lhs()), self.map(insn.rhs())],
|
||||
name=insn.name)
|
||||
else:
|
||||
llrhs = self.llbuilder.trunc(self.map(insn.rhs()), ll.IntType(32))
|
||||
llfnty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType(), ll.IntType(32)])
|
||||
llfn = ll.Function(self.llmodule, llfnty, "llvm.powi.f64")
|
||||
llvalue = self.llbuilder.call(llfn, [self.map(insn.lhs()), llrhs])
|
||||
llvalue = self.llbuilder.call(self.llbuiltin("llvm.powi.f64"),
|
||||
[self.map(insn.lhs()), llrhs])
|
||||
return self.llbuilder.fptosi(llvalue, self.llty_of_type(insn.type),
|
||||
name=insn.name)
|
||||
elif isinstance(insn.op, ast.LShift):
|
||||
@ -366,8 +381,7 @@ class LLVMIRGenerator:
|
||||
|
||||
def process_Builtin(self, insn):
|
||||
if insn.op == "nop":
|
||||
fn = ll.Function(self.llmodule, ll.FunctionType(ll.VoidType(), []), "llvm.donothing")
|
||||
return self.llbuilder.call(fn, [])
|
||||
return self.llbuilder.call(self.llbuiltin("llvm.donothing"), [])
|
||||
elif insn.op == "unwrap":
|
||||
optarg, default = map(self.map, insn.operands)
|
||||
has_arg = self.llbuilder.extract_value(optarg, 0)
|
||||
@ -375,9 +389,7 @@ class LLVMIRGenerator:
|
||||
return self.llbuilder.select(has_arg, arg, default,
|
||||
name=insn.name)
|
||||
elif insn.op == "round":
|
||||
llfnty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType()])
|
||||
llfn = ll.Function(self.llmodule, llfnty, "llvm.round.f64")
|
||||
return self.llbuilder.call(llfn, [llvalue],
|
||||
return self.llbuilder.call(self.llbuiltin("llvm.round.f64"), [llvalue],
|
||||
name=insn.name)
|
||||
elif insn.op == "globalenv":
|
||||
def get_outer(llenv, env_ty):
|
||||
@ -394,6 +406,11 @@ class LLVMIRGenerator:
|
||||
elif insn.op == "len":
|
||||
lst, = insn.operands
|
||||
return self.llbuilder.extract_value(self.map(lst), 0)
|
||||
elif insn.op == "printf":
|
||||
# We only get integers, floats, pointers and strings here.
|
||||
llargs = map(self.map, insn.operands)
|
||||
return self.llbuilder.call(self.llbuiltin("printf"), llargs,
|
||||
name=insn.name)
|
||||
# elif insn.op == "exncast":
|
||||
else:
|
||||
assert False
|
||||
@ -414,8 +431,8 @@ class LLVMIRGenerator:
|
||||
name=insn.name)
|
||||
|
||||
def process_Select(self, insn):
|
||||
return self.llbuilder.select(self.map(insn.cond()),
|
||||
self.map(insn.lhs()), self.map(insn.rhs()))
|
||||
return self.llbuilder.select(self.map(insn.condition()),
|
||||
self.map(insn.if_true()), self.map(insn.if_false()))
|
||||
|
||||
def process_Branch(self, insn):
|
||||
return self.llbuilder.branch(self.map(insn.target()))
|
||||
@ -438,9 +455,7 @@ class LLVMIRGenerator:
|
||||
|
||||
def process_Raise(self, insn):
|
||||
# TODO: hack before EH is working
|
||||
llfnty = ll.FunctionType(ll.VoidType(), [])
|
||||
llfn = ll.Function(self.llmodule, llfnty, "llvm.abort")
|
||||
llinsn = self.llbuilder.call(llfn, [],
|
||||
llinsn = self.llbuilder.call(self.llbuiltin("llvm.abort"), [],
|
||||
name=insn.name)
|
||||
self.llbuilder.unreachable()
|
||||
return llinsn
|
||||
|
Loading…
Reference in New Issue
Block a user