forked from M-Labs/nac3
Compare commits
1 Commits
8655a5f0c7
...
6ad597e592
Author | SHA1 | Date | |
---|---|---|---|
6ad597e592 |
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,3 +1,4 @@
|
|||||||
__pycache__
|
__pycache__
|
||||||
/target
|
/target
|
||||||
nix/windows/msys2
|
nix/windows/msys2
|
||||||
|
nac3standalone/demo/externfns/target
|
||||||
|
69
Cargo.lock
generated
69
Cargo.lock
generated
@ -126,9 +126,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cc"
|
name = "cc"
|
||||||
version = "1.1.6"
|
version = "1.1.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "2aba8f4e9906c7ce3c73463f62a7f0c65183ada1a2d47e397cc8810827f9694f"
|
checksum = "eaff6f8ce506b9773fa786672d63fc7a191ffea1be33f72bbd4aeacefca9ffc8"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cfg-if"
|
name = "cfg-if"
|
||||||
@ -167,7 +167,7 @@ dependencies = [
|
|||||||
"heck 0.5.0",
|
"heck 0.5.0",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.71",
|
"syn 2.0.70",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -256,12 +256,6 @@ version = "0.2.2"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
|
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "cslice"
|
|
||||||
version = "0.3.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "0f8cb7306107e4b10e64994de6d3274bd08996a7c1322a27b86482392f96be0a"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "dirs-next"
|
name = "dirs-next"
|
||||||
version = "2.0.0"
|
version = "2.0.0"
|
||||||
@ -320,6 +314,13 @@ dependencies = [
|
|||||||
"windows-sys",
|
"windows-sys",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "externfns"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"nalgebra",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fastrand"
|
name = "fastrand"
|
||||||
version = "2.1.0"
|
version = "2.1.0"
|
||||||
@ -436,7 +437,7 @@ checksum = "4fa4d8d74483041a882adaa9a29f633253a66dde85055f0495c121620ac484b2"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.71",
|
"syn 2.0.70",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -528,9 +529,9 @@ checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "libloading"
|
name = "libloading"
|
||||||
version = "0.8.5"
|
version = "0.8.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4"
|
checksum = "e310b3a6b5907f99202fcdb4960ff45b93735d7c7d96b760fcff8db2dc0e103d"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"windows-targets",
|
"windows-targets",
|
||||||
@ -552,14 +553,6 @@ dependencies = [
|
|||||||
"libc",
|
"libc",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "linalg_externfns"
|
|
||||||
version = "0.1.0"
|
|
||||||
dependencies = [
|
|
||||||
"cslice",
|
|
||||||
"nalgebra",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "linked-hash-map"
|
name = "linked-hash-map"
|
||||||
version = "0.5.6"
|
version = "0.5.6"
|
||||||
@ -687,8 +680,8 @@ name = "nac3standalone"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"clap",
|
"clap",
|
||||||
|
"externfns",
|
||||||
"inkwell",
|
"inkwell",
|
||||||
"linalg_externfns",
|
|
||||||
"nac3core",
|
"nac3core",
|
||||||
"nac3parser",
|
"nac3parser",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
@ -837,7 +830,7 @@ dependencies = [
|
|||||||
"phf_shared 0.11.2",
|
"phf_shared 0.11.2",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.71",
|
"syn 2.0.70",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -866,9 +859,9 @@ checksum = "5be167a7af36ee22fe3115051bc51f6e6c7054c9348e28deb4f49bd6f705a315"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "portable-atomic"
|
name = "portable-atomic"
|
||||||
version = "1.7.0"
|
version = "1.6.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265"
|
checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ppv-lite86"
|
name = "ppv-lite86"
|
||||||
@ -938,7 +931,7 @@ dependencies = [
|
|||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"pyo3-macros-backend",
|
"pyo3-macros-backend",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.71",
|
"syn 2.0.70",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -951,7 +944,7 @@ dependencies = [
|
|||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"pyo3-build-config",
|
"pyo3-build-config",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.71",
|
"syn 2.0.70",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -1015,9 +1008,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "redox_syscall"
|
name = "redox_syscall"
|
||||||
version = "0.5.3"
|
version = "0.5.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "2a908a6e00f1fdd0dfd9c0eb08ce85126f6d8bbda50017e74bc4a4b7d4a926a4"
|
checksum = "c82cf8cff14456045f55ec4241383baeff27af886adb72ffb2162f99911de0fd"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bitflags",
|
"bitflags",
|
||||||
]
|
]
|
||||||
@ -1132,7 +1125,7 @@ checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.71",
|
"syn 2.0.70",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -1234,7 +1227,7 @@ dependencies = [
|
|||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"rustversion",
|
"rustversion",
|
||||||
"syn 2.0.71",
|
"syn 2.0.70",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -1250,9 +1243,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "syn"
|
name = "syn"
|
||||||
version = "2.0.71"
|
version = "2.0.70"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b146dcf730474b4bcd16c311627b31ede9ab149045db4d6088b3becaea046462"
|
checksum = "2f0209b68b3613b093e0ec905354eccaedcfe83b8cb37cbdeae64026c3064c16"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
@ -1303,22 +1296,22 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "thiserror"
|
name = "thiserror"
|
||||||
version = "1.0.63"
|
version = "1.0.61"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724"
|
checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"thiserror-impl",
|
"thiserror-impl",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "thiserror-impl"
|
name = "thiserror-impl"
|
||||||
version = "1.0.63"
|
version = "1.0.61"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261"
|
checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.71",
|
"syn 2.0.70",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -1592,5 +1585,5 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.71",
|
"syn 2.0.70",
|
||||||
]
|
]
|
||||||
|
@ -4,8 +4,8 @@ members = [
|
|||||||
"nac3ast",
|
"nac3ast",
|
||||||
"nac3parser",
|
"nac3parser",
|
||||||
"nac3core",
|
"nac3core",
|
||||||
"nac3standalone/linalg_externfns",
|
|
||||||
"nac3standalone",
|
"nac3standalone",
|
||||||
|
"nac3standalone/demo/externfns",
|
||||||
"nac3artiq",
|
"nac3artiq",
|
||||||
"runkernel",
|
"runkernel",
|
||||||
]
|
]
|
||||||
|
@ -14,12 +14,23 @@ class Demo:
|
|||||||
|
|
||||||
@kernel
|
@kernel
|
||||||
def run(self):
|
def run(self):
|
||||||
self.core.reset()
|
a = np_array([[1., 2.], [3., 4.]])
|
||||||
while True:
|
b = try_invert_to(a)
|
||||||
with parallel:
|
if b:
|
||||||
self.led0.pulse(100.*ms)
|
# self.core.reset()
|
||||||
self.led1.pulse(100.*ms)
|
# while True:
|
||||||
self.core.delay(100.*ms)
|
# with parallel:
|
||||||
|
# self.led0.pulse(100.*ms)
|
||||||
|
# self.led1.pulse(100.*ms)
|
||||||
|
# self.core.delay(100.*ms)
|
||||||
|
v = try_invert_to(np_identity(2))
|
||||||
|
if v:
|
||||||
|
while True:
|
||||||
|
with parallel:
|
||||||
|
self.led0.pulse(100.*ms)
|
||||||
|
self.led1.pulse(100.*ms)
|
||||||
|
self.core.delay(100.*ms)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,24 +0,0 @@
|
|||||||
from min_artiq import *
|
|
||||||
from numpy import int32
|
|
||||||
|
|
||||||
|
|
||||||
@nac3
|
|
||||||
class EmptyList:
|
|
||||||
core: KernelInvariant[Core]
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.core = Core()
|
|
||||||
|
|
||||||
@rpc
|
|
||||||
def get_empty(self) -> list[int32]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
@kernel
|
|
||||||
def run(self):
|
|
||||||
a: list[int32] = self.get_empty()
|
|
||||||
if a != []:
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
EmptyList().run()
|
|
BIN
nac3artiq/demo/module.elf
Normal file
BIN
nac3artiq/demo/module.elf
Normal file
Binary file not shown.
@ -991,15 +991,8 @@ impl InnerResolver {
|
|||||||
}
|
}
|
||||||
_ => unreachable!("must be list"),
|
_ => unreachable!("must be list"),
|
||||||
};
|
};
|
||||||
|
let ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
let size_t = generator.get_size_type(ctx.ctx);
|
let size_t = generator.get_size_type(ctx.ctx);
|
||||||
let ty = if len == 0
|
|
||||||
&& matches!(&*ctx.unifier.get_ty_immutable(elem_ty), TypeEnum::TVar { .. })
|
|
||||||
{
|
|
||||||
// The default type for zero-length lists of unknown element type is size_t
|
|
||||||
size_t.into()
|
|
||||||
} else {
|
|
||||||
ctx.get_llvm_type(generator, elem_ty)
|
|
||||||
};
|
|
||||||
let arr_ty = ctx
|
let arr_ty = ctx
|
||||||
.ctx
|
.ctx
|
||||||
.struct_type(&[ty.ptr_type(AddressSpace::default()).into(), size_t.into()], false);
|
.struct_type(&[ty.ptr_type(AddressSpace::default()).into(), size_t.into()], false);
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
use inkwell::types::BasicTypeEnum;
|
use inkwell::types::BasicTypeEnum;
|
||||||
use inkwell::values::{BasicValue, BasicValueEnum};
|
use inkwell::values::BasicValueEnum;
|
||||||
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
|
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
|
||||||
use crate::codegen::classes::{ArrayLikeValue, NDArrayValue, ProxyValue, UntypedArrayLikeAccessor};
|
use crate::codegen::classes::{ArrayLikeValue, NDArrayValue, ProxyValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor};
|
||||||
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
|
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
|
||||||
use crate::codegen::stmt::gen_for_callback_incrementing;
|
use crate::codegen::stmt::gen_for_callback_incrementing;
|
||||||
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
|
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
|
||||||
@ -31,6 +31,7 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
let (n_ty, n) = n;
|
let (n_ty, n) = n;
|
||||||
|
|
||||||
Ok(match n {
|
Ok(match n {
|
||||||
BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => {
|
BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => {
|
||||||
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool));
|
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool));
|
||||||
@ -921,6 +922,122 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn call_try_invert_to<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
a: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
const FN_NAME: &str = "linalg_try_invert_to";
|
||||||
|
let (a_ty, a) = a;
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
match a {
|
||||||
|
BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
|
{
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
|
||||||
|
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
match llvm_ndarray_ty {
|
||||||
|
BasicTypeEnum::FloatType(_) => {},
|
||||||
|
_ => unreachable!("Inverse Operation supported on float type NDArray Values only")
|
||||||
|
};
|
||||||
|
|
||||||
|
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
||||||
|
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None));
|
||||||
|
|
||||||
|
// Add asserts for dims
|
||||||
|
if cfg!(debug_assertions) {
|
||||||
|
let n_dims = n.load_ndims(ctx);
|
||||||
|
|
||||||
|
// num_dim == 2
|
||||||
|
ctx.make_assert(generator,ctx.builder.build_int_compare(IntPredicate::EQ, n_dims, n_dims.get_type().const_int(2, false), "").unwrap(),
|
||||||
|
"0:ValueError", format!("Inverse only supported on 2D lists").as_str(), [None, None, None], ctx.current_loc);
|
||||||
|
|
||||||
|
let dim0 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)};
|
||||||
|
let dim1 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)};
|
||||||
|
|
||||||
|
// dim0 == dim1
|
||||||
|
ctx.make_assert(generator,ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim1, "").unwrap(),
|
||||||
|
"0:ValueError", format!("Dimensions do not match {dim0} is not same as {dim1}").as_str(), [None, None, None], ctx.current_loc);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||||
|
let n_sz_eqz = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "")
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
n_sz_eqz,
|
||||||
|
"0:ValueError",
|
||||||
|
format!("zero-size array to reduction operation {FN_NAME}").as_str(),
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a call to linalg_try_invert_to
|
||||||
|
let dim0 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)};
|
||||||
|
let dim1 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)};
|
||||||
|
|
||||||
|
Ok(extern_fns::call_linalg_try_invert_to(ctx, dim0, dim1, n.data().base_ptr(ctx, generator), None).into())
|
||||||
|
}
|
||||||
|
_ => unsupported_type(ctx, FN_NAME, &[a_ty]),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_wilkinson_shift<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
a: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
const FN_NAME: &str = "linalg_wilkinson_shift";
|
||||||
|
let (a_ty, a) = a;
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
match a {
|
||||||
|
BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
|
{
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
|
||||||
|
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
match llvm_ndarray_ty {
|
||||||
|
BasicTypeEnum::FloatType(_) | BasicTypeEnum::IntType(_) => {},
|
||||||
|
_ => unreachable!("Wilkinson Shift Operation supported on float/integer type NDArray Values only")
|
||||||
|
};
|
||||||
|
|
||||||
|
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
||||||
|
|
||||||
|
// Add asserts for dims
|
||||||
|
if cfg!(debug_assertions) {
|
||||||
|
let n_dims = n.load_ndims(ctx);
|
||||||
|
|
||||||
|
// num_dim == 2
|
||||||
|
ctx.make_assert(generator,ctx.builder.build_int_compare(IntPredicate::EQ, n_dims, n_dims.get_type().const_int(2, false), "").unwrap(),
|
||||||
|
"0:ValueError", format!("Wilkinson Shift supported only on 2D lists").as_str(), [None, None, None], ctx.current_loc);
|
||||||
|
|
||||||
|
let dim0 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)};
|
||||||
|
let dim1 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)};
|
||||||
|
|
||||||
|
// dim0 == dim1
|
||||||
|
ctx.make_assert(generator,ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim1, "").unwrap(),
|
||||||
|
"0:ValueError", format!("Dimensions do not match {dim0} is not same as {dim1}").as_str(), [None, None, None], ctx.current_loc);
|
||||||
|
|
||||||
|
// dimesions should be 2x2
|
||||||
|
ctx.make_assert(generator,ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim0.get_type().const_int(2, false), "").unwrap(),
|
||||||
|
"0:ValueError", format!("Wilkinson Shift supported only on 2x2 matrices").as_str(), [None, None, None], ctx.current_loc);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a call to linalg_try_invert_to
|
||||||
|
let dim0 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)};
|
||||||
|
let dim1 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)};
|
||||||
|
|
||||||
|
Ok(extern_fns::call_linalg_wilkinson_shift(ctx, dim0, dim1, n.data().base_ptr(ctx, generator), None).into())
|
||||||
|
}
|
||||||
|
_ => unsupported_type(ctx, FN_NAME, &[a_ty]),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Invokes the `np_maximum` builtin function.
|
/// Invokes the `np_maximum` builtin function.
|
||||||
pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
|
pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
@ -1834,763 +1951,3 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]),
|
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Linalg Methods
|
|
||||||
pub fn call_np_dot<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
x1: (Type, BasicValueEnum<'ctx>),
|
|
||||||
x2: (Type, BasicValueEnum<'ctx>),
|
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
|
||||||
const FN_NAME: &str = "np_dot";
|
|
||||||
let (x1_ty, x1) = x1;
|
|
||||||
let (x2_ty, x2) = x2;
|
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
|
|
||||||
let one = llvm_usize.const_int(1, false);
|
|
||||||
|
|
||||||
if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) = (x1, x2) {
|
|
||||||
let (n1_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
|
||||||
let n1_elem_ty = ctx.get_llvm_type(generator, n1_elem_ty);
|
|
||||||
let (n2_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
|
||||||
let n2_elem_ty = ctx.get_llvm_type(generator, n2_elem_ty);
|
|
||||||
|
|
||||||
let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty)
|
|
||||||
else {
|
|
||||||
unimplemented!("{FN_NAME} operates on float type NdArrays only");
|
|
||||||
};
|
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
|
||||||
let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None);
|
|
||||||
|
|
||||||
// The following constraints must be satisfied:
|
|
||||||
// * Input must be 1D
|
|
||||||
// * Number of elements in two matrices must equal
|
|
||||||
if cfg!(debug_assertions) {
|
|
||||||
let n1_dims = n1.load_ndims(ctx);
|
|
||||||
let n2_dims = n2.load_ndims(ctx);
|
|
||||||
|
|
||||||
let n1_dims_eq1 =
|
|
||||||
ctx.builder.build_int_compare(IntPredicate::EQ, n1_dims, one, "").unwrap();
|
|
||||||
let n2_dims_eq1 =
|
|
||||||
ctx.builder.build_int_compare(IntPredicate::EQ, n2_dims, one, "").unwrap();
|
|
||||||
|
|
||||||
// num_dim = 1
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
n1_dims_eq1,
|
|
||||||
"0:ValueError",
|
|
||||||
format!("{FN_NAME} operates on 1D matrices").as_str(),
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
n2_dims_eq1,
|
|
||||||
"0:ValueError",
|
|
||||||
format!("{FN_NAME} operates on 1D matrices").as_str(),
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
|
|
||||||
// equal number of elements
|
|
||||||
let n1_sz = irrt::call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
|
||||||
let n2_sz = irrt::call_ndarray_calc_size(generator, ctx, &n2.dim_sizes(), (None, None));
|
|
||||||
|
|
||||||
let size_eq =
|
|
||||||
ctx.builder.build_int_compare(IntPredicate::EQ, n1_sz, n2_sz, "").unwrap();
|
|
||||||
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
size_eq,
|
|
||||||
"0:ValueError",
|
|
||||||
format!("The operands of {FN_NAME} must have equal length").as_str(),
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
let dim0 = unsafe {
|
|
||||||
n1.dim_sizes()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(extern_fns::call_np_dot(
|
|
||||||
ctx,
|
|
||||||
(dim0, one, n1.data().base_ptr(ctx, generator)),
|
|
||||||
(dim0, one, n2.data().base_ptr(ctx, generator)),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.into())
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_np_linalg_matmul<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
x1: (Type, BasicValueEnum<'ctx>),
|
|
||||||
x2: (Type, BasicValueEnum<'ctx>),
|
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
|
||||||
const FN_NAME: &str = "np_linalg_matmul";
|
|
||||||
let (x1_ty, x1) = x1;
|
|
||||||
let (x2_ty, x2) = x2;
|
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
|
|
||||||
let one = llvm_usize.const_int(1, false);
|
|
||||||
let two = llvm_usize.const_int(2, false);
|
|
||||||
|
|
||||||
if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) = (x1, x2) {
|
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
|
||||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
||||||
let (n2_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
|
||||||
let n2_elem_ty = ctx.get_llvm_type(generator, n2_elem_ty);
|
|
||||||
|
|
||||||
let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty)
|
|
||||||
else {
|
|
||||||
unimplemented!("{FN_NAME} operates on float type NdArrays only");
|
|
||||||
};
|
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
|
||||||
let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None);
|
|
||||||
|
|
||||||
// The following constraints must be satisfied:
|
|
||||||
// * Input must be 2D
|
|
||||||
// * Number of columns of first matrix should equal number of rows of second
|
|
||||||
if true {
|
|
||||||
let n1_dims = n1.load_ndims(ctx);
|
|
||||||
let n2_dims = n2.load_ndims(ctx);
|
|
||||||
|
|
||||||
let n1_dims_eq2 =
|
|
||||||
ctx.builder.build_int_compare(IntPredicate::EQ, n1_dims, two, "").unwrap();
|
|
||||||
let n2_dims_eq2 =
|
|
||||||
ctx.builder.build_int_compare(IntPredicate::EQ, n2_dims, two, "").unwrap();
|
|
||||||
|
|
||||||
// num_dim = 2
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
n1_dims_eq2,
|
|
||||||
"0:ValueError",
|
|
||||||
format!("{FN_NAME} operates on 2D matrices").as_str(),
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
n2_dims_eq2,
|
|
||||||
"0:ValueError",
|
|
||||||
format!("{FN_NAME} operates on 2D matrices").as_str(),
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
|
|
||||||
// matrix must be compatible for multiplication
|
|
||||||
let n1_col = unsafe {
|
|
||||||
n1.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value()
|
|
||||||
};
|
|
||||||
let n2_col = unsafe {
|
|
||||||
n1.dim_sizes()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
|
|
||||||
let dim_eq =
|
|
||||||
ctx.builder.build_int_compare(IntPredicate::EQ, n1_col, n2_col, "").unwrap();
|
|
||||||
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
dim_eq,
|
|
||||||
"0:ValueError",
|
|
||||||
format!("Columns of first matrix must equal rows of second for {FN_NAME}").as_str(),
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
let out_dim0 = unsafe {
|
|
||||||
n1.dim_sizes()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
let out_dim1 =
|
|
||||||
unsafe { n2.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value() };
|
|
||||||
|
|
||||||
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[out_dim0, out_dim1])
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let dim0 = unsafe {
|
|
||||||
n1.dim_sizes()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
let dim1 =
|
|
||||||
unsafe { n1.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value() };
|
|
||||||
let dim2 =
|
|
||||||
unsafe { n2.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value() };
|
|
||||||
|
|
||||||
// let r = ctx.ctx.const_string(string, null_terminated);
|
|
||||||
|
|
||||||
extern_fns::call_np_linalg_matmul(
|
|
||||||
ctx,
|
|
||||||
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
|
|
||||||
(dim1, dim2, n2.data().base_ptr(ctx, generator)),
|
|
||||||
(dim0, dim2, out.data().base_ptr(ctx, generator)),
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
Ok(out.as_base_value().as_basic_value_enum())
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
x1: (Type, BasicValueEnum<'ctx>),
|
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
|
||||||
const FN_NAME: &str = "np_linalg_cholesky";
|
|
||||||
let (x1_ty, x1) = x1;
|
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
|
|
||||||
let one = llvm_usize.const_int(1, false);
|
|
||||||
let two = llvm_usize.const_int(2, false);
|
|
||||||
|
|
||||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
|
||||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
||||||
|
|
||||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
|
||||||
unimplemented!("{FN_NAME} operates on float type NdArrays only");
|
|
||||||
};
|
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
|
||||||
|
|
||||||
// The following constraints must be satisfied:
|
|
||||||
// * Input must be 2D
|
|
||||||
// * Input must be a square matrix (here we assume it is symmetric)
|
|
||||||
if cfg!(debug_assertions) {
|
|
||||||
let n1_dims = n1.load_ndims(ctx);
|
|
||||||
|
|
||||||
let n1_dims_eq2 =
|
|
||||||
ctx.builder.build_int_compare(IntPredicate::EQ, n1_dims, two, "").unwrap();
|
|
||||||
|
|
||||||
// num_dim = 2
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
n1_dims_eq2,
|
|
||||||
"0:ValueError",
|
|
||||||
format!("{FN_NAME} operates on 2D matrices").as_str(),
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
|
|
||||||
// Square Matrix
|
|
||||||
let dim0 = unsafe {
|
|
||||||
n1.dim_sizes()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
let dim1 = unsafe {
|
|
||||||
n1.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value()
|
|
||||||
};
|
|
||||||
|
|
||||||
let dim_match =
|
|
||||||
ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim1, "").unwrap();
|
|
||||||
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
dim_match,
|
|
||||||
"0:ValueError",
|
|
||||||
format!("Input matrix must be a square matrix {FN_NAME}").as_str(),
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
let dim0 = unsafe {
|
|
||||||
n1.dim_sizes()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
let dim1 =
|
|
||||||
unsafe { n1.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value() };
|
|
||||||
|
|
||||||
let out =
|
|
||||||
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1]).unwrap();
|
|
||||||
|
|
||||||
let dim0 = unsafe {
|
|
||||||
n1.dim_sizes()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
|
|
||||||
extern_fns::call_np_linalg_cholesky(
|
|
||||||
ctx,
|
|
||||||
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
|
|
||||||
(dim0, dim1, out.data().base_ptr(ctx, generator)),
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
Ok(out.as_base_value().as_basic_value_enum())
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
x1: (Type, BasicValueEnum<'ctx>),
|
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
|
||||||
const FN_NAME: &str = "np_linalg_qr";
|
|
||||||
let (x1_ty, x1) = x1;
|
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
|
|
||||||
let one = llvm_usize.const_int(1, false);
|
|
||||||
|
|
||||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
|
||||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
||||||
|
|
||||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
|
||||||
unimplemented!("{FN_NAME} operates on float type NdArrays only");
|
|
||||||
};
|
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
|
||||||
|
|
||||||
let dim0 = unsafe {
|
|
||||||
n1.dim_sizes()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
let dim1 =
|
|
||||||
unsafe { n1.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value() };
|
|
||||||
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
|
|
||||||
|
|
||||||
let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k]).unwrap();
|
|
||||||
let out_r = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1]).unwrap();
|
|
||||||
|
|
||||||
extern_fns::call_np_linalg_qr(
|
|
||||||
ctx,
|
|
||||||
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
|
|
||||||
(dim0, k, out_q.data().base_ptr(ctx, generator)),
|
|
||||||
(k, dim1, out_r.data().base_ptr(ctx, generator)),
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
let out_q = out_q.as_base_value().as_basic_value_enum();
|
|
||||||
let out_r = out_r.as_base_value().as_basic_value_enum();
|
|
||||||
let res_ty = ctx.ctx.struct_type(&[out_q.get_type(), out_r.get_type()], false);
|
|
||||||
let res_ptr = ctx.builder.build_alloca(res_ty, "QR_factorization").unwrap();
|
|
||||||
|
|
||||||
let res_val = [out_q, out_r];
|
|
||||||
for (i, v) in res_val.into_iter().enumerate() {
|
|
||||||
unsafe {
|
|
||||||
let ptr = ctx
|
|
||||||
.builder
|
|
||||||
.build_in_bounds_gep(
|
|
||||||
res_ptr,
|
|
||||||
&[
|
|
||||||
ctx.ctx.i32_type().const_zero(),
|
|
||||||
ctx.ctx.i32_type().const_int(i as u64, false),
|
|
||||||
],
|
|
||||||
"ptr",
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
ctx.builder.build_store(ptr, v).unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(ctx.builder.build_load(res_ptr, "QR_Factorization_result").map(Into::into).unwrap())
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
x1: (Type, BasicValueEnum<'ctx>),
|
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
|
||||||
const FN_NAME: &str = "np_linalg_svd";
|
|
||||||
let (x1_ty, x1) = x1;
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
|
|
||||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
|
||||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
||||||
|
|
||||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
|
||||||
unimplemented!("{FN_NAME} operates on float type NdArrays only");
|
|
||||||
};
|
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
|
||||||
|
|
||||||
let dim0 = unsafe {
|
|
||||||
n1.dim_sizes()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
let dim1 = unsafe {
|
|
||||||
n1.dim_sizes()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
|
|
||||||
|
|
||||||
let out_u =
|
|
||||||
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]).unwrap();
|
|
||||||
let out_s = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k]).unwrap();
|
|
||||||
|
|
||||||
let out_vh =
|
|
||||||
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim1]).unwrap();
|
|
||||||
|
|
||||||
extern_fns::call_np_linalg_svd(
|
|
||||||
ctx,
|
|
||||||
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
|
|
||||||
(dim0, dim0, out_u.data().base_ptr(ctx, generator)),
|
|
||||||
(k, llvm_usize.const_int(1, false), out_s.data().base_ptr(ctx, generator)),
|
|
||||||
(dim1, dim1, out_vh.data().base_ptr(ctx, generator)),
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
|
|
||||||
let out_u = out_u.as_base_value().as_basic_value_enum();
|
|
||||||
let out_s = out_s.as_base_value().as_basic_value_enum();
|
|
||||||
let out_vh = out_vh.as_base_value().as_basic_value_enum();
|
|
||||||
|
|
||||||
let res_ty =
|
|
||||||
ctx.ctx.struct_type(&[out_u.get_type(), out_s.get_type(), out_vh.get_type()], false);
|
|
||||||
let res_ptr = ctx.builder.build_alloca(res_ty, "SVD_factorization").unwrap();
|
|
||||||
|
|
||||||
let res_val = [out_u, out_s, out_vh];
|
|
||||||
for (i, v) in res_val.into_iter().enumerate() {
|
|
||||||
unsafe {
|
|
||||||
let ptr = ctx
|
|
||||||
.builder
|
|
||||||
.build_in_bounds_gep(
|
|
||||||
res_ptr,
|
|
||||||
&[
|
|
||||||
ctx.ctx.i32_type().const_zero(),
|
|
||||||
ctx.ctx.i32_type().const_int(i as u64, false),
|
|
||||||
],
|
|
||||||
"ptr",
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
ctx.builder.build_store(ptr, v).unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(ctx.builder.build_load(res_ptr, "SVD_Factorization_result").map(Into::into).unwrap())
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
x1: (Type, BasicValueEnum<'ctx>),
|
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
|
||||||
const FN_NAME: &str = "np_linalg_inv";
|
|
||||||
let (x1_ty, x1) = x1;
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
|
|
||||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
|
||||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
||||||
|
|
||||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
|
||||||
unimplemented!("{FN_NAME} operates on float type NdArrays only");
|
|
||||||
};
|
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
|
||||||
|
|
||||||
let dim0 = unsafe {
|
|
||||||
n1.dim_sizes()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
let dim1 = unsafe {
|
|
||||||
n1.dim_sizes()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
|
|
||||||
let out =
|
|
||||||
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1]).unwrap();
|
|
||||||
|
|
||||||
extern_fns::call_np_linalg_inv(
|
|
||||||
ctx,
|
|
||||||
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
|
|
||||||
(dim0, dim1, out.data().base_ptr(ctx, generator)),
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
Ok(out.as_base_value().as_basic_value_enum())
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
x1: (Type, BasicValueEnum<'ctx>),
|
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
|
||||||
const FN_NAME: &str = "np_linalg_pinv";
|
|
||||||
let (x1_ty, x1) = x1;
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
|
|
||||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
|
||||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
||||||
|
|
||||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
|
||||||
unimplemented!("{FN_NAME} operates on float type NdArrays only");
|
|
||||||
};
|
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
|
||||||
|
|
||||||
let dim0 = unsafe {
|
|
||||||
n1.dim_sizes()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
let dim1 = unsafe {
|
|
||||||
n1.dim_sizes()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
|
|
||||||
let out =
|
|
||||||
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim0]).unwrap();
|
|
||||||
|
|
||||||
extern_fns::call_np_linalg_pinv(
|
|
||||||
ctx,
|
|
||||||
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
|
|
||||||
(dim1, dim0, out.data().base_ptr(ctx, generator)),
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
Ok(out.as_base_value().as_basic_value_enum())
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
x1: (Type, BasicValueEnum<'ctx>),
|
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
|
||||||
const FN_NAME: &str = "sp_linalg_lu";
|
|
||||||
let (x1_ty, x1) = x1;
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
|
|
||||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
|
||||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
||||||
|
|
||||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
|
||||||
unimplemented!("{FN_NAME} operates on float type NdArrays only");
|
|
||||||
};
|
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
|
||||||
|
|
||||||
let dim0 = unsafe {
|
|
||||||
n1.dim_sizes()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
let dim1 = unsafe {
|
|
||||||
n1.dim_sizes()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
|
|
||||||
|
|
||||||
let out_l = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k]).unwrap();
|
|
||||||
let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1]).unwrap();
|
|
||||||
|
|
||||||
extern_fns::call_sp_linalg_lu(
|
|
||||||
ctx,
|
|
||||||
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
|
|
||||||
(dim0, k, out_l.data().base_ptr(ctx, generator)),
|
|
||||||
(k, dim1, out_u.data().base_ptr(ctx, generator)),
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
|
|
||||||
let out_l = out_l.as_base_value().as_basic_value_enum();
|
|
||||||
let out_u = out_u.as_base_value().as_basic_value_enum();
|
|
||||||
|
|
||||||
let res_ty = ctx.ctx.struct_type(&[out_l.get_type(), out_u.get_type()], false);
|
|
||||||
let res_ptr = ctx.builder.build_alloca(res_ty, "LU_factorization").unwrap();
|
|
||||||
|
|
||||||
let res_val = [out_l, out_u];
|
|
||||||
for (i, v) in res_val.into_iter().enumerate() {
|
|
||||||
unsafe {
|
|
||||||
let ptr = ctx
|
|
||||||
.builder
|
|
||||||
.build_in_bounds_gep(
|
|
||||||
res_ptr,
|
|
||||||
&[
|
|
||||||
ctx.ctx.i32_type().const_zero(),
|
|
||||||
ctx.ctx.i32_type().const_int(i as u64, false),
|
|
||||||
],
|
|
||||||
"ptr",
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
ctx.builder.build_store(ptr, v).unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(ctx.builder.build_load(res_ptr, "LU_Factorization_result").map(Into::into).unwrap())
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Must be square (add check later)
|
|
||||||
pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
x1: (Type, BasicValueEnum<'ctx>),
|
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
|
||||||
const FN_NAME: &str = "sp_linalg_schur";
|
|
||||||
let (x1_ty, x1) = x1;
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
|
|
||||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
|
||||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
||||||
|
|
||||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
|
||||||
unimplemented!("{FN_NAME} operates on float type NdArrays only");
|
|
||||||
};
|
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
|
||||||
|
|
||||||
let dim0 = unsafe {
|
|
||||||
n1.dim_sizes()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
let dim1 = unsafe {
|
|
||||||
n1.dim_sizes()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
let out_t =
|
|
||||||
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]).unwrap();
|
|
||||||
let out_z =
|
|
||||||
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]).unwrap();
|
|
||||||
|
|
||||||
extern_fns::call_sp_linalg_schur(
|
|
||||||
ctx,
|
|
||||||
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
|
|
||||||
(dim0, dim0, out_t.data().base_ptr(ctx, generator)),
|
|
||||||
(dim0, dim0, out_z.data().base_ptr(ctx, generator)),
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
|
|
||||||
let out_t = out_t.as_base_value().as_basic_value_enum();
|
|
||||||
let out_z = out_z.as_base_value().as_basic_value_enum();
|
|
||||||
|
|
||||||
let res_ty = ctx.ctx.struct_type(&[out_t.get_type(), out_z.get_type()], false);
|
|
||||||
|
|
||||||
let res_ptr = ctx.builder.build_alloca(res_ty, "Schur_factorization").unwrap();
|
|
||||||
let r = ctx
|
|
||||||
.ctx
|
|
||||||
.const_string(ctx.current_loc.file.0.to_string().as_bytes(), true)
|
|
||||||
.as_basic_value_enum()
|
|
||||||
.into_pointer_value();
|
|
||||||
let res_val = [out_t, out_z];
|
|
||||||
for (i, v) in res_val.into_iter().enumerate() {
|
|
||||||
unsafe {
|
|
||||||
let ptr = ctx
|
|
||||||
.builder
|
|
||||||
.build_in_bounds_gep(
|
|
||||||
res_ptr,
|
|
||||||
&[
|
|
||||||
ctx.ctx.i32_type().const_zero(),
|
|
||||||
ctx.ctx.i32_type().const_int(i as u64, false),
|
|
||||||
],
|
|
||||||
"ptr",
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
ctx.builder.build_store(ptr, v).unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(ctx.builder.build_load(res_ptr, "Schur_Factorization_result").map(Into::into).unwrap())
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Must be square (add check later)
|
|
||||||
pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
x1: (Type, BasicValueEnum<'ctx>),
|
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
|
||||||
const FN_NAME: &str = "sp_linalg_hessenberg";
|
|
||||||
let (x1_ty, x1) = x1;
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
|
|
||||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
|
||||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
||||||
|
|
||||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
|
||||||
unimplemented!("{FN_NAME} operates on float type NdArrays only");
|
|
||||||
};
|
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
|
||||||
|
|
||||||
let dim0 = unsafe {
|
|
||||||
n1.dim_sizes()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
let dim1 = unsafe {
|
|
||||||
n1.dim_sizes()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
|
|
||||||
// Check if matrix is square
|
|
||||||
// ctx.builder.build_select(
|
|
||||||
// ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim1, "").unwrap(),
|
|
||||||
// {
|
|
||||||
// let func =
|
|
||||||
// }, else_, name)
|
|
||||||
// ;
|
|
||||||
|
|
||||||
// ctx.builder.build_call(
|
|
||||||
// ctx.module.get_function("__nac3_raise"),
|
|
||||||
// &[]
|
|
||||||
|
|
||||||
// )
|
|
||||||
// let err_msg = ctx.gen_string(generator, "{FN_NAME} requires square matrix");
|
|
||||||
// ctx.raise_exn(generator, "0:ValueError", err_msg, [None, None, None], ctx.current_loc);
|
|
||||||
|
|
||||||
let out_h =
|
|
||||||
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]).unwrap();
|
|
||||||
|
|
||||||
extern_fns::call_sp_linalg_hessenberg(
|
|
||||||
ctx,
|
|
||||||
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
|
|
||||||
(dim0, dim0, out_h.data().base_ptr(ctx, generator)),
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(out_h.as_base_value().as_basic_value_enum())
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -951,9 +951,9 @@ pub fn destructure_range<'ctx>(
|
|||||||
/// Allocates a List structure with the given [type][ty] and [length]. The name of the resulting
|
/// Allocates a List structure with the given [type][ty] and [length]. The name of the resulting
|
||||||
/// LLVM value is `{name}.addr`, or `list.addr` if [name] is not specified.
|
/// LLVM value is `{name}.addr`, or `list.addr` if [name] is not specified.
|
||||||
///
|
///
|
||||||
/// Setting `ty` to [`None`] implies that the list is empty **and** does not have a known element
|
/// Setting `ty` to [`None`] implies that the list does not have a known element type, which is only
|
||||||
/// type, and will therefore set the `list.data` type as `size_t*`. It is undefined behavior to
|
/// valid for empty lists. It is undefined behavior to generate a sized list with an unknown element
|
||||||
/// generate a sized list with an unknown element type.
|
/// type.
|
||||||
pub fn allocate_list<'ctx, G: CodeGenerator + ?Sized>(
|
pub fn allocate_list<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
@ -131,153 +131,81 @@ pub fn call_ldexp<'ctx>(
|
|||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Macro to generate np_linalg external functions
|
/// Invokes the [`try_invert_to`](https://docs.rs/nalgebra/latest/nalgebra/linalg/fn.try_invert_to.html) function
|
||||||
macro_rules! generate_np_linalg_extern_fn {
|
pub fn call_linalg_try_invert_to<'ctx>(
|
||||||
($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 1) => {
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
generate_np_linalg_extern_fn!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1);
|
dim0: IntValue<'ctx>,
|
||||||
};
|
dim1: IntValue<'ctx>,
|
||||||
($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 2) => {
|
data: PointerValue<'ctx>,
|
||||||
generate_np_linalg_extern_fn!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1, mat2);
|
name: Option<&str>,
|
||||||
};
|
) -> IntValue<'ctx> {
|
||||||
($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 3) => {
|
const FN_NAME: &str = "linalg_try_invert_to";
|
||||||
generate_np_linalg_extern_fn!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1, mat2, mat3);
|
|
||||||
};
|
|
||||||
($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 4) => {
|
|
||||||
generate_np_linalg_extern_fn!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1, mat2, mat3, mat4);
|
|
||||||
};
|
|
||||||
($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal $(,$input_matrix:ident)*) => {
|
|
||||||
#[doc = concat!("Invokes the numpy `", stringify!($extern_fn), " function." )]
|
|
||||||
pub fn $fn_name<'ctx>(
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>
|
|
||||||
$(,$input_matrix: (IntValue<'ctx>, IntValue<'ctx>, PointerValue<'ctx>))*,
|
|
||||||
name: Option<&str>,
|
|
||||||
) -> $ret_ty<'ctx> {
|
|
||||||
const FN_NAME: &str = $extern_fn;
|
|
||||||
|
|
||||||
let llvm_f64 = ctx.ctx.f64_type();
|
let llvm_f64 = ctx.ctx.f64_type();
|
||||||
let allowed_index_types = [ctx.ctx.i32_type(), ctx.ctx.i64_type()];
|
let allowed_indices = [ctx.ctx.i32_type(), ctx.ctx.i64_type()];
|
||||||
|
|
||||||
|
let allowed_dim0 = allowed_indices.iter().any(|p| *p == dim0.get_type());
|
||||||
|
let allowed_dim1 = allowed_indices.iter().any(|p| *p == dim1.get_type());
|
||||||
|
|
||||||
$(
|
debug_assert!(allowed_dim0);
|
||||||
debug_assert!(allowed_index_types.iter().any(|p| *p == $input_matrix.0.get_type()));
|
debug_assert!(allowed_dim1);
|
||||||
debug_assert!(allowed_index_types.iter().any(|p| *p == $input_matrix.1.get_type()));
|
debug_assert_eq!(data.get_type().get_element_type().into_float_type(), llvm_f64);
|
||||||
debug_assert_eq!($input_matrix.2.get_type().get_element_type().into_float_type(), llvm_f64);
|
|
||||||
)*
|
|
||||||
|
|
||||||
// let row = ctx.ctx.i32_type().const_int(ctx.current_loc.row.try_into().unwrap(), false);
|
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||||
// let col = ctx.ctx.i32_type().const_int(ctx.current_loc.column.try_into().unwrap(), false);
|
let fn_type = ctx.ctx.i8_type().fn_type(&[dim0.get_type().into(), dim0.get_type().into(), data.get_type().into()], false);
|
||||||
// let file_name = ctx.current_loc.file.0;
|
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||||
// let name_len = ctx.ctx.i32_type().const_int(file_name.to_string().len().try_into().unwrap(), false);
|
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
|
||||||
// let file_name = ctx.ctx.const_string(&ctx.current_loc.file.0.to_string().into_bytes(), true);
|
func.add_attribute(
|
||||||
|
AttributeLoc::Function,
|
||||||
|
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
func
|
||||||
// let fn_type = ctx.ctx.$extern_ret_ty().fn_type(&[row.get_type().into(), col.get_type().into(), file_name.get_type().into(), $($input_matrix.0.get_type().into(), $input_matrix.1.get_type().into(), $input_matrix.2.get_type().into()),*], false);
|
});
|
||||||
// let fn_type = ctx.ctx.$extern_ret_ty().fn_type(&[row.get_type().into(), col.get_type().into(), file_name.get_type().into(), name_len.get_type().into(), $($input_matrix.0.get_type().into(), $input_matrix.1.get_type().into(), $input_matrix.2.get_type().into()),*], false);
|
|
||||||
let fn_type = ctx.ctx.$extern_ret_ty().fn_type(&[$($input_matrix.0.get_type().into(), $input_matrix.1.get_type().into(), $input_matrix.2.get_type().into()),*], false);
|
ctx.builder
|
||||||
|
.build_call(extern_fn, &[dim0.into(), dim1.into(), data.into()], name.unwrap_or_default())
|
||||||
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
|
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||||
func.add_attribute(
|
.map(Either::unwrap_left)
|
||||||
AttributeLoc::Function,
|
.unwrap().into()
|
||||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
func
|
|
||||||
});
|
|
||||||
|
|
||||||
ctx.builder
|
|
||||||
// .build_call(extern_fn, &[row.into(), col.into(), file_name.into(), $($input_matrix.0.into(), $input_matrix.1.into(), $input_matrix.2.into(),)*], name.unwrap_or_default())
|
|
||||||
// .build_call(extern_fn, &[name_len.into(), col.into(), file_name.into(), row.into(), $($input_matrix.0.into(), $input_matrix.1.into(), $input_matrix.2.into(),)*], name.unwrap_or_default())
|
|
||||||
.build_call(extern_fn, &[$($input_matrix.0.into(), $input_matrix.1.into(), $input_matrix.2.into(),)*], name.unwrap_or_default())
|
|
||||||
.map(CallSiteValue::try_as_basic_value)
|
|
||||||
.map(|v| v.map_left($map_fn))
|
|
||||||
.map(Either::unwrap_left)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
generate_np_linalg_extern_fn!(
|
/// Invokes the [`wilkinson_shift`](https://docs.rs/nalgebra/latest/nalgebra/linalg/fn.wilkinson_shift.html) function
|
||||||
call_np_dot,
|
pub fn call_linalg_wilkinson_shift<'ctx>(
|
||||||
FloatValue,
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
f64_type,
|
dim0: IntValue<'ctx>,
|
||||||
BasicValueEnum::into_float_value,
|
dim1: IntValue<'ctx>,
|
||||||
"np_dot",
|
data: PointerValue<'ctx>,
|
||||||
2
|
name: Option<&str>,
|
||||||
);
|
) -> FloatValue<'ctx> {
|
||||||
generate_np_linalg_extern_fn!(
|
const FN_NAME: &str = "linalg_wilkinson_shift";
|
||||||
call_np_linalg_matmul,
|
let allowed_index_types = [ctx.ctx.i32_type(), ctx.ctx.i64_type()];
|
||||||
IntValue,
|
|
||||||
i8_type,
|
let allowed_dim0 = allowed_index_types.iter().any(|p| *p == dim0.get_type());
|
||||||
BasicValueEnum::into_int_value,
|
let allowed_dim1 = allowed_index_types.iter().any(|p| *p == dim1.get_type());
|
||||||
"np_linalg_matmul",
|
|
||||||
3
|
|
||||||
);
|
|
||||||
generate_np_linalg_extern_fn!(
|
|
||||||
call_np_linalg_cholesky,
|
|
||||||
IntValue,
|
|
||||||
i8_type,
|
|
||||||
BasicValueEnum::into_int_value,
|
|
||||||
"np_linalg_cholesky",
|
|
||||||
2
|
|
||||||
);
|
|
||||||
generate_np_linalg_extern_fn!(
|
|
||||||
call_np_linalg_qr,
|
|
||||||
IntValue,
|
|
||||||
i8_type,
|
|
||||||
BasicValueEnum::into_int_value,
|
|
||||||
"np_linalg_qr",
|
|
||||||
3
|
|
||||||
);
|
|
||||||
generate_np_linalg_extern_fn!(
|
|
||||||
call_np_linalg_svd,
|
|
||||||
IntValue,
|
|
||||||
i8_type,
|
|
||||||
BasicValueEnum::into_int_value,
|
|
||||||
"np_linalg_svd",
|
|
||||||
4
|
|
||||||
);
|
|
||||||
|
|
||||||
generate_np_linalg_extern_fn!(
|
debug_assert!(allowed_dim0);
|
||||||
call_np_linalg_inv,
|
debug_assert!(allowed_dim1);
|
||||||
IntValue,
|
|
||||||
i8_type,
|
|
||||||
BasicValueEnum::into_int_value,
|
|
||||||
"np_linalg_inv",
|
|
||||||
2
|
|
||||||
);
|
|
||||||
|
|
||||||
generate_np_linalg_extern_fn!(
|
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||||
call_np_linalg_pinv,
|
let fn_type = ctx.ctx.f64_type().fn_type(&[dim0.get_type().into(), dim0.get_type().into(), data.get_type().into()], false);
|
||||||
IntValue,
|
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||||
i8_type,
|
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
|
||||||
BasicValueEnum::into_int_value,
|
func.add_attribute(
|
||||||
"np_linalg_pinv",
|
AttributeLoc::Function,
|
||||||
2
|
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
||||||
);
|
);
|
||||||
|
}
|
||||||
|
|
||||||
generate_np_linalg_extern_fn!(
|
func
|
||||||
call_sp_linalg_lu,
|
});
|
||||||
IntValue,
|
|
||||||
i8_type,
|
ctx.builder
|
||||||
BasicValueEnum::into_int_value,
|
.build_call(extern_fn, &[dim0.into(), dim1.into(), data.into()], name.unwrap_or_default())
|
||||||
"sp_linalg_lu",
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
3
|
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||||
);
|
.map(Either::unwrap_left)
|
||||||
|
.unwrap().into()
|
||||||
generate_np_linalg_extern_fn!(
|
}
|
||||||
call_sp_linalg_schur,
|
|
||||||
IntValue,
|
|
||||||
i8_type,
|
|
||||||
BasicValueEnum::into_int_value,
|
|
||||||
"sp_linalg_schur",
|
|
||||||
3
|
|
||||||
);
|
|
||||||
|
|
||||||
generate_np_linalg_extern_fn!(
|
|
||||||
call_sp_linalg_hessenberg,
|
|
||||||
IntValue,
|
|
||||||
i8_type,
|
|
||||||
BasicValueEnum::into_int_value,
|
|
||||||
"sp_linalg_hessenberg",
|
|
||||||
2
|
|
||||||
);
|
|
||||||
|
@ -61,7 +61,7 @@ fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
/// * `shape` - The shape of the `NDArray`.
|
/// * `shape` - The shape of the `NDArray`.
|
||||||
/// * `shape_len_fn` - A function that retrieves the number of dimensions from `shape`.
|
/// * `shape_len_fn` - A function that retrieves the number of dimensions from `shape`.
|
||||||
/// * `shape_data_fn` - A function that retrieves the size of a dimension from `shape`.
|
/// * `shape_data_fn` - A function that retrieves the size of a dimension from `shape`.
|
||||||
pub fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
|
fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
elem_ty: Type,
|
elem_ty: Type,
|
||||||
@ -157,7 +157,7 @@ where
|
|||||||
///
|
///
|
||||||
/// * `elem_ty` - The element type of the `NDArray`.
|
/// * `elem_ty` - The element type of the `NDArray`.
|
||||||
/// * `shape` - The shape of the `NDArray`, represented am array of [`IntValue`]s.
|
/// * `shape` - The shape of the `NDArray`, represented am array of [`IntValue`]s.
|
||||||
pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
|
fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
elem_ty: Type,
|
elem_ty: Type,
|
||||||
|
@ -1637,7 +1637,7 @@ pub fn gen_stmt<G: CodeGenerator>(
|
|||||||
};
|
};
|
||||||
ctx.make_assert_impl(
|
ctx.make_assert_impl(
|
||||||
generator,
|
generator,
|
||||||
generator.bool_to_i1(ctx, test.into_int_value()),
|
test.into_int_value(),
|
||||||
"0:AssertionError",
|
"0:AssertionError",
|
||||||
err_msg,
|
err_msg,
|
||||||
[None, None, None],
|
[None, None, None],
|
||||||
|
@ -557,18 +557,8 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
| PrimDef::FunNpHypot
|
| PrimDef::FunNpHypot
|
||||||
| PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim),
|
| PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim),
|
||||||
|
|
||||||
PrimDef::FunNpDot
|
PrimDef::FunTryInvertTo => self.build_linalg_try_invert_to(prim), // Inplace invert
|
||||||
| PrimDef::FunNpLinalgMatmul
|
PrimDef::FunWilkinsonShift => self.build_linalg_wilkinson_shift(prim),
|
||||||
| PrimDef::FunNpLinalgCholesky
|
|
||||||
| PrimDef::FunNpLinalgQr
|
|
||||||
| PrimDef::FunNpLinalgSvd
|
|
||||||
| PrimDef::FunNpLinalgInv
|
|
||||||
| PrimDef::FunNpLinalgPinv
|
|
||||||
| PrimDef::FunSpLinalgLu
|
|
||||||
| PrimDef::FunSpLinalgSchur
|
|
||||||
| PrimDef::FunSpLinalgHessenberg => self.build_np_linalg_methods(prim),
|
|
||||||
// PrimDef::FunNpDot | PrimDef::FunNpLinalgMatmul => self.build_np_linalg_binary_methods(prim),
|
|
||||||
// PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgQr => self.build_np_linalg_unary_methods(prim),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if cfg!(debug_assertions) {
|
if cfg!(debug_assertions) {
|
||||||
@ -577,7 +567,7 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
match (&tld, prim.details()) {
|
match (&tld, prim.details()) {
|
||||||
(
|
(
|
||||||
TopLevelDef::Class { name, object_id, .. },
|
TopLevelDef::Class { name, object_id, .. },
|
||||||
PrimDefDetails::PrimClass { name: exp_name, .. },
|
PrimDefDetails::PrimClass { name: exp_name },
|
||||||
) => {
|
) => {
|
||||||
let exp_object_id = prim.id();
|
let exp_object_id = prim.id();
|
||||||
assert_eq!(name, &exp_name.into());
|
assert_eq!(name, &exp_name.into());
|
||||||
@ -1887,140 +1877,62 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_np_linalg_methods(&mut self, prim: PrimDef) -> TopLevelDef {
|
fn build_linalg_try_invert_to(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||||
debug_assert_prim_is_allowed(
|
debug_assert_prim_is_allowed(
|
||||||
prim,
|
prim,
|
||||||
&[
|
&[
|
||||||
PrimDef::FunNpDot,
|
PrimDef::FunTryInvertTo,
|
||||||
PrimDef::FunNpLinalgMatmul,
|
|
||||||
PrimDef::FunNpLinalgCholesky,
|
|
||||||
PrimDef::FunNpLinalgQr,
|
|
||||||
PrimDef::FunNpLinalgSvd,
|
|
||||||
PrimDef::FunNpLinalgInv,
|
|
||||||
PrimDef::FunNpLinalgPinv,
|
|
||||||
PrimDef::FunSpLinalgLu,
|
|
||||||
PrimDef::FunSpLinalgSchur,
|
|
||||||
PrimDef::FunSpLinalgHessenberg,
|
|
||||||
],
|
],
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let var_map = self.num_or_ndarray_var_map.clone();
|
||||||
|
create_fn_by_codegen(
|
||||||
|
self.unifier,
|
||||||
|
&var_map,
|
||||||
|
prim.name(),
|
||||||
|
self.primitives.bool,
|
||||||
|
&[(self.ndarray_float_2d, "x")],
|
||||||
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
|
let x_ty = fun.0.args[0].ty;
|
||||||
|
let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?;
|
||||||
|
|
||||||
match prim {
|
let func = match prim {
|
||||||
PrimDef::FunNpDot => create_fn_by_codegen(
|
PrimDef::FunTryInvertTo => builtin_fns::call_try_invert_to,
|
||||||
self.unifier,
|
_ => unreachable!(),
|
||||||
&self.num_or_ndarray_var_map,
|
};
|
||||||
prim.name(),
|
|
||||||
self.primitives.float,
|
|
||||||
&[(self.num_or_ndarray_ty.ty, "x1"), (self.num_or_ndarray_ty.ty, "x2")],
|
|
||||||
Box::new(move |ctx, _, fun, args, generator| {
|
|
||||||
let x1_ty = fun.0.args[0].ty;
|
|
||||||
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
|
||||||
let x2_ty = fun.0.args[1].ty;
|
|
||||||
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
|
||||||
|
|
||||||
Ok(Some(builtin_fns::call_np_dot(
|
Ok(Some(func(generator, ctx, (x_ty, x_val))?))
|
||||||
generator,
|
}),
|
||||||
ctx,
|
)
|
||||||
(x1_ty, x1_val),
|
}
|
||||||
(x2_ty, x2_val),
|
|
||||||
)?))
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
|
|
||||||
PrimDef::FunNpLinalgMatmul => create_fn_by_codegen(
|
fn build_linalg_wilkinson_shift(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||||
self.unifier,
|
debug_assert_prim_is_allowed(
|
||||||
&VarMap::new(),
|
prim,
|
||||||
prim.name(),
|
&[
|
||||||
self.ndarray_float_2d,
|
PrimDef::FunWilkinsonShift,
|
||||||
&[(self.ndarray_float_2d, "x1"), (self.ndarray_float_2d, "x2")],
|
],
|
||||||
Box::new(move |ctx, _, fun, args, generator| {
|
);
|
||||||
let x1_ty = fun.0.args[0].ty;
|
|
||||||
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
let var_map = self.num_or_ndarray_var_map.clone();
|
||||||
let x2_ty = fun.0.args[1].ty;
|
create_fn_by_codegen(
|
||||||
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
self.unifier,
|
||||||
|
&var_map,
|
||||||
|
prim.name(),
|
||||||
|
self.primitives.float,
|
||||||
|
&[(self.ndarray_float_2d, "x")],
|
||||||
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
|
let x_ty = fun.0.args[0].ty;
|
||||||
|
let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?;
|
||||||
|
|
||||||
Ok(Some(builtin_fns::call_np_linalg_matmul(
|
let func = match prim {
|
||||||
generator,
|
PrimDef::FunWilkinsonShift => builtin_fns::call_wilkinson_shift,
|
||||||
ctx,
|
_ => unreachable!(),
|
||||||
(x1_ty, x1_val),
|
};
|
||||||
(x2_ty, x2_val),
|
|
||||||
)?))
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
|
|
||||||
PrimDef::FunNpLinalgCholesky
|
Ok(Some(func(generator, ctx, (x_ty, x_val))?))
|
||||||
| PrimDef::FunNpLinalgInv
|
}),
|
||||||
| PrimDef::FunNpLinalgPinv
|
)
|
||||||
| PrimDef::FunSpLinalgHessenberg => create_fn_by_codegen(
|
|
||||||
self.unifier,
|
|
||||||
&VarMap::new(),
|
|
||||||
prim.name(),
|
|
||||||
self.ndarray_float_2d,
|
|
||||||
&[(self.ndarray_float_2d, "x1")],
|
|
||||||
Box::new(move |ctx, _, fun, args, generator| {
|
|
||||||
let x1_ty = fun.0.args[0].ty;
|
|
||||||
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
|
||||||
|
|
||||||
let func = match prim {
|
|
||||||
PrimDef::FunNpLinalgCholesky => builtin_fns::call_np_linalg_cholesky,
|
|
||||||
PrimDef::FunNpLinalgInv => builtin_fns::call_np_linalg_inv,
|
|
||||||
PrimDef::FunNpLinalgPinv => builtin_fns::call_np_linalg_pinv,
|
|
||||||
PrimDef::FunSpLinalgHessenberg => builtin_fns::call_sp_linalg_hessenberg,
|
|
||||||
_ => unreachable!(),
|
|
||||||
};
|
|
||||||
Ok(Some(func(generator, ctx, (x1_ty, x1_val))?))
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
|
|
||||||
PrimDef::FunNpLinalgQr | PrimDef::FunSpLinalgLu | PrimDef::FunSpLinalgSchur => {
|
|
||||||
let ret_ty = self.unifier.add_ty(TypeEnum::TTuple {
|
|
||||||
ty: vec![self.ndarray_float_2d, self.ndarray_float_2d],
|
|
||||||
});
|
|
||||||
create_fn_by_codegen(
|
|
||||||
self.unifier,
|
|
||||||
&VarMap::new(),
|
|
||||||
prim.name(),
|
|
||||||
ret_ty,
|
|
||||||
&[(self.ndarray_float_2d, "x1")],
|
|
||||||
Box::new(move |ctx, _, fun, args, generator| {
|
|
||||||
let x1_ty = fun.0.args[0].ty;
|
|
||||||
let x1_val =
|
|
||||||
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
|
||||||
|
|
||||||
let func = match prim {
|
|
||||||
PrimDef::FunNpLinalgQr => builtin_fns::call_np_linalg_qr,
|
|
||||||
PrimDef::FunSpLinalgLu => builtin_fns::call_sp_linalg_lu,
|
|
||||||
PrimDef::FunSpLinalgSchur => builtin_fns::call_sp_linalg_schur,
|
|
||||||
_ => unreachable!(),
|
|
||||||
};
|
|
||||||
Ok(Some(func(generator, ctx, (x1_ty, x1_val))?))
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
PrimDef::FunNpLinalgSvd => {
|
|
||||||
let ret_ty = self.unifier.add_ty(TypeEnum::TTuple {
|
|
||||||
ty: vec![self.ndarray_float_2d, self.ndarray_float, self.ndarray_float_2d],
|
|
||||||
});
|
|
||||||
create_fn_by_codegen(
|
|
||||||
self.unifier,
|
|
||||||
&VarMap::new(),
|
|
||||||
prim.name(),
|
|
||||||
ret_ty,
|
|
||||||
&[(self.ndarray_float_2d, "x1")],
|
|
||||||
Box::new(move |ctx, _, fun, args, generator| {
|
|
||||||
let x1_ty = fun.0.args[0].ty;
|
|
||||||
let x1_val =
|
|
||||||
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
|
||||||
|
|
||||||
Ok(Some(builtin_fns::call_np_linalg_svd(generator, ctx, (x1_ty, x1_val))?))
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
println!("{:?}", prim.name());
|
|
||||||
unreachable!()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) {
|
fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) {
|
||||||
|
@ -766,7 +766,6 @@ impl TopLevelComposer {
|
|||||||
let target_ty = get_type_from_type_annotation_kinds(
|
let target_ty = get_type_from_type_annotation_kinds(
|
||||||
&temp_def_list,
|
&temp_def_list,
|
||||||
unifier,
|
unifier,
|
||||||
primitives,
|
|
||||||
&def,
|
&def,
|
||||||
&mut subst_list,
|
&mut subst_list,
|
||||||
)?;
|
)?;
|
||||||
@ -937,7 +936,6 @@ impl TopLevelComposer {
|
|||||||
let ty = get_type_from_type_annotation_kinds(
|
let ty = get_type_from_type_annotation_kinds(
|
||||||
temp_def_list.as_ref(),
|
temp_def_list.as_ref(),
|
||||||
unifier,
|
unifier,
|
||||||
primitives_store,
|
|
||||||
&type_annotation,
|
&type_annotation,
|
||||||
&mut None,
|
&mut None,
|
||||||
)?;
|
)?;
|
||||||
@ -1004,7 +1002,6 @@ impl TopLevelComposer {
|
|||||||
get_type_from_type_annotation_kinds(
|
get_type_from_type_annotation_kinds(
|
||||||
&temp_def_list,
|
&temp_def_list,
|
||||||
unifier,
|
unifier,
|
||||||
primitives_store,
|
|
||||||
&return_ty_annotation,
|
&return_ty_annotation,
|
||||||
&mut None,
|
&mut None,
|
||||||
)?
|
)?
|
||||||
@ -1625,7 +1622,6 @@ impl TopLevelComposer {
|
|||||||
let self_type = get_type_from_type_annotation_kinds(
|
let self_type = get_type_from_type_annotation_kinds(
|
||||||
&def_list,
|
&def_list,
|
||||||
unifier,
|
unifier,
|
||||||
primitives_ty,
|
|
||||||
&make_self_type_annotation(type_vars, *object_id),
|
&make_self_type_annotation(type_vars, *object_id),
|
||||||
&mut None,
|
&mut None,
|
||||||
)?;
|
)?;
|
||||||
@ -1807,11 +1803,7 @@ impl TopLevelComposer {
|
|||||||
|
|
||||||
let ty_ann = make_self_type_annotation(type_vars, *class_id);
|
let ty_ann = make_self_type_annotation(type_vars, *class_id);
|
||||||
let self_ty = get_type_from_type_annotation_kinds(
|
let self_ty = get_type_from_type_annotation_kinds(
|
||||||
&def_list,
|
&def_list, unifier, &ty_ann, &mut None,
|
||||||
unifier,
|
|
||||||
primitives_ty,
|
|
||||||
&ty_ann,
|
|
||||||
&mut None,
|
|
||||||
)?;
|
)?;
|
||||||
vars.extend(type_vars.iter().map(|ty| {
|
vars.extend(type_vars.iter().map(|ty| {
|
||||||
let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*ty) else {
|
let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*ty) else {
|
||||||
|
@ -105,16 +105,8 @@ pub enum PrimDef {
|
|||||||
FunNpLdExp,
|
FunNpLdExp,
|
||||||
FunNpHypot,
|
FunNpHypot,
|
||||||
FunNpNextAfter,
|
FunNpNextAfter,
|
||||||
FunNpDot,
|
FunTryInvertTo,
|
||||||
FunNpLinalgMatmul,
|
FunWilkinsonShift,
|
||||||
FunNpLinalgCholesky,
|
|
||||||
FunNpLinalgQr,
|
|
||||||
FunNpLinalgSvd,
|
|
||||||
FunNpLinalgInv,
|
|
||||||
FunNpLinalgPinv,
|
|
||||||
FunSpLinalgLu,
|
|
||||||
FunSpLinalgSchur,
|
|
||||||
FunSpLinalgHessenberg,
|
|
||||||
|
|
||||||
// Top-Level Functions
|
// Top-Level Functions
|
||||||
FunSome,
|
FunSome,
|
||||||
@ -123,7 +115,7 @@ pub enum PrimDef {
|
|||||||
/// Associated details of a [`PrimDef`]
|
/// Associated details of a [`PrimDef`]
|
||||||
pub enum PrimDefDetails {
|
pub enum PrimDefDetails {
|
||||||
PrimFunction { name: &'static str, simple_name: &'static str },
|
PrimFunction { name: &'static str, simple_name: &'static str },
|
||||||
PrimClass { name: &'static str, get_ty_fn: fn(&PrimitiveStore) -> Type },
|
PrimClass { name: &'static str },
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PrimDef {
|
impl PrimDef {
|
||||||
@ -165,17 +157,15 @@ impl PrimDef {
|
|||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn name(&self) -> &'static str {
|
pub fn name(&self) -> &'static str {
|
||||||
match self.details() {
|
match self.details() {
|
||||||
PrimDefDetails::PrimFunction { name, .. } | PrimDefDetails::PrimClass { name, .. } => {
|
PrimDefDetails::PrimFunction { name, .. } | PrimDefDetails::PrimClass { name } => name,
|
||||||
name
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the associated details of this [`PrimDef`]
|
/// Get the associated details of this [`PrimDef`]
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn details(self) -> PrimDefDetails {
|
pub fn details(self) -> PrimDefDetails {
|
||||||
fn class(name: &'static str, get_ty_fn: fn(&PrimitiveStore) -> Type) -> PrimDefDetails {
|
fn class(name: &'static str) -> PrimDefDetails {
|
||||||
PrimDefDetails::PrimClass { name, get_ty_fn }
|
PrimDefDetails::PrimClass { name }
|
||||||
}
|
}
|
||||||
|
|
||||||
fn fun(name: &'static str, simple_name: Option<&'static str>) -> PrimDefDetails {
|
fn fun(name: &'static str, simple_name: Option<&'static str>) -> PrimDefDetails {
|
||||||
@ -183,22 +173,22 @@ impl PrimDef {
|
|||||||
}
|
}
|
||||||
|
|
||||||
match self {
|
match self {
|
||||||
PrimDef::Int32 => class("int32", |primitives| primitives.int32),
|
PrimDef::Int32 => class("int32"),
|
||||||
PrimDef::Int64 => class("int64", |primitives| primitives.int64),
|
PrimDef::Int64 => class("int64"),
|
||||||
PrimDef::Float => class("float", |primitives| primitives.float),
|
PrimDef::Float => class("float"),
|
||||||
PrimDef::Bool => class("bool", |primitives| primitives.bool),
|
PrimDef::Bool => class("bool"),
|
||||||
PrimDef::None => class("none", |primitives| primitives.none),
|
PrimDef::None => class("none"),
|
||||||
PrimDef::Range => class("range", |primitives| primitives.range),
|
PrimDef::Range => class("range"),
|
||||||
PrimDef::Str => class("str", |primitives| primitives.str),
|
PrimDef::Str => class("str"),
|
||||||
PrimDef::Exception => class("Exception", |primitives| primitives.exception),
|
PrimDef::Exception => class("Exception"),
|
||||||
PrimDef::UInt32 => class("uint32", |primitives| primitives.uint32),
|
PrimDef::UInt32 => class("uint32"),
|
||||||
PrimDef::UInt64 => class("uint64", |primitives| primitives.uint64),
|
PrimDef::UInt64 => class("uint64"),
|
||||||
PrimDef::Option => class("Option", |primitives| primitives.option),
|
PrimDef::Option => class("Option"),
|
||||||
PrimDef::OptionIsSome => fun("Option.is_some", Some("is_some")),
|
PrimDef::OptionIsSome => fun("Option.is_some", Some("is_some")),
|
||||||
PrimDef::OptionIsNone => fun("Option.is_none", Some("is_none")),
|
PrimDef::OptionIsNone => fun("Option.is_none", Some("is_none")),
|
||||||
PrimDef::OptionUnwrap => fun("Option.unwrap", Some("unwrap")),
|
PrimDef::OptionUnwrap => fun("Option.unwrap", Some("unwrap")),
|
||||||
PrimDef::List => class("list", |primitives| primitives.list),
|
PrimDef::List => class("list"),
|
||||||
PrimDef::NDArray => class("ndarray", |primitives| primitives.ndarray),
|
PrimDef::NDArray => class("ndarray"),
|
||||||
PrimDef::NDArrayCopy => fun("ndarray.copy", Some("copy")),
|
PrimDef::NDArrayCopy => fun("ndarray.copy", Some("copy")),
|
||||||
PrimDef::NDArrayFill => fun("ndarray.fill", Some("fill")),
|
PrimDef::NDArrayFill => fun("ndarray.fill", Some("fill")),
|
||||||
PrimDef::FunInt32 => fun("int32", None),
|
PrimDef::FunInt32 => fun("int32", None),
|
||||||
@ -273,17 +263,8 @@ impl PrimDef {
|
|||||||
PrimDef::FunNpLdExp => fun("np_ldexp", None),
|
PrimDef::FunNpLdExp => fun("np_ldexp", None),
|
||||||
PrimDef::FunNpHypot => fun("np_hypot", None),
|
PrimDef::FunNpHypot => fun("np_hypot", None),
|
||||||
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
|
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
|
||||||
PrimDef::FunNpDot => fun("np_dot", None),
|
PrimDef::FunTryInvertTo => fun("try_invert_to", None),
|
||||||
PrimDef::FunNpLinalgMatmul => fun("np_linalg_matmul", None),
|
PrimDef::FunWilkinsonShift => fun("wilkinson_shift", None),
|
||||||
PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None),
|
|
||||||
PrimDef::FunNpLinalgQr => fun("np_linalg_qr", None),
|
|
||||||
PrimDef::FunNpLinalgSvd => fun("np_linalg_svd", None),
|
|
||||||
PrimDef::FunNpLinalgInv => fun("np_linalg_inv", None),
|
|
||||||
PrimDef::FunNpLinalgPinv => fun("np_linalg_pinv", None),
|
|
||||||
PrimDef::FunSpLinalgLu => fun("sp_linalg_lu", None),
|
|
||||||
PrimDef::FunSpLinalgSchur => fun("sp_linalg_schur", None),
|
|
||||||
PrimDef::FunSpLinalgHessenberg => fun("sp_linalg_hessenberg", None),
|
|
||||||
|
|
||||||
PrimDef::FunSome => fun("Some", None),
|
PrimDef::FunSome => fun("Some", None),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::symbol_resolver::SymbolValue;
|
use crate::symbol_resolver::SymbolValue;
|
||||||
use crate::toplevel::helper::{PrimDef, PrimDefDetails};
|
use crate::toplevel::helper::PrimDef;
|
||||||
use crate::typecheck::typedef::VarMap;
|
use crate::typecheck::typedef::VarMap;
|
||||||
use nac3parser::ast::Constant;
|
use nac3parser::ast::Constant;
|
||||||
use strum::IntoEnumIterator;
|
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub enum TypeAnnotation {
|
pub enum TypeAnnotation {
|
||||||
@ -358,7 +357,6 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
|
|||||||
pub fn get_type_from_type_annotation_kinds(
|
pub fn get_type_from_type_annotation_kinds(
|
||||||
top_level_defs: &[Arc<RwLock<TopLevelDef>>],
|
top_level_defs: &[Arc<RwLock<TopLevelDef>>],
|
||||||
unifier: &mut Unifier,
|
unifier: &mut Unifier,
|
||||||
primitives: &PrimitiveStore,
|
|
||||||
ann: &TypeAnnotation,
|
ann: &TypeAnnotation,
|
||||||
subst_list: &mut Option<Vec<Type>>,
|
subst_list: &mut Option<Vec<Type>>,
|
||||||
) -> Result<Type, HashSet<String>> {
|
) -> Result<Type, HashSet<String>> {
|
||||||
@ -381,141 +379,100 @@ pub fn get_type_from_type_annotation_kinds(
|
|||||||
let param_ty = params
|
let param_ty = params
|
||||||
.iter()
|
.iter()
|
||||||
.map(|x| {
|
.map(|x| {
|
||||||
get_type_from_type_annotation_kinds(
|
get_type_from_type_annotation_kinds(top_level_defs, unifier, x, subst_list)
|
||||||
top_level_defs,
|
|
||||||
unifier,
|
|
||||||
primitives,
|
|
||||||
x,
|
|
||||||
subst_list,
|
|
||||||
)
|
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
|
||||||
let ty = if let Some(prim_def) = PrimDef::iter().find(|prim| prim.id() == *obj_id) {
|
let subst = {
|
||||||
// Primitive TopLevelDefs do not contain all fields that are present in their Type
|
// check for compatible range
|
||||||
// counterparts, so directly perform subst on the Type instead.
|
// TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check
|
||||||
|
let mut result = VarMap::new();
|
||||||
let PrimDefDetails::PrimClass { get_ty_fn, .. } = prim_def.details() else {
|
for (tvar, p) in type_vars.iter().zip(param_ty) {
|
||||||
unreachable!()
|
match unifier.get_ty(*tvar).as_ref() {
|
||||||
};
|
TypeEnum::TVar {
|
||||||
|
id,
|
||||||
let base_ty = get_ty_fn(primitives);
|
range,
|
||||||
let params =
|
fields: None,
|
||||||
if let TypeEnum::TObj { params, .. } = &*unifier.get_ty_immutable(base_ty) {
|
name,
|
||||||
params.clone()
|
loc,
|
||||||
} else {
|
is_const_generic: false,
|
||||||
unreachable!()
|
} => {
|
||||||
};
|
let ok: bool = {
|
||||||
|
// create a temp type var and unify to check compatibility
|
||||||
unifier
|
p == *tvar || {
|
||||||
.subst(
|
let temp = unifier.get_fresh_var_with_range(
|
||||||
get_ty_fn(primitives),
|
range.as_slice(),
|
||||||
¶ms
|
*name,
|
||||||
.iter()
|
*loc,
|
||||||
.zip(param_ty)
|
);
|
||||||
.map(|(obj_tv, param)| (*obj_tv.0, param))
|
unifier.unify(temp.ty, p).is_ok()
|
||||||
.collect(),
|
|
||||||
)
|
|
||||||
.unwrap_or(base_ty)
|
|
||||||
} else {
|
|
||||||
let subst = {
|
|
||||||
// check for compatible range
|
|
||||||
// TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check
|
|
||||||
let mut result = VarMap::new();
|
|
||||||
for (tvar, p) in type_vars.iter().zip(param_ty) {
|
|
||||||
match unifier.get_ty(*tvar).as_ref() {
|
|
||||||
TypeEnum::TVar {
|
|
||||||
id,
|
|
||||||
range,
|
|
||||||
fields: None,
|
|
||||||
name,
|
|
||||||
loc,
|
|
||||||
is_const_generic: false,
|
|
||||||
} => {
|
|
||||||
let ok: bool = {
|
|
||||||
// create a temp type var and unify to check compatibility
|
|
||||||
p == *tvar || {
|
|
||||||
let temp = unifier.get_fresh_var_with_range(
|
|
||||||
range.as_slice(),
|
|
||||||
*name,
|
|
||||||
*loc,
|
|
||||||
);
|
|
||||||
unifier.unify(temp.ty, p).is_ok()
|
|
||||||
}
|
|
||||||
};
|
|
||||||
if ok {
|
|
||||||
result.insert(*id, p);
|
|
||||||
} else {
|
|
||||||
return Err(HashSet::from([format!(
|
|
||||||
"cannot apply type {} to type variable with id {:?}",
|
|
||||||
unifier.internal_stringify(
|
|
||||||
p,
|
|
||||||
&mut |id| format!("class{id}"),
|
|
||||||
&mut |id| format!("typevar{id}"),
|
|
||||||
&mut None
|
|
||||||
),
|
|
||||||
*id
|
|
||||||
)]));
|
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
if ok {
|
||||||
|
result.insert(*id, p);
|
||||||
|
} else {
|
||||||
|
return Err(HashSet::from([format!(
|
||||||
|
"cannot apply type {} to type variable with id {:?}",
|
||||||
|
unifier.internal_stringify(
|
||||||
|
p,
|
||||||
|
&mut |id| format!("class{id}"),
|
||||||
|
&mut |id| format!("typevar{id}"),
|
||||||
|
&mut None
|
||||||
|
),
|
||||||
|
*id
|
||||||
|
)]));
|
||||||
}
|
}
|
||||||
|
|
||||||
TypeEnum::TVar {
|
|
||||||
id, range, name, loc, is_const_generic: true, ..
|
|
||||||
} => {
|
|
||||||
let ty = range[0];
|
|
||||||
let ok: bool = {
|
|
||||||
// create a temp type var and unify to check compatibility
|
|
||||||
p == *tvar || {
|
|
||||||
let temp =
|
|
||||||
unifier.get_fresh_const_generic_var(ty, *name, *loc);
|
|
||||||
unifier.unify(temp.ty, p).is_ok()
|
|
||||||
}
|
|
||||||
};
|
|
||||||
if ok {
|
|
||||||
result.insert(*id, p);
|
|
||||||
} else {
|
|
||||||
return Err(HashSet::from([format!(
|
|
||||||
"cannot apply type {} to type variable {}",
|
|
||||||
unifier.stringify(p),
|
|
||||||
name.unwrap_or_else(|| format!("typevar{id}").into()),
|
|
||||||
)]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
_ => unreachable!("must be generic type var"),
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
result
|
|
||||||
};
|
|
||||||
// Class Attributes keep a copy with Class Definition and are not added to objects
|
|
||||||
let mut tobj_fields = methods
|
|
||||||
.iter()
|
|
||||||
.map(|(name, ty, _)| {
|
|
||||||
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
|
|
||||||
// methods are immutable
|
|
||||||
(*name, (subst_ty, false))
|
|
||||||
})
|
|
||||||
.collect::<HashMap<_, _>>();
|
|
||||||
tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| {
|
|
||||||
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
|
|
||||||
(*name, (subst_ty, *mutability))
|
|
||||||
}));
|
|
||||||
let need_subst = !subst.is_empty();
|
|
||||||
let ty = unifier.add_ty(TypeEnum::TObj {
|
|
||||||
obj_id: *obj_id,
|
|
||||||
fields: tobj_fields,
|
|
||||||
params: subst,
|
|
||||||
});
|
|
||||||
|
|
||||||
if need_subst {
|
TypeEnum::TVar { id, range, name, loc, is_const_generic: true, .. } => {
|
||||||
if let Some(wl) = subst_list.as_mut() {
|
let ty = range[0];
|
||||||
wl.push(ty);
|
let ok: bool = {
|
||||||
|
// create a temp type var and unify to check compatibility
|
||||||
|
p == *tvar || {
|
||||||
|
let temp = unifier.get_fresh_const_generic_var(ty, *name, *loc);
|
||||||
|
unifier.unify(temp.ty, p).is_ok()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if ok {
|
||||||
|
result.insert(*id, p);
|
||||||
|
} else {
|
||||||
|
return Err(HashSet::from([format!(
|
||||||
|
"cannot apply type {} to type variable {}",
|
||||||
|
unifier.stringify(p),
|
||||||
|
name.unwrap_or_else(|| format!("typevar{id}").into()),
|
||||||
|
)]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => unreachable!("must be generic type var"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
result
|
||||||
ty
|
|
||||||
};
|
};
|
||||||
|
// Class Attributes keep a copy with Class Definition and are not added to objects
|
||||||
|
let mut tobj_fields = methods
|
||||||
|
.iter()
|
||||||
|
.map(|(name, ty, _)| {
|
||||||
|
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
|
||||||
|
// methods are immutable
|
||||||
|
(*name, (subst_ty, false))
|
||||||
|
})
|
||||||
|
.collect::<HashMap<_, _>>();
|
||||||
|
tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| {
|
||||||
|
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
|
||||||
|
(*name, (subst_ty, *mutability))
|
||||||
|
}));
|
||||||
|
let need_subst = !subst.is_empty();
|
||||||
|
let ty = unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: *obj_id,
|
||||||
|
fields: tobj_fields,
|
||||||
|
params: subst,
|
||||||
|
});
|
||||||
|
if need_subst {
|
||||||
|
if let Some(wl) = subst_list.as_mut() {
|
||||||
|
wl.push(ty);
|
||||||
|
}
|
||||||
|
}
|
||||||
Ok(ty)
|
Ok(ty)
|
||||||
}
|
}
|
||||||
TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty),
|
TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty),
|
||||||
@ -533,7 +490,6 @@ pub fn get_type_from_type_annotation_kinds(
|
|||||||
let ty = get_type_from_type_annotation_kinds(
|
let ty = get_type_from_type_annotation_kinds(
|
||||||
top_level_defs,
|
top_level_defs,
|
||||||
unifier,
|
unifier,
|
||||||
primitives,
|
|
||||||
ty.as_ref(),
|
ty.as_ref(),
|
||||||
subst_list,
|
subst_list,
|
||||||
)?;
|
)?;
|
||||||
@ -543,13 +499,7 @@ pub fn get_type_from_type_annotation_kinds(
|
|||||||
let tys = tys
|
let tys = tys
|
||||||
.iter()
|
.iter()
|
||||||
.map(|x| {
|
.map(|x| {
|
||||||
get_type_from_type_annotation_kinds(
|
get_type_from_type_annotation_kinds(top_level_defs, unifier, x, subst_list)
|
||||||
top_level_defs,
|
|
||||||
unifier,
|
|
||||||
primitives,
|
|
||||||
x,
|
|
||||||
subst_list,
|
|
||||||
)
|
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
Ok(unifier.add_ty(TypeEnum::TTuple { ty: tys }))
|
Ok(unifier.add_ty(TypeEnum::TTuple { ty: tys }))
|
||||||
|
@ -8,7 +8,7 @@ edition = "2021"
|
|||||||
parking_lot = "0.12"
|
parking_lot = "0.12"
|
||||||
nac3parser = { path = "../nac3parser" }
|
nac3parser = { path = "../nac3parser" }
|
||||||
nac3core = { path = "../nac3core" }
|
nac3core = { path = "../nac3core" }
|
||||||
linalg_externfns = { path = "./linalg_externfns" }
|
externfns = { path = "./demo/externfns" }
|
||||||
|
|
||||||
[dependencies.clap]
|
[dependencies.clap]
|
||||||
version = "4.5"
|
version = "4.5"
|
||||||
|
@ -15,7 +15,7 @@ done
|
|||||||
demo="$1"
|
demo="$1"
|
||||||
|
|
||||||
echo -n "Checking $demo... "
|
echo -n "Checking $demo... "
|
||||||
# ./interpret_demo.py "$demo" > interpreted.log
|
./interpret_demo.py "$demo" > interpreted.log
|
||||||
./run_demo.sh --out run.log "${nac3args[@]}" "$demo"
|
./run_demo.sh --out run.log "${nac3args[@]}" "$demo"
|
||||||
./run_demo.sh --lli --out run_lli.log "${nac3args[@]}" "$demo"
|
./run_demo.sh --lli --out run_lli.log "${nac3args[@]}" "$demo"
|
||||||
diff -Nau interpreted.log run.log
|
diff -Nau interpreted.log run.log
|
||||||
|
@ -107,25 +107,8 @@ uint32_t __nac3_personality(uint32_t state, uint32_t exception_object, uint32_t
|
|||||||
__builtin_unreachable();
|
__builtin_unreachable();
|
||||||
}
|
}
|
||||||
|
|
||||||
// See `struct Exception<'a>` in
|
uint32_t __nac3_raise(uint32_t state, uint32_t exception_object, uint32_t context) {
|
||||||
// https://github.com/m-labs/artiq/blob/master/artiq/firmware/libeh/eh_artiq.rs
|
printf("__nac3_raise(state: %u, exception_object: %u, context: %u)\n", state, exception_object, context);
|
||||||
struct Exception {
|
|
||||||
uint32_t id;
|
|
||||||
struct cslice file;
|
|
||||||
uint32_t line;
|
|
||||||
uint32_t column;
|
|
||||||
struct cslice function;
|
|
||||||
struct cslice message;
|
|
||||||
int64_t param[3];
|
|
||||||
};
|
|
||||||
|
|
||||||
uint32_t __nac3_raise(struct Exception* e) {
|
|
||||||
printf("__nac3_raise called. Exception details:\n");
|
|
||||||
printf(" ID: %lld\n", e->id);
|
|
||||||
printf(" Location: %*s:%lld:%lld\n" , e->file.len, (const char*) e->file.data, e->line, e->column);
|
|
||||||
printf(" Function: %*s\n" , e->function.len, (const char*) e->function.data);
|
|
||||||
printf(" Message: \"%*s\"\n" , e->message.len, (const char*) e->message.data);
|
|
||||||
printf(" Params: {0}=%lld, {1}=%lld, {2}=%lld\n", e->param[0], e->param[1], e->param[2]);
|
|
||||||
exit(101);
|
exit(101);
|
||||||
__builtin_unreachable();
|
__builtin_unreachable();
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "linalg_externfns"
|
name = "externfns"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
@ -8,4 +8,3 @@ crate-type = ["cdylib"]
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
nalgebra = {version = "0.32.6", default-features = false, features = ["libm", "alloc"]}
|
nalgebra = {version = "0.32.6", default-features = false, features = ["libm", "alloc"]}
|
||||||
cslice = "0.3.0"
|
|
41
nac3standalone/demo/externfns/src/lib.rs
Normal file
41
nac3standalone/demo/externfns/src/lib.rs
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
#![deny(
|
||||||
|
future_incompatible,
|
||||||
|
let_underscore,
|
||||||
|
nonstandard_style,
|
||||||
|
rust_2024_compatibility,
|
||||||
|
clippy::all
|
||||||
|
)]
|
||||||
|
#![warn(clippy::pedantic)]
|
||||||
|
#![allow(clippy::semicolon_if_nothing_returned, clippy::uninlined_format_args)]
|
||||||
|
|
||||||
|
use core::slice;
|
||||||
|
use nalgebra::{DMatrix, linalg};
|
||||||
|
|
||||||
|
#[no_mangle]
|
||||||
|
/// Provide an interface to `nalgebra::linalg::try_invert_to`
|
||||||
|
pub extern "C" fn linalg_try_invert_to(dim0: usize, dim1: usize, data: *mut f64) -> i8
|
||||||
|
{
|
||||||
|
|
||||||
|
let data_slice = unsafe { slice::from_raw_parts_mut(data, dim0 * dim1) };
|
||||||
|
let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice);
|
||||||
|
let mut inverted_matrix = DMatrix::<f64>::zeros(dim0, dim1);
|
||||||
|
|
||||||
|
if linalg::try_invert_to(matrix, &mut inverted_matrix) {
|
||||||
|
data_slice.copy_from_slice(inverted_matrix.transpose().as_slice());
|
||||||
|
1
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[no_mangle]
|
||||||
|
/// Provide an interface to `nalgebra::linalg::wilkinson_shift`
|
||||||
|
pub extern "C" fn linalg_wilkinson_shift(dim0: usize, dim1: usize, data: *mut f64) -> f64
|
||||||
|
{
|
||||||
|
let data_slice = unsafe { slice::from_raw_parts_mut(data, dim0 * dim1) };
|
||||||
|
let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice);
|
||||||
|
|
||||||
|
// Check if matrix is symmetric
|
||||||
|
assert!(matrix[(0, 1)] == matrix[(1, 0)], "Operation Wilkinson Shift expects symmetric matrix");
|
||||||
|
return linalg::wilkinson_shift(matrix[(0, 0)], matrix[(1, 1)], matrix[(0, 1)]);
|
||||||
|
}
|
@ -5,7 +5,6 @@ import importlib.util
|
|||||||
import importlib.machinery
|
import importlib.machinery
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy as sp
|
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
@ -188,7 +187,7 @@ def patch(module):
|
|||||||
module.ceil64 = _ceil
|
module.ceil64 = _ceil
|
||||||
module.np_ceil = np.ceil
|
module.np_ceil = np.ceil
|
||||||
|
|
||||||
# NumPy NDArray factory functions
|
# NumPy ndarray functions
|
||||||
module.ndarray = NDArray
|
module.ndarray = NDArray
|
||||||
module.np_ndarray = np.ndarray
|
module.np_ndarray = np.ndarray
|
||||||
module.np_empty = np.empty
|
module.np_empty = np.empty
|
||||||
@ -239,7 +238,7 @@ def patch(module):
|
|||||||
module.np_hypot = np.hypot
|
module.np_hypot = np.hypot
|
||||||
module.np_nextafter = np.nextafter
|
module.np_nextafter = np.nextafter
|
||||||
|
|
||||||
# SciPy Math functions
|
# SciPy Math Functions
|
||||||
module.sp_spec_erf = special.erf
|
module.sp_spec_erf = special.erf
|
||||||
module.sp_spec_erfc = special.erfc
|
module.sp_spec_erfc = special.erfc
|
||||||
module.sp_spec_gamma = special.gamma
|
module.sp_spec_gamma = special.gamma
|
||||||
@ -247,21 +246,16 @@ def patch(module):
|
|||||||
module.sp_spec_j0 = special.j0
|
module.sp_spec_j0 = special.j0
|
||||||
module.sp_spec_j1 = special.j1
|
module.sp_spec_j1 = special.j1
|
||||||
|
|
||||||
# Linalg functions
|
# NumPy NDArray Functions
|
||||||
module.np_dot = np.dot
|
module.np_ndarray = np.ndarray
|
||||||
module.np_linalg_matmul = np.matmul
|
module.np_empty = np.empty
|
||||||
module.np_linalg_cholesky = np.linalg.cholesky
|
module.np_zeros = np.zeros
|
||||||
module.np_linalg_qr = np.linalg.qr
|
module.np_ones = np.ones
|
||||||
module.np_linalg_svd = np.linalg.svd
|
module.np_full = np.full
|
||||||
module.np_linalg_inv = np.linalg.inv
|
module.np_eye = np.eye
|
||||||
module.np_linalg_pinv = np.linalg.pinv
|
module.np_identity = np.identity
|
||||||
|
module.try_invert_to = try_invert_to
|
||||||
module.sp_linalg_lu = lambda x: sp.linalg.lu(x, True)
|
module.wilkinson_shift = wilkinson_shift
|
||||||
module.sp_linalg_schur = sp.linalg.schur
|
|
||||||
# module.sp_linalg_hessenberg = sp.linalg.hessenberg
|
|
||||||
module.sp_linalg_hessenberg = lambda x: x
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def file_import(filename, prefix="file_import_"):
|
def file_import(filename, prefix="file_import_"):
|
||||||
filename = pathlib.Path(filename)
|
filename = pathlib.Path(filename)
|
||||||
|
@ -1,2 +0,0 @@
|
|||||||
Excepytiopn!! knfv 0x7fffffff9218
|
|
||||||
__nac3_personality(state: 1, exception_object: 1, context: 1381323604)
|
|
@ -42,14 +42,14 @@ done
|
|||||||
|
|
||||||
if [ -n "$debug" ] && [ -e ../../target/debug/nac3standalone ]; then
|
if [ -n "$debug" ] && [ -e ../../target/debug/nac3standalone ]; then
|
||||||
nac3standalone=../../target/debug/nac3standalone
|
nac3standalone=../../target/debug/nac3standalone
|
||||||
externfns=../../target/debug/deps/liblinalg_externfns.so
|
externfns=../../target/debug/deps/libexternfns.so
|
||||||
elif [ -e ../../target/release/nac3standalone ]; then
|
elif [ -e ../../target/release/nac3standalone ]; then
|
||||||
nac3standalone=../../target/release/nac3standalone
|
nac3standalone=../../target/release/nac3standalone
|
||||||
externfns=../../target/release/deps/liblinalg_externfns.so
|
externfns=../../target/release/deps/libexternfns.so
|
||||||
else
|
else
|
||||||
# used by Nix builds
|
# used by Nix builds
|
||||||
nac3standalone=../../target/x86_64-unknown-linux-gnu/release/nac3standalone
|
nac3standalone=../../target/x86_64-unknown-linux-gnu/release/nac3standalone
|
||||||
externfns=../../target/x86_64-unknown-linux-gnu/release/deps/liblinalg_externfns.so
|
externfns=../../target/x86_64-unknown-linux-gnu/release/deps/libexternfns.so
|
||||||
fi
|
fi
|
||||||
|
|
||||||
rm -f ./*.o ./*.bc demo
|
rm -f ./*.o ./*.bc demo
|
||||||
|
@ -1,12 +0,0 @@
|
|||||||
8.000000
|
|
||||||
10.000000
|
|
||||||
12.000000
|
|
||||||
4.000000
|
|
||||||
5.000000
|
|
||||||
6.000000
|
|
||||||
1.000000
|
|
||||||
2.000000
|
|
||||||
3.000000
|
|
||||||
4.000000
|
|
||||||
5.000000
|
|
||||||
6.000000
|
|
@ -531,12 +531,10 @@ def test_ndarray_ipow_broadcast_scalar():
|
|||||||
|
|
||||||
def test_ndarray_matmul():
|
def test_ndarray_matmul():
|
||||||
x = np_identity(2)
|
x = np_identity(2)
|
||||||
t: ndarray[float, 2] = np_array([[1., 2., 3.,], [4., 5., 6.], [7., 8., 9.], [7., 8., 9.]])
|
y = x @ np_ones([2, 2])
|
||||||
y = x @ t
|
|
||||||
y2 = np_linalg_matmul(x, t)
|
|
||||||
output_ndarray_float_2(y)
|
|
||||||
output_ndarray_float_2(y2)
|
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
def test_ndarray_imatmul():
|
def test_ndarray_imatmul():
|
||||||
x = np_identity(2)
|
x = np_identity(2)
|
||||||
@ -1431,289 +1429,201 @@ def test_ndarray_nextafter_broadcast_rhs_scalar():
|
|||||||
output_ndarray_float_2(nextafter_x_zeros)
|
output_ndarray_float_2(nextafter_x_zeros)
|
||||||
output_ndarray_float_2(nextafter_x_ones)
|
output_ndarray_float_2(nextafter_x_ones)
|
||||||
|
|
||||||
def test_ndarray_dot():
|
def test_try_invert():
|
||||||
x: ndarray[float, 1] = np_array([5.0, 1.0])
|
x: ndarray[float, 2] = np_array([[1.0, 1.0], [1.0, 5.0]])
|
||||||
y: ndarray[float, 1] = np_array([5.0, 1.0])
|
|
||||||
z = np_dot(x, y)
|
|
||||||
|
|
||||||
output_ndarray_float_1(x)
|
|
||||||
output_ndarray_float_1(y)
|
|
||||||
output_float64(z)
|
|
||||||
|
|
||||||
def test_ndarray_linalg_matmul():
|
|
||||||
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
|
|
||||||
y: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
|
|
||||||
z = np_linalg_matmul(x, y)
|
|
||||||
|
|
||||||
m = np_argmax(z)
|
|
||||||
|
|
||||||
output_ndarray_float_2(x)
|
|
||||||
output_ndarray_float_2(y)
|
|
||||||
output_ndarray_float_2(z)
|
|
||||||
output_int64(m)
|
|
||||||
|
|
||||||
def test_ndarray_cholesky():
|
|
||||||
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
|
|
||||||
y = np_linalg_cholesky(x)
|
|
||||||
|
|
||||||
output_ndarray_float_2(x)
|
|
||||||
output_ndarray_float_2(y)
|
|
||||||
|
|
||||||
def test_ndarray_qr():
|
|
||||||
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
|
|
||||||
y, z = np_linalg_qr(x)
|
|
||||||
|
|
||||||
output_ndarray_float_2(x)
|
output_ndarray_float_2(x)
|
||||||
|
|
||||||
# QR Factorization in nalgebra and numpy do not give the same result
|
y = try_invert_to(x)
|
||||||
# Generating product for printing
|
|
||||||
a = np_linalg_matmul(y, z)
|
|
||||||
output_ndarray_float_2(a)
|
|
||||||
|
|
||||||
def test_ndarray_linalg_inv():
|
|
||||||
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
|
|
||||||
y = np_linalg_inv(x)
|
|
||||||
|
|
||||||
output_ndarray_float_2(x)
|
output_ndarray_float_2(x)
|
||||||
output_ndarray_float_2(y)
|
output_bool(y)
|
||||||
|
|
||||||
def test_ndarray_pinv():
|
|
||||||
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5]])
|
|
||||||
y = np_linalg_pinv(x)
|
|
||||||
|
|
||||||
|
def test_wilkinson_shift():
|
||||||
|
x: ndarray[float, 2] = np_array([[5., 1.], [1., 4.]])
|
||||||
|
y = wilkinson_shift(x)
|
||||||
output_ndarray_float_2(x)
|
output_ndarray_float_2(x)
|
||||||
output_ndarray_float_2(y)
|
output_float64(y)
|
||||||
|
|
||||||
def test_ndarray_schur():
|
|
||||||
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
|
|
||||||
t, z = sp_linalg_schur(x)
|
|
||||||
|
|
||||||
output_ndarray_float_2(x)
|
|
||||||
# Same as np_linalg_qr the signs are different in nalgebra and numpy
|
|
||||||
a = np_linalg_matmul(np_linalg_matmul(z, t), np_linalg_inv(z))
|
|
||||||
output_ndarray_float_2(a)
|
|
||||||
|
|
||||||
def test_ndarray_hessenberg():
|
|
||||||
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5]])
|
|
||||||
h = sp_linalg_hessenberg(x)
|
|
||||||
|
|
||||||
output_ndarray_float_2(x)
|
|
||||||
output_ndarray_float_2(h)
|
|
||||||
|
|
||||||
|
|
||||||
def test_ndarray_lu():
|
|
||||||
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5]])
|
|
||||||
l, u = sp_linalg_lu(x)
|
|
||||||
|
|
||||||
output_ndarray_float_2(x)
|
|
||||||
output_ndarray_float_2(l)
|
|
||||||
output_ndarray_float_2(u)
|
|
||||||
|
|
||||||
|
|
||||||
def test_ndarray_svd():
|
|
||||||
w: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
|
|
||||||
x, y, z = np_linalg_svd(w)
|
|
||||||
|
|
||||||
output_ndarray_float_2(w)
|
|
||||||
|
|
||||||
# Same as np_linalg_qr the signs are different in nalgebra and numpy
|
|
||||||
a = np_linalg_matmul(x, z)
|
|
||||||
output_ndarray_float_2(a)
|
|
||||||
output_ndarray_float_1(y)
|
|
||||||
|
|
||||||
|
|
||||||
def run() -> int32:
|
def run() -> int32:
|
||||||
|
test_try_invert()
|
||||||
|
test_wilkinson_shift()
|
||||||
|
test_ndarray_ctor()
|
||||||
|
test_ndarray_empty()
|
||||||
|
test_ndarray_zeros()
|
||||||
|
test_ndarray_ones()
|
||||||
|
test_ndarray_full()
|
||||||
|
test_ndarray_eye()
|
||||||
|
test_ndarray_array()
|
||||||
|
test_ndarray_identity()
|
||||||
|
test_ndarray_fill()
|
||||||
|
test_ndarray_copy()
|
||||||
|
|
||||||
|
test_ndarray_neg_idx()
|
||||||
|
test_ndarray_slices()
|
||||||
|
test_ndarray_nd_idx()
|
||||||
|
|
||||||
|
test_ndarray_add()
|
||||||
|
test_ndarray_add_broadcast()
|
||||||
|
test_ndarray_add_broadcast_lhs_scalar()
|
||||||
|
test_ndarray_add_broadcast_rhs_scalar()
|
||||||
|
test_ndarray_iadd()
|
||||||
|
test_ndarray_iadd_broadcast()
|
||||||
|
test_ndarray_iadd_broadcast_scalar()
|
||||||
|
test_ndarray_sub()
|
||||||
|
test_ndarray_sub_broadcast()
|
||||||
|
test_ndarray_sub_broadcast_lhs_scalar()
|
||||||
|
test_ndarray_sub_broadcast_rhs_scalar()
|
||||||
|
test_ndarray_isub()
|
||||||
|
test_ndarray_isub_broadcast()
|
||||||
|
test_ndarray_isub_broadcast_scalar()
|
||||||
|
test_ndarray_mul()
|
||||||
|
test_ndarray_mul_broadcast()
|
||||||
|
test_ndarray_mul_broadcast_lhs_scalar()
|
||||||
|
test_ndarray_mul_broadcast_rhs_scalar()
|
||||||
|
test_ndarray_imul()
|
||||||
|
test_ndarray_imul_broadcast()
|
||||||
|
test_ndarray_imul_broadcast_scalar()
|
||||||
|
test_ndarray_truediv()
|
||||||
|
test_ndarray_truediv_broadcast()
|
||||||
|
test_ndarray_truediv_broadcast_lhs_scalar()
|
||||||
|
test_ndarray_truediv_broadcast_rhs_scalar()
|
||||||
|
test_ndarray_itruediv()
|
||||||
|
test_ndarray_itruediv_broadcast()
|
||||||
|
test_ndarray_itruediv_broadcast_scalar()
|
||||||
|
test_ndarray_floordiv()
|
||||||
|
test_ndarray_floordiv_broadcast()
|
||||||
|
test_ndarray_floordiv_broadcast_lhs_scalar()
|
||||||
|
test_ndarray_floordiv_broadcast_rhs_scalar()
|
||||||
|
test_ndarray_ifloordiv()
|
||||||
|
test_ndarray_ifloordiv_broadcast()
|
||||||
|
test_ndarray_ifloordiv_broadcast_scalar()
|
||||||
|
test_ndarray_mod()
|
||||||
|
test_ndarray_mod_broadcast()
|
||||||
|
test_ndarray_mod_broadcast_lhs_scalar()
|
||||||
|
test_ndarray_mod_broadcast_rhs_scalar()
|
||||||
|
test_ndarray_imod()
|
||||||
|
test_ndarray_imod_broadcast()
|
||||||
|
test_ndarray_imod_broadcast_scalar()
|
||||||
|
test_ndarray_pow()
|
||||||
|
test_ndarray_pow_broadcast()
|
||||||
|
test_ndarray_pow_broadcast_lhs_scalar()
|
||||||
|
test_ndarray_pow_broadcast_rhs_scalar()
|
||||||
|
test_ndarray_ipow()
|
||||||
|
test_ndarray_ipow_broadcast()
|
||||||
|
test_ndarray_ipow_broadcast_scalar()
|
||||||
test_ndarray_matmul()
|
test_ndarray_matmul()
|
||||||
# test_ndarray_dot()
|
test_ndarray_imatmul()
|
||||||
# test_ndarray_linalg_matmul()
|
test_ndarray_pos()
|
||||||
# test_ndarray_cholesky()
|
test_ndarray_neg()
|
||||||
# test_ndarray_qr()
|
test_ndarray_inv()
|
||||||
# test_ndarray_svd()
|
test_ndarray_eq()
|
||||||
# test_ndarray_linalg_inv()
|
test_ndarray_eq_broadcast()
|
||||||
# test_ndarray_pinv()
|
test_ndarray_eq_broadcast_lhs_scalar()
|
||||||
# test_ndarray_lu()
|
test_ndarray_eq_broadcast_rhs_scalar()
|
||||||
# test_ndarray_schur()
|
test_ndarray_ne()
|
||||||
# test_ndarray_hessenberg()
|
test_ndarray_ne_broadcast()
|
||||||
|
test_ndarray_ne_broadcast_lhs_scalar()
|
||||||
|
test_ndarray_ne_broadcast_rhs_scalar()
|
||||||
|
test_ndarray_lt()
|
||||||
|
test_ndarray_lt_broadcast()
|
||||||
|
test_ndarray_lt_broadcast_lhs_scalar()
|
||||||
|
test_ndarray_lt_broadcast_rhs_scalar()
|
||||||
|
test_ndarray_lt()
|
||||||
|
test_ndarray_le_broadcast()
|
||||||
|
test_ndarray_le_broadcast_lhs_scalar()
|
||||||
|
test_ndarray_le_broadcast_rhs_scalar()
|
||||||
|
test_ndarray_gt()
|
||||||
|
test_ndarray_gt_broadcast()
|
||||||
|
test_ndarray_gt_broadcast_lhs_scalar()
|
||||||
|
test_ndarray_gt_broadcast_rhs_scalar()
|
||||||
|
test_ndarray_gt()
|
||||||
|
test_ndarray_ge_broadcast()
|
||||||
|
test_ndarray_ge_broadcast_lhs_scalar()
|
||||||
|
test_ndarray_ge_broadcast_rhs_scalar()
|
||||||
|
|
||||||
# test_ndarray_ctor()
|
test_ndarray_int32()
|
||||||
# test_ndarray_empty()
|
test_ndarray_int64()
|
||||||
# test_ndarray_zeros()
|
test_ndarray_uint32()
|
||||||
# test_ndarray_ones()
|
test_ndarray_uint64()
|
||||||
# test_ndarray_full()
|
test_ndarray_float()
|
||||||
# test_ndarray_eye()
|
test_ndarray_bool()
|
||||||
# test_ndarray_array()
|
|
||||||
# test_ndarray_identity()
|
|
||||||
# test_ndarray_fill()
|
|
||||||
# test_ndarray_copy()
|
|
||||||
|
|
||||||
# test_ndarray_neg_idx()
|
test_ndarray_round()
|
||||||
# test_ndarray_slices()
|
test_ndarray_floor()
|
||||||
# test_ndarray_nd_idx()
|
test_ndarray_min()
|
||||||
|
test_ndarray_minimum()
|
||||||
|
test_ndarray_minimum_broadcast()
|
||||||
|
test_ndarray_minimum_broadcast_lhs_scalar()
|
||||||
|
test_ndarray_minimum_broadcast_rhs_scalar()
|
||||||
|
test_ndarray_argmin()
|
||||||
|
test_ndarray_max()
|
||||||
|
test_ndarray_maximum()
|
||||||
|
test_ndarray_maximum_broadcast()
|
||||||
|
test_ndarray_maximum_broadcast_lhs_scalar()
|
||||||
|
test_ndarray_maximum_broadcast_rhs_scalar()
|
||||||
|
test_ndarray_argmax()
|
||||||
|
test_ndarray_abs()
|
||||||
|
test_ndarray_isnan()
|
||||||
|
test_ndarray_isinf()
|
||||||
|
|
||||||
# test_ndarray_add()
|
test_ndarray_sin()
|
||||||
# test_ndarray_add_broadcast()
|
test_ndarray_cos()
|
||||||
# test_ndarray_add_broadcast_lhs_scalar()
|
test_ndarray_exp()
|
||||||
# test_ndarray_add_broadcast_rhs_scalar()
|
test_ndarray_exp2()
|
||||||
# test_ndarray_iadd()
|
test_ndarray_log()
|
||||||
# test_ndarray_iadd_broadcast()
|
test_ndarray_log10()
|
||||||
# test_ndarray_iadd_broadcast_scalar()
|
test_ndarray_log2()
|
||||||
# test_ndarray_sub()
|
test_ndarray_fabs()
|
||||||
# test_ndarray_sub_broadcast()
|
test_ndarray_sqrt()
|
||||||
# test_ndarray_sub_broadcast_lhs_scalar()
|
test_ndarray_rint()
|
||||||
# test_ndarray_sub_broadcast_rhs_scalar()
|
test_ndarray_tan()
|
||||||
# test_ndarray_isub()
|
test_ndarray_arcsin()
|
||||||
# test_ndarray_isub_broadcast()
|
test_ndarray_arccos()
|
||||||
# test_ndarray_isub_broadcast_scalar()
|
test_ndarray_arctan()
|
||||||
# test_ndarray_mul()
|
test_ndarray_sinh()
|
||||||
# test_ndarray_mul_broadcast()
|
test_ndarray_cosh()
|
||||||
# test_ndarray_mul_broadcast_lhs_scalar()
|
test_ndarray_tanh()
|
||||||
# test_ndarray_mul_broadcast_rhs_scalar()
|
test_ndarray_arcsinh()
|
||||||
# test_ndarray_imul()
|
test_ndarray_arccosh()
|
||||||
# test_ndarray_imul_broadcast()
|
test_ndarray_arctanh()
|
||||||
# test_ndarray_imul_broadcast_scalar()
|
test_ndarray_expm1()
|
||||||
# test_ndarray_truediv()
|
test_ndarray_cbrt()
|
||||||
# test_ndarray_truediv_broadcast()
|
|
||||||
# test_ndarray_truediv_broadcast_lhs_scalar()
|
|
||||||
# test_ndarray_truediv_broadcast_rhs_scalar()
|
|
||||||
# test_ndarray_itruediv()
|
|
||||||
# test_ndarray_itruediv_broadcast()
|
|
||||||
# test_ndarray_itruediv_broadcast_scalar()
|
|
||||||
# test_ndarray_floordiv()
|
|
||||||
# test_ndarray_floordiv_broadcast()
|
|
||||||
# test_ndarray_floordiv_broadcast_lhs_scalar()
|
|
||||||
# test_ndarray_floordiv_broadcast_rhs_scalar()
|
|
||||||
# test_ndarray_ifloordiv()
|
|
||||||
# test_ndarray_ifloordiv_broadcast()
|
|
||||||
# test_ndarray_ifloordiv_broadcast_scalar()
|
|
||||||
# test_ndarray_mod()
|
|
||||||
# test_ndarray_mod_broadcast()
|
|
||||||
# test_ndarray_mod_broadcast_lhs_scalar()
|
|
||||||
# test_ndarray_mod_broadcast_rhs_scalar()
|
|
||||||
# test_ndarray_imod()
|
|
||||||
# test_ndarray_imod_broadcast()
|
|
||||||
# test_ndarray_imod_broadcast_scalar()
|
|
||||||
# test_ndarray_pow()
|
|
||||||
# test_ndarray_pow_broadcast()
|
|
||||||
# test_ndarray_pow_broadcast_lhs_scalar()
|
|
||||||
# test_ndarray_pow_broadcast_rhs_scalar()
|
|
||||||
# test_ndarray_ipow()
|
|
||||||
# test_ndarray_ipow_broadcast()
|
|
||||||
# test_ndarray_ipow_broadcast_scalar()
|
|
||||||
# test_ndarray_matmul()
|
|
||||||
# test_ndarray_imatmul()
|
|
||||||
# test_ndarray_pos()
|
|
||||||
# test_ndarray_neg()
|
|
||||||
# test_ndarray_inv()
|
|
||||||
# test_ndarray_eq()
|
|
||||||
# test_ndarray_eq_broadcast()
|
|
||||||
# test_ndarray_eq_broadcast_lhs_scalar()
|
|
||||||
# test_ndarray_eq_broadcast_rhs_scalar()
|
|
||||||
# test_ndarray_ne()
|
|
||||||
# test_ndarray_ne_broadcast()
|
|
||||||
# test_ndarray_ne_broadcast_lhs_scalar()
|
|
||||||
# test_ndarray_ne_broadcast_rhs_scalar()
|
|
||||||
# test_ndarray_lt()
|
|
||||||
# test_ndarray_lt_broadcast()
|
|
||||||
# test_ndarray_lt_broadcast_lhs_scalar()
|
|
||||||
# test_ndarray_lt_broadcast_rhs_scalar()
|
|
||||||
# test_ndarray_lt()
|
|
||||||
# test_ndarray_le_broadcast()
|
|
||||||
# test_ndarray_le_broadcast_lhs_scalar()
|
|
||||||
# test_ndarray_le_broadcast_rhs_scalar()
|
|
||||||
# test_ndarray_gt()
|
|
||||||
# test_ndarray_gt_broadcast()
|
|
||||||
# test_ndarray_gt_broadcast_lhs_scalar()
|
|
||||||
# test_ndarray_gt_broadcast_rhs_scalar()
|
|
||||||
# test_ndarray_gt()
|
|
||||||
# test_ndarray_ge_broadcast()
|
|
||||||
# test_ndarray_ge_broadcast_lhs_scalar()
|
|
||||||
# test_ndarray_ge_broadcast_rhs_scalar()
|
|
||||||
|
|
||||||
# test_ndarray_int32()
|
test_ndarray_erf()
|
||||||
# test_ndarray_int64()
|
test_ndarray_erfc()
|
||||||
# test_ndarray_uint32()
|
test_ndarray_gamma()
|
||||||
# test_ndarray_uint64()
|
test_ndarray_gammaln()
|
||||||
# test_ndarray_float()
|
test_ndarray_j0()
|
||||||
# test_ndarray_bool()
|
test_ndarray_j1()
|
||||||
|
|
||||||
# test_ndarray_round()
|
test_ndarray_arctan2()
|
||||||
# test_ndarray_floor()
|
test_ndarray_arctan2_broadcast()
|
||||||
# test_ndarray_min()
|
test_ndarray_arctan2_broadcast_lhs_scalar()
|
||||||
# test_ndarray_minimum()
|
test_ndarray_arctan2_broadcast_rhs_scalar()
|
||||||
# test_ndarray_minimum_broadcast()
|
test_ndarray_copysign()
|
||||||
# test_ndarray_minimum_broadcast_lhs_scalar()
|
test_ndarray_copysign_broadcast()
|
||||||
# test_ndarray_minimum_broadcast_rhs_scalar()
|
test_ndarray_copysign_broadcast_lhs_scalar()
|
||||||
# test_ndarray_argmin()
|
test_ndarray_copysign_broadcast_rhs_scalar()
|
||||||
# test_ndarray_max()
|
test_ndarray_fmax()
|
||||||
# test_ndarray_maximum()
|
test_ndarray_fmax_broadcast()
|
||||||
# test_ndarray_maximum_broadcast()
|
test_ndarray_fmax_broadcast_lhs_scalar()
|
||||||
# test_ndarray_maximum_broadcast_lhs_scalar()
|
test_ndarray_fmax_broadcast_rhs_scalar()
|
||||||
# test_ndarray_maximum_broadcast_rhs_scalar()
|
test_ndarray_fmin()
|
||||||
# test_ndarray_argmax()
|
test_ndarray_fmin_broadcast()
|
||||||
# test_ndarray_abs()
|
test_ndarray_fmin_broadcast_lhs_scalar()
|
||||||
# test_ndarray_isnan()
|
test_ndarray_fmin_broadcast_rhs_scalar()
|
||||||
# test_ndarray_isinf()
|
test_ndarray_ldexp()
|
||||||
|
test_ndarray_ldexp_broadcast()
|
||||||
# test_ndarray_sin()
|
test_ndarray_ldexp_broadcast_lhs_scalar()
|
||||||
# test_ndarray_cos()
|
test_ndarray_ldexp_broadcast_rhs_scalar()
|
||||||
# test_ndarray_exp()
|
test_ndarray_hypot()
|
||||||
# test_ndarray_exp2()
|
test_ndarray_hypot_broadcast()
|
||||||
# test_ndarray_log()
|
test_ndarray_hypot_broadcast_lhs_scalar()
|
||||||
# test_ndarray_log10()
|
test_ndarray_hypot_broadcast_rhs_scalar()
|
||||||
# test_ndarray_log2()
|
test_ndarray_nextafter()
|
||||||
# test_ndarray_fabs()
|
test_ndarray_nextafter_broadcast()
|
||||||
# test_ndarray_sqrt()
|
test_ndarray_nextafter_broadcast_lhs_scalar()
|
||||||
# test_ndarray_rint()
|
test_ndarray_nextafter_broadcast_rhs_scalar()
|
||||||
# test_ndarray_tan()
|
|
||||||
# test_ndarray_arcsin()
|
|
||||||
# test_ndarray_arccos()
|
|
||||||
# test_ndarray_arctan()
|
|
||||||
# test_ndarray_sinh()
|
|
||||||
# test_ndarray_cosh()
|
|
||||||
# test_ndarray_tanh()
|
|
||||||
# test_ndarray_arcsinh()
|
|
||||||
# test_ndarray_arccosh()
|
|
||||||
# test_ndarray_arctanh()
|
|
||||||
# test_ndarray_expm1()
|
|
||||||
# test_ndarray_cbrt()
|
|
||||||
|
|
||||||
# test_ndarray_erf()
|
|
||||||
# test_ndarray_erfc()
|
|
||||||
# test_ndarray_gamma()
|
|
||||||
# test_ndarray_gammaln()
|
|
||||||
# test_ndarray_j0()
|
|
||||||
# test_ndarray_j1()
|
|
||||||
|
|
||||||
# test_ndarray_arctan2()
|
|
||||||
# test_ndarray_arctan2_broadcast()
|
|
||||||
# test_ndarray_arctan2_broadcast_lhs_scalar()
|
|
||||||
# test_ndarray_arctan2_broadcast_rhs_scalar()
|
|
||||||
# test_ndarray_copysign()
|
|
||||||
# test_ndarray_copysign_broadcast()
|
|
||||||
# test_ndarray_copysign_broadcast_lhs_scalar()
|
|
||||||
# test_ndarray_copysign_broadcast_rhs_scalar()
|
|
||||||
# test_ndarray_fmax()
|
|
||||||
# test_ndarray_fmax_broadcast()
|
|
||||||
# test_ndarray_fmax_broadcast_lhs_scalar()
|
|
||||||
# test_ndarray_fmax_broadcast_rhs_scalar()
|
|
||||||
# test_ndarray_fmin()
|
|
||||||
# test_ndarray_fmin_broadcast()
|
|
||||||
# test_ndarray_fmin_broadcast_lhs_scalar()
|
|
||||||
# test_ndarray_fmin_broadcast_rhs_scalar()
|
|
||||||
# test_ndarray_ldexp()
|
|
||||||
# test_ndarray_ldexp_broadcast()
|
|
||||||
# test_ndarray_ldexp_broadcast_lhs_scalar()
|
|
||||||
# test_ndarray_ldexp_broadcast_rhs_scalar()
|
|
||||||
# test_ndarray_hypot()
|
|
||||||
# test_ndarray_hypot_broadcast()
|
|
||||||
# test_ndarray_hypot_broadcast_lhs_scalar()
|
|
||||||
# test_ndarray_hypot_broadcast_rhs_scalar()
|
|
||||||
# test_ndarray_nextafter()
|
|
||||||
# test_ndarray_nextafter_broadcast()
|
|
||||||
# test_ndarray_nextafter_broadcast_lhs_scalar()
|
|
||||||
# test_ndarray_nextafter_broadcast_rhs_scalar()
|
|
||||||
|
|
||||||
# test_try_invert()
|
|
||||||
# test_wilkinson_shift()
|
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
@ -1,346 +0,0 @@
|
|||||||
mod runtime_exception;
|
|
||||||
use core::slice;
|
|
||||||
use nalgebra::{linalg, DMatrix};
|
|
||||||
|
|
||||||
macro_rules! raise_exn {
|
|
||||||
($name:expr, $message:expr, $param0:expr, $param1:expr, $param2:expr) => {{
|
|
||||||
use cslice::AsCSlice;
|
|
||||||
let name_id = $crate::runtime_exception::get_exception_id($name);
|
|
||||||
let exn = $crate::runtime_exception::Exception {
|
|
||||||
id: name_id,
|
|
||||||
file: file!().as_c_slice(),
|
|
||||||
line: line!(),
|
|
||||||
column: column!(),
|
|
||||||
// https://github.com/rust-lang/rfcs/pull/1719
|
|
||||||
function: "(Rust function)".as_c_slice(),
|
|
||||||
message: $message.as_c_slice(),
|
|
||||||
param: [$param0, $param1, $param2],
|
|
||||||
};
|
|
||||||
#[allow(unused_unsafe)]
|
|
||||||
unsafe {
|
|
||||||
$crate::runtime_exception::raise(&exn)
|
|
||||||
}
|
|
||||||
}};
|
|
||||||
($name:expr, $message:expr) => {{
|
|
||||||
raise_exn!($name, $message, 0, 0, 0)
|
|
||||||
}};
|
|
||||||
}
|
|
||||||
|
|
||||||
/// # Safety
|
|
||||||
///
|
|
||||||
/// `data` must point to an array with `dim0`x`dim1` elements in row-major order
|
|
||||||
#[no_mangle]
|
|
||||||
pub unsafe extern "C" fn linalg_try_invert_to(dim0: usize, dim1: usize, data: *mut f64) -> i8 {
|
|
||||||
let data_slice = unsafe { slice::from_raw_parts_mut(data, dim0 * dim1) };
|
|
||||||
let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice);
|
|
||||||
let mut inverted_matrix = DMatrix::<f64>::zeros(dim0, dim1);
|
|
||||||
|
|
||||||
if linalg::try_invert_to(matrix, &mut inverted_matrix) {
|
|
||||||
data_slice.copy_from_slice(inverted_matrix.transpose().as_slice());
|
|
||||||
1
|
|
||||||
} else {
|
|
||||||
0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// # Safety
|
|
||||||
///
|
|
||||||
/// `data` must point to an array of 4 elements in row-major order
|
|
||||||
#[no_mangle]
|
|
||||||
pub unsafe extern "C" fn linalg_wilkinson_shift(dim0: usize, dim1: usize, data: *mut f64) -> f64 {
|
|
||||||
let data_slice = unsafe { slice::from_raw_parts_mut(data, dim0 * dim1) };
|
|
||||||
let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice);
|
|
||||||
|
|
||||||
linalg::wilkinson_shift(matrix[(0, 0)], matrix[(1, 1)], matrix[(0, 1)])
|
|
||||||
}
|
|
||||||
|
|
||||||
/// # Safety
|
|
||||||
///
|
|
||||||
/// `data` must point to an array of 4 elements in row-major order
|
|
||||||
#[no_mangle]
|
|
||||||
pub unsafe extern "C" fn np_dot(
|
|
||||||
dim0: usize,
|
|
||||||
dim1: usize,
|
|
||||||
x1: *mut f64,
|
|
||||||
_: usize,
|
|
||||||
_: usize,
|
|
||||||
x2: *mut f64,
|
|
||||||
) -> f64 {
|
|
||||||
let data_slice1 = unsafe { slice::from_raw_parts_mut(x1, dim0 * dim1) };
|
|
||||||
let data_slice2 = unsafe { slice::from_raw_parts_mut(x2, dim0 * dim1) };
|
|
||||||
|
|
||||||
let matrix1 = DMatrix::from_row_slice(dim0, dim1, data_slice1);
|
|
||||||
let matrix2 = DMatrix::from_row_slice(dim0, dim1, data_slice2);
|
|
||||||
|
|
||||||
matrix1.dot(&matrix2)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// # Safety
|
|
||||||
///
|
|
||||||
/// `data` must point to an array of 4 elements in row-major order
|
|
||||||
#[no_mangle]
|
|
||||||
pub unsafe extern "C" fn np_linalg_matmul(
|
|
||||||
dim1_0: usize,
|
|
||||||
dim1_1: usize,
|
|
||||||
x1: *mut f64,
|
|
||||||
dim2_0: usize,
|
|
||||||
dim2_1: usize,
|
|
||||||
x2: *mut f64,
|
|
||||||
dim3_0: usize,
|
|
||||||
dim3_1: usize,
|
|
||||||
out: *mut f64,
|
|
||||||
) -> i8 {
|
|
||||||
// let name = unsafe {slice::from_raw_parts_mut(n, l)};
|
|
||||||
// let fne = name.as_c_slice();
|
|
||||||
raise_exn!("ZeroDivisionError", "Divide by Zero");
|
|
||||||
let data_slice1 = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
|
|
||||||
let data_slice2 = unsafe { slice::from_raw_parts_mut(x2, dim2_0 * dim2_1) };
|
|
||||||
let out_slice = unsafe { slice::from_raw_parts_mut(out, dim3_0 * dim3_1) };
|
|
||||||
|
|
||||||
let matrix1 = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice1);
|
|
||||||
let matrix2 = DMatrix::from_row_slice(dim2_0, dim2_1, data_slice2);
|
|
||||||
let mut result = DMatrix::<f64>::zeros(dim3_0, dim3_1);
|
|
||||||
|
|
||||||
matrix1.mul_to(&matrix2, &mut result);
|
|
||||||
out_slice.copy_from_slice(result.transpose().as_slice());
|
|
||||||
// raise_exn!("ZeroDivisionError", "Divide by Zero", r, c, n);
|
|
||||||
|
|
||||||
1
|
|
||||||
}
|
|
||||||
|
|
||||||
/// # Safety
|
|
||||||
///
|
|
||||||
/// `data` must point to an array of 4 elements in row-major order
|
|
||||||
#[no_mangle]
|
|
||||||
pub unsafe extern "C" fn np_linalg_cholesky(
|
|
||||||
dim1_0: usize,
|
|
||||||
dim1_1: usize,
|
|
||||||
x1: *mut f64,
|
|
||||||
dim2_0: usize,
|
|
||||||
dim2_1: usize,
|
|
||||||
out: *mut f64,
|
|
||||||
) -> i8 {
|
|
||||||
let data_slice1 = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
|
|
||||||
let out_slice = unsafe { slice::from_raw_parts_mut(out, dim2_0 * dim2_1) };
|
|
||||||
|
|
||||||
let matrix1 = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice1);
|
|
||||||
|
|
||||||
let res = matrix1.cholesky();
|
|
||||||
match res {
|
|
||||||
None => 0,
|
|
||||||
Some(c) => {
|
|
||||||
out_slice.copy_from_slice(c.unpack().transpose().as_slice());
|
|
||||||
1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// # Safety
|
|
||||||
///
|
|
||||||
/// `data` must point to an array of 4 elements in row-major order
|
|
||||||
#[no_mangle]
|
|
||||||
pub unsafe extern "C" fn np_linalg_qr(
|
|
||||||
dim1_0: usize,
|
|
||||||
dim1_1: usize,
|
|
||||||
x1: *mut f64,
|
|
||||||
dim2_0: usize,
|
|
||||||
dim2_1: usize,
|
|
||||||
out_q: *mut f64,
|
|
||||||
dim3_0: usize,
|
|
||||||
dim3_1: usize,
|
|
||||||
out_r: *mut f64,
|
|
||||||
) -> i8 {
|
|
||||||
let data_slice1 = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
|
|
||||||
let out_q_slice = unsafe { slice::from_raw_parts_mut(out_q, dim2_0 * dim2_1) };
|
|
||||||
let out_r_slice = unsafe { slice::from_raw_parts_mut(out_r, dim3_0 * dim3_1) };
|
|
||||||
|
|
||||||
// Refer to https://github.com/dimforge/nalgebra/issues/735
|
|
||||||
let matrix1 = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice1);
|
|
||||||
|
|
||||||
let res = matrix1.qr();
|
|
||||||
let (q, r) = res.unpack();
|
|
||||||
|
|
||||||
// Uses different algo need to match numpy
|
|
||||||
out_q_slice.copy_from_slice(q.transpose().as_slice());
|
|
||||||
out_r_slice.copy_from_slice(r.transpose().as_slice());
|
|
||||||
1
|
|
||||||
}
|
|
||||||
|
|
||||||
/// # Safety
|
|
||||||
///
|
|
||||||
/// `data` must point to an array of 4 elements in row-major order
|
|
||||||
#[no_mangle]
|
|
||||||
pub unsafe extern "C" fn np_linalg_svd(
|
|
||||||
dim1_0: usize,
|
|
||||||
dim1_1: usize,
|
|
||||||
x1: *mut f64,
|
|
||||||
dim2_0: usize,
|
|
||||||
dim2_1: usize,
|
|
||||||
out_u: *mut f64,
|
|
||||||
dim3_0: usize,
|
|
||||||
dim3_1: usize,
|
|
||||||
out_s: *mut f64,
|
|
||||||
dim4_0: usize,
|
|
||||||
dim4_1: usize,
|
|
||||||
out_vh: *mut f64,
|
|
||||||
) -> i8 {
|
|
||||||
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
|
|
||||||
let out_u_slice = unsafe { slice::from_raw_parts_mut(out_u, dim2_0 * dim2_1) };
|
|
||||||
let out_s_slice = unsafe { slice::from_raw_parts_mut(out_s, dim3_0 * dim3_1) };
|
|
||||||
let out_vh_slice = unsafe { slice::from_raw_parts_mut(out_vh, dim4_0 * dim4_1) };
|
|
||||||
|
|
||||||
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
|
|
||||||
let res = matrix.svd(true, true);
|
|
||||||
|
|
||||||
out_u_slice.copy_from_slice(res.u.unwrap().transpose().as_slice());
|
|
||||||
out_s_slice.copy_from_slice(res.singular_values.as_slice());
|
|
||||||
out_vh_slice.copy_from_slice(res.v_t.unwrap().transpose().as_slice());
|
|
||||||
|
|
||||||
1
|
|
||||||
}
|
|
||||||
|
|
||||||
/// # Safety
|
|
||||||
///
|
|
||||||
/// `data` must point to an array of 4 elements in row-major order
|
|
||||||
#[no_mangle]
|
|
||||||
pub unsafe extern "C" fn np_linalg_inv(
|
|
||||||
dim1_0: usize,
|
|
||||||
dim1_1: usize,
|
|
||||||
x1: *mut f64,
|
|
||||||
dim2_0: usize,
|
|
||||||
dim2_1: usize,
|
|
||||||
out: *mut f64,
|
|
||||||
) -> i8 {
|
|
||||||
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
|
|
||||||
let out_slice = unsafe { slice::from_raw_parts_mut(out, dim2_0 * dim2_1) };
|
|
||||||
|
|
||||||
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
|
|
||||||
if !matrix.is_invertible() {
|
|
||||||
// raise error
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
let inv = matrix.try_inverse().unwrap();
|
|
||||||
|
|
||||||
out_slice.copy_from_slice(inv.transpose().as_slice());
|
|
||||||
1
|
|
||||||
}
|
|
||||||
|
|
||||||
/// # Safety
|
|
||||||
///
|
|
||||||
/// `data` must point to an array of 4 elements in row-major order
|
|
||||||
#[no_mangle]
|
|
||||||
pub unsafe extern "C" fn np_linalg_pinv(
|
|
||||||
dim1_0: usize,
|
|
||||||
dim1_1: usize,
|
|
||||||
x1: *mut f64,
|
|
||||||
dim2_0: usize,
|
|
||||||
dim2_1: usize,
|
|
||||||
out: *mut f64,
|
|
||||||
) -> i8 {
|
|
||||||
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
|
|
||||||
let out_slice = unsafe { slice::from_raw_parts_mut(out, dim2_0 * dim2_1) };
|
|
||||||
|
|
||||||
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
|
|
||||||
let svd = matrix.svd(true, true);
|
|
||||||
let inv = svd.pseudo_inverse(1e-15);
|
|
||||||
|
|
||||||
match inv {
|
|
||||||
Ok(m) => {
|
|
||||||
out_slice.copy_from_slice(m.transpose().as_slice());
|
|
||||||
1
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
// raise exception here
|
|
||||||
assert!(false, "{e}");
|
|
||||||
0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// # Safety
|
|
||||||
///
|
|
||||||
/// `data` must point to an array of 4 elements in row-major order
|
|
||||||
#[no_mangle]
|
|
||||||
pub unsafe extern "C" fn sp_linalg_lu(
|
|
||||||
dim1_0: usize,
|
|
||||||
dim1_1: usize,
|
|
||||||
x1: *mut f64,
|
|
||||||
dim2_0: usize,
|
|
||||||
dim2_1: usize,
|
|
||||||
out_l: *mut f64,
|
|
||||||
dim3_0: usize,
|
|
||||||
dim3_1: usize,
|
|
||||||
out_u: *mut f64,
|
|
||||||
) -> i8 {
|
|
||||||
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
|
|
||||||
let out_l_slice = unsafe { slice::from_raw_parts_mut(out_l, dim2_0 * dim2_1) };
|
|
||||||
let out_u_slice = unsafe { slice::from_raw_parts_mut(out_u, dim3_0 * dim3_1) };
|
|
||||||
|
|
||||||
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
|
|
||||||
let (_, l, u) = matrix.lu().unpack();
|
|
||||||
|
|
||||||
out_l_slice.copy_from_slice(l.transpose().as_slice());
|
|
||||||
out_u_slice.copy_from_slice(u.transpose().as_slice());
|
|
||||||
|
|
||||||
1
|
|
||||||
}
|
|
||||||
|
|
||||||
/// # Safety
|
|
||||||
///
|
|
||||||
/// `data` must point to an array of 4 elements in row-major order
|
|
||||||
#[no_mangle]
|
|
||||||
pub unsafe extern "C" fn sp_linalg_schur(
|
|
||||||
dim1_0: usize,
|
|
||||||
dim1_1: usize,
|
|
||||||
x1: *mut f64,
|
|
||||||
dim2_0: usize,
|
|
||||||
dim2_1: usize,
|
|
||||||
out_t: *mut f64,
|
|
||||||
dim3_0: usize,
|
|
||||||
dim3_1: usize,
|
|
||||||
out_z: *mut f64,
|
|
||||||
) -> i8 {
|
|
||||||
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
|
|
||||||
let out_t_slice = unsafe { slice::from_raw_parts_mut(out_t, dim2_0 * dim2_1) };
|
|
||||||
let out_z_slice = unsafe { slice::from_raw_parts_mut(out_z, dim3_0 * dim3_1) };
|
|
||||||
|
|
||||||
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
|
|
||||||
if !matrix.is_square() {
|
|
||||||
// Throw error here
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
let (z, t) = matrix.schur().unpack();
|
|
||||||
|
|
||||||
out_t_slice.copy_from_slice(t.transpose().as_slice());
|
|
||||||
out_z_slice.copy_from_slice(z.transpose().as_slice());
|
|
||||||
|
|
||||||
1
|
|
||||||
}
|
|
||||||
|
|
||||||
/// # Safety
|
|
||||||
///
|
|
||||||
/// `data` must point to an array of 4 elements in row-major order
|
|
||||||
#[no_mangle]
|
|
||||||
pub unsafe extern "C" fn sp_linalg_hessenberg(
|
|
||||||
dim1_0: usize,
|
|
||||||
dim1_1: usize,
|
|
||||||
x1: *mut f64,
|
|
||||||
dim2_0: usize,
|
|
||||||
dim2_1: usize,
|
|
||||||
out_h: *mut f64,
|
|
||||||
) -> i8 {
|
|
||||||
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
|
|
||||||
let out_h_slice = unsafe { slice::from_raw_parts_mut(out_h, dim2_0 * dim2_1) };
|
|
||||||
|
|
||||||
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
|
|
||||||
if !matrix.is_square() {
|
|
||||||
// Throw error here
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
let (_, h) = matrix.hessenberg().unpack();
|
|
||||||
|
|
||||||
out_h_slice.copy_from_slice(h.transpose().as_slice());
|
|
||||||
|
|
||||||
1
|
|
||||||
}
|
|
@ -1,80 +0,0 @@
|
|||||||
#![allow(non_camel_case_types)]
|
|
||||||
#![allow(unused)]
|
|
||||||
|
|
||||||
// ARTIQ Exception struct declaration
|
|
||||||
use cslice::CSlice;
|
|
||||||
|
|
||||||
// Note: CSlice within an exception may not be actual cslice, they may be strings that exist only
|
|
||||||
// in the host. If the length == usize:MAX, the pointer is actually a string key in the host.
|
|
||||||
#[repr(C)]
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct Exception<'a> {
|
|
||||||
pub id: u32,
|
|
||||||
pub file: CSlice<'a, u8>,
|
|
||||||
pub line: u32,
|
|
||||||
pub column: u32,
|
|
||||||
pub function: CSlice<'a, u8>,
|
|
||||||
pub message: CSlice<'a, u8>,
|
|
||||||
pub param: [i64; 3],
|
|
||||||
}
|
|
||||||
|
|
||||||
fn str_err(_: core::str::Utf8Error) -> core::fmt::Error {
|
|
||||||
core::fmt::Error
|
|
||||||
}
|
|
||||||
|
|
||||||
fn exception_str<'a>(s: &'a CSlice<'a, u8>) -> Result<&'a str, core::str::Utf8Error> {
|
|
||||||
if s.len() == usize::MAX {
|
|
||||||
Ok("<host string>")
|
|
||||||
} else {
|
|
||||||
core::str::from_utf8(s.as_ref())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a> core::fmt::Debug for Exception<'a> {
|
|
||||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
|
||||||
write!(
|
|
||||||
f,
|
|
||||||
"Exception {} from {} in {}:{}:{}, message: {}",
|
|
||||||
self.id,
|
|
||||||
exception_str(&self.function).map_err(str_err)?,
|
|
||||||
exception_str(&self.file).map_err(str_err)?,
|
|
||||||
self.line,
|
|
||||||
self.column,
|
|
||||||
exception_str(&self.message).map_err(str_err)?
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub unsafe fn raise(exception: *const Exception) -> ! {
|
|
||||||
println!("Excepytiopn!! knfv {:?}", exception);
|
|
||||||
let e = &*exception;
|
|
||||||
let f1 = exception_str(&e.function).map_err(str_err).unwrap();
|
|
||||||
let f2 = exception_str(&e.file).map_err(str_err).unwrap();
|
|
||||||
let f3 = exception_str(&e.message).map_err(str_err).unwrap();
|
|
||||||
|
|
||||||
panic!("Exception {} from {} in {}:{}:{}, message: {}", e.id, f1, f2, e.line, e.column, f3);
|
|
||||||
}
|
|
||||||
|
|
||||||
static EXCEPTION_ID_LOOKUP: [(&str, u32); 12] = [
|
|
||||||
("RuntimeError", 0),
|
|
||||||
("RTIOUnderflow", 1),
|
|
||||||
("RTIOOverflow", 2),
|
|
||||||
("RTIODestinationUnreachable", 3),
|
|
||||||
("DMAError", 4),
|
|
||||||
("I2CError", 5),
|
|
||||||
("CacheError", 6),
|
|
||||||
("SPIError", 7),
|
|
||||||
("ZeroDivisionError", 8),
|
|
||||||
("IndexError", 9),
|
|
||||||
("UnwrapNoneError", 10),
|
|
||||||
("Value", 11),
|
|
||||||
];
|
|
||||||
|
|
||||||
pub fn get_exception_id(name: &str) -> u32 {
|
|
||||||
for (n, id) in EXCEPTION_ID_LOOKUP.iter() {
|
|
||||||
if *n == name {
|
|
||||||
return *id;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
unimplemented!("unallocated internal exception id")
|
|
||||||
}
|
|
@ -113,9 +113,7 @@ fn handle_typevar_definition(
|
|||||||
x,
|
x,
|
||||||
HashMap::new(),
|
HashMap::new(),
|
||||||
)?;
|
)?;
|
||||||
get_type_from_type_annotation_kinds(
|
get_type_from_type_annotation_kinds(def_list, unifier, &ty, &mut None)
|
||||||
def_list, unifier, primitives, &ty, &mut None,
|
|
||||||
)
|
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
let loc = func.location;
|
let loc = func.location;
|
||||||
@ -154,7 +152,7 @@ fn handle_typevar_definition(
|
|||||||
HashMap::new(),
|
HashMap::new(),
|
||||||
)?;
|
)?;
|
||||||
let constraint =
|
let constraint =
|
||||||
get_type_from_type_annotation_kinds(def_list, unifier, primitives, &ty, &mut None)?;
|
get_type_from_type_annotation_kinds(def_list, unifier, &ty, &mut None)?;
|
||||||
let loc = func.location;
|
let loc = func.location;
|
||||||
|
|
||||||
Ok(unifier.get_fresh_const_generic_var(constraint, Some(generic_name), Some(loc)).ty)
|
Ok(unifier.get_fresh_const_generic_var(constraint, Some(generic_name), Some(loc)).ty)
|
||||||
|
BIN
pyo3_output/nac3artiq.so
Executable file
BIN
pyo3_output/nac3artiq.so
Executable file
Binary file not shown.
Loading…
Reference in New Issue
Block a user