From 63d2b49b09aab6eace463b20e99ce8b5c19db50b Mon Sep 17 00:00:00 2001 From: abdul124 Date: Thu, 1 Aug 2024 18:43:06 +0800 Subject: [PATCH] core: remove np_linalg_matmul --- nac3core/src/codegen/builtin_fns.rs | 49 --------------------------- nac3core/src/codegen/extern_fns.rs | 1 - nac3core/src/toplevel/builtins.rs | 23 ------------- nac3core/src/toplevel/helper.rs | 2 -- nac3standalone/demo/interpret_demo.py | 3 +- nac3standalone/demo/linalg/src/lib.rs | 45 ------------------------ nac3standalone/demo/src/ndarray.py | 21 +++--------- 7 files changed, 5 insertions(+), 139 deletions(-) diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index ad15d61b9..afdb43012 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -1867,55 +1867,6 @@ fn build_output_struct<'ctx>( out_ptr } -/// Invokes the `np_linalg_matmul` linalg function -pub fn call_np_linalg_matmul<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - const FN_NAME: &str = "np_linalg_matmul"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; - - let llvm_usize = generator.get_size_type(ctx.ctx); - if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) = (x1, x2) { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); - let (n2_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - let n2_elem_ty = ctx.get_llvm_type(generator, n2_elem_ty); - - let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty) - else { - unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]); - }; - - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); - let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None); - - let outdim0 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let outdim1 = unsafe { - n2.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - .into_int_value() - }; - - let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - - extern_fns::call_np_linalg_matmul(ctx, x1, x2, out, None); - Ok(out) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) - } -} - /// Invokes the `np_linalg_cholesky` linalg function pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, diff --git a/nac3core/src/codegen/extern_fns.rs b/nac3core/src/codegen/extern_fns.rs index e181f57f2..d3f252bd3 100644 --- a/nac3core/src/codegen/extern_fns.rs +++ b/nac3core/src/codegen/extern_fns.rs @@ -179,7 +179,6 @@ macro_rules! generate_linalg_extern_fn { }; } -generate_linalg_extern_fn!(call_np_linalg_matmul, "np_linalg_matmul", 3); generate_linalg_extern_fn!(call_np_linalg_cholesky, "np_linalg_cholesky", 2); generate_linalg_extern_fn!(call_np_linalg_qr, "np_linalg_qr", 3); generate_linalg_extern_fn!(call_np_linalg_svd, "np_linalg_svd", 4); diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 49692baf9..a14ade5d8 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -562,7 +562,6 @@ impl<'a> BuiltinBuilder<'a> { } PrimDef::FunNpDot - | PrimDef::FunNpLinalgMatmul | PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgQr | PrimDef::FunNpLinalgSvd @@ -1950,7 +1949,6 @@ impl<'a> BuiltinBuilder<'a> { prim, &[ PrimDef::FunNpDot, - PrimDef::FunNpLinalgMatmul, PrimDef::FunNpLinalgCholesky, PrimDef::FunNpLinalgQr, PrimDef::FunNpLinalgSvd, @@ -1981,27 +1979,6 @@ impl<'a> BuiltinBuilder<'a> { }), ), - PrimDef::FunNpLinalgMatmul => create_fn_by_codegen( - self.unifier, - &VarMap::new(), - prim.name(), - self.ndarray_float_2d, - &[(self.ndarray_float_2d, "x1"), (self.ndarray_float_2d, "x2")], - Box::new(move |ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_ty = fun.0.args[1].ty; - let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - - Ok(Some(builtin_fns::call_np_linalg_matmul( - generator, - ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - )?)) - }), - ), - PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgInv | PrimDef::FunNpLinalgPinv => { create_fn_by_codegen( self.unifier, diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index be4671820..598a80e9d 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -104,7 +104,6 @@ pub enum PrimDef { // Linalg functions FunNpDot, - FunNpLinalgMatmul, FunNpLinalgCholesky, FunNpLinalgQr, FunNpLinalgSvd, @@ -291,7 +290,6 @@ impl PrimDef { // Linalg functions PrimDef::FunNpDot => fun("np_dot", None), - PrimDef::FunNpLinalgMatmul => fun("np_linalg_matmul", None), PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None), PrimDef::FunNpLinalgQr => fun("np_linalg_qr", None), PrimDef::FunNpLinalgSvd => fun("np_linalg_svd", None), diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 15bf08534..4f19db95c 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -5,8 +5,8 @@ import importlib.util import importlib.machinery import math import numpy as np -import scipy as sp import numpy.typing as npt +import scipy as sp import pathlib from numpy import int32, int64, uint32, uint64 @@ -231,7 +231,6 @@ def patch(module): # Linalg functions module.np_dot = np.dot - module.np_linalg_matmul = np.matmul module.np_linalg_cholesky = np.linalg.cholesky module.np_linalg_qr = np.linalg.qr module.np_linalg_svd = np.linalg.svd diff --git a/nac3standalone/demo/linalg/src/lib.rs b/nac3standalone/demo/linalg/src/lib.rs index d4ee030fa..c671e0b7f 100644 --- a/nac3standalone/demo/linalg/src/lib.rs +++ b/nac3standalone/demo/linalg/src/lib.rs @@ -34,51 +34,6 @@ impl InputMatrix { } } -/// # Safety -/// -/// `mat1` and `mat2` should point to a valid 2DArray of `f64` floats in row-major order -#[no_mangle] -pub unsafe extern "C" fn np_linalg_matmul( - mat1: *mut InputMatrix, - mat2: *mut InputMatrix, - out: *mut InputMatrix, -) { - let mat1 = mat1.as_mut().unwrap(); - let mat2 = mat2.as_mut().unwrap(); - let out = out.as_mut().unwrap(); - - if !(mat1.ndims == 2 && mat2.ndims == 2) { - let err_msg = format!( - "expected 2D Vector Input, but received {}D and {}D input", - mat1.ndims, mat2.ndims - ); - report_error("ValueError", "np_matmul", file!(), line!(), column!(), &err_msg); - } - - let dim1 = (*mat1).get_dims(); - let dim2 = (*mat2).get_dims(); - - if dim1[1] != dim2[0] { - let err_msg = format!( - "shapes ({},{}) and ({},{}) not aligned: {} (dim 1) != {} (dim 0)", - dim1[0], dim1[1], dim2[0], dim2[1], dim1[1], dim2[0] - ); - report_error("ValueError", "np_matmul", file!(), line!(), column!(), &err_msg); - } - - let outdim = out.get_dims(); - let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) }; - let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) }; - let data_slice2 = unsafe { slice::from_raw_parts_mut(mat2.data, dim2[0] * dim2[1]) }; - - let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1); - let matrix2 = DMatrix::from_row_slice(dim2[0], dim2[1], data_slice2); - let mut result = DMatrix::::zeros(outdim[0], outdim[1]); - - matrix1.mul_to(&matrix2, &mut result); - out_slice.copy_from_slice(result.transpose().as_slice()); -} - /// # Safety /// /// `mat1` should point to a valid 2DArray of `f64` floats in row-major order diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 0c7e3233a..362a4ac89 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -1474,18 +1474,6 @@ def test_ndarray_dot(): output_float64(z5) output_bool(z6) -def test_ndarray_linalg_matmul(): - x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]]) - y: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]]) - z = np_linalg_matmul(x, y) - - m = np_argmax(z) - - output_ndarray_float_2(x) - output_ndarray_float_2(y) - output_ndarray_float_2(z) - output_int64(m) - def test_ndarray_cholesky(): x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]]) y = np_linalg_cholesky(x) @@ -1501,7 +1489,7 @@ def test_ndarray_qr(): # QR Factorization is not unique and gives different results in numpy and nalgebra # Reverting the decomposition to compare the initial arrays - a = np_linalg_matmul(y, z) + a = y @ z output_ndarray_float_2(a) def test_ndarray_linalg_inv(): @@ -1540,7 +1528,7 @@ def test_ndarray_schur(): # Schur Factorization is not unique and gives different results in scipy and nalgebra # Reverting the decomposition to compare the initial arrays - a = np_linalg_matmul(np_linalg_matmul(z, t), np_linalg_inv(z)) + a = (z @ t) @ np_linalg_inv(z) output_ndarray_float_2(a) def test_ndarray_hessenberg(): @@ -1551,7 +1539,7 @@ def test_ndarray_hessenberg(): # Hessenberg Factorization is not unique and gives different results in scipy and nalgebra # Reverting the decomposition to compare the initial arrays - a = np_linalg_matmul(np_linalg_matmul(q, h), np_linalg_inv(q)) + a = (q @ h) @ np_linalg_inv(q) output_ndarray_float_2(a) @@ -1572,7 +1560,7 @@ def test_ndarray_svd(): # SVD Factorization is not unique and gives different results in numpy and nalgebra # Reverting the decomposition to compare the initial arrays - a = np_linalg_matmul(x, z) + a = x @ z output_ndarray_float_2(a) output_ndarray_float_1(y) @@ -1759,7 +1747,6 @@ def run() -> int32: test_ndarray_reshape() test_ndarray_dot() - test_ndarray_linalg_matmul() test_ndarray_cholesky() test_ndarray_qr() test_ndarray_svd()