Error Refactored

This commit is contained in:
= 2024-07-25 00:47:50 +08:00
parent 7ec36e80f7
commit bc4e3a228b
19 changed files with 1675 additions and 531 deletions

23
Cargo.lock generated
View File

@ -256,6 +256,12 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
[[package]]
name = "cslice"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f8cb7306107e4b10e64994de6d3274bd08996a7c1322a27b86482392f96be0a"
[[package]] [[package]]
name = "dirs-next" name = "dirs-next"
version = "2.0.0" version = "2.0.0"
@ -314,13 +320,6 @@ dependencies = [
"windows-sys", "windows-sys",
] ]
[[package]]
name = "externfns"
version = "0.1.0"
dependencies = [
"nalgebra",
]
[[package]] [[package]]
name = "fastrand" name = "fastrand"
version = "2.1.0" version = "2.1.0"
@ -553,6 +552,14 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "linalg_externfns"
version = "0.1.0"
dependencies = [
"cslice",
"nalgebra",
]
[[package]] [[package]]
name = "linked-hash-map" name = "linked-hash-map"
version = "0.5.6" version = "0.5.6"
@ -638,7 +645,6 @@ name = "nac3core"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"crossbeam", "crossbeam",
"externfns",
"indexmap 2.2.6", "indexmap 2.2.6",
"indoc", "indoc",
"inkwell", "inkwell",
@ -682,6 +688,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"clap", "clap",
"inkwell", "inkwell",
"linalg_externfns",
"nac3core", "nac3core",
"nac3parser", "nac3parser",
"parking_lot", "parking_lot",

View File

@ -4,7 +4,7 @@ members = [
"nac3ast", "nac3ast",
"nac3parser", "nac3parser",
"nac3core", "nac3core",
"nac3core/src/codegen/externfns", "nac3standalone/linalg_externfns",
"nac3standalone", "nac3standalone",
"nac3artiq", "nac3artiq",
"runkernel", "runkernel",

View File

@ -161,7 +161,9 @@
clippy clippy
pre-commit pre-commit
rustfmt rustfmt
rust-analyzer
]; ];
RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
}; };
devShells.x86_64-linux.msys2 = pkgs.mkShell { devShells.x86_64-linux.msys2 = pkgs.mkShell {
name = "nac3-dev-shell-msys2"; name = "nac3-dev-shell-msys2";

View File

@ -11,7 +11,6 @@ indexmap = "2.2"
parking_lot = "0.12" parking_lot = "0.12"
rayon = "1.8" rayon = "1.8"
nac3parser = { path = "../nac3parser" } nac3parser = { path = "../nac3parser" }
externfns = { path = "src/codegen/externfns" }
strum = "0.26.2" strum = "0.26.2"
strum_macros = "0.26.4" strum_macros = "0.26.4"

View File

@ -1,5 +1,5 @@
use inkwell::types::BasicTypeEnum; use inkwell::types::BasicTypeEnum;
use inkwell::values::BasicValueEnum; use inkwell::values::{BasicValue, BasicValueEnum, PointerValue};
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel}; use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
use itertools::Itertools; use itertools::Itertools;
@ -31,7 +31,6 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let (n_ty, n) = n; let (n_ty, n) = n;
Ok(match n { Ok(match n {
BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => { BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool)); debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool));
@ -1836,231 +1835,547 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
}) })
} }
/// Invokes the `linalg_try_invert_to` function fn build_input_matrix<'ctx>(
pub fn call_linalg_try_invert_to<'ctx, G: CodeGenerator + ?Sized>( ctx: &mut CodeGenContext<'ctx, '_>,
out_matrices: Vec<BasicValueEnum<'ctx>>,
) -> PointerValue<'ctx> {
let field_ty = out_matrices.iter().map(|x| x.get_type()).collect::<Vec<BasicTypeEnum>>();
let out_ty = ctx.ctx.struct_type(&field_ty, false);
let out_ptr = ctx.builder.build_alloca(out_ty, "").unwrap();
for (i, v) in out_matrices.into_iter().enumerate() {
unsafe {
let ptr = ctx
.builder
.build_in_bounds_gep(
out_ptr,
&[
ctx.ctx.i32_type().const_zero(),
ctx.ctx.i32_type().const_int(i as u64, false),
],
"",
)
.unwrap();
ctx.builder.build_store(ptr, v).unwrap();
}
}
out_ptr
}
// Linalg Methods
pub fn call_np_dot<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
a: (Type, BasicValueEnum<'ctx>), x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "linalg_try_invert_to"; const FN_NAME: &str = "np_dot";
let (a_ty, a) = a; let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
if let (BasicValueEnum::PointerValue(_), BasicValueEnum::PointerValue(_)) = (x1, x2) {
let (n1_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, n1_elem_ty);
let (n2_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
let n2_elem_ty = ctx.get_llvm_type(generator, n2_elem_ty);
let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty)
else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
Ok(extern_fns::call_np_dot(ctx, x1, x2, None).into())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
}
}
pub fn call_np_linalg_matmul<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_matmul";
let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) = (x1, x2) {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let (n2_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
let n2_elem_ty = ctx.get_llvm_type(generator, n2_elem_ty);
let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty)
else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None);
let outdim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let outdim1 = unsafe {
n2.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_matmul(ctx, x1, x2, out, None);
Ok(out)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
}
}
pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_cholesky";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
match a { if let BasicValueEnum::PointerValue(n1) = x1 {
BasicValueEnum::PointerValue(n) let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
match llvm_ndarray_ty {
BasicTypeEnum::FloatType(_) => {}
_ => {
unimplemented!("Inverse Operation supported on float type NDArray Values only")
}
};
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None)); unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
// The following constraints must be satisfied: let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
// * Input must be 2D let dim0 = unsafe {
// * number of rows should equal number of columns (square matrix) n1.dim_sizes()
if cfg!(debug_assertions) { .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
let n_dims = n.load_ndims(ctx); .into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
// num_dim == 2 let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1])
ctx.make_assert( .unwrap()
generator, .as_base_value()
ctx.builder .as_basic_value_enum();
.build_int_compare(
IntPredicate::EQ,
n_dims,
llvm_usize.const_int(2, false),
"",
)
.unwrap(),
"0:ValueError",
format!("Input matrix must have two dimensions for {FN_NAME}").as_str(),
[None, None, None],
ctx.current_loc,
);
let dim0 = unsafe { extern_fns::call_np_linalg_cholesky(ctx, x1, out, None);
n.dim_sizes() Ok(out)
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) } else {
.into_int_value() unsupported_type(ctx, FN_NAME, &[x1_ty])
}; }
let dim1 = unsafe { }
n.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
// dim0 == dim1 pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
ctx.make_assert( generator: &mut G,
generator, ctx: &mut CodeGenContext<'ctx, '_>,
ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim1, "").unwrap(), x1: (Type, BasicValueEnum<'ctx>),
"0:ValueError", ) -> Result<BasicValueEnum<'ctx>, String> {
format!( const FN_NAME: &str = "np_linalg_qr";
"Input matrix should have equal number of rows and columns for {FN_NAME}" let (x1_ty, x1) = x1;
) let llvm_usize = generator.get_size_type(ctx.ctx);
.as_str(),
[None, None, None],
ctx.current_loc,
);
}
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { if let BasicValueEnum::PointerValue(n1) = x1 {
let n_sz_eqz = ctx let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k])
.unwrap()
.as_base_value()
.as_basic_value_enum();
let out_r = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_qr(ctx, x1, out_q, out_r, None);
let out_ptr = build_input_matrix(ctx, vec![out_q, out_r]);
Ok(ctx.builder.build_load(out_ptr, "QR_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_svd";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
let out_s = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k])
.unwrap()
.as_base_value()
.as_basic_value_enum();
let out_vh = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_svd(ctx, x1, out_u, out_s, out_vh, None);
let out_ptr = build_input_matrix(ctx, vec![out_u, out_s, out_vh]);
Ok(ctx.builder.build_load(out_ptr, "SVD_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_inv";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out =
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1]).unwrap();
extern_fns::call_np_linalg_inv(
ctx,
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
(dim0, dim1, out.data().base_ptr(ctx, generator)),
None,
);
Ok(out.as_base_value().as_basic_value_enum())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_pinv";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out =
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim0]).unwrap();
extern_fns::call_np_linalg_pinv(
ctx,
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
(dim1, dim0, out.data().base_ptr(ctx, generator)),
None,
);
Ok(out.as_base_value().as_basic_value_enum())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "sp_linalg_lu";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
let out_l = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k]).unwrap();
let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1]).unwrap();
extern_fns::call_sp_linalg_lu(
ctx,
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
(dim0, k, out_l.data().base_ptr(ctx, generator)),
(k, dim1, out_u.data().base_ptr(ctx, generator)),
None,
);
let out_l = out_l.as_base_value().as_basic_value_enum();
let out_u = out_u.as_base_value().as_basic_value_enum();
let res_ty = ctx.ctx.struct_type(&[out_l.get_type(), out_u.get_type()], false);
let res_ptr = ctx.builder.build_alloca(res_ty, "LU_factorization").unwrap();
let res_val = [out_l, out_u];
for (i, v) in res_val.into_iter().enumerate() {
unsafe {
let ptr = ctx
.builder .builder
.build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "") .build_in_bounds_gep(
res_ptr,
&[
ctx.ctx.i32_type().const_zero(),
ctx.ctx.i32_type().const_int(i as u64, false),
],
"ptr",
)
.unwrap(); .unwrap();
ctx.builder.build_store(ptr, v).unwrap();
ctx.make_assert(
generator,
n_sz_eqz,
"0:ValueError",
format!("zero-size array to inverse operation {FN_NAME}").as_str(),
[None, None, None],
ctx.current_loc,
);
} }
let dim0 = unsafe {
n.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
Ok(extern_fns::call_linalg_try_invert_to(
ctx,
dim0,
dim1,
n.data().base_ptr(ctx, generator),
None,
)
.into())
} }
_ => unsupported_type(ctx, FN_NAME, &[a_ty]),
Ok(ctx.builder.build_load(res_ptr, "LU_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
} }
} }
/// Invokes the `linalg_wilkinson_shift` function // Must be square (add check later)
pub fn call_linalg_wilkinson_shift<'ctx, G: CodeGenerator + ?Sized>( pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
a: (Type, BasicValueEnum<'ctx>), x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "linalg_wilkinson_shift"; const FN_NAME: &str = "sp_linalg_schur";
let (a_ty, a) = a; let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let one = llvm_usize.const_int(1, false); if let BasicValueEnum::PointerValue(n1) = x1 {
let two = llvm_usize.const_int(2, false); let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
match a { let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
BasicValueEnum::PointerValue(n) unimplemented!("{FN_NAME} operates on float type NdArrays only");
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => };
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
match llvm_ndarray_ty {
BasicTypeEnum::FloatType(_) | BasicTypeEnum::IntType(_) => {}
_ => unimplemented!(
"Wilkinson Shift Operation supported on float type NDArray Values only"
),
};
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
// The following constraints must be satisfied: let dim0 = unsafe {
// * Input must be 2D n1.dim_sizes()
// * Number of rows and columns should equal 2 .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
// * Input matrix must be symmetric .into_int_value()
if cfg!(debug_assertions) { };
let n_dims = n.load_ndims(ctx); let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out_t =
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]).unwrap();
let out_z =
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]).unwrap();
// num_dim == 2 extern_fns::call_sp_linalg_schur(
ctx.make_assert( ctx,
generator, (dim0, dim1, n1.data().base_ptr(ctx, generator)),
ctx.builder.build_int_compare(IntPredicate::EQ, n_dims, two, "").unwrap(), (dim0, dim0, out_t.data().base_ptr(ctx, generator)),
"0:ValueError", (dim0, dim0, out_z.data().base_ptr(ctx, generator)),
format!("Input matrix must have two dimensions for {FN_NAME}").as_str(), None,
[None, None, None], );
ctx.current_loc,
);
let dim0 = unsafe { let out_t = out_t.as_base_value().as_basic_value_enum();
n.dim_sizes() let out_z = out_z.as_base_value().as_basic_value_enum();
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value()
};
// dim0 == 2 let res_ty = ctx.ctx.struct_type(&[out_t.get_type(), out_z.get_type()], false);
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, dim0, two, "").unwrap(),
"0:ValueError",
format!("Number of rows must be 2 for {FN_NAME}").as_str(),
[None, None, None],
ctx.current_loc,
);
// dim1 == 2 let res_ptr = ctx.builder.build_alloca(res_ty, "Schur_factorization").unwrap();
ctx.make_assert( let res_val = [out_t, out_z];
generator, for (i, v) in res_val.into_iter().enumerate() {
ctx.builder.build_int_compare(IntPredicate::EQ, dim1, two, "").unwrap(), unsafe {
"0:ValueError", let ptr = ctx
format!("Number of columns must be 2 for {FN_NAME}").as_str(), .builder
[None, None, None], .build_in_bounds_gep(
ctx.current_loc, res_ptr,
); &[
ctx.ctx.i32_type().const_zero(),
let entry_01 = unsafe { ctx.ctx.i32_type().const_int(i as u64, false),
n.data().get_unchecked(ctx, generator, &one, None).into_float_value() ],
}; "ptr",
let entry_10 = unsafe { )
n.data().get_unchecked(ctx, generator, &two, None).into_float_value() .unwrap();
}; ctx.builder.build_store(ptr, v).unwrap();
// symmetric matrix
ctx.make_assert(
generator,
ctx.builder
.build_float_compare(FloatPredicate::OEQ, entry_01, entry_10, "")
.unwrap(),
"0:ValueError",
format!("Input Matrix must be symmetric for {FN_NAME}").as_str(),
[None, None, None],
ctx.current_loc,
);
} }
let dim0 = unsafe {
n.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 =
unsafe { n.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value() };
Ok(extern_fns::call_linalg_wilkinson_shift(
ctx,
dim0,
dim1,
n.data().base_ptr(ctx, generator),
None,
)
.into())
} }
_ => unsupported_type(ctx, FN_NAME, &[a_ty]),
Ok(ctx.builder.build_load(res_ptr, "Schur_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
// Must be square (add check later)
pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "sp_linalg_hessenberg";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
// Check if matrix is square
// ctx.builder.build_select(
// ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim1, "").unwrap(),
// {
// let func =
// }, else_, name)
// ;
// ctx.builder.build_call(
// ctx.module.get_function("__nac3_raise"),
// &[]
// )
// let err_msg = ctx.gen_string(generator, "{FN_NAME} requires square matrix");
// ctx.raise_exn(generator, "0:ValueError", err_msg, [None, None, None], ctx.current_loc);
let out_h =
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]).unwrap();
extern_fns::call_sp_linalg_hessenberg(
ctx,
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
(dim0, dim0, out_h.data().base_ptr(ctx, generator)),
None,
);
Ok(out_h.as_base_value().as_basic_value_enum())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
} }
} }

