From 36fb6306b04245e5e3f9043f6a52eb11d476e1aa Mon Sep 17 00:00:00 2001
From: Donald Sebastian Leung
Date: Tue, 29 Sep 2020 16:35:59 +0800
Subject: [PATCH] Add rtio.sed.output_network
---
README.md | 2 +-
rtio/sed/output_network.py | 105 +++++++++++++++++++++++++++++++++++++
2 files changed, 106 insertions(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 0dfe53f..b0bdfe5 100644
--- a/README.md
+++ b/README.md
@@ -15,7 +15,7 @@ Formally verified implementation of the ARTIQ RTIO core in nMigen
- - [ ] `rtio.sed.fifos`
- - [ ] `rtio.sed.gates`
- - [ ] `rtio.sed.output_driver`
-- - [ ] `rtio.sed.output_network`
+- - [x] `rtio.sed.output_network`
- - [ ] `rtio.input_collector`
- [ ] Add suitable assertions for verification (BMC / unbounded proof?)
diff --git a/rtio/sed/output_network.py b/rtio/sed/output_network.py
index e69de29..8c60624 100644
--- a/rtio/sed/output_network.py
+++ b/rtio/sed/output_network.py
@@ -0,0 +1,105 @@
+from nmigen import *
+from nmigen.utils import *
+
+from rtio.sed import layouts
+
+__all__ = ["latency", "OutputNetwork"]
+
+# Based on: https://github.com/Bekbolatov/SortingNetworks/blob/master/src/main/js/gr.js
+def boms_get_partner(n, l, p):
+ if p == 1:
+ return n ^ (1 << (l - 1))
+ scale = 1 << (l - p)
+ box = 1 << p
+ sn = n//scale - n//scale//box*box
+ if sn == 0 or sn == (box - 1):
+ return n
+ if (sn % 2) == 0:
+ return n - scale
+ return n + scale
+
+def boms_steps_pairs(lane_count):
+ d = log2_int(lane_count)
+ steps = []
+ for l in range(1, d+1):
+ for p in range(1, l+1):
+ pairs = []
+ for n in range(2**d):
+ partner = boms_get_partner(n, l, p)
+ if partner != n:
+ if partner > n:
+ pair = (n, partner)
+ else:
+ pair = (partner, n)
+ if pair not in pairs:
+ pairs.append(pair)
+ steps.append(pairs)
+ return steps
+
+def latency(lane_count):
+ d = log2_int(lane_count)
+ return sum(l for l in range(1, d+1))
+
+def cmp_wrap(a, b):
+ return Mux((a[-2] == a[-1]) & (b[-2] == b[-1]) & (a[-1] != b[-1]), a[-1], a < b)
+
+class OutputNetwork(Elaboratable):
+ def __init__(self, lane_count, seqn_width, layout_payload):
+ self.lane_count = lane_count
+ self.seqn_width = seqn_width
+ self.layout_payload = layout_payload
+ self.input = [Record(layouts.output_network_node(seqn_width, layout_payload))
+ for _ in range(lane_count)]
+ self.output = None
+
+ def elaborate(self, platform):
+ m = Module()
+
+ step_input = self.input
+ for step in boms_steps_pairs(self.lane_count):
+ step_output = []
+ for i in range(lane_count):
+ rec = Record(layouts.output_network_node(seqn_width, layout_payload),
+ reset_less=True)
+ rec.valid.reset_less = False
+ step_output.append(rec)
+
+ for node1, node2 in step:
+ nondata_difference = Signal()
+ for field, _ in layout_payload:
+ if field != "data":
+ f1 = getattr(step_input[node1].payload, field)
+ f2 = getattr(step_input[node2].payload, field)
+ with m.If(f1 != f2):
+ m.d.comb += nondata_difference.eq(1)
+
+ k1 = Cat(step_input[node1].payload.channel, ~step_input[node1].valid)
+ k2 = Cat(step_input[node2].payload.channel, ~step_input[node2].valid)
+ with m.If(k1 == k2):
+ with m.If(cmp_wrap(step_input[node1].seqn, step_input[node2].seqn)):
+ m.d.sync += step_output[node1].eq(step_input[node2])
+ m.d.sync += step_output[node2].eq(step_input[node1])
+ with m.Else():
+ m.d.sync += step_output[node1].eq(step_input[node1])
+ m.d.sync += step_output[node2].eq(step_input[node2])
+ m.d.sync += step_output[node1].replace_occurred.eq(1)
+ m.d.sync += step_output[node1].nondata_replace_occurred.eq(nondata_difference),
+ m.d.sync += step_output[node2].valid.eq(0)
+ with m.Elif(k1 < k2):
+ m.d.sync += step_output[node1].eq(step_input[node1])
+ m.d.sync += step_output[node2].eq(step_input[node2])
+ with m.Else():
+ m.d.sync += step_output[node1].eq(step_input[node2])
+ m.d.sync += step_output[node2].eq(step_input[node1])
+
+ unchanged = list(range(lane_count))
+ for node1, node2 in step:
+ unchanged.remove(node1)
+ unchanged.remove(node2)
+ for node in unchanged:
+ m.d.sync += step_output[node].eq(step_input[node])
+
+ self.output = step_output
+ step_input = step_output
+
+ return m