1
0
forked from M-Labs/nac3

core: Implement most ndarray-creation functions

This commit is contained in:
David Mak 2023-11-27 13:25:53 +08:00
parent 27fcf8926e
commit 140f8f8a08
10 changed files with 1130 additions and 105 deletions

View File

@ -92,6 +92,18 @@ pub trait CodeGenerator {
gen_var(ctx, ty, name) gen_var(ctx, ty, name)
} }
/// Allocate memory for a variable and return a pointer pointing to it.
/// The default implementation places the allocations at the start of the function.
fn gen_array_var_alloc<'ctx, 'a>(
&mut self,
ctx: &mut CodeGenContext<'ctx, 'a>,
ty: BasicTypeEnum<'ctx>,
size: IntValue<'ctx>,
name: Option<&str>,
) -> Result<PointerValue<'ctx>, String> {
gen_array_var(ctx, ty, size, name)
}
/// Return a pointer pointing to the target of the expression. /// Return a pointer pointing to the target of the expression.
fn gen_store_target<'ctx>( fn gen_store_target<'ctx>(
&mut self, &mut self,

View File

@ -199,27 +199,27 @@ double __nac3_j0(double x) {
} }
uint32_t __nac3_ndarray_calc_size( uint32_t __nac3_ndarray_calc_size(
const int32_t *list_data, const uint64_t *list_data,
uint32_t list_len uint32_t list_len
) { ) {
uint32_t num_elems = 1; uint32_t num_elems = 1;
for (uint32_t i = 0; i < list_len; ++i) { for (uint32_t i = 0; i < list_len; ++i) {
int32_t val = list_data[i]; uint64_t val = list_data[i];
__builtin_assume(val >= 0); __builtin_assume(val >= 0);
num_elems *= (uint32_t) list_data[i]; num_elems *= list_data[i];
} }
return num_elems; return num_elems;
} }
uint64_t __nac3_ndarray_calc_size64( uint64_t __nac3_ndarray_calc_size64(
const int32_t *list_data, const uint64_t *list_data,
uint64_t list_len uint64_t list_len
) { ) {
uint64_t num_elems = 1; uint64_t num_elems = 1;
for (uint64_t i = 0; i < list_len; ++i) { for (uint64_t i = 0; i < list_len; ++i) {
int32_t val = list_data[i]; uint64_t val = list_data[i];
__builtin_assume(val >= 0); __builtin_assume(val >= 0);
num_elems *= (uint64_t) list_data[i]; num_elems *= list_data[i];
} }
return num_elems; return num_elems;
} }
@ -240,4 +240,32 @@ void __nac3_ndarray_init_dims64(
for (uint64_t i = 0; i < shape_len; ++i) { for (uint64_t i = 0; i < shape_len; ++i) {
ndarray_dims[i] = (uint64_t) shape_data[i]; ndarray_dims[i] = (uint64_t) shape_data[i];
} }
}
void __nac3_ndarray_calc_nd_indices(
uint32_t index,
const uint32_t* dims,
uint32_t num_dims,
uint32_t* idxs
) {
uint32_t stride = 1;
for (uint32_t dim = 0; dim < num_dims; dim++) {
uint32_t i = num_dims - dim - 1;
idxs[i] = (index / stride) % dims[i];
stride *= dims[i];
}
}
void __nac3_ndarray_calc_nd_indices64(
uint64_t index,
const uint64_t* dims,
uint64_t num_dims,
uint64_t* idxs
) {
uint64_t stride = 1;
for (uint64_t dim = 0; dim < num_dims; dim++) {
uint64_t i = num_dims - dim - 1;
idxs[i] = (index / stride) % dims[i];
stride *= dims[i];
}
} }

View File

