From e3a8229ac24d93e3a2cc6b2994b7e26877fad415 Mon Sep 17 00:00:00 2001 From: Donald Sebastian Leung Date: Fri, 23 Oct 2020 14:57:29 +0800 Subject: [PATCH] Rewrite sorting network to follow nMigen convention --- README.md | 2 +- rtio/sed/output_network.py | 124 +++++++++++++++++-------------------- 2 files changed, 59 insertions(+), 67 deletions(-) diff --git a/README.md b/README.md index 44fd9a0..72569c9 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ $ python -m rtio.test.sed.output_network - - [ ] `rtio.cri` (`Interface` and `CRIDecoder` only) - - [ ] `rtio.rtlink` - - [ ] `rtio.sed.layouts` -- - [ ] `rtio.sed.output_network` +- - [x] `rtio.sed.output_network` - - [ ] `rtio.sed.output_driver` ## License diff --git a/rtio/sed/output_network.py b/rtio/sed/output_network.py index 9e68987..df76ce2 100644 --- a/rtio/sed/output_network.py +++ b/rtio/sed/output_network.py @@ -46,82 +46,75 @@ def cmp_wrap(a, b): class OutputNetwork(Elaboratable): def __init__(self, lane_count, seqn_width, layout_payload, *, fv_mode=False): - m = Module() - self.m = m - self.input = [Record(layouts.output_network_node(seqn_width, layout_payload)) - for _ in range(lane_count)] - self.output = None + self.lane_count = lane_count + self.seqn_width = seqn_width + self.layout_payload = layout_payload + self.fv_mode = fv_mode - if fv_mode: - # Model arbitrary inputs for network nodes - for i in range(lane_count): - m.d.comb += self.input[i].valid.eq(1) - m.d.comb += self.input[i].seqn.eq(AnySeq(seqn_width)) - m.d.comb += self.input[i].replace_occured.eq(0) - m.d.comb += self.input[i].nondata_replace_occured.eq(0) - for field, width in layout_payload: - m.d.comb += getattr(self.input[i].payload, field).eq(AnySeq(width)) - - step_input = self.input - for step in boms_steps_pairs(lane_count): - step_output = [] - for i in range(lane_count): - rec = Record(layouts.output_network_node(seqn_width, layout_payload), - reset_less=True) - rec.valid.reset_less = False - step_output.append(rec) - - for node1, node2 in step: - nondata_difference = Signal() + self.steps = boms_steps_pairs(lane_count) + self.network = [[Record(layouts.output_network_node(seqn_width, layout_payload)) for _ in range(lane_count)] for _ in range(len(self.steps) + 1)] + for i in range(1, len(self.steps) + 1): + for rec in self.network[i]: + rec.seqn.reset_less = True + rec.replace_occured.reset_less = True + rec.nondata_replace_occured.reset_less = True for field, _ in layout_payload: - if field != "data": - f1 = getattr(step_input[node1].payload, field) - f2 = getattr(step_input[node2].payload, field) + getattr(rec.payload, field).reset_less = True + self.input = self.network[0] + self.output = self.network[-1] + + def elaborate(self, platform): + m = Module() + + for i in range(len(self.steps)): + for node1, node2 in self.steps[i]: + nondata_difference = Signal() + for field, _ in self.layout_payload: + if field != 'data': + f1 = getattr(self.network[i][node1].payload, field) + f2 = getattr(self.network[i][node2].payload, field) with m.If(f1 != f2): m.d.comb += nondata_difference.eq(1) - k1 = Cat(step_input[node1].payload.channel, ~step_input[node1].valid) - k2 = Cat(step_input[node2].payload.channel, ~step_input[node2].valid) + k1 = Cat(self.network[i][node1].payload.channel, ~self.network[i][node1].valid) + k2 = Cat(self.network[i][node2].payload.channel, ~self.network[i][node2].valid) with m.If(k1 == k2): - with m.If(cmp_wrap(step_input[node1].seqn, step_input[node2].seqn)): - m.d.sync += step_output[node1].eq(step_input[node2]) - m.d.sync += step_output[node2].eq(step_input[node1]) + with m.If(cmp_wrap(self.network[i][node1].seqn, self.network[i][node2].seqn)): + m.d.sync += self.network[i + 1][node1].eq(self.network[i][node2]) + m.d.sync += self.network[i + 1][node2].eq(self.network[i][node1]) with m.Else(): - m.d.sync += step_output[node1].eq(step_input[node1]) - m.d.sync += step_output[node2].eq(step_input[node2]) - m.d.sync += step_output[node1].replace_occured.eq(1) - m.d.sync += step_output[node1].nondata_replace_occured.eq(nondata_difference) - m.d.sync += step_output[node2].valid.eq(0) + m.d.sync += self.network[i + 1][node1].eq(self.network[i][node1]) + m.d.sync += self.network[i + 1][node2].eq(self.network[i][node2]) + m.d.sync += self.network[i + 1][node1].replace_occured.eq(1) + m.d.sync += self.network[i + 1][node1].nondata_replace_occured.eq(nondata_difference) + m.d.sync += self.network[i + 1][node2].valid.eq(0) with m.Elif(k1 < k2): - m.d.sync += step_output[node1].eq(step_input[node1]) - m.d.sync += step_output[node2].eq(step_input[node2]) + m.d.sync += self.network[i + 1][node1].eq(self.network[i][node1]) + m.d.sync += self.network[i + 1][node2].eq(self.network[i][node2]) with m.Else(): - m.d.sync += step_output[node1].eq(step_input[node2]) - m.d.sync += step_output[node2].eq(step_input[node1]) - - unchanged = list(range(lane_count)) - for node1, node2 in step: + m.d.sync += self.network[i + 1][node1].eq(self.network[i][node2]) + m.d.sync += self.network[i + 1][node2].eq(self.network[i][node1]) + + unchanged = list(range(self.lane_count)) + for node1, node2 in self.steps[i]: unchanged.remove(node1) unchanged.remove(node2) for node in unchanged: - m.d.sync += step_output[node].eq(step_input[node]) + m.d.sync += self.network[i + 1][node].eq(self.network[i][node]) - self.output = step_output - step_input = step_output - - if fv_mode: - # Sanity checks - assert self.output is not None - assert len(self.input) == lane_count - assert len(self.output) == lane_count - - # Indicator of when Past() is valid - f_past_valid = Signal() - m.d.sync += f_past_valid.eq(1) + if self.fv_mode: + # Model arbitrary inputs for network nodes + for i in range(self.lane_count): + m.d.comb += self.input[i].valid.eq(1) + m.d.comb += self.input[i].seqn.eq(AnySeq(self.seqn_width)) + m.d.comb += self.input[i].replace_occured.eq(0) + m.d.comb += self.input[i].nondata_replace_occured.eq(0) + for field, width in self.layout_payload: + m.d.comb += getattr(self.input[i].payload, field).eq(AnySeq(width)) # Indicator of when inputs from the first clock cycle make it # through the sorting network - network_latency = latency(lane_count) + network_latency = latency(self.lane_count) counter = Signal(range(network_latency + 1)) m.d.sync += counter.eq(counter + 1) with m.If(counter == network_latency): @@ -142,7 +135,7 @@ class OutputNetwork(Elaboratable): with m.If(k1 == k2): m.d.comb += channels_unique.eq(0) # If there are no replacements then: - # - (Input) channel numbers are unique + # - Channel numbers are unique # - All outputs are valid # - All inputs make it through the sorting network with m.If(~replacement_occurred): @@ -161,14 +154,14 @@ class OutputNetwork(Elaboratable): m.d.comb += match.eq(0) with m.If(Past(input_node.nondata_replace_occured, clocks=network_latency) != output_node.nondata_replace_occured): m.d.comb += match.eq(0) - for field, _ in layout_payload: + for field, _ in self.layout_payload: with m.If(Past(getattr(input_node.payload, field), clocks=network_latency) != getattr(output_node.payload, field)): m.d.comb += match.eq(0) with m.If(match): m.d.comb += appeared.eq(1) m.d.comb += Assert(appeared) # Otherwise, if there are replacements: - # - Channel numbers are not unique + # - Channel number are not unique # - Not all outputs are valid # - All channel numbers in the input appear exactly once as a # valid output @@ -197,12 +190,11 @@ class OutputNetwork(Elaboratable): match = Signal(reset=1) with m.If(Past(input_node.seqn, clocks=network_latency) != output_node.seqn): m.d.comb += match.eq(0) - for field, _ in layout_payload: + for field, _ in self.layout_payload: with m.If(Past(getattr(input_node.payload, field), clocks=network_latency) != getattr(output_node.payload, field)): m.d.comb += match.eq(0) with m.If(match): m.d.comb += found_input.eq(1) m.d.comb += Assert(found_input) - def elaborate(self, platform): - return self.m + return m