forked from M-Labs/nac3
1
0
Fork 0

WIP: core/ndstrides: AnyObject + TupleObject

This commit is contained in:
lyken 2024-08-14 12:48:10 +08:00
parent febe78b6a4
commit 2fbe981701
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
12 changed files with 221 additions and 122 deletions

View File

@ -1,15 +1,14 @@
use inkwell::types::BasicTypeEnum;
use inkwell::values::{BasicValue, BasicValueEnum, IntValue, PointerValue};
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
use inkwell::IntPredicate;
use itertools::Itertools;
use crate::codegen::classes::{
ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor,
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
UntypedArrayLikeAccessor,
};
use crate::codegen::expr::destructure_range;
use crate::codegen::irrt::calculate_len_for_slice_range;
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
use crate::codegen::{extern_fns, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
use crate::toplevel::helper::PrimDef;
use crate::toplevel::numpy::unpack_ndarray_var_tys;

View File

@ -12,7 +12,10 @@ use crate::{
call_memcpy_generic,
},
need_sret, numpy,
object::ndarray::{scalar::split_scalar_or_ndarray, NDArrayObject, NDArrayOut},
object::{
ndarray::{scalar::split_scalar_or_ndarray, NDArrayObject, NDArrayOut},
AnyObject,
},
stmt::{
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
gen_var,
@ -1537,10 +1540,11 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
} else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|| ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
{
let left =
split_scalar_or_ndarray(generator, ctx, left_val, ty1).as_ndarray(generator, ctx);
let right =
split_scalar_or_ndarray(generator, ctx, right_val, ty2).as_ndarray(generator, ctx);
let left = AnyObject { ty: ty1, value: left_val };
let right = AnyObject { ty: ty1, value: right_val };
let left = split_scalar_or_ndarray(generator, ctx, left).as_ndarray(generator, ctx);
let right = split_scalar_or_ndarray(generator, ctx, right).as_ndarray(generator, ctx);
debug_assert!(ctx.unifier.unioned(left.dtype, right.dtype)); // Typechecker ensures this.
@ -2860,8 +2864,9 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
let ndarray_ty = value.custom.unwrap();
let ndarray = ndarray.to_basic_value_enum(ctx, generator, ndarray_ty)?;
let ndarray =
NDArrayObject::from_value_and_type(generator, ctx, ndarray, ndarray_ty);
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
let indexes = gen_ndarray_subscript_ndindexes(generator, ctx, slice)?;
let result = ndarray

View File

@ -13,7 +13,10 @@ use crate::{
use super::{
irrt::call_nac3_ndarray_util_assert_shape_no_negative,
model::*,
object::ndarray::{shape_util::parse_numpy_int_sequence, NDArrayObject},
object::{
ndarray::{shape_util::parse_numpy_int_sequence, NDArrayObject},
AnyObject,
},
CodeGenContext, CodeGenerator,
};
@ -84,13 +87,12 @@ fn create_empty_ndarray<'ctx, G>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray_ty: Type,
shape: BasicValueEnum<'ctx>,
shape_ty: Type,
shape: AnyObject<'ctx>,
) -> NDArrayObject<'ctx>
where
G: CodeGenerator + ?Sized,
{
let (_, shape) = parse_numpy_int_sequence(generator, ctx, shape, shape_ty);
let (_, shape) = parse_numpy_int_sequence(generator, ctx, shape);
let ndarray =
NDArrayObject::alloca_uninitialized_of_type(generator, ctx, ndarray_ty, "ndarray");
@ -120,10 +122,11 @@ pub fn gen_ndarray_empty<'ctx>(
// Parse arguments
let shape_ty = fun.0.args[0].ty;
let shape = args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
let shape = AnyObject { ty: shape_ty, value: shape };
// Implementation
let ndarray_ty = fun.0.ret;
let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape, shape_ty);
let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape);
Ok(ndarray.instance.value.as_basic_value_enum())
}
@ -142,10 +145,11 @@ pub fn gen_ndarray_zeros<'ctx>(
// Parse arguments
let shape_ty = fun.0.args[0].ty;
let shape = args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
let shape = AnyObject { ty: shape_ty, value: shape };
// Implementation
let ndarray_ty = fun.0.ret;
let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape, shape_ty);
let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape);
let fill_value = ndarray_zero_value(generator, ctx, ndarray.dtype);
ndarray.fill(generator, ctx, fill_value);
@ -167,10 +171,11 @@ pub fn gen_ndarray_ones<'ctx>(
// Parse arguments
let shape_ty = fun.0.args[0].ty;
let shape = args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
let shape = AnyObject { ty: shape_ty, value: shape };
// Implementation
let ndarray_ty = fun.0.ret;
let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape, shape_ty);
let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape);
let fill_value = ndarray_one_value(generator, ctx, ndarray.dtype);
ndarray.fill(generator, ctx, fill_value);
@ -192,6 +197,7 @@ pub fn gen_ndarray_full<'ctx>(
// Parse argument #1 shape
let shape_ty = fun.0.args[0].ty;
let shape = args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
let shape = AnyObject { ty: shape_ty, value: shape };
// Parse argument #2 fill_value
let fill_value_ty = fun.0.args[1].ty;
@ -199,7 +205,7 @@ pub fn gen_ndarray_full<'ctx>(
// Implementation
let ndarray_ty = fun.0.ret;
let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape, shape_ty);
let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape);
ndarray.fill(generator, ctx, fill_value);
@ -220,10 +226,12 @@ pub fn gen_ndarray_broadcast_to<'ctx>(
// Parse argument #1 input
let input_ty = fun.0.args[0].ty;
let input = args[0].1.clone().to_basic_value_enum(ctx, generator, input_ty)?;
let input = AnyObject { ty: input_ty, value: input };
// Parse argument #2 shape
let shape_ty = fun.0.args[1].ty;
let shape = args[1].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
let shape = AnyObject { ty: shape_ty, value: shape };
// Define models
let sizet_model = IntModel(SizeT);
@ -234,11 +242,10 @@ pub fn gen_ndarray_broadcast_to<'ctx>(
let broadcast_ndims = extract_ndims(&ctx.unifier, broadcast_ndims_ty);
// Process `input`
let in_ndarray =
split_scalar_or_ndarray(generator, ctx, input, input_ty).as_ndarray(generator, ctx);
let in_ndarray = split_scalar_or_ndarray(generator, ctx, input).as_ndarray(generator, ctx);
// Process `shape`
let (_, broadcast_shape) = parse_numpy_int_sequence(generator, ctx, shape, shape_ty);
let (_, broadcast_shape) = parse_numpy_int_sequence(generator, ctx, shape);
// NOTE: shape.size should equal to `broadcasted_ndims`.
let broadcast_ndims_llvm = sizet_model.constant(generator, ctx.ctx, broadcast_ndims);
call_nac3_ndarray_util_assert_shape_no_negative(
@ -269,23 +276,24 @@ pub fn gen_ndarray_reshape<'ctx>(
// Parse argument #1 input
let input_ty = fun.0.args[0].ty;
let input = args[0].1.clone().to_basic_value_enum(ctx, generator, input_ty)?;
let input = AnyObject { ty: input_ty, value: input };
// Parse argument #2 shape
let shape_ty = fun.0.args[1].ty;
let shape = args[1].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
let shape = AnyObject { ty: shape_ty, value: shape };
// Extract reshaped_ndims
let (_, reshaped_ndims_ty) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
let reshaped_ndims = extract_ndims(&ctx.unifier, reshaped_ndims_ty);
// Process `input`
let in_ndarray =
split_scalar_or_ndarray(generator, ctx, input, input_ty).as_ndarray(generator, ctx);
let in_ndarray = split_scalar_or_ndarray(generator, ctx, input).as_ndarray(generator, ctx);
// Process the shape input from user and resolve negative indices.
// The resulting `new_shape`'s size should be equal to reshaped_ndims.
// This is ensured by the typechecker.
let (_, new_shape) = parse_numpy_int_sequence(generator, ctx, shape, shape_ty);
let (_, new_shape) = parse_numpy_int_sequence(generator, ctx, shape);
let reshaped_ndarray = in_ndarray.reshape_or_copy(generator, ctx, reshaped_ndims, new_shape);
Ok(reshaped_ndarray.instance.value.as_basic_value_enum())
@ -354,8 +362,8 @@ pub fn gen_ndarray_size<'ctx>(
let ndarray_ty = fun.0.args[0].ty;
let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
let ndarray = NDArrayObject::from_value_and_type(generator, ctx, ndarray, ndarray_ty);
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
let size = ndarray.size(generator, ctx).truncate(generator, ctx, Int32, "size");
Ok(size.value.as_basic_value_enum())
@ -375,14 +383,15 @@ pub fn gen_ndarray_shape<'ctx>(
// Parse argument #1 ndarray
let ndarray_ty = fun.0.args[0].ty;
let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
// Define models
let sizet_model = IntModel(SizeT);
// Process ndarray
let ndarray = NDArrayObject::from_value_and_type(generator, ctx, ndarray, ndarray_ty);
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
let mut items = Vec::with_capacity(ndarray.ndims as usize);
let mut objects = Vec::with_capacity(ndarray.ndims as usize);
for i in 0..ndarray.ndims {
let i = sizet_model.constant(generator, ctx.ctx, i);
@ -392,10 +401,11 @@ pub fn gen_ndarray_shape<'ctx>(
.ix(generator, ctx, i.value, "dim");
let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT
items.push((dim.value.as_basic_value_enum(), ctx.primitives.int32));
objects
.push(AnyObject { ty: ctx.primitives.int32, value: dim.value.as_basic_value_enum() });
}
let shape = TupleObject::create(generator, ctx, items, "shape");
let shape = TupleObject::create(generator, ctx, objects, "shape");
Ok(shape.value.as_basic_value_enum())
}
@ -407,7 +417,7 @@ pub fn gen_ndarray_strides<'ctx>(
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<BasicValueEnum<'ctx>, String> {
// TODO: This function looks exactly like `gen_ndarray_shapes`, code duplication?
// TODO: Code duplication: This function looks exactly like `gen_ndarray_shapes`.
assert!(obj.is_none());
assert_eq!(args.len(), 1);
@ -415,14 +425,15 @@ pub fn gen_ndarray_strides<'ctx>(
// Parse argument #1 ndarray
let ndarray_ty = fun.0.args[0].ty;
let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
// Define models
let sizet_model = IntModel(SizeT);
// Process ndarray
let ndarray = NDArrayObject::from_value_and_type(generator, ctx, ndarray, ndarray_ty);
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
let mut items = Vec::with_capacity(ndarray.ndims as usize);
let mut objects = Vec::with_capacity(ndarray.ndims as usize);
for i in 0..ndarray.ndims {
let i = sizet_model.constant(generator, ctx.ctx, i);
@ -432,10 +443,11 @@ pub fn gen_ndarray_strides<'ctx>(
.ix(generator, ctx, i.value, "dim");
let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT
items.push((dim.value.as_basic_value_enum(), ctx.primitives.int32));
objects
.push(AnyObject { ty: ctx.primitives.int32, value: dim.value.as_basic_value_enum() });
}
let strides = TupleObject::create(generator, ctx, items, "strides");
let strides = TupleObject::create(generator, ctx, objects, "strides");
Ok(strides.value.as_basic_value_enum())
}
@ -458,20 +470,23 @@ pub fn gen_ndarray_transpose<'ctx>(
// Parse argument #1 ndarray
let ndarray_ty = fun.0.args[0].ty;
let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
// Implementation
let ndarray = NDArrayObject::from_value_and_type(generator, ctx, ndarray, ndarray_ty);
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
let has_axes = args.len() >= 2;
let transposed_ndarray = if has_axes {
// Parse argument #2 axes
let in_axes_ty = fun.0.args[1].ty;
let in_axes = args[1].1.clone().to_basic_value_enum(ctx, generator, in_axes_ty)?;
let in_axes = AnyObject { ty: in_axes_ty, value: in_axes };
let (_, axes) = parse_numpy_int_sequence(generator, ctx, in_axes, in_axes_ty);
let (_, axes) = parse_numpy_int_sequence(generator, ctx, in_axes);
ndarray.transpose(generator, ctx, Some(axes))
} else {
// axes is not given
ndarray.transpose(generator, ctx, None)
};
@ -488,8 +503,9 @@ pub fn gen_ndarray_array<'ctx>(
assert!(obj.is_none());
assert!(matches!(args.len(), 1..=3));
let obj_ty = fun.0.args[0].ty;
let obj_arg = args[0].1.clone().to_basic_value_enum(ctx, generator, obj_ty)?;
let object_ty = fun.0.args[0].ty;
let object = args[0].1.clone().to_basic_value_enum(ctx, generator, object_ty)?;
let object = AnyObject { ty: object_ty, value: object };
let copy_arg = if let Some(arg) =
args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name))
@ -513,7 +529,7 @@ pub fn gen_ndarray_array<'ctx>(
let copy = IntModel(Byte).check_value(generator, ctx.ctx, copy_arg).unwrap(); // NAC3 booleans are i8
let copy = copy.truncate(generator, ctx, Bool, "copy_bool");
let ndarray = NDArrayObject::from_np_array(generator, ctx, obj_arg, obj_ty, copy);
let ndarray = NDArrayObject::from_np_array(generator, ctx, object, copy);
debug_assert!(ndarray.ndims <= output_ndims); // Sanity check on `ndims`
let ndarray = ndarray.atleast_nd(generator, ctx, output_ndims);

