Compare commits

...

16 Commits

Author SHA1 Message Date
David Mak d2ce0679ed WIP 2024-03-22 17:16:03 +08:00
David Mak aa673fce4e core: Implement elementwise binary operators
Including immediate variants of these operators.
2024-03-22 17:16:03 +08:00
David Mak ddfd19d00c core: Add handling of ndarrays in gen_binop_expr 2024-03-22 17:16:03 +08:00
David Mak 4887cd8007 core: Implement calculations for broadcasting ndarrays 2024-03-22 17:15:49 +08:00
David Mak 876e850d71 core: Extract codegen portion of gen_binop_expr
This allows binops to be generated internally using LLVM values as
input. Required in a future change.
2024-03-22 17:11:16 +08:00
David Mak 1de2e9a4be core: Remove ArrayValue variants of accessors 2024-03-22 17:11:16 +08:00
David Mak 2b0beea8c0 core: Use more typed slices in APIs 2024-03-22 17:11:16 +08:00
David Mak 5778de02fc core: Fix index-based operations not returning i32 2024-03-22 17:11:15 +08:00
David Mak 13f06f3e29 core: Refactor VarMap to IndexMap
This is the only Map I can find that preserves insertion order while
also deduplicating elements by key.
2024-03-22 15:51:23 +08:00
David Mak f0da9c0283 core: Add ArrayLikeValue
For exposing LLVM values that can be accessed like an array.
2024-03-22 15:51:06 +08:00
David Mak 2c4bf3ce59 core: Allow unsized CodeGenerator to be passed to some codegen functions
Enables codegen_callback to call these codegen functions as well.
2024-03-22 15:07:28 +08:00
David Mak e980f19c93 core: Simplify typed value assertions 2024-03-22 15:07:28 +08:00
David Mak cfbc37c1ed core: Add gen_for_callback_incrementing
Simplifies generation of monotonically increasing for loops.
2024-03-22 15:07:28 +08:00
David Mak 50264e8750 core: Add missing unchecked accessors for NDArrayDimsProxy 2024-03-22 15:07:28 +08:00
David Mak 1b77e62901 core: Split numpy into codegen and toplevel 2024-03-22 15:07:28 +08:00
David Mak fd44ee6887 core: Apply clippy suggestions 2024-03-22 15:07:23 +08:00
28 changed files with 2653 additions and 1461 deletions

1
Cargo.lock generated
View File

