diff --git a/nac3core/src/expression.rs b/nac3core/src/expression.rs index 29b9565b53..45f641331e 100644 --- a/nac3core/src/expression.rs +++ b/nac3core/src/expression.rs @@ -2,7 +2,10 @@ use crate::inference::resolve_call; use crate::operators::*; use crate::primitives::*; use crate::typedef::{GlobalContext, Type, Type::*}; -use rustpython_parser::ast::{Comparison, Expression, ExpressionType, Operator, UnaryOperator}; +use rustpython_parser::ast::{ + Comparison, Comprehension, ComprehensionKind, Expression, ExpressionType, Operator, + UnaryOperator, +}; use std::collections::HashMap; use std::rc::Rc; @@ -10,7 +13,39 @@ type SymTable<'a> = HashMap<&'a str, Rc>; type ParserResult = Result>, String>; pub fn parse_expr(ctx: &GlobalContext, sym_table: &SymTable, expr: &Expression) -> ParserResult { - Err("not supported".into()) + match &expr.node { + ExpressionType::Number { value } => parse_constant(ctx, sym_table, value), + ExpressionType::Identifier { name } => parse_identifier(ctx, sym_table, name), + ExpressionType::List { elements } => parse_list(ctx, sym_table, elements), + ExpressionType::Tuple { elements } => parse_tuple(ctx, sym_table, elements), + ExpressionType::Attribute { value, name } => parse_attribute(ctx, sym_table, value, name), + ExpressionType::BoolOp { values, .. } => parse_bool_ops(ctx, sym_table, values), + ExpressionType::Binop { a, b, op } => parse_bin_ops(ctx, sym_table, op, a, b), + ExpressionType::Unop { op, a } => parse_unary_ops(ctx, sym_table, op, a), + ExpressionType::Compare { vals, ops } => parse_compare(ctx, sym_table, vals, ops), + ExpressionType::Call { args, function, keywords} => { + if keywords.len() > 0 { + Err("keyword is not supported".into()) + } else { + parse_call(ctx, sym_table, &args, &function) + } + }, + ExpressionType::Subscript { a, b } => parse_subscript(ctx, sym_table, a, b), + ExpressionType::IfExpression { test, body, orelse } => { + parse_if_expr(ctx, sym_table, &test, &body, orelse) + } + ExpressionType::Comprehension { kind, generators } => match kind.as_ref() { + ComprehensionKind::List { element } => { + if generators.len() == 1 { + parse_list_comprehension(ctx, sym_table, element, &generators[0]) + } else { + Err("only 1 generator statement is supported".into()) + } + } + _ => Err("only list comprehension is supported".into()), + }, + _ => Err("not supported".into()), + } } fn parse_constant( @@ -78,7 +113,7 @@ fn parse_attribute( ctx: &GlobalContext, sym_table: &SymTable, value: &Expression, - name: String, + name: &String, ) -> ParserResult { let value = parse_expr(ctx, sym_table, value)?.ok_or("no value".to_string())?; if let TypeVariable(id) = value.as_ref() { @@ -141,3 +176,198 @@ fn parse_bin_ops( resolve_call(ctx, Some(left), fun, &[right], &mut assumptions) } +fn parse_unary_ops( + ctx: &GlobalContext, + sym_table: &SymTable, + op: &UnaryOperator, + obj: &Expression, +) -> ParserResult { + let ty = parse_expr(ctx, sym_table, obj)?.ok_or("no value".to_string())?; + let mut assumptions = HashMap::new(); + if let UnaryOperator::Not = op { + if ty.as_ref() == &PrimitiveType(BOOL_TYPE) { + Ok(Some(ty)) + } else { + Err("logical not must be applied to bool".into()) + } + } else { + resolve_call(ctx, Some(ty), unaryop_name(op), &[], &mut assumptions) + } +} + +fn parse_compare( + ctx: &GlobalContext, + sym_table: &SymTable, + vals: &[Expression], + ops: &[Comparison], +) -> ParserResult { + let types: Result>, _> = + vals.iter().map(|v| parse_expr(ctx, sym_table, v)).collect(); + let types = types?; + if types.is_none() { + return Err("comparison operands must have type".into()); + } + let types = types.unwrap(); + let boolean = PrimitiveType(BOOL_TYPE); + let left = &types[..types.len() - 1]; + let right = &types[1..]; + let mut assumptions = HashMap::new(); + + for ((a, b), op) in left.iter().zip(right.iter()).zip(ops.iter()) { + let fun = comparison_name(op).ok_or("unsupported comparison".to_string())?; + let ty = resolve_call(ctx, Some(a.clone()), fun, &[b.clone()], &mut assumptions)?; + if ty.is_none() || ty.unwrap().as_ref() != &boolean { + return Err("comparison result must be boolean".into()); + } + } + Ok(Some(boolean.into())) +} + +fn parse_call( + ctx: &GlobalContext, + sym_table: &SymTable, + args: &[Expression], + function: &Expression, +) -> ParserResult { + let types: Result>, _> = + args.iter().map(|v| parse_expr(ctx, sym_table, v)).collect(); + let types = types?; + if types.is_none() { + return Err("function params must have type".into()); + } + let mut assumptions = HashMap::new(); + + let (obj, fun) = match &function.node { + ExpressionType::Identifier { name } => (None, name), + ExpressionType::Attribute { value, name } => ( + Some(parse_expr(ctx, sym_table, &value)?.ok_or("no value".to_string())?), + name, + ), + _ => return Err("not supported".into()), + }; + resolve_call(ctx, obj, fun.as_str(), &types.unwrap(), &mut assumptions) +} + +fn parse_subscript( + ctx: &GlobalContext, + sym_table: &SymTable, + a: &Expression, + b: &Expression, +) -> ParserResult { + let a = parse_expr(ctx, sym_table, a)?.ok_or("no value".to_string())?; + let t = if let ParametricType(LIST_TYPE, ls) = a.as_ref() { + ls[0].clone() + } else { + return Err("subscript is not supported for types other than list".into()); + }; + + match &b.node { + ExpressionType::Slice { elements } => { + let types: Result>, _> = elements + .iter() + .map(|v| parse_expr(ctx, sym_table, v)) + .collect(); + let types = types?.ok_or("slice must have type".to_string())?; + let int32 = PrimitiveType(INT32_TYPE); + if types.iter().all(|v| v.as_ref() == &int32) { + Ok(Some(a)) + } else { + Err("slice must be int32 type".into()) + } + } + _ => { + let b = parse_expr(ctx, sym_table, b)?.ok_or("no value".to_string())?; + if b.as_ref() == &PrimitiveType(INT32_TYPE) { + Ok(Some(t)) + } else { + Err("index must be either slice or int32".into()) + } + } + } +} + +fn parse_if_expr( + ctx: &GlobalContext, + sym_table: &SymTable, + test: &Expression, + body: &Expression, + orelse: &Expression, +) -> ParserResult { + let test = parse_expr(ctx, sym_table, test)?.ok_or("no value".to_string())?; + if test.as_ref() != &PrimitiveType(BOOL_TYPE) { + return Err("test should be bool".into()); + } + + let body = parse_expr(ctx, sym_table, body)?.ok_or("no value".to_string())?; + let orelse = parse_expr(ctx, sym_table, orelse)?.ok_or("no value".to_string())?; + if body.as_ref() == orelse.as_ref() { + Ok(Some(body)) + } else { + Err("divergent type".into()) + } +} + +fn parse_simple_binding<'a: 'b, 'b>( + sym_table: &mut SymTable<'b>, + name: &'a Expression, + ty: Rc, +) -> Result<(), String> { + match &name.node { + ExpressionType::Identifier { name } => { + if name == "_" { + Ok(()) + } else if sym_table.get(name.as_str()).is_some() { + Err("duplicated naming".into()) + } else { + sym_table.insert(name.as_str(), ty); + Ok(()) + } + } + ExpressionType::Tuple { elements } => { + if let ParametricType(TUPLE_TYPE, ls) = ty.as_ref() { + if elements.len() == ls.len() { + for (a, b) in elements.iter().zip(ls.iter()) { + parse_simple_binding(sym_table, a, b.clone())?; + } + Ok(()) + } else { + Err("different length".into()) + } + } else { + Err("not supported".into()) + } + } + _ => Err("not supported".into()), + } +} + +fn parse_list_comprehension( + ctx: &GlobalContext, + sym_table: &SymTable, + element: &Expression, + comprehension: &Comprehension, +) -> ParserResult { + if comprehension.is_async { + return Err("async is not supported".into()); + } + + // TODO: it may be more efficient to use multi-level table + // but it would better done in a whole program level + let iter = parse_expr(ctx, sym_table, &comprehension.iter)?.ok_or("no value".to_string())?; + if let ParametricType(LIST_TYPE, ls) = iter.as_ref() { + let mut local_sym = sym_table.clone(); + parse_simple_binding(&mut local_sym, &comprehension.target, ls[0].clone())?; + + let boolean = PrimitiveType(BOOL_TYPE); + for test in comprehension.ifs.iter() { + let result = parse_expr(ctx, &local_sym, test)?.ok_or("no value in test".to_string())?; + if result.as_ref() != &boolean { + return Err("test must be bool".into()); + } + } + parse_expr(ctx, &local_sym, element) + } else { + Err("iteration is supported for list only".into()) + } +} + diff --git a/nac3core/src/operators.rs b/nac3core/src/operators.rs index a33f4aa882..5619b2e46a 100644 --- a/nac3core/src/operators.rs +++ b/nac3core/src/operators.rs @@ -2,39 +2,57 @@ use rustpython_parser::ast::{Comparison, Operator, UnaryOperator}; pub fn binop_name(op: &Operator) -> &'static str { match op { - Operator::Add => "add", - Operator::Sub => "sub", - Operator::Div => "truediv", - Operator::Mod => "mod", - Operator::Mult => "mul", - Operator::Pow => "pow", - Operator::BitOr => "or", - Operator::BitXor => "xor", - Operator::BitAnd => "and", - Operator::LShift => "lshift", - Operator::RShift => "rshift", - Operator::FloorDiv => "floordiv", - Operator::MatMult => "matmul", + Operator::Add => "__add__", + Operator::Sub => "__sub__", + Operator::Div => "__truediv__", + Operator::Mod => "__mod__", + Operator::Mult => "__mul__", + Operator::Pow => "__pow__", + Operator::BitOr => "__or__", + Operator::BitXor => "__xor__", + Operator::BitAnd => "__and__", + Operator::LShift => "__lshift__", + Operator::RShift => "__rshift__", + Operator::FloorDiv => "__floordiv__", + Operator::MatMult => "__matmul__", + } +} + +pub fn binop_assign_name(op: &Operator) -> &'static str { + match op { + Operator::Add => "__iadd__", + Operator::Sub => "__isub__", + Operator::Div => "__itruediv__", + Operator::Mod => "__imod__", + Operator::Mult => "__imul__", + Operator::Pow => "__ipow__", + Operator::BitOr => "__ior__", + Operator::BitXor => "__ixor__", + Operator::BitAnd => "__iand__", + Operator::LShift => "__ilshift__", + Operator::RShift => "__irshift__", + Operator::FloorDiv => "__ifloordiv__", + Operator::MatMult => "__imatmul__", } } pub fn unaryop_name(op: &UnaryOperator) -> &'static str { match op { - UnaryOperator::Pos => "pos", - UnaryOperator::Neg => "neg", - UnaryOperator::Not => "not", - UnaryOperator::Inv => "inv", + UnaryOperator::Pos => "__pos__", + UnaryOperator::Neg => "__neg__", + UnaryOperator::Not => "__not__", + UnaryOperator::Inv => "__inv__", } } pub fn comparison_name(op: &Comparison) -> Option<&'static str> { match op { - Comparison::Less => Some("lt"), - Comparison::LessOrEqual => Some("le"), - Comparison::Greater => Some("gt"), - Comparison::GreaterOrEqual => Some("ge"), - Comparison::Equal => Some("eq"), - Comparison::NotEqual => Some("ne"), + Comparison::Less => Some("__lt__"), + Comparison::LessOrEqual => Some("__le__"), + Comparison::Greater => Some("__gt__"), + Comparison::GreaterOrEqual => Some("__ge__"), + Comparison::Equal => Some("__eq__"), + Comparison::NotEqual => Some("__ne__"), _ => None, } }