@ -1,6 +1,6 @@
use crate::typecheck::typedef::Type; use crate::typecheck::typedef::Type;
use super::{CodeGenContext, CodeGenerator}; use super::{assert_is_list, assert_is_ndarray, CodeGenContext, CodeGenerator};
use inkwell::{ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
context::Context, context::Context,
@ -12,9 +12,6 @@ use inkwell::{
}; };
use nac3parser::ast::Expr; use nac3parser::ast::Expr;
#[cfg(debug_assertions)]
use inkwell::types::AnyTypeEnum;
#[must_use] #[must_use]
pub fn load_irrt(ctx: &Context) -> Module { pub fn load_irrt(ctx: &Context) -> Module {
let bitcode_buf = MemoryBuffer::create_from_memory_range( let bitcode_buf = MemoryBuffer::create_from_memory_range(
@ -550,62 +547,21 @@ pub fn call_j0<'ctx>(
.into_float_value() .into_float_value()
} }
/// Checks whether the pointer `value` refers to a `list` in LLVM.
fn assert_is_list(value: PointerValue) -> PointerValue {
#[cfg(debug_assertions)]
{
let llvm_shape_ty = value.get_type().get_element_type();
let AnyTypeEnum::StructType(llvm_shape_ty) = llvm_shape_ty else {
panic!("Expected struct type for `list` type, but got {llvm_shape_ty}")
};
assert_eq!(llvm_shape_ty.count_fields(), 2);
assert!(matches!(llvm_shape_ty.get_field_type_at_index(0), Some(BasicTypeEnum::PointerType(..))));
assert!(matches!(llvm_shape_ty.get_field_type_at_index(1), Some(BasicTypeEnum::IntType(..))));
}
value
}
/// Checks whether the pointer `value` refers to an `NDArray` in LLVM.
fn assert_is_ndarray(value: PointerValue) -> PointerValue {
#[cfg(debug_assertions)]
{
let llvm_ndarray_ty = value.get_type().get_element_type();
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
panic!("Expected struct type for `NDArray` type, but got {llvm_ndarray_ty}")
};
assert_eq!(llvm_ndarray_ty.count_fields(), 3);
assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(0), Some(BasicTypeEnum::IntType(..))));
let Some(ndarray_dims) = llvm_ndarray_ty.get_field_type_at_index(1) else {
unreachable!()
};
let BasicTypeEnum::PointerType(dims) = ndarray_dims else {
panic!("Expected pointer type for `list.1`, but got {ndarray_dims}")
};
assert!(matches!(dims.get_element_type(), AnyTypeEnum::IntType(..)));
assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(2), Some(BasicTypeEnum::PointerType(..))));
}
value
}
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [IntValue] representing the /// Generates a call to `__nac3_ndarray_calc_size`. Returns an [IntValue] representing the
/// calculated total size. /// calculated total size.
/// ///
/// * `shape` - LLVM pointer to the `shape` of the NDArray. This value must be the LLVM /// * `num_dims` - An [IntValue] containing the number of dimensions.
/// representation of a `list`. /// * `dims` - A [PointerValue] to an array containing the size of each dimensions.
pub fn call_ndarray_calc_size<'ctx, 'a>( pub fn call_ndarray_calc_size<'ctx, 'a>(
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
shape: PointerValue<'ctx>, num_dims: IntValue<'ctx>,
dims: PointerValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
assert_is_list(shape); let llvm_i64 = ctx.ctx.i64_type();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); let llvm_pi64 = llvm_i64.ptr_type(AddressSpace::default());
let ndarray_calc_size_fn_name = match generator.get_size_type(ctx.ctx).get_bit_width() { let ndarray_calc_size_fn_name = match generator.get_size_type(ctx.ctx).get_bit_width() {
32 => "__nac3_ndarray_calc_size", 32 => "__nac3_ndarray_calc_size",
@ -614,7 +570,7 @@ pub fn call_ndarray_calc_size<'ctx, 'a>(
}; };
let ndarray_calc_size_fn_t = llvm_usize.fn_type( let ndarray_calc_size_fn_t = llvm_usize.fn_type(
&[ &[
llvm_pi32.into(), llvm_pi64.into(),
llvm_usize.into(), llvm_usize.into(),
], ],
false, false,
@ -624,30 +580,12 @@ pub fn call_ndarray_calc_size<'ctx, 'a>(
ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None) ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
}); });
let (
shape_data,
shape_len,
) = unsafe {
(
ctx.builder.build_in_bounds_gep(
shape,
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
""
),
ctx.builder.build_in_bounds_gep(
shape,
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
""
),
)
};
ctx.builder ctx.builder
.build_call( .build_call(
ndarray_calc_size_fn, ndarray_calc_size_fn,
&[ &[
ctx.builder.build_load(shape_data, "").into(), dims.into(),
ctx.builder.build_load(shape_len, "").into(), num_dims.into(),
], ],
"", "",
) )
@ -721,4 +659,68 @@ pub fn call_ndarray_init_dims<'ctx, 'a>(
], ],
"", "",
); );
}
pub fn call_ndarray_calc_nd_indices<'ctx, 'a>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>,
index: IntValue<'ctx>,
ndarray: PointerValue<'ctx>,
) -> Result<PointerValue<'ctx>, String> {
assert_is_ndarray(ndarray);
let llvm_void = ctx.ctx.void_type();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_nd_indices_dn_name = match generator.get_size_type(ctx.ctx).get_bit_width() {
32 => "__nac3_ndarray_calc_nd_indices",
64 => "__nac3_ndarray_calc_nd_indices64",
bw => unreachable!("Unsupported size type bit width: {}", bw)
};
let ndarray_calc_nd_indices_fn = ctx.module.get_function(ndarray_calc_nd_indices_dn_name).unwrap_or_else(|| {
let fn_type = llvm_void.fn_type(
&[
llvm_usize.into(),
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
],
false,
);
ctx.module.add_function(ndarray_calc_nd_indices_dn_name, fn_type, None)
});
let ndarray_num_dims = ctx.build_gep_and_load(
ndarray,
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
None,
).into_int_value();
let ndarray_dims = ctx.build_gep_and_load(
ndarray,
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
None,
).into_pointer_value();
let indices = ctx.builder.build_array_alloca(
llvm_usize,
ndarray_num_dims,
"",
);
ctx.builder.build_call(
ndarray_calc_nd_indices_fn,
&[
index.into(),
ndarray_dims.into(),
ndarray_num_dims.into(),
indices.into(),
],
"",
);
Ok(indices)
} }

View File

@ -34,6 +34,9 @@ use std::sync::{
}; };
use std::thread; use std::thread;
#[cfg(debug_assertions)]
use inkwell::types::AnyTypeEnum;
pub mod concrete_type; pub mod concrete_type;
pub mod expr; pub mod expr;
mod generator; mod generator;
@ -236,7 +239,7 @@ pub struct WorkerRegistry {
static_value_store: Arc<Mutex<StaticValueStore>>, static_value_store: Arc<Mutex<StaticValueStore>>,
/// LLVM-related options for code generation. /// LLVM-related options for code generation.
llvm_options: CodeGenLLVMOptions, pub llvm_options: CodeGenLLVMOptions,
} }
impl WorkerRegistry { impl WorkerRegistry {
@ -995,3 +998,43 @@ fn gen_in_range_check<'ctx>(
ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp") ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp")
} }
/// Checks whether the pointer `value` refers to a `list` in LLVM.
fn assert_is_list(value: PointerValue) -> PointerValue {
#[cfg(debug_assertions)]
{
let llvm_shape_ty = value.get_type().get_element_type();
let AnyTypeEnum::StructType(llvm_shape_ty) = llvm_shape_ty else {
panic!("Expected struct type for `list` type, but got {llvm_shape_ty}")
};
assert_eq!(llvm_shape_ty.count_fields(), 2);
assert!(matches!(llvm_shape_ty.get_field_type_at_index(0), Some(BasicTypeEnum::PointerType(..))));
assert!(matches!(llvm_shape_ty.get_field_type_at_index(1), Some(BasicTypeEnum::IntType(..))));
}
value
}
/// Checks whether the pointer `value` refers to an `NDArray` in LLVM.
fn assert_is_ndarray(value: PointerValue) -> PointerValue {
#[cfg(debug_assertions)]
{
let llvm_ndarray_ty = value.get_type().get_element_type();
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
panic!("Expected struct type for `NDArray` type, but got {llvm_ndarray_ty}")
};
assert_eq!(llvm_ndarray_ty.count_fields(), 3);
assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(0), Some(BasicTypeEnum::IntType(..))));
let Some(ndarray_dims) = llvm_ndarray_ty.get_field_type_at_index(1) else {
unreachable!()
};
let BasicTypeEnum::PointerType(dims) = ndarray_dims else {
panic!("Expected pointer type for `list.1`, but got {ndarray_dims}")
};
assert!(matches!(dims.get_element_type(), AnyTypeEnum::IntType(..)));
assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(2), Some(BasicTypeEnum::PointerType(..))));
}
value
}

