core: DO NOT MERGE - Complete assertion for calc_broadcast

This commit is contained in:
David Mak 2024-03-18 16:25:53 +08:00
parent 7a7ca65f9d
commit b43f94c477
4 changed files with 101 additions and 38 deletions

View File

@ -104,9 +104,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
index
}
pub fn gen_symbol_val(
pub fn gen_symbol_val<G: CodeGenerator + ?Sized>(
&mut self,
generator: &mut dyn CodeGenerator,
generator: &mut G,
val: &SymbolValue,
ty: Type,
) -> BasicValueEnum<'ctx> {
@ -175,9 +175,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
}
/// See [`get_llvm_type`].
pub fn get_llvm_type(
pub fn get_llvm_type<G: CodeGenerator + ?Sized>(
&mut self,
generator: &mut dyn CodeGenerator,
generator: &mut G,
ty: Type,
) -> BasicTypeEnum<'ctx> {
get_llvm_type(
@ -210,9 +210,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
}
/// Generates an LLVM variable for a [constant value][value] with a given [type][ty].
pub fn gen_const(
pub fn gen_const<G: CodeGenerator + ?Sized>(
&mut self,
generator: &mut dyn CodeGenerator,
generator: &mut G,
value: &Constant,
ty: Type,
) -> Option<BasicValueEnum<'ctx>> {
@ -493,17 +493,21 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
}
/// Helper function for generating a LLVM variable storing a [String].
pub fn gen_string<S: Into<String>>(
pub fn gen_string<G, S>(
&mut self,
generator: &mut dyn CodeGenerator,
generator: &mut G,
s: S,
) -> BasicValueEnum<'ctx> {
) -> BasicValueEnum<'ctx>
where
G: CodeGenerator + ?Sized,
S: Into<String>,
{
self.gen_const(generator, &Constant::Str(s.into()), self.primitives.str).unwrap()
}
pub fn raise_exn(
pub fn raise_exn<G: CodeGenerator + ?Sized>(
&mut self,
generator: &mut dyn CodeGenerator,
generator: &mut G,
name: &str,
msg: BasicValueEnum<'ctx>,
params: [Option<IntValue<'ctx>>; 3],
@ -547,9 +551,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
gen_raise(generator, self, Some(&zelf.into()), loc);
}
pub fn make_assert(
pub fn make_assert<G: CodeGenerator + ?Sized>(
&mut self,
generator: &mut dyn CodeGenerator,
generator: &mut G,
cond: IntValue<'ctx>,
err_name: &str,
err_msg: &str,
@ -560,9 +564,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
self.make_assert_impl(generator, cond, err_name, err_msg, params, loc);
}
pub fn make_assert_impl(
pub fn make_assert_impl<G: CodeGenerator + ?Sized>(
&mut self,
generator: &mut dyn CodeGenerator,
generator: &mut G,
cond: IntValue<'ctx>,
err_name: &str,
err_msg: BasicValueEnum<'ctx>,

View File

@ -1,6 +1,12 @@
use crate::typecheck::typedef::Type;
use super::{classes::{ListValue, NDArrayValue}, CodeGenContext, CodeGenerator, llvm_intrinsics};
use super::{
classes::{ListValue, NDArrayValue},
CodeGenContext,
CodeGenerator,
llvm_intrinsics,
stmt::{gen_for_callback_incrementing, gen_if_callback},
};
use inkwell::{
attributes::{Attribute, AttributeLoc},
context::Context,
@ -840,24 +846,77 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_ndims = rhs.load_ndims(ctx);
let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None);
let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
// TODO: Generate assertion checks for whether each dimension is compatible
// gen_for_callback_incrementing(
// generator,
// ctx,
// llvm_usize.const_zero(),
// (max_ndims, false),
// |generator, ctx, idx| {
// let lhs_dim_sz =
//
// let lhs_elem = lhs.get_dims().get(ctx, generator, idx, None);
// let rhs_elem = rhs.get_dims().get(ctx, generator, idx, None);
//
//
// },
// llvm_usize.const_int(1, false),
// ).unwrap();
gen_for_callback_incrementing(
generator,
ctx,
llvm_usize.const_zero(),
(max_ndims, false),
|generator, ctx, idx| {
gen_if_callback(
generator,
ctx,
|_, ctx| {
// Only compare the dimensions of the array with fewer dimensions, since any
// additional dimensions are implicitly broadcasted
let lhs_idx_geq_dim = ctx.builder
.build_int_compare(IntPredicate::UGE, idx, min_ndims, "")
.unwrap();
let rhs_idx_geq_dim = ctx.builder
.build_int_compare(IntPredicate::UGE, idx, min_ndims, "")
.unwrap();
Ok(ctx.builder.build_and(lhs_idx_geq_dim, rhs_idx_geq_dim, "").unwrap())
},
|generator, ctx| {
let ri = ctx.builder
.build_int_sub(min_ndims, idx, "")
.unwrap();
let (lhs_dim, rhs_dim) = unsafe {
(
lhs.dim_sizes().get_unchecked(ctx, ri, None),
rhs.dim_sizes().get_unchecked(ctx, ri, None),
)
};
let lhs_dim_ne_1 = ctx.builder
.build_int_compare(IntPredicate::NE, lhs_dim, lhs_dim.get_type().const_int(1, false), "")
.unwrap();
let rhs_dim_ne_1 = ctx.builder
.build_int_compare(IntPredicate::NE, rhs_dim, rhs_dim.get_type().const_int(1, false), "")
.unwrap();
let lhs_ne_rhs = ctx.builder
.build_int_compare(IntPredicate::NE, lhs_dim, rhs_dim, "")
.unwrap();
let both_dims_ne_1 = ctx.builder
.build_and(lhs_dim_ne_1, rhs_dim_ne_1, "")
.unwrap();
let dims_not_broadcastable = ctx.builder
.build_and(both_dims_ne_1, lhs_ne_rhs, "")
.unwrap();
ctx.make_assert(
generator,
dims_not_broadcastable,
"0:ValueError",
"operands cannot be broadcast together",
[None, None, None],
ctx.current_loc,
);
Ok(())
},
None,
)?;
Ok(())
},
llvm_usize.const_int(1, false),
).unwrap();
let lhs_dims = lhs.dim_sizes().as_ptr_value(ctx);
let lhs_ndims = lhs.load_ndims(ctx);

View File

@ -416,10 +416,10 @@ pub struct CodeGenTask {
/// This function is used to obtain the in-memory representation of `ty`, e.g. a `bool` variable
/// would be represented by an `i8`.
#[allow(clippy::too_many_arguments)]
fn get_llvm_type<'ctx>(
fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
ctx: &'ctx Context,
module: &Module<'ctx>,
generator: &mut dyn CodeGenerator,
generator: &mut G,
unifier: &mut Unifier,
top_level: &TopLevelContext,
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,

View File

@ -900,8 +900,8 @@ pub fn final_proxy<'ctx>(
/// Inserts the declaration of the builtin function with the specified `symbol` name, and returns
/// the function.
pub fn get_builtins<'ctx>(
generator: &mut dyn CodeGenerator,
pub fn get_builtins<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
symbol: &str,
) -> FunctionValue<'ctx> {
@ -994,8 +994,8 @@ pub fn exn_constructor<'ctx>(
///
/// * `exception` - The exception thrown by the `raise` statement.
/// * `loc` - The location where the exception is raised from.
pub fn gen_raise<'ctx>(
generator: &mut dyn CodeGenerator,
pub fn gen_raise<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
exception: Option<&BasicValueEnum<'ctx>>,
loc: Location,