forked from M-Labs/nac3
1
0
Fork 0

core/ndstrides: checkpoint 2

This commit is contained in:
lyken 2024-08-08 14:58:26 +08:00
parent bcd35544cc
commit 937b36dcfd
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
16 changed files with 1459 additions and 703 deletions

File diff suppressed because it is too large Load Diff

View File

@ -953,9 +953,9 @@ pub fn call_ndarray_calc_broadcast_index<
)
}
pub fn call_nac3_throw_dummy_error<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_nac3_throw_dummy_error<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
ctx: &CodeGenContext<'_, '_>,
) {
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_throw_dummy_error");
CallFunction::begin(generator, ctx, &name).returning_void();

View File

@ -3,9 +3,9 @@ use crate::codegen::{CodeGenContext, CodeGenerator};
// When [`TypeContext::size_type`] is 32-bits, the function name is "{fn_name}".
// When [`TypeContext::size_type`] is 64-bits, the function name is "{fn_name}64".
#[must_use]
pub fn get_sizet_dependent_function_name<'ctx, G: CodeGenerator + ?Sized>(
pub fn get_sizet_dependent_function_name<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
ctx: &CodeGenContext<'_, '_>,
name: &str,
) -> String {
let mut name = name.to_owned();

View File

@ -48,9 +48,7 @@ struct TypeFieldTraversal<'ctx, 'a, G: CodeGenerator + ?Sized> {
field_types: Vec<BasicTypeEnum<'ctx>>,
}
impl<'ctx, 'a, 'b, G: CodeGenerator + ?Sized> FieldTraversal<'ctx>
for TypeFieldTraversal<'ctx, 'a, G>
{
impl<'ctx, 'a, G: CodeGenerator + ?Sized> FieldTraversal<'ctx> for TypeFieldTraversal<'ctx, 'a, G> {
type Out<M> = ();
fn add<M: Model<'ctx>>(&mut self, _name: &'static str, model: M) -> Self::Out<M> {
@ -204,6 +202,6 @@ impl<'ctx, S: StructKind<'ctx>> Ptr<'ctx, StructModel<S>> {
M: Model<'ctx>,
GetField: FnOnce(S::Fields<GepFieldTraversal>) -> GepField<M>,
{
self.gep(ctx, get_field).store(ctx, value)
self.gep(ctx, get_field).store(ctx, value);
}
}

View File

@ -27,14 +27,14 @@ pub fn call_memcpy_model<'ctx, Item: Model<'ctx> + Default, G: CodeGenerator + ?
/// Like [`gen_for_callback_incrementing`] with [`Model`] abstractions.
/// The [`IntKind`] is automatically inferred.
pub fn gen_for_model_auto<'ctx, 'a, G, F, I, R>(
pub fn gen_for_model_auto<'ctx, 'a, G, F, I>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
start: Int<'ctx, I>,
stop: Int<'ctx, I>,
step: Int<'ctx, I>,
body: F,
) -> Result<R, String>
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
F: FnOnce(
@ -42,7 +42,7 @@ where
&mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks<'ctx>,
Int<'ctx, I>,
) -> Result<R, String>,
) -> Result<(), String>,
I: IntKind<'ctx> + Default,
{
let int_model = IntModel(I::default());
@ -60,3 +60,32 @@ where
step.value,
)
}
/// Like [`gen_if_callback`] with [`Model`] abstractions and without the `else` block.
pub fn gen_if_model<'ctx, 'a, G, ThenFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
cond: Int<'ctx, Bool>,
then: ThenFn,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
ThenFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<(), String>,
{
let current_bb = ctx.builder.get_insert_block().unwrap();
let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "if.then");
let end_bb = ctx.ctx.insert_basic_block_after(then_bb, "if.end");
// Inserting into `current_bb`.
ctx.builder.build_conditional_branch(cond.value, then_bb, end_bb).unwrap();
// Inserting into `then_bb`
ctx.builder.position_at_end(then_bb);
then(generator, ctx)?;
ctx.builder.build_unconditional_branch(end_bb).unwrap();
// Reposition to `end_bb` for continuation.
ctx.builder.position_at_end(end_bb);
Ok(())
}

View File

