From 4baf1c64ed181213aa39e396cd89958a6f2dc621 Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 20 Aug 2024 21:18:59 +0800 Subject: [PATCH] core/ndstrides: implement np_dot() for scalars and 1D --- nac3core/src/codegen/numpy.rs | 124 +++++++++++++++++------------- nac3core/src/toplevel/builtins.rs | 4 +- 2 files changed, 72 insertions(+), 56 deletions(-) diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 9ab0ec43..31b658f4 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -15,9 +15,12 @@ use crate::{ model::*, object::{ any::AnyObject, - ndarray::{shape_util::parse_numpy_int_sequence, NDArrayObject}, + ndarray::{nditer::NDIterHandle, shape_util::parse_numpy_int_sequence, NDArrayObject}, + }, + stmt::{ + gen_for_callback, gen_for_callback_incrementing, gen_for_range_callback, + gen_if_else_expr_callback, }, - stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, CodeGenContext, CodeGenerator, }, symbol_resolver::ValueEnum, @@ -1705,77 +1708,88 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "ndarray_dot"; let (x1_ty, x1) = x1; - let (_, x2) = x2; - - let llvm_usize = generator.get_size_type(ctx.ctx); + let (x2_ty, x2) = x2; match (x1, x2) { - (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => { - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); - let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None); + (BasicValueEnum::PointerValue(_), BasicValueEnum::PointerValue(_)) => { + let a = AnyObject { ty: x1_ty, value: x1 }; + let b = AnyObject { ty: x2_ty, value: x2 }; - let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); - let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); + let a = NDArrayObject::from_object(generator, ctx, a); + let b = NDArrayObject::from_object(generator, ctx, b); + // TODO: General `np.dot()` https://numpy.org/doc/stable/reference/generated/numpy.dot.html. + assert_eq!(a.ndims, 1); + assert_eq!(b.ndims, 1); + let common_dtype = a.dtype; + + // Check shapes. + let a_size = a.size(generator, ctx); + let b_size = b.size(generator, ctx); + let same_shape = a_size.compare(ctx, IntPredicate::EQ, b_size); ctx.make_assert( generator, - ctx.builder.build_int_compare(IntPredicate::EQ, n1_sz, n2_sz, "").unwrap(), + same_shape.value, "0:ValueError", - "shapes ({0}), ({1}) not aligned", - [Some(n1_sz), Some(n2_sz), None], + "shapes ({0},) and ({1},) not aligned: {0} (dim 0) != {1} (dim 1)", + [Some(a_size.value), Some(b_size.value), None], ctx.current_loc, ); - let identity = - unsafe { n1.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; - let acc = ctx.builder.build_alloca(identity.get_type(), "").unwrap(); - ctx.builder.build_store(acc, identity.get_type().const_zero()).unwrap(); + let dtype_llvm = ctx.get_llvm_type(generator, common_dtype); - gen_for_callback_incrementing( + let result = ctx.builder.build_alloca(dtype_llvm, "np_dot_result").unwrap(); + ctx.builder.build_store(result, dtype_llvm.const_zero()).unwrap(); + + // Do dot product. + gen_for_callback( generator, ctx, - None, - llvm_usize.const_zero(), - (n1_sz, false), - |generator, ctx, _, idx| { - let elem1 = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) }; - let elem2 = unsafe { n2.data().get_unchecked(ctx, generator, &idx, None) }; + Some("np_dot"), + |generator, ctx| { + let a_iter = NDIterHandle::new(generator, ctx, a); + let b_iter = NDIterHandle::new(generator, ctx, b); + Ok((a_iter, b_iter)) + }, + |generator, ctx, (a_iter, _b_iter)| { + // Only a_iter drives the condition, b_iter should have the same status. + Ok(a_iter.has_element(generator, ctx).value) + }, + |generator, ctx, _hooks, (a_iter, b_iter)| { + let a_scalar = a_iter.get_scalar(generator, ctx).value; + let b_scalar = b_iter.get_scalar(generator, ctx).value; - let product = match elem1 { - BasicValueEnum::IntValue(e1) => ctx - .builder - .build_int_mul(e1, elem2.into_int_value(), "") - .unwrap() - .as_basic_value_enum(), - BasicValueEnum::FloatValue(e1) => ctx - .builder - .build_float_mul(e1, elem2.into_float_value(), "") - .unwrap() - .as_basic_value_enum(), - _ => codegen_unreachable!(ctx), + let old_result = ctx.builder.build_load(result, "").unwrap(); + let new_result: BasicValueEnum<'ctx> = match old_result { + BasicValueEnum::IntValue(old_result) => { + let a_scalar = a_scalar.into_int_value(); + let b_scalar = b_scalar.into_int_value(); + let x = ctx.builder.build_int_mul(a_scalar, b_scalar, "").unwrap(); + ctx.builder.build_int_add(old_result, x, "").unwrap().into() + } + BasicValueEnum::FloatValue(old_result) => { + let a_scalar = a_scalar.into_float_value(); + let b_scalar = b_scalar.into_float_value(); + let x = ctx.builder.build_float_mul(a_scalar, b_scalar, "").unwrap(); + ctx.builder.build_float_add(old_result, x, "").unwrap().into() + } + _ => { + panic!("Unrecognized dtype: {}", ctx.unifier.stringify(common_dtype)); + } }; - let acc_val = ctx.builder.build_load(acc, "").unwrap(); - let acc_val = match acc_val { - BasicValueEnum::IntValue(e1) => ctx - .builder - .build_int_add(e1, product.into_int_value(), "") - .unwrap() - .as_basic_value_enum(), - BasicValueEnum::FloatValue(e1) => ctx - .builder - .build_float_add(e1, product.into_float_value(), "") - .unwrap() - .as_basic_value_enum(), - _ => codegen_unreachable!(ctx), - }; - ctx.builder.build_store(acc, acc_val).unwrap(); + ctx.builder.build_store(result, new_result).unwrap(); Ok(()) }, - llvm_usize.const_int(1, false), - )?; - let acc_val = ctx.builder.build_load(acc, "").unwrap(); - Ok(acc_val) + |generator, ctx, (a_iter, b_iter)| { + a_iter.next(generator, ctx); + b_iter.next(generator, ctx); + Ok(()) + }, + ) + .unwrap(); + + Ok(ctx.builder.build_load(result, "").unwrap()) } (BasicValueEnum::IntValue(e1), BasicValueEnum::IntValue(e2)) => { Ok(ctx.builder.build_int_mul(e1, e2, "").unwrap().as_basic_value_enum()) diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 7eae6a2c..1a810f28 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -2078,10 +2078,12 @@ impl<'a> BuiltinBuilder<'a> { Box::new(move |ctx, _, fun, args, generator| { let x1_ty = fun.0.args[0].ty; let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - Ok(Some(ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) + let result = ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?; + Ok(Some(result)) }), ),