core: add nalgebra::linalg methods

This commit is contained in:
abdul124 2024-07-22 13:19:01 +08:00 committed by =
parent 44487b76ae
commit 2dddab1fcf
4 changed files with 356 additions and 2 deletions

View File

@ -3,7 +3,7 @@ use inkwell::values::BasicValueEnum;
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel}; use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
use itertools::Itertools; use itertools::Itertools;
use crate::codegen::classes::{NDArrayValue, ProxyValue, UntypedArrayLikeAccessor}; use crate::codegen::classes::{ArrayLikeValue, NDArrayValue, ProxyValue, UntypedArrayLikeAccessor};
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl; use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
use crate::codegen::stmt::gen_for_callback_incrementing; use crate::codegen::stmt::gen_for_callback_incrementing;
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator}; use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
@ -1835,3 +1835,232 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]),
}) })
} }
/// Invokes the `linalg_try_invert_to` function
pub fn call_linalg_try_invert_to<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
a: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "linalg_try_invert_to";
let (a_ty, a) = a;
let llvm_usize = generator.get_size_type(ctx.ctx);
match a {
BasicValueEnum::PointerValue(n)
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
match llvm_ndarray_ty {
BasicTypeEnum::FloatType(_) => {}
_ => {
unimplemented!("Inverse Operation supported on float type NDArray Values only")
}
};
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None));
// The following constraints must be satisfied:
// * Input must be 2D
// * number of rows should equal number of columns (square matrix)
if cfg!(debug_assertions) {
let n_dims = n.load_ndims(ctx);
// num_dim == 2
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(
IntPredicate::EQ,
n_dims,
llvm_usize.const_int(2, false),
"",
)
.unwrap(),
"0:ValueError",
format!("Input matrix must have two dimensions for {FN_NAME}").as_str(),
[None, None, None],
ctx.current_loc,
);
let dim0 = unsafe {
n.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
// dim0 == dim1
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim1, "").unwrap(),
"0:ValueError",
format!(
"Input matrix should have equal number of rows and columns for {FN_NAME}"
)
.as_str(),
[None, None, None],
ctx.current_loc,
);
}
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let n_sz_eqz = ctx
.builder
.build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "")
.unwrap();
ctx.make_assert(
generator,
n_sz_eqz,
"0:ValueError",
format!("zero-size array to inverse operation {FN_NAME}").as_str(),
[None, None, None],
ctx.current_loc,
);
}
let dim0 = unsafe {
n.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
Ok(extern_fns::call_linalg_try_invert_to(
ctx,
dim0,
dim1,
n.data().base_ptr(ctx, generator),
None,
)
.into())
}
_ => unsupported_type(ctx, FN_NAME, &[a_ty]),
}
}
/// Invokes the `linalg_wilkinson_shift` function
pub fn call_linalg_wilkinson_shift<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
a: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "linalg_wilkinson_shift";
let (a_ty, a) = a;
let llvm_usize = generator.get_size_type(ctx.ctx);
let one = llvm_usize.const_int(1, false);
let two = llvm_usize.const_int(2, false);
match a {
BasicValueEnum::PointerValue(n)
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
match llvm_ndarray_ty {
BasicTypeEnum::FloatType(_) | BasicTypeEnum::IntType(_) => {}
_ => unimplemented!(
"Wilkinson Shift Operation supported on float type NDArray Values only"
),
};
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
// The following constraints must be satisfied:
// * Input must be 2D
// * Number of rows and columns should equal 2
// * Input matrix must be symmetric
if cfg!(debug_assertions) {
let n_dims = n.load_ndims(ctx);
// num_dim == 2
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, n_dims, two, "").unwrap(),
"0:ValueError",
format!("Input matrix must have two dimensions for {FN_NAME}").as_str(),
[None, None, None],
ctx.current_loc,
);
let dim0 = unsafe {
n.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value()
};
// dim0 == 2
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, dim0, two, "").unwrap(),
"0:ValueError",
format!("Number of rows must be 2 for {FN_NAME}").as_str(),
[None, None, None],
ctx.current_loc,
);
// dim1 == 2
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, dim1, two, "").unwrap(),
"0:ValueError",
format!("Number of columns must be 2 for {FN_NAME}").as_str(),
[None, None, None],
ctx.current_loc,
);
let entry_01 = unsafe {
n.data().get_unchecked(ctx, generator, &one, None).into_float_value()
};
let entry_10 = unsafe {
n.data().get_unchecked(ctx, generator, &two, None).into_float_value()
};
// symmetric matrix
ctx.make_assert(
generator,
ctx.builder
.build_float_compare(FloatPredicate::OEQ, entry_01, entry_10, "")
.unwrap(),
"0:ValueError",
format!("Input Matrix must be symmetric for {FN_NAME}").as_str(),
[None, None, None],
ctx.current_loc,
);
}
let dim0 = unsafe {
n.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 =
unsafe { n.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value() };
Ok(extern_fns::call_linalg_wilkinson_shift(
ctx,
dim0,
dim1,
n.data().base_ptr(ctx, generator),
None,
)
.into())
}
_ => unsupported_type(ctx, FN_NAME, &[a_ty]),
}
}

View File

