forked from M-Labs/nac3
1
0
Fork 0

WIP: core/ndstrides: builtin_fns deleted

This commit is contained in:
lyken 2024-08-15 11:01:56 +08:00
parent 5dce27e87d
commit 0df2f26c98
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
16 changed files with 315 additions and 277 deletions

View File

@ -1,81 +0,0 @@
use inkwell::values::{BasicValueEnum, IntValue};
use inkwell::IntPredicate;
use itertools::Itertools;
use crate::codegen::classes::{ArrayLikeValue, NDArrayValue, RangeValue, TypedArrayLikeAccessor};
use crate::codegen::expr::destructure_range;
use crate::codegen::irrt::calculate_len_for_slice_range;
use crate::codegen::{CodeGenContext, CodeGenerator};
use crate::toplevel::helper::PrimDef;
use crate::typecheck::typedef::{Type, TypeEnum};
/// Shorthand for [`unreachable!()`] when a type of argument is not supported.
///
/// The generated message will contain the function name and the name of the unsupported type.
fn unsupported_type(ctx: &CodeGenContext<'_, '_>, fn_name: &str, tys: &[Type]) -> ! {
unreachable!(
"{fn_name}() not supported for '{}'",
tys.iter().map(|ty| format!("'{}'", ctx.unifier.stringify(*ty))).join(", "),
)
}
/// Invokes the `len` builtin function.
pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
n: (Type, BasicValueEnum<'ctx>),
) -> Result<IntValue<'ctx>, String> {
let llvm_i32 = ctx.ctx.i32_type();
let range_ty = ctx.primitives.range;
let (arg_ty, arg) = n;
Ok(if ctx.unifier.unioned(arg_ty, range_ty) {
let arg = RangeValue::from_ptr_val(arg.into_pointer_value(), Some("range"));
let (start, end, step) = destructure_range(ctx, arg);
calculate_len_for_slice_range(generator, ctx, start, end, step)
} else {
match &*ctx.unifier.get_ty_immutable(arg_ty) {
TypeEnum::TTuple { ty, .. } => llvm_i32.const_int(ty.len() as u64, false),
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::List.id() => {
let zero = llvm_i32.const_zero();
let len = ctx
.build_gep_and_load(
arg.into_pointer_value(),
&[zero, llvm_i32.const_int(1, false)],
None,
)
.into_int_value();
ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap()
}
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let llvm_usize = generator.get_size_type(ctx.ctx);
let arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), llvm_usize, None);
let ndims = arg.dim_sizes().size(ctx, generator);
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(IntPredicate::NE, ndims, llvm_usize.const_zero(), "")
.unwrap(),
"0:TypeError",
"len() of unsized object",
[None, None, None],
ctx.current_loc,
);
let len = unsafe {
arg.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_zero(),
None,
)
};
ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap()
}
_ => unreachable!(),
}
})
}

View File

