core: Extract special method handling in type inferencer

To prepare for more special handling with methods.
pull/371/head
David Mak 2023-12-15 16:57:23 +08:00
parent e435b25756
commit 03870f222d
1 changed files with 126 additions and 106 deletions

View File

@ -10,7 +10,13 @@ use itertools::izip;
use nac3parser::ast::{
self,
fold::{self, Fold},
Arguments, Comprehension, ExprContext, ExprKind, Located, Location, StrRef,
Arguments,
Comprehension,
ExprContext,
ExprKind,
Located,
Location,
StrRef
};
#[cfg(test)]
@ -773,6 +779,120 @@ impl<'a> Inferencer<'a> {
})
}
/// Tries to fold a special call. Returns [`Some`] if the call expression `func` is a special call, otherwise
/// returns [`None`].
fn try_fold_special_call(
&mut self,
location: Location,
func: &ast::Expr<()>,
args: &mut Vec<ast::Expr<()>>,
keywords: &Vec<Located<ast::KeywordData>>,
) -> Result<Option<ast::Expr<Option<Type>>>, HashSet<String>> {
let Located { location: func_location, node: ExprKind::Name { id, ctx }, .. } = func else {
return Ok(None)
};
// handle special functions that cannot be typed in the usual way...
if id == &"virtual".into() {
if args.is_empty() || args.len() > 2 || !keywords.is_empty() {
return report_error(
"`virtual` can only accept 1/2 positional arguments",
*func_location,
)
}
let arg0 = self.fold_expr(args.remove(0))?;
let ty = if let Some(arg) = args.pop() {
let top_level_defs = self.top_level.definitions.read();
self.function_data.resolver.parse_type_annotation(
top_level_defs.as_slice(),
self.unifier,
self.primitives,
&arg,
)?
} else {
self.unifier.get_dummy_var().0
};
self.virtual_checks.push((arg0.custom.unwrap(), ty, *func_location));
let custom = Some(self.unifier.add_ty(TypeEnum::TVirtual { ty }));
return Ok(Some(Located {
location,
custom,
node: ExprKind::Call {
func: Box::new(Located {
custom: None,
location: func.location,
node: ExprKind::Name { id: *id, ctx: ctx.clone() },
}),
args: vec![arg0],
keywords: vec![],
},
}))
}
// int64 is special because its argument can be a constant larger than int32
if id == &"int64".into() && args.len() == 1 {
if let ExprKind::Constant { value: ast::Constant::Int(val), kind } =
&args[0].node
{
let custom = Some(self.primitives.int64);
let v: Result<i64, _> = (*val).try_into();
return if v.is_ok() {
Ok(Some(Located {
location: args[0].location,
custom,
node: ExprKind::Constant {
value: ast::Constant::Int(*val),
kind: kind.clone(),
},
}))
} else {
report_error("Integer out of bound", args[0].location)
}
}
}
if id == &"uint32".into() && args.len() == 1 {
if let ExprKind::Constant { value: ast::Constant::Int(val), kind } =
&args[0].node
{
let custom = Some(self.primitives.uint32);
let v: Result<u32, _> = (*val).try_into();
return if v.is_ok() {
Ok(Some(Located {
location: args[0].location,
custom,
node: ExprKind::Constant {
value: ast::Constant::Int(*val),
kind: kind.clone(),
},
}))
} else {
report_error("Integer out of bound", args[0].location)
}
}
}
if id == &"uint64".into() && args.len() == 1 {
if let ExprKind::Constant { value: ast::Constant::Int(val), kind } =
&args[0].node
{
let custom = Some(self.primitives.uint64);
let v: Result<u64, _> = (*val).try_into();
return if v.is_ok() {
Ok(Some(Located {
location: args[0].location,
custom,
node: ExprKind::Constant {
value: ast::Constant::Int(*val),
kind: kind.clone(),
},
}))
} else {
report_error("Integer out of bound", args[0].location)
}
}
}
Ok(None)
}
fn fold_call(
&mut self,
location: Location,
@ -780,111 +900,11 @@ impl<'a> Inferencer<'a> {
mut args: Vec<ast::Expr<()>>,
keywords: Vec<Located<ast::KeywordData>>,
) -> Result<ast::Expr<Option<Type>>, HashSet<String>> {
let func =
if let Located { location: func_location, custom, node: ExprKind::Name { id, ctx } } =
func
{
// handle special functions that cannot be typed in the usual way...
if id == "virtual".into() {
if args.is_empty() || args.len() > 2 || !keywords.is_empty() {
return report_error(
"`virtual` can only accept 1/2 positional arguments",
func_location,
);
}
let arg0 = self.fold_expr(args.remove(0))?;
let ty = if let Some(arg) = args.pop() {
let top_level_defs = self.top_level.definitions.read();
self.function_data.resolver.parse_type_annotation(
top_level_defs.as_slice(),
self.unifier,
self.primitives,
&arg,
)?
} else {
self.unifier.get_dummy_var().0
};
self.virtual_checks.push((arg0.custom.unwrap(), ty, func_location));
let custom = Some(self.unifier.add_ty(TypeEnum::TVirtual { ty }));
return Ok(Located {
location,
custom,
node: ExprKind::Call {
func: Box::new(Located {
custom: None,
location: func.location,
node: ExprKind::Name { id, ctx },
}),
args: vec![arg0],
keywords: vec![],
},
});
}
// int64 is special because its argument can be a constant larger than int32
if id == "int64".into() && args.len() == 1 {
if let ExprKind::Constant { value: ast::Constant::Int(val), kind } =
&args[0].node
{
let custom = Some(self.primitives.int64);
let v: Result<i64, _> = (*val).try_into();
return if v.is_ok() {
Ok(Located {
location: args[0].location,
custom,
node: ExprKind::Constant {
value: ast::Constant::Int(*val),
kind: kind.clone(),
},
})
} else {
report_error("Integer out of bound", args[0].location)
}
}
}
if id == "uint32".into() && args.len() == 1 {
if let ExprKind::Constant { value: ast::Constant::Int(val), kind } =
&args[0].node
{
let custom = Some(self.primitives.uint32);
let v: Result<u32, _> = (*val).try_into();
return if v.is_ok() {
Ok(Located {
location: args[0].location,
custom,
node: ExprKind::Constant {
value: ast::Constant::Int(*val),
kind: kind.clone(),
},
})
} else {
report_error("Integer out of bound", args[0].location)
}
}
}
if id == "uint64".into() && args.len() == 1 {
if let ExprKind::Constant { value: ast::Constant::Int(val), kind } =
&args[0].node
{
let custom = Some(self.primitives.uint64);
let v: Result<u64, _> = (*val).try_into();
return if v.is_ok() {
Ok(Located {
location: args[0].location,
custom,
node: ExprKind::Constant {
value: ast::Constant::Int(*val),
kind: kind.clone(),
},
})
} else {
report_error("Integer out of bound", args[0].location)
}
}
}
Located { location: func_location, custom, node: ExprKind::Name { id, ctx } }
} else {
func
};
let func = if let Some(spec_call_func) = self.try_fold_special_call(location, &func, &mut args, &keywords)? {
return Ok(spec_call_func)
} else {
func
};
let func = Box::new(self.fold_expr(func)?);
let args = args.into_iter().map(|v| self.fold_expr(v)).collect::<Result<Vec<_>, _>>()?;
let keywords = keywords