Compare commits

..

1 Commits

Author SHA1 Message Date
6ad597e592 WIP 2024-07-19 18:18:08 +08:00
31 changed files with 649 additions and 2070 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
__pycache__ __pycache__
/target /target
nix/windows/msys2 nix/windows/msys2
nac3standalone/demo/externfns/target

69
Cargo.lock generated
View File

@ -126,9 +126,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.1.6" version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2aba8f4e9906c7ce3c73463f62a7f0c65183ada1a2d47e397cc8810827f9694f" checksum = "eaff6f8ce506b9773fa786672d63fc7a191ffea1be33f72bbd4aeacefca9ffc8"
[[package]] [[package]]
name = "cfg-if" name = "cfg-if"
@ -167,7 +167,7 @@ dependencies = [
"heck 0.5.0", "heck 0.5.0",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.71", "syn 2.0.70",
] ]
[[package]] [[package]]
@ -256,12 +256,6 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
[[package]]
name = "cslice"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f8cb7306107e4b10e64994de6d3274bd08996a7c1322a27b86482392f96be0a"
[[package]] [[package]]
name = "dirs-next" name = "dirs-next"
version = "2.0.0" version = "2.0.0"
@ -320,6 +314,13 @@ dependencies = [
"windows-sys", "windows-sys",
] ]
[[package]]
name = "externfns"
version = "0.1.0"
dependencies = [
"nalgebra",
]
[[package]] [[package]]
name = "fastrand" name = "fastrand"
version = "2.1.0" version = "2.1.0"
@ -436,7 +437,7 @@ checksum = "4fa4d8d74483041a882adaa9a29f633253a66dde85055f0495c121620ac484b2"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.71", "syn 2.0.70",
] ]
[[package]] [[package]]
@ -528,9 +529,9 @@ checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c"
[[package]] [[package]]
name = "libloading" name = "libloading"
version = "0.8.5" version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4" checksum = "e310b3a6b5907f99202fcdb4960ff45b93735d7c7d96b760fcff8db2dc0e103d"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"windows-targets", "windows-targets",
@ -552,14 +553,6 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "linalg_externfns"
version = "0.1.0"
dependencies = [
"cslice",
"nalgebra",
]
[[package]] [[package]]
name = "linked-hash-map" name = "linked-hash-map"
version = "0.5.6" version = "0.5.6"
@ -687,8 +680,8 @@ name = "nac3standalone"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"clap", "clap",
"externfns",
"inkwell", "inkwell",
"linalg_externfns",
"nac3core", "nac3core",
"nac3parser", "nac3parser",
"parking_lot", "parking_lot",
@ -837,7 +830,7 @@ dependencies = [
"phf_shared 0.11.2", "phf_shared 0.11.2",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.71", "syn 2.0.70",
] ]
[[package]] [[package]]
@ -866,9 +859,9 @@ checksum = "5be167a7af36ee22fe3115051bc51f6e6c7054c9348e28deb4f49bd6f705a315"
[[package]] [[package]]
name = "portable-atomic" name = "portable-atomic"
version = "1.7.0" version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265" checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0"
[[package]] [[package]]
name = "ppv-lite86" name = "ppv-lite86"
@ -938,7 +931,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"pyo3-macros-backend", "pyo3-macros-backend",
"quote", "quote",
"syn 2.0.71", "syn 2.0.70",
] ]
[[package]] [[package]]
@ -951,7 +944,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"pyo3-build-config", "pyo3-build-config",
"quote", "quote",
"syn 2.0.71", "syn 2.0.70",
] ]
[[package]] [[package]]
@ -1015,9 +1008,9 @@ dependencies = [
[[package]] [[package]]
name = "redox_syscall" name = "redox_syscall"
version = "0.5.3" version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a908a6e00f1fdd0dfd9c0eb08ce85126f6d8bbda50017e74bc4a4b7d4a926a4" checksum = "c82cf8cff14456045f55ec4241383baeff27af886adb72ffb2162f99911de0fd"
dependencies = [ dependencies = [
"bitflags", "bitflags",
] ]
@ -1132,7 +1125,7 @@ checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.71", "syn 2.0.70",
] ]
[[package]] [[package]]
@ -1234,7 +1227,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"rustversion", "rustversion",
"syn 2.0.71", "syn 2.0.70",
] ]
[[package]] [[package]]
@ -1250,9 +1243,9 @@ dependencies = [
[[package]] [[package]]
name = "syn" name = "syn"
version = "2.0.71" version = "2.0.70"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b146dcf730474b4bcd16c311627b31ede9ab149045db4d6088b3becaea046462" checksum = "2f0209b68b3613b093e0ec905354eccaedcfe83b8cb37cbdeae64026c3064c16"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -1303,22 +1296,22 @@ dependencies = [
[[package]] [[package]]
name = "thiserror" name = "thiserror"
version = "1.0.63" version = "1.0.61"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709"
dependencies = [ dependencies = [
"thiserror-impl", "thiserror-impl",
] ]
[[package]] [[package]]
name = "thiserror-impl" name = "thiserror-impl"
version = "1.0.63" version = "1.0.61"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.71", "syn 2.0.70",
] ]
[[package]] [[package]]
@ -1592,5 +1585,5 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.71", "syn 2.0.70",
] ]

View File

@ -4,8 +4,8 @@ members = [
"nac3ast", "nac3ast",
"nac3parser", "nac3parser",
"nac3core", "nac3core",
"nac3standalone/linalg_externfns",
"nac3standalone", "nac3standalone",
"nac3standalone/demo/externfns",
"nac3artiq", "nac3artiq",
"runkernel", "runkernel",
] ]

View File

@ -14,12 +14,23 @@ class Demo:
@kernel @kernel
def run(self): def run(self):
self.core.reset() a = np_array([[1., 2.], [3., 4.]])
while True: b = try_invert_to(a)
with parallel: if b:
self.led0.pulse(100.*ms) # self.core.reset()
self.led1.pulse(100.*ms) # while True:
self.core.delay(100.*ms) # with parallel:
# self.led0.pulse(100.*ms)
# self.led1.pulse(100.*ms)
# self.core.delay(100.*ms)
v = try_invert_to(np_identity(2))
if v:
while True:
with parallel:
self.led0.pulse(100.*ms)
self.led1.pulse(100.*ms)
self.core.delay(100.*ms)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,24 +0,0 @@
from min_artiq import *
from numpy import int32
@nac3
class EmptyList:
core: KernelInvariant[Core]
def __init__(self):
self.core = Core()
@rpc
def get_empty(self) -> list[int32]:
return []
@kernel
def run(self):
a: list[int32] = self.get_empty()
if a != []:
raise ValueError
if __name__ == "__main__":
EmptyList().run()

BIN
nac3artiq/demo/module.elf Normal file

Binary file not shown.

View File

@ -991,15 +991,8 @@ impl InnerResolver {
} }
_ => unreachable!("must be list"), _ => unreachable!("must be list"),
}; };
let ty = ctx.get_llvm_type(generator, elem_ty);
let size_t = generator.get_size_type(ctx.ctx); let size_t = generator.get_size_type(ctx.ctx);
let ty = if len == 0
&& matches!(&*ctx.unifier.get_ty_immutable(elem_ty), TypeEnum::TVar { .. })
{
// The default type for zero-length lists of unknown element type is size_t
size_t.into()
} else {
ctx.get_llvm_type(generator, elem_ty)
};
let arr_ty = ctx let arr_ty = ctx
.ctx .ctx
.struct_type(&[ty.ptr_type(AddressSpace::default()).into(), size_t.into()], false); .struct_type(&[ty.ptr_type(AddressSpace::default()).into(), size_t.into()], false);

View File

