From cc2c53f328dd783afb95121293f1d9a3e657887c Mon Sep 17 00:00:00 2001 From: abdul124 Date: Mon, 22 Jul 2024 13:22:27 +0800 Subject: [PATCH] standalone: add nalgebra::linalg methods --- nac3standalone/demo/interpret_demo.py | 23 +++++++++++++++++++++++ nac3standalone/demo/src/ndarray.py | 17 +++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index b948edee..57d60717 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -141,6 +141,26 @@ def patch(module): else: raise NotImplementedError + def try_invert_to(x): + try: + y = np.linalg.inv(x) + x[:] = y + except np.linalg.LinAlgError: + return False + return True + + def wilkinson_shift(x): + assert (len(x.flatten()) == 4) and (x[0, 1] == x[1, 0]), f"Operation Wilkinson Shift expects symmetric matrix" + tmm, tnn, tmn = x[0, 0], x[1, 1], x[0, 1] + sq_tmn = tmn * tmn + if sq_tmn != 0: + d = (tmm - tnn) * 0.5 + if d > 0: + return tnn - sq_tmn / (d + np.sqrt(d*d + sq_tmn)) + else: + return tnn - sq_tmn / (d - np.sqrt(d*d + sq_tmn)) + return tnn + module.int32 = int32 module.int64 = int64 module.uint32 = uint32 @@ -226,6 +246,9 @@ def patch(module): module.sp_spec_j0 = special.j0 module.sp_spec_j1 = special.j1 + module.try_invert_to = try_invert_to + module.wilkinson_shift = wilkinson_shift + def file_import(filename, prefix="file_import_"): filename = pathlib.Path(filename) modname = prefix + filename.stem diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 7501cb5d..5c469650 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -1429,6 +1429,20 @@ def test_ndarray_nextafter_broadcast_rhs_scalar(): output_ndarray_float_2(nextafter_x_zeros) output_ndarray_float_2(nextafter_x_ones) +def test_try_invert(): + x: ndarray[float, 2] = np_array([[1.0, 2.0], [3.0, 4.0]]) + output_ndarray_float_2(x) + y = try_invert_to(x) + + output_ndarray_float_2(x) + output_bool(y) + +def test_wilkinson_shift(): + x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]]) + y = wilkinson_shift(x) + output_ndarray_float_2(x) + output_float64(y) + def run() -> int32: test_ndarray_ctor() test_ndarray_empty() @@ -1608,4 +1622,7 @@ def run() -> int32: test_ndarray_nextafter_broadcast_lhs_scalar() test_ndarray_nextafter_broadcast_rhs_scalar() + test_try_invert() + test_wilkinson_shift() + return 0