diff --git a/nac3core/src/expression.rs b/nac3core/src/expression.rs index c91e8c10a4..29b9565b53 100644 --- a/nac3core/src/expression.rs +++ b/nac3core/src/expression.rs @@ -1,13 +1,13 @@ use crate::inference::resolve_call; +use crate::operators::*; use crate::primitives::*; use crate::typedef::{GlobalContext, Type, Type::*}; -use rustpython_parser::ast::{Expression, ExpressionType}; +use rustpython_parser::ast::{Comparison, Expression, ExpressionType, Operator, UnaryOperator}; use std::collections::HashMap; -use std::convert::TryFrom; use std::rc::Rc; type SymTable<'a> = HashMap<&'a str, Rc>; -type ParserResult = Result, String>; +type ParserResult = Result>, String>; pub fn parse_expr(ctx: &GlobalContext, sym_table: &SymTable, expr: &Expression) -> ParserResult { Err("not supported".into()) @@ -22,7 +22,7 @@ fn parse_constant( match value { Number::Integer { .. } => { // not check the range now - Ok(PrimitiveType(INT32_TYPE).into()) + Ok(Some(PrimitiveType(INT32_TYPE).into())) // if i32::try_from(&value).is_ok() { // Ok(PrimitiveType(INT32_TYPE).into()) // } else if i64::try_from(&value).is_ok() { @@ -31,40 +31,47 @@ fn parse_constant( // Err("integer out of range".into()) // } } - Number::Float { .. } => Ok(PrimitiveType(FLOAT_TYPE).into()), + Number::Float { .. } => Ok(Some(PrimitiveType(FLOAT_TYPE).into())), _ => Err("not supported".into()), } } fn parse_identifier(_: &GlobalContext, sym_table: &SymTable, name: &str) -> ParserResult { match sym_table.get(name) { - Some(v) => Ok(v.clone()), + Some(v) => Ok(Some(v.clone())), None => Err("unbounded variable".into()), } } fn parse_list(ctx: &GlobalContext, sym_table: &SymTable, elements: &[Expression]) -> ParserResult { if elements.len() == 0 { - return Ok(ParametricType(LIST_TYPE, vec![BotType.into()]).into()); + return Ok(Some(ParametricType(LIST_TYPE, vec![BotType.into()]).into())); } let mut types = elements.iter().map(|v| parse_expr(&ctx, sym_table, v)); let head = types.next().unwrap()?; + if head.is_none() { + return Err("list elements must have some type".into()); + } for v in types { if v? != head { return Err("inhomogeneous list is not allowed".into()); } } - Ok(ParametricType(LIST_TYPE, vec![head]).into()) + Ok(Some(ParametricType(LIST_TYPE, vec![head.unwrap()]).into())) } fn parse_tuple(ctx: &GlobalContext, sym_table: &SymTable, elements: &[Expression]) -> ParserResult { - let types: Result, String> = elements + let types: Result>, String> = elements .iter() .map(|v| parse_expr(&ctx, sym_table, v)) .collect(); - Ok(ParametricType(TUPLE_TYPE, types?).into()) + if let Some(t) = types? { + Ok(Some(ParametricType(TUPLE_TYPE, t).into())) + } else { + Err("tuple elements must have some type".into()) + } } fn parse_attribute( @@ -73,7 +80,7 @@ fn parse_attribute( value: &Expression, name: String, ) -> ParserResult { - let value = parse_expr(ctx, sym_table, value)?; + let value = parse_expr(ctx, sym_table, value)?.ok_or("no value".to_string())?; if let TypeVariable(id) = value.as_ref() { let v = ctx.get_variable(*id); if v.bound.len() == 0 { @@ -91,12 +98,12 @@ fn parse_attribute( return Err("unknown field (type mismatch between variants)".into()); } } - return Ok(ty.unwrap().clone()); + return Ok(Some(ty.unwrap().clone())); } match value.get_base(ctx) { Some(b) => match b.fields.get(name.as_str()) { - Some(t) => Ok(t.clone()), + Some(t) => Ok(Some(t.clone())), None => Err("no such field".into()), }, None => Err("this object has no fields".into()), @@ -109,15 +116,28 @@ fn parse_bool_ops( values: &[Expression], ) -> ParserResult { assert_eq!(values.len(), 2); - let left = parse_expr(ctx, sym_table, &values[0])?; - let right = parse_expr(ctx, sym_table, &values[1])?; + let left = parse_expr(ctx, sym_table, &values[0])?.ok_or("no value".to_string())?; + let right = parse_expr(ctx, sym_table, &values[1])?.ok_or("no value".to_string())?; let b = PrimitiveType(BOOL_TYPE); if left.as_ref() == &b && right.as_ref() == &b { - Ok(b.into()) + Ok(Some(b.into())) } else { Err("bool operands must be bool".into()) } } +fn parse_bin_ops( + ctx: &GlobalContext, + sym_table: &SymTable, + op: &Operator, + left: &Expression, + right: &Expression, +) -> ParserResult { + let left = parse_expr(ctx, sym_table, left)?.ok_or("no value".to_string())?; + let right = parse_expr(ctx, sym_table, right)?.ok_or("no value".to_string())?; + let fun = binop_name(op); + let mut assumptions = HashMap::new(); + resolve_call(ctx, Some(left), fun, &[right], &mut assumptions) +} diff --git a/nac3core/src/operators.rs b/nac3core/src/operators.rs index 31990d6976..a33f4aa882 100644 --- a/nac3core/src/operators.rs +++ b/nac3core/src/operators.rs @@ -1,6 +1,6 @@ use rustpython_parser::ast::{Comparison, Operator, UnaryOperator}; -pub fn binop_name(op: Operator) -> &'static str { +pub fn binop_name(op: &Operator) -> &'static str { match op { Operator::Add => "add", Operator::Sub => "sub", @@ -18,7 +18,7 @@ pub fn binop_name(op: Operator) -> &'static str { } } -pub fn unaryop_name(op: UnaryOperator) -> &'static str { +pub fn unaryop_name(op: &UnaryOperator) -> &'static str { match op { UnaryOperator::Pos => "pos", UnaryOperator::Neg => "neg", @@ -27,7 +27,7 @@ pub fn unaryop_name(op: UnaryOperator) -> &'static str { } } -pub fn comparison_name(op: Comparison) -> Option<&'static str> { +pub fn comparison_name(op: &Comparison) -> Option<&'static str> { match op { Comparison::Less => Some("lt"), Comparison::LessOrEqual => Some("le"),