diff --git a/README.md b/README.md index 0dfe53f..b0bdfe5 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ Formally verified implementation of the ARTIQ RTIO core in nMigen - - [ ] `rtio.sed.fifos` - - [ ] `rtio.sed.gates` - - [ ] `rtio.sed.output_driver` -- - [ ] `rtio.sed.output_network` +- - [x] `rtio.sed.output_network` - - [ ] `rtio.input_collector` - [ ] Add suitable assertions for verification (BMC / unbounded proof?) diff --git a/rtio/sed/output_network.py b/rtio/sed/output_network.py index e69de29..8c60624 100644 --- a/rtio/sed/output_network.py +++ b/rtio/sed/output_network.py @@ -0,0 +1,105 @@ +from nmigen import * +from nmigen.utils import * + +from rtio.sed import layouts + +__all__ = ["latency", "OutputNetwork"] + +# Based on: https://github.com/Bekbolatov/SortingNetworks/blob/master/src/main/js/gr.js +def boms_get_partner(n, l, p): + if p == 1: + return n ^ (1 << (l - 1)) + scale = 1 << (l - p) + box = 1 << p + sn = n//scale - n//scale//box*box + if sn == 0 or sn == (box - 1): + return n + if (sn % 2) == 0: + return n - scale + return n + scale + +def boms_steps_pairs(lane_count): + d = log2_int(lane_count) + steps = [] + for l in range(1, d+1): + for p in range(1, l+1): + pairs = [] + for n in range(2**d): + partner = boms_get_partner(n, l, p) + if partner != n: + if partner > n: + pair = (n, partner) + else: + pair = (partner, n) + if pair not in pairs: + pairs.append(pair) + steps.append(pairs) + return steps + +def latency(lane_count): + d = log2_int(lane_count) + return sum(l for l in range(1, d+1)) + +def cmp_wrap(a, b): + return Mux((a[-2] == a[-1]) & (b[-2] == b[-1]) & (a[-1] != b[-1]), a[-1], a < b) + +class OutputNetwork(Elaboratable): + def __init__(self, lane_count, seqn_width, layout_payload): + self.lane_count = lane_count + self.seqn_width = seqn_width + self.layout_payload = layout_payload + self.input = [Record(layouts.output_network_node(seqn_width, layout_payload)) + for _ in range(lane_count)] + self.output = None + + def elaborate(self, platform): + m = Module() + + step_input = self.input + for step in boms_steps_pairs(self.lane_count): + step_output = [] + for i in range(lane_count): + rec = Record(layouts.output_network_node(seqn_width, layout_payload), + reset_less=True) + rec.valid.reset_less = False + step_output.append(rec) + + for node1, node2 in step: + nondata_difference = Signal() + for field, _ in layout_payload: + if field != "data": + f1 = getattr(step_input[node1].payload, field) + f2 = getattr(step_input[node2].payload, field) + with m.If(f1 != f2): + m.d.comb += nondata_difference.eq(1) + + k1 = Cat(step_input[node1].payload.channel, ~step_input[node1].valid) + k2 = Cat(step_input[node2].payload.channel, ~step_input[node2].valid) + with m.If(k1 == k2): + with m.If(cmp_wrap(step_input[node1].seqn, step_input[node2].seqn)): + m.d.sync += step_output[node1].eq(step_input[node2]) + m.d.sync += step_output[node2].eq(step_input[node1]) + with m.Else(): + m.d.sync += step_output[node1].eq(step_input[node1]) + m.d.sync += step_output[node2].eq(step_input[node2]) + m.d.sync += step_output[node1].replace_occurred.eq(1) + m.d.sync += step_output[node1].nondata_replace_occurred.eq(nondata_difference), + m.d.sync += step_output[node2].valid.eq(0) + with m.Elif(k1 < k2): + m.d.sync += step_output[node1].eq(step_input[node1]) + m.d.sync += step_output[node2].eq(step_input[node2]) + with m.Else(): + m.d.sync += step_output[node1].eq(step_input[node2]) + m.d.sync += step_output[node2].eq(step_input[node1]) + + unchanged = list(range(lane_count)) + for node1, node2 in step: + unchanged.remove(node1) + unchanged.remove(node2) + for node in unchanged: + m.d.sync += step_output[node].eq(step_input[node]) + + self.output = step_output + step_input = step_output + + return m