Reset translation progress

This commit is contained in:
Donald Sebastian Leung 2020-09-30 10:55:08 +08:00
parent a3cdb44572
commit c9857bb831
15 changed files with 13 additions and 716 deletions

View File

@ -4,18 +4,19 @@ Formally verified implementation of the ARTIQ RTIO core in nMigen
## Progress
- Devise a suitable migration strategy for `artiq.gateware.rtio` from Migen to nMigen
- [ ] Implement the core in nMigen
- - [ ] `rtio.core`
- - [ ] `rtio.cri`
- - [x] `rtio.rtlink`
- - [x] `rtio.channel`
- - [ ] `rtio.rtlink`
- - [ ] `rtio.channel`
- - [ ] `rtio.sed.core`
- - [x] `rtio.sed.layouts`
- - [x] `rtio.sed.lane_distributor`
- - [ ] `rtio.sed.layouts`
- - [ ] `rtio.sed.lane_distributor`
- - [ ] `rtio.sed.fifos`
- - [ ] `rtio.sed.gates`
- - [ ] `rtio.sed.output_driver`
- - [x] `rtio.sed.output_network`
- - [ ] `rtio.sed.output_network`
- - [ ] `rtio.input_collector`
- [ ] Add suitable assertions for verification (BMC / unbounded proof?)

View File

@ -0,0 +1 @@

View File

@ -1,36 +1 @@
import warnings
from rtio import rtlink
class Channel:
def __init__(self, interface, probes=None, overrides=None,
ofifo_depth=None, ififo_depth=64):
if probes is None:
probes = []
if overrides is None:
overrides = []
self.interface = interface
self.probes = probes
self.overrides = overrides
if ofifo_depth is None:
ofifo_depth = 64
else:
warnings.warn("ofifo_depth is deprecated", FutureWarning)
self.ofifo_depth = ofifo_depth
self.ififo_depth = ififo_depth
@classmethod
def from_phy(cls, phy, **kwargs):
probes = getattr(phy, "probes", [])
overrides = getattr(phy, "overrides", [])
return cls(phy.rtlink, probes, overrides, **kwargs)
class LogChannel:
"""A degenerate channel used to log messages into the analyzer."""
def __init__(self):
self.interface = rtlink.Interface(rtlink.OInterface(32))
self.probes = []
self.overrides = []

View File

@ -0,0 +1 @@

View File

