262 lines
9.3 KiB
Python
262 lines
9.3 KiB
Python
from functools import reduce
|
|
from itertools import starmap
|
|
from operator import or_
|
|
|
|
from nmigen import *
|
|
|
|
from ..isa import Opcode, Funct3, Funct7, Funct12
|
|
|
|
|
|
__all__ = ["InstructionDecoder"]
|
|
|
|
|
|
class Type:
|
|
R = 0
|
|
I = 1
|
|
S = 2
|
|
B = 3
|
|
U = 4
|
|
J = 5
|
|
|
|
|
|
class InstructionDecoder(Elaboratable):
|
|
def __init__(self, with_muldiv):
|
|
self.with_muldiv = with_muldiv
|
|
|
|
self.instruction = Signal(32)
|
|
|
|
self.rd = Signal(5)
|
|
self.rd_we = Signal()
|
|
self.rs1 = Signal(5)
|
|
self.rs1_re = Signal()
|
|
self.rs2 = Signal(5)
|
|
self.rs2_re = Signal()
|
|
self.immediate = Signal((32, True))
|
|
self.bypass_x = Signal()
|
|
self.bypass_m = Signal()
|
|
self.load = Signal()
|
|
self.store = Signal()
|
|
self.fence_i = Signal()
|
|
self.adder = Signal()
|
|
self.adder_sub = Signal()
|
|
self.logic = Signal()
|
|
self.multiply = Signal()
|
|
self.divide = Signal()
|
|
self.shift = Signal()
|
|
self.direction = Signal()
|
|
self.sext = Signal()
|
|
self.lui = Signal()
|
|
self.auipc = Signal()
|
|
self.jump = Signal()
|
|
self.branch = Signal()
|
|
self.compare = Signal()
|
|
self.csr = Signal()
|
|
self.csr_we = Signal()
|
|
self.privileged = Signal()
|
|
self.ecall = Signal()
|
|
self.ebreak = Signal()
|
|
self.mret = Signal()
|
|
self.funct3 = Signal(3)
|
|
self.illegal = Signal()
|
|
|
|
def elaborate(self, platform):
|
|
m = Module()
|
|
|
|
opcode = Signal(5)
|
|
funct3 = Signal(3)
|
|
funct7 = Signal(7)
|
|
funct12 = Signal(12)
|
|
|
|
iimm12 = Signal((12, True))
|
|
simm12 = Signal((12, True))
|
|
bimm12 = Signal((13, True))
|
|
uimm20 = Signal(20)
|
|
jimm20 = Signal((21, True))
|
|
|
|
insn = self.instruction
|
|
fmt = Signal(range(Type.J + 1))
|
|
|
|
m.d.comb += [
|
|
opcode.eq(insn[2:7]),
|
|
funct3.eq(insn[12:15]),
|
|
funct7.eq(insn[25:32]),
|
|
funct12.eq(insn[20:32]),
|
|
|
|
iimm12.eq(insn[20:32]),
|
|
simm12.eq(Cat(insn[7:12], insn[25:32])),
|
|
bimm12.eq(Cat(0, insn[8:12], insn[25:31], insn[7], insn[31])),
|
|
uimm20.eq(insn[12:32]),
|
|
jimm20.eq(Cat(0, insn[21:31], insn[20], insn[12:20], insn[31])),
|
|
]
|
|
|
|
with m.Switch(opcode):
|
|
with m.Case(Opcode.LUI):
|
|
m.d.comb += fmt.eq(Type.U)
|
|
with m.Case(Opcode.AUIPC):
|
|
m.d.comb += fmt.eq(Type.U)
|
|
with m.Case(Opcode.JAL):
|
|
m.d.comb += fmt.eq(Type.J)
|
|
with m.Case(Opcode.JALR):
|
|
m.d.comb += fmt.eq(Type.I)
|
|
with m.Case(Opcode.BRANCH):
|
|
m.d.comb += fmt.eq(Type.B)
|
|
with m.Case(Opcode.LOAD):
|
|
m.d.comb += fmt.eq(Type.I)
|
|
with m.Case(Opcode.STORE):
|
|
m.d.comb += fmt.eq(Type.S)
|
|
with m.Case(Opcode.OP_IMM_32):
|
|
m.d.comb += fmt.eq(Type.I)
|
|
with m.Case(Opcode.OP_32):
|
|
m.d.comb += fmt.eq(Type.R)
|
|
with m.Case(Opcode.MISC_MEM):
|
|
m.d.comb += fmt.eq(Type.I)
|
|
with m.Case(Opcode.SYSTEM):
|
|
m.d.comb += fmt.eq(Type.I)
|
|
|
|
with m.Switch(fmt):
|
|
with m.Case(Type.I):
|
|
m.d.comb += self.immediate.eq(iimm12)
|
|
with m.Case(Type.S):
|
|
m.d.comb += self.immediate.eq(simm12)
|
|
with m.Case(Type.B):
|
|
m.d.comb += self.immediate.eq(bimm12)
|
|
with m.Case(Type.U):
|
|
m.d.comb += self.immediate.eq(uimm20 << 12)
|
|
with m.Case(Type.J):
|
|
m.d.comb += self.immediate.eq(jimm20)
|
|
|
|
m.d.comb += [
|
|
self.rd.eq(insn[7:12]),
|
|
self.rs1.eq(insn[15:20]),
|
|
self.rs2.eq(insn[20:25]),
|
|
|
|
self.rd_we.eq(reduce(or_, (fmt == T for T in (Type.R, Type.I, Type.U, Type.J)))),
|
|
self.rs1_re.eq(reduce(or_, (fmt == T for T in (Type.R, Type.I, Type.S, Type.B)))),
|
|
self.rs2_re.eq(reduce(or_, (fmt == T for T in (Type.R, Type.S, Type.B)))),
|
|
|
|
self.funct3.eq(funct3)
|
|
]
|
|
|
|
def matcher(encodings):
|
|
return reduce(or_, starmap(
|
|
lambda opc, f3=None, f7=None, f12=None:
|
|
(opcode == opc if opc is not None else 1) \
|
|
& (funct3 == f3 if f3 is not None else 1) \
|
|
& (funct7 == f7 if f7 is not None else 1) \
|
|
& (funct12 == f12 if f12 is not None else 1),
|
|
encodings))
|
|
|
|
m.d.comb += [
|
|
self.compare.eq(matcher([
|
|
(Opcode.OP_IMM_32, Funct3.SLT, None), # slti
|
|
(Opcode.OP_IMM_32, Funct3.SLTU, None), # sltiu
|
|
(Opcode.OP_32, Funct3.SLT, 0), # slt
|
|
(Opcode.OP_32, Funct3.SLTU, 0) # sltu
|
|
])),
|
|
self.branch.eq(matcher([
|
|
(Opcode.BRANCH, Funct3.BEQ, None), # beq
|
|
(Opcode.BRANCH, Funct3.BNE, None), # bne
|
|
(Opcode.BRANCH, Funct3.BLT, None), # blt
|
|
(Opcode.BRANCH, Funct3.BGE, None), # bge
|
|
(Opcode.BRANCH, Funct3.BLTU, None), # bltu
|
|
(Opcode.BRANCH, Funct3.BGEU, None) # bgeu
|
|
])),
|
|
|
|
self.adder.eq(matcher([
|
|
(Opcode.OP_IMM_32, Funct3.ADD, None), # addi
|
|
(Opcode.OP_32, Funct3.ADD, Funct7.ADD), # add
|
|
(Opcode.OP_32, Funct3.ADD, Funct7.SUB) # sub
|
|
])),
|
|
self.adder_sub.eq(self.rs2_re & (funct7 == Funct7.SUB)),
|
|
|
|
self.logic.eq(matcher([
|
|
(Opcode.OP_IMM_32, Funct3.XOR, None), # xori
|
|
(Opcode.OP_IMM_32, Funct3.OR, None), # ori
|
|
(Opcode.OP_IMM_32, Funct3.AND, None), # andi
|
|
(Opcode.OP_32, Funct3.XOR, 0), # xor
|
|
(Opcode.OP_32, Funct3.OR, 0), # or
|
|
(Opcode.OP_32, Funct3.AND, 0) # and
|
|
])),
|
|
]
|
|
|
|
if self.with_muldiv:
|
|
m.d.comb += [
|
|
self.multiply.eq(matcher([
|
|
(Opcode.OP_32, Funct3.MUL, Funct7.MULDIV), # mul
|
|
(Opcode.OP_32, Funct3.MULH, Funct7.MULDIV), # mulh
|
|
(Opcode.OP_32, Funct3.MULHSU, Funct7.MULDIV), # mulhsu
|
|
(Opcode.OP_32, Funct3.MULHU, Funct7.MULDIV), # mulhu
|
|
])),
|
|
|
|
self.divide.eq(matcher([
|
|
(Opcode.OP_32, Funct3.DIV, Funct7.MULDIV), # div
|
|
(Opcode.OP_32, Funct3.DIVU, Funct7.MULDIV), # divu
|
|
(Opcode.OP_32, Funct3.REM, Funct7.MULDIV), # rem
|
|
(Opcode.OP_32, Funct3.REMU, Funct7.MULDIV) # remu
|
|
])),
|
|
]
|
|
|
|
m.d.comb += [
|
|
self.shift.eq(matcher([
|
|
(Opcode.OP_IMM_32, Funct3.SLL, 0), # slli
|
|
(Opcode.OP_IMM_32, Funct3.SR, Funct7.SRL), # srli
|
|
(Opcode.OP_IMM_32, Funct3.SR, Funct7.SRA), # srai
|
|
(Opcode.OP_32, Funct3.SLL, 0), # sll
|
|
(Opcode.OP_32, Funct3.SR, Funct7.SRL), # srl
|
|
(Opcode.OP_32, Funct3.SR, Funct7.SRA) # sra
|
|
])),
|
|
self.direction.eq(funct3 == Funct3.SR),
|
|
self.sext.eq(funct7 == Funct7.SRA),
|
|
|
|
self.lui.eq(opcode == Opcode.LUI),
|
|
self.auipc.eq(opcode == Opcode.AUIPC),
|
|
|
|
self.jump.eq(matcher([
|
|
(Opcode.JAL, None), # jal
|
|
(Opcode.JALR, 0) # jalr
|
|
])),
|
|
|
|
self.load.eq(matcher([
|
|
(Opcode.LOAD, Funct3.B), # lb
|
|
(Opcode.LOAD, Funct3.BU), # lbu
|
|
(Opcode.LOAD, Funct3.H), # lh
|
|
(Opcode.LOAD, Funct3.HU), # lhu
|
|
(Opcode.LOAD, Funct3.W) # lw
|
|
])),
|
|
self.store.eq(matcher([
|
|
(Opcode.STORE, Funct3.B), # sb
|
|
(Opcode.STORE, Funct3.H), # sh
|
|
(Opcode.STORE, Funct3.W) # sw
|
|
])),
|
|
|
|
self.fence_i.eq(matcher([
|
|
(Opcode.MISC_MEM, Funct3.FENCEI) # fence.i
|
|
])),
|
|
|
|
self.csr.eq(matcher([
|
|
(Opcode.SYSTEM, Funct3.CSRRW), # csrrw
|
|
(Opcode.SYSTEM, Funct3.CSRRS), # csrrs
|
|
(Opcode.SYSTEM, Funct3.CSRRC), # csrrc
|
|
(Opcode.SYSTEM, Funct3.CSRRWI), # csrrwi
|
|
(Opcode.SYSTEM, Funct3.CSRRSI), # csrrsi
|
|
(Opcode.SYSTEM, Funct3.CSRRCI) # csrrci
|
|
])),
|
|
self.csr_we.eq(~funct3[1] | (self.rs1 != 0)),
|
|
|
|
self.privileged.eq((opcode == Opcode.SYSTEM) & (funct3 == Funct3.PRIV)),
|
|
self.ecall.eq(self.privileged & (funct12 == Funct12.ECALL)),
|
|
self.ebreak.eq(self.privileged & (funct12 == Funct12.EBREAK)),
|
|
self.mret.eq(self.privileged & (funct12 == Funct12.MRET)),
|
|
|
|
self.bypass_x.eq(self.adder | self.logic | self.lui | self.auipc | self.csr),
|
|
self.bypass_m.eq(self.compare | self.divide | self.shift),
|
|
|
|
self.illegal.eq((self.instruction[:2] != 0b11) | ~reduce(or_, (
|
|
self.compare, self.branch, self.adder, self.logic, self.multiply, self.divide, self.shift,
|
|
self.lui, self.auipc, self.jump, self.load, self.store,
|
|
self.csr, self.ecall, self.ebreak, self.mret, self.fence_i,
|
|
)))
|
|
]
|
|
|
|
return m
|