Compare commits

...

11 Commits

21 changed files with 565 additions and 180 deletions

View File

@ -386,7 +386,7 @@ fn gen_rpc_tag(
} else {
let ty_enum = ctx.unifier.get_ty(ty);
match &*ty_enum {
TTuple { ty } => {
TTuple { ty, is_vararg_ctx: false } => {
buffer.push(b't');
buffer.push(ty.len() as u8);
for ty in ty {
@ -700,6 +700,7 @@ pub fn attributes_writeback(
name: i.to_string().into(),
ty: *ty,
default_value: None,
is_vararg: false,
})
.collect(),
ret: ctx.primitives.none,

View File

@ -264,7 +264,7 @@ impl Nac3 {
arg_names.len(),
));
}
for (i, FuncArg { ty, default_value, name }) in args.iter().enumerate() {
for (i, FuncArg { ty, default_value, name, .. }) in args.iter().enumerate() {
let in_name = match arg_names.get(i) {
Some(n) => n,
None if default_value.is_none() => {
@ -863,6 +863,7 @@ impl Nac3 {
name: "t".into(),
ty: primitive.int64,
default_value: None,
is_vararg: false,
}],
ret: primitive.none,
vars: VarMap::new(),
@ -882,6 +883,7 @@ impl Nac3 {
name: "dt".into(),
ty: primitive.int64,
default_value: None,
is_vararg: false,
}],
ret: primitive.none,
vars: VarMap::new(),

View File

@ -351,7 +351,7 @@ impl InnerResolver {
Ok(Ok((ndarray, false)))
} else if ty_id == self.primitive_ids.tuple {
// do not handle type var param and concrete check here
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false)))
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![], is_vararg_ctx: false }), false)))
} else if ty_id == self.primitive_ids.option {
Ok(Ok((primitives.option, false)))
} else if ty_id == self.primitive_ids.none {
@ -555,7 +555,10 @@ impl InnerResolver {
Err(err) => return Ok(Err(err)),
_ => return Ok(Err("tuple type needs at least 1 type parameters".to_string()))
};
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: args }), true)))
Ok(Ok((
unifier.add_ty(TypeEnum::TTuple { ty: args, is_vararg_ctx: false }),
true,
)))
}
TypeEnum::TObj { params, obj_id, .. } => {
let subst = {
@ -797,7 +800,9 @@ impl InnerResolver {
.map(|elem| self.get_obj_type(py, elem, unifier, defs, primitives))
.collect();
let types = types?;
Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types })))
Ok(types.map(|types| {
unifier.add_ty(TypeEnum::TTuple { ty: types, is_vararg_ctx: false })
}))
}
// special handling for option type since its class member layout in python side
// is special and cannot be mapped directly to a nac3 type as below
@ -1196,7 +1201,9 @@ impl InnerResolver {
Ok(Some(ndarray.as_pointer_value().into()))
} else if ty_id == self.primitive_ids.tuple {
let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty);
let TypeEnum::TTuple { ty } = expected_ty_enum.as_ref() else { unreachable!() };
let TypeEnum::TTuple { ty, is_vararg_ctx: false } = expected_ty_enum.as_ref() else {
unreachable!()
};
let tup_tys = ty.iter();
let elements: &PyTuple = obj.downcast()?;

View File

@ -1,15 +1,20 @@
use inkwell::types::BasicTypeEnum;
use inkwell::values::BasicValueEnum;
use inkwell::values::{BasicValueEnum, IntValue};
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
use itertools::Itertools;
use crate::codegen::classes::{NDArrayValue, ProxyValue, UntypedArrayLikeAccessor};
use crate::codegen::classes::{
ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor,
UntypedArrayLikeAccessor,
};
use crate::codegen::expr::destructure_range;
use crate::codegen::irrt::calculate_len_for_slice_range;
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
use crate::codegen::stmt::gen_for_callback_incrementing;
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
use crate::toplevel::helper::PrimDef;
use crate::toplevel::numpy::unpack_ndarray_var_tys;
use crate::typecheck::typedef::Type;
use crate::typecheck::typedef::{Type, TypeEnum};
/// Shorthand for [`unreachable!()`] when a type of argument is not supported.
///
@ -21,6 +26,67 @@ fn unsupported_type(ctx: &CodeGenContext<'_, '_>, fn_name: &str, tys: &[Type]) -
)
}
/// Invokes the `len` builtin function.
pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
n: (Type, BasicValueEnum<'ctx>),
) -> Result<IntValue<'ctx>, String> {
let llvm_i32 = ctx.ctx.i32_type();
let range_ty = ctx.primitives.range;
let (arg_ty, arg) = n;
Ok(if ctx.unifier.unioned(arg_ty, range_ty) {
let arg = RangeValue::from_ptr_val(arg.into_pointer_value(), Some("range"));
let (start, end, step) = destructure_range(ctx, arg);
calculate_len_for_slice_range(generator, ctx, start, end, step)
} else {
match &*ctx.unifier.get_ty_immutable(arg_ty) {
TypeEnum::TTuple { ty, .. } => llvm_i32.const_int(ty.len() as u64, false),
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::List.id() => {
let zero = llvm_i32.const_zero();
let len = ctx
.build_gep_and_load(
arg.into_pointer_value(),
&[zero, llvm_i32.const_int(1, false)],
None,
)
.into_int_value();
ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap()
}
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let llvm_usize = generator.get_size_type(ctx.ctx);
let arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), llvm_usize, None);
let ndims = arg.dim_sizes().size(ctx, generator);
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(IntPredicate::NE, ndims, llvm_usize.const_zero(), "")
.unwrap(),
"0:TypeError",
"len() of unsized object",
[None, None, None],
ctx.current_loc,
);
let len = unsafe {
arg.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_zero(),
None,
)
};
ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap()
}
_ => unreachable!(),
}
})
}
/// Invokes the `int32` builtin function.
pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,

View File