@ -1,126 +1 @@
from nmigen import *
from nmigen.utils import *
from nmigen.hdl.rec import *
"""Common RTIO Interface"""
# CRI write happens in 3 cycles:
# 1. set timestamp and channel
# 2. set other payload elements and issue write command
# 3. check status
commands = {
"nop": 0,
"write": 1,
# i_status should have the "wait for status" bit set until
# an event is available, or timestamp is reached.
"read": 2,
# targets must assert o_buffer_space_valid in response
# to this command
"get_buffer_space": 3
}
layout = [
("cmd", 2, DIR_FANOUT),
# 8 MSBs of chan_sel = routing destination
# 16 LSBs of chan_sel = channel within the destination
("chan_sel", 24, DIR_FANOUT),
("o_timestamp", 64, DIR_FANOUT),
("o_data", 512, DIR_FANOUT),
("o_address", 8, DIR_FANOUT),
# o_status bits:
# <0:wait> <1:underflow> <2:destination unreachable>
("o_status", 3, DIR_FANIN),
# pessimistic estimate of the number of outputs events that can be
# written without waiting.
# this feature may be omitted on systems without DRTIO.
("o_buffer_space_valid", 1, DIR_FANIN),
("o_buffer_space", 16, DIR_FANIN),
("i_timeout", 64, DIR_FANOUT),
("i_data", 32, DIR_FANIN),
("i_timestamp", 64, DIR_FANIN),
# i_status bits:
# <0:wait for event (command timeout)> <1:overflow> <2:wait for status>
# <3:destination unreachable>
# <0> and <1> are mutually exclusive. <1> has higher priority.
("i_status", 4, DIR_FANIN),
]
class Interface(Record):
def __init__(self, **kwargs):
super().__init__(layout, **kwargs)
# Skip KernelInitiator for now
class CRIDecoder(Elaboratable):
def __init__(self, slaves=2, master=None, mode="async", enable_routing=False):
if isinstance(slaves, int):
slaves = [Interface() for _ in range(slaves)]
if master is None:
master = Interface()
self.slaves = slaves
self.master = master
self.mode = mode
self.enable_routing = enable_routing
def elaborate(self, platform):
m = Module()
# routing
if self.enable_routing:
destination_unreachable = Interface()
m.d.comb += destination_unreachable.o_status.eq(4)
m.d.comb += destination_unreachable.i_status.eq(8)
self.slaves = self.slaves[:]
self.slaves.append(destination_unreachable)
target_len = 2 ** (len(slaves) - 1).bit_length()
self.slaves += [destination_unreachable] * (target_len - len(slaves))
slave_bits = bits_for(len(self.slaves) - 1)
selected = Signal(slave_bits)
if self.enable_routing:
routing_table = Memory(slave_bits, 256)
if self.mode == "async":
rtp_decoder_rdport = routing_table.read_port()
rtp_decoder_wrport = routing_table.write_port()
elif self.mode == "sync":
rtp_decoder_rdport = routing_table.read_port(clock_domain="rtio")
rtp_decoder_wrport = routing_tables.write_port(clock_domain="rtio")
else:
raise ValueError
m.submodules.rtp_decoder_rdport = rtp_decoder_rdport
m.submodules.rtp_decoder_wrport = rtp_decoder_wrport
m.d.comb += rtp_decoder_rdport.addr.eq(self.master.chan_sel[16:])
m.d.comb += selected.eq(rtp_decoder_rdport.data)
else:
m.d.sync += selected.eq(self.master.chan_sel[16:])
# master -> slave
for n, slave in enumerate(self.slaves):
for name, size, direction in layout:
if direction == DIR_FANOUT and name != "cmd":
m.d.comb += getattr(slave, name).eq(getattr(self.master, name))
with m.If(selected == n):
m.d.comb += slave.cmd.eq(self.master.cmd)
# slave -> master
with m.Switch(selected):
for n, slave in enumerate(self.slaves):
with m.Case(n):
for name, size, direction in layout:
if direction == DIR_FANIN:
m.d.comb += getattr(self.master, name).eq(getattr(slave, name))
return m
# TODO: CRISwitch
# TODO: CRIInterconnectShared
# TODO: RoutingTableAccess

View File

@ -0,0 +1 @@

View File

@ -1,97 +1 @@
from nmigen import *
class OInterface:
def __init__(self, data_width, address_width=0,
fine_ts_width=0, enable_replace=True,
delay=0):
self.stb = Signal()
self.busy = Signal()
assert 0 <= data_width <= 512
assert 0 <= address_width <= 8
assert 0 <= fine_ts_width <= 4
if data_width:
self.data = Signal(data_width, reset_less=True)
if address_width:
self.address = Signal(address_width, reset_less=True)
if fine_ts_width:
self.fine_ts = Signal(fine_ts_width, reset_less=True)
self.enable_replace = enable_replace
if delay < 0:
raise ValueError("only positive delays allowed", delay)
self.delay = delay
@classmethod
def like(cls, other):
return cls(get_data_width(other),
get_address_width(other),
get_fine_ts_width(other),
other.enable_replace,
other.delay)
class IInterface:
def __init__(self, data_width,
timestamped=True, fine_ts_width=0, delay=0):
self.stb = Signal()
assert 0 <= data_width <= 32
assert 0 <= fine_ts_width <= 4
if data_width:
self.data = Signal(data_width, reset_less=True)
if fine_ts_width:
self.fine_ts = Signal(fine_ts_width, reset_less=True)
assert(not fine_ts_width or timestamped)
self.timestamped = timestamped
if delay < 0:
raise ValueError("only positive delays")
self.delay = delay
@classmethod
def like(cls, other):
return cls(get_data_width(other),
other.timestamped,
get_fine_ts_width(other),
other.delay)
class Interface:
def __init__(self, o, i=None):
self.o = o
self.i = i
@classmethod
def like(cls, other):
if other.i is None:
return cls(OInterface.like(other.o))
else:
return cls(OInterface.like(other.o),
IInterface.like(other.i))
def _get_or_zero(interface, attr):
if interface is None:
return 0
assert isinstance(interface, (OInterface, IInterface))
if hasattr(interface, attr):
return len(getattr(interface, attr))
else:
return 0
def get_data_width(interface):
return _get_or_zero(interface, "data")
def get_address_width(interface):
return _get_or_zero(interface, "address")
def get_fine_ts_width(interface):
return _get_or_zero(interface, "fine_ts")

