forked from M-Labs/nac3
1
0
Fork 0

core: Implement ndarray constructor and numpy.empty

This commit is contained in:
David Mak 2023-11-17 17:30:27 +08:00
parent afa7d9b100
commit 27fcf8926e
9 changed files with 619 additions and 4 deletions

View File

@ -196,4 +196,48 @@ double __nac3_j0(double x) {
}
return j0(x);
}
uint32_t __nac3_ndarray_calc_size(
const int32_t *list_data,
uint32_t list_len
) {
uint32_t num_elems = 1;
for (uint32_t i = 0; i < list_len; ++i) {
int32_t val = list_data[i];
__builtin_assume(val >= 0);
num_elems *= (uint32_t) list_data[i];
}
return num_elems;
}
uint64_t __nac3_ndarray_calc_size64(
const int32_t *list_data,
uint64_t list_len
) {
uint64_t num_elems = 1;
for (uint64_t i = 0; i < list_len; ++i) {
int32_t val = list_data[i];
__builtin_assume(val >= 0);
num_elems *= (uint64_t) list_data[i];
}
return num_elems;
}
void __nac3_ndarray_init_dims(
uint32_t *ndarray_dims,
const int32_t *shape_data,
uint32_t shape_len
) {
__builtin_memcpy(ndarray_dims, shape_data, shape_len * sizeof(int32_t));
}
void __nac3_ndarray_init_dims64(
uint64_t *ndarray_dims,
const int32_t *shape_data,
uint64_t shape_len
) {
for (uint64_t i = 0; i < shape_len; ++i) {
ndarray_dims[i] = (uint64_t) shape_data[i];
}
}

View File

