diff --git a/nac3core/src/typecheck/type_inferencer.rs b/nac3core/src/typecheck/type_inferencer.rs index 3ccc3016..6c273ec9 100644 --- a/nac3core/src/typecheck/type_inferencer.rs +++ b/nac3core/src/typecheck/type_inferencer.rs @@ -38,6 +38,12 @@ impl<'a> Fold<()> for Inferencer<'a> { type InferenceResult = Result; impl<'a> Inferencer<'a> { + /// Constrain a <: b + /// Currently implemented as unification + fn constrain(&mut self, a: Type, b: Type) -> Result<(), String> { + self.unifier.unify(a, b) + } + fn build_method_call( &mut self, method: String, @@ -55,7 +61,7 @@ impl<'a> Inferencer<'a> { let call = self.unifier.add_ty(TypeEnum::TCall { calls: vec![call] }); let fields = once((method, call)).collect(); let record = self.unifier.add_ty(TypeEnum::TRecord { fields }); - self.unifier.unify(obj, record)?; + self.constrain(obj, record)?; Ok(ret) } @@ -114,15 +120,15 @@ impl<'a> Inferencer<'a> { fn infer_attribute(&mut self, value: &ast::Expr>, attr: &str) -> InferenceResult { let (attr_ty, _) = self.unifier.get_fresh_var(); let fields = once((attr.to_string(), attr_ty)).collect(); - let parent = self.unifier.add_ty(TypeEnum::TRecord { fields }); - self.unifier.unify(value.custom.unwrap(), parent)?; + let record = self.unifier.add_ty(TypeEnum::TRecord { fields }); + self.constrain(value.custom.unwrap(), record)?; Ok(attr_ty) } fn infer_bool_ops(&mut self, values: &[ast::Expr>]) -> InferenceResult { let b = self.primitives.bool; for v in values { - self.unifier.unify(v.custom.unwrap(), b)?; + self.constrain(v.custom.unwrap(), b)?; } Ok(b) } @@ -181,32 +187,30 @@ impl<'a> Inferencer<'a> { .iter() .flatten() { - self.unifier - .unify(self.primitives.int32, v.custom.unwrap())?; + self.constrain(v.custom.unwrap(), self.primitives.int32)?; } let list = self.unifier.add_ty(TypeEnum::TList { ty }); - self.unifier.unify(value.custom.unwrap(), list)?; + self.constrain(value.custom.unwrap(), list)?; Ok(list) } ast::ExprKind::Constant { value: ast::Constant::Int(val), .. } => { - // the index is a constant, so value can be a sequence (either list/tuple) + // the index is a constant, so value can be a sequence. let ind: i32 = val .try_into() .map_err(|_| "Index must be int32".to_string())?; let map = once((ind, ty)).collect(); let seq = self.unifier.add_ty(TypeEnum::TSeq { map }); - self.unifier.unify(value.custom.unwrap(), seq)?; + self.constrain(value.custom.unwrap(), seq)?; Ok(ty) } _ => { // the index is not a constant, so value can only be a list - self.unifier - .unify(slice.custom.unwrap(), self.primitives.int32)?; + self.constrain(slice.custom.unwrap(), self.primitives.int32)?; let list = self.unifier.add_ty(TypeEnum::TList { ty }); - self.unifier.unify(value.custom.unwrap(), list)?; + self.constrain(value.custom.unwrap(), list)?; Ok(ty) } } @@ -218,10 +222,10 @@ impl<'a> Inferencer<'a> { body: ast::Expr>, orelse: ast::Expr>, ) -> InferenceResult { - self.unifier - .unify(test.custom.unwrap(), self.primitives.bool)?; - self.unifier - .unify(body.custom.unwrap(), orelse.custom.unwrap())?; - Ok(body.custom.unwrap()) + self.constrain(test.custom.unwrap(), self.primitives.bool)?; + let ty = self.unifier.get_fresh_var().0; + self.constrain(body.custom.unwrap(), ty)?; + self.constrain(orelse.custom.unwrap(), ty)?; + Ok(ty) } } diff --git a/nac3core/src/typecheck/typedef.rs b/nac3core/src/typecheck/typedef.rs index 4c7eaa89..e9abbb16 100644 --- a/nac3core/src/typecheck/typedef.rs +++ b/nac3core/src/typecheck/typedef.rs @@ -188,9 +188,7 @@ impl Unifier { ) }; - let (ty_a, ty_b) = { - (ty_a_cell.borrow(), ty_b_cell.borrow()) - }; + let (ty_a, ty_b) = { (ty_a_cell.borrow(), ty_b_cell.borrow()) }; self.occur_check(a, b)?; match (&*ty_a, &*ty_b) {