From b43f94c477669c281960a1fdfde725d3cdf809e9 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 18 Mar 2024 16:25:53 +0800 Subject: [PATCH] core: DO NOT MERGE - Complete assertion for calc_broadcast --- nac3core/src/codegen/expr.rs | 34 ++++++------ nac3core/src/codegen/irrt/mod.rs | 93 ++++++++++++++++++++++++++------ nac3core/src/codegen/mod.rs | 4 +- nac3core/src/codegen/stmt.rs | 8 +-- 4 files changed, 101 insertions(+), 38 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 4a173ea..eeeaa22 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -104,9 +104,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { index } - pub fn gen_symbol_val( + pub fn gen_symbol_val( &mut self, - generator: &mut dyn CodeGenerator, + generator: &mut G, val: &SymbolValue, ty: Type, ) -> BasicValueEnum<'ctx> { @@ -175,9 +175,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { } /// See [`get_llvm_type`]. - pub fn get_llvm_type( + pub fn get_llvm_type( &mut self, - generator: &mut dyn CodeGenerator, + generator: &mut G, ty: Type, ) -> BasicTypeEnum<'ctx> { get_llvm_type( @@ -210,9 +210,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { } /// Generates an LLVM variable for a [constant value][value] with a given [type][ty]. - pub fn gen_const( + pub fn gen_const( &mut self, - generator: &mut dyn CodeGenerator, + generator: &mut G, value: &Constant, ty: Type, ) -> Option> { @@ -493,17 +493,21 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { } /// Helper function for generating a LLVM variable storing a [String]. - pub fn gen_string>( + pub fn gen_string( &mut self, - generator: &mut dyn CodeGenerator, + generator: &mut G, s: S, - ) -> BasicValueEnum<'ctx> { + ) -> BasicValueEnum<'ctx> + where + G: CodeGenerator + ?Sized, + S: Into, + { self.gen_const(generator, &Constant::Str(s.into()), self.primitives.str).unwrap() } - pub fn raise_exn( + pub fn raise_exn( &mut self, - generator: &mut dyn CodeGenerator, + generator: &mut G, name: &str, msg: BasicValueEnum<'ctx>, params: [Option>; 3], @@ -547,9 +551,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { gen_raise(generator, self, Some(&zelf.into()), loc); } - pub fn make_assert( + pub fn make_assert( &mut self, - generator: &mut dyn CodeGenerator, + generator: &mut G, cond: IntValue<'ctx>, err_name: &str, err_msg: &str, @@ -560,9 +564,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { self.make_assert_impl(generator, cond, err_name, err_msg, params, loc); } - pub fn make_assert_impl( + pub fn make_assert_impl( &mut self, - generator: &mut dyn CodeGenerator, + generator: &mut G, cond: IntValue<'ctx>, err_name: &str, err_msg: BasicValueEnum<'ctx>, diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 16c23df..543b574 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -1,6 +1,12 @@ use crate::typecheck::typedef::Type; -use super::{classes::{ListValue, NDArrayValue}, CodeGenContext, CodeGenerator, llvm_intrinsics}; +use super::{ + classes::{ListValue, NDArrayValue}, + CodeGenContext, + CodeGenerator, + llvm_intrinsics, + stmt::{gen_for_callback_incrementing, gen_if_callback}, +}; use inkwell::{ attributes::{Attribute, AttributeLoc}, context::Context, @@ -840,24 +846,77 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( let lhs_ndims = lhs.load_ndims(ctx); let rhs_ndims = rhs.load_ndims(ctx); + let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None); let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None); - // TODO: Generate assertion checks for whether each dimension is compatible - // gen_for_callback_incrementing( - // generator, - // ctx, - // llvm_usize.const_zero(), - // (max_ndims, false), - // |generator, ctx, idx| { - // let lhs_dim_sz = - // - // let lhs_elem = lhs.get_dims().get(ctx, generator, idx, None); - // let rhs_elem = rhs.get_dims().get(ctx, generator, idx, None); - // - // - // }, - // llvm_usize.const_int(1, false), - // ).unwrap(); + gen_for_callback_incrementing( + generator, + ctx, + llvm_usize.const_zero(), + (max_ndims, false), + |generator, ctx, idx| { + + gen_if_callback( + generator, + ctx, + |_, ctx| { + // Only compare the dimensions of the array with fewer dimensions, since any + // additional dimensions are implicitly broadcasted + let lhs_idx_geq_dim = ctx.builder + .build_int_compare(IntPredicate::UGE, idx, min_ndims, "") + .unwrap(); + let rhs_idx_geq_dim = ctx.builder + .build_int_compare(IntPredicate::UGE, idx, min_ndims, "") + .unwrap(); + + Ok(ctx.builder.build_and(lhs_idx_geq_dim, rhs_idx_geq_dim, "").unwrap()) + }, + |generator, ctx| { + let ri = ctx.builder + .build_int_sub(min_ndims, idx, "") + .unwrap(); + let (lhs_dim, rhs_dim) = unsafe { + ( + lhs.dim_sizes().get_unchecked(ctx, ri, None), + rhs.dim_sizes().get_unchecked(ctx, ri, None), + ) + }; + + let lhs_dim_ne_1 = ctx.builder + .build_int_compare(IntPredicate::NE, lhs_dim, lhs_dim.get_type().const_int(1, false), "") + .unwrap(); + let rhs_dim_ne_1 = ctx.builder + .build_int_compare(IntPredicate::NE, rhs_dim, rhs_dim.get_type().const_int(1, false), "") + .unwrap(); + let lhs_ne_rhs = ctx.builder + .build_int_compare(IntPredicate::NE, lhs_dim, rhs_dim, "") + .unwrap(); + + let both_dims_ne_1 = ctx.builder + .build_and(lhs_dim_ne_1, rhs_dim_ne_1, "") + .unwrap(); + let dims_not_broadcastable = ctx.builder + .build_and(both_dims_ne_1, lhs_ne_rhs, "") + .unwrap(); + + ctx.make_assert( + generator, + dims_not_broadcastable, + "0:ValueError", + "operands cannot be broadcast together", + [None, None, None], + ctx.current_loc, + ); + + Ok(()) + }, + None, + )?; + + Ok(()) + }, + llvm_usize.const_int(1, false), + ).unwrap(); let lhs_dims = lhs.dim_sizes().as_ptr_value(ctx); let lhs_ndims = lhs.load_ndims(ctx); diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index e075900..8919ae3 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -416,10 +416,10 @@ pub struct CodeGenTask { /// This function is used to obtain the in-memory representation of `ty`, e.g. a `bool` variable /// would be represented by an `i8`. #[allow(clippy::too_many_arguments)] -fn get_llvm_type<'ctx>( +fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( ctx: &'ctx Context, module: &Module<'ctx>, - generator: &mut dyn CodeGenerator, + generator: &mut G, unifier: &mut Unifier, top_level: &TopLevelContext, type_cache: &mut HashMap>, diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index b43508d..318dcd0 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -900,8 +900,8 @@ pub fn final_proxy<'ctx>( /// Inserts the declaration of the builtin function with the specified `symbol` name, and returns /// the function. -pub fn get_builtins<'ctx>( - generator: &mut dyn CodeGenerator, +pub fn get_builtins<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, symbol: &str, ) -> FunctionValue<'ctx> { @@ -994,8 +994,8 @@ pub fn exn_constructor<'ctx>( /// /// * `exception` - The exception thrown by the `raise` statement. /// * `loc` - The location where the exception is raised from. -pub fn gen_raise<'ctx>( - generator: &mut dyn CodeGenerator, +pub fn gen_raise<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, exception: Option<&BasicValueEnum<'ctx>>, loc: Location,