This commit is contained in:
abdul124 2024-07-19 18:18:08 +08:00
parent 6c10e3d056
commit 6ad597e592
17 changed files with 480 additions and 11 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
__pycache__ __pycache__
/target /target
nix/windows/msys2 nix/windows/msys2
nac3standalone/demo/externfns/target

99
Cargo.lock generated
View File

@ -73,6 +73,15 @@ dependencies = [
"windows-sys", "windows-sys",
] ]
[[package]]
name = "approx"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6"
dependencies = [
"num-traits",
]
[[package]] [[package]]
name = "ascii-canvas" name = "ascii-canvas"
version = "3.0.0" version = "3.0.0"
@ -305,6 +314,13 @@ dependencies = [
"windows-sys", "windows-sys",
] ]
[[package]]
name = "externfns"
version = "0.1.0"
dependencies = [
"nalgebra",
]
[[package]] [[package]]
name = "fastrand" name = "fastrand"
version = "2.1.0" version = "2.1.0"
@ -521,6 +537,12 @@ dependencies = [
"windows-targets", "windows-targets",
] ]
[[package]]
name = "libm"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
[[package]] [[package]]
name = "libredox" name = "libredox"
version = "0.1.3" version = "0.1.3"
@ -658,18 +680,71 @@ name = "nac3standalone"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"clap", "clap",
"externfns",
"inkwell", "inkwell",
"nac3core", "nac3core",
"nac3parser", "nac3parser",
"parking_lot", "parking_lot",
] ]
[[package]]
name = "nalgebra"
version = "0.32.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b5c17de023a86f59ed79891b2e5d5a94c705dbe904a5b5c9c952ea6221b03e4"
dependencies = [
"approx",
"num-complex",
"num-rational",
"num-traits",
"simba",
"typenum",
]
[[package]] [[package]]
name = "new_debug_unreachable" name = "new_debug_unreachable"
version = "1.0.6" version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086"
[[package]]
name = "num-complex"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
dependencies = [
"num-traits",
]
[[package]]
name = "num-integer"
version = "0.1.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
dependencies = [
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
dependencies = [
"num-integer",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
"libm",
]
[[package]] [[package]]
name = "once_cell" name = "once_cell"
version = "1.19.0" version = "1.19.0"
@ -699,6 +774,12 @@ dependencies = [
"windows-targets", "windows-targets",
] ]
[[package]]
name = "paste"
version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
[[package]] [[package]]
name = "petgraph" name = "petgraph"
version = "0.6.5" version = "0.6.5"
@ -1070,6 +1151,18 @@ dependencies = [
"yaml-rust", "yaml-rust",
] ]
[[package]]
name = "simba"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "061507c94fc6ab4ba1c9a0305018408e312e17c041eb63bef8aa726fa33aceae"
dependencies = [
"approx",
"num-complex",
"num-traits",
"paste",
]
[[package]] [[package]]
name = "similar" name = "similar"
version = "2.5.0" version = "2.5.0"
@ -1230,6 +1323,12 @@ dependencies = [
"crunchy", "crunchy",
] ]
[[package]]
name = "typenum"
version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
[[package]] [[package]]
name = "unic-char-property" name = "unic-char-property"
version = "0.9.0" version = "0.9.0"

View File

@ -5,6 +5,7 @@ members = [
"nac3parser", "nac3parser",
"nac3core", "nac3core",
"nac3standalone", "nac3standalone",
"nac3standalone/demo/externfns",
"nac3artiq", "nac3artiq",
"runkernel", "runkernel",
] ]

View File

