forked from M-Labs/nac3
core: Implement ndarray constructor and numpy.empty
This commit is contained in:
parent
afa7d9b100
commit
27fcf8926e
|
@ -196,4 +196,48 @@ double __nac3_j0(double x) {
|
||||||
}
|
}
|
||||||
|
|
||||||
return j0(x);
|
return j0(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t __nac3_ndarray_calc_size(
|
||||||
|
const int32_t *list_data,
|
||||||
|
uint32_t list_len
|
||||||
|
) {
|
||||||
|
uint32_t num_elems = 1;
|
||||||
|
for (uint32_t i = 0; i < list_len; ++i) {
|
||||||
|
int32_t val = list_data[i];
|
||||||
|
__builtin_assume(val >= 0);
|
||||||
|
num_elems *= (uint32_t) list_data[i];
|
||||||
|
}
|
||||||
|
return num_elems;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t __nac3_ndarray_calc_size64(
|
||||||
|
const int32_t *list_data,
|
||||||
|
uint64_t list_len
|
||||||
|
) {
|
||||||
|
uint64_t num_elems = 1;
|
||||||
|
for (uint64_t i = 0; i < list_len; ++i) {
|
||||||
|
int32_t val = list_data[i];
|
||||||
|
__builtin_assume(val >= 0);
|
||||||
|
num_elems *= (uint64_t) list_data[i];
|
||||||
|
}
|
||||||
|
return num_elems;
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_init_dims(
|
||||||
|
uint32_t *ndarray_dims,
|
||||||
|
const int32_t *shape_data,
|
||||||
|
uint32_t shape_len
|
||||||
|
) {
|
||||||
|
__builtin_memcpy(ndarray_dims, shape_data, shape_len * sizeof(int32_t));
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_init_dims64(
|
||||||
|
uint64_t *ndarray_dims,
|
||||||
|
const int32_t *shape_data,
|
||||||
|
uint64_t shape_len
|
||||||
|
) {
|
||||||
|
for (uint64_t i = 0; i < shape_len; ++i) {
|
||||||
|
ndarray_dims[i] = (uint64_t) shape_data[i];
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -12,6 +12,9 @@ use inkwell::{
|
||||||
};
|
};
|
||||||
use nac3parser::ast::Expr;
|
use nac3parser::ast::Expr;
|
||||||
|
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
use inkwell::types::AnyTypeEnum;
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn load_irrt(ctx: &Context) -> Module {
|
pub fn load_irrt(ctx: &Context) -> Module {
|
||||||
let bitcode_buf = MemoryBuffer::create_from_memory_range(
|
let bitcode_buf = MemoryBuffer::create_from_memory_range(
|
||||||
|
@ -546,3 +549,176 @@ pub fn call_j0<'ctx>(
|
||||||
.unwrap_left()
|
.unwrap_left()
|
||||||
.into_float_value()
|
.into_float_value()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Checks whether the pointer `value` refers to a `list` in LLVM.
|
||||||
|
fn assert_is_list(value: PointerValue) -> PointerValue {
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
let llvm_shape_ty = value.get_type().get_element_type();
|
||||||
|
let AnyTypeEnum::StructType(llvm_shape_ty) = llvm_shape_ty else {
|
||||||
|
panic!("Expected struct type for `list` type, but got {llvm_shape_ty}")
|
||||||
|
};
|
||||||
|
assert_eq!(llvm_shape_ty.count_fields(), 2);
|
||||||
|
assert!(matches!(llvm_shape_ty.get_field_type_at_index(0), Some(BasicTypeEnum::PointerType(..))));
|
||||||
|
assert!(matches!(llvm_shape_ty.get_field_type_at_index(1), Some(BasicTypeEnum::IntType(..))));
|
||||||
|
}
|
||||||
|
|
||||||
|
value
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Checks whether the pointer `value` refers to an `NDArray` in LLVM.
|
||||||
|
fn assert_is_ndarray(value: PointerValue) -> PointerValue {
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
let llvm_ndarray_ty = value.get_type().get_element_type();
|
||||||
|
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
|
||||||
|
panic!("Expected struct type for `NDArray` type, but got {llvm_ndarray_ty}")
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(llvm_ndarray_ty.count_fields(), 3);
|
||||||
|
assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(0), Some(BasicTypeEnum::IntType(..))));
|
||||||
|
let Some(ndarray_dims) = llvm_ndarray_ty.get_field_type_at_index(1) else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
let BasicTypeEnum::PointerType(dims) = ndarray_dims else {
|
||||||
|
panic!("Expected pointer type for `list.1`, but got {ndarray_dims}")
|
||||||
|
};
|
||||||
|
assert!(matches!(dims.get_element_type(), AnyTypeEnum::IntType(..)));
|
||||||
|
assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(2), Some(BasicTypeEnum::PointerType(..))));
|
||||||
|
}
|
||||||
|
|
||||||
|
value
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [IntValue] representing the
|
||||||
|
/// calculated total size.
|
||||||
|
///
|
||||||
|
/// * `shape` - LLVM pointer to the `shape` of the NDArray. This value must be the LLVM
|
||||||
|
/// representation of a `list`.
|
||||||
|
pub fn call_ndarray_calc_size<'ctx, 'a>(
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
shape: PointerValue<'ctx>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
assert_is_list(shape);
|
||||||
|
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
|
||||||
|
|
||||||
|
let ndarray_calc_size_fn_name = match generator.get_size_type(ctx.ctx).get_bit_width() {
|
||||||
|
32 => "__nac3_ndarray_calc_size",
|
||||||
|
64 => "__nac3_ndarray_calc_size64",
|
||||||
|
bw => unreachable!("Unsupported size type bit width: {}", bw)
|
||||||
|
};
|
||||||
|
let ndarray_calc_size_fn_t = llvm_usize.fn_type(
|
||||||
|
&[
|
||||||
|
llvm_pi32.into(),
|
||||||
|
llvm_usize.into(),
|
||||||
|
],
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
let ndarray_calc_size_fn = ctx.module.get_function(ndarray_calc_size_fn_name)
|
||||||
|
.unwrap_or_else(|| {
|
||||||
|
ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let (
|
||||||
|
shape_data,
|
||||||
|
shape_len,
|
||||||
|
) = unsafe {
|
||||||
|
(
|
||||||
|
ctx.builder.build_in_bounds_gep(
|
||||||
|
shape,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||||
|
""
|
||||||
|
),
|
||||||
|
ctx.builder.build_in_bounds_gep(
|
||||||
|
shape,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
||||||
|
""
|
||||||
|
),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_call(
|
||||||
|
ndarray_calc_size_fn,
|
||||||
|
&[
|
||||||
|
ctx.builder.build_load(shape_data, "").into(),
|
||||||
|
ctx.builder.build_load(shape_len, "").into(),
|
||||||
|
],
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.try_as_basic_value()
|
||||||
|
.unwrap_left()
|
||||||
|
.into_int_value()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a call to `__nac3_ndarray_init_dims`.
|
||||||
|
///
|
||||||
|
/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an
|
||||||
|
/// `NDArray`.
|
||||||
|
/// * `shape` - LLVM pointer to the `shape` of the NDArray. This value must be the LLVM
|
||||||
|
/// representation of a `list`.
|
||||||
|
pub fn call_ndarray_init_dims<'ctx, 'a>(
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
ndarray: PointerValue<'ctx>,
|
||||||
|
shape: PointerValue<'ctx>,
|
||||||
|
) {
|
||||||
|
assert_is_ndarray(ndarray);
|
||||||
|
assert_is_list(shape);
|
||||||
|
|
||||||
|
let llvm_void = ctx.ctx.void_type();
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
|
||||||
|
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||||
|
|
||||||
|
let ndarray_init_dims_fn_name = match generator.get_size_type(ctx.ctx).get_bit_width() {
|
||||||
|
32 => "__nac3_ndarray_init_dims",
|
||||||
|
64 => "__nac3_ndarray_init_dims64",
|
||||||
|
bw => unreachable!("Unsupported size type bit width: {}", bw)
|
||||||
|
};
|
||||||
|
let ndarray_init_dims_fn = ctx.module.get_function(ndarray_init_dims_fn_name).unwrap_or_else(|| {
|
||||||
|
let fn_type = llvm_void.fn_type(
|
||||||
|
&[
|
||||||
|
llvm_pusize.into(),
|
||||||
|
llvm_pi32.into(),
|
||||||
|
llvm_usize.into(),
|
||||||
|
],
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
|
||||||
|
ctx.module.add_function(ndarray_init_dims_fn_name, fn_type, None)
|
||||||
|
});
|
||||||
|
|
||||||
|
let ndarray_dims = ctx.build_gep_and_load(
|
||||||
|
ndarray,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
let shape_data = ctx.build_gep_and_load(
|
||||||
|
shape,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||||
|
None
|
||||||
|
);
|
||||||
|
let ndarray_num_dims = ctx.build_gep_and_load(
|
||||||
|
ndarray,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||||
|
None,
|
||||||
|
).into_int_value();
|
||||||
|
|
||||||
|
ctx.builder.build_call(
|
||||||
|
ndarray_init_dims_fn,
|
||||||
|
&[
|
||||||
|
ndarray_dims.into(),
|
||||||
|
shape_data.into(),
|
||||||
|
ndarray_num_dims.into(),
|
||||||
|
],
|
||||||
|
"",
|
||||||
|
);
|
||||||
|
}
|
|
@ -16,7 +16,7 @@ use inkwell::{
|
||||||
attributes::{Attribute, AttributeLoc},
|
attributes::{Attribute, AttributeLoc},
|
||||||
basic_block::BasicBlock,
|
basic_block::BasicBlock,
|
||||||
types::BasicTypeEnum,
|
types::BasicTypeEnum,
|
||||||
values::{BasicValue, BasicValueEnum, FunctionValue, PointerValue},
|
values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue},
|
||||||
IntPredicate,
|
IntPredicate,
|
||||||
};
|
};
|
||||||
use nac3parser::ast::{
|
use nac3parser::ast::{
|
||||||
|
@ -405,6 +405,80 @@ pub fn gen_for<G: CodeGenerator>(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generates a C-style `for` construct using lambdas, similar to the following C code:
|
||||||
|
///
|
||||||
|
/// ```c
|
||||||
|
/// for (x... = init(); cond(x...); update(x...)) {
|
||||||
|
/// body(x...);
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// * `init` - A lambda containing IR statements declaring and initializing loop variables. The
|
||||||
|
/// return value is a [Clone] value which will be passed to the other lambdas.
|
||||||
|
/// * `cond` - A lambda containing IR statements checking whether the loop should continue
|
||||||
|
/// executing. The result value must be an `i1` indicating if the loop should continue.
|
||||||
|
/// * `body` - A lambda containing IR statements within the loop body.
|
||||||
|
/// * `update` - A lambda containing IR statements updating loop variables.
|
||||||
|
pub fn gen_for_callback<'ctx, 'a, I, InitFn, CondFn, BodyFn, UpdateFn>(
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
init: InitFn,
|
||||||
|
cond: CondFn,
|
||||||
|
body: BodyFn,
|
||||||
|
update: UpdateFn,
|
||||||
|
) -> Result<(), String>
|
||||||
|
where
|
||||||
|
I: Clone,
|
||||||
|
InitFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>) -> Result<I, String>,
|
||||||
|
CondFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, I) -> Result<IntValue<'ctx>, String>,
|
||||||
|
BodyFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
|
||||||
|
UpdateFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
|
||||||
|
{
|
||||||
|
let current = ctx.builder.get_insert_block().and_then(|bb| bb.get_parent()).unwrap();
|
||||||
|
let init_bb = ctx.ctx.append_basic_block(current, "for.init");
|
||||||
|
// The BB containing the loop condition check
|
||||||
|
let cond_bb = ctx.ctx.append_basic_block(current, "for.cond");
|
||||||
|
let body_bb = ctx.ctx.append_basic_block(current, "for.body");
|
||||||
|
// The BB containing the increment expression
|
||||||
|
let update_bb = ctx.ctx.append_basic_block(current, "for.update");
|
||||||
|
let cont_bb = ctx.ctx.append_basic_block(current, "for.end");
|
||||||
|
|
||||||
|
// store loop bb information and restore it later
|
||||||
|
let loop_bb = ctx.loop_target.replace((update_bb, cont_bb));
|
||||||
|
|
||||||
|
ctx.builder.build_unconditional_branch(init_bb);
|
||||||
|
|
||||||
|
let loop_var = {
|
||||||
|
ctx.builder.position_at_end(init_bb);
|
||||||
|
let result = init(generator, ctx)?;
|
||||||
|
ctx.builder.build_unconditional_branch(cond_bb);
|
||||||
|
|
||||||
|
result
|
||||||
|
};
|
||||||
|
|
||||||
|
ctx.builder.position_at_end(cond_bb);
|
||||||
|
let cond = cond(generator, ctx, loop_var.clone())?;
|
||||||
|
assert_eq!(cond.get_type().get_bit_width(), ctx.ctx.bool_type().get_bit_width());
|
||||||
|
ctx.builder.build_conditional_branch(
|
||||||
|
cond,
|
||||||
|
body_bb,
|
||||||
|
cont_bb
|
||||||
|
);
|
||||||
|
|
||||||
|
ctx.builder.position_at_end(body_bb);
|
||||||
|
body(generator, ctx, loop_var.clone())?;
|
||||||
|
ctx.builder.build_unconditional_branch(update_bb);
|
||||||
|
|
||||||
|
ctx.builder.position_at_end(update_bb);
|
||||||
|
update(generator, ctx, loop_var)?;
|
||||||
|
ctx.builder.build_unconditional_branch(cond_bb);
|
||||||
|
|
||||||
|
ctx.builder.position_at_end(cont_bb);
|
||||||
|
ctx.loop_target = loop_bb;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// See [`CodeGenerator::gen_while`].
|
/// See [`CodeGenerator::gen_while`].
|
||||||
pub fn gen_while<G: CodeGenerator>(
|
pub fn gen_while<G: CodeGenerator>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
|
|
|
@ -13,11 +13,12 @@ use crate::{
|
||||||
stmt::exn_constructor,
|
stmt::exn_constructor,
|
||||||
},
|
},
|
||||||
symbol_resolver::SymbolValue,
|
symbol_resolver::SymbolValue,
|
||||||
|
toplevel::numpy::gen_ndarray_empty,
|
||||||
};
|
};
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
attributes::{Attribute, AttributeLoc},
|
attributes::{Attribute, AttributeLoc},
|
||||||
types::{BasicType, BasicMetadataTypeEnum},
|
types::{BasicType, BasicMetadataTypeEnum},
|
||||||
values::BasicMetadataValueEnum,
|
values::{BasicValue, BasicMetadataValueEnum},
|
||||||
FloatPredicate,
|
FloatPredicate,
|
||||||
IntPredicate
|
IntPredicate
|
||||||
};
|
};
|
||||||
|
@ -278,6 +279,11 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
let boolean = primitives.0.bool;
|
let boolean = primitives.0.bool;
|
||||||
let range = primitives.0.range;
|
let range = primitives.0.range;
|
||||||
let string = primitives.0.str;
|
let string = primitives.0.str;
|
||||||
|
let ndarray_float = {
|
||||||
|
let ndarray_ty_enum = TypeEnum::ndarray(&mut primitives.1, Some(float), None, &primitives.0);
|
||||||
|
primitives.1.add_ty(ndarray_ty_enum)
|
||||||
|
};
|
||||||
|
let list_int32 = primitives.1.add_ty(TypeEnum::TList { ty: int32 });
|
||||||
let num_ty = primitives.1.get_fresh_var_with_range(
|
let num_ty = primitives.1.get_fresh_var_with_range(
|
||||||
&[int32, int64, float, boolean, uint32, uint64],
|
&[int32, int64, float, boolean, uint32, uint64],
|
||||||
Some("N".into()),
|
Some("N".into()),
|
||||||
|
@ -837,6 +843,32 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
)))),
|
)))),
|
||||||
loc: None,
|
loc: None,
|
||||||
})),
|
})),
|
||||||
|
create_fn_by_codegen(
|
||||||
|
primitives,
|
||||||
|
&var_map,
|
||||||
|
"np_ndarray",
|
||||||
|
ndarray_float,
|
||||||
|
// We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a
|
||||||
|
// type variable
|
||||||
|
&[(list_int32, "shape")],
|
||||||
|
Box::new(|ctx, obj, fun, args, generator| {
|
||||||
|
gen_ndarray_empty(ctx, obj, fun, args, generator)
|
||||||
|
.map(|val| Some(val.as_basic_value_enum()))
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
create_fn_by_codegen(
|
||||||
|
primitives,
|
||||||
|
&var_map,
|
||||||
|
"np_empty",
|
||||||
|
ndarray_float,
|
||||||
|
// We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a
|
||||||
|
// type variable
|
||||||
|
&[(list_int32, "shape")],
|
||||||
|
Box::new(|ctx, obj, fun, args, generator| {
|
||||||
|
gen_ndarray_empty(ctx, obj, fun, args, generator)
|
||||||
|
.map(|val| Some(val.as_basic_value_enum()))
|
||||||
|
}),
|
||||||
|
),
|
||||||
create_fn_by_codegen(
|
create_fn_by_codegen(
|
||||||
primitives,
|
primitives,
|
||||||
&var_map,
|
&var_map,
|
||||||
|
|
|
@ -25,6 +25,7 @@ pub struct DefinitionId(pub usize);
|
||||||
pub mod builtins;
|
pub mod builtins;
|
||||||
pub mod composer;
|
pub mod composer;
|
||||||
pub mod helper;
|
pub mod helper;
|
||||||
|
pub mod numpy;
|
||||||
pub mod type_annotation;
|
pub mod type_annotation;
|
||||||
use composer::*;
|
use composer::*;
|
||||||
use type_annotation::*;
|
use type_annotation::*;
|
||||||
|
|
|
@ -0,0 +1,198 @@
|
||||||
|
use inkwell::{
|
||||||
|
IntPredicate,
|
||||||
|
types::BasicType,
|
||||||
|
values::PointerValue,
|
||||||
|
};
|
||||||
|
use nac3parser::ast::StrRef;
|
||||||
|
use crate::{
|
||||||
|
codegen::{
|
||||||
|
CodeGenContext,
|
||||||
|
CodeGenerator,
|
||||||
|
irrt::{call_ndarray_calc_size, call_ndarray_init_dims},
|
||||||
|
stmt::gen_for_callback
|
||||||
|
},
|
||||||
|
symbol_resolver::ValueEnum,
|
||||||
|
toplevel::DefinitionId,
|
||||||
|
typecheck::typedef::{FunSignature, Type, TypeEnum},
|
||||||
|
};
|
||||||
|
|
||||||
|
/// LLVM-typed implementation for generating the implementation for constructing an `NDArray`.
|
||||||
|
///
|
||||||
|
/// * `elem_ty` - The element type of the NDArray.
|
||||||
|
/// * `var_name` - The variable name of the NDArray.
|
||||||
|
/// * `shape` - The `shape` parameter used to construct the NDArray.
|
||||||
|
fn call_ndarray_impl<'ctx, 'a>(
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
elem_ty: Type,
|
||||||
|
var_name: Option<&str>,
|
||||||
|
shape: PointerValue<'ctx>,
|
||||||
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
|
let ndarray_ty_enum = TypeEnum::ndarray(&mut ctx.unifier, Some(elem_ty), None, &ctx.primitives);
|
||||||
|
let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum);
|
||||||
|
|
||||||
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
|
||||||
|
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
|
||||||
|
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum();
|
||||||
|
assert!(llvm_ndarray_data_t.is_sized());
|
||||||
|
|
||||||
|
// Assert that all dimensions are non-negative
|
||||||
|
gen_for_callback(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
|_, ctx| {
|
||||||
|
let i = ctx.builder.build_alloca(llvm_usize, "");
|
||||||
|
ctx.builder.build_store(i, llvm_usize.const_zero());
|
||||||
|
|
||||||
|
Ok(i)
|
||||||
|
},
|
||||||
|
|_, ctx, i_addr| {
|
||||||
|
let i = ctx.builder
|
||||||
|
.build_load(i_addr, "")
|
||||||
|
.into_int_value();
|
||||||
|
let shape_len = ctx.build_gep_and_load(
|
||||||
|
shape,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
||||||
|
None,
|
||||||
|
).into_int_value();
|
||||||
|
|
||||||
|
Ok(ctx.builder.build_int_compare(IntPredicate::ULE, i, shape_len, ""))
|
||||||
|
},
|
||||||
|
|generator, ctx, i_addr| {
|
||||||
|
let shape_elems = ctx.build_gep_and_load(
|
||||||
|
shape,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||||
|
None
|
||||||
|
).into_pointer_value();
|
||||||
|
|
||||||
|
let i = ctx.builder
|
||||||
|
.build_load(i_addr, "")
|
||||||
|
.into_int_value();
|
||||||
|
let shape_dim = ctx.build_gep_and_load(
|
||||||
|
shape_elems,
|
||||||
|
&[i],
|
||||||
|
None
|
||||||
|
).into_int_value();
|
||||||
|
|
||||||
|
let shape_dim_gez = ctx.builder.build_int_compare(
|
||||||
|
IntPredicate::SGE,
|
||||||
|
shape_dim,
|
||||||
|
llvm_i32.const_zero(),
|
||||||
|
""
|
||||||
|
);
|
||||||
|
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
shape_dim_gez,
|
||||||
|
"0:ValueError",
|
||||||
|
"negative dimensions not supported",
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
|_, ctx, i_addr| {
|
||||||
|
let i = ctx.builder
|
||||||
|
.build_load(i_addr, "")
|
||||||
|
.into_int_value();
|
||||||
|
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "");
|
||||||
|
ctx.builder.build_store(i_addr, i);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let ndarray = ctx.builder.build_alloca(
|
||||||
|
llvm_ndarray_t,
|
||||||
|
var_name.unwrap_or_default()
|
||||||
|
);
|
||||||
|
|
||||||
|
let num_dims = ctx.build_gep_and_load(
|
||||||
|
shape,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
||||||
|
None
|
||||||
|
).into_int_value();
|
||||||
|
|
||||||
|
let ndarray_num_dims = unsafe {
|
||||||
|
ctx.builder.build_in_bounds_gep(
|
||||||
|
ndarray,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
};
|
||||||
|
ctx.builder.build_store(ndarray_num_dims, num_dims);
|
||||||
|
|
||||||
|
let ndarray_dims = unsafe {
|
||||||
|
ctx.builder.build_in_bounds_gep(
|
||||||
|
ndarray,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
let ndarray_num_dims = ctx.build_gep_and_load(
|
||||||
|
ndarray,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||||
|
None,
|
||||||
|
).into_int_value();
|
||||||
|
|
||||||
|
ctx.builder.build_store(
|
||||||
|
ndarray_dims,
|
||||||
|
ctx.builder.build_array_alloca(
|
||||||
|
llvm_usize,
|
||||||
|
ndarray_num_dims,
|
||||||
|
"",
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
call_ndarray_init_dims(generator, ctx, ndarray, shape);
|
||||||
|
|
||||||
|
let ndarray_num_elems = call_ndarray_calc_size(generator, ctx, shape);
|
||||||
|
|
||||||
|
let ndarray_data = unsafe {
|
||||||
|
ctx.builder.build_in_bounds_gep(
|
||||||
|
ndarray,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_int(2, true)],
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
};
|
||||||
|
ctx.builder.build_store(
|
||||||
|
ndarray_data,
|
||||||
|
ctx.builder.build_array_alloca(
|
||||||
|
llvm_ndarray_data_t,
|
||||||
|
ndarray_num_elems,
|
||||||
|
"",
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(ndarray)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates LLVM IR for `ndarray.empty`.
|
||||||
|
pub fn gen_ndarray_empty<'ctx, 'a>(
|
||||||
|
context: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
obj: Option<(Type, ValueEnum<'ctx>)>,
|
||||||
|
fun: (&FunSignature, DefinitionId),
|
||||||
|
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||||
|
generator: &mut dyn CodeGenerator,
|
||||||
|
) -> Result<PointerValue<'ctx>, String> {
|
||||||
|
assert!(obj.is_none());
|
||||||
|
assert_eq!(args.len(), 1);
|
||||||
|
|
||||||
|
let shape_ty = fun.0.args[0].ty;
|
||||||
|
let shape_arg_name = args[0].0;
|
||||||
|
let shape_arg = args[0].1.clone()
|
||||||
|
.to_basic_value_enum(context, generator, shape_ty)?;
|
||||||
|
|
||||||
|
call_ndarray_impl(
|
||||||
|
generator,
|
||||||
|
context,
|
||||||
|
context.primitives.float,
|
||||||
|
shape_arg_name.map(|name| name.to_string()).as_deref(),
|
||||||
|
shape_arg.into_pointer_value(),
|
||||||
|
)
|
||||||
|
}
|
|
@ -5,7 +5,7 @@ use std::{cell::RefCell, sync::Arc};
|
||||||
|
|
||||||
use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier};
|
use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier};
|
||||||
use super::{magic_methods::*, typedef::CallId};
|
use super::{magic_methods::*, typedef::CallId};
|
||||||
use crate::{symbol_resolver::SymbolResolver, toplevel::TopLevelContext};
|
use crate::{symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::TopLevelContext};
|
||||||
use itertools::izip;
|
use itertools::izip;
|
||||||
use nac3parser::ast::{
|
use nac3parser::ast::{
|
||||||
self,
|
self,
|
||||||
|
@ -894,6 +894,53 @@ impl<'a> Inferencer<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 1-argument ndarray n-dimensional creation functions
|
||||||
|
if [
|
||||||
|
"np_ndarray".into(),
|
||||||
|
"np_empty".into(),
|
||||||
|
].contains(id) && args.len() == 1 {
|
||||||
|
let ExprKind::List { elts, .. } = &args[0].node else {
|
||||||
|
return report_error("Expected List literal for first argument of np_ndarray", args[0].location)
|
||||||
|
};
|
||||||
|
|
||||||
|
let ndims = elts.len() as u64;
|
||||||
|
|
||||||
|
let arg0 = self.fold_expr(args.remove(0))?;
|
||||||
|
let ndims = self.unifier.get_fresh_literal(
|
||||||
|
vec![SymbolValue::U64(ndims)],
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
let ret = self.unifier.add_ty(TypeEnum::TNDArray {
|
||||||
|
ty: self.primitives.float,
|
||||||
|
ndims
|
||||||
|
});
|
||||||
|
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
|
args: vec![
|
||||||
|
FuncArg {
|
||||||
|
name: "shape".into(),
|
||||||
|
ty: arg0.custom.unwrap(),
|
||||||
|
default_value: None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
ret,
|
||||||
|
vars: HashMap::new(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
return Ok(Some(Located {
|
||||||
|
location,
|
||||||
|
custom: Some(ret),
|
||||||
|
node: ExprKind::Call {
|
||||||
|
func: Box::new(Located {
|
||||||
|
custom: Some(custom),
|
||||||
|
location: func.location,
|
||||||
|
node: ExprKind::Name { id: *id, ctx: ctx.clone() },
|
||||||
|
}),
|
||||||
|
args: vec![arg0],
|
||||||
|
keywords: vec![],
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,11 +5,12 @@ import importlib.util
|
||||||
import importlib.machinery
|
import importlib.machinery
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
from numpy import int32, int64, uint32, uint64
|
from numpy import int32, int64, uint32, uint64
|
||||||
from scipy import special
|
from scipy import special
|
||||||
from typing import TypeVar, Generic, Literal
|
from typing import TypeVar, Generic, Literal, Union
|
||||||
|
|
||||||
T = TypeVar('T')
|
T = TypeVar('T')
|
||||||
class Option(Generic[T]):
|
class Option(Generic[T]):
|
||||||
|
@ -50,6 +51,13 @@ class _ConstGenericMarker:
|
||||||
def ConstGeneric(name, constraint):
|
def ConstGeneric(name, constraint):
|
||||||
return TypeVar(name, _ConstGenericMarker, constraint)
|
return TypeVar(name, _ConstGenericMarker, constraint)
|
||||||
|
|
||||||
|
N = TypeVar("N", bound=np.uint64)
|
||||||
|
class _NDArrayDummy(Generic[T, N]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# https://stackoverflow.com/questions/67803260/how-to-create-a-type-alias-with-a-throw-away-generic
|
||||||
|
NDArray = Union[npt.NDArray[T], _NDArrayDummy[T, N]]
|
||||||
|
|
||||||
def round_away_zero(x):
|
def round_away_zero(x):
|
||||||
if x >= 0.0:
|
if x >= 0.0:
|
||||||
return math.floor(x + 0.5)
|
return math.floor(x + 0.5)
|
||||||
|
@ -124,6 +132,16 @@ def patch(module):
|
||||||
module.ceil64 = math.ceil
|
module.ceil64 = math.ceil
|
||||||
module.np_ceil = np.ceil
|
module.np_ceil = np.ceil
|
||||||
|
|
||||||
|
# NumPy ndarray functions
|
||||||
|
module.ndarray = NDArray
|
||||||
|
module.np_ndarray = np.ndarray
|
||||||
|
module.np_empty = np.empty
|
||||||
|
module.np_zeros = np.zeros
|
||||||
|
module.np_ones = np.ones
|
||||||
|
module.np_full = np.full
|
||||||
|
module.np_eye = np.eye
|
||||||
|
module.np_identity = np.identity
|
||||||
|
|
||||||
# NumPy Math functions
|
# NumPy Math functions
|
||||||
module.np_isnan = np.isnan
|
module.np_isnan = np.isnan
|
||||||
module.np_isinf = np.isinf
|
module.np_isinf = np.isinf
|
||||||
|
@ -166,6 +184,9 @@ def patch(module):
|
||||||
module.sp_spec_j0 = special.j0
|
module.sp_spec_j0 = special.j0
|
||||||
module.sp_spec_j1 = special.j1
|
module.sp_spec_j1 = special.j1
|
||||||
|
|
||||||
|
# NumPy NDArray Functions
|
||||||
|
module.np_ndarray = np.ndarray
|
||||||
|
module.np_empty = np.empty
|
||||||
|
|
||||||
def file_import(filename, prefix="file_import_"):
|
def file_import(filename, prefix="file_import_"):
|
||||||
filename = pathlib.Path(filename)
|
filename = pathlib.Path(filename)
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
def consume_ndarray_1(n: ndarray[float, Literal[1]]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def consume_ndarray_i32_1(n: ndarray[int32, Literal[1]]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def consume_ndarray_2(n: ndarray[float, Literal[2]]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_ndarray_ctor():
|
||||||
|
n = np_ndarray([1])
|
||||||
|
consume_ndarray_1(n)
|
||||||
|
|
||||||
|
def test_ndarray_empty():
|
||||||
|
n = np_empty([1])
|
||||||
|
consume_ndarray_1(n)
|
||||||
|
|
||||||
|
def run() -> int32:
|
||||||
|
test_ndarray_ctor()
|
||||||
|
test_ndarray_empty()
|
||||||
|
|
||||||
|
return 0
|
Loading…
Reference in New Issue