Compare commits
6 Commits
847d19077e
...
a920fe0501
Author | SHA1 | Date |
---|---|---|
David Mak | a920fe0501 | |
David Mak | 727a1886b3 | |
David Mak | 6af13a8261 | |
David Mak | 3540d0ab29 | |
David Mak | 3a6c53d760 | |
David Mak | 87bc34f7ec |
|
@ -946,18 +946,6 @@ pub fn call_ndarray_calc_broadcast_index<'ctx, G: CodeGenerator + ?Sized, Broadc
|
||||||
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
|
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
|
||||||
});
|
});
|
||||||
|
|
||||||
let array_ndims = array.load_ndims(ctx);
|
|
||||||
let broadcast_size = broadcast_idx.size(ctx, generator);
|
|
||||||
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
ctx.builder.build_int_compare(IntPredicate::ULE, array_ndims, broadcast_size, "").unwrap(),
|
|
||||||
"0:ValueError",
|
|
||||||
"operands cannot be broadcast together",
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
|
|
||||||
let broadcast_size = broadcast_idx.size(ctx, generator);
|
let broadcast_size = broadcast_idx.size(ctx, generator);
|
||||||
let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();
|
let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();
|
||||||
|
|
||||||
|
|
|
@ -365,6 +365,27 @@ fn ndarray_fill_mapping<'ctx, G, MapFn>(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generates the LLVM IR for checking whether the source `ndarray` can be broadcast to the shape of
|
||||||
|
/// the target `ndarray`.
|
||||||
|
fn ndarray_assert_is_broadcastable<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
target: NDArrayValue<'ctx>,
|
||||||
|
source: NDArrayValue<'ctx>,
|
||||||
|
) {
|
||||||
|
let array_ndims = source.load_ndims(ctx);
|
||||||
|
let broadcast_size = target.load_ndims(ctx);
|
||||||
|
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
ctx.builder.build_int_compare(IntPredicate::ULE, array_ndims, broadcast_size, "").unwrap(),
|
||||||
|
"0:ValueError",
|
||||||
|
"operands cannot be broadcast together",
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
/// Generates the LLVM IR for populating the entire `NDArray` from two `ndarray` or scalar value
|
/// Generates the LLVM IR for populating the entire `NDArray` from two `ndarray` or scalar value
|
||||||
/// with broadcast-compatible shapes.
|
/// with broadcast-compatible shapes.
|
||||||
fn ndarray_broadcast_fill<'ctx, G, ValueFn>(
|
fn ndarray_broadcast_fill<'ctx, G, ValueFn>(
|
||||||
|
@ -389,6 +410,17 @@ fn ndarray_broadcast_fill<'ctx, G, ValueFn>(
|
||||||
lhs_val.get_type(),
|
lhs_val.get_type(),
|
||||||
rhs_val.get_type());
|
rhs_val.get_type());
|
||||||
|
|
||||||
|
// Assert that all ndarray operands are broadcastable to the target size
|
||||||
|
if !lhs_scalar {
|
||||||
|
let lhs_val = NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None);
|
||||||
|
ndarray_assert_is_broadcastable(generator, ctx, res, lhs_val);
|
||||||
|
}
|
||||||
|
|
||||||
|
if !rhs_scalar {
|
||||||
|
let rhs_val = NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None);
|
||||||
|
ndarray_assert_is_broadcastable(generator, ctx, res, rhs_val);
|
||||||
|
}
|
||||||
|
|
||||||
ndarray_fill_indexed(
|
ndarray_fill_indexed(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
|
|
Loading…
Reference in New Issue