From 2dddab1fcf845a7b45f8bb13ea574bc302e119bc Mon Sep 17 00:00:00 2001 From: abdul124 Date: Mon, 22 Jul 2024 13:19:01 +0800 Subject: [PATCH] core: add nalgebra::linalg methods --- nac3core/src/codegen/builtin_fns.rs | 231 +++++++++++++++++++++++++++- nac3core/src/codegen/extern_fns.rs | 90 ++++++++++- nac3core/src/toplevel/builtins.rs | 33 ++++ nac3core/src/toplevel/helper.rs | 4 + 4 files changed, 356 insertions(+), 2 deletions(-) diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 63078107..16a970bb 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -3,7 +3,7 @@ use inkwell::values::BasicValueEnum; use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel}; use itertools::Itertools; -use crate::codegen::classes::{NDArrayValue, ProxyValue, UntypedArrayLikeAccessor}; +use crate::codegen::classes::{ArrayLikeValue, NDArrayValue, ProxyValue, UntypedArrayLikeAccessor}; use crate::codegen::numpy::ndarray_elementwise_unaryop_impl; use crate::codegen::stmt::gen_for_callback_incrementing; use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator}; @@ -1835,3 +1835,232 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>( _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), }) } + +/// Invokes the `linalg_try_invert_to` function +pub fn call_linalg_try_invert_to<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + a: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "linalg_try_invert_to"; + let (a_ty, a) = a; + let llvm_usize = generator.get_size_type(ctx.ctx); + + match a { + BasicValueEnum::PointerValue(n) + if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => + { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); + let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); + match llvm_ndarray_ty { + BasicTypeEnum::FloatType(_) => {} + _ => { + unimplemented!("Inverse Operation supported on float type NDArray Values only") + } + }; + + let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); + let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None)); + + // The following constraints must be satisfied: + // * Input must be 2D + // * number of rows should equal number of columns (square matrix) + if cfg!(debug_assertions) { + let n_dims = n.load_ndims(ctx); + + // num_dim == 2 + ctx.make_assert( + generator, + ctx.builder + .build_int_compare( + IntPredicate::EQ, + n_dims, + llvm_usize.const_int(2, false), + "", + ) + .unwrap(), + "0:ValueError", + format!("Input matrix must have two dimensions for {FN_NAME}").as_str(), + [None, None, None], + ctx.current_loc, + ); + + let dim0 = unsafe { + n.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + .into_int_value() + }; + let dim1 = unsafe { + n.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) + .into_int_value() + }; + + // dim0 == dim1 + ctx.make_assert( + generator, + ctx.builder.build_int_compare(IntPredicate::EQ, dim0, dim1, "").unwrap(), + "0:ValueError", + format!( + "Input matrix should have equal number of rows and columns for {FN_NAME}" + ) + .as_str(), + [None, None, None], + ctx.current_loc, + ); + } + + if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { + let n_sz_eqz = ctx + .builder + .build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "") + .unwrap(); + + ctx.make_assert( + generator, + n_sz_eqz, + "0:ValueError", + format!("zero-size array to inverse operation {FN_NAME}").as_str(), + [None, None, None], + ctx.current_loc, + ); + } + + let dim0 = unsafe { + n.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + .into_int_value() + }; + let dim1 = unsafe { + n.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) + .into_int_value() + }; + + Ok(extern_fns::call_linalg_try_invert_to( + ctx, + dim0, + dim1, + n.data().base_ptr(ctx, generator), + None, + ) + .into()) + } + _ => unsupported_type(ctx, FN_NAME, &[a_ty]), + } +} + +/// Invokes the `linalg_wilkinson_shift` function +pub fn call_linalg_wilkinson_shift<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + a: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "linalg_wilkinson_shift"; + let (a_ty, a) = a; + let llvm_usize = generator.get_size_type(ctx.ctx); + + let one = llvm_usize.const_int(1, false); + let two = llvm_usize.const_int(2, false); + + match a { + BasicValueEnum::PointerValue(n) + if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => + { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); + let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); + match llvm_ndarray_ty { + BasicTypeEnum::FloatType(_) | BasicTypeEnum::IntType(_) => {} + _ => unimplemented!( + "Wilkinson Shift Operation supported on float type NDArray Values only" + ), + }; + + let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); + + // The following constraints must be satisfied: + // * Input must be 2D + // * Number of rows and columns should equal 2 + // * Input matrix must be symmetric + if cfg!(debug_assertions) { + let n_dims = n.load_ndims(ctx); + + // num_dim == 2 + ctx.make_assert( + generator, + ctx.builder.build_int_compare(IntPredicate::EQ, n_dims, two, "").unwrap(), + "0:ValueError", + format!("Input matrix must have two dimensions for {FN_NAME}").as_str(), + [None, None, None], + ctx.current_loc, + ); + + let dim0 = unsafe { + n.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + .into_int_value() + }; + let dim1 = unsafe { + n.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value() + }; + + // dim0 == 2 + ctx.make_assert( + generator, + ctx.builder.build_int_compare(IntPredicate::EQ, dim0, two, "").unwrap(), + "0:ValueError", + format!("Number of rows must be 2 for {FN_NAME}").as_str(), + [None, None, None], + ctx.current_loc, + ); + + // dim1 == 2 + ctx.make_assert( + generator, + ctx.builder.build_int_compare(IntPredicate::EQ, dim1, two, "").unwrap(), + "0:ValueError", + format!("Number of columns must be 2 for {FN_NAME}").as_str(), + [None, None, None], + ctx.current_loc, + ); + + let entry_01 = unsafe { + n.data().get_unchecked(ctx, generator, &one, None).into_float_value() + }; + let entry_10 = unsafe { + n.data().get_unchecked(ctx, generator, &two, None).into_float_value() + }; + + // symmetric matrix + ctx.make_assert( + generator, + ctx.builder + .build_float_compare(FloatPredicate::OEQ, entry_01, entry_10, "") + .unwrap(), + "0:ValueError", + format!("Input Matrix must be symmetric for {FN_NAME}").as_str(), + [None, None, None], + ctx.current_loc, + ); + } + + let dim0 = unsafe { + n.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + .into_int_value() + }; + let dim1 = + unsafe { n.dim_sizes().get_unchecked(ctx, generator, &one, None).into_int_value() }; + + Ok(extern_fns::call_linalg_wilkinson_shift( + ctx, + dim0, + dim1, + n.data().base_ptr(ctx, generator), + None, + ) + .into()) + } + _ => unsupported_type(ctx, FN_NAME, &[a_ty]), + } +} diff --git a/nac3core/src/codegen/extern_fns.rs b/nac3core/src/codegen/extern_fns.rs index 8b510ed9..09e97c5a 100644 --- a/nac3core/src/codegen/extern_fns.rs +++ b/nac3core/src/codegen/extern_fns.rs @@ -1,5 +1,5 @@ use inkwell::attributes::{Attribute, AttributeLoc}; -use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue}; +use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue}; use itertools::Either; use crate::codegen::CodeGenContext; @@ -130,3 +130,91 @@ pub fn call_ldexp<'ctx>( .map(Either::unwrap_left) .unwrap() } + +/// Invokes the [`try_invert_to`](https://docs.rs/nalgebra/latest/nalgebra/linalg/fn.try_invert_to.html) function +pub fn call_linalg_try_invert_to<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + dim0: IntValue<'ctx>, + dim1: IntValue<'ctx>, + data: PointerValue<'ctx>, + name: Option<&str>, +) -> IntValue<'ctx> { + const FN_NAME: &str = "linalg_try_invert_to"; + + let llvm_f64 = ctx.ctx.f64_type(); + let allowed_indices = [ctx.ctx.i32_type(), ctx.ctx.i64_type()]; + + let allowed_dim0 = allowed_indices.iter().any(|p| *p == dim0.get_type()); + let allowed_dim1 = allowed_indices.iter().any(|p| *p == dim1.get_type()); + + debug_assert!(allowed_dim0); + debug_assert!(allowed_dim1); + debug_assert_eq!(data.get_type().get_element_type().into_float_type(), llvm_f64); + + let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| { + let fn_type = ctx.ctx.i8_type().fn_type( + &[dim0.get_type().into(), dim0.get_type().into(), data.get_type().into()], + false, + ); + let func = ctx.module.add_function(FN_NAME, fn_type, None); + for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] { + func.add_attribute( + AttributeLoc::Function, + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), + ); + } + + func + }); + + ctx.builder + .build_call(extern_fn, &[dim0.into(), dim1.into(), data.into()], name.unwrap_or_default()) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_int_value)) + .map(Either::unwrap_left) + .unwrap() +} + +/// Invokes the [`wilkinson_shift`](https://docs.rs/nalgebra/latest/nalgebra/linalg/fn.wilkinson_shift.html) function +pub fn call_linalg_wilkinson_shift<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + dim0: IntValue<'ctx>, + dim1: IntValue<'ctx>, + data: PointerValue<'ctx>, + name: Option<&str>, +) -> FloatValue<'ctx> { + const FN_NAME: &str = "linalg_wilkinson_shift"; + + let llvm_f64 = ctx.ctx.f64_type(); + let allowed_index_types = [ctx.ctx.i32_type(), ctx.ctx.i64_type()]; + + let allowed_dim0 = allowed_index_types.iter().any(|p| *p == dim0.get_type()); + let allowed_dim1 = allowed_index_types.iter().any(|p| *p == dim1.get_type()); + + debug_assert!(allowed_dim0); + debug_assert!(allowed_dim1); + debug_assert_eq!(data.get_type().get_element_type().into_float_type(), llvm_f64); + + let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| { + let fn_type = ctx.ctx.f64_type().fn_type( + &[dim0.get_type().into(), dim0.get_type().into(), data.get_type().into()], + false, + ); + let func = ctx.module.add_function(FN_NAME, fn_type, None); + for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] { + func.add_attribute( + AttributeLoc::Function, + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), + ); + } + + func + }); + + ctx.builder + .build_call(extern_fn, &[dim0.into(), dim1.into(), data.into()], name.unwrap_or_default()) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_float_value)) + .map(Either::unwrap_left) + .unwrap() +} diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 83256b14..783aa5fb 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -556,6 +556,8 @@ impl<'a> BuiltinBuilder<'a> { | PrimDef::FunNpLdExp | PrimDef::FunNpHypot | PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim), + + PrimDef::FunTryInvertTo | PrimDef::FunWilkinsonShift => self.build_linalg_methods(prim), }; if cfg!(debug_assertions) { @@ -1874,6 +1876,37 @@ impl<'a> BuiltinBuilder<'a> { } } + /// Build the functions `try_invert_to` and `wilkinson_shift` + fn build_linalg_methods(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed(prim, &[PrimDef::FunTryInvertTo, PrimDef::FunWilkinsonShift]); + + let ret_ty = match prim { + PrimDef::FunTryInvertTo => self.primitives.bool, + PrimDef::FunWilkinsonShift => self.primitives.float, + _ => unreachable!(), + }; + let var_map = self.num_or_ndarray_var_map.clone(); + create_fn_by_codegen( + self.unifier, + &var_map, + prim.name(), + ret_ty, + &[(self.ndarray_float_2d, "x")], + Box::new(move |ctx, _, fun, args, generator| { + let x_ty = fun.0.args[0].ty; + let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; + + let func = match prim { + PrimDef::FunTryInvertTo => builtin_fns::call_linalg_try_invert_to, + PrimDef::FunWilkinsonShift => builtin_fns::call_linalg_wilkinson_shift, + _ => unreachable!(), + }; + + Ok(Some(func(generator, ctx, (x_ty, x_val))?)) + }), + ) + } + fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) { (prim.simple_name().into(), method_ty, prim.id()) } diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 538e653e..cd6c4975 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -105,6 +105,8 @@ pub enum PrimDef { FunNpLdExp, FunNpHypot, FunNpNextAfter, + FunTryInvertTo, + FunWilkinsonShift, // Top-Level Functions FunSome, @@ -263,6 +265,8 @@ impl PrimDef { PrimDef::FunNpLdExp => fun("np_ldexp", None), PrimDef::FunNpHypot => fun("np_hypot", None), PrimDef::FunNpNextAfter => fun("np_nextafter", None), + PrimDef::FunTryInvertTo => fun("try_invert_to", None), + PrimDef::FunWilkinsonShift => fun("wilkinson_shift", None), PrimDef::FunSome => fun("Some", None), } }