core: Extract special method handling in type inferencer

To prepare for more special handling with methods.
This commit is contained in:
David Mak 2023-12-15 16:57:23 +08:00
parent e435b25756
commit 03870f222d

View File

@ -10,7 +10,13 @@ use itertools::izip;
use nac3parser::ast::{ use nac3parser::ast::{
self, self,
fold::{self, Fold}, fold::{self, Fold},
Arguments, Comprehension, ExprContext, ExprKind, Located, Location, StrRef, Arguments,
Comprehension,
ExprContext,
ExprKind,
Located,
Location,
StrRef
}; };
#[cfg(test)] #[cfg(test)]
@ -773,24 +779,26 @@ impl<'a> Inferencer<'a> {
}) })
} }
fn fold_call( /// Tries to fold a special call. Returns [`Some`] if the call expression `func` is a special call, otherwise
/// returns [`None`].
fn try_fold_special_call(
&mut self, &mut self,
location: Location, location: Location,
func: ast::Expr<()>, func: &ast::Expr<()>,
mut args: Vec<ast::Expr<()>>, args: &mut Vec<ast::Expr<()>>,
keywords: Vec<Located<ast::KeywordData>>, keywords: &Vec<Located<ast::KeywordData>>,
) -> Result<ast::Expr<Option<Type>>, HashSet<String>> { ) -> Result<Option<ast::Expr<Option<Type>>>, HashSet<String>> {
let func = let Located { location: func_location, node: ExprKind::Name { id, ctx }, .. } = func else {
if let Located { location: func_location, custom, node: ExprKind::Name { id, ctx } } = return Ok(None)
func };
{
// handle special functions that cannot be typed in the usual way... // handle special functions that cannot be typed in the usual way...
if id == "virtual".into() { if id == &"virtual".into() {
if args.is_empty() || args.len() > 2 || !keywords.is_empty() { if args.is_empty() || args.len() > 2 || !keywords.is_empty() {
return report_error( return report_error(
"`virtual` can only accept 1/2 positional arguments", "`virtual` can only accept 1/2 positional arguments",
func_location, *func_location,
); )
} }
let arg0 = self.fold_expr(args.remove(0))?; let arg0 = self.fold_expr(args.remove(0))?;
let ty = if let Some(arg) = args.pop() { let ty = if let Some(arg) = args.pop() {
@ -804,84 +812,96 @@ impl<'a> Inferencer<'a> {
} else { } else {
self.unifier.get_dummy_var().0 self.unifier.get_dummy_var().0
}; };
self.virtual_checks.push((arg0.custom.unwrap(), ty, func_location)); self.virtual_checks.push((arg0.custom.unwrap(), ty, *func_location));
let custom = Some(self.unifier.add_ty(TypeEnum::TVirtual { ty })); let custom = Some(self.unifier.add_ty(TypeEnum::TVirtual { ty }));
return Ok(Located { return Ok(Some(Located {
location, location,
custom, custom,
node: ExprKind::Call { node: ExprKind::Call {
func: Box::new(Located { func: Box::new(Located {
custom: None, custom: None,
location: func.location, location: func.location,
node: ExprKind::Name { id, ctx }, node: ExprKind::Name { id: *id, ctx: ctx.clone() },
}), }),
args: vec![arg0], args: vec![arg0],
keywords: vec![], keywords: vec![],
}, },
}); }))
} }
// int64 is special because its argument can be a constant larger than int32 // int64 is special because its argument can be a constant larger than int32
if id == "int64".into() && args.len() == 1 { if id == &"int64".into() && args.len() == 1 {
if let ExprKind::Constant { value: ast::Constant::Int(val), kind } = if let ExprKind::Constant { value: ast::Constant::Int(val), kind } =
&args[0].node &args[0].node
{ {
let custom = Some(self.primitives.int64); let custom = Some(self.primitives.int64);
let v: Result<i64, _> = (*val).try_into(); let v: Result<i64, _> = (*val).try_into();
return if v.is_ok() { return if v.is_ok() {
Ok(Located { Ok(Some(Located {
location: args[0].location, location: args[0].location,
custom, custom,
node: ExprKind::Constant { node: ExprKind::Constant {
value: ast::Constant::Int(*val), value: ast::Constant::Int(*val),
kind: kind.clone(), kind: kind.clone(),
}, },
}) }))
} else { } else {
report_error("Integer out of bound", args[0].location) report_error("Integer out of bound", args[0].location)
} }
} }
} }
if id == "uint32".into() && args.len() == 1 { if id == &"uint32".into() && args.len() == 1 {
if let ExprKind::Constant { value: ast::Constant::Int(val), kind } = if let ExprKind::Constant { value: ast::Constant::Int(val), kind } =
&args[0].node &args[0].node
{ {
let custom = Some(self.primitives.uint32); let custom = Some(self.primitives.uint32);
let v: Result<u32, _> = (*val).try_into(); let v: Result<u32, _> = (*val).try_into();
return if v.is_ok() { return if v.is_ok() {
Ok(Located { Ok(Some(Located {
location: args[0].location, location: args[0].location,
custom, custom,
node: ExprKind::Constant { node: ExprKind::Constant {
value: ast::Constant::Int(*val), value: ast::Constant::Int(*val),
kind: kind.clone(), kind: kind.clone(),
}, },
}) }))
} else { } else {
report_error("Integer out of bound", args[0].location) report_error("Integer out of bound", args[0].location)
} }
} }
} }
if id == "uint64".into() && args.len() == 1 { if id == &"uint64".into() && args.len() == 1 {
if let ExprKind::Constant { value: ast::Constant::Int(val), kind } = if let ExprKind::Constant { value: ast::Constant::Int(val), kind } =
&args[0].node &args[0].node
{ {
let custom = Some(self.primitives.uint64); let custom = Some(self.primitives.uint64);
let v: Result<u64, _> = (*val).try_into(); let v: Result<u64, _> = (*val).try_into();
return if v.is_ok() { return if v.is_ok() {
Ok(Located { Ok(Some(Located {
location: args[0].location, location: args[0].location,
custom, custom,
node: ExprKind::Constant { node: ExprKind::Constant {
value: ast::Constant::Int(*val), value: ast::Constant::Int(*val),
kind: kind.clone(), kind: kind.clone(),
}, },
}) }))
} else { } else {
report_error("Integer out of bound", args[0].location) report_error("Integer out of bound", args[0].location)
} }
} }
} }
Located { location: func_location, custom, node: ExprKind::Name { id, ctx } }
Ok(None)
}
fn fold_call(
&mut self,
location: Location,
func: ast::Expr<()>,
mut args: Vec<ast::Expr<()>>,
keywords: Vec<Located<ast::KeywordData>>,
) -> Result<ast::Expr<Option<Type>>, HashSet<String>> {
let func = if let Some(spec_call_func) = self.try_fold_special_call(location, &func, &mut args, &keywords)? {
return Ok(spec_call_func)
} else { } else {
func func
}; };