Error Interface Added

This commit is contained in:
abdul124 2024-07-24 18:07:55 +08:00
parent 7ec36e80f7
commit 8655a5f0c7
21 changed files with 1787 additions and 536 deletions

23
Cargo.lock generated
View File

@ -256,6 +256,12 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
[[package]]
name = "cslice"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f8cb7306107e4b10e64994de6d3274bd08996a7c1322a27b86482392f96be0a"
[[package]] [[package]]
name = "dirs-next" name = "dirs-next"
version = "2.0.0" version = "2.0.0"
@ -314,13 +320,6 @@ dependencies = [
"windows-sys", "windows-sys",
] ]
[[package]]
name = "externfns"
version = "0.1.0"
dependencies = [
"nalgebra",
]
[[package]] [[package]]
name = "fastrand" name = "fastrand"
version = "2.1.0" version = "2.1.0"
@ -553,6 +552,14 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "linalg_externfns"
version = "0.1.0"
dependencies = [
"cslice",
"nalgebra",
]
[[package]] [[package]]
name = "linked-hash-map" name = "linked-hash-map"
version = "0.5.6" version = "0.5.6"
@ -638,7 +645,6 @@ name = "nac3core"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"crossbeam", "crossbeam",
"externfns",
"indexmap 2.2.6", "indexmap 2.2.6",
"indoc", "indoc",
"inkwell", "inkwell",
@ -682,6 +688,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"clap", "clap",
"inkwell", "inkwell",
"linalg_externfns",
"nac3core", "nac3core",
"nac3parser", "nac3parser",
"parking_lot", "parking_lot",

View File

