forked from M-Labs/artiq
1
0
Fork 0

Add basic IR generator.

This commit is contained in:
whitequark 2015-07-14 06:44:16 +03:00
parent f417ef31a4
commit bdcb24108b
9 changed files with 332 additions and 10 deletions

View File

@ -8,7 +8,7 @@ from . import types, builtins
# Generic SSA IR classes # Generic SSA IR classes
def escape_name(name): def escape_name(name):
if all([isalnum(x) or x == "." for x in name]): if all([str.isalnum(x) or x == "." for x in name]):
return name return name
else: else:
return "\"{}\"".format(name.replace("\"", "\\\"")) return "\"{}\"".format(name.replace("\"", "\\\""))
@ -36,6 +36,24 @@ class Value:
for user in self.uses: for user in self.uses:
user.replace_uses_of(self, value) user.replace_uses_of(self, value)
class Constant(Value):
"""
A constant value.
:ivar value: (None, True or False) value
"""
def __init__(self, value, typ):
super().__init__(typ)
self.value = value
def as_operand(self):
return str(self)
def __str__(self):
return "{} {}".format(types.TypePrinter().name(self.type),
repr(self.value))
class NamedValue(Value): class NamedValue(Value):
""" """
An SSA value that has a name. An SSA value that has a name.
@ -150,10 +168,10 @@ class Instruction(User):
types.TypePrinter().name(self.type)) types.TypePrinter().name(self.type))
if any(self.operands): if any(self.operands):
return "{} {} {}".format(prefix, self.opcode(), return "{}{} {}".format(prefix, self.opcode(),
", ".join([operand.as_operand() for operand in self.operands])) ", ".join([operand.as_operand() for operand in self.operands]))
else: else:
return "{} {}".format(prefix, self.opcode()) return "{}{}".format(prefix, self.opcode())
class Phi(Instruction): class Phi(Instruction):
""" """
@ -201,7 +219,7 @@ class Phi(Instruction):
if any(self.operands): if any(self.operands):
operand_list = ["%{} => %{}".format(escape_name(block.name), escape_name(value.name)) operand_list = ["%{} => %{}".format(escape_name(block.name), escape_name(value.name))
for operand in self.operands] for operand in self.operands]
return "{} {} [{}]".format(prefix, self.opcode(), ", ".join(operand_list)) return "{}{} [{}]".format(prefix, self.opcode(), ", ".join(operand_list))
class Terminator(Instruction): class Terminator(Instruction):
""" """
@ -241,6 +259,7 @@ class BasicBlock(NamedValue):
def append(self, insn): def append(self, insn):
insn.set_basic_block(self) insn.set_basic_block(self)
self.instructions.append(insn) self.instructions.append(insn)
return insn
def index(self, insn): def index(self, insn):
return self.instructions.index(insn) return self.instructions.index(insn)
@ -248,17 +267,22 @@ class BasicBlock(NamedValue):
def insert(self, before, insn): def insert(self, before, insn):
insn.set_basic_block(self) insn.set_basic_block(self)
self.instructions.insert(self.index(before), insn) self.instructions.insert(self.index(before), insn)
return insn
def remove(self, insn): def remove(self, insn):
insn._detach() insn._detach()
self.instructions.remove(insn) self.instructions.remove(insn)
return insn
def replace(self, insn, replacement): def replace(self, insn, replacement):
self.insert(insn, replacement) self.insert(insn, replacement)
self.remove(insn) self.remove(insn)
def is_terminated(self):
return any(self.instructions) and isinstance(self.instructions[-1], Terminator)
def terminator(self): def terminator(self):
assert isinstance(self.instructions[-1], Terminator) assert self.is_terminated()
return self.instructions[-1] return self.instructions[-1]
def successors(self): def successors(self):
@ -271,7 +295,7 @@ class BasicBlock(NamedValue):
def __str__(self): def __str__(self):
lines = ["{}:".format(escape_name(self.name))] lines = ["{}:".format(escape_name(self.name))]
for insn in self.instructions: for insn in self.instructions:
lines.append(str(insn)) lines.append(" " + str(insn))
return "\n".join(lines) return "\n".join(lines)
class Argument(NamedValue): class Argument(NamedValue):
@ -290,7 +314,7 @@ class Function(Value):
def __init__(self, typ, name, arguments): def __init__(self, typ, name, arguments):
self.type, self.name = typ, name self.type, self.name = typ, name
self.arguments = [] self.arguments = []
self.basic_blocks = set() self.basic_blocks = []
self.names = set() self.names = set()
self.set_arguments(arguments) self.set_arguments(arguments)
@ -318,14 +342,14 @@ class Function(Value):
def add(self, basic_block): def add(self, basic_block):
basic_block._set_function(self) basic_block._set_function(self)
self.basic_blocks.add(basic_blocks) self.basic_blocks.append(basic_block)
def remove(self, basic_block): def remove(self, basic_block):
basic_block._detach() basic_block._detach()
self.basic_block.remove(basic_block) self.basic_block.remove(basic_block)
def predecessors_of(self, successor): def predecessors_of(self, successor):
return set(block for block in self.basic_blocks if successor in block.successors()) return [block for block in self.basic_blocks if successor in block.successors()]
def as_operand(self): def as_operand(self):
return "{} @{}".format(types.TypePrinter().name(self.type), return "{} @{}".format(types.TypePrinter().name(self.type),
@ -344,3 +368,65 @@ class Function(Value):
return "\n".join(lines) return "\n".join(lines)
# Python-specific SSA IR classes # Python-specific SSA IR classes
class Branch(Terminator):
"""
An unconditional branch instruction.
"""
"""
:param target: (:class:`BasicBlock`) branch target
"""
def __init__(self, target, name=""):
super().__init__([target], builtins.TNone(), name)
def opcode(self):
return "branch"
class BranchIf(Terminator):
"""
A conditional branch instruction.
"""
"""
:param cond: (:class:`Value`) branch condition
:param if_true: (:class:`BasicBlock`) branch target if expression is truthful
:param if_false: (:class:`BasicBlock`) branch target if expression is falseful
"""
def __init__(self, cond, if_true, if_false, name=""):
super().__init__([cond, if_true, if_false], builtins.TNone(), name)
def opcode(self):
return "branch_if"
class Return(Terminator):
"""
A return instruction.
"""
"""
:param value: (:class:`Value`) return value
"""
def __init__(self, value, name=""):
super().__init__([value], builtins.TNone(), name)
def opcode(self):
return "return"
class Eval(Instruction):
"""
An instruction that evaluates an AST fragment.
"""
"""
:param ast: (:class:`.asttyped.AST`) return value
"""
def __init__(self, ast, name=""):
super().__init__([], ast.type, name)
self.ast = ast
def opcode(self):
return "eval"
def __str__(self):
return super().__str__() + " `{}`".format(self.ast.loc.source())

View File

@ -11,11 +11,14 @@ class Module:
if engine is None: if engine is None:
engine = diagnostic.Engine(all_errors_are_fatal=True) engine = diagnostic.Engine(all_errors_are_fatal=True)
module_name, _ = os.path.splitext(os.path.basename(source_buffer.name))
asttyped_rewriter = transforms.ASTTypedRewriter(engine=engine) asttyped_rewriter = transforms.ASTTypedRewriter(engine=engine)
inferencer = transforms.Inferencer(engine=engine) inferencer = transforms.Inferencer(engine=engine)
int_monomorphizer = transforms.IntMonomorphizer(engine=engine) int_monomorphizer = transforms.IntMonomorphizer(engine=engine)
monomorphism_validator = validators.MonomorphismValidator(engine=engine) monomorphism_validator = validators.MonomorphismValidator(engine=engine)
escape_validator = validators.EscapeValidator(engine=engine) escape_validator = validators.EscapeValidator(engine=engine)
ir_generator = transforms.IRGenerator(engine=engine, module_name=module_name)
parsetree, comments = parse_buffer(source_buffer, engine=engine) parsetree, comments = parse_buffer(source_buffer, engine=engine)
typedtree = asttyped_rewriter.visit(parsetree) typedtree = asttyped_rewriter.visit(parsetree)
@ -24,9 +27,11 @@ class Module:
inferencer.visit(typedtree) inferencer.visit(typedtree)
monomorphism_validator.visit(typedtree) monomorphism_validator.visit(typedtree)
escape_validator.visit(typedtree) escape_validator.visit(typedtree)
ir_generator.visit(typedtree)
self.name = os.path.basename(source_buffer.name) self.name = module_name
self.globals = asttyped_rewriter.globals self.globals = asttyped_rewriter.globals
self.ir = ir_generator.functions
@classmethod @classmethod
def from_string(cls, source_string, name="input.py", first_line=1, engine=None): def from_string(cls, source_string, name="input.py", first_line=1, engine=None):

View File

@ -0,0 +1,19 @@
import sys, fileinput
from pythonparser import diagnostic
from .. import Module
def main():
def process_diagnostic(diag):
print("\n".join(diag.render()))
if diag.level in ("fatal", "error"):
exit(1)
engine = diagnostic.Engine()
engine.process = process_diagnostic
mod = Module.from_string("".join(fileinput.input()).expandtabs(), engine=engine)
for fn in mod.ir:
print(fn)
if __name__ == "__main__":
main()

View File

@ -1,3 +1,4 @@
from .asttyped_rewriter import ASTTypedRewriter from .asttyped_rewriter import ASTTypedRewriter
from .inferencer import Inferencer from .inferencer import Inferencer
from .int_monomorphizer import IntMonomorphizer from .int_monomorphizer import IntMonomorphizer
from .ir_generator import IRGenerator

View File

@ -0,0 +1,149 @@
"""
:class:`IRGenerator` transforms typed AST into ARTIQ intermediate
representation.
"""
from collections import OrderedDict
from pythonparser import algorithm, diagnostic, ast
from .. import types, builtins, ir
# We put some effort in keeping generated IR readable,
# i.e. with a more or less linear correspondence to the source.
# This is why basic blocks sometimes seem to be produced in an odd order.
class IRGenerator(algorithm.Visitor):
def __init__(self, module_name, engine):
self.engine = engine
self.functions = []
self.name = [module_name]
self.current_function = None
self.current_block = None
self.break_target, self.continue_target = None, None
def add_block(self):
block = ir.BasicBlock([])
self.current_function.add(block)
return block
def append(self, insn):
return self.current_block.append(insn)
def terminate(self, insn):
if not self.current_block.is_terminated():
self.append(insn)
def visit(self, obj):
if isinstance(obj, list):
for elt in obj:
self.visit(elt)
if self.current_block.is_terminated():
break
elif isinstance(obj, ast.AST):
return self._visit_one(obj)
def visit_function(self, name, typ, inner):
try:
old_name, self.name = self.name, self.name + [name]
args = []
for arg_name in typ.args:
args.append(ir.Argument(typ.args[arg_name], arg_name))
for arg_name in typ.optargs:
args.append(ir.Argument(ir.TSSAOption(typ.optargs[arg_name]), arg_name))
func = ir.Function(typ, ".".join(self.name), args)
self.functions.append(func)
old_func, self.current_function = self.current_function, func
self.current_block = self.add_block()
inner()
finally:
self.name = old_name
self.current_function = old_func
def visit_ModuleT(self, node):
def inner():
self.generic_visit(node)
return_value = ir.Constant(None, builtins.TNone())
self.terminate(ir.Return(return_value))
typ = types.TFunction(OrderedDict(), OrderedDict(), builtins.TNone())
self.visit_function('__modinit__', typ, inner)
def visit_FunctionDefT(self, node):
self.visit_function(node.name, node.signature_type.find(),
lambda: self.generic_visit(node))
def visit_Return(self, node):
if node.value is None:
return_value = ir.Constant(None, builtins.TNone())
self.append(ir.Return(return_value))
else:
expr = self.append(ir.Eval(node.value))
self.append(ir.Return(expr))
def visit_Expr(self, node):
self.append(ir.Eval(node.value))
# Assign
# AugAssign
def visit_If(self, node):
cond = self.append(ir.Eval(node.test))
head = self.current_block
if_true = self.add_block()
self.current_block = if_true
self.visit(node.body)
if_false = self.add_block()
self.current_block = if_false
self.visit(node.orelse)
tail = self.add_block()
self.current_block = tail
if not if_true.is_terminated():
if_true.append(ir.Branch(tail))
if not if_false.is_terminated():
if_false.append(ir.Branch(tail))
head.append(ir.BranchIf(cond, if_true, if_false))
def visit_While(self, node):
try:
head = self.add_block()
self.append(ir.Branch(head))
self.current_block = head
tail_tramp = self.add_block()
old_break, self.break_target = self.break_target, tail_tramp
body = self.add_block()
old_continue, self.continue_target = self.continue_target, body
self.current_block = body
self.visit(node.body)
tail = self.add_block()
self.current_block = tail
self.visit(node.orelse)
cond = head.append(ir.Eval(node.test))
head.append(ir.BranchIf(cond, body, tail))
if not body.is_terminated():
body.append(ir.Branch(tail))
tail_tramp.append(ir.Branch(tail))
finally:
self.break_target = old_break
self.continue_target = old_continue
# For
def visit_Break(self, node):
self.append(ir.Branch(self.break_target))
def visit_Continue(self, node):
self.append(ir.Branch(self.continue_target))
# Raise
# Try
# With

View File

@ -0,0 +1,7 @@
# RUN: %python -m artiq.compiler.testbench.irgen %s >%t
# RUN: OutputCheck %s --file-to-check=%t
# CHECK-L: NoneType input.__modinit__() {
# CHECK-L: 1:
# CHECK-L: return NoneType None
# CHECK-L: }

View File

@ -0,0 +1,9 @@
# RUN: %python -m artiq.compiler.testbench.irgen %s >%t
# RUN: OutputCheck %s --file-to-check=%t
2 + 2
# CHECK-L: NoneType input.__modinit__() {
# CHECK-L: 1:
# CHECK-L: %2 = int(width=32) eval `2 + 2`
# CHECK-L: return NoneType None
# CHECK-L: }

View File

@ -0,0 +1,21 @@
# RUN: %python -m artiq.compiler.testbench.irgen %s >%t
# RUN: OutputCheck %s --file-to-check=%t
if 1:
2
else:
3
# CHECK-L: NoneType input.__modinit__() {
# CHECK-L: 1:
# CHECK-L: %2 = int(width=32) eval `1`
# CHECK-L: branch_if int(width=32) %2, ssa.basic_block %3, ssa.basic_block %5
# CHECK-L: 3:
# CHECK-L: %4 = int(width=32) eval `2`
# CHECK-L: branch ssa.basic_block %7
# CHECK-L: 5:
# CHECK-L: %6 = int(width=32) eval `3`
# CHECK-L: branch ssa.basic_block %7
# CHECK-L: 7:
# CHECK-L: return NoneType None
# CHECK-L: }

View File

@ -0,0 +1,25 @@
# RUN: %python -m artiq.compiler.testbench.irgen %s >%t
# RUN: OutputCheck %s --file-to-check=%t
while 1:
2
else:
3
4
# CHECK-L: NoneType input.__modinit__() {
# CHECK-L: 1:
# CHECK-L: branch ssa.basic_block %2
# CHECK-L: 2:
# CHECK-L: %9 = int(width=32) eval `1`
# CHECK-L: branch_if int(width=32) %9, ssa.basic_block %5, ssa.basic_block %7
# CHECK-L: 4:
# CHECK-L: branch ssa.basic_block %7
# CHECK-L: 5:
# CHECK-L: %6 = int(width=32) eval `2`
# CHECK-L: branch ssa.basic_block %7
# CHECK-L: 7:
# CHECK-L: %8 = int(width=32) eval `3`
# CHECK-L: %13 = int(width=32) eval `4`
# CHECK-L: return NoneType None
# CHECK-L: }