View File

@ -131,90 +131,218 @@ pub fn call_ldexp<'ctx>(
.unwrap() .unwrap()
} }
/// Invokes the [`try_invert_to`](https://docs.rs/nalgebra/latest/nalgebra/linalg/fn.try_invert_to.html) function /// Macro to generate np_linalg external functions
pub fn call_linalg_try_invert_to<'ctx>( macro_rules! generate_np_linalg_extern_fn {
ctx: &CodeGenContext<'ctx, '_>, ($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 1) => {
dim0: IntValue<'ctx>, generate_np_linalg_extern_fn!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1);
dim1: IntValue<'ctx>, };
data: PointerValue<'ctx>, ($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 2) => {
name: Option<&str>, generate_np_linalg_extern_fn!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1, mat2);
) -> IntValue<'ctx> { };
const FN_NAME: &str = "linalg_try_invert_to"; ($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 3) => {
generate_np_linalg_extern_fn!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1, mat2, mat3);
};
($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 4) => {
generate_np_linalg_extern_fn!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1, mat2, mat3, mat4);
};
($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal $(,$input_matrix:ident)*) => {
#[doc = concat!("Invokes the numpy `", stringify!($extern_fn), " function." )]
pub fn $fn_name<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>
$(,$input_matrix: (IntValue<'ctx>, IntValue<'ctx>, PointerValue<'ctx>))*,
name: Option<&str>,
) -> $ret_ty<'ctx> {
const FN_NAME: &str = $extern_fn;
let llvm_f64 = ctx.ctx.f64_type(); let llvm_f64 = ctx.ctx.f64_type();
let allowed_indices = [ctx.ctx.i32_type(), ctx.ctx.i64_type()]; let allowed_index_types = [ctx.ctx.i32_type(), ctx.ctx.i64_type()];
let allowed_dim0 = allowed_indices.iter().any(|p| *p == dim0.get_type()); $(
let allowed_dim1 = allowed_indices.iter().any(|p| *p == dim1.get_type()); debug_assert!(allowed_index_types.iter().any(|p| *p == $input_matrix.0.get_type()));
debug_assert!(allowed_index_types.iter().any(|p| *p == $input_matrix.1.get_type()));
debug_assert_eq!($input_matrix.2.get_type().get_element_type().into_float_type(), llvm_f64);
)*
debug_assert!(allowed_dim0); // let row = ctx.ctx.i32_type().const_int(ctx.current_loc.row.try_into().unwrap(), false);
debug_assert!(allowed_dim1); // let col = ctx.ctx.i32_type().const_int(ctx.current_loc.column.try_into().unwrap(), false);
debug_assert_eq!(data.get_type().get_element_type().into_float_type(), llvm_f64); // let file_name = ctx.current_loc.file.0;
// let name_len = ctx.ctx.i32_type().const_int(file_name.to_string().len().try_into().unwrap(), false);
// let file_name = ctx.ctx.const_string(&ctx.current_loc.file.0.to_string().into_bytes(), true);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| { let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = ctx.ctx.i8_type().fn_type( // let fn_type = ctx.ctx.$extern_ret_ty().fn_type(&[row.get_type().into(), col.get_type().into(), file_name.get_type().into(), $($input_matrix.0.get_type().into(), $input_matrix.1.get_type().into(), $input_matrix.2.get_type().into()),*], false);
&[dim0.get_type().into(), dim0.get_type().into(), data.get_type().into()], // let fn_type = ctx.ctx.$extern_ret_ty().fn_type(&[row.get_type().into(), col.get_type().into(), file_name.get_type().into(), name_len.get_type().into(), $($input_matrix.0.get_type().into(), $input_matrix.1.get_type().into(), $input_matrix.2.get_type().into()),*], false);
false, let fn_type = ctx.ctx.$extern_ret_ty().fn_type(&[$($input_matrix.0.get_type().into(), $input_matrix.1.get_type().into(), $input_matrix.2.get_type().into()),*], false);
);
let func = ctx.module.add_function(FN_NAME, fn_type, None); let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] { for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute( func.add_attribute(
AttributeLoc::Function, AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
); );
}
func
});
ctx.builder
// .build_call(extern_fn, &[row.into(), col.into(), file_name.into(), $($input_matrix.0.into(), $input_matrix.1.into(), $input_matrix.2.into(),)*], name.unwrap_or_default())
// .build_call(extern_fn, &[name_len.into(), col.into(), file_name.into(), row.into(), $($input_matrix.0.into(), $input_matrix.1.into(), $input_matrix.2.into(),)*], name.unwrap_or_default())
.build_call(extern_fn, &[$($input_matrix.0.into(), $input_matrix.1.into(), $input_matrix.2.into(),)*], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left($map_fn))
.map(Either::unwrap_left)
.unwrap()
} }
};
func
});
ctx.builder
.build_call(extern_fn, &[dim0.into(), dim1.into(), data.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
} }
/// Invokes the [`wilkinson_shift`](https://docs.rs/nalgebra/latest/nalgebra/linalg/fn.wilkinson_shift.html) function /// Macro to generate `np_linalg` external functions
pub fn call_linalg_wilkinson_shift<'ctx>( macro_rules! generate_np_linalg_extern_fn2 {
ctx: &CodeGenContext<'ctx, '_>, ($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 1) => {
dim0: IntValue<'ctx>, generate_np_linalg_extern_fn2!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1);
dim1: IntValue<'ctx>, };
data: PointerValue<'ctx>, ($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 2) => {
name: Option<&str>, generate_np_linalg_extern_fn2!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1, mat2);
) -> FloatValue<'ctx> { };
const FN_NAME: &str = "linalg_wilkinson_shift"; ($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 3) => {
generate_np_linalg_extern_fn2!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1, mat2, mat3);
};
($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 4) => {
generate_np_linalg_extern_fn2!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1, mat2, mat3, mat4);
};
($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal $(,$input_matrix:ident)*) => {
#[doc = concat!("Invokes the numpy `", stringify!($extern_fn), " function." )]
pub fn $fn_name<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>
$(,$input_matrix: BasicValueEnum<'ctx>)*,
name: Option<&str>,
) -> $ret_ty<'ctx> {
const FN_NAME: &str = $extern_fn;
// let row = ctx.ctx.i32_type().const_int(ctx.current_loc.row.try_into().unwrap(), false);
// let col = ctx.ctx.i32_type().const_int(ctx.current_loc.column.try_into().unwrap(), false);
// let file_name = ctx.current_loc.file.0;
// let name_len = ctx.ctx.i32_type().const_int(file_name.to_string().len().try_into().unwrap(), false);
// let file_name = ctx.ctx.const_string(&ctx.current_loc.file.0.to_string().into_bytes(), true);
let llvm_f64 = ctx.ctx.f64_type(); let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let allowed_index_types = [ctx.ctx.i32_type(), ctx.ctx.i64_type()]; // let fn_type = ctx.ctx.$extern_ret_ty().fn_type(&[row.get_type().into(), col.get_type().into(), file_name.get_type().into(), $($input_matrix.0.get_type().into(), $input_matrix.1.get_type().into(), $input_matrix.2.get_type().into()),*], false);
// let fn_type = ctx.ctx.$extern_ret_ty().fn_type(&[row.get_type().into(), col.get_type().into(), file_name.get_type().into(), name_len.get_type().into(), $($input_matrix.0.get_type().into(), $input_matrix.1.get_type().into(), $input_matrix.2.get_type().into()),*], false);
let fn_type = ctx.ctx.$extern_ret_ty().fn_type(&[$($input_matrix.get_type().into()),*], false);
let allowed_dim0 = allowed_index_types.iter().any(|p| *p == dim0.get_type()); let func = ctx.module.add_function(FN_NAME, fn_type, None);
let allowed_dim1 = allowed_index_types.iter().any(|p| *p == dim1.get_type()); for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
debug_assert!(allowed_dim0); ctx.builder
debug_assert!(allowed_dim1); // .build_call(extern_fn, &[row.into(), col.into(), file_name.into(), $($input_matrix.0.into(), $input_matrix.1.into(), $input_matrix.2.into(),)*], name.unwrap_or_default())
debug_assert_eq!(data.get_type().get_element_type().into_float_type(), llvm_f64); // .build_call(extern_fn, &[name_len.into(), col.into(), file_name.into(), row.into(), $($input_matrix.0.into(), $input_matrix.1.into(), $input_matrix.2.into(),)*], name.unwrap_or_default())
.build_call(extern_fn, &[$($input_matrix.into(),)*], name.unwrap_or_default())
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| { .map(CallSiteValue::try_as_basic_value)
let fn_type = ctx.ctx.f64_type().fn_type( .map(|v| v.map_left($map_fn))
&[dim0.get_type().into(), dim0.get_type().into(), data.get_type().into()], .map(Either::unwrap_left)
false, .unwrap()
);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
} }
};
func
});
ctx.builder
.build_call(extern_fn, &[dim0.into(), dim1.into(), data.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
} }
/// Macro to generate `np_linalg` external functions
macro_rules! generate_np_linalg_extern_fn3 {
($fn_name:ident, $extern_fn:literal, 1) => {
generate_np_linalg_extern_fn3!($fn_name, $extern_fn, mat1);
};
($fn_name:ident, $extern_fn:literal, 2) => {
generate_np_linalg_extern_fn3!($fn_name, $extern_fn, mat1, mat2);
};
($fn_name:ident, $extern_fn:literal, 3) => {
generate_np_linalg_extern_fn3!($fn_name, $extern_fn, mat1, mat2, mat3);
};
($fn_name:ident, $extern_fn:literal, 4) => {
generate_np_linalg_extern_fn3!($fn_name, $extern_fn, mat1, mat2, mat3, mat4);
};
($fn_name:ident, $extern_fn:literal $(,$input_matrix:ident)*) => {
#[doc = concat!("Invokes the numpy `", stringify!($extern_fn), " function." )]
pub fn $fn_name<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>
$(,$input_matrix: BasicValueEnum<'ctx>)*,
name: Option<&str>,
){
const FN_NAME: &str = $extern_fn;
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = ctx.ctx.void_type().fn_type(&[$($input_matrix.get_type().into()),*], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder.build_call(extern_fn, &[$($input_matrix.into(),)*], name.unwrap_or_default()).unwrap();
}
};
}
generate_np_linalg_extern_fn2!(
call_np_dot,
FloatValue,
f64_type,
BasicValueEnum::into_float_value,
"np_dot",
2
);
generate_np_linalg_extern_fn3!(call_np_linalg_matmul, "np_linalg_matmul", 3);
generate_np_linalg_extern_fn3!(call_np_linalg_cholesky, "np_linalg_cholesky", 2);
generate_np_linalg_extern_fn3!(call_np_linalg_qr, "np_linalg_qr", 3);
generate_np_linalg_extern_fn3!(call_np_linalg_svd, "np_linalg_svd", 4);
generate_np_linalg_extern_fn!(
call_np_linalg_inv,
IntValue,
i8_type,
BasicValueEnum::into_int_value,
"np_linalg_inv",
2
);
generate_np_linalg_extern_fn!(
call_np_linalg_pinv,
IntValue,
i8_type,
BasicValueEnum::into_int_value,
"np_linalg_pinv",
2
);
generate_np_linalg_extern_fn!(
call_sp_linalg_lu,
IntValue,
i8_type,
BasicValueEnum::into_int_value,
"sp_linalg_lu",
3
);
generate_np_linalg_extern_fn!(
call_sp_linalg_schur,
IntValue,
i8_type,
BasicValueEnum::into_int_value,
"sp_linalg_schur",
3
);
generate_np_linalg_extern_fn!(
call_sp_linalg_hessenberg,
IntValue,
i8_type,
BasicValueEnum::into_int_value,
"sp_linalg_hessenberg",
2
);

View File

@ -1,30 +0,0 @@
use core::slice;
use nalgebra::{linalg, DMatrix};
/// # Safety
///
/// `data` must point to an array with `dim0`x`dim1` elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn linalg_try_invert_to(dim0: usize, dim1: usize, data: *mut f64) -> i8 {
let data_slice = unsafe { slice::from_raw_parts_mut(data, dim0 * dim1) };
let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice);
let mut inverted_matrix = DMatrix::<f64>::zeros(dim0, dim1);
if linalg::try_invert_to(matrix, &mut inverted_matrix) {
data_slice.copy_from_slice(inverted_matrix.transpose().as_slice());
1
} else {
0
}
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn linalg_wilkinson_shift(dim0: usize, dim1: usize, data: *mut f64) -> f64 {
let data_slice = unsafe { slice::from_raw_parts_mut(data, dim0 * dim1) };
let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice);
linalg::wilkinson_shift(matrix[(0, 0)], matrix[(1, 1)], matrix[(0, 1)])
}