@ -4,7 +4,7 @@ members = [
"nac3ast", "nac3ast",
"nac3parser", "nac3parser",
"nac3core", "nac3core",
"nac3core/src/codegen/externfns", "nac3standalone/linalg_externfns",
"nac3standalone", "nac3standalone",
"nac3artiq", "nac3artiq",
"runkernel", "runkernel",

View File

@ -161,7 +161,9 @@
clippy clippy
pre-commit pre-commit
rustfmt rustfmt
rust-analyzer
]; ];
RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
}; };
devShells.x86_64-linux.msys2 = pkgs.mkShell { devShells.x86_64-linux.msys2 = pkgs.mkShell {
name = "nac3-dev-shell-msys2"; name = "nac3-dev-shell-msys2";

View File

@ -11,7 +11,6 @@ indexmap = "2.2"
parking_lot = "0.12" parking_lot = "0.12"
rayon = "1.8" rayon = "1.8"
nac3parser = { path = "../nac3parser" } nac3parser = { path = "../nac3parser" }
externfns = { path = "src/codegen/externfns" }
strum = "0.26.2" strum = "0.26.2"
strum_macros = "0.26.4" strum_macros = "0.26.4"

View File

@ -1,5 +1,5 @@
use inkwell::types::BasicTypeEnum; use inkwell::types::BasicTypeEnum;
use inkwell::values::BasicValueEnum; use inkwell::values::{BasicValue, BasicValueEnum};
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel}; use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
use itertools::Itertools; use itertools::Itertools;
@ -31,7 +31,6 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let (n_ty, n) = n; let (n_ty, n) = n;
Ok(match n { Ok(match n {
BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => { BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool)); debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool));
@ -1836,231 +1835,762 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
}) })
} }
/// Invokes the `linalg_try_invert_to` function // Linalg Methods
pub fn call_linalg_try_invert_to<'ctx, G: CodeGenerator + ?Sized>( pub fn call_np_dot<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
a: (Type, BasicValueEnum<'ctx>), x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "linalg_try_invert_to"; const FN_NAME: &str = "np_dot";
let (a_ty, a) = a; let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
match a { let one = llvm_usize.const_int(1, false);
BasicValueEnum::PointerValue(n)
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) = (x1, x2) {
{ let (n1_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); let n1_elem_ty = ctx.get_llvm_type(generator, n1_elem_ty);
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); let (n2_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
match llvm_ndarray_ty { let n2_elem_ty = ctx.get_llvm_type(generator, n2_elem_ty);
BasicTypeEnum::FloatType(_) => {}
_ => { let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty)
unimplemented!("Inverse Operation supported on float type NDArray Values only") else {
} unimplemented!("{FN_NAME} operates on float type NdArrays only");
}; };
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None)); let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None);
// The following constraints must be satisfied: // The following constraints must be satisfied:
// * Input must be 2D // * Input must be 1D
// * number of rows should equal number of columns (square matrix) // * Number of elements in two matrices must equal
if cfg!(debug_assertions) { if cfg!(debug_assertions) {
let n_dims = n.load_ndims(ctx); let n1_dims = n1.load_ndims(ctx);
let n2_dims = n2.load_ndims(ctx);
// num_dim == 2 let n1_dims_eq1 =
ctx.builder.build_int_compare(IntPredicate::EQ, n1_dims, one, "").unwrap();
let n2_dims_eq1 =
ctx.builder.build_int_compare(IntPredicate::EQ, n2_dims, one, "").unwrap();
// num_dim = 1
ctx.make_assert( ctx.make_assert(
generator, generator,
ctx.builder n1_dims_eq1,
.build_int_compare(
IntPredicate::EQ,
n_dims,
llvm_usize.const_int(2, false),
"",
)
.unwrap(),
"0:ValueError", "0:ValueError",
format!("Input matrix must have two dimensions for {FN_NAME}").as_str(), format!("{FN_NAME} operates on 1D matrices").as_str(),
[None, None, None], [None, None, None],
ctx.current_loc, ctx.current_loc,
); );
let dim0 = unsafe {
n.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
// dim0 == dim1
ctx.make_assert( ctx.make_assert(
generator, generator,
ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim1, "").unwrap(), n2_dims_eq1,
"0:ValueError", "0:ValueError",
format!( format!("{FN_NAME} operates on 1D matrices").as_str(),
"Input matrix should have equal number of rows and columns for {FN_NAME}"
)
.as_str(),
[None, None, None], [None, None, None],
ctx.current_loc, ctx.current_loc,
); );
}
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { // equal number of elements
let n_sz_eqz = ctx let n1_sz = irrt::call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
.builder let n2_sz = irrt::call_ndarray_calc_size(generator, ctx, &n2.dim_sizes(), (None, None));
.build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "")
.unwrap(); let size_eq =
ctx.builder.build_int_compare(IntPredicate::EQ, n1_sz, n2_sz, "").unwrap();
ctx.make_assert( ctx.make_assert(
generator, generator,
n_sz_eqz, size_eq,
"0:ValueError", "0:ValueError",
format!("zero-size array to inverse operation {FN_NAME}").as_str(), format!("The operands of {FN_NAME} must have equal length").as_str(),
[None, None, None], [None, None, None],
ctx.current_loc, ctx.current_loc,
); );
} }
let dim0 = unsafe { let dim0 = unsafe {
n.dim_sizes() n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value() .into_int_value()
}; };
let dim1 = unsafe {
n.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
Ok(extern_fns::call_linalg_try_invert_to( Ok(extern_fns::call_np_dot(
ctx, ctx,
dim0, (dim0, one, n1.data().base_ptr(ctx, generator)),
dim1, (dim0, one, n2.data().base_ptr(ctx, generator)),
n.data().base_ptr(ctx, generator),
None, None,
) )
.into()) .into())
} } else {
_ => unsupported_type(ctx, FN_NAME, &[a_ty]), unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
} }
} }
/// Invokes the `linalg_wilkinson_shift` function pub fn call_np_linalg_matmul<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_linalg_wilkinson_shift<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
a: (Type, BasicValueEnum<'ctx>), x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "linalg_wilkinson_shift"; const FN_NAME: &str = "np_linalg_matmul";
let (a_ty, a) = a; let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let one = llvm_usize.const_int(1, false); let one = llvm_usize.const_int(1, false);
let two = llvm_usize.const_int(2, false); let two = llvm_usize.const_int(2, false);
match a { if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) = (x1, x2) {
BasicValueEnum::PointerValue(n) let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
{ let (n2_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); let n2_elem_ty = ctx.get_llvm_type(generator, n2_elem_ty);
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
match llvm_ndarray_ty { let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty)
BasicTypeEnum::FloatType(_) | BasicTypeEnum::IntType(_) => {} else {
_ => unimplemented!( unimplemented!("{FN_NAME} operates on float type NdArrays only");
"Wilkinson Shift Operation supported on float type NDArray Values only"
),
}; };
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None);
// The following constraints must be satisfied: // The following constraints must be satisfied:
// * Input must be 2D // * Input must be 2D
// * Number of rows and columns should equal 2 // * Number of columns of first matrix should equal number of rows of second
// * Input matrix must be symmetric if true {
if cfg!(debug_assertions) { let n1_dims = n1.load_ndims(ctx);
let n_dims = n.load_ndims(ctx); let n2_dims = n2.load_ndims(ctx);
// num_dim == 2 let n1_dims_eq2 =
ctx.builder.build_int_compare(IntPredicate::EQ, n1_dims, two, "").unwrap();
let n2_dims_eq2 =
ctx.builder.build_int_compare(IntPredicate::EQ, n2_dims, two, "").unwrap();
// num_dim = 2
ctx.make_assert( ctx.make_assert(
generator, generator,
ctx.builder.build_int_compare(IntPredicate::EQ, n_dims, two, "").unwrap(), n1_dims_eq2,
"0:ValueError", "0:ValueError",
format!("Input matrix must have two dimensions for {FN_NAME}").as_str(), format!("{FN_NAME} operates on 2D matrices").as_str(),
[None, None, None], [None, None, None],
ctx.current_loc, ctx.current_loc,
); );
let dim0 = unsafe { ctx.make_assert(
n.dim_sizes() generator,
n2_dims_eq2,
"0:ValueError",
format!("{FN_NAME} operates on 2D matrices").as_str(),
[None, None, None],
ctx.current_loc,
);
// matrix must be compatible for multiplication
let n1_col = unsafe {
n1.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value()
};
let n2_col = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value() .into_int_value()
}; };
let dim1 = unsafe {
n.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value()
};
// dim0 == 2 let dim_eq =
ctx.builder.build_int_compare(IntPredicate::EQ, n1_col, n2_col, "").unwrap();
ctx.make_assert( ctx.make_assert(
generator, generator,
ctx.builder.build_int_compare(IntPredicate::EQ, dim0, two, "").unwrap(), dim_eq,
"0:ValueError", "0:ValueError",
format!("Number of rows must be 2 for {FN_NAME}").as_str(), format!("Columns of first matrix must equal rows of second for {FN_NAME}").as_str(),
[None, None, None],
ctx.current_loc,
);
// dim1 == 2
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, dim1, two, "").unwrap(),
"0:ValueError",
format!("Number of columns must be 2 for {FN_NAME}").as_str(),
[None, None, None],
ctx.current_loc,
);
let entry_01 = unsafe {
n.data().get_unchecked(ctx, generator, &one, None).into_float_value()
};
let entry_10 = unsafe {
n.data().get_unchecked(ctx, generator, &two, None).into_float_value()
};
// symmetric matrix
ctx.make_assert(
generator,
ctx.builder
.build_float_compare(FloatPredicate::OEQ, entry_01, entry_10, "")
.unwrap(),
"0:ValueError",
format!("Input Matrix must be symmetric for {FN_NAME}").as_str(),
[None, None, None], [None, None, None],
ctx.current_loc, ctx.current_loc,
); );
} }
let out_dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let out_dim1 =
unsafe { n2.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value() };
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[out_dim0, out_dim1])
.unwrap();
let dim0 = unsafe { let dim0 = unsafe {
n.dim_sizes() n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value() .into_int_value()
}; };
let dim1 = let dim1 =
unsafe { n.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value() }; unsafe { n1.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value() };
let dim2 =
unsafe { n2.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value() };
Ok(extern_fns::call_linalg_wilkinson_shift( // let r = ctx.ctx.const_string(string, null_terminated);
extern_fns::call_np_linalg_matmul(
ctx, ctx,
dim0, (dim0, dim1, n1.data().base_ptr(ctx, generator)),
dim1, (dim1, dim2, n2.data().base_ptr(ctx, generator)),
n.data().base_ptr(ctx, generator), (dim0, dim2, out.data().base_ptr(ctx, generator)),
None, None,
) );
.into()) Ok(out.as_base_value().as_basic_value_enum())
} } else {
_ => unsupported_type(ctx, FN_NAME, &[a_ty]), unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
}
}
pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_cholesky";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
let one = llvm_usize.const_int(1, false);
let two = llvm_usize.const_int(2, false);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
// The following constraints must be satisfied:
// * Input must be 2D
// * Input must be a square matrix (here we assume it is symmetric)
if cfg!(debug_assertions) {
let n1_dims = n1.load_ndims(ctx);
let n1_dims_eq2 =
ctx.builder.build_int_compare(IntPredicate::EQ, n1_dims, two, "").unwrap();
// num_dim = 2
ctx.make_assert(
generator,
n1_dims_eq2,
"0:ValueError",
format!("{FN_NAME} operates on 2D matrices").as_str(),
[None, None, None],
ctx.current_loc,
);
// Square Matrix
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value()
};
let dim_match =
ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim1, "").unwrap();
ctx.make_assert(
generator,
dim_match,
"0:ValueError",
format!("Input matrix must be a square matrix {FN_NAME}").as_str(),
[None, None, None],
ctx.current_loc,
);
}
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 =
unsafe { n1.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value() };
let out =
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1]).unwrap();
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
extern_fns::call_np_linalg_cholesky(
ctx,
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
(dim0, dim1, out.data().base_ptr(ctx, generator)),
None,
);
Ok(out.as_base_value().as_basic_value_enum())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_qr";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
let one = llvm_usize.const_int(1, false);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 =
unsafe { n1.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value() };
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k]).unwrap();
let out_r = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1]).unwrap();
extern_fns::call_np_linalg_qr(
ctx,
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
(dim0, k, out_q.data().base_ptr(ctx, generator)),
(k, dim1, out_r.data().base_ptr(ctx, generator)),
None,
);
let out_q = out_q.as_base_value().as_basic_value_enum();
let out_r = out_r.as_base_value().as_basic_value_enum();
let res_ty = ctx.ctx.struct_type(&[out_q.get_type(), out_r.get_type()], false);
let res_ptr = ctx.builder.build_alloca(res_ty, "QR_factorization").unwrap();
let res_val = [out_q, out_r];
for (i, v) in res_val.into_iter().enumerate() {
unsafe {
let ptr = ctx
.builder
.build_in_bounds_gep(
res_ptr,
&[
ctx.ctx.i32_type().const_zero(),
ctx.ctx.i32_type().const_int(i as u64, false),
],
"ptr",
)
.unwrap();
ctx.builder.build_store(ptr, v).unwrap();
}
}
Ok(ctx.builder.build_load(res_ptr, "QR_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_svd";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
let out_u =
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]).unwrap();
let out_s = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k]).unwrap();
let out_vh =
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim1]).unwrap();
extern_fns::call_np_linalg_svd(
ctx,
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
(dim0, dim0, out_u.data().base_ptr(ctx, generator)),
(k, llvm_usize.const_int(1, false), out_s.data().base_ptr(ctx, generator)),
(dim1, dim1, out_vh.data().base_ptr(ctx, generator)),
None,
);
let out_u = out_u.as_base_value().as_basic_value_enum();
let out_s = out_s.as_base_value().as_basic_value_enum();
let out_vh = out_vh.as_base_value().as_basic_value_enum();
let res_ty =
ctx.ctx.struct_type(&[out_u.get_type(), out_s.get_type(), out_vh.get_type()], false);
let res_ptr = ctx.builder.build_alloca(res_ty, "SVD_factorization").unwrap();
let res_val = [out_u, out_s, out_vh];
for (i, v) in res_val.into_iter().enumerate() {
unsafe {
let ptr = ctx
.builder
.build_in_bounds_gep(
res_ptr,
&[
ctx.ctx.i32_type().const_zero(),
ctx.ctx.i32_type().const_int(i as u64, false),
],
"ptr",
)
.unwrap();
ctx.builder.build_store(ptr, v).unwrap();
}
}
Ok(ctx.builder.build_load(res_ptr, "SVD_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_inv";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out =
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1]).unwrap();
extern_fns::call_np_linalg_inv(
ctx,
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
(dim0, dim1, out.data().base_ptr(ctx, generator)),
None,
);
Ok(out.as_base_value().as_basic_value_enum())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_pinv";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out =
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim0]).unwrap();
extern_fns::call_np_linalg_pinv(
ctx,
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
(dim1, dim0, out.data().base_ptr(ctx, generator)),
None,
);
Ok(out.as_base_value().as_basic_value_enum())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "sp_linalg_lu";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
let out_l = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k]).unwrap();
let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1]).unwrap();
extern_fns::call_sp_linalg_lu(
ctx,
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
(dim0, k, out_l.data().base_ptr(ctx, generator)),
(k, dim1, out_u.data().base_ptr(ctx, generator)),
None,
);
let out_l = out_l.as_base_value().as_basic_value_enum();
let out_u = out_u.as_base_value().as_basic_value_enum();
let res_ty = ctx.ctx.struct_type(&[out_l.get_type(), out_u.get_type()], false);
let res_ptr = ctx.builder.build_alloca(res_ty, "LU_factorization").unwrap();
let res_val = [out_l, out_u];
for (i, v) in res_val.into_iter().enumerate() {
unsafe {
let ptr = ctx
.builder
.build_in_bounds_gep(
res_ptr,
&[
ctx.ctx.i32_type().const_zero(),
ctx.ctx.i32_type().const_int(i as u64, false),
],
"ptr",
)
.unwrap();
ctx.builder.build_store(ptr, v).unwrap();
}
}
Ok(ctx.builder.build_load(res_ptr, "LU_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
// Must be square (add check later)
pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "sp_linalg_schur";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out_t =
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]).unwrap();
let out_z =
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]).unwrap();
extern_fns::call_sp_linalg_schur(
ctx,
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
(dim0, dim0, out_t.data().base_ptr(ctx, generator)),
(dim0, dim0, out_z.data().base_ptr(ctx, generator)),
None,
);
let out_t = out_t.as_base_value().as_basic_value_enum();
let out_z = out_z.as_base_value().as_basic_value_enum();
let res_ty = ctx.ctx.struct_type(&[out_t.get_type(), out_z.get_type()], false);
let res_ptr = ctx.builder.build_alloca(res_ty, "Schur_factorization").unwrap();
let r = ctx
.ctx
.const_string(ctx.current_loc.file.0.to_string().as_bytes(), true)
.as_basic_value_enum()
.into_pointer_value();
let res_val = [out_t, out_z];
for (i, v) in res_val.into_iter().enumerate() {
unsafe {
let ptr = ctx
.builder
.build_in_bounds_gep(
res_ptr,
&[
ctx.ctx.i32_type().const_zero(),
ctx.ctx.i32_type().const_int(i as u64, false),
],
"ptr",
)
.unwrap();
ctx.builder.build_store(ptr, v).unwrap();
}
}
Ok(ctx.builder.build_load(res_ptr, "Schur_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
// Must be square (add check later)
pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "sp_linalg_hessenberg";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
// Check if matrix is square
// ctx.builder.build_select(
// ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim1, "").unwrap(),
// {
// let func =
// }, else_, name)
// ;
// ctx.builder.build_call(
// ctx.module.get_function("__nac3_raise"),
// &[]
// )
// let err_msg = ctx.gen_string(generator, "{FN_NAME} requires square matrix");
// ctx.raise_exn(generator, "0:ValueError", err_msg, [None, None, None], ctx.current_loc);
let out_h =
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]).unwrap();
extern_fns::call_sp_linalg_hessenberg(
ctx,
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
(dim0, dim0, out_h.data().base_ptr(ctx, generator)),
None,
);
Ok(out_h.as_base_value().as_basic_value_enum())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
} }
} }

