forked from M-Labs/nac3
Error Refactored
This commit is contained in:
parent
7ec36e80f7
commit
bc4e3a228b
23
Cargo.lock
generated
23
Cargo.lock
generated
@ -256,6 +256,12 @@ version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
|
||||
|
||||
[[package]]
|
||||
name = "cslice"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0f8cb7306107e4b10e64994de6d3274bd08996a7c1322a27b86482392f96be0a"
|
||||
|
||||
[[package]]
|
||||
name = "dirs-next"
|
||||
version = "2.0.0"
|
||||
@ -314,13 +320,6 @@ dependencies = [
|
||||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "externfns"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"nalgebra",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastrand"
|
||||
version = "2.1.0"
|
||||
@ -553,6 +552,14 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linalg_externfns"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"cslice",
|
||||
"nalgebra",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linked-hash-map"
|
||||
version = "0.5.6"
|
||||
@ -638,7 +645,6 @@ name = "nac3core"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"crossbeam",
|
||||
"externfns",
|
||||
"indexmap 2.2.6",
|
||||
"indoc",
|
||||
"inkwell",
|
||||
@ -682,6 +688,7 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"clap",
|
||||
"inkwell",
|
||||
"linalg_externfns",
|
||||
"nac3core",
|
||||
"nac3parser",
|
||||
"parking_lot",
|
||||
|
@ -4,7 +4,7 @@ members = [
|
||||
"nac3ast",
|
||||
"nac3parser",
|
||||
"nac3core",
|
||||
"nac3core/src/codegen/externfns",
|
||||
"nac3standalone/linalg_externfns",
|
||||
"nac3standalone",
|
||||
"nac3artiq",
|
||||
"runkernel",
|
||||
|
@ -161,7 +161,9 @@
|
||||
clippy
|
||||
pre-commit
|
||||
rustfmt
|
||||
rust-analyzer
|
||||
];
|
||||
RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
|
||||
};
|
||||
devShells.x86_64-linux.msys2 = pkgs.mkShell {
|
||||
name = "nac3-dev-shell-msys2";
|
||||
|
@ -11,7 +11,6 @@ indexmap = "2.2"
|
||||
parking_lot = "0.12"
|
||||
rayon = "1.8"
|
||||
nac3parser = { path = "../nac3parser" }
|
||||
externfns = { path = "src/codegen/externfns" }
|
||||
strum = "0.26.2"
|
||||
strum_macros = "0.26.4"
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
use inkwell::types::BasicTypeEnum;
|
||||
use inkwell::values::BasicValueEnum;
|
||||
use inkwell::values::{BasicValue, BasicValueEnum, PointerValue};
|
||||
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
|
||||
use itertools::Itertools;
|
||||
|
||||
@ -31,7 +31,6 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (n_ty, n) = n;
|
||||
|
||||
Ok(match n {
|
||||
BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => {
|
||||
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool));
|
||||
@ -1836,231 +1835,547 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
|
||||
})
|
||||
}
|
||||
|
||||
/// Invokes the `linalg_try_invert_to` function
|
||||
pub fn call_linalg_try_invert_to<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
fn build_input_matrix<'ctx>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
a: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "linalg_try_invert_to";
|
||||
let (a_ty, a) = a;
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
out_matrices: Vec<BasicValueEnum<'ctx>>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let field_ty = out_matrices.iter().map(|x| x.get_type()).collect::<Vec<BasicTypeEnum>>();
|
||||
let out_ty = ctx.ctx.struct_type(&field_ty, false);
|
||||
let out_ptr = ctx.builder.build_alloca(out_ty, "").unwrap();
|
||||
|
||||
match a {
|
||||
BasicValueEnum::PointerValue(n)
|
||||
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||
{
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
|
||||
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
match llvm_ndarray_ty {
|
||||
BasicTypeEnum::FloatType(_) => {}
|
||||
_ => {
|
||||
unimplemented!("Inverse Operation supported on float type NDArray Values only")
|
||||
}
|
||||
};
|
||||
|
||||
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
||||
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None));
|
||||
|
||||
// The following constraints must be satisfied:
|
||||
// * Input must be 2D
|
||||
// * number of rows should equal number of columns (square matrix)
|
||||
if cfg!(debug_assertions) {
|
||||
let n_dims = n.load_ndims(ctx);
|
||||
|
||||
// num_dim == 2
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::EQ,
|
||||
n_dims,
|
||||
llvm_usize.const_int(2, false),
|
||||
for (i, v) in out_matrices.into_iter().enumerate() {
|
||||
unsafe {
|
||||
let ptr = ctx
|
||||
.builder
|
||||
.build_in_bounds_gep(
|
||||
out_ptr,
|
||||
&[
|
||||
ctx.ctx.i32_type().const_zero(),
|
||||
ctx.ctx.i32_type().const_int(i as u64, false),
|
||||
],
|
||||
"",
|
||||
)
|
||||
.unwrap(),
|
||||
"0:ValueError",
|
||||
format!("Input matrix must have two dimensions for {FN_NAME}").as_str(),
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
let dim0 = unsafe {
|
||||
n.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let dim1 = unsafe {
|
||||
n.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
.into_int_value()
|
||||
};
|
||||
|
||||
// dim0 == dim1
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim1, "").unwrap(),
|
||||
"0:ValueError",
|
||||
format!(
|
||||
"Input matrix should have equal number of rows and columns for {FN_NAME}"
|
||||
)
|
||||
.as_str(),
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
}
|
||||
|
||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||
let n_sz_eqz = ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "")
|
||||
.unwrap();
|
||||
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
n_sz_eqz,
|
||||
"0:ValueError",
|
||||
format!("zero-size array to inverse operation {FN_NAME}").as_str(),
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
ctx.builder.build_store(ptr, v).unwrap();
|
||||
}
|
||||
|
||||
let dim0 = unsafe {
|
||||
n.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let dim1 = unsafe {
|
||||
n.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
.into_int_value()
|
||||
};
|
||||
|
||||
Ok(extern_fns::call_linalg_try_invert_to(
|
||||
ctx,
|
||||
dim0,
|
||||
dim1,
|
||||
n.data().base_ptr(ctx, generator),
|
||||
None,
|
||||
)
|
||||
.into())
|
||||
}
|
||||
_ => unsupported_type(ctx, FN_NAME, &[a_ty]),
|
||||
}
|
||||
out_ptr
|
||||
}
|
||||
|
||||
/// Invokes the `linalg_wilkinson_shift` function
|
||||
pub fn call_linalg_wilkinson_shift<'ctx, G: CodeGenerator + ?Sized>(
|
||||
// Linalg Methods
|
||||
pub fn call_np_dot<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
a: (Type, BasicValueEnum<'ctx>),
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
x2: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "linalg_wilkinson_shift";
|
||||
let (a_ty, a) = a;
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
const FN_NAME: &str = "np_dot";
|
||||
let (x1_ty, x1) = x1;
|
||||
let (x2_ty, x2) = x2;
|
||||
|
||||
let one = llvm_usize.const_int(1, false);
|
||||
let two = llvm_usize.const_int(2, false);
|
||||
if let (BasicValueEnum::PointerValue(_), BasicValueEnum::PointerValue(_)) = (x1, x2) {
|
||||
let (n1_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, n1_elem_ty);
|
||||
let (n2_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
||||
let n2_elem_ty = ctx.get_llvm_type(generator, n2_elem_ty);
|
||||
|
||||
match a {
|
||||
BasicValueEnum::PointerValue(n)
|
||||
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||
{
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
|
||||
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
match llvm_ndarray_ty {
|
||||
BasicTypeEnum::FloatType(_) | BasicTypeEnum::IntType(_) => {}
|
||||
_ => unimplemented!(
|
||||
"Wilkinson Shift Operation supported on float type NDArray Values only"
|
||||
),
|
||||
let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty)
|
||||
else {
|
||||
unimplemented!("{FN_NAME} operates on float type NdArrays only");
|
||||
};
|
||||
|
||||
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
||||
Ok(extern_fns::call_np_dot(ctx, x1, x2, None).into())
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
|
||||
}
|
||||
}
|
||||
|
||||
// The following constraints must be satisfied:
|
||||
// * Input must be 2D
|
||||
// * Number of rows and columns should equal 2
|
||||
// * Input matrix must be symmetric
|
||||
if cfg!(debug_assertions) {
|
||||
let n_dims = n.load_ndims(ctx);
|
||||
pub fn call_np_linalg_matmul<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
x2: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_matmul";
|
||||
let (x1_ty, x1) = x1;
|
||||
let (x2_ty, x2) = x2;
|
||||
|
||||
// num_dim == 2
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder.build_int_compare(IntPredicate::EQ, n_dims, two, "").unwrap(),
|
||||
"0:ValueError",
|
||||
format!("Input matrix must have two dimensions for {FN_NAME}").as_str(),
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) = (x1, x2) {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
let (n2_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
||||
let n2_elem_ty = ctx.get_llvm_type(generator, n2_elem_ty);
|
||||
|
||||
let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty)
|
||||
else {
|
||||
unimplemented!("{FN_NAME} operates on float type NdArrays only");
|
||||
};
|
||||
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None);
|
||||
|
||||
let outdim0 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let outdim1 = unsafe {
|
||||
n2.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
.into_int_value()
|
||||
};
|
||||
|
||||
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
|
||||
extern_fns::call_np_linalg_matmul(ctx, x1, x2, out, None);
|
||||
Ok(out)
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
|
||||
}
|
||||
}
|
||||
|
||||
pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_cholesky";
|
||||
let (x1_ty, x1) = x1;
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
|
||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||
unimplemented!("{FN_NAME} operates on float type NdArrays only");
|
||||
};
|
||||
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
let dim0 = unsafe {
|
||||
n.dim_sizes()
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let dim1 = unsafe {
|
||||
n.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value()
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
.into_int_value()
|
||||
};
|
||||
|
||||
// dim0 == 2
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder.build_int_compare(IntPredicate::EQ, dim0, two, "").unwrap(),
|
||||
"0:ValueError",
|
||||
format!("Number of rows must be 2 for {FN_NAME}").as_str(),
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
|
||||
// dim1 == 2
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder.build_int_compare(IntPredicate::EQ, dim1, two, "").unwrap(),
|
||||
"0:ValueError",
|
||||
format!("Number of columns must be 2 for {FN_NAME}").as_str(),
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
let entry_01 = unsafe {
|
||||
n.data().get_unchecked(ctx, generator, &one, None).into_float_value()
|
||||
};
|
||||
let entry_10 = unsafe {
|
||||
n.data().get_unchecked(ctx, generator, &two, None).into_float_value()
|
||||
};
|
||||
|
||||
// symmetric matrix
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder
|
||||
.build_float_compare(FloatPredicate::OEQ, entry_01, entry_10, "")
|
||||
.unwrap(),
|
||||
"0:ValueError",
|
||||
format!("Input Matrix must be symmetric for {FN_NAME}").as_str(),
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
extern_fns::call_np_linalg_cholesky(ctx, x1, out, None);
|
||||
Ok(out)
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||
}
|
||||
}
|
||||
|
||||
pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_qr";
|
||||
let (x1_ty, x1) = x1;
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
|
||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||
unimplemented!("{FN_NAME} operates on float type NdArrays only");
|
||||
};
|
||||
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
let dim0 = unsafe {
|
||||
n.dim_sizes()
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let dim1 =
|
||||
unsafe { n.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value() };
|
||||
let dim1 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
|
||||
|
||||
Ok(extern_fns::call_linalg_wilkinson_shift(
|
||||
ctx,
|
||||
dim0,
|
||||
dim1,
|
||||
n.data().base_ptr(ctx, generator),
|
||||
None,
|
||||
)
|
||||
.into())
|
||||
}
|
||||
_ => unsupported_type(ctx, FN_NAME, &[a_ty]),
|
||||
let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
let out_r = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
|
||||
extern_fns::call_np_linalg_qr(ctx, x1, out_q, out_r, None);
|
||||
|
||||
let out_ptr = build_input_matrix(ctx, vec![out_q, out_r]);
|
||||
|
||||
Ok(ctx.builder.build_load(out_ptr, "QR_Factorization_result").map(Into::into).unwrap())
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||
}
|
||||
}
|
||||
|
||||
pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_svd";
|
||||
let (x1_ty, x1) = x1;
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
|
||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||
unimplemented!("{FN_NAME} operates on float type NdArrays only");
|
||||
};
|
||||
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
|
||||
let dim0 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let dim1 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
|
||||
|
||||
let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
let out_s = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
let out_vh = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim1])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
|
||||
extern_fns::call_np_linalg_svd(ctx, x1, out_u, out_s, out_vh, None);
|
||||
|
||||
let out_ptr = build_input_matrix(ctx, vec![out_u, out_s, out_vh]);
|
||||
|
||||
Ok(ctx.builder.build_load(out_ptr, "SVD_Factorization_result").map(Into::into).unwrap())
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||
}
|
||||
}
|
||||
|
||||
pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_inv";
|
||||
let (x1_ty, x1) = x1;
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
|
||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||
unimplemented!("{FN_NAME} operates on float type NdArrays only");
|
||||
};
|
||||
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
|
||||
let dim0 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let dim1 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
.into_int_value()
|
||||
};
|
||||
|
||||
let out =
|
||||
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1]).unwrap();
|
||||
|
||||
extern_fns::call_np_linalg_inv(
|
||||
ctx,
|
||||
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
|
||||
(dim0, dim1, out.data().base_ptr(ctx, generator)),
|
||||
None,
|
||||
);
|
||||
Ok(out.as_base_value().as_basic_value_enum())
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||
}
|
||||
}
|
||||
|
||||
pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_pinv";
|
||||
let (x1_ty, x1) = x1;
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
|
||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||
unimplemented!("{FN_NAME} operates on float type NdArrays only");
|
||||
};
|
||||
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
|
||||
let dim0 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let dim1 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
.into_int_value()
|
||||
};
|
||||
|
||||
let out =
|
||||
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim0]).unwrap();
|
||||
|
||||
extern_fns::call_np_linalg_pinv(
|
||||
ctx,
|
||||
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
|
||||
(dim1, dim0, out.data().base_ptr(ctx, generator)),
|
||||
None,
|
||||
);
|
||||
Ok(out.as_base_value().as_basic_value_enum())
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||
}
|
||||
}
|
||||
|
||||
pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "sp_linalg_lu";
|
||||
let (x1_ty, x1) = x1;
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
|
||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||
unimplemented!("{FN_NAME} operates on float type NdArrays only");
|
||||
};
|
||||
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
|
||||
let dim0 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let dim1 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
|
||||
|
||||
let out_l = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k]).unwrap();
|
||||
let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1]).unwrap();
|
||||
|
||||
extern_fns::call_sp_linalg_lu(
|
||||
ctx,
|
||||
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
|
||||
(dim0, k, out_l.data().base_ptr(ctx, generator)),
|
||||
(k, dim1, out_u.data().base_ptr(ctx, generator)),
|
||||
None,
|
||||
);
|
||||
|
||||
let out_l = out_l.as_base_value().as_basic_value_enum();
|
||||
let out_u = out_u.as_base_value().as_basic_value_enum();
|
||||
|
||||
let res_ty = ctx.ctx.struct_type(&[out_l.get_type(), out_u.get_type()], false);
|
||||
let res_ptr = ctx.builder.build_alloca(res_ty, "LU_factorization").unwrap();
|
||||
|
||||
let res_val = [out_l, out_u];
|
||||
for (i, v) in res_val.into_iter().enumerate() {
|
||||
unsafe {
|
||||
let ptr = ctx
|
||||
.builder
|
||||
.build_in_bounds_gep(
|
||||
res_ptr,
|
||||
&[
|
||||
ctx.ctx.i32_type().const_zero(),
|
||||
ctx.ctx.i32_type().const_int(i as u64, false),
|
||||
],
|
||||
"ptr",
|
||||
)
|
||||
.unwrap();
|
||||
ctx.builder.build_store(ptr, v).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ctx.builder.build_load(res_ptr, "LU_Factorization_result").map(Into::into).unwrap())
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||
}
|
||||
}
|
||||
|
||||
// Must be square (add check later)
|
||||
pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "sp_linalg_schur";
|
||||
let (x1_ty, x1) = x1;
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
|
||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||
unimplemented!("{FN_NAME} operates on float type NdArrays only");
|
||||
};
|
||||
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
|
||||
let dim0 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let dim1 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let out_t =
|
||||
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]).unwrap();
|
||||
let out_z =
|
||||
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]).unwrap();
|
||||
|
||||
extern_fns::call_sp_linalg_schur(
|
||||
ctx,
|
||||
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
|
||||
(dim0, dim0, out_t.data().base_ptr(ctx, generator)),
|
||||
(dim0, dim0, out_z.data().base_ptr(ctx, generator)),
|
||||
None,
|
||||
);
|
||||
|
||||
let out_t = out_t.as_base_value().as_basic_value_enum();
|
||||
let out_z = out_z.as_base_value().as_basic_value_enum();
|
||||
|
||||
let res_ty = ctx.ctx.struct_type(&[out_t.get_type(), out_z.get_type()], false);
|
||||
|
||||
let res_ptr = ctx.builder.build_alloca(res_ty, "Schur_factorization").unwrap();
|
||||
let res_val = [out_t, out_z];
|
||||
for (i, v) in res_val.into_iter().enumerate() {
|
||||
unsafe {
|
||||
let ptr = ctx
|
||||
.builder
|
||||
.build_in_bounds_gep(
|
||||
res_ptr,
|
||||
&[
|
||||
ctx.ctx.i32_type().const_zero(),
|
||||
ctx.ctx.i32_type().const_int(i as u64, false),
|
||||
],
|
||||
"ptr",
|
||||
)
|
||||
.unwrap();
|
||||
ctx.builder.build_store(ptr, v).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ctx.builder.build_load(res_ptr, "Schur_Factorization_result").map(Into::into).unwrap())
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||
}
|
||||
}
|
||||
|
||||
// Must be square (add check later)
|
||||
pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "sp_linalg_hessenberg";
|
||||
let (x1_ty, x1) = x1;
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
|
||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||
unimplemented!("{FN_NAME} operates on float type NdArrays only");
|
||||
};
|
||||
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
|
||||
let dim0 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let dim1 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
.into_int_value()
|
||||
};
|
||||
|
||||
// Check if matrix is square
|
||||
// ctx.builder.build_select(
|
||||
// ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim1, "").unwrap(),
|
||||
// {
|
||||
// let func =
|
||||
// }, else_, name)
|
||||
// ;
|
||||
|
||||
// ctx.builder.build_call(
|
||||
// ctx.module.get_function("__nac3_raise"),
|
||||
// &[]
|
||||
|
||||
// )
|
||||
// let err_msg = ctx.gen_string(generator, "{FN_NAME} requires square matrix");
|
||||
// ctx.raise_exn(generator, "0:ValueError", err_msg, [None, None, None], ctx.current_loc);
|
||||
|
||||
let out_h =
|
||||
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]).unwrap();
|
||||
|
||||
extern_fns::call_sp_linalg_hessenberg(
|
||||
ctx,
|
||||
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
|
||||
(dim0, dim0, out_h.data().base_ptr(ctx, generator)),
|
||||
None,
|
||||
);
|
||||
|
||||
Ok(out_h.as_base_value().as_basic_value_enum())
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||
}
|
||||
}
|
||||
|
@ -131,90 +131,218 @@ pub fn call_ldexp<'ctx>(
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`try_invert_to`](https://docs.rs/nalgebra/latest/nalgebra/linalg/fn.try_invert_to.html) function
|
||||
pub fn call_linalg_try_invert_to<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
dim0: IntValue<'ctx>,
|
||||
dim1: IntValue<'ctx>,
|
||||
data: PointerValue<'ctx>,
|
||||
/// Macro to generate np_linalg external functions
|
||||
macro_rules! generate_np_linalg_extern_fn {
|
||||
($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 1) => {
|
||||
generate_np_linalg_extern_fn!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1);
|
||||
};
|
||||
($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 2) => {
|
||||
generate_np_linalg_extern_fn!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1, mat2);
|
||||
};
|
||||
($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 3) => {
|
||||
generate_np_linalg_extern_fn!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1, mat2, mat3);
|
||||
};
|
||||
($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 4) => {
|
||||
generate_np_linalg_extern_fn!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1, mat2, mat3, mat4);
|
||||
};
|
||||
($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal $(,$input_matrix:ident)*) => {
|
||||
#[doc = concat!("Invokes the numpy `", stringify!($extern_fn), " function." )]
|
||||
pub fn $fn_name<'ctx>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>
|
||||
$(,$input_matrix: (IntValue<'ctx>, IntValue<'ctx>, PointerValue<'ctx>))*,
|
||||
name: Option<&str>,
|
||||
) -> IntValue<'ctx> {
|
||||
const FN_NAME: &str = "linalg_try_invert_to";
|
||||
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
let allowed_indices = [ctx.ctx.i32_type(), ctx.ctx.i64_type()];
|
||||
|
||||
let allowed_dim0 = allowed_indices.iter().any(|p| *p == dim0.get_type());
|
||||
let allowed_dim1 = allowed_indices.iter().any(|p| *p == dim1.get_type());
|
||||
|
||||
debug_assert!(allowed_dim0);
|
||||
debug_assert!(allowed_dim1);
|
||||
debug_assert_eq!(data.get_type().get_element_type().into_float_type(), llvm_f64);
|
||||
|
||||
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||
let fn_type = ctx.ctx.i8_type().fn_type(
|
||||
&[dim0.get_type().into(), dim0.get_type().into(), data.get_type().into()],
|
||||
false,
|
||||
);
|
||||
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
||||
);
|
||||
}
|
||||
|
||||
func
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(extern_fn, &[dim0.into(), dim1.into(), data.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`wilkinson_shift`](https://docs.rs/nalgebra/latest/nalgebra/linalg/fn.wilkinson_shift.html) function
|
||||
pub fn call_linalg_wilkinson_shift<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
dim0: IntValue<'ctx>,
|
||||
dim1: IntValue<'ctx>,
|
||||
data: PointerValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "linalg_wilkinson_shift";
|
||||
) -> $ret_ty<'ctx> {
|
||||
const FN_NAME: &str = $extern_fn;
|
||||
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
let allowed_index_types = [ctx.ctx.i32_type(), ctx.ctx.i64_type()];
|
||||
|
||||
let allowed_dim0 = allowed_index_types.iter().any(|p| *p == dim0.get_type());
|
||||
let allowed_dim1 = allowed_index_types.iter().any(|p| *p == dim1.get_type());
|
||||
$(
|
||||
debug_assert!(allowed_index_types.iter().any(|p| *p == $input_matrix.0.get_type()));
|
||||
debug_assert!(allowed_index_types.iter().any(|p| *p == $input_matrix.1.get_type()));
|
||||
debug_assert_eq!($input_matrix.2.get_type().get_element_type().into_float_type(), llvm_f64);
|
||||
)*
|
||||
|
||||
debug_assert!(allowed_dim0);
|
||||
debug_assert!(allowed_dim1);
|
||||
debug_assert_eq!(data.get_type().get_element_type().into_float_type(), llvm_f64);
|
||||
// let row = ctx.ctx.i32_type().const_int(ctx.current_loc.row.try_into().unwrap(), false);
|
||||
// let col = ctx.ctx.i32_type().const_int(ctx.current_loc.column.try_into().unwrap(), false);
|
||||
// let file_name = ctx.current_loc.file.0;
|
||||
// let name_len = ctx.ctx.i32_type().const_int(file_name.to_string().len().try_into().unwrap(), false);
|
||||
// let file_name = ctx.ctx.const_string(&ctx.current_loc.file.0.to_string().into_bytes(), true);
|
||||
|
||||
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||
let fn_type = ctx.ctx.f64_type().fn_type(
|
||||
&[dim0.get_type().into(), dim0.get_type().into(), data.get_type().into()],
|
||||
false,
|
||||
);
|
||||
// let fn_type = ctx.ctx.$extern_ret_ty().fn_type(&[row.get_type().into(), col.get_type().into(), file_name.get_type().into(), $($input_matrix.0.get_type().into(), $input_matrix.1.get_type().into(), $input_matrix.2.get_type().into()),*], false);
|
||||
// let fn_type = ctx.ctx.$extern_ret_ty().fn_type(&[row.get_type().into(), col.get_type().into(), file_name.get_type().into(), name_len.get_type().into(), $($input_matrix.0.get_type().into(), $input_matrix.1.get_type().into(), $input_matrix.2.get_type().into()),*], false);
|
||||
let fn_type = ctx.ctx.$extern_ret_ty().fn_type(&[$($input_matrix.0.get_type().into(), $input_matrix.1.get_type().into(), $input_matrix.2.get_type().into()),*], false);
|
||||
|
||||
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
|
||||
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
||||
);
|
||||
}
|
||||
|
||||
func
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(extern_fn, &[dim0.into(), dim1.into(), data.into()], name.unwrap_or_default())
|
||||
// .build_call(extern_fn, &[row.into(), col.into(), file_name.into(), $($input_matrix.0.into(), $input_matrix.1.into(), $input_matrix.2.into(),)*], name.unwrap_or_default())
|
||||
// .build_call(extern_fn, &[name_len.into(), col.into(), file_name.into(), row.into(), $($input_matrix.0.into(), $input_matrix.1.into(), $input_matrix.2.into(),)*], name.unwrap_or_default())
|
||||
.build_call(extern_fn, &[$($input_matrix.0.into(), $input_matrix.1.into(), $input_matrix.2.into(),)*], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(|v| v.map_left($map_fn))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Macro to generate `np_linalg` external functions
|
||||
macro_rules! generate_np_linalg_extern_fn2 {
|
||||
($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 1) => {
|
||||
generate_np_linalg_extern_fn2!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1);
|
||||
};
|
||||
($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 2) => {
|
||||
generate_np_linalg_extern_fn2!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1, mat2);
|
||||
};
|
||||
($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 3) => {
|
||||
generate_np_linalg_extern_fn2!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1, mat2, mat3);
|
||||
};
|
||||
($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 4) => {
|
||||
generate_np_linalg_extern_fn2!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1, mat2, mat3, mat4);
|
||||
};
|
||||
($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal $(,$input_matrix:ident)*) => {
|
||||
#[doc = concat!("Invokes the numpy `", stringify!($extern_fn), " function." )]
|
||||
pub fn $fn_name<'ctx>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>
|
||||
$(,$input_matrix: BasicValueEnum<'ctx>)*,
|
||||
name: Option<&str>,
|
||||
) -> $ret_ty<'ctx> {
|
||||
const FN_NAME: &str = $extern_fn;
|
||||
// let row = ctx.ctx.i32_type().const_int(ctx.current_loc.row.try_into().unwrap(), false);
|
||||
// let col = ctx.ctx.i32_type().const_int(ctx.current_loc.column.try_into().unwrap(), false);
|
||||
// let file_name = ctx.current_loc.file.0;
|
||||
// let name_len = ctx.ctx.i32_type().const_int(file_name.to_string().len().try_into().unwrap(), false);
|
||||
// let file_name = ctx.ctx.const_string(&ctx.current_loc.file.0.to_string().into_bytes(), true);
|
||||
|
||||
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||
// let fn_type = ctx.ctx.$extern_ret_ty().fn_type(&[row.get_type().into(), col.get_type().into(), file_name.get_type().into(), $($input_matrix.0.get_type().into(), $input_matrix.1.get_type().into(), $input_matrix.2.get_type().into()),*], false);
|
||||
// let fn_type = ctx.ctx.$extern_ret_ty().fn_type(&[row.get_type().into(), col.get_type().into(), file_name.get_type().into(), name_len.get_type().into(), $($input_matrix.0.get_type().into(), $input_matrix.1.get_type().into(), $input_matrix.2.get_type().into()),*], false);
|
||||
let fn_type = ctx.ctx.$extern_ret_ty().fn_type(&[$($input_matrix.get_type().into()),*], false);
|
||||
|
||||
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
||||
);
|
||||
}
|
||||
func
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
// .build_call(extern_fn, &[row.into(), col.into(), file_name.into(), $($input_matrix.0.into(), $input_matrix.1.into(), $input_matrix.2.into(),)*], name.unwrap_or_default())
|
||||
// .build_call(extern_fn, &[name_len.into(), col.into(), file_name.into(), row.into(), $($input_matrix.0.into(), $input_matrix.1.into(), $input_matrix.2.into(),)*], name.unwrap_or_default())
|
||||
.build_call(extern_fn, &[$($input_matrix.into(),)*], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left($map_fn))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
};
|
||||
}
|
||||
/// Macro to generate `np_linalg` external functions
|
||||
macro_rules! generate_np_linalg_extern_fn3 {
|
||||
($fn_name:ident, $extern_fn:literal, 1) => {
|
||||
generate_np_linalg_extern_fn3!($fn_name, $extern_fn, mat1);
|
||||
};
|
||||
($fn_name:ident, $extern_fn:literal, 2) => {
|
||||
generate_np_linalg_extern_fn3!($fn_name, $extern_fn, mat1, mat2);
|
||||
};
|
||||
($fn_name:ident, $extern_fn:literal, 3) => {
|
||||
generate_np_linalg_extern_fn3!($fn_name, $extern_fn, mat1, mat2, mat3);
|
||||
};
|
||||
($fn_name:ident, $extern_fn:literal, 4) => {
|
||||
generate_np_linalg_extern_fn3!($fn_name, $extern_fn, mat1, mat2, mat3, mat4);
|
||||
};
|
||||
($fn_name:ident, $extern_fn:literal $(,$input_matrix:ident)*) => {
|
||||
#[doc = concat!("Invokes the numpy `", stringify!($extern_fn), " function." )]
|
||||
pub fn $fn_name<'ctx>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>
|
||||
$(,$input_matrix: BasicValueEnum<'ctx>)*,
|
||||
name: Option<&str>,
|
||||
){
|
||||
const FN_NAME: &str = $extern_fn;
|
||||
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||
let fn_type = ctx.ctx.void_type().fn_type(&[$($input_matrix.get_type().into()),*], false);
|
||||
|
||||
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
||||
);
|
||||
}
|
||||
func
|
||||
});
|
||||
|
||||
ctx.builder.build_call(extern_fn, &[$($input_matrix.into(),)*], name.unwrap_or_default()).unwrap();
|
||||
}
|
||||
};
|
||||
}
|
||||
generate_np_linalg_extern_fn2!(
|
||||
call_np_dot,
|
||||
FloatValue,
|
||||
f64_type,
|
||||
BasicValueEnum::into_float_value,
|
||||
"np_dot",
|
||||
2
|
||||
);
|
||||
generate_np_linalg_extern_fn3!(call_np_linalg_matmul, "np_linalg_matmul", 3);
|
||||
generate_np_linalg_extern_fn3!(call_np_linalg_cholesky, "np_linalg_cholesky", 2);
|
||||
generate_np_linalg_extern_fn3!(call_np_linalg_qr, "np_linalg_qr", 3);
|
||||
generate_np_linalg_extern_fn3!(call_np_linalg_svd, "np_linalg_svd", 4);
|
||||
|
||||
generate_np_linalg_extern_fn!(
|
||||
call_np_linalg_inv,
|
||||
IntValue,
|
||||
i8_type,
|
||||
BasicValueEnum::into_int_value,
|
||||
"np_linalg_inv",
|
||||
2
|
||||
);
|
||||
|
||||
generate_np_linalg_extern_fn!(
|
||||
call_np_linalg_pinv,
|
||||
IntValue,
|
||||
i8_type,
|
||||
BasicValueEnum::into_int_value,
|
||||
"np_linalg_pinv",
|
||||
2
|
||||
);
|
||||
|
||||
generate_np_linalg_extern_fn!(
|
||||
call_sp_linalg_lu,
|
||||
IntValue,
|
||||
i8_type,
|
||||
BasicValueEnum::into_int_value,
|
||||
"sp_linalg_lu",
|
||||
3
|
||||
);
|
||||
|
||||
generate_np_linalg_extern_fn!(
|
||||
call_sp_linalg_schur,
|
||||
IntValue,
|
||||
i8_type,
|
||||
BasicValueEnum::into_int_value,
|
||||
"sp_linalg_schur",
|
||||
3
|
||||
);
|
||||
|
||||
generate_np_linalg_extern_fn!(
|
||||
call_sp_linalg_hessenberg,
|
||||
IntValue,
|
||||
i8_type,
|
||||
BasicValueEnum::into_int_value,
|
||||
"sp_linalg_hessenberg",
|
||||
2
|
||||
);
|
||||
|
@ -1,30 +0,0 @@
|
||||
use core::slice;
|
||||
use nalgebra::{linalg, DMatrix};
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `data` must point to an array with `dim0`x`dim1` elements in row-major order
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn linalg_try_invert_to(dim0: usize, dim1: usize, data: *mut f64) -> i8 {
|
||||
let data_slice = unsafe { slice::from_raw_parts_mut(data, dim0 * dim1) };
|
||||
let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice);
|
||||
let mut inverted_matrix = DMatrix::<f64>::zeros(dim0, dim1);
|
||||
|
||||
if linalg::try_invert_to(matrix, &mut inverted_matrix) {
|
||||
data_slice.copy_from_slice(inverted_matrix.transpose().as_slice());
|
||||
1
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `data` must point to an array of 4 elements in row-major order
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn linalg_wilkinson_shift(dim0: usize, dim1: usize, data: *mut f64) -> f64 {
|
||||
let data_slice = unsafe { slice::from_raw_parts_mut(data, dim0 * dim1) };
|
||||
let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice);
|
||||
|
||||
linalg::wilkinson_shift(matrix[(0, 0)], matrix[(1, 1)], matrix[(0, 1)])
|
||||
}
|
@ -61,7 +61,7 @@ fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
|
||||
/// * `shape` - The shape of the `NDArray`.
|
||||
/// * `shape_len_fn` - A function that retrieves the number of dimensions from `shape`.
|
||||
/// * `shape_data_fn` - A function that retrieves the size of a dimension from `shape`.
|
||||
fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
|
||||
pub fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
elem_ty: Type,
|
||||
@ -157,7 +157,7 @@ where
|
||||
///
|
||||
/// * `elem_ty` - The element type of the `NDArray`.
|
||||
/// * `shape` - The shape of the `NDArray`, represented am array of [`IntValue`]s.
|
||||
fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
|
||||
pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
elem_ty: Type,
|
||||
|
@ -557,7 +557,18 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
| PrimDef::FunNpHypot
|
||||
| PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim),
|
||||
|
||||
PrimDef::FunTryInvertTo | PrimDef::FunWilkinsonShift => self.build_linalg_methods(prim),
|
||||
PrimDef::FunNpDot
|
||||
| PrimDef::FunNpLinalgMatmul
|
||||
| PrimDef::FunNpLinalgCholesky
|
||||
| PrimDef::FunNpLinalgQr
|
||||
| PrimDef::FunNpLinalgSvd
|
||||
| PrimDef::FunNpLinalgInv
|
||||
| PrimDef::FunNpLinalgPinv
|
||||
| PrimDef::FunSpLinalgLu
|
||||
| PrimDef::FunSpLinalgSchur
|
||||
| PrimDef::FunSpLinalgHessenberg => self.build_np_linalg_methods(prim),
|
||||
// PrimDef::FunNpDot | PrimDef::FunNpLinalgMatmul => self.build_np_linalg_binary_methods(prim),
|
||||
// PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgQr => self.build_np_linalg_unary_methods(prim),
|
||||
};
|
||||
|
||||
if cfg!(debug_assertions) {
|
||||
@ -1876,37 +1887,142 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the functions `try_invert_to` and `wilkinson_shift`
|
||||
fn build_linalg_methods(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||
debug_assert_prim_is_allowed(prim, &[PrimDef::FunTryInvertTo, PrimDef::FunWilkinsonShift]);
|
||||
fn build_np_linalg_methods(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||
debug_assert_prim_is_allowed(
|
||||
prim,
|
||||
&[
|
||||
PrimDef::FunNpDot,
|
||||
PrimDef::FunNpLinalgMatmul,
|
||||
PrimDef::FunNpLinalgCholesky,
|
||||
PrimDef::FunNpLinalgQr,
|
||||
PrimDef::FunNpLinalgSvd,
|
||||
PrimDef::FunNpLinalgInv,
|
||||
PrimDef::FunNpLinalgPinv,
|
||||
PrimDef::FunSpLinalgLu,
|
||||
PrimDef::FunSpLinalgSchur,
|
||||
PrimDef::FunSpLinalgHessenberg,
|
||||
],
|
||||
);
|
||||
|
||||
let ret_ty = match prim {
|
||||
PrimDef::FunTryInvertTo => self.primitives.bool,
|
||||
PrimDef::FunWilkinsonShift => self.primitives.float,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let var_map = self.num_or_ndarray_var_map.clone();
|
||||
create_fn_by_codegen(
|
||||
match prim {
|
||||
PrimDef::FunNpDot => create_fn_by_codegen(
|
||||
self.unifier,
|
||||
&var_map,
|
||||
&self.num_or_ndarray_var_map,
|
||||
prim.name(),
|
||||
ret_ty,
|
||||
&[(self.ndarray_float_2d, "x")],
|
||||
self.primitives.float,
|
||||
&[(self.num_or_ndarray_ty.ty, "x1"), (self.num_or_ndarray_ty.ty, "x2")],
|
||||
Box::new(move |ctx, _, fun, args, generator| {
|
||||
let x_ty = fun.0.args[0].ty;
|
||||
let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?;
|
||||
let x1_ty = fun.0.args[0].ty;
|
||||
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||
let x2_ty = fun.0.args[1].ty;
|
||||
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||
|
||||
Ok(Some(builtin_fns::call_np_dot(
|
||||
generator,
|
||||
ctx,
|
||||
(x1_ty, x1_val),
|
||||
(x2_ty, x2_val),
|
||||
)?))
|
||||
}),
|
||||
),
|
||||
|
||||
PrimDef::FunNpLinalgMatmul => create_fn_by_codegen(
|
||||
self.unifier,
|
||||
&VarMap::new(),
|
||||
prim.name(),
|
||||
self.ndarray_float_2d,
|
||||
&[(self.ndarray_float_2d, "x1"), (self.ndarray_float_2d, "x2")],
|
||||
Box::new(move |ctx, _, fun, args, generator| {
|
||||
let x1_ty = fun.0.args[0].ty;
|
||||
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||
let x2_ty = fun.0.args[1].ty;
|
||||
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||
|
||||
Ok(Some(builtin_fns::call_np_linalg_matmul(
|
||||
generator,
|
||||
ctx,
|
||||
(x1_ty, x1_val),
|
||||
(x2_ty, x2_val),
|
||||
)?))
|
||||
}),
|
||||
),
|
||||
|
||||
PrimDef::FunNpLinalgCholesky
|
||||
| PrimDef::FunNpLinalgInv
|
||||
| PrimDef::FunNpLinalgPinv
|
||||
| PrimDef::FunSpLinalgHessenberg => create_fn_by_codegen(
|
||||
self.unifier,
|
||||
&VarMap::new(),
|
||||
prim.name(),
|
||||
self.ndarray_float_2d,
|
||||
&[(self.ndarray_float_2d, "x1")],
|
||||
Box::new(move |ctx, _, fun, args, generator| {
|
||||
let x1_ty = fun.0.args[0].ty;
|
||||
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||
|
||||
let func = match prim {
|
||||
PrimDef::FunTryInvertTo => builtin_fns::call_linalg_try_invert_to,
|
||||
PrimDef::FunWilkinsonShift => builtin_fns::call_linalg_wilkinson_shift,
|
||||
PrimDef::FunNpLinalgCholesky => builtin_fns::call_np_linalg_cholesky,
|
||||
PrimDef::FunNpLinalgInv => builtin_fns::call_np_linalg_inv,
|
||||
PrimDef::FunNpLinalgPinv => builtin_fns::call_np_linalg_pinv,
|
||||
PrimDef::FunSpLinalgHessenberg => builtin_fns::call_sp_linalg_hessenberg,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
Ok(Some(func(generator, ctx, (x1_ty, x1_val))?))
|
||||
}),
|
||||
),
|
||||
|
||||
Ok(Some(func(generator, ctx, (x_ty, x_val))?))
|
||||
PrimDef::FunNpLinalgQr | PrimDef::FunSpLinalgLu | PrimDef::FunSpLinalgSchur => {
|
||||
let ret_ty = self.unifier.add_ty(TypeEnum::TTuple {
|
||||
ty: vec![self.ndarray_float_2d, self.ndarray_float_2d],
|
||||
});
|
||||
create_fn_by_codegen(
|
||||
self.unifier,
|
||||
&VarMap::new(),
|
||||
prim.name(),
|
||||
ret_ty,
|
||||
&[(self.ndarray_float_2d, "x1")],
|
||||
Box::new(move |ctx, _, fun, args, generator| {
|
||||
let x1_ty = fun.0.args[0].ty;
|
||||
let x1_val =
|
||||
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||
|
||||
let func = match prim {
|
||||
PrimDef::FunNpLinalgQr => builtin_fns::call_np_linalg_qr,
|
||||
PrimDef::FunSpLinalgLu => builtin_fns::call_sp_linalg_lu,
|
||||
PrimDef::FunSpLinalgSchur => builtin_fns::call_sp_linalg_schur,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
Ok(Some(func(generator, ctx, (x1_ty, x1_val))?))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
PrimDef::FunNpLinalgSvd => {
|
||||
let ret_ty = self.unifier.add_ty(TypeEnum::TTuple {
|
||||
ty: vec![self.ndarray_float_2d, self.ndarray_float, self.ndarray_float_2d],
|
||||
});
|
||||
create_fn_by_codegen(
|
||||
self.unifier,
|
||||
&VarMap::new(),
|
||||
prim.name(),
|
||||
ret_ty,
|
||||
&[(self.ndarray_float_2d, "x1")],
|
||||
Box::new(move |ctx, _, fun, args, generator| {
|
||||
let x1_ty = fun.0.args[0].ty;
|
||||
let x1_val =
|
||||
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||
|
||||
Ok(Some(builtin_fns::call_np_linalg_svd(generator, ctx, (x1_ty, x1_val))?))
|
||||
}),
|
||||
)
|
||||
}
|
||||
_ => {
|
||||
println!("{:?}", prim.name());
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) {
|
||||
(prim.simple_name().into(), method_ty, prim.id())
|
||||
}
|
||||
|
@ -105,8 +105,16 @@ pub enum PrimDef {
|
||||
FunNpLdExp,
|
||||
FunNpHypot,
|
||||
FunNpNextAfter,
|
||||
FunTryInvertTo,
|
||||
FunWilkinsonShift,
|
||||
FunNpDot,
|
||||
FunNpLinalgMatmul,
|
||||
FunNpLinalgCholesky,
|
||||
FunNpLinalgQr,
|
||||
FunNpLinalgSvd,
|
||||
FunNpLinalgInv,
|
||||
FunNpLinalgPinv,
|
||||
FunSpLinalgLu,
|
||||
FunSpLinalgSchur,
|
||||
FunSpLinalgHessenberg,
|
||||
|
||||
// Top-Level Functions
|
||||
FunSome,
|
||||
@ -265,8 +273,17 @@ impl PrimDef {
|
||||
PrimDef::FunNpLdExp => fun("np_ldexp", None),
|
||||
PrimDef::FunNpHypot => fun("np_hypot", None),
|
||||
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
|
||||
PrimDef::FunTryInvertTo => fun("try_invert_to", None),
|
||||
PrimDef::FunWilkinsonShift => fun("wilkinson_shift", None),
|
||||
PrimDef::FunNpDot => fun("np_dot", None),
|
||||
PrimDef::FunNpLinalgMatmul => fun("np_linalg_matmul", None),
|
||||
PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None),
|
||||
PrimDef::FunNpLinalgQr => fun("np_linalg_qr", None),
|
||||
PrimDef::FunNpLinalgSvd => fun("np_linalg_svd", None),
|
||||
PrimDef::FunNpLinalgInv => fun("np_linalg_inv", None),
|
||||
PrimDef::FunNpLinalgPinv => fun("np_linalg_pinv", None),
|
||||
PrimDef::FunSpLinalgLu => fun("sp_linalg_lu", None),
|
||||
PrimDef::FunSpLinalgSchur => fun("sp_linalg_schur", None),
|
||||
PrimDef::FunSpLinalgHessenberg => fun("sp_linalg_hessenberg", None),
|
||||
|
||||
PrimDef::FunSome => fun("Some", None),
|
||||
}
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ edition = "2021"
|
||||
parking_lot = "0.12"
|
||||
nac3parser = { path = "../nac3parser" }
|
||||
nac3core = { path = "../nac3core" }
|
||||
linalg_externfns = { path = "./linalg_externfns" }
|
||||
|
||||
[dependencies.clap]
|
||||
version = "4.5"
|
||||
|
@ -5,6 +5,7 @@ import importlib.util
|
||||
import importlib.machinery
|
||||
import math
|
||||
import numpy as np
|
||||
import scipy as sp
|
||||
import numpy.typing as npt
|
||||
import pathlib
|
||||
|
||||
@ -246,8 +247,21 @@ def patch(module):
|
||||
module.sp_spec_j0 = special.j0
|
||||
module.sp_spec_j1 = special.j1
|
||||
|
||||
module.try_invert_to = try_invert_to
|
||||
module.wilkinson_shift = wilkinson_shift
|
||||
# Linalg functions
|
||||
module.np_dot = np.dot
|
||||
module.np_linalg_matmul = np.matmul
|
||||
module.np_linalg_cholesky = np.linalg.cholesky
|
||||
module.np_linalg_qr = np.linalg.qr
|
||||
module.np_linalg_svd = np.linalg.svd
|
||||
module.np_linalg_inv = np.linalg.inv
|
||||
module.np_linalg_pinv = np.linalg.pinv
|
||||
|
||||
module.sp_linalg_lu = lambda x: sp.linalg.lu(x, True)
|
||||
module.sp_linalg_schur = sp.linalg.schur
|
||||
# module.sp_linalg_hessenberg = sp.linalg.hessenberg
|
||||
module.sp_linalg_hessenberg = lambda x: x
|
||||
|
||||
|
||||
|
||||
def file_import(filename, prefix="file_import_"):
|
||||
filename = pathlib.Path(filename)
|
||||
|
@ -42,14 +42,14 @@ done
|
||||
|
||||
if [ -n "$debug" ] && [ -e ../../target/debug/nac3standalone ]; then
|
||||
nac3standalone=../../target/debug/nac3standalone
|
||||
externfns=../../target/debug/deps/libexternfns.so
|
||||
externfns=../../target/debug/deps/liblinalg_externfns.so
|
||||
elif [ -e ../../target/release/nac3standalone ]; then
|
||||
nac3standalone=../../target/release/nac3standalone
|
||||
externfns=../../target/release/deps/libexternfns.so
|
||||
externfns=../../target/release/deps/liblinalg_externfns.so
|
||||
else
|
||||
# used by Nix builds
|
||||
nac3standalone=../../target/x86_64-unknown-linux-gnu/release/nac3standalone
|
||||
externfns=../../target/x86_64-unknown-linux-gnu/release/deps/libexternfns.so
|
||||
externfns=../../target/x86_64-unknown-linux-gnu/release/deps/liblinalg_externfns.so
|
||||
fi
|
||||
|
||||
rm -f ./*.o ./*.bc demo
|
||||
|
12
nac3standalone/demo/sample
Normal file
12
nac3standalone/demo/sample
Normal file
@ -0,0 +1,12 @@
|
||||
Checking src/ndarray.py... Function Called np_dot
|
||||
Module { data_layout: RefCell { value: Some(DataLayout { address: 0x555559cda5b0, repr: "" }) }, module: Cell { value: 0x555559cda3a0 }, owned_by_ee: RefCell { value: None }, _marker: PhantomData<&inkwell::context::Context> }
|
||||
Function Called np_dot
|
||||
Module { data_layout: RefCell { value: Some(DataLayout { address: 0x555559cda740, repr: "" }) }, module: Cell { value: 0x555559cda530 }, owned_by_ee: RefCell { value: None }, _marker: PhantomData<&inkwell::context::Context> }
|
||||
--- interpreted.log 2024-07-24 19:39:47.480093947 +0800
|
||||
+++ run.log 2024-07-24 22:39:50.183382396 +0800
|
||||
@@ -0,0 +1,5 @@
|
||||
+5.000000
|
||||
+1.000000
|
||||
+5.000000
|
||||
+1.000000
|
||||
+26.000000
|
@ -1429,200 +1429,289 @@ def test_ndarray_nextafter_broadcast_rhs_scalar():
|
||||
output_ndarray_float_2(nextafter_x_zeros)
|
||||
output_ndarray_float_2(nextafter_x_ones)
|
||||
|
||||
def test_try_invert():
|
||||
x: ndarray[float, 2] = np_array([[1.0, 2.0], [3.0, 4.0]])
|
||||
output_ndarray_float_2(x)
|
||||
y = try_invert_to(x)
|
||||
def test_ndarray_dot():
|
||||
x: ndarray[float, 1] = np_array([5.0, 1.0])
|
||||
y: ndarray[float, 1] = np_array([5.0, 1.0])
|
||||
z = np_dot(x, y)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_bool(y)
|
||||
output_ndarray_float_1(x)
|
||||
output_ndarray_float_1(y)
|
||||
output_float64(z)
|
||||
|
||||
def test_wilkinson_shift():
|
||||
def test_ndarray_linalg_matmul():
|
||||
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
|
||||
y = wilkinson_shift(x)
|
||||
y: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
|
||||
z = np_linalg_matmul(x, y)
|
||||
|
||||
m = np_argmax(z)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_float64(y)
|
||||
output_ndarray_float_2(y)
|
||||
output_ndarray_float_2(z)
|
||||
output_int64(m)
|
||||
|
||||
def test_ndarray_cholesky():
|
||||
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
|
||||
y = np_linalg_cholesky(x)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_qr():
|
||||
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
|
||||
y, z = np_linalg_qr(x)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
|
||||
# QR Factorization in nalgebra and numpy do not give the same result
|
||||
# Generating product for printing
|
||||
a = np_linalg_matmul(y, z)
|
||||
output_ndarray_float_2(a)
|
||||
|
||||
def test_ndarray_linalg_inv():
|
||||
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
|
||||
y = np_linalg_inv(x)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_pinv():
|
||||
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5]])
|
||||
y = np_linalg_pinv(x)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_schur():
|
||||
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
|
||||
t, z = sp_linalg_schur(x)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
# Same as np_linalg_qr the signs are different in nalgebra and numpy
|
||||
a = np_linalg_matmul(np_linalg_matmul(z, t), np_linalg_inv(z))
|
||||
output_ndarray_float_2(a)
|
||||
|
||||
def test_ndarray_hessenberg():
|
||||
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5]])
|
||||
h = sp_linalg_hessenberg(x)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(h)
|
||||
|
||||
|
||||
def test_ndarray_lu():
|
||||
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5]])
|
||||
l, u = sp_linalg_lu(x)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(l)
|
||||
output_ndarray_float_2(u)
|
||||
|
||||
|
||||
def test_ndarray_svd():
|
||||
w: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
|
||||
x, y, z = np_linalg_svd(w)
|
||||
|
||||
output_ndarray_float_2(w)
|
||||
|
||||
# Same as np_linalg_qr the signs are different in nalgebra and numpy
|
||||
a = np_linalg_matmul(x, z)
|
||||
output_ndarray_float_2(a)
|
||||
output_ndarray_float_1(y)
|
||||
|
||||
|
||||
def run() -> int32:
|
||||
test_ndarray_ctor()
|
||||
test_ndarray_empty()
|
||||
test_ndarray_zeros()
|
||||
test_ndarray_ones()
|
||||
test_ndarray_full()
|
||||
test_ndarray_eye()
|
||||
test_ndarray_array()
|
||||
test_ndarray_identity()
|
||||
test_ndarray_fill()
|
||||
test_ndarray_copy()
|
||||
# test_ndarray_matmul()
|
||||
test_ndarray_dot()
|
||||
test_ndarray_linalg_matmul()
|
||||
test_ndarray_cholesky()
|
||||
test_ndarray_qr()
|
||||
test_ndarray_svd()
|
||||
# test_ndarray_linalg_inv()
|
||||
# test_ndarray_pinv()
|
||||
# test_ndarray_lu()
|
||||
# test_ndarray_schur()
|
||||
# test_ndarray_hessenberg()
|
||||
|
||||
test_ndarray_neg_idx()
|
||||
test_ndarray_slices()
|
||||
test_ndarray_nd_idx()
|
||||
# test_ndarray_ctor()
|
||||
# test_ndarray_empty()
|
||||
# test_ndarray_zeros()
|
||||
# test_ndarray_ones()
|
||||
# test_ndarray_full()
|
||||
# test_ndarray_eye()
|
||||
# test_ndarray_array()
|
||||
# test_ndarray_identity()
|
||||
# test_ndarray_fill()
|
||||
# test_ndarray_copy()
|
||||
|
||||
test_ndarray_add()
|
||||
test_ndarray_add_broadcast()
|
||||
test_ndarray_add_broadcast_lhs_scalar()
|
||||
test_ndarray_add_broadcast_rhs_scalar()
|
||||
test_ndarray_iadd()
|
||||
test_ndarray_iadd_broadcast()
|
||||
test_ndarray_iadd_broadcast_scalar()
|
||||
test_ndarray_sub()
|
||||
test_ndarray_sub_broadcast()
|
||||
test_ndarray_sub_broadcast_lhs_scalar()
|
||||
test_ndarray_sub_broadcast_rhs_scalar()
|
||||
test_ndarray_isub()
|
||||
test_ndarray_isub_broadcast()
|
||||
test_ndarray_isub_broadcast_scalar()
|
||||
test_ndarray_mul()
|
||||
test_ndarray_mul_broadcast()
|
||||
test_ndarray_mul_broadcast_lhs_scalar()
|
||||
test_ndarray_mul_broadcast_rhs_scalar()
|
||||
test_ndarray_imul()
|
||||
test_ndarray_imul_broadcast()
|
||||
test_ndarray_imul_broadcast_scalar()
|
||||
test_ndarray_truediv()
|
||||
test_ndarray_truediv_broadcast()
|
||||
test_ndarray_truediv_broadcast_lhs_scalar()
|
||||
test_ndarray_truediv_broadcast_rhs_scalar()
|
||||
test_ndarray_itruediv()
|
||||
test_ndarray_itruediv_broadcast()
|
||||
test_ndarray_itruediv_broadcast_scalar()
|
||||
test_ndarray_floordiv()
|
||||
test_ndarray_floordiv_broadcast()
|
||||
test_ndarray_floordiv_broadcast_lhs_scalar()
|
||||
test_ndarray_floordiv_broadcast_rhs_scalar()
|
||||
test_ndarray_ifloordiv()
|
||||
test_ndarray_ifloordiv_broadcast()
|
||||
test_ndarray_ifloordiv_broadcast_scalar()
|
||||
test_ndarray_mod()
|
||||
test_ndarray_mod_broadcast()
|
||||
test_ndarray_mod_broadcast_lhs_scalar()
|
||||
test_ndarray_mod_broadcast_rhs_scalar()
|
||||
test_ndarray_imod()
|
||||
test_ndarray_imod_broadcast()
|
||||
test_ndarray_imod_broadcast_scalar()
|
||||
test_ndarray_pow()
|
||||
test_ndarray_pow_broadcast()
|
||||
test_ndarray_pow_broadcast_lhs_scalar()
|
||||
test_ndarray_pow_broadcast_rhs_scalar()
|
||||
test_ndarray_ipow()
|
||||
test_ndarray_ipow_broadcast()
|
||||
test_ndarray_ipow_broadcast_scalar()
|
||||
test_ndarray_matmul()
|
||||
test_ndarray_imatmul()
|
||||
test_ndarray_pos()
|
||||
test_ndarray_neg()
|
||||
test_ndarray_inv()
|
||||
test_ndarray_eq()
|
||||
test_ndarray_eq_broadcast()
|
||||
test_ndarray_eq_broadcast_lhs_scalar()
|
||||
test_ndarray_eq_broadcast_rhs_scalar()
|
||||
test_ndarray_ne()
|
||||
test_ndarray_ne_broadcast()
|
||||
test_ndarray_ne_broadcast_lhs_scalar()
|
||||
test_ndarray_ne_broadcast_rhs_scalar()
|
||||
test_ndarray_lt()
|
||||
test_ndarray_lt_broadcast()
|
||||
test_ndarray_lt_broadcast_lhs_scalar()
|
||||
test_ndarray_lt_broadcast_rhs_scalar()
|
||||
test_ndarray_lt()
|
||||
test_ndarray_le_broadcast()
|
||||
test_ndarray_le_broadcast_lhs_scalar()
|
||||
test_ndarray_le_broadcast_rhs_scalar()
|
||||
test_ndarray_gt()
|
||||
test_ndarray_gt_broadcast()
|
||||
test_ndarray_gt_broadcast_lhs_scalar()
|
||||
test_ndarray_gt_broadcast_rhs_scalar()
|
||||
test_ndarray_gt()
|
||||
test_ndarray_ge_broadcast()
|
||||
test_ndarray_ge_broadcast_lhs_scalar()
|
||||
test_ndarray_ge_broadcast_rhs_scalar()
|
||||
# test_ndarray_neg_idx()
|
||||
# test_ndarray_slices()
|
||||
# test_ndarray_nd_idx()
|
||||
|
||||
test_ndarray_int32()
|
||||
test_ndarray_int64()
|
||||
test_ndarray_uint32()
|
||||
test_ndarray_uint64()
|
||||
test_ndarray_float()
|
||||
test_ndarray_bool()
|
||||
# test_ndarray_add()
|
||||
# test_ndarray_add_broadcast()
|
||||
# test_ndarray_add_broadcast_lhs_scalar()
|
||||
# test_ndarray_add_broadcast_rhs_scalar()
|
||||
# test_ndarray_iadd()
|
||||
# test_ndarray_iadd_broadcast()
|
||||
# test_ndarray_iadd_broadcast_scalar()
|
||||
# test_ndarray_sub()
|
||||
# test_ndarray_sub_broadcast()
|
||||
# test_ndarray_sub_broadcast_lhs_scalar()
|
||||
# test_ndarray_sub_broadcast_rhs_scalar()
|
||||
# test_ndarray_isub()
|
||||
# test_ndarray_isub_broadcast()
|
||||
# test_ndarray_isub_broadcast_scalar()
|
||||
# test_ndarray_mul()
|
||||
# test_ndarray_mul_broadcast()
|
||||
# test_ndarray_mul_broadcast_lhs_scalar()
|
||||
# test_ndarray_mul_broadcast_rhs_scalar()
|
||||
# test_ndarray_imul()
|
||||
# test_ndarray_imul_broadcast()
|
||||
# test_ndarray_imul_broadcast_scalar()
|
||||
# test_ndarray_truediv()
|
||||
# test_ndarray_truediv_broadcast()
|
||||
# test_ndarray_truediv_broadcast_lhs_scalar()
|
||||
# test_ndarray_truediv_broadcast_rhs_scalar()
|
||||
# test_ndarray_itruediv()
|
||||
# test_ndarray_itruediv_broadcast()
|
||||
# test_ndarray_itruediv_broadcast_scalar()
|
||||
# test_ndarray_floordiv()
|
||||
# test_ndarray_floordiv_broadcast()
|
||||
# test_ndarray_floordiv_broadcast_lhs_scalar()
|
||||
# test_ndarray_floordiv_broadcast_rhs_scalar()
|
||||
# test_ndarray_ifloordiv()
|
||||
# test_ndarray_ifloordiv_broadcast()
|
||||
# test_ndarray_ifloordiv_broadcast_scalar()
|
||||
# test_ndarray_mod()
|
||||
# test_ndarray_mod_broadcast()
|
||||
# test_ndarray_mod_broadcast_lhs_scalar()
|
||||
# test_ndarray_mod_broadcast_rhs_scalar()
|
||||
# test_ndarray_imod()
|
||||
# test_ndarray_imod_broadcast()
|
||||
# test_ndarray_imod_broadcast_scalar()
|
||||
# test_ndarray_pow()
|
||||
# test_ndarray_pow_broadcast()
|
||||
# test_ndarray_pow_broadcast_lhs_scalar()
|
||||
# test_ndarray_pow_broadcast_rhs_scalar()
|
||||
# test_ndarray_ipow()
|
||||
# test_ndarray_ipow_broadcast()
|
||||
# test_ndarray_ipow_broadcast_scalar()
|
||||
# test_ndarray_matmul()
|
||||
# test_ndarray_imatmul()
|
||||
# test_ndarray_pos()
|
||||
# test_ndarray_neg()
|
||||
# test_ndarray_inv()
|
||||
# test_ndarray_eq()
|
||||
# test_ndarray_eq_broadcast()
|
||||
# test_ndarray_eq_broadcast_lhs_scalar()
|
||||
# test_ndarray_eq_broadcast_rhs_scalar()
|
||||
# test_ndarray_ne()
|
||||
# test_ndarray_ne_broadcast()
|
||||
# test_ndarray_ne_broadcast_lhs_scalar()
|
||||
# test_ndarray_ne_broadcast_rhs_scalar()
|
||||
# test_ndarray_lt()
|
||||
# test_ndarray_lt_broadcast()
|
||||
# test_ndarray_lt_broadcast_lhs_scalar()
|
||||
# test_ndarray_lt_broadcast_rhs_scalar()
|
||||
# test_ndarray_lt()
|
||||
# test_ndarray_le_broadcast()
|
||||
# test_ndarray_le_broadcast_lhs_scalar()
|
||||
# test_ndarray_le_broadcast_rhs_scalar()
|
||||
# test_ndarray_gt()
|
||||
# test_ndarray_gt_broadcast()
|
||||
# test_ndarray_gt_broadcast_lhs_scalar()
|
||||
# test_ndarray_gt_broadcast_rhs_scalar()
|
||||
# test_ndarray_gt()
|
||||
# test_ndarray_ge_broadcast()
|
||||
# test_ndarray_ge_broadcast_lhs_scalar()
|
||||
# test_ndarray_ge_broadcast_rhs_scalar()
|
||||
|
||||
test_ndarray_round()
|
||||
test_ndarray_floor()
|
||||
test_ndarray_min()
|
||||
test_ndarray_minimum()
|
||||
test_ndarray_minimum_broadcast()
|
||||
test_ndarray_minimum_broadcast_lhs_scalar()
|
||||
test_ndarray_minimum_broadcast_rhs_scalar()
|
||||
test_ndarray_argmin()
|
||||
test_ndarray_max()
|
||||
test_ndarray_maximum()
|
||||
test_ndarray_maximum_broadcast()
|
||||
test_ndarray_maximum_broadcast_lhs_scalar()
|
||||
test_ndarray_maximum_broadcast_rhs_scalar()
|
||||
test_ndarray_argmax()
|
||||
test_ndarray_abs()
|
||||
test_ndarray_isnan()
|
||||
test_ndarray_isinf()
|
||||
# test_ndarray_int32()
|
||||
# test_ndarray_int64()
|
||||
# test_ndarray_uint32()
|
||||
# test_ndarray_uint64()
|
||||
# test_ndarray_float()
|
||||
# test_ndarray_bool()
|
||||
|
||||
test_ndarray_sin()
|
||||
test_ndarray_cos()
|
||||
test_ndarray_exp()
|
||||
test_ndarray_exp2()
|
||||
test_ndarray_log()
|
||||
test_ndarray_log10()
|
||||
test_ndarray_log2()
|
||||
test_ndarray_fabs()
|
||||
test_ndarray_sqrt()
|
||||
test_ndarray_rint()
|
||||
test_ndarray_tan()
|
||||
test_ndarray_arcsin()
|
||||
test_ndarray_arccos()
|
||||
test_ndarray_arctan()
|
||||
test_ndarray_sinh()
|
||||
test_ndarray_cosh()
|
||||
test_ndarray_tanh()
|
||||
test_ndarray_arcsinh()
|
||||
test_ndarray_arccosh()
|
||||
test_ndarray_arctanh()
|
||||
test_ndarray_expm1()
|
||||
test_ndarray_cbrt()
|
||||
# test_ndarray_round()
|
||||
# test_ndarray_floor()
|
||||
# test_ndarray_min()
|
||||
# test_ndarray_minimum()
|
||||
# test_ndarray_minimum_broadcast()
|
||||
# test_ndarray_minimum_broadcast_lhs_scalar()
|
||||
# test_ndarray_minimum_broadcast_rhs_scalar()
|
||||
# test_ndarray_argmin()
|
||||
# test_ndarray_max()
|
||||
# test_ndarray_maximum()
|
||||
# test_ndarray_maximum_broadcast()
|
||||
# test_ndarray_maximum_broadcast_lhs_scalar()
|
||||
# test_ndarray_maximum_broadcast_rhs_scalar()
|
||||
# test_ndarray_argmax()
|
||||
# test_ndarray_abs()
|
||||
# test_ndarray_isnan()
|
||||
# test_ndarray_isinf()
|
||||
|
||||
test_ndarray_erf()
|
||||
test_ndarray_erfc()
|
||||
test_ndarray_gamma()
|
||||
test_ndarray_gammaln()
|
||||
test_ndarray_j0()
|
||||
test_ndarray_j1()
|
||||
# test_ndarray_sin()
|
||||
# test_ndarray_cos()
|
||||
# test_ndarray_exp()
|
||||
# test_ndarray_exp2()
|
||||
# test_ndarray_log()
|
||||
# test_ndarray_log10()
|
||||
# test_ndarray_log2()
|
||||
# test_ndarray_fabs()
|
||||
# test_ndarray_sqrt()
|
||||
# test_ndarray_rint()
|
||||
# test_ndarray_tan()
|
||||
# test_ndarray_arcsin()
|
||||
# test_ndarray_arccos()
|
||||
# test_ndarray_arctan()
|
||||
# test_ndarray_sinh()
|
||||
# test_ndarray_cosh()
|
||||
# test_ndarray_tanh()
|
||||
# test_ndarray_arcsinh()
|
||||
# test_ndarray_arccosh()
|
||||
# test_ndarray_arctanh()
|
||||
# test_ndarray_expm1()
|
||||
# test_ndarray_cbrt()
|
||||
|
||||
test_ndarray_arctan2()
|
||||
test_ndarray_arctan2_broadcast()
|
||||
test_ndarray_arctan2_broadcast_lhs_scalar()
|
||||
test_ndarray_arctan2_broadcast_rhs_scalar()
|
||||
test_ndarray_copysign()
|
||||
test_ndarray_copysign_broadcast()
|
||||
test_ndarray_copysign_broadcast_lhs_scalar()
|
||||
test_ndarray_copysign_broadcast_rhs_scalar()
|
||||
test_ndarray_fmax()
|
||||
test_ndarray_fmax_broadcast()
|
||||
test_ndarray_fmax_broadcast_lhs_scalar()
|
||||
test_ndarray_fmax_broadcast_rhs_scalar()
|
||||
test_ndarray_fmin()
|
||||
test_ndarray_fmin_broadcast()
|
||||
test_ndarray_fmin_broadcast_lhs_scalar()
|
||||
test_ndarray_fmin_broadcast_rhs_scalar()
|
||||
test_ndarray_ldexp()
|
||||
test_ndarray_ldexp_broadcast()
|
||||
test_ndarray_ldexp_broadcast_lhs_scalar()
|
||||
test_ndarray_ldexp_broadcast_rhs_scalar()
|
||||
test_ndarray_hypot()
|
||||
test_ndarray_hypot_broadcast()
|
||||
test_ndarray_hypot_broadcast_lhs_scalar()
|
||||
test_ndarray_hypot_broadcast_rhs_scalar()
|
||||
test_ndarray_nextafter()
|
||||
test_ndarray_nextafter_broadcast()
|
||||
test_ndarray_nextafter_broadcast_lhs_scalar()
|
||||
test_ndarray_nextafter_broadcast_rhs_scalar()
|
||||
# test_ndarray_erf()
|
||||
# test_ndarray_erfc()
|
||||
# test_ndarray_gamma()
|
||||
# test_ndarray_gammaln()
|
||||
# test_ndarray_j0()
|
||||
# test_ndarray_j1()
|
||||
|
||||
test_try_invert()
|
||||
test_wilkinson_shift()
|
||||
# test_ndarray_arctan2()
|
||||
# test_ndarray_arctan2_broadcast()
|
||||
# test_ndarray_arctan2_broadcast_lhs_scalar()
|
||||
# test_ndarray_arctan2_broadcast_rhs_scalar()
|
||||
# test_ndarray_copysign()
|
||||
# test_ndarray_copysign_broadcast()
|
||||
# test_ndarray_copysign_broadcast_lhs_scalar()
|
||||
# test_ndarray_copysign_broadcast_rhs_scalar()
|
||||
# test_ndarray_fmax()
|
||||
# test_ndarray_fmax_broadcast()
|
||||
# test_ndarray_fmax_broadcast_lhs_scalar()
|
||||
# test_ndarray_fmax_broadcast_rhs_scalar()
|
||||
# test_ndarray_fmin()
|
||||
# test_ndarray_fmin_broadcast()
|
||||
# test_ndarray_fmin_broadcast_lhs_scalar()
|
||||
# test_ndarray_fmin_broadcast_rhs_scalar()
|
||||
# test_ndarray_ldexp()
|
||||
# test_ndarray_ldexp_broadcast()
|
||||
# test_ndarray_ldexp_broadcast_lhs_scalar()
|
||||
# test_ndarray_ldexp_broadcast_rhs_scalar()
|
||||
# test_ndarray_hypot()
|
||||
# test_ndarray_hypot_broadcast()
|
||||
# test_ndarray_hypot_broadcast_lhs_scalar()
|
||||
# test_ndarray_hypot_broadcast_rhs_scalar()
|
||||
# test_ndarray_nextafter()
|
||||
# test_ndarray_nextafter_broadcast()
|
||||
# test_ndarray_nextafter_broadcast_lhs_scalar()
|
||||
# test_ndarray_nextafter_broadcast_rhs_scalar()
|
||||
|
||||
# test_try_invert()
|
||||
# test_wilkinson_shift()
|
||||
|
||||
return 0
|
||||
|
@ -1,5 +1,5 @@
|
||||
[package]
|
||||
name = "externfns"
|
||||
name = "linalg_externfns"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
@ -8,3 +8,4 @@ crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
nalgebra = {version = "0.32.6", default-features = false, features = ["libm", "alloc"]}
|
||||
cslice = "0.3.0"
|
407
nac3standalone/linalg_externfns/src/lib.rs
Normal file
407
nac3standalone/linalg_externfns/src/lib.rs
Normal file
@ -0,0 +1,407 @@
|
||||
mod runtime_exception;
|
||||
use core::slice;
|
||||
use nalgebra::{linalg, DMatrix};
|
||||
|
||||
pub struct InputMatrix {
|
||||
pub ndims: usize,
|
||||
pub dims: *const usize,
|
||||
pub data: *mut f64,
|
||||
}
|
||||
impl InputMatrix {
|
||||
fn get_dims(&mut self) -> Vec<usize> {
|
||||
let dims = unsafe { slice::from_raw_parts(self.dims, self.ndims) };
|
||||
dims.to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! raise_exn {
|
||||
($name:expr, $fn_name:expr, $message:expr, $param0:expr, $param1:expr, $param2:expr) => {{
|
||||
use cslice::AsCSlice;
|
||||
let name_id = $crate::runtime_exception::get_exception_id($name);
|
||||
let exn = $crate::runtime_exception::Exception {
|
||||
id: name_id,
|
||||
file: file!().as_c_slice(),
|
||||
line: line!(),
|
||||
column: column!(),
|
||||
// https://github.com/rust-lang/rfcs/pull/1719
|
||||
function: $fn_name.as_c_slice(),
|
||||
message: $message.as_c_slice(),
|
||||
param: [$param0, $param1, $param2],
|
||||
};
|
||||
#[allow(unused_unsafe)]
|
||||
unsafe {
|
||||
$crate::runtime_exception::raise(&exn)
|
||||
}
|
||||
}};
|
||||
($name:expr, $fn_name:expr, $message:expr) => {{
|
||||
raise_exn!($name, $fn_name, $message, 0, 0, 0)
|
||||
}};
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `mat1` and `mat2` should point to a valid 1DArray of `f64` floats in row-major order
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn np_dot(mat1: *mut InputMatrix, mat2: *mut InputMatrix) -> f64 {
|
||||
let mat1 = mat1.as_mut().unwrap();
|
||||
let mat2 = mat2.as_mut().unwrap();
|
||||
|
||||
if !(mat1.ndims == 1 && mat2.ndims == 1) {
|
||||
raise_exn!(
|
||||
"ValueError",
|
||||
"np_dot",
|
||||
"expected 1D Vector Input, but received {0}-D and {1}-D input",
|
||||
mat1.ndims.try_into().unwrap(),
|
||||
mat2.ndims.try_into().unwrap(),
|
||||
0
|
||||
);
|
||||
}
|
||||
|
||||
let dim1 = (*mat1).get_dims();
|
||||
let dim2 = (*mat2).get_dims();
|
||||
|
||||
if dim1[0] != dim2[0] {
|
||||
raise_exn!(
|
||||
"ValueError",
|
||||
"np_dot",
|
||||
"shapes ({},) and ({},) not aligned",
|
||||
dim1[0].try_into().unwrap(),
|
||||
dim2[0].try_into().unwrap(),
|
||||
0
|
||||
);
|
||||
}
|
||||
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0]) };
|
||||
let data_slice2 = unsafe { slice::from_raw_parts_mut(mat2.data, dim2[0]) };
|
||||
|
||||
let matrix1 = DMatrix::from_row_slice(dim1[0], 1, data_slice1);
|
||||
let matrix2 = DMatrix::from_row_slice(dim2[0], 1, data_slice2);
|
||||
|
||||
matrix1.dot(&matrix2)
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `mat1` and `mat2` should point to a valid 2DArray of `f64` floats in row-major order
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn np_linalg_matmul(
|
||||
mat1: *mut InputMatrix,
|
||||
mat2: *mut InputMatrix,
|
||||
out: *mut InputMatrix,
|
||||
) {
|
||||
let mat1 = mat1.as_mut().unwrap();
|
||||
let mat2 = mat2.as_mut().unwrap();
|
||||
let out = out.as_mut().unwrap();
|
||||
|
||||
if !(mat1.ndims == 2 && mat2.ndims == 2) {
|
||||
raise_exn!(
|
||||
"ValueError",
|
||||
"np_matmul",
|
||||
"expected 2D Vector Input, but received {0}-D and {1}-D input",
|
||||
mat1.ndims.try_into().unwrap(),
|
||||
mat2.ndims.try_into().unwrap(),
|
||||
0
|
||||
);
|
||||
}
|
||||
|
||||
let dim1 = (*mat1).get_dims();
|
||||
let dim2 = (*mat2).get_dims();
|
||||
|
||||
if dim1[1] != dim2[0] {
|
||||
let err_msg = format!(
|
||||
"shapes ({},{}) and ({},{}) not aligned: {} (dim 1) != {} (dim 0)",
|
||||
dim1[0], dim1[1], dim2[0], dim2[1], dim1[1], dim2[0]
|
||||
);
|
||||
raise_exn!("ValueError", "np_matmul", err_msg);
|
||||
}
|
||||
|
||||
let outdim = out.get_dims();
|
||||
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
|
||||
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||
let data_slice2 = unsafe { slice::from_raw_parts_mut(mat2.data, dim2[0] * dim2[1]) };
|
||||
|
||||
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||
let matrix2 = DMatrix::from_row_slice(dim2[0], dim2[1], data_slice2);
|
||||
let mut result = DMatrix::<f64>::zeros(outdim[0], outdim[1]);
|
||||
|
||||
matrix1.mul_to(&matrix2, &mut result);
|
||||
out_slice.copy_from_slice(result.transpose().as_slice());
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn np_linalg_cholesky(mat1: *mut InputMatrix, out: *mut InputMatrix) {
|
||||
let mat1 = mat1.as_mut().unwrap();
|
||||
let out = out.as_mut().unwrap();
|
||||
|
||||
if mat1.ndims != 2 {
|
||||
raise_exn!(
|
||||
"ValueError",
|
||||
"np_linalg_cholesky",
|
||||
"expected 2D Vector Input, but received {0}-D input",
|
||||
mat1.ndims.try_into().unwrap(),
|
||||
0,
|
||||
0
|
||||
);
|
||||
}
|
||||
|
||||
let dim1 = (*mat1).get_dims();
|
||||
if dim1[0] != dim1[1] {
|
||||
raise_exn!(
|
||||
"LinAlgError",
|
||||
"np_linalg_cholesky",
|
||||
"Last 2 dimensions of the array must be square: {0} != {1}",
|
||||
dim1[0].try_into().unwrap(),
|
||||
dim1[1].try_into().unwrap(),
|
||||
0
|
||||
);
|
||||
}
|
||||
|
||||
let outdim = out.get_dims();
|
||||
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
|
||||
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||
|
||||
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||
let result = matrix1.cholesky();
|
||||
match result {
|
||||
Some(res) => {
|
||||
out_slice.copy_from_slice(res.unpack().transpose().as_slice());
|
||||
}
|
||||
None => {
|
||||
raise_exn!("LinAlgError", "np_linalg_cholesky", "Matrix is not positive definite");
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn np_linalg_qr(
|
||||
mat1: *mut InputMatrix,
|
||||
outq: *mut InputMatrix,
|
||||
outr: *mut InputMatrix,
|
||||
) {
|
||||
let mat1 = mat1.as_mut().unwrap();
|
||||
let outq = outq.as_mut().unwrap();
|
||||
let outr = outr.as_mut().unwrap();
|
||||
|
||||
if mat1.ndims != 2 {
|
||||
raise_exn!(
|
||||
"ValueError",
|
||||
"np_linalg_cholesky",
|
||||
"expected 2D Vector Input, but received {0}-D input",
|
||||
mat1.ndims.try_into().unwrap(),
|
||||
0,
|
||||
0
|
||||
);
|
||||
}
|
||||
|
||||
let dim1 = (*mat1).get_dims();
|
||||
let outq_dim = (*outq).get_dims();
|
||||
let outr_dim = (*outr).get_dims();
|
||||
|
||||
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||
let out_q_slice = unsafe { slice::from_raw_parts_mut(outq.data, outq_dim[0] * outq_dim[1]) };
|
||||
let out_r_slice = unsafe { slice::from_raw_parts_mut(outr.data, outr_dim[0] * outr_dim[1]) };
|
||||
|
||||
// Refer to https://github.com/dimforge/nalgebra/issues/735
|
||||
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||
|
||||
let res = matrix1.qr();
|
||||
let (q, r) = res.unpack();
|
||||
|
||||
// Uses different algo need to match numpy
|
||||
out_q_slice.copy_from_slice(q.transpose().as_slice());
|
||||
out_r_slice.copy_from_slice(r.transpose().as_slice());
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn np_linalg_svd(
|
||||
mat1: *mut InputMatrix,
|
||||
outu: *mut InputMatrix,
|
||||
outs: *mut InputMatrix,
|
||||
outvh: *mut InputMatrix,
|
||||
) {
|
||||
let mat1 = mat1.as_mut().unwrap();
|
||||
let outu = outu.as_mut().unwrap();
|
||||
let outs = outs.as_mut().unwrap();
|
||||
let outvh = outvh.as_mut().unwrap();
|
||||
|
||||
if mat1.ndims != 2 {
|
||||
raise_exn!(
|
||||
"ValueError",
|
||||
"np_linalg_svd",
|
||||
"expected 2D Vector Input, but received {0}-D input",
|
||||
mat1.ndims.try_into().unwrap(),
|
||||
0,
|
||||
0
|
||||
);
|
||||
}
|
||||
|
||||
let dim1 = (*mat1).get_dims();
|
||||
let outu_dim = (*outu).get_dims();
|
||||
let outs_dim = (*outs).get_dims();
|
||||
let outvh_dim = (*outvh).get_dims();
|
||||
|
||||
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
|
||||
let out_u_slice = unsafe { slice::from_raw_parts_mut(outu.data, outu_dim[0] * outu_dim[1]) };
|
||||
let out_s_slice = unsafe { slice::from_raw_parts_mut(outs.data, outs_dim[0]) };
|
||||
let out_vh_slice =
|
||||
unsafe { slice::from_raw_parts_mut(outvh.data, outvh_dim[0] * outvh_dim[1]) };
|
||||
|
||||
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
|
||||
let result = matrix.svd(true, true);
|
||||
out_u_slice.copy_from_slice(result.u.unwrap().transpose().as_slice());
|
||||
out_s_slice.copy_from_slice(result.singular_values.as_slice());
|
||||
out_vh_slice.copy_from_slice(result.v_t.unwrap().transpose().as_slice());
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `data` must point to an array of 4 elements in row-major order
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn np_linalg_inv(
|
||||
dim1_0: usize,
|
||||
dim1_1: usize,
|
||||
x1: *mut f64,
|
||||
dim2_0: usize,
|
||||
dim2_1: usize,
|
||||
out: *mut f64,
|
||||
) -> i8 {
|
||||
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
|
||||
let out_slice = unsafe { slice::from_raw_parts_mut(out, dim2_0 * dim2_1) };
|
||||
|
||||
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
|
||||
if !matrix.is_invertible() {
|
||||
// raise error
|
||||
return 0;
|
||||
}
|
||||
let inv = matrix.try_inverse().unwrap();
|
||||
|
||||
out_slice.copy_from_slice(inv.transpose().as_slice());
|
||||
1
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `data` must point to an array of 4 elements in row-major order
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn np_linalg_pinv(
|
||||
dim1_0: usize,
|
||||
dim1_1: usize,
|
||||
x1: *mut f64,
|
||||
dim2_0: usize,
|
||||
dim2_1: usize,
|
||||
out: *mut f64,
|
||||
) -> i8 {
|
||||
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
|
||||
let out_slice = unsafe { slice::from_raw_parts_mut(out, dim2_0 * dim2_1) };
|
||||
|
||||
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
|
||||
let svd = matrix.svd(true, true);
|
||||
let inv = svd.pseudo_inverse(1e-15);
|
||||
|
||||
match inv {
|
||||
Ok(m) => {
|
||||
out_slice.copy_from_slice(m.transpose().as_slice());
|
||||
1
|
||||
}
|
||||
Err(_e) => {
|
||||
// raise exception here
|
||||
0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `data` must point to an array of 4 elements in row-major order
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sp_linalg_lu(
|
||||
dim1_0: usize,
|
||||
dim1_1: usize,
|
||||
x1: *mut f64,
|
||||
dim2_0: usize,
|
||||
dim2_1: usize,
|
||||
out_l: *mut f64,
|
||||
dim3_0: usize,
|
||||
dim3_1: usize,
|
||||
out_u: *mut f64,
|
||||
) -> i8 {
|
||||
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
|
||||
let out_l_slice = unsafe { slice::from_raw_parts_mut(out_l, dim2_0 * dim2_1) };
|
||||
let out_u_slice = unsafe { slice::from_raw_parts_mut(out_u, dim3_0 * dim3_1) };
|
||||
|
||||
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
|
||||
let (_, l, u) = matrix.lu().unpack();
|
||||
|
||||
out_l_slice.copy_from_slice(l.transpose().as_slice());
|
||||
out_u_slice.copy_from_slice(u.transpose().as_slice());
|
||||
|
||||
1
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `data` must point to an array of 4 elements in row-major order
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sp_linalg_schur(
|
||||
dim1_0: usize,
|
||||
dim1_1: usize,
|
||||
x1: *mut f64,
|
||||
dim2_0: usize,
|
||||
dim2_1: usize,
|
||||
out_t: *mut f64,
|
||||
dim3_0: usize,
|
||||
dim3_1: usize,
|
||||
out_z: *mut f64,
|
||||
) -> i8 {
|
||||
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
|
||||
let out_t_slice = unsafe { slice::from_raw_parts_mut(out_t, dim2_0 * dim2_1) };
|
||||
let out_z_slice = unsafe { slice::from_raw_parts_mut(out_z, dim3_0 * dim3_1) };
|
||||
|
||||
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
|
||||
if !matrix.is_square() {
|
||||
// Throw error here
|
||||
return 0;
|
||||
}
|
||||
let (z, t) = matrix.schur().unpack();
|
||||
|
||||
out_t_slice.copy_from_slice(t.transpose().as_slice());
|
||||
out_z_slice.copy_from_slice(z.transpose().as_slice());
|
||||
|
||||
1
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `data` must point to an array of 4 elements in row-major order
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sp_linalg_hessenberg(
|
||||
dim1_0: usize,
|
||||
dim1_1: usize,
|
||||
x1: *mut f64,
|
||||
dim2_0: usize,
|
||||
dim2_1: usize,
|
||||
out_h: *mut f64,
|
||||
) -> i8 {
|
||||
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
|
||||
let out_h_slice = unsafe { slice::from_raw_parts_mut(out_h, dim2_0 * dim2_1) };
|
||||
|
||||
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
|
||||
if !matrix.is_square() {
|
||||
// Throw error here
|
||||
|
||||
return 0;
|
||||
}
|
||||
let (_, h) = matrix.hessenberg().unpack();
|
||||
|
||||
out_h_slice.copy_from_slice(h.transpose().as_slice());
|
||||
|
||||
1
|
||||
}
|
66
nac3standalone/linalg_externfns/src/runtime_exception.rs
Normal file
66
nac3standalone/linalg_externfns/src/runtime_exception.rs
Normal file
@ -0,0 +1,66 @@
|
||||
#![allow(non_camel_case_types)]
|
||||
#![allow(unused)]
|
||||
|
||||
// ARTIQ Exception struct declaration
|
||||
use cslice::CSlice;
|
||||
|
||||
// Note: CSlice within an exception may not be actual cslice, they may be strings that exist only
|
||||
// in the host. If the length == usize:MAX, the pointer is actually a string key in the host.
|
||||
#[repr(C)]
|
||||
#[derive(Clone)]
|
||||
pub struct Exception<'a> {
|
||||
pub id: u32,
|
||||
pub file: CSlice<'a, u8>,
|
||||
pub line: u32,
|
||||
pub column: u32,
|
||||
pub function: CSlice<'a, u8>,
|
||||
pub message: CSlice<'a, u8>,
|
||||
pub param: [i64; 3],
|
||||
}
|
||||
|
||||
fn str_err(_: core::str::Utf8Error) -> core::fmt::Error {
|
||||
core::fmt::Error
|
||||
}
|
||||
|
||||
fn exception_str<'a>(s: &'a CSlice<'a, u8>) -> Result<&'a str, core::str::Utf8Error> {
|
||||
if s.len() == usize::MAX {
|
||||
Ok("<host string>")
|
||||
} else {
|
||||
core::str::from_utf8(s.as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn raise(exception: *const Exception) -> ! {
|
||||
let e = &*exception;
|
||||
let f1 = exception_str(&e.function).map_err(str_err).unwrap();
|
||||
let f2 = exception_str(&e.file).map_err(str_err).unwrap();
|
||||
let f3 = exception_str(&e.message).map_err(str_err).unwrap();
|
||||
|
||||
panic!("Exception {} from {} in {}:{}:{}, message: {}", e.id, f1, f2, e.line, e.column, f3);
|
||||
}
|
||||
|
||||
static EXCEPTION_ID_LOOKUP: [(&str, u32); 14] = [
|
||||
("RuntimeError", 0),
|
||||
("RTIOUnderflow", 1),
|
||||
("RTIOOverflow", 2),
|
||||
("RTIODestinationUnreachable", 3),
|
||||
("DMAError", 4),
|
||||
("I2CError", 5),
|
||||
("CacheError", 6),
|
||||
("SPIError", 7),
|
||||
("ZeroDivisionError", 8),
|
||||
("IndexError", 9),
|
||||
("UnwrapNoneError", 10),
|
||||
("Value", 11),
|
||||
("ValueError", 12),
|
||||
("LinAlgError", 13),
|
||||
];
|
||||
|
||||
pub fn get_exception_id(name: &str) -> u32 {
|
||||
for (n, id) in EXCEPTION_ID_LOOKUP.iter() {
|
||||
if *n == name {
|
||||
return *id;
|
||||
}
|
||||
}
|
||||
unimplemented!("unallocated internal exception id")
|
||||
}
|
BIN
pyo3_output/nac3artiq.so
Executable file
BIN
pyo3_output/nac3artiq.so
Executable file
Binary file not shown.
Loading…
Reference in New Issue
Block a user