@ -12,6 +12,9 @@ use inkwell::{
};
use nac3parser::ast::Expr;
#[cfg(debug_assertions)]
use inkwell::types::AnyTypeEnum;
#[must_use]
pub fn load_irrt(ctx: &Context) -> Module {
let bitcode_buf = MemoryBuffer::create_from_memory_range(
@ -546,3 +549,176 @@ pub fn call_j0<'ctx>(
.unwrap_left()
.into_float_value()
}
/// Checks whether the pointer `value` refers to a `list` in LLVM.
fn assert_is_list(value: PointerValue) -> PointerValue {
#[cfg(debug_assertions)]
{
let llvm_shape_ty = value.get_type().get_element_type();
let AnyTypeEnum::StructType(llvm_shape_ty) = llvm_shape_ty else {
panic!("Expected struct type for `list` type, but got {llvm_shape_ty}")
};
assert_eq!(llvm_shape_ty.count_fields(), 2);
assert!(matches!(llvm_shape_ty.get_field_type_at_index(0), Some(BasicTypeEnum::PointerType(..))));
assert!(matches!(llvm_shape_ty.get_field_type_at_index(1), Some(BasicTypeEnum::IntType(..))));
}
value
}
/// Checks whether the pointer `value` refers to an `NDArray` in LLVM.
fn assert_is_ndarray(value: PointerValue) -> PointerValue {
#[cfg(debug_assertions)]
{
let llvm_ndarray_ty = value.get_type().get_element_type();
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
panic!("Expected struct type for `NDArray` type, but got {llvm_ndarray_ty}")
};
assert_eq!(llvm_ndarray_ty.count_fields(), 3);
assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(0), Some(BasicTypeEnum::IntType(..))));
let Some(ndarray_dims) = llvm_ndarray_ty.get_field_type_at_index(1) else {
unreachable!()
};
let BasicTypeEnum::PointerType(dims) = ndarray_dims else {
panic!("Expected pointer type for `list.1`, but got {ndarray_dims}")
};
assert!(matches!(dims.get_element_type(), AnyTypeEnum::IntType(..)));
assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(2), Some(BasicTypeEnum::PointerType(..))));
}
value
}
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [IntValue] representing the
/// calculated total size.
///
/// * `shape` - LLVM pointer to the `shape` of the NDArray. This value must be the LLVM
/// representation of a `list`.
pub fn call_ndarray_calc_size<'ctx, 'a>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>,
shape: PointerValue<'ctx>,
) -> IntValue<'ctx> {
assert_is_list(shape);
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let ndarray_calc_size_fn_name = match generator.get_size_type(ctx.ctx).get_bit_width() {
32 => "__nac3_ndarray_calc_size",
64 => "__nac3_ndarray_calc_size64",
bw => unreachable!("Unsupported size type bit width: {}", bw)
};
let ndarray_calc_size_fn_t = llvm_usize.fn_type(
&[
llvm_pi32.into(),
llvm_usize.into(),
],
false,
);
let ndarray_calc_size_fn = ctx.module.get_function(ndarray_calc_size_fn_name)
.unwrap_or_else(|| {
ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
});
let (
shape_data,
shape_len,
) = unsafe {
(
ctx.builder.build_in_bounds_gep(
shape,
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
""
),
ctx.builder.build_in_bounds_gep(
shape,
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
""
),
)
};
ctx.builder
.build_call(
ndarray_calc_size_fn,
&[
ctx.builder.build_load(shape_data, "").into(),
ctx.builder.build_load(shape_len, "").into(),
],
"",
)
.try_as_basic_value()
.unwrap_left()
.into_int_value()
}
/// Generates a call to `__nac3_ndarray_init_dims`.
///
/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an
/// `NDArray`.
/// * `shape` - LLVM pointer to the `shape` of the NDArray. This value must be the LLVM
/// representation of a `list`.
pub fn call_ndarray_init_dims<'ctx, 'a>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>,
ndarray: PointerValue<'ctx>,
shape: PointerValue<'ctx>,
) {
assert_is_ndarray(ndarray);
assert_is_list(shape);
let llvm_void = ctx.ctx.void_type();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_init_dims_fn_name = match generator.get_size_type(ctx.ctx).get_bit_width() {
32 => "__nac3_ndarray_init_dims",
64 => "__nac3_ndarray_init_dims64",
bw => unreachable!("Unsupported size type bit width: {}", bw)
};
let ndarray_init_dims_fn = ctx.module.get_function(ndarray_init_dims_fn_name).unwrap_or_else(|| {
let fn_type = llvm_void.fn_type(
&[
llvm_pusize.into(),
llvm_pi32.into(),
llvm_usize.into(),
],
false,
);
ctx.module.add_function(ndarray_init_dims_fn_name, fn_type, None)
});
let ndarray_dims = ctx.build_gep_and_load(
ndarray,
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
None,
);
let shape_data = ctx.build_gep_and_load(
shape,
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
None
);
let ndarray_num_dims = ctx.build_gep_and_load(
ndarray,
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
None,
).into_int_value();
ctx.builder.build_call(
ndarray_init_dims_fn,
&[
ndarray_dims.into(),
shape_data.into(),
ndarray_num_dims.into(),
],
"",
);
}

View File

