1
0
forked from M-Labs/nac3

Compare commits

...

2 Commits

9 changed files with 278 additions and 18 deletions

View File

@ -725,7 +725,7 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
llvm_usize.const_int(1, false), llvm_usize.const_int(1, false),
(n_sz, false), (n_sz, false),
|generator, ctx, idx| { |generator, ctx, _, idx| {
let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) }; let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) };
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
@ -941,7 +941,7 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
llvm_usize.const_int(1, false), llvm_usize.const_int(1, false),
(n_sz, false), (n_sz, false),
|generator, ctx, idx| { |generator, ctx, _, idx| {
let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) }; let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) };
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();

View File

@ -1706,7 +1706,7 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
ctx, ctx,
llvm_usize.const_zero(), llvm_usize.const_zero(),
(len, false), (len, false),
|generator, ctx, i| { |generator, ctx, _, i| {
let (dim_idx, dim_sz) = unsafe { let (dim_idx, dim_sz) = unsafe {
( (
indices.get_unchecked(ctx, generator, &i, None).into_int_value(), indices.get_unchecked(ctx, generator, &i, None).into_int_value(),

View File

@ -802,7 +802,7 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
llvm_usize.const_zero(), llvm_usize.const_zero(),
(min_ndims, false), (min_ndims, false),
|generator, ctx, idx| { |generator, ctx, _, idx| {
let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap(); let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap();
let (lhs_dim_sz, rhs_dim_sz) = unsafe { let (lhs_dim_sz, rhs_dim_sz) = unsafe {
( (

View File

@ -7,12 +7,11 @@ use crate::{
}, },
expr::gen_binop_expr_with_values, expr::gen_binop_expr_with_values,
irrt::{ irrt::{
calculate_len_for_slice_range, call_ndarray_calc_broadcast, self, calculate_len_for_slice_range, call_ndarray_calc_broadcast,
call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices, call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices,
call_ndarray_calc_size, call_ndarray_calc_size,
}, },
llvm_intrinsics, llvm_intrinsics::{self, call_memcpy_generic},
llvm_intrinsics::call_memcpy_generic,
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}, },
@ -32,6 +31,8 @@ use inkwell::{
}; };
use nac3parser::ast::{Operator, StrRef}; use nac3parser::ast::{Operator, StrRef};
use super::{builtin_fns::call_bool, stmt::BreakContinueHooks};
/// Creates an uninitialized `NDArray` instance. /// Creates an uninitialized `NDArray` instance.
fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>( fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
@ -86,7 +87,7 @@ where
ctx, ctx,
llvm_usize.const_zero(), llvm_usize.const_zero(),
(shape_len, false), (shape_len, false),
|generator, ctx, i| { |generator, ctx, _, i| {
let shape_dim = shape_data_fn(generator, ctx, shape, i)?; let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width()); debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
@ -131,7 +132,7 @@ where
ctx, ctx,
llvm_usize.const_zero(), llvm_usize.const_zero(),
(shape_len, false), (shape_len, false),
|generator, ctx, i| { |generator, ctx, _, i| {
let shape_dim = shape_data_fn(generator, ctx, shape, i)?; let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width()); debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
@ -334,7 +335,7 @@ where
ctx, ctx,
llvm_usize.const_zero(), llvm_usize.const_zero(),
(ndarray_num_elems, false), (ndarray_num_elems, false),
|generator, ctx, i| { |generator, ctx, _, i| {
let elem = unsafe { ndarray.data().ptr_offset_unchecked(ctx, generator, &i, None) }; let elem = unsafe { ndarray.data().ptr_offset_unchecked(ctx, generator, &i, None) };
let value = value_fn(generator, ctx, i)?; let value = value_fn(generator, ctx, i)?;
@ -1193,7 +1194,7 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
llvm_usize.const_int(slices.len() as u64, false), llvm_usize.const_int(slices.len() as u64, false),
(this.load_ndims(ctx), false), (this.load_ndims(ctx), false),
|generator, ctx, idx| { |generator, ctx, _, idx| {
unsafe { unsafe {
let dim_sz = this.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None); let dim_sz = this.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None);
ndarray.dim_sizes().set_typed_unchecked(ctx, generator, &idx, dim_sz); ndarray.dim_sizes().set_typed_unchecked(ctx, generator, &idx, dim_sz);
@ -1366,6 +1367,46 @@ where
Ok(ndarray) Ok(ndarray)
} }
/// LLVM-typed implementation for iterating through all elements within an `ndarray`.
///
/// * `ndarray`: The input [`NDArrayValue`] to iterate through.
/// * `body`: A lambda containing IR statements that acts on every element within `ndarray`.
/// It may also implement short-circuiting logic by branching with [`BreakContinueHooks`].
pub fn ndarray_iter_elem_impl<'ctx, G, BodyFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
body: BodyFn,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
BodyFn: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, '_>,
BreakContinueHooks,
BasicValueEnum<'ctx>,
) -> Result<(), String>,
{
let llvm_usize = generator.get_size_type(ctx.ctx);
let ndarray_size =
irrt::call_ndarray_calc_size(generator, ctx, &ndarray.dim_sizes(), (None, None));
gen_for_callback_incrementing(
generator,
ctx,
llvm_usize.const_int(0, false),
(ndarray_size, false),
|generator, ctx, hooks, idx| {
let scalar = unsafe { ndarray.data().get_unchecked(ctx, generator, &idx, None) };
body(generator, ctx, hooks, scalar)
},
llvm_usize.const_int(1, false),
)?;
Ok(())
}
/// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s. /// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s.
/// ///
/// * `elem_ty` - The element type of the `NDArray`. /// * `elem_ty` - The element type of the `NDArray`.
@ -1597,7 +1638,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
ctx, ctx,
llvm_i32.const_zero(), llvm_i32.const_zero(),
(common_dim, false), (common_dim, false),
|generator, ctx, i| { |generator, ctx, _, i| {
let i = ctx.builder.build_int_truncate(i, llvm_i32, "").unwrap(); let i = ctx.builder.build_int_truncate(i, llvm_i32, "").unwrap();
let ab_idx = generator.gen_array_var_alloc( let ab_idx = generator.gen_array_var_alloc(
@ -1984,3 +2025,136 @@ pub fn gen_ndarray_fill<'ctx>(
Ok(()) Ok(())
} }
/// Used by [`call_ndarray_any_all_impl`]
#[derive(Debug, Clone, Copy)]
enum AnyOrAll {
/// The numpy function `np.any()`
IsAny,
/// The numpy function `np.all()`
IsAll,
}
/// Helper function to create `np.any()` and `np.all()`.
///
/// Returns a boolean result in the form of an `i8` [`IntValue`].
///
/// They are mixed together since they are extremely similar in terms of implementation.
fn call_ndarray_any_all_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
kind: AnyOrAll,
elem_ty: Type,
ndarray: NDArrayValue<'ctx>,
) -> Result<IntValue<'ctx>, String> {
/*
NOTE: `np.any(<empty ndarray>)` returns false.
NOTE: `np.all(<empty ndarray>)` returns true.
Here is the reference C code of what the implemented LLVM of `np.any()` is essentially doing.
```c
// np.any(ndarray)
const int8 neutral = 0;
const int8 on_hit = 1;
int8 has_true = neutral;
for (size_t index = 0; index < ndarray.size; index++) {
Element *elem = ndarray.get(index);
bool elem_is_truthy = call_bool(*elem);
if (elem_is_truthy) {
has_true = on_hit;
break; // Short-circuiting for performance
}
}
```
*/
// Name of the function. Used here for creating LLVM labels and names.
let fn_name = match kind {
AnyOrAll::IsAny => "np_any",
AnyOrAll::IsAll => "np_all",
};
let llvm_i8 = ctx.ctx.i8_type();
let (neutral, on_hit) = match kind {
AnyOrAll::IsAny => (0, 1),
AnyOrAll::IsAll => (1, 0),
};
// The result of `np.any()`/`np.all()`
let result_ptr =
ctx.builder.build_alloca(llvm_i8, format!("{fn_name}.result").as_str()).unwrap();
ctx.builder.build_store(result_ptr, llvm_i8.const_int(neutral, false)).unwrap();
ndarray_iter_elem_impl(generator, ctx, ndarray, |generator, ctx, hooks, elem| {
// The basic block to go to when...
// - np.any() sees a `true`, then `result` is set from `false` to `true` and short-circuit.
// - np.all() sees a `false`, then `result` is set from `true` to `false` and short-circuit.
let on_hit_bb =
ctx.ctx.insert_basic_block_after(hooks.break_bb, format!("{fn_name}.on_hit").as_str());
let elem_is_truthy = call_bool(generator, ctx, (elem_ty, elem)).unwrap().into_int_value();
let (on_true, on_false) = match kind {
AnyOrAll::IsAny => (on_hit_bb, hooks.continue_bb),
AnyOrAll::IsAll => (hooks.continue_bb, on_hit_bb),
};
ctx.builder.build_conditional_branch(elem_is_truthy, on_true, on_false).unwrap();
// Begin inserting into `on_hit_bb`
ctx.builder.position_at_end(on_hit_bb);
ctx.builder.build_store(result_ptr, llvm_i8.const_int(on_hit, false)).unwrap();
ctx.builder.build_unconditional_branch(hooks.break_bb).unwrap();
Ok(())
})?;
// Load `result` and return it
let result = ctx.builder.build_load(result_ptr, "result").unwrap();
Ok(result.into_int_value())
}
/// Helper function to generate LLVM IR for `np.any()` and `np.all()`.
fn gen_ndarray_any_all_helper<'ctx>(
kind: AnyOrAll,
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<IntValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let llvm_usize = generator.get_size_type(context.ctx);
let in_ty = fun.0.args[0].ty;
let (elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, in_ty);
let in_arg = args[0].1.clone().to_basic_value_enum(context, generator, elem_ty)?;
let ndarray = NDArrayValue::from_ptr_val(in_arg.into_pointer_value(), llvm_usize, None);
call_ndarray_any_all_impl(generator, context, kind, elem_ty, ndarray)
}
/// Generates LLVM IR for `np.any()`.
pub fn gen_ndarray_any<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<IntValue<'ctx>, String> {
gen_ndarray_any_all_helper(AnyOrAll::IsAny, context, obj, fun, args, generator)
}
/// Generates LLVM IR for `np.all()`.
pub fn gen_ndarray_all<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<IntValue<'ctx>, String> {
gen_ndarray_any_all_helper(AnyOrAll::IsAll, context, obj, fun, args, generator)
}

View File

@ -462,6 +462,14 @@ pub fn gen_for<G: CodeGenerator>(
Ok(()) Ok(())
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct BreakContinueHooks<'ctx> {
/// [`BasicBlock`] to branch to for `break`-ing out of the loop.
pub break_bb: BasicBlock<'ctx>,
/// [`BasicBlock`] to branch to for `continue`-ing to the next iteration in the loop.
pub continue_bb: BasicBlock<'ctx>,
}
/// Generates a C-style `for` construct using lambdas, similar to the following C code: /// Generates a C-style `for` construct using lambdas, similar to the following C code:
/// ///
/// ```c /// ```c
@ -489,7 +497,8 @@ where
I: Clone, I: Clone,
InitFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<I, String>, InitFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<I, String>,
CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<IntValue<'ctx>, String>, CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<IntValue<'ctx>, String>,
BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>, BodyFn:
FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, BreakContinueHooks, I) -> Result<(), String>,
UpdateFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>, UpdateFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
{ {
let current_bb = ctx.builder.get_insert_block().unwrap(); let current_bb = ctx.builder.get_insert_block().unwrap();
@ -520,7 +529,8 @@ where
} }
ctx.builder.position_at_end(body_bb); ctx.builder.position_at_end(body_bb);
body(generator, ctx, loop_var.clone())?; let hooks = BreakContinueHooks { break_bb: cont_bb, continue_bb: update_bb };
body(generator, ctx, hooks, loop_var.clone())?;
if !ctx.is_terminated() { if !ctx.is_terminated() {
ctx.builder.build_unconditional_branch(update_bb).unwrap(); ctx.builder.build_unconditional_branch(update_bb).unwrap();
} }
@ -562,7 +572,12 @@ pub fn gen_for_callback_incrementing<'ctx, 'a, G, BodyFn>(
) -> Result<(), String> ) -> Result<(), String>
where where
G: CodeGenerator + ?Sized, G: CodeGenerator + ?Sized,
BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<(), String>, BodyFn: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks,
IntValue<'ctx>,
) -> Result<(), String>,
{ {
let init_val_t = init_val.get_type(); let init_val_t = init_val.get_type();
@ -584,10 +599,10 @@ where
Ok(ctx.builder.build_int_compare(cmp_op, i, max_val, "").unwrap()) Ok(ctx.builder.build_int_compare(cmp_op, i, max_val, "").unwrap())
}, },
|generator, ctx, i_addr| { |generator, ctx, hooks, i_addr| {
let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
body(generator, ctx, i) body(generator, ctx, hooks, i)
}, },
|_, ctx, i_addr| { |_, ctx, i_addr| {
let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
@ -698,7 +713,7 @@ where
Ok(cond) Ok(cond)
}, },
|generator, ctx, (i_addr, _)| { |generator, ctx, _, (i_addr, _)| {
let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
body_fn(generator, ctx, i) body_fn(generator, ctx, i)

View File

@ -497,6 +497,8 @@ impl<'a> BuiltinBuilder<'a> {
PrimDef::FunNpIsNan | PrimDef::FunNpIsInf => self.build_np_float_to_bool_function(prim), PrimDef::FunNpIsNan | PrimDef::FunNpIsInf => self.build_np_float_to_bool_function(prim),
PrimDef::FunNpAny | PrimDef::FunNpAll => self.build_np_any_all_function(prim),
PrimDef::FunNpSin PrimDef::FunNpSin
| PrimDef::FunNpCos | PrimDef::FunNpCos
| PrimDef::FunNpTan | PrimDef::FunNpTan
@ -1757,6 +1759,24 @@ impl<'a> BuiltinBuilder<'a> {
} }
} }
fn build_np_any_all_function(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpAny, PrimDef::FunNpAll]);
let param_ty = &[(self.ndarray_num_ty, "a")];
let ret_ty = self.primitives.bool;
let var_map = &into_var_map([]);
let codegen_callback: Box<GenCallCallback> =
Box::new(move |ctx, obj, fun, args, generator| {
let func = match prim {
PrimDef::FunNpAny => gen_ndarray_any,
PrimDef::FunNpAll => gen_ndarray_all,
_ => unreachable!(),
};
func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum()))
});
create_fn_by_codegen(self.unifier, var_map, prim.name(), ret_ty, param_ty, codegen_callback)
}
fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) { fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) {
(prim.simple_name().into(), method_ty, prim.id()) (prim.simple_name().into(), method_ty, prim.id())
} }

