ndarray: Implement 2D-2D matrix multiplication #398
|
@ -58,7 +58,7 @@ pub trait ArrayLikeIndexer<'ctx, Index = IntValue<'ctx>>: ArrayLikeValue<'ctx> {
|
|||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: Index,
|
||||
idx: &Index,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx>;
|
||||
|
||||
|
@ -67,7 +67,7 @@ pub trait ArrayLikeIndexer<'ctx, Index = IntValue<'ctx>>: ArrayLikeValue<'ctx> {
|
|||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: Index,
|
||||
idx: &Index,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx>;
|
||||
}
|
||||
|
@ -81,7 +81,7 @@ pub trait UntypedArrayLikeAccessor<'ctx, Index = IntValue<'ctx>>: ArrayLikeIndex
|
|||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: Index,
|
||||
idx: &Index,
|
||||
name: Option<&str>,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
let ptr = self.ptr_offset_unchecked(ctx, generator, idx, name);
|
||||
|
@ -93,7 +93,7 @@ pub trait UntypedArrayLikeAccessor<'ctx, Index = IntValue<'ctx>>: ArrayLikeIndex
|
|||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: Index,
|
||||
idx: &Index,
|
||||
name: Option<&str>,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
let ptr = self.ptr_offset(ctx, generator, idx, name);
|
||||
|
@ -110,7 +110,7 @@ pub trait UntypedArrayLikeMutator<'ctx, Index = IntValue<'ctx>>: ArrayLikeIndexe
|
|||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: Index,
|
||||
idx: &Index,
|
||||
value: BasicValueEnum<'ctx>,
|
||||
) {
|
||||
let ptr = self.ptr_offset_unchecked(ctx, generator, idx, None);
|
||||
|
@ -122,7 +122,7 @@ pub trait UntypedArrayLikeMutator<'ctx, Index = IntValue<'ctx>>: ArrayLikeIndexe
|
|||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: Index,
|
||||
idx: &Index,
|
||||
value: BasicValueEnum<'ctx>,
|
||||
) {
|
||||
let ptr = self.ptr_offset(ctx, generator, idx, None);
|
||||
|
@ -142,7 +142,7 @@ pub trait TypedArrayLikeAccessor<'ctx, T, Index = IntValue<'ctx>>: UntypedArrayL
|
|||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: Index,
|
||||
idx: &Index,
|
||||
name: Option<&str>,
|
||||
) -> T {
|
||||
let value = self.get_unchecked(ctx, generator, idx, name);
|
||||
|
@ -154,7 +154,7 @@ pub trait TypedArrayLikeAccessor<'ctx, T, Index = IntValue<'ctx>>: UntypedArrayL
|
|||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: Index,
|
||||
idx: &Index,
|
||||
name: Option<&str>,
|
||||
) -> T {
|
||||
let value = self.get(ctx, generator, idx, name);
|
||||
|
@ -174,7 +174,7 @@ pub trait TypedArrayLikeMutator<'ctx, T, Index = IntValue<'ctx>>: UntypedArrayLi
|
|||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: Index,
|
||||
idx: &Index,
|
||||
value: T,
|
||||
) {
|
||||
let value = self.upcast_from_type(ctx, value);
|
||||
|
@ -186,7 +186,7 @@ pub trait TypedArrayLikeMutator<'ctx, T, Index = IntValue<'ctx>>: UntypedArrayLi
|
|||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: Index,
|
||||
idx: &Index,
|
||||
value: T,
|
||||
) {
|
||||
let value = self.upcast_from_type(ctx, value);
|
||||
|
@ -255,7 +255,7 @@ impl<'ctx, T, Index, Adapted> ArrayLikeIndexer<'ctx, Index> for TypedArrayLikeAd
|
|||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: Index,
|
||||
idx: &Index,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
self.adapted.ptr_offset_unchecked(ctx, generator, idx, name)
|
||||
|
@ -265,7 +265,7 @@ impl<'ctx, T, Index, Adapted> ArrayLikeIndexer<'ctx, Index> for TypedArrayLikeAd
|
|||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: Index,
|
||||
idx: &Index,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
self.adapted.ptr_offset(ctx, generator, idx, name)
|
||||
|
@ -345,7 +345,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ArraySliceValue<'ctx> {
|
|||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: IntValue<'ctx>,
|
||||
idx: &IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let var_name = name
|
||||
|
@ -354,7 +354,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ArraySliceValue<'ctx> {
|
|||
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
self.base_ptr(ctx, generator),
|
||||
&[idx],
|
||||
&[*idx],
|
||||
var_name.as_str(),
|
||||
).unwrap()
|
||||
}
|
||||
|
@ -363,13 +363,13 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ArraySliceValue<'ctx> {
|
|||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: IntValue<'ctx>,
|
||||
idx: &IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
debug_assert_eq!(idx.get_type(), generator.get_size_type(ctx.ctx));
|
||||
|
||||
let size = self.size(ctx, generator);
|
||||
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, idx, size, "").unwrap();
|
||||
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
in_range,
|
||||
|
@ -573,7 +573,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ListDataProxy<'ctx, '_> {
|
|||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: IntValue<'ctx>,
|
||||
idx: &IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let var_name = name
|
||||
|
@ -582,7 +582,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ListDataProxy<'ctx, '_> {
|
|||
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
self.base_ptr(ctx, generator),
|
||||
&[idx],
|
||||
&[*idx],
|
||||
var_name.as_str(),
|
||||
).unwrap()
|
||||
}
|
||||
|
@ -591,13 +591,13 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ListDataProxy<'ctx, '_> {
|
|||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: IntValue<'ctx>,
|
||||
idx: &IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
debug_assert_eq!(idx.get_type(), generator.get_size_type(ctx.ctx));
|
||||
|
||||
let size = self.size(ctx, generator);
|
||||
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, idx, size, "").unwrap();
|
||||
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
in_range,
|
||||
|
@ -1015,7 +1015,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_>
|
|||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: IntValue<'ctx>,
|
||||
idx: &IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let var_name = name
|
||||
|
@ -1024,7 +1024,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_>
|
|||
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
self.base_ptr(ctx, generator),
|
||||
&[idx],
|
||||
&[*idx],
|
||||
var_name.as_str(),
|
||||
).unwrap()
|
||||
}
|
||||
|
@ -1033,13 +1033,13 @@ impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_>
|
|||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: IntValue<'ctx>,
|
||||
idx: &IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let size = self.size(ctx, generator);
|
||||
let in_range = ctx.builder.build_int_compare(
|
||||
IntPredicate::ULT,
|
||||
idx,
|
||||
*idx,
|
||||
size,
|
||||
""
|
||||
).unwrap();
|
||||
|
@ -1048,7 +1048,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_>
|
|||
in_range,
|
||||
"0:IndexError",
|
||||
"index {0} is out of bounds for axis 0 with size {1}",
|
||||
[Some(idx), Some(self.0.load_ndims(ctx)), None],
|
||||
[Some(*idx), Some(self.0.load_ndims(ctx)), None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
|
@ -1120,12 +1120,12 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> {
|
|||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: IntValue<'ctx>,
|
||||
idx: &IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
self.base_ptr(ctx, generator),
|
||||
&[idx],
|
||||
&[*idx],
|
||||
name.unwrap_or_default(),
|
||||
).unwrap()
|
||||
}
|
||||
|
@ -1134,13 +1134,13 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> {
|
|||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: IntValue<'ctx>,
|
||||
idx: &IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let data_sz = self.size(ctx, generator);
|
||||
let in_range = ctx.builder.build_int_compare(
|
||||
IntPredicate::ULT,
|
||||
idx,
|
||||
*idx,
|
||||
data_sz,
|
||||
""
|
||||
).unwrap();
|
||||
|
@ -1149,7 +1149,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> {
|
|||
in_range,
|
||||
"0:IndexError",
|
||||
"index {0} is out of bounds with size {1}",
|
||||
[Some(idx), Some(self.0.load_ndims(ctx)), None],
|
||||
[Some(*idx), Some(self.0.load_ndims(ctx)), None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
|
@ -1167,12 +1167,15 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
|
|||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
indices: Index,
|
||||
indices: &Index,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let indices_elem_ty = indices.ptr_offset(ctx, generator, llvm_usize.const_zero(), None).get_type().get_element_type();
|
||||
let indices_elem_ty = indices
|
||||
.ptr_offset(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.get_type()
|
||||
.get_element_type();
|
||||
let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else {
|
||||
panic!("Expected list[int32] but got {indices_elem_ty}")
|
||||
};
|
||||
|
@ -1182,7 +1185,7 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
|
|||
generator,
|
||||
ctx,
|
||||
*self.0,
|
||||
&indices,
|
||||
indices,
|
||||
);
|
||||
|
||||
unsafe {
|
||||
|
@ -1198,7 +1201,7 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
|
|||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
indices: Index,
|
||||
indices: &Index,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
@ -1230,8 +1233,8 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
|
|||
|generator, ctx, i| {
|
||||
let (dim_idx, dim_sz) = unsafe {
|
||||
(
|
||||
indices.get_unchecked(ctx, generator, i, None).into_int_value(),
|
||||
self.0.dim_sizes().get_typed_unchecked(ctx, generator, i, None),
|
||||
indices.get_unchecked(ctx, generator, &i, None).into_int_value(),
|
||||
self.0.dim_sizes().get_typed_unchecked(ctx, generator, &i, None),
|
||||
)
|
||||
};
|
||||
let dim_idx = ctx.builder
|
||||
|
|
|
@ -384,7 +384,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
|||
rhs: BasicValueEnum<'ctx>,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
let (BasicValueEnum::FloatValue(lhs), BasicValueEnum::FloatValue(rhs)) = (lhs, rhs) else {
|
||||
unreachable!()
|
||||
unreachable!("Expected (FloatValue, FloatValue), got ({}, {})", lhs.get_type(), rhs.get_type())
|
||||
};
|
||||
match op {
|
||||
Operator::Add => self.builder.build_float_add(lhs, rhs, "fadd").map(Into::into).unwrap(),
|
||||
|
@ -589,8 +589,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
|||
// even if this assumption is violated, it does not matter as exception unwinding is
|
||||
// slow anyway...
|
||||
let cond = call_expect(self, cond, i1_true, Some("expect"));
|
||||
let current_fun = self.builder.get_insert_block().unwrap().get_parent().unwrap();
|
||||
let then_block = self.ctx.append_basic_block(current_fun, "succ");
|
||||
let current_bb = self.builder.get_insert_block().unwrap();
|
||||
let current_fun = current_bb.get_parent().unwrap();
|
||||
let then_block = self.ctx.insert_basic_block_after(current_bb, "succ");
|
||||
let exn_block = self.ctx.append_basic_block(current_fun, "fail");
|
||||
self.builder.build_conditional_branch(cond, then_block, exn_block).unwrap();
|
||||
self.builder.position_at_end(exn_block);
|
||||
|
@ -1148,27 +1149,45 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||
let left_val = NDArrayValue::from_ptr_val(
|
||||
left_val.into_pointer_value(),
|
||||
llvm_usize,
|
||||
None
|
||||
None,
|
||||
);
|
||||
let res = numpy::ndarray_elementwise_binop_impl(
|
||||
generator,
|
||||
ctx,
|
||||
ndarray_dtype1,
|
||||
if is_aug_assign { Some(left_val) } else { None },
|
||||
(left_val.as_ptr_value().into(), false),
|
||||
(right_val, false),
|
||||
|generator, ctx, (lhs, rhs)| {
|
||||
gen_binop_expr_with_values(
|
||||
generator,
|
||||
ctx,
|
||||
(&Some(ndarray_dtype1), lhs),
|
||||
op,
|
||||
(&Some(ndarray_dtype2), rhs),
|
||||
ctx.current_loc,
|
||||
is_aug_assign,
|
||||
)?.unwrap().to_basic_value_enum(ctx, generator, ndarray_dtype1)
|
||||
},
|
||||
)?;
|
||||
let right_val = NDArrayValue::from_ptr_val(
|
||||
right_val.into_pointer_value(),
|
||||
llvm_usize,
|
||||
None,
|
||||
);
|
||||
|
||||
let res = if *op == Operator::MatMult {
|
||||
// MatMult is the only binop which is not an elementwise op
|
||||
numpy::ndarray_matmul_2d(
|
||||
generator,
|
||||
ctx,
|
||||
ndarray_dtype1,
|
||||
if is_aug_assign { Some(left_val) } else { None },
|
||||
left_val,
|
||||
right_val,
|
||||
)?
|
||||
} else {
|
||||
numpy::ndarray_elementwise_binop_impl(
|
||||
generator,
|
||||
ctx,
|
||||
ndarray_dtype1,
|
||||
if is_aug_assign { Some(left_val) } else { None },
|
||||
(left_val.as_ptr_value().into(), false),
|
||||
(right_val.as_ptr_value().into(), false),
|
||||
|generator, ctx, (lhs, rhs)| {
|
||||
gen_binop_expr_with_values(
|
||||
generator,
|
||||
ctx,
|
||||
(&Some(ndarray_dtype1), lhs),
|
||||
op,
|
||||
(&Some(ndarray_dtype2), rhs),
|
||||
ctx.current_loc,
|
||||
is_aug_assign,
|
||||
)?.unwrap().to_basic_value_enum(ctx, generator, ndarray_dtype1)
|
||||
},
|
||||
)?
|
||||
};
|
||||
|
||||
Ok(Some(res.as_ptr_value().into()))
|
||||
} else {
|
||||
|
@ -1719,7 +1738,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
|||
v.dim_sizes().get_typed_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
llvm_usize.const_zero(),
|
||||
&llvm_usize.const_zero(),
|
||||
None,
|
||||
)
|
||||
};
|
||||
|
@ -1746,7 +1765,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
|||
.get(
|
||||
ctx,
|
||||
generator,
|
||||
ArraySliceValue::from_ptr_val(index_addr, llvm_usize.const_int(1, false), None),
|
||||
&ArraySliceValue::from_ptr_val(index_addr, llvm_usize.const_int(1, false), None),
|
||||
None,
|
||||
)
|
||||
.into()))
|
||||
|
@ -1781,7 +1800,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
|||
v.dim_sizes().ptr_offset_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
llvm_usize.const_int(1, false),
|
||||
&llvm_usize.const_int(1, false),
|
||||
None,
|
||||
)
|
||||
};
|
||||
|
@ -1806,7 +1825,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
|||
let v_data_src_ptr = v.data().ptr_offset(
|
||||
ctx,
|
||||
generator,
|
||||
ArraySliceValue::from_ptr_val(index_addr, llvm_usize.const_int(1, false), None),
|
||||
&ArraySliceValue::from_ptr_val(index_addr, llvm_usize.const_int(1, false), None),
|
||||
None
|
||||
);
|
||||
call_memcpy_generic(
|
||||
|
@ -1906,7 +1925,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
|||
let arr_ptr = arr_str_ptr.data();
|
||||
for (i, v) in elements.iter().enumerate() {
|
||||
let elem_ptr = arr_ptr
|
||||
.ptr_offset(ctx, generator, usize.const_int(i as u64, false), Some("elem_ptr"));
|
||||
.ptr_offset(ctx, generator, &usize.const_int(i as u64, false), Some("elem_ptr"));
|
||||
ctx.builder.build_store(elem_ptr, *v).unwrap();
|
||||
}
|
||||
arr_str_ptr.as_ptr_value().into()
|
||||
|
@ -2324,7 +2343,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
|||
[Some(raw_index), Some(len), None],
|
||||
expr.location,
|
||||
);
|
||||
v.data().get(ctx, generator, index, None).into()
|
||||
v.data().get(ctx, generator, &index, None).into()
|
||||
}
|
||||
}
|
||||
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
|
||||
|
|
|
@ -833,8 +833,8 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
|
|||
let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap();
|
||||
let (lhs_dim_sz, rhs_dim_sz) = unsafe {
|
||||
(
|
||||
lhs.dim_sizes().get_typed_unchecked(ctx, generator, idx, None),
|
||||
rhs.dim_sizes().get_typed_unchecked(ctx, generator, idx, None),
|
||||
lhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
|
||||
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
|
||||
)
|
||||
};
|
||||
|
||||
|
@ -955,7 +955,7 @@ pub fn call_ndarray_calc_broadcast_index<'ctx, G: CodeGenerator + ?Sized, Broadc
|
|||
broadcast_idx.ptr_offset_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
llvm_usize.const_zero(),
|
||||
&llvm_usize.const_zero(),
|
||||
None
|
||||
)
|
||||
};
|
||||
|
|
|
@ -1,9 +1,5 @@
|
|||
use inkwell::{
|
||||
IntPredicate,
|
||||
types::BasicType,
|
||||
values::{BasicValueEnum, IntValue, PointerValue}
|
||||
};
|
||||
use nac3parser::ast::StrRef;
|
||||
use inkwell::{IntPredicate, OptimizationLevel, types::BasicType, values::{BasicValueEnum, IntValue, PointerValue}};
|
||||
use nac3parser::ast::{Operator, StrRef};
|
||||
use crate::{
|
||||
codegen::{
|
||||
classes::{
|
||||
|
@ -14,17 +10,20 @@ use crate::{
|
|||
TypedArrayLikeAccessor,
|
||||
TypedArrayLikeAdapter,
|
||||
UntypedArrayLikeAccessor,
|
||||
UntypedArrayLikeMutator,
|
||||
},
|
||||
CodeGenContext,
|
||||
CodeGenerator,
|
||||
expr::gen_binop_expr_with_values,
|
||||
irrt::{
|
||||
call_ndarray_calc_broadcast,
|
||||
call_ndarray_calc_broadcast_index,
|
||||
call_ndarray_calc_nd_indices,
|
||||
call_ndarray_calc_size,
|
||||
},
|
||||
llvm_intrinsics::call_memcpy_generic,
|
||||
stmt::gen_for_callback_incrementing,
|
||||
llvm_intrinsics,
|
||||
llvm_intrinsics::{call_memcpy_generic},
|
||||
stmt::{gen_for_callback_incrementing, gen_if_else_expr_callback},
|
||||
},
|
||||
symbol_resolver::ValueEnum,
|
||||
toplevel::{
|
||||
|
@ -85,6 +84,8 @@ fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
|
|||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
// TODO: Disallow dim_sz > u32_MAX
|
||||
|
||||
Ok(())
|
||||
},
|
||||
|
@ -119,7 +120,7 @@ fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
|
|||
.unwrap();
|
||||
|
||||
let ndarray_pdim = unsafe {
|
||||
ndarray.dim_sizes().ptr_offset_unchecked(ctx, generator, i, None)
|
||||
ndarray.dim_sizes().ptr_offset_unchecked(ctx, generator, &i, None)
|
||||
};
|
||||
|
||||
ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap();
|
||||
|
@ -171,6 +172,8 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
|
|||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
// TODO: Disallow dim_sz > u32_MAX
|
||||
}
|
||||
|
||||
let ndarray = generator.gen_var_alloc(
|
||||
|
@ -190,7 +193,7 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
|
|||
let ndarray_dim = unsafe {
|
||||
ndarray
|
||||
.dim_sizes()
|
||||
.ptr_offset_unchecked(ctx, generator, llvm_usize.const_int(i as u64, true), None)
|
||||
.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_int(i as u64, true), None)
|
||||
};
|
||||
|
||||
ctx.builder.build_store(ndarray_dim, *shape_dim).unwrap();
|
||||
|
@ -267,7 +270,7 @@ fn call_ndarray_empty_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|||
Ok(shape.load_size(ctx, None))
|
||||
},
|
||||
|generator, ctx, shape, idx| {
|
||||
Ok(shape.data().get(ctx, generator, idx, None).into_int_value())
|
||||
Ok(shape.data().get(ctx, generator, &idx, None).into_int_value())
|
||||
},
|
||||
)
|
||||
}
|
||||
|
@ -299,7 +302,7 @@ fn ndarray_fill_flattened<'ctx, 'a, G, ValueFn>(
|
|||
(ndarray_num_elems, false),
|
||||
|generator, ctx, i| {
|
||||
let elem = unsafe {
|
||||
ndarray.data().ptr_offset_unchecked(ctx, generator, i, None)
|
||||
ndarray.data().ptr_offset_unchecked(ctx, generator, &i, None)
|
||||
};
|
||||
|
||||
let value = value_fn(generator, ctx, i)?;
|
||||
|
@ -321,7 +324,7 @@ fn ndarray_fill_indexed<'ctx, G, ValueFn>(
|
|||
) -> Result<(), String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, TypedArrayLikeAdapter<'ctx, IntValue<'ctx>>) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, &TypedArrayLikeAdapter<'ctx, IntValue<'ctx>>) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
{
|
||||
ndarray_fill_flattened(
|
||||
generator,
|
||||
|
@ -335,7 +338,7 @@ fn ndarray_fill_indexed<'ctx, G, ValueFn>(
|
|||
ndarray,
|
||||
);
|
||||
|
||||
value_fn(generator, ctx, indices)
|
||||
value_fn(generator, ctx, &indices)
|
||||
}
|
||||
)
|
||||
}
|
||||
|
@ -357,7 +360,7 @@ fn ndarray_fill_mapping<'ctx, G, MapFn>(
|
|||
dest,
|
||||
|generator, ctx, i| {
|
||||
let elem = unsafe {
|
||||
src.data().get_unchecked(ctx, generator, i, None)
|
||||
src.data().get_unchecked(ctx, generator, &i, None)
|
||||
};
|
||||
|
||||
map_fn(generator, ctx, elem)
|
||||
|
@ -430,10 +433,10 @@ fn ndarray_broadcast_fill<'ctx, G, ValueFn>(
|
|||
lhs_val
|
||||
} else {
|
||||
let lhs = NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None);
|
||||
let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, &idx);
|
||||
let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, idx);
|
||||
|
||||
unsafe {
|
||||
lhs.data().get_unchecked(ctx, generator, lhs_idx, None)
|
||||
lhs.data().get_unchecked(ctx, generator, &lhs_idx, None)
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -441,10 +444,10 @@ fn ndarray_broadcast_fill<'ctx, G, ValueFn>(
|
|||
rhs_val
|
||||
} else {
|
||||
let rhs = NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None);
|
||||
let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, &idx);
|
||||
let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, idx);
|
||||
|
||||
unsafe {
|
||||
rhs.data().get_unchecked(ctx, generator, rhs_idx, None)
|
||||
rhs.data().get_unchecked(ctx, generator, &rhs_idx, None)
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -604,8 +607,8 @@ fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|||
|generator, ctx, indices| {
|
||||
let (row, col) = unsafe {
|
||||
(
|
||||
indices.get_typed_unchecked(ctx, generator, llvm_usize.const_zero(), None),
|
||||
indices.get_typed_unchecked(ctx, generator, llvm_usize.const_int(1, false), None),
|
||||
indices.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None),
|
||||
indices.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None),
|
||||
)
|
||||
};
|
||||
|
||||
|
@ -652,7 +655,7 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|||
Ok(shape.load_ndims(ctx))
|
||||
},
|
||||
|generator, ctx, shape, idx| {
|
||||
unsafe { Ok(shape.dim_sizes().get_typed_unchecked(ctx, generator, idx, None)) }
|
||||
unsafe { Ok(shape.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) }
|
||||
},
|
||||
)?;
|
||||
|
||||
|
@ -704,7 +707,7 @@ pub fn ndarray_elementwise_unaryop_impl<'ctx, G, MapFn>(
|
|||
},
|
||||
|generator, ctx, v, idx| {
|
||||
unsafe {
|
||||
Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, idx, None))
|
||||
Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None))
|
||||
}
|
||||
},
|
||||
).unwrap()
|
||||
|
@ -782,7 +785,7 @@ pub fn ndarray_elementwise_binop_impl<'ctx, G, ValueFn>(
|
|||
},
|
||||
|generator, ctx, v, idx| {
|
||||
unsafe {
|
||||
Ok(v.get_typed_unchecked(ctx, generator, idx, None))
|
||||
Ok(v.get_typed_unchecked(ctx, generator, &idx, None))
|
||||
}
|
||||
},
|
||||
).unwrap()
|
||||
|
@ -803,7 +806,7 @@ pub fn ndarray_elementwise_binop_impl<'ctx, G, ValueFn>(
|
|||
},
|
||||
|generator, ctx, v, idx| {
|
||||
unsafe {
|
||||
Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, idx, None))
|
||||
Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None))
|
||||
}
|
||||
},
|
||||
).unwrap()
|
||||
|
@ -824,6 +827,319 @@ pub fn ndarray_elementwise_binop_impl<'ctx, G, ValueFn>(
|
|||
Ok(ndarray)
|
||||
}
|
||||
|
||||
/// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s.
|
||||
///
|
||||
/// * `elem_ty` - The element type of the `NDArray`.
|
||||
/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be
|
||||
/// written to a new `ndarray`.
|
||||
pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
elem_ty: Type,
|
||||
res: Option<NDArrayValue<'ctx>>,
|
||||
lhs: NDArrayValue<'ctx>,
|
||||
rhs: NDArrayValue<'ctx>,
|
||||
) -> Result<NDArrayValue<'ctx>, String> {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
if cfg!(debug_assertions) {
|
||||
let lhs_ndims = lhs.load_ndims(ctx);
|
||||
let rhs_ndims = rhs.load_ndims(ctx);
|
||||
|
||||
// lhs.ndims == 2
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder.build_int_compare(
|
||||
IntPredicate::EQ,
|
||||
lhs_ndims,
|
||||
llvm_usize.const_int(2, false),
|
||||
"",
|
||||
).unwrap(),
|
||||
"0:ValueError",
|
||||
"",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
// rhs.ndims == 2
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder.build_int_compare(
|
||||
IntPredicate::EQ,
|
||||
rhs_ndims,
|
||||
llvm_usize.const_int(2, false),
|
||||
"",
|
||||
).unwrap(),
|
||||
"0:ValueError",
|
||||
"",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
if let Some(res) = res {
|
||||
let res_ndims = res.load_ndims(ctx);
|
||||
let res_dim0 = unsafe {
|
||||
res.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
};
|
||||
let res_dim1 = unsafe {
|
||||
res.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
};
|
||||
let lhs_dim0 = unsafe {
|
||||
lhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
};
|
||||
let rhs_dim1 = unsafe {
|
||||
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
};
|
||||
|
||||
// res.ndims == 2
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder.build_int_compare(
|
||||
IntPredicate::EQ,
|
||||
res_ndims,
|
||||
llvm_usize.const_int(2, false),
|
||||
"",
|
||||
).unwrap(),
|
||||
"0:ValueError",
|
||||
"",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
// res.dims[0] == lhs.dims[0]
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder.build_int_compare(
|
||||
IntPredicate::EQ,
|
||||
lhs_dim0,
|
||||
res_dim0,
|
||||
"",
|
||||
).unwrap(),
|
||||
"0:ValueError",
|
||||
"",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
// res.dims[1] == rhs.dims[0]
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder.build_int_compare(
|
||||
IntPredicate::EQ,
|
||||
rhs_dim1,
|
||||
res_dim1,
|
||||
"",
|
||||
).unwrap(),
|
||||
"0:ValueError",
|
||||
"",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||
let lhs_dim1 = unsafe {
|
||||
lhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
};
|
||||
let rhs_dim0 = unsafe {
|
||||
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
};
|
||||
|
||||
// lhs.dims[1] == rhs.dims[0]
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder.build_int_compare(
|
||||
IntPredicate::EQ,
|
||||
lhs_dim1,
|
||||
rhs_dim0,
|
||||
"",
|
||||
).unwrap(),
|
||||
"0:ValueError",
|
||||
"",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
}
|
||||
|
||||
let lhs = if res.is_some_and(|res| res.as_ptr_value() == lhs.as_ptr_value()) {
|
||||
ndarray_copy_impl(generator, ctx, elem_ty, lhs)?
|
||||
} else {
|
||||
lhs
|
||||
};
|
||||
|
||||
let ndarray = res.unwrap_or_else(|| {
|
||||
create_ndarray_dyn_shape(
|
||||
generator,
|
||||
ctx,
|
||||
elem_ty,
|
||||
&(lhs, rhs),
|
||||
|_, _, _| {
|
||||
Ok(llvm_usize.const_int(2, false))
|
||||
},
|
||||
|generator, ctx, (lhs, rhs), idx| {
|
||||
gen_if_else_expr_callback(
|
||||
generator,
|
||||
ctx,
|
||||
|_, ctx| {
|
||||
Ok(ctx.builder.build_int_compare(
|
||||
IntPredicate::EQ,
|
||||
idx,
|
||||
llvm_usize.const_zero(),
|
||||
"",
|
||||
).unwrap())
|
||||
},
|
||||
|generator, ctx| {
|
||||
Ok(Some(unsafe {
|
||||
lhs.dim_sizes().get_typed_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_zero(),
|
||||
None,
|
||||
)
|
||||
}))
|
||||
},
|
||||
|generator, ctx| {
|
||||
Ok(Some(unsafe {
|
||||
rhs.dim_sizes().get_typed_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(1, false),
|
||||
None,
|
||||
)
|
||||
}))
|
||||
},
|
||||
).map(|v| v.map(BasicValueEnum::into_int_value).unwrap())
|
||||
},
|
||||
).unwrap()
|
||||
});
|
||||
|
||||
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
|
||||
ndarray_fill_indexed(
|
||||
generator,
|
||||
ctx,
|
||||
ndarray,
|
||||
|generator, ctx, idx| {
|
||||
llvm_intrinsics::call_expect(
|
||||
ctx,
|
||||
idx.size(ctx, generator).get_type().const_int(2, false),
|
||||
idx.size(ctx, generator),
|
||||
None,
|
||||
);
|
||||
|
||||
let common_dim = {
|
||||
let lhs_idx1 = unsafe {
|
||||
lhs.dim_sizes().get_typed_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(1, false),
|
||||
None,
|
||||
)
|
||||
};
|
||||
let rhs_idx0 = unsafe {
|
||||
rhs.dim_sizes().get_typed_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_zero(),
|
||||
None,
|
||||
)
|
||||
};
|
||||
|
||||
let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None);
|
||||
|
||||
ctx.builder.build_int_truncate(idx, llvm_i32, "").unwrap()
|
||||
};
|
||||
|
||||
let idx0 = unsafe {
|
||||
let idx0 = idx.get_typed_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_zero(),
|
||||
None,
|
||||
);
|
||||
|
||||
ctx.builder.build_int_truncate(idx0, llvm_i32, "").unwrap()
|
||||
};
|
||||
let idx1 = unsafe {
|
||||
let idx1 = idx.get_typed_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(1, false),
|
||||
None,
|
||||
);
|
||||
|
||||
ctx.builder.build_int_truncate(idx1, llvm_i32, "").unwrap()
|
||||
};
|
||||
|
||||
let result_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
|
||||
let result_identity = ndarray_zero_value(generator, ctx, elem_ty);
|
||||
ctx.builder.build_store(result_addr, result_identity).unwrap();
|
||||
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
llvm_i32.const_zero(),
|
||||
(common_dim, false),
|
||||
|generator, ctx, i| {
|
||||
let i = ctx.builder.build_int_truncate(i, llvm_i32, "").unwrap();
|
||||
|
||||
let ab_idx = generator.gen_array_var_alloc(
|
||||
ctx,
|
||||
llvm_i32.into(),
|
||||
llvm_usize.const_int(2, false),
|
||||
None,
|
||||
)?;
|
||||
|
||||
let a = unsafe {
|
||||
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), idx0.into());
|
||||
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_int(1, false), i.into());
|
||||
|
||||
lhs.data().get_unchecked(ctx, generator, &ab_idx, None)
|
||||
};
|
||||
let b = unsafe {
|
||||
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), i.into());
|
||||
ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_int(1, false), idx1.into());
|
||||
|
||||
rhs.data().get_unchecked(ctx, generator, &ab_idx, None)
|
||||
};
|
||||
|
||||
let a_mul_b = gen_binop_expr_with_values(
|
||||
generator,
|
||||
ctx,
|
||||
(&Some(elem_ty), a),
|
||||
&Operator::Mult,
|
||||
(&Some(elem_ty), b),
|
||||
ctx.current_loc,
|
||||
false,
|
||||
)?.unwrap().to_basic_value_enum(ctx, generator, elem_ty)?;
|
||||
|
||||
let result = ctx.builder.build_load(result_addr, "").unwrap();
|
||||
let result = gen_binop_expr_with_values(
|
||||
generator,
|
||||
ctx,
|
||||
(&Some(elem_ty), result),
|
||||
&Operator::Add,
|
||||
(&Some(elem_ty), a_mul_b),
|
||||
ctx.current_loc,
|
||||
false,
|
||||
)?.unwrap().to_basic_value_enum(ctx, generator, elem_ty)?;
|
||||
ctx.builder.build_store(result_addr, result).unwrap();
|
||||
|
||||
Ok(())
|
||||
},
|
||||
llvm_usize.const_int(1, false),
|
||||
)?;
|
||||
|
||||
let result = ctx.builder.build_load(result_addr, "").unwrap();
|
||||
Ok(result)
|
||||
}
|
||||
)?;
|
||||
|
||||
Ok(ndarray)
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `ndarray.empty`.
|
||||
pub fn gen_ndarray_empty<'ctx>(
|
||||
context: &mut CodeGenContext<'ctx, '_>,
|
||||
|
|
|
@ -29,7 +29,6 @@ use nac3parser::ast::{
|
|||
Constant, ExcepthandlerKind, Expr, ExprKind, Location, Stmt, StmtKind, StrRef,
|
||||
};
|
||||
use std::convert::TryFrom;
|
||||
use itertools::Itertools;
|
||||
|
||||
/// See [`CodeGenerator::gen_var_alloc`].
|
||||
pub fn gen_var<'ctx>(
|
||||
|
@ -190,7 +189,7 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
|
|||
[Some(raw_index), Some(len), None],
|
||||
slice.location,
|
||||
);
|
||||
v.data().ptr_offset(ctx, generator, index, name)
|
||||
v.data().ptr_offset(ctx, generator, &index, name)
|
||||
}
|
||||
|
||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
|
||||
|
@ -496,14 +495,14 @@ pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>(
|
|||
BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
|
||||
UpdateFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
|
||||
{
|
||||
let current = ctx.builder.get_insert_block().and_then(BasicBlock::get_parent).unwrap();
|
||||
let init_bb = ctx.ctx.append_basic_block(current, "for.init");
|
||||
let current_bb = ctx.builder.get_insert_block().unwrap();
|
||||
let init_bb = ctx.ctx.insert_basic_block_after(current_bb, "for.init");
|
||||
// The BB containing the loop condition check
|
||||
let cond_bb = ctx.ctx.append_basic_block(current, "for.cond");
|
||||
let body_bb = ctx.ctx.append_basic_block(current, "for.body");
|
||||
let cond_bb = ctx.ctx.insert_basic_block_after(init_bb, "for.cond");
|
||||
let body_bb = ctx.ctx.insert_basic_block_after(cond_bb, "for.body");
|
||||
// The BB containing the increment expression
|
||||
let update_bb = ctx.ctx.append_basic_block(current, "for.update");
|
||||
let cont_bb = ctx.ctx.append_basic_block(current, "for.end");
|
||||
let update_bb = ctx.ctx.insert_basic_block_after(body_bb, "for.update");
|
||||
let cont_bb = ctx.ctx.insert_basic_block_after(update_bb, "for.end");
|
||||
|
||||
// store loop bb information and restore it later
|
||||
let loop_bb = ctx.loop_target.replace((update_bb, cont_bb));
|
||||
|
@ -695,173 +694,6 @@ pub fn gen_while<G: CodeGenerator>(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Generates a C-style chained-`if` construct using lambdas, similar to the following C code:
|
||||
///
|
||||
/// ```c
|
||||
/// T val;
|
||||
/// if (ifts[0].cond()) {
|
||||
/// val = ifts[0].then();
|
||||
/// } else if (ifts[1].cond()) {
|
||||
/// val = ifts[1].then();
|
||||
/// } else if /* ... */
|
||||
/// else {
|
||||
/// if (else_fn) {
|
||||
/// val = else_fn();
|
||||
/// } else {
|
||||
/// __builtin_unreachable();
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// - `ifts` - A slice of tuples containing the condition and body of a branch respectively. The
|
||||
/// branches will be generated in the order as appears in the slice.
|
||||
/// - `else_fn` - The body to generate if no other branches evaluates to `true`. If [`None`], a call
|
||||
/// to `__builtin_unreachable` will be generated instead.
|
||||
pub fn gen_chained_if_expr_callback<'ctx, 'a, G, CondFn, ThenFn, ElseFn, R>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ifts: &[(CondFn, ThenFn)],
|
||||
else_fn: Option<ElseFn>,
|
||||
) -> Result<Option<BasicValueEnum<'ctx>>, String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
CondFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
|
||||
ThenFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<Option<R>, String>,
|
||||
ElseFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<Option<R>, String>,
|
||||
R: BasicValue<'ctx>,
|
||||
{
|
||||
assert!(!ifts.is_empty());
|
||||
|
||||
let current_bb = ctx.builder.get_insert_block().unwrap();
|
||||
let current_fn = current_bb.get_parent().unwrap();
|
||||
|
||||
let end_bb = ctx.ctx.append_basic_block(current_fn, "if.end");
|
||||
|
||||
let vals = {
|
||||
let mut vals = ifts.iter()
|
||||
.map(|(cond, then)| -> Result<_, String> {
|
||||
let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "if.then");
|
||||
let else_bb = ctx.ctx.insert_basic_block_after(current_bb, "if.else");
|
||||
|
||||
let cond = cond(generator, ctx)?;
|
||||
assert_eq!(cond.get_type().get_bit_width(), ctx.ctx.bool_type().get_bit_width());
|
||||
ctx.builder.build_conditional_branch(cond, then_bb, else_bb).unwrap();
|
||||
|
||||
ctx.builder.position_at_end(then_bb);
|
||||
let val = then(generator, ctx)?;
|
||||
|
||||
if !ctx.is_terminated() {
|
||||
ctx.builder.build_unconditional_branch(end_bb).unwrap();
|
||||
}
|
||||
|
||||
ctx.builder.position_at_end(else_bb);
|
||||
|
||||
Ok((val, then_bb))
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
if let Some(else_fn) = else_fn {
|
||||
let else_bb = ctx.builder.get_insert_block().unwrap();
|
||||
let else_val = else_fn(generator, ctx)?;
|
||||
vals.push((else_val, else_bb));
|
||||
|
||||
if !ctx.is_terminated() {
|
||||
ctx.builder.build_unconditional_branch(end_bb).unwrap();
|
||||
}
|
||||
} else {
|
||||
ctx.builder.build_unreachable().unwrap();
|
||||
}
|
||||
|
||||
vals
|
||||
};
|
||||
|
||||
ctx.builder.position_at_end(end_bb);
|
||||
let phi = if vals.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let llvm_val_ty = vals.iter()
|
||||
.filter_map(|(val, _)| val.as_ref().map(|v| v.as_basic_value_enum().get_type()))
|
||||
.reduce(|acc, ty| {
|
||||
assert_eq!(acc, ty);
|
||||
acc
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let phi = ctx.builder.build_phi(llvm_val_ty, "").unwrap();
|
||||
vals.into_iter()
|
||||
.filter_map(|(val, bb)| val.map(|v| (v, bb)))
|
||||
.for_each(|(val, bb)| phi.add_incoming(&[(&val.as_basic_value_enum(), bb)]));
|
||||
|
||||
Some(phi.as_basic_value())
|
||||
};
|
||||
|
||||
Ok(phi)
|
||||
}
|
||||
|
||||
/// Generates a C-style chained-`if` construct using lambdas, similar to the following C code:
|
||||
///
|
||||
/// ```c
|
||||
/// if (ifts[0].cond()) {
|
||||
/// ifts[0].then();
|
||||
/// } else if (ifts[1].cond()) {
|
||||
/// ifts[1].then();
|
||||
/// } else if /* ... */
|
||||
/// else {
|
||||
/// if (else_fn) {
|
||||
/// else_fn();
|
||||
/// } else {
|
||||
/// __builtin_unreachable();
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// This function mainly serves as an abstraction over [`gen_chained_if_expr_callback`] when a value
|
||||
/// does not need to be returned from the `if` construct.
|
||||
///
|
||||
/// - `ifts` - A slice of tuples containing the condition and body of a branch respectively. The
|
||||
/// branches will be generated in the order as appears in the slice.
|
||||
/// - `else_fn` - The body to generate if no other branches evaluates to `true`. If [`None`], a call
|
||||
/// to `__builtin_unreachable` will be generated instead.
|
||||
pub fn gen_chained_if_callback<'ctx, 'a, G, CondFn, ThenFn, ElseFn>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ifts: &[(CondFn, ThenFn)],
|
||||
else_fn: &Option<ElseFn>,
|
||||
) -> Result<(), String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
CondFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
|
||||
ThenFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<(), String>,
|
||||
ElseFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<(), String>,
|
||||
{
|
||||
let res = gen_chained_if_expr_callback(
|
||||
generator,
|
||||
ctx,
|
||||
ifts.iter()
|
||||
.map(|(cond, then)| {
|
||||
(
|
||||
cond,
|
||||
|generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>| {
|
||||
then(generator, ctx)?;
|
||||
Ok(None::<BasicValueEnum<'ctx>>)
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect_vec()
|
||||
.as_slice(),
|
||||
else_fn
|
||||
.as_ref()
|
||||
.map(|else_fn| |generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>| {
|
||||
else_fn(generator, ctx)?;
|
||||
Ok(None)
|
||||
}),
|
||||
)?;
|
||||
|
||||
assert!(res.is_none());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Generates a C-style chained-`if` construct using lambdas, similar to the following C code:
|
||||
///
|
||||
/// ```c
|
||||
|
@ -872,9 +704,6 @@ pub fn gen_chained_if_callback<'ctx, 'a, G, CondFn, ThenFn, ElseFn>(
|
|||
/// val = else_fn();
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// This function mainly serves as an abstraction over [`gen_chained_if_expr_callback`] for a basic
|
||||
/// `if`-`else` construct that returns a value.
|
||||
pub fn gen_if_else_expr_callback<'ctx, 'a, G, CondFn, ThenFn, ElseFn, R>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
|
@ -884,17 +713,50 @@ pub fn gen_if_else_expr_callback<'ctx, 'a, G, CondFn, ThenFn, ElseFn, R>(
|
|||
) -> Result<Option<BasicValueEnum<'ctx>>, String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
CondFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
|
||||
ThenFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<Option<R>, String>,
|
||||
ElseFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<Option<R>, String>,
|
||||
CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
|
||||
ThenFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<Option<R>, String>,
|
||||
ElseFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<Option<R>, String>,
|
||||
R: BasicValue<'ctx>,
|
||||
{
|
||||
gen_chained_if_expr_callback(
|
||||
generator,
|
||||
ctx,
|
||||
&[(cond_fn, then_fn)],
|
||||
Some(else_fn),
|
||||
)
|
||||
let current_bb = ctx.builder.get_insert_block().unwrap();
|
||||
|
||||
let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "if.then");
|
||||
let else_bb = ctx.ctx.insert_basic_block_after(then_bb, "if.else");
|
||||
let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "if.end");
|
||||
|
||||
let cond = cond_fn(generator, ctx)?;
|
||||
assert_eq!(cond.get_type().get_bit_width(), ctx.ctx.bool_type().get_bit_width());
|
||||
ctx.builder.build_conditional_branch(cond, then_bb, else_bb).unwrap();
|
||||
|
||||
ctx.builder.position_at_end(then_bb);
|
||||
let then_val = then_fn(generator, ctx)?;
|
||||
if !ctx.is_terminated() {
|
||||
ctx.builder.build_unconditional_branch(end_bb).unwrap();
|
||||
}
|
||||
|
||||
ctx.builder.position_at_end(else_bb);
|
||||
let else_val = else_fn(generator, ctx)?;
|
||||
if !ctx.is_terminated() {
|
||||
ctx.builder.build_unconditional_branch(end_bb).unwrap();
|
||||
}
|
||||
|
||||
ctx.builder.position_at_end(end_bb);
|
||||
let phi = match (then_val, else_val) {
|
||||
(Some(tv), Some(ev)) => {
|
||||
let tv_ty = tv.as_basic_value_enum().get_type();
|
||||
assert_eq!(tv_ty, ev.as_basic_value_enum().get_type());
|
||||
|
||||
let phi = ctx.builder.build_phi(tv_ty, "").unwrap();
|
||||
phi.add_incoming(&[(&tv, then_bb), (&ev, else_bb)]);
|
||||
|
||||
Some(phi.as_basic_value())
|
||||
},
|
||||
(Some(tv), None) => Some(tv.as_basic_value_enum()),
|
||||
(None, Some(ev)) => Some(ev.as_basic_value_enum()),
|
||||
(None, None) => None,
|
||||
};
|
||||
|
||||
Ok(phi)
|
||||
}
|
||||
|
||||
/// Generates a C-style chained-`if` construct using lambdas, similar to the following C code:
|
||||
|
@ -903,33 +765,37 @@ pub fn gen_if_else_expr_callback<'ctx, 'a, G, CondFn, ThenFn, ElseFn, R>(
|
|||
/// if (cond_fn()) {
|
||||
/// then_fn();
|
||||
/// } else {
|
||||
/// if (else_fn) {
|
||||
/// else_fn();
|
||||
/// }
|
||||
/// else_fn();
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// This function mainly serves as an abstraction over [`gen_chained_if_expr_callback`] for a basic
|
||||
/// `if`-`else` construct that does not return a value.
|
||||
pub fn gen_if_callback<'ctx, 'a, G, CondFn, ThenFn, ElseFn>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
cond_fn: CondFn,
|
||||
then_fn: ThenFn,
|
||||
else_fn: &Option<ElseFn>,
|
||||
else_fn: ElseFn,
|
||||
) -> Result<(), String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
CondFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
|
||||
ThenFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<(), String>,
|
||||
ElseFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<(), String>,
|
||||
CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
|
||||
ThenFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<(), String>,
|
||||
ElseFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<(), String>,
|
||||
{
|
||||
gen_chained_if_callback(
|
||||
gen_if_else_expr_callback(
|
||||
generator,
|
||||
ctx,
|
||||
&[(cond_fn, then_fn)],
|
||||
else_fn,
|
||||
)
|
||||
cond_fn,
|
||||
|generator, ctx| {
|
||||
then_fn(generator, ctx)?;
|
||||
Ok(None::<BasicValueEnum<'ctx>>)
|
||||
},
|
||||
|generator, ctx| {
|
||||
else_fn(generator, ctx)?;
|
||||
Ok(None)
|
||||
}
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// See [`CodeGenerator::gen_if`].
|
||||
|
|
|
@ -1730,7 +1730,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
arg.dim_sizes().get_typed_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
llvm_usize.const_zero(),
|
||||
&llvm_usize.const_zero(),
|
||||
None,
|
||||
)
|
||||
};
|
||||
|
|
|
@ -291,6 +291,17 @@ pub fn impl_mod(
|
|||
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Mod]);
|
||||
}
|
||||
|
||||
/// [Operator::MatMult]
|
||||
pub fn impl_matmul(
|
||||
unifier: &mut Unifier,
|
||||
store: &PrimitiveStore,
|
||||
ty: Type,
|
||||
other_ty: &[Type],
|
||||
ret_ty: Option<Type>,
|
||||
) {
|
||||
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::MatMult])
|
||||
}
|
||||
|
||||
/// `UAdd`, `USub`
|
||||
pub fn impl_sign(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret_ty: Option<Type>) {
|
||||
impl_unaryop(unifier, ty, ret_ty, &[Unaryop::UAdd, Unaryop::USub]);
|
||||
|
@ -431,7 +442,38 @@ pub fn typeof_binop(
|
|||
}
|
||||
}
|
||||
|
||||
Operator::MatMult => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?,
|
||||
Operator::MatMult => {
|
||||
let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs);
|
||||
let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) {
|
||||
TypeEnum::TLiteral { values, .. } => {
|
||||
assert_eq!(values.len(), 1);
|
||||
u64::try_from(values[0].clone()).unwrap()
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs);
|
||||
let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) {
|
||||
TypeEnum::TLiteral { values, .. } => {
|
||||
assert_eq!(values.len(), 1);
|
||||
u64::try_from(values[0].clone()).unwrap()
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
match (lhs_ndims, rhs_ndims) {
|
||||
(2, 2) => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?,
|
||||
(lhs, rhs) if lhs == 0 || rhs == 0 => {
|
||||
return Err(format!(
|
||||
"Input operand {} does not have enough dimensions (has {lhs}, requires {rhs})",
|
||||
(rhs == 0) as u8
|
||||
))
|
||||
}
|
||||
(lhs, rhs) => {
|
||||
return Err(format!("ndarray.__matmul__ on {lhs}D and {rhs}D operands not supported"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Operator::Div => {
|
||||
if is_left_ndarray || is_right_ndarray {
|
||||
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
|
||||
|
@ -610,6 +652,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
|||
impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None);
|
||||
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||
impl_matmul(unifier, store, ndarray_t, &[ndarray_t], Some(ndarray_t));
|
||||
impl_sign(unifier, store, ndarray_t, Some(ndarray_t));
|
||||
impl_invert(unifier, store, ndarray_t, Some(ndarray_t));
|
||||
impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||
|
|
|
@ -429,6 +429,19 @@ def test_ndarray_ipow_broadcast_scalar():
|
|||
|
||||
output_ndarray_float_2(x)
|
||||
|
||||
def test_ndarray_matmul():
|
||||
x = np_identity(2)
|
||||
y = x @ np_ones([2, 2])
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_imatmul():
|
||||
x = np_identity(2)
|
||||
x @= np_ones([2, 2])
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
|
||||
def test_ndarray_pos():
|
||||
x_int32 = np_full([2, 2], -2)
|
||||
y_int32 = +x_int32
|
||||
|
@ -696,6 +709,8 @@ def run() -> int32:
|
|||
test_ndarray_ipow()
|
||||
test_ndarray_ipow_broadcast()
|
||||
test_ndarray_ipow_broadcast_scalar()
|
||||
test_ndarray_matmul()
|
||||
test_ndarray_imatmul()
|
||||
test_ndarray_pos()
|
||||
test_ndarray_neg()
|
||||
test_ndarray_inv()
|
||||
|
|
Loading…
Reference in New Issue