@ -616,6 +616,7 @@ name = "nac3core"
version = "0.1.0"
dependencies = [
"crossbeam",
"indexmap 2.2.5",
"indoc",
"inkwell",
"insta",

View File

@ -5,7 +5,7 @@ use nac3core::{
toplevel::{
DefinitionId,
helper::PRIMITIVE_DEF_IDS,
numpy::{make_ndarray_ty, unpack_ndarray_tvars},
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
TopLevelDef,
},
typecheck::{
@ -654,7 +654,7 @@ impl InnerResolver {
}
}
(TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
let (ty, ndims) = unpack_ndarray_tvars(unifier, extracted_ty);
let (ty, ndims) = unpack_ndarray_var_tys(unifier, extracted_ty);
let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?;
if len == 0 {
assert!(matches!(

View File

@ -7,6 +7,7 @@ edition = "2021"
[dependencies]
itertools = "0.12"
crossbeam = "0.8"
indexmap = "2.2"
parking_lot = "0.12"
rayon = "1.8"
nac3parser = { path = "../nac3parser" }

View File

@ -21,7 +21,7 @@ fn main() {
match env::var("PROFILE").as_deref() {
Ok("debug") => "-O0",
Ok("release") => "-O3",
flavor => panic!("Unknown or missing build flavor {:?}", flavor),
flavor => panic!("Unknown or missing build flavor {flavor:?}"),
},
"-emit-llvm",
"-S",

File diff suppressed because it is too large Load Diff

View File

@ -9,6 +9,7 @@ use crate::{
use nac3parser::ast::StrRef;
use std::collections::HashMap;
use indexmap::IndexMap;
pub struct ConcreteTypeStore {
store: Vec<ConcreteTypeEnum>,
@ -50,7 +51,7 @@ pub enum ConcreteTypeEnum {
TObj {
obj_id: DefinitionId,
fields: HashMap<StrRef, (ConcreteType, bool)>,
params: HashMap<u32, ConcreteType>,
params: IndexMap<u32, ConcreteType>,
},
TVirtual {
ty: ConcreteType,

View File

@ -2,13 +2,21 @@ use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
use crate::{
codegen::{
classes::{ListValue, NDArrayValue, RangeValue},
classes::{
ArrayLikeIndexer,
ArrayLikeValue,
ListValue,
NDArrayValue,
RangeValue,
UntypedArrayLikeAccessor,
},
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
gen_in_range_check,
get_llvm_type,
get_llvm_abi_type,
irrt::*,
llvm_intrinsics::{call_expect, call_float_floor, call_float_pow, call_float_powi},
numpy,
stmt::{gen_raise, gen_var},
CodeGenContext, CodeGenTask,
},
@ -16,7 +24,7 @@ use crate::{
toplevel::{
DefinitionId,
helper::PRIMITIVE_DEF_IDS,
numpy::make_ndarray_ty,
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
TopLevelDef,
},
typecheck::{
@ -35,6 +43,7 @@ use itertools::{chain, izip, Itertools, Either};
use nac3parser::ast::{
self, Boolop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef,
};
use crate::codegen::classes::ArraySliceValue;
use super::{CodeGenerator, llvm_intrinsics::call_memcpy_generic, need_sret};
@ -52,7 +61,7 @@ pub fn get_subst_key(
params.clone()
})
.unwrap_or_default();
vars.extend(fun_vars.iter());
vars.extend(fun_vars);
let sorted = vars.keys().filter(|id| filter.map_or(true, |v| v.contains(id))).sorted();
sorted
.map(|id| {
@ -103,9 +112,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
index
}
pub fn gen_symbol_val(
pub fn gen_symbol_val<G: CodeGenerator + ?Sized>(
&mut self,
generator: &mut dyn CodeGenerator,
generator: &mut G,
val: &SymbolValue,
ty: Type,
) -> BasicValueEnum<'ctx> {
@ -174,9 +183,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
}
/// See [`get_llvm_type`].
pub fn get_llvm_type(
pub fn get_llvm_type<G: CodeGenerator + ?Sized>(
&mut self,
generator: &mut dyn CodeGenerator,
generator: &mut G,
ty: Type,
) -> BasicTypeEnum<'ctx> {
get_llvm_type(
@ -191,9 +200,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
}
/// See [`get_llvm_abi_type`].
pub fn get_llvm_abi_type(
pub fn get_llvm_abi_type<G: CodeGenerator + ?Sized>(
&mut self,
generator: &mut dyn CodeGenerator,
generator: &mut G,
ty: Type,
) -> BasicTypeEnum<'ctx> {
get_llvm_abi_type(
@ -209,9 +218,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
}
/// Generates an LLVM variable for a [constant value][value] with a given [type][ty].
pub fn gen_const(
pub fn gen_const<G: CodeGenerator + ?Sized>(
&mut self,
generator: &mut dyn CodeGenerator,
generator: &mut G,
value: &Constant,
ty: Type,
) -> Option<BasicValueEnum<'ctx>> {
@ -291,9 +300,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
}
/// Generates a binary operation `op` between two integral operands `lhs` and `rhs`.
pub fn gen_int_ops(
pub fn gen_int_ops<G: CodeGenerator + ?Sized>(
&mut self,
generator: &mut dyn CodeGenerator,
generator: &mut G,
op: &Operator,
lhs: BasicValueEnum<'ctx>,
rhs: BasicValueEnum<'ctx>,
@ -492,17 +501,21 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
}
/// Helper function for generating a LLVM variable storing a [String].
pub fn gen_string<S: Into<String>>(
pub fn gen_string<G, S>(
&mut self,
generator: &mut dyn CodeGenerator,
generator: &mut G,
s: S,
) -> BasicValueEnum<'ctx> {
) -> BasicValueEnum<'ctx>
where
G: CodeGenerator + ?Sized,
S: Into<String>,
{
self.gen_const(generator, &Constant::Str(s.into()), self.primitives.str).unwrap()
}
pub fn raise_exn(
pub fn raise_exn<G: CodeGenerator + ?Sized>(
&mut self,
generator: &mut dyn CodeGenerator,
generator: &mut G,
name: &str,
msg: BasicValueEnum<'ctx>,
params: [Option<IntValue<'ctx>>; 3],
@ -546,9 +559,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
gen_raise(generator, self, Some(&zelf.into()), loc);
}
pub fn make_assert(
pub fn make_assert<G: CodeGenerator + ?Sized>(
&mut self,
generator: &mut dyn CodeGenerator,
generator: &mut G,
cond: IntValue<'ctx>,
err_name: &str,
err_msg: &str,
@ -559,9 +572,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
self.make_assert_impl(generator, cond, err_name, err_msg, params, loc);
}
pub fn make_assert_impl(
pub fn make_assert_impl<G: CodeGenerator + ?Sized>(
&mut self,
generator: &mut dyn CodeGenerator,
generator: &mut G,
cond: IntValue<'ctx>,
err_name: &str,
err_msg: BasicValueEnum<'ctx>,
@ -878,7 +891,7 @@ pub fn destructure_range<'ctx>(
/// Returns an instance of [`PointerValue`] pointing to the List structure. The List structure is
/// defined as `type { ty*, size_t }` in LLVM, where the first element stores the pointer to the
/// data, and the second element stores the size of the List.
pub fn allocate_list<'ctx, G: CodeGenerator>(
pub fn allocate_list<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ty: BasicTypeEnum<'ctx>,
@ -978,7 +991,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
list_alloc_size.into_int_value(),
Some("listcomp.addr")
);
list_content = list.data().as_ptr_value(ctx);
list_content = list.data().base_ptr(ctx, generator);
let i = generator.gen_store_target(ctx, target, Some("i.addr"))?.unwrap();
ctx.builder
@ -1011,7 +1024,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
)
.into_int_value();
list = allocate_list(generator, ctx, elem_ty, length, Some("listcomp"));
list_content = list.data().as_ptr_value(ctx);
list_content = list.data().base_ptr(ctx, generator);
let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("counter.addr"))?;
// counter = -1
ctx.builder.build_store(counter, size_t.const_int(u64::MAX, true)).unwrap();
@ -1078,34 +1091,22 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
Ok(Some(list.as_ptr_value().into()))
}
/// Generates LLVM IR for a [binary operator expression][expr].
///
/// * `left` - The left-hand side of the binary operator.
/// * `op` - The operator applied on the operands.
/// * `right` - The right-hand side of the binary operator.
/// * `loc` - The location of the full expression.
/// * `is_aug_assign` - Whether the binary operator expression is also an assignment operator.
pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
/// Generates LLVM IR for a binary operator expression using the [`Type`] and
/// [LLVM value][`BasicValueEnum`] of the operands.
pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
left: &Expr<Option<Type>>,
left: (&Option<Type>, BasicValueEnum<'ctx>),
op: &Operator,
right: &Expr<Option<Type>>,
right: (&Option<Type>, BasicValueEnum<'ctx>),
loc: Location,
is_aug_assign: bool,
) -> Result<Option<ValueEnum<'ctx>>, String> {
let ty1 = ctx.unifier.get_representative(left.custom.unwrap());
let ty2 = ctx.unifier.get_representative(right.custom.unwrap());
let left_val = if let Some(v) = generator.gen_expr(ctx, left)? {
v.to_basic_value_enum(ctx, generator, left.custom.unwrap())?
} else {
return Ok(None)
};
let right_val = if let Some(v) = generator.gen_expr(ctx, right)? {
v.to_basic_value_enum(ctx, generator, right.custom.unwrap())?
} else {
return Ok(None)
};
let (left_ty, left_val) = left;
let (right_ty, right_val) = right;
let ty1 = ctx.unifier.get_representative(left_ty.unwrap());
let ty2 = ctx.unifier.get_representative(right_ty.unwrap());
// we can directly compare the types, because we've got their representatives
// which would be unchanged until further unification, which we would never do
@ -1129,8 +1130,46 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
Some("f_pow_i")
);
Ok(Some(res.into()))
} else if matches!(&*ctx.unifier.get_ty(ty1), TypeEnum::TObj { obj_id, .. } if obj_id == &PRIMITIVE_DEF_IDS.ndarray) && matches!(&*ctx.unifier.get_ty(ty2), TypeEnum::TObj { obj_id, .. } if obj_id == &PRIMITIVE_DEF_IDS.ndarray) {
let llvm_usize = generator.get_size_type(ctx.ctx);
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1);
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
let left_val = NDArrayValue::from_ptr_val(
left_val.into_pointer_value(),
llvm_usize,
None
);
let right_val = NDArrayValue::from_ptr_val(
right_val.into_pointer_value(),
llvm_usize,
None
);
let res = numpy::ndarray_elementwise_binop_impl(
generator,
ctx,
ndarray_dtype1,
if is_aug_assign { Some(left_val) } else { None },
left_val,
right_val,
|generator, ctx, elem_ty, (lhs, rhs)| {
gen_binop_expr_with_values(
generator,
ctx,
(&Some(elem_ty), lhs),
op,
(&Some(elem_ty), rhs),
ctx.current_loc,
is_aug_assign,
)?.unwrap().to_basic_value_enum(ctx, generator, elem_ty)
},
)?;
Ok(Some(res.as_ptr_value().into()))
} else {
let left_ty_enum = ctx.unifier.get_ty_immutable(left.custom.unwrap());
let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap());
let TypeEnum::TObj { fields, obj_id, .. } = left_ty_enum.as_ref() else {
unreachable!("must be tobj")
};
@ -1150,7 +1189,7 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
let signature = if let Some(call) = ctx.calls.get(&loc.into()) {
ctx.unifier.get_call_signature(*call).unwrap()
} else {
let left_enum_ty = ctx.unifier.get_ty_immutable(left.custom.unwrap());
let left_enum_ty = ctx.unifier.get_ty_immutable(left_ty.unwrap());
let TypeEnum::TObj { fields, .. } = left_enum_ty.as_ref() else {
unreachable!("must be tobj")
};
@ -1175,13 +1214,51 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
generator
.gen_call(
ctx,
Some((left.custom.unwrap(), left_val.into())),
Some((left_ty.unwrap(), left_val.into())),
(&signature, fun_id),
vec![(None, right_val.into())],
).map(|f| f.map(Into::into))
}
}
/// Generates LLVM IR for a [binary operator expression][expr].
///
/// * `left` - The left-hand side of the binary operator.
/// * `op` - The operator applied on the operands.
/// * `right` - The right-hand side of the binary operator.
/// * `loc` - The location of the full expression.
/// * `is_aug_assign` - Whether the binary operator expression is also an assignment operator.
pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
left: &Expr<Option<Type>>,
op: &Operator,
right: &Expr<Option<Type>>,
loc: Location,
is_aug_assign: bool,
) -> Result<Option<ValueEnum<'ctx>>, String> {
let left_val = if let Some(v) = generator.gen_expr(ctx, left)? {
v.to_basic_value_enum(ctx, generator, left.custom.unwrap())?
} else {
return Ok(None)
};
let right_val = if let Some(v) = generator.gen_expr(ctx, right)? {
v.to_basic_value_enum(ctx, generator, right.custom.unwrap())?
} else {
return Ok(None)
};
gen_binop_expr_with_values(
generator,
ctx,
(&left.custom, left_val),
op,
(&right.custom, right_val),
loc,
is_aug_assign,
)
}
/// Generates code for a subscript expression on an `ndarray`.
///
/// * `ty` - The `Type` of the `NDArray` elements.
@ -1254,12 +1331,14 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
} else {
return Ok(None)
};
let index_addr = generator.gen_var_alloc(ctx, index.get_type().into(), None)?;
ctx.builder.build_store(index_addr, index).unwrap();
Ok(Some(v.data()
.get_const(
.get(
ctx,
generator,
ctx.ctx.i32_type().const_array(&[index]),
ArraySliceValue::from_ptr_val(index_addr, llvm_usize.const_int(1, false), None),
None,
)
.into()))
@ -1275,6 +1354,8 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
} else {
return Ok(None)
};
let index_addr = generator.gen_var_alloc(ctx, index.get_type().into(), None)?;
ctx.builder.build_store(index_addr, index).unwrap();
// Create a new array, remove the top dimension from the dimension-size-list, and copy the
// elements over
@ -1300,15 +1381,17 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
let ndarray_num_dims = ndarray.load_ndims(ctx);
let v_dims_src_ptr = v.dim_sizes().ptr_offset(
ctx,
generator,
llvm_usize.const_int(1, false),
None,
);
let v_dims_src_ptr = unsafe {
v.dim_sizes().ptr_offset_unchecked(
ctx,
generator,
llvm_usize.const_int(1, false),
None,
)
};
call_memcpy_generic(
ctx,
ndarray.dim_sizes().as_ptr_value(ctx),
ndarray.dim_sizes().base_ptr(ctx, generator),
v_dims_src_ptr,
ctx.builder
.build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "")
@ -1320,20 +1403,19 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
let ndarray_num_elems = call_ndarray_calc_size(
generator,
ctx,
ndarray.load_ndims(ctx),
ndarray.dim_sizes().as_ptr_value(ctx),
&ndarray.dim_sizes().as_slice_value(ctx, generator),
);
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
let v_data_src_ptr = v.data().ptr_offset_const(
let v_data_src_ptr = v.data().ptr_offset(
ctx,
generator,
ctx.ctx.i32_type().const_array(&[index]),
ArraySliceValue::from_ptr_val(index_addr, llvm_usize.const_int(1, false), None),
None
);
call_memcpy_generic(
ctx,
ndarray.data().as_ptr_value(ctx),
ndarray.data().base_ptr(ctx, generator),
v_data_src_ptr,
ctx.builder
.build_int_mul(ndarray_num_elems, llvm_ndarray_data_t.size_of().unwrap(), "")
@ -1971,7 +2053,6 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
}
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
let (ty, ndims) = params.iter()
.sorted_by_key(|(var_id, _)| *var_id)
.map(|(_, ty)| ty)
.collect_tuple()
.unwrap();

View File

@ -1,5 +1,5 @@
use crate::{
codegen::{expr::*, stmt::*, bool_to_i1, bool_to_i8, CodeGenContext},
codegen::{classes::ArraySliceValue, expr::*, stmt::*, bool_to_i1, bool_to_i8, CodeGenContext},
symbol_resolver::ValueEnum,
toplevel::{DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type},
@ -99,8 +99,8 @@ pub trait CodeGenerator {
ctx: &mut CodeGenContext<'ctx, '_>,
ty: BasicTypeEnum<'ctx>,
size: IntValue<'ctx>,
name: Option<&str>,
) -> Result<PointerValue<'ctx>, String> {
name: Option<&'ctx str>,
) -> Result<ArraySliceValue<'ctx>, String> {
gen_array_var(ctx, ty, size, name)
}

View File

@ -8,6 +8,8 @@ typedef unsigned _BitInt(64) uint64_t;
# define MAX(a, b) (a > b ? a : b)
# define MIN(a, b) (a > b ? b : a)
# define NULL ((void *) 0)
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
// need to make sure `exp >= 0` before calling this function
#define DEF_INT_EXP(T) T __nac3_int_exp_##T( \
@ -243,13 +245,13 @@ void __nac3_ndarray_calc_nd_indices64(
uint64_t index,
const uint64_t* dims,
uint64_t num_dims,
uint64_t* idxs
uint32_t* idxs
) {
uint64_t stride = 1;
for (uint64_t dim = 0; dim < num_dims; dim++) {
uint64_t i = num_dims - dim - 1;
__builtin_assume(dims[i] > 0);
idxs[i] = (index / stride) % dims[i];
idxs[i] = (uint32_t) ((index / stride) % dims[i]);
stride *= dims[i];
}
}
@ -293,3 +295,87 @@ uint64_t __nac3_ndarray_flatten_index64(
}
return idx;
}
void __nac3_ndarray_calc_broadcast(
const uint32_t *lhs_dims,
uint32_t lhs_ndims,
const uint32_t *rhs_dims,
uint32_t rhs_ndims,
uint32_t *out_dims
) {
uint32_t max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
for (uint32_t i = 0; i < max_ndims; ++i) {
uint32_t *lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : NULL;
uint32_t *rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : NULL;
uint32_t *out_dim = &out_dims[max_ndims - i - 1];
if (lhs_dim_sz == NULL) {
*out_dim = *rhs_dim_sz;
} else if (rhs_dim_sz == NULL) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == 1) {
*out_dim = *rhs_dim_sz;
} else if (*rhs_dim_sz == 1) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == *rhs_dim_sz) {
*out_dim = *lhs_dim_sz;
} else {
__builtin_unreachable();
}
}
}
void __nac3_ndarray_calc_broadcast64(
const uint64_t *lhs_dims,
uint64_t lhs_ndims,
const uint64_t *rhs_dims,
uint64_t rhs_ndims,
uint64_t *out_dims
) {
uint64_t max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
for (uint64_t i = 0; i < max_ndims; ++i) {
uint64_t *lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : NULL;
uint64_t *rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : NULL;
uint64_t *out_dim = &out_dims[max_ndims - i - 1];
if (lhs_dim_sz == NULL) {
*out_dim = *rhs_dim_sz;
} else if (rhs_dim_sz == NULL) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == 1) {
*out_dim = *rhs_dim_sz;
} else if (*rhs_dim_sz == 1) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == *rhs_dim_sz) {
*out_dim = *lhs_dim_sz;
} else {
__builtin_unreachable();
}
}
}
void __nac3_ndarray_calc_broadcast_idx(
const uint32_t *src_dims,
uint32_t src_ndims,
const uint32_t *in_idx,
uint32_t *out_idx
) {
for (uint32_t i = 0; i < src_ndims; ++i) {
uint32_t src_i = src_ndims - i - 1;
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
}
}
void __nac3_ndarray_calc_broadcast_idx64(
const uint64_t *src_dims,
uint64_t src_ndims,
const uint32_t *in_idx,
uint32_t *out_idx
) {
for (uint64_t i = 0; i < src_ndims; ++i) {
uint64_t src_i = src_ndims - i - 1;
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : (uint32_t) in_idx[src_i];
}
}

View File

@ -1,9 +1,18 @@
use crate::typecheck::typedef::Type;
use super::{
classes::{ListValue, NDArrayValue},
classes::{
ArrayLikeIndexer,
ArrayLikeValue,
ArraySliceValue,
ListValue,
NDArrayValue,
TypedArrayLikeAdapter,
UntypedArrayLikeAccessor,
},
CodeGenContext,
CodeGenerator,
llvm_intrinsics,
};
use inkwell::{
attributes::{Attribute, AttributeLoc},
@ -11,7 +20,7 @@ use inkwell::{
memory_buffer::MemoryBuffer,
module::Module,
types::{BasicTypeEnum, IntType},
values::{ArrayValue, BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue},
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue},
AddressSpace, IntPredicate,
};
use itertools::Either;
@ -39,8 +48,8 @@ pub fn load_irrt(ctx: &Context) -> Module {
// repeated squaring method adapted from GNU Scientific Library:
// https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
pub fn integer_power<'ctx>(
generator: &mut dyn CodeGenerator,
pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
base: IntValue<'ctx>,
exp: IntValue<'ctx>,
@ -81,8 +90,8 @@ pub fn integer_power<'ctx>(
.unwrap()
}
pub fn calculate_len_for_slice_range<'ctx>(
generator: &mut dyn CodeGenerator,
pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
start: IntValue<'ctx>,
end: IntValue<'ctx>,
@ -303,8 +312,8 @@ pub fn handle_slice_index_bound<'ctx, G: CodeGenerator>(
/// This function handles 'end' **inclusively**.
/// Order of tuples `assign_idx` and `value_idx` is ('start', 'end', 'step').
/// Negative index should be handled before entering this function
pub fn list_slice_assignment<'ctx>(
generator: &mut dyn CodeGenerator,
pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ty: BasicTypeEnum<'ctx>,
dest_arr: ListValue<'ctx>,
@ -338,7 +347,7 @@ pub fn list_slice_assignment<'ctx>(
let zero = int32.const_zero();
let one = int32.const_int(1, false);
let dest_arr_ptr = dest_arr.data().as_ptr_value(ctx);
let dest_arr_ptr = dest_arr.data().base_ptr(ctx, generator);
let dest_arr_ptr = ctx.builder.build_pointer_cast(
dest_arr_ptr,
elem_ptr_type,
@ -346,7 +355,7 @@ pub fn list_slice_assignment<'ctx>(
).unwrap();
let dest_len = dest_arr.load_size(ctx, Some("dest.len"));
let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32").unwrap();
let src_arr_ptr = src_arr.data().as_ptr_value(ctx);
let src_arr_ptr = src_arr.data().base_ptr(ctx, generator);
let src_arr_ptr = ctx.builder.build_pointer_cast(
src_arr_ptr,
elem_ptr_type,
@ -468,8 +477,8 @@ pub fn list_slice_assignment<'ctx>(
}
/// Generates a call to `isinf` in IR. Returns an `i1` representing the result.
pub fn call_isinf<'ctx>(
generator: &mut dyn CodeGenerator,
pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> IntValue<'ctx> {
@ -489,8 +498,8 @@ pub fn call_isinf<'ctx>(
}
/// Generates a call to `isnan` in IR. Returns an `i1` representing the result.
pub fn call_isnan<'ctx>(
generator: &mut dyn CodeGenerator,
pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> IntValue<'ctx> {
@ -574,12 +583,14 @@ pub fn call_j0<'ctx>(
///
/// * `num_dims` - An [`IntValue`] containing the number of dimensions.
/// * `dims` - A [`PointerValue`] to an array containing the size of each dimension.
pub fn call_ndarray_calc_size<'ctx>(
generator: &dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>,
num_dims: IntValue<'ctx>,
dims: PointerValue<'ctx>,
) -> IntValue<'ctx> {
pub fn call_ndarray_calc_size<'ctx, G, Dims>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
dims: &Dims,
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Dims: ArrayLikeIndexer<'ctx>, {
let llvm_i64 = ctx.ctx.i64_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
@ -606,8 +617,8 @@ pub fn call_ndarray_calc_size<'ctx>(
.build_call(
ndarray_calc_size_fn,
&[
dims.into(),
num_dims.into(),
dims.base_ptr(ctx, generator).into(),
dims.size(ctx, generator).into(),
],
"",
)
@ -617,20 +628,22 @@ pub fn call_ndarray_calc_size<'ctx>(
.unwrap()
}
/// Generates a call to `__nac3_ndarray_calc_nd_indices`.
/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`]
/// containing `i32` indices of the flattened index.
///
/// * `index` - The index to compute the multidimensional index for.
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
pub fn call_ndarray_calc_nd_indices<'ctx>(
generator: &dyn CodeGenerator,
pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &mut CodeGenContext<'ctx, '_>,
index: IntValue<'ctx>,
ndarray: NDArrayValue<'ctx>,
) -> PointerValue<'ctx> {
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_void = ctx.ctx.void_type();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() {
@ -644,7 +657,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx>(
llvm_usize.into(),
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
llvm_pi32.into(),
],
false,
);
@ -656,7 +669,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx>(
let ndarray_dims = ndarray.dim_sizes();
let indices = ctx.builder.build_array_alloca(
llvm_usize,
llvm_i32,
ndarray_num_dims,
"",
).unwrap();
@ -666,7 +679,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx>(
ndarray_calc_nd_indices_fn,
&[
index.into(),
ndarray_dims.as_ptr_value(ctx).into(),
ndarray_dims.base_ptr(ctx, generator).into(),
ndarray_num_dims.into(),
indices.into(),
],
@ -674,16 +687,22 @@ pub fn call_ndarray_calc_nd_indices<'ctx>(
)
.unwrap();
indices
TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None),
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}
fn call_ndarray_flatten_index_impl<'ctx>(
generator: &dyn CodeGenerator,
fn call_ndarray_flatten_index_impl<'ctx, G, Indices>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: PointerValue<'ctx>,
indices_size: IntValue<'ctx>,
) -> IntValue<'ctx> {
indices: &Indices,
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Indices: ArrayLikeIndexer<'ctx>, {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
@ -691,14 +710,14 @@ fn call_ndarray_flatten_index_impl<'ctx>(
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
debug_assert_eq!(
IntType::try_from(indices.get_type().get_element_type())
IntType::try_from(indices.element_type(ctx, generator))
.map(IntType::get_bit_width)
.unwrap_or_default(),
llvm_i32.get_bit_width(),
"Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`"
);
debug_assert_eq!(
indices_size.get_type().get_bit_width(),
indices.size(ctx, generator).get_type().get_bit_width(),
llvm_usize.get_bit_width(),
"Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`"
);
@ -729,10 +748,10 @@ fn call_ndarray_flatten_index_impl<'ctx>(
.build_call(
ndarray_flatten_index_fn,
&[
ndarray_dims.as_ptr_value(ctx).into(),
ndarray_dims.base_ptr(ctx, generator).into(),
ndarray_num_dims.into(),
indices.into(),
indices_size.into(),
indices.base_ptr(ctx, generator).into(),
indices.size(ctx, generator).into(),
],
"",
)
@ -750,63 +769,169 @@ fn call_ndarray_flatten_index_impl<'ctx>(
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
/// * `indices` - The multidimensional index to compute the flattened index for.
pub fn call_ndarray_flatten_index<'ctx>(
generator: &dyn CodeGenerator,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: ListValue<'ctx>,
) -> IntValue<'ctx> {
let indices_size = indices.load_size(ctx, None);
let indices_data = indices.data();
call_ndarray_flatten_index_impl(
generator,
ctx,
ndarray,
indices_data.as_ptr_value(ctx),
indices_size,
)
}
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
/// multidimensional index.
///
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
/// * `indices` - The multidimensional index to compute the flattened index for.
pub fn call_ndarray_flatten_index_const<'ctx>(
generator: &mut dyn CodeGenerator,
pub fn call_ndarray_flatten_index<'ctx, G, Index>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: ArrayValue<'ctx>,
) -> IntValue<'ctx> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let indices_size = indices.get_type().len();
let indices_alloca = generator.gen_array_var_alloc(
ctx,
indices.get_type().get_element_type(),
llvm_usize.const_int(indices_size as u64, false),
None
).unwrap();
for i in 0..indices_size {
let v = ctx.builder.build_extract_value(indices, i, "")
.unwrap()
.into_int_value();
let elem_ptr = unsafe {
ctx.builder.build_in_bounds_gep(
indices_alloca,
&[ctx.ctx.i32_type().const_int(i as u64, false)],
""
)
}.unwrap();
ctx.builder.build_store(elem_ptr, v).unwrap();
}
indices: &Index,
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Index: ArrayLikeIndexer<'ctx>, {
call_ndarray_flatten_index_impl(
generator,
ctx,
ndarray,
indices_alloca,
llvm_usize.const_int(indices_size as u64, false),
indices,
)
}
/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of
/// dimension and size of each dimension of the resultant `ndarray`.
pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
lhs: NDArrayValue<'ctx>,
rhs: NDArrayValue<'ctx>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast",
64 => "__nac3_ndarray_calc_broadcast64",
bw => unreachable!("Unsupported size type bit width: {}", bw)
};
let ndarray_calc_broadcast_fn = ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
],
false,
);
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_ndims = rhs.load_ndims(ctx);
let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
// TODO: Generate assertion checks for whether each dimension is compatible
// gen_for_callback_incrementing(
// generator,
// ctx,
// llvm_usize.const_zero(),
// (max_ndims, false),
// |generator, ctx, idx| {
// let lhs_dim_sz =
//
// let lhs_elem = lhs.get_dims().get(ctx, generator, idx, None);
// let rhs_elem = rhs.get_dims().get(ctx, generator, idx, None);
//
//
// },
// llvm_usize.const_int(1, false),
// ).unwrap();
let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator);
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator);
let rhs_ndims = rhs.load_ndims(ctx);
let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap();
let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None);
ctx.builder
.build_call(
ndarray_calc_broadcast_fn,
&[
lhs_dims.into(),
lhs_ndims.into(),
rhs_dims.into(),
rhs_ndims.into(),
out_dims.base_ptr(ctx, generator).into(),
],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
out_dims,
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}
/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`]
/// containing the indices used for accessing `array` corresponding to the index of the broadcasted
/// array `broadcast_idx`.
pub fn call_ndarray_calc_broadcast_index<'ctx, G: CodeGenerator + ?Sized, BroadcastIdx: UntypedArrayLikeAccessor<'ctx>>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
array: NDArrayValue<'ctx>,
broadcast_idx: &BroadcastIdx,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast_idx",
64 => "__nac3_ndarray_calc_broadcast_idx64",
bw => unreachable!("Unsupported size type bit width: {}", bw)
};
let ndarray_calc_broadcast_fn = ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[
llvm_pusize.into(),
llvm_usize.into(),
llvm_pi32.into(),
llvm_pi32.into(),
],
false,
);
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
// TODO: Assertions
let broadcast_size = broadcast_idx.size(ctx, generator);
let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();
let array_dims = array.dim_sizes().base_ptr(ctx, generator);
let array_ndims = array.load_ndims(ctx);
let broadcast_idx_ptr = unsafe {
broadcast_idx.ptr_offset_unchecked(
ctx,
generator,
llvm_usize.const_zero(),
None
)
};
ctx.builder
.build_call(
ndarray_calc_broadcast_fn,
&[
array_dims.into(),
array_ndims.into(),
broadcast_idx_ptr.into(),
out_idx.into(),
],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None),
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}

View File

@ -2,7 +2,7 @@ use crate::{
symbol_resolver::{StaticValue, SymbolResolver},
toplevel::{
helper::PRIMITIVE_DEF_IDS,
numpy::unpack_ndarray_tvars,
numpy::unpack_ndarray_var_tys,
TopLevelContext,
TopLevelDef,
},
@ -45,6 +45,7 @@ pub mod expr;
mod generator;
pub mod irrt;
pub mod llvm_intrinsics;
pub mod numpy;
pub mod stmt;
#[cfg(test)]
@ -415,10 +416,10 @@ pub struct CodeGenTask {
/// This function is used to obtain the in-memory representation of `ty`, e.g. a `bool` variable
/// would be represented by an `i8`.
#[allow(clippy::too_many_arguments)]
fn get_llvm_type<'ctx>(
fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
ctx: &'ctx Context,
module: &Module<'ctx>,
generator: &mut dyn CodeGenerator,
generator: &mut G,
unifier: &mut Unifier,
top_level: &TopLevelContext,
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
@ -450,7 +451,7 @@ fn get_llvm_type<'ctx>(
TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
let llvm_usize = generator.get_size_type(ctx);
let (dtype, _) = unpack_ndarray_tvars(unifier, ty);
let (dtype, _) = unpack_ndarray_var_tys(unifier, ty);
let element_type = get_llvm_type(
ctx,
module,
@ -553,10 +554,10 @@ fn get_llvm_type<'ctx>(
/// be byte-aligned for the variable to be addressable in memory, whereas there is no such
/// restriction for ABI representations.
#[allow(clippy::too_many_arguments)]
fn get_llvm_abi_type<'ctx>(
fn get_llvm_abi_type<'ctx, G: CodeGenerator + ?Sized>(
ctx: &'ctx Context,
module: &Module<'ctx>,
generator: &mut dyn CodeGenerator,
generator: &mut G,
unifier: &mut Unifier,
top_level: &TopLevelContext,
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,

View File

@ -0,0 +1,928 @@
use inkwell::{
IntPredicate,
types::BasicType,
values::{AggregateValueEnum, ArrayValue, BasicValueEnum, IntValue, PointerValue}
};
use nac3parser::ast::StrRef;
use crate::{
codegen::{
classes::{
ArrayLikeIndexer,
ArrayLikeValue,
ListValue,
NDArrayValue,
TypedArrayLikeAccessor,
TypedArrayLikeAdapter,
UntypedArrayLikeAccessor,
},
CodeGenContext,
CodeGenerator,
irrt::{
call_ndarray_calc_broadcast,
call_ndarray_calc_broadcast_index,
call_ndarray_calc_nd_indices,
call_ndarray_calc_size,
},
llvm_intrinsics::call_memcpy_generic,
stmt::gen_for_callback_incrementing,
},
symbol_resolver::ValueEnum,
toplevel::{
DefinitionId,
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
},
typecheck::typedef::{FunSignature, Type},
};
/// Creates an `NDArray` instance from a dynamic shape.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The shape of the `NDArray`.
/// * `shape_len_fn` - A function that retrieves the number of dimensions from `shape`.
/// * `shape_data_fn` - A function that retrieves the size of a dimension from `shape`.
fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
elem_ty: Type,
shape: &V,
shape_len_fn: LenFn,
shape_data_fn: DataFn,
) -> Result<NDArrayValue<'ctx>, String>
where
G: CodeGenerator + ?Sized,
LenFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, &V) -> Result<IntValue<'ctx>, String>,
DataFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, &V, IntValue<'ctx>) -> Result<IntValue<'ctx>, String>,
{
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum();
assert!(llvm_ndarray_data_t.is_sized());
// Assert that all dimensions are non-negative
let shape_len = shape_len_fn(generator, ctx, shape)?;
gen_for_callback_incrementing(
generator,
ctx,
llvm_usize.const_zero(),
(shape_len, false),
|generator, ctx, i| {
let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
let shape_dim_gez = ctx.builder
.build_int_compare(IntPredicate::SGE, shape_dim, shape_dim.get_type().const_zero(), "")
.unwrap();
ctx.make_assert(
generator,
shape_dim_gez,
"0:ValueError",
"negative dimensions not supported",
[None, None, None],
ctx.current_loc,
);
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let ndarray = generator.gen_var_alloc(
ctx,
llvm_ndarray_t.into(),
None,
)?;
let ndarray = NDArrayValue::from_ptr_val(ndarray, llvm_usize, None);
let num_dims = shape_len_fn(generator, ctx, shape)?;
ndarray.store_ndims(ctx, generator, num_dims);
let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
// Copy the dimension sizes from shape to ndarray.dims
let shape_len = shape_len_fn(generator, ctx, shape)?;
gen_for_callback_incrementing(
generator,
ctx,
llvm_usize.const_zero(),
(shape_len, false),
|generator, ctx, i| {
let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
let shape_dim = ctx.builder
.build_int_z_extend(shape_dim, llvm_usize, "")
.unwrap();
let ndarray_pdim = unsafe {
ndarray.dim_sizes().ptr_offset_unchecked(ctx, generator, i, None)
};
ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let ndarray_num_elems = call_ndarray_calc_size(
generator,
ctx,
&ndarray.dim_sizes().as_slice_value(ctx, generator),
);
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
Ok(ndarray)
}
/// Creates an `NDArray` instance from a constant shape.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The shape of the `NDArray`, represented as an LLVM [`ArrayValue`].
fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: ArrayValue<'ctx>
) -> Result<NDArrayValue<'ctx>, String> {
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum();
assert!(llvm_ndarray_data_t.is_sized());
for i in 0..shape.get_type().len() {
let shape_dim = ctx.builder
.build_extract_value(shape, i, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let shape_dim_gez = ctx.builder
.build_int_compare(IntPredicate::SGE, shape_dim, llvm_usize.const_zero(), "")
.unwrap();
ctx.make_assert(
generator,
shape_dim_gez,
"0:ValueError",
"negative dimensions not supported",
[None, None, None],
ctx.current_loc,
);
}
let ndarray = generator.gen_var_alloc(
ctx,
llvm_ndarray_t.into(),
None,
)?;
let ndarray = NDArrayValue::from_ptr_val(ndarray, llvm_usize, None);
let num_dims = llvm_usize.const_int(shape.get_type().len() as u64, false);
ndarray.store_ndims(ctx, generator, num_dims);
let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
for i in 0..shape.get_type().len() {
let ndarray_dim = ndarray
.dim_sizes()
.ptr_offset(ctx, generator, llvm_usize.const_int(i as u64, true), None);
let shape_dim = ctx.builder.build_extract_value(shape, i, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
ctx.builder.build_store(ndarray_dim, shape_dim).unwrap();
}
let ndarray_num_elems = call_ndarray_calc_size(
generator,
ctx,
&ndarray.dim_sizes().as_slice_value(ctx, generator),
);
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
Ok(ndarray)
}
fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
) -> BasicValueEnum<'ctx> {
if [ctx.primitives.int32, ctx.primitives.uint32].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) {
ctx.ctx.i32_type().const_zero().into()
} else if [ctx.primitives.int64, ctx.primitives.uint64].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) {
ctx.ctx.i64_type().const_zero().into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) {
ctx.ctx.f64_type().const_zero().into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
ctx.ctx.bool_type().const_zero().into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
ctx.gen_string(generator, "")
} else {
unreachable!()
}
}
fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
) -> BasicValueEnum<'ctx> {
if [ctx.primitives.int32, ctx.primitives.uint32].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) {
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int32);
ctx.ctx.i32_type().const_int(1, is_signed).into()
} else if [ctx.primitives.int64, ctx.primitives.uint64].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) {
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int64);
ctx.ctx.i64_type().const_int(1, is_signed).into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) {
ctx.ctx.f64_type().const_float(1.0).into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
ctx.ctx.bool_type().const_int(1, false).into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
ctx.gen_string(generator, "1")
} else {
unreachable!()
}
}
/// LLVM-typed implementation for generating the implementation for constructing an `NDArray`.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
fn call_ndarray_empty_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: ListValue<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&shape,
|_, ctx, shape| {
Ok(shape.load_size(ctx, None))
},
|generator, ctx, shape, idx| {
Ok(shape.data().get(ctx, generator, idx, None).into_int_value())
},
)
}
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as
/// its input.
fn ndarray_fill_flattened<'ctx, 'a, G, ValueFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
ndarray: NDArrayValue<'ctx>,
value_fn: ValueFn,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<BasicValueEnum<'ctx>, String>,
{
let llvm_usize = generator.get_size_type(ctx.ctx);
let ndarray_num_elems = call_ndarray_calc_size(
generator,
ctx,
&ndarray.dim_sizes().as_slice_value(ctx, generator),
);
gen_for_callback_incrementing(
generator,
ctx,
llvm_usize.const_zero(),
(ndarray_num_elems, false),
|generator, ctx, i| {
let elem = unsafe {
ndarray.data().ptr_offset_unchecked(ctx, generator, i, None)
};
let value = value_fn(generator, ctx, i)?;
ctx.builder.build_store(elem, value).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)
}
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with the dimension-indices
/// as its input.
fn ndarray_fill_indexed<'ctx, G, ValueFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
value_fn: ValueFn,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, TypedArrayLikeAdapter<'ctx, IntValue<'ctx>>) -> Result<BasicValueEnum<'ctx>, String>,
{
ndarray_fill_flattened(
generator,
ctx,
ndarray,
|generator, ctx, idx| {
let indices = call_ndarray_calc_nd_indices(
generator,
ctx,
idx,
ndarray,
);
value_fn(generator, ctx, indices)
}
)
}
/// Generates the LLVM IR for populating the entire `NDArray` using a lambda with the same-indexed
/// element from two other `NDArray` as its input.
fn ndarray_broadcast_fill<'ctx, G, ValueFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
res: NDArrayValue<'ctx>,
lhs: NDArrayValue<'ctx>,
rhs: NDArrayValue<'ctx>,
value_fn: ValueFn,
) -> Result<NDArrayValue<'ctx>, String>
where
G: CodeGenerator + ?Sized,
ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, Type, (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>)) -> Result<BasicValueEnum<'ctx>, String>,
{
ndarray_fill_indexed(
generator,
ctx,
res,
|generator, ctx, idx| {
let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, &idx);
let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, &idx);
let elem = unsafe {
(
lhs.data().get_unchecked(ctx, generator, lhs_idx, None),
rhs.data().get_unchecked(ctx, generator, rhs_idx, None),
)
};
debug_assert_eq!(elem.0.get_type(), elem.1.get_type());
value_fn(generator, ctx, elem_ty, elem)
},
)?;
Ok(res)
}
/// LLVM-typed implementation for generating the implementation for `ndarray.zeros`.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
fn call_ndarray_zeros_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: ListValue<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
let supported_types = [
ctx.primitives.int32,
ctx.primitives.int64,
ctx.primitives.uint32,
ctx.primitives.uint64,
ctx.primitives.float,
ctx.primitives.bool,
ctx.primitives.str,
];
assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty)));
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
ndarray_fill_flattened(
generator,
ctx,
ndarray,
|generator, ctx, _| {
let value = ndarray_zero_value(generator, ctx, elem_ty);
Ok(value)
}
)?;
Ok(ndarray)
}
/// LLVM-typed implementation for generating the implementation for `ndarray.ones`.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
fn call_ndarray_ones_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: ListValue<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
let supported_types = [
ctx.primitives.int32,
ctx.primitives.int64,
ctx.primitives.uint32,
ctx.primitives.uint64,
ctx.primitives.float,
ctx.primitives.bool,
ctx.primitives.str,
];
assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty)));
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
ndarray_fill_flattened(
generator,
ctx,
ndarray,
|generator, ctx, _| {
let value = ndarray_one_value(generator, ctx, elem_ty);
Ok(value)
}
)?;
Ok(ndarray)
}
/// LLVM-typed implementation for generating the implementation for `ndarray.full`.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
fn call_ndarray_full_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: ListValue<'ctx>,
fill_value: BasicValueEnum<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
ndarray_fill_flattened(
generator,
ctx,
ndarray,
|generator, ctx, _| {
let value = if fill_value.is_pointer_value() {
let llvm_i1 = ctx.ctx.bool_type();
let copy = generator.gen_var_alloc(ctx, fill_value.get_type(), None)?;
call_memcpy_generic(
ctx,
copy,
fill_value.into_pointer_value(),
fill_value.get_type().size_of().map(Into::into).unwrap(),
llvm_i1.const_zero(),
);
copy.into()
} else if fill_value.is_int_value() || fill_value.is_float_value() {
fill_value
} else {
unreachable!()
};
Ok(value)
}
)?;
Ok(ndarray)
}
/// LLVM-typed implementation for generating the implementation for `ndarray.eye`.
///
/// * `elem_ty` - The element type of the `NDArray`.
fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
nrows: IntValue<'ctx>,
ncols: IntValue<'ctx>,
offset: IntValue<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_usize_2 = llvm_usize.array_type(2);
let shape_addr = generator.gen_var_alloc(ctx, llvm_usize_2.into(), None)?;
let shape = ctx.builder.build_load(shape_addr, "")
.map(BasicValueEnum::into_array_value)
.unwrap();
let nrows = ctx.builder.build_int_z_extend_or_bit_cast(nrows, llvm_usize, "").unwrap();
let shape = ctx.builder
.build_insert_value(shape, nrows, 0, "")
.map(AggregateValueEnum::into_array_value)
.unwrap();
let ncols = ctx.builder.build_int_z_extend_or_bit_cast(ncols, llvm_usize, "").unwrap();
let shape = ctx.builder
.build_insert_value(shape, ncols, 1, "")
.map(AggregateValueEnum::into_array_value)
.unwrap();
let ndarray = create_ndarray_const_shape(generator, ctx, elem_ty, shape)?;
ndarray_fill_indexed(
generator,
ctx,
ndarray,
|generator, ctx, indices| {
let (row, col) = unsafe {
(
indices.get_typed_unchecked(ctx, generator, llvm_usize.const_zero(), None),
indices.get_typed_unchecked(ctx, generator, llvm_usize.const_int(1, false), None),
)
};
let col_with_offset = ctx.builder
.build_int_add(
col,
ctx.builder.build_int_s_extend_or_bit_cast(offset, llvm_i32, "").unwrap(),
"",
)
.unwrap();
let is_on_diag = ctx.builder
.build_int_compare(IntPredicate::EQ, row, col_with_offset, "")
.unwrap();
let zero = ndarray_zero_value(generator, ctx, elem_ty);
let one = ndarray_one_value(generator, ctx, elem_ty);
let value = ctx.builder.build_select(is_on_diag, one, zero, "").unwrap();
Ok(value)
},
)?;
Ok(ndarray)
}
/// LLVM-typed implementation for generating the implementation for `ndarray.copy`.
///
/// * `elem_ty` - The element type of the `NDArray`.
fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
this: NDArrayValue<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
let llvm_i1 = ctx.ctx.bool_type();
let ndarray = create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&this,
|_, ctx, shape| {
Ok(shape.load_ndims(ctx))
},
|generator, ctx, shape, idx| {
unsafe { Ok(shape.dim_sizes().get_typed_unchecked(ctx, generator, idx, None)) }
},
)?;
let len = call_ndarray_calc_size(
generator,
ctx,
&ndarray.dim_sizes().as_slice_value(ctx, generator),
);
let sizeof_ty = ctx.get_llvm_type(generator, elem_ty);
let len_bytes = ctx.builder
.build_int_mul(
len,
sizeof_ty.size_of().unwrap(),
"",
)
.unwrap();
call_memcpy_generic(
ctx,
ndarray.data().base_ptr(ctx, generator),
this.data().base_ptr(ctx, generator),
len_bytes,
llvm_i1.const_zero(),
);
Ok(ndarray)
}
/// LLVM-typed implementation for computing elementwise binary operations.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be
/// written to a new `ndarray`.
/// * `value_fn` - Function mapping the two input elements into the result.
pub fn ndarray_elementwise_binop_impl<'ctx, G, ValueFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
res: Option<NDArrayValue<'ctx>>,
this: NDArrayValue<'ctx>,
other: NDArrayValue<'ctx>,
value_fn: ValueFn,
) -> Result<NDArrayValue<'ctx>, String>
where
G: CodeGenerator,
ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, Type, (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>)) -> Result<BasicValueEnum<'ctx>, String>,
{
let ndarray_dims = call_ndarray_calc_broadcast(generator, ctx, this, other);
let ndarray = res.unwrap_or_else(|| {
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&ndarray_dims,
|generator, ctx, v| {
Ok(v.size(ctx, generator))
},
|generator, ctx, v, idx| {
unsafe {
Ok(v.get_typed_unchecked(ctx, generator, idx, None))
}
},
).unwrap()
});
ndarray_broadcast_fill(
generator,
ctx,
elem_ty,
ndarray,
this,
other,
|generator, ctx, elem_ty, elems| {
value_fn(generator, ctx, elem_ty, elems)
},
)?;
Ok(ndarray)
}
/// Generates LLVM IR for `ndarray.empty`.
pub fn gen_ndarray_empty<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let llvm_usize = generator.get_size_type(context.ctx);
let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone()
.to_basic_value_enum(context, generator, shape_ty)?;
call_ndarray_empty_impl(
generator,
context,
context.primitives.float,
ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None),
).map(NDArrayValue::into)
}
/// Generates LLVM IR for `ndarray.zeros`.
pub fn gen_ndarray_zeros<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let llvm_usize = generator.get_size_type(context.ctx);
let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone()
.to_basic_value_enum(context, generator, shape_ty)?;
call_ndarray_zeros_impl(
generator,
context,
context.primitives.float,
ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None),
).map(NDArrayValue::into)
}
/// Generates LLVM IR for `ndarray.ones`.
pub fn gen_ndarray_ones<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let llvm_usize = generator.get_size_type(context.ctx);
let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone()
.to_basic_value_enum(context, generator, shape_ty)?;
call_ndarray_ones_impl(
generator,
context,
context.primitives.float,
ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None),
).map(NDArrayValue::into)
}
/// Generates LLVM IR for `ndarray.full`.
pub fn gen_ndarray_full<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 2);
let llvm_usize = generator.get_size_type(context.ctx);
let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone()
.to_basic_value_enum(context, generator, shape_ty)?;
let fill_value_ty = fun.0.args[1].ty;
let fill_value_arg = args[1].1.clone()
.to_basic_value_enum(context, generator, fill_value_ty)?;
call_ndarray_full_impl(
generator,
context,
fill_value_ty,
ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None),
fill_value_arg,
).map(NDArrayValue::into)
}
/// Generates LLVM IR for `ndarray.eye`.
pub fn gen_ndarray_eye<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert!(matches!(args.len(), 1..=3));
let nrows_ty = fun.0.args[0].ty;
let nrows_arg = args[0].1.clone()
.to_basic_value_enum(context, generator, nrows_ty)?;
let ncols_ty = fun.0.args[1].ty;
let ncols_arg = args.iter()
.find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name))
.map(|arg| arg.1.clone().to_basic_value_enum(context, generator, ncols_ty))
.unwrap_or_else(|| {
args[0].1.clone().to_basic_value_enum(context, generator, nrows_ty)
})?;
let offset_ty = fun.0.args[2].ty;
let offset_arg = args.iter()
.find(|arg| arg.0.is_some_and(|name| name == fun.0.args[2].name))
.map(|arg| arg.1.clone().to_basic_value_enum(context, generator, offset_ty))
.unwrap_or_else(|| {
Ok(context.gen_symbol_val(
generator,
fun.0.args[2].default_value.as_ref().unwrap(),
offset_ty
))
})?;
call_ndarray_eye_impl(
generator,
context,
context.primitives.float,
nrows_arg.into_int_value(),
ncols_arg.into_int_value(),
offset_arg.into_int_value(),
).map(NDArrayValue::into)
}
/// Generates LLVM IR for `ndarray.identity`.
pub fn gen_ndarray_identity<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let llvm_usize = generator.get_size_type(context.ctx);
let n_ty = fun.0.args[0].ty;
let n_arg = args[0].1.clone()
.to_basic_value_enum(context, generator, n_ty)?;
call_ndarray_eye_impl(
generator,
context,
context.primitives.float,
n_arg.into_int_value(),
n_arg.into_int_value(),
llvm_usize.const_zero(),
).map(NDArrayValue::into)
}
/// Generates LLVM IR for `ndarray.copy`.
pub fn gen_ndarray_copy<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
_fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_some());
assert!(args.is_empty());
let llvm_usize = generator.get_size_type(context.ctx);
let this_ty = obj.as_ref().unwrap().0;
let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty);
let this_arg = obj
.as_ref()
.unwrap()
.1
.clone()
.to_basic_value_enum(context, generator, this_ty)?;
ndarray_copy_impl(
generator,
context,
this_elem_ty,
NDArrayValue::from_ptr_val(this_arg.into_pointer_value(), llvm_usize, None),
).map(NDArrayValue::into)
}
/// Generates LLVM IR for `ndarray.fill`.
pub fn gen_ndarray_fill<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<(), String> {
assert!(obj.is_some());
assert_eq!(args.len(), 1);
let llvm_usize = generator.get_size_type(context.ctx);
let this_ty = obj.as_ref().unwrap().0;
let this_arg = obj.as_ref().unwrap().1.clone()
.to_basic_value_enum(context, generator, this_ty)?
.into_pointer_value();
let value_ty = fun.0.args[0].ty;
let value_arg = args[0].1.clone()
.to_basic_value_enum(context, generator, value_ty)?;
ndarray_fill_flattened(
generator,
context,
NDArrayValue::from_ptr_val(this_arg, llvm_usize, None),
|generator, ctx, _| {
let value = if value_arg.is_pointer_value() {
let llvm_i1 = ctx.ctx.bool_type();
let copy = generator.gen_var_alloc(ctx, value_arg.get_type(), None)?;
call_memcpy_generic(
ctx,
copy,
value_arg.into_pointer_value(),
value_arg.get_type().size_of().map(Into::into).unwrap(),
llvm_i1.const_zero(),
);
copy.into()
} else if value_arg.is_int_value() || value_arg.is_float_value() {
value_arg
} else {
unreachable!()
};
Ok(value)
}
)?;
Ok(())
}

View File

@ -6,14 +6,14 @@ use super::{
};
use crate::{
codegen::{
classes::{ListValue, RangeValue},
classes::{ArrayLikeIndexer, ArraySliceValue, ListValue, RangeValue},
expr::gen_binop_expr,
gen_in_range_check,
},
toplevel::{
DefinitionId,
helper::PRIMITIVE_DEF_IDS,
numpy::unpack_ndarray_tvars,
numpy::unpack_ndarray_var_tys,
TopLevelDef,
},
typecheck::typedef::{FunSignature, Type, TypeEnum},
@ -65,8 +65,8 @@ pub fn gen_array_var<'ctx, 'a, T: BasicType<'ctx>>(
ctx: &mut CodeGenContext<'ctx, 'a>,
ty: T,
size: IntValue<'ctx>,
name: Option<&str>,
) -> Result<PointerValue<'ctx>, String> {
name: Option<&'ctx str>,
) -> Result<ArraySliceValue<'ctx>, String> {
// Restore debug location
let di_loc = ctx.debug_info.0.create_debug_location(
ctx.ctx,
@ -84,6 +84,7 @@ pub fn gen_array_var<'ctx, 'a, T: BasicType<'ctx>>(
ctx.builder.set_current_debug_location(di_loc);
let ptr = ctx.builder.build_array_alloca(ty, size, name.unwrap_or("")).unwrap();
let ptr = ArraySliceValue::from_ptr_val(ptr, size, name);
ctx.builder.position_at_end(current);
ctx.builder.set_current_debug_location(di_loc);
@ -250,7 +251,7 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) {
TypeEnum::TList { ty } => *ty,
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
unpack_ndarray_tvars(&mut ctx.unifier, target.custom.unwrap()).0
unpack_ndarray_var_tys(&mut ctx.unifier, target.custom.unwrap()).0
}
_ => unreachable!(),
};
@ -478,8 +479,8 @@ pub fn gen_for<G: CodeGenerator>(
/// executing. The result value must be an `i1` indicating if the loop should continue.
/// * `body` - A lambda containing IR statements within the loop body.
/// * `update` - A lambda containing IR statements updating loop variables.
pub fn gen_for_callback<'ctx, 'a, I, InitFn, CondFn, BodyFn, UpdateFn>(
generator: &mut dyn CodeGenerator,
pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
init: InitFn,
cond: CondFn,
@ -487,11 +488,12 @@ pub fn gen_for_callback<'ctx, 'a, I, InitFn, CondFn, BodyFn, UpdateFn>(
update: UpdateFn,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
I: Clone,
InitFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>) -> Result<I, String>,
CondFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, I) -> Result<IntValue<'ctx>, String>,
BodyFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
UpdateFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
InitFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<I, String>,
CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<IntValue<'ctx>, String>,
BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
UpdateFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
{
let current = ctx.builder.get_insert_block().and_then(BasicBlock::get_parent).unwrap();
let init_bb = ctx.ctx.append_basic_block(current, "for.init");
@ -536,6 +538,85 @@ pub fn gen_for_callback<'ctx, 'a, I, InitFn, CondFn, BodyFn, UpdateFn>(
Ok(())
}
/// Generates a C-style monotonically-increasing `for` construct using lambdas, similar to the
/// following C code:
///
/// ```c
/// for (int x = init_val; x /* < or <= ; see `max_val` */ max_val; x += incr_val) {
/// body(x);
/// }
/// ```
///
/// * `init_val` - The initial value of the loop variable. The type of this value will also be used
/// as the type of the loop variable.
/// * `max_val` - A tuple containing the maximum value of the loop variable, and whether the maximum
/// value should be treated as inclusive (as opposed to exclusive).
/// * `body` - A lambda containing IR statements within the loop body.
/// * `incr_val` - The value to increment the loop variable on each iteration.
pub fn gen_for_callback_incrementing<'ctx, 'a, G, BodyFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
init_val: IntValue<'ctx>,
max_val: (IntValue<'ctx>, bool),
body: BodyFn,
incr_val: IntValue<'ctx>,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<(), String>,
{
let init_val_t = init_val.get_type();
gen_for_callback(
generator,
ctx,
|generator, ctx| {
let i_addr = generator.gen_var_alloc(ctx, init_val_t.into(), None)?;
ctx.builder.build_store(i_addr, init_val).unwrap();
Ok(i_addr)
},
|_, ctx, i_addr| {
let cmp_op = if max_val.1 {
IntPredicate::ULE
} else {
IntPredicate::ULT
};
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let max_val = ctx.builder
.build_int_z_extend_or_bit_cast(max_val.0, init_val_t, "")
.unwrap();
Ok(ctx.builder.build_int_compare(cmp_op, i, max_val, "").unwrap())
},
|generator, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
body(generator, ctx, i)
},
|_, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let incr_val = ctx.builder
.build_int_z_extend_or_bit_cast(incr_val, init_val_t, "")
.unwrap();
let i = ctx.builder.build_int_add(i, incr_val, "").unwrap();
ctx.builder.build_store(i_addr, i).unwrap();
Ok(())
},
)
}
/// See [`CodeGenerator::gen_while`].
pub fn gen_while<G: CodeGenerator>(
generator: &mut G,
@ -701,8 +782,8 @@ pub fn final_proxy<'ctx>(
/// Inserts the declaration of the builtin function with the specified `symbol` name, and returns
/// the function.
pub fn get_builtins<'ctx>(
generator: &mut dyn CodeGenerator,
pub fn get_builtins<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
symbol: &str,
) -> FunctionValue<'ctx> {
@ -795,8 +876,8 @@ pub fn exn_constructor<'ctx>(
///
/// * `exception` - The exception thrown by the `raise` statement.
/// * `loc` - The location where the exception is raised from.
pub fn gen_raise<'ctx>(
generator: &mut dyn CodeGenerator,
pub fn gen_raise<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
exception: Option<&BasicValueEnum<'ctx>>,
loc: Location,

