forked from M-Labs/nac3
type_inferencer: check defined identifiers during inference
This commit is contained in:
parent
35ef0386db
commit
a24e204824
@ -77,6 +77,7 @@ fn test_primitives() {
|
||||
};
|
||||
let mut virtual_checks = Vec::new();
|
||||
let mut calls = HashMap::new();
|
||||
let mut identifiers = vec!["a".to_string(), "b".to_string()];
|
||||
let mut inferencer = Inferencer {
|
||||
top_level: &top_level,
|
||||
function_data: &mut function_data,
|
||||
@ -85,6 +86,7 @@ fn test_primitives() {
|
||||
primitives: &primitives,
|
||||
virtual_checks: &mut virtual_checks,
|
||||
calls: &mut calls,
|
||||
defined_identifiers: identifiers.clone()
|
||||
};
|
||||
inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32);
|
||||
inferencer.variable_mapping.insert("b".into(), inferencer.primitives.int32);
|
||||
@ -95,7 +97,6 @@ fn test_primitives() {
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.unwrap();
|
||||
|
||||
let mut identifiers = vec!["a".to_string(), "b".to_string()];
|
||||
inferencer.check_block(&statements, &mut identifiers).unwrap();
|
||||
let top_level = Arc::new(TopLevelContext {
|
||||
definitions: Arc::new(RwLock::new(std::mem::take(&mut *top_level.definitions.write()))),
|
||||
@ -235,6 +236,7 @@ fn test_simple_call() {
|
||||
};
|
||||
let mut virtual_checks = Vec::new();
|
||||
let mut calls = HashMap::new();
|
||||
let mut identifiers = vec!["a".to_string(), "foo".into()];
|
||||
let mut inferencer = Inferencer {
|
||||
top_level: &top_level,
|
||||
function_data: &mut function_data,
|
||||
@ -243,6 +245,7 @@ fn test_simple_call() {
|
||||
primitives: &primitives,
|
||||
virtual_checks: &mut virtual_checks,
|
||||
calls: &mut calls,
|
||||
defined_identifiers: identifiers.clone()
|
||||
};
|
||||
inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32);
|
||||
inferencer.variable_mapping.insert("foo".into(), fun_ty);
|
||||
@ -273,7 +276,6 @@ fn test_simple_call() {
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
let mut identifiers = vec!["a".to_string(), "foo".into()];
|
||||
inferencer.check_block(&statements_1, &mut identifiers).unwrap();
|
||||
let top_level = Arc::new(TopLevelContext {
|
||||
definitions: Arc::new(RwLock::new(std::mem::take(&mut *top_level.definitions.write()))),
|
||||
|
@ -42,7 +42,7 @@ impl<'a> Inferencer<'a> {
|
||||
fn check_expr(
|
||||
&mut self,
|
||||
expr: &Expr<Option<Type>>,
|
||||
defined_identifiers: &[String],
|
||||
defined_identifiers: &mut Vec<String>,
|
||||
) -> Result<(), String> {
|
||||
// there are some cases where the custom field is None
|
||||
if let Some(ty) = &expr.custom {
|
||||
@ -57,12 +57,16 @@ impl<'a> Inferencer<'a> {
|
||||
match &expr.node {
|
||||
ExprKind::Name { id, .. } => {
|
||||
if !defined_identifiers.contains(id) {
|
||||
if self.function_data.resolver.get_identifier_def(id).is_some() {
|
||||
defined_identifiers.push(id.clone());
|
||||
} else {
|
||||
return Err(format!(
|
||||
"unknown identifier {} (use before def?) at {}",
|
||||
id, expr.location
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
ExprKind::List { elts, .. }
|
||||
| ExprKind::Tuple { elts, .. }
|
||||
| ExprKind::BoolOp { values: elts, .. } => {
|
||||
@ -106,7 +110,7 @@ impl<'a> Inferencer<'a> {
|
||||
defined_identifiers.push(arg.node.arg.clone());
|
||||
}
|
||||
}
|
||||
self.check_expr(body, &defined_identifiers)?;
|
||||
self.check_expr(body, &mut defined_identifiers)?;
|
||||
}
|
||||
ExprKind::ListComp { elt, generators, .. } => {
|
||||
// in our type inference stage, we already make sure that there is only 1 generator
|
||||
@ -115,7 +119,7 @@ impl<'a> Inferencer<'a> {
|
||||
let mut defined_identifiers = defined_identifiers.to_vec();
|
||||
self.check_pattern(target, &mut defined_identifiers)?;
|
||||
for term in once(elt.as_ref()).chain(ifs.iter()) {
|
||||
self.check_expr(term, &defined_identifiers)?;
|
||||
self.check_expr(term, &mut defined_identifiers)?;
|
||||
}
|
||||
}
|
||||
ExprKind::Call { func, args, keywords } => {
|
||||
|
@ -45,6 +45,7 @@ pub struct FunctionData {
|
||||
|
||||
pub struct Inferencer<'a> {
|
||||
pub top_level: &'a TopLevelContext,
|
||||
pub defined_identifiers: Vec<String>,
|
||||
pub function_data: &'a mut FunctionData,
|
||||
pub unifier: &'a mut Unifier,
|
||||
pub primitives: &'a PrimitiveStore,
|
||||
@ -74,13 +75,14 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
||||
let stmt = match node.node {
|
||||
// we don't want fold over type annotation
|
||||
ast::StmtKind::AnnAssign { target, annotation, value, simple } => {
|
||||
self.infer_pattern(&target)?;
|
||||
let target = Box::new(self.fold_expr(*target)?);
|
||||
let value = if let Some(v) = value {
|
||||
let ty = Box::new(self.fold_expr(*v)?);
|
||||
self.unify(target.custom.unwrap(), ty.custom.unwrap(), &node.location)?;
|
||||
Some(ty)
|
||||
} else {
|
||||
None
|
||||
return Err(format!("declaration without definition is not yet supported, at {}", node.location))
|
||||
};
|
||||
let top_level_defs = self.top_level.definitions.read();
|
||||
let annotation_type = self.function_data.resolver.parse_type_annotation(
|
||||
@ -97,6 +99,16 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
||||
node: ast::StmtKind::AnnAssign { target, annotation, value, simple },
|
||||
}
|
||||
}
|
||||
ast::StmtKind::For { ref target, .. } => {
|
||||
self.infer_pattern(target)?;
|
||||
fold::fold_stmt(self, node)?
|
||||
}
|
||||
ast::StmtKind::Assign { ref targets, .. } => {
|
||||
for target in targets {
|
||||
self.infer_pattern(target)?;
|
||||
}
|
||||
fold::fold_stmt(self, node)?
|
||||
}
|
||||
_ => fold::fold_stmt(self, node)?,
|
||||
};
|
||||
match &stmt.node {
|
||||
@ -146,7 +158,19 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
||||
};
|
||||
let custom = match &expr.node {
|
||||
ast::ExprKind::Constant { value, .. } => Some(self.infer_constant(value)?),
|
||||
ast::ExprKind::Name { id, .. } => Some(self.infer_identifier(id)?),
|
||||
ast::ExprKind::Name { id, .. } => {
|
||||
if !self.defined_identifiers.contains(id) {
|
||||
if self.function_data.resolver.get_identifier_def(id.as_str()).is_some() {
|
||||
self.defined_identifiers.push(id.clone());
|
||||
} else {
|
||||
return Err(format!(
|
||||
"unknown identifier {} (use before def?) at {}",
|
||||
id, expr.location
|
||||
));
|
||||
}
|
||||
}
|
||||
Some(self.infer_identifier(id)?)
|
||||
}
|
||||
ast::ExprKind::List { elts, .. } => Some(self.infer_list(elts)?),
|
||||
ast::ExprKind::Tuple { elts, .. } => Some(self.infer_tuple(elts)?),
|
||||
ast::ExprKind::Attribute { value, attr, ctx: _ } => {
|
||||
@ -187,6 +211,24 @@ impl<'a> Inferencer<'a> {
|
||||
self.unifier.unify(a, b).map_err(|old| format!("{} at {}", old, location))
|
||||
}
|
||||
|
||||
fn infer_pattern(&mut self, pattern: &ast::Expr<()>) -> Result<(), String> {
|
||||
match &pattern.node {
|
||||
ExprKind::Name { id, .. } => {
|
||||
if !self.defined_identifiers.contains(id) {
|
||||
self.defined_identifiers.push(id.clone());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
ExprKind::Tuple { elts, .. } => {
|
||||
for elt in elts.iter() {
|
||||
self.infer_pattern(elt)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
_ => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_method_call(
|
||||
&mut self,
|
||||
location: Location,
|
||||
@ -228,6 +270,13 @@ impl<'a> Inferencer<'a> {
|
||||
);
|
||||
}
|
||||
|
||||
let mut defined_identifiers = self.defined_identifiers.clone();
|
||||
for arg in args.args.iter() {
|
||||
let name = &arg.node.arg;
|
||||
if !defined_identifiers.contains(name) {
|
||||
defined_identifiers.push(name.clone());
|
||||
}
|
||||
}
|
||||
let fn_args: Vec<_> = args
|
||||
.args
|
||||
.iter()
|
||||
@ -236,6 +285,7 @@ impl<'a> Inferencer<'a> {
|
||||
let mut variable_mapping = self.variable_mapping.clone();
|
||||
variable_mapping.extend(fn_args.iter().cloned());
|
||||
let ret = self.unifier.get_fresh_var().0;
|
||||
|
||||
let mut new_context = Inferencer {
|
||||
function_data: self.function_data,
|
||||
unifier: self.unifier,
|
||||
@ -243,6 +293,7 @@ impl<'a> Inferencer<'a> {
|
||||
virtual_checks: self.virtual_checks,
|
||||
calls: self.calls,
|
||||
top_level: self.top_level,
|
||||
defined_identifiers,
|
||||
variable_mapping,
|
||||
};
|
||||
let fun = FunSignature {
|
||||
@ -279,6 +330,7 @@ impl<'a> Inferencer<'a> {
|
||||
);
|
||||
}
|
||||
let variable_mapping = self.variable_mapping.clone();
|
||||
let defined_identifiers = self.defined_identifiers.clone();
|
||||
let mut new_context = Inferencer {
|
||||
function_data: self.function_data,
|
||||
unifier: self.unifier,
|
||||
@ -287,12 +339,14 @@ impl<'a> Inferencer<'a> {
|
||||
variable_mapping,
|
||||
primitives: self.primitives,
|
||||
calls: self.calls,
|
||||
defined_identifiers,
|
||||
};
|
||||
let elt = new_context.fold_expr(elt)?;
|
||||
let generator = generators.pop().unwrap();
|
||||
if generator.is_async {
|
||||
return Err("Async iterator not supported.".to_string());
|
||||
}
|
||||
new_context.infer_pattern(&generator.target)?;
|
||||
let elt = new_context.fold_expr(elt)?;
|
||||
let target = new_context.fold_expr(*generator.target)?;
|
||||
let iter = new_context.fold_expr(*generator.iter)?;
|
||||
let ifs: Vec<_> = generator
|
||||
|
@ -313,6 +313,7 @@ impl TestEnvironment {
|
||||
primitives: &mut self.primitives,
|
||||
virtual_checks: &mut self.virtual_checks,
|
||||
calls: &mut self.calls,
|
||||
defined_identifiers: vec![]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -382,6 +383,7 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
|
||||
let mut defined_identifiers: Vec<_> = env.identifier_mapping.keys().cloned().collect();
|
||||
defined_identifiers.push("virtual".to_string());
|
||||
let mut inferencer = env.get_inferencer();
|
||||
inferencer.defined_identifiers = defined_identifiers.clone();
|
||||
let statements = parse_program(source).unwrap();
|
||||
let statements = statements
|
||||
.into_iter()
|
||||
@ -523,6 +525,7 @@ fn test_primitive_magic_methods(source: &str, mapping: HashMap<&str, &str>) {
|
||||
let mut defined_identifiers: Vec<_> = env.identifier_mapping.keys().cloned().collect();
|
||||
defined_identifiers.push("virtual".to_string());
|
||||
let mut inferencer = env.get_inferencer();
|
||||
inferencer.defined_identifiers = defined_identifiers.clone();
|
||||
let statements = parse_program(source).unwrap();
|
||||
let statements = statements
|
||||
.into_iter()
|
||||
|
@ -112,6 +112,7 @@ fn main() {
|
||||
primitives: &primitives,
|
||||
virtual_checks: &mut virtual_checks,
|
||||
calls: &mut calls,
|
||||
defined_identifiers: vec![]
|
||||
};
|
||||
|
||||
let statements = statements
|
||||
@ -124,6 +125,7 @@ fn main() {
|
||||
inferencer
|
||||
.check_block(&statements, &mut identifiers)
|
||||
.unwrap();
|
||||
|
||||
let top_level = Arc::new(TopLevelContext {
|
||||
definitions: Arc::new(RwLock::new(std::mem::take(
|
||||
&mut *top_level.definitions.write(),
|
||||
|
Loading…
Reference in New Issue
Block a user