forked from M-Labs/nac3
1
0
Fork 0

type_inferencer: check defined identifiers during inference

This commit is contained in:
pca006132 2021-08-27 11:13:43 +08:00
parent 35ef0386db
commit a24e204824
5 changed files with 77 additions and 12 deletions

View File

@ -77,6 +77,7 @@ fn test_primitives() {
}; };
let mut virtual_checks = Vec::new(); let mut virtual_checks = Vec::new();
let mut calls = HashMap::new(); let mut calls = HashMap::new();
let mut identifiers = vec!["a".to_string(), "b".to_string()];
let mut inferencer = Inferencer { let mut inferencer = Inferencer {
top_level: &top_level, top_level: &top_level,
function_data: &mut function_data, function_data: &mut function_data,
@ -85,6 +86,7 @@ fn test_primitives() {
primitives: &primitives, primitives: &primitives,
virtual_checks: &mut virtual_checks, virtual_checks: &mut virtual_checks,
calls: &mut calls, calls: &mut calls,
defined_identifiers: identifiers.clone()
}; };
inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32); inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32);
inferencer.variable_mapping.insert("b".into(), inferencer.primitives.int32); inferencer.variable_mapping.insert("b".into(), inferencer.primitives.int32);
@ -95,7 +97,6 @@ fn test_primitives() {
.collect::<Result<Vec<_>, _>>() .collect::<Result<Vec<_>, _>>()
.unwrap(); .unwrap();
let mut identifiers = vec!["a".to_string(), "b".to_string()];
inferencer.check_block(&statements, &mut identifiers).unwrap(); inferencer.check_block(&statements, &mut identifiers).unwrap();
let top_level = Arc::new(TopLevelContext { let top_level = Arc::new(TopLevelContext {
definitions: Arc::new(RwLock::new(std::mem::take(&mut *top_level.definitions.write()))), definitions: Arc::new(RwLock::new(std::mem::take(&mut *top_level.definitions.write()))),
@ -235,6 +236,7 @@ fn test_simple_call() {
}; };
let mut virtual_checks = Vec::new(); let mut virtual_checks = Vec::new();
let mut calls = HashMap::new(); let mut calls = HashMap::new();
let mut identifiers = vec!["a".to_string(), "foo".into()];
let mut inferencer = Inferencer { let mut inferencer = Inferencer {
top_level: &top_level, top_level: &top_level,
function_data: &mut function_data, function_data: &mut function_data,
@ -243,6 +245,7 @@ fn test_simple_call() {
primitives: &primitives, primitives: &primitives,
virtual_checks: &mut virtual_checks, virtual_checks: &mut virtual_checks,
calls: &mut calls, calls: &mut calls,
defined_identifiers: identifiers.clone()
}; };
inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32); inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32);
inferencer.variable_mapping.insert("foo".into(), fun_ty); inferencer.variable_mapping.insert("foo".into(), fun_ty);
@ -273,7 +276,6 @@ fn test_simple_call() {
unreachable!() unreachable!()
} }
let mut identifiers = vec!["a".to_string(), "foo".into()];
inferencer.check_block(&statements_1, &mut identifiers).unwrap(); inferencer.check_block(&statements_1, &mut identifiers).unwrap();
let top_level = Arc::new(TopLevelContext { let top_level = Arc::new(TopLevelContext {
definitions: Arc::new(RwLock::new(std::mem::take(&mut *top_level.definitions.write()))), definitions: Arc::new(RwLock::new(std::mem::take(&mut *top_level.definitions.write()))),

View File

@ -42,7 +42,7 @@ impl<'a> Inferencer<'a> {
fn check_expr( fn check_expr(
&mut self, &mut self,
expr: &Expr<Option<Type>>, expr: &Expr<Option<Type>>,
defined_identifiers: &[String], defined_identifiers: &mut Vec<String>,
) -> Result<(), String> { ) -> Result<(), String> {
// there are some cases where the custom field is None // there are some cases where the custom field is None
if let Some(ty) = &expr.custom { if let Some(ty) = &expr.custom {
@ -57,12 +57,16 @@ impl<'a> Inferencer<'a> {
match &expr.node { match &expr.node {
ExprKind::Name { id, .. } => { ExprKind::Name { id, .. } => {
if !defined_identifiers.contains(id) { if !defined_identifiers.contains(id) {
if self.function_data.resolver.get_identifier_def(id).is_some() {
defined_identifiers.push(id.clone());
} else {
return Err(format!( return Err(format!(
"unknown identifier {} (use before def?) at {}", "unknown identifier {} (use before def?) at {}",
id, expr.location id, expr.location
)); ));
} }
} }
}
ExprKind::List { elts, .. } ExprKind::List { elts, .. }
| ExprKind::Tuple { elts, .. } | ExprKind::Tuple { elts, .. }
| ExprKind::BoolOp { values: elts, .. } => { | ExprKind::BoolOp { values: elts, .. } => {
@ -106,7 +110,7 @@ impl<'a> Inferencer<'a> {
defined_identifiers.push(arg.node.arg.clone()); defined_identifiers.push(arg.node.arg.clone());
} }
} }
self.check_expr(body, &defined_identifiers)?; self.check_expr(body, &mut defined_identifiers)?;
} }
ExprKind::ListComp { elt, generators, .. } => { ExprKind::ListComp { elt, generators, .. } => {
// in our type inference stage, we already make sure that there is only 1 generator // in our type inference stage, we already make sure that there is only 1 generator
@ -115,7 +119,7 @@ impl<'a> Inferencer<'a> {
let mut defined_identifiers = defined_identifiers.to_vec(); let mut defined_identifiers = defined_identifiers.to_vec();
self.check_pattern(target, &mut defined_identifiers)?; self.check_pattern(target, &mut defined_identifiers)?;
for term in once(elt.as_ref()).chain(ifs.iter()) { for term in once(elt.as_ref()).chain(ifs.iter()) {
self.check_expr(term, &defined_identifiers)?; self.check_expr(term, &mut defined_identifiers)?;
} }
} }
ExprKind::Call { func, args, keywords } => { ExprKind::Call { func, args, keywords } => {

View File

@ -45,6 +45,7 @@ pub struct FunctionData {
pub struct Inferencer<'a> { pub struct Inferencer<'a> {
pub top_level: &'a TopLevelContext, pub top_level: &'a TopLevelContext,
pub defined_identifiers: Vec<String>,
pub function_data: &'a mut FunctionData, pub function_data: &'a mut FunctionData,
pub unifier: &'a mut Unifier, pub unifier: &'a mut Unifier,
pub primitives: &'a PrimitiveStore, pub primitives: &'a PrimitiveStore,
@ -74,13 +75,14 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
let stmt = match node.node { let stmt = match node.node {
// we don't want fold over type annotation // we don't want fold over type annotation
ast::StmtKind::AnnAssign { target, annotation, value, simple } => { ast::StmtKind::AnnAssign { target, annotation, value, simple } => {
self.infer_pattern(&target)?;
let target = Box::new(self.fold_expr(*target)?); let target = Box::new(self.fold_expr(*target)?);
let value = if let Some(v) = value { let value = if let Some(v) = value {
let ty = Box::new(self.fold_expr(*v)?); let ty = Box::new(self.fold_expr(*v)?);
self.unify(target.custom.unwrap(), ty.custom.unwrap(), &node.location)?; self.unify(target.custom.unwrap(), ty.custom.unwrap(), &node.location)?;
Some(ty) Some(ty)
} else { } else {
None return Err(format!("declaration without definition is not yet supported, at {}", node.location))
}; };
let top_level_defs = self.top_level.definitions.read(); let top_level_defs = self.top_level.definitions.read();
let annotation_type = self.function_data.resolver.parse_type_annotation( let annotation_type = self.function_data.resolver.parse_type_annotation(
@ -97,6 +99,16 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
node: ast::StmtKind::AnnAssign { target, annotation, value, simple }, node: ast::StmtKind::AnnAssign { target, annotation, value, simple },
} }
} }
ast::StmtKind::For { ref target, .. } => {
self.infer_pattern(target)?;
fold::fold_stmt(self, node)?
}
ast::StmtKind::Assign { ref targets, .. } => {
for target in targets {
self.infer_pattern(target)?;
}
fold::fold_stmt(self, node)?
}
_ => fold::fold_stmt(self, node)?, _ => fold::fold_stmt(self, node)?,
}; };
match &stmt.node { match &stmt.node {
@ -146,7 +158,19 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
}; };
let custom = match &expr.node { let custom = match &expr.node {
ast::ExprKind::Constant { value, .. } => Some(self.infer_constant(value)?), ast::ExprKind::Constant { value, .. } => Some(self.infer_constant(value)?),
ast::ExprKind::Name { id, .. } => Some(self.infer_identifier(id)?), ast::ExprKind::Name { id, .. } => {
if !self.defined_identifiers.contains(id) {
if self.function_data.resolver.get_identifier_def(id.as_str()).is_some() {
self.defined_identifiers.push(id.clone());
} else {
return Err(format!(
"unknown identifier {} (use before def?) at {}",
id, expr.location
));
}
}
Some(self.infer_identifier(id)?)
}
ast::ExprKind::List { elts, .. } => Some(self.infer_list(elts)?), ast::ExprKind::List { elts, .. } => Some(self.infer_list(elts)?),
ast::ExprKind::Tuple { elts, .. } => Some(self.infer_tuple(elts)?), ast::ExprKind::Tuple { elts, .. } => Some(self.infer_tuple(elts)?),
ast::ExprKind::Attribute { value, attr, ctx: _ } => { ast::ExprKind::Attribute { value, attr, ctx: _ } => {
@ -187,6 +211,24 @@ impl<'a> Inferencer<'a> {
self.unifier.unify(a, b).map_err(|old| format!("{} at {}", old, location)) self.unifier.unify(a, b).map_err(|old| format!("{} at {}", old, location))
} }
fn infer_pattern(&mut self, pattern: &ast::Expr<()>) -> Result<(), String> {
match &pattern.node {
ExprKind::Name { id, .. } => {
if !self.defined_identifiers.contains(id) {
self.defined_identifiers.push(id.clone());
}
Ok(())
}
ExprKind::Tuple { elts, .. } => {
for elt in elts.iter() {
self.infer_pattern(elt)?;
}
Ok(())
}
_ => Ok(()),
}
}
fn build_method_call( fn build_method_call(
&mut self, &mut self,
location: Location, location: Location,
@ -228,6 +270,13 @@ impl<'a> Inferencer<'a> {
); );
} }
let mut defined_identifiers = self.defined_identifiers.clone();
for arg in args.args.iter() {
let name = &arg.node.arg;
if !defined_identifiers.contains(name) {
defined_identifiers.push(name.clone());
}
}
let fn_args: Vec<_> = args let fn_args: Vec<_> = args
.args .args
.iter() .iter()
@ -236,6 +285,7 @@ impl<'a> Inferencer<'a> {
let mut variable_mapping = self.variable_mapping.clone(); let mut variable_mapping = self.variable_mapping.clone();
variable_mapping.extend(fn_args.iter().cloned()); variable_mapping.extend(fn_args.iter().cloned());
let ret = self.unifier.get_fresh_var().0; let ret = self.unifier.get_fresh_var().0;
let mut new_context = Inferencer { let mut new_context = Inferencer {
function_data: self.function_data, function_data: self.function_data,
unifier: self.unifier, unifier: self.unifier,
@ -243,6 +293,7 @@ impl<'a> Inferencer<'a> {
virtual_checks: self.virtual_checks, virtual_checks: self.virtual_checks,
calls: self.calls, calls: self.calls,
top_level: self.top_level, top_level: self.top_level,
defined_identifiers,
variable_mapping, variable_mapping,
}; };
let fun = FunSignature { let fun = FunSignature {
@ -279,6 +330,7 @@ impl<'a> Inferencer<'a> {
); );
} }
let variable_mapping = self.variable_mapping.clone(); let variable_mapping = self.variable_mapping.clone();
let defined_identifiers = self.defined_identifiers.clone();
let mut new_context = Inferencer { let mut new_context = Inferencer {
function_data: self.function_data, function_data: self.function_data,
unifier: self.unifier, unifier: self.unifier,
@ -287,12 +339,14 @@ impl<'a> Inferencer<'a> {
variable_mapping, variable_mapping,
primitives: self.primitives, primitives: self.primitives,
calls: self.calls, calls: self.calls,
defined_identifiers,
}; };
let elt = new_context.fold_expr(elt)?;
let generator = generators.pop().unwrap(); let generator = generators.pop().unwrap();
if generator.is_async { if generator.is_async {
return Err("Async iterator not supported.".to_string()); return Err("Async iterator not supported.".to_string());
} }
new_context.infer_pattern(&generator.target)?;
let elt = new_context.fold_expr(elt)?;
let target = new_context.fold_expr(*generator.target)?; let target = new_context.fold_expr(*generator.target)?;
let iter = new_context.fold_expr(*generator.iter)?; let iter = new_context.fold_expr(*generator.iter)?;
let ifs: Vec<_> = generator let ifs: Vec<_> = generator

View File

@ -313,6 +313,7 @@ impl TestEnvironment {
primitives: &mut self.primitives, primitives: &mut self.primitives,
virtual_checks: &mut self.virtual_checks, virtual_checks: &mut self.virtual_checks,
calls: &mut self.calls, calls: &mut self.calls,
defined_identifiers: vec![]
} }
} }
} }
@ -382,6 +383,7 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
let mut defined_identifiers: Vec<_> = env.identifier_mapping.keys().cloned().collect(); let mut defined_identifiers: Vec<_> = env.identifier_mapping.keys().cloned().collect();
defined_identifiers.push("virtual".to_string()); defined_identifiers.push("virtual".to_string());
let mut inferencer = env.get_inferencer(); let mut inferencer = env.get_inferencer();
inferencer.defined_identifiers = defined_identifiers.clone();
let statements = parse_program(source).unwrap(); let statements = parse_program(source).unwrap();
let statements = statements let statements = statements
.into_iter() .into_iter()
@ -523,6 +525,7 @@ fn test_primitive_magic_methods(source: &str, mapping: HashMap<&str, &str>) {
let mut defined_identifiers: Vec<_> = env.identifier_mapping.keys().cloned().collect(); let mut defined_identifiers: Vec<_> = env.identifier_mapping.keys().cloned().collect();
defined_identifiers.push("virtual".to_string()); defined_identifiers.push("virtual".to_string());
let mut inferencer = env.get_inferencer(); let mut inferencer = env.get_inferencer();
inferencer.defined_identifiers = defined_identifiers.clone();
let statements = parse_program(source).unwrap(); let statements = parse_program(source).unwrap();
let statements = statements let statements = statements
.into_iter() .into_iter()

View File

@ -112,6 +112,7 @@ fn main() {
primitives: &primitives, primitives: &primitives,
virtual_checks: &mut virtual_checks, virtual_checks: &mut virtual_checks,
calls: &mut calls, calls: &mut calls,
defined_identifiers: vec![]
}; };
let statements = statements let statements = statements
@ -124,6 +125,7 @@ fn main() {
inferencer inferencer
.check_block(&statements, &mut identifiers) .check_block(&statements, &mut identifiers)
.unwrap(); .unwrap();
let top_level = Arc::new(TopLevelContext { let top_level = Arc::new(TopLevelContext {
definitions: Arc::new(RwLock::new(std::mem::take( definitions: Arc::new(RwLock::new(std::mem::take(
&mut *top_level.definitions.write(), &mut *top_level.definitions.write(),