diff --git a/Cargo.lock b/Cargo.lock index eacd7238..3c702e9e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -256,6 +256,12 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" +[[package]] +name = "cslice" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f8cb7306107e4b10e64994de6d3274bd08996a7c1322a27b86482392f96be0a" + [[package]] name = "dirs-next" version = "2.0.0" @@ -314,13 +320,6 @@ dependencies = [ "windows-sys", ] -[[package]] -name = "externfns" -version = "0.1.0" -dependencies = [ - "nalgebra", -] - [[package]] name = "fastrand" version = "2.1.0" @@ -553,6 +552,14 @@ dependencies = [ "libc", ] +[[package]] +name = "linalg_externfns" +version = "0.1.0" +dependencies = [ + "cslice", + "nalgebra", +] + [[package]] name = "linked-hash-map" version = "0.5.6" @@ -638,7 +645,6 @@ name = "nac3core" version = "0.1.0" dependencies = [ "crossbeam", - "externfns", "indexmap 2.2.6", "indoc", "inkwell", @@ -682,6 +688,7 @@ version = "0.1.0" dependencies = [ "clap", "inkwell", + "linalg_externfns", "nac3core", "nac3parser", "parking_lot", diff --git a/Cargo.toml b/Cargo.toml index 3b15c464..96fc975c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ members = [ "nac3ast", "nac3parser", "nac3core", - "nac3core/src/codegen/externfns", + "nac3standalone/linalg_externfns", "nac3standalone", "nac3artiq", "runkernel", diff --git a/flake.nix b/flake.nix index 4febca24..7bd28c70 100644 --- a/flake.nix +++ b/flake.nix @@ -161,7 +161,9 @@ clippy pre-commit rustfmt + rust-analyzer ]; + RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}"; }; devShells.x86_64-linux.msys2 = pkgs.mkShell { name = "nac3-dev-shell-msys2"; diff --git a/nac3core/Cargo.toml b/nac3core/Cargo.toml index 2659ddc4..724e0c8c 100644 --- a/nac3core/Cargo.toml +++ b/nac3core/Cargo.toml @@ -11,7 +11,6 @@ indexmap = "2.2" parking_lot = "0.12" rayon = "1.8" nac3parser = { path = "../nac3parser" } -externfns = { path = "src/codegen/externfns" } strum = "0.26.2" strum_macros = "0.26.4" diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 16a970bb..c218b1e8 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -1,5 +1,5 @@ use inkwell::types::BasicTypeEnum; -use inkwell::values::BasicValueEnum; +use inkwell::values::{BasicValue, BasicValueEnum, PointerValue}; use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel}; use itertools::Itertools; @@ -31,7 +31,6 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( let llvm_usize = generator.get_size_type(ctx.ctx); let (n_ty, n) = n; - Ok(match n { BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => { debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool)); @@ -1836,231 +1835,547 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>( }) } -/// Invokes the `linalg_try_invert_to` function -pub fn call_linalg_try_invert_to<'ctx, G: CodeGenerator + ?Sized>( +fn build_input_matrix<'ctx>( + ctx: &mut CodeGenContext<'ctx, '_>, + out_matrices: Vec>, +) -> PointerValue<'ctx> { + let field_ty = out_matrices.iter().map(|x| x.get_type()).collect::>(); + let out_ty = ctx.ctx.struct_type(&field_ty, false); + let out_ptr = ctx.builder.build_alloca(out_ty, "").unwrap(); + + for (i, v) in out_matrices.into_iter().enumerate() { + unsafe { + let ptr = ctx + .builder + .build_in_bounds_gep( + out_ptr, + &[ + ctx.ctx.i32_type().const_zero(), + ctx.ctx.i32_type().const_int(i as u64, false), + ], + "", + ) + .unwrap(); + ctx.builder.build_store(ptr, v).unwrap(); + } + } + out_ptr +} + +// Linalg Methods +pub fn call_np_dot<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - a: (Type, BasicValueEnum<'ctx>), + x1: (Type, BasicValueEnum<'ctx>), + x2: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - const FN_NAME: &str = "linalg_try_invert_to"; - let (a_ty, a) = a; + const FN_NAME: &str = "np_dot"; + let (x1_ty, x1) = x1; + let (x2_ty, x2) = x2; + + if let (BasicValueEnum::PointerValue(_), BasicValueEnum::PointerValue(_)) = (x1, x2) { + let (n1_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let n1_elem_ty = ctx.get_llvm_type(generator, n1_elem_ty); + let (n2_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let n2_elem_ty = ctx.get_llvm_type(generator, n2_elem_ty); + + let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty) + else { + unimplemented!("{FN_NAME} operates on float type NdArrays only"); + }; + + Ok(extern_fns::call_np_dot(ctx, x1, x2, None).into()) + } else { + unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) + } +} + +pub fn call_np_linalg_matmul<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + x1: (Type, BasicValueEnum<'ctx>), + x2: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_linalg_matmul"; + let (x1_ty, x1) = x1; + let (x2_ty, x2) = x2; + + let llvm_usize = generator.get_size_type(ctx.ctx); + if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) = (x1, x2) { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let (n2_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let n2_elem_ty = ctx.get_llvm_type(generator, n2_elem_ty); + + let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty) + else { + unimplemented!("{FN_NAME} operates on float type NdArrays only"); + }; + + let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None); + + let outdim0 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + .into_int_value() + }; + let outdim1 = unsafe { + n2.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) + .into_int_value() + }; + + let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1]) + .unwrap() + .as_base_value() + .as_basic_value_enum(); + + extern_fns::call_np_linalg_matmul(ctx, x1, x2, out, None); + Ok(out) + } else { + unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) + } +} + +pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + x1: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_linalg_cholesky"; + let (x1_ty, x1) = x1; let llvm_usize = generator.get_size_type(ctx.ctx); - match a { - BasicValueEnum::PointerValue(n) - if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); - let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); - match llvm_ndarray_ty { - BasicTypeEnum::FloatType(_) => {} - _ => { - unimplemented!("Inverse Operation supported on float type NDArray Values only") - } - }; + if let BasicValueEnum::PointerValue(n1) = x1 { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); - let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); - let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None)); + let BasicTypeEnum::FloatType(_) = n1_elem_ty else { + unimplemented!("{FN_NAME} operates on float type NdArrays only"); + }; - // The following constraints must be satisfied: - // * Input must be 2D - // * number of rows should equal number of columns (square matrix) - if cfg!(debug_assertions) { - let n_dims = n.load_ndims(ctx); + let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + let dim0 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + .into_int_value() + }; + let dim1 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) + .into_int_value() + }; - // num_dim == 2 - ctx.make_assert( - generator, - ctx.builder - .build_int_compare( - IntPredicate::EQ, - n_dims, - llvm_usize.const_int(2, false), - "", - ) - .unwrap(), - "0:ValueError", - format!("Input matrix must have two dimensions for {FN_NAME}").as_str(), - [None, None, None], - ctx.current_loc, - ); + let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1]) + .unwrap() + .as_base_value() + .as_basic_value_enum(); - let dim0 = unsafe { - n.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let dim1 = unsafe { - n.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - .into_int_value() - }; + extern_fns::call_np_linalg_cholesky(ctx, x1, out, None); + Ok(out) + } else { + unsupported_type(ctx, FN_NAME, &[x1_ty]) + } +} - // dim0 == dim1 - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim1, "").unwrap(), - "0:ValueError", - format!( - "Input matrix should have equal number of rows and columns for {FN_NAME}" - ) - .as_str(), - [None, None, None], - ctx.current_loc, - ); - } +pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + x1: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_linalg_qr"; + let (x1_ty, x1) = x1; + let llvm_usize = generator.get_size_type(ctx.ctx); - if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { - let n_sz_eqz = ctx + if let BasicValueEnum::PointerValue(n1) = x1 { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + + let BasicTypeEnum::FloatType(_) = n1_elem_ty else { + unimplemented!("{FN_NAME} operates on float type NdArrays only"); + }; + + let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + let dim0 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + .into_int_value() + }; + let dim1 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) + .into_int_value() + }; + let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None); + + let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k]) + .unwrap() + .as_base_value() + .as_basic_value_enum(); + let out_r = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1]) + .unwrap() + .as_base_value() + .as_basic_value_enum(); + + extern_fns::call_np_linalg_qr(ctx, x1, out_q, out_r, None); + + let out_ptr = build_input_matrix(ctx, vec![out_q, out_r]); + + Ok(ctx.builder.build_load(out_ptr, "QR_Factorization_result").map(Into::into).unwrap()) + } else { + unsupported_type(ctx, FN_NAME, &[x1_ty]) + } +} + +pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + x1: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_linalg_svd"; + let (x1_ty, x1) = x1; + let llvm_usize = generator.get_size_type(ctx.ctx); + + if let BasicValueEnum::PointerValue(n1) = x1 { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + + let BasicTypeEnum::FloatType(_) = n1_elem_ty else { + unimplemented!("{FN_NAME} operates on float type NdArrays only"); + }; + + let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + + let dim0 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + .into_int_value() + }; + let dim1 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) + .into_int_value() + }; + let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None); + + let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]) + .unwrap() + .as_base_value() + .as_basic_value_enum(); + let out_s = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k]) + .unwrap() + .as_base_value() + .as_basic_value_enum(); + let out_vh = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim1]) + .unwrap() + .as_base_value() + .as_basic_value_enum(); + + extern_fns::call_np_linalg_svd(ctx, x1, out_u, out_s, out_vh, None); + + let out_ptr = build_input_matrix(ctx, vec![out_u, out_s, out_vh]); + + Ok(ctx.builder.build_load(out_ptr, "SVD_Factorization_result").map(Into::into).unwrap()) + } else { + unsupported_type(ctx, FN_NAME, &[x1_ty]) + } +} + +pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + x1: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_linalg_inv"; + let (x1_ty, x1) = x1; + let llvm_usize = generator.get_size_type(ctx.ctx); + + if let BasicValueEnum::PointerValue(n1) = x1 { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + + let BasicTypeEnum::FloatType(_) = n1_elem_ty else { + unimplemented!("{FN_NAME} operates on float type NdArrays only"); + }; + + let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + + let dim0 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + .into_int_value() + }; + let dim1 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) + .into_int_value() + }; + + let out = + numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1]).unwrap(); + + extern_fns::call_np_linalg_inv( + ctx, + (dim0, dim1, n1.data().base_ptr(ctx, generator)), + (dim0, dim1, out.data().base_ptr(ctx, generator)), + None, + ); + Ok(out.as_base_value().as_basic_value_enum()) + } else { + unsupported_type(ctx, FN_NAME, &[x1_ty]) + } +} + +pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + x1: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_linalg_pinv"; + let (x1_ty, x1) = x1; + let llvm_usize = generator.get_size_type(ctx.ctx); + + if let BasicValueEnum::PointerValue(n1) = x1 { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + + let BasicTypeEnum::FloatType(_) = n1_elem_ty else { + unimplemented!("{FN_NAME} operates on float type NdArrays only"); + }; + + let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + + let dim0 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + .into_int_value() + }; + let dim1 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) + .into_int_value() + }; + + let out = + numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim0]).unwrap(); + + extern_fns::call_np_linalg_pinv( + ctx, + (dim0, dim1, n1.data().base_ptr(ctx, generator)), + (dim1, dim0, out.data().base_ptr(ctx, generator)), + None, + ); + Ok(out.as_base_value().as_basic_value_enum()) + } else { + unsupported_type(ctx, FN_NAME, &[x1_ty]) + } +} + +pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + x1: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "sp_linalg_lu"; + let (x1_ty, x1) = x1; + let llvm_usize = generator.get_size_type(ctx.ctx); + + if let BasicValueEnum::PointerValue(n1) = x1 { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + + let BasicTypeEnum::FloatType(_) = n1_elem_ty else { + unimplemented!("{FN_NAME} operates on float type NdArrays only"); + }; + + let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + + let dim0 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + .into_int_value() + }; + let dim1 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) + .into_int_value() + }; + let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None); + + let out_l = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k]).unwrap(); + let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1]).unwrap(); + + extern_fns::call_sp_linalg_lu( + ctx, + (dim0, dim1, n1.data().base_ptr(ctx, generator)), + (dim0, k, out_l.data().base_ptr(ctx, generator)), + (k, dim1, out_u.data().base_ptr(ctx, generator)), + None, + ); + + let out_l = out_l.as_base_value().as_basic_value_enum(); + let out_u = out_u.as_base_value().as_basic_value_enum(); + + let res_ty = ctx.ctx.struct_type(&[out_l.get_type(), out_u.get_type()], false); + let res_ptr = ctx.builder.build_alloca(res_ty, "LU_factorization").unwrap(); + + let res_val = [out_l, out_u]; + for (i, v) in res_val.into_iter().enumerate() { + unsafe { + let ptr = ctx .builder - .build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "") + .build_in_bounds_gep( + res_ptr, + &[ + ctx.ctx.i32_type().const_zero(), + ctx.ctx.i32_type().const_int(i as u64, false), + ], + "ptr", + ) .unwrap(); - - ctx.make_assert( - generator, - n_sz_eqz, - "0:ValueError", - format!("zero-size array to inverse operation {FN_NAME}").as_str(), - [None, None, None], - ctx.current_loc, - ); + ctx.builder.build_store(ptr, v).unwrap(); } - - let dim0 = unsafe { - n.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let dim1 = unsafe { - n.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - .into_int_value() - }; - - Ok(extern_fns::call_linalg_try_invert_to( - ctx, - dim0, - dim1, - n.data().base_ptr(ctx, generator), - None, - ) - .into()) } - _ => unsupported_type(ctx, FN_NAME, &[a_ty]), + + Ok(ctx.builder.build_load(res_ptr, "LU_Factorization_result").map(Into::into).unwrap()) + } else { + unsupported_type(ctx, FN_NAME, &[x1_ty]) } } -/// Invokes the `linalg_wilkinson_shift` function -pub fn call_linalg_wilkinson_shift<'ctx, G: CodeGenerator + ?Sized>( +// Must be square (add check later) +pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - a: (Type, BasicValueEnum<'ctx>), + x1: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - const FN_NAME: &str = "linalg_wilkinson_shift"; - let (a_ty, a) = a; + const FN_NAME: &str = "sp_linalg_schur"; + let (x1_ty, x1) = x1; let llvm_usize = generator.get_size_type(ctx.ctx); - let one = llvm_usize.const_int(1, false); - let two = llvm_usize.const_int(2, false); + if let BasicValueEnum::PointerValue(n1) = x1 { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); - match a { - BasicValueEnum::PointerValue(n) - if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); - let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); - match llvm_ndarray_ty { - BasicTypeEnum::FloatType(_) | BasicTypeEnum::IntType(_) => {} - _ => unimplemented!( - "Wilkinson Shift Operation supported on float type NDArray Values only" - ), - }; + let BasicTypeEnum::FloatType(_) = n1_elem_ty else { + unimplemented!("{FN_NAME} operates on float type NdArrays only"); + }; - let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); + let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); - // The following constraints must be satisfied: - // * Input must be 2D - // * Number of rows and columns should equal 2 - // * Input matrix must be symmetric - if cfg!(debug_assertions) { - let n_dims = n.load_ndims(ctx); + let dim0 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + .into_int_value() + }; + let dim1 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) + .into_int_value() + }; + let out_t = + numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]).unwrap(); + let out_z = + numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]).unwrap(); - // num_dim == 2 - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, n_dims, two, "").unwrap(), - "0:ValueError", - format!("Input matrix must have two dimensions for {FN_NAME}").as_str(), - [None, None, None], - ctx.current_loc, - ); + extern_fns::call_sp_linalg_schur( + ctx, + (dim0, dim1, n1.data().base_ptr(ctx, generator)), + (dim0, dim0, out_t.data().base_ptr(ctx, generator)), + (dim0, dim0, out_z.data().base_ptr(ctx, generator)), + None, + ); - let dim0 = unsafe { - n.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let dim1 = unsafe { - n.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value() - }; + let out_t = out_t.as_base_value().as_basic_value_enum(); + let out_z = out_z.as_base_value().as_basic_value_enum(); - // dim0 == 2 - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, dim0, two, "").unwrap(), - "0:ValueError", - format!("Number of rows must be 2 for {FN_NAME}").as_str(), - [None, None, None], - ctx.current_loc, - ); + let res_ty = ctx.ctx.struct_type(&[out_t.get_type(), out_z.get_type()], false); - // dim1 == 2 - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, dim1, two, "").unwrap(), - "0:ValueError", - format!("Number of columns must be 2 for {FN_NAME}").as_str(), - [None, None, None], - ctx.current_loc, - ); - - let entry_01 = unsafe { - n.data().get_unchecked(ctx, generator, &one, None).into_float_value() - }; - let entry_10 = unsafe { - n.data().get_unchecked(ctx, generator, &two, None).into_float_value() - }; - - // symmetric matrix - ctx.make_assert( - generator, - ctx.builder - .build_float_compare(FloatPredicate::OEQ, entry_01, entry_10, "") - .unwrap(), - "0:ValueError", - format!("Input Matrix must be symmetric for {FN_NAME}").as_str(), - [None, None, None], - ctx.current_loc, - ); + let res_ptr = ctx.builder.build_alloca(res_ty, "Schur_factorization").unwrap(); + let res_val = [out_t, out_z]; + for (i, v) in res_val.into_iter().enumerate() { + unsafe { + let ptr = ctx + .builder + .build_in_bounds_gep( + res_ptr, + &[ + ctx.ctx.i32_type().const_zero(), + ctx.ctx.i32_type().const_int(i as u64, false), + ], + "ptr", + ) + .unwrap(); + ctx.builder.build_store(ptr, v).unwrap(); } - - let dim0 = unsafe { - n.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let dim1 = - unsafe { n.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value() }; - - Ok(extern_fns::call_linalg_wilkinson_shift( - ctx, - dim0, - dim1, - n.data().base_ptr(ctx, generator), - None, - ) - .into()) } - _ => unsupported_type(ctx, FN_NAME, &[a_ty]), + + Ok(ctx.builder.build_load(res_ptr, "Schur_Factorization_result").map(Into::into).unwrap()) + } else { + unsupported_type(ctx, FN_NAME, &[x1_ty]) + } +} + +// Must be square (add check later) +pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + x1: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "sp_linalg_hessenberg"; + let (x1_ty, x1) = x1; + let llvm_usize = generator.get_size_type(ctx.ctx); + + if let BasicValueEnum::PointerValue(n1) = x1 { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + + let BasicTypeEnum::FloatType(_) = n1_elem_ty else { + unimplemented!("{FN_NAME} operates on float type NdArrays only"); + }; + + let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + + let dim0 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + .into_int_value() + }; + let dim1 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) + .into_int_value() + }; + + // Check if matrix is square + // ctx.builder.build_select( + // ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim1, "").unwrap(), + // { + // let func = + // }, else_, name) + // ; + + // ctx.builder.build_call( + // ctx.module.get_function("__nac3_raise"), + // &[] + + // ) + // let err_msg = ctx.gen_string(generator, "{FN_NAME} requires square matrix"); + // ctx.raise_exn(generator, "0:ValueError", err_msg, [None, None, None], ctx.current_loc); + + let out_h = + numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]).unwrap(); + + extern_fns::call_sp_linalg_hessenberg( + ctx, + (dim0, dim1, n1.data().base_ptr(ctx, generator)), + (dim0, dim0, out_h.data().base_ptr(ctx, generator)), + None, + ); + + Ok(out_h.as_base_value().as_basic_value_enum()) + } else { + unsupported_type(ctx, FN_NAME, &[x1_ty]) } } diff --git a/nac3core/src/codegen/extern_fns.rs b/nac3core/src/codegen/extern_fns.rs index 09e97c5a..1b8274e6 100644 --- a/nac3core/src/codegen/extern_fns.rs +++ b/nac3core/src/codegen/extern_fns.rs @@ -131,90 +131,218 @@ pub fn call_ldexp<'ctx>( .unwrap() } -/// Invokes the [`try_invert_to`](https://docs.rs/nalgebra/latest/nalgebra/linalg/fn.try_invert_to.html) function -pub fn call_linalg_try_invert_to<'ctx>( - ctx: &CodeGenContext<'ctx, '_>, - dim0: IntValue<'ctx>, - dim1: IntValue<'ctx>, - data: PointerValue<'ctx>, - name: Option<&str>, -) -> IntValue<'ctx> { - const FN_NAME: &str = "linalg_try_invert_to"; +/// Macro to generate np_linalg external functions +macro_rules! generate_np_linalg_extern_fn { + ($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 1) => { + generate_np_linalg_extern_fn!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1); + }; + ($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 2) => { + generate_np_linalg_extern_fn!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1, mat2); + }; + ($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 3) => { + generate_np_linalg_extern_fn!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1, mat2, mat3); + }; + ($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 4) => { + generate_np_linalg_extern_fn!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1, mat2, mat3, mat4); + }; + ($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal $(,$input_matrix:ident)*) => { + #[doc = concat!("Invokes the numpy `", stringify!($extern_fn), " function." )] + pub fn $fn_name<'ctx>( + ctx: &mut CodeGenContext<'ctx, '_> + $(,$input_matrix: (IntValue<'ctx>, IntValue<'ctx>, PointerValue<'ctx>))*, + name: Option<&str>, + ) -> $ret_ty<'ctx> { + const FN_NAME: &str = $extern_fn; - let llvm_f64 = ctx.ctx.f64_type(); - let allowed_indices = [ctx.ctx.i32_type(), ctx.ctx.i64_type()]; + let llvm_f64 = ctx.ctx.f64_type(); + let allowed_index_types = [ctx.ctx.i32_type(), ctx.ctx.i64_type()]; - let allowed_dim0 = allowed_indices.iter().any(|p| *p == dim0.get_type()); - let allowed_dim1 = allowed_indices.iter().any(|p| *p == dim1.get_type()); + $( + debug_assert!(allowed_index_types.iter().any(|p| *p == $input_matrix.0.get_type())); + debug_assert!(allowed_index_types.iter().any(|p| *p == $input_matrix.1.get_type())); + debug_assert_eq!($input_matrix.2.get_type().get_element_type().into_float_type(), llvm_f64); + )* - debug_assert!(allowed_dim0); - debug_assert!(allowed_dim1); - debug_assert_eq!(data.get_type().get_element_type().into_float_type(), llvm_f64); + // let row = ctx.ctx.i32_type().const_int(ctx.current_loc.row.try_into().unwrap(), false); + // let col = ctx.ctx.i32_type().const_int(ctx.current_loc.column.try_into().unwrap(), false); + // let file_name = ctx.current_loc.file.0; + // let name_len = ctx.ctx.i32_type().const_int(file_name.to_string().len().try_into().unwrap(), false); + // let file_name = ctx.ctx.const_string(&ctx.current_loc.file.0.to_string().into_bytes(), true); - let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| { - let fn_type = ctx.ctx.i8_type().fn_type( - &[dim0.get_type().into(), dim0.get_type().into(), data.get_type().into()], - false, - ); - let func = ctx.module.add_function(FN_NAME, fn_type, None); - for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] { - func.add_attribute( - AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), - ); + let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| { + // let fn_type = ctx.ctx.$extern_ret_ty().fn_type(&[row.get_type().into(), col.get_type().into(), file_name.get_type().into(), $($input_matrix.0.get_type().into(), $input_matrix.1.get_type().into(), $input_matrix.2.get_type().into()),*], false); + // let fn_type = ctx.ctx.$extern_ret_ty().fn_type(&[row.get_type().into(), col.get_type().into(), file_name.get_type().into(), name_len.get_type().into(), $($input_matrix.0.get_type().into(), $input_matrix.1.get_type().into(), $input_matrix.2.get_type().into()),*], false); + let fn_type = ctx.ctx.$extern_ret_ty().fn_type(&[$($input_matrix.0.get_type().into(), $input_matrix.1.get_type().into(), $input_matrix.2.get_type().into()),*], false); + + let func = ctx.module.add_function(FN_NAME, fn_type, None); + for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { + func.add_attribute( + AttributeLoc::Function, + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), + ); + } + func + }); + + ctx.builder + // .build_call(extern_fn, &[row.into(), col.into(), file_name.into(), $($input_matrix.0.into(), $input_matrix.1.into(), $input_matrix.2.into(),)*], name.unwrap_or_default()) + // .build_call(extern_fn, &[name_len.into(), col.into(), file_name.into(), row.into(), $($input_matrix.0.into(), $input_matrix.1.into(), $input_matrix.2.into(),)*], name.unwrap_or_default()) + .build_call(extern_fn, &[$($input_matrix.0.into(), $input_matrix.1.into(), $input_matrix.2.into(),)*], name.unwrap_or_default()) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left($map_fn)) + .map(Either::unwrap_left) + .unwrap() } - - func - }); - - ctx.builder - .build_call(extern_fn, &[dim0.into(), dim1.into(), data.into()], name.unwrap_or_default()) - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap() + }; } -/// Invokes the [`wilkinson_shift`](https://docs.rs/nalgebra/latest/nalgebra/linalg/fn.wilkinson_shift.html) function -pub fn call_linalg_wilkinson_shift<'ctx>( - ctx: &CodeGenContext<'ctx, '_>, - dim0: IntValue<'ctx>, - dim1: IntValue<'ctx>, - data: PointerValue<'ctx>, - name: Option<&str>, -) -> FloatValue<'ctx> { - const FN_NAME: &str = "linalg_wilkinson_shift"; +/// Macro to generate `np_linalg` external functions +macro_rules! generate_np_linalg_extern_fn2 { + ($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 1) => { + generate_np_linalg_extern_fn2!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1); + }; + ($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 2) => { + generate_np_linalg_extern_fn2!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1, mat2); + }; + ($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 3) => { + generate_np_linalg_extern_fn2!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1, mat2, mat3); + }; + ($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 4) => { + generate_np_linalg_extern_fn2!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1, mat2, mat3, mat4); + }; + ($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal $(,$input_matrix:ident)*) => { + #[doc = concat!("Invokes the numpy `", stringify!($extern_fn), " function." )] + pub fn $fn_name<'ctx>( + ctx: &mut CodeGenContext<'ctx, '_> + $(,$input_matrix: BasicValueEnum<'ctx>)*, + name: Option<&str>, + ) -> $ret_ty<'ctx> { + const FN_NAME: &str = $extern_fn; + // let row = ctx.ctx.i32_type().const_int(ctx.current_loc.row.try_into().unwrap(), false); + // let col = ctx.ctx.i32_type().const_int(ctx.current_loc.column.try_into().unwrap(), false); + // let file_name = ctx.current_loc.file.0; + // let name_len = ctx.ctx.i32_type().const_int(file_name.to_string().len().try_into().unwrap(), false); + // let file_name = ctx.ctx.const_string(&ctx.current_loc.file.0.to_string().into_bytes(), true); - let llvm_f64 = ctx.ctx.f64_type(); - let allowed_index_types = [ctx.ctx.i32_type(), ctx.ctx.i64_type()]; + let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| { + // let fn_type = ctx.ctx.$extern_ret_ty().fn_type(&[row.get_type().into(), col.get_type().into(), file_name.get_type().into(), $($input_matrix.0.get_type().into(), $input_matrix.1.get_type().into(), $input_matrix.2.get_type().into()),*], false); + // let fn_type = ctx.ctx.$extern_ret_ty().fn_type(&[row.get_type().into(), col.get_type().into(), file_name.get_type().into(), name_len.get_type().into(), $($input_matrix.0.get_type().into(), $input_matrix.1.get_type().into(), $input_matrix.2.get_type().into()),*], false); + let fn_type = ctx.ctx.$extern_ret_ty().fn_type(&[$($input_matrix.get_type().into()),*], false); - let allowed_dim0 = allowed_index_types.iter().any(|p| *p == dim0.get_type()); - let allowed_dim1 = allowed_index_types.iter().any(|p| *p == dim1.get_type()); + let func = ctx.module.add_function(FN_NAME, fn_type, None); + for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { + func.add_attribute( + AttributeLoc::Function, + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), + ); + } + func + }); - debug_assert!(allowed_dim0); - debug_assert!(allowed_dim1); - debug_assert_eq!(data.get_type().get_element_type().into_float_type(), llvm_f64); - - let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| { - let fn_type = ctx.ctx.f64_type().fn_type( - &[dim0.get_type().into(), dim0.get_type().into(), data.get_type().into()], - false, - ); - let func = ctx.module.add_function(FN_NAME, fn_type, None); - for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] { - func.add_attribute( - AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), - ); + ctx.builder + // .build_call(extern_fn, &[row.into(), col.into(), file_name.into(), $($input_matrix.0.into(), $input_matrix.1.into(), $input_matrix.2.into(),)*], name.unwrap_or_default()) + // .build_call(extern_fn, &[name_len.into(), col.into(), file_name.into(), row.into(), $($input_matrix.0.into(), $input_matrix.1.into(), $input_matrix.2.into(),)*], name.unwrap_or_default()) + .build_call(extern_fn, &[$($input_matrix.into(),)*], name.unwrap_or_default()) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left($map_fn)) + .map(Either::unwrap_left) + .unwrap() } - - func - }); - - ctx.builder - .build_call(extern_fn, &[dim0.into(), dim1.into(), data.into()], name.unwrap_or_default()) - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_float_value)) - .map(Either::unwrap_left) - .unwrap() + }; } +/// Macro to generate `np_linalg` external functions +macro_rules! generate_np_linalg_extern_fn3 { + ($fn_name:ident, $extern_fn:literal, 1) => { + generate_np_linalg_extern_fn3!($fn_name, $extern_fn, mat1); + }; + ($fn_name:ident, $extern_fn:literal, 2) => { + generate_np_linalg_extern_fn3!($fn_name, $extern_fn, mat1, mat2); + }; + ($fn_name:ident, $extern_fn:literal, 3) => { + generate_np_linalg_extern_fn3!($fn_name, $extern_fn, mat1, mat2, mat3); + }; + ($fn_name:ident, $extern_fn:literal, 4) => { + generate_np_linalg_extern_fn3!($fn_name, $extern_fn, mat1, mat2, mat3, mat4); + }; + ($fn_name:ident, $extern_fn:literal $(,$input_matrix:ident)*) => { + #[doc = concat!("Invokes the numpy `", stringify!($extern_fn), " function." )] + pub fn $fn_name<'ctx>( + ctx: &mut CodeGenContext<'ctx, '_> + $(,$input_matrix: BasicValueEnum<'ctx>)*, + name: Option<&str>, + ){ + const FN_NAME: &str = $extern_fn; + let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| { + let fn_type = ctx.ctx.void_type().fn_type(&[$($input_matrix.get_type().into()),*], false); + + let func = ctx.module.add_function(FN_NAME, fn_type, None); + for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { + func.add_attribute( + AttributeLoc::Function, + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), + ); + } + func + }); + + ctx.builder.build_call(extern_fn, &[$($input_matrix.into(),)*], name.unwrap_or_default()).unwrap(); + } + }; +} +generate_np_linalg_extern_fn2!( + call_np_dot, + FloatValue, + f64_type, + BasicValueEnum::into_float_value, + "np_dot", + 2 +); +generate_np_linalg_extern_fn3!(call_np_linalg_matmul, "np_linalg_matmul", 3); +generate_np_linalg_extern_fn3!(call_np_linalg_cholesky, "np_linalg_cholesky", 2); +generate_np_linalg_extern_fn3!(call_np_linalg_qr, "np_linalg_qr", 3); +generate_np_linalg_extern_fn3!(call_np_linalg_svd, "np_linalg_svd", 4); + +generate_np_linalg_extern_fn!( + call_np_linalg_inv, + IntValue, + i8_type, + BasicValueEnum::into_int_value, + "np_linalg_inv", + 2 +); + +generate_np_linalg_extern_fn!( + call_np_linalg_pinv, + IntValue, + i8_type, + BasicValueEnum::into_int_value, + "np_linalg_pinv", + 2 +); + +generate_np_linalg_extern_fn!( + call_sp_linalg_lu, + IntValue, + i8_type, + BasicValueEnum::into_int_value, + "sp_linalg_lu", + 3 +); + +generate_np_linalg_extern_fn!( + call_sp_linalg_schur, + IntValue, + i8_type, + BasicValueEnum::into_int_value, + "sp_linalg_schur", + 3 +); + +generate_np_linalg_extern_fn!( + call_sp_linalg_hessenberg, + IntValue, + i8_type, + BasicValueEnum::into_int_value, + "sp_linalg_hessenberg", + 2 +); diff --git a/nac3core/src/codegen/externfns/src/lib.rs b/nac3core/src/codegen/externfns/src/lib.rs deleted file mode 100644 index 504aa670..00000000 --- a/nac3core/src/codegen/externfns/src/lib.rs +++ /dev/null @@ -1,30 +0,0 @@ -use core::slice; -use nalgebra::{linalg, DMatrix}; - -/// # Safety -/// -/// `data` must point to an array with `dim0`x`dim1` elements in row-major order -#[no_mangle] -pub unsafe extern "C" fn linalg_try_invert_to(dim0: usize, dim1: usize, data: *mut f64) -> i8 { - let data_slice = unsafe { slice::from_raw_parts_mut(data, dim0 * dim1) }; - let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice); - let mut inverted_matrix = DMatrix::::zeros(dim0, dim1); - - if linalg::try_invert_to(matrix, &mut inverted_matrix) { - data_slice.copy_from_slice(inverted_matrix.transpose().as_slice()); - 1 - } else { - 0 - } -} - -/// # Safety -/// -/// `data` must point to an array of 4 elements in row-major order -#[no_mangle] -pub unsafe extern "C" fn linalg_wilkinson_shift(dim0: usize, dim1: usize, data: *mut f64) -> f64 { - let data_slice = unsafe { slice::from_raw_parts_mut(data, dim0 * dim1) }; - let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice); - - linalg::wilkinson_shift(matrix[(0, 0)], matrix[(1, 1)], matrix[(0, 1)]) -} diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 7421c894..6b5b3ace 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -61,7 +61,7 @@ fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>( /// * `shape` - The shape of the `NDArray`. /// * `shape_len_fn` - A function that retrieves the number of dimensions from `shape`. /// * `shape_data_fn` - A function that retrieves the size of a dimension from `shape`. -fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>( +pub fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, elem_ty: Type, @@ -157,7 +157,7 @@ where /// /// * `elem_ty` - The element type of the `NDArray`. /// * `shape` - The shape of the `NDArray`, represented am array of [`IntValue`]s. -fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( +pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, elem_ty: Type, diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 783aa5fb..2bad8bca 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -557,7 +557,18 @@ impl<'a> BuiltinBuilder<'a> { | PrimDef::FunNpHypot | PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim), - PrimDef::FunTryInvertTo | PrimDef::FunWilkinsonShift => self.build_linalg_methods(prim), + PrimDef::FunNpDot + | PrimDef::FunNpLinalgMatmul + | PrimDef::FunNpLinalgCholesky + | PrimDef::FunNpLinalgQr + | PrimDef::FunNpLinalgSvd + | PrimDef::FunNpLinalgInv + | PrimDef::FunNpLinalgPinv + | PrimDef::FunSpLinalgLu + | PrimDef::FunSpLinalgSchur + | PrimDef::FunSpLinalgHessenberg => self.build_np_linalg_methods(prim), + // PrimDef::FunNpDot | PrimDef::FunNpLinalgMatmul => self.build_np_linalg_binary_methods(prim), + // PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgQr => self.build_np_linalg_unary_methods(prim), }; if cfg!(debug_assertions) { @@ -1876,35 +1887,140 @@ impl<'a> BuiltinBuilder<'a> { } } - /// Build the functions `try_invert_to` and `wilkinson_shift` - fn build_linalg_methods(&mut self, prim: PrimDef) -> TopLevelDef { - debug_assert_prim_is_allowed(prim, &[PrimDef::FunTryInvertTo, PrimDef::FunWilkinsonShift]); + fn build_np_linalg_methods(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed( + prim, + &[ + PrimDef::FunNpDot, + PrimDef::FunNpLinalgMatmul, + PrimDef::FunNpLinalgCholesky, + PrimDef::FunNpLinalgQr, + PrimDef::FunNpLinalgSvd, + PrimDef::FunNpLinalgInv, + PrimDef::FunNpLinalgPinv, + PrimDef::FunSpLinalgLu, + PrimDef::FunSpLinalgSchur, + PrimDef::FunSpLinalgHessenberg, + ], + ); - let ret_ty = match prim { - PrimDef::FunTryInvertTo => self.primitives.bool, - PrimDef::FunWilkinsonShift => self.primitives.float, - _ => unreachable!(), - }; - let var_map = self.num_or_ndarray_var_map.clone(); - create_fn_by_codegen( - self.unifier, - &var_map, - prim.name(), - ret_ty, - &[(self.ndarray_float_2d, "x")], - Box::new(move |ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; + match prim { + PrimDef::FunNpDot => create_fn_by_codegen( + self.unifier, + &self.num_or_ndarray_var_map, + prim.name(), + self.primitives.float, + &[(self.num_or_ndarray_ty.ty, "x1"), (self.num_or_ndarray_ty.ty, "x2")], + Box::new(move |ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; + let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - let func = match prim { - PrimDef::FunTryInvertTo => builtin_fns::call_linalg_try_invert_to, - PrimDef::FunWilkinsonShift => builtin_fns::call_linalg_wilkinson_shift, - _ => unreachable!(), - }; + Ok(Some(builtin_fns::call_np_dot( + generator, + ctx, + (x1_ty, x1_val), + (x2_ty, x2_val), + )?)) + }), + ), - Ok(Some(func(generator, ctx, (x_ty, x_val))?)) - }), - ) + PrimDef::FunNpLinalgMatmul => create_fn_by_codegen( + self.unifier, + &VarMap::new(), + prim.name(), + self.ndarray_float_2d, + &[(self.ndarray_float_2d, "x1"), (self.ndarray_float_2d, "x2")], + Box::new(move |ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; + let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; + + Ok(Some(builtin_fns::call_np_linalg_matmul( + generator, + ctx, + (x1_ty, x1_val), + (x2_ty, x2_val), + )?)) + }), + ), + + PrimDef::FunNpLinalgCholesky + | PrimDef::FunNpLinalgInv + | PrimDef::FunNpLinalgPinv + | PrimDef::FunSpLinalgHessenberg => create_fn_by_codegen( + self.unifier, + &VarMap::new(), + prim.name(), + self.ndarray_float_2d, + &[(self.ndarray_float_2d, "x1")], + Box::new(move |ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + + let func = match prim { + PrimDef::FunNpLinalgCholesky => builtin_fns::call_np_linalg_cholesky, + PrimDef::FunNpLinalgInv => builtin_fns::call_np_linalg_inv, + PrimDef::FunNpLinalgPinv => builtin_fns::call_np_linalg_pinv, + PrimDef::FunSpLinalgHessenberg => builtin_fns::call_sp_linalg_hessenberg, + _ => unreachable!(), + }; + Ok(Some(func(generator, ctx, (x1_ty, x1_val))?)) + }), + ), + + PrimDef::FunNpLinalgQr | PrimDef::FunSpLinalgLu | PrimDef::FunSpLinalgSchur => { + let ret_ty = self.unifier.add_ty(TypeEnum::TTuple { + ty: vec![self.ndarray_float_2d, self.ndarray_float_2d], + }); + create_fn_by_codegen( + self.unifier, + &VarMap::new(), + prim.name(), + ret_ty, + &[(self.ndarray_float_2d, "x1")], + Box::new(move |ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = + args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + + let func = match prim { + PrimDef::FunNpLinalgQr => builtin_fns::call_np_linalg_qr, + PrimDef::FunSpLinalgLu => builtin_fns::call_sp_linalg_lu, + PrimDef::FunSpLinalgSchur => builtin_fns::call_sp_linalg_schur, + _ => unreachable!(), + }; + Ok(Some(func(generator, ctx, (x1_ty, x1_val))?)) + }), + ) + } + + PrimDef::FunNpLinalgSvd => { + let ret_ty = self.unifier.add_ty(TypeEnum::TTuple { + ty: vec![self.ndarray_float_2d, self.ndarray_float, self.ndarray_float_2d], + }); + create_fn_by_codegen( + self.unifier, + &VarMap::new(), + prim.name(), + ret_ty, + &[(self.ndarray_float_2d, "x1")], + Box::new(move |ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = + args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + + Ok(Some(builtin_fns::call_np_linalg_svd(generator, ctx, (x1_ty, x1_val))?)) + }), + ) + } + _ => { + println!("{:?}", prim.name()); + unreachable!() + } + } } fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) { diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index cd6c4975..674d0a8d 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -105,8 +105,16 @@ pub enum PrimDef { FunNpLdExp, FunNpHypot, FunNpNextAfter, - FunTryInvertTo, - FunWilkinsonShift, + FunNpDot, + FunNpLinalgMatmul, + FunNpLinalgCholesky, + FunNpLinalgQr, + FunNpLinalgSvd, + FunNpLinalgInv, + FunNpLinalgPinv, + FunSpLinalgLu, + FunSpLinalgSchur, + FunSpLinalgHessenberg, // Top-Level Functions FunSome, @@ -265,8 +273,17 @@ impl PrimDef { PrimDef::FunNpLdExp => fun("np_ldexp", None), PrimDef::FunNpHypot => fun("np_hypot", None), PrimDef::FunNpNextAfter => fun("np_nextafter", None), - PrimDef::FunTryInvertTo => fun("try_invert_to", None), - PrimDef::FunWilkinsonShift => fun("wilkinson_shift", None), + PrimDef::FunNpDot => fun("np_dot", None), + PrimDef::FunNpLinalgMatmul => fun("np_linalg_matmul", None), + PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None), + PrimDef::FunNpLinalgQr => fun("np_linalg_qr", None), + PrimDef::FunNpLinalgSvd => fun("np_linalg_svd", None), + PrimDef::FunNpLinalgInv => fun("np_linalg_inv", None), + PrimDef::FunNpLinalgPinv => fun("np_linalg_pinv", None), + PrimDef::FunSpLinalgLu => fun("sp_linalg_lu", None), + PrimDef::FunSpLinalgSchur => fun("sp_linalg_schur", None), + PrimDef::FunSpLinalgHessenberg => fun("sp_linalg_hessenberg", None), + PrimDef::FunSome => fun("Some", None), } } diff --git a/nac3standalone/Cargo.toml b/nac3standalone/Cargo.toml index a55a26b9..ccf69ba5 100644 --- a/nac3standalone/Cargo.toml +++ b/nac3standalone/Cargo.toml @@ -8,6 +8,7 @@ edition = "2021" parking_lot = "0.12" nac3parser = { path = "../nac3parser" } nac3core = { path = "../nac3core" } +linalg_externfns = { path = "./linalg_externfns" } [dependencies.clap] version = "4.5" diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 57d60717..2db011ea 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -5,6 +5,7 @@ import importlib.util import importlib.machinery import math import numpy as np +import scipy as sp import numpy.typing as npt import pathlib @@ -246,8 +247,21 @@ def patch(module): module.sp_spec_j0 = special.j0 module.sp_spec_j1 = special.j1 - module.try_invert_to = try_invert_to - module.wilkinson_shift = wilkinson_shift + # Linalg functions + module.np_dot = np.dot + module.np_linalg_matmul = np.matmul + module.np_linalg_cholesky = np.linalg.cholesky + module.np_linalg_qr = np.linalg.qr + module.np_linalg_svd = np.linalg.svd + module.np_linalg_inv = np.linalg.inv + module.np_linalg_pinv = np.linalg.pinv + + module.sp_linalg_lu = lambda x: sp.linalg.lu(x, True) + module.sp_linalg_schur = sp.linalg.schur + # module.sp_linalg_hessenberg = sp.linalg.hessenberg + module.sp_linalg_hessenberg = lambda x: x + + def file_import(filename, prefix="file_import_"): filename = pathlib.Path(filename) diff --git a/nac3standalone/demo/run_demo.sh b/nac3standalone/demo/run_demo.sh index 5a4665e0..b34a1210 100755 --- a/nac3standalone/demo/run_demo.sh +++ b/nac3standalone/demo/run_demo.sh @@ -42,14 +42,14 @@ done if [ -n "$debug" ] && [ -e ../../target/debug/nac3standalone ]; then nac3standalone=../../target/debug/nac3standalone - externfns=../../target/debug/deps/libexternfns.so + externfns=../../target/debug/deps/liblinalg_externfns.so elif [ -e ../../target/release/nac3standalone ]; then nac3standalone=../../target/release/nac3standalone - externfns=../../target/release/deps/libexternfns.so + externfns=../../target/release/deps/liblinalg_externfns.so else # used by Nix builds nac3standalone=../../target/x86_64-unknown-linux-gnu/release/nac3standalone - externfns=../../target/x86_64-unknown-linux-gnu/release/deps/libexternfns.so + externfns=../../target/x86_64-unknown-linux-gnu/release/deps/liblinalg_externfns.so fi rm -f ./*.o ./*.bc demo diff --git a/nac3standalone/demo/sample b/nac3standalone/demo/sample new file mode 100644 index 00000000..bbaaf4cd --- /dev/null +++ b/nac3standalone/demo/sample @@ -0,0 +1,12 @@ +Checking src/ndarray.py... Function Called np_dot +Module { data_layout: RefCell { value: Some(DataLayout { address: 0x555559cda5b0, repr: "" }) }, module: Cell { value: 0x555559cda3a0 }, owned_by_ee: RefCell { value: None }, _marker: PhantomData<&inkwell::context::Context> } +Function Called np_dot +Module { data_layout: RefCell { value: Some(DataLayout { address: 0x555559cda740, repr: "" }) }, module: Cell { value: 0x555559cda530 }, owned_by_ee: RefCell { value: None }, _marker: PhantomData<&inkwell::context::Context> } +--- interpreted.log 2024-07-24 19:39:47.480093947 +0800 ++++ run.log 2024-07-24 22:39:50.183382396 +0800 +@@ -0,0 +1,5 @@ ++5.000000 ++1.000000 ++5.000000 ++1.000000 ++26.000000 diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 5c469650..548e2388 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -1429,200 +1429,289 @@ def test_ndarray_nextafter_broadcast_rhs_scalar(): output_ndarray_float_2(nextafter_x_zeros) output_ndarray_float_2(nextafter_x_ones) -def test_try_invert(): - x: ndarray[float, 2] = np_array([[1.0, 2.0], [3.0, 4.0]]) - output_ndarray_float_2(x) - y = try_invert_to(x) +def test_ndarray_dot(): + x: ndarray[float, 1] = np_array([5.0, 1.0]) + y: ndarray[float, 1] = np_array([5.0, 1.0]) + z = np_dot(x, y) - output_ndarray_float_2(x) - output_bool(y) + output_ndarray_float_1(x) + output_ndarray_float_1(y) + output_float64(z) -def test_wilkinson_shift(): +def test_ndarray_linalg_matmul(): x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]]) - y = wilkinson_shift(x) + y: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]]) + z = np_linalg_matmul(x, y) + + m = np_argmax(z) + output_ndarray_float_2(x) - output_float64(y) + output_ndarray_float_2(y) + output_ndarray_float_2(z) + output_int64(m) + +def test_ndarray_cholesky(): + x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]]) + y = np_linalg_cholesky(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_qr(): + x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]]) + y, z = np_linalg_qr(x) + + output_ndarray_float_2(x) + + # QR Factorization in nalgebra and numpy do not give the same result + # Generating product for printing + a = np_linalg_matmul(y, z) + output_ndarray_float_2(a) + +def test_ndarray_linalg_inv(): + x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]]) + y = np_linalg_inv(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_pinv(): + x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5]]) + y = np_linalg_pinv(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_schur(): + x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]]) + t, z = sp_linalg_schur(x) + + output_ndarray_float_2(x) + # Same as np_linalg_qr the signs are different in nalgebra and numpy + a = np_linalg_matmul(np_linalg_matmul(z, t), np_linalg_inv(z)) + output_ndarray_float_2(a) + +def test_ndarray_hessenberg(): + x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5]]) + h = sp_linalg_hessenberg(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(h) + + +def test_ndarray_lu(): + x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5]]) + l, u = sp_linalg_lu(x) + + output_ndarray_float_2(x) + output_ndarray_float_2(l) + output_ndarray_float_2(u) + + +def test_ndarray_svd(): + w: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]]) + x, y, z = np_linalg_svd(w) + + output_ndarray_float_2(w) + + # Same as np_linalg_qr the signs are different in nalgebra and numpy + a = np_linalg_matmul(x, z) + output_ndarray_float_2(a) + output_ndarray_float_1(y) + def run() -> int32: - test_ndarray_ctor() - test_ndarray_empty() - test_ndarray_zeros() - test_ndarray_ones() - test_ndarray_full() - test_ndarray_eye() - test_ndarray_array() - test_ndarray_identity() - test_ndarray_fill() - test_ndarray_copy() + # test_ndarray_matmul() + test_ndarray_dot() + test_ndarray_linalg_matmul() + test_ndarray_cholesky() + test_ndarray_qr() + test_ndarray_svd() + # test_ndarray_linalg_inv() + # test_ndarray_pinv() + # test_ndarray_lu() + # test_ndarray_schur() + # test_ndarray_hessenberg() - test_ndarray_neg_idx() - test_ndarray_slices() - test_ndarray_nd_idx() + # test_ndarray_ctor() + # test_ndarray_empty() + # test_ndarray_zeros() + # test_ndarray_ones() + # test_ndarray_full() + # test_ndarray_eye() + # test_ndarray_array() + # test_ndarray_identity() + # test_ndarray_fill() + # test_ndarray_copy() - test_ndarray_add() - test_ndarray_add_broadcast() - test_ndarray_add_broadcast_lhs_scalar() - test_ndarray_add_broadcast_rhs_scalar() - test_ndarray_iadd() - test_ndarray_iadd_broadcast() - test_ndarray_iadd_broadcast_scalar() - test_ndarray_sub() - test_ndarray_sub_broadcast() - test_ndarray_sub_broadcast_lhs_scalar() - test_ndarray_sub_broadcast_rhs_scalar() - test_ndarray_isub() - test_ndarray_isub_broadcast() - test_ndarray_isub_broadcast_scalar() - test_ndarray_mul() - test_ndarray_mul_broadcast() - test_ndarray_mul_broadcast_lhs_scalar() - test_ndarray_mul_broadcast_rhs_scalar() - test_ndarray_imul() - test_ndarray_imul_broadcast() - test_ndarray_imul_broadcast_scalar() - test_ndarray_truediv() - test_ndarray_truediv_broadcast() - test_ndarray_truediv_broadcast_lhs_scalar() - test_ndarray_truediv_broadcast_rhs_scalar() - test_ndarray_itruediv() - test_ndarray_itruediv_broadcast() - test_ndarray_itruediv_broadcast_scalar() - test_ndarray_floordiv() - test_ndarray_floordiv_broadcast() - test_ndarray_floordiv_broadcast_lhs_scalar() - test_ndarray_floordiv_broadcast_rhs_scalar() - test_ndarray_ifloordiv() - test_ndarray_ifloordiv_broadcast() - test_ndarray_ifloordiv_broadcast_scalar() - test_ndarray_mod() - test_ndarray_mod_broadcast() - test_ndarray_mod_broadcast_lhs_scalar() - test_ndarray_mod_broadcast_rhs_scalar() - test_ndarray_imod() - test_ndarray_imod_broadcast() - test_ndarray_imod_broadcast_scalar() - test_ndarray_pow() - test_ndarray_pow_broadcast() - test_ndarray_pow_broadcast_lhs_scalar() - test_ndarray_pow_broadcast_rhs_scalar() - test_ndarray_ipow() - test_ndarray_ipow_broadcast() - test_ndarray_ipow_broadcast_scalar() - test_ndarray_matmul() - test_ndarray_imatmul() - test_ndarray_pos() - test_ndarray_neg() - test_ndarray_inv() - test_ndarray_eq() - test_ndarray_eq_broadcast() - test_ndarray_eq_broadcast_lhs_scalar() - test_ndarray_eq_broadcast_rhs_scalar() - test_ndarray_ne() - test_ndarray_ne_broadcast() - test_ndarray_ne_broadcast_lhs_scalar() - test_ndarray_ne_broadcast_rhs_scalar() - test_ndarray_lt() - test_ndarray_lt_broadcast() - test_ndarray_lt_broadcast_lhs_scalar() - test_ndarray_lt_broadcast_rhs_scalar() - test_ndarray_lt() - test_ndarray_le_broadcast() - test_ndarray_le_broadcast_lhs_scalar() - test_ndarray_le_broadcast_rhs_scalar() - test_ndarray_gt() - test_ndarray_gt_broadcast() - test_ndarray_gt_broadcast_lhs_scalar() - test_ndarray_gt_broadcast_rhs_scalar() - test_ndarray_gt() - test_ndarray_ge_broadcast() - test_ndarray_ge_broadcast_lhs_scalar() - test_ndarray_ge_broadcast_rhs_scalar() + # test_ndarray_neg_idx() + # test_ndarray_slices() + # test_ndarray_nd_idx() - test_ndarray_int32() - test_ndarray_int64() - test_ndarray_uint32() - test_ndarray_uint64() - test_ndarray_float() - test_ndarray_bool() + # test_ndarray_add() + # test_ndarray_add_broadcast() + # test_ndarray_add_broadcast_lhs_scalar() + # test_ndarray_add_broadcast_rhs_scalar() + # test_ndarray_iadd() + # test_ndarray_iadd_broadcast() + # test_ndarray_iadd_broadcast_scalar() + # test_ndarray_sub() + # test_ndarray_sub_broadcast() + # test_ndarray_sub_broadcast_lhs_scalar() + # test_ndarray_sub_broadcast_rhs_scalar() + # test_ndarray_isub() + # test_ndarray_isub_broadcast() + # test_ndarray_isub_broadcast_scalar() + # test_ndarray_mul() + # test_ndarray_mul_broadcast() + # test_ndarray_mul_broadcast_lhs_scalar() + # test_ndarray_mul_broadcast_rhs_scalar() + # test_ndarray_imul() + # test_ndarray_imul_broadcast() + # test_ndarray_imul_broadcast_scalar() + # test_ndarray_truediv() + # test_ndarray_truediv_broadcast() + # test_ndarray_truediv_broadcast_lhs_scalar() + # test_ndarray_truediv_broadcast_rhs_scalar() + # test_ndarray_itruediv() + # test_ndarray_itruediv_broadcast() + # test_ndarray_itruediv_broadcast_scalar() + # test_ndarray_floordiv() + # test_ndarray_floordiv_broadcast() + # test_ndarray_floordiv_broadcast_lhs_scalar() + # test_ndarray_floordiv_broadcast_rhs_scalar() + # test_ndarray_ifloordiv() + # test_ndarray_ifloordiv_broadcast() + # test_ndarray_ifloordiv_broadcast_scalar() + # test_ndarray_mod() + # test_ndarray_mod_broadcast() + # test_ndarray_mod_broadcast_lhs_scalar() + # test_ndarray_mod_broadcast_rhs_scalar() + # test_ndarray_imod() + # test_ndarray_imod_broadcast() + # test_ndarray_imod_broadcast_scalar() + # test_ndarray_pow() + # test_ndarray_pow_broadcast() + # test_ndarray_pow_broadcast_lhs_scalar() + # test_ndarray_pow_broadcast_rhs_scalar() + # test_ndarray_ipow() + # test_ndarray_ipow_broadcast() + # test_ndarray_ipow_broadcast_scalar() + # test_ndarray_matmul() + # test_ndarray_imatmul() + # test_ndarray_pos() + # test_ndarray_neg() + # test_ndarray_inv() + # test_ndarray_eq() + # test_ndarray_eq_broadcast() + # test_ndarray_eq_broadcast_lhs_scalar() + # test_ndarray_eq_broadcast_rhs_scalar() + # test_ndarray_ne() + # test_ndarray_ne_broadcast() + # test_ndarray_ne_broadcast_lhs_scalar() + # test_ndarray_ne_broadcast_rhs_scalar() + # test_ndarray_lt() + # test_ndarray_lt_broadcast() + # test_ndarray_lt_broadcast_lhs_scalar() + # test_ndarray_lt_broadcast_rhs_scalar() + # test_ndarray_lt() + # test_ndarray_le_broadcast() + # test_ndarray_le_broadcast_lhs_scalar() + # test_ndarray_le_broadcast_rhs_scalar() + # test_ndarray_gt() + # test_ndarray_gt_broadcast() + # test_ndarray_gt_broadcast_lhs_scalar() + # test_ndarray_gt_broadcast_rhs_scalar() + # test_ndarray_gt() + # test_ndarray_ge_broadcast() + # test_ndarray_ge_broadcast_lhs_scalar() + # test_ndarray_ge_broadcast_rhs_scalar() - test_ndarray_round() - test_ndarray_floor() - test_ndarray_min() - test_ndarray_minimum() - test_ndarray_minimum_broadcast() - test_ndarray_minimum_broadcast_lhs_scalar() - test_ndarray_minimum_broadcast_rhs_scalar() - test_ndarray_argmin() - test_ndarray_max() - test_ndarray_maximum() - test_ndarray_maximum_broadcast() - test_ndarray_maximum_broadcast_lhs_scalar() - test_ndarray_maximum_broadcast_rhs_scalar() - test_ndarray_argmax() - test_ndarray_abs() - test_ndarray_isnan() - test_ndarray_isinf() + # test_ndarray_int32() + # test_ndarray_int64() + # test_ndarray_uint32() + # test_ndarray_uint64() + # test_ndarray_float() + # test_ndarray_bool() - test_ndarray_sin() - test_ndarray_cos() - test_ndarray_exp() - test_ndarray_exp2() - test_ndarray_log() - test_ndarray_log10() - test_ndarray_log2() - test_ndarray_fabs() - test_ndarray_sqrt() - test_ndarray_rint() - test_ndarray_tan() - test_ndarray_arcsin() - test_ndarray_arccos() - test_ndarray_arctan() - test_ndarray_sinh() - test_ndarray_cosh() - test_ndarray_tanh() - test_ndarray_arcsinh() - test_ndarray_arccosh() - test_ndarray_arctanh() - test_ndarray_expm1() - test_ndarray_cbrt() + # test_ndarray_round() + # test_ndarray_floor() + # test_ndarray_min() + # test_ndarray_minimum() + # test_ndarray_minimum_broadcast() + # test_ndarray_minimum_broadcast_lhs_scalar() + # test_ndarray_minimum_broadcast_rhs_scalar() + # test_ndarray_argmin() + # test_ndarray_max() + # test_ndarray_maximum() + # test_ndarray_maximum_broadcast() + # test_ndarray_maximum_broadcast_lhs_scalar() + # test_ndarray_maximum_broadcast_rhs_scalar() + # test_ndarray_argmax() + # test_ndarray_abs() + # test_ndarray_isnan() + # test_ndarray_isinf() - test_ndarray_erf() - test_ndarray_erfc() - test_ndarray_gamma() - test_ndarray_gammaln() - test_ndarray_j0() - test_ndarray_j1() + # test_ndarray_sin() + # test_ndarray_cos() + # test_ndarray_exp() + # test_ndarray_exp2() + # test_ndarray_log() + # test_ndarray_log10() + # test_ndarray_log2() + # test_ndarray_fabs() + # test_ndarray_sqrt() + # test_ndarray_rint() + # test_ndarray_tan() + # test_ndarray_arcsin() + # test_ndarray_arccos() + # test_ndarray_arctan() + # test_ndarray_sinh() + # test_ndarray_cosh() + # test_ndarray_tanh() + # test_ndarray_arcsinh() + # test_ndarray_arccosh() + # test_ndarray_arctanh() + # test_ndarray_expm1() + # test_ndarray_cbrt() - test_ndarray_arctan2() - test_ndarray_arctan2_broadcast() - test_ndarray_arctan2_broadcast_lhs_scalar() - test_ndarray_arctan2_broadcast_rhs_scalar() - test_ndarray_copysign() - test_ndarray_copysign_broadcast() - test_ndarray_copysign_broadcast_lhs_scalar() - test_ndarray_copysign_broadcast_rhs_scalar() - test_ndarray_fmax() - test_ndarray_fmax_broadcast() - test_ndarray_fmax_broadcast_lhs_scalar() - test_ndarray_fmax_broadcast_rhs_scalar() - test_ndarray_fmin() - test_ndarray_fmin_broadcast() - test_ndarray_fmin_broadcast_lhs_scalar() - test_ndarray_fmin_broadcast_rhs_scalar() - test_ndarray_ldexp() - test_ndarray_ldexp_broadcast() - test_ndarray_ldexp_broadcast_lhs_scalar() - test_ndarray_ldexp_broadcast_rhs_scalar() - test_ndarray_hypot() - test_ndarray_hypot_broadcast() - test_ndarray_hypot_broadcast_lhs_scalar() - test_ndarray_hypot_broadcast_rhs_scalar() - test_ndarray_nextafter() - test_ndarray_nextafter_broadcast() - test_ndarray_nextafter_broadcast_lhs_scalar() - test_ndarray_nextafter_broadcast_rhs_scalar() + # test_ndarray_erf() + # test_ndarray_erfc() + # test_ndarray_gamma() + # test_ndarray_gammaln() + # test_ndarray_j0() + # test_ndarray_j1() - test_try_invert() - test_wilkinson_shift() + # test_ndarray_arctan2() + # test_ndarray_arctan2_broadcast() + # test_ndarray_arctan2_broadcast_lhs_scalar() + # test_ndarray_arctan2_broadcast_rhs_scalar() + # test_ndarray_copysign() + # test_ndarray_copysign_broadcast() + # test_ndarray_copysign_broadcast_lhs_scalar() + # test_ndarray_copysign_broadcast_rhs_scalar() + # test_ndarray_fmax() + # test_ndarray_fmax_broadcast() + # test_ndarray_fmax_broadcast_lhs_scalar() + # test_ndarray_fmax_broadcast_rhs_scalar() + # test_ndarray_fmin() + # test_ndarray_fmin_broadcast() + # test_ndarray_fmin_broadcast_lhs_scalar() + # test_ndarray_fmin_broadcast_rhs_scalar() + # test_ndarray_ldexp() + # test_ndarray_ldexp_broadcast() + # test_ndarray_ldexp_broadcast_lhs_scalar() + # test_ndarray_ldexp_broadcast_rhs_scalar() + # test_ndarray_hypot() + # test_ndarray_hypot_broadcast() + # test_ndarray_hypot_broadcast_lhs_scalar() + # test_ndarray_hypot_broadcast_rhs_scalar() + # test_ndarray_nextafter() + # test_ndarray_nextafter_broadcast() + # test_ndarray_nextafter_broadcast_lhs_scalar() + # test_ndarray_nextafter_broadcast_rhs_scalar() + + # test_try_invert() + # test_wilkinson_shift() return 0 diff --git a/nac3core/src/codegen/externfns/Cargo.toml b/nac3standalone/linalg_externfns/Cargo.toml similarity index 80% rename from nac3core/src/codegen/externfns/Cargo.toml rename to nac3standalone/linalg_externfns/Cargo.toml index 9d4592b4..5b697fbb 100644 --- a/nac3core/src/codegen/externfns/Cargo.toml +++ b/nac3standalone/linalg_externfns/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "externfns" +name = "linalg_externfns" version = "0.1.0" edition = "2021" @@ -8,3 +8,4 @@ crate-type = ["cdylib"] [dependencies] nalgebra = {version = "0.32.6", default-features = false, features = ["libm", "alloc"]} +cslice = "0.3.0" diff --git a/nac3standalone/linalg_externfns/src/lib.rs b/nac3standalone/linalg_externfns/src/lib.rs new file mode 100644 index 00000000..651a901c --- /dev/null +++ b/nac3standalone/linalg_externfns/src/lib.rs @@ -0,0 +1,407 @@ +mod runtime_exception; +use core::slice; +use nalgebra::{linalg, DMatrix}; + +pub struct InputMatrix { + pub ndims: usize, + pub dims: *const usize, + pub data: *mut f64, +} +impl InputMatrix { + fn get_dims(&mut self) -> Vec { + let dims = unsafe { slice::from_raw_parts(self.dims, self.ndims) }; + dims.to_vec() + } +} + +macro_rules! raise_exn { + ($name:expr, $fn_name:expr, $message:expr, $param0:expr, $param1:expr, $param2:expr) => {{ + use cslice::AsCSlice; + let name_id = $crate::runtime_exception::get_exception_id($name); + let exn = $crate::runtime_exception::Exception { + id: name_id, + file: file!().as_c_slice(), + line: line!(), + column: column!(), + // https://github.com/rust-lang/rfcs/pull/1719 + function: $fn_name.as_c_slice(), + message: $message.as_c_slice(), + param: [$param0, $param1, $param2], + }; + #[allow(unused_unsafe)] + unsafe { + $crate::runtime_exception::raise(&exn) + } + }}; + ($name:expr, $fn_name:expr, $message:expr) => {{ + raise_exn!($name, $fn_name, $message, 0, 0, 0) + }}; +} + +/// # Safety +/// +/// `mat1` and `mat2` should point to a valid 1DArray of `f64` floats in row-major order +#[no_mangle] +pub unsafe extern "C" fn np_dot(mat1: *mut InputMatrix, mat2: *mut InputMatrix) -> f64 { + let mat1 = mat1.as_mut().unwrap(); + let mat2 = mat2.as_mut().unwrap(); + + if !(mat1.ndims == 1 && mat2.ndims == 1) { + raise_exn!( + "ValueError", + "np_dot", + "expected 1D Vector Input, but received {0}-D and {1}-D input", + mat1.ndims.try_into().unwrap(), + mat2.ndims.try_into().unwrap(), + 0 + ); + } + + let dim1 = (*mat1).get_dims(); + let dim2 = (*mat2).get_dims(); + + if dim1[0] != dim2[0] { + raise_exn!( + "ValueError", + "np_dot", + "shapes ({},) and ({},) not aligned", + dim1[0].try_into().unwrap(), + dim2[0].try_into().unwrap(), + 0 + ); + } + let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0]) }; + let data_slice2 = unsafe { slice::from_raw_parts_mut(mat2.data, dim2[0]) }; + + let matrix1 = DMatrix::from_row_slice(dim1[0], 1, data_slice1); + let matrix2 = DMatrix::from_row_slice(dim2[0], 1, data_slice2); + + matrix1.dot(&matrix2) +} + +/// # Safety +/// +/// `mat1` and `mat2` should point to a valid 2DArray of `f64` floats in row-major order +#[no_mangle] +pub unsafe extern "C" fn np_linalg_matmul( + mat1: *mut InputMatrix, + mat2: *mut InputMatrix, + out: *mut InputMatrix, +) { + let mat1 = mat1.as_mut().unwrap(); + let mat2 = mat2.as_mut().unwrap(); + let out = out.as_mut().unwrap(); + + if !(mat1.ndims == 2 && mat2.ndims == 2) { + raise_exn!( + "ValueError", + "np_matmul", + "expected 2D Vector Input, but received {0}-D and {1}-D input", + mat1.ndims.try_into().unwrap(), + mat2.ndims.try_into().unwrap(), + 0 + ); + } + + let dim1 = (*mat1).get_dims(); + let dim2 = (*mat2).get_dims(); + + if dim1[1] != dim2[0] { + let err_msg = format!( + "shapes ({},{}) and ({},{}) not aligned: {} (dim 1) != {} (dim 0)", + dim1[0], dim1[1], dim2[0], dim2[1], dim1[1], dim2[0] + ); + raise_exn!("ValueError", "np_matmul", err_msg); + } + + let outdim = out.get_dims(); + let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) }; + let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) }; + let data_slice2 = unsafe { slice::from_raw_parts_mut(mat2.data, dim2[0] * dim2[1]) }; + + let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1); + let matrix2 = DMatrix::from_row_slice(dim2[0], dim2[1], data_slice2); + let mut result = DMatrix::::zeros(outdim[0], outdim[1]); + + matrix1.mul_to(&matrix2, &mut result); + out_slice.copy_from_slice(result.transpose().as_slice()); +} + +/// # Safety +/// +/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order +#[no_mangle] +pub unsafe extern "C" fn np_linalg_cholesky(mat1: *mut InputMatrix, out: *mut InputMatrix) { + let mat1 = mat1.as_mut().unwrap(); + let out = out.as_mut().unwrap(); + + if mat1.ndims != 2 { + raise_exn!( + "ValueError", + "np_linalg_cholesky", + "expected 2D Vector Input, but received {0}-D input", + mat1.ndims.try_into().unwrap(), + 0, + 0 + ); + } + + let dim1 = (*mat1).get_dims(); + if dim1[0] != dim1[1] { + raise_exn!( + "LinAlgError", + "np_linalg_cholesky", + "Last 2 dimensions of the array must be square: {0} != {1}", + dim1[0].try_into().unwrap(), + dim1[1].try_into().unwrap(), + 0 + ); + } + + let outdim = out.get_dims(); + let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) }; + let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) }; + + let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1); + let result = matrix1.cholesky(); + match result { + Some(res) => { + out_slice.copy_from_slice(res.unpack().transpose().as_slice()); + } + None => { + raise_exn!("LinAlgError", "np_linalg_cholesky", "Matrix is not positive definite"); + } + }; +} + +/// # Safety +/// +/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order +#[no_mangle] +pub unsafe extern "C" fn np_linalg_qr( + mat1: *mut InputMatrix, + outq: *mut InputMatrix, + outr: *mut InputMatrix, +) { + let mat1 = mat1.as_mut().unwrap(); + let outq = outq.as_mut().unwrap(); + let outr = outr.as_mut().unwrap(); + + if mat1.ndims != 2 { + raise_exn!( + "ValueError", + "np_linalg_cholesky", + "expected 2D Vector Input, but received {0}-D input", + mat1.ndims.try_into().unwrap(), + 0, + 0 + ); + } + + let dim1 = (*mat1).get_dims(); + let outq_dim = (*outq).get_dims(); + let outr_dim = (*outr).get_dims(); + + let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) }; + let out_q_slice = unsafe { slice::from_raw_parts_mut(outq.data, outq_dim[0] * outq_dim[1]) }; + let out_r_slice = unsafe { slice::from_raw_parts_mut(outr.data, outr_dim[0] * outr_dim[1]) }; + + // Refer to https://github.com/dimforge/nalgebra/issues/735 + let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1); + + let res = matrix1.qr(); + let (q, r) = res.unpack(); + + // Uses different algo need to match numpy + out_q_slice.copy_from_slice(q.transpose().as_slice()); + out_r_slice.copy_from_slice(r.transpose().as_slice()); +} + +/// # Safety +/// +/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order +#[no_mangle] +pub unsafe extern "C" fn np_linalg_svd( + mat1: *mut InputMatrix, + outu: *mut InputMatrix, + outs: *mut InputMatrix, + outvh: *mut InputMatrix, +) { + let mat1 = mat1.as_mut().unwrap(); + let outu = outu.as_mut().unwrap(); + let outs = outs.as_mut().unwrap(); + let outvh = outvh.as_mut().unwrap(); + + if mat1.ndims != 2 { + raise_exn!( + "ValueError", + "np_linalg_svd", + "expected 2D Vector Input, but received {0}-D input", + mat1.ndims.try_into().unwrap(), + 0, + 0 + ); + } + + let dim1 = (*mat1).get_dims(); + let outu_dim = (*outu).get_dims(); + let outs_dim = (*outs).get_dims(); + let outvh_dim = (*outvh).get_dims(); + + let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) }; + let out_u_slice = unsafe { slice::from_raw_parts_mut(outu.data, outu_dim[0] * outu_dim[1]) }; + let out_s_slice = unsafe { slice::from_raw_parts_mut(outs.data, outs_dim[0]) }; + let out_vh_slice = + unsafe { slice::from_raw_parts_mut(outvh.data, outvh_dim[0] * outvh_dim[1]) }; + + let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1); + let result = matrix.svd(true, true); + out_u_slice.copy_from_slice(result.u.unwrap().transpose().as_slice()); + out_s_slice.copy_from_slice(result.singular_values.as_slice()); + out_vh_slice.copy_from_slice(result.v_t.unwrap().transpose().as_slice()); +} + +/// # Safety +/// +/// `data` must point to an array of 4 elements in row-major order +#[no_mangle] +pub unsafe extern "C" fn np_linalg_inv( + dim1_0: usize, + dim1_1: usize, + x1: *mut f64, + dim2_0: usize, + dim2_1: usize, + out: *mut f64, +) -> i8 { + let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) }; + let out_slice = unsafe { slice::from_raw_parts_mut(out, dim2_0 * dim2_1) }; + + let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice); + if !matrix.is_invertible() { + // raise error + return 0; + } + let inv = matrix.try_inverse().unwrap(); + + out_slice.copy_from_slice(inv.transpose().as_slice()); + 1 +} + +/// # Safety +/// +/// `data` must point to an array of 4 elements in row-major order +#[no_mangle] +pub unsafe extern "C" fn np_linalg_pinv( + dim1_0: usize, + dim1_1: usize, + x1: *mut f64, + dim2_0: usize, + dim2_1: usize, + out: *mut f64, +) -> i8 { + let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) }; + let out_slice = unsafe { slice::from_raw_parts_mut(out, dim2_0 * dim2_1) }; + + let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice); + let svd = matrix.svd(true, true); + let inv = svd.pseudo_inverse(1e-15); + + match inv { + Ok(m) => { + out_slice.copy_from_slice(m.transpose().as_slice()); + 1 + } + Err(_e) => { + // raise exception here + 0 + } + } +} + +/// # Safety +/// +/// `data` must point to an array of 4 elements in row-major order +#[no_mangle] +pub unsafe extern "C" fn sp_linalg_lu( + dim1_0: usize, + dim1_1: usize, + x1: *mut f64, + dim2_0: usize, + dim2_1: usize, + out_l: *mut f64, + dim3_0: usize, + dim3_1: usize, + out_u: *mut f64, +) -> i8 { + let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) }; + let out_l_slice = unsafe { slice::from_raw_parts_mut(out_l, dim2_0 * dim2_1) }; + let out_u_slice = unsafe { slice::from_raw_parts_mut(out_u, dim3_0 * dim3_1) }; + + let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice); + let (_, l, u) = matrix.lu().unpack(); + + out_l_slice.copy_from_slice(l.transpose().as_slice()); + out_u_slice.copy_from_slice(u.transpose().as_slice()); + + 1 +} + +/// # Safety +/// +/// `data` must point to an array of 4 elements in row-major order +#[no_mangle] +pub unsafe extern "C" fn sp_linalg_schur( + dim1_0: usize, + dim1_1: usize, + x1: *mut f64, + dim2_0: usize, + dim2_1: usize, + out_t: *mut f64, + dim3_0: usize, + dim3_1: usize, + out_z: *mut f64, +) -> i8 { + let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) }; + let out_t_slice = unsafe { slice::from_raw_parts_mut(out_t, dim2_0 * dim2_1) }; + let out_z_slice = unsafe { slice::from_raw_parts_mut(out_z, dim3_0 * dim3_1) }; + + let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice); + if !matrix.is_square() { + // Throw error here + return 0; + } + let (z, t) = matrix.schur().unpack(); + + out_t_slice.copy_from_slice(t.transpose().as_slice()); + out_z_slice.copy_from_slice(z.transpose().as_slice()); + + 1 +} + +/// # Safety +/// +/// `data` must point to an array of 4 elements in row-major order +#[no_mangle] +pub unsafe extern "C" fn sp_linalg_hessenberg( + dim1_0: usize, + dim1_1: usize, + x1: *mut f64, + dim2_0: usize, + dim2_1: usize, + out_h: *mut f64, +) -> i8 { + let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) }; + let out_h_slice = unsafe { slice::from_raw_parts_mut(out_h, dim2_0 * dim2_1) }; + + let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice); + if !matrix.is_square() { + // Throw error here + + return 0; + } + let (_, h) = matrix.hessenberg().unpack(); + + out_h_slice.copy_from_slice(h.transpose().as_slice()); + + 1 +} diff --git a/nac3standalone/linalg_externfns/src/runtime_exception.rs b/nac3standalone/linalg_externfns/src/runtime_exception.rs new file mode 100644 index 00000000..7179f521 --- /dev/null +++ b/nac3standalone/linalg_externfns/src/runtime_exception.rs @@ -0,0 +1,66 @@ +#![allow(non_camel_case_types)] +#![allow(unused)] + +// ARTIQ Exception struct declaration +use cslice::CSlice; + +// Note: CSlice within an exception may not be actual cslice, they may be strings that exist only +// in the host. If the length == usize:MAX, the pointer is actually a string key in the host. +#[repr(C)] +#[derive(Clone)] +pub struct Exception<'a> { + pub id: u32, + pub file: CSlice<'a, u8>, + pub line: u32, + pub column: u32, + pub function: CSlice<'a, u8>, + pub message: CSlice<'a, u8>, + pub param: [i64; 3], +} + +fn str_err(_: core::str::Utf8Error) -> core::fmt::Error { + core::fmt::Error +} + +fn exception_str<'a>(s: &'a CSlice<'a, u8>) -> Result<&'a str, core::str::Utf8Error> { + if s.len() == usize::MAX { + Ok("") + } else { + core::str::from_utf8(s.as_ref()) + } +} + +pub unsafe fn raise(exception: *const Exception) -> ! { + let e = &*exception; + let f1 = exception_str(&e.function).map_err(str_err).unwrap(); + let f2 = exception_str(&e.file).map_err(str_err).unwrap(); + let f3 = exception_str(&e.message).map_err(str_err).unwrap(); + + panic!("Exception {} from {} in {}:{}:{}, message: {}", e.id, f1, f2, e.line, e.column, f3); +} + +static EXCEPTION_ID_LOOKUP: [(&str, u32); 14] = [ + ("RuntimeError", 0), + ("RTIOUnderflow", 1), + ("RTIOOverflow", 2), + ("RTIODestinationUnreachable", 3), + ("DMAError", 4), + ("I2CError", 5), + ("CacheError", 6), + ("SPIError", 7), + ("ZeroDivisionError", 8), + ("IndexError", 9), + ("UnwrapNoneError", 10), + ("Value", 11), + ("ValueError", 12), + ("LinAlgError", 13), +]; + +pub fn get_exception_id(name: &str) -> u32 { + for (n, id) in EXCEPTION_ID_LOOKUP.iter() { + if *n == name { + return *id; + } + } + unimplemented!("unallocated internal exception id") +} diff --git a/pyo3_output/nac3artiq.so b/pyo3_output/nac3artiq.so new file mode 100755 index 00000000..beb4f236 Binary files /dev/null and b/pyo3_output/nac3artiq.so differ