core: Implement ndarray constructor and numpy.empty

David Mak 2023-11-17 17:30:27 +08:00
parent afa7d9b100
commit 6ba4ef8961
9 changed files with 595 additions and 3 deletions

View File

@ -196,4 +196,48 @@ double __nac3_j0(double x) {
}
return j0(x);
}
uint32_t __nac3_ndarray_calc_size(
const int32_t *list_data,
uint32_t list_len
) {
uint32_t num_elems = 1;
for (uint32_t i = 0; i < list_len; ++i) {
int32_t val = list_data[i];
__builtin_assume(val >= 0);
num_elems *= (uint32_t) list_data[i];
}
return num_elems;
}
uint64_t __nac3_ndarray_calc_size64(
const int32_t *list_data,
uint64_t list_len
) {
uint64_t num_elems = 1;
for (uint64_t i = 0; i < list_len; ++i) {
int32_t val = list_data[i];
__builtin_assume(val >= 0);
num_elems *= (uint64_t) list_data[i];
}
return num_elems;
}
void __nac3_ndarray_init_dims(
uint32_t *ndarray_dims,
const int32_t *shape_data,
uint32_t shape_len
) {
__builtin_memcpy(ndarray_dims, shape_data, shape_len * sizeof(int32_t));
}
void __nac3_ndarray_init_dims64(
uint64_t *ndarray_dims,
const int32_t *shape_data,
uint64_t shape_len
) {
for (uint64_t i = 0; i < shape_len; ++i) {
ndarray_dims[i] = (uint64_t) shape_data[i];
}
}

View File

@ -12,6 +12,9 @@ use inkwell::{
};
use nac3parser::ast::Expr;
#[cfg(debug_assertions)]
use inkwell::types::AnyTypeEnum;
#[must_use]
pub fn load_irrt(ctx: &Context) -> Module {
let bitcode_buf = MemoryBuffer::create_from_memory_range(
@ -546,3 +549,176 @@ pub fn call_j0<'ctx>(
.unwrap_left()
.into_float_value()
}
/// Checks whether the pointer `value` refers to a `list` in LLVM.
fn assert_is_list(value: PointerValue) -> PointerValue {
#[cfg(debug_assertions)]
{
let llvm_shape_ty = value.get_type().get_element_type();
let AnyTypeEnum::StructType(llvm_shape_ty) = llvm_shape_ty else {
panic!("Expected struct type for `list` type, but got {llvm_shape_ty}")
};
assert_eq!(llvm_shape_ty.count_fields(), 2);
assert!(matches!(llvm_shape_ty.get_field_type_at_index(0), Some(BasicTypeEnum::PointerType(..))));
assert!(matches!(llvm_shape_ty.get_field_type_at_index(1), Some(BasicTypeEnum::IntType(..))));
}
value
}
/// Checks whether the pointer `value` refers to an `NDArray` in LLVM.
fn assert_is_ndarray(value: PointerValue) -> PointerValue {
#[cfg(debug_assertions)]
{
let llvm_ndarray_ty = value.get_type().get_element_type();
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
panic!("Expected struct type for `NDArray` type, but got {llvm_ndarray_ty}")
};
assert_eq!(llvm_ndarray_ty.count_fields(), 3);
assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(0), Some(BasicTypeEnum::IntType(..))));
let Some(ndarray_dims) = llvm_ndarray_ty.get_field_type_at_index(1) else {
unreachable!()
};
let BasicTypeEnum::PointerType(dims) = ndarray_dims else {
panic!("Expected pointer type for `list.1`, but got {ndarray_dims}")
};
assert!(matches!(dims.get_element_type(), AnyTypeEnum::IntType(..)));
assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(2), Some(BasicTypeEnum::PointerType(..))));
}
value
}
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [IntValue] representing the
/// calculated total size.
///
/// * `shape` - LLVM pointer to the `shape` of the NDArray. This value must be the LLVM
/// representation of a `list`.
pub fn call_ndarray_calc_size<'ctx, 'a>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>,
shape: PointerValue<'ctx>,
) -> IntValue<'ctx> {
assert_is_list(shape);
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let ndarray_calc_size_fn_name = match generator.get_size_type(ctx.ctx).get_bit_width() {
32 => "__nac3_ndarray_calc_size",
64 => "__nac3_ndarray_calc_size64",
bw => unreachable!("Unsupported size type bit width: {}", bw)
};
let ndarray_calc_size_fn_t = llvm_usize.fn_type(
&[
llvm_pi32.into(),
llvm_usize.into(),
],
false,
);
let ndarray_calc_size_fn = ctx.module.get_function(ndarray_calc_size_fn_name)
.unwrap_or_else(|| {
ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
});
let (
shape_data,
shape_len,
) = unsafe {
(
ctx.builder.build_in_bounds_gep(
shape,
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
""
),
ctx.builder.build_in_bounds_gep(
shape,
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
""
),
)
};
ctx.builder
.build_call(
ndarray_calc_size_fn,
&[
ctx.builder.build_load(shape_data, "").into(),
ctx.builder.build_load(shape_len, "").into(),
],
"",
)
.try_as_basic_value()
.unwrap_left()
.into_int_value()
}
/// Generates a call to `__nac3_ndarray_init_dims`.
///
/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an
/// `NDArray`.
/// * `shape` - LLVM pointer to the `shape` of the NDArray. This value must be the LLVM
/// representation of a `list`.
pub fn call_ndarray_init_dims<'ctx, 'a>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>,
ndarray: PointerValue<'ctx>,
shape: PointerValue<'ctx>,
) {
assert_is_ndarray(ndarray);
assert_is_list(shape);
let llvm_void = ctx.ctx.void_type();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_init_dims_fn_name = match generator.get_size_type(ctx.ctx).get_bit_width() {
32 => "__nac3_ndarray_init_dims",
64 => "__nac3_ndarray_init_dims64",
bw => unreachable!("Unsupported size type bit width: {}", bw)
};
let ndarray_init_dims_fn = ctx.module.get_function(ndarray_init_dims_fn_name).unwrap_or_else(|| {
let fn_type = llvm_void.fn_type(
&[
llvm_pusize.into(),
llvm_pi32.into(),
llvm_usize.into(),
],
false,
);
ctx.module.add_function(ndarray_init_dims_fn_name, fn_type, None)
});
let ndarray_dims = ctx.build_gep_and_load(
ndarray,
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
None,
);
let shape_data = ctx.build_gep_and_load(
shape,
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
None
);
let ndarray_num_dims = ctx.build_gep_and_load(
ndarray,
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
None,
).into_int_value();
ctx.builder.build_call(
ndarray_init_dims_fn,
&[
ndarray_dims.into(),
shape_data.into(),
ndarray_num_dims.into(),
],
"",
);
}

