forked from M-Labs/nac3
WIP
This commit is contained in:
parent
6c10e3d056
commit
6ad597e592
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,3 +1,4 @@
|
||||
__pycache__
|
||||
/target
|
||||
nix/windows/msys2
|
||||
nac3standalone/demo/externfns/target
|
||||
|
99
Cargo.lock
generated
99
Cargo.lock
generated
@ -73,6 +73,15 @@ dependencies = [
|
||||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "approx"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ascii-canvas"
|
||||
version = "3.0.0"
|
||||
@ -305,6 +314,13 @@ dependencies = [
|
||||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "externfns"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"nalgebra",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastrand"
|
||||
version = "2.1.0"
|
||||
@ -521,6 +537,12 @@ dependencies = [
|
||||
"windows-targets",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "libm"
|
||||
version = "0.2.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
|
||||
|
||||
[[package]]
|
||||
name = "libredox"
|
||||
version = "0.1.3"
|
||||
@ -658,18 +680,71 @@ name = "nac3standalone"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"clap",
|
||||
"externfns",
|
||||
"inkwell",
|
||||
"nac3core",
|
||||
"nac3parser",
|
||||
"parking_lot",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nalgebra"
|
||||
version = "0.32.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7b5c17de023a86f59ed79891b2e5d5a94c705dbe904a5b5c9c952ea6221b03e4"
|
||||
dependencies = [
|
||||
"approx",
|
||||
"num-complex",
|
||||
"num-rational",
|
||||
"num-traits",
|
||||
"simba",
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "new_debug_unreachable"
|
||||
version = "1.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086"
|
||||
|
||||
[[package]]
|
||||
name = "num-complex"
|
||||
version = "0.4.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-integer"
|
||||
version = "0.1.46"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-rational"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
|
||||
dependencies = [
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"libm",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.19.0"
|
||||
@ -699,6 +774,12 @@ dependencies = [
|
||||
"windows-targets",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "paste"
|
||||
version = "1.0.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
|
||||
|
||||
[[package]]
|
||||
name = "petgraph"
|
||||
version = "0.6.5"
|
||||
@ -1070,6 +1151,18 @@ dependencies = [
|
||||
"yaml-rust",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "simba"
|
||||
version = "0.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "061507c94fc6ab4ba1c9a0305018408e312e17c041eb63bef8aa726fa33aceae"
|
||||
dependencies = [
|
||||
"approx",
|
||||
"num-complex",
|
||||
"num-traits",
|
||||
"paste",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "similar"
|
||||
version = "2.5.0"
|
||||
@ -1230,6 +1323,12 @@ dependencies = [
|
||||
"crunchy",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.17.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
|
||||
|
||||
[[package]]
|
||||
name = "unic-char-property"
|
||||
version = "0.9.0"
|
||||
|
@ -5,6 +5,7 @@ members = [
|
||||
"nac3parser",
|
||||
"nac3core",
|
||||
"nac3standalone",
|
||||
"nac3standalone/demo/externfns",
|
||||
"nac3artiq",
|
||||
"runkernel",
|
||||
]
|
||||
|
@ -161,7 +161,9 @@
|
||||
clippy
|
||||
pre-commit
|
||||
rustfmt
|
||||
rust-analyzer
|
||||
];
|
||||
RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
|
||||
};
|
||||
devShells.x86_64-linux.msys2 = pkgs.mkShell {
|
||||
name = "nac3-dev-shell-msys2";
|
||||
|
@ -14,12 +14,23 @@ class Demo:
|
||||
|
||||
@kernel
|
||||
def run(self):
|
||||
self.core.reset()
|
||||
while True:
|
||||
with parallel:
|
||||
self.led0.pulse(100.*ms)
|
||||
self.led1.pulse(100.*ms)
|
||||
self.core.delay(100.*ms)
|
||||
a = np_array([[1., 2.], [3., 4.]])
|
||||
b = try_invert_to(a)
|
||||
if b:
|
||||
# self.core.reset()
|
||||
# while True:
|
||||
# with parallel:
|
||||
# self.led0.pulse(100.*ms)
|
||||
# self.led1.pulse(100.*ms)
|
||||
# self.core.delay(100.*ms)
|
||||
v = try_invert_to(np_identity(2))
|
||||
if v:
|
||||
while True:
|
||||
with parallel:
|
||||
self.led0.pulse(100.*ms)
|
||||
self.led1.pulse(100.*ms)
|
||||
self.core.delay(100.*ms)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
BIN
nac3artiq/demo/module.elf
Normal file
BIN
nac3artiq/demo/module.elf
Normal file
Binary file not shown.
@ -3,7 +3,7 @@ use inkwell::values::BasicValueEnum;
|
||||
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::codegen::classes::{NDArrayValue, ProxyValue, UntypedArrayLikeAccessor};
|
||||
use crate::codegen::classes::{ArrayLikeValue, NDArrayValue, ProxyValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor};
|
||||
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
|
||||
use crate::codegen::stmt::gen_for_callback_incrementing;
|
||||
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
|
||||
@ -922,6 +922,122 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
|
||||
})
|
||||
}
|
||||
|
||||
pub fn call_try_invert_to<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
a: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "linalg_try_invert_to";
|
||||
let (a_ty, a) = a;
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
match a {
|
||||
BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||
{
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
|
||||
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
match llvm_ndarray_ty {
|
||||
BasicTypeEnum::FloatType(_) => {},
|
||||
_ => unreachable!("Inverse Operation supported on float type NDArray Values only")
|
||||
};
|
||||
|
||||
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
||||
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None));
|
||||
|
||||
// Add asserts for dims
|
||||
if cfg!(debug_assertions) {
|
||||
let n_dims = n.load_ndims(ctx);
|
||||
|
||||
// num_dim == 2
|
||||
ctx.make_assert(generator,ctx.builder.build_int_compare(IntPredicate::EQ, n_dims, n_dims.get_type().const_int(2, false), "").unwrap(),
|
||||
"0:ValueError", format!("Inverse only supported on 2D lists").as_str(), [None, None, None], ctx.current_loc);
|
||||
|
||||
let dim0 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)};
|
||||
let dim1 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)};
|
||||
|
||||
// dim0 == dim1
|
||||
ctx.make_assert(generator,ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim1, "").unwrap(),
|
||||
"0:ValueError", format!("Dimensions do not match {dim0} is not same as {dim1}").as_str(), [None, None, None], ctx.current_loc);
|
||||
|
||||
}
|
||||
|
||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||
let n_sz_eqz = ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "")
|
||||
.unwrap();
|
||||
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
n_sz_eqz,
|
||||
"0:ValueError",
|
||||
format!("zero-size array to reduction operation {FN_NAME}").as_str(),
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
}
|
||||
|
||||
// Create a call to linalg_try_invert_to
|
||||
let dim0 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)};
|
||||
let dim1 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)};
|
||||
|
||||
Ok(extern_fns::call_linalg_try_invert_to(ctx, dim0, dim1, n.data().base_ptr(ctx, generator), None).into())
|
||||
}
|
||||
_ => unsupported_type(ctx, FN_NAME, &[a_ty]),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn call_wilkinson_shift<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
a: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "linalg_wilkinson_shift";
|
||||
let (a_ty, a) = a;
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
match a {
|
||||
BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||
{
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
|
||||
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
match llvm_ndarray_ty {
|
||||
BasicTypeEnum::FloatType(_) | BasicTypeEnum::IntType(_) => {},
|
||||
_ => unreachable!("Wilkinson Shift Operation supported on float/integer type NDArray Values only")
|
||||
};
|
||||
|
||||
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
||||
|
||||
// Add asserts for dims
|
||||
if cfg!(debug_assertions) {
|
||||
let n_dims = n.load_ndims(ctx);
|
||||
|
||||
// num_dim == 2
|
||||
ctx.make_assert(generator,ctx.builder.build_int_compare(IntPredicate::EQ, n_dims, n_dims.get_type().const_int(2, false), "").unwrap(),
|
||||
"0:ValueError", format!("Wilkinson Shift supported only on 2D lists").as_str(), [None, None, None], ctx.current_loc);
|
||||
|
||||
let dim0 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)};
|
||||
let dim1 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)};
|
||||
|
||||
// dim0 == dim1
|
||||
ctx.make_assert(generator,ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim1, "").unwrap(),
|
||||
"0:ValueError", format!("Dimensions do not match {dim0} is not same as {dim1}").as_str(), [None, None, None], ctx.current_loc);
|
||||
|
||||
// dimesions should be 2x2
|
||||
ctx.make_assert(generator,ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim0.get_type().const_int(2, false), "").unwrap(),
|
||||
"0:ValueError", format!("Wilkinson Shift supported only on 2x2 matrices").as_str(), [None, None, None], ctx.current_loc);
|
||||
}
|
||||
|
||||
// Create a call to linalg_try_invert_to
|
||||
let dim0 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)};
|
||||
let dim1 = unsafe {n.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)};
|
||||
|
||||
Ok(extern_fns::call_linalg_wilkinson_shift(ctx, dim0, dim1, n.data().base_ptr(ctx, generator), None).into())
|
||||
}
|
||||
_ => unsupported_type(ctx, FN_NAME, &[a_ty]),
|
||||
}
|
||||
}
|
||||
|
||||
/// Invokes the `np_maximum` builtin function.
|
||||
pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
|
@ -1,5 +1,5 @@
|
||||
use inkwell::attributes::{Attribute, AttributeLoc};
|
||||
use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue};
|
||||
use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue};
|
||||
use itertools::Either;
|
||||
|
||||
use crate::codegen::CodeGenContext;
|
||||
@ -130,3 +130,82 @@ pub fn call_ldexp<'ctx>(
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`try_invert_to`](https://docs.rs/nalgebra/latest/nalgebra/linalg/fn.try_invert_to.html) function
|
||||
pub fn call_linalg_try_invert_to<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
dim0: IntValue<'ctx>,
|
||||
dim1: IntValue<'ctx>,
|
||||
data: PointerValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> IntValue<'ctx> {
|
||||
const FN_NAME: &str = "linalg_try_invert_to";
|
||||
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
let allowed_indices = [ctx.ctx.i32_type(), ctx.ctx.i64_type()];
|
||||
|
||||
let allowed_dim0 = allowed_indices.iter().any(|p| *p == dim0.get_type());
|
||||
let allowed_dim1 = allowed_indices.iter().any(|p| *p == dim1.get_type());
|
||||
|
||||
debug_assert!(allowed_dim0);
|
||||
debug_assert!(allowed_dim1);
|
||||
debug_assert_eq!(data.get_type().get_element_type().into_float_type(), llvm_f64);
|
||||
|
||||
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||
let fn_type = ctx.ctx.i8_type().fn_type(&[dim0.get_type().into(), dim0.get_type().into(), data.get_type().into()], false);
|
||||
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
||||
);
|
||||
}
|
||||
|
||||
func
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(extern_fn, &[dim0.into(), dim1.into(), data.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap().into()
|
||||
}
|
||||
|
||||
/// Invokes the [`wilkinson_shift`](https://docs.rs/nalgebra/latest/nalgebra/linalg/fn.wilkinson_shift.html) function
|
||||
pub fn call_linalg_wilkinson_shift<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
dim0: IntValue<'ctx>,
|
||||
dim1: IntValue<'ctx>,
|
||||
data: PointerValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "linalg_wilkinson_shift";
|
||||
let allowed_index_types = [ctx.ctx.i32_type(), ctx.ctx.i64_type()];
|
||||
|
||||
let allowed_dim0 = allowed_index_types.iter().any(|p| *p == dim0.get_type());
|
||||
let allowed_dim1 = allowed_index_types.iter().any(|p| *p == dim1.get_type());
|
||||
|
||||
debug_assert!(allowed_dim0);
|
||||
debug_assert!(allowed_dim1);
|
||||
|
||||
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||
let fn_type = ctx.ctx.f64_type().fn_type(&[dim0.get_type().into(), dim0.get_type().into(), data.get_type().into()], false);
|
||||
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
||||
);
|
||||
}
|
||||
|
||||
func
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(extern_fn, &[dim0.into(), dim1.into(), data.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap().into()
|
||||
}
|
||||
|
@ -556,6 +556,9 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
| PrimDef::FunNpLdExp
|
||||
| PrimDef::FunNpHypot
|
||||
| PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim),
|
||||
|
||||
PrimDef::FunTryInvertTo => self.build_linalg_try_invert_to(prim), // Inplace invert
|
||||
PrimDef::FunWilkinsonShift => self.build_linalg_wilkinson_shift(prim),
|
||||
};
|
||||
|
||||
if cfg!(debug_assertions) {
|
||||
@ -1874,6 +1877,64 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
fn build_linalg_try_invert_to(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||
debug_assert_prim_is_allowed(
|
||||
prim,
|
||||
&[
|
||||
PrimDef::FunTryInvertTo,
|
||||
],
|
||||
);
|
||||
|
||||
let var_map = self.num_or_ndarray_var_map.clone();
|
||||
create_fn_by_codegen(
|
||||
self.unifier,
|
||||
&var_map,
|
||||
prim.name(),
|
||||
self.primitives.bool,
|
||||
&[(self.ndarray_float_2d, "x")],
|
||||
Box::new(move |ctx, _, fun, args, generator| {
|
||||
let x_ty = fun.0.args[0].ty;
|
||||
let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?;
|
||||
|
||||
let func = match prim {
|
||||
PrimDef::FunTryInvertTo => builtin_fns::call_try_invert_to,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Ok(Some(func(generator, ctx, (x_ty, x_val))?))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
fn build_linalg_wilkinson_shift(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||
debug_assert_prim_is_allowed(
|
||||
prim,
|
||||
&[
|
||||
PrimDef::FunWilkinsonShift,
|
||||
],
|
||||
);
|
||||
|
||||
let var_map = self.num_or_ndarray_var_map.clone();
|
||||
create_fn_by_codegen(
|
||||
self.unifier,
|
||||
&var_map,
|
||||
prim.name(),
|
||||
self.primitives.float,
|
||||
&[(self.ndarray_float_2d, "x")],
|
||||
Box::new(move |ctx, _, fun, args, generator| {
|
||||
let x_ty = fun.0.args[0].ty;
|
||||
let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?;
|
||||
|
||||
let func = match prim {
|
||||
PrimDef::FunWilkinsonShift => builtin_fns::call_wilkinson_shift,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Ok(Some(func(generator, ctx, (x_ty, x_val))?))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) {
|
||||
(prim.simple_name().into(), method_ty, prim.id())
|
||||
}
|
||||
|
@ -105,6 +105,8 @@ pub enum PrimDef {
|
||||
FunNpLdExp,
|
||||
FunNpHypot,
|
||||
FunNpNextAfter,
|
||||
FunTryInvertTo,
|
||||
FunWilkinsonShift,
|
||||
|
||||
// Top-Level Functions
|
||||
FunSome,
|
||||
@ -261,6 +263,8 @@ impl PrimDef {
|
||||
PrimDef::FunNpLdExp => fun("np_ldexp", None),
|
||||
PrimDef::FunNpHypot => fun("np_hypot", None),
|
||||
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
|
||||
PrimDef::FunTryInvertTo => fun("try_invert_to", None),
|
||||
PrimDef::FunWilkinsonShift => fun("wilkinson_shift", None),
|
||||
PrimDef::FunSome => fun("Some", None),
|
||||
}
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ edition = "2021"
|
||||
parking_lot = "0.12"
|
||||
nac3parser = { path = "../nac3parser" }
|
||||
nac3core = { path = "../nac3core" }
|
||||
externfns = { path = "./demo/externfns" }
|
||||
|
||||
[dependencies.clap]
|
||||
version = "4.5"
|
||||
|
10
nac3standalone/demo/externfns/Cargo.toml
Normal file
10
nac3standalone/demo/externfns/Cargo.toml
Normal file
@ -0,0 +1,10 @@
|
||||
[package]
|
||||
name = "externfns"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
nalgebra = {version = "0.32.6", default-features = false, features = ["libm", "alloc"]}
|
41
nac3standalone/demo/externfns/src/lib.rs
Normal file
41
nac3standalone/demo/externfns/src/lib.rs
Normal file
@ -0,0 +1,41 @@
|
||||
#![deny(
|
||||
future_incompatible,
|
||||
let_underscore,
|
||||
nonstandard_style,
|
||||
rust_2024_compatibility,
|
||||
clippy::all
|
||||
)]
|
||||
#![warn(clippy::pedantic)]
|
||||
#![allow(clippy::semicolon_if_nothing_returned, clippy::uninlined_format_args)]
|
||||
|
||||
use core::slice;
|
||||
use nalgebra::{DMatrix, linalg};
|
||||
|
||||
#[no_mangle]
|
||||
/// Provide an interface to `nalgebra::linalg::try_invert_to`
|
||||
pub extern "C" fn linalg_try_invert_to(dim0: usize, dim1: usize, data: *mut f64) -> i8
|
||||
{
|
||||
|
||||
let data_slice = unsafe { slice::from_raw_parts_mut(data, dim0 * dim1) };
|
||||
let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice);
|
||||
let mut inverted_matrix = DMatrix::<f64>::zeros(dim0, dim1);
|
||||
|
||||
if linalg::try_invert_to(matrix, &mut inverted_matrix) {
|
||||
data_slice.copy_from_slice(inverted_matrix.transpose().as_slice());
|
||||
1
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
/// Provide an interface to `nalgebra::linalg::wilkinson_shift`
|
||||
pub extern "C" fn linalg_wilkinson_shift(dim0: usize, dim1: usize, data: *mut f64) -> f64
|
||||
{
|
||||
let data_slice = unsafe { slice::from_raw_parts_mut(data, dim0 * dim1) };
|
||||
let matrix = DMatrix::from_row_slice(dim0, dim1, data_slice);
|
||||
|
||||
// Check if matrix is symmetric
|
||||
assert!(matrix[(0, 1)] == matrix[(1, 0)], "Operation Wilkinson Shift expects symmetric matrix");
|
||||
return linalg::wilkinson_shift(matrix[(0, 0)], matrix[(1, 1)], matrix[(0, 1)]);
|
||||
}
|
@ -141,6 +141,26 @@ def patch(module):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def try_invert_to(x):
|
||||
try:
|
||||
y = np.linalg.inv(x)
|
||||
x[:] = y
|
||||
except np.linalg.LinAlgError:
|
||||
return False
|
||||
return True
|
||||
|
||||
def wilkinson_shift(x):
|
||||
assert (len(x.flatten()) == 4) and (x[0, 1] == x[1, 0]), f"Operation Wilkinson Shift expects symmetric matrix"
|
||||
tmm, tnn, tmn = x[0, 0], x[1, 1], x[0, 1]
|
||||
sq_tmn = tmn * tmn
|
||||
if sq_tmn != 0:
|
||||
d = (tmm - tnn) * 0.5
|
||||
if d > 0:
|
||||
return tnn - sq_tmn / (d + np.sqrt(d*d + sq_tmn))
|
||||
else:
|
||||
return tnn - sq_tmn / (d - np.sqrt(d*d + sq_tmn))
|
||||
return tnn
|
||||
|
||||
module.int32 = int32
|
||||
module.int64 = int64
|
||||
module.uint32 = uint32
|
||||
@ -234,6 +254,8 @@ def patch(module):
|
||||
module.np_full = np.full
|
||||
module.np_eye = np.eye
|
||||
module.np_identity = np.identity
|
||||
module.try_invert_to = try_invert_to
|
||||
module.wilkinson_shift = wilkinson_shift
|
||||
|
||||
def file_import(filename, prefix="file_import_"):
|
||||
filename = pathlib.Path(filename)
|
||||
|
@ -42,11 +42,14 @@ done
|
||||
|
||||
if [ -n "$debug" ] && [ -e ../../target/debug/nac3standalone ]; then
|
||||
nac3standalone=../../target/debug/nac3standalone
|
||||
externfns=../../target/debug/deps/libexternfns.so
|
||||
elif [ -e ../../target/release/nac3standalone ]; then
|
||||
nac3standalone=../../target/release/nac3standalone
|
||||
externfns=../../target/release/deps/libexternfns.so
|
||||
else
|
||||
# used by Nix builds
|
||||
nac3standalone=../../target/x86_64-unknown-linux-gnu/release/nac3standalone
|
||||
externfns=../../target/x86_64-unknown-linux-gnu/release/deps/libexternfns.so
|
||||
fi
|
||||
|
||||
rm -f ./*.o ./*.bc demo
|
||||
@ -54,7 +57,7 @@ if [ -z "$use_lli" ]; then
|
||||
$nac3standalone "${nac3args[@]}"
|
||||
|
||||
clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c
|
||||
clang -lm -o demo module.o demo.o
|
||||
clang -lm -o demo module.o demo.o $externfns
|
||||
|
||||
if [ -z "$outfile" ]; then
|
||||
./demo
|
||||
@ -71,8 +74,8 @@ else
|
||||
shopt -u nullglob
|
||||
|
||||
if [ -z "$outfile" ]; then
|
||||
lli --extra-module demo.bc --extra-module irrt.bc nac3out.bc
|
||||
lli -load=$externfns --extra-module demo.bc --extra-module irrt.bc nac3out.bc
|
||||
else
|
||||
lli --extra-module demo.bc --extra-module irrt.bc nac3out.bc > "$outfile"
|
||||
lli -load=$externfns --extra-module demo.bc --extra-module irrt.bc nac3out.bc > "$outfile"
|
||||
fi
|
||||
fi
|
||||
|
@ -1429,7 +1429,25 @@ def test_ndarray_nextafter_broadcast_rhs_scalar():
|
||||
output_ndarray_float_2(nextafter_x_zeros)
|
||||
output_ndarray_float_2(nextafter_x_ones)
|
||||
|
||||
def test_try_invert():
|
||||
x: ndarray[float, 2] = np_array([[1.0, 1.0], [1.0, 5.0]])
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
|
||||
y = try_invert_to(x)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_bool(y)
|
||||
|
||||
def test_wilkinson_shift():
|
||||
x: ndarray[float, 2] = np_array([[5., 1.], [1., 4.]])
|
||||
y = wilkinson_shift(x)
|
||||
output_ndarray_float_2(x)
|
||||
output_float64(y)
|
||||
|
||||
def run() -> int32:
|
||||
test_try_invert()
|
||||
test_wilkinson_shift()
|
||||
test_ndarray_ctor()
|
||||
test_ndarray_empty()
|
||||
test_ndarray_zeros()
|
||||
|
BIN
pyo3_output/nac3artiq.so
Executable file
BIN
pyo3_output/nac3artiq.so
Executable file
Binary file not shown.
Loading…
Reference in New Issue
Block a user