type_inferencer: check defined identifiers during inference

This commit is contained in:
pca006132 2021-08-27 11:13:43 +08:00
parent 35ef0386db
commit a24e204824
5 changed files with 77 additions and 12 deletions

View File

@ -77,6 +77,7 @@ fn test_primitives() {
};
let mut virtual_checks = Vec::new();
let mut calls = HashMap::new();
let mut identifiers = vec!["a".to_string(), "b".to_string()];
let mut inferencer = Inferencer {
top_level: &top_level,
function_data: &mut function_data,
@ -85,6 +86,7 @@ fn test_primitives() {
primitives: &primitives,
virtual_checks: &mut virtual_checks,
calls: &mut calls,
defined_identifiers: identifiers.clone()
};
inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32);
inferencer.variable_mapping.insert("b".into(), inferencer.primitives.int32);
@ -95,7 +97,6 @@ fn test_primitives() {
.collect::<Result<Vec<_>, _>>()
.unwrap();
let mut identifiers = vec!["a".to_string(), "b".to_string()];
inferencer.check_block(&statements, &mut identifiers).unwrap();
let top_level = Arc::new(TopLevelContext {
definitions: Arc::new(RwLock::new(std::mem::take(&mut *top_level.definitions.write()))),
@ -235,6 +236,7 @@ fn test_simple_call() {
};
let mut virtual_checks = Vec::new();
let mut calls = HashMap::new();
let mut identifiers = vec!["a".to_string(), "foo".into()];
let mut inferencer = Inferencer {
top_level: &top_level,
function_data: &mut function_data,
@ -243,6 +245,7 @@ fn test_simple_call() {
primitives: &primitives,
virtual_checks: &mut virtual_checks,
calls: &mut calls,
defined_identifiers: identifiers.clone()
};
inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32);
inferencer.variable_mapping.insert("foo".into(), fun_ty);
@ -273,7 +276,6 @@ fn test_simple_call() {
unreachable!()
}
let mut identifiers = vec!["a".to_string(), "foo".into()];
inferencer.check_block(&statements_1, &mut identifiers).unwrap();
let top_level = Arc::new(TopLevelContext {
definitions: Arc::new(RwLock::new(std::mem::take(&mut *top_level.definitions.write()))),

View File

@ -42,7 +42,7 @@ impl<'a> Inferencer<'a> {
fn check_expr(
&mut self,
expr: &Expr<Option<Type>>,
defined_identifiers: &[String],
defined_identifiers: &mut Vec<String>,
) -> Result<(), String> {
// there are some cases where the custom field is None
if let Some(ty) = &expr.custom {
@ -57,12 +57,16 @@ impl<'a> Inferencer<'a> {
match &expr.node {
ExprKind::Name { id, .. } => {
if !defined_identifiers.contains(id) {
if self.function_data.resolver.get_identifier_def(id).is_some() {
defined_identifiers.push(id.clone());
} else {
return Err(format!(
"unknown identifier {} (use before def?) at {}",
id, expr.location
));
}
}
}
ExprKind::List { elts, .. }
| ExprKind::Tuple { elts, .. }
| ExprKind::BoolOp { values: elts, .. } => {
@ -106,7 +110,7 @@ impl<'a> Inferencer<'a> {
defined_identifiers.push(arg.node.arg.clone());
}
}
self.check_expr(body, &defined_identifiers)?;
self.check_expr(body, &mut defined_identifiers)?;
}
ExprKind::ListComp { elt, generators, .. } => {
// in our type inference stage, we already make sure that there is only 1 generator
@ -115,7 +119,7 @@ impl<'a> Inferencer<'a> {
let mut defined_identifiers = defined_identifiers.to_vec();
self.check_pattern(target, &mut defined_identifiers)?;
for term in once(elt.as_ref()).chain(ifs.iter()) {
self.check_expr(term, &defined_identifiers)?;
self.check_expr(term, &mut defined_identifiers)?;
}
}
ExprKind::Call { func, args, keywords } => {

View File

@ -45,6 +45,7 @@ pub struct FunctionData {
pub struct Inferencer<'a> {
pub top_level: &'a TopLevelContext,
pub defined_identifiers: Vec<String>,
pub function_data: &'a mut FunctionData,
pub unifier: &'a mut Unifier,
pub primitives: &'a PrimitiveStore,
@ -74,13 +75,14 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
let stmt = match node.node {
// we don't want fold over type annotation
ast::StmtKind::AnnAssign { target, annotation, value, simple } => {
self.infer_pattern(&target)?;
let target = Box::new(self.fold_expr(*target)?);
let value = if let Some(v) = value {
let ty = Box::new(self.fold_expr(*v)?);
self.unify(target.custom.unwrap(), ty.custom.unwrap(), &node.location)?;
Some(ty)
} else {
None
return Err(format!("declaration without definition is not yet supported, at {}", node.location))
};
let top_level_defs = self.top_level.definitions.read();
let annotation_type = self.function_data.resolver.parse_type_annotation(
@ -97,6 +99,16 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
node: ast::StmtKind::AnnAssign { target, annotation, value, simple },
}
}
ast::StmtKind::For { ref target, .. } => {
self.infer_pattern(target)?;
fold::fold_stmt(self, node)?
}
ast::StmtKind::Assign { ref targets, .. } => {
for target in targets {
self.infer_pattern(target)?;
}
fold::fold_stmt(self, node)?
}
_ => fold::fold_stmt(self, node)?,
};
match &stmt.node {
@ -146,7 +158,19 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
};
let custom = match &expr.node {
ast::ExprKind::Constant { value, .. } => Some(self.infer_constant(value)?),
ast::ExprKind::Name { id, .. } => Some(self.infer_identifier(id)?),
ast::ExprKind::Name { id, .. } => {
if !self.defined_identifiers.contains(id) {
if self.function_data.resolver.get_identifier_def(id.as_str()).is_some() {
self.defined_identifiers.push(id.clone());
} else {
return Err(format!(
"unknown identifier {} (use before def?) at {}",
id, expr.location
));
}
}
Some(self.infer_identifier(id)?)
}
ast::ExprKind::List { elts, .. } => Some(self.infer_list(elts)?),
ast::ExprKind::Tuple { elts, .. } => Some(self.infer_tuple(elts)?),
ast::ExprKind::Attribute { value, attr, ctx: _ } => {
@ -187,6 +211,24 @@ impl<'a> Inferencer<'a> {
self.unifier.unify(a, b).map_err(|old| format!("{} at {}", old, location))
}
fn infer_pattern(&mut self, pattern: &ast::Expr<()>) -> Result<(), String> {
match &pattern.node {
ExprKind::Name { id, .. } => {
if !self.defined_identifiers.contains(id) {
self.defined_identifiers.push(id.clone());
}
Ok(())
}
ExprKind::Tuple { elts, .. } => {
for elt in elts.iter() {
self.infer_pattern(elt)?;
}
Ok(())
}
_ => Ok(()),
}
}
fn build_method_call(
&mut self,
location: Location,
@ -228,6 +270,13 @@ impl<'a> Inferencer<'a> {
);
}
let mut defined_identifiers = self.defined_identifiers.clone();
for arg in args.args.iter() {
let name = &arg.node.arg;
if !defined_identifiers.contains(name) {
defined_identifiers.push(name.clone());
}
}
let fn_args: Vec<_> = args
.args
.iter()
@ -236,6 +285,7 @@ impl<'a> Inferencer<'a> {
let mut variable_mapping = self.variable_mapping.clone();
variable_mapping.extend(fn_args.iter().cloned());
let ret = self.unifier.get_fresh_var().0;
let mut new_context = Inferencer {
function_data: self.function_data,
unifier: self.unifier,
@ -243,6 +293,7 @@ impl<'a> Inferencer<'a> {
virtual_checks: self.virtual_checks,
calls: self.calls,
top_level: self.top_level,
defined_identifiers,
variable_mapping,
};
let fun = FunSignature {
@ -279,6 +330,7 @@ impl<'a> Inferencer<'a> {
);
}
let variable_mapping = self.variable_mapping.clone();
let defined_identifiers = self.defined_identifiers.clone();
let mut new_context = Inferencer {
function_data: self.function_data,
unifier: self.unifier,
@ -287,12 +339,14 @@ impl<'a> Inferencer<'a> {
variable_mapping,
primitives: self.primitives,
calls: self.calls,
defined_identifiers,
};
let elt = new_context.fold_expr(elt)?;
let generator = generators.pop().unwrap();
if generator.is_async {
return Err("Async iterator not supported.".to_string());
}
new_context.infer_pattern(&generator.target)?;
let elt = new_context.fold_expr(elt)?;
let target = new_context.fold_expr(*generator.target)?;
let iter = new_context.fold_expr(*generator.iter)?;
let ifs: Vec<_> = generator

View File

@ -313,6 +313,7 @@ impl TestEnvironment {
primitives: &mut self.primitives,
virtual_checks: &mut self.virtual_checks,
calls: &mut self.calls,
defined_identifiers: vec![]
}
}
}
@ -382,6 +383,7 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
let mut defined_identifiers: Vec<_> = env.identifier_mapping.keys().cloned().collect();
defined_identifiers.push("virtual".to_string());
let mut inferencer = env.get_inferencer();
inferencer.defined_identifiers = defined_identifiers.clone();
let statements = parse_program(source).unwrap();
let statements = statements
.into_iter()
@ -523,6 +525,7 @@ fn test_primitive_magic_methods(source: &str, mapping: HashMap<&str, &str>) {
let mut defined_identifiers: Vec<_> = env.identifier_mapping.keys().cloned().collect();
defined_identifiers.push("virtual".to_string());
let mut inferencer = env.get_inferencer();
inferencer.defined_identifiers = defined_identifiers.clone();
let statements = parse_program(source).unwrap();
let statements = statements
.into_iter()

View File

@ -112,6 +112,7 @@ fn main() {
primitives: &primitives,
virtual_checks: &mut virtual_checks,
calls: &mut calls,
defined_identifiers: vec![]
};
let statements = statements
@ -124,6 +125,7 @@ fn main() {
inferencer
.check_block(&statements, &mut identifiers)
.unwrap();
let top_level = Arc::new(TopLevelContext {
definitions: Arc::new(RwLock::new(std::mem::take(
&mut *top_level.definitions.write(),