View File

@ -16,7 +16,7 @@ use inkwell::{
attributes::{Attribute, AttributeLoc},
basic_block::BasicBlock,
types::BasicTypeEnum,
values::{BasicValue, BasicValueEnum, FunctionValue, PointerValue},
values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue},
IntPredicate,
};
use nac3parser::ast::{
@ -405,6 +405,80 @@ pub fn gen_for<G: CodeGenerator>(
Ok(())
}
/// Generates a C-style `for` construct using lambdas, similar to the following C code:
///
/// ```c
/// for (x... = init(); cond(x...); update(x...)) {
/// body(x...);
/// }
/// ```
///
/// * `init` - A lambda containing IR statements declaring and initializing loop variables. The
/// return value is a [Clone] value which will be passed to the other lambdas.
/// * `cond` - A lambda containing IR statements checking whether the loop should continue
/// executing. The result value must be an `i1` indicating if the loop should continue.
/// * `body` - A lambda containing IR statements within the loop body.
/// * `update` - A lambda containing IR statements updating loop variables.
pub fn gen_for_callback<'ctx, 'a, I, InitFn, CondFn, BodyFn, UpdateFn>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>,
init: InitFn,
cond: CondFn,
body: BodyFn,
update: UpdateFn,
) -> Result<(), String>
where
I: Clone,
InitFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>) -> Result<I, String>,
CondFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, I) -> Result<IntValue<'ctx>, String>,
BodyFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
UpdateFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
{
let current = ctx.builder.get_insert_block().and_then(|bb| bb.get_parent()).unwrap();
let init_bb = ctx.ctx.append_basic_block(current, "for.init");
// The BB containing the loop condition check
let cond_bb = ctx.ctx.append_basic_block(current, "for.cond");
let body_bb = ctx.ctx.append_basic_block(current, "for.body");
// The BB containing the increment expression
let update_bb = ctx.ctx.append_basic_block(current, "for.update");
let cont_bb = ctx.ctx.append_basic_block(current, "for.end");
// store loop bb information and restore it later
let loop_bb = ctx.loop_target.replace((update_bb, cont_bb));
ctx.builder.build_unconditional_branch(init_bb);
let loop_var = {
ctx.builder.position_at_end(init_bb);
let result = init(generator, ctx)?;
ctx.builder.build_unconditional_branch(cond_bb);
result
};
ctx.builder.position_at_end(cond_bb);
let cond = cond(generator, ctx, loop_var.clone())?;
assert_eq!(cond.get_type().get_bit_width(), ctx.ctx.bool_type().get_bit_width());
ctx.builder.build_conditional_branch(
cond,
body_bb,
cont_bb
);
ctx.builder.position_at_end(body_bb);
body(generator, ctx, loop_var.clone())?;
ctx.builder.build_unconditional_branch(update_bb);
ctx.builder.position_at_end(update_bb);
update(generator, ctx, loop_var)?;
ctx.builder.build_unconditional_branch(cond_bb);
ctx.builder.position_at_end(cont_bb);
ctx.loop_target = loop_bb;
Ok(())
}
/// See [`CodeGenerator::gen_while`].
pub fn gen_while<G: CodeGenerator>(
generator: &mut G,

View File

@ -13,11 +13,12 @@ use crate::{
stmt::exn_constructor,
},
symbol_resolver::SymbolValue,
toplevel::numpy::gen_ndarray_empty,
};
use inkwell::{
attributes::{Attribute, AttributeLoc},
types::{BasicType, BasicMetadataTypeEnum},
values::BasicMetadataValueEnum,
values::{BasicValue, BasicMetadataValueEnum},
FloatPredicate,
IntPredicate
};
@ -278,6 +279,12 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
let boolean = primitives.0.bool;
let range = primitives.0.range;
let string = primitives.0.str;
// TODO(Derppening): Could we make this generic?
let ndarray_float = {
let ndarray_ty_enum = TypeEnum::ndarray(&mut primitives.1, Some(float), None, &primitives.0);
primitives.1.add_ty(ndarray_ty_enum)
};
let list_int32 = primitives.1.add_ty(TypeEnum::TList { ty: int32 });
let num_ty = primitives.1.get_fresh_var_with_range(
&[int32, int64, float, boolean, uint32, uint64],
Some("N".into()),
@ -837,6 +844,32 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
)))),
loc: None,
})),
create_fn_by_codegen(
primitives,
&var_map,
"np_ndarray",
ndarray_float,
// We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a
// type variable
&[(list_int32, "shape")],
Box::new(|ctx, obj, fun, args, generator| {
gen_ndarray_empty(ctx, obj, fun, args, generator)
.map(|val| Some(val.as_basic_value_enum()))
}),
),
create_fn_by_codegen(
primitives,
&var_map,
"np_empty",
ndarray_float,
// We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a
// type variable
&[(list_int32, "shape")],
Box::new(|ctx, obj, fun, args, generator| {
gen_ndarray_empty(ctx, obj, fun, args, generator)
.map(|val| Some(val.as_basic_value_enum()))
}),
),
create_fn_by_codegen(
primitives,
&var_map,

View File

@ -25,6 +25,7 @@ pub struct DefinitionId(pub usize);
pub mod builtins;
pub mod composer;
pub mod helper;
mod numpy;
pub mod type_annotation;
use composer::*;
use type_annotation::*;

View File

@ -0,0 +1,198 @@
use inkwell::{
IntPredicate,
types::BasicType,
values::PointerValue,
};
use nac3parser::ast::StrRef;
use crate::{
codegen::{
CodeGenContext,
CodeGenerator,
irrt::{call_ndarray_calc_size, call_ndarray_init_dims},
stmt::gen_for_callback
},
symbol_resolver::ValueEnum,
toplevel::DefinitionId,
typecheck::typedef::{FunSignature, Type, TypeEnum},
};
/// LLVM-typed implementation for generating the implementation for constructing an `NDArray`.
///
/// * `elem_ty` - The element type of the NDArray.
/// * `var_name` - The variable name of the NDArray.
/// * `shape` - The `shape` parameter used to construct the NDArray.
fn call_ndarray_impl<'ctx, 'a>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>,
elem_ty: Type,
var_name: Option<&str>,
shape: PointerValue<'ctx>,
) -> Result<PointerValue<'ctx>, String> {
let ndarray_ty_enum = TypeEnum::ndarray(&mut ctx.unifier, Some(elem_ty), None, &ctx.primitives);
let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum);
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum();
assert!(llvm_ndarray_data_t.is_sized());
// Assert that all dimensions are non-negative
gen_for_callback(
generator,
ctx,
|_, ctx| {
let i = ctx.builder.build_alloca(llvm_usize, "");
ctx.builder.build_store(i, llvm_usize.const_zero());
Ok(i)
},
|_, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.into_int_value();
let shape_len = ctx.build_gep_and_load(
shape,
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
None,
).into_int_value();
Ok(ctx.builder.build_int_compare(IntPredicate::ULE, i, shape_len, ""))
},
|generator, ctx, i_addr| {
let shape_elems = ctx.build_gep_and_load(
shape,
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
None
).into_pointer_value();
let i = ctx.builder
.build_load(i_addr, "")
.into_int_value();
let shape_dim = ctx.build_gep_and_load(
shape_elems,
&[i],
None
).into_int_value();
let shape_dim_gez = ctx.builder.build_int_compare(
IntPredicate::SGE,
shape_dim,
llvm_i32.const_zero(),
""
);
ctx.make_assert(
generator,
shape_dim_gez,
"0:ValueError",
"negative dimensions not supported",
[None, None, None],
ctx.current_loc,
);
Ok(())
},
|_, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.into_int_value();
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "");
ctx.builder.build_store(i_addr, i);
Ok(())
},
)?;
let ndarray = ctx.builder.build_alloca(
llvm_ndarray_t,
var_name.unwrap_or_default()
);
let num_dims = ctx.build_gep_and_load(
shape,
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
None
).into_int_value();
let ndarray_num_dims = unsafe {
ctx.builder.build_in_bounds_gep(
ndarray,
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
"",
)
};
ctx.builder.build_store(ndarray_num_dims, num_dims);
let ndarray_dims = unsafe {
ctx.builder.build_in_bounds_gep(
ndarray,
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
"",
)
};
let ndarray_num_dims = ctx.build_gep_and_load(
ndarray,
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
None,
).into_int_value();
ctx.builder.build_store(
ndarray_dims,
ctx.builder.build_array_alloca(
llvm_usize,
ndarray_num_dims,
"",
),
);
call_ndarray_init_dims(generator, ctx, ndarray, shape);
let ndarray_num_elems = call_ndarray_calc_size(generator, ctx, shape);
let ndarray_data = unsafe {
ctx.builder.build_in_bounds_gep(
ndarray,
&[llvm_i32.const_zero(), llvm_i32.const_int(2, true)],
"",
)
};
ctx.builder.build_store(
ndarray_data,
ctx.builder.build_array_alloca(
llvm_ndarray_data_t,
ndarray_num_elems,
"",
),
);
Ok(ndarray)
}
/// Generates LLVM IR for `ndarray.empty`.
pub fn gen_ndarray_empty<'ctx, 'a>(
context: &mut CodeGenContext<'ctx, 'a>,
obj: Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let shape_ty = fun.0.args[0].ty;
let shape_arg_name = args[0].0;
let shape_arg = args[0].1.clone()
.to_basic_value_enum(context, generator, shape_ty)?;
call_ndarray_impl(
generator,
context,
context.primitives.float,
shape_arg_name.map(|name| name.to_string()).as_deref(),
shape_arg.into_pointer_value(),
)
}