View File

@ -15,7 +15,7 @@ use crate::{
use inkwell::{ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
basic_block::BasicBlock, basic_block::BasicBlock,
types::BasicTypeEnum, types::{BasicType, BasicTypeEnum},
values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue}, values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue},
IntPredicate, IntPredicate,
}; };
@ -54,6 +54,37 @@ pub fn gen_var<'ctx>(
Ok(ptr) Ok(ptr)
} }
/// See [CodeGenerator::gen_array_var_alloc].
pub fn gen_array_var<'ctx, 'a, T: BasicType<'ctx>>(
ctx: &mut CodeGenContext<'ctx, 'a>,
ty: T,
size: IntValue<'ctx>,
name: Option<&str>,
) -> Result<PointerValue<'ctx>, String> {
// Restore debug location
let di_loc = ctx.debug_info.0.create_debug_location(
ctx.ctx,
ctx.current_loc.row as u32,
ctx.current_loc.column as u32,
ctx.debug_info.2,
None,
);
// put the alloca in init block
let current = ctx.builder.get_insert_block().unwrap();
// position before the last branching instruction...
ctx.builder.position_before(&ctx.init_bb.get_last_instruction().unwrap());
ctx.builder.set_current_debug_location(di_loc);
let ptr = ctx.builder.build_array_alloca(ty, size, name.unwrap_or(""));
ctx.builder.position_at_end(current);
ctx.builder.set_current_debug_location(di_loc);
Ok(ptr)
}
/// See [`CodeGenerator::gen_store_target`]. /// See [`CodeGenerator::gen_store_target`].
pub fn gen_store_target<'ctx, G: CodeGenerator>( pub fn gen_store_target<'ctx, G: CodeGenerator>(
generator: &mut G, generator: &mut G,

View File

@ -13,7 +13,13 @@ use crate::{
stmt::exn_constructor, stmt::exn_constructor,
}, },
symbol_resolver::SymbolValue, symbol_resolver::SymbolValue,
toplevel::numpy::gen_ndarray_empty, toplevel::numpy::{
gen_ndarray_empty,
gen_ndarray_eye,
gen_ndarray_full,
gen_ndarray_ones,
gen_ndarray_zeros,
},
}; };
use inkwell::{ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
@ -22,6 +28,7 @@ use inkwell::{
FloatPredicate, FloatPredicate,
IntPredicate IntPredicate
}; };
use crate::toplevel::numpy::gen_ndarray_identity;
type BuiltinInfo = Vec<(Arc<RwLock<TopLevelDef>>, Option<Stmt>)>; type BuiltinInfo = Vec<(Arc<RwLock<TopLevelDef>>, Option<Stmt>)>;
@ -279,10 +286,30 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
let boolean = primitives.0.bool; let boolean = primitives.0.bool;
let range = primitives.0.range; let range = primitives.0.range;
let string = primitives.0.str; let string = primitives.0.str;
let ndarray = {
let ndarray_ty = TypeEnum::ndarray(&mut primitives.1, None, None, &primitives.0);
primitives.1.add_ty(ndarray_ty)
};
let ndarray_float = { let ndarray_float = {
let ndarray_ty_enum = TypeEnum::ndarray(&mut primitives.1, Some(float), None, &primitives.0); let ndarray_ty_enum = TypeEnum::ndarray(&mut primitives.1, Some(float), None, &primitives.0);
primitives.1.add_ty(ndarray_ty_enum) primitives.1.add_ty(ndarray_ty_enum)
}; };
let ndarray_float_2d = {
let value = match primitives.0.size_t {
64 => SymbolValue::U64(2u64),
32 => SymbolValue::U32(2u32),
_ => unreachable!(),
};
let ndims = primitives.1.add_ty(TypeEnum::TLiteral {
values: vec![value],
loc: None,
});
primitives.1.add_ty(TypeEnum::TNDArray {
ty: float,
ndims,
})
};
let list_int32 = primitives.1.add_ty(TypeEnum::TList { ty: int32 }); let list_int32 = primitives.1.add_ty(TypeEnum::TList { ty: int32 });
let num_ty = primitives.1.get_fresh_var_with_range( let num_ty = primitives.1.get_fresh_var_with_range(
&[int32, int64, float, boolean, uint32, uint64], &[int32, int64, float, boolean, uint32, uint64],
@ -869,6 +896,89 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
.map(|val| Some(val.as_basic_value_enum())) .map(|val| Some(val.as_basic_value_enum()))
}), }),
), ),
create_fn_by_codegen(
primitives,
&var_map,
"np_zeros",
ndarray_float,
// We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a
// type variable
&[(list_int32, "shape")],
Box::new(|ctx, obj, fun, args, generator| {
gen_ndarray_zeros(ctx, obj, fun, args, generator)
.map(|val| Some(val.as_basic_value_enum()))
}),
),
create_fn_by_codegen(
primitives,
&var_map,
"np_ones",
ndarray_float,
// We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a
// type variable
&[(list_int32, "shape")],
Box::new(|ctx, obj, fun, args, generator| {
gen_ndarray_ones(ctx, obj, fun, args, generator)
.map(|val| Some(val.as_basic_value_enum()))
}),
),
{
let tv = primitives.1.get_fresh_var(Some("T".into()), None).0;
create_fn_by_codegen(
primitives,
&var_map,
"np_full",
ndarray,
// We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a
// type variable
&[(list_int32, "shape"), (tv, "fill_value")],
Box::new(|ctx, obj, fun, args, generator| {
gen_ndarray_full(ctx, obj, fun, args, generator)
.map(|val| Some(val.as_basic_value_enum()))
}),
)
},
Arc::new(RwLock::new(TopLevelDef::Function {
name: "np_eye".into(),
simple_name: "np_eye".into(),
signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg { name: "N".into(), ty: int32, default_value: None },
// TODO(Derppening): Default values current do not work?
FuncArg {
name: "M".into(),
ty: int32,
default_value: Some(SymbolValue::OptionNone)
},
FuncArg { name: "k".into(), ty: int32, default_value: Some(SymbolValue::I32(0)) },
],
ret: ndarray_float_2d,
vars: var_map.clone(),
})),
var_id: Default::default(),
instance_to_symbol: Default::default(),
instance_to_stmt: Default::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, obj, fun, args, generator| {
gen_ndarray_eye(ctx, obj, fun, args, generator)
.map(|val| Some(val.as_basic_value_enum()))
},
)))),
loc: None,
})),
create_fn_by_codegen(
primitives,
&var_map,
"np_identity",
ndarray_float_2d,
&[(int32, "n")],
Box::new(|ctx, obj, fun, args, generator| {
gen_ndarray_identity(ctx, obj, fun, args, generator)
.map(|val| Some(val.as_basic_value_enum()))
}),
),
create_fn_by_codegen( create_fn_by_codegen(
primitives, primitives,
&var_map, &var_map,
@ -1364,7 +1474,22 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
Some(ctx.builder.build_int_truncate(len, int32, "len2i32").into()) Some(ctx.builder.build_int_truncate(len, int32, "len2i32").into())
} }
} }
TypeEnum::TNDArray { .. } => todo!(), TypeEnum::TNDArray { .. } => {
let llvm_i32 = ctx.ctx.i32_type();
let i32_zero = llvm_i32.const_zero();
let len = ctx.build_gep_and_load(
arg.into_pointer_value(),
&[i32_zero, i32_zero],
None,
).into_int_value();
if len.get_type().get_bit_width() != 32 {
Some(ctx.builder.build_int_truncate(len, llvm_i32, "len").into())
} else {
Some(len.into())
}
}
_ => unreachable!(), _ => unreachable!(),
} }
}) })

