forked from M-Labs/nac3
standalone: add nalgebra::linalg methods
This commit is contained in:
parent
2dddab1fcf
commit
cc2c53f328
@ -141,6 +141,26 @@ def patch(module):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def try_invert_to(x):
|
||||
try:
|
||||
y = np.linalg.inv(x)
|
||||
x[:] = y
|
||||
except np.linalg.LinAlgError:
|
||||
return False
|
||||
return True
|
||||
|
||||
def wilkinson_shift(x):
|
||||
assert (len(x.flatten()) == 4) and (x[0, 1] == x[1, 0]), f"Operation Wilkinson Shift expects symmetric matrix"
|
||||
tmm, tnn, tmn = x[0, 0], x[1, 1], x[0, 1]
|
||||
sq_tmn = tmn * tmn
|
||||
if sq_tmn != 0:
|
||||
d = (tmm - tnn) * 0.5
|
||||
if d > 0:
|
||||
return tnn - sq_tmn / (d + np.sqrt(d*d + sq_tmn))
|
||||
else:
|
||||
return tnn - sq_tmn / (d - np.sqrt(d*d + sq_tmn))
|
||||
return tnn
|
||||
|
||||
module.int32 = int32
|
||||
module.int64 = int64
|
||||
module.uint32 = uint32
|
||||
@ -226,6 +246,9 @@ def patch(module):
|
||||
module.sp_spec_j0 = special.j0
|
||||
module.sp_spec_j1 = special.j1
|
||||
|
||||
module.try_invert_to = try_invert_to
|
||||
module.wilkinson_shift = wilkinson_shift
|
||||
|
||||
def file_import(filename, prefix="file_import_"):
|
||||
filename = pathlib.Path(filename)
|
||||
modname = prefix + filename.stem
|
||||
|
@ -1429,6 +1429,20 @@ def test_ndarray_nextafter_broadcast_rhs_scalar():
|
||||
output_ndarray_float_2(nextafter_x_zeros)
|
||||
output_ndarray_float_2(nextafter_x_ones)
|
||||
|
||||
def test_try_invert():
|
||||
x: ndarray[float, 2] = np_array([[1.0, 2.0], [3.0, 4.0]])
|
||||
output_ndarray_float_2(x)
|
||||
y = try_invert_to(x)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_bool(y)
|
||||
|
||||
def test_wilkinson_shift():
|
||||
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
|
||||
y = wilkinson_shift(x)
|
||||
output_ndarray_float_2(x)
|
||||
output_float64(y)
|
||||
|
||||
def run() -> int32:
|
||||
test_ndarray_ctor()
|
||||
test_ndarray_empty()
|
||||
@ -1608,4 +1622,7 @@ def run() -> int32:
|
||||
test_ndarray_nextafter_broadcast_lhs_scalar()
|
||||
test_ndarray_nextafter_broadcast_rhs_scalar()
|
||||
|
||||
test_try_invert()
|
||||
test_wilkinson_shift()
|
||||
|
||||
return 0
|
||||
|
Loading…
Reference in New Issue
Block a user