@ -161,7 +161,9 @@
clippy clippy
pre-commit pre-commit
rustfmt rustfmt
rust-analyzer
]; ];
RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
}; };
devShells.x86_64-linux.msys2 = pkgs.mkShell { devShells.x86_64-linux.msys2 = pkgs.mkShell {
name = "nac3-dev-shell-msys2"; name = "nac3-dev-shell-msys2";

View File

@ -14,12 +14,23 @@ class Demo:
@kernel @kernel
def run(self): def run(self):
self.core.reset() a = np_array([[1., 2.], [3., 4.]])
while True: b = try_invert_to(a)
with parallel: if b:
self.led0.pulse(100.*ms) # self.core.reset()
self.led1.pulse(100.*ms) # while True:
self.core.delay(100.*ms) # with parallel:
# self.led0.pulse(100.*ms)
# self.led1.pulse(100.*ms)
# self.core.delay(100.*ms)
v = try_invert_to(np_identity(2))
if v:
while True:
with parallel:
self.led0.pulse(100.*ms)
self.led1.pulse(100.*ms)
self.core.delay(100.*ms)
if __name__ == "__main__": if __name__ == "__main__":

BIN
nac3artiq/demo/module.elf Normal file

Binary file not shown.

View File

@ -3,7 +3,7 @@ use inkwell::values::BasicValueEnum;
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel}; use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
use itertools::Itertools; use itertools::Itertools;
use crate::codegen::classes::{NDArrayValue, ProxyValue, UntypedArrayLikeAccessor}; use crate::codegen::classes::{ArrayLikeValue, NDArrayValue, ProxyValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor};
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl; use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
use crate::codegen::stmt::gen_for_callback_incrementing; use crate::codegen::stmt::gen_for_callback_incrementing;
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator}; use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
@ -922,6 +922,122 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
}) })
} }
pub fn call_try_invert_to<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
a: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "linalg_try_invert_to";
let (a_ty, a) = a;
let llvm_usize = generator.get_size_type(ctx.ctx);
match a {
BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
match llvm_ndarray_ty {
BasicTypeEnum::FloatType(_) => {},
_ => unreachable!("Inverse Operation supported on float type NDArray Values only")
};
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None));
// Add asserts for dims
if cfg!(debug_assertions) {
let n_dims = n.load_ndims(ctx);
// num_dim == 2
ctx.make_assert(generator,ctx.builder.build_int_compare(IntPredicate::EQ, n_dims, n_dims.get_type().const_int(2, false), "").unwrap(),
"0:ValueError", format!("Inverse only supported on 2D lists").as_str(), [None, None, None], ctx.current_loc);
let dim0 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)};
let dim1 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)};
// dim0 == dim1
ctx.make_assert(generator,ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim1, "").unwrap(),
"0:ValueError", format!("Dimensions do not match {dim0} is not same as {dim1}").as_str(), [None, None, None], ctx.current_loc);
}
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let n_sz_eqz = ctx
.builder
.build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "")
.unwrap();
ctx.make_assert(
generator,
n_sz_eqz,
"0:ValueError",
format!("zero-size array to reduction operation {FN_NAME}").as_str(),
[None, None, None],
ctx.current_loc,
);
}
// Create a call to linalg_try_invert_to
let dim0 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)};
let dim1 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)};
Ok(extern_fns::call_linalg_try_invert_to(ctx, dim0, dim1, n.data().base_ptr(ctx, generator), None).into())
}
_ => unsupported_type(ctx, FN_NAME, &[a_ty]),
}
}
pub fn call_wilkinson_shift<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
a: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "linalg_wilkinson_shift";
let (a_ty, a) = a;
let llvm_usize = generator.get_size_type(ctx.ctx);
match a {
BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
match llvm_ndarray_ty {
BasicTypeEnum::FloatType(_) | BasicTypeEnum::IntType(_) => {},
_ => unreachable!("Wilkinson Shift Operation supported on float/integer type NDArray Values only")
};
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
// Add asserts for dims
if cfg!(debug_assertions) {
let n_dims = n.load_ndims(ctx);
// num_dim == 2
ctx.make_assert(generator,ctx.builder.build_int_compare(IntPredicate::EQ, n_dims, n_dims.get_type().const_int(2, false), "").unwrap(),
"0:ValueError", format!("Wilkinson Shift supported only on 2D lists").as_str(), [None, None, None], ctx.current_loc);
let dim0 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)};
let dim1 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)};
// dim0 == dim1
ctx.make_assert(generator,ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim1, "").unwrap(),
"0:ValueError", format!("Dimensions do not match {dim0} is not same as {dim1}").as_str(), [None, None, None], ctx.current_loc);
// dimesions should be 2x2
ctx.make_assert(generator,ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim0.get_type().const_int(2, false), "").unwrap(),
"0:ValueError", format!("Wilkinson Shift supported only on 2x2 matrices").as_str(), [None, None, None], ctx.current_loc);
}
// Create a call to linalg_try_invert_to
let dim0 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)};
let dim1 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)};
Ok(extern_fns::call_linalg_wilkinson_shift(ctx, dim0, dim1, n.data().base_ptr(ctx, generator), None).into())
}
_ => unsupported_type(ctx, FN_NAME, &[a_ty]),
}
}
/// Invokes the `np_maximum` builtin function. /// Invokes the `np_maximum` builtin function.
pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>( pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,

View File

@ -1,5 +1,5 @@
use inkwell::attributes::{Attribute, AttributeLoc}; use inkwell::attributes::{Attribute, AttributeLoc};
use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue}; use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue};
use itertools::Either; use itertools::Either;
use crate::codegen::CodeGenContext; use crate::codegen::CodeGenContext;
@ -130,3 +130,82 @@ pub fn call_ldexp<'ctx>(
.map(Either::unwrap_left) .map(Either::unwrap_left)
.unwrap() .unwrap()
} }
/// Invokes the [`try_invert_to`](https://docs.rs/nalgebra/latest/nalgebra/linalg/fn.try_invert_to.html) function
pub fn call_linalg_try_invert_to<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
dim0: IntValue<'ctx>,
dim1: IntValue<'ctx>,
data: PointerValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
const FN_NAME: &str = "linalg_try_invert_to";
let llvm_f64 = ctx.ctx.f64_type();
let allowed_indices = [ctx.ctx.i32_type(), ctx.ctx.i64_type()];
let allowed_dim0 = allowed_indices.iter().any(|p| *p == dim0.get_type());
let allowed_dim1 = allowed_indices.iter().any(|p| *p == dim1.get_type());
debug_assert!(allowed_dim0);
debug_assert!(allowed_dim1);
debug_assert_eq!(data.get_type().get_element_type().into_float_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = ctx.ctx.i8_type().fn_type(&[dim0.get_type().into(), dim0.get_type().into(), data.get_type().into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[dim0.into(), dim1.into(), data.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap().into()
}
/// Invokes the [`wilkinson_shift`](https://docs.rs/nalgebra/latest/nalgebra/linalg/fn.wilkinson_shift.html) function
pub fn call_linalg_wilkinson_shift<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
dim0: IntValue<'ctx>,
dim1: IntValue<'ctx>,
data: PointerValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "linalg_wilkinson_shift";
let allowed_index_types = [ctx.ctx.i32_type(), ctx.ctx.i64_type()];
let allowed_dim0 = allowed_index_types.iter().any(|p| *p == dim0.get_type());
let allowed_dim1 = allowed_index_types.iter().any(|p| *p == dim1.get_type());
debug_assert!(allowed_dim0);
debug_assert!(allowed_dim1);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = ctx.ctx.f64_type().fn_type(&[dim0.get_type().into(), dim0.get_type().into(), data.get_type().into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[dim0.into(), dim1.into(), data.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap().into()
}

View File

@ -556,6 +556,9 @@ impl<'a> BuiltinBuilder<'a> {
| PrimDef::FunNpLdExp | PrimDef::FunNpLdExp
| PrimDef::FunNpHypot | PrimDef::FunNpHypot
| PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim), | PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim),
PrimDef::FunTryInvertTo => self.build_linalg_try_invert_to(prim), // Inplace invert
PrimDef::FunWilkinsonShift => self.build_linalg_wilkinson_shift(prim),
}; };
if cfg!(debug_assertions) { if cfg!(debug_assertions) {
@ -1874,6 +1877,64 @@ impl<'a> BuiltinBuilder<'a> {
} }
} }
fn build_linalg_try_invert_to(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(
prim,
&[
PrimDef::FunTryInvertTo,
],
);
let var_map = self.num_or_ndarray_var_map.clone();
create_fn_by_codegen(
self.unifier,
&var_map,
prim.name(),
self.primitives.bool,
&[(self.ndarray_float_2d, "x")],
Box::new(move |ctx, _, fun, args, generator| {
let x_ty = fun.0.args[0].ty;
let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?;
let func = match prim {
PrimDef::FunTryInvertTo => builtin_fns::call_try_invert_to,
_ => unreachable!(),
};
Ok(Some(func(generator, ctx, (x_ty, x_val))?))
}),
)
}
fn build_linalg_wilkinson_shift(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(
prim,
&[
PrimDef::FunWilkinsonShift,
],
);
let var_map = self.num_or_ndarray_var_map.clone();
create_fn_by_codegen(
self.unifier,
&var_map,
prim.name(),
self.primitives.float,
&[(self.ndarray_float_2d, "x")],
Box::new(move |ctx, _, fun, args, generator| {
let x_ty = fun.0.args[0].ty;
let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?;
let func = match prim {
PrimDef::FunWilkinsonShift => builtin_fns::call_wilkinson_shift,
_ => unreachable!(),
};
Ok(Some(func(generator, ctx, (x_ty, x_val))?))
}),
)
}
fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) { fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) {
(prim.simple_name().into(), method_ty, prim.id()) (prim.simple_name().into(), method_ty, prim.id())
} }

