core/ndstrides: implement 1D np_dot()

This commit is contained in:
lyken 2024-08-20 21:18:59 +08:00
parent 660928bfba
commit 8312ec2278
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
4 changed files with 106 additions and 111 deletions

View File

@ -5,7 +5,6 @@ use crate::{
ProxyType, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, ProxyType, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter,
TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
}, },
expr::gen_binop_expr_with_values,
irrt::{ irrt::{
calculate_len_for_slice_range, call_ndarray_calc_broadcast, calculate_len_for_slice_range, call_ndarray_calc_broadcast,
call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices, call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices,
@ -26,21 +25,15 @@ use crate::{
numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
DefinitionId, DefinitionId,
}, },
typecheck::{ typecheck::typedef::{FunSignature, Type},
magic_methods::Binop,
typedef::{FunSignature, Type},
},
}; };
use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType};
use inkwell::{ use inkwell::{
types::BasicType, types::BasicType,
values::{BasicValueEnum, IntValue, PointerValue}, values::{BasicValueEnum, IntValue, PointerValue},
AddressSpace, IntPredicate, OptimizationLevel, AddressSpace, IntPredicate,
}; };
use inkwell::{ use nac3parser::ast::StrRef;
types::{AnyTypeEnum, BasicTypeEnum, PointerType},
values::BasicValue,
};
use nac3parser::ast::{Operator, StrRef};
/// Creates an uninitialized `NDArray` instance. /// Creates an uninitialized `NDArray` instance.
fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>( fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
@ -1692,102 +1685,3 @@ pub fn gen_ndarray_fill<'ctx>(
this.fill(generator, context, value_arg); this.fill(generator, context, value_arg);
Ok(()) Ok(())
} }
/// Generates LLVM IR for `ndarray.dot`.
/// Calculate inner product of two vectors or literals
/// For matrix multiplication use `np_matmul`
///
/// The input `NDArray` are flattened and treated as 1D
/// The operation is equivalent to `np.dot(arr1.ravel(), arr2.ravel())`
pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "ndarray_dot";
let (x1_ty, x1) = x1;
let (_, x2) = x2;
let llvm_usize = generator.get_size_type(ctx.ctx);
match (x1, x2) {
(BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => {
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None);
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, n1_sz, n2_sz, "").unwrap(),
"0:ValueError",
"shapes ({0}), ({1}) not aligned",
[Some(n1_sz), Some(n2_sz), None],
ctx.current_loc,
);
let identity =
unsafe { n1.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
let acc = ctx.builder.build_alloca(identity.get_type(), "").unwrap();
ctx.builder.build_store(acc, identity.get_type().const_zero()).unwrap();
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(n1_sz, false),
|generator, ctx, _, idx| {
let elem1 = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
let elem2 = unsafe { n2.data().get_unchecked(ctx, generator, &idx, None) };
let product = match elem1 {
BasicValueEnum::IntValue(e1) => ctx
.builder
.build_int_mul(e1, elem2.into_int_value(), "")
.unwrap()
.as_basic_value_enum(),
BasicValueEnum::FloatValue(e1) => ctx
.builder
.build_float_mul(e1, elem2.into_float_value(), "")
.unwrap()
.as_basic_value_enum(),
_ => unreachable!(),
};
let acc_val = ctx.builder.build_load(acc, "").unwrap();
let acc_val = match acc_val {
BasicValueEnum::IntValue(e1) => ctx
.builder
.build_int_add(e1, product.into_int_value(), "")
.unwrap()
.as_basic_value_enum(),
BasicValueEnum::FloatValue(e1) => ctx
.builder
.build_float_add(e1, product.into_float_value(), "")
.unwrap()
.as_basic_value_enum(),
_ => unreachable!(),
};
ctx.builder.build_store(acc, acc_val).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let acc_val = ctx.builder.build_load(acc, "").unwrap();
Ok(acc_val)
}
(BasicValueEnum::IntValue(e1), BasicValueEnum::IntValue(e2)) => {
Ok(ctx.builder.build_int_mul(e1, e2, "").unwrap().as_basic_value_enum())
}
(BasicValueEnum::FloatValue(e1), BasicValueEnum::FloatValue(e2)) => {
Ok(ctx.builder.build_float_mul(e1, e2, "").unwrap().as_basic_value_enum())
}
_ => unreachable!(
"{FN_NAME}() not supported for '{}'",
format!("'{}'", ctx.unifier.stringify(x1_ty))
),
}
}