View File

@ -131,90 +131,153 @@ pub fn call_ldexp<'ctx>(
.unwrap() .unwrap()
} }
/// Invokes the [`try_invert_to`](https://docs.rs/nalgebra/latest/nalgebra/linalg/fn.try_invert_to.html) function /// Macro to generate np_linalg external functions
pub fn call_linalg_try_invert_to<'ctx>( macro_rules! generate_np_linalg_extern_fn {
ctx: &CodeGenContext<'ctx, '_>, ($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 1) => {
dim0: IntValue<'ctx>, generate_np_linalg_extern_fn!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1);
dim1: IntValue<'ctx>, };
data: PointerValue<'ctx>, ($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 2) => {
generate_np_linalg_extern_fn!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1, mat2);
};
($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 3) => {
generate_np_linalg_extern_fn!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1, mat2, mat3);
};
($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 4) => {
generate_np_linalg_extern_fn!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1, mat2, mat3, mat4);
};
($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal $(,$input_matrix:ident)*) => {
#[doc = concat!("Invokes the numpy `", stringify!($extern_fn), " function." )]
pub fn $fn_name<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>
$(,$input_matrix: (IntValue<'ctx>, IntValue<'ctx>, PointerValue<'ctx>))*,
name: Option<&str>, name: Option<&str>,
) -> IntValue<'ctx> { ) -> $ret_ty<'ctx> {
const FN_NAME: &str = "linalg_try_invert_to"; const FN_NAME: &str = $extern_fn;
let llvm_f64 = ctx.ctx.f64_type();
let allowed_indices = [ctx.ctx.i32_type(), ctx.ctx.i64_type()];
let allowed_dim0 = allowed_indices.iter().any(|p| *p == dim0.get_type());
let allowed_dim1 = allowed_indices.iter().any(|p| *p == dim1.get_type());
debug_assert!(allowed_dim0);
debug_assert!(allowed_dim1);
debug_assert_eq!(data.get_type().get_element_type().into_float_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = ctx.ctx.i8_type().fn_type(
&[dim0.get_type().into(), dim0.get_type().into(), data.get_type().into()],
false,
);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[dim0.into(), dim1.into(), data.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`wilkinson_shift`](https://docs.rs/nalgebra/latest/nalgebra/linalg/fn.wilkinson_shift.html) function
pub fn call_linalg_wilkinson_shift<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
dim0: IntValue<'ctx>,
dim1: IntValue<'ctx>,
data: PointerValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "linalg_wilkinson_shift";
let llvm_f64 = ctx.ctx.f64_type(); let llvm_f64 = ctx.ctx.f64_type();
let allowed_index_types = [ctx.ctx.i32_type(), ctx.ctx.i64_type()]; let allowed_index_types = [ctx.ctx.i32_type(), ctx.ctx.i64_type()];
let allowed_dim0 = allowed_index_types.iter().any(|p| *p == dim0.get_type()); $(
let allowed_dim1 = allowed_index_types.iter().any(|p| *p == dim1.get_type()); debug_assert!(allowed_index_types.iter().any(|p| *p == $input_matrix.0.get_type()));
debug_assert!(allowed_index_types.iter().any(|p| *p == $input_matrix.1.get_type()));
debug_assert_eq!($input_matrix.2.get_type().get_element_type().into_float_type(), llvm_f64);
)*
debug_assert!(allowed_dim0); // let row = ctx.ctx.i32_type().const_int(ctx.current_loc.row.try_into().unwrap(), false);
debug_assert!(allowed_dim1); // let col = ctx.ctx.i32_type().const_int(ctx.current_loc.column.try_into().unwrap(), false);
debug_assert_eq!(data.get_type().get_element_type().into_float_type(), llvm_f64); // let file_name = ctx.current_loc.file.0;
// let name_len = ctx.ctx.i32_type().const_int(file_name.to_string().len().try_into().unwrap(), false);
// let file_name = ctx.ctx.const_string(&ctx.current_loc.file.0.to_string().into_bytes(), true);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| { let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = ctx.ctx.f64_type().fn_type( // let fn_type = ctx.ctx.$extern_ret_ty().fn_type(&[row.get_type().into(), col.get_type().into(), file_name.get_type().into(), $($input_matrix.0.get_type().into(), $input_matrix.1.get_type().into(), $input_matrix.2.get_type().into()),*], false);
&[dim0.get_type().into(), dim0.get_type().into(), data.get_type().into()], // let fn_type = ctx.ctx.$extern_ret_ty().fn_type(&[row.get_type().into(), col.get_type().into(), file_name.get_type().into(), name_len.get_type().into(), $($input_matrix.0.get_type().into(), $input_matrix.1.get_type().into(), $input_matrix.2.get_type().into()),*], false);
false, let fn_type = ctx.ctx.$extern_ret_ty().fn_type(&[$($input_matrix.0.get_type().into(), $input_matrix.1.get_type().into(), $input_matrix.2.get_type().into()),*], false);
);
let func = ctx.module.add_function(FN_NAME, fn_type, None); let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] { for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute( func.add_attribute(
AttributeLoc::Function, AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
); );
} }
func func
}); });
ctx.builder ctx.builder
.build_call(extern_fn, &[dim0.into(), dim1.into(), data.into()], name.unwrap_or_default()) // .build_call(extern_fn, &[row.into(), col.into(), file_name.into(), $($input_matrix.0.into(), $input_matrix.1.into(), $input_matrix.2.into(),)*], name.unwrap_or_default())
// .build_call(extern_fn, &[name_len.into(), col.into(), file_name.into(), row.into(), $($input_matrix.0.into(), $input_matrix.1.into(), $input_matrix.2.into(),)*], name.unwrap_or_default())
.build_call(extern_fn, &[$($input_matrix.0.into(), $input_matrix.1.into(), $input_matrix.2.into(),)*], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value) .map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value)) .map(|v| v.map_left($map_fn))
.map(Either::unwrap_left) .map(Either::unwrap_left)
.unwrap() .unwrap()
}
};
} }
generate_np_linalg_extern_fn!(
call_np_dot,
FloatValue,
f64_type,
BasicValueEnum::into_float_value,
"np_dot",
2
);
generate_np_linalg_extern_fn!(
call_np_linalg_matmul,
IntValue,
i8_type,
BasicValueEnum::into_int_value,
"np_linalg_matmul",
3
);
generate_np_linalg_extern_fn!(
call_np_linalg_cholesky,
IntValue,
i8_type,
BasicValueEnum::into_int_value,
"np_linalg_cholesky",
2
);
generate_np_linalg_extern_fn!(
call_np_linalg_qr,
IntValue,
i8_type,
BasicValueEnum::into_int_value,
"np_linalg_qr",
3
);
generate_np_linalg_extern_fn!(
call_np_linalg_svd,
IntValue,
i8_type,
BasicValueEnum::into_int_value,
"np_linalg_svd",
4
);
generate_np_linalg_extern_fn!(
call_np_linalg_inv,
IntValue,
i8_type,
BasicValueEnum::into_int_value,
"np_linalg_inv",
2
);
generate_np_linalg_extern_fn!(
call_np_linalg_pinv,
IntValue,
i8_type,
BasicValueEnum::into_int_value,
"np_linalg_pinv",
2
);
generate_np_linalg_extern_fn!(
call_sp_linalg_lu,
IntValue,
i8_type,
BasicValueEnum::into_int_value,
"sp_linalg_lu",
3
);
generate_np_linalg_extern_fn!(
call_sp_linalg_schur,
IntValue,
i8_type,
BasicValueEnum::into_int_value,
"sp_linalg_schur",
3
);
generate_np_linalg_extern_fn!(
call_sp_linalg_hessenberg,
IntValue,
i8_type,
BasicValueEnum::into_int_value,
"sp_linalg_hessenberg",
2
);

