Compare commits

...

7 Commits

Author SHA1 Message Date
a763ea3b61 Fix common issue WIP(assignment) 2024-08-27 17:53:29 +08:00
5b2b6db7ed core: improve error messages 2024-08-26 18:37:55 +08:00
15e62f467e standalone: add tests for polymorphism 2024-08-26 18:37:55 +08:00
2c88924ff7 core: add support for simple polymorphism 2024-08-26 18:37:55 +08:00
a744b139ba core: allow Call and AnnAssign in init block 2024-08-26 18:37:55 +08:00
2b2b2dbf8f [core] Fix resolution of exception names in raise short form
Previous implementation fails as `resolver.get_identifier_def` in ARTIQ
would return the exception __init__ function rather than the class.

We fix this by limiting the exception class resolution to only include
raise statements, and to force the exception name to always be treated
as a class.

Fixes #501.
2024-08-26 18:35:02 +08:00
d9f96dab33 [core] Add codegen_unreachable 2024-08-23 13:10:55 +08:00
11 changed files with 483 additions and 155 deletions

View File

@ -9,6 +9,7 @@ use crate::codegen::classes::{
};
use crate::codegen::expr::destructure_range;
use crate::codegen::irrt::calculate_len_for_slice_range;
use crate::codegen::macros::codegen_unreachable;
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
use crate::codegen::stmt::gen_for_callback_incrementing;
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
@ -20,7 +21,8 @@ use crate::typecheck::typedef::{Type, TypeEnum};
///
/// The generated message will contain the function name and the name of the unsupported type.
fn unsupported_type(ctx: &CodeGenContext<'_, '_>, fn_name: &str, tys: &[Type]) -> ! {
unreachable!(
codegen_unreachable!(
ctx,
"{fn_name}() not supported for '{}'",
tys.iter().map(|ty| format!("'{}'", ctx.unifier.stringify(*ty))).join(", "),
)
@ -82,7 +84,7 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap()
}
_ => unreachable!(),
_ => codegen_unreachable!(ctx),
}
})
}
@ -784,7 +786,7 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
} else {
unreachable!()
codegen_unreachable!(ctx)
};
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
@ -888,7 +890,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
match fn_name {
"np_argmin" | "np_argmax" => llvm_int64.const_zero().into(),
"np_max" | "np_min" => a,
_ => unreachable!(),
_ => codegen_unreachable!(ctx),
}
}
BasicValueEnum::PointerValue(n)
@ -943,7 +945,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
"np_argmax" | "np_max" => {
call_max(ctx, (elem_ty, accumulator), (elem_ty, elem))
}
_ => unreachable!(),
_ => codegen_unreachable!(ctx),
};
let updated_idx = match (accumulator, result) {
@ -980,7 +982,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
match fn_name {
"np_argmin" | "np_argmax" => ctx.builder.build_load(res_idx, "").unwrap(),
"np_max" | "np_min" => ctx.builder.build_load(accumulator_addr, "").unwrap(),
_ => unreachable!(),
_ => codegen_unreachable!(ctx),
}
}
@ -1046,7 +1048,7 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
} else {
unreachable!()
codegen_unreachable!(ctx)
};
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
@ -1486,7 +1488,7 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>(
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
} else {
unreachable!()
codegen_unreachable!(ctx)
};
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
@ -1553,7 +1555,7 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>(
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
} else {
unreachable!()
codegen_unreachable!(ctx)
};
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
@ -1620,7 +1622,7 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>(
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
} else {
unreachable!()
codegen_unreachable!(ctx)
};
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
@ -1687,7 +1689,7 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>(
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
} else {
unreachable!()
codegen_unreachable!(ctx)
};
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
@ -1810,7 +1812,7 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>(
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
} else {
unreachable!()
codegen_unreachable!(ctx)
};
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
@ -1877,7 +1879,7 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
} else {
unreachable!()
codegen_unreachable!(ctx)
};
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };

View File

