diff --git a/artiq/gateware/wrpll/thls.py b/artiq/gateware/wrpll/thls.py index e28d812bd..ba2dde7f6 100644 --- a/artiq/gateware/wrpll/thls.py +++ b/artiq/gateware/wrpll/thls.py @@ -42,7 +42,7 @@ class AddIsn(Isn): class SubIsn(Isn): opcode = 2 -class MulIsn(Isn): +class MulShiftIsn(Isn): opcode = 3 class CopyIsn(Isn): @@ -86,18 +86,27 @@ class ASTCompiler: def emit(self, node): if isinstance(node, ast.BinOp): - left = self.emit(node.left) - right = self.emit(node.right) - if isinstance(node.op, ast.Add): - cls = AddIsn - elif isinstance(node.op, ast.Sub): - cls = SubIsn - elif isinstance(node.op, ast.Mult): - cls = MulIsn + if isinstance(node.op, ast.RShift): + if not isinstance(node.left, ast.BinOp) or not isinstance(node.left.op, ast.Mult): + raise NotImplementedError + if not isinstance(node.right, ast.Num): + raise NotImplementedError + left = self.emit(node.left.left) + right = self.emit(node.left.right) + cons = lambda **kwargs: MulShiftIsn(immediate=node.right.n, **kwargs) else: - raise NotImplementedError + left = self.emit(node.left) + right = self.emit(node.right) + if isinstance(node.op, ast.Add): + cons = AddIsn + elif isinstance(node.op, ast.Sub): + cons = SubIsn + elif isinstance(node.op, ast.Mult): + cons = lambda **kwargs: MulShiftIsn(immediate=0, **kwargs) + else: + raise NotImplementedError output = self.get_ssa_reg() - self.program.append(cls(inputs=[left, right], outputs=[output])) + self.program.append(cons(inputs=[left, right], outputs=[output])) return output elif isinstance(node, ast.Num): if node.n in self.constants: @@ -127,6 +136,7 @@ class Processor: def __init__(self, data_width=32, multiplier_stages=2): self.data_width = data_width self.multiplier_stages = multiplier_stages + self.multiplier_shifts = [] self.program_rom_size = None self.data_ram_size = None self.opcode_bits = 3 @@ -136,14 +146,14 @@ class Processor: return { AddIsn: 2, SubIsn: 2, - MulIsn: 1 + self.multiplier_stages, + MulShiftIsn: 1 + self.multiplier_stages, CopyIsn: 1, InputIsn: 1 }[isn.__class__] def encode_instruction(self, isn, exit): opcode = isn.opcode - if isn.immediate is not None: + if isn.immediate is not None and not isinstance(isn, MulShiftIsn): r0 = isn.immediate if len(isn.inputs) >= 1: r1 = isn.inputs[0] @@ -265,10 +275,13 @@ class CompiledProgram: l += " -> r{}".format(self.exits[cycle]) print(l) - def dimension_memories(self): + def dimension_processor(self): self.processor.program_rom_size = len(self.program) self.processor.data_ram_size = len(self.data) self.processor.reg_bits = (self.processor.data_ram_size - 1).bit_length() + for isn in self.program: + if isinstance(isn, MulShiftIsn) and isn.immediate not in self.processor.multiplier_shifts: + self.processor.multiplier_shifts.append(isn.immediate) def encode(self): r = [] @@ -446,7 +459,13 @@ class ProcessorImpl(Module): nop = NopUnit(pd.data_width) adder = OpUnit(operator.add, pd.data_width, 1) subtractor = OpUnit(operator.sub, pd.data_width, 1) - multiplier = OpUnit(operator.mul, pd.data_width, pd.multiplier_stages) + if pd.multiplier_shifts: + if len(pd.multiplier_shifts) != 1: + raise NotImplementedError + multiplier = OpUnit(lambda a, b: a * b >> pd.multiplier_shifts[0], + pd.data_width, pd.multiplier_stages) + else: + multiplier = NopUnit(pd.data_width) copier = CopyUnit(pd.data_width) inu = InputUnit(pd.data_width, self.input_stb, self.input) outu = OutputUnit(pd.data_width, self.output_stb, self.output) @@ -494,14 +513,14 @@ def foo(x): def simple_test(x): - return x*2+2 + return (x*2 >> 1) + 2 if __name__ == "__main__": proc = Processor() cp = compile(proc, simple_test) cp.pretty_print() - cp.dimension_memories() + cp.dimension_processor() print(cp.encode()) proc_impl = proc.implement(cp.encode(), cp.data)