Compare commits

...

4 Commits

14 changed files with 727 additions and 294 deletions

View File

@ -21,7 +21,7 @@ use crate::{
DefinitionId, TopLevelDef,
},
typecheck::{
magic_methods::{binop_assign_name, binop_name, unaryop_name},
magic_methods::{BinOpVariant, OpInfo},
typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap},
},
};
@ -1167,7 +1167,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
op: Operator,
right: (&Option<Type>, BasicValueEnum<'ctx>),
loc: Location,
is_aug_assign: bool,
variant: BinOpVariant,
) -> Result<Option<ValueEnum<'ctx>>, String> {
let (left_ty, left_val) = left;
let (right_ty, right_val) = right;
@ -1222,7 +1222,10 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
generator,
ctx,
ndarray_dtype1,
if is_aug_assign { Some(left_val) } else { None },
match variant {
BinOpVariant::Normal => None,
BinOpVariant::AugAssign => Some(left_val),
},
left_val,
right_val,
)?
@ -1231,7 +1234,10 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
generator,
ctx,
ndarray_dtype1,
if is_aug_assign { Some(left_val) } else { None },
match variant {
BinOpVariant::Normal => None,
BinOpVariant::AugAssign => Some(left_val),
},
(left_val.as_base_value().into(), false),
(right_val.as_base_value().into(), false),
|generator, ctx, (lhs, rhs)| {
@ -1242,7 +1248,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
op,
(&Some(ndarray_dtype2), rhs),
ctx.current_loc,
is_aug_assign,
variant,
)?
.unwrap()
.to_basic_value_enum(
@ -1267,7 +1273,10 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
generator,
ctx,
ndarray_dtype,
if is_aug_assign { Some(ndarray_val) } else { None },
match variant {
BinOpVariant::Normal => None,
BinOpVariant::AugAssign => Some(ndarray_val),
},
(left_val, !is_ndarray1),
(right_val, !is_ndarray2),
|generator, ctx, (lhs, rhs)| {
@ -1278,7 +1287,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
op,
(&Some(ndarray_dtype), rhs),
ctx.current_loc,
is_aug_assign,
variant,
)?
.unwrap()
.to_basic_value_enum(ctx, generator, ndarray_dtype)
@ -1293,13 +1302,15 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
unreachable!("must be tobj")
};
let (op_name, id) = {
let (binop_name, binop_assign_name) =
(binop_name(op).into(), binop_assign_name(op).into());
let normal_method_name = OpInfo::from_binop(op, BinOpVariant::Normal).method_name;
let assign_method_name = OpInfo::from_binop(op, BinOpVariant::AugAssign).method_name;
// if is aug_assign, try aug_assign operator first
if is_aug_assign && fields.contains_key(&binop_assign_name) {
(binop_assign_name, *obj_id)
if variant == BinOpVariant::AugAssign && fields.contains_key(&assign_method_name.into())
{
(assign_method_name.into(), *obj_id)
} else {
(binop_name, *obj_id)
(normal_method_name.into(), *obj_id)
}
};
@ -1349,7 +1360,7 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
op: Operator,
right: &Expr<Option<Type>>,
loc: Location,
is_aug_assign: bool,
variant: BinOpVariant,
) -> Result<Option<ValueEnum<'ctx>>, String> {
let left_val = if let Some(v) = generator.gen_expr(ctx, left)? {
v.to_basic_value_enum(ctx, generator, left.custom.unwrap())?
@ -1369,7 +1380,7 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
op,
(&right.custom, right_val),
loc,
is_aug_assign,
variant,
)
}
@ -1453,7 +1464,10 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
if op == ast::Unaryop::Invert {
ast::Unaryop::Not
} else {
unreachable!("ufunc {} not supported for ndarray[bool, N]", unaryop_name(op))
unreachable!(
"ufunc {} not supported for ndarray[bool, N]",
OpInfo::from_unaryop(op).method_name
)
}
} else {
op
@ -2343,7 +2357,15 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
}
}
ExprKind::BinOp { op, left, right } => {
return gen_binop_expr(generator, ctx, left, *op, right, expr.location, false);
return gen_binop_expr(
generator,
ctx,
left,
*op,
right,
expr.location,
BinOpVariant::Normal,
);
}
ExprKind::UnaryOp { op, operand } => return gen_unaryop_expr(generator, ctx, *op, operand),
ExprKind::Compare { left, ops, comparators } => {

View File

@ -11,8 +11,7 @@ use crate::{
call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices,
call_ndarray_calc_size,
},
llvm_intrinsics,
llvm_intrinsics::call_memcpy_generic,
llvm_intrinsics::{self, call_memcpy_generic},
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
CodeGenContext, CodeGenerator,
},
@ -22,7 +21,10 @@ use crate::{
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
DefinitionId,
},
typecheck::typedef::{FunSignature, Type, TypeEnum},
typecheck::{
magic_methods::BinOpVariant,
typedef::{FunSignature, Type, TypeEnum},
},
};
use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType};
use inkwell::{
@ -163,10 +165,11 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
) -> Result<NDArrayValue<'ctx>, String> {
let llvm_usize = generator.get_size_type(ctx.ctx);
for shape_dim in shape {
for &shape_dim in shape {
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
let shape_dim_gez = ctx
.builder
.build_int_compare(IntPredicate::SGE, *shape_dim, llvm_usize.const_zero(), "")
.build_int_compare(IntPredicate::SGE, shape_dim, llvm_usize.const_zero(), "")
.unwrap();
ctx.make_assert(
@ -189,7 +192,8 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
for (i, shape_dim) in shape.iter().enumerate() {
for (i, &shape_dim) in shape.iter().enumerate() {
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
let ndarray_dim = unsafe {
ndarray.dim_sizes().ptr_offset_unchecked(
ctx,
@ -199,7 +203,7 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
)
};
ctx.builder.build_store(ndarray_dim, *shape_dim).unwrap();
ctx.builder.build_store(ndarray_dim, shape_dim).unwrap();
}
let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray);
@ -286,22 +290,68 @@ fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
///
/// ### Notes on `shape`
///
/// Just like numpy, the `shape` argument can be:
/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))`
/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
///
/// See also [`typecheck::type_inferencer::fold_numpy_function_call_shape_argument`] to
/// learn how `shape` gets from being a Python user expression to here.
fn call_ndarray_empty_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: ListValue<'ctx>,
shape: BasicValueEnum<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&shape,
|_, ctx, shape| Ok(shape.load_size(ctx, None)),
|generator, ctx, shape, idx| {
Ok(shape.data().get(ctx, generator, &idx, None).into_int_value())
},
)
let llvm_usize = generator.get_size_type(ctx.ctx);
match shape {
BasicValueEnum::PointerValue(shape_list_ptr)
if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() =>
{
// 1. A list of ints; e.g., `np.empty([600, 800, 3])`
let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None);
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&shape_list,
|_, ctx, shape_list| Ok(shape_list.load_size(ctx, None)),
|generator, ctx, shape_list, idx| {
Ok(shape_list.data().get(ctx, generator, &idx, None).into_int_value())
},
)
}
BasicValueEnum::StructValue(shape_tuple) => {
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
// Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM.
// Get the length/size of the tuple, which also happens to be the value of `ndims`.
let ndims = shape_tuple.get_type().count_fields();
let mut shape = Vec::with_capacity(ndims as usize);
for dim_i in 0..ndims {
let dim = ctx
.builder
.build_extract_value(shape_tuple, dim_i, format!("dim{dim_i}").as_str())
.unwrap()
.into_int_value();
shape.push(dim);
}
create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice())
}
BasicValueEnum::IntValue(shape_int) => {
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
}
_ => unreachable!(),
}
}
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as
@ -486,7 +536,7 @@ fn call_ndarray_zeros_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: ListValue<'ctx>,
shape: BasicValueEnum<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
let supported_types = [
ctx.primitives.int32,
@ -517,7 +567,7 @@ fn call_ndarray_ones_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: ListValue<'ctx>,
shape: BasicValueEnum<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
let supported_types = [
ctx.primitives.int32,
@ -548,7 +598,7 @@ fn call_ndarray_full_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: ListValue<'ctx>,
shape: BasicValueEnum<'ctx>,
fill_value: BasicValueEnum<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
@ -1632,7 +1682,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
Operator::Mult,
(&Some(elem_ty), b),
ctx.current_loc,
false,
BinOpVariant::Normal,
)?
.unwrap()
.to_basic_value_enum(ctx, generator, elem_ty)?;
@ -1645,7 +1695,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
Operator::Add,
(&Some(elem_ty), a_mul_b),
ctx.current_loc,
false,
BinOpVariant::Normal,
)?
.unwrap()
.to_basic_value_enum(ctx, generator, elem_ty)?;
@ -1674,17 +1724,11 @@ pub fn gen_ndarray_empty<'ctx>(
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let llvm_usize = generator.get_size_type(context.ctx);
let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
call_ndarray_empty_impl(
generator,
context,
context.primitives.float,
ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None),
)
.map(NDArrayValue::into)
call_ndarray_empty_impl(generator, context, context.primitives.float, shape_arg)
.map(NDArrayValue::into)
}
/// Generates LLVM IR for `ndarray.zeros`.
@ -1698,17 +1742,11 @@ pub fn gen_ndarray_zeros<'ctx>(
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let llvm_usize = generator.get_size_type(context.ctx);
let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
call_ndarray_zeros_impl(
generator,
context,
context.primitives.float,
ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None),
)
.map(NDArrayValue::into)
call_ndarray_zeros_impl(generator, context, context.primitives.float, shape_arg)
.map(NDArrayValue::into)
}
/// Generates LLVM IR for `ndarray.ones`.
@ -1722,17 +1760,11 @@ pub fn gen_ndarray_ones<'ctx>(
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let llvm_usize = generator.get_size_type(context.ctx);
let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
call_ndarray_ones_impl(
generator,
context,
context.primitives.float,
ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None),
)
.map(NDArrayValue::into)
call_ndarray_ones_impl(generator, context, context.primitives.float, shape_arg)
.map(NDArrayValue::into)
}
/// Generates LLVM IR for `ndarray.full`.
@ -1746,21 +1778,14 @@ pub fn gen_ndarray_full<'ctx>(
assert!(obj.is_none());
assert_eq!(args.len(), 2);
let llvm_usize = generator.get_size_type(context.ctx);
let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
let fill_value_ty = fun.0.args[1].ty;
let fill_value_arg =
args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?;
call_ndarray_full_impl(
generator,
context,
fill_value_ty,
ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None),
fill_value_arg,
)
.map(NDArrayValue::into)
call_ndarray_full_impl(generator, context, fill_value_ty, shape_arg, fill_value_arg)
.map(NDArrayValue::into)
}
pub fn gen_ndarray_array<'ctx>(

View File

@ -11,7 +11,10 @@ use crate::{
gen_in_range_check,
},
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type, TypeEnum},
typecheck::{
magic_methods::BinOpVariant,
typedef::{FunSignature, Type, TypeEnum},
},
};
use inkwell::{
attributes::{Attribute, AttributeLoc},
@ -1574,7 +1577,15 @@ pub fn gen_stmt<G: CodeGenerator>(
StmtKind::For { .. } => generator.gen_for(ctx, stmt)?,
StmtKind::With { .. } => generator.gen_with(ctx, stmt)?,
StmtKind::AugAssign { target, op, value, .. } => {
let value = gen_binop_expr(generator, ctx, target, *op, value, stmt.location, true)?;
let value = gen_binop_expr(
generator,
ctx,
target,
*op,
value,
stmt.location,
BinOpVariant::AugAssign,
)?;
generator.gen_assign(ctx, target, value.unwrap())?;
}
StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?,

View File

@ -324,6 +324,9 @@ struct BuiltinBuilder<'a> {
num_or_ndarray_ty: TypeVar,
num_or_ndarray_var_map: VarMap,
/// See [`BuiltinBuilder::build_ndarray_from_shape_factory_function`]
ndarray_factory_fn_shape_arg_tvar: TypeVar,
}
impl<'a> BuiltinBuilder<'a> {
@ -394,6 +397,8 @@ impl<'a> BuiltinBuilder<'a> {
let list_int32 = unifier.add_ty(TypeEnum::TList { ty: int32 });
let ndarray_factory_fn_shape_arg_tvar = unifier.get_fresh_var(Some("Shape".into()), None);
BuiltinBuilder {
unifier,
primitives,
@ -421,6 +426,8 @@ impl<'a> BuiltinBuilder<'a> {
num_or_ndarray_ty,
num_or_ndarray_var_map,
ndarray_factory_fn_shape_arg_tvar,
}
}
@ -959,21 +966,46 @@ impl<'a> BuiltinBuilder<'a> {
)
}
/// Build ndarray factory functions that only take in an argument `shape` of type `list[int32]` and return an ndarray.
/// Build ndarray factory functions that only take in an argument `shape`.
///
/// `shape` can be a tuple of int32s, a list of int32s, or a scalar int32.
fn build_ndarray_from_shape_factory_function(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(
prim,
&[PrimDef::FunNpNDArray, PrimDef::FunNpEmpty, PrimDef::FunNpZeros, PrimDef::FunNpOnes],
);
// NOTE: on `ndarray_factory_fn_shape_arg_tvar` and
// the `param_ty` for `create_fn_by_codegen`.
//
// Ideally, we should have created a [`TypeVar`] to define all possible input
// types for the parameter "shape" like so:
// ```rust
// self.unifier.get_fresh_var_with_range(
// &[int32, list_int32, /* and more... */],
// Some("T".into()), None)
// )
// ```
//
// However, there is (currently) no way to type a tuple of arbitrary length in `nac3core`.
//
// And this is the best we could do:
// ```rust
// &[ int32, list_int32, tuple_1_int32, tuple_2_int32, tuple_3_int32, ... ],
// ```
//
// But this is not ideal.
//
// Instead, we delegate the responsibility of typechecking
// to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`],
// and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`.
create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
self.ndarray_float,
// We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a
// type variable
&[(self.list_int32, "shape")],
&[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
Box::new(move |ctx, obj, fun, args, generator| {
let func = match prim {
PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => gen_ndarray_empty,

View File

@ -5,7 +5,7 @@ expression: res_vec
[
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(239)]\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(240)]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",

View File

@ -7,7 +7,7 @@ expression: res_vec
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar228]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar228\"]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar229]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar229\"]\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",

View File

@ -5,8 +5,8 @@ expression: res_vec
[
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(241)]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(246)]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(242)]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(247)]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",

