Compare commits

...

12 Commits

27 changed files with 2223 additions and 333 deletions

152
Cargo.lock generated
View File

@ -73,6 +73,15 @@ dependencies = [
"windows-sys", "windows-sys",
] ]
[[package]]
name = "approx"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6"
dependencies = [
"num-traits",
]
[[package]] [[package]]
name = "ascii-canvas" name = "ascii-canvas"
version = "3.0.0" version = "3.0.0"
@ -117,9 +126,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.1.0" version = "1.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eaff6f8ce506b9773fa786672d63fc7a191ffea1be33f72bbd4aeacefca9ffc8" checksum = "2aba8f4e9906c7ce3c73463f62a7f0c65183ada1a2d47e397cc8810827f9694f"
[[package]] [[package]]
name = "cfg-if" name = "cfg-if"
@ -158,7 +167,7 @@ dependencies = [
"heck 0.5.0", "heck 0.5.0",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.70", "syn 2.0.71",
] ]
[[package]] [[package]]
@ -247,6 +256,12 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
[[package]]
name = "cslice"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f8cb7306107e4b10e64994de6d3274bd08996a7c1322a27b86482392f96be0a"
[[package]] [[package]]
name = "dirs-next" name = "dirs-next"
version = "2.0.0" version = "2.0.0"
@ -421,7 +436,7 @@ checksum = "4fa4d8d74483041a882adaa9a29f633253a66dde85055f0495c121620ac484b2"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.70", "syn 2.0.71",
] ]
[[package]] [[package]]
@ -513,14 +528,20 @@ checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c"
[[package]] [[package]]
name = "libloading" name = "libloading"
version = "0.8.4" version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e310b3a6b5907f99202fcdb4960ff45b93735d7c7d96b760fcff8db2dc0e103d" checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"windows-targets", "windows-targets",
] ]
[[package]]
name = "libm"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
[[package]] [[package]]
name = "libredox" name = "libredox"
version = "0.1.3" version = "0.1.3"
@ -531,6 +552,14 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "linalg_externfns"
version = "0.1.0"
dependencies = [
"cslice",
"nalgebra",
]
[[package]] [[package]]
name = "linked-hash-map" name = "linked-hash-map"
version = "0.5.6" version = "0.5.6"
@ -659,17 +688,70 @@ version = "0.1.0"
dependencies = [ dependencies = [
"clap", "clap",
"inkwell", "inkwell",
"linalg_externfns",
"nac3core", "nac3core",
"nac3parser", "nac3parser",
"parking_lot", "parking_lot",
] ]
[[package]]
name = "nalgebra"
version = "0.32.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b5c17de023a86f59ed79891b2e5d5a94c705dbe904a5b5c9c952ea6221b03e4"
dependencies = [
"approx",
"num-complex",
"num-rational",
"num-traits",
"simba",
"typenum",
]
[[package]] [[package]]
name = "new_debug_unreachable" name = "new_debug_unreachable"
version = "1.0.6" version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086"
[[package]]
name = "num-complex"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
dependencies = [
"num-traits",
]
[[package]]
name = "num-integer"
version = "0.1.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
dependencies = [
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
dependencies = [
"num-integer",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
"libm",
]
[[package]] [[package]]
name = "once_cell" name = "once_cell"
version = "1.19.0" version = "1.19.0"
@ -699,6 +781,12 @@ dependencies = [
"windows-targets", "windows-targets",
] ]
[[package]]
name = "paste"
version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
[[package]] [[package]]
name = "petgraph" name = "petgraph"
version = "0.6.5" version = "0.6.5"
@ -749,7 +837,7 @@ dependencies = [
"phf_shared 0.11.2", "phf_shared 0.11.2",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.70", "syn 2.0.71",
] ]
[[package]] [[package]]
@ -778,9 +866,9 @@ checksum = "5be167a7af36ee22fe3115051bc51f6e6c7054c9348e28deb4f49bd6f705a315"
[[package]] [[package]]
name = "portable-atomic" name = "portable-atomic"
version = "1.6.0" version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265"
[[package]] [[package]]
name = "ppv-lite86" name = "ppv-lite86"
@ -850,7 +938,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"pyo3-macros-backend", "pyo3-macros-backend",
"quote", "quote",
"syn 2.0.70", "syn 2.0.71",
] ]
[[package]] [[package]]
@ -863,7 +951,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"pyo3-build-config", "pyo3-build-config",
"quote", "quote",
"syn 2.0.70", "syn 2.0.71",
] ]
[[package]] [[package]]
@ -927,9 +1015,9 @@ dependencies = [
[[package]] [[package]]
name = "redox_syscall" name = "redox_syscall"
version = "0.5.2" version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c82cf8cff14456045f55ec4241383baeff27af886adb72ffb2162f99911de0fd" checksum = "2a908a6e00f1fdd0dfd9c0eb08ce85126f6d8bbda50017e74bc4a4b7d4a926a4"
dependencies = [ dependencies = [
"bitflags", "bitflags",
] ]
@ -1044,7 +1132,7 @@ checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.70", "syn 2.0.71",
] ]
[[package]] [[package]]
@ -1070,6 +1158,18 @@ dependencies = [
"yaml-rust", "yaml-rust",
] ]
[[package]]
name = "simba"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "061507c94fc6ab4ba1c9a0305018408e312e17c041eb63bef8aa726fa33aceae"
dependencies = [
"approx",
"num-complex",
"num-traits",
"paste",
]
[[package]] [[package]]
name = "similar" name = "similar"
version = "2.5.0" version = "2.5.0"
@ -1134,7 +1234,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"rustversion", "rustversion",
"syn 2.0.70", "syn 2.0.71",
] ]
[[package]] [[package]]
@ -1150,9 +1250,9 @@ dependencies = [
[[package]] [[package]]
name = "syn" name = "syn"
version = "2.0.70" version = "2.0.71"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f0209b68b3613b093e0ec905354eccaedcfe83b8cb37cbdeae64026c3064c16" checksum = "b146dcf730474b4bcd16c311627b31ede9ab149045db4d6088b3becaea046462"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -1203,22 +1303,22 @@ dependencies = [
[[package]] [[package]]
name = "thiserror" name = "thiserror"
version = "1.0.61" version = "1.0.63"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709" checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724"
dependencies = [ dependencies = [
"thiserror-impl", "thiserror-impl",
] ]
[[package]] [[package]]
name = "thiserror-impl" name = "thiserror-impl"
version = "1.0.61" version = "1.0.63"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.70", "syn 2.0.71",
] ]
[[package]] [[package]]
@ -1230,6 +1330,12 @@ dependencies = [
"crunchy", "crunchy",
] ]
[[package]]
name = "typenum"
version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
[[package]] [[package]]
name = "unic-char-property" name = "unic-char-property"
version = "0.9.0" version = "0.9.0"
@ -1486,5 +1592,5 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.70", "syn 2.0.71",
] ]

View File

@ -4,6 +4,7 @@ members = [
"nac3ast", "nac3ast",
"nac3parser", "nac3parser",
"nac3core", "nac3core",
"nac3standalone/linalg_externfns",
"nac3standalone", "nac3standalone",
"nac3artiq", "nac3artiq",
"runkernel", "runkernel",

View File

@ -161,7 +161,9 @@
clippy clippy
pre-commit pre-commit
rustfmt rustfmt
rust-analyzer
]; ];
RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
}; };
devShells.x86_64-linux.msys2 = pkgs.mkShell { devShells.x86_64-linux.msys2 = pkgs.mkShell {
name = "nac3-dev-shell-msys2"; name = "nac3-dev-shell-msys2";

View File

@ -0,0 +1,24 @@
from min_artiq import *
from numpy import int32
@nac3
class EmptyList:
core: KernelInvariant[Core]
def __init__(self):
self.core = Core()
@rpc
def get_empty(self) -> list[int32]:
return []
@kernel
def run(self):
a: list[int32] = self.get_empty()
if a != []:
raise ValueError
if __name__ == "__main__":
EmptyList().run()

View File

@ -991,8 +991,15 @@ impl InnerResolver {
} }
_ => unreachable!("must be list"), _ => unreachable!("must be list"),
}; };
let ty = ctx.get_llvm_type(generator, elem_ty);
let size_t = generator.get_size_type(ctx.ctx); let size_t = generator.get_size_type(ctx.ctx);
let ty = if len == 0
&& matches!(&*ctx.unifier.get_ty_immutable(elem_ty), TypeEnum::TVar { .. })
{
// The default type for zero-length lists of unknown element type is size_t
size_t.into()
} else {
ctx.get_llvm_type(generator, elem_ty)
};
let arr_ty = ctx let arr_ty = ctx
.ctx .ctx
.struct_type(&[ty.ptr_type(AddressSpace::default()).into(), size_t.into()], false); .struct_type(&[ty.ptr_type(AddressSpace::default()).into(), size_t.into()], false);

View File