@ -35,7 +35,6 @@ use std::sync::{
use std::thread; use std::thread;
use structure::{CSlice, Exception, NDArray}; use structure::{CSlice, Exception, NDArray};
pub mod builtin_fns;
pub mod classes; pub mod classes;
pub mod concrete_type; pub mod concrete_type;
pub mod expr; pub mod expr;

View File

@ -0,0 +1,122 @@
use inkwell::{
context::Context,
types::{ArrayType, BasicType, BasicTypeEnum},
values::ArrayValue,
};
use crate::codegen::{CodeGenContext, CodeGenerator};
use super::*;
/// A Model for an [`ArrayType`].
#[derive(Debug, Clone, Copy)]
pub struct ArrayModel<Element> {
pub len: u32,
pub element: Element,
}
pub type Array<'ctx, Element> = Instance<'ctx, ArrayModel<Element>>;
impl<'ctx, Element: Model<'ctx>> Model<'ctx> for ArrayModel<Element> {
type Value = ArrayValue<'ctx>;
type Type = ArrayType<'ctx>;
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
self.element.get_type(generator, ctx).array_type(self.len)
}
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
ty: T,
) -> Result<(), ModelError> {
let ty = ty.as_basic_type_enum();
let BasicTypeEnum::ArrayType(ty) = ty else {
return Err(ModelError(format!("Expecting ArrayType, but got {ty:?}")));
};
if ty.len() != self.len {
return Err(ModelError(format!(
"Expecting ArrayType with size {}, but got an ArrayType with size {}",
ty.len(),
self.len
)));
}
self.element
.check_type(generator, ctx, ty.get_element_type())
.map_err(|err| err.under_context("an ArrayType"))?;
Ok(())
}
}
impl<'ctx, Element: Model<'ctx>> Ptr<'ctx, ArrayModel<Element>> {
/// Get the pointer to the `i`-th (0-based) array element.
pub fn at<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
i: u32,
name: &str,
) -> Ptr<'ctx, Element> {
assert!(i < self.model.0.len);
let zero = ctx.ctx.i32_type().const_zero();
let i = ctx.ctx.i32_type().const_int(i as u64, false);
let ptr = unsafe { ctx.builder.build_in_bounds_gep(self.value, &[zero, i], name).unwrap() };
PtrModel(self.model.0.element).check_value(generator, ctx.ctx, ptr).unwrap()
}
}
/// Like [`ArrayModel`] but length is strongly-typed.
#[derive(Debug, Clone, Copy, Default)]
pub struct NArrayModel<const LEN: u32, Element>(pub Element);
pub type NArray<'ctx, const LEN: u32, Element> = Instance<'ctx, NArrayModel<LEN, Element>>;
impl<'ctx, const LEN: u32, Element: Model<'ctx>> NArrayModel<LEN, Element> {
/// Forget the `LEN` constant generic and get an [`ArrayModel`] with the same length.
pub fn forget_len(&self) -> ArrayModel<Element> {
ArrayModel { element: self.0, len: LEN }
}
}
impl<'ctx, const LEN: u32, Element: Model<'ctx>> Model<'ctx> for NArrayModel<LEN, Element> {
type Value = ArrayValue<'ctx>;
type Type = ArrayType<'ctx>;
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
// Convenient implementation
self.forget_len().get_type(generator, ctx)
}
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
ty: T,
) -> Result<(), ModelError> {
// Convenient implementation
self.forget_len().check_type(generator, ctx, ty)
}
}
impl<'ctx, const LEN: u32, Element: Model<'ctx>> Ptr<'ctx, NArrayModel<LEN, Element>> {
/// Get the pointer to the `i`-th (0-based) array element.
pub fn at_const<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
i: u32,
name: &str,
) -> Ptr<'ctx, Element> {
assert!(i < LEN);
let zero = ctx.ctx.i32_type().const_zero();
let i = ctx.ctx.i32_type().const_int(i as u64, false);
let ptr = unsafe { ctx.builder.build_in_bounds_gep(self.value, &[zero, i], name).unwrap() };
PtrModel(self.model.0 .0).check_value(generator, ctx.ctx, ptr).unwrap()
}
}

View File

@ -1,4 +1,5 @@
mod any; mod any;
mod array;
mod core; mod core;
mod float; mod float;
pub mod function; pub mod function;
@ -8,6 +9,7 @@ mod structure;
pub mod util; pub mod util;
pub use any::*; pub use any::*;
pub use array::*;
pub use core::*; pub use core::*;
pub use float::*; pub use float::*;
pub use int::*; pub use int::*;

View File

@ -60,32 +60,3 @@ where
step.value, step.value,
) )
} }
/// Like [`gen_if_callback`] with [`Model`] abstractions and without the `else` block.
pub fn gen_if_model<'ctx, 'a, G, ThenFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
cond: Int<'ctx, Bool>,
then: ThenFn,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
ThenFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<(), String>,
{
let current_bb = ctx.builder.get_insert_block().unwrap();
let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "if.then");
let end_bb = ctx.ctx.insert_basic_block_after(then_bb, "if.end");
// Inserting into `current_bb`.
ctx.builder.build_conditional_branch(cond.value, then_bb, end_bb).unwrap();
// Inserting into `then_bb`
ctx.builder.position_at_end(then_bb);
then(generator, ctx)?;
ctx.builder.build_unconditional_branch(end_bb).unwrap();
// Reposition to `end_bb` for continuation.
ctx.builder.position_at_end(end_bb);
Ok(())
}

