core: DO NOT MERGE - Complete assertion for calc_broadcast
This commit is contained in:
parent
7a7ca65f9d
commit
b43f94c477
|
@ -104,9 +104,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
|||
index
|
||||
}
|
||||
|
||||
pub fn gen_symbol_val(
|
||||
pub fn gen_symbol_val<G: CodeGenerator + ?Sized>(
|
||||
&mut self,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
generator: &mut G,
|
||||
val: &SymbolValue,
|
||||
ty: Type,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
|
@ -175,9 +175,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
|||
}
|
||||
|
||||
/// See [`get_llvm_type`].
|
||||
pub fn get_llvm_type(
|
||||
pub fn get_llvm_type<G: CodeGenerator + ?Sized>(
|
||||
&mut self,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
generator: &mut G,
|
||||
ty: Type,
|
||||
) -> BasicTypeEnum<'ctx> {
|
||||
get_llvm_type(
|
||||
|
@ -210,9 +210,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
|||
}
|
||||
|
||||
/// Generates an LLVM variable for a [constant value][value] with a given [type][ty].
|
||||
pub fn gen_const(
|
||||
pub fn gen_const<G: CodeGenerator + ?Sized>(
|
||||
&mut self,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
generator: &mut G,
|
||||
value: &Constant,
|
||||
ty: Type,
|
||||
) -> Option<BasicValueEnum<'ctx>> {
|
||||
|
@ -493,17 +493,21 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
|||
}
|
||||
|
||||
/// Helper function for generating a LLVM variable storing a [String].
|
||||
pub fn gen_string<S: Into<String>>(
|
||||
pub fn gen_string<G, S>(
|
||||
&mut self,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
generator: &mut G,
|
||||
s: S,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
) -> BasicValueEnum<'ctx>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
S: Into<String>,
|
||||
{
|
||||
self.gen_const(generator, &Constant::Str(s.into()), self.primitives.str).unwrap()
|
||||
}
|
||||
|
||||
pub fn raise_exn(
|
||||
pub fn raise_exn<G: CodeGenerator + ?Sized>(
|
||||
&mut self,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
generator: &mut G,
|
||||
name: &str,
|
||||
msg: BasicValueEnum<'ctx>,
|
||||
params: [Option<IntValue<'ctx>>; 3],
|
||||
|
@ -547,9 +551,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
|||
gen_raise(generator, self, Some(&zelf.into()), loc);
|
||||
}
|
||||
|
||||
pub fn make_assert(
|
||||
pub fn make_assert<G: CodeGenerator + ?Sized>(
|
||||
&mut self,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
generator: &mut G,
|
||||
cond: IntValue<'ctx>,
|
||||
err_name: &str,
|
||||
err_msg: &str,
|
||||
|
@ -560,9 +564,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
|||
self.make_assert_impl(generator, cond, err_name, err_msg, params, loc);
|
||||
}
|
||||
|
||||
pub fn make_assert_impl(
|
||||
pub fn make_assert_impl<G: CodeGenerator + ?Sized>(
|
||||
&mut self,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
generator: &mut G,
|
||||
cond: IntValue<'ctx>,
|
||||
err_name: &str,
|
||||
err_msg: BasicValueEnum<'ctx>,
|
||||
|
|
|
@ -1,6 +1,12 @@
|
|||
use crate::typecheck::typedef::Type;
|
||||
|
||||
use super::{classes::{ListValue, NDArrayValue}, CodeGenContext, CodeGenerator, llvm_intrinsics};
|
||||
use super::{
|
||||
classes::{ListValue, NDArrayValue},
|
||||
CodeGenContext,
|
||||
CodeGenerator,
|
||||
llvm_intrinsics,
|
||||
stmt::{gen_for_callback_incrementing, gen_if_callback},
|
||||
};
|
||||
use inkwell::{
|
||||
attributes::{Attribute, AttributeLoc},
|
||||
context::Context,
|
||||
|
@ -840,24 +846,77 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
|
|||
|
||||
let lhs_ndims = lhs.load_ndims(ctx);
|
||||
let rhs_ndims = rhs.load_ndims(ctx);
|
||||
let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None);
|
||||
let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
|
||||
|
||||
// TODO: Generate assertion checks for whether each dimension is compatible
|
||||
// gen_for_callback_incrementing(
|
||||
// generator,
|
||||
// ctx,
|
||||
// llvm_usize.const_zero(),
|
||||
// (max_ndims, false),
|
||||
// |generator, ctx, idx| {
|
||||
// let lhs_dim_sz =
|
||||
//
|
||||
// let lhs_elem = lhs.get_dims().get(ctx, generator, idx, None);
|
||||
// let rhs_elem = rhs.get_dims().get(ctx, generator, idx, None);
|
||||
//
|
||||
//
|
||||
// },
|
||||
// llvm_usize.const_int(1, false),
|
||||
// ).unwrap();
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
llvm_usize.const_zero(),
|
||||
(max_ndims, false),
|
||||
|generator, ctx, idx| {
|
||||
|
||||
gen_if_callback(
|
||||
generator,
|
||||
ctx,
|
||||
|_, ctx| {
|
||||
// Only compare the dimensions of the array with fewer dimensions, since any
|
||||
// additional dimensions are implicitly broadcasted
|
||||
let lhs_idx_geq_dim = ctx.builder
|
||||
.build_int_compare(IntPredicate::UGE, idx, min_ndims, "")
|
||||
.unwrap();
|
||||
let rhs_idx_geq_dim = ctx.builder
|
||||
.build_int_compare(IntPredicate::UGE, idx, min_ndims, "")
|
||||
.unwrap();
|
||||
|
||||
Ok(ctx.builder.build_and(lhs_idx_geq_dim, rhs_idx_geq_dim, "").unwrap())
|
||||
},
|
||||
|generator, ctx| {
|
||||
let ri = ctx.builder
|
||||
.build_int_sub(min_ndims, idx, "")
|
||||
.unwrap();
|
||||
let (lhs_dim, rhs_dim) = unsafe {
|
||||
(
|
||||
lhs.dim_sizes().get_unchecked(ctx, ri, None),
|
||||
rhs.dim_sizes().get_unchecked(ctx, ri, None),
|
||||
)
|
||||
};
|
||||
|
||||
let lhs_dim_ne_1 = ctx.builder
|
||||
.build_int_compare(IntPredicate::NE, lhs_dim, lhs_dim.get_type().const_int(1, false), "")
|
||||
.unwrap();
|
||||
let rhs_dim_ne_1 = ctx.builder
|
||||
.build_int_compare(IntPredicate::NE, rhs_dim, rhs_dim.get_type().const_int(1, false), "")
|
||||
.unwrap();
|
||||
let lhs_ne_rhs = ctx.builder
|
||||
.build_int_compare(IntPredicate::NE, lhs_dim, rhs_dim, "")
|
||||
.unwrap();
|
||||
|
||||
let both_dims_ne_1 = ctx.builder
|
||||
.build_and(lhs_dim_ne_1, rhs_dim_ne_1, "")
|
||||
.unwrap();
|
||||
let dims_not_broadcastable = ctx.builder
|
||||
.build_and(both_dims_ne_1, lhs_ne_rhs, "")
|
||||
.unwrap();
|
||||
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
dims_not_broadcastable,
|
||||
"0:ValueError",
|
||||
"operands cannot be broadcast together",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
},
|
||||
None,
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
},
|
||||
llvm_usize.const_int(1, false),
|
||||
).unwrap();
|
||||
|
||||
let lhs_dims = lhs.dim_sizes().as_ptr_value(ctx);
|
||||
let lhs_ndims = lhs.load_ndims(ctx);
|
||||
|
|
|
@ -416,10 +416,10 @@ pub struct CodeGenTask {
|
|||
/// This function is used to obtain the in-memory representation of `ty`, e.g. a `bool` variable
|
||||
/// would be represented by an `i8`.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn get_llvm_type<'ctx>(
|
||||
fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
|
||||
ctx: &'ctx Context,
|
||||
module: &Module<'ctx>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
generator: &mut G,
|
||||
unifier: &mut Unifier,
|
||||
top_level: &TopLevelContext,
|
||||
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
|
||||
|
|
|
@ -900,8 +900,8 @@ pub fn final_proxy<'ctx>(
|
|||
|
||||
/// Inserts the declaration of the builtin function with the specified `symbol` name, and returns
|
||||
/// the function.
|
||||
pub fn get_builtins<'ctx>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
pub fn get_builtins<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
symbol: &str,
|
||||
) -> FunctionValue<'ctx> {
|
||||
|
@ -994,8 +994,8 @@ pub fn exn_constructor<'ctx>(
|
|||
///
|
||||
/// * `exception` - The exception thrown by the `raise` statement.
|
||||
/// * `loc` - The location where the exception is raised from.
|
||||
pub fn gen_raise<'ctx>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
pub fn gen_raise<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
exception: Option<&BasicValueEnum<'ctx>>,
|
||||
loc: Location,
|
||||
|
|
Loading…
Reference in New Issue