Matrix::perp() fix UB

This commit is contained in:
Max Verevkin 2021-09-19 17:00:49 +03:00
parent 654eca7f80
commit b91eecebcd
1 changed files with 21 additions and 17 deletions

View File

@ -2036,16 +2036,26 @@ impl<T: Scalar + ClosedAdd + ClosedSub + ClosedMul, R: Dim, C: Dim, S: RawStorag
+ SameNumberOfRows<R2, U2> + SameNumberOfRows<R2, U2>
+ SameNumberOfColumns<C2, U1>, + SameNumberOfColumns<C2, U1>,
{ {
assert!( let shape = self.shape();
self.shape() == (2, 1), assert_eq!(
"2D perpendicular product requires (2, 1) vector but found {:?}", shape,
self.shape() b.shape(),
"2D vector perpendicular product dimension mismatch."
);
assert_eq!(
shape,
(2, 1),
"2D perpendicular product requires (2, 1) vectors {:?}",
shape
); );
unsafe { // SAFETY: assertion above ensures correct shape
self.get_unchecked((0, 0)).clone() * b.get_unchecked((1, 0)).clone() let ax = unsafe { self.get_unchecked((0, 0)).clone() };
- self.get_unchecked((1, 0)).clone() * b.get_unchecked((0, 0)).clone() let ay = unsafe { self.get_unchecked((1, 0)).clone() };
} let bx = unsafe { b.get_unchecked((0, 0)).clone() };
let by = unsafe { b.get_unchecked((1, 0)).clone() };
ax * by - ay * bx
} }
// TODO: use specialization instead of an assertion. // TODO: use specialization instead of an assertion.
@ -2066,17 +2076,14 @@ impl<T: Scalar + ClosedAdd + ClosedSub + ClosedMul, R: Dim, C: Dim, S: RawStorag
let shape = self.shape(); let shape = self.shape();
assert_eq!(shape, b.shape(), "Vector cross product dimension mismatch."); assert_eq!(shape, b.shape(), "Vector cross product dimension mismatch.");
assert!( assert!(
(shape.0 == 3 && shape.1 == 1) || (shape.0 == 1 && shape.1 == 3), shape == (3, 1) || shape == (1, 3),
"Vector cross product dimension mismatch: must be (3, 1) or (1, 3) but found {:?}.", "Vector cross product dimension mismatch: must be (3, 1) or (1, 3) but found {:?}.",
shape shape
); );
if shape.0 == 3 { if shape.0 == 3 {
unsafe { unsafe {
// TODO: soooo ugly! let mut res = Matrix::uninit(Dim::from_usize(3), Dim::from_usize(1));
let nrows = SameShapeR::<R, R2>::from_usize(3);
let ncols = SameShapeC::<C, C2>::from_usize(1);
let mut res = Matrix::uninit(nrows, ncols);
let ax = self.get_unchecked((0, 0)); let ax = self.get_unchecked((0, 0));
let ay = self.get_unchecked((1, 0)); let ay = self.get_unchecked((1, 0));
@ -2098,10 +2105,7 @@ impl<T: Scalar + ClosedAdd + ClosedSub + ClosedMul, R: Dim, C: Dim, S: RawStorag
} }
} else { } else {
unsafe { unsafe {
// TODO: ugly! let mut res = Matrix::uninit(Dim::from_usize(1), Dim::from_usize(3));
let nrows = SameShapeR::<R, R2>::from_usize(1);
let ncols = SameShapeC::<C, C2>::from_usize(3);
let mut res = Matrix::uninit(nrows, ncols);
let ax = self.get_unchecked((0, 0)); let ax = self.get_unchecked((0, 0));
let ay = self.get_unchecked((0, 1)); let ay = self.get_unchecked((0, 1));