@ -16,7 +16,7 @@ use inkwell::{
attributes::{Attribute, AttributeLoc},
basic_block::BasicBlock,
types::BasicTypeEnum,
values::{BasicValue, BasicValueEnum, FunctionValue, PointerValue},
values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue},
IntPredicate,
};
use nac3parser::ast::{
@ -405,6 +405,80 @@ pub fn gen_for<G: CodeGenerator>(
Ok(())
}
/// Generates a C-style `for` construct using lambdas, similar to the following C code:
///
/// ```c
/// for (x... = init(); cond(x...); update(x...)) {
/// body(x...);
/// }
/// ```
///
/// * `init` - A lambda containing IR statements declaring and initializing loop variables. The
/// return value is a [Clone] value which will be passed to the other lambdas.
/// * `cond` - A lambda containing IR statements checking whether the loop should continue
/// executing. The result value must be an `i1` indicating if the loop should continue.
/// * `body` - A lambda containing IR statements within the loop body.
/// * `update` - A lambda containing IR statements updating loop variables.
pub fn gen_for_callback<'ctx, 'a, I, InitFn, CondFn, BodyFn, UpdateFn>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>,
init: InitFn,
cond: CondFn,
body: BodyFn,
update: UpdateFn,
) -> Result<(), String>
where
I: Clone,
InitFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>) -> Result<I, String>,
CondFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, I) -> Result<IntValue<'ctx>, String>,
BodyFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
UpdateFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
{
let current = ctx.builder.get_insert_block().and_then(|bb| bb.get_parent()).unwrap();
let init_bb = ctx.ctx.append_basic_block(current, "for.init");
// The BB containing the loop condition check
let cond_bb = ctx.ctx.append_basic_block(current, "for.cond");
let body_bb = ctx.ctx.append_basic_block(current, "for.body");
// The BB containing the increment expression
let update_bb = ctx.ctx.append_basic_block(current, "for.update");
let cont_bb = ctx.ctx.append_basic_block(current, "for.end");
// store loop bb information and restore it later
let loop_bb = ctx.loop_target.replace((update_bb, cont_bb));
ctx.builder.build_unconditional_branch(init_bb);
let loop_var = {
ctx.builder.position_at_end(init_bb);
let result = init(generator, ctx)?;
ctx.builder.build_unconditional_branch(cond_bb);
result
};
ctx.builder.position_at_end(cond_bb);
let cond = cond(generator, ctx, loop_var.clone())?;
assert_eq!(cond.get_type().get_bit_width(), ctx.ctx.bool_type().get_bit_width());
ctx.builder.build_conditional_branch(
cond,
body_bb,
cont_bb
);
ctx.builder.position_at_end(body_bb);
body(generator, ctx, loop_var.clone())?;
ctx.builder.build_unconditional_branch(update_bb);
ctx.builder.position_at_end(update_bb);
update(generator, ctx, loop_var)?;
ctx.builder.build_unconditional_branch(cond_bb);
ctx.builder.position_at_end(cont_bb);
ctx.loop_target = loop_bb;
Ok(())
}
/// See [`CodeGenerator::gen_while`].
pub fn gen_while<G: CodeGenerator>(
generator: &mut G,

View File

@ -13,11 +13,12 @@ use crate::{
stmt::exn_constructor,
},
symbol_resolver::SymbolValue,
toplevel::numpy::gen_ndarray_empty,
};
use inkwell::{
attributes::{Attribute, AttributeLoc},
types::{BasicType, BasicMetadataTypeEnum},
values::BasicMetadataValueEnum,
values::{BasicValue, BasicMetadataValueEnum},
FloatPredicate,
IntPredicate
};
@ -278,6 +279,11 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
let boolean = primitives.0.bool;
let range = primitives.0.range;
let string = primitives.0.str;
let ndarray_float = {
let ndarray_ty_enum = TypeEnum::ndarray(&mut primitives.1, Some(float), None, &primitives.0);
primitives.1.add_ty(ndarray_ty_enum)
};
let list_int32 = primitives.1.add_ty(TypeEnum::TList { ty: int32 });
let num_ty = primitives.1.get_fresh_var_with_range(
&[int32, int64, float, boolean, uint32, uint64],
Some("N".into()),
@ -837,6 +843,32 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
)))),
loc: None,
})),
create_fn_by_codegen(
primitives,
&var_map,
"np_ndarray",
ndarray_float,
// We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a
// type variable
&[(list_int32, "shape")],
Box::new(|ctx, obj, fun, args, generator| {
gen_ndarray_empty(ctx, obj, fun, args, generator)
.map(|val| Some(val.as_basic_value_enum()))
}),
),
create_fn_by_codegen(
primitives,
&var_map,
"np_empty",
ndarray_float,
// We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a
// type variable
&[(list_int32, "shape")],
Box::new(|ctx, obj, fun, args, generator| {
gen_ndarray_empty(ctx, obj, fun, args, generator)
.map(|val| Some(val.as_basic_value_enum()))
}),
),
create_fn_by_codegen(
primitives,
&var_map,

View File

@ -25,6 +25,7 @@ pub struct DefinitionId(pub usize);
pub mod builtins;
pub mod composer;
pub mod helper;
pub mod numpy;
pub mod type_annotation;
use composer::*;
use type_annotation::*;

View File

@ -0,0 +1,198 @@
use inkwell::{
IntPredicate,
types::BasicType,
values::PointerValue,
};
use nac3parser::ast::StrRef;
use crate::{
codegen::{
CodeGenContext,
CodeGenerator,
irrt::{call_ndarray_calc_size, call_ndarray_init_dims},
stmt::gen_for_callback
},
symbol_resolver::ValueEnum,
toplevel::DefinitionId,
typecheck::typedef::{FunSignature, Type, TypeEnum},
};
/// LLVM-typed implementation for generating the implementation for constructing an `NDArray`.
///
/// * `elem_ty` - The element type of the NDArray.
/// * `var_name` - The variable name of the NDArray.
/// * `shape` - The `shape` parameter used to construct the NDArray.
fn call_ndarray_impl<'ctx, 'a>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>,
elem_ty: Type,
var_name: Option<&str>,
shape: PointerValue<'ctx>,
) -> Result<PointerValue<'ctx>, String> {
let ndarray_ty_enum = TypeEnum::ndarray(&mut ctx.unifier, Some(elem_ty), None, &ctx.primitives);
let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum);
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum();
assert!(llvm_ndarray_data_t.is_sized());
// Assert that all dimensions are non-negative
gen_for_callback(
generator,
ctx,
|_, ctx| {
let i = ctx.builder.build_alloca(llvm_usize, "");
ctx.builder.build_store(i, llvm_usize.const_zero());
Ok(i)
},
|_, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.into_int_value();
let shape_len = ctx.build_gep_and_load(
shape,
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
None,
).into_int_value();
Ok(ctx.builder.build_int_compare(IntPredicate::ULE, i, shape_len, ""))
},
|generator, ctx, i_addr| {
let shape_elems = ctx.build_gep_and_load(
shape,
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
None
).into_pointer_value();
let i = ctx.builder
.build_load(i_addr, "")
.into_int_value();
let shape_dim = ctx.build_gep_and_load(
shape_elems,
&[i],
None
).into_int_value();
let shape_dim_gez = ctx.builder.build_int_compare(
IntPredicate::SGE,
shape_dim,
llvm_i32.const_zero(),
""
);
ctx.make_assert(
generator,
shape_dim_gez,
"0:ValueError",
"negative dimensions not supported",
[None, None, None],
ctx.current_loc,
);
Ok(())
},
|_, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.into_int_value();
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "");
ctx.builder.build_store(i_addr, i);
Ok(())
},
)?;
let ndarray = ctx.builder.build_alloca(
llvm_ndarray_t,
var_name.unwrap_or_default()
);
let num_dims = ctx.build_gep_and_load(
shape,
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
None
).into_int_value();
let ndarray_num_dims = unsafe {
ctx.builder.build_in_bounds_gep(
ndarray,
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
"",
)
};
ctx.builder.build_store(ndarray_num_dims, num_dims);
let ndarray_dims = unsafe {
ctx.builder.build_in_bounds_gep(
ndarray,
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
"",
)
};
let ndarray_num_dims = ctx.build_gep_and_load(
ndarray,
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
None,
).into_int_value();
ctx.builder.build_store(
ndarray_dims,
ctx.builder.build_array_alloca(
llvm_usize,
ndarray_num_dims,
"",
),
);
call_ndarray_init_dims(generator, ctx, ndarray, shape);
let ndarray_num_elems = call_ndarray_calc_size(generator, ctx, shape);
let ndarray_data = unsafe {
ctx.builder.build_in_bounds_gep(
ndarray,
&[llvm_i32.const_zero(), llvm_i32.const_int(2, true)],
"",
)
};
ctx.builder.build_store(
ndarray_data,
ctx.builder.build_array_alloca(
llvm_ndarray_data_t,
ndarray_num_elems,
"",
),
);
Ok(ndarray)
}
/// Generates LLVM IR for `ndarray.empty`.
pub fn gen_ndarray_empty<'ctx, 'a>(
context: &mut CodeGenContext<'ctx, 'a>,
obj: Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let shape_ty = fun.0.args[0].ty;
let shape_arg_name = args[0].0;
let shape_arg = args[0].1.clone()
.to_basic_value_enum(context, generator, shape_ty)?;
call_ndarray_impl(
generator,
context,
context.primitives.float,
shape_arg_name.map(|name| name.to_string()).as_deref(),
shape_arg.into_pointer_value(),
)
}

