Compare commits

...

4 Commits

14 changed files with 613 additions and 213 deletions

View File

@ -26,7 +26,7 @@ pub struct Location {
impl fmt::Display for Location { impl fmt::Display for Location {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}: line {} column {}", self.file.0, self.row, self.column) write!(f, "{}:{}:{}", self.file.0, self.row, self.column)
} }
} }

View File

@ -207,12 +207,12 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
value: &Constant, value: &Constant,
ty: Type, ty: Type,
) -> BasicValueEnum<'ctx> { ) -> Option<BasicValueEnum<'ctx>> {
match value { match value {
Constant::Bool(v) => { Constant::Bool(v) => {
assert!(self.unifier.unioned(ty, self.primitives.bool)); assert!(self.unifier.unioned(ty, self.primitives.bool));
let ty = self.ctx.i8_type(); let ty = self.ctx.i8_type();
ty.const_int(if *v { 1 } else { 0 }, false).into() Some(ty.const_int(if *v { 1 } else { 0 }, false).into())
} }
Constant::Int(val) => { Constant::Int(val) => {
let ty = if self.unifier.unioned(ty, self.primitives.int32) let ty = if self.unifier.unioned(ty, self.primitives.int32)
@ -226,28 +226,33 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
} else { } else {
unreachable!(); unreachable!();
}; };
ty.const_int(*val as u64, false).into() Some(ty.const_int(*val as u64, false).into())
} }
Constant::Float(v) => { Constant::Float(v) => {
assert!(self.unifier.unioned(ty, self.primitives.float)); assert!(self.unifier.unioned(ty, self.primitives.float));
let ty = self.ctx.f64_type(); let ty = self.ctx.f64_type();
ty.const_float(*v).into() Some(ty.const_float(*v).into())
} }
Constant::Tuple(v) => { Constant::Tuple(v) => {
let ty = self.unifier.get_ty(ty); let ty = self.unifier.get_ty(ty);
let types = let types =
if let TypeEnum::TTuple { ty } = &*ty { ty.clone() } else { unreachable!() }; if let TypeEnum::TTuple { ty } = &*ty { ty.clone() } else { unreachable!() };
let values = zip(types.into_iter(), v.iter()) let values = zip(types.into_iter(), v.iter())
.map(|(ty, v)| self.gen_const(generator, v, ty)) .map_while(|(ty, v)| self.gen_const(generator, v, ty))
.collect_vec(); .collect_vec();
let types = values.iter().map(BasicValueEnum::get_type).collect_vec();
let ty = self.ctx.struct_type(&types, false); if values.len() == v.len() {
ty.const_named_struct(&values).into() let types = values.iter().map(BasicValueEnum::get_type).collect_vec();
let ty = self.ctx.struct_type(&types, false);
Some(ty.const_named_struct(&values).into())
} else {
None
}
} }
Constant::Str(v) => { Constant::Str(v) => {
assert!(self.unifier.unioned(ty, self.primitives.str)); assert!(self.unifier.unioned(ty, self.primitives.str));
if let Some(v) = self.const_strings.get(v) { if let Some(v) = self.const_strings.get(v) {
*v Some(*v)
} else { } else {
let str_ptr = let str_ptr =
self.builder.build_global_string_ptr(v, "const").as_pointer_value().into(); self.builder.build_global_string_ptr(v, "const").as_pointer_value().into();
@ -256,9 +261,22 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
let val = let val =
ty.into_struct_type().const_named_struct(&[str_ptr, size.into()]).into(); ty.into_struct_type().const_named_struct(&[str_ptr, size.into()]).into();
self.const_strings.insert(v.to_string(), val); self.const_strings.insert(v.to_string(), val);
val Some(val)
} }
} }
Constant::Ellipsis => {
let msg = self.gen_string(generator, "");
self.raise_exn(
generator,
"0:NotImplementedError",
msg,
[None, None, None],
self.current_loc,
);
None
}
_ => unreachable!(), _ => unreachable!(),
} }
} }
@ -481,7 +499,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
s: S, s: S,
) -> BasicValueEnum<'ctx> { ) -> BasicValueEnum<'ctx> {
self.gen_const(generator, &nac3parser::ast::Constant::Str(s.into()), self.primitives.str) self.gen_const(generator, &nac3parser::ast::Constant::Str(s.into()), self.primitives.str).unwrap()
} }
pub fn raise_exn( pub fn raise_exn(
@ -1211,7 +1229,10 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>(
Ok(Some(match &expr.node { Ok(Some(match &expr.node {
ExprKind::Constant { value, .. } => { ExprKind::Constant { value, .. } => {
let ty = expr.custom.unwrap(); let ty = expr.custom.unwrap();
ctx.gen_const(generator, value, ty).into() let Some(const_val) = ctx.gen_const(generator, value, ty) else {
return Ok(None)
};
const_val.into()
} }
ExprKind::Name { id, .. } if id == &"none".into() => { ExprKind::Name { id, .. } if id == &"none".into() => {
match ( match (

View File

@ -604,7 +604,7 @@ pub fn exn_constructor<'ctx, 'a>(
let msg = if !args.is_empty() { let msg = if !args.is_empty() {
args.remove(0).1.to_basic_value_enum(ctx, generator, ctx.primitives.str)? args.remove(0).1.to_basic_value_enum(ctx, generator, ctx.primitives.str)?
} else { } else {
empty_string empty_string.unwrap()
}; };
ctx.builder.build_store(ptr, msg); ctx.builder.build_store(ptr, msg);
for i in [6, 7, 8].iter() { for i in [6, 7, 8].iter() {
@ -627,7 +627,7 @@ pub fn exn_constructor<'ctx, 'a>(
&[zero, int32.const_int(*i, false)], &[zero, int32.const_int(*i, false)],
"exn.str", "exn.str",
); );
ctx.builder.build_store(ptr, empty_string); ctx.builder.build_store(ptr, empty_string.unwrap());
} }
// set ints to zero // set ints to zero
for i in [2, 3].iter() { for i in [2, 3].iter() {

View File

@ -1,11 +1,12 @@
use std::fmt::Debug; use std::fmt::Debug;
use std::sync::Arc; use std::sync::Arc;
use std::{collections::HashMap, fmt::Display}; use std::{collections::HashMap, fmt::Display};
use std::rc::Rc;
use crate::typecheck::typedef::TypeEnum; use crate::typecheck::typedef::TypeEnum;
use crate::{ use crate::{
codegen::CodeGenContext, codegen::CodeGenContext,
toplevel::{DefinitionId, TopLevelDef}, toplevel::{DefinitionId, TopLevelDef, type_annotation::TypeAnnotation},
}; };
use crate::{ use crate::{
codegen::CodeGenerator, codegen::CodeGenerator,
@ -16,7 +17,7 @@ use crate::{
}; };
use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue}; use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue};
use itertools::{chain, izip}; use itertools::{chain, izip};
use nac3parser::ast::{Expr, Location, StrRef}; use nac3parser::ast::{Constant, Expr, Location, StrRef};
use parking_lot::RwLock; use parking_lot::RwLock;
#[derive(Clone, PartialEq, Debug)] #[derive(Clone, PartialEq, Debug)]
@ -33,6 +34,147 @@ pub enum SymbolValue {
OptionNone, OptionNone,
} }
impl SymbolValue {
/// Creates a [SymbolValue] from a [Constant].
///
/// * `constant` - The constant to create the value from.
/// * `expected_ty` - The expected type of the [SymbolValue].
pub fn from_constant(
constant: &Constant,
expected_ty: Type,
primitives: &PrimitiveStore,
unifier: &mut Unifier
) -> Result<Self, String> {
match constant {
Constant::None => {
if unifier.unioned(expected_ty, primitives.option) {
Ok(SymbolValue::OptionNone)
} else {
Err(format!("Expected {:?}, but got Option", expected_ty))
}
}
Constant::Bool(b) => {
if unifier.unioned(expected_ty, primitives.bool) {
Ok(SymbolValue::Bool(*b))
} else {
Err(format!("Expected {:?}, but got bool", expected_ty))
}
}
Constant::Str(s) => {
if unifier.unioned(expected_ty, primitives.str) {
Ok(SymbolValue::Str(s.to_string()))
} else {
Err(format!("Expected {:?}, but got str", expected_ty))
}
},
Constant::Int(i) => {
if unifier.unioned(expected_ty, primitives.int32) {
i32::try_from(*i)
.map(|val| SymbolValue::I32(val))
.map_err(|e| e.to_string())
} else if unifier.unioned(expected_ty, primitives.int64) {
i64::try_from(*i)
.map(|val| SymbolValue::I64(val))
.map_err(|e| e.to_string())
} else if unifier.unioned(expected_ty, primitives.uint32) {
u32::try_from(*i)
.map(|val| SymbolValue::U32(val))
.map_err(|e| e.to_string())
} else if unifier.unioned(expected_ty, primitives.uint64) {
u64::try_from(*i)
.map(|val| SymbolValue::U64(val))
.map_err(|e| e.to_string())
} else {
Err(format!("Expected {:?}, but got int", expected_ty))
}
}
Constant::Tuple(t) => {
let expected_ty = unifier.get_ty(expected_ty);
let TypeEnum::TTuple { ty } = expected_ty.as_ref() else {
return Err(format!("Expected {:?}, but got Tuple", expected_ty.get_type_name()))
};
assert_eq!(ty.len(), t.len());
let elems = t.into_iter()
.zip(ty)
.map(|(constant, ty)| Self::from_constant(constant, *ty, primitives, unifier))
.collect::<Result<Vec<SymbolValue>, _>>()?;
Ok(SymbolValue::Tuple(elems))
}
Constant::Float(f) => {
if unifier.unioned(expected_ty, primitives.float) {
Ok(SymbolValue::Double(*f))
} else {
Err(format!("Expected {:?}, but got float", expected_ty))
}
},
_ => Err(format!("Unsupported value type {:?}", constant)),
}
}
/// Returns the [Type] representing the data type of this value.
pub fn get_type(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> Type {
match self {
SymbolValue::I32(_) => primitives.int32,
SymbolValue::I64(_) => primitives.int64,
SymbolValue::U32(_) => primitives.uint32,
SymbolValue::U64(_) => primitives.uint64,
SymbolValue::Str(_) => primitives.str,
SymbolValue::Double(_) => primitives.float,
SymbolValue::Bool(_) => primitives.bool,
SymbolValue::Tuple(vs) => {
let vs_tys = vs
.iter()
.map(|v| v.get_type(primitives, unifier))
.collect::<Vec<_>>();
unifier.add_ty(TypeEnum::TTuple {
ty: vs_tys,
})
}
SymbolValue::OptionSome(_) => primitives.option,
SymbolValue::OptionNone => primitives.option,
}
}
/// Returns the [TypeAnnotation] representing the data type of this value.
pub fn get_type_annotation(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> TypeAnnotation {
match self {
SymbolValue::Bool(..) => TypeAnnotation::Primitive(primitives.bool),
SymbolValue::Double(..) => TypeAnnotation::Primitive(primitives.float),
SymbolValue::I32(..) => TypeAnnotation::Primitive(primitives.int32),
SymbolValue::I64(..) => TypeAnnotation::Primitive(primitives.int64),
SymbolValue::U32(..) => TypeAnnotation::Primitive(primitives.uint32),
SymbolValue::U64(..) => TypeAnnotation::Primitive(primitives.uint64),
SymbolValue::Str(..) => TypeAnnotation::Primitive(primitives.str),
SymbolValue::Tuple(vs) => {
let vs_tys = vs
.iter()
.map(|v| v.get_type_annotation(primitives, unifier))
.collect::<Vec<_>>();
TypeAnnotation::Tuple(vs_tys)
}
SymbolValue::OptionNone => TypeAnnotation::CustomClass {
id: primitives.option.get_obj_id(unifier),
params: Default::default(),
},
SymbolValue::OptionSome(v) => {
let ty = v.get_type_annotation(primitives, unifier);
TypeAnnotation::CustomClass {
id: primitives.option.get_obj_id(unifier),
params: vec![ty],
}
}
}
}
/// Returns the [TypeEnum] representing the data type of this value.
pub fn get_type_enum(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> Rc<TypeEnum> {
let ty = self.get_type(primitives, unifier);
unifier.get_ty(ty)
}
}
impl Display for SymbolValue { impl Display for SymbolValue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {

View File

@ -58,6 +58,7 @@ impl TopLevelComposer {
let mut unifier = primitives.1; let mut unifier = primitives.1;
let mut keyword_list: HashSet<StrRef> = HashSet::from_iter(vec![ let mut keyword_list: HashSet<StrRef> = HashSet::from_iter(vec![
"Generic".into(), "Generic".into(),
"Const".into(),
"virtual".into(), "virtual".into(),
"list".into(), "list".into(),
"tuple".into(), "tuple".into(),
@ -401,6 +402,7 @@ impl TopLevelComposer {
let class_resolver = class_resolver.deref(); let class_resolver = class_resolver.deref();
let mut is_generic = false; let mut is_generic = false;
let mut is_const_generic = false;
for b in class_bases_ast { for b in class_bases_ast {
match &b.node { match &b.node {
// analyze typevars bounded to the class, // analyze typevars bounded to the class,
@ -408,66 +410,77 @@ impl TopLevelComposer {
// things like `class A(Generic[T, V, ImportedModule.T])` is not supported // things like `class A(Generic[T, V, ImportedModule.T])` is not supported
// i.e. only simple names are allowed in the subscript // i.e. only simple names are allowed in the subscript
// should update the TopLevelDef::Class.typevars and the TypeEnum::TObj.params // should update the TopLevelDef::Class.typevars and the TypeEnum::TObj.params
ast::ExprKind::Subscript { value, slice, .. } ast::ExprKind::Subscript { value, slice, .. } => {
if { match &value.node {
matches!( ast::ExprKind::Name { id, .. } if id == &"Generic".into() || id == &"Const".into() => {
&value.node, if id == &"Generic".into() {
ast::ExprKind::Name { id, .. } if id == &"Generic".into() if !is_generic {
) is_generic = true;
} => } else {
{ return Err(format!(
if !is_generic { "only single Generic[...] is allowed (at {})",
is_generic = true; b.location
} else { ));
return Err(format!( }
"only single Generic[...] is allowed (at {})", } else if id == &"Const".into() {
b.location if !is_const_generic {
)); is_const_generic = true;
} } else {
return Err(format!(
let type_var_list: Vec<&ast::Expr<()>>; "only single Const[...] is allowed (at {})",
// if `class A(Generic[T, V, G])` b.location
if let ast::ExprKind::Tuple { elts, .. } = &slice.node { ));
type_var_list = elts.iter().collect_vec(); }
// `class A(Generic[T])`
} else {
type_var_list = vec![slice.deref()];
}
// parse the type vars
let type_vars = type_var_list
.into_iter()
.map(|e| {
class_resolver.parse_type_annotation(
&temp_def_list,
unifier,
primitives_store,
e,
)
})
.collect::<Result<Vec<_>, _>>()?;
// check if all are unique type vars
let all_unique_type_var = {
let mut occurred_type_var_id: HashSet<u32> = HashSet::new();
type_vars.iter().all(|x| {
let ty = unifier.get_ty(*x);
if let TypeEnum::TVar { id, .. } = ty.as_ref() {
occurred_type_var_id.insert(*id)
} else { } else {
false unreachable!()
} }
})
};
if !all_unique_type_var {
return Err(format!(
"duplicate type variable occurs (at {})",
slice.location
));
}
// add to TopLevelDef let type_var_list: Vec<&ast::Expr<()>>;
class_def_type_vars.extend(type_vars); // if `class A(Generic[T, V, G])`
if let ast::ExprKind::Tuple { elts, .. } = &slice.node {
type_var_list = elts.iter().collect_vec();
// `class A(Generic[T])`
} else {
type_var_list = vec![slice.deref()];
}
// parse the type vars
let type_vars = type_var_list
.into_iter()
.map(|e| {
class_resolver.parse_type_annotation(
&temp_def_list,
unifier,
primitives_store,
e,
)
})
.collect::<Result<Vec<_>, _>>()?;
// check if all are unique type vars
let all_unique_type_var = {
let mut occurred_type_var_id: HashSet<u32> = HashSet::new();
type_vars.iter().all(|x| {
let ty = unifier.get_ty(*x);
if let TypeEnum::TVar { id, .. } = ty.as_ref() {
occurred_type_var_id.insert(*id)
} else {
false
}
})
};
if !all_unique_type_var {
return Err(format!(
"duplicate type variable occurs (at {})",
slice.location
));
}
// add to TopLevelDef
class_def_type_vars.extend(type_vars);
}
_ => continue,
}
} }
// if others, do nothing in this function // if others, do nothing in this function
@ -536,7 +549,7 @@ impl TopLevelComposer {
ast::ExprKind::Subscript { value, .. } ast::ExprKind::Subscript { value, .. }
if matches!( if matches!(
&value.node, &value.node,
ast::ExprKind::Name { id, .. } if id == &"Generic".into() ast::ExprKind::Name { id, .. } if id == &"Generic".into() || id == &"Const".into()
) )
) { ) {
continue; continue;
@ -560,6 +573,7 @@ impl TopLevelComposer {
&primitive_types, &primitive_types,
b, b,
vec![(*class_def_id, class_type_vars.clone())].into_iter().collect(), vec![(*class_def_id, class_type_vars.clone())].into_iter().collect(),
None,
)?; )?;
if let TypeAnnotation::CustomClass { .. } = &base_ty { if let TypeAnnotation::CustomClass { .. } = &base_ty {
@ -894,6 +908,7 @@ impl TopLevelComposer {
// NOTE: since only class need this, for function // NOTE: since only class need this, for function
// it should be fine to be empty map // it should be fine to be empty map
HashMap::new(), HashMap::new(),
None,
)?; )?;
let type_vars_within = let type_vars_within =
@ -961,6 +976,7 @@ impl TopLevelComposer {
// NOTE: since only class need this, for function // NOTE: since only class need this, for function
// it should be fine to be empty map // it should be fine to be empty map
HashMap::new(), HashMap::new(),
None,
)? )?
}; };
@ -1158,6 +1174,7 @@ impl TopLevelComposer {
vec![(class_id, class_type_vars_def.clone())] vec![(class_id, class_type_vars_def.clone())]
.into_iter() .into_iter()
.collect(), .collect(),
None,
)? )?
}; };
// find type vars within this method parameter type annotation // find type vars within this method parameter type annotation
@ -1221,6 +1238,7 @@ impl TopLevelComposer {
primitives, primitives,
result, result,
vec![(class_id, class_type_vars_def.clone())].into_iter().collect(), vec![(class_id, class_type_vars_def.clone())].into_iter().collect(),
None,
)?; )?;
// find type vars within this return type annotation // find type vars within this return type annotation
let type_vars_within = let type_vars_within =
@ -1317,6 +1335,7 @@ impl TopLevelComposer {
primitives, primitives,
annotation.as_ref(), annotation.as_ref(),
vec![(class_id, class_type_vars_def.clone())].into_iter().collect(), vec![(class_id, class_type_vars_def.clone())].into_iter().collect(),
None,
)?; )?;
// find type vars within this return type annotation // find type vars within this return type annotation
let type_vars_within = let type_vars_within =
@ -1735,7 +1754,7 @@ impl TopLevelComposer {
.iter() .iter()
.map(|(_, ty)| { .map(|(_, ty)| {
unifier.get_instantiations(*ty).unwrap_or_else(|| { unifier.get_instantiations(*ty).unwrap_or_else(|| {
if let TypeEnum::TVar { name, loc, .. } = &*unifier.get_ty(*ty) if let TypeEnum::TVar { name, loc, is_const_generic: false, .. } = &*unifier.get_ty(*ty)
{ {
let rigid = unifier.get_fresh_rigid_var(*name, *loc).0; let rigid = unifier.get_fresh_rigid_var(*name, *loc).0;
no_ranges.push(rigid); no_ranges.push(rigid);

View File

@ -416,40 +416,6 @@ impl TopLevelComposer {
primitive: &PrimitiveStore, primitive: &PrimitiveStore,
unifier: &mut Unifier, unifier: &mut Unifier,
) -> Result<(), String> { ) -> Result<(), String> {
fn type_default_param(
val: &SymbolValue,
primitive: &PrimitiveStore,
unifier: &mut Unifier,
) -> TypeAnnotation {
match val {
SymbolValue::Bool(..) => TypeAnnotation::Primitive(primitive.bool),
SymbolValue::Double(..) => TypeAnnotation::Primitive(primitive.float),
SymbolValue::I32(..) => TypeAnnotation::Primitive(primitive.int32),
SymbolValue::I64(..) => TypeAnnotation::Primitive(primitive.int64),
SymbolValue::U32(..) => TypeAnnotation::Primitive(primitive.uint32),
SymbolValue::U64(..) => TypeAnnotation::Primitive(primitive.uint64),
SymbolValue::Str(..) => TypeAnnotation::Primitive(primitive.str),
SymbolValue::Tuple(vs) => {
let vs_tys = vs
.iter()
.map(|v| type_default_param(v, primitive, unifier))
.collect::<Vec<_>>();
TypeAnnotation::Tuple(vs_tys)
}
SymbolValue::OptionNone => TypeAnnotation::CustomClass {
id: primitive.option.get_obj_id(unifier),
params: Default::default(),
},
SymbolValue::OptionSome(v) => {
let ty = type_default_param(v, primitive, unifier);
TypeAnnotation::CustomClass {
id: primitive.option.get_obj_id(unifier),
params: vec![ty],
}
}
}
}
fn is_compatible( fn is_compatible(
found: &TypeAnnotation, found: &TypeAnnotation,
expect: &TypeAnnotation, expect: &TypeAnnotation,
@ -481,7 +447,7 @@ impl TopLevelComposer {
} }
} }
let found = type_default_param(val, primitive, unifier); let found = val.get_type_annotation(primitive, unifier);
if !is_compatible(&found, ty, unifier, primitive) { if !is_compatible(&found, ty, unifier, primitive) {
Err(format!( Err(format!(
"incompatible default parameter type, expect {}, found {}", "incompatible default parameter type, expect {}, found {}",

View File

@ -361,7 +361,7 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
pass pass
"} "}
], ],
vec!["application of type vars to generic class is not currently supported (at unknown: line 4 column 24)"]; vec!["application of type vars to generic class is not currently supported (at unknown:4:24)"];
"err no type var in generic app" "err no type var in generic app"
)] )]
#[test_case( #[test_case(
@ -417,7 +417,7 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
def __init__(): def __init__():
pass pass
"}], "}],
vec!["__init__ method must have a `self` parameter (at unknown: line 2 column 5)"]; vec!["__init__ method must have a `self` parameter (at unknown:2:5)"];
"err no self_1" "err no self_1"
)] )]
#[test_case( #[test_case(
@ -439,7 +439,7 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
"} "}
], ],
vec!["a class definition can only have at most one base class declaration and one generic declaration (at unknown: line 1 column 24)"]; vec!["a class definition can only have at most one base class declaration and one generic declaration (at unknown:1:24)"];
"err multiple inheritance" "err multiple inheritance"
)] )]
#[test_case( #[test_case(
@ -507,7 +507,7 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
pass pass
"} "}
], ],
vec!["duplicate definition of class `A` (at unknown: line 1 column 1)"]; vec!["duplicate definition of class `A` (at unknown:1:1)"];
"class same name" "class same name"
)] )]
fn test_analyze(source: Vec<&str>, res: Vec<&str>) { fn test_analyze(source: Vec<&str>, res: Vec<&str>) {

View File

@ -1,3 +1,4 @@
use crate::symbol_resolver::SymbolValue;
use super::*; use super::*;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -12,6 +13,16 @@ pub enum TypeAnnotation {
// can only be CustomClassKind // can only be CustomClassKind
Virtual(Box<TypeAnnotation>), Virtual(Box<TypeAnnotation>),
TypeVar(Type), TypeVar(Type),
/// A constant used in the context of a const-generic variable.
Constant {
/// The non-type variable associated with this constant.
///
/// Invoking [Unifier::get_ty] on this type will return a [TypeEnum::TVar] representing the
/// const generic variable of which this constant is associated with.
ty: Type,
/// The constant value of this constant.
value: SymbolValue
},
List(Box<TypeAnnotation>), List(Box<TypeAnnotation>),
Tuple(Vec<TypeAnnotation>), Tuple(Vec<TypeAnnotation>),
} }
@ -47,6 +58,7 @@ impl TypeAnnotation {
} }
) )
} }
Constant { value, .. } => format!("Const({value})"),
Virtual(ty) => format!("virtual[{}]", ty.stringify(unifier)), Virtual(ty) => format!("virtual[{}]", ty.stringify(unifier)),
List(ty) => format!("list[{}]", ty.stringify(unifier)), List(ty) => format!("list[{}]", ty.stringify(unifier)),
Tuple(types) => { Tuple(types) => {
@ -56,6 +68,12 @@ impl TypeAnnotation {
} }
} }
/// Parses an AST expression `expr` into a [TypeAnnotation].
///
/// * `locked` - A [HashMap] containing the IDs of known definitions, mapped to a [Vec] of all
/// generic variables associated with the definition.
/// * `type_var` - The type variable associated with the type argument currently being parsed. Pass
/// [None] when this function is invoked externally.
pub fn parse_ast_to_type_annotation_kinds<T>( pub fn parse_ast_to_type_annotation_kinds<T>(
resolver: &(dyn SymbolResolver + Send + Sync), resolver: &(dyn SymbolResolver + Send + Sync),
top_level_defs: &[Arc<RwLock<TopLevelDef>>], top_level_defs: &[Arc<RwLock<TopLevelDef>>],
@ -64,6 +82,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
expr: &ast::Expr<T>, expr: &ast::Expr<T>,
// the key stores the type_var of this topleveldef::class, we only need this field here // the key stores the type_var of this topleveldef::class, we only need this field here
locked: HashMap<DefinitionId, Vec<Type>>, locked: HashMap<DefinitionId, Vec<Type>>,
type_var: Option<Type>,
) -> Result<TypeAnnotation, String> { ) -> Result<TypeAnnotation, String> {
let name_handle = |id: &StrRef, let name_handle = |id: &StrRef,
unifier: &mut Unifier, unifier: &mut Unifier,
@ -127,7 +146,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
slice: &ast::Expr<T>, slice: &ast::Expr<T>,
unifier: &mut Unifier, unifier: &mut Unifier,
mut locked: HashMap<DefinitionId, Vec<Type>>| { mut locked: HashMap<DefinitionId, Vec<Type>>| {
if vec!["virtual".into(), "Generic".into(), "list".into(), "tuple".into(), "Option".into()].contains(id) if vec!["virtual".into(), "Generic".into(), "Const".into(), "list".into(), "tuple".into(), "Option".into()].contains(id)
{ {
return Err(format!("keywords cannot be class name (at {})", expr.location)); return Err(format!("keywords cannot be class name (at {})", expr.location));
} }
@ -161,7 +180,8 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
} }
let result = params_ast let result = params_ast
.iter() .iter()
.map(|x| { .enumerate()
.map(|(idx, x)| {
parse_ast_to_type_annotation_kinds( parse_ast_to_type_annotation_kinds(
resolver, resolver,
top_level_defs, top_level_defs,
@ -172,6 +192,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
locked.insert(obj_id, type_vars.clone()); locked.insert(obj_id, type_vars.clone());
locked.clone() locked.clone()
}, },
Some(type_vars[idx]),
) )
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
@ -190,6 +211,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
}; };
Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos }) Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos })
}; };
match &expr.node { match &expr.node {
ast::ExprKind::Name { id, .. } => name_handle(id, unifier, locked), ast::ExprKind::Name { id, .. } => name_handle(id, unifier, locked),
// virtual // virtual
@ -205,6 +227,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
primitives, primitives,
slice.as_ref(), slice.as_ref(),
locked, locked,
None,
)?; )?;
if !matches!(def, TypeAnnotation::CustomClass { .. }) { if !matches!(def, TypeAnnotation::CustomClass { .. }) {
unreachable!("must be concretized custom class kind in the virtual") unreachable!("must be concretized custom class kind in the virtual")
@ -225,6 +248,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
primitives, primitives,
slice.as_ref(), slice.as_ref(),
locked, locked,
None,
)?; )?;
Ok(TypeAnnotation::List(def_ann.into())) Ok(TypeAnnotation::List(def_ann.into()))
} }
@ -242,6 +266,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
primitives, primitives,
slice.as_ref(), slice.as_ref(),
locked, locked,
None,
)?; )?;
let id = let id =
if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(primitives.option).as_ref() { if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(primitives.option).as_ref() {
@ -275,6 +300,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
primitives, primitives,
e, e,
locked.clone(), locked.clone(),
None,
) )
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
@ -290,6 +316,31 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
} }
} }
ast::ExprKind::Constant { value, .. } => {
let type_var = type_var.expect("Expect type variable to be present");
let ntv_ty_enum = unifier.get_ty_immutable(type_var);
let TypeEnum::TVar { range: underlying_ty, .. } = ntv_ty_enum.as_ref() else {
unreachable!()
};
let underlying_ty = underlying_ty[0];
let value = SymbolValue::from_constant(value, underlying_ty, primitives, unifier)?;
if matches!(value, SymbolValue::Str(_) | SymbolValue::Tuple(_) | SymbolValue::OptionSome(_)) {
return Err(format!(
"expression {} is not allowed for constant type annotation (at {})",
value.to_string(),
expr.location
))
}
Ok(TypeAnnotation::Constant {
ty: type_var,
value,
})
}
_ => Err(format!("unsupported expression for type annotation (at {})", expr.location)), _ => Err(format!("unsupported expression for type annotation (at {})", expr.location)),
} }
} }
@ -308,94 +359,130 @@ pub fn get_type_from_type_annotation_kinds(
TypeAnnotation::CustomClass { id: obj_id, params } => { TypeAnnotation::CustomClass { id: obj_id, params } => {
let def_read = top_level_defs[obj_id.0].read(); let def_read = top_level_defs[obj_id.0].read();
let class_def: &TopLevelDef = def_read.deref(); let class_def: &TopLevelDef = def_read.deref();
if let TopLevelDef::Class { fields, methods, type_vars, .. } = class_def { let TopLevelDef::Class { fields, methods, type_vars, .. } = class_def else {
if type_vars.len() != params.len() { unreachable!("should be class def here")
Err(format!( };
"unexpected number of type parameters: expected {} but got {}",
type_vars.len(),
params.len()
))
} else {
let param_ty = params
.iter()
.map(|x| {
get_type_from_type_annotation_kinds(
top_level_defs,
unifier,
primitives,
x,
subst_list
)
})
.collect::<Result<Vec<_>, _>>()?;
let subst = { if type_vars.len() != params.len() {
// check for compatible range return Err(format!(
// TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check "unexpected number of type parameters: expected {} but got {}",
let mut result: HashMap<u32, Type> = HashMap::new(); type_vars.len(),
for (tvar, p) in type_vars.iter().zip(param_ty) { params.len()
if let TypeEnum::TVar { id, range, fields: None, name, loc } = ))
unifier.get_ty(*tvar).as_ref() }
{
let ok: bool = { let param_ty = params
// create a temp type var and unify to check compatibility .iter()
p == *tvar || { .map(|x| {
let temp = unifier.get_fresh_var_with_range( get_type_from_type_annotation_kinds(
range.as_slice(), top_level_defs,
*name, unifier,
*loc, primitives,
); x,
unifier.unify(temp.0, p).is_ok() subst_list
} )
}; })
if ok { .collect::<Result<Vec<_>, _>>()?;
result.insert(*id, p);
} else { let subst = {
return Err(format!( // check for compatible range
"cannot apply type {} to type variable with id {:?}", // TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check
unifier.internal_stringify( let mut result: HashMap<u32, Type> = HashMap::new();
p, for (tvar, p) in type_vars.iter().zip(param_ty) {
&mut |id| format!("class{}", id), match unifier.get_ty(*tvar).as_ref() {
&mut |id| format!("typevar{}", id), TypeEnum::TVar { id, range, fields: None, name, loc, is_const_generic: false } => {
&mut None let ok: bool = {
), // create a temp type var and unify to check compatibility
*id p == *tvar || {
)); let temp = unifier.get_fresh_var_with_range(
range.as_slice(),
*name,
*loc,
);
unifier.unify(temp.0, p).is_ok()
} }
};
if ok {
result.insert(*id, p);
} else { } else {
unreachable!("must be generic type var") return Err(format!(
"cannot apply type {} to type variable with id {:?}",
unifier.internal_stringify(
p,
&mut |id| format!("class{}", id),
&mut |id| format!("typevar{}", id),
&mut None
),
*id
));
} }
} }
result
}; TypeEnum::TVar { id, range, name, loc, is_const_generic: true, .. } => {
let mut tobj_fields = methods let ty = range[0];
.iter() let ok: bool = {
.map(|(name, ty, _)| { // create a temp type var and unify to check compatibility
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty); p == *tvar || {
// methods are immutable let temp = unifier.get_fresh_const_generic_var(
(*name, (subst_ty, false)) ty,
}) *name,
.collect::<HashMap<_, _>>(); *loc,
tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| { );
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty); unifier.unify(temp.0, p).is_ok()
(*name, (subst_ty, *mutability)) }
})); };
let need_subst = !subst.is_empty(); if ok {
let ty = unifier.add_ty(TypeEnum::TObj { result.insert(*id, p);
obj_id: *obj_id, } else {
fields: tobj_fields, return Err(format!(
params: subst, "cannot apply type {} to type variable {}",
}); unifier.stringify(p),
if need_subst { name.unwrap_or_else(|| format!("typevar{id}").into()),
subst_list.as_mut().map(|wl| wl.push(ty)); ))
}
}
_ => unreachable!("must be generic type var"),
} }
Ok(ty)
} }
} else { result
unreachable!("should be class def here") };
let mut tobj_fields = methods
.iter()
.map(|(name, ty, _)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
// methods are immutable
(*name, (subst_ty, false))
})
.collect::<HashMap<_, _>>();
tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*name, (subst_ty, *mutability))
}));
let need_subst = !subst.is_empty();
let ty = unifier.add_ty(TypeEnum::TObj {
obj_id: *obj_id,
fields: tobj_fields,
params: subst,
});
if need_subst {
subst_list.as_mut().map(|wl| wl.push(ty));
} }
Ok(ty)
} }
TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty), TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty),
TypeAnnotation::Constant { ty, value, .. } => {
let ty_enum = unifier.get_ty(*ty);
let (ty, loc) = match &*ty_enum {
TypeEnum::TVar { range: ntv_underlying_ty, loc, is_const_generic: true, .. } => {
(ntv_underlying_ty[0], loc)
}
_ => unreachable!("{} ({})", unifier.stringify(*ty), ty_enum.get_type_name()),
};
let var = unifier.get_fresh_constant(value.clone(), ty, *loc);
Ok(var)
}
TypeAnnotation::Virtual(ty) => { TypeAnnotation::Virtual(ty) => {
let ty = get_type_from_type_annotation_kinds( let ty = get_type_from_type_annotation_kinds(
top_level_defs, top_level_defs,
@ -470,7 +557,7 @@ pub fn get_type_var_contained_in_type_annotation(ann: &TypeAnnotation) -> Vec<Ty
result.extend(get_type_var_contained_in_type_annotation(a)); result.extend(get_type_var_contained_in_type_annotation(a));
} }
} }
TypeAnnotation::Primitive(..) => {} TypeAnnotation::Primitive(..) | TypeAnnotation::Constant { .. } => {}
} }
result result
} }