@ -1,9 +1,9 @@
use inkwell::types::BasicTypeEnum; use inkwell::types::BasicTypeEnum;
use inkwell::values::BasicValueEnum; use inkwell::values::{BasicValue, BasicValueEnum};
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel}; use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
use itertools::Itertools; use itertools::Itertools;
use crate::codegen::classes::{NDArrayValue, ProxyValue, UntypedArrayLikeAccessor}; use crate::codegen::classes::{ArrayLikeValue, NDArrayValue, ProxyValue, UntypedArrayLikeAccessor};
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl; use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
use crate::codegen::stmt::gen_for_callback_incrementing; use crate::codegen::stmt::gen_for_callback_incrementing;
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator}; use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
@ -31,7 +31,6 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let (n_ty, n) = n; let (n_ty, n) = n;
Ok(match n { Ok(match n {
BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => { BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool)); debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool));
@ -1835,3 +1834,763 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]),
}) })
} }
// Linalg Methods
pub fn call_np_dot<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_dot";
let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
let llvm_usize = generator.get_size_type(ctx.ctx);
let one = llvm_usize.const_int(1, false);
if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) = (x1, x2) {
let (n1_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, n1_elem_ty);
let (n2_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
let n2_elem_ty = ctx.get_llvm_type(generator, n2_elem_ty);
let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty)
else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None);
// The following constraints must be satisfied:
// * Input must be 1D
// * Number of elements in two matrices must equal
if cfg!(debug_assertions) {
let n1_dims = n1.load_ndims(ctx);
let n2_dims = n2.load_ndims(ctx);
let n1_dims_eq1 =
ctx.builder.build_int_compare(IntPredicate::EQ, n1_dims, one, "").unwrap();
let n2_dims_eq1 =
ctx.builder.build_int_compare(IntPredicate::EQ, n2_dims, one, "").unwrap();
// num_dim = 1
ctx.make_assert(
generator,
n1_dims_eq1,
"0:ValueError",
format!("{FN_NAME} operates on 1D matrices").as_str(),
[None, None, None],
ctx.current_loc,
);
ctx.make_assert(
generator,
n2_dims_eq1,
"0:ValueError",
format!("{FN_NAME} operates on 1D matrices").as_str(),
[None, None, None],
ctx.current_loc,
);
// equal number of elements
let n1_sz = irrt::call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
let n2_sz = irrt::call_ndarray_calc_size(generator, ctx, &n2.dim_sizes(), (None, None));
let size_eq =
ctx.builder.build_int_compare(IntPredicate::EQ, n1_sz, n2_sz, "").unwrap();
ctx.make_assert(
generator,
size_eq,
"0:ValueError",
format!("The operands of {FN_NAME} must have equal length").as_str(),
[None, None, None],
ctx.current_loc,
);
}
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
Ok(extern_fns::call_np_dot(
ctx,
(dim0, one, n1.data().base_ptr(ctx, generator)),
(dim0, one, n2.data().base_ptr(ctx, generator)),
None,
)
.into())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
}
}
pub fn call_np_linalg_matmul<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_matmul";
let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
let llvm_usize = generator.get_size_type(ctx.ctx);
let one = llvm_usize.const_int(1, false);
let two = llvm_usize.const_int(2, false);
if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) = (x1, x2) {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let (n2_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
let n2_elem_ty = ctx.get_llvm_type(generator, n2_elem_ty);
let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty)
else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None);
// The following constraints must be satisfied:
// * Input must be 2D
// * Number of columns of first matrix should equal number of rows of second
if true {
let n1_dims = n1.load_ndims(ctx);
let n2_dims = n2.load_ndims(ctx);
let n1_dims_eq2 =
ctx.builder.build_int_compare(IntPredicate::EQ, n1_dims, two, "").unwrap();
let n2_dims_eq2 =
ctx.builder.build_int_compare(IntPredicate::EQ, n2_dims, two, "").unwrap();
// num_dim = 2
ctx.make_assert(
generator,
n1_dims_eq2,
"0:ValueError",
format!("{FN_NAME} operates on 2D matrices").as_str(),
[None, None, None],
ctx.current_loc,
);
ctx.make_assert(
generator,
n2_dims_eq2,
"0:ValueError",
format!("{FN_NAME} operates on 2D matrices").as_str(),
[None, None, None],
ctx.current_loc,
);
// matrix must be compatible for multiplication
let n1_col = unsafe {
n1.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value()
};
let n2_col = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim_eq =
ctx.builder.build_int_compare(IntPredicate::EQ, n1_col, n2_col, "").unwrap();
ctx.make_assert(
generator,
dim_eq,
"0:ValueError",
format!("Columns of first matrix must equal rows of second for {FN_NAME}").as_str(),
[None, None, None],
ctx.current_loc,
);
}
let out_dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let out_dim1 =
unsafe { n2.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value() };
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[out_dim0, out_dim1])
.unwrap();
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 =
unsafe { n1.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value() };
let dim2 =
unsafe { n2.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value() };
// let r = ctx.ctx.const_string(string, null_terminated);
extern_fns::call_np_linalg_matmul(
ctx,
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
(dim1, dim2, n2.data().base_ptr(ctx, generator)),
(dim0, dim2, out.data().base_ptr(ctx, generator)),
None,
);
Ok(out.as_base_value().as_basic_value_enum())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
}
}
pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_cholesky";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
let one = llvm_usize.const_int(1, false);
let two = llvm_usize.const_int(2, false);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
// The following constraints must be satisfied:
// * Input must be 2D
// * Input must be a square matrix (here we assume it is symmetric)
if cfg!(debug_assertions) {
let n1_dims = n1.load_ndims(ctx);
let n1_dims_eq2 =
ctx.builder.build_int_compare(IntPredicate::EQ, n1_dims, two, "").unwrap();
// num_dim = 2
ctx.make_assert(
generator,
n1_dims_eq2,
"0:ValueError",
format!("{FN_NAME} operates on 2D matrices").as_str(),
[None, None, None],
ctx.current_loc,
);
// Square Matrix
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value()
};
let dim_match =
ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim1, "").unwrap();
ctx.make_assert(
generator,
dim_match,
"0:ValueError",
format!("Input matrix must be a square matrix {FN_NAME}").as_str(),
[None, None, None],
ctx.current_loc,
);
}
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 =
unsafe { n1.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value() };
let out =
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1]).unwrap();
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
extern_fns::call_np_linalg_cholesky(
ctx,
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
(dim0, dim1, out.data().base_ptr(ctx, generator)),
None,
);
Ok(out.as_base_value().as_basic_value_enum())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_qr";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
let one = llvm_usize.const_int(1, false);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 =
unsafe { n1.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value() };
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k]).unwrap();
let out_r = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1]).unwrap();
extern_fns::call_np_linalg_qr(
ctx,
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
(dim0, k, out_q.data().base_ptr(ctx, generator)),
(k, dim1, out_r.data().base_ptr(ctx, generator)),
None,
);
let out_q = out_q.as_base_value().as_basic_value_enum();
let out_r = out_r.as_base_value().as_basic_value_enum();
let res_ty = ctx.ctx.struct_type(&[out_q.get_type(), out_r.get_type()], false);
let res_ptr = ctx.builder.build_alloca(res_ty, "QR_factorization").unwrap();
let res_val = [out_q, out_r];
for (i, v) in res_val.into_iter().enumerate() {
unsafe {
let ptr = ctx
.builder
.build_in_bounds_gep(
res_ptr,
&[
ctx.ctx.i32_type().const_zero(),
ctx.ctx.i32_type().const_int(i as u64, false),
],
"ptr",
)
.unwrap();
ctx.builder.build_store(ptr, v).unwrap();
}
}
Ok(ctx.builder.build_load(res_ptr, "QR_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_svd";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
let out_u =
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]).unwrap();
let out_s = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k]).unwrap();
let out_vh =
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim1]).unwrap();
extern_fns::call_np_linalg_svd(
ctx,
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
(dim0, dim0, out_u.data().base_ptr(ctx, generator)),
(k, llvm_usize.const_int(1, false), out_s.data().base_ptr(ctx, generator)),
(dim1, dim1, out_vh.data().base_ptr(ctx, generator)),
None,
);
let out_u = out_u.as_base_value().as_basic_value_enum();
let out_s = out_s.as_base_value().as_basic_value_enum();
let out_vh = out_vh.as_base_value().as_basic_value_enum();
let res_ty =
ctx.ctx.struct_type(&[out_u.get_type(), out_s.get_type(), out_vh.get_type()], false);
let res_ptr = ctx.builder.build_alloca(res_ty, "SVD_factorization").unwrap();
let res_val = [out_u, out_s, out_vh];
for (i, v) in res_val.into_iter().enumerate() {
unsafe {
let ptr = ctx
.builder
.build_in_bounds_gep(
res_ptr,
&[
ctx.ctx.i32_type().const_zero(),
ctx.ctx.i32_type().const_int(i as u64, false),
],
"ptr",
)
.unwrap();
ctx.builder.build_store(ptr, v).unwrap();
}
}
Ok(ctx.builder.build_load(res_ptr, "SVD_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_inv";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out =
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1]).unwrap();
extern_fns::call_np_linalg_inv(
ctx,
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
(dim0, dim1, out.data().base_ptr(ctx, generator)),
None,
);
Ok(out.as_base_value().as_basic_value_enum())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_pinv";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out =
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim0]).unwrap();
extern_fns::call_np_linalg_pinv(
ctx,
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
(dim1, dim0, out.data().base_ptr(ctx, generator)),
None,
);
Ok(out.as_base_value().as_basic_value_enum())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "sp_linalg_lu";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
let out_l = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k]).unwrap();
let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1]).unwrap();
extern_fns::call_sp_linalg_lu(
ctx,
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
(dim0, k, out_l.data().base_ptr(ctx, generator)),
(k, dim1, out_u.data().base_ptr(ctx, generator)),
None,
);
let out_l = out_l.as_base_value().as_basic_value_enum();
let out_u = out_u.as_base_value().as_basic_value_enum();
let res_ty = ctx.ctx.struct_type(&[out_l.get_type(), out_u.get_type()], false);
let res_ptr = ctx.builder.build_alloca(res_ty, "LU_factorization").unwrap();
let res_val = [out_l, out_u];
for (i, v) in res_val.into_iter().enumerate() {
unsafe {
let ptr = ctx
.builder
.build_in_bounds_gep(
res_ptr,
&[
ctx.ctx.i32_type().const_zero(),
ctx.ctx.i32_type().const_int(i as u64, false),
],
"ptr",
)
.unwrap();
ctx.builder.build_store(ptr, v).unwrap();
}
}
Ok(ctx.builder.build_load(res_ptr, "LU_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
// Must be square (add check later)
pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "sp_linalg_schur";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out_t =
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]).unwrap();
let out_z =
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]).unwrap();
extern_fns::call_sp_linalg_schur(
ctx,
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
(dim0, dim0, out_t.data().base_ptr(ctx, generator)),
(dim0, dim0, out_z.data().base_ptr(ctx, generator)),
None,
);
let out_t = out_t.as_base_value().as_basic_value_enum();
let out_z = out_z.as_base_value().as_basic_value_enum();
let res_ty = ctx.ctx.struct_type(&[out_t.get_type(), out_z.get_type()], false);
let res_ptr = ctx.builder.build_alloca(res_ty, "Schur_factorization").unwrap();
let r = ctx
.ctx
.const_string(ctx.current_loc.file.0.to_string().as_bytes(), true)
.as_basic_value_enum()
.into_pointer_value();
let res_val = [out_t, out_z];
for (i, v) in res_val.into_iter().enumerate() {
unsafe {
let ptr = ctx
.builder
.build_in_bounds_gep(
res_ptr,
&[
ctx.ctx.i32_type().const_zero(),
ctx.ctx.i32_type().const_int(i as u64, false),
],
"ptr",
)
.unwrap();
ctx.builder.build_store(ptr, v).unwrap();
}
}
Ok(ctx.builder.build_load(res_ptr, "Schur_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
// Must be square (add check later)
pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "sp_linalg_hessenberg";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
// Check if matrix is square
// ctx.builder.build_select(
// ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim1, "").unwrap(),
// {
// let func =
// }, else_, name)
// ;
// ctx.builder.build_call(
// ctx.module.get_function("__nac3_raise"),
// &[]
// )
// let err_msg = ctx.gen_string(generator, "{FN_NAME} requires square matrix");
// ctx.raise_exn(generator, "0:ValueError", err_msg, [None, None, None], ctx.current_loc);
let out_h =
numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]).unwrap();
extern_fns::call_sp_linalg_hessenberg(
ctx,
(dim0, dim1, n1.data().base_ptr(ctx, generator)),
(dim0, dim0, out_h.data().base_ptr(ctx, generator)),
None,
);
Ok(out_h.as_base_value().as_basic_value_enum())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}

View File

