forked from M-Labs/artiq
wrpll: add simple thls compiler
This commit is contained in:
parent
e4fff390a8
commit
5f461d08cd
|
@ -0,0 +1,267 @@
|
||||||
|
import inspect
|
||||||
|
import ast
|
||||||
|
from copy import copy
|
||||||
|
|
||||||
|
|
||||||
|
class Isn:
|
||||||
|
def __init__(self, immediate=None, inputs=None, outputs=None):
|
||||||
|
if inputs is None:
|
||||||
|
inputs = []
|
||||||
|
if outputs is None:
|
||||||
|
outputs = []
|
||||||
|
self.immediate = immediate
|
||||||
|
self.inputs = inputs
|
||||||
|
self.outputs = outputs
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
r = "<"
|
||||||
|
r += self.__class__.__name__
|
||||||
|
if self.immediate is not None:
|
||||||
|
r += " (" + str(self.immediate) + ")"
|
||||||
|
for inp in self.inputs:
|
||||||
|
r += " r" + str(inp)
|
||||||
|
if self.outputs:
|
||||||
|
r += " ->"
|
||||||
|
for outp in self.outputs:
|
||||||
|
r += " r" + str(outp)
|
||||||
|
r += ">"
|
||||||
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
class NopIsn(Isn):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AddIsn(Isn):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SubIsn(Isn):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MulIsn(Isn):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ShiftIsn(Isn):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CopyIsn(Isn):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InputIsn(Isn):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class OutputIsn(Isn):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ASTCompiler:
|
||||||
|
def __init__(self):
|
||||||
|
self.program = []
|
||||||
|
self.data = []
|
||||||
|
self.next_ssa_reg = -1
|
||||||
|
self.constants = dict()
|
||||||
|
self.names = dict()
|
||||||
|
self.globals = dict()
|
||||||
|
|
||||||
|
def get_ssa_reg(self):
|
||||||
|
r = self.next_ssa_reg
|
||||||
|
self.next_ssa_reg -= 1
|
||||||
|
return r
|
||||||
|
|
||||||
|
def add_global(self, name):
|
||||||
|
r = len(self.data)
|
||||||
|
self.data.append(0)
|
||||||
|
self.names[name] = r
|
||||||
|
self.globals[name] = r
|
||||||
|
return r
|
||||||
|
|
||||||
|
def input(self, name):
|
||||||
|
target = self.get_ssa_reg()
|
||||||
|
self.program.append(InputIsn(outputs=[target]))
|
||||||
|
self.names[name] = target
|
||||||
|
|
||||||
|
def emit(self, node):
|
||||||
|
if isinstance(node, ast.BinOp):
|
||||||
|
left = self.emit(node.left)
|
||||||
|
right = self.emit(node.right)
|
||||||
|
if isinstance(node.op, ast.Add):
|
||||||
|
cls = AddIsn
|
||||||
|
elif isinstance(node.op, ast.Sub):
|
||||||
|
cls = SubIsn
|
||||||
|
elif isinstance(node.op, ast.Mult):
|
||||||
|
cls = MulIsn
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
output = self.get_ssa_reg()
|
||||||
|
self.program.append(cls(inputs=[left, right], outputs=[output]))
|
||||||
|
return output
|
||||||
|
elif isinstance(node, ast.Num):
|
||||||
|
if node.n in self.constants:
|
||||||
|
return self.constants[node.n]
|
||||||
|
else:
|
||||||
|
r = len(self.data)
|
||||||
|
self.data.append(node.n)
|
||||||
|
self.constants[node.n] = r
|
||||||
|
return r
|
||||||
|
elif isinstance(node, ast.Name):
|
||||||
|
return self.names[node.id]
|
||||||
|
elif isinstance(node, ast.Assign):
|
||||||
|
output = self.emit(node.value)
|
||||||
|
for target in node.targets:
|
||||||
|
assert isinstance(target, ast.Name)
|
||||||
|
self.names[target.id] = output
|
||||||
|
elif isinstance(node, ast.Return):
|
||||||
|
value = self.emit(node.value)
|
||||||
|
self.program.append(OutputIsn(inputs=[value]))
|
||||||
|
elif isinstance(node, ast.Global):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class Processor:
|
||||||
|
def __init__(self, multiplier_stages=2):
|
||||||
|
self.multiplier_stages = multiplier_stages
|
||||||
|
|
||||||
|
def get_instruction_latency(self, isn):
|
||||||
|
return {
|
||||||
|
AddIsn: 2,
|
||||||
|
SubIsn: 2,
|
||||||
|
MulIsn: 1 + self.multiplier_stages,
|
||||||
|
ShiftIsn: 2,
|
||||||
|
CopyIsn: 1,
|
||||||
|
InputIsn: 1
|
||||||
|
}[isn.__class__]
|
||||||
|
|
||||||
|
|
||||||
|
class Scheduler:
|
||||||
|
def __init__(self, processor, reserved_data, program):
|
||||||
|
self.processor = processor
|
||||||
|
self.reserved_data = reserved_data
|
||||||
|
self.used_registers = set(range(self.reserved_data))
|
||||||
|
self.exits = dict()
|
||||||
|
self.program = program
|
||||||
|
self.remaining = copy(program)
|
||||||
|
self.output = []
|
||||||
|
|
||||||
|
def allocate_register(self):
|
||||||
|
r = min(set(range(max(self.used_registers) + 2)) - self.used_registers)
|
||||||
|
self.used_registers.add(r)
|
||||||
|
return r
|
||||||
|
|
||||||
|
def free_register(self, r):
|
||||||
|
assert r >= self.reserved_data
|
||||||
|
self.used_registers.discard(r)
|
||||||
|
|
||||||
|
def find_inputs(self, cycle, isn):
|
||||||
|
mapped_inputs = []
|
||||||
|
for inp in isn.inputs:
|
||||||
|
if inp >= 0:
|
||||||
|
mapped_inputs.append(inp)
|
||||||
|
else:
|
||||||
|
found = False
|
||||||
|
for i in range(cycle):
|
||||||
|
if i in self.exits:
|
||||||
|
r, rm = self.exits[i]
|
||||||
|
if r == inp:
|
||||||
|
mapped_inputs.append(rm)
|
||||||
|
found = True
|
||||||
|
break
|
||||||
|
if not found:
|
||||||
|
return None
|
||||||
|
return mapped_inputs
|
||||||
|
|
||||||
|
def schedule_one(self, isn):
|
||||||
|
cycle = len(self.output)
|
||||||
|
mapped_inputs = self.find_inputs(cycle, isn)
|
||||||
|
if mapped_inputs is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if isn.outputs:
|
||||||
|
latency = self.processor.get_instruction_latency(isn)
|
||||||
|
exit = cycle + latency
|
||||||
|
if exit in self.exits:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Instruction can be scheduled
|
||||||
|
|
||||||
|
self.remaining.remove(isn)
|
||||||
|
|
||||||
|
for inp, minp in zip(isn.inputs, mapped_inputs):
|
||||||
|
can_free = inp < 0 and all(inp != rinp for risn in self.remaining for rinp in risn.inputs)
|
||||||
|
if can_free:
|
||||||
|
self.free_register(minp)
|
||||||
|
|
||||||
|
if isn.outputs:
|
||||||
|
assert len(isn.outputs) == 1
|
||||||
|
output = self.allocate_register()
|
||||||
|
self.exits[exit] = (isn.outputs[0], output)
|
||||||
|
self.output.append(isn.__class__(immediate=isn.immediate, inputs=mapped_inputs))
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def schedule(self):
|
||||||
|
while self.remaining:
|
||||||
|
success = False
|
||||||
|
for isn in self.remaining:
|
||||||
|
if self.schedule_one(isn):
|
||||||
|
success = True
|
||||||
|
break
|
||||||
|
if not success:
|
||||||
|
self.output.append(NopIsn())
|
||||||
|
self.output += [NopIsn()]*(max(self.exits.keys()) - len(self.output) + 1)
|
||||||
|
return self.output
|
||||||
|
|
||||||
|
def pretty_print(self):
|
||||||
|
for cycle, isn in enumerate(self.output):
|
||||||
|
l = "{:4d} {:15}".format(cycle, str(isn))
|
||||||
|
if cycle in self.exits:
|
||||||
|
l += " -> r{}".format(self.exits[cycle][1])
|
||||||
|
print(l)
|
||||||
|
|
||||||
|
|
||||||
|
def compile(function):
|
||||||
|
node = ast.parse(inspect.getsource(function))
|
||||||
|
assert isinstance(node, ast.Module)
|
||||||
|
assert len(node.body) == 1
|
||||||
|
node = node.body[0]
|
||||||
|
assert isinstance(node, ast.FunctionDef)
|
||||||
|
assert len(node.args.args) == 1
|
||||||
|
arg = node.args.args[0].arg
|
||||||
|
body = node.body
|
||||||
|
|
||||||
|
astcompiler = ASTCompiler()
|
||||||
|
for node in body:
|
||||||
|
if isinstance(node, ast.Global):
|
||||||
|
for name in node.names:
|
||||||
|
astcompiler.add_global(name)
|
||||||
|
arg_r = astcompiler.input(arg)
|
||||||
|
for node in body:
|
||||||
|
astcompiler.emit(node)
|
||||||
|
print(astcompiler.data)
|
||||||
|
print(astcompiler.program)
|
||||||
|
|
||||||
|
scheduler = Scheduler(Processor(), len(astcompiler.data), astcompiler.program)
|
||||||
|
scheduler.schedule()
|
||||||
|
scheduler.pretty_print()
|
||||||
|
|
||||||
|
|
||||||
|
a = 0
|
||||||
|
b = 0
|
||||||
|
c = 0
|
||||||
|
|
||||||
|
def foo(x):
|
||||||
|
global a, b, c
|
||||||
|
c = b
|
||||||
|
b = a
|
||||||
|
a = x
|
||||||
|
return 4748*a + 259*b - 155*c
|
||||||
|
|
||||||
|
|
||||||
|
compile(foo)
|
Loading…
Reference in New Issue