@ -1,9 +1,9 @@
use inkwell::types::BasicTypeEnum; use inkwell::types::BasicTypeEnum;
use inkwell::values::{BasicValue, BasicValueEnum}; use inkwell::values::BasicValueEnum;
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel}; use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
use itertools::Itertools; use itertools::Itertools;
use crate::codegen::classes::{ArrayLikeValue, NDArrayValue, ProxyValue, UntypedArrayLikeAccessor}; use crate::codegen::classes::{ArrayLikeValue, NDArrayValue, ProxyValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor};
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl; use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
use crate::codegen::stmt::gen_for_callback_incrementing; use crate::codegen::stmt::gen_for_callback_incrementing;
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator}; use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
@ -31,6 +31,7 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let (n_ty, n) = n; let (n_ty, n) = n;
Ok(match n { Ok(match n {
BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => { BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool)); debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool));
@ -921,6 +922,122 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
}) })
} }
pub fn call_try_invert_to<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
a: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "linalg_try_invert_to";
let (a_ty, a) = a;
let llvm_usize = generator.get_size_type(ctx.ctx);
match a {
BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
match llvm_ndarray_ty {
BasicTypeEnum::FloatType(_) => {},
_ => unreachable!("Inverse Operation supported on float type NDArray Values only")
};
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None));
// Add asserts for dims
if cfg!(debug_assertions) {
let n_dims = n.load_ndims(ctx);
// num_dim == 2
ctx.make_assert(generator,ctx.builder.build_int_compare(IntPredicate::EQ, n_dims, n_dims.get_type().const_int(2, false), "").unwrap(),
"0:ValueError", format!("Inverse only supported on 2D lists").as_str(), [None, None, None], ctx.current_loc);
let dim0 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)};
let dim1 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)};
// dim0 == dim1
ctx.make_assert(generator,ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim1, "").unwrap(),
"0:ValueError", format!("Dimensions do not match {dim0} is not same as {dim1}").as_str(), [None, None, None], ctx.current_loc);
}
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let n_sz_eqz = ctx
.builder
.build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "")
.unwrap();
ctx.make_assert(
generator,
n_sz_eqz,
"0:ValueError",
format!("zero-size array to reduction operation {FN_NAME}").as_str(),
[None, None, None],
ctx.current_loc,
);
}
// Create a call to linalg_try_invert_to
let dim0 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)};
let dim1 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)};
Ok(extern_fns::call_linalg_try_invert_to(ctx, dim0, dim1, n.data().base_ptr(ctx, generator), None).into())
}
_ => unsupported_type(ctx, FN_NAME, &[a_ty]),
}
}
pub fn call_wilkinson_shift<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
a: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "linalg_wilkinson_shift";
let (a_ty, a) = a;
let llvm_usize = generator.get_size_type(ctx.ctx);
match a {
BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
match llvm_ndarray_ty {
BasicTypeEnum::FloatType(_) | BasicTypeEnum::IntType(_) => {},
_ => unreachable!("Wilkinson Shift Operation supported on float/integer type NDArray Values only")
};
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
// Add asserts for dims
if cfg!(debug_assertions) {
let n_dims = n.load_ndims(ctx);
// num_dim == 2
ctx.make_assert(generator,ctx.builder.build_int_compare(IntPredicate::EQ, n_dims, n_dims.get_type().const_int(2, false), "").unwrap(),
"0:ValueError", format!("Wilkinson Shift supported only on 2D lists").as_str(), [None, None, None], ctx.current_loc);
let dim0 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)};
let dim1 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)};
// dim0 == dim1
ctx.make_assert(generator,ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim1, "").unwrap(),
"0:ValueError", format!("Dimensions do not match {dim0} is not same as {dim1}").as_str(), [None, None, None], ctx.current_loc);
// dimesions should be 2x2
ctx.make_assert(generator,ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim0.get_type().const_int(2, false), "").unwrap(),
"0:ValueError", format!("Wilkinson Shift supported only on 2x2 matrices").as_str(), [None, None, None], ctx.current_loc);
}
// Create a call to linalg_try_invert_to
let dim0 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)};
let dim1 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)};
Ok(extern_fns::call_linalg_wilkinson_shift(ctx, dim0, dim1, n.data().base_ptr(ctx, generator), None).into())
}
_ => unsupported_type(ctx, FN_NAME, &[a_ty]),
}
}
/// Invokes the `np_maximum` builtin function. /// Invokes the `np_maximum` builtin function.
pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>( pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
@ -1834,763 +1951,3 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]),
}) })
} }
// Linalg Methods
pub fn call_np_dot<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_dot";
let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
let llvm_usize = generator.get_size_type(ctx.ctx);
let one = llvm_usize.const_int(1, false);
if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) = (x1, x2) {
let (n1_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, n1_elem_ty);
let (n2_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
let n2_elem_ty = ctx.get_llvm_type(generator, n2_elem_ty);
let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty)
else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None);
// The following constraints must be satisfied:
// * Input must be 1D
// * Number of elements in two matrices must equal
if cfg!(debug_assertions) {
let n1_dims = n1.load_ndims(ctx);
let n2_dims = n2.load_ndims(ctx);
let n1_dims_eq1 =
ctx.builder.build_int_compare(IntPredicate::EQ, n1_dims, one, "").unwrap();
let n2_dims_eq1 =
ctx.builder.build_int_compare(IntPredicate::EQ, n2_dims, one, "").unwrap();
// num_dim = 1
ctx.make_assert(
generator,
n1_dims_eq1,
"0:ValueError",
format!("{FN_NAME} operates on 1D matrices").as_str(),
[None, None, None],
ctx.current_loc,
);
ctx.make_assert(
generator,
n2_dims_eq1,
"0:ValueError",
format!("{FN_NAME} operates on 1D matrices").as_str(),
[None, None, None],
ctx.current_loc,
);
// equal number of elements
let n1_sz = irrt::call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
let n2_sz = irrt::call_ndarray_calc_size(generator, ctx, &n2.dim_sizes(), (None, None));
let size_eq =
ctx.builder.build_int_compare(IntPredicate::EQ, n1_sz, n2_sz, "").unwrap();
ctx.make_assert(
generator,
size_eq,
"0:ValueError",
format!("The operands of {FN_NAME} must have equal length").as_str(),
[None, None, None],
ctx.current_loc,
);
}
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
Ok(extern_fns::call_np_dot(
ctx,
(dim0, one, n1.data().base_ptr(ctx, generator)),
(dim0, one, n2.data().base_ptr(ctx, generator)),
None,
)
.into())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
}
}
pub fn call_np_linalg_matmul<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_matmul";
let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
let llvm_usize = generator.get_size_type(ctx.ctx);
let one = llvm_usize.const_int(1, false);
let two = llvm_usize.const_int(2, false);
if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) = (x1, x2) {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let (n2_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
let n2_elem_ty = ctx.get_llvm_type(generator, n2_elem_ty);
let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty)
else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None);
// The following constraints must be satisfied:
// * Input must be 2D
// * Number of columns of first matrix should equal number of rows of second
if true {
let n1_dims = n1.load_ndims(ctx);
let n2_dims = n2.load_ndims(ctx);
let n1_dims_eq2 =
ctx.builder.build_int_compare(IntPredicate::EQ, n1_dims, two, "").unwrap();
let n2_dims_eq2 =
ctx.builder.build_int_compare(IntPredicate::EQ, n2_dims, two, "").unwrap();
// num_dim = 2
ctx.make_assert(
generator,
n1_dims_eq2,
"0:ValueError",
format!("{FN_NAME} operates on 2D matrices").as_str(),
[None, None, None],
ctx.current_loc,
);
ctx.make_assert(
generator,
n2_dims_eq2,
"0:ValueError",
format!("{FN_NAME} operates on 2D matrices").as_str(),
[None, None, None],
ctx.current_loc,
);
// matrix must be compatible for multiplication
let n1_col = unsafe {
n1.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value()
};
let n2_col = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim_eq =
ctx.builder.build_int_compare(IntPredicate::EQ, n1_col, n2_col, "").unwrap();
ctx.make_assert(
generator,
dim_eq,
"0:ValueError",
format!("Columns of first matrix must equal rows of second for {FN_NAME}").as_str(),
[None, None, None],
ctx.current_loc,
);
}
let out_dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let out_dim1 =
unsafe { n2.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value() };
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[out_dim0, out_dim1])
.unwrap();
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 =
unsafe { n1.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value() };
let dim2 =
unsafe { n2.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value() };
// let r = ctx.ctx.const_string(string, null_terminated);
extern_fns::call_np_linalg_matmul(
ctx,
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
(dim1, dim2, n2.data().base_ptr(ctx, generator)),
(dim0, dim2, out.data().base_ptr(ctx, generator)),
None,
);
Ok(out.as_base_value().as_basic_value_enum())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
}
}
pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_cholesky";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
let one = llvm_usize.const_int(1, false);
let two = llvm_usize.const_int(2, false);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
// The following constraints must be satisfied:
// * Input must be 2D
// * Input must be a square matrix (here we assume it is symmetric)
if cfg!(debug_assertions) {
let n1_dims = n1.load_ndims(ctx);
let n1_dims_eq2 =
ctx.builder.build_int_compare(IntPredicate::EQ, n1_dims, two, "").unwrap();
// num_dim = 2
ctx.make_assert(
generator,
n1_dims_eq2,
"0:ValueError",
format!("{FN_NAME} operates on 2D matrices").as_str(),
[None, None, None],
ctx.current_loc,
);
// Square Matrix
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value()
};
let dim_match =
ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim1, "").unwrap();
ctx.make_assert(
generator,
dim_match,
"0:ValueError",
format!("Input matrix must be a square matrix {FN_NAME}").as_str(),
[None, None, None],
ctx.current_loc,
);
}
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 =
unsafe { n1.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value() };
let out =
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1]).unwrap();
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
extern_fns::call_np_linalg_cholesky(
ctx,
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
(dim0, dim1, out.data().base_ptr(ctx, generator)),
None,
);
Ok(out.as_base_value().as_basic_value_enum())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_qr";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
let one = llvm_usize.const_int(1, false);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 =
unsafe { n1.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value() };
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k]).unwrap();
let out_r = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1]).unwrap();
extern_fns::call_np_linalg_qr(
ctx,
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
(dim0, k, out_q.data().base_ptr(ctx, generator)),
(k, dim1, out_r.data().base_ptr(ctx, generator)),
None,
);
let out_q = out_q.as_base_value().as_basic_value_enum();
let out_r = out_r.as_base_value().as_basic_value_enum();
let res_ty = ctx.ctx.struct_type(&[out_q.get_type(), out_r.get_type()], false);
let res_ptr = ctx.builder.build_alloca(res_ty, "QR_factorization").unwrap();
let res_val = [out_q, out_r];
for (i, v) in res_val.into_iter().enumerate() {
unsafe {
let ptr = ctx
.builder
.build_in_bounds_gep(
res_ptr,
&[
ctx.ctx.i32_type().const_zero(),
ctx.ctx.i32_type().const_int(i as u64, false),
],
"ptr",
)
.unwrap();
ctx.builder.build_store(ptr, v).unwrap();
}
}
Ok(ctx.builder.build_load(res_ptr, "QR_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_svd";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
let out_u =
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]).unwrap();
let out_s = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k]).unwrap();
let out_vh =
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim1]).unwrap();
extern_fns::call_np_linalg_svd(
ctx,
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
(dim0, dim0, out_u.data().base_ptr(ctx, generator)),
(k, llvm_usize.const_int(1, false), out_s.data().base_ptr(ctx, generator)),
(dim1, dim1, out_vh.data().base_ptr(ctx, generator)),
None,
);
let out_u = out_u.as_base_value().as_basic_value_enum();
let out_s = out_s.as_base_value().as_basic_value_enum();
let out_vh = out_vh.as_base_value().as_basic_value_enum();
let res_ty =
ctx.ctx.struct_type(&[out_u.get_type(), out_s.get_type(), out_vh.get_type()], false);
let res_ptr = ctx.builder.build_alloca(res_ty, "SVD_factorization").unwrap();
let res_val = [out_u, out_s, out_vh];
for (i, v) in res_val.into_iter().enumerate() {
unsafe {
let ptr = ctx
.builder
.build_in_bounds_gep(
res_ptr,
&[
ctx.ctx.i32_type().const_zero(),
ctx.ctx.i32_type().const_int(i as u64, false),
],
"ptr",
)
.unwrap();
ctx.builder.build_store(ptr, v).unwrap();
}
}
Ok(ctx.builder.build_load(res_ptr, "SVD_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_inv";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out =
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1]).unwrap();
extern_fns::call_np_linalg_inv(
ctx,
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
(dim0, dim1, out.data().base_ptr(ctx, generator)),
None,
);
Ok(out.as_base_value().as_basic_value_enum())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_pinv";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out =
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim0]).unwrap();
extern_fns::call_np_linalg_pinv(
ctx,
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
(dim1, dim0, out.data().base_ptr(ctx, generator)),
None,
);
Ok(out.as_base_value().as_basic_value_enum())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "sp_linalg_lu";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
let out_l = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k]).unwrap();
let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1]).unwrap();
extern_fns::call_sp_linalg_lu(
ctx,
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
(dim0, k, out_l.data().base_ptr(ctx, generator)),
(k, dim1, out_u.data().base_ptr(ctx, generator)),
None,
);
let out_l = out_l.as_base_value().as_basic_value_enum();
let out_u = out_u.as_base_value().as_basic_value_enum();
let res_ty = ctx.ctx.struct_type(&[out_l.get_type(), out_u.get_type()], false);
let res_ptr = ctx.builder.build_alloca(res_ty, "LU_factorization").unwrap();
let res_val = [out_l, out_u];
for (i, v) in res_val.into_iter().enumerate() {
unsafe {
let ptr = ctx
.builder
.build_in_bounds_gep(
res_ptr,
&[
ctx.ctx.i32_type().const_zero(),
ctx.ctx.i32_type().const_int(i as u64, false),
],
"ptr",
)
.unwrap();
ctx.builder.build_store(ptr, v).unwrap();
}
}
Ok(ctx.builder.build_load(res_ptr, "LU_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
// Must be square (add check later)
pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "sp_linalg_schur";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out_t =
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]).unwrap();
let out_z =
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]).unwrap();
extern_fns::call_sp_linalg_schur(
ctx,
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
(dim0, dim0, out_t.data().base_ptr(ctx, generator)),
(dim0, dim0, out_z.data().base_ptr(ctx, generator)),
None,
);
let out_t = out_t.as_base_value().as_basic_value_enum();
let out_z = out_z.as_base_value().as_basic_value_enum();
let res_ty = ctx.ctx.struct_type(&[out_t.get_type(), out_z.get_type()], false);
let res_ptr = ctx.builder.build_alloca(res_ty, "Schur_factorization").unwrap();
let r = ctx
.ctx
.const_string(ctx.current_loc.file.0.to_string().as_bytes(), true)
.as_basic_value_enum()
.into_pointer_value();
let res_val = [out_t, out_z];
for (i, v) in res_val.into_iter().enumerate() {
unsafe {
let ptr = ctx
.builder
.build_in_bounds_gep(
res_ptr,
&[
ctx.ctx.i32_type().const_zero(),
ctx.ctx.i32_type().const_int(i as u64, false),
],
"ptr",
)
.unwrap();
ctx.builder.build_store(ptr, v).unwrap();
}
}
Ok(ctx.builder.build_load(res_ptr, "Schur_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
// Must be square (add check later)
pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "sp_linalg_hessenberg";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
// Check if matrix is square
// ctx.builder.build_select(
// ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim1, "").unwrap(),
// {
// let func =
// }, else_, name)
// ;
// ctx.builder.build_call(
// ctx.module.get_function("__nac3_raise"),
// &[]
// )
// let err_msg = ctx.gen_string(generator, "{FN_NAME} requires square matrix");
// ctx.raise_exn(generator, "0:ValueError", err_msg, [None, None, None], ctx.current_loc);
let out_h =
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]).unwrap();
extern_fns::call_sp_linalg_hessenberg(
ctx,
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
(dim0, dim0, out_h.data().base_ptr(ctx, generator)),
None,
);
Ok(out_h.as_base_value().as_basic_value_enum())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}

