From f026b48e2ac444c8c3335ddf77c68cc86591367f Mon Sep 17 00:00:00 2001 From: lyken Date: Thu, 13 Jun 2024 13:28:39 +0800 Subject: [PATCH] core: refactor to use `TypeVarId` and `TypeVar` --- nac3artiq/src/symbol_resolver.rs | 56 +++---- nac3core/src/codegen/concrete_type.rs | 24 +-- nac3core/src/codegen/expr.rs | 6 +- nac3core/src/toplevel/builtins.rs | 154 ++++++++---------- nac3core/src/toplevel/composer.rs | 42 ++--- nac3core/src/toplevel/helper.rs | 33 ++-- nac3core/src/toplevel/mod.rs | 7 +- nac3core/src/toplevel/numpy.rs | 6 +- ...el__test__test_analyze__generic_class.snap | 2 +- ...est__test_analyze__list_tuple_generic.snap | 4 +- ...t__test_analyze__simple_class_compose.snap | 4 +- nac3core/src/toplevel/test.rs | 6 +- nac3core/src/toplevel/type_annotation.rs | 6 +- nac3core/src/typecheck/magic_methods.rs | 16 +- nac3core/src/typecheck/type_inferencer/mod.rs | 30 ++-- .../src/typecheck/type_inferencer/test.rs | 17 +- nac3core/src/typecheck/typedef/mod.rs | 124 ++++++++------ nac3core/src/typecheck/typedef/test.rs | 70 ++++---- nac3standalone/src/main.rs | 4 +- 19 files changed, 309 insertions(+), 302 deletions(-) diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 16cb5ca..6ef6e95 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -9,7 +9,7 @@ use nac3core::{ }, typecheck::{ type_inferencer::PrimitiveStore, - typedef::{Type, TypeEnum, Unifier, VarMap}, + typedef::{iter_type_vars, to_var_map, Type, TypeEnum, TypeVar, Unifier, VarMap}, }, }; use nac3parser::ast::{self, StrRef}; @@ -317,13 +317,13 @@ impl InnerResolver { Ok(Ok((primitives.exception, true))) } else if ty_id == self.primitive_ids.list { // do not handle type var param and concrete check here - let var = unifier.get_dummy_var().0; + let var = unifier.get_dummy_var().ty; let list = unifier.add_ty(TypeEnum::TList { ty: var }); Ok(Ok((list, false))) } else if ty_id == self.primitive_ids.ndarray { // do not handle type var param and concrete check here - let var = unifier.get_dummy_var().0; - let ndims = unifier.get_fresh_const_generic_var(primitives.usize(), None, None).0; + let var = unifier.get_dummy_var().ty; + let ndims = unifier.get_fresh_const_generic_var(primitives.usize(), None, None).ty; let ndarray = make_ndarray_ty(unifier, primitives, Some(var), Some(ndims)); Ok(Ok((ndarray, false))) } else if ty_id == self.primitive_ids.tuple { @@ -383,7 +383,7 @@ impl InnerResolver { } if !is_const_generic && needs_defer { - result.push(unifier.get_dummy_var().0); + result.push(unifier.get_dummy_var().ty); } else { result.push({ match self.get_pyty_obj_type(py, constr, unifier, defs, primitives)? { @@ -426,9 +426,9 @@ impl InnerResolver { ))); } - unifier.get_fresh_const_generic_var(constraint_types[0], Some(name.into()), None).0 + unifier.get_fresh_const_generic_var(constraint_types[0], Some(name.into()), None).ty } else { - unifier.get_fresh_var_with_range(&constraint_types, Some(name.into()), None).0 + unifier.get_fresh_var_with_range(&constraint_types, Some(name.into()), None).ty }; Ok(Ok((res, true))) @@ -568,7 +568,7 @@ impl InnerResolver { } else if ty_id == self.primitive_ids.virtual_id { Ok(Ok(( { - let ty = TypeEnum::TVirtual { ty: unifier.get_dummy_var().0 }; + let ty = TypeEnum::TVirtual { ty: unifier.get_dummy_var().ty }; unifier.add_ty(ty) }, false, @@ -719,18 +719,16 @@ impl InnerResolver { unreachable!("must be tobj") }; - let var_map = params - .iter() - .map(|(id_var, ty)| { - let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(*ty) - else { - unreachable!() - }; + let var_map = to_var_map(iter_type_vars(params).map(|tvar| { + let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(tvar.ty) + else { + unreachable!() + }; - assert_eq!(*id, *id_var); - (*id, unifier.get_fresh_var_with_range(range, *name, *loc).0) - }) - .collect::(); + assert_eq!(*id, tvar.id); + let ty = unifier.get_fresh_var_with_range(range, *name, *loc).ty; + TypeVar { id: *id, ty } + })); return Ok(Ok(unifier.subst(primitives.option, &var_map).unwrap())); } @@ -748,18 +746,16 @@ impl InnerResolver { } (TypeEnum::TObj { params, fields, .. }, false) => { self.pyid_to_type.write().insert(py_obj_id, extracted_ty); - let var_map = params - .iter() - .map(|(id_var, ty)| { - let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(*ty) - else { - unreachable!() - }; + let var_map = to_var_map(iter_type_vars(params).map(|tvar| { + let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(tvar.ty) + else { + unreachable!() + }; - assert_eq!(*id, *id_var); - (*id, unifier.get_fresh_var_with_range(range, *name, *loc).0) - }) - .collect::(); + assert_eq!(*id, tvar.id); + let ty = unifier.get_fresh_var_with_range(range, *name, *loc).ty; + TypeVar { id: *id, ty } + })); let mut instantiate_obj = || { // loop through non-function fields of the class to get the instantiated value for field in fields { diff --git a/nac3core/src/codegen/concrete_type.rs b/nac3core/src/codegen/concrete_type.rs index 4b55654..7ac198a 100644 --- a/nac3core/src/codegen/concrete_type.rs +++ b/nac3core/src/codegen/concrete_type.rs @@ -3,7 +3,7 @@ use crate::{ toplevel::DefinitionId, typecheck::{ type_inferencer::PrimitiveStore, - typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, + typedef::{to_var_map, FunSignature, FuncArg, Type, TypeEnum, TypeVar, TypeVarId, Unifier}, }, }; @@ -51,7 +51,7 @@ pub enum ConcreteTypeEnum { TObj { obj_id: DefinitionId, fields: HashMap, - params: IndexMap, + params: IndexMap, }, TVirtual { ty: ConcreteType, @@ -59,7 +59,7 @@ pub enum ConcreteTypeEnum { TFunc { args: Vec, ret: ConcreteType, - vars: HashMap, + vars: HashMap, }, TLiteral { values: Vec, @@ -230,7 +230,7 @@ impl ConcreteTypeStore { return if let Some(ty) = ty { *ty } else { - *ty = Some(unifier.get_dummy_var().0); + *ty = Some(unifier.get_dummy_var().ty); ty.unwrap() }; } @@ -272,10 +272,10 @@ impl ConcreteTypeStore { (*name, (self.to_unifier_type(unifier, primitives, cty.0, cache), cty.1)) }) .collect::>(), - params: params - .iter() - .map(|(id, cty)| (*id, self.to_unifier_type(unifier, primitives, *cty, cache))) - .collect::(), + params: to_var_map(params.iter().map(|(&id, cty)| { + let ty = self.to_unifier_type(unifier, primitives, *cty, cache); + TypeVar { id, ty } + })), }, ConcreteTypeEnum::TFunc { args, ret, vars } => TypeEnum::TFunc(FunSignature { args: args @@ -287,10 +287,10 @@ impl ConcreteTypeStore { }) .collect(), ret: self.to_unifier_type(unifier, primitives, *ret, cache), - vars: vars - .iter() - .map(|(id, cty)| (*id, self.to_unifier_type(unifier, primitives, *cty, cache))) - .collect::(), + vars: to_var_map(vars.iter().map(|(&id, cty)| { + let ty = self.to_unifier_type(unifier, primitives, *cty, cache); + TypeVar { id, ty } + })), }), ConcreteTypeEnum::TLiteral { values, .. } => { TypeEnum::TLiteral { values: values.clone(), loc: None } diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index b1b32b4..d507548 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -22,7 +22,7 @@ use crate::{ }, typecheck::{ magic_methods::{binop_assign_name, binop_name, unaryop_name}, - typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, + typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap}, }, }; use inkwell::{ @@ -42,7 +42,7 @@ pub fn get_subst_key( unifier: &mut Unifier, obj: Option, fun_vars: &VarMap, - filter: Option<&Vec>, + filter: Option<&Vec>, ) -> String { let mut vars = obj .map(|ty| { @@ -81,7 +81,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { &mut self, obj: Option, fun: &FunSignature, - filter: Option<&Vec>, + filter: Option<&Vec>, ) -> String { get_subst_key(&mut self.unifier, obj, &fun.vars, filter) } diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index a0dbb5d..bfecd65 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -25,7 +25,7 @@ use crate::{ }, symbol_resolver::SymbolValue, toplevel::{helper::PrimDef, numpy::make_ndarray_ty}, - typecheck::typedef::VarMap, + typecheck::typedef::{iter_type_vars, to_var_map, TypeVar, VarMap}, }; use super::*; @@ -307,26 +307,26 @@ struct BuiltinBuilder<'a> { is_some_ty: (Type, bool), unwrap_ty: (Type, bool), - option_tvar: (Type, u32), + option_tvar: TypeVar, - ndarray_dtype_tvar: (Type, u32), - ndarray_ndims_tvar: (Type, u32), + ndarray_dtype_tvar: TypeVar, + ndarray_ndims_tvar: TypeVar, ndarray_copy_ty: (Type, bool), ndarray_fill_ty: (Type, bool), list_int32: Type, - num_ty: (Type, u32), + num_ty: TypeVar, num_var_map: VarMap, ndarray_float: Type, ndarray_float_2d: Type, ndarray_num_ty: Type, - float_or_ndarray_ty: (Type, u32), + float_or_ndarray_ty: TypeVar, float_or_ndarray_var_map: VarMap, - num_or_ndarray_ty: (Type, u32), + num_or_ndarray_ty: TypeVar, num_or_ndarray_var_map: VarMap, } @@ -350,7 +350,7 @@ impl<'a> BuiltinBuilder<'a> { ( *fields.get(&PrimDef::OptionIsSome.simple_name().into()).unwrap(), *fields.get(&PrimDef::OptionUnwrap.simple_name().into()).unwrap(), - (*params.iter().next().unwrap().1, *params.iter().next().unwrap().0), + iter_type_vars(params).next().unwrap(), ) } else { unreachable!() @@ -361,10 +361,8 @@ impl<'a> BuiltinBuilder<'a> { else { unreachable!() }; - let ndarray_dtype_tvar = - ndarray_params.iter().next().map(|(var_id, ty)| (*ty, *var_id)).unwrap(); - let ndarray_ndims_tvar = - ndarray_params.iter().nth(1).map(|(var_id, ty)| (*ty, *var_id)).unwrap(); + let ndarray_dtype_tvar = iter_type_vars(ndarray_params).next().unwrap(); + let ndarray_ndims_tvar = iter_type_vars(ndarray_params).nth(1).unwrap(); let ndarray_copy_ty = *ndarray_fields.get(&PrimDef::NDArrayCopy.simple_name().into()).unwrap(); let ndarray_fill_ty = @@ -375,7 +373,7 @@ impl<'a> BuiltinBuilder<'a> { Some("N".into()), None, ); - let num_var_map: VarMap = vec![(num_ty.1, num_ty.0)].into_iter().collect(); + let num_var_map = to_var_map([num_ty]); let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), None); let ndarray_float_2d = { @@ -389,18 +387,14 @@ impl<'a> BuiltinBuilder<'a> { make_ndarray_ty(unifier, primitives, Some(float), Some(ndims)) }; - let ndarray_num_ty = make_ndarray_ty(unifier, primitives, Some(num_ty.0), None); + let ndarray_num_ty = make_ndarray_ty(unifier, primitives, Some(num_ty.ty), None); let float_or_ndarray_ty = unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); - let float_or_ndarray_var_map: VarMap = - vec![(float_or_ndarray_ty.1, float_or_ndarray_ty.0)].into_iter().collect(); + let float_or_ndarray_var_map = to_var_map([float_or_ndarray_ty]); let num_or_ndarray_ty = - unifier.get_fresh_var_with_range(&[num_ty.0, ndarray_num_ty], Some("T".into()), None); - let num_or_ndarray_var_map: VarMap = - vec![(num_ty.1, num_ty.0), (num_or_ndarray_ty.1, num_or_ndarray_ty.0)] - .into_iter() - .collect(); + unifier.get_fresh_var_with_range(&[num_ty.ty, ndarray_num_ty], Some("T".into()), None); + let num_or_ndarray_var_map = to_var_map([num_ty, num_or_ndarray_ty]); let list_int32 = unifier.add_ty(TypeEnum::TList { ty: int32 }); @@ -648,7 +642,7 @@ impl<'a> BuiltinBuilder<'a> { PrimDef::Option => TopLevelDef::Class { name: prim.name().into(), object_id: prim.id(), - type_vars: vec![self.option_tvar.0], + type_vars: vec![self.option_tvar.ty], fields: vec![], methods: vec![ Self::create_method(PrimDef::OptionIsSome, self.is_some_ty.0), @@ -668,7 +662,7 @@ impl<'a> BuiltinBuilder<'a> { name: prim.name().into(), simple_name: prim.simple_name().into(), signature: self.unwrap_ty.0, - var_id: vec![self.option_tvar.1], + var_id: vec![self.option_tvar.id], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, @@ -682,7 +676,7 @@ impl<'a> BuiltinBuilder<'a> { name: prim.name().to_string(), simple_name: prim.simple_name().into(), signature: self.is_some_ty.0, - var_id: vec![self.option_tvar.1], + var_id: vec![self.option_tvar.id], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, @@ -719,13 +713,13 @@ impl<'a> BuiltinBuilder<'a> { signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![FuncArg { name: "n".into(), - ty: self.option_tvar.0, + ty: self.option_tvar.ty, default_value: None, }], ret: self.primitives.option, - vars: VarMap::from([(self.option_tvar.1, self.option_tvar.0)]), + vars: to_var_map([self.option_tvar]), })), - var_id: vec![self.option_tvar.1], + var_id: vec![self.option_tvar.id], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, @@ -761,7 +755,7 @@ impl<'a> BuiltinBuilder<'a> { PrimDef::NDArray => TopLevelDef::Class { name: prim.name().into(), object_id: prim.id(), - type_vars: vec![self.ndarray_dtype_tvar.0, self.ndarray_ndims_tvar.0], + type_vars: vec![self.ndarray_dtype_tvar.ty, self.ndarray_ndims_tvar.ty], fields: Vec::default(), methods: vec![ Self::create_method(PrimDef::NDArrayCopy, self.ndarray_copy_ty.0), @@ -777,7 +771,7 @@ impl<'a> BuiltinBuilder<'a> { name: prim.name().into(), simple_name: prim.simple_name().into(), signature: self.ndarray_copy_ty.0, - var_id: vec![self.ndarray_dtype_tvar.1, self.ndarray_ndims_tvar.1], + var_id: vec![self.ndarray_dtype_tvar.id, self.ndarray_ndims_tvar.id], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, @@ -794,7 +788,7 @@ impl<'a> BuiltinBuilder<'a> { name: prim.name().into(), simple_name: prim.simple_name().into(), signature: self.ndarray_fill_ty.0, - var_id: vec![self.ndarray_dtype_tvar.1, self.ndarray_ndims_tvar.1], + var_id: vec![self.ndarray_dtype_tvar.id, self.ndarray_ndims_tvar.id], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, @@ -831,10 +825,10 @@ impl<'a> BuiltinBuilder<'a> { signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![FuncArg { name: "n".into(), - ty: self.num_or_ndarray_ty.0, + ty: self.num_or_ndarray_ty.ty, default_value: None, }], - ret: self.num_or_ndarray_ty.0, + ret: self.num_or_ndarray_ty.ty, vars: self.num_or_ndarray_var_map.clone(), })), var_id: Vec::default(), @@ -884,9 +878,9 @@ impl<'a> BuiltinBuilder<'a> { let int_sized = size_variant.of_int(self.primitives); let ndarray_int_sized = - make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.0)); + make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.ty)); let ndarray_float = - make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.0)); + make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.ty)); let p0_ty = self.unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); @@ -898,12 +892,10 @@ impl<'a> BuiltinBuilder<'a> { create_fn_by_codegen( self.unifier, - &[(common_ndim.1, common_ndim.0), (p0_ty.1, p0_ty.0), (ret_ty.1, ret_ty.0)] - .into_iter() - .collect(), + &to_var_map([common_ndim, p0_ty, ret_ty]), prim.name(), - ret_ty.0, - &[(p0_ty.0, "n")], + ret_ty.ty, + &[(p0_ty.ty, "n")], Box::new(move |ctx, _, fun, args, generator| { let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; @@ -946,12 +938,12 @@ impl<'a> BuiltinBuilder<'a> { ); let ndarray_float = - make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.0)); + make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.ty)); // The size variant of the function determines the type of int returned let int_sized = size_variant.of_int(self.primitives); let ndarray_int_sized = - make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.0)); + make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.ty)); let p0_ty = self.unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); @@ -964,12 +956,10 @@ impl<'a> BuiltinBuilder<'a> { create_fn_by_codegen( self.unifier, - &[(common_ndim.1, common_ndim.0), (p0_ty.1, p0_ty.0), (ret_ty.1, ret_ty.0)] - .into_iter() - .collect(), + &to_var_map([common_ndim, p0_ty, ret_ty]), prim.name(), - ret_ty.0, - &[(p0_ty.0, "n")], + ret_ty.ty, + &[(p0_ty.ty, "n")], Box::new(move |ctx, _, fun, args, generator| { let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; @@ -1031,7 +1021,7 @@ impl<'a> BuiltinBuilder<'a> { simple_name: prim.simple_name().into(), signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![ - FuncArg { name: "object".into(), ty: tv.0, default_value: None }, + FuncArg { name: "object".into(), ty: tv.ty, default_value: None }, FuncArg { name: "copy".into(), ty: bool, @@ -1044,9 +1034,9 @@ impl<'a> BuiltinBuilder<'a> { }, ], ret: ndarray, - vars: VarMap::from([(tv.1, tv.0)]), + vars: to_var_map([tv]), })), - var_id: vec![tv.1], + var_id: vec![tv.id], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, @@ -1064,12 +1054,12 @@ impl<'a> BuiltinBuilder<'a> { create_fn_by_codegen( self.unifier, - &[(tv.1, tv.0)].into_iter().collect(), + &to_var_map([tv]), prim.name(), self.primitives.ndarray, // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a // type variable - &[(self.list_int32, "shape"), (tv.0, "fill_value")], + &[(self.list_int32, "shape"), (tv.ty, "fill_value")], Box::new(move |ctx, obj, fun, args, generator| { gen_ndarray_full(ctx, &obj, fun, &args, generator) .map(|val| Some(val.as_basic_value_enum())) @@ -1287,8 +1277,8 @@ impl<'a> BuiltinBuilder<'a> { self.unifier, &self.float_or_ndarray_var_map, prim.name(), - self.float_or_ndarray_ty.0, - &[(self.float_or_ndarray_ty.0, "n")], + self.float_or_ndarray_ty.ty, + &[(self.float_or_ndarray_ty.ty, "n")], Box::new(move |ctx, _, fun, args, generator| { let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; @@ -1311,8 +1301,8 @@ impl<'a> BuiltinBuilder<'a> { self.unifier, &self.float_or_ndarray_var_map, prim.name(), - self.float_or_ndarray_ty.0, - &[(self.float_or_ndarray_ty.0, "n")], + self.float_or_ndarray_ty.ty, + &[(self.float_or_ndarray_ty.ty, "n")], Box::new(|ctx, _, fun, args, generator| { let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; @@ -1328,9 +1318,9 @@ impl<'a> BuiltinBuilder<'a> { let PrimitiveStore { uint64, int32, .. } = *self.primitives; let tvar = self.unifier.get_fresh_var(Some("L".into()), None); - let list = self.unifier.add_ty(TypeEnum::TList { ty: tvar.0 }); + let list = self.unifier.add_ty(TypeEnum::TList { ty: tvar.ty }); let ndims = self.unifier.get_fresh_const_generic_var(uint64, Some("N".into()), None); - let ndarray = make_ndarray_ty(self.unifier, self.primitives, Some(tvar.0), Some(ndims.0)); + let ndarray = make_ndarray_ty(self.unifier, self.primitives, Some(tvar.ty), Some(ndims.ty)); let arg_ty = self.unifier.get_fresh_var_with_range( &[list, ndarray, self.primitives.range], @@ -1341,9 +1331,9 @@ impl<'a> BuiltinBuilder<'a> { name: prim.name().into(), simple_name: prim.simple_name().into(), signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "ls".into(), ty: arg_ty.0, default_value: None }], + args: vec![FuncArg { name: "ls".into(), ty: arg_ty.ty, default_value: None }], ret: int32, - vars: vec![(tvar.1, tvar.0), (arg_ty.1, arg_ty.0)].into_iter().collect(), + vars: to_var_map([tvar, arg_ty]), })), var_id: Vec::default(), instance_to_symbol: HashMap::default(), @@ -1446,10 +1436,10 @@ impl<'a> BuiltinBuilder<'a> { simple_name: prim.simple_name().into(), signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![ - FuncArg { name: "m".into(), ty: self.num_ty.0, default_value: None }, - FuncArg { name: "n".into(), ty: self.num_ty.0, default_value: None }, + FuncArg { name: "m".into(), ty: self.num_ty.ty, default_value: None }, + FuncArg { name: "n".into(), ty: self.num_ty.ty, default_value: None }, ], - ret: self.num_ty.0, + ret: self.num_ty.ty, vars: self.num_var_map.clone(), })), var_id: Vec::default(), @@ -1484,15 +1474,15 @@ impl<'a> BuiltinBuilder<'a> { .num_or_ndarray_var_map .clone() .into_iter() - .chain(once((ret_ty.1, ret_ty.0))) + .chain(once((ret_ty.id, ret_ty.ty))) .collect::>(); create_fn_by_codegen( self.unifier, &var_map, prim.name(), - ret_ty.0, - &[(self.float_or_ndarray_ty.0, "a")], + ret_ty.ty, + &[(self.float_or_ndarray_ty.ty, "a")], Box::new(move |ctx, _, fun, args, generator| { let a_ty = fun.0.args[0].ty; let a = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?; @@ -1512,9 +1502,9 @@ impl<'a> BuiltinBuilder<'a> { fn build_np_minimum_maximum_function(&mut self, prim: PrimDef) -> TopLevelDef { debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpMinimum, PrimDef::FunNpMaximum]); - let x1_ty = self.new_type_or_ndarray_ty(self.num_ty.0); - let x2_ty = self.new_type_or_ndarray_ty(self.num_ty.0); - let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; + let x1_ty = self.new_type_or_ndarray_ty(self.num_ty.ty); + let x2_ty = self.new_type_or_ndarray_ty(self.num_ty.ty); + let param_ty = &[(x1_ty.ty, "x1"), (x2_ty.ty, "x2")]; let ret_ty = self.unifier.get_fresh_var(None, None); TopLevelDef::Function { @@ -1525,12 +1515,10 @@ impl<'a> BuiltinBuilder<'a> { .iter() .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None }) .collect(), - ret: ret_ty.0, - vars: [(x1_ty.1, x1_ty.0), (x2_ty.1, x2_ty.0), (ret_ty.1, ret_ty.0)] - .into_iter() - .collect(), + ret: ret_ty.ty, + vars: to_var_map([x1_ty, x2_ty, ret_ty]), })), - var_id: vec![x1_ty.1, x2_ty.1], + var_id: vec![x1_ty.id, x2_ty.id], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, @@ -1564,10 +1552,10 @@ impl<'a> BuiltinBuilder<'a> { signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![FuncArg { name: "n".into(), - ty: self.num_or_ndarray_ty.0, + ty: self.num_or_ndarray_ty.ty, default_value: None, }], - ret: self.num_or_ndarray_ty.0, + ret: self.num_or_ndarray_ty.ty, vars: self.num_or_ndarray_var_map.clone(), })), var_id: Vec::default(), @@ -1660,8 +1648,8 @@ impl<'a> BuiltinBuilder<'a> { self.unifier, &self.float_or_ndarray_var_map, prim.name(), - self.float_or_ndarray_ty.0, - &[(self.float_or_ndarray_ty.0, arg_name)], + self.float_or_ndarray_ty.ty, + &[(self.float_or_ndarray_ty.ty, arg_name)], Box::new(move |ctx, _, fun, args, generator| { let arg_ty = fun.0.args[0].ty; let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; @@ -1745,7 +1733,7 @@ impl<'a> BuiltinBuilder<'a> { let x1_ty = self.new_type_or_ndarray_ty(x1_ty); let x2_ty = self.new_type_or_ndarray_ty(x2_ty); - let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; + let param_ty = &[(x1_ty.ty, "x1"), (x2_ty.ty, "x2")]; let ret_ty = self.unifier.get_fresh_var(None, None); TopLevelDef::Function { @@ -1756,12 +1744,10 @@ impl<'a> BuiltinBuilder<'a> { .iter() .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None }) .collect(), - ret: ret_ty.0, - vars: [(x1_ty.1, x1_ty.0), (x2_ty.1, x2_ty.0), (ret_ty.1, ret_ty.0)] - .into_iter() - .collect(), + ret: ret_ty.ty, + vars: to_var_map([x1_ty, x2_ty, ret_ty]), })), - var_id: vec![ret_ty.1], + var_id: vec![ret_ty.id], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, @@ -1794,7 +1780,7 @@ impl<'a> BuiltinBuilder<'a> { (prim.simple_name().into(), method_ty, prim.id()) } - fn new_type_or_ndarray_ty(&mut self, scalar_ty: Type) -> (Type, u32) { + fn new_type_or_ndarray_ty(&mut self, scalar_ty: Type) -> TypeVar { let ndarray = make_ndarray_ty(self.unifier, self.primitives, Some(scalar_ty), None); self.unifier.get_fresh_var_with_range(&[scalar_ty, ndarray], Some("T".into()), None) diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index ccedd82..ff0e82c 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -6,7 +6,7 @@ use crate::{ symbol_resolver::SymbolValue, typecheck::{ type_inferencer::{FunctionData, Inferencer}, - typedef::VarMap, + typedef::{TypeVar, VarMap}, }, }; @@ -225,7 +225,7 @@ impl TopLevelComposer { // since later when registering class method, ast will still be used, // here push None temporarily, later will move the ast inside - let constructor_ty = self.unifier.get_dummy_var().0; + let constructor_ty = self.unifier.get_dummy_var().ty; let mut class_def_ast = ( Arc::new(RwLock::new(Self::make_top_level_class_def( DefinitionId(class_def_id), @@ -281,7 +281,7 @@ impl TopLevelComposer { }; // dummy method define here - let dummy_method_type = self.unifier.get_dummy_var().0; + let dummy_method_type = self.unifier.get_dummy_var().ty; class_method_name_def_ids.push(( *method_name, RwLock::new(Self::make_top_level_function_def( @@ -337,7 +337,7 @@ impl TopLevelComposer { } let fun_name = *name; - let ty_to_be_unified = self.unifier.get_dummy_var().0; + let ty_to_be_unified = self.unifier.get_dummy_var().ty; // add to the definition list self.definition_ast_list.push(( RwLock::new(Self::make_top_level_function_def( @@ -452,7 +452,7 @@ impl TopLevelComposer { // check if all are unique type vars let all_unique_type_var = { - let mut occurred_type_var_id: HashSet = HashSet::new(); + let mut occurred_type_var_id: HashSet = HashSet::new(); type_vars.iter().all(|x| { let ty = unifier.get_ty(*x); if let TypeEnum::TVar { id, .. } = ty.as_ref() { @@ -917,18 +917,19 @@ impl TopLevelComposer { let type_vars_within = get_type_var_contained_in_type_annotation(&type_annotation) .into_iter() - .map(|x| -> Result<(u32, Type), HashSet> { + .map(|x| -> Result> { let TypeAnnotation::TypeVar(ty) = x else { unreachable!("must be type var annotation kind") }; - Ok((Self::get_var_id(ty, unifier)?, ty)) + let id = Self::get_var_id(ty, unifier)?; + Ok(TypeVar { id, ty }) }) .collect::, _>>()?; - for (id, ty) in type_vars_within { - if let Some(prev_ty) = function_var_map.insert(id, ty) { + for var in type_vars_within { + if let Some(prev_ty) = function_var_map.insert(var.id, var.ty) { // if already have the type inserted, make sure they are the same thing - assert_eq!(prev_ty, ty); + assert_eq!(prev_ty, var.ty); } } @@ -982,18 +983,19 @@ impl TopLevelComposer { let type_vars_within = get_type_var_contained_in_type_annotation(&return_ty_annotation) .into_iter() - .map(|x| -> Result<(u32, Type), HashSet> { + .map(|x| -> Result> { let TypeAnnotation::TypeVar(ty) = x else { unreachable!("must be type var here") }; - Ok((Self::get_var_id(ty, unifier)?, ty)) + let id = Self::get_var_id(ty, unifier)?; + Ok(TypeVar { id, ty }) }) .collect::, _>>()?; - for (id, ty) in type_vars_within { - if let Some(prev_ty) = function_var_map.insert(id, ty) { + for var in type_vars_within { + if let Some(prev_ty) = function_var_map.insert(var.id, var.ty) { // if already have the type inserted, make sure they are the same thing - assert_eq!(prev_ty, ty); + assert_eq!(prev_ty, var.ty); } } @@ -1177,7 +1179,7 @@ impl TopLevelComposer { // finish handling type vars let dummy_func_arg = FuncArg { name, - ty: unifier.get_dummy_var().0, + ty: unifier.get_dummy_var().ty, default_value: match default { None => None, Some(default) => { @@ -1240,13 +1242,13 @@ impl TopLevelComposer { assert_eq!(prev_ty, ty); } } - let dummy_return_type = unifier.get_dummy_var().0; + let dummy_return_type = unifier.get_dummy_var().ty; type_var_to_concrete_def.insert(dummy_return_type, annotation.clone()); dummy_return_type } else { // if do not have return annotation, return none // for uniform handling, still use type annotation - let dummy_return_type = unifier.get_dummy_var().0; + let dummy_return_type = unifier.get_dummy_var().ty; type_var_to_concrete_def.insert( dummy_return_type, TypeAnnotation::Primitive(primitives.none), @@ -1286,7 +1288,7 @@ impl TopLevelComposer { ast::StmtKind::AnnAssign { target, annotation, value: None, .. } => { if let ast::ExprKind::Name { id: attr, .. } = &target.node { if defined_fields.insert(attr.to_string()) { - let dummy_field_type = unifier.get_dummy_var().0; + let dummy_field_type = unifier.get_dummy_var().ty; // handle Kernel[T], KernelInvariant[T] let (annotation, mutable) = match &annotation.node { @@ -1749,7 +1751,7 @@ impl TopLevelComposer { unreachable!() }; - let rigid = unifier.get_fresh_rigid_var(*name, *loc).0; + let rigid = unifier.get_fresh_rigid_var(*name, *loc).ty; no_ranges.push(rigid); vec![rigid] }) diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 154740b..8cc77a9 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -2,7 +2,7 @@ use std::convert::TryInto; use crate::symbol_resolver::SymbolValue; use crate::toplevel::numpy::unpack_ndarray_var_tys; -use crate::typecheck::typedef::{Mapping, VarMap}; +use crate::typecheck::typedef::{to_var_map, Mapping, TypeVarId, VarMap}; use nac3parser::ast::{Constant, Location}; use strum::IntoEnumIterator; use strum_macros::EnumIter; @@ -377,12 +377,12 @@ impl TopLevelComposer { let is_some_type_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![], ret: bool, - vars: VarMap::from([(option_type_var.1, option_type_var.0)]), + vars: to_var_map([option_type_var]), })); let unwrap_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![], - ret: option_type_var.0, - vars: VarMap::from([(option_type_var.1, option_type_var.0)]), + ret: option_type_var.ty, + vars: to_var_map([option_type_var]), })); let option = unifier.add_ty(TypeEnum::TObj { obj_id: PrimDef::Option.id(), @@ -393,7 +393,7 @@ impl TopLevelComposer { ] .into_iter() .collect::>(), - params: VarMap::from([(option_type_var.1, option_type_var.0)]), + params: to_var_map([option_type_var]), }); let size_t_ty = match size_t { @@ -408,23 +408,17 @@ impl TopLevelComposer { let ndarray_copy_fun_ret_ty = unifier.get_fresh_var(None, None); let ndarray_copy_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![], - ret: ndarray_copy_fun_ret_ty.0, - vars: VarMap::from([ - (ndarray_dtype_tvar.1, ndarray_dtype_tvar.0), - (ndarray_ndims_tvar.1, ndarray_ndims_tvar.0), - ]), + ret: ndarray_copy_fun_ret_ty.ty, + vars: to_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]), })); let ndarray_fill_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![FuncArg { name: "value".into(), - ty: ndarray_dtype_tvar.0, + ty: ndarray_dtype_tvar.ty, default_value: None, }], ret: none, - vars: VarMap::from([ - (ndarray_dtype_tvar.1, ndarray_dtype_tvar.0), - (ndarray_ndims_tvar.1, ndarray_ndims_tvar.0), - ]), + vars: to_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]), })); let ndarray = unifier.add_ty(TypeEnum::TObj { obj_id: PrimDef::NDArray.id(), @@ -432,13 +426,10 @@ impl TopLevelComposer { (PrimDef::NDArrayCopy.simple_name().into(), (ndarray_copy_fun_ty, true)), (PrimDef::NDArrayFill.simple_name().into(), (ndarray_fill_fun_ty, true)), ]), - params: VarMap::from([ - (ndarray_dtype_tvar.1, ndarray_dtype_tvar.0), - (ndarray_ndims_tvar.1, ndarray_ndims_tvar.0), - ]), + params: to_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]), }); - unifier.unify(ndarray_copy_fun_ret_ty.0, ndarray).unwrap(); + unifier.unify(ndarray_copy_fun_ret_ty.ty, ndarray).unwrap(); let primitives = PrimitiveStore { int32, @@ -583,7 +574,7 @@ impl TopLevelComposer { } /// get the `var_id` of a given `TVar` type - pub fn get_var_id(var_ty: Type, unifier: &mut Unifier) -> Result> { + pub fn get_var_id(var_ty: Type, unifier: &mut Unifier) -> Result> { if let TypeEnum::TVar { id, .. } = unifier.get_ty(var_ty).as_ref() { Ok(*id) } else { diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index f7fa92b..21e2fac 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -14,7 +14,10 @@ use super::typecheck::typedef::{ use crate::{ codegen::CodeGenerator, symbol_resolver::{SymbolResolver, ValueEnum}, - typecheck::{type_inferencer::CodeLocation, typedef::CallId}, + typecheck::{ + type_inferencer::CodeLocation, + typedef::{CallId, TypeVarId}, + }, }; use inkwell::values::BasicValueEnum; use itertools::{izip, Itertools}; @@ -119,7 +122,7 @@ pub enum TopLevelDef { /// Function signature. signature: Type, /// Instantiated type variable IDs. - var_id: Vec, + var_id: Vec, /// Function instance to symbol mapping /// /// * Key: String representation of type variable values, sorted by variable ID in ascending diff --git a/nac3core/src/toplevel/numpy.rs b/nac3core/src/toplevel/numpy.rs index b6e0ca5..63f6173 100644 --- a/nac3core/src/toplevel/numpy.rs +++ b/nac3core/src/toplevel/numpy.rs @@ -2,7 +2,7 @@ use crate::{ toplevel::helper::PrimDef, typecheck::{ type_inferencer::PrimitiveStore, - typedef::{Type, TypeEnum, Unifier, VarMap}, + typedef::{Type, TypeEnum, TypeVarId, Unifier, VarMap}, }, }; use itertools::Itertools; @@ -57,7 +57,7 @@ pub fn subst_ndarray_tvars( unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray) } -fn unpack_ndarray_tvars(unifier: &mut Unifier, ndarray: Type) -> Vec<(u32, Type)> { +fn unpack_ndarray_tvars(unifier: &mut Unifier, ndarray: Type) -> Vec<(TypeVarId, Type)> { let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else { panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray)) }; @@ -74,7 +74,7 @@ fn unpack_ndarray_tvars(unifier: &mut Unifier, ndarray: Type) -> Vec<(u32, Type) /// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds /// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` /// respectively. -pub fn unpack_ndarray_var_ids(unifier: &mut Unifier, ndarray: Type) -> (u32, u32) { +pub fn unpack_ndarray_var_ids(unifier: &mut Unifier, ndarray: Type) -> (TypeVarId, TypeVarId) { unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.0).collect_tuple().unwrap() } diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index fb41467..33e4433 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -5,7 +5,7 @@ expression: res_vec [ "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [239]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(239)]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index 75b4d4b..8f3d9bf 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [241]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [246]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(241)]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(246)]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index ba6b236..d58f5f1 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [247]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(247)]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [255]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(255)]\n}\n", ] diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index b9514da..dc4251e 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -782,11 +782,11 @@ fn make_internal_resolver_with_tvar( .into_iter() .map(|(name, range)| { (name, { - let (ty, id) = unifier.get_fresh_var_with_range(range.as_slice(), None, None); + let tvar = unifier.get_fresh_var_with_range(range.as_slice(), None, None); if print { - println!("{}: {:?}, typevar{}", name, ty, id); + println!("{}: {:?}, typevar{}", name, tvar.ty, tvar.id); } - ty + tvar.ty }) }) .collect::>() diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index e7fe0c7..7b5c405 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -123,7 +123,7 @@ pub fn parse_ast_to_type_annotation_kinds( Ok(TypeAnnotation::CustomClass { id: obj_id, params: vec![] }) } else if let Ok(ty) = resolver.get_symbol_type(unifier, top_level_defs, primitives, *id) { if let TypeEnum::TVar { .. } = unifier.get_ty(ty).as_ref() { - let var = unifier.get_fresh_var(Some(*id), Some(expr.location)).0; + let var = unifier.get_fresh_var(Some(*id), Some(expr.location)).ty; unifier.unify(var, ty).unwrap(); Ok(TypeAnnotation::TypeVar(ty)) } else { @@ -426,7 +426,7 @@ pub fn get_type_from_type_annotation_kinds( *name, *loc, ); - unifier.unify(temp.0, p).is_ok() + unifier.unify(temp.ty, p).is_ok() } }; if ok { @@ -451,7 +451,7 @@ pub fn get_type_from_type_annotation_kinds( // create a temp type var and unify to check compatibility p == *tvar || { let temp = unifier.get_fresh_const_generic_var(ty, *name, *loc); - unifier.unify(temp.0, p).is_ok() + unifier.unify(temp.ty, p).is_ok() } }; if ok { diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index c1e5ac6..f2b995e 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -103,8 +103,8 @@ pub fn impl_binop( let (other_ty, other_var_id) = if other_ty.len() == 1 { (other_ty[0], None) } else { - let (ty, var_id) = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None); - (ty, Some(var_id)) + let tvar = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None); + (tvar.ty, Some(tvar.id)) }; let function_vars = if let Some(var_id) = other_var_id { @@ -113,7 +113,7 @@ pub fn impl_binop( VarMap::new() }; - let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).0); + let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty); for op in ops { fields.insert(binop_name(*op).into(), { @@ -151,7 +151,7 @@ pub fn impl_binop( pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Option, ops: &[Unaryop]) { with_fields(unifier, ty, |unifier, fields| { - let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).0); + let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty); for op in ops { fields.insert( @@ -181,8 +181,8 @@ pub fn impl_cmpop( let (other_ty, other_var_id) = if other_ty.len() == 1 { (other_ty[0], None) } else { - let (ty, var_id) = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None); - (ty, Some(var_id)) + let tvar = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None); + (tvar.ty, Some(tvar.id)) }; let function_vars = if let Some(var_id) = other_var_id { @@ -191,7 +191,7 @@ pub fn impl_cmpop( VarMap::new() }; - let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).0); + let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty); for op in ops { fields.insert( @@ -652,7 +652,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None); let ndarray_unsized_t = - make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.0)); + make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.ty)); let (ndarray_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_t); let (ndarray_unsized_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_unsized_t); impl_basic_arithmetic( diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index c3773b7..251b91e 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -432,7 +432,7 @@ impl<'a> Fold<()> for Inferencer<'a> { let enter = TypeEnum::TFunc(FunSignature { args: vec![], ret: item.optional_vars.as_ref().map_or_else( - || self.unifier.get_dummy_var().0, + || self.unifier.get_dummy_var().ty, |var| var.custom.unwrap(), ), vars: VarMap::default(), @@ -440,7 +440,7 @@ impl<'a> Fold<()> for Inferencer<'a> { let enter = self.unifier.add_ty(enter); let exit = TypeEnum::TFunc(FunSignature { args: vec![], - ret: self.unifier.get_dummy_var().0, + ret: self.unifier.get_dummy_var().ty, vars: VarMap::default(), }); let exit = self.unifier.add_ty(exit); @@ -511,7 +511,7 @@ impl<'a> Fold<()> for Inferencer<'a> { }; assert_eq!(*id, *id_var); - (*id, self.unifier.get_fresh_var_with_range(range, *name, *loc).0) + (*id, self.unifier.get_fresh_var_with_range(range, *name, *loc).ty) }) .collect::(); Some(self.unifier.subst(self.primitives.option, &var_map).unwrap()) @@ -660,7 +660,7 @@ impl<'a> Inferencer<'a> { } } } - let ret = ret.unwrap_or_else(|| self.unifier.get_dummy_var().0); + let ret = ret.unwrap_or_else(|| self.unifier.get_dummy_var().ty); let call = self.unifier.add_call(Call { posargs: params, @@ -706,11 +706,13 @@ impl<'a> Inferencer<'a> { let fn_args: Vec<_> = args .args .iter() - .map(|v| (v.node.arg, self.unifier.get_fresh_var(Some(v.node.arg), Some(v.location)).0)) + .map(|v| { + (v.node.arg, self.unifier.get_fresh_var(Some(v.node.arg), Some(v.location)).ty) + }) .collect(); let mut variable_mapping = self.variable_mapping.clone(); variable_mapping.extend(fn_args.iter().copied()); - let ret = self.unifier.get_dummy_var().0; + let ret = self.unifier.get_dummy_var().ty; let mut new_context = Inferencer { function_data: self.function_data, @@ -849,7 +851,7 @@ impl<'a> Inferencer<'a> { &arg, )? } else { - self.unifier.get_dummy_var().0 + self.unifier.get_dummy_var().ty }; self.virtual_checks.push((arg0.custom.unwrap(), ty, *func_location)); let custom = Some(self.unifier.add_ty(TypeEnum::TVirtual { ty })); @@ -1362,7 +1364,7 @@ impl<'a> Inferencer<'a> { } } - let ret = self.unifier.get_dummy_var().0; + let ret = self.unifier.get_dummy_var().ty; let call = self.unifier.add_call(Call { posargs: args.iter().map(|v| v.custom.unwrap()).collect(), kwargs: keywords @@ -1391,7 +1393,7 @@ impl<'a> Inferencer<'a> { .resolver .get_symbol_type(unifier, &self.top_level.definitions.read(), self.primitives, id) .unwrap_or_else(|_| { - let ty = unifier.get_dummy_var().0; + let ty = unifier.get_dummy_var().ty; variable_mapping.insert(id, ty); ty }) @@ -1420,13 +1422,13 @@ impl<'a> Inferencer<'a> { ast::Constant::None => { report_error("CPython `None` not supported (nac3 uses `none` instead)", *loc) } - ast::Constant::Ellipsis => Ok(self.unifier.get_fresh_var(None, None).0), + ast::Constant::Ellipsis => Ok(self.unifier.get_fresh_var(None, None).ty), _ => report_error("not supported", *loc), } } fn infer_list(&mut self, elts: &[ast::Expr>]) -> InferenceResult { - let ty = self.unifier.get_dummy_var().0; + let ty = self.unifier.get_dummy_var().ty; for t in elts { self.unify(ty, t.custom.unwrap(), &t.location)?; } @@ -1462,7 +1464,7 @@ impl<'a> Inferencer<'a> { } } } else { - let attr_ty = self.unifier.get_dummy_var().0; + let attr_ty = self.unifier.get_dummy_var().ty; let fields = once(( attr.into(), RecordField::new(attr_ty, ctx == ExprContext::Store, Some(value.location)), @@ -1655,7 +1657,7 @@ impl<'a> Inferencer<'a> { slice: &ast::Expr>, ctx: ExprContext, ) -> InferenceResult { - let ty = self.unifier.get_dummy_var().0; + let ty = self.unifier.get_dummy_var().ty; match &slice.node { ExprKind::Slice { lower, upper, step } => { for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() { @@ -1759,7 +1761,7 @@ impl<'a> Inferencer<'a> { let valid_index_ty = self .unifier .get_fresh_var_with_range(valid_index_tys.as_slice(), None, None) - .0; + .ty; self.constrain(slice.custom.unwrap(), valid_index_ty, &slice.location)?; self.infer_subscript_ndarray(value, ty, ndims) } diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index f9fec50..45c9788 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -143,10 +143,7 @@ impl TestEnvironment { let ndarray = unifier.add_ty(TypeEnum::TObj { obj_id: PrimDef::NDArray.id(), fields: HashMap::new(), - params: VarMap::from([ - (ndarray_dtype_tvar.1, ndarray_dtype_tvar.0), - (ndarray_ndims_tvar.1, ndarray_ndims_tvar.0), - ]), + params: to_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]), }); let primitives = PrimitiveStore { int32, @@ -321,19 +318,19 @@ impl TestEnvironment { unifier.put_primitive_store(&primitives); - let (v0, id) = unifier.get_dummy_var(); + let tvar = unifier.get_dummy_var(); let foo_ty = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(defs + 1), - fields: [("a".into(), (v0, true))].iter().cloned().collect::>(), - params: [(id, v0)].iter().cloned().collect::(), + fields: [("a".into(), (tvar.ty, true))].iter().cloned().collect::>(), + params: to_var_map([tvar]), }); top_level_defs.push( RwLock::new(TopLevelDef::Class { name: "Foo".into(), object_id: DefinitionId(defs + 1), - type_vars: vec![v0], - fields: [("a".into(), v0, true)].into(), + type_vars: vec![tvar.ty], + fields: [("a".into(), tvar.ty, true)].into(), methods: Default::default(), ancestors: Default::default(), resolver: None, @@ -348,7 +345,7 @@ impl TestEnvironment { unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![], ret: foo_ty, - vars: [(id, v0)].iter().cloned().collect(), + vars: to_var_map([tvar]), })), ); diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 13d9605..855074b 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -2,7 +2,7 @@ use indexmap::IndexMap; use itertools::Itertools; use std::cell::RefCell; use std::collections::HashMap; -use std::fmt::Display; +use std::fmt::{self, Display}; use std::iter::zip; use std::rc::Rc; use std::sync::{Arc, Mutex}; @@ -28,8 +28,44 @@ pub struct CallId(pub(super) usize); pub type Mapping = HashMap; pub type IndexMapping = IndexMap; -/// The mapping between type variable ID and [unifier type][`Type`]. -pub type VarMap = IndexMapping; +/// ID of a Python type variable. Specific to `nac3core`. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct TypeVarId(pub u32); + +impl fmt::Display for TypeVarId { + // NOTE: Must output the string fo the ID value. Certain unit tests rely on string comparisons. + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_fmt(format_args!("{}", self.0)) + } +} + +/// A Python type variable. Used by `nac3core` during type inference. +#[derive(Debug, Clone, Copy)] +pub struct TypeVar { + /// `nac3core`'s internal [`TypeVarId`] of this type variable. + pub id: TypeVarId, + + /// The assigned [`Type`] of this Python type variable. + pub ty: Type, +} + +/// The mapping between [`TypeVarId`] and [unifier type][`Type`]. +pub type VarMap = IndexMapping; + +/// Build a [`VarMap`] from an iterator of [`TypeVar`] +/// +/// The resulting [`VarMap`] wil have the same order as the input iterator. +pub fn to_var_map(vars: I) -> VarMap +where + I: IntoIterator, +{ + vars.into_iter().map(|var| (var.id, var.ty)).collect() +} + +/// Get an iterator of [`TypeVar`]s from a [`VarMap`] +pub fn iter_type_vars(var_map: &VarMap) -> impl Iterator + '_ { + var_map.iter().map(|(&id, &ty)| TypeVar { id, ty }) +} #[derive(Clone)] pub struct Call { @@ -127,14 +163,14 @@ impl RecordField { #[derive(Clone)] pub enum TypeEnum { TRigidVar { - id: u32, + id: TypeVarId, name: Option, loc: Option, }, /// A type variable. TVar { - id: u32, + id: TypeVarId, // empty indicates this is not a struct/tuple/list fields: Option>, // empty indicates no restriction @@ -295,7 +331,7 @@ impl Unifier { } pub fn add_record(&mut self, fields: Mapping) -> Type { - let id = self.var_id + 1; + let id = TypeVarId(self.var_id + 1); self.var_id += 1; self.add_ty(TypeEnum::TVar { id, @@ -346,24 +382,21 @@ impl Unifier { self.unification_table.probe_value_immutable(a).clone() } - pub fn get_fresh_rigid_var( - &mut self, - name: Option, - loc: Option, - ) -> (Type, u32) { - let id = self.var_id + 1; + pub fn get_fresh_rigid_var(&mut self, name: Option, loc: Option) -> TypeVar { + let id = TypeVarId(self.var_id + 1); self.var_id += 1; - (self.add_ty(TypeEnum::TRigidVar { id, name, loc }), id) + let ty = self.add_ty(TypeEnum::TRigidVar { id, name, loc }); + TypeVar { id, ty } } - pub fn get_dummy_var(&mut self) -> (Type, u32) { + pub fn get_dummy_var(&mut self) -> TypeVar { self.get_fresh_var_with_range(&[], None, None) } /// Returns a fresh [type variable][TypeEnum::TVar] with no associated range. /// /// This type variable can be instantiated by any type. - pub fn get_fresh_var(&mut self, name: Option, loc: Option) -> (Type, u32) { + pub fn get_fresh_var(&mut self, name: Option, loc: Option) -> TypeVar { self.get_fresh_var_with_range(&[], name, loc) } @@ -375,21 +408,20 @@ impl Unifier { range: &[Type], name: Option, loc: Option, - ) -> (Type, u32) { - let id = self.var_id + 1; - self.var_id += 1; + ) -> TypeVar { let range = range.to_vec(); - ( - self.add_ty(TypeEnum::TVar { - id, - range, - fields: None, - name, - loc, - is_const_generic: false, - }), + + let id = TypeVarId(self.var_id + 1); + self.var_id += 1; + let ty = self.add_ty(TypeEnum::TVar { id, - ) + range, + fields: None, + name, + loc, + is_const_generic: false, + }); + TypeVar { id, ty } } /// Returns a fresh type representing a constant generic variable with the given underlying type `ty`. @@ -398,20 +430,18 @@ impl Unifier { ty: Type, name: Option, loc: Option, - ) -> (Type, u32) { - let id = self.var_id + 1; + ) -> TypeVar { + let id = TypeVarId(self.var_id + 1); self.var_id += 1; - ( - self.add_ty(TypeEnum::TVar { - id, - range: vec![ty], - fields: None, - name, - loc, - is_const_generic: true, - }), + let ty = self.add_ty(TypeEnum::TVar { id, - ) + range: vec![ty], + fields: None, + name, + loc, + is_const_generic: true, + }); + TypeVar { id, ty } } /// Returns a fresh type representing a [literal][TypeEnum::TConstant] with the given `values`. @@ -464,7 +494,7 @@ impl Unifier { } } TypeEnum::TObj { params, .. } => { - let (keys, params): (Vec, Vec) = params.iter().unzip(); + let (keys, params): (Vec, Vec) = params.iter().unzip(); let params = params .into_iter() .map(|ty| self.get_instantiations(ty).unwrap_or_else(|| vec![ty])) @@ -1014,7 +1044,7 @@ impl Unifier { pub fn stringify_with_notes( &self, ty: Type, - notes: &mut Option>, + notes: &mut Option>, ) -> String { let top_level = self.top_level.clone(); self.internal_stringify( @@ -1043,11 +1073,11 @@ impl Unifier { ty: Type, obj_to_name: &mut F, var_to_name: &mut G, - notes: &mut Option>, + notes: &mut Option>, ) -> String where F: FnMut(usize) -> String, - G: FnMut(u32) -> String, + G: FnMut(TypeVarId) -> String, { let ty = self.unification_table.probe_value_immutable(ty).clone(); match ty.as_ref() { @@ -1182,7 +1212,7 @@ impl Unifier { let mapping = vars .into_iter() .map(|(k, range, name, loc)| { - (k, self.get_fresh_var_with_range(range.as_ref(), name, loc).0) + (k, self.get_fresh_var_with_range(range.as_ref(), name, loc).ty) }) .collect(); self.subst(ty, &mapping).unwrap_or(ty) @@ -1206,7 +1236,7 @@ impl Unifier { let cached = cache.get_mut(&a); if let Some(cached) = cached { if cached.is_none() { - *cached = Some(self.get_fresh_var(None, None).0); + *cached = Some(self.get_fresh_var(None, None).ty); } return *cached; } @@ -1361,7 +1391,7 @@ impl Unifier { if range.is_empty() { Err(()) } else { - let id = self.var_id + 1; + let id = TypeVarId(self.var_id + 1); self.var_id += 1; let ty = TVar { id, diff --git a/nac3core/src/typecheck/typedef/test.rs b/nac3core/src/typecheck/typedef/test.rs index 1c6267e..6867819 100644 --- a/nac3core/src/typecheck/typedef/test.rs +++ b/nac3core/src/typecheck/typedef/test.rs @@ -110,13 +110,13 @@ impl TestEnvironment { params: VarMap::new(), }), ); - let (v0, id) = unifier.get_dummy_var(); + let tvar = unifier.get_dummy_var(); type_mapping.insert( "Foo".into(), unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(3), - fields: [("a".into(), (v0, true))].iter().cloned().collect::>(), - params: [(id, v0)].iter().cloned().collect::(), + fields: [("a".into(), (tvar.ty, true))].iter().cloned().collect::>(), + params: to_var_map([tvar]), }), ); @@ -250,7 +250,7 @@ fn test_unify( let mut mapping = HashMap::new(); for i in 1..=variable_count { let v = env.unifier.get_dummy_var(); - mapping.insert(format!("v{}", i), v.0); + mapping.insert(format!("v{}", i), v.ty); } // unification may have side effect when we do type resolution, so freeze the types // before doing unification. @@ -315,7 +315,7 @@ fn test_invalid_unification( let mut mapping = HashMap::new(); for i in 1..=variable_count { let v = env.unifier.get_dummy_var(); - mapping.insert(format!("v{}", i), v.0); + mapping.insert(format!("v{}", i), v.ty); } // unification may have side effect when we do type resolution, so freeze the types // before doing unification. @@ -369,8 +369,8 @@ fn test_virtual() { .collect::>(), params: VarMap::new(), }); - let v0 = env.unifier.get_dummy_var().0; - let v1 = env.unifier.get_dummy_var().0; + let v0 = env.unifier.get_dummy_var().ty; + let v1 = env.unifier.get_dummy_var().ty; let a = env.unifier.add_ty(TypeEnum::TVirtual { ty: bar }); let b = env.unifier.add_ty(TypeEnum::TVirtual { ty: v0 }); @@ -403,12 +403,12 @@ fn test_typevar_range() { // unification between v and int // where v in (int, bool) - let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0; + let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty; env.unifier.unify(int, v).unwrap(); // unification between v and list[int] // where v in (int, bool) - let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0; + let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty; assert_eq!( env.unify(int_list, v), Err("Expected any one of these types: 0, 2, but got list[0]".to_string()) @@ -416,25 +416,25 @@ fn test_typevar_range() { // unification between v and float // where v in (int, bool) - let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0; + let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty; assert_eq!( env.unify(float, v), Err("Expected any one of these types: 0, 2, but got 1".to_string()) ); - let v1 = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0; + let v1 = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty; let v1_list = env.unifier.add_ty(TypeEnum::TList { ty: v1 }); - let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).0; + let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).ty; // unification between v and int // where v in (int, list[v1]), v1 in (int, bool) env.unifier.unify(int, v).unwrap(); - let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).0; + let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).ty; // unification between v and list[int] // where v in (int, list[v1]), v1 in (int, bool) env.unifier.unify(int_list, v).unwrap(); - let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).0; + let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).ty; // unification between v and list[float] // where v in (int, list[v1]), v1 in (int, bool) assert_eq!( @@ -442,30 +442,30 @@ fn test_typevar_range() { Err("Expected any one of these types: 0, list[typevar5], but got list[1]\n\nNotes:\n typevar5 ∈ {0, 2}".to_string()) ); - let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0; - let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0; + let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty; + let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).ty; env.unifier.unify(a, b).unwrap(); env.unifier.unify(a, float).unwrap(); - let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0; - let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0; + let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty; + let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).ty; env.unifier.unify(a, b).unwrap(); assert_eq!(env.unify(a, int), Err("Expected any one of these types: 1, but got 0".into())); - let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0; - let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0; + let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty; + let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).ty; let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a }); - let a_list = env.unifier.get_fresh_var_with_range(&[a_list], None, None).0; + let a_list = env.unifier.get_fresh_var_with_range(&[a_list], None, None).ty; let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b }); - let b_list = env.unifier.get_fresh_var_with_range(&[b_list], None, None).0; + let b_list = env.unifier.get_fresh_var_with_range(&[b_list], None, None).ty; env.unifier.unify(a_list, b_list).unwrap(); let float_list = env.unifier.add_ty(TypeEnum::TList { ty: float }); env.unifier.unify(a_list, float_list).unwrap(); // previous unifications should not affect a and b env.unifier.unify(a, int).unwrap(); - let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0; - let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0; + let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty; + let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).ty; let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a }); let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b }); env.unifier.unify(a_list, b_list).unwrap(); @@ -477,10 +477,10 @@ fn test_typevar_range() { .into()) ); - let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0; - let b = env.unifier.get_dummy_var().0; + let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty; + let b = env.unifier.get_dummy_var().ty; let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a }); - let a_list = env.unifier.get_fresh_var_with_range(&[a_list], None, None).0; + let a_list = env.unifier.get_fresh_var_with_range(&[a_list], None, None).ty; let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b }); env.unifier.unify(a_list, b_list).unwrap(); assert_eq!( @@ -492,9 +492,9 @@ fn test_typevar_range() { #[test] fn test_rigid_var() { let mut env = TestEnvironment::new(); - let a = env.unifier.get_fresh_rigid_var(None, None).0; - let b = env.unifier.get_fresh_rigid_var(None, None).0; - let x = env.unifier.get_dummy_var().0; + let a = env.unifier.get_fresh_rigid_var(None, None).ty; + let b = env.unifier.get_fresh_rigid_var(None, None).ty; + let x = env.unifier.get_dummy_var().ty; let list_a = env.unifier.add_ty(TypeEnum::TList { ty: a }); let list_x = env.unifier.add_ty(TypeEnum::TList { ty: x }); let int = env.parse("int", &HashMap::new()); @@ -522,13 +522,13 @@ fn test_instantiation() { let obj_map: HashMap<_, _> = [(0usize, "int"), (1, "float"), (2, "bool")].iter().cloned().collect(); - let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0; + let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty; let list_v = env.unifier.add_ty(TypeEnum::TList { ty: v }); - let v1 = env.unifier.get_fresh_var_with_range(&[list_v, int], None, None).0; - let v2 = env.unifier.get_fresh_var_with_range(&[list_int, float], None, None).0; - let t = env.unifier.get_dummy_var().0; + let v1 = env.unifier.get_fresh_var_with_range(&[list_v, int], None, None).ty; + let v2 = env.unifier.get_fresh_var_with_range(&[list_int, float], None, None).ty; + let t = env.unifier.get_dummy_var().ty; let tuple = env.unifier.add_ty(TypeEnum::TTuple { ty: vec![v, v1, v2] }); - let v3 = env.unifier.get_fresh_var_with_range(&[tuple, t], None, None).0; + let v3 = env.unifier.get_fresh_var_with_range(&[tuple, t], None, None).ty; // t = TypeVar('t') // v = TypeVar('v', int, bool) // v1 = TypeVar('v1', 'list[v]', int) diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index 128c8d1..506b62a 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -124,7 +124,7 @@ fn handle_typevar_definition( )])); } - Ok(unifier.get_fresh_var_with_range(&constraints, Some(generic_name), Some(loc)).0) + Ok(unifier.get_fresh_var_with_range(&constraints, Some(generic_name), Some(loc)).ty) } ExprKind::Name { id, .. } if id == &"ConstGeneric".into() => { @@ -155,7 +155,7 @@ fn handle_typevar_definition( get_type_from_type_annotation_kinds(def_list, unifier, &ty, &mut None)?; let loc = func.location; - Ok(unifier.get_fresh_const_generic_var(constraint, Some(generic_name), Some(loc)).0) + Ok(unifier.get_fresh_const_generic_var(constraint, Some(generic_name), Some(loc)).ty) } _ => Err(HashSet::from([format!(