View File

@ -20,113 +20,6 @@ use super::{
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}; };
/// Helper function to create an ndarray with uninitialized values.
///
/// * `ndarray_ty` - The [`Type`] of the ndarray
/// * `shape` - The user input shape argument
/// * `shape_ty` - The [`Type`] of the shape argument
///
/// This function does data validation the `shape` input.
fn create_empty_ndarray<'ctx, G>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray_ty: Type,
shape: AnyObject<'ctx>,
) -> NDArrayObject<'ctx>
where
G: CodeGenerator + ?Sized,
{
let (_, shape) = parse_numpy_int_sequence(generator, ctx, shape);
let ndarray = NDArrayObject::alloca_ndarray_type(generator, ctx, ndarray_ty, "ndarray");
// Validate `shape`
let ndims = ndarray.get_ndims(generator, ctx.ctx);
call_nac3_ndarray_util_assert_shape_no_negative(generator, ctx, ndims, shape);
// Setup `ndarray` with `shape`
ndarray.copy_shape_from_array(generator, ctx, shape);
ndarray.create_data(generator, ctx); // `shape` has to be set
ndarray
}
/// Generates LLVM IR for `np.empty`.
pub fn gen_ndarray_empty<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<BasicValueEnum<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
// Parse arguments
let shape_ty = fun.0.args[0].ty;
let shape = args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
let shape = AnyObject { ty: shape_ty, value: shape };
// Implementation
let ndarray_ty = fun.0.ret;
let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape);
Ok(ndarray.instance.value.as_basic_value_enum())
}
/// Generates LLVM IR for `np.zero`.
pub fn gen_ndarray_zeros<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<BasicValueEnum<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
// Parse arguments
let shape_ty = fun.0.args[0].ty;
let shape = args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
let shape = AnyObject { ty: shape_ty, value: shape };
// Implementation
let ndarray_ty = fun.0.ret;
let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape);
let fill_value = ndarray_zero_value(generator, ctx, ndarray.dtype);
ndarray.fill(generator, ctx, fill_value);
Ok(ndarray.instance.value.as_basic_value_enum())
}
/// Generates LLVM IR for `np.ones`.
pub fn gen_ndarray_ones<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<BasicValueEnum<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
// Parse arguments
let shape_ty = fun.0.args[0].ty;
let shape = args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
let shape = AnyObject { ty: shape_ty, value: shape };
// Implementation
let ndarray_ty = fun.0.ret;
let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape);
let fill_value = ndarray_one_value(generator, ctx, ndarray.dtype);
ndarray.fill(generator, ctx, fill_value);
Ok(ndarray.instance.value.as_basic_value_enum())
}
/// Generates LLVM IR for `np.broadcast_to`. /// Generates LLVM IR for `np.broadcast_to`.
pub fn gen_ndarray_broadcast_to<'ctx>( pub fn gen_ndarray_broadcast_to<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -315,4 +208,3 @@ pub fn gen_ndarray_strides<'ctx>(
let strides = TupleObject::create(generator, ctx, objects, "strides"); let strides = TupleObject::create(generator, ctx, objects, "strides");
Ok(strides.value.as_basic_value_enum()) Ok(strides.value.as_basic_value_enum())
} }

View File

@ -75,4 +75,13 @@ impl<'ctx> ListObject<'ctx> {
opaque_list_ptr opaque_list_ptr
} }
/// Get the `len()` of this list.
pub fn len<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, SizeT> {
self.instance.get(generator, ctx, |f| f.len, "list_len")
}
} }

View File