View File

@ -0,0 +1 @@

View File

@ -0,0 +1 @@

View File

@ -0,0 +1 @@

View File

@ -0,0 +1 @@

View File

@ -1,170 +1 @@
from nmigen import *
from rtio import cri
from rtio.sed import layouts
__all__ = ["LaneDistributor"]
class LaneDistributor(Elaboratable):
def __init__(self, lane_count, seqn_width, layout_payload,
compensation, glbl_fine_ts_width,
enable_spread=True, quash_channels=[], interface=None):
if lane_count & (lane_count - 1):
raise NotImplementedError("lane count must be a power of 2")
if interface is None:
interface = cri.Interface()
self.cri = interface
self.sequence_error = Signal()
self.sequence_error_channel = Signal(16, reset_less=True)
# The minimum timestamp that an event must have to avoid triggering
# an underflow, at the time when the CRI write happens, and to a channel
# with zero latency compensation. This is synchronous to the system clock
# domain.
us_timestamp_width = 64 - glbl_fine_ts_width
self.minimum_coarse_timestamp = Signal(us_timestamp_width)
self.output = [Record(layouts.fifo_ingress(seqn_width, layout_payload))
for _ in range(lane_count)]
self.lane_count = lane_count
self.us_timestamp_width = us_timestamp_width
self.seqn_width = seqn_width
self.glbl_fine_ts_width = glbl_fine_ts_width
self.quash_channels = quash_channels
self.compensation = compensation
self.enable_spread = enable_spread
def elaborate(self, platform):
m = Module()
o_status_wait = Signal()
o_status_underflow = Signal()
m.d.comb += self.cri.o_status.eq(Cat(o_status_wait, o_status_underflow))
# The core keeps writing events into the current lane as long as timestamps
# (after compensation) are strictly increasing, otherwise it switches to
# the next lane.
# If spread is enabled, it also switches to the next lane after the current
# lane has been full, in order to maximize lane utilization.
# The current lane is called lane "A". The next lane (which may be chosen
# at a later stage by the core) is called lane "B".
# Computations for both lanes are prepared in advance to increase performance.
current_lane = Signal(range(self.lane_count))
# The last coarse timestamp received from the CRI, after compensation.
# Used to determine when to switch lanes.
last_coarse_timestamp = Signal(self.us_timestamp_width)
# The last coarse timestamp written to each lane. Used to detect
# sequence errors.
last_lane_coarse_timestamps = Array(Signal(self.us_timestamp_width)
for _ in range(self.lane_count))
# Sequence number counter. The sequence number is used to determine which
# event wins during a replace.
seqn = Signal(self.seqn_width)
# distribute data to lanes
for lio in self.output:
m.d.comb += lio.seqn.eq(seqn)
m.d.comb += lio.payload.channel.eq(self.cri.chan_sel[:16])
m.d.comb += lio.payload.timestamp.eq(self.cri.o_timestamp)
if hasattr(lio.payload, "address"):
m.d.comb += lio.payload.address.eq(self.cri.o_address)
if hasattr(lio.payload, "data"):
m.d.comb += lio.payload.data.eq(self.cri.o_data)
# when timestamp and channel arrive in cycle #1, prepare computations
coarse_timestamp = Signal(self.us_timestamp_width)
m.d.comb += coarse_timestamp.eq(self.cri.o_timestamp[self.glbl_fine_ts_width:])
min_minus_timestamp = Signal(Shape(self.us_timestamp_width + 1, True),
reset_less=True)
laneAmin_minus_timestamp = Signal.like(min_minus_timestamp)
laneBmin_minus_timestamp = Signal.like(min_minus_timestamp)
last_minus_timestamp = Signal.like(min_minus_timestamp)
current_lane_plus_one = Signal(range(self.lane_count))
m.d.comb += current_lane_plus_one.eq(current_lane + 1)
m.d.sync += min_minus_timestamp.eq(self.minimum_coarse_timestamp - coarse_timestamp)
m.d.sync += laneAmin_minus_timestamp.eq(last_lane_coarse_timestamps[current_lane] - coarse_timestamp)
m.d.sync += laneBmin_minus_timestamp.eq(last_lane_coarse_timestamps[current_lane_plus_one] - coarse_timestamp)
m.d.sync += last_minus_timestamp.eq(last_coarse_timestamp - coarse_timestamp)
# Quash channels are "dummy" channels to which writes are completely ignored.
# This is used by the RTIO log channel, which is taken into account
# by the analyzer but does not enter the lanes.
quash = Signal()
m.d.sync += quash.eq(0)
for channel in self.quash_channels:
with m.If(self.cri.chan_sel[:16] == channel):
m.d.sync += quash.eq(1)
assert all(abs(c) < 1 << 14 - 1 for c in self.compensation)
latency_compensation = Memory(width=14, depth=len(self.compensation), init=self.compensation)
latency_compensation_rdport = latency_compensation.read_port()
m.submodules.latency_compensation_rdport = latency_compensation_rdport
m.d.comb += latency_compensation_rdport.addr.eq(self.cri.chan_sel[:16])
# cycle #2, write
compensation = Signal(Shape(14, True))
m.d.comb += compensation.eq(latency_compensation_rdport.data)
timestamp_above_min = Signal()
timestamp_above_last = Signal()
timestamp_above_laneA_min = Signal()
timestamp_above_laneB_min = Signal()
timestamp_above_lane_min = Signal()
force_laneB = Signal()
use_laneB = Signal()
use_lanen = Signal(range(self.lane_count))
do_write = Signal()
do_underflow = Signal()
do_sequence_error = Signal()
m.d.comb += timestamp_above_min.eq(min_minus_timestamp - compensation < 0)
m.d.comb += timestamp_above_laneA_min.eq(laneAmin_minus_timestamp - compensation < 0)
m.d.comb += timestamp_above_laneB_min.eq(laneBmin_minus_timestamp - compensation < 0)
m.d.comb += timestamp_above_last.eq(last_minus_timestamp - compensation < 0)
with m.If(force_laneB | ~timestamp_above_last):
m.d.comb += use_lanen.eq(current_lane_plus_one)
m.d.comb += use_laneB.eq(1)
with m.Else():
m.d.comb += use_lanen.eq(current_lane)
m.d.comb += use_laneB.eq(0)
m.d.comb += timestamp_above_lane_min.eq(Mux(use_laneB, timestamp_above_laneB_min, timestamp_above_laneA_min))
with m.If(~quash & (self.cri.cmd == cri.commands["write"])):
with m.If(timestamp_above_min):
with m.If(timestamp_above_lane_min):
m.d.comb += do_write.eq(1)
with m.Else():
m.d.comb += do_sequence_error.eq(1)
with m.Else():
m.d.comb += do_underflow.eq(1)
m.d.comb += Array(lio.we for lio in self.output)[use_lanen].eq(do_write)
compensated_timestamp = Signal(64)
m.d.comb += compensated_timestamp.eq(self.cri.o_timestamp + (compensation << self.glbl_fine_ts_width))
with m.If(do_write):
m.d.sync += current_lane.eq(use_lanen)
m.d.sync += last_coarse_timestamp.eq(compensated_timestamp[self.glbl_fine_ts_width:])
m.d.sync += last_lane_coarse_timestamps[use_lanen].eq(compensated_timestamp[self.glbl_fine_ts_width:])
m.d.sync += seqn.eq(seqn + 1)
for lio in self.output:
m.d.comb += lio.payload.timestamp.eq(compensated_timestamp)
# cycle #3, read status
current_lane_writable = Signal()
m.d.comb += current_lane_writable.eq(Array(lio.writable for lio in self.output)[current_lane])
m.d.comb += o_status_wait.eq(~current_lane_writable)
with m.If(self.cri.cmd == cri.commands["write"]):
m.d.sync += o_status_underflow.eq(0)
with m.If(do_underflow):
m.d.sync += o_status_underflow.eq(1)
m.d.sync += self.sequence_error.eq(do_sequence_error)
m.d.sync += self.sequence_error_channel.eq(self.cri.chan_sel[:16])
# current lane has been full, spread events by switching to the next.
if self.enable_spread:
current_lane_writable_r = Signal(reset=1)
m.d.sync += current_lane_writable_r.eq(current_lane_writable)
with m.If(~current_lane_writable_r & current_lane_writable):
m.d.sync += force_laneB.eq(1)
with m.If(do_write):
m.d.sync += force_laneB.eq(0)
return m