View File

@ -100,6 +100,8 @@ pub enum PrimDef {
FunNpHypot, FunNpHypot,
FunNpNextAfter, FunNpNextAfter,
FunSome, FunSome,
FunNpAny,
FunNpAll,
} }
/// Associated details of a [`PrimDef`] /// Associated details of a [`PrimDef`]
@ -251,6 +253,8 @@ impl PrimDef {
PrimDef::FunNpHypot => fun("np_hypot", None), PrimDef::FunNpHypot => fun("np_hypot", None),
PrimDef::FunNpNextAfter => fun("np_nextafter", None), PrimDef::FunNpNextAfter => fun("np_nextafter", None),
PrimDef::FunSome => fun("Some", None), PrimDef::FunSome => fun("Some", None),
PrimDef::FunNpAny => fun("np_any", None),
PrimDef::FunNpAll => fun("np_all", None),
} }
} }
} }

View File

@ -226,6 +226,8 @@ def patch(module):
module.np_full = np.full module.np_full = np.full
module.np_eye = np.eye module.np_eye = np.eye
module.np_identity = np.identity module.np_identity = np.identity
module.np_any = np.any
module.np_all = np.all
def file_import(filename, prefix="file_import_"): def file_import(filename, prefix="file_import_"):
filename = pathlib.Path(filename) filename = pathlib.Path(filename)