View File

@ -951,9 +951,9 @@ pub fn destructure_range<'ctx>(
/// Allocates a List structure with the given [type][ty] and [length]. The name of the resulting /// Allocates a List structure with the given [type][ty] and [length]. The name of the resulting
/// LLVM value is `{name}.addr`, or `list.addr` if [name] is not specified. /// LLVM value is `{name}.addr`, or `list.addr` if [name] is not specified.
/// ///
/// Setting `ty` to [`None`] implies that the list is empty **and** does not have a known element /// Setting `ty` to [`None`] implies that the list does not have a known element type, which is only
/// type, and will therefore set the `list.data` type as `size_t*`. It is undefined behavior to /// valid for empty lists. It is undefined behavior to generate a sized list with an unknown element
/// generate a sized list with an unknown element type. /// type.
pub fn allocate_list<'ctx, G: CodeGenerator + ?Sized>( pub fn allocate_list<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,

View File

@ -131,153 +131,81 @@ pub fn call_ldexp<'ctx>(
.unwrap() .unwrap()
} }
/// Macro to generate np_linalg external functions /// Invokes the [`try_invert_to`](https://docs.rs/nalgebra/latest/nalgebra/linalg/fn.try_invert_to.html) function
macro_rules! generate_np_linalg_extern_fn { pub fn call_linalg_try_invert_to<'ctx>(
($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 1) => { ctx: &CodeGenContext<'ctx, '_>,
generate_np_linalg_extern_fn!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1); dim0: IntValue<'ctx>,
}; dim1: IntValue<'ctx>,
($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 2) => { data: PointerValue<'ctx>,
generate_np_linalg_extern_fn!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1, mat2); name: Option<&str>,
}; ) -> IntValue<'ctx> {
($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 3) => { const FN_NAME: &str = "linalg_try_invert_to";
generate_np_linalg_extern_fn!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1, mat2, mat3);
};
($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 4) => {
generate_np_linalg_extern_fn!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1, mat2, mat3, mat4);
};
($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal $(,$input_matrix:ident)*) => {
#[doc = concat!("Invokes the numpy `", stringify!($extern_fn), " function." )]
pub fn $fn_name<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>
$(,$input_matrix: (IntValue<'ctx>, IntValue<'ctx>, PointerValue<'ctx>))*,
name: Option<&str>,
) -> $ret_ty<'ctx> {
const FN_NAME: &str = $extern_fn;
let llvm_f64 = ctx.ctx.f64_type(); let llvm_f64 = ctx.ctx.f64_type();
let allowed_index_types = [ctx.ctx.i32_type(), ctx.ctx.i64_type()]; let allowed_indices = [ctx.ctx.i32_type(), ctx.ctx.i64_type()];
let allowed_dim0 = allowed_indices.iter().any(|p| *p == dim0.get_type());
let allowed_dim1 = allowed_indices.iter().any(|p| *p == dim1.get_type());
$( debug_assert!(allowed_dim0);
debug_assert!(allowed_index_types.iter().any(|p| *p == $input_matrix.0.get_type())); debug_assert!(allowed_dim1);
debug_assert!(allowed_index_types.iter().any(|p| *p == $input_matrix.1.get_type())); debug_assert_eq!(data.get_type().get_element_type().into_float_type(), llvm_f64);
debug_assert_eq!($input_matrix.2.get_type().get_element_type().into_float_type(), llvm_f64);
)*
// let row = ctx.ctx.i32_type().const_int(ctx.current_loc.row.try_into().unwrap(), false); let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
// let col = ctx.ctx.i32_type().const_int(ctx.current_loc.column.try_into().unwrap(), false); let fn_type = ctx.ctx.i8_type().fn_type(&[dim0.get_type().into(), dim0.get_type().into(), data.get_type().into()], false);
// let file_name = ctx.current_loc.file.0; let func = ctx.module.add_function(FN_NAME, fn_type, None);
// let name_len = ctx.ctx.i32_type().const_int(file_name.to_string().len().try_into().unwrap(), false); for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
// let file_name = ctx.ctx.const_string(&ctx.current_loc.file.0.to_string().into_bytes(), true); func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| { func
// let fn_type = ctx.ctx.$extern_ret_ty().fn_type(&[row.get_type().into(), col.get_type().into(), file_name.get_type().into(), $($input_matrix.0.get_type().into(), $input_matrix.1.get_type().into(), $input_matrix.2.get_type().into()),*], false); });
// let fn_type = ctx.ctx.$extern_ret_ty().fn_type(&[row.get_type().into(), col.get_type().into(), file_name.get_type().into(), name_len.get_type().into(), $($input_matrix.0.get_type().into(), $input_matrix.1.get_type().into(), $input_matrix.2.get_type().into()),*], false);
let fn_type = ctx.ctx.$extern_ret_ty().fn_type(&[$($input_matrix.0.get_type().into(), $input_matrix.1.get_type().into(), $input_matrix.2.get_type().into()),*], false); ctx.builder
.build_call(extern_fn, &[dim0.into(), dim1.into(), data.into()], name.unwrap_or_default())
let func = ctx.module.add_function(FN_NAME, fn_type, None); .map(CallSiteValue::try_as_basic_value)
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { .map(|v| v.map_left(BasicValueEnum::into_int_value))
func.add_attribute( .map(Either::unwrap_left)
AttributeLoc::Function, .unwrap().into()
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
// .build_call(extern_fn, &[row.into(), col.into(), file_name.into(), $($input_matrix.0.into(), $input_matrix.1.into(), $input_matrix.2.into(),)*], name.unwrap_or_default())
// .build_call(extern_fn, &[name_len.into(), col.into(), file_name.into(), row.into(), $($input_matrix.0.into(), $input_matrix.1.into(), $input_matrix.2.into(),)*], name.unwrap_or_default())
.build_call(extern_fn, &[$($input_matrix.0.into(), $input_matrix.1.into(), $input_matrix.2.into(),)*], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left($map_fn))
.map(Either::unwrap_left)
.unwrap()
}
};
} }
generate_np_linalg_extern_fn!( /// Invokes the [`wilkinson_shift`](https://docs.rs/nalgebra/latest/nalgebra/linalg/fn.wilkinson_shift.html) function
call_np_dot, pub fn call_linalg_wilkinson_shift<'ctx>(
FloatValue, ctx: &CodeGenContext<'ctx, '_>,
f64_type, dim0: IntValue<'ctx>,
BasicValueEnum::into_float_value, dim1: IntValue<'ctx>,
"np_dot", data: PointerValue<'ctx>,
2 name: Option<&str>,
); ) -> FloatValue<'ctx> {
generate_np_linalg_extern_fn!( const FN_NAME: &str = "linalg_wilkinson_shift";
call_np_linalg_matmul, let allowed_index_types = [ctx.ctx.i32_type(), ctx.ctx.i64_type()];
IntValue,
i8_type, let allowed_dim0 = allowed_index_types.iter().any(|p| *p == dim0.get_type());
BasicValueEnum::into_int_value, let allowed_dim1 = allowed_index_types.iter().any(|p| *p == dim1.get_type());
"np_linalg_matmul",
3
);
generate_np_linalg_extern_fn!(
call_np_linalg_cholesky,
IntValue,
i8_type,
BasicValueEnum::into_int_value,
"np_linalg_cholesky",
2
);
generate_np_linalg_extern_fn!(
call_np_linalg_qr,
IntValue,
i8_type,
BasicValueEnum::into_int_value,
"np_linalg_qr",
3
);
generate_np_linalg_extern_fn!(
call_np_linalg_svd,
IntValue,
i8_type,
BasicValueEnum::into_int_value,
"np_linalg_svd",
4
);
generate_np_linalg_extern_fn!( debug_assert!(allowed_dim0);
call_np_linalg_inv, debug_assert!(allowed_dim1);
IntValue,
i8_type,
BasicValueEnum::into_int_value,
"np_linalg_inv",
2
);
generate_np_linalg_extern_fn!( let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
call_np_linalg_pinv, let fn_type = ctx.ctx.f64_type().fn_type(&[dim0.get_type().into(), dim0.get_type().into(), data.get_type().into()], false);
IntValue, let func = ctx.module.add_function(FN_NAME, fn_type, None);
i8_type, for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
BasicValueEnum::into_int_value, func.add_attribute(
"np_linalg_pinv", AttributeLoc::Function,
2 ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
); );
}
generate_np_linalg_extern_fn!( func
call_sp_linalg_lu, });
IntValue,
i8_type, ctx.builder
BasicValueEnum::into_int_value, .build_call(extern_fn, &[dim0.into(), dim1.into(), data.into()], name.unwrap_or_default())
"sp_linalg_lu", .map(CallSiteValue::try_as_basic_value)
3 .map(|v| v.map_left(BasicValueEnum::into_float_value))
); .map(Either::unwrap_left)
.unwrap().into()
generate_np_linalg_extern_fn!( }
call_sp_linalg_schur,
IntValue,
i8_type,
BasicValueEnum::into_int_value,
"sp_linalg_schur",
3
);
generate_np_linalg_extern_fn!(
call_sp_linalg_hessenberg,
IntValue,
i8_type,
BasicValueEnum::into_int_value,
"sp_linalg_hessenberg",
2
);

View File

