From 54f883f0a509ceefe570f8101829d938725692f3 Mon Sep 17 00:00:00 2001 From: abdul124 Date: Wed, 31 Jul 2024 15:53:51 +0800 Subject: [PATCH] core: implement np_dot using LLVM_IR --- nac3core/src/codegen/builtin_fns.rs | 28 ----- nac3core/src/codegen/extern_fns.rs | 30 ----- nac3core/src/codegen/numpy.rs | 106 +++++++++++++++++- nac3core/src/toplevel/builtins.rs | 9 +- nac3core/src/typecheck/type_inferencer/mod.rs | 38 +++++++ nac3standalone/demo/linalg/src/lib.rs | 32 ------ nac3standalone/demo/src/ndarray.py | 27 ++++- 7 files changed, 165 insertions(+), 105 deletions(-) diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 733cb9f36..4bc7c913c 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -1865,34 +1865,6 @@ fn build_output_struct<'ctx>( out_ptr } -/// Invokes the `np_dot` linalg function -pub fn call_np_dot<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - const FN_NAME: &str = "np_dot"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; - - if let (BasicValueEnum::PointerValue(_), BasicValueEnum::PointerValue(_)) = (x1, x2) { - let (n1_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1_elem_ty = ctx.get_llvm_type(generator, n1_elem_ty); - let (n2_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - let n2_elem_ty = ctx.get_llvm_type(generator, n2_elem_ty); - - let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty) - else { - unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]); - }; - - Ok(extern_fns::call_np_dot(ctx, x1, x2, None).into()) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) - } -} - /// Invokes the `np_linalg_matmul` linalg function pub fn call_np_linalg_matmul<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, diff --git a/nac3core/src/codegen/extern_fns.rs b/nac3core/src/codegen/extern_fns.rs index ba8403c58..089a94f5e 100644 --- a/nac3core/src/codegen/extern_fns.rs +++ b/nac3core/src/codegen/extern_fns.rs @@ -188,33 +188,3 @@ generate_linalg_extern_fn!(call_np_linalg_pinv, "np_linalg_pinv", 2); generate_linalg_extern_fn!(call_sp_linalg_lu, "sp_linalg_lu", 3); generate_linalg_extern_fn!(call_sp_linalg_schur, "sp_linalg_schur", 3); generate_linalg_extern_fn!(call_sp_linalg_hessenberg, "sp_linalg_hessenberg", 3); - -/// Invokes the linalg `np_dot` function. -pub fn call_np_dot<'ctx>( - ctx: &mut CodeGenContext<'ctx, '_>, - mat1: BasicValueEnum<'ctx>, - mat2: BasicValueEnum<'ctx>, - name: Option<&str>, -) -> FloatValue<'ctx> { - const FN_NAME: &str = "np_dot"; - - let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| { - let fn_type = - ctx.ctx.f64_type().fn_type(&[mat1.get_type().into(), mat2.get_type().into()], false); - let func = ctx.module.add_function(FN_NAME, fn_type, None); - for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { - func.add_attribute( - AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), - ); - } - func - }); - - ctx.builder - .build_call(extern_fn, &[mat1.into(), mat2.into()], name.unwrap_or_default()) - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_float_value)) - .map(Either::unwrap_left) - .unwrap() -} diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 4ab1391e2..f4299ff5a 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -26,12 +26,15 @@ use crate::{ typedef::{FunSignature, Type, TypeEnum}, }, }; -use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType}; use inkwell::{ types::BasicType, values::{BasicValueEnum, IntValue, PointerValue}, AddressSpace, IntPredicate, OptimizationLevel, }; +use inkwell::{ + types::{AnyTypeEnum, BasicTypeEnum, PointerType}, + values::BasicValue, +}; use nac3parser::ast::{Operator, StrRef}; /// Creates an uninitialized `NDArray` instance. @@ -2390,7 +2393,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>( generator, ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(), "0:ValueError", - "cannot reshape array of size {} into provided shape of size {}", + "cannot reshape array of size {0} into provided shape of size {1}", [Some(n_sz), Some(out_sz), None], ctx.current_loc, ); @@ -2417,3 +2420,102 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>( ) } } + +/// Generates LLVM IR for `ndarray.dot`. +/// Calculate inner product of two vectors or literals +/// For matrix multiplication use `np_matmul` +/// +/// The input `NDArray` are flattened and treated as 1D +/// The operation is equivalent to np.dot(arr1.ravel(), arr2.ravel()) +pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + x1: (Type, BasicValueEnum<'ctx>), + x2: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "ndarray_dot"; + let (x1_ty, x1) = x1; + let (_, x2) = x2; + + let llvm_usize = generator.get_size_type(ctx.ctx); + + match (x1, x2) { + (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => { + let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None); + + let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); + let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); + + ctx.make_assert( + generator, + ctx.builder.build_int_compare(IntPredicate::EQ, n1_sz, n2_sz, "").unwrap(), + "0:ValueError", + "shapes ({0}), ({1}) not aligned", + [Some(n1_sz), Some(n2_sz), None], + ctx.current_loc, + ); + + let identity = + unsafe { n1.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; + let acc = ctx.builder.build_alloca(identity.get_type(), "").unwrap(); + ctx.builder.build_store(acc, identity.get_type().const_zero()).unwrap(); + + gen_for_callback_incrementing( + generator, + ctx, + None, + llvm_usize.const_zero(), + (n1_sz, false), + |generator, ctx, _, idx| { + let elem1 = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) }; + let elem2 = unsafe { n2.data().get_unchecked(ctx, generator, &idx, None) }; + + let product = match elem1 { + BasicValueEnum::IntValue(e1) => ctx + .builder + .build_int_mul(e1, elem2.into_int_value(), "") + .unwrap() + .as_basic_value_enum(), + BasicValueEnum::FloatValue(e1) => ctx + .builder + .build_float_mul(e1, elem2.into_float_value(), "") + .unwrap() + .as_basic_value_enum(), + _ => unreachable!(), + }; + let acc_val = ctx.builder.build_load(acc, "").unwrap(); + let acc_val = match acc_val { + BasicValueEnum::IntValue(e1) => ctx + .builder + .build_int_add(e1, product.into_int_value(), "") + .unwrap() + .as_basic_value_enum(), + BasicValueEnum::FloatValue(e1) => ctx + .builder + .build_float_add(e1, product.into_float_value(), "") + .unwrap() + .as_basic_value_enum(), + _ => unreachable!(), + }; + ctx.builder.build_store(acc, acc_val).unwrap(); + + Ok(()) + }, + llvm_usize.const_int(1, false), + )?; + let acc_val = ctx.builder.build_load(acc, "").unwrap(); + Ok(acc_val) + } + (BasicValueEnum::IntValue(e1), BasicValueEnum::IntValue(e2)) => { + Ok(ctx.builder.build_int_mul(e1, e2, "").unwrap().as_basic_value_enum()) + } + (BasicValueEnum::FloatValue(e1), BasicValueEnum::FloatValue(e2)) => { + Ok(ctx.builder.build_float_mul(e1, e2, "").unwrap().as_basic_value_enum()) + } + _ => unreachable!( + "{FN_NAME}() not supported for '{}'", + format!("'{}'", ctx.unifier.stringify(x1_ty)) + ), + } +} diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 18f0be656..746ea0520 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1965,7 +1965,7 @@ impl<'a> BuiltinBuilder<'a> { self.unifier, &self.num_or_ndarray_var_map, prim.name(), - self.primitives.float, + self.num_ty.ty, &[(self.num_or_ndarray_ty.ty, "x1"), (self.num_or_ndarray_ty.ty, "x2")], Box::new(move |ctx, _, fun, args, generator| { let x1_ty = fun.0.args[0].ty; @@ -1973,12 +1973,7 @@ impl<'a> BuiltinBuilder<'a> { let x2_ty = fun.0.args[1].ty; let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - Ok(Some(builtin_fns::call_np_dot( - generator, - ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - )?)) + Ok(Some(ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) }), ), diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 8e8cba39d..f4d3a62e2 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -1130,6 +1130,44 @@ impl<'a> Inferencer<'a> { })); } + if id == &"np_dot".into() { + let arg0 = self.fold_expr(args.remove(0))?; + let arg1 = self.fold_expr(args.remove(0))?; + let arg0_ty = arg0.custom.unwrap(); + + let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) + { + let (ndarray_dtype, _) = unpack_ndarray_var_tys(self.unifier, arg0_ty); + + ndarray_dtype + } else { + arg0_ty + }; + + let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![ + FuncArg { name: "x1".into(), ty: arg0.custom.unwrap(), default_value: None }, + FuncArg { name: "x2".into(), ty: arg1.custom.unwrap(), default_value: None }, + ], + ret, + vars: VarMap::new(), + })); + + return Ok(Some(Located { + location, + custom: Some(ret), + node: ExprKind::Call { + func: Box::new(Located { + custom: Some(custom), + location: func.location, + node: ExprKind::Name { id: *id, ctx: *ctx }, + }), + args: vec![arg0, arg1], + keywords: vec![], + }, + })); + } + if ["np_min", "np_max"].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 { let arg0 = self.fold_expr(args.remove(0))?; let arg0_ty = arg0.custom.unwrap(); diff --git a/nac3standalone/demo/linalg/src/lib.rs b/nac3standalone/demo/linalg/src/lib.rs index 1e6091009..76933c13c 100644 --- a/nac3standalone/demo/linalg/src/lib.rs +++ b/nac3standalone/demo/linalg/src/lib.rs @@ -34,38 +34,6 @@ impl InputMatrix { } } -/// # Safety -/// -/// `mat1` and `mat2` should point to a valid 1DArray of `f64` floats in row-major order -#[no_mangle] -pub unsafe extern "C" fn np_dot(mat1: *mut InputMatrix, mat2: *mut InputMatrix) -> f64 { - let mat1 = mat1.as_mut().unwrap(); - let mat2 = mat2.as_mut().unwrap(); - - if !(mat1.ndims == 1 && mat2.ndims == 1) { - let err_msg = format!( - "expected 1D Vector Input, but received {}D and {}D input", - mat1.ndims, mat2.ndims - ); - report_error("ValueError", "np_dot", file!(), line!(), column!(), &err_msg); - } - - let dim1 = (*mat1).get_dims(); - let dim2 = (*mat2).get_dims(); - - if dim1[0] != dim2[0] { - let err_msg = format!("shapes ({},) and ({},) not aligned", dim1[0], dim2[0]); - report_error("ValueError", "np_dot", file!(), line!(), column!(), &err_msg); - } - let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0]) }; - let data_slice2 = unsafe { slice::from_raw_parts_mut(mat2.data, dim2[0]) }; - - let matrix1 = DMatrix::from_row_slice(dim1[0], 1, data_slice1); - let matrix2 = DMatrix::from_row_slice(dim2[0], 1, data_slice2); - - matrix1.dot(&matrix2) -} - /// # Safety /// /// `mat1` and `mat2` should point to a valid 2DArray of `f64` floats in row-major order diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index a7ec89185..beb0c1d29 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -1451,13 +1451,28 @@ def test_ndarray_reshape(): output_ndarray_float_1(z) def test_ndarray_dot(): - x: ndarray[float, 1] = np_array([5.0, 1.0]) - y: ndarray[float, 1] = np_array([5.0, 1.0]) - z = np_dot(x, y) + x1: ndarray[float, 1] = np_array([5.0, 1.0, 4.0, 2.0]) + y1: ndarray[float, 1] = np_array([5.0, 1.0, 6.0, 6.0]) + z1 = np_dot(x1, y1) - output_ndarray_float_1(x) - output_ndarray_float_1(y) - output_float64(z) + x2: ndarray[int32, 1] = np_array([5, 1, 4, 2]) + y2: ndarray[int32, 1] = np_array([5, 1, 6, 6]) + z2 = np_dot(x2, y2) + + x3: ndarray[bool, 1] = np_array([True, True, True, True]) + y3: ndarray[bool, 1] = np_array([True, True, True, True]) + z3 = np_dot(x3, y3) + + z4 = np_dot(2, 3) + z5 = np_dot(2., 3.) + z6 = np_dot(True, False) + + output_float64(z1) + output_int32(z2) + output_bool(z3) + output_int32(z4) + output_float64(z5) + output_bool(z6) def test_ndarray_linalg_matmul(): x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])