@ -99,8 +99,7 @@ fn create_empty_ndarray<'ctx, G>(
where
G: CodeGenerator + ?Sized,
{
let shape = parse_numpy_int_sequence(generator, ctx, shape, shape_ty);
let shape = shape.value.get(generator, ctx, |f| f.items, "shape");
let (_, shape) = parse_numpy_int_sequence(generator, ctx, shape, shape_ty);
let ndarray =
NDArrayObject::alloca_uninitialized_of_type(generator, ctx, ndarray_ty, "ndarray");
@ -248,8 +247,7 @@ pub fn gen_ndarray_broadcast_to<'ctx>(
split_scalar_or_ndarray(generator, ctx, input, input_ty).as_ndarray(generator, ctx);
// Process `shape`
let broadcast_shape = parse_numpy_int_sequence(generator, ctx, shape, shape_ty);
let broadcast_shape = broadcast_shape.value.get(generator, ctx, |f| f.items, "shape");
let (_, broadcast_shape) = parse_numpy_int_sequence(generator, ctx, shape, shape_ty);
// NOTE: shape.size should equal to `broadcasted_ndims`.
let broadcast_ndims_llvm = sizet_model.constant(generator, ctx.ctx, broadcast_ndims);
call_nac3_ndarray_util_assert_shape_no_negative(
@ -300,8 +298,7 @@ pub fn gen_ndarray_reshape<'ctx>(
// Process the shape input from user and resolve negative indices.
// The resulting `new_shape`'s size should be equal to reshaped_ndims.
// This is ensured by the typechecker.
let new_shape = parse_numpy_int_sequence(generator, ctx, shape, shape_ty);
let new_shape = new_shape.value.get(generator, ctx, |f| f.items, "new_shape");
let (_, new_shape) = parse_numpy_int_sequence(generator, ctx, shape, shape_ty);
// Resolve unknown dimensions & validate `new_shape`.
let new_ndims = sizet_model.constant(generator, ctx.ctx, reshaped_ndims);
@ -353,7 +350,7 @@ pub fn gen_ndarray_arange<'ctx>(
// Create data and set elements
ndarray.create_data(generator, ctx);
ndarray.foreach(generator, ctx, |_generator, ctx, _hooks, i, pelement| {
ndarray.foreach_pointer(generator, ctx, |_generator, ctx, _hooks, i, pelement| {
let val =
ctx.builder.build_unsigned_int_to_float(i.value, ctx.ctx.f64_type(), "val").unwrap();
ctx.builder.build_store(pelement, val).unwrap();
@ -495,10 +492,9 @@ pub fn gen_ndarray_transpose<'ctx>(
// Parse argument #2 axes
let in_axes_ty = fun.0.args[1].ty;
let in_axes = args[1].1.clone().to_basic_value_enum(ctx, generator, in_axes_ty)?;
let in_axes = parse_numpy_int_sequence(generator, ctx, in_axes, in_axes_ty);
let (_, axes) = parse_numpy_int_sequence(generator, ctx, in_axes, in_axes_ty);
let num_axes = ndarray.get_ndims(generator, ctx.ctx);
let axes = in_axes.value.get(generator, ctx, |f| f.items, "axes");
call_nac3_ndarray_transpose(
generator,

View File

@ -1,6 +1,5 @@
use super::model::*;
use super::structure::cslice::CSlice;
use super::structure::ndarray::broadcast::broadcast_all_ndarrays;
use super::{
super::symbol_resolver::ValueEnum,
expr::destructure_range,
@ -438,7 +437,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
let value =
split_scalar_or_ndarray(generator, ctx, value, value_ty).as_ndarray(generator, ctx);
let broadcast_result = broadcast_all_ndarrays(generator, ctx, &vec![target, value]);
let broadcast_result = NDArrayObject::broadcast_all(generator, ctx, &[target, value]);
let target = broadcast_result.ndarrays[0];
let value = broadcast_result.ndarrays[1];

View File

@ -33,44 +33,38 @@ impl<'ctx, Item: Model<'ctx>, Size: IntKind<'ctx>> StructKind<'ctx> for List<Ite
}
}
pub struct ListObject<'ctx, Item: Model<'ctx>, Size: IntKind<'ctx>> {
/// A NAC3 Python List object.
pub struct ListObject<'ctx> {
/// Typechecker type of the list items
pub item_type: Type,
pub value: Ptr<'ctx, StructModel<List<Item, Size>>>,
pub value: Ptr<'ctx, StructModel<List<AnyModel<'ctx>, SizeT>>>,
}
impl<'ctx, Item: Model<'ctx>, Len: IntKind<'ctx>> ListObject<'ctx, Item, Len> {
impl<'ctx> ListObject<'ctx> {
/// Create a [`ListObject`] from an LLVM value and its typechecker [`Type`].
///
/// - The `Item` model has to be manually provided, and should match the
/// `get_llvm_type()` of `ty` and the `get_type()`. You may want to use
/// [`AnyModel`] if `ty`'s type is not knowable statically.
pub fn from_value_and_type<V: BasicValue<'ctx>, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
val: V,
ty: Type,
item_model: Item,
len_model: Len,
list_val: V,
list_type: Type,
) -> Self {
let plist_model = PtrModel(StructModel(List { item: item_model, len: len_model }));
// Check typechecker type and extract `item_type`
let item_type = match &*ctx.unifier.get_ty(ty) {
let item_type = match &*ctx.unifier.get_ty(list_type) {
TypeEnum::TObj { obj_id, params, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{
iter_type_vars(params).next().unwrap().ty // Extract `item_type`
}
_ => panic!("Expecting type to be a list, but got {}", ctx.unifier.stringify(ty)),
_ => {
panic!("Expecting type to be a list, but got {}", ctx.unifier.stringify(list_type))
}
};
// LLVM types of `item_model` and `ty` should match
let llvm_ty = ctx.get_llvm_type(generator, ty);
item_model.check_type(generator, ctx.ctx, llvm_ty).unwrap();
let item_model = AnyModel(ctx.get_llvm_type(generator, item_type));
let plist_model = PtrModel(StructModel(List { item: item_model, len: SizeT }));
// Create object
let val = plist_model.check_value(generator, ctx.ctx, val).unwrap();
ListObject { item_type: item_type, value: val }
let value = plist_model.check_value(generator, ctx.ctx, list_val).unwrap();
ListObject { item_type, value }
}
}

View File

@ -71,12 +71,13 @@ pub struct BroadcastAllResult<'ctx> {
pub ndarrays: Vec<NDArrayObject<'ctx>>,
}
// TODO: DOCUMENT: Behaves like `np.broadcast()`, except returns results differently.
pub fn broadcast_all_ndarrays<'ctx, G: CodeGenerator + ?Sized>(
impl<'ctx> NDArrayObject<'ctx> {
// TODO: DOCUMENT: Behaves like `np.broadcast()`, except returns results differently.
pub fn broadcast_all<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarrays: &Vec<NDArrayObject<'ctx>>,
) -> BroadcastAllResult<'ctx> {
ndarrays: &[Self],
) -> BroadcastAllResult<'ctx> {
assert!(!ndarrays.is_empty());
let sizet_model = IntModel(SizeT);
@ -120,7 +121,7 @@ pub fn broadcast_all_ndarrays<'ctx, G: CodeGenerator + ?Sized>(
// Broadcast all the inputs to shape `dst_shape`.
let broadcast_ndarrays: Vec<_> = ndarrays
.into_iter()
.iter()
.map(|ndarray| ndarray.broadcast_to(generator, ctx, broadcast_ndims, broadcast_shape))
.collect_vec();
@ -129,4 +130,5 @@ pub fn broadcast_all_ndarrays<'ctx, G: CodeGenerator + ?Sized>(
shape: broadcast_shape,
ndarrays: broadcast_ndarrays,
}
}
}

View File

@ -0,0 +1,477 @@
use inkwell::{
values::{BasicValue, FloatValue, IntValue},
FloatPredicate, IntPredicate,
};
use itertools::Itertools;
use crate::{
codegen::{
llvm_intrinsics,
model::{
util::{gen_for_model_auto, gen_if_model},
*,
},
CodeGenContext, CodeGenerator,
},
typecheck::typedef::Type,
};
use super::{scalar::ScalarObject, NDArrayObject};
/// Convenience function to crash the program when types of arguments are not supported.
/// Used to be debugged with a stacktrace.
fn unsupported_type<I>(ctx: &CodeGenContext<'_, '_>, tys: I) -> !
where
I: IntoIterator<Item = Type>,
{
unreachable!(
"unsupported types found '{}'",
tys.into_iter().map(|ty| format!("'{}'", ctx.unifier.stringify(ty))).join(", "),
)
}
#[derive(Debug, Clone, Copy)]
pub enum FloorOrCeil {
Floor,
Ceil,
}
#[derive(Debug, Clone, Copy)]
pub enum MinOrMax {
Min,
Max,
}
fn signed_ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
vec![ctx.primitives.int32, ctx.primitives.int64]
}
fn unsigned_ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
vec![ctx.primitives.uint32, ctx.primitives.uint64]
}
fn ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
vec![ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64]
}
fn int_like(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
vec![
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.int64,
ctx.primitives.uint32,
ctx.primitives.uint64,
]
}
fn cast_to_int_conversion<'ctx, 'a, G, HandleFloatFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
scalar: ScalarObject<'ctx>,
target_int_dtype: Type,
handle_float: HandleFloatFn,
) -> ScalarObject<'ctx>
where
G: CodeGenerator + ?Sized,
HandleFloatFn:
FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, FloatValue<'ctx>) -> IntValue<'ctx>,
{
let target_int_dtype_llvm = ctx.get_llvm_type(generator, target_int_dtype).into_int_type();
let result = if ctx.unifier.unioned(scalar.dtype, ctx.primitives.float) {
// Special handling for floats
let n = scalar.value.into_float_value();
handle_float(generator, ctx, n)
} else if ctx.unifier.unioned_any(scalar.dtype, int_like(ctx)) {
let n = scalar.value.into_int_value();
if n.get_type().get_bit_width() <= target_int_dtype_llvm.get_bit_width() {
ctx.builder.build_int_z_extend(n, target_int_dtype_llvm, "zext").unwrap()
} else {
ctx.builder.build_int_truncate(n, target_int_dtype_llvm, "trunc").unwrap()
}
} else {
unsupported_type(ctx, [scalar.dtype]);
};
assert_eq!(target_int_dtype_llvm.get_bit_width(), result.get_type().get_bit_width()); // Sanity check
ScalarObject { value: result.into(), dtype: target_int_dtype }
}
impl<'ctx> ScalarObject<'ctx> {
/// Compare two scalars. Only int-to-int and float-to-float comparisons are allowed.
/// Panic otherwise.
pub fn compare<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
lhs: ScalarObject<'ctx>,
rhs: ScalarObject<'ctx>,
int_predicate: IntPredicate,
float_predicate: FloatPredicate,
name: &str,
) -> Int<'ctx, Bool> {
if !ctx.unifier.unioned(lhs.dtype, rhs.dtype) {
unsupported_type(ctx, [lhs.dtype, rhs.dtype]);
}
let bool_model = IntModel(Bool);
let common_ty = lhs.dtype;
let result = if ctx.unifier.unioned(common_ty, ctx.primitives.float) {
let lhs = lhs.value.into_float_value();
let rhs = rhs.value.into_float_value();
ctx.builder.build_float_compare(float_predicate, lhs, rhs, name).unwrap()
} else if ctx.unifier.unioned_any(common_ty, int_like(ctx)) {
let lhs = lhs.value.into_int_value();
let rhs = rhs.value.into_int_value();
ctx.builder.build_int_compare(int_predicate, lhs, rhs, name).unwrap()
} else {
unsupported_type(ctx, [lhs.dtype, rhs.dtype]);
};
bool_model.check_value(generator, ctx.ctx, result).unwrap()
}
/// Invoke NAC3's builtin `int32()`.
#[must_use]
pub fn cast_to_int32<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Self {
cast_to_int_conversion(
generator,
ctx,
*self,
ctx.primitives.int32,
|_generator, ctx, input| {
let n =
ctx.builder.build_float_to_signed_int(input, ctx.ctx.i64_type(), "").unwrap();
ctx.builder.build_int_truncate(n, ctx.ctx.i32_type(), "conv").unwrap()
},
)
}
/// Invoke NAC3's builtin `int64()`.
#[must_use]
pub fn cast_to_int64<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Self {
cast_to_int_conversion(
generator,
ctx,
*self,
ctx.primitives.int64,
|_generator, ctx, input| {
ctx.builder.build_float_to_signed_int(input, ctx.ctx.i64_type(), "").unwrap()
},
)
}
/// Invoke NAC3's builtin `uint32()`.
#[must_use]
pub fn cast_to_uint32<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Self {
cast_to_int_conversion(
generator,
ctx,
*self,
ctx.primitives.uint32,
|_generator, ctx, n| {
let n_gez = ctx
.builder
.build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "")
.unwrap();
let to_int32 =
ctx.builder.build_float_to_signed_int(n, ctx.ctx.i32_type(), "").unwrap();
let to_uint64 =
ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap();
ctx.builder
.build_select(
n_gez,
ctx.builder.build_int_truncate(to_uint64, ctx.ctx.i32_type(), "").unwrap(),
to_int32,
"conv",
)
.unwrap()
.into_int_value()
},
)
}
/// Invoke NAC3's builtin `uint64()`.
#[must_use]
pub fn cast_to_uint64<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Self {
cast_to_int_conversion(
generator,
ctx,
*self,
ctx.primitives.uint64,
|_generator, ctx, n| {
let val_gez = ctx
.builder
.build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "")
.unwrap();
let to_int64 =
ctx.builder.build_float_to_signed_int(n, ctx.ctx.i64_type(), "").unwrap();
let to_uint64 =
ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap();
ctx.builder
.build_select(val_gez, to_uint64, to_int64, "conv")
.unwrap()
.into_int_value()
},
)
}
/// Invoke NAC3's builtin `bool()`.
#[must_use]
pub fn cast_to_bool<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Self {
// TODO: Why is the original code being so lax about i1 and i8 for the returned int type?
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.bool) {
self.value.into_int_value()
} else if ctx.unifier.unioned_any(self.dtype, ints(ctx)) {
let n = self.value.into_int_value();
ctx.builder
.build_int_compare(inkwell::IntPredicate::NE, n, n.get_type().const_zero(), "bool")
.unwrap()
} else if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
let n = self.value.into_float_value();
ctx.builder
.build_float_compare(FloatPredicate::UNE, n, n.get_type().const_zero(), "bool")
.unwrap()
} else {
unsupported_type(ctx, [self.dtype])
};
ScalarObject { dtype: ctx.primitives.bool, value: result.as_basic_value_enum() }
}
/// Invoke NAC3's builtin `round()`.
#[must_use]
pub fn round<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
target_int_dtype: Type,
) -> Self {
let target_int_dtype_llvm = ctx.get_llvm_type(generator, target_int_dtype).into_int_type();
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
let n = self.value.into_float_value();
let n = llvm_intrinsics::call_float_round(ctx, n, None);
ctx.builder.build_float_to_signed_int(n, target_int_dtype_llvm, "round").unwrap()
} else {
unsupported_type(ctx, [self.dtype, target_int_dtype])
};
ScalarObject { dtype: target_int_dtype, value: result.as_basic_value_enum() }
}
/// Invoke NAC3's builtin `np_round()`.
///
/// NOTE: `np.round()` has different behaviors than `round()` in terms of their result
/// on "tie" cases and return type.
#[must_use]
pub fn np_round<G: CodeGenerator + ?Sized>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
let n = self.value.into_float_value();
llvm_intrinsics::call_float_rint(ctx, n, None)
} else {
unsupported_type(ctx, [self.dtype])
};
ScalarObject { dtype: ctx.primitives.float, value: result.as_basic_value_enum() }
}
/// Invoke NAC3's builtin `min()` or `max()`.
fn min_or_max_helper(
ctx: &mut CodeGenContext<'ctx, '_>,
kind: MinOrMax,
a: Self,
b: Self,
) -> Self {
if !ctx.unifier.unioned(a.dtype, b.dtype) {
unsupported_type(ctx, [a.dtype, b.dtype])
}
let common_dtype = a.dtype;
if ctx.unifier.unioned(common_dtype, ctx.primitives.float) {
let function = match kind {
MinOrMax::Min => llvm_intrinsics::call_float_minnum,
MinOrMax::Max => llvm_intrinsics::call_float_maxnum,
};
let result =
function(ctx, a.value.into_float_value(), b.value.into_float_value(), None);
ScalarObject { value: result.as_basic_value_enum(), dtype: ctx.primitives.float }
} else if ctx.unifier.unioned_any(
common_dtype,
[unsigned_ints(ctx).as_slice(), &[ctx.primitives.bool]].concat(),
) {
// Treating bool has an unsigned int since that is convenient
let function = match kind {
MinOrMax::Min => llvm_intrinsics::call_int_umin,
MinOrMax::Max => llvm_intrinsics::call_int_umax,
};
let result = function(ctx, a.value.into_int_value(), b.value.into_int_value(), None);
ScalarObject { value: result.as_basic_value_enum(), dtype: common_dtype }
} else {
unsupported_type(ctx, [common_dtype])
}
}
/// Invoke NAC3's builtin `floor()` or `ceil()`.
#[must_use]
pub fn floor_or_ceil<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
kind: FloorOrCeil,
target_int_dtype: Type,
) -> Self {
let target_int_dtype_llvm = ctx.get_llvm_type(generator, target_int_dtype).into_int_type();
if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
let function = match kind {
FloorOrCeil::Floor => llvm_intrinsics::call_float_floor,
FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil,
};
let n = self.value.into_float_value();
let n = function(ctx, n, None);
let n = ctx.builder.build_float_to_signed_int(n, target_int_dtype_llvm, "").unwrap();
ScalarObject { dtype: target_int_dtype, value: n.as_basic_value_enum() }
} else {
unsupported_type(ctx, [self.dtype])
}
}
}
impl<'ctx> NDArrayObject<'ctx> {
/// Helper function for NAC3's builtin `np_min()`, `np_max()`, `np_argmin()`, and `np_argmax()`.
///
/// Generate LLVM IR to find the extremum and index of the **first** extremum value.
///
/// Care has also been taken to make the error messages match that of NumPy.
fn min_or_max_helper<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
kind: MinOrMax,
on_empty_err_msg: &str,
) -> (ScalarObject<'ctx>, Int<'ctx, SizeT>) {
let sizet_model = IntModel(SizeT);
let dtype_llvm = ctx.get_llvm_type(generator, self.dtype);
// If the ndarray is empty, throw an error.
let is_empty = self.is_empty(generator, ctx);
ctx.make_assert(
generator,
is_empty.value,
"0:ValueError",
on_empty_err_msg,
[None, None, None],
ctx.current_loc,
);
// Setup and initialize the extremum to be the first element in the ndarray
let pextremum_index = sizet_model.alloca(generator, ctx, "extremum_index");
let pextremum = ctx.builder.build_alloca(dtype_llvm, "extremum").unwrap();
let zero = sizet_model.const_0(generator, ctx.ctx);
pextremum_index.store(ctx, zero);
let first_scalar = self.get_nth(generator, ctx, zero);
ctx.builder.build_store(pextremum, first_scalar.value).unwrap();
// Find extremum
let start = sizet_model.const_1(generator, ctx.ctx); // Start on 1
let stop = self.size(generator, ctx);
let step = sizet_model.const_1(generator, ctx.ctx);
gen_for_model_auto(generator, ctx, start, stop, step, |generator, ctx, _hooks, i| {
// Worth reading on "Notes" in <https://numpy.org/doc/stable/reference/generated/numpy.min.html#numpy.min>
// on how `NaN` values have to be handled.
let scalar = self.get_nth(generator, ctx, i);
let old_extremum = ctx.builder.build_load(pextremum, "current_extremum").unwrap();
let old_extremum = ScalarObject { dtype: self.dtype, value: old_extremum };
let new_extremum = ScalarObject::min_or_max_helper(ctx, kind, old_extremum, scalar);
// Check if new_extremum is more extreme than old_extremum.
let update_index = ScalarObject::compare(
generator,
ctx,
new_extremum,
old_extremum,
IntPredicate::NE,
FloatPredicate::ONE,
"",
);
gen_if_model(generator, ctx, update_index, |_generator, ctx| {
pextremum_index.store(ctx, i);
Ok(())
})
.unwrap();
Ok(())
})
.unwrap();
// Finally return the extremum and extremum index.
let extremum_index = pextremum_index.load(generator, ctx, "extremum_index");
let extremum = ctx.builder.build_load(pextremum, "extremum_value").unwrap();
let extremum = ScalarObject { dtype: self.dtype, value: extremum };
(extremum, extremum_index)
}
/// Invoke NAC3's builtin `np_min()` or `np_max()`.
pub fn min_or_max<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
kind: MinOrMax,
) -> ScalarObject<'ctx> {
let on_empty_err_msg = format!(
"zero-size array to reduction operation {} which has no identity",
match kind {
MinOrMax::Min => "minimum",
MinOrMax::Max => "maximum",
}
);
self.min_or_max_helper(generator, ctx, kind, &on_empty_err_msg).0
}
/// Invoke NAC3's builtin `np_argmin()` or `np_argmax()`.
pub fn argmin_or_argmax<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
kind: MinOrMax,
) -> Int<'ctx, SizeT> {
let on_empty_err_msg = format!(
"attempt to get {} of an empty sequence",
match kind {
MinOrMax::Min => "argmin",
MinOrMax::Max => "argmax",
}
);
self.min_or_max_helper(generator, ctx, kind, &on_empty_err_msg).1
}
}

