diff --git a/nac3type/src/inference.rs b/nac3type/src/inference.rs index 685c07a6a9..bc875b6f96 100644 --- a/nac3type/src/inference.rs +++ b/nac3type/src/inference.rs @@ -1,3 +1,4 @@ +use super::primitives::*; use super::types::{Type::*, *}; use std::collections::HashMap; use std::rc::Rc; @@ -147,14 +148,14 @@ pub fn resolve_call( } else { return Err("divergent type after substitution".to_string()); } - }, + } PrimitiveType(id) => &ctx.get_primitive(*id), ClassType(id) | VirtualClassType(id) => &ctx.get_class(*id).base, ParametricType(id, _) => &ctx.get_parametric(*id).base, _ => return Err("not supported".to_string()), }; base.methods.get(func) - }, + } None => ctx.get_fn(func), } .ok_or("no such function".to_string())?; diff --git a/nac3type/src/lib.rs b/nac3type/src/lib.rs index 40217dcd46..a3df4ca59b 100644 --- a/nac3type/src/lib.rs +++ b/nac3type/src/lib.rs @@ -2,4 +2,5 @@ extern crate rustpython_parser; mod types; mod inference; +mod primitives; diff --git a/nac3type/src/primitives.rs b/nac3type/src/primitives.rs new file mode 100644 index 0000000000..8357a4a03d --- /dev/null +++ b/nac3type/src/primitives.rs @@ -0,0 +1,127 @@ +use super::types::{Type::*, *}; +use std::collections::HashMap; +use std::rc::Rc; + +pub const TUPLE_TYPE: ParamId = ParamId(0); +pub const LIST_TYPE: ParamId = ParamId(1); + +pub const BOOL_TYPE: PrimitiveId = PrimitiveId(0); +pub const INT32_TYPE: PrimitiveId = PrimitiveId(1); +pub const FLOAT_TYPE: PrimitiveId = PrimitiveId(2); + +fn impl_math(def: &mut TypeDef, ty: &Rc) { + let bin = Rc::new(ParametricType( + TUPLE_TYPE, + vec![SelfType.into(), ty.clone()], + )); + let result = Some(ty.clone()); + let fun = FnDef { + args: bin.clone(), + result, + }; + def.methods.insert("__add__", fun.clone()); + def.methods.insert("__sub__", fun.clone()); + def.methods.insert("__mul__", fun.clone()); + def.methods.insert("__neg__", fun.clone()); + def.methods.insert( + "__truediv__", + FnDef { + args: bin.clone(), + result: Some(PrimitiveType(FLOAT_TYPE).into()), + }, + ); + def.methods.insert("__floordiv__", fun.clone()); + def.methods.insert("__mod__", fun.clone()); + def.methods.insert("__pow__", fun.clone()); +} + +fn impl_bits(def: &mut TypeDef, ty: &Rc) { + let bin = Rc::new(ParametricType( + TUPLE_TYPE, + vec![SelfType.into(), PrimitiveType(INT32_TYPE).into()], + )); + let result = Some(ty.clone()); + let fun = FnDef { + args: bin.clone(), + result, + }; + + def.methods.insert("__lshift__", fun.clone()); + def.methods.insert("__rshift__", fun.clone()); + def.methods.insert( + "__xor__", + FnDef { + args: ParametricType(TUPLE_TYPE, vec![SelfType.into(), ty.clone()]).into(), + result: Some(ty.clone()), + }, + ); +} + +fn impl_eq(def: &mut TypeDef, ty: &Rc) { + let bin = Rc::new(ParametricType( + TUPLE_TYPE, + vec![SelfType.into(), ty.clone()], + )); + let fun = FnDef { + args: bin.clone(), + result: Some(PrimitiveType(BOOL_TYPE).into()), + }; + + def.methods.insert("__eq__", fun.clone()); + def.methods.insert("__ne__", fun.clone()); +} + +fn impl_order(def: &mut TypeDef, ty: &Rc) { + let bin = Rc::new(ParametricType( + TUPLE_TYPE, + vec![SelfType.into(), ty.clone()], + )); + let fun = FnDef { + args: bin.clone(), + result: Some(PrimitiveType(BOOL_TYPE).into()), + }; + + def.methods.insert("__lt__", fun.clone()); + def.methods.insert("__gt__", fun.clone()); + def.methods.insert("__le__", fun.clone()); + def.methods.insert("__ge__", fun.clone()); +} + +pub fn basic_ctx() -> GlobalContext<'static> { + let primitives = [ + TypeDef { + name: "bool", + fields: HashMap::new(), + methods: HashMap::new(), + }, + TypeDef { + name: "int32", + fields: HashMap::new(), + methods: HashMap::new(), + }, + TypeDef { + name: "float", + fields: HashMap::new(), + methods: HashMap::new(), + }, + ] + .to_vec(); + let mut ctx = GlobalContext::new(primitives); + + let b_def = ctx.get_primitive_mut(BOOL_TYPE); + let b = PrimitiveType(BOOL_TYPE).into(); + impl_eq(b_def, &b); + let int32_def = ctx.get_primitive_mut(INT32_TYPE); + let int32 = PrimitiveType(INT32_TYPE).into(); + impl_math(int32_def, &int32); + impl_bits(int32_def, &int32); + impl_order(int32_def, &int32); + impl_eq(int32_def, &int32); + let float_def = ctx.get_primitive_mut(FLOAT_TYPE); + let float = PrimitiveType(FLOAT_TYPE).into(); + impl_math(float_def, &float); + impl_order(float_def, &float); + impl_eq(float_def, &float); + + ctx +} diff --git a/nac3type/src/types.rs b/nac3type/src/types.rs index ad11c36fc8..e48f85ee5f 100644 --- a/nac3type/src/types.rs +++ b/nac3type/src/types.rs @@ -2,16 +2,16 @@ use std::collections::HashMap; use std::rc::Rc; #[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)] -pub struct PrimitiveId(usize); +pub struct PrimitiveId(pub(crate) usize); #[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)] -pub struct ClassId(usize); +pub struct ClassId(pub(crate) usize); #[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)] -pub struct ParamId(usize); +pub struct ParamId(pub(crate) usize); #[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)] -pub struct VariableId(usize); +pub struct VariableId(pub(crate) usize); #[derive(PartialEq, Eq, Clone, Hash, Debug)] pub enum Type { @@ -24,38 +24,37 @@ pub enum Type { TypeVariable(VariableId), } +#[derive(Clone)] pub struct FnDef { pub args: Rc, pub result: Option>, } +#[derive(Clone)] pub struct TypeDef<'a> { pub name: &'a str, pub fields: HashMap<&'a str, Type>, pub methods: HashMap<&'a str, FnDef>, } +#[derive(Clone)] pub struct ClassDef<'a> { pub base: TypeDef<'a>, pub parents: Vec, } +#[derive(Clone)] pub struct ParametricDef<'a> { pub base: TypeDef<'a>, pub params: Vec, } +#[derive(Clone)] pub struct VarDef<'a> { pub name: &'a str, pub bound: Vec>, } -pub const TUPLE_TYPE: ParamId = ParamId(0); -pub const LIST_TYPE: ParamId = ParamId(1); - -pub const BOOL_TYPE: PrimitiveId = PrimitiveId(0); -pub const INT32_TYPE: PrimitiveId = PrimitiveId(1); - pub struct GlobalContext<'a> { primitive_defs: Vec>, class_defs: Vec>, @@ -81,15 +80,16 @@ impl<'a> GlobalContext<'a> { }; } - pub fn add_class(&mut self, def: ClassDef<'a>) { + pub fn add_class(&mut self, def: ClassDef<'a>) -> ClassId { self.sym_table.insert( def.base.name, Type::ClassType(ClassId(self.class_defs.len())), ); self.class_defs.push(def); + ClassId(self.class_defs.len() - 1) } - pub fn add_parametric(&mut self, def: ParametricDef<'a>) { + pub fn add_parametric(&mut self, def: ParametricDef<'a>) -> ParamId { let params = def .params .iter() @@ -100,18 +100,20 @@ impl<'a> GlobalContext<'a> { Type::ParametricType(ParamId(self.parametric_defs.len()), params), ); self.parametric_defs.push(def); + ParamId(self.parametric_defs.len() - 1) } - pub fn add_variable(&mut self, def: VarDef<'a>) { + pub fn add_variable(&mut self, def: VarDef<'a>) -> VariableId { self.sym_table.insert( def.name, Type::TypeVariable(VariableId(self.var_defs.len())), ); - self.var_defs.push(def); + self.add_variable_private(def) } - pub fn add_variable_private(&mut self, def: VarDef<'a>) { + pub fn add_variable_private(&mut self, def: VarDef<'a>) -> VariableId { self.var_defs.push(def); + VariableId(self.var_defs.len() - 1) } pub fn add_fn(&'a mut self, name: &'a str, def: FnDef) {