@ -1,5 +1,5 @@
use inkwell::attributes::{Attribute, AttributeLoc}; use inkwell::attributes::{Attribute, AttributeLoc};
use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue}; use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue};
use itertools::Either; use itertools::Either;
use crate::codegen::CodeGenContext; use crate::codegen::CodeGenContext;
@ -130,3 +130,91 @@ pub fn call_ldexp<'ctx>(
.map(Either::unwrap_left) .map(Either::unwrap_left)
.unwrap() .unwrap()
} }
/// Invokes the [`try_invert_to`](https://docs.rs/nalgebra/latest/nalgebra/linalg/fn.try_invert_to.html) function
pub fn call_linalg_try_invert_to<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
dim0: IntValue<'ctx>,
dim1: IntValue<'ctx>,
data: PointerValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
const FN_NAME: &str = "linalg_try_invert_to";
let llvm_f64 = ctx.ctx.f64_type();
let allowed_indices = [ctx.ctx.i32_type(), ctx.ctx.i64_type()];
let allowed_dim0 = allowed_indices.iter().any(|p| *p == dim0.get_type());
let allowed_dim1 = allowed_indices.iter().any(|p| *p == dim1.get_type());
debug_assert!(allowed_dim0);
debug_assert!(allowed_dim1);
debug_assert_eq!(data.get_type().get_element_type().into_float_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = ctx.ctx.i8_type().fn_type(
&[dim0.get_type().into(), dim0.get_type().into(), data.get_type().into()],
false,
);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[dim0.into(), dim1.into(), data.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`wilkinson_shift`](https://docs.rs/nalgebra/latest/nalgebra/linalg/fn.wilkinson_shift.html) function
pub fn call_linalg_wilkinson_shift<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
dim0: IntValue<'ctx>,
dim1: IntValue<'ctx>,
data: PointerValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "linalg_wilkinson_shift";
let llvm_f64 = ctx.ctx.f64_type();
let allowed_index_types = [ctx.ctx.i32_type(), ctx.ctx.i64_type()];
let allowed_dim0 = allowed_index_types.iter().any(|p| *p == dim0.get_type());
let allowed_dim1 = allowed_index_types.iter().any(|p| *p == dim1.get_type());
debug_assert!(allowed_dim0);
debug_assert!(allowed_dim1);
debug_assert_eq!(data.get_type().get_element_type().into_float_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = ctx.ctx.f64_type().fn_type(
&[dim0.get_type().into(), dim0.get_type().into(), data.get_type().into()],
false,
);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[dim0.into(), dim1.into(), data.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}

View File

@ -556,6 +556,8 @@ impl<'a> BuiltinBuilder<'a> {
| PrimDef::FunNpLdExp | PrimDef::FunNpLdExp
| PrimDef::FunNpHypot | PrimDef::FunNpHypot
| PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim), | PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim),
PrimDef::FunTryInvertTo | PrimDef::FunWilkinsonShift => self.build_linalg_methods(prim),
}; };
if cfg!(debug_assertions) { if cfg!(debug_assertions) {
@ -1874,6 +1876,37 @@ impl<'a> BuiltinBuilder<'a> {
} }
} }
/// Build the functions `try_invert_to` and `wilkinson_shift`
fn build_linalg_methods(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(prim, &[PrimDef::FunTryInvertTo, PrimDef::FunWilkinsonShift]);
let ret_ty = match prim {
PrimDef::FunTryInvertTo => self.primitives.bool,
PrimDef::FunWilkinsonShift => self.primitives.float,
_ => unreachable!(),
};
let var_map = self.num_or_ndarray_var_map.clone();
create_fn_by_codegen(
self.unifier,
&var_map,
prim.name(),
ret_ty,
&[(self.ndarray_float_2d, "x")],
Box::new(move |ctx, _, fun, args, generator| {
let x_ty = fun.0.args[0].ty;
let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?;
let func = match prim {
PrimDef::FunTryInvertTo => builtin_fns::call_linalg_try_invert_to,
PrimDef::FunWilkinsonShift => builtin_fns::call_linalg_wilkinson_shift,
_ => unreachable!(),
};
Ok(Some(func(generator, ctx, (x_ty, x_val))?))
}),
)
}
fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) { fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) {
(prim.simple_name().into(), method_ty, prim.id()) (prim.simple_name().into(), method_ty, prim.id())
} }

View File

@ -105,6 +105,8 @@ pub enum PrimDef {
FunNpLdExp, FunNpLdExp,
FunNpHypot, FunNpHypot,
FunNpNextAfter, FunNpNextAfter,
FunTryInvertTo,
FunWilkinsonShift,
// Top-Level Functions // Top-Level Functions
FunSome, FunSome,
@ -263,6 +265,8 @@ impl PrimDef {
PrimDef::FunNpLdExp => fun("np_ldexp", None), PrimDef::FunNpLdExp => fun("np_ldexp", None),
PrimDef::FunNpHypot => fun("np_hypot", None), PrimDef::FunNpHypot => fun("np_hypot", None),
PrimDef::FunNpNextAfter => fun("np_nextafter", None), PrimDef::FunNpNextAfter => fun("np_nextafter", None),
PrimDef::FunTryInvertTo => fun("try_invert_to", None),
PrimDef::FunWilkinsonShift => fun("wilkinson_shift", None),
PrimDef::FunSome => fun("Some", None), PrimDef::FunSome => fun("Some", None),
} }
} }