forked from M-Labs/nac3
1
0
Fork 0

core/ndstrides: refactoring builtin_fns

This commit is contained in:
lyken 2024-08-07 17:31:47 +08:00
parent 7afc9ff7fb
commit bcd35544cc
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
5 changed files with 417 additions and 639 deletions

File diff suppressed because it is too large Load Diff

View File

@ -27,14 +27,14 @@ pub fn call_memcpy_model<'ctx, Item: Model<'ctx> + Default, G: CodeGenerator + ?
/// Like [`gen_for_callback_incrementing`] with [`Model`] abstractions.
/// The [`IntKind`] is automatically inferred.
pub fn gen_for_model_auto<'ctx, 'a, G, F, I>(
pub fn gen_for_model_auto<'ctx, 'a, G, F, I, R>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
start: Int<'ctx, I>,
stop: Int<'ctx, I>,
step: Int<'ctx, I>,
body: F,
) -> Result<(), String>
) -> Result<R, String>
where
G: CodeGenerator + ?Sized,
F: FnOnce(
@ -42,7 +42,7 @@ where
&mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks<'ctx>,
Int<'ctx, I>,
) -> Result<(), String>,
) -> Result<R, String>,
I: IntKind<'ctx> + Default,
{
let int_model = IntModel(I::default());

View File

@ -1,4 +1,3 @@
use inkwell::values::BasicValueEnum;
use itertools::Itertools;
use util::gen_for_model_auto;
@ -19,7 +18,6 @@ pub fn starmap_scalars_array_like<'ctx, 'a, F, G>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
inputs: &Vec<ScalarOrNDArray<'ctx>>,
ret_dtype: Type,
mapping: F,
) -> Result<ScalarOrNDArray<'ctx>, String>
where
@ -28,7 +26,7 @@ where
&mut CodeGenContext<'ctx, 'a>,
Int<'ctx, SizeT>,
&Vec<ScalarObject<'ctx>>,
) -> Result<BasicValueEnum<'ctx>, String>,
) -> Result<ScalarObject<'ctx>, String>,
G: CodeGenerator + ?Sized,
{
assert!(!inputs.is_empty());
@ -44,9 +42,9 @@ where
// When inputs are all scalars, return a ScalarObject back
let i = sizet_model.const_0(generator, ctx.ctx);
let ret = mapping(generator, ctx, i, &scalars)?;
Ok(ScalarOrNDArray::Scalar(ScalarObject { value: ret, dtype: ret_dtype }))
let scalar = mapping(generator, ctx, i, &scalars)?;
Ok(ScalarOrNDArray::Scalar(scalar))
}
None => {
// When not all inputs are scalars, promote all non-ndarray inputs
@ -57,22 +55,12 @@ where
let broadcast_result = broadcast_all_ndarrays(generator, ctx, &ndarrays);
let mapped_ndarray = NDArrayObject::alloca_uninitialized(
generator,
ctx,
ret_dtype,
broadcast_result.ndims,
"mapped_ndarray",
);
mapped_ndarray.copy_shape_from_array(generator, ctx, broadcast_result.shape);
mapped_ndarray.create_data(generator, ctx);
let start = sizet_model.const_0(generator, ctx.ctx);
let stop = mapped_ndarray.size(generator, ctx);
let stop = broadcast_result.ndarrays[0].size(generator, ctx); // They all should have the same `np.size`.
let step = sizet_model.const_1(generator, ctx.ctx);
// Map element-wise and store results into `mapped_ndarray`.
gen_for_model_auto(
let mapped_ndarray = gen_for_model_auto(
generator,
ctx,
start,
@ -89,12 +77,26 @@ where
.collect_vec();
let ret = mapping(generator, ctx, i, &elements)?;
// It might look weird but it is perfectly fine putting the allocation codegen
// here within `for`.
// The reason for doing this is to get the `dtype` out of `ret`, which is only
// available after running `mapping`.
let mapped_ndarray = NDArrayObject::alloca_uninitialized(
generator,
ctx,
ret.dtype,
broadcast_result.ndims,
"mapped_ndarray",
);
mapped_ndarray.copy_shape_from_array(generator, ctx, broadcast_result.shape);
mapped_ndarray.create_data(generator, ctx);
let pret = mapped_ndarray.get_nth_pelement(generator, ctx, i, "pret");
ctx.builder.build_store(pret, ret).unwrap();
Ok(())
ctx.builder.build_store(pret, ret.value).unwrap();
Ok(mapped_ndarray)
},
)?;
Ok(ScalarOrNDArray::NDArray(mapped_ndarray))
}
}
@ -105,7 +107,6 @@ impl<'ctx> ScalarObject<'ctx> {
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
ret_dtype: Type,
mapping: F,
) -> Result<Self, String>
where
@ -114,14 +115,13 @@ impl<'ctx> ScalarObject<'ctx> {
&mut CodeGenContext<'ctx, 'a>,
Int<'ctx, SizeT>,
ScalarObject<'ctx>,
) -> Result<BasicValueEnum<'ctx>, String>,
) -> Result<ScalarObject<'ctx>, String>,
G: CodeGenerator + ?Sized,
{
let ScalarOrNDArray::Scalar(ret) = starmap_scalars_array_like(
generator,
ctx,
&vec![ScalarOrNDArray::Scalar(*self)],
ret_dtype,
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]),
)?
else {
@ -136,7 +136,6 @@ impl<'ctx> NDArrayObject<'ctx> {
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
ret_dtype: Type,
mapping: F,
) -> Result<Self, String>
where
@ -145,14 +144,13 @@ impl<'ctx> NDArrayObject<'ctx> {
&mut CodeGenContext<'ctx, 'a>,
Int<'ctx, SizeT>,
ScalarObject<'ctx>,
) -> Result<BasicValueEnum<'ctx>, String>,
) -> Result<ScalarObject<'ctx>, String>,
G: CodeGenerator + ?Sized,
{
let ScalarOrNDArray::NDArray(ret) = starmap_scalars_array_like(
generator,
ctx,
&vec![ScalarOrNDArray::NDArray(*self)],
ret_dtype,
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]),
)?
else {
@ -176,15 +174,18 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
&mut CodeGenContext<'ctx, 'a>,
Int<'ctx, SizeT>,
ScalarObject<'ctx>,
) -> Result<BasicValueEnum<'ctx>, String>,
) -> Result<ScalarObject<'ctx>, String>,
G: CodeGenerator + ?Sized,
{
match self {
ScalarOrNDArray::Scalar(scalar) => {
scalar.map(generator, ctx, ret_dtype, mapping).map(ScalarOrNDArray::Scalar)
}
ScalarOrNDArray::Scalar(scalar) => starmap_scalars_array_like(
generator,
ctx,
&vec![ScalarOrNDArray::Scalar(*scalar)],
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]),
),
ScalarOrNDArray::NDArray(ndarray) => {
ndarray.map(generator, ctx, ret_dtype, mapping).map(ScalarOrNDArray::NDArray)
ndarray.map(generator, ctx, mapping).map(ScalarOrNDArray::NDArray)
}
}
}

