forked from M-Labs/nac3
core/ndstrides: refactoring builtin_fns
This commit is contained in:
parent
7afc9ff7fb
commit
bcd35544cc
File diff suppressed because it is too large
Load Diff
@ -27,14 +27,14 @@ pub fn call_memcpy_model<'ctx, Item: Model<'ctx> + Default, G: CodeGenerator + ?
|
||||
|
||||
/// Like [`gen_for_callback_incrementing`] with [`Model`] abstractions.
|
||||
/// The [`IntKind`] is automatically inferred.
|
||||
pub fn gen_for_model_auto<'ctx, 'a, G, F, I>(
|
||||
pub fn gen_for_model_auto<'ctx, 'a, G, F, I, R>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
start: Int<'ctx, I>,
|
||||
stop: Int<'ctx, I>,
|
||||
step: Int<'ctx, I>,
|
||||
body: F,
|
||||
) -> Result<(), String>
|
||||
) -> Result<R, String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
F: FnOnce(
|
||||
@ -42,7 +42,7 @@ where
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
BreakContinueHooks<'ctx>,
|
||||
Int<'ctx, I>,
|
||||
) -> Result<(), String>,
|
||||
) -> Result<R, String>,
|
||||
I: IntKind<'ctx> + Default,
|
||||
{
|
||||
let int_model = IntModel(I::default());
|
||||
|
@ -1,4 +1,3 @@
|
||||
use inkwell::values::BasicValueEnum;
|
||||
use itertools::Itertools;
|
||||
use util::gen_for_model_auto;
|
||||
|
||||
@ -19,7 +18,6 @@ pub fn starmap_scalars_array_like<'ctx, 'a, F, G>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
inputs: &Vec<ScalarOrNDArray<'ctx>>,
|
||||
ret_dtype: Type,
|
||||
mapping: F,
|
||||
) -> Result<ScalarOrNDArray<'ctx>, String>
|
||||
where
|
||||
@ -28,7 +26,7 @@ where
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
Int<'ctx, SizeT>,
|
||||
&Vec<ScalarObject<'ctx>>,
|
||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
) -> Result<ScalarObject<'ctx>, String>,
|
||||
G: CodeGenerator + ?Sized,
|
||||
{
|
||||
assert!(!inputs.is_empty());
|
||||
@ -44,9 +42,9 @@ where
|
||||
// When inputs are all scalars, return a ScalarObject back
|
||||
|
||||
let i = sizet_model.const_0(generator, ctx.ctx);
|
||||
let ret = mapping(generator, ctx, i, &scalars)?;
|
||||
|
||||
Ok(ScalarOrNDArray::Scalar(ScalarObject { value: ret, dtype: ret_dtype }))
|
||||
let scalar = mapping(generator, ctx, i, &scalars)?;
|
||||
Ok(ScalarOrNDArray::Scalar(scalar))
|
||||
}
|
||||
None => {
|
||||
// When not all inputs are scalars, promote all non-ndarray inputs
|
||||
@ -57,22 +55,12 @@ where
|
||||
|
||||
let broadcast_result = broadcast_all_ndarrays(generator, ctx, &ndarrays);
|
||||
|
||||
let mapped_ndarray = NDArrayObject::alloca_uninitialized(
|
||||
generator,
|
||||
ctx,
|
||||
ret_dtype,
|
||||
broadcast_result.ndims,
|
||||
"mapped_ndarray",
|
||||
);
|
||||
mapped_ndarray.copy_shape_from_array(generator, ctx, broadcast_result.shape);
|
||||
mapped_ndarray.create_data(generator, ctx);
|
||||
|
||||
let start = sizet_model.const_0(generator, ctx.ctx);
|
||||
let stop = mapped_ndarray.size(generator, ctx);
|
||||
let stop = broadcast_result.ndarrays[0].size(generator, ctx); // They all should have the same `np.size`.
|
||||
let step = sizet_model.const_1(generator, ctx.ctx);
|
||||
|
||||
// Map element-wise and store results into `mapped_ndarray`.
|
||||
gen_for_model_auto(
|
||||
let mapped_ndarray = gen_for_model_auto(
|
||||
generator,
|
||||
ctx,
|
||||
start,
|
||||
@ -89,12 +77,26 @@ where
|
||||
.collect_vec();
|
||||
|
||||
let ret = mapping(generator, ctx, i, &elements)?;
|
||||
|
||||
// It might look weird but it is perfectly fine putting the allocation codegen
|
||||
// here within `for`.
|
||||
// The reason for doing this is to get the `dtype` out of `ret`, which is only
|
||||
// available after running `mapping`.
|
||||
let mapped_ndarray = NDArrayObject::alloca_uninitialized(
|
||||
generator,
|
||||
ctx,
|
||||
ret.dtype,
|
||||
broadcast_result.ndims,
|
||||
"mapped_ndarray",
|
||||
);
|
||||
mapped_ndarray.copy_shape_from_array(generator, ctx, broadcast_result.shape);
|
||||
mapped_ndarray.create_data(generator, ctx);
|
||||
|
||||
let pret = mapped_ndarray.get_nth_pelement(generator, ctx, i, "pret");
|
||||
ctx.builder.build_store(pret, ret).unwrap();
|
||||
Ok(())
|
||||
ctx.builder.build_store(pret, ret.value).unwrap();
|
||||
Ok(mapped_ndarray)
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(ScalarOrNDArray::NDArray(mapped_ndarray))
|
||||
}
|
||||
}
|
||||
@ -105,7 +107,6 @@ impl<'ctx> ScalarObject<'ctx> {
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ret_dtype: Type,
|
||||
mapping: F,
|
||||
) -> Result<Self, String>
|
||||
where
|
||||
@ -114,14 +115,13 @@ impl<'ctx> ScalarObject<'ctx> {
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
Int<'ctx, SizeT>,
|
||||
ScalarObject<'ctx>,
|
||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
) -> Result<ScalarObject<'ctx>, String>,
|
||||
G: CodeGenerator + ?Sized,
|
||||
{
|
||||
let ScalarOrNDArray::Scalar(ret) = starmap_scalars_array_like(
|
||||
generator,
|
||||
ctx,
|
||||
&vec![ScalarOrNDArray::Scalar(*self)],
|
||||
ret_dtype,
|
||||
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]),
|
||||
)?
|
||||
else {
|
||||
@ -136,7 +136,6 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ret_dtype: Type,
|
||||
mapping: F,
|
||||
) -> Result<Self, String>
|
||||
where
|
||||
@ -145,14 +144,13 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
Int<'ctx, SizeT>,
|
||||
ScalarObject<'ctx>,
|
||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
) -> Result<ScalarObject<'ctx>, String>,
|
||||
G: CodeGenerator + ?Sized,
|
||||
{
|
||||
let ScalarOrNDArray::NDArray(ret) = starmap_scalars_array_like(
|
||||
generator,
|
||||
ctx,
|
||||
&vec![ScalarOrNDArray::NDArray(*self)],
|
||||
ret_dtype,
|
||||
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]),
|
||||
)?
|
||||
else {
|
||||
@ -176,15 +174,18 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
Int<'ctx, SizeT>,
|
||||
ScalarObject<'ctx>,
|
||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
) -> Result<ScalarObject<'ctx>, String>,
|
||||
G: CodeGenerator + ?Sized,
|
||||
{
|
||||
match self {
|
||||
ScalarOrNDArray::Scalar(scalar) => {
|
||||
scalar.map(generator, ctx, ret_dtype, mapping).map(ScalarOrNDArray::Scalar)
|
||||
}
|
||||
ScalarOrNDArray::Scalar(scalar) => starmap_scalars_array_like(
|
||||
generator,
|
||||
ctx,
|
||||
&vec![ScalarOrNDArray::Scalar(*scalar)],
|
||||
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]),
|
||||
),
|
||||
ScalarOrNDArray::NDArray(ndarray) => {
|
||||
ndarray.map(generator, ctx, ret_dtype, mapping).map(ScalarOrNDArray::NDArray)
|
||||
ndarray.map(generator, ctx, mapping).map(ScalarOrNDArray::NDArray)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1180,7 +1180,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
let ret_elem_ty = size_variant.of_int(&ctx.primitives);
|
||||
let func = match kind {
|
||||
Kind::Ceil => builtin_fns::call_ceil,
|
||||
Kind::Floor => builtin_fns::call_floor,
|
||||
Kind::Floor => builtin_fns::call_ceil_or_floor,
|
||||
};
|
||||
Ok(Some(func(generator, ctx, (arg_ty, arg), ret_elem_ty)?))
|
||||
}),
|
||||
@ -1548,7 +1548,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
|
||||
let func = match prim {
|
||||
PrimDef::FunNpCeil => builtin_fns::call_ceil,
|
||||
PrimDef::FunNpFloor => builtin_fns::call_floor,
|
||||
PrimDef::FunNpFloor => builtin_fns::call_ceil_or_floor,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
Ok(Some(func(generator, ctx, (arg_ty, arg), ctx.primitives.float)?))
|
||||
|
@ -336,6 +336,14 @@ impl Unifier {
|
||||
self.unification_table.unioned(a, b)
|
||||
}
|
||||
|
||||
/// Determine if a type unions with a type in `tys`.
|
||||
pub fn unioned_any<I>(&mut self, a: Type, tys: I) -> bool
|
||||
where
|
||||
I: IntoIterator<Item = Type>,
|
||||
{
|
||||
tys.into_iter().any(|ty| self.unioned(a, ty))
|
||||
}
|
||||
|
||||
pub fn from_shared_unifier(unifier: &SharedUnifier) -> Unifier {
|
||||
let lock = unifier.lock().unwrap();
|
||||
Unifier {
|
||||
|
Loading…
Reference in New Issue
Block a user