View File

@ -1,7 +1,4 @@
use crate::codegen::{
irrt::call_nac3_ndarray_index, model::*, structure::ndarray::scalar::ScalarObject,
CodeGenContext, CodeGenerator,
};
use crate::codegen::{irrt::call_nac3_ndarray_index, model::*, CodeGenContext, CodeGenerator};
use super::{scalar::ScalarOrNDArray, NDArrayObject};
@ -130,7 +127,7 @@ impl<'ctx> RustNDIndex<'ctx> {
dst_ndindex_ptr: Ptr<'ctx, StructModel<NDIndex>>,
) {
let ndindex_type_model = IntModel(NDIndexType::default());
let i32_model = IntModel(Int32::default());
let i32_model = IntModel(Int32);
let user_slice_model = StructModel(UserSlice);
// Set `dst_ndindex_ptr->type`
@ -178,6 +175,7 @@ impl<'ctx> RustNDIndex<'ctx> {
impl<'ctx> NDArrayObject<'ctx> {
/// Get the ndims [`Type`] after indexing with a given slice.
#[must_use]
pub fn deduce_ndims_after_indexing_with(&self, indexes: &[RustNDIndex<'ctx>]) -> u64 {
let mut ndims = self.ndims;
for index in indexes {
@ -235,11 +233,8 @@ impl<'ctx> NDArrayObject<'ctx> {
let subndarray = self.index(generator, ctx, indexes, name);
if subndarray.is_unsized() {
// NOTE: `np.size(self) == 0` is impossible.
let pfirst = subndarray.get_nth_pelement(generator, ctx, zero, "pfirst");
let first = ctx.builder.build_load(pfirst, "first").unwrap();
ScalarOrNDArray::Scalar(ScalarObject { dtype: self.dtype, value: first })
// NOTE: `np.size(self) == 0` here is never possible.
ScalarOrNDArray::Scalar(subndarray.get_nth(generator, ctx, zero))
} else {
ScalarOrNDArray::NDArray(subndarray)
}
@ -319,9 +314,9 @@ pub mod util {
})
};
let start = help(&start)?;
let stop = help(&stop)?;
let step = help(&step)?;
let start = help(start)?;
let stop = help(stop)?;
let step = help(step)?;
RustNDIndex::Slice(RustUserSlice { start, stop, step })
} else {

View File

@ -1,12 +1,11 @@
use inkwell::values::BasicValueEnum;
use itertools::Itertools;
use util::gen_for_model_auto;
use crate::{
codegen::{
model::*,
structure::ndarray::{
broadcast::broadcast_all_ndarrays, scalar::ScalarObject, NDArrayObject,
},
structure::ndarray::{scalar::ScalarObject, NDArrayObject},
CodeGenContext, CodeGenerator,
},
typecheck::typedef::Type,
@ -14,179 +13,126 @@ use crate::{
use super::scalar::ScalarOrNDArray;
pub fn starmap_scalars_array_like<'ctx, 'a, F, G>(
impl<'ctx> NDArrayObject<'ctx> {
/// TODO: Document me. Has complex behavior.
pub fn broadcasting_starmap<'a, G, MappingFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
inputs: &Vec<ScalarOrNDArray<'ctx>>,
mapping: F,
) -> Result<ScalarOrNDArray<'ctx>, String>
where
F: FnOnce(
ndarrays: &[Self],
ret_dtype: Type,
name: &str,
mapping: MappingFn,
) -> Result<Self, String>
where
G: CodeGenerator + ?Sized,
MappingFn: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
Int<'ctx, SizeT>,
&Vec<ScalarObject<'ctx>>,
&[ScalarObject<'ctx>],
) -> Result<ScalarObject<'ctx>, String>,
G: CodeGenerator + ?Sized,
{
assert!(!inputs.is_empty());
{
let sizet_model = IntModel(SizeT);
// Check if all inputs are ScalarObjects
let scalars: Option<Vec<_>> =
inputs.iter().map(|input| ScalarObject::try_from(input)).try_collect().ok();
// Broadcast inputs
let broadcast_result = NDArrayObject::broadcast_all(generator, ctx, ndarrays);
match scalars {
Some(scalars) => {
// When inputs are all scalars, return a ScalarObject back
let i = sizet_model.const_0(generator, ctx.ctx);
let scalar = mapping(generator, ctx, i, &scalars)?;
Ok(ScalarOrNDArray::Scalar(scalar))
}
None => {
// When not all inputs are scalars, promote all non-ndarray inputs
// to ndarrays, do broadcast_shapes on them, and map.
let ndarrays =
inputs.into_iter().map(|input| input.as_ndarray(generator, ctx)).collect_vec();
let broadcast_result = broadcast_all_ndarrays(generator, ctx, &ndarrays);
let start = sizet_model.const_0(generator, ctx.ctx);
let stop = broadcast_result.ndarrays[0].size(generator, ctx); // They all should have the same `np.size`.
let step = sizet_model.const_1(generator, ctx.ctx);
// Map element-wise and store results into `mapped_ndarray`.
let mapped_ndarray = gen_for_model_auto(
generator,
ctx,
start,
stop,
step,
move |generator, ctx, _hooks, i| {
let elements = ndarrays
.iter()
.map(|ndarray| {
let pelement = ndarray.get_nth_pelement(generator, ctx, i, "pelement");
let element = ctx.builder.build_load(pelement, "element").unwrap();
ScalarObject { value: element, dtype: ndarray.dtype }
})
.collect_vec();
let ret = mapping(generator, ctx, i, &elements)?;
// It might look weird but it is perfectly fine putting the allocation codegen
// here within `for`.
// The reason for doing this is to get the `dtype` out of `ret`, which is only
// available after running `mapping`.
// Allocate the resulting ndarray
let mapped_ndarray = NDArrayObject::alloca_uninitialized(
generator,
ctx,
ret.dtype,
ret_dtype,
broadcast_result.ndims,
"mapped_ndarray",
name,
);
mapped_ndarray.copy_shape_from_array(generator, ctx, broadcast_result.shape);
mapped_ndarray.create_data(generator, ctx);
let pret = mapped_ndarray.get_nth_pelement(generator, ctx, i, "pret");
// Map element-wise and store results into `mapped_ndarray`.
let start = sizet_model.const_0(generator, ctx.ctx);
let stop = broadcast_result.ndarrays[0].size(generator, ctx); // They all should have the same `np.size`.
let step = sizet_model.const_1(generator, ctx.ctx);
gen_for_model_auto(generator, ctx, start, stop, step, move |generator, ctx, _hooks, i| {
let elements =
ndarrays.iter().map(|ndarray| ndarray.get_nth(generator, ctx, i)).collect_vec();
let ret = mapping(generator, ctx, i, &elements)?;
let pret = mapped_ndarray.get_nth_pointer(generator, ctx, i, "pret");
ctx.builder.build_store(pret, ret.value).unwrap();
Ok(())
})?;
Ok(mapped_ndarray)
},
)?;
Ok(ScalarOrNDArray::NDArray(mapped_ndarray))
}
}
}
impl<'ctx> ScalarObject<'ctx> {
pub fn map<'a, F, G>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
mapping: F,
) -> Result<Self, String>
where
F: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
Int<'ctx, SizeT>,
ScalarObject<'ctx>,
) -> Result<ScalarObject<'ctx>, String>,
G: CodeGenerator + ?Sized,
{
let ScalarOrNDArray::Scalar(ret) = starmap_scalars_array_like(
generator,
ctx,
&vec![ScalarOrNDArray::Scalar(*self)],
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]),
)?
else {
unreachable!()
};
Ok(ret)
}
}
impl<'ctx> NDArrayObject<'ctx> {
pub fn map<'a, F, G>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
mapping: F,
) -> Result<Self, String>
where
F: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
Int<'ctx, SizeT>,
ScalarObject<'ctx>,
) -> Result<ScalarObject<'ctx>, String>,
G: CodeGenerator + ?Sized,
{
let ScalarOrNDArray::NDArray(ret) = starmap_scalars_array_like(
generator,
ctx,
&vec![ScalarOrNDArray::NDArray(*self)],
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]),
)?
else {
unreachable!()
};
Ok(ret)
}
}
impl<'ctx> ScalarOrNDArray<'ctx> {
pub fn map<'a, F, G>(
pub fn map<'a, G, Mapping>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
ret_dtype: Type,
mapping: F,
name: &str,
mapping: Mapping,
) -> Result<Self, String>
where
F: FnOnce(
G: CodeGenerator + ?Sized,
Mapping: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
Int<'ctx, SizeT>,
ScalarObject<'ctx>,
) -> Result<ScalarObject<'ctx>, String>,
G: CodeGenerator + ?Sized,
) -> Result<BasicValueEnum<'ctx>, String>,
{
match self {
ScalarOrNDArray::Scalar(scalar) => starmap_scalars_array_like(
NDArrayObject::broadcasting_starmap(
generator,
ctx,
&vec![ScalarOrNDArray::Scalar(*scalar)],
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]),
),
ScalarOrNDArray::NDArray(ndarray) => {
ndarray.map(generator, ctx, mapping).map(ScalarOrNDArray::NDArray)
&[*self],
ret_dtype,
name,
|generator, ctx, i, scalars| {
let value = mapping(generator, ctx, i, scalars[0])?;
Ok(ScalarObject { dtype: ret_dtype, value })
},
)
}
}
impl<'ctx> ScalarOrNDArray<'ctx> {
/// TODO: Document me. Has complex behavior.
pub fn broadcasting_starmap<'a, G, MappingFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
inputs: &[Self],
ret_dtype: Type,
name: &str,
mapping: MappingFn,
) -> Result<Self, String>
where
G: CodeGenerator + ?Sized,
MappingFn: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
Int<'ctx, SizeT>,
&[ScalarObject<'ctx>],
) -> Result<ScalarObject<'ctx>, String>,
{
let sizet_model = IntModel(SizeT);
// Check if all inputs are ScalarObjects
let all_scalars: Option<Vec<_>> =
inputs.iter().map(ScalarObject::try_from).try_collect().ok();
if let Some(scalars) = all_scalars {
let i = sizet_model.const_0(generator, ctx.ctx); // Pass 0 as the index
let scalar = mapping(generator, ctx, i, &scalars)?;
Ok(ScalarOrNDArray::Scalar(scalar))
} else {
// Promote all input to ndarrays and map through them.
let inputs = inputs.iter().map(|input| input.as_ndarray(generator, ctx)).collect_vec();
let ndarray = NDArrayObject::broadcasting_starmap(
generator, ctx, &inputs, ret_dtype, name, mapping,
)?;
Ok(ScalarOrNDArray::NDArray(ndarray))
}
}
}

