Compare commits
6 Commits
847d19077e
...
a920fe0501
Author | SHA1 | Date |
---|---|---|
David Mak | a920fe0501 | |
David Mak | 727a1886b3 | |
David Mak | 6af13a8261 | |
David Mak | 3540d0ab29 | |
David Mak | 3a6c53d760 | |
David Mak | 87bc34f7ec |
|
@ -946,18 +946,6 @@ pub fn call_ndarray_calc_broadcast_index<'ctx, G: CodeGenerator + ?Sized, Broadc
|
|||
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
|
||||
});
|
||||
|
||||
let array_ndims = array.load_ndims(ctx);
|
||||
let broadcast_size = broadcast_idx.size(ctx, generator);
|
||||
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder.build_int_compare(IntPredicate::ULE, array_ndims, broadcast_size, "").unwrap(),
|
||||
"0:ValueError",
|
||||
"operands cannot be broadcast together",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
let broadcast_size = broadcast_idx.size(ctx, generator);
|
||||
let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();
|
||||
|
||||
|
|
|
@ -365,6 +365,27 @@ fn ndarray_fill_mapping<'ctx, G, MapFn>(
|
|||
)
|
||||
}
|
||||
|
||||
/// Generates the LLVM IR for checking whether the source `ndarray` can be broadcast to the shape of
|
||||
/// the target `ndarray`.
|
||||
fn ndarray_assert_is_broadcastable<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
target: NDArrayValue<'ctx>,
|
||||
source: NDArrayValue<'ctx>,
|
||||
) {
|
||||
let array_ndims = source.load_ndims(ctx);
|
||||
let broadcast_size = target.load_ndims(ctx);
|
||||
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder.build_int_compare(IntPredicate::ULE, array_ndims, broadcast_size, "").unwrap(),
|
||||
"0:ValueError",
|
||||
"operands cannot be broadcast together",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
}
|
||||
|
||||
/// Generates the LLVM IR for populating the entire `NDArray` from two `ndarray` or scalar value
|
||||
/// with broadcast-compatible shapes.
|
||||
fn ndarray_broadcast_fill<'ctx, G, ValueFn>(
|
||||
|
@ -389,6 +410,17 @@ fn ndarray_broadcast_fill<'ctx, G, ValueFn>(
|
|||
lhs_val.get_type(),
|
||||
rhs_val.get_type());
|
||||
|
||||
// Assert that all ndarray operands are broadcastable to the target size
|
||||
if !lhs_scalar {
|
||||
let lhs_val = NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None);
|
||||
ndarray_assert_is_broadcastable(generator, ctx, res, lhs_val);
|
||||
}
|
||||
|
||||
if !rhs_scalar {
|
||||
let rhs_val = NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None);
|
||||
ndarray_assert_is_broadcastable(generator, ctx, res, rhs_val);
|
||||
}
|
||||
|
||||
ndarray_fill_indexed(
|
||||
generator,
|
||||
ctx,
|
||||
|
|
Loading…
Reference in New Issue