forked from M-Labs/nac3
[core] codegen/ndarray: Reimplement np_{zeros,ones,full,empty}
Based on 792374fa: core/ndstrides: implement np_{zeros,ones,full,empty}.
This commit is contained in:
parent
7d02f5833d
commit
5880f964bb
@ -3,7 +3,6 @@ use inkwell::{
|
||||
values::{BasicValue, BasicValueEnum, IntValue, PointerValue},
|
||||
AddressSpace, IntPredicate, OptimizationLevel,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
|
||||
use nac3parser::ast::{Operator, StrRef};
|
||||
|
||||
@ -19,17 +18,28 @@ use super::{
|
||||
llvm_intrinsics::{self, call_memcpy_generic},
|
||||
macros::codegen_unreachable,
|
||||
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
|
||||
types::{ndarray::NDArrayType, ListType, ProxyType},
|
||||
types::{
|
||||
ndarray::{
|
||||
factory::{ndarray_one_value, ndarray_zero_value},
|
||||
NDArrayType,
|
||||
},
|
||||
ListType, ProxyType,
|
||||
},
|
||||
values::{
|
||||
ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue,
|
||||
ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator,
|
||||
ndarray::{shape::parse_numpy_int_sequence, NDArrayValue},
|
||||
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue,
|
||||
TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator,
|
||||
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
||||
},
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
use crate::{
|
||||
symbol_resolver::ValueEnum,
|
||||
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId},
|
||||
toplevel::{
|
||||
helper::{extract_ndims, PrimDef},
|
||||
numpy::unpack_ndarray_var_tys,
|
||||
DefinitionId,
|
||||
},
|
||||
typecheck::{
|
||||
magic_methods::Binop,
|
||||
typedef::{FunSignature, Type, TypeEnum},
|
||||
@ -174,132 +184,6 @@ pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
|
||||
Ok(ndarray)
|
||||
}
|
||||
|
||||
fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
elem_ty: Type,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
if [ctx.primitives.int32, ctx.primitives.uint32]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
|
||||
{
|
||||
ctx.ctx.i32_type().const_zero().into()
|
||||
} else if [ctx.primitives.int64, ctx.primitives.uint64]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
|
||||
{
|
||||
ctx.ctx.i64_type().const_zero().into()
|
||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) {
|
||||
ctx.ctx.f64_type().const_zero().into()
|
||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
|
||||
ctx.ctx.bool_type().const_zero().into()
|
||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
|
||||
ctx.gen_string(generator, "").into()
|
||||
} else {
|
||||
codegen_unreachable!(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
elem_ty: Type,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
if [ctx.primitives.int32, ctx.primitives.uint32]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
|
||||
{
|
||||
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int32);
|
||||
ctx.ctx.i32_type().const_int(1, is_signed).into()
|
||||
} else if [ctx.primitives.int64, ctx.primitives.uint64]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(elem_ty, *ty))
|
||||
{
|
||||
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int64);
|
||||
ctx.ctx.i64_type().const_int(1, is_signed).into()
|
||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) {
|
||||
ctx.ctx.f64_type().const_float(1.0).into()
|
||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
|
||||
ctx.ctx.bool_type().const_int(1, false).into()
|
||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
|
||||
ctx.gen_string(generator, "1").into()
|
||||
} else {
|
||||
codegen_unreachable!(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
/// LLVM-typed implementation for generating the implementation for constructing an `NDArray`.
|
||||
///
|
||||
/// * `elem_ty` - The element type of the `NDArray`.
|
||||
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
|
||||
///
|
||||
/// ### Notes on `shape`
|
||||
///
|
||||
/// Just like numpy, the `shape` argument can be:
|
||||
/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
|
||||
/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))`
|
||||
/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
||||
///
|
||||
/// See also [`typecheck::type_inferencer::fold_numpy_function_call_shape_argument`] to
|
||||
/// learn how `shape` gets from being a Python user expression to here.
|
||||
fn call_ndarray_empty_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
elem_ty: Type,
|
||||
shape: BasicValueEnum<'ctx>,
|
||||
) -> Result<NDArrayValue<'ctx>, String> {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
match shape {
|
||||
BasicValueEnum::PointerValue(shape_list_ptr)
|
||||
if ListValue::is_representable(shape_list_ptr, llvm_usize).is_ok() =>
|
||||
{
|
||||
// 1. A list of ints; e.g., `np.empty([600, 800, 3])`
|
||||
|
||||
let shape_list = ListValue::from_pointer_value(shape_list_ptr, llvm_usize, None);
|
||||
create_ndarray_dyn_shape(
|
||||
generator,
|
||||
ctx,
|
||||
elem_ty,
|
||||
&shape_list,
|
||||
|_, ctx, shape_list| Ok(shape_list.load_size(ctx, None)),
|
||||
|generator, ctx, shape_list, idx| {
|
||||
Ok(shape_list.data().get(ctx, generator, &idx, None).into_int_value())
|
||||
},
|
||||
)
|
||||
}
|
||||
BasicValueEnum::StructValue(shape_tuple) => {
|
||||
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
|
||||
// Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM.
|
||||
|
||||
// Get the length/size of the tuple, which also happens to be the value of `ndims`.
|
||||
let ndims = shape_tuple.get_type().count_fields();
|
||||
|
||||
let shape = (0..ndims)
|
||||
.map(|dim_i| {
|
||||
ctx.builder
|
||||
.build_extract_value(shape_tuple, dim_i, format!("dim{dim_i}").as_str())
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.map(|v| {
|
||||
ctx.builder.build_int_z_extend_or_bit_cast(v, llvm_usize, "").unwrap()
|
||||
})
|
||||
.unwrap()
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice())
|
||||
}
|
||||
BasicValueEnum::IntValue(shape_int) => {
|
||||
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
||||
let shape_int =
|
||||
ctx.builder.build_int_z_extend_or_bit_cast(shape_int, llvm_usize, "").unwrap();
|
||||
|
||||
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
|
||||
}
|
||||
_ => codegen_unreachable!(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as
|
||||
/// its input.
|
||||
fn ndarray_fill_flattened<'ctx, 'a, G, ValueFn>(
|
||||
@ -529,107 +413,6 @@ where
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// LLVM-typed implementation for generating the implementation for `ndarray.zeros`.
|
||||
///
|
||||
/// * `elem_ty` - The element type of the `NDArray`.
|
||||
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
|
||||
fn call_ndarray_zeros_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
elem_ty: Type,
|
||||
shape: BasicValueEnum<'ctx>,
|
||||
) -> Result<NDArrayValue<'ctx>, String> {
|
||||
let supported_types = [
|
||||
ctx.primitives.int32,
|
||||
ctx.primitives.int64,
|
||||
ctx.primitives.uint32,
|
||||
ctx.primitives.uint64,
|
||||
ctx.primitives.float,
|
||||
ctx.primitives.bool,
|
||||
ctx.primitives.str,
|
||||
];
|
||||
assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty)));
|
||||
|
||||
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
|
||||
ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| {
|
||||
let value = ndarray_zero_value(generator, ctx, elem_ty);
|
||||
|
||||
Ok(value)
|
||||
})?;
|
||||
|
||||
Ok(ndarray)
|
||||
}
|
||||
|
||||
/// LLVM-typed implementation for generating the implementation for `ndarray.ones`.
|
||||
///
|
||||
/// * `elem_ty` - The element type of the `NDArray`.
|
||||
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
|
||||
fn call_ndarray_ones_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
elem_ty: Type,
|
||||
shape: BasicValueEnum<'ctx>,
|
||||
) -> Result<NDArrayValue<'ctx>, String> {
|
||||
let supported_types = [
|
||||
ctx.primitives.int32,
|
||||
ctx.primitives.int64,
|
||||
ctx.primitives.uint32,
|
||||
ctx.primitives.uint64,
|
||||
ctx.primitives.float,
|
||||
ctx.primitives.bool,
|
||||
ctx.primitives.str,
|
||||
];
|
||||
assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty)));
|
||||
|
||||
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
|
||||
ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| {
|
||||
let value = ndarray_one_value(generator, ctx, elem_ty);
|
||||
|
||||
Ok(value)
|
||||
})?;
|
||||
|
||||
Ok(ndarray)
|
||||
}
|
||||
|
||||
/// LLVM-typed implementation for generating the implementation for `ndarray.full`.
|
||||
///
|
||||
/// * `elem_ty` - The element type of the `NDArray`.
|
||||
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
|
||||
fn call_ndarray_full_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
elem_ty: Type,
|
||||
shape: BasicValueEnum<'ctx>,
|
||||
fill_value: BasicValueEnum<'ctx>,
|
||||
) -> Result<NDArrayValue<'ctx>, String> {
|
||||
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
|
||||
ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| {
|
||||
let value = if fill_value.is_pointer_value() {
|
||||
let llvm_i1 = ctx.ctx.bool_type();
|
||||
|
||||
let copy = generator.gen_var_alloc(ctx, fill_value.get_type(), None)?;
|
||||
|
||||
call_memcpy_generic(
|
||||
ctx,
|
||||
copy,
|
||||
fill_value.into_pointer_value(),
|
||||
fill_value.get_type().size_of().map(Into::into).unwrap(),
|
||||
llvm_i1.const_zero(),
|
||||
);
|
||||
|
||||
copy.into()
|
||||
} else if fill_value.is_int_value() || fill_value.is_float_value() {
|
||||
fill_value
|
||||
} else {
|
||||
codegen_unreachable!(ctx)
|
||||
};
|
||||
|
||||
Ok(value)
|
||||
})?;
|
||||
|
||||
Ok(ndarray)
|
||||
}
|
||||
|
||||
/// Returns the number of dimensions for a multidimensional list as an [`IntValue`].
|
||||
fn llvm_ndlist_get_ndims<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
@ -1752,8 +1535,15 @@ pub fn gen_ndarray_empty<'ctx>(
|
||||
let shape_ty = fun.0.args[0].ty;
|
||||
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
||||
|
||||
call_ndarray_empty_impl(generator, context, context.primitives.float, shape_arg)
|
||||
.map(NDArrayValue::into)
|
||||
let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
|
||||
let llvm_dtype = context.get_llvm_type(generator, dtype);
|
||||
let ndims = extract_ndims(&context.unifier, ndims);
|
||||
|
||||
let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg));
|
||||
|
||||
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims))
|
||||
.construct_numpy_empty(generator, context, &shape, None);
|
||||
Ok(ndarray.as_base_value())
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `ndarray.zeros`.
|
||||
@ -1770,8 +1560,15 @@ pub fn gen_ndarray_zeros<'ctx>(
|
||||
let shape_ty = fun.0.args[0].ty;
|
||||
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
||||
|
||||
call_ndarray_zeros_impl(generator, context, context.primitives.float, shape_arg)
|
||||
.map(NDArrayValue::into)
|
||||
let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
|
||||
let llvm_dtype = context.get_llvm_type(generator, dtype);
|
||||
let ndims = extract_ndims(&context.unifier, ndims);
|
||||
|
||||
let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg));
|
||||
|
||||
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims))
|
||||
.construct_numpy_zeros(generator, context, dtype, &shape, None);
|
||||
Ok(ndarray.as_base_value())
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `ndarray.ones`.
|
||||
@ -1788,8 +1585,15 @@ pub fn gen_ndarray_ones<'ctx>(
|
||||
let shape_ty = fun.0.args[0].ty;
|
||||
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
|
||||
|
||||
call_ndarray_ones_impl(generator, context, context.primitives.float, shape_arg)
|
||||
.map(NDArrayValue::into)
|
||||
let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
|
||||
let llvm_dtype = context.get_llvm_type(generator, dtype);
|
||||
let ndims = extract_ndims(&context.unifier, ndims);
|
||||
|
||||
let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg));
|
||||
|
||||
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims))
|
||||
.construct_numpy_ones(generator, context, dtype, &shape, None);
|
||||
Ok(ndarray.as_base_value())
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `ndarray.full`.
|
||||
@ -1809,8 +1613,15 @@ pub fn gen_ndarray_full<'ctx>(
|
||||
let fill_value_arg =
|
||||
args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?;
|
||||
|
||||
call_ndarray_full_impl(generator, context, fill_value_ty, shape_arg, fill_value_arg)
|
||||
.map(NDArrayValue::into)
|
||||
let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
|
||||
let llvm_dtype = context.get_llvm_type(generator, dtype);
|
||||
let ndims = extract_ndims(&context.unifier, ndims);
|
||||
|
||||
let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg));
|
||||
|
||||
let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, Some(ndims))
|
||||
.construct_numpy_full(generator, context, &shape, fill_value_arg, None);
|
||||
Ok(ndarray.as_base_value())
|
||||
}
|
||||
|
||||
pub fn gen_ndarray_array<'ctx>(
|
||||
|
146
nac3core/src/codegen/types/ndarray/factory.rs
Normal file
146
nac3core/src/codegen/types/ndarray/factory.rs
Normal file
@ -0,0 +1,146 @@
|
||||
use inkwell::values::{BasicValueEnum, IntValue};
|
||||
|
||||
use super::NDArrayType;
|
||||
use crate::{
|
||||
codegen::{
|
||||
irrt, types::ProxyType, values::TypedArrayLikeAccessor, CodeGenContext, CodeGenerator,
|
||||
},
|
||||
typecheck::typedef::Type,
|
||||
};
|
||||
|
||||
/// Get the zero value in `np.zeros()` of a `dtype`.
|
||||
pub fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
dtype: Type,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
if [ctx.primitives.int32, ctx.primitives.uint32]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(dtype, *ty))
|
||||
{
|
||||
ctx.ctx.i32_type().const_zero().into()
|
||||
} else if [ctx.primitives.int64, ctx.primitives.uint64]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(dtype, *ty))
|
||||
{
|
||||
ctx.ctx.i64_type().const_zero().into()
|
||||
} else if ctx.unifier.unioned(dtype, ctx.primitives.float) {
|
||||
ctx.ctx.f64_type().const_zero().into()
|
||||
} else if ctx.unifier.unioned(dtype, ctx.primitives.bool) {
|
||||
ctx.ctx.bool_type().const_zero().into()
|
||||
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
|
||||
ctx.gen_string(generator, "").into()
|
||||
} else {
|
||||
panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype));
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the one value in `np.ones()` of a `dtype`.
|
||||
pub fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
dtype: Type,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
if [ctx.primitives.int32, ctx.primitives.uint32]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(dtype, *ty))
|
||||
{
|
||||
let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int32);
|
||||
ctx.ctx.i32_type().const_int(1, is_signed).into()
|
||||
} else if [ctx.primitives.int64, ctx.primitives.uint64]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(dtype, *ty))
|
||||
{
|
||||
let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int64);
|
||||
ctx.ctx.i64_type().const_int(1, is_signed).into()
|
||||
} else if ctx.unifier.unioned(dtype, ctx.primitives.float) {
|
||||
ctx.ctx.f64_type().const_float(1.0).into()
|
||||
} else if ctx.unifier.unioned(dtype, ctx.primitives.bool) {
|
||||
ctx.ctx.bool_type().const_int(1, false).into()
|
||||
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
|
||||
ctx.gen_string(generator, "1").into()
|
||||
} else {
|
||||
panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype));
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> NDArrayType<'ctx> {
|
||||
/// Create an ndarray like
|
||||
/// [`np.empty`](https://numpy.org/doc/stable/reference/generated/numpy.empty.html).
|
||||
pub fn construct_numpy_empty<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
let ndarray = self.construct_uninitialized(generator, ctx, name);
|
||||
|
||||
// Validate `shape`
|
||||
irrt::ndarray::call_nac3_ndarray_util_assert_shape_no_negative(generator, ctx, shape);
|
||||
|
||||
ndarray.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator));
|
||||
unsafe { ndarray.create_data(generator, ctx) };
|
||||
|
||||
ndarray
|
||||
}
|
||||
|
||||
/// Create an ndarray like
|
||||
/// [`np.full`](https://numpy.org/doc/stable/reference/generated/numpy.full.html).
|
||||
pub fn construct_numpy_full<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
fill_value: BasicValueEnum<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
let ndarray = self.construct_numpy_empty(generator, ctx, shape, name);
|
||||
ndarray.fill(generator, ctx, fill_value);
|
||||
ndarray
|
||||
}
|
||||
|
||||
/// Create an ndarray like
|
||||
/// [`np.zero`](https://numpy.org/doc/stable/reference/generated/numpy.zeros.html).
|
||||
pub fn construct_numpy_zeros<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
dtype: Type,
|
||||
shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
assert_eq!(
|
||||
ctx.get_llvm_type(generator, dtype),
|
||||
self.dtype,
|
||||
"Expected LLVM dtype={} but got {}",
|
||||
self.dtype.print_to_string(),
|
||||
ctx.get_llvm_type(generator, dtype).print_to_string(),
|
||||
);
|
||||
|
||||
let fill_value = ndarray_zero_value(generator, ctx, dtype);
|
||||
self.construct_numpy_full(generator, ctx, shape, fill_value, name)
|
||||
}
|
||||
|
||||
/// Create an ndarray like
|
||||
/// [`np.ones`](https://numpy.org/doc/stable/reference/generated/numpy.ones.html).
|
||||
pub fn construct_numpy_ones<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
dtype: Type,
|
||||
shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
assert_eq!(
|
||||
ctx.get_llvm_type(generator, dtype),
|
||||
self.dtype,
|
||||
"Expected LLVM dtype={} but got {}",
|
||||
self.dtype.print_to_string(),
|
||||
ctx.get_llvm_type(generator, dtype).print_to_string(),
|
||||
);
|
||||
|
||||
let fill_value = ndarray_one_value(generator, ctx, dtype);
|
||||
self.construct_numpy_full(generator, ctx, shape, fill_value, name)
|
||||
}
|
||||
}
|
@ -25,6 +25,7 @@ pub use indexing::*;
|
||||
pub use nditer::*;
|
||||
|
||||
mod contiguous;
|
||||
pub mod factory;
|
||||
mod indexing;
|
||||
mod nditer;
|
||||
|
||||
|
@ -163,8 +163,13 @@ impl<'ctx> NDIterType<'ctx> {
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
assert!(
|
||||
ndarray.get_type().ndims().is_some(),
|
||||
"NDIter requires ndims of NDArray to be known."
|
||||
);
|
||||
|
||||
let nditer = self.raw_alloca_var(generator, ctx, None);
|
||||
let ndims = ndarray.load_ndims(ctx);
|
||||
let ndims = self.llvm_usize.const_int(ndarray.get_type().ndims().unwrap(), false);
|
||||
|
||||
// The caller has the responsibility to allocate 'indices' for `NDIter`.
|
||||
let indices =
|
||||
|
@ -23,6 +23,7 @@ pub use nditer::*;
|
||||
mod contiguous;
|
||||
mod indexing;
|
||||
mod nditer;
|
||||
pub mod shape;
|
||||
mod view;
|
||||
|
||||
/// Proxy type for accessing an `NDArray` value in LLVM.
|
||||
@ -397,6 +398,23 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||
irrt::ndarray::call_nac3_ndarray_copy_data(generator, ctx, src, *self);
|
||||
}
|
||||
|
||||
/// Fill the ndarray with a scalar.
|
||||
///
|
||||
/// `fill_value` must have the same LLVM type as the `dtype` of this ndarray.
|
||||
pub fn fill<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
value: BasicValueEnum<'ctx>,
|
||||
) {
|
||||
self.foreach(generator, ctx, |_, ctx, _, nditer| {
|
||||
let p = nditer.get_pointer(ctx);
|
||||
ctx.builder.build_store(p, value).unwrap();
|
||||
Ok(())
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
/// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar.
|
||||
#[must_use]
|
||||
pub fn is_unsized(&self) -> Option<bool> {
|
||||
|
152
nac3core/src/codegen/values/ndarray/shape.rs
Normal file
152
nac3core/src/codegen/values/ndarray/shape.rs
Normal file
@ -0,0 +1,152 @@
|
||||
use inkwell::values::{BasicValueEnum, IntValue};
|
||||
|
||||
use crate::{
|
||||
codegen::{
|
||||
stmt::gen_for_callback_incrementing,
|
||||
types::{ListType, TupleType},
|
||||
values::{
|
||||
ArraySliceValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter,
|
||||
TypedArrayLikeMutator, UntypedArrayLikeAccessor,
|
||||
},
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
typecheck::typedef::{Type, TypeEnum},
|
||||
};
|
||||
|
||||
/// Parse a NumPy-like "int sequence" input and return the int sequence as an array and its length.
|
||||
///
|
||||
/// * `sequence` - The `sequence` parameter.
|
||||
/// * `sequence_ty` - The typechecker type of `sequence`
|
||||
///
|
||||
/// The `sequence` argument type may only be one of the following:
|
||||
/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
|
||||
/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))`
|
||||
/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to
|
||||
/// `np.empty([3])`
|
||||
///
|
||||
/// All `int32` values will be sign-extended to `SizeT`.
|
||||
pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
(input_seq_ty, input_seq): (Type, BasicValueEnum<'ctx>),
|
||||
) -> impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let zero = llvm_usize.const_zero();
|
||||
let one = llvm_usize.const_int(1, false);
|
||||
|
||||
// The result `list` to return.
|
||||
match &*ctx.unifier.get_ty_immutable(input_seq_ty) {
|
||||
TypeEnum::TObj { obj_id, .. }
|
||||
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
||||
{
|
||||
// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
|
||||
|
||||
let input_seq = ListType::from_unifier_type(generator, ctx, input_seq_ty)
|
||||
.map_value(input_seq.into_pointer_value(), None);
|
||||
|
||||
let len = input_seq.load_size(ctx, None);
|
||||
// TODO: Find a way to remove this mid-BB allocation
|
||||
let result = ctx.builder.build_array_alloca(llvm_usize, len, "").unwrap();
|
||||
let result = TypedArrayLikeAdapter::from(
|
||||
ArraySliceValue::from_ptr_val(result, len, None),
|
||||
|_, _, val| val.into_int_value(),
|
||||
|_, _, val| val.into(),
|
||||
);
|
||||
|
||||
// Load all the `int32`s from the input_sequence, cast them to `SizeT`, and store them into `result`
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
None,
|
||||
zero,
|
||||
(len, false),
|
||||
|generator, ctx, _, i| {
|
||||
// Load the i-th int32 in the input sequence
|
||||
let int = unsafe {
|
||||
input_seq.data().get_unchecked(ctx, generator, &i, None).into_int_value()
|
||||
};
|
||||
|
||||
// Cast to SizeT
|
||||
let int =
|
||||
ctx.builder.build_int_s_extend_or_bit_cast(int, llvm_usize, "").unwrap();
|
||||
|
||||
// Store
|
||||
unsafe { result.set_typed_unchecked(ctx, generator, &i, int) };
|
||||
|
||||
Ok(())
|
||||
},
|
||||
one,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
TypeEnum::TTuple { .. } => {
|
||||
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
|
||||
|
||||
let input_seq = TupleType::from_unifier_type(generator, ctx, input_seq_ty)
|
||||
.map_value(input_seq.into_struct_value(), None);
|
||||
|
||||
let len = input_seq.get_type().num_elements();
|
||||
|
||||
let result = generator
|
||||
.gen_array_var_alloc(
|
||||
ctx,
|
||||
llvm_usize.into(),
|
||||
llvm_usize.const_int(u64::from(len), false),
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
let result = TypedArrayLikeAdapter::from(
|
||||
result,
|
||||
|_, _, val| val.into_int_value(),
|
||||
|_, _, val| val.into(),
|
||||
);
|
||||
|
||||
for i in 0..input_seq.get_type().num_elements() {
|
||||
// Get the i-th element off of the tuple and load it into `result`.
|
||||
let int = input_seq.load_element(ctx, i).into_int_value();
|
||||
let int = ctx.builder.build_int_s_extend_or_bit_cast(int, llvm_usize, "").unwrap();
|
||||
|
||||
unsafe {
|
||||
result.set_typed_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(u64::from(i), false),
|
||||
int,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
TypeEnum::TObj { obj_id, .. }
|
||||
if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() =>
|
||||
{
|
||||
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
||||
|
||||
let input_int = input_seq.into_int_value();
|
||||
|
||||
let len = one;
|
||||
let result = generator.gen_array_var_alloc(ctx, llvm_usize.into(), len, None).unwrap();
|
||||
let result = TypedArrayLikeAdapter::from(
|
||||
result,
|
||||
|_, _, val| val.into_int_value(),
|
||||
|_, _, val| val.into(),
|
||||
);
|
||||
let int =
|
||||
ctx.builder.build_int_s_extend_or_bit_cast(input_int, llvm_usize, "").unwrap();
|
||||
|
||||
// Storing into result[0]
|
||||
unsafe {
|
||||
result.set_typed_unchecked(ctx, generator, &zero, int);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
_ => panic!("encountered unknown sequence type: {}", ctx.unifier.stringify(input_seq_ty)),
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user