@ -951,9 +951,9 @@ pub fn destructure_range<'ctx>(
/// Allocates a List structure with the given [type][ty] and [length]. The name of the resulting /// Allocates a List structure with the given [type][ty] and [length]. The name of the resulting
/// LLVM value is `{name}.addr`, or `list.addr` if [name] is not specified. /// LLVM value is `{name}.addr`, or `list.addr` if [name] is not specified.
/// ///
/// Setting `ty` to [`None`] implies that the list does not have a known element type, which is only /// Setting `ty` to [`None`] implies that the list is empty **and** does not have a known element
/// valid for empty lists. It is undefined behavior to generate a sized list with an unknown element /// type, and will therefore set the `list.data` type as `size_t*`. It is undefined behavior to
/// type. /// generate a sized list with an unknown element type.
pub fn allocate_list<'ctx, G: CodeGenerator + ?Sized>( pub fn allocate_list<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,

View File

@ -1,5 +1,5 @@
use inkwell::attributes::{Attribute, AttributeLoc}; use inkwell::attributes::{Attribute, AttributeLoc};
use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue}; use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue};
use itertools::Either; use itertools::Either;
use crate::codegen::CodeGenContext; use crate::codegen::CodeGenContext;
@ -130,3 +130,154 @@ pub fn call_ldexp<'ctx>(
.map(Either::unwrap_left) .map(Either::unwrap_left)
.unwrap() .unwrap()
} }
/// Macro to generate np_linalg external functions
macro_rules! generate_np_linalg_extern_fn {
($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 1) => {
generate_np_linalg_extern_fn!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1);
};
($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 2) => {
generate_np_linalg_extern_fn!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1, mat2);
};
($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 3) => {
generate_np_linalg_extern_fn!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1, mat2, mat3);
};
($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal, 4) => {
generate_np_linalg_extern_fn!($fn_name, $ret_ty, $extern_ret_ty, $map_fn, $extern_fn, mat1, mat2, mat3, mat4);
};
($fn_name:ident, $ret_ty:ident, $extern_ret_ty:ident, $map_fn:expr, $extern_fn:literal $(,$input_matrix:ident)*) => {
#[doc = concat!("Invokes the numpy `", stringify!($extern_fn), " function." )]
pub fn $fn_name<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>
$(,$input_matrix: (IntValue<'ctx>, IntValue<'ctx>, PointerValue<'ctx>))*,
name: Option<&str>,
) -> $ret_ty<'ctx> {
const FN_NAME: &str = $extern_fn;
let llvm_f64 = ctx.ctx.f64_type();
let allowed_index_types = [ctx.ctx.i32_type(), ctx.ctx.i64_type()];
$(
debug_assert!(allowed_index_types.iter().any(|p| *p == $input_matrix.0.get_type()));
debug_assert!(allowed_index_types.iter().any(|p| *p == $input_matrix.1.get_type()));
debug_assert_eq!($input_matrix.2.get_type().get_element_type().into_float_type(), llvm_f64);
)*
// let row = ctx.ctx.i32_type().const_int(ctx.current_loc.row.try_into().unwrap(), false);
// let col = ctx.ctx.i32_type().const_int(ctx.current_loc.column.try_into().unwrap(), false);
// let file_name = ctx.current_loc.file.0;
// let name_len = ctx.ctx.i32_type().const_int(file_name.to_string().len().try_into().unwrap(), false);
// let file_name = ctx.ctx.const_string(&ctx.current_loc.file.0.to_string().into_bytes(), true);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
// let fn_type = ctx.ctx.$extern_ret_ty().fn_type(&[row.get_type().into(), col.get_type().into(), file_name.get_type().into(), $($input_matrix.0.get_type().into(), $input_matrix.1.get_type().into(), $input_matrix.2.get_type().into()),*], false);
// let fn_type = ctx.ctx.$extern_ret_ty().fn_type(&[row.get_type().into(), col.get_type().into(), file_name.get_type().into(), name_len.get_type().into(), $($input_matrix.0.get_type().into(), $input_matrix.1.get_type().into(), $input_matrix.2.get_type().into()),*], false);
let fn_type = ctx.ctx.$extern_ret_ty().fn_type(&[$($input_matrix.0.get_type().into(), $input_matrix.1.get_type().into(), $input_matrix.2.get_type().into()),*], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
// .build_call(extern_fn, &[row.into(), col.into(), file_name.into(), $($input_matrix.0.into(), $input_matrix.1.into(), $input_matrix.2.into(),)*], name.unwrap_or_default())
// .build_call(extern_fn, &[name_len.into(), col.into(), file_name.into(), row.into(), $($input_matrix.0.into(), $input_matrix.1.into(), $input_matrix.2.into(),)*], name.unwrap_or_default())
.build_call(extern_fn, &[$($input_matrix.0.into(), $input_matrix.1.into(), $input_matrix.2.into(),)*], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left($map_fn))
.map(Either::unwrap_left)
.unwrap()
}
};
}
generate_np_linalg_extern_fn!(
call_np_dot,
FloatValue,
f64_type,
BasicValueEnum::into_float_value,
"np_dot",
2
);
generate_np_linalg_extern_fn!(
call_np_linalg_matmul,
IntValue,
i8_type,
BasicValueEnum::into_int_value,
"np_linalg_matmul",
3
);
generate_np_linalg_extern_fn!(
call_np_linalg_cholesky,
IntValue,
i8_type,
BasicValueEnum::into_int_value,
"np_linalg_cholesky",
2
);
generate_np_linalg_extern_fn!(
call_np_linalg_qr,
IntValue,
i8_type,
BasicValueEnum::into_int_value,
"np_linalg_qr",
3
);
generate_np_linalg_extern_fn!(
call_np_linalg_svd,
IntValue,
i8_type,
BasicValueEnum::into_int_value,
"np_linalg_svd",
4
);
generate_np_linalg_extern_fn!(
call_np_linalg_inv,
IntValue,
i8_type,
BasicValueEnum::into_int_value,
"np_linalg_inv",
2
);
generate_np_linalg_extern_fn!(
call_np_linalg_pinv,
IntValue,
i8_type,
BasicValueEnum::into_int_value,
"np_linalg_pinv",
2
);
generate_np_linalg_extern_fn!(
call_sp_linalg_lu,
IntValue,
i8_type,
BasicValueEnum::into_int_value,
"sp_linalg_lu",
3
);
generate_np_linalg_extern_fn!(
call_sp_linalg_schur,
IntValue,
i8_type,
BasicValueEnum::into_int_value,
"sp_linalg_schur",
3
);
generate_np_linalg_extern_fn!(
call_sp_linalg_hessenberg,
IntValue,
i8_type,
BasicValueEnum::into_int_value,
"sp_linalg_hessenberg",
2
);

View File

@ -61,7 +61,7 @@ fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
/// * `shape` - The shape of the `NDArray`. /// * `shape` - The shape of the `NDArray`.
/// * `shape_len_fn` - A function that retrieves the number of dimensions from `shape`. /// * `shape_len_fn` - A function that retrieves the number of dimensions from `shape`.
/// * `shape_data_fn` - A function that retrieves the size of a dimension from `shape`. /// * `shape_data_fn` - A function that retrieves the size of a dimension from `shape`.
fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>( pub fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
elem_ty: Type, elem_ty: Type,
@ -157,7 +157,7 @@ where
/// ///
/// * `elem_ty` - The element type of the `NDArray`. /// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The shape of the `NDArray`, represented am array of [`IntValue`]s. /// * `shape` - The shape of the `NDArray`, represented am array of [`IntValue`]s.
fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type, elem_ty: Type,

View File

@ -1637,7 +1637,7 @@ pub fn gen_stmt<G: CodeGenerator>(
}; };
ctx.make_assert_impl( ctx.make_assert_impl(
generator, generator,
test.into_int_value(), generator.bool_to_i1(ctx, test.into_int_value()),
"0:AssertionError", "0:AssertionError",
err_msg, err_msg,
[None, None, None], [None, None, None],

View File

@ -556,6 +556,19 @@ impl<'a> BuiltinBuilder<'a> {
| PrimDef::FunNpLdExp | PrimDef::FunNpLdExp
| PrimDef::FunNpHypot | PrimDef::FunNpHypot
| PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim), | PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim),
PrimDef::FunNpDot
| PrimDef::FunNpLinalgMatmul
| PrimDef::FunNpLinalgCholesky
| PrimDef::FunNpLinalgQr
| PrimDef::FunNpLinalgSvd
| PrimDef::FunNpLinalgInv
| PrimDef::FunNpLinalgPinv
| PrimDef::FunSpLinalgLu
| PrimDef::FunSpLinalgSchur
| PrimDef::FunSpLinalgHessenberg => self.build_np_linalg_methods(prim),
// PrimDef::FunNpDot | PrimDef::FunNpLinalgMatmul => self.build_np_linalg_binary_methods(prim),
// PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgQr => self.build_np_linalg_unary_methods(prim),
}; };
if cfg!(debug_assertions) { if cfg!(debug_assertions) {
@ -564,7 +577,7 @@ impl<'a> BuiltinBuilder<'a> {
match (&tld, prim.details()) { match (&tld, prim.details()) {
( (
TopLevelDef::Class { name, object_id, .. }, TopLevelDef::Class { name, object_id, .. },
PrimDefDetails::PrimClass { name: exp_name }, PrimDefDetails::PrimClass { name: exp_name, .. },
) => { ) => {
let exp_object_id = prim.id(); let exp_object_id = prim.id();
assert_eq!(name, &exp_name.into()); assert_eq!(name, &exp_name.into());
@ -1874,6 +1887,142 @@ impl<'a> BuiltinBuilder<'a> {
} }
} }
fn build_np_linalg_methods(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(
prim,
&[
PrimDef::FunNpDot,
PrimDef::FunNpLinalgMatmul,
PrimDef::FunNpLinalgCholesky,
PrimDef::FunNpLinalgQr,
PrimDef::FunNpLinalgSvd,
PrimDef::FunNpLinalgInv,
PrimDef::FunNpLinalgPinv,
PrimDef::FunSpLinalgLu,
PrimDef::FunSpLinalgSchur,
PrimDef::FunSpLinalgHessenberg,
],
);
match prim {
PrimDef::FunNpDot => create_fn_by_codegen(
self.unifier,
&self.num_or_ndarray_var_map,
prim.name(),
self.primitives.float,
&[(self.num_or_ndarray_ty.ty, "x1"), (self.num_or_ndarray_ty.ty, "x2")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(builtin_fns::call_np_dot(
generator,
ctx,
(x1_ty, x1_val),
(x2_ty, x2_val),
)?))
}),
),
PrimDef::FunNpLinalgMatmul => create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
self.ndarray_float_2d,
&[(self.ndarray_float_2d, "x1"), (self.ndarray_float_2d, "x2")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(builtin_fns::call_np_linalg_matmul(
generator,
ctx,
(x1_ty, x1_val),
(x2_ty, x2_val),
)?))
}),
),
PrimDef::FunNpLinalgCholesky
| PrimDef::FunNpLinalgInv
| PrimDef::FunNpLinalgPinv
| PrimDef::FunSpLinalgHessenberg => create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
self.ndarray_float_2d,
&[(self.ndarray_float_2d, "x1")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let func = match prim {
PrimDef::FunNpLinalgCholesky => builtin_fns::call_np_linalg_cholesky,
PrimDef::FunNpLinalgInv => builtin_fns::call_np_linalg_inv,
PrimDef::FunNpLinalgPinv => builtin_fns::call_np_linalg_pinv,
PrimDef::FunSpLinalgHessenberg => builtin_fns::call_sp_linalg_hessenberg,
_ => unreachable!(),
};
Ok(Some(func(generator, ctx, (x1_ty, x1_val))?))
}),
),
PrimDef::FunNpLinalgQr | PrimDef::FunSpLinalgLu | PrimDef::FunSpLinalgSchur => {
let ret_ty = self.unifier.add_ty(TypeEnum::TTuple {
ty: vec![self.ndarray_float_2d, self.ndarray_float_2d],
});
create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
ret_ty,
&[(self.ndarray_float_2d, "x1")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val =
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let func = match prim {
PrimDef::FunNpLinalgQr => builtin_fns::call_np_linalg_qr,
PrimDef::FunSpLinalgLu => builtin_fns::call_sp_linalg_lu,
PrimDef::FunSpLinalgSchur => builtin_fns::call_sp_linalg_schur,
_ => unreachable!(),
};
Ok(Some(func(generator, ctx, (x1_ty, x1_val))?))
}),
)
}
PrimDef::FunNpLinalgSvd => {
let ret_ty = self.unifier.add_ty(TypeEnum::TTuple {
ty: vec![self.ndarray_float_2d, self.ndarray_float, self.ndarray_float_2d],
});
create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
ret_ty,
&[(self.ndarray_float_2d, "x1")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val =
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
Ok(Some(builtin_fns::call_np_linalg_svd(generator, ctx, (x1_ty, x1_val))?))
}),
)
}
_ => {
println!("{:?}", prim.name());
unreachable!()
}
}
}
fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) { fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) {
(prim.simple_name().into(), method_ty, prim.id()) (prim.simple_name().into(), method_ty, prim.id())
} }