View File

@ -5,6 +5,8 @@ use crate::{
typecheck::typedef::{iter_type_vars, Type, TypeEnum},
};
use super::AnyObject;
/// A NAC3 Python List object.
#[derive(Debug, Clone, Copy)]
pub struct ListObject<'ctx> {
@ -15,21 +17,20 @@ pub struct ListObject<'ctx> {
impl<'ctx> ListObject<'ctx> {
/// Create a [`ListObject`] from an LLVM value and its typechecker [`Type`].
pub fn from_value_and_type<V: BasicValue<'ctx>, G: CodeGenerator + ?Sized>(
pub fn from_object<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
list_val: V,
list_type: Type,
object: AnyObject<'ctx>,
) -> Self {
// Check typechecker type and extract `item_type`
let item_type = match &*ctx.unifier.get_ty(list_type) {
let item_type = match &*ctx.unifier.get_ty(object.ty) {
TypeEnum::TObj { obj_id, params, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{
iter_type_vars(params).next().unwrap().ty // Extract `item_type`
}
_ => {
panic!("Expecting type to be a list, but got {}", ctx.unifier.stringify(list_type))
panic!("Expecting type to be a list, but got {}", ctx.unifier.stringify(object.ty))
}
};
@ -37,7 +38,7 @@ impl<'ctx> ListObject<'ctx> {
let plist_model = PtrModel(StructModel(List { item: item_model }));
// Create object
let value = plist_model.check_value(generator, ctx.ctx, list_val).unwrap();
let value = plist_model.check_value(generator, ctx.ctx, object.value).unwrap();
ListObject { item_type, instance: value }
}

View File

@ -1,3 +1,13 @@
use inkwell::values::BasicValueEnum;
use crate::typecheck::typedef::Type;
pub mod list;
pub mod ndarray;
pub mod tuple;
#[derive(Debug, Clone, Copy)]
pub struct AnyObject<'ctx> {
pub ty: Type,
pub value: BasicValueEnum<'ctx>,
}

View File

@ -1,11 +1,9 @@
use inkwell::values::BasicValueEnum;
use super::NDArrayObject;
use crate::{
codegen::{
irrt::{call_nac3_array_set_and_validate_list_shape, call_nac3_array_write_list_to_array},
model::*,
object::list::ListObject,
object::{list::ListObject, AnyObject},
stmt::gen_if_else_expr_callback,
CodeGenContext, CodeGenerator,
},
@ -26,7 +24,7 @@ fn get_list_object_dtype_and_ndims<'ctx>(
}
impl<'ctx> NDArrayObject<'ctx> {
fn from_np_array_list_copy<G: CodeGenerator + ?Sized>(
fn from_np_array_list_copy_impl<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
list: ListObject<'ctx>,
@ -59,7 +57,7 @@ impl<'ctx> NDArrayObject<'ctx> {
ndarray
}
fn from_np_array_list_try_no_copy<G: CodeGenerator + ?Sized>(
fn from_np_array_list_try_no_copy_impl<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
list: ListObject<'ctx>,
@ -101,11 +99,11 @@ impl<'ctx> NDArrayObject<'ctx> {
ndarray
} else {
// `list` is nested, it is impossible to not copy.
NDArrayObject::from_np_array_list_copy(generator, ctx, list)
NDArrayObject::from_np_array_list_copy_impl(generator, ctx, list)
}
}
fn from_np_array_list<G: CodeGenerator + ?Sized>(
fn from_np_array_list_impl<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
list: ListObject<'ctx>,
@ -118,11 +116,12 @@ impl<'ctx> NDArrayObject<'ctx> {
ctx,
|_generator, _ctx| Ok(copy.value),
|generator, ctx| {
let ndarray = NDArrayObject::from_np_array_list_copy(generator, ctx, list);
let ndarray = NDArrayObject::from_np_array_list_copy_impl(generator, ctx, list);
Ok(Some(ndarray.instance.value))
},
|generator, ctx| {
let ndarray = NDArrayObject::from_np_array_list_try_no_copy(generator, ctx, list);
let ndarray =
NDArrayObject::from_np_array_list_try_no_copy_impl(generator, ctx, list);
Ok(Some(ndarray.instance.value))
},
)
@ -132,7 +131,7 @@ impl<'ctx> NDArrayObject<'ctx> {
NDArrayObject::from_value_and_unpacked_types(generator, ctx, ndarray, dtype, ndims)
}
fn from_np_array_ndarray<G: CodeGenerator + ?Sized>(
pub fn from_np_array_ndarray_impl<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayObject<'ctx>,
@ -166,24 +165,23 @@ impl<'ctx> NDArrayObject<'ctx> {
pub fn from_np_array<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
object: BasicValueEnum<'ctx>,
object_ty: Type,
object: AnyObject<'ctx>,
copy: Int<'ctx, Bool>,
) -> Self {
match &*ctx.unifier.get_ty(object_ty) {
match &*ctx.unifier.get_ty(object.ty) {
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{
let list = ListObject::from_value_and_type(generator, ctx, object, object_ty);
NDArrayObject::from_np_array_list(generator, ctx, list, copy)
let list = ListObject::from_object(generator, ctx, object);
NDArrayObject::from_np_array_list_impl(generator, ctx, list, copy)
}
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
{
let ndarray = NDArrayObject::from_value_and_type(generator, ctx, object, object_ty);
NDArrayObject::from_np_array_ndarray(generator, ctx, ndarray, copy)
let ndarray = NDArrayObject::from_object(generator, ctx, object);
NDArrayObject::from_np_array_ndarray_impl(generator, ctx, ndarray, copy)
}
_ => panic!("Unrecognized object type: {}", ctx.unifier.stringify(object_ty)), // Typechecker ensures this
_ => panic!("Unrecognized object type: {}", ctx.unifier.stringify(object.ty)), // Typechecker ensures this
}
}
}

View File

@ -35,6 +35,8 @@ use inkwell::{
use scalar::{ScalarObject, ScalarOrNDArray};
use util::{call_memcpy_model, gen_for_model_auto};
use super::AnyObject;
/// A NAC3 Python ndarray object.
#[derive(Debug, Clone, Copy)]
pub struct NDArrayObject<'ctx> {
@ -45,18 +47,17 @@ pub struct NDArrayObject<'ctx> {
impl<'ctx> NDArrayObject<'ctx> {
/// Create an [`NDArrayObject`] from an LLVM value and its typechecker [`Type`].
pub fn from_value_and_type<V: BasicValue<'ctx>, G: CodeGenerator + ?Sized>(
pub fn from_object<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
value: V,
ty: Type,
object: AnyObject<'ctx>,
) -> Self {
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, object.ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
Self::from_value_and_unpacked_types(generator, ctx, value, dtype, ndims)
Self::from_value_and_unpacked_types(generator, ctx, object.value, dtype, ndims)
}
/// Like [`NDArrayObject::from_value_and_type`] but you directly supply the ndarray's
/// Like [`NDArrayObject::from_object`] but you directly supply the ndarray's
/// `dtype` and `ndims`.
pub fn from_value_and_unpacked_types<V: BasicValue<'ctx>, G: CodeGenerator + ?Sized>(
generator: &mut G,

View File

@ -1,7 +1,7 @@
use inkwell::values::{BasicValue, BasicValueEnum};
use crate::{
codegen::{model::*, CodeGenContext, CodeGenerator},
codegen::{model::*, object::AnyObject, CodeGenContext, CodeGenerator},
typecheck::typedef::{Type, TypeEnum},
};
@ -120,23 +120,22 @@ impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for NDArrayObject<'ctx> {
}
}
/// Split an [`BasicValueEnum<'ctx>`] into a [`ScalarOrNDArray`] depending
/// Split an [`AnyObject`] into a [`ScalarOrNDArray`] depending
/// on its [`Type`].
pub fn split_scalar_or_ndarray<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
input: BasicValueEnum<'ctx>,
input_ty: Type,
object: AnyObject<'ctx>,
) -> ScalarOrNDArray<'ctx> {
match &*ctx.unifier.get_ty(input_ty) {
match &*ctx.unifier.get_ty(object.ty) {
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
{
let ndarray = NDArrayObject::from_value_and_type(generator, ctx, input, input_ty);
let ndarray = NDArrayObject::from_object(generator, ctx, object);
ScalarOrNDArray::NDArray(ndarray)
}
_ => {
let scalar = ScalarObject { dtype: input_ty, instance: input };
let scalar = ScalarObject { dtype: object.ty, instance: object.value };
ScalarOrNDArray::Scalar(scalar)
}
}

View File

@ -1,9 +1,12 @@
use inkwell::values::BasicValueEnum;
use util::gen_for_model_auto;
use crate::{
codegen::{model::*, object::list::ListObject, CodeGenContext, CodeGenerator},
typecheck::typedef::{Type, TypeEnum},
codegen::{
model::*,
object::{list::ListObject, tuple::TupleObject, AnyObject},
CodeGenContext, CodeGenerator,
},
typecheck::typedef::TypeEnum,
};
/// Parse a NumPy-like "int sequence" input and return the int sequence as an array and its length.
@ -20,8 +23,7 @@ use crate::{
pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
input_sequence: BasicValueEnum<'ctx>,
input_sequence_ty: Type,
input_sequence: AnyObject<'ctx>,
) -> (Int<'ctx, SizeT>, Ptr<'ctx, IntModel<SizeT>>) {
let sizet_model = IntModel(SizeT);
@ -29,15 +31,14 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
let one = sizet_model.const_1(generator, ctx.ctx);
// The result `list` to return.
match &*ctx.unifier.get_ty(input_sequence_ty) {
match &*ctx.unifier.get_ty(input_sequence.ty) {
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{
// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
// Check `input_sequence`
let input_sequence =
ListObject::from_value_and_type(generator, ctx, input_sequence, input_sequence_ty);
let input_sequence = ListObject::from_object(generator, ctx, input_sequence);
let len = input_sequence.instance.gep(ctx, |f| f.len).load(generator, ctx, "len");
let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence");
@ -66,20 +67,17 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
}
TypeEnum::TTuple { ty: tuple_types, .. } => {
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
let input_sequence = input_sequence.into_struct_value(); // A tuple is a struct
let len_int = tuple_types.len();
let input_sequence = TupleObject::from_object(ctx, input_sequence);
let len_int = input_sequence.len();
let len = sizet_model.constant(generator, ctx.ctx, len_int as u64);
let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence");
for i in 0..len_int {
// Get the i-th element off of the tuple and load it into `result`.
let int = ctx
.builder
.build_extract_value(input_sequence, i as u32, "int")
.unwrap()
.into_int_value();
let int = input_sequence.get(ctx, i, "dim").value.into_int_value();
let int = sizet_model.s_extend_or_bit_cast(generator, ctx, int, "int");
let offset = sizet_model.constant(generator, ctx.ctx, i as u64);
@ -92,7 +90,7 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() =>
{
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
let input_int = input_sequence.into_int_value();
let input_int = input_sequence.value.into_int_value();
let len = sizet_model.const_1(generator, ctx.ctx);
let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence");
@ -106,7 +104,7 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
}
_ => panic!(
"encountered unknown sequence type: {}",
ctx.unifier.stringify(input_sequence_ty)
ctx.unifier.stringify(input_sequence.ty)
),
}
}

View File

@ -1,34 +1,68 @@
use inkwell::values::{BasicValueEnum, StructValue};
use core::panic;
use inkwell::values::StructValue;
use itertools::Itertools;
use crate::{
codegen::{CodeGenContext, CodeGenerator},
typecheck::typedef::Type,
typecheck::typedef::{Type, TypeEnum},
};
use super::AnyObject;
/// A NAC3 tuple object.
///
/// NOTE: This struct has no copy trait.
#[derive(Debug, Clone)]
pub struct TupleObject<'ctx> {
/// The type of the tuple.
pub tys: Vec<Type>,
/// The underlying LLVM value of this tuple.
pub value: StructValue<'ctx>,
}
impl<'ctx> TupleObject<'ctx> {
// NOTE: There is no Model abstraction for Tuples. Everything has to be done raw with Inkwell.
pub fn from_object(ctx: &mut CodeGenContext<'ctx, '_>, object: AnyObject<'ctx>) -> Self {
// TODO: Keep `is_vararg_ctx` from TTuple?
let TypeEnum::TTuple { ty: tys, .. } = &*ctx.unifier.get_ty(object.ty) else {
panic!(
"Expected type to be a TypeEnum::TTuple, got {}",
ctx.unifier.stringify(object.ty)
);
};
let value = object.value.into_struct_value();
if value.get_type().count_fields() as usize != tys.len() {
panic!(
"Tuple type has {} item(s), but the LLVM struct value has {} field(s)",
tys.len(),
value.get_type().count_fields()
);
}
TupleObject { tys: tys.clone(), value }
}
/// Convenience function. Create a [`TupleObject`] from an iterator of objects.
pub fn create<I, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
items: I,
objects: I,
name: &str,
) -> Self
where
I: IntoIterator<Item = (BasicValueEnum<'ctx>, Type)>,
I: IntoIterator<Item = AnyObject<'ctx>>,
{
let (vals, tys): (Vec<_>, Vec<_>) = items.into_iter().unzip();
let (values, tys): (Vec<_>, Vec<_>) =
objects.into_iter().map(|object| (object.value, object.ty)).unzip();
// let tuple_ty = ctx.unifier.add_ty(TypeEnum::TTuple { ty: tys });
let llvm_tys = tys.iter().map(|ty| ctx.get_llvm_type(generator, *ty)).collect_vec();
let llvm_tuple_ty = ctx.ctx.struct_type(&llvm_tys, false);
let pllvm_tuple = ctx.builder.build_alloca(llvm_tuple_ty, "tuple").unwrap();
for (i, val) in vals.into_iter().enumerate() {
// Store the dim value into the tuple
for (i, val) in values.into_iter().enumerate() {
let pval = ctx.builder.build_struct_gep(pllvm_tuple, i as u32, "value").unwrap();
ctx.builder.build_store(pval, val).unwrap();
}
@ -36,4 +70,22 @@ impl<'ctx> TupleObject<'ctx> {
let value = ctx.builder.build_load(pllvm_tuple, name).unwrap().into_struct_value();
TupleObject { tys, value }
}
/// Get the `len()` of this tuple.
///
/// We statically know the lengths of tuples in NAC3.
pub fn len(&self) -> usize {
self.tys.len()
}
/// Get the `i`-th (0-based) object in this tuple.
pub fn get(&self, ctx: &mut CodeGenContext<'ctx, '_>, i: usize, name: &str) -> AnyObject<'ctx> {
if i >= self.len() {
panic!("Tuple object with length {} have index {i}", self.len());
}
let value = ctx.builder.build_extract_value(self.value, i as u32, name).unwrap();
let ty = self.tys[i];
AnyObject { value, ty }
}
}

View File

@ -2,6 +2,7 @@ use super::model::*;
use super::object::ndarray::indexing::util::gen_ndarray_subscript_ndindexes;
use super::object::ndarray::scalar::split_scalar_or_ndarray;
use super::object::ndarray::NDArrayObject;
use super::object::AnyObject;
use super::{
super::symbol_resolver::ValueEnum,
expr::destructure_range,
@ -411,13 +412,14 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
.gen_expr(ctx, target)?
.unwrap()
.to_basic_value_enum(ctx, generator, target_ty)?;
let target = NDArrayObject::from_value_and_type(generator, ctx, target, target_ty);
let target = AnyObject { value: target, ty: target_ty };
// Process key
let key = gen_ndarray_subscript_ndindexes(generator, ctx, key)?;
// Process value
let value = value.to_basic_value_enum(ctx, generator, value_ty)?;
let value = AnyObject { value, ty: value_ty };
/*
Reference code:
@ -433,9 +435,10 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
...and finally copy 1-1 from value to target.
```
*/
let target = NDArrayObject::from_object(generator, ctx, target);
let target = target.index(generator, ctx, &key, "assign_target_ndarray");
let value =
split_scalar_or_ndarray(generator, ctx, value, value_ty).as_ndarray(generator, ctx);
let value = split_scalar_or_ndarray(generator, ctx, value).as_ndarray(generator, ctx);
let broadcast_result = NDArrayObject::broadcast(generator, ctx, &[target, value]);

View File

@ -18,10 +18,13 @@ use crate::{
extern_fns, irrt, llvm_intrinsics,
numpy::*,
numpy_new::{self, gen_ndarray_transpose},
object::ndarray::{
object::{
ndarray::{
functions::{FloorOrCeil, MinOrMax},
scalar::{split_scalar_or_ndarray, ScalarObject, ScalarOrNDArray},
},
AnyObject,
},
stmt::exn_constructor,
},
symbol_resolver::SymbolValue,
@ -1080,6 +1083,7 @@ impl<'a> BuiltinBuilder<'a> {
move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
let arg = AnyObject { ty: arg_ty, value: arg };
let ret_dtype = match prim {
PrimDef::FunInt32 => ctx.primitives.int32,
@ -1091,7 +1095,7 @@ impl<'a> BuiltinBuilder<'a> {
_ => unreachable!(),
};
let result = split_scalar_or_ndarray(generator, ctx, arg, arg_ty).map(
let result = split_scalar_or_ndarray(generator, ctx, arg).map(
generator,
ctx,
ret_dtype,
@ -1158,10 +1162,11 @@ impl<'a> BuiltinBuilder<'a> {
Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
let arg = AnyObject { ty: arg_ty, value: arg };
let ret_int_dtype = size_variant.of_int(&ctx.primitives);
let result = split_scalar_or_ndarray(generator, ctx, arg, arg_ty).map(
let result = split_scalar_or_ndarray(generator, ctx, arg).map(
generator,
ctx,
ret_int_dtype,
@ -1225,8 +1230,9 @@ impl<'a> BuiltinBuilder<'a> {
Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
let arg = AnyObject { ty: arg_ty, value: arg };
let result = split_scalar_or_ndarray(generator, ctx, arg, arg_ty).map(
let result = split_scalar_or_ndarray(generator, ctx, arg).map(
generator,
ctx,
int_sized,
@ -1618,6 +1624,7 @@ impl<'a> BuiltinBuilder<'a> {
Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
let arg = AnyObject { ty: arg_ty, value: arg };
let kind = match prim {
PrimDef::FunNpFloor => FloorOrCeil::Floor,
@ -1625,7 +1632,7 @@ impl<'a> BuiltinBuilder<'a> {
_ => unreachable!(),
};
let result = split_scalar_or_ndarray(generator, ctx, arg, arg_ty).map(
let result = split_scalar_or_ndarray(generator, ctx, arg).map(
generator,
ctx,
ctx.primitives.float,
@ -1652,8 +1659,9 @@ impl<'a> BuiltinBuilder<'a> {
Box::new(|ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
let arg = AnyObject { ty: arg_ty, value: arg };
let result = split_scalar_or_ndarray(generator, ctx, arg, arg_ty).map(
let result = split_scalar_or_ndarray(generator, ctx, arg).map(
generator,
ctx,
ctx.primitives.float,
@ -1790,8 +1798,9 @@ impl<'a> BuiltinBuilder<'a> {
Box::new(move |ctx, _, fun, args, generator| {
let a_ty = fun.0.args[0].ty;
let a = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?;
let a = AnyObject { ty: a_ty, value: a };
let a = split_scalar_or_ndarray(generator, ctx, a, a_ty).as_ndarray(generator, ctx);
let a = split_scalar_or_ndarray(generator, ctx, a).as_ndarray(generator, ctx);
let result = match prim {
PrimDef::FunNpArgmin => a
.argmin_or_argmax(generator, ctx, MinOrMax::Min)
@ -1845,9 +1854,12 @@ impl<'a> BuiltinBuilder<'a> {
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x2_ty = fun.0.args[1].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x1 = AnyObject { ty: x1_ty, value: x1_val };
let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
let x2 = AnyObject { ty: x2_ty, value: x2_val };
let kind = match prim {
PrimDef::FunNpMinimum => MinOrMax::Min,
@ -1855,8 +1867,8 @@ impl<'a> BuiltinBuilder<'a> {
_ => unreachable!(),
};
let x1 = split_scalar_or_ndarray(generator, ctx, x1_val, x1_ty);
let x2 = split_scalar_or_ndarray(generator, ctx, x2_val, x2_ty);
let x1 = split_scalar_or_ndarray(generator, ctx, x1);
let x2 = split_scalar_or_ndarray(generator, ctx, x2);
// NOTE: x1.dtype() and x2.dtype() should be the same
let common_ty = x1.dtype();
@ -1907,8 +1919,9 @@ impl<'a> BuiltinBuilder<'a> {
move |ctx, _, fun, args, generator| {
let n_ty = fun.0.args[0].ty;
let n_val = args[0].1.clone().to_basic_value_enum(ctx, generator, n_ty)?;
let n = AnyObject { ty: n_ty, value: n_val };
let result = split_scalar_or_ndarray(generator, ctx, n_val, n_ty).map(
let result = split_scalar_or_ndarray(generator, ctx, n).map(
generator,
ctx,
num_ty.ty,
@ -1936,6 +1949,7 @@ impl<'a> BuiltinBuilder<'a> {
Box::new(move |ctx, _, fun, args, generator| {
let x_ty = fun.0.args[0].ty;
let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?;
let x = AnyObject { value: x_val, ty: x_ty };
let function = match prim {
PrimDef::FunNpIsInf => irrt::call_isnan,
@ -1943,7 +1957,7 @@ impl<'a> BuiltinBuilder<'a> {
_ => unreachable!(),
};
let result = split_scalar_or_ndarray(generator, ctx, x_val, x_ty).map(
let result = split_scalar_or_ndarray(generator, ctx, x).map(
generator,
ctx,
ctx.primitives.bool,
@ -2010,8 +2024,9 @@ impl<'a> BuiltinBuilder<'a> {
Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty;
let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
let arg = AnyObject { ty: arg_ty, value: arg_val };
let result = split_scalar_or_ndarray(generator, ctx, arg_val, arg_ty).map(
let result = split_scalar_or_ndarray(generator, ctx, arg).map(
generator,
ctx,
ctx.primitives.float,
@ -2126,12 +2141,14 @@ impl<'a> BuiltinBuilder<'a> {
move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x1 = AnyObject { ty: x1_ty, value: x1_val };
let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
let x2 = AnyObject { ty: x2_ty, value: x2_val };
let x1 = split_scalar_or_ndarray(generator, ctx, x1_val, x1_ty);
let x2 = split_scalar_or_ndarray(generator, ctx, x2_val, x2_ty);
let x1 = split_scalar_or_ndarray(generator, ctx, x1);
let x2 = split_scalar_or_ndarray(generator, ctx, x2);
let result = ScalarOrNDArray::broadcasting_starmap(
generator,