View File

@ -62,7 +62,7 @@ impl<'a> Inferencer<'a> {
) -> Result<(), String> { ) -> Result<(), String> {
// there are some cases where the custom field is None // there are some cases where the custom field is None
if let Some(ty) = &expr.custom { if let Some(ty) = &expr.custom {
if !self.unifier.is_concrete(*ty, &self.function_data.bound_variables) { if !matches!(&expr.node, ExprKind::Constant { value: Constant::Ellipsis, .. }) && !self.unifier.is_concrete(*ty, &self.function_data.bound_variables) {
return Err(format!( return Err(format!(
"expected concrete type at {} but got {}", "expected concrete type at {} but got {}",
expr.location, expr.location,

View File

@ -964,6 +964,7 @@ impl<'a> Inferencer<'a> {
ast::Constant::Str(_) => Ok(self.primitives.str), ast::Constant::Str(_) => Ok(self.primitives.str),
ast::Constant::None ast::Constant::None
=> report_error("CPython `None` not supported (nac3 uses `none` instead)", *loc), => report_error("CPython `None` not supported (nac3 uses `none` instead)", *loc),
ast::Constant::Ellipsis => Ok(self.unifier.get_fresh_var(None, None).0),
_ => report_error("not supported", *loc), _ => report_error("not supported", *loc),
} }
} }

View File

@ -134,6 +134,17 @@ pub enum TypeEnum {
range: Vec<Type>, range: Vec<Type>,
name: Option<StrRef>, name: Option<StrRef>,
loc: Option<Location>, loc: Option<Location>,
/// Whether this type variable refers to a const-generic variable.
is_const_generic: bool,
},
/// A constant for substitution into a const generic variable.
TConstant {
/// The value of the constant.
value: SymbolValue,
/// The underlying type of the value.
ty: Type,
loc: Option<Location>,
}, },
/// A tuple type. /// A tuple type.
@ -178,6 +189,7 @@ impl TypeEnum {
match self { match self {
TypeEnum::TRigidVar { .. } => "TRigidVar", TypeEnum::TRigidVar { .. } => "TRigidVar",
TypeEnum::TVar { .. } => "TVar", TypeEnum::TVar { .. } => "TVar",
TypeEnum::TConstant { .. } => "TConstant",
TypeEnum::TTuple { .. } => "TTuple", TypeEnum::TTuple { .. } => "TTuple",
TypeEnum::TList { .. } => "TList", TypeEnum::TList { .. } => "TList",
TypeEnum::TObj { .. } => "TObj", TypeEnum::TObj { .. } => "TObj",
@ -263,6 +275,7 @@ impl Unifier {
fields: Some(fields), fields: Some(fields),
name: None, name: None,
loc: None, loc: None,
is_const_generic: false,
}) })
} }
@ -336,7 +349,33 @@ impl Unifier {
let id = self.var_id + 1; let id = self.var_id + 1;
self.var_id += 1; self.var_id += 1;
let range = range.to_vec(); let range = range.to_vec();
(self.add_ty(TypeEnum::TVar { id, range, fields: None, name, loc }), id) (self.add_ty(TypeEnum::TVar { id, range, fields: None, name, loc, is_const_generic: false }), id)
}
/// Returns a fresh type representing a constant generic variable with the given underlying type
/// `ty`.
pub fn get_fresh_const_generic_var(
&mut self,
ty: Type,
name: Option<StrRef>,
loc: Option<Location>,
) -> (Type, u32) {
let id = self.var_id + 1;
self.var_id += 1;
(self.add_ty(TypeEnum::TVar { id, range: vec![ty], fields: None, name, loc, is_const_generic: true }), id)
}
/// Returns a fresh type representing a [fresh constant][TypeEnum::TConstant] with the given
/// `value` and type `ty`.
pub fn get_fresh_constant(
&mut self,
value: SymbolValue,
ty: Type,
loc: Option<Location>,
) -> Type {
assert!(matches!(self.get_ty(ty).as_ref(), TypeEnum::TObj { .. }));
self.add_ty(TypeEnum::TConstant { ty, value, loc })
} }
/// Unification would not unify rigid variables with other types, but we want to do this for /// Unification would not unify rigid variables with other types, but we want to do this for
@ -412,7 +451,7 @@ impl Unifier {
pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool { pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool {
use TypeEnum::*; use TypeEnum::*;
match &*self.get_ty(a) { match &*self.get_ty(a) {
TRigidVar { .. } => true, TRigidVar { .. } | TConstant { .. } => true,
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)), TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
TCall { .. } => false, TCall { .. } => false,
TList { ty } => self.is_concrete(*ty, allowed_typevars), TList { ty } => self.is_concrete(*ty, allowed_typevars),
@ -560,8 +599,8 @@ impl Unifier {
}; };
match (&*ty_a, &*ty_b) { match (&*ty_a, &*ty_b) {
( (
TVar { fields: fields1, id, name: name1, loc: loc1, .. }, TVar { fields: fields1, id, name: name1, loc: loc1, is_const_generic: false, .. },
TVar { fields: fields2, id: id2, name: name2, loc: loc2, .. }, TVar { fields: fields2, id: id2, name: name2, loc: loc2, is_const_generic: false, .. },
) => { ) => {
let new_fields = match (fields1, fields2) { let new_fields = match (fields1, fields2) {
(None, None) => None, (None, None) => None,
@ -616,10 +655,11 @@ impl Unifier {
range, range,
name: name1.or(*name2), name: name1.or(*name2),
loc: loc1.or(*loc2), loc: loc1.or(*loc2),
is_const_generic: false,
}), }),
); );
} }
(TVar { fields: None, range, .. }, _) => { (TVar { fields: None, range, is_const_generic: false, .. }, _) => {
// We check for the range of the type variable to see if unification is allowed. // We check for the range of the type variable to see if unification is allowed.
// Note that although b may be compatible with a, we may have to constrain type // Note that although b may be compatible with a, we may have to constrain type
// variables in b to make sure that instantiations of b would always be compatible // variables in b to make sure that instantiations of b would always be compatible
@ -636,7 +676,7 @@ impl Unifier {
self.unify_impl(x, b, false)?; self.unify_impl(x, b, false)?;
self.set_a_to_b(a, x); self.set_a_to_b(a, x);
} }
(TVar { fields: Some(fields), range, .. }, TTuple { ty }) => { (TVar { fields: Some(fields), range, is_const_generic: false, .. }, TTuple { ty }) => {
let len = ty.len() as i32; let len = ty.len() as i32;
for (k, v) in fields.iter() { for (k, v) in fields.iter() {
match *k { match *k {
@ -666,7 +706,7 @@ impl Unifier {
self.unify_impl(x, b, false)?; self.unify_impl(x, b, false)?;
self.set_a_to_b(a, x); self.set_a_to_b(a, x);
} }
(TVar { fields: Some(fields), range, .. }, TList { ty }) => { (TVar { fields: Some(fields), range, is_const_generic: false, .. }, TList { ty }) => {
for (k, v) in fields.iter() { for (k, v) in fields.iter() {
match *k { match *k {
RecordKey::Int(_) => { RecordKey::Int(_) => {
@ -681,6 +721,35 @@ impl Unifier {
self.unify_impl(x, b, false)?; self.unify_impl(x, b, false)?;
self.set_a_to_b(a, x); self.set_a_to_b(a, x);
} }
(TVar { id: id1, range: ty1, is_const_generic: true, .. }, TVar { id: id2, range: ty2, .. }) => {
let ty1 = ty1[0];
let ty2 = ty2[0];
if id1 != id2 {
self.unify_impl(ty1, ty2, false)?;
}
self.set_a_to_b(a, b);
}
(TVar { range: ty1, is_const_generic: true, .. }, TConstant { ty: ty2, .. }) => {
let ty1 = ty1[0];
self.unify_impl(ty1, *ty2, false)?;
self.set_a_to_b(a, b);
}
(TConstant { value: val1, ty: ty1, .. }, TConstant { value: val2, ty: ty2, .. }) => {
if val1 != val2 {
eprintln!("VALUE MISMATCH: lhs={val1:?} rhs={val2:?} eq={}", val1 == val2);
return self.incompatible_types(a, b)
}
self.unify_impl(*ty1, *ty2, false)?;
self.set_a_to_b(a, b);
}
(TTuple { ty: ty1 }, TTuple { ty: ty2 }) => { (TTuple { ty: ty1 }, TTuple { ty: ty2 }) => {
if ty1.len() != ty2.len() { if ty1.len() != ty2.len() {
return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None)); return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None));
@ -775,7 +844,14 @@ impl Unifier {
if id1 != id2 { if id1 != id2 {
self.incompatible_types(a, b)?; self.incompatible_types(a, b)?;
} }
for (x, y) in zip(params1.values(), params2.values()) {
// Sort the type arguments by its UnificationKey first, since `HashMap::iter` visits
// all K-V pairs "in arbitrary order"
let (tv1, tv2) = (
params1.iter().sorted_by_key(|(k, _)| *k).map(|(_, v)| v).collect_vec(),
params2.iter().sorted_by_key(|(k, _)| *k).map(|(_, v)| v).collect_vec(),
);
for (x, y) in zip(tv1, tv2) {
if self.unify_impl(*x, *y, false).is_err() { if self.unify_impl(*x, *y, false).is_err() {
return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None)); return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None));
}; };
@ -928,6 +1004,9 @@ impl Unifier {
}; };
n n
} }
TypeEnum::TConstant { value, .. } => {
format!("const({value})")
}
TypeEnum::TTuple { ty } => { TypeEnum::TTuple { ty } => {
let mut fields = let mut fields =
ty.iter().map(|v| self.internal_stringify(*v, obj_to_name, var_to_name, notes)); ty.iter().map(|v| self.internal_stringify(*v, obj_to_name, var_to_name, notes));
@ -983,8 +1062,8 @@ impl Unifier {
} }
} }
/// Unifies `a` and `b` together, and set the value to the value of `b`.
fn set_a_to_b(&mut self, a: Type, b: Type) { fn set_a_to_b(&mut self, a: Type, b: Type) {
// unify a and b together, and set the value to b's value.
let table = &mut self.unification_table; let table = &mut self.unification_table;
let ty_b = table.probe_value(b).clone(); let ty_b = table.probe_value(b).clone();
table.unify(a, b); table.unify(a, b);
@ -1207,6 +1286,7 @@ impl Unifier {
range, range,
name: name2.or(*name), name: name2.or(*name),
loc: loc2.or(*loc), loc: loc2.or(*loc),
is_const_generic: false,
}; };
Ok(Some(self.unification_table.new_key(ty.into()))) Ok(Some(self.unification_table.new_key(ty.into())))
} }

View File

@ -9,7 +9,7 @@ import pathlib
from numpy import int32, int64, uint32, uint64 from numpy import int32, int64, uint32, uint64
from scipy import special from scipy import special
from typing import TypeVar, Generic from typing import TypeVar, Generic, Any
T = TypeVar('T') T = TypeVar('T')
class Option(Generic[T]): class Option(Generic[T]):
@ -94,11 +94,20 @@ def patch(module):
else: else:
raise NotImplementedError raise NotImplementedError
def TypeVarDummy(zelf, name, *constraints):
if len(constraints) == 1:
zelf.__init_base__(name, *constraints, Any)
else:
zelf.__init_base__(name, *constraints)
module.int32 = int32 module.int32 = int32
module.int64 = int64 module.int64 = int64
module.uint32 = uint32 module.uint32 = uint32
module.uint64 = uint64 module.uint64 = uint64
module.TypeVar = TypeVar module.TypeVar = TypeVar
module.ConstGeneric = TypeVar
module.ConstGeneric.__init_base__ = TypeVar.__init__
module.ConstGeneric.__init__ = TypeVarDummy
module.Generic = Generic module.Generic = Generic
module.extern = extern module.extern = extern
module.Option = Option module.Option = Option

View File

@ -0,0 +1,50 @@
A = ConstGeneric("A", int32)
B = ConstGeneric("B", uint32)
T = TypeVar("T")
class ConstGenericClass(Generic[A]):
def __init__(self):
pass
class ConstGeneric2Class(Generic[A, B]):
def __init__(self):
pass
class HybridGenericClass2(Generic[A, T]):
pass
class HybridGenericClass3(Generic[T, A, B]):
pass
def make_generic_2() -> ConstGenericClass[2]:
return ...
def make_generic2_1_2() -> ConstGeneric2Class[1, 2]:
return ...
def make_hybrid_class_2_int32() -> HybridGenericClass2[2, int32]:
return ...
def make_hybrid_class_i32_0_1() -> HybridGenericClass3[int32, 0, 1]:
return ...
def consume_generic_2(instance: ConstGenericClass[2]):
pass
def consume_generic2_1_2(instance: ConstGeneric2Class[1, 2]):
pass
def consume_hybrid_class_2_i32(instance: HybridGenericClass2[2, int32]):
pass
def consume_hybrid_class_i32_0_1(instance: HybridGenericClass3[int32, 0, 1]):
pass
def f():
consume_generic_2(make_generic_2())
consume_generic2_1_2(make_generic2_1_2())
consume_hybrid_class_2_i32(make_hybrid_class_2_int32())
consume_hybrid_class_i32_0_1(make_hybrid_class_i32_0_1())
def run() -> int32:
return 0

View File

@ -25,7 +25,7 @@ use nac3core::{
}, },
}; };
use nac3parser::{ use nac3parser::{
ast::{Expr, ExprKind, StmtKind}, ast::{Constant, Expr, ExprKind, StmtKind, StrRef},
parser, parser,
}; };
@ -76,13 +76,18 @@ fn handle_typevar_definition(
) -> Result<Type, String> { ) -> Result<Type, String> {
let ExprKind::Call { func, args, .. } = &var.node else { let ExprKind::Call { func, args, .. } = &var.node else {
return Err(format!( return Err(format!(
"expression {:?} cannot be handled as a TypeVar in global scope", "expression {:?} cannot be handled as a TypeVar or ConstGeneric in global scope",
var var
)) ))
}; };
match &func.node { match &func.node {
ExprKind::Name { id, .. } if id == &"TypeVar".into() => { ExprKind::Name { id, .. } if id == &"TypeVar".into() => {
let ExprKind::Constant { value: Constant::Str(ty_name), .. } = &args[0].node else {
unreachable!("Expected string constant for first parameter of `TypeVar`, got {:?}", &args[0].node)
};
let generic_name: StrRef = ty_name.to_string().into();
let constraints = args let constraints = args
.iter() .iter()
.skip(1) .skip(1)
@ -94,17 +99,34 @@ fn handle_typevar_definition(
primitives, primitives,
x, x,
Default::default(), Default::default(),
None,
)?; )?;
get_type_from_type_annotation_kinds( get_type_from_type_annotation_kinds(
def_list, unifier, primitives, &ty, &mut None def_list, unifier, primitives, &ty, &mut None
) )
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
Ok(unifier.get_fresh_var_with_range(&constraints, None, None).0) let loc = func.location;
if constraints.len() == 1 {
return Err(format!("A single constraint is not allowed (at {})", loc))
}
Ok(unifier.get_fresh_var_with_range(&constraints, Some(generic_name), Some(loc)).0)
} }
ExprKind::Name { id, .. } if id == &"NonTypeVar".into() => { ExprKind::Name { id, .. } if id == &"ConstGeneric".into() => {
assert_eq!(args.len(), 2); if args.len() != 2 {
return Err(format!("Expected 2 arguments for `ConstGeneric`, got {}", args.len()))
}
let ExprKind::Constant { value: Constant::Str(ty_name), .. } = &args[0].node else {
return Err(format!(
"Expected string constant for first parameter of `ConstGeneric`, got {:?}",
&args[0].node
))
};
let generic_name: StrRef = ty_name.to_string().into();
let ty = parse_ast_to_type_annotation_kinds( let ty = parse_ast_to_type_annotation_kinds(
resolver, resolver,
@ -113,11 +135,14 @@ fn handle_typevar_definition(
primitives, primitives,
&args[1], &args[1],
Default::default(), Default::default(),
None,
)?; )?;
let constraint = get_type_from_type_annotation_kinds( let constraint = get_type_from_type_annotation_kinds(
def_list, unifier, primitives, &ty, &mut None def_list, unifier, primitives, &ty, &mut None
)?; )?;
Ok(unifier.get_fresh_var_with_range(&[constraint], None, None).0) let loc = func.location;
Ok(unifier.get_fresh_const_generic_var(constraint, Some(generic_name), Some(loc)).0)
} }
_ => Err(format!( _ => Err(format!(