forked from M-Labs/nac3
core/ndstrides: refactoring builtin_fns
This commit is contained in:
parent
7afc9ff7fb
commit
bcd35544cc
File diff suppressed because it is too large
Load Diff
|
@ -27,14 +27,14 @@ pub fn call_memcpy_model<'ctx, Item: Model<'ctx> + Default, G: CodeGenerator + ?
|
||||||
|
|
||||||
/// Like [`gen_for_callback_incrementing`] with [`Model`] abstractions.
|
/// Like [`gen_for_callback_incrementing`] with [`Model`] abstractions.
|
||||||
/// The [`IntKind`] is automatically inferred.
|
/// The [`IntKind`] is automatically inferred.
|
||||||
pub fn gen_for_model_auto<'ctx, 'a, G, F, I>(
|
pub fn gen_for_model_auto<'ctx, 'a, G, F, I, R>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
start: Int<'ctx, I>,
|
start: Int<'ctx, I>,
|
||||||
stop: Int<'ctx, I>,
|
stop: Int<'ctx, I>,
|
||||||
step: Int<'ctx, I>,
|
step: Int<'ctx, I>,
|
||||||
body: F,
|
body: F,
|
||||||
) -> Result<(), String>
|
) -> Result<R, String>
|
||||||
where
|
where
|
||||||
G: CodeGenerator + ?Sized,
|
G: CodeGenerator + ?Sized,
|
||||||
F: FnOnce(
|
F: FnOnce(
|
||||||
|
@ -42,7 +42,7 @@ where
|
||||||
&mut CodeGenContext<'ctx, 'a>,
|
&mut CodeGenContext<'ctx, 'a>,
|
||||||
BreakContinueHooks<'ctx>,
|
BreakContinueHooks<'ctx>,
|
||||||
Int<'ctx, I>,
|
Int<'ctx, I>,
|
||||||
) -> Result<(), String>,
|
) -> Result<R, String>,
|
||||||
I: IntKind<'ctx> + Default,
|
I: IntKind<'ctx> + Default,
|
||||||
{
|
{
|
||||||
let int_model = IntModel(I::default());
|
let int_model = IntModel(I::default());
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
use inkwell::values::BasicValueEnum;
|
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use util::gen_for_model_auto;
|
use util::gen_for_model_auto;
|
||||||
|
|
||||||
|
@ -19,7 +18,6 @@ pub fn starmap_scalars_array_like<'ctx, 'a, F, G>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
inputs: &Vec<ScalarOrNDArray<'ctx>>,
|
inputs: &Vec<ScalarOrNDArray<'ctx>>,
|
||||||
ret_dtype: Type,
|
|
||||||
mapping: F,
|
mapping: F,
|
||||||
) -> Result<ScalarOrNDArray<'ctx>, String>
|
) -> Result<ScalarOrNDArray<'ctx>, String>
|
||||||
where
|
where
|
||||||
|
@ -28,7 +26,7 @@ where
|
||||||
&mut CodeGenContext<'ctx, 'a>,
|
&mut CodeGenContext<'ctx, 'a>,
|
||||||
Int<'ctx, SizeT>,
|
Int<'ctx, SizeT>,
|
||||||
&Vec<ScalarObject<'ctx>>,
|
&Vec<ScalarObject<'ctx>>,
|
||||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
) -> Result<ScalarObject<'ctx>, String>,
|
||||||
G: CodeGenerator + ?Sized,
|
G: CodeGenerator + ?Sized,
|
||||||
{
|
{
|
||||||
assert!(!inputs.is_empty());
|
assert!(!inputs.is_empty());
|
||||||
|
@ -44,9 +42,9 @@ where
|
||||||
// When inputs are all scalars, return a ScalarObject back
|
// When inputs are all scalars, return a ScalarObject back
|
||||||
|
|
||||||
let i = sizet_model.const_0(generator, ctx.ctx);
|
let i = sizet_model.const_0(generator, ctx.ctx);
|
||||||
let ret = mapping(generator, ctx, i, &scalars)?;
|
|
||||||
|
|
||||||
Ok(ScalarOrNDArray::Scalar(ScalarObject { value: ret, dtype: ret_dtype }))
|
let scalar = mapping(generator, ctx, i, &scalars)?;
|
||||||
|
Ok(ScalarOrNDArray::Scalar(scalar))
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
// When not all inputs are scalars, promote all non-ndarray inputs
|
// When not all inputs are scalars, promote all non-ndarray inputs
|
||||||
|
@ -57,22 +55,12 @@ where
|
||||||
|
|
||||||
let broadcast_result = broadcast_all_ndarrays(generator, ctx, &ndarrays);
|
let broadcast_result = broadcast_all_ndarrays(generator, ctx, &ndarrays);
|
||||||
|
|
||||||
let mapped_ndarray = NDArrayObject::alloca_uninitialized(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
ret_dtype,
|
|
||||||
broadcast_result.ndims,
|
|
||||||
"mapped_ndarray",
|
|
||||||
);
|
|
||||||
mapped_ndarray.copy_shape_from_array(generator, ctx, broadcast_result.shape);
|
|
||||||
mapped_ndarray.create_data(generator, ctx);
|
|
||||||
|
|
||||||
let start = sizet_model.const_0(generator, ctx.ctx);
|
let start = sizet_model.const_0(generator, ctx.ctx);
|
||||||
let stop = mapped_ndarray.size(generator, ctx);
|
let stop = broadcast_result.ndarrays[0].size(generator, ctx); // They all should have the same `np.size`.
|
||||||
let step = sizet_model.const_1(generator, ctx.ctx);
|
let step = sizet_model.const_1(generator, ctx.ctx);
|
||||||
|
|
||||||
// Map element-wise and store results into `mapped_ndarray`.
|
// Map element-wise and store results into `mapped_ndarray`.
|
||||||
gen_for_model_auto(
|
let mapped_ndarray = gen_for_model_auto(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
start,
|
start,
|
||||||
|
@ -89,12 +77,26 @@ where
|
||||||
.collect_vec();
|
.collect_vec();
|
||||||
|
|
||||||
let ret = mapping(generator, ctx, i, &elements)?;
|
let ret = mapping(generator, ctx, i, &elements)?;
|
||||||
|
|
||||||
|
// It might look weird but it is perfectly fine putting the allocation codegen
|
||||||
|
// here within `for`.
|
||||||
|
// The reason for doing this is to get the `dtype` out of `ret`, which is only
|
||||||
|
// available after running `mapping`.
|
||||||
|
let mapped_ndarray = NDArrayObject::alloca_uninitialized(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ret.dtype,
|
||||||
|
broadcast_result.ndims,
|
||||||
|
"mapped_ndarray",
|
||||||
|
);
|
||||||
|
mapped_ndarray.copy_shape_from_array(generator, ctx, broadcast_result.shape);
|
||||||
|
mapped_ndarray.create_data(generator, ctx);
|
||||||
|
|
||||||
let pret = mapped_ndarray.get_nth_pelement(generator, ctx, i, "pret");
|
let pret = mapped_ndarray.get_nth_pelement(generator, ctx, i, "pret");
|
||||||
ctx.builder.build_store(pret, ret).unwrap();
|
ctx.builder.build_store(pret, ret.value).unwrap();
|
||||||
Ok(())
|
Ok(mapped_ndarray)
|
||||||
},
|
},
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
Ok(ScalarOrNDArray::NDArray(mapped_ndarray))
|
Ok(ScalarOrNDArray::NDArray(mapped_ndarray))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -105,7 +107,6 @@ impl<'ctx> ScalarObject<'ctx> {
|
||||||
&self,
|
&self,
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
ret_dtype: Type,
|
|
||||||
mapping: F,
|
mapping: F,
|
||||||
) -> Result<Self, String>
|
) -> Result<Self, String>
|
||||||
where
|
where
|
||||||
|
@ -114,14 +115,13 @@ impl<'ctx> ScalarObject<'ctx> {
|
||||||
&mut CodeGenContext<'ctx, 'a>,
|
&mut CodeGenContext<'ctx, 'a>,
|
||||||
Int<'ctx, SizeT>,
|
Int<'ctx, SizeT>,
|
||||||
ScalarObject<'ctx>,
|
ScalarObject<'ctx>,
|
||||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
) -> Result<ScalarObject<'ctx>, String>,
|
||||||
G: CodeGenerator + ?Sized,
|
G: CodeGenerator + ?Sized,
|
||||||
{
|
{
|
||||||
let ScalarOrNDArray::Scalar(ret) = starmap_scalars_array_like(
|
let ScalarOrNDArray::Scalar(ret) = starmap_scalars_array_like(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
&vec![ScalarOrNDArray::Scalar(*self)],
|
&vec![ScalarOrNDArray::Scalar(*self)],
|
||||||
ret_dtype,
|
|
||||||
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]),
|
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]),
|
||||||
)?
|
)?
|
||||||
else {
|
else {
|
||||||
|
@ -136,7 +136,6 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
&self,
|
&self,
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
ret_dtype: Type,
|
|
||||||
mapping: F,
|
mapping: F,
|
||||||
) -> Result<Self, String>
|
) -> Result<Self, String>
|
||||||
where
|
where
|
||||||
|
@ -145,14 +144,13 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
&mut CodeGenContext<'ctx, 'a>,
|
&mut CodeGenContext<'ctx, 'a>,
|
||||||
Int<'ctx, SizeT>,
|
Int<'ctx, SizeT>,
|
||||||
ScalarObject<'ctx>,
|
ScalarObject<'ctx>,
|
||||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
) -> Result<ScalarObject<'ctx>, String>,
|
||||||
G: CodeGenerator + ?Sized,
|
G: CodeGenerator + ?Sized,
|
||||||
{
|
{
|
||||||
let ScalarOrNDArray::NDArray(ret) = starmap_scalars_array_like(
|
let ScalarOrNDArray::NDArray(ret) = starmap_scalars_array_like(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
&vec![ScalarOrNDArray::NDArray(*self)],
|
&vec![ScalarOrNDArray::NDArray(*self)],
|
||||||
ret_dtype,
|
|
||||||
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]),
|
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]),
|
||||||
)?
|
)?
|
||||||
else {
|
else {
|
||||||
|
@ -176,15 +174,18 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||||
&mut CodeGenContext<'ctx, 'a>,
|
&mut CodeGenContext<'ctx, 'a>,
|
||||||
Int<'ctx, SizeT>,
|
Int<'ctx, SizeT>,
|
||||||
ScalarObject<'ctx>,
|
ScalarObject<'ctx>,
|
||||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
) -> Result<ScalarObject<'ctx>, String>,
|
||||||
G: CodeGenerator + ?Sized,
|
G: CodeGenerator + ?Sized,
|
||||||
{
|
{
|
||||||
match self {
|
match self {
|
||||||
ScalarOrNDArray::Scalar(scalar) => {
|
ScalarOrNDArray::Scalar(scalar) => starmap_scalars_array_like(
|
||||||
scalar.map(generator, ctx, ret_dtype, mapping).map(ScalarOrNDArray::Scalar)
|
generator,
|
||||||
}
|
ctx,
|
||||||
|
&vec![ScalarOrNDArray::Scalar(*scalar)],
|
||||||
|
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]),
|
||||||
|
),
|
||||||
ScalarOrNDArray::NDArray(ndarray) => {
|
ScalarOrNDArray::NDArray(ndarray) => {
|
||||||
ndarray.map(generator, ctx, ret_dtype, mapping).map(ScalarOrNDArray::NDArray)
|
ndarray.map(generator, ctx, mapping).map(ScalarOrNDArray::NDArray)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1180,7 +1180,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
let ret_elem_ty = size_variant.of_int(&ctx.primitives);
|
let ret_elem_ty = size_variant.of_int(&ctx.primitives);
|
||||||
let func = match kind {
|
let func = match kind {
|
||||||
Kind::Ceil => builtin_fns::call_ceil,
|
Kind::Ceil => builtin_fns::call_ceil,
|
||||||
Kind::Floor => builtin_fns::call_floor,
|
Kind::Floor => builtin_fns::call_ceil_or_floor,
|
||||||
};
|
};
|
||||||
Ok(Some(func(generator, ctx, (arg_ty, arg), ret_elem_ty)?))
|
Ok(Some(func(generator, ctx, (arg_ty, arg), ret_elem_ty)?))
|
||||||
}),
|
}),
|
||||||
|
@ -1548,7 +1548,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
|
|
||||||
let func = match prim {
|
let func = match prim {
|
||||||
PrimDef::FunNpCeil => builtin_fns::call_ceil,
|
PrimDef::FunNpCeil => builtin_fns::call_ceil,
|
||||||
PrimDef::FunNpFloor => builtin_fns::call_floor,
|
PrimDef::FunNpFloor => builtin_fns::call_ceil_or_floor,
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
Ok(Some(func(generator, ctx, (arg_ty, arg), ctx.primitives.float)?))
|
Ok(Some(func(generator, ctx, (arg_ty, arg), ctx.primitives.float)?))
|
||||||
|
|
|
@ -336,6 +336,14 @@ impl Unifier {
|
||||||
self.unification_table.unioned(a, b)
|
self.unification_table.unioned(a, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Determine if a type unions with a type in `tys`.
|
||||||
|
pub fn unioned_any<I>(&mut self, a: Type, tys: I) -> bool
|
||||||
|
where
|
||||||
|
I: IntoIterator<Item = Type>,
|
||||||
|
{
|
||||||
|
tys.into_iter().any(|ty| self.unioned(a, ty))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn from_shared_unifier(unifier: &SharedUnifier) -> Unifier {
|
pub fn from_shared_unifier(unifier: &SharedUnifier) -> Unifier {
|
||||||
let lock = unifier.lock().unwrap();
|
let lock = unifier.lock().unwrap();
|
||||||
Unifier {
|
Unifier {
|
||||||
|
|
Loading…
Reference in New Issue