forked from M-Labs/nac3
core/ndstrides: checkpoint 2
This commit is contained in:
parent
bcd35544cc
commit
937b36dcfd
File diff suppressed because it is too large
Load Diff
|
@ -953,9 +953,9 @@ pub fn call_ndarray_calc_broadcast_index<
|
|||
)
|
||||
}
|
||||
|
||||
pub fn call_nac3_throw_dummy_error<'ctx, G: CodeGenerator + ?Sized>(
|
||||
pub fn call_nac3_throw_dummy_error<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
ctx: &CodeGenContext<'_, '_>,
|
||||
) {
|
||||
let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_throw_dummy_error");
|
||||
CallFunction::begin(generator, ctx, &name).returning_void();
|
||||
|
|
|
@ -3,9 +3,9 @@ use crate::codegen::{CodeGenContext, CodeGenerator};
|
|||
// When [`TypeContext::size_type`] is 32-bits, the function name is "{fn_name}".
|
||||
// When [`TypeContext::size_type`] is 64-bits, the function name is "{fn_name}64".
|
||||
#[must_use]
|
||||
pub fn get_sizet_dependent_function_name<'ctx, G: CodeGenerator + ?Sized>(
|
||||
pub fn get_sizet_dependent_function_name<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
ctx: &CodeGenContext<'_, '_>,
|
||||
name: &str,
|
||||
) -> String {
|
||||
let mut name = name.to_owned();
|
||||
|
|
|
@ -48,9 +48,7 @@ struct TypeFieldTraversal<'ctx, 'a, G: CodeGenerator + ?Sized> {
|
|||
field_types: Vec<BasicTypeEnum<'ctx>>,
|
||||
}
|
||||
|
||||
impl<'ctx, 'a, 'b, G: CodeGenerator + ?Sized> FieldTraversal<'ctx>
|
||||
for TypeFieldTraversal<'ctx, 'a, G>
|
||||
{
|
||||
impl<'ctx, 'a, G: CodeGenerator + ?Sized> FieldTraversal<'ctx> for TypeFieldTraversal<'ctx, 'a, G> {
|
||||
type Out<M> = ();
|
||||
|
||||
fn add<M: Model<'ctx>>(&mut self, _name: &'static str, model: M) -> Self::Out<M> {
|
||||
|
@ -204,6 +202,6 @@ impl<'ctx, S: StructKind<'ctx>> Ptr<'ctx, StructModel<S>> {
|
|||
M: Model<'ctx>,
|
||||
GetField: FnOnce(S::Fields<GepFieldTraversal>) -> GepField<M>,
|
||||
{
|
||||
self.gep(ctx, get_field).store(ctx, value)
|
||||
self.gep(ctx, get_field).store(ctx, value);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,14 +27,14 @@ pub fn call_memcpy_model<'ctx, Item: Model<'ctx> + Default, G: CodeGenerator + ?
|
|||
|
||||
/// Like [`gen_for_callback_incrementing`] with [`Model`] abstractions.
|
||||
/// The [`IntKind`] is automatically inferred.
|
||||
pub fn gen_for_model_auto<'ctx, 'a, G, F, I, R>(
|
||||
pub fn gen_for_model_auto<'ctx, 'a, G, F, I>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
start: Int<'ctx, I>,
|
||||
stop: Int<'ctx, I>,
|
||||
step: Int<'ctx, I>,
|
||||
body: F,
|
||||
) -> Result<R, String>
|
||||
) -> Result<(), String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
F: FnOnce(
|
||||
|
@ -42,7 +42,7 @@ where
|
|||
&mut CodeGenContext<'ctx, 'a>,
|
||||
BreakContinueHooks<'ctx>,
|
||||
Int<'ctx, I>,
|
||||
) -> Result<R, String>,
|
||||
) -> Result<(), String>,
|
||||
I: IntKind<'ctx> + Default,
|
||||
{
|
||||
let int_model = IntModel(I::default());
|
||||
|
@ -60,3 +60,32 @@ where
|
|||
step.value,
|
||||
)
|
||||
}
|
||||
|
||||
/// Like [`gen_if_callback`] with [`Model`] abstractions and without the `else` block.
|
||||
pub fn gen_if_model<'ctx, 'a, G, ThenFn>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
cond: Int<'ctx, Bool>,
|
||||
then: ThenFn,
|
||||
) -> Result<(), String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
ThenFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<(), String>,
|
||||
{
|
||||
let current_bb = ctx.builder.get_insert_block().unwrap();
|
||||
let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "if.then");
|
||||
let end_bb = ctx.ctx.insert_basic_block_after(then_bb, "if.end");
|
||||
|
||||
// Inserting into `current_bb`.
|
||||
ctx.builder.build_conditional_branch(cond.value, then_bb, end_bb).unwrap();
|
||||
|
||||
// Inserting into `then_bb`
|
||||
ctx.builder.position_at_end(then_bb);
|
||||
then(generator, ctx)?;
|
||||
ctx.builder.build_unconditional_branch(end_bb).unwrap();
|
||||
|
||||
// Reposition to `end_bb` for continuation.
|
||||
ctx.builder.position_at_end(end_bb);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -99,8 +99,7 @@ fn create_empty_ndarray<'ctx, G>(
|
|||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
{
|
||||
let shape = parse_numpy_int_sequence(generator, ctx, shape, shape_ty);
|
||||
let shape = shape.value.get(generator, ctx, |f| f.items, "shape");
|
||||
let (_, shape) = parse_numpy_int_sequence(generator, ctx, shape, shape_ty);
|
||||
|
||||
let ndarray =
|
||||
NDArrayObject::alloca_uninitialized_of_type(generator, ctx, ndarray_ty, "ndarray");
|
||||
|
@ -248,8 +247,7 @@ pub fn gen_ndarray_broadcast_to<'ctx>(
|
|||
split_scalar_or_ndarray(generator, ctx, input, input_ty).as_ndarray(generator, ctx);
|
||||
|
||||
// Process `shape`
|
||||
let broadcast_shape = parse_numpy_int_sequence(generator, ctx, shape, shape_ty);
|
||||
let broadcast_shape = broadcast_shape.value.get(generator, ctx, |f| f.items, "shape");
|
||||
let (_, broadcast_shape) = parse_numpy_int_sequence(generator, ctx, shape, shape_ty);
|
||||
// NOTE: shape.size should equal to `broadcasted_ndims`.
|
||||
let broadcast_ndims_llvm = sizet_model.constant(generator, ctx.ctx, broadcast_ndims);
|
||||
call_nac3_ndarray_util_assert_shape_no_negative(
|
||||
|
@ -300,8 +298,7 @@ pub fn gen_ndarray_reshape<'ctx>(
|
|||
// Process the shape input from user and resolve negative indices.
|
||||
// The resulting `new_shape`'s size should be equal to reshaped_ndims.
|
||||
// This is ensured by the typechecker.
|
||||
let new_shape = parse_numpy_int_sequence(generator, ctx, shape, shape_ty);
|
||||
let new_shape = new_shape.value.get(generator, ctx, |f| f.items, "new_shape");
|
||||
let (_, new_shape) = parse_numpy_int_sequence(generator, ctx, shape, shape_ty);
|
||||
|
||||
// Resolve unknown dimensions & validate `new_shape`.
|
||||
let new_ndims = sizet_model.constant(generator, ctx.ctx, reshaped_ndims);
|
||||
|
@ -353,7 +350,7 @@ pub fn gen_ndarray_arange<'ctx>(
|
|||
|
||||
// Create data and set elements
|
||||
ndarray.create_data(generator, ctx);
|
||||
ndarray.foreach(generator, ctx, |_generator, ctx, _hooks, i, pelement| {
|
||||
ndarray.foreach_pointer(generator, ctx, |_generator, ctx, _hooks, i, pelement| {
|
||||
let val =
|
||||
ctx.builder.build_unsigned_int_to_float(i.value, ctx.ctx.f64_type(), "val").unwrap();
|
||||
ctx.builder.build_store(pelement, val).unwrap();
|
||||
|
@ -495,10 +492,9 @@ pub fn gen_ndarray_transpose<'ctx>(
|
|||
// Parse argument #2 axes
|
||||
let in_axes_ty = fun.0.args[1].ty;
|
||||
let in_axes = args[1].1.clone().to_basic_value_enum(ctx, generator, in_axes_ty)?;
|
||||
let in_axes = parse_numpy_int_sequence(generator, ctx, in_axes, in_axes_ty);
|
||||
|
||||
let (_, axes) = parse_numpy_int_sequence(generator, ctx, in_axes, in_axes_ty);
|
||||
let num_axes = ndarray.get_ndims(generator, ctx.ctx);
|
||||
let axes = in_axes.value.get(generator, ctx, |f| f.items, "axes");
|
||||
|
||||
call_nac3_ndarray_transpose(
|
||||
generator,
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
use super::model::*;
|
||||
use super::structure::cslice::CSlice;
|
||||
use super::structure::ndarray::broadcast::broadcast_all_ndarrays;
|
||||
use super::{
|
||||
super::symbol_resolver::ValueEnum,
|
||||
expr::destructure_range,
|
||||
|
@ -438,7 +437,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
|
|||
let value =
|
||||
split_scalar_or_ndarray(generator, ctx, value, value_ty).as_ndarray(generator, ctx);
|
||||
|
||||
let broadcast_result = broadcast_all_ndarrays(generator, ctx, &vec![target, value]);
|
||||
let broadcast_result = NDArrayObject::broadcast_all(generator, ctx, &[target, value]);
|
||||
|
||||
let target = broadcast_result.ndarrays[0];
|
||||
let value = broadcast_result.ndarrays[1];
|
||||
|
|
|
@ -33,44 +33,38 @@ impl<'ctx, Item: Model<'ctx>, Size: IntKind<'ctx>> StructKind<'ctx> for List<Ite
|
|||
}
|
||||
}
|
||||
|
||||
pub struct ListObject<'ctx, Item: Model<'ctx>, Size: IntKind<'ctx>> {
|
||||
/// A NAC3 Python List object.
|
||||
pub struct ListObject<'ctx> {
|
||||
/// Typechecker type of the list items
|
||||
pub item_type: Type,
|
||||
pub value: Ptr<'ctx, StructModel<List<Item, Size>>>,
|
||||
pub value: Ptr<'ctx, StructModel<List<AnyModel<'ctx>, SizeT>>>,
|
||||
}
|
||||
|
||||
impl<'ctx, Item: Model<'ctx>, Len: IntKind<'ctx>> ListObject<'ctx, Item, Len> {
|
||||
impl<'ctx> ListObject<'ctx> {
|
||||
/// Create a [`ListObject`] from an LLVM value and its typechecker [`Type`].
|
||||
///
|
||||
/// - The `Item` model has to be manually provided, and should match the
|
||||
/// `get_llvm_type()` of `ty` and the `get_type()`. You may want to use
|
||||
/// [`AnyModel`] if `ty`'s type is not knowable statically.
|
||||
pub fn from_value_and_type<V: BasicValue<'ctx>, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
val: V,
|
||||
ty: Type,
|
||||
item_model: Item,
|
||||
len_model: Len,
|
||||
list_val: V,
|
||||
list_type: Type,
|
||||
) -> Self {
|
||||
let plist_model = PtrModel(StructModel(List { item: item_model, len: len_model }));
|
||||
|
||||
// Check typechecker type and extract `item_type`
|
||||
let item_type = match &*ctx.unifier.get_ty(ty) {
|
||||
let item_type = match &*ctx.unifier.get_ty(list_type) {
|
||||
TypeEnum::TObj { obj_id, params, .. }
|
||||
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
||||
{
|
||||
iter_type_vars(params).next().unwrap().ty // Extract `item_type`
|
||||
}
|
||||
_ => panic!("Expecting type to be a list, but got {}", ctx.unifier.stringify(ty)),
|
||||
_ => {
|
||||
panic!("Expecting type to be a list, but got {}", ctx.unifier.stringify(list_type))
|
||||
}
|
||||
};
|
||||
|
||||
// LLVM types of `item_model` and `ty` should match
|
||||
let llvm_ty = ctx.get_llvm_type(generator, ty);
|
||||
item_model.check_type(generator, ctx.ctx, llvm_ty).unwrap();
|
||||
let item_model = AnyModel(ctx.get_llvm_type(generator, item_type));
|
||||
let plist_model = PtrModel(StructModel(List { item: item_model, len: SizeT }));
|
||||
|
||||
// Create object
|
||||
let val = plist_model.check_value(generator, ctx.ctx, val).unwrap();
|
||||
ListObject { item_type: item_type, value: val }
|
||||
let value = plist_model.check_value(generator, ctx.ctx, list_val).unwrap();
|
||||
ListObject { item_type, value }
|
||||
}
|
||||
}
|
||||
|
|
|
@ -71,62 +71,64 @@ pub struct BroadcastAllResult<'ctx> {
|
|||
pub ndarrays: Vec<NDArrayObject<'ctx>>,
|
||||
}
|
||||
|
||||
// TODO: DOCUMENT: Behaves like `np.broadcast()`, except returns results differently.
|
||||
pub fn broadcast_all_ndarrays<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ndarrays: &Vec<NDArrayObject<'ctx>>,
|
||||
) -> BroadcastAllResult<'ctx> {
|
||||
assert!(!ndarrays.is_empty());
|
||||
impl<'ctx> NDArrayObject<'ctx> {
|
||||
// TODO: DOCUMENT: Behaves like `np.broadcast()`, except returns results differently.
|
||||
pub fn broadcast_all<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ndarrays: &[Self],
|
||||
) -> BroadcastAllResult<'ctx> {
|
||||
assert!(!ndarrays.is_empty());
|
||||
|
||||
let sizet_model = IntModel(SizeT);
|
||||
let shape_model = StructModel(ShapeEntry);
|
||||
let sizet_model = IntModel(SizeT);
|
||||
let shape_model = StructModel(ShapeEntry);
|
||||
|
||||
let broadcast_ndims = get_broadcast_all_ndims(ndarrays.iter().map(|ndarray| ndarray.ndims));
|
||||
let broadcast_ndims = get_broadcast_all_ndims(ndarrays.iter().map(|ndarray| ndarray.ndims));
|
||||
|
||||
// Prepare input shape entries
|
||||
let num_shape_entries =
|
||||
sizet_model.constant(generator, ctx.ctx, u64::try_from(ndarrays.len()).unwrap());
|
||||
let shape_entries =
|
||||
shape_model.array_alloca(generator, ctx, num_shape_entries.value, "shape_entries");
|
||||
for (i, ndarray) in ndarrays.iter().enumerate() {
|
||||
let i = sizet_model.constant(generator, ctx.ctx, i as u64).value;
|
||||
// Prepare input shape entries
|
||||
let num_shape_entries =
|
||||
sizet_model.constant(generator, ctx.ctx, u64::try_from(ndarrays.len()).unwrap());
|
||||
let shape_entries =
|
||||
shape_model.array_alloca(generator, ctx, num_shape_entries.value, "shape_entries");
|
||||
for (i, ndarray) in ndarrays.iter().enumerate() {
|
||||
let i = sizet_model.constant(generator, ctx.ctx, i as u64).value;
|
||||
|
||||
let shape_entry = shape_entries.offset(generator, ctx, i, "shape_entry");
|
||||
let shape_entry = shape_entries.offset(generator, ctx, i, "shape_entry");
|
||||
|
||||
let this_ndims = ndarray.value.get(generator, ctx, |f| f.ndims, "this_ndims");
|
||||
shape_entry.set(ctx, |f| f.ndims, this_ndims);
|
||||
let this_ndims = ndarray.value.get(generator, ctx, |f| f.ndims, "this_ndims");
|
||||
shape_entry.set(ctx, |f| f.ndims, this_ndims);
|
||||
|
||||
let this_shape = ndarray.value.get(generator, ctx, |f| f.shape, "this_shape");
|
||||
shape_entry.set(ctx, |f| f.shape, this_shape);
|
||||
}
|
||||
let this_shape = ndarray.value.get(generator, ctx, |f| f.shape, "this_shape");
|
||||
shape_entry.set(ctx, |f| f.shape, this_shape);
|
||||
}
|
||||
|
||||
// Prepare destination
|
||||
let broadcast_ndims_llvm = sizet_model.constant(generator, ctx.ctx, broadcast_ndims);
|
||||
let broadcast_shape =
|
||||
sizet_model.array_alloca(generator, ctx, broadcast_ndims_llvm.value, "dst_shape");
|
||||
// Prepare destination
|
||||
let broadcast_ndims_llvm = sizet_model.constant(generator, ctx.ctx, broadcast_ndims);
|
||||
let broadcast_shape =
|
||||
sizet_model.array_alloca(generator, ctx, broadcast_ndims_llvm.value, "dst_shape");
|
||||
|
||||
// Compute the target broadcast shape `dst_shape` for all ndarrays.
|
||||
call_nac3_ndarray_broadcast_shapes(
|
||||
generator,
|
||||
ctx,
|
||||
num_shape_entries,
|
||||
shape_entries,
|
||||
broadcast_ndims_llvm,
|
||||
broadcast_shape,
|
||||
);
|
||||
// Compute the target broadcast shape `dst_shape` for all ndarrays.
|
||||
call_nac3_ndarray_broadcast_shapes(
|
||||
generator,
|
||||
ctx,
|
||||
num_shape_entries,
|
||||
shape_entries,
|
||||
broadcast_ndims_llvm,
|
||||
broadcast_shape,
|
||||
);
|
||||
|
||||
// Now that we know about the broadcasting shape, broadcast all the inputs.
|
||||
// Now that we know about the broadcasting shape, broadcast all the inputs.
|
||||
|
||||
// Broadcast all the inputs to shape `dst_shape`.
|
||||
let broadcast_ndarrays: Vec<_> = ndarrays
|
||||
.into_iter()
|
||||
.map(|ndarray| ndarray.broadcast_to(generator, ctx, broadcast_ndims, broadcast_shape))
|
||||
.collect_vec();
|
||||
// Broadcast all the inputs to shape `dst_shape`.
|
||||
let broadcast_ndarrays: Vec<_> = ndarrays
|
||||
.iter()
|
||||
.map(|ndarray| ndarray.broadcast_to(generator, ctx, broadcast_ndims, broadcast_shape))
|
||||
.collect_vec();
|
||||
|
||||
BroadcastAllResult {
|
||||
ndims: broadcast_ndims,
|
||||
shape: broadcast_shape,
|
||||
ndarrays: broadcast_ndarrays,
|
||||
BroadcastAllResult {
|
||||
ndims: broadcast_ndims,
|
||||
shape: broadcast_shape,
|
||||
ndarrays: broadcast_ndarrays,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,477 @@
|
|||
use inkwell::{
|
||||
values::{BasicValue, FloatValue, IntValue},
|
||||
FloatPredicate, IntPredicate,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::{
|
||||
codegen::{
|
||||
llvm_intrinsics,
|
||||
model::{
|
||||
util::{gen_for_model_auto, gen_if_model},
|
||||
*,
|
||||
},
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
typecheck::typedef::Type,
|
||||
};
|
||||
|
||||
use super::{scalar::ScalarObject, NDArrayObject};
|
||||
|
||||
/// Convenience function to crash the program when types of arguments are not supported.
|
||||
/// Used to be debugged with a stacktrace.
|
||||
fn unsupported_type<I>(ctx: &CodeGenContext<'_, '_>, tys: I) -> !
|
||||
where
|
||||
I: IntoIterator<Item = Type>,
|
||||
{
|
||||
unreachable!(
|
||||
"unsupported types found '{}'",
|
||||
tys.into_iter().map(|ty| format!("'{}'", ctx.unifier.stringify(ty))).join(", "),
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum FloorOrCeil {
|
||||
Floor,
|
||||
Ceil,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum MinOrMax {
|
||||
Min,
|
||||
Max,
|
||||
}
|
||||
|
||||
fn signed_ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
|
||||
vec![ctx.primitives.int32, ctx.primitives.int64]
|
||||
}
|
||||
|
||||
fn unsigned_ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
|
||||
vec![ctx.primitives.uint32, ctx.primitives.uint64]
|
||||
}
|
||||
|
||||
fn ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
|
||||
vec![ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64]
|
||||
}
|
||||
|
||||
fn int_like(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
|
||||
vec![
|
||||
ctx.primitives.bool,
|
||||
ctx.primitives.int32,
|
||||
ctx.primitives.int64,
|
||||
ctx.primitives.uint32,
|
||||
ctx.primitives.uint64,
|
||||
]
|
||||
}
|
||||
|
||||
fn cast_to_int_conversion<'ctx, 'a, G, HandleFloatFn>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
scalar: ScalarObject<'ctx>,
|
||||
target_int_dtype: Type,
|
||||
handle_float: HandleFloatFn,
|
||||
) -> ScalarObject<'ctx>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
HandleFloatFn:
|
||||
FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, FloatValue<'ctx>) -> IntValue<'ctx>,
|
||||
{
|
||||
let target_int_dtype_llvm = ctx.get_llvm_type(generator, target_int_dtype).into_int_type();
|
||||
|
||||
let result = if ctx.unifier.unioned(scalar.dtype, ctx.primitives.float) {
|
||||
// Special handling for floats
|
||||
let n = scalar.value.into_float_value();
|
||||
handle_float(generator, ctx, n)
|
||||
} else if ctx.unifier.unioned_any(scalar.dtype, int_like(ctx)) {
|
||||
let n = scalar.value.into_int_value();
|
||||
|
||||
if n.get_type().get_bit_width() <= target_int_dtype_llvm.get_bit_width() {
|
||||
ctx.builder.build_int_z_extend(n, target_int_dtype_llvm, "zext").unwrap()
|
||||
} else {
|
||||
ctx.builder.build_int_truncate(n, target_int_dtype_llvm, "trunc").unwrap()
|
||||
}
|
||||
} else {
|
||||
unsupported_type(ctx, [scalar.dtype]);
|
||||
};
|
||||
|
||||
assert_eq!(target_int_dtype_llvm.get_bit_width(), result.get_type().get_bit_width()); // Sanity check
|
||||
ScalarObject { value: result.into(), dtype: target_int_dtype }
|
||||
}
|
||||
|
||||
impl<'ctx> ScalarObject<'ctx> {
|
||||
/// Compare two scalars. Only int-to-int and float-to-float comparisons are allowed.
|
||||
/// Panic otherwise.
|
||||
pub fn compare<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
lhs: ScalarObject<'ctx>,
|
||||
rhs: ScalarObject<'ctx>,
|
||||
int_predicate: IntPredicate,
|
||||
float_predicate: FloatPredicate,
|
||||
name: &str,
|
||||
) -> Int<'ctx, Bool> {
|
||||
if !ctx.unifier.unioned(lhs.dtype, rhs.dtype) {
|
||||
unsupported_type(ctx, [lhs.dtype, rhs.dtype]);
|
||||
}
|
||||
|
||||
let bool_model = IntModel(Bool);
|
||||
|
||||
let common_ty = lhs.dtype;
|
||||
let result = if ctx.unifier.unioned(common_ty, ctx.primitives.float) {
|
||||
let lhs = lhs.value.into_float_value();
|
||||
let rhs = rhs.value.into_float_value();
|
||||
ctx.builder.build_float_compare(float_predicate, lhs, rhs, name).unwrap()
|
||||
} else if ctx.unifier.unioned_any(common_ty, int_like(ctx)) {
|
||||
let lhs = lhs.value.into_int_value();
|
||||
let rhs = rhs.value.into_int_value();
|
||||
ctx.builder.build_int_compare(int_predicate, lhs, rhs, name).unwrap()
|
||||
} else {
|
||||
unsupported_type(ctx, [lhs.dtype, rhs.dtype]);
|
||||
};
|
||||
|
||||
bool_model.check_value(generator, ctx.ctx, result).unwrap()
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `int32()`.
|
||||
#[must_use]
|
||||
pub fn cast_to_int32<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> Self {
|
||||
cast_to_int_conversion(
|
||||
generator,
|
||||
ctx,
|
||||
*self,
|
||||
ctx.primitives.int32,
|
||||
|_generator, ctx, input| {
|
||||
let n =
|
||||
ctx.builder.build_float_to_signed_int(input, ctx.ctx.i64_type(), "").unwrap();
|
||||
ctx.builder.build_int_truncate(n, ctx.ctx.i32_type(), "conv").unwrap()
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `int64()`.
|
||||
#[must_use]
|
||||
pub fn cast_to_int64<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> Self {
|
||||
cast_to_int_conversion(
|
||||
generator,
|
||||
ctx,
|
||||
*self,
|
||||
ctx.primitives.int64,
|
||||
|_generator, ctx, input| {
|
||||
ctx.builder.build_float_to_signed_int(input, ctx.ctx.i64_type(), "").unwrap()
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `uint32()`.
|
||||
#[must_use]
|
||||
pub fn cast_to_uint32<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> Self {
|
||||
cast_to_int_conversion(
|
||||
generator,
|
||||
ctx,
|
||||
*self,
|
||||
ctx.primitives.uint32,
|
||||
|_generator, ctx, n| {
|
||||
let n_gez = ctx
|
||||
.builder
|
||||
.build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "")
|
||||
.unwrap();
|
||||
|
||||
let to_int32 =
|
||||
ctx.builder.build_float_to_signed_int(n, ctx.ctx.i32_type(), "").unwrap();
|
||||
let to_uint64 =
|
||||
ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_select(
|
||||
n_gez,
|
||||
ctx.builder.build_int_truncate(to_uint64, ctx.ctx.i32_type(), "").unwrap(),
|
||||
to_int32,
|
||||
"conv",
|
||||
)
|
||||
.unwrap()
|
||||
.into_int_value()
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `uint64()`.
|
||||
#[must_use]
|
||||
pub fn cast_to_uint64<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> Self {
|
||||
cast_to_int_conversion(
|
||||
generator,
|
||||
ctx,
|
||||
*self,
|
||||
ctx.primitives.uint64,
|
||||
|_generator, ctx, n| {
|
||||
let val_gez = ctx
|
||||
.builder
|
||||
.build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "")
|
||||
.unwrap();
|
||||
|
||||
let to_int64 =
|
||||
ctx.builder.build_float_to_signed_int(n, ctx.ctx.i64_type(), "").unwrap();
|
||||
let to_uint64 =
|
||||
ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap();
|
||||
ctx.builder
|
||||
.build_select(val_gez, to_uint64, to_int64, "conv")
|
||||
.unwrap()
|
||||
.into_int_value()
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `bool()`.
|
||||
#[must_use]
|
||||
pub fn cast_to_bool<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> Self {
|
||||
// TODO: Why is the original code being so lax about i1 and i8 for the returned int type?
|
||||
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.bool) {
|
||||
self.value.into_int_value()
|
||||
} else if ctx.unifier.unioned_any(self.dtype, ints(ctx)) {
|
||||
let n = self.value.into_int_value();
|
||||
ctx.builder
|
||||
.build_int_compare(inkwell::IntPredicate::NE, n, n.get_type().const_zero(), "bool")
|
||||
.unwrap()
|
||||
} else if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
||||
let n = self.value.into_float_value();
|
||||
ctx.builder
|
||||
.build_float_compare(FloatPredicate::UNE, n, n.get_type().const_zero(), "bool")
|
||||
.unwrap()
|
||||
} else {
|
||||
unsupported_type(ctx, [self.dtype])
|
||||
};
|
||||
|
||||
ScalarObject { dtype: ctx.primitives.bool, value: result.as_basic_value_enum() }
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `round()`.
|
||||
#[must_use]
|
||||
pub fn round<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
target_int_dtype: Type,
|
||||
) -> Self {
|
||||
let target_int_dtype_llvm = ctx.get_llvm_type(generator, target_int_dtype).into_int_type();
|
||||
|
||||
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
||||
let n = self.value.into_float_value();
|
||||
let n = llvm_intrinsics::call_float_round(ctx, n, None);
|
||||
ctx.builder.build_float_to_signed_int(n, target_int_dtype_llvm, "round").unwrap()
|
||||
} else {
|
||||
unsupported_type(ctx, [self.dtype, target_int_dtype])
|
||||
};
|
||||
ScalarObject { dtype: target_int_dtype, value: result.as_basic_value_enum() }
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `np_round()`.
|
||||
///
|
||||
/// NOTE: `np.round()` has different behaviors than `round()` in terms of their result
|
||||
/// on "tie" cases and return type.
|
||||
#[must_use]
|
||||
pub fn np_round<G: CodeGenerator + ?Sized>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
|
||||
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
||||
let n = self.value.into_float_value();
|
||||
llvm_intrinsics::call_float_rint(ctx, n, None)
|
||||
} else {
|
||||
unsupported_type(ctx, [self.dtype])
|
||||
};
|
||||
ScalarObject { dtype: ctx.primitives.float, value: result.as_basic_value_enum() }
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `min()` or `max()`.
|
||||
fn min_or_max_helper(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
kind: MinOrMax,
|
||||
a: Self,
|
||||
b: Self,
|
||||
) -> Self {
|
||||
if !ctx.unifier.unioned(a.dtype, b.dtype) {
|
||||
unsupported_type(ctx, [a.dtype, b.dtype])
|
||||
}
|
||||
|
||||
let common_dtype = a.dtype;
|
||||
|
||||
if ctx.unifier.unioned(common_dtype, ctx.primitives.float) {
|
||||
let function = match kind {
|
||||
MinOrMax::Min => llvm_intrinsics::call_float_minnum,
|
||||
MinOrMax::Max => llvm_intrinsics::call_float_maxnum,
|
||||
};
|
||||
let result =
|
||||
function(ctx, a.value.into_float_value(), b.value.into_float_value(), None);
|
||||
ScalarObject { value: result.as_basic_value_enum(), dtype: ctx.primitives.float }
|
||||
} else if ctx.unifier.unioned_any(
|
||||
common_dtype,
|
||||
[unsigned_ints(ctx).as_slice(), &[ctx.primitives.bool]].concat(),
|
||||
) {
|
||||
// Treating bool has an unsigned int since that is convenient
|
||||
let function = match kind {
|
||||
MinOrMax::Min => llvm_intrinsics::call_int_umin,
|
||||
MinOrMax::Max => llvm_intrinsics::call_int_umax,
|
||||
};
|
||||
let result = function(ctx, a.value.into_int_value(), b.value.into_int_value(), None);
|
||||
ScalarObject { value: result.as_basic_value_enum(), dtype: common_dtype }
|
||||
} else {
|
||||
unsupported_type(ctx, [common_dtype])
|
||||
}
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `floor()` or `ceil()`.
|
||||
#[must_use]
|
||||
pub fn floor_or_ceil<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
kind: FloorOrCeil,
|
||||
target_int_dtype: Type,
|
||||
) -> Self {
|
||||
let target_int_dtype_llvm = ctx.get_llvm_type(generator, target_int_dtype).into_int_type();
|
||||
|
||||
if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
||||
let function = match kind {
|
||||
FloorOrCeil::Floor => llvm_intrinsics::call_float_floor,
|
||||
FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil,
|
||||
};
|
||||
let n = self.value.into_float_value();
|
||||
let n = function(ctx, n, None);
|
||||
let n = ctx.builder.build_float_to_signed_int(n, target_int_dtype_llvm, "").unwrap();
|
||||
ScalarObject { dtype: target_int_dtype, value: n.as_basic_value_enum() }
|
||||
} else {
|
||||
unsupported_type(ctx, [self.dtype])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> NDArrayObject<'ctx> {
|
||||
/// Helper function for NAC3's builtin `np_min()`, `np_max()`, `np_argmin()`, and `np_argmax()`.
|
||||
///
|
||||
/// Generate LLVM IR to find the extremum and index of the **first** extremum value.
|
||||
///
|
||||
/// Care has also been taken to make the error messages match that of NumPy.
|
||||
fn min_or_max_helper<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
kind: MinOrMax,
|
||||
on_empty_err_msg: &str,
|
||||
) -> (ScalarObject<'ctx>, Int<'ctx, SizeT>) {
|
||||
let sizet_model = IntModel(SizeT);
|
||||
let dtype_llvm = ctx.get_llvm_type(generator, self.dtype);
|
||||
|
||||
// If the ndarray is empty, throw an error.
|
||||
let is_empty = self.is_empty(generator, ctx);
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
is_empty.value,
|
||||
"0:ValueError",
|
||||
on_empty_err_msg,
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
// Setup and initialize the extremum to be the first element in the ndarray
|
||||
let pextremum_index = sizet_model.alloca(generator, ctx, "extremum_index");
|
||||
let pextremum = ctx.builder.build_alloca(dtype_llvm, "extremum").unwrap();
|
||||
|
||||
let zero = sizet_model.const_0(generator, ctx.ctx);
|
||||
pextremum_index.store(ctx, zero);
|
||||
|
||||
let first_scalar = self.get_nth(generator, ctx, zero);
|
||||
ctx.builder.build_store(pextremum, first_scalar.value).unwrap();
|
||||
|
||||
// Find extremum
|
||||
let start = sizet_model.const_1(generator, ctx.ctx); // Start on 1
|
||||
let stop = self.size(generator, ctx);
|
||||
let step = sizet_model.const_1(generator, ctx.ctx);
|
||||
gen_for_model_auto(generator, ctx, start, stop, step, |generator, ctx, _hooks, i| {
|
||||
// Worth reading on "Notes" in <https://numpy.org/doc/stable/reference/generated/numpy.min.html#numpy.min>
|
||||
// on how `NaN` values have to be handled.
|
||||
|
||||
let scalar = self.get_nth(generator, ctx, i);
|
||||
|
||||
let old_extremum = ctx.builder.build_load(pextremum, "current_extremum").unwrap();
|
||||
let old_extremum = ScalarObject { dtype: self.dtype, value: old_extremum };
|
||||
|
||||
let new_extremum = ScalarObject::min_or_max_helper(ctx, kind, old_extremum, scalar);
|
||||
|
||||
// Check if new_extremum is more extreme than old_extremum.
|
||||
let update_index = ScalarObject::compare(
|
||||
generator,
|
||||
ctx,
|
||||
new_extremum,
|
||||
old_extremum,
|
||||
IntPredicate::NE,
|
||||
FloatPredicate::ONE,
|
||||
"",
|
||||
);
|
||||
|
||||
gen_if_model(generator, ctx, update_index, |_generator, ctx| {
|
||||
pextremum_index.store(ctx, i);
|
||||
Ok(())
|
||||
})
|
||||
.unwrap();
|
||||
Ok(())
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// Finally return the extremum and extremum index.
|
||||
let extremum_index = pextremum_index.load(generator, ctx, "extremum_index");
|
||||
|
||||
let extremum = ctx.builder.build_load(pextremum, "extremum_value").unwrap();
|
||||
let extremum = ScalarObject { dtype: self.dtype, value: extremum };
|
||||
|
||||
(extremum, extremum_index)
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `np_min()` or `np_max()`.
|
||||
pub fn min_or_max<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
kind: MinOrMax,
|
||||
) -> ScalarObject<'ctx> {
|
||||
let on_empty_err_msg = format!(
|
||||
"zero-size array to reduction operation {} which has no identity",
|
||||
match kind {
|
||||
MinOrMax::Min => "minimum",
|
||||
MinOrMax::Max => "maximum",
|
||||
}
|
||||
);
|
||||
self.min_or_max_helper(generator, ctx, kind, &on_empty_err_msg).0
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `np_argmin()` or `np_argmax()`.
|
||||
pub fn argmin_or_argmax<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
kind: MinOrMax,
|
||||
) -> Int<'ctx, SizeT> {
|
||||
let on_empty_err_msg = format!(
|
||||
"attempt to get {} of an empty sequence",
|
||||
match kind {
|
||||
MinOrMax::Min => "argmin",
|
||||
MinOrMax::Max => "argmax",
|
||||
}
|
||||
);
|
||||
self.min_or_max_helper(generator, ctx, kind, &on_empty_err_msg).1
|
||||
}
|
||||
}
|
|
@ -1,7 +1,4 @@
|
|||
use crate::codegen::{
|
||||
irrt::call_nac3_ndarray_index, model::*, structure::ndarray::scalar::ScalarObject,
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
use crate::codegen::{irrt::call_nac3_ndarray_index, model::*, CodeGenContext, CodeGenerator};
|
||||
|
||||
use super::{scalar::ScalarOrNDArray, NDArrayObject};
|
||||
|
||||
|
@ -130,7 +127,7 @@ impl<'ctx> RustNDIndex<'ctx> {
|
|||
dst_ndindex_ptr: Ptr<'ctx, StructModel<NDIndex>>,
|
||||
) {
|
||||
let ndindex_type_model = IntModel(NDIndexType::default());
|
||||
let i32_model = IntModel(Int32::default());
|
||||
let i32_model = IntModel(Int32);
|
||||
let user_slice_model = StructModel(UserSlice);
|
||||
|
||||
// Set `dst_ndindex_ptr->type`
|
||||
|
@ -178,6 +175,7 @@ impl<'ctx> RustNDIndex<'ctx> {
|
|||
|
||||
impl<'ctx> NDArrayObject<'ctx> {
|
||||
/// Get the ndims [`Type`] after indexing with a given slice.
|
||||
#[must_use]
|
||||
pub fn deduce_ndims_after_indexing_with(&self, indexes: &[RustNDIndex<'ctx>]) -> u64 {
|
||||
let mut ndims = self.ndims;
|
||||
for index in indexes {
|
||||
|
@ -235,11 +233,8 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
|
||||
let subndarray = self.index(generator, ctx, indexes, name);
|
||||
if subndarray.is_unsized() {
|
||||
// NOTE: `np.size(self) == 0` is impossible.
|
||||
let pfirst = subndarray.get_nth_pelement(generator, ctx, zero, "pfirst");
|
||||
let first = ctx.builder.build_load(pfirst, "first").unwrap();
|
||||
|
||||
ScalarOrNDArray::Scalar(ScalarObject { dtype: self.dtype, value: first })
|
||||
// NOTE: `np.size(self) == 0` here is never possible.
|
||||
ScalarOrNDArray::Scalar(subndarray.get_nth(generator, ctx, zero))
|
||||
} else {
|
||||
ScalarOrNDArray::NDArray(subndarray)
|
||||
}
|
||||
|
@ -319,9 +314,9 @@ pub mod util {
|
|||
})
|
||||
};
|
||||
|
||||
let start = help(&start)?;
|
||||
let stop = help(&stop)?;
|
||||
let step = help(&step)?;
|
||||
let start = help(start)?;
|
||||
let stop = help(stop)?;
|
||||
let step = help(step)?;
|
||||
|
||||
RustNDIndex::Slice(RustUserSlice { start, stop, step })
|
||||
} else {
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
use inkwell::values::BasicValueEnum;
|
||||
use itertools::Itertools;
|
||||
use util::gen_for_model_auto;
|
||||
|
||||
use crate::{
|
||||
codegen::{
|
||||
model::*,
|
||||
structure::ndarray::{
|
||||
broadcast::broadcast_all_ndarrays, scalar::ScalarObject, NDArrayObject,
|
||||
},
|
||||
structure::ndarray::{scalar::ScalarObject, NDArrayObject},
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
typecheck::typedef::Type,
|
||||
|
@ -14,179 +13,126 @@ use crate::{
|
|||
|
||||
use super::scalar::ScalarOrNDArray;
|
||||
|
||||
pub fn starmap_scalars_array_like<'ctx, 'a, F, G>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
inputs: &Vec<ScalarOrNDArray<'ctx>>,
|
||||
mapping: F,
|
||||
) -> Result<ScalarOrNDArray<'ctx>, String>
|
||||
where
|
||||
F: FnOnce(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
Int<'ctx, SizeT>,
|
||||
&Vec<ScalarObject<'ctx>>,
|
||||
) -> Result<ScalarObject<'ctx>, String>,
|
||||
G: CodeGenerator + ?Sized,
|
||||
{
|
||||
assert!(!inputs.is_empty());
|
||||
|
||||
let sizet_model = IntModel(SizeT);
|
||||
|
||||
// Check if all inputs are ScalarObjects
|
||||
let scalars: Option<Vec<_>> =
|
||||
inputs.iter().map(|input| ScalarObject::try_from(input)).try_collect().ok();
|
||||
|
||||
match scalars {
|
||||
Some(scalars) => {
|
||||
// When inputs are all scalars, return a ScalarObject back
|
||||
|
||||
let i = sizet_model.const_0(generator, ctx.ctx);
|
||||
|
||||
let scalar = mapping(generator, ctx, i, &scalars)?;
|
||||
Ok(ScalarOrNDArray::Scalar(scalar))
|
||||
}
|
||||
None => {
|
||||
// When not all inputs are scalars, promote all non-ndarray inputs
|
||||
// to ndarrays, do broadcast_shapes on them, and map.
|
||||
|
||||
let ndarrays =
|
||||
inputs.into_iter().map(|input| input.as_ndarray(generator, ctx)).collect_vec();
|
||||
|
||||
let broadcast_result = broadcast_all_ndarrays(generator, ctx, &ndarrays);
|
||||
|
||||
let start = sizet_model.const_0(generator, ctx.ctx);
|
||||
let stop = broadcast_result.ndarrays[0].size(generator, ctx); // They all should have the same `np.size`.
|
||||
let step = sizet_model.const_1(generator, ctx.ctx);
|
||||
|
||||
// Map element-wise and store results into `mapped_ndarray`.
|
||||
let mapped_ndarray = gen_for_model_auto(
|
||||
generator,
|
||||
ctx,
|
||||
start,
|
||||
stop,
|
||||
step,
|
||||
move |generator, ctx, _hooks, i| {
|
||||
let elements = ndarrays
|
||||
.iter()
|
||||
.map(|ndarray| {
|
||||
let pelement = ndarray.get_nth_pelement(generator, ctx, i, "pelement");
|
||||
let element = ctx.builder.build_load(pelement, "element").unwrap();
|
||||
ScalarObject { value: element, dtype: ndarray.dtype }
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
let ret = mapping(generator, ctx, i, &elements)?;
|
||||
|
||||
// It might look weird but it is perfectly fine putting the allocation codegen
|
||||
// here within `for`.
|
||||
// The reason for doing this is to get the `dtype` out of `ret`, which is only
|
||||
// available after running `mapping`.
|
||||
let mapped_ndarray = NDArrayObject::alloca_uninitialized(
|
||||
generator,
|
||||
ctx,
|
||||
ret.dtype,
|
||||
broadcast_result.ndims,
|
||||
"mapped_ndarray",
|
||||
);
|
||||
mapped_ndarray.copy_shape_from_array(generator, ctx, broadcast_result.shape);
|
||||
mapped_ndarray.create_data(generator, ctx);
|
||||
|
||||
let pret = mapped_ndarray.get_nth_pelement(generator, ctx, i, "pret");
|
||||
ctx.builder.build_store(pret, ret.value).unwrap();
|
||||
Ok(mapped_ndarray)
|
||||
},
|
||||
)?;
|
||||
Ok(ScalarOrNDArray::NDArray(mapped_ndarray))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ScalarObject<'ctx> {
|
||||
pub fn map<'a, F, G>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
mapping: F,
|
||||
) -> Result<Self, String>
|
||||
where
|
||||
F: FnOnce(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
Int<'ctx, SizeT>,
|
||||
ScalarObject<'ctx>,
|
||||
) -> Result<ScalarObject<'ctx>, String>,
|
||||
G: CodeGenerator + ?Sized,
|
||||
{
|
||||
let ScalarOrNDArray::Scalar(ret) = starmap_scalars_array_like(
|
||||
generator,
|
||||
ctx,
|
||||
&vec![ScalarOrNDArray::Scalar(*self)],
|
||||
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]),
|
||||
)?
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
Ok(ret)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> NDArrayObject<'ctx> {
|
||||
pub fn map<'a, F, G>(
|
||||
&self,
|
||||
/// TODO: Document me. Has complex behavior.
|
||||
pub fn broadcasting_starmap<'a, G, MappingFn>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
mapping: F,
|
||||
ndarrays: &[Self],
|
||||
ret_dtype: Type,
|
||||
name: &str,
|
||||
mapping: MappingFn,
|
||||
) -> Result<Self, String>
|
||||
where
|
||||
F: FnOnce(
|
||||
G: CodeGenerator + ?Sized,
|
||||
MappingFn: FnOnce(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
Int<'ctx, SizeT>,
|
||||
ScalarObject<'ctx>,
|
||||
&[ScalarObject<'ctx>],
|
||||
) -> Result<ScalarObject<'ctx>, String>,
|
||||
G: CodeGenerator + ?Sized,
|
||||
{
|
||||
let ScalarOrNDArray::NDArray(ret) = starmap_scalars_array_like(
|
||||
let sizet_model = IntModel(SizeT);
|
||||
|
||||
// Broadcast inputs
|
||||
let broadcast_result = NDArrayObject::broadcast_all(generator, ctx, ndarrays);
|
||||
|
||||
// Allocate the resulting ndarray
|
||||
let mapped_ndarray = NDArrayObject::alloca_uninitialized(
|
||||
generator,
|
||||
ctx,
|
||||
&vec![ScalarOrNDArray::NDArray(*self)],
|
||||
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]),
|
||||
)?
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
Ok(ret)
|
||||
}
|
||||
}
|
||||
ret_dtype,
|
||||
broadcast_result.ndims,
|
||||
name,
|
||||
);
|
||||
mapped_ndarray.copy_shape_from_array(generator, ctx, broadcast_result.shape);
|
||||
mapped_ndarray.create_data(generator, ctx);
|
||||
|
||||
impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||
pub fn map<'a, F, G>(
|
||||
// Map element-wise and store results into `mapped_ndarray`.
|
||||
let start = sizet_model.const_0(generator, ctx.ctx);
|
||||
let stop = broadcast_result.ndarrays[0].size(generator, ctx); // They all should have the same `np.size`.
|
||||
let step = sizet_model.const_1(generator, ctx.ctx);
|
||||
gen_for_model_auto(generator, ctx, start, stop, step, move |generator, ctx, _hooks, i| {
|
||||
let elements =
|
||||
ndarrays.iter().map(|ndarray| ndarray.get_nth(generator, ctx, i)).collect_vec();
|
||||
|
||||
let ret = mapping(generator, ctx, i, &elements)?;
|
||||
|
||||
let pret = mapped_ndarray.get_nth_pointer(generator, ctx, i, "pret");
|
||||
ctx.builder.build_store(pret, ret.value).unwrap();
|
||||
Ok(())
|
||||
})?;
|
||||
|
||||
Ok(mapped_ndarray)
|
||||
}
|
||||
|
||||
pub fn map<'a, G, Mapping>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ret_dtype: Type,
|
||||
mapping: F,
|
||||
name: &str,
|
||||
mapping: Mapping,
|
||||
) -> Result<Self, String>
|
||||
where
|
||||
F: FnOnce(
|
||||
G: CodeGenerator + ?Sized,
|
||||
Mapping: FnOnce(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
Int<'ctx, SizeT>,
|
||||
ScalarObject<'ctx>,
|
||||
) -> Result<ScalarObject<'ctx>, String>,
|
||||
G: CodeGenerator + ?Sized,
|
||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
{
|
||||
match self {
|
||||
ScalarOrNDArray::Scalar(scalar) => starmap_scalars_array_like(
|
||||
generator,
|
||||
ctx,
|
||||
&vec![ScalarOrNDArray::Scalar(*scalar)],
|
||||
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]),
|
||||
),
|
||||
ScalarOrNDArray::NDArray(ndarray) => {
|
||||
ndarray.map(generator, ctx, mapping).map(ScalarOrNDArray::NDArray)
|
||||
}
|
||||
NDArrayObject::broadcasting_starmap(
|
||||
generator,
|
||||
ctx,
|
||||
&[*self],
|
||||
ret_dtype,
|
||||
name,
|
||||
|generator, ctx, i, scalars| {
|
||||
let value = mapping(generator, ctx, i, scalars[0])?;
|
||||
Ok(ScalarObject { dtype: ret_dtype, value })
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||
/// TODO: Document me. Has complex behavior.
|
||||
pub fn broadcasting_starmap<'a, G, MappingFn>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
inputs: &[Self],
|
||||
ret_dtype: Type,
|
||||
name: &str,
|
||||
mapping: MappingFn,
|
||||
) -> Result<Self, String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
MappingFn: FnOnce(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
Int<'ctx, SizeT>,
|
||||
&[ScalarObject<'ctx>],
|
||||
) -> Result<ScalarObject<'ctx>, String>,
|
||||
{
|
||||
let sizet_model = IntModel(SizeT);
|
||||
|
||||
// Check if all inputs are ScalarObjects
|
||||
let all_scalars: Option<Vec<_>> =
|
||||
inputs.iter().map(ScalarObject::try_from).try_collect().ok();
|
||||
|
||||
if let Some(scalars) = all_scalars {
|
||||
let i = sizet_model.const_0(generator, ctx.ctx); // Pass 0 as the index
|
||||
let scalar = mapping(generator, ctx, i, &scalars)?;
|
||||
Ok(ScalarOrNDArray::Scalar(scalar))
|
||||
} else {
|
||||
// Promote all input to ndarrays and map through them.
|
||||
let inputs = inputs.iter().map(|input| input.as_ndarray(generator, ctx)).collect_vec();
|
||||
let ndarray = NDArrayObject::broadcasting_starmap(
|
||||
generator, ctx, &inputs, ret_dtype, name, mapping,
|
||||
)?;
|
||||
Ok(ScalarOrNDArray::NDArray(ndarray))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
pub mod broadcast;
|
||||
pub mod functions;
|
||||
pub mod indexing;
|
||||
pub mod mapping;
|
||||
pub mod scalar;
|
||||
|
@ -22,8 +23,9 @@ use inkwell::{
|
|||
context::Context,
|
||||
types::BasicType,
|
||||
values::{BasicValue, BasicValueEnum, PointerValue},
|
||||
AddressSpace,
|
||||
AddressSpace, IntPredicate,
|
||||
};
|
||||
use scalar::ScalarObject;
|
||||
use util::{call_memcpy_model, gen_for_model_auto};
|
||||
|
||||
pub struct NpArrayFields<'ctx, F: FieldTraversal<'ctx>> {
|
||||
|
@ -52,6 +54,7 @@ impl<'ctx> StructKind<'ctx> for NpArray {
|
|||
}
|
||||
}
|
||||
|
||||
/// A NAC3 Python ndarray object.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct NDArrayObject<'ctx> {
|
||||
pub dtype: Type,
|
||||
|
@ -116,7 +119,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
/// Get the pointer to the n-th (0-based) element.
|
||||
///
|
||||
/// The returned pointer has the element type of the LLVM type of this ndarray's `dtype`.
|
||||
pub fn get_nth_pelement<G: CodeGenerator + ?Sized>(
|
||||
pub fn get_nth_pointer<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
|
@ -131,6 +134,18 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
.unwrap()
|
||||
}
|
||||
|
||||
/// Get the n-th (0-based) scalar.
|
||||
pub fn get_nth<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
nth: Int<'ctx, SizeT>,
|
||||
) -> ScalarObject<'ctx> {
|
||||
let p = self.get_nth_pointer(generator, ctx, nth, "value");
|
||||
let value = ctx.builder.build_load(p, "value").unwrap();
|
||||
ScalarObject { dtype: self.dtype, value }
|
||||
}
|
||||
|
||||
/// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`.
|
||||
///
|
||||
/// Please refer to the IRRT implementation to see its purpose.
|
||||
|
@ -210,7 +225,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
name: &str,
|
||||
) -> Self {
|
||||
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty);
|
||||
let ndims = extract_ndims(&mut ctx.unifier, ndims);
|
||||
let ndims = extract_ndims(&ctx.unifier, ndims);
|
||||
Self::alloca_uninitialized(generator, ctx, dtype, ndims, name)
|
||||
}
|
||||
|
||||
|
@ -224,7 +239,22 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
sizet_model.constant(generator, ctx, self.ndims)
|
||||
}
|
||||
|
||||
/// Return true if this ndarray is unsized.
|
||||
/// Get if this ndarray's `np.size` is `0` - containing no content.
|
||||
pub fn is_empty<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> Int<'ctx, Bool> {
|
||||
let sizet_model = IntModel(SizeT);
|
||||
|
||||
let size = self.size(generator, ctx);
|
||||
size.compare(ctx, IntPredicate::EQ, sizet_model.const_0(generator, ctx.ctx), "is_empty")
|
||||
}
|
||||
|
||||
/// Return true if this ndarray is unsized - `ndims == 0` and only contains a scalar.
|
||||
///
|
||||
/// This is a staticially known property of ndarrays. This is why it is returning
|
||||
/// a Rust boolean instead of a [`BasicValue`].
|
||||
#[must_use]
|
||||
pub fn is_unsized(&self) -> bool {
|
||||
self.ndims == 0
|
||||
|
@ -273,7 +303,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
) {
|
||||
assert_eq!(self.ndims, src_ndarray.ndims);
|
||||
let src_shape = src_ndarray.value.get(generator, ctx, |f| f.shape, "src_shape");
|
||||
self.copy_shape_from_array(generator, ctx, src_shape)
|
||||
self.copy_shape_from_array(generator, ctx, src_shape);
|
||||
}
|
||||
|
||||
/// Copy strides dimensions from an array.
|
||||
|
@ -298,14 +328,14 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
) {
|
||||
assert_eq!(self.ndims, src_ndarray.ndims);
|
||||
let src_strides = src_ndarray.value.get(generator, ctx, |f| f.strides, "src_strides");
|
||||
self.copy_strides_from_array(generator, ctx, src_strides)
|
||||
self.copy_strides_from_array(generator, ctx, src_strides);
|
||||
}
|
||||
|
||||
/// Loop through every element pointer in the ndarray in its flatten view.
|
||||
/// Iterate through every element pointer in the ndarray in its flatten view.
|
||||
///
|
||||
/// `body` also access to [`BreakContinueHooks`] to short-circuit and an element's
|
||||
/// index. The given element pointer also has been casted to the LLVM type of this ndarray's `dtype`.
|
||||
pub fn foreach<'a, G, F>(
|
||||
pub fn foreach_pointer<'a, G, F>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
|
@ -328,12 +358,36 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
let step = sizet_model.const_1(generator, ctx.ctx);
|
||||
|
||||
gen_for_model_auto(generator, ctx, start, stop, step, |generator, ctx, hooks, i| {
|
||||
let pelement = self.get_nth_pelement(generator, ctx, i, "element");
|
||||
let pelement = self.get_nth_pointer(generator, ctx, i, "element");
|
||||
body(generator, ctx, hooks, i, pelement)
|
||||
})
|
||||
}
|
||||
|
||||
/// Fill the NDArray with a value.
|
||||
/// Iterate through every scalar in this ndarray.
|
||||
pub fn foreach<'a, G, F>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
body: F,
|
||||
) -> Result<(), String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
F: FnOnce(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
BreakContinueHooks<'ctx>,
|
||||
Int<'ctx, SizeT>,
|
||||
ScalarObject<'ctx>,
|
||||
) -> Result<(), String>,
|
||||
{
|
||||
self.foreach_pointer(generator, ctx, |generator, ctx, hooks, i, p| {
|
||||
let value = ctx.builder.build_load(p, "value").unwrap();
|
||||
let scalar = ScalarObject { dtype: self.dtype, value };
|
||||
body(generator, ctx, hooks, i, scalar)
|
||||
})
|
||||
}
|
||||
|
||||
/// Fill the ndarray with a value.
|
||||
///
|
||||
/// `fill_value` must have the same LLVM type as the `dtype` of this ndarray.
|
||||
pub fn fill<G: CodeGenerator + ?Sized>(
|
||||
|
@ -342,11 +396,11 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
fill_value: BasicValueEnum<'ctx>,
|
||||
) {
|
||||
self.foreach(generator, ctx, |_generator, ctx, _hooks, _i, pelement| {
|
||||
self.foreach_pointer(generator, ctx, |_generator, ctx, _hooks, _i, pelement| {
|
||||
ctx.builder.build_store(pelement, fill_value).unwrap();
|
||||
Ok(())
|
||||
})
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
/// Create a reshaped view on this ndarray like `np.reshape()`.
|
||||
|
|
|
@ -8,6 +8,10 @@ use crate::{
|
|||
use super::NDArrayObject;
|
||||
|
||||
/// An LLVM numpy scalar with its [`Type`].
|
||||
///
|
||||
/// Intended to be used with [`ScalarOrNDArray`].
|
||||
///
|
||||
/// A scalar does not have to be an actual number. It could be arbitrary objects.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ScalarObject<'ctx> {
|
||||
pub dtype: Type,
|
||||
|
@ -55,6 +59,22 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
|
|||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn into_scalar(&self) -> ScalarObject<'ctx> {
|
||||
match self {
|
||||
ScalarOrNDArray::NDArray(_ndarray) => panic!("Got NDArray"),
|
||||
ScalarOrNDArray::Scalar(scalar) => *scalar,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn into_ndarray(&self) -> NDArrayObject<'ctx> {
|
||||
match self {
|
||||
ScalarOrNDArray::NDArray(ndarray) => *ndarray,
|
||||
ScalarOrNDArray::Scalar(_scalar) => panic!("Got Scalar"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert this [`ScalarOrNDArray`] to an ndarray - behaves like `np.asarray`.
|
||||
/// - If this is an ndarray, the ndarray is returned.
|
||||
/// - If this is a scalar, an unsized ndarray view is created on it.
|
||||
|
@ -68,6 +88,14 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
|
|||
ScalarOrNDArray::Scalar(scalar) => scalar.as_ndarray(generator, ctx),
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn dtype(&self) -> Type {
|
||||
match self {
|
||||
ScalarOrNDArray::Scalar(scalar) => scalar.dtype,
|
||||
ScalarOrNDArray::NDArray(ndarray) => ndarray.dtype,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for ScalarObject<'ctx> {
|
||||
|
@ -76,7 +104,18 @@ impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for ScalarObject<'ctx> {
|
|||
fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
ScalarOrNDArray::Scalar(scalar) => Ok(*scalar),
|
||||
ScalarOrNDArray::NDArray(_) => Err(()),
|
||||
ScalarOrNDArray::NDArray(_ndarray) => Err(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for NDArrayObject<'ctx> {
|
||||
type Error = ();
|
||||
|
||||
fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
ScalarOrNDArray::Scalar(_scalar) => Err(()),
|
||||
ScalarOrNDArray::NDArray(ndarray) => Ok(*ndarray),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,15 +2,11 @@ use inkwell::values::BasicValueEnum;
|
|||
use util::gen_for_model_auto;
|
||||
|
||||
use crate::{
|
||||
codegen::{
|
||||
model::*,
|
||||
structure::list::{List, ListObject},
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
codegen::{model::*, structure::list::ListObject, CodeGenContext, CodeGenerator},
|
||||
typecheck::typedef::{Type, TypeEnum},
|
||||
};
|
||||
|
||||
/// Parse a NumPy-like "int sequence" input and return the int sequence as a [`ListObject`]
|
||||
/// Parse a NumPy-like "int sequence" input and return the int sequence as an array and its length.
|
||||
///
|
||||
/// * `sequence` - The `sequence` parameter.
|
||||
/// * `sequence_ty` - The typechecker type of `sequence`
|
||||
|
@ -20,99 +16,97 @@ use crate::{
|
|||
/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))`
|
||||
/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
||||
///
|
||||
/// `int32` values will be sign-extended to `SizeT`
|
||||
/// All `int32` values will be sign-extended to `SizeT`.
|
||||
pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
sequence: BasicValueEnum<'ctx>,
|
||||
sequence_ty: Type,
|
||||
) -> ListObject<'ctx, IntModel<SizeT>, SizeT> {
|
||||
input_sequence: BasicValueEnum<'ctx>,
|
||||
input_sequence_ty: Type,
|
||||
) -> (Int<'ctx, SizeT>, Ptr<'ctx, IntModel<SizeT>>) {
|
||||
let sizet_model = IntModel(SizeT);
|
||||
let list_model = StructModel(List { len: SizeT, item: IntModel(SizeT) });
|
||||
|
||||
let zero = sizet_model.const_0(generator, ctx.ctx);
|
||||
let one = sizet_model.const_1(generator, ctx.ctx);
|
||||
|
||||
// The result `list` to return.
|
||||
let result = list_model.alloca(generator, ctx, "result_sequence");
|
||||
match &*ctx.unifier.get_ty(sequence_ty) {
|
||||
match &*ctx.unifier.get_ty(input_sequence_ty) {
|
||||
TypeEnum::TObj { obj_id, .. }
|
||||
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
||||
{
|
||||
// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
|
||||
let in_sequence_model =
|
||||
PtrModel(StructModel(List { item: IntModel(Int32), len: SizeT }));
|
||||
let in_sequence = in_sequence_model.check_value(generator, ctx.ctx, sequence).unwrap();
|
||||
|
||||
/*
|
||||
Reference code:
|
||||
```
|
||||
result.size = sequence.size;
|
||||
result.data = __builtin_alloca(sizeof(SizeT) * sequence.size);
|
||||
for (SizeT i = 0; i < sequence.size; i++) {
|
||||
result.data[i] = (SizeT) sequence.data[i];
|
||||
}
|
||||
return result
|
||||
```
|
||||
*/
|
||||
// Check `input_sequence`
|
||||
let input_sequence =
|
||||
ListObject::from_value_and_type(generator, ctx, input_sequence, input_sequence_ty);
|
||||
|
||||
let ndims = in_sequence.get(generator, ctx, |f| f.len, "size");
|
||||
result.set(ctx, |f| f.len, ndims);
|
||||
let len = input_sequence.value.gep(ctx, |f| f.len).load(generator, ctx, "len");
|
||||
let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence");
|
||||
|
||||
let result_data = sizet_model.array_alloca(generator, ctx, ndims.value, "data");
|
||||
result.set(ctx, |f| f.items, result_data);
|
||||
// Load all the `int32`s from the input_sequence, cast them to `SizeT`, and store them into `result`
|
||||
gen_for_model_auto(generator, ctx, zero, len, one, |generator, ctx, _hooks, i| {
|
||||
// Load the i-th int32 in the input sequence
|
||||
let int = input_sequence
|
||||
.value
|
||||
.get(generator, ctx, |f| f.items, "int")
|
||||
.ix(generator, ctx, i.value, "int")
|
||||
.value
|
||||
.into_int_value();
|
||||
|
||||
// Cast to SizeT
|
||||
let int = sizet_model.s_extend_or_bit_cast(generator, ctx, int, "int");
|
||||
|
||||
// Store
|
||||
result.offset(generator, ctx, i.value, "int").store(ctx, int);
|
||||
|
||||
gen_for_model_auto(generator, ctx, zero, ndims, one, |generator, ctx, _hooks, i| {
|
||||
let in_dim = in_sequence
|
||||
.get(generator, ctx, |f| f.items, "in_dim")
|
||||
.ix(generator, ctx, i.value, "in_dim")
|
||||
.s_extend_or_bit_cast(generator, ctx, SizeT, "in_dim");
|
||||
result_data.offset(generator, ctx, i.value, "dim").store(ctx, in_dim);
|
||||
Ok(())
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
(len, result)
|
||||
}
|
||||
TypeEnum::TTuple { ty: tuple_types } => {
|
||||
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
|
||||
let ndims_int = tuple_types.len();
|
||||
let ndims = sizet_model.constant(generator, ctx.ctx, ndims_int as u64);
|
||||
result.set(ctx, |f| f.len, ndims);
|
||||
let input_sequence = input_sequence.into_struct_value(); // A tuple is a struct
|
||||
|
||||
// A tuple has to be a StructValue
|
||||
// Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM.
|
||||
let tuple = sequence.into_struct_value();
|
||||
let data = sizet_model.array_alloca(generator, ctx, ndims.value, "sequence_data");
|
||||
result.set(ctx, |f| f.items, data);
|
||||
let len_int = tuple_types.len();
|
||||
|
||||
for i in 0..ndims_int {
|
||||
// Get the i-th (0-based) element off of the tuple and load it
|
||||
// into `result`.
|
||||
let dim = ctx
|
||||
let len = sizet_model.constant(generator, ctx.ctx, len_int as u64);
|
||||
let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence");
|
||||
|
||||
for i in 0..len_int {
|
||||
// Get the i-th element off of the tuple and load it into `result`.
|
||||
let int = ctx
|
||||
.builder
|
||||
.build_extract_value(tuple, i as u32, format!("dim").as_str())
|
||||
.build_extract_value(input_sequence, i as u32, "int")
|
||||
.unwrap()
|
||||
.into_int_value();
|
||||
let dim = sizet_model.s_extend_or_bit_cast(generator, ctx, dim, "dim");
|
||||
let int = sizet_model.s_extend_or_bit_cast(generator, ctx, int, "int");
|
||||
|
||||
let offset = sizet_model.constant(generator, ctx.ctx, i as u64);
|
||||
data.offset(generator, ctx, offset.value, "dim").store(ctx, dim);
|
||||
result.offset(generator, ctx, offset.value, "int").store(ctx, int);
|
||||
}
|
||||
|
||||
(len, result)
|
||||
}
|
||||
TypeEnum::TObj { obj_id, .. }
|
||||
if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() =>
|
||||
{
|
||||
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
||||
let sequence_int = sizet_model.check_value(generator, ctx.ctx, sequence).unwrap();
|
||||
let input_int = input_sequence.into_int_value();
|
||||
|
||||
// Size is 1
|
||||
result.set(ctx, |f| f.len, one);
|
||||
let len = sizet_model.const_1(generator, ctx.ctx);
|
||||
let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence");
|
||||
|
||||
// Alloca an array of length 1 and store the sole integer input into the array.
|
||||
let data = sizet_model.array_alloca(generator, ctx, one.value, "data");
|
||||
data.offset(generator, ctx, zero.value, "dim").store(ctx, sequence_int);
|
||||
let int = sizet_model.s_extend_or_bit_cast(generator, ctx, input_int, "int");
|
||||
|
||||
// Storing into result[0]
|
||||
result.store(ctx, int);
|
||||
|
||||
(len, result)
|
||||
}
|
||||
_ => panic!("encountered unknown sequence type: {}", ctx.unifier.stringify(sequence_ty)),
|
||||
_ => panic!(
|
||||
"encountered unknown sequence type: {}",
|
||||
ctx.unifier.stringify(input_sequence_ty)
|
||||
),
|
||||
}
|
||||
|
||||
ListObject { item_type: ctx.primitives.usize(), value: result }
|
||||
}
|
||||
|
|
|
@ -1180,7 +1180,7 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
let ret_elem_ty = size_variant.of_int(&ctx.primitives);
|
||||
let func = match kind {
|
||||
Kind::Ceil => builtin_fns::call_ceil,
|
||||
Kind::Floor => builtin_fns::call_ceil_or_floor,
|
||||
Kind::Floor => builtin_fns::call_floor,
|
||||
};
|
||||
Ok(Some(func(generator, ctx, (arg_ty, arg), ret_elem_ty)?))
|
||||
}),
|
||||
|
@ -1361,7 +1361,7 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
let ndims1 = create_ndims(self.unifier, 1);
|
||||
let ndarray_float_1d = make_ndarray_ty(
|
||||
self.unifier,
|
||||
&self.primitives,
|
||||
self.primitives,
|
||||
Some(self.primitives.float),
|
||||
Some(ndims1),
|
||||
);
|
||||
|
@ -1470,7 +1470,7 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
PrimDef::FunNpSize => {
|
||||
// TODO: Make the return type usize
|
||||
create_fn_by_codegen(
|
||||
&mut self.unifier,
|
||||
self.unifier,
|
||||
&VarMap::new(),
|
||||
prim.name(),
|
||||
self.primitives.int32,
|
||||
|
@ -1485,7 +1485,7 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
// of the input ndarray.
|
||||
let ret_ty = self.unifier.get_dummy_var().ty;
|
||||
create_fn_by_codegen(
|
||||
&mut self.unifier,
|
||||
self.unifier,
|
||||
&VarMap::new(),
|
||||
prim.name(),
|
||||
ret_ty,
|
||||
|
@ -1548,7 +1548,7 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
|
||||
let func = match prim {
|
||||
PrimDef::FunNpCeil => builtin_fns::call_ceil,
|
||||
PrimDef::FunNpFloor => builtin_fns::call_ceil_or_floor,
|
||||
PrimDef::FunNpFloor => builtin_fns::call_floor,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
Ok(Some(func(generator, ctx, (arg_ty, arg), ctx.primitives.float)?))
|
||||
|
|
Loading…
Reference in New Issue