diff --git a/rtio/sed/output_network.py b/rtio/sed/output_network.py index ff33c20..c789627 100644 --- a/rtio/sed/output_network.py +++ b/rtio/sed/output_network.py @@ -55,10 +55,10 @@ class OutputNetwork(Elaboratable): if fv_mode: # Model arbitrary inputs for network nodes for i in range(lane_count): - m.d.comb += self.input[i].valid.eq(AnySeq(1)) + m.d.comb += self.input[i].valid.eq(1) m.d.comb += self.input[i].seqn.eq(AnySeq(seqn_width)) - m.d.comb += self.input[i].replace_occured.eq(AnySeq(1)) - m.d.comb += self.input[i].nondata_replace_occured.eq(AnySeq(1)) + m.d.comb += self.input[i].replace_occured.eq(0) + m.d.comb += self.input[i].nondata_replace_occured.eq(0) for field, width in layout_payload: m.d.comb += getattr(self.input[i].payload, field).eq(AnySeq(width)) @@ -119,9 +119,51 @@ class OutputNetwork(Elaboratable): f_past_valid = Signal() m.d.sync += f_past_valid.eq(1) - # Valid nodes always come first in outputs - for i in range(lane_count - 1): - m.d.comb += Assert(self.output[i].valid | ~self.output[i + 1].valid) # TODO: Figure out why this is failing + # Indicator of when inputs from the first clock cycle make it + # through the sorting network + network_latency = latency(lane_count) + counter = Signal(range(network_latency + 1)) + m.d.sync += counter.eq(counter + 1) + with m.If(counter == network_latency): + m.d.sync += counter.eq(counter) + f_output_valid = Signal() + m.d.comb += f_output_valid.eq(counter == network_latency) + + # If there are no replacements, all input data are unique and they + # all make it through the sorting network + with m.If(f_output_valid): + replacement_occurred = Signal() + for node in self.output: + with m.If(node.replace_occured): + m.d.comb += replacement_occurred.eq(1) + with m.If(~replacement_occurred): + nodes_unique = Signal(reset=1) + for node1 in range(len(self.input)): + for node2 in range(node1): + k1 = Cat(Past(self.input[node1].payload.channel, clocks=network_latency), ~Past(self.input[node1].valid, clocks=network_latency)) + k2 = Cat(Past(self.input[node2].payload.channel, clocks=network_latency), ~Past(self.input[node2].valid, clocks=network_latency)) + with m.If(k1 == k2): + m.d.comb += nodes_unique.eq(0) + m.d.comb += Assert(nodes_unique) + # TODO: figure out why the rest is failing + # appeared = Signal(len(self.input)) + # for input_node in range(len(self.input)): + # for output_node in self.output: + # identical = Signal(reset=1) + # with m.If(Past(self.input[input_node].valid, clocks=network_latency) != output_node.valid): + # m.d.comb += identical.eq(0) + # with m.If(Past(self.input[input_node].seqn, clocks=network_latency) != output_node.seqn): + # m.d.comb += identical.eq(0) + # with m.If(Past(self.input[input_node].replace_occured, clocks=network_latency) != output_node.replace_occured): + # m.d.comb += identical.eq(0) + # with m.If(Past(self.input[input_node].nondata_replace_occured, clocks=network_latency) != output_node.nondata_replace_occured): + # m.d.comb += identical.eq(0) + # for field, _ in layout_payload: + # with m.If(Past(getattr(self.input[input_node].payload, field), clocks=network_latency) != getattr(output_node.payload, field)): + # m.d.comb += identical.eq(0) + # m.d.comb += appeared[input_node].eq(identical) + # for i in range(len(self.input)): + # m.d.comb += Assert(appeared[i]) def elaborate(self, platform): return self.m