View File

@ -10,7 +10,7 @@ use crate::{
},
typecheck::{
type_inferencer::{FunctionData, Inferencer, PrimitiveStore},
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier},
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
},
};
use indoc::indoc;
@ -25,7 +25,6 @@ use nac3parser::{
use parking_lot::RwLock;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use crate::typecheck::typedef::VarMap;
struct Resolver {
id_to_type: HashMap<StrRef, Type>,

View File

@ -5,11 +5,14 @@ use crate::{
expr::destructure_range,
irrt::*,
llvm_intrinsics::*,
numpy::*,
stmt::exn_constructor,
},
symbol_resolver::SymbolValue,
toplevel::helper::PRIMITIVE_DEF_IDS,
toplevel::numpy::*,
toplevel::{
helper::PRIMITIVE_DEF_IDS,
numpy::make_ndarray_ty,
},
typecheck::typedef::VarMap,
};
use inkwell::{
@ -296,6 +299,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
Some("N".into()),
None,
);
let size_t = primitives.0.usize();
let var_map: VarMap = vec![(num_ty.1, num_ty.0)].into_iter().collect();
let exception_fields = vec![
("__name__".into(), int32, true),
@ -342,8 +347,27 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
.nth(1)
.map(|(var_id, ty)| (*ty, *var_id))
.unwrap();
let ndarray_usized_ndims_tvar = primitives.1.get_fresh_const_generic_var(
size_t,
Some("ndarray_ndims".into()),
None,
);
let ndarray_copy_ty = *ndarray_fields.get(&"copy".into()).unwrap();
let ndarray_fill_ty = *ndarray_fields.get(&"fill".into()).unwrap();
let ndarray_add_ty = *ndarray_fields.get(&"__add__".into()).unwrap();
let ndarray_sub_ty = *ndarray_fields.get(&"__sub__".into()).unwrap();
let ndarray_mul_ty = *ndarray_fields.get(&"__mul__".into()).unwrap();
let ndarray_truediv_ty = *ndarray_fields.get(&"__truediv__".into()).unwrap();
let ndarray_floordiv_ty = *ndarray_fields.get(&"__floordiv__".into()).unwrap();
let ndarray_mod_ty = *ndarray_fields.get(&"__mod__".into()).unwrap();
let ndarray_pow_ty = *ndarray_fields.get(&"__pow__".into()).unwrap();
let ndarray_iadd_ty = *ndarray_fields.get(&"__iadd__".into()).unwrap();
let ndarray_isub_ty = *ndarray_fields.get(&"__isub__".into()).unwrap();
let ndarray_imul_ty = *ndarray_fields.get(&"__imul__".into()).unwrap();
let ndarray_itruediv_ty = *ndarray_fields.get(&"__itruediv__".into()).unwrap();
let ndarray_ifloordiv_ty = *ndarray_fields.get(&"__ifloordiv__".into()).unwrap();
let ndarray_imod_ty = *ndarray_fields.get(&"__imod__".into()).unwrap();
let ndarray_ipow_ty = *ndarray_fields.get(&"__ipow__".into()).unwrap();
let top_level_def_list = vec![
Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def(
@ -521,6 +545,20 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
methods: vec![
("copy".into(), ndarray_copy_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 1)),
("fill".into(), ndarray_fill_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 2)),
("__add__".into(), ndarray_add_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 3)),
("__sub__".into(), ndarray_sub_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 4)),
("__mul__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 5)),
("__truediv__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 6)),
("__floordiv__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 7)),
("__mod__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 8)),
("__pow__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 9)),
("__iadd__".into(), ndarray_iadd_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 10)),
("__isub__".into(), ndarray_isub_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 11)),
("__imul__".into(), ndarray_imul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 12)),
("__itruediv__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 13)),
("__ifloordiv__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 14)),
("__imod__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 15)),
("__ipow__".into(), ndarray_imul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 16)),
],
ancestors: Vec::default(),
constructor: None,
@ -559,6 +597,216 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__add__".into(),
simple_name: "__add__".into(),
signature: ndarray_add_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__sub__".into(),
simple_name: "__sub__".into(),
signature: ndarray_sub_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__mul__".into(),
simple_name: "__mul__".into(),
signature: ndarray_mul_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__truediv__".into(),
simple_name: "__truediv__".into(),
signature: ndarray_truediv_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__floordiv__".into(),
simple_name: "__floordiv__".into(),
signature: ndarray_floordiv_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__mod__".into(),
simple_name: "__mod__".into(),
signature: ndarray_mod_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__pow__".into(),
simple_name: "__pow__".into(),
signature: ndarray_pow_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__iadd__".into(),
simple_name: "__iadd__".into(),
signature: ndarray_iadd_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id, ndarray_usized_ndims_tvar.1],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__isub__".into(),
simple_name: "__isub__".into(),
signature: ndarray_isub_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__imul__".into(),
simple_name: "__imul__".into(),
signature: ndarray_imul_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__itruediv__".into(),
simple_name: "__itruediv__".into(),
signature: ndarray_itruediv_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__ifloordiv__".into(),
simple_name: "__ifloordiv__".into(),
signature: ndarray_ifloordiv_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__imod__".into(),
simple_name: "__imod__".into(),
signature: ndarray_imod_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__ipow__".into(),
simple_name: "__ipow__".into(),
signature: ndarray_ipow_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "int32".into(),
simple_name: "int32".into(),

View File

@ -1926,9 +1926,8 @@ impl TopLevelComposer {
ret_str,
name,
ast.as_ref().unwrap().location
),
]))
}
),]))
}
instance_to_stmt.insert(
get_subst_key(unifier, self_type, &subst, Some(&vars.keys().copied().collect())),

View File

@ -1,6 +1,7 @@
use std::convert::TryInto;
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::numpy::subst_ndarray_tvars;
use crate::typecheck::typedef::{Mapping, VarMap};
use nac3parser::ast::{Constant, Location};
@ -226,11 +227,57 @@ impl TopLevelComposer {
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
]),
}));
let ndarray_binop_fun_other_ty = unifier.get_fresh_var(None, None);
let ndarray_binop_fun_ret_ty = unifier.get_fresh_var(None, None);
let ndarray_binop_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg {
name: "other".into(),
ty: ndarray_binop_fun_other_ty.0,
default_value: None,
},
],
ret: ndarray_binop_fun_ret_ty.0,
vars: VarMap::from([
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
]),
}));
let ndarray_truediv_fun_other_ty = unifier.get_fresh_var(None, None);
let ndarray_truediv_fun_ret_ty = unifier.get_fresh_var(None, None);
let ndarray_truediv_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg {
name: "other".into(),
ty: ndarray_truediv_fun_other_ty.0,
default_value: None,
},
],
ret: ndarray_truediv_fun_ret_ty.0,
vars: VarMap::from([
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
]),
}));
let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.ndarray,
fields: Mapping::from([
("copy".into(), (ndarray_copy_fun_ty, true)),
("fill".into(), (ndarray_fill_fun_ty, true)),
("__add__".into(), (ndarray_binop_fun_ty, true)),
("__sub__".into(), (ndarray_binop_fun_ty, true)),
("__mul__".into(), (ndarray_binop_fun_ty, true)),
("__truediv__".into(), (ndarray_truediv_fun_ty, true)),
("__floordiv__".into(), (ndarray_binop_fun_ty, true)),
("__mod__".into(), (ndarray_binop_fun_ty, true)),
("__pow__".into(), (ndarray_binop_fun_ty, true)),
("__iadd__".into(), (ndarray_binop_fun_ty, true)),
("__isub__".into(), (ndarray_binop_fun_ty, true)),
("__imul__".into(), (ndarray_binop_fun_ty, true)),
("__itruediv__".into(), (ndarray_truediv_fun_ty, true)),
("__ifloordiv__".into(), (ndarray_binop_fun_ty, true)),
("__imod__".into(), (ndarray_binop_fun_ty, true)),
("__ipow__".into(), (ndarray_binop_fun_ty, true)),
]),
params: VarMap::from([
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
@ -238,7 +285,16 @@ impl TopLevelComposer {
]),
});
let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None);
let ndarray_unsized = subst_ndarray_tvars(&mut unifier, ndarray, Some(ndarray_usized_ndims_tvar.0), None);
unifier.unify(ndarray_copy_fun_ret_ty.0, ndarray).unwrap();
unifier.unify(ndarray_binop_fun_other_ty.0, ndarray_unsized).unwrap();
unifier.unify(ndarray_binop_fun_ret_ty.0, ndarray).unwrap();
let ndarray_float = subst_ndarray_tvars(&mut unifier, ndarray, Some(float), None);
unifier.unify(ndarray_truediv_fun_other_ty.0, ndarray).unwrap();
unifier.unify(ndarray_truediv_fun_ret_ty.0, ndarray_float).unwrap();
let primitives = PrimitiveStore {
int32,

View File

@ -1,24 +1,9 @@
use inkwell::{IntPredicate, types::BasicType, values::{BasicValueEnum, PointerValue}};
use inkwell::values::{AggregateValueEnum, ArrayValue, IntValue};
use itertools::Itertools;
use nac3parser::ast::StrRef;
use crate::{
codegen::{
classes::{ListValue, NDArrayValue},
CodeGenContext,
CodeGenerator,
irrt::{
call_ndarray_calc_nd_indices,
call_ndarray_calc_size,
},
llvm_intrinsics::call_memcpy_generic,
stmt::gen_for_callback
},
symbol_resolver::ValueEnum,
toplevel::{DefinitionId, helper::PRIMITIVE_DEF_IDS},
toplevel::helper::PRIMITIVE_DEF_IDS,
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{FunSignature, Type, TypeEnum, Unifier, VarMap},
typedef::{Type, TypeEnum, Unifier, VarMap},
},
};
@ -34,16 +19,32 @@ pub fn make_ndarray_ty(
dtype: Option<Type>,
ndims: Option<Type>,
) -> Type {
let ndarray = primitives.ndarray;
subst_ndarray_tvars(unifier, primitives.ndarray, dtype, ndims)
}
/// Substitutes type variables in `ndarray`.
///
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
pub fn subst_ndarray_tvars(
unifier: &mut Unifier,
ndarray: Type,
dtype: Option<Type>,
ndims: Option<Type>,
) -> Type {
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
};
debug_assert_eq!(*obj_id, PRIMITIVE_DEF_IDS.ndarray);
if dtype.is_none() && ndims.is_none() {
return ndarray
}
let tvar_ids = params.iter()
.map(|(obj_id, _)| *obj_id)
.sorted()
.collect_vec();
debug_assert_eq!(tvar_ids.len(), 2);
@ -58,12 +59,10 @@ pub fn make_ndarray_ty(
unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray)
}
/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to
/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively.
pub fn unpack_ndarray_tvars(
fn unpack_ndarray_tvars(
unifier: &mut Unifier,
ndarray: Type,
) -> (Type, Type) {
) -> Vec<(u32, Type)> {
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
};
@ -72,889 +71,33 @@ pub fn unpack_ndarray_tvars(
params.iter()
.sorted_by_key(|(obj_id, _)| *obj_id)
.map(|(_, ty)| *ty)
.map(|(var_id, ty)| (*var_id, *ty))
.collect_vec()
}
/// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds
/// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray`
/// respectively.
pub fn unpack_ndarray_var_ids(
unifier: &mut Unifier,
ndarray: Type,
) -> (u32, u32) {
unpack_ndarray_tvars(unifier, ndarray)
.into_iter()
.map(|v| v.0)
.collect_tuple()
.unwrap()
}
/// Creates an `NDArray` instance from a dynamic shape.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The shape of the `NDArray`.
/// * `shape_len_fn` - A function that retrieves the number of dimensions from `shape`.
/// * `shape_data_fn` - A function that retrieves the size of a dimension from `shape`.
fn create_ndarray_dyn_shape<'ctx, 'a, V, LenFn, DataFn>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>,
elem_ty: Type,
shape: &V,
shape_len_fn: LenFn,
shape_data_fn: DataFn,
) -> Result<NDArrayValue<'ctx>, String>
where
LenFn: Fn(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, &V) -> Result<IntValue<'ctx>, String>,
DataFn: Fn(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, &V, IntValue<'ctx>) -> Result<IntValue<'ctx>, String>,
{
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum();
assert!(llvm_ndarray_data_t.is_sized());
// Assert that all dimensions are non-negative
gen_for_callback(
generator,
ctx,
|generator, ctx| {
let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
ctx.builder.build_store(i, llvm_usize.const_zero()).unwrap();
Ok(i)
},
|generator, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let shape_len = shape_len_fn(generator, ctx, shape)?;
debug_assert!(shape_len.get_type().get_bit_width() <= llvm_usize.get_bit_width());
Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, shape_len, "").unwrap())
},
|generator, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
let shape_dim_gez = ctx.builder
.build_int_compare(IntPredicate::SGE, shape_dim, shape_dim.get_type().const_zero(), "")
.unwrap();
ctx.make_assert(
generator,
shape_dim_gez,
"0:ValueError",
"negative dimensions not supported",
[None, None, None],
ctx.current_loc,
);
Ok(())
},
|_, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "").unwrap();
ctx.builder.build_store(i_addr, i).unwrap();
Ok(())
},
)?;
let ndarray = generator.gen_var_alloc(
ctx,
llvm_ndarray_t.into(),
None,
)?;
let ndarray = NDArrayValue::from_ptr_val(ndarray, llvm_usize, None);
let num_dims = shape_len_fn(generator, ctx, shape)?;
ndarray.store_ndims(ctx, generator, num_dims);
let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
// Copy the dimension sizes from shape to ndarray.dims
gen_for_callback(
generator,
ctx,
|generator, ctx| {
let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
ctx.builder.build_store(i, llvm_usize.const_zero()).unwrap();
Ok(i)
},
|generator, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let shape_len = shape_len_fn(generator, ctx, shape)?;
debug_assert!(shape_len.get_type().get_bit_width() <= llvm_usize.get_bit_width());
Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, shape_len, "").unwrap())
},
|generator, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let shape_dim = shape_data_fn(generator, ctx, shape, i)?;
debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width());
let shape_dim = ctx.builder
.build_int_z_extend(shape_dim, llvm_usize, "")
.unwrap();
let ndarray_pdim = ndarray.dim_sizes().ptr_offset(ctx, generator, i, None);
ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap();
Ok(())
},
|_, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "").unwrap();
ctx.builder.build_store(i_addr, i).unwrap();
Ok(())
},
)?;
let ndarray_num_elems = call_ndarray_calc_size(
generator,
ctx,
ndarray.load_ndims(ctx),
ndarray.dim_sizes().as_ptr_value(ctx),
);
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
Ok(ndarray)
}
/// Creates an `NDArray` instance from a constant shape.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The shape of the `NDArray`, represented as an LLVM [`ArrayValue`].
fn create_ndarray_const_shape<'ctx>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: ArrayValue<'ctx>
) -> Result<NDArrayValue<'ctx>, String> {
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum();
assert!(llvm_ndarray_data_t.is_sized());
for i in 0..shape.get_type().len() {
let shape_dim = ctx.builder
.build_extract_value(shape, i, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let shape_dim_gez = ctx.builder
.build_int_compare(IntPredicate::SGE, shape_dim, llvm_usize.const_zero(), "")
.unwrap();
ctx.make_assert(
generator,
shape_dim_gez,
"0:ValueError",
"negative dimensions not supported",
[None, None, None],
ctx.current_loc,
);
}
let ndarray = generator.gen_var_alloc(
ctx,
llvm_ndarray_t.into(),
None,
)?;
let ndarray = NDArrayValue::from_ptr_val(ndarray, llvm_usize, None);
let num_dims = llvm_usize.const_int(shape.get_type().len() as u64, false);
ndarray.store_ndims(ctx, generator, num_dims);
let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
for i in 0..shape.get_type().len() {
let ndarray_dim = ndarray
.dim_sizes()
.ptr_offset(ctx, generator, llvm_usize.const_int(i as u64, true), None);
let shape_dim = ctx.builder.build_extract_value(shape, i, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
ctx.builder.build_store(ndarray_dim, shape_dim).unwrap();
}
let ndarray_dims = ndarray.dim_sizes().as_ptr_value(ctx);
let ndarray_num_elems = call_ndarray_calc_size(
generator,
ctx,
ndarray.load_ndims(ctx),
ndarray_dims,
);
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
Ok(ndarray)
}
fn ndarray_zero_value<'ctx>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
) -> BasicValueEnum<'ctx> {
if [ctx.primitives.int32, ctx.primitives.uint32].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) {
ctx.ctx.i32_type().const_zero().into()
} else if [ctx.primitives.int64, ctx.primitives.uint64].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) {
ctx.ctx.i64_type().const_zero().into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) {
ctx.ctx.f64_type().const_zero().into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
ctx.ctx.bool_type().const_zero().into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
ctx.gen_string(generator, "")
} else {
unreachable!()
}
}
fn ndarray_one_value<'ctx>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
) -> BasicValueEnum<'ctx> {
if [ctx.primitives.int32, ctx.primitives.uint32].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) {
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int32);
ctx.ctx.i32_type().const_int(1, is_signed).into()
} else if [ctx.primitives.int64, ctx.primitives.uint64].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) {
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int64);
ctx.ctx.i64_type().const_int(1, is_signed).into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) {
ctx.ctx.f64_type().const_float(1.0).into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
ctx.ctx.bool_type().const_int(1, false).into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
ctx.gen_string(generator, "1")
} else {
unreachable!()
}
}
/// LLVM-typed implementation for generating the implementation for constructing an `NDArray`.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
fn call_ndarray_empty_impl<'ctx>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: ListValue<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&shape,
|_, ctx, shape| {
Ok(shape.load_size(ctx, None))
},
|generator, ctx, shape, idx| {
Ok(shape.data().get(ctx, generator, idx, None).into_int_value())
},
)
}
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as
/// its input.
fn ndarray_fill_flattened<'ctx, 'a, ValueFn>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, 'a>,
ndarray: NDArrayValue<'ctx>,
value_fn: ValueFn,
) -> Result<(), String>
where
ValueFn: Fn(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<BasicValueEnum<'ctx>, String>,
{
let llvm_usize = generator.get_size_type(ctx.ctx);
let ndarray_num_elems = call_ndarray_calc_size(
generator,
ctx,
ndarray.load_ndims(ctx),
ndarray.dim_sizes().as_ptr_value(ctx),
);
gen_for_callback(
generator,
ctx,
|generator, ctx| {
let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
ctx.builder.build_store(i, llvm_usize.const_zero()).unwrap();
Ok(i)
},
|_, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, ndarray_num_elems, "").unwrap())
},
|generator, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let elem = unsafe {
ndarray.data().ptr_to_data_flattened_unchecked(ctx, i, None)
};
let value = value_fn(generator, ctx, i)?;
ctx.builder.build_store(elem, value).unwrap();
Ok(())
},
|_, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "").unwrap();
ctx.builder.build_store(i_addr, i).unwrap();
Ok(())
},
)
}
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with the dimension-indices
/// as its input.
fn ndarray_fill_indexed<'ctx, ValueFn>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
value_fn: ValueFn,
) -> Result<(), String>
where
ValueFn: Fn(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, '_>, PointerValue<'ctx>) -> Result<BasicValueEnum<'ctx>, String>,
{
ndarray_fill_flattened(
generator,
ctx,
ndarray,
|generator, ctx, idx| {
let indices = call_ndarray_calc_nd_indices(
generator,
ctx,
idx,
ndarray,
);
value_fn(generator, ctx, indices)
}
)
}
/// LLVM-typed implementation for generating the implementation for `ndarray.zeros`.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
fn call_ndarray_zeros_impl<'ctx>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: ListValue<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
let supported_types = [
ctx.primitives.int32,
ctx.primitives.int64,
ctx.primitives.uint32,
ctx.primitives.uint64,
ctx.primitives.float,
ctx.primitives.bool,
ctx.primitives.str,
];
assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty)));
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
ndarray_fill_flattened(
generator,
ctx,
ndarray,
|generator, ctx, _| {
let value = ndarray_zero_value(generator, ctx, elem_ty);
Ok(value)
}
)?;
Ok(ndarray)
}
/// LLVM-typed implementation for generating the implementation for `ndarray.ones`.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
fn call_ndarray_ones_impl<'ctx>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: ListValue<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
let supported_types = [
ctx.primitives.int32,
ctx.primitives.int64,
ctx.primitives.uint32,
ctx.primitives.uint64,
ctx.primitives.float,
ctx.primitives.bool,
ctx.primitives.str,
];
assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty)));
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
ndarray_fill_flattened(
generator,
ctx,
ndarray,
|generator, ctx, _| {
let value = ndarray_one_value(generator, ctx, elem_ty);
Ok(value)
}
)?;
Ok(ndarray)
}
/// LLVM-typed implementation for generating the implementation for `ndarray.full`.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
fn call_ndarray_full_impl<'ctx>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: ListValue<'ctx>,
fill_value: BasicValueEnum<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
ndarray_fill_flattened(
generator,
ctx,
ndarray,
|generator, ctx, _| {
let value = if fill_value.is_pointer_value() {
let llvm_i1 = ctx.ctx.bool_type();
let copy = generator.gen_var_alloc(ctx, fill_value.get_type(), None)?;
call_memcpy_generic(
ctx,
copy,
fill_value.into_pointer_value(),
fill_value.get_type().size_of().map(Into::into).unwrap(),
llvm_i1.const_zero(),
);
copy.into()
} else if fill_value.is_int_value() || fill_value.is_float_value() {
fill_value
} else {
unreachable!()
};
Ok(value)
}
)?;
Ok(ndarray)
}
/// LLVM-typed implementation for generating the implementation for `ndarray.eye`.
///
/// * `elem_ty` - The element type of the `NDArray`.
fn call_ndarray_eye_impl<'ctx>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
nrows: IntValue<'ctx>,
ncols: IntValue<'ctx>,
offset: IntValue<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_usize_2 = llvm_usize.array_type(2);
let shape_addr = generator.gen_var_alloc(ctx, llvm_usize_2.into(), None)?;
let shape = ctx.builder.build_load(shape_addr, "")
.map(BasicValueEnum::into_array_value)
.unwrap();
let nrows = ctx.builder.build_int_z_extend_or_bit_cast(nrows, llvm_usize, "").unwrap();
let shape = ctx.builder
.build_insert_value(shape, nrows, 0, "")
.map(AggregateValueEnum::into_array_value)
.unwrap();
let ncols = ctx.builder.build_int_z_extend_or_bit_cast(ncols, llvm_usize, "").unwrap();
let shape = ctx.builder
.build_insert_value(shape, ncols, 1, "")
.map(AggregateValueEnum::into_array_value)
.unwrap();
let ndarray = create_ndarray_const_shape(generator, ctx, elem_ty, shape)?;
ndarray_fill_indexed(
generator,
ctx,
ndarray,
|generator, ctx, indices| {
let row = ctx.build_gep_and_load(
indices,
&[llvm_usize.const_int(0, false)],
None,
).into_int_value();
let col = ctx.build_gep_and_load(
indices,
&[llvm_usize.const_int(1, false)],
None,
).into_int_value();
let col_with_offset = ctx.builder
.build_int_add(
col,
ctx.builder.build_int_s_extend_or_bit_cast(offset, llvm_usize, "").unwrap(),
"",
)
.unwrap();
let is_on_diag = ctx.builder
.build_int_compare(IntPredicate::EQ, row, col_with_offset, "")
.unwrap();
let zero = ndarray_zero_value(generator, ctx, elem_ty);
let one = ndarray_one_value(generator, ctx, elem_ty);
let value = ctx.builder.build_select(is_on_diag, one, zero, "").unwrap();
Ok(value)
},
)?;
Ok(ndarray)
}
/// LLVM-typed implementation for generating the implementation for `ndarray.copy`.
///
/// * `elem_ty` - The element type of the `NDArray`.
fn ndarray_copy_impl<'ctx>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
this: NDArrayValue<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
let llvm_i1 = ctx.ctx.bool_type();
let ndarray = create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&this,
|_, ctx, shape| {
Ok(shape.load_ndims(ctx))
},
|generator, ctx, shape, idx| {
Ok(shape.dim_sizes().get(ctx, generator, idx, None))
},
)?;
let len = call_ndarray_calc_size(
generator,
ctx,
ndarray.load_ndims(ctx),
ndarray.dim_sizes().as_ptr_value(ctx),
);
let sizeof_ty = ctx.get_llvm_type(generator, elem_ty);
let len_bytes = ctx.builder
.build_int_mul(
len,
sizeof_ty.size_of().unwrap(),
"",
)
.unwrap();
call_memcpy_generic(
ctx,
ndarray.data().as_ptr_value(ctx),
this.data().as_ptr_value(ctx),
len_bytes,
llvm_i1.const_zero(),
);
Ok(ndarray)
}
/// Generates LLVM IR for `ndarray.empty`.
pub fn gen_ndarray_empty<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let llvm_usize = generator.get_size_type(context.ctx);
let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone()
.to_basic_value_enum(context, generator, shape_ty)?;
call_ndarray_empty_impl(
generator,
context,
context.primitives.float,
ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None),
).map(NDArrayValue::into)
}
/// Generates LLVM IR for `ndarray.zeros`.
pub fn gen_ndarray_zeros<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let llvm_usize = generator.get_size_type(context.ctx);
let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone()
.to_basic_value_enum(context, generator, shape_ty)?;
call_ndarray_zeros_impl(
generator,
context,
context.primitives.float,
ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None),
).map(NDArrayValue::into)
}
/// Generates LLVM IR for `ndarray.ones`.
pub fn gen_ndarray_ones<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let llvm_usize = generator.get_size_type(context.ctx);
let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone()
.to_basic_value_enum(context, generator, shape_ty)?;
call_ndarray_ones_impl(
generator,
context,
context.primitives.float,
ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None),
).map(NDArrayValue::into)
}
/// Generates LLVM IR for `ndarray.full`.
pub fn gen_ndarray_full<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 2);
let llvm_usize = generator.get_size_type(context.ctx);
let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone()
.to_basic_value_enum(context, generator, shape_ty)?;
let fill_value_ty = fun.0.args[1].ty;
let fill_value_arg = args[1].1.clone()
.to_basic_value_enum(context, generator, fill_value_ty)?;
call_ndarray_full_impl(
generator,
context,
fill_value_ty,
ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None),
fill_value_arg,
).map(NDArrayValue::into)
}
/// Generates LLVM IR for `ndarray.eye`.
pub fn gen_ndarray_eye<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert!(matches!(args.len(), 1..=3));
let nrows_ty = fun.0.args[0].ty;
let nrows_arg = args[0].1.clone()
.to_basic_value_enum(context, generator, nrows_ty)?;
let ncols_ty = fun.0.args[1].ty;
let ncols_arg = args.iter()
.find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name))
.map(|arg| arg.1.clone().to_basic_value_enum(context, generator, ncols_ty))
.unwrap_or_else(|| {
args[0].1.clone().to_basic_value_enum(context, generator, nrows_ty)
})?;
let offset_ty = fun.0.args[2].ty;
let offset_arg = args.iter()
.find(|arg| arg.0.is_some_and(|name| name == fun.0.args[2].name))
.map(|arg| arg.1.clone().to_basic_value_enum(context, generator, offset_ty))
.unwrap_or_else(|| {
Ok(context.gen_symbol_val(
generator,
fun.0.args[2].default_value.as_ref().unwrap(),
offset_ty
))
})?;
call_ndarray_eye_impl(
generator,
context,
context.primitives.float,
nrows_arg.into_int_value(),
ncols_arg.into_int_value(),
offset_arg.into_int_value(),
).map(NDArrayValue::into)
}
/// Generates LLVM IR for `ndarray.identity`.
pub fn gen_ndarray_identity<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let llvm_usize = generator.get_size_type(context.ctx);
let n_ty = fun.0.args[0].ty;
let n_arg = args[0].1.clone()
.to_basic_value_enum(context, generator, n_ty)?;
call_ndarray_eye_impl(
generator,
context,
context.primitives.float,
n_arg.into_int_value(),
n_arg.into_int_value(),
llvm_usize.const_zero(),
).map(NDArrayValue::into)
}
/// Generates LLVM IR for `ndarray.copy`.
pub fn gen_ndarray_copy<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
_fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_some());
assert!(args.is_empty());
let llvm_usize = generator.get_size_type(context.ctx);
let this_ty = obj.as_ref().unwrap().0;
let (this_elem_ty, _) = unpack_ndarray_tvars(&mut context.unifier, this_ty);
let this_arg = obj
.as_ref()
/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to
/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively.
pub fn unpack_ndarray_var_tys(
unifier: &mut Unifier,
ndarray: Type,
) -> (Type, Type) {
unpack_ndarray_tvars(unifier, ndarray)
.into_iter()
.map(|v| v.1)
.collect_tuple()
.unwrap()
.1
.clone()
.to_basic_value_enum(context, generator, this_ty)?;
ndarray_copy_impl(
generator,
context,
this_elem_ty,
NDArrayValue::from_ptr_val(this_arg.into_pointer_value(), llvm_usize, None),
).map(NDArrayValue::into)
}
/// Generates LLVM IR for `ndarray.fill`.
pub fn gen_ndarray_fill<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<(), String> {
assert!(obj.is_some());
assert_eq!(args.len(), 1);
let llvm_usize = generator.get_size_type(context.ctx);
let this_ty = obj.as_ref().unwrap().0;
let this_arg = obj.as_ref().unwrap().1.clone()
.to_basic_value_enum(context, generator, this_ty)?
.into_pointer_value();
let value_ty = fun.0.args[0].ty;
let value_arg = args[0].1.clone()
.to_basic_value_enum(context, generator, value_ty)?;
ndarray_fill_flattened(
generator,
context,
NDArrayValue::from_ptr_val(this_arg, llvm_usize, None),
|generator, ctx, _| {
let value = if value_arg.is_pointer_value() {
let llvm_i1 = ctx.ctx.bool_type();
let copy = generator.gen_var_alloc(ctx, value_arg.get_type(), None)?;
call_memcpy_generic(
ctx,
copy,
value_arg.into_pointer_value(),
value_arg.get_type().size_of().map(Into::into).unwrap(),
llvm_i1.const_zero(),
);
copy.into()
} else if value_arg.is_int_value() || value_arg.is_float_value() {
value_arg
} else {
unreachable!()
};
Ok(value)
}
)?;
Ok(())
}