View File

@ -1,79 +1 @@
from nmigen import *
from nmigen.utils import *
from nmigen.hdl.rec import *
from rtio import rtlink
def fifo_payload(channels):
address_width = max(rtlink.get_address_width(channel.interface.o)
for channel in channels)
data_width = max(rtlink.get_data_width(channel.interface.o)
for channel in channels)
layout = [
("channel", bits_for(len(channels)-1)),
("timestamp", 64)
]
if address_width:
layout.append(("address", address_width))
if data_width:
layout.append(("data", data_width))
return layout
def seqn_width(lane_count, fifo_depth):
# There must be a unique sequence number for every possible event in every FIFO.
# Plus 2 bits to detect and handle wraparounds.
return bits_for(lane_count*fifo_depth-1) + 2
def fifo_ingress(seqn_width, layout_payload):
return [
("we", 1, DIR_FANOUT),
("writable", 1, DIR_FANIN),
("seqn", seqn_width, DIR_FANOUT),
("payload", [(a, b, DIR_FANOUT) for a, b in layout_payload])
]
def fifo_egress(seqn_width, layout_payload):
return [
("re", 1, DIR_FANIN),
("readable", 1, DIR_FANOUT),
("seqn", seqn_width, DIR_FANOUT),
("payload", [(a, b, DIR_FANOUT) for a, b in layout_payload])
]
# We use glbl_fine_ts_width in the output network so that collisions due
# to insufficiently increasing timestamps are always reliably detected.
# We can still have undetected collisions on the address by making it wrap
# around, but those are more rare and easier to debug, and addresses are
# not normally exposed directly to the ARTIQ user.
def output_network_payload(channels, glbl_fine_ts_width):
address_width = max(rtlink.get_address_width(channel.interface.o)
for channel in channels)
data_width = max(rtlink.get_data_width(channel.interface.o)
for channel in channels)
layout = [("channel", bits_for(len(channels)-1))]
if glbl_fine_ts_width:
layout.append(("fine_ts", glbl_fine_ts_width))
if address_width:
layout.append(("address", address_width))
if data_width:
layout.append(("data", data_width))
return layout
def output_network_node(seqn_width, layout_payload):
return [
("valid", 1),
("seqn", seqn_width),
("replace_occured", 1),
("nondata_replace_occured", 1),
("payload", layout_payload)
]

