Add a polymorphic print function.

This commit is contained in:
whitequark 2015-07-21 22:32:10 +03:00
parent 0e7294db8d
commit 1e851adf4f
5 changed files with 170 additions and 48 deletions

View File

@ -116,6 +116,9 @@ def fn_len():
def fn_round():
return types.TBuiltinFunction("round")
def fn_print():
return types.TBuiltinFunction("print")
def fn_syscall():
return types.TBuiltinFunction("syscall")

View File

@ -17,5 +17,6 @@ def globals():
"ValueError": builtins.fn_ValueError(),
"len": builtins.fn_len(),
"round": builtins.fn_round(),
"print": builtins.fn_print(),
"syscall": builtins.fn_syscall(),
}

View File

@ -338,7 +338,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
self.break_target = old_break
self.continue_target = old_continue
def _iterable_len(self, value, typ=builtins.TInt(types.TValue(32))):
def iterable_len(self, value, typ=builtins.TInt(types.TValue(32))):
if builtins.is_list(value.type):
return self.append(ir.Builtin("len", [value], typ))
elif builtins.is_range(value.type):
@ -350,7 +350,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
else:
assert False
def _iterable_get(self, value, index):
def iterable_get(self, value, index):
# Assuming the value is within bounds.
if builtins.is_list(value.type):
return self.append(ir.GetElem(value, index))
@ -365,7 +365,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
def visit_For(self, node):
try:
iterable = self.visit(node.iter)
length = self._iterable_len(iterable)
length = self.iterable_len(iterable)
prehead = self.current_block
head = self.add_block("for.head")
@ -388,7 +388,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
body = self.add_block("for.body")
self.current_block = body
elt = self._iterable_get(iterable, phi)
elt = self.iterable_get(iterable, phi)
try:
self.current_assign = elt
self.visit(node.target)
@ -669,17 +669,17 @@ class ARTIQIRGenerator(algorithm.Visitor):
if isinstance(node.slice, ast.Index):
index = self.visit(node.slice.value)
length = self._iterable_len(value, index.type)
length = self.iterable_len(value, index.type)
mapped_index = self._map_index(length, index)
if self.current_assign is None:
result = self._iterable_get(value, mapped_index)
result = self.iterable_get(value, mapped_index)
result.set_name("{}.at.{}".format(value.name, _readable_name(index)))
return result
else:
self.append(ir.SetElem(value, mapped_index, self.current_assign,
name="{}.at.{}".format(value.name, _readable_name(index))))
else: # Slice
length = self._iterable_len(value, node.slice.type)
length = self.iterable_len(value, node.slice.type)
if node.slice.lower is not None:
min_index = self.visit(node.slice.lower)
@ -715,7 +715,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
index = self.append(ir.Arith(ast.Add(loc=None), min_index, offset))
if self.current_assign is None:
elem = self._iterable_get(value, index)
elem = self.iterable_get(value, index)
self.append(ir.SetElem(other_value, other_index, elem))
else:
elem = self.append(ir.GetElem(self.current_assign, other_index))
@ -771,7 +771,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
assert comprehension.ifs == []
iterable = self.visit(comprehension.iter)
length = self._iterable_len(iterable)
length = self.iterable_len(iterable)
result = self.append(ir.Alloc([length], node.type))
try:
@ -782,7 +782,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
self.append(ir.SetLocal(env, ".outer", old_env))
def body_gen(index):
elt = self._iterable_get(iterable, index)
elt = self.iterable_get(iterable, index)
try:
old_assign, self.current_assign = self.current_assign, elt
print(comprehension.target, self.current_assign)
@ -926,7 +926,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
else:
assert False
def _compare_pair_order(self, op, lhs, rhs):
def polymorphic_compare_pair_order(self, op, lhs, rhs):
if builtins.is_numeric(lhs.type) and builtins.is_numeric(rhs.type):
return self.append(ir.Compare(op, lhs, rhs))
elif types.is_tuple(lhs.type) and types.is_tuple(rhs.type):
@ -960,7 +960,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
self.current_block = loop_body
lhs_elt = self.append(ir.GetElem(lhs, index_phi))
rhs_elt = self.append(ir.GetElem(rhs, index_phi))
body_result = self._compare_pair(op, lhs_elt, rhs_elt)
body_result = self.polymorphic_compare_pair(op, lhs_elt, rhs_elt)
loop_body2 = self.add_block()
self.current_block = loop_body2
@ -989,7 +989,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
else:
assert False
def _compare_pair_inclusion(self, op, needle, haystack):
def polymorphic_compare_pair_inclusion(self, op, needle, haystack):
if builtins.is_range(haystack.type):
# Optimized range `in` operator
start = self.append(ir.GetAttr(haystack, "start"))
@ -1005,15 +1005,15 @@ class ARTIQIRGenerator(algorithm.Visitor):
ir.Constant(False, builtins.TBool())))
result = self.append(ir.Select(result, on_step,
ir.Constant(False, builtins.TBool())))
elif builtins.is_iterable(haystack.type):
length = self._iterable_len(haystack)
elif builtins.isiterable(haystack.type):
length = self.iterable_len(haystack)
cmp_result = loop_body2 = None
def body_gen(index):
nonlocal cmp_result, loop_body2
elt = self._iterable_get(haystack, index)
cmp_result = self._compare_pair(ast.Eq(loc=None), needle, elt)
elt = self.iterable_get(haystack, index)
cmp_result = self.polymorphic_compare_pair(ast.Eq(loc=None), needle, elt)
loop_body2 = self.add_block()
self.current_block = loop_body2
@ -1040,7 +1040,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
return result
def _compare_pair_identity(self, op, lhs, rhs):
def polymorphic_compare_pair_identity(self, op, lhs, rhs):
if builtins.is_allocated(lhs) and builtins.is_allocated(rhs):
# These are actually pointers, compare directly.
return self.append(ir.Compare(op, lhs, rhs))
@ -1053,15 +1053,15 @@ class ARTIQIRGenerator(algorithm.Visitor):
op = ast.NotEq(loc=None)
else:
assert False
return self._compare_pair_order(op, lhs, rhs)
return self.polymorphic_compare_pair_order(op, lhs, rhs)
def _compare_pair(self, op, lhs, rhs):
def polymorphic_compare_pair(self, op, lhs, rhs):
if isinstance(op, (ast.Is, ast.IsNot)):
return self._compare_pair_identity(op, lhs, rhs)
return self.polymorphic_compare_pair_identity(op, lhs, rhs)
elif isinstance(op, (ast.In, ast.NotIn)):
return self._compare_pair_inclusion(op, lhs, rhs)
return self.polymorphic_compare_pair_inclusion(op, lhs, rhs)
else: # Eq, NotEq, Lt, LtE, Gt, GtE
return self._compare_pair_order(op, lhs, rhs)
return self.polymorphic_compare_pair_order(op, lhs, rhs)
def visit_CompareT(self, node):
# Essentially a sequence of `and`s performed over results
@ -1070,7 +1070,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
lhs = self.visit(node.left)
for op, rhs_node in zip(node.ops, node.comparators):
rhs = self.visit(rhs_node)
result = self._compare_pair(op, lhs, rhs)
result = self.polymorphic_compare_pair(op, lhs, rhs)
blocks.append((result, self.current_block))
self.current_block = self.add_block()
lhs = rhs
@ -1120,11 +1120,11 @@ class ARTIQIRGenerator(algorithm.Visitor):
return self.append(ir.Alloc(node.type, length))
elif len(node.args) == 1 and len(node.keywords) == 0:
arg = self.visit(node.args[0])
length = self._iterable_len(arg)
length = self.iterable_len(arg)
result = self.append(ir.Alloc([length], node.type))
def body_gen(index):
elt = self._iterable_get(arg, index)
elt = self.iterable_get(arg, index)
self.append(ir.SetElem(result, index, elt))
return self.append(ir.Arith(ast.Add(loc=None), index,
ir.Constant(1, length.type)))
@ -1136,7 +1136,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
else:
assert False
elif types.is_builtin(typ, "range"):
elt_typ = builtins.get_iterable_elt(node.type)
elt_typ = builtins.getiterable_elt(node.type)
if len(node.args) == 1 and len(node.keywords) == 0:
max_arg = self.visit(node.args[0])
return self.append(ir.Alloc([
@ -1166,7 +1166,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
elif types.is_builtin(typ, "len"):
if len(node.args) == 1 and len(node.keywords) == 0:
arg = self.visit(node.args[0])
return self._iterable_len(arg)
return self.iterable_len(arg)
else:
assert False
elif types.is_builtin(typ, "round"):
@ -1175,6 +1175,10 @@ class ARTIQIRGenerator(algorithm.Visitor):
return self.append(ir.Builtin("round", [arg]))
else:
assert False
elif types.is_builtin(typ, "print"):
self.polymorphic_print([self.visit(arg) for arg in node.args],
separator=" ", suffix="\n")
return ir.Constant(None, builtins.TNone())
elif types.is_exn_constructor(typ):
return self.append(ir.Alloc([self.visit(arg) for args in node.args], node.type))
else:
@ -1206,3 +1210,89 @@ class ARTIQIRGenerator(algorithm.Visitor):
invoke = self.append(ir.Invoke(func, args, after_invoke, self.unwind_target))
self.current_block = after_invoke
return invoke
def polymorphic_print(self, values, separator, suffix=""):
format_string = ""
args = []
def flush():
nonlocal format_string, args
if format_string != "":
format_arg = [ir.Constant(format_string, builtins.TStr())]
self.append(ir.Builtin("printf", format_arg + args, builtins.TNone()))
format_string = ""
args = []
for value in values:
if format_string != "":
format_string += separator
if types.is_tuple(value.type):
format_string += "("; flush()
self.polymorphic_print([self.append(ir.GetAttr(value, index))
for index in range(len(value.type.elts))],
separator=", ")
format_string += ")"
elif types.is_function(value.type):
format_string += "<closure %p(%p)>"
# We're relying on the internal layout of the closure here.
args.append(self.append(ir.GetAttr(value, 0)))
args.append(self.append(ir.GetAttr(value, 1)))
elif builtins.is_none(value.type):
format_string += "None"
elif builtins.is_bool(value.type):
format_string += "%s"
args.append(self.append(ir.Select(value,
ir.Constant("True", builtins.TStr()),
ir.Constant("False", builtins.TStr()))))
elif builtins.is_int(value.type):
format_string += "%d"
args.append(value)
elif builtins.is_float(value.type):
format_string += "%g"
args.append(value)
elif builtins.is_str(value.type):
format_string += "%s"
args.append(value)
elif builtins.is_list(value.type):
format_string += "["; flush()
length = self.iterable_len(value)
last = self.append(ir.Arith(ast.Sub(loc=None), length, ir.Constant(1, length.type)))
def body_gen(index):
elt = self.iterable_get(value, index)
self.polymorphic_print([elt], separator="")
is_last = self.append(ir.Compare(ast.Lt(loc=None), index, last))
head = self.current_block
if_last = self.current_block = self.add_block()
self.append(ir.Builtin("printf",
[ir.Constant(", ", builtins.TStr())], builtins.TNone()))
tail = self.current_block = self.add_block()
if_last.append(ir.Branch(tail))
head.append(ir.BranchIf(is_last, if_last, tail))
return self.append(ir.Arith(ast.Add(loc=None), index,
ir.Constant(1, length.type)))
self._make_loop(ir.Constant(0, length.type),
lambda index: self.append(ir.Compare(ast.Lt(loc=None), index, length)),
body_gen)
format_string += "]"
elif builtins.is_range(value.type):
format_string += "range("; flush()
start = self.append(ir.GetAttr(value, "start"))
stop = self.append(ir.GetAttr(value, "stop"))
step = self.append(ir.GetAttr(value, "step"))
self.polymorphic_print([start, stop, step], separator=", ")
format_string += ")"
elif builtins.is_exception(value.type):
# TODO: print exceptions
assert False
else:
assert False
format_string += suffix
flush()

View File

@ -637,6 +637,19 @@ class Inferencer(algorithm.Visitor):
arg.loc, None)
else:
diagnose(valid_forms())
elif types.is_builtin(typ, "print"):
valid_forms = lambda: [
valid_form("print(args...) -> None"),
]
self._unify(node.type, builtins.TNone(),
node.loc, None)
if len(node.keywords) == 0:
# We can print any arguments.
pass
else:
diagnose(valid_forms())
# TODO: add when it is clear what interface syscall() has
# elif types.is_builtin(typ, "syscall"):
# valid_Forms = lambda: [

View File

@ -90,6 +90,25 @@ class LLVMIRGenerator:
else:
assert False
def llbuiltin(self, name):
llfun = self.llmodule.get_global(name)
if llfun is not None:
return llfun
if name in ("llvm.abort", "llvm.donothing"):
llty = ll.FunctionType(ll.VoidType(), [])
elif name == "llvm.round.f64":
llty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType()])
elif name == "llvm.pow.f64":
llty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType(), ll.DoubleType()])
elif name == "llvm.powi.f64":
llty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType(), ll.IntType(32)])
elif name == "printf":
llty = ll.FunctionType(ll.VoidType(), [ll.IntType(8).as_pointer()], var_arg=True)
else:
assert False
return ll.Function(self.llmodule, llty, name)
def map(self, value):
if isinstance(value, (ir.Argument, ir.Instruction, ir.BasicBlock)):
return self.llmap[value]
@ -214,7 +233,7 @@ class LLVMIRGenerator:
def process_GetAttr(self, insn):
if types.is_tuple(insn.object().type):
return self.llbuilder.extract_value(self.map(insn.object()), self.attr_index(insn),
return self.llbuilder.extract_value(self.map(insn.object()), insn.attr,
name=insn.name)
elif not builtins.is_allocated(insn.object().type):
return self.llbuilder.extract_value(self.map(insn.object()), self.attr_index(insn),
@ -296,9 +315,7 @@ class LLVMIRGenerator:
elif isinstance(insn.op, ast.FloorDiv):
if builtins.is_float(insn.type):
llvalue = self.llbuilder.fdiv(self.map(insn.lhs()), self.map(insn.rhs()))
llfnty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType()])
llfn = ll.Function(self.llmodule, llfnty, "llvm.round.f64")
return self.llbuilder.call(llfn, [llvalue],
return self.llbuilder.call(self.llbuiltin("llvm.round.f64"), [llvalue],
name=insn.name)
else:
return self.llbuilder.sdiv(self.map(insn.lhs()), self.map(insn.rhs()),
@ -312,15 +329,13 @@ class LLVMIRGenerator:
name=insn.name)
elif isinstance(insn.op, ast.Pow):
if builtins.is_float(insn.type):
llfnty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType(), ll.DoubleType()])
llfn = ll.Function(self.llmodule, llfnty, "llvm.pow.f64")
return self.llbuilder.call(llfn, [self.map(insn.lhs()), self.map(insn.rhs())],
return self.llbuilder.call(self.llbuiltin("llvm.pow.f64"),
[self.map(insn.lhs()), self.map(insn.rhs())],
name=insn.name)
else:
llrhs = self.llbuilder.trunc(self.map(insn.rhs()), ll.IntType(32))
llfnty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType(), ll.IntType(32)])
llfn = ll.Function(self.llmodule, llfnty, "llvm.powi.f64")
llvalue = self.llbuilder.call(llfn, [self.map(insn.lhs()), llrhs])
llvalue = self.llbuilder.call(self.llbuiltin("llvm.powi.f64"),
[self.map(insn.lhs()), llrhs])
return self.llbuilder.fptosi(llvalue, self.llty_of_type(insn.type),
name=insn.name)
elif isinstance(insn.op, ast.LShift):
@ -366,8 +381,7 @@ class LLVMIRGenerator:
def process_Builtin(self, insn):
if insn.op == "nop":
fn = ll.Function(self.llmodule, ll.FunctionType(ll.VoidType(), []), "llvm.donothing")
return self.llbuilder.call(fn, [])
return self.llbuilder.call(self.llbuiltin("llvm.donothing"), [])
elif insn.op == "unwrap":
optarg, default = map(self.map, insn.operands)
has_arg = self.llbuilder.extract_value(optarg, 0)
@ -375,9 +389,7 @@ class LLVMIRGenerator:
return self.llbuilder.select(has_arg, arg, default,
name=insn.name)
elif insn.op == "round":
llfnty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType()])
llfn = ll.Function(self.llmodule, llfnty, "llvm.round.f64")
return self.llbuilder.call(llfn, [llvalue],
return self.llbuilder.call(self.llbuiltin("llvm.round.f64"), [llvalue],
name=insn.name)
elif insn.op == "globalenv":
def get_outer(llenv, env_ty):
@ -394,6 +406,11 @@ class LLVMIRGenerator:
elif insn.op == "len":
lst, = insn.operands
return self.llbuilder.extract_value(self.map(lst), 0)
elif insn.op == "printf":
# We only get integers, floats, pointers and strings here.
llargs = map(self.map, insn.operands)
return self.llbuilder.call(self.llbuiltin("printf"), llargs,
name=insn.name)
# elif insn.op == "exncast":
else:
assert False
@ -414,8 +431,8 @@ class LLVMIRGenerator:
name=insn.name)
def process_Select(self, insn):
return self.llbuilder.select(self.map(insn.cond()),
self.map(insn.lhs()), self.map(insn.rhs()))
return self.llbuilder.select(self.map(insn.condition()),
self.map(insn.if_true()), self.map(insn.if_false()))
def process_Branch(self, insn):
return self.llbuilder.branch(self.map(insn.target()))
@ -438,9 +455,7 @@ class LLVMIRGenerator:
def process_Raise(self, insn):
# TODO: hack before EH is working
llfnty = ll.FunctionType(ll.VoidType(), [])
llfn = ll.Function(self.llmodule, llfnty, "llvm.abort")
llinsn = self.llbuilder.call(llfn, [],
llinsn = self.llbuilder.call(self.llbuiltin("llvm.abort"), [],
name=insn.name)
self.llbuilder.unreachable()
return llinsn