standalone: add nalgebra::linalg methods

This commit is contained in:
abdul124 2024-07-22 13:22:27 +08:00 committed by =
parent 2dddab1fcf
commit cc2c53f328
2 changed files with 40 additions and 0 deletions

View File

@ -141,6 +141,26 @@ def patch(module):
else: else:
raise NotImplementedError raise NotImplementedError
def try_invert_to(x):
try:
y = np.linalg.inv(x)
x[:] = y
except np.linalg.LinAlgError:
return False
return True
def wilkinson_shift(x):
assert (len(x.flatten()) == 4) and (x[0, 1] == x[1, 0]), f"Operation Wilkinson Shift expects symmetric matrix"
tmm, tnn, tmn = x[0, 0], x[1, 1], x[0, 1]
sq_tmn = tmn * tmn
if sq_tmn != 0:
d = (tmm - tnn) * 0.5
if d > 0:
return tnn - sq_tmn / (d + np.sqrt(d*d + sq_tmn))
else:
return tnn - sq_tmn / (d - np.sqrt(d*d + sq_tmn))
return tnn
module.int32 = int32 module.int32 = int32
module.int64 = int64 module.int64 = int64
module.uint32 = uint32 module.uint32 = uint32
@ -226,6 +246,9 @@ def patch(module):
module.sp_spec_j0 = special.j0 module.sp_spec_j0 = special.j0
module.sp_spec_j1 = special.j1 module.sp_spec_j1 = special.j1
module.try_invert_to = try_invert_to
module.wilkinson_shift = wilkinson_shift
def file_import(filename, prefix="file_import_"): def file_import(filename, prefix="file_import_"):
filename = pathlib.Path(filename) filename = pathlib.Path(filename)
modname = prefix + filename.stem modname = prefix + filename.stem

View File

@ -1429,6 +1429,20 @@ def test_ndarray_nextafter_broadcast_rhs_scalar():
output_ndarray_float_2(nextafter_x_zeros) output_ndarray_float_2(nextafter_x_zeros)
output_ndarray_float_2(nextafter_x_ones) output_ndarray_float_2(nextafter_x_ones)
def test_try_invert():
x: ndarray[float, 2] = np_array([[1.0, 2.0], [3.0, 4.0]])
output_ndarray_float_2(x)
y = try_invert_to(x)
output_ndarray_float_2(x)
output_bool(y)
def test_wilkinson_shift():
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
y = wilkinson_shift(x)
output_ndarray_float_2(x)
output_float64(y)
def run() -> int32: def run() -> int32:
test_ndarray_ctor() test_ndarray_ctor()
test_ndarray_empty() test_ndarray_empty()
@ -1608,4 +1622,7 @@ def run() -> int32:
test_ndarray_nextafter_broadcast_lhs_scalar() test_ndarray_nextafter_broadcast_lhs_scalar()
test_ndarray_nextafter_broadcast_rhs_scalar() test_ndarray_nextafter_broadcast_rhs_scalar()
test_try_invert()
test_wilkinson_shift()
return 0 return 0