View File

@ -1,4 +1,5 @@
pub mod broadcast;
pub mod functions;
pub mod indexing;
pub mod mapping;
pub mod scalar;
@ -22,8 +23,9 @@ use inkwell::{
context::Context,
types::BasicType,
values::{BasicValue, BasicValueEnum, PointerValue},
AddressSpace,
AddressSpace, IntPredicate,
};
use scalar::ScalarObject;
use util::{call_memcpy_model, gen_for_model_auto};
pub struct NpArrayFields<'ctx, F: FieldTraversal<'ctx>> {
@ -52,6 +54,7 @@ impl<'ctx> StructKind<'ctx> for NpArray {
}
}
/// A NAC3 Python ndarray object.
#[derive(Debug, Clone, Copy)]
pub struct NDArrayObject<'ctx> {
pub dtype: Type,
@ -116,7 +119,7 @@ impl<'ctx> NDArrayObject<'ctx> {
/// Get the pointer to the n-th (0-based) element.
///
/// The returned pointer has the element type of the LLVM type of this ndarray's `dtype`.
pub fn get_nth_pelement<G: CodeGenerator + ?Sized>(
pub fn get_nth_pointer<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
@ -131,6 +134,18 @@ impl<'ctx> NDArrayObject<'ctx> {
.unwrap()
}
/// Get the n-th (0-based) scalar.
pub fn get_nth<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
nth: Int<'ctx, SizeT>,
) -> ScalarObject<'ctx> {
let p = self.get_nth_pointer(generator, ctx, nth, "value");
let value = ctx.builder.build_load(p, "value").unwrap();
ScalarObject { dtype: self.dtype, value }
}
/// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`.
///
/// Please refer to the IRRT implementation to see its purpose.
@ -210,7 +225,7 @@ impl<'ctx> NDArrayObject<'ctx> {
name: &str,
) -> Self {
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty);
let ndims = extract_ndims(&mut ctx.unifier, ndims);
let ndims = extract_ndims(&ctx.unifier, ndims);
Self::alloca_uninitialized(generator, ctx, dtype, ndims, name)
}
@ -224,7 +239,22 @@ impl<'ctx> NDArrayObject<'ctx> {
sizet_model.constant(generator, ctx, self.ndims)
}
/// Return true if this ndarray is unsized.
/// Get if this ndarray's `np.size` is `0` - containing no content.
pub fn is_empty<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, Bool> {
let sizet_model = IntModel(SizeT);
let size = self.size(generator, ctx);
size.compare(ctx, IntPredicate::EQ, sizet_model.const_0(generator, ctx.ctx), "is_empty")
}
/// Return true if this ndarray is unsized - `ndims == 0` and only contains a scalar.
///
/// This is a staticially known property of ndarrays. This is why it is returning
/// a Rust boolean instead of a [`BasicValue`].
#[must_use]
pub fn is_unsized(&self) -> bool {
self.ndims == 0
@ -273,7 +303,7 @@ impl<'ctx> NDArrayObject<'ctx> {
) {
assert_eq!(self.ndims, src_ndarray.ndims);
let src_shape = src_ndarray.value.get(generator, ctx, |f| f.shape, "src_shape");
self.copy_shape_from_array(generator, ctx, src_shape)
self.copy_shape_from_array(generator, ctx, src_shape);
}
/// Copy strides dimensions from an array.
@ -298,14 +328,14 @@ impl<'ctx> NDArrayObject<'ctx> {
) {
assert_eq!(self.ndims, src_ndarray.ndims);
let src_strides = src_ndarray.value.get(generator, ctx, |f| f.strides, "src_strides");
self.copy_strides_from_array(generator, ctx, src_strides)
self.copy_strides_from_array(generator, ctx, src_strides);
}
/// Loop through every element pointer in the ndarray in its flatten view.
/// Iterate through every element pointer in the ndarray in its flatten view.
///
/// `body` also access to [`BreakContinueHooks`] to short-circuit and an element's
/// index. The given element pointer also has been casted to the LLVM type of this ndarray's `dtype`.
pub fn foreach<'a, G, F>(
pub fn foreach_pointer<'a, G, F>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
@ -328,12 +358,36 @@ impl<'ctx> NDArrayObject<'ctx> {
let step = sizet_model.const_1(generator, ctx.ctx);
gen_for_model_auto(generator, ctx, start, stop, step, |generator, ctx, hooks, i| {
let pelement = self.get_nth_pelement(generator, ctx, i, "element");
let pelement = self.get_nth_pointer(generator, ctx, i, "element");
body(generator, ctx, hooks, i, pelement)
})
}
/// Fill the NDArray with a value.
/// Iterate through every scalar in this ndarray.
pub fn foreach<'a, G, F>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
body: F,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
F: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks<'ctx>,
Int<'ctx, SizeT>,
ScalarObject<'ctx>,
) -> Result<(), String>,
{
self.foreach_pointer(generator, ctx, |generator, ctx, hooks, i, p| {
let value = ctx.builder.build_load(p, "value").unwrap();
let scalar = ScalarObject { dtype: self.dtype, value };
body(generator, ctx, hooks, i, scalar)
})
}
/// Fill the ndarray with a value.
///
/// `fill_value` must have the same LLVM type as the `dtype` of this ndarray.
pub fn fill<G: CodeGenerator + ?Sized>(
@ -342,11 +396,11 @@ impl<'ctx> NDArrayObject<'ctx> {
ctx: &mut CodeGenContext<'ctx, '_>,
fill_value: BasicValueEnum<'ctx>,
) {
self.foreach(generator, ctx, |_generator, ctx, _hooks, _i, pelement| {
self.foreach_pointer(generator, ctx, |_generator, ctx, _hooks, _i, pelement| {
ctx.builder.build_store(pelement, fill_value).unwrap();
Ok(())
})
.unwrap()
.unwrap();
}
/// Create a reshaped view on this ndarray like `np.reshape()`.

View File

@ -8,6 +8,10 @@ use crate::{
use super::NDArrayObject;
/// An LLVM numpy scalar with its [`Type`].
///
/// Intended to be used with [`ScalarOrNDArray`].
///
/// A scalar does not have to be an actual number. It could be arbitrary objects.
#[derive(Debug, Clone, Copy)]
pub struct ScalarObject<'ctx> {
pub dtype: Type,
@ -55,6 +59,22 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
}
}
#[must_use]
pub fn into_scalar(&self) -> ScalarObject<'ctx> {
match self {
ScalarOrNDArray::NDArray(_ndarray) => panic!("Got NDArray"),
ScalarOrNDArray::Scalar(scalar) => *scalar,
}
}
#[must_use]
pub fn into_ndarray(&self) -> NDArrayObject<'ctx> {
match self {
ScalarOrNDArray::NDArray(ndarray) => *ndarray,
ScalarOrNDArray::Scalar(_scalar) => panic!("Got Scalar"),
}
}
/// Convert this [`ScalarOrNDArray`] to an ndarray - behaves like `np.asarray`.
/// - If this is an ndarray, the ndarray is returned.
/// - If this is a scalar, an unsized ndarray view is created on it.
@ -68,6 +88,14 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
ScalarOrNDArray::Scalar(scalar) => scalar.as_ndarray(generator, ctx),
}
}
#[must_use]
pub fn dtype(&self) -> Type {
match self {
ScalarOrNDArray::Scalar(scalar) => scalar.dtype,
ScalarOrNDArray::NDArray(ndarray) => ndarray.dtype,
}
}
}
impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for ScalarObject<'ctx> {
@ -76,7 +104,18 @@ impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for ScalarObject<'ctx> {
fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result<Self, Self::Error> {
match value {
ScalarOrNDArray::Scalar(scalar) => Ok(*scalar),
ScalarOrNDArray::NDArray(_) => Err(()),
ScalarOrNDArray::NDArray(_ndarray) => Err(()),
}
}
}
impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for NDArrayObject<'ctx> {
type Error = ();
fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result<Self, Self::Error> {
match value {
ScalarOrNDArray::Scalar(_scalar) => Err(()),
ScalarOrNDArray::NDArray(ndarray) => Ok(*ndarray),
}
}
}

