forked from M-Labs/artiq
1
0
Fork 0

wrpll/thls: support mulshift

This commit is contained in:
Sebastien Bourdeauducq 2019-08-15 15:07:13 +08:00
parent f861459ace
commit 2776c5b16b
1 changed files with 36 additions and 17 deletions

View File

@ -42,7 +42,7 @@ class AddIsn(Isn):
class SubIsn(Isn): class SubIsn(Isn):
opcode = 2 opcode = 2
class MulIsn(Isn): class MulShiftIsn(Isn):
opcode = 3 opcode = 3
class CopyIsn(Isn): class CopyIsn(Isn):
@ -86,18 +86,27 @@ class ASTCompiler:
def emit(self, node): def emit(self, node):
if isinstance(node, ast.BinOp): if isinstance(node, ast.BinOp):
left = self.emit(node.left) if isinstance(node.op, ast.RShift):
right = self.emit(node.right) if not isinstance(node.left, ast.BinOp) or not isinstance(node.left.op, ast.Mult):
if isinstance(node.op, ast.Add): raise NotImplementedError
cls = AddIsn if not isinstance(node.right, ast.Num):
elif isinstance(node.op, ast.Sub): raise NotImplementedError
cls = SubIsn left = self.emit(node.left.left)
elif isinstance(node.op, ast.Mult): right = self.emit(node.left.right)
cls = MulIsn cons = lambda **kwargs: MulShiftIsn(immediate=node.right.n, **kwargs)
else: else:
raise NotImplementedError left = self.emit(node.left)
right = self.emit(node.right)
if isinstance(node.op, ast.Add):
cons = AddIsn
elif isinstance(node.op, ast.Sub):
cons = SubIsn
elif isinstance(node.op, ast.Mult):
cons = lambda **kwargs: MulShiftIsn(immediate=0, **kwargs)
else:
raise NotImplementedError
output = self.get_ssa_reg() output = self.get_ssa_reg()
self.program.append(cls(inputs=[left, right], outputs=[output])) self.program.append(cons(inputs=[left, right], outputs=[output]))
return output return output
elif isinstance(node, ast.Num): elif isinstance(node, ast.Num):
if node.n in self.constants: if node.n in self.constants:
@ -127,6 +136,7 @@ class Processor:
def __init__(self, data_width=32, multiplier_stages=2): def __init__(self, data_width=32, multiplier_stages=2):
self.data_width = data_width self.data_width = data_width
self.multiplier_stages = multiplier_stages self.multiplier_stages = multiplier_stages
self.multiplier_shifts = []
self.program_rom_size = None self.program_rom_size = None
self.data_ram_size = None self.data_ram_size = None
self.opcode_bits = 3 self.opcode_bits = 3
@ -136,14 +146,14 @@ class Processor:
return { return {
AddIsn: 2, AddIsn: 2,
SubIsn: 2, SubIsn: 2,
MulIsn: 1 + self.multiplier_stages, MulShiftIsn: 1 + self.multiplier_stages,
CopyIsn: 1, CopyIsn: 1,
InputIsn: 1 InputIsn: 1
}[isn.__class__] }[isn.__class__]
def encode_instruction(self, isn, exit): def encode_instruction(self, isn, exit):
opcode = isn.opcode opcode = isn.opcode
if isn.immediate is not None: if isn.immediate is not None and not isinstance(isn, MulShiftIsn):
r0 = isn.immediate r0 = isn.immediate
if len(isn.inputs) >= 1: if len(isn.inputs) >= 1:
r1 = isn.inputs[0] r1 = isn.inputs[0]
@ -265,10 +275,13 @@ class CompiledProgram:
l += " -> r{}".format(self.exits[cycle]) l += " -> r{}".format(self.exits[cycle])
print(l) print(l)
def dimension_memories(self): def dimension_processor(self):
self.processor.program_rom_size = len(self.program) self.processor.program_rom_size = len(self.program)
self.processor.data_ram_size = len(self.data) self.processor.data_ram_size = len(self.data)
self.processor.reg_bits = (self.processor.data_ram_size - 1).bit_length() self.processor.reg_bits = (self.processor.data_ram_size - 1).bit_length()
for isn in self.program:
if isinstance(isn, MulShiftIsn) and isn.immediate not in self.processor.multiplier_shifts:
self.processor.multiplier_shifts.append(isn.immediate)
def encode(self): def encode(self):
r = [] r = []
@ -446,7 +459,13 @@ class ProcessorImpl(Module):
nop = NopUnit(pd.data_width) nop = NopUnit(pd.data_width)
adder = OpUnit(operator.add, pd.data_width, 1) adder = OpUnit(operator.add, pd.data_width, 1)
subtractor = OpUnit(operator.sub, pd.data_width, 1) subtractor = OpUnit(operator.sub, pd.data_width, 1)
multiplier = OpUnit(operator.mul, pd.data_width, pd.multiplier_stages) if pd.multiplier_shifts:
if len(pd.multiplier_shifts) != 1:
raise NotImplementedError
multiplier = OpUnit(lambda a, b: a * b >> pd.multiplier_shifts[0],
pd.data_width, pd.multiplier_stages)
else:
multiplier = NopUnit(pd.data_width)
copier = CopyUnit(pd.data_width) copier = CopyUnit(pd.data_width)
inu = InputUnit(pd.data_width, self.input_stb, self.input) inu = InputUnit(pd.data_width, self.input_stb, self.input)
outu = OutputUnit(pd.data_width, self.output_stb, self.output) outu = OutputUnit(pd.data_width, self.output_stb, self.output)
@ -494,14 +513,14 @@ def foo(x):
def simple_test(x): def simple_test(x):
return x*2+2 return (x*2 >> 1) + 2
if __name__ == "__main__": if __name__ == "__main__":
proc = Processor() proc = Processor()
cp = compile(proc, simple_test) cp = compile(proc, simple_test)
cp.pretty_print() cp.pretty_print()
cp.dimension_memories() cp.dimension_processor()
print(cp.encode()) print(cp.encode())
proc_impl = proc.implement(cp.encode(), cp.data) proc_impl = proc.implement(cp.encode(), cp.data)