forked from M-Labs/nac3
core: add nalgebra::linalg methods
This commit is contained in:
parent
44487b76ae
commit
2dddab1fcf
@ -3,7 +3,7 @@ use inkwell::values::BasicValueEnum;
|
|||||||
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
|
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
|
||||||
use crate::codegen::classes::{NDArrayValue, ProxyValue, UntypedArrayLikeAccessor};
|
use crate::codegen::classes::{ArrayLikeValue, NDArrayValue, ProxyValue, UntypedArrayLikeAccessor};
|
||||||
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
|
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
|
||||||
use crate::codegen::stmt::gen_for_callback_incrementing;
|
use crate::codegen::stmt::gen_for_callback_incrementing;
|
||||||
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
|
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
|
||||||
@ -1835,3 +1835,232 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]),
|
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Invokes the `linalg_try_invert_to` function
|
||||||
|
pub fn call_linalg_try_invert_to<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
a: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
const FN_NAME: &str = "linalg_try_invert_to";
|
||||||
|
let (a_ty, a) = a;
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
match a {
|
||||||
|
BasicValueEnum::PointerValue(n)
|
||||||
|
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
|
{
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
|
||||||
|
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
match llvm_ndarray_ty {
|
||||||
|
BasicTypeEnum::FloatType(_) => {}
|
||||||
|
_ => {
|
||||||
|
unimplemented!("Inverse Operation supported on float type NDArray Values only")
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
||||||
|
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None));
|
||||||
|
|
||||||
|
// The following constraints must be satisfied:
|
||||||
|
// * Input must be 2D
|
||||||
|
// * number of rows should equal number of columns (square matrix)
|
||||||
|
if cfg!(debug_assertions) {
|
||||||
|
let n_dims = n.load_ndims(ctx);
|
||||||
|
|
||||||
|
// num_dim == 2
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
ctx.builder
|
||||||
|
.build_int_compare(
|
||||||
|
IntPredicate::EQ,
|
||||||
|
n_dims,
|
||||||
|
llvm_usize.const_int(2, false),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap(),
|
||||||
|
"0:ValueError",
|
||||||
|
format!("Input matrix must have two dimensions for {FN_NAME}").as_str(),
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
|
||||||
|
let dim0 = unsafe {
|
||||||
|
n.dim_sizes()
|
||||||
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
|
.into_int_value()
|
||||||
|
};
|
||||||
|
let dim1 = unsafe {
|
||||||
|
n.dim_sizes()
|
||||||
|
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||||
|
.into_int_value()
|
||||||
|
};
|
||||||
|
|
||||||
|
// dim0 == dim1
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim1, "").unwrap(),
|
||||||
|
"0:ValueError",
|
||||||
|
format!(
|
||||||
|
"Input matrix should have equal number of rows and columns for {FN_NAME}"
|
||||||
|
)
|
||||||
|
.as_str(),
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||||
|
let n_sz_eqz = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "")
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
n_sz_eqz,
|
||||||
|
"0:ValueError",
|
||||||
|
format!("zero-size array to inverse operation {FN_NAME}").as_str(),
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let dim0 = unsafe {
|
||||||
|
n.dim_sizes()
|
||||||
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
|
.into_int_value()
|
||||||
|
};
|
||||||
|
let dim1 = unsafe {
|
||||||
|
n.dim_sizes()
|
||||||
|
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||||
|
.into_int_value()
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(extern_fns::call_linalg_try_invert_to(
|
||||||
|
ctx,
|
||||||
|
dim0,
|
||||||
|
dim1,
|
||||||
|
n.data().base_ptr(ctx, generator),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.into())
|
||||||
|
}
|
||||||
|
_ => unsupported_type(ctx, FN_NAME, &[a_ty]),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Invokes the `linalg_wilkinson_shift` function
|
||||||
|
pub fn call_linalg_wilkinson_shift<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
a: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
const FN_NAME: &str = "linalg_wilkinson_shift";
|
||||||
|
let (a_ty, a) = a;
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
let one = llvm_usize.const_int(1, false);
|
||||||
|
let two = llvm_usize.const_int(2, false);
|
||||||
|
|
||||||
|
match a {
|
||||||
|
BasicValueEnum::PointerValue(n)
|
||||||
|
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
|
{
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
|
||||||
|
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
match llvm_ndarray_ty {
|
||||||
|
BasicTypeEnum::FloatType(_) | BasicTypeEnum::IntType(_) => {}
|
||||||
|
_ => unimplemented!(
|
||||||
|
"Wilkinson Shift Operation supported on float type NDArray Values only"
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
|
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
||||||
|
|
||||||
|
// The following constraints must be satisfied:
|
||||||
|
// * Input must be 2D
|
||||||
|
// * Number of rows and columns should equal 2
|
||||||
|
// * Input matrix must be symmetric
|
||||||
|
if cfg!(debug_assertions) {
|
||||||
|
let n_dims = n.load_ndims(ctx);
|
||||||
|
|
||||||
|
// num_dim == 2
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
ctx.builder.build_int_compare(IntPredicate::EQ, n_dims, two, "").unwrap(),
|
||||||
|
"0:ValueError",
|
||||||
|
format!("Input matrix must have two dimensions for {FN_NAME}").as_str(),
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
|
||||||
|
let dim0 = unsafe {
|
||||||
|
n.dim_sizes()
|
||||||
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
|
.into_int_value()
|
||||||
|
};
|
||||||
|
let dim1 = unsafe {
|
||||||
|
n.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value()
|
||||||
|
};
|
||||||
|
|
||||||
|
// dim0 == 2
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
ctx.builder.build_int_compare(IntPredicate::EQ, dim0, two, "").unwrap(),
|
||||||
|
"0:ValueError",
|
||||||
|
format!("Number of rows must be 2 for {FN_NAME}").as_str(),
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
|
||||||
|
// dim1 == 2
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
ctx.builder.build_int_compare(IntPredicate::EQ, dim1, two, "").unwrap(),
|
||||||
|
"0:ValueError",
|
||||||
|
format!("Number of columns must be 2 for {FN_NAME}").as_str(),
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
|
||||||
|
let entry_01 = unsafe {
|
||||||
|
n.data().get_unchecked(ctx, generator, &one, None).into_float_value()
|
||||||
|
};
|
||||||
|
let entry_10 = unsafe {
|
||||||
|
n.data().get_unchecked(ctx, generator, &two, None).into_float_value()
|
||||||
|
};
|
||||||
|
|
||||||
|
// symmetric matrix
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
ctx.builder
|
||||||
|
.build_float_compare(FloatPredicate::OEQ, entry_01, entry_10, "")
|
||||||
|
.unwrap(),
|
||||||
|
"0:ValueError",
|
||||||
|
format!("Input Matrix must be symmetric for {FN_NAME}").as_str(),
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let dim0 = unsafe {
|
||||||
|
n.dim_sizes()
|
||||||
|
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
|
.into_int_value()
|
||||||
|
};
|
||||||
|
let dim1 =
|
||||||
|
unsafe { n.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value() };
|
||||||
|
|
||||||
|
Ok(extern_fns::call_linalg_wilkinson_shift(
|
||||||
|
ctx,
|
||||||
|
dim0,
|
||||||
|
dim1,
|
||||||
|
n.data().base_ptr(ctx, generator),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.into())
|
||||||
|
}
|
||||||
|
_ => unsupported_type(ctx, FN_NAME, &[a_ty]),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use inkwell::attributes::{Attribute, AttributeLoc};
|
use inkwell::attributes::{Attribute, AttributeLoc};
|
||||||
use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue};
|
use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue};
|
||||||
use itertools::Either;
|
use itertools::Either;
|
||||||
|
|
||||||
use crate::codegen::CodeGenContext;
|
use crate::codegen::CodeGenContext;
|
||||||
@ -130,3 +130,91 @@ pub fn call_ldexp<'ctx>(
|
|||||||
.map(Either::unwrap_left)
|
.map(Either::unwrap_left)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Invokes the [`try_invert_to`](https://docs.rs/nalgebra/latest/nalgebra/linalg/fn.try_invert_to.html) function
|
||||||
|
pub fn call_linalg_try_invert_to<'ctx>(
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
dim0: IntValue<'ctx>,
|
||||||
|
dim1: IntValue<'ctx>,
|
||||||
|
data: PointerValue<'ctx>,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
const FN_NAME: &str = "linalg_try_invert_to";
|
||||||
|
|
||||||
|
let llvm_f64 = ctx.ctx.f64_type();
|
||||||
|
let allowed_indices = [ctx.ctx.i32_type(), ctx.ctx.i64_type()];
|
||||||
|
|
||||||
|
let allowed_dim0 = allowed_indices.iter().any(|p| *p == dim0.get_type());
|
||||||
|
let allowed_dim1 = allowed_indices.iter().any(|p| *p == dim1.get_type());
|
||||||
|
|
||||||
|
debug_assert!(allowed_dim0);
|
||||||
|
debug_assert!(allowed_dim1);
|
||||||
|
debug_assert_eq!(data.get_type().get_element_type().into_float_type(), llvm_f64);
|
||||||
|
|
||||||
|
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||||
|
let fn_type = ctx.ctx.i8_type().fn_type(
|
||||||
|
&[dim0.get_type().into(), dim0.get_type().into(), data.get_type().into()],
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||||
|
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
|
||||||
|
func.add_attribute(
|
||||||
|
AttributeLoc::Function,
|
||||||
|
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
func
|
||||||
|
});
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_call(extern_fn, &[dim0.into(), dim1.into(), data.into()], name.unwrap_or_default())
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||||
|
.map(Either::unwrap_left)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Invokes the [`wilkinson_shift`](https://docs.rs/nalgebra/latest/nalgebra/linalg/fn.wilkinson_shift.html) function
|
||||||
|
pub fn call_linalg_wilkinson_shift<'ctx>(
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
dim0: IntValue<'ctx>,
|
||||||
|
dim1: IntValue<'ctx>,
|
||||||
|
data: PointerValue<'ctx>,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> FloatValue<'ctx> {
|
||||||
|
const FN_NAME: &str = "linalg_wilkinson_shift";
|
||||||
|
|
||||||
|
let llvm_f64 = ctx.ctx.f64_type();
|
||||||
|
let allowed_index_types = [ctx.ctx.i32_type(), ctx.ctx.i64_type()];
|
||||||
|
|
||||||
|
let allowed_dim0 = allowed_index_types.iter().any(|p| *p == dim0.get_type());
|
||||||
|
let allowed_dim1 = allowed_index_types.iter().any(|p| *p == dim1.get_type());
|
||||||
|
|
||||||
|
debug_assert!(allowed_dim0);
|
||||||
|
debug_assert!(allowed_dim1);
|
||||||
|
debug_assert_eq!(data.get_type().get_element_type().into_float_type(), llvm_f64);
|
||||||
|
|
||||||
|
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
||||||
|
let fn_type = ctx.ctx.f64_type().fn_type(
|
||||||
|
&[dim0.get_type().into(), dim0.get_type().into(), data.get_type().into()],
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
||||||
|
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
|
||||||
|
func.add_attribute(
|
||||||
|
AttributeLoc::Function,
|
||||||
|
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
func
|
||||||
|
});
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_call(extern_fn, &[dim0.into(), dim1.into(), data.into()], name.unwrap_or_default())
|
||||||
|
.map(CallSiteValue::try_as_basic_value)
|
||||||
|
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||||
|
.map(Either::unwrap_left)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
@ -556,6 +556,8 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
| PrimDef::FunNpLdExp
|
| PrimDef::FunNpLdExp
|
||||||
| PrimDef::FunNpHypot
|
| PrimDef::FunNpHypot
|
||||||
| PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim),
|
| PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim),
|
||||||
|
|
||||||
|
PrimDef::FunTryInvertTo | PrimDef::FunWilkinsonShift => self.build_linalg_methods(prim),
|
||||||
};
|
};
|
||||||
|
|
||||||
if cfg!(debug_assertions) {
|
if cfg!(debug_assertions) {
|
||||||
@ -1874,6 +1876,37 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Build the functions `try_invert_to` and `wilkinson_shift`
|
||||||
|
fn build_linalg_methods(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||||
|
debug_assert_prim_is_allowed(prim, &[PrimDef::FunTryInvertTo, PrimDef::FunWilkinsonShift]);
|
||||||
|
|
||||||
|
let ret_ty = match prim {
|
||||||
|
PrimDef::FunTryInvertTo => self.primitives.bool,
|
||||||
|
PrimDef::FunWilkinsonShift => self.primitives.float,
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
let var_map = self.num_or_ndarray_var_map.clone();
|
||||||
|
create_fn_by_codegen(
|
||||||
|
self.unifier,
|
||||||
|
&var_map,
|
||||||
|
prim.name(),
|
||||||
|
ret_ty,
|
||||||
|
&[(self.ndarray_float_2d, "x")],
|
||||||
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
|
let x_ty = fun.0.args[0].ty;
|
||||||
|
let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?;
|
||||||
|
|
||||||
|
let func = match prim {
|
||||||
|
PrimDef::FunTryInvertTo => builtin_fns::call_linalg_try_invert_to,
|
||||||
|
PrimDef::FunWilkinsonShift => builtin_fns::call_linalg_wilkinson_shift,
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Some(func(generator, ctx, (x_ty, x_val))?))
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) {
|
fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) {
|
||||||
(prim.simple_name().into(), method_ty, prim.id())
|
(prim.simple_name().into(), method_ty, prim.id())
|
||||||
}
|
}
|
||||||
|
@ -105,6 +105,8 @@ pub enum PrimDef {
|
|||||||
FunNpLdExp,
|
FunNpLdExp,
|
||||||
FunNpHypot,
|
FunNpHypot,
|
||||||
FunNpNextAfter,
|
FunNpNextAfter,
|
||||||
|
FunTryInvertTo,
|
||||||
|
FunWilkinsonShift,
|
||||||
|
|
||||||
// Top-Level Functions
|
// Top-Level Functions
|
||||||
FunSome,
|
FunSome,
|
||||||
@ -263,6 +265,8 @@ impl PrimDef {
|
|||||||
PrimDef::FunNpLdExp => fun("np_ldexp", None),
|
PrimDef::FunNpLdExp => fun("np_ldexp", None),
|
||||||
PrimDef::FunNpHypot => fun("np_hypot", None),
|
PrimDef::FunNpHypot => fun("np_hypot", None),
|
||||||
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
|
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
|
||||||
|
PrimDef::FunTryInvertTo => fun("try_invert_to", None),
|
||||||
|
PrimDef::FunWilkinsonShift => fun("wilkinson_shift", None),
|
||||||
PrimDef::FunSome => fun("Some", None),
|
PrimDef::FunSome => fun("Some", None),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user