core: Implement elementwise unary operators

David Mak 2024-03-27 17:06:58 +08:00
parent 0537e816a5
commit 5450147007
5 changed files with 192 additions and 16 deletions

View File

@ -1292,7 +1292,7 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
}
pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
_generator: &mut G,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
op: &ast::Unaryop,
operand: (&Option<Type>, BasicValueEnum<'ctx>),
@ -1332,6 +1332,33 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
.unwrap(),
_ => val.into(),
}
} else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
let llvm_usize = generator.get_size_type(ctx.ctx);
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
let val = NDArrayValue::from_ptr_val(
val.into_pointer_value(),
llvm_usize,
None,
);
let res = numpy::ndarray_elementwise_unaryop_impl(
generator,
ctx,
ndarray_dtype,
None,
val,
|generator, ctx, elem_ty, val| {
gen_unaryop_expr_with_values(
generator,
ctx,
op,
(&Some(elem_ty), val)
)?.unwrap().to_basic_value_enum(ctx, generator, elem_ty)
},
)?;
res.as_ptr_value().into()
} else {
unimplemented!()
}))

View File

@ -346,6 +346,31 @@ fn ndarray_fill_indexed<'ctx, G, ValueFn>(
)
}
fn ndarray_fill_mapping<'ctx, G, MapFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src: NDArrayValue<'ctx>,
dest: NDArrayValue<'ctx>,
map_fn: MapFn,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
MapFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, BasicValueEnum<'ctx>) -> Result<BasicValueEnum<'ctx>, String>,
{
ndarray_fill_flattened(
generator,
ctx,
dest,
|generator, ctx, i| {
let elem = unsafe {
src.data().get_unchecked(ctx, generator, i, None)
};
map_fn(generator, ctx, elem)
},
)
}
/// Generates the LLVM IR for populating the entire `NDArray` from two `ndarray` or scalar value
/// with broadcast-compatible shapes.
fn ndarray_broadcast_fill<'ctx, G, ValueFn>(
@ -642,6 +667,48 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>(
Ok(ndarray)
}
pub fn ndarray_elementwise_unaryop_impl<'ctx, G, MapFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
res: Option<NDArrayValue<'ctx>>,
operand: NDArrayValue<'ctx>,
map_fn: MapFn,
) -> Result<NDArrayValue<'ctx>, String>
where
G: CodeGenerator,
MapFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, Type, BasicValueEnum<'ctx>) -> Result<BasicValueEnum<'ctx>, String>,
{
let res = res.unwrap_or_else(|| {
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&operand,
|_, ctx, v| {
Ok(v.load_ndims(ctx))
},
|generator, ctx, v, idx| {
unsafe {
Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, idx, None))
}
},
).unwrap()
});
ndarray_fill_mapping(
generator,
ctx,
operand,
res,
|generator, ctx, elem| {
map_fn(generator, ctx, elem_ty, elem)
}
)?;
Ok(res)
}
/// LLVM-typed implementation for computing elementwise binary operations on two input operands.
///
/// If the operand is a `ndarray`, the broadcast index corresponding to each element in the output

View File