View File

@ -1,105 +1 @@
from functools import reduce
from operator import or_
from nmigen import *
from rtio.sed import layouts
from rtio.sed.output_network import OutputNetwork
__all__ = ["OutputDriver"]
class OutputDriver(Elaboratable):
def __init__(self, channels, glbl_fine_ts_width, lane_count, seqn_width):
self.channels = channels
self.glbl_fine_ts_width = glbl_fine_ts_width
self.lane_count = lane_count
self.seqn_width = seqn_width
self.collision = Signal()
self.collision_channel = Signal(range(len(channels)), reset_less=True)
self.busy = Signal()
self.busy_channel = Signal(range(len(channels)), reset_less=True)
self.input = None
def elaborate(self, platform):
m = Module()
# output network
layout_on_payload = layouts.output_network_payload(self.channels, self.glbl_fine_ts_width)
output_network = OutputNetwork(self.lane_count, self.seqn_width, layout_on_payload)
m.submodules.output_network = output_network
self.input = output_network.input
# detect collisions (adds one pipeline stage)
layout_lane_data = [
("valid", 1),
("collision", 1),
("payload", layout_on_payload)
]
lane_datas = [Record(layout_lane_data, reset_less=True) for _ in range(lane_count)]
en_replaces = [channel.interface.o.enable_replace for channel in channels]
for lane_data, on_output in zip(lane_datas, output_network.output):
lane_data.valid.reset_less = False
lane_data.collision.reset_less = False
replace_occurred_r = Signal()
nondata_replace_occurred_r = Signal()
m.d.sync += lane_data.valid.eq(on_output.valid)
m.d.sync += lane_data.payload.eq(on_output.payload)
m.d.sync += replace_occurred_r.eq(on_output.replace_occurred)
m.d.sync += nondata_replace_occurred_r.eq(on_output.nondata_replace_occurred)
en_replaces_rom = Memory(1, len(en_replaces), init=en_replaces)
en_replaces_rom_rdport = en_replaces_rom.read_port()
m.submodules.en_replaces_rom_rdport = en_replaces_rom_rdport
m.d.comb += en_replaces_rom_rdport.addr.eq(on_output.payload.channel)
m.d.comb += lane_data.collision.eq(replace_occurred_r & (~en_replaces_rom_rdport.data | nondata_replace_occurred_r))
m.d.sync += self.collision.eq(0)
m.d.sync += self.collision_channel.eq(0)
for lane_data in lane_datas:
with m.If(lane_data.valid & lane_data.collision):
m.d.sync += self.collision.eq(1)
m.d.sync += self.collision_channel.eq(lane_data.payload.channel)
# demultiplex channels (adds one pipeline stage)
for n, channel in enumerate(channels):
oif = channel.interface.o
onehot_stb = []
onehot_fine_ts = []
onehot_address = []
onehot_data = []
for lane_data in lane_datas:
selected = Signal()
m.d.comb += selected.eq(lane_data.valid & ~lane_data.collision & (lane_data.payload.channel == n))
onehot_stb.append(selected)
if hasattr(lane_data.payload, "fine_ts") and hasattr(oif, "fine_ts"):
ts_shift = len(lane_data.payload.fine_ts) - len(oif.fine_ts)
onehot_fine_ts.append(Mux(selected, lane_data.payload.fine_ts[ts_shift:], 0))
if hasattr(lane_data.payload, "address"):
onehot_address.append(Mux(selected, lane_data.payload.address, 0))
if hasattr(lane_data.payload, "data"):
onehot_data.append(Mux(selected, lane_data.payload.data, 0))
m.d.sync += oif.stb.eq(reduce(or_, onehot_stb))
if hasattr(oif, "fine_ts"):
m.d.sync += oif.fine_ts.eq(reduce(or_, onehot_fine_ts))
if hasattr(oif, "address"):
m.d.sync += oif.address.eq(reduce(or_, onehot_address))
if hasattr(oif, "data"):
m.d.sync += oif.data.eq(reduce(or_, onehot_data))
# detect busy errors, at lane level to reduce muxing
m.d.sync += self.busy.eq(0)
m.d.sync += self.busy_channel.eq(0)
for lane_data in lane_datas:
stb_r = Signal()
channel_r = Signal(range(len(channels)), reset_less=True)
m.d.sync += stb_r.eq(lane_data.valid & ~lane_data.collision)
m.d.sync += channel_r.eq(lane_data.payload.channel)
with m.If(stb_r & Array(channel.interface.o.busy for channel in channels)[channel_r]):
m.d.sync += self.busy.eq(1)
m.d.sync += self.busy_channel.eq(channel_r)
return m
OutputDriver([], 1, 1, 1).elaborate(None)

