wrpll: encode thls program

This commit is contained in:
Sebastien Bourdeauducq 2019-07-09 17:56:14 +08:00
parent 5f461d08cd
commit 34222b3f38
1 changed files with 73 additions and 24 deletions

View File

@ -29,35 +29,28 @@ class Isn:
class NopIsn(Isn): class NopIsn(Isn):
pass opcode = 0
class AddIsn(Isn): class AddIsn(Isn):
pass opcode = 1
class SubIsn(Isn): class SubIsn(Isn):
pass opcode = 2
class MulIsn(Isn): class MulIsn(Isn):
pass opcode = 3
class ShiftIsn(Isn): class ShiftIsn(Isn):
pass opcode = 4
class CopyIsn(Isn): class CopyIsn(Isn):
pass opcode = 5
class InputIsn(Isn): class InputIsn(Isn):
pass opcode = 6
class OutputIsn(Isn): class OutputIsn(Isn):
pass opcode = 7
class ASTCompiler: class ASTCompiler:
@ -126,8 +119,13 @@ class ASTCompiler:
class Processor: class Processor:
def __init__(self, multiplier_stages=2): def __init__(self, data_width=32, multiplier_stages=2):
self.data_width = data_width
self.multiplier_stages = multiplier_stages self.multiplier_stages = multiplier_stages
self.program_rom_size = None
self.data_ram_size = None
self.opcode_bits = 3
self.reg_bits = None
def get_instruction_latency(self, isn): def get_instruction_latency(self, isn):
return { return {
@ -139,6 +137,29 @@ class Processor:
InputIsn: 1 InputIsn: 1
}[isn.__class__] }[isn.__class__]
def encode_instruction(self, isn, exit):
opcode = isn.opcode
if isn.immediate is not None:
r0 = isn.immediate
if len(isn.inputs) >= 1:
r1 = isn.inputs[0]
else:
r1 = 0
else:
if len(isn.inputs) >= 1:
r0 = isn.inputs[0]
else:
r0 = 0
if len(isn.inputs) >= 2:
r1 = isn.inputs[1]
else:
r1 = 0
r = 0
for value, bits in ((exit, self.reg_bits), (r1, self.reg_bits), (r0, self.reg_bits), (opcode, self.opcode_bits)):
r <<= bits
r |= value
return r
class Scheduler: class Scheduler:
def __init__(self, processor, reserved_data, program): def __init__(self, processor, reserved_data, program):
@ -218,15 +239,36 @@ class Scheduler:
self.output += [NopIsn()]*(max(self.exits.keys()) - len(self.output) + 1) self.output += [NopIsn()]*(max(self.exits.keys()) - len(self.output) + 1)
return self.output return self.output
class CompiledProgram:
def __init__(self, processor, program, exits, data, glbs):
self.processor = processor
self.program = program
self.exits = exits
self.data = data
self.globals = glbs
def pretty_print(self): def pretty_print(self):
for cycle, isn in enumerate(self.output): for cycle, isn in enumerate(self.program):
l = "{:4d} {:15}".format(cycle, str(isn)) l = "{:4d} {:15}".format(cycle, str(isn))
if cycle in self.exits: if cycle in self.exits:
l += " -> r{}".format(self.exits[cycle][1]) l += " -> r{}".format(self.exits[cycle])
print(l) print(l)
def dimension_memories(self):
self.processor.program_rom_size = len(self.program)
self.processor.data_ram_size = len(self.data)
self.processor.reg_bits = (self.processor.data_ram_size - 1).bit_length()
def compile(function): def encode(self):
r = []
for i, isn in enumerate(self.program):
exit = self.exits.get(i, 0)
r.append(self.processor.encode_instruction(isn, exit))
return r
def compile(processor, function):
node = ast.parse(inspect.getsource(function)) node = ast.parse(inspect.getsource(function))
assert isinstance(node, ast.Module) assert isinstance(node, ast.Module)
assert len(node.body) == 1 assert len(node.body) == 1
@ -244,12 +286,16 @@ def compile(function):
arg_r = astcompiler.input(arg) arg_r = astcompiler.input(arg)
for node in body: for node in body:
astcompiler.emit(node) astcompiler.emit(node)
print(astcompiler.data)
print(astcompiler.program)
scheduler = Scheduler(Processor(), len(astcompiler.data), astcompiler.program) scheduler = Scheduler(processor, len(astcompiler.data), astcompiler.program)
scheduler.schedule() scheduler.schedule()
scheduler.pretty_print()
return CompiledProgram(
processor=processor,
program=scheduler.output,
exits={k: v[1] for k,v in scheduler.exits.items()},
data=astcompiler.data,
glbs=astcompiler.globals)
a = 0 a = 0
@ -264,4 +310,7 @@ def foo(x):
return 4748*a + 259*b - 155*c return 4748*a + 259*b - 155*c
compile(foo) cp = compile(Processor(), foo)
cp.pretty_print()
cp.dimension_memories()
print(cp.encode())