From b2994ff90a21eb8d430dc6b068b1528f54e9377d Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 19 Mar 2024 17:38:09 +0800 Subject: [PATCH] WIP --- nac3core/src/codegen/classes.rs | 2 +- nac3core/src/codegen/expr.rs | 9 ++- nac3core/src/codegen/irrt/irrt.c | 24 ++++++++ nac3core/src/codegen/irrt/mod.rs | 77 ++++++++++++++++++++++++- nac3core/src/codegen/numpy.rs | 23 ++++---- nac3core/src/toplevel/builtins.rs | 9 ++- nac3core/src/toplevel/helper.rs | 5 +- nac3core/src/typecheck/magic_methods.rs | 11 ++-- nac3standalone/demo/src/ndarray.py | 14 +++++ 9 files changed, 148 insertions(+), 26 deletions(-) diff --git a/nac3core/src/codegen/classes.rs b/nac3core/src/codegen/classes.rs index 39c8880..09612f7 100644 --- a/nac3core/src/codegen/classes.rs +++ b/nac3core/src/codegen/classes.rs @@ -11,7 +11,7 @@ use crate::codegen::{ stmt::gen_for_callback_incrementing, }; -/// An LLVM value that is array-like, i.e. it contains a contiguous, sequenced collection of +/// An LLVM value that is array-like, i.e. it contains a contiguous, sequenced collection of /// elements. pub trait ArrayLikeValue<'ctx> { /// Returns the element type of this array-like value. diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index e537d06..0a1f278 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1129,9 +1129,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( Some("f_pow_i") ); Ok(Some(res.into())) - } else if ty1 == ty2 && matches!(&*ctx.unifier.get_ty(ty1), TypeEnum::TObj { obj_id, .. } if obj_id == &PRIMITIVE_DEF_IDS.ndarray) { + } else if matches!(&*ctx.unifier.get_ty(ty1), TypeEnum::TObj { obj_id, .. } if obj_id == &PRIMITIVE_DEF_IDS.ndarray) && matches!(&*ctx.unifier.get_ty(ty2), TypeEnum::TObj { obj_id, .. } if obj_id == &PRIMITIVE_DEF_IDS.ndarray) { let llvm_usize = generator.get_size_type(ctx.ctx); - let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1); + let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1); + let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1); + + assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); let left_val = NDArrayValue::from_ptr_val( left_val.into_pointer_value(), @@ -1146,7 +1149,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( let res = numpy::ndarray_elementwise_binop_impl( generator, ctx, - ndarray_dtype, + ndarray_dtype1, if is_aug_assign { Some(left_val) } else { None }, left_val, right_val, diff --git a/nac3core/src/codegen/irrt/irrt.c b/nac3core/src/codegen/irrt/irrt.c index 363c3c2..cf656b8 100644 --- a/nac3core/src/codegen/irrt/irrt.c +++ b/nac3core/src/codegen/irrt/irrt.c @@ -355,3 +355,27 @@ void __nac3_ndarray_calc_broadcast64( } } } + +void __nac3_ndarray_calc_broadcast_idx( + const uint32_t *src_dims, + uint32_t src_ndims, + const uint32_t *in_idx, + uint32_t *out_idx +) { + for (uint32_t i = 0; i < src_ndims; ++i) { + uint32_t src_i = src_ndims - i - 1; + out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i]; + } +} + +void __nac3_ndarray_calc_broadcast_idx64( + const uint64_t *src_dims, + uint64_t src_ndims, + const uint64_t *in_idx, + uint64_t *out_idx +) { + for (uint64_t i = 0; i < src_ndims; ++i) { + uint64_t src_i = src_ndims - i - 1; + out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i]; + } +} diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index d2c9248..4ec94d1 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -1,7 +1,15 @@ use crate::typecheck::typedef::Type; use super::{ - classes::{ArrayLikeIndexer, ArrayLikeValue, ListValue, NDArrayValue, UntypedArrayLikeMutator}, + classes::{ + ArrayLikeIndexer, + ArraySliceValue, + ArrayLikeValue, + ListValue, + NDArrayValue, + UntypedArrayLikeAccessor, + UntypedArrayLikeMutator, + }, CodeGenContext, CodeGenerator, llvm_intrinsics, @@ -630,7 +638,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( ctx: &mut CodeGenContext<'ctx, '_>, index: IntValue<'ctx>, ndarray: NDArrayValue<'ctx>, -) -> PointerValue<'ctx> { +) -> ArraySliceValue<'ctx> { let llvm_void = ctx.ctx.void_type(); let llvm_usize = generator.get_size_type(ctx.ctx); @@ -677,7 +685,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( ) .unwrap(); - indices + ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None) } fn call_ndarray_flatten_index_impl<'ctx, G, Indices>( @@ -889,4 +897,67 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( .unwrap(); (max_ndims, out_dims) +} + +/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`] +/// containing the indices used for accessing `array` corresponding to the `broadcast_idx`. +pub fn call_ndarray_calc_broadcast_index<'ctx, G: CodeGenerator + ?Sized, BroadcastIdx: UntypedArrayLikeAccessor<'ctx>>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + array: NDArrayValue<'ctx>, + broadcast_idx: &BroadcastIdx, +) -> ArraySliceValue<'ctx> { + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + + let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { + 32 => "__nac3_ndarray_calc_broadcast_idx", + 64 => "__nac3_ndarray_calc_broadcast_idx64", + bw => unreachable!("Unsupported size type bit width: {}", bw) + }; + let ndarray_calc_broadcast_fn = ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { + let fn_type = llvm_usize.fn_type( + &[ + llvm_pusize.into(), + llvm_usize.into(), + llvm_pusize.into(), + llvm_usize.into(), + ], + false, + ); + + ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) + }); + + // TODO: Assertions + + let broadcast_size = broadcast_idx.size(ctx, generator); + let out_idx = ctx.builder.build_array_alloca(llvm_usize, broadcast_size, "").unwrap(); + let out_idx = ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None); + + let array_dims = array.dim_sizes().base_ptr(ctx, generator); + let array_ndims = array.load_ndims(ctx); + let broadcast_idx_ptr = unsafe { + broadcast_idx.ptr_offset_unchecked( + ctx, + generator, + llvm_usize.const_zero(), + None + ) + }; + + ctx.builder + .build_call( + ndarray_calc_broadcast_fn, + &[ + array_dims.into(), + array_ndims.into(), + broadcast_idx_ptr.into(), + out_idx.base_ptr(ctx, generator).into(), + ], + "", + ) + .unwrap(); + + out_idx } \ No newline at end of file diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index e2c648f..8a7914c 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -8,6 +8,7 @@ use crate::{ codegen::{ classes::{ ArrayLikeIndexer, + ArraySliceValue, ArrayLikeValue, ListValue, NDArrayValue, @@ -325,7 +326,7 @@ fn ndarray_fill_indexed<'ctx, G, ValueFn>( ) -> Result<(), String> where G: CodeGenerator + ?Sized, - ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, PointerValue<'ctx>) -> Result, String>, + ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, ArraySliceValue<'ctx>) -> Result, String>, { ndarray_fill_flattened( generator, @@ -346,7 +347,7 @@ fn ndarray_fill_indexed<'ctx, G, ValueFn>( /// Generates the LLVM IR for populating the entire `NDArray` using a lambda with the same-indexed /// element from two other `NDArray` as its input. -fn ndarray_fill_zip_map_flattened<'ctx, G, ValueFn>( +fn ndarray_broadcast_fill_flattened<'ctx, G, ValueFn>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, elem_ty: Type, @@ -535,16 +536,12 @@ fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>( ctx, ndarray, |generator, ctx, indices| { - let row = ctx.build_gep_and_load( - indices, - &[llvm_usize.const_int(0, false)], - None, - ).into_int_value(); - let col = ctx.build_gep_and_load( - indices, - &[llvm_usize.const_int(1, false)], - None, - ).into_int_value(); + let (row, col) = unsafe { + ( + indices.get_unchecked(ctx, generator, llvm_usize.const_int(0, false), None).into_int_value(), + indices.get_unchecked(ctx, generator, llvm_usize.const_int(1, false), None).into_int_value(), + ) + }; let col_with_offset = ctx.builder .build_int_add( @@ -660,7 +657,7 @@ pub fn ndarray_elementwise_binop_impl<'ctx, G, ValueFn>( ).unwrap() }); - ndarray_fill_zip_map_flattened( + ndarray_broadcast_fill_flattened( generator, ctx, elem_ty, diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index bd8e358..f446f43 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -299,6 +299,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { Some("N".into()), None, ); + let size_t = primitives.0.usize(); + let var_map: VarMap = vec![(num_ty.1, num_ty.0)].into_iter().collect(); let exception_fields = vec![ ("__name__".into(), int32, true), @@ -345,6 +347,11 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { .nth(1) .map(|(var_id, ty)| (*ty, *var_id)) .unwrap(); + let ndarray_usized_ndims_tvar = primitives.1.get_fresh_const_generic_var( + size_t, + Some("ndarray_ndims".into()), + None, + ); let ndarray_copy_ty = *ndarray_fields.get(&"copy".into()).unwrap(); let ndarray_fill_ty = *ndarray_fields.get(&"fill".into()).unwrap(); let ndarray_add_ty = *ndarray_fields.get(&"__add__".into()).unwrap(); @@ -699,7 +706,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { name: "ndarray.__iadd__".into(), simple_name: "__iadd__".into(), signature: ndarray_iadd_ty.0, - var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id, ndarray_usized_ndims_tvar.1], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index ea42b92..f00d754 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -285,8 +285,11 @@ impl TopLevelComposer { ]), }); + let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None); + let ndarray_unsized = subst_ndarray_tvars(&mut unifier, ndarray, Some(ndarray_usized_ndims_tvar.0), None); + unifier.unify(ndarray_copy_fun_ret_ty.0, ndarray).unwrap(); - unifier.unify(ndarray_binop_fun_other_ty.0, ndarray).unwrap(); + unifier.unify(ndarray_binop_fun_other_ty.0, ndarray_unsized).unwrap(); unifier.unify(ndarray_binop_fun_ret_ty.0, ndarray).unwrap(); let ndarray_float = subst_ndarray_tvars(&mut unifier, ndarray, Some(float), None); diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index 53be568..c6ece9f 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -309,6 +309,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie ndarray: ndarray_t, .. } = *store; + let size_t = store.usize(); /* int ======== */ for t in [int32_t, int64_t, uint32_t, uint64_t] { @@ -345,9 +346,11 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie /* ndarray ===== */ let ndarray_float_t = make_ndarray_ty(unifier, store, Some(float_t), None); - impl_basic_arithmetic(unifier, store, ndarray_t, &[ndarray_t], ndarray_t); - impl_pow(unifier, store, ndarray_t, &[ndarray_t], ndarray_t); + let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None); + let ndarray_unsized_t = make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.0)); + impl_basic_arithmetic(unifier, store, ndarray_t, &[ndarray_unsized_t], ndarray_t); + impl_pow(unifier, store, ndarray_t, &[ndarray_unsized_t], ndarray_t); impl_div(unifier, store, ndarray_t, &[ndarray_t], ndarray_float_t); - impl_floordiv(unifier, store, ndarray_t, &[ndarray_t], ndarray_t); - impl_mod(unifier, store, ndarray_t, &[ndarray_t], ndarray_t); + impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t], ndarray_t); + impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t], ndarray_t); } diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 5d91541..65d47de 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -81,6 +81,20 @@ def test_ndarray_add(): output_float64(y[1][0]) output_float64(y[1][1]) +# def test_ndarray_add_broadcast(): +# x = np_identity(2) +# y: ndarray[float, 2] = x + np_ones([2]) +# +# output_float64(x[0][0]) +# output_float64(x[0][1]) +# output_float64(x[1][0]) +# output_float64(x[1][1]) +# +# output_float64(y[0][0]) +# output_float64(y[0][1]) +# output_float64(y[1][0]) +# output_float64(y[1][1]) + def test_ndarray_iadd(): x = np_identity(2) x += np_ones([2, 2])