View File

@ -2,15 +2,11 @@ use inkwell::values::BasicValueEnum;
use util::gen_for_model_auto;
use crate::{
codegen::{
model::*,
structure::list::{List, ListObject},
CodeGenContext, CodeGenerator,
},
codegen::{model::*, structure::list::ListObject, CodeGenContext, CodeGenerator},
typecheck::typedef::{Type, TypeEnum},
};
/// Parse a NumPy-like "int sequence" input and return the int sequence as a [`ListObject`]
/// Parse a NumPy-like "int sequence" input and return the int sequence as an array and its length.
///
/// * `sequence` - The `sequence` parameter.
/// * `sequence_ty` - The typechecker type of `sequence`
@ -20,99 +16,97 @@ use crate::{
/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))`
/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
///
/// `int32` values will be sign-extended to `SizeT`
/// All `int32` values will be sign-extended to `SizeT`.
pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
sequence: BasicValueEnum<'ctx>,
sequence_ty: Type,
) -> ListObject<'ctx, IntModel<SizeT>, SizeT> {
input_sequence: BasicValueEnum<'ctx>,
input_sequence_ty: Type,
) -> (Int<'ctx, SizeT>, Ptr<'ctx, IntModel<SizeT>>) {
let sizet_model = IntModel(SizeT);
let list_model = StructModel(List { len: SizeT, item: IntModel(SizeT) });
let zero = sizet_model.const_0(generator, ctx.ctx);
let one = sizet_model.const_1(generator, ctx.ctx);
// The result `list` to return.
let result = list_model.alloca(generator, ctx, "result_sequence");
match &*ctx.unifier.get_ty(sequence_ty) {
match &*ctx.unifier.get_ty(input_sequence_ty) {
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{
// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
let in_sequence_model =
PtrModel(StructModel(List { item: IntModel(Int32), len: SizeT }));
let in_sequence = in_sequence_model.check_value(generator, ctx.ctx, sequence).unwrap();
/*
Reference code:
```
result.size = sequence.size;
result.data = __builtin_alloca(sizeof(SizeT) * sequence.size);
for (SizeT i = 0; i < sequence.size; i++) {
result.data[i] = (SizeT) sequence.data[i];
}
return result
```
*/
// Check `input_sequence`
let input_sequence =
ListObject::from_value_and_type(generator, ctx, input_sequence, input_sequence_ty);
let ndims = in_sequence.get(generator, ctx, |f| f.len, "size");
result.set(ctx, |f| f.len, ndims);
let len = input_sequence.value.gep(ctx, |f| f.len).load(generator, ctx, "len");
let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence");
let result_data = sizet_model.array_alloca(generator, ctx, ndims.value, "data");
result.set(ctx, |f| f.items, result_data);
// Load all the `int32`s from the input_sequence, cast them to `SizeT`, and store them into `result`
gen_for_model_auto(generator, ctx, zero, len, one, |generator, ctx, _hooks, i| {
// Load the i-th int32 in the input sequence
let int = input_sequence
.value
.get(generator, ctx, |f| f.items, "int")
.ix(generator, ctx, i.value, "int")
.value
.into_int_value();
// Cast to SizeT
let int = sizet_model.s_extend_or_bit_cast(generator, ctx, int, "int");
// Store
result.offset(generator, ctx, i.value, "int").store(ctx, int);
gen_for_model_auto(generator, ctx, zero, ndims, one, |generator, ctx, _hooks, i| {
let in_dim = in_sequence
.get(generator, ctx, |f| f.items, "in_dim")
.ix(generator, ctx, i.value, "in_dim")
.s_extend_or_bit_cast(generator, ctx, SizeT, "in_dim");
result_data.offset(generator, ctx, i.value, "dim").store(ctx, in_dim);
Ok(())
})
.unwrap();
(len, result)
}
TypeEnum::TTuple { ty: tuple_types } => {
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
let ndims_int = tuple_types.len();
let ndims = sizet_model.constant(generator, ctx.ctx, ndims_int as u64);
result.set(ctx, |f| f.len, ndims);
let input_sequence = input_sequence.into_struct_value(); // A tuple is a struct
// A tuple has to be a StructValue
// Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM.
let tuple = sequence.into_struct_value();
let data = sizet_model.array_alloca(generator, ctx, ndims.value, "sequence_data");
result.set(ctx, |f| f.items, data);
let len_int = tuple_types.len();
for i in 0..ndims_int {
// Get the i-th (0-based) element off of the tuple and load it
// into `result`.
let dim = ctx
let len = sizet_model.constant(generator, ctx.ctx, len_int as u64);
let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence");
for i in 0..len_int {
// Get the i-th element off of the tuple and load it into `result`.
let int = ctx
.builder
.build_extract_value(tuple, i as u32, format!("dim").as_str())
.build_extract_value(input_sequence, i as u32, "int")
.unwrap()
.into_int_value();
let dim = sizet_model.s_extend_or_bit_cast(generator, ctx, dim, "dim");
let int = sizet_model.s_extend_or_bit_cast(generator, ctx, int, "int");
let offset = sizet_model.constant(generator, ctx.ctx, i as u64);
data.offset(generator, ctx, offset.value, "dim").store(ctx, dim);
result.offset(generator, ctx, offset.value, "int").store(ctx, int);
}
(len, result)
}
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() =>
{
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
let sequence_int = sizet_model.check_value(generator, ctx.ctx, sequence).unwrap();
let input_int = input_sequence.into_int_value();
// Size is 1
result.set(ctx, |f| f.len, one);
let len = sizet_model.const_1(generator, ctx.ctx);
let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence");
// Alloca an array of length 1 and store the sole integer input into the array.
let data = sizet_model.array_alloca(generator, ctx, one.value, "data");
data.offset(generator, ctx, zero.value, "dim").store(ctx, sequence_int);
let int = sizet_model.s_extend_or_bit_cast(generator, ctx, input_int, "int");
// Storing into result[0]
result.store(ctx, int);
(len, result)
}
_ => panic!("encountered unknown sequence type: {}", ctx.unifier.stringify(sequence_ty)),
_ => panic!(
"encountered unknown sequence type: {}",
ctx.unifier.stringify(input_sequence_ty)
),
}
ListObject { item_type: ctx.primitives.usize(), value: result }
}

View File

@ -1180,7 +1180,7 @@ impl<'a> BuiltinBuilder<'a> {
let ret_elem_ty = size_variant.of_int(&ctx.primitives);
let func = match kind {
Kind::Ceil => builtin_fns::call_ceil,
Kind::Floor => builtin_fns::call_ceil_or_floor,
Kind::Floor => builtin_fns::call_floor,
};
Ok(Some(func(generator, ctx, (arg_ty, arg), ret_elem_ty)?))
}),
@ -1361,7 +1361,7 @@ impl<'a> BuiltinBuilder<'a> {
let ndims1 = create_ndims(self.unifier, 1);
let ndarray_float_1d = make_ndarray_ty(
self.unifier,
&self.primitives,
self.primitives,
Some(self.primitives.float),
Some(ndims1),
);
@ -1470,7 +1470,7 @@ impl<'a> BuiltinBuilder<'a> {
PrimDef::FunNpSize => {
// TODO: Make the return type usize
create_fn_by_codegen(
&mut self.unifier,
self.unifier,
&VarMap::new(),
prim.name(),
self.primitives.int32,
@ -1485,7 +1485,7 @@ impl<'a> BuiltinBuilder<'a> {
// of the input ndarray.
let ret_ty = self.unifier.get_dummy_var().ty;
create_fn_by_codegen(
&mut self.unifier,
self.unifier,
&VarMap::new(),
prim.name(),
ret_ty,
@ -1548,7 +1548,7 @@ impl<'a> BuiltinBuilder<'a> {
let func = match prim {
PrimDef::FunNpCeil => builtin_fns::call_ceil,
PrimDef::FunNpFloor => builtin_fns::call_ceil_or_floor,
PrimDef::FunNpFloor => builtin_fns::call_floor,
_ => unreachable!(),
};
Ok(Some(func(generator, ctx, (arg_ty, arg), ctx.primitives.float)?))