core/ndstrides: implement np_dot() for scalars and 1D

This commit is contained in:
lyken 2024-08-20 21:18:59 +08:00
parent b416ece921
commit 75b2e80418
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
2 changed files with 72 additions and 56 deletions

View File

@ -14,9 +14,12 @@ use crate::{
model::*,
object::{
any::AnyObject,
ndarray::{shape_util::parse_numpy_int_sequence, NDArrayObject},
ndarray::{nditer::NDIterHandle, shape_util::parse_numpy_int_sequence, NDArrayObject},
},
stmt::{
gen_for_callback, gen_for_callback_incrementing, gen_for_range_callback,
gen_if_else_expr_callback,
},
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
CodeGenContext, CodeGenerator,
},
symbol_resolver::ValueEnum,
@ -1704,77 +1707,88 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "ndarray_dot";
let (x1_ty, x1) = x1;
let (_, x2) = x2;
let llvm_usize = generator.get_size_type(ctx.ctx);
let (x2_ty, x2) = x2;
match (x1, x2) {
(BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => {
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None);
(BasicValueEnum::PointerValue(_), BasicValueEnum::PointerValue(_)) => {
let a = AnyObject { ty: x1_ty, value: x1 };
let b = AnyObject { ty: x2_ty, value: x2 };
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
let a = NDArrayObject::from_object(generator, ctx, a);
let b = NDArrayObject::from_object(generator, ctx, b);
// TODO: General `np.dot()` https://numpy.org/doc/stable/reference/generated/numpy.dot.html.
assert_eq!(a.ndims, 1);
assert_eq!(b.ndims, 1);
let common_dtype = a.dtype;
// Check shapes.
let a_size = a.size(generator, ctx);
let b_size = b.size(generator, ctx);
let same_shape = a_size.compare(ctx, IntPredicate::EQ, b_size);
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, n1_sz, n2_sz, "").unwrap(),
same_shape.value,
"0:ValueError",
"shapes ({0}), ({1}) not aligned",
[Some(n1_sz), Some(n2_sz), None],
"shapes ({0},) and ({1},) not aligned: {0} (dim 0) != {1} (dim 1)",
[Some(a_size.value), Some(b_size.value), None],
ctx.current_loc,
);
let identity =
unsafe { n1.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
let acc = ctx.builder.build_alloca(identity.get_type(), "").unwrap();
ctx.builder.build_store(acc, identity.get_type().const_zero()).unwrap();
let dtype_llvm = ctx.get_llvm_type(generator, common_dtype);
gen_for_callback_incrementing(
let result = ctx.builder.build_alloca(dtype_llvm, "np_dot_result").unwrap();
ctx.builder.build_store(result, dtype_llvm.const_zero()).unwrap();
// Do dot product.
gen_for_callback(
generator,
ctx,
None,
llvm_usize.const_zero(),
(n1_sz, false),
|generator, ctx, _, idx| {
let elem1 = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
let elem2 = unsafe { n2.data().get_unchecked(ctx, generator, &idx, None) };
Some("np_dot"),
|generator, ctx| {
let a_iter = NDIterHandle::new(generator, ctx, a);
let b_iter = NDIterHandle::new(generator, ctx, b);
Ok((a_iter, b_iter))
},
|generator, ctx, (a_iter, _b_iter)| {
// Only a_iter drives the condition, b_iter should have the same status.
Ok(a_iter.has_next(generator, ctx).value)
},
|generator, ctx, _hooks, (a_iter, b_iter)| {
let a_scalar = a_iter.get_scalar(generator, ctx).value;
let b_scalar = b_iter.get_scalar(generator, ctx).value;
let product = match elem1 {
BasicValueEnum::IntValue(e1) => ctx
.builder
.build_int_mul(e1, elem2.into_int_value(), "")
.unwrap()
.as_basic_value_enum(),
BasicValueEnum::FloatValue(e1) => ctx
.builder
.build_float_mul(e1, elem2.into_float_value(), "")
.unwrap()
.as_basic_value_enum(),
_ => unreachable!(),
let old_result = ctx.builder.build_load(result, "").unwrap();
let new_result: BasicValueEnum<'ctx> = match old_result {
BasicValueEnum::IntValue(old_result) => {
let a_scalar = a_scalar.into_int_value();
let b_scalar = b_scalar.into_int_value();
let x = ctx.builder.build_int_mul(a_scalar, b_scalar, "").unwrap();
ctx.builder.build_int_add(old_result, x, "").unwrap().into()
}
BasicValueEnum::FloatValue(old_result) => {
let a_scalar = a_scalar.into_float_value();
let b_scalar = b_scalar.into_float_value();
let x = ctx.builder.build_float_mul(a_scalar, b_scalar, "").unwrap();
ctx.builder.build_float_add(old_result, x, "").unwrap().into()
}
_ => {
panic!("Unrecognized dtype: {}", ctx.unifier.stringify(common_dtype));
}
};
let acc_val = ctx.builder.build_load(acc, "").unwrap();
let acc_val = match acc_val {
BasicValueEnum::IntValue(e1) => ctx
.builder
.build_int_add(e1, product.into_int_value(), "")
.unwrap()
.as_basic_value_enum(),
BasicValueEnum::FloatValue(e1) => ctx
.builder
.build_float_add(e1, product.into_float_value(), "")
.unwrap()
.as_basic_value_enum(),
_ => unreachable!(),
};
ctx.builder.build_store(acc, acc_val).unwrap();
ctx.builder.build_store(result, new_result).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let acc_val = ctx.builder.build_load(acc, "").unwrap();
Ok(acc_val)
|generator, ctx, (a_iter, b_iter)| {
a_iter.next(generator, ctx);
b_iter.next(generator, ctx);
Ok(())
},
)
.unwrap();
Ok(ctx.builder.build_load(result, "").unwrap())
}
(BasicValueEnum::IntValue(e1), BasicValueEnum::IntValue(e2)) => {
Ok(ctx.builder.build_int_mul(e1, e2, "").unwrap().as_basic_value_enum())

View File

@ -2078,10 +2078,12 @@ impl<'a> BuiltinBuilder<'a> {
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
let result = ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?;
Ok(Some(result))
}),
),