From 2242c5af434413ca6664b3077bcd0b58d99aaf1d Mon Sep 17 00:00:00 2001 From: abdul124 Date: Thu, 25 Jul 2024 12:16:53 +0800 Subject: [PATCH] core: add linalg methods --- nac3core/src/codegen/builtin_fns.rs | 480 +++++++++++++++++++++++++++- nac3core/src/codegen/extern_fns.rs | 88 +++++ nac3core/src/codegen/numpy.rs | 2 +- nac3core/src/toplevel/builtins.rs | 155 +++++++++ nac3core/src/toplevel/helper.rs | 21 ++ 5 files changed, 743 insertions(+), 3 deletions(-) diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index abe9205..50748de 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -1,5 +1,5 @@ use inkwell::types::BasicTypeEnum; -use inkwell::values::BasicValueEnum; +use inkwell::values::{BasicValue, BasicValueEnum, PointerValue}; use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel}; use itertools::Itertools; @@ -31,7 +31,6 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( let llvm_usize = generator.get_size_type(ctx.ctx); let (n_ty, n) = n; - Ok(match n { BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => { debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool)); @@ -1836,3 +1835,480 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>( _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), }) } + +/// Allocates a struct with the fields specified by `out_matrices` and returns a pointer to it +fn build_output_struct<'ctx>( + ctx: &mut CodeGenContext<'ctx, '_>, + out_matrices: Vec>, +) -> PointerValue<'ctx> { + let field_ty = + out_matrices.iter().map(BasicValueEnum::get_type).collect::>(); + let out_ty = ctx.ctx.struct_type(&field_ty, false); + let out_ptr = ctx.builder.build_alloca(out_ty, "").unwrap(); + + for (i, v) in out_matrices.into_iter().enumerate() { + unsafe { + let ptr = ctx + .builder + .build_in_bounds_gep( + out_ptr, + &[ + ctx.ctx.i32_type().const_zero(), + ctx.ctx.i32_type().const_int(i as u64, false), + ], + "", + ) + .unwrap(); + ctx.builder.build_store(ptr, v).unwrap(); + } + } + out_ptr +} + +/// Invokes the `np_dot` using `nalgebra` crate +pub fn call_np_dot<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + x1: (Type, BasicValueEnum<'ctx>), + x2: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_dot"; + let (x1_ty, x1) = x1; + let (x2_ty, x2) = x2; + + if let (BasicValueEnum::PointerValue(_), BasicValueEnum::PointerValue(_)) = (x1, x2) { + let (n1_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let n1_elem_ty = ctx.get_llvm_type(generator, n1_elem_ty); + let (n2_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let n2_elem_ty = ctx.get_llvm_type(generator, n2_elem_ty); + + let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty) + else { + unimplemented!("{FN_NAME} operates on float type NdArrays only"); + }; + + Ok(extern_fns::call_np_dot(ctx, x1, x2, None).into()) + } else { + unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) + } +} + +/// Invokes the `np_linalg_matmul` using `nalgebra` crate +pub fn call_np_linalg_matmul<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + x1: (Type, BasicValueEnum<'ctx>), + x2: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_linalg_matmul"; + let (x1_ty, x1) = x1; + let (x2_ty, x2) = x2; + + let llvm_usize = generator.get_size_type(ctx.ctx); + if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) = (x1, x2) { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let (n2_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let n2_elem_ty = ctx.get_llvm_type(generator, n2_elem_ty); + + let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty) + else { + unimplemented!("{FN_NAME} operates on float type NdArrays only"); + }; + + let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None); + + let outdim0 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + .into_int_value() + }; + let outdim1 = unsafe { + n2.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) + .into_int_value() + }; + + let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1]) + .unwrap() + .as_base_value() + .as_basic_value_enum(); + + extern_fns::call_np_linalg_matmul(ctx, x1, x2, out, None); + Ok(out) + } else { + unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) + } +} + +/// Invokes the `np_linalg_cholesky` using `nalgebra` crate +pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + x1: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_linalg_cholesky"; + let (x1_ty, x1) = x1; + let llvm_usize = generator.get_size_type(ctx.ctx); + + if let BasicValueEnum::PointerValue(n1) = x1 { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + + let BasicTypeEnum::FloatType(_) = n1_elem_ty else { + unimplemented!("{FN_NAME} operates on float type NdArrays only"); + }; + + let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + let dim0 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + .into_int_value() + }; + let dim1 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) + .into_int_value() + }; + + let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1]) + .unwrap() + .as_base_value() + .as_basic_value_enum(); + + extern_fns::call_np_linalg_cholesky(ctx, x1, out, None); + Ok(out) + } else { + unsupported_type(ctx, FN_NAME, &[x1_ty]) + } +} + +/// Invokes the `np_linalg_qr` using `nalgebra` crate +pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + x1: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_linalg_qr"; + let (x1_ty, x1) = x1; + let llvm_usize = generator.get_size_type(ctx.ctx); + + if let BasicValueEnum::PointerValue(n1) = x1 { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + + let BasicTypeEnum::FloatType(_) = n1_elem_ty else { + unimplemented!("{FN_NAME} operates on float type NdArrays only"); + }; + + let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + let dim0 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + .into_int_value() + }; + let dim1 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) + .into_int_value() + }; + let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None); + + let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k]) + .unwrap() + .as_base_value() + .as_basic_value_enum(); + let out_r = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1]) + .unwrap() + .as_base_value() + .as_basic_value_enum(); + + extern_fns::call_np_linalg_qr(ctx, x1, out_q, out_r, None); + + let out_ptr = build_output_struct(ctx, vec![out_q, out_r]); + + Ok(ctx.builder.build_load(out_ptr, "QR_Factorization_result").map(Into::into).unwrap()) + } else { + unsupported_type(ctx, FN_NAME, &[x1_ty]) + } +} + +/// Invokes the `np_linalg_svd` using `nalgebra` crate +pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + x1: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_linalg_svd"; + let (x1_ty, x1) = x1; + let llvm_usize = generator.get_size_type(ctx.ctx); + + if let BasicValueEnum::PointerValue(n1) = x1 { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + + let BasicTypeEnum::FloatType(_) = n1_elem_ty else { + unimplemented!("{FN_NAME} operates on float type NdArrays only"); + }; + + let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + + let dim0 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + .into_int_value() + }; + let dim1 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) + .into_int_value() + }; + let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None); + + let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]) + .unwrap() + .as_base_value() + .as_basic_value_enum(); + let out_s = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k]) + .unwrap() + .as_base_value() + .as_basic_value_enum(); + let out_vh = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim1]) + .unwrap() + .as_base_value() + .as_basic_value_enum(); + + extern_fns::call_np_linalg_svd(ctx, x1, out_u, out_s, out_vh, None); + + let out_ptr = build_output_struct(ctx, vec![out_u, out_s, out_vh]); + + Ok(ctx.builder.build_load(out_ptr, "SVD_Factorization_result").map(Into::into).unwrap()) + } else { + unsupported_type(ctx, FN_NAME, &[x1_ty]) + } +} + +/// Invokes the `np_linalg_inv` using `nalgebra` crate +pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + x1: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_linalg_inv"; + let (x1_ty, x1) = x1; + let llvm_usize = generator.get_size_type(ctx.ctx); + + if let BasicValueEnum::PointerValue(n1) = x1 { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + + let BasicTypeEnum::FloatType(_) = n1_elem_ty else { + unimplemented!("{FN_NAME} operates on float type NdArrays only"); + }; + + let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + let dim0 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + .into_int_value() + }; + let dim1 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) + .into_int_value() + }; + + let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1]) + .unwrap() + .as_base_value() + .as_basic_value_enum(); + + extern_fns::call_np_linalg_inv(ctx, x1, out, None); + Ok(out) + } else { + unsupported_type(ctx, FN_NAME, &[x1_ty]) + } +} + +/// Invokes the `np_linalg_pinv` using `nalgebra` crate +pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + x1: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_linalg_pinv"; + let (x1_ty, x1) = x1; + let llvm_usize = generator.get_size_type(ctx.ctx); + + if let BasicValueEnum::PointerValue(n1) = x1 { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + + let BasicTypeEnum::FloatType(_) = n1_elem_ty else { + unimplemented!("{FN_NAME} operates on float type NdArrays only"); + }; + + let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + + let dim0 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + .into_int_value() + }; + let dim1 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) + .into_int_value() + }; + + let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim0]) + .unwrap() + .as_base_value() + .as_basic_value_enum(); + + extern_fns::call_np_linalg_pinv(ctx, x1, out, None); + Ok(out) + } else { + unsupported_type(ctx, FN_NAME, &[x1_ty]) + } +} + +/// Invokes the `sp_linalg_lu` using `nalgebra` crate +pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + x1: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "sp_linalg_lu"; + let (x1_ty, x1) = x1; + let llvm_usize = generator.get_size_type(ctx.ctx); + + if let BasicValueEnum::PointerValue(n1) = x1 { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + + let BasicTypeEnum::FloatType(_) = n1_elem_ty else { + unimplemented!("{FN_NAME} operates on float type NdArrays only"); + }; + + let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + + let dim0 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + .into_int_value() + }; + let dim1 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) + .into_int_value() + }; + let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None); + + let out_l = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k]) + .unwrap() + .as_base_value() + .as_basic_value_enum(); + let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1]) + .unwrap() + .as_base_value() + .as_basic_value_enum(); + + extern_fns::call_sp_linalg_lu(ctx, x1, out_l, out_u, None); + + let out_ptr = build_output_struct(ctx, vec![out_l, out_u]); + Ok(ctx.builder.build_load(out_ptr, "LU_Factorization_result").map(Into::into).unwrap()) + } else { + unsupported_type(ctx, FN_NAME, &[x1_ty]) + } +} + +/// Invokes the `sp_linalg_schur` using `nalgebra` crate +pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + x1: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "sp_linalg_schur"; + let (x1_ty, x1) = x1; + let llvm_usize = generator.get_size_type(ctx.ctx); + + if let BasicValueEnum::PointerValue(n1) = x1 { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + + let BasicTypeEnum::FloatType(_) = n1_elem_ty else { + unimplemented!("{FN_NAME} operates on float type NdArrays only"); + }; + + let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + + let dim0 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + .into_int_value() + }; + let out_t = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]) + .unwrap() + .as_base_value() + .as_basic_value_enum(); + let out_z = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]) + .unwrap() + .as_base_value() + .as_basic_value_enum(); + + extern_fns::call_sp_linalg_schur(ctx, x1, out_t, out_z, None); + + let out_ptr = build_output_struct(ctx, vec![out_t, out_z]); + Ok(ctx.builder.build_load(out_ptr, "Schur_Factorization_result").map(Into::into).unwrap()) + } else { + unsupported_type(ctx, FN_NAME, &[x1_ty]) + } +} + +/// Invokes the `sp_linalg_hessenberg` using `nalgebra` crate +pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + x1: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "sp_linalg_hessenberg"; + let (x1_ty, x1) = x1; + let llvm_usize = generator.get_size_type(ctx.ctx); + + if let BasicValueEnum::PointerValue(n1) = x1 { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); + let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + + let BasicTypeEnum::FloatType(_) = n1_elem_ty else { + unimplemented!("{FN_NAME} operates on float type NdArrays only"); + }; + + let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + + let dim0 = unsafe { + n1.dim_sizes() + .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + .into_int_value() + }; + let out_h = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]) + .unwrap() + .as_base_value() + .as_basic_value_enum(); + let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]) + .unwrap() + .as_base_value() + .as_basic_value_enum(); + extern_fns::call_sp_linalg_hessenberg(ctx, x1, out_h, out_q, None); + + let out_ptr = build_output_struct(ctx, vec![out_h, out_q]); + Ok(ctx + .builder + .build_load(out_ptr, "Hessenberg_decomposition_result") + .map(Into::into) + .unwrap()) + } else { + unsupported_type(ctx, FN_NAME, &[x1_ty]) + } +} diff --git a/nac3core/src/codegen/extern_fns.rs b/nac3core/src/codegen/extern_fns.rs index 8b510ed..ba8403c 100644 --- a/nac3core/src/codegen/extern_fns.rs +++ b/nac3core/src/codegen/extern_fns.rs @@ -130,3 +130,91 @@ pub fn call_ldexp<'ctx>( .map(Either::unwrap_left) .unwrap() } + +/// Macro to generate `np_linalg` and `sp_linalg` functions +/// The function takes as input `NDArray` and returns () +/// +/// Arguments: +/// * `$fn_name:ident`: The identifier of the rust function to be generated +/// * `$extern_fn:literal`: Name of underlying extern function +/// * (2/3/4): Number of `NDArray` that function takes as input +/// +/// Note: +/// The operands and resulting `NDArray` are both passed as input to the funcion +/// It is the responsibility of caller to ensure that output `NDArray` is properly allocated on stack +/// The function changes the content of the output `NDArray` in-place +macro_rules! generate_linalg_extern_fn { + ($fn_name:ident, $extern_fn:literal, 2) => { + generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2); + }; + ($fn_name:ident, $extern_fn:literal, 3) => { + generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2, mat3); + }; + ($fn_name:ident, $extern_fn:literal, 4) => { + generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2, mat3, mat4); + }; + ($fn_name:ident, $extern_fn:literal $(,$input_matrix:ident)*) => { + #[doc = concat!("Invokes the linalg `", stringify!($extern_fn), " function." )] + pub fn $fn_name<'ctx>( + ctx: &mut CodeGenContext<'ctx, '_> + $(,$input_matrix: BasicValueEnum<'ctx>)*, + name: Option<&str>, + ){ + const FN_NAME: &str = $extern_fn; + let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| { + let fn_type = ctx.ctx.void_type().fn_type(&[$($input_matrix.get_type().into()),*], false); + + let func = ctx.module.add_function(FN_NAME, fn_type, None); + for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { + func.add_attribute( + AttributeLoc::Function, + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), + ); + } + func + }); + + ctx.builder.build_call(extern_fn, &[$($input_matrix.into(),)*], name.unwrap_or_default()).unwrap(); + } + }; +} + +generate_linalg_extern_fn!(call_np_linalg_matmul, "np_linalg_matmul", 3); +generate_linalg_extern_fn!(call_np_linalg_cholesky, "np_linalg_cholesky", 2); +generate_linalg_extern_fn!(call_np_linalg_qr, "np_linalg_qr", 3); +generate_linalg_extern_fn!(call_np_linalg_svd, "np_linalg_svd", 4); +generate_linalg_extern_fn!(call_np_linalg_inv, "np_linalg_inv", 2); +generate_linalg_extern_fn!(call_np_linalg_pinv, "np_linalg_pinv", 2); +generate_linalg_extern_fn!(call_sp_linalg_lu, "sp_linalg_lu", 3); +generate_linalg_extern_fn!(call_sp_linalg_schur, "sp_linalg_schur", 3); +generate_linalg_extern_fn!(call_sp_linalg_hessenberg, "sp_linalg_hessenberg", 3); + +/// Invokes the linalg `np_dot` function. +pub fn call_np_dot<'ctx>( + ctx: &mut CodeGenContext<'ctx, '_>, + mat1: BasicValueEnum<'ctx>, + mat2: BasicValueEnum<'ctx>, + name: Option<&str>, +) -> FloatValue<'ctx> { + const FN_NAME: &str = "np_dot"; + + let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| { + let fn_type = + ctx.ctx.f64_type().fn_type(&[mat1.get_type().into(), mat2.get_type().into()], false); + let func = ctx.module.add_function(FN_NAME, fn_type, None); + for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { + func.add_attribute( + AttributeLoc::Function, + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), + ); + } + func + }); + + ctx.builder + .build_call(extern_fn, &[mat1.into(), mat2.into()], name.unwrap_or_default()) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_float_value)) + .map(Either::unwrap_left) + .unwrap() +} diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 92e0705..a2b3c2c 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -159,7 +159,7 @@ where /// /// * `elem_ty` - The element type of the `NDArray`. /// * `shape` - The shape of the `NDArray`, represented am array of [`IntValue`]s. -fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( +pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, elem_ty: Type, diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 008f24d..0d65828 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -556,6 +556,17 @@ impl<'a> BuiltinBuilder<'a> { | PrimDef::FunNpLdExp | PrimDef::FunNpHypot | PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim), + + PrimDef::FunNpDot + | PrimDef::FunNpLinalgMatmul + | PrimDef::FunNpLinalgCholesky + | PrimDef::FunNpLinalgQr + | PrimDef::FunNpLinalgSvd + | PrimDef::FunNpLinalgInv + | PrimDef::FunNpLinalgPinv + | PrimDef::FunSpLinalgLu + | PrimDef::FunSpLinalgSchur + | PrimDef::FunSpLinalgHessenberg => self.build_linalg_methods(prim), }; if cfg!(debug_assertions) { @@ -1874,6 +1885,150 @@ impl<'a> BuiltinBuilder<'a> { } } + /// Build `np_linalg` and `sp_linalg` functions + /// + /// The input to these functions must be floating point `NDArray` + fn build_linalg_methods(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed( + prim, + &[ + PrimDef::FunNpDot, + PrimDef::FunNpLinalgMatmul, + PrimDef::FunNpLinalgCholesky, + PrimDef::FunNpLinalgQr, + PrimDef::FunNpLinalgSvd, + PrimDef::FunNpLinalgInv, + PrimDef::FunNpLinalgPinv, + PrimDef::FunSpLinalgLu, + PrimDef::FunSpLinalgSchur, + PrimDef::FunSpLinalgHessenberg, + ], + ); + + match prim { + PrimDef::FunNpDot => create_fn_by_codegen( + self.unifier, + &self.num_or_ndarray_var_map, + prim.name(), + self.primitives.float, + &[(self.num_or_ndarray_ty.ty, "x1"), (self.num_or_ndarray_ty.ty, "x2")], + Box::new(move |ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; + let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; + + Ok(Some(builtin_fns::call_np_dot( + generator, + ctx, + (x1_ty, x1_val), + (x2_ty, x2_val), + )?)) + }), + ), + + PrimDef::FunNpLinalgMatmul => create_fn_by_codegen( + self.unifier, + &VarMap::new(), + prim.name(), + self.ndarray_float_2d, + &[(self.ndarray_float_2d, "x1"), (self.ndarray_float_2d, "x2")], + Box::new(move |ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; + let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; + + Ok(Some(builtin_fns::call_np_linalg_matmul( + generator, + ctx, + (x1_ty, x1_val), + (x2_ty, x2_val), + )?)) + }), + ), + + PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgInv | PrimDef::FunNpLinalgPinv => { + create_fn_by_codegen( + self.unifier, + &VarMap::new(), + prim.name(), + self.ndarray_float_2d, + &[(self.ndarray_float_2d, "x1")], + Box::new(move |ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = + args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + + let func = match prim { + PrimDef::FunNpLinalgCholesky => builtin_fns::call_np_linalg_cholesky, + PrimDef::FunNpLinalgInv => builtin_fns::call_np_linalg_inv, + PrimDef::FunNpLinalgPinv => builtin_fns::call_np_linalg_pinv, + _ => unreachable!(), + }; + Ok(Some(func(generator, ctx, (x1_ty, x1_val))?)) + }), + ) + } + + PrimDef::FunNpLinalgQr + | PrimDef::FunSpLinalgLu + | PrimDef::FunSpLinalgSchur + | PrimDef::FunSpLinalgHessenberg => { + let ret_ty = self.unifier.add_ty(TypeEnum::TTuple { + ty: vec![self.ndarray_float_2d, self.ndarray_float_2d], + }); + create_fn_by_codegen( + self.unifier, + &VarMap::new(), + prim.name(), + ret_ty, + &[(self.ndarray_float_2d, "x1")], + Box::new(move |ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = + args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + + let func = match prim { + PrimDef::FunNpLinalgQr => builtin_fns::call_np_linalg_qr, + PrimDef::FunSpLinalgLu => builtin_fns::call_sp_linalg_lu, + PrimDef::FunSpLinalgSchur => builtin_fns::call_sp_linalg_schur, + PrimDef::FunSpLinalgHessenberg => { + builtin_fns::call_sp_linalg_hessenberg + } + _ => unreachable!(), + }; + Ok(Some(func(generator, ctx, (x1_ty, x1_val))?)) + }), + ) + } + + PrimDef::FunNpLinalgSvd => { + let ret_ty = self.unifier.add_ty(TypeEnum::TTuple { + ty: vec![self.ndarray_float_2d, self.ndarray_float, self.ndarray_float_2d], + }); + create_fn_by_codegen( + self.unifier, + &VarMap::new(), + prim.name(), + ret_ty, + &[(self.ndarray_float_2d, "x1")], + Box::new(move |ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = + args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + + Ok(Some(builtin_fns::call_np_linalg_svd(generator, ctx, (x1_ty, x1_val))?)) + }), + ) + } + _ => { + println!("{:?}", prim.name()); + unreachable!() + } + } + } + fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) { (prim.simple_name().into(), method_ty, prim.id()) } diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index c4a6963..ae50c77 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -100,6 +100,17 @@ pub enum PrimDef { FunNpHypot, FunNpNextAfter, + FunNpDot, + FunNpLinalgMatmul, + FunNpLinalgCholesky, + FunNpLinalgQr, + FunNpLinalgSvd, + FunNpLinalgInv, + FunNpLinalgPinv, + FunSpLinalgLu, + FunSpLinalgSchur, + FunSpLinalgHessenberg, + // Miscellaneous Python & NAC3 functions FunInt32, FunInt64, @@ -270,6 +281,16 @@ impl PrimDef { PrimDef::FunNpLdExp => fun("np_ldexp", None), PrimDef::FunNpHypot => fun("np_hypot", None), PrimDef::FunNpNextAfter => fun("np_nextafter", None), + PrimDef::FunNpDot => fun("np_dot", None), + PrimDef::FunNpLinalgMatmul => fun("np_linalg_matmul", None), + PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None), + PrimDef::FunNpLinalgQr => fun("np_linalg_qr", None), + PrimDef::FunNpLinalgSvd => fun("np_linalg_svd", None), + PrimDef::FunNpLinalgInv => fun("np_linalg_inv", None), + PrimDef::FunNpLinalgPinv => fun("np_linalg_pinv", None), + PrimDef::FunSpLinalgLu => fun("sp_linalg_lu", None), + PrimDef::FunSpLinalgSchur => fun("sp_linalg_schur", None), + PrimDef::FunSpLinalgHessenberg => fun("sp_linalg_hessenberg", None), // Miscellaneous Python & NAC3 functions PrimDef::FunInt32 => fun("int32", None),