@ -148,8 +148,10 @@ pub fn impl_binop(
});
}
pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Type, ops: &[Unaryop]) {
pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Option<Type>, ops: &[Unaryop]) {
with_fields(unifier, ty, |unifier, fields| {
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).0);
for op in ops {
fields.insert(
unaryop_name(op).into(),
@ -274,18 +276,18 @@ pub fn impl_mod(
}
/// `UAdd`, `USub`
pub fn impl_sign(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) {
impl_unaryop(unifier, ty, ty, &[Unaryop::UAdd, Unaryop::USub]);
pub fn impl_sign(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret_ty: Option<Type>) {
impl_unaryop(unifier, ty, ret_ty, &[Unaryop::UAdd, Unaryop::USub]);
}
/// `Invert`
pub fn impl_invert(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) {
impl_unaryop(unifier, ty, ty, &[Unaryop::Invert]);
pub fn impl_invert(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret_ty: Option<Type>) {
impl_unaryop(unifier, ty, ret_ty, &[Unaryop::Invert]);
}
/// `Not`
pub fn impl_not(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_unaryop(unifier, ty, store.bool, &[Unaryop::Not]);
pub fn impl_not(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, ret_ty: Option<Type>) {
impl_unaryop(unifier, ty, ret_ty, &[Unaryop::Not]);
}
/// `Lt`, `LtE`, `Gt`, `GtE`
@ -439,6 +441,23 @@ pub fn typeof_binop(
}))
}
pub fn typeof_unaryop(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
op: &Unaryop,
operand: Type,
) -> Result<Option<Type>, String> {
if *op == Unaryop::Not && operand.obj_id(unifier).is_some_and(|id| id == primitives.ndarray.obj_id(unifier).unwrap()) {
return Err("The truth value of an array with more than one element is ambiguous".to_string())
}
Ok(if operand.obj_id(unifier).is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) {
Some(operand)
} else {
None
})
}
pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) {
let PrimitiveStore {
int32: int32_t,
@ -462,13 +481,13 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
impl_div(unifier, store, t, &[t, ndarray_int_t], None);
impl_floordiv(unifier, store, t, &[t, ndarray_int_t], None);
impl_mod(unifier, store, t, &[t, ndarray_int_t], None);
impl_invert(unifier, store, t);
impl_not(unifier, store, t);
impl_invert(unifier, store, t, Some(t));
impl_not(unifier, store, t, Some(bool_t));
impl_comparison(unifier, store, t, t);
impl_eq(unifier, store, t);
}
for t in [int32_t, int64_t] {
impl_sign(unifier, store, t);
impl_sign(unifier, store, t, Some(t));
}
/* float ======== */
@ -479,13 +498,13 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
impl_div(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_floordiv(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_mod(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_sign(unifier, store, float_t);
impl_not(unifier, store, float_t);
impl_sign(unifier, store, float_t, Some(float_t));
impl_not(unifier, store, float_t, Some(bool_t));
impl_comparison(unifier, store, float_t, float_t);
impl_eq(unifier, store, float_t);
/* bool ======== */
impl_not(unifier, store, bool_t);
impl_not(unifier, store, bool_t, Some(bool_t));
impl_eq(unifier, store, bool_t);
/* ndarray ===== */
@ -498,4 +517,6 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None);
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_sign(unifier, store, ndarray_t, Some(ndarray_t));
impl_invert(unifier, store, ndarray_t, Some(ndarray_t));
}

View File

@ -549,7 +549,9 @@ impl<'a> Fold<()> for Inferencer<'a> {
ExprKind::BinOp { left, op, right } => {
Some(self.infer_bin_ops(expr.location, left, op, right, false)?)
}
ExprKind::UnaryOp { op, operand } => Some(self.infer_unary_ops(op, operand)?),
ExprKind::UnaryOp { op, operand } => {
Some(self.infer_unary_ops(expr.location, op, operand)?)
}
ExprKind::Compare { left, ops, comparators } => {
Some(self.infer_compare(left, ops, comparators)?)
}
@ -1247,11 +1249,20 @@ impl<'a> Inferencer<'a> {
fn infer_unary_ops(
&mut self,
location: Location,
op: &ast::Unaryop,
operand: &ast::Expr<Option<Type>>,
) -> InferenceResult {
let method = unaryop_name(op).into();
self.build_method_call(operand.location, method, operand.custom.unwrap(), vec![], None)
let ret = typeof_unaryop(
self.unifier,
self.primitives,
op,
operand.custom.unwrap(),
).map_err(|e| HashSet::from([format!("{e} (at {location})")]))?;
self.build_method_call(operand.location, method, operand.custom.unwrap(), vec![], ret)
}
fn infer_compare(

View File

@ -1,3 +1,7 @@
@extern
def output_bool(x: bool):
...
@extern
def output_int32(x: int32):
...
@ -6,10 +10,20 @@ def output_int32(x: int32):
def output_float64(x: float):
...
def output_ndarray_bool_2(n: ndarray[bool, Literal[2]]):
for r in range(len(n)):
for c in range(len(n[r])):
output_bool(n[r][c])
def output_ndarray_int32_1(n: ndarray[int32, Literal[1]]):
for i in range(len(n)):
output_int32(n[i])
def output_ndarray_int32_2(n: ndarray[int32, Literal[2]]):
for r in range(len(n)):
for c in range(len(n[r])):
output_int32(n[r][c])
def output_ndarray_float_1(n: ndarray[float, Literal[1]]):
for i in range(len(n)):
output_float64(n[i])
@ -408,6 +422,39 @@ def test_ndarray_ipow_broadcast_scalar():
output_ndarray_float_2(x)
def test_ndarray_pos():
x_int32 = np_full([2, 2], -2)
y_int32 = +x_int32
output_ndarray_int32_2(x_int32)
output_ndarray_int32_2(y_int32)
x_float = np_full([2, 2], -2.0)
y_float = +x_float
output_ndarray_float_2(x_float)
output_ndarray_float_2(y_float)
def test_ndarray_neg():
x_int32 = np_full([2, 2], -2)
y_int32 = -x_int32
output_ndarray_int32_2(x_int32)
output_ndarray_int32_2(y_int32)
x_float = np_full([2, 2], 2.0)
y_float = -x_float
output_ndarray_float_2(x_float)
output_ndarray_float_2(y_float)
def test_ndarray_inv():
x_int32 = np_full([2, 2], -2)
y_int32 = ~x_int32
output_ndarray_int32_2(x_int32)
output_ndarray_int32_2(y_int32)
def run() -> int32:
test_ndarray_ctor()
test_ndarray_empty()
@ -467,5 +514,8 @@ def run() -> int32:
test_ndarray_ipow()
test_ndarray_ipow_broadcast()
test_ndarray_ipow_broadcast_scalar()
test_ndarray_pos()
test_ndarray_neg()
test_ndarray_inv()
return 0