diff --git a/artiq/gateware/wrpll/thls.py b/artiq/gateware/wrpll/thls.py index dc95e343d..99ffbc17b 100644 --- a/artiq/gateware/wrpll/thls.py +++ b/artiq/gateware/wrpll/thls.py @@ -1,6 +1,10 @@ import inspect import ast from copy import copy +import operator +from functools import reduce + +from migen import * class Isn: @@ -40,17 +44,14 @@ class SubIsn(Isn): class MulIsn(Isn): opcode = 3 -class ShiftIsn(Isn): +class CopyIsn(Isn): opcode = 4 -class CopyIsn(Isn): +class InputIsn(Isn): opcode = 5 -class InputIsn(Isn): - opcode = 6 - class OutputIsn(Isn): - opcode = 7 + opcode = 6 class ASTCompiler: @@ -132,7 +133,6 @@ class Processor: AddIsn: 2, SubIsn: 2, MulIsn: 1 + self.multiplier_stages, - ShiftIsn: 2, CopyIsn: 1, InputIsn: 1 }[isn.__class__] @@ -160,6 +160,12 @@ class Processor: r |= value return r + def instruction_bits(self): + return 3*self.reg_bits + self.opcode_bits + + def implement(self, program, data): + return ProcessorImpl(self, program, data) + class Scheduler: def __init__(self, processor, reserved_data, program): @@ -292,14 +298,167 @@ def compile(processor, function): scheduler = Scheduler(processor, len(astcompiler.data), astcompiler.program) scheduler.schedule() + max_reg = max(max(max(isn.inputs + [0]) for isn in scheduler.output), max(v[1] for k, v in scheduler.exits.items())) + return CompiledProgram( processor=processor, program=scheduler.output, - exits={k: v[1] for k,v in scheduler.exits.items()}, - data=astcompiler.data, + exits={k: v[1] for k, v in scheduler.exits.items()}, + data=astcompiler.data + [0]*(max_reg - len(astcompiler.data) + 1), glbs=astcompiler.globals) +class BaseUnit(Module): + def __init__(self, data_width): + self.stb_i = Signal() + self.i0 = Signal(data_width) + self.i1 = Signal(data_width) + self.stb_o = Signal() + self.o = Signal(data_width) + + +class NopUnit(BaseUnit): + pass + + +class OpUnit(BaseUnit): + def __init__(self, op, data_width, stages): + BaseUnit.__init__(self, data_width) + + o = op(self.i0, self.i1) + stb_o = self.stb_i + for i in range(stages): + n_o = Signal(data_width) + n_stb_o = Signal() + self.sync += [ + n_o.eq(o), + n_stb_o.eq(stb_o) + ] + o = n_o + stb_o = n_stb_o + self.comb += [ + self.o.eq(o), + self.stb_o.eq(stb_o) + ] + + +class CopyUnit(BaseUnit): + def __init__(self, data_width): + BaseUnit.__init__(self, data_width) + + self.comb += [ + self.stb_o.eq(self.stb_i), + self.o.eq(self.i0) + ] + + +class InputUnit(BaseUnit): + def __init__(self, data_width, input_stb, input): + BaseUnit.__init__(self, data_width) + + # TODO + self.comb += [ + self.stb_o.eq(self.stb_i), + self.o.eq(42) + ] + + +class OutputUnit(BaseUnit): + def __init__(self, data_width, output_stb, output): + BaseUnit.__init__(self, data_width) + + self.sync += [ + output_stb.eq(self.stb_i), + output.eq(self.i0) + ] + + +class ProcessorImpl(Module): + def __init__(self, pd, program, data): + self.input_stb = Signal() + self.input = Signal(pd.data_width) + + self.output_stb = Signal() + self.output = Signal(pd.data_width) + + # # # + + program_mem = Memory(pd.instruction_bits(), pd.program_rom_size, init=program) + data_mem0 = Memory(pd.data_width, pd.data_ram_size, init=data) + data_mem1 = Memory(pd.data_width, pd.data_ram_size, init=data) + self.specials += program_mem, data_mem0, data_mem1 + + pc = Signal(pd.instruction_bits()) + pc_next = Signal.like(pc) + pc_en = Signal() + self.sync += pc.eq(pc_next) + self.comb += [ + If(pc_en, + pc_next.eq(pc + 1) + ).Else( + pc_next.eq(0) + ) + ] + program_mem_port = program_mem.get_port() + self.specials += program_mem_port + self.comb += program_mem_port.adr.eq(pc_next) + + # TODO + self.comb += pc_en.eq(1) + + s = 0 + opcode = Signal(pd.opcode_bits) + self.comb += opcode.eq(program_mem_port.dat_r[s:s+pd.opcode_bits]) + s += pd.opcode_bits + r0 = Signal(pd.reg_bits) + self.comb += r0.eq(program_mem_port.dat_r[s:s+pd.reg_bits]) + s += pd.reg_bits + r1 = Signal(pd.reg_bits) + self.comb += r1.eq(program_mem_port.dat_r[s:s+pd.reg_bits]) + s += pd.reg_bits + exit = Signal(pd.reg_bits) + self.comb += exit.eq(program_mem_port.dat_r[s:s+pd.reg_bits]) + + data_read_port0 = data_mem0.get_port() + data_read_port1 = data_mem1.get_port() + self.specials += data_read_port0, data_read_port1 + self.comb += [ + data_read_port0.adr.eq(r0), + data_read_port1.adr.eq(r1) + ] + + data_write_port = data_mem0.get_port(write_capable=True) + data_write_port_dup = data_mem1.get_port(write_capable=True) + self.specials += data_write_port, data_write_port_dup + self.comb += [ + data_write_port_dup.we.eq(data_write_port.we), + data_write_port_dup.adr.eq(data_write_port.adr), + data_write_port_dup.dat_w.eq(data_write_port.dat_w), + data_write_port.adr.eq(exit) + ] + + nop = NopUnit(pd.data_width) + adder = OpUnit(operator.add, pd.data_width, 1) + subtractor = OpUnit(operator.sub, pd.data_width, 1) + multiplier = OpUnit(operator.mul, pd.data_width, pd.multiplier_stages) + copier = CopyUnit(pd.data_width) + inu = InputUnit(pd.data_width, self.input_stb, self.input) + outu = OutputUnit(pd.data_width, self.output_stb, self.output) + units = [nop, adder, subtractor, multiplier, copier, inu, outu] + self.submodules += units + + for n, unit in enumerate(units): + self.sync += unit.stb_i.eq(opcode == n) + self.comb += [ + unit.i0.eq(data_read_port0.dat_r), + unit.i1.eq(data_read_port1.dat_r), + If(unit.stb_o, + data_write_port.we.eq(1), + data_write_port.dat_w.eq(unit.o) + ) + ] + + a = 0 b = 0 c = 0 @@ -312,7 +471,22 @@ def foo(x): return 4748*a + 259*b - 155*c -cp = compile(Processor(), foo) -cp.pretty_print() -cp.dimension_memories() -print(cp.encode()) +def simple_test(x): + a = 5 + 3 + return a*4 + + +if __name__ == "__main__": + proc = Processor() + cp = compile(proc, simple_test) + cp.pretty_print() + cp.dimension_memories() + print(cp.encode()) + proc_impl = proc.implement(cp.encode(), cp.data) + + def wait_result(): + while not (yield proc_impl.output_stb): + yield + result = yield proc_impl.output + print(result) + run_simulation(proc_impl, [wait_result()], vcd_name="test.vcd")