View File

@ -1,30 +0,0 @@
use core::slice;
use nalgebra::{linalg, DMatrix};
/// # Safety
///
/// `data` must point to an array with `dim0`x`dim1` elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn linalg_try_invert_to(dim0: usize, dim1: usize, data: *mut f64) -> i8 {
let data_slice = unsafe { slice::from_raw_parts_mut(data, dim0 * dim1) };
let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice);
let mut inverted_matrix = DMatrix::<f64>::zeros(dim0, dim1);
if linalg::try_invert_to(matrix, &mut inverted_matrix) {
data_slice.copy_from_slice(inverted_matrix.transpose().as_slice());
1
} else {
0
}
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn linalg_wilkinson_shift(dim0: usize, dim1: usize, data: *mut f64) -> f64 {
let data_slice = unsafe { slice::from_raw_parts_mut(data, dim0 * dim1) };
let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice);
linalg::wilkinson_shift(matrix[(0, 0)], matrix[(1, 1)], matrix[(0, 1)])
}

View File

@ -61,7 +61,7 @@ fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
/// * `shape` - The shape of the `NDArray`. /// * `shape` - The shape of the `NDArray`.
/// * `shape_len_fn` - A function that retrieves the number of dimensions from `shape`. /// * `shape_len_fn` - A function that retrieves the number of dimensions from `shape`.
/// * `shape_data_fn` - A function that retrieves the size of a dimension from `shape`. /// * `shape_data_fn` - A function that retrieves the size of a dimension from `shape`.
fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>( pub fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
elem_ty: Type, elem_ty: Type,
@ -157,7 +157,7 @@ where
/// ///
/// * `elem_ty` - The element type of the `NDArray`. /// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The shape of the `NDArray`, represented am array of [`IntValue`]s. /// * `shape` - The shape of the `NDArray`, represented am array of [`IntValue`]s.
fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type, elem_ty: Type,

View File

@ -557,7 +557,18 @@ impl<'a> BuiltinBuilder<'a> {
| PrimDef::FunNpHypot | PrimDef::FunNpHypot
| PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim), | PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim),
PrimDef::FunTryInvertTo | PrimDef::FunWilkinsonShift => self.build_linalg_methods(prim), PrimDef::FunNpDot
| PrimDef::FunNpLinalgMatmul
| PrimDef::FunNpLinalgCholesky
| PrimDef::FunNpLinalgQr
| PrimDef::FunNpLinalgSvd
| PrimDef::FunNpLinalgInv
| PrimDef::FunNpLinalgPinv
| PrimDef::FunSpLinalgLu
| PrimDef::FunSpLinalgSchur
| PrimDef::FunSpLinalgHessenberg => self.build_np_linalg_methods(prim),
// PrimDef::FunNpDot | PrimDef::FunNpLinalgMatmul => self.build_np_linalg_binary_methods(prim),
// PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgQr => self.build_np_linalg_unary_methods(prim),
}; };
if cfg!(debug_assertions) { if cfg!(debug_assertions) {
@ -1876,37 +1887,142 @@ impl<'a> BuiltinBuilder<'a> {
} }
} }
/// Build the functions `try_invert_to` and `wilkinson_shift` fn build_np_linalg_methods(&mut self, prim: PrimDef) -> TopLevelDef {
fn build_linalg_methods(&mut self, prim: PrimDef) -> TopLevelDef { debug_assert_prim_is_allowed(
debug_assert_prim_is_allowed(prim, &[PrimDef::FunTryInvertTo, PrimDef::FunWilkinsonShift]); prim,
&[
PrimDef::FunNpDot,
PrimDef::FunNpLinalgMatmul,
PrimDef::FunNpLinalgCholesky,
PrimDef::FunNpLinalgQr,
PrimDef::FunNpLinalgSvd,
PrimDef::FunNpLinalgInv,
PrimDef::FunNpLinalgPinv,
PrimDef::FunSpLinalgLu,
PrimDef::FunSpLinalgSchur,
PrimDef::FunSpLinalgHessenberg,
],
);
let ret_ty = match prim { match prim {
PrimDef::FunTryInvertTo => self.primitives.bool, PrimDef::FunNpDot => create_fn_by_codegen(
PrimDef::FunWilkinsonShift => self.primitives.float,
_ => unreachable!(),
};
let var_map = self.num_or_ndarray_var_map.clone();
create_fn_by_codegen(
self.unifier, self.unifier,
&var_map, &self.num_or_ndarray_var_map,
prim.name(), prim.name(),
ret_ty, self.primitives.float,
&[(self.ndarray_float_2d, "x")], &[(self.num_or_ndarray_ty.ty, "x1"), (self.num_or_ndarray_ty.ty, "x2")],
Box::new(move |ctx, _, fun, args, generator| { Box::new(move |ctx, _, fun, args, generator| {
let x_ty = fun.0.args[0].ty; let x1_ty = fun.0.args[0].ty;
let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(builtin_fns::call_np_dot(
generator,
ctx,
(x1_ty, x1_val),
(x2_ty, x2_val),
)?))
}),
),
PrimDef::FunNpLinalgMatmul => create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
self.ndarray_float_2d,
&[(self.ndarray_float_2d, "x1"), (self.ndarray_float_2d, "x2")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(builtin_fns::call_np_linalg_matmul(
generator,
ctx,
(x1_ty, x1_val),
(x2_ty, x2_val),
)?))
}),
),
PrimDef::FunNpLinalgCholesky
| PrimDef::FunNpLinalgInv
| PrimDef::FunNpLinalgPinv
| PrimDef::FunSpLinalgHessenberg => create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
self.ndarray_float_2d,
&[(self.ndarray_float_2d, "x1")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let func = match prim { let func = match prim {
PrimDef::FunTryInvertTo => builtin_fns::call_linalg_try_invert_to, PrimDef::FunNpLinalgCholesky => builtin_fns::call_np_linalg_cholesky,
PrimDef::FunWilkinsonShift => builtin_fns::call_linalg_wilkinson_shift, PrimDef::FunNpLinalgInv => builtin_fns::call_np_linalg_inv,
PrimDef::FunNpLinalgPinv => builtin_fns::call_np_linalg_pinv,
PrimDef::FunSpLinalgHessenberg => builtin_fns::call_sp_linalg_hessenberg,
_ => unreachable!(), _ => unreachable!(),
}; };
Ok(Some(func(generator, ctx, (x1_ty, x1_val))?))
}),
),
Ok(Some(func(generator, ctx, (x_ty, x_val))?)) PrimDef::FunNpLinalgQr | PrimDef::FunSpLinalgLu | PrimDef::FunSpLinalgSchur => {
let ret_ty = self.unifier.add_ty(TypeEnum::TTuple {
ty: vec![self.ndarray_float_2d, self.ndarray_float_2d],
});
create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
ret_ty,
&[(self.ndarray_float_2d, "x1")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val =
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let func = match prim {
PrimDef::FunNpLinalgQr => builtin_fns::call_np_linalg_qr,
PrimDef::FunSpLinalgLu => builtin_fns::call_sp_linalg_lu,
PrimDef::FunSpLinalgSchur => builtin_fns::call_sp_linalg_schur,
_ => unreachable!(),
};
Ok(Some(func(generator, ctx, (x1_ty, x1_val))?))
}), }),
) )
} }
PrimDef::FunNpLinalgSvd => {
let ret_ty = self.unifier.add_ty(TypeEnum::TTuple {
ty: vec![self.ndarray_float_2d, self.ndarray_float, self.ndarray_float_2d],
});
create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
ret_ty,
&[(self.ndarray_float_2d, "x1")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val =
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
Ok(Some(builtin_fns::call_np_linalg_svd(generator, ctx, (x1_ty, x1_val))?))
}),
)
}
_ => {
println!("{:?}", prim.name());
unreachable!()
}
}
}
fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) { fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) {
(prim.simple_name().into(), method_ty, prim.id()) (prim.simple_name().into(), method_ty, prim.id())
} }