View File

@ -1,105 +1 @@
from nmigen import *
from nmigen.utils import *
from rtio.sed import layouts
__all__ = ["latency", "OutputNetwork"]
# Based on: https://github.com/Bekbolatov/SortingNetworks/blob/master/src/main/js/gr.js
def boms_get_partner(n, l, p):
if p == 1:
return n ^ (1 << (l - 1))
scale = 1 << (l - p)
box = 1 << p
sn = n//scale - n//scale//box*box
if sn == 0 or sn == (box - 1):
return n
if (sn % 2) == 0:
return n - scale
return n + scale
def boms_steps_pairs(lane_count):
d = log2_int(lane_count)
steps = []
for l in range(1, d+1):
for p in range(1, l+1):
pairs = []
for n in range(2**d):
partner = boms_get_partner(n, l, p)
if partner != n:
if partner > n:
pair = (n, partner)
else:
pair = (partner, n)
if pair not in pairs:
pairs.append(pair)
steps.append(pairs)
return steps
def latency(lane_count):
d = log2_int(lane_count)
return sum(l for l in range(1, d+1))
def cmp_wrap(a, b):
return Mux((a[-2] == a[-1]) & (b[-2] == b[-1]) & (a[-1] != b[-1]), a[-1], a < b)
class OutputNetwork(Elaboratable):
def __init__(self, lane_count, seqn_width, layout_payload):
self.lane_count = lane_count
self.seqn_width = seqn_width
self.layout_payload = layout_payload
self.input = [Record(layouts.output_network_node(seqn_width, layout_payload))
for _ in range(lane_count)]
self.output = None
def elaborate(self, platform):
m = Module()
step_input = self.input
for step in boms_steps_pairs(self.lane_count):
step_output = []
for i in range(lane_count):
rec = Record(layouts.output_network_node(seqn_width, layout_payload),
reset_less=True)
rec.valid.reset_less = False
step_output.append(rec)
for node1, node2 in step:
nondata_difference = Signal()
for field, _ in layout_payload:
if field != "data":
f1 = getattr(step_input[node1].payload, field)
f2 = getattr(step_input[node2].payload, field)
with m.If(f1 != f2):
m.d.comb += nondata_difference.eq(1)
k1 = Cat(step_input[node1].payload.channel, ~step_input[node1].valid)
k2 = Cat(step_input[node2].payload.channel, ~step_input[node2].valid)
with m.If(k1 == k2):
with m.If(cmp_wrap(step_input[node1].seqn, step_input[node2].seqn)):
m.d.sync += step_output[node1].eq(step_input[node2])
m.d.sync += step_output[node2].eq(step_input[node1])
with m.Else():
m.d.sync += step_output[node1].eq(step_input[node1])
m.d.sync += step_output[node2].eq(step_input[node2])
m.d.sync += step_output[node1].replace_occurred.eq(1)
m.d.sync += step_output[node1].nondata_replace_occurred.eq(nondata_difference),
m.d.sync += step_output[node2].valid.eq(0)
with m.Elif(k1 < k2):
m.d.sync += step_output[node1].eq(step_input[node1])
m.d.sync += step_output[node2].eq(step_input[node2])
with m.Else():
m.d.sync += step_output[node1].eq(step_input[node2])
m.d.sync += step_output[node2].eq(step_input[node1])
unchanged = list(range(lane_count))
for node1, node2 in step:
unchanged.remove(node1)
unchanged.remove(node2)
for node in unchanged:
m.d.sync += step_output[node].eq(step_input[node])
self.output = step_output
step_input = step_output
return m