diff --git a/rtio/sed/output_network.py b/rtio/sed/output_network.py index 22f9238..affdb0d 100644 --- a/rtio/sed/output_network.py +++ b/rtio/sed/output_network.py @@ -129,21 +129,22 @@ class OutputNetwork(Elaboratable): f_output_valid = Signal() m.d.comb += f_output_valid.eq(counter == network_latency) - # If there are no replacements, all input data are unique and they - # all make it through the sorting network with m.If(f_output_valid): replacement_occurred = Signal() for node in self.output: with m.If(node.replace_occured): m.d.comb += replacement_occurred.eq(1) + nodes_unique = Signal(reset=1) + for node1 in range(len(self.input)): + for node2 in range(node1): + k1 = Cat(Past(self.input[node1].payload.channel, clocks=network_latency), ~Past(self.input[node1].valid, clocks=network_latency)) + k2 = Cat(Past(self.input[node2].payload.channel, clocks=network_latency), ~Past(self.input[node2].valid, clocks=network_latency)) + with m.If(k1 == k2): + m.d.comb += nodes_unique.eq(0) + # If there are no replacements then: + # - All input data are unique + # - They all make it through the sorting network with m.If(~replacement_occurred): - nodes_unique = Signal(reset=1) - for node1 in range(len(self.input)): - for node2 in range(node1): - k1 = Cat(Past(self.input[node1].payload.channel, clocks=network_latency), ~Past(self.input[node1].valid, clocks=network_latency)) - k2 = Cat(Past(self.input[node2].payload.channel, clocks=network_latency), ~Past(self.input[node2].valid, clocks=network_latency)) - with m.If(k1 == k2): - m.d.comb += nodes_unique.eq(0) m.d.comb += Assert(nodes_unique) for input_node in self.input: appeared = Signal() @@ -164,8 +165,6 @@ class OutputNetwork(Elaboratable): m.d.comb += appeared.eq(1) m.d.comb += Assert(appeared) - # If the valid bit / channel no. combinations of all input data are - # unique then there should be no replacements with m.If(f_output_valid): nodes_unique = Signal(reset=1) for node1 in range(len(self.input)): @@ -174,11 +173,13 @@ class OutputNetwork(Elaboratable): k2 = Cat(Past(self.input[node2].payload.channel, clocks=network_latency), ~Past(self.input[node2].valid, clocks=network_latency)) with m.If(k1 == k2): m.d.comb += nodes_unique.eq(0) + replacement_occurred = Signal() + for output_node in self.output: + with m.If(output_node.replace_occured): + m.d.comb += replacement_occurred.eq(1) + # If the valid bit / channel no. combinations of all input data + # are unique then there should be no replacements with m.If(nodes_unique): - replacement_occurred = Signal() - for output_node in self.output: - with m.If(output_node.replace_occured): - m.d.comb += replacement_occurred.eq(1) m.d.comb += Assert(~replacement_occurred) def elaborate(self, platform):