View File

@ -105,6 +105,8 @@ pub enum PrimDef {
FunNpLdExp, FunNpLdExp,
FunNpHypot, FunNpHypot,
FunNpNextAfter, FunNpNextAfter,
FunTryInvertTo,
FunWilkinsonShift,
// Top-Level Functions // Top-Level Functions
FunSome, FunSome,
@ -261,6 +263,8 @@ impl PrimDef {
PrimDef::FunNpLdExp => fun("np_ldexp", None), PrimDef::FunNpLdExp => fun("np_ldexp", None),
PrimDef::FunNpHypot => fun("np_hypot", None), PrimDef::FunNpHypot => fun("np_hypot", None),
PrimDef::FunNpNextAfter => fun("np_nextafter", None), PrimDef::FunNpNextAfter => fun("np_nextafter", None),
PrimDef::FunTryInvertTo => fun("try_invert_to", None),
PrimDef::FunWilkinsonShift => fun("wilkinson_shift", None),
PrimDef::FunSome => fun("Some", None), PrimDef::FunSome => fun("Some", None),
} }
} }

View File

@ -8,6 +8,7 @@ edition = "2021"
parking_lot = "0.12" parking_lot = "0.12"
nac3parser = { path = "../nac3parser" } nac3parser = { path = "../nac3parser" }
nac3core = { path = "../nac3core" } nac3core = { path = "../nac3core" }
externfns = { path = "./demo/externfns" }
[dependencies.clap] [dependencies.clap]
version = "4.5" version = "4.5"