View File

@ -5,7 +5,7 @@ use std::{cell::RefCell, sync::Arc};
use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier};
use super::{magic_methods::*, typedef::CallId};
use crate::{symbol_resolver::SymbolResolver, toplevel::TopLevelContext};
use crate::{symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::TopLevelContext};
use itertools::izip;
use nac3parser::ast::{
self,
@ -894,6 +894,53 @@ impl<'a> Inferencer<'a> {
}
}
// 1-argument ndarray n-dimensional creation functions
if [
"np_ndarray".into(),
"np_empty".into(),
].contains(id) && args.len() == 1 {
let ExprKind::List { elts, .. } = &args[0].node else {
return report_error("Expected List literal for first argument of np_ndarray", args[0].location)
};
let ndims = elts.len() as u64;
let arg0 = self.fold_expr(args.remove(0))?;
let ndims = self.unifier.get_fresh_literal(
vec![SymbolValue::U64(ndims)],
None,
);
let ret = self.unifier.add_ty(TypeEnum::TNDArray {
ty: self.primitives.float,
ndims
});
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg {
name: "shape".into(),
ty: arg0.custom.unwrap(),
default_value: None,
},
],
ret,
vars: HashMap::new(),
}));
return Ok(Some(Located {
location,
custom: Some(ret),
node: ExprKind::Call {
func: Box::new(Located {
custom: Some(custom),
location: func.location,
node: ExprKind::Name { id: *id, ctx: ctx.clone() },
}),
args: vec![arg0],
keywords: vec![],
},
}))
}
Ok(None)
}

