diff --git a/nac3core/src/type_check/context/top_level_context.rs b/nac3core/src/type_check/context/top_level_context.rs index d7c4ca23..b2148d9f 100644 --- a/nac3core/src/type_check/context/top_level_context.rs +++ b/nac3core/src/type_check/context/top_level_context.rs @@ -130,7 +130,9 @@ impl<'a> TopLevelContext<'a> { } pub fn get_type(&self, name: &str) -> Option { - // TODO: handle parametric types + // TODO: handle name visibility + // possibly by passing a function from outside to tell what names are allowed, and what are + // not... self.sym_table.get(name).cloned() } } diff --git a/nac3core/src/type_check/signature.rs b/nac3core/src/type_check/signature.rs index 04d3fc4b..1383cfb8 100644 --- a/nac3core/src/type_check/signature.rs +++ b/nac3core/src/type_check/signature.rs @@ -1,6 +1,9 @@ /// obtain class and function signature from AST -use super::primitives::PRIMITIVES; +use super::context::TopLevelContext; +use super::primitives::*; +use super::typedef::*; use rustpython_parser::ast::{ExpressionType, Statement, StatementType, StringGroup}; +use std::collections::HashMap; fn typename_from_expr<'b: 'a, 'a>(typenames: &mut Vec<&'a str>, expr: &'b ExpressionType) { match expr { @@ -33,7 +36,61 @@ fn typename_from_fn<'b: 'a, 'a>(typenames: &mut Vec<&'a str>, fun: &'b Statement } } -pub fn resolve_classes<'b: 'a, 'a>(stmts: &'b [Statement]) -> (Vec<&'a str>, Vec<&'a str>) { +fn name_from_expr<'b: 'a, 'a>(expr: &'b ExpressionType) -> &'a str { + match &expr { + ExpressionType::Identifier { name } => &name, + ExpressionType::String { value } => match value { + StringGroup::Constant { value } => &value, + _ => unimplemented!(), + }, + _ => unimplemented!(), + } +} + +fn type_from_expr<'b: 'a, 'a>( + ctx: &'a TopLevelContext, + expr: &'b ExpressionType, +) -> Result { + match expr { + ExpressionType::Identifier { name } => { + ctx.get_type(name).ok_or_else(|| "no such type".into()) + } + ExpressionType::String { value } => match value { + StringGroup::Constant { value } => { + ctx.get_type(&value).ok_or_else(|| "no such type".into()) + } + _ => unimplemented!(), + }, + ExpressionType::Subscript { a, b } => { + if let ExpressionType::Identifier { name } = &a.node { + match name.as_str() { + "list" => { + let ty = type_from_expr(ctx, &b.node)?; + Ok(TypeEnum::ParametricType(LIST_TYPE, vec![ty]).into()) + } + "tuple" => { + if let ExpressionType::Tuple { elements } = &b.node { + let ty_list: Result, _> = elements + .iter() + .map(|v| type_from_expr(ctx, &v.node)) + .collect(); + Ok(TypeEnum::ParametricType(TUPLE_TYPE, ty_list?).into()) + } else { + Err("unsupported format".into()) + } + } + _ => Err("no such parameterized type".into()), + } + } else { + // we require a to be an identifier, for a[b] + Err("unsupported format".into()) + } + } + _ => Err("unsupported format".into()), + } +} + +pub fn get_typenames<'b: 'a, 'a>(stmts: &'b [Statement]) -> (Vec<&'a str>, Vec<&'a str>) { let mut classes = Vec::new(); let mut typenames = Vec::new(); for stmt in stmts.iter() { @@ -45,15 +102,8 @@ pub fn resolve_classes<'b: 'a, 'a>(stmts: &'b [Statement]) -> (Vec<&'a str>, Vec // and annotations classes.push(&name[..]); for base in bases.iter() { - let name = match &base.node { - ExpressionType::Identifier { name } => name, - ExpressionType::String { value } => match value { - StringGroup::Constant { value } => value, - _ => unimplemented!(), - }, - _ => unimplemented!(), - }; - typenames.push(&name[..]); + let name = name_from_expr(&base.node); + typenames.push(name); } // may check if fields/functions are not duplicated for stmt in body.iter() { @@ -84,6 +134,82 @@ pub fn resolve_classes<'b: 'a, 'a>(stmts: &'b [Statement]) -> (Vec<&'a str>, Vec (classes, unknowns) } +fn resolve_function<'b: 'a, 'a>( + ctx: &'a TopLevelContext, + fun: &'b StatementType, +) -> Result { + if let StatementType::FunctionDef { args, returns, .. } = &fun { + let args: Result, _> = args + .args + .iter() + .map(|arg| type_from_expr(ctx, &arg.annotation.as_ref().unwrap().node)) + .collect(); + let args = args?; + let result = match returns { + Some(v) => Some(type_from_expr(ctx, &v.node)?), + None => None, + }; + Ok(FnDef { args, result }) + } else { + unreachable!() + } +} + +pub fn resolve_signatures<'b: 'a, 'a>(ctx: &'a mut TopLevelContext<'a>, stmts: &'b [Statement]) { + for stmt in stmts.iter() { + match &stmt.node { + StatementType::ClassDef { + name, bases, body, .. + } => { + let mut parents = Vec::new(); + for base in bases.iter() { + let name = name_from_expr(&base.node); + let c = ctx.get_type(name).unwrap(); + let id = if let TypeEnum::ClassType(id) = c.as_ref() { + *id + } else { + unreachable!() + }; + parents.push(id); + } + + let mut fields = HashMap::new(); + let mut functions = HashMap::new(); + + for stmt in body.iter() { + match &stmt.node { + StatementType::AnnAssign { + target, annotation, .. + } => { + let name = name_from_expr(&target.node); + let ty = type_from_expr(ctx, &annotation.node).unwrap(); + fields.insert(name, ty); + } + StatementType::FunctionDef { name, .. } => { + functions.insert(&name[..], resolve_function(ctx, &stmt.node).unwrap()); + } + _ => unimplemented!(), + } + } + + let class = ctx.get_type(name).unwrap(); + let class = if let TypeEnum::ClassType(id) = class.as_ref() { + ctx.get_class_def_mut(*id) + } else { + unreachable!() + }; + class.parents.extend_from_slice(&parents); + class.base.fields.clone_from(&fields); + class.base.methods.clone_from(&functions); + } + StatementType::FunctionDef { name, .. } => { + ctx.add_fn(&name[..], resolve_function(ctx, &stmt.node).unwrap()); + } + _ => unimplemented!(), + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -108,7 +234,7 @@ mod tests { return b " }) .unwrap(); - let (mut classes, mut unknowns) = resolve_classes(&ast.statements); + let (mut classes, mut unknowns) = get_typenames(&ast.statements); let classes_count = classes.len(); let unknowns_count = unknowns.len(); classes.sort();