From 3ac1083734ff093c77244cc675efb2b9151a56ae Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 19 Dec 2024 12:32:18 +0800 Subject: [PATCH] [core] codegen: Reimplement np_dot() for scalars and 1D Based on 693b7f37: core/ndstrides: implement np_dot() for scalars and 1D --- nac3core/src/codegen/numpy.rs | 136 +++++++++++++++++------------- nac3core/src/toplevel/builtins.rs | 4 +- 2 files changed, 79 insertions(+), 61 deletions(-) diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index d46a61198..e5a893c9c 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -7,14 +7,18 @@ use nac3parser::ast::StrRef; use super::{ macros::codegen_unreachable, - stmt::gen_for_callback_incrementing, - types::ndarray::NDArrayType, - values::{ndarray::shape::parse_numpy_int_sequence, ProxyValue, UntypedArrayLikeAccessor}, + stmt::gen_for_callback, + types::ndarray::{NDArrayType, NDIterType}, + values::{ndarray::shape::parse_numpy_int_sequence, ProxyValue}, CodeGenContext, CodeGenerator, }; use crate::{ symbol_resolver::ValueEnum, - toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys, DefinitionId}, + toplevel::{ + helper::{arraylike_flatten_element_type, extract_ndims}, + numpy::unpack_ndarray_var_tys, + DefinitionId, + }, typecheck::typedef::{FunSignature, Type}, }; @@ -300,89 +304,101 @@ pub fn gen_ndarray_fill<'ctx>( pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "ndarray_dot"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; - - let llvm_usize = generator.get_size_type(ctx.ctx); match (x1, x2) { (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => { - let n1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(n1, None); - let n2 = NDArrayType::from_unifier_type(generator, ctx, x2_ty).map_value(n2, None); + let a = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(n1, None); + let b = NDArrayType::from_unifier_type(generator, ctx, x2_ty).map_value(n2, None); - let n1_sz = n1.size(generator, ctx); - let n2_sz = n2.size(generator, ctx); + // TODO: General `np.dot()` https://numpy.org/doc/stable/reference/generated/numpy.dot.html. + assert!(a.get_type().ndims().is_some_and(|ndims| ndims == 1)); + assert!(b.get_type().ndims().is_some_and(|ndims| ndims == 1)); + let common_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty); + // Check shapes. + let a_size = a.size(generator, ctx); + let b_size = b.size(generator, ctx); + let same_shape = + ctx.builder.build_int_compare(IntPredicate::EQ, a_size, b_size, "").unwrap(); ctx.make_assert( generator, - ctx.builder.build_int_compare(IntPredicate::EQ, n1_sz, n2_sz, "").unwrap(), + same_shape, "0:ValueError", - "shapes ({0}), ({1}) not aligned", - [Some(n1_sz), Some(n2_sz), None], + "shapes ({0},) and ({1},) not aligned: {0} (dim 0) != {1} (dim 1)", + [Some(a_size), Some(b_size), None], ctx.current_loc, ); - let identity = - unsafe { n1.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; - let acc = ctx.builder.build_alloca(identity.get_type(), "").unwrap(); - ctx.builder.build_store(acc, identity.get_type().const_zero()).unwrap(); + let dtype_llvm = ctx.get_llvm_type(generator, common_dtype); - gen_for_callback_incrementing( + let result = ctx.builder.build_alloca(dtype_llvm, "np_dot_result").unwrap(); + ctx.builder.build_store(result, dtype_llvm.const_zero()).unwrap(); + + // Do dot product. + gen_for_callback( generator, ctx, - None, - llvm_usize.const_zero(), - (n1_sz, false), - |generator, ctx, _, idx| { - let elem1 = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) }; - let elem2 = unsafe { n2.data().get_unchecked(ctx, generator, &idx, None) }; + Some("np_dot"), + |generator, ctx| { + let a_iter = NDIterType::new(generator, ctx.ctx).construct(generator, ctx, a); + let b_iter = NDIterType::new(generator, ctx.ctx).construct(generator, ctx, b); + Ok((a_iter, b_iter)) + }, + |generator, ctx, (a_iter, _b_iter)| { + // Only a_iter drives the condition, b_iter should have the same status. + Ok(a_iter.has_element(generator, ctx)) + }, + |_, ctx, _hooks, (a_iter, b_iter)| { + let a_scalar = a_iter.get_scalar(ctx); + let b_scalar = b_iter.get_scalar(ctx); - let product = match elem1 { - BasicValueEnum::IntValue(e1) => ctx - .builder - .build_int_mul(e1, elem2.into_int_value(), "") - .unwrap() - .as_basic_value_enum(), - BasicValueEnum::FloatValue(e1) => ctx - .builder - .build_float_mul(e1, elem2.into_float_value(), "") - .unwrap() - .as_basic_value_enum(), - _ => codegen_unreachable!(ctx, "product: {}", elem1.get_type()), - }; - let acc_val = ctx.builder.build_load(acc, "").unwrap(); - let acc_val = match acc_val { - BasicValueEnum::IntValue(e1) => ctx - .builder - .build_int_add(e1, product.into_int_value(), "") - .unwrap() - .as_basic_value_enum(), - BasicValueEnum::FloatValue(e1) => ctx - .builder - .build_float_add(e1, product.into_float_value(), "") - .unwrap() - .as_basic_value_enum(), - _ => codegen_unreachable!(ctx, "acc_val: {}", acc_val.get_type()), - }; - ctx.builder.build_store(acc, acc_val).unwrap(); + let old_result = ctx.builder.build_load(result, "").unwrap(); + let new_result: BasicValueEnum<'ctx> = match old_result { + BasicValueEnum::IntValue(old_result) => { + let a_scalar = a_scalar.into_int_value(); + let b_scalar = b_scalar.into_int_value(); + let x = ctx.builder.build_int_mul(a_scalar, b_scalar, "").unwrap(); + ctx.builder.build_int_add(old_result, x, "").unwrap().into() + } + BasicValueEnum::FloatValue(old_result) => { + let a_scalar = a_scalar.into_float_value(); + let b_scalar = b_scalar.into_float_value(); + let x = ctx.builder.build_float_mul(a_scalar, b_scalar, "").unwrap(); + ctx.builder.build_float_add(old_result, x, "").unwrap().into() + } + + _ => { + panic!("Unrecognized dtype: {}", ctx.unifier.stringify(common_dtype)); + } + }; + + ctx.builder.build_store(result, new_result).unwrap(); Ok(()) }, - llvm_usize.const_int(1, false), - )?; - let acc_val = ctx.builder.build_load(acc, "").unwrap(); - Ok(acc_val) + |generator, ctx, (a_iter, b_iter)| { + a_iter.next(generator, ctx); + b_iter.next(generator, ctx); + Ok(()) + }, + ) + .unwrap(); + + Ok(ctx.builder.build_load(result, "").unwrap()) } + (BasicValueEnum::IntValue(e1), BasicValueEnum::IntValue(e2)) => { Ok(ctx.builder.build_int_mul(e1, e2, "").unwrap().as_basic_value_enum()) } + (BasicValueEnum::FloatValue(e1), BasicValueEnum::FloatValue(e2)) => { Ok(ctx.builder.build_float_mul(e1, e2, "").unwrap().as_basic_value_enum()) } + _ => codegen_unreachable!( ctx, "{FN_NAME}() not supported for '{}'", diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 538961a6b..600276b78 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1935,10 +1935,12 @@ impl<'a> BuiltinBuilder<'a> { Box::new(move |ctx, _, fun, args, generator| { let x1_ty = fun.0.args[0].ty; let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - Ok(Some(ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) + let result = ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?; + Ok(Some(result)) }), ),