diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 64df45e..7524df1 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -16,6 +16,7 @@ use crate::{ get_llvm_abi_type, irrt::*, llvm_intrinsics::{call_expect, call_float_floor, call_float_pow, call_float_powi}, + numpy, stmt::{gen_raise, gen_var}, CodeGenContext, CodeGenTask, }, @@ -23,7 +24,7 @@ use crate::{ toplevel::{ DefinitionId, helper::PRIMITIVE_DEF_IDS, - numpy::make_ndarray_ty, + numpy::{make_ndarray_ty, unpack_ndarray_tvars}, TopLevelDef, }, typecheck::{ @@ -1129,6 +1130,41 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( Some("f_pow_i") ); Ok(Some(res.into())) + } else if ty1 == ty2 && matches!(&*ctx.unifier.get_ty(ty1), TypeEnum::TObj { obj_id, .. } if obj_id == &PRIMITIVE_DEF_IDS.ndarray) { + let llvm_usize = generator.get_size_type(ctx.ctx); + let (ndarray_dtype, _) = unpack_ndarray_tvars(&mut ctx.unifier, ty1); + + let left_val = NDArrayValue::from_ptr_val( + left_val.into_pointer_value(), + llvm_usize, + None + ); + let right_val = NDArrayValue::from_ptr_val( + right_val.into_pointer_value(), + llvm_usize, + None + ); + let res = numpy::ndarray_elementwise_binop_impl( + generator, + ctx, + ndarray_dtype, + if is_aug_assign { Some(left_val) } else { None }, + left_val, + right_val, + |generator, ctx, elem_ty, (lhs, rhs)| { + gen_binop_expr_with_values( + generator, + ctx, + (&Some(elem_ty), lhs), + op, + (&Some(elem_ty), rhs), + ctx.current_loc, + is_aug_assign, + )?.unwrap().to_basic_value_enum(ctx, generator, elem_ty) + }, + )?; + + Ok(Some(res.as_ptr_value().into())) } else { let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap()); let TypeEnum::TObj { fields, obj_id, .. } = left_ty_enum.as_ref() else { diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 083bc3a..2ce9e8d 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -344,6 +344,42 @@ fn ndarray_fill_indexed<'ctx, G, ValueFn>( ) } +/// Generates the LLVM IR for populating the entire `NDArray` using a lambda with the same-indexed +/// element from two other `NDArray` as its input. +fn ndarray_fill_zip_map_flattened<'ctx, G, ValueFn>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + elem_ty: Type, + res: NDArrayValue<'ctx>, + lhs: NDArrayValue<'ctx>, + rhs: NDArrayValue<'ctx>, + value_fn: ValueFn, +) -> Result, String> + where + G: CodeGenerator + ?Sized, + ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, Type, (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>)) -> Result, String>, +{ + ndarray_fill_flattened( + generator, + ctx, + res, + |generator, ctx, idx| { + let elem = unsafe { + ( + lhs.data().get_unchecked(ctx, generator, idx, None), + rhs.data().get_unchecked(ctx, generator, idx, None), + ) + }; + + debug_assert_eq!(elem.0.get_type(), elem.1.get_type()); + + value_fn(generator, ctx, elem_ty, elem) + }, + )?; + + Ok(res) +} + /// LLVM-typed implementation for generating the implementation for `ndarray.zeros`. /// /// * `elem_ty` - The element type of the `NDArray`. diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 61018a3..2e6e8a6 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -546,7 +546,7 @@ pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>( /// body(x); /// } /// ``` -/// +/// /// * `init_val` - The initial value of the loop variable. The type of this value will also be used /// as the type of the loop variable. /// * `max_val` - A tuple containing the maximum value of the loop variable, and whether the maximum