View File

@ -1180,7 +1180,7 @@ impl<'a> BuiltinBuilder<'a> {
let ret_elem_ty = size_variant.of_int(&ctx.primitives);
let func = match kind {
Kind::Ceil => builtin_fns::call_ceil,
Kind::Floor => builtin_fns::call_floor,
Kind::Floor => builtin_fns::call_ceil_or_floor,
};
Ok(Some(func(generator, ctx, (arg_ty, arg), ret_elem_ty)?))
}),
@ -1548,7 +1548,7 @@ impl<'a> BuiltinBuilder<'a> {
let func = match prim {
PrimDef::FunNpCeil => builtin_fns::call_ceil,
PrimDef::FunNpFloor => builtin_fns::call_floor,
PrimDef::FunNpFloor => builtin_fns::call_ceil_or_floor,
_ => unreachable!(),
};
Ok(Some(func(generator, ctx, (arg_ty, arg), ctx.primitives.float)?))

View File

@ -336,6 +336,14 @@ impl Unifier {
self.unification_table.unioned(a, b)
}
/// Determine if a type unions with a type in `tys`.
pub fn unioned_any<I>(&mut self, a: Type, tys: I) -> bool
where
I: IntoIterator<Item = Type>,
{
tys.into_iter().any(|ty| self.unioned(a, ty))
}
pub fn from_shared_unifier(unifier: &SharedUnifier) -> Unifier {
let lock = unifier.lock().unwrap();
Unifier {