core: refactor to use `TypeVarId` and `TypeVar`

This commit is contained in:
lyken 2024-06-13 13:28:39 +08:00
parent dc874f2994
commit f026b48e2a
19 changed files with 309 additions and 302 deletions

View File

@ -9,7 +9,7 @@ use nac3core::{
}, },
typecheck::{ typecheck::{
type_inferencer::PrimitiveStore, type_inferencer::PrimitiveStore,
typedef::{Type, TypeEnum, Unifier, VarMap}, typedef::{iter_type_vars, to_var_map, Type, TypeEnum, TypeVar, Unifier, VarMap},
}, },
}; };
use nac3parser::ast::{self, StrRef}; use nac3parser::ast::{self, StrRef};
@ -317,13 +317,13 @@ impl InnerResolver {
Ok(Ok((primitives.exception, true))) Ok(Ok((primitives.exception, true)))
} else if ty_id == self.primitive_ids.list { } else if ty_id == self.primitive_ids.list {
// do not handle type var param and concrete check here // do not handle type var param and concrete check here
let var = unifier.get_dummy_var().0; let var = unifier.get_dummy_var().ty;
let list = unifier.add_ty(TypeEnum::TList { ty: var }); let list = unifier.add_ty(TypeEnum::TList { ty: var });
Ok(Ok((list, false))) Ok(Ok((list, false)))
} else if ty_id == self.primitive_ids.ndarray { } else if ty_id == self.primitive_ids.ndarray {
// do not handle type var param and concrete check here // do not handle type var param and concrete check here
let var = unifier.get_dummy_var().0; let var = unifier.get_dummy_var().ty;
let ndims = unifier.get_fresh_const_generic_var(primitives.usize(), None, None).0; let ndims = unifier.get_fresh_const_generic_var(primitives.usize(), None, None).ty;
let ndarray = make_ndarray_ty(unifier, primitives, Some(var), Some(ndims)); let ndarray = make_ndarray_ty(unifier, primitives, Some(var), Some(ndims));
Ok(Ok((ndarray, false))) Ok(Ok((ndarray, false)))
} else if ty_id == self.primitive_ids.tuple { } else if ty_id == self.primitive_ids.tuple {
@ -383,7 +383,7 @@ impl InnerResolver {
} }
if !is_const_generic && needs_defer { if !is_const_generic && needs_defer {
result.push(unifier.get_dummy_var().0); result.push(unifier.get_dummy_var().ty);
} else { } else {
result.push({ result.push({
match self.get_pyty_obj_type(py, constr, unifier, defs, primitives)? { match self.get_pyty_obj_type(py, constr, unifier, defs, primitives)? {
@ -426,9 +426,9 @@ impl InnerResolver {
))); )));
} }
unifier.get_fresh_const_generic_var(constraint_types[0], Some(name.into()), None).0 unifier.get_fresh_const_generic_var(constraint_types[0], Some(name.into()), None).ty
} else { } else {
unifier.get_fresh_var_with_range(&constraint_types, Some(name.into()), None).0 unifier.get_fresh_var_with_range(&constraint_types, Some(name.into()), None).ty
}; };
Ok(Ok((res, true))) Ok(Ok((res, true)))
@ -568,7 +568,7 @@ impl InnerResolver {
} else if ty_id == self.primitive_ids.virtual_id { } else if ty_id == self.primitive_ids.virtual_id {
Ok(Ok(( Ok(Ok((
{ {
let ty = TypeEnum::TVirtual { ty: unifier.get_dummy_var().0 }; let ty = TypeEnum::TVirtual { ty: unifier.get_dummy_var().ty };
unifier.add_ty(ty) unifier.add_ty(ty)
}, },
false, false,
@ -719,18 +719,16 @@ impl InnerResolver {
unreachable!("must be tobj") unreachable!("must be tobj")
}; };
let var_map = params let var_map = to_var_map(iter_type_vars(params).map(|tvar| {
.iter() let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(tvar.ty)
.map(|(id_var, ty)| {
let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(*ty)
else { else {
unreachable!() unreachable!()
}; };
assert_eq!(*id, *id_var); assert_eq!(*id, tvar.id);
(*id, unifier.get_fresh_var_with_range(range, *name, *loc).0) let ty = unifier.get_fresh_var_with_range(range, *name, *loc).ty;
}) TypeVar { id: *id, ty }
.collect::<VarMap>(); }));
return Ok(Ok(unifier.subst(primitives.option, &var_map).unwrap())); return Ok(Ok(unifier.subst(primitives.option, &var_map).unwrap()));
} }
@ -748,18 +746,16 @@ impl InnerResolver {
} }
(TypeEnum::TObj { params, fields, .. }, false) => { (TypeEnum::TObj { params, fields, .. }, false) => {
self.pyid_to_type.write().insert(py_obj_id, extracted_ty); self.pyid_to_type.write().insert(py_obj_id, extracted_ty);
let var_map = params let var_map = to_var_map(iter_type_vars(params).map(|tvar| {
.iter() let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(tvar.ty)
.map(|(id_var, ty)| {
let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(*ty)
else { else {
unreachable!() unreachable!()
}; };
assert_eq!(*id, *id_var); assert_eq!(*id, tvar.id);
(*id, unifier.get_fresh_var_with_range(range, *name, *loc).0) let ty = unifier.get_fresh_var_with_range(range, *name, *loc).ty;
}) TypeVar { id: *id, ty }
.collect::<VarMap>(); }));
let mut instantiate_obj = || { let mut instantiate_obj = || {
// loop through non-function fields of the class to get the instantiated value // loop through non-function fields of the class to get the instantiated value
for field in fields { for field in fields {

View File

@ -3,7 +3,7 @@ use crate::{
toplevel::DefinitionId, toplevel::DefinitionId,
typecheck::{ typecheck::{
type_inferencer::PrimitiveStore, type_inferencer::PrimitiveStore,
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, typedef::{to_var_map, FunSignature, FuncArg, Type, TypeEnum, TypeVar, TypeVarId, Unifier},
}, },
}; };
@ -51,7 +51,7 @@ pub enum ConcreteTypeEnum {
TObj { TObj {
obj_id: DefinitionId, obj_id: DefinitionId,
fields: HashMap<StrRef, (ConcreteType, bool)>, fields: HashMap<StrRef, (ConcreteType, bool)>,
params: IndexMap<u32, ConcreteType>, params: IndexMap<TypeVarId, ConcreteType>,
}, },
TVirtual { TVirtual {
ty: ConcreteType, ty: ConcreteType,
@ -59,7 +59,7 @@ pub enum ConcreteTypeEnum {
TFunc { TFunc {
args: Vec<ConcreteFuncArg>, args: Vec<ConcreteFuncArg>,
ret: ConcreteType, ret: ConcreteType,
vars: HashMap<u32, ConcreteType>, vars: HashMap<TypeVarId, ConcreteType>,
}, },
TLiteral { TLiteral {
values: Vec<SymbolValue>, values: Vec<SymbolValue>,
@ -230,7 +230,7 @@ impl ConcreteTypeStore {
return if let Some(ty) = ty { return if let Some(ty) = ty {
*ty *ty
} else { } else {
*ty = Some(unifier.get_dummy_var().0); *ty = Some(unifier.get_dummy_var().ty);
ty.unwrap() ty.unwrap()
}; };
} }
@ -272,10 +272,10 @@ impl ConcreteTypeStore {
(*name, (self.to_unifier_type(unifier, primitives, cty.0, cache), cty.1)) (*name, (self.to_unifier_type(unifier, primitives, cty.0, cache), cty.1))
}) })
.collect::<HashMap<_, _>>(), .collect::<HashMap<_, _>>(),
params: params params: to_var_map(params.iter().map(|(&id, cty)| {
.iter() let ty = self.to_unifier_type(unifier, primitives, *cty, cache);
.map(|(id, cty)| (*id, self.to_unifier_type(unifier, primitives, *cty, cache))) TypeVar { id, ty }
.collect::<VarMap>(), })),
}, },
ConcreteTypeEnum::TFunc { args, ret, vars } => TypeEnum::TFunc(FunSignature { ConcreteTypeEnum::TFunc { args, ret, vars } => TypeEnum::TFunc(FunSignature {
args: args args: args
@ -287,10 +287,10 @@ impl ConcreteTypeStore {
}) })
.collect(), .collect(),
ret: self.to_unifier_type(unifier, primitives, *ret, cache), ret: self.to_unifier_type(unifier, primitives, *ret, cache),
vars: vars vars: to_var_map(vars.iter().map(|(&id, cty)| {
.iter() let ty = self.to_unifier_type(unifier, primitives, *cty, cache);
.map(|(id, cty)| (*id, self.to_unifier_type(unifier, primitives, *cty, cache))) TypeVar { id, ty }
.collect::<VarMap>(), })),
}), }),
ConcreteTypeEnum::TLiteral { values, .. } => { ConcreteTypeEnum::TLiteral { values, .. } => {
TypeEnum::TLiteral { values: values.clone(), loc: None } TypeEnum::TLiteral { values: values.clone(), loc: None }

View File

@ -22,7 +22,7 @@ use crate::{
}, },
typecheck::{ typecheck::{
magic_methods::{binop_assign_name, binop_name, unaryop_name}, magic_methods::{binop_assign_name, binop_name, unaryop_name},
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap},
}, },
}; };
use inkwell::{ use inkwell::{
@ -42,7 +42,7 @@ pub fn get_subst_key(
unifier: &mut Unifier, unifier: &mut Unifier,
obj: Option<Type>, obj: Option<Type>,
fun_vars: &VarMap, fun_vars: &VarMap,
filter: Option<&Vec<u32>>, filter: Option<&Vec<TypeVarId>>,
) -> String { ) -> String {
let mut vars = obj let mut vars = obj
.map(|ty| { .map(|ty| {
@ -81,7 +81,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
&mut self, &mut self,
obj: Option<Type>, obj: Option<Type>,
fun: &FunSignature, fun: &FunSignature,
filter: Option<&Vec<u32>>, filter: Option<&Vec<TypeVarId>>,
) -> String { ) -> String {
get_subst_key(&mut self.unifier, obj, &fun.vars, filter) get_subst_key(&mut self.unifier, obj, &fun.vars, filter)
} }

View File

@ -25,7 +25,7 @@ use crate::{
}, },
symbol_resolver::SymbolValue, symbol_resolver::SymbolValue,
toplevel::{helper::PrimDef, numpy::make_ndarray_ty}, toplevel::{helper::PrimDef, numpy::make_ndarray_ty},
typecheck::typedef::VarMap, typecheck::typedef::{iter_type_vars, to_var_map, TypeVar, VarMap},
}; };
use super::*; use super::*;
@ -307,26 +307,26 @@ struct BuiltinBuilder<'a> {
is_some_ty: (Type, bool), is_some_ty: (Type, bool),
unwrap_ty: (Type, bool), unwrap_ty: (Type, bool),
option_tvar: (Type, u32), option_tvar: TypeVar,
ndarray_dtype_tvar: (Type, u32), ndarray_dtype_tvar: TypeVar,
ndarray_ndims_tvar: (Type, u32), ndarray_ndims_tvar: TypeVar,
ndarray_copy_ty: (Type, bool), ndarray_copy_ty: (Type, bool),
ndarray_fill_ty: (Type, bool), ndarray_fill_ty: (Type, bool),
list_int32: Type, list_int32: Type,
num_ty: (Type, u32), num_ty: TypeVar,
num_var_map: VarMap, num_var_map: VarMap,
ndarray_float: Type, ndarray_float: Type,
ndarray_float_2d: Type, ndarray_float_2d: Type,
ndarray_num_ty: Type, ndarray_num_ty: Type,
float_or_ndarray_ty: (Type, u32), float_or_ndarray_ty: TypeVar,
float_or_ndarray_var_map: VarMap, float_or_ndarray_var_map: VarMap,
num_or_ndarray_ty: (Type, u32), num_or_ndarray_ty: TypeVar,
num_or_ndarray_var_map: VarMap, num_or_ndarray_var_map: VarMap,
} }
@ -350,7 +350,7 @@ impl<'a> BuiltinBuilder<'a> {
( (
*fields.get(&PrimDef::OptionIsSome.simple_name().into()).unwrap(), *fields.get(&PrimDef::OptionIsSome.simple_name().into()).unwrap(),
*fields.get(&PrimDef::OptionUnwrap.simple_name().into()).unwrap(), *fields.get(&PrimDef::OptionUnwrap.simple_name().into()).unwrap(),
(*params.iter().next().unwrap().1, *params.iter().next().unwrap().0), iter_type_vars(params).next().unwrap(),
) )
} else { } else {
unreachable!() unreachable!()
@ -361,10 +361,8 @@ impl<'a> BuiltinBuilder<'a> {
else { else {
unreachable!() unreachable!()
}; };
let ndarray_dtype_tvar = let ndarray_dtype_tvar = iter_type_vars(ndarray_params).next().unwrap();
ndarray_params.iter().next().map(|(var_id, ty)| (*ty, *var_id)).unwrap(); let ndarray_ndims_tvar = iter_type_vars(ndarray_params).nth(1).unwrap();
let ndarray_ndims_tvar =
ndarray_params.iter().nth(1).map(|(var_id, ty)| (*ty, *var_id)).unwrap();
let ndarray_copy_ty = let ndarray_copy_ty =
*ndarray_fields.get(&PrimDef::NDArrayCopy.simple_name().into()).unwrap(); *ndarray_fields.get(&PrimDef::NDArrayCopy.simple_name().into()).unwrap();
let ndarray_fill_ty = let ndarray_fill_ty =
@ -375,7 +373,7 @@ impl<'a> BuiltinBuilder<'a> {
Some("N".into()), Some("N".into()),
None, None,
); );
let num_var_map: VarMap = vec![(num_ty.1, num_ty.0)].into_iter().collect(); let num_var_map = to_var_map([num_ty]);
let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), None); let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), None);
let ndarray_float_2d = { let ndarray_float_2d = {
@ -389,18 +387,14 @@ impl<'a> BuiltinBuilder<'a> {
make_ndarray_ty(unifier, primitives, Some(float), Some(ndims)) make_ndarray_ty(unifier, primitives, Some(float), Some(ndims))
}; };
let ndarray_num_ty = make_ndarray_ty(unifier, primitives, Some(num_ty.0), None); let ndarray_num_ty = make_ndarray_ty(unifier, primitives, Some(num_ty.ty), None);
let float_or_ndarray_ty = let float_or_ndarray_ty =
unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None);
let float_or_ndarray_var_map: VarMap = let float_or_ndarray_var_map = to_var_map([float_or_ndarray_ty]);
vec![(float_or_ndarray_ty.1, float_or_ndarray_ty.0)].into_iter().collect();
let num_or_ndarray_ty = let num_or_ndarray_ty =
unifier.get_fresh_var_with_range(&[num_ty.0, ndarray_num_ty], Some("T".into()), None); unifier.get_fresh_var_with_range(&[num_ty.ty, ndarray_num_ty], Some("T".into()), None);
let num_or_ndarray_var_map: VarMap = let num_or_ndarray_var_map = to_var_map([num_ty, num_or_ndarray_ty]);
vec![(num_ty.1, num_ty.0), (num_or_ndarray_ty.1, num_or_ndarray_ty.0)]
.into_iter()
.collect();
let list_int32 = unifier.add_ty(TypeEnum::TList { ty: int32 }); let list_int32 = unifier.add_ty(TypeEnum::TList { ty: int32 });
@ -648,7 +642,7 @@ impl<'a> BuiltinBuilder<'a> {
PrimDef::Option => TopLevelDef::Class { PrimDef::Option => TopLevelDef::Class {
name: prim.name().into(), name: prim.name().into(),
object_id: prim.id(), object_id: prim.id(),
type_vars: vec![self.option_tvar.0], type_vars: vec![self.option_tvar.ty],
fields: vec![], fields: vec![],
methods: vec![ methods: vec![
Self::create_method(PrimDef::OptionIsSome, self.is_some_ty.0), Self::create_method(PrimDef::OptionIsSome, self.is_some_ty.0),
@ -668,7 +662,7 @@ impl<'a> BuiltinBuilder<'a> {
name: prim.name().into(), name: prim.name().into(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.unwrap_ty.0, signature: self.unwrap_ty.0,
var_id: vec![self.option_tvar.1], var_id: vec![self.option_tvar.id],
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
@ -682,7 +676,7 @@ impl<'a> BuiltinBuilder<'a> {
name: prim.name().to_string(), name: prim.name().to_string(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.is_some_ty.0, signature: self.is_some_ty.0,
var_id: vec![self.option_tvar.1], var_id: vec![self.option_tvar.id],
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
@ -719,13 +713,13 @@ impl<'a> BuiltinBuilder<'a> {
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { args: vec![FuncArg {
name: "n".into(), name: "n".into(),
ty: self.option_tvar.0, ty: self.option_tvar.ty,
default_value: None, default_value: None,
}], }],
ret: self.primitives.option, ret: self.primitives.option,
vars: VarMap::from([(self.option_tvar.1, self.option_tvar.0)]), vars: to_var_map([self.option_tvar]),
})), })),
var_id: vec![self.option_tvar.1], var_id: vec![self.option_tvar.id],
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
@ -761,7 +755,7 @@ impl<'a> BuiltinBuilder<'a> {
PrimDef::NDArray => TopLevelDef::Class { PrimDef::NDArray => TopLevelDef::Class {
name: prim.name().into(), name: prim.name().into(),
object_id: prim.id(), object_id: prim.id(),
type_vars: vec![self.ndarray_dtype_tvar.0, self.ndarray_ndims_tvar.0], type_vars: vec![self.ndarray_dtype_tvar.ty, self.ndarray_ndims_tvar.ty],
fields: Vec::default(), fields: Vec::default(),
methods: vec![ methods: vec![
Self::create_method(PrimDef::NDArrayCopy, self.ndarray_copy_ty.0), Self::create_method(PrimDef::NDArrayCopy, self.ndarray_copy_ty.0),
@ -777,7 +771,7 @@ impl<'a> BuiltinBuilder<'a> {
name: prim.name().into(), name: prim.name().into(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.ndarray_copy_ty.0, signature: self.ndarray_copy_ty.0,
var_id: vec![self.ndarray_dtype_tvar.1, self.ndarray_ndims_tvar.1], var_id: vec![self.ndarray_dtype_tvar.id, self.ndarray_ndims_tvar.id],
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
@ -794,7 +788,7 @@ impl<'a> BuiltinBuilder<'a> {
name: prim.name().into(), name: prim.name().into(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.ndarray_fill_ty.0, signature: self.ndarray_fill_ty.0,
var_id: vec![self.ndarray_dtype_tvar.1, self.ndarray_ndims_tvar.1], var_id: vec![self.ndarray_dtype_tvar.id, self.ndarray_ndims_tvar.id],
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
@ -831,10 +825,10 @@ impl<'a> BuiltinBuilder<'a> {
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { args: vec![FuncArg {
name: "n".into(), name: "n".into(),
ty: self.num_or_ndarray_ty.0, ty: self.num_or_ndarray_ty.ty,
default_value: None, default_value: None,
}], }],
ret: self.num_or_ndarray_ty.0, ret: self.num_or_ndarray_ty.ty,
vars: self.num_or_ndarray_var_map.clone(), vars: self.num_or_ndarray_var_map.clone(),
})), })),
var_id: Vec::default(), var_id: Vec::default(),
@ -884,9 +878,9 @@ impl<'a> BuiltinBuilder<'a> {
let int_sized = size_variant.of_int(self.primitives); let int_sized = size_variant.of_int(self.primitives);
let ndarray_int_sized = let ndarray_int_sized =
make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.0)); make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.ty));
let ndarray_float = let ndarray_float =
make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.0)); make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.ty));
let p0_ty = let p0_ty =
self.unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); self.unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None);
@ -898,12 +892,10 @@ impl<'a> BuiltinBuilder<'a> {
create_fn_by_codegen( create_fn_by_codegen(
self.unifier, self.unifier,
&[(common_ndim.1, common_ndim.0), (p0_ty.1, p0_ty.0), (ret_ty.1, ret_ty.0)] &to_var_map([common_ndim, p0_ty, ret_ty]),
.into_iter()
.collect(),
prim.name(), prim.name(),
ret_ty.0, ret_ty.ty,
&[(p0_ty.0, "n")], &[(p0_ty.ty, "n")],
Box::new(move |ctx, _, fun, args, generator| { Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty; let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
@ -946,12 +938,12 @@ impl<'a> BuiltinBuilder<'a> {
); );
let ndarray_float = let ndarray_float =
make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.0)); make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.ty));
// The size variant of the function determines the type of int returned // The size variant of the function determines the type of int returned
let int_sized = size_variant.of_int(self.primitives); let int_sized = size_variant.of_int(self.primitives);
let ndarray_int_sized = let ndarray_int_sized =
make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.0)); make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.ty));
let p0_ty = let p0_ty =
self.unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); self.unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None);
@ -964,12 +956,10 @@ impl<'a> BuiltinBuilder<'a> {
create_fn_by_codegen( create_fn_by_codegen(
self.unifier, self.unifier,
&[(common_ndim.1, common_ndim.0), (p0_ty.1, p0_ty.0), (ret_ty.1, ret_ty.0)] &to_var_map([common_ndim, p0_ty, ret_ty]),
.into_iter()
.collect(),
prim.name(), prim.name(),
ret_ty.0, ret_ty.ty,
&[(p0_ty.0, "n")], &[(p0_ty.ty, "n")],
Box::new(move |ctx, _, fun, args, generator| { Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty; let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
@ -1031,7 +1021,7 @@ impl<'a> BuiltinBuilder<'a> {
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![ args: vec![
FuncArg { name: "object".into(), ty: tv.0, default_value: None }, FuncArg { name: "object".into(), ty: tv.ty, default_value: None },
FuncArg { FuncArg {
name: "copy".into(), name: "copy".into(),
ty: bool, ty: bool,
@ -1044,9 +1034,9 @@ impl<'a> BuiltinBuilder<'a> {
}, },
], ],
ret: ndarray, ret: ndarray,
vars: VarMap::from([(tv.1, tv.0)]), vars: to_var_map([tv]),
})), })),
var_id: vec![tv.1], var_id: vec![tv.id],
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
@ -1064,12 +1054,12 @@ impl<'a> BuiltinBuilder<'a> {
create_fn_by_codegen( create_fn_by_codegen(
self.unifier, self.unifier,
&[(tv.1, tv.0)].into_iter().collect(), &to_var_map([tv]),
prim.name(), prim.name(),
self.primitives.ndarray, self.primitives.ndarray,
// We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a
// type variable // type variable
&[(self.list_int32, "shape"), (tv.0, "fill_value")], &[(self.list_int32, "shape"), (tv.ty, "fill_value")],
Box::new(move |ctx, obj, fun, args, generator| { Box::new(move |ctx, obj, fun, args, generator| {
gen_ndarray_full(ctx, &obj, fun, &args, generator) gen_ndarray_full(ctx, &obj, fun, &args, generator)
.map(|val| Some(val.as_basic_value_enum())) .map(|val| Some(val.as_basic_value_enum()))
@ -1287,8 +1277,8 @@ impl<'a> BuiltinBuilder<'a> {
self.unifier, self.unifier,
&self.float_or_ndarray_var_map, &self.float_or_ndarray_var_map,
prim.name(), prim.name(),
self.float_or_ndarray_ty.0, self.float_or_ndarray_ty.ty,
&[(self.float_or_ndarray_ty.0, "n")], &[(self.float_or_ndarray_ty.ty, "n")],
Box::new(move |ctx, _, fun, args, generator| { Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty; let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
@ -1311,8 +1301,8 @@ impl<'a> BuiltinBuilder<'a> {
self.unifier, self.unifier,
&self.float_or_ndarray_var_map, &self.float_or_ndarray_var_map,
prim.name(), prim.name(),
self.float_or_ndarray_ty.0, self.float_or_ndarray_ty.ty,
&[(self.float_or_ndarray_ty.0, "n")], &[(self.float_or_ndarray_ty.ty, "n")],
Box::new(|ctx, _, fun, args, generator| { Box::new(|ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty; let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
@ -1328,9 +1318,9 @@ impl<'a> BuiltinBuilder<'a> {
let PrimitiveStore { uint64, int32, .. } = *self.primitives; let PrimitiveStore { uint64, int32, .. } = *self.primitives;
let tvar = self.unifier.get_fresh_var(Some("L".into()), None); let tvar = self.unifier.get_fresh_var(Some("L".into()), None);
let list = self.unifier.add_ty(TypeEnum::TList { ty: tvar.0 }); let list = self.unifier.add_ty(TypeEnum::TList { ty: tvar.ty });
let ndims = self.unifier.get_fresh_const_generic_var(uint64, Some("N".into()), None); let ndims = self.unifier.get_fresh_const_generic_var(uint64, Some("N".into()), None);
let ndarray = make_ndarray_ty(self.unifier, self.primitives, Some(tvar.0), Some(ndims.0)); let ndarray = make_ndarray_ty(self.unifier, self.primitives, Some(tvar.ty), Some(ndims.ty));
let arg_ty = self.unifier.get_fresh_var_with_range( let arg_ty = self.unifier.get_fresh_var_with_range(
&[list, ndarray, self.primitives.range], &[list, ndarray, self.primitives.range],
@ -1341,9 +1331,9 @@ impl<'a> BuiltinBuilder<'a> {
name: prim.name().into(), name: prim.name().into(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { name: "ls".into(), ty: arg_ty.0, default_value: None }], args: vec![FuncArg { name: "ls".into(), ty: arg_ty.ty, default_value: None }],
ret: int32, ret: int32,
vars: vec![(tvar.1, tvar.0), (arg_ty.1, arg_ty.0)].into_iter().collect(), vars: to_var_map([tvar, arg_ty]),
})), })),
var_id: Vec::default(), var_id: Vec::default(),
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
@ -1446,10 +1436,10 @@ impl<'a> BuiltinBuilder<'a> {
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![ args: vec![
FuncArg { name: "m".into(), ty: self.num_ty.0, default_value: None }, FuncArg { name: "m".into(), ty: self.num_ty.ty, default_value: None },
FuncArg { name: "n".into(), ty: self.num_ty.0, default_value: None }, FuncArg { name: "n".into(), ty: self.num_ty.ty, default_value: None },
], ],
ret: self.num_ty.0, ret: self.num_ty.ty,
vars: self.num_var_map.clone(), vars: self.num_var_map.clone(),
})), })),
var_id: Vec::default(), var_id: Vec::default(),
@ -1484,15 +1474,15 @@ impl<'a> BuiltinBuilder<'a> {
.num_or_ndarray_var_map .num_or_ndarray_var_map
.clone() .clone()
.into_iter() .into_iter()
.chain(once((ret_ty.1, ret_ty.0))) .chain(once((ret_ty.id, ret_ty.ty)))
.collect::<IndexMap<_, _>>(); .collect::<IndexMap<_, _>>();
create_fn_by_codegen( create_fn_by_codegen(
self.unifier, self.unifier,
&var_map, &var_map,
prim.name(), prim.name(),
ret_ty.0, ret_ty.ty,
&[(self.float_or_ndarray_ty.0, "a")], &[(self.float_or_ndarray_ty.ty, "a")],
Box::new(move |ctx, _, fun, args, generator| { Box::new(move |ctx, _, fun, args, generator| {
let a_ty = fun.0.args[0].ty; let a_ty = fun.0.args[0].ty;
let a = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?; let a = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?;
@ -1512,9 +1502,9 @@ impl<'a> BuiltinBuilder<'a> {
fn build_np_minimum_maximum_function(&mut self, prim: PrimDef) -> TopLevelDef { fn build_np_minimum_maximum_function(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpMinimum, PrimDef::FunNpMaximum]); debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpMinimum, PrimDef::FunNpMaximum]);
let x1_ty = self.new_type_or_ndarray_ty(self.num_ty.0); let x1_ty = self.new_type_or_ndarray_ty(self.num_ty.ty);
let x2_ty = self.new_type_or_ndarray_ty(self.num_ty.0); let x2_ty = self.new_type_or_ndarray_ty(self.num_ty.ty);
let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; let param_ty = &[(x1_ty.ty, "x1"), (x2_ty.ty, "x2")];
let ret_ty = self.unifier.get_fresh_var(None, None); let ret_ty = self.unifier.get_fresh_var(None, None);
TopLevelDef::Function { TopLevelDef::Function {
@ -1525,12 +1515,10 @@ impl<'a> BuiltinBuilder<'a> {
.iter() .iter()
.map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None }) .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None })
.collect(), .collect(),
ret: ret_ty.0, ret: ret_ty.ty,
vars: [(x1_ty.1, x1_ty.0), (x2_ty.1, x2_ty.0), (ret_ty.1, ret_ty.0)] vars: to_var_map([x1_ty, x2_ty, ret_ty]),
.into_iter()
.collect(),
})), })),
var_id: vec![x1_ty.1, x2_ty.1], var_id: vec![x1_ty.id, x2_ty.id],
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
@ -1564,10 +1552,10 @@ impl<'a> BuiltinBuilder<'a> {
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { args: vec![FuncArg {
name: "n".into(), name: "n".into(),
ty: self.num_or_ndarray_ty.0, ty: self.num_or_ndarray_ty.ty,
default_value: None, default_value: None,
}], }],
ret: self.num_or_ndarray_ty.0, ret: self.num_or_ndarray_ty.ty,
vars: self.num_or_ndarray_var_map.clone(), vars: self.num_or_ndarray_var_map.clone(),
})), })),
var_id: Vec::default(), var_id: Vec::default(),
@ -1660,8 +1648,8 @@ impl<'a> BuiltinBuilder<'a> {
self.unifier, self.unifier,
&self.float_or_ndarray_var_map, &self.float_or_ndarray_var_map,
prim.name(), prim.name(),
self.float_or_ndarray_ty.0, self.float_or_ndarray_ty.ty,
&[(self.float_or_ndarray_ty.0, arg_name)], &[(self.float_or_ndarray_ty.ty, arg_name)],
Box::new(move |ctx, _, fun, args, generator| { Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty; let arg_ty = fun.0.args[0].ty;
let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
@ -1745,7 +1733,7 @@ impl<'a> BuiltinBuilder<'a> {
let x1_ty = self.new_type_or_ndarray_ty(x1_ty); let x1_ty = self.new_type_or_ndarray_ty(x1_ty);
let x2_ty = self.new_type_or_ndarray_ty(x2_ty); let x2_ty = self.new_type_or_ndarray_ty(x2_ty);
let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; let param_ty = &[(x1_ty.ty, "x1"), (x2_ty.ty, "x2")];
let ret_ty = self.unifier.get_fresh_var(None, None); let ret_ty = self.unifier.get_fresh_var(None, None);
TopLevelDef::Function { TopLevelDef::Function {
@ -1756,12 +1744,10 @@ impl<'a> BuiltinBuilder<'a> {
.iter() .iter()
.map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None }) .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None })
.collect(), .collect(),
ret: ret_ty.0, ret: ret_ty.ty,
vars: [(x1_ty.1, x1_ty.0), (x2_ty.1, x2_ty.0), (ret_ty.1, ret_ty.0)] vars: to_var_map([x1_ty, x2_ty, ret_ty]),
.into_iter()
.collect(),
})), })),
var_id: vec![ret_ty.1], var_id: vec![ret_ty.id],
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
@ -1794,7 +1780,7 @@ impl<'a> BuiltinBuilder<'a> {
(prim.simple_name().into(), method_ty, prim.id()) (prim.simple_name().into(), method_ty, prim.id())
} }
fn new_type_or_ndarray_ty(&mut self, scalar_ty: Type) -> (Type, u32) { fn new_type_or_ndarray_ty(&mut self, scalar_ty: Type) -> TypeVar {
let ndarray = make_ndarray_ty(self.unifier, self.primitives, Some(scalar_ty), None); let ndarray = make_ndarray_ty(self.unifier, self.primitives, Some(scalar_ty), None);
self.unifier.get_fresh_var_with_range(&[scalar_ty, ndarray], Some("T".into()), None) self.unifier.get_fresh_var_with_range(&[scalar_ty, ndarray], Some("T".into()), None)

View File

@ -6,7 +6,7 @@ use crate::{
symbol_resolver::SymbolValue, symbol_resolver::SymbolValue,
typecheck::{ typecheck::{
type_inferencer::{FunctionData, Inferencer}, type_inferencer::{FunctionData, Inferencer},
typedef::VarMap, typedef::{TypeVar, VarMap},
}, },
}; };
@ -225,7 +225,7 @@ impl TopLevelComposer {
// since later when registering class method, ast will still be used, // since later when registering class method, ast will still be used,
// here push None temporarily, later will move the ast inside // here push None temporarily, later will move the ast inside
let constructor_ty = self.unifier.get_dummy_var().0; let constructor_ty = self.unifier.get_dummy_var().ty;
let mut class_def_ast = ( let mut class_def_ast = (
Arc::new(RwLock::new(Self::make_top_level_class_def( Arc::new(RwLock::new(Self::make_top_level_class_def(
DefinitionId(class_def_id), DefinitionId(class_def_id),
@ -281,7 +281,7 @@ impl TopLevelComposer {
}; };
// dummy method define here // dummy method define here
let dummy_method_type = self.unifier.get_dummy_var().0; let dummy_method_type = self.unifier.get_dummy_var().ty;
class_method_name_def_ids.push(( class_method_name_def_ids.push((
*method_name, *method_name,
RwLock::new(Self::make_top_level_function_def( RwLock::new(Self::make_top_level_function_def(
@ -337,7 +337,7 @@ impl TopLevelComposer {
} }
let fun_name = *name; let fun_name = *name;
let ty_to_be_unified = self.unifier.get_dummy_var().0; let ty_to_be_unified = self.unifier.get_dummy_var().ty;
// add to the definition list // add to the definition list
self.definition_ast_list.push(( self.definition_ast_list.push((
RwLock::new(Self::make_top_level_function_def( RwLock::new(Self::make_top_level_function_def(
@ -452,7 +452,7 @@ impl TopLevelComposer {
// check if all are unique type vars // check if all are unique type vars
let all_unique_type_var = { let all_unique_type_var = {
let mut occurred_type_var_id: HashSet<u32> = HashSet::new(); let mut occurred_type_var_id: HashSet<TypeVarId> = HashSet::new();
type_vars.iter().all(|x| { type_vars.iter().all(|x| {
let ty = unifier.get_ty(*x); let ty = unifier.get_ty(*x);
if let TypeEnum::TVar { id, .. } = ty.as_ref() { if let TypeEnum::TVar { id, .. } = ty.as_ref() {
@ -917,18 +917,19 @@ impl TopLevelComposer {
let type_vars_within = let type_vars_within =
get_type_var_contained_in_type_annotation(&type_annotation) get_type_var_contained_in_type_annotation(&type_annotation)
.into_iter() .into_iter()
.map(|x| -> Result<(u32, Type), HashSet<String>> { .map(|x| -> Result<TypeVar, HashSet<String>> {
let TypeAnnotation::TypeVar(ty) = x else { let TypeAnnotation::TypeVar(ty) = x else {
unreachable!("must be type var annotation kind") unreachable!("must be type var annotation kind")
}; };
Ok((Self::get_var_id(ty, unifier)?, ty)) let id = Self::get_var_id(ty, unifier)?;
Ok(TypeVar { id, ty })
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
for (id, ty) in type_vars_within { for var in type_vars_within {
if let Some(prev_ty) = function_var_map.insert(id, ty) { if let Some(prev_ty) = function_var_map.insert(var.id, var.ty) {
// if already have the type inserted, make sure they are the same thing // if already have the type inserted, make sure they are the same thing
assert_eq!(prev_ty, ty); assert_eq!(prev_ty, var.ty);
} }
} }
@ -982,18 +983,19 @@ impl TopLevelComposer {
let type_vars_within = let type_vars_within =
get_type_var_contained_in_type_annotation(&return_ty_annotation) get_type_var_contained_in_type_annotation(&return_ty_annotation)
.into_iter() .into_iter()
.map(|x| -> Result<(u32, Type), HashSet<String>> { .map(|x| -> Result<TypeVar, HashSet<String>> {
let TypeAnnotation::TypeVar(ty) = x else { let TypeAnnotation::TypeVar(ty) = x else {
unreachable!("must be type var here") unreachable!("must be type var here")
}; };
Ok((Self::get_var_id(ty, unifier)?, ty)) let id = Self::get_var_id(ty, unifier)?;
Ok(TypeVar { id, ty })
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
for (id, ty) in type_vars_within { for var in type_vars_within {
if let Some(prev_ty) = function_var_map.insert(id, ty) { if let Some(prev_ty) = function_var_map.insert(var.id, var.ty) {
// if already have the type inserted, make sure they are the same thing // if already have the type inserted, make sure they are the same thing
assert_eq!(prev_ty, ty); assert_eq!(prev_ty, var.ty);
} }
} }
@ -1177,7 +1179,7 @@ impl TopLevelComposer {
// finish handling type vars // finish handling type vars
let dummy_func_arg = FuncArg { let dummy_func_arg = FuncArg {
name, name,
ty: unifier.get_dummy_var().0, ty: unifier.get_dummy_var().ty,
default_value: match default { default_value: match default {
None => None, None => None,
Some(default) => { Some(default) => {
@ -1240,13 +1242,13 @@ impl TopLevelComposer {
assert_eq!(prev_ty, ty); assert_eq!(prev_ty, ty);
} }
} }
let dummy_return_type = unifier.get_dummy_var().0; let dummy_return_type = unifier.get_dummy_var().ty;
type_var_to_concrete_def.insert(dummy_return_type, annotation.clone()); type_var_to_concrete_def.insert(dummy_return_type, annotation.clone());
dummy_return_type dummy_return_type
} else { } else {
// if do not have return annotation, return none // if do not have return annotation, return none
// for uniform handling, still use type annotation // for uniform handling, still use type annotation
let dummy_return_type = unifier.get_dummy_var().0; let dummy_return_type = unifier.get_dummy_var().ty;
type_var_to_concrete_def.insert( type_var_to_concrete_def.insert(
dummy_return_type, dummy_return_type,
TypeAnnotation::Primitive(primitives.none), TypeAnnotation::Primitive(primitives.none),
@ -1286,7 +1288,7 @@ impl TopLevelComposer {
ast::StmtKind::AnnAssign { target, annotation, value: None, .. } => { ast::StmtKind::AnnAssign { target, annotation, value: None, .. } => {
if let ast::ExprKind::Name { id: attr, .. } = &target.node { if let ast::ExprKind::Name { id: attr, .. } = &target.node {
if defined_fields.insert(attr.to_string()) { if defined_fields.insert(attr.to_string()) {
let dummy_field_type = unifier.get_dummy_var().0; let dummy_field_type = unifier.get_dummy_var().ty;
// handle Kernel[T], KernelInvariant[T] // handle Kernel[T], KernelInvariant[T]
let (annotation, mutable) = match &annotation.node { let (annotation, mutable) = match &annotation.node {
@ -1749,7 +1751,7 @@ impl TopLevelComposer {
unreachable!() unreachable!()
}; };
let rigid = unifier.get_fresh_rigid_var(*name, *loc).0; let rigid = unifier.get_fresh_rigid_var(*name, *loc).ty;
no_ranges.push(rigid); no_ranges.push(rigid);
vec![rigid] vec![rigid]
}) })

View File

@ -2,7 +2,7 @@ use std::convert::TryInto;
use crate::symbol_resolver::SymbolValue; use crate::symbol_resolver::SymbolValue;
use crate::toplevel::numpy::unpack_ndarray_var_tys; use crate::toplevel::numpy::unpack_ndarray_var_tys;
use crate::typecheck::typedef::{Mapping, VarMap}; use crate::typecheck::typedef::{to_var_map, Mapping, TypeVarId, VarMap};
use nac3parser::ast::{Constant, Location}; use nac3parser::ast::{Constant, Location};
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
use strum_macros::EnumIter; use strum_macros::EnumIter;
@ -377,12 +377,12 @@ impl TopLevelComposer {
let is_some_type_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { let is_some_type_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![], args: vec![],
ret: bool, ret: bool,
vars: VarMap::from([(option_type_var.1, option_type_var.0)]), vars: to_var_map([option_type_var]),
})); }));
let unwrap_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { let unwrap_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![], args: vec![],
ret: option_type_var.0, ret: option_type_var.ty,
vars: VarMap::from([(option_type_var.1, option_type_var.0)]), vars: to_var_map([option_type_var]),
})); }));
let option = unifier.add_ty(TypeEnum::TObj { let option = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Option.id(), obj_id: PrimDef::Option.id(),
@ -393,7 +393,7 @@ impl TopLevelComposer {
] ]
.into_iter() .into_iter()
.collect::<HashMap<_, _>>(), .collect::<HashMap<_, _>>(),
params: VarMap::from([(option_type_var.1, option_type_var.0)]), params: to_var_map([option_type_var]),
}); });
let size_t_ty = match size_t { let size_t_ty = match size_t {
@ -408,23 +408,17 @@ impl TopLevelComposer {
let ndarray_copy_fun_ret_ty = unifier.get_fresh_var(None, None); let ndarray_copy_fun_ret_ty = unifier.get_fresh_var(None, None);
let ndarray_copy_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { let ndarray_copy_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![], args: vec![],
ret: ndarray_copy_fun_ret_ty.0, ret: ndarray_copy_fun_ret_ty.ty,
vars: VarMap::from([ vars: to_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
]),
})); }));
let ndarray_fill_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { let ndarray_fill_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { args: vec![FuncArg {
name: "value".into(), name: "value".into(),
ty: ndarray_dtype_tvar.0, ty: ndarray_dtype_tvar.ty,
default_value: None, default_value: None,
}], }],
ret: none, ret: none,
vars: VarMap::from([ vars: to_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
]),
})); }));
let ndarray = unifier.add_ty(TypeEnum::TObj { let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::NDArray.id(), obj_id: PrimDef::NDArray.id(),
@ -432,13 +426,10 @@ impl TopLevelComposer {
(PrimDef::NDArrayCopy.simple_name().into(), (ndarray_copy_fun_ty, true)), (PrimDef::NDArrayCopy.simple_name().into(), (ndarray_copy_fun_ty, true)),
(PrimDef::NDArrayFill.simple_name().into(), (ndarray_fill_fun_ty, true)), (PrimDef::NDArrayFill.simple_name().into(), (ndarray_fill_fun_ty, true)),
]), ]),
params: VarMap::from([ params: to_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
]),
}); });
unifier.unify(ndarray_copy_fun_ret_ty.0, ndarray).unwrap(); unifier.unify(ndarray_copy_fun_ret_ty.ty, ndarray).unwrap();
let primitives = PrimitiveStore { let primitives = PrimitiveStore {
int32, int32,
@ -583,7 +574,7 @@ impl TopLevelComposer {
} }
/// get the `var_id` of a given `TVar` type /// get the `var_id` of a given `TVar` type
pub fn get_var_id(var_ty: Type, unifier: &mut Unifier) -> Result<u32, HashSet<String>> { pub fn get_var_id(var_ty: Type, unifier: &mut Unifier) -> Result<TypeVarId, HashSet<String>> {
if let TypeEnum::TVar { id, .. } = unifier.get_ty(var_ty).as_ref() { if let TypeEnum::TVar { id, .. } = unifier.get_ty(var_ty).as_ref() {
Ok(*id) Ok(*id)
} else { } else {

View File

@ -14,7 +14,10 @@ use super::typecheck::typedef::{
use crate::{ use crate::{
codegen::CodeGenerator, codegen::CodeGenerator,
symbol_resolver::{SymbolResolver, ValueEnum}, symbol_resolver::{SymbolResolver, ValueEnum},
typecheck::{type_inferencer::CodeLocation, typedef::CallId}, typecheck::{
type_inferencer::CodeLocation,
typedef::{CallId, TypeVarId},
},
}; };
use inkwell::values::BasicValueEnum; use inkwell::values::BasicValueEnum;
use itertools::{izip, Itertools}; use itertools::{izip, Itertools};
@ -119,7 +122,7 @@ pub enum TopLevelDef {
/// Function signature. /// Function signature.
signature: Type, signature: Type,
/// Instantiated type variable IDs. /// Instantiated type variable IDs.
var_id: Vec<u32>, var_id: Vec<TypeVarId>,
/// Function instance to symbol mapping /// Function instance to symbol mapping
/// ///
/// * Key: String representation of type variable values, sorted by variable ID in ascending /// * Key: String representation of type variable values, sorted by variable ID in ascending

View File

@ -2,7 +2,7 @@ use crate::{
toplevel::helper::PrimDef, toplevel::helper::PrimDef,
typecheck::{ typecheck::{
type_inferencer::PrimitiveStore, type_inferencer::PrimitiveStore,
typedef::{Type, TypeEnum, Unifier, VarMap}, typedef::{Type, TypeEnum, TypeVarId, Unifier, VarMap},
}, },
}; };
use itertools::Itertools; use itertools::Itertools;
@ -57,7 +57,7 @@ pub fn subst_ndarray_tvars(
unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray) unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray)
} }
fn unpack_ndarray_tvars(unifier: &mut Unifier, ndarray: Type) -> Vec<(u32, Type)> { fn unpack_ndarray_tvars(unifier: &mut Unifier, ndarray: Type) -> Vec<(TypeVarId, Type)> {
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else { let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray)) panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
}; };
@ -74,7 +74,7 @@ fn unpack_ndarray_tvars(unifier: &mut Unifier, ndarray: Type) -> Vec<(u32, Type)
/// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds /// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds
/// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` /// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray`
/// respectively. /// respectively.
pub fn unpack_ndarray_var_ids(unifier: &mut Unifier, ndarray: Type) -> (u32, u32) { pub fn unpack_ndarray_var_ids(unifier: &mut Unifier, ndarray: Type) -> (TypeVarId, TypeVarId) {
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.0).collect_tuple().unwrap() unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.0).collect_tuple().unwrap()
} }

View File

@ -5,7 +5,7 @@ expression: res_vec
[ [
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [239]\n}\n", "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(239)]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",

View File

@ -5,8 +5,8 @@ expression: res_vec
[ [
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [241]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(241)]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [246]\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(246)]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",

View File

@ -6,12 +6,12 @@ expression: res_vec
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [247]\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(247)]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [255]\n}\n", "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(255)]\n}\n",
] ]

View File

@ -782,11 +782,11 @@ fn make_internal_resolver_with_tvar(
.into_iter() .into_iter()
.map(|(name, range)| { .map(|(name, range)| {
(name, { (name, {
let (ty, id) = unifier.get_fresh_var_with_range(range.as_slice(), None, None); let tvar = unifier.get_fresh_var_with_range(range.as_slice(), None, None);
if print { if print {
println!("{}: {:?}, typevar{}", name, ty, id); println!("{}: {:?}, typevar{}", name, tvar.ty, tvar.id);
} }
ty tvar.ty
}) })
}) })
.collect::<HashMap<_, _>>() .collect::<HashMap<_, _>>()

View File

@ -123,7 +123,7 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
Ok(TypeAnnotation::CustomClass { id: obj_id, params: vec![] }) Ok(TypeAnnotation::CustomClass { id: obj_id, params: vec![] })
} else if let Ok(ty) = resolver.get_symbol_type(unifier, top_level_defs, primitives, *id) { } else if let Ok(ty) = resolver.get_symbol_type(unifier, top_level_defs, primitives, *id) {
if let TypeEnum::TVar { .. } = unifier.get_ty(ty).as_ref() { if let TypeEnum::TVar { .. } = unifier.get_ty(ty).as_ref() {
let var = unifier.get_fresh_var(Some(*id), Some(expr.location)).0; let var = unifier.get_fresh_var(Some(*id), Some(expr.location)).ty;
unifier.unify(var, ty).unwrap(); unifier.unify(var, ty).unwrap();
Ok(TypeAnnotation::TypeVar(ty)) Ok(TypeAnnotation::TypeVar(ty))
} else { } else {
@ -426,7 +426,7 @@ pub fn get_type_from_type_annotation_kinds(
*name, *name,
*loc, *loc,
); );
unifier.unify(temp.0, p).is_ok() unifier.unify(temp.ty, p).is_ok()
} }
}; };
if ok { if ok {
@ -451,7 +451,7 @@ pub fn get_type_from_type_annotation_kinds(
// create a temp type var and unify to check compatibility // create a temp type var and unify to check compatibility
p == *tvar || { p == *tvar || {
let temp = unifier.get_fresh_const_generic_var(ty, *name, *loc); let temp = unifier.get_fresh_const_generic_var(ty, *name, *loc);
unifier.unify(temp.0, p).is_ok() unifier.unify(temp.ty, p).is_ok()
} }
}; };
if ok { if ok {

View File

@ -103,8 +103,8 @@ pub fn impl_binop(
let (other_ty, other_var_id) = if other_ty.len() == 1 { let (other_ty, other_var_id) = if other_ty.len() == 1 {
(other_ty[0], None) (other_ty[0], None)
} else { } else {
let (ty, var_id) = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None); let tvar = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None);
(ty, Some(var_id)) (tvar.ty, Some(tvar.id))
}; };
let function_vars = if let Some(var_id) = other_var_id { let function_vars = if let Some(var_id) = other_var_id {
@ -113,7 +113,7 @@ pub fn impl_binop(
VarMap::new() VarMap::new()
}; };
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).0); let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty);
for op in ops { for op in ops {
fields.insert(binop_name(*op).into(), { fields.insert(binop_name(*op).into(), {
@ -151,7 +151,7 @@ pub fn impl_binop(
pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Option<Type>, ops: &[Unaryop]) { pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Option<Type>, ops: &[Unaryop]) {
with_fields(unifier, ty, |unifier, fields| { with_fields(unifier, ty, |unifier, fields| {
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).0); let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty);
for op in ops { for op in ops {
fields.insert( fields.insert(
@ -181,8 +181,8 @@ pub fn impl_cmpop(
let (other_ty, other_var_id) = if other_ty.len() == 1 { let (other_ty, other_var_id) = if other_ty.len() == 1 {
(other_ty[0], None) (other_ty[0], None)
} else { } else {
let (ty, var_id) = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None); let tvar = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None);
(ty, Some(var_id)) (tvar.ty, Some(tvar.id))
}; };
let function_vars = if let Some(var_id) = other_var_id { let function_vars = if let Some(var_id) = other_var_id {
@ -191,7 +191,7 @@ pub fn impl_cmpop(
VarMap::new() VarMap::new()
}; };
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).0); let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty);
for op in ops { for op in ops {
fields.insert( fields.insert(
@ -652,7 +652,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
let ndarray_usized_ndims_tvar = let ndarray_usized_ndims_tvar =
unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None); unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None);
let ndarray_unsized_t = let ndarray_unsized_t =
make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.0)); make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.ty));
let (ndarray_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_t); let (ndarray_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_t);
let (ndarray_unsized_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_unsized_t); let (ndarray_unsized_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_unsized_t);
impl_basic_arithmetic( impl_basic_arithmetic(

View File

@ -432,7 +432,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
let enter = TypeEnum::TFunc(FunSignature { let enter = TypeEnum::TFunc(FunSignature {
args: vec![], args: vec![],
ret: item.optional_vars.as_ref().map_or_else( ret: item.optional_vars.as_ref().map_or_else(
|| self.unifier.get_dummy_var().0, || self.unifier.get_dummy_var().ty,
|var| var.custom.unwrap(), |var| var.custom.unwrap(),
), ),
vars: VarMap::default(), vars: VarMap::default(),
@ -440,7 +440,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
let enter = self.unifier.add_ty(enter); let enter = self.unifier.add_ty(enter);
let exit = TypeEnum::TFunc(FunSignature { let exit = TypeEnum::TFunc(FunSignature {
args: vec![], args: vec![],
ret: self.unifier.get_dummy_var().0, ret: self.unifier.get_dummy_var().ty,
vars: VarMap::default(), vars: VarMap::default(),
}); });
let exit = self.unifier.add_ty(exit); let exit = self.unifier.add_ty(exit);
@ -511,7 +511,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
}; };
assert_eq!(*id, *id_var); assert_eq!(*id, *id_var);
(*id, self.unifier.get_fresh_var_with_range(range, *name, *loc).0) (*id, self.unifier.get_fresh_var_with_range(range, *name, *loc).ty)
}) })
.collect::<VarMap>(); .collect::<VarMap>();
Some(self.unifier.subst(self.primitives.option, &var_map).unwrap()) Some(self.unifier.subst(self.primitives.option, &var_map).unwrap())
@ -660,7 +660,7 @@ impl<'a> Inferencer<'a> {
} }
} }
} }
let ret = ret.unwrap_or_else(|| self.unifier.get_dummy_var().0); let ret = ret.unwrap_or_else(|| self.unifier.get_dummy_var().ty);
let call = self.unifier.add_call(Call { let call = self.unifier.add_call(Call {
posargs: params, posargs: params,
@ -706,11 +706,13 @@ impl<'a> Inferencer<'a> {
let fn_args: Vec<_> = args let fn_args: Vec<_> = args
.args .args
.iter() .iter()
.map(|v| (v.node.arg, self.unifier.get_fresh_var(Some(v.node.arg), Some(v.location)).0)) .map(|v| {
(v.node.arg, self.unifier.get_fresh_var(Some(v.node.arg), Some(v.location)).ty)
})
.collect(); .collect();
let mut variable_mapping = self.variable_mapping.clone(); let mut variable_mapping = self.variable_mapping.clone();
variable_mapping.extend(fn_args.iter().copied()); variable_mapping.extend(fn_args.iter().copied());
let ret = self.unifier.get_dummy_var().0; let ret = self.unifier.get_dummy_var().ty;
let mut new_context = Inferencer { let mut new_context = Inferencer {
function_data: self.function_data, function_data: self.function_data,
@ -849,7 +851,7 @@ impl<'a> Inferencer<'a> {
&arg, &arg,
)? )?
} else { } else {
self.unifier.get_dummy_var().0 self.unifier.get_dummy_var().ty
}; };
self.virtual_checks.push((arg0.custom.unwrap(), ty, *func_location)); self.virtual_checks.push((arg0.custom.unwrap(), ty, *func_location));
let custom = Some(self.unifier.add_ty(TypeEnum::TVirtual { ty })); let custom = Some(self.unifier.add_ty(TypeEnum::TVirtual { ty }));
@ -1362,7 +1364,7 @@ impl<'a> Inferencer<'a> {
} }
} }
let ret = self.unifier.get_dummy_var().0; let ret = self.unifier.get_dummy_var().ty;
let call = self.unifier.add_call(Call { let call = self.unifier.add_call(Call {
posargs: args.iter().map(|v| v.custom.unwrap()).collect(), posargs: args.iter().map(|v| v.custom.unwrap()).collect(),
kwargs: keywords kwargs: keywords
@ -1391,7 +1393,7 @@ impl<'a> Inferencer<'a> {
.resolver .resolver
.get_symbol_type(unifier, &self.top_level.definitions.read(), self.primitives, id) .get_symbol_type(unifier, &self.top_level.definitions.read(), self.primitives, id)
.unwrap_or_else(|_| { .unwrap_or_else(|_| {
let ty = unifier.get_dummy_var().0; let ty = unifier.get_dummy_var().ty;
variable_mapping.insert(id, ty); variable_mapping.insert(id, ty);
ty ty
}) })
@ -1420,13 +1422,13 @@ impl<'a> Inferencer<'a> {
ast::Constant::None => { ast::Constant::None => {
report_error("CPython `None` not supported (nac3 uses `none` instead)", *loc) report_error("CPython `None` not supported (nac3 uses `none` instead)", *loc)
} }
ast::Constant::Ellipsis => Ok(self.unifier.get_fresh_var(None, None).0), ast::Constant::Ellipsis => Ok(self.unifier.get_fresh_var(None, None).ty),
_ => report_error("not supported", *loc), _ => report_error("not supported", *loc),
} }
} }
fn infer_list(&mut self, elts: &[ast::Expr<Option<Type>>]) -> InferenceResult { fn infer_list(&mut self, elts: &[ast::Expr<Option<Type>>]) -> InferenceResult {
let ty = self.unifier.get_dummy_var().0; let ty = self.unifier.get_dummy_var().ty;
for t in elts { for t in elts {
self.unify(ty, t.custom.unwrap(), &t.location)?; self.unify(ty, t.custom.unwrap(), &t.location)?;
} }
@ -1462,7 +1464,7 @@ impl<'a> Inferencer<'a> {
} }
} }
} else { } else {
let attr_ty = self.unifier.get_dummy_var().0; let attr_ty = self.unifier.get_dummy_var().ty;
let fields = once(( let fields = once((
attr.into(), attr.into(),
RecordField::new(attr_ty, ctx == ExprContext::Store, Some(value.location)), RecordField::new(attr_ty, ctx == ExprContext::Store, Some(value.location)),
@ -1655,7 +1657,7 @@ impl<'a> Inferencer<'a> {
slice: &ast::Expr<Option<Type>>, slice: &ast::Expr<Option<Type>>,
ctx: ExprContext, ctx: ExprContext,
) -> InferenceResult { ) -> InferenceResult {
let ty = self.unifier.get_dummy_var().0; let ty = self.unifier.get_dummy_var().ty;
match &slice.node { match &slice.node {
ExprKind::Slice { lower, upper, step } => { ExprKind::Slice { lower, upper, step } => {
for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() { for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() {
@ -1759,7 +1761,7 @@ impl<'a> Inferencer<'a> {
let valid_index_ty = self let valid_index_ty = self
.unifier .unifier
.get_fresh_var_with_range(valid_index_tys.as_slice(), None, None) .get_fresh_var_with_range(valid_index_tys.as_slice(), None, None)
.0; .ty;
self.constrain(slice.custom.unwrap(), valid_index_ty, &slice.location)?; self.constrain(slice.custom.unwrap(), valid_index_ty, &slice.location)?;
self.infer_subscript_ndarray(value, ty, ndims) self.infer_subscript_ndarray(value, ty, ndims)
} }

View File

@ -143,10 +143,7 @@ impl TestEnvironment {
let ndarray = unifier.add_ty(TypeEnum::TObj { let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::NDArray.id(), obj_id: PrimDef::NDArray.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::from([ params: to_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
]),
}); });
let primitives = PrimitiveStore { let primitives = PrimitiveStore {
int32, int32,
@ -321,19 +318,19 @@ impl TestEnvironment {
unifier.put_primitive_store(&primitives); unifier.put_primitive_store(&primitives);
let (v0, id) = unifier.get_dummy_var(); let tvar = unifier.get_dummy_var();
let foo_ty = unifier.add_ty(TypeEnum::TObj { let foo_ty = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(defs + 1), obj_id: DefinitionId(defs + 1),
fields: [("a".into(), (v0, true))].iter().cloned().collect::<HashMap<_, _>>(), fields: [("a".into(), (tvar.ty, true))].iter().cloned().collect::<HashMap<_, _>>(),
params: [(id, v0)].iter().cloned().collect::<VarMap>(), params: to_var_map([tvar]),
}); });
top_level_defs.push( top_level_defs.push(
RwLock::new(TopLevelDef::Class { RwLock::new(TopLevelDef::Class {
name: "Foo".into(), name: "Foo".into(),
object_id: DefinitionId(defs + 1), object_id: DefinitionId(defs + 1),
type_vars: vec![v0], type_vars: vec![tvar.ty],
fields: [("a".into(), v0, true)].into(), fields: [("a".into(), tvar.ty, true)].into(),
methods: Default::default(), methods: Default::default(),
ancestors: Default::default(), ancestors: Default::default(),
resolver: None, resolver: None,
@ -348,7 +345,7 @@ impl TestEnvironment {
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![], args: vec![],
ret: foo_ty, ret: foo_ty,
vars: [(id, v0)].iter().cloned().collect(), vars: to_var_map([tvar]),
})), })),
); );

View File

@ -2,7 +2,7 @@ use indexmap::IndexMap;
use itertools::Itertools; use itertools::Itertools;
use std::cell::RefCell; use std::cell::RefCell;
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::Display; use std::fmt::{self, Display};
use std::iter::zip; use std::iter::zip;
use std::rc::Rc; use std::rc::Rc;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
@ -28,8 +28,44 @@ pub struct CallId(pub(super) usize);
pub type Mapping<K, V = Type> = HashMap<K, V>; pub type Mapping<K, V = Type> = HashMap<K, V>;
pub type IndexMapping<K, V = Type> = IndexMap<K, V>; pub type IndexMapping<K, V = Type> = IndexMap<K, V>;
/// The mapping between type variable ID and [unifier type][`Type`]. /// ID of a Python type variable. Specific to `nac3core`.
pub type VarMap = IndexMapping<u32>; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct TypeVarId(pub u32);
impl fmt::Display for TypeVarId {
// NOTE: Must output the string fo the ID value. Certain unit tests rely on string comparisons.
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_fmt(format_args!("{}", self.0))
}
}
/// A Python type variable. Used by `nac3core` during type inference.
#[derive(Debug, Clone, Copy)]
pub struct TypeVar {
/// `nac3core`'s internal [`TypeVarId`] of this type variable.
pub id: TypeVarId,
/// The assigned [`Type`] of this Python type variable.
pub ty: Type,
}
/// The mapping between [`TypeVarId`] and [unifier type][`Type`].
pub type VarMap = IndexMapping<TypeVarId>;
/// Build a [`VarMap`] from an iterator of [`TypeVar`]
///
/// The resulting [`VarMap`] wil have the same order as the input iterator.
pub fn to_var_map<I>(vars: I) -> VarMap
where
I: IntoIterator<Item = TypeVar>,
{
vars.into_iter().map(|var| (var.id, var.ty)).collect()
}
/// Get an iterator of [`TypeVar`]s from a [`VarMap`]
pub fn iter_type_vars(var_map: &VarMap) -> impl Iterator<Item = TypeVar> + '_ {
var_map.iter().map(|(&id, &ty)| TypeVar { id, ty })
}
#[derive(Clone)] #[derive(Clone)]
pub struct Call { pub struct Call {
@ -127,14 +163,14 @@ impl RecordField {
#[derive(Clone)] #[derive(Clone)]
pub enum TypeEnum { pub enum TypeEnum {
TRigidVar { TRigidVar {
id: u32, id: TypeVarId,
name: Option<StrRef>, name: Option<StrRef>,
loc: Option<Location>, loc: Option<Location>,
}, },
/// A type variable. /// A type variable.
TVar { TVar {
id: u32, id: TypeVarId,
// empty indicates this is not a struct/tuple/list // empty indicates this is not a struct/tuple/list
fields: Option<Mapping<RecordKey, RecordField>>, fields: Option<Mapping<RecordKey, RecordField>>,
// empty indicates no restriction // empty indicates no restriction
@ -295,7 +331,7 @@ impl Unifier {
} }
pub fn add_record(&mut self, fields: Mapping<RecordKey, RecordField>) -> Type { pub fn add_record(&mut self, fields: Mapping<RecordKey, RecordField>) -> Type {
let id = self.var_id + 1; let id = TypeVarId(self.var_id + 1);
self.var_id += 1; self.var_id += 1;
self.add_ty(TypeEnum::TVar { self.add_ty(TypeEnum::TVar {
id, id,
@ -346,24 +382,21 @@ impl Unifier {
self.unification_table.probe_value_immutable(a).clone() self.unification_table.probe_value_immutable(a).clone()
} }
pub fn get_fresh_rigid_var( pub fn get_fresh_rigid_var(&mut self, name: Option<StrRef>, loc: Option<Location>) -> TypeVar {
&mut self, let id = TypeVarId(self.var_id + 1);
name: Option<StrRef>,
loc: Option<Location>,
) -> (Type, u32) {
let id = self.var_id + 1;
self.var_id += 1; self.var_id += 1;
(self.add_ty(TypeEnum::TRigidVar { id, name, loc }), id) let ty = self.add_ty(TypeEnum::TRigidVar { id, name, loc });
TypeVar { id, ty }
} }
pub fn get_dummy_var(&mut self) -> (Type, u32) { pub fn get_dummy_var(&mut self) -> TypeVar {
self.get_fresh_var_with_range(&[], None, None) self.get_fresh_var_with_range(&[], None, None)
} }
/// Returns a fresh [type variable][TypeEnum::TVar] with no associated range. /// Returns a fresh [type variable][TypeEnum::TVar] with no associated range.
/// ///
/// This type variable can be instantiated by any type. /// This type variable can be instantiated by any type.
pub fn get_fresh_var(&mut self, name: Option<StrRef>, loc: Option<Location>) -> (Type, u32) { pub fn get_fresh_var(&mut self, name: Option<StrRef>, loc: Option<Location>) -> TypeVar {
self.get_fresh_var_with_range(&[], name, loc) self.get_fresh_var_with_range(&[], name, loc)
} }
@ -375,21 +408,20 @@ impl Unifier {
range: &[Type], range: &[Type],
name: Option<StrRef>, name: Option<StrRef>,
loc: Option<Location>, loc: Option<Location>,
) -> (Type, u32) { ) -> TypeVar {
let id = self.var_id + 1;
self.var_id += 1;
let range = range.to_vec(); let range = range.to_vec();
(
self.add_ty(TypeEnum::TVar { let id = TypeVarId(self.var_id + 1);
self.var_id += 1;
let ty = self.add_ty(TypeEnum::TVar {
id, id,
range, range,
fields: None, fields: None,
name, name,
loc, loc,
is_const_generic: false, is_const_generic: false,
}), });
id, TypeVar { id, ty }
)
} }
/// Returns a fresh type representing a constant generic variable with the given underlying type `ty`. /// Returns a fresh type representing a constant generic variable with the given underlying type `ty`.
@ -398,20 +430,18 @@ impl Unifier {
ty: Type, ty: Type,
name: Option<StrRef>, name: Option<StrRef>,
loc: Option<Location>, loc: Option<Location>,
) -> (Type, u32) { ) -> TypeVar {
let id = self.var_id + 1; let id = TypeVarId(self.var_id + 1);
self.var_id += 1; self.var_id += 1;
( let ty = self.add_ty(TypeEnum::TVar {
self.add_ty(TypeEnum::TVar {
id, id,
range: vec![ty], range: vec![ty],
fields: None, fields: None,
name, name,
loc, loc,
is_const_generic: true, is_const_generic: true,
}), });
id, TypeVar { id, ty }
)
} }
/// Returns a fresh type representing a [literal][TypeEnum::TConstant] with the given `values`. /// Returns a fresh type representing a [literal][TypeEnum::TConstant] with the given `values`.
@ -464,7 +494,7 @@ impl Unifier {
} }
} }
TypeEnum::TObj { params, .. } => { TypeEnum::TObj { params, .. } => {
let (keys, params): (Vec<u32>, Vec<Type>) = params.iter().unzip(); let (keys, params): (Vec<TypeVarId>, Vec<Type>) = params.iter().unzip();
let params = params let params = params
.into_iter() .into_iter()
.map(|ty| self.get_instantiations(ty).unwrap_or_else(|| vec![ty])) .map(|ty| self.get_instantiations(ty).unwrap_or_else(|| vec![ty]))
@ -1014,7 +1044,7 @@ impl Unifier {
pub fn stringify_with_notes( pub fn stringify_with_notes(
&self, &self,
ty: Type, ty: Type,
notes: &mut Option<HashMap<u32, String>>, notes: &mut Option<HashMap<TypeVarId, String>>,
) -> String { ) -> String {
let top_level = self.top_level.clone(); let top_level = self.top_level.clone();
self.internal_stringify( self.internal_stringify(
@ -1043,11 +1073,11 @@ impl Unifier {
ty: Type, ty: Type,
obj_to_name: &mut F, obj_to_name: &mut F,
var_to_name: &mut G, var_to_name: &mut G,
notes: &mut Option<HashMap<u32, String>>, notes: &mut Option<HashMap<TypeVarId, String>>,
) -> String ) -> String
where where
F: FnMut(usize) -> String, F: FnMut(usize) -> String,
G: FnMut(u32) -> String, G: FnMut(TypeVarId) -> String,
{ {
let ty = self.unification_table.probe_value_immutable(ty).clone(); let ty = self.unification_table.probe_value_immutable(ty).clone();
match ty.as_ref() { match ty.as_ref() {
@ -1182,7 +1212,7 @@ impl Unifier {
let mapping = vars let mapping = vars
.into_iter() .into_iter()
.map(|(k, range, name, loc)| { .map(|(k, range, name, loc)| {
(k, self.get_fresh_var_with_range(range.as_ref(), name, loc).0) (k, self.get_fresh_var_with_range(range.as_ref(), name, loc).ty)
}) })
.collect(); .collect();
self.subst(ty, &mapping).unwrap_or(ty) self.subst(ty, &mapping).unwrap_or(ty)
@ -1206,7 +1236,7 @@ impl Unifier {
let cached = cache.get_mut(&a); let cached = cache.get_mut(&a);
if let Some(cached) = cached { if let Some(cached) = cached {
if cached.is_none() { if cached.is_none() {
*cached = Some(self.get_fresh_var(None, None).0); *cached = Some(self.get_fresh_var(None, None).ty);
} }
return *cached; return *cached;
} }
@ -1361,7 +1391,7 @@ impl Unifier {
if range.is_empty() { if range.is_empty() {
Err(()) Err(())
} else { } else {
let id = self.var_id + 1; let id = TypeVarId(self.var_id + 1);
self.var_id += 1; self.var_id += 1;
let ty = TVar { let ty = TVar {
id, id,

View File

@ -110,13 +110,13 @@ impl TestEnvironment {
params: VarMap::new(), params: VarMap::new(),
}), }),
); );
let (v0, id) = unifier.get_dummy_var(); let tvar = unifier.get_dummy_var();
type_mapping.insert( type_mapping.insert(
"Foo".into(), "Foo".into(),
unifier.add_ty(TypeEnum::TObj { unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(3), obj_id: DefinitionId(3),
fields: [("a".into(), (v0, true))].iter().cloned().collect::<HashMap<_, _>>(), fields: [("a".into(), (tvar.ty, true))].iter().cloned().collect::<HashMap<_, _>>(),
params: [(id, v0)].iter().cloned().collect::<VarMap>(), params: to_var_map([tvar]),
}), }),
); );
@ -250,7 +250,7 @@ fn test_unify(
let mut mapping = HashMap::new(); let mut mapping = HashMap::new();
for i in 1..=variable_count { for i in 1..=variable_count {
let v = env.unifier.get_dummy_var(); let v = env.unifier.get_dummy_var();
mapping.insert(format!("v{}", i), v.0); mapping.insert(format!("v{}", i), v.ty);
} }
// unification may have side effect when we do type resolution, so freeze the types // unification may have side effect when we do type resolution, so freeze the types
// before doing unification. // before doing unification.
@ -315,7 +315,7 @@ fn test_invalid_unification(
let mut mapping = HashMap::new(); let mut mapping = HashMap::new();
for i in 1..=variable_count { for i in 1..=variable_count {
let v = env.unifier.get_dummy_var(); let v = env.unifier.get_dummy_var();
mapping.insert(format!("v{}", i), v.0); mapping.insert(format!("v{}", i), v.ty);
} }
// unification may have side effect when we do type resolution, so freeze the types // unification may have side effect when we do type resolution, so freeze the types
// before doing unification. // before doing unification.
@ -369,8 +369,8 @@ fn test_virtual() {
.collect::<HashMap<StrRef, _>>(), .collect::<HashMap<StrRef, _>>(),
params: VarMap::new(), params: VarMap::new(),
}); });
let v0 = env.unifier.get_dummy_var().0; let v0 = env.unifier.get_dummy_var().ty;
let v1 = env.unifier.get_dummy_var().0; let v1 = env.unifier.get_dummy_var().ty;
let a = env.unifier.add_ty(TypeEnum::TVirtual { ty: bar }); let a = env.unifier.add_ty(TypeEnum::TVirtual { ty: bar });
let b = env.unifier.add_ty(TypeEnum::TVirtual { ty: v0 }); let b = env.unifier.add_ty(TypeEnum::TVirtual { ty: v0 });
@ -403,12 +403,12 @@ fn test_typevar_range() {
// unification between v and int // unification between v and int
// where v in (int, bool) // where v in (int, bool)
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0; let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty;
env.unifier.unify(int, v).unwrap(); env.unifier.unify(int, v).unwrap();
// unification between v and list[int] // unification between v and list[int]
// where v in (int, bool) // where v in (int, bool)
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0; let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty;
assert_eq!( assert_eq!(
env.unify(int_list, v), env.unify(int_list, v),
Err("Expected any one of these types: 0, 2, but got list[0]".to_string()) Err("Expected any one of these types: 0, 2, but got list[0]".to_string())
@ -416,25 +416,25 @@ fn test_typevar_range() {
// unification between v and float // unification between v and float
// where v in (int, bool) // where v in (int, bool)
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0; let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty;
assert_eq!( assert_eq!(
env.unify(float, v), env.unify(float, v),
Err("Expected any one of these types: 0, 2, but got 1".to_string()) Err("Expected any one of these types: 0, 2, but got 1".to_string())
); );
let v1 = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0; let v1 = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty;
let v1_list = env.unifier.add_ty(TypeEnum::TList { ty: v1 }); let v1_list = env.unifier.add_ty(TypeEnum::TList { ty: v1 });
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).0; let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).ty;
// unification between v and int // unification between v and int
// where v in (int, list[v1]), v1 in (int, bool) // where v in (int, list[v1]), v1 in (int, bool)
env.unifier.unify(int, v).unwrap(); env.unifier.unify(int, v).unwrap();
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).0; let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).ty;
// unification between v and list[int] // unification between v and list[int]
// where v in (int, list[v1]), v1 in (int, bool) // where v in (int, list[v1]), v1 in (int, bool)
env.unifier.unify(int_list, v).unwrap(); env.unifier.unify(int_list, v).unwrap();
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).0; let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).ty;
// unification between v and list[float] // unification between v and list[float]
// where v in (int, list[v1]), v1 in (int, bool) // where v in (int, list[v1]), v1 in (int, bool)
assert_eq!( assert_eq!(
@ -442,30 +442,30 @@ fn test_typevar_range() {
Err("Expected any one of these types: 0, list[typevar5], but got list[1]\n\nNotes:\n typevar5 ∈ {0, 2}".to_string()) Err("Expected any one of these types: 0, list[typevar5], but got list[1]\n\nNotes:\n typevar5 ∈ {0, 2}".to_string())
); );
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0; let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0; let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).ty;
env.unifier.unify(a, b).unwrap(); env.unifier.unify(a, b).unwrap();
env.unifier.unify(a, float).unwrap(); env.unifier.unify(a, float).unwrap();
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0; let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0; let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).ty;
env.unifier.unify(a, b).unwrap(); env.unifier.unify(a, b).unwrap();
assert_eq!(env.unify(a, int), Err("Expected any one of these types: 1, but got 0".into())); assert_eq!(env.unify(a, int), Err("Expected any one of these types: 1, but got 0".into()));
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0; let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0; let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).ty;
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a }); let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a });
let a_list = env.unifier.get_fresh_var_with_range(&[a_list], None, None).0; let a_list = env.unifier.get_fresh_var_with_range(&[a_list], None, None).ty;
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b }); let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b });
let b_list = env.unifier.get_fresh_var_with_range(&[b_list], None, None).0; let b_list = env.unifier.get_fresh_var_with_range(&[b_list], None, None).ty;
env.unifier.unify(a_list, b_list).unwrap(); env.unifier.unify(a_list, b_list).unwrap();
let float_list = env.unifier.add_ty(TypeEnum::TList { ty: float }); let float_list = env.unifier.add_ty(TypeEnum::TList { ty: float });
env.unifier.unify(a_list, float_list).unwrap(); env.unifier.unify(a_list, float_list).unwrap();
// previous unifications should not affect a and b // previous unifications should not affect a and b
env.unifier.unify(a, int).unwrap(); env.unifier.unify(a, int).unwrap();
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0; let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0; let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).ty;
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a }); let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a });
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b }); let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b });
env.unifier.unify(a_list, b_list).unwrap(); env.unifier.unify(a_list, b_list).unwrap();
@ -477,10 +477,10 @@ fn test_typevar_range() {
.into()) .into())
); );
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0; let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty;
let b = env.unifier.get_dummy_var().0; let b = env.unifier.get_dummy_var().ty;
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a }); let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a });
let a_list = env.unifier.get_fresh_var_with_range(&[a_list], None, None).0; let a_list = env.unifier.get_fresh_var_with_range(&[a_list], None, None).ty;
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b }); let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b });
env.unifier.unify(a_list, b_list).unwrap(); env.unifier.unify(a_list, b_list).unwrap();
assert_eq!( assert_eq!(
@ -492,9 +492,9 @@ fn test_typevar_range() {
#[test] #[test]
fn test_rigid_var() { fn test_rigid_var() {
let mut env = TestEnvironment::new(); let mut env = TestEnvironment::new();
let a = env.unifier.get_fresh_rigid_var(None, None).0; let a = env.unifier.get_fresh_rigid_var(None, None).ty;
let b = env.unifier.get_fresh_rigid_var(None, None).0; let b = env.unifier.get_fresh_rigid_var(None, None).ty;
let x = env.unifier.get_dummy_var().0; let x = env.unifier.get_dummy_var().ty;
let list_a = env.unifier.add_ty(TypeEnum::TList { ty: a }); let list_a = env.unifier.add_ty(TypeEnum::TList { ty: a });
let list_x = env.unifier.add_ty(TypeEnum::TList { ty: x }); let list_x = env.unifier.add_ty(TypeEnum::TList { ty: x });
let int = env.parse("int", &HashMap::new()); let int = env.parse("int", &HashMap::new());
@ -522,13 +522,13 @@ fn test_instantiation() {
let obj_map: HashMap<_, _> = let obj_map: HashMap<_, _> =
[(0usize, "int"), (1, "float"), (2, "bool")].iter().cloned().collect(); [(0usize, "int"), (1, "float"), (2, "bool")].iter().cloned().collect();
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0; let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty;
let list_v = env.unifier.add_ty(TypeEnum::TList { ty: v }); let list_v = env.unifier.add_ty(TypeEnum::TList { ty: v });
let v1 = env.unifier.get_fresh_var_with_range(&[list_v, int], None, None).0; let v1 = env.unifier.get_fresh_var_with_range(&[list_v, int], None, None).ty;
let v2 = env.unifier.get_fresh_var_with_range(&[list_int, float], None, None).0; let v2 = env.unifier.get_fresh_var_with_range(&[list_int, float], None, None).ty;
let t = env.unifier.get_dummy_var().0; let t = env.unifier.get_dummy_var().ty;
let tuple = env.unifier.add_ty(TypeEnum::TTuple { ty: vec![v, v1, v2] }); let tuple = env.unifier.add_ty(TypeEnum::TTuple { ty: vec![v, v1, v2] });
let v3 = env.unifier.get_fresh_var_with_range(&[tuple, t], None, None).0; let v3 = env.unifier.get_fresh_var_with_range(&[tuple, t], None, None).ty;
// t = TypeVar('t') // t = TypeVar('t')
// v = TypeVar('v', int, bool) // v = TypeVar('v', int, bool)
// v1 = TypeVar('v1', 'list[v]', int) // v1 = TypeVar('v1', 'list[v]', int)

View File

@ -124,7 +124,7 @@ fn handle_typevar_definition(
)])); )]));
} }
Ok(unifier.get_fresh_var_with_range(&constraints, Some(generic_name), Some(loc)).0) Ok(unifier.get_fresh_var_with_range(&constraints, Some(generic_name), Some(loc)).ty)
} }
ExprKind::Name { id, .. } if id == &"ConstGeneric".into() => { ExprKind::Name { id, .. } if id == &"ConstGeneric".into() => {
@ -155,7 +155,7 @@ fn handle_typevar_definition(
get_type_from_type_annotation_kinds(def_list, unifier, &ty, &mut None)?; get_type_from_type_annotation_kinds(def_list, unifier, &ty, &mut None)?;
let loc = func.location; let loc = func.location;
Ok(unifier.get_fresh_const_generic_var(constraint, Some(generic_name), Some(loc)).0) Ok(unifier.get_fresh_const_generic_var(constraint, Some(generic_name), Some(loc)).ty)
} }
_ => Err(HashSet::from([format!( _ => Err(HashSet::from([format!(