From 8ade8c7b1fffb60f6873c50bb0515671cb25b00b Mon Sep 17 00:00:00 2001 From: pca006132 Date: Thu, 31 Dec 2020 16:36:01 +0800 Subject: [PATCH] fixed expression.rs and renamed to expression_inference --- nac3core/src/context/inference_context.rs | 7 +- nac3core/src/expression.rs | 1023 --------------------- nac3core/src/expression_inference.rs | 1007 ++++++++++++++++++++ nac3core/src/lib.rs | 2 +- 4 files changed, 1014 insertions(+), 1025 deletions(-) delete mode 100644 nac3core/src/expression.rs create mode 100644 nac3core/src/expression_inference.rs diff --git a/nac3core/src/context/inference_context.rs b/nac3core/src/context/inference_context.rs index f321f2e1c..04fb49d04 100644 --- a/nac3core/src/context/inference_context.rs +++ b/nac3core/src/context/inference_context.rs @@ -113,10 +113,15 @@ impl<'a> InferenceContext<'a> { } } + /// check if an identifier is already defined + pub fn defined(&self, name: &str) -> bool { + self.sym_table.get(name).is_some() + } + /// get the type of an identifier /// may return error if the identifier is not defined, and cannot be resolved with the /// resolution function. - pub fn resolve(&mut self, name: &'a str) -> Result { + pub fn resolve(&mut self, name: & str) -> Result { if let Some((t, x)) = self.sym_table.get(name) { if *x { Ok(t.clone()) diff --git a/nac3core/src/expression.rs b/nac3core/src/expression.rs deleted file mode 100644 index 1e24bbdba..000000000 --- a/nac3core/src/expression.rs +++ /dev/null @@ -1,1023 +0,0 @@ -use crate::inference::resolve_call; -use crate::operators::*; -use crate::primitives::*; -use crate::typedef::{GlobalContext, Type, Type::*}; -use rustpython_parser::ast::{ - Comparison, Comprehension, ComprehensionKind, Expression, ExpressionType, Operator, - UnaryOperator, -}; -use std::collections::HashMap; -use std::convert::TryInto; -use std::rc::Rc; - -type SymTable<'a> = HashMap<&'a str, Rc>; -type ParserResult = Result>, String>; - -pub fn infer_expr(ctx: &GlobalContext, sym_table: &SymTable, expr: &Expression) -> ParserResult { - match &expr.node { - ExpressionType::Number { value } => infer_constant(ctx, sym_table, value), - ExpressionType::Identifier { name } => infer_identifier(ctx, sym_table, name), - ExpressionType::List { elements } => infer_list(ctx, sym_table, elements), - ExpressionType::Tuple { elements } => infer_tuple(ctx, sym_table, elements), - ExpressionType::Attribute { value, name } => infer_attribute(ctx, sym_table, value, name), - ExpressionType::BoolOp { values, .. } => infer_bool_ops(ctx, sym_table, values), - ExpressionType::Binop { a, b, op } => infer_bin_ops(ctx, sym_table, op, a, b), - ExpressionType::Unop { op, a } => infer_unary_ops(ctx, sym_table, op, a), - ExpressionType::Compare { vals, ops } => infer_compare(ctx, sym_table, vals, ops), - ExpressionType::Call { - args, - function, - keywords, - } => { - if keywords.len() > 0 { - Err("keyword is not supported".into()) - } else { - infer_call(ctx, sym_table, &args, &function) - } - } - ExpressionType::Subscript { a, b } => infer_subscript(ctx, sym_table, a, b), - ExpressionType::IfExpression { test, body, orelse } => { - infer_if_expr(ctx, sym_table, &test, &body, orelse) - } - ExpressionType::Comprehension { kind, generators } => match kind.as_ref() { - ComprehensionKind::List { element } => { - if generators.len() == 1 { - infer_list_comprehension(ctx, sym_table, element, &generators[0]) - } else { - Err("only 1 generator statement is supported".into()) - } - } - _ => Err("only list comprehension is supported".into()), - }, - ExpressionType::True | ExpressionType::False => Ok(Some(PrimitiveType(BOOL_TYPE).into())), - _ => Err("not supported".into()), - } -} - -fn infer_constant( - _: &GlobalContext, - _: &SymTable, - value: &rustpython_parser::ast::Number, -) -> ParserResult { - use rustpython_parser::ast::Number; - match value { - Number::Integer { value } => { - let int32: Result = value.try_into(); - if int32.is_ok() { - Ok(Some(PrimitiveType(INT32_TYPE).into())) - } else { - let int64: Result = value.try_into(); - if int64.is_ok() { - Ok(Some(PrimitiveType(INT64_TYPE).into())) - } else { - Err("integer out of range".into()) - } - } - } - Number::Float { .. } => Ok(Some(PrimitiveType(FLOAT_TYPE).into())), - _ => Err("not supported".into()), - } -} - -fn infer_identifier(_: &GlobalContext, sym_table: &SymTable, name: &str) -> ParserResult { - match sym_table.get(name) { - Some(v) => Ok(Some(v.clone())), - None => Err("unbounded variable".into()), - } -} - -fn infer_list(ctx: &GlobalContext, sym_table: &SymTable, elements: &[Expression]) -> ParserResult { - if elements.len() == 0 { - return Ok(Some(ParametricType(LIST_TYPE, vec![BotType.into()]).into())); - } - - let mut types = elements.iter().map(|v| infer_expr(&ctx, sym_table, v)); - - let head = types.next().unwrap()?; - if head.is_none() { - return Err("list elements must have some type".into()); - } - for v in types { - if v? != head { - return Err("inhomogeneous list is not allowed".into()); - } - } - Ok(Some(ParametricType(LIST_TYPE, vec![head.unwrap()]).into())) -} - -fn infer_tuple(ctx: &GlobalContext, sym_table: &SymTable, elements: &[Expression]) -> ParserResult { - let types: Result>, String> = elements - .iter() - .map(|v| infer_expr(&ctx, sym_table, v)) - .collect(); - if let Some(t) = types? { - Ok(Some(ParametricType(TUPLE_TYPE, t).into())) - } else { - Err("tuple elements must have some type".into()) - } -} - -fn infer_attribute( - ctx: &GlobalContext, - sym_table: &SymTable, - value: &Expression, - name: &String, -) -> ParserResult { - let value = infer_expr(ctx, sym_table, value)?.ok_or("no value".to_string())?; - if let TypeVariable(id) = value.as_ref() { - let v = ctx.get_variable(*id); - if v.bound.len() == 0 { - return Err("no fields on unbounded type variable".into()); - } - let ty = v.bound[0] - .get_base(ctx) - .and_then(|v| v.fields.get(name.as_str())); - if ty.is_none() { - return Err("unknown field".into()); - } - for x in v.bound[1..].iter() { - let ty1 = x.get_base(ctx).and_then(|v| v.fields.get(name.as_str())); - if ty1 != ty { - return Err("unknown field (type mismatch between variants)".into()); - } - } - return Ok(Some(ty.unwrap().clone())); - } - - match value.get_base(ctx) { - Some(b) => match b.fields.get(name.as_str()) { - Some(t) => Ok(Some(t.clone())), - None => Err("no such field".into()), - }, - None => Err("this object has no fields".into()), - } -} - -fn infer_bool_ops( - ctx: &GlobalContext, - sym_table: &SymTable, - values: &[Expression], -) -> ParserResult { - assert_eq!(values.len(), 2); - let left = infer_expr(ctx, sym_table, &values[0])?.ok_or("no value".to_string())?; - let right = infer_expr(ctx, sym_table, &values[1])?.ok_or("no value".to_string())?; - - let b = PrimitiveType(BOOL_TYPE); - if left.as_ref() == &b && right.as_ref() == &b { - Ok(Some(b.into())) - } else { - Err("bool operands must be bool".into()) - } -} - -fn infer_bin_ops( - ctx: &GlobalContext, - sym_table: &SymTable, - op: &Operator, - left: &Expression, - right: &Expression, -) -> ParserResult { - let left = infer_expr(ctx, sym_table, left)?.ok_or("no value".to_string())?; - let right = infer_expr(ctx, sym_table, right)?.ok_or("no value".to_string())?; - let fun = binop_name(op); - resolve_call(ctx, Some(left), fun, &[right]) -} - -fn infer_unary_ops( - ctx: &GlobalContext, - sym_table: &SymTable, - op: &UnaryOperator, - obj: &Expression, -) -> ParserResult { - let ty = infer_expr(ctx, sym_table, obj)?.ok_or("no value".to_string())?; - if let UnaryOperator::Not = op { - if ty.as_ref() == &PrimitiveType(BOOL_TYPE) { - Ok(Some(ty)) - } else { - Err("logical not must be applied to bool".into()) - } - } else { - resolve_call(ctx, Some(ty), unaryop_name(op), &[]) - } -} - -fn infer_compare( - ctx: &GlobalContext, - sym_table: &SymTable, - vals: &[Expression], - ops: &[Comparison], -) -> ParserResult { - let types: Result>, _> = - vals.iter().map(|v| infer_expr(ctx, sym_table, v)).collect(); - let types = types?; - if types.is_none() { - return Err("comparison operands must have type".into()); - } - let types = types.unwrap(); - let boolean = PrimitiveType(BOOL_TYPE); - let left = &types[..types.len() - 1]; - let right = &types[1..]; - - for ((a, b), op) in left.iter().zip(right.iter()).zip(ops.iter()) { - let fun = comparison_name(op).ok_or("unsupported comparison".to_string())?; - let ty = resolve_call(ctx, Some(a.clone()), fun, &[b.clone()])?; - if ty.is_none() || ty.unwrap().as_ref() != &boolean { - return Err("comparison result must be boolean".into()); - } - } - Ok(Some(boolean.into())) -} - -fn infer_call( - ctx: &GlobalContext, - sym_table: &SymTable, - args: &[Expression], - function: &Expression, -) -> ParserResult { - let types: Result>, _> = - args.iter().map(|v| infer_expr(ctx, sym_table, v)).collect(); - let types = types?; - if types.is_none() { - return Err("function params must have type".into()); - } - - let (obj, fun) = match &function.node { - ExpressionType::Identifier { name } => (None, name), - ExpressionType::Attribute { value, name } => ( - Some(infer_expr(ctx, sym_table, &value)?.ok_or("no value".to_string())?), - name, - ), - _ => return Err("not supported".into()), - }; - resolve_call(ctx, obj, fun.as_str(), &types.unwrap()) -} - -fn infer_subscript( - ctx: &GlobalContext, - sym_table: &SymTable, - a: &Expression, - b: &Expression, -) -> ParserResult { - let a = infer_expr(ctx, sym_table, a)?.ok_or("no value".to_string())?; - let t = if let ParametricType(LIST_TYPE, ls) = a.as_ref() { - ls[0].clone() - } else { - return Err("subscript is not supported for types other than list".into()); - }; - - match &b.node { - ExpressionType::Slice { elements } => { - let int32 = Rc::new(PrimitiveType(INT32_TYPE)); - let types: Result>, _> = elements - .iter() - .map(|v| { - if let ExpressionType::None = v.node { - Ok(Some(int32.clone())) - } else { - infer_expr(ctx, sym_table, v) - } - }) - .collect(); - let types = types?.ok_or("slice must have type".to_string())?; - if types.iter().all(|v| v == &int32) { - Ok(Some(a)) - } else { - Err("slice must be int32 type".into()) - } - } - _ => { - let b = infer_expr(ctx, sym_table, b)?.ok_or("no value".to_string())?; - if b.as_ref() == &PrimitiveType(INT32_TYPE) { - Ok(Some(t)) - } else { - Err("index must be either slice or int32".into()) - } - } - } -} - -fn infer_if_expr( - ctx: &GlobalContext, - sym_table: &SymTable, - test: &Expression, - body: &Expression, - orelse: &Expression, -) -> ParserResult { - let test = infer_expr(ctx, sym_table, test)?.ok_or("no value".to_string())?; - if test.as_ref() != &PrimitiveType(BOOL_TYPE) { - return Err("test should be bool".into()); - } - - let body = infer_expr(ctx, sym_table, body)?; - let orelse = infer_expr(ctx, sym_table, orelse)?; - if body.as_ref() == orelse.as_ref() { - Ok(body) - } else { - Err("divergent type".into()) - } -} - -fn infer_simple_binding<'a: 'b, 'b>( - sym_table: &mut SymTable<'b>, - name: &'a Expression, - ty: Rc, -) -> Result<(), String> { - match &name.node { - ExpressionType::Identifier { name } => { - if name == "_" { - Ok(()) - } else if sym_table.get(name.as_str()).is_some() { - Err("duplicated naming".into()) - } else { - sym_table.insert(name.as_str(), ty); - Ok(()) - } - } - ExpressionType::Tuple { elements } => { - if let ParametricType(TUPLE_TYPE, ls) = ty.as_ref() { - if elements.len() == ls.len() { - for (a, b) in elements.iter().zip(ls.iter()) { - infer_simple_binding(sym_table, a, b.clone())?; - } - Ok(()) - } else { - Err("different length".into()) - } - } else { - Err("not supported".into()) - } - } - _ => Err("not supported".into()), - } -} - -fn infer_list_comprehension( - ctx: &GlobalContext, - sym_table: &SymTable, - element: &Expression, - comprehension: &Comprehension, -) -> ParserResult { - if comprehension.is_async { - return Err("async is not supported".into()); - } - - // TODO: it may be more efficient to use multi-level table - // but it would better done in a whole program level - let iter = infer_expr(ctx, sym_table, &comprehension.iter)?.ok_or("no value".to_string())?; - if let ParametricType(LIST_TYPE, ls) = iter.as_ref() { - let mut local_sym = sym_table.clone(); - infer_simple_binding(&mut local_sym, &comprehension.target, ls[0].clone())?; - - let boolean = PrimitiveType(BOOL_TYPE); - for test in comprehension.ifs.iter() { - let result = - infer_expr(ctx, &local_sym, test)?.ok_or("no value in test".to_string())?; - if result.as_ref() != &boolean { - return Err("test must be bool".into()); - } - } - let result = infer_expr(ctx, &local_sym, element)?.ok_or("no value")?; - Ok(Some(ParametricType(LIST_TYPE, vec![result]).into())) - } else { - Err("iteration is supported for list only".into()) - } -} - -#[cfg(test)] -mod test { - use super::*; - use crate::typedef::*; - use rustpython_parser::parser::parse_expression; - - #[test] - fn test_constants() { - let ctx = basic_ctx(); - let sym_table = HashMap::new(); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("123").unwrap()); - assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("2147483647").unwrap()); - assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("2147483648").unwrap()); - assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT64_TYPE).into()); - - let result = infer_expr( - &ctx, - &sym_table, - &parse_expression("9223372036854775807").unwrap(), - ); - assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT64_TYPE).into()); - - let result = infer_expr( - &ctx, - &sym_table, - &parse_expression("9223372036854775808").unwrap(), - ); - assert_eq!(result, Err("integer out of range".into())); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("123.456").unwrap()); - assert_eq!(result.unwrap().unwrap(), PrimitiveType(FLOAT_TYPE).into()); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("True").unwrap()); - assert_eq!(result.unwrap().unwrap(), PrimitiveType(BOOL_TYPE).into()); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("False").unwrap()); - assert_eq!(result.unwrap().unwrap(), PrimitiveType(BOOL_TYPE).into()); - } - - #[test] - fn test_identifier() { - let ctx = basic_ctx(); - let mut sym_table = HashMap::new(); - sym_table.insert("abc", Rc::new(PrimitiveType(INT32_TYPE))); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("abc").unwrap()); - assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("ab").unwrap()); - assert_eq!(result, Err("unbounded variable".into())); - } - - #[test] - fn test_list() { - let mut ctx = basic_ctx(); - ctx.add_fn( - "foo", - FnDef { - args: vec![], - result: None, - }, - ); - let mut sym_table = HashMap::new(); - sym_table.insert("abc", Rc::new(PrimitiveType(INT32_TYPE))); - // def is reserved... - sym_table.insert("efg", Rc::new(PrimitiveType(INT32_TYPE))); - sym_table.insert("xyz", Rc::new(PrimitiveType(FLOAT_TYPE))); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("[]").unwrap()); - assert_eq!( - result.unwrap().unwrap(), - ParametricType(LIST_TYPE, vec![BotType.into()]).into() - ); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("[abc]").unwrap()); - assert_eq!( - result.unwrap().unwrap(), - ParametricType(LIST_TYPE, vec![PrimitiveType(INT32_TYPE).into()]).into() - ); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("[abc, efg]").unwrap()); - assert_eq!( - result.unwrap().unwrap(), - ParametricType(LIST_TYPE, vec![PrimitiveType(INT32_TYPE).into()]).into() - ); - - let result = infer_expr( - &ctx, - &sym_table, - &parse_expression("[abc, efg, xyz]").unwrap(), - ); - assert_eq!(result, Err("inhomogeneous list is not allowed".into())); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("[foo()]").unwrap()); - assert_eq!(result, Err("list elements must have some type".into())); - } - - #[test] - fn test_tuple() { - let mut ctx = basic_ctx(); - ctx.add_fn( - "foo", - FnDef { - args: vec![], - result: None, - }, - ); - let mut sym_table = HashMap::new(); - sym_table.insert("abc", Rc::new(PrimitiveType(INT32_TYPE))); - sym_table.insert("efg", Rc::new(PrimitiveType(FLOAT_TYPE))); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("(abc, efg)").unwrap()); - assert_eq!( - result.unwrap().unwrap(), - ParametricType( - TUPLE_TYPE, - vec![ - PrimitiveType(INT32_TYPE).into(), - PrimitiveType(FLOAT_TYPE).into() - ] - ) - .into() - ); - - let result = infer_expr( - &ctx, - &sym_table, - &parse_expression("(abc, efg, foo())").unwrap(), - ); - assert_eq!(result, Err("tuple elements must have some type".into())); - } - - #[test] - fn test_attribute() { - let mut ctx = basic_ctx(); - ctx.add_fn( - "none", - FnDef { - args: vec![], - result: None, - }, - ); - - let foo = ctx.add_class(ClassDef { - base: TypeDef { - name: "Foo", - fields: HashMap::new(), - methods: HashMap::new(), - }, - parents: vec![], - }); - let foo_def = ctx.get_class_mut(foo); - foo_def - .base - .fields - .insert("a", PrimitiveType(INT32_TYPE).into()); - foo_def.base.fields.insert("b", ClassType(foo).into()); - foo_def - .base - .fields - .insert("c", PrimitiveType(INT32_TYPE).into()); - - let bar = ctx.add_class(ClassDef { - base: TypeDef { - name: "Bar", - fields: HashMap::new(), - methods: HashMap::new(), - }, - parents: vec![], - }); - let bar_def = ctx.get_class_mut(bar); - bar_def - .base - .fields - .insert("a", PrimitiveType(INT32_TYPE).into()); - bar_def.base.fields.insert("b", ClassType(bar).into()); - bar_def - .base - .fields - .insert("c", PrimitiveType(FLOAT_TYPE).into()); - - let v0 = ctx.add_variable(VarDef { - name: "v0", - bound: vec![], - }); - - let v1 = ctx.add_variable(VarDef { - name: "v1", - bound: vec![ClassType(foo).into(), ClassType(bar).into()], - }); - - let mut sym_table = HashMap::new(); - sym_table.insert("foo", Rc::new(ClassType(foo))); - sym_table.insert("bar", Rc::new(ClassType(bar))); - sym_table.insert("foobar", Rc::new(VirtualClassType(foo))); - sym_table.insert("v0", Rc::new(TypeVariable(v0))); - sym_table.insert("v1", Rc::new(TypeVariable(v1))); - sym_table.insert("bot", Rc::new(BotType)); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("foo.a").unwrap()); - assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("foo.d").unwrap()); - assert_eq!(result, Err("no such field".into())); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("foobar.a").unwrap()); - assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("v0.a").unwrap()); - assert_eq!(result, Err("no fields on unbounded type variable".into())); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("v1.a").unwrap()); - assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); - - // shall we support this? - let result = infer_expr(&ctx, &sym_table, &parse_expression("v1.b").unwrap()); - assert_eq!( - result, - Err("unknown field (type mismatch between variants)".into()) - ); - // assert_eq!(result.unwrap().unwrap(), TypeVariable(v1).into()); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("v1.c").unwrap()); - assert_eq!( - result, - Err("unknown field (type mismatch between variants)".into()) - ); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("v1.d").unwrap()); - assert_eq!(result, Err("unknown field".into())); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("none().a").unwrap()); - assert_eq!(result, Err("no value".into())); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("bot.a").unwrap()); - assert_eq!(result, Err("this object has no fields".into())); - } - - #[test] - fn test_bool_ops() { - let mut ctx = basic_ctx(); - ctx.add_fn( - "none", - FnDef { - args: vec![], - result: None, - }, - ); - let sym_table = HashMap::new(); - - let result = infer_expr( - &ctx, - &sym_table, - &parse_expression("True and False").unwrap(), - ); - assert_eq!(result.unwrap().unwrap(), PrimitiveType(BOOL_TYPE).into()); - - let result = infer_expr( - &ctx, - &sym_table, - &parse_expression("True and none()").unwrap(), - ); - assert_eq!(result, Err("no value".into())); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("True and 123").unwrap()); - assert_eq!(result, Err("bool operands must be bool".into())); - } - - #[test] - fn test_bin_ops() { - let mut ctx = basic_ctx(); - let v0 = ctx.add_variable(VarDef { - name: "v0", - bound: vec![ - PrimitiveType(INT32_TYPE).into(), - PrimitiveType(INT64_TYPE).into(), - ], - }); - let mut sym_table = HashMap::new(); - sym_table.insert("a", TypeVariable(v0).into()); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("1 + 2 + 3").unwrap()); - assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("a + a + a").unwrap()); - assert_eq!(result.unwrap().unwrap(), TypeVariable(v0).into()); - } - - #[test] - fn test_unary_ops() { - let mut ctx = basic_ctx(); - let v0 = ctx.add_variable(VarDef { - name: "v0", - bound: vec![ - PrimitiveType(INT32_TYPE).into(), - PrimitiveType(INT64_TYPE).into(), - ], - }); - let mut sym_table = HashMap::new(); - sym_table.insert("a", TypeVariable(v0).into()); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("-(123)").unwrap()); - assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("-a").unwrap()); - assert_eq!(result.unwrap().unwrap(), TypeVariable(v0).into()); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("not True").unwrap()); - assert_eq!(result.unwrap().unwrap(), PrimitiveType(BOOL_TYPE).into()); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("not (1)").unwrap()); - assert_eq!(result, Err("logical not must be applied to bool".into())); - } - - #[test] - fn test_compare() { - let mut ctx = basic_ctx(); - let v0 = ctx.add_variable(VarDef { - name: "v0", - bound: vec![ - PrimitiveType(INT32_TYPE).into(), - PrimitiveType(INT64_TYPE).into(), - ], - }); - let mut sym_table = HashMap::new(); - sym_table.insert("a", TypeVariable(v0).into()); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("a == a == a").unwrap()); - assert_eq!(result.unwrap().unwrap(), PrimitiveType(BOOL_TYPE).into()); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("a == a == 1").unwrap()); - assert_eq!(result, Err("not equal".into())); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("True > False").unwrap()); - assert_eq!(result, Err("no such function".into())); - - let result = infer_expr( - &ctx, - &sym_table, - &parse_expression("True in False").unwrap(), - ); - assert_eq!(result, Err("unsupported comparison".into())); - } - - #[test] - fn test_call() { - let mut ctx = basic_ctx(); - ctx.add_fn( - "none", - FnDef { - args: vec![], - result: None, - }, - ); - - let foo = ctx.add_class(ClassDef { - base: TypeDef { - name: "Foo", - fields: HashMap::new(), - methods: HashMap::new(), - }, - parents: vec![], - }); - let foo_def = ctx.get_class_mut(foo); - foo_def.base.methods.insert( - "a", - FnDef { - args: vec![], - result: Some(Rc::new(ClassType(foo))), - }, - ); - - let bar = ctx.add_class(ClassDef { - base: TypeDef { - name: "Bar", - fields: HashMap::new(), - methods: HashMap::new(), - }, - parents: vec![], - }); - let bar_def = ctx.get_class_mut(bar); - bar_def.base.methods.insert( - "a", - FnDef { - args: vec![], - result: Some(Rc::new(ClassType(bar))), - }, - ); - - let v0 = ctx.add_variable(VarDef { - name: "v0", - bound: vec![], - }); - let v1 = ctx.add_variable(VarDef { - name: "v1", - bound: vec![ClassType(foo).into(), ClassType(bar).into()], - }); - let v2 = ctx.add_variable(VarDef { - name: "v2", - bound: vec![ - ClassType(foo).into(), - ClassType(bar).into(), - PrimitiveType(INT32_TYPE).into(), - ], - }); - let mut sym_table = HashMap::new(); - sym_table.insert("foo", Rc::new(ClassType(foo))); - sym_table.insert("bar", Rc::new(ClassType(bar))); - sym_table.insert("foobar", Rc::new(VirtualClassType(foo))); - sym_table.insert("v0", Rc::new(TypeVariable(v0))); - sym_table.insert("v1", Rc::new(TypeVariable(v1))); - sym_table.insert("v2", Rc::new(TypeVariable(v2))); - sym_table.insert("bot", Rc::new(BotType)); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("foo.a()").unwrap()); - assert_eq!(result.unwrap().unwrap(), ClassType(foo).into()); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("v1.a()").unwrap()); - assert_eq!(result.unwrap().unwrap(), TypeVariable(v1).into()); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("foobar.a()").unwrap()); - assert_eq!(result.unwrap().unwrap(), ClassType(foo).into()); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("none().a()").unwrap()); - assert_eq!(result, Err("no value".into())); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("bot.a()").unwrap()); - assert_eq!(result, Err("not supported".into())); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("[][0].a()").unwrap()); - assert_eq!(result, Err("not supported".into())); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("v0.a()").unwrap()); - assert_eq!(result, Err("unbounded type var".into())); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("v2.a()").unwrap()); - assert_eq!(result, Err("no such function".into())); - } - - #[test] - fn infer_subscript() { - let mut ctx = basic_ctx(); - ctx.add_fn( - "none", - FnDef { - args: vec![], - result: None, - }, - ); - let sym_table = HashMap::new(); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("[1, 2, 3][0]").unwrap()); - assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("[[1]][0][0]").unwrap()); - assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); - - let result = infer_expr( - &ctx, - &sym_table, - &parse_expression("[1, 2, 3][1:2]").unwrap(), - ); - assert_eq!( - result.unwrap().unwrap(), - ParametricType(LIST_TYPE, vec![PrimitiveType(INT32_TYPE).into()]).into() - ); - - let result = infer_expr( - &ctx, - &sym_table, - &parse_expression("[1, 2, 3][1:2:2]").unwrap(), - ); - assert_eq!( - result.unwrap().unwrap(), - ParametricType(LIST_TYPE, vec![PrimitiveType(INT32_TYPE).into()]).into() - ); - - let result = infer_expr( - &ctx, - &sym_table, - &parse_expression("[1, 2, 3][1:1.2]").unwrap(), - ); - assert_eq!(result, Err("slice must be int32 type".into())); - - let result = infer_expr( - &ctx, - &sym_table, - &parse_expression("[1, 2, 3][1:none()]").unwrap(), - ); - assert_eq!(result, Err("slice must have type".into())); - - let result = infer_expr( - &ctx, - &sym_table, - &parse_expression("[1, 2, 3][1.2]").unwrap(), - ); - assert_eq!(result, Err("index must be either slice or int32".into())); - - let result = infer_expr( - &ctx, - &sym_table, - &parse_expression("[1, 2, 3][none()]").unwrap(), - ); - assert_eq!(result, Err("no value".into())); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("none()[1.2]").unwrap()); - assert_eq!(result, Err("no value".into())); - - let result = infer_expr(&ctx, &sym_table, &parse_expression("123[1]").unwrap()); - assert_eq!( - result, - Err("subscript is not supported for types other than list".into()) - ); - } - - #[test] - fn test_if_expr() { - let mut ctx = basic_ctx(); - ctx.add_fn( - "none", - FnDef { - args: vec![], - result: None, - }, - ); - let sym_table = HashMap::new(); - - let result = infer_expr( - &ctx, - &sym_table, - &parse_expression("1 if True else 0").unwrap(), - ); - assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); - - let result = infer_expr( - &ctx, - &sym_table, - &parse_expression("none() if True else none()").unwrap(), - ); - assert_eq!(result.unwrap(), None); - - let result = infer_expr( - &ctx, - &sym_table, - &parse_expression("none() if 1 else none()").unwrap(), - ); - assert_eq!(result, Err("test should be bool".into())); - - let result = infer_expr( - &ctx, - &sym_table, - &parse_expression("1 if True else none()").unwrap(), - ); - assert_eq!(result, Err("divergent type".into())); - } - - #[test] - fn test_list_comp() { - let mut ctx = basic_ctx(); - ctx.add_fn( - "none", - FnDef { - args: vec![], - result: None, - }, - ); - let int32 = Rc::new(PrimitiveType(INT32_TYPE)); - let mut sym_table = HashMap::new(); - sym_table.insert("z", int32.clone()); - - let result = infer_expr( - &ctx, - &sym_table, - &parse_expression("[x for x in [(1, 2), (2, 3), (3, 4)]][0]").unwrap(), - ); - assert_eq!( - result.unwrap().unwrap(), - ParametricType(TUPLE_TYPE, vec![int32.clone(), int32.clone()]).into() - ); - - let result = infer_expr( - &ctx, - &sym_table, - &parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)]][0]").unwrap(), - ); - assert_eq!(result.unwrap().unwrap(), int32.clone()); - - let result = infer_expr( - &ctx, - &sym_table, - &parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)] if x > 0][0]").unwrap(), - ); - assert_eq!(result.unwrap().unwrap(), int32.clone()); - - let result = infer_expr( - &ctx, - &sym_table, - &parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)] if x][0]").unwrap(), - ); - assert_eq!(result, Err("test must be bool".into())); - - let result = infer_expr( - &ctx, - &sym_table, - &parse_expression("[y for x in []][0]").unwrap(), - ); - assert_eq!(result, Err("unbounded variable".into())); - - let result = infer_expr( - &ctx, - &sym_table, - &parse_expression("[none() for x in []][0]").unwrap(), - ); - assert_eq!(result, Err("no value".into())); - - let result = infer_expr( - &ctx, - &sym_table, - &parse_expression("[z for z in []][0]").unwrap(), - ); - assert_eq!(result, Err("duplicated naming".into())); - - let result = infer_expr( - &ctx, - &sym_table, - &parse_expression("[x for x in [] for y in []]").unwrap(), - ); - assert_eq!( - result, - Err("only 1 generator statement is supported".into()) - ); - } -} diff --git a/nac3core/src/expression_inference.rs b/nac3core/src/expression_inference.rs new file mode 100644 index 000000000..369285266 --- /dev/null +++ b/nac3core/src/expression_inference.rs @@ -0,0 +1,1007 @@ +use crate::inference_core::resolve_call; +use crate::magic_methods::*; +use crate::primitives::*; +use crate::typedef::{Type, TypeEnum::*}; +use crate::context::InferenceContext; +use rustpython_parser::ast::{ + Comparison, Comprehension, ComprehensionKind, Expression, ExpressionType, Operator, + UnaryOperator, +}; +use std::convert::TryInto; + +type ParserResult = Result, String>; + +pub fn infer_expr<'b: 'a, 'a>(ctx: &mut InferenceContext<'a>, expr: &'b Expression) -> ParserResult { + match &expr.node { + ExpressionType::Number { value } => infer_constant(ctx, value), + ExpressionType::Identifier { name } => infer_identifier(ctx, name), + ExpressionType::List { elements } => infer_list(ctx, elements), + ExpressionType::Tuple { elements } => infer_tuple(ctx, elements), + ExpressionType::Attribute { value, name } => infer_attribute(ctx, value, name), + ExpressionType::BoolOp { values, .. } => infer_bool_ops(ctx, values), + ExpressionType::Binop { a, b, op } => infer_bin_ops(ctx, op, a, b), + ExpressionType::Unop { op, a } => infer_unary_ops(ctx, op, a), + ExpressionType::Compare { vals, ops } => infer_compare(ctx, vals, ops), + ExpressionType::Call { + args, + function, + keywords, + } => { + if keywords.len() > 0 { + Err("keyword is not supported".into()) + } else { + infer_call(ctx, &args, &function) + } + } + ExpressionType::Subscript { a, b } => infer_subscript(ctx, a, b), + ExpressionType::IfExpression { test, body, orelse } => { + infer_if_expr(ctx, &test, &body, orelse) + } + ExpressionType::Comprehension { kind, generators } => match kind.as_ref() { + ComprehensionKind::List { element } => { + if generators.len() == 1 { + infer_list_comprehension(ctx, element, &generators[0]) + } else { + Err("only 1 generator statement is supported".into()) + } + } + _ => Err("only list comprehension is supported".into()), + }, + ExpressionType::True | ExpressionType::False => Ok(Some(ctx.get_primitive(BOOL_TYPE))), + _ => Err("not supported".into()), + } +} + +fn infer_constant( + ctx: &mut InferenceContext, + value: &rustpython_parser::ast::Number, +) -> ParserResult { + use rustpython_parser::ast::Number; + match value { + Number::Integer { value } => { + let int32: Result = value.try_into(); + if int32.is_ok() { + Ok(Some(ctx.get_primitive(INT32_TYPE))) + } else { + let int64: Result = value.try_into(); + if int64.is_ok() { + Ok(Some(ctx.get_primitive(INT64_TYPE))) + } else { + Err("integer out of range".into()) + } + } + } + Number::Float { .. } => Ok(Some(ctx.get_primitive(FLOAT_TYPE))), + _ => Err("not supported".into()), + } +} + +fn infer_identifier(ctx: &mut InferenceContext, name: &str) -> ParserResult { + Ok(Some(ctx.resolve(name)?)) +} + +fn infer_list<'b: 'a, 'a>(ctx: &mut InferenceContext<'a>, elements: &'b [Expression]) -> ParserResult { + if elements.len() == 0 { + return Ok(Some(ParametricType(LIST_TYPE, vec![BotType.into()]).into())); + } + + let mut types = elements.iter().map(|v| infer_expr(ctx, v)); + + let head = types.next().unwrap()?; + if head.is_none() { + return Err("list elements must have some type".into()); + } + for v in types { + if v? != head { + return Err("inhomogeneous list is not allowed".into()); + } + } + Ok(Some(ParametricType(LIST_TYPE, vec![head.unwrap()]).into())) +} + +fn infer_tuple<'b: 'a, 'a>(ctx: &mut InferenceContext<'a>, elements: &'b [Expression]) -> ParserResult { + let types: Result>, String> = elements + .iter() + .map(|v| infer_expr(ctx, v)) + .collect(); + if let Some(t) = types? { + Ok(Some(ParametricType(TUPLE_TYPE, t).into())) + } else { + Err("tuple elements must have some type".into()) + } +} + +fn infer_attribute<'b: 'a, 'a>( + ctx: &mut InferenceContext<'a>, + value: &'a Expression, + name: &String, +) -> ParserResult { + let value = infer_expr(ctx, value)?.ok_or("no value".to_string())?; + if let TypeVariable(id) = value.as_ref() { + let v = ctx.get_variable_def(*id); + if v.bound.len() == 0 { + return Err("no fields on unbounded type variable".into()); + } + let ty = v.bound[0] + .get_base(ctx) + .and_then(|v| v.fields.get(name.as_str())); + if ty.is_none() { + return Err("unknown field".into()); + } + for x in v.bound[1..].iter() { + let ty1 = x.get_base(ctx).and_then(|v| v.fields.get(name.as_str())); + if ty1 != ty { + return Err("unknown field (type mismatch between variants)".into()); + } + } + return Ok(Some(ty.unwrap().clone())); + } + + match value.get_base(ctx) { + Some(b) => match b.fields.get(name.as_str()) { + Some(t) => Ok(Some(t.clone())), + None => Err("no such field".into()), + }, + None => Err("this object has no fields".into()), + } +} + +fn infer_bool_ops<'b: 'a, 'a>( + ctx: &mut InferenceContext<'a>, + values: &'a [Expression], +) -> ParserResult { + assert_eq!(values.len(), 2); + let left = infer_expr(ctx, &values[0])?.ok_or("no value".to_string())?; + let right = infer_expr(ctx, &values[1])?.ok_or("no value".to_string())?; + + let b = ctx.get_primitive(BOOL_TYPE); + if left == b && right == b { + Ok(Some(b.into())) + } else { + Err("bool operands must be bool".into()) + } +} + +fn infer_bin_ops<'b: 'a, 'a>( + ctx: &mut InferenceContext<'a>, + op: &Operator, + left: &'b Expression, + right: &'b Expression, +) -> ParserResult { + let left = infer_expr(ctx, left)?.ok_or("no value".to_string())?; + let right = infer_expr(ctx, right)?.ok_or("no value".to_string())?; + let fun = binop_name(op); + resolve_call(ctx, Some(left), fun, &[right]) +} + +fn infer_unary_ops<'b: 'a, 'a>( + ctx: &mut InferenceContext<'a>, + op: &UnaryOperator, + obj: &'b Expression, +) -> ParserResult { + let ty = infer_expr(ctx, obj)?.ok_or("no value".to_string())?; + if let UnaryOperator::Not = op { + if ty == ctx.get_primitive(BOOL_TYPE) { + Ok(Some(ty)) + } else { + Err("logical not must be applied to bool".into()) + } + } else { + resolve_call(ctx, Some(ty), unaryop_name(op), &[]) + } +} + +fn infer_compare<'b: 'a, 'a>( + ctx: &mut InferenceContext<'a>, + vals: &'b [Expression], + ops: &'b [Comparison], +) -> ParserResult { + let types: Result>, _> = + vals.iter().map(|v| infer_expr(ctx, v)).collect(); + let types = types?; + if types.is_none() { + return Err("comparison operands must have type".into()); + } + let types = types.unwrap(); + let boolean = ctx.get_primitive(BOOL_TYPE); + let left = &types[..types.len() - 1]; + let right = &types[1..]; + + for ((a, b), op) in left.iter().zip(right.iter()).zip(ops.iter()) { + let fun = comparison_name(op).ok_or("unsupported comparison".to_string())?; + let ty = resolve_call(ctx, Some(a.clone()), fun, &[b.clone()])?; + if ty.is_none() || ty.unwrap() != boolean { + return Err("comparison result must be boolean".into()); + } + } + Ok(Some(boolean.into())) +} + +fn infer_call<'b: 'a, 'a>( + ctx: &mut InferenceContext<'a>, + args: &'b [Expression], + function: &'b Expression, +) -> ParserResult { + let types: Result>, _> = + args.iter().map(|v| infer_expr(ctx, v)).collect(); + let types = types?; + if types.is_none() { + return Err("function params must have type".into()); + } + + let (obj, fun) = match &function.node { + ExpressionType::Identifier { name } => (None, name), + ExpressionType::Attribute { value, name } => ( + Some(infer_expr(ctx, &value)?.ok_or("no value".to_string())?), + name, + ), + _ => return Err("not supported".into()), + }; + resolve_call(ctx, obj, fun.as_str(), &types.unwrap()) +} + +fn infer_subscript<'b: 'a, 'a>( + ctx: &mut InferenceContext<'a>, + a: &'b Expression, + b: &'b Expression, +) -> ParserResult { + let a = infer_expr(ctx, a)?.ok_or("no value".to_string())?; + let t = if let ParametricType(LIST_TYPE, ls) = a.as_ref() { + ls[0].clone() + } else { + return Err("subscript is not supported for types other than list".into()); + }; + + match &b.node { + ExpressionType::Slice { elements } => { + let int32 = ctx.get_primitive(INT32_TYPE); + let types: Result>, _> = elements + .iter() + .map(|v| { + if let ExpressionType::None = v.node { + Ok(Some(int32.clone())) + } else { + infer_expr(ctx, v) + } + }) + .collect(); + let types = types?.ok_or("slice must have type".to_string())?; + if types.iter().all(|v| v == &int32) { + Ok(Some(a)) + } else { + Err("slice must be int32 type".into()) + } + } + _ => { + let b = infer_expr(ctx, b)?.ok_or("no value".to_string())?; + if b == ctx.get_primitive(INT32_TYPE) { + Ok(Some(t)) + } else { + Err("index must be either slice or int32".into()) + } + } + } +} + +fn infer_if_expr<'b: 'a, 'a>( + ctx: &mut InferenceContext<'a>, + test: &'b Expression, + body: &'b Expression, + orelse: &'b Expression, +) -> ParserResult { + let test = infer_expr(ctx, test)?.ok_or("no value".to_string())?; + if test != ctx.get_primitive(BOOL_TYPE) { + return Err("test should be bool".into()); + } + + let body = infer_expr(ctx, body)?; + let orelse = infer_expr(ctx, orelse)?; + if body.as_ref() == orelse.as_ref() { + Ok(body) + } else { + Err("divergent type".into()) + } +} + +fn infer_simple_binding<'a: 'b, 'b>( + ctx: &mut InferenceContext<'b>, + name: &'a Expression, + ty: Type, +) -> Result<(), String> { + match &name.node { + ExpressionType::Identifier { name } => { + if name == "_" { + Ok(()) + } else if ctx.defined(name.as_str()) { + Err("duplicated naming".into()) + } else { + ctx.assign(name.as_str(), ty)?; + Ok(()) + } + } + ExpressionType::Tuple { elements } => { + if let ParametricType(TUPLE_TYPE, ls) = ty.as_ref() { + if elements.len() == ls.len() { + for (a, b) in elements.iter().zip(ls.iter()) { + infer_simple_binding(ctx, a, b.clone())?; + } + Ok(()) + } else { + Err("different length".into()) + } + } else { + Err("not supported".into()) + } + } + _ => Err("not supported".into()), + } +} + +fn infer_list_comprehension<'b: 'a, 'a>( + ctx: &mut InferenceContext<'a>, + element: &'b Expression, + comprehension: &'b Comprehension, +) -> ParserResult { + if comprehension.is_async { + return Err("async is not supported".into()); + } + + let iter = infer_expr(ctx, &comprehension.iter)?.ok_or("no value".to_string())?; + if let ParametricType(LIST_TYPE, ls) = iter.as_ref() { + ctx.with_scope(|ctx| { + infer_simple_binding(ctx, &comprehension.target, ls[0].clone())?; + + let boolean = ctx.get_primitive(BOOL_TYPE); + for test in comprehension.ifs.iter() { + let result = + infer_expr(ctx, test)?.ok_or("no value in test".to_string())?; + if result != boolean { + return Err("test must be bool".into()); + } + } + let result = infer_expr(ctx, element)?.ok_or("no value")?; + Ok(Some(ParametricType(LIST_TYPE, vec![result]).into())) + }).1 + } else { + Err("iteration is supported for list only".into()) + } +} + +// #[cfg(test)] +// mod test { +// use super::*; +// use crate::typedef::*; +// use rustpython_parser::parser::parse_expression; + +// #[test] +// fn test_constants() { +// let ctx = basic_ctx(); +// let sym_table = HashMap::new(); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("123").unwrap()); +// assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("2147483647").unwrap()); +// assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("2147483648").unwrap()); +// assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT64_TYPE).into()); + +// let result = infer_expr( +// &ctx, +// &sym_table, +// &parse_expression("9223372036854775807").unwrap(), +// ); +// assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT64_TYPE).into()); + +// let result = infer_expr( +// &ctx, +// &sym_table, +// &parse_expression("9223372036854775808").unwrap(), +// ); +// assert_eq!(result, Err("integer out of range".into())); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("123.456").unwrap()); +// assert_eq!(result.unwrap().unwrap(), PrimitiveType(FLOAT_TYPE).into()); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("True").unwrap()); +// assert_eq!(result.unwrap().unwrap(), PrimitiveType(BOOL_TYPE).into()); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("False").unwrap()); +// assert_eq!(result.unwrap().unwrap(), PrimitiveType(BOOL_TYPE).into()); +// } + +// #[test] +// fn test_identifier() { +// let ctx = basic_ctx(); +// let mut sym_table = HashMap::new(); +// sym_table.insert("abc", Rc::new(PrimitiveType(INT32_TYPE))); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("abc").unwrap()); +// assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("ab").unwrap()); +// assert_eq!(result, Err("unbounded variable".into())); +// } + +// #[test] +// fn test_list() { +// let mut ctx = basic_ctx(); +// ctx.add_fn( +// "foo", +// FnDef { +// args: vec![], +// result: None, +// }, +// ); +// let mut sym_table = HashMap::new(); +// sym_table.insert("abc", Rc::new(PrimitiveType(INT32_TYPE))); +// // def is reserved... +// sym_table.insert("efg", Rc::new(PrimitiveType(INT32_TYPE))); +// sym_table.insert("xyz", Rc::new(PrimitiveType(FLOAT_TYPE))); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("[]").unwrap()); +// assert_eq!( +// result.unwrap().unwrap(), +// ParametricType(LIST_TYPE, vec![BotType.into()]).into() +// ); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("[abc]").unwrap()); +// assert_eq!( +// result.unwrap().unwrap(), +// ParametricType(LIST_TYPE, vec![PrimitiveType(INT32_TYPE).into()]).into() +// ); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("[abc, efg]").unwrap()); +// assert_eq!( +// result.unwrap().unwrap(), +// ParametricType(LIST_TYPE, vec![PrimitiveType(INT32_TYPE).into()]).into() +// ); + +// let result = infer_expr( +// &ctx, +// &sym_table, +// &parse_expression("[abc, efg, xyz]").unwrap(), +// ); +// assert_eq!(result, Err("inhomogeneous list is not allowed".into())); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("[foo()]").unwrap()); +// assert_eq!(result, Err("list elements must have some type".into())); +// } + +// #[test] +// fn test_tuple() { +// let mut ctx = basic_ctx(); +// ctx.add_fn( +// "foo", +// FnDef { +// args: vec![], +// result: None, +// }, +// ); +// let mut sym_table = HashMap::new(); +// sym_table.insert("abc", Rc::new(PrimitiveType(INT32_TYPE))); +// sym_table.insert("efg", Rc::new(PrimitiveType(FLOAT_TYPE))); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("(abc, efg)").unwrap()); +// assert_eq!( +// result.unwrap().unwrap(), +// ParametricType( +// TUPLE_TYPE, +// vec![ +// PrimitiveType(INT32_TYPE).into(), +// PrimitiveType(FLOAT_TYPE).into() +// ] +// ) +// .into() +// ); + +// let result = infer_expr( +// &ctx, +// &sym_table, +// &parse_expression("(abc, efg, foo())").unwrap(), +// ); +// assert_eq!(result, Err("tuple elements must have some type".into())); +// } + +// #[test] +// fn test_attribute() { +// let mut ctx = basic_ctx(); +// ctx.add_fn( +// "none", +// FnDef { +// args: vec![], +// result: None, +// }, +// ); + +// let foo = ctx.add_class(ClassDef { +// base: TypeDef { +// name: "Foo", +// fields: HashMap::new(), +// methods: HashMap::new(), +// }, +// parents: vec![], +// }); +// let foo_def = ctx.get_class_mut(foo); +// foo_def +// .base +// .fields +// .insert("a", PrimitiveType(INT32_TYPE).into()); +// foo_def.base.fields.insert("b", ClassType(foo).into()); +// foo_def +// .base +// .fields +// .insert("c", PrimitiveType(INT32_TYPE).into()); + +// let bar = ctx.add_class(ClassDef { +// base: TypeDef { +// name: "Bar", +// fields: HashMap::new(), +// methods: HashMap::new(), +// }, +// parents: vec![], +// }); +// let bar_def = ctx.get_class_mut(bar); +// bar_def +// .base +// .fields +// .insert("a", PrimitiveType(INT32_TYPE).into()); +// bar_def.base.fields.insert("b", ClassType(bar).into()); +// bar_def +// .base +// .fields +// .insert("c", PrimitiveType(FLOAT_TYPE).into()); + +// let v0 = ctx.add_variable(VarDef { +// name: "v0", +// bound: vec![], +// }); + +// let v1 = ctx.add_variable(VarDef { +// name: "v1", +// bound: vec![ClassType(foo).into(), ClassType(bar).into()], +// }); + +// let mut sym_table = HashMap::new(); +// sym_table.insert("foo", Rc::new(ClassType(foo))); +// sym_table.insert("bar", Rc::new(ClassType(bar))); +// sym_table.insert("foobar", Rc::new(VirtualClassType(foo))); +// sym_table.insert("v0", Rc::new(TypeVariable(v0))); +// sym_table.insert("v1", Rc::new(TypeVariable(v1))); +// sym_table.insert("bot", Rc::new(BotType)); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("foo.a").unwrap()); +// assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("foo.d").unwrap()); +// assert_eq!(result, Err("no such field".into())); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("foobar.a").unwrap()); +// assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("v0.a").unwrap()); +// assert_eq!(result, Err("no fields on unbounded type variable".into())); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("v1.a").unwrap()); +// assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); + +// // shall we support this? +// let result = infer_expr(&ctx, &sym_table, &parse_expression("v1.b").unwrap()); +// assert_eq!( +// result, +// Err("unknown field (type mismatch between variants)".into()) +// ); +// // assert_eq!(result.unwrap().unwrap(), TypeVariable(v1).into()); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("v1.c").unwrap()); +// assert_eq!( +// result, +// Err("unknown field (type mismatch between variants)".into()) +// ); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("v1.d").unwrap()); +// assert_eq!(result, Err("unknown field".into())); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("none().a").unwrap()); +// assert_eq!(result, Err("no value".into())); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("bot.a").unwrap()); +// assert_eq!(result, Err("this object has no fields".into())); +// } + +// #[test] +// fn test_bool_ops() { +// let mut ctx = basic_ctx(); +// ctx.add_fn( +// "none", +// FnDef { +// args: vec![], +// result: None, +// }, +// ); +// let sym_table = HashMap::new(); + +// let result = infer_expr( +// &ctx, +// &sym_table, +// &parse_expression("True and False").unwrap(), +// ); +// assert_eq!(result.unwrap().unwrap(), PrimitiveType(BOOL_TYPE).into()); + +// let result = infer_expr( +// &ctx, +// &sym_table, +// &parse_expression("True and none()").unwrap(), +// ); +// assert_eq!(result, Err("no value".into())); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("True and 123").unwrap()); +// assert_eq!(result, Err("bool operands must be bool".into())); +// } + +// #[test] +// fn test_bin_ops() { +// let mut ctx = basic_ctx(); +// let v0 = ctx.add_variable(VarDef { +// name: "v0", +// bound: vec![ +// PrimitiveType(INT32_TYPE).into(), +// PrimitiveType(INT64_TYPE).into(), +// ], +// }); +// let mut sym_table = HashMap::new(); +// sym_table.insert("a", TypeVariable(v0).into()); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("1 + 2 + 3").unwrap()); +// assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("a + a + a").unwrap()); +// assert_eq!(result.unwrap().unwrap(), TypeVariable(v0).into()); +// } + +// #[test] +// fn test_unary_ops() { +// let mut ctx = basic_ctx(); +// let v0 = ctx.add_variable(VarDef { +// name: "v0", +// bound: vec![ +// PrimitiveType(INT32_TYPE).into(), +// PrimitiveType(INT64_TYPE).into(), +// ], +// }); +// let mut sym_table = HashMap::new(); +// sym_table.insert("a", TypeVariable(v0).into()); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("-(123)").unwrap()); +// assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("-a").unwrap()); +// assert_eq!(result.unwrap().unwrap(), TypeVariable(v0).into()); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("not True").unwrap()); +// assert_eq!(result.unwrap().unwrap(), PrimitiveType(BOOL_TYPE).into()); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("not (1)").unwrap()); +// assert_eq!(result, Err("logical not must be applied to bool".into())); +// } + +// #[test] +// fn test_compare() { +// let mut ctx = basic_ctx(); +// let v0 = ctx.add_variable(VarDef { +// name: "v0", +// bound: vec![ +// PrimitiveType(INT32_TYPE).into(), +// PrimitiveType(INT64_TYPE).into(), +// ], +// }); +// let mut sym_table = HashMap::new(); +// sym_table.insert("a", TypeVariable(v0).into()); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("a == a == a").unwrap()); +// assert_eq!(result.unwrap().unwrap(), PrimitiveType(BOOL_TYPE).into()); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("a == a == 1").unwrap()); +// assert_eq!(result, Err("not equal".into())); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("True > False").unwrap()); +// assert_eq!(result, Err("no such function".into())); + +// let result = infer_expr( +// &ctx, +// &sym_table, +// &parse_expression("True in False").unwrap(), +// ); +// assert_eq!(result, Err("unsupported comparison".into())); +// } + +// #[test] +// fn test_call() { +// let mut ctx = basic_ctx(); +// ctx.add_fn( +// "none", +// FnDef { +// args: vec![], +// result: None, +// }, +// ); + +// let foo = ctx.add_class(ClassDef { +// base: TypeDef { +// name: "Foo", +// fields: HashMap::new(), +// methods: HashMap::new(), +// }, +// parents: vec![], +// }); +// let foo_def = ctx.get_class_mut(foo); +// foo_def.base.methods.insert( +// "a", +// FnDef { +// args: vec![], +// result: Some(Rc::new(ClassType(foo))), +// }, +// ); + +// let bar = ctx.add_class(ClassDef { +// base: TypeDef { +// name: "Bar", +// fields: HashMap::new(), +// methods: HashMap::new(), +// }, +// parents: vec![], +// }); +// let bar_def = ctx.get_class_mut(bar); +// bar_def.base.methods.insert( +// "a", +// FnDef { +// args: vec![], +// result: Some(Rc::new(ClassType(bar))), +// }, +// ); + +// let v0 = ctx.add_variable(VarDef { +// name: "v0", +// bound: vec![], +// }); +// let v1 = ctx.add_variable(VarDef { +// name: "v1", +// bound: vec![ClassType(foo).into(), ClassType(bar).into()], +// }); +// let v2 = ctx.add_variable(VarDef { +// name: "v2", +// bound: vec![ +// ClassType(foo).into(), +// ClassType(bar).into(), +// PrimitiveType(INT32_TYPE).into(), +// ], +// }); +// let mut sym_table = HashMap::new(); +// sym_table.insert("foo", Rc::new(ClassType(foo))); +// sym_table.insert("bar", Rc::new(ClassType(bar))); +// sym_table.insert("foobar", Rc::new(VirtualClassType(foo))); +// sym_table.insert("v0", Rc::new(TypeVariable(v0))); +// sym_table.insert("v1", Rc::new(TypeVariable(v1))); +// sym_table.insert("v2", Rc::new(TypeVariable(v2))); +// sym_table.insert("bot", Rc::new(BotType)); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("foo.a()").unwrap()); +// assert_eq!(result.unwrap().unwrap(), ClassType(foo).into()); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("v1.a()").unwrap()); +// assert_eq!(result.unwrap().unwrap(), TypeVariable(v1).into()); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("foobar.a()").unwrap()); +// assert_eq!(result.unwrap().unwrap(), ClassType(foo).into()); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("none().a()").unwrap()); +// assert_eq!(result, Err("no value".into())); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("bot.a()").unwrap()); +// assert_eq!(result, Err("not supported".into())); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("[][0].a()").unwrap()); +// assert_eq!(result, Err("not supported".into())); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("v0.a()").unwrap()); +// assert_eq!(result, Err("unbounded type var".into())); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("v2.a()").unwrap()); +// assert_eq!(result, Err("no such function".into())); +// } + +// #[test] +// fn infer_subscript() { +// let mut ctx = basic_ctx(); +// ctx.add_fn( +// "none", +// FnDef { +// args: vec![], +// result: None, +// }, +// ); +// let sym_table = HashMap::new(); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("[1, 2, 3][0]").unwrap()); +// assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("[[1]][0][0]").unwrap()); +// assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); + +// let result = infer_expr( +// &ctx, +// &sym_table, +// &parse_expression("[1, 2, 3][1:2]").unwrap(), +// ); +// assert_eq!( +// result.unwrap().unwrap(), +// ParametricType(LIST_TYPE, vec![PrimitiveType(INT32_TYPE).into()]).into() +// ); + +// let result = infer_expr( +// &ctx, +// &sym_table, +// &parse_expression("[1, 2, 3][1:2:2]").unwrap(), +// ); +// assert_eq!( +// result.unwrap().unwrap(), +// ParametricType(LIST_TYPE, vec![PrimitiveType(INT32_TYPE).into()]).into() +// ); + +// let result = infer_expr( +// &ctx, +// &sym_table, +// &parse_expression("[1, 2, 3][1:1.2]").unwrap(), +// ); +// assert_eq!(result, Err("slice must be int32 type".into())); + +// let result = infer_expr( +// &ctx, +// &sym_table, +// &parse_expression("[1, 2, 3][1:none()]").unwrap(), +// ); +// assert_eq!(result, Err("slice must have type".into())); + +// let result = infer_expr( +// &ctx, +// &sym_table, +// &parse_expression("[1, 2, 3][1.2]").unwrap(), +// ); +// assert_eq!(result, Err("index must be either slice or int32".into())); + +// let result = infer_expr( +// &ctx, +// &sym_table, +// &parse_expression("[1, 2, 3][none()]").unwrap(), +// ); +// assert_eq!(result, Err("no value".into())); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("none()[1.2]").unwrap()); +// assert_eq!(result, Err("no value".into())); + +// let result = infer_expr(&ctx, &sym_table, &parse_expression("123[1]").unwrap()); +// assert_eq!( +// result, +// Err("subscript is not supported for types other than list".into()) +// ); +// } + +// #[test] +// fn test_if_expr() { +// let mut ctx = basic_ctx(); +// ctx.add_fn( +// "none", +// FnDef { +// args: vec![], +// result: None, +// }, +// ); +// let sym_table = HashMap::new(); + +// let result = infer_expr( +// &ctx, +// &sym_table, +// &parse_expression("1 if True else 0").unwrap(), +// ); +// assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); + +// let result = infer_expr( +// &ctx, +// &sym_table, +// &parse_expression("none() if True else none()").unwrap(), +// ); +// assert_eq!(result.unwrap(), None); + +// let result = infer_expr( +// &ctx, +// &sym_table, +// &parse_expression("none() if 1 else none()").unwrap(), +// ); +// assert_eq!(result, Err("test should be bool".into())); + +// let result = infer_expr( +// &ctx, +// &sym_table, +// &parse_expression("1 if True else none()").unwrap(), +// ); +// assert_eq!(result, Err("divergent type".into())); +// } + +// #[test] +// fn test_list_comp() { +// let mut ctx = basic_ctx(); +// ctx.add_fn( +// "none", +// FnDef { +// args: vec![], +// result: None, +// }, +// ); +// let int32 = Rc::new(PrimitiveType(INT32_TYPE)); +// let mut sym_table = HashMap::new(); +// sym_table.insert("z", int32.clone()); + +// let result = infer_expr( +// &ctx, +// &sym_table, +// &parse_expression("[x for x in [(1, 2), (2, 3), (3, 4)]][0]").unwrap(), +// ); +// assert_eq!( +// result.unwrap().unwrap(), +// ParametricType(TUPLE_TYPE, vec![int32.clone(), int32.clone()]).into() +// ); + +// let result = infer_expr( +// &ctx, +// &sym_table, +// &parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)]][0]").unwrap(), +// ); +// assert_eq!(result.unwrap().unwrap(), int32.clone()); + +// let result = infer_expr( +// &ctx, +// &sym_table, +// &parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)] if x > 0][0]").unwrap(), +// ); +// assert_eq!(result.unwrap().unwrap(), int32.clone()); + +// let result = infer_expr( +// &ctx, +// &sym_table, +// &parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)] if x][0]").unwrap(), +// ); +// assert_eq!(result, Err("test must be bool".into())); + +// let result = infer_expr( +// &ctx, +// &sym_table, +// &parse_expression("[y for x in []][0]").unwrap(), +// ); +// assert_eq!(result, Err("unbounded variable".into())); + +// let result = infer_expr( +// &ctx, +// &sym_table, +// &parse_expression("[none() for x in []][0]").unwrap(), +// ); +// assert_eq!(result, Err("no value".into())); + +// let result = infer_expr( +// &ctx, +// &sym_table, +// &parse_expression("[z for z in []][0]").unwrap(), +// ); +// assert_eq!(result, Err("duplicated naming".into())); + +// let result = infer_expr( +// &ctx, +// &sym_table, +// &parse_expression("[x for x in [] for y in []]").unwrap(), +// ); +// assert_eq!( +// result, +// Err("only 1 generator statement is supported".into()) +// ); +// } +// } diff --git a/nac3core/src/lib.rs b/nac3core/src/lib.rs index b38e5925d..514988fb2 100644 --- a/nac3core/src/lib.rs +++ b/nac3core/src/lib.rs @@ -2,7 +2,7 @@ extern crate num_bigint; extern crate inkwell; extern crate rustpython_parser; -// pub mod expression; +pub mod expression_inference; pub mod inference_core; mod magic_methods; pub mod primitives;