@ -25,6 +25,7 @@ pub struct ConcreteFuncArg {
pub name: StrRef,
pub ty: ConcreteType,
pub default_value: Option<SymbolValue>,
pub is_vararg: bool,
}
#[derive(Clone, Debug)]
@ -46,6 +47,7 @@ pub enum ConcreteTypeEnum {
TPrimitive(Primitive),
TTuple {
ty: Vec<ConcreteType>,
is_vararg_ctx: bool,
},
TObj {
obj_id: DefinitionId,
@ -102,8 +104,16 @@ impl ConcreteTypeStore {
.iter()
.map(|arg| ConcreteFuncArg {
name: arg.name,
ty: self.from_unifier_type(unifier, primitives, arg.ty, cache),
ty: if arg.is_vararg {
let tuple_ty = unifier
.add_ty(TypeEnum::TTuple { ty: vec![arg.ty], is_vararg_ctx: true });
self.from_unifier_type(unifier, primitives, tuple_ty, cache)
} else {
self.from_unifier_type(unifier, primitives, arg.ty, cache)
},
default_value: arg.default_value.clone(),
is_vararg: arg.is_vararg,
})
.collect(),
ret: self.from_unifier_type(unifier, primitives, signature.ret, cache),
@ -158,11 +168,12 @@ impl ConcreteTypeStore {
cache.insert(ty, None);
let ty_enum = unifier.get_ty(ty);
let result = match &*ty_enum {
TypeEnum::TTuple { ty } => ConcreteTypeEnum::TTuple {
TypeEnum::TTuple { ty, is_vararg_ctx } => ConcreteTypeEnum::TTuple {
ty: ty
.iter()
.map(|t| self.from_unifier_type(unifier, primitives, *t, cache))
.collect(),
is_vararg_ctx: *is_vararg_ctx,
},
TypeEnum::TObj { obj_id, fields, params } => ConcreteTypeEnum::TObj {
obj_id: *obj_id,
@ -248,11 +259,12 @@ impl ConcreteTypeStore {
*cache.get_mut(&cty).unwrap() = Some(ty);
return ty;
}
ConcreteTypeEnum::TTuple { ty } => TypeEnum::TTuple {
ConcreteTypeEnum::TTuple { ty, is_vararg_ctx } => TypeEnum::TTuple {
ty: ty
.iter()
.map(|cty| self.to_unifier_type(unifier, primitives, *cty, cache))
.collect(),
is_vararg_ctx: *is_vararg_ctx,
},
ConcreteTypeEnum::TVirtual { ty } => {
TypeEnum::TVirtual { ty: self.to_unifier_type(unifier, primitives, *ty, cache) }
@ -277,6 +289,7 @@ impl ConcreteTypeStore {
name: arg.name,
ty: self.to_unifier_type(unifier, primitives, arg.ty, cache),
default_value: arg.default_value.clone(),
is_vararg: false,
})
.collect(),
ret: self.to_unifier_type(unifier, primitives, *ret, cache),

View File

@ -267,13 +267,16 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
}
Constant::Tuple(v) => {
let ty = self.unifier.get_ty(ty);
let types =
if let TypeEnum::TTuple { ty } = &*ty { ty.clone() } else { unreachable!() };
let (types, is_vararg_ctx) = if let TypeEnum::TTuple { ty, is_vararg_ctx } = &*ty {
(ty.clone(), *is_vararg_ctx)
} else {
unreachable!()
};
let values = zip(types, v.iter())
.map_while(|(ty, v)| self.gen_const(generator, v, ty))
.collect_vec();
if values.len() == v.len() {
if is_vararg_ctx || values.len() == v.len() {
let types = values.iter().map(BasicValueEnum::get_type).collect_vec();
let ty = self.ctx.struct_type(&types, false);
Some(ty.const_named_struct(&values).into())
@ -731,7 +734,15 @@ pub fn gen_func_instance<'ctx>(
let zelf = store.from_unifier_type(&mut ctx.unifier, &ctx.primitives, obj.0, &mut cache);
let ConcreteTypeEnum::TFunc { args, .. } = &mut signature else { unreachable!() };
args.insert(0, ConcreteFuncArg { name: "self".into(), ty: zelf, default_value: None });
args.insert(
0,
ConcreteFuncArg {
name: "self".into(),
ty: zelf,
default_value: None,
is_vararg: false,
},
);
}
let signature = store.add_cty(signature);
@ -763,6 +774,9 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
let param_vals;
let is_extern;
// Ensure that the function object only contains up to 1 vararg parameter
debug_assert!(fun.0.args.iter().filter(|arg| arg.is_vararg).count() <= 1);
let symbol = {
// make sure this lock guard is dropped at the end of this scope...
let def = definition.read();
@ -779,18 +793,46 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
is_extern = instance_to_stmt.is_empty();
let old_key = ctx.get_subst_key(obj.as_ref().map(|a| a.0), fun.0, None);
let mut keys = fun.0.args.clone();
let mut mapping = HashMap::new();
let mut mapping = HashMap::<_, Vec<ValueEnum>>::new();
for (key, value) in params {
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value);
// Find the matching argument
let matching_param = fun
.0
.args
.iter()
.find_or_last(|p| key.is_some_and(|k| k == p.name))
.unwrap();
if matching_param.is_vararg {
if key.is_none() && !keys.is_empty() {
keys.remove(0);
}
if let Some(param) = mapping.get_mut(&matching_param.name) {
param.push(value);
} else {
mapping.insert(key.unwrap_or(matching_param.name), vec![value]);
}
} else {
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), vec![value]);
}
}
// default value handling
for k in keys {
if mapping.contains_key(&k.name) {
continue;
}
mapping.insert(
k.name,
ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into(),
if k.is_vararg {
Vec::default()
} else {
vec![ctx
.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty)
.into()]
},
);
}
// reorder the parameters
@ -801,13 +843,15 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
.map(|arg| (mapping.remove(&arg.name).unwrap(), arg.ty))
.collect_vec();
if let Some(obj) = &obj {
real_params.insert(0, (obj.1.clone(), obj.0));
real_params.insert(0, (vec![obj.1.clone()], obj.0));
}
let static_params = real_params
.iter()
.enumerate()
.filter_map(|(i, (v, _))| {
if let ValueEnum::Static(s) = v {
if v.len() != 1 {
None
} else if let ValueEnum::Static(s) = &v[0] {
Some((i, s.clone()))
} else {
None
@ -837,8 +881,13 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
};
param_vals = real_params
.into_iter()
.map(|(p, t)| p.to_basic_value_enum(ctx, generator, t))
.collect::<Result<Vec<_>, String>>()?;
.map(|(ps, t)| {
ps.into_iter().map(|p| p.to_basic_value_enum(ctx, generator, t)).collect()
})
.collect::<Result<Vec<Vec<_>>, _>>()?
.into_iter()
.flatten()
.collect::<Vec<_>>();
instance_to_symbol.get(&key).cloned().ok_or_else(String::new)
}
TopLevelDef::Class { .. } => {
@ -852,7 +901,10 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
let fun_val = ctx.module.get_function(&symbol).unwrap_or_else(|| {
let mut args = fun.0.args.clone();
if let Some(obj) = &obj {
args.insert(0, FuncArg { name: "self".into(), ty: obj.0, default_value: None });
args.insert(
0,
FuncArg { name: "self".into(), ty: obj.0, default_value: None, is_vararg: false },
);
}
let ret_type = if ctx.unifier.unioned(fun.0.ret, ctx.primitives.none) {
None

View File

@ -520,8 +520,10 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
};
return ty;
}
TTuple { ty } => {
TTuple { ty, is_vararg_ctx } => {
// a struct with fields in the order present in the tuple
assert!(!is_vararg_ctx, "Tuples in vararg context must be instantiated with the correct number of arguments before calling get_llvm_type");
let fields = ty
.iter()
.map(|ty| {
@ -700,6 +702,7 @@ pub fn gen_func_impl<
name: arg.name,
ty: task.store.to_unifier_type(&mut unifier, &primitives, arg.ty, &mut cache),
default_value: arg.default_value.clone(),
is_vararg: arg.is_vararg,
})
.collect_vec(),
task.store.to_unifier_type(&mut unifier, &primitives, *ret, &mut cache),
@ -723,6 +726,19 @@ pub fn gen_func_impl<
let mut params = args
.iter()
.map(|arg| {
let base_ty = if arg.is_vararg {
let TypeEnum::TTuple { ty, is_vararg_ctx: true, .. } =
&*unifier.get_ty_immutable(arg.ty)
else {
unreachable!()
};
debug_assert_eq!(ty.len(), 1);
ty[0]
} else {
arg.ty
};
get_llvm_abi_type(
context,
&module,
@ -731,7 +747,7 @@ pub fn gen_func_impl<
top_level_ctx.as_ref(),
&mut type_cache,
&primitives,
arg.ty,
base_ty,
)
.into()
})
@ -741,9 +757,11 @@ pub fn gen_func_impl<
params.insert(0, ret_type.unwrap().ptr_type(AddressSpace::default()).into());
}
let is_vararg = args.iter().any(|arg| arg.is_vararg);
let fn_type = match ret_type {
Some(ret_type) if !has_sret => ret_type.fn_type(&params, false),
_ => context.void_type().fn_type(&params, false),
Some(ret_type) if !has_sret => ret_type.fn_type(&params, is_vararg),
_ => context.void_type().fn_type(&params, is_vararg),
};
let symbol = &task.symbol_name;

View File

@ -109,8 +109,18 @@ fn test_primitives() {
let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()];
let signature = FunSignature {
args: vec![
FuncArg { name: "a".into(), ty: primitives.int32, default_value: None },
FuncArg { name: "b".into(), ty: primitives.int32, default_value: None },
FuncArg {
name: "a".into(),
ty: primitives.int32,
default_value: None,
is_vararg: false,
},
FuncArg {
name: "b".into(),
ty: primitives.int32,
default_value: None,
is_vararg: false,
},
],
ret: primitives.int32,
vars: VarMap::new(),
@ -253,7 +263,12 @@ fn test_simple_call() {
unifier.top_level = Some(top_level.clone());
let signature = FunSignature {
args: vec![FuncArg { name: "a".into(), ty: primitives.int32, default_value: None }],
args: vec![FuncArg {
name: "a".into(),
ty: primitives.int32,
default_value: None,
is_vararg: false,
}],
ret: primitives.int32,
vars: VarMap::new(),
};

View File

@ -78,14 +78,14 @@ impl SymbolValue {
}
Constant::Tuple(t) => {
let expected_ty = unifier.get_ty(expected_ty);
let TypeEnum::TTuple { ty } = expected_ty.as_ref() else {
let TypeEnum::TTuple { ty, is_vararg_ctx } = expected_ty.as_ref() else {
return Err(format!(
"Expected {:?}, but got Tuple",
expected_ty.get_type_name()
));
};
assert_eq!(ty.len(), t.len());
assert!(*is_vararg_ctx || ty.len() == t.len());
let elems = t
.iter()
@ -155,7 +155,7 @@ impl SymbolValue {
SymbolValue::Bool(_) => primitives.bool,
SymbolValue::Tuple(vs) => {
let vs_tys = vs.iter().map(|v| v.get_type(primitives, unifier)).collect::<Vec<_>>();
unifier.add_ty(TypeEnum::TTuple { ty: vs_tys })
unifier.add_ty(TypeEnum::TTuple { ty: vs_tys, is_vararg_ctx: false })
}
SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option,
}
@ -482,7 +482,7 @@ pub fn parse_type_annotation<T>(
parse_type_annotation(resolver, top_level_defs, unifier, primitives, elt)
})
.collect::<Result<Vec<_>, _>>()?;
Ok(unifier.add_ty(TypeEnum::TTuple { ty }))
Ok(unifier.add_ty(TypeEnum::TTuple { ty, is_vararg_ctx: false }))
} else {
Err(HashSet::from(["Expected multiple elements for tuple".into()]))
}

View File

@ -14,9 +14,7 @@ use strum::IntoEnumIterator;
use crate::{
codegen::{
builtin_fns,
classes::{ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor},
expr::destructure_range,
irrt::*,
classes::{ProxyValue, RangeValue},
numpy::*,
stmt::exn_constructor,
},
@ -45,10 +43,26 @@ pub fn get_exn_constructor(
name: "msg".into(),
ty: string,
default_value: Some(SymbolValue::Str(String::new())),
is_vararg: false,
},
FuncArg {
name: "param0".into(),
ty: int64,
default_value: Some(SymbolValue::I64(0)),
is_vararg: false,
},
FuncArg {
name: "param1".into(),
ty: int64,
default_value: Some(SymbolValue::I64(0)),
is_vararg: false,
},
FuncArg {
name: "param2".into(),
ty: int64,
default_value: Some(SymbolValue::I64(0)),
is_vararg: false,
},
FuncArg { name: "param0".into(), ty: int64, default_value: Some(SymbolValue::I64(0)) },
FuncArg { name: "param1".into(), ty: int64, default_value: Some(SymbolValue::I64(0)) },
FuncArg { name: "param2".into(), ty: int64, default_value: Some(SymbolValue::I64(0)) },
];
let exn_type = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(class_id),
@ -114,7 +128,12 @@ fn create_fn_by_codegen(
signature: unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: param_ty
.iter()
.map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None })
.map(|p| FuncArg {
name: p.1.into(),
ty: p.0,
default_value: None,
is_vararg: false,
})
.collect(),
ret: ret_ty,
vars: var_map.clone(),
@ -613,17 +632,24 @@ impl<'a> BuiltinBuilder<'a> {
let make_ctor_signature = |unifier: &mut Unifier| {
unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg { name: "start".into(), ty: int32, default_value: None },
FuncArg {
name: "start".into(),
ty: int32,
default_value: None,
is_vararg: false,
},
FuncArg {
name: "stop".into(),
ty: int32,
// placeholder
default_value: Some(SymbolValue::I32(0)),
is_vararg: false,
},
FuncArg {
name: "step".into(),
ty: int32,
default_value: Some(SymbolValue::I32(1)),
is_vararg: false,
},
],
ret: range,
@ -879,6 +905,7 @@ impl<'a> BuiltinBuilder<'a> {
name: "n".into(),
ty: self.option_tvar.ty,
default_value: None,
is_vararg: false,
}],
ret: self.primitives.option,
vars: into_var_map([self.option_tvar]),
@ -1013,6 +1040,7 @@ impl<'a> BuiltinBuilder<'a> {
name: "n".into(),
ty: self.num_or_ndarray_ty.ty,
default_value: None,
is_vararg: false,
}],
ret: self.num_or_ndarray_ty.ty,
vars: self.num_or_ndarray_var_map.clone(),
@ -1232,16 +1260,23 @@ impl<'a> BuiltinBuilder<'a> {
simple_name: prim.simple_name().into(),
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg { name: "object".into(), ty: tv.ty, default_value: None },
FuncArg {
name: "object".into(),
ty: tv.ty,
default_value: None,
is_vararg: false,
},
FuncArg {
name: "copy".into(),
ty: bool,
default_value: Some(SymbolValue::Bool(true)),
is_vararg: false,
},
FuncArg {
name: "ndmin".into(),
ty: int32,
default_value: Some(SymbolValue::U32(0)),
is_vararg: false,
},
],
ret: ndarray,
@ -1283,17 +1318,24 @@ impl<'a> BuiltinBuilder<'a> {
simple_name: prim.simple_name().into(),
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg { name: "N".into(), ty: int32, default_value: None },
FuncArg {
name: "N".into(),
ty: int32,
default_value: None,
is_vararg: false,
},
// TODO(Derppening): Default values current do not work?
FuncArg {
name: "M".into(),
ty: int32,
default_value: Some(SymbolValue::OptionNone),
is_vararg: false,
},
FuncArg {
name: "k".into(),
ty: int32,
default_value: Some(SymbolValue::I32(0)),
is_vararg: false,
},
],
ret: self.ndarray_float_2d,
@ -1337,7 +1379,12 @@ impl<'a> BuiltinBuilder<'a> {
name: prim.name().into(),
simple_name: prim.simple_name().into(),
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { name: "s".into(), ty: str, default_value: None }],
args: vec![FuncArg {
name: "s".into(),
ty: str,
default_value: None,
is_vararg: false,
}],
ret: str,
vars: VarMap::default(),
})),
@ -1423,7 +1470,12 @@ impl<'a> BuiltinBuilder<'a> {
name: prim.name().into(),
simple_name: prim.simple_name().into(),
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { name: "ls".into(), ty: arg_ty.ty, default_value: None }],
args: vec![FuncArg {
name: "ls".into(),
ty: arg_ty.ty,
default_value: None,
is_vararg: false,
}],
ret: int32,
vars: into_var_map([tvar, arg_ty]),
})),
@ -1433,86 +1485,10 @@ impl<'a> BuiltinBuilder<'a> {
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
move |ctx, _, fun, args, generator| {
let range_ty = ctx.primitives.range;
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
Ok(if ctx.unifier.unioned(arg_ty, range_ty) {
let arg = RangeValue::from_ptr_val(arg.into_pointer_value(), Some("range"));
let (start, end, step) = destructure_range(ctx, arg);
Some(calculate_len_for_slice_range(generator, ctx, start, end, step).into())
} else {
match &*ctx.unifier.get_ty_immutable(arg_ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::List.id() => {
let int32 = ctx.ctx.i32_type();
let zero = int32.const_zero();
let len = ctx
.build_gep_and_load(
arg.into_pointer_value(),
&[zero, int32.const_int(1, false)],
None,
)
.into_int_value();
if len.get_type().get_bit_width() == 32 {
Some(len.into())
} else {
Some(
ctx.builder
.build_int_truncate(len, int32, "len2i32")
.map(Into::into)
.unwrap(),
)
}
}
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let arg = NDArrayValue::from_ptr_val(
arg.into_pointer_value(),
llvm_usize,
None,
);
let ndims = arg.dim_sizes().size(ctx, generator);
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(
IntPredicate::NE,
ndims,
llvm_usize.const_zero(),
"",
)
.unwrap(),
"0:TypeError",
&format!("{name}() of unsized object", name = prim.name()),
[None, None, None],
ctx.current_loc,
);
let len = unsafe {
arg.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_zero(),
None,
)
};
if len.get_type().get_bit_width() == 32 {
Some(len.into())
} else {
Some(
ctx.builder
.build_int_truncate(len, llvm_i32, "len")
.map(Into::into)
.unwrap(),
)
}
}
_ => unreachable!(),
}
})
builtin_fns::call_len(generator, ctx, (arg_ty, arg)).map(|ret| Some(ret.into()))
},
)))),
loc: None,
@ -1528,8 +1504,18 @@ impl<'a> BuiltinBuilder<'a> {
simple_name: prim.simple_name().into(),
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg { name: "m".into(), ty: self.num_ty.ty, default_value: None },
FuncArg { name: "n".into(), ty: self.num_ty.ty, default_value: None },
FuncArg {
name: "m".into(),
ty: self.num_ty.ty,
default_value: None,
is_vararg: false,
},
FuncArg {
name: "n".into(),
ty: self.num_ty.ty,
default_value: None,
is_vararg: false,
},
],
ret: self.num_ty.ty,
vars: self.num_var_map.clone(),
@ -1611,7 +1597,12 @@ impl<'a> BuiltinBuilder<'a> {
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: param_ty
.iter()
.map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None })
.map(|p| FuncArg {
name: p.1.into(),
ty: p.0,
default_value: None,
is_vararg: false,
})
.collect(),
ret: ret_ty.ty,
vars: into_var_map([x1_ty, x2_ty, ret_ty]),
@ -1652,6 +1643,7 @@ impl<'a> BuiltinBuilder<'a> {
name: "n".into(),
ty: self.num_or_ndarray_ty.ty,
default_value: None,
is_vararg: false,
}],
ret: self.num_or_ndarray_ty.ty,
vars: self.num_or_ndarray_var_map.clone(),
@ -1840,7 +1832,12 @@ impl<'a> BuiltinBuilder<'a> {
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: param_ty
.iter()
.map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None })
.map(|p| FuncArg {
name: p.1.into(),
ty: p.0,
default_value: None,
is_vararg: false,
})
.collect(),
ret: ret_ty.ty,
vars: into_var_map([x1_ty, x2_ty, ret_ty]),