View File

@ -1,14 +1,15 @@
use inkwell::{ use inkwell::{AddressSpace, IntPredicate, types::BasicType, values::{BasicValueEnum, PointerValue}};
IntPredicate, use inkwell::values::{ArrayValue, IntValue};
types::BasicType,
values::PointerValue,
};
use nac3parser::ast::StrRef; use nac3parser::ast::StrRef;
use crate::{ use crate::{
codegen::{ codegen::{
CodeGenContext, CodeGenContext,
CodeGenerator, CodeGenerator,
irrt::{call_ndarray_calc_size, call_ndarray_init_dims}, irrt::{
call_ndarray_calc_nd_indices,
call_ndarray_calc_size,
call_ndarray_init_dims,
},
stmt::gen_for_callback stmt::gen_for_callback
}, },
symbol_resolver::ValueEnum, symbol_resolver::ValueEnum,
@ -16,16 +17,201 @@ use crate::{
typecheck::typedef::{FunSignature, Type, TypeEnum}, typecheck::typedef::{FunSignature, Type, TypeEnum},
}; };
/// LLVM-typed implementation for generating the implementation for constructing an `NDArray`. /// Creates an `NDArray` instance from a constant shape.
/// ///
/// * `elem_ty` - The element type of the NDArray. /// * `elem_ty` - The element type of the `NDArray`.
/// * `var_name` - The variable name of the NDArray. /// * `shape` - The shape of the `NDArray`, represented as an LLVM [ArrayValue].
/// * `shape` - The `shape` parameter used to construct the NDArray. fn create_ndarray_const_shape<'ctx, 'a>(
fn call_ndarray_impl<'ctx, 'a>( generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>,
elem_ty: Type,
shape: ArrayValue<'ctx>
) -> Result<PointerValue<'ctx>, String> {
let ndarray_ty_enum = TypeEnum::ndarray(&mut ctx.unifier, Some(elem_ty), None, &ctx.primitives);
let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum);
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum();
assert!(llvm_ndarray_data_t.is_sized());
for i in 0..shape.get_type().len() {
let shape_dim = ctx.builder.build_extract_value(
shape,
i,
"",
).unwrap();
let shape_dim_gez = ctx.builder.build_int_compare(
IntPredicate::SGE,
shape_dim.into_int_value(),
llvm_usize.const_zero(),
""
);
ctx.make_assert(
generator,
shape_dim_gez,
"0:ValueError",
"negative dimensions not supported",
[None, None, None],
ctx.current_loc,
);
}
let ndarray = generator.gen_var_alloc(
ctx,
llvm_ndarray_t.into(),
None,
)?;
let num_dims = llvm_usize.const_int(shape.get_type().len() as u64, false);
let ndarray_num_dims = unsafe {
ctx.builder.build_in_bounds_gep(
ndarray,
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
"",
)
};
ctx.builder.build_store(ndarray_num_dims, num_dims);
let ndarray_dims = unsafe {
ctx.builder.build_in_bounds_gep(
ndarray,
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
"",
)
};
let ndarray_num_dims = ctx.build_gep_and_load(
ndarray,
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
None,
).into_int_value();
ctx.builder.build_store(
ndarray_dims,
ctx.builder.build_array_alloca(
llvm_usize,
ndarray_num_dims,
"",
),
);
for i in 0..shape.get_type().len() {
let ndarray_dim = ctx.build_gep_and_load(
ndarray,
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
None,
).into_pointer_value();
let ndarray_dim = unsafe {
ctx.builder.build_in_bounds_gep(
ndarray_dim,
&[llvm_i32.const_int(i as u64, true)],
"",
)
};
let shape_dim = ctx.builder.build_extract_value(shape, i, "")
.map(|val| val.into_int_value())
.unwrap();
ctx.builder.build_store(ndarray_dim, shape_dim);
}
let (ndarray_num_dims, ndarray_dims) = unsafe {
(
ctx.builder.build_in_bounds_gep(
ndarray,
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
""
),
ctx.builder.build_in_bounds_gep(
ndarray,
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
""
),
)
};
let ndarray_num_elems = call_ndarray_calc_size(
generator,
ctx,
ctx.builder.build_load(ndarray_num_dims, "").into_int_value(),
ctx.builder.build_load(ndarray_dims, "").into_pointer_value(),
);
let ndarray_data = unsafe {
ctx.builder.build_in_bounds_gep(
ndarray,
&[llvm_i32.const_zero(), llvm_i32.const_int(2, true)],
"",
)
};
ctx.builder.build_store(
ndarray_data,
ctx.builder.build_array_alloca(
llvm_ndarray_data_t,
ndarray_num_elems,
""
),
);
Ok(ndarray)
}
fn ndarray_zero_value<'ctx, 'a>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>,
elem_ty: Type,
) -> BasicValueEnum<'ctx> {
if [ctx.primitives.int32, ctx.primitives.uint32].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) {
ctx.ctx.i32_type().const_zero().into()
} else if [ctx.primitives.int64, ctx.primitives.uint64].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) {
ctx.ctx.i64_type().const_zero().into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) {
ctx.ctx.f64_type().const_zero().into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
ctx.ctx.bool_type().const_zero().into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
ctx.gen_string(generator, "").into()
} else {
unreachable!()
}
}
fn ndarray_one_value<'ctx, 'a>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>,
elem_ty: Type,
) -> BasicValueEnum<'ctx> {
if [ctx.primitives.int32, ctx.primitives.uint32].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) {
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int32);
ctx.ctx.i32_type().const_int(1, is_signed).into()
} else if [ctx.primitives.int64, ctx.primitives.uint64].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) {
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int64);
ctx.ctx.i64_type().const_int(1, is_signed).into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) {
ctx.ctx.f64_type().const_float(1.0).into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
ctx.ctx.bool_type().const_int(1, false).into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
ctx.gen_string(generator, "1").into()
} else {
unreachable!()
}
}
/// LLVM-typed implementation for generating the implementation for constructing an `NDArray`.
///
/// * `elem_ty` - The element type of the NDArray.
/// * `shape` - The `shape` parameter used to construct the NDArray.
fn call_ndarray_empty_impl<'ctx, 'a>(
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
elem_ty: Type, elem_ty: Type,
var_name: Option<&str>,
shape: PointerValue<'ctx>, shape: PointerValue<'ctx>,
) -> Result<PointerValue<'ctx>, String> { ) -> Result<PointerValue<'ctx>, String> {
let ndarray_ty_enum = TypeEnum::ndarray(&mut ctx.unifier, Some(elem_ty), None, &ctx.primitives); let ndarray_ty_enum = TypeEnum::ndarray(&mut ctx.unifier, Some(elem_ty), None, &ctx.primitives);
@ -43,8 +229,8 @@ fn call_ndarray_impl<'ctx, 'a>(
gen_for_callback( gen_for_callback(
generator, generator,
ctx, ctx,
|_, ctx| { |generator, ctx| {
let i = ctx.builder.build_alloca(llvm_usize, ""); let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
ctx.builder.build_store(i, llvm_usize.const_zero()); ctx.builder.build_store(i, llvm_usize.const_zero());
Ok(i) Ok(i)
@ -106,10 +292,11 @@ fn call_ndarray_impl<'ctx, 'a>(
}, },
)?; )?;
let ndarray = ctx.builder.build_alloca( let ndarray = generator.gen_var_alloc(
llvm_ndarray_t, ctx,
var_name.unwrap_or_default() llvm_ndarray_t.into(),
); None,
)?;
let num_dims = ctx.build_gep_and_load( let num_dims = ctx.build_gep_and_load(
shape, shape,
@ -151,7 +338,26 @@ fn call_ndarray_impl<'ctx, 'a>(
call_ndarray_init_dims(generator, ctx, ndarray, shape); call_ndarray_init_dims(generator, ctx, ndarray, shape);
let ndarray_num_elems = call_ndarray_calc_size(generator, ctx, shape); let (ndarray_num_dims, ndarray_dims) = unsafe {
(
ctx.builder.build_in_bounds_gep(
ndarray,
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
""
),
ctx.builder.build_in_bounds_gep(
ndarray,
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
""
),
)
};
let ndarray_num_elems = call_ndarray_calc_size(
generator,
ctx,
ctx.builder.build_load(ndarray_num_dims, "").into_int_value(),
ctx.builder.build_load(ndarray_dims, "").into_pointer_value(),
);
let ndarray_data = unsafe { let ndarray_data = unsafe {
ctx.builder.build_in_bounds_gep( ctx.builder.build_in_bounds_gep(
@ -172,6 +378,342 @@ fn call_ndarray_impl<'ctx, 'a>(
Ok(ndarray) Ok(ndarray)
} }
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as
/// its input.
///
/// Note that this differs from `ndarray.fill`, which instead replaces all first-dimension elements
/// with the given value (as opposed to all elements within the array).
fn ndarray_fill_flattened<'ctx, 'a, ValueFn>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>,
ndarray: PointerValue<'ctx>,
value_fn: ValueFn,
) -> Result<(), String>
where
ValueFn: Fn(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<BasicValueEnum<'ctx>, String>,
{
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let (num_dims, dims) = unsafe {
(
ctx.builder.build_in_bounds_gep(
ndarray,
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
""
),
ctx.builder.build_in_bounds_gep(
ndarray,
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
""
),
)
};
let ndarray_num_elems = call_ndarray_calc_size(
generator,
ctx,
ctx.builder.build_load(num_dims, "").into_int_value(),
ctx.builder.build_load(dims, "").into_pointer_value(),
);
gen_for_callback(
generator,
ctx,
|generator, ctx| {
let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
ctx.builder.build_store(i, llvm_usize.const_zero());
Ok(i)
},
|_, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.into_int_value();
Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, ndarray_num_elems, ""))
},
|generator, ctx, i_addr| {
let ndarray_data = ctx.build_gep_and_load(
ndarray,
&[llvm_i32.const_zero(), llvm_i32.const_int(2, true)],
None
).into_pointer_value();
let i = ctx.builder
.build_load(i_addr, "")
.into_int_value();
let elem = unsafe {
ctx.builder.build_in_bounds_gep(
ndarray_data,
&[i],
""
)
};
let value = value_fn(generator, ctx, i)?;
ctx.builder.build_store(elem, value);
Ok(())
},
|_, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.into_int_value();
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "");
ctx.builder.build_store(i_addr, i);
Ok(())
},
)
}
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with the dimension-indices
/// as its input
///
/// Note that this differs from `ndarray.fill`, which instead replaces all first-dimension elements
/// with the given value (as opposed to all elements within the array).
fn ndarray_fill_indexed<'ctx, 'a, ValueFn>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>,
ndarray: PointerValue<'ctx>,
value_fn: ValueFn,
) -> Result<(), String>
where
ValueFn: Fn(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, PointerValue<'ctx>) -> Result<BasicValueEnum<'ctx>, String>,
{
ndarray_fill_flattened(
generator,
ctx,
ndarray,
|generator, ctx, idx| {
let indices = call_ndarray_calc_nd_indices(
generator,
ctx,
idx,
ndarray,
)?;
value_fn(generator, ctx, indices)
}
)
}
/// LLVM-typed implementation for generating the implementation for `ndarray.zeros`.
///
/// * `elem_ty` - The element type of the NDArray.
/// * `shape` - The `shape` parameter used to construct the NDArray.
fn call_ndarray_zeros_impl<'ctx, 'a>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>,
elem_ty: Type,
shape: PointerValue<'ctx>,
) -> Result<PointerValue<'ctx>, String> {
let supported_types = [
ctx.primitives.int32,
ctx.primitives.int64,
ctx.primitives.uint32,
ctx.primitives.uint64,
ctx.primitives.float,
ctx.primitives.bool,
ctx.primitives.str,
];
assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty)));
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
ndarray_fill_flattened(
generator,
ctx,
ndarray,
|generator, ctx, _| {
let value = ndarray_zero_value(generator, ctx, elem_ty);
Ok(value)
}
)?;
Ok(ndarray)
}
/// LLVM-typed implementation for generating the implementation for `ndarray.ones`.
///
/// * `elem_ty` - The element type of the NDArray.
/// * `shape` - The `shape` parameter used to construct the NDArray.
fn call_ndarray_ones_impl<'ctx, 'a>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>,
elem_ty: Type,
shape: PointerValue<'ctx>,
) -> Result<PointerValue<'ctx>, String> {
let supported_types = [
ctx.primitives.int32,
ctx.primitives.int64,
ctx.primitives.uint32,
ctx.primitives.uint64,
ctx.primitives.float,
ctx.primitives.bool,
ctx.primitives.str,
];
assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty)));
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
ndarray_fill_flattened(
generator,
ctx,
ndarray,
|generator, ctx, _| {
let value = ndarray_one_value(generator, ctx, elem_ty);
Ok(value)
}
)?;
Ok(ndarray)
}
/// LLVM-typed implementation for generating the implementation for `ndarray.ones`.
///
/// * `elem_ty` - The element type of the NDArray.
/// * `shape` - The `shape` parameter used to construct the NDArray.
fn call_ndarray_full_impl<'ctx, 'a>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>,
elem_ty: Type,
shape: PointerValue<'ctx>,
fill_value: BasicValueEnum<'ctx>,
) -> Result<PointerValue<'ctx>, String> {
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
ndarray_fill_flattened(
generator,
ctx,
ndarray,
|generator, ctx, _| {
let value = if fill_value.is_pointer_value() {
let llvm_void = ctx.ctx.void_type();
let llvm_i1 = ctx.ctx.bool_type();
let llvm_i8 = ctx.ctx.i8_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let copy = generator.gen_var_alloc(ctx, fill_value.get_type(), None)?;
let memcpy_fn_name = format!(
"llvm.memcpy.p0i8.p0i8.i{}",
generator.get_size_type(ctx.ctx).get_bit_width(),
);
let memcpy_fn = ctx.module.get_function(memcpy_fn_name.as_str()).unwrap_or_else(|| {
let fn_type = llvm_void.fn_type(
&[
llvm_pi8.into(),
llvm_pi8.into(),
llvm_usize.into(),
llvm_i1.into(),
],
false,
);
ctx.module.add_function(memcpy_fn_name.as_str(), fn_type, None)
});
ctx.builder.build_call(
memcpy_fn,
&[
copy.into(),
fill_value.into(),
fill_value.get_type().size_of().unwrap().into(),
llvm_i1.const_zero().into(),
],
"",
);
copy.into()
} else if fill_value.is_int_value() || fill_value.is_float_value() {
fill_value.into()
} else {
unreachable!()
};
Ok(value)
}
)?;
Ok(ndarray)
}
/// LLVM-typed implementation for generating the implementation for `ndarray.eye`.
///
/// * `elem_ty` - The element type of the NDArray.
fn call_ndarray_eye_impl<'ctx, 'a>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>,
elem_ty: Type,
nrows: IntValue<'ctx>,
ncols: IntValue<'ctx>,
offset: IntValue<'ctx>,
) -> Result<PointerValue<'ctx>, String> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_usize_2 = llvm_usize.array_type(2);
let shape_addr = generator.gen_var_alloc(ctx, llvm_usize_2.into(), None)?;
let shape = ctx.builder.build_load(shape_addr, "")
.into_array_value();
let nrows = ctx.builder.build_int_z_extend_or_bit_cast(nrows, llvm_usize, "");
let shape = ctx.builder
.build_insert_value(shape, nrows, 0, "")
.map(|val| val.into_array_value())
.unwrap();
let ncols = ctx.builder.build_int_z_extend_or_bit_cast(ncols, llvm_usize, "");
let shape = ctx.builder
.build_insert_value(shape, ncols, 1, "")
.map(|val| val.into_array_value())
.unwrap();
let ndarray = create_ndarray_const_shape(generator, ctx, elem_ty, shape)?;
ndarray_fill_indexed(
generator,
ctx,
ndarray,
|generator, ctx, indices| {
let row = ctx.build_gep_and_load(
indices,
&[llvm_i32.const_zero()],
None,
).into_int_value();
let col = ctx.build_gep_and_load(
indices,
&[llvm_i32.const_int(1, true)],
None,
).into_int_value();
let col_with_offset = ctx.builder.build_int_add(
col,
ctx.builder.build_int_z_extend_or_bit_cast(offset, llvm_usize, ""),
""
);
let is_on_diag = ctx.builder.build_int_compare(
IntPredicate::EQ,
row,
col_with_offset,
""
);
let zero = ndarray_zero_value(generator, ctx, elem_ty);
let one = ndarray_one_value(generator, ctx, elem_ty);
let value = ctx.builder.build_select(is_on_diag, one, zero, "");
Ok(value)
},
)?;
Ok(ndarray)
}
/// Generates LLVM IR for `ndarray.empty`. /// Generates LLVM IR for `ndarray.empty`.
pub fn gen_ndarray_empty<'ctx, 'a>( pub fn gen_ndarray_empty<'ctx, 'a>(
context: &mut CodeGenContext<'ctx, 'a>, context: &mut CodeGenContext<'ctx, 'a>,
@ -184,15 +726,158 @@ pub fn gen_ndarray_empty<'ctx, 'a>(
assert_eq!(args.len(), 1); assert_eq!(args.len(), 1);
let shape_ty = fun.0.args[0].ty; let shape_ty = fun.0.args[0].ty;
let shape_arg_name = args[0].0;
let shape_arg = args[0].1.clone() let shape_arg = args[0].1.clone()
.to_basic_value_enum(context, generator, shape_ty)?; .to_basic_value_enum(context, generator, shape_ty)?;
call_ndarray_impl( call_ndarray_empty_impl(
generator, generator,
context, context,
context.primitives.float, context.primitives.float,
shape_arg_name.map(|name| name.to_string()).as_deref(),
shape_arg.into_pointer_value(), shape_arg.into_pointer_value(),
) )
}
/// Generates LLVM IR for `ndarray.zeros`.
pub fn gen_ndarray_zeros<'ctx, 'a>(
context: &mut CodeGenContext<'ctx, 'a>,
obj: Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone()
.to_basic_value_enum(context, generator, shape_ty)?;
call_ndarray_zeros_impl(
generator,
context,
context.primitives.float,
shape_arg.into_pointer_value(),
)
}
/// Generates LLVM IR for `ndarray.ones`.
pub fn gen_ndarray_ones<'ctx, 'a>(
context: &mut CodeGenContext<'ctx, 'a>,
obj: Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone()
.to_basic_value_enum(context, generator, shape_ty)?;
call_ndarray_ones_impl(
generator,
context,
context.primitives.float,
shape_arg.into_pointer_value(),
)
}
/// Generates LLVM IR for `ndarray.full`.
pub fn gen_ndarray_full<'ctx, 'a>(
context: &mut CodeGenContext<'ctx, 'a>,
obj: Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 2);
let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone()
.to_basic_value_enum(context, generator, shape_ty)?;
let fill_value_ty = fun.0.args[1].ty;
let fill_value_arg = args[1].1.clone()
.to_basic_value_enum(context, generator, fill_value_ty)?;
call_ndarray_full_impl(
generator,
context,
fill_value_ty,
shape_arg.into_pointer_value(),
fill_value_arg,
)
}
/// Generates LLVM IR for `ndarray.eye`.
pub fn gen_ndarray_eye<'ctx, 'a>(
context: &mut CodeGenContext<'ctx, 'a>,
obj: Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert!(matches!(args.len(), 1..=3));
let nrows_ty = fun.0.args[0].ty;
let nrows_arg = args[0].1.clone()
.to_basic_value_enum(context, generator, nrows_ty)?;
let ncols_ty = fun.0.args[1].ty;
let ncols_arg = args.iter()
.find(|arg| arg.0.map(|name| name == fun.0.args[1].name).unwrap_or(false))
.map(|arg| arg.1.clone().to_basic_value_enum(context, generator, ncols_ty))
.unwrap_or_else(|| {
args[0].1.clone().to_basic_value_enum(context, generator, nrows_ty)
})?;
let offset_ty = fun.0.args[2].ty;
let offset_arg = args.iter()
.find(|arg| arg.0.map(|name| name == fun.0.args[2].name).unwrap_or(false))
.map(|arg| arg.1.clone().to_basic_value_enum(context, generator, offset_ty))
.unwrap_or_else(|| {
Ok(context.gen_symbol_val(
generator,
fun.0.args[2].default_value.as_ref().unwrap(),
offset_ty
))
})?;
call_ndarray_eye_impl(
generator,
context,
context.primitives.float,
nrows_arg.into_int_value(),
ncols_arg.into_int_value(),
offset_arg.into_int_value(),
)
}
/// Generates LLVM IR for `ndarray.identity`.
pub fn gen_ndarray_identity<'ctx, 'a>(
context: &mut CodeGenContext<'ctx, 'a>,
obj: Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let llvm_usize = generator.get_size_type(context.ctx);
let n_ty = fun.0.args[0].ty;
let n_arg = args[0].1.clone()
.to_basic_value_enum(context, generator, n_ty)?;
call_ndarray_eye_impl(
generator,
context,
context.primitives.float,
n_arg.into_int_value(),
n_arg.into_int_value(),
llvm_usize.const_zero(),
)
} }

