forked from M-Labs/nac3
core: Implement most ndarray-creation functions
This commit is contained in:
parent
27fcf8926e
commit
140f8f8a08
|
@ -92,6 +92,18 @@ pub trait CodeGenerator {
|
||||||
gen_var(ctx, ty, name)
|
gen_var(ctx, ty, name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Allocate memory for a variable and return a pointer pointing to it.
|
||||||
|
/// The default implementation places the allocations at the start of the function.
|
||||||
|
fn gen_array_var_alloc<'ctx, 'a>(
|
||||||
|
&mut self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
ty: BasicTypeEnum<'ctx>,
|
||||||
|
size: IntValue<'ctx>,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
|
gen_array_var(ctx, ty, size, name)
|
||||||
|
}
|
||||||
|
|
||||||
/// Return a pointer pointing to the target of the expression.
|
/// Return a pointer pointing to the target of the expression.
|
||||||
fn gen_store_target<'ctx>(
|
fn gen_store_target<'ctx>(
|
||||||
&mut self,
|
&mut self,
|
||||||
|
|
|
@ -199,27 +199,27 @@ double __nac3_j0(double x) {
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t __nac3_ndarray_calc_size(
|
uint32_t __nac3_ndarray_calc_size(
|
||||||
const int32_t *list_data,
|
const uint64_t *list_data,
|
||||||
uint32_t list_len
|
uint32_t list_len
|
||||||
) {
|
) {
|
||||||
uint32_t num_elems = 1;
|
uint32_t num_elems = 1;
|
||||||
for (uint32_t i = 0; i < list_len; ++i) {
|
for (uint32_t i = 0; i < list_len; ++i) {
|
||||||
int32_t val = list_data[i];
|
uint64_t val = list_data[i];
|
||||||
__builtin_assume(val >= 0);
|
__builtin_assume(val >= 0);
|
||||||
num_elems *= (uint32_t) list_data[i];
|
num_elems *= list_data[i];
|
||||||
}
|
}
|
||||||
return num_elems;
|
return num_elems;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t __nac3_ndarray_calc_size64(
|
uint64_t __nac3_ndarray_calc_size64(
|
||||||
const int32_t *list_data,
|
const uint64_t *list_data,
|
||||||
uint64_t list_len
|
uint64_t list_len
|
||||||
) {
|
) {
|
||||||
uint64_t num_elems = 1;
|
uint64_t num_elems = 1;
|
||||||
for (uint64_t i = 0; i < list_len; ++i) {
|
for (uint64_t i = 0; i < list_len; ++i) {
|
||||||
int32_t val = list_data[i];
|
uint64_t val = list_data[i];
|
||||||
__builtin_assume(val >= 0);
|
__builtin_assume(val >= 0);
|
||||||
num_elems *= (uint64_t) list_data[i];
|
num_elems *= list_data[i];
|
||||||
}
|
}
|
||||||
return num_elems;
|
return num_elems;
|
||||||
}
|
}
|
||||||
|
@ -240,4 +240,32 @@ void __nac3_ndarray_init_dims64(
|
||||||
for (uint64_t i = 0; i < shape_len; ++i) {
|
for (uint64_t i = 0; i < shape_len; ++i) {
|
||||||
ndarray_dims[i] = (uint64_t) shape_data[i];
|
ndarray_dims[i] = (uint64_t) shape_data[i];
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_calc_nd_indices(
|
||||||
|
uint32_t index,
|
||||||
|
const uint32_t* dims,
|
||||||
|
uint32_t num_dims,
|
||||||
|
uint32_t* idxs
|
||||||
|
) {
|
||||||
|
uint32_t stride = 1;
|
||||||
|
for (uint32_t dim = 0; dim < num_dims; dim++) {
|
||||||
|
uint32_t i = num_dims - dim - 1;
|
||||||
|
idxs[i] = (index / stride) % dims[i];
|
||||||
|
stride *= dims[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_calc_nd_indices64(
|
||||||
|
uint64_t index,
|
||||||
|
const uint64_t* dims,
|
||||||
|
uint64_t num_dims,
|
||||||
|
uint64_t* idxs
|
||||||
|
) {
|
||||||
|
uint64_t stride = 1;
|
||||||
|
for (uint64_t dim = 0; dim < num_dims; dim++) {
|
||||||
|
uint64_t i = num_dims - dim - 1;
|
||||||
|
idxs[i] = (index / stride) % dims[i];
|
||||||
|
stride *= dims[i];
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -1,6 +1,6 @@
|
||||||
use crate::typecheck::typedef::Type;
|
use crate::typecheck::typedef::Type;
|
||||||
|
|
||||||
use super::{CodeGenContext, CodeGenerator};
|
use super::{assert_is_list, assert_is_ndarray, CodeGenContext, CodeGenerator};
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
attributes::{Attribute, AttributeLoc},
|
attributes::{Attribute, AttributeLoc},
|
||||||
context::Context,
|
context::Context,
|
||||||
|
@ -12,9 +12,6 @@ use inkwell::{
|
||||||
};
|
};
|
||||||
use nac3parser::ast::Expr;
|
use nac3parser::ast::Expr;
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
use inkwell::types::AnyTypeEnum;
|
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn load_irrt(ctx: &Context) -> Module {
|
pub fn load_irrt(ctx: &Context) -> Module {
|
||||||
let bitcode_buf = MemoryBuffer::create_from_memory_range(
|
let bitcode_buf = MemoryBuffer::create_from_memory_range(
|
||||||
|
@ -550,62 +547,21 @@ pub fn call_j0<'ctx>(
|
||||||
.into_float_value()
|
.into_float_value()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Checks whether the pointer `value` refers to a `list` in LLVM.
|
|
||||||
fn assert_is_list(value: PointerValue) -> PointerValue {
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
|
||||||
let llvm_shape_ty = value.get_type().get_element_type();
|
|
||||||
let AnyTypeEnum::StructType(llvm_shape_ty) = llvm_shape_ty else {
|
|
||||||
panic!("Expected struct type for `list` type, but got {llvm_shape_ty}")
|
|
||||||
};
|
|
||||||
assert_eq!(llvm_shape_ty.count_fields(), 2);
|
|
||||||
assert!(matches!(llvm_shape_ty.get_field_type_at_index(0), Some(BasicTypeEnum::PointerType(..))));
|
|
||||||
assert!(matches!(llvm_shape_ty.get_field_type_at_index(1), Some(BasicTypeEnum::IntType(..))));
|
|
||||||
}
|
|
||||||
|
|
||||||
value
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Checks whether the pointer `value` refers to an `NDArray` in LLVM.
|
|
||||||
fn assert_is_ndarray(value: PointerValue) -> PointerValue {
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
|
||||||
let llvm_ndarray_ty = value.get_type().get_element_type();
|
|
||||||
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
|
|
||||||
panic!("Expected struct type for `NDArray` type, but got {llvm_ndarray_ty}")
|
|
||||||
};
|
|
||||||
|
|
||||||
assert_eq!(llvm_ndarray_ty.count_fields(), 3);
|
|
||||||
assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(0), Some(BasicTypeEnum::IntType(..))));
|
|
||||||
let Some(ndarray_dims) = llvm_ndarray_ty.get_field_type_at_index(1) else {
|
|
||||||
unreachable!()
|
|
||||||
};
|
|
||||||
let BasicTypeEnum::PointerType(dims) = ndarray_dims else {
|
|
||||||
panic!("Expected pointer type for `list.1`, but got {ndarray_dims}")
|
|
||||||
};
|
|
||||||
assert!(matches!(dims.get_element_type(), AnyTypeEnum::IntType(..)));
|
|
||||||
assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(2), Some(BasicTypeEnum::PointerType(..))));
|
|
||||||
}
|
|
||||||
|
|
||||||
value
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [IntValue] representing the
|
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [IntValue] representing the
|
||||||
/// calculated total size.
|
/// calculated total size.
|
||||||
///
|
///
|
||||||
/// * `shape` - LLVM pointer to the `shape` of the NDArray. This value must be the LLVM
|
/// * `num_dims` - An [IntValue] containing the number of dimensions.
|
||||||
/// representation of a `list`.
|
/// * `dims` - A [PointerValue] to an array containing the size of each dimensions.
|
||||||
pub fn call_ndarray_calc_size<'ctx, 'a>(
|
pub fn call_ndarray_calc_size<'ctx, 'a>(
|
||||||
generator: &mut dyn CodeGenerator,
|
generator: &mut dyn CodeGenerator,
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
shape: PointerValue<'ctx>,
|
num_dims: IntValue<'ctx>,
|
||||||
|
dims: PointerValue<'ctx>,
|
||||||
) -> IntValue<'ctx> {
|
) -> IntValue<'ctx> {
|
||||||
assert_is_list(shape);
|
let llvm_i64 = ctx.ctx.i64_type();
|
||||||
|
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
|
let llvm_pi64 = llvm_i64.ptr_type(AddressSpace::default());
|
||||||
|
|
||||||
let ndarray_calc_size_fn_name = match generator.get_size_type(ctx.ctx).get_bit_width() {
|
let ndarray_calc_size_fn_name = match generator.get_size_type(ctx.ctx).get_bit_width() {
|
||||||
32 => "__nac3_ndarray_calc_size",
|
32 => "__nac3_ndarray_calc_size",
|
||||||
|
@ -614,7 +570,7 @@ pub fn call_ndarray_calc_size<'ctx, 'a>(
|
||||||
};
|
};
|
||||||
let ndarray_calc_size_fn_t = llvm_usize.fn_type(
|
let ndarray_calc_size_fn_t = llvm_usize.fn_type(
|
||||||
&[
|
&[
|
||||||
llvm_pi32.into(),
|
llvm_pi64.into(),
|
||||||
llvm_usize.into(),
|
llvm_usize.into(),
|
||||||
],
|
],
|
||||||
false,
|
false,
|
||||||
|
@ -624,30 +580,12 @@ pub fn call_ndarray_calc_size<'ctx, 'a>(
|
||||||
ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
|
ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
|
||||||
});
|
});
|
||||||
|
|
||||||
let (
|
|
||||||
shape_data,
|
|
||||||
shape_len,
|
|
||||||
) = unsafe {
|
|
||||||
(
|
|
||||||
ctx.builder.build_in_bounds_gep(
|
|
||||||
shape,
|
|
||||||
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
|
||||||
""
|
|
||||||
),
|
|
||||||
ctx.builder.build_in_bounds_gep(
|
|
||||||
shape,
|
|
||||||
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
|
||||||
""
|
|
||||||
),
|
|
||||||
)
|
|
||||||
};
|
|
||||||
|
|
||||||
ctx.builder
|
ctx.builder
|
||||||
.build_call(
|
.build_call(
|
||||||
ndarray_calc_size_fn,
|
ndarray_calc_size_fn,
|
||||||
&[
|
&[
|
||||||
ctx.builder.build_load(shape_data, "").into(),
|
dims.into(),
|
||||||
ctx.builder.build_load(shape_len, "").into(),
|
num_dims.into(),
|
||||||
],
|
],
|
||||||
"",
|
"",
|
||||||
)
|
)
|
||||||
|
@ -721,4 +659,68 @@ pub fn call_ndarray_init_dims<'ctx, 'a>(
|
||||||
],
|
],
|
||||||
"",
|
"",
|
||||||
);
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_ndarray_calc_nd_indices<'ctx, 'a>(
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
index: IntValue<'ctx>,
|
||||||
|
ndarray: PointerValue<'ctx>,
|
||||||
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
|
assert_is_ndarray(ndarray);
|
||||||
|
|
||||||
|
let llvm_void = ctx.ctx.void_type();
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||||
|
|
||||||
|
let ndarray_calc_nd_indices_dn_name = match generator.get_size_type(ctx.ctx).get_bit_width() {
|
||||||
|
32 => "__nac3_ndarray_calc_nd_indices",
|
||||||
|
64 => "__nac3_ndarray_calc_nd_indices64",
|
||||||
|
bw => unreachable!("Unsupported size type bit width: {}", bw)
|
||||||
|
};
|
||||||
|
let ndarray_calc_nd_indices_fn = ctx.module.get_function(ndarray_calc_nd_indices_dn_name).unwrap_or_else(|| {
|
||||||
|
let fn_type = llvm_void.fn_type(
|
||||||
|
&[
|
||||||
|
llvm_usize.into(),
|
||||||
|
llvm_pusize.into(),
|
||||||
|
llvm_usize.into(),
|
||||||
|
llvm_pusize.into(),
|
||||||
|
],
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
|
||||||
|
ctx.module.add_function(ndarray_calc_nd_indices_dn_name, fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let ndarray_num_dims = ctx.build_gep_and_load(
|
||||||
|
ndarray,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||||
|
None,
|
||||||
|
).into_int_value();
|
||||||
|
let ndarray_dims = ctx.build_gep_and_load(
|
||||||
|
ndarray,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
||||||
|
None,
|
||||||
|
).into_pointer_value();
|
||||||
|
|
||||||
|
let indices = ctx.builder.build_array_alloca(
|
||||||
|
llvm_usize,
|
||||||
|
ndarray_num_dims,
|
||||||
|
"",
|
||||||
|
);
|
||||||
|
|
||||||
|
ctx.builder.build_call(
|
||||||
|
ndarray_calc_nd_indices_fn,
|
||||||
|
&[
|
||||||
|
index.into(),
|
||||||
|
ndarray_dims.into(),
|
||||||
|
ndarray_num_dims.into(),
|
||||||
|
indices.into(),
|
||||||
|
],
|
||||||
|
"",
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(indices)
|
||||||
}
|
}
|
|
@ -34,6 +34,9 @@ use std::sync::{
|
||||||
};
|
};
|
||||||
use std::thread;
|
use std::thread;
|
||||||
|
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
use inkwell::types::AnyTypeEnum;
|
||||||
|
|
||||||
pub mod concrete_type;
|
pub mod concrete_type;
|
||||||
pub mod expr;
|
pub mod expr;
|
||||||
mod generator;
|
mod generator;
|
||||||
|
@ -236,7 +239,7 @@ pub struct WorkerRegistry {
|
||||||
static_value_store: Arc<Mutex<StaticValueStore>>,
|
static_value_store: Arc<Mutex<StaticValueStore>>,
|
||||||
|
|
||||||
/// LLVM-related options for code generation.
|
/// LLVM-related options for code generation.
|
||||||
llvm_options: CodeGenLLVMOptions,
|
pub llvm_options: CodeGenLLVMOptions,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WorkerRegistry {
|
impl WorkerRegistry {
|
||||||
|
@ -995,3 +998,43 @@ fn gen_in_range_check<'ctx>(
|
||||||
|
|
||||||
ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp")
|
ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Checks whether the pointer `value` refers to a `list` in LLVM.
|
||||||
|
fn assert_is_list(value: PointerValue) -> PointerValue {
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
let llvm_shape_ty = value.get_type().get_element_type();
|
||||||
|
let AnyTypeEnum::StructType(llvm_shape_ty) = llvm_shape_ty else {
|
||||||
|
panic!("Expected struct type for `list` type, but got {llvm_shape_ty}")
|
||||||
|
};
|
||||||
|
assert_eq!(llvm_shape_ty.count_fields(), 2);
|
||||||
|
assert!(matches!(llvm_shape_ty.get_field_type_at_index(0), Some(BasicTypeEnum::PointerType(..))));
|
||||||
|
assert!(matches!(llvm_shape_ty.get_field_type_at_index(1), Some(BasicTypeEnum::IntType(..))));
|
||||||
|
}
|
||||||
|
|
||||||
|
value
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Checks whether the pointer `value` refers to an `NDArray` in LLVM.
|
||||||
|
fn assert_is_ndarray(value: PointerValue) -> PointerValue {
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
let llvm_ndarray_ty = value.get_type().get_element_type();
|
||||||
|
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
|
||||||
|
panic!("Expected struct type for `NDArray` type, but got {llvm_ndarray_ty}")
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(llvm_ndarray_ty.count_fields(), 3);
|
||||||
|
assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(0), Some(BasicTypeEnum::IntType(..))));
|
||||||
|
let Some(ndarray_dims) = llvm_ndarray_ty.get_field_type_at_index(1) else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
let BasicTypeEnum::PointerType(dims) = ndarray_dims else {
|
||||||
|
panic!("Expected pointer type for `list.1`, but got {ndarray_dims}")
|
||||||
|
};
|
||||||
|
assert!(matches!(dims.get_element_type(), AnyTypeEnum::IntType(..)));
|
||||||
|
assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(2), Some(BasicTypeEnum::PointerType(..))));
|
||||||
|
}
|
||||||
|
|
||||||
|
value
|
||||||
|
}
|
||||||
|
|
|
@ -15,7 +15,7 @@ use crate::{
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
attributes::{Attribute, AttributeLoc},
|
attributes::{Attribute, AttributeLoc},
|
||||||
basic_block::BasicBlock,
|
basic_block::BasicBlock,
|
||||||
types::BasicTypeEnum,
|
types::{BasicType, BasicTypeEnum},
|
||||||
values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue},
|
values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue},
|
||||||
IntPredicate,
|
IntPredicate,
|
||||||
};
|
};
|
||||||
|
@ -54,6 +54,37 @@ pub fn gen_var<'ctx>(
|
||||||
Ok(ptr)
|
Ok(ptr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// See [CodeGenerator::gen_array_var_alloc].
|
||||||
|
pub fn gen_array_var<'ctx, 'a, T: BasicType<'ctx>>(
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
ty: T,
|
||||||
|
size: IntValue<'ctx>,
|
||||||
|
name: Option<&str>,
|
||||||
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
|
// Restore debug location
|
||||||
|
let di_loc = ctx.debug_info.0.create_debug_location(
|
||||||
|
ctx.ctx,
|
||||||
|
ctx.current_loc.row as u32,
|
||||||
|
ctx.current_loc.column as u32,
|
||||||
|
ctx.debug_info.2,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
|
// put the alloca in init block
|
||||||
|
let current = ctx.builder.get_insert_block().unwrap();
|
||||||
|
|
||||||
|
// position before the last branching instruction...
|
||||||
|
ctx.builder.position_before(&ctx.init_bb.get_last_instruction().unwrap());
|
||||||
|
ctx.builder.set_current_debug_location(di_loc);
|
||||||
|
|
||||||
|
let ptr = ctx.builder.build_array_alloca(ty, size, name.unwrap_or(""));
|
||||||
|
|
||||||
|
ctx.builder.position_at_end(current);
|
||||||
|
ctx.builder.set_current_debug_location(di_loc);
|
||||||
|
|
||||||
|
Ok(ptr)
|
||||||
|
}
|
||||||
|
|
||||||
/// See [`CodeGenerator::gen_store_target`].
|
/// See [`CodeGenerator::gen_store_target`].
|
||||||
pub fn gen_store_target<'ctx, G: CodeGenerator>(
|
pub fn gen_store_target<'ctx, G: CodeGenerator>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
|
|
|
@ -13,7 +13,13 @@ use crate::{
|
||||||
stmt::exn_constructor,
|
stmt::exn_constructor,
|
||||||
},
|
},
|
||||||
symbol_resolver::SymbolValue,
|
symbol_resolver::SymbolValue,
|
||||||
toplevel::numpy::gen_ndarray_empty,
|
toplevel::numpy::{
|
||||||
|
gen_ndarray_empty,
|
||||||
|
gen_ndarray_eye,
|
||||||
|
gen_ndarray_full,
|
||||||
|
gen_ndarray_ones,
|
||||||
|
gen_ndarray_zeros,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
attributes::{Attribute, AttributeLoc},
|
attributes::{Attribute, AttributeLoc},
|
||||||
|
@ -22,6 +28,7 @@ use inkwell::{
|
||||||
FloatPredicate,
|
FloatPredicate,
|
||||||
IntPredicate
|
IntPredicate
|
||||||
};
|
};
|
||||||
|
use crate::toplevel::numpy::gen_ndarray_identity;
|
||||||
|
|
||||||
type BuiltinInfo = Vec<(Arc<RwLock<TopLevelDef>>, Option<Stmt>)>;
|
type BuiltinInfo = Vec<(Arc<RwLock<TopLevelDef>>, Option<Stmt>)>;
|
||||||
|
|
||||||
|
@ -279,10 +286,30 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
let boolean = primitives.0.bool;
|
let boolean = primitives.0.bool;
|
||||||
let range = primitives.0.range;
|
let range = primitives.0.range;
|
||||||
let string = primitives.0.str;
|
let string = primitives.0.str;
|
||||||
|
let ndarray = {
|
||||||
|
let ndarray_ty = TypeEnum::ndarray(&mut primitives.1, None, None, &primitives.0);
|
||||||
|
primitives.1.add_ty(ndarray_ty)
|
||||||
|
};
|
||||||
let ndarray_float = {
|
let ndarray_float = {
|
||||||
let ndarray_ty_enum = TypeEnum::ndarray(&mut primitives.1, Some(float), None, &primitives.0);
|
let ndarray_ty_enum = TypeEnum::ndarray(&mut primitives.1, Some(float), None, &primitives.0);
|
||||||
primitives.1.add_ty(ndarray_ty_enum)
|
primitives.1.add_ty(ndarray_ty_enum)
|
||||||
};
|
};
|
||||||
|
let ndarray_float_2d = {
|
||||||
|
let value = match primitives.0.size_t {
|
||||||
|
64 => SymbolValue::U64(2u64),
|
||||||
|
32 => SymbolValue::U32(2u32),
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
let ndims = primitives.1.add_ty(TypeEnum::TLiteral {
|
||||||
|
values: vec![value],
|
||||||
|
loc: None,
|
||||||
|
});
|
||||||
|
|
||||||
|
primitives.1.add_ty(TypeEnum::TNDArray {
|
||||||
|
ty: float,
|
||||||
|
ndims,
|
||||||
|
})
|
||||||
|
};
|
||||||
let list_int32 = primitives.1.add_ty(TypeEnum::TList { ty: int32 });
|
let list_int32 = primitives.1.add_ty(TypeEnum::TList { ty: int32 });
|
||||||
let num_ty = primitives.1.get_fresh_var_with_range(
|
let num_ty = primitives.1.get_fresh_var_with_range(
|
||||||
&[int32, int64, float, boolean, uint32, uint64],
|
&[int32, int64, float, boolean, uint32, uint64],
|
||||||
|
@ -869,6 +896,89 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
.map(|val| Some(val.as_basic_value_enum()))
|
.map(|val| Some(val.as_basic_value_enum()))
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
|
create_fn_by_codegen(
|
||||||
|
primitives,
|
||||||
|
&var_map,
|
||||||
|
"np_zeros",
|
||||||
|
ndarray_float,
|
||||||
|
// We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a
|
||||||
|
// type variable
|
||||||
|
&[(list_int32, "shape")],
|
||||||
|
Box::new(|ctx, obj, fun, args, generator| {
|
||||||
|
gen_ndarray_zeros(ctx, obj, fun, args, generator)
|
||||||
|
.map(|val| Some(val.as_basic_value_enum()))
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
create_fn_by_codegen(
|
||||||
|
primitives,
|
||||||
|
&var_map,
|
||||||
|
"np_ones",
|
||||||
|
ndarray_float,
|
||||||
|
// We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a
|
||||||
|
// type variable
|
||||||
|
&[(list_int32, "shape")],
|
||||||
|
Box::new(|ctx, obj, fun, args, generator| {
|
||||||
|
gen_ndarray_ones(ctx, obj, fun, args, generator)
|
||||||
|
.map(|val| Some(val.as_basic_value_enum()))
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
{
|
||||||
|
let tv = primitives.1.get_fresh_var(Some("T".into()), None).0;
|
||||||
|
|
||||||
|
create_fn_by_codegen(
|
||||||
|
primitives,
|
||||||
|
&var_map,
|
||||||
|
"np_full",
|
||||||
|
ndarray,
|
||||||
|
// We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a
|
||||||
|
// type variable
|
||||||
|
&[(list_int32, "shape"), (tv, "fill_value")],
|
||||||
|
Box::new(|ctx, obj, fun, args, generator| {
|
||||||
|
gen_ndarray_full(ctx, obj, fun, args, generator)
|
||||||
|
.map(|val| Some(val.as_basic_value_enum()))
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||||
|
name: "np_eye".into(),
|
||||||
|
simple_name: "np_eye".into(),
|
||||||
|
signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
|
args: vec![
|
||||||
|
FuncArg { name: "N".into(), ty: int32, default_value: None },
|
||||||
|
// TODO(Derppening): Default values current do not work?
|
||||||
|
FuncArg {
|
||||||
|
name: "M".into(),
|
||||||
|
ty: int32,
|
||||||
|
default_value: Some(SymbolValue::OptionNone)
|
||||||
|
},
|
||||||
|
FuncArg { name: "k".into(), ty: int32, default_value: Some(SymbolValue::I32(0)) },
|
||||||
|
],
|
||||||
|
ret: ndarray_float_2d,
|
||||||
|
vars: var_map.clone(),
|
||||||
|
})),
|
||||||
|
var_id: Default::default(),
|
||||||
|
instance_to_symbol: Default::default(),
|
||||||
|
instance_to_stmt: Default::default(),
|
||||||
|
resolver: None,
|
||||||
|
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||||
|
|ctx, obj, fun, args, generator| {
|
||||||
|
gen_ndarray_eye(ctx, obj, fun, args, generator)
|
||||||
|
.map(|val| Some(val.as_basic_value_enum()))
|
||||||
|
},
|
||||||
|
)))),
|
||||||
|
loc: None,
|
||||||
|
})),
|
||||||
|
create_fn_by_codegen(
|
||||||
|
primitives,
|
||||||
|
&var_map,
|
||||||
|
"np_identity",
|
||||||
|
ndarray_float_2d,
|
||||||
|
&[(int32, "n")],
|
||||||
|
Box::new(|ctx, obj, fun, args, generator| {
|
||||||
|
gen_ndarray_identity(ctx, obj, fun, args, generator)
|
||||||
|
.map(|val| Some(val.as_basic_value_enum()))
|
||||||
|
}),
|
||||||
|
),
|
||||||
create_fn_by_codegen(
|
create_fn_by_codegen(
|
||||||
primitives,
|
primitives,
|
||||||
&var_map,
|
&var_map,
|
||||||
|
@ -1364,7 +1474,22 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
Some(ctx.builder.build_int_truncate(len, int32, "len2i32").into())
|
Some(ctx.builder.build_int_truncate(len, int32, "len2i32").into())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TypeEnum::TNDArray { .. } => todo!(),
|
TypeEnum::TNDArray { .. } => {
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
let i32_zero = llvm_i32.const_zero();
|
||||||
|
|
||||||
|
let len = ctx.build_gep_and_load(
|
||||||
|
arg.into_pointer_value(),
|
||||||
|
&[i32_zero, i32_zero],
|
||||||
|
None,
|
||||||
|
).into_int_value();
|
||||||
|
|
||||||
|
if len.get_type().get_bit_width() != 32 {
|
||||||
|
Some(ctx.builder.build_int_truncate(len, llvm_i32, "len").into())
|
||||||
|
} else {
|
||||||
|
Some(len.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -1,14 +1,15 @@
|
||||||
use inkwell::{
|
use inkwell::{AddressSpace, IntPredicate, types::BasicType, values::{BasicValueEnum, PointerValue}};
|
||||||
IntPredicate,
|
use inkwell::values::{ArrayValue, IntValue};
|
||||||
types::BasicType,
|
|
||||||
values::PointerValue,
|
|
||||||
};
|
|
||||||
use nac3parser::ast::StrRef;
|
use nac3parser::ast::StrRef;
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::{
|
codegen::{
|
||||||
CodeGenContext,
|
CodeGenContext,
|
||||||
CodeGenerator,
|
CodeGenerator,
|
||||||
irrt::{call_ndarray_calc_size, call_ndarray_init_dims},
|
irrt::{
|
||||||
|
call_ndarray_calc_nd_indices,
|
||||||
|
call_ndarray_calc_size,
|
||||||
|
call_ndarray_init_dims,
|
||||||
|
},
|
||||||
stmt::gen_for_callback
|
stmt::gen_for_callback
|
||||||
},
|
},
|
||||||
symbol_resolver::ValueEnum,
|
symbol_resolver::ValueEnum,
|
||||||
|
@ -16,16 +17,201 @@ use crate::{
|
||||||
typecheck::typedef::{FunSignature, Type, TypeEnum},
|
typecheck::typedef::{FunSignature, Type, TypeEnum},
|
||||||
};
|
};
|
||||||
|
|
||||||
/// LLVM-typed implementation for generating the implementation for constructing an `NDArray`.
|
/// Creates an `NDArray` instance from a constant shape.
|
||||||
///
|
///
|
||||||
/// * `elem_ty` - The element type of the NDArray.
|
/// * `elem_ty` - The element type of the `NDArray`.
|
||||||
/// * `var_name` - The variable name of the NDArray.
|
/// * `shape` - The shape of the `NDArray`, represented as an LLVM [ArrayValue].
|
||||||
/// * `shape` - The `shape` parameter used to construct the NDArray.
|
fn create_ndarray_const_shape<'ctx, 'a>(
|
||||||
fn call_ndarray_impl<'ctx, 'a>(
|
generator: &mut dyn CodeGenerator,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
elem_ty: Type,
|
||||||
|
shape: ArrayValue<'ctx>
|
||||||
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
|
let ndarray_ty_enum = TypeEnum::ndarray(&mut ctx.unifier, Some(elem_ty), None, &ctx.primitives);
|
||||||
|
let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum);
|
||||||
|
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
|
||||||
|
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
|
||||||
|
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum();
|
||||||
|
assert!(llvm_ndarray_data_t.is_sized());
|
||||||
|
|
||||||
|
for i in 0..shape.get_type().len() {
|
||||||
|
let shape_dim = ctx.builder.build_extract_value(
|
||||||
|
shape,
|
||||||
|
i,
|
||||||
|
"",
|
||||||
|
).unwrap();
|
||||||
|
|
||||||
|
let shape_dim_gez = ctx.builder.build_int_compare(
|
||||||
|
IntPredicate::SGE,
|
||||||
|
shape_dim.into_int_value(),
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
""
|
||||||
|
);
|
||||||
|
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
shape_dim_gez,
|
||||||
|
"0:ValueError",
|
||||||
|
"negative dimensions not supported",
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let ndarray = generator.gen_var_alloc(
|
||||||
|
ctx,
|
||||||
|
llvm_ndarray_t.into(),
|
||||||
|
None,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let num_dims = llvm_usize.const_int(shape.get_type().len() as u64, false);
|
||||||
|
|
||||||
|
let ndarray_num_dims = unsafe {
|
||||||
|
ctx.builder.build_in_bounds_gep(
|
||||||
|
ndarray,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
};
|
||||||
|
ctx.builder.build_store(ndarray_num_dims, num_dims);
|
||||||
|
|
||||||
|
let ndarray_dims = unsafe {
|
||||||
|
ctx.builder.build_in_bounds_gep(
|
||||||
|
ndarray,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
let ndarray_num_dims = ctx.build_gep_and_load(
|
||||||
|
ndarray,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||||
|
None,
|
||||||
|
).into_int_value();
|
||||||
|
|
||||||
|
ctx.builder.build_store(
|
||||||
|
ndarray_dims,
|
||||||
|
ctx.builder.build_array_alloca(
|
||||||
|
llvm_usize,
|
||||||
|
ndarray_num_dims,
|
||||||
|
"",
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
for i in 0..shape.get_type().len() {
|
||||||
|
let ndarray_dim = ctx.build_gep_and_load(
|
||||||
|
ndarray,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
||||||
|
None,
|
||||||
|
).into_pointer_value();
|
||||||
|
let ndarray_dim = unsafe {
|
||||||
|
ctx.builder.build_in_bounds_gep(
|
||||||
|
ndarray_dim,
|
||||||
|
&[llvm_i32.const_int(i as u64, true)],
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
};
|
||||||
|
let shape_dim = ctx.builder.build_extract_value(shape, i, "")
|
||||||
|
.map(|val| val.into_int_value())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
ctx.builder.build_store(ndarray_dim, shape_dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
let (ndarray_num_dims, ndarray_dims) = unsafe {
|
||||||
|
(
|
||||||
|
ctx.builder.build_in_bounds_gep(
|
||||||
|
ndarray,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||||
|
""
|
||||||
|
),
|
||||||
|
ctx.builder.build_in_bounds_gep(
|
||||||
|
ndarray,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
||||||
|
""
|
||||||
|
),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
let ndarray_num_elems = call_ndarray_calc_size(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ctx.builder.build_load(ndarray_num_dims, "").into_int_value(),
|
||||||
|
ctx.builder.build_load(ndarray_dims, "").into_pointer_value(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let ndarray_data = unsafe {
|
||||||
|
ctx.builder.build_in_bounds_gep(
|
||||||
|
ndarray,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_int(2, true)],
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
};
|
||||||
|
ctx.builder.build_store(
|
||||||
|
ndarray_data,
|
||||||
|
ctx.builder.build_array_alloca(
|
||||||
|
llvm_ndarray_data_t,
|
||||||
|
ndarray_num_elems,
|
||||||
|
""
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(ndarray)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ndarray_zero_value<'ctx, 'a>(
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
elem_ty: Type,
|
||||||
|
) -> BasicValueEnum<'ctx> {
|
||||||
|
if [ctx.primitives.int32, ctx.primitives.uint32].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) {
|
||||||
|
ctx.ctx.i32_type().const_zero().into()
|
||||||
|
} else if [ctx.primitives.int64, ctx.primitives.uint64].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) {
|
||||||
|
ctx.ctx.i64_type().const_zero().into()
|
||||||
|
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) {
|
||||||
|
ctx.ctx.f64_type().const_zero().into()
|
||||||
|
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
|
||||||
|
ctx.ctx.bool_type().const_zero().into()
|
||||||
|
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
|
||||||
|
ctx.gen_string(generator, "").into()
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ndarray_one_value<'ctx, 'a>(
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
elem_ty: Type,
|
||||||
|
) -> BasicValueEnum<'ctx> {
|
||||||
|
if [ctx.primitives.int32, ctx.primitives.uint32].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) {
|
||||||
|
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int32);
|
||||||
|
ctx.ctx.i32_type().const_int(1, is_signed).into()
|
||||||
|
} else if [ctx.primitives.int64, ctx.primitives.uint64].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) {
|
||||||
|
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int64);
|
||||||
|
ctx.ctx.i64_type().const_int(1, is_signed).into()
|
||||||
|
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) {
|
||||||
|
ctx.ctx.f64_type().const_float(1.0).into()
|
||||||
|
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
|
||||||
|
ctx.ctx.bool_type().const_int(1, false).into()
|
||||||
|
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
|
||||||
|
ctx.gen_string(generator, "1").into()
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// LLVM-typed implementation for generating the implementation for constructing an `NDArray`.
|
||||||
|
///
|
||||||
|
/// * `elem_ty` - The element type of the NDArray.
|
||||||
|
/// * `shape` - The `shape` parameter used to construct the NDArray.
|
||||||
|
fn call_ndarray_empty_impl<'ctx, 'a>(
|
||||||
generator: &mut dyn CodeGenerator,
|
generator: &mut dyn CodeGenerator,
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
elem_ty: Type,
|
elem_ty: Type,
|
||||||
var_name: Option<&str>,
|
|
||||||
shape: PointerValue<'ctx>,
|
shape: PointerValue<'ctx>,
|
||||||
) -> Result<PointerValue<'ctx>, String> {
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
let ndarray_ty_enum = TypeEnum::ndarray(&mut ctx.unifier, Some(elem_ty), None, &ctx.primitives);
|
let ndarray_ty_enum = TypeEnum::ndarray(&mut ctx.unifier, Some(elem_ty), None, &ctx.primitives);
|
||||||
|
@ -43,8 +229,8 @@ fn call_ndarray_impl<'ctx, 'a>(
|
||||||
gen_for_callback(
|
gen_for_callback(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
|_, ctx| {
|
|generator, ctx| {
|
||||||
let i = ctx.builder.build_alloca(llvm_usize, "");
|
let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||||
ctx.builder.build_store(i, llvm_usize.const_zero());
|
ctx.builder.build_store(i, llvm_usize.const_zero());
|
||||||
|
|
||||||
Ok(i)
|
Ok(i)
|
||||||
|
@ -106,10 +292,11 @@ fn call_ndarray_impl<'ctx, 'a>(
|
||||||
},
|
},
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let ndarray = ctx.builder.build_alloca(
|
let ndarray = generator.gen_var_alloc(
|
||||||
llvm_ndarray_t,
|
ctx,
|
||||||
var_name.unwrap_or_default()
|
llvm_ndarray_t.into(),
|
||||||
);
|
None,
|
||||||
|
)?;
|
||||||
|
|
||||||
let num_dims = ctx.build_gep_and_load(
|
let num_dims = ctx.build_gep_and_load(
|
||||||
shape,
|
shape,
|
||||||
|
@ -151,7 +338,26 @@ fn call_ndarray_impl<'ctx, 'a>(
|
||||||
|
|
||||||
call_ndarray_init_dims(generator, ctx, ndarray, shape);
|
call_ndarray_init_dims(generator, ctx, ndarray, shape);
|
||||||
|
|
||||||
let ndarray_num_elems = call_ndarray_calc_size(generator, ctx, shape);
|
let (ndarray_num_dims, ndarray_dims) = unsafe {
|
||||||
|
(
|
||||||
|
ctx.builder.build_in_bounds_gep(
|
||||||
|
ndarray,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||||
|
""
|
||||||
|
),
|
||||||
|
ctx.builder.build_in_bounds_gep(
|
||||||
|
ndarray,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
||||||
|
""
|
||||||
|
),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
let ndarray_num_elems = call_ndarray_calc_size(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ctx.builder.build_load(ndarray_num_dims, "").into_int_value(),
|
||||||
|
ctx.builder.build_load(ndarray_dims, "").into_pointer_value(),
|
||||||
|
);
|
||||||
|
|
||||||
let ndarray_data = unsafe {
|
let ndarray_data = unsafe {
|
||||||
ctx.builder.build_in_bounds_gep(
|
ctx.builder.build_in_bounds_gep(
|
||||||
|
@ -172,6 +378,342 @@ fn call_ndarray_impl<'ctx, 'a>(
|
||||||
Ok(ndarray)
|
Ok(ndarray)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as
|
||||||
|
/// its input.
|
||||||
|
///
|
||||||
|
/// Note that this differs from `ndarray.fill`, which instead replaces all first-dimension elements
|
||||||
|
/// with the given value (as opposed to all elements within the array).
|
||||||
|
fn ndarray_fill_flattened<'ctx, 'a, ValueFn>(
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
ndarray: PointerValue<'ctx>,
|
||||||
|
value_fn: ValueFn,
|
||||||
|
) -> Result<(), String>
|
||||||
|
where
|
||||||
|
ValueFn: Fn(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<BasicValueEnum<'ctx>, String>,
|
||||||
|
{
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
let (num_dims, dims) = unsafe {
|
||||||
|
(
|
||||||
|
ctx.builder.build_in_bounds_gep(
|
||||||
|
ndarray,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||||
|
""
|
||||||
|
),
|
||||||
|
ctx.builder.build_in_bounds_gep(
|
||||||
|
ndarray,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
||||||
|
""
|
||||||
|
),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
let ndarray_num_elems = call_ndarray_calc_size(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ctx.builder.build_load(num_dims, "").into_int_value(),
|
||||||
|
ctx.builder.build_load(dims, "").into_pointer_value(),
|
||||||
|
);
|
||||||
|
|
||||||
|
gen_for_callback(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
|generator, ctx| {
|
||||||
|
let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||||
|
ctx.builder.build_store(i, llvm_usize.const_zero());
|
||||||
|
|
||||||
|
Ok(i)
|
||||||
|
},
|
||||||
|
|_, ctx, i_addr| {
|
||||||
|
let i = ctx.builder
|
||||||
|
.build_load(i_addr, "")
|
||||||
|
.into_int_value();
|
||||||
|
|
||||||
|
Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, ndarray_num_elems, ""))
|
||||||
|
},
|
||||||
|
|generator, ctx, i_addr| {
|
||||||
|
let ndarray_data = ctx.build_gep_and_load(
|
||||||
|
ndarray,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_int(2, true)],
|
||||||
|
None
|
||||||
|
).into_pointer_value();
|
||||||
|
|
||||||
|
let i = ctx.builder
|
||||||
|
.build_load(i_addr, "")
|
||||||
|
.into_int_value();
|
||||||
|
let elem = unsafe {
|
||||||
|
ctx.builder.build_in_bounds_gep(
|
||||||
|
ndarray_data,
|
||||||
|
&[i],
|
||||||
|
""
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
let value = value_fn(generator, ctx, i)?;
|
||||||
|
ctx.builder.build_store(elem, value);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
|_, ctx, i_addr| {
|
||||||
|
let i = ctx.builder
|
||||||
|
.build_load(i_addr, "")
|
||||||
|
.into_int_value();
|
||||||
|
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "");
|
||||||
|
ctx.builder.build_store(i_addr, i);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with the dimension-indices
|
||||||
|
/// as its input
|
||||||
|
///
|
||||||
|
/// Note that this differs from `ndarray.fill`, which instead replaces all first-dimension elements
|
||||||
|
/// with the given value (as opposed to all elements within the array).
|
||||||
|
fn ndarray_fill_indexed<'ctx, 'a, ValueFn>(
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
ndarray: PointerValue<'ctx>,
|
||||||
|
value_fn: ValueFn,
|
||||||
|
) -> Result<(), String>
|
||||||
|
where
|
||||||
|
ValueFn: Fn(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, PointerValue<'ctx>) -> Result<BasicValueEnum<'ctx>, String>,
|
||||||
|
{
|
||||||
|
ndarray_fill_flattened(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ndarray,
|
||||||
|
|generator, ctx, idx| {
|
||||||
|
let indices = call_ndarray_calc_nd_indices(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
idx,
|
||||||
|
ndarray,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
value_fn(generator, ctx, indices)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// LLVM-typed implementation for generating the implementation for `ndarray.zeros`.
|
||||||
|
///
|
||||||
|
/// * `elem_ty` - The element type of the NDArray.
|
||||||
|
/// * `shape` - The `shape` parameter used to construct the NDArray.
|
||||||
|
fn call_ndarray_zeros_impl<'ctx, 'a>(
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
elem_ty: Type,
|
||||||
|
shape: PointerValue<'ctx>,
|
||||||
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
|
let supported_types = [
|
||||||
|
ctx.primitives.int32,
|
||||||
|
ctx.primitives.int64,
|
||||||
|
ctx.primitives.uint32,
|
||||||
|
ctx.primitives.uint64,
|
||||||
|
ctx.primitives.float,
|
||||||
|
ctx.primitives.bool,
|
||||||
|
ctx.primitives.str,
|
||||||
|
];
|
||||||
|
assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty)));
|
||||||
|
|
||||||
|
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
|
||||||
|
ndarray_fill_flattened(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ndarray,
|
||||||
|
|generator, ctx, _| {
|
||||||
|
let value = ndarray_zero_value(generator, ctx, elem_ty);
|
||||||
|
|
||||||
|
Ok(value)
|
||||||
|
}
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(ndarray)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// LLVM-typed implementation for generating the implementation for `ndarray.ones`.
|
||||||
|
///
|
||||||
|
/// * `elem_ty` - The element type of the NDArray.
|
||||||
|
/// * `shape` - The `shape` parameter used to construct the NDArray.
|
||||||
|
fn call_ndarray_ones_impl<'ctx, 'a>(
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
elem_ty: Type,
|
||||||
|
shape: PointerValue<'ctx>,
|
||||||
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
|
let supported_types = [
|
||||||
|
ctx.primitives.int32,
|
||||||
|
ctx.primitives.int64,
|
||||||
|
ctx.primitives.uint32,
|
||||||
|
ctx.primitives.uint64,
|
||||||
|
ctx.primitives.float,
|
||||||
|
ctx.primitives.bool,
|
||||||
|
ctx.primitives.str,
|
||||||
|
];
|
||||||
|
assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty)));
|
||||||
|
|
||||||
|
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
|
||||||
|
ndarray_fill_flattened(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ndarray,
|
||||||
|
|generator, ctx, _| {
|
||||||
|
let value = ndarray_one_value(generator, ctx, elem_ty);
|
||||||
|
|
||||||
|
Ok(value)
|
||||||
|
}
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(ndarray)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// LLVM-typed implementation for generating the implementation for `ndarray.ones`.
|
||||||
|
///
|
||||||
|
/// * `elem_ty` - The element type of the NDArray.
|
||||||
|
/// * `shape` - The `shape` parameter used to construct the NDArray.
|
||||||
|
fn call_ndarray_full_impl<'ctx, 'a>(
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
elem_ty: Type,
|
||||||
|
shape: PointerValue<'ctx>,
|
||||||
|
fill_value: BasicValueEnum<'ctx>,
|
||||||
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
|
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
|
||||||
|
ndarray_fill_flattened(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ndarray,
|
||||||
|
|generator, ctx, _| {
|
||||||
|
let value = if fill_value.is_pointer_value() {
|
||||||
|
let llvm_void = ctx.ctx.void_type();
|
||||||
|
let llvm_i1 = ctx.ctx.bool_type();
|
||||||
|
let llvm_i8 = ctx.ctx.i8_type();
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
||||||
|
|
||||||
|
let copy = generator.gen_var_alloc(ctx, fill_value.get_type(), None)?;
|
||||||
|
|
||||||
|
let memcpy_fn_name = format!(
|
||||||
|
"llvm.memcpy.p0i8.p0i8.i{}",
|
||||||
|
generator.get_size_type(ctx.ctx).get_bit_width(),
|
||||||
|
);
|
||||||
|
let memcpy_fn = ctx.module.get_function(memcpy_fn_name.as_str()).unwrap_or_else(|| {
|
||||||
|
let fn_type = llvm_void.fn_type(
|
||||||
|
&[
|
||||||
|
llvm_pi8.into(),
|
||||||
|
llvm_pi8.into(),
|
||||||
|
llvm_usize.into(),
|
||||||
|
llvm_i1.into(),
|
||||||
|
],
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
|
||||||
|
ctx.module.add_function(memcpy_fn_name.as_str(), fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
ctx.builder.build_call(
|
||||||
|
memcpy_fn,
|
||||||
|
&[
|
||||||
|
copy.into(),
|
||||||
|
fill_value.into(),
|
||||||
|
fill_value.get_type().size_of().unwrap().into(),
|
||||||
|
llvm_i1.const_zero().into(),
|
||||||
|
],
|
||||||
|
"",
|
||||||
|
);
|
||||||
|
|
||||||
|
copy.into()
|
||||||
|
} else if fill_value.is_int_value() || fill_value.is_float_value() {
|
||||||
|
fill_value.into()
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(value)
|
||||||
|
}
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(ndarray)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// LLVM-typed implementation for generating the implementation for `ndarray.eye`.
|
||||||
|
///
|
||||||
|
/// * `elem_ty` - The element type of the NDArray.
|
||||||
|
fn call_ndarray_eye_impl<'ctx, 'a>(
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
elem_ty: Type,
|
||||||
|
nrows: IntValue<'ctx>,
|
||||||
|
ncols: IntValue<'ctx>,
|
||||||
|
offset: IntValue<'ctx>,
|
||||||
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let llvm_usize_2 = llvm_usize.array_type(2);
|
||||||
|
|
||||||
|
let shape_addr = generator.gen_var_alloc(ctx, llvm_usize_2.into(), None)?;
|
||||||
|
|
||||||
|
let shape = ctx.builder.build_load(shape_addr, "")
|
||||||
|
.into_array_value();
|
||||||
|
|
||||||
|
let nrows = ctx.builder.build_int_z_extend_or_bit_cast(nrows, llvm_usize, "");
|
||||||
|
let shape = ctx.builder
|
||||||
|
.build_insert_value(shape, nrows, 0, "")
|
||||||
|
.map(|val| val.into_array_value())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let ncols = ctx.builder.build_int_z_extend_or_bit_cast(ncols, llvm_usize, "");
|
||||||
|
let shape = ctx.builder
|
||||||
|
.build_insert_value(shape, ncols, 1, "")
|
||||||
|
.map(|val| val.into_array_value())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let ndarray = create_ndarray_const_shape(generator, ctx, elem_ty, shape)?;
|
||||||
|
|
||||||
|
ndarray_fill_indexed(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ndarray,
|
||||||
|
|generator, ctx, indices| {
|
||||||
|
let row = ctx.build_gep_and_load(
|
||||||
|
indices,
|
||||||
|
&[llvm_i32.const_zero()],
|
||||||
|
None,
|
||||||
|
).into_int_value();
|
||||||
|
let col = ctx.build_gep_and_load(
|
||||||
|
indices,
|
||||||
|
&[llvm_i32.const_int(1, true)],
|
||||||
|
None,
|
||||||
|
).into_int_value();
|
||||||
|
|
||||||
|
let col_with_offset = ctx.builder.build_int_add(
|
||||||
|
col,
|
||||||
|
ctx.builder.build_int_z_extend_or_bit_cast(offset, llvm_usize, ""),
|
||||||
|
""
|
||||||
|
);
|
||||||
|
let is_on_diag = ctx.builder.build_int_compare(
|
||||||
|
IntPredicate::EQ,
|
||||||
|
row,
|
||||||
|
col_with_offset,
|
||||||
|
""
|
||||||
|
);
|
||||||
|
|
||||||
|
let zero = ndarray_zero_value(generator, ctx, elem_ty);
|
||||||
|
let one = ndarray_one_value(generator, ctx, elem_ty);
|
||||||
|
|
||||||
|
let value = ctx.builder.build_select(is_on_diag, one, zero, "");
|
||||||
|
|
||||||
|
Ok(value)
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(ndarray)
|
||||||
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for `ndarray.empty`.
|
/// Generates LLVM IR for `ndarray.empty`.
|
||||||
pub fn gen_ndarray_empty<'ctx, 'a>(
|
pub fn gen_ndarray_empty<'ctx, 'a>(
|
||||||
context: &mut CodeGenContext<'ctx, 'a>,
|
context: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
@ -184,15 +726,158 @@ pub fn gen_ndarray_empty<'ctx, 'a>(
|
||||||
assert_eq!(args.len(), 1);
|
assert_eq!(args.len(), 1);
|
||||||
|
|
||||||
let shape_ty = fun.0.args[0].ty;
|
let shape_ty = fun.0.args[0].ty;
|
||||||
let shape_arg_name = args[0].0;
|
|
||||||
let shape_arg = args[0].1.clone()
|
let shape_arg = args[0].1.clone()
|
||||||
.to_basic_value_enum(context, generator, shape_ty)?;
|
.to_basic_value_enum(context, generator, shape_ty)?;
|
||||||
|
|
||||||
call_ndarray_impl(
|
call_ndarray_empty_impl(
|
||||||
generator,
|
generator,
|
||||||
context,
|
context,
|
||||||
context.primitives.float,
|
context.primitives.float,
|
||||||
shape_arg_name.map(|name| name.to_string()).as_deref(),
|
|
||||||
shape_arg.into_pointer_value(),
|
shape_arg.into_pointer_value(),
|
||||||
)
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates LLVM IR for `ndarray.zeros`.
|
||||||
|
pub fn gen_ndarray_zeros<'ctx, 'a>(
|
||||||
|
context: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
obj: Option<(Type, ValueEnum<'ctx>)>,
|
||||||
|
fun: (&FunSignature, DefinitionId),
|
||||||
|
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
|
assert!(obj.is_none());
|
||||||
|
assert_eq!(args.len(), 1);
|
||||||
|
|
||||||
|
let shape_ty = fun.0.args[0].ty;
|
||||||
|
let shape_arg = args[0].1.clone()
|
||||||
|
.to_basic_value_enum(context, generator, shape_ty)?;
|
||||||
|
|
||||||
|
call_ndarray_zeros_impl(
|
||||||
|
generator,
|
||||||
|
context,
|
||||||
|
context.primitives.float,
|
||||||
|
shape_arg.into_pointer_value(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates LLVM IR for `ndarray.ones`.
|
||||||
|
pub fn gen_ndarray_ones<'ctx, 'a>(
|
||||||
|
context: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
obj: Option<(Type, ValueEnum<'ctx>)>,
|
||||||
|
fun: (&FunSignature, DefinitionId),
|
||||||
|
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
|
assert!(obj.is_none());
|
||||||
|
assert_eq!(args.len(), 1);
|
||||||
|
|
||||||
|
let shape_ty = fun.0.args[0].ty;
|
||||||
|
let shape_arg = args[0].1.clone()
|
||||||
|
.to_basic_value_enum(context, generator, shape_ty)?;
|
||||||
|
|
||||||
|
call_ndarray_ones_impl(
|
||||||
|
generator,
|
||||||
|
context,
|
||||||
|
context.primitives.float,
|
||||||
|
shape_arg.into_pointer_value(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates LLVM IR for `ndarray.full`.
|
||||||
|
pub fn gen_ndarray_full<'ctx, 'a>(
|
||||||
|
context: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
obj: Option<(Type, ValueEnum<'ctx>)>,
|
||||||
|
fun: (&FunSignature, DefinitionId),
|
||||||
|
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
|
assert!(obj.is_none());
|
||||||
|
assert_eq!(args.len(), 2);
|
||||||
|
|
||||||
|
let shape_ty = fun.0.args[0].ty;
|
||||||
|
let shape_arg = args[0].1.clone()
|
||||||
|
.to_basic_value_enum(context, generator, shape_ty)?;
|
||||||
|
let fill_value_ty = fun.0.args[1].ty;
|
||||||
|
let fill_value_arg = args[1].1.clone()
|
||||||
|
.to_basic_value_enum(context, generator, fill_value_ty)?;
|
||||||
|
|
||||||
|
call_ndarray_full_impl(
|
||||||
|
generator,
|
||||||
|
context,
|
||||||
|
fill_value_ty,
|
||||||
|
shape_arg.into_pointer_value(),
|
||||||
|
fill_value_arg,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates LLVM IR for `ndarray.eye`.
|
||||||
|
pub fn gen_ndarray_eye<'ctx, 'a>(
|
||||||
|
context: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
obj: Option<(Type, ValueEnum<'ctx>)>,
|
||||||
|
fun: (&FunSignature, DefinitionId),
|
||||||
|
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
|
assert!(obj.is_none());
|
||||||
|
assert!(matches!(args.len(), 1..=3));
|
||||||
|
|
||||||
|
let nrows_ty = fun.0.args[0].ty;
|
||||||
|
let nrows_arg = args[0].1.clone()
|
||||||
|
.to_basic_value_enum(context, generator, nrows_ty)?;
|
||||||
|
|
||||||
|
let ncols_ty = fun.0.args[1].ty;
|
||||||
|
let ncols_arg = args.iter()
|
||||||
|
.find(|arg| arg.0.map(|name| name == fun.0.args[1].name).unwrap_or(false))
|
||||||
|
.map(|arg| arg.1.clone().to_basic_value_enum(context, generator, ncols_ty))
|
||||||
|
.unwrap_or_else(|| {
|
||||||
|
args[0].1.clone().to_basic_value_enum(context, generator, nrows_ty)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let offset_ty = fun.0.args[2].ty;
|
||||||
|
let offset_arg = args.iter()
|
||||||
|
.find(|arg| arg.0.map(|name| name == fun.0.args[2].name).unwrap_or(false))
|
||||||
|
.map(|arg| arg.1.clone().to_basic_value_enum(context, generator, offset_ty))
|
||||||
|
.unwrap_or_else(|| {
|
||||||
|
Ok(context.gen_symbol_val(
|
||||||
|
generator,
|
||||||
|
fun.0.args[2].default_value.as_ref().unwrap(),
|
||||||
|
offset_ty
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
call_ndarray_eye_impl(
|
||||||
|
generator,
|
||||||
|
context,
|
||||||
|
context.primitives.float,
|
||||||
|
nrows_arg.into_int_value(),
|
||||||
|
ncols_arg.into_int_value(),
|
||||||
|
offset_arg.into_int_value(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates LLVM IR for `ndarray.identity`.
|
||||||
|
pub fn gen_ndarray_identity<'ctx, 'a>(
|
||||||
|
context: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
obj: Option<(Type, ValueEnum<'ctx>)>,
|
||||||
|
fun: (&FunSignature, DefinitionId),
|
||||||
|
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
|
assert!(obj.is_none());
|
||||||
|
assert_eq!(args.len(), 1);
|
||||||
|
|
||||||
|
let llvm_usize = generator.get_size_type(context.ctx);
|
||||||
|
|
||||||
|
let n_ty = fun.0.args[0].ty;
|
||||||
|
let n_arg = args[0].1.clone()
|
||||||
|
.to_basic_value_enum(context, generator, n_ty)?;
|
||||||
|
|
||||||
|
call_ndarray_eye_impl(
|
||||||
|
generator,
|
||||||
|
context,
|
||||||
|
context.primitives.float,
|
||||||
|
n_arg.into_int_value(),
|
||||||
|
n_arg.into_int_value(),
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
)
|
||||||
}
|
}
|
|
@ -898,9 +898,14 @@ impl<'a> Inferencer<'a> {
|
||||||
if [
|
if [
|
||||||
"np_ndarray".into(),
|
"np_ndarray".into(),
|
||||||
"np_empty".into(),
|
"np_empty".into(),
|
||||||
|
"np_zeros".into(),
|
||||||
|
"np_ones".into(),
|
||||||
].contains(id) && args.len() == 1 {
|
].contains(id) && args.len() == 1 {
|
||||||
let ExprKind::List { elts, .. } = &args[0].node else {
|
let ExprKind::List { elts, .. } = &args[0].node else {
|
||||||
return report_error("Expected List literal for first argument of np_ndarray", args[0].location)
|
return report_error(
|
||||||
|
format!("Expected List literal for first argument of {id}, got {}", args[0].node.name()).as_str(),
|
||||||
|
args[0].location
|
||||||
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
let ndims = elts.len() as u64;
|
let ndims = elts.len() as u64;
|
||||||
|
@ -941,6 +946,62 @@ impl<'a> Inferencer<'a> {
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 2-argument ndarray n-dimensional creation functions
|
||||||
|
if id == &"np_full".into() && args.len() == 2 {
|
||||||
|
let ExprKind::List { elts, .. } = &args[0].node else {
|
||||||
|
return report_error(
|
||||||
|
format!("Expected List literal for first argument of {id}, got {}", args[0].node.name()).as_str(),
|
||||||
|
args[0].location
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
let ndims = elts.len() as u64;
|
||||||
|
|
||||||
|
let arg0 = self.fold_expr(args.remove(0))?;
|
||||||
|
let arg1 = self.fold_expr(args.remove(0))?;
|
||||||
|
|
||||||
|
let ty = arg1.custom.unwrap();
|
||||||
|
let ndims = self.unifier.get_fresh_literal(
|
||||||
|
vec![SymbolValue::U64(ndims)],
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
|
let ret = self.unifier.add_ty(TypeEnum::TNDArray {
|
||||||
|
ty,
|
||||||
|
ndims
|
||||||
|
});
|
||||||
|
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
|
args: vec![
|
||||||
|
FuncArg {
|
||||||
|
name: "shape".into(),
|
||||||
|
ty: arg0.custom.unwrap(),
|
||||||
|
default_value: None,
|
||||||
|
},
|
||||||
|
FuncArg {
|
||||||
|
name: "fill_value".into(),
|
||||||
|
ty: arg1.custom.unwrap(),
|
||||||
|
default_value: None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
ret,
|
||||||
|
vars: HashMap::new(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
return Ok(Some(Located {
|
||||||
|
location,
|
||||||
|
custom: Some(ret),
|
||||||
|
node: ExprKind::Call {
|
||||||
|
func: Box::new(Located {
|
||||||
|
custom: Some(custom),
|
||||||
|
location: func.location,
|
||||||
|
node: ExprKind::Name { id: *id, ctx: ctx.clone() },
|
||||||
|
}),
|
||||||
|
args: vec![arg0, arg1],
|
||||||
|
keywords: vec![],
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -187,6 +187,11 @@ def patch(module):
|
||||||
# NumPy NDArray Functions
|
# NumPy NDArray Functions
|
||||||
module.np_ndarray = np.ndarray
|
module.np_ndarray = np.ndarray
|
||||||
module.np_empty = np.empty
|
module.np_empty = np.empty
|
||||||
|
module.np_zeros = np.zeros
|
||||||
|
module.np_ones = np.ones
|
||||||
|
module.np_full = np.full
|
||||||
|
module.np_eye = np.eye
|
||||||
|
module.np_identity = np.identity
|
||||||
|
|
||||||
def file_import(filename, prefix="file_import_"):
|
def file_import(filename, prefix="file_import_"):
|
||||||
filename = pathlib.Path(filename)
|
filename = pathlib.Path(filename)
|
||||||
|
|
|
@ -7,6 +7,12 @@ def consume_ndarray_i32_1(n: ndarray[int32, Literal[1]]):
|
||||||
def consume_ndarray_2(n: ndarray[float, Literal[2]]):
|
def consume_ndarray_2(n: ndarray[float, Literal[2]]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def consume_ndarray_i32_1(n: ndarray[int32, 1]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def consume_ndarray_2(n: ndarray[float, 2]):
|
||||||
|
pass
|
||||||
|
|
||||||
def test_ndarray_ctor():
|
def test_ndarray_ctor():
|
||||||
n = np_ndarray([1])
|
n = np_ndarray([1])
|
||||||
consume_ndarray_1(n)
|
consume_ndarray_1(n)
|
||||||
|
@ -15,8 +21,35 @@ def test_ndarray_empty():
|
||||||
n = np_empty([1])
|
n = np_empty([1])
|
||||||
consume_ndarray_1(n)
|
consume_ndarray_1(n)
|
||||||
|
|
||||||
|
def test_ndarray_zeros():
|
||||||
|
n = np_zeros([1])
|
||||||
|
consume_ndarray_1(n)
|
||||||
|
|
||||||
|
def test_ndarray_ones():
|
||||||
|
n = np_ones([1])
|
||||||
|
consume_ndarray_1(n)
|
||||||
|
|
||||||
|
def test_ndarray_full():
|
||||||
|
n_float = np_full([1], 2.0)
|
||||||
|
consume_ndarray_1(n_float)
|
||||||
|
n_i32 = np_full([1], 2)
|
||||||
|
consume_ndarray_i32_1(n_i32)
|
||||||
|
|
||||||
|
def test_ndarray_eye():
|
||||||
|
n = np_eye(2)
|
||||||
|
consume_ndarray_2(n)
|
||||||
|
|
||||||
|
def test_ndarray_identity():
|
||||||
|
n = np_identity(2)
|
||||||
|
consume_ndarray_2(n)
|
||||||
|
|
||||||
def run() -> int32:
|
def run() -> int32:
|
||||||
test_ndarray_ctor()
|
test_ndarray_ctor()
|
||||||
test_ndarray_empty()
|
test_ndarray_empty()
|
||||||
|
test_ndarray_zeros()
|
||||||
|
test_ndarray_ones()
|
||||||
|
test_ndarray_full()
|
||||||
|
test_ndarray_eye()
|
||||||
|
test_ndarray_identity()
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|
Loading…
Reference in New Issue