View File

@ -859,7 +859,72 @@ impl TopLevelComposer {
let resolver = &**resolver;
let mut function_var_map = VarMap::new();
let arg_types = {
let vararg = args
.vararg
.as_ref()
.map(|vararg| -> Result<_, HashSet<String>> {
let vararg = vararg.as_ref();
let annotation = vararg
.node
.annotation
.as_ref()
.ok_or_else(|| {
HashSet::from([format!(
"function parameter `{}` needs type annotation at {}",
vararg.node.arg, vararg.location
)])
})?
.as_ref();
let type_annotation = parse_ast_to_type_annotation_kinds(
resolver,
temp_def_list.as_slice(),
unifier,
primitives_store,
annotation,
// NOTE: since only class need this, for function
// it should be fine to be empty map
HashMap::new(),
)?;
let type_vars_within =
get_type_var_contained_in_type_annotation(&type_annotation)
.into_iter()
.map(|x| -> Result<TypeVar, HashSet<String>> {
let TypeAnnotation::TypeVar(ty) = x else {
unreachable!("must be type var annotation kind")
};
let id = Self::get_var_id(ty, unifier)?;
Ok(TypeVar { id, ty })
})
.collect::<Result<Vec<_>, _>>()?;
for var in type_vars_within {
if let Some(prev_ty) = function_var_map.insert(var.id, var.ty) {
// if already have the type inserted, make sure they are the same thing
assert_eq!(prev_ty, var.ty);
}
}
let ty = get_type_from_type_annotation_kinds(
temp_def_list.as_ref(),
unifier,
&type_annotation,
&mut None,
)?;
Ok(FuncArg {
name: vararg.node.arg,
ty,
default_value: Some(SymbolValue::Tuple(Vec::default())),
is_vararg: true,
})
})
.transpose()?;
let mut arg_types = {
// make sure no duplicate parameter
let mut defined_parameter_name: HashSet<_> = HashSet::new();
for x in &args.args {
@ -959,11 +1024,18 @@ impl TopLevelComposer {
v
}),
},
is_vararg: false,
})
})
.collect::<Result<Vec<_>, _>>()?
};
if let Some(vararg) = vararg {
arg_types.push(vararg);
};
let arg_types = arg_types;
let return_ty = {
if let Some(returns) = returns {
let return_ty_annotation = {
@ -1214,6 +1286,7 @@ impl TopLevelComposer {
})
}
},
is_vararg: false,
};
// push the dummy type and the type annotation
// into the list for later unification
@ -1638,21 +1711,25 @@ impl TopLevelComposer {
name: "msg".into(),
ty: string,
default_value: Some(SymbolValue::Str(String::new())),
is_vararg: false,
},
FuncArg {
name: "param0".into(),
ty: int64,
default_value: Some(SymbolValue::I64(0)),
is_vararg: false,
},
FuncArg {
name: "param1".into(),
ty: int64,
default_value: Some(SymbolValue::I64(0)),
is_vararg: false,
},
FuncArg {
name: "param2".into(),
ty: int64,
default_value: Some(SymbolValue::I64(0)),
is_vararg: false,
},
],
ret: self_type,
@ -1858,6 +1935,7 @@ impl TopLevelComposer {
name: a.name,
ty: unifier.subst(a.ty, &subst).unwrap_or(a.ty),
default_value: a.default_value.clone(),
is_vararg: false,
})
.collect_vec()
};

View File

@ -448,6 +448,7 @@ impl TopLevelComposer {
name: "value".into(),
ty: ndarray_dtype_tvar.ty,
default_value: None,
is_vararg: false,
}],
ret: none,
vars: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),

View File

@ -502,7 +502,7 @@ pub fn get_type_from_type_annotation_kinds(
get_type_from_type_annotation_kinds(top_level_defs, unifier, x, subst_list)
})
.collect::<Result<Vec<_>, _>>()?;
Ok(unifier.add_ty(TypeEnum::TTuple { ty: tys }))
Ok(unifier.add_ty(TypeEnum::TTuple { ty: tys, is_vararg_ctx: false }))
}
}
}

View File

@ -218,7 +218,7 @@ impl<'a> Inferencer<'a> {
]
.iter()
.any(|allowed_ty| self.unifier.unioned(ret_ty, *allowed_ty)),
TypeEnum::TTuple { ty } => ty.iter().all(|t| self.check_return_value_ty(*t)),
TypeEnum::TTuple { ty, .. } => ty.iter().all(|t| self.check_return_value_ty(*t)),
_ => false,
}
}

View File

@ -197,6 +197,7 @@ pub fn impl_binop(
ty: other_ty,
default_value: None,
name: "other".into(),
is_vararg: false,
}],
})),
false,
@ -261,6 +262,7 @@ pub fn impl_cmpop(
ty: other_ty,
default_value: None,
name: "other".into(),
is_vararg: false,
}],
})),
false,

View File

@ -183,9 +183,10 @@ impl<'a> Display for DisplayTypeError<'a> {
}
result
}
(TypeEnum::TTuple { ty: ty1 }, TypeEnum::TTuple { ty: ty2 })
if ty1.len() != ty2.len() =>
{
(
TypeEnum::TTuple { ty: ty1, is_vararg_ctx: is_vararg1 },
TypeEnum::TTuple { ty: ty2, is_vararg_ctx: is_vararg2 },
) if !is_vararg1 && !is_vararg2 && ty1.len() != ty2.len() => {
let t1 = self.unifier.stringify_with_notes(*t1, &mut notes);
let t2 = self.unifier.stringify_with_notes(*t2, &mut notes);
write!(f, "Tuple length mismatch: got {t1} and {t2}")

View File

@ -763,7 +763,7 @@ impl<'a> Inferencer<'a> {
let fun = FunSignature {
args: fn_args
.iter()
.map(|(k, ty)| FuncArg { name: *k, ty: *ty, default_value: None })
.map(|(k, ty)| FuncArg { name: *k, ty: *ty, default_value: None, is_vararg: false })
.collect(),
ret,
vars: VarMap::default(),
@ -977,13 +977,14 @@ impl<'a> Inferencer<'a> {
]));
}
}
TypeEnum::TTuple { ty: tuple_element_types } => {
TypeEnum::TTuple { ty: tuple_element_types, .. } => {
// Handle 2. A tuple of int32s
// Typecheck
// The expected type is just the tuple but with all its elements being int32.
let expected_ty = self.unifier.add_ty(TypeEnum::TTuple {
ty: tuple_element_types.iter().map(|_| self.primitives.int32).collect_vec(),
is_vararg_ctx: false,
});
self.unifier.unify(shape_ty, expected_ty).map_err(|err| {
HashSet::from([err
@ -1110,6 +1111,7 @@ impl<'a> Inferencer<'a> {
name: "n".into(),
ty: arg0.custom.unwrap(),
default_value: None,
is_vararg: false,
}],
ret,
vars: VarMap::new(),
@ -1148,6 +1150,7 @@ impl<'a> Inferencer<'a> {
name: "a".into(),
ty: arg0.custom.unwrap(),
default_value: None,
is_vararg: false,
}],
ret,
vars: VarMap::new(),
@ -1248,8 +1251,18 @@ impl<'a> Inferencer<'a> {
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg { name: "x1".into(), ty: arg0.custom.unwrap(), default_value: None },
FuncArg { name: "x2".into(), ty: arg1.custom.unwrap(), default_value: None },
FuncArg {
name: "x1".into(),
ty: arg0.custom.unwrap(),
default_value: None,
is_vararg: false,
},
FuncArg {
name: "x2".into(),
ty: arg1.custom.unwrap(),
default_value: None,
is_vararg: false,
},
],
ret,
vars: VarMap::new(),
@ -1329,6 +1342,7 @@ impl<'a> Inferencer<'a> {
name: "n".into(),
ty: arg0.custom.unwrap(),
default_value: None,
is_vararg: false,
}],
ret,
vars: VarMap::new(),
@ -1370,6 +1384,7 @@ impl<'a> Inferencer<'a> {
name: "shape".into(),
ty: shape.custom.unwrap(),
default_value: None,
is_vararg: false,
}],
ret,
vars: VarMap::new(),
@ -1413,11 +1428,17 @@ impl<'a> Inferencer<'a> {
let ret = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims));
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg { name: "shape".into(), ty: arg0.custom.unwrap(), default_value: None },
FuncArg {
name: "shape".into(),
ty: arg0.custom.unwrap(),
default_value: None,
is_vararg: false,
},
FuncArg {
name: "fill_value".into(),
ty: arg1.custom.unwrap(),
default_value: None,
is_vararg: false,
},
],
ret,
@ -1472,16 +1493,19 @@ impl<'a> Inferencer<'a> {
name: "object".into(),
ty: arg0.custom.unwrap(),
default_value: None,
is_vararg: false,
},
FuncArg {
name: "copy".into(),
ty: self.primitives.bool,
default_value: Some(SymbolValue::Bool(true)),
is_vararg: false,
},
FuncArg {
name: "ndmin".into(),
ty: self.primitives.int32,
default_value: Some(SymbolValue::U32(0)),
is_vararg: false,
},
],
ret,
@ -1539,6 +1563,7 @@ impl<'a> Inferencer<'a> {
loc: Some(location),
operator_info: None,
};
println!("{}", self.unifier.stringify(func.custom.unwrap()));
self.unifier.unify_call(&call, func.custom.unwrap(), sign).map_err(|e| {
HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()])
})?;
@ -1603,7 +1628,7 @@ impl<'a> Inferencer<'a> {
ast::Constant::Tuple(vals) => {
let ty: Result<Vec<_>, _> =
vals.iter().map(|x| self.infer_constant(x, loc)).collect();
Ok(self.unifier.add_ty(TypeEnum::TTuple { ty: ty? }))
Ok(self.unifier.add_ty(TypeEnum::TTuple { ty: ty?, is_vararg_ctx: false }))
}
ast::Constant::Str(_) => Ok(self.primitives.str),
ast::Constant::None => {
@ -1637,7 +1662,7 @@ impl<'a> Inferencer<'a> {
#[allow(clippy::unnecessary_wraps)]
fn infer_tuple(&mut self, elts: &[ast::Expr<Option<Type>>]) -> InferenceResult {
let ty = elts.iter().map(|x| x.custom.unwrap()).collect();
Ok(self.unifier.add_ty(TypeEnum::TTuple { ty }))
Ok(self.unifier.add_ty(TypeEnum::TTuple { ty, is_vararg_ctx: false }))
}
/// Checks for non-class attributes

View File

@ -83,7 +83,12 @@ impl TestEnvironment {
});
with_fields(&mut unifier, int32, |unifier, fields| {
let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }],
args: vec![FuncArg {
name: "other".into(),
ty: int32,
default_value: None,
is_vararg: false,
}],
ret: int32,
vars: VarMap::new(),
}));
@ -224,7 +229,12 @@ impl TestEnvironment {
});
with_fields(&mut unifier, int32, |unifier, fields| {
let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }],
args: vec![FuncArg {
name: "other".into(),
ty: int32,
default_value: None,
is_vararg: false,
}],
ret: int32,
vars: VarMap::new(),
}));