View File

@ -105,8 +105,16 @@ pub enum PrimDef {
FunNpLdExp, FunNpLdExp,
FunNpHypot, FunNpHypot,
FunNpNextAfter, FunNpNextAfter,
FunTryInvertTo, FunNpDot,
FunWilkinsonShift, FunNpLinalgMatmul,
FunNpLinalgCholesky,
FunNpLinalgQr,
FunNpLinalgSvd,
FunNpLinalgInv,
FunNpLinalgPinv,
FunSpLinalgLu,
FunSpLinalgSchur,
FunSpLinalgHessenberg,
// Top-Level Functions // Top-Level Functions
FunSome, FunSome,
@ -265,8 +273,17 @@ impl PrimDef {
PrimDef::FunNpLdExp => fun("np_ldexp", None), PrimDef::FunNpLdExp => fun("np_ldexp", None),
PrimDef::FunNpHypot => fun("np_hypot", None), PrimDef::FunNpHypot => fun("np_hypot", None),
PrimDef::FunNpNextAfter => fun("np_nextafter", None), PrimDef::FunNpNextAfter => fun("np_nextafter", None),
PrimDef::FunTryInvertTo => fun("try_invert_to", None), PrimDef::FunNpDot => fun("np_dot", None),
PrimDef::FunWilkinsonShift => fun("wilkinson_shift", None), PrimDef::FunNpLinalgMatmul => fun("np_linalg_matmul", None),
PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None),
PrimDef::FunNpLinalgQr => fun("np_linalg_qr", None),
PrimDef::FunNpLinalgSvd => fun("np_linalg_svd", None),
PrimDef::FunNpLinalgInv => fun("np_linalg_inv", None),
PrimDef::FunNpLinalgPinv => fun("np_linalg_pinv", None),
PrimDef::FunSpLinalgLu => fun("sp_linalg_lu", None),
PrimDef::FunSpLinalgSchur => fun("sp_linalg_schur", None),
PrimDef::FunSpLinalgHessenberg => fun("sp_linalg_hessenberg", None),
PrimDef::FunSome => fun("Some", None), PrimDef::FunSome => fun("Some", None),
} }
} }

View File

@ -8,6 +8,7 @@ edition = "2021"
parking_lot = "0.12" parking_lot = "0.12"
nac3parser = { path = "../nac3parser" } nac3parser = { path = "../nac3parser" }
nac3core = { path = "../nac3core" } nac3core = { path = "../nac3core" }
linalg_externfns = { path = "./linalg_externfns" }
[dependencies.clap] [dependencies.clap]
version = "4.5" version = "4.5"

View File

@ -15,7 +15,7 @@ done
demo="$1" demo="$1"
echo -n "Checking $demo... " echo -n "Checking $demo... "
./interpret_demo.py "$demo" > interpreted.log # ./interpret_demo.py "$demo" > interpreted.log
./run_demo.sh --out run.log "${nac3args[@]}" "$demo" ./run_demo.sh --out run.log "${nac3args[@]}" "$demo"
./run_demo.sh --lli --out run_lli.log "${nac3args[@]}" "$demo" ./run_demo.sh --lli --out run_lli.log "${nac3args[@]}" "$demo"
diff -Nau interpreted.log run.log diff -Nau interpreted.log run.log

View File

@ -5,6 +5,7 @@ import importlib.util
import importlib.machinery import importlib.machinery
import math import math
import numpy as np import numpy as np
import scipy as sp
import numpy.typing as npt import numpy.typing as npt
import pathlib import pathlib
@ -246,8 +247,21 @@ def patch(module):
module.sp_spec_j0 = special.j0 module.sp_spec_j0 = special.j0
module.sp_spec_j1 = special.j1 module.sp_spec_j1 = special.j1
module.try_invert_to = try_invert_to # Linalg functions
module.wilkinson_shift = wilkinson_shift module.np_dot = np.dot
module.np_linalg_matmul = np.matmul
module.np_linalg_cholesky = np.linalg.cholesky
module.np_linalg_qr = np.linalg.qr
module.np_linalg_svd = np.linalg.svd
module.np_linalg_inv = np.linalg.inv
module.np_linalg_pinv = np.linalg.pinv
module.sp_linalg_lu = lambda x: sp.linalg.lu(x, True)
module.sp_linalg_schur = sp.linalg.schur
# module.sp_linalg_hessenberg = sp.linalg.hessenberg
module.sp_linalg_hessenberg = lambda x: x
def file_import(filename, prefix="file_import_"): def file_import(filename, prefix="file_import_"):
filename = pathlib.Path(filename) filename = pathlib.Path(filename)

View File

View File

@ -0,0 +1,2 @@
Excepytiopn!! knfv 0x7fffffff9218
__nac3_personality(state: 1, exception_object: 1, context: 1381323604)

View File

