Merge pull request #991 from MaxVerevkin/fix-ub

Fix UB in `Matrix::perp()`
This commit is contained in:
Sébastien Crozet 2021-09-26 11:05:44 +02:00 committed by GitHub
commit 7f236d88aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 21 additions and 17 deletions

View File

@ -2033,16 +2033,26 @@ impl<T: Scalar + ClosedAdd + ClosedSub + ClosedMul, R: Dim, C: Dim, S: RawStorag
+ SameNumberOfRows<R2, U2> + SameNumberOfRows<R2, U2>
+ SameNumberOfColumns<C2, U1>, + SameNumberOfColumns<C2, U1>,
{ {
assert!( let shape = self.shape();
self.shape() == (2, 1), assert_eq!(
"2D perpendicular product requires (2, 1) vector but found {:?}", shape,
self.shape() b.shape(),
"2D vector perpendicular product dimension mismatch."
);
assert_eq!(
shape,
(2, 1),
"2D perpendicular product requires (2, 1) vectors {:?}",
shape
); );
unsafe { // SAFETY: assertion above ensures correct shape
self.get_unchecked((0, 0)).clone() * b.get_unchecked((1, 0)).clone() let ax = unsafe { self.get_unchecked((0, 0)).clone() };
- self.get_unchecked((1, 0)).clone() * b.get_unchecked((0, 0)).clone() let ay = unsafe { self.get_unchecked((1, 0)).clone() };
} let bx = unsafe { b.get_unchecked((0, 0)).clone() };
let by = unsafe { b.get_unchecked((1, 0)).clone() };
ax * by - ay * bx
} }
// TODO: use specialization instead of an assertion. // TODO: use specialization instead of an assertion.
@ -2063,17 +2073,14 @@ impl<T: Scalar + ClosedAdd + ClosedSub + ClosedMul, R: Dim, C: Dim, S: RawStorag
let shape = self.shape(); let shape = self.shape();
assert_eq!(shape, b.shape(), "Vector cross product dimension mismatch."); assert_eq!(shape, b.shape(), "Vector cross product dimension mismatch.");
assert!( assert!(
(shape.0 == 3 && shape.1 == 1) || (shape.0 == 1 && shape.1 == 3), shape == (3, 1) || shape == (1, 3),
"Vector cross product dimension mismatch: must be (3, 1) or (1, 3) but found {:?}.", "Vector cross product dimension mismatch: must be (3, 1) or (1, 3) but found {:?}.",
shape shape
); );
if shape.0 == 3 { if shape.0 == 3 {
unsafe { unsafe {
// TODO: soooo ugly! let mut res = Matrix::uninit(Dim::from_usize(3), Dim::from_usize(1));
let nrows = SameShapeR::<R, R2>::from_usize(3);
let ncols = SameShapeC::<C, C2>::from_usize(1);
let mut res = Matrix::uninit(nrows, ncols);
let ax = self.get_unchecked((0, 0)); let ax = self.get_unchecked((0, 0));
let ay = self.get_unchecked((1, 0)); let ay = self.get_unchecked((1, 0));
@ -2095,10 +2102,7 @@ impl<T: Scalar + ClosedAdd + ClosedSub + ClosedMul, R: Dim, C: Dim, S: RawStorag
} }
} else { } else {
unsafe { unsafe {
// TODO: ugly! let mut res = Matrix::uninit(Dim::from_usize(1), Dim::from_usize(3));
let nrows = SameShapeR::<R, R2>::from_usize(1);
let ncols = SameShapeC::<C, C2>::from_usize(3);
let mut res = Matrix::uninit(nrows, ncols);
let ax = self.get_unchecked((0, 0)); let ax = self.get_unchecked((0, 0));
let ay = self.get_unchecked((0, 1)); let ay = self.get_unchecked((0, 1));