View File

@ -7,7 +7,7 @@ expression: res_vec
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [32]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [37]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[int32, list[float]]], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
]

View File

@ -3,11 +3,11 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec
---
[
"Class {\nname: \"A\",\nancestors: [\"A[typevar18, typevar19]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[bool, float], b:B], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\")],\ntype_vars: [\"typevar18\", \"typevar19\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[bool, float], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[bool, float]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[int32, list[B]]], tuple[A[bool, virtual[A[B, int32]]], B]]\")],\ntype_vars: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[typevar18, typevar19]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar18\", \"typevar19\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:B], B]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.bar\",\nsig: \"fn[[a:A[int32, list[B]]], tuple[A[bool, virtual[A[B, int32]]], B]]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.bar\",\nsig: \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\",\nvar_id: []\n}\n",
]

View File

@ -1,5 +1,6 @@
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PRIMITIVE_DEF_IDS;
use crate::typecheck::typedef::VarMap;
use super::*;
use nac3parser::ast::Constant;

View File

@ -1,3 +1,4 @@
use crate::toplevel::numpy::make_ndarray_ty;
use crate::typecheck::{
type_inferencer::*,
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
@ -234,8 +235,14 @@ pub fn impl_bitwise_shift(unifier: &mut Unifier, store: &PrimitiveStore, ty: Typ
}
/// `Div`
pub fn impl_div(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type]) {
impl_binop(unifier, store, ty, other_ty, store.float, &[Operator::Div]);
pub fn impl_div(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Type,
) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Div]);
}
/// `FloorDiv`
@ -299,8 +306,10 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
bool: bool_t,
uint32: uint32_t,
uint64: uint64_t,
ndarray: ndarray_t,
..
} = *store;
let size_t = store.usize();
/* int ======== */
for t in [int32_t, int64_t, uint32_t, uint64_t] {
@ -308,7 +317,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
impl_pow(unifier, store, t, &[t], t);
impl_bitwise_arithmetic(unifier, store, t);
impl_bitwise_shift(unifier, store, t);
impl_div(unifier, store, t, &[t]);
impl_div(unifier, store, t, &[t], float_t);
impl_floordiv(unifier, store, t, &[t], t);
impl_mod(unifier, store, t, &[t], t);
impl_invert(unifier, store, t);
@ -323,7 +332,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
/* float ======== */
impl_basic_arithmetic(unifier, store, float_t, &[float_t], float_t);
impl_pow(unifier, store, float_t, &[int32_t, float_t], float_t);
impl_div(unifier, store, float_t, &[float_t]);
impl_div(unifier, store, float_t, &[float_t], float_t);
impl_floordiv(unifier, store, float_t, &[float_t], float_t);
impl_mod(unifier, store, float_t, &[float_t], float_t);
impl_sign(unifier, store, float_t);
@ -334,4 +343,14 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
/* bool ======== */
impl_not(unifier, store, bool_t);
impl_eq(unifier, store, bool_t);
/* ndarray ===== */
let ndarray_float_t = make_ndarray_ty(unifier, store, Some(float_t), None);
let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None);
let ndarray_unsized_t = make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.0));
impl_basic_arithmetic(unifier, store, ndarray_t, &[ndarray_unsized_t], ndarray_t);
impl_pow(unifier, store, ndarray_t, &[ndarray_unsized_t], ndarray_t);
impl_div(unifier, store, ndarray_t, &[ndarray_t], ndarray_float_t);
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t], ndarray_t);
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t], ndarray_t);
}