@ -61,7 +61,7 @@ fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
/// * `shape` - The shape of the `NDArray`. /// * `shape` - The shape of the `NDArray`.
/// * `shape_len_fn` - A function that retrieves the number of dimensions from `shape`. /// * `shape_len_fn` - A function that retrieves the number of dimensions from `shape`.
/// * `shape_data_fn` - A function that retrieves the size of a dimension from `shape`. /// * `shape_data_fn` - A function that retrieves the size of a dimension from `shape`.
pub fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>( fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
elem_ty: Type, elem_ty: Type,
@ -157,7 +157,7 @@ where
/// ///
/// * `elem_ty` - The element type of the `NDArray`. /// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The shape of the `NDArray`, represented am array of [`IntValue`]s. /// * `shape` - The shape of the `NDArray`, represented am array of [`IntValue`]s.
pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type, elem_ty: Type,

View File

@ -1637,7 +1637,7 @@ pub fn gen_stmt<G: CodeGenerator>(
}; };
ctx.make_assert_impl( ctx.make_assert_impl(
generator, generator,
generator.bool_to_i1(ctx, test.into_int_value()), test.into_int_value(),
"0:AssertionError", "0:AssertionError",
err_msg, err_msg,
[None, None, None], [None, None, None],

View File

@ -557,18 +557,8 @@ impl<'a> BuiltinBuilder<'a> {
| PrimDef::FunNpHypot | PrimDef::FunNpHypot
| PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim), | PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim),
PrimDef::FunNpDot PrimDef::FunTryInvertTo => self.build_linalg_try_invert_to(prim), // Inplace invert
| PrimDef::FunNpLinalgMatmul PrimDef::FunWilkinsonShift => self.build_linalg_wilkinson_shift(prim),
| PrimDef::FunNpLinalgCholesky
| PrimDef::FunNpLinalgQr
| PrimDef::FunNpLinalgSvd
| PrimDef::FunNpLinalgInv
| PrimDef::FunNpLinalgPinv
| PrimDef::FunSpLinalgLu
| PrimDef::FunSpLinalgSchur
| PrimDef::FunSpLinalgHessenberg => self.build_np_linalg_methods(prim),
// PrimDef::FunNpDot | PrimDef::FunNpLinalgMatmul => self.build_np_linalg_binary_methods(prim),
// PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgQr => self.build_np_linalg_unary_methods(prim),
}; };
if cfg!(debug_assertions) { if cfg!(debug_assertions) {
@ -577,7 +567,7 @@ impl<'a> BuiltinBuilder<'a> {
match (&tld, prim.details()) { match (&tld, prim.details()) {
( (
TopLevelDef::Class { name, object_id, .. }, TopLevelDef::Class { name, object_id, .. },
PrimDefDetails::PrimClass { name: exp_name, .. }, PrimDefDetails::PrimClass { name: exp_name },
) => { ) => {
let exp_object_id = prim.id(); let exp_object_id = prim.id();
assert_eq!(name, &exp_name.into()); assert_eq!(name, &exp_name.into());
@ -1887,140 +1877,62 @@ impl<'a> BuiltinBuilder<'a> {
} }
} }
fn build_np_linalg_methods(&mut self, prim: PrimDef) -> TopLevelDef { fn build_linalg_try_invert_to(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed( debug_assert_prim_is_allowed(
prim, prim,
&[ &[
PrimDef::FunNpDot, PrimDef::FunTryInvertTo,
PrimDef::FunNpLinalgMatmul,
PrimDef::FunNpLinalgCholesky,
PrimDef::FunNpLinalgQr,
PrimDef::FunNpLinalgSvd,
PrimDef::FunNpLinalgInv,
PrimDef::FunNpLinalgPinv,
PrimDef::FunSpLinalgLu,
PrimDef::FunSpLinalgSchur,
PrimDef::FunSpLinalgHessenberg,
], ],
); );
let var_map = self.num_or_ndarray_var_map.clone();
create_fn_by_codegen(
self.unifier,
&var_map,
prim.name(),
self.primitives.bool,
&[(self.ndarray_float_2d, "x")],
Box::new(move |ctx, _, fun, args, generator| {
let x_ty = fun.0.args[0].ty;
let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?;
match prim { let func = match prim {
PrimDef::FunNpDot => create_fn_by_codegen( PrimDef::FunTryInvertTo => builtin_fns::call_try_invert_to,
self.unifier, _ => unreachable!(),
&self.num_or_ndarray_var_map, };
prim.name(),
self.primitives.float,
&[(self.num_or_ndarray_ty.ty, "x1"), (self.num_or_ndarray_ty.ty, "x2")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(builtin_fns::call_np_dot( Ok(Some(func(generator, ctx, (x_ty, x_val))?))
generator, }),
ctx, )
(x1_ty, x1_val), }
(x2_ty, x2_val),
)?))
}),
),
PrimDef::FunNpLinalgMatmul => create_fn_by_codegen( fn build_linalg_wilkinson_shift(&mut self, prim: PrimDef) -> TopLevelDef {
self.unifier, debug_assert_prim_is_allowed(
&VarMap::new(), prim,
prim.name(), &[
self.ndarray_float_2d, PrimDef::FunWilkinsonShift,
&[(self.ndarray_float_2d, "x1"), (self.ndarray_float_2d, "x2")], ],
Box::new(move |ctx, _, fun, args, generator| { );
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; let var_map = self.num_or_ndarray_var_map.clone();
let x2_ty = fun.0.args[1].ty; create_fn_by_codegen(
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; self.unifier,
&var_map,
prim.name(),
self.primitives.float,
&[(self.ndarray_float_2d, "x")],
Box::new(move |ctx, _, fun, args, generator| {
let x_ty = fun.0.args[0].ty;
let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?;
Ok(Some(builtin_fns::call_np_linalg_matmul( let func = match prim {
generator, PrimDef::FunWilkinsonShift => builtin_fns::call_wilkinson_shift,
ctx, _ => unreachable!(),
(x1_ty, x1_val), };
(x2_ty, x2_val),
)?))
}),
),
PrimDef::FunNpLinalgCholesky Ok(Some(func(generator, ctx, (x_ty, x_val))?))
| PrimDef::FunNpLinalgInv }),
| PrimDef::FunNpLinalgPinv )
| PrimDef::FunSpLinalgHessenberg => create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
self.ndarray_float_2d,
&[(self.ndarray_float_2d, "x1")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let func = match prim {
PrimDef::FunNpLinalgCholesky => builtin_fns::call_np_linalg_cholesky,
PrimDef::FunNpLinalgInv => builtin_fns::call_np_linalg_inv,
PrimDef::FunNpLinalgPinv => builtin_fns::call_np_linalg_pinv,
PrimDef::FunSpLinalgHessenberg => builtin_fns::call_sp_linalg_hessenberg,
_ => unreachable!(),
};
Ok(Some(func(generator, ctx, (x1_ty, x1_val))?))
}),
),
PrimDef::FunNpLinalgQr | PrimDef::FunSpLinalgLu | PrimDef::FunSpLinalgSchur => {
let ret_ty = self.unifier.add_ty(TypeEnum::TTuple {
ty: vec![self.ndarray_float_2d, self.ndarray_float_2d],
});
create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
ret_ty,
&[(self.ndarray_float_2d, "x1")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val =
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let func = match prim {
PrimDef::FunNpLinalgQr => builtin_fns::call_np_linalg_qr,
PrimDef::FunSpLinalgLu => builtin_fns::call_sp_linalg_lu,
PrimDef::FunSpLinalgSchur => builtin_fns::call_sp_linalg_schur,
_ => unreachable!(),
};
Ok(Some(func(generator, ctx, (x1_ty, x1_val))?))
}),
)
}
PrimDef::FunNpLinalgSvd => {
let ret_ty = self.unifier.add_ty(TypeEnum::TTuple {
ty: vec![self.ndarray_float_2d, self.ndarray_float, self.ndarray_float_2d],
});
create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
ret_ty,
&[(self.ndarray_float_2d, "x1")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val =
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
Ok(Some(builtin_fns::call_np_linalg_svd(generator, ctx, (x1_ty, x1_val))?))
}),
)
}
_ => {
println!("{:?}", prim.name());
unreachable!()
}
}
} }
fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) { fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) {

View File

@ -766,7 +766,6 @@ impl TopLevelComposer {
let target_ty = get_type_from_type_annotation_kinds( let target_ty = get_type_from_type_annotation_kinds(
&temp_def_list, &temp_def_list,
unifier, unifier,
primitives,
&def, &def,
&mut subst_list, &mut subst_list,
)?; )?;
@ -937,7 +936,6 @@ impl TopLevelComposer {
let ty = get_type_from_type_annotation_kinds( let ty = get_type_from_type_annotation_kinds(
temp_def_list.as_ref(), temp_def_list.as_ref(),
unifier, unifier,
primitives_store,
&type_annotation, &type_annotation,
&mut None, &mut None,
)?; )?;
@ -1004,7 +1002,6 @@ impl TopLevelComposer {
get_type_from_type_annotation_kinds( get_type_from_type_annotation_kinds(
&temp_def_list, &temp_def_list,
unifier, unifier,
primitives_store,
&return_ty_annotation, &return_ty_annotation,
&mut None, &mut None,
)? )?
@ -1625,7 +1622,6 @@ impl TopLevelComposer {
let self_type = get_type_from_type_annotation_kinds( let self_type = get_type_from_type_annotation_kinds(
&def_list, &def_list,
unifier, unifier,
primitives_ty,
&make_self_type_annotation(type_vars, *object_id), &make_self_type_annotation(type_vars, *object_id),
&mut None, &mut None,
)?; )?;
@ -1807,11 +1803,7 @@ impl TopLevelComposer {
let ty_ann = make_self_type_annotation(type_vars, *class_id); let ty_ann = make_self_type_annotation(type_vars, *class_id);
let self_ty = get_type_from_type_annotation_kinds( let self_ty = get_type_from_type_annotation_kinds(
&def_list, &def_list, unifier, &ty_ann, &mut None,
unifier,
primitives_ty,
&ty_ann,
&mut None,
)?; )?;
vars.extend(type_vars.iter().map(|ty| { vars.extend(type_vars.iter().map(|ty| {
let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*ty) else { let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*ty) else {

View File

@ -105,16 +105,8 @@ pub enum PrimDef {
FunNpLdExp, FunNpLdExp,
FunNpHypot, FunNpHypot,
FunNpNextAfter, FunNpNextAfter,
FunNpDot, FunTryInvertTo,
FunNpLinalgMatmul, FunWilkinsonShift,
FunNpLinalgCholesky,
FunNpLinalgQr,
FunNpLinalgSvd,
FunNpLinalgInv,
FunNpLinalgPinv,
FunSpLinalgLu,
FunSpLinalgSchur,
FunSpLinalgHessenberg,
// Top-Level Functions // Top-Level Functions
FunSome, FunSome,
@ -123,7 +115,7 @@ pub enum PrimDef {
/// Associated details of a [`PrimDef`] /// Associated details of a [`PrimDef`]
pub enum PrimDefDetails { pub enum PrimDefDetails {
PrimFunction { name: &'static str, simple_name: &'static str }, PrimFunction { name: &'static str, simple_name: &'static str },
PrimClass { name: &'static str, get_ty_fn: fn(&PrimitiveStore) -> Type }, PrimClass { name: &'static str },
} }
impl PrimDef { impl PrimDef {
@ -165,17 +157,15 @@ impl PrimDef {
#[must_use] #[must_use]
pub fn name(&self) -> &'static str { pub fn name(&self) -> &'static str {
match self.details() { match self.details() {
PrimDefDetails::PrimFunction { name, .. } | PrimDefDetails::PrimClass { name, .. } => { PrimDefDetails::PrimFunction { name, .. } | PrimDefDetails::PrimClass { name } => name,
name
}
} }
} }
/// Get the associated details of this [`PrimDef`] /// Get the associated details of this [`PrimDef`]
#[must_use] #[must_use]
pub fn details(self) -> PrimDefDetails { pub fn details(self) -> PrimDefDetails {
fn class(name: &'static str, get_ty_fn: fn(&PrimitiveStore) -> Type) -> PrimDefDetails { fn class(name: &'static str) -> PrimDefDetails {
PrimDefDetails::PrimClass { name, get_ty_fn } PrimDefDetails::PrimClass { name }
} }
fn fun(name: &'static str, simple_name: Option<&'static str>) -> PrimDefDetails { fn fun(name: &'static str, simple_name: Option<&'static str>) -> PrimDefDetails {
@ -183,22 +173,22 @@ impl PrimDef {
} }
match self { match self {
PrimDef::Int32 => class("int32", |primitives| primitives.int32), PrimDef::Int32 => class("int32"),
PrimDef::Int64 => class("int64", |primitives| primitives.int64), PrimDef::Int64 => class("int64"),
PrimDef::Float => class("float", |primitives| primitives.float), PrimDef::Float => class("float"),
PrimDef::Bool => class("bool", |primitives| primitives.bool), PrimDef::Bool => class("bool"),
PrimDef::None => class("none", |primitives| primitives.none), PrimDef::None => class("none"),
PrimDef::Range => class("range", |primitives| primitives.range), PrimDef::Range => class("range"),
PrimDef::Str => class("str", |primitives| primitives.str), PrimDef::Str => class("str"),
PrimDef::Exception => class("Exception", |primitives| primitives.exception), PrimDef::Exception => class("Exception"),
PrimDef::UInt32 => class("uint32", |primitives| primitives.uint32), PrimDef::UInt32 => class("uint32"),
PrimDef::UInt64 => class("uint64", |primitives| primitives.uint64), PrimDef::UInt64 => class("uint64"),
PrimDef::Option => class("Option", |primitives| primitives.option), PrimDef::Option => class("Option"),
PrimDef::OptionIsSome => fun("Option.is_some", Some("is_some")), PrimDef::OptionIsSome => fun("Option.is_some", Some("is_some")),
PrimDef::OptionIsNone => fun("Option.is_none", Some("is_none")), PrimDef::OptionIsNone => fun("Option.is_none", Some("is_none")),
PrimDef::OptionUnwrap => fun("Option.unwrap", Some("unwrap")), PrimDef::OptionUnwrap => fun("Option.unwrap", Some("unwrap")),
PrimDef::List => class("list", |primitives| primitives.list), PrimDef::List => class("list"),
PrimDef::NDArray => class("ndarray", |primitives| primitives.ndarray), PrimDef::NDArray => class("ndarray"),
PrimDef::NDArrayCopy => fun("ndarray.copy", Some("copy")), PrimDef::NDArrayCopy => fun("ndarray.copy", Some("copy")),
PrimDef::NDArrayFill => fun("ndarray.fill", Some("fill")), PrimDef::NDArrayFill => fun("ndarray.fill", Some("fill")),
PrimDef::FunInt32 => fun("int32", None), PrimDef::FunInt32 => fun("int32", None),
@ -273,17 +263,8 @@ impl PrimDef {
PrimDef::FunNpLdExp => fun("np_ldexp", None), PrimDef::FunNpLdExp => fun("np_ldexp", None),
PrimDef::FunNpHypot => fun("np_hypot", None), PrimDef::FunNpHypot => fun("np_hypot", None),
PrimDef::FunNpNextAfter => fun("np_nextafter", None), PrimDef::FunNpNextAfter => fun("np_nextafter", None),
PrimDef::FunNpDot => fun("np_dot", None), PrimDef::FunTryInvertTo => fun("try_invert_to", None),
PrimDef::FunNpLinalgMatmul => fun("np_linalg_matmul", None), PrimDef::FunWilkinsonShift => fun("wilkinson_shift", None),
PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None),
PrimDef::FunNpLinalgQr => fun("np_linalg_qr", None),
PrimDef::FunNpLinalgSvd => fun("np_linalg_svd", None),
PrimDef::FunNpLinalgInv => fun("np_linalg_inv", None),
PrimDef::FunNpLinalgPinv => fun("np_linalg_pinv", None),
PrimDef::FunSpLinalgLu => fun("sp_linalg_lu", None),
PrimDef::FunSpLinalgSchur => fun("sp_linalg_schur", None),
PrimDef::FunSpLinalgHessenberg => fun("sp_linalg_hessenberg", None),
PrimDef::FunSome => fun("Some", None), PrimDef::FunSome => fun("Some", None),
} }
} }

View File

@ -1,9 +1,8 @@
use super::*; use super::*;
use crate::symbol_resolver::SymbolValue; use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::{PrimDef, PrimDefDetails}; use crate::toplevel::helper::PrimDef;
use crate::typecheck::typedef::VarMap; use crate::typecheck::typedef::VarMap;
use nac3parser::ast::Constant; use nac3parser::ast::Constant;
use strum::IntoEnumIterator;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum TypeAnnotation { pub enum TypeAnnotation {
@ -358,7 +357,6 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
pub fn get_type_from_type_annotation_kinds( pub fn get_type_from_type_annotation_kinds(
top_level_defs: &[Arc<RwLock<TopLevelDef>>], top_level_defs: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier, unifier: &mut Unifier,
primitives: &PrimitiveStore,
ann: &TypeAnnotation, ann: &TypeAnnotation,
subst_list: &mut Option<Vec<Type>>, subst_list: &mut Option<Vec<Type>>,
) -> Result<Type, HashSet<String>> { ) -> Result<Type, HashSet<String>> {
@ -381,141 +379,100 @@ pub fn get_type_from_type_annotation_kinds(
let param_ty = params let param_ty = params
.iter() .iter()
.map(|x| { .map(|x| {
get_type_from_type_annotation_kinds( get_type_from_type_annotation_kinds(top_level_defs, unifier, x, subst_list)
top_level_defs,
unifier,
primitives,
x,
subst_list,
)
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
let ty = if let Some(prim_def) = PrimDef::iter().find(|prim| prim.id() == *obj_id) { let subst = {
// Primitive TopLevelDefs do not contain all fields that are present in their Type // check for compatible range
// counterparts, so directly perform subst on the Type instead. // TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check
let mut result = VarMap::new();
let PrimDefDetails::PrimClass { get_ty_fn, .. } = prim_def.details() else { for (tvar, p) in type_vars.iter().zip(param_ty) {
unreachable!() match unifier.get_ty(*tvar).as_ref() {
}; TypeEnum::TVar {
id,
let base_ty = get_ty_fn(primitives); range,
let params = fields: None,
if let TypeEnum::TObj { params, .. } = &*unifier.get_ty_immutable(base_ty) { name,
params.clone() loc,
} else { is_const_generic: false,
unreachable!() } => {
}; let ok: bool = {
// create a temp type var and unify to check compatibility
unifier p == *tvar || {
.subst( let temp = unifier.get_fresh_var_with_range(
get_ty_fn(primitives), range.as_slice(),
&params *name,
.iter() *loc,
.zip(param_ty) );
.map(|(obj_tv, param)| (*obj_tv.0, param)) unifier.unify(temp.ty, p).is_ok()
.collect(),
)
.unwrap_or(base_ty)
} else {
let subst = {
// check for compatible range
// TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check
let mut result = VarMap::new();
for (tvar, p) in type_vars.iter().zip(param_ty) {
match unifier.get_ty(*tvar).as_ref() {
TypeEnum::TVar {
id,
range,
fields: None,
name,
loc,
is_const_generic: false,
} => {
let ok: bool = {
// create a temp type var and unify to check compatibility
p == *tvar || {
let temp = unifier.get_fresh_var_with_range(
range.as_slice(),
*name,
*loc,
);
unifier.unify(temp.ty, p).is_ok()
}
};
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([format!(
"cannot apply type {} to type variable with id {:?}",
unifier.internal_stringify(
p,
&mut |id| format!("class{id}"),
&mut |id| format!("typevar{id}"),
&mut None
),
*id
)]));
} }
};
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([format!(
"cannot apply type {} to type variable with id {:?}",
unifier.internal_stringify(
p,
&mut |id| format!("class{id}"),
&mut |id| format!("typevar{id}"),
&mut None
),
*id
)]));
} }
TypeEnum::TVar {
id, range, name, loc, is_const_generic: true, ..
} => {
let ty = range[0];
let ok: bool = {
// create a temp type var and unify to check compatibility
p == *tvar || {
let temp =
unifier.get_fresh_const_generic_var(ty, *name, *loc);
unifier.unify(temp.ty, p).is_ok()
}
};
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([format!(
"cannot apply type {} to type variable {}",
unifier.stringify(p),
name.unwrap_or_else(|| format!("typevar{id}").into()),
)]));
}
}
_ => unreachable!("must be generic type var"),
} }
}
result
};
// Class Attributes keep a copy with Class Definition and are not added to objects
let mut tobj_fields = methods
.iter()
.map(|(name, ty, _)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
// methods are immutable
(*name, (subst_ty, false))
})
.collect::<HashMap<_, _>>();
tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*name, (subst_ty, *mutability))
}));
let need_subst = !subst.is_empty();
let ty = unifier.add_ty(TypeEnum::TObj {
obj_id: *obj_id,
fields: tobj_fields,
params: subst,
});
if need_subst { TypeEnum::TVar { id, range, name, loc, is_const_generic: true, .. } => {
if let Some(wl) = subst_list.as_mut() { let ty = range[0];
wl.push(ty); let ok: bool = {
// create a temp type var and unify to check compatibility
p == *tvar || {
let temp = unifier.get_fresh_const_generic_var(ty, *name, *loc);
unifier.unify(temp.ty, p).is_ok()
}
};
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([format!(
"cannot apply type {} to type variable {}",
unifier.stringify(p),
name.unwrap_or_else(|| format!("typevar{id}").into()),
)]));
}
}
_ => unreachable!("must be generic type var"),
} }
} }
result
ty
}; };
// Class Attributes keep a copy with Class Definition and are not added to objects
let mut tobj_fields = methods
.iter()
.map(|(name, ty, _)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
// methods are immutable
(*name, (subst_ty, false))
})
.collect::<HashMap<_, _>>();
tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*name, (subst_ty, *mutability))
}));
let need_subst = !subst.is_empty();
let ty = unifier.add_ty(TypeEnum::TObj {
obj_id: *obj_id,
fields: tobj_fields,
params: subst,
});
if need_subst {
if let Some(wl) = subst_list.as_mut() {
wl.push(ty);
}
}
Ok(ty) Ok(ty)
} }
TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty), TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty),
@ -533,7 +490,6 @@ pub fn get_type_from_type_annotation_kinds(
let ty = get_type_from_type_annotation_kinds( let ty = get_type_from_type_annotation_kinds(
top_level_defs, top_level_defs,
unifier, unifier,
primitives,
ty.as_ref(), ty.as_ref(),
subst_list, subst_list,
)?; )?;
@ -543,13 +499,7 @@ pub fn get_type_from_type_annotation_kinds(
let tys = tys let tys = tys
.iter() .iter()
.map(|x| { .map(|x| {
get_type_from_type_annotation_kinds( get_type_from_type_annotation_kinds(top_level_defs, unifier, x, subst_list)
top_level_defs,
unifier,
primitives,
x,
subst_list,
)
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
Ok(unifier.add_ty(TypeEnum::TTuple { ty: tys })) Ok(unifier.add_ty(TypeEnum::TTuple { ty: tys }))

View File

@ -8,7 +8,7 @@ edition = "2021"
parking_lot = "0.12" parking_lot = "0.12"
nac3parser = { path = "../nac3parser" } nac3parser = { path = "../nac3parser" }
nac3core = { path = "../nac3core" } nac3core = { path = "../nac3core" }
linalg_externfns = { path = "./linalg_externfns" } externfns = { path = "./demo/externfns" }
[dependencies.clap] [dependencies.clap]
version = "4.5" version = "4.5"

View File

@ -15,7 +15,7 @@ done
demo="$1" demo="$1"
echo -n "Checking $demo... " echo -n "Checking $demo... "
# ./interpret_demo.py "$demo" > interpreted.log ./interpret_demo.py "$demo" > interpreted.log
./run_demo.sh --out run.log "${nac3args[@]}" "$demo" ./run_demo.sh --out run.log "${nac3args[@]}" "$demo"
./run_demo.sh --lli --out run_lli.log "${nac3args[@]}" "$demo" ./run_demo.sh --lli --out run_lli.log "${nac3args[@]}" "$demo"
diff -Nau interpreted.log run.log diff -Nau interpreted.log run.log

View File

@ -107,25 +107,8 @@ uint32_t __nac3_personality(uint32_t state, uint32_t exception_object, uint32_t
__builtin_unreachable(); __builtin_unreachable();
} }
// See `struct Exception<'a>` in uint32_t __nac3_raise(uint32_t state, uint32_t exception_object, uint32_t context) {
// https://github.com/m-labs/artiq/blob/master/artiq/firmware/libeh/eh_artiq.rs printf("__nac3_raise(state: %u, exception_object: %u, context: %u)\n", state, exception_object, context);
struct Exception {
uint32_t id;
struct cslice file;
uint32_t line;
uint32_t column;
struct cslice function;
struct cslice message;
int64_t param[3];
};
uint32_t __nac3_raise(struct Exception* e) {
printf("__nac3_raise called. Exception details:\n");
printf(" ID: %lld\n", e->id);
printf(" Location: %*s:%lld:%lld\n" , e->file.len, (const char*) e->file.data, e->line, e->column);
printf(" Function: %*s\n" , e->function.len, (const char*) e->function.data);
printf(" Message: \"%*s\"\n" , e->message.len, (const char*) e->message.data);
printf(" Params: {0}=%lld, {1}=%lld, {2}=%lld\n", e->param[0], e->param[1], e->param[2]);
exit(101); exit(101);
__builtin_unreachable(); __builtin_unreachable();
} }

View File

@ -1,5 +1,5 @@
[package] [package]
name = "linalg_externfns" name = "externfns"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
@ -8,4 +8,3 @@ crate-type = ["cdylib"]
[dependencies] [dependencies]
nalgebra = {version = "0.32.6", default-features = false, features = ["libm", "alloc"]} nalgebra = {version = "0.32.6", default-features = false, features = ["libm", "alloc"]}
cslice = "0.3.0"

View File

@ -0,0 +1,41 @@
#![deny(
future_incompatible,
let_underscore,
nonstandard_style,
rust_2024_compatibility,
clippy::all
)]
#![warn(clippy::pedantic)]
#![allow(clippy::semicolon_if_nothing_returned, clippy::uninlined_format_args)]
use core::slice;
use nalgebra::{DMatrix, linalg};
#[no_mangle]
/// Provide an interface to `nalgebra::linalg::try_invert_to`
pub extern "C" fn linalg_try_invert_to(dim0: usize, dim1: usize, data: *mut f64) -> i8
{
let data_slice = unsafe { slice::from_raw_parts_mut(data, dim0 * dim1) };
let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice);
let mut inverted_matrix = DMatrix::<f64>::zeros(dim0, dim1);
if linalg::try_invert_to(matrix, &mut inverted_matrix) {
data_slice.copy_from_slice(inverted_matrix.transpose().as_slice());
1
} else {
0
}
}
#[no_mangle]
/// Provide an interface to `nalgebra::linalg::wilkinson_shift`
pub extern "C" fn linalg_wilkinson_shift(dim0: usize, dim1: usize, data: *mut f64) -> f64
{
let data_slice = unsafe { slice::from_raw_parts_mut(data, dim0 * dim1) };
let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice);
// Check if matrix is symmetric
assert!(matrix[(0, 1)] == matrix[(1, 0)], "Operation Wilkinson Shift expects symmetric matrix");
return linalg::wilkinson_shift(matrix[(0, 0)], matrix[(1, 1)], matrix[(0, 1)]);
}

View File

@ -5,7 +5,6 @@ import importlib.util
import importlib.machinery import importlib.machinery
import math import math
import numpy as np import numpy as np
import scipy as sp
import numpy.typing as npt import numpy.typing as npt
import pathlib import pathlib
@ -188,7 +187,7 @@ def patch(module):
module.ceil64 = _ceil module.ceil64 = _ceil
module.np_ceil = np.ceil module.np_ceil = np.ceil
# NumPy NDArray factory functions # NumPy ndarray functions
module.ndarray = NDArray module.ndarray = NDArray
module.np_ndarray = np.ndarray module.np_ndarray = np.ndarray
module.np_empty = np.empty module.np_empty = np.empty
@ -239,7 +238,7 @@ def patch(module):
module.np_hypot = np.hypot module.np_hypot = np.hypot
module.np_nextafter = np.nextafter module.np_nextafter = np.nextafter
# SciPy Math functions # SciPy Math Functions
module.sp_spec_erf = special.erf module.sp_spec_erf = special.erf
module.sp_spec_erfc = special.erfc module.sp_spec_erfc = special.erfc
module.sp_spec_gamma = special.gamma module.sp_spec_gamma = special.gamma
@ -247,21 +246,16 @@ def patch(module):
module.sp_spec_j0 = special.j0 module.sp_spec_j0 = special.j0
module.sp_spec_j1 = special.j1 module.sp_spec_j1 = special.j1
# Linalg functions # NumPy NDArray Functions
module.np_dot = np.dot module.np_ndarray = np.ndarray
module.np_linalg_matmul = np.matmul module.np_empty = np.empty
module.np_linalg_cholesky = np.linalg.cholesky module.np_zeros = np.zeros
module.np_linalg_qr = np.linalg.qr module.np_ones = np.ones
module.np_linalg_svd = np.linalg.svd module.np_full = np.full
module.np_linalg_inv = np.linalg.inv module.np_eye = np.eye
module.np_linalg_pinv = np.linalg.pinv module.np_identity = np.identity
module.try_invert_to = try_invert_to
module.sp_linalg_lu = lambda x: sp.linalg.lu(x, True) module.wilkinson_shift = wilkinson_shift
module.sp_linalg_schur = sp.linalg.schur
# module.sp_linalg_hessenberg = sp.linalg.hessenberg
module.sp_linalg_hessenberg = lambda x: x
def file_import(filename, prefix="file_import_"): def file_import(filename, prefix="file_import_"):
filename = pathlib.Path(filename) filename = pathlib.Path(filename)

View File

@ -1,2 +0,0 @@
Excepytiopn!! knfv 0x7fffffff9218
__nac3_personality(state: 1, exception_object: 1, context: 1381323604)

View File

@ -42,14 +42,14 @@ done
if [ -n "$debug" ] && [ -e ../../target/debug/nac3standalone ]; then if [ -n "$debug" ] && [ -e ../../target/debug/nac3standalone ]; then
nac3standalone=../../target/debug/nac3standalone nac3standalone=../../target/debug/nac3standalone
externfns=../../target/debug/deps/liblinalg_externfns.so externfns=../../target/debug/deps/libexternfns.so
elif [ -e ../../target/release/nac3standalone ]; then elif [ -e ../../target/release/nac3standalone ]; then
nac3standalone=../../target/release/nac3standalone nac3standalone=../../target/release/nac3standalone
externfns=../../target/release/deps/liblinalg_externfns.so externfns=../../target/release/deps/libexternfns.so
else else
# used by Nix builds # used by Nix builds
nac3standalone=../../target/x86_64-unknown-linux-gnu/release/nac3standalone nac3standalone=../../target/x86_64-unknown-linux-gnu/release/nac3standalone
externfns=../../target/x86_64-unknown-linux-gnu/release/deps/liblinalg_externfns.so externfns=../../target/x86_64-unknown-linux-gnu/release/deps/libexternfns.so
fi fi
rm -f ./*.o ./*.bc demo rm -f ./*.o ./*.bc demo

View File

@ -1,12 +0,0 @@
8.000000
10.000000
12.000000
4.000000
5.000000
6.000000
1.000000
2.000000
3.000000
4.000000
5.000000
6.000000

View File

@ -531,12 +531,10 @@ def test_ndarray_ipow_broadcast_scalar():
def test_ndarray_matmul(): def test_ndarray_matmul():
x = np_identity(2) x = np_identity(2)
t: ndarray[float, 2] = np_array([[1., 2., 3.,], [4., 5., 6.], [7., 8., 9.], [7., 8., 9.]]) y = x @ np_ones([2, 2])
y = x @ t
y2 = np_linalg_matmul(x, t)
output_ndarray_float_2(y)
output_ndarray_float_2(y2)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_imatmul(): def test_ndarray_imatmul():
x = np_identity(2) x = np_identity(2)
@ -1431,289 +1429,201 @@ def test_ndarray_nextafter_broadcast_rhs_scalar():
output_ndarray_float_2(nextafter_x_zeros) output_ndarray_float_2(nextafter_x_zeros)
output_ndarray_float_2(nextafter_x_ones) output_ndarray_float_2(nextafter_x_ones)
def test_ndarray_dot(): def test_try_invert():
x: ndarray[float, 1] = np_array([5.0, 1.0]) x: ndarray[float, 2] = np_array([[1.0, 1.0], [1.0, 5.0]])
y: ndarray[float, 1] = np_array([5.0, 1.0])
z = np_dot(x, y)
output_ndarray_float_1(x)
output_ndarray_float_1(y)
output_float64(z)
def test_ndarray_linalg_matmul():
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
y: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
z = np_linalg_matmul(x, y)
m = np_argmax(z)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
output_ndarray_float_2(z)
output_int64(m)
def test_ndarray_cholesky():
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
y = np_linalg_cholesky(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_qr():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
y, z = np_linalg_qr(x)
output_ndarray_float_2(x) output_ndarray_float_2(x)
# QR Factorization in nalgebra and numpy do not give the same result y = try_invert_to(x)
# Generating product for printing
a = np_linalg_matmul(y, z)
output_ndarray_float_2(a)
def test_ndarray_linalg_inv():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
y = np_linalg_inv(x)
output_ndarray_float_2(x) output_ndarray_float_2(x)
output_ndarray_float_2(y) output_bool(y)
def test_ndarray_pinv():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5]])
y = np_linalg_pinv(x)
def test_wilkinson_shift():
x: ndarray[float, 2] = np_array([[5., 1.], [1., 4.]])
y = wilkinson_shift(x)
output_ndarray_float_2(x) output_ndarray_float_2(x)
output_ndarray_float_2(y) output_float64(y)
def test_ndarray_schur():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
t, z = sp_linalg_schur(x)
output_ndarray_float_2(x)
# Same as np_linalg_qr the signs are different in nalgebra and numpy
a = np_linalg_matmul(np_linalg_matmul(z, t), np_linalg_inv(z))
output_ndarray_float_2(a)
def test_ndarray_hessenberg():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5]])
h = sp_linalg_hessenberg(x)
output_ndarray_float_2(x)
output_ndarray_float_2(h)
def test_ndarray_lu():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5]])
l, u = sp_linalg_lu(x)
output_ndarray_float_2(x)
output_ndarray_float_2(l)
output_ndarray_float_2(u)
def test_ndarray_svd():
w: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
x, y, z = np_linalg_svd(w)
output_ndarray_float_2(w)
# Same as np_linalg_qr the signs are different in nalgebra and numpy
a = np_linalg_matmul(x, z)
output_ndarray_float_2(a)
output_ndarray_float_1(y)
def run() -> int32: def run() -> int32:
test_try_invert()
test_wilkinson_shift()
test_ndarray_ctor()
test_ndarray_empty()
test_ndarray_zeros()
test_ndarray_ones()
test_ndarray_full()
test_ndarray_eye()
test_ndarray_array()
test_ndarray_identity()
test_ndarray_fill()
test_ndarray_copy()
test_ndarray_neg_idx()
test_ndarray_slices()
test_ndarray_nd_idx()
test_ndarray_add()
test_ndarray_add_broadcast()
test_ndarray_add_broadcast_lhs_scalar()
test_ndarray_add_broadcast_rhs_scalar()
test_ndarray_iadd()
test_ndarray_iadd_broadcast()
test_ndarray_iadd_broadcast_scalar()
test_ndarray_sub()
test_ndarray_sub_broadcast()
test_ndarray_sub_broadcast_lhs_scalar()
test_ndarray_sub_broadcast_rhs_scalar()
test_ndarray_isub()
test_ndarray_isub_broadcast()
test_ndarray_isub_broadcast_scalar()
test_ndarray_mul()
test_ndarray_mul_broadcast()
test_ndarray_mul_broadcast_lhs_scalar()
test_ndarray_mul_broadcast_rhs_scalar()
test_ndarray_imul()
test_ndarray_imul_broadcast()
test_ndarray_imul_broadcast_scalar()
test_ndarray_truediv()
test_ndarray_truediv_broadcast()
test_ndarray_truediv_broadcast_lhs_scalar()
test_ndarray_truediv_broadcast_rhs_scalar()
test_ndarray_itruediv()
test_ndarray_itruediv_broadcast()
test_ndarray_itruediv_broadcast_scalar()
test_ndarray_floordiv()
test_ndarray_floordiv_broadcast()
test_ndarray_floordiv_broadcast_lhs_scalar()
test_ndarray_floordiv_broadcast_rhs_scalar()
test_ndarray_ifloordiv()
test_ndarray_ifloordiv_broadcast()
test_ndarray_ifloordiv_broadcast_scalar()
test_ndarray_mod()
test_ndarray_mod_broadcast()
test_ndarray_mod_broadcast_lhs_scalar()
test_ndarray_mod_broadcast_rhs_scalar()
test_ndarray_imod()
test_ndarray_imod_broadcast()
test_ndarray_imod_broadcast_scalar()
test_ndarray_pow()
test_ndarray_pow_broadcast()
test_ndarray_pow_broadcast_lhs_scalar()
test_ndarray_pow_broadcast_rhs_scalar()
test_ndarray_ipow()
test_ndarray_ipow_broadcast()
test_ndarray_ipow_broadcast_scalar()
test_ndarray_matmul() test_ndarray_matmul()
# test_ndarray_dot() test_ndarray_imatmul()
# test_ndarray_linalg_matmul() test_ndarray_pos()
# test_ndarray_cholesky() test_ndarray_neg()
# test_ndarray_qr() test_ndarray_inv()
# test_ndarray_svd() test_ndarray_eq()
# test_ndarray_linalg_inv() test_ndarray_eq_broadcast()
# test_ndarray_pinv() test_ndarray_eq_broadcast_lhs_scalar()
# test_ndarray_lu() test_ndarray_eq_broadcast_rhs_scalar()
# test_ndarray_schur() test_ndarray_ne()
# test_ndarray_hessenberg() test_ndarray_ne_broadcast()
test_ndarray_ne_broadcast_lhs_scalar()
test_ndarray_ne_broadcast_rhs_scalar()
test_ndarray_lt()
test_ndarray_lt_broadcast()
test_ndarray_lt_broadcast_lhs_scalar()
test_ndarray_lt_broadcast_rhs_scalar()
test_ndarray_lt()
test_ndarray_le_broadcast()
test_ndarray_le_broadcast_lhs_scalar()
test_ndarray_le_broadcast_rhs_scalar()
test_ndarray_gt()
test_ndarray_gt_broadcast()
test_ndarray_gt_broadcast_lhs_scalar()
test_ndarray_gt_broadcast_rhs_scalar()
test_ndarray_gt()
test_ndarray_ge_broadcast()
test_ndarray_ge_broadcast_lhs_scalar()
test_ndarray_ge_broadcast_rhs_scalar()
# test_ndarray_ctor() test_ndarray_int32()
# test_ndarray_empty() test_ndarray_int64()
# test_ndarray_zeros() test_ndarray_uint32()
# test_ndarray_ones() test_ndarray_uint64()
# test_ndarray_full() test_ndarray_float()
# test_ndarray_eye() test_ndarray_bool()
# test_ndarray_array()
# test_ndarray_identity()
# test_ndarray_fill()
# test_ndarray_copy()
# test_ndarray_neg_idx() test_ndarray_round()
# test_ndarray_slices() test_ndarray_floor()
# test_ndarray_nd_idx() test_ndarray_min()
test_ndarray_minimum()
test_ndarray_minimum_broadcast()
test_ndarray_minimum_broadcast_lhs_scalar()
test_ndarray_minimum_broadcast_rhs_scalar()
test_ndarray_argmin()
test_ndarray_max()
test_ndarray_maximum()
test_ndarray_maximum_broadcast()
test_ndarray_maximum_broadcast_lhs_scalar()
test_ndarray_maximum_broadcast_rhs_scalar()
test_ndarray_argmax()
test_ndarray_abs()
test_ndarray_isnan()
test_ndarray_isinf()
# test_ndarray_add() test_ndarray_sin()
# test_ndarray_add_broadcast() test_ndarray_cos()
# test_ndarray_add_broadcast_lhs_scalar() test_ndarray_exp()
# test_ndarray_add_broadcast_rhs_scalar() test_ndarray_exp2()
# test_ndarray_iadd() test_ndarray_log()
# test_ndarray_iadd_broadcast() test_ndarray_log10()
# test_ndarray_iadd_broadcast_scalar() test_ndarray_log2()
# test_ndarray_sub() test_ndarray_fabs()
# test_ndarray_sub_broadcast() test_ndarray_sqrt()
# test_ndarray_sub_broadcast_lhs_scalar() test_ndarray_rint()
# test_ndarray_sub_broadcast_rhs_scalar() test_ndarray_tan()
# test_ndarray_isub() test_ndarray_arcsin()
# test_ndarray_isub_broadcast() test_ndarray_arccos()
# test_ndarray_isub_broadcast_scalar() test_ndarray_arctan()
# test_ndarray_mul() test_ndarray_sinh()
# test_ndarray_mul_broadcast() test_ndarray_cosh()
# test_ndarray_mul_broadcast_lhs_scalar() test_ndarray_tanh()
# test_ndarray_mul_broadcast_rhs_scalar() test_ndarray_arcsinh()
# test_ndarray_imul() test_ndarray_arccosh()
# test_ndarray_imul_broadcast() test_ndarray_arctanh()
# test_ndarray_imul_broadcast_scalar() test_ndarray_expm1()
# test_ndarray_truediv() test_ndarray_cbrt()
# test_ndarray_truediv_broadcast()
# test_ndarray_truediv_broadcast_lhs_scalar()
# test_ndarray_truediv_broadcast_rhs_scalar()
# test_ndarray_itruediv()
# test_ndarray_itruediv_broadcast()
# test_ndarray_itruediv_broadcast_scalar()
# test_ndarray_floordiv()
# test_ndarray_floordiv_broadcast()
# test_ndarray_floordiv_broadcast_lhs_scalar()
# test_ndarray_floordiv_broadcast_rhs_scalar()
# test_ndarray_ifloordiv()
# test_ndarray_ifloordiv_broadcast()
# test_ndarray_ifloordiv_broadcast_scalar()
# test_ndarray_mod()
# test_ndarray_mod_broadcast()
# test_ndarray_mod_broadcast_lhs_scalar()
# test_ndarray_mod_broadcast_rhs_scalar()
# test_ndarray_imod()
# test_ndarray_imod_broadcast()
# test_ndarray_imod_broadcast_scalar()
# test_ndarray_pow()
# test_ndarray_pow_broadcast()
# test_ndarray_pow_broadcast_lhs_scalar()
# test_ndarray_pow_broadcast_rhs_scalar()
# test_ndarray_ipow()
# test_ndarray_ipow_broadcast()
# test_ndarray_ipow_broadcast_scalar()
# test_ndarray_matmul()
# test_ndarray_imatmul()
# test_ndarray_pos()
# test_ndarray_neg()
# test_ndarray_inv()
# test_ndarray_eq()
# test_ndarray_eq_broadcast()
# test_ndarray_eq_broadcast_lhs_scalar()
# test_ndarray_eq_broadcast_rhs_scalar()
# test_ndarray_ne()
# test_ndarray_ne_broadcast()
# test_ndarray_ne_broadcast_lhs_scalar()
# test_ndarray_ne_broadcast_rhs_scalar()
# test_ndarray_lt()
# test_ndarray_lt_broadcast()
# test_ndarray_lt_broadcast_lhs_scalar()
# test_ndarray_lt_broadcast_rhs_scalar()
# test_ndarray_lt()
# test_ndarray_le_broadcast()
# test_ndarray_le_broadcast_lhs_scalar()
# test_ndarray_le_broadcast_rhs_scalar()
# test_ndarray_gt()
# test_ndarray_gt_broadcast()
# test_ndarray_gt_broadcast_lhs_scalar()
# test_ndarray_gt_broadcast_rhs_scalar()
# test_ndarray_gt()
# test_ndarray_ge_broadcast()
# test_ndarray_ge_broadcast_lhs_scalar()
# test_ndarray_ge_broadcast_rhs_scalar()
# test_ndarray_int32() test_ndarray_erf()
# test_ndarray_int64() test_ndarray_erfc()
# test_ndarray_uint32() test_ndarray_gamma()
# test_ndarray_uint64() test_ndarray_gammaln()
# test_ndarray_float() test_ndarray_j0()
# test_ndarray_bool() test_ndarray_j1()
# test_ndarray_round() test_ndarray_arctan2()
# test_ndarray_floor() test_ndarray_arctan2_broadcast()
# test_ndarray_min() test_ndarray_arctan2_broadcast_lhs_scalar()
# test_ndarray_minimum() test_ndarray_arctan2_broadcast_rhs_scalar()
# test_ndarray_minimum_broadcast() test_ndarray_copysign()
# test_ndarray_minimum_broadcast_lhs_scalar() test_ndarray_copysign_broadcast()
# test_ndarray_minimum_broadcast_rhs_scalar() test_ndarray_copysign_broadcast_lhs_scalar()
# test_ndarray_argmin() test_ndarray_copysign_broadcast_rhs_scalar()
# test_ndarray_max() test_ndarray_fmax()
# test_ndarray_maximum() test_ndarray_fmax_broadcast()
# test_ndarray_maximum_broadcast() test_ndarray_fmax_broadcast_lhs_scalar()
# test_ndarray_maximum_broadcast_lhs_scalar() test_ndarray_fmax_broadcast_rhs_scalar()
# test_ndarray_maximum_broadcast_rhs_scalar() test_ndarray_fmin()
# test_ndarray_argmax() test_ndarray_fmin_broadcast()
# test_ndarray_abs() test_ndarray_fmin_broadcast_lhs_scalar()
# test_ndarray_isnan() test_ndarray_fmin_broadcast_rhs_scalar()
# test_ndarray_isinf() test_ndarray_ldexp()
test_ndarray_ldexp_broadcast()
# test_ndarray_sin() test_ndarray_ldexp_broadcast_lhs_scalar()
# test_ndarray_cos() test_ndarray_ldexp_broadcast_rhs_scalar()
# test_ndarray_exp() test_ndarray_hypot()
# test_ndarray_exp2() test_ndarray_hypot_broadcast()
# test_ndarray_log() test_ndarray_hypot_broadcast_lhs_scalar()
# test_ndarray_log10() test_ndarray_hypot_broadcast_rhs_scalar()
# test_ndarray_log2() test_ndarray_nextafter()
# test_ndarray_fabs() test_ndarray_nextafter_broadcast()
# test_ndarray_sqrt() test_ndarray_nextafter_broadcast_lhs_scalar()
# test_ndarray_rint() test_ndarray_nextafter_broadcast_rhs_scalar()
# test_ndarray_tan()
# test_ndarray_arcsin()
# test_ndarray_arccos()
# test_ndarray_arctan()
# test_ndarray_sinh()
# test_ndarray_cosh()
# test_ndarray_tanh()
# test_ndarray_arcsinh()
# test_ndarray_arccosh()
# test_ndarray_arctanh()
# test_ndarray_expm1()
# test_ndarray_cbrt()
# test_ndarray_erf()
# test_ndarray_erfc()
# test_ndarray_gamma()
# test_ndarray_gammaln()
# test_ndarray_j0()
# test_ndarray_j1()
# test_ndarray_arctan2()
# test_ndarray_arctan2_broadcast()
# test_ndarray_arctan2_broadcast_lhs_scalar()
# test_ndarray_arctan2_broadcast_rhs_scalar()
# test_ndarray_copysign()
# test_ndarray_copysign_broadcast()
# test_ndarray_copysign_broadcast_lhs_scalar()
# test_ndarray_copysign_broadcast_rhs_scalar()
# test_ndarray_fmax()
# test_ndarray_fmax_broadcast()
# test_ndarray_fmax_broadcast_lhs_scalar()
# test_ndarray_fmax_broadcast_rhs_scalar()
# test_ndarray_fmin()
# test_ndarray_fmin_broadcast()
# test_ndarray_fmin_broadcast_lhs_scalar()
# test_ndarray_fmin_broadcast_rhs_scalar()
# test_ndarray_ldexp()
# test_ndarray_ldexp_broadcast()
# test_ndarray_ldexp_broadcast_lhs_scalar()
# test_ndarray_ldexp_broadcast_rhs_scalar()
# test_ndarray_hypot()
# test_ndarray_hypot_broadcast()
# test_ndarray_hypot_broadcast_lhs_scalar()
# test_ndarray_hypot_broadcast_rhs_scalar()
# test_ndarray_nextafter()
# test_ndarray_nextafter_broadcast()
# test_ndarray_nextafter_broadcast_lhs_scalar()
# test_ndarray_nextafter_broadcast_rhs_scalar()
# test_try_invert()
# test_wilkinson_shift()
return 0 return 0

View File

@ -1,346 +0,0 @@
mod runtime_exception;
use core::slice;
use nalgebra::{linalg, DMatrix};
macro_rules! raise_exn {
($name:expr, $message:expr, $param0:expr, $param1:expr, $param2:expr) => {{
use cslice::AsCSlice;
let name_id = $crate::runtime_exception::get_exception_id($name);
let exn = $crate::runtime_exception::Exception {
id: name_id,
file: file!().as_c_slice(),
line: line!(),
column: column!(),
// https://github.com/rust-lang/rfcs/pull/1719
function: "(Rust function)".as_c_slice(),
message: $message.as_c_slice(),
param: [$param0, $param1, $param2],
};
#[allow(unused_unsafe)]
unsafe {
$crate::runtime_exception::raise(&exn)
}
}};
($name:expr, $message:expr) => {{
raise_exn!($name, $message, 0, 0, 0)
}};
}
/// # Safety
///
/// `data` must point to an array with `dim0`x`dim1` elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn linalg_try_invert_to(dim0: usize, dim1: usize, data: *mut f64) -> i8 {
let data_slice = unsafe { slice::from_raw_parts_mut(data, dim0 * dim1) };
let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice);
let mut inverted_matrix = DMatrix::<f64>::zeros(dim0, dim1);
if linalg::try_invert_to(matrix, &mut inverted_matrix) {
data_slice.copy_from_slice(inverted_matrix.transpose().as_slice());
1
} else {
0
}
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn linalg_wilkinson_shift(dim0: usize, dim1: usize, data: *mut f64) -> f64 {
let data_slice = unsafe { slice::from_raw_parts_mut(data, dim0 * dim1) };
let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice);
linalg::wilkinson_shift(matrix[(0, 0)], matrix[(1, 1)], matrix[(0, 1)])
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_dot(
dim0: usize,
dim1: usize,
x1: *mut f64,
_: usize,
_: usize,
x2: *mut f64,
) -> f64 {
let data_slice1 = unsafe { slice::from_raw_parts_mut(x1, dim0 * dim1) };
let data_slice2 = unsafe { slice::from_raw_parts_mut(x2, dim0 * dim1) };
let matrix1 = DMatrix::from_row_slice(dim0, dim1, data_slice1);
let matrix2 = DMatrix::from_row_slice(dim0, dim1, data_slice2);
matrix1.dot(&matrix2)
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_matmul(
dim1_0: usize,
dim1_1: usize,
x1: *mut f64,
dim2_0: usize,
dim2_1: usize,
x2: *mut f64,
dim3_0: usize,
dim3_1: usize,
out: *mut f64,
) -> i8 {
// let name = unsafe {slice::from_raw_parts_mut(n, l)};
// let fne = name.as_c_slice();
raise_exn!("ZeroDivisionError", "Divide by Zero");
let data_slice1 = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
let data_slice2 = unsafe { slice::from_raw_parts_mut(x2, dim2_0 * dim2_1) };
let out_slice = unsafe { slice::from_raw_parts_mut(out, dim3_0 * dim3_1) };
let matrix1 = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice1);
let matrix2 = DMatrix::from_row_slice(dim2_0, dim2_1, data_slice2);
let mut result = DMatrix::<f64>::zeros(dim3_0, dim3_1);
matrix1.mul_to(&matrix2, &mut result);
out_slice.copy_from_slice(result.transpose().as_slice());
// raise_exn!("ZeroDivisionError", "Divide by Zero", r, c, n);
1
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_cholesky(
dim1_0: usize,
dim1_1: usize,
x1: *mut f64,
dim2_0: usize,
dim2_1: usize,
out: *mut f64,
) -> i8 {
let data_slice1 = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
let out_slice = unsafe { slice::from_raw_parts_mut(out, dim2_0 * dim2_1) };
let matrix1 = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice1);
let res = matrix1.cholesky();
match res {
None => 0,
Some(c) => {
out_slice.copy_from_slice(c.unpack().transpose().as_slice());
1
}
}
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_qr(
dim1_0: usize,
dim1_1: usize,
x1: *mut f64,
dim2_0: usize,
dim2_1: usize,
out_q: *mut f64,
dim3_0: usize,
dim3_1: usize,
out_r: *mut f64,
) -> i8 {
let data_slice1 = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
let out_q_slice = unsafe { slice::from_raw_parts_mut(out_q, dim2_0 * dim2_1) };
let out_r_slice = unsafe { slice::from_raw_parts_mut(out_r, dim3_0 * dim3_1) };
// Refer to https://github.com/dimforge/nalgebra/issues/735
let matrix1 = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice1);
let res = matrix1.qr();
let (q, r) = res.unpack();
// Uses different algo need to match numpy
out_q_slice.copy_from_slice(q.transpose().as_slice());
out_r_slice.copy_from_slice(r.transpose().as_slice());
1
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_svd(
dim1_0: usize,
dim1_1: usize,
x1: *mut f64,
dim2_0: usize,
dim2_1: usize,
out_u: *mut f64,
dim3_0: usize,
dim3_1: usize,
out_s: *mut f64,
dim4_0: usize,
dim4_1: usize,
out_vh: *mut f64,
) -> i8 {
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
let out_u_slice = unsafe { slice::from_raw_parts_mut(out_u, dim2_0 * dim2_1) };
let out_s_slice = unsafe { slice::from_raw_parts_mut(out_s, dim3_0 * dim3_1) };
let out_vh_slice = unsafe { slice::from_raw_parts_mut(out_vh, dim4_0 * dim4_1) };
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
let res = matrix.svd(true, true);
out_u_slice.copy_from_slice(res.u.unwrap().transpose().as_slice());
out_s_slice.copy_from_slice(res.singular_values.as_slice());
out_vh_slice.copy_from_slice(res.v_t.unwrap().transpose().as_slice());
1
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_inv(
dim1_0: usize,
dim1_1: usize,
x1: *mut f64,
dim2_0: usize,
dim2_1: usize,
out: *mut f64,
) -> i8 {
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
let out_slice = unsafe { slice::from_raw_parts_mut(out, dim2_0 * dim2_1) };
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
if !matrix.is_invertible() {
// raise error
return 0;
}
let inv = matrix.try_inverse().unwrap();
out_slice.copy_from_slice(inv.transpose().as_slice());
1
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_pinv(
dim1_0: usize,
dim1_1: usize,
x1: *mut f64,
dim2_0: usize,
dim2_1: usize,
out: *mut f64,
) -> i8 {
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
let out_slice = unsafe { slice::from_raw_parts_mut(out, dim2_0 * dim2_1) };
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
let svd = matrix.svd(true, true);
let inv = svd.pseudo_inverse(1e-15);
match inv {
Ok(m) => {
out_slice.copy_from_slice(m.transpose().as_slice());
1
}
Err(e) => {
// raise exception here
assert!(false, "{e}");
0
}
}
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn sp_linalg_lu(
dim1_0: usize,
dim1_1: usize,
x1: *mut f64,
dim2_0: usize,
dim2_1: usize,
out_l: *mut f64,
dim3_0: usize,
dim3_1: usize,
out_u: *mut f64,
) -> i8 {
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
let out_l_slice = unsafe { slice::from_raw_parts_mut(out_l, dim2_0 * dim2_1) };
let out_u_slice = unsafe { slice::from_raw_parts_mut(out_u, dim3_0 * dim3_1) };
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
let (_, l, u) = matrix.lu().unpack();
out_l_slice.copy_from_slice(l.transpose().as_slice());
out_u_slice.copy_from_slice(u.transpose().as_slice());
1
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn sp_linalg_schur(
dim1_0: usize,
dim1_1: usize,
x1: *mut f64,
dim2_0: usize,
dim2_1: usize,
out_t: *mut f64,
dim3_0: usize,
dim3_1: usize,
out_z: *mut f64,
) -> i8 {
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
let out_t_slice = unsafe { slice::from_raw_parts_mut(out_t, dim2_0 * dim2_1) };
let out_z_slice = unsafe { slice::from_raw_parts_mut(out_z, dim3_0 * dim3_1) };
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
if !matrix.is_square() {
// Throw error here
return 0;
}
let (z, t) = matrix.schur().unpack();
out_t_slice.copy_from_slice(t.transpose().as_slice());
out_z_slice.copy_from_slice(z.transpose().as_slice());
1
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn sp_linalg_hessenberg(
dim1_0: usize,
dim1_1: usize,
x1: *mut f64,
dim2_0: usize,
dim2_1: usize,
out_h: *mut f64,
) -> i8 {
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
let out_h_slice = unsafe { slice::from_raw_parts_mut(out_h, dim2_0 * dim2_1) };
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
if !matrix.is_square() {
// Throw error here
return 0;
}
let (_, h) = matrix.hessenberg().unpack();
out_h_slice.copy_from_slice(h.transpose().as_slice());
1
}

View File

@ -1,80 +0,0 @@
#![allow(non_camel_case_types)]
#![allow(unused)]
// ARTIQ Exception struct declaration
use cslice::CSlice;
// Note: CSlice within an exception may not be actual cslice, they may be strings that exist only
// in the host. If the length == usize:MAX, the pointer is actually a string key in the host.
#[repr(C)]
#[derive(Clone)]
pub struct Exception<'a> {
pub id: u32,
pub file: CSlice<'a, u8>,
pub line: u32,
pub column: u32,
pub function: CSlice<'a, u8>,
pub message: CSlice<'a, u8>,
pub param: [i64; 3],
}
fn str_err(_: core::str::Utf8Error) -> core::fmt::Error {
core::fmt::Error
}
fn exception_str<'a>(s: &'a CSlice<'a, u8>) -> Result<&'a str, core::str::Utf8Error> {
if s.len() == usize::MAX {
Ok("<host string>")
} else {
core::str::from_utf8(s.as_ref())
}
}
impl<'a> core::fmt::Debug for Exception<'a> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"Exception {} from {} in {}:{}:{}, message: {}",
self.id,
exception_str(&self.function).map_err(str_err)?,
exception_str(&self.file).map_err(str_err)?,
self.line,
self.column,
exception_str(&self.message).map_err(str_err)?
)
}
}
pub unsafe fn raise(exception: *const Exception) -> ! {
println!("Excepytiopn!! knfv {:?}", exception);
let e = &*exception;
let f1 = exception_str(&e.function).map_err(str_err).unwrap();
let f2 = exception_str(&e.file).map_err(str_err).unwrap();
let f3 = exception_str(&e.message).map_err(str_err).unwrap();
panic!("Exception {} from {} in {}:{}:{}, message: {}", e.id, f1, f2, e.line, e.column, f3);
}
static EXCEPTION_ID_LOOKUP: [(&str, u32); 12] = [
("RuntimeError", 0),
("RTIOUnderflow", 1),
("RTIOOverflow", 2),
("RTIODestinationUnreachable", 3),
("DMAError", 4),
("I2CError", 5),
("CacheError", 6),
("SPIError", 7),
("ZeroDivisionError", 8),
("IndexError", 9),
("UnwrapNoneError", 10),
("Value", 11),
];
pub fn get_exception_id(name: &str) -> u32 {
for (n, id) in EXCEPTION_ID_LOOKUP.iter() {
if *n == name {
return *id;
}
}
unimplemented!("unallocated internal exception id")
}

View File

@ -113,9 +113,7 @@ fn handle_typevar_definition(
x, x,
HashMap::new(), HashMap::new(),
)?; )?;
get_type_from_type_annotation_kinds( get_type_from_type_annotation_kinds(def_list, unifier, &ty, &mut None)
def_list, unifier, primitives, &ty, &mut None,
)
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
let loc = func.location; let loc = func.location;
@ -154,7 +152,7 @@ fn handle_typevar_definition(
HashMap::new(), HashMap::new(),
)?; )?;
let constraint = let constraint =
get_type_from_type_annotation_kinds(def_list, unifier, primitives, &ty, &mut None)?; get_type_from_type_annotation_kinds(def_list, unifier, &ty, &mut None)?;
let loc = func.location; let loc = func.location;
Ok(unifier.get_fresh_const_generic_var(constraint, Some(generic_name), Some(loc)).ty) Ok(unifier.get_fresh_const_generic_var(constraint, Some(generic_name), Some(loc)).ty)

BIN
pyo3_output/nac3artiq.so Executable file

Binary file not shown.