From 14e14cb6eb6c89c69e31dbcb0c90367014175c43 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 19 Dec 2024 12:32:18 +0800 Subject: [PATCH] [core] codegen: Reimplement np_dot() for scalars and 1D Based on 693b7f37: core/ndstrides: implement np_dot() for scalars and 1D --- nac3core/src/codegen/numpy.rs | 136 ++++++++++-------- nac3core/src/codegen/values/ndarray/matmul.rs | 8 +- nac3core/src/toplevel/builtins.rs | 4 +- ...el__test__test_analyze__generic_class.snap | 2 +- ...t__test_analyze__inheritance_override.snap | 2 +- ...est__test_analyze__list_tuple_generic.snap | 4 +- ...__toplevel__test__test_analyze__self1.snap | 2 +- ...t__test_analyze__simple_class_compose.snap | 4 +- 8 files changed, 87 insertions(+), 75 deletions(-) diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index d46a6119..e5a893c9 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -7,14 +7,18 @@ use nac3parser::ast::StrRef; use super::{ macros::codegen_unreachable, - stmt::gen_for_callback_incrementing, - types::ndarray::NDArrayType, - values::{ndarray::shape::parse_numpy_int_sequence, ProxyValue, UntypedArrayLikeAccessor}, + stmt::gen_for_callback, + types::ndarray::{NDArrayType, NDIterType}, + values::{ndarray::shape::parse_numpy_int_sequence, ProxyValue}, CodeGenContext, CodeGenerator, }; use crate::{ symbol_resolver::ValueEnum, - toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys, DefinitionId}, + toplevel::{ + helper::{arraylike_flatten_element_type, extract_ndims}, + numpy::unpack_ndarray_var_tys, + DefinitionId, + }, typecheck::typedef::{FunSignature, Type}, }; @@ -300,89 +304,101 @@ pub fn gen_ndarray_fill<'ctx>( pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "ndarray_dot"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; - - let llvm_usize = generator.get_size_type(ctx.ctx); match (x1, x2) { (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => { - let n1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(n1, None); - let n2 = NDArrayType::from_unifier_type(generator, ctx, x2_ty).map_value(n2, None); + let a = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(n1, None); + let b = NDArrayType::from_unifier_type(generator, ctx, x2_ty).map_value(n2, None); - let n1_sz = n1.size(generator, ctx); - let n2_sz = n2.size(generator, ctx); + // TODO: General `np.dot()` https://numpy.org/doc/stable/reference/generated/numpy.dot.html. + assert!(a.get_type().ndims().is_some_and(|ndims| ndims == 1)); + assert!(b.get_type().ndims().is_some_and(|ndims| ndims == 1)); + let common_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty); + // Check shapes. + let a_size = a.size(generator, ctx); + let b_size = b.size(generator, ctx); + let same_shape = + ctx.builder.build_int_compare(IntPredicate::EQ, a_size, b_size, "").unwrap(); ctx.make_assert( generator, - ctx.builder.build_int_compare(IntPredicate::EQ, n1_sz, n2_sz, "").unwrap(), + same_shape, "0:ValueError", - "shapes ({0}), ({1}) not aligned", - [Some(n1_sz), Some(n2_sz), None], + "shapes ({0},) and ({1},) not aligned: {0} (dim 0) != {1} (dim 1)", + [Some(a_size), Some(b_size), None], ctx.current_loc, ); - let identity = - unsafe { n1.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; - let acc = ctx.builder.build_alloca(identity.get_type(), "").unwrap(); - ctx.builder.build_store(acc, identity.get_type().const_zero()).unwrap(); + let dtype_llvm = ctx.get_llvm_type(generator, common_dtype); - gen_for_callback_incrementing( + let result = ctx.builder.build_alloca(dtype_llvm, "np_dot_result").unwrap(); + ctx.builder.build_store(result, dtype_llvm.const_zero()).unwrap(); + + // Do dot product. + gen_for_callback( generator, ctx, - None, - llvm_usize.const_zero(), - (n1_sz, false), - |generator, ctx, _, idx| { - let elem1 = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) }; - let elem2 = unsafe { n2.data().get_unchecked(ctx, generator, &idx, None) }; + Some("np_dot"), + |generator, ctx| { + let a_iter = NDIterType::new(generator, ctx.ctx).construct(generator, ctx, a); + let b_iter = NDIterType::new(generator, ctx.ctx).construct(generator, ctx, b); + Ok((a_iter, b_iter)) + }, + |generator, ctx, (a_iter, _b_iter)| { + // Only a_iter drives the condition, b_iter should have the same status. + Ok(a_iter.has_element(generator, ctx)) + }, + |_, ctx, _hooks, (a_iter, b_iter)| { + let a_scalar = a_iter.get_scalar(ctx); + let b_scalar = b_iter.get_scalar(ctx); - let product = match elem1 { - BasicValueEnum::IntValue(e1) => ctx - .builder - .build_int_mul(e1, elem2.into_int_value(), "") - .unwrap() - .as_basic_value_enum(), - BasicValueEnum::FloatValue(e1) => ctx - .builder - .build_float_mul(e1, elem2.into_float_value(), "") - .unwrap() - .as_basic_value_enum(), - _ => codegen_unreachable!(ctx, "product: {}", elem1.get_type()), - }; - let acc_val = ctx.builder.build_load(acc, "").unwrap(); - let acc_val = match acc_val { - BasicValueEnum::IntValue(e1) => ctx - .builder - .build_int_add(e1, product.into_int_value(), "") - .unwrap() - .as_basic_value_enum(), - BasicValueEnum::FloatValue(e1) => ctx - .builder - .build_float_add(e1, product.into_float_value(), "") - .unwrap() - .as_basic_value_enum(), - _ => codegen_unreachable!(ctx, "acc_val: {}", acc_val.get_type()), - }; - ctx.builder.build_store(acc, acc_val).unwrap(); + let old_result = ctx.builder.build_load(result, "").unwrap(); + let new_result: BasicValueEnum<'ctx> = match old_result { + BasicValueEnum::IntValue(old_result) => { + let a_scalar = a_scalar.into_int_value(); + let b_scalar = b_scalar.into_int_value(); + let x = ctx.builder.build_int_mul(a_scalar, b_scalar, "").unwrap(); + ctx.builder.build_int_add(old_result, x, "").unwrap().into() + } + BasicValueEnum::FloatValue(old_result) => { + let a_scalar = a_scalar.into_float_value(); + let b_scalar = b_scalar.into_float_value(); + let x = ctx.builder.build_float_mul(a_scalar, b_scalar, "").unwrap(); + ctx.builder.build_float_add(old_result, x, "").unwrap().into() + } + + _ => { + panic!("Unrecognized dtype: {}", ctx.unifier.stringify(common_dtype)); + } + }; + + ctx.builder.build_store(result, new_result).unwrap(); Ok(()) }, - llvm_usize.const_int(1, false), - )?; - let acc_val = ctx.builder.build_load(acc, "").unwrap(); - Ok(acc_val) + |generator, ctx, (a_iter, b_iter)| { + a_iter.next(generator, ctx); + b_iter.next(generator, ctx); + Ok(()) + }, + ) + .unwrap(); + + Ok(ctx.builder.build_load(result, "").unwrap()) } + (BasicValueEnum::IntValue(e1), BasicValueEnum::IntValue(e2)) => { Ok(ctx.builder.build_int_mul(e1, e2, "").unwrap().as_basic_value_enum()) } + (BasicValueEnum::FloatValue(e1), BasicValueEnum::FloatValue(e2)) => { Ok(ctx.builder.build_float_mul(e1, e2, "").unwrap().as_basic_value_enum()) } + _ => codegen_unreachable!( ctx, "{FN_NAME}() not supported for '{}'", diff --git a/nac3core/src/codegen/values/ndarray/matmul.rs b/nac3core/src/codegen/values/ndarray/matmul.rs index 6c696bd7..9d3be422 100644 --- a/nac3core/src/codegen/values/ndarray/matmul.rs +++ b/nac3core/src/codegen/values/ndarray/matmul.rs @@ -314,14 +314,8 @@ impl<'ctx> NDArrayValue<'ctx> { match out { NDArrayOut::NewNDArray { .. } => result, NDArrayOut::WriteToNDArray { ndarray: out_ndarray } => { - // let result_shape = result.instance.get(generator, ctx, |f| f.shape); let result_shape = result.shape(); - out_ndarray.assert_can_be_written_by_out( - generator, - ctx, - result.ndims.unwrap(), - result_shape.as_slice_value(ctx, generator), - ); + out_ndarray.assert_can_be_written_by_out(generator, ctx, result_shape); out_ndarray.copy_data_from(generator, ctx, result); out_ndarray diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 1c6b8e9e..54578b9d 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1936,10 +1936,12 @@ impl<'a> BuiltinBuilder<'a> { Box::new(move |ctx, _, fun, args, generator| { let x1_ty = fun.0.args[0].ty; let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - Ok(Some(ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) + let result = ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?; + Ok(Some(result)) }), ), diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index bd5834d3..4332b474 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -8,5 +8,5 @@ expression: res_vec "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(262)]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(261)]\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index df4517b4..60e0c194 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar246]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar246\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar245]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar245\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index b1d9bf99..46601817 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(259)]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(264)]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(258)]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(263)]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index 83746579..da58d121 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar245, typevar246]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar245\", \"typevar246\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar244, typevar245]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar244\", \"typevar245\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index eeef45d2..8f384fa1 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(265)]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(264)]\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(273)]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(272)]\n}\n", ]