From 583a8fb1102f075e4457208f3c1cf68b21707384 Mon Sep 17 00:00:00 2001 From: Fredrik Jansson Date: Sun, 12 Apr 2020 13:27:11 +0200 Subject: [PATCH] Removed unsafe immutable hack --- src/linalg/exp.rs | 146 ++++++++++++++++++++++++++-------------------- 1 file changed, 83 insertions(+), 63 deletions(-) diff --git a/src/linalg/exp.rs b/src/linalg/exp.rs index e07ddd2f..86bd81a7 100644 --- a/src/linalg/exp.rs +++ b/src/linalg/exp.rs @@ -66,87 +66,78 @@ where } } - fn a2(&self) -> &MatrixN { + fn calc_a2(&mut self) { if self.a2.is_none() { - let ap = &self.a2 as *const Option> as *mut Option>; - unsafe { - *ap = Some(&self.a * &self.a); - } + self.a2 = Some(&self.a * &self.a); } - self.a2.as_ref().unwrap() } - fn a4(&self) -> &MatrixN { + fn calc_a4(&mut self) { if self.a4.is_none() { - let ap = &self.a4 as *const Option> as *mut Option>; - let a2 = self.a2(); - unsafe { - *ap = Some(a2 * a2); - } + self.calc_a2(); + let a2 = self.a2.as_ref().unwrap(); + self.a4 = Some(a2 * a2); } - self.a4.as_ref().unwrap() } - fn a6(&self) -> &MatrixN { + fn calc_a6(&mut self) { if self.a6.is_none() { - let a2 = self.a2(); - let a4 = self.a4(); - let ap = &self.a6 as *const Option> as *mut Option>; - unsafe { - *ap = Some(a4 * a2); - } + self.calc_a2(); + self.calc_a4(); + let a2 = self.a2.as_ref().unwrap(); + let a4 = self.a4.as_ref().unwrap(); + self.a6 = Some(a4 * a2); } - self.a6.as_ref().unwrap() } - fn a8(&self) -> &MatrixN { + fn calc_a8(&mut self) { if self.a8.is_none() { - let a2 = self.a2(); - let a6 = self.a6(); - let ap = &self.a8 as *const Option> as *mut Option>; - unsafe { - *ap = Some(a6 * a2); - } + self.calc_a2(); + self.calc_a6(); + let a2 = self.a2.as_ref().unwrap(); + let a6 = self.a6.as_ref().unwrap(); + self.a8 = Some(a6 * a2); } - self.a8.as_ref().unwrap() } - fn a10(&mut self) -> &MatrixN { + fn calc_a10(&mut self) { if self.a10.is_none() { - let a4 = self.a4(); - let a6 = self.a6(); - let ap = &self.a10 as *const Option> as *mut Option>; - unsafe { - *ap = Some(a6 * a4); - } + self.calc_a4(); + self.calc_a6(); + let a4 = self.a4.as_ref().unwrap(); + let a6 = self.a6.as_ref().unwrap(); + self.a10 = Some(a6 * a4); } - self.a10.as_ref().unwrap() } fn d4_tight(&mut self) -> N { if self.d4_exact.is_none() { - self.d4_exact = Some(one_norm(self.a4()).powf(convert(0.25))); + self.calc_a4(); + self.d4_exact = Some(one_norm(self.a4.as_ref().unwrap()).powf(convert(0.25))); } self.d4_exact.unwrap() } fn d6_tight(&mut self) -> N { if self.d6_exact.is_none() { - self.d6_exact = Some(one_norm(self.a6()).powf(convert(1.0 / 6.0))); + self.calc_a6(); + self.d6_exact = Some(one_norm(self.a6.as_ref().unwrap()).powf(convert(1.0 / 6.0))); } self.d6_exact.unwrap() } fn d8_tight(&mut self) -> N { if self.d8_exact.is_none() { - self.d8_exact = Some(one_norm(self.a8()).powf(convert(1.0 / 8.0))); + self.calc_a8(); + self.d8_exact = Some(one_norm(self.a8.as_ref().unwrap()).powf(convert(1.0 / 8.0))); } self.d8_exact.unwrap() } fn d10_tight(&mut self) -> N { if self.d10_exact.is_none() { - self.d10_exact = Some(one_norm(self.a10()).powf(convert(1.0 / 10.0))); + self.calc_a10(); + self.d10_exact = Some(one_norm(self.a10.as_ref().unwrap()).powf(convert(1.0 / 10.0))); } self.d10_exact.unwrap() } @@ -161,7 +152,8 @@ where } if self.d4_approx.is_none() { - self.d4_approx = Some(one_norm(self.a4()).powf(convert(0.25))); + self.calc_a4(); + self.d4_approx = Some(one_norm(self.a4.as_ref().unwrap()).powf(convert(0.25))); } self.d4_approx.unwrap() @@ -177,7 +169,8 @@ where } if self.d6_approx.is_none() { - self.d6_approx = Some(one_norm(self.a6()).powf(convert(1.0 / 6.0))); + self.calc_a6(); + self.d6_approx = Some(one_norm(self.a6.as_ref().unwrap()).powf(convert(1.0 / 6.0))); } self.d6_approx.unwrap() @@ -193,7 +186,8 @@ where } if self.d8_approx.is_none() { - self.d8_approx = Some(one_norm(self.a8()).powf(convert(1.0 / 8.0))); + self.calc_a8(); + self.d8_approx = Some(one_norm(self.a8.as_ref().unwrap()).powf(convert(1.0 / 8.0))); } self.d8_approx.unwrap() @@ -209,7 +203,8 @@ where } if self.d10_approx.is_none() { - self.d10_approx = Some(one_norm(self.a10()).powf(convert(1.0 / 10.0))); + self.calc_a10(); + self.d10_approx = Some(one_norm(self.a10.as_ref().unwrap()).powf(convert(1.0 / 10.0))); } self.d10_approx.unwrap() @@ -217,8 +212,10 @@ where fn pade3(&mut self) -> (MatrixN, MatrixN) { let b: [N; 4] = [convert(120.0), convert(60.0), convert(12.0), convert(1.0)]; - let u = &self.a * (self.a2() * b[3] + &self.ident * b[1]); - let v = self.a2() * b[2] + &self.ident * b[0]; + self.calc_a2(); + let a2 = self.a2.as_ref().unwrap(); + let u = &self.a * (a2 * b[3] + &self.ident * b[1]); + let v = a2 * b[2] + &self.ident * b[0]; (u, v) } @@ -231,8 +228,15 @@ where convert(30.0), convert(1.0), ]; - let u = &self.a * (self.a4() * b[5] + self.a2() * b[3] + &self.ident * b[1]); - let v = self.a4() * b[4] + self.a2() * b[2] + &self.ident * b[0]; + self.calc_a2(); + self.calc_a6(); + let u = &self.a + * (self.a4.as_ref().unwrap() * b[5] + + self.a2.as_ref().unwrap() * b[3] + + &self.ident * b[1]); + let v = self.a4.as_ref().unwrap() * b[4] + + self.a2.as_ref().unwrap() * b[2] + + &self.ident * b[0]; (u, v) } @@ -247,9 +251,18 @@ where convert(56.0), convert(1.0), ]; - let u = - &self.a * (self.a6() * b[7] + self.a4() * b[5] + self.a2() * b[3] + &self.ident * b[1]); - let v = self.a6() * b[6] + self.a4() * b[4] + self.a2() * b[2] + &self.ident * b[0]; + self.calc_a2(); + self.calc_a4(); + self.calc_a6(); + let u = &self.a + * (self.a6.as_ref().unwrap() * b[7] + + self.a4.as_ref().unwrap() * b[5] + + self.a2.as_ref().unwrap() * b[3] + + &self.ident * b[1]); + let v = self.a6.as_ref().unwrap() * b[6] + + self.a4.as_ref().unwrap() * b[4] + + self.a2.as_ref().unwrap() * b[2] + + &self.ident * b[0]; (u, v) } @@ -266,16 +279,20 @@ where convert(90.0), convert(1.0), ]; + self.calc_a2(); + self.calc_a4(); + self.calc_a6(); + self.calc_a8(); let u = &self.a - * (self.a8() * b[9] - + self.a6() * b[7] - + self.a4() * b[5] - + self.a2() * b[3] + * (self.a8.as_ref().unwrap() * b[9] + + self.a6.as_ref().unwrap() * b[7] + + self.a4.as_ref().unwrap() * b[5] + + self.a2.as_ref().unwrap() * b[3] + &self.ident * b[1]); - let v = self.a8() * b[8] - + self.a6() * b[6] - + self.a4() * b[4] - + self.a2() * b[2] + let v = self.a8.as_ref().unwrap() * b[8] + + self.a6.as_ref().unwrap() * b[6] + + self.a4.as_ref().unwrap() * b[4] + + self.a2.as_ref().unwrap() * b[2] + &self.ident * b[0]; (u, v) } @@ -300,9 +317,12 @@ where let s = s as f64; let mb = &self.a * convert::(2.0_f64.powf(-s)); - let mb2 = self.a2() * convert::(2.0_f64.powf(-2.0 * s)); - let mb4 = self.a4() * convert::(2.0.powf(-4.0 * s)); - let mb6 = self.a6() * convert::(2.0.powf(-6.0 * s)); + self.calc_a2(); + self.calc_a4(); + self.calc_a6(); + let mb2 = self.a2.as_ref().unwrap() * convert::(2.0_f64.powf(-2.0 * s)); + let mb4 = self.a4.as_ref().unwrap() * convert::(2.0.powf(-4.0 * s)); + let mb6 = self.a6.as_ref().unwrap() * convert::(2.0.powf(-6.0 * s)); let u2 = &mb6 * (&mb6 * b[13] + &mb4 * b[11] + &mb2 * b[9]); let u = &mb * (&u2 + &mb6 * b[7] + &mb4 * b[5] + &mb2 * b[3] + &self.ident * b[1]);