View File

@ -61,7 +61,7 @@ fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
/// * `shape` - The shape of the `NDArray`. /// * `shape` - The shape of the `NDArray`.
/// * `shape_len_fn` - A function that retrieves the number of dimensions from `shape`. /// * `shape_len_fn` - A function that retrieves the number of dimensions from `shape`.
/// * `shape_data_fn` - A function that retrieves the size of a dimension from `shape`. /// * `shape_data_fn` - A function that retrieves the size of a dimension from `shape`.
fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>( pub fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
elem_ty: Type, elem_ty: Type,
@ -157,7 +157,7 @@ where
/// ///
/// * `elem_ty` - The element type of the `NDArray`. /// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The shape of the `NDArray`, represented am array of [`IntValue`]s. /// * `shape` - The shape of the `NDArray`, represented am array of [`IntValue`]s.
fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type, elem_ty: Type,

View File

@ -557,7 +557,18 @@ impl<'a> BuiltinBuilder<'a> {
| PrimDef::FunNpHypot | PrimDef::FunNpHypot
| PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim), | PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim),
PrimDef::FunTryInvertTo | PrimDef::FunWilkinsonShift => self.build_linalg_methods(prim), PrimDef::FunNpDot
| PrimDef::FunNpLinalgMatmul
| PrimDef::FunNpLinalgCholesky
| PrimDef::FunNpLinalgQr
| PrimDef::FunNpLinalgSvd
| PrimDef::FunNpLinalgInv
| PrimDef::FunNpLinalgPinv
| PrimDef::FunSpLinalgLu
| PrimDef::FunSpLinalgSchur
| PrimDef::FunSpLinalgHessenberg => self.build_np_linalg_methods(prim),
// PrimDef::FunNpDot | PrimDef::FunNpLinalgMatmul => self.build_np_linalg_binary_methods(prim),
// PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgQr => self.build_np_linalg_unary_methods(prim),
}; };
if cfg!(debug_assertions) { if cfg!(debug_assertions) {
@ -1876,35 +1887,140 @@ impl<'a> BuiltinBuilder<'a> {
} }
} }
/// Build the functions `try_invert_to` and `wilkinson_shift` fn build_np_linalg_methods(&mut self, prim: PrimDef) -> TopLevelDef {
fn build_linalg_methods(&mut self, prim: PrimDef) -> TopLevelDef { debug_assert_prim_is_allowed(
debug_assert_prim_is_allowed(prim, &[PrimDef::FunTryInvertTo, PrimDef::FunWilkinsonShift]); prim,
&[
PrimDef::FunNpDot,
PrimDef::FunNpLinalgMatmul,
PrimDef::FunNpLinalgCholesky,
PrimDef::FunNpLinalgQr,
PrimDef::FunNpLinalgSvd,
PrimDef::FunNpLinalgInv,
PrimDef::FunNpLinalgPinv,
PrimDef::FunSpLinalgLu,
PrimDef::FunSpLinalgSchur,
PrimDef::FunSpLinalgHessenberg,
],
);
let ret_ty = match prim { match prim {
PrimDef::FunTryInvertTo => self.primitives.bool, PrimDef::FunNpDot => create_fn_by_codegen(
PrimDef::FunWilkinsonShift => self.primitives.float, self.unifier,
_ => unreachable!(), &self.num_or_ndarray_var_map,
}; prim.name(),
let var_map = self.num_or_ndarray_var_map.clone(); self.primitives.float,
create_fn_by_codegen( &[(self.num_or_ndarray_ty.ty, "x1"), (self.num_or_ndarray_ty.ty, "x2")],
self.unifier, Box::new(move |ctx, _, fun, args, generator| {
&var_map, let x1_ty = fun.0.args[0].ty;
prim.name(), let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
ret_ty, let x2_ty = fun.0.args[1].ty;
&[(self.ndarray_float_2d, "x")], let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Box::new(move |ctx, _, fun, args, generator| {
let x_ty = fun.0.args[0].ty;
let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?;
let func = match prim { Ok(Some(builtin_fns::call_np_dot(
PrimDef::FunTryInvertTo => builtin_fns::call_linalg_try_invert_to, generator,
PrimDef::FunWilkinsonShift => builtin_fns::call_linalg_wilkinson_shift, ctx,
_ => unreachable!(), (x1_ty, x1_val),
}; (x2_ty, x2_val),
)?))
}),
),
Ok(Some(func(generator, ctx, (x_ty, x_val))?)) PrimDef::FunNpLinalgMatmul => create_fn_by_codegen(
}), self.unifier,
) &VarMap::new(),
prim.name(),
self.ndarray_float_2d,
&[(self.ndarray_float_2d, "x1"), (self.ndarray_float_2d, "x2")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(builtin_fns::call_np_linalg_matmul(
generator,
ctx,
(x1_ty, x1_val),
(x2_ty, x2_val),
)?))
}),
),
PrimDef::FunNpLinalgCholesky
| PrimDef::FunNpLinalgInv
| PrimDef::FunNpLinalgPinv
| PrimDef::FunSpLinalgHessenberg => create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
self.ndarray_float_2d,
&[(self.ndarray_float_2d, "x1")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let func = match prim {
PrimDef::FunNpLinalgCholesky => builtin_fns::call_np_linalg_cholesky,
PrimDef::FunNpLinalgInv => builtin_fns::call_np_linalg_inv,
PrimDef::FunNpLinalgPinv => builtin_fns::call_np_linalg_pinv,
PrimDef::FunSpLinalgHessenberg => builtin_fns::call_sp_linalg_hessenberg,
_ => unreachable!(),
};
Ok(Some(func(generator, ctx, (x1_ty, x1_val))?))
}),
),
PrimDef::FunNpLinalgQr | PrimDef::FunSpLinalgLu | PrimDef::FunSpLinalgSchur => {
let ret_ty = self.unifier.add_ty(TypeEnum::TTuple {
ty: vec![self.ndarray_float_2d, self.ndarray_float_2d],
});
create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
ret_ty,
&[(self.ndarray_float_2d, "x1")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val =
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let func = match prim {
PrimDef::FunNpLinalgQr => builtin_fns::call_np_linalg_qr,
PrimDef::FunSpLinalgLu => builtin_fns::call_sp_linalg_lu,
PrimDef::FunSpLinalgSchur => builtin_fns::call_sp_linalg_schur,
_ => unreachable!(),
};
Ok(Some(func(generator, ctx, (x1_ty, x1_val))?))
}),
)
}
PrimDef::FunNpLinalgSvd => {
let ret_ty = self.unifier.add_ty(TypeEnum::TTuple {
ty: vec![self.ndarray_float_2d, self.ndarray_float, self.ndarray_float_2d],
});
create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
ret_ty,
&[(self.ndarray_float_2d, "x1")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val =
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
Ok(Some(builtin_fns::call_np_linalg_svd(generator, ctx, (x1_ty, x1_val))?))
}),
)
}
_ => {
println!("{:?}", prim.name());
unreachable!()
}
}
} }
fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) { fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) {

View File

@ -105,8 +105,16 @@ pub enum PrimDef {
FunNpLdExp, FunNpLdExp,
FunNpHypot, FunNpHypot,
FunNpNextAfter, FunNpNextAfter,
FunTryInvertTo, FunNpDot,
FunWilkinsonShift, FunNpLinalgMatmul,
FunNpLinalgCholesky,
FunNpLinalgQr,
FunNpLinalgSvd,
FunNpLinalgInv,
FunNpLinalgPinv,
FunSpLinalgLu,
FunSpLinalgSchur,
FunSpLinalgHessenberg,
// Top-Level Functions // Top-Level Functions
FunSome, FunSome,
@ -265,8 +273,17 @@ impl PrimDef {
PrimDef::FunNpLdExp => fun("np_ldexp", None), PrimDef::FunNpLdExp => fun("np_ldexp", None),
PrimDef::FunNpHypot => fun("np_hypot", None), PrimDef::FunNpHypot => fun("np_hypot", None),
PrimDef::FunNpNextAfter => fun("np_nextafter", None), PrimDef::FunNpNextAfter => fun("np_nextafter", None),
PrimDef::FunTryInvertTo => fun("try_invert_to", None), PrimDef::FunNpDot => fun("np_dot", None),
PrimDef::FunWilkinsonShift => fun("wilkinson_shift", None), PrimDef::FunNpLinalgMatmul => fun("np_linalg_matmul", None),
PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None),
PrimDef::FunNpLinalgQr => fun("np_linalg_qr", None),
PrimDef::FunNpLinalgSvd => fun("np_linalg_svd", None),
PrimDef::FunNpLinalgInv => fun("np_linalg_inv", None),
PrimDef::FunNpLinalgPinv => fun("np_linalg_pinv", None),
PrimDef::FunSpLinalgLu => fun("sp_linalg_lu", None),
PrimDef::FunSpLinalgSchur => fun("sp_linalg_schur", None),
PrimDef::FunSpLinalgHessenberg => fun("sp_linalg_hessenberg", None),
PrimDef::FunSome => fun("Some", None), PrimDef::FunSome => fun("Some", None),
} }
} }