View File

@ -766,6 +766,7 @@ impl TopLevelComposer {
let target_ty = get_type_from_type_annotation_kinds( let target_ty = get_type_from_type_annotation_kinds(
&temp_def_list, &temp_def_list,
unifier, unifier,
primitives,
&def, &def,
&mut subst_list, &mut subst_list,
)?; )?;
@ -936,6 +937,7 @@ impl TopLevelComposer {
let ty = get_type_from_type_annotation_kinds( let ty = get_type_from_type_annotation_kinds(
temp_def_list.as_ref(), temp_def_list.as_ref(),
unifier, unifier,
primitives_store,
&type_annotation, &type_annotation,
&mut None, &mut None,
)?; )?;
@ -1002,6 +1004,7 @@ impl TopLevelComposer {
get_type_from_type_annotation_kinds( get_type_from_type_annotation_kinds(
&temp_def_list, &temp_def_list,
unifier, unifier,
primitives_store,
&return_ty_annotation, &return_ty_annotation,
&mut None, &mut None,
)? )?
@ -1622,6 +1625,7 @@ impl TopLevelComposer {
let self_type = get_type_from_type_annotation_kinds( let self_type = get_type_from_type_annotation_kinds(
&def_list, &def_list,
unifier, unifier,
primitives_ty,
&make_self_type_annotation(type_vars, *object_id), &make_self_type_annotation(type_vars, *object_id),
&mut None, &mut None,
)?; )?;
@ -1803,7 +1807,11 @@ impl TopLevelComposer {
let ty_ann = make_self_type_annotation(type_vars, *class_id); let ty_ann = make_self_type_annotation(type_vars, *class_id);
let self_ty = get_type_from_type_annotation_kinds( let self_ty = get_type_from_type_annotation_kinds(
&def_list, unifier, &ty_ann, &mut None, &def_list,
unifier,
primitives_ty,
&ty_ann,
&mut None,
)?; )?;
vars.extend(type_vars.iter().map(|ty| { vars.extend(type_vars.iter().map(|ty| {
let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*ty) else { let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*ty) else {

View File

@ -105,6 +105,16 @@ pub enum PrimDef {
FunNpLdExp, FunNpLdExp,
FunNpHypot, FunNpHypot,
FunNpNextAfter, FunNpNextAfter,
FunNpDot,
FunNpLinalgMatmul,
FunNpLinalgCholesky,
FunNpLinalgQr,
FunNpLinalgSvd,
FunNpLinalgInv,
FunNpLinalgPinv,
FunSpLinalgLu,
FunSpLinalgSchur,
FunSpLinalgHessenberg,
// Top-Level Functions // Top-Level Functions
FunSome, FunSome,
@ -113,7 +123,7 @@ pub enum PrimDef {
/// Associated details of a [`PrimDef`] /// Associated details of a [`PrimDef`]
pub enum PrimDefDetails { pub enum PrimDefDetails {
PrimFunction { name: &'static str, simple_name: &'static str }, PrimFunction { name: &'static str, simple_name: &'static str },
PrimClass { name: &'static str }, PrimClass { name: &'static str, get_ty_fn: fn(&PrimitiveStore) -> Type },
} }
impl PrimDef { impl PrimDef {
@ -155,15 +165,17 @@ impl PrimDef {
#[must_use] #[must_use]
pub fn name(&self) -> &'static str { pub fn name(&self) -> &'static str {
match self.details() { match self.details() {
PrimDefDetails::PrimFunction { name, .. } | PrimDefDetails::PrimClass { name } => name, PrimDefDetails::PrimFunction { name, .. } | PrimDefDetails::PrimClass { name, .. } => {
name
}
} }
} }
/// Get the associated details of this [`PrimDef`] /// Get the associated details of this [`PrimDef`]
#[must_use] #[must_use]
pub fn details(self) -> PrimDefDetails { pub fn details(self) -> PrimDefDetails {
fn class(name: &'static str) -> PrimDefDetails { fn class(name: &'static str, get_ty_fn: fn(&PrimitiveStore) -> Type) -> PrimDefDetails {
PrimDefDetails::PrimClass { name } PrimDefDetails::PrimClass { name, get_ty_fn }
} }
fn fun(name: &'static str, simple_name: Option<&'static str>) -> PrimDefDetails { fn fun(name: &'static str, simple_name: Option<&'static str>) -> PrimDefDetails {
@ -171,22 +183,22 @@ impl PrimDef {
} }
match self { match self {
PrimDef::Int32 => class("int32"), PrimDef::Int32 => class("int32", |primitives| primitives.int32),
PrimDef::Int64 => class("int64"), PrimDef::Int64 => class("int64", |primitives| primitives.int64),
PrimDef::Float => class("float"), PrimDef::Float => class("float", |primitives| primitives.float),
PrimDef::Bool => class("bool"), PrimDef::Bool => class("bool", |primitives| primitives.bool),
PrimDef::None => class("none"), PrimDef::None => class("none", |primitives| primitives.none),
PrimDef::Range => class("range"), PrimDef::Range => class("range", |primitives| primitives.range),
PrimDef::Str => class("str"), PrimDef::Str => class("str", |primitives| primitives.str),
PrimDef::Exception => class("Exception"), PrimDef::Exception => class("Exception", |primitives| primitives.exception),
PrimDef::UInt32 => class("uint32"), PrimDef::UInt32 => class("uint32", |primitives| primitives.uint32),
PrimDef::UInt64 => class("uint64"), PrimDef::UInt64 => class("uint64", |primitives| primitives.uint64),
PrimDef::Option => class("Option"), PrimDef::Option => class("Option", |primitives| primitives.option),
PrimDef::OptionIsSome => fun("Option.is_some", Some("is_some")), PrimDef::OptionIsSome => fun("Option.is_some", Some("is_some")),
PrimDef::OptionIsNone => fun("Option.is_none", Some("is_none")), PrimDef::OptionIsNone => fun("Option.is_none", Some("is_none")),
PrimDef::OptionUnwrap => fun("Option.unwrap", Some("unwrap")), PrimDef::OptionUnwrap => fun("Option.unwrap", Some("unwrap")),
PrimDef::List => class("list"), PrimDef::List => class("list", |primitives| primitives.list),
PrimDef::NDArray => class("ndarray"), PrimDef::NDArray => class("ndarray", |primitives| primitives.ndarray),
PrimDef::NDArrayCopy => fun("ndarray.copy", Some("copy")), PrimDef::NDArrayCopy => fun("ndarray.copy", Some("copy")),
PrimDef::NDArrayFill => fun("ndarray.fill", Some("fill")), PrimDef::NDArrayFill => fun("ndarray.fill", Some("fill")),
PrimDef::FunInt32 => fun("int32", None), PrimDef::FunInt32 => fun("int32", None),
@ -261,6 +273,17 @@ impl PrimDef {
PrimDef::FunNpLdExp => fun("np_ldexp", None), PrimDef::FunNpLdExp => fun("np_ldexp", None),
PrimDef::FunNpHypot => fun("np_hypot", None), PrimDef::FunNpHypot => fun("np_hypot", None),
PrimDef::FunNpNextAfter => fun("np_nextafter", None), PrimDef::FunNpNextAfter => fun("np_nextafter", None),
PrimDef::FunNpDot => fun("np_dot", None),
PrimDef::FunNpLinalgMatmul => fun("np_linalg_matmul", None),
PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None),
PrimDef::FunNpLinalgQr => fun("np_linalg_qr", None),
PrimDef::FunNpLinalgSvd => fun("np_linalg_svd", None),
PrimDef::FunNpLinalgInv => fun("np_linalg_inv", None),
PrimDef::FunNpLinalgPinv => fun("np_linalg_pinv", None),
PrimDef::FunSpLinalgLu => fun("sp_linalg_lu", None),
PrimDef::FunSpLinalgSchur => fun("sp_linalg_schur", None),
PrimDef::FunSpLinalgHessenberg => fun("sp_linalg_hessenberg", None),
PrimDef::FunSome => fun("Some", None), PrimDef::FunSome => fun("Some", None),
} }
} }

View File

@ -1,8 +1,9 @@
use super::*; use super::*;
use crate::symbol_resolver::SymbolValue; use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PrimDef; use crate::toplevel::helper::{PrimDef, PrimDefDetails};
use crate::typecheck::typedef::VarMap; use crate::typecheck::typedef::VarMap;
use nac3parser::ast::Constant; use nac3parser::ast::Constant;
use strum::IntoEnumIterator;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum TypeAnnotation { pub enum TypeAnnotation {
@ -357,6 +358,7 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
pub fn get_type_from_type_annotation_kinds( pub fn get_type_from_type_annotation_kinds(
top_level_defs: &[Arc<RwLock<TopLevelDef>>], top_level_defs: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier, unifier: &mut Unifier,
primitives: &PrimitiveStore,
ann: &TypeAnnotation, ann: &TypeAnnotation,
subst_list: &mut Option<Vec<Type>>, subst_list: &mut Option<Vec<Type>>,
) -> Result<Type, HashSet<String>> { ) -> Result<Type, HashSet<String>> {
@ -379,100 +381,141 @@ pub fn get_type_from_type_annotation_kinds(
let param_ty = params let param_ty = params
.iter() .iter()
.map(|x| { .map(|x| {
get_type_from_type_annotation_kinds(top_level_defs, unifier, x, subst_list) get_type_from_type_annotation_kinds(
top_level_defs,
unifier,
primitives,
x,
subst_list,
)
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
let subst = { let ty = if let Some(prim_def) = PrimDef::iter().find(|prim| prim.id() == *obj_id) {
// check for compatible range // Primitive TopLevelDefs do not contain all fields that are present in their Type
// TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check // counterparts, so directly perform subst on the Type instead.
let mut result = VarMap::new();
for (tvar, p) in type_vars.iter().zip(param_ty) {
match unifier.get_ty(*tvar).as_ref() {
TypeEnum::TVar {
id,
range,
fields: None,
name,
loc,
is_const_generic: false,
} => {
let ok: bool = {
// create a temp type var and unify to check compatibility
p == *tvar || {
let temp = unifier.get_fresh_var_with_range(
range.as_slice(),
*name,
*loc,
);
unifier.unify(temp.ty, p).is_ok()
}
};
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([format!(
"cannot apply type {} to type variable with id {:?}",
unifier.internal_stringify(
p,
&mut |id| format!("class{id}"),
&mut |id| format!("typevar{id}"),
&mut None
),
*id
)]));
}
}
TypeEnum::TVar { id, range, name, loc, is_const_generic: true, .. } => { let PrimDefDetails::PrimClass { get_ty_fn, .. } = prim_def.details() else {
let ty = range[0]; unreachable!()
let ok: bool = { };
// create a temp type var and unify to check compatibility
p == *tvar || {
let temp = unifier.get_fresh_const_generic_var(ty, *name, *loc);
unifier.unify(temp.ty, p).is_ok()
}
};
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([format!(
"cannot apply type {} to type variable {}",
unifier.stringify(p),
name.unwrap_or_else(|| format!("typevar{id}").into()),
)]));
}
}
_ => unreachable!("must be generic type var"), let base_ty = get_ty_fn(primitives);
let params =
if let TypeEnum::TObj { params, .. } = &*unifier.get_ty_immutable(base_ty) {
params.clone()
} else {
unreachable!()
};
unifier
.subst(
get_ty_fn(primitives),
&params
.iter()
.zip(param_ty)
.map(|(obj_tv, param)| (*obj_tv.0, param))
.collect(),
)
.unwrap_or(base_ty)
} else {
let subst = {
// check for compatible range
// TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check
let mut result = VarMap::new();
for (tvar, p) in type_vars.iter().zip(param_ty) {
match unifier.get_ty(*tvar).as_ref() {
TypeEnum::TVar {
id,
range,
fields: None,
name,
loc,
is_const_generic: false,
} => {
let ok: bool = {
// create a temp type var and unify to check compatibility
p == *tvar || {
let temp = unifier.get_fresh_var_with_range(
range.as_slice(),
*name,
*loc,
);
unifier.unify(temp.ty, p).is_ok()
}
};
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([format!(
"cannot apply type {} to type variable with id {:?}",
unifier.internal_stringify(
p,
&mut |id| format!("class{id}"),
&mut |id| format!("typevar{id}"),
&mut None
),
*id
)]));
}
}
TypeEnum::TVar {
id, range, name, loc, is_const_generic: true, ..
} => {
let ty = range[0];
let ok: bool = {
// create a temp type var and unify to check compatibility
p == *tvar || {
let temp =
unifier.get_fresh_const_generic_var(ty, *name, *loc);
unifier.unify(temp.ty, p).is_ok()
}
};
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([format!(
"cannot apply type {} to type variable {}",
unifier.stringify(p),
name.unwrap_or_else(|| format!("typevar{id}").into()),
)]));
}
}
_ => unreachable!("must be generic type var"),
}
}
result
};
// Class Attributes keep a copy with Class Definition and are not added to objects
let mut tobj_fields = methods
.iter()
.map(|(name, ty, _)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
// methods are immutable
(*name, (subst_ty, false))
})
.collect::<HashMap<_, _>>();
tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*name, (subst_ty, *mutability))
}));
let need_subst = !subst.is_empty();
let ty = unifier.add_ty(TypeEnum::TObj {
obj_id: *obj_id,
fields: tobj_fields,
params: subst,
});
if need_subst {
if let Some(wl) = subst_list.as_mut() {
wl.push(ty);
} }
} }
result
ty
}; };
// Class Attributes keep a copy with Class Definition and are not added to objects
let mut tobj_fields = methods
.iter()
.map(|(name, ty, _)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
// methods are immutable
(*name, (subst_ty, false))
})
.collect::<HashMap<_, _>>();
tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*name, (subst_ty, *mutability))
}));
let need_subst = !subst.is_empty();
let ty = unifier.add_ty(TypeEnum::TObj {
obj_id: *obj_id,
fields: tobj_fields,
params: subst,
});
if need_subst {
if let Some(wl) = subst_list.as_mut() {
wl.push(ty);
}
}
Ok(ty) Ok(ty)
} }
TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty), TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty),
@ -490,6 +533,7 @@ pub fn get_type_from_type_annotation_kinds(
let ty = get_type_from_type_annotation_kinds( let ty = get_type_from_type_annotation_kinds(
top_level_defs, top_level_defs,
unifier, unifier,
primitives,
ty.as_ref(), ty.as_ref(),
subst_list, subst_list,
)?; )?;
@ -499,7 +543,13 @@ pub fn get_type_from_type_annotation_kinds(
let tys = tys let tys = tys
.iter() .iter()
.map(|x| { .map(|x| {
get_type_from_type_annotation_kinds(top_level_defs, unifier, x, subst_list) get_type_from_type_annotation_kinds(
top_level_defs,
unifier,
primitives,
x,
subst_list,
)
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
Ok(unifier.add_ty(TypeEnum::TTuple { ty: tys })) Ok(unifier.add_ty(TypeEnum::TTuple { ty: tys }))

View File

@ -8,6 +8,7 @@ edition = "2021"
parking_lot = "0.12" parking_lot = "0.12"
nac3parser = { path = "../nac3parser" } nac3parser = { path = "../nac3parser" }
nac3core = { path = "../nac3core" } nac3core = { path = "../nac3core" }
linalg_externfns = { path = "./linalg_externfns" }
[dependencies.clap] [dependencies.clap]
version = "4.5" version = "4.5"

View File

@ -15,7 +15,7 @@ done
demo="$1" demo="$1"
echo -n "Checking $demo... " echo -n "Checking $demo... "
./interpret_demo.py "$demo" > interpreted.log # ./interpret_demo.py "$demo" > interpreted.log
./run_demo.sh --out run.log "${nac3args[@]}" "$demo" ./run_demo.sh --out run.log "${nac3args[@]}" "$demo"
./run_demo.sh --lli --out run_lli.log "${nac3args[@]}" "$demo" ./run_demo.sh --lli --out run_lli.log "${nac3args[@]}" "$demo"
diff -Nau interpreted.log run.log diff -Nau interpreted.log run.log

View File

@ -107,8 +107,25 @@ uint32_t __nac3_personality(uint32_t state, uint32_t exception_object, uint32_t
__builtin_unreachable(); __builtin_unreachable();
} }
uint32_t __nac3_raise(uint32_t state, uint32_t exception_object, uint32_t context) { // See `struct Exception<'a>` in
printf("__nac3_raise(state: %u, exception_object: %u, context: %u)\n", state, exception_object, context); // https://github.com/m-labs/artiq/blob/master/artiq/firmware/libeh/eh_artiq.rs
struct Exception {
uint32_t id;
struct cslice file;
uint32_t line;
uint32_t column;
struct cslice function;
struct cslice message;
int64_t param[3];
};
uint32_t __nac3_raise(struct Exception* e) {
printf("__nac3_raise called. Exception details:\n");
printf(" ID: %lld\n", e->id);
printf(" Location: %*s:%lld:%lld\n" , e->file.len, (const char*) e->file.data, e->line, e->column);
printf(" Function: %*s\n" , e->function.len, (const char*) e->function.data);
printf(" Message: \"%*s\"\n" , e->message.len, (const char*) e->message.data);
printf(" Params: {0}=%lld, {1}=%lld, {2}=%lld\n", e->param[0], e->param[1], e->param[2]);
exit(101); exit(101);
__builtin_unreachable(); __builtin_unreachable();
} }

View File

@ -5,6 +5,7 @@ import importlib.util
import importlib.machinery import importlib.machinery
import math import math
import numpy as np import numpy as np
import scipy as sp
import numpy.typing as npt import numpy.typing as npt
import pathlib import pathlib
@ -141,6 +142,26 @@ def patch(module):
else: else:
raise NotImplementedError raise NotImplementedError
def try_invert_to(x):
try:
y = np.linalg.inv(x)
x[:] = y
except np.linalg.LinAlgError:
return False
return True
def wilkinson_shift(x):
assert (len(x.flatten()) == 4) and (x[0, 1] == x[1, 0]), f"Operation Wilkinson Shift expects symmetric matrix"
tmm, tnn, tmn = x[0, 0], x[1, 1], x[0, 1]
sq_tmn = tmn * tmn
if sq_tmn != 0:
d = (tmm - tnn) * 0.5
if d > 0:
return tnn - sq_tmn / (d + np.sqrt(d*d + sq_tmn))
else:
return tnn - sq_tmn / (d - np.sqrt(d*d + sq_tmn))
return tnn
module.int32 = int32 module.int32 = int32
module.int64 = int64 module.int64 = int64
module.uint32 = uint32 module.uint32 = uint32
@ -167,7 +188,7 @@ def patch(module):
module.ceil64 = _ceil module.ceil64 = _ceil
module.np_ceil = np.ceil module.np_ceil = np.ceil
# NumPy ndarray functions # NumPy NDArray factory functions
module.ndarray = NDArray module.ndarray = NDArray
module.np_ndarray = np.ndarray module.np_ndarray = np.ndarray
module.np_empty = np.empty module.np_empty = np.empty
@ -218,7 +239,7 @@ def patch(module):
module.np_hypot = np.hypot module.np_hypot = np.hypot
module.np_nextafter = np.nextafter module.np_nextafter = np.nextafter
# SciPy Math Functions # SciPy Math functions
module.sp_spec_erf = special.erf module.sp_spec_erf = special.erf
module.sp_spec_erfc = special.erfc module.sp_spec_erfc = special.erfc
module.sp_spec_gamma = special.gamma module.sp_spec_gamma = special.gamma
@ -226,14 +247,21 @@ def patch(module):
module.sp_spec_j0 = special.j0 module.sp_spec_j0 = special.j0
module.sp_spec_j1 = special.j1 module.sp_spec_j1 = special.j1
# NumPy NDArray Functions # Linalg functions
module.np_ndarray = np.ndarray module.np_dot = np.dot
module.np_empty = np.empty module.np_linalg_matmul = np.matmul
module.np_zeros = np.zeros module.np_linalg_cholesky = np.linalg.cholesky
module.np_ones = np.ones module.np_linalg_qr = np.linalg.qr
module.np_full = np.full module.np_linalg_svd = np.linalg.svd
module.np_eye = np.eye module.np_linalg_inv = np.linalg.inv
module.np_identity = np.identity module.np_linalg_pinv = np.linalg.pinv
module.sp_linalg_lu = lambda x: sp.linalg.lu(x, True)
module.sp_linalg_schur = sp.linalg.schur
# module.sp_linalg_hessenberg = sp.linalg.hessenberg
module.sp_linalg_hessenberg = lambda x: x
def file_import(filename, prefix="file_import_"): def file_import(filename, prefix="file_import_"):
filename = pathlib.Path(filename) filename = pathlib.Path(filename)

View File

View File

@ -0,0 +1,2 @@
Excepytiopn!! knfv 0x7fffffff9218
__nac3_personality(state: 1, exception_object: 1, context: 1381323604)

View File

@ -42,11 +42,14 @@ done
if [ -n "$debug" ] && [ -e ../../target/debug/nac3standalone ]; then if [ -n "$debug" ] && [ -e ../../target/debug/nac3standalone ]; then
nac3standalone=../../target/debug/nac3standalone nac3standalone=../../target/debug/nac3standalone
externfns=../../target/debug/deps/liblinalg_externfns.so
elif [ -e ../../target/release/nac3standalone ]; then elif [ -e ../../target/release/nac3standalone ]; then
nac3standalone=../../target/release/nac3standalone nac3standalone=../../target/release/nac3standalone
externfns=../../target/release/deps/liblinalg_externfns.so
else else
# used by Nix builds # used by Nix builds
nac3standalone=../../target/x86_64-unknown-linux-gnu/release/nac3standalone nac3standalone=../../target/x86_64-unknown-linux-gnu/release/nac3standalone
externfns=../../target/x86_64-unknown-linux-gnu/release/deps/liblinalg_externfns.so
fi fi
rm -f ./*.o ./*.bc demo rm -f ./*.o ./*.bc demo
@ -54,7 +57,7 @@ if [ -z "$use_lli" ]; then
$nac3standalone "${nac3args[@]}" $nac3standalone "${nac3args[@]}"
clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c
clang -lm -o demo module.o demo.o clang -lm -o demo module.o demo.o $externfns
if [ -z "$outfile" ]; then if [ -z "$outfile" ]; then
./demo ./demo
@ -71,8 +74,8 @@ else
shopt -u nullglob shopt -u nullglob
if [ -z "$outfile" ]; then if [ -z "$outfile" ]; then
lli --extra-module demo.bc --extra-module irrt.bc nac3out.bc lli -load=$externfns --extra-module demo.bc --extra-module irrt.bc nac3out.bc
else else
lli --extra-module demo.bc --extra-module irrt.bc nac3out.bc > "$outfile" lli -load=$externfns --extra-module demo.bc --extra-module irrt.bc nac3out.bc > "$outfile"
fi fi
fi fi

View File

@ -0,0 +1,12 @@
8.000000
10.000000
12.000000
4.000000
5.000000
6.000000
1.000000
2.000000
3.000000
4.000000
5.000000
6.000000

View File

@ -531,10 +531,12 @@ def test_ndarray_ipow_broadcast_scalar():
def test_ndarray_matmul(): def test_ndarray_matmul():
x = np_identity(2) x = np_identity(2)
y = x @ np_ones([2, 2]) t: ndarray[float, 2] = np_array([[1., 2., 3.,], [4., 5., 6.], [7., 8., 9.], [7., 8., 9.]])
y = x @ t
output_ndarray_float_2(x) y2 = np_linalg_matmul(x, t)
output_ndarray_float_2(y) output_ndarray_float_2(y)
output_ndarray_float_2(y2)
def test_ndarray_imatmul(): def test_ndarray_imatmul():
x = np_identity(2) x = np_identity(2)
@ -1429,183 +1431,289 @@ def test_ndarray_nextafter_broadcast_rhs_scalar():
output_ndarray_float_2(nextafter_x_zeros) output_ndarray_float_2(nextafter_x_zeros)
output_ndarray_float_2(nextafter_x_ones) output_ndarray_float_2(nextafter_x_ones)
def test_ndarray_dot():
x: ndarray[float, 1] = np_array([5.0, 1.0])
y: ndarray[float, 1] = np_array([5.0, 1.0])
z = np_dot(x, y)
output_ndarray_float_1(x)
output_ndarray_float_1(y)
output_float64(z)
def test_ndarray_linalg_matmul():
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
y: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
z = np_linalg_matmul(x, y)
m = np_argmax(z)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
output_ndarray_float_2(z)
output_int64(m)
def test_ndarray_cholesky():
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
y = np_linalg_cholesky(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_qr():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
y, z = np_linalg_qr(x)
output_ndarray_float_2(x)
# QR Factorization in nalgebra and numpy do not give the same result
# Generating product for printing
a = np_linalg_matmul(y, z)
output_ndarray_float_2(a)
def test_ndarray_linalg_inv():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
y = np_linalg_inv(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_pinv():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5]])
y = np_linalg_pinv(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_schur():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
t, z = sp_linalg_schur(x)
output_ndarray_float_2(x)
# Same as np_linalg_qr the signs are different in nalgebra and numpy
a = np_linalg_matmul(np_linalg_matmul(z, t), np_linalg_inv(z))
output_ndarray_float_2(a)
def test_ndarray_hessenberg():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5]])
h = sp_linalg_hessenberg(x)
output_ndarray_float_2(x)
output_ndarray_float_2(h)
def test_ndarray_lu():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5]])
l, u = sp_linalg_lu(x)
output_ndarray_float_2(x)
output_ndarray_float_2(l)
output_ndarray_float_2(u)
def test_ndarray_svd():
w: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
x, y, z = np_linalg_svd(w)
output_ndarray_float_2(w)
# Same as np_linalg_qr the signs are different in nalgebra and numpy
a = np_linalg_matmul(x, z)
output_ndarray_float_2(a)
output_ndarray_float_1(y)
def run() -> int32: def run() -> int32:
test_ndarray_ctor()
test_ndarray_empty()
test_ndarray_zeros()
test_ndarray_ones()
test_ndarray_full()
test_ndarray_eye()
test_ndarray_array()
test_ndarray_identity()
test_ndarray_fill()
test_ndarray_copy()
test_ndarray_neg_idx()
test_ndarray_slices()
test_ndarray_nd_idx()
test_ndarray_add()
test_ndarray_add_broadcast()
test_ndarray_add_broadcast_lhs_scalar()
test_ndarray_add_broadcast_rhs_scalar()
test_ndarray_iadd()
test_ndarray_iadd_broadcast()
test_ndarray_iadd_broadcast_scalar()
test_ndarray_sub()
test_ndarray_sub_broadcast()
test_ndarray_sub_broadcast_lhs_scalar()
test_ndarray_sub_broadcast_rhs_scalar()
test_ndarray_isub()
test_ndarray_isub_broadcast()
test_ndarray_isub_broadcast_scalar()
test_ndarray_mul()
test_ndarray_mul_broadcast()
test_ndarray_mul_broadcast_lhs_scalar()
test_ndarray_mul_broadcast_rhs_scalar()
test_ndarray_imul()
test_ndarray_imul_broadcast()
test_ndarray_imul_broadcast_scalar()
test_ndarray_truediv()
test_ndarray_truediv_broadcast()
test_ndarray_truediv_broadcast_lhs_scalar()
test_ndarray_truediv_broadcast_rhs_scalar()
test_ndarray_itruediv()
test_ndarray_itruediv_broadcast()
test_ndarray_itruediv_broadcast_scalar()
test_ndarray_floordiv()
test_ndarray_floordiv_broadcast()
test_ndarray_floordiv_broadcast_lhs_scalar()
test_ndarray_floordiv_broadcast_rhs_scalar()
test_ndarray_ifloordiv()
test_ndarray_ifloordiv_broadcast()
test_ndarray_ifloordiv_broadcast_scalar()
test_ndarray_mod()
test_ndarray_mod_broadcast()
test_ndarray_mod_broadcast_lhs_scalar()
test_ndarray_mod_broadcast_rhs_scalar()
test_ndarray_imod()
test_ndarray_imod_broadcast()
test_ndarray_imod_broadcast_scalar()
test_ndarray_pow()
test_ndarray_pow_broadcast()
test_ndarray_pow_broadcast_lhs_scalar()
test_ndarray_pow_broadcast_rhs_scalar()
test_ndarray_ipow()
test_ndarray_ipow_broadcast()
test_ndarray_ipow_broadcast_scalar()
test_ndarray_matmul() test_ndarray_matmul()
test_ndarray_imatmul() # test_ndarray_dot()
test_ndarray_pos() # test_ndarray_linalg_matmul()
test_ndarray_neg() # test_ndarray_cholesky()
test_ndarray_inv() # test_ndarray_qr()
test_ndarray_eq() # test_ndarray_svd()
test_ndarray_eq_broadcast() # test_ndarray_linalg_inv()
test_ndarray_eq_broadcast_lhs_scalar() # test_ndarray_pinv()
test_ndarray_eq_broadcast_rhs_scalar() # test_ndarray_lu()
test_ndarray_ne() # test_ndarray_schur()
test_ndarray_ne_broadcast() # test_ndarray_hessenberg()
test_ndarray_ne_broadcast_lhs_scalar()
test_ndarray_ne_broadcast_rhs_scalar()
test_ndarray_lt()
test_ndarray_lt_broadcast()
test_ndarray_lt_broadcast_lhs_scalar()
test_ndarray_lt_broadcast_rhs_scalar()
test_ndarray_lt()
test_ndarray_le_broadcast()
test_ndarray_le_broadcast_lhs_scalar()
test_ndarray_le_broadcast_rhs_scalar()
test_ndarray_gt()
test_ndarray_gt_broadcast()
test_ndarray_gt_broadcast_lhs_scalar()
test_ndarray_gt_broadcast_rhs_scalar()
test_ndarray_gt()
test_ndarray_ge_broadcast()
test_ndarray_ge_broadcast_lhs_scalar()
test_ndarray_ge_broadcast_rhs_scalar()
test_ndarray_int32() # test_ndarray_ctor()
test_ndarray_int64() # test_ndarray_empty()
test_ndarray_uint32() # test_ndarray_zeros()
test_ndarray_uint64() # test_ndarray_ones()
test_ndarray_float() # test_ndarray_full()
test_ndarray_bool() # test_ndarray_eye()
# test_ndarray_array()
# test_ndarray_identity()
# test_ndarray_fill()
# test_ndarray_copy()
test_ndarray_round() # test_ndarray_neg_idx()
test_ndarray_floor() # test_ndarray_slices()
test_ndarray_min() # test_ndarray_nd_idx()
test_ndarray_minimum()
test_ndarray_minimum_broadcast()
test_ndarray_minimum_broadcast_lhs_scalar()
test_ndarray_minimum_broadcast_rhs_scalar()
test_ndarray_argmin()
test_ndarray_max()
test_ndarray_maximum()
test_ndarray_maximum_broadcast()
test_ndarray_maximum_broadcast_lhs_scalar()
test_ndarray_maximum_broadcast_rhs_scalar()
test_ndarray_argmax()
test_ndarray_abs()
test_ndarray_isnan()
test_ndarray_isinf()
test_ndarray_sin() # test_ndarray_add()
test_ndarray_cos() # test_ndarray_add_broadcast()
test_ndarray_exp() # test_ndarray_add_broadcast_lhs_scalar()
test_ndarray_exp2() # test_ndarray_add_broadcast_rhs_scalar()
test_ndarray_log() # test_ndarray_iadd()
test_ndarray_log10() # test_ndarray_iadd_broadcast()
test_ndarray_log2() # test_ndarray_iadd_broadcast_scalar()
test_ndarray_fabs() # test_ndarray_sub()
test_ndarray_sqrt() # test_ndarray_sub_broadcast()
test_ndarray_rint() # test_ndarray_sub_broadcast_lhs_scalar()
test_ndarray_tan() # test_ndarray_sub_broadcast_rhs_scalar()
test_ndarray_arcsin() # test_ndarray_isub()
test_ndarray_arccos() # test_ndarray_isub_broadcast()
test_ndarray_arctan() # test_ndarray_isub_broadcast_scalar()
test_ndarray_sinh() # test_ndarray_mul()
test_ndarray_cosh() # test_ndarray_mul_broadcast()
test_ndarray_tanh() # test_ndarray_mul_broadcast_lhs_scalar()
test_ndarray_arcsinh() # test_ndarray_mul_broadcast_rhs_scalar()
test_ndarray_arccosh() # test_ndarray_imul()
test_ndarray_arctanh() # test_ndarray_imul_broadcast()
test_ndarray_expm1() # test_ndarray_imul_broadcast_scalar()
test_ndarray_cbrt() # test_ndarray_truediv()
# test_ndarray_truediv_broadcast()
# test_ndarray_truediv_broadcast_lhs_scalar()
# test_ndarray_truediv_broadcast_rhs_scalar()
# test_ndarray_itruediv()
# test_ndarray_itruediv_broadcast()
# test_ndarray_itruediv_broadcast_scalar()
# test_ndarray_floordiv()
# test_ndarray_floordiv_broadcast()
# test_ndarray_floordiv_broadcast_lhs_scalar()
# test_ndarray_floordiv_broadcast_rhs_scalar()
# test_ndarray_ifloordiv()
# test_ndarray_ifloordiv_broadcast()
# test_ndarray_ifloordiv_broadcast_scalar()
# test_ndarray_mod()
# test_ndarray_mod_broadcast()
# test_ndarray_mod_broadcast_lhs_scalar()
# test_ndarray_mod_broadcast_rhs_scalar()
# test_ndarray_imod()
# test_ndarray_imod_broadcast()
# test_ndarray_imod_broadcast_scalar()
# test_ndarray_pow()
# test_ndarray_pow_broadcast()
# test_ndarray_pow_broadcast_lhs_scalar()
# test_ndarray_pow_broadcast_rhs_scalar()
# test_ndarray_ipow()
# test_ndarray_ipow_broadcast()
# test_ndarray_ipow_broadcast_scalar()
# test_ndarray_matmul()
# test_ndarray_imatmul()
# test_ndarray_pos()
# test_ndarray_neg()
# test_ndarray_inv()
# test_ndarray_eq()
# test_ndarray_eq_broadcast()
# test_ndarray_eq_broadcast_lhs_scalar()
# test_ndarray_eq_broadcast_rhs_scalar()
# test_ndarray_ne()
# test_ndarray_ne_broadcast()
# test_ndarray_ne_broadcast_lhs_scalar()
# test_ndarray_ne_broadcast_rhs_scalar()
# test_ndarray_lt()
# test_ndarray_lt_broadcast()
# test_ndarray_lt_broadcast_lhs_scalar()
# test_ndarray_lt_broadcast_rhs_scalar()
# test_ndarray_lt()
# test_ndarray_le_broadcast()
# test_ndarray_le_broadcast_lhs_scalar()
# test_ndarray_le_broadcast_rhs_scalar()
# test_ndarray_gt()
# test_ndarray_gt_broadcast()
# test_ndarray_gt_broadcast_lhs_scalar()
# test_ndarray_gt_broadcast_rhs_scalar()
# test_ndarray_gt()
# test_ndarray_ge_broadcast()
# test_ndarray_ge_broadcast_lhs_scalar()
# test_ndarray_ge_broadcast_rhs_scalar()
test_ndarray_erf() # test_ndarray_int32()
test_ndarray_erfc() # test_ndarray_int64()
test_ndarray_gamma() # test_ndarray_uint32()
test_ndarray_gammaln() # test_ndarray_uint64()
test_ndarray_j0() # test_ndarray_float()
test_ndarray_j1() # test_ndarray_bool()
test_ndarray_arctan2() # test_ndarray_round()
test_ndarray_arctan2_broadcast() # test_ndarray_floor()
test_ndarray_arctan2_broadcast_lhs_scalar() # test_ndarray_min()
test_ndarray_arctan2_broadcast_rhs_scalar() # test_ndarray_minimum()
test_ndarray_copysign() # test_ndarray_minimum_broadcast()
test_ndarray_copysign_broadcast() # test_ndarray_minimum_broadcast_lhs_scalar()
test_ndarray_copysign_broadcast_lhs_scalar() # test_ndarray_minimum_broadcast_rhs_scalar()
test_ndarray_copysign_broadcast_rhs_scalar() # test_ndarray_argmin()
test_ndarray_fmax() # test_ndarray_max()
test_ndarray_fmax_broadcast() # test_ndarray_maximum()
test_ndarray_fmax_broadcast_lhs_scalar() # test_ndarray_maximum_broadcast()
test_ndarray_fmax_broadcast_rhs_scalar() # test_ndarray_maximum_broadcast_lhs_scalar()
test_ndarray_fmin() # test_ndarray_maximum_broadcast_rhs_scalar()
test_ndarray_fmin_broadcast() # test_ndarray_argmax()
test_ndarray_fmin_broadcast_lhs_scalar() # test_ndarray_abs()
test_ndarray_fmin_broadcast_rhs_scalar() # test_ndarray_isnan()
test_ndarray_ldexp() # test_ndarray_isinf()
test_ndarray_ldexp_broadcast()
test_ndarray_ldexp_broadcast_lhs_scalar() # test_ndarray_sin()
test_ndarray_ldexp_broadcast_rhs_scalar() # test_ndarray_cos()
test_ndarray_hypot() # test_ndarray_exp()
test_ndarray_hypot_broadcast() # test_ndarray_exp2()
test_ndarray_hypot_broadcast_lhs_scalar() # test_ndarray_log()
test_ndarray_hypot_broadcast_rhs_scalar() # test_ndarray_log10()
test_ndarray_nextafter() # test_ndarray_log2()
test_ndarray_nextafter_broadcast() # test_ndarray_fabs()
test_ndarray_nextafter_broadcast_lhs_scalar() # test_ndarray_sqrt()
test_ndarray_nextafter_broadcast_rhs_scalar() # test_ndarray_rint()
# test_ndarray_tan()
# test_ndarray_arcsin()
# test_ndarray_arccos()
# test_ndarray_arctan()
# test_ndarray_sinh()
# test_ndarray_cosh()
# test_ndarray_tanh()
# test_ndarray_arcsinh()
# test_ndarray_arccosh()
# test_ndarray_arctanh()
# test_ndarray_expm1()
# test_ndarray_cbrt()
# test_ndarray_erf()
# test_ndarray_erfc()
# test_ndarray_gamma()
# test_ndarray_gammaln()
# test_ndarray_j0()
# test_ndarray_j1()
# test_ndarray_arctan2()
# test_ndarray_arctan2_broadcast()
# test_ndarray_arctan2_broadcast_lhs_scalar()
# test_ndarray_arctan2_broadcast_rhs_scalar()
# test_ndarray_copysign()
# test_ndarray_copysign_broadcast()
# test_ndarray_copysign_broadcast_lhs_scalar()
# test_ndarray_copysign_broadcast_rhs_scalar()
# test_ndarray_fmax()
# test_ndarray_fmax_broadcast()
# test_ndarray_fmax_broadcast_lhs_scalar()
# test_ndarray_fmax_broadcast_rhs_scalar()
# test_ndarray_fmin()
# test_ndarray_fmin_broadcast()
# test_ndarray_fmin_broadcast_lhs_scalar()
# test_ndarray_fmin_broadcast_rhs_scalar()
# test_ndarray_ldexp()
# test_ndarray_ldexp_broadcast()
# test_ndarray_ldexp_broadcast_lhs_scalar()
# test_ndarray_ldexp_broadcast_rhs_scalar()
# test_ndarray_hypot()
# test_ndarray_hypot_broadcast()
# test_ndarray_hypot_broadcast_lhs_scalar()
# test_ndarray_hypot_broadcast_rhs_scalar()
# test_ndarray_nextafter()
# test_ndarray_nextafter_broadcast()
# test_ndarray_nextafter_broadcast_lhs_scalar()
# test_ndarray_nextafter_broadcast_rhs_scalar()
# test_try_invert()
# test_wilkinson_shift()
return 0 return 0

View File

@ -0,0 +1,11 @@
[package]
name = "linalg_externfns"
version = "0.1.0"
edition = "2021"
[lib]
crate-type = ["cdylib"]
[dependencies]
nalgebra = {version = "0.32.6", default-features = false, features = ["libm", "alloc"]}
cslice = "0.3.0"

View File

@ -0,0 +1,346 @@
mod runtime_exception;
use core::slice;
use nalgebra::{linalg, DMatrix};
macro_rules! raise_exn {
($name:expr, $message:expr, $param0:expr, $param1:expr, $param2:expr) => {{
use cslice::AsCSlice;
let name_id = $crate::runtime_exception::get_exception_id($name);
let exn = $crate::runtime_exception::Exception {
id: name_id,
file: file!().as_c_slice(),
line: line!(),
column: column!(),
// https://github.com/rust-lang/rfcs/pull/1719
function: "(Rust function)".as_c_slice(),
message: $message.as_c_slice(),
param: [$param0, $param1, $param2],
};
#[allow(unused_unsafe)]
unsafe {
$crate::runtime_exception::raise(&exn)
}
}};
($name:expr, $message:expr) => {{
raise_exn!($name, $message, 0, 0, 0)
}};
}
/// # Safety
///
/// `data` must point to an array with `dim0`x`dim1` elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn linalg_try_invert_to(dim0: usize, dim1: usize, data: *mut f64) -> i8 {
let data_slice = unsafe { slice::from_raw_parts_mut(data, dim0 * dim1) };
let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice);
let mut inverted_matrix = DMatrix::<f64>::zeros(dim0, dim1);
if linalg::try_invert_to(matrix, &mut inverted_matrix) {
data_slice.copy_from_slice(inverted_matrix.transpose().as_slice());
1
} else {
0
}
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn linalg_wilkinson_shift(dim0: usize, dim1: usize, data: *mut f64) -> f64 {
let data_slice = unsafe { slice::from_raw_parts_mut(data, dim0 * dim1) };
let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice);
linalg::wilkinson_shift(matrix[(0, 0)], matrix[(1, 1)], matrix[(0, 1)])
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_dot(
dim0: usize,
dim1: usize,
x1: *mut f64,
_: usize,
_: usize,
x2: *mut f64,
) -> f64 {
let data_slice1 = unsafe { slice::from_raw_parts_mut(x1, dim0 * dim1) };
let data_slice2 = unsafe { slice::from_raw_parts_mut(x2, dim0 * dim1) };
let matrix1 = DMatrix::from_row_slice(dim0, dim1, data_slice1);
let matrix2 = DMatrix::from_row_slice(dim0, dim1, data_slice2);
matrix1.dot(&matrix2)
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_matmul(
dim1_0: usize,
dim1_1: usize,
x1: *mut f64,
dim2_0: usize,
dim2_1: usize,
x2: *mut f64,
dim3_0: usize,
dim3_1: usize,
out: *mut f64,
) -> i8 {
// let name = unsafe {slice::from_raw_parts_mut(n, l)};
// let fne = name.as_c_slice();
raise_exn!("ZeroDivisionError", "Divide by Zero");
let data_slice1 = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
let data_slice2 = unsafe { slice::from_raw_parts_mut(x2, dim2_0 * dim2_1) };
let out_slice = unsafe { slice::from_raw_parts_mut(out, dim3_0 * dim3_1) };
let matrix1 = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice1);
let matrix2 = DMatrix::from_row_slice(dim2_0, dim2_1, data_slice2);
let mut result = DMatrix::<f64>::zeros(dim3_0, dim3_1);
matrix1.mul_to(&matrix2, &mut result);
out_slice.copy_from_slice(result.transpose().as_slice());
// raise_exn!("ZeroDivisionError", "Divide by Zero", r, c, n);
1
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_cholesky(
dim1_0: usize,
dim1_1: usize,
x1: *mut f64,
dim2_0: usize,
dim2_1: usize,
out: *mut f64,
) -> i8 {
let data_slice1 = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
let out_slice = unsafe { slice::from_raw_parts_mut(out, dim2_0 * dim2_1) };
let matrix1 = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice1);
let res = matrix1.cholesky();
match res {
None => 0,
Some(c) => {
out_slice.copy_from_slice(c.unpack().transpose().as_slice());
1
}
}
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_qr(
dim1_0: usize,
dim1_1: usize,
x1: *mut f64,
dim2_0: usize,
dim2_1: usize,
out_q: *mut f64,
dim3_0: usize,
dim3_1: usize,
out_r: *mut f64,
) -> i8 {
let data_slice1 = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
let out_q_slice = unsafe { slice::from_raw_parts_mut(out_q, dim2_0 * dim2_1) };
let out_r_slice = unsafe { slice::from_raw_parts_mut(out_r, dim3_0 * dim3_1) };
// Refer to https://github.com/dimforge/nalgebra/issues/735
let matrix1 = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice1);
let res = matrix1.qr();
let (q, r) = res.unpack();
// Uses different algo need to match numpy
out_q_slice.copy_from_slice(q.transpose().as_slice());
out_r_slice.copy_from_slice(r.transpose().as_slice());
1
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_svd(
dim1_0: usize,
dim1_1: usize,
x1: *mut f64,
dim2_0: usize,
dim2_1: usize,
out_u: *mut f64,
dim3_0: usize,
dim3_1: usize,
out_s: *mut f64,
dim4_0: usize,
dim4_1: usize,
out_vh: *mut f64,
) -> i8 {
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
let out_u_slice = unsafe { slice::from_raw_parts_mut(out_u, dim2_0 * dim2_1) };
let out_s_slice = unsafe { slice::from_raw_parts_mut(out_s, dim3_0 * dim3_1) };
let out_vh_slice = unsafe { slice::from_raw_parts_mut(out_vh, dim4_0 * dim4_1) };
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
let res = matrix.svd(true, true);
out_u_slice.copy_from_slice(res.u.unwrap().transpose().as_slice());
out_s_slice.copy_from_slice(res.singular_values.as_slice());
out_vh_slice.copy_from_slice(res.v_t.unwrap().transpose().as_slice());
1
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_inv(
dim1_0: usize,
dim1_1: usize,
x1: *mut f64,
dim2_0: usize,
dim2_1: usize,
out: *mut f64,
) -> i8 {
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
let out_slice = unsafe { slice::from_raw_parts_mut(out, dim2_0 * dim2_1) };
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
if !matrix.is_invertible() {
// raise error
return 0;
}
let inv = matrix.try_inverse().unwrap();
out_slice.copy_from_slice(inv.transpose().as_slice());
1
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_pinv(
dim1_0: usize,
dim1_1: usize,
x1: *mut f64,
dim2_0: usize,
dim2_1: usize,
out: *mut f64,
) -> i8 {
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
let out_slice = unsafe { slice::from_raw_parts_mut(out, dim2_0 * dim2_1) };
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
let svd = matrix.svd(true, true);
let inv = svd.pseudo_inverse(1e-15);
match inv {
Ok(m) => {
out_slice.copy_from_slice(m.transpose().as_slice());
1
}
Err(e) => {
// raise exception here
assert!(false, "{e}");
0
}
}
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn sp_linalg_lu(
dim1_0: usize,
dim1_1: usize,
x1: *mut f64,
dim2_0: usize,
dim2_1: usize,
out_l: *mut f64,
dim3_0: usize,
dim3_1: usize,
out_u: *mut f64,
) -> i8 {
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
let out_l_slice = unsafe { slice::from_raw_parts_mut(out_l, dim2_0 * dim2_1) };
let out_u_slice = unsafe { slice::from_raw_parts_mut(out_u, dim3_0 * dim3_1) };
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
let (_, l, u) = matrix.lu().unpack();
out_l_slice.copy_from_slice(l.transpose().as_slice());
out_u_slice.copy_from_slice(u.transpose().as_slice());
1
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn sp_linalg_schur(
dim1_0: usize,
dim1_1: usize,
x1: *mut f64,
dim2_0: usize,
dim2_1: usize,
out_t: *mut f64,
dim3_0: usize,
dim3_1: usize,
out_z: *mut f64,
) -> i8 {
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
let out_t_slice = unsafe { slice::from_raw_parts_mut(out_t, dim2_0 * dim2_1) };
let out_z_slice = unsafe { slice::from_raw_parts_mut(out_z, dim3_0 * dim3_1) };
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
if !matrix.is_square() {
// Throw error here
return 0;
}
let (z, t) = matrix.schur().unpack();
out_t_slice.copy_from_slice(t.transpose().as_slice());
out_z_slice.copy_from_slice(z.transpose().as_slice());
1
}
/// # Safety
///
/// `data` must point to an array of 4 elements in row-major order
#[no_mangle]
pub unsafe extern "C" fn sp_linalg_hessenberg(
dim1_0: usize,
dim1_1: usize,
x1: *mut f64,
dim2_0: usize,
dim2_1: usize,
out_h: *mut f64,
) -> i8 {
let data_slice = unsafe { slice::from_raw_parts_mut(x1, dim1_0 * dim1_1) };
let out_h_slice = unsafe { slice::from_raw_parts_mut(out_h, dim2_0 * dim2_1) };
let matrix = DMatrix::from_row_slice(dim1_0, dim1_1, data_slice);
if !matrix.is_square() {
// Throw error here
return 0;
}
let (_, h) = matrix.hessenberg().unpack();
out_h_slice.copy_from_slice(h.transpose().as_slice());
1
}

View File

@ -0,0 +1,80 @@
#![allow(non_camel_case_types)]
#![allow(unused)]
// ARTIQ Exception struct declaration
use cslice::CSlice;
// Note: CSlice within an exception may not be actual cslice, they may be strings that exist only
// in the host. If the length == usize:MAX, the pointer is actually a string key in the host.
#[repr(C)]
#[derive(Clone)]
pub struct Exception<'a> {
pub id: u32,
pub file: CSlice<'a, u8>,
pub line: u32,
pub column: u32,
pub function: CSlice<'a, u8>,
pub message: CSlice<'a, u8>,
pub param: [i64; 3],
}
fn str_err(_: core::str::Utf8Error) -> core::fmt::Error {
core::fmt::Error
}
fn exception_str<'a>(s: &'a CSlice<'a, u8>) -> Result<&'a str, core::str::Utf8Error> {
if s.len() == usize::MAX {
Ok("<host string>")
} else {
core::str::from_utf8(s.as_ref())
}
}
impl<'a> core::fmt::Debug for Exception<'a> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"Exception {} from {} in {}:{}:{}, message: {}",
self.id,
exception_str(&self.function).map_err(str_err)?,
exception_str(&self.file).map_err(str_err)?,
self.line,
self.column,
exception_str(&self.message).map_err(str_err)?
)
}
}
pub unsafe fn raise(exception: *const Exception) -> ! {
println!("Excepytiopn!! knfv {:?}", exception);
let e = &*exception;
let f1 = exception_str(&e.function).map_err(str_err).unwrap();
let f2 = exception_str(&e.file).map_err(str_err).unwrap();
let f3 = exception_str(&e.message).map_err(str_err).unwrap();
panic!("Exception {} from {} in {}:{}:{}, message: {}", e.id, f1, f2, e.line, e.column, f3);
}
static EXCEPTION_ID_LOOKUP: [(&str, u32); 12] = [
("RuntimeError", 0),
("RTIOUnderflow", 1),
("RTIOOverflow", 2),
("RTIODestinationUnreachable", 3),
("DMAError", 4),
("I2CError", 5),
("CacheError", 6),
("SPIError", 7),
("ZeroDivisionError", 8),
("IndexError", 9),
("UnwrapNoneError", 10),
("Value", 11),
];
pub fn get_exception_id(name: &str) -> u32 {
for (n, id) in EXCEPTION_ID_LOOKUP.iter() {
if *n == name {
return *id;
}
}
unimplemented!("unallocated internal exception id")
}

View File

@ -113,7 +113,9 @@ fn handle_typevar_definition(
x, x,
HashMap::new(), HashMap::new(),
)?; )?;
get_type_from_type_annotation_kinds(def_list, unifier, &ty, &mut None) get_type_from_type_annotation_kinds(
def_list, unifier, primitives, &ty, &mut None,
)
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
let loc = func.location; let loc = func.location;
@ -152,7 +154,7 @@ fn handle_typevar_definition(
HashMap::new(), HashMap::new(),
)?; )?;
let constraint = let constraint =
get_type_from_type_annotation_kinds(def_list, unifier, &ty, &mut None)?; get_type_from_type_annotation_kinds(def_list, unifier, primitives, &ty, &mut None)?;
let loc = func.location; let loc = func.location;
Ok(unifier.get_fresh_const_generic_var(constraint, Some(generic_name), Some(loc)).ty) Ok(unifier.get_fresh_const_generic_var(constraint, Some(generic_name), Some(loc)).ty)