View File

@ -0,0 +1,10 @@
[package]
name = "externfns"
version = "0.1.0"
edition = "2021"
[lib]
crate-type = ["cdylib"]
[dependencies]
nalgebra = {version = "0.32.6", default-features = false, features = ["libm", "alloc"]}

View File

@ -0,0 +1,41 @@
#![deny(
future_incompatible,
let_underscore,
nonstandard_style,
rust_2024_compatibility,
clippy::all
)]
#![warn(clippy::pedantic)]
#![allow(clippy::semicolon_if_nothing_returned, clippy::uninlined_format_args)]
use core::slice;
use nalgebra::{DMatrix, linalg};
#[no_mangle]
/// Provide an interface to `nalgebra::linalg::try_invert_to`
pub extern "C" fn linalg_try_invert_to(dim0: usize, dim1: usize, data: *mut f64) -> i8
{
let data_slice = unsafe { slice::from_raw_parts_mut(data, dim0 * dim1) };
let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice);
let mut inverted_matrix = DMatrix::<f64>::zeros(dim0, dim1);
if linalg::try_invert_to(matrix, &mut inverted_matrix) {
data_slice.copy_from_slice(inverted_matrix.transpose().as_slice());
1
} else {
0
}
}
#[no_mangle]
/// Provide an interface to `nalgebra::linalg::wilkinson_shift`
pub extern "C" fn linalg_wilkinson_shift(dim0: usize, dim1: usize, data: *mut f64) -> f64
{
let data_slice = unsafe { slice::from_raw_parts_mut(data, dim0 * dim1) };
let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice);
// Check if matrix is symmetric
assert!(matrix[(0, 1)] == matrix[(1, 0)], "Operation Wilkinson Shift expects symmetric matrix");
return linalg::wilkinson_shift(matrix[(0, 0)], matrix[(1, 1)], matrix[(0, 1)]);
}

View File