View File

@ -8,6 +8,7 @@ edition = "2021"
parking_lot = "0.12" parking_lot = "0.12"
nac3parser = { path = "../nac3parser" } nac3parser = { path = "../nac3parser" }
nac3core = { path = "../nac3core" } nac3core = { path = "../nac3core" }
linalg_externfns = { path = "./linalg_externfns" }
[dependencies.clap] [dependencies.clap]
version = "4.5" version = "4.5"

View File

@ -5,6 +5,7 @@ import importlib.util
import importlib.machinery import importlib.machinery
import math import math
import numpy as np import numpy as np
import scipy as sp
import numpy.typing as npt import numpy.typing as npt
import pathlib import pathlib
@ -246,8 +247,21 @@ def patch(module):
module.sp_spec_j0 = special.j0 module.sp_spec_j0 = special.j0
module.sp_spec_j1 = special.j1 module.sp_spec_j1 = special.j1
module.try_invert_to = try_invert_to # Linalg functions
module.wilkinson_shift = wilkinson_shift module.np_dot = np.dot
module.np_linalg_matmul = np.matmul
module.np_linalg_cholesky = np.linalg.cholesky
module.np_linalg_qr = np.linalg.qr
module.np_linalg_svd = np.linalg.svd
module.np_linalg_inv = np.linalg.inv
module.np_linalg_pinv = np.linalg.pinv
module.sp_linalg_lu = lambda x: sp.linalg.lu(x, True)
module.sp_linalg_schur = sp.linalg.schur
# module.sp_linalg_hessenberg = sp.linalg.hessenberg
module.sp_linalg_hessenberg = lambda x: x
def file_import(filename, prefix="file_import_"): def file_import(filename, prefix="file_import_"):
filename = pathlib.Path(filename) filename = pathlib.Path(filename)

