From e3a8229ac24d93e3a2cc6b2994b7e26877fad415 Mon Sep 17 00:00:00 2001
From: Donald Sebastian Leung
Date: Fri, 23 Oct 2020 14:57:29 +0800
Subject: [PATCH] Rewrite sorting network to follow nMigen convention
---
README.md | 2 +-
rtio/sed/output_network.py | 124 +++++++++++++++++--------------------
2 files changed, 59 insertions(+), 67 deletions(-)
diff --git a/README.md b/README.md
index 44fd9a0..72569c9 100644
--- a/README.md
+++ b/README.md
@@ -36,7 +36,7 @@ $ python -m rtio.test.sed.output_network
- - [ ] `rtio.cri` (`Interface` and `CRIDecoder` only)
- - [ ] `rtio.rtlink`
- - [ ] `rtio.sed.layouts`
-- - [ ] `rtio.sed.output_network`
+- - [x] `rtio.sed.output_network`
- - [ ] `rtio.sed.output_driver`
## License
diff --git a/rtio/sed/output_network.py b/rtio/sed/output_network.py
index 9e68987..df76ce2 100644
--- a/rtio/sed/output_network.py
+++ b/rtio/sed/output_network.py
@@ -46,82 +46,75 @@ def cmp_wrap(a, b):
class OutputNetwork(Elaboratable):
def __init__(self, lane_count, seqn_width, layout_payload, *, fv_mode=False):
- m = Module()
- self.m = m
- self.input = [Record(layouts.output_network_node(seqn_width, layout_payload))
- for _ in range(lane_count)]
- self.output = None
+ self.lane_count = lane_count
+ self.seqn_width = seqn_width
+ self.layout_payload = layout_payload
+ self.fv_mode = fv_mode
- if fv_mode:
- # Model arbitrary inputs for network nodes
- for i in range(lane_count):
- m.d.comb += self.input[i].valid.eq(1)
- m.d.comb += self.input[i].seqn.eq(AnySeq(seqn_width))
- m.d.comb += self.input[i].replace_occured.eq(0)
- m.d.comb += self.input[i].nondata_replace_occured.eq(0)
- for field, width in layout_payload:
- m.d.comb += getattr(self.input[i].payload, field).eq(AnySeq(width))
-
- step_input = self.input
- for step in boms_steps_pairs(lane_count):
- step_output = []
- for i in range(lane_count):
- rec = Record(layouts.output_network_node(seqn_width, layout_payload),
- reset_less=True)
- rec.valid.reset_less = False
- step_output.append(rec)
-
- for node1, node2 in step:
- nondata_difference = Signal()
+ self.steps = boms_steps_pairs(lane_count)
+ self.network = [[Record(layouts.output_network_node(seqn_width, layout_payload)) for _ in range(lane_count)] for _ in range(len(self.steps) + 1)]
+ for i in range(1, len(self.steps) + 1):
+ for rec in self.network[i]:
+ rec.seqn.reset_less = True
+ rec.replace_occured.reset_less = True
+ rec.nondata_replace_occured.reset_less = True
for field, _ in layout_payload:
- if field != "data":
- f1 = getattr(step_input[node1].payload, field)
- f2 = getattr(step_input[node2].payload, field)
+ getattr(rec.payload, field).reset_less = True
+ self.input = self.network[0]
+ self.output = self.network[-1]
+
+ def elaborate(self, platform):
+ m = Module()
+
+ for i in range(len(self.steps)):
+ for node1, node2 in self.steps[i]:
+ nondata_difference = Signal()
+ for field, _ in self.layout_payload:
+ if field != 'data':
+ f1 = getattr(self.network[i][node1].payload, field)
+ f2 = getattr(self.network[i][node2].payload, field)
with m.If(f1 != f2):
m.d.comb += nondata_difference.eq(1)
- k1 = Cat(step_input[node1].payload.channel, ~step_input[node1].valid)
- k2 = Cat(step_input[node2].payload.channel, ~step_input[node2].valid)
+ k1 = Cat(self.network[i][node1].payload.channel, ~self.network[i][node1].valid)
+ k2 = Cat(self.network[i][node2].payload.channel, ~self.network[i][node2].valid)
with m.If(k1 == k2):
- with m.If(cmp_wrap(step_input[node1].seqn, step_input[node2].seqn)):
- m.d.sync += step_output[node1].eq(step_input[node2])
- m.d.sync += step_output[node2].eq(step_input[node1])
+ with m.If(cmp_wrap(self.network[i][node1].seqn, self.network[i][node2].seqn)):
+ m.d.sync += self.network[i + 1][node1].eq(self.network[i][node2])
+ m.d.sync += self.network[i + 1][node2].eq(self.network[i][node1])
with m.Else():
- m.d.sync += step_output[node1].eq(step_input[node1])
- m.d.sync += step_output[node2].eq(step_input[node2])
- m.d.sync += step_output[node1].replace_occured.eq(1)
- m.d.sync += step_output[node1].nondata_replace_occured.eq(nondata_difference)
- m.d.sync += step_output[node2].valid.eq(0)
+ m.d.sync += self.network[i + 1][node1].eq(self.network[i][node1])
+ m.d.sync += self.network[i + 1][node2].eq(self.network[i][node2])
+ m.d.sync += self.network[i + 1][node1].replace_occured.eq(1)
+ m.d.sync += self.network[i + 1][node1].nondata_replace_occured.eq(nondata_difference)
+ m.d.sync += self.network[i + 1][node2].valid.eq(0)
with m.Elif(k1 < k2):
- m.d.sync += step_output[node1].eq(step_input[node1])
- m.d.sync += step_output[node2].eq(step_input[node2])
+ m.d.sync += self.network[i + 1][node1].eq(self.network[i][node1])
+ m.d.sync += self.network[i + 1][node2].eq(self.network[i][node2])
with m.Else():
- m.d.sync += step_output[node1].eq(step_input[node2])
- m.d.sync += step_output[node2].eq(step_input[node1])
-
- unchanged = list(range(lane_count))
- for node1, node2 in step:
+ m.d.sync += self.network[i + 1][node1].eq(self.network[i][node2])
+ m.d.sync += self.network[i + 1][node2].eq(self.network[i][node1])
+
+ unchanged = list(range(self.lane_count))
+ for node1, node2 in self.steps[i]:
unchanged.remove(node1)
unchanged.remove(node2)
for node in unchanged:
- m.d.sync += step_output[node].eq(step_input[node])
+ m.d.sync += self.network[i + 1][node].eq(self.network[i][node])
- self.output = step_output
- step_input = step_output
-
- if fv_mode:
- # Sanity checks
- assert self.output is not None
- assert len(self.input) == lane_count
- assert len(self.output) == lane_count
-
- # Indicator of when Past() is valid
- f_past_valid = Signal()
- m.d.sync += f_past_valid.eq(1)
+ if self.fv_mode:
+ # Model arbitrary inputs for network nodes
+ for i in range(self.lane_count):
+ m.d.comb += self.input[i].valid.eq(1)
+ m.d.comb += self.input[i].seqn.eq(AnySeq(self.seqn_width))
+ m.d.comb += self.input[i].replace_occured.eq(0)
+ m.d.comb += self.input[i].nondata_replace_occured.eq(0)
+ for field, width in self.layout_payload:
+ m.d.comb += getattr(self.input[i].payload, field).eq(AnySeq(width))
# Indicator of when inputs from the first clock cycle make it
# through the sorting network
- network_latency = latency(lane_count)
+ network_latency = latency(self.lane_count)
counter = Signal(range(network_latency + 1))
m.d.sync += counter.eq(counter + 1)
with m.If(counter == network_latency):
@@ -142,7 +135,7 @@ class OutputNetwork(Elaboratable):
with m.If(k1 == k2):
m.d.comb += channels_unique.eq(0)
# If there are no replacements then:
- # - (Input) channel numbers are unique
+ # - Channel numbers are unique
# - All outputs are valid
# - All inputs make it through the sorting network
with m.If(~replacement_occurred):
@@ -161,14 +154,14 @@ class OutputNetwork(Elaboratable):
m.d.comb += match.eq(0)
with m.If(Past(input_node.nondata_replace_occured, clocks=network_latency) != output_node.nondata_replace_occured):
m.d.comb += match.eq(0)
- for field, _ in layout_payload:
+ for field, _ in self.layout_payload:
with m.If(Past(getattr(input_node.payload, field), clocks=network_latency) != getattr(output_node.payload, field)):
m.d.comb += match.eq(0)
with m.If(match):
m.d.comb += appeared.eq(1)
m.d.comb += Assert(appeared)
# Otherwise, if there are replacements:
- # - Channel numbers are not unique
+ # - Channel number are not unique
# - Not all outputs are valid
# - All channel numbers in the input appear exactly once as a
# valid output
@@ -197,12 +190,11 @@ class OutputNetwork(Elaboratable):
match = Signal(reset=1)
with m.If(Past(input_node.seqn, clocks=network_latency) != output_node.seqn):
m.d.comb += match.eq(0)
- for field, _ in layout_payload:
+ for field, _ in self.layout_payload:
with m.If(Past(getattr(input_node.payload, field), clocks=network_latency) != getattr(output_node.payload, field)):
m.d.comb += match.eq(0)
with m.If(match):
m.d.comb += found_input.eq(1)
m.d.comb += Assert(found_input)
- def elaborate(self, platform):
- return self.m
+ return m