forked from M-Labs/artiq
1
0
Fork 0

wrpll/thls: simple simulation demo

This commit is contained in:
Sebastien Bourdeauducq 2019-07-20 18:50:57 +08:00
parent 831b3514d3
commit 623446f82c
1 changed files with 187 additions and 13 deletions

View File

@ -1,6 +1,10 @@
import inspect
import ast
from copy import copy
import operator
from functools import reduce
from migen import *
class Isn:
@ -40,17 +44,14 @@ class SubIsn(Isn):
class MulIsn(Isn):
opcode = 3
class ShiftIsn(Isn):
class CopyIsn(Isn):
opcode = 4
class CopyIsn(Isn):
class InputIsn(Isn):
opcode = 5
class InputIsn(Isn):
opcode = 6
class OutputIsn(Isn):
opcode = 7
opcode = 6
class ASTCompiler:
@ -132,7 +133,6 @@ class Processor:
AddIsn: 2,
SubIsn: 2,
MulIsn: 1 + self.multiplier_stages,
ShiftIsn: 2,
CopyIsn: 1,
InputIsn: 1
}[isn.__class__]
@ -160,6 +160,12 @@ class Processor:
r |= value
return r
def instruction_bits(self):
return 3*self.reg_bits + self.opcode_bits
def implement(self, program, data):
return ProcessorImpl(self, program, data)
class Scheduler:
def __init__(self, processor, reserved_data, program):
@ -292,14 +298,167 @@ def compile(processor, function):
scheduler = Scheduler(processor, len(astcompiler.data), astcompiler.program)
scheduler.schedule()
max_reg = max(max(max(isn.inputs + [0]) for isn in scheduler.output), max(v[1] for k, v in scheduler.exits.items()))
return CompiledProgram(
processor=processor,
program=scheduler.output,
exits={k: v[1] for k,v in scheduler.exits.items()},
data=astcompiler.data,
exits={k: v[1] for k, v in scheduler.exits.items()},
data=astcompiler.data + [0]*(max_reg - len(astcompiler.data) + 1),
glbs=astcompiler.globals)
class BaseUnit(Module):
def __init__(self, data_width):
self.stb_i = Signal()
self.i0 = Signal(data_width)
self.i1 = Signal(data_width)
self.stb_o = Signal()
self.o = Signal(data_width)
class NopUnit(BaseUnit):
pass
class OpUnit(BaseUnit):
def __init__(self, op, data_width, stages):
BaseUnit.__init__(self, data_width)
o = op(self.i0, self.i1)
stb_o = self.stb_i
for i in range(stages):
n_o = Signal(data_width)
n_stb_o = Signal()
self.sync += [
n_o.eq(o),
n_stb_o.eq(stb_o)
]
o = n_o
stb_o = n_stb_o
self.comb += [
self.o.eq(o),
self.stb_o.eq(stb_o)
]
class CopyUnit(BaseUnit):
def __init__(self, data_width):
BaseUnit.__init__(self, data_width)
self.comb += [
self.stb_o.eq(self.stb_i),
self.o.eq(self.i0)
]
class InputUnit(BaseUnit):
def __init__(self, data_width, input_stb, input):
BaseUnit.__init__(self, data_width)
# TODO
self.comb += [
self.stb_o.eq(self.stb_i),
self.o.eq(42)
]
class OutputUnit(BaseUnit):
def __init__(self, data_width, output_stb, output):
BaseUnit.__init__(self, data_width)
self.sync += [
output_stb.eq(self.stb_i),
output.eq(self.i0)
]
class ProcessorImpl(Module):
def __init__(self, pd, program, data):
self.input_stb = Signal()
self.input = Signal(pd.data_width)
self.output_stb = Signal()
self.output = Signal(pd.data_width)
# # #
program_mem = Memory(pd.instruction_bits(), pd.program_rom_size, init=program)
data_mem0 = Memory(pd.data_width, pd.data_ram_size, init=data)
data_mem1 = Memory(pd.data_width, pd.data_ram_size, init=data)
self.specials += program_mem, data_mem0, data_mem1
pc = Signal(pd.instruction_bits())
pc_next = Signal.like(pc)
pc_en = Signal()
self.sync += pc.eq(pc_next)
self.comb += [
If(pc_en,
pc_next.eq(pc + 1)
).Else(
pc_next.eq(0)
)
]
program_mem_port = program_mem.get_port()
self.specials += program_mem_port
self.comb += program_mem_port.adr.eq(pc_next)
# TODO
self.comb += pc_en.eq(1)
s = 0
opcode = Signal(pd.opcode_bits)
self.comb += opcode.eq(program_mem_port.dat_r[s:s+pd.opcode_bits])
s += pd.opcode_bits
r0 = Signal(pd.reg_bits)
self.comb += r0.eq(program_mem_port.dat_r[s:s+pd.reg_bits])
s += pd.reg_bits
r1 = Signal(pd.reg_bits)
self.comb += r1.eq(program_mem_port.dat_r[s:s+pd.reg_bits])
s += pd.reg_bits
exit = Signal(pd.reg_bits)
self.comb += exit.eq(program_mem_port.dat_r[s:s+pd.reg_bits])
data_read_port0 = data_mem0.get_port()
data_read_port1 = data_mem1.get_port()
self.specials += data_read_port0, data_read_port1
self.comb += [
data_read_port0.adr.eq(r0),
data_read_port1.adr.eq(r1)
]
data_write_port = data_mem0.get_port(write_capable=True)
data_write_port_dup = data_mem1.get_port(write_capable=True)
self.specials += data_write_port, data_write_port_dup
self.comb += [
data_write_port_dup.we.eq(data_write_port.we),
data_write_port_dup.adr.eq(data_write_port.adr),
data_write_port_dup.dat_w.eq(data_write_port.dat_w),
data_write_port.adr.eq(exit)
]
nop = NopUnit(pd.data_width)
adder = OpUnit(operator.add, pd.data_width, 1)
subtractor = OpUnit(operator.sub, pd.data_width, 1)
multiplier = OpUnit(operator.mul, pd.data_width, pd.multiplier_stages)
copier = CopyUnit(pd.data_width)
inu = InputUnit(pd.data_width, self.input_stb, self.input)
outu = OutputUnit(pd.data_width, self.output_stb, self.output)
units = [nop, adder, subtractor, multiplier, copier, inu, outu]
self.submodules += units
for n, unit in enumerate(units):
self.sync += unit.stb_i.eq(opcode == n)
self.comb += [
unit.i0.eq(data_read_port0.dat_r),
unit.i1.eq(data_read_port1.dat_r),
If(unit.stb_o,
data_write_port.we.eq(1),
data_write_port.dat_w.eq(unit.o)
)
]
a = 0
b = 0
c = 0
@ -312,7 +471,22 @@ def foo(x):
return 4748*a + 259*b - 155*c
cp = compile(Processor(), foo)
cp.pretty_print()
cp.dimension_memories()
print(cp.encode())
def simple_test(x):
a = 5 + 3
return a*4
if __name__ == "__main__":
proc = Processor()
cp = compile(proc, simple_test)
cp.pretty_print()
cp.dimension_memories()
print(cp.encode())
proc_impl = proc.implement(cp.encode(), cp.data)
def wait_result():
while not (yield proc_impl.output_stb):
yield
result = yield proc_impl.output
print(result)
run_simulation(proc_impl, [wait_result()], vcd_name="test.vcd")