@ -1,9 +1,16 @@
use inkwell::values::BasicValueEnum; use inkwell::values::BasicValueEnum;
use list::ListObject;
use ndarray::NDArrayObject;
use range::RangeObject;
use tuple::TupleObject;
use crate::typecheck::typedef::Type; use crate::typecheck::typedef::{Type, TypeEnum};
use super::{model::*, CodeGenContext, CodeGenerator};
pub mod list; pub mod list;
pub mod ndarray; pub mod ndarray;
pub mod range;
pub mod tuple; pub mod tuple;
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
@ -11,3 +18,38 @@ pub struct AnyObject<'ctx> {
pub ty: Type, pub ty: Type,
pub value: BasicValueEnum<'ctx>, pub value: BasicValueEnum<'ctx>,
} }
impl<'ctx> AnyObject<'ctx> {
// Get the `len()` of this object.
pub fn len<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, Int32> {
match &*ctx.unifier.get_ty_immutable(self.ty) {
TypeEnum::TTuple { .. } => {
let tuple = TupleObject::from_object(ctx, *self);
tuple.len(generator, ctx).truncate(generator, ctx, Int32, "tuple_len_32")
}
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() =>
{
let range = RangeObject::from_object(generator, ctx, *self);
range.len(generator, ctx)
}
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{
let list = ListObject::from_object(generator, ctx, *self);
list.len(generator, ctx).truncate(generator, ctx, Int32, "list_len_i32")
}
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
{
let ndarray = NDArrayObject::from_object(generator, ctx, *self);
ndarray.len(generator, ctx).truncate(generator, ctx, Int32, "ndarray_len_i32")
}
_ => unreachable!(),
}
}
}

View File

@ -5,14 +5,7 @@ use inkwell::{
use itertools::Itertools; use itertools::Itertools;
use crate::{ use crate::{
codegen::{ codegen::{llvm_intrinsics, model::*, stmt::gen_if_callback, CodeGenContext, CodeGenerator},
llvm_intrinsics,
model::{
util::{gen_for_model_auto, gen_if_model},
*,
},
CodeGenContext, CodeGenerator,
},
typecheck::typedef::Type, typecheck::typedef::Type,
}; };
@ -483,38 +476,40 @@ impl<'ctx> NDArrayObject<'ctx> {
let first_scalar = self.get_nth_scalar(generator, ctx, zero); let first_scalar = self.get_nth_scalar(generator, ctx, zero);
ctx.builder.build_store(pextremum, first_scalar.value).unwrap(); ctx.builder.build_store(pextremum, first_scalar.value).unwrap();
// Find extremum self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
let start = sizet_model.const_1(generator, ctx.ctx); // Start on 1
let stop = self.size(generator, ctx);
let step = sizet_model.const_1(generator, ctx.ctx);
gen_for_model_auto(generator, ctx, start, stop, step, |generator, ctx, _hooks, i| {
// Worth reading on "Notes" in <https://numpy.org/doc/stable/reference/generated/numpy.min.html#numpy.min>
// on how `NaN` values have to be handled.
let scalar = self.get_nth_scalar(generator, ctx, i);
let old_extremum = ctx.builder.build_load(pextremum, "current_extremum").unwrap(); let old_extremum = ctx.builder.build_load(pextremum, "current_extremum").unwrap();
let old_extremum = ScalarObject { dtype: self.dtype, value: old_extremum }; let old_extremum = ScalarObject { dtype: self.dtype, value: old_extremum };
let scalar = nditer.get_scalar(generator, ctx);
let new_extremum = ScalarObject::min_or_max(ctx, kind, old_extremum, scalar); let new_extremum = ScalarObject::min_or_max(ctx, kind, old_extremum, scalar);
// Check if new_extremum is more extreme than old_extremum. gen_if_callback(
let update_index = ScalarObject::compare(
generator, generator,
ctx, ctx,
new_extremum, |generator, ctx| {
old_extremum, // Is new_extremum is more extreme than old_extremum?
IntPredicate::NE, let cmp = ScalarObject::compare(
FloatPredicate::ONE, generator,
"", ctx,
); new_extremum,
old_extremum,
gen_if_model(generator, ctx, update_index, |_generator, ctx| { IntPredicate::NE,
pextremum_index.store(ctx, i); FloatPredicate::ONE,
Ok(()) "",
}) );
.unwrap(); Ok(cmp.value)
Ok(()) },
|generator, ctx| {
// Yes, update the extremum index
let index = nditer.get_index(generator, ctx);
pextremum_index.store(ctx, index);
Ok(())
},
|_generator, _ctx| {
// No, do nothing
Ok(())
},
)
}) })
.unwrap(); .unwrap();

