forked from M-Labs/nac3
WIP
This commit is contained in:
parent
6c10e3d056
commit
6ad597e592
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,3 +1,4 @@
|
|||||||
__pycache__
|
__pycache__
|
||||||
/target
|
/target
|
||||||
nix/windows/msys2
|
nix/windows/msys2
|
||||||
|
nac3standalone/demo/externfns/target
|
||||||
|
99
Cargo.lock
generated
99
Cargo.lock
generated
@ -73,6 +73,15 @@ dependencies = [
|
|||||||
"windows-sys",
|
"windows-sys",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "approx"
|
||||||
|
version = "0.5.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6"
|
||||||
|
dependencies = [
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ascii-canvas"
|
name = "ascii-canvas"
|
||||||
version = "3.0.0"
|
version = "3.0.0"
|
||||||
@ -305,6 +314,13 @@ dependencies = [
|
|||||||
"windows-sys",
|
"windows-sys",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "externfns"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"nalgebra",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fastrand"
|
name = "fastrand"
|
||||||
version = "2.1.0"
|
version = "2.1.0"
|
||||||
@ -521,6 +537,12 @@ dependencies = [
|
|||||||
"windows-targets",
|
"windows-targets",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "libm"
|
||||||
|
version = "0.2.8"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "libredox"
|
name = "libredox"
|
||||||
version = "0.1.3"
|
version = "0.1.3"
|
||||||
@ -658,18 +680,71 @@ name = "nac3standalone"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"clap",
|
"clap",
|
||||||
|
"externfns",
|
||||||
"inkwell",
|
"inkwell",
|
||||||
"nac3core",
|
"nac3core",
|
||||||
"nac3parser",
|
"nac3parser",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "nalgebra"
|
||||||
|
version = "0.32.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7b5c17de023a86f59ed79891b2e5d5a94c705dbe904a5b5c9c952ea6221b03e4"
|
||||||
|
dependencies = [
|
||||||
|
"approx",
|
||||||
|
"num-complex",
|
||||||
|
"num-rational",
|
||||||
|
"num-traits",
|
||||||
|
"simba",
|
||||||
|
"typenum",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "new_debug_unreachable"
|
name = "new_debug_unreachable"
|
||||||
version = "1.0.6"
|
version = "1.0.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086"
|
checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-complex"
|
||||||
|
version = "0.4.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
|
||||||
|
dependencies = [
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-integer"
|
||||||
|
version = "0.1.46"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
|
||||||
|
dependencies = [
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-rational"
|
||||||
|
version = "0.4.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
|
||||||
|
dependencies = [
|
||||||
|
"num-integer",
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-traits"
|
||||||
|
version = "0.2.19"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
"libm",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "once_cell"
|
name = "once_cell"
|
||||||
version = "1.19.0"
|
version = "1.19.0"
|
||||||
@ -699,6 +774,12 @@ dependencies = [
|
|||||||
"windows-targets",
|
"windows-targets",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "paste"
|
||||||
|
version = "1.0.15"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "petgraph"
|
name = "petgraph"
|
||||||
version = "0.6.5"
|
version = "0.6.5"
|
||||||
@ -1070,6 +1151,18 @@ dependencies = [
|
|||||||
"yaml-rust",
|
"yaml-rust",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "simba"
|
||||||
|
version = "0.8.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "061507c94fc6ab4ba1c9a0305018408e312e17c041eb63bef8aa726fa33aceae"
|
||||||
|
dependencies = [
|
||||||
|
"approx",
|
||||||
|
"num-complex",
|
||||||
|
"num-traits",
|
||||||
|
"paste",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "similar"
|
name = "similar"
|
||||||
version = "2.5.0"
|
version = "2.5.0"
|
||||||
@ -1230,6 +1323,12 @@ dependencies = [
|
|||||||
"crunchy",
|
"crunchy",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "typenum"
|
||||||
|
version = "1.17.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unic-char-property"
|
name = "unic-char-property"
|
||||||
version = "0.9.0"
|
version = "0.9.0"
|
||||||
|
@ -5,6 +5,7 @@ members = [
|
|||||||
"nac3parser",
|
"nac3parser",
|
||||||
"nac3core",
|
"nac3core",
|
||||||
"nac3standalone",
|
"nac3standalone",
|
||||||
|
"nac3standalone/demo/externfns",
|
||||||
"nac3artiq",
|
"nac3artiq",
|
||||||
"runkernel",
|
"runkernel",
|
||||||
]
|
]
|
||||||
|
@ -161,7 +161,9 @@
|
|||||||
clippy
|
clippy
|
||||||
pre-commit
|
pre-commit
|
||||||
rustfmt
|
rustfmt
|
||||||
|
rust-analyzer
|
||||||
];
|
];
|
||||||
|
RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
|
||||||
};
|
};
|
||||||
devShells.x86_64-linux.msys2 = pkgs.mkShell {
|
devShells.x86_64-linux.msys2 = pkgs.mkShell {
|
||||||
name = "nac3-dev-shell-msys2";
|
name = "nac3-dev-shell-msys2";
|
||||||
|
@ -14,12 +14,23 @@ class Demo:
|
|||||||
|
|
||||||
@kernel
|
@kernel
|
||||||
def run(self):
|
def run(self):
|
||||||
self.core.reset()
|
a = np_array([[1., 2.], [3., 4.]])
|
||||||
while True:
|
b = try_invert_to(a)
|
||||||
with parallel:
|
if b:
|
||||||
self.led0.pulse(100.*ms)
|
# self.core.reset()
|
||||||
self.led1.pulse(100.*ms)
|
# while True:
|
||||||
self.core.delay(100.*ms)
|
# with parallel:
|
||||||
|
# self.led0.pulse(100.*ms)
|
||||||
|
# self.led1.pulse(100.*ms)
|
||||||
|
# self.core.delay(100.*ms)
|
||||||
|
v = try_invert_to(np_identity(2))
|
||||||
|
if v:
|
||||||
|
while True:
|
||||||
|
with parallel:
|
||||||
|
self.led0.pulse(100.*ms)
|
||||||
|
self.led1.pulse(100.*ms)
|
||||||
|
self.core.delay(100.*ms)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
BIN
nac3artiq/demo/module.elf
Normal file
BIN
nac3artiq/demo/module.elf
Normal file
Binary file not shown.
@ -3,7 +3,7 @@ use inkwell::values::BasicValueEnum;
|
|||||||
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
|
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
|
||||||
use crate::codegen::classes::{NDArrayValue, ProxyValue, UntypedArrayLikeAccessor};
|
use crate::codegen::classes::{ArrayLikeValue, NDArrayValue, ProxyValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor};
|
||||||
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
|
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
|
||||||
use crate::codegen::stmt::gen_for_callback_incrementing;
|
use crate::codegen::stmt::gen_for_callback_incrementing;
|
||||||
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
|
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
|
||||||
@ -922,6 +922,122 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn call_try_invert_to<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
a: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
const FN_NAME: &str = "linalg_try_invert_to";
|
||||||
|
let (a_ty, a) = a;
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
match a {
|
||||||
|
BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
|
{
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
|
||||||
|
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
match llvm_ndarray_ty {
|
||||||
|
BasicTypeEnum::FloatType(_) => {},
|
||||||
|
_ => unreachable!("Inverse Operation supported on float type NDArray Values only")
|
||||||
|
};
|
||||||
|
|
||||||
|
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
||||||
|
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None));
|
||||||
|
|
||||||
|
// Add asserts for dims
|
||||||
|
if cfg!(debug_assertions) {
|
||||||
|
let n_dims = n.load_ndims(ctx);
|
||||||
|
|
||||||
|
// num_dim == 2
|
||||||
|
ctx.make_assert(generator,ctx.builder.build_int_compare(IntPredicate::EQ, n_dims, n_dims.get_type().const_int(2, false), "").unwrap(),
|
||||||
|
"0:ValueError", format!("Inverse only supported on 2D lists").as_str(), [None, None, None], ctx.current_loc);
|
||||||
|
|
||||||
|
let dim0 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)};
|
||||||
|
let dim1 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)};
|
||||||
|
|
||||||
|
// dim0 == dim1
|
||||||
|
ctx.make_assert(generator,ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim1, "").unwrap(),
|
||||||
|
"0:ValueError", format!("Dimensions do not match {dim0} is not same as {dim1}").as_str(), [None, None, None], ctx.current_loc);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||||
|
let n_sz_eqz = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "")
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
n_sz_eqz,
|
||||||
|
"0:ValueError",
|
||||||
|
format!("zero-size array to reduction operation {FN_NAME}").as_str(),
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a call to linalg_try_invert_to
|
||||||
|
let dim0 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)};
|
||||||
|
let dim1 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)};
|
||||||
|
|
||||||
|
Ok(extern_fns::call_linalg_try_invert_to(ctx, dim0, dim1, n.data().base_ptr(ctx, generator), None).into())
|
||||||
|
}
|
||||||
|
_ => unsupported_type(ctx, FN_NAME, &[a_ty]),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_wilkinson_shift<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
a: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
const FN_NAME: &str = "linalg_wilkinson_shift";
|
||||||
|
let (a_ty, a) = a;
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
match a {
|
||||||
|
BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
|
{
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
|
||||||
|
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
match llvm_ndarray_ty {
|
||||||
|
BasicTypeEnum::FloatType(_) | BasicTypeEnum::IntType(_) => {},
|
||||||
|
_ => unreachable!("Wilkinson Shift Operation supported on float/integer type NDArray Values only")
|
||||||
|
};
|
||||||
|
|
||||||
|
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
||||||
|
|
||||||
|
// Add asserts for dims
|
||||||
|
if cfg!(debug_assertions) {
|
||||||
|
let n_dims = n.load_ndims(ctx);
|
||||||
|
|
||||||
|
// num_dim == 2
|
||||||
|
ctx.make_assert(generator,ctx.builder.build_int_compare(IntPredicate::EQ, n_dims, n_dims.get_type().const_int(2, false), "").unwrap(),
|
||||||
|
"0:ValueError", format!("Wilkinson Shift supported only on 2D lists").as_str(), [None, None, None], ctx.current_loc);
|
||||||
|
|
||||||
|
let dim0 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)};
|
||||||
|
let dim1 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)};
|
||||||
|
|
||||||
|
// dim0 == dim1
|
||||||
|
ctx.make_assert(generator,ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim1, "").unwrap(),
|
||||||
|
"0:ValueError", format!("Dimensions do not match {dim0} is not same as {dim1}").as_str(), [None, None, None], ctx.current_loc);
|
||||||
|
|
||||||
|
// dimesions should be 2x2
|
||||||
|
ctx.make_assert(generator,ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim0.get_type().const_int(2, false), "").unwrap(),
|
||||||
|
"0:ValueError", format!("Wilkinson Shift supported only on 2x2 matrices").as_str(), [None, None, None], ctx.current_loc);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a call to linalg_try_invert_to
|
||||||
|
let dim0 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)};
|
||||||
|
let dim1 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)};
|
||||||
|
|
||||||
|
Ok(extern_fns::call_linalg_wilkinson_shift(ctx, dim0, dim1, n.data().base_ptr(ctx, generator), None).into())
|
||||||
|
}
|
||||||
|
_ => unsupported_type(ctx, FN_NAME, &[a_ty]),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Invokes the `np_maximum` builtin function.
|
/// Invokes the `np_maximum` builtin function.
|
||||||
pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
|
pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use inkwell::attributes::{Attribute, AttributeLoc};
|
use inkwell::attributes::{Attribute, AttributeLoc};
|
||||||
use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue};
|
use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue};
|
||||||
use itertools::Either;
|
use itertools::Either;
|
||||||
|
|
||||||
use crate::codegen::CodeGenContext;
|
use crate::codegen::CodeGenContext;
|
||||||
@ -130,3 +130,82 @@ pub fn call_ldexp<'ctx>(
|
|||||||
.map(Either::unwrap_left)
|
.map(Either::unwrap_left)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Invokes the [`try_invert_to`](https://docs.rs/nalgebra/latest/nalgebra/linalg/fn.try_invert_to.html) function
|
||||||
|
pub fn call_linalg_try_invert_to<'ctx>(
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
dim0: IntValue<'ctx>,
|
||||||
|
dim1: IntValue<'ctx>,
|
||||||
|
data: PointerValue<'ctx>,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
const FN_NAME: &str = "linalg_try_invert_to";
|
||||||
|
|
||||||
|
let llvm_f64 = ctx.ctx.f64_type();
|
||||||
|
let allowed_indices = [ctx.ctx.i32_type(), ctx.ctx.i64_type()];
|
||||||
|
|
||||||
|
let allowed_dim0 = allowed_indices.iter().any(|p| *p == dim0.get_type());
|
||||||
|
let allowed_dim1 = allowed_indices.iter().any(|p| *p == dim1.get_type());
|
||||||
|
|
||||||
|
debug_assert!(allowed_dim0);
|
||||||
|
debug_assert!(allowed_dim1);
|
||||||
|
debug_assert_eq!(data.get_type().get_element_type().into_float_type(), llvm_f64);
|
||||||
|
|
||||||
|
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||||
|
let fn_type = ctx.ctx.i8_type().fn_type(&[dim0.get_type().into(), dim0.get_type().into(), data.get_type().into()], false);
|
||||||
|
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||||
|
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
|
||||||
|
func.add_attribute(
|
||||||
|
AttributeLoc::Function,
|
||||||
|
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
func
|
||||||
|
});
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_call(extern_fn, &[dim0.into(), dim1.into(), data.into()], name.unwrap_or_default())
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||||
|
.map(Either::unwrap_left)
|
||||||
|
.unwrap().into()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Invokes the [`wilkinson_shift`](https://docs.rs/nalgebra/latest/nalgebra/linalg/fn.wilkinson_shift.html) function
|
||||||
|
pub fn call_linalg_wilkinson_shift<'ctx>(
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
dim0: IntValue<'ctx>,
|
||||||
|
dim1: IntValue<'ctx>,
|
||||||
|
data: PointerValue<'ctx>,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> FloatValue<'ctx> {
|
||||||
|
const FN_NAME: &str = "linalg_wilkinson_shift";
|
||||||
|
let allowed_index_types = [ctx.ctx.i32_type(), ctx.ctx.i64_type()];
|
||||||
|
|
||||||
|
let allowed_dim0 = allowed_index_types.iter().any(|p| *p == dim0.get_type());
|
||||||
|
let allowed_dim1 = allowed_index_types.iter().any(|p| *p == dim1.get_type());
|
||||||
|
|
||||||
|
debug_assert!(allowed_dim0);
|
||||||
|
debug_assert!(allowed_dim1);
|
||||||
|
|
||||||
|
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||||
|
let fn_type = ctx.ctx.f64_type().fn_type(&[dim0.get_type().into(), dim0.get_type().into(), data.get_type().into()], false);
|
||||||
|
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||||
|
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
|
||||||
|
func.add_attribute(
|
||||||
|
AttributeLoc::Function,
|
||||||
|
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
func
|
||||||
|
});
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_call(extern_fn, &[dim0.into(), dim1.into(), data.into()], name.unwrap_or_default())
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||||
|
.map(Either::unwrap_left)
|
||||||
|
.unwrap().into()
|
||||||
|
}
|
||||||
|
@ -556,6 +556,9 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
| PrimDef::FunNpLdExp
|
| PrimDef::FunNpLdExp
|
||||||
| PrimDef::FunNpHypot
|
| PrimDef::FunNpHypot
|
||||||
| PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim),
|
| PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim),
|
||||||
|
|
||||||
|
PrimDef::FunTryInvertTo => self.build_linalg_try_invert_to(prim), // Inplace invert
|
||||||
|
PrimDef::FunWilkinsonShift => self.build_linalg_wilkinson_shift(prim),
|
||||||
};
|
};
|
||||||
|
|
||||||
if cfg!(debug_assertions) {
|
if cfg!(debug_assertions) {
|
||||||
@ -1874,6 +1877,64 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn build_linalg_try_invert_to(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||||
|
debug_assert_prim_is_allowed(
|
||||||
|
prim,
|
||||||
|
&[
|
||||||
|
PrimDef::FunTryInvertTo,
|
||||||
|
],
|
||||||
|
);
|
||||||
|
|
||||||
|
let var_map = self.num_or_ndarray_var_map.clone();
|
||||||
|
create_fn_by_codegen(
|
||||||
|
self.unifier,
|
||||||
|
&var_map,
|
||||||
|
prim.name(),
|
||||||
|
self.primitives.bool,
|
||||||
|
&[(self.ndarray_float_2d, "x")],
|
||||||
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
|
let x_ty = fun.0.args[0].ty;
|
||||||
|
let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?;
|
||||||
|
|
||||||
|
let func = match prim {
|
||||||
|
PrimDef::FunTryInvertTo => builtin_fns::call_try_invert_to,
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Some(func(generator, ctx, (x_ty, x_val))?))
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_linalg_wilkinson_shift(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||||
|
debug_assert_prim_is_allowed(
|
||||||
|
prim,
|
||||||
|
&[
|
||||||
|
PrimDef::FunWilkinsonShift,
|
||||||
|
],
|
||||||
|
);
|
||||||
|
|
||||||
|
let var_map = self.num_or_ndarray_var_map.clone();
|
||||||
|
create_fn_by_codegen(
|
||||||
|
self.unifier,
|
||||||
|
&var_map,
|
||||||
|
prim.name(),
|
||||||
|
self.primitives.float,
|
||||||
|
&[(self.ndarray_float_2d, "x")],
|
||||||
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
|
let x_ty = fun.0.args[0].ty;
|
||||||
|
let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?;
|
||||||
|
|
||||||
|
let func = match prim {
|
||||||
|
PrimDef::FunWilkinsonShift => builtin_fns::call_wilkinson_shift,
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Some(func(generator, ctx, (x_ty, x_val))?))
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) {
|
fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) {
|
||||||
(prim.simple_name().into(), method_ty, prim.id())
|
(prim.simple_name().into(), method_ty, prim.id())
|
||||||
}
|
}
|
||||||
|
@ -105,6 +105,8 @@ pub enum PrimDef {
|
|||||||
FunNpLdExp,
|
FunNpLdExp,
|
||||||
FunNpHypot,
|
FunNpHypot,
|
||||||
FunNpNextAfter,
|
FunNpNextAfter,
|
||||||
|
FunTryInvertTo,
|
||||||
|
FunWilkinsonShift,
|
||||||
|
|
||||||
// Top-Level Functions
|
// Top-Level Functions
|
||||||
FunSome,
|
FunSome,
|
||||||
@ -261,6 +263,8 @@ impl PrimDef {
|
|||||||
PrimDef::FunNpLdExp => fun("np_ldexp", None),
|
PrimDef::FunNpLdExp => fun("np_ldexp", None),
|
||||||
PrimDef::FunNpHypot => fun("np_hypot", None),
|
PrimDef::FunNpHypot => fun("np_hypot", None),
|
||||||
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
|
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
|
||||||
|
PrimDef::FunTryInvertTo => fun("try_invert_to", None),
|
||||||
|
PrimDef::FunWilkinsonShift => fun("wilkinson_shift", None),
|
||||||
PrimDef::FunSome => fun("Some", None),
|
PrimDef::FunSome => fun("Some", None),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -8,6 +8,7 @@ edition = "2021"
|
|||||||
parking_lot = "0.12"
|
parking_lot = "0.12"
|
||||||
nac3parser = { path = "../nac3parser" }
|
nac3parser = { path = "../nac3parser" }
|
||||||
nac3core = { path = "../nac3core" }
|
nac3core = { path = "../nac3core" }
|
||||||
|
externfns = { path = "./demo/externfns" }
|
||||||
|
|
||||||
[dependencies.clap]
|
[dependencies.clap]
|
||||||
version = "4.5"
|
version = "4.5"
|
||||||
|
10
nac3standalone/demo/externfns/Cargo.toml
Normal file
10
nac3standalone/demo/externfns/Cargo.toml
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
[package]
|
||||||
|
name = "externfns"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
crate-type = ["cdylib"]
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
nalgebra = {version = "0.32.6", default-features = false, features = ["libm", "alloc"]}
|
41
nac3standalone/demo/externfns/src/lib.rs
Normal file
41
nac3standalone/demo/externfns/src/lib.rs
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
#![deny(
|
||||||
|
future_incompatible,
|
||||||
|
let_underscore,
|
||||||
|
nonstandard_style,
|
||||||
|
rust_2024_compatibility,
|
||||||
|
clippy::all
|
||||||
|
)]
|
||||||
|
#![warn(clippy::pedantic)]
|
||||||
|
#![allow(clippy::semicolon_if_nothing_returned, clippy::uninlined_format_args)]
|
||||||
|
|
||||||
|
use core::slice;
|
||||||
|
use nalgebra::{DMatrix, linalg};
|
||||||
|
|
||||||
|
#[no_mangle]
|
||||||
|
/// Provide an interface to `nalgebra::linalg::try_invert_to`
|
||||||
|
pub extern "C" fn linalg_try_invert_to(dim0: usize, dim1: usize, data: *mut f64) -> i8
|
||||||
|
{
|
||||||
|
|
||||||
|
let data_slice = unsafe { slice::from_raw_parts_mut(data, dim0 * dim1) };
|
||||||
|
let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice);
|
||||||
|
let mut inverted_matrix = DMatrix::<f64>::zeros(dim0, dim1);
|
||||||
|
|
||||||
|
if linalg::try_invert_to(matrix, &mut inverted_matrix) {
|
||||||
|
data_slice.copy_from_slice(inverted_matrix.transpose().as_slice());
|
||||||
|
1
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[no_mangle]
|
||||||
|
/// Provide an interface to `nalgebra::linalg::wilkinson_shift`
|
||||||
|
pub extern "C" fn linalg_wilkinson_shift(dim0: usize, dim1: usize, data: *mut f64) -> f64
|
||||||
|
{
|
||||||
|
let data_slice = unsafe { slice::from_raw_parts_mut(data, dim0 * dim1) };
|
||||||
|
let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice);
|
||||||
|
|
||||||
|
// Check if matrix is symmetric
|
||||||
|
assert!(matrix[(0, 1)] == matrix[(1, 0)], "Operation Wilkinson Shift expects symmetric matrix");
|
||||||
|
return linalg::wilkinson_shift(matrix[(0, 0)], matrix[(1, 1)], matrix[(0, 1)]);
|
||||||
|
}
|
@ -141,6 +141,26 @@ def patch(module):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def try_invert_to(x):
|
||||||
|
try:
|
||||||
|
y = np.linalg.inv(x)
|
||||||
|
x[:] = y
|
||||||
|
except np.linalg.LinAlgError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def wilkinson_shift(x):
|
||||||
|
assert (len(x.flatten()) == 4) and (x[0, 1] == x[1, 0]), f"Operation Wilkinson Shift expects symmetric matrix"
|
||||||
|
tmm, tnn, tmn = x[0, 0], x[1, 1], x[0, 1]
|
||||||
|
sq_tmn = tmn * tmn
|
||||||
|
if sq_tmn != 0:
|
||||||
|
d = (tmm - tnn) * 0.5
|
||||||
|
if d > 0:
|
||||||
|
return tnn - sq_tmn / (d + np.sqrt(d*d + sq_tmn))
|
||||||
|
else:
|
||||||
|
return tnn - sq_tmn / (d - np.sqrt(d*d + sq_tmn))
|
||||||
|
return tnn
|
||||||
|
|
||||||
module.int32 = int32
|
module.int32 = int32
|
||||||
module.int64 = int64
|
module.int64 = int64
|
||||||
module.uint32 = uint32
|
module.uint32 = uint32
|
||||||
@ -234,6 +254,8 @@ def patch(module):
|
|||||||
module.np_full = np.full
|
module.np_full = np.full
|
||||||
module.np_eye = np.eye
|
module.np_eye = np.eye
|
||||||
module.np_identity = np.identity
|
module.np_identity = np.identity
|
||||||
|
module.try_invert_to = try_invert_to
|
||||||
|
module.wilkinson_shift = wilkinson_shift
|
||||||
|
|
||||||
def file_import(filename, prefix="file_import_"):
|
def file_import(filename, prefix="file_import_"):
|
||||||
filename = pathlib.Path(filename)
|
filename = pathlib.Path(filename)
|
||||||
|
@ -42,11 +42,14 @@ done
|
|||||||
|
|
||||||
if [ -n "$debug" ] && [ -e ../../target/debug/nac3standalone ]; then
|
if [ -n "$debug" ] && [ -e ../../target/debug/nac3standalone ]; then
|
||||||
nac3standalone=../../target/debug/nac3standalone
|
nac3standalone=../../target/debug/nac3standalone
|
||||||
|
externfns=../../target/debug/deps/libexternfns.so
|
||||||
elif [ -e ../../target/release/nac3standalone ]; then
|
elif [ -e ../../target/release/nac3standalone ]; then
|
||||||
nac3standalone=../../target/release/nac3standalone
|
nac3standalone=../../target/release/nac3standalone
|
||||||
|
externfns=../../target/release/deps/libexternfns.so
|
||||||
else
|
else
|
||||||
# used by Nix builds
|
# used by Nix builds
|
||||||
nac3standalone=../../target/x86_64-unknown-linux-gnu/release/nac3standalone
|
nac3standalone=../../target/x86_64-unknown-linux-gnu/release/nac3standalone
|
||||||
|
externfns=../../target/x86_64-unknown-linux-gnu/release/deps/libexternfns.so
|
||||||
fi
|
fi
|
||||||
|
|
||||||
rm -f ./*.o ./*.bc demo
|
rm -f ./*.o ./*.bc demo
|
||||||
@ -54,7 +57,7 @@ if [ -z "$use_lli" ]; then
|
|||||||
$nac3standalone "${nac3args[@]}"
|
$nac3standalone "${nac3args[@]}"
|
||||||
|
|
||||||
clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c
|
clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c
|
||||||
clang -lm -o demo module.o demo.o
|
clang -lm -o demo module.o demo.o $externfns
|
||||||
|
|
||||||
if [ -z "$outfile" ]; then
|
if [ -z "$outfile" ]; then
|
||||||
./demo
|
./demo
|
||||||
@ -71,8 +74,8 @@ else
|
|||||||
shopt -u nullglob
|
shopt -u nullglob
|
||||||
|
|
||||||
if [ -z "$outfile" ]; then
|
if [ -z "$outfile" ]; then
|
||||||
lli --extra-module demo.bc --extra-module irrt.bc nac3out.bc
|
lli -load=$externfns --extra-module demo.bc --extra-module irrt.bc nac3out.bc
|
||||||
else
|
else
|
||||||
lli --extra-module demo.bc --extra-module irrt.bc nac3out.bc > "$outfile"
|
lli -load=$externfns --extra-module demo.bc --extra-module irrt.bc nac3out.bc > "$outfile"
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
@ -1429,7 +1429,25 @@ def test_ndarray_nextafter_broadcast_rhs_scalar():
|
|||||||
output_ndarray_float_2(nextafter_x_zeros)
|
output_ndarray_float_2(nextafter_x_zeros)
|
||||||
output_ndarray_float_2(nextafter_x_ones)
|
output_ndarray_float_2(nextafter_x_ones)
|
||||||
|
|
||||||
|
def test_try_invert():
|
||||||
|
x: ndarray[float, 2] = np_array([[1.0, 1.0], [1.0, 5.0]])
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
|
||||||
|
y = try_invert_to(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_bool(y)
|
||||||
|
|
||||||
|
def test_wilkinson_shift():
|
||||||
|
x: ndarray[float, 2] = np_array([[5., 1.], [1., 4.]])
|
||||||
|
y = wilkinson_shift(x)
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_float64(y)
|
||||||
|
|
||||||
def run() -> int32:
|
def run() -> int32:
|
||||||
|
test_try_invert()
|
||||||
|
test_wilkinson_shift()
|
||||||
test_ndarray_ctor()
|
test_ndarray_ctor()
|
||||||
test_ndarray_empty()
|
test_ndarray_empty()
|
||||||
test_ndarray_zeros()
|
test_ndarray_zeros()
|
||||||
|
BIN
pyo3_output/nac3artiq.so
Executable file
BIN
pyo3_output/nac3artiq.so
Executable file
Binary file not shown.
Loading…
Reference in New Issue
Block a user