diff --git a/rtio/test/sed/output_network.py b/rtio/test/sed/output_network.py index e7e5ae6..9e61aa2 100644 --- a/rtio/test/sed/output_network.py +++ b/rtio/test/sed/output_network.py @@ -74,22 +74,19 @@ class OutputNetworkSpec(Elaboratable): m.d.comb += Assert(appeared) # Otherwise: # - There is a channel number collision among the valid inputs - # - All channel numbers in valid inputs appear exactly once as a + # - All channel numbers in valid inputs appear at least once as a # valid output # - All valid outputs correspond to a valid input modulo accounting # information with m.Else(): m.d.comb += Assert(~valid_channels_unique) for input_node in output_network.input: - input_channel_valid_once = Const(0) - for node1 in range(len(output_network.output)): - accum = (Past(input_node.payload.channel, clocks=network_latency) == output_network.output[node1].payload.channel) & output_network.output[node1].valid - for node2 in range(len(output_network.output)): - if node1 != node2: - accum = accum & ((Past(input_node.payload.channel, clocks=network_latency) != output_network.output[node2].payload.channel) | ~output_network.output[node2].valid) - input_channel_valid_once = input_channel_valid_once | accum with m.If(Past(input_node.valid, clocks=network_latency)): - m.d.comb += Assert(input_channel_valid_once) + appeared = Signal() + for output_node in output_network.output: + with m.If(output_node.valid & (output_node.payload.channel == Past(input_node.payload.channel, clocks=network_latency))): + m.d.comb += appeared.eq(1) + m.d.comb += Assert(appeared) for output_node in output_network.output: with m.If(output_node.valid): found_input = Signal() @@ -110,19 +107,14 @@ class OutputNetworkSpec(Elaboratable): class OutputNetworkTestCase(FHDLTestCase): def verify(self): - # Bounded proofs - # 8 lanes (failing) - # self.assertFormal( - # OutputNetworkSpec(8, 2, [("data", 32), ("channel", 3)]), - # mode="bmc", depth=40) - # Unbounded proofs - # 2 lanes self.assertFormal( OutputNetworkSpec(2, 2, [("data", 32), ("channel", 3)]), mode="prove", depth=latency(2)) - # 4 lanes self.assertFormal( OutputNetworkSpec(4, 2, [("data", 32), ("channel", 3)]), mode="prove", depth=latency(4)) + self.assertFormal( + OutputNetworkSpec(8, 2, [("data", 32), ("channel", 3)]), + mode="prove", depth=latency(8)) OutputNetworkTestCase().verify()