forked from M-Labs/nac3
core: add nalgebra::linalg methods
This commit is contained in:
parent
44487b76ae
commit
2dddab1fcf
@ -3,7 +3,7 @@ use inkwell::values::BasicValueEnum;
|
||||
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::codegen::classes::{NDArrayValue, ProxyValue, UntypedArrayLikeAccessor};
|
||||
use crate::codegen::classes::{ArrayLikeValue, NDArrayValue, ProxyValue, UntypedArrayLikeAccessor};
|
||||
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
|
||||
use crate::codegen::stmt::gen_for_callback_incrementing;
|
||||
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
|
||||
@ -1835,3 +1835,232 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
|
||||
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]),
|
||||
})
|
||||
}
|
||||
|
||||
/// Invokes the `linalg_try_invert_to` function
|
||||
pub fn call_linalg_try_invert_to<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
a: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "linalg_try_invert_to";
|
||||
let (a_ty, a) = a;
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
match a {
|
||||
BasicValueEnum::PointerValue(n)
|
||||
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||
{
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
|
||||
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
match llvm_ndarray_ty {
|
||||
BasicTypeEnum::FloatType(_) => {}
|
||||
_ => {
|
||||
unimplemented!("Inverse Operation supported on float type NDArray Values only")
|
||||
}
|
||||
};
|
||||
|
||||
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
||||
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None));
|
||||
|
||||
// The following constraints must be satisfied:
|
||||
// * Input must be 2D
|
||||
// * number of rows should equal number of columns (square matrix)
|
||||
if cfg!(debug_assertions) {
|
||||
let n_dims = n.load_ndims(ctx);
|
||||
|
||||
// num_dim == 2
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::EQ,
|
||||
n_dims,
|
||||
llvm_usize.const_int(2, false),
|
||||
"",
|
||||
)
|
||||
.unwrap(),
|
||||
"0:ValueError",
|
||||
format!("Input matrix must have two dimensions for {FN_NAME}").as_str(),
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
let dim0 = unsafe {
|
||||
n.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let dim1 = unsafe {
|
||||
n.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
.into_int_value()
|
||||
};
|
||||
|
||||
// dim0 == dim1
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim1, "").unwrap(),
|
||||
"0:ValueError",
|
||||
format!(
|
||||
"Input matrix should have equal number of rows and columns for {FN_NAME}"
|
||||
)
|
||||
.as_str(),
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
}
|
||||
|
||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||
let n_sz_eqz = ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "")
|
||||
.unwrap();
|
||||
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
n_sz_eqz,
|
||||
"0:ValueError",
|
||||
format!("zero-size array to inverse operation {FN_NAME}").as_str(),
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
}
|
||||
|
||||
let dim0 = unsafe {
|
||||
n.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let dim1 = unsafe {
|
||||
n.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
.into_int_value()
|
||||
};
|
||||
|
||||
Ok(extern_fns::call_linalg_try_invert_to(
|
||||
ctx,
|
||||
dim0,
|
||||
dim1,
|
||||
n.data().base_ptr(ctx, generator),
|
||||
None,
|
||||
)
|
||||
.into())
|
||||
}
|
||||
_ => unsupported_type(ctx, FN_NAME, &[a_ty]),
|
||||
}
|
||||
}
|
||||
|
||||
/// Invokes the `linalg_wilkinson_shift` function
|
||||
pub fn call_linalg_wilkinson_shift<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
a: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "linalg_wilkinson_shift";
|
||||
let (a_ty, a) = a;
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let one = llvm_usize.const_int(1, false);
|
||||
let two = llvm_usize.const_int(2, false);
|
||||
|
||||
match a {
|
||||
BasicValueEnum::PointerValue(n)
|
||||
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||
{
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
|
||||
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
match llvm_ndarray_ty {
|
||||
BasicTypeEnum::FloatType(_) | BasicTypeEnum::IntType(_) => {}
|
||||
_ => unimplemented!(
|
||||
"Wilkinson Shift Operation supported on float type NDArray Values only"
|
||||
),
|
||||
};
|
||||
|
||||
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
||||
|
||||
// The following constraints must be satisfied:
|
||||
// * Input must be 2D
|
||||
// * Number of rows and columns should equal 2
|
||||
// * Input matrix must be symmetric
|
||||
if cfg!(debug_assertions) {
|
||||
let n_dims = n.load_ndims(ctx);
|
||||
|
||||
// num_dim == 2
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder.build_int_compare(IntPredicate::EQ, n_dims, two, "").unwrap(),
|
||||
"0:ValueError",
|
||||
format!("Input matrix must have two dimensions for {FN_NAME}").as_str(),
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
let dim0 = unsafe {
|
||||
n.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let dim1 = unsafe {
|
||||
n.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value()
|
||||
};
|
||||
|
||||
// dim0 == 2
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder.build_int_compare(IntPredicate::EQ, dim0, two, "").unwrap(),
|
||||
"0:ValueError",
|
||||
format!("Number of rows must be 2 for {FN_NAME}").as_str(),
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
// dim1 == 2
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder.build_int_compare(IntPredicate::EQ, dim1, two, "").unwrap(),
|
||||
"0:ValueError",
|
||||
format!("Number of columns must be 2 for {FN_NAME}").as_str(),
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
let entry_01 = unsafe {
|
||||
n.data().get_unchecked(ctx, generator, &one, None).into_float_value()
|
||||
};
|
||||
let entry_10 = unsafe {
|
||||
n.data().get_unchecked(ctx, generator, &two, None).into_float_value()
|
||||
};
|
||||
|
||||
// symmetric matrix
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder
|
||||
.build_float_compare(FloatPredicate::OEQ, entry_01, entry_10, "")
|
||||
.unwrap(),
|
||||
"0:ValueError",
|
||||
format!("Input Matrix must be symmetric for {FN_NAME}").as_str(),
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
}
|
||||
|
||||
let dim0 = unsafe {
|
||||
n.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let dim1 =
|
||||
unsafe { n.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value() };
|
||||
|
||||
Ok(extern_fns::call_linalg_wilkinson_shift(
|
||||
ctx,
|
||||
dim0,
|
||||
dim1,
|
||||
n.data().base_ptr(ctx, generator),
|
||||
None,
|
||||
)
|
||||
.into())
|
||||
}
|
||||
_ => unsupported_type(ctx, FN_NAME, &[a_ty]),
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
use inkwell::attributes::{Attribute, AttributeLoc};
|
||||
use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue};
|
||||
use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue};
|
||||
use itertools::Either;
|
||||
|
||||
use crate::codegen::CodeGenContext;
|
||||
@ -130,3 +130,91 @@ pub fn call_ldexp<'ctx>(
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`try_invert_to`](https://docs.rs/nalgebra/latest/nalgebra/linalg/fn.try_invert_to.html) function
|
||||
pub fn call_linalg_try_invert_to<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
dim0: IntValue<'ctx>,
|
||||
dim1: IntValue<'ctx>,
|
||||
data: PointerValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> IntValue<'ctx> {
|
||||
const FN_NAME: &str = "linalg_try_invert_to";
|
||||
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
let allowed_indices = [ctx.ctx.i32_type(), ctx.ctx.i64_type()];
|
||||
|
||||
let allowed_dim0 = allowed_indices.iter().any(|p| *p == dim0.get_type());
|
||||
let allowed_dim1 = allowed_indices.iter().any(|p| *p == dim1.get_type());
|
||||
|
||||
debug_assert!(allowed_dim0);
|
||||
debug_assert!(allowed_dim1);
|
||||
debug_assert_eq!(data.get_type().get_element_type().into_float_type(), llvm_f64);
|
||||
|
||||
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||
let fn_type = ctx.ctx.i8_type().fn_type(
|
||||
&[dim0.get_type().into(), dim0.get_type().into(), data.get_type().into()],
|
||||
false,
|
||||
);
|
||||
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
||||
);
|
||||
}
|
||||
|
||||
func
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(extern_fn, &[dim0.into(), dim1.into(), data.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`wilkinson_shift`](https://docs.rs/nalgebra/latest/nalgebra/linalg/fn.wilkinson_shift.html) function
|
||||
pub fn call_linalg_wilkinson_shift<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
dim0: IntValue<'ctx>,
|
||||
dim1: IntValue<'ctx>,
|
||||
data: PointerValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> FloatValue<'ctx> {
|
||||
const FN_NAME: &str = "linalg_wilkinson_shift";
|
||||
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
let allowed_index_types = [ctx.ctx.i32_type(), ctx.ctx.i64_type()];
|
||||
|
||||
let allowed_dim0 = allowed_index_types.iter().any(|p| *p == dim0.get_type());
|
||||
let allowed_dim1 = allowed_index_types.iter().any(|p| *p == dim1.get_type());
|
||||
|
||||
debug_assert!(allowed_dim0);
|
||||
debug_assert!(allowed_dim1);
|
||||
debug_assert_eq!(data.get_type().get_element_type().into_float_type(), llvm_f64);
|
||||
|
||||
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||
let fn_type = ctx.ctx.f64_type().fn_type(
|
||||
&[dim0.get_type().into(), dim0.get_type().into(), data.get_type().into()],
|
||||
false,
|
||||
);
|
||||
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
||||
);
|
||||
}
|
||||
|
||||
func
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(extern_fn, &[dim0.into(), dim1.into(), data.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
@ -556,6 +556,8 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
| PrimDef::FunNpLdExp
|
||||
| PrimDef::FunNpHypot
|
||||
| PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim),
|
||||
|
||||
PrimDef::FunTryInvertTo | PrimDef::FunWilkinsonShift => self.build_linalg_methods(prim),
|
||||
};
|
||||
|
||||
if cfg!(debug_assertions) {
|
||||
@ -1874,6 +1876,37 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the functions `try_invert_to` and `wilkinson_shift`
|
||||
fn build_linalg_methods(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||
debug_assert_prim_is_allowed(prim, &[PrimDef::FunTryInvertTo, PrimDef::FunWilkinsonShift]);
|
||||
|
||||
let ret_ty = match prim {
|
||||
PrimDef::FunTryInvertTo => self.primitives.bool,
|
||||
PrimDef::FunWilkinsonShift => self.primitives.float,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let var_map = self.num_or_ndarray_var_map.clone();
|
||||
create_fn_by_codegen(
|
||||
self.unifier,
|
||||
&var_map,
|
||||
prim.name(),
|
||||
ret_ty,
|
||||
&[(self.ndarray_float_2d, "x")],
|
||||
Box::new(move |ctx, _, fun, args, generator| {
|
||||
let x_ty = fun.0.args[0].ty;
|
||||
let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?;
|
||||
|
||||
let func = match prim {
|
||||
PrimDef::FunTryInvertTo => builtin_fns::call_linalg_try_invert_to,
|
||||
PrimDef::FunWilkinsonShift => builtin_fns::call_linalg_wilkinson_shift,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Ok(Some(func(generator, ctx, (x_ty, x_val))?))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) {
|
||||
(prim.simple_name().into(), method_ty, prim.id())
|
||||
}
|
||||
|
@ -105,6 +105,8 @@ pub enum PrimDef {
|
||||
FunNpLdExp,
|
||||
FunNpHypot,
|
||||
FunNpNextAfter,
|
||||
FunTryInvertTo,
|
||||
FunWilkinsonShift,
|
||||
|
||||
// Top-Level Functions
|
||||
FunSome,
|
||||
@ -263,6 +265,8 @@ impl PrimDef {
|
||||
PrimDef::FunNpLdExp => fun("np_ldexp", None),
|
||||
PrimDef::FunNpHypot => fun("np_hypot", None),
|
||||
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
|
||||
PrimDef::FunTryInvertTo => fun("try_invert_to", None),
|
||||
PrimDef::FunWilkinsonShift => fun("wilkinson_shift", None),
|
||||
PrimDef::FunSome => fun("Some", None),
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user