View File

@ -3,7 +3,6 @@ use itertools::Itertools;
use crate::{ use crate::{
codegen::{ codegen::{
model::*,
object::ndarray::{NDArrayObject, ScalarObject}, object::ndarray::{NDArrayObject, ScalarObject},
stmt::gen_for_callback, stmt::gen_for_callback,
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,

View File

@ -17,8 +17,7 @@ use crate::{
call_nac3_ndarray_is_c_contiguous, call_nac3_ndarray_len, call_nac3_ndarray_nbytes, call_nac3_ndarray_is_c_contiguous, call_nac3_ndarray_len, call_nac3_ndarray_nbytes,
call_nac3_ndarray_resolve_and_check_new_shape, call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_resolve_and_check_new_shape, call_nac3_ndarray_set_strides_by_shape,
call_nac3_ndarray_size, call_nac3_ndarray_transpose, call_nac3_ndarray_size, call_nac3_ndarray_transpose,
call_nac3_ndarray_util_assert_output_shape_same, call_nac3_nditer_has_next, call_nac3_ndarray_util_assert_output_shape_same,
call_nac3_nditer_initialize, call_nac3_nditer_next,
}, },
model::*, model::*,
stmt::{gen_for_callback, BreakContinueHooks}, stmt::{gen_for_callback, BreakContinueHooks},
@ -40,7 +39,7 @@ use inkwell::{
}; };
use nditer::NDIterHandle; use nditer::NDIterHandle;
use scalar::{ScalarObject, ScalarOrNDArray}; use scalar::{ScalarObject, ScalarOrNDArray};
use util::{call_memcpy_model, gen_for_model_auto}; use util::call_memcpy_model;
use super::{tuple::TupleObject, AnyObject}; use super::{tuple::TupleObject, AnyObject};

View File

@ -71,7 +71,7 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
let input_sequence = TupleObject::from_object(ctx, input_sequence); let input_sequence = TupleObject::from_object(ctx, input_sequence);
let len_int = input_sequence.len(); let len_int = input_sequence.len_static();
let len = sizet_model.constant(generator, ctx.ctx, len_int as u64); let len = sizet_model.constant(generator, ctx.ctx, len_int as u64);
let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence"); let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence");

View File

@ -0,0 +1,41 @@
use crate::codegen::{
irrt::calculate_len_for_slice_range, model::*, structure::RangeModel, CodeGenContext,
CodeGenerator,
};
use super::AnyObject;
/// A `range` in NAC3
pub struct RangeObject<'ctx> {
pub instance: Ptr<'ctx, RangeModel>,
}
impl<'ctx> RangeObject<'ctx> {
pub fn from_object<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
object: AnyObject<'ctx>,
) -> Self {
assert!(ctx.unifier.unioned(ctx.primitives.range, object.ty)); // Sanity check on type.
let model = PtrModel(RangeModel::default());
let instance = model.check_value(generator, ctx.ctx, object.value).unwrap();
RangeObject { instance }
}
/// Get the `len()` of this range.
pub fn len<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, Int32> {
let start = self.instance.gep_start(generator, ctx, "").load(generator, ctx, "start");
let stop = self.instance.gep_stop(generator, ctx, "").load(generator, ctx, "stop");
let step = self.instance.gep_step(generator, ctx, "").load(generator, ctx, "step");
// TODO: Refactor this
let len =
calculate_len_for_slice_range(generator, ctx, start.value, stop.value, step.value);
IntModel(Int32).check_value(generator, ctx.ctx, len).unwrap()
}
}

View File

