diff --git a/rtio/sed/output_network.py b/rtio/sed/output_network.py index ed7737d..9b4755d 100644 --- a/rtio/sed/output_network.py +++ b/rtio/sed/output_network.py @@ -45,11 +45,10 @@ def cmp_wrap(a, b): return Mux((a[-2] == a[-1]) & (b[-2] == b[-1]) & (a[-1] != b[-1]), a[-1], a < b) class OutputNetwork(Elaboratable): - def __init__(self, lane_count, seqn_width, layout_payload, *, fv_mode=False): + def __init__(self, lane_count, seqn_width, layout_payload): self.lane_count = lane_count self.seqn_width = seqn_width self.layout_payload = layout_payload - self.fv_mode = fv_mode self.steps = boms_steps_pairs(lane_count) self.network = [[Record(layouts.output_network_node(seqn_width, layout_payload)) for _ in range(lane_count)] for _ in range(len(self.steps) + 1)] @@ -102,99 +101,4 @@ class OutputNetwork(Elaboratable): for node in unchanged: m.d.sync += self.network[i + 1][node].eq(self.network[i][node]) - if self.fv_mode: - # Model arbitrary inputs for network nodes - for i in range(self.lane_count): - m.d.comb += self.input[i].valid.eq(1) - m.d.comb += self.input[i].seqn.eq(AnySeq(self.seqn_width)) - m.d.comb += self.input[i].replace_occured.eq(0) - m.d.comb += self.input[i].nondata_replace_occured.eq(0) - for field, width in self.layout_payload: - m.d.comb += getattr(self.input[i].payload, field).eq(AnySeq(width)) - - # Indicator of when inputs from the first clock cycle make it - # through the sorting network - network_latency = latency(self.lane_count) - counter = Signal(range(network_latency + 1)) - m.d.sync += counter.eq(counter + 1) - with m.If(counter == network_latency): - m.d.sync += counter.eq(counter) - f_output_valid = Signal() - m.d.comb += f_output_valid.eq(counter == network_latency) - - with m.If(f_output_valid): - replacement_occurred = Signal() - for node in self.output: - with m.If(node.replace_occured): - m.d.comb += replacement_occurred.eq(1) - channels_unique = Signal(reset=1) - for node1 in range(len(self.input)): - for node2 in range(node1): - k1 = Past(self.input[node1].payload.channel, clocks=network_latency) - k2 = Past(self.input[node2].payload.channel, clocks=network_latency) - with m.If(k1 == k2): - m.d.comb += channels_unique.eq(0) - # If there are no replacements then: - # - Input channel numbers are unique - # - All outputs are valid - # - All inputs make it through the sorting network - with m.If(~replacement_occurred): - m.d.comb += Assert(channels_unique) - for node in self.output: - m.d.comb += Assert(node.valid) - for input_node in self.input: - appeared = Signal() - for output_node in self.output: - match = Signal(reset=1) - with m.If(Past(input_node.valid, clocks=network_latency) != output_node.valid): - m.d.comb += match.eq(0) - with m.If(Past(input_node.seqn, clocks=network_latency) != output_node.seqn): - m.d.comb += match.eq(0) - with m.If(Past(input_node.replace_occured, clocks=network_latency) != output_node.replace_occured): - m.d.comb += match.eq(0) - with m.If(Past(input_node.nondata_replace_occured, clocks=network_latency) != output_node.nondata_replace_occured): - m.d.comb += match.eq(0) - for field, _ in self.layout_payload: - with m.If(Past(getattr(input_node.payload, field), clocks=network_latency) != getattr(output_node.payload, field)): - m.d.comb += match.eq(0) - with m.If(match): - m.d.comb += appeared.eq(1) - m.d.comb += Assert(appeared) - # Otherwise, if there are replacements: - # - Channel numbers are not unique - # - Not all outputs are valid - # - All channel numbers in the input appear exactly once as a - # valid output - # - All valid outputs match an input modulo accounting - # information - with m.Else(): - m.d.comb += Assert(~channels_unique) - all_valid = Signal(reset=1) - for node in self.output: - with m.If(~node.valid): - m.d.comb += all_valid.eq(0) - m.d.comb += Assert(~all_valid) - for input_node in self.input: - input_channel_valid_once = Const(0) - for node1 in range(len(self.output)): - accum = (Past(input_node.payload.channel, clocks=network_latency) == self.output[node1].payload.channel) & self.output[node1].valid - for node2 in range(len(self.output)): - if node1 != node2: - accum = accum & ((Past(input_node.payload.channel, clocks=network_latency) != self.output[node2].payload.channel) | ~self.output[node2].valid) - input_channel_valid_once = input_channel_valid_once | accum - m.d.comb += Assert(input_channel_valid_once) - for output_node in self.output: - with m.If(output_node.valid): - found_input = Signal() - for input_node in self.input: - match = Signal(reset=1) - with m.If(Past(input_node.seqn, clocks=network_latency) != output_node.seqn): - m.d.comb += match.eq(0) - for field, _ in self.layout_payload: - with m.If(Past(getattr(input_node.payload, field), clocks=network_latency) != getattr(output_node.payload, field)): - m.d.comb += match.eq(0) - with m.If(match): - m.d.comb += found_input.eq(1) - m.d.comb += Assert(found_input) - return m diff --git a/rtio/test/sed/output_network.py b/rtio/test/sed/output_network.py index 9e4cd6c..b878d66 100644 --- a/rtio/test/sed/output_network.py +++ b/rtio/test/sed/output_network.py @@ -11,7 +11,7 @@ class OutputNetworkTestCase(FHDLTestCase): def verify(self): # Bounded model check self.assertFormal( - OutputNetwork(4, 2, [("data", 32), ("channel", 3)], fv_mode=True), + OutputNetwork(4, 2, [("data", 32), ("channel", 3)]), mode="bmc", depth=40) # TODO: unbounded proof