View File

@ -42,14 +42,14 @@ done
if [ -n "$debug" ] && [ -e ../../target/debug/nac3standalone ]; then if [ -n "$debug" ] && [ -e ../../target/debug/nac3standalone ]; then
nac3standalone=../../target/debug/nac3standalone nac3standalone=../../target/debug/nac3standalone
externfns=../../target/debug/deps/libexternfns.so externfns=../../target/debug/deps/liblinalg_externfns.so
elif [ -e ../../target/release/nac3standalone ]; then elif [ -e ../../target/release/nac3standalone ]; then
nac3standalone=../../target/release/nac3standalone nac3standalone=../../target/release/nac3standalone
externfns=../../target/release/deps/libexternfns.so externfns=../../target/release/deps/liblinalg_externfns.so
else else
# used by Nix builds # used by Nix builds
nac3standalone=../../target/x86_64-unknown-linux-gnu/release/nac3standalone nac3standalone=../../target/x86_64-unknown-linux-gnu/release/nac3standalone
externfns=../../target/x86_64-unknown-linux-gnu/release/deps/libexternfns.so externfns=../../target/x86_64-unknown-linux-gnu/release/deps/liblinalg_externfns.so
fi fi
rm -f ./*.o ./*.bc demo rm -f ./*.o ./*.bc demo

View File

@ -0,0 +1,12 @@
Checking src/ndarray.py... Function Called np_dot
Module { data_layout: RefCell { value: Some(DataLayout { address: 0x555559cda5b0, repr: "" }) }, module: Cell { value: 0x555559cda3a0 }, owned_by_ee: RefCell { value: None }, _marker: PhantomData<&inkwell::context::Context> }
Function Called np_dot
Module { data_layout: RefCell { value: Some(DataLayout { address: 0x555559cda740, repr: "" }) }, module: Cell { value: 0x555559cda530 }, owned_by_ee: RefCell { value: None }, _marker: PhantomData<&inkwell::context::Context> }
--- interpreted.log 2024-07-24 19:39:47.480093947 +0800
+++ run.log 2024-07-24 22:39:50.183382396 +0800
@@ -0,0 +1,5 @@
+5.000000
+1.000000
+5.000000
+1.000000
+26.000000

View File

@ -1429,200 +1429,289 @@ def test_ndarray_nextafter_broadcast_rhs_scalar():
output_ndarray_float_2(nextafter_x_zeros) output_ndarray_float_2(nextafter_x_zeros)
output_ndarray_float_2(nextafter_x_ones) output_ndarray_float_2(nextafter_x_ones)
def test_try_invert(): def test_ndarray_dot():
x: ndarray[float, 2] = np_array([[1.0, 2.0], [3.0, 4.0]]) x: ndarray[float, 1] = np_array([5.0, 1.0])
output_ndarray_float_2(x) y: ndarray[float, 1] = np_array([5.0, 1.0])
y = try_invert_to(x) z = np_dot(x, y)
output_ndarray_float_2(x) output_ndarray_float_1(x)
output_bool(y) output_ndarray_float_1(y)
output_float64(z)
def test_wilkinson_shift(): def test_ndarray_linalg_matmul():
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]]) x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
y = wilkinson_shift(x) y: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
z = np_linalg_matmul(x, y)
m = np_argmax(z)
output_ndarray_float_2(x) output_ndarray_float_2(x)
output_float64(y) output_ndarray_float_2(y)
output_ndarray_float_2(z)
output_int64(m)
def test_ndarray_cholesky():
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
y = np_linalg_cholesky(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_qr():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
y, z = np_linalg_qr(x)
output_ndarray_float_2(x)
# QR Factorization in nalgebra and numpy do not give the same result
# Generating product for printing
a = np_linalg_matmul(y, z)
output_ndarray_float_2(a)
def test_ndarray_linalg_inv():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
y = np_linalg_inv(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_pinv():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5]])
y = np_linalg_pinv(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_schur():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
t, z = sp_linalg_schur(x)
output_ndarray_float_2(x)
# Same as np_linalg_qr the signs are different in nalgebra and numpy
a = np_linalg_matmul(np_linalg_matmul(z, t), np_linalg_inv(z))
output_ndarray_float_2(a)
def test_ndarray_hessenberg():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5]])
h = sp_linalg_hessenberg(x)
output_ndarray_float_2(x)
output_ndarray_float_2(h)
def test_ndarray_lu():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5]])
l, u = sp_linalg_lu(x)
output_ndarray_float_2(x)
output_ndarray_float_2(l)
output_ndarray_float_2(u)
def test_ndarray_svd():
w: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
x, y, z = np_linalg_svd(w)
output_ndarray_float_2(w)
# Same as np_linalg_qr the signs are different in nalgebra and numpy
a = np_linalg_matmul(x, z)
output_ndarray_float_2(a)
output_ndarray_float_1(y)
def run() -> int32: def run() -> int32:
test_ndarray_ctor() # test_ndarray_matmul()
test_ndarray_empty() test_ndarray_dot()
test_ndarray_zeros() test_ndarray_linalg_matmul()
test_ndarray_ones() test_ndarray_cholesky()
test_ndarray_full() test_ndarray_qr()
test_ndarray_eye() test_ndarray_svd()
test_ndarray_array() # test_ndarray_linalg_inv()
test_ndarray_identity() # test_ndarray_pinv()
test_ndarray_fill() # test_ndarray_lu()
test_ndarray_copy() # test_ndarray_schur()
# test_ndarray_hessenberg()
test_ndarray_neg_idx() # test_ndarray_ctor()
test_ndarray_slices() # test_ndarray_empty()
test_ndarray_nd_idx() # test_ndarray_zeros()
# test_ndarray_ones()
# test_ndarray_full()
# test_ndarray_eye()
# test_ndarray_array()
# test_ndarray_identity()
# test_ndarray_fill()
# test_ndarray_copy()
test_ndarray_add() # test_ndarray_neg_idx()
test_ndarray_add_broadcast() # test_ndarray_slices()
test_ndarray_add_broadcast_lhs_scalar() # test_ndarray_nd_idx()
test_ndarray_add_broadcast_rhs_scalar()
test_ndarray_iadd()
test_ndarray_iadd_broadcast()
test_ndarray_iadd_broadcast_scalar()
test_ndarray_sub()
test_ndarray_sub_broadcast()
test_ndarray_sub_broadcast_lhs_scalar()
test_ndarray_sub_broadcast_rhs_scalar()
test_ndarray_isub()
test_ndarray_isub_broadcast()
test_ndarray_isub_broadcast_scalar()
test_ndarray_mul()
test_ndarray_mul_broadcast()
test_ndarray_mul_broadcast_lhs_scalar()
test_ndarray_mul_broadcast_rhs_scalar()
test_ndarray_imul()
test_ndarray_imul_broadcast()
test_ndarray_imul_broadcast_scalar()
test_ndarray_truediv()
test_ndarray_truediv_broadcast()
test_ndarray_truediv_broadcast_lhs_scalar()
test_ndarray_truediv_broadcast_rhs_scalar()
test_ndarray_itruediv()
test_ndarray_itruediv_broadcast()
test_ndarray_itruediv_broadcast_scalar()
test_ndarray_floordiv()
test_ndarray_floordiv_broadcast()
test_ndarray_floordiv_broadcast_lhs_scalar()
test_ndarray_floordiv_broadcast_rhs_scalar()
test_ndarray_ifloordiv()
test_ndarray_ifloordiv_broadcast()
test_ndarray_ifloordiv_broadcast_scalar()
test_ndarray_mod()
test_ndarray_mod_broadcast()
test_ndarray_mod_broadcast_lhs_scalar()
test_ndarray_mod_broadcast_rhs_scalar()
test_ndarray_imod()
test_ndarray_imod_broadcast()
test_ndarray_imod_broadcast_scalar()
test_ndarray_pow()
test_ndarray_pow_broadcast()
test_ndarray_pow_broadcast_lhs_scalar()
test_ndarray_pow_broadcast_rhs_scalar()
test_ndarray_ipow()
test_ndarray_ipow_broadcast()
test_ndarray_ipow_broadcast_scalar()
test_ndarray_matmul()
test_ndarray_imatmul()
test_ndarray_pos()
test_ndarray_neg()
test_ndarray_inv()
test_ndarray_eq()
test_ndarray_eq_broadcast()
test_ndarray_eq_broadcast_lhs_scalar()
test_ndarray_eq_broadcast_rhs_scalar()
test_ndarray_ne()
test_ndarray_ne_broadcast()
test_ndarray_ne_broadcast_lhs_scalar()
test_ndarray_ne_broadcast_rhs_scalar()
test_ndarray_lt()
test_ndarray_lt_broadcast()
test_ndarray_lt_broadcast_lhs_scalar()
test_ndarray_lt_broadcast_rhs_scalar()
test_ndarray_lt()
test_ndarray_le_broadcast()
test_ndarray_le_broadcast_lhs_scalar()
test_ndarray_le_broadcast_rhs_scalar()
test_ndarray_gt()
test_ndarray_gt_broadcast()
test_ndarray_gt_broadcast_lhs_scalar()
test_ndarray_gt_broadcast_rhs_scalar()
test_ndarray_gt()
test_ndarray_ge_broadcast()
test_ndarray_ge_broadcast_lhs_scalar()
test_ndarray_ge_broadcast_rhs_scalar()
test_ndarray_int32() # test_ndarray_add()
test_ndarray_int64() # test_ndarray_add_broadcast()
test_ndarray_uint32() # test_ndarray_add_broadcast_lhs_scalar()
test_ndarray_uint64() # test_ndarray_add_broadcast_rhs_scalar()
test_ndarray_float() # test_ndarray_iadd()
test_ndarray_bool() # test_ndarray_iadd_broadcast()
# test_ndarray_iadd_broadcast_scalar()
# test_ndarray_sub()
# test_ndarray_sub_broadcast()
# test_ndarray_sub_broadcast_lhs_scalar()
# test_ndarray_sub_broadcast_rhs_scalar()
# test_ndarray_isub()
# test_ndarray_isub_broadcast()
# test_ndarray_isub_broadcast_scalar()
# test_ndarray_mul()
# test_ndarray_mul_broadcast()
# test_ndarray_mul_broadcast_lhs_scalar()
# test_ndarray_mul_broadcast_rhs_scalar()
# test_ndarray_imul()
# test_ndarray_imul_broadcast()
# test_ndarray_imul_broadcast_scalar()
# test_ndarray_truediv()
# test_ndarray_truediv_broadcast()
# test_ndarray_truediv_broadcast_lhs_scalar()
# test_ndarray_truediv_broadcast_rhs_scalar()
# test_ndarray_itruediv()
# test_ndarray_itruediv_broadcast()
# test_ndarray_itruediv_broadcast_scalar()
# test_ndarray_floordiv()
# test_ndarray_floordiv_broadcast()
# test_ndarray_floordiv_broadcast_lhs_scalar()
# test_ndarray_floordiv_broadcast_rhs_scalar()
# test_ndarray_ifloordiv()
# test_ndarray_ifloordiv_broadcast()
# test_ndarray_ifloordiv_broadcast_scalar()
# test_ndarray_mod()
# test_ndarray_mod_broadcast()
# test_ndarray_mod_broadcast_lhs_scalar()
# test_ndarray_mod_broadcast_rhs_scalar()
# test_ndarray_imod()
# test_ndarray_imod_broadcast()
# test_ndarray_imod_broadcast_scalar()
# test_ndarray_pow()
# test_ndarray_pow_broadcast()
# test_ndarray_pow_broadcast_lhs_scalar()
# test_ndarray_pow_broadcast_rhs_scalar()
# test_ndarray_ipow()
# test_ndarray_ipow_broadcast()
# test_ndarray_ipow_broadcast_scalar()
# test_ndarray_matmul()
# test_ndarray_imatmul()
# test_ndarray_pos()
# test_ndarray_neg()
# test_ndarray_inv()
# test_ndarray_eq()
# test_ndarray_eq_broadcast()
# test_ndarray_eq_broadcast_lhs_scalar()
# test_ndarray_eq_broadcast_rhs_scalar()
# test_ndarray_ne()
# test_ndarray_ne_broadcast()
# test_ndarray_ne_broadcast_lhs_scalar()
# test_ndarray_ne_broadcast_rhs_scalar()
# test_ndarray_lt()
# test_ndarray_lt_broadcast()
# test_ndarray_lt_broadcast_lhs_scalar()
# test_ndarray_lt_broadcast_rhs_scalar()
# test_ndarray_lt()
# test_ndarray_le_broadcast()
# test_ndarray_le_broadcast_lhs_scalar()
# test_ndarray_le_broadcast_rhs_scalar()
# test_ndarray_gt()
# test_ndarray_gt_broadcast()
# test_ndarray_gt_broadcast_lhs_scalar()
# test_ndarray_gt_broadcast_rhs_scalar()
# test_ndarray_gt()
# test_ndarray_ge_broadcast()
# test_ndarray_ge_broadcast_lhs_scalar()
# test_ndarray_ge_broadcast_rhs_scalar()
test_ndarray_round() # test_ndarray_int32()
test_ndarray_floor() # test_ndarray_int64()
test_ndarray_min() # test_ndarray_uint32()
test_ndarray_minimum() # test_ndarray_uint64()
test_ndarray_minimum_broadcast() # test_ndarray_float()
test_ndarray_minimum_broadcast_lhs_scalar() # test_ndarray_bool()
test_ndarray_minimum_broadcast_rhs_scalar()
test_ndarray_argmin()
test_ndarray_max()
test_ndarray_maximum()
test_ndarray_maximum_broadcast()
test_ndarray_maximum_broadcast_lhs_scalar()
test_ndarray_maximum_broadcast_rhs_scalar()
test_ndarray_argmax()
test_ndarray_abs()
test_ndarray_isnan()
test_ndarray_isinf()
test_ndarray_sin() # test_ndarray_round()
test_ndarray_cos() # test_ndarray_floor()
test_ndarray_exp() # test_ndarray_min()
test_ndarray_exp2() # test_ndarray_minimum()
test_ndarray_log() # test_ndarray_minimum_broadcast()
test_ndarray_log10() # test_ndarray_minimum_broadcast_lhs_scalar()
test_ndarray_log2() # test_ndarray_minimum_broadcast_rhs_scalar()
test_ndarray_fabs() # test_ndarray_argmin()
test_ndarray_sqrt() # test_ndarray_max()
test_ndarray_rint() # test_ndarray_maximum()
test_ndarray_tan() # test_ndarray_maximum_broadcast()
test_ndarray_arcsin() # test_ndarray_maximum_broadcast_lhs_scalar()
test_ndarray_arccos() # test_ndarray_maximum_broadcast_rhs_scalar()
test_ndarray_arctan() # test_ndarray_argmax()
test_ndarray_sinh() # test_ndarray_abs()
test_ndarray_cosh() # test_ndarray_isnan()
test_ndarray_tanh() # test_ndarray_isinf()
test_ndarray_arcsinh()
test_ndarray_arccosh()
test_ndarray_arctanh()
test_ndarray_expm1()
test_ndarray_cbrt()
test_ndarray_erf() # test_ndarray_sin()
test_ndarray_erfc() # test_ndarray_cos()
test_ndarray_gamma() # test_ndarray_exp()
test_ndarray_gammaln() # test_ndarray_exp2()
test_ndarray_j0() # test_ndarray_log()
test_ndarray_j1() # test_ndarray_log10()
# test_ndarray_log2()
# test_ndarray_fabs()
# test_ndarray_sqrt()
# test_ndarray_rint()
# test_ndarray_tan()
# test_ndarray_arcsin()
# test_ndarray_arccos()
# test_ndarray_arctan()
# test_ndarray_sinh()
# test_ndarray_cosh()
# test_ndarray_tanh()
# test_ndarray_arcsinh()
# test_ndarray_arccosh()
# test_ndarray_arctanh()
# test_ndarray_expm1()
# test_ndarray_cbrt()
test_ndarray_arctan2() # test_ndarray_erf()
test_ndarray_arctan2_broadcast() # test_ndarray_erfc()
test_ndarray_arctan2_broadcast_lhs_scalar() # test_ndarray_gamma()
test_ndarray_arctan2_broadcast_rhs_scalar() # test_ndarray_gammaln()
test_ndarray_copysign() # test_ndarray_j0()
test_ndarray_copysign_broadcast() # test_ndarray_j1()
test_ndarray_copysign_broadcast_lhs_scalar()
test_ndarray_copysign_broadcast_rhs_scalar()
test_ndarray_fmax()
test_ndarray_fmax_broadcast()
test_ndarray_fmax_broadcast_lhs_scalar()
test_ndarray_fmax_broadcast_rhs_scalar()
test_ndarray_fmin()
test_ndarray_fmin_broadcast()
test_ndarray_fmin_broadcast_lhs_scalar()
test_ndarray_fmin_broadcast_rhs_scalar()
test_ndarray_ldexp()
test_ndarray_ldexp_broadcast()
test_ndarray_ldexp_broadcast_lhs_scalar()
test_ndarray_ldexp_broadcast_rhs_scalar()
test_ndarray_hypot()
test_ndarray_hypot_broadcast()
test_ndarray_hypot_broadcast_lhs_scalar()
test_ndarray_hypot_broadcast_rhs_scalar()
test_ndarray_nextafter()
test_ndarray_nextafter_broadcast()
test_ndarray_nextafter_broadcast_lhs_scalar()
test_ndarray_nextafter_broadcast_rhs_scalar()
test_try_invert() # test_ndarray_arctan2()
test_wilkinson_shift() # test_ndarray_arctan2_broadcast()
# test_ndarray_arctan2_broadcast_lhs_scalar()
# test_ndarray_arctan2_broadcast_rhs_scalar()
# test_ndarray_copysign()
# test_ndarray_copysign_broadcast()
# test_ndarray_copysign_broadcast_lhs_scalar()
# test_ndarray_copysign_broadcast_rhs_scalar()
# test_ndarray_fmax()
# test_ndarray_fmax_broadcast()
# test_ndarray_fmax_broadcast_lhs_scalar()
# test_ndarray_fmax_broadcast_rhs_scalar()
# test_ndarray_fmin()
# test_ndarray_fmin_broadcast()
# test_ndarray_fmin_broadcast_lhs_scalar()
# test_ndarray_fmin_broadcast_rhs_scalar()
# test_ndarray_ldexp()
# test_ndarray_ldexp_broadcast()
# test_ndarray_ldexp_broadcast_lhs_scalar()
# test_ndarray_ldexp_broadcast_rhs_scalar()
# test_ndarray_hypot()
# test_ndarray_hypot_broadcast()
# test_ndarray_hypot_broadcast_lhs_scalar()
# test_ndarray_hypot_broadcast_rhs_scalar()
# test_ndarray_nextafter()
# test_ndarray_nextafter_broadcast()
# test_ndarray_nextafter_broadcast_lhs_scalar()
# test_ndarray_nextafter_broadcast_rhs_scalar()
# test_try_invert()
# test_wilkinson_shift()
return 0 return 0

View File

@ -1,5 +1,5 @@
[package] [package]
name = "externfns" name = "linalg_externfns"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
@ -8,3 +8,4 @@ crate-type = ["cdylib"]
[dependencies] [dependencies]
nalgebra = {version = "0.32.6", default-features = false, features = ["libm", "alloc"]} nalgebra = {version = "0.32.6", default-features = false, features = ["libm", "alloc"]}
cslice = "0.3.0"

View File

@ -0,0 +1,407 @@
mod runtime_exception;
use core::slice;
use nalgebra::{linalg, DMatrix};
pub struct InputMatrix {
pub ndims: usize,
pub dims: *const usize,
pub data: *mut f64,
}
impl InputMatrix {
fn get_dims(&mut self) -> Vec<usize> {
let dims = unsafe { slice::from_raw_parts(self.dims, self.ndims) };
dims.to_vec()
}
}
macro_rules! raise_exn {
($name:expr, $fn_name:expr, $message:expr, $param0:expr, $param1:expr, $param2:expr) => {{
use cslice::AsCSlice;
let name_id = $crate::runtime_exception::get_exception_id($name);
let exn = $crate::runtime_exception::Exception {
id: name_id,
file: file!().as_c_slice(),
line: line!(),
column: column!(),
// https://github.com/rust-lang/rfcs/pull/1719
function: $fn_name.as_c_slice(),
message: $message.as_c_slice(),
param: [$param0, $param1, $param2],
};
#[allow(unused_unsafe)]
unsafe {
$crate::runtime_exception::raise(&exn)
}
}};
($name:expr, $fn_name:expr, $message:expr) => {{
raise_exn!($name, $fn_name, $message, 0, 0, 0)
}};
}
/// # Safety
///
/// `mat1` and `mat2` should point to a valid 1DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_dot(mat1: *mut InputMatrix, mat2: *mut InputMatrix) -> f64 {
let mat1 = mat1.as_mut().unwrap();
let mat2 = mat2.as_mut().unwrap();
if !(mat1.ndims == 1 && mat2.ndims == 1) {
raise_exn!(
"ValueError",
"np_dot",
"expected 1D Vector Input, but received {0}-D and {1}-D input",
mat1.ndims.try_into().unwrap(),
mat2.ndims.try_into().unwrap(),
0
);
}
let dim1 = (*mat1).get_dims();
let dim2 = (*mat2).get_dims();
if dim1[0] != dim2[0] {
raise_exn!(
"ValueError",
"np_dot",
"shapes ({},) and ({},) not aligned",
dim1[0].try_into().unwrap(),
dim2[0].try_into().unwrap(),
0
);
}
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0]) };
let data_slice2 = unsafe { slice::from_raw_parts_mut(mat2.data, dim2[0]) };
let matrix1 = DMatrix::from_row_slice(dim1[0], 1, data_slice1);
let matrix2 = DMatrix::from_row_slice(dim2[0], 1, data_slice2);
matrix1.dot(&matrix2)
}
/// # Safety
///
/// `mat1` and `mat2` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_matmul(
mat1: *mut InputMatrix,
mat2: *mut InputMatrix,
out: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let mat2 = mat2.as_mut().unwrap();
let out = out.as_mut().unwrap();
if !(mat1.ndims == 2 && mat2.ndims == 2) {
raise_exn!(
"ValueError",
"np_matmul",
"expected 2D Vector Input, but received {0}-D and {1}-D input",
mat1.ndims.try_into().unwrap(),
mat2.ndims.try_into().unwrap(),
0
);
}
let dim1 = (*mat1).get_dims();
let dim2 = (*mat2).get_dims();
if dim1[1] != dim2[0] {
let err_msg = format!(
"shapes ({},{}) and ({},{}) not aligned: {} (dim 1) != {} (dim 0)",
dim1[0], dim1[1], dim2[0], dim2[1], dim1[1], dim2[0]
);
raise_exn!("ValueError", "np_matmul", err_msg);
}
let outdim = out.get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let data_slice2 = unsafe { slice::from_raw_parts_mut(mat2.data, dim2[0] * dim2[1]) };
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let matrix2 = DMatrix::from_row_slice(dim2[0], dim2[1], data_slice2);
let mut result = DMatrix::<f64>::zeros(outdim[0], outdim[1]);
matrix1.mul_to(&matrix2, &mut result);
out_slice.copy_from_slice(result.transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_cholesky(mat1: *mut InputMatrix, out: *mut InputMatrix) {
let mat1 = mat1.as_mut().unwrap();
let out = out.as_mut().unwrap();
if mat1.ndims != 2 {
raise_exn!(
"ValueError",
"np_linalg_cholesky",
"expected 2D Vector Input, but received {0}-D input",
mat1.ndims.try_into().unwrap(),
0,
0
);
}
let dim1 = (*mat1).get_dims();
if dim1[0] != dim1[1] {
raise_exn!(
"LinAlgError",
"np_linalg_cholesky",
"Last 2 dimensions of the array must be square: {0} != {1}",
dim1[0].try_into().unwrap(),
dim1[1].try_into().unwrap(),
0
);
}
let outdim = out.get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let result = matrix1.cholesky();
match result {
Some(res) => {
out_slice.copy_from_slice(res.unpack().transpose().as_slice());
}
None => {
raise_exn!("LinAlgError", "np_linalg_cholesky", "Matrix is not positive definite");
}
};
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_qr(
mat1: *mut InputMatrix,
outq: *mut InputMatrix,
outr: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let outq = outq.as_mut().unwrap();
let outr = outr.as_mut().unwrap();
if mat1.ndims != 2 {
raise_exn!(
"ValueError",
"np_linalg_cholesky",
"expected 2D Vector Input, but received {0}-D input",
mat1.ndims.try_into().unwrap(),
0,
0
);
}
let dim1 = (*mat1).get_dims();
let outq_dim = (*outq).get_dims();
let outr_dim = (*outr).get_dims();
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let out_q_slice = unsafe { slice::from_raw_parts_mut(outq.data, outq_dim[0] * outq_dim[1]) };
let out_r_slice = unsafe { slice::from_raw_parts_mut(outr.data, outr_dim[0] * outr_dim[1]) };
// Refer to https://github.com/dimforge/nalgebra/issues/735
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let res = matrix1.qr();
let (q, r) = res.unpack();
// Uses different algo need to match numpy
out_q_slice.copy_from_slice(q.transpose().as_slice());
out_r_slice.copy_from_slice(r.transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_svd(
mat1: *mut InputMatrix,
outu: *mut InputMatrix,
outs: *mut InputMatrix,
outvh: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let outu = outu.as_mut().unwrap();
let outs = outs.as_mut().unwrap();
let outvh = outvh.as_mut().unwrap();
if mat1.ndims != 2 {
raise_exn!(
"ValueError",
"np_linalg_svd",
"expected 2D Vector Input, but received {0}-D input",
mat1.ndims.try_into().unwrap(),
0,
0
);
}
let dim1 = (*mat1).get_dims();
let outu_dim = (*outu).get_dims();
let outs_dim = (*outs).get_dims();
let outvh_dim = (*outvh).get_dims();
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let out_u_slice = unsafe { slice::from_raw_parts_mut(outu.data, outu_dim[0] * outu_dim[1]) };
let out_s_slice = unsafe { slice::from_raw_parts_mut(outs.data, outs_dim[0]) };
let out_vh_slice =
unsafe { slice::from_raw_parts_mut(outvh.data, outvh_dim[0] * outvh_dim[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let result = matrix.svd(true, true);
out_u_slice.copy_from_slice(result.u.unwrap().transpose().as_slice());
out_s_slice.copy_from_slice(result.singular_values.as_slice());
out_vh_slice.copy_from_slice(result.v_t.unwrap().transpose().as_slice());
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_inv(
dim1_0: usize,
dim1_1: usize,
x1: *mut f64,
dim2_0: usize,
dim2_1: usize,
out: *mut f64,
) -> i8 {
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
let out_slice = unsafe { slice::from_raw_parts_mut(out, dim2_0 * dim2_1) };
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
if !matrix.is_invertible() {
// raise error
return 0;
}
let inv = matrix.try_inverse().unwrap();
out_slice.copy_from_slice(inv.transpose().as_slice());
1
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_pinv(
dim1_0: usize,
dim1_1: usize,
x1: *mut f64,
dim2_0: usize,
dim2_1: usize,
out: *mut f64,
) -> i8 {
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
let out_slice = unsafe { slice::from_raw_parts_mut(out, dim2_0 * dim2_1) };
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
let svd = matrix.svd(true, true);
let inv = svd.pseudo_inverse(1e-15);
match inv {
Ok(m) => {
out_slice.copy_from_slice(m.transpose().as_slice());
1
}
Err(_e) => {
// raise exception here
0
}
}
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn sp_linalg_lu(
dim1_0: usize,
dim1_1: usize,
x1: *mut f64,
dim2_0: usize,
dim2_1: usize,
out_l: *mut f64,
dim3_0: usize,
dim3_1: usize,
out_u: *mut f64,
) -> i8 {
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
let out_l_slice = unsafe { slice::from_raw_parts_mut(out_l, dim2_0 * dim2_1) };
let out_u_slice = unsafe { slice::from_raw_parts_mut(out_u, dim3_0 * dim3_1) };
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
let (_, l, u) = matrix.lu().unpack();
out_l_slice.copy_from_slice(l.transpose().as_slice());
out_u_slice.copy_from_slice(u.transpose().as_slice());
1
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn sp_linalg_schur(
dim1_0: usize,
dim1_1: usize,
x1: *mut f64,
dim2_0: usize,
dim2_1: usize,
out_t: *mut f64,
dim3_0: usize,
dim3_1: usize,
out_z: *mut f64,
) -> i8 {
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
let out_t_slice = unsafe { slice::from_raw_parts_mut(out_t, dim2_0 * dim2_1) };
let out_z_slice = unsafe { slice::from_raw_parts_mut(out_z, dim3_0 * dim3_1) };
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
if !matrix.is_square() {
// Throw error here
return 0;
}
let (z, t) = matrix.schur().unpack();
out_t_slice.copy_from_slice(t.transpose().as_slice());
out_z_slice.copy_from_slice(z.transpose().as_slice());
1
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn sp_linalg_hessenberg(
dim1_0: usize,
dim1_1: usize,
x1: *mut f64,
dim2_0: usize,
dim2_1: usize,
out_h: *mut f64,
) -> i8 {
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
let out_h_slice = unsafe { slice::from_raw_parts_mut(out_h, dim2_0 * dim2_1) };
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
if !matrix.is_square() {
// Throw error here
return 0;
}
let (_, h) = matrix.hessenberg().unpack();
out_h_slice.copy_from_slice(h.transpose().as_slice());
1
}

View File

@ -0,0 +1,66 @@
#![allow(non_camel_case_types)]
#![allow(unused)]
// ARTIQ Exception struct declaration
use cslice::CSlice;
// Note: CSlice within an exception may not be actual cslice, they may be strings that exist only
// in the host. If the length == usize:MAX, the pointer is actually a string key in the host.
#[repr(C)]
#[derive(Clone)]
pub struct Exception<'a> {
pub id: u32,
pub file: CSlice<'a, u8>,
pub line: u32,
pub column: u32,
pub function: CSlice<'a, u8>,
pub message: CSlice<'a, u8>,
pub param: [i64; 3],
}
fn str_err(_: core::str::Utf8Error) -> core::fmt::Error {
core::fmt::Error
}
fn exception_str<'a>(s: &'a CSlice<'a, u8>) -> Result<&'a str, core::str::Utf8Error> {
if s.len() == usize::MAX {
Ok("<host string>")
} else {
core::str::from_utf8(s.as_ref())
}
}
pub unsafe fn raise(exception: *const Exception) -> ! {
let e = &*exception;
let f1 = exception_str(&e.function).map_err(str_err).unwrap();
let f2 = exception_str(&e.file).map_err(str_err).unwrap();
let f3 = exception_str(&e.message).map_err(str_err).unwrap();
panic!("Exception {} from {} in {}:{}:{}, message: {}", e.id, f1, f2, e.line, e.column, f3);
}
static EXCEPTION_ID_LOOKUP: [(&str, u32); 14] = [
("RuntimeError", 0),
("RTIOUnderflow", 1),
("RTIOOverflow", 2),
("RTIODestinationUnreachable", 3),
("DMAError", 4),
("I2CError", 5),
("CacheError", 6),
("SPIError", 7),
("ZeroDivisionError", 8),
("IndexError", 9),
("UnwrapNoneError", 10),
("Value", 11),
("ValueError", 12),
("LinAlgError", 13),
];
pub fn get_exception_id(name: &str) -> u32 {
for (n, id) in EXCEPTION_ID_LOOKUP.iter() {
if *n == name {
return *id;
}
}
unimplemented!("unallocated internal exception id")
}

BIN
pyo3_output/nac3artiq.so Executable file

Binary file not shown.