diff --git a/.gitignore b/.gitignore index fbf6a2ef..c98b1ad1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ __pycache__ /target nix/windows/msys2 +nac3standalone/demo/externfns/target diff --git a/Cargo.lock b/Cargo.lock index aaa531b2..12dfddef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -73,6 +73,15 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "approx" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" +dependencies = [ + "num-traits", +] + [[package]] name = "ascii-canvas" version = "3.0.0" @@ -305,6 +314,13 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "externfns" +version = "0.1.0" +dependencies = [ + "nalgebra", +] + [[package]] name = "fastrand" version = "2.1.0" @@ -521,6 +537,12 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "libm" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" + [[package]] name = "libredox" version = "0.1.3" @@ -658,18 +680,71 @@ name = "nac3standalone" version = "0.1.0" dependencies = [ "clap", + "externfns", "inkwell", "nac3core", "nac3parser", "parking_lot", ] +[[package]] +name = "nalgebra" +version = "0.32.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5c17de023a86f59ed79891b2e5d5a94c705dbe904a5b5c9c952ea6221b03e4" +dependencies = [ + "approx", + "num-complex", + "num-rational", + "num-traits", + "simba", + "typenum", +] + [[package]] name = "new_debug_unreachable" version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", + "libm", +] + [[package]] name = "once_cell" version = "1.19.0" @@ -699,6 +774,12 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + [[package]] name = "petgraph" version = "0.6.5" @@ -1070,6 +1151,18 @@ dependencies = [ "yaml-rust", ] +[[package]] +name = "simba" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "061507c94fc6ab4ba1c9a0305018408e312e17c041eb63bef8aa726fa33aceae" +dependencies = [ + "approx", + "num-complex", + "num-traits", + "paste", +] + [[package]] name = "similar" version = "2.5.0" @@ -1230,6 +1323,12 @@ dependencies = [ "crunchy", ] +[[package]] +name = "typenum" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" + [[package]] name = "unic-char-property" version = "0.9.0" diff --git a/Cargo.toml b/Cargo.toml index 765ab391..7b702502 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ members = [ "nac3parser", "nac3core", "nac3standalone", + "nac3standalone/demo/externfns", "nac3artiq", "runkernel", ] diff --git a/flake.nix b/flake.nix index 4febca24..7bd28c70 100644 --- a/flake.nix +++ b/flake.nix @@ -161,7 +161,9 @@ clippy pre-commit rustfmt + rust-analyzer ]; + RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}"; }; devShells.x86_64-linux.msys2 = pkgs.mkShell { name = "nac3-dev-shell-msys2"; diff --git a/nac3artiq/demo/demo.py b/nac3artiq/demo/demo.py index aa135757..3c6aab9e 100644 --- a/nac3artiq/demo/demo.py +++ b/nac3artiq/demo/demo.py @@ -14,12 +14,23 @@ class Demo: @kernel def run(self): - self.core.reset() - while True: - with parallel: - self.led0.pulse(100.*ms) - self.led1.pulse(100.*ms) - self.core.delay(100.*ms) + a = np_array([[1., 2.], [3., 4.]]) + b = try_invert_to(a) + if b: + # self.core.reset() + # while True: + # with parallel: + # self.led0.pulse(100.*ms) + # self.led1.pulse(100.*ms) + # self.core.delay(100.*ms) + v = try_invert_to(np_identity(2)) + if v: + while True: + with parallel: + self.led0.pulse(100.*ms) + self.led1.pulse(100.*ms) + self.core.delay(100.*ms) + if __name__ == "__main__": diff --git a/nac3artiq/demo/module.elf b/nac3artiq/demo/module.elf new file mode 100644 index 00000000..eb74b7cd Binary files /dev/null and b/nac3artiq/demo/module.elf differ diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 63078107..a7fc3a94 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -3,7 +3,7 @@ use inkwell::values::BasicValueEnum; use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel}; use itertools::Itertools; -use crate::codegen::classes::{NDArrayValue, ProxyValue, UntypedArrayLikeAccessor}; +use crate::codegen::classes::{ArrayLikeValue, NDArrayValue, ProxyValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor}; use crate::codegen::numpy::ndarray_elementwise_unaryop_impl; use crate::codegen::stmt::gen_for_callback_incrementing; use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator}; @@ -922,6 +922,122 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( }) } +pub fn call_try_invert_to<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + a: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "linalg_try_invert_to"; + let (a_ty, a) = a; + let llvm_usize = generator.get_size_type(ctx.ctx); + + match a { + BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => + { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); + let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); + match llvm_ndarray_ty { + BasicTypeEnum::FloatType(_) => {}, + _ => unreachable!("Inverse Operation supported on float type NDArray Values only") + }; + + let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); + let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None)); + + // Add asserts for dims + if cfg!(debug_assertions) { + let n_dims = n.load_ndims(ctx); + + // num_dim == 2 + ctx.make_assert(generator,ctx.builder.build_int_compare(IntPredicate::EQ, n_dims, n_dims.get_type().const_int(2, false), "").unwrap(), + "0:ValueError", format!("Inverse only supported on 2D lists").as_str(), [None, None, None], ctx.current_loc); + + let dim0 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)}; + let dim1 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)}; + + // dim0 == dim1 + ctx.make_assert(generator,ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim1, "").unwrap(), + "0:ValueError", format!("Dimensions do not match {dim0} is not same as {dim1}").as_str(), [None, None, None], ctx.current_loc); + + } + + if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { + let n_sz_eqz = ctx + .builder + .build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "") + .unwrap(); + + ctx.make_assert( + generator, + n_sz_eqz, + "0:ValueError", + format!("zero-size array to reduction operation {FN_NAME}").as_str(), + [None, None, None], + ctx.current_loc, + ); + } + + // Create a call to linalg_try_invert_to + let dim0 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)}; + let dim1 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)}; + + Ok(extern_fns::call_linalg_try_invert_to(ctx, dim0, dim1, n.data().base_ptr(ctx, generator), None).into()) + } + _ => unsupported_type(ctx, FN_NAME, &[a_ty]), + } +} + +pub fn call_wilkinson_shift<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + a: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "linalg_wilkinson_shift"; + let (a_ty, a) = a; + let llvm_usize = generator.get_size_type(ctx.ctx); + + match a { + BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => + { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); + let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); + match llvm_ndarray_ty { + BasicTypeEnum::FloatType(_) | BasicTypeEnum::IntType(_) => {}, + _ => unreachable!("Wilkinson Shift Operation supported on float/integer type NDArray Values only") + }; + + let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); + + // Add asserts for dims + if cfg!(debug_assertions) { + let n_dims = n.load_ndims(ctx); + + // num_dim == 2 + ctx.make_assert(generator,ctx.builder.build_int_compare(IntPredicate::EQ, n_dims, n_dims.get_type().const_int(2, false), "").unwrap(), + "0:ValueError", format!("Wilkinson Shift supported only on 2D lists").as_str(), [None, None, None], ctx.current_loc); + + let dim0 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)}; + let dim1 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)}; + + // dim0 == dim1 + ctx.make_assert(generator,ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim1, "").unwrap(), + "0:ValueError", format!("Dimensions do not match {dim0} is not same as {dim1}").as_str(), [None, None, None], ctx.current_loc); + + // dimesions should be 2x2 + ctx.make_assert(generator,ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim0.get_type().const_int(2, false), "").unwrap(), + "0:ValueError", format!("Wilkinson Shift supported only on 2x2 matrices").as_str(), [None, None, None], ctx.current_loc); + } + + // Create a call to linalg_try_invert_to + let dim0 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)}; + let dim1 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)}; + + Ok(extern_fns::call_linalg_wilkinson_shift(ctx, dim0, dim1, n.data().base_ptr(ctx, generator), None).into()) + } + _ => unsupported_type(ctx, FN_NAME, &[a_ty]), + } +} + /// Invokes the `np_maximum` builtin function. pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, diff --git a/nac3core/src/codegen/extern_fns.rs b/nac3core/src/codegen/extern_fns.rs index 8b510ed9..73190520 100644 --- a/nac3core/src/codegen/extern_fns.rs +++ b/nac3core/src/codegen/extern_fns.rs @@ -1,5 +1,5 @@ use inkwell::attributes::{Attribute, AttributeLoc}; -use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue}; +use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue}; use itertools::Either; use crate::codegen::CodeGenContext; @@ -130,3 +130,82 @@ pub fn call_ldexp<'ctx>( .map(Either::unwrap_left) .unwrap() } + +/// Invokes the [`try_invert_to`](https://docs.rs/nalgebra/latest/nalgebra/linalg/fn.try_invert_to.html) function +pub fn call_linalg_try_invert_to<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + dim0: IntValue<'ctx>, + dim1: IntValue<'ctx>, + data: PointerValue<'ctx>, + name: Option<&str>, +) -> IntValue<'ctx> { + const FN_NAME: &str = "linalg_try_invert_to"; + + let llvm_f64 = ctx.ctx.f64_type(); + let allowed_indices = [ctx.ctx.i32_type(), ctx.ctx.i64_type()]; + + let allowed_dim0 = allowed_indices.iter().any(|p| *p == dim0.get_type()); + let allowed_dim1 = allowed_indices.iter().any(|p| *p == dim1.get_type()); + + debug_assert!(allowed_dim0); + debug_assert!(allowed_dim1); + debug_assert_eq!(data.get_type().get_element_type().into_float_type(), llvm_f64); + + let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| { + let fn_type = ctx.ctx.i8_type().fn_type(&[dim0.get_type().into(), dim0.get_type().into(), data.get_type().into()], false); + let func = ctx.module.add_function(FN_NAME, fn_type, None); + for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] { + func.add_attribute( + AttributeLoc::Function, + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), + ); + } + + func + }); + + ctx.builder + .build_call(extern_fn, &[dim0.into(), dim1.into(), data.into()], name.unwrap_or_default()) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_int_value)) + .map(Either::unwrap_left) + .unwrap().into() +} + +/// Invokes the [`wilkinson_shift`](https://docs.rs/nalgebra/latest/nalgebra/linalg/fn.wilkinson_shift.html) function +pub fn call_linalg_wilkinson_shift<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + dim0: IntValue<'ctx>, + dim1: IntValue<'ctx>, + data: PointerValue<'ctx>, + name: Option<&str>, +) -> FloatValue<'ctx> { + const FN_NAME: &str = "linalg_wilkinson_shift"; + let allowed_index_types = [ctx.ctx.i32_type(), ctx.ctx.i64_type()]; + + let allowed_dim0 = allowed_index_types.iter().any(|p| *p == dim0.get_type()); + let allowed_dim1 = allowed_index_types.iter().any(|p| *p == dim1.get_type()); + + debug_assert!(allowed_dim0); + debug_assert!(allowed_dim1); + + let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| { + let fn_type = ctx.ctx.f64_type().fn_type(&[dim0.get_type().into(), dim0.get_type().into(), data.get_type().into()], false); + let func = ctx.module.add_function(FN_NAME, fn_type, None); + for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] { + func.add_attribute( + AttributeLoc::Function, + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), + ); + } + + func + }); + + ctx.builder + .build_call(extern_fn, &[dim0.into(), dim1.into(), data.into()], name.unwrap_or_default()) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_float_value)) + .map(Either::unwrap_left) + .unwrap().into() +} diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index bcb3c433..c07bd48c 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -556,6 +556,9 @@ impl<'a> BuiltinBuilder<'a> { | PrimDef::FunNpLdExp | PrimDef::FunNpHypot | PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim), + + PrimDef::FunTryInvertTo => self.build_linalg_try_invert_to(prim), // Inplace invert + PrimDef::FunWilkinsonShift => self.build_linalg_wilkinson_shift(prim), }; if cfg!(debug_assertions) { @@ -1874,6 +1877,64 @@ impl<'a> BuiltinBuilder<'a> { } } + fn build_linalg_try_invert_to(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed( + prim, + &[ + PrimDef::FunTryInvertTo, + ], + ); + + let var_map = self.num_or_ndarray_var_map.clone(); + create_fn_by_codegen( + self.unifier, + &var_map, + prim.name(), + self.primitives.bool, + &[(self.ndarray_float_2d, "x")], + Box::new(move |ctx, _, fun, args, generator| { + let x_ty = fun.0.args[0].ty; + let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; + + let func = match prim { + PrimDef::FunTryInvertTo => builtin_fns::call_try_invert_to, + _ => unreachable!(), + }; + + Ok(Some(func(generator, ctx, (x_ty, x_val))?)) + }), + ) + } + + fn build_linalg_wilkinson_shift(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed( + prim, + &[ + PrimDef::FunWilkinsonShift, + ], + ); + + let var_map = self.num_or_ndarray_var_map.clone(); + create_fn_by_codegen( + self.unifier, + &var_map, + prim.name(), + self.primitives.float, + &[(self.ndarray_float_2d, "x")], + Box::new(move |ctx, _, fun, args, generator| { + let x_ty = fun.0.args[0].ty; + let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; + + let func = match prim { + PrimDef::FunWilkinsonShift => builtin_fns::call_wilkinson_shift, + _ => unreachable!(), + }; + + Ok(Some(func(generator, ctx, (x_ty, x_val))?)) + }), + ) + } + fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) { (prim.simple_name().into(), method_ty, prim.id()) } diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 5560f41c..96a50dad 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -105,6 +105,8 @@ pub enum PrimDef { FunNpLdExp, FunNpHypot, FunNpNextAfter, + FunTryInvertTo, + FunWilkinsonShift, // Top-Level Functions FunSome, @@ -261,6 +263,8 @@ impl PrimDef { PrimDef::FunNpLdExp => fun("np_ldexp", None), PrimDef::FunNpHypot => fun("np_hypot", None), PrimDef::FunNpNextAfter => fun("np_nextafter", None), + PrimDef::FunTryInvertTo => fun("try_invert_to", None), + PrimDef::FunWilkinsonShift => fun("wilkinson_shift", None), PrimDef::FunSome => fun("Some", None), } } diff --git a/nac3standalone/Cargo.toml b/nac3standalone/Cargo.toml index a55a26b9..01ec3018 100644 --- a/nac3standalone/Cargo.toml +++ b/nac3standalone/Cargo.toml @@ -8,6 +8,7 @@ edition = "2021" parking_lot = "0.12" nac3parser = { path = "../nac3parser" } nac3core = { path = "../nac3core" } +externfns = { path = "./demo/externfns" } [dependencies.clap] version = "4.5" diff --git a/nac3standalone/demo/externfns/Cargo.toml b/nac3standalone/demo/externfns/Cargo.toml new file mode 100644 index 00000000..9d4592b4 --- /dev/null +++ b/nac3standalone/demo/externfns/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "externfns" +version = "0.1.0" +edition = "2021" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +nalgebra = {version = "0.32.6", default-features = false, features = ["libm", "alloc"]} diff --git a/nac3standalone/demo/externfns/src/lib.rs b/nac3standalone/demo/externfns/src/lib.rs new file mode 100644 index 00000000..ecde7b6f --- /dev/null +++ b/nac3standalone/demo/externfns/src/lib.rs @@ -0,0 +1,41 @@ +#![deny( + future_incompatible, + let_underscore, + nonstandard_style, + rust_2024_compatibility, + clippy::all +)] +#![warn(clippy::pedantic)] +#![allow(clippy::semicolon_if_nothing_returned, clippy::uninlined_format_args)] + +use core::slice; +use nalgebra::{DMatrix, linalg}; + +#[no_mangle] +/// Provide an interface to `nalgebra::linalg::try_invert_to` +pub extern "C" fn linalg_try_invert_to(dim0: usize, dim1: usize, data: *mut f64) -> i8 +{ + + let data_slice = unsafe { slice::from_raw_parts_mut(data, dim0 * dim1) }; + let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice); + let mut inverted_matrix = DMatrix::::zeros(dim0, dim1); + + if linalg::try_invert_to(matrix, &mut inverted_matrix) { + data_slice.copy_from_slice(inverted_matrix.transpose().as_slice()); + 1 + } else { + 0 + } +} + +#[no_mangle] +/// Provide an interface to `nalgebra::linalg::wilkinson_shift` +pub extern "C" fn linalg_wilkinson_shift(dim0: usize, dim1: usize, data: *mut f64) -> f64 +{ + let data_slice = unsafe { slice::from_raw_parts_mut(data, dim0 * dim1) }; + let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice); + + // Check if matrix is symmetric + assert!(matrix[(0, 1)] == matrix[(1, 0)], "Operation Wilkinson Shift expects symmetric matrix"); + return linalg::wilkinson_shift(matrix[(0, 0)], matrix[(1, 1)], matrix[(0, 1)]); +} diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index b5c0095f..7b84e150 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -141,6 +141,26 @@ def patch(module): else: raise NotImplementedError + def try_invert_to(x): + try: + y = np.linalg.inv(x) + x[:] = y + except np.linalg.LinAlgError: + return False + return True + + def wilkinson_shift(x): + assert (len(x.flatten()) == 4) and (x[0, 1] == x[1, 0]), f"Operation Wilkinson Shift expects symmetric matrix" + tmm, tnn, tmn = x[0, 0], x[1, 1], x[0, 1] + sq_tmn = tmn * tmn + if sq_tmn != 0: + d = (tmm - tnn) * 0.5 + if d > 0: + return tnn - sq_tmn / (d + np.sqrt(d*d + sq_tmn)) + else: + return tnn - sq_tmn / (d - np.sqrt(d*d + sq_tmn)) + return tnn + module.int32 = int32 module.int64 = int64 module.uint32 = uint32 @@ -234,6 +254,8 @@ def patch(module): module.np_full = np.full module.np_eye = np.eye module.np_identity = np.identity + module.try_invert_to = try_invert_to + module.wilkinson_shift = wilkinson_shift def file_import(filename, prefix="file_import_"): filename = pathlib.Path(filename) diff --git a/nac3standalone/demo/run_demo.sh b/nac3standalone/demo/run_demo.sh index 68132545..5a4665e0 100755 --- a/nac3standalone/demo/run_demo.sh +++ b/nac3standalone/demo/run_demo.sh @@ -42,11 +42,14 @@ done if [ -n "$debug" ] && [ -e ../../target/debug/nac3standalone ]; then nac3standalone=../../target/debug/nac3standalone + externfns=../../target/debug/deps/libexternfns.so elif [ -e ../../target/release/nac3standalone ]; then nac3standalone=../../target/release/nac3standalone + externfns=../../target/release/deps/libexternfns.so else # used by Nix builds nac3standalone=../../target/x86_64-unknown-linux-gnu/release/nac3standalone + externfns=../../target/x86_64-unknown-linux-gnu/release/deps/libexternfns.so fi rm -f ./*.o ./*.bc demo @@ -54,7 +57,7 @@ if [ -z "$use_lli" ]; then $nac3standalone "${nac3args[@]}" clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c - clang -lm -o demo module.o demo.o + clang -lm -o demo module.o demo.o $externfns if [ -z "$outfile" ]; then ./demo @@ -71,8 +74,8 @@ else shopt -u nullglob if [ -z "$outfile" ]; then - lli --extra-module demo.bc --extra-module irrt.bc nac3out.bc + lli -load=$externfns --extra-module demo.bc --extra-module irrt.bc nac3out.bc else - lli --extra-module demo.bc --extra-module irrt.bc nac3out.bc > "$outfile" + lli -load=$externfns --extra-module demo.bc --extra-module irrt.bc nac3out.bc > "$outfile" fi fi diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 7501cb5d..c2c5669f 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -1429,7 +1429,25 @@ def test_ndarray_nextafter_broadcast_rhs_scalar(): output_ndarray_float_2(nextafter_x_zeros) output_ndarray_float_2(nextafter_x_ones) +def test_try_invert(): + x: ndarray[float, 2] = np_array([[1.0, 1.0], [1.0, 5.0]]) + + output_ndarray_float_2(x) + + y = try_invert_to(x) + + output_ndarray_float_2(x) + output_bool(y) + +def test_wilkinson_shift(): + x: ndarray[float, 2] = np_array([[5., 1.], [1., 4.]]) + y = wilkinson_shift(x) + output_ndarray_float_2(x) + output_float64(y) + def run() -> int32: + test_try_invert() + test_wilkinson_shift() test_ndarray_ctor() test_ndarray_empty() test_ndarray_zeros() diff --git a/pyo3_output/nac3artiq.so b/pyo3_output/nac3artiq.so new file mode 100755 index 00000000..beb4f236 Binary files /dev/null and b/pyo3_output/nac3artiq.so differ