From 6a64c9d1de006c731c4d45d96c7598ee42dc3001 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 10 Jul 2024 12:27:59 +0800 Subject: [PATCH] core/typecheck/typedef: Add is_vararg_ctx to TTuple --- nac3artiq/src/codegen.rs | 2 +- nac3artiq/src/symbol_resolver.rs | 15 +- nac3core/src/codegen/concrete_type.rs | 16 +- nac3core/src/codegen/expr.rs | 9 +- nac3core/src/codegen/mod.rs | 4 +- nac3core/src/codegen/stmt.rs | 5 +- nac3core/src/symbol_resolver.rs | 8 +- nac3core/src/toplevel/builtins.rs | 2 + nac3core/src/toplevel/type_annotation.rs | 2 +- nac3core/src/typecheck/function_check.rs | 2 +- nac3core/src/typecheck/type_error.rs | 7 +- nac3core/src/typecheck/type_inferencer/mod.rs | 14 +- nac3core/src/typecheck/typedef/mod.rs | 139 ++++++++++++++---- nac3core/src/typecheck/typedef/test.rs | 9 +- 14 files changed, 174 insertions(+), 60 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 9b0b00d96..be9fecc3c 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -386,7 +386,7 @@ fn gen_rpc_tag( } else { let ty_enum = ctx.unifier.get_ty(ty); match &*ty_enum { - TTuple { ty } => { + TTuple { ty, is_vararg_ctx: false } => { buffer.push(b't'); buffer.push(ty.len() as u8); for ty in ty { diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 2e998ff39..31d0864c8 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -351,7 +351,7 @@ impl InnerResolver { Ok(Ok((ndarray, false))) } else if ty_id == self.primitive_ids.tuple { // do not handle type var param and concrete check here - Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false))) + Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![], is_vararg_ctx: false }), false))) } else if ty_id == self.primitive_ids.option { Ok(Ok((primitives.option, false))) } else if ty_id == self.primitive_ids.none { @@ -555,7 +555,10 @@ impl InnerResolver { Err(err) => return Ok(Err(err)), _ => return Ok(Err("tuple type needs at least 1 type parameters".to_string())) }; - Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: args }), true))) + Ok(Ok(( + unifier.add_ty(TypeEnum::TTuple { ty: args, is_vararg_ctx: false }), + true, + ))) } TypeEnum::TObj { params, obj_id, .. } => { let subst = { @@ -797,7 +800,9 @@ impl InnerResolver { .map(|elem| self.get_obj_type(py, elem, unifier, defs, primitives)) .collect(); let types = types?; - Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types }))) + Ok(types.map(|types| { + unifier.add_ty(TypeEnum::TTuple { ty: types, is_vararg_ctx: false }) + })) } // special handling for option type since its class member layout in python side // is special and cannot be mapped directly to a nac3 type as below @@ -1203,7 +1208,9 @@ impl InnerResolver { Ok(Some(ndarray.as_pointer_value().into())) } else if ty_id == self.primitive_ids.tuple { let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty); - let TypeEnum::TTuple { ty } = expected_ty_enum.as_ref() else { unreachable!() }; + let TypeEnum::TTuple { ty, is_vararg_ctx: false } = expected_ty_enum.as_ref() else { + unreachable!() + }; let tup_tys = ty.iter(); let elements: &PyTuple = obj.downcast()?; diff --git a/nac3core/src/codegen/concrete_type.rs b/nac3core/src/codegen/concrete_type.rs index ff0777757..8680beeea 100644 --- a/nac3core/src/codegen/concrete_type.rs +++ b/nac3core/src/codegen/concrete_type.rs @@ -47,6 +47,7 @@ pub enum ConcreteTypeEnum { TPrimitive(Primitive), TTuple { ty: Vec, + is_vararg_ctx: bool, }, TObj { obj_id: DefinitionId, @@ -103,7 +104,14 @@ impl ConcreteTypeStore { .iter() .map(|arg| ConcreteFuncArg { name: arg.name, - ty: self.from_unifier_type(unifier, primitives, arg.ty, cache), + ty: if arg.is_vararg { + let tuple_ty = unifier + .add_ty(TypeEnum::TTuple { ty: vec![arg.ty], is_vararg_ctx: true }); + + self.from_unifier_type(unifier, primitives, tuple_ty, cache) + } else { + self.from_unifier_type(unifier, primitives, arg.ty, cache) + }, default_value: arg.default_value.clone(), is_vararg: arg.is_vararg, }) @@ -160,11 +168,12 @@ impl ConcreteTypeStore { cache.insert(ty, None); let ty_enum = unifier.get_ty(ty); let result = match &*ty_enum { - TypeEnum::TTuple { ty } => ConcreteTypeEnum::TTuple { + TypeEnum::TTuple { ty, is_vararg_ctx } => ConcreteTypeEnum::TTuple { ty: ty .iter() .map(|t| self.from_unifier_type(unifier, primitives, *t, cache)) .collect(), + is_vararg_ctx: *is_vararg_ctx, }, TypeEnum::TObj { obj_id, fields, params } => ConcreteTypeEnum::TObj { obj_id: *obj_id, @@ -250,11 +259,12 @@ impl ConcreteTypeStore { *cache.get_mut(&cty).unwrap() = Some(ty); return ty; } - ConcreteTypeEnum::TTuple { ty } => TypeEnum::TTuple { + ConcreteTypeEnum::TTuple { ty, is_vararg_ctx } => TypeEnum::TTuple { ty: ty .iter() .map(|cty| self.to_unifier_type(unifier, primitives, *cty, cache)) .collect(), + is_vararg_ctx: *is_vararg_ctx, }, ConcreteTypeEnum::TVirtual { ty } => { TypeEnum::TVirtual { ty: self.to_unifier_type(unifier, primitives, *ty, cache) } diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 847ddb73f..b22d3cc2a 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -267,13 +267,16 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { } Constant::Tuple(v) => { let ty = self.unifier.get_ty(ty); - let types = - if let TypeEnum::TTuple { ty } = &*ty { ty.clone() } else { unreachable!() }; + let (types, is_vararg_ctx) = if let TypeEnum::TTuple { ty, is_vararg_ctx } = &*ty { + (ty.clone(), *is_vararg_ctx) + } else { + unreachable!() + }; let values = zip(types, v.iter()) .map_while(|(ty, v)| self.gen_const(generator, v, ty)) .collect_vec(); - if values.len() == v.len() { + if is_vararg_ctx || values.len() == v.len() { let types = values.iter().map(BasicValueEnum::get_type).collect_vec(); let ty = self.ctx.struct_type(&types, false); Some(ty.const_named_struct(&values).into()) diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index c5d07ca79..7bc8c9892 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -538,8 +538,10 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( }; return ty; } - TTuple { ty } => { + TTuple { ty, is_vararg_ctx } => { // a struct with fields in the order present in the tuple + assert!(!is_vararg_ctx, "Tuples in vararg context must be instantiated with the correct number of arguments before calling get_llvm_type"); + let fields = ty .iter() .map(|ty| { diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index c5148ad56..ea5869d67 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -197,7 +197,7 @@ pub fn gen_assign_target_list<'ctx, G: CodeGenerator>( }; // NOTE: Currently, RHS's type is forced to be a Tuple by the type inferencer. - let TypeEnum::TTuple { ty: tuple_tys } = &*ctx.unifier.get_ty(value_ty) else { + let TypeEnum::TTuple { ty: tuple_tys, .. } = &*ctx.unifier.get_ty(value_ty) else { unreachable!(); }; @@ -252,7 +252,8 @@ pub fn gen_assign_target_list<'ctx, G: CodeGenerator>( ctx.builder.build_load(psub_tuple_val, "starred_target_value").unwrap(); // Create the typechecker type of the sub-tuple - let sub_tuple_ty = ctx.unifier.add_ty(TypeEnum::TTuple { ty: val_tys.to_vec() }); + let sub_tuple_ty = + ctx.unifier.add_ty(TypeEnum::TTuple { ty: val_tys.to_vec(), is_vararg_ctx: false }); // Now assign with that sub-tuple to the starred target. generator.gen_assign(ctx, target, ValueEnum::Dynamic(sub_tuple_val), sub_tuple_ty)?; diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index ce9c0985c..9d7084b91 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -78,14 +78,14 @@ impl SymbolValue { } Constant::Tuple(t) => { let expected_ty = unifier.get_ty(expected_ty); - let TypeEnum::TTuple { ty } = expected_ty.as_ref() else { + let TypeEnum::TTuple { ty, is_vararg_ctx } = expected_ty.as_ref() else { return Err(format!( "Expected {:?}, but got Tuple", expected_ty.get_type_name() )); }; - assert_eq!(ty.len(), t.len()); + assert!(*is_vararg_ctx || ty.len() == t.len()); let elems = t .iter() @@ -155,7 +155,7 @@ impl SymbolValue { SymbolValue::Bool(_) => primitives.bool, SymbolValue::Tuple(vs) => { let vs_tys = vs.iter().map(|v| v.get_type(primitives, unifier)).collect::>(); - unifier.add_ty(TypeEnum::TTuple { ty: vs_tys }) + unifier.add_ty(TypeEnum::TTuple { ty: vs_tys, is_vararg_ctx: false }) } SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option, } @@ -482,7 +482,7 @@ pub fn parse_type_annotation( parse_type_annotation(resolver, top_level_defs, unifier, primitives, elt) }) .collect::, _>>()?; - Ok(unifier.add_ty(TypeEnum::TTuple { ty })) + Ok(unifier.add_ty(TypeEnum::TTuple { ty, is_vararg_ctx: false })) } else { Err(HashSet::from(["Expected multiple elements for tuple".into()])) } diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 247031634..be8687eac 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -2083,6 +2083,7 @@ impl<'a> BuiltinBuilder<'a> { | PrimDef::FunSpLinalgHessenberg => { let ret_ty = self.unifier.add_ty(TypeEnum::TTuple { ty: vec![self.ndarray_float_2d, self.ndarray_float_2d], + is_vararg_ctx: false, }); create_fn_by_codegen( self.unifier, @@ -2112,6 +2113,7 @@ impl<'a> BuiltinBuilder<'a> { PrimDef::FunNpLinalgSvd => { let ret_ty = self.unifier.add_ty(TypeEnum::TTuple { ty: vec![self.ndarray_float_2d, self.ndarray_float, self.ndarray_float_2d], + is_vararg_ctx: false, }); create_fn_by_codegen( self.unifier, diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 3f9b61a0b..827f5330f 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -552,7 +552,7 @@ pub fn get_type_from_type_annotation_kinds( ) }) .collect::, _>>()?; - Ok(unifier.add_ty(TypeEnum::TTuple { ty: tys })) + Ok(unifier.add_ty(TypeEnum::TTuple { ty: tys, is_vararg_ctx: false })) } } } diff --git a/nac3core/src/typecheck/function_check.rs b/nac3core/src/typecheck/function_check.rs index 86fc8b8ee..b37909942 100644 --- a/nac3core/src/typecheck/function_check.rs +++ b/nac3core/src/typecheck/function_check.rs @@ -223,7 +223,7 @@ impl<'a> Inferencer<'a> { ] .iter() .any(|allowed_ty| self.unifier.unioned(ret_ty, *allowed_ty)), - TypeEnum::TTuple { ty } => ty.iter().all(|t| self.check_return_value_ty(*t)), + TypeEnum::TTuple { ty, .. } => ty.iter().all(|t| self.check_return_value_ty(*t)), _ => false, } } diff --git a/nac3core/src/typecheck/type_error.rs b/nac3core/src/typecheck/type_error.rs index 0d84b87ea..706ba5b44 100644 --- a/nac3core/src/typecheck/type_error.rs +++ b/nac3core/src/typecheck/type_error.rs @@ -183,9 +183,10 @@ impl<'a> Display for DisplayTypeError<'a> { } result } - (TypeEnum::TTuple { ty: ty1 }, TypeEnum::TTuple { ty: ty2 }) - if ty1.len() != ty2.len() => - { + ( + TypeEnum::TTuple { ty: ty1, is_vararg_ctx: is_vararg1 }, + TypeEnum::TTuple { ty: ty2, is_vararg_ctx: is_vararg2 }, + ) if !is_vararg1 && !is_vararg2 && ty1.len() != ty2.len() => { let t1 = self.unifier.stringify_with_notes(*t1, &mut notes); let t2 = self.unifier.stringify_with_notes(*t2, &mut notes); write!(f, "Tuple length mismatch: got {t1} and {t2}") diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 24fe7a12a..9ac503a17 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -973,13 +973,14 @@ impl<'a> Inferencer<'a> { ])); } } - TypeEnum::TTuple { ty: tuple_element_types } => { + TypeEnum::TTuple { ty: tuple_element_types, .. } => { // Handle 2. A tuple of int32s // Typecheck // The expected type is just the tuple but with all its elements being int32. let expected_ty = self.unifier.add_ty(TypeEnum::TTuple { ty: tuple_element_types.iter().map(|_| self.primitives.int32).collect_vec(), + is_vararg_ctx: false, }); self.unifier.unify(shape_ty, expected_ty).map_err(|err| { HashSet::from([err @@ -1714,7 +1715,7 @@ impl<'a> Inferencer<'a> { ast::Constant::Tuple(vals) => { let ty: Result, _> = vals.iter().map(|x| self.infer_constant(x, loc)).collect(); - Ok(self.unifier.add_ty(TypeEnum::TTuple { ty: ty? })) + Ok(self.unifier.add_ty(TypeEnum::TTuple { ty: ty?, is_vararg_ctx: false })) } ast::Constant::Str(_) => Ok(self.primitives.str), ast::Constant::None => { @@ -1748,7 +1749,7 @@ impl<'a> Inferencer<'a> { #[allow(clippy::unnecessary_wraps)] fn infer_tuple(&mut self, elts: &[ast::Expr>]) -> InferenceResult { let ty = elts.iter().map(|x| x.custom.unwrap()).collect(); - Ok(self.unifier.add_ty(TypeEnum::TTuple { ty })) + Ok(self.unifier.add_ty(TypeEnum::TTuple { ty, is_vararg_ctx: false })) } /// Checks for non-class attributes @@ -1985,7 +1986,7 @@ impl<'a> Inferencer<'a> { rhs_ty: Type, ) -> Result>>, InferenceError> { // TODO: Allow bidirectional typechecking? Currently RHS's type has to be resolved. - let TypeEnum::TTuple { ty: rhs_tys } = &*self.unifier.get_ty(rhs_ty) else { + let TypeEnum::TTuple { ty: rhs_tys, .. } = &*self.unifier.get_ty(rhs_ty) else { // TODO: Allow RHS AST-aware error reporting return report_error( "LHS target list pattern requires RHS to be a tuple type", @@ -2055,7 +2056,10 @@ impl<'a> Inferencer<'a> { // Fold the starred target if let ExprKind::Starred { value: target, .. } = target_starred.node { - let ty = self.unifier.add_ty(TypeEnum::TTuple { ty: rhs_tys_starred.to_vec() }); + let ty = self.unifier.add_ty(TypeEnum::TTuple { + ty: rhs_tys_starred.to_vec(), + is_vararg_ctx: false, + }); let folded_target = self.fold_assign_target(*target, ty)?; folded_targets.push(Located { location: target_starred.location, diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index f4f3b9ca4..ef65ab664 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -1,5 +1,6 @@ use indexmap::IndexMap; -use itertools::Itertools; +use itertools::{repeat_n, Itertools}; +use nac3parser::ast::{Cmpop, Location, StrRef, Unaryop}; use std::cell::RefCell; use std::collections::HashMap; use std::fmt::{self, Display}; @@ -8,8 +9,6 @@ use std::rc::Rc; use std::sync::{Arc, Mutex}; use std::{borrow::Cow, collections::HashSet}; -use nac3parser::ast::{Cmpop, Location, StrRef, Unaryop}; - use super::magic_methods::Binop; use super::type_error::{TypeError, TypeErrorKind}; use super::unification_table::{UnificationKey, UnificationTable}; @@ -234,6 +233,12 @@ pub enum TypeEnum { TTuple { /// The types of elements present in this tuple. ty: Vec, + + /// Whether this tuple is used in a vararg context. + /// + /// If `true`, `ty` must only contain one type, and the tuple is assumed to contain any + /// number of `ty`-typed values. + is_vararg_ctx: bool, }, /// An object type. @@ -528,7 +533,7 @@ impl Unifier { TypeEnum::TVirtual { ty } => self.get_instantiations(*ty).map(|ty| { ty.iter().map(|&ty| self.add_ty(TypeEnum::TVirtual { ty })).collect_vec() }), - TypeEnum::TTuple { ty } => { + TypeEnum::TTuple { ty, is_vararg_ctx } => { let tuples = ty .iter() .map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty])) @@ -538,7 +543,12 @@ impl Unifier { None } else { Some( - tuples.into_iter().map(|ty| self.add_ty(TypeEnum::TTuple { ty })).collect(), + tuples + .into_iter() + .map(|ty| { + self.add_ty(TypeEnum::TTuple { ty, is_vararg_ctx: *is_vararg_ctx }) + }) + .collect(), ) } } @@ -582,7 +592,7 @@ impl Unifier { TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)), TCall { .. } => false, TVirtual { ty } => self.is_concrete(*ty, allowed_typevars), - TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)), + TTuple { ty, .. } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)), TObj { params: vars, .. } => { vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars)) } @@ -974,7 +984,10 @@ impl Unifier { self.unify_impl(x, b, false)?; self.set_a_to_b(a, x); } - (TVar { fields: Some(fields), range, is_const_generic: false, .. }, TTuple { ty }) => { + ( + TVar { fields: Some(fields), range, is_const_generic: false, .. }, + TTuple { ty, .. }, + ) => { let len = i32::try_from(ty.len()).unwrap(); for (k, v) in fields { match *k { @@ -1071,15 +1084,47 @@ impl Unifier { self.set_a_to_b(a, b); } - (TTuple { ty: ty1 }, TTuple { ty: ty2 }) => { - if ty1.len() != ty2.len() { - return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None)); - } - for (x, y) in ty1.iter().zip(ty2.iter()) { - if self.unify_impl(*x, *y, false).is_err() { - return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None)); + ( + TTuple { ty: ty1, is_vararg_ctx: is_vararg1 }, + TTuple { ty: ty2, is_vararg_ctx: is_vararg2 }, + ) => { + // Rules for Tuples: + // - ty1: is_vararg && ty2: is_vararg -> ty1[0] == ty2[0] + // - ty1: is_vararg && ty2: !is_vararg -> type error (not enough info to infer the correct number of arguments) + // - ty1: !is_vararg && ty2: is_vararg -> ty1[..] == ty2[0] + // - ty1: !is_vararg && ty2: !is_vararg -> ty1.len() == ty2.len() && ty1[i] == ty2[i] + + debug_assert!(!is_vararg1 || ty1.len() == 1); + debug_assert!(!is_vararg2 || ty2.len() == 1); + + match (*is_vararg1, *is_vararg2) { + (true, true) => { + if self.unify_impl(ty1[0], ty2[0], false).is_err() { + return Self::incompatible_types(a, b); + } + } + (true, false) => return Self::incompatible_types(a, b), + + (false, true) => { + for y in ty2 { + if self.unify_impl(ty1[0], *y, false).is_err() { + return Self::incompatible_types(a, b); + } + } + } + (false, false) => { + if ty1.len() != ty2.len() { + return Self::incompatible_types(a, b); + } + + for (x, y) in ty1.iter().zip(ty2.iter()) { + if self.unify_impl(*x, *y, false).is_err() { + return Self::incompatible_types(a, b); + } + } } } + self.set_a_to_b(a, b); } (TVar { fields: Some(map), range, .. }, TObj { obj_id, fields, params }) => { @@ -1322,10 +1367,22 @@ impl Unifier { TypeEnum::TLiteral { values, .. } => { format!("const({})", values.iter().map(|v| format!("{v:?}")).join(", ")) } - TypeEnum::TTuple { ty } => { - let mut fields = - ty.iter().map(|v| self.internal_stringify(*v, obj_to_name, var_to_name, notes)); - format!("tuple[{}]", fields.join(", ")) + TypeEnum::TTuple { ty, is_vararg_ctx } => { + if *is_vararg_ctx { + debug_assert_eq!(ty.len(), 1); + let field = self.internal_stringify( + *ty.iter().next().unwrap(), + obj_to_name, + var_to_name, + notes, + ); + format!("tuple[*{field}]") + } else { + let mut fields = ty + .iter() + .map(|v| self.internal_stringify(*v, obj_to_name, var_to_name, notes)); + format!("tuple[{}]", fields.join(", ")) + } } TypeEnum::TVirtual { ty } => { format!( @@ -1446,7 +1503,7 @@ impl Unifier { match &*ty { TypeEnum::TRigidVar { .. } | TypeEnum::TLiteral { .. } => None, TypeEnum::TVar { id, .. } => mapping.get(id).copied(), - TypeEnum::TTuple { ty } => { + TypeEnum::TTuple { ty, is_vararg_ctx } => { let mut new_ty = Cow::from(ty); for (i, t) in ty.iter().enumerate() { if let Some(t1) = self.subst_impl(*t, mapping, cache) { @@ -1454,7 +1511,10 @@ impl Unifier { } } if matches!(new_ty, Cow::Owned(_)) { - Some(self.add_ty(TypeEnum::TTuple { ty: new_ty.into_owned() })) + Some(self.add_ty(TypeEnum::TTuple { + ty: new_ty.into_owned(), + is_vararg_ctx: *is_vararg_ctx, + })) } else { None } @@ -1614,16 +1674,37 @@ impl Unifier { } } (TVar { range, .. }, _) => self.check_var_compatibility(b, range).or(Err(())), - (TTuple { ty: ty1 }, TTuple { ty: ty2 }) if ty1.len() == ty2.len() => { - let ty: Vec<_> = zip(ty1.iter(), ty2.iter()) - .map(|(a, b)| self.get_intersection(*a, *b)) - .try_collect()?; - if ty.iter().any(Option::is_some) { - Ok(Some(self.add_ty(TTuple { - ty: zip(ty, ty1.iter()).map(|(a, b)| a.unwrap_or(*b)).collect(), - }))) + ( + TTuple { ty: ty1, is_vararg_ctx: is_vararg1 }, + TTuple { ty: ty2, is_vararg_ctx: is_vararg2 }, + ) => { + if *is_vararg1 && *is_vararg2 { + let isect_ty = self.get_intersection(ty1[0], ty2[0])?; + Ok(isect_ty.map(|ty| self.add_ty(TTuple { ty: vec![ty], is_vararg_ctx: true }))) } else { - Ok(None) + let zip_iter: Box> = + match (*is_vararg1, *is_vararg2) { + (true, _) => Box::new(repeat_n(&ty1[0], ty2.len()).zip(ty2.iter())), + (_, false) => Box::new(ty1.iter().zip(repeat_n(&ty2[0], ty1.len()))), + _ => { + if ty1.len() != ty2.len() { + return Err(()); + } + + Box::new(ty1.iter().zip(ty2.iter())) + } + }; + + let ty: Vec<_> = + zip_iter.map(|(a, b)| self.get_intersection(*a, *b)).try_collect()?; + Ok(if ty.iter().any(Option::is_some) { + Some(self.add_ty(TTuple { + ty: zip(ty, ty1.iter()).map(|(a, b)| a.unwrap_or(*b)).collect(), + is_vararg_ctx: false, + })) + } else { + None + }) } } // TODO(Derppening): #444 diff --git a/nac3core/src/typecheck/typedef/test.rs b/nac3core/src/typecheck/typedef/test.rs index 451f5f01a..435c134db 100644 --- a/nac3core/src/typecheck/typedef/test.rs +++ b/nac3core/src/typecheck/typedef/test.rs @@ -28,7 +28,10 @@ impl Unifier { TypeEnum::TVar { fields: Some(map1), .. }, TypeEnum::TVar { fields: Some(map2), .. }, ) => self.map_eq2(map1, map2), - (TypeEnum::TTuple { ty: ty1 }, TypeEnum::TTuple { ty: ty2 }) => { + ( + TypeEnum::TTuple { ty: ty1, is_vararg_ctx: false }, + TypeEnum::TTuple { ty: ty2, is_vararg_ctx: false }, + ) => { ty1.len() == ty2.len() && ty1.iter().zip(ty2.iter()).all(|(t1, t2)| self.eq(*t1, *t2)) } @@ -178,7 +181,7 @@ impl TestEnvironment { ty.push(result.0); s = result.1; } - (self.unifier.add_ty(TypeEnum::TTuple { ty }), &s[1..]) + (self.unifier.add_ty(TypeEnum::TTuple { ty, is_vararg_ctx: false }), &s[1..]) } "Record" => { let mut s = &typ[end..]; @@ -608,7 +611,7 @@ fn test_instantiation() { let v1 = env.unifier.get_fresh_var_with_range(&[list_v, int], None, None).ty; let v2 = env.unifier.get_fresh_var_with_range(&[list_int, float], None, None).ty; let t = env.unifier.get_dummy_var().ty; - let tuple = env.unifier.add_ty(TypeEnum::TTuple { ty: vec![v, v1, v2] }); + let tuple = env.unifier.add_ty(TypeEnum::TTuple { ty: vec![v, v1, v2], is_vararg_ctx: false }); let v3 = env.unifier.get_fresh_var_with_range(&[tuple, t], None, None).ty; // t = TypeVar('t') // v = TypeVar('v', int, bool)