forked from M-Labs/nac3
core: Implement numpy.matmul for 2D-2D ndarrays
This commit is contained in:
parent
5dfcc63978
commit
847615fc2f
@ -384,7 +384,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
|||||||
rhs: BasicValueEnum<'ctx>,
|
rhs: BasicValueEnum<'ctx>,
|
||||||
) -> BasicValueEnum<'ctx> {
|
) -> BasicValueEnum<'ctx> {
|
||||||
let (BasicValueEnum::FloatValue(lhs), BasicValueEnum::FloatValue(rhs)) = (lhs, rhs) else {
|
let (BasicValueEnum::FloatValue(lhs), BasicValueEnum::FloatValue(rhs)) = (lhs, rhs) else {
|
||||||
unreachable!()
|
unreachable!("Expected (FloatValue, FloatValue), got ({}, {})", lhs.get_type(), rhs.get_type())
|
||||||
};
|
};
|
||||||
match op {
|
match op {
|
||||||
Operator::Add => self.builder.build_float_add(lhs, rhs, "fadd").map(Into::into).unwrap(),
|
Operator::Add => self.builder.build_float_add(lhs, rhs, "fadd").map(Into::into).unwrap(),
|
||||||
@ -589,8 +589,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
|||||||
// even if this assumption is violated, it does not matter as exception unwinding is
|
// even if this assumption is violated, it does not matter as exception unwinding is
|
||||||
// slow anyway...
|
// slow anyway...
|
||||||
let cond = call_expect(self, cond, i1_true, Some("expect"));
|
let cond = call_expect(self, cond, i1_true, Some("expect"));
|
||||||
let current_fun = self.builder.get_insert_block().unwrap().get_parent().unwrap();
|
let current_bb = self.builder.get_insert_block().unwrap();
|
||||||
let then_block = self.ctx.append_basic_block(current_fun, "succ");
|
let current_fun = current_bb.get_parent().unwrap();
|
||||||
|
let then_block = self.ctx.insert_basic_block_after(current_bb, "succ");
|
||||||
let exn_block = self.ctx.append_basic_block(current_fun, "fail");
|
let exn_block = self.ctx.append_basic_block(current_fun, "fail");
|
||||||
self.builder.build_conditional_branch(cond, then_block, exn_block).unwrap();
|
self.builder.build_conditional_branch(cond, then_block, exn_block).unwrap();
|
||||||
self.builder.position_at_end(exn_block);
|
self.builder.position_at_end(exn_block);
|
||||||
@ -1148,27 +1149,45 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
let left_val = NDArrayValue::from_ptr_val(
|
let left_val = NDArrayValue::from_ptr_val(
|
||||||
left_val.into_pointer_value(),
|
left_val.into_pointer_value(),
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
None
|
None,
|
||||||
);
|
);
|
||||||
let res = numpy::ndarray_elementwise_binop_impl(
|
let right_val = NDArrayValue::from_ptr_val(
|
||||||
generator,
|
right_val.into_pointer_value(),
|
||||||
ctx,
|
llvm_usize,
|
||||||
ndarray_dtype1,
|
None,
|
||||||
if is_aug_assign { Some(left_val) } else { None },
|
);
|
||||||
(left_val.as_ptr_value().into(), false),
|
|
||||||
(right_val, false),
|
let res = if *op == Operator::MatMult {
|
||||||
|generator, ctx, (lhs, rhs)| {
|
// MatMult is the only binop which is not an elementwise op
|
||||||
gen_binop_expr_with_values(
|
numpy::ndarray_matmul_2d(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
(&Some(ndarray_dtype1), lhs),
|
ndarray_dtype1,
|
||||||
op,
|
if is_aug_assign { Some(left_val) } else { None },
|
||||||
(&Some(ndarray_dtype2), rhs),
|
left_val,
|
||||||
ctx.current_loc,
|
right_val,
|
||||||
is_aug_assign,
|
)?
|
||||||
)?.unwrap().to_basic_value_enum(ctx, generator, ndarray_dtype1)
|
} else {
|
||||||
},
|
numpy::ndarray_elementwise_binop_impl(
|
||||||
)?;
|
generator,
|
||||||
|
ctx,
|
||||||
|
ndarray_dtype1,
|
||||||
|
if is_aug_assign { Some(left_val) } else { None },
|
||||||
|
(left_val.as_ptr_value().into(), false),
|
||||||
|
(right_val.as_ptr_value().into(), false),
|
||||||
|
|generator, ctx, (lhs, rhs)| {
|
||||||
|
gen_binop_expr_with_values(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
(&Some(ndarray_dtype1), lhs),
|
||||||
|
op,
|
||||||
|
(&Some(ndarray_dtype2), rhs),
|
||||||
|
ctx.current_loc,
|
||||||
|
is_aug_assign,
|
||||||
|
)?.unwrap().to_basic_value_enum(ctx, generator, ndarray_dtype1)
|
||||||
|
},
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
|
||||||
Ok(Some(res.as_ptr_value().into()))
|
Ok(Some(res.as_ptr_value().into()))
|
||||||
} else {
|
} else {
|
||||||
|
@ -1,9 +1,5 @@
|
|||||||
use inkwell::{
|
use inkwell::{IntPredicate, OptimizationLevel, types::BasicType, values::{BasicValueEnum, IntValue, PointerValue}};
|
||||||
IntPredicate,
|
use nac3parser::ast::{Operator, StrRef};
|
||||||
types::BasicType,
|
|
||||||
values::{BasicValueEnum, IntValue, PointerValue}
|
|
||||||
};
|
|
||||||
use nac3parser::ast::StrRef;
|
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::{
|
codegen::{
|
||||||
classes::{
|
classes::{
|
||||||
@ -14,17 +10,20 @@ use crate::{
|
|||||||
TypedArrayLikeAccessor,
|
TypedArrayLikeAccessor,
|
||||||
TypedArrayLikeAdapter,
|
TypedArrayLikeAdapter,
|
||||||
UntypedArrayLikeAccessor,
|
UntypedArrayLikeAccessor,
|
||||||
|
UntypedArrayLikeMutator,
|
||||||
},
|
},
|
||||||
CodeGenContext,
|
CodeGenContext,
|
||||||
CodeGenerator,
|
CodeGenerator,
|
||||||
|
expr::gen_binop_expr_with_values,
|
||||||
irrt::{
|
irrt::{
|
||||||
call_ndarray_calc_broadcast,
|
call_ndarray_calc_broadcast,
|
||||||
call_ndarray_calc_broadcast_index,
|
call_ndarray_calc_broadcast_index,
|
||||||
call_ndarray_calc_nd_indices,
|
call_ndarray_calc_nd_indices,
|
||||||
call_ndarray_calc_size,
|
call_ndarray_calc_size,
|
||||||
},
|
},
|
||||||
llvm_intrinsics::call_memcpy_generic,
|
llvm_intrinsics,
|
||||||
stmt::gen_for_callback_incrementing,
|
llvm_intrinsics::{call_memcpy_generic},
|
||||||
|
stmt::{gen_for_callback_incrementing, gen_if_else_expr_callback},
|
||||||
},
|
},
|
||||||
symbol_resolver::ValueEnum,
|
symbol_resolver::ValueEnum,
|
||||||
toplevel::{
|
toplevel::{
|
||||||
@ -86,6 +85,8 @@ fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
|
|||||||
ctx.current_loc,
|
ctx.current_loc,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// TODO: Disallow dim_sz > u32_MAX
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
},
|
},
|
||||||
llvm_usize.const_int(1, false),
|
llvm_usize.const_int(1, false),
|
||||||
@ -171,6 +172,8 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
[None, None, None],
|
[None, None, None],
|
||||||
ctx.current_loc,
|
ctx.current_loc,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// TODO: Disallow dim_sz > u32_MAX
|
||||||
}
|
}
|
||||||
|
|
||||||
let ndarray = generator.gen_var_alloc(
|
let ndarray = generator.gen_var_alloc(
|
||||||
@ -824,6 +827,319 @@ pub fn ndarray_elementwise_binop_impl<'ctx, G, ValueFn>(
|
|||||||
Ok(ndarray)
|
Ok(ndarray)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s.
|
||||||
|
///
|
||||||
|
/// * `elem_ty` - The element type of the `NDArray`.
|
||||||
|
/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be
|
||||||
|
/// written to a new `ndarray`.
|
||||||
|
pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
elem_ty: Type,
|
||||||
|
res: Option<NDArrayValue<'ctx>>,
|
||||||
|
lhs: NDArrayValue<'ctx>,
|
||||||
|
rhs: NDArrayValue<'ctx>,
|
||||||
|
) -> Result<NDArrayValue<'ctx>, String> {
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
if cfg!(debug_assertions) {
|
||||||
|
let lhs_ndims = lhs.load_ndims(ctx);
|
||||||
|
let rhs_ndims = rhs.load_ndims(ctx);
|
||||||
|
|
||||||
|
// lhs.ndims == 2
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
ctx.builder.build_int_compare(
|
||||||
|
IntPredicate::EQ,
|
||||||
|
lhs_ndims,
|
||||||
|
llvm_usize.const_int(2, false),
|
||||||
|
"",
|
||||||
|
).unwrap(),
|
||||||
|
"0:ValueError",
|
||||||
|
"",
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
|
||||||
|
// rhs.ndims == 2
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
ctx.builder.build_int_compare(
|
||||||
|
IntPredicate::EQ,
|
||||||
|
rhs_ndims,
|
||||||
|
llvm_usize.const_int(2, false),
|
||||||
|
"",
|
||||||
|
).unwrap(),
|
||||||
|
"0:ValueError",
|
||||||
|
"",
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
|
||||||
|
if let Some(res) = res {
|
||||||
|
let res_ndims = res.load_ndims(ctx);
|
||||||
|
let res_dim0 = unsafe {
|
||||||
|
res.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
|
};
|
||||||
|
let res_dim1 = unsafe {
|
||||||
|
res.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||||
|
};
|
||||||
|
let lhs_dim0 = unsafe {
|
||||||
|
lhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
|
};
|
||||||
|
let rhs_dim1 = unsafe {
|
||||||
|
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||||
|
};
|
||||||
|
|
||||||
|
// res.ndims == 2
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
ctx.builder.build_int_compare(
|
||||||
|
IntPredicate::EQ,
|
||||||
|
res_ndims,
|
||||||
|
llvm_usize.const_int(2, false),
|
||||||
|
"",
|
||||||
|
).unwrap(),
|
||||||
|
"0:ValueError",
|
||||||
|
"",
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
|
||||||
|
// res.dims[0] == lhs.dims[0]
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
ctx.builder.build_int_compare(
|
||||||
|
IntPredicate::EQ,
|
||||||
|
lhs_dim0,
|
||||||
|
res_dim0,
|
||||||
|
"",
|
||||||
|
).unwrap(),
|
||||||
|
"0:ValueError",
|
||||||
|
"",
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
|
||||||
|
// res.dims[1] == rhs.dims[0]
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
ctx.builder.build_int_compare(
|
||||||
|
IntPredicate::EQ,
|
||||||
|
rhs_dim1,
|
||||||
|
res_dim1,
|
||||||
|
"",
|
||||||
|
).unwrap(),
|
||||||
|
"0:ValueError",
|
||||||
|
"",
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||||
|
let lhs_dim1 = unsafe {
|
||||||
|
lhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||||
|
};
|
||||||
|
let rhs_dim0 = unsafe {
|
||||||
|
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||||
|
};
|
||||||
|
|
||||||
|
// lhs.dims[1] == rhs.dims[0]
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
ctx.builder.build_int_compare(
|
||||||
|
IntPredicate::EQ,
|
||||||
|
lhs_dim1,
|
||||||
|
rhs_dim0,
|
||||||
|
"",
|
||||||
|
).unwrap(),
|
||||||
|
"0:ValueError",
|
||||||
|
"",
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let lhs = if res.is_some_and(|res| res.as_ptr_value() == lhs.as_ptr_value()) {
|
||||||
|
ndarray_copy_impl(generator, ctx, elem_ty, lhs)?
|
||||||
|
} else {
|
||||||
|
lhs
|
||||||
|
};
|
||||||
|
|
||||||
|
let ndarray = res.unwrap_or_else(|| {
|
||||||
|
create_ndarray_dyn_shape(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
elem_ty,
|
||||||
|
&(lhs, rhs),
|
||||||
|
|_, _, _| {
|
||||||
|
Ok(llvm_usize.const_int(2, false))
|
||||||
|
},
|
||||||
|
|generator, ctx, (lhs, rhs), idx| {
|
||||||
|
gen_if_else_expr_callback(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
|_, ctx| {
|
||||||
|
Ok(ctx.builder.build_int_compare(
|
||||||
|
IntPredicate::EQ,
|
||||||
|
idx,
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
"",
|
||||||
|
).unwrap())
|
||||||
|
},
|
||||||
|
|generator, ctx| {
|
||||||
|
Ok(Some(unsafe {
|
||||||
|
lhs.dim_sizes().get_typed_unchecked(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
&llvm_usize.const_zero(),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
}))
|
||||||
|
},
|
||||||
|
|generator, ctx| {
|
||||||
|
Ok(Some(unsafe {
|
||||||
|
rhs.dim_sizes().get_typed_unchecked(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
&llvm_usize.const_int(1, false),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
}))
|
||||||
|
},
|
||||||
|
).map(|v| v.map(BasicValueEnum::into_int_value).unwrap())
|
||||||
|
},
|
||||||
|
).unwrap()
|
||||||
|
});
|
||||||
|
|
||||||
|
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
|
ndarray_fill_indexed(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ndarray,
|
||||||
|
|generator, ctx, idx| {
|
||||||
|
llvm_intrinsics::call_expect(
|
||||||
|
ctx,
|
||||||
|
idx.size(ctx, generator).get_type().const_int(2, false),
|
||||||
|
idx.size(ctx, generator),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
|
let common_dim = {
|
||||||
|
let lhs_idx1 = unsafe {
|
||||||
|
lhs.dim_sizes().get_typed_unchecked(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
&llvm_usize.const_int(1, false),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
let rhs_idx0 = unsafe {
|
||||||
|
rhs.dim_sizes().get_typed_unchecked(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
&llvm_usize.const_zero(),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None);
|
||||||
|
|
||||||
|
ctx.builder.build_int_truncate(idx, llvm_i32, "").unwrap()
|
||||||
|
};
|
||||||
|
|
||||||
|
let idx0 = unsafe {
|
||||||
|
let idx0 = idx.get_typed_unchecked(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
&llvm_usize.const_zero(),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
|
ctx.builder.build_int_truncate(idx0, llvm_i32, "").unwrap()
|
||||||
|
};
|
||||||
|
let idx1 = unsafe {
|
||||||
|
let idx1 = idx.get_typed_unchecked(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
&llvm_usize.const_int(1, false),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
|
ctx.builder.build_int_truncate(idx1, llvm_i32, "").unwrap()
|
||||||
|
};
|
||||||
|
|
||||||
|
let result_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
|
||||||
|
let result_identity = ndarray_zero_value(generator, ctx, elem_ty);
|
||||||
|
ctx.builder.build_store(result_addr, result_identity).unwrap();
|
||||||
|
|
||||||
|
gen_for_callback_incrementing(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
llvm_i32.const_zero(),
|
||||||
|
(common_dim, false),
|
||||||
|
|generator, ctx, i| {
|
||||||
|
let i = ctx.builder.build_int_truncate(i, llvm_i32, "").unwrap();
|
||||||
|
|
||||||
|
let ab_idx = generator.gen_array_var_alloc(
|
||||||
|
ctx,
|
||||||
|
llvm_i32.into(),
|
||||||
|
llvm_usize.const_int(2, false),
|
||||||
|
None,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let a = unsafe {
|
||||||
|
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), idx0.into());
|
||||||
|
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_int(1, false), i.into());
|
||||||
|
|
||||||
|
lhs.data().get_unchecked(ctx, generator, &ab_idx, None)
|
||||||
|
};
|
||||||
|
let b = unsafe {
|
||||||
|
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), i.into());
|
||||||
|
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_int(1, false), idx1.into());
|
||||||
|
|
||||||
|
rhs.data().get_unchecked(ctx, generator, &ab_idx, None)
|
||||||
|
};
|
||||||
|
|
||||||
|
let a_mul_b = gen_binop_expr_with_values(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
(&Some(elem_ty), a),
|
||||||
|
&Operator::Mult,
|
||||||
|
(&Some(elem_ty), b),
|
||||||
|
ctx.current_loc,
|
||||||
|
false,
|
||||||
|
)?.unwrap().to_basic_value_enum(ctx, generator, elem_ty)?;
|
||||||
|
|
||||||
|
let result = ctx.builder.build_load(result_addr, "").unwrap();
|
||||||
|
let result = gen_binop_expr_with_values(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
(&Some(elem_ty), result),
|
||||||
|
&Operator::Add,
|
||||||
|
(&Some(elem_ty), a_mul_b),
|
||||||
|
ctx.current_loc,
|
||||||
|
false,
|
||||||
|
)?.unwrap().to_basic_value_enum(ctx, generator, elem_ty)?;
|
||||||
|
ctx.builder.build_store(result_addr, result).unwrap();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
llvm_usize.const_int(1, false),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let result = ctx.builder.build_load(result_addr, "").unwrap();
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(ndarray)
|
||||||
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for `ndarray.empty`.
|
/// Generates LLVM IR for `ndarray.empty`.
|
||||||
pub fn gen_ndarray_empty<'ctx>(
|
pub fn gen_ndarray_empty<'ctx>(
|
||||||
context: &mut CodeGenContext<'ctx, '_>,
|
context: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
@ -495,14 +495,14 @@ pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>(
|
|||||||
BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
|
BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
|
||||||
UpdateFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
|
UpdateFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
|
||||||
{
|
{
|
||||||
let current = ctx.builder.get_insert_block().and_then(BasicBlock::get_parent).unwrap();
|
let current_bb = ctx.builder.get_insert_block().unwrap();
|
||||||
let init_bb = ctx.ctx.append_basic_block(current, "for.init");
|
let init_bb = ctx.ctx.insert_basic_block_after(current_bb, "for.init");
|
||||||
// The BB containing the loop condition check
|
// The BB containing the loop condition check
|
||||||
let cond_bb = ctx.ctx.append_basic_block(current, "for.cond");
|
let cond_bb = ctx.ctx.insert_basic_block_after(init_bb, "for.cond");
|
||||||
let body_bb = ctx.ctx.append_basic_block(current, "for.body");
|
let body_bb = ctx.ctx.insert_basic_block_after(cond_bb, "for.body");
|
||||||
// The BB containing the increment expression
|
// The BB containing the increment expression
|
||||||
let update_bb = ctx.ctx.append_basic_block(current, "for.update");
|
let update_bb = ctx.ctx.insert_basic_block_after(body_bb, "for.update");
|
||||||
let cont_bb = ctx.ctx.append_basic_block(current, "for.end");
|
let cont_bb = ctx.ctx.insert_basic_block_after(update_bb, "for.end");
|
||||||
|
|
||||||
// store loop bb information and restore it later
|
// store loop bb information and restore it later
|
||||||
let loop_bb = ctx.loop_target.replace((update_bb, cont_bb));
|
let loop_bb = ctx.loop_target.replace((update_bb, cont_bb));
|
||||||
@ -719,12 +719,10 @@ pub fn gen_if_else_expr_callback<'ctx, 'a, G, CondFn, ThenFn, ElseFn, R>(
|
|||||||
R: BasicValue<'ctx>,
|
R: BasicValue<'ctx>,
|
||||||
{
|
{
|
||||||
let current_bb = ctx.builder.get_insert_block().unwrap();
|
let current_bb = ctx.builder.get_insert_block().unwrap();
|
||||||
let current_fn = current_bb.get_parent().unwrap();
|
|
||||||
|
|
||||||
let end_bb = ctx.ctx.append_basic_block(current_fn, "if.end");
|
|
||||||
|
|
||||||
let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "if.then");
|
let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "if.then");
|
||||||
let else_bb = ctx.ctx.insert_basic_block_after(current_bb, "if.else");
|
let else_bb = ctx.ctx.insert_basic_block_after(then_bb, "if.else");
|
||||||
|
let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "if.end");
|
||||||
|
|
||||||
let cond = cond_fn(generator, ctx)?;
|
let cond = cond_fn(generator, ctx)?;
|
||||||
assert_eq!(cond.get_type().get_bit_width(), ctx.ctx.bool_type().get_bit_width());
|
assert_eq!(cond.get_type().get_bit_width(), ctx.ctx.bool_type().get_bit_width());
|
||||||
@ -742,6 +740,7 @@ pub fn gen_if_else_expr_callback<'ctx, 'a, G, CondFn, ThenFn, ElseFn, R>(
|
|||||||
ctx.builder.build_unconditional_branch(end_bb).unwrap();
|
ctx.builder.build_unconditional_branch(end_bb).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx.builder.position_at_end(end_bb);
|
||||||
let phi = match (then_val, else_val) {
|
let phi = match (then_val, else_val) {
|
||||||
(Some(tv), Some(ev)) => {
|
(Some(tv), Some(ev)) => {
|
||||||
let tv_ty = tv.as_basic_value_enum().get_type();
|
let tv_ty = tv.as_basic_value_enum().get_type();
|
||||||
|
@ -291,6 +291,17 @@ pub fn impl_mod(
|
|||||||
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Mod]);
|
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Mod]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// [Operator::MatMult]
|
||||||
|
pub fn impl_matmul(
|
||||||
|
unifier: &mut Unifier,
|
||||||
|
store: &PrimitiveStore,
|
||||||
|
ty: Type,
|
||||||
|
other_ty: &[Type],
|
||||||
|
ret_ty: Option<Type>,
|
||||||
|
) {
|
||||||
|
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::MatMult])
|
||||||
|
}
|
||||||
|
|
||||||
/// `UAdd`, `USub`
|
/// `UAdd`, `USub`
|
||||||
pub fn impl_sign(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret_ty: Option<Type>) {
|
pub fn impl_sign(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret_ty: Option<Type>) {
|
||||||
impl_unaryop(unifier, ty, ret_ty, &[Unaryop::UAdd, Unaryop::USub]);
|
impl_unaryop(unifier, ty, ret_ty, &[Unaryop::UAdd, Unaryop::USub]);
|
||||||
@ -431,7 +442,38 @@ pub fn typeof_binop(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator::MatMult => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?,
|
Operator::MatMult => {
|
||||||
|
let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs);
|
||||||
|
let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) {
|
||||||
|
TypeEnum::TLiteral { values, .. } => {
|
||||||
|
assert_eq!(values.len(), 1);
|
||||||
|
u64::try_from(values[0].clone()).unwrap()
|
||||||
|
}
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs);
|
||||||
|
let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) {
|
||||||
|
TypeEnum::TLiteral { values, .. } => {
|
||||||
|
assert_eq!(values.len(), 1);
|
||||||
|
u64::try_from(values[0].clone()).unwrap()
|
||||||
|
}
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
match (lhs_ndims, rhs_ndims) {
|
||||||
|
(2, 2) => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?,
|
||||||
|
(lhs, rhs) if lhs == 0 || rhs == 0 => {
|
||||||
|
return Err(format!(
|
||||||
|
"Input operand {} does not have enough dimensions (has {lhs}, requires {rhs})",
|
||||||
|
(rhs == 0) as u8
|
||||||
|
))
|
||||||
|
}
|
||||||
|
(lhs, rhs) => {
|
||||||
|
return Err(format!("ndarray.__matmul__ on {lhs}D and {rhs}D operands not supported"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Operator::Div => {
|
Operator::Div => {
|
||||||
if is_left_ndarray || is_right_ndarray {
|
if is_left_ndarray || is_right_ndarray {
|
||||||
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
|
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
|
||||||
@ -610,6 +652,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
|||||||
impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None);
|
impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None);
|
||||||
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||||
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||||
|
impl_matmul(unifier, store, ndarray_t, &[ndarray_t], Some(ndarray_t));
|
||||||
impl_sign(unifier, store, ndarray_t, Some(ndarray_t));
|
impl_sign(unifier, store, ndarray_t, Some(ndarray_t));
|
||||||
impl_invert(unifier, store, ndarray_t, Some(ndarray_t));
|
impl_invert(unifier, store, ndarray_t, Some(ndarray_t));
|
||||||
impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||||
|
@ -429,6 +429,19 @@ def test_ndarray_ipow_broadcast_scalar():
|
|||||||
|
|
||||||
output_ndarray_float_2(x)
|
output_ndarray_float_2(x)
|
||||||
|
|
||||||
|
def test_ndarray_matmul():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = x @ np_ones([2, 2])
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_float_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_imatmul():
|
||||||
|
x = np_identity(2)
|
||||||
|
x @= np_ones([2, 2])
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
|
||||||
def test_ndarray_pos():
|
def test_ndarray_pos():
|
||||||
x_int32 = np_full([2, 2], -2)
|
x_int32 = np_full([2, 2], -2)
|
||||||
y_int32 = +x_int32
|
y_int32 = +x_int32
|
||||||
@ -696,6 +709,8 @@ def run() -> int32:
|
|||||||
test_ndarray_ipow()
|
test_ndarray_ipow()
|
||||||
test_ndarray_ipow_broadcast()
|
test_ndarray_ipow_broadcast()
|
||||||
test_ndarray_ipow_broadcast_scalar()
|
test_ndarray_ipow_broadcast_scalar()
|
||||||
|
test_ndarray_matmul()
|
||||||
|
test_ndarray_imatmul()
|
||||||
test_ndarray_pos()
|
test_ndarray_pos()
|
||||||
test_ndarray_neg()
|
test_ndarray_neg()
|
||||||
test_ndarray_inv()
|
test_ndarray_inv()
|
||||||
|
Loading…
Reference in New Issue
Block a user