wrpll/thls: implement min/max

This commit is contained in:
Sebastien Bourdeauducq 2019-08-15 16:42:59 +08:00
parent 44969b03ad
commit efc43142a6

View File

@ -47,6 +47,12 @@ class MulShiftIsn(Isn):
# opcode = 4: MulShift with alternate shift
class MinIsn(Isn):
opcode = 5
class MaxIsn(Isn):
opcode = 6
class CopyIsn(Isn):
opcode = 7
@ -110,6 +116,22 @@ class ASTCompiler:
output = self.get_ssa_reg()
self.program.append(cons(inputs=[left, right], outputs=[output]))
return output
elif isinstance(node, ast.Call):
if not isinstance(node.func, ast.Name):
raise NotImplementedError
funcname = node.func.id
if node.keywords:
raise NotImplementedError
inputs = [self.emit(x) for x in node.args]
if funcname == "min":
cons = MinIsn
elif funcname == "max":
cons = MaxIsn
else:
raise NotImplementedError
output = self.get_ssa_reg()
self.program.append(cons(inputs=inputs, outputs=[output]))
return output
elif isinstance(node, ast.Num):
if node.n in self.constants:
return self.constants[node.n]
@ -149,6 +171,8 @@ class Processor:
AddIsn: 2,
SubIsn: 2,
MulShiftIsn: 1 + self.multiplier_stages,
MinIsn: 2,
MaxIsn: 2,
CopyIsn: 1,
InputIsn: 1
}[isn.__class__]
@ -364,6 +388,20 @@ class OpUnit(BaseUnit):
]
class SelectUnit(BaseUnit):
def __init__(self, op, data_width):
BaseUnit.__init__(self, data_width)
self.sync += [
self.stb_o.eq(self.stb_i),
If(op(self.i0, self.i1),
self.o.eq(self.i0)
).Else(
self.o.eq(self.i1)
)
]
class CopyUnit(BaseUnit):
def __init__(self, data_width):
BaseUnit.__init__(self, data_width)
@ -468,10 +506,12 @@ class ProcessorImpl(Module):
pd.data_width, pd.multiplier_stages)
else:
multiplier = NopUnit(pd.data_width)
minu = SelectUnit(operator.lt, pd.data_width)
maxu = SelectUnit(operator.gt, pd.data_width)
copier = CopyUnit(pd.data_width)
inu = InputUnit(pd.data_width, self.input_stb, self.input)
outu = OutputUnit(pd.data_width, self.output_stb, self.output)
units = [nop, adder, subtractor, multiplier, copier, inu, outu]
units = [nop, adder, subtractor, multiplier, minu, maxu, copier, inu, outu]
self.submodules += units
for unit in units:
@ -490,6 +530,8 @@ class ProcessorImpl(Module):
(SubIsn.opcode, subtractor),
(MulShiftIsn.opcode, multiplier),
(MulShiftIsn.opcode + 1, multiplier),
(MinIsn.opcode, minu),
(MaxIsn.opcode, maxu),
(CopyIsn.opcode, copier),
(InputIsn.opcode, inu),
(OutputIsn.opcode, outu)
@ -527,7 +569,7 @@ def foo(x):
def simple_test(x):
return (x*2 >> 1) + 2
return min((x*2 >> 1) + 2, 10)
if __name__ == "__main__":