View File

@ -1388,6 +1388,48 @@ def test_ndarray_nextafter_broadcast_rhs_scalar():
output_ndarray_float_2(nextafter_x_zeros) output_ndarray_float_2(nextafter_x_zeros)
output_ndarray_float_2(nextafter_x_ones) output_ndarray_float_2(nextafter_x_ones)
def test_ndarray_any():
x1 = np_identity(5)
y1 = np_any(x1)
output_ndarray_float_2(x1)
output_bool(y1)
x2 = np_identity(1)
y2 = np_any(x2)
output_ndarray_float_2(x2)
output_bool(y2)
x3 = np_array([[1.0, 2.0], [3.0, 4.0]])
y3 = np_any(x3)
output_ndarray_float_2(x3)
output_bool(y3)
x4 = np_zeros([3, 5])
y4 = np_any(x4)
output_ndarray_float_2(x4)
output_bool(y4)
def test_ndarray_all():
x1 = np_identity(5)
y1 = np_all(x1)
output_ndarray_float_2(x1)
output_bool(y1)
x2 = np_identity(1)
y2 = np_all(x2)
output_ndarray_float_2(x2)
output_bool(y2)
x3 = np_array([[1.0, 2.0], [3.0, 4.0]])
y3 = np_all(x3)
output_ndarray_float_2(x3)
output_bool(y3)
x4 = np_zeros([3, 5])
y4 = np_all(x4)
output_ndarray_float_2(x4)
output_bool(y4)
def run() -> int32: def run() -> int32:
test_ndarray_ctor() test_ndarray_ctor()
test_ndarray_empty() test_ndarray_empty()
@ -1565,4 +1607,7 @@ def run() -> int32:
test_ndarray_nextafter_broadcast_lhs_scalar() test_ndarray_nextafter_broadcast_lhs_scalar()
test_ndarray_nextafter_broadcast_rhs_scalar() test_ndarray_nextafter_broadcast_rhs_scalar()
test_ndarray_any()
test_ndarray_all()
return 0 return 0