View File

@ -898,9 +898,14 @@ impl<'a> Inferencer<'a> {
if [ if [
"np_ndarray".into(), "np_ndarray".into(),
"np_empty".into(), "np_empty".into(),
"np_zeros".into(),
"np_ones".into(),
].contains(id) && args.len() == 1 { ].contains(id) && args.len() == 1 {
let ExprKind::List { elts, .. } = &args[0].node else { let ExprKind::List { elts, .. } = &args[0].node else {
return report_error("Expected List literal for first argument of np_ndarray", args[0].location) return report_error(
format!("Expected List literal for first argument of {id}, got {}", args[0].node.name()).as_str(),
args[0].location
)
}; };
let ndims = elts.len() as u64; let ndims = elts.len() as u64;
@ -941,6 +946,62 @@ impl<'a> Inferencer<'a> {
})) }))
} }
// 2-argument ndarray n-dimensional creation functions
if id == &"np_full".into() && args.len() == 2 {
let ExprKind::List { elts, .. } = &args[0].node else {
return report_error(
format!("Expected List literal for first argument of {id}, got {}", args[0].node.name()).as_str(),
args[0].location
)
};
let ndims = elts.len() as u64;
let arg0 = self.fold_expr(args.remove(0))?;
let arg1 = self.fold_expr(args.remove(0))?;
let ty = arg1.custom.unwrap();
let ndims = self.unifier.get_fresh_literal(
vec![SymbolValue::U64(ndims)],
None,
);
let ret = self.unifier.add_ty(TypeEnum::TNDArray {
ty,
ndims
});
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg {
name: "shape".into(),
ty: arg0.custom.unwrap(),
default_value: None,
},
FuncArg {
name: "fill_value".into(),
ty: arg1.custom.unwrap(),
default_value: None,
},
],
ret,
vars: HashMap::new(),
}));
return Ok(Some(Located {
location,
custom: Some(ret),
node: ExprKind::Call {
func: Box::new(Located {
custom: Some(custom),
location: func.location,
node: ExprKind::Name { id: *id, ctx: ctx.clone() },
}),
args: vec![arg0, arg1],
keywords: vec![],
},
}))
}
Ok(None) Ok(None)
} }

