diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 95248b05..0a83e624 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -5,7 +5,6 @@ use crate::{ ProxyType, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }, - expr::gen_binop_expr_with_values, irrt::{ calculate_len_for_slice_range, call_ndarray_calc_broadcast, call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices, @@ -26,21 +25,15 @@ use crate::{ numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, DefinitionId, }, - typecheck::{ - magic_methods::Binop, - typedef::{FunSignature, Type}, - }, + typecheck::typedef::{FunSignature, Type}, }; +use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType}; use inkwell::{ types::BasicType, values::{BasicValueEnum, IntValue, PointerValue}, - AddressSpace, IntPredicate, OptimizationLevel, + AddressSpace, IntPredicate, }; -use inkwell::{ - types::{AnyTypeEnum, BasicTypeEnum, PointerType}, - values::BasicValue, -}; -use nac3parser::ast::{Operator, StrRef}; +use nac3parser::ast::StrRef; /// Creates an uninitialized `NDArray` instance. fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>( @@ -1692,102 +1685,3 @@ pub fn gen_ndarray_fill<'ctx>( this.fill(generator, context, value_arg); Ok(()) } - -/// Generates LLVM IR for `ndarray.dot`. -/// Calculate inner product of two vectors or literals -/// For matrix multiplication use `np_matmul` -/// -/// The input `NDArray` are flattened and treated as 1D -/// The operation is equivalent to `np.dot(arr1.ravel(), arr2.ravel())` -pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - const FN_NAME: &str = "ndarray_dot"; - let (x1_ty, x1) = x1; - let (_, x2) = x2; - - let llvm_usize = generator.get_size_type(ctx.ctx); - - match (x1, x2) { - (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => { - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); - let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None); - - let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); - let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); - - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, n1_sz, n2_sz, "").unwrap(), - "0:ValueError", - "shapes ({0}), ({1}) not aligned", - [Some(n1_sz), Some(n2_sz), None], - ctx.current_loc, - ); - - let identity = - unsafe { n1.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; - let acc = ctx.builder.build_alloca(identity.get_type(), "").unwrap(); - ctx.builder.build_store(acc, identity.get_type().const_zero()).unwrap(); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (n1_sz, false), - |generator, ctx, _, idx| { - let elem1 = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) }; - let elem2 = unsafe { n2.data().get_unchecked(ctx, generator, &idx, None) }; - - let product = match elem1 { - BasicValueEnum::IntValue(e1) => ctx - .builder - .build_int_mul(e1, elem2.into_int_value(), "") - .unwrap() - .as_basic_value_enum(), - BasicValueEnum::FloatValue(e1) => ctx - .builder - .build_float_mul(e1, elem2.into_float_value(), "") - .unwrap() - .as_basic_value_enum(), - _ => unreachable!(), - }; - let acc_val = ctx.builder.build_load(acc, "").unwrap(); - let acc_val = match acc_val { - BasicValueEnum::IntValue(e1) => ctx - .builder - .build_int_add(e1, product.into_int_value(), "") - .unwrap() - .as_basic_value_enum(), - BasicValueEnum::FloatValue(e1) => ctx - .builder - .build_float_add(e1, product.into_float_value(), "") - .unwrap() - .as_basic_value_enum(), - _ => unreachable!(), - }; - ctx.builder.build_store(acc, acc_val).unwrap(); - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - let acc_val = ctx.builder.build_load(acc, "").unwrap(); - Ok(acc_val) - } - (BasicValueEnum::IntValue(e1), BasicValueEnum::IntValue(e2)) => { - Ok(ctx.builder.build_int_mul(e1, e2, "").unwrap().as_basic_value_enum()) - } - (BasicValueEnum::FloatValue(e1), BasicValueEnum::FloatValue(e2)) => { - Ok(ctx.builder.build_float_mul(e1, e2, "").unwrap().as_basic_value_enum()) - } - _ => unreachable!( - "{FN_NAME}() not supported for '{}'", - format!("'{}'", ctx.unifier.stringify(x1_ty)) - ), - } -} diff --git a/nac3core/src/codegen/object/ndarray/dot.rs b/nac3core/src/codegen/object/ndarray/dot.rs new file mode 100644 index 00000000..aa1237b6 --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/dot.rs @@ -0,0 +1,93 @@ +use inkwell::{values::BasicValueEnum, IntPredicate}; + +use crate::codegen::{ + object::ndarray::nditer::NDIterHandle, stmt::gen_for_callback, CodeGenContext, CodeGenerator, +}; + +use super::NDArrayObject; + +impl<'ctx> NDArrayObject<'ctx> { + /// Perform `np.dot()`. + /// + /// Both ndarrays must be 1D and have the same type. + pub fn dot( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + a: NDArrayObject<'ctx>, + b: NDArrayObject<'ctx>, + ) -> BasicValueEnum<'ctx> { + // TODO: General `np.dot()` https://numpy.org/doc/stable/reference/generated/numpy.dot.html. + assert_eq!(a.ndims, 1); + assert_eq!(b.ndims, 1); + assert!(ctx.unifier.unioned(a.dtype, b.dtype)); + let common_dtype = a.dtype; + + // Check shapes. + let a_size = a.size(generator, ctx); + let b_size = b.size(generator, ctx); + let same_shape = a_size.compare(ctx, IntPredicate::EQ, b_size); + ctx.make_assert( + generator, + same_shape.value, + "0:ValueError", + "shapes ({0},) and ({1},) not aligned: {0} (dim 0) != {1} (dim 1)", + [Some(a_size.value), Some(b_size.value), None], + ctx.current_loc, + ); + + let dtype_llvm = ctx.get_llvm_type(generator, common_dtype); + + let result = ctx.builder.build_alloca(dtype_llvm, "np_dot_result").unwrap(); + ctx.builder.build_store(result, dtype_llvm.const_zero()).unwrap(); + + // Do dot product. + gen_for_callback( + generator, + ctx, + Some("np_dot"), + |generator, ctx| { + let a_iter = NDIterHandle::new(generator, ctx, a); + let b_iter = NDIterHandle::new(generator, ctx, b); + Ok((a_iter, b_iter)) + }, + |generator, ctx, (a_iter, _b_iter)| { + // Only a_iter drives the condition, b_iter should have the same status. + Ok(a_iter.has_next(generator, ctx).value) + }, + |generator, ctx, _hooks, (a_iter, b_iter)| { + let a_scalar = a_iter.get_scalar(generator, ctx).value; + let b_scalar = b_iter.get_scalar(generator, ctx).value; + + let old_result = ctx.builder.build_load(result, "").unwrap(); + let new_result: BasicValueEnum<'ctx> = match old_result { + BasicValueEnum::IntValue(old_result) => { + let a_scalar = a_scalar.into_int_value(); + let b_scalar = b_scalar.into_int_value(); + let x = ctx.builder.build_int_mul(a_scalar, b_scalar, "").unwrap(); + ctx.builder.build_int_add(old_result, x, "").unwrap().into() + } + BasicValueEnum::FloatValue(old_result) => { + let a_scalar = a_scalar.into_float_value(); + let b_scalar = b_scalar.into_float_value(); + let x = ctx.builder.build_float_mul(a_scalar, b_scalar, "").unwrap(); + ctx.builder.build_float_add(old_result, x, "").unwrap().into() + } + _ => { + panic!("Unrecognized dtype: {}", ctx.unifier.stringify(common_dtype)); + } + }; + + ctx.builder.build_store(result, new_result).unwrap(); + Ok(()) + }, + |generator, ctx, (a_iter, b_iter)| { + a_iter.next(generator, ctx); + b_iter.next(generator, ctx); + Ok(()) + }, + ) + .unwrap(); + + ctx.builder.build_load(result, "").unwrap() + } +} diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs index 9fb3d220..6272f77a 100644 --- a/nac3core/src/codegen/object/ndarray/mod.rs +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -1,5 +1,6 @@ pub mod array; pub mod broadcast; +pub mod dot; pub mod factory; pub mod indexing; pub mod map; diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 67cfc456..c83c3adf 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -2076,10 +2076,17 @@ impl<'a> BuiltinBuilder<'a> { Box::new(move |ctx, _, fun, args, generator| { let x1_ty = fun.0.args[0].ty; let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - Ok(Some(ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) + let x1 = AnyObject { ty: x1_ty, value: x1_val }; + let x1 = NDArrayObject::from_object(generator, ctx, x1); + let x2 = AnyObject { ty: x2_ty, value: x2_val }; + let x2 = NDArrayObject::from_object(generator, ctx, x2); + + let result = NDArrayObject::dot(generator, ctx, x1, x2); + Ok(Some(result)) }), ),