forked from M-Labs/nac3
1
0
Fork 0

core/ndstrides: refactoring builtin_fns

This commit is contained in:
lyken 2024-08-07 17:31:47 +08:00
parent 7afc9ff7fb
commit bcd35544cc
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
5 changed files with 417 additions and 639 deletions

File diff suppressed because it is too large Load Diff

View File

@ -27,14 +27,14 @@ pub fn call_memcpy_model<'ctx, Item: Model<'ctx> + Default, G: CodeGenerator + ?
/// Like [`gen_for_callback_incrementing`] with [`Model`] abstractions. /// Like [`gen_for_callback_incrementing`] with [`Model`] abstractions.
/// The [`IntKind`] is automatically inferred. /// The [`IntKind`] is automatically inferred.
pub fn gen_for_model_auto<'ctx, 'a, G, F, I>( pub fn gen_for_model_auto<'ctx, 'a, G, F, I, R>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
start: Int<'ctx, I>, start: Int<'ctx, I>,
stop: Int<'ctx, I>, stop: Int<'ctx, I>,
step: Int<'ctx, I>, step: Int<'ctx, I>,
body: F, body: F,
) -> Result<(), String> ) -> Result<R, String>
where where
G: CodeGenerator + ?Sized, G: CodeGenerator + ?Sized,
F: FnOnce( F: FnOnce(
@ -42,7 +42,7 @@ where
&mut CodeGenContext<'ctx, 'a>, &mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks<'ctx>, BreakContinueHooks<'ctx>,
Int<'ctx, I>, Int<'ctx, I>,
) -> Result<(), String>, ) -> Result<R, String>,
I: IntKind<'ctx> + Default, I: IntKind<'ctx> + Default,
{ {
let int_model = IntModel(I::default()); let int_model = IntModel(I::default());

View File

@ -1,4 +1,3 @@
use inkwell::values::BasicValueEnum;
use itertools::Itertools; use itertools::Itertools;
use util::gen_for_model_auto; use util::gen_for_model_auto;
@ -19,7 +18,6 @@ pub fn starmap_scalars_array_like<'ctx, 'a, F, G>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
inputs: &Vec<ScalarOrNDArray<'ctx>>, inputs: &Vec<ScalarOrNDArray<'ctx>>,
ret_dtype: Type,
mapping: F, mapping: F,
) -> Result<ScalarOrNDArray<'ctx>, String> ) -> Result<ScalarOrNDArray<'ctx>, String>
where where
@ -28,7 +26,7 @@ where
&mut CodeGenContext<'ctx, 'a>, &mut CodeGenContext<'ctx, 'a>,
Int<'ctx, SizeT>, Int<'ctx, SizeT>,
&Vec<ScalarObject<'ctx>>, &Vec<ScalarObject<'ctx>>,
) -> Result<BasicValueEnum<'ctx>, String>, ) -> Result<ScalarObject<'ctx>, String>,
G: CodeGenerator + ?Sized, G: CodeGenerator + ?Sized,
{ {
assert!(!inputs.is_empty()); assert!(!inputs.is_empty());
@ -44,9 +42,9 @@ where
// When inputs are all scalars, return a ScalarObject back // When inputs are all scalars, return a ScalarObject back
let i = sizet_model.const_0(generator, ctx.ctx); let i = sizet_model.const_0(generator, ctx.ctx);
let ret = mapping(generator, ctx, i, &scalars)?;
Ok(ScalarOrNDArray::Scalar(ScalarObject { value: ret, dtype: ret_dtype })) let scalar = mapping(generator, ctx, i, &scalars)?;
Ok(ScalarOrNDArray::Scalar(scalar))
} }
None => { None => {
// When not all inputs are scalars, promote all non-ndarray inputs // When not all inputs are scalars, promote all non-ndarray inputs
@ -57,22 +55,12 @@ where
let broadcast_result = broadcast_all_ndarrays(generator, ctx, &ndarrays); let broadcast_result = broadcast_all_ndarrays(generator, ctx, &ndarrays);
let mapped_ndarray = NDArrayObject::alloca_uninitialized(
generator,
ctx,
ret_dtype,
broadcast_result.ndims,
"mapped_ndarray",
);
mapped_ndarray.copy_shape_from_array(generator, ctx, broadcast_result.shape);
mapped_ndarray.create_data(generator, ctx);
let start = sizet_model.const_0(generator, ctx.ctx); let start = sizet_model.const_0(generator, ctx.ctx);
let stop = mapped_ndarray.size(generator, ctx); let stop = broadcast_result.ndarrays[0].size(generator, ctx); // They all should have the same `np.size`.
let step = sizet_model.const_1(generator, ctx.ctx); let step = sizet_model.const_1(generator, ctx.ctx);
// Map element-wise and store results into `mapped_ndarray`. // Map element-wise and store results into `mapped_ndarray`.
gen_for_model_auto( let mapped_ndarray = gen_for_model_auto(
generator, generator,
ctx, ctx,
start, start,
@ -89,12 +77,26 @@ where
.collect_vec(); .collect_vec();
let ret = mapping(generator, ctx, i, &elements)?; let ret = mapping(generator, ctx, i, &elements)?;
// It might look weird but it is perfectly fine putting the allocation codegen
// here within `for`.
// The reason for doing this is to get the `dtype` out of `ret`, which is only
// available after running `mapping`.
let mapped_ndarray = NDArrayObject::alloca_uninitialized(
generator,
ctx,
ret.dtype,
broadcast_result.ndims,
"mapped_ndarray",
);
mapped_ndarray.copy_shape_from_array(generator, ctx, broadcast_result.shape);
mapped_ndarray.create_data(generator, ctx);
let pret = mapped_ndarray.get_nth_pelement(generator, ctx, i, "pret"); let pret = mapped_ndarray.get_nth_pelement(generator, ctx, i, "pret");
ctx.builder.build_store(pret, ret).unwrap(); ctx.builder.build_store(pret, ret.value).unwrap();
Ok(()) Ok(mapped_ndarray)
}, },
)?; )?;
Ok(ScalarOrNDArray::NDArray(mapped_ndarray)) Ok(ScalarOrNDArray::NDArray(mapped_ndarray))
} }
} }
@ -105,7 +107,6 @@ impl<'ctx> ScalarObject<'ctx> {
&self, &self,
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
ret_dtype: Type,
mapping: F, mapping: F,
) -> Result<Self, String> ) -> Result<Self, String>
where where
@ -114,14 +115,13 @@ impl<'ctx> ScalarObject<'ctx> {
&mut CodeGenContext<'ctx, 'a>, &mut CodeGenContext<'ctx, 'a>,
Int<'ctx, SizeT>, Int<'ctx, SizeT>,
ScalarObject<'ctx>, ScalarObject<'ctx>,
) -> Result<BasicValueEnum<'ctx>, String>, ) -> Result<ScalarObject<'ctx>, String>,
G: CodeGenerator + ?Sized, G: CodeGenerator + ?Sized,
{ {
let ScalarOrNDArray::Scalar(ret) = starmap_scalars_array_like( let ScalarOrNDArray::Scalar(ret) = starmap_scalars_array_like(
generator, generator,
ctx, ctx,
&vec![ScalarOrNDArray::Scalar(*self)], &vec![ScalarOrNDArray::Scalar(*self)],
ret_dtype,
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]), |generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]),
)? )?
else { else {
@ -136,7 +136,6 @@ impl<'ctx> NDArrayObject<'ctx> {
&self, &self,
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
ret_dtype: Type,
mapping: F, mapping: F,
) -> Result<Self, String> ) -> Result<Self, String>
where where
@ -145,14 +144,13 @@ impl<'ctx> NDArrayObject<'ctx> {
&mut CodeGenContext<'ctx, 'a>, &mut CodeGenContext<'ctx, 'a>,
Int<'ctx, SizeT>, Int<'ctx, SizeT>,
ScalarObject<'ctx>, ScalarObject<'ctx>,
) -> Result<BasicValueEnum<'ctx>, String>, ) -> Result<ScalarObject<'ctx>, String>,
G: CodeGenerator + ?Sized, G: CodeGenerator + ?Sized,
{ {
let ScalarOrNDArray::NDArray(ret) = starmap_scalars_array_like( let ScalarOrNDArray::NDArray(ret) = starmap_scalars_array_like(
generator, generator,
ctx, ctx,
&vec![ScalarOrNDArray::NDArray(*self)], &vec![ScalarOrNDArray::NDArray(*self)],
ret_dtype,
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]), |generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]),
)? )?
else { else {
@ -176,15 +174,18 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
&mut CodeGenContext<'ctx, 'a>, &mut CodeGenContext<'ctx, 'a>,
Int<'ctx, SizeT>, Int<'ctx, SizeT>,
ScalarObject<'ctx>, ScalarObject<'ctx>,
) -> Result<BasicValueEnum<'ctx>, String>, ) -> Result<ScalarObject<'ctx>, String>,
G: CodeGenerator + ?Sized, G: CodeGenerator + ?Sized,
{ {
match self { match self {
ScalarOrNDArray::Scalar(scalar) => { ScalarOrNDArray::Scalar(scalar) => starmap_scalars_array_like(
scalar.map(generator, ctx, ret_dtype, mapping).map(ScalarOrNDArray::Scalar) generator,
} ctx,
&vec![ScalarOrNDArray::Scalar(*scalar)],
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]),
),
ScalarOrNDArray::NDArray(ndarray) => { ScalarOrNDArray::NDArray(ndarray) => {
ndarray.map(generator, ctx, ret_dtype, mapping).map(ScalarOrNDArray::NDArray) ndarray.map(generator, ctx, mapping).map(ScalarOrNDArray::NDArray)
} }
} }
} }