@ -4,7 +4,7 @@ use inkwell::values::StructValue;
use itertools::Itertools; use itertools::Itertools;
use crate::{ use crate::{
codegen::{CodeGenContext, CodeGenerator}, codegen::{model::*, CodeGenContext, CodeGenerator},
typecheck::typedef::{Type, TypeEnum}, typecheck::typedef::{Type, TypeEnum},
}; };
@ -22,10 +22,13 @@ pub struct TupleObject<'ctx> {
} }
impl<'ctx> TupleObject<'ctx> { impl<'ctx> TupleObject<'ctx> {
// NOTE: There is no Model abstraction for Tuples. Everything has to be done raw with Inkwell. // NOTE: There is no Model abstraction for Tuples with arbitrary lengths.
// Everything has to be done raw with Inkwell.
pub fn from_object(ctx: &mut CodeGenContext<'ctx, '_>, object: AnyObject<'ctx>) -> Self { pub fn from_object(ctx: &mut CodeGenContext<'ctx, '_>, object: AnyObject<'ctx>) -> Self {
// TODO: Keep `is_vararg_ctx` from TTuple? // TODO: Keep `is_vararg_ctx` from TTuple?
// Sanity check on object type.
let TypeEnum::TTuple { ty: tys, .. } = &*ctx.unifier.get_ty(object.ty) else { let TypeEnum::TTuple { ty: tys, .. } = &*ctx.unifier.get_ty(object.ty) else {
panic!( panic!(
"Expected type to be a TypeEnum::TTuple, got {}", "Expected type to be a TypeEnum::TTuple, got {}",
@ -71,23 +74,37 @@ impl<'ctx> TupleObject<'ctx> {
TupleObject { tys, value } TupleObject { tys, value }
} }
/// Get the `len()` of this tuple. /// Get the `len()` of this tuple statically.
/// ///
/// We statically know the lengths of tuples in NAC3. /// We statically know the lengths of tuples in NAC3 when compiling.
#[must_use] #[must_use]
pub fn len(&self) -> usize { pub fn len_static(&self) -> usize {
self.tys.len() self.tys.len()
} }
/// Get the `len()` of this tuple.
#[must_use]
pub fn len<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, SizeT> {
IntModel(SizeT).constant(generator, ctx.ctx, self.len_static() as u64)
}
/// Check if this tuple is an empty/unit tuple. /// Check if this tuple is an empty/unit tuple.
#[must_use] #[must_use]
pub fn is_empty(&self) -> bool { pub fn is_empty(&self) -> bool {
self.len() == 0 self.len_static() == 0
} }
/// Get the `i`-th (0-based) object in this tuple. /// Get the `i`-th (0-based) object in this tuple.
pub fn get(&self, ctx: &mut CodeGenContext<'ctx, '_>, i: usize, name: &str) -> AnyObject<'ctx> { pub fn get(&self, ctx: &mut CodeGenContext<'ctx, '_>, i: usize, name: &str) -> AnyObject<'ctx> {
assert!(i < self.len(), "Tuple object with length {} have index {i}", self.len()); assert!(
i < self.len_static(),
"Tuple object with length {} have index {i}",
self.len_static()
);
let value = ctx.builder.build_extract_value(self.value, i as u32, name).unwrap(); let value = ctx.builder.build_extract_value(self.value, i as u32, name).unwrap();
let ty = self.tys[i]; let ty = self.tys[i];

View File

@ -2,7 +2,7 @@ use inkwell::context::Context;
use crate::codegen::model::*; use crate::codegen::model::*;
use super::CodeGenerator; use super::{CodeGenContext, CodeGenerator};
/// Fields of [`CSlice`] /// Fields of [`CSlice`]
pub struct CSliceFields<'ctx, F: FieldTraversal<'ctx>> { pub struct CSliceFields<'ctx, F: FieldTraversal<'ctx>> {
@ -186,7 +186,6 @@ impl<'ctx, Item: Model<'ctx>> StructKind<'ctx> for SimpleNDArray<Item> {
} }
} }
/// An IRRT helper structure used when iterating through an ndarray.
/// Fields of [`NDIter`] /// Fields of [`NDIter`]
pub struct NDIterFields<'ctx, F: FieldTraversal<'ctx>> { pub struct NDIterFields<'ctx, F: FieldTraversal<'ctx>> {
pub ndims: F::Out<IntModel<SizeT>>, pub ndims: F::Out<IntModel<SizeT>>,
@ -200,6 +199,7 @@ pub struct NDIterFields<'ctx, F: FieldTraversal<'ctx>> {
pub size: F::Out<IntModel<SizeT>>, pub size: F::Out<IntModel<SizeT>>,
} }
/// An IRRT helper structure used when iterating through an ndarray.
#[derive(Debug, Clone, Copy, Default)] #[derive(Debug, Clone, Copy, Default)]
pub struct NDIter; pub struct NDIter;
@ -220,3 +220,37 @@ impl<'ctx> StructKind<'ctx> for NDIter {
} }
} }
} }
/// A NAC3 `range`. It is an array of 3 int32s.
// TODO: Use `pub type RangeModel<N> = NArrayModel<3, IntModel<N>>` in the future when
// `range` type is type dependent.
pub type RangeModel = NArrayModel<3, IntModel<Int32>>;
impl<'ctx> Ptr<'ctx, RangeModel> {
pub fn gep_start<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: &str,
) -> Ptr<'ctx, IntModel<Int32>> {
self.at_const(generator, ctx, 0, name)
}
pub fn gep_stop<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: &str,
) -> Ptr<'ctx, IntModel<Int32>> {
self.at_const(generator, ctx, 1, name)
}
pub fn gep_step<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: &str,
) -> Ptr<'ctx, IntModel<Int32>> {
self.at_const(generator, ctx, 2, name)
}
}

