diff --git a/nac3core/src/expression.rs b/nac3core/src/expression.rs new file mode 100644 index 0000000000..4192fa4c0c --- /dev/null +++ b/nac3core/src/expression.rs @@ -0,0 +1,385 @@ +use crate::inference::resolve_call; +use crate::operators::*; +use crate::primitives::*; +use crate::typedef::{GlobalContext, Type, Type::*}; +use rustpython_parser::ast::{ + Comparison, Comprehension, ComprehensionKind, Expression, ExpressionType, Operator, + UnaryOperator, +}; +use std::collections::HashMap; +use std::convert::TryInto; +use std::rc::Rc; + +type SymTable<'a> = HashMap<&'a str, Rc>; +type ParserResult = Result>, String>; + +pub fn infer_expr(ctx: &GlobalContext, sym_table: &SymTable, expr: &Expression) -> ParserResult { + match &expr.node { + ExpressionType::Number { value } => infer_constant(ctx, sym_table, value), + ExpressionType::Identifier { name } => infer_identifier(ctx, sym_table, name), + ExpressionType::List { elements } => infer_list(ctx, sym_table, elements), + ExpressionType::Tuple { elements } => infer_tuple(ctx, sym_table, elements), + ExpressionType::Attribute { value, name } => infer_attribute(ctx, sym_table, value, name), + ExpressionType::BoolOp { values, .. } => infer_bool_ops(ctx, sym_table, values), + ExpressionType::Binop { a, b, op } => infer_bin_ops(ctx, sym_table, op, a, b), + ExpressionType::Unop { op, a } => infer_unary_ops(ctx, sym_table, op, a), + ExpressionType::Compare { vals, ops } => infer_compare(ctx, sym_table, vals, ops), + ExpressionType::Call { + args, + function, + keywords, + } => { + if keywords.len() > 0 { + Err("keyword is not supported".into()) + } else { + infer_call(ctx, sym_table, &args, &function) + } + } + ExpressionType::Subscript { a, b } => infer_subscript(ctx, sym_table, a, b), + ExpressionType::IfExpression { test, body, orelse } => { + infer_if_expr(ctx, sym_table, &test, &body, orelse) + } + ExpressionType::Comprehension { kind, generators } => match kind.as_ref() { + ComprehensionKind::List { element } => { + if generators.len() == 1 { + infer_list_comprehension(ctx, sym_table, element, &generators[0]) + } else { + Err("only 1 generator statement is supported".into()) + } + } + _ => Err("only list comprehension is supported".into()), + }, + ExpressionType::True | ExpressionType::False => Ok(Some(PrimitiveType(BOOL_TYPE).into())), + _ => Err("not supported".into()), + } +} + +fn infer_constant( + _: &GlobalContext, + _: &SymTable, + value: &rustpython_parser::ast::Number, +) -> ParserResult { + use rustpython_parser::ast::Number; + match value { + Number::Integer { value } => { + let int32: Result = value.try_into(); + if int32.is_ok() { + Ok(Some(PrimitiveType(INT32_TYPE).into())) + } else { + let int64: Result = value.try_into(); + if int64.is_ok() { + Ok(Some(PrimitiveType(INT64_TYPE).into())) + } else { + Err("integer out of range".into()) + } + } + } + Number::Float { .. } => Ok(Some(PrimitiveType(FLOAT_TYPE).into())), + _ => Err("not supported".into()), + } +} + +fn infer_identifier(_: &GlobalContext, sym_table: &SymTable, name: &str) -> ParserResult { + match sym_table.get(name) { + Some(v) => Ok(Some(v.clone())), + None => Err("unbounded variable".into()), + } +} + +fn infer_list(ctx: &GlobalContext, sym_table: &SymTable, elements: &[Expression]) -> ParserResult { + if elements.len() == 0 { + return Ok(Some(ParametricType(LIST_TYPE, vec![BotType.into()]).into())); + } + + let mut types = elements.iter().map(|v| infer_expr(&ctx, sym_table, v)); + + let head = types.next().unwrap()?; + if head.is_none() { + return Err("list elements must have some type".into()); + } + for v in types { + if v? != head { + return Err("inhomogeneous list is not allowed".into()); + } + } + Ok(Some(ParametricType(LIST_TYPE, vec![head.unwrap()]).into())) +} + +fn infer_tuple(ctx: &GlobalContext, sym_table: &SymTable, elements: &[Expression]) -> ParserResult { + let types: Result>, String> = elements + .iter() + .map(|v| infer_expr(&ctx, sym_table, v)) + .collect(); + if let Some(t) = types? { + Ok(Some(ParametricType(TUPLE_TYPE, t).into())) + } else { + Err("tuple elements must have some type".into()) + } +} + +fn infer_attribute( + ctx: &GlobalContext, + sym_table: &SymTable, + value: &Expression, + name: &String, +) -> ParserResult { + let value = infer_expr(ctx, sym_table, value)?.ok_or("no value".to_string())?; + if let TypeVariable(id) = value.as_ref() { + let v = ctx.get_variable(*id); + if v.bound.len() == 0 { + return Err("no fields on unbounded type variable".into()); + } + let ty = v.bound[0] + .get_base(ctx) + .and_then(|v| v.fields.get(name.as_str())); + if ty.is_none() { + return Err("unknown field".into()); + } + for x in v.bound[1..].iter() { + let ty1 = x.get_base(ctx).and_then(|v| v.fields.get(name.as_str())); + if ty1 != ty { + return Err("unknown field (type mismatch between variants)".into()); + } + } + return Ok(Some(ty.unwrap().clone())); + } + + match value.get_base(ctx) { + Some(b) => match b.fields.get(name.as_str()) { + Some(t) => Ok(Some(t.clone())), + None => Err("no such field".into()), + }, + None => Err("this object has no fields".into()), + } +} + +fn infer_bool_ops( + ctx: &GlobalContext, + sym_table: &SymTable, + values: &[Expression], +) -> ParserResult { + assert_eq!(values.len(), 2); + let left = infer_expr(ctx, sym_table, &values[0])?.ok_or("no value".to_string())?; + let right = infer_expr(ctx, sym_table, &values[1])?.ok_or("no value".to_string())?; + + let b = PrimitiveType(BOOL_TYPE); + if left.as_ref() == &b && right.as_ref() == &b { + Ok(Some(b.into())) + } else { + Err("bool operands must be bool".into()) + } +} + +fn infer_bin_ops( + ctx: &GlobalContext, + sym_table: &SymTable, + op: &Operator, + left: &Expression, + right: &Expression, +) -> ParserResult { + let left = infer_expr(ctx, sym_table, left)?.ok_or("no value".to_string())?; + let right = infer_expr(ctx, sym_table, right)?.ok_or("no value".to_string())?; + let fun = binop_name(op); + resolve_call(ctx, Some(left), fun, &[right]) +} + +fn infer_unary_ops( + ctx: &GlobalContext, + sym_table: &SymTable, + op: &UnaryOperator, + obj: &Expression, +) -> ParserResult { + let ty = infer_expr(ctx, sym_table, obj)?.ok_or("no value".to_string())?; + if let UnaryOperator::Not = op { + if ty.as_ref() == &PrimitiveType(BOOL_TYPE) { + Ok(Some(ty)) + } else { + Err("logical not must be applied to bool".into()) + } + } else { + resolve_call(ctx, Some(ty), unaryop_name(op), &[]) + } +} + +fn infer_compare( + ctx: &GlobalContext, + sym_table: &SymTable, + vals: &[Expression], + ops: &[Comparison], +) -> ParserResult { + let types: Result>, _> = + vals.iter().map(|v| infer_expr(ctx, sym_table, v)).collect(); + let types = types?; + if types.is_none() { + return Err("comparison operands must have type".into()); + } + let types = types.unwrap(); + let boolean = PrimitiveType(BOOL_TYPE); + let left = &types[..types.len() - 1]; + let right = &types[1..]; + + for ((a, b), op) in left.iter().zip(right.iter()).zip(ops.iter()) { + let fun = comparison_name(op).ok_or("unsupported comparison".to_string())?; + let ty = resolve_call(ctx, Some(a.clone()), fun, &[b.clone()])?; + if ty.is_none() || ty.unwrap().as_ref() != &boolean { + return Err("comparison result must be boolean".into()); + } + } + Ok(Some(boolean.into())) +} + +fn infer_call( + ctx: &GlobalContext, + sym_table: &SymTable, + args: &[Expression], + function: &Expression, +) -> ParserResult { + let types: Result>, _> = + args.iter().map(|v| infer_expr(ctx, sym_table, v)).collect(); + let types = types?; + if types.is_none() { + return Err("function params must have type".into()); + } + + let (obj, fun) = match &function.node { + ExpressionType::Identifier { name } => (None, name), + ExpressionType::Attribute { value, name } => ( + Some(infer_expr(ctx, sym_table, &value)?.ok_or("no value".to_string())?), + name, + ), + _ => return Err("not supported".into()), + }; + resolve_call(ctx, obj, fun.as_str(), &types.unwrap()) +} + +fn infer_subscript( + ctx: &GlobalContext, + sym_table: &SymTable, + a: &Expression, + b: &Expression, +) -> ParserResult { + let a = infer_expr(ctx, sym_table, a)?.ok_or("no value".to_string())?; + let t = if let ParametricType(LIST_TYPE, ls) = a.as_ref() { + ls[0].clone() + } else { + return Err("subscript is not supported for types other than list".into()); + }; + + match &b.node { + ExpressionType::Slice { elements } => { + let int32 = Rc::new(PrimitiveType(INT32_TYPE)); + let types: Result>, _> = elements + .iter() + .map(|v| { + if let ExpressionType::None = v.node { + Ok(Some(int32.clone())) + } else { + infer_expr(ctx, sym_table, v) + } + }) + .collect(); + let types = types?.ok_or("slice must have type".to_string())?; + if types.iter().all(|v| v == &int32) { + Ok(Some(a)) + } else { + Err("slice must be int32 type".into()) + } + } + _ => { + let b = infer_expr(ctx, sym_table, b)?.ok_or("no value".to_string())?; + if b.as_ref() == &PrimitiveType(INT32_TYPE) { + Ok(Some(t)) + } else { + Err("index must be either slice or int32".into()) + } + } + } +} + +fn infer_if_expr( + ctx: &GlobalContext, + sym_table: &SymTable, + test: &Expression, + body: &Expression, + orelse: &Expression, +) -> ParserResult { + let test = infer_expr(ctx, sym_table, test)?.ok_or("no value".to_string())?; + if test.as_ref() != &PrimitiveType(BOOL_TYPE) { + return Err("test should be bool".into()); + } + + let body = infer_expr(ctx, sym_table, body)?; + let orelse = infer_expr(ctx, sym_table, orelse)?; + if body.as_ref() == orelse.as_ref() { + Ok(body) + } else { + Err("divergent type".into()) + } +} + +fn infer_simple_binding<'a: 'b, 'b>( + sym_table: &mut SymTable<'b>, + name: &'a Expression, + ty: Rc, +) -> Result<(), String> { + match &name.node { + ExpressionType::Identifier { name } => { + if name == "_" { + Ok(()) + } else if sym_table.get(name.as_str()).is_some() { + Err("duplicated naming".into()) + } else { + sym_table.insert(name.as_str(), ty); + Ok(()) + } + } + ExpressionType::Tuple { elements } => { + if let ParametricType(TUPLE_TYPE, ls) = ty.as_ref() { + if elements.len() == ls.len() { + for (a, b) in elements.iter().zip(ls.iter()) { + infer_simple_binding(sym_table, a, b.clone())?; + } + Ok(()) + } else { + Err("different length".into()) + } + } else { + Err("not supported".into()) + } + } + _ => Err("not supported".into()), + } +} + +fn infer_list_comprehension( + ctx: &GlobalContext, + sym_table: &SymTable, + element: &Expression, + comprehension: &Comprehension, +) -> ParserResult { + if comprehension.is_async { + return Err("async is not supported".into()); + } + + // TODO: it may be more efficient to use multi-level table + // but it would better done in a whole program level + let iter = infer_expr(ctx, sym_table, &comprehension.iter)?.ok_or("no value".to_string())?; + if let ParametricType(LIST_TYPE, ls) = iter.as_ref() { + let mut local_sym = sym_table.clone(); + infer_simple_binding(&mut local_sym, &comprehension.target, ls[0].clone())?; + + let boolean = PrimitiveType(BOOL_TYPE); + for test in comprehension.ifs.iter() { + let result = + infer_expr(ctx, &local_sym, test)?.ok_or("no value in test".to_string())?; + if result.as_ref() != &boolean { + return Err("test must be bool".into()); + } + } + let result = infer_expr(ctx, &local_sym, element)?.ok_or("no value")?; + Ok(Some(ParametricType(LIST_TYPE, vec![result]).into())) + } else { + Err("iteration is supported for list only".into()) + } +} + diff --git a/nac3core/src/lib.rs b/nac3core/src/lib.rs index 36962b77ff..ab521b4a75 100644 --- a/nac3core/src/lib.rs +++ b/nac3core/src/lib.rs @@ -2,6 +2,7 @@ extern crate num_bigint; extern crate inkwell; extern crate rustpython_parser; +pub mod expression; pub mod inference; mod operators; pub mod primitives;