ndstrides: [10] Reimplement binops, unary ops, and cmpops. #520
|
@ -8,7 +8,10 @@ use std::{
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
attributes::{Attribute, AttributeLoc},
|
attributes::{Attribute, AttributeLoc},
|
||||||
types::{AnyType, BasicType, BasicTypeEnum},
|
types::{AnyType, BasicType, BasicTypeEnum},
|
||||||
values::{BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue, StructValue},
|
values::{
|
||||||
|
BasicValue, BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue,
|
||||||
|
StructValue,
|
||||||
|
},
|
||||||
AddressSpace, IntPredicate, OptimizationLevel,
|
AddressSpace, IntPredicate, OptimizationLevel,
|
||||||
};
|
};
|
||||||
use itertools::{chain, izip, Either, Itertools};
|
use itertools::{chain, izip, Either, Itertools};
|
||||||
|
@ -20,8 +23,8 @@ use nac3parser::ast::{
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
classes::{
|
classes::{
|
||||||
ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayValue, ProxyType, ProxyValue,
|
ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, ProxyType, ProxyValue, RangeValue,
|
||||||
RangeValue, UntypedArrayLikeAccessor,
|
UntypedArrayLikeAccessor,
|
||||||
},
|
},
|
||||||
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
|
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
|
||||||
gen_in_range_check, get_llvm_abi_type, get_llvm_type, get_va_count_arg_name,
|
gen_in_range_check, get_llvm_abi_type, get_llvm_type, get_va_count_arg_name,
|
||||||
|
@ -31,10 +34,13 @@ use super::{
|
||||||
call_int_umin, call_memcpy_generic,
|
call_int_umin, call_memcpy_generic,
|
||||||
},
|
},
|
||||||
macros::codegen_unreachable,
|
macros::codegen_unreachable,
|
||||||
need_sret, numpy,
|
need_sret,
|
||||||
object::{
|
object::{
|
||||||
any::AnyObject,
|
any::AnyObject,
|
||||||
ndarray::{indexing::util::gen_ndarray_subscript_ndindices, NDArrayObject},
|
ndarray::{
|
||||||
|
indexing::util::gen_ndarray_subscript_ndindices, NDArrayObject, NDArrayOut,
|
||||||
|
ScalarOrNDArray,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
stmt::{
|
stmt::{
|
||||||
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
|
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
|
||||||
|
@ -44,7 +50,7 @@ use super::{
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
symbol_resolver::{SymbolValue, ValueEnum},
|
symbol_resolver::{SymbolValue, ValueEnum},
|
||||||
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef},
|
toplevel::{helper::PrimDef, DefinitionId, TopLevelDef},
|
||||||
typecheck::{
|
typecheck::{
|
||||||
magic_methods::{Binop, BinopVariant, HasOpInfo},
|
magic_methods::{Binop, BinopVariant, HasOpInfo},
|
||||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap},
|
typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap},
|
||||||
|
@ -1549,99 +1555,71 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
} else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
} else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
|| ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
|| ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
{
|
{
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let left =
|
||||||
|
ScalarOrNDArray::split_object(generator, ctx, AnyObject { ty: ty1, value: left_val });
|
||||||
|
let right =
|
||||||
|
ScalarOrNDArray::split_object(generator, ctx, AnyObject { ty: ty2, value: right_val });
|
||||||
|
|
||||||
let is_ndarray1 = ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
// Inhomogeneous binary operations are not supported.
|
||||||
let is_ndarray2 = ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
assert!(ctx.unifier.unioned(left.get_dtype(), right.get_dtype()));
|
||||||
|
|
||||||
if is_ndarray1 && is_ndarray2 {
|
let common_dtype = left.get_dtype();
|
||||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1);
|
|
||||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty2);
|
|
||||||
|
|
||||||
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
let out = match op.variant {
|
||||||
|
BinopVariant::Normal => NDArrayOut::NewNDArray { dtype: common_dtype },
|
||||||
let left_val =
|
BinopVariant::AugAssign => {
|
||||||
NDArrayValue::from_ptr_val(left_val.into_pointer_value(), llvm_usize, None);
|
// If this is an augmented assignment.
|
||||||
let right_val =
|
// `left` has to be an ndarray. If it were a scalar then NAC3 simply doesn't support it.
|
||||||
NDArrayValue::from_ptr_val(right_val.into_pointer_value(), llvm_usize, None);
|
if let ScalarOrNDArray::NDArray(out_ndarray) = left {
|
||||||
|
NDArrayOut::WriteToNDArray { ndarray: out_ndarray }
|
||||||
let res = if op.base == Operator::MatMult {
|
|
||||||
// MatMult is the only binop which is not an elementwise op
|
|
||||||
numpy::ndarray_matmul_2d(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
ndarray_dtype1,
|
|
||||||
match op.variant {
|
|
||||||
BinopVariant::Normal => None,
|
|
||||||
BinopVariant::AugAssign => Some(left_val),
|
|
||||||
},
|
|
||||||
left_val,
|
|
||||||
right_val,
|
|
||||||
)?
|
|
||||||
} else {
|
} else {
|
||||||
numpy::ndarray_elementwise_binop_impl(
|
panic!("left must be an ndarray")
|
||||||
generator,
|
}
|
||||||
ctx,
|
}
|
||||||
ndarray_dtype1,
|
|
||||||
match op.variant {
|
|
||||||
BinopVariant::Normal => None,
|
|
||||||
BinopVariant::AugAssign => Some(left_val),
|
|
||||||
},
|
|
||||||
(left_val.as_base_value().into(), false),
|
|
||||||
(right_val.as_base_value().into(), false),
|
|
||||||
|generator, ctx, (lhs, rhs)| {
|
|
||||||
gen_binop_expr_with_values(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
(&Some(ndarray_dtype1), lhs),
|
|
||||||
op,
|
|
||||||
(&Some(ndarray_dtype2), rhs),
|
|
||||||
ctx.current_loc,
|
|
||||||
)?
|
|
||||||
.unwrap()
|
|
||||||
.to_basic_value_enum(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
ndarray_dtype1,
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)?
|
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Some(res.as_base_value().into()))
|
if op.base == Operator::MatMult {
|
||||||
|
// Handle matrix multiplication.
|
||||||
|
todo!()
|
||||||
} else {
|
} else {
|
||||||
let (ndarray_dtype, _) =
|
// For other operations, they are all elementwise operations.
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, if is_ndarray1 { ty1 } else { ty2 });
|
|
||||||
let ndarray_val = NDArrayValue::from_ptr_val(
|
// There are only three cases:
|
||||||
if is_ndarray1 { left_val } else { right_val }.into_pointer_value(),
|
// - LHS is a scalar, RHS is an ndarray.
|
||||||
llvm_usize,
|
// - LHS is an ndarray, RHS is a scalar.
|
||||||
None,
|
// - LHS is an ndarray, RHS is an ndarray.
|
||||||
);
|
//
|
||||||
let res = numpy::ndarray_elementwise_binop_impl(
|
// For all cases, the scalar operand is promoted to an ndarray,
|
||||||
|
// the two are then broadcasted, and starmapped through.
|
||||||
|
|
||||||
|
let left = left.to_ndarray(generator, ctx);
|
||||||
|
let right = right.to_ndarray(generator, ctx);
|
||||||
|
|
||||||
|
let result = NDArrayObject::broadcast_starmap(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ndarray_dtype,
|
&[left, right],
|
||||||
match op.variant {
|
out,
|
||||||
BinopVariant::Normal => None,
|
|generator, ctx, scalars| {
|
||||||
BinopVariant::AugAssign => Some(ndarray_val),
|
let left_value = scalars[0];
|
||||||
},
|
let right_value = scalars[1];
|
||||||
(left_val, !is_ndarray1),
|
|
||||||
(right_val, !is_ndarray2),
|
let result = gen_binop_expr_with_values(
|
||||||
|generator, ctx, (lhs, rhs)| {
|
|
||||||
gen_binop_expr_with_values(
|
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
(&Some(ndarray_dtype), lhs),
|
(&Some(left.dtype), left_value),
|
||||||
op,
|
op,
|
||||||
(&Some(ndarray_dtype), rhs),
|
(&Some(right.dtype), right_value),
|
||||||
ctx.current_loc,
|
ctx.current_loc,
|
||||||
)?
|
)?
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.to_basic_value_enum(ctx, generator, ndarray_dtype)
|
.to_basic_value_enum(ctx, generator, common_dtype)?;
|
||||||
},
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(Some(res.as_base_value().into()))
|
Ok(result)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
Ok(Some(ValueEnum::Dynamic(result.instance.value.as_basic_value_enum())))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap());
|
let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap());
|
||||||
|
@ -1799,14 +1777,12 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
_ => val.into(),
|
_ => val.into(),
|
||||||
}
|
}
|
||||||
} else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
} else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let ndarray = AnyObject { value: val, ty };
|
||||||
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
|
||||||
|
|
||||||
let val = NDArrayValue::from_ptr_val(val.into_pointer_value(), llvm_usize, None);
|
|
||||||
|
|
||||||
// ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before
|
// ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before
|
||||||
// passing it to the elementwise codegen function
|
// passing it to the elementwise codegen function
|
||||||
let op = if ndarray_dtype.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::Bool.id()) {
|
let op = if ndarray.dtype.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::Bool.id()) {
|
||||||
if op == ast::Unaryop::Invert {
|
if op == ast::Unaryop::Invert {
|
||||||
ast::Unaryop::Not
|
ast::Unaryop::Not
|
||||||
} else {
|
} else {
|
||||||
|
@ -1820,20 +1796,18 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
op
|
op
|
||||||
};
|
};
|
||||||
|
|
||||||
let res = numpy::ndarray_elementwise_unaryop_impl(
|
let mapped_ndarray = ndarray.map(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ndarray_dtype,
|
NDArrayOut::NewNDArray { dtype: ndarray.dtype },
|
||||||
None,
|
|generator, ctx, scalar| {
|
||||||
val,
|
gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray.dtype), scalar))?
|
||||||
|generator, ctx, val| {
|
|
||||||
gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), val))?
|
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.to_basic_value_enum(ctx, generator, ndarray_dtype)
|
.to_basic_value_enum(ctx, generator, ndarray.dtype)
|
||||||
},
|
},
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
res.as_base_value().into()
|
ValueEnum::Dynamic(mapped_ndarray.instance.value.as_basic_value_enum())
|
||||||
} else {
|
} else {
|
||||||
unimplemented!()
|
unimplemented!()
|
||||||
}))
|
}))
|
||||||
|
@ -1876,39 +1850,33 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
|| right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
|| right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
{
|
{
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let (Some(left_ty), left) = left else { codegen_unreachable!(ctx) };
|
||||||
|
let (Some(right_ty), right) = comparators[0] else { codegen_unreachable!(ctx) };
|
||||||
let (Some(left_ty), lhs) = left else { codegen_unreachable!(ctx) };
|
|
||||||
let (Some(right_ty), rhs) = comparators[0] else { codegen_unreachable!(ctx) };
|
|
||||||
let op = ops[0];
|
let op = ops[0];
|
||||||
|
|
||||||
let is_ndarray1 =
|
let left = AnyObject { value: left, ty: left_ty };
|
||||||
left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
let left =
|
||||||
let is_ndarray2 =
|
ScalarOrNDArray::split_object(generator, ctx, left).to_ndarray(generator, ctx);
|
||||||
right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
|
||||||
|
|
||||||
return if is_ndarray1 && is_ndarray2 {
|
let right = AnyObject { value: right, ty: right_ty };
|
||||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty);
|
let right =
|
||||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty);
|
ScalarOrNDArray::split_object(generator, ctx, right).to_ndarray(generator, ctx);
|
||||||
|
|
||||||
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
let result_ndarray = NDArrayObject::broadcast_starmap(
|
||||||
|
|
||||||
let left_val =
|
|
||||||
NDArrayValue::from_ptr_val(lhs.into_pointer_value(), llvm_usize, None);
|
|
||||||
let res = numpy::ndarray_elementwise_binop_impl(
|
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.bool,
|
&[left, right],
|
||||||
None,
|
NDArrayOut::NewNDArray { dtype: ctx.primitives.bool },
|
||||||
(left_val.as_base_value().into(), false),
|
|generator, ctx, scalars| {
|
||||||
(rhs, false),
|
let left_scalar = scalars[0];
|
||||||
|generator, ctx, (lhs, rhs)| {
|
let right_scalar = scalars[1];
|
||||||
|
|
||||||
let val = gen_cmpop_expr_with_values(
|
let val = gen_cmpop_expr_with_values(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
(Some(ndarray_dtype1), lhs),
|
(Some(left.dtype), left_scalar),
|
||||||
&[op],
|
&[op],
|
||||||
&[(Some(ndarray_dtype2), rhs)],
|
&[(Some(right.dtype), right_scalar)],
|
||||||
)?
|
)?
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.to_basic_value_enum(
|
.to_basic_value_enum(
|
||||||
|
@ -1921,40 +1889,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
},
|
},
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
Ok(Some(res.as_base_value().into()))
|
return Ok(Some(result_ndarray.instance.value.into()));
|
||||||
} else {
|
|
||||||
let (ndarray_dtype, _) = unpack_ndarray_var_tys(
|
|
||||||
&mut ctx.unifier,
|
|
||||||
if is_ndarray1 { left_ty } else { right_ty },
|
|
||||||
);
|
|
||||||
let res = numpy::ndarray_elementwise_binop_impl(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
ctx.primitives.bool,
|
|
||||||
None,
|
|
||||||
(lhs, !is_ndarray1),
|
|
||||||
(rhs, !is_ndarray2),
|
|
||||||
|generator, ctx, (lhs, rhs)| {
|
|
||||||
let val = gen_cmpop_expr_with_values(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
(Some(ndarray_dtype), lhs),
|
|
||||||
&[op],
|
|
||||||
&[(Some(ndarray_dtype), rhs)],
|
|
||||||
)?
|
|
||||||
.unwrap()
|
|
||||||
.to_basic_value_enum(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
ctx.primitives.bool,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(generator.bool_to_i8(ctx, val.into_int_value()).into())
|
|
||||||
},
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(Some(res.as_base_value().into()))
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,219 @@
|
||||||
|
use inkwell::values::BasicValueEnum;
|
||||||
|
use itertools::Itertools;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
codegen::{
|
||||||
|
object::ndarray::{AnyObject, NDArrayObject},
|
||||||
|
stmt::gen_for_callback,
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
},
|
||||||
|
typecheck::typedef::Type,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{nditer::NDIterHandle, NDArrayOut, ScalarOrNDArray};
|
||||||
|
|
||||||
|
impl<'ctx> NDArrayObject<'ctx> {
|
||||||
|
/// Generate LLVM IR to broadcast `ndarray`s together, and starmap through them with `mapping` elementwise.
|
||||||
|
///
|
||||||
|
/// `mapping` is an LLVM IR generator. The input of `mapping` is the list of elements when iterating through
|
||||||
|
/// the input `ndarrays` after broadcasting. The output of `mapping` is the result of the elementwise operation.
|
||||||
|
///
|
||||||
|
/// `out` specifies whether the result should be a new ndarray or to be written an existing ndarray.
|
||||||
|
pub fn broadcast_starmap<'a, G, MappingFn>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
ndarrays: &[Self],
|
||||||
|
out: NDArrayOut<'ctx>,
|
||||||
|
mapping: MappingFn,
|
||||||
|
) -> Result<Self, String>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
MappingFn: FnOnce(
|
||||||
|
&mut G,
|
||||||
|
&mut CodeGenContext<'ctx, 'a>,
|
||||||
|
&[BasicValueEnum<'ctx>],
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||||
|
{
|
||||||
|
// Broadcast inputs
|
||||||
|
let broadcast_result = NDArrayObject::broadcast(generator, ctx, ndarrays);
|
||||||
|
|
||||||
|
let out_ndarray = match out {
|
||||||
|
NDArrayOut::NewNDArray { dtype } => {
|
||||||
|
// Create a new ndarray based on the broadcast shape.
|
||||||
|
let result_ndarray =
|
||||||
|
NDArrayObject::alloca(generator, ctx, dtype, broadcast_result.ndims);
|
||||||
|
result_ndarray.copy_shape_from_array(generator, ctx, broadcast_result.shape);
|
||||||
|
result_ndarray.create_data(generator, ctx);
|
||||||
|
result_ndarray
|
||||||
|
}
|
||||||
|
NDArrayOut::WriteToNDArray { ndarray: result_ndarray } => {
|
||||||
|
// Use an existing ndarray.
|
||||||
|
|
||||||
|
// Check that its shape is compatible with the broadcast shape.
|
||||||
|
result_ndarray.assert_can_be_written_by_out(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
broadcast_result.ndims,
|
||||||
|
broadcast_result.shape,
|
||||||
|
);
|
||||||
|
result_ndarray
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Map element-wise and store results into `mapped_ndarray`.
|
||||||
|
let nditer = NDIterHandle::new(generator, ctx, out_ndarray);
|
||||||
|
gen_for_callback(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
Some("broadcast_starmap"),
|
||||||
|
|generator, ctx| {
|
||||||
|
// Create NDIters for all broadcasted input ndarrays.
|
||||||
|
let other_nditers = broadcast_result
|
||||||
|
.ndarrays
|
||||||
|
.iter()
|
||||||
|
.map(|ndarray| NDIterHandle::new(generator, ctx, *ndarray))
|
||||||
|
.collect_vec();
|
||||||
|
Ok((nditer, other_nditers))
|
||||||
|
},
|
||||||
|
|generator, ctx, (out_nditer, _in_nditers)| {
|
||||||
|
// We can simply use `out_nditer`'s `has_element()`.
|
||||||
|
// `in_nditers`' `has_element()`s should return the same value.
|
||||||
|
Ok(out_nditer.has_element(generator, ctx).value)
|
||||||
|
},
|
||||||
|
|generator, ctx, _hooks, (out_nditer, in_nditers)| {
|
||||||
|
// Get all the scalars from the broadcasted input ndarrays, pass them to `mapping`,
|
||||||
|
// and write to `out_ndarray`.
|
||||||
|
let in_scalars = in_nditers
|
||||||
|
.iter()
|
||||||
|
.map(|nditer| nditer.get_scalar(generator, ctx).value)
|
||||||
|
.collect_vec();
|
||||||
|
|
||||||
|
let result = mapping(generator, ctx, &in_scalars)?;
|
||||||
|
|
||||||
|
let p = out_nditer.get_pointer(generator, ctx);
|
||||||
|
ctx.builder.build_store(p, result).unwrap();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
|generator, ctx, (out_nditer, in_nditers)| {
|
||||||
|
// Advance all iterators
|
||||||
|
out_nditer.next(generator, ctx);
|
||||||
|
in_nditers.iter().for_each(|nditer| nditer.next(generator, ctx));
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(out_ndarray)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Map through this ndarray with an elementwise function.
|
||||||
|
pub fn map<'a, G, Mapping>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
out: NDArrayOut<'ctx>,
|
||||||
|
mapping: Mapping,
|
||||||
|
) -> Result<Self, String>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
Mapping: FnOnce(
|
||||||
|
&mut G,
|
||||||
|
&mut CodeGenContext<'ctx, 'a>,
|
||||||
|
BasicValueEnum<'ctx>,
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||||
|
{
|
||||||
|
NDArrayObject::broadcast_starmap(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
&[*self],
|
||||||
|
out,
|
||||||
|
|generator, ctx, scalars| mapping(generator, ctx, scalars[0]),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||||
|
/// Starmap through a list of inputs using `mapping`, where an input could be an ndarray, a scalar.
|
||||||
|
///
|
||||||
|
/// This function is very helpful when implementing NumPy functions that takes on either scalars or ndarrays or a mix of them
|
||||||
|
/// as their inputs and produces either an ndarray with broadcast, or a scalar if all its inputs are all scalars.
|
||||||
|
///
|
||||||
|
/// For example ,this function can be used to implement `np.add`, which has the following behaviors:
|
||||||
|
/// - `np.add(3, 4) = 7` # (scalar, scalar) -> scalar
|
||||||
|
/// - `np.add(3, np.array([4, 5, 6]))` # (scalar, ndarray) -> ndarray; the first `scalar` is converted into an ndarray and broadcasted.
|
||||||
|
/// - `np.add(np.array([[1], [2], [3]]), np.array([[4, 5, 6]]))` # (ndarray, ndarray) -> ndarray; there is broadcasting.
|
||||||
|
///
|
||||||
|
/// ## Details:
|
||||||
|
///
|
||||||
|
/// If `inputs` are all [`ScalarOrNDArray::Scalar`], the output will be a [`ScalarOrNDArray::Scalar`] with type `ret_dtype`.
|
||||||
|
///
|
||||||
|
/// Otherwise (if there are any [`ScalarOrNDArray::NDArray`] in `inputs`), all inputs will be 'as-ndarray'-ed into ndarrays,
|
||||||
|
/// then all inputs (now all ndarrays) will be passed to [`NDArrayObject::broadcasting_starmap`] and **create** a new ndarray
|
||||||
|
/// with dtype `ret_dtype`.
|
||||||
|
pub fn broadcasting_starmap<'a, G, MappingFn>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
inputs: &[ScalarOrNDArray<'ctx>],
|
||||||
|
ret_dtype: Type,
|
||||||
|
mapping: MappingFn,
|
||||||
|
) -> Result<ScalarOrNDArray<'ctx>, String>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
MappingFn: FnOnce(
|
||||||
|
&mut G,
|
||||||
|
&mut CodeGenContext<'ctx, 'a>,
|
||||||
|
&[BasicValueEnum<'ctx>],
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||||
|
{
|
||||||
|
// Check if all inputs are Scalars
|
||||||
|
let all_scalars: Option<Vec<_>> = inputs.iter().map(AnyObject::try_from).try_collect().ok();
|
||||||
|
|
||||||
|
if let Some(scalars) = all_scalars {
|
||||||
|
let scalars = scalars.iter().map(|scalar| scalar.value).collect_vec();
|
||||||
|
let value = mapping(generator, ctx, &scalars)?;
|
||||||
|
|
||||||
|
Ok(ScalarOrNDArray::Scalar(AnyObject { ty: ret_dtype, value }))
|
||||||
|
} else {
|
||||||
|
// Promote all input to ndarrays and map through them.
|
||||||
|
let inputs = inputs.iter().map(|input| input.to_ndarray(generator, ctx)).collect_vec();
|
||||||
|
let ndarray = NDArrayObject::broadcast_starmap(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
&inputs,
|
||||||
|
NDArrayOut::NewNDArray { dtype: ret_dtype },
|
||||||
|
mapping,
|
||||||
|
)?;
|
||||||
|
Ok(ScalarOrNDArray::NDArray(ndarray))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Map through this [`ScalarOrNDArray`] with an elementwise function.
|
||||||
|
///
|
||||||
|
/// If this is a scalar, `mapping` will directly act on the scalar. This function will return a [`ScalarOrNDArray::Scalar`] of that result.
|
||||||
|
///
|
||||||
|
/// If this is an ndarray, `mapping` will be applied to the elements of the ndarray. A new ndarray of the results will be created and
|
||||||
|
/// returned as a [`ScalarOrNDArray::NDArray`].
|
||||||
|
pub fn map<'a, G, Mapping>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
ret_dtype: Type,
|
||||||
|
mapping: Mapping,
|
||||||
|
) -> Result<ScalarOrNDArray<'ctx>, String>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
Mapping: FnOnce(
|
||||||
|
&mut G,
|
||||||
|
&mut CodeGenContext<'ctx, 'a>,
|
||||||
|
BasicValueEnum<'ctx>,
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||||
|
{
|
||||||
|
ScalarOrNDArray::broadcasting_starmap(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
&[*self],
|
||||||
|
ret_dtype,
|
||||||
|
|generator, ctx, scalars| mapping(generator, ctx, scalars[0]),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
|
@ -13,6 +13,7 @@ use crate::{
|
||||||
call_nac3_ndarray_get_pelement_by_indices, call_nac3_ndarray_is_c_contiguous,
|
call_nac3_ndarray_get_pelement_by_indices, call_nac3_ndarray_is_c_contiguous,
|
||||||
call_nac3_ndarray_len, call_nac3_ndarray_nbytes,
|
call_nac3_ndarray_len, call_nac3_ndarray_nbytes,
|
||||||
call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_size,
|
call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_size,
|
||||||
|
call_nac3_ndarray_util_assert_output_shape_same,
|
||||||
},
|
},
|
||||||
model::*,
|
model::*,
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
|
@ -25,6 +26,7 @@ pub mod array;
|
||||||
pub mod broadcast;
|
pub mod broadcast;
|
||||||
pub mod factory;
|
pub mod factory;
|
||||||
pub mod indexing;
|
pub mod indexing;
|
||||||
|
pub mod map;
|
||||||
pub mod nditer;
|
pub mod nditer;
|
||||||
pub mod shape_util;
|
pub mod shape_util;
|
||||||
pub mod view;
|
pub mod view;
|
||||||
|
@ -499,6 +501,31 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
ndarray.instance.set(ctx, |f| f.data, data);
|
ndarray.instance.set(ctx, |f| f.data, data);
|
||||||
ndarray
|
ndarray
|
||||||
}
|
}
|
||||||
|
/// Check if this `NDArray` can be used as an `out` ndarray for an operation.
|
||||||
|
///
|
||||||
|
/// Raise an exception if the shapes do not match.
|
||||||
|
pub fn assert_can_be_written_by_out<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
out_ndims: u64,
|
||||||
|
out_shape: Instance<'ctx, Ptr<Int<SizeT>>>,
|
||||||
|
) {
|
||||||
|
let ndarray_ndims = self.ndims_llvm(generator, ctx.ctx);
|
||||||
|
let ndarray_shape = self.instance.get(generator, ctx, |f| f.shape);
|
||||||
|
|
||||||
|
let output_ndims = Int(SizeT).const_int(generator, ctx.ctx, out_ndims, false);
|
||||||
|
let output_shape = out_shape;
|
||||||
|
|
||||||
|
call_nac3_ndarray_util_assert_output_shape_same(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ndarray_ndims,
|
||||||
|
ndarray_shape,
|
||||||
|
output_ndims,
|
||||||
|
output_shape,
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A convenience enum for implementing functions that acts on scalars or ndarrays or both.
|
/// A convenience enum for implementing functions that acts on scalars or ndarrays or both.
|
||||||
|
@ -584,3 +611,27 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// An helper enum specifying how a function should produce its output.
|
||||||
|
///
|
||||||
|
/// Many functions in NumPy has an optional `out` parameter (e.g., `matmul`). If `out` is specified
|
||||||
|
/// with an ndarray, the result of a function will be written to `out`. If `out` is not specified, a function will
|
||||||
|
/// create a new ndarray and store the result in it.
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub enum NDArrayOut<'ctx> {
|
||||||
|
/// Tell a function should create a new ndarray with the expected element type `dtype`.
|
||||||
|
NewNDArray { dtype: Type },
|
||||||
|
/// Tell a function to write the result to `ndarray`.
|
||||||
|
WriteToNDArray { ndarray: NDArrayObject<'ctx> },
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> NDArrayOut<'ctx> {
|
||||||
|
/// Get the dtype of this output.
|
||||||
|
#[must_use]
|
||||||
|
pub fn get_dtype(&self) -> Type {
|
||||||
|
match self {
|
||||||
|
NDArrayOut::NewNDArray { dtype } => *dtype,
|
||||||
|
NDArrayOut::WriteToNDArray { ndarray } => ndarray.dtype,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue