diff --git a/nac3core/src/expression.rs b/nac3core/src/expression.rs new file mode 100644 index 00000000..751a8aba --- /dev/null +++ b/nac3core/src/expression.rs @@ -0,0 +1,1023 @@ +use crate::inference::resolve_call; +use crate::operators::*; +use crate::primitives::*; +use crate::typedef::{GlobalContext, Type, Type::*}; +use rustpython_parser::ast::{ + Comparison, Comprehension, ComprehensionKind, Expression, ExpressionType, Operator, + UnaryOperator, +}; +use std::collections::HashMap; +use std::convert::TryInto; +use std::rc::Rc; + +type SymTable<'a> = HashMap<&'a str, Rc>; +type ParserResult = Result>, String>; + +pub fn parse_expr(ctx: &GlobalContext, sym_table: &SymTable, expr: &Expression) -> ParserResult { + match &expr.node { + ExpressionType::Number { value } => parse_constant(ctx, sym_table, value), + ExpressionType::Identifier { name } => parse_identifier(ctx, sym_table, name), + ExpressionType::List { elements } => parse_list(ctx, sym_table, elements), + ExpressionType::Tuple { elements } => parse_tuple(ctx, sym_table, elements), + ExpressionType::Attribute { value, name } => parse_attribute(ctx, sym_table, value, name), + ExpressionType::BoolOp { values, .. } => parse_bool_ops(ctx, sym_table, values), + ExpressionType::Binop { a, b, op } => parse_bin_ops(ctx, sym_table, op, a, b), + ExpressionType::Unop { op, a } => parse_unary_ops(ctx, sym_table, op, a), + ExpressionType::Compare { vals, ops } => parse_compare(ctx, sym_table, vals, ops), + ExpressionType::Call { + args, + function, + keywords, + } => { + if keywords.len() > 0 { + Err("keyword is not supported".into()) + } else { + parse_call(ctx, sym_table, &args, &function) + } + } + ExpressionType::Subscript { a, b } => parse_subscript(ctx, sym_table, a, b), + ExpressionType::IfExpression { test, body, orelse } => { + parse_if_expr(ctx, sym_table, &test, &body, orelse) + } + ExpressionType::Comprehension { kind, generators } => match kind.as_ref() { + ComprehensionKind::List { element } => { + if generators.len() == 1 { + parse_list_comprehension(ctx, sym_table, element, &generators[0]) + } else { + Err("only 1 generator statement is supported".into()) + } + } + _ => Err("only list comprehension is supported".into()), + }, + ExpressionType::True | ExpressionType::False => Ok(Some(PrimitiveType(BOOL_TYPE).into())), + _ => Err("not supported".into()), + } +} + +fn parse_constant( + _: &GlobalContext, + _: &SymTable, + value: &rustpython_parser::ast::Number, +) -> ParserResult { + use rustpython_parser::ast::Number; + match value { + Number::Integer { value } => { + let int32: Result = value.try_into(); + if int32.is_ok() { + Ok(Some(PrimitiveType(INT32_TYPE).into())) + } else { + let int64: Result = value.try_into(); + if int64.is_ok() { + Ok(Some(PrimitiveType(INT64_TYPE).into())) + } else { + Err("integer out of range".into()) + } + } + } + Number::Float { .. } => Ok(Some(PrimitiveType(FLOAT_TYPE).into())), + _ => Err("not supported".into()), + } +} + +fn parse_identifier(_: &GlobalContext, sym_table: &SymTable, name: &str) -> ParserResult { + match sym_table.get(name) { + Some(v) => Ok(Some(v.clone())), + None => Err("unbounded variable".into()), + } +} + +fn parse_list(ctx: &GlobalContext, sym_table: &SymTable, elements: &[Expression]) -> ParserResult { + if elements.len() == 0 { + return Ok(Some(ParametricType(LIST_TYPE, vec![BotType.into()]).into())); + } + + let mut types = elements.iter().map(|v| parse_expr(&ctx, sym_table, v)); + + let head = types.next().unwrap()?; + if head.is_none() { + return Err("list elements must have some type".into()); + } + for v in types { + if v? != head { + return Err("inhomogeneous list is not allowed".into()); + } + } + Ok(Some(ParametricType(LIST_TYPE, vec![head.unwrap()]).into())) +} + +fn parse_tuple(ctx: &GlobalContext, sym_table: &SymTable, elements: &[Expression]) -> ParserResult { + let types: Result>, String> = elements + .iter() + .map(|v| parse_expr(&ctx, sym_table, v)) + .collect(); + if let Some(t) = types? { + Ok(Some(ParametricType(TUPLE_TYPE, t).into())) + } else { + Err("tuple elements must have some type".into()) + } +} + +fn parse_attribute( + ctx: &GlobalContext, + sym_table: &SymTable, + value: &Expression, + name: &String, +) -> ParserResult { + let value = parse_expr(ctx, sym_table, value)?.ok_or("no value".to_string())?; + if let TypeVariable(id) = value.as_ref() { + let v = ctx.get_variable(*id); + if v.bound.len() == 0 { + return Err("no fields on unbounded type variable".into()); + } + let ty = v.bound[0] + .get_base(ctx) + .and_then(|v| v.fields.get(name.as_str())); + if ty.is_none() { + return Err("unknown field".into()); + } + for x in v.bound[1..].iter() { + let ty1 = x.get_base(ctx).and_then(|v| v.fields.get(name.as_str())); + if ty1 != ty { + return Err("unknown field (type mismatch between variants)".into()); + } + } + return Ok(Some(ty.unwrap().clone())); + } + + match value.get_base(ctx) { + Some(b) => match b.fields.get(name.as_str()) { + Some(t) => Ok(Some(t.clone())), + None => Err("no such field".into()), + }, + None => Err("this object has no fields".into()), + } +} + +fn parse_bool_ops( + ctx: &GlobalContext, + sym_table: &SymTable, + values: &[Expression], +) -> ParserResult { + assert_eq!(values.len(), 2); + let left = parse_expr(ctx, sym_table, &values[0])?.ok_or("no value".to_string())?; + let right = parse_expr(ctx, sym_table, &values[1])?.ok_or("no value".to_string())?; + + let b = PrimitiveType(BOOL_TYPE); + if left.as_ref() == &b && right.as_ref() == &b { + Ok(Some(b.into())) + } else { + Err("bool operands must be bool".into()) + } +} + +fn parse_bin_ops( + ctx: &GlobalContext, + sym_table: &SymTable, + op: &Operator, + left: &Expression, + right: &Expression, +) -> ParserResult { + let left = parse_expr(ctx, sym_table, left)?.ok_or("no value".to_string())?; + let right = parse_expr(ctx, sym_table, right)?.ok_or("no value".to_string())?; + let fun = binop_name(op); + resolve_call(ctx, Some(left), fun, &[right]) +} + +fn parse_unary_ops( + ctx: &GlobalContext, + sym_table: &SymTable, + op: &UnaryOperator, + obj: &Expression, +) -> ParserResult { + let ty = parse_expr(ctx, sym_table, obj)?.ok_or("no value".to_string())?; + if let UnaryOperator::Not = op { + if ty.as_ref() == &PrimitiveType(BOOL_TYPE) { + Ok(Some(ty)) + } else { + Err("logical not must be applied to bool".into()) + } + } else { + resolve_call(ctx, Some(ty), unaryop_name(op), &[]) + } +} + +fn parse_compare( + ctx: &GlobalContext, + sym_table: &SymTable, + vals: &[Expression], + ops: &[Comparison], +) -> ParserResult { + let types: Result>, _> = + vals.iter().map(|v| parse_expr(ctx, sym_table, v)).collect(); + let types = types?; + if types.is_none() { + return Err("comparison operands must have type".into()); + } + let types = types.unwrap(); + let boolean = PrimitiveType(BOOL_TYPE); + let left = &types[..types.len() - 1]; + let right = &types[1..]; + + for ((a, b), op) in left.iter().zip(right.iter()).zip(ops.iter()) { + let fun = comparison_name(op).ok_or("unsupported comparison".to_string())?; + let ty = resolve_call(ctx, Some(a.clone()), fun, &[b.clone()])?; + if ty.is_none() || ty.unwrap().as_ref() != &boolean { + return Err("comparison result must be boolean".into()); + } + } + Ok(Some(boolean.into())) +} + +fn parse_call( + ctx: &GlobalContext, + sym_table: &SymTable, + args: &[Expression], + function: &Expression, +) -> ParserResult { + let types: Result>, _> = + args.iter().map(|v| parse_expr(ctx, sym_table, v)).collect(); + let types = types?; + if types.is_none() { + return Err("function params must have type".into()); + } + + let (obj, fun) = match &function.node { + ExpressionType::Identifier { name } => (None, name), + ExpressionType::Attribute { value, name } => ( + Some(parse_expr(ctx, sym_table, &value)?.ok_or("no value".to_string())?), + name, + ), + _ => return Err("not supported".into()), + }; + resolve_call(ctx, obj, fun.as_str(), &types.unwrap()) +} + +fn parse_subscript( + ctx: &GlobalContext, + sym_table: &SymTable, + a: &Expression, + b: &Expression, +) -> ParserResult { + let a = parse_expr(ctx, sym_table, a)?.ok_or("no value".to_string())?; + let t = if let ParametricType(LIST_TYPE, ls) = a.as_ref() { + ls[0].clone() + } else { + return Err("subscript is not supported for types other than list".into()); + }; + + match &b.node { + ExpressionType::Slice { elements } => { + let int32 = Rc::new(PrimitiveType(INT32_TYPE)); + let types: Result>, _> = elements + .iter() + .map(|v| { + if let ExpressionType::None = v.node { + Ok(Some(int32.clone())) + } else { + parse_expr(ctx, sym_table, v) + } + }) + .collect(); + let types = types?.ok_or("slice must have type".to_string())?; + if types.iter().all(|v| v == &int32) { + Ok(Some(a)) + } else { + Err("slice must be int32 type".into()) + } + } + _ => { + let b = parse_expr(ctx, sym_table, b)?.ok_or("no value".to_string())?; + if b.as_ref() == &PrimitiveType(INT32_TYPE) { + Ok(Some(t)) + } else { + Err("index must be either slice or int32".into()) + } + } + } +} + +fn parse_if_expr( + ctx: &GlobalContext, + sym_table: &SymTable, + test: &Expression, + body: &Expression, + orelse: &Expression, +) -> ParserResult { + let test = parse_expr(ctx, sym_table, test)?.ok_or("no value".to_string())?; + if test.as_ref() != &PrimitiveType(BOOL_TYPE) { + return Err("test should be bool".into()); + } + + let body = parse_expr(ctx, sym_table, body)?; + let orelse = parse_expr(ctx, sym_table, orelse)?; + if body.as_ref() == orelse.as_ref() { + Ok(body) + } else { + Err("divergent type".into()) + } +} + +fn parse_simple_binding<'a: 'b, 'b>( + sym_table: &mut SymTable<'b>, + name: &'a Expression, + ty: Rc, +) -> Result<(), String> { + match &name.node { + ExpressionType::Identifier { name } => { + if name == "_" { + Ok(()) + } else if sym_table.get(name.as_str()).is_some() { + Err("duplicated naming".into()) + } else { + sym_table.insert(name.as_str(), ty); + Ok(()) + } + } + ExpressionType::Tuple { elements } => { + if let ParametricType(TUPLE_TYPE, ls) = ty.as_ref() { + if elements.len() == ls.len() { + for (a, b) in elements.iter().zip(ls.iter()) { + parse_simple_binding(sym_table, a, b.clone())?; + } + Ok(()) + } else { + Err("different length".into()) + } + } else { + Err("not supported".into()) + } + } + _ => Err("not supported".into()), + } +} + +fn parse_list_comprehension( + ctx: &GlobalContext, + sym_table: &SymTable, + element: &Expression, + comprehension: &Comprehension, +) -> ParserResult { + if comprehension.is_async { + return Err("async is not supported".into()); + } + + // TODO: it may be more efficient to use multi-level table + // but it would better done in a whole program level + let iter = parse_expr(ctx, sym_table, &comprehension.iter)?.ok_or("no value".to_string())?; + if let ParametricType(LIST_TYPE, ls) = iter.as_ref() { + let mut local_sym = sym_table.clone(); + parse_simple_binding(&mut local_sym, &comprehension.target, ls[0].clone())?; + + let boolean = PrimitiveType(BOOL_TYPE); + for test in comprehension.ifs.iter() { + let result = + parse_expr(ctx, &local_sym, test)?.ok_or("no value in test".to_string())?; + if result.as_ref() != &boolean { + return Err("test must be bool".into()); + } + } + let result = parse_expr(ctx, &local_sym, element)?.ok_or("no value")?; + Ok(Some(ParametricType(LIST_TYPE, vec![result]).into())) + } else { + Err("iteration is supported for list only".into()) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::typedef::*; + use rustpython_parser::parser::parse_expression; + + #[test] + fn test_constants() { + let ctx = basic_ctx(); + let sym_table = HashMap::new(); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("123").unwrap()); + assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("2147483647").unwrap()); + assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("2147483648").unwrap()); + assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT64_TYPE).into()); + + let result = parse_expr( + &ctx, + &sym_table, + &parse_expression("9223372036854775807").unwrap(), + ); + assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT64_TYPE).into()); + + let result = parse_expr( + &ctx, + &sym_table, + &parse_expression("9223372036854775808").unwrap(), + ); + assert_eq!(result, Err("integer out of range".into())); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("123.456").unwrap()); + assert_eq!(result.unwrap().unwrap(), PrimitiveType(FLOAT_TYPE).into()); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("True").unwrap()); + assert_eq!(result.unwrap().unwrap(), PrimitiveType(BOOL_TYPE).into()); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("False").unwrap()); + assert_eq!(result.unwrap().unwrap(), PrimitiveType(BOOL_TYPE).into()); + } + + #[test] + fn test_identifier() { + let ctx = basic_ctx(); + let mut sym_table = HashMap::new(); + sym_table.insert("abc", Rc::new(PrimitiveType(INT32_TYPE))); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("abc").unwrap()); + assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("ab").unwrap()); + assert_eq!(result, Err("unbounded variable".into())); + } + + #[test] + fn test_list() { + let mut ctx = basic_ctx(); + ctx.add_fn( + "foo", + FnDef { + args: vec![], + result: None, + }, + ); + let mut sym_table = HashMap::new(); + sym_table.insert("abc", Rc::new(PrimitiveType(INT32_TYPE))); + // def is reserved... + sym_table.insert("efg", Rc::new(PrimitiveType(INT32_TYPE))); + sym_table.insert("xyz", Rc::new(PrimitiveType(FLOAT_TYPE))); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("[]").unwrap()); + assert_eq!( + result.unwrap().unwrap(), + ParametricType(LIST_TYPE, vec![BotType.into()]).into() + ); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("[abc]").unwrap()); + assert_eq!( + result.unwrap().unwrap(), + ParametricType(LIST_TYPE, vec![PrimitiveType(INT32_TYPE).into()]).into() + ); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("[abc, efg]").unwrap()); + assert_eq!( + result.unwrap().unwrap(), + ParametricType(LIST_TYPE, vec![PrimitiveType(INT32_TYPE).into()]).into() + ); + + let result = parse_expr( + &ctx, + &sym_table, + &parse_expression("[abc, efg, xyz]").unwrap(), + ); + assert_eq!(result, Err("inhomogeneous list is not allowed".into())); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("[foo()]").unwrap()); + assert_eq!(result, Err("list elements must have some type".into())); + } + + #[test] + fn test_tuple() { + let mut ctx = basic_ctx(); + ctx.add_fn( + "foo", + FnDef { + args: vec![], + result: None, + }, + ); + let mut sym_table = HashMap::new(); + sym_table.insert("abc", Rc::new(PrimitiveType(INT32_TYPE))); + sym_table.insert("efg", Rc::new(PrimitiveType(FLOAT_TYPE))); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("(abc, efg)").unwrap()); + assert_eq!( + result.unwrap().unwrap(), + ParametricType( + TUPLE_TYPE, + vec![ + PrimitiveType(INT32_TYPE).into(), + PrimitiveType(FLOAT_TYPE).into() + ] + ) + .into() + ); + + let result = parse_expr( + &ctx, + &sym_table, + &parse_expression("(abc, efg, foo())").unwrap(), + ); + assert_eq!(result, Err("tuple elements must have some type".into())); + } + + #[test] + fn test_attribute() { + let mut ctx = basic_ctx(); + ctx.add_fn( + "none", + FnDef { + args: vec![], + result: None, + }, + ); + + let foo = ctx.add_class(ClassDef { + base: TypeDef { + name: "Foo", + fields: HashMap::new(), + methods: HashMap::new(), + }, + parents: vec![], + }); + let foo_def = ctx.get_class_mut(foo); + foo_def + .base + .fields + .insert("a", PrimitiveType(INT32_TYPE).into()); + foo_def.base.fields.insert("b", ClassType(foo).into()); + foo_def + .base + .fields + .insert("c", PrimitiveType(INT32_TYPE).into()); + + let bar = ctx.add_class(ClassDef { + base: TypeDef { + name: "Bar", + fields: HashMap::new(), + methods: HashMap::new(), + }, + parents: vec![], + }); + let bar_def = ctx.get_class_mut(bar); + bar_def + .base + .fields + .insert("a", PrimitiveType(INT32_TYPE).into()); + bar_def.base.fields.insert("b", ClassType(bar).into()); + bar_def + .base + .fields + .insert("c", PrimitiveType(FLOAT_TYPE).into()); + + let v0 = ctx.add_variable(VarDef { + name: "v0", + bound: vec![], + }); + + let v1 = ctx.add_variable(VarDef { + name: "v1", + bound: vec![ClassType(foo).into(), ClassType(bar).into()], + }); + + let mut sym_table = HashMap::new(); + sym_table.insert("foo", Rc::new(ClassType(foo))); + sym_table.insert("bar", Rc::new(ClassType(bar))); + sym_table.insert("foobar", Rc::new(VirtualClassType(foo))); + sym_table.insert("v0", Rc::new(TypeVariable(v0))); + sym_table.insert("v1", Rc::new(TypeVariable(v1))); + sym_table.insert("bot", Rc::new(BotType)); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("foo.a").unwrap()); + assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("foo.d").unwrap()); + assert_eq!(result, Err("no such field".into())); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("foobar.a").unwrap()); + assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("v0.a").unwrap()); + assert_eq!(result, Err("no fields on unbounded type variable".into())); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("v1.a").unwrap()); + assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); + + // shall we support this? + let result = parse_expr(&ctx, &sym_table, &parse_expression("v1.b").unwrap()); + assert_eq!( + result, + Err("unknown field (type mismatch between variants)".into()) + ); + // assert_eq!(result.unwrap().unwrap(), TypeVariable(v1).into()); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("v1.c").unwrap()); + assert_eq!( + result, + Err("unknown field (type mismatch between variants)".into()) + ); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("v1.d").unwrap()); + assert_eq!(result, Err("unknown field".into())); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("none().a").unwrap()); + assert_eq!(result, Err("no value".into())); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("bot.a").unwrap()); + assert_eq!(result, Err("this object has no fields".into())); + } + + #[test] + fn test_bool_ops() { + let mut ctx = basic_ctx(); + ctx.add_fn( + "none", + FnDef { + args: vec![], + result: None, + }, + ); + let sym_table = HashMap::new(); + + let result = parse_expr( + &ctx, + &sym_table, + &parse_expression("True and False").unwrap(), + ); + assert_eq!(result.unwrap().unwrap(), PrimitiveType(BOOL_TYPE).into()); + + let result = parse_expr( + &ctx, + &sym_table, + &parse_expression("True and none()").unwrap(), + ); + assert_eq!(result, Err("no value".into())); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("True and 123").unwrap()); + assert_eq!(result, Err("bool operands must be bool".into())); + } + + #[test] + fn test_bin_ops() { + let mut ctx = basic_ctx(); + let v0 = ctx.add_variable(VarDef { + name: "v0", + bound: vec![ + PrimitiveType(INT32_TYPE).into(), + PrimitiveType(INT64_TYPE).into(), + ], + }); + let mut sym_table = HashMap::new(); + sym_table.insert("a", TypeVariable(v0).into()); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("1 + 2 + 3").unwrap()); + assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("a + a + a").unwrap()); + assert_eq!(result.unwrap().unwrap(), TypeVariable(v0).into()); + } + + #[test] + fn test_unary_ops() { + let mut ctx = basic_ctx(); + let v0 = ctx.add_variable(VarDef { + name: "v0", + bound: vec![ + PrimitiveType(INT32_TYPE).into(), + PrimitiveType(INT64_TYPE).into(), + ], + }); + let mut sym_table = HashMap::new(); + sym_table.insert("a", TypeVariable(v0).into()); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("-(123)").unwrap()); + assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("-a").unwrap()); + assert_eq!(result.unwrap().unwrap(), TypeVariable(v0).into()); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("not True").unwrap()); + assert_eq!(result.unwrap().unwrap(), PrimitiveType(BOOL_TYPE).into()); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("not (1)").unwrap()); + assert_eq!(result, Err("logical not must be applied to bool".into())); + } + + #[test] + fn test_compare() { + let mut ctx = basic_ctx(); + let v0 = ctx.add_variable(VarDef { + name: "v0", + bound: vec![ + PrimitiveType(INT32_TYPE).into(), + PrimitiveType(INT64_TYPE).into(), + ], + }); + let mut sym_table = HashMap::new(); + sym_table.insert("a", TypeVariable(v0).into()); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("a == a == a").unwrap()); + assert_eq!(result.unwrap().unwrap(), PrimitiveType(BOOL_TYPE).into()); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("a == a == 1").unwrap()); + assert_eq!(result, Err("not equal".into())); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("True > False").unwrap()); + assert_eq!(result, Err("no such function".into())); + + let result = parse_expr( + &ctx, + &sym_table, + &parse_expression("True in False").unwrap(), + ); + assert_eq!(result, Err("unsupported comparison".into())); + } + + #[test] + fn test_call() { + let mut ctx = basic_ctx(); + ctx.add_fn( + "none", + FnDef { + args: vec![], + result: None, + }, + ); + + let foo = ctx.add_class(ClassDef { + base: TypeDef { + name: "Foo", + fields: HashMap::new(), + methods: HashMap::new(), + }, + parents: vec![], + }); + let foo_def = ctx.get_class_mut(foo); + foo_def.base.methods.insert( + "a", + FnDef { + args: vec![], + result: Some(Rc::new(ClassType(foo))), + }, + ); + + let bar = ctx.add_class(ClassDef { + base: TypeDef { + name: "Bar", + fields: HashMap::new(), + methods: HashMap::new(), + }, + parents: vec![], + }); + let bar_def = ctx.get_class_mut(bar); + bar_def.base.methods.insert( + "a", + FnDef { + args: vec![], + result: Some(Rc::new(ClassType(bar))), + }, + ); + + let v0 = ctx.add_variable(VarDef { + name: "v0", + bound: vec![], + }); + let v1 = ctx.add_variable(VarDef { + name: "v1", + bound: vec![ClassType(foo).into(), ClassType(bar).into()], + }); + let v2 = ctx.add_variable(VarDef { + name: "v2", + bound: vec![ + ClassType(foo).into(), + ClassType(bar).into(), + PrimitiveType(INT32_TYPE).into(), + ], + }); + let mut sym_table = HashMap::new(); + sym_table.insert("foo", Rc::new(ClassType(foo))); + sym_table.insert("bar", Rc::new(ClassType(bar))); + sym_table.insert("foobar", Rc::new(VirtualClassType(foo))); + sym_table.insert("v0", Rc::new(TypeVariable(v0))); + sym_table.insert("v1", Rc::new(TypeVariable(v1))); + sym_table.insert("v2", Rc::new(TypeVariable(v2))); + sym_table.insert("bot", Rc::new(BotType)); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("foo.a()").unwrap()); + assert_eq!(result.unwrap().unwrap(), ClassType(foo).into()); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("v1.a()").unwrap()); + assert_eq!(result.unwrap().unwrap(), TypeVariable(v1).into()); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("foobar.a()").unwrap()); + assert_eq!(result.unwrap().unwrap(), ClassType(foo).into()); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("none().a()").unwrap()); + assert_eq!(result, Err("no value".into())); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("bot.a()").unwrap()); + assert_eq!(result, Err("not supported".into())); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("[][0].a()").unwrap()); + assert_eq!(result, Err("not supported".into())); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("v0.a()").unwrap()); + assert_eq!(result, Err("unbounded type var".into())); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("v2.a()").unwrap()); + assert_eq!(result, Err("no such function".into())); + } + + #[test] + fn parse_subscript() { + let mut ctx = basic_ctx(); + ctx.add_fn( + "none", + FnDef { + args: vec![], + result: None, + }, + ); + let sym_table = HashMap::new(); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("[1, 2, 3][0]").unwrap()); + assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("[[1]][0][0]").unwrap()); + assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); + + let result = parse_expr( + &ctx, + &sym_table, + &parse_expression("[1, 2, 3][1:2]").unwrap(), + ); + assert_eq!( + result.unwrap().unwrap(), + ParametricType(LIST_TYPE, vec![PrimitiveType(INT32_TYPE).into()]).into() + ); + + let result = parse_expr( + &ctx, + &sym_table, + &parse_expression("[1, 2, 3][1:2:2]").unwrap(), + ); + assert_eq!( + result.unwrap().unwrap(), + ParametricType(LIST_TYPE, vec![PrimitiveType(INT32_TYPE).into()]).into() + ); + + let result = parse_expr( + &ctx, + &sym_table, + &parse_expression("[1, 2, 3][1:1.2]").unwrap(), + ); + assert_eq!(result, Err("slice must be int32 type".into())); + + let result = parse_expr( + &ctx, + &sym_table, + &parse_expression("[1, 2, 3][1:none()]").unwrap(), + ); + assert_eq!(result, Err("slice must have type".into())); + + let result = parse_expr( + &ctx, + &sym_table, + &parse_expression("[1, 2, 3][1.2]").unwrap(), + ); + assert_eq!(result, Err("index must be either slice or int32".into())); + + let result = parse_expr( + &ctx, + &sym_table, + &parse_expression("[1, 2, 3][none()]").unwrap(), + ); + assert_eq!(result, Err("no value".into())); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("none()[1.2]").unwrap()); + assert_eq!(result, Err("no value".into())); + + let result = parse_expr(&ctx, &sym_table, &parse_expression("123[1]").unwrap()); + assert_eq!( + result, + Err("subscript is not supported for types other than list".into()) + ); + } + + #[test] + fn test_if_expr() { + let mut ctx = basic_ctx(); + ctx.add_fn( + "none", + FnDef { + args: vec![], + result: None, + }, + ); + let sym_table = HashMap::new(); + + let result = parse_expr( + &ctx, + &sym_table, + &parse_expression("1 if True else 0").unwrap(), + ); + assert_eq!(result.unwrap().unwrap(), PrimitiveType(INT32_TYPE).into()); + + let result = parse_expr( + &ctx, + &sym_table, + &parse_expression("none() if True else none()").unwrap(), + ); + assert_eq!(result.unwrap(), None); + + let result = parse_expr( + &ctx, + &sym_table, + &parse_expression("none() if 1 else none()").unwrap(), + ); + assert_eq!(result, Err("test should be bool".into())); + + let result = parse_expr( + &ctx, + &sym_table, + &parse_expression("1 if True else none()").unwrap(), + ); + assert_eq!(result, Err("divergent type".into())); + } + + #[test] + fn test_list_comp() { + let mut ctx = basic_ctx(); + ctx.add_fn( + "none", + FnDef { + args: vec![], + result: None, + }, + ); + let int32 = Rc::new(PrimitiveType(INT32_TYPE)); + let mut sym_table = HashMap::new(); + sym_table.insert("z", int32.clone()); + + let result = parse_expr( + &ctx, + &sym_table, + &parse_expression("[x for x in [(1, 2), (2, 3), (3, 4)]][0]").unwrap(), + ); + assert_eq!( + result.unwrap().unwrap(), + ParametricType(TUPLE_TYPE, vec![int32.clone(), int32.clone()]).into() + ); + + let result = parse_expr( + &ctx, + &sym_table, + &parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)]][0]").unwrap(), + ); + assert_eq!(result.unwrap().unwrap(), int32.clone()); + + let result = parse_expr( + &ctx, + &sym_table, + &parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)] if x > 0][0]").unwrap(), + ); + assert_eq!(result.unwrap().unwrap(), int32.clone()); + + let result = parse_expr( + &ctx, + &sym_table, + &parse_expression("[x for (x, y) in [(1, 2), (2, 3), (3, 4)] if x][0]").unwrap(), + ); + assert_eq!(result, Err("test must be bool".into())); + + let result = parse_expr( + &ctx, + &sym_table, + &parse_expression("[y for x in []][0]").unwrap(), + ); + assert_eq!(result, Err("unbounded variable".into())); + + let result = parse_expr( + &ctx, + &sym_table, + &parse_expression("[none() for x in []][0]").unwrap(), + ); + assert_eq!(result, Err("no value".into())); + + let result = parse_expr( + &ctx, + &sym_table, + &parse_expression("[z for z in []][0]").unwrap(), + ); + assert_eq!(result, Err("duplicated naming".into())); + + let result = parse_expr( + &ctx, + &sym_table, + &parse_expression("[x for x in [] for y in []]").unwrap(), + ); + assert_eq!( + result, + Err("only 1 generator statement is supported".into()) + ); + } +} diff --git a/nac3core/src/inference.rs b/nac3core/src/inference.rs new file mode 100644 index 00000000..1eb351e6 --- /dev/null +++ b/nac3core/src/inference.rs @@ -0,0 +1,591 @@ +use super::typedef::{Type::*, *}; +use std::collections::HashMap; +use std::rc::Rc; + +fn find_subst( + ctx: &GlobalContext, + valuation: &Option<(VariableId, Rc)>, + sub: &mut HashMap>, + mut a: Rc, + mut b: Rc, +) -> Result<(), String> { + // TODO: fix error messages later + if let TypeVariable(id) = a.as_ref() { + if let Some((assumption_id, t)) = valuation { + if assumption_id == id { + a = t.clone(); + } + } + } + + let mut substituted = false; + if let TypeVariable(id) = b.as_ref() { + if let Some(c) = sub.get(&id) { + b = c.clone(); + substituted = true; + } + } + + match (a.as_ref(), b.as_ref()) { + (BotType, _) => Ok(()), + (TypeVariable(id_a), TypeVariable(id_b)) => { + if substituted { + return if id_a == id_b { + Ok(()) + } else { + Err("different variables".to_string()) + }; + } + let v_a = ctx.get_variable(*id_a); + let v_b = ctx.get_variable(*id_b); + if v_b.bound.len() > 0 { + if v_a.bound.len() == 0 { + return Err("unbounded a".to_string()); + } else { + let diff: Vec<_> = v_a + .bound + .iter() + .filter(|x| !v_b.bound.contains(x)) + .collect(); + if diff.len() > 0 { + return Err("different domain".to_string()); + } + } + } + sub.insert(*id_b, a.clone().into()); + Ok(()) + } + (TypeVariable(id_a), _) => { + let v_a = ctx.get_variable(*id_a); + if v_a.bound.len() == 1 && v_a.bound[0].as_ref() == b.as_ref() { + Ok(()) + } else { + Err("different domain".to_string()) + } + } + (_, TypeVariable(id_b)) => { + let v_b = ctx.get_variable(*id_b); + if v_b.bound.len() == 0 || v_b.bound.contains(&a) { + sub.insert(*id_b, a.clone().into()); + Ok(()) + } else { + Err("different domain".to_string()) + } + } + (_, VirtualClassType(id_b)) => { + let mut parents; + match a.as_ref() { + ClassType(id_a) => { + parents = [*id_a].to_vec(); + } + VirtualClassType(id_a) => { + parents = [*id_a].to_vec(); + } + _ => { + return Err("cannot substitute non-class type into virtual class".to_string()); + } + }; + while !parents.is_empty() { + if *id_b == parents[0] { + return Ok(()); + } + let c = ctx.get_class(parents.remove(0)); + parents.extend_from_slice(&c.parents); + } + Err("not subtype".to_string()) + } + (ParametricType(id_a, param_a), ParametricType(id_b, param_b)) => { + if id_a != id_b || param_a.len() != param_b.len() { + Err("different parametric types".to_string()) + } else { + for (x, y) in param_a.iter().zip(param_b.iter()) { + find_subst(ctx, valuation, sub, x.clone(), y.clone())?; + } + Ok(()) + } + } + (_, _) => { + if a == b { + Ok(()) + } else { + Err("not equal".to_string()) + } + } + } +} + +fn resolve_call_rec( + ctx: &GlobalContext, + valuation: &Option<(VariableId, Rc)>, + obj: Option>, + func: &str, + args: &[Rc], +) -> Result>, String> { + let mut subst = obj + .as_ref() + .map(|v| v.get_subst(ctx)) + .unwrap_or(HashMap::new()); + + let fun = match &obj { + Some(obj) => { + let base = match obj.as_ref() { + TypeVariable(id) => { + let v = ctx.get_variable(*id); + if v.bound.len() == 0 { + return Err("unbounded type var".to_string()); + } + let results: Result, String> = v + .bound + .iter() + .map(|ins| { + resolve_call_rec( + ctx, + &Some((*id, ins.clone())), + Some(ins.clone()), + func, + args.clone(), + ) + }) + .collect(); + let results = results?; + if results.iter().all(|v| v == &results[0]) { + return Ok(results[0].clone()); + } + let mut results = results.iter().zip(v.bound.iter()).map(|(r, ins)| { + r.as_ref() + .map(|v| v.inv_subst(&[(ins.clone(), obj.clone().into())])) + }); + let first = results.next().unwrap(); + if results.all(|v| v == first) { + return Ok(first); + } else { + return Err("divergent type after substitution".to_string()); + } + } + PrimitiveType(id) => &ctx.get_primitive(*id), + ClassType(id) | VirtualClassType(id) => &ctx.get_class(*id).base, + ParametricType(id, _) => &ctx.get_parametric(*id).base, + _ => return Err("not supported".to_string()), + }; + base.methods.get(func) + } + None => ctx.get_fn(func), + } + .ok_or("no such function".to_string())?; + + if args.len() != fun.args.len() { + return Err("incorrect parameter number".to_string()); + } + for (a, b) in args.iter().zip(fun.args.iter()) { + find_subst(ctx, valuation, &mut subst, a.clone(), b.clone())?; + } + let result = fun.result.as_ref().map(|v| v.subst(&subst)); + Ok(result.map(|result| { + if let SelfType = result { + obj.unwrap() + } else { + result.into() + } + })) +} + +pub fn resolve_call( + ctx: &GlobalContext, + obj: Option>, + func: &str, + args: &[Rc], +) -> Result>, String> { + resolve_call_rec(ctx, &None, obj, func, args) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::primitives::*; + + #[test] + fn test_simple_generic() { + let mut ctx = basic_ctx(); + + assert_eq!( + resolve_call(&ctx, None, "int32", &[PrimitiveType(FLOAT_TYPE).into()]), + Ok(Some(PrimitiveType(INT32_TYPE).into())) + ); + + assert_eq!( + resolve_call(&ctx, None, "int32", &[PrimitiveType(INT32_TYPE).into()],), + Ok(Some(PrimitiveType(INT32_TYPE).into())) + ); + + assert_eq!( + resolve_call(&ctx, None, "float", &[PrimitiveType(INT32_TYPE).into()]), + Ok(Some(PrimitiveType(FLOAT_TYPE).into())) + ); + + assert_eq!( + resolve_call(&ctx, None, "float", &[PrimitiveType(BOOL_TYPE).into()]), + Err("different domain".to_string()) + ); + + assert_eq!( + resolve_call(&ctx, None, "float", &[]), + Err("incorrect parameter number".to_string()) + ); + + let v1 = ctx.add_variable(VarDef { + name: "V1", + bound: vec![ + PrimitiveType(INT32_TYPE).into(), + PrimitiveType(FLOAT_TYPE).into(), + ], + }); + + assert_eq!( + resolve_call(&ctx, None, "float", &[TypeVariable(v1).into()]), + Ok(Some(PrimitiveType(FLOAT_TYPE).into())) + ); + + let v2 = ctx.add_variable(VarDef { + name: "V2", + bound: vec![ + PrimitiveType(BOOL_TYPE).into(), + PrimitiveType(INT32_TYPE).into(), + PrimitiveType(FLOAT_TYPE).into(), + ], + }); + + assert_eq!( + resolve_call(&ctx, None, "float", &[TypeVariable(v2).into()]), + Err("different domain".to_string()) + ); + } + + #[test] + fn test_methods() { + let mut ctx = basic_ctx(); + + let v0 = Rc::new(TypeVariable(ctx.add_variable(VarDef { + name: "V0", + bound: vec![], + }))); + let v1 = Rc::new(TypeVariable(ctx.add_variable(VarDef { + name: "V1", + bound: vec![ + PrimitiveType(INT32_TYPE).into(), + PrimitiveType(FLOAT_TYPE).into(), + ], + }))); + let v2 = Rc::new(TypeVariable(ctx.add_variable(VarDef { + name: "V2", + bound: vec![ + PrimitiveType(INT32_TYPE).into(), + PrimitiveType(FLOAT_TYPE).into(), + ], + }))); + let v3 = Rc::new(TypeVariable(ctx.add_variable(VarDef { + name: "V3", + bound: vec![ + PrimitiveType(BOOL_TYPE).into(), + PrimitiveType(INT32_TYPE).into(), + PrimitiveType(FLOAT_TYPE).into(), + ], + }))); + + let int32 = Rc::new(PrimitiveType(INT32_TYPE)); + let int64 = Rc::new(PrimitiveType(INT64_TYPE)); + + // simple cases + assert_eq!( + resolve_call(&ctx, Some(int32.clone()), "__add__", &[int32.clone()]), + Ok(Some(int32.clone())) + ); + + assert_ne!( + resolve_call(&ctx, Some(int32.clone()), "__add__", &[int32.clone()]), + Ok(Some(int64.clone())) + ); + + assert_eq!( + resolve_call(&ctx, Some(int32.clone()), "__add__", &[int64.clone()]), + Err("not equal".to_string()) + ); + + // with type variables + assert_eq!( + resolve_call(&ctx, Some(v1.clone()), "__add__", &[v1.clone()]), + Ok(Some(v1.clone())) + ); + assert_eq!( + resolve_call(&ctx, Some(v0.clone()), "__add__", &[v2.clone()]), + Err("unbounded type var".to_string()) + ); + assert_eq!( + resolve_call(&ctx, Some(v1.clone()), "__add__", &[v0.clone()]), + Err("different domain".to_string()) + ); + assert_eq!( + resolve_call(&ctx, Some(v1.clone()), "__add__", &[v2.clone()]), + Err("different domain".to_string()) + ); + assert_eq!( + resolve_call(&ctx, Some(v1.clone()), "__add__", &[v3.clone()]), + Err("different domain".to_string()) + ); + assert_eq!( + resolve_call(&ctx, Some(v3.clone()), "__add__", &[v1.clone()]), + Err("no such function".to_string()) + ); + assert_eq!( + resolve_call(&ctx, Some(v3.clone()), "__add__", &[v3.clone()]), + Err("no such function".to_string()) + ); + } + + #[test] + fn test_multi_generic() { + let mut ctx = basic_ctx(); + let v0 = Rc::new(TypeVariable(ctx.add_variable(VarDef { + name: "V0", + bound: vec![], + }))); + let v1 = Rc::new(TypeVariable(ctx.add_variable(VarDef { + name: "V1", + bound: vec![], + }))); + let v2 = Rc::new(TypeVariable(ctx.add_variable(VarDef { + name: "V2", + bound: vec![], + }))); + let v3 = Rc::new(TypeVariable(ctx.add_variable(VarDef { + name: "V3", + bound: vec![], + }))); + + ctx.add_fn( + "foo", + FnDef { + args: vec![v0.clone(), v0.clone(), v1.clone()], + result: Some(v0.clone()), + }, + ); + + ctx.add_fn( + "foo1", + FnDef { + args: vec![ + ParametricType(TUPLE_TYPE, vec![v0.clone(), v0.clone(), v1.clone()]).into(), + ], + result: Some(v0.clone()), + }, + ); + + assert_eq!( + resolve_call(&ctx, None, "foo", &[v2.clone(), v2.clone(), v2.clone()]), + Ok(Some(v2.clone())) + ); + assert_eq!( + resolve_call(&ctx, None, "foo", &[v2.clone(), v2.clone(), v3.clone()]), + Ok(Some(v2.clone())) + ); + assert_eq!( + resolve_call(&ctx, None, "foo", &[v2.clone(), v3.clone(), v3.clone()]), + Err("different variables".to_string()) + ); + + assert_eq!( + resolve_call( + &ctx, + None, + "foo1", + &[ParametricType(TUPLE_TYPE, vec![v2.clone(), v2.clone(), v2.clone()]).into()] + ), + Ok(Some(v2.clone())) + ); + assert_eq!( + resolve_call( + &ctx, + None, + "foo1", + &[ParametricType(TUPLE_TYPE, vec![v2.clone(), v2.clone(), v3.clone()]).into()] + ), + Ok(Some(v2.clone())) + ); + assert_eq!( + resolve_call( + &ctx, + None, + "foo1", + &[ParametricType(TUPLE_TYPE, vec![v2.clone(), v3.clone(), v3.clone()]).into()] + ), + Err("different variables".to_string()) + ); + } + + #[test] + fn test_class_generics() { + let mut ctx = basic_ctx(); + + let list = ctx.get_parametric_mut(LIST_TYPE); + let t = Rc::new(TypeVariable(list.params[0])); + list.base.methods.insert( + "head", + FnDef { + args: vec![], + result: Some(t.clone()), + }, + ); + list.base.methods.insert( + "append", + FnDef { + args: vec![t.clone()], + result: None, + }, + ); + + let v0 = Rc::new(TypeVariable(ctx.add_variable(VarDef { + name: "V0", + bound: vec![], + }))); + let v1 = Rc::new(TypeVariable(ctx.add_variable(VarDef { + name: "V1", + bound: vec![], + }))); + + assert_eq!( + resolve_call( + &ctx, + Some(ParametricType(LIST_TYPE, vec![v0.clone()]).into()), + "head", + &[] + ), + Ok(Some(v0.clone())) + ); + assert_eq!( + resolve_call( + &ctx, + Some(ParametricType(LIST_TYPE, vec![v0.clone()]).into()), + "append", + &[v0.clone()] + ), + Ok(None) + ); + assert_eq!( + resolve_call( + &ctx, + Some(ParametricType(LIST_TYPE, vec![v0.clone()]).into()), + "append", + &[v1.clone()] + ), + Err("different variables".to_string()) + ); + } + + #[test] + fn test_virtual_class() { + let mut ctx = basic_ctx(); + + let foo = ctx.add_class(ClassDef { + base: TypeDef { + name: "Foo", + methods: HashMap::new(), + fields: HashMap::new(), + }, + parents: vec![], + }); + + let foo1 = ctx.add_class(ClassDef { + base: TypeDef { + name: "Foo1", + methods: HashMap::new(), + fields: HashMap::new(), + }, + parents: vec![foo], + }); + + let foo2 = ctx.add_class(ClassDef { + base: TypeDef { + name: "Foo2", + methods: HashMap::new(), + fields: HashMap::new(), + }, + parents: vec![foo1], + }); + + let bar = ctx.add_class(ClassDef { + base: TypeDef { + name: "bar", + methods: HashMap::new(), + fields: HashMap::new(), + }, + parents: vec![], + }); + + ctx.add_fn( + "foo", + FnDef { + args: vec![VirtualClassType(foo).into()], + result: None, + }, + ); + ctx.add_fn( + "foo1", + FnDef { + args: vec![VirtualClassType(foo1).into()], + result: None, + }, + ); + + assert_eq!( + resolve_call(&ctx, None, "foo", &[ClassType(foo).into()]), + Ok(None) + ); + + assert_eq!( + resolve_call(&ctx, None, "foo", &[ClassType(foo1).into()]), + Ok(None) + ); + + assert_eq!( + resolve_call(&ctx, None, "foo", &[ClassType(foo2).into()]), + Ok(None) + ); + + assert_eq!( + resolve_call(&ctx, None, "foo", &[ClassType(bar).into()]), + Err("not subtype".to_string()) + ); + + assert_eq!( + resolve_call(&ctx, None, "foo1", &[ClassType(foo1).into()]), + Ok(None) + ); + + assert_eq!( + resolve_call(&ctx, None, "foo1", &[ClassType(foo2).into()]), + Ok(None) + ); + + assert_eq!( + resolve_call(&ctx, None, "foo1", &[ClassType(foo).into()]), + Err("not subtype".to_string()) + ); + + // virtual class substitution + assert_eq!( + resolve_call(&ctx, None, "foo", &[VirtualClassType(foo).into()]), + Ok(None) + ); + assert_eq!( + resolve_call(&ctx, None, "foo", &[VirtualClassType(foo1).into()]), + Ok(None) + ); + assert_eq!( + resolve_call(&ctx, None, "foo", &[VirtualClassType(foo2).into()]), + Ok(None) + ); + assert_eq!( + resolve_call(&ctx, None, "foo", &[VirtualClassType(bar).into()]), + Err("not subtype".to_string()) + ); + } +} diff --git a/nac3core/src/lib.rs b/nac3core/src/lib.rs index e122233b..ab521b4a 100644 --- a/nac3core/src/lib.rs +++ b/nac3core/src/lib.rs @@ -2,6 +2,12 @@ extern crate num_bigint; extern crate inkwell; extern crate rustpython_parser; +pub mod expression; +pub mod inference; +mod operators; +pub mod primitives; +pub mod typedef; + use std::error::Error; use std::fmt; use std::path::Path; diff --git a/nac3core/src/operators.rs b/nac3core/src/operators.rs new file mode 100644 index 00000000..b0c248b4 --- /dev/null +++ b/nac3core/src/operators.rs @@ -0,0 +1,58 @@ +use rustpython_parser::ast::{Comparison, Operator, UnaryOperator}; + +pub fn binop_name(op: &Operator) -> &'static str { + match op { + Operator::Add => "__add__", + Operator::Sub => "__sub__", + Operator::Div => "__truediv__", + Operator::Mod => "__mod__", + Operator::Mult => "__mul__", + Operator::Pow => "__pow__", + Operator::BitOr => "__or__", + Operator::BitXor => "__xor__", + Operator::BitAnd => "__and__", + Operator::LShift => "__lshift__", + Operator::RShift => "__rshift__", + Operator::FloorDiv => "__floordiv__", + Operator::MatMult => "__matmul__", + } +} + +pub fn binop_assign_name(op: &Operator) -> &'static str { + match op { + Operator::Add => "__iadd__", + Operator::Sub => "__isub__", + Operator::Div => "__itruediv__", + Operator::Mod => "__imod__", + Operator::Mult => "__imul__", + Operator::Pow => "__ipow__", + Operator::BitOr => "__ior__", + Operator::BitXor => "__ixor__", + Operator::BitAnd => "__iand__", + Operator::LShift => "__ilshift__", + Operator::RShift => "__irshift__", + Operator::FloorDiv => "__ifloordiv__", + Operator::MatMult => "__imatmul__", + } +} + +pub fn unaryop_name(op: &UnaryOperator) -> &'static str { + match op { + UnaryOperator::Pos => "__pos__", + UnaryOperator::Neg => "__neg__", + UnaryOperator::Not => "__not__", + UnaryOperator::Inv => "__inv__", + } +} + +pub fn comparison_name(op: &Comparison) -> Option<&'static str> { + match op { + Comparison::Less => Some("__lt__"), + Comparison::LessOrEqual => Some("__le__"), + Comparison::Greater => Some("__gt__"), + Comparison::GreaterOrEqual => Some("__ge__"), + Comparison::Equal => Some("__eq__"), + Comparison::NotEqual => Some("__ne__"), + _ => None, + } +} diff --git a/nac3core/src/primitives.rs b/nac3core/src/primitives.rs new file mode 100644 index 00000000..25b5cc3d --- /dev/null +++ b/nac3core/src/primitives.rs @@ -0,0 +1,181 @@ +use super::typedef::{Type::*, *}; +use std::collections::HashMap; +use std::rc::Rc; + +pub const TUPLE_TYPE: ParamId = ParamId(0); +pub const LIST_TYPE: ParamId = ParamId(1); + +pub const BOOL_TYPE: PrimitiveId = PrimitiveId(0); +pub const INT32_TYPE: PrimitiveId = PrimitiveId(1); +pub const INT64_TYPE: PrimitiveId = PrimitiveId(2); +pub const FLOAT_TYPE: PrimitiveId = PrimitiveId(3); + +fn impl_math(def: &mut TypeDef, ty: &Rc) { + let result = Some(ty.clone()); + let fun = FnDef { + args: vec![ty.clone()], + result: result.clone(), + }; + def.methods.insert("__add__", fun.clone()); + def.methods.insert("__sub__", fun.clone()); + def.methods.insert("__mul__", fun.clone()); + def.methods.insert("__neg__", FnDef { + args: vec![], + result + }); + def.methods.insert( + "__truediv__", + FnDef { + args: vec![ty.clone()], + result: Some(PrimitiveType(FLOAT_TYPE).into()), + }, + ); + def.methods.insert("__floordiv__", fun.clone()); + def.methods.insert("__mod__", fun.clone()); + def.methods.insert("__pow__", fun.clone()); +} + +fn impl_bits(def: &mut TypeDef, ty: &Rc) { + let result = Some(ty.clone()); + let fun = FnDef { + args: vec![PrimitiveType(INT32_TYPE).into()], + result, + }; + + def.methods.insert("__lshift__", fun.clone()); + def.methods.insert("__rshift__", fun.clone()); + def.methods.insert( + "__xor__", + FnDef { + args: vec![ty.clone()], + result: Some(ty.clone()), + }, + ); +} + +fn impl_eq(def: &mut TypeDef, ty: &Rc) { + let fun = FnDef { + args: vec![ty.clone()], + result: Some(PrimitiveType(BOOL_TYPE).into()), + }; + + def.methods.insert("__eq__", fun.clone()); + def.methods.insert("__ne__", fun.clone()); +} + +fn impl_order(def: &mut TypeDef, ty: &Rc) { + let fun = FnDef { + args: vec![ty.clone()], + result: Some(PrimitiveType(BOOL_TYPE).into()), + }; + + def.methods.insert("__lt__", fun.clone()); + def.methods.insert("__gt__", fun.clone()); + def.methods.insert("__le__", fun.clone()); + def.methods.insert("__ge__", fun.clone()); +} + +pub fn basic_ctx() -> GlobalContext<'static> { + let primitives = [ + TypeDef { + name: "bool", + fields: HashMap::new(), + methods: HashMap::new(), + }, + TypeDef { + name: "int32", + fields: HashMap::new(), + methods: HashMap::new(), + }, + TypeDef { + name: "int64", + fields: HashMap::new(), + methods: HashMap::new(), + }, + TypeDef { + name: "float", + fields: HashMap::new(), + methods: HashMap::new(), + }, + ] + .to_vec(); + let mut ctx = GlobalContext::new(primitives); + + let b_def = ctx.get_primitive_mut(BOOL_TYPE); + let b = PrimitiveType(BOOL_TYPE).into(); + impl_eq(b_def, &b); + let int32_def = ctx.get_primitive_mut(INT32_TYPE); + let int32 = PrimitiveType(INT32_TYPE).into(); + impl_math(int32_def, &int32); + impl_bits(int32_def, &int32); + impl_order(int32_def, &int32); + impl_eq(int32_def, &int32); + let int64_def = ctx.get_primitive_mut(INT64_TYPE); + let int64 = PrimitiveType(INT64_TYPE).into(); + impl_math(int64_def, &int64); + impl_bits(int64_def, &int64); + impl_order(int64_def, &int64); + impl_eq(int64_def, &int64); + let float_def = ctx.get_primitive_mut(FLOAT_TYPE); + let float = PrimitiveType(FLOAT_TYPE).into(); + impl_math(float_def, &float); + impl_order(float_def, &float); + impl_eq(float_def, &float); + + let t = ctx.add_variable_private(VarDef { + name: "T", + bound: vec![], + }); + + ctx.add_parametric(ParametricDef { + base: TypeDef { + name: "tuple", + fields: HashMap::new(), + methods: HashMap::new(), + }, + // we have nothing for tuple, so no param def + params: vec![], + }); + + ctx.add_parametric(ParametricDef { + base: TypeDef { + name: "list", + fields: HashMap::new(), + methods: HashMap::new(), + }, + params: vec![t], + }); + + let i = ctx.add_variable_private(VarDef { + name: "I", + bound: vec![ + PrimitiveType(INT32_TYPE).into(), + PrimitiveType(INT64_TYPE).into(), + PrimitiveType(FLOAT_TYPE).into(), + ], + }); + let args = vec![TypeVariable(i).into()]; + ctx.add_fn( + "int32", + FnDef { + args: args.clone(), + result: Some(PrimitiveType(INT32_TYPE).into()), + }, + ); + ctx.add_fn( + "int64", + FnDef { + args: args.clone(), + result: Some(PrimitiveType(INT64_TYPE).into()), + }, + ); + ctx.add_fn( + "float", + FnDef { + args: args.clone(), + result: Some(PrimitiveType(FLOAT_TYPE).into()), + }, + ); + + ctx +} diff --git a/nac3core/src/typedef.rs b/nac3core/src/typedef.rs new file mode 100644 index 00000000..5fe42279 --- /dev/null +++ b/nac3core/src/typedef.rs @@ -0,0 +1,223 @@ +use std::collections::HashMap; +use std::rc::Rc; + +#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)] +pub struct PrimitiveId(pub(crate) usize); + +#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)] +pub struct ClassId(pub(crate) usize); + +#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)] +pub struct ParamId(pub(crate) usize); + +#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)] +pub struct VariableId(pub(crate) usize); + +#[derive(PartialEq, Eq, Clone, Hash, Debug)] +pub enum Type { + BotType, + SelfType, + PrimitiveType(PrimitiveId), + ClassType(ClassId), + VirtualClassType(ClassId), + ParametricType(ParamId, Vec>), + TypeVariable(VariableId), +} + +#[derive(Clone)] +pub struct FnDef { + // we assume methods first argument to be SelfType, + // so the first argument is not contained here + pub args: Vec>, + pub result: Option>, +} + +#[derive(Clone)] +pub struct TypeDef<'a> { + pub name: &'a str, + pub fields: HashMap<&'a str, Rc>, + pub methods: HashMap<&'a str, FnDef>, +} + +#[derive(Clone)] +pub struct ClassDef<'a> { + pub base: TypeDef<'a>, + pub parents: Vec, +} + +#[derive(Clone)] +pub struct ParametricDef<'a> { + pub base: TypeDef<'a>, + pub params: Vec, +} + +#[derive(Clone)] +pub struct VarDef<'a> { + pub name: &'a str, + pub bound: Vec>, +} + +pub struct GlobalContext<'a> { + primitive_defs: Vec>, + class_defs: Vec>, + parametric_defs: Vec>, + var_defs: Vec>, + sym_table: HashMap<&'a str, Type>, + fn_table: HashMap<&'a str, FnDef>, +} + +impl<'a> GlobalContext<'a> { + pub fn new(primitives: Vec>) -> GlobalContext { + let mut sym_table = HashMap::new(); + for (i, t) in primitives.iter().enumerate() { + sym_table.insert(t.name, Type::PrimitiveType(PrimitiveId(i))); + } + return GlobalContext { + primitive_defs: primitives, + class_defs: Vec::new(), + parametric_defs: Vec::new(), + var_defs: Vec::new(), + fn_table: HashMap::new(), + sym_table, + }; + } + + pub fn add_class(&mut self, def: ClassDef<'a>) -> ClassId { + self.sym_table.insert( + def.base.name, + Type::ClassType(ClassId(self.class_defs.len())), + ); + self.class_defs.push(def); + ClassId(self.class_defs.len() - 1) + } + + pub fn add_parametric(&mut self, def: ParametricDef<'a>) -> ParamId { + let params = def + .params + .iter() + .map(|&v| Rc::new(Type::TypeVariable(v))) + .collect(); + self.sym_table.insert( + def.base.name, + Type::ParametricType(ParamId(self.parametric_defs.len()), params), + ); + self.parametric_defs.push(def); + ParamId(self.parametric_defs.len() - 1) + } + + pub fn add_variable(&mut self, def: VarDef<'a>) -> VariableId { + self.sym_table.insert( + def.name, + Type::TypeVariable(VariableId(self.var_defs.len())), + ); + self.add_variable_private(def) + } + + pub fn add_variable_private(&mut self, def: VarDef<'a>) -> VariableId { + self.var_defs.push(def); + VariableId(self.var_defs.len() - 1) + } + + pub fn add_fn(&mut self, name: &'a str, def: FnDef) { + self.fn_table.insert(name, def); + } + + pub fn get_fn(&self, name: &str) -> Option<&FnDef> { + self.fn_table.get(name) + } + + pub fn get_primitive_mut(&mut self, id: PrimitiveId) -> &mut TypeDef<'a> { + self.primitive_defs.get_mut(id.0).unwrap() + } + + pub fn get_primitive(&self, id: PrimitiveId) -> &TypeDef { + self.primitive_defs.get(id.0).unwrap() + } + + pub fn get_class_mut(&mut self, id: ClassId) -> &mut ClassDef<'a> { + self.class_defs.get_mut(id.0).unwrap() + } + + pub fn get_class(&self, id: ClassId) -> &ClassDef { + self.class_defs.get(id.0).unwrap() + } + + pub fn get_parametric_mut(&mut self, id: ParamId) -> &mut ParametricDef<'a> { + self.parametric_defs.get_mut(id.0).unwrap() + } + + pub fn get_parametric(&self, id: ParamId) -> &ParametricDef { + self.parametric_defs.get(id.0).unwrap() + } + + pub fn get_variable_mut(&mut self, id: VariableId) -> &mut VarDef<'a> { + self.var_defs.get_mut(id.0).unwrap() + } + + pub fn get_variable(&self, id: VariableId) -> &VarDef { + self.var_defs.get(id.0).unwrap() + } + + pub fn get_type(&self, name: &str) -> Option { + // TODO: change this to handle import + self.sym_table.get(name).map(|v| v.clone()) + } +} + +impl Type { + pub fn subst(&self, map: &HashMap>) -> Type { + match self { + Type::TypeVariable(id) => map.get(id).map(|v| v.as_ref()).unwrap_or(self).clone(), + Type::ParametricType(id, params) => Type::ParametricType( + *id, + params + .iter() + .map(|v| v.as_ref().subst(map).into()) + .collect(), + ), + _ => self.clone(), + } + } + + pub fn inv_subst(&self, map: &[(Rc, Rc)]) -> Rc { + for (from, to) in map.iter() { + if self == from.as_ref() { + return to.clone(); + } + } + match self { + Type::ParametricType(id, params) => Type::ParametricType( + *id, + params + .iter() + .map(|v| v.as_ref().inv_subst(map).into()) + .collect(), + ), + _ => self.clone(), + } + .into() + } + + pub fn get_subst(&self, ctx: &GlobalContext) -> HashMap> { + match self { + Type::ParametricType(id, params) => { + let vars = &ctx.get_parametric(*id).params; + vars.iter() + .zip(params) + .map(|(v, p)| (*v, p.as_ref().clone().into())) + .collect() + } + // if this proves to be slow, we can use option type + _ => HashMap::new(), + } + } + + pub fn get_base<'b: 'a, 'a>(&'a self, ctx: &'b GlobalContext) -> Option<&'b TypeDef> { + match self { + Type::PrimitiveType(id) => Some(ctx.get_primitive(*id)), + Type::ClassType(id) | Type::VirtualClassType(id) => Some(&ctx.get_class(*id).base), + Type::ParametricType(id, _) => Some(&ctx.get_parametric(*id).base), + _ => None, + } + } +}