Compare commits

..

16 Commits

Author SHA1 Message Date
David Mak e69fab866c core: Implement elementwise comparison operators 2024-04-01 17:34:04 +08:00
David Mak 03a12f24d6 core: Implement elementwise unary operators 2024-04-01 17:34:04 +08:00
David Mak eaf4af21cb core: Implement elementwise binary operators
Including immediate variants of these operators.
2024-04-01 17:34:04 +08:00
David Mak 46aec66340 core/magic_methods: Add typeof_*op
Used to determine the expected type of the binary operator with
primitive operands.
2024-04-01 17:34:04 +08:00
David Mak ad3ddcddeb core/toplevel/numpy: Split ndarray type var utilities 2024-04-01 17:34:04 +08:00
David Mak 949339b0da core: Implement calculations for broadcasting ndarrays 2024-04-01 17:34:01 +08:00
David Mak 60be5b650d core/type_inferencer: Allow both int32 and isize when indexing ndarray 2024-04-01 17:06:33 +08:00
David Mak d6ed13d9da core/magic_methods: Allow unknown return types
These types can be later inferred by the type inferencer.
2024-04-01 17:06:31 +08:00
David Mak 1d1954e90e core/helper: Add PrimitiveDefinitionIds::iter 2024-04-01 16:48:25 +08:00
David Mak 630c93204a core/typedef: Add Type::obj_id to replace get_obj_id 2024-04-01 16:48:25 +08:00
David Mak 9a98cde595 core: Extract codegen portion of gen_*op_expr
This allows *ops to be generated internally using LLVM values as
input. Required in a future change.
2024-04-01 16:48:25 +08:00
David Mak 5ba8601b39 core: Remove ArrayValue variants of functions
These will be lowered and optimized away later anyways, and we have
ArrayLikeAccessor now.
2024-04-01 16:48:25 +08:00
David Mak 26a01b14d5 core: Use more typed slices in APIs 2024-04-01 16:48:25 +08:00
David Mak d5f4817134 core/builtins: Fix len() on ndarrays 2024-04-01 16:48:24 +08:00
David Mak 789bfb5a26 core: Fix index-based operations not returning i32 2024-04-01 16:46:45 +08:00
David Mak 4bb0e60981 core: Apply clippy suggestions 2024-04-01 16:46:41 +08:00
16 changed files with 245 additions and 153 deletions

1
.clippy.toml Normal file
View File

@ -0,0 +1 @@
doc-valid-idents = ["NumPy", ".."]

View File

