core/ndstrides: implement 1D np_dot()
This commit is contained in:
parent
660928bfba
commit
8312ec2278
|
@ -5,7 +5,6 @@ use crate::{
|
||||||
ProxyType, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter,
|
ProxyType, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter,
|
||||||
TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
||||||
},
|
},
|
||||||
expr::gen_binop_expr_with_values,
|
|
||||||
irrt::{
|
irrt::{
|
||||||
calculate_len_for_slice_range, call_ndarray_calc_broadcast,
|
calculate_len_for_slice_range, call_ndarray_calc_broadcast,
|
||||||
call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices,
|
call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices,
|
||||||
|
@ -26,21 +25,15 @@ use crate::{
|
||||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||||
DefinitionId,
|
DefinitionId,
|
||||||
},
|
},
|
||||||
typecheck::{
|
typecheck::typedef::{FunSignature, Type},
|
||||||
magic_methods::Binop,
|
|
||||||
typedef::{FunSignature, Type},
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType};
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
types::BasicType,
|
types::BasicType,
|
||||||
values::{BasicValueEnum, IntValue, PointerValue},
|
values::{BasicValueEnum, IntValue, PointerValue},
|
||||||
AddressSpace, IntPredicate, OptimizationLevel,
|
AddressSpace, IntPredicate,
|
||||||
};
|
};
|
||||||
use inkwell::{
|
use nac3parser::ast::StrRef;
|
||||||
types::{AnyTypeEnum, BasicTypeEnum, PointerType},
|
|
||||||
values::BasicValue,
|
|
||||||
};
|
|
||||||
use nac3parser::ast::{Operator, StrRef};
|
|
||||||
|
|
||||||
/// Creates an uninitialized `NDArray` instance.
|
/// Creates an uninitialized `NDArray` instance.
|
||||||
fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
|
fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
@ -1692,102 +1685,3 @@ pub fn gen_ndarray_fill<'ctx>(
|
||||||
this.fill(generator, context, value_arg);
|
this.fill(generator, context, value_arg);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for `ndarray.dot`.
|
|
||||||
/// Calculate inner product of two vectors or literals
|
|
||||||
/// For matrix multiplication use `np_matmul`
|
|
||||||
///
|
|
||||||
/// The input `NDArray` are flattened and treated as 1D
|
|
||||||
/// The operation is equivalent to `np.dot(arr1.ravel(), arr2.ravel())`
|
|
||||||
pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
x1: (Type, BasicValueEnum<'ctx>),
|
|
||||||
x2: (Type, BasicValueEnum<'ctx>),
|
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
|
||||||
const FN_NAME: &str = "ndarray_dot";
|
|
||||||
let (x1_ty, x1) = x1;
|
|
||||||
let (_, x2) = x2;
|
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
|
|
||||||
match (x1, x2) {
|
|
||||||
(BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => {
|
|
||||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
|
||||||
let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None);
|
|
||||||
|
|
||||||
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
|
||||||
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
|
||||||
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
ctx.builder.build_int_compare(IntPredicate::EQ, n1_sz, n2_sz, "").unwrap(),
|
|
||||||
"0:ValueError",
|
|
||||||
"shapes ({0}), ({1}) not aligned",
|
|
||||||
[Some(n1_sz), Some(n2_sz), None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
|
|
||||||
let identity =
|
|
||||||
unsafe { n1.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
|
|
||||||
let acc = ctx.builder.build_alloca(identity.get_type(), "").unwrap();
|
|
||||||
ctx.builder.build_store(acc, identity.get_type().const_zero()).unwrap();
|
|
||||||
|
|
||||||
gen_for_callback_incrementing(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
None,
|
|
||||||
llvm_usize.const_zero(),
|
|
||||||
(n1_sz, false),
|
|
||||||
|generator, ctx, _, idx| {
|
|
||||||
let elem1 = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
|
|
||||||
let elem2 = unsafe { n2.data().get_unchecked(ctx, generator, &idx, None) };
|
|
||||||
|
|
||||||
let product = match elem1 {
|
|
||||||
BasicValueEnum::IntValue(e1) => ctx
|
|
||||||
.builder
|
|
||||||
.build_int_mul(e1, elem2.into_int_value(), "")
|
|
||||||
.unwrap()
|
|
||||||
.as_basic_value_enum(),
|
|
||||||
BasicValueEnum::FloatValue(e1) => ctx
|
|
||||||
.builder
|
|
||||||
.build_float_mul(e1, elem2.into_float_value(), "")
|
|
||||||
.unwrap()
|
|
||||||
.as_basic_value_enum(),
|
|
||||||
_ => unreachable!(),
|
|
||||||
};
|
|
||||||
let acc_val = ctx.builder.build_load(acc, "").unwrap();
|
|
||||||
let acc_val = match acc_val {
|
|
||||||
BasicValueEnum::IntValue(e1) => ctx
|
|
||||||
.builder
|
|
||||||
.build_int_add(e1, product.into_int_value(), "")
|
|
||||||
.unwrap()
|
|
||||||
.as_basic_value_enum(),
|
|
||||||
BasicValueEnum::FloatValue(e1) => ctx
|
|
||||||
.builder
|
|
||||||
.build_float_add(e1, product.into_float_value(), "")
|
|
||||||
.unwrap()
|
|
||||||
.as_basic_value_enum(),
|
|
||||||
_ => unreachable!(),
|
|
||||||
};
|
|
||||||
ctx.builder.build_store(acc, acc_val).unwrap();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
},
|
|
||||||
llvm_usize.const_int(1, false),
|
|
||||||
)?;
|
|
||||||
let acc_val = ctx.builder.build_load(acc, "").unwrap();
|
|
||||||
Ok(acc_val)
|
|
||||||
}
|
|
||||||
(BasicValueEnum::IntValue(e1), BasicValueEnum::IntValue(e2)) => {
|
|
||||||
Ok(ctx.builder.build_int_mul(e1, e2, "").unwrap().as_basic_value_enum())
|
|
||||||
}
|
|
||||||
(BasicValueEnum::FloatValue(e1), BasicValueEnum::FloatValue(e2)) => {
|
|
||||||
Ok(ctx.builder.build_float_mul(e1, e2, "").unwrap().as_basic_value_enum())
|
|
||||||
}
|
|
||||||
_ => unreachable!(
|
|
||||||
"{FN_NAME}() not supported for '{}'",
|
|
||||||
format!("'{}'", ctx.unifier.stringify(x1_ty))
|
|
||||||
),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -0,0 +1,93 @@
|
||||||
|
use inkwell::{values::BasicValueEnum, IntPredicate};
|
||||||
|
|
||||||
|
use crate::codegen::{
|
||||||
|
object::ndarray::nditer::NDIterHandle, stmt::gen_for_callback, CodeGenContext, CodeGenerator,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::NDArrayObject;
|
||||||
|
|
||||||
|
impl<'ctx> NDArrayObject<'ctx> {
|
||||||
|
/// Perform `np.dot()`.
|
||||||
|
///
|
||||||
|
/// Both ndarrays must be 1D and have the same type.
|
||||||
|
pub fn dot<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
a: NDArrayObject<'ctx>,
|
||||||
|
b: NDArrayObject<'ctx>,
|
||||||
|
) -> BasicValueEnum<'ctx> {
|
||||||
|
// TODO: General `np.dot()` https://numpy.org/doc/stable/reference/generated/numpy.dot.html.
|
||||||
|
assert_eq!(a.ndims, 1);
|
||||||
|
assert_eq!(b.ndims, 1);
|
||||||
|
assert!(ctx.unifier.unioned(a.dtype, b.dtype));
|
||||||
|
let common_dtype = a.dtype;
|
||||||
|
|
||||||
|
// Check shapes.
|
||||||
|
let a_size = a.size(generator, ctx);
|
||||||
|
let b_size = b.size(generator, ctx);
|
||||||
|
let same_shape = a_size.compare(ctx, IntPredicate::EQ, b_size);
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
same_shape.value,
|
||||||
|
"0:ValueError",
|
||||||
|
"shapes ({0},) and ({1},) not aligned: {0} (dim 0) != {1} (dim 1)",
|
||||||
|
[Some(a_size.value), Some(b_size.value), None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
|
||||||
|
let dtype_llvm = ctx.get_llvm_type(generator, common_dtype);
|
||||||
|
|
||||||
|
let result = ctx.builder.build_alloca(dtype_llvm, "np_dot_result").unwrap();
|
||||||
|
ctx.builder.build_store(result, dtype_llvm.const_zero()).unwrap();
|
||||||
|
|
||||||
|
// Do dot product.
|
||||||
|
gen_for_callback(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
Some("np_dot"),
|
||||||
|
|generator, ctx| {
|
||||||
|
let a_iter = NDIterHandle::new(generator, ctx, a);
|
||||||
|
let b_iter = NDIterHandle::new(generator, ctx, b);
|
||||||
|
Ok((a_iter, b_iter))
|
||||||
|
},
|
||||||
|
|generator, ctx, (a_iter, _b_iter)| {
|
||||||
|
// Only a_iter drives the condition, b_iter should have the same status.
|
||||||
|
Ok(a_iter.has_next(generator, ctx).value)
|
||||||
|
},
|
||||||
|
|generator, ctx, _hooks, (a_iter, b_iter)| {
|
||||||
|
let a_scalar = a_iter.get_scalar(generator, ctx).value;
|
||||||
|
let b_scalar = b_iter.get_scalar(generator, ctx).value;
|
||||||
|
|
||||||
|
let old_result = ctx.builder.build_load(result, "").unwrap();
|
||||||
|
let new_result: BasicValueEnum<'ctx> = match old_result {
|
||||||
|
BasicValueEnum::IntValue(old_result) => {
|
||||||
|
let a_scalar = a_scalar.into_int_value();
|
||||||
|
let b_scalar = b_scalar.into_int_value();
|
||||||
|
let x = ctx.builder.build_int_mul(a_scalar, b_scalar, "").unwrap();
|
||||||
|
ctx.builder.build_int_add(old_result, x, "").unwrap().into()
|
||||||
|
}
|
||||||
|
BasicValueEnum::FloatValue(old_result) => {
|
||||||
|
let a_scalar = a_scalar.into_float_value();
|
||||||
|
let b_scalar = b_scalar.into_float_value();
|
||||||
|
let x = ctx.builder.build_float_mul(a_scalar, b_scalar, "").unwrap();
|
||||||
|
ctx.builder.build_float_add(old_result, x, "").unwrap().into()
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
panic!("Unrecognized dtype: {}", ctx.unifier.stringify(common_dtype));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
ctx.builder.build_store(result, new_result).unwrap();
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
|generator, ctx, (a_iter, b_iter)| {
|
||||||
|
a_iter.next(generator, ctx);
|
||||||
|
b_iter.next(generator, ctx);
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
ctx.builder.build_load(result, "").unwrap()
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,5 +1,6 @@
|
||||||
pub mod array;
|
pub mod array;
|
||||||
pub mod broadcast;
|
pub mod broadcast;
|
||||||
|
pub mod dot;
|
||||||
pub mod factory;
|
pub mod factory;
|
||||||
pub mod indexing;
|
pub mod indexing;
|
||||||
pub mod map;
|
pub mod map;
|
||||||
|
|
|
@ -2076,10 +2076,17 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
Box::new(move |ctx, _, fun, args, generator| {
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
let x1_ty = fun.0.args[0].ty;
|
let x1_ty = fun.0.args[0].ty;
|
||||||
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||||
|
|
||||||
let x2_ty = fun.0.args[1].ty;
|
let x2_ty = fun.0.args[1].ty;
|
||||||
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||||
|
|
||||||
Ok(Some(ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
|
let x1 = AnyObject { ty: x1_ty, value: x1_val };
|
||||||
|
let x1 = NDArrayObject::from_object(generator, ctx, x1);
|
||||||
|
let x2 = AnyObject { ty: x2_ty, value: x2_val };
|
||||||
|
let x2 = NDArrayObject::from_object(generator, ctx, x2);
|
||||||
|
|
||||||
|
let result = NDArrayObject::dot(generator, ctx, x1, x2);
|
||||||
|
Ok(Some(result))
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue