forked from M-Labs/nac3
1
0
Fork 0

core/ndstrides: checkpoint 2

This commit is contained in:
lyken 2024-08-08 14:58:26 +08:00
parent bcd35544cc
commit 937b36dcfd
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
16 changed files with 1459 additions and 703 deletions

File diff suppressed because it is too large Load Diff

View File

@ -953,9 +953,9 @@ pub fn call_ndarray_calc_broadcast_index<
) )
} }
pub fn call_nac3_throw_dummy_error<'ctx, G: CodeGenerator + ?Sized>( pub fn call_nac3_throw_dummy_error<G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'_, '_>,
) { ) {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_throw_dummy_error"); let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_throw_dummy_error");
CallFunction::begin(generator, ctx, &name).returning_void(); CallFunction::begin(generator, ctx, &name).returning_void();

View File

@ -3,9 +3,9 @@ use crate::codegen::{CodeGenContext, CodeGenerator};
// When [`TypeContext::size_type`] is 32-bits, the function name is "{fn_name}". // When [`TypeContext::size_type`] is 32-bits, the function name is "{fn_name}".
// When [`TypeContext::size_type`] is 64-bits, the function name is "{fn_name}64". // When [`TypeContext::size_type`] is 64-bits, the function name is "{fn_name}64".
#[must_use] #[must_use]
pub fn get_sizet_dependent_function_name<'ctx, G: CodeGenerator + ?Sized>( pub fn get_sizet_dependent_function_name<G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'_, '_>,
name: &str, name: &str,
) -> String { ) -> String {
let mut name = name.to_owned(); let mut name = name.to_owned();

View File

@ -48,9 +48,7 @@ struct TypeFieldTraversal<'ctx, 'a, G: CodeGenerator + ?Sized> {
field_types: Vec<BasicTypeEnum<'ctx>>, field_types: Vec<BasicTypeEnum<'ctx>>,
} }
impl<'ctx, 'a, 'b, G: CodeGenerator + ?Sized> FieldTraversal<'ctx> impl<'ctx, 'a, G: CodeGenerator + ?Sized> FieldTraversal<'ctx> for TypeFieldTraversal<'ctx, 'a, G> {
for TypeFieldTraversal<'ctx, 'a, G>
{
type Out<M> = (); type Out<M> = ();
fn add<M: Model<'ctx>>(&mut self, _name: &'static str, model: M) -> Self::Out<M> { fn add<M: Model<'ctx>>(&mut self, _name: &'static str, model: M) -> Self::Out<M> {
@ -204,6 +202,6 @@ impl<'ctx, S: StructKind<'ctx>> Ptr<'ctx, StructModel<S>> {
M: Model<'ctx>, M: Model<'ctx>,
GetField: FnOnce(S::Fields<GepFieldTraversal>) -> GepField<M>, GetField: FnOnce(S::Fields<GepFieldTraversal>) -> GepField<M>,
{ {
self.gep(ctx, get_field).store(ctx, value) self.gep(ctx, get_field).store(ctx, value);
} }
} }

View File

@ -27,14 +27,14 @@ pub fn call_memcpy_model<'ctx, Item: Model<'ctx> + Default, G: CodeGenerator + ?
/// Like [`gen_for_callback_incrementing`] with [`Model`] abstractions. /// Like [`gen_for_callback_incrementing`] with [`Model`] abstractions.
/// The [`IntKind`] is automatically inferred. /// The [`IntKind`] is automatically inferred.
pub fn gen_for_model_auto<'ctx, 'a, G, F, I, R>( pub fn gen_for_model_auto<'ctx, 'a, G, F, I>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
start: Int<'ctx, I>, start: Int<'ctx, I>,
stop: Int<'ctx, I>, stop: Int<'ctx, I>,
step: Int<'ctx, I>, step: Int<'ctx, I>,
body: F, body: F,
) -> Result<R, String> ) -> Result<(), String>
where where
G: CodeGenerator + ?Sized, G: CodeGenerator + ?Sized,
F: FnOnce( F: FnOnce(
@ -42,7 +42,7 @@ where
&mut CodeGenContext<'ctx, 'a>, &mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks<'ctx>, BreakContinueHooks<'ctx>,
Int<'ctx, I>, Int<'ctx, I>,
) -> Result<R, String>, ) -> Result<(), String>,
I: IntKind<'ctx> + Default, I: IntKind<'ctx> + Default,
{ {
let int_model = IntModel(I::default()); let int_model = IntModel(I::default());
@ -60,3 +60,32 @@ where
step.value, step.value,
) )
} }
/// Like [`gen_if_callback`] with [`Model`] abstractions and without the `else` block.
pub fn gen_if_model<'ctx, 'a, G, ThenFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
cond: Int<'ctx, Bool>,
then: ThenFn,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
ThenFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<(), String>,
{
let current_bb = ctx.builder.get_insert_block().unwrap();
let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "if.then");
let end_bb = ctx.ctx.insert_basic_block_after(then_bb, "if.end");
// Inserting into `current_bb`.
ctx.builder.build_conditional_branch(cond.value, then_bb, end_bb).unwrap();
// Inserting into `then_bb`
ctx.builder.position_at_end(then_bb);
then(generator, ctx)?;
ctx.builder.build_unconditional_branch(end_bb).unwrap();
// Reposition to `end_bb` for continuation.
ctx.builder.position_at_end(end_bb);
Ok(())
}

View File

@ -99,8 +99,7 @@ fn create_empty_ndarray<'ctx, G>(
where where
G: CodeGenerator + ?Sized, G: CodeGenerator + ?Sized,
{ {
let shape = parse_numpy_int_sequence(generator, ctx, shape, shape_ty); let (_, shape) = parse_numpy_int_sequence(generator, ctx, shape, shape_ty);
let shape = shape.value.get(generator, ctx, |f| f.items, "shape");
let ndarray = let ndarray =
NDArrayObject::alloca_uninitialized_of_type(generator, ctx, ndarray_ty, "ndarray"); NDArrayObject::alloca_uninitialized_of_type(generator, ctx, ndarray_ty, "ndarray");
@ -248,8 +247,7 @@ pub fn gen_ndarray_broadcast_to<'ctx>(
split_scalar_or_ndarray(generator, ctx, input, input_ty).as_ndarray(generator, ctx); split_scalar_or_ndarray(generator, ctx, input, input_ty).as_ndarray(generator, ctx);
// Process `shape` // Process `shape`
let broadcast_shape = parse_numpy_int_sequence(generator, ctx, shape, shape_ty); let (_, broadcast_shape) = parse_numpy_int_sequence(generator, ctx, shape, shape_ty);
let broadcast_shape = broadcast_shape.value.get(generator, ctx, |f| f.items, "shape");
// NOTE: shape.size should equal to `broadcasted_ndims`. // NOTE: shape.size should equal to `broadcasted_ndims`.
let broadcast_ndims_llvm = sizet_model.constant(generator, ctx.ctx, broadcast_ndims); let broadcast_ndims_llvm = sizet_model.constant(generator, ctx.ctx, broadcast_ndims);
call_nac3_ndarray_util_assert_shape_no_negative( call_nac3_ndarray_util_assert_shape_no_negative(
@ -300,8 +298,7 @@ pub fn gen_ndarray_reshape<'ctx>(
// Process the shape input from user and resolve negative indices. // Process the shape input from user and resolve negative indices.
// The resulting `new_shape`'s size should be equal to reshaped_ndims. // The resulting `new_shape`'s size should be equal to reshaped_ndims.
// This is ensured by the typechecker. // This is ensured by the typechecker.
let new_shape = parse_numpy_int_sequence(generator, ctx, shape, shape_ty); let (_, new_shape) = parse_numpy_int_sequence(generator, ctx, shape, shape_ty);
let new_shape = new_shape.value.get(generator, ctx, |f| f.items, "new_shape");
// Resolve unknown dimensions & validate `new_shape`. // Resolve unknown dimensions & validate `new_shape`.
let new_ndims = sizet_model.constant(generator, ctx.ctx, reshaped_ndims); let new_ndims = sizet_model.constant(generator, ctx.ctx, reshaped_ndims);
@ -353,7 +350,7 @@ pub fn gen_ndarray_arange<'ctx>(
// Create data and set elements // Create data and set elements
ndarray.create_data(generator, ctx); ndarray.create_data(generator, ctx);
ndarray.foreach(generator, ctx, |_generator, ctx, _hooks, i, pelement| { ndarray.foreach_pointer(generator, ctx, |_generator, ctx, _hooks, i, pelement| {
let val = let val =
ctx.builder.build_unsigned_int_to_float(i.value, ctx.ctx.f64_type(), "val").unwrap(); ctx.builder.build_unsigned_int_to_float(i.value, ctx.ctx.f64_type(), "val").unwrap();
ctx.builder.build_store(pelement, val).unwrap(); ctx.builder.build_store(pelement, val).unwrap();
@ -495,10 +492,9 @@ pub fn gen_ndarray_transpose<'ctx>(
// Parse argument #2 axes // Parse argument #2 axes
let in_axes_ty = fun.0.args[1].ty; let in_axes_ty = fun.0.args[1].ty;
let in_axes = args[1].1.clone().to_basic_value_enum(ctx, generator, in_axes_ty)?; let in_axes = args[1].1.clone().to_basic_value_enum(ctx, generator, in_axes_ty)?;
let in_axes = parse_numpy_int_sequence(generator, ctx, in_axes, in_axes_ty);
let (_, axes) = parse_numpy_int_sequence(generator, ctx, in_axes, in_axes_ty);
let num_axes = ndarray.get_ndims(generator, ctx.ctx); let num_axes = ndarray.get_ndims(generator, ctx.ctx);
let axes = in_axes.value.get(generator, ctx, |f| f.items, "axes");
call_nac3_ndarray_transpose( call_nac3_ndarray_transpose(
generator, generator,

View File

@ -1,6 +1,5 @@
use super::model::*; use super::model::*;
use super::structure::cslice::CSlice; use super::structure::cslice::CSlice;
use super::structure::ndarray::broadcast::broadcast_all_ndarrays;
use super::{ use super::{
super::symbol_resolver::ValueEnum, super::symbol_resolver::ValueEnum,
expr::destructure_range, expr::destructure_range,
@ -438,7 +437,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
let value = let value =
split_scalar_or_ndarray(generator, ctx, value, value_ty).as_ndarray(generator, ctx); split_scalar_or_ndarray(generator, ctx, value, value_ty).as_ndarray(generator, ctx);
let broadcast_result = broadcast_all_ndarrays(generator, ctx, &vec![target, value]); let broadcast_result = NDArrayObject::broadcast_all(generator, ctx, &[target, value]);
let target = broadcast_result.ndarrays[0]; let target = broadcast_result.ndarrays[0];
let value = broadcast_result.ndarrays[1]; let value = broadcast_result.ndarrays[1];

View File

@ -33,44 +33,38 @@ impl<'ctx, Item: Model<'ctx>, Size: IntKind<'ctx>> StructKind<'ctx> for List<Ite
} }
} }
pub struct ListObject<'ctx, Item: Model<'ctx>, Size: IntKind<'ctx>> { /// A NAC3 Python List object.
pub struct ListObject<'ctx> {
/// Typechecker type of the list items /// Typechecker type of the list items
pub item_type: Type, pub item_type: Type,
pub value: Ptr<'ctx, StructModel<List<Item, Size>>>, pub value: Ptr<'ctx, StructModel<List<AnyModel<'ctx>, SizeT>>>,
} }
impl<'ctx, Item: Model<'ctx>, Len: IntKind<'ctx>> ListObject<'ctx, Item, Len> { impl<'ctx> ListObject<'ctx> {
/// Create a [`ListObject`] from an LLVM value and its typechecker [`Type`]. /// Create a [`ListObject`] from an LLVM value and its typechecker [`Type`].
///
/// - The `Item` model has to be manually provided, and should match the
/// `get_llvm_type()` of `ty` and the `get_type()`. You may want to use
/// [`AnyModel`] if `ty`'s type is not knowable statically.
pub fn from_value_and_type<V: BasicValue<'ctx>, G: CodeGenerator + ?Sized>( pub fn from_value_and_type<V: BasicValue<'ctx>, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
val: V, list_val: V,
ty: Type, list_type: Type,
item_model: Item,
len_model: Len,
) -> Self { ) -> Self {
let plist_model = PtrModel(StructModel(List { item: item_model, len: len_model }));
// Check typechecker type and extract `item_type` // Check typechecker type and extract `item_type`
let item_type = match &*ctx.unifier.get_ty(ty) { let item_type = match &*ctx.unifier.get_ty(list_type) {
TypeEnum::TObj { obj_id, params, .. } TypeEnum::TObj { obj_id, params, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{ {
iter_type_vars(params).next().unwrap().ty // Extract `item_type` iter_type_vars(params).next().unwrap().ty // Extract `item_type`
} }
_ => panic!("Expecting type to be a list, but got {}", ctx.unifier.stringify(ty)), _ => {
panic!("Expecting type to be a list, but got {}", ctx.unifier.stringify(list_type))
}
}; };
// LLVM types of `item_model` and `ty` should match let item_model = AnyModel(ctx.get_llvm_type(generator, item_type));
let llvm_ty = ctx.get_llvm_type(generator, ty); let plist_model = PtrModel(StructModel(List { item: item_model, len: SizeT }));
item_model.check_type(generator, ctx.ctx, llvm_ty).unwrap();
// Create object // Create object
let val = plist_model.check_value(generator, ctx.ctx, val).unwrap(); let value = plist_model.check_value(generator, ctx.ctx, list_val).unwrap();
ListObject { item_type: item_type, value: val } ListObject { item_type, value }
} }
} }

View File

@ -71,62 +71,64 @@ pub struct BroadcastAllResult<'ctx> {
pub ndarrays: Vec<NDArrayObject<'ctx>>, pub ndarrays: Vec<NDArrayObject<'ctx>>,
} }
// TODO: DOCUMENT: Behaves like `np.broadcast()`, except returns results differently. impl<'ctx> NDArrayObject<'ctx> {
pub fn broadcast_all_ndarrays<'ctx, G: CodeGenerator + ?Sized>( // TODO: DOCUMENT: Behaves like `np.broadcast()`, except returns results differently.
generator: &mut G, pub fn broadcast_all<G: CodeGenerator + ?Sized>(
ctx: &mut CodeGenContext<'ctx, '_>, generator: &mut G,
ndarrays: &Vec<NDArrayObject<'ctx>>, ctx: &mut CodeGenContext<'ctx, '_>,
) -> BroadcastAllResult<'ctx> { ndarrays: &[Self],
assert!(!ndarrays.is_empty()); ) -> BroadcastAllResult<'ctx> {
assert!(!ndarrays.is_empty());
let sizet_model = IntModel(SizeT); let sizet_model = IntModel(SizeT);
let shape_model = StructModel(ShapeEntry); let shape_model = StructModel(ShapeEntry);
let broadcast_ndims = get_broadcast_all_ndims(ndarrays.iter().map(|ndarray| ndarray.ndims)); let broadcast_ndims = get_broadcast_all_ndims(ndarrays.iter().map(|ndarray| ndarray.ndims));
// Prepare input shape entries // Prepare input shape entries
let num_shape_entries = let num_shape_entries =
sizet_model.constant(generator, ctx.ctx, u64::try_from(ndarrays.len()).unwrap()); sizet_model.constant(generator, ctx.ctx, u64::try_from(ndarrays.len()).unwrap());
let shape_entries = let shape_entries =
shape_model.array_alloca(generator, ctx, num_shape_entries.value, "shape_entries"); shape_model.array_alloca(generator, ctx, num_shape_entries.value, "shape_entries");
for (i, ndarray) in ndarrays.iter().enumerate() { for (i, ndarray) in ndarrays.iter().enumerate() {
let i = sizet_model.constant(generator, ctx.ctx, i as u64).value; let i = sizet_model.constant(generator, ctx.ctx, i as u64).value;
let shape_entry = shape_entries.offset(generator, ctx, i, "shape_entry"); let shape_entry = shape_entries.offset(generator, ctx, i, "shape_entry");
let this_ndims = ndarray.value.get(generator, ctx, |f| f.ndims, "this_ndims"); let this_ndims = ndarray.value.get(generator, ctx, |f| f.ndims, "this_ndims");
shape_entry.set(ctx, |f| f.ndims, this_ndims); shape_entry.set(ctx, |f| f.ndims, this_ndims);
let this_shape = ndarray.value.get(generator, ctx, |f| f.shape, "this_shape"); let this_shape = ndarray.value.get(generator, ctx, |f| f.shape, "this_shape");
shape_entry.set(ctx, |f| f.shape, this_shape); shape_entry.set(ctx, |f| f.shape, this_shape);
} }
// Prepare destination // Prepare destination
let broadcast_ndims_llvm = sizet_model.constant(generator, ctx.ctx, broadcast_ndims); let broadcast_ndims_llvm = sizet_model.constant(generator, ctx.ctx, broadcast_ndims);
let broadcast_shape = let broadcast_shape =
sizet_model.array_alloca(generator, ctx, broadcast_ndims_llvm.value, "dst_shape"); sizet_model.array_alloca(generator, ctx, broadcast_ndims_llvm.value, "dst_shape");
// Compute the target broadcast shape `dst_shape` for all ndarrays. // Compute the target broadcast shape `dst_shape` for all ndarrays.
call_nac3_ndarray_broadcast_shapes( call_nac3_ndarray_broadcast_shapes(
generator, generator,
ctx, ctx,
num_shape_entries, num_shape_entries,
shape_entries, shape_entries,
broadcast_ndims_llvm, broadcast_ndims_llvm,
broadcast_shape, broadcast_shape,
); );
// Now that we know about the broadcasting shape, broadcast all the inputs. // Now that we know about the broadcasting shape, broadcast all the inputs.
// Broadcast all the inputs to shape `dst_shape`. // Broadcast all the inputs to shape `dst_shape`.
let broadcast_ndarrays: Vec<_> = ndarrays let broadcast_ndarrays: Vec<_> = ndarrays
.into_iter() .iter()
.map(|ndarray| ndarray.broadcast_to(generator, ctx, broadcast_ndims, broadcast_shape)) .map(|ndarray| ndarray.broadcast_to(generator, ctx, broadcast_ndims, broadcast_shape))
.collect_vec(); .collect_vec();
BroadcastAllResult { BroadcastAllResult {
ndims: broadcast_ndims, ndims: broadcast_ndims,
shape: broadcast_shape, shape: broadcast_shape,
ndarrays: broadcast_ndarrays, ndarrays: broadcast_ndarrays,
}
} }
} }

View File

@ -0,0 +1,477 @@
use inkwell::{
values::{BasicValue, FloatValue, IntValue},
FloatPredicate, IntPredicate,
};
use itertools::Itertools;
use crate::{
codegen::{
llvm_intrinsics,
model::{
util::{gen_for_model_auto, gen_if_model},
*,
},
CodeGenContext, CodeGenerator,
},
typecheck::typedef::Type,
};
use super::{scalar::ScalarObject, NDArrayObject};
/// Convenience function to crash the program when types of arguments are not supported.
/// Used to be debugged with a stacktrace.
fn unsupported_type<I>(ctx: &CodeGenContext<'_, '_>, tys: I) -> !
where
I: IntoIterator<Item = Type>,
{
unreachable!(
"unsupported types found '{}'",
tys.into_iter().map(|ty| format!("'{}'", ctx.unifier.stringify(ty))).join(", "),
)
}
#[derive(Debug, Clone, Copy)]
pub enum FloorOrCeil {
Floor,
Ceil,
}
#[derive(Debug, Clone, Copy)]
pub enum MinOrMax {
Min,
Max,
}
fn signed_ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
vec![ctx.primitives.int32, ctx.primitives.int64]
}
fn unsigned_ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
vec![ctx.primitives.uint32, ctx.primitives.uint64]
}
fn ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
vec![ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64]
}
fn int_like(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
vec![
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.int64,
ctx.primitives.uint32,
ctx.primitives.uint64,
]
}
fn cast_to_int_conversion<'ctx, 'a, G, HandleFloatFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
scalar: ScalarObject<'ctx>,
target_int_dtype: Type,
handle_float: HandleFloatFn,
) -> ScalarObject<'ctx>
where
G: CodeGenerator + ?Sized,
HandleFloatFn:
FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, FloatValue<'ctx>) -> IntValue<'ctx>,
{
let target_int_dtype_llvm = ctx.get_llvm_type(generator, target_int_dtype).into_int_type();
let result = if ctx.unifier.unioned(scalar.dtype, ctx.primitives.float) {
// Special handling for floats
let n = scalar.value.into_float_value();
handle_float(generator, ctx, n)
} else if ctx.unifier.unioned_any(scalar.dtype, int_like(ctx)) {
let n = scalar.value.into_int_value();
if n.get_type().get_bit_width() <= target_int_dtype_llvm.get_bit_width() {
ctx.builder.build_int_z_extend(n, target_int_dtype_llvm, "zext").unwrap()
} else {
ctx.builder.build_int_truncate(n, target_int_dtype_llvm, "trunc").unwrap()
}
} else {
unsupported_type(ctx, [scalar.dtype]);
};
assert_eq!(target_int_dtype_llvm.get_bit_width(), result.get_type().get_bit_width()); // Sanity check
ScalarObject { value: result.into(), dtype: target_int_dtype }
}
impl<'ctx> ScalarObject<'ctx> {
/// Compare two scalars. Only int-to-int and float-to-float comparisons are allowed.
/// Panic otherwise.
pub fn compare<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
lhs: ScalarObject<'ctx>,
rhs: ScalarObject<'ctx>,
int_predicate: IntPredicate,
float_predicate: FloatPredicate,
name: &str,
) -> Int<'ctx, Bool> {
if !ctx.unifier.unioned(lhs.dtype, rhs.dtype) {
unsupported_type(ctx, [lhs.dtype, rhs.dtype]);
}
let bool_model = IntModel(Bool);
let common_ty = lhs.dtype;
let result = if ctx.unifier.unioned(common_ty, ctx.primitives.float) {
let lhs = lhs.value.into_float_value();
let rhs = rhs.value.into_float_value();
ctx.builder.build_float_compare(float_predicate, lhs, rhs, name).unwrap()
} else if ctx.unifier.unioned_any(common_ty, int_like(ctx)) {
let lhs = lhs.value.into_int_value();
let rhs = rhs.value.into_int_value();
ctx.builder.build_int_compare(int_predicate, lhs, rhs, name).unwrap()
} else {
unsupported_type(ctx, [lhs.dtype, rhs.dtype]);
};
bool_model.check_value(generator, ctx.ctx, result).unwrap()
}
/// Invoke NAC3's builtin `int32()`.
#[must_use]
pub fn cast_to_int32<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Self {
cast_to_int_conversion(
generator,
ctx,
*self,
ctx.primitives.int32,
|_generator, ctx, input| {
let n =
ctx.builder.build_float_to_signed_int(input, ctx.ctx.i64_type(), "").unwrap();
ctx.builder.build_int_truncate(n, ctx.ctx.i32_type(), "conv").unwrap()
},
)
}
/// Invoke NAC3's builtin `int64()`.
#[must_use]
pub fn cast_to_int64<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Self {
cast_to_int_conversion(
generator,
ctx,
*self,
ctx.primitives.int64,
|_generator, ctx, input| {
ctx.builder.build_float_to_signed_int(input, ctx.ctx.i64_type(), "").unwrap()
},
)
}
/// Invoke NAC3's builtin `uint32()`.
#[must_use]
pub fn cast_to_uint32<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Self {
cast_to_int_conversion(
generator,
ctx,
*self,
ctx.primitives.uint32,
|_generator, ctx, n| {
let n_gez = ctx
.builder
.build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "")
.unwrap();
let to_int32 =
ctx.builder.build_float_to_signed_int(n, ctx.ctx.i32_type(), "").unwrap();
let to_uint64 =
ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap();
ctx.builder
.build_select(
n_gez,
ctx.builder.build_int_truncate(to_uint64, ctx.ctx.i32_type(), "").unwrap(),
to_int32,
"conv",
)
.unwrap()
.into_int_value()
},
)
}
/// Invoke NAC3's builtin `uint64()`.
#[must_use]
pub fn cast_to_uint64<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Self {
cast_to_int_conversion(
generator,
ctx,
*self,
ctx.primitives.uint64,
|_generator, ctx, n| {
let val_gez = ctx
.builder
.build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "")
.unwrap();
let to_int64 =
ctx.builder.build_float_to_signed_int(n, ctx.ctx.i64_type(), "").unwrap();
let to_uint64 =
ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap();
ctx.builder
.build_select(val_gez, to_uint64, to_int64, "conv")
.unwrap()
.into_int_value()
},
)
}
/// Invoke NAC3's builtin `bool()`.
#[must_use]
pub fn cast_to_bool<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Self {
// TODO: Why is the original code being so lax about i1 and i8 for the returned int type?
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.bool) {
self.value.into_int_value()
} else if ctx.unifier.unioned_any(self.dtype, ints(ctx)) {
let n = self.value.into_int_value();
ctx.builder
.build_int_compare(inkwell::IntPredicate::NE, n, n.get_type().const_zero(), "bool")
.unwrap()
} else if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
let n = self.value.into_float_value();
ctx.builder
.build_float_compare(FloatPredicate::UNE, n, n.get_type().const_zero(), "bool")
.unwrap()
} else {
unsupported_type(ctx, [self.dtype])
};
ScalarObject { dtype: ctx.primitives.bool, value: result.as_basic_value_enum() }
}
/// Invoke NAC3's builtin `round()`.
#[must_use]
pub fn round<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
target_int_dtype: Type,
) -> Self {
let target_int_dtype_llvm = ctx.get_llvm_type(generator, target_int_dtype).into_int_type();
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
let n = self.value.into_float_value();
let n = llvm_intrinsics::call_float_round(ctx, n, None);
ctx.builder.build_float_to_signed_int(n, target_int_dtype_llvm, "round").unwrap()
} else {
unsupported_type(ctx, [self.dtype, target_int_dtype])
};
ScalarObject { dtype: target_int_dtype, value: result.as_basic_value_enum() }
}
/// Invoke NAC3's builtin `np_round()`.
///
/// NOTE: `np.round()` has different behaviors than `round()` in terms of their result
/// on "tie" cases and return type.
#[must_use]
pub fn np_round<G: CodeGenerator + ?Sized>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
let n = self.value.into_float_value();
llvm_intrinsics::call_float_rint(ctx, n, None)
} else {
unsupported_type(ctx, [self.dtype])
};
ScalarObject { dtype: ctx.primitives.float, value: result.as_basic_value_enum() }
}
/// Invoke NAC3's builtin `min()` or `max()`.
fn min_or_max_helper(
ctx: &mut CodeGenContext<'ctx, '_>,
kind: MinOrMax,
a: Self,
b: Self,
) -> Self {
if !ctx.unifier.unioned(a.dtype, b.dtype) {
unsupported_type(ctx, [a.dtype, b.dtype])
}
let common_dtype = a.dtype;
if ctx.unifier.unioned(common_dtype, ctx.primitives.float) {
let function = match kind {
MinOrMax::Min => llvm_intrinsics::call_float_minnum,
MinOrMax::Max => llvm_intrinsics::call_float_maxnum,
};
let result =
function(ctx, a.value.into_float_value(), b.value.into_float_value(), None);
ScalarObject { value: result.as_basic_value_enum(), dtype: ctx.primitives.float }
} else if ctx.unifier.unioned_any(
common_dtype,
[unsigned_ints(ctx).as_slice(), &[ctx.primitives.bool]].concat(),
) {
// Treating bool has an unsigned int since that is convenient
let function = match kind {
MinOrMax::Min => llvm_intrinsics::call_int_umin,
MinOrMax::Max => llvm_intrinsics::call_int_umax,
};
let result = function(ctx, a.value.into_int_value(), b.value.into_int_value(), None);
ScalarObject { value: result.as_basic_value_enum(), dtype: common_dtype }
} else {
unsupported_type(ctx, [common_dtype])
}
}
/// Invoke NAC3's builtin `floor()` or `ceil()`.
#[must_use]
pub fn floor_or_ceil<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
kind: FloorOrCeil,
target_int_dtype: Type,
) -> Self {
let target_int_dtype_llvm = ctx.get_llvm_type(generator, target_int_dtype).into_int_type();
if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
let function = match kind {
FloorOrCeil::Floor => llvm_intrinsics::call_float_floor,
FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil,
};
let n = self.value.into_float_value();
let n = function(ctx, n, None);
let n = ctx.builder.build_float_to_signed_int(n, target_int_dtype_llvm, "").unwrap();
ScalarObject { dtype: target_int_dtype, value: n.as_basic_value_enum() }
} else {
unsupported_type(ctx, [self.dtype])
}
}
}
impl<'ctx> NDArrayObject<'ctx> {
/// Helper function for NAC3's builtin `np_min()`, `np_max()`, `np_argmin()`, and `np_argmax()`.
///
/// Generate LLVM IR to find the extremum and index of the **first** extremum value.
///
/// Care has also been taken to make the error messages match that of NumPy.
fn min_or_max_helper<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
kind: MinOrMax,
on_empty_err_msg: &str,
) -> (ScalarObject<'ctx>, Int<'ctx, SizeT>) {
let sizet_model = IntModel(SizeT);
let dtype_llvm = ctx.get_llvm_type(generator, self.dtype);
// If the ndarray is empty, throw an error.
let is_empty = self.is_empty(generator, ctx);
ctx.make_assert(
generator,
is_empty.value,
"0:ValueError",
on_empty_err_msg,
[None, None, None],
ctx.current_loc,
);
// Setup and initialize the extremum to be the first element in the ndarray
let pextremum_index = sizet_model.alloca(generator, ctx, "extremum_index");
let pextremum = ctx.builder.build_alloca(dtype_llvm, "extremum").unwrap();
let zero = sizet_model.const_0(generator, ctx.ctx);
pextremum_index.store(ctx, zero);
let first_scalar = self.get_nth(generator, ctx, zero);
ctx.builder.build_store(pextremum, first_scalar.value).unwrap();
// Find extremum
let start = sizet_model.const_1(generator, ctx.ctx); // Start on 1
let stop = self.size(generator, ctx);
let step = sizet_model.const_1(generator, ctx.ctx);
gen_for_model_auto(generator, ctx, start, stop, step, |generator, ctx, _hooks, i| {
// Worth reading on "Notes" in <https://numpy.org/doc/stable/reference/generated/numpy.min.html#numpy.min>
// on how `NaN` values have to be handled.
let scalar = self.get_nth(generator, ctx, i);
let old_extremum = ctx.builder.build_load(pextremum, "current_extremum").unwrap();
let old_extremum = ScalarObject { dtype: self.dtype, value: old_extremum };
let new_extremum = ScalarObject::min_or_max_helper(ctx, kind, old_extremum, scalar);
// Check if new_extremum is more extreme than old_extremum.
let update_index = ScalarObject::compare(
generator,
ctx,
new_extremum,
old_extremum,
IntPredicate::NE,
FloatPredicate::ONE,
"",
);
gen_if_model(generator, ctx, update_index, |_generator, ctx| {
pextremum_index.store(ctx, i);
Ok(())
})
.unwrap();
Ok(())
})
.unwrap();
// Finally return the extremum and extremum index.
let extremum_index = pextremum_index.load(generator, ctx, "extremum_index");
let extremum = ctx.builder.build_load(pextremum, "extremum_value").unwrap();
let extremum = ScalarObject { dtype: self.dtype, value: extremum };
(extremum, extremum_index)
}
/// Invoke NAC3's builtin `np_min()` or `np_max()`.
pub fn min_or_max<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
kind: MinOrMax,
) -> ScalarObject<'ctx> {
let on_empty_err_msg = format!(
"zero-size array to reduction operation {} which has no identity",
match kind {
MinOrMax::Min => "minimum",
MinOrMax::Max => "maximum",
}
);
self.min_or_max_helper(generator, ctx, kind, &on_empty_err_msg).0
}
/// Invoke NAC3's builtin `np_argmin()` or `np_argmax()`.
pub fn argmin_or_argmax<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
kind: MinOrMax,
) -> Int<'ctx, SizeT> {
let on_empty_err_msg = format!(
"attempt to get {} of an empty sequence",
match kind {
MinOrMax::Min => "argmin",
MinOrMax::Max => "argmax",
}
);
self.min_or_max_helper(generator, ctx, kind, &on_empty_err_msg).1
}
}

View File

@ -1,7 +1,4 @@
use crate::codegen::{ use crate::codegen::{irrt::call_nac3_ndarray_index, model::*, CodeGenContext, CodeGenerator};
irrt::call_nac3_ndarray_index, model::*, structure::ndarray::scalar::ScalarObject,
CodeGenContext, CodeGenerator,
};
use super::{scalar::ScalarOrNDArray, NDArrayObject}; use super::{scalar::ScalarOrNDArray, NDArrayObject};
@ -130,7 +127,7 @@ impl<'ctx> RustNDIndex<'ctx> {
dst_ndindex_ptr: Ptr<'ctx, StructModel<NDIndex>>, dst_ndindex_ptr: Ptr<'ctx, StructModel<NDIndex>>,
) { ) {
let ndindex_type_model = IntModel(NDIndexType::default()); let ndindex_type_model = IntModel(NDIndexType::default());
let i32_model = IntModel(Int32::default()); let i32_model = IntModel(Int32);
let user_slice_model = StructModel(UserSlice); let user_slice_model = StructModel(UserSlice);
// Set `dst_ndindex_ptr->type` // Set `dst_ndindex_ptr->type`
@ -178,6 +175,7 @@ impl<'ctx> RustNDIndex<'ctx> {
impl<'ctx> NDArrayObject<'ctx> { impl<'ctx> NDArrayObject<'ctx> {
/// Get the ndims [`Type`] after indexing with a given slice. /// Get the ndims [`Type`] after indexing with a given slice.
#[must_use]
pub fn deduce_ndims_after_indexing_with(&self, indexes: &[RustNDIndex<'ctx>]) -> u64 { pub fn deduce_ndims_after_indexing_with(&self, indexes: &[RustNDIndex<'ctx>]) -> u64 {
let mut ndims = self.ndims; let mut ndims = self.ndims;
for index in indexes { for index in indexes {
@ -235,11 +233,8 @@ impl<'ctx> NDArrayObject<'ctx> {
let subndarray = self.index(generator, ctx, indexes, name); let subndarray = self.index(generator, ctx, indexes, name);
if subndarray.is_unsized() { if subndarray.is_unsized() {
// NOTE: `np.size(self) == 0` is impossible. // NOTE: `np.size(self) == 0` here is never possible.
let pfirst = subndarray.get_nth_pelement(generator, ctx, zero, "pfirst"); ScalarOrNDArray::Scalar(subndarray.get_nth(generator, ctx, zero))
let first = ctx.builder.build_load(pfirst, "first").unwrap();
ScalarOrNDArray::Scalar(ScalarObject { dtype: self.dtype, value: first })
} else { } else {
ScalarOrNDArray::NDArray(subndarray) ScalarOrNDArray::NDArray(subndarray)
} }
@ -319,9 +314,9 @@ pub mod util {
}) })
}; };
let start = help(&start)?; let start = help(start)?;
let stop = help(&stop)?; let stop = help(stop)?;
let step = help(&step)?; let step = help(step)?;
RustNDIndex::Slice(RustUserSlice { start, stop, step }) RustNDIndex::Slice(RustUserSlice { start, stop, step })
} else { } else {

View File

@ -1,12 +1,11 @@
use inkwell::values::BasicValueEnum;
use itertools::Itertools; use itertools::Itertools;
use util::gen_for_model_auto; use util::gen_for_model_auto;
use crate::{ use crate::{
codegen::{ codegen::{
model::*, model::*,
structure::ndarray::{ structure::ndarray::{scalar::ScalarObject, NDArrayObject},
broadcast::broadcast_all_ndarrays, scalar::ScalarObject, NDArrayObject,
},
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}, },
typecheck::typedef::Type, typecheck::typedef::Type,
@ -14,179 +13,126 @@ use crate::{
use super::scalar::ScalarOrNDArray; use super::scalar::ScalarOrNDArray;
pub fn starmap_scalars_array_like<'ctx, 'a, F, G>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
inputs: &Vec<ScalarOrNDArray<'ctx>>,
mapping: F,
) -> Result<ScalarOrNDArray<'ctx>, String>
where
F: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
Int<'ctx, SizeT>,
&Vec<ScalarObject<'ctx>>,
) -> Result<ScalarObject<'ctx>, String>,
G: CodeGenerator + ?Sized,
{
assert!(!inputs.is_empty());
let sizet_model = IntModel(SizeT);
// Check if all inputs are ScalarObjects
let scalars: Option<Vec<_>> =
inputs.iter().map(|input| ScalarObject::try_from(input)).try_collect().ok();
match scalars {
Some(scalars) => {
// When inputs are all scalars, return a ScalarObject back
let i = sizet_model.const_0(generator, ctx.ctx);
let scalar = mapping(generator, ctx, i, &scalars)?;
Ok(ScalarOrNDArray::Scalar(scalar))
}
None => {
// When not all inputs are scalars, promote all non-ndarray inputs
// to ndarrays, do broadcast_shapes on them, and map.
let ndarrays =
inputs.into_iter().map(|input| input.as_ndarray(generator, ctx)).collect_vec();
let broadcast_result = broadcast_all_ndarrays(generator, ctx, &ndarrays);
let start = sizet_model.const_0(generator, ctx.ctx);
let stop = broadcast_result.ndarrays[0].size(generator, ctx); // They all should have the same `np.size`.
let step = sizet_model.const_1(generator, ctx.ctx);
// Map element-wise and store results into `mapped_ndarray`.
let mapped_ndarray = gen_for_model_auto(
generator,
ctx,
start,
stop,
step,
move |generator, ctx, _hooks, i| {
let elements = ndarrays
.iter()
.map(|ndarray| {
let pelement = ndarray.get_nth_pelement(generator, ctx, i, "pelement");
let element = ctx.builder.build_load(pelement, "element").unwrap();
ScalarObject { value: element, dtype: ndarray.dtype }
})
.collect_vec();
let ret = mapping(generator, ctx, i, &elements)?;
// It might look weird but it is perfectly fine putting the allocation codegen
// here within `for`.
// The reason for doing this is to get the `dtype` out of `ret`, which is only
// available after running `mapping`.
let mapped_ndarray = NDArrayObject::alloca_uninitialized(
generator,
ctx,
ret.dtype,
broadcast_result.ndims,
"mapped_ndarray",
);
mapped_ndarray.copy_shape_from_array(generator, ctx, broadcast_result.shape);
mapped_ndarray.create_data(generator, ctx);
let pret = mapped_ndarray.get_nth_pelement(generator, ctx, i, "pret");
ctx.builder.build_store(pret, ret.value).unwrap();
Ok(mapped_ndarray)
},
)?;
Ok(ScalarOrNDArray::NDArray(mapped_ndarray))
}
}
}
impl<'ctx> ScalarObject<'ctx> {
pub fn map<'a, F, G>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
mapping: F,
) -> Result<Self, String>
where
F: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
Int<'ctx, SizeT>,
ScalarObject<'ctx>,
) -> Result<ScalarObject<'ctx>, String>,
G: CodeGenerator + ?Sized,
{
let ScalarOrNDArray::Scalar(ret) = starmap_scalars_array_like(
generator,
ctx,
&vec![ScalarOrNDArray::Scalar(*self)],
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]),
)?
else {
unreachable!()
};
Ok(ret)
}
}
impl<'ctx> NDArrayObject<'ctx> { impl<'ctx> NDArrayObject<'ctx> {
pub fn map<'a, F, G>( /// TODO: Document me. Has complex behavior.
&self, pub fn broadcasting_starmap<'a, G, MappingFn>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
mapping: F, ndarrays: &[Self],
ret_dtype: Type,
name: &str,
mapping: MappingFn,
) -> Result<Self, String> ) -> Result<Self, String>
where where
F: FnOnce( G: CodeGenerator + ?Sized,
MappingFn: FnOnce(
&mut G, &mut G,
&mut CodeGenContext<'ctx, 'a>, &mut CodeGenContext<'ctx, 'a>,
Int<'ctx, SizeT>, Int<'ctx, SizeT>,
ScalarObject<'ctx>, &[ScalarObject<'ctx>],
) -> Result<ScalarObject<'ctx>, String>, ) -> Result<ScalarObject<'ctx>, String>,
G: CodeGenerator + ?Sized,
{ {
let ScalarOrNDArray::NDArray(ret) = starmap_scalars_array_like( let sizet_model = IntModel(SizeT);
// Broadcast inputs
let broadcast_result = NDArrayObject::broadcast_all(generator, ctx, ndarrays);
// Allocate the resulting ndarray
let mapped_ndarray = NDArrayObject::alloca_uninitialized(
generator, generator,
ctx, ctx,
&vec![ScalarOrNDArray::NDArray(*self)], ret_dtype,
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]), broadcast_result.ndims,
)? name,
else { );
unreachable!() mapped_ndarray.copy_shape_from_array(generator, ctx, broadcast_result.shape);
}; mapped_ndarray.create_data(generator, ctx);
Ok(ret)
}
}
impl<'ctx> ScalarOrNDArray<'ctx> { // Map element-wise and store results into `mapped_ndarray`.
pub fn map<'a, F, G>( let start = sizet_model.const_0(generator, ctx.ctx);
let stop = broadcast_result.ndarrays[0].size(generator, ctx); // They all should have the same `np.size`.
let step = sizet_model.const_1(generator, ctx.ctx);
gen_for_model_auto(generator, ctx, start, stop, step, move |generator, ctx, _hooks, i| {
let elements =
ndarrays.iter().map(|ndarray| ndarray.get_nth(generator, ctx, i)).collect_vec();
let ret = mapping(generator, ctx, i, &elements)?;
let pret = mapped_ndarray.get_nth_pointer(generator, ctx, i, "pret");
ctx.builder.build_store(pret, ret.value).unwrap();
Ok(())
})?;
Ok(mapped_ndarray)
}
pub fn map<'a, G, Mapping>(
&self, &self,
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
ret_dtype: Type, ret_dtype: Type,
mapping: F, name: &str,
mapping: Mapping,
) -> Result<Self, String> ) -> Result<Self, String>
where where
F: FnOnce( G: CodeGenerator + ?Sized,
Mapping: FnOnce(
&mut G, &mut G,
&mut CodeGenContext<'ctx, 'a>, &mut CodeGenContext<'ctx, 'a>,
Int<'ctx, SizeT>, Int<'ctx, SizeT>,
ScalarObject<'ctx>, ScalarObject<'ctx>,
) -> Result<ScalarObject<'ctx>, String>, ) -> Result<BasicValueEnum<'ctx>, String>,
G: CodeGenerator + ?Sized,
{ {
match self { NDArrayObject::broadcasting_starmap(
ScalarOrNDArray::Scalar(scalar) => starmap_scalars_array_like( generator,
generator, ctx,
ctx, &[*self],
&vec![ScalarOrNDArray::Scalar(*scalar)], ret_dtype,
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]), name,
), |generator, ctx, i, scalars| {
ScalarOrNDArray::NDArray(ndarray) => { let value = mapping(generator, ctx, i, scalars[0])?;
ndarray.map(generator, ctx, mapping).map(ScalarOrNDArray::NDArray) Ok(ScalarObject { dtype: ret_dtype, value })
} },
)
}
}
impl<'ctx> ScalarOrNDArray<'ctx> {
/// TODO: Document me. Has complex behavior.
pub fn broadcasting_starmap<'a, G, MappingFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
inputs: &[Self],
ret_dtype: Type,
name: &str,
mapping: MappingFn,
) -> Result<Self, String>
where
G: CodeGenerator + ?Sized,
MappingFn: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
Int<'ctx, SizeT>,
&[ScalarObject<'ctx>],
) -> Result<ScalarObject<'ctx>, String>,
{
let sizet_model = IntModel(SizeT);
// Check if all inputs are ScalarObjects
let all_scalars: Option<Vec<_>> =
inputs.iter().map(ScalarObject::try_from).try_collect().ok();
if let Some(scalars) = all_scalars {
let i = sizet_model.const_0(generator, ctx.ctx); // Pass 0 as the index
let scalar = mapping(generator, ctx, i, &scalars)?;
Ok(ScalarOrNDArray::Scalar(scalar))
} else {
// Promote all input to ndarrays and map through them.
let inputs = inputs.iter().map(|input| input.as_ndarray(generator, ctx)).collect_vec();
let ndarray = NDArrayObject::broadcasting_starmap(
generator, ctx, &inputs, ret_dtype, name, mapping,
)?;
Ok(ScalarOrNDArray::NDArray(ndarray))
} }
} }
} }

View File

@ -1,4 +1,5 @@
pub mod broadcast; pub mod broadcast;
pub mod functions;
pub mod indexing; pub mod indexing;
pub mod mapping; pub mod mapping;
pub mod scalar; pub mod scalar;
@ -22,8 +23,9 @@ use inkwell::{
context::Context, context::Context,
types::BasicType, types::BasicType,
values::{BasicValue, BasicValueEnum, PointerValue}, values::{BasicValue, BasicValueEnum, PointerValue},
AddressSpace, AddressSpace, IntPredicate,
}; };
use scalar::ScalarObject;
use util::{call_memcpy_model, gen_for_model_auto}; use util::{call_memcpy_model, gen_for_model_auto};
pub struct NpArrayFields<'ctx, F: FieldTraversal<'ctx>> { pub struct NpArrayFields<'ctx, F: FieldTraversal<'ctx>> {
@ -52,6 +54,7 @@ impl<'ctx> StructKind<'ctx> for NpArray {
} }
} }
/// A NAC3 Python ndarray object.
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub struct NDArrayObject<'ctx> { pub struct NDArrayObject<'ctx> {
pub dtype: Type, pub dtype: Type,
@ -116,7 +119,7 @@ impl<'ctx> NDArrayObject<'ctx> {
/// Get the pointer to the n-th (0-based) element. /// Get the pointer to the n-th (0-based) element.
/// ///
/// The returned pointer has the element type of the LLVM type of this ndarray's `dtype`. /// The returned pointer has the element type of the LLVM type of this ndarray's `dtype`.
pub fn get_nth_pelement<G: CodeGenerator + ?Sized>( pub fn get_nth_pointer<G: CodeGenerator + ?Sized>(
&self, &self,
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -131,6 +134,18 @@ impl<'ctx> NDArrayObject<'ctx> {
.unwrap() .unwrap()
} }
/// Get the n-th (0-based) scalar.
pub fn get_nth<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
nth: Int<'ctx, SizeT>,
) -> ScalarObject<'ctx> {
let p = self.get_nth_pointer(generator, ctx, nth, "value");
let value = ctx.builder.build_load(p, "value").unwrap();
ScalarObject { dtype: self.dtype, value }
}
/// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`. /// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`.
/// ///
/// Please refer to the IRRT implementation to see its purpose. /// Please refer to the IRRT implementation to see its purpose.
@ -210,7 +225,7 @@ impl<'ctx> NDArrayObject<'ctx> {
name: &str, name: &str,
) -> Self { ) -> Self {
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty); let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty);
let ndims = extract_ndims(&mut ctx.unifier, ndims); let ndims = extract_ndims(&ctx.unifier, ndims);
Self::alloca_uninitialized(generator, ctx, dtype, ndims, name) Self::alloca_uninitialized(generator, ctx, dtype, ndims, name)
} }
@ -224,7 +239,22 @@ impl<'ctx> NDArrayObject<'ctx> {
sizet_model.constant(generator, ctx, self.ndims) sizet_model.constant(generator, ctx, self.ndims)
} }
/// Return true if this ndarray is unsized. /// Get if this ndarray's `np.size` is `0` - containing no content.
pub fn is_empty<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, Bool> {
let sizet_model = IntModel(SizeT);
let size = self.size(generator, ctx);
size.compare(ctx, IntPredicate::EQ, sizet_model.const_0(generator, ctx.ctx), "is_empty")
}
/// Return true if this ndarray is unsized - `ndims == 0` and only contains a scalar.
///
/// This is a staticially known property of ndarrays. This is why it is returning
/// a Rust boolean instead of a [`BasicValue`].
#[must_use] #[must_use]
pub fn is_unsized(&self) -> bool { pub fn is_unsized(&self) -> bool {
self.ndims == 0 self.ndims == 0
@ -273,7 +303,7 @@ impl<'ctx> NDArrayObject<'ctx> {
) { ) {
assert_eq!(self.ndims, src_ndarray.ndims); assert_eq!(self.ndims, src_ndarray.ndims);
let src_shape = src_ndarray.value.get(generator, ctx, |f| f.shape, "src_shape"); let src_shape = src_ndarray.value.get(generator, ctx, |f| f.shape, "src_shape");
self.copy_shape_from_array(generator, ctx, src_shape) self.copy_shape_from_array(generator, ctx, src_shape);
} }
/// Copy strides dimensions from an array. /// Copy strides dimensions from an array.
@ -298,14 +328,14 @@ impl<'ctx> NDArrayObject<'ctx> {
) { ) {
assert_eq!(self.ndims, src_ndarray.ndims); assert_eq!(self.ndims, src_ndarray.ndims);
let src_strides = src_ndarray.value.get(generator, ctx, |f| f.strides, "src_strides"); let src_strides = src_ndarray.value.get(generator, ctx, |f| f.strides, "src_strides");
self.copy_strides_from_array(generator, ctx, src_strides) self.copy_strides_from_array(generator, ctx, src_strides);
} }
/// Loop through every element pointer in the ndarray in its flatten view. /// Iterate through every element pointer in the ndarray in its flatten view.
/// ///
/// `body` also access to [`BreakContinueHooks`] to short-circuit and an element's /// `body` also access to [`BreakContinueHooks`] to short-circuit and an element's
/// index. The given element pointer also has been casted to the LLVM type of this ndarray's `dtype`. /// index. The given element pointer also has been casted to the LLVM type of this ndarray's `dtype`.
pub fn foreach<'a, G, F>( pub fn foreach_pointer<'a, G, F>(
&self, &self,
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
@ -328,12 +358,36 @@ impl<'ctx> NDArrayObject<'ctx> {
let step = sizet_model.const_1(generator, ctx.ctx); let step = sizet_model.const_1(generator, ctx.ctx);
gen_for_model_auto(generator, ctx, start, stop, step, |generator, ctx, hooks, i| { gen_for_model_auto(generator, ctx, start, stop, step, |generator, ctx, hooks, i| {
let pelement = self.get_nth_pelement(generator, ctx, i, "element"); let pelement = self.get_nth_pointer(generator, ctx, i, "element");
body(generator, ctx, hooks, i, pelement) body(generator, ctx, hooks, i, pelement)
}) })
} }
/// Fill the NDArray with a value. /// Iterate through every scalar in this ndarray.
pub fn foreach<'a, G, F>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
body: F,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
F: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks<'ctx>,
Int<'ctx, SizeT>,
ScalarObject<'ctx>,
) -> Result<(), String>,
{
self.foreach_pointer(generator, ctx, |generator, ctx, hooks, i, p| {
let value = ctx.builder.build_load(p, "value").unwrap();
let scalar = ScalarObject { dtype: self.dtype, value };
body(generator, ctx, hooks, i, scalar)
})
}
/// Fill the ndarray with a value.
/// ///
/// `fill_value` must have the same LLVM type as the `dtype` of this ndarray. /// `fill_value` must have the same LLVM type as the `dtype` of this ndarray.
pub fn fill<G: CodeGenerator + ?Sized>( pub fn fill<G: CodeGenerator + ?Sized>(
@ -342,11 +396,11 @@ impl<'ctx> NDArrayObject<'ctx> {
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
fill_value: BasicValueEnum<'ctx>, fill_value: BasicValueEnum<'ctx>,
) { ) {
self.foreach(generator, ctx, |_generator, ctx, _hooks, _i, pelement| { self.foreach_pointer(generator, ctx, |_generator, ctx, _hooks, _i, pelement| {
ctx.builder.build_store(pelement, fill_value).unwrap(); ctx.builder.build_store(pelement, fill_value).unwrap();
Ok(()) Ok(())
}) })
.unwrap() .unwrap();
} }
/// Create a reshaped view on this ndarray like `np.reshape()`. /// Create a reshaped view on this ndarray like `np.reshape()`.

View File

@ -8,6 +8,10 @@ use crate::{
use super::NDArrayObject; use super::NDArrayObject;
/// An LLVM numpy scalar with its [`Type`]. /// An LLVM numpy scalar with its [`Type`].
///
/// Intended to be used with [`ScalarOrNDArray`].
///
/// A scalar does not have to be an actual number. It could be arbitrary objects.
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub struct ScalarObject<'ctx> { pub struct ScalarObject<'ctx> {
pub dtype: Type, pub dtype: Type,
@ -55,6 +59,22 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
} }
} }
#[must_use]
pub fn into_scalar(&self) -> ScalarObject<'ctx> {
match self {
ScalarOrNDArray::NDArray(_ndarray) => panic!("Got NDArray"),
ScalarOrNDArray::Scalar(scalar) => *scalar,
}
}
#[must_use]
pub fn into_ndarray(&self) -> NDArrayObject<'ctx> {
match self {
ScalarOrNDArray::NDArray(ndarray) => *ndarray,
ScalarOrNDArray::Scalar(_scalar) => panic!("Got Scalar"),
}
}
/// Convert this [`ScalarOrNDArray`] to an ndarray - behaves like `np.asarray`. /// Convert this [`ScalarOrNDArray`] to an ndarray - behaves like `np.asarray`.
/// - If this is an ndarray, the ndarray is returned. /// - If this is an ndarray, the ndarray is returned.
/// - If this is a scalar, an unsized ndarray view is created on it. /// - If this is a scalar, an unsized ndarray view is created on it.
@ -68,6 +88,14 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
ScalarOrNDArray::Scalar(scalar) => scalar.as_ndarray(generator, ctx), ScalarOrNDArray::Scalar(scalar) => scalar.as_ndarray(generator, ctx),
} }
} }
#[must_use]
pub fn dtype(&self) -> Type {
match self {
ScalarOrNDArray::Scalar(scalar) => scalar.dtype,
ScalarOrNDArray::NDArray(ndarray) => ndarray.dtype,
}
}
} }
impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for ScalarObject<'ctx> { impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for ScalarObject<'ctx> {
@ -76,7 +104,18 @@ impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for ScalarObject<'ctx> {
fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result<Self, Self::Error> { fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result<Self, Self::Error> {
match value { match value {
ScalarOrNDArray::Scalar(scalar) => Ok(*scalar), ScalarOrNDArray::Scalar(scalar) => Ok(*scalar),
ScalarOrNDArray::NDArray(_) => Err(()), ScalarOrNDArray::NDArray(_ndarray) => Err(()),
}
}
}
impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for NDArrayObject<'ctx> {
type Error = ();
fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result<Self, Self::Error> {
match value {
ScalarOrNDArray::Scalar(_scalar) => Err(()),
ScalarOrNDArray::NDArray(ndarray) => Ok(*ndarray),
} }
} }
} }

View File

@ -2,15 +2,11 @@ use inkwell::values::BasicValueEnum;
use util::gen_for_model_auto; use util::gen_for_model_auto;
use crate::{ use crate::{
codegen::{ codegen::{model::*, structure::list::ListObject, CodeGenContext, CodeGenerator},
model::*,
structure::list::{List, ListObject},
CodeGenContext, CodeGenerator,
},
typecheck::typedef::{Type, TypeEnum}, typecheck::typedef::{Type, TypeEnum},
}; };
/// Parse a NumPy-like "int sequence" input and return the int sequence as a [`ListObject`] /// Parse a NumPy-like "int sequence" input and return the int sequence as an array and its length.
/// ///
/// * `sequence` - The `sequence` parameter. /// * `sequence` - The `sequence` parameter.
/// * `sequence_ty` - The typechecker type of `sequence` /// * `sequence_ty` - The typechecker type of `sequence`
@ -20,99 +16,97 @@ use crate::{
/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))` /// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))`
/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` /// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
/// ///
/// `int32` values will be sign-extended to `SizeT` /// All `int32` values will be sign-extended to `SizeT`.
pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>( pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
sequence: BasicValueEnum<'ctx>, input_sequence: BasicValueEnum<'ctx>,
sequence_ty: Type, input_sequence_ty: Type,
) -> ListObject<'ctx, IntModel<SizeT>, SizeT> { ) -> (Int<'ctx, SizeT>, Ptr<'ctx, IntModel<SizeT>>) {
let sizet_model = IntModel(SizeT); let sizet_model = IntModel(SizeT);
let list_model = StructModel(List { len: SizeT, item: IntModel(SizeT) });
let zero = sizet_model.const_0(generator, ctx.ctx); let zero = sizet_model.const_0(generator, ctx.ctx);
let one = sizet_model.const_1(generator, ctx.ctx); let one = sizet_model.const_1(generator, ctx.ctx);
// The result `list` to return. // The result `list` to return.
let result = list_model.alloca(generator, ctx, "result_sequence"); match &*ctx.unifier.get_ty(input_sequence_ty) {
match &*ctx.unifier.get_ty(sequence_ty) {
TypeEnum::TObj { obj_id, .. } TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{ {
// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])` // 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
let in_sequence_model =
PtrModel(StructModel(List { item: IntModel(Int32), len: SizeT }));
let in_sequence = in_sequence_model.check_value(generator, ctx.ctx, sequence).unwrap();
/* // Check `input_sequence`
Reference code: let input_sequence =
``` ListObject::from_value_and_type(generator, ctx, input_sequence, input_sequence_ty);
result.size = sequence.size;
result.data = __builtin_alloca(sizeof(SizeT) * sequence.size);
for (SizeT i = 0; i < sequence.size; i++) {
result.data[i] = (SizeT) sequence.data[i];
}
return result
```
*/
let ndims = in_sequence.get(generator, ctx, |f| f.len, "size"); let len = input_sequence.value.gep(ctx, |f| f.len).load(generator, ctx, "len");
result.set(ctx, |f| f.len, ndims); let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence");
let result_data = sizet_model.array_alloca(generator, ctx, ndims.value, "data"); // Load all the `int32`s from the input_sequence, cast them to `SizeT`, and store them into `result`
result.set(ctx, |f| f.items, result_data); gen_for_model_auto(generator, ctx, zero, len, one, |generator, ctx, _hooks, i| {
// Load the i-th int32 in the input sequence
let int = input_sequence
.value
.get(generator, ctx, |f| f.items, "int")
.ix(generator, ctx, i.value, "int")
.value
.into_int_value();
// Cast to SizeT
let int = sizet_model.s_extend_or_bit_cast(generator, ctx, int, "int");
// Store
result.offset(generator, ctx, i.value, "int").store(ctx, int);
gen_for_model_auto(generator, ctx, zero, ndims, one, |generator, ctx, _hooks, i| {
let in_dim = in_sequence
.get(generator, ctx, |f| f.items, "in_dim")
.ix(generator, ctx, i.value, "in_dim")
.s_extend_or_bit_cast(generator, ctx, SizeT, "in_dim");
result_data.offset(generator, ctx, i.value, "dim").store(ctx, in_dim);
Ok(()) Ok(())
}) })
.unwrap(); .unwrap();
(len, result)
} }
TypeEnum::TTuple { ty: tuple_types } => { TypeEnum::TTuple { ty: tuple_types } => {
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))` // 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
let ndims_int = tuple_types.len(); let input_sequence = input_sequence.into_struct_value(); // A tuple is a struct
let ndims = sizet_model.constant(generator, ctx.ctx, ndims_int as u64);
result.set(ctx, |f| f.len, ndims);
// A tuple has to be a StructValue let len_int = tuple_types.len();
// Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM.
let tuple = sequence.into_struct_value();
let data = sizet_model.array_alloca(generator, ctx, ndims.value, "sequence_data");
result.set(ctx, |f| f.items, data);
for i in 0..ndims_int { let len = sizet_model.constant(generator, ctx.ctx, len_int as u64);
// Get the i-th (0-based) element off of the tuple and load it let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence");
// into `result`.
let dim = ctx for i in 0..len_int {
// Get the i-th element off of the tuple and load it into `result`.
let int = ctx
.builder .builder
.build_extract_value(tuple, i as u32, format!("dim").as_str()) .build_extract_value(input_sequence, i as u32, "int")
.unwrap() .unwrap()
.into_int_value(); .into_int_value();
let dim = sizet_model.s_extend_or_bit_cast(generator, ctx, dim, "dim"); let int = sizet_model.s_extend_or_bit_cast(generator, ctx, int, "int");
let offset = sizet_model.constant(generator, ctx.ctx, i as u64); let offset = sizet_model.constant(generator, ctx.ctx, i as u64);
data.offset(generator, ctx, offset.value, "dim").store(ctx, dim); result.offset(generator, ctx, offset.value, "int").store(ctx, int);
} }
(len, result)
} }
TypeEnum::TObj { obj_id, .. } TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() => if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() =>
{ {
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` // 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
let sequence_int = sizet_model.check_value(generator, ctx.ctx, sequence).unwrap(); let input_int = input_sequence.into_int_value();
// Size is 1 let len = sizet_model.const_1(generator, ctx.ctx);
result.set(ctx, |f| f.len, one); let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence");
// Alloca an array of length 1 and store the sole integer input into the array. let int = sizet_model.s_extend_or_bit_cast(generator, ctx, input_int, "int");
let data = sizet_model.array_alloca(generator, ctx, one.value, "data");
data.offset(generator, ctx, zero.value, "dim").store(ctx, sequence_int); // Storing into result[0]
result.store(ctx, int);
(len, result)
} }
_ => panic!("encountered unknown sequence type: {}", ctx.unifier.stringify(sequence_ty)), _ => panic!(
"encountered unknown sequence type: {}",
ctx.unifier.stringify(input_sequence_ty)
),
} }
ListObject { item_type: ctx.primitives.usize(), value: result }
} }

View File

@ -1180,7 +1180,7 @@ impl<'a> BuiltinBuilder<'a> {
let ret_elem_ty = size_variant.of_int(&ctx.primitives); let ret_elem_ty = size_variant.of_int(&ctx.primitives);
let func = match kind { let func = match kind {
Kind::Ceil => builtin_fns::call_ceil, Kind::Ceil => builtin_fns::call_ceil,
Kind::Floor => builtin_fns::call_ceil_or_floor, Kind::Floor => builtin_fns::call_floor,
}; };
Ok(Some(func(generator, ctx, (arg_ty, arg), ret_elem_ty)?)) Ok(Some(func(generator, ctx, (arg_ty, arg), ret_elem_ty)?))
}), }),
@ -1361,7 +1361,7 @@ impl<'a> BuiltinBuilder<'a> {
let ndims1 = create_ndims(self.unifier, 1); let ndims1 = create_ndims(self.unifier, 1);
let ndarray_float_1d = make_ndarray_ty( let ndarray_float_1d = make_ndarray_ty(
self.unifier, self.unifier,
&self.primitives, self.primitives,
Some(self.primitives.float), Some(self.primitives.float),
Some(ndims1), Some(ndims1),
); );
@ -1470,7 +1470,7 @@ impl<'a> BuiltinBuilder<'a> {
PrimDef::FunNpSize => { PrimDef::FunNpSize => {
// TODO: Make the return type usize // TODO: Make the return type usize
create_fn_by_codegen( create_fn_by_codegen(
&mut self.unifier, self.unifier,
&VarMap::new(), &VarMap::new(),
prim.name(), prim.name(),
self.primitives.int32, self.primitives.int32,
@ -1485,7 +1485,7 @@ impl<'a> BuiltinBuilder<'a> {
// of the input ndarray. // of the input ndarray.
let ret_ty = self.unifier.get_dummy_var().ty; let ret_ty = self.unifier.get_dummy_var().ty;
create_fn_by_codegen( create_fn_by_codegen(
&mut self.unifier, self.unifier,
&VarMap::new(), &VarMap::new(),
prim.name(), prim.name(),
ret_ty, ret_ty,
@ -1548,7 +1548,7 @@ impl<'a> BuiltinBuilder<'a> {
let func = match prim { let func = match prim {
PrimDef::FunNpCeil => builtin_fns::call_ceil, PrimDef::FunNpCeil => builtin_fns::call_ceil,
PrimDef::FunNpFloor => builtin_fns::call_ceil_or_floor, PrimDef::FunNpFloor => builtin_fns::call_floor,
_ => unreachable!(), _ => unreachable!(),
}; };
Ok(Some(func(generator, ctx, (arg_ty, arg), ctx.primitives.float)?)) Ok(Some(func(generator, ctx, (arg_ty, arg), ctx.primitives.float)?))