@ -624,7 +624,7 @@ pub fn attributes_writeback(
let ty = ty.unwrap();
match &*ctx.unifier.get_ty(ty) {
TypeEnum::TObj { fields, obj_id, .. }
if *obj_id != ctx.primitives.option.get_obj_id(&ctx.unifier) =>
if *obj_id != ctx.primitives.option.obj_id(&ctx.unifier).unwrap() =>
{
// we only care about primitive attributes
// for non-primitive attributes, they should be in another global

View File

@ -38,10 +38,15 @@ pub enum PrimitiveValue {
Bool(bool),
}
/// An entry in the [`DeferredEvaluationStore`], containing the deferred types, a [`PyObject`]
/// representing the `__constraints__` of the type variables, and the name of the type to be
/// instantiated.
type DeferredEvaluationEntry = (Vec<Type>, PyObject, String);
#[derive(Clone)]
pub struct DeferredEvaluationStore {
needs_defer: Arc<AtomicBool>,
store: Arc<RwLock<Vec<(Vec<Type>, PyObject, String)>>>,
store: Arc<RwLock<Vec<DeferredEvaluationEntry>>>,
}
impl DeferredEvaluationStore {
@ -53,12 +58,18 @@ impl DeferredEvaluationStore {
}
}
/// A class field as stored in the [`InnerResolver`], represented by the ID and name of the
/// associated [`PythonValue`].
type ResolverField = (u64, StrRef);
/// A class field as stored in Python, represented by the `id()` and [`PyObject`] of the field.
type PyFieldHandle = (u64, PyObject);
pub struct InnerResolver {
pub id_to_type: RwLock<HashMap<StrRef, Type>>,
pub id_to_def: RwLock<HashMap<StrRef, DefinitionId>>,
pub id_to_pyval: RwLock<HashMap<StrRef, (u64, PyObject)>>,
pub id_to_primitive: RwLock<HashMap<u64, PrimitiveValue>>,
pub field_to_val: RwLock<HashMap<(u64, StrRef), Option<(u64, PyObject)>>>,
pub field_to_val: RwLock<HashMap<ResolverField, Option<PyFieldHandle>>>,
pub global_value_ids: Arc<RwLock<HashMap<u64, PyObject>>>,
pub class_names: Mutex<HashMap<StrRef, Type>>,
pub pyid_to_def: Arc<RwLock<HashMap<u64, DefinitionId>>>,
@ -699,7 +710,7 @@ impl InnerResolver {
// special handling for option type since its class member layout in python side
// is special and cannot be mapped directly to a nac3 type as below
(TypeEnum::TObj { obj_id, params, .. }, false)
if *obj_id == primitives.option.get_obj_id(unifier) =>
if *obj_id == primitives.option.obj_id(unifier).unwrap() =>
{
let Ok(field_data) = obj.getattr("_nac3_option") else {
unreachable!("cannot be None")
@ -982,7 +993,7 @@ impl InnerResolver {
} else if ty_id == self.primitive_ids.option {
let option_val_ty = match ctx.unifier.get_ty_immutable(expected_ty).as_ref() {
TypeEnum::TObj { obj_id, params, .. }
if *obj_id == ctx.primitives.option.get_obj_id(&ctx.unifier) =>
if *obj_id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap() =>
{
*params.iter().next().unwrap().1
}

View File

@ -5,6 +5,7 @@ use crate::{
classes::{
ArrayLikeIndexer,
ArrayLikeValue,
ArraySliceValue,
ListValue,
NDArrayValue,
RangeValue,
@ -39,12 +40,10 @@ use inkwell::{
types::{AnyType, BasicType, BasicTypeEnum},
values::{BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue}
};
use inkwell::values::BasicValue;
use itertools::{chain, izip, Itertools, Either};
use nac3parser::ast::{
self, Boolop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef,
};
use crate::codegen::classes::ArraySliceValue;
use super::{CodeGenerator, llvm_intrinsics::call_memcpy_generic, need_sret};
@ -156,7 +155,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
SymbolValue::OptionSome(v) => {
let ty = match self.unifier.get_ty_immutable(ty).as_ref() {
TypeEnum::TObj { obj_id, params, .. }
if *obj_id == self.primitives.option.get_obj_id(&self.unifier) =>
if *obj_id == self.primitives.option.obj_id(&self.unifier).unwrap() =>
{
*params.iter().next().unwrap().1
}
@ -170,7 +169,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
SymbolValue::OptionNone => {
let ty = match self.unifier.get_ty_immutable(ty).as_ref() {
TypeEnum::TObj { obj_id, params, .. }
if *obj_id == self.primitives.option.get_obj_id(&self.unifier) =>
if *obj_id == self.primitives.option.obj_id(&self.unifier).unwrap() =>
{
*params.iter().next().unwrap().1
}
@ -1131,15 +1130,17 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
Some("f_pow_i")
);
Ok(Some(res.into()))
} else if ty1.get_obj_id(&ctx.unifier) == PRIMITIVE_DEF_IDS.ndarray || ty2.get_obj_id(&ctx.unifier) == PRIMITIVE_DEF_IDS.ndarray {
} else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) || ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
let llvm_usize = generator.get_size_type(ctx.ctx);
let is_ndarray1 = ty1.get_obj_id(&ctx.unifier) == PRIMITIVE_DEF_IDS.ndarray;
let is_ndarray2 = ty2.get_obj_id(&ctx.unifier) == PRIMITIVE_DEF_IDS.ndarray;
let is_ndarray1 = ty1.obj_id(&ctx.unifier)
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
let is_ndarray2 = ty2.obj_id(&ctx.unifier)
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty2);
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
@ -1155,16 +1156,16 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
if is_aug_assign { Some(left_val) } else { None },
(left_val.as_ptr_value().into(), false),
(right_val, false),
|generator, ctx, elem_ty, (lhs, rhs)| {
|generator, ctx, (lhs, rhs)| {
gen_binop_expr_with_values(
generator,
ctx,
(&Some(elem_ty), lhs),
(&Some(ndarray_dtype1), lhs),
op,
(&Some(elem_ty), rhs),
(&Some(ndarray_dtype2), rhs),
ctx.current_loc,
is_aug_assign,
)?.unwrap().to_basic_value_enum(ctx, generator, elem_ty)
)?.unwrap().to_basic_value_enum(ctx, generator, ndarray_dtype1)
},
)?;
@ -1186,16 +1187,16 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
if is_aug_assign { Some(ndarray_val) } else { None },
(left_val, !is_ndarray1),
(right_val, !is_ndarray2),
|generator, ctx, elem_ty, (lhs, rhs)| {
|generator, ctx, (lhs, rhs)| {
gen_binop_expr_with_values(
generator,
ctx,
(&Some(elem_ty), lhs),
(&Some(ndarray_dtype), lhs),
op,
(&Some(elem_ty), rhs),
(&Some(ndarray_dtype), rhs),
ctx.current_loc,
is_aug_assign,
)?.unwrap().to_basic_value_enum(ctx, generator, elem_ty)
)?.unwrap().to_basic_value_enum(ctx, generator, ndarray_dtype)
},
)?;
@ -1254,7 +1255,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
}
}
/// Generates LLVM IR for a [binary operator expression][expr].
/// Generates LLVM IR for a binary operator expression.
///
/// * `left` - The left-hand side of the binary operator.
/// * `op` - The operator applied on the operands.
@ -1292,6 +1293,8 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
)
}
/// Generates LLVM IR for a unary operator expression using the [`Type`] and
/// [LLVM value][`BasicValueEnum`] of the operands.
pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
@ -1349,13 +1352,13 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
ndarray_dtype,
None,
val,
|generator, ctx, elem_ty, val| {
|generator, ctx, val| {
gen_unaryop_expr_with_values(
generator,
ctx,
op,
(&Some(elem_ty), val)
)?.unwrap().to_basic_value_enum(ctx, generator, elem_ty)
(&Some(ndarray_dtype), val)
)?.unwrap().to_basic_value_enum(ctx, generator, ndarray_dtype)
},
)?;
@ -1365,6 +1368,10 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
}))
}
/// Generates LLVM IR for a unary operator expression.
///
/// * `op` - The operator applied on the operand.
/// * `operand` - The unary operand.
pub fn gen_unaryop_expr<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
@ -1380,6 +1387,8 @@ pub fn gen_unaryop_expr<'ctx, G: CodeGenerator>(
gen_unaryop_expr_with_values(generator, ctx, op, (&operand.custom, val))
}
/// Generates LLVM IR for a comparison operator expression using the [`Type`] and
/// [LLVM value][`BasicValueEnum`] of the operands.
pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
@ -1399,9 +1408,11 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
let (Some(left_ty), lhs) = left else { unreachable!() };
let (Some(right_ty), rhs) = comparators[0] else { unreachable!() };
let op = ops[0].clone();
let is_ndarray1 = left_ty.get_obj_id(&ctx.unifier) == PRIMITIVE_DEF_IDS.ndarray;
let is_ndarray2 = right_ty.get_obj_id(&ctx.unifier) == PRIMITIVE_DEF_IDS.ndarray;
let is_ndarray1 = left_ty.obj_id(&ctx.unifier)
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
let is_ndarray2 = right_ty.obj_id(&ctx.unifier)
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
return if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty);
@ -1421,14 +1432,14 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
None,
(left_val.as_ptr_value().into(), false),
(rhs, false),
|generator, ctx, elem_ty, (lhs, rhs)| {
|generator, ctx, (lhs, rhs)| {
let val = gen_cmpop_expr_with_values(
generator,
ctx,
(Some(ndarray_dtype1), lhs),
&[op.clone()],
&[(Some(ndarray_dtype2), rhs)],
)?.unwrap().to_basic_value_enum(ctx, generator, elem_ty)?;
)?.unwrap().to_basic_value_enum(ctx, generator, ctx.primitives.bool)?;
Ok(generator.bool_to_i8(ctx, val.into_int_value()).into())
},
@ -1447,14 +1458,14 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
None,
(lhs, !is_ndarray1),
(rhs, !is_ndarray2),
|generator, ctx, elem_ty, (lhs, rhs)| {
|generator, ctx, (lhs, rhs)| {
let val = gen_cmpop_expr_with_values(
generator,
ctx,
(Some(ndarray_dtype), lhs),
&[op.clone()],
&[(Some(ndarray_dtype), rhs)],
)?.unwrap().to_basic_value_enum(ctx, generator, elem_ty)?;
)?.unwrap().to_basic_value_enum(ctx, generator, ctx.primitives.bool)?;
Ok(generator.bool_to_i8(ctx, val.into_int_value()).into())
},
@ -1544,6 +1555,11 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
}))
}
/// Generates LLVM IR for a comparison operator expression.
///
/// * `left` - The left-hand side of the comparison operator.
/// * `ops` - The (possibly chained) operators applied on the operands.
/// * `comparators` - The right-hand side of the binary operator.
pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
@ -1556,19 +1572,27 @@ pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>(
} else {
return Ok(None)
};
let comparators = {
let mut new_comparators = Vec::new();
new_comparators.reserve(comparators.len());
for cmptor in comparators {
if let Some(v) = generator.gen_expr(ctx, cmptor)? {
new_comparators.push((cmptor.custom, v.to_basic_value_enum(ctx, generator, cmptor.custom.unwrap())?))
let comparator_vals = comparators.iter()
.map(|cmptor| {
Ok(if let Some(v) = generator.gen_expr(ctx, cmptor)? {
Some((cmptor.custom, v.to_basic_value_enum(ctx, generator, cmptor.custom.unwrap())?))
} else {
return Ok(None)
}
}
new_comparators
None
})
})
.take_while(|v| if let Ok(v) = v {
v.is_some()
} else {
true
})
.collect::<Result<Vec<_>, String>>()?;
let comparator_vals = if comparator_vals.len() == comparators.len() {
comparator_vals
.into_iter()
.map(Option::unwrap)
.collect_vec()
} else {
return Ok(None)
};
gen_cmpop_expr_with_values(
@ -1576,7 +1600,7 @@ pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>(
ctx,
(left.custom, left_val),
ops,
comparators.as_slice(),
comparator_vals.as_slice(),
)
}
@ -1959,7 +1983,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
return gen_unaryop_expr(generator, ctx, op, operand)
}
ExprKind::Compare { left, ops, comparators } => {
return gen_cmpop_expr(generator, ctx, left, &ops, &comparators)
return gen_cmpop_expr(generator, ctx, left, ops, comparators)
}
ExprKind::IfExp { test, body, orelse } => {
let test = match generator.gen_expr(ctx, test)? {
@ -2081,7 +2105,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
// directly generate code for option.unwrap
// since it needs to return static value to optimize for kernel invariant
if attr == &"unwrap".into()
&& id == ctx.primitives.option.get_obj_id(&ctx.unifier)
&& id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap()
{
match val {
ValueEnum::Static(v) => return match v.get_field("_nac3_option".into(), ctx) {

View File

@ -25,6 +25,8 @@ use inkwell::{
};
use itertools::Either;
use nac3parser::ast::Expr;
use crate::codegen::classes::TypedArrayLikeAccessor;
use crate::codegen::stmt::gen_for_callback_incrementing;
#[must_use]
pub fn load_irrt(ctx: &Context) -> Module {
@ -820,25 +822,69 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_ndims = rhs.load_ndims(ctx);
let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None);
gen_for_callback_incrementing(
generator,
ctx,
llvm_usize.const_zero(),
(min_ndims, false),
|generator, ctx, idx| {
let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap();
let (lhs_dim_sz, rhs_dim_sz) = unsafe {
(
lhs.dim_sizes().get_typed_unchecked(ctx, generator, idx, None),
rhs.dim_sizes().get_typed_unchecked(ctx, generator, idx, None),
)
};
let llvm_usize_const_one = llvm_usize.const_int(1, false);
let lhs_eqz = ctx.builder.build_int_compare(
IntPredicate::EQ,
lhs_dim_sz,
llvm_usize_const_one,
"",
).unwrap();
let rhs_eqz = ctx.builder.build_int_compare(
IntPredicate::EQ,
rhs_dim_sz,
llvm_usize_const_one,
"",
).unwrap();
let lhs_or_rhs_eqz = ctx.builder.build_or(
lhs_eqz,
rhs_eqz,
""
).unwrap();
let lhs_eq_rhs = ctx.builder.build_int_compare(
IntPredicate::EQ,
lhs_dim_sz,
rhs_dim_sz,
""
).unwrap();
let is_compatible = ctx.builder.build_or(
lhs_or_rhs_eqz,
lhs_eq_rhs,
""
).unwrap();
ctx.make_assert(
generator,
is_compatible,
"0:ValueError",
"operands could not be broadcast together",
[None, None, None],
ctx.current_loc,
);
Ok(())
},
llvm_usize.const_int(1, false),
).unwrap();
let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
// TODO: Generate assertion checks for whether each dimension is compatible
// gen_for_callback_incrementing(
// generator,
// ctx,
// llvm_usize.const_zero(),
// (max_ndims, false),
// |generator, ctx, idx| {
// let lhs_dim_sz =
//
// let lhs_elem = lhs.get_dims().get(ctx, generator, idx, None);
// let rhs_elem = rhs.get_dims().get(ctx, generator, idx, None);
//
//
// },
// llvm_usize.const_int(1, false),
// ).unwrap();
let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator);
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator);
@ -900,7 +946,17 @@ pub fn call_ndarray_calc_broadcast_index<'ctx, G: CodeGenerator + ?Sized, Broadc
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
// TODO: Assertions
let array_ndims = array.load_ndims(ctx);
let broadcast_size = broadcast_idx.size(ctx, generator);
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::ULE, array_ndims, broadcast_size, "").unwrap(),
"0:ValueError",
"operands cannot be broadcast together",
[None, None, None],
ctx.current_loc,
);
let broadcast_size = broadcast_idx.size(ctx, generator);
let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();

View File

@ -1,7 +1,7 @@
use inkwell::{
IntPredicate,
types::BasicType,
values::{AggregateValueEnum, ArrayValue, BasicValueEnum, IntValue, PointerValue}
values::{BasicValueEnum, IntValue, PointerValue}
};
use nac3parser::ast::StrRef;
use crate::{
@ -142,12 +142,12 @@ fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
/// Creates an `NDArray` instance from a constant shape.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The shape of the `NDArray`, represented as an LLVM [`ArrayValue`].
/// * `shape` - The shape of the `NDArray`, represented am array of [`IntValue`]s.
fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: ArrayValue<'ctx>
shape: &[IntValue<'ctx>],
) -> Result<NDArrayValue<'ctx>, String> {
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
@ -158,14 +158,9 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum();
assert!(llvm_ndarray_data_t.is_sized());
for i in 0..shape.get_type().len() {
let shape_dim = ctx.builder
.build_extract_value(shape, i, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
for shape_dim in shape {
let shape_dim_gez = ctx.builder
.build_int_compare(IntPredicate::SGE, shape_dim, llvm_usize.const_zero(), "")
.build_int_compare(IntPredicate::SGE, *shape_dim, llvm_usize.const_zero(), "")
.unwrap();
ctx.make_assert(
@ -185,21 +180,20 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
)?;
let ndarray = NDArrayValue::from_ptr_val(ndarray, llvm_usize, None);
let num_dims = llvm_usize.const_int(shape.get_type().len() as u64, false);
let num_dims = llvm_usize.const_int(shape.len() as u64, false);
ndarray.store_ndims(ctx, generator, num_dims);
let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
for i in 0..shape.get_type().len() {
let ndarray_dim = ndarray
.dim_sizes()
.ptr_offset(ctx, generator, llvm_usize.const_int(i as u64, true), None);
let shape_dim = ctx.builder.build_extract_value(shape, i, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
for (i, shape_dim) in shape.iter().enumerate() {
let ndarray_dim = unsafe {
ndarray
.dim_sizes()
.ptr_offset_unchecked(ctx, generator, llvm_usize.const_int(i as u64, true), None)
};
ctx.builder.build_store(ndarray_dim, shape_dim).unwrap();
ctx.builder.build_store(ndarray_dim, *shape_dim).unwrap();
}
let ndarray_num_elems = call_ndarray_calc_size(
@ -376,7 +370,6 @@ fn ndarray_fill_mapping<'ctx, G, MapFn>(
fn ndarray_broadcast_fill<'ctx, G, ValueFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
res: NDArrayValue<'ctx>,
lhs: (BasicValueEnum<'ctx>, bool),
rhs: (BasicValueEnum<'ctx>, bool),
@ -384,7 +377,7 @@ fn ndarray_broadcast_fill<'ctx, G, ValueFn>(
) -> Result<NDArrayValue<'ctx>, String>
where
G: CodeGenerator + ?Sized,
ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, Type, (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>)) -> Result<BasicValueEnum<'ctx>, String>,
ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>)) -> Result<BasicValueEnum<'ctx>, String>,
{
let llvm_usize = generator.get_size_type(ctx.ctx);
@ -425,7 +418,7 @@ fn ndarray_broadcast_fill<'ctx, G, ValueFn>(
debug_assert_eq!(lhs_elem.get_type(), rhs_elem.get_type());
value_fn(generator, ctx, elem_ty, (lhs_elem, rhs_elem))
value_fn(generator, ctx, (lhs_elem, rhs_elem))
},
)?;
@ -561,27 +554,16 @@ fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>(
) -> Result<NDArrayValue<'ctx>, String> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_usize_2 = llvm_usize.array_type(2);
let shape_addr = generator.gen_var_alloc(ctx, llvm_usize_2.into(), None)?;
let shape = ctx.builder.build_load(shape_addr, "")
.map(BasicValueEnum::into_array_value)
.unwrap();
let nrows = ctx.builder.build_int_z_extend_or_bit_cast(nrows, llvm_usize, "").unwrap();
let shape = ctx.builder
.build_insert_value(shape, nrows, 0, "")
.map(AggregateValueEnum::into_array_value)
.unwrap();
let ncols = ctx.builder.build_int_z_extend_or_bit_cast(ncols, llvm_usize, "").unwrap();
let shape = ctx.builder
.build_insert_value(shape, ncols, 1, "")
.map(AggregateValueEnum::into_array_value)
.unwrap();
let ndarray = create_ndarray_const_shape(generator, ctx, elem_ty, shape)?;
let ndarray = create_ndarray_const_shape(
generator,
ctx,
elem_ty,
&[nrows, ncols],
)?;
ndarray_fill_indexed(
generator,
@ -677,7 +659,7 @@ pub fn ndarray_elementwise_unaryop_impl<'ctx, G, MapFn>(
) -> Result<NDArrayValue<'ctx>, String>
where
G: CodeGenerator,
MapFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, Type, BasicValueEnum<'ctx>) -> Result<BasicValueEnum<'ctx>, String>,
MapFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, BasicValueEnum<'ctx>) -> Result<BasicValueEnum<'ctx>, String>,
{
let res = res.unwrap_or_else(|| {
create_ndarray_dyn_shape(
@ -702,7 +684,7 @@ pub fn ndarray_elementwise_unaryop_impl<'ctx, G, MapFn>(
operand,
res,
|generator, ctx, elem| {
map_fn(generator, ctx, elem_ty, elem)
map_fn(generator, ctx, elem)
}
)?;
@ -728,7 +710,6 @@ pub fn ndarray_elementwise_unaryop_impl<'ctx, G, MapFn>(
/// # Panic
///
/// This function will panic if neither input operands (`lhs` or `rhs`) is a `ndarray`.
// TODO: Remove elem_ty from value_fn
pub fn ndarray_elementwise_binop_impl<'ctx, G, ValueFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
@ -740,7 +721,7 @@ pub fn ndarray_elementwise_binop_impl<'ctx, G, ValueFn>(
) -> Result<NDArrayValue<'ctx>, String>
where
G: CodeGenerator,
ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, Type, (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>)) -> Result<BasicValueEnum<'ctx>, String>,
ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>)) -> Result<BasicValueEnum<'ctx>, String>,
{
let llvm_usize = generator.get_size_type(ctx.ctx);
@ -800,12 +781,11 @@ pub fn ndarray_elementwise_binop_impl<'ctx, G, ValueFn>(
ndarray_broadcast_fill(
generator,
ctx,
elem_ty,
ndarray,
lhs,
rhs,
|generator, ctx, elem_ty, elems| {
value_fn(generator, ctx, elem_ty, elems)
|generator, ctx, elems| {
value_fn(generator, ctx, elems)
},
)?;
@ -928,24 +908,24 @@ pub fn gen_ndarray_eye<'ctx>(
.to_basic_value_enum(context, generator, nrows_ty)?;
let ncols_ty = fun.0.args[1].ty;
let ncols_arg = args.iter()
.find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name))
.map(|arg| arg.1.clone().to_basic_value_enum(context, generator, ncols_ty))
.unwrap_or_else(|| {
args[0].1.clone().to_basic_value_enum(context, generator, nrows_ty)
})?;
let ncols_arg = if let Some(arg) =
args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name)) {
arg.1.clone().to_basic_value_enum(context, generator, ncols_ty)
} else {
args[0].1.clone().to_basic_value_enum(context, generator, nrows_ty)
}?;
let offset_ty = fun.0.args[2].ty;
let offset_arg = args.iter()
.find(|arg| arg.0.is_some_and(|name| name == fun.0.args[2].name))
.map(|arg| arg.1.clone().to_basic_value_enum(context, generator, offset_ty))
.unwrap_or_else(|| {
Ok(context.gen_symbol_val(
generator,
fun.0.args[2].default_value.as_ref().unwrap(),
offset_ty
))
})?;
let offset_arg = if let Some(arg) =
args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[2].name)) {
arg.1.clone().to_basic_value_enum(context, generator, offset_ty)
} else {
Ok(context.gen_symbol_val(
generator,
fun.0.args[2].default_value.as_ref().unwrap(),
offset_ty
))
}?;
call_ndarray_eye_impl(
generator,

View File

@ -185,13 +185,13 @@ impl SymbolValue {
TypeAnnotation::Tuple(vs_tys)
}
SymbolValue::OptionNone => TypeAnnotation::CustomClass {
id: primitives.option.get_obj_id(unifier),
id: primitives.option.obj_id(unifier).unwrap(),
params: Vec::default(),
},
SymbolValue::OptionSome(v) => {
let ty = v.get_type_annotation(primitives, unifier);
TypeAnnotation::CustomClass {
id: primitives.option.get_obj_id(unifier),
id: primitives.option.obj_id(unifier).unwrap(),
params: vec![ty],
}
}

View File

@ -1727,8 +1727,12 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
);
let len = unsafe {
arg.dim_sizes()
.get_typed_unchecked(ctx, generator, llvm_usize.const_zero(), None)
arg.dim_sizes().get_typed_unchecked(
ctx,
generator,
llvm_usize.const_zero(),
None,
)
};
if len.get_type().get_bit_width() == 32 {

View File

@ -46,7 +46,7 @@ impl PrimitiveDefinitionIds {
]
}
/// Returns an iterator over all [`DefinitionId`]s of this instance.
/// Returns an iterator over all [`DefinitionId`]s of this instance in indeterminate order.
pub fn iter(&self) -> impl Iterator<Item=DefinitionId> {
self.as_vec().into_iter()
}
@ -610,7 +610,7 @@ impl TopLevelComposer {
TypeAnnotation::CustomClass { id: e_id, params: e_param },
) => {
*f_id == *e_id
&& *f_id == primitive.option.get_obj_id(unifier)
&& *f_id == primitive.option.obj_id(unifier).unwrap()
&& (f_param.is_empty()
|| (f_param.len() == 1
&& e_param.len() == 1

View File

@ -5,7 +5,7 @@ expression: res_vec
[
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [30]\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [156]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",

View File

@ -7,7 +7,7 @@ expression: res_vec
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar19]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar19\"]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar145]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar145\"]\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",

View File

@ -5,8 +5,8 @@ expression: res_vec
[
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [32]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [37]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [158]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [163]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",

View File

@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec
---
[
"Class {\nname: \"A\",\nancestors: [\"A[typevar18, typevar19]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar18\", \"typevar19\"]\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[typevar144, typevar145]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar144\", \"typevar145\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",

View File

@ -6,12 +6,12 @@ expression: res_vec
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [38]\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [164]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [46]\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [172]\n}\n",
]

View File

@ -302,12 +302,18 @@ pub fn impl_invert(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret
}
/// `Not`
pub fn impl_not(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, ret_ty: Option<Type>) {
pub fn impl_not(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret_ty: Option<Type>) {
impl_unaryop(unifier, ty, ret_ty, &[Unaryop::Not]);
}
/// `Lt`, `LtE`, `Gt`, `GtE`
pub fn impl_comparison(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type], ret_ty: Option<Type>) {
pub fn impl_comparison(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
) {
impl_cmpop(
unifier,
store,
@ -319,7 +325,13 @@ pub fn impl_comparison(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type,
}
/// `Eq`, `NotEq`
pub fn impl_eq(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type], ret_ty: Option<Type>) {
pub fn impl_eq(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
) {
impl_cmpop(unifier, store, ty, other_ty, &[Cmpop::Eq, Cmpop::NotEq], ret_ty);
}
@ -330,8 +342,8 @@ pub fn typeof_ndarray_broadcast(
left: Type,
right: Type,
) -> Result<Type, String> {
let is_left_ndarray = left.get_obj_id(unifier) == PRIMITIVE_DEF_IDS.ndarray;
let is_right_ndarray = right.get_obj_id(unifier) == PRIMITIVE_DEF_IDS.ndarray;
let is_left_ndarray = left.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
let is_right_ndarray = right.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
assert!(is_left_ndarray || is_right_ndarray);
@ -401,12 +413,8 @@ pub fn typeof_binop(
lhs: Type,
rhs: Type,
) -> Result<Option<Type>, String> {
let is_left_ndarray = lhs
.obj_id(unifier)
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
let is_right_ndarray = rhs
.obj_id(unifier)
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
Ok(Some(match op {
Operator::Add
@ -465,7 +473,7 @@ pub fn typeof_unaryop(
operand: Type,
) -> Result<Option<Type>, String> {
if *op == Unaryop::Not && operand.obj_id(unifier).is_some_and(|id| id == primitives.ndarray.obj_id(unifier).unwrap()) {
return Err("The truth value of an array with more than one element is ambiguous".to_string())
return Err("The truth value of an array with more than one element is ambiguous".to_string())
}
Ok(if operand.obj_id(unifier).is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) {

View File

@ -254,6 +254,14 @@ impl Unifier {
self.primitive_store.replace(*primitives);
}
/// Returns the [`UnificationTable`] associated with this `Unifier`.
///
/// # Safety
///
/// The use of this function is discouraged under most circumstances. Only use this function if
/// in-place manipulation of type variables and/or type fields is necessary, otherwise prefer to
/// [add a new type][`Unifier::add_ty`] and [unify the type][`Unifier::unify`] with an existing
/// type.
pub unsafe fn get_unification_table(&mut self) -> &mut UnificationTable<Rc<TypeEnum>> {
&mut self.unification_table
}