View File

@ -5,7 +5,7 @@ use std::{cell::RefCell, sync::Arc};
use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier};
use super::{magic_methods::*, typedef::CallId};
use crate::{symbol_resolver::SymbolResolver, toplevel::TopLevelContext};
use crate::{symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::TopLevelContext};
use itertools::izip;
use nac3parser::ast::{
self,
@ -894,6 +894,53 @@ impl<'a> Inferencer<'a> {
}
}
// 1-argument ndarray n-dimensional creation functions
if [
"np_ndarray".into(),
"np_empty".into(),
].contains(id) && args.len() == 1 {
let ExprKind::List { elts, .. } = &args[0].node else {
return report_error("Expected List literal for first argument of np_ndarray", args[0].location)
};
let ndims = elts.len() as u64;
let arg0 = self.fold_expr(args.remove(0))?;
let ndims = self.unifier.get_fresh_literal(
vec![SymbolValue::U64(ndims)],
None,
);
let ret = self.unifier.add_ty(TypeEnum::TNDArray {
ty: self.primitives.float,
ndims
});
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg {
name: "shape".into(),
ty: arg0.custom.unwrap(),
default_value: None,
},
],
ret,
vars: HashMap::new(),
}));
return Ok(Some(Located {
location,
custom: Some(ret),
node: ExprKind::Call {
func: Box::new(Located {
custom: Some(custom),
location: func.location,
node: ExprKind::Name { id: *id, ctx: ctx.clone() },
}),
args: vec![arg0],
keywords: vec![],
},
}))
}
Ok(None)
}

View File

@ -166,6 +166,9 @@ def patch(module):
module.sp_spec_j0 = special.j0
module.sp_spec_j1 = special.j1
# NumPy NDArray Functions
module.np_ndarray = np.ndarray
module.np_empty = np.empty
def file_import(filename, prefix="file_import_"):
filename = pathlib.Path(filename)

View File

@ -0,0 +1,16 @@
def consume_ndarray_1(n: ndarray[float, 1]):
pass
def test_ndarray_ctor():
n = np_ndarray([1])
consume_ndarray_1(n)
def test_ndarray_empty():
n = np_empty([1])
consume_ndarray_1(n)
def run() -> int32:
test_ndarray_ctor()
test_ndarray_empty()
return 0