forked from M-Labs/nac3
type_inferencer: check defined identifiers during inference
This commit is contained in:
parent
35ef0386db
commit
a24e204824
|
@ -77,6 +77,7 @@ fn test_primitives() {
|
||||||
};
|
};
|
||||||
let mut virtual_checks = Vec::new();
|
let mut virtual_checks = Vec::new();
|
||||||
let mut calls = HashMap::new();
|
let mut calls = HashMap::new();
|
||||||
|
let mut identifiers = vec!["a".to_string(), "b".to_string()];
|
||||||
let mut inferencer = Inferencer {
|
let mut inferencer = Inferencer {
|
||||||
top_level: &top_level,
|
top_level: &top_level,
|
||||||
function_data: &mut function_data,
|
function_data: &mut function_data,
|
||||||
|
@ -85,6 +86,7 @@ fn test_primitives() {
|
||||||
primitives: &primitives,
|
primitives: &primitives,
|
||||||
virtual_checks: &mut virtual_checks,
|
virtual_checks: &mut virtual_checks,
|
||||||
calls: &mut calls,
|
calls: &mut calls,
|
||||||
|
defined_identifiers: identifiers.clone()
|
||||||
};
|
};
|
||||||
inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32);
|
inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32);
|
||||||
inferencer.variable_mapping.insert("b".into(), inferencer.primitives.int32);
|
inferencer.variable_mapping.insert("b".into(), inferencer.primitives.int32);
|
||||||
|
@ -95,7 +97,6 @@ fn test_primitives() {
|
||||||
.collect::<Result<Vec<_>, _>>()
|
.collect::<Result<Vec<_>, _>>()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let mut identifiers = vec!["a".to_string(), "b".to_string()];
|
|
||||||
inferencer.check_block(&statements, &mut identifiers).unwrap();
|
inferencer.check_block(&statements, &mut identifiers).unwrap();
|
||||||
let top_level = Arc::new(TopLevelContext {
|
let top_level = Arc::new(TopLevelContext {
|
||||||
definitions: Arc::new(RwLock::new(std::mem::take(&mut *top_level.definitions.write()))),
|
definitions: Arc::new(RwLock::new(std::mem::take(&mut *top_level.definitions.write()))),
|
||||||
|
@ -235,6 +236,7 @@ fn test_simple_call() {
|
||||||
};
|
};
|
||||||
let mut virtual_checks = Vec::new();
|
let mut virtual_checks = Vec::new();
|
||||||
let mut calls = HashMap::new();
|
let mut calls = HashMap::new();
|
||||||
|
let mut identifiers = vec!["a".to_string(), "foo".into()];
|
||||||
let mut inferencer = Inferencer {
|
let mut inferencer = Inferencer {
|
||||||
top_level: &top_level,
|
top_level: &top_level,
|
||||||
function_data: &mut function_data,
|
function_data: &mut function_data,
|
||||||
|
@ -243,6 +245,7 @@ fn test_simple_call() {
|
||||||
primitives: &primitives,
|
primitives: &primitives,
|
||||||
virtual_checks: &mut virtual_checks,
|
virtual_checks: &mut virtual_checks,
|
||||||
calls: &mut calls,
|
calls: &mut calls,
|
||||||
|
defined_identifiers: identifiers.clone()
|
||||||
};
|
};
|
||||||
inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32);
|
inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32);
|
||||||
inferencer.variable_mapping.insert("foo".into(), fun_ty);
|
inferencer.variable_mapping.insert("foo".into(), fun_ty);
|
||||||
|
@ -273,7 +276,6 @@ fn test_simple_call() {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut identifiers = vec!["a".to_string(), "foo".into()];
|
|
||||||
inferencer.check_block(&statements_1, &mut identifiers).unwrap();
|
inferencer.check_block(&statements_1, &mut identifiers).unwrap();
|
||||||
let top_level = Arc::new(TopLevelContext {
|
let top_level = Arc::new(TopLevelContext {
|
||||||
definitions: Arc::new(RwLock::new(std::mem::take(&mut *top_level.definitions.write()))),
|
definitions: Arc::new(RwLock::new(std::mem::take(&mut *top_level.definitions.write()))),
|
||||||
|
|
|
@ -42,7 +42,7 @@ impl<'a> Inferencer<'a> {
|
||||||
fn check_expr(
|
fn check_expr(
|
||||||
&mut self,
|
&mut self,
|
||||||
expr: &Expr<Option<Type>>,
|
expr: &Expr<Option<Type>>,
|
||||||
defined_identifiers: &[String],
|
defined_identifiers: &mut Vec<String>,
|
||||||
) -> Result<(), String> {
|
) -> Result<(), String> {
|
||||||
// there are some cases where the custom field is None
|
// there are some cases where the custom field is None
|
||||||
if let Some(ty) = &expr.custom {
|
if let Some(ty) = &expr.custom {
|
||||||
|
@ -57,10 +57,14 @@ impl<'a> Inferencer<'a> {
|
||||||
match &expr.node {
|
match &expr.node {
|
||||||
ExprKind::Name { id, .. } => {
|
ExprKind::Name { id, .. } => {
|
||||||
if !defined_identifiers.contains(id) {
|
if !defined_identifiers.contains(id) {
|
||||||
return Err(format!(
|
if self.function_data.resolver.get_identifier_def(id).is_some() {
|
||||||
"unknown identifier {} (use before def?) at {}",
|
defined_identifiers.push(id.clone());
|
||||||
id, expr.location
|
} else {
|
||||||
));
|
return Err(format!(
|
||||||
|
"unknown identifier {} (use before def?) at {}",
|
||||||
|
id, expr.location
|
||||||
|
));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ExprKind::List { elts, .. }
|
ExprKind::List { elts, .. }
|
||||||
|
@ -106,7 +110,7 @@ impl<'a> Inferencer<'a> {
|
||||||
defined_identifiers.push(arg.node.arg.clone());
|
defined_identifiers.push(arg.node.arg.clone());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
self.check_expr(body, &defined_identifiers)?;
|
self.check_expr(body, &mut defined_identifiers)?;
|
||||||
}
|
}
|
||||||
ExprKind::ListComp { elt, generators, .. } => {
|
ExprKind::ListComp { elt, generators, .. } => {
|
||||||
// in our type inference stage, we already make sure that there is only 1 generator
|
// in our type inference stage, we already make sure that there is only 1 generator
|
||||||
|
@ -115,7 +119,7 @@ impl<'a> Inferencer<'a> {
|
||||||
let mut defined_identifiers = defined_identifiers.to_vec();
|
let mut defined_identifiers = defined_identifiers.to_vec();
|
||||||
self.check_pattern(target, &mut defined_identifiers)?;
|
self.check_pattern(target, &mut defined_identifiers)?;
|
||||||
for term in once(elt.as_ref()).chain(ifs.iter()) {
|
for term in once(elt.as_ref()).chain(ifs.iter()) {
|
||||||
self.check_expr(term, &defined_identifiers)?;
|
self.check_expr(term, &mut defined_identifiers)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ExprKind::Call { func, args, keywords } => {
|
ExprKind::Call { func, args, keywords } => {
|
||||||
|
|
|
@ -45,6 +45,7 @@ pub struct FunctionData {
|
||||||
|
|
||||||
pub struct Inferencer<'a> {
|
pub struct Inferencer<'a> {
|
||||||
pub top_level: &'a TopLevelContext,
|
pub top_level: &'a TopLevelContext,
|
||||||
|
pub defined_identifiers: Vec<String>,
|
||||||
pub function_data: &'a mut FunctionData,
|
pub function_data: &'a mut FunctionData,
|
||||||
pub unifier: &'a mut Unifier,
|
pub unifier: &'a mut Unifier,
|
||||||
pub primitives: &'a PrimitiveStore,
|
pub primitives: &'a PrimitiveStore,
|
||||||
|
@ -74,13 +75,14 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
||||||
let stmt = match node.node {
|
let stmt = match node.node {
|
||||||
// we don't want fold over type annotation
|
// we don't want fold over type annotation
|
||||||
ast::StmtKind::AnnAssign { target, annotation, value, simple } => {
|
ast::StmtKind::AnnAssign { target, annotation, value, simple } => {
|
||||||
|
self.infer_pattern(&target)?;
|
||||||
let target = Box::new(self.fold_expr(*target)?);
|
let target = Box::new(self.fold_expr(*target)?);
|
||||||
let value = if let Some(v) = value {
|
let value = if let Some(v) = value {
|
||||||
let ty = Box::new(self.fold_expr(*v)?);
|
let ty = Box::new(self.fold_expr(*v)?);
|
||||||
self.unify(target.custom.unwrap(), ty.custom.unwrap(), &node.location)?;
|
self.unify(target.custom.unwrap(), ty.custom.unwrap(), &node.location)?;
|
||||||
Some(ty)
|
Some(ty)
|
||||||
} else {
|
} else {
|
||||||
None
|
return Err(format!("declaration without definition is not yet supported, at {}", node.location))
|
||||||
};
|
};
|
||||||
let top_level_defs = self.top_level.definitions.read();
|
let top_level_defs = self.top_level.definitions.read();
|
||||||
let annotation_type = self.function_data.resolver.parse_type_annotation(
|
let annotation_type = self.function_data.resolver.parse_type_annotation(
|
||||||
|
@ -97,6 +99,16 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
||||||
node: ast::StmtKind::AnnAssign { target, annotation, value, simple },
|
node: ast::StmtKind::AnnAssign { target, annotation, value, simple },
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
ast::StmtKind::For { ref target, .. } => {
|
||||||
|
self.infer_pattern(target)?;
|
||||||
|
fold::fold_stmt(self, node)?
|
||||||
|
}
|
||||||
|
ast::StmtKind::Assign { ref targets, .. } => {
|
||||||
|
for target in targets {
|
||||||
|
self.infer_pattern(target)?;
|
||||||
|
}
|
||||||
|
fold::fold_stmt(self, node)?
|
||||||
|
}
|
||||||
_ => fold::fold_stmt(self, node)?,
|
_ => fold::fold_stmt(self, node)?,
|
||||||
};
|
};
|
||||||
match &stmt.node {
|
match &stmt.node {
|
||||||
|
@ -146,7 +158,19 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
||||||
};
|
};
|
||||||
let custom = match &expr.node {
|
let custom = match &expr.node {
|
||||||
ast::ExprKind::Constant { value, .. } => Some(self.infer_constant(value)?),
|
ast::ExprKind::Constant { value, .. } => Some(self.infer_constant(value)?),
|
||||||
ast::ExprKind::Name { id, .. } => Some(self.infer_identifier(id)?),
|
ast::ExprKind::Name { id, .. } => {
|
||||||
|
if !self.defined_identifiers.contains(id) {
|
||||||
|
if self.function_data.resolver.get_identifier_def(id.as_str()).is_some() {
|
||||||
|
self.defined_identifiers.push(id.clone());
|
||||||
|
} else {
|
||||||
|
return Err(format!(
|
||||||
|
"unknown identifier {} (use before def?) at {}",
|
||||||
|
id, expr.location
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some(self.infer_identifier(id)?)
|
||||||
|
}
|
||||||
ast::ExprKind::List { elts, .. } => Some(self.infer_list(elts)?),
|
ast::ExprKind::List { elts, .. } => Some(self.infer_list(elts)?),
|
||||||
ast::ExprKind::Tuple { elts, .. } => Some(self.infer_tuple(elts)?),
|
ast::ExprKind::Tuple { elts, .. } => Some(self.infer_tuple(elts)?),
|
||||||
ast::ExprKind::Attribute { value, attr, ctx: _ } => {
|
ast::ExprKind::Attribute { value, attr, ctx: _ } => {
|
||||||
|
@ -187,6 +211,24 @@ impl<'a> Inferencer<'a> {
|
||||||
self.unifier.unify(a, b).map_err(|old| format!("{} at {}", old, location))
|
self.unifier.unify(a, b).map_err(|old| format!("{} at {}", old, location))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn infer_pattern(&mut self, pattern: &ast::Expr<()>) -> Result<(), String> {
|
||||||
|
match &pattern.node {
|
||||||
|
ExprKind::Name { id, .. } => {
|
||||||
|
if !self.defined_identifiers.contains(id) {
|
||||||
|
self.defined_identifiers.push(id.clone());
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
ExprKind::Tuple { elts, .. } => {
|
||||||
|
for elt in elts.iter() {
|
||||||
|
self.infer_pattern(elt)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
_ => Ok(()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn build_method_call(
|
fn build_method_call(
|
||||||
&mut self,
|
&mut self,
|
||||||
location: Location,
|
location: Location,
|
||||||
|
@ -228,6 +270,13 @@ impl<'a> Inferencer<'a> {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let mut defined_identifiers = self.defined_identifiers.clone();
|
||||||
|
for arg in args.args.iter() {
|
||||||
|
let name = &arg.node.arg;
|
||||||
|
if !defined_identifiers.contains(name) {
|
||||||
|
defined_identifiers.push(name.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
let fn_args: Vec<_> = args
|
let fn_args: Vec<_> = args
|
||||||
.args
|
.args
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -236,6 +285,7 @@ impl<'a> Inferencer<'a> {
|
||||||
let mut variable_mapping = self.variable_mapping.clone();
|
let mut variable_mapping = self.variable_mapping.clone();
|
||||||
variable_mapping.extend(fn_args.iter().cloned());
|
variable_mapping.extend(fn_args.iter().cloned());
|
||||||
let ret = self.unifier.get_fresh_var().0;
|
let ret = self.unifier.get_fresh_var().0;
|
||||||
|
|
||||||
let mut new_context = Inferencer {
|
let mut new_context = Inferencer {
|
||||||
function_data: self.function_data,
|
function_data: self.function_data,
|
||||||
unifier: self.unifier,
|
unifier: self.unifier,
|
||||||
|
@ -243,6 +293,7 @@ impl<'a> Inferencer<'a> {
|
||||||
virtual_checks: self.virtual_checks,
|
virtual_checks: self.virtual_checks,
|
||||||
calls: self.calls,
|
calls: self.calls,
|
||||||
top_level: self.top_level,
|
top_level: self.top_level,
|
||||||
|
defined_identifiers,
|
||||||
variable_mapping,
|
variable_mapping,
|
||||||
};
|
};
|
||||||
let fun = FunSignature {
|
let fun = FunSignature {
|
||||||
|
@ -279,6 +330,7 @@ impl<'a> Inferencer<'a> {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
let variable_mapping = self.variable_mapping.clone();
|
let variable_mapping = self.variable_mapping.clone();
|
||||||
|
let defined_identifiers = self.defined_identifiers.clone();
|
||||||
let mut new_context = Inferencer {
|
let mut new_context = Inferencer {
|
||||||
function_data: self.function_data,
|
function_data: self.function_data,
|
||||||
unifier: self.unifier,
|
unifier: self.unifier,
|
||||||
|
@ -287,12 +339,14 @@ impl<'a> Inferencer<'a> {
|
||||||
variable_mapping,
|
variable_mapping,
|
||||||
primitives: self.primitives,
|
primitives: self.primitives,
|
||||||
calls: self.calls,
|
calls: self.calls,
|
||||||
|
defined_identifiers,
|
||||||
};
|
};
|
||||||
let elt = new_context.fold_expr(elt)?;
|
|
||||||
let generator = generators.pop().unwrap();
|
let generator = generators.pop().unwrap();
|
||||||
if generator.is_async {
|
if generator.is_async {
|
||||||
return Err("Async iterator not supported.".to_string());
|
return Err("Async iterator not supported.".to_string());
|
||||||
}
|
}
|
||||||
|
new_context.infer_pattern(&generator.target)?;
|
||||||
|
let elt = new_context.fold_expr(elt)?;
|
||||||
let target = new_context.fold_expr(*generator.target)?;
|
let target = new_context.fold_expr(*generator.target)?;
|
||||||
let iter = new_context.fold_expr(*generator.iter)?;
|
let iter = new_context.fold_expr(*generator.iter)?;
|
||||||
let ifs: Vec<_> = generator
|
let ifs: Vec<_> = generator
|
||||||
|
|
|
@ -313,6 +313,7 @@ impl TestEnvironment {
|
||||||
primitives: &mut self.primitives,
|
primitives: &mut self.primitives,
|
||||||
virtual_checks: &mut self.virtual_checks,
|
virtual_checks: &mut self.virtual_checks,
|
||||||
calls: &mut self.calls,
|
calls: &mut self.calls,
|
||||||
|
defined_identifiers: vec![]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -382,6 +383,7 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
|
||||||
let mut defined_identifiers: Vec<_> = env.identifier_mapping.keys().cloned().collect();
|
let mut defined_identifiers: Vec<_> = env.identifier_mapping.keys().cloned().collect();
|
||||||
defined_identifiers.push("virtual".to_string());
|
defined_identifiers.push("virtual".to_string());
|
||||||
let mut inferencer = env.get_inferencer();
|
let mut inferencer = env.get_inferencer();
|
||||||
|
inferencer.defined_identifiers = defined_identifiers.clone();
|
||||||
let statements = parse_program(source).unwrap();
|
let statements = parse_program(source).unwrap();
|
||||||
let statements = statements
|
let statements = statements
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
@ -523,6 +525,7 @@ fn test_primitive_magic_methods(source: &str, mapping: HashMap<&str, &str>) {
|
||||||
let mut defined_identifiers: Vec<_> = env.identifier_mapping.keys().cloned().collect();
|
let mut defined_identifiers: Vec<_> = env.identifier_mapping.keys().cloned().collect();
|
||||||
defined_identifiers.push("virtual".to_string());
|
defined_identifiers.push("virtual".to_string());
|
||||||
let mut inferencer = env.get_inferencer();
|
let mut inferencer = env.get_inferencer();
|
||||||
|
inferencer.defined_identifiers = defined_identifiers.clone();
|
||||||
let statements = parse_program(source).unwrap();
|
let statements = parse_program(source).unwrap();
|
||||||
let statements = statements
|
let statements = statements
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
|
|
@ -112,6 +112,7 @@ fn main() {
|
||||||
primitives: &primitives,
|
primitives: &primitives,
|
||||||
virtual_checks: &mut virtual_checks,
|
virtual_checks: &mut virtual_checks,
|
||||||
calls: &mut calls,
|
calls: &mut calls,
|
||||||
|
defined_identifiers: vec![]
|
||||||
};
|
};
|
||||||
|
|
||||||
let statements = statements
|
let statements = statements
|
||||||
|
@ -124,6 +125,7 @@ fn main() {
|
||||||
inferencer
|
inferencer
|
||||||
.check_block(&statements, &mut identifiers)
|
.check_block(&statements, &mut identifiers)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let top_level = Arc::new(TopLevelContext {
|
let top_level = Arc::new(TopLevelContext {
|
||||||
definitions: Arc::new(RwLock::new(std::mem::take(
|
definitions: Arc::new(RwLock::new(std::mem::take(
|
||||||
&mut *top_level.definitions.write(),
|
&mut *top_level.definitions.write(),
|
||||||
|
|
Loading…
Reference in New Issue