from nmigen import * from nmigen.utils import log2_int from ..cache import * from ..wishbone import * __all__ = ["PCSelector", "FetchUnitInterface", "BareFetchUnit", "CachedFetchUnit"] class PCSelector(Elaboratable): def __init__(self): self.f_pc = Signal(32) self.d_pc = Signal(32) self.d_branch_predict_taken = Signal() self.d_branch_target = Signal(32) self.d_valid = Signal() self.x_pc = Signal(32) self.x_fence_i = Signal() self.x_valid = Signal() self.m_branch_predict_taken = Signal() self.m_branch_taken = Signal() self.m_branch_target = Signal(32) self.m_exception = Signal() self.m_mret = Signal() self.m_valid = Signal() self.mtvec_r_base = Signal(30) self.mepc_r_base = Signal(30) self.a_pc = Signal(32) def elaborate(self, platform): m = Module() m_sel = Signal(reset=1) m_a_pc = Signal(32) with m.If(self.m_exception): m.d.comb += m_a_pc[2:].eq(self.mtvec_r_base) with m.Elif(self.m_mret): m.d.comb += m_a_pc[2:].eq(self.mepc_r_base) with m.Elif(self.m_branch_predict_taken & ~self.m_branch_taken): m.d.comb += m_a_pc[2:].eq(self.x_pc[2:]) with m.Elif(~self.m_branch_predict_taken & self.m_branch_taken): m.d.comb += m_a_pc[2:].eq(self.m_branch_target[2:]), with m.Else(): m.d.comb += m_sel.eq(0) with m.If(m_sel & self.m_valid): m.d.comb += self.a_pc[2:].eq(m_a_pc[2:]) with m.Elif(self.x_fence_i & self.x_valid): m.d.comb += self.a_pc[2:].eq(self.d_pc[2:]) with m.Elif(self.d_branch_predict_taken & self.d_valid): m.d.comb += self.a_pc[2:].eq(self.d_branch_target[2:]), with m.Else(): m.d.comb += self.a_pc[2:].eq(self.f_pc[2:] + 1) return m class FetchUnitInterface: def __init__(self): self.ibus = Record(wishbone_layout) self.a_pc = Signal(32) self.a_stall = Signal() self.a_valid = Signal() self.f_stall = Signal() self.f_valid = Signal() self.a_busy = Signal() self.f_busy = Signal() self.f_instruction = Signal(32, reset=0x00000013) # nop (addi x0, x0, 0) self.f_fetch_error = Signal() self.f_badaddr = Signal(30) class BareFetchUnit(FetchUnitInterface, Elaboratable): def elaborate(self, platform): m = Module() ibus_rdata = Signal.like(self.ibus.dat_r) with m.If(self.ibus.cyc): with m.If(self.ibus.ack | self.ibus.err | ~self.f_valid): m.d.sync += [ self.ibus.cyc.eq(0), self.ibus.stb.eq(0), ibus_rdata.eq(self.ibus.dat_r) ] with m.Elif(self.a_valid & ~self.a_stall): m.d.sync += [ self.ibus.adr.eq(self.a_pc[2:]), self.ibus.cyc.eq(1), self.ibus.stb.eq(1) ] m.d.comb += self.ibus.sel.eq(0b1111) with m.If(self.ibus.cyc & self.ibus.err): m.d.sync += [ self.f_fetch_error.eq(1), self.f_badaddr.eq(self.ibus.adr) ] with m.Elif(~self.f_stall): m.d.sync += self.f_fetch_error.eq(0) m.d.comb += self.a_busy.eq(self.ibus.cyc) with m.If(self.f_fetch_error): m.d.comb += self.f_busy.eq(0) with m.Else(): m.d.comb += [ self.f_busy.eq(self.ibus.cyc), self.f_instruction.eq(ibus_rdata) ] return m class CachedFetchUnit(FetchUnitInterface, Elaboratable): def __init__(self, *icache_args): super().__init__() self.icache_args = icache_args self.f_pc = Signal(32) self.a_flush = Signal() def elaborate(self, platform): m = Module() icache = m.submodules.icache = L1Cache(*self.icache_args) a_icache_select = Signal() # Test whether the target address is inside the L1 cache region. We use bit masks in order # to avoid carry chains from arithmetic comparisons. This restricts the region boundaries # to powers of 2. with m.Switch(self.a_pc[2:]): def addr_below(limit): assert limit in range(1, 2**30 + 1) range_bits = log2_int(limit) const_bits = 30 - range_bits return "{}{}".format("0" * const_bits, "-" * range_bits) if icache.base >= 4: with m.Case(addr_below(icache.base >> 2)): m.d.comb += a_icache_select.eq(0) with m.Case(addr_below(icache.limit >> 2)): m.d.comb += a_icache_select.eq(1) with m.Default(): m.d.comb += a_icache_select.eq(0) f_icache_select = Signal() f_flush = Signal() with m.If(~self.a_stall): m.d.sync += [ f_icache_select.eq(a_icache_select), f_flush.eq(self.a_flush), ] m.d.comb += [ icache.s1_addr.eq(self.a_pc[2:]), icache.s1_stall.eq(self.a_stall), icache.s1_valid.eq(self.a_valid), icache.s2_addr.eq(self.f_pc[2:]), icache.s2_re.eq(f_icache_select), icache.s2_evict.eq(Const(0)), icache.s2_flush.eq(f_flush), icache.s2_valid.eq(self.f_valid), ] ibus_arbiter = m.submodules.ibus_arbiter = WishboneArbiter() m.d.comb += ibus_arbiter.bus.connect(self.ibus) icache_port = ibus_arbiter.port(priority=0) m.d.comb += [ icache_port.cyc.eq(icache.bus_re), icache_port.stb.eq(icache.bus_re), icache_port.adr.eq(icache.bus_addr), icache_port.sel.eq(0b1111), icache_port.cti.eq(Mux(icache.bus_last, Cycle.END, Cycle.INCREMENT)), icache_port.bte.eq(Const(log2_int(icache.nwords) - 1)), icache.bus_valid.eq(icache_port.ack), icache.bus_error.eq(icache_port.err), icache.bus_rdata.eq(icache_port.dat_r) ] bare_port = ibus_arbiter.port(priority=1) bare_rdata = Signal.like(bare_port.dat_r) with m.If(bare_port.cyc): with m.If(bare_port.ack | bare_port.err | ~self.f_valid): m.d.sync += [ bare_port.cyc.eq(0), bare_port.stb.eq(0), bare_rdata.eq(bare_port.dat_r) ] with m.Elif(~a_icache_select & self.a_valid & ~self.a_stall): m.d.sync += [ bare_port.cyc.eq(1), bare_port.stb.eq(1), bare_port.adr.eq(self.a_pc[2:]) ] m.d.comb += bare_port.sel.eq(0b1111) m.d.comb += self.a_busy.eq(bare_port.cyc) with m.If(self.ibus.cyc & self.ibus.err): m.d.sync += [ self.f_fetch_error.eq(1), self.f_badaddr.eq(self.ibus.adr) ] with m.Elif(~self.f_stall): m.d.sync += self.f_fetch_error.eq(0) with m.If(f_flush): m.d.comb += self.f_busy.eq(~icache.s2_flush_ack) with m.Elif(self.f_fetch_error): m.d.comb += self.f_busy.eq(0) with m.Elif(f_icache_select): m.d.comb += [ self.f_busy.eq(icache.s2_miss), self.f_instruction.eq(icache.s2_rdata) ] with m.Else(): m.d.comb += [ self.f_busy.eq(bare_port.cyc), self.f_instruction.eq(bare_rdata) ] return m