View File

@ -1,15 +1,14 @@
use indexmap::IndexMap;
use itertools::Itertools;
use itertools::{repeat_n, Itertools};
use nac3parser::ast::{Cmpop, Location, StrRef, Unaryop};
use std::cell::RefCell;
use std::collections::HashMap;
use std::fmt::{self, Display};
use std::iter::zip;
use std::iter::{repeat_with, zip};
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use std::{borrow::Cow, collections::HashSet};
use nac3parser::ast::{Cmpop, Location, StrRef, Unaryop};
use super::magic_methods::Binop;
use super::type_error::{TypeError, TypeErrorKind};
use super::unification_table::{UnificationKey, UnificationTable};
@ -115,6 +114,7 @@ pub struct FuncArg {
pub name: StrRef,
pub ty: Type,
pub default_value: Option<SymbolValue>,
pub is_vararg: bool,
}
impl FuncArg {
@ -233,6 +233,12 @@ pub enum TypeEnum {
TTuple {
/// The types of elements present in this tuple.
ty: Vec<Type>,
/// Whether this tuple is used in a vararg context.
///
/// If `true`, `ty` must only contain one type, and the tuple is assumed to contain any
/// number of `ty`-typed values.
is_vararg_ctx: bool,
},
/// An object type.
@ -527,7 +533,7 @@ impl Unifier {
TypeEnum::TVirtual { ty } => self.get_instantiations(*ty).map(|ty| {
ty.iter().map(|&ty| self.add_ty(TypeEnum::TVirtual { ty })).collect_vec()
}),
TypeEnum::TTuple { ty } => {
TypeEnum::TTuple { ty, is_vararg_ctx } => {
let tuples = ty
.iter()
.map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty]))
@ -537,7 +543,12 @@ impl Unifier {
None
} else {
Some(
tuples.into_iter().map(|ty| self.add_ty(TypeEnum::TTuple { ty })).collect(),
tuples
.into_iter()
.map(|ty| {
self.add_ty(TypeEnum::TTuple { ty, is_vararg_ctx: *is_vararg_ctx })
})
.collect(),
)
}
}
@ -581,7 +592,7 @@ impl Unifier {
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
TCall { .. } => false,
TVirtual { ty } => self.is_concrete(*ty, allowed_typevars),
TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)),
TTuple { ty, .. } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)),
TObj { params: vars, .. } => {
vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars))
}
@ -649,6 +660,7 @@ impl Unifier {
// Get details about the function signature/parameters.
let num_params = signature.args.len();
let is_vararg = signature.args.iter().any(|arg| arg.is_vararg);
// Force the type vars in `b` and `signature' to be up-to-date.
let b = self.instantiate_fun(b, signature);
@ -737,7 +749,7 @@ impl Unifier {
};
// Check for "too many arguments"
if num_params < posargs.len() {
if !is_vararg && num_params < posargs.len() {
let expected_min_count =
signature.args.iter().filter(|param| param.is_required()).count();
let expected_max_count = num_params;
@ -761,7 +773,7 @@ impl Unifier {
.collect();
// Now consume all positional arguments and typecheck them.
for (&arg_ty, param) in zip(posargs, signature.args.iter()) {
for (param, &arg_ty) in zip(signature.args.iter().chain(repeat_with(|| signature.args.iter().last().unwrap())), posargs) {
// We will also use this opportunity to mark the corresponding `param_info` as having been supplied.
let param_info = param_info_by_name.get_mut(&param.name).unwrap();
param_info.has_been_supplied = true;
@ -959,7 +971,10 @@ impl Unifier {
self.unify_impl(x, b, false)?;
self.set_a_to_b(a, x);
}
(TVar { fields: Some(fields), range, is_const_generic: false, .. }, TTuple { ty }) => {
(
TVar { fields: Some(fields), range, is_const_generic: false, .. },
TTuple { ty, .. },
) => {
let len = i32::try_from(ty.len()).unwrap();
for (k, v) in fields {
match *k {
@ -1056,15 +1071,43 @@ impl Unifier {
self.set_a_to_b(a, b);
}
(TTuple { ty: ty1 }, TTuple { ty: ty2 }) => {
if ty1.len() != ty2.len() {
return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None));
}
for (x, y) in ty1.iter().zip(ty2.iter()) {
if self.unify_impl(*x, *y, false).is_err() {
return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None));
(
TTuple { ty: ty1, is_vararg_ctx: is_vararg1 },
TTuple { ty: ty2, is_vararg_ctx: is_vararg2 },
) => {
// Rules for Tuples:
// - ty1: is_vararg && ty2: is_vararg -> ty1[0] == ty2[0]
// - ty1: is_vararg && ty2: !is_vararg -> type error (not enough info to infer the correct number of arguments)
// - ty1: !is_vararg && ty2: is_vararg -> ty1[..] == ty2[0]
// - ty1: !is_vararg && ty2: !is_vararg -> ty1.len() == ty2.len() && ty1[i] == ty2[i]
debug_assert!(!is_vararg1 || ty1.len() == 1);
debug_assert!(!is_vararg2 || ty2.len() == 1);
match (*is_vararg1, *is_vararg2) {
(true, true) => {
if self.unify_impl(ty1[0], ty2[0], false).is_err() {
return Self::incompatible_types(a, b);
}
}
(true, false) => return Self::incompatible_types(a, b),
(false, true) => {
for y in ty2 {
if self.unify_impl(ty1[0], *y, false).is_err() {
return Self::incompatible_types(a, b);
}
}
}
(false, false) => {
for (x, y) in ty1.iter().zip(ty2.iter()) {
if self.unify_impl(*x, *y, false).is_err() {
return Self::incompatible_types(a, b);
}
}
}
}
self.set_a_to_b(a, b);
}
(TVar { fields: Some(map), range, .. }, TObj { obj_id, fields, params }) => {
@ -1307,10 +1350,22 @@ impl Unifier {
TypeEnum::TLiteral { values, .. } => {
format!("const({})", values.iter().map(|v| format!("{v:?}")).join(", "))
}
TypeEnum::TTuple { ty } => {
let mut fields =
ty.iter().map(|v| self.internal_stringify(*v, obj_to_name, var_to_name, notes));
format!("tuple[{}]", fields.join(", "))
TypeEnum::TTuple { ty, is_vararg_ctx } => {
if *is_vararg_ctx {
debug_assert_eq!(ty.len(), 1);
let field = self.internal_stringify(
*ty.iter().next().unwrap(),
obj_to_name,
var_to_name,
notes,
);
format!("tuple[*{field}]")
} else {
let mut fields = ty
.iter()
.map(|v| self.internal_stringify(*v, obj_to_name, var_to_name, notes));
format!("tuple[{}]", fields.join(", "))
}
}
TypeEnum::TVirtual { ty } => {
format!(
@ -1335,17 +1390,21 @@ impl Unifier {
.args
.iter()
.map(|arg| {
let vararg_prefix = if arg.is_vararg { "*" } else { "" };
if let Some(dv) = &arg.default_value {
format!(
"{}:{}={}",
"{}:{}{}={}",
arg.name,
vararg_prefix,
self.internal_stringify(arg.ty, obj_to_name, var_to_name, notes),
dv
)
} else {
format!(
"{}:{}",
"{}:{}{}",
arg.name,
vararg_prefix,
self.internal_stringify(arg.ty, obj_to_name, var_to_name, notes)
)
}
@ -1431,7 +1490,7 @@ impl Unifier {
match &*ty {
TypeEnum::TRigidVar { .. } | TypeEnum::TLiteral { .. } => None,
TypeEnum::TVar { id, .. } => mapping.get(id).copied(),
TypeEnum::TTuple { ty } => {
TypeEnum::TTuple { ty, is_vararg_ctx } => {
let mut new_ty = Cow::from(ty);
for (i, t) in ty.iter().enumerate() {
if let Some(t1) = self.subst_impl(*t, mapping, cache) {
@ -1439,7 +1498,10 @@ impl Unifier {
}
}
if matches!(new_ty, Cow::Owned(_)) {
Some(self.add_ty(TypeEnum::TTuple { ty: new_ty.into_owned() }))
Some(self.add_ty(TypeEnum::TTuple {
ty: new_ty.into_owned(),
is_vararg_ctx: *is_vararg_ctx,
}))
} else {
None
}
@ -1599,16 +1661,37 @@ impl Unifier {
}
}
(TVar { range, .. }, _) => self.check_var_compatibility(b, range).or(Err(())),
(TTuple { ty: ty1 }, TTuple { ty: ty2 }) if ty1.len() == ty2.len() => {
let ty: Vec<_> = zip(ty1.iter(), ty2.iter())
.map(|(a, b)| self.get_intersection(*a, *b))
.try_collect()?;
if ty.iter().any(Option::is_some) {
Ok(Some(self.add_ty(TTuple {
ty: zip(ty, ty1.iter()).map(|(a, b)| a.unwrap_or(*b)).collect(),
})))
(
TTuple { ty: ty1, is_vararg_ctx: is_vararg1 },
TTuple { ty: ty2, is_vararg_ctx: is_vararg2 },
) => {
if *is_vararg1 && *is_vararg2 {
let isect_ty = self.get_intersection(ty1[0], ty2[0])?;
Ok(isect_ty.map(|ty| self.add_ty(TTuple { ty: vec![ty], is_vararg_ctx: true })))
} else {
Ok(None)
let zip_iter: Box<dyn Iterator<Item = (&Type, &Type)>> =
match (*is_vararg1, *is_vararg2) {
(true, _) => Box::new(repeat_n(&ty1[0], ty2.len()).zip(ty2.iter())),
(_, false) => Box::new(ty1.iter().zip(repeat_n(&ty2[0], ty1.len()))),
_ => {
if ty1.len() != ty2.len() {
return Err(());
}
Box::new(ty1.iter().zip(ty2.iter()))
}
};
let ty: Vec<_> =
zip_iter.map(|(a, b)| self.get_intersection(*a, *b)).try_collect()?;
Ok(if ty.iter().any(Option::is_some) {
Some(self.add_ty(TTuple {
ty: zip(ty, ty1.iter()).map(|(a, b)| a.unwrap_or(*b)).collect(),
is_vararg_ctx: false,
}))
} else {
None
})
}
}
// TODO(Derppening): #444

View File

@ -28,7 +28,10 @@ impl Unifier {
TypeEnum::TVar { fields: Some(map1), .. },
TypeEnum::TVar { fields: Some(map2), .. },
) => self.map_eq2(map1, map2),
(TypeEnum::TTuple { ty: ty1 }, TypeEnum::TTuple { ty: ty2 }) => {
(
TypeEnum::TTuple { ty: ty1, is_vararg_ctx: false },
TypeEnum::TTuple { ty: ty2, is_vararg_ctx: false },
) => {
ty1.len() == ty2.len()
&& ty1.iter().zip(ty2.iter()).all(|(t1, t2)| self.eq(*t1, *t2))
}
@ -178,7 +181,7 @@ impl TestEnvironment {
ty.push(result.0);
s = result.1;
}
(self.unifier.add_ty(TypeEnum::TTuple { ty }), &s[1..])
(self.unifier.add_ty(TypeEnum::TTuple { ty, is_vararg_ctx: false }), &s[1..])
}
"Record" => {
let mut s = &typ[end..];
@ -608,7 +611,7 @@ fn test_instantiation() {
let v1 = env.unifier.get_fresh_var_with_range(&[list_v, int], None, None).ty;
let v2 = env.unifier.get_fresh_var_with_range(&[list_int, float], None, None).ty;
let t = env.unifier.get_dummy_var().ty;
let tuple = env.unifier.add_ty(TypeEnum::TTuple { ty: vec![v, v1, v2] });
let tuple = env.unifier.add_ty(TypeEnum::TTuple { ty: vec![v, v1, v2], is_vararg_ctx: false });
let v3 = env.unifier.get_fresh_var_with_range(&[tuple, t], None, None).ty;
// t = TypeVar('t')
// v = TypeVar('v', int, bool)

View File

@ -0,0 +1,11 @@
def f(*args: int32):
pass
def run() -> int32:
f()
f(1)
f(1, 2)
f(1, 2, 3)
return 0