View File

@ -187,6 +187,11 @@ def patch(module):
# NumPy NDArray Functions # NumPy NDArray Functions
module.np_ndarray = np.ndarray module.np_ndarray = np.ndarray
module.np_empty = np.empty module.np_empty = np.empty
module.np_zeros = np.zeros
module.np_ones = np.ones
module.np_full = np.full
module.np_eye = np.eye
module.np_identity = np.identity
def file_import(filename, prefix="file_import_"): def file_import(filename, prefix="file_import_"):
filename = pathlib.Path(filename) filename = pathlib.Path(filename)

View File

@ -7,6 +7,12 @@ def consume_ndarray_i32_1(n: ndarray[int32, Literal[1]]):
def consume_ndarray_2(n: ndarray[float, Literal[2]]): def consume_ndarray_2(n: ndarray[float, Literal[2]]):
pass pass
def consume_ndarray_i32_1(n: ndarray[int32, 1]):
pass
def consume_ndarray_2(n: ndarray[float, 2]):
pass
def test_ndarray_ctor(): def test_ndarray_ctor():
n = np_ndarray([1]) n = np_ndarray([1])
consume_ndarray_1(n) consume_ndarray_1(n)
@ -15,8 +21,35 @@ def test_ndarray_empty():
n = np_empty([1]) n = np_empty([1])
consume_ndarray_1(n) consume_ndarray_1(n)
def test_ndarray_zeros():
n = np_zeros([1])
consume_ndarray_1(n)
def test_ndarray_ones():
n = np_ones([1])
consume_ndarray_1(n)
def test_ndarray_full():
n_float = np_full([1], 2.0)
consume_ndarray_1(n_float)
n_i32 = np_full([1], 2)
consume_ndarray_i32_1(n_i32)
def test_ndarray_eye():
n = np_eye(2)
consume_ndarray_2(n)
def test_ndarray_identity():
n = np_identity(2)
consume_ndarray_2(n)
def run() -> int32: def run() -> int32:
test_ndarray_ctor() test_ndarray_ctor()
test_ndarray_empty() test_ndarray_empty()
test_ndarray_zeros()
test_ndarray_ones()
test_ndarray_full()
test_ndarray_eye()
test_ndarray_identity()
return 0 return 0