diff --git a/src/base/matrix.rs b/src/base/matrix.rs index 78ec2dd1..983f9a86 100644 --- a/src/base/matrix.rs +++ b/src/base/matrix.rs @@ -2033,16 +2033,26 @@ impl + SameNumberOfColumns, { - assert!( - self.shape() == (2, 1), - "2D perpendicular product requires (2, 1) vector but found {:?}", - self.shape() + let shape = self.shape(); + assert_eq!( + shape, + b.shape(), + "2D vector perpendicular product dimension mismatch." + ); + assert_eq!( + shape, + (2, 1), + "2D perpendicular product requires (2, 1) vectors {:?}", + shape ); - unsafe { - self.get_unchecked((0, 0)).clone() * b.get_unchecked((1, 0)).clone() - - self.get_unchecked((1, 0)).clone() * b.get_unchecked((0, 0)).clone() - } + // SAFETY: assertion above ensures correct shape + let ax = unsafe { self.get_unchecked((0, 0)).clone() }; + let ay = unsafe { self.get_unchecked((1, 0)).clone() }; + let bx = unsafe { b.get_unchecked((0, 0)).clone() }; + let by = unsafe { b.get_unchecked((1, 0)).clone() }; + + ax * by - ay * bx } // TODO: use specialization instead of an assertion. @@ -2063,17 +2073,14 @@ impl::from_usize(3); - let ncols = SameShapeC::::from_usize(1); - let mut res = Matrix::uninit(nrows, ncols); + let mut res = Matrix::uninit(Dim::from_usize(3), Dim::from_usize(1)); let ax = self.get_unchecked((0, 0)); let ay = self.get_unchecked((1, 0)); @@ -2095,10 +2102,7 @@ impl::from_usize(1); - let ncols = SameShapeC::::from_usize(3); - let mut res = Matrix::uninit(nrows, ncols); + let mut res = Matrix::uninit(Dim::from_usize(1), Dim::from_usize(3)); let ax = self.get_unchecked((0, 0)); let ay = self.get_unchecked((0, 1));