forked from M-Labs/nac3
[core] codegen/ndarray: Reimplement np_{zeros,ones,full,empty}
Based on 792374fa: core/ndstrides: implement np_{zeros,ones,full,empty}.
This commit is contained in:
parent
7d02f5833d
commit
5880f964bb
@ -3,7 +3,6 @@ use inkwell::{
|
|||||||
values::{BasicValue, BasicValueEnum, IntValue, PointerValue},
|
values::{BasicValue, BasicValueEnum, IntValue, PointerValue},
|
||||||
AddressSpace, IntPredicate, OptimizationLevel,
|
AddressSpace, IntPredicate, OptimizationLevel,
|
||||||
};
|
};
|
||||||
use itertools::Itertools;
|
|
||||||
|
|
||||||
use nac3parser::ast::{Operator, StrRef};
|
use nac3parser::ast::{Operator, StrRef};
|
||||||
|
|
||||||
@ -19,17 +18,28 @@ use super::{
|
|||||||
llvm_intrinsics::{self, call_memcpy_generic},
|
llvm_intrinsics::{self, call_memcpy_generic},
|
||||||
macros::codegen_unreachable,
|
macros::codegen_unreachable,
|
||||||
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
|
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
|
||||||
types::{ndarray::NDArrayType, ListType, ProxyType},
|
types::{
|
||||||
|
ndarray::{
|
||||||
|
factory::{ndarray_one_value, ndarray_zero_value},
|
||||||
|
NDArrayType,
|
||||||
|
},
|
||||||
|
ListType, ProxyType,
|
||||||
|
},
|
||||||
values::{
|
values::{
|
||||||
ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue,
|
ndarray::{shape::parse_numpy_int_sequence, NDArrayValue},
|
||||||
ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator,
|
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue,
|
||||||
|
TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator,
|
||||||
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
||||||
},
|
},
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
symbol_resolver::ValueEnum,
|
symbol_resolver::ValueEnum,
|
||||||
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId},
|
toplevel::{
|
||||||
|
helper::{extract_ndims, PrimDef},
|
||||||
|
numpy::unpack_ndarray_var_tys,
|
||||||
|
DefinitionId,
|
||||||
|
},
|
||||||
typecheck::{
|
typecheck::{
|
||||||
magic_methods::Binop,
|
magic_methods::Binop,
|
||||||
typedef::{FunSignature, Type, TypeEnum},
|
typedef::{FunSignature, Type, TypeEnum},
|
||||||
@ -174,132 +184,6 @@ pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
Ok(ndarray)
|
Ok(ndarray)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
elem_ty: Type,
|
|
||||||
) -> BasicValueEnum<'ctx> {
|
|
||||||
if [ctx.primitives.int32, ctx.primitives.uint32]
|
|
||||||
.iter()
|
|
||||||
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
|
|
||||||
{
|
|
||||||
ctx.ctx.i32_type().const_zero().into()
|
|
||||||
} else if [ctx.primitives.int64, ctx.primitives.uint64]
|
|
||||||
.iter()
|
|
||||||
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
|
|
||||||
{
|
|
||||||
ctx.ctx.i64_type().const_zero().into()
|
|
||||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) {
|
|
||||||
ctx.ctx.f64_type().const_zero().into()
|
|
||||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
|
|
||||||
ctx.ctx.bool_type().const_zero().into()
|
|
||||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
|
|
||||||
ctx.gen_string(generator, "").into()
|
|
||||||
} else {
|
|
||||||
codegen_unreachable!(ctx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
elem_ty: Type,
|
|
||||||
) -> BasicValueEnum<'ctx> {
|
|
||||||
if [ctx.primitives.int32, ctx.primitives.uint32]
|
|
||||||
.iter()
|
|
||||||
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
|
|
||||||
{
|
|
||||||
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int32);
|
|
||||||
ctx.ctx.i32_type().const_int(1, is_signed).into()
|
|
||||||
} else if [ctx.primitives.int64, ctx.primitives.uint64]
|
|
||||||
.iter()
|
|
||||||
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
|
|
||||||
{
|
|
||||||
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int64);
|
|
||||||
ctx.ctx.i64_type().const_int(1, is_signed).into()
|
|
||||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) {
|
|
||||||
ctx.ctx.f64_type().const_float(1.0).into()
|
|
||||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
|
|
||||||
ctx.ctx.bool_type().const_int(1, false).into()
|
|
||||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
|
|
||||||
ctx.gen_string(generator, "1").into()
|
|
||||||
} else {
|
|
||||||
codegen_unreachable!(ctx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// LLVM-typed implementation for generating the implementation for constructing an `NDArray`.
|
|
||||||
///
|
|
||||||
/// * `elem_ty` - The element type of the `NDArray`.
|
|
||||||
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
|
|
||||||
///
|
|
||||||
/// ### Notes on `shape`
|
|
||||||
///
|
|
||||||
/// Just like numpy, the `shape` argument can be:
|
|
||||||
/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
|
|
||||||
/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))`
|
|
||||||
/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
|
||||||
///
|
|
||||||
/// See also [`typecheck::type_inferencer::fold_numpy_function_call_shape_argument`] to
|
|
||||||
/// learn how `shape` gets from being a Python user expression to here.
|
|
||||||
fn call_ndarray_empty_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
elem_ty: Type,
|
|
||||||
shape: BasicValueEnum<'ctx>,
|
|
||||||
) -> Result<NDArrayValue<'ctx>, String> {
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
|
|
||||||
match shape {
|
|
||||||
BasicValueEnum::PointerValue(shape_list_ptr)
|
|
||||||
if ListValue::is_representable(shape_list_ptr, llvm_usize).is_ok() =>
|
|
||||||
{
|
|
||||||
// 1. A list of ints; e.g., `np.empty([600, 800, 3])`
|
|
||||||
|
|
||||||
let shape_list = ListValue::from_pointer_value(shape_list_ptr, llvm_usize, None);
|
|
||||||
create_ndarray_dyn_shape(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
elem_ty,
|
|
||||||
&shape_list,
|
|
||||||
|_, ctx, shape_list| Ok(shape_list.load_size(ctx, None)),
|
|
||||||
|generator, ctx, shape_list, idx| {
|
|
||||||
Ok(shape_list.data().get(ctx, generator, &idx, None).into_int_value())
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
BasicValueEnum::StructValue(shape_tuple) => {
|
|
||||||
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
|
|
||||||
// Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM.
|
|
||||||
|
|
||||||
// Get the length/size of the tuple, which also happens to be the value of `ndims`.
|
|
||||||
let ndims = shape_tuple.get_type().count_fields();
|
|
||||||
|
|
||||||
let shape = (0..ndims)
|
|
||||||
.map(|dim_i| {
|
|
||||||
ctx.builder
|
|
||||||
.build_extract_value(shape_tuple, dim_i, format!("dim{dim_i}").as_str())
|
|
||||||
.map(BasicValueEnum::into_int_value)
|
|
||||||
.map(|v| {
|
|
||||||
ctx.builder.build_int_z_extend_or_bit_cast(v, llvm_usize, "").unwrap()
|
|
||||||
})
|
|
||||||
.unwrap()
|
|
||||||
})
|
|
||||||
.collect_vec();
|
|
||||||
|
|
||||||
create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice())
|
|
||||||
}
|
|
||||||
BasicValueEnum::IntValue(shape_int) => {
|
|
||||||
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
|
||||||
let shape_int =
|
|
||||||
ctx.builder.build_int_z_extend_or_bit_cast(shape_int, llvm_usize, "").unwrap();
|
|
||||||
|
|
||||||
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
|
|
||||||
}
|
|
||||||
_ => codegen_unreachable!(ctx),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as
|
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as
|
||||||
/// its input.
|
/// its input.
|
||||||
fn ndarray_fill_flattened<'ctx, 'a, G, ValueFn>(
|
fn ndarray_fill_flattened<'ctx, 'a, G, ValueFn>(
|
||||||
@ -529,107 +413,6 @@ where
|
|||||||
Ok(res)
|
Ok(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// LLVM-typed implementation for generating the implementation for `ndarray.zeros`.
|
|
||||||
///
|
|
||||||
/// * `elem_ty` - The element type of the `NDArray`.
|
|
||||||
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
|
|
||||||
fn call_ndarray_zeros_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
elem_ty: Type,
|
|
||||||
shape: BasicValueEnum<'ctx>,
|
|
||||||
) -> Result<NDArrayValue<'ctx>, String> {
|
|
||||||
let supported_types = [
|
|
||||||
ctx.primitives.int32,
|
|
||||||
ctx.primitives.int64,
|
|
||||||
ctx.primitives.uint32,
|
|
||||||
ctx.primitives.uint64,
|
|
||||||
ctx.primitives.float,
|
|
||||||
ctx.primitives.bool,
|
|
||||||
ctx.primitives.str,
|
|
||||||
];
|
|
||||||
assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty)));
|
|
||||||
|
|
||||||
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
|
|
||||||
ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| {
|
|
||||||
let value = ndarray_zero_value(generator, ctx, elem_ty);
|
|
||||||
|
|
||||||
Ok(value)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(ndarray)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// LLVM-typed implementation for generating the implementation for `ndarray.ones`.
|
|
||||||
///
|
|
||||||
/// * `elem_ty` - The element type of the `NDArray`.
|
|
||||||
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
|
|
||||||
fn call_ndarray_ones_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
elem_ty: Type,
|
|
||||||
shape: BasicValueEnum<'ctx>,
|
|
||||||
) -> Result<NDArrayValue<'ctx>, String> {
|
|
||||||
let supported_types = [
|
|
||||||
ctx.primitives.int32,
|
|
||||||
ctx.primitives.int64,
|
|
||||||
ctx.primitives.uint32,
|
|
||||||
ctx.primitives.uint64,
|
|
||||||
ctx.primitives.float,
|
|
||||||
ctx.primitives.bool,
|
|
||||||
ctx.primitives.str,
|
|
||||||
];
|
|
||||||
assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty)));
|
|
||||||
|
|
||||||
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
|
|
||||||
ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| {
|
|
||||||
let value = ndarray_one_value(generator, ctx, elem_ty);
|
|
||||||
|
|
||||||
Ok(value)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(ndarray)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// LLVM-typed implementation for generating the implementation for `ndarray.full`.
|
|
||||||
///
|
|
||||||
/// * `elem_ty` - The element type of the `NDArray`.
|
|
||||||
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
|
|
||||||
fn call_ndarray_full_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
elem_ty: Type,
|
|
||||||
shape: BasicValueEnum<'ctx>,
|
|
||||||
fill_value: BasicValueEnum<'ctx>,
|
|
||||||
) -> Result<NDArrayValue<'ctx>, String> {
|
|
||||||
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
|
|
||||||
ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| {
|
|
||||||
let value = if fill_value.is_pointer_value() {
|
|
||||||
let llvm_i1 = ctx.ctx.bool_type();
|
|
||||||
|
|
||||||
let copy = generator.gen_var_alloc(ctx, fill_value.get_type(), None)?;
|
|
||||||
|
|
||||||
call_memcpy_generic(
|
|
||||||
ctx,
|
|
||||||
copy,
|
|
||||||
fill_value.into_pointer_value(),
|
|
||||||
fill_value.get_type().size_of().map(Into::into).unwrap(),
|
|
||||||
llvm_i1.const_zero(),
|
|
||||||
);
|
|
||||||
|
|
||||||
copy.into()
|
|
||||||
} else if fill_value.is_int_value() || fill_value.is_float_value() {
|
|
||||||
fill_value
|
|
||||||
} else {
|
|
||||||
codegen_unreachable!(ctx)
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(value)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(ndarray)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the number of dimensions for a multidimensional list as an [`IntValue`].
|
/// Returns the number of dimensions for a multidimensional list as an [`IntValue`].
|
||||||
fn llvm_ndlist_get_ndims<'ctx, G: CodeGenerator + ?Sized>(
|
fn llvm_ndlist_get_ndims<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &G,
|
generator: &G,
|
||||||
@ -1752,8 +1535,15 @@ pub fn gen_ndarray_empty<'ctx>(
|
|||||||
let shape_ty = fun.0.args[0].ty;
|
let shape_ty = fun.0.args[0].ty;
|
||||||
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
||||||
|
|
||||||
call_ndarray_empty_impl(generator, context, context.primitives.float, shape_arg)
|
let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
|
||||||
.map(NDArrayValue::into)
|
let llvm_dtype = context.get_llvm_type(generator, dtype);
|
||||||
|
let ndims = extract_ndims(&context.unifier, ndims);
|
||||||
|
|
||||||
|
let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg));
|
||||||
|
|
||||||
|
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims))
|
||||||
|
.construct_numpy_empty(generator, context, &shape, None);
|
||||||
|
Ok(ndarray.as_base_value())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for `ndarray.zeros`.
|
/// Generates LLVM IR for `ndarray.zeros`.
|
||||||
@ -1770,8 +1560,15 @@ pub fn gen_ndarray_zeros<'ctx>(
|
|||||||
let shape_ty = fun.0.args[0].ty;
|
let shape_ty = fun.0.args[0].ty;
|
||||||
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
||||||
|
|
||||||
call_ndarray_zeros_impl(generator, context, context.primitives.float, shape_arg)
|
let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
|
||||||
.map(NDArrayValue::into)
|
let llvm_dtype = context.get_llvm_type(generator, dtype);
|
||||||
|
let ndims = extract_ndims(&context.unifier, ndims);
|
||||||
|
|
||||||
|
let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg));
|
||||||
|
|
||||||
|
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims))
|
||||||
|
.construct_numpy_zeros(generator, context, dtype, &shape, None);
|
||||||
|
Ok(ndarray.as_base_value())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for `ndarray.ones`.
|
/// Generates LLVM IR for `ndarray.ones`.
|
||||||
@ -1788,8 +1585,15 @@ pub fn gen_ndarray_ones<'ctx>(
|
|||||||
let shape_ty = fun.0.args[0].ty;
|
let shape_ty = fun.0.args[0].ty;
|
||||||
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
||||||
|
|
||||||
call_ndarray_ones_impl(generator, context, context.primitives.float, shape_arg)
|
let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
|
||||||
.map(NDArrayValue::into)
|
let llvm_dtype = context.get_llvm_type(generator, dtype);
|
||||||
|
let ndims = extract_ndims(&context.unifier, ndims);
|
||||||
|
|
||||||
|
let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg));
|
||||||
|
|
||||||
|
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims))
|
||||||
|
.construct_numpy_ones(generator, context, dtype, &shape, None);
|
||||||
|
Ok(ndarray.as_base_value())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for `ndarray.full`.
|
/// Generates LLVM IR for `ndarray.full`.
|
||||||
@ -1809,8 +1613,15 @@ pub fn gen_ndarray_full<'ctx>(
|
|||||||
let fill_value_arg =
|
let fill_value_arg =
|
||||||
args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?;
|
args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?;
|
||||||
|
|
||||||
call_ndarray_full_impl(generator, context, fill_value_ty, shape_arg, fill_value_arg)
|
let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
|
||||||
.map(NDArrayValue::into)
|
let llvm_dtype = context.get_llvm_type(generator, dtype);
|
||||||
|
let ndims = extract_ndims(&context.unifier, ndims);
|
||||||
|
|
||||||
|
let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg));
|
||||||
|
|
||||||
|
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims))
|
||||||
|
.construct_numpy_full(generator, context, &shape, fill_value_arg, None);
|
||||||
|
Ok(ndarray.as_base_value())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn gen_ndarray_array<'ctx>(
|
pub fn gen_ndarray_array<'ctx>(
|
||||||
|
146
nac3core/src/codegen/types/ndarray/factory.rs
Normal file
146
nac3core/src/codegen/types/ndarray/factory.rs
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
use inkwell::values::{BasicValueEnum, IntValue};
|
||||||
|
|
||||||
|
use super::NDArrayType;
|
||||||
|
use crate::{
|
||||||
|
codegen::{
|
||||||
|
irrt, types::ProxyType, values::TypedArrayLikeAccessor, CodeGenContext, CodeGenerator,
|
||||||
|
},
|
||||||
|
typecheck::typedef::Type,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Get the zero value in `np.zeros()` of a `dtype`.
|
||||||
|
pub fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
dtype: Type,
|
||||||
|
) -> BasicValueEnum<'ctx> {
|
||||||
|
if [ctx.primitives.int32, ctx.primitives.uint32]
|
||||||
|
.iter()
|
||||||
|
.any(|ty| ctx.unifier.unioned(dtype, *ty))
|
||||||
|
{
|
||||||
|
ctx.ctx.i32_type().const_zero().into()
|
||||||
|
} else if [ctx.primitives.int64, ctx.primitives.uint64]
|
||||||
|
.iter()
|
||||||
|
.any(|ty| ctx.unifier.unioned(dtype, *ty))
|
||||||
|
{
|
||||||
|
ctx.ctx.i64_type().const_zero().into()
|
||||||
|
} else if ctx.unifier.unioned(dtype, ctx.primitives.float) {
|
||||||
|
ctx.ctx.f64_type().const_zero().into()
|
||||||
|
} else if ctx.unifier.unioned(dtype, ctx.primitives.bool) {
|
||||||
|
ctx.ctx.bool_type().const_zero().into()
|
||||||
|
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
|
||||||
|
ctx.gen_string(generator, "").into()
|
||||||
|
} else {
|
||||||
|
panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the one value in `np.ones()` of a `dtype`.
|
||||||
|
pub fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
dtype: Type,
|
||||||
|
) -> BasicValueEnum<'ctx> {
|
||||||
|
if [ctx.primitives.int32, ctx.primitives.uint32]
|
||||||
|
.iter()
|
||||||
|
.any(|ty| ctx.unifier.unioned(dtype, *ty))
|
||||||
|
{
|
||||||
|
let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int32);
|
||||||
|
ctx.ctx.i32_type().const_int(1, is_signed).into()
|
||||||
|
} else if [ctx.primitives.int64, ctx.primitives.uint64]
|
||||||
|
.iter()
|
||||||
|
.any(|ty| ctx.unifier.unioned(dtype, *ty))
|
||||||
|
{
|
||||||
|
let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int64);
|
||||||
|
ctx.ctx.i64_type().const_int(1, is_signed).into()
|
||||||
|
} else if ctx.unifier.unioned(dtype, ctx.primitives.float) {
|
||||||
|
ctx.ctx.f64_type().const_float(1.0).into()
|
||||||
|
} else if ctx.unifier.unioned(dtype, ctx.primitives.bool) {
|
||||||
|
ctx.ctx.bool_type().const_int(1, false).into()
|
||||||
|
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
|
||||||
|
ctx.gen_string(generator, "1").into()
|
||||||
|
} else {
|
||||||
|
panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> NDArrayType<'ctx> {
|
||||||
|
/// Create an ndarray like
|
||||||
|
/// [`np.empty`](https://numpy.org/doc/stable/reference/generated/numpy.empty.html).
|
||||||
|
pub fn construct_numpy_empty<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> <Self as ProxyType<'ctx>>::Value {
|
||||||
|
let ndarray = self.construct_uninitialized(generator, ctx, name);
|
||||||
|
|
||||||
|
// Validate `shape`
|
||||||
|
irrt::ndarray::call_nac3_ndarray_util_assert_shape_no_negative(generator, ctx, shape);
|
||||||
|
|
||||||
|
ndarray.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator));
|
||||||
|
unsafe { ndarray.create_data(generator, ctx) };
|
||||||
|
|
||||||
|
ndarray
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create an ndarray like
|
||||||
|
/// [`np.full`](https://numpy.org/doc/stable/reference/generated/numpy.full.html).
|
||||||
|
pub fn construct_numpy_full<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||||
|
fill_value: BasicValueEnum<'ctx>,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> <Self as ProxyType<'ctx>>::Value {
|
||||||
|
let ndarray = self.construct_numpy_empty(generator, ctx, shape, name);
|
||||||
|
ndarray.fill(generator, ctx, fill_value);
|
||||||
|
ndarray
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create an ndarray like
|
||||||
|
/// [`np.zero`](https://numpy.org/doc/stable/reference/generated/numpy.zeros.html).
|
||||||
|
pub fn construct_numpy_zeros<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
dtype: Type,
|
||||||
|
shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> <Self as ProxyType<'ctx>>::Value {
|
||||||
|
assert_eq!(
|
||||||
|
ctx.get_llvm_type(generator, dtype),
|
||||||
|
self.dtype,
|
||||||
|
"Expected LLVM dtype={} but got {}",
|
||||||
|
self.dtype.print_to_string(),
|
||||||
|
ctx.get_llvm_type(generator, dtype).print_to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let fill_value = ndarray_zero_value(generator, ctx, dtype);
|
||||||
|
self.construct_numpy_full(generator, ctx, shape, fill_value, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create an ndarray like
|
||||||
|
/// [`np.ones`](https://numpy.org/doc/stable/reference/generated/numpy.ones.html).
|
||||||
|
pub fn construct_numpy_ones<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
dtype: Type,
|
||||||
|
shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> <Self as ProxyType<'ctx>>::Value {
|
||||||
|
assert_eq!(
|
||||||
|
ctx.get_llvm_type(generator, dtype),
|
||||||
|
self.dtype,
|
||||||
|
"Expected LLVM dtype={} but got {}",
|
||||||
|
self.dtype.print_to_string(),
|
||||||
|
ctx.get_llvm_type(generator, dtype).print_to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let fill_value = ndarray_one_value(generator, ctx, dtype);
|
||||||
|
self.construct_numpy_full(generator, ctx, shape, fill_value, name)
|
||||||
|
}
|
||||||
|
}
|
@ -25,6 +25,7 @@ pub use indexing::*;
|
|||||||
pub use nditer::*;
|
pub use nditer::*;
|
||||||
|
|
||||||
mod contiguous;
|
mod contiguous;
|
||||||
|
pub mod factory;
|
||||||
mod indexing;
|
mod indexing;
|
||||||
mod nditer;
|
mod nditer;
|
||||||
|
|
||||||
|
@ -163,8 +163,13 @@ impl<'ctx> NDIterType<'ctx> {
|
|||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
ndarray: NDArrayValue<'ctx>,
|
ndarray: NDArrayValue<'ctx>,
|
||||||
) -> <Self as ProxyType<'ctx>>::Value {
|
) -> <Self as ProxyType<'ctx>>::Value {
|
||||||
|
assert!(
|
||||||
|
ndarray.get_type().ndims().is_some(),
|
||||||
|
"NDIter requires ndims of NDArray to be known."
|
||||||
|
);
|
||||||
|
|
||||||
let nditer = self.raw_alloca_var(generator, ctx, None);
|
let nditer = self.raw_alloca_var(generator, ctx, None);
|
||||||
let ndims = ndarray.load_ndims(ctx);
|
let ndims = self.llvm_usize.const_int(ndarray.get_type().ndims().unwrap(), false);
|
||||||
|
|
||||||
// The caller has the responsibility to allocate 'indices' for `NDIter`.
|
// The caller has the responsibility to allocate 'indices' for `NDIter`.
|
||||||
let indices =
|
let indices =
|
||||||
|
@ -23,6 +23,7 @@ pub use nditer::*;
|
|||||||
mod contiguous;
|
mod contiguous;
|
||||||
mod indexing;
|
mod indexing;
|
||||||
mod nditer;
|
mod nditer;
|
||||||
|
pub mod shape;
|
||||||
mod view;
|
mod view;
|
||||||
|
|
||||||
/// Proxy type for accessing an `NDArray` value in LLVM.
|
/// Proxy type for accessing an `NDArray` value in LLVM.
|
||||||
@ -397,6 +398,23 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
irrt::ndarray::call_nac3_ndarray_copy_data(generator, ctx, src, *self);
|
irrt::ndarray::call_nac3_ndarray_copy_data(generator, ctx, src, *self);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Fill the ndarray with a scalar.
|
||||||
|
///
|
||||||
|
/// `fill_value` must have the same LLVM type as the `dtype` of this ndarray.
|
||||||
|
pub fn fill<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
value: BasicValueEnum<'ctx>,
|
||||||
|
) {
|
||||||
|
self.foreach(generator, ctx, |_, ctx, _, nditer| {
|
||||||
|
let p = nditer.get_pointer(ctx);
|
||||||
|
ctx.builder.build_store(p, value).unwrap();
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar.
|
/// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar.
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn is_unsized(&self) -> Option<bool> {
|
pub fn is_unsized(&self) -> Option<bool> {
|
||||||
|
152
nac3core/src/codegen/values/ndarray/shape.rs
Normal file
152
nac3core/src/codegen/values/ndarray/shape.rs
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
use inkwell::values::{BasicValueEnum, IntValue};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
codegen::{
|
||||||
|
stmt::gen_for_callback_incrementing,
|
||||||
|
types::{ListType, TupleType},
|
||||||
|
values::{
|
||||||
|
ArraySliceValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter,
|
||||||
|
TypedArrayLikeMutator, UntypedArrayLikeAccessor,
|
||||||
|
},
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
},
|
||||||
|
typecheck::typedef::{Type, TypeEnum},
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Parse a NumPy-like "int sequence" input and return the int sequence as an array and its length.
|
||||||
|
///
|
||||||
|
/// * `sequence` - The `sequence` parameter.
|
||||||
|
/// * `sequence_ty` - The typechecker type of `sequence`
|
||||||
|
///
|
||||||
|
/// The `sequence` argument type may only be one of the following:
|
||||||
|
/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
|
||||||
|
/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))`
|
||||||
|
/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to
|
||||||
|
/// `np.empty([3])`
|
||||||
|
///
|
||||||
|
/// All `int32` values will be sign-extended to `SizeT`.
|
||||||
|
pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
(input_seq_ty, input_seq): (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let zero = llvm_usize.const_zero();
|
||||||
|
let one = llvm_usize.const_int(1, false);
|
||||||
|
|
||||||
|
// The result `list` to return.
|
||||||
|
match &*ctx.unifier.get_ty_immutable(input_seq_ty) {
|
||||||
|
TypeEnum::TObj { obj_id, .. }
|
||||||
|
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
||||||
|
{
|
||||||
|
// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
|
||||||
|
|
||||||
|
let input_seq = ListType::from_unifier_type(generator, ctx, input_seq_ty)
|
||||||
|
.map_value(input_seq.into_pointer_value(), None);
|
||||||
|
|
||||||
|
let len = input_seq.load_size(ctx, None);
|
||||||
|
// TODO: Find a way to remove this mid-BB allocation
|
||||||
|
let result = ctx.builder.build_array_alloca(llvm_usize, len, "").unwrap();
|
||||||
|
let result = TypedArrayLikeAdapter::from(
|
||||||
|
ArraySliceValue::from_ptr_val(result, len, None),
|
||||||
|
|_, _, val| val.into_int_value(),
|
||||||
|
|_, _, val| val.into(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Load all the `int32`s from the input_sequence, cast them to `SizeT`, and store them into `result`
|
||||||
|
gen_for_callback_incrementing(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
None,
|
||||||
|
zero,
|
||||||
|
(len, false),
|
||||||
|
|generator, ctx, _, i| {
|
||||||
|
// Load the i-th int32 in the input sequence
|
||||||
|
let int = unsafe {
|
||||||
|
input_seq.data().get_unchecked(ctx, generator, &i, None).into_int_value()
|
||||||
|
};
|
||||||
|
|
||||||
|
// Cast to SizeT
|
||||||
|
let int =
|
||||||
|
ctx.builder.build_int_s_extend_or_bit_cast(int, llvm_usize, "").unwrap();
|
||||||
|
|
||||||
|
// Store
|
||||||
|
unsafe { result.set_typed_unchecked(ctx, generator, &i, int) };
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
one,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
TypeEnum::TTuple { .. } => {
|
||||||
|
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
|
||||||
|
|
||||||
|
let input_seq = TupleType::from_unifier_type(generator, ctx, input_seq_ty)
|
||||||
|
.map_value(input_seq.into_struct_value(), None);
|
||||||
|
|
||||||
|
let len = input_seq.get_type().num_elements();
|
||||||
|
|
||||||
|
let result = generator
|
||||||
|
.gen_array_var_alloc(
|
||||||
|
ctx,
|
||||||
|
llvm_usize.into(),
|
||||||
|
llvm_usize.const_int(u64::from(len), false),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let result = TypedArrayLikeAdapter::from(
|
||||||
|
result,
|
||||||
|
|_, _, val| val.into_int_value(),
|
||||||
|
|_, _, val| val.into(),
|
||||||
|
);
|
||||||
|
|
||||||
|
for i in 0..input_seq.get_type().num_elements() {
|
||||||
|
// Get the i-th element off of the tuple and load it into `result`.
|
||||||
|
let int = input_seq.load_element(ctx, i).into_int_value();
|
||||||
|
let int = ctx.builder.build_int_s_extend_or_bit_cast(int, llvm_usize, "").unwrap();
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
result.set_typed_unchecked(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
&llvm_usize.const_int(u64::from(i), false),
|
||||||
|
int,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
TypeEnum::TObj { obj_id, .. }
|
||||||
|
if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() =>
|
||||||
|
{
|
||||||
|
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
||||||
|
|
||||||
|
let input_int = input_seq.into_int_value();
|
||||||
|
|
||||||
|
let len = one;
|
||||||
|
let result = generator.gen_array_var_alloc(ctx, llvm_usize.into(), len, None).unwrap();
|
||||||
|
let result = TypedArrayLikeAdapter::from(
|
||||||
|
result,
|
||||||
|
|_, _, val| val.into_int_value(),
|
||||||
|
|_, _, val| val.into(),
|
||||||
|
);
|
||||||
|
let int =
|
||||||
|
ctx.builder.build_int_s_extend_or_bit_cast(input_int, llvm_usize, "").unwrap();
|
||||||
|
|
||||||
|
// Storing into result[0]
|
||||||
|
unsafe {
|
||||||
|
result.set_typed_unchecked(ctx, generator, &zero, int);
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => panic!("encountered unknown sequence type: {}", ctx.unifier.stringify(input_seq_ty)),
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user