View File

@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec
---
[
"Class {\nname: \"A\",\nancestors: [\"A[typevar227, typevar228]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar227\", \"typevar228\"]\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[typevar228, typevar229]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar228\", \"typevar229\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",

View File

@ -6,12 +6,12 @@ expression: res_vec
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(247)]\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(248)]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(255)]\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(256)]\n}\n",
]

View File

@ -5,7 +5,7 @@ use crate::typecheck::{
type_inferencer::*,
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
};
use itertools::Itertools;
use itertools::{iproduct, Itertools};
use nac3parser::ast::StrRef;
use nac3parser::ast::{Cmpop, Operator, Unaryop};
use std::cmp::max;
@ -13,64 +13,93 @@ use std::collections::HashMap;
use std::rc::Rc;
use strum::IntoEnumIterator;
#[must_use]
pub fn binop_name(op: Operator) -> &'static str {
match op {
Operator::Add => "__add__",
Operator::Sub => "__sub__",
Operator::Div => "__truediv__",
Operator::Mod => "__mod__",
Operator::Mult => "__mul__",
Operator::Pow => "__pow__",
Operator::BitOr => "__or__",
Operator::BitXor => "__xor__",
Operator::BitAnd => "__and__",
Operator::LShift => "__lshift__",
Operator::RShift => "__rshift__",
Operator::FloorDiv => "__floordiv__",
Operator::MatMult => "__matmul__",
}
/// Details about an operator (unary, binary, etc...) in Python
#[derive(Debug, Clone, Copy)]
pub struct OpInfo {
/// The method name of the binary operator.
/// For addition, this would be `__add__`, and `__iadd__` if
/// it is the augmented assigning variant.
pub method_name: &'static str,
/// The symbol of the binary operator.
/// For addition, this would be `+`, and `+=` if
/// it is the augmented assigning variant.
pub symbol: &'static str,
}
#[must_use]
pub fn binop_assign_name(op: Operator) -> &'static str {
match op {
Operator::Add => "__iadd__",
Operator::Sub => "__isub__",
Operator::Div => "__itruediv__",
Operator::Mod => "__imod__",
Operator::Mult => "__imul__",
Operator::Pow => "__ipow__",
Operator::BitOr => "__ior__",
Operator::BitXor => "__ixor__",
Operator::BitAnd => "__iand__",
Operator::LShift => "__ilshift__",
Operator::RShift => "__irshift__",
Operator::FloorDiv => "__ifloordiv__",
Operator::MatMult => "__imatmul__",
}
/// Helper macro to conveniently build an [`OpInfo`].
///
/// Example usage: `make_info("add", "+")` generates `OpInfo { name: "__add__", symbol: "+" }`
macro_rules! make_info {
($name:expr, $symbol:expr) => {
OpInfo { method_name: concat!("__", $name, "__"), symbol: $symbol }
};
}
#[must_use]
pub fn unaryop_name(op: Unaryop) -> &'static str {
match op {
Unaryop::UAdd => "__pos__",
Unaryop::USub => "__neg__",
Unaryop::Not => "__not__",
Unaryop::Invert => "__inv__",
}
/// The variant of a binary operator.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BinOpVariant {
/// The normal variant.
/// For addition, it would be `+`.
Normal,
/// The "Augmented Assigning Operator" variant.
/// For addition, it would be `+=`.
AugAssign,
}
#[must_use]
pub fn comparison_name(op: Cmpop) -> Option<&'static str> {
match op {
Cmpop::Lt => Some("__lt__"),
Cmpop::LtE => Some("__le__"),
Cmpop::Gt => Some("__gt__"),
Cmpop::GtE => Some("__ge__"),
Cmpop::Eq => Some("__eq__"),
Cmpop::NotEq => Some("__ne__"),
_ => None,
impl OpInfo {
#[must_use]
pub fn from_binop(op: Operator, variant: BinOpVariant) -> Self {
// Helper macro to generate both the normal variant [`OpInfo`] and the
// augmented assigning variant [`OpInfo`] for a binary operator conveniently.
macro_rules! info {
($name:literal, $symbol:literal) => {
(make_info!($name, $symbol), make_info!(concat!("i", $name), concat!($symbol, "=")))
};
}
let (normal_variant, aug_assign_variant) = match op {
Operator::Add => info!("add", "+"),
Operator::Sub => info!("sub", "-"),
Operator::Div => info!("truediv", "/"),
Operator::Mod => info!("mod", "%"),
Operator::Mult => info!("mul", "*"),
Operator::Pow => info!("pow", "**"),
Operator::BitOr => info!("or", "|"),
Operator::BitXor => info!("xor", "^"),
Operator::BitAnd => info!("and", "&"),
Operator::LShift => info!("lshift", "<<"),
Operator::RShift => info!("rshift", ">>"),
Operator::FloorDiv => info!("floordiv", "//"),
Operator::MatMult => info!("matmul", "@"),
};
match variant {
BinOpVariant::Normal => normal_variant,
BinOpVariant::AugAssign => aug_assign_variant,
}
}
#[must_use]
pub fn from_unaryop(op: Unaryop) -> Self {
match op {
Unaryop::UAdd => make_info!("pos", "+"),
Unaryop::USub => make_info!("neg", "-"),
Unaryop::Not => make_info!("not", "not"), // i.e., `not False`, so the symbol is just `not`.
Unaryop::Invert => make_info!("inv", "~"),
}
}
#[must_use]
pub fn from_cmpop(op: Cmpop) -> Option<Self> {
match op {
Cmpop::Lt => Some(make_info!("lt", "<")),
Cmpop::LtE => Some(make_info!("le", "<=")),
Cmpop::Gt => Some(make_info!("gt", ">")),
Cmpop::GtE => Some(make_info!("ge", ">=")),
Cmpop::Eq => Some(make_info!("eq", "==")),
Cmpop::NotEq => Some(make_info!("ne", "!=")),
_ => None,
}
}
}
@ -115,23 +144,8 @@ pub fn impl_binop(
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty);
for op in ops {
fields.insert(binop_name(*op).into(), {
(
unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty,
vars: function_vars.clone(),
args: vec![FuncArg {
ty: other_ty,
default_value: None,
name: "other".into(),
}],
})),
false,
)
});
fields.insert(binop_assign_name(*op).into(), {
for (op, variant) in iproduct!(ops, [BinOpVariant::Normal, BinOpVariant::AugAssign]) {
fields.insert(OpInfo::from_binop(*op, variant).method_name.into(), {
(
unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty,
@ -155,7 +169,7 @@ pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Option<Type>, ops:
for op in ops {
fields.insert(
unaryop_name(*op).into(),
OpInfo::from_unaryop(*op).method_name.into(),
(
unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty,
@ -195,7 +209,7 @@ pub fn impl_cmpop(
for op in ops {
fields.insert(
comparison_name(*op).unwrap().into(),
OpInfo::from_cmpop(*op).unwrap().method_name.into(),
(
unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty,

View File

@ -1,11 +1,14 @@
use std::collections::HashMap;
use std::fmt::Display;
use crate::typecheck::typedef::TypeEnum;
use crate::typecheck::{magic_methods::OpInfo, typedef::TypeEnum};
use super::typedef::{RecordKey, Type, Unifier};
use super::{
magic_methods::BinOpVariant,
typedef::{RecordKey, Type, Unifier},
};
use itertools::Itertools;
use nac3parser::ast::{Location, StrRef};
use nac3parser::ast::{Cmpop, Location, Operator, StrRef};
#[derive(Debug, Clone)]
pub enum TypeErrorKind {
@ -26,6 +29,19 @@ pub enum TypeErrorKind {
expected: Type,
got: Type,
},
UnsupportedBinaryOpTypes {
operator: Operator,
variant: BinOpVariant,
lhs_type: Type,
rhs_type: Type,
expected_rhs_type: Type,
},
UnsupportedComparsionOpTypes {
operator: Cmpop,
lhs_type: Type,
rhs_type: Type,
expected_rhs_type: Type,
},
FieldUnificationError {
field: RecordKey,
types: (Type, Type),
@ -101,6 +117,32 @@ impl<'a> Display for DisplayTypeError<'a> {
let args = missing_arg_names.iter().join(", ");
write!(f, "Missing arguments: {args}")
}
UnsupportedBinaryOpTypes {
operator,
variant,
lhs_type,
rhs_type,
expected_rhs_type,
} => {
let op_symbol = OpInfo::from_binop(*operator, *variant).symbol;
let lhs_type_str = self.unifier.stringify_with_notes(*lhs_type, &mut notes);
let rhs_type_str = self.unifier.stringify_with_notes(*rhs_type, &mut notes);
let expected_rhs_type_str =
self.unifier.stringify_with_notes(*expected_rhs_type, &mut notes);
write!(f, "Unsupported operand type(s) for {op_symbol}: '{lhs_type_str}' and '{rhs_type_str}' (right operand should have type {expected_rhs_type_str})")
}
UnsupportedComparsionOpTypes { operator, lhs_type, rhs_type, expected_rhs_type } => {
let op_symbol = OpInfo::from_cmpop(*operator).unwrap().symbol;
let lhs_type_str = self.unifier.stringify_with_notes(*lhs_type, &mut notes);
let rhs_type_str = self.unifier.stringify_with_notes(*rhs_type, &mut notes);
let expected_rhs_type_str =
self.unifier.stringify_with_notes(*expected_rhs_type, &mut notes);
write!(f, "'{op_symbol}' not supported between instances of '{lhs_type_str}' and '{rhs_type_str}' (right operand should have type {expected_rhs_type_str})")
}
UnknownArgName(name) => {
write!(f, "Unknown argument name: {name}")
}

View File

@ -4,7 +4,9 @@ use std::iter::once;
use std::ops::Not;
use std::{cell::RefCell, sync::Arc};
use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap};
use super::typedef::{
Call, CallInfo, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap,
};
use super::{magic_methods::*, type_error::TypeError, typedef::CallId};
use crate::toplevel::TopLevelDef;
use crate::{
@ -466,7 +468,8 @@ impl<'a> Fold<()> for Inferencer<'a> {
(None, None) => {}
},
ast::StmtKind::AugAssign { target, op, value, .. } => {
let res_ty = self.infer_bin_ops(stmt.location, target, *op, value, true)?;
let res_ty =
self.infer_bin_ops(stmt.location, target, *op, value, BinOpVariant::AugAssign)?;
self.unify(res_ty, target.custom.unwrap(), &stmt.location)?;
}
ast::StmtKind::Assert { test, msg, .. } => {
@ -548,7 +551,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
}
ExprKind::BoolOp { values, .. } => Some(self.infer_bool_ops(values)?),
ExprKind::BinOp { left, op, right } => {
Some(self.infer_bin_ops(expr.location, left, *op, right, false)?)
Some(self.infer_bin_ops(expr.location, left, *op, right, BinOpVariant::Normal)?)
}
ExprKind::UnaryOp { op, operand } => {
Some(self.infer_unary_ops(expr.location, *op, operand)?)
@ -615,6 +618,7 @@ impl<'a> Inferencer<'a> {
obj: Type,
params: Vec<Type>,
ret: Option<Type>,
call_info: CallInfo,
) -> InferenceResult {
if let TypeEnum::TObj { params: class_params, fields, .. } = &*self.unifier.get_ty(obj) {
if class_params.is_empty() {
@ -628,6 +632,7 @@ impl<'a> Inferencer<'a> {
ret: sign.ret,
fun: RefCell::new(None),
loc: Some(location),
info: call_info,
};
if let Some(ret) = ret {
self.unifier
@ -662,6 +667,7 @@ impl<'a> Inferencer<'a> {
ret,
fun: RefCell::new(None),
loc: Some(location),
info: call_info,
});
self.calls.insert(location.into(), call);
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call]));
@ -814,6 +820,150 @@ impl<'a> Inferencer<'a> {
})
}
/// Fold an ndarray `shape` argument. This function aims to fold `shape` arguments like that of
/// <https://numpy.org/doc/stable/reference/generated/numpy.zeros.html> (for `np_zeros`).
///
/// Arguments:
/// * `id` - The name of the function of the function call this `shape` argument is in. Used for error reporting.
/// * `arg_index` - The position (0-based) of this argument in the function call. Used for error reporting.
/// * `shape_expr` - [`Located<ExprKind>`] of the input argument.
///
/// On success, it returns a tuple of
/// 1) the `ndims` value inferred from the input `shape`,
/// 2) and the elaborated expression. Like what other fold functions of [`Inferencer`] would normally return.
fn fold_numpy_function_call_shape_argument(
&mut self,
id: StrRef,
arg_index: usize,
shape_expr: Located<ExprKind>,
) -> Result<(u64, ast::Expr<Option<Type>>), HashSet<String>> {
/*
### Further explanation
As said, this function aims to fold `shape` arguments, but this is *not* trivial.
The root of the issue is that `nac3core` has to deduce the `ndims`
of the created (for in the case of `np_zeros`) ndarray statically - i.e., during inference time.
There are three types of valid input to `shape`:
1. A python `List` (all `int32s`); e.g., `np_zeros([600, 800, 3])`
2. A python `Tuple` (all `int32s`); e.g., `np_zeros((600, 800, 3))`
3. An `int32`; e.g., `np_zeros(256)` - this is functionally equivalent to `np_zeros([256])`
For 2. and 3., `ndims` can be deduce immediately from the inferred type of the input:
- For 2. `ndims` is simply the number of elements found in [`TypeEnum::TTuple`] after typechecking the `shape` argument.
- For 3. `ndims` is simply 1.
For 1., `ndims` is supposedly the length of the input list. However, the length of the input list
is a runtime property. Therefore (as a hack) we resort to analyzing the argument expression [`ExprKind::List`]
itself to extract the input list length statically.
This implies that the user could only write:
```python
my_rgba_image = np_zeros([600, 800, 4])
# the shape argument is directly written as a list literal.
# and `nac3core` could therefore tell that ndims is `3` by
# looking at the raw AST expression itself.
```
But not:
```python
my_image_dimension = [600, 800, 4]
mystery_function_that_mutates_my_list(my_image_dimension)
my_image = np_zeros(my_image_dimension)
# what is the length now? what is `ndims`?
# it is *basically impossible* to generally determine the
# length of `my_image_dimension` statically for `ndims`!!
```
*/
// Fold `shape`
let shape = self.fold_expr(shape_expr)?;
let shape_ty = shape.custom.unwrap(); // The inferred type of `shape`
// Check `shape_ty` to see if its a list of int32s, a tuple of int32s, or just int32.
// Otherwise throw an error as that would mean the user wrote an ill-typed `shape_expr`.
//
// Here, we also take the opportunity to deduce `ndims` statically.
let shape_ty_enum = &*self.unifier.get_ty(shape_ty);
let ndims = match shape_ty_enum {
TypeEnum::TList { ty } => {
// Handle 1. A list of int32s
// Typecheck
self.unifier.unify(*ty, self.primitives.int32).map_err(|err| {
HashSet::from([err
.at(Some(shape.location))
.to_display(self.unifier)
.to_string()])
})?;
// Special handling for (1. A python `List` (all `int32s`)).
// Read the doc above this function to see what is going on here.
if let ExprKind::List { elts, .. } = &shape.node {
// The user wrote a List literal as the input argument
elts.len() as u64
} else {
// This means the user is passing an expression of type `List`,
// but it is done so indirectly (like putting a variable referencing a `List`)
// rather than writing a List literal. We need to report an error.
return Err(HashSet::from([
format!(
"Expected list literal, tuple, or int32 for argument {arg_num} of {id} at {location}. Input argument is of type list but not a list literal.",
arg_num = arg_index + 1,
location = shape.location
)
]));
}
}
TypeEnum::TTuple { ty: tuple_element_types } => {
// Handle 2. A tuple of int32s
// Typecheck
// The expected type is just the tuple but with all its elements being int32.
let expected_ty = self.unifier.add_ty(TypeEnum::TTuple {
ty: tuple_element_types.iter().map(|_| self.primitives.int32).collect_vec(),
});
self.unifier.unify(shape_ty, expected_ty).map_err(|err| {
HashSet::from([err
.at(Some(shape.location))
.to_display(self.unifier)
.to_string()])
})?;
// `ndims` can be deduced statically from the inferred Tuple type.
tuple_element_types.len() as u64
}
TypeEnum::TObj { .. } => {
// Handle 3. An integer (generalized as [`TypeEnum::TObj`])
// Typecheck
self.unify(self.primitives.int32, shape_ty, &shape.location)?;
// Deduce `ndims`
1
}
_ => {
// The user wrote an ill-typed `shape_expr`,
// so throw an error.
let shape_ty_str = self.unifier.stringify(shape_ty);
return report_error(
format!(
"Expected list literal, tuple, or int32 for argument {arg_num} of {id}, got {shape_expr_name} of type {shape_ty_str}",
arg_num = arg_index + 1,
shape_expr_name = shape.node.name(),
)
.as_str(),
shape.location,
);
}
};
Ok((ndims, shape))
}
/// Tries to fold a special call. Returns [`Some`] if the call expression `func` is a special call, otherwise
/// returns [`None`].
fn try_fold_special_call(
@ -1141,25 +1291,15 @@ impl<'a> Inferencer<'a> {
}));
}
// 1-argument ndarray n-dimensional creation functions
// 1-argument ndarray n-dimensional factory functions
if ["np_ndarray".into(), "np_empty".into(), "np_zeros".into(), "np_ones".into()]
.contains(id)
&& args.len() == 1
{
let ExprKind::List { elts, .. } = &args[0].node else {
return report_error(
format!(
"Expected List literal for first argument of {id}, got {}",
args[0].node.name()
)
.as_str(),
args[0].location,
);
};
let shape_expr = args.remove(0);
let (ndims, shape) =
self.fold_numpy_function_call_shape_argument(*id, 0, shape_expr)?; // Special handling for `shape`
let ndims = elts.len() as u64;
let arg0 = self.fold_expr(args.remove(0))?;
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
let ret = make_ndarray_ty(
self.unifier,
@ -1170,7 +1310,7 @@ impl<'a> Inferencer<'a> {
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg {
name: "shape".into(),
ty: arg0.custom.unwrap(),
ty: shape.custom.unwrap(),
default_value: None,
}],
ret,
@ -1186,7 +1326,7 @@ impl<'a> Inferencer<'a> {
location: func.location,
node: ExprKind::Name { id: *id, ctx: *ctx },
}),
args: vec![arg0],
args: vec![shape],
keywords: vec![],
},
}));
@ -1339,6 +1479,7 @@ impl<'a> Inferencer<'a> {
fun: RefCell::new(None),
ret: sign.ret,
loc: Some(location),
info: CallInfo::IsNormalFunctionCall,
};
self.unifier.unify_call(&call, func.custom.unwrap(), sign).map_err(|e| {
HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()])
@ -1361,6 +1502,7 @@ impl<'a> Inferencer<'a> {
fun: RefCell::new(None),
ret,
loc: Some(location),
info: CallInfo::IsNormalFunctionCall,
});
self.calls.insert(location.into(), call);
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call]));
@ -1536,7 +1678,7 @@ impl<'a> Inferencer<'a> {
left: &ast::Expr<Option<Type>>,
op: ast::Operator,
right: &ast::Expr<Option<Type>>,
is_aug_assign: bool,
variant: BinOpVariant,
) -> InferenceResult {
let left_ty = left.custom.unwrap();
let right_ty = right.custom.unwrap();
@ -1544,27 +1686,39 @@ impl<'a> Inferencer<'a> {
let method = if let TypeEnum::TObj { fields, .. } =
self.unifier.get_ty_immutable(left_ty).as_ref()
{
let (binop_name, binop_assign_name) =
(binop_name(op).into(), binop_assign_name(op).into());
let normal_method_name = OpInfo::from_binop(op, BinOpVariant::Normal).method_name;
let assign_method_name = OpInfo::from_binop(op, BinOpVariant::AugAssign).method_name;
// if is aug_assign, try aug_assign operator first
if is_aug_assign && fields.contains_key(&binop_assign_name) {
binop_assign_name
if variant == BinOpVariant::AugAssign && fields.contains_key(&assign_method_name.into())
{
assign_method_name
} else {
binop_name
normal_method_name
}
} else {
binop_name(op).into()
OpInfo::from_binop(op, variant).method_name
};
let ret = if is_aug_assign {
// The type of augmented assignment operator should never change
Some(left_ty)
} else {
typeof_binop(self.unifier, self.primitives, op, left_ty, right_ty)
.map_err(|e| HashSet::from([format!("{e} (at {location})")]))?
let ret = match variant {
BinOpVariant::Normal => {
typeof_binop(self.unifier, self.primitives, op, left_ty, right_ty)
.map_err(|e| HashSet::from([format!("{e} (at {location})")]))?
}
BinOpVariant::AugAssign => {
// The type of augmented assignment operator should never change
Some(left_ty)
}
};
self.build_method_call(location, method, left_ty, vec![right_ty], ret)
self.build_method_call(
location,
method.into(),
left_ty,
vec![right_ty],
ret,
CallInfo::IsBinaryOp { self_type: left.custom.unwrap(), operator: op, variant },
)
}
fn infer_unary_ops(
@ -1573,12 +1727,19 @@ impl<'a> Inferencer<'a> {
op: ast::Unaryop,
operand: &ast::Expr<Option<Type>>,
) -> InferenceResult {
let method = unaryop_name(op).into();
let method = OpInfo::from_unaryop(op).method_name.into();
let ret = typeof_unaryop(self.unifier, self.primitives, op, operand.custom.unwrap())
.map_err(|e| HashSet::from([format!("{e} (at {location})")]))?;
self.build_method_call(location, method, operand.custom.unwrap(), vec![], ret)
self.build_method_call(
location,
method,
operand.custom.unwrap(),
vec![],
ret,
CallInfo::IsUnaryOp { self_type: operand.custom.unwrap(), operator: op },
)
}
fn infer_compare(
@ -1603,8 +1764,9 @@ impl<'a> Inferencer<'a> {
let mut res = None;
for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) {
let method = comparison_name(*c)
let method = OpInfo::from_cmpop(*c)
.ok_or_else(|| HashSet::from(["unsupported comparator".to_string()]))?
.method_name
.into();
let ret = typeof_cmpop(
@ -1622,6 +1784,7 @@ impl<'a> Inferencer<'a> {
a.custom.unwrap(),
vec![b.custom.unwrap()],
ret,
CallInfo::IsComparisonOp { self_type: left.custom.unwrap(), operator: *c },
)?);
}

View File

@ -8,12 +8,14 @@ use std::rc::Rc;
use std::sync::{Arc, Mutex};
use std::{borrow::Cow, collections::HashSet};
use nac3parser::ast::{Location, StrRef};
use nac3parser::ast::{Cmpop, Location, Operator, StrRef, Unaryop};
use super::magic_methods::BinOpVariant;
use super::type_error::{TypeError, TypeErrorKind};
use super::unification_table::{UnificationKey, UnificationTable};
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::{DefinitionId, TopLevelContext, TopLevelDef};
use crate::typecheck::magic_methods::OpInfo;
use crate::typecheck::type_inferencer::PrimitiveStore;
#[cfg(test)]
@ -73,6 +75,32 @@ pub fn iter_type_vars(var_map: &VarMap) -> impl Iterator<Item = TypeVar> + '_ {
var_map.iter().map(|(&id, &ty)| TypeVar { id, ty })
}
/// Extra details about how a [`Call`] was written by the user.
#[derive(Debug, Clone)]
pub enum CallInfo {
/// The call was written as an unary operation, e.g., `~a` or `not a`.
IsUnaryOp {
/// The [`Type`] of the `self` object
self_type: Type,
operator: Unaryop,
},
/// The call was written as a binary operation, e.g., `a + b` or `a += b`.
IsBinaryOp {
/// The [`Type`] of the `self` object
self_type: Type,
operator: Operator,
variant: BinOpVariant,
},
/// The call was written as a binary comparison operation, e.g., `a < b`.
IsComparisonOp {
/// The [`Type`] of the `self` object
self_type: Type,
operator: Cmpop,
},
/// "Normal" function calls that looks like `func(1, 2, 3)`.
IsNormalFunctionCall,
}
#[derive(Clone)]
pub struct Call {
pub posargs: Vec<Type>,
@ -80,6 +108,7 @@ pub struct Call {
pub ret: Type,
pub fun: RefCell<Option<Type>>,
pub loc: Option<Location>,
pub info: CallInfo,
}
#[derive(Debug, Clone)]
@ -618,111 +647,179 @@ impl Unifier {
let TypeEnum::TFunc(signature) = &*self.get_ty(b) else { unreachable!() };
// Get details about the input arguments
let Call { posargs, kwargs, ret, fun, loc } = call;
let Call { posargs, kwargs, ret, fun, loc, info: call_info } = call;
let num_args = posargs.len() + kwargs.len();
// Now we check the arguments against the parameters
// Now we check the arguments against the parameters,
// and depending on what `call_info` is, we might change how the behavior `unify_call()`
// in hopes to improve user error messages when type checking fails.
match call_info {
CallInfo::IsBinaryOp { self_type, operator, variant } => {
// The call is written in the form of (say) `a + b`.
// Technically, it is `a.__add__(b)`, and they have the following constraints:
assert_eq!(posargs.len(), 1);
assert_eq!(kwargs.len(), 0);
assert_eq!(num_params, 1);
// Helper lambdas
let mut type_check_arg = |param_name, expected_arg_ty, arg_ty| {
let ok = self.unify_impl(expected_arg_ty, arg_ty, false).is_ok();
if ok {
Ok(())
} else {
// Typecheck failed, throw an error.
self.restore_snapshot();
Err(TypeError::new(
TypeErrorKind::IncorrectArgType {
name: param_name,
expected: expected_arg_ty,
got: arg_ty,
},
*loc,
))
let other_type = posargs[0]; // the second operand
let expected_other_type = signature.args[0].ty;
let ok = self.unify_impl(expected_other_type, other_type, false).is_ok();
if !ok {
self.restore_snapshot();
return Err(TypeError::new(
TypeErrorKind::UnsupportedBinaryOpTypes {
operator: *operator,
variant: *variant,
lhs_type: *self_type,
rhs_type: other_type,
expected_rhs_type: expected_other_type,
},
*loc,
));
}
}
};
CallInfo::IsComparisonOp { self_type, operator }
if OpInfo::from_cmpop(*operator).is_some() // Otherwise that comparison operator is not supported.
=>
{
// The call is written in the form of (say) `a <= b`.
// Technically, it is `a.__le__(b)`, and they have the following constraints:
assert_eq!(posargs.len(), 1);
assert_eq!(kwargs.len(), 0);
assert_eq!(num_params, 1);
// Check for "too many arguments"
if num_params < posargs.len() {
let expected_min_count =
signature.args.iter().filter(|param| param.is_required()).count();
let expected_max_count = num_params;
let other_type = posargs[0]; // the second operand
let expected_other_type = signature.args[0].ty;
self.restore_snapshot();
return Err(TypeError::new(
TypeErrorKind::TooManyArguments {
expected_min_count,
expected_max_count,
got_count: num_args,
},
*loc,
));
}
// NOTE: order of `param_info_by_name` is leveraged, so use an IndexMap
let mut param_info_by_name: IndexMap<StrRef, ParamInfo> = signature
.args
.iter()
.map(|arg| (arg.name, ParamInfo { has_been_supplied: false, param: arg }))
.collect();
// Now consume all positional arguments and typecheck them.
for (&arg_ty, param) in zip(posargs, signature.args.iter()) {
// We will also use this opportunity to mark the corresponding `param_info` as having been supplied.
let param_info = param_info_by_name.get_mut(&param.name).unwrap();
param_info.has_been_supplied = true;
// Typecheck
type_check_arg(param.name, param.ty, arg_ty)?;
}
// Now consume all keyword arguments and typecheck them.
for (&param_name, &arg_ty) in kwargs {
// We will also use this opportunity to check if this keyword argument is "legal".
let Some(param_info) = param_info_by_name.get_mut(&param_name) else {
self.restore_snapshot();
return Err(TypeError::new(TypeErrorKind::UnknownArgName(param_name), *loc));
};
if param_info.has_been_supplied {
// NOTE: Duplicate keyword argument (i.e., `hello(1, 2, 3, arg = 4, arg = 5)`)
// is IMPOSSIBLE as the parser would have already failed.
// We only have to care about "got multiple values for XYZ"
self.restore_snapshot();
return Err(TypeError::new(
TypeErrorKind::GotMultipleValues { name: param_name },
*loc,
));
let ok = self.unify_impl(expected_other_type, other_type, false).is_ok();
if !ok {
self.restore_snapshot();
return Err(TypeError::new(
TypeErrorKind::UnsupportedComparsionOpTypes {
operator: *operator,
lhs_type: *self_type,
rhs_type: other_type,
expected_rhs_type: expected_other_type,
},
*loc,
));
}
}
_ => {
// Handle [`CallInfo::IsNormalFunctionCall`] and other uninteresting variants
// of [`CallInfo`] (e.g, `CallInfo::IsUnaryOp` and unsupported comparison operators)
param_info.has_been_supplied = true;
// Helper lambdas
let mut type_check_arg = |param_name, expected_arg_ty, arg_ty| {
let ok = self.unify_impl(expected_arg_ty, arg_ty, false).is_ok();
if ok {
Ok(())
} else {
// Typecheck failed, throw an error.
self.restore_snapshot();
Err(TypeError::new(
TypeErrorKind::IncorrectArgType {
name: param_name,
expected: expected_arg_ty,
got: arg_ty,
},
*loc,
))
}
};
// Typecheck
type_check_arg(param_name, param_info.param.ty, arg_ty)?;
}
// Check for "too many arguments"
if num_params < posargs.len() {
let expected_min_count =
signature.args.iter().filter(|param| param.is_required()).count();
let expected_max_count = num_params;
// After checking posargs and kwargs, check if there are any
// unsupplied required parameters, and throw an error if they exist.
let missing_arg_names = param_info_by_name
.values()
.filter(|param_info| param_info.param.is_required() && !param_info.has_been_supplied)
.map(|param_info| param_info.param.name)
.collect_vec();
if !missing_arg_names.is_empty() {
self.restore_snapshot();
return Err(TypeError::new(TypeErrorKind::MissingArgs { missing_arg_names }, *loc));
}
self.restore_snapshot();
return Err(TypeError::new(
TypeErrorKind::TooManyArguments {
expected_min_count,
expected_max_count,
got_count: num_args,
},
*loc,
));
}
// Finally, check the Call's return type
self.unify_impl(*ret, signature.ret, false).map_err(|mut err| {
self.restore_snapshot();
if err.loc.is_none() {
err.loc = *loc;
// NOTE: order of `param_info_by_name` is leveraged, so use an IndexMap
let mut param_info_by_name: IndexMap<StrRef, ParamInfo> = signature
.args
.iter()
.map(|arg| (arg.name, ParamInfo { has_been_supplied: false, param: arg }))
.collect();
// Now consume all positional arguments and typecheck them.
for (&arg_ty, param) in zip(posargs, signature.args.iter()) {
// We will also use this opportunity to mark the corresponding `param_info` as having been supplied.
let param_info = param_info_by_name.get_mut(&param.name).unwrap();
param_info.has_been_supplied = true;
// Typecheck
type_check_arg(param.name, param.ty, arg_ty)?;
}
// Now consume all keyword arguments and typecheck them.
for (&param_name, &arg_ty) in kwargs {
// We will also use this opportunity to check if this keyword argument is "legal".
let Some(param_info) = param_info_by_name.get_mut(&param_name) else {
self.restore_snapshot();
return Err(TypeError::new(
TypeErrorKind::UnknownArgName(param_name),
*loc,
));
};
if param_info.has_been_supplied {
// NOTE: Duplicate keyword argument (i.e., `hello(1, 2, 3, arg = 4, arg = 5)`)
// is IMPOSSIBLE as the parser would have already failed.
// We only have to care about "got multiple values for XYZ"
self.restore_snapshot();
return Err(TypeError::new(
TypeErrorKind::GotMultipleValues { name: param_name },
*loc,
));
}
param_info.has_been_supplied = true;
// Typecheck
type_check_arg(param_name, param_info.param.ty, arg_ty)?;
}
// After checking posargs and kwargs, check if there are any
// unsupplied required parameters, and throw an error if they exist.
let missing_arg_names = param_info_by_name
.values()
.filter(|param_info| {
param_info.param.is_required() && !param_info.has_been_supplied
})
.map(|param_info| param_info.param.name)
.collect_vec();
if !missing_arg_names.is_empty() {
self.restore_snapshot();
return Err(TypeError::new(
TypeErrorKind::MissingArgs { missing_arg_names },
*loc,
));
}
// Finally, check the Call's return type
self.unify_impl(*ret, signature.ret, false).map_err(|mut err| {
self.restore_snapshot();
if err.loc.is_none() {
err.loc = *loc;
}
err
})?;
}
err
})?;
}
*fun.borrow_mut() = Some(b);

View File

@ -71,17 +71,44 @@ def output_ndarray_float_2(n: ndarray[float, Literal[2]]):
def consume_ndarray_1(n: ndarray[float, Literal[1]]):
pass
def consume_ndarray_2(n: ndarray[float, Literal[2]]):
pass
def test_ndarray_ctor():
n: ndarray[float, Literal[1]] = np_ndarray([1])
consume_ndarray_1(n)
def test_ndarray_empty():
n: ndarray[float, 1] = np_empty([1])
consume_ndarray_1(n)
n1: ndarray[float, 1] = np_empty([1])
consume_ndarray_1(n1)
n2: ndarray[float, 1] = np_empty(10)
consume_ndarray_1(n2)
n3: ndarray[float, 1] = np_empty((2,))
consume_ndarray_1(n3)
n4: ndarray[float, 2] = np_empty((4, 4))
consume_ndarray_2(n4)
dim4 = (5, 2)
n5: ndarray[float, 2] = np_empty(dim4)
consume_ndarray_2(n5)
def test_ndarray_zeros():
n: ndarray[float, 1] = np_zeros([1])
output_ndarray_float_1(n)
n1: ndarray[float, 1] = np_zeros([1])
output_ndarray_float_1(n1)
k = 3 + int32(n1[0]) # to test variable shape inputs
n2: ndarray[float, 1] = np_zeros(k * k)
output_ndarray_float_1(n2)
n3: ndarray[float, 1] = np_zeros((k * 2,))
output_ndarray_float_1(n3)
dim4 = (3, 2 * k)
n4: ndarray[float, 2] = np_zeros(dim4)
output_ndarray_float_2(n4)
def test_ndarray_ones():
n: ndarray[float, 1] = np_ones([1])