forked from M-Labs/nac3
[core] codegen: Reimplement ndarray unary op
Based on bb992704: core/ndstrides: implement unary op
This commit is contained in:
parent
59f19e29df
commit
a2f1b25fd8
@ -1777,10 +1777,10 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
_ => val.into(),
|
_ => val.into(),
|
||||||
}
|
}
|
||||||
} else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
} else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
||||||
let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, ty);
|
|
||||||
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
||||||
|
|
||||||
let val = llvm_ndarray_ty.map_value(val.into_pointer_value(), None);
|
let ndarray = NDArrayType::from_unifier_type(generator, ctx, ty)
|
||||||
|
.map_value(val.into_pointer_value(), None);
|
||||||
|
|
||||||
// ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before
|
// ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before
|
||||||
// passing it to the elementwise codegen function
|
// passing it to the elementwise codegen function
|
||||||
@ -1798,20 +1798,18 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
op
|
op
|
||||||
};
|
};
|
||||||
|
|
||||||
let res = numpy::ndarray_elementwise_unaryop_impl(
|
let mapped_ndarray = ndarray.map(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ndarray_dtype,
|
NDArrayOut::NewNDArray { dtype: ndarray.get_type().element_type() },
|
||||||
None,
|
|generator, ctx, scalar| {
|
||||||
val,
|
gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), scalar))?
|
||||||
|generator, ctx, val| {
|
.map(|val| val.to_basic_value_enum(ctx, generator, ndarray_dtype))
|
||||||
gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), val))?
|
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.to_basic_value_enum(ctx, generator, ndarray_dtype)
|
|
||||||
},
|
},
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
res.as_base_value().into()
|
mapped_ndarray.as_base_value().into()
|
||||||
} else {
|
} else {
|
||||||
unimplemented!()
|
unimplemented!()
|
||||||
}))
|
}))
|
||||||
|
@ -195,28 +195,6 @@ where
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ndarray_fill_mapping<'ctx, 'a, G, MapFn>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
|
||||||
src: NDArrayValue<'ctx>,
|
|
||||||
dest: NDArrayValue<'ctx>,
|
|
||||||
map_fn: MapFn,
|
|
||||||
) -> Result<(), String>
|
|
||||||
where
|
|
||||||
G: CodeGenerator + ?Sized,
|
|
||||||
MapFn: Fn(
|
|
||||||
&mut G,
|
|
||||||
&mut CodeGenContext<'ctx, 'a>,
|
|
||||||
BasicValueEnum<'ctx>,
|
|
||||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
|
||||||
{
|
|
||||||
ndarray_fill_flattened(generator, ctx, dest, |generator, ctx, i| {
|
|
||||||
let elem = unsafe { src.data().get_unchecked(ctx, generator, &i, None) };
|
|
||||||
|
|
||||||
map_fn(generator, ctx, elem)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates the LLVM IR for checking whether the source `ndarray` can be broadcast to the shape of
|
/// Generates the LLVM IR for checking whether the source `ndarray` can be broadcast to the shape of
|
||||||
/// the target `ndarray`.
|
/// the target `ndarray`.
|
||||||
fn ndarray_assert_is_broadcastable<'ctx, G: CodeGenerator + ?Sized>(
|
fn ndarray_assert_is_broadcastable<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
@ -614,43 +592,6 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
ndarray_sliced_copy(generator, ctx, elem_ty, this, &[])
|
ndarray_sliced_copy(generator, ctx, elem_ty, this, &[])
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn ndarray_elementwise_unaryop_impl<'ctx, 'a, G, MapFn>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
|
||||||
elem_ty: Type,
|
|
||||||
res: Option<NDArrayValue<'ctx>>,
|
|
||||||
operand: NDArrayValue<'ctx>,
|
|
||||||
map_fn: MapFn,
|
|
||||||
) -> Result<NDArrayValue<'ctx>, String>
|
|
||||||
where
|
|
||||||
G: CodeGenerator + ?Sized,
|
|
||||||
MapFn: Fn(
|
|
||||||
&mut G,
|
|
||||||
&mut CodeGenContext<'ctx, 'a>,
|
|
||||||
BasicValueEnum<'ctx>,
|
|
||||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
|
||||||
{
|
|
||||||
let res = res.unwrap_or_else(|| {
|
|
||||||
create_ndarray_dyn_shape(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
elem_ty,
|
|
||||||
&operand,
|
|
||||||
|_, ctx, v| Ok(v.load_ndims(ctx)),
|
|
||||||
|generator, ctx, v, idx| unsafe {
|
|
||||||
Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None))
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.unwrap()
|
|
||||||
});
|
|
||||||
|
|
||||||
ndarray_fill_mapping(generator, ctx, operand, res, |generator, ctx, elem| {
|
|
||||||
map_fn(generator, ctx, elem)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// LLVM-typed implementation for computing elementwise binary operations on two input operands.
|
/// LLVM-typed implementation for computing elementwise binary operations on two input operands.
|
||||||
///
|
///
|
||||||
/// If the operand is a `ndarray`, the broadcast index corresponding to each element in the output
|
/// If the operand is a `ndarray`, the broadcast index corresponding to each element in the output
|
||||||
|
Loading…
Reference in New Issue
Block a user