View File

@ -5,11 +5,12 @@ import importlib.util
import importlib.machinery
import math
import numpy as np
import numpy.typing as npt
import pathlib
from numpy import int32, int64, uint32, uint64
from scipy import special
from typing import TypeVar, Generic, Literal
from typing import TypeVar, Generic, Literal, Union
T = TypeVar('T')
class Option(Generic[T]):
@ -50,6 +51,13 @@ class _ConstGenericMarker:
def ConstGeneric(name, constraint):
return TypeVar(name, _ConstGenericMarker, constraint)
N = TypeVar("N", bound=np.uint64)
class _NDArrayDummy(Generic[T, N]):
pass
# https://stackoverflow.com/questions/67803260/how-to-create-a-type-alias-with-a-throw-away-generic
NDArray = Union[npt.NDArray[T], _NDArrayDummy[T, N]]
def round_away_zero(x):
if x >= 0.0:
return math.floor(x + 0.5)
@ -124,6 +132,16 @@ def patch(module):
module.ceil64 = math.ceil
module.np_ceil = np.ceil
# NumPy ndarray functions
module.ndarray = NDArray
module.np_ndarray = np.ndarray
module.np_empty = np.empty
module.np_zeros = np.zeros
module.np_ones = np.ones
module.np_full = np.full
module.np_eye = np.eye
module.np_identity = np.identity
# NumPy Math functions
module.np_isnan = np.isnan
module.np_isinf = np.isinf
@ -166,6 +184,9 @@ def patch(module):
module.sp_spec_j0 = special.j0
module.sp_spec_j1 = special.j1
# NumPy NDArray Functions
module.np_ndarray = np.ndarray
module.np_empty = np.empty
def file_import(filename, prefix="file_import_"):
filename = pathlib.Path(filename)

View File

@ -0,0 +1,22 @@
def consume_ndarray_1(n: ndarray[float, Literal[1]]):
pass
def consume_ndarray_i32_1(n: ndarray[int32, Literal[1]]):
pass
def consume_ndarray_2(n: ndarray[float, Literal[2]]):
pass
def test_ndarray_ctor():
n = np_ndarray([1])
consume_ndarray_1(n)
def test_ndarray_empty():
n = np_empty([1])
consume_ndarray_1(n)
def run() -> int32:
test_ndarray_ctor()
test_ndarray_empty()
return 0