@ -11,6 +11,7 @@ use crate::{
call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax,
call_int_umin, call_memcpy_generic,
},
macros::codegen_unreachable,
need_sret, numpy,
stmt::{
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
@ -112,7 +113,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
let obj_id = match &*self.unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } => *obj_id,
// we cannot have other types, virtual type should be handled by function calls
_ => unreachable!(),
_ => codegen_unreachable!(self),
};
let def = &self.top_level.definitions.read()[obj_id.0];
let (index, value) = if let TopLevelDef::Class { fields, attributes, .. } = &*def.read() {
@ -123,7 +124,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
(attribute_index.0, Some(attribute_index.1 .2.clone()))
}
} else {
unreachable!()
codegen_unreachable!(self)
};
(index, value)
}
@ -133,7 +134,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
TypeEnum::TObj { fields, .. } => {
fields.iter().find_position(|x| *x.0 == attr).unwrap().0
}
_ => unreachable!(),
_ => codegen_unreachable!(self),
}
}
@ -188,7 +189,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
{
*params.iter().next().unwrap().1
}
_ => unreachable!("must be option type"),
_ => codegen_unreachable!(self, "must be option type"),
};
let val = self.gen_symbol_val(generator, v, ty);
let ptr = generator
@ -204,7 +205,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
{
*params.iter().next().unwrap().1
}
_ => unreachable!("must be option type"),
_ => codegen_unreachable!(self, "must be option type"),
};
let actual_ptr_type =
self.get_llvm_type(generator, ty).ptr_type(AddressSpace::default());
@ -271,7 +272,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
{
self.ctx.i64_type()
} else {
unreachable!()
codegen_unreachable!(self)
};
Some(ty.const_int(*val as u64, false).into())
}
@ -285,7 +286,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
let (types, is_vararg_ctx) = if let TypeEnum::TTuple { ty, is_vararg_ctx } = &*ty {
(ty.clone(), *is_vararg_ctx)
} else {
unreachable!()
codegen_unreachable!(self)
};
let values = zip(types, v.iter())
.map_while(|(ty, v)| self.gen_const(generator, v, ty))
@ -330,7 +331,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
None
}
_ => unreachable!(),
_ => codegen_unreachable!(self),
}
}
@ -344,7 +345,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
signed: bool,
) -> BasicValueEnum<'ctx> {
let (BasicValueEnum::IntValue(lhs), BasicValueEnum::IntValue(rhs)) = (lhs, rhs) else {
unreachable!()
codegen_unreachable!(self)
};
let float = self.ctx.f64_type();
match (op, signed) {
@ -419,7 +420,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
.build_right_shift(lhs, rhs, signed, "rshift")
.map(Into::into)
.unwrap(),
_ => unreachable!(),
_ => codegen_unreachable!(self),
}
}
@ -431,7 +432,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
}
(Operator::Pow, s) => integer_power(generator, self, lhs, rhs, s).into(),
// special implementation?
(Operator::MatMult, _) => unreachable!(),
(Operator::MatMult, _) => codegen_unreachable!(self),
}
}
@ -443,7 +444,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
rhs: BasicValueEnum<'ctx>,
) -> BasicValueEnum<'ctx> {
let (BasicValueEnum::FloatValue(lhs), BasicValueEnum::FloatValue(rhs)) = (lhs, rhs) else {
unreachable!(
codegen_unreachable!(
self,
"Expected (FloatValue, FloatValue), got ({}, {})",
lhs.get_type(),
rhs.get_type()
@ -687,7 +689,7 @@ pub fn gen_constructor<'ctx, 'a, G: CodeGenerator>(
def: &TopLevelDef,
params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
) -> Result<BasicValueEnum<'ctx>, String> {
let TopLevelDef::Class { methods, .. } = def else { unreachable!() };
let TopLevelDef::Class { methods, .. } = def else { codegen_unreachable!(ctx) };
// TODO: what about other fields that require alloca?
let fun_id = methods.iter().find(|method| method.0 == "__init__".into()).map(|method| method.2);
@ -719,7 +721,7 @@ pub fn gen_func_instance<'ctx>(
key,
) = fun
else {
unreachable!()
codegen_unreachable!(ctx)
};
if let Some(sym) = instance_to_symbol.get(&key) {
@ -751,7 +753,7 @@ pub fn gen_func_instance<'ctx>(
.collect();
let mut signature = store.from_signature(&mut ctx.unifier, &ctx.primitives, sign, &mut cache);
let ConcreteTypeEnum::TFunc { args, .. } = &mut signature else { unreachable!() };
let ConcreteTypeEnum::TFunc { args, .. } = &mut signature else { codegen_unreachable!(ctx) };
if let Some(obj) = &obj {
let zelf = store.from_unifier_type(&mut ctx.unifier, &ctx.primitives, obj.0, &mut cache);
@ -1117,7 +1119,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
ctx: &mut CodeGenContext<'ctx, '_>,
expr: &Expr<Option<Type>>,
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
let ExprKind::ListComp { elt, generators } = &expr.node else { unreachable!() };
let ExprKind::ListComp { elt, generators } = &expr.node else { codegen_unreachable!(ctx) };
let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
@ -1376,13 +1378,13 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
if let TypeEnum::TObj { params, .. } = &*ctx.unifier.get_ty_immutable(ty1) {
ctx.unifier.get_representative(*params.iter().next().unwrap().1)
} else {
unreachable!()
codegen_unreachable!(ctx)
};
let elem_ty2 =
if let TypeEnum::TObj { params, .. } = &*ctx.unifier.get_ty_immutable(ty2) {
ctx.unifier.get_representative(*params.iter().next().unwrap().1)
} else {
unreachable!()
codegen_unreachable!(ctx)
};
debug_assert!(ctx.unifier.unioned(elem_ty1, elem_ty2));
@ -1455,7 +1457,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
{
*params.iter().next().unwrap().1
} else {
unreachable!()
codegen_unreachable!(ctx)
};
(elem_ty, left_val, right_val)
@ -1465,12 +1467,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
{
*params.iter().next().unwrap().1
} else {
unreachable!()
codegen_unreachable!(ctx)
};
(elem_ty, right_val, left_val)
} else {
unreachable!()
codegen_unreachable!(ctx)
};
let list_val =
ListValue::from_ptr_val(list_val.into_pointer_value(), llvm_usize, None);
@ -1637,7 +1639,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
} else {
let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap());
let TypeEnum::TObj { fields, obj_id, .. } = left_ty_enum.as_ref() else {
unreachable!("must be tobj")
codegen_unreachable!(ctx, "must be tobj")
};
let (op_name, id) = {
let normal_method_name = Binop::normal(op.base).op_info().method_name;
@ -1658,19 +1660,19 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
} else {
let left_enum_ty = ctx.unifier.get_ty_immutable(left_ty.unwrap());
let TypeEnum::TObj { fields, .. } = left_enum_ty.as_ref() else {
unreachable!("must be tobj")
codegen_unreachable!(ctx, "must be tobj")
};
let fn_ty = fields.get(&op_name).unwrap().0;
let fn_ty_enum = ctx.unifier.get_ty_immutable(fn_ty);
let TypeEnum::TFunc(sig) = fn_ty_enum.as_ref() else { unreachable!() };
let TypeEnum::TFunc(sig) = fn_ty_enum.as_ref() else { codegen_unreachable!(ctx) };
sig.clone()
};
let fun_id = {
let defs = ctx.top_level.definitions.read();
let obj_def = defs.get(id.0).unwrap().read();
let TopLevelDef::Class { methods, .. } = &*obj_def else { unreachable!() };
let TopLevelDef::Class { methods, .. } = &*obj_def else { codegen_unreachable!(ctx) };
methods.iter().find(|method| method.0 == op_name).unwrap().2
};
@ -1801,7 +1803,8 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
if op == ast::Unaryop::Invert {
ast::Unaryop::Not
} else {
unreachable!(
codegen_unreachable!(
ctx,
"ufunc {} not supported for ndarray[bool, N]",
op.op_info().method_name,
)
@ -1868,8 +1871,8 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
{
let llvm_usize = generator.get_size_type(ctx.ctx);
let (Some(left_ty), lhs) = left else { unreachable!() };
let (Some(right_ty), rhs) = comparators[0] else { unreachable!() };
let (Some(left_ty), lhs) = left else { codegen_unreachable!(ctx) };
let (Some(right_ty), rhs) = comparators[0] else { codegen_unreachable!(ctx) };
let op = ops[0];
let is_ndarray1 =
@ -1976,7 +1979,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
let op = match op {
ast::Cmpop::Eq | ast::Cmpop::Is => IntPredicate::EQ,
ast::Cmpop::NotEq => IntPredicate::NE,
_ if left_ty == ctx.primitives.bool => unreachable!(),
_ if left_ty == ctx.primitives.bool => codegen_unreachable!(ctx),
ast::Cmpop::Lt => {
if use_unsigned_ops {
IntPredicate::ULT
@ -2005,7 +2008,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
IntPredicate::SGE
}
}
_ => unreachable!(),
_ => codegen_unreachable!(ctx),
};
ctx.builder.build_int_compare(op, lhs, rhs, "cmp").unwrap()
@ -2022,7 +2025,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
ast::Cmpop::LtE => inkwell::FloatPredicate::OLE,
ast::Cmpop::Gt => inkwell::FloatPredicate::OGT,
ast::Cmpop::GtE => inkwell::FloatPredicate::OGE,
_ => unreachable!(),
_ => codegen_unreachable!(ctx),
};
ctx.builder.build_float_compare(op, lhs, rhs, "cmp").unwrap()
} else if left_ty == ctx.primitives.str {
@ -2154,7 +2157,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
match (op, val) {
(Cmpop::Eq, true) | (Cmpop::NotEq, false) => llvm_i1.const_all_ones(),
(Cmpop::Eq, false) | (Cmpop::NotEq, true) => llvm_i1.const_zero(),
(_, _) => unreachable!(),
(_, _) => codegen_unreachable!(ctx),
}
};
@ -2167,14 +2170,14 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
{
*params.iter().next().unwrap().1
} else {
unreachable!()
codegen_unreachable!(ctx)
};
let right_elem_ty = if let TypeEnum::TObj { params, .. } =
&*ctx.unifier.get_ty_immutable(right_ty)
{
*params.iter().next().unwrap().1
} else {
unreachable!()
codegen_unreachable!(ctx)
};
if !ctx.unifier.unioned(left_elem_ty, right_elem_ty) {
@ -2511,7 +2514,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
let llvm_usize = generator.get_size_type(ctx.ctx);
let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) else {
unreachable!()
codegen_unreachable!(ctx)
};
let ndims = values
@ -2863,7 +2866,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
.const_null()
.into()
}
_ => unreachable!("must be option type"),
_ => codegen_unreachable!(ctx, "must be option type"),
}
}
ExprKind::Name { id, .. } => match ctx.var_assignment.get(id) {
@ -2873,29 +2876,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
Some((_, Some(static_value), _)) => ValueEnum::Static(static_value.clone()),
None => {
let resolver = ctx.resolver.clone();
if let Some(res) = resolver.get_symbol_value(*id, ctx) {
res
} else {
// Allow "raise Exception" short form
let def_id = resolver.get_identifier_def(*id).map_err(|e| {
format!("{} (at {})", e.iter().next().unwrap(), expr.location)
})?;
let def = ctx.top_level.definitions.read();
if let TopLevelDef::Class { constructor, .. } = *def[def_id.0].read() {
let TypeEnum::TFunc(signature) =
ctx.unifier.get_ty(constructor.unwrap()).as_ref().clone()
else {
return Err(format!(
"Failed to resolve symbol {} (at {})",
id, expr.location
));
};
return Ok(generator
.gen_call(ctx, None, (&signature, def_id), Vec::default())?
.map(Into::into));
}
return Err(format!("Failed to resolve symbol {} (at {})", id, expr.location));
}
resolver.get_symbol_value(*id, ctx).unwrap()
}
},
ExprKind::List { elts, .. } => {
@ -2924,7 +2905,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
*params.iter().next().unwrap().1
} else {
unreachable!()
codegen_unreachable!(ctx)
};
if let TypeEnum::TVar { .. } = &*ctx.unifier.get_ty_immutable(ty) {
@ -3018,7 +2999,9 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
return generator.gen_expr(ctx, &modified_expr);
}
None => unreachable!("Function Type should not have attributes"),
None => {
codegen_unreachable!(ctx, "Function Type should not have attributes")
}
}
} else if let TypeEnum::TObj { obj_id, fields, params } = &*ctx.unifier.get_ty(c) {
if fields.is_empty() && params.is_empty() {
@ -3040,7 +3023,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
return generator.gen_expr(ctx, &modified_expr);
}
None => unreachable!(),
None => codegen_unreachable!(ctx),
}
}
}
@ -3142,7 +3125,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
}
(Some(a), None) => a.into(),
(None, Some(b)) => b.into(),
(None, None) => unreachable!(),
(None, None) => codegen_unreachable!(ctx),
}
}
ExprKind::BinOp { op, left, right } => {
@ -3232,7 +3215,9 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
ctx.unifier.get_call_signature(*call).unwrap()
} else {
let ty = func.custom.unwrap();
let TypeEnum::TFunc(sign) = &*ctx.unifier.get_ty(ty) else { unreachable!() };
let TypeEnum::TFunc(sign) = &*ctx.unifier.get_ty(ty) else {
codegen_unreachable!(ctx)
};
sign.clone()
};
@ -3251,17 +3236,26 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
let Some(val) = generator.gen_expr(ctx, value)? else { return Ok(None) };
// Handle Class Method calls
// The attribute will be `DefinitionId` of the method if the call is to one of the parent methods
let func_id = attr.to_string().parse::<usize>();
let id = if let TypeEnum::TObj { obj_id, .. } =
&*ctx.unifier.get_ty(value.custom.unwrap())
{
*obj_id
} else {
unreachable!()
codegen_unreachable!(ctx)
};
let fun_id = {
// Use the `DefinitionID` from attribute if it is available
let fun_id = if let Ok(func_id) = func_id {
DefinitionId(func_id)
} else {
let defs = ctx.top_level.definitions.read();
let obj_def = defs.get(id.0).unwrap().read();
let TopLevelDef::Class { methods, .. } = &*obj_def else { unreachable!() };
let TopLevelDef::Class { methods, .. } = &*obj_def else {
codegen_unreachable!(ctx)
};
methods.iter().find(|method| method.0 == *attr).unwrap().2
};
@ -3332,7 +3326,9 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
.unwrap(),
));
}
ValueEnum::Dynamic(_) => unreachable!("option must be static or ptr"),
ValueEnum::Dynamic(_) => {
codegen_unreachable!(ctx, "option must be static or ptr")
}
}
}
@ -3481,7 +3477,10 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
if let ExprKind::Constant { value: Constant::Int(v), .. } = &slice.node {
(*v).try_into().unwrap()
} else {
unreachable!("tuple subscript must be const int after type check");
codegen_unreachable!(
ctx,
"tuple subscript must be const int after type check"
);
};
match generator.gen_expr(ctx, value)? {
Some(ValueEnum::Dynamic(v)) => {
@ -3504,7 +3503,10 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
None => return Ok(None),
}
}
_ => unreachable!("should not be other subscriptable types after type check"),
_ => codegen_unreachable!(
ctx,
"should not be other subscriptable types after type check"
),
}
}
ExprKind::ListComp { .. } => {

View File

@ -3,12 +3,13 @@ use crate::typecheck::typedef::Type;
use super::{
classes::{
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue,
TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
TypedArrayLikeAccessor, TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
},
llvm_intrinsics, CodeGenContext, CodeGenerator,
llvm_intrinsics,
macros::codegen_unreachable,
stmt::gen_for_callback_incrementing,
CodeGenContext, CodeGenerator,
};
use crate::codegen::classes::TypedArrayLikeAccessor;
use crate::codegen::stmt::gen_for_callback_incrementing;
use inkwell::{
attributes::{Attribute, AttributeLoc},
context::Context,
@ -55,7 +56,7 @@ pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>(
(64, 64, true) => "__nac3_int_exp_int64_t",
(32, 32, false) => "__nac3_int_exp_uint32_t",
(64, 64, false) => "__nac3_int_exp_uint64_t",
_ => unreachable!(),
_ => codegen_unreachable!(ctx),
};
let base_type = base.get_type();
let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| {
@ -441,7 +442,7 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
BasicTypeEnum::IntType(t) => t.size_of(),
BasicTypeEnum::PointerType(t) => t.size_of(),
BasicTypeEnum::StructType(t) => t.size_of().unwrap(),
_ => unreachable!(),
_ => codegen_unreachable!(ctx),
};
ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size").unwrap()
}
@ -586,7 +587,7 @@ where
let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_size",
64 => "__nac3_ndarray_calc_size64",
bw => unreachable!("Unsupported size type bit width: {}", bw),
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_size_fn_t = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()],
@ -637,7 +638,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_nd_indices",
64 => "__nac3_ndarray_calc_nd_indices64",
bw => unreachable!("Unsupported size type bit width: {}", bw),
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_nd_indices_fn =
ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| {
@ -706,7 +707,7 @@ where
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_flatten_index",
64 => "__nac3_ndarray_flatten_index64",
bw => unreachable!("Unsupported size type bit width: {}", bw),
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_flatten_index_fn =
ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| {
@ -774,7 +775,7 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast",
64 => "__nac3_ndarray_calc_broadcast64",
bw => unreachable!("Unsupported size type bit width: {}", bw),
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_broadcast_fn =
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
@ -894,7 +895,7 @@ pub fn call_ndarray_calc_broadcast_index<
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast_idx",
64 => "__nac3_ndarray_calc_broadcast_idx64",
bw => unreachable!("Unsupported size type bit width: {}", bw),
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_broadcast_fn =
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {

View File

@ -50,6 +50,22 @@ mod test;
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore};
pub use generator::{CodeGenerator, DefaultCodeGenerator};
mod macros {
/// Codegen-variant of [`std::unreachable`] which accepts an instance of [`CodeGenContext`] as
/// its first argument to provide Python source information to indicate the codegen location
/// causing the assertion.
macro_rules! codegen_unreachable {
($ctx:expr $(,)?) => {
std::unreachable!("unreachable code while processing {}", &$ctx.current_loc)
};
($ctx:expr, $($arg:tt)*) => {
std::unreachable!("unreachable code while processing {}: {}", &$ctx.current_loc, std::format!("{}", std::format_args!($($arg)+)))
};
}
pub(crate) use codegen_unreachable;
}
#[derive(Default)]
pub struct StaticValueStore {
pub lookup: HashMap<Vec<(usize, u64)>, usize>,

View File

@ -12,6 +12,7 @@ use crate::{
call_ndarray_calc_size,
},
llvm_intrinsics::{self, call_memcpy_generic},
macros::codegen_unreachable,
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
CodeGenContext, CodeGenerator,
},
@ -259,7 +260,7 @@ fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
ctx.gen_string(generator, "").into()
} else {
unreachable!()
codegen_unreachable!(ctx)
}
}
@ -287,7 +288,7 @@ fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
ctx.gen_string(generator, "1").into()
} else {
unreachable!()
codegen_unreachable!(ctx)
}
}
@ -355,7 +356,7 @@ fn call_ndarray_empty_impl<'ctx, G: CodeGenerator + ?Sized>(
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
}
_ => unreachable!(),
_ => codegen_unreachable!(ctx),
}
}
@ -626,7 +627,7 @@ fn call_ndarray_full_impl<'ctx, G: CodeGenerator + ?Sized>(
} else if fill_value.is_int_value() || fill_value.is_float_value() {
fill_value
} else {
unreachable!()
codegen_unreachable!(ctx)
};
Ok(value)
@ -2020,7 +2021,7 @@ pub fn gen_ndarray_fill<'ctx>(
} else if value_arg.is_int_value() || value_arg.is_float_value() {
value_arg
} else {
unreachable!()
codegen_unreachable!(ctx)
};
Ok(value)
@ -2129,7 +2130,8 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
Ok(out.as_base_value().into())
} else {
unreachable!(
codegen_unreachable!(
ctx,
"{FN_NAME}() not supported for '{}'",
format!("'{}'", ctx.unifier.stringify(x1_ty))
)
@ -2371,7 +2373,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
.into_int_value();
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
}
_ => unreachable!(),
_ => codegen_unreachable!(ctx),
}
.unwrap();
@ -2415,7 +2417,8 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
Ok(out.as_base_value().into())
} else {
unreachable!(
codegen_unreachable!(
ctx,
"{FN_NAME}() not supported for '{}'",
format!("'{}'", ctx.unifier.stringify(x1_ty))
)
@ -2483,7 +2486,7 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
.build_float_mul(e1, elem2.into_float_value(), "")
.unwrap()
.as_basic_value_enum(),
_ => unreachable!(),
_ => codegen_unreachable!(ctx),
};
let acc_val = ctx.builder.build_load(acc, "").unwrap();
let acc_val = match acc_val {
@ -2497,7 +2500,7 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
.build_float_add(e1, product.into_float_value(), "")
.unwrap()
.as_basic_value_enum(),
_ => unreachable!(),
_ => codegen_unreachable!(ctx),
};
ctx.builder.build_store(acc, acc_val).unwrap();
@ -2514,7 +2517,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
(BasicValueEnum::FloatValue(e1), BasicValueEnum::FloatValue(e2)) => {
Ok(ctx.builder.build_float_mul(e1, e2, "").unwrap().as_basic_value_enum())
}
_ => unreachable!(
_ => codegen_unreachable!(
ctx,
"{FN_NAME}() not supported for '{}'",
format!("'{}'", ctx.unifier.stringify(x1_ty))
),

View File

@ -1,15 +1,13 @@
use super::{
super::symbol_resolver::ValueEnum,
expr::destructure_range,
classes::{ArrayLikeIndexer, ArraySliceValue, ListValue, RangeValue},
expr::{destructure_range, gen_binop_expr},
gen_in_range_check,
irrt::{handle_slice_indices, list_slice_assignment},
macros::codegen_unreachable,
CodeGenContext, CodeGenerator,
};
use crate::{
codegen::{
classes::{ArrayLikeIndexer, ArraySliceValue, ListValue, RangeValue},
expr::gen_binop_expr,
gen_in_range_check,
},
symbol_resolver::ValueEnum,
toplevel::{DefinitionId, TopLevelDef},
typecheck::{
magic_methods::Binop,
@ -121,7 +119,7 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
return Ok(None);
};
let BasicValueEnum::PointerValue(ptr) = val else {
unreachable!();
codegen_unreachable!(ctx);
};
unsafe {
ctx.builder.build_in_bounds_gep(
@ -135,7 +133,7 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
}
.unwrap()
}
_ => unreachable!(),
_ => codegen_unreachable!(ctx),
}))
}
@ -193,12 +191,12 @@ pub fn gen_assign_target_list<'ctx, G: CodeGenerator>(
// Deconstruct the tuple `value`
let BasicValueEnum::StructValue(tuple) = value.to_basic_value_enum(ctx, generator, value_ty)?
else {
unreachable!()
codegen_unreachable!(ctx)
};
// NOTE: Currently, RHS's type is forced to be a Tuple by the type inferencer.
let TypeEnum::TTuple { ty: tuple_tys, .. } = &*ctx.unifier.get_ty(value_ty) else {
unreachable!();
codegen_unreachable!(ctx);
};
assert_eq!(tuple.get_type().count_fields() as usize, tuple_tys.len());
@ -258,7 +256,7 @@ pub fn gen_assign_target_list<'ctx, G: CodeGenerator>(
// Now assign with that sub-tuple to the starred target.
generator.gen_assign(ctx, target, ValueEnum::Dynamic(sub_tuple_val), sub_tuple_ty)?;
} else {
unreachable!() // The typechecker ensures this
codegen_unreachable!(ctx) // The typechecker ensures this
}
// Handle assignment after the starred target
@ -306,7 +304,9 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
if let ExprKind::Slice { .. } = &key.node {
// Handle assigning to a slice
let ExprKind::Slice { lower, upper, step } = &key.node else { unreachable!() };
let ExprKind::Slice { lower, upper, step } = &key.node else {
codegen_unreachable!(ctx)
};
let Some((start, end, step)) = handle_slice_indices(
lower,
upper,
@ -416,7 +416,9 @@ pub fn gen_for<G: CodeGenerator>(
ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String> {
let StmtKind::For { iter, target, body, orelse, .. } = &stmt.node else { unreachable!() };
let StmtKind::For { iter, target, body, orelse, .. } = &stmt.node else {
codegen_unreachable!(ctx)
};
// var_assignment static values may be changed in another branch
// if so, remove the static value as it may not be correct in this branch
@ -458,7 +460,7 @@ pub fn gen_for<G: CodeGenerator>(
let Some(target_i) =
generator.gen_store_target(ctx, target, Some("for.target.addr"))?
else {
unreachable!()
codegen_unreachable!(ctx)
};
let (start, stop, step) = destructure_range(ctx, iter_val);
@ -901,7 +903,7 @@ pub fn gen_while<G: CodeGenerator>(
ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String> {
let StmtKind::While { test, body, orelse, .. } = &stmt.node else { unreachable!() };
let StmtKind::While { test, body, orelse, .. } = &stmt.node else { codegen_unreachable!(ctx) };
// var_assignment static values may be changed in another branch
// if so, remove the static value as it may not be correct in this branch
@ -931,7 +933,7 @@ pub fn gen_while<G: CodeGenerator>(
return Ok(());
};
let BasicValueEnum::IntValue(test) = test else { unreachable!() };
let BasicValueEnum::IntValue(test) = test else { codegen_unreachable!(ctx) };
ctx.builder
.build_conditional_branch(generator.bool_to_i1(ctx, test), body_bb, orelse_bb)
@ -1079,7 +1081,7 @@ pub fn gen_if<G: CodeGenerator>(
ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String> {
let StmtKind::If { test, body, orelse, .. } = &stmt.node else { unreachable!() };
let StmtKind::If { test, body, orelse, .. } = &stmt.node else { codegen_unreachable!(ctx) };
// var_assignment static values may be changed in another branch
// if so, remove the static value as it may not be correct in this branch
@ -1202,11 +1204,11 @@ pub fn exn_constructor<'ctx>(
let zelf_id = if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(zelf_ty) {
obj_id.0
} else {
unreachable!()
codegen_unreachable!(ctx)
};
let defs = ctx.top_level.definitions.read();
let def = defs[zelf_id].read();
let TopLevelDef::Class { name: zelf_name, .. } = &*def else { unreachable!() };
let TopLevelDef::Class { name: zelf_name, .. } = &*def else { codegen_unreachable!(ctx) };
let exception_name = format!("{}:{}", ctx.resolver.get_exception_id(zelf_id), zelf_name);
unsafe {
let id_ptr = ctx.builder.build_in_bounds_gep(zelf, &[zero, zero], "exn.id").unwrap();
@ -1314,7 +1316,7 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
target: &Stmt<Option<Type>>,
) -> Result<(), String> {
let StmtKind::Try { body, handlers, orelse, finalbody, .. } = &target.node else {
unreachable!()
codegen_unreachable!(ctx)
};
// if we need to generate anything related to exception, we must have personality defined
@ -1391,7 +1393,7 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(type_.custom.unwrap()) {
*obj_id
} else {
unreachable!()
codegen_unreachable!(ctx)
};
let exception_name = format!("{}:{}", ctx.resolver.get_exception_id(obj_id.0), exn_name);
let exn_id = ctx.resolver.get_string_id(&exception_name);
@ -1760,7 +1762,30 @@ pub fn gen_stmt<G: CodeGenerator>(
StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?,
StmtKind::Raise { exc, .. } => {
if let Some(exc) = exc {
let exc = if let Some(v) = generator.gen_expr(ctx, exc)? {
let exn = if let ExprKind::Name { id, .. } = &exc.node {
// Handle "raise Exception" short form
let def_id = ctx.resolver.get_identifier_def(*id).map_err(|e| {
format!("{} (at {})", e.iter().next().unwrap(), exc.location)
})?;
let def = ctx.top_level.definitions.read();
let TopLevelDef::Class { constructor, .. } = *def[def_id.0].read() else {
return Err(format!("Failed to resolve symbol {id} (at {})", exc.location));
};
let TypeEnum::TFunc(signature) =
ctx.unifier.get_ty(constructor.unwrap()).as_ref().clone()
else {
return Err(format!("Failed to resolve symbol {id} (at {})", exc.location));
};
generator
.gen_call(ctx, None, (&signature, def_id), Vec::default())?
.map(Into::into)
} else {
generator.gen_expr(ctx, exc)?
};
let exc = if let Some(v) = exn {
v.to_basic_value_enum(ctx, generator, exc.custom.unwrap())?
} else {
return Ok(());

View File

@ -23,7 +23,7 @@ impl Default for ComposerConfig {
}
}
type DefAst = (Arc<RwLock<TopLevelDef>>, Option<Stmt<()>>);
pub type DefAst = (Arc<RwLock<TopLevelDef>>, Option<Stmt<()>>);
pub struct TopLevelComposer {
// list of top level definitions, same as top level context
pub definition_ast_list: Vec<DefAst>,
@ -1822,7 +1822,12 @@ impl TopLevelComposer {
if *name != init_str_id {
unreachable!("must be init function here")
}
let all_inited = Self::get_all_assigned_field(body.as_slice())?;
let all_inited = Self::get_all_assigned_field(
object_id.0,
definition_ast_list,
body.as_slice(),
)?;
for (f, _, _) in fields {
if !all_inited.contains(f) {
return Err(HashSet::from([

View File

@ -3,6 +3,7 @@ use std::convert::TryInto;
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::numpy::unpack_ndarray_var_tys;
use crate::typecheck::typedef::{into_var_map, iter_type_vars, Mapping, TypeVarId, VarMap};
use ast::ExprKind;
use nac3parser::ast::{Constant, Location};
use strum::IntoEnumIterator;
use strum_macros::EnumIter;
@ -733,7 +734,16 @@ impl TopLevelComposer {
)
}
pub fn get_all_assigned_field(stmts: &[Stmt<()>]) -> Result<HashSet<StrRef>, HashSet<String>> {
/// This function returns the fields that have been initialized in the `__init__` function of a class
/// The function takes as input:
/// * `class_id`: The `object_id` of the class whose function is being evaluated (check `TopLevelDef::Class`)
/// * `definition_ast_list`: A list of ast definitions and statements defined in `TopLevelComposer`
/// * `stmts`: The body of function being parsed. Each statment is analyzed to check varaible initialization statements
pub fn get_all_assigned_field(
class_id: usize,
definition_ast_list: &Vec<DefAst>,
stmts: &[Stmt<()>],
) -> Result<HashSet<StrRef>, HashSet<String>> {
let mut result = HashSet::new();
for s in stmts {
match &s.node {
@ -769,30 +779,138 @@ impl TopLevelComposer {
// TODO: do not check for For and While?
ast::StmtKind::For { body, orelse, .. }
| ast::StmtKind::While { body, orelse, .. } => {
result.extend(Self::get_all_assigned_field(body.as_slice())?);
result.extend(Self::get_all_assigned_field(orelse.as_slice())?);
result.extend(Self::get_all_assigned_field(
class_id,
definition_ast_list,
body.as_slice(),
)?);
result.extend(Self::get_all_assigned_field(
class_id,
definition_ast_list,
orelse.as_slice(),
)?);
}
ast::StmtKind::If { body, orelse, .. } => {
let inited_for_sure = Self::get_all_assigned_field(body.as_slice())?
.intersection(&Self::get_all_assigned_field(orelse.as_slice())?)
.copied()
.collect::<HashSet<_>>();
let inited_for_sure = Self::get_all_assigned_field(
class_id,
definition_ast_list,
body.as_slice(),
)?
.intersection(&Self::get_all_assigned_field(
class_id,
definition_ast_list,
orelse.as_slice(),
)?)
.copied()
.collect::<HashSet<_>>();
result.extend(inited_for_sure);
}
ast::StmtKind::Try { body, orelse, finalbody, .. } => {
let inited_for_sure = Self::get_all_assigned_field(body.as_slice())?
.intersection(&Self::get_all_assigned_field(orelse.as_slice())?)
.copied()
.collect::<HashSet<_>>();
let inited_for_sure = Self::get_all_assigned_field(
class_id,
definition_ast_list,
body.as_slice(),
)?
.intersection(&Self::get_all_assigned_field(
class_id,
definition_ast_list,
orelse.as_slice(),
)?)
.copied()
.collect::<HashSet<_>>();
result.extend(inited_for_sure);
result.extend(Self::get_all_assigned_field(finalbody.as_slice())?);
result.extend(Self::get_all_assigned_field(
class_id,
definition_ast_list,
finalbody.as_slice(),
)?);
}
ast::StmtKind::With { body, .. } => {
result.extend(Self::get_all_assigned_field(body.as_slice())?);
result.extend(Self::get_all_assigned_field(
class_id,
definition_ast_list,
body.as_slice(),
)?);
}
// Variables Initialized in function calls
ast::StmtKind::Expr { value, .. } => {
let ExprKind::Call { func, .. } = &value.node else {
continue;
};
let ExprKind::Attribute { value, attr, .. } = &func.node else {
continue;
};
let ExprKind::Name { id, .. } = &value.node else {
continue;
};
// Need to consider the two cases:
// Case 1) Call to class function i.e. id = `self`
// Case 2) Call to class ancestor function i.e. id = ancestor_name
// We leave checking whether function in case 2 belonged to class ancestor or not to type checker
//
// According to current handling of `self`, function definition are fixed and do not change regardless
// of which object is passed as `self` i.e. virtual polymorphism is not supported
// Therefore, we change class id for case 2 to reflect behavior of our compiler
let class_name = if *id == "self".into() {
let ast::StmtKind::ClassDef { name, .. } =
&definition_ast_list[class_id].1.as_ref().unwrap().node
else {
unreachable!()
};
name
} else {
id
};
let parent_method = definition_ast_list.iter().find_map(|def| {
let (
class_def,
Some(ast::Located {
node: ast::StmtKind::ClassDef { name, body, .. },
..
}),
) = &def
else {
return None;
};
let TopLevelDef::Class { object_id: class_id, .. } = &*class_def.read()
else {
unreachable!()
};
if name == class_name {
body.iter().find_map(|m| {
let ast::StmtKind::FunctionDef { name, body, .. } = &m.node else {
return None;
};
if *name == *attr {
return Some((body.clone(), class_id.0));
}
None
})
} else {
None
}
});
// If method body is none then method does not exist
if let Some((method_body, class_id)) = parent_method {
result.extend(Self::get_all_assigned_field(
class_id,
definition_ast_list,
method_body.as_slice(),
)?);
} else {
return Err(HashSet::from([format!(
"{}.{} not found in class {class_name} at {}",
*id, *attr, value.location
)]));
}
}
ast::StmtKind::Pass { .. }
| ast::StmtKind::Assert { .. }
| ast::StmtKind::Expr { .. } => {}
| ast::StmtKind::AnnAssign { .. } => {}
_ => {
unimplemented!()

View File

@ -12,6 +12,7 @@ use super::{
RecordField, RecordKey, Type, TypeEnum, TypeVar, Unifier, VarMap,
},
};
use crate::toplevel::type_annotation::TypeAnnotation;
use crate::{
symbol_resolver::{SymbolResolver, SymbolValue},
toplevel::{
@ -102,6 +103,7 @@ pub struct Inferencer<'a> {
}
type InferenceError = HashSet<String>;
type OverrideResult = Result<Option<ast::Expr<Option<Type>>>, InferenceError>;
struct NaiveFolder();
impl Fold<()> for NaiveFolder {
@ -1672,6 +1674,86 @@ impl<'a> Inferencer<'a> {
Ok(None)
}
/// Checks whether a class method is calling parent function
/// Returns [`None`] if its not a call to parent method, otherwise
/// returns a new `func` with class name replaced by `self` and method resolved to its `DefinitionID`
///
/// e.g. A.f1(self, ...) returns Some(self.{DefintionID(f1)})
fn check_overriding(&mut self, func: &ast::Expr<()>, args: &[ast::Expr<()>]) -> OverrideResult {
// `self` must be first argument for call to parent method
if let Some(Located { node: ExprKind::Name { id, .. }, .. }) = &args.first() {
if *id != "self".into() {
return Ok(None);
}
} else {
return Ok(None);
}
let Located {
node: ExprKind::Attribute { value, attr: method_name, ctx }, location, ..
} = func
else {
return Ok(None);
};
let ExprKind::Name { id: class_name, ctx: class_ctx } = &value.node else {
return Ok(None);
};
let zelf = &self.fold_expr(args[0].clone())?;
// Check whether the method belongs to class ancestors
let def_id = self.unifier.get_ty(zelf.custom.unwrap());
let TypeEnum::TObj { obj_id, .. } = def_id.as_ref() else { unreachable!() };
let defs = self.top_level.definitions.read();
let res = {
if let TopLevelDef::Class { ancestors, .. } = &*defs[obj_id.0].read() {
let res = ancestors.iter().find_map(|f| {
let TypeAnnotation::CustomClass { id, .. } = f else { unreachable!() };
let TopLevelDef::Class { name, methods, .. } = &*defs[id.0].read() else {
unreachable!()
};
// Class names are stored as `__module__.class`
let name = name.to_string();
let (_, name) = name.rsplit_once('.').unwrap();
if name == class_name.to_string() {
return methods.iter().find_map(|f| {
if f.0 == *method_name {
return Some(*f);
}
None
});
}
None
});
res
} else {
None
}
};
match res {
Some(r) => {
let mut new_func = func.clone();
let mut new_value = value.clone();
new_value.node = ExprKind::Name { id: "self".into(), ctx: *class_ctx };
new_func.node =
ExprKind::Attribute { value: new_value.clone(), attr: *method_name, ctx: *ctx };
let mut new_func = self.fold_expr(new_func)?;
let ExprKind::Attribute { value, .. } = new_func.node else { unreachable!() };
new_func.node =
ExprKind::Attribute { value, attr: r.2 .0.to_string().into(), ctx: *ctx };
new_func.custom = Some(r.1);
Ok(Some(new_func))
}
None => report_error(
format!("Ancestor method [{class_name}.{method_name}] should be defined with same decorator as its overridden version").as_str(),
*location,
),
}
}
fn fold_call(
&mut self,
location: Location,
@ -1685,8 +1767,20 @@ impl<'a> Inferencer<'a> {
return Ok(spec_call_func);
}
let func = Box::new(self.fold_expr(func)?);
let args = args.into_iter().map(|v| self.fold_expr(v)).collect::<Result<Vec<_>, _>>()?;
// Check for call to parent method
let override_res = self.check_overriding(&func, &args)?;
let is_override = override_res.is_some();
let func = if is_override { override_res.unwrap() } else { self.fold_expr(func)? };
let func = Box::new(func);
let mut args =
args.into_iter().map(|v| self.fold_expr(v)).collect::<Result<Vec<_>, _>>()?;
// TODO: Handle passing of self to functions to allow runtime lookup of functions to be called
// Currently removing `self` and using compile time function definitions
if is_override {
args.remove(0);
}
let keywords = keywords
.into_iter()
.map(|v| fold::fold_keyword(self, v))

View File

@ -0,0 +1,26 @@
1
2
3
4
5
1
2
3
4
5
1
2
3
4
5
1
0
2
1
2
False
4
5
6
7
8

View File

@ -10,23 +10,58 @@ class A:
def __init__(self, a: int32):
self.a = a
def f1(self):
self.f2()
def f2(self):
def output_all_fields(self):
output_int32(self.a)
def set_a(self, a: int32):
self.a = a
class B(A):
b: int32
def __init__(self, b: int32):
self.a = b + 1
A.__init__(self, b + 1)
self.set_b(b)
def output_parent_fields(self):
A.output_all_fields(self)
def output_all_fields(self):
A.output_all_fields(self)
output_int32(self.b)
def set_b(self, b: int32):
self.b = b
class C(B):
c: int32
def __init__(self, c: int32):
B.__init__(self, c + 1)
self.c = c
def output_parent_fields(self):
B.output_all_fields(self)
def output_all_fields(self):
B.output_all_fields(self)
output_int32(self.c)
def set_c(self, c: int32):
self.c = c
def run() -> int32:
aaa = A(5)
bbb = B(2)
aaa.f1()
bbb.f1()
ccc = C(10)
ccc.output_all_fields()
ccc.set_a(1)
ccc.set_b(2)
ccc.set_c(3)
ccc.output_all_fields()
bbb = B(10)
bbb.set_a(9)
bbb.set_b(8)
bbb.output_all_fields()
ccc.output_all_fields()
return 0