View File

@ -1180,7 +1180,7 @@ impl<'a> BuiltinBuilder<'a> {
let ret_elem_ty = size_variant.of_int(&ctx.primitives); let ret_elem_ty = size_variant.of_int(&ctx.primitives);
let func = match kind { let func = match kind {
Kind::Ceil => builtin_fns::call_ceil, Kind::Ceil => builtin_fns::call_ceil,
Kind::Floor => builtin_fns::call_floor, Kind::Floor => builtin_fns::call_ceil_or_floor,
}; };
Ok(Some(func(generator, ctx, (arg_ty, arg), ret_elem_ty)?)) Ok(Some(func(generator, ctx, (arg_ty, arg), ret_elem_ty)?))
}), }),
@ -1548,7 +1548,7 @@ impl<'a> BuiltinBuilder<'a> {
let func = match prim { let func = match prim {
PrimDef::FunNpCeil => builtin_fns::call_ceil, PrimDef::FunNpCeil => builtin_fns::call_ceil,
PrimDef::FunNpFloor => builtin_fns::call_floor, PrimDef::FunNpFloor => builtin_fns::call_ceil_or_floor,
_ => unreachable!(), _ => unreachable!(),
}; };
Ok(Some(func(generator, ctx, (arg_ty, arg), ctx.primitives.float)?)) Ok(Some(func(generator, ctx, (arg_ty, arg), ctx.primitives.float)?))

View File

@ -336,6 +336,14 @@ impl Unifier {
self.unification_table.unioned(a, b) self.unification_table.unioned(a, b)
} }
/// Determine if a type unions with a type in `tys`.
pub fn unioned_any<I>(&mut self, a: Type, tys: I) -> bool
where
I: IntoIterator<Item = Type>,
{
tys.into_iter().any(|ty| self.unioned(a, ty))
}
pub fn from_shared_unifier(unifier: &SharedUnifier) -> Unifier { pub fn from_shared_unifier(unifier: &SharedUnifier) -> Unifier {
let lock = unifier.lock().unwrap(); let lock = unifier.lock().unwrap();
Unifier { Unifier {