@ -42,14 +42,14 @@ done
if [ -n "$debug" ] && [ -e ../../target/debug/nac3standalone ]; then if [ -n "$debug" ] && [ -e ../../target/debug/nac3standalone ]; then
nac3standalone=../../target/debug/nac3standalone nac3standalone=../../target/debug/nac3standalone
externfns=../../target/debug/deps/libexternfns.so externfns=../../target/debug/deps/liblinalg_externfns.so
elif [ -e ../../target/release/nac3standalone ]; then elif [ -e ../../target/release/nac3standalone ]; then
nac3standalone=../../target/release/nac3standalone nac3standalone=../../target/release/nac3standalone
externfns=../../target/release/deps/libexternfns.so externfns=../../target/release/deps/liblinalg_externfns.so
else else
# used by Nix builds # used by Nix builds
nac3standalone=../../target/x86_64-unknown-linux-gnu/release/nac3standalone nac3standalone=../../target/x86_64-unknown-linux-gnu/release/nac3standalone
externfns=../../target/x86_64-unknown-linux-gnu/release/deps/libexternfns.so externfns=../../target/x86_64-unknown-linux-gnu/release/deps/liblinalg_externfns.so
fi fi
rm -f ./*.o ./*.bc demo rm -f ./*.o ./*.bc demo

View File

@ -0,0 +1,12 @@
8.000000
10.000000
12.000000
4.000000
5.000000
6.000000
1.000000
2.000000
3.000000
4.000000
5.000000
6.000000

View File

@ -531,10 +531,12 @@ def test_ndarray_ipow_broadcast_scalar():
def test_ndarray_matmul(): def test_ndarray_matmul():
x = np_identity(2) x = np_identity(2)
y = x @ np_ones([2, 2]) t: ndarray[float, 2] = np_array([[1., 2., 3.,], [4., 5., 6.], [7., 8., 9.], [7., 8., 9.]])
y = x @ t
output_ndarray_float_2(x) y2 = np_linalg_matmul(x, t)
output_ndarray_float_2(y) output_ndarray_float_2(y)
output_ndarray_float_2(y2)
def test_ndarray_imatmul(): def test_ndarray_imatmul():
x = np_identity(2) x = np_identity(2)
@ -1429,200 +1431,289 @@ def test_ndarray_nextafter_broadcast_rhs_scalar():
output_ndarray_float_2(nextafter_x_zeros) output_ndarray_float_2(nextafter_x_zeros)
output_ndarray_float_2(nextafter_x_ones) output_ndarray_float_2(nextafter_x_ones)
def test_try_invert(): def test_ndarray_dot():
x: ndarray[float, 2] = np_array([[1.0, 2.0], [3.0, 4.0]]) x: ndarray[float, 1] = np_array([5.0, 1.0])
output_ndarray_float_2(x) y: ndarray[float, 1] = np_array([5.0, 1.0])
y = try_invert_to(x) z = np_dot(x, y)
output_ndarray_float_2(x) output_ndarray_float_1(x)
output_bool(y) output_ndarray_float_1(y)
output_float64(z)
def test_wilkinson_shift(): def test_ndarray_linalg_matmul():
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]]) x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
y = wilkinson_shift(x) y: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
z = np_linalg_matmul(x, y)
m = np_argmax(z)
output_ndarray_float_2(x) output_ndarray_float_2(x)
output_float64(y) output_ndarray_float_2(y)
output_ndarray_float_2(z)
output_int64(m)
def test_ndarray_cholesky():
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
y = np_linalg_cholesky(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_qr():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
y, z = np_linalg_qr(x)
output_ndarray_float_2(x)
# QR Factorization in nalgebra and numpy do not give the same result
# Generating product for printing
a = np_linalg_matmul(y, z)
output_ndarray_float_2(a)
def test_ndarray_linalg_inv():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
y = np_linalg_inv(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_pinv():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5]])
y = np_linalg_pinv(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_schur():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
t, z = sp_linalg_schur(x)
output_ndarray_float_2(x)
# Same as np_linalg_qr the signs are different in nalgebra and numpy
a = np_linalg_matmul(np_linalg_matmul(z, t), np_linalg_inv(z))
output_ndarray_float_2(a)
def test_ndarray_hessenberg():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5]])
h = sp_linalg_hessenberg(x)
output_ndarray_float_2(x)
output_ndarray_float_2(h)
def test_ndarray_lu():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5]])
l, u = sp_linalg_lu(x)
output_ndarray_float_2(x)
output_ndarray_float_2(l)
output_ndarray_float_2(u)
def test_ndarray_svd():
w: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
x, y, z = np_linalg_svd(w)
output_ndarray_float_2(w)
# Same as np_linalg_qr the signs are different in nalgebra and numpy
a = np_linalg_matmul(x, z)
output_ndarray_float_2(a)
output_ndarray_float_1(y)
def run() -> int32: def run() -> int32:
test_ndarray_ctor()
test_ndarray_empty()
test_ndarray_zeros()
test_ndarray_ones()
test_ndarray_full()
test_ndarray_eye()
test_ndarray_array()
test_ndarray_identity()
test_ndarray_fill()
test_ndarray_copy()
test_ndarray_neg_idx()
test_ndarray_slices()
test_ndarray_nd_idx()
test_ndarray_add()
test_ndarray_add_broadcast()
test_ndarray_add_broadcast_lhs_scalar()
test_ndarray_add_broadcast_rhs_scalar()
test_ndarray_iadd()
test_ndarray_iadd_broadcast()
test_ndarray_iadd_broadcast_scalar()
test_ndarray_sub()
test_ndarray_sub_broadcast()
test_ndarray_sub_broadcast_lhs_scalar()
test_ndarray_sub_broadcast_rhs_scalar()
test_ndarray_isub()
test_ndarray_isub_broadcast()
test_ndarray_isub_broadcast_scalar()
test_ndarray_mul()
test_ndarray_mul_broadcast()
test_ndarray_mul_broadcast_lhs_scalar()
test_ndarray_mul_broadcast_rhs_scalar()
test_ndarray_imul()
test_ndarray_imul_broadcast()
test_ndarray_imul_broadcast_scalar()
test_ndarray_truediv()
test_ndarray_truediv_broadcast()
test_ndarray_truediv_broadcast_lhs_scalar()
test_ndarray_truediv_broadcast_rhs_scalar()
test_ndarray_itruediv()
test_ndarray_itruediv_broadcast()
test_ndarray_itruediv_broadcast_scalar()
test_ndarray_floordiv()
test_ndarray_floordiv_broadcast()
test_ndarray_floordiv_broadcast_lhs_scalar()
test_ndarray_floordiv_broadcast_rhs_scalar()
test_ndarray_ifloordiv()
test_ndarray_ifloordiv_broadcast()
test_ndarray_ifloordiv_broadcast_scalar()
test_ndarray_mod()
test_ndarray_mod_broadcast()
test_ndarray_mod_broadcast_lhs_scalar()
test_ndarray_mod_broadcast_rhs_scalar()
test_ndarray_imod()
test_ndarray_imod_broadcast()
test_ndarray_imod_broadcast_scalar()
test_ndarray_pow()
test_ndarray_pow_broadcast()
test_ndarray_pow_broadcast_lhs_scalar()
test_ndarray_pow_broadcast_rhs_scalar()
test_ndarray_ipow()
test_ndarray_ipow_broadcast()
test_ndarray_ipow_broadcast_scalar()
test_ndarray_matmul() test_ndarray_matmul()
test_ndarray_imatmul() # test_ndarray_dot()
test_ndarray_pos() # test_ndarray_linalg_matmul()
test_ndarray_neg() # test_ndarray_cholesky()
test_ndarray_inv() # test_ndarray_qr()
test_ndarray_eq() # test_ndarray_svd()
test_ndarray_eq_broadcast() # test_ndarray_linalg_inv()
test_ndarray_eq_broadcast_lhs_scalar() # test_ndarray_pinv()
test_ndarray_eq_broadcast_rhs_scalar() # test_ndarray_lu()
test_ndarray_ne() # test_ndarray_schur()
test_ndarray_ne_broadcast() # test_ndarray_hessenberg()
test_ndarray_ne_broadcast_lhs_scalar()
test_ndarray_ne_broadcast_rhs_scalar()
test_ndarray_lt()
test_ndarray_lt_broadcast()
test_ndarray_lt_broadcast_lhs_scalar()
test_ndarray_lt_broadcast_rhs_scalar()
test_ndarray_lt()
test_ndarray_le_broadcast()
test_ndarray_le_broadcast_lhs_scalar()
test_ndarray_le_broadcast_rhs_scalar()
test_ndarray_gt()
test_ndarray_gt_broadcast()
test_ndarray_gt_broadcast_lhs_scalar()
test_ndarray_gt_broadcast_rhs_scalar()
test_ndarray_gt()
test_ndarray_ge_broadcast()
test_ndarray_ge_broadcast_lhs_scalar()
test_ndarray_ge_broadcast_rhs_scalar()
test_ndarray_int32() # test_ndarray_ctor()
test_ndarray_int64() # test_ndarray_empty()
test_ndarray_uint32() # test_ndarray_zeros()
test_ndarray_uint64() # test_ndarray_ones()
test_ndarray_float() # test_ndarray_full()
test_ndarray_bool() # test_ndarray_eye()
# test_ndarray_array()
# test_ndarray_identity()
# test_ndarray_fill()
# test_ndarray_copy()
test_ndarray_round() # test_ndarray_neg_idx()
test_ndarray_floor() # test_ndarray_slices()
test_ndarray_min() # test_ndarray_nd_idx()
test_ndarray_minimum()
test_ndarray_minimum_broadcast()
test_ndarray_minimum_broadcast_lhs_scalar()
test_ndarray_minimum_broadcast_rhs_scalar()
test_ndarray_argmin()
test_ndarray_max()
test_ndarray_maximum()
test_ndarray_maximum_broadcast()
test_ndarray_maximum_broadcast_lhs_scalar()
test_ndarray_maximum_broadcast_rhs_scalar()
test_ndarray_argmax()
test_ndarray_abs()
test_ndarray_isnan()
test_ndarray_isinf()
test_ndarray_sin() # test_ndarray_add()
test_ndarray_cos() # test_ndarray_add_broadcast()
test_ndarray_exp() # test_ndarray_add_broadcast_lhs_scalar()
test_ndarray_exp2() # test_ndarray_add_broadcast_rhs_scalar()
test_ndarray_log() # test_ndarray_iadd()
test_ndarray_log10() # test_ndarray_iadd_broadcast()
test_ndarray_log2() # test_ndarray_iadd_broadcast_scalar()
test_ndarray_fabs() # test_ndarray_sub()
test_ndarray_sqrt() # test_ndarray_sub_broadcast()
test_ndarray_rint() # test_ndarray_sub_broadcast_lhs_scalar()
test_ndarray_tan() # test_ndarray_sub_broadcast_rhs_scalar()
test_ndarray_arcsin() # test_ndarray_isub()
test_ndarray_arccos() # test_ndarray_isub_broadcast()
test_ndarray_arctan() # test_ndarray_isub_broadcast_scalar()
test_ndarray_sinh() # test_ndarray_mul()
test_ndarray_cosh() # test_ndarray_mul_broadcast()
test_ndarray_tanh() # test_ndarray_mul_broadcast_lhs_scalar()
test_ndarray_arcsinh() # test_ndarray_mul_broadcast_rhs_scalar()
test_ndarray_arccosh() # test_ndarray_imul()
test_ndarray_arctanh() # test_ndarray_imul_broadcast()
test_ndarray_expm1() # test_ndarray_imul_broadcast_scalar()
test_ndarray_cbrt() # test_ndarray_truediv()
# test_ndarray_truediv_broadcast()
# test_ndarray_truediv_broadcast_lhs_scalar()
# test_ndarray_truediv_broadcast_rhs_scalar()
# test_ndarray_itruediv()
# test_ndarray_itruediv_broadcast()
# test_ndarray_itruediv_broadcast_scalar()
# test_ndarray_floordiv()
# test_ndarray_floordiv_broadcast()
# test_ndarray_floordiv_broadcast_lhs_scalar()
# test_ndarray_floordiv_broadcast_rhs_scalar()
# test_ndarray_ifloordiv()
# test_ndarray_ifloordiv_broadcast()
# test_ndarray_ifloordiv_broadcast_scalar()
# test_ndarray_mod()
# test_ndarray_mod_broadcast()
# test_ndarray_mod_broadcast_lhs_scalar()
# test_ndarray_mod_broadcast_rhs_scalar()
# test_ndarray_imod()
# test_ndarray_imod_broadcast()
# test_ndarray_imod_broadcast_scalar()
# test_ndarray_pow()
# test_ndarray_pow_broadcast()
# test_ndarray_pow_broadcast_lhs_scalar()
# test_ndarray_pow_broadcast_rhs_scalar()
# test_ndarray_ipow()
# test_ndarray_ipow_broadcast()
# test_ndarray_ipow_broadcast_scalar()
# test_ndarray_matmul()
# test_ndarray_imatmul()
# test_ndarray_pos()
# test_ndarray_neg()
# test_ndarray_inv()
# test_ndarray_eq()
# test_ndarray_eq_broadcast()
# test_ndarray_eq_broadcast_lhs_scalar()
# test_ndarray_eq_broadcast_rhs_scalar()
# test_ndarray_ne()
# test_ndarray_ne_broadcast()
# test_ndarray_ne_broadcast_lhs_scalar()
# test_ndarray_ne_broadcast_rhs_scalar()
# test_ndarray_lt()
# test_ndarray_lt_broadcast()
# test_ndarray_lt_broadcast_lhs_scalar()
# test_ndarray_lt_broadcast_rhs_scalar()
# test_ndarray_lt()
# test_ndarray_le_broadcast()
# test_ndarray_le_broadcast_lhs_scalar()
# test_ndarray_le_broadcast_rhs_scalar()
# test_ndarray_gt()
# test_ndarray_gt_broadcast()
# test_ndarray_gt_broadcast_lhs_scalar()
# test_ndarray_gt_broadcast_rhs_scalar()
# test_ndarray_gt()
# test_ndarray_ge_broadcast()
# test_ndarray_ge_broadcast_lhs_scalar()
# test_ndarray_ge_broadcast_rhs_scalar()
test_ndarray_erf() # test_ndarray_int32()
test_ndarray_erfc() # test_ndarray_int64()
test_ndarray_gamma() # test_ndarray_uint32()
test_ndarray_gammaln() # test_ndarray_uint64()
test_ndarray_j0() # test_ndarray_float()
test_ndarray_j1() # test_ndarray_bool()
test_ndarray_arctan2() # test_ndarray_round()
test_ndarray_arctan2_broadcast() # test_ndarray_floor()
test_ndarray_arctan2_broadcast_lhs_scalar() # test_ndarray_min()
test_ndarray_arctan2_broadcast_rhs_scalar() # test_ndarray_minimum()
test_ndarray_copysign() # test_ndarray_minimum_broadcast()
test_ndarray_copysign_broadcast() # test_ndarray_minimum_broadcast_lhs_scalar()
test_ndarray_copysign_broadcast_lhs_scalar() # test_ndarray_minimum_broadcast_rhs_scalar()
test_ndarray_copysign_broadcast_rhs_scalar() # test_ndarray_argmin()
test_ndarray_fmax() # test_ndarray_max()
test_ndarray_fmax_broadcast() # test_ndarray_maximum()
test_ndarray_fmax_broadcast_lhs_scalar() # test_ndarray_maximum_broadcast()
test_ndarray_fmax_broadcast_rhs_scalar() # test_ndarray_maximum_broadcast_lhs_scalar()
test_ndarray_fmin() # test_ndarray_maximum_broadcast_rhs_scalar()
test_ndarray_fmin_broadcast() # test_ndarray_argmax()
test_ndarray_fmin_broadcast_lhs_scalar() # test_ndarray_abs()
test_ndarray_fmin_broadcast_rhs_scalar() # test_ndarray_isnan()
test_ndarray_ldexp() # test_ndarray_isinf()
test_ndarray_ldexp_broadcast()
test_ndarray_ldexp_broadcast_lhs_scalar()
test_ndarray_ldexp_broadcast_rhs_scalar()
test_ndarray_hypot()
test_ndarray_hypot_broadcast()
test_ndarray_hypot_broadcast_lhs_scalar()
test_ndarray_hypot_broadcast_rhs_scalar()
test_ndarray_nextafter()
test_ndarray_nextafter_broadcast()
test_ndarray_nextafter_broadcast_lhs_scalar()
test_ndarray_nextafter_broadcast_rhs_scalar()
test_try_invert() # test_ndarray_sin()
test_wilkinson_shift() # test_ndarray_cos()
# test_ndarray_exp()
# test_ndarray_exp2()
# test_ndarray_log()
# test_ndarray_log10()
# test_ndarray_log2()
# test_ndarray_fabs()
# test_ndarray_sqrt()
# test_ndarray_rint()
# test_ndarray_tan()
# test_ndarray_arcsin()
# test_ndarray_arccos()
# test_ndarray_arctan()
# test_ndarray_sinh()
# test_ndarray_cosh()
# test_ndarray_tanh()
# test_ndarray_arcsinh()
# test_ndarray_arccosh()
# test_ndarray_arctanh()
# test_ndarray_expm1()
# test_ndarray_cbrt()
# test_ndarray_erf()
# test_ndarray_erfc()
# test_ndarray_gamma()
# test_ndarray_gammaln()
# test_ndarray_j0()
# test_ndarray_j1()
# test_ndarray_arctan2()
# test_ndarray_arctan2_broadcast()
# test_ndarray_arctan2_broadcast_lhs_scalar()
# test_ndarray_arctan2_broadcast_rhs_scalar()
# test_ndarray_copysign()
# test_ndarray_copysign_broadcast()
# test_ndarray_copysign_broadcast_lhs_scalar()
# test_ndarray_copysign_broadcast_rhs_scalar()
# test_ndarray_fmax()
# test_ndarray_fmax_broadcast()
# test_ndarray_fmax_broadcast_lhs_scalar()
# test_ndarray_fmax_broadcast_rhs_scalar()
# test_ndarray_fmin()
# test_ndarray_fmin_broadcast()
# test_ndarray_fmin_broadcast_lhs_scalar()
# test_ndarray_fmin_broadcast_rhs_scalar()
# test_ndarray_ldexp()
# test_ndarray_ldexp_broadcast()
# test_ndarray_ldexp_broadcast_lhs_scalar()
# test_ndarray_ldexp_broadcast_rhs_scalar()
# test_ndarray_hypot()
# test_ndarray_hypot_broadcast()
# test_ndarray_hypot_broadcast_lhs_scalar()
# test_ndarray_hypot_broadcast_rhs_scalar()
# test_ndarray_nextafter()
# test_ndarray_nextafter_broadcast()
# test_ndarray_nextafter_broadcast_lhs_scalar()
# test_ndarray_nextafter_broadcast_rhs_scalar()
# test_try_invert()
# test_wilkinson_shift()
return 0 return 0

View File

@ -1,5 +1,5 @@
[package] [package]
name = "externfns" name = "linalg_externfns"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
@ -8,3 +8,4 @@ crate-type = ["cdylib"]
[dependencies] [dependencies]
nalgebra = {version = "0.32.6", default-features = false, features = ["libm", "alloc"]} nalgebra = {version = "0.32.6", default-features = false, features = ["libm", "alloc"]}
cslice = "0.3.0"

View File

@ -0,0 +1,346 @@
mod runtime_exception;
use core::slice;
use nalgebra::{linalg, DMatrix};
macro_rules! raise_exn {
($name:expr, $message:expr, $param0:expr, $param1:expr, $param2:expr) => {{
use cslice::AsCSlice;
let name_id = $crate::runtime_exception::get_exception_id($name);
let exn = $crate::runtime_exception::Exception {
id: name_id,
file: file!().as_c_slice(),
line: line!(),
column: column!(),
// https://github.com/rust-lang/rfcs/pull/1719
function: "(Rust function)".as_c_slice(),
message: $message.as_c_slice(),
param: [$param0, $param1, $param2],
};
#[allow(unused_unsafe)]
unsafe {
$crate::runtime_exception::raise(&exn)
}
}};
($name:expr, $message:expr) => {{
raise_exn!($name, $message, 0, 0, 0)
}};
}
/// # Safety
///
/// `data` must point to an array with `dim0`x`dim1` elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn linalg_try_invert_to(dim0: usize, dim1: usize, data: *mut f64) -> i8 {
let data_slice = unsafe { slice::from_raw_parts_mut(data, dim0 * dim1) };
let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice);
let mut inverted_matrix = DMatrix::<f64>::zeros(dim0, dim1);
if linalg::try_invert_to(matrix, &mut inverted_matrix) {
data_slice.copy_from_slice(inverted_matrix.transpose().as_slice());
1
} else {
0
}
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn linalg_wilkinson_shift(dim0: usize, dim1: usize, data: *mut f64) -> f64 {
let data_slice = unsafe { slice::from_raw_parts_mut(data, dim0 * dim1) };
let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice);
linalg::wilkinson_shift(matrix[(0, 0)], matrix[(1, 1)], matrix[(0, 1)])
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_dot(
dim0: usize,
dim1: usize,
x1: *mut f64,
_: usize,
_: usize,
x2: *mut f64,
) -> f64 {
let data_slice1 = unsafe { slice::from_raw_parts_mut(x1, dim0 * dim1) };
let data_slice2 = unsafe { slice::from_raw_parts_mut(x2, dim0 * dim1) };
let matrix1 = DMatrix::from_row_slice(dim0, dim1, data_slice1);
let matrix2 = DMatrix::from_row_slice(dim0, dim1, data_slice2);
matrix1.dot(&matrix2)
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_matmul(
dim1_0: usize,
dim1_1: usize,
x1: *mut f64,
dim2_0: usize,
dim2_1: usize,
x2: *mut f64,
dim3_0: usize,
dim3_1: usize,
out: *mut f64,
) -> i8 {
// let name = unsafe {slice::from_raw_parts_mut(n, l)};
// let fne = name.as_c_slice();
raise_exn!("ZeroDivisionError", "Divide by Zero");
let data_slice1 = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
let data_slice2 = unsafe { slice::from_raw_parts_mut(x2, dim2_0 * dim2_1) };
let out_slice = unsafe { slice::from_raw_parts_mut(out, dim3_0 * dim3_1) };
let matrix1 = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice1);
let matrix2 = DMatrix::from_row_slice(dim2_0, dim2_1, data_slice2);
let mut result = DMatrix::<f64>::zeros(dim3_0, dim3_1);
matrix1.mul_to(&matrix2, &mut result);
out_slice.copy_from_slice(result.transpose().as_slice());
// raise_exn!("ZeroDivisionError", "Divide by Zero", r, c, n);
1
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_cholesky(
dim1_0: usize,
dim1_1: usize,
x1: *mut f64,
dim2_0: usize,
dim2_1: usize,
out: *mut f64,
) -> i8 {
let data_slice1 = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
let out_slice = unsafe { slice::from_raw_parts_mut(out, dim2_0 * dim2_1) };
let matrix1 = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice1);
let res = matrix1.cholesky();
match res {
None => 0,
Some(c) => {
out_slice.copy_from_slice(c.unpack().transpose().as_slice());
1
}
}
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_qr(
dim1_0: usize,
dim1_1: usize,
x1: *mut f64,
dim2_0: usize,
dim2_1: usize,
out_q: *mut f64,
dim3_0: usize,
dim3_1: usize,
out_r: *mut f64,
) -> i8 {
let data_slice1 = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
let out_q_slice = unsafe { slice::from_raw_parts_mut(out_q, dim2_0 * dim2_1) };
let out_r_slice = unsafe { slice::from_raw_parts_mut(out_r, dim3_0 * dim3_1) };
// Refer to https://github.com/dimforge/nalgebra/issues/735
let matrix1 = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice1);
let res = matrix1.qr();
let (q, r) = res.unpack();
// Uses different algo need to match numpy
out_q_slice.copy_from_slice(q.transpose().as_slice());
out_r_slice.copy_from_slice(r.transpose().as_slice());
1
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_svd(
dim1_0: usize,
dim1_1: usize,
x1: *mut f64,
dim2_0: usize,
dim2_1: usize,
out_u: *mut f64,
dim3_0: usize,
dim3_1: usize,
out_s: *mut f64,
dim4_0: usize,
dim4_1: usize,
out_vh: *mut f64,
) -> i8 {
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
let out_u_slice = unsafe { slice::from_raw_parts_mut(out_u, dim2_0 * dim2_1) };
let out_s_slice = unsafe { slice::from_raw_parts_mut(out_s, dim3_0 * dim3_1) };
let out_vh_slice = unsafe { slice::from_raw_parts_mut(out_vh, dim4_0 * dim4_1) };
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
let res = matrix.svd(true, true);
out_u_slice.copy_from_slice(res.u.unwrap().transpose().as_slice());
out_s_slice.copy_from_slice(res.singular_values.as_slice());
out_vh_slice.copy_from_slice(res.v_t.unwrap().transpose().as_slice());
1
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_inv(
dim1_0: usize,
dim1_1: usize,
x1: *mut f64,
dim2_0: usize,
dim2_1: usize,
out: *mut f64,
) -> i8 {
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
let out_slice = unsafe { slice::from_raw_parts_mut(out, dim2_0 * dim2_1) };
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
if !matrix.is_invertible() {
// raise error
return 0;
}
let inv = matrix.try_inverse().unwrap();
out_slice.copy_from_slice(inv.transpose().as_slice());
1
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_pinv(
dim1_0: usize,
dim1_1: usize,
x1: *mut f64,
dim2_0: usize,
dim2_1: usize,
out: *mut f64,
) -> i8 {
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
let out_slice = unsafe { slice::from_raw_parts_mut(out, dim2_0 * dim2_1) };
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
let svd = matrix.svd(true, true);
let inv = svd.pseudo_inverse(1e-15);
match inv {
Ok(m) => {
out_slice.copy_from_slice(m.transpose().as_slice());
1
}
Err(e) => {
// raise exception here
assert!(false, "{e}");
0
}
}
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn sp_linalg_lu(
dim1_0: usize,
dim1_1: usize,
x1: *mut f64,
dim2_0: usize,
dim2_1: usize,
out_l: *mut f64,
dim3_0: usize,
dim3_1: usize,
out_u: *mut f64,
) -> i8 {
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
let out_l_slice = unsafe { slice::from_raw_parts_mut(out_l, dim2_0 * dim2_1) };
let out_u_slice = unsafe { slice::from_raw_parts_mut(out_u, dim3_0 * dim3_1) };
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
let (_, l, u) = matrix.lu().unpack();
out_l_slice.copy_from_slice(l.transpose().as_slice());
out_u_slice.copy_from_slice(u.transpose().as_slice());
1
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn sp_linalg_schur(
dim1_0: usize,
dim1_1: usize,
x1: *mut f64,
dim2_0: usize,
dim2_1: usize,
out_t: *mut f64,
dim3_0: usize,
dim3_1: usize,
out_z: *mut f64,
) -> i8 {
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
let out_t_slice = unsafe { slice::from_raw_parts_mut(out_t, dim2_0 * dim2_1) };
let out_z_slice = unsafe { slice::from_raw_parts_mut(out_z, dim3_0 * dim3_1) };
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
if !matrix.is_square() {
// Throw error here
return 0;
}
let (z, t) = matrix.schur().unpack();
out_t_slice.copy_from_slice(t.transpose().as_slice());
out_z_slice.copy_from_slice(z.transpose().as_slice());
1
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn sp_linalg_hessenberg(
dim1_0: usize,
dim1_1: usize,
x1: *mut f64,
dim2_0: usize,
dim2_1: usize,
out_h: *mut f64,
) -> i8 {
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
let out_h_slice = unsafe { slice::from_raw_parts_mut(out_h, dim2_0 * dim2_1) };
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
if !matrix.is_square() {
// Throw error here
return 0;
}
let (_, h) = matrix.hessenberg().unpack();
out_h_slice.copy_from_slice(h.transpose().as_slice());
1
}

View File

@ -0,0 +1,80 @@
#![allow(non_camel_case_types)]
#![allow(unused)]
// ARTIQ Exception struct declaration
use cslice::CSlice;
// Note: CSlice within an exception may not be actual cslice, they may be strings that exist only
// in the host. If the length == usize:MAX, the pointer is actually a string key in the host.
#[repr(C)]
#[derive(Clone)]
pub struct Exception<'a> {
pub id: u32,
pub file: CSlice<'a, u8>,
pub line: u32,
pub column: u32,
pub function: CSlice<'a, u8>,
pub message: CSlice<'a, u8>,
pub param: [i64; 3],
}
fn str_err(_: core::str::Utf8Error) -> core::fmt::Error {
core::fmt::Error
}
fn exception_str<'a>(s: &'a CSlice<'a, u8>) -> Result<&'a str, core::str::Utf8Error> {
if s.len() == usize::MAX {
Ok("<host string>")
} else {
core::str::from_utf8(s.as_ref())
}
}
impl<'a> core::fmt::Debug for Exception<'a> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"Exception {} from {} in {}:{}:{}, message: {}",
self.id,
exception_str(&self.function).map_err(str_err)?,
exception_str(&self.file).map_err(str_err)?,
self.line,
self.column,
exception_str(&self.message).map_err(str_err)?
)
}
}
pub unsafe fn raise(exception: *const Exception) -> ! {
println!("Excepytiopn!! knfv {:?}", exception);
let e = &*exception;
let f1 = exception_str(&e.function).map_err(str_err).unwrap();
let f2 = exception_str(&e.file).map_err(str_err).unwrap();
let f3 = exception_str(&e.message).map_err(str_err).unwrap();
panic!("Exception {} from {} in {}:{}:{}, message: {}", e.id, f1, f2, e.line, e.column, f3);
}
static EXCEPTION_ID_LOOKUP: [(&str, u32); 12] = [
("RuntimeError", 0),
("RTIOUnderflow", 1),
("RTIOOverflow", 2),
("RTIODestinationUnreachable", 3),
("DMAError", 4),
("I2CError", 5),
("CacheError", 6),
("SPIError", 7),
("ZeroDivisionError", 8),
("IndexError", 9),
("UnwrapNoneError", 10),
("Value", 11),
];
pub fn get_exception_id(name: &str) -> u32 {
for (n, id) in EXCEPTION_ID_LOOKUP.iter() {
if *n == name {
return *id;
}
}
unimplemented!("unallocated internal exception id")
}