Compare commits

..

6 Commits

Author SHA1 Message Date
David Mak a920fe0501 core: Implement elementwise comparison operators 2024-04-03 00:07:33 +08:00
David Mak 727a1886b3 core: Implement elementwise unary operators 2024-04-03 00:07:33 +08:00
David Mak 6af13a8261 core: Implement elementwise binary operators
Including immediate variants of these operators.
2024-04-03 00:07:33 +08:00
David Mak 3540d0ab29 core/magic_methods: Add typeof_*op
Used to determine the expected type of the binary operator with
primitive operands.
2024-04-03 00:07:33 +08:00
David Mak 3a6c53d760 core/toplevel/numpy: Split ndarray type var utilities 2024-04-03 00:07:33 +08:00
David Mak 87bc34f7ec core: Implement calculations for broadcasting ndarrays 2024-04-03 00:07:31 +08:00
2 changed files with 32 additions and 12 deletions

View File

@ -946,18 +946,6 @@ pub fn call_ndarray_calc_broadcast_index<'ctx, G: CodeGenerator + ?Sized, Broadc
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
let array_ndims = array.load_ndims(ctx);
let broadcast_size = broadcast_idx.size(ctx, generator);
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::ULE, array_ndims, broadcast_size, "").unwrap(),
"0:ValueError",
"operands cannot be broadcast together",
[None, None, None],
ctx.current_loc,
);
let broadcast_size = broadcast_idx.size(ctx, generator);
let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();

View File

@ -365,6 +365,27 @@ fn ndarray_fill_mapping<'ctx, G, MapFn>(
)
}
/// Generates the LLVM IR for checking whether the source `ndarray` can be broadcast to the shape of
/// the target `ndarray`.
fn ndarray_assert_is_broadcastable<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
target: NDArrayValue<'ctx>,
source: NDArrayValue<'ctx>,
) {
let array_ndims = source.load_ndims(ctx);
let broadcast_size = target.load_ndims(ctx);
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::ULE, array_ndims, broadcast_size, "").unwrap(),
"0:ValueError",
"operands cannot be broadcast together",
[None, None, None],
ctx.current_loc,
);
}
/// Generates the LLVM IR for populating the entire `NDArray` from two `ndarray` or scalar value
/// with broadcast-compatible shapes.
fn ndarray_broadcast_fill<'ctx, G, ValueFn>(
@ -389,6 +410,17 @@ fn ndarray_broadcast_fill<'ctx, G, ValueFn>(
lhs_val.get_type(),
rhs_val.get_type());
// Assert that all ndarray operands are broadcastable to the target size
if !lhs_scalar {
let lhs_val = NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None);
ndarray_assert_is_broadcastable(generator, ctx, res, lhs_val);
}
if !rhs_scalar {
let rhs_val = NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None);
ndarray_assert_is_broadcastable(generator, ctx, res, rhs_val);
}
ndarray_fill_indexed(
generator,
ctx,