View File

@ -0,0 +1,93 @@
use inkwell::{values::BasicValueEnum, IntPredicate};
use crate::codegen::{
object::ndarray::nditer::NDIterHandle, stmt::gen_for_callback, CodeGenContext, CodeGenerator,
};
use super::NDArrayObject;
impl<'ctx> NDArrayObject<'ctx> {
/// Perform `np.dot()`.
///
/// Both ndarrays must be 1D and have the same type.
pub fn dot<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
a: NDArrayObject<'ctx>,
b: NDArrayObject<'ctx>,
) -> BasicValueEnum<'ctx> {
// TODO: General `np.dot()` https://numpy.org/doc/stable/reference/generated/numpy.dot.html.
assert_eq!(a.ndims, 1);
assert_eq!(b.ndims, 1);
assert!(ctx.unifier.unioned(a.dtype, b.dtype));
let common_dtype = a.dtype;
// Check shapes.
let a_size = a.size(generator, ctx);
let b_size = b.size(generator, ctx);
let same_shape = a_size.compare(ctx, IntPredicate::EQ, b_size);
ctx.make_assert(
generator,
same_shape.value,
"0:ValueError",
"shapes ({0},) and ({1},) not aligned: {0} (dim 0) != {1} (dim 1)",
[Some(a_size.value), Some(b_size.value), None],
ctx.current_loc,
);
let dtype_llvm = ctx.get_llvm_type(generator, common_dtype);
let result = ctx.builder.build_alloca(dtype_llvm, "np_dot_result").unwrap();
ctx.builder.build_store(result, dtype_llvm.const_zero()).unwrap();
// Do dot product.
gen_for_callback(
generator,
ctx,
Some("np_dot"),
|generator, ctx| {
let a_iter = NDIterHandle::new(generator, ctx, a);
let b_iter = NDIterHandle::new(generator, ctx, b);
Ok((a_iter, b_iter))
},
|generator, ctx, (a_iter, _b_iter)| {
// Only a_iter drives the condition, b_iter should have the same status.
Ok(a_iter.has_next(generator, ctx).value)
},
|generator, ctx, _hooks, (a_iter, b_iter)| {
let a_scalar = a_iter.get_scalar(generator, ctx).value;
let b_scalar = b_iter.get_scalar(generator, ctx).value;
let old_result = ctx.builder.build_load(result, "").unwrap();
let new_result: BasicValueEnum<'ctx> = match old_result {
BasicValueEnum::IntValue(old_result) => {
let a_scalar = a_scalar.into_int_value();
let b_scalar = b_scalar.into_int_value();
let x = ctx.builder.build_int_mul(a_scalar, b_scalar, "").unwrap();
ctx.builder.build_int_add(old_result, x, "").unwrap().into()
}
BasicValueEnum::FloatValue(old_result) => {
let a_scalar = a_scalar.into_float_value();
let b_scalar = b_scalar.into_float_value();
let x = ctx.builder.build_float_mul(a_scalar, b_scalar, "").unwrap();
ctx.builder.build_float_add(old_result, x, "").unwrap().into()
}
_ => {
panic!("Unrecognized dtype: {}", ctx.unifier.stringify(common_dtype));
}
};
ctx.builder.build_store(result, new_result).unwrap();
Ok(())
},
|generator, ctx, (a_iter, b_iter)| {
a_iter.next(generator, ctx);
b_iter.next(generator, ctx);
Ok(())
},
)
.unwrap();
ctx.builder.build_load(result, "").unwrap()
}
}

View File

@ -1,5 +1,6 @@
pub mod array; pub mod array;
pub mod broadcast; pub mod broadcast;
pub mod dot;
pub mod factory; pub mod factory;
pub mod indexing; pub mod indexing;
pub mod map; pub mod map;

View File

@ -2076,10 +2076,17 @@ impl<'a> BuiltinBuilder<'a> {
Box::new(move |ctx, _, fun, args, generator| { Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty; let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_ty = fun.0.args[1].ty; let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) let x1 = AnyObject { ty: x1_ty, value: x1_val };
let x1 = NDArrayObject::from_object(generator, ctx, x1);
let x2 = AnyObject { ty: x2_ty, value: x2_val };
let x2 = NDArrayObject::from_object(generator, ctx, x2);
let result = NDArrayObject::dot(generator, ctx, x1, x2);
Ok(Some(result))
}), }),
), ),