[core] codegen: Reimplement np_dot() for scalars and 1D

Based on 693b7f37: core/ndstrides: implement np_dot() for scalars and 1D
This commit is contained in:
David Mak 2024-12-19 12:32:18 +08:00
parent a525592941
commit 14e14cb6eb
8 changed files with 87 additions and 75 deletions

View File

@ -7,14 +7,18 @@ use nac3parser::ast::StrRef;
use super::{
macros::codegen_unreachable,
stmt::gen_for_callback_incrementing,
types::ndarray::NDArrayType,
values::{ndarray::shape::parse_numpy_int_sequence, ProxyValue, UntypedArrayLikeAccessor},
stmt::gen_for_callback,
types::ndarray::{NDArrayType, NDIterType},
values::{ndarray::shape::parse_numpy_int_sequence, ProxyValue},
CodeGenContext, CodeGenerator,
};
use crate::{
symbol_resolver::ValueEnum,
toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys, DefinitionId},
toplevel::{
helper::{arraylike_flatten_element_type, extract_ndims},
numpy::unpack_ndarray_var_tys,
DefinitionId,
},
typecheck::typedef::{FunSignature, Type},
};
@ -300,89 +304,101 @@ pub fn gen_ndarray_fill<'ctx>(
pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
(x2_ty, x2): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "ndarray_dot";
let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
let llvm_usize = generator.get_size_type(ctx.ctx);
match (x1, x2) {
(BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => {
let n1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(n1, None);
let n2 = NDArrayType::from_unifier_type(generator, ctx, x2_ty).map_value(n2, None);
let a = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(n1, None);
let b = NDArrayType::from_unifier_type(generator, ctx, x2_ty).map_value(n2, None);
let n1_sz = n1.size(generator, ctx);
let n2_sz = n2.size(generator, ctx);
// TODO: General `np.dot()` https://numpy.org/doc/stable/reference/generated/numpy.dot.html.
assert!(a.get_type().ndims().is_some_and(|ndims| ndims == 1));
assert!(b.get_type().ndims().is_some_and(|ndims| ndims == 1));
let common_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty);
// Check shapes.
let a_size = a.size(generator, ctx);
let b_size = b.size(generator, ctx);
let same_shape =
ctx.builder.build_int_compare(IntPredicate::EQ, a_size, b_size, "").unwrap();
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, n1_sz, n2_sz, "").unwrap(),
same_shape,
"0:ValueError",
"shapes ({0}), ({1}) not aligned",
[Some(n1_sz), Some(n2_sz), None],
"shapes ({0},) and ({1},) not aligned: {0} (dim 0) != {1} (dim 1)",
[Some(a_size), Some(b_size), None],
ctx.current_loc,
);
let identity =
unsafe { n1.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
let acc = ctx.builder.build_alloca(identity.get_type(), "").unwrap();
ctx.builder.build_store(acc, identity.get_type().const_zero()).unwrap();
let dtype_llvm = ctx.get_llvm_type(generator, common_dtype);
gen_for_callback_incrementing(
let result = ctx.builder.build_alloca(dtype_llvm, "np_dot_result").unwrap();
ctx.builder.build_store(result, dtype_llvm.const_zero()).unwrap();
// Do dot product.
gen_for_callback(
generator,
ctx,
None,
llvm_usize.const_zero(),
(n1_sz, false),
|generator, ctx, _, idx| {
let elem1 = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
let elem2 = unsafe { n2.data().get_unchecked(ctx, generator, &idx, None) };
Some("np_dot"),
|generator, ctx| {
let a_iter = NDIterType::new(generator, ctx.ctx).construct(generator, ctx, a);
let b_iter = NDIterType::new(generator, ctx.ctx).construct(generator, ctx, b);
Ok((a_iter, b_iter))
},
|generator, ctx, (a_iter, _b_iter)| {
// Only a_iter drives the condition, b_iter should have the same status.
Ok(a_iter.has_element(generator, ctx))
},
|_, ctx, _hooks, (a_iter, b_iter)| {
let a_scalar = a_iter.get_scalar(ctx);
let b_scalar = b_iter.get_scalar(ctx);
let product = match elem1 {
BasicValueEnum::IntValue(e1) => ctx
.builder
.build_int_mul(e1, elem2.into_int_value(), "")
.unwrap()
.as_basic_value_enum(),
BasicValueEnum::FloatValue(e1) => ctx
.builder
.build_float_mul(e1, elem2.into_float_value(), "")
.unwrap()
.as_basic_value_enum(),
_ => codegen_unreachable!(ctx, "product: {}", elem1.get_type()),
};
let acc_val = ctx.builder.build_load(acc, "").unwrap();
let acc_val = match acc_val {
BasicValueEnum::IntValue(e1) => ctx
.builder
.build_int_add(e1, product.into_int_value(), "")
.unwrap()
.as_basic_value_enum(),
BasicValueEnum::FloatValue(e1) => ctx
.builder
.build_float_add(e1, product.into_float_value(), "")
.unwrap()
.as_basic_value_enum(),
_ => codegen_unreachable!(ctx, "acc_val: {}", acc_val.get_type()),
};
ctx.builder.build_store(acc, acc_val).unwrap();
let old_result = ctx.builder.build_load(result, "").unwrap();
let new_result: BasicValueEnum<'ctx> = match old_result {
BasicValueEnum::IntValue(old_result) => {
let a_scalar = a_scalar.into_int_value();
let b_scalar = b_scalar.into_int_value();
let x = ctx.builder.build_int_mul(a_scalar, b_scalar, "").unwrap();
ctx.builder.build_int_add(old_result, x, "").unwrap().into()
}
BasicValueEnum::FloatValue(old_result) => {
let a_scalar = a_scalar.into_float_value();
let b_scalar = b_scalar.into_float_value();
let x = ctx.builder.build_float_mul(a_scalar, b_scalar, "").unwrap();
ctx.builder.build_float_add(old_result, x, "").unwrap().into()
}
_ => {
panic!("Unrecognized dtype: {}", ctx.unifier.stringify(common_dtype));
}
};
ctx.builder.build_store(result, new_result).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let acc_val = ctx.builder.build_load(acc, "").unwrap();
Ok(acc_val)
|generator, ctx, (a_iter, b_iter)| {
a_iter.next(generator, ctx);
b_iter.next(generator, ctx);
Ok(())
},
)
.unwrap();
Ok(ctx.builder.build_load(result, "").unwrap())
}
(BasicValueEnum::IntValue(e1), BasicValueEnum::IntValue(e2)) => {
Ok(ctx.builder.build_int_mul(e1, e2, "").unwrap().as_basic_value_enum())
}
(BasicValueEnum::FloatValue(e1), BasicValueEnum::FloatValue(e2)) => {
Ok(ctx.builder.build_float_mul(e1, e2, "").unwrap().as_basic_value_enum())
}
_ => codegen_unreachable!(
ctx,
"{FN_NAME}() not supported for '{}'",

View File

@ -314,14 +314,8 @@ impl<'ctx> NDArrayValue<'ctx> {
match out {
NDArrayOut::NewNDArray { .. } => result,
NDArrayOut::WriteToNDArray { ndarray: out_ndarray } => {
// let result_shape = result.instance.get(generator, ctx, |f| f.shape);
let result_shape = result.shape();
out_ndarray.assert_can_be_written_by_out(
generator,
ctx,
result.ndims.unwrap(),
result_shape.as_slice_value(ctx, generator),
);
out_ndarray.assert_can_be_written_by_out(generator, ctx, result_shape);
out_ndarray.copy_data_from(generator, ctx, result);
out_ndarray

View File

@ -1936,10 +1936,12 @@ impl<'a> BuiltinBuilder<'a> {
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
let result = ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?;
Ok(Some(result))
}),
),

View File

@ -8,5 +8,5 @@ expression: res_vec
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(262)]\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(261)]\n}\n",
]

View File

@ -7,7 +7,7 @@ expression: res_vec
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar246]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar246\"]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar245]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar245\"]\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",

View File

@ -5,8 +5,8 @@ expression: res_vec
[
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(259)]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(264)]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(258)]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(263)]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",

View File

@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec
---
[
"Class {\nname: \"A\",\nancestors: [\"A[typevar245, typevar246]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar245\", \"typevar246\"]\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[typevar244, typevar245]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar244\", \"typevar245\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",

View File

@ -6,12 +6,12 @@ expression: res_vec
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(265)]\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(264)]\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(273)]\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(272)]\n}\n",
]