diff --git a/nac3core/src/expression.rs b/nac3core/src/expression.rs index fd434aa601..c91e8c10a4 100644 --- a/nac3core/src/expression.rs +++ b/nac3core/src/expression.rs @@ -10,7 +10,7 @@ type SymTable<'a> = HashMap<&'a str, Rc>; type ParserResult = Result, String>; pub fn parse_expr(ctx: &GlobalContext, sym_table: &SymTable, expr: &Expression) -> ParserResult { - Err("not supported".to_string()) + Err("not supported".into()) } fn parse_constant( @@ -28,18 +28,18 @@ fn parse_constant( // } else if i64::try_from(&value).is_ok() { // Ok(PrimitiveType(INT64_TYPE).into()) // } else { - // Err("integer out of range".to_string()) + // Err("integer out of range".into()) // } } Number::Float { .. } => Ok(PrimitiveType(FLOAT_TYPE).into()), - _ => Err("not supported".to_string()), + _ => Err("not supported".into()), } } fn parse_identifier(_: &GlobalContext, sym_table: &SymTable, name: &str) -> ParserResult { match sym_table.get(name) { Some(v) => Ok(v.clone()), - None => Err("unbounded variable".to_string()), + None => Err("unbounded variable".into()), } } @@ -53,7 +53,7 @@ fn parse_list(ctx: &GlobalContext, sym_table: &SymTable, elements: &[Expression] let head = types.next().unwrap()?; for v in types { if v? != head { - return Err("inhomogeneous list is not allowed".to_string()); + return Err("inhomogeneous list is not allowed".into()); } } Ok(ParametricType(LIST_TYPE, vec![head]).into()) @@ -67,4 +67,57 @@ fn parse_tuple(ctx: &GlobalContext, sym_table: &SymTable, elements: &[Expression Ok(ParametricType(TUPLE_TYPE, types?).into()) } +fn parse_attribute( + ctx: &GlobalContext, + sym_table: &SymTable, + value: &Expression, + name: String, +) -> ParserResult { + let value = parse_expr(ctx, sym_table, value)?; + if let TypeVariable(id) = value.as_ref() { + let v = ctx.get_variable(*id); + if v.bound.len() == 0 { + return Err("no fields on unbounded type variable".into()); + } + let ty = v.bound[0] + .get_base(ctx) + .and_then(|v| v.fields.get(name.as_str())); + if ty.is_none() { + return Err("unknown field".into()); + } + for x in v.bound[1..].iter() { + let ty1 = x.get_base(ctx).and_then(|v| v.fields.get(name.as_str())); + if ty1 != ty { + return Err("unknown field (type mismatch between variants)".into()); + } + } + return Ok(ty.unwrap().clone()); + } + + match value.get_base(ctx) { + Some(b) => match b.fields.get(name.as_str()) { + Some(t) => Ok(t.clone()), + None => Err("no such field".into()), + }, + None => Err("this object has no fields".into()), + } +} + +fn parse_bool_ops( + ctx: &GlobalContext, + sym_table: &SymTable, + values: &[Expression], +) -> ParserResult { + assert_eq!(values.len(), 2); + let left = parse_expr(ctx, sym_table, &values[0])?; + let right = parse_expr(ctx, sym_table, &values[1])?; + + let b = PrimitiveType(BOOL_TYPE); + if left.as_ref() == &b && right.as_ref() == &b { + Ok(b.into()) + } else { + Err("bool operands must be bool".into()) + } +} + diff --git a/nac3core/src/typedef.rs b/nac3core/src/typedef.rs index d2517deb99..e4886b307b 100644 --- a/nac3core/src/typedef.rs +++ b/nac3core/src/typedef.rs @@ -35,7 +35,7 @@ pub struct FnDef { #[derive(Clone)] pub struct TypeDef<'a> { pub name: &'a str, - pub fields: HashMap<&'a str, Type>, + pub fields: HashMap<&'a str, Rc>, pub methods: HashMap<&'a str, FnDef>, } @@ -194,7 +194,8 @@ impl Type { .collect(), ), _ => self.clone(), - }.into() + } + .into() } pub fn get_subst(&self, ctx: &GlobalContext) -> HashMap> { @@ -210,5 +211,14 @@ impl Type { _ => HashMap::new(), } } -} + pub fn get_base<'b: 'a, 'a>(&'a self, ctx: &'b GlobalContext) -> Option<&'b TypeDef> { + match self { + Type::PrimitiveType(id) => Some(ctx.get_primitive(*id)), + Type::ClassType(id) + | Type::VirtualClassType(id) => Some(&ctx.get_class(*id).base), + Type::ParametricType(id, _) => Some(&ctx.get_parametric(*id).base), + _ => None, + } + } +}