View File

@ -9,7 +9,7 @@ use crate::{
symbol_resolver::{SymbolResolver, SymbolValue},
toplevel::{
helper::PRIMITIVE_DEF_IDS,
numpy::{make_ndarray_ty, unpack_ndarray_tvars},
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
TopLevelContext,
},
};
@ -1334,7 +1334,7 @@ impl<'a> Inferencer<'a> {
let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) {
TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }),
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
let (_, ndims) = unpack_ndarray_tvars(self.unifier, value.custom.unwrap());
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims))
}
@ -1347,7 +1347,7 @@ impl<'a> Inferencer<'a> {
ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
match &*self.unifier.get_ty(value.custom.unwrap()) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
let (_, ndims) = unpack_ndarray_tvars(self.unifier, value.custom.unwrap());
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
self.infer_subscript_ndarray(value, ty, ndims)
}
_ => {
@ -1379,7 +1379,7 @@ impl<'a> Inferencer<'a> {
Ok(ty)
}
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
let (_, ndims) = unpack_ndarray_tvars(self.unifier, value.custom.unwrap());
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
self.constrain(slice.custom.unwrap(), self.primitives.usize(), &slice.location)?;
self.infer_subscript_ndarray(value, ty, ndims)

View File

@ -1,10 +1,11 @@
use std::cell::RefCell;
use std::collections::{BTreeMap, HashMap};
use std::collections::HashMap;
use std::fmt::Display;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use std::{borrow::Cow, collections::HashSet};
use std::iter::zip;
use indexmap::IndexMap;
use itertools::Itertools;
use nac3parser::ast::{Location, StrRef};
@ -25,14 +26,10 @@ pub type Type = UnificationKey;
pub struct CallId(pub(super) usize);
pub type Mapping<K, V = Type> = HashMap<K, V>;
pub type IndexMapping<K, V = Type> = IndexMap<K, V>;
/// A [`Mapping`] sorted by its key.
///
/// This type is recommended for mappings that should be stored and/or iterated by its sorted key.
pub type SortedMapping<K, V = Type> = BTreeMap<K, V>;
/// A [`BTreeMap`] storing the mapping between type variable ID and [unifier type][`Type`].
pub type VarMap = SortedMapping<u32>;
/// The mapping between type variable ID and [unifier type][`Type`].
pub type VarMap = IndexMapping<u32>;
#[derive(Clone)]
pub struct Call {
@ -699,7 +696,7 @@ impl Unifier {
self.set_a_to_b(a, x);
}
(TVar { fields: Some(fields), range, is_const_generic: false, .. }, TTuple { ty }) => {
let len = ty.len() as i32;
let len = i32::try_from(ty.len()).unwrap();
for (k, v) in fields {
match *k {
RecordKey::Int(i) => {
@ -920,8 +917,8 @@ impl Unifier {
// Sort the type arguments by its UnificationKey first, since `HashMap::iter` visits
// all K-V pairs "in arbitrary order"
let (tv1, tv2) = (
params1.iter().sorted_by_key(|(k, _)| *k).map(|(_, v)| v).collect_vec(),
params2.iter().sorted_by_key(|(k, _)| *k).map(|(_, v)| v).collect_vec(),
params1.iter().map(|(_, v)| v).collect_vec(),
params2.iter().map(|(_, v)| v).collect_vec(),
);
for (x, y) in zip(tv1, tv2) {
if self.unify_impl(*x, *y, false).is_err() {
@ -1097,11 +1094,9 @@ impl Unifier {
if params.is_empty() {
name
} else {
let params = params
let mut params = params
.iter()
.map(|(_, v)| self.internal_stringify(*v, obj_to_name, var_to_name, notes));
// sort to preserve order
let mut params = params.sorted();
format!("{}[{}]", name, params.join(", "))
}
}
@ -1283,12 +1278,12 @@ impl Unifier {
fn subst_map<K>(
&mut self,
map: &SortedMapping<K>,
map: &IndexMapping<K>,
mapping: &VarMap,
cache: &mut HashMap<Type, Option<Type>>,
) -> Option<SortedMapping<K>>
where
K: Ord + Eq + Clone,
) -> Option<IndexMapping<K>>
where
K: std::hash::Hash + Eq + Clone,
{
let mut map2 = None;
for (k, v) in map {

View File

@ -45,9 +45,9 @@ impl Unifier {
}
}
fn map_eq<K>(&mut self, map1: &SortedMapping<K>, map2: &SortedMapping<K>) -> bool
where
K: Ord + Eq + Clone,
fn map_eq<K>(&mut self, map1: &IndexMapping<K>, map2: &IndexMapping<K>) -> bool
where
K: std::hash::Hash + Eq + Clone
{
if map1.len() != map2.len() {
return false;

View File

@ -67,6 +67,181 @@ def test_ndarray_copy():
output_float64(y[1][0])
output_float64(y[1][1])
def test_ndarray_add():
x = np_identity(2)
y = x + np_ones([2, 2])
output_float64(x[0][0])
output_float64(x[0][1])
output_float64(x[1][0])
output_float64(x[1][1])
output_float64(y[0][0])
output_float64(y[0][1])
output_float64(y[1][0])
output_float64(y[1][1])
# def test_ndarray_add_broadcast():
# x = np_identity(2)
# y: ndarray[float, 2] = x + np_ones([2])
#
# output_float64(x[0][0])
# output_float64(x[0][1])
# output_float64(x[1][0])
# output_float64(x[1][1])
#
# output_float64(y[0][0])
# output_float64(y[0][1])
# output_float64(y[1][0])
# output_float64(y[1][1])
def test_ndarray_iadd():
x = np_identity(2)
x += np_ones([2, 2])
output_float64(x[0][0])
output_float64(x[0][1])
output_float64(x[1][0])
output_float64(x[1][1])
def test_ndarray_sub():
x = np_ones([2, 2])
y = x - np_identity(2)
output_float64(x[0][0])
output_float64(x[0][1])
output_float64(x[1][0])
output_float64(x[1][1])
output_float64(y[0][0])
output_float64(y[0][1])
output_float64(y[1][0])
output_float64(y[1][1])
def test_ndarray_isub():
x = np_ones([2, 2])
x -= np_identity(2)
output_float64(x[0][0])
output_float64(x[0][1])
output_float64(x[1][0])
output_float64(x[1][1])
def test_ndarray_mul():
x = np_ones([2, 2])
y = x * np_identity(2)
output_float64(x[0][0])
output_float64(x[0][1])
output_float64(x[1][0])
output_float64(x[1][1])
output_float64(y[0][0])
output_float64(y[0][1])
output_float64(y[1][0])
output_float64(y[1][1])
def test_ndarray_imul():
x = np_ones([2, 2])
x *= np_identity(2)
output_float64(x[0][0])
output_float64(x[0][1])
output_float64(x[1][0])
output_float64(x[1][1])
def test_ndarray_truediv():
x = np_identity(2)
y = x / np_ones([2, 2])
output_float64(x[0][0])
output_float64(x[0][1])
output_float64(x[1][0])
output_float64(x[1][1])
output_float64(y[0][0])
output_float64(y[0][1])
output_float64(y[1][0])
output_float64(y[1][1])
def test_ndarray_itruediv():
x = np_identity(2)
x /= np_ones([2, 2])
output_float64(x[0][0])
output_float64(x[0][1])
output_float64(x[1][0])
output_float64(x[1][1])
def test_ndarray_floordiv():
x = np_identity(2)
y = x // np_ones([2, 2])
output_float64(x[0][0])
output_float64(x[0][1])
output_float64(x[1][0])
output_float64(x[1][1])
output_float64(y[0][0])
output_float64(y[0][1])
output_float64(y[1][0])
output_float64(y[1][1])
def test_ndarray_ifloordiv():
x = np_identity(2)
x //= np_ones([2, 2])
output_float64(x[0][0])
output_float64(x[0][1])
output_float64(x[1][0])
output_float64(x[1][1])
def test_ndarray_mod():
x = np_identity(2)
y = x % np_full([2, 2], 2.0)
output_float64(x[0][0])
output_float64(x[0][1])
output_float64(x[1][0])
output_float64(x[1][1])
output_float64(y[0][0])
output_float64(y[0][1])
output_float64(y[1][0])
output_float64(y[1][1])
def test_ndarray_imod():
x = np_identity(2)
x %= np_full([2, 2], 2.0)
output_float64(x[0][0])
output_float64(x[0][1])
output_float64(x[1][0])
output_float64(x[1][1])
def test_ndarray_pow():
x = np_identity(2)
y = x ** np_full([2, 2], 2.0)
output_float64(x[0][0])
output_float64(x[0][1])
output_float64(x[1][0])
output_float64(x[1][1])
output_float64(y[0][0])
output_float64(y[0][1])
output_float64(y[1][0])
output_float64(y[1][1])
def test_ndarray_ipow():
x = np_identity(2)
x **= np_full([2, 2], 2.0)
output_float64(x[0][0])
output_float64(x[0][1])
output_float64(x[1][0])
output_float64(x[1][1])
def run() -> int32:
test_ndarray_ctor()
test_ndarray_empty()
@ -77,5 +252,17 @@ def run() -> int32:
test_ndarray_identity()
test_ndarray_fill()
test_ndarray_copy()
test_ndarray_add()
test_ndarray_iadd()
test_ndarray_sub()
test_ndarray_isub()
test_ndarray_mul()
test_ndarray_imul()
test_ndarray_truediv()
test_ndarray_itruediv()
test_ndarray_floordiv()
test_ndarray_ifloordiv()
test_ndarray_mod()
test_ndarray_imod()
return 0

View File

@ -74,7 +74,8 @@ impl SymbolResolver for Resolver {
if let Some(id) = str_store.get(s) {
*id
} else {
let id = str_store.len() as i32;
let id = i32::try_from(str_store.len())
.expect("Symbol resolver string store size exceeds max capacity (i32::MAX)");
str_store.insert(s.to_string(), id);
id
}

View File

@ -247,6 +247,8 @@ fn handle_assignment_pattern(
}
fn main() {
const SIZE_T: u32 = usize::BITS;
let cli = CommandLineArgs::parse();
let CommandLineArgs {
file_name,
@ -287,7 +289,6 @@ fn main() {
// The default behavior for -O<n> where n>3 defaults to O3 for both Clang and GCC
_ => OptimizationLevel::Aggressive,
};
const SIZE_T: u32 = 64;
let program = match fs::read_to_string(file_name.clone()) {
Ok(program) => program,