View File

@ -1,4 +1,4 @@
use std::iter::once; use std::{any::Any, iter::once};
use helper::{ use helper::{
create_ndims, debug_assert_prim_is_allowed, extract_ndims, make_exception_fields, create_ndims, debug_assert_prim_is_allowed, extract_ndims, make_exception_fields,
@ -17,10 +17,9 @@ use strum::IntoEnumIterator;
use crate::{ use crate::{
codegen::{ codegen::{
builtin_fns,
classes::{ProxyValue, RangeValue}, classes::{ProxyValue, RangeValue},
extern_fns::{self, call_np_linalg_det, call_np_linalg_matrix_power}, extern_fns::{self, call_np_linalg_det, call_np_linalg_matrix_power},
irrt::{self, call_nac3_ndarray_util_assert_shape_no_negative}, irrt::{self},
llvm_intrinsics, llvm_intrinsics,
model::*, model::*,
numpy::*, numpy::*,
@ -1294,7 +1293,7 @@ impl<'a> BuiltinBuilder<'a> {
prim.name(), prim.name(),
self.ndarray_float, self.ndarray_float,
&[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")], &[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
Box::new(move |ctx, obj, fun, args, generator| { Box::new(move |ctx, _obj, fun, args, generator| {
// Parse argument `shape`. // Parse argument `shape`.
let shape_ty = fun.0.args[0].ty; let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?; let shape_arg = args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
@ -1559,7 +1558,7 @@ impl<'a> BuiltinBuilder<'a> {
(array_tvar.ty, "array"), (array_tvar.ty, "array"),
(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape"),
], ],
Box::new(move |ctx, obj, fun, args, generator| { Box::new(move |ctx, _obj, fun, args, generator| {
// Parse argument #1 ndarray // Parse argument #1 ndarray
let input_ty = fun.0.args[0].ty; let input_ty = fun.0.args[0].ty;
let input = let input =
@ -1631,10 +1630,7 @@ impl<'a> BuiltinBuilder<'a> {
instance_to_stmt: HashMap::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new( codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, obj, fun, args, generator| { |ctx, _obj, fun, args, generator| {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
// Parse argument #1 ndarray // Parse argument #1 ndarray
let ndarray_ty = fun.0.args[0].ty; let ndarray_ty = fun.0.args[0].ty;
let ndarray = args[0] let ndarray = args[0]
@ -1843,8 +1839,9 @@ impl<'a> BuiltinBuilder<'a> {
move |ctx, _, fun, args, generator| { move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty; let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
let arg = AnyObject { value: arg, ty: arg_ty };
builtin_fns::call_len(generator, ctx, (arg_ty, arg)).map(|ret| Some(ret.into())) Ok(Some(arg.len(generator, ctx).value))
}, },
)))), )))),
loc: None, loc: None,