@ -141,6 +141,26 @@ def patch(module):
else: else:
raise NotImplementedError raise NotImplementedError
def try_invert_to(x):
try:
y = np.linalg.inv(x)
x[:] = y
except np.linalg.LinAlgError:
return False
return True
def wilkinson_shift(x):
assert (len(x.flatten()) == 4) and (x[0, 1] == x[1, 0]), f"Operation Wilkinson Shift expects symmetric matrix"
tmm, tnn, tmn = x[0, 0], x[1, 1], x[0, 1]
sq_tmn = tmn * tmn
if sq_tmn != 0:
d = (tmm - tnn) * 0.5
if d > 0:
return tnn - sq_tmn / (d + np.sqrt(d*d + sq_tmn))
else:
return tnn - sq_tmn / (d - np.sqrt(d*d + sq_tmn))
return tnn
module.int32 = int32 module.int32 = int32
module.int64 = int64 module.int64 = int64
module.uint32 = uint32 module.uint32 = uint32
@ -234,6 +254,8 @@ def patch(module):
module.np_full = np.full module.np_full = np.full
module.np_eye = np.eye module.np_eye = np.eye
module.np_identity = np.identity module.np_identity = np.identity
module.try_invert_to = try_invert_to
module.wilkinson_shift = wilkinson_shift
def file_import(filename, prefix="file_import_"): def file_import(filename, prefix="file_import_"):
filename = pathlib.Path(filename) filename = pathlib.Path(filename)

View File

@ -42,11 +42,14 @@ done
if [ -n "$debug" ] && [ -e ../../target/debug/nac3standalone ]; then if [ -n "$debug" ] && [ -e ../../target/debug/nac3standalone ]; then
nac3standalone=../../target/debug/nac3standalone nac3standalone=../../target/debug/nac3standalone
externfns=../../target/debug/deps/libexternfns.so
elif [ -e ../../target/release/nac3standalone ]; then elif [ -e ../../target/release/nac3standalone ]; then
nac3standalone=../../target/release/nac3standalone nac3standalone=../../target/release/nac3standalone
externfns=../../target/release/deps/libexternfns.so
else else
# used by Nix builds # used by Nix builds
nac3standalone=../../target/x86_64-unknown-linux-gnu/release/nac3standalone nac3standalone=../../target/x86_64-unknown-linux-gnu/release/nac3standalone
externfns=../../target/x86_64-unknown-linux-gnu/release/deps/libexternfns.so
fi fi
rm -f ./*.o ./*.bc demo rm -f ./*.o ./*.bc demo
@ -54,7 +57,7 @@ if [ -z "$use_lli" ]; then
$nac3standalone "${nac3args[@]}" $nac3standalone "${nac3args[@]}"
clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c
clang -lm -o demo module.o demo.o clang -lm -o demo module.o demo.o $externfns
if [ -z "$outfile" ]; then if [ -z "$outfile" ]; then
./demo ./demo
@ -71,8 +74,8 @@ else
shopt -u nullglob shopt -u nullglob
if [ -z "$outfile" ]; then if [ -z "$outfile" ]; then
lli --extra-module demo.bc --extra-module irrt.bc nac3out.bc lli -load=$externfns --extra-module demo.bc --extra-module irrt.bc nac3out.bc
else else
lli --extra-module demo.bc --extra-module irrt.bc nac3out.bc > "$outfile" lli -load=$externfns --extra-module demo.bc --extra-module irrt.bc nac3out.bc > "$outfile"
fi fi
fi fi

View File

@ -1429,7 +1429,25 @@ def test_ndarray_nextafter_broadcast_rhs_scalar():
output_ndarray_float_2(nextafter_x_zeros) output_ndarray_float_2(nextafter_x_zeros)
output_ndarray_float_2(nextafter_x_ones) output_ndarray_float_2(nextafter_x_ones)
def test_try_invert():
x: ndarray[float, 2] = np_array([[1.0, 1.0], [1.0, 5.0]])
output_ndarray_float_2(x)
y = try_invert_to(x)
output_ndarray_float_2(x)
output_bool(y)
def test_wilkinson_shift():
x: ndarray[float, 2] = np_array([[5., 1.], [1., 4.]])
y = wilkinson_shift(x)
output_ndarray_float_2(x)
output_float64(y)
def run() -> int32: def run() -> int32:
test_try_invert()
test_wilkinson_shift()
test_ndarray_ctor() test_ndarray_ctor()
test_ndarray_empty() test_ndarray_empty()
test_ndarray_zeros() test_ndarray_zeros()

BIN
pyo3_output/nac3artiq.so Executable file

Binary file not shown.