2021-11-23 07:32:09 +08:00
use std ::convert ::TryInto ;
2024-06-12 15:01:01 +08:00
use strum ::IntoEnumIterator ;
use strum_macros ::EnumIter ;
2021-11-23 07:32:09 +08:00
2024-10-17 15:57:33 +08:00
use nac3parser ::ast ::{ Constant , ExprKind , Location } ;
2024-10-03 12:37:56 +08:00
2024-10-17 15:57:33 +08:00
use super ::{ numpy ::unpack_ndarray_var_tys , * } ;
2024-10-03 12:37:56 +08:00
use crate ::{
symbol_resolver ::SymbolValue ,
typecheck ::typedef ::{ into_var_map , iter_type_vars , Mapping , TypeVarId , VarMap } ,
} ;
2021-08-27 10:21:51 +08:00
2024-06-12 15:01:01 +08:00
/// All primitive types and functions in nac3core.
#[ derive(Clone, Copy, Debug, EnumIter, PartialEq, Eq) ]
pub enum PrimDef {
2024-07-02 11:05:05 +08:00
// Classes
2024-06-12 15:01:01 +08:00
Int32 ,
Int64 ,
Float ,
Bool ,
None ,
Range ,
Str ,
Exception ,
UInt32 ,
UInt64 ,
Option ,
2024-07-02 11:05:05 +08:00
List ,
NDArray ,
2024-07-26 10:58:15 +08:00
// Option methods
2024-07-26 10:59:50 +08:00
FunOptionIsSome ,
FunOptionIsNone ,
FunOptionUnwrap ,
2024-07-26 10:58:15 +08:00
// Option-related functions
FunSome ,
// NDArray methods
2024-07-26 10:59:50 +08:00
FunNDArrayCopy ,
FunNDArrayFill ,
2024-07-26 10:58:15 +08:00
// Range methods
FunRangeInit ,
2024-07-26 11:38:23 +08:00
// NumPy factory functions
2024-06-12 15:01:01 +08:00
FunNpNDArray ,
FunNpEmpty ,
FunNpZeros ,
FunNpOnes ,
FunNpFull ,
FunNpArray ,
FunNpEye ,
FunNpIdentity ,
2024-07-26 11:38:23 +08:00
// Miscellaneous NumPy & SciPy functions
2024-06-12 15:01:01 +08:00
FunNpRound ,
FunNpFloor ,
FunNpCeil ,
FunNpMin ,
FunNpMinimum ,
2024-07-12 18:18:28 +08:00
FunNpArgmin ,
2024-06-12 15:01:01 +08:00
FunNpMax ,
FunNpMaximum ,
2024-07-12 18:18:28 +08:00
FunNpArgmax ,
2024-06-12 15:01:01 +08:00
FunNpIsNan ,
FunNpIsInf ,
FunNpSin ,
FunNpCos ,
FunNpExp ,
FunNpExp2 ,
FunNpLog ,
FunNpLog10 ,
FunNpLog2 ,
FunNpFabs ,
FunNpSqrt ,
FunNpRint ,
FunNpTan ,
FunNpArcsin ,
FunNpArccos ,
FunNpArctan ,
FunNpSinh ,
FunNpCosh ,
FunNpTanh ,
FunNpArcsinh ,
FunNpArccosh ,
FunNpArctanh ,
FunNpExpm1 ,
FunNpCbrt ,
FunSpSpecErf ,
FunSpSpecErfc ,
FunSpSpecGamma ,
FunSpSpecGammaln ,
FunSpSpecJ0 ,
FunSpSpecJ1 ,
FunNpArctan2 ,
FunNpCopysign ,
FunNpFmax ,
FunNpFmin ,
FunNpLdExp ,
FunNpHypot ,
FunNpNextAfter ,
2024-07-31 13:16:42 +08:00
FunNpTranspose ,
FunNpReshape ,
2024-07-02 11:05:05 +08:00
2024-07-29 17:30:09 +08:00
// Linalg functions
2024-07-25 12:16:53 +08:00
FunNpDot ,
FunNpLinalgCholesky ,
FunNpLinalgQr ,
FunNpLinalgSvd ,
FunNpLinalgInv ,
FunNpLinalgPinv ,
2024-07-31 18:02:54 +08:00
FunNpLinalgMatrixPower ,
FunNpLinalgDet ,
2024-07-25 12:16:53 +08:00
FunSpLinalgLu ,
FunSpLinalgSchur ,
FunSpLinalgHessenberg ,
2024-07-26 11:38:23 +08:00
// Miscellaneous Python & NAC3 functions
2024-07-26 10:58:15 +08:00
FunInt32 ,
FunInt64 ,
FunUInt32 ,
FunUInt64 ,
FunFloat ,
FunRound ,
FunRound64 ,
FunStr ,
FunBool ,
FunFloor ,
FunFloor64 ,
FunCeil ,
FunCeil64 ,
FunLen ,
FunMin ,
FunMax ,
FunAbs ,
2024-02-26 15:11:00 +08:00
}
2024-06-12 15:09:20 +08:00
/// Associated details of a [`PrimDef`]
pub enum PrimDefDetails {
PrimFunction { name : & 'static str , simple_name : & 'static str } ,
2024-07-18 15:47:40 +08:00
PrimClass { name : & 'static str , get_ty_fn : fn ( & PrimitiveStore ) -> Type } ,
2024-06-12 15:09:20 +08:00
}
2024-06-12 15:01:01 +08:00
impl PrimDef {
/// Get the assigned [`DefinitionId`] of this [`PrimDef`].
2024-02-26 15:11:00 +08:00
///
2024-06-12 15:01:01 +08:00
/// The assigned definition ID is defined by the position this [`PrimDef`] enum unit variant is defined at,
/// with the first `PrimDef`'s definition id being `0`.
2024-02-26 15:11:00 +08:00
#[ must_use ]
2024-06-12 15:01:01 +08:00
pub fn id ( & self ) -> DefinitionId {
DefinitionId ( * self as usize )
2024-02-26 15:11:00 +08:00
}
2024-06-12 15:01:01 +08:00
/// Check if a definition ID is that of a [`PrimDef`].
2024-02-26 15:11:00 +08:00
#[ must_use ]
2024-06-12 15:01:01 +08:00
pub fn contains_id ( id : DefinitionId ) -> bool {
Self ::iter ( ) . any ( | prim | prim . id ( ) = = id )
2024-02-26 15:11:00 +08:00
}
2024-06-12 15:09:20 +08:00
/// Get the definition "simple name" of this [`PrimDef`].
///
/// If the [`PrimDef`] is a function, this corresponds to [`TopLevelDef::Function::simple_name`].
///
/// If the [`PrimDef`] is a class, this returns [`None`].
#[ must_use ]
pub fn simple_name ( & self ) -> & 'static str {
match self . details ( ) {
PrimDefDetails ::PrimFunction { simple_name , .. } = > simple_name ,
PrimDefDetails ::PrimClass { .. } = > {
panic! ( " PrimDef {self:?} has no simple_name as it is not a function. " )
}
}
}
/// Get the definition "name" of this [`PrimDef`].
///
/// If the [`PrimDef`] is a function, this corresponds to [`TopLevelDef::Function::name`].
///
/// If the [`PrimDef`] is a class, this corresponds to [`TopLevelDef::Class::name`].
#[ must_use ]
pub fn name ( & self ) -> & 'static str {
match self . details ( ) {
2024-07-18 15:47:40 +08:00
PrimDefDetails ::PrimFunction { name , .. } | PrimDefDetails ::PrimClass { name , .. } = > {
name
}
2024-06-12 15:09:20 +08:00
}
}
/// Get the associated details of this [`PrimDef`]
#[ must_use ]
pub fn details ( self ) -> PrimDefDetails {
2024-07-18 15:47:40 +08:00
fn class ( name : & 'static str , get_ty_fn : fn ( & PrimitiveStore ) -> Type ) -> PrimDefDetails {
PrimDefDetails ::PrimClass { name , get_ty_fn }
2024-06-12 15:09:20 +08:00
}
fn fun ( name : & 'static str , simple_name : Option < & 'static str > ) -> PrimDefDetails {
PrimDefDetails ::PrimFunction { simple_name : simple_name . unwrap_or ( name ) , name }
}
match self {
2024-07-26 11:32:12 +08:00
// Classes
2024-07-18 15:47:40 +08:00
PrimDef ::Int32 = > class ( " int32 " , | primitives | primitives . int32 ) ,
PrimDef ::Int64 = > class ( " int64 " , | primitives | primitives . int64 ) ,
PrimDef ::Float = > class ( " float " , | primitives | primitives . float ) ,
PrimDef ::Bool = > class ( " bool " , | primitives | primitives . bool ) ,
PrimDef ::None = > class ( " none " , | primitives | primitives . none ) ,
PrimDef ::Range = > class ( " range " , | primitives | primitives . range ) ,
PrimDef ::Str = > class ( " str " , | primitives | primitives . str ) ,
PrimDef ::Exception = > class ( " Exception " , | primitives | primitives . exception ) ,
PrimDef ::UInt32 = > class ( " uint32 " , | primitives | primitives . uint32 ) ,
PrimDef ::UInt64 = > class ( " uint64 " , | primitives | primitives . uint64 ) ,
PrimDef ::Option = > class ( " Option " , | primitives | primitives . option ) ,
2024-07-26 11:32:12 +08:00
PrimDef ::List = > class ( " list " , | primitives | primitives . list ) ,
PrimDef ::NDArray = > class ( " ndarray " , | primitives | primitives . ndarray ) ,
// Option methods
2024-07-26 10:59:50 +08:00
PrimDef ::FunOptionIsSome = > fun ( " Option.is_some " , Some ( " is_some " ) ) ,
PrimDef ::FunOptionIsNone = > fun ( " Option.is_none " , Some ( " is_none " ) ) ,
PrimDef ::FunOptionUnwrap = > fun ( " Option.unwrap " , Some ( " unwrap " ) ) ,
2024-07-26 11:32:12 +08:00
// Option-related functions
PrimDef ::FunSome = > fun ( " Some " , None ) ,
// NDArray methods
2024-07-26 10:59:50 +08:00
PrimDef ::FunNDArrayCopy = > fun ( " ndarray.copy " , Some ( " copy " ) ) ,
PrimDef ::FunNDArrayFill = > fun ( " ndarray.fill " , Some ( " fill " ) ) ,
2024-07-26 11:32:12 +08:00
// Range methods
PrimDef ::FunRangeInit = > fun ( " range.__init__ " , Some ( " __init__ " ) ) ,
2024-07-26 11:38:23 +08:00
// NumPy factory functions
2024-06-12 15:09:20 +08:00
PrimDef ::FunNpNDArray = > fun ( " np_ndarray " , None ) ,
PrimDef ::FunNpEmpty = > fun ( " np_empty " , None ) ,
PrimDef ::FunNpZeros = > fun ( " np_zeros " , None ) ,
PrimDef ::FunNpOnes = > fun ( " np_ones " , None ) ,
PrimDef ::FunNpFull = > fun ( " np_full " , None ) ,
PrimDef ::FunNpArray = > fun ( " np_array " , None ) ,
PrimDef ::FunNpEye = > fun ( " np_eye " , None ) ,
PrimDef ::FunNpIdentity = > fun ( " np_identity " , None ) ,
2024-07-26 11:38:23 +08:00
// Miscellaneous NumPy & SciPy functions
2024-06-12 15:09:20 +08:00
PrimDef ::FunNpRound = > fun ( " np_round " , None ) ,
PrimDef ::FunNpFloor = > fun ( " np_floor " , None ) ,
PrimDef ::FunNpCeil = > fun ( " np_ceil " , None ) ,
PrimDef ::FunNpMin = > fun ( " np_min " , None ) ,
PrimDef ::FunNpMinimum = > fun ( " np_minimum " , None ) ,
2024-07-12 18:18:28 +08:00
PrimDef ::FunNpArgmin = > fun ( " np_argmin " , None ) ,
2024-06-12 15:09:20 +08:00
PrimDef ::FunNpMax = > fun ( " np_max " , None ) ,
PrimDef ::FunNpMaximum = > fun ( " np_maximum " , None ) ,
2024-07-12 18:18:28 +08:00
PrimDef ::FunNpArgmax = > fun ( " np_argmax " , None ) ,
2024-06-12 15:09:20 +08:00
PrimDef ::FunNpIsNan = > fun ( " np_isnan " , None ) ,
PrimDef ::FunNpIsInf = > fun ( " np_isinf " , None ) ,
PrimDef ::FunNpSin = > fun ( " np_sin " , None ) ,
PrimDef ::FunNpCos = > fun ( " np_cos " , None ) ,
PrimDef ::FunNpExp = > fun ( " np_exp " , None ) ,
PrimDef ::FunNpExp2 = > fun ( " np_exp2 " , None ) ,
PrimDef ::FunNpLog = > fun ( " np_log " , None ) ,
PrimDef ::FunNpLog10 = > fun ( " np_log10 " , None ) ,
PrimDef ::FunNpLog2 = > fun ( " np_log2 " , None ) ,
PrimDef ::FunNpFabs = > fun ( " np_fabs " , None ) ,
PrimDef ::FunNpSqrt = > fun ( " np_sqrt " , None ) ,
PrimDef ::FunNpRint = > fun ( " np_rint " , None ) ,
PrimDef ::FunNpTan = > fun ( " np_tan " , None ) ,
PrimDef ::FunNpArcsin = > fun ( " np_arcsin " , None ) ,
PrimDef ::FunNpArccos = > fun ( " np_arccos " , None ) ,
PrimDef ::FunNpArctan = > fun ( " np_arctan " , None ) ,
PrimDef ::FunNpSinh = > fun ( " np_sinh " , None ) ,
PrimDef ::FunNpCosh = > fun ( " np_cosh " , None ) ,
PrimDef ::FunNpTanh = > fun ( " np_tanh " , None ) ,
PrimDef ::FunNpArcsinh = > fun ( " np_arcsinh " , None ) ,
PrimDef ::FunNpArccosh = > fun ( " np_arccosh " , None ) ,
PrimDef ::FunNpArctanh = > fun ( " np_arctanh " , None ) ,
PrimDef ::FunNpExpm1 = > fun ( " np_expm1 " , None ) ,
PrimDef ::FunNpCbrt = > fun ( " np_cbrt " , None ) ,
PrimDef ::FunSpSpecErf = > fun ( " sp_spec_erf " , None ) ,
PrimDef ::FunSpSpecErfc = > fun ( " sp_spec_erfc " , None ) ,
PrimDef ::FunSpSpecGamma = > fun ( " sp_spec_gamma " , None ) ,
PrimDef ::FunSpSpecGammaln = > fun ( " sp_spec_gammaln " , None ) ,
PrimDef ::FunSpSpecJ0 = > fun ( " sp_spec_j0 " , None ) ,
PrimDef ::FunSpSpecJ1 = > fun ( " sp_spec_j1 " , None ) ,
PrimDef ::FunNpArctan2 = > fun ( " np_arctan2 " , None ) ,
PrimDef ::FunNpCopysign = > fun ( " np_copysign " , None ) ,
PrimDef ::FunNpFmax = > fun ( " np_fmax " , None ) ,
PrimDef ::FunNpFmin = > fun ( " np_fmin " , None ) ,
PrimDef ::FunNpLdExp = > fun ( " np_ldexp " , None ) ,
PrimDef ::FunNpHypot = > fun ( " np_hypot " , None ) ,
PrimDef ::FunNpNextAfter = > fun ( " np_nextafter " , None ) ,
2024-07-31 13:16:42 +08:00
PrimDef ::FunNpTranspose = > fun ( " np_transpose " , None ) ,
PrimDef ::FunNpReshape = > fun ( " np_reshape " , None ) ,
// Linalg functions
2024-07-25 12:16:53 +08:00
PrimDef ::FunNpDot = > fun ( " np_dot " , None ) ,
PrimDef ::FunNpLinalgCholesky = > fun ( " np_linalg_cholesky " , None ) ,
PrimDef ::FunNpLinalgQr = > fun ( " np_linalg_qr " , None ) ,
PrimDef ::FunNpLinalgSvd = > fun ( " np_linalg_svd " , None ) ,
PrimDef ::FunNpLinalgInv = > fun ( " np_linalg_inv " , None ) ,
PrimDef ::FunNpLinalgPinv = > fun ( " np_linalg_pinv " , None ) ,
2024-07-31 18:02:54 +08:00
PrimDef ::FunNpLinalgMatrixPower = > fun ( " np_linalg_matrix_power " , None ) ,
PrimDef ::FunNpLinalgDet = > fun ( " np_linalg_det " , None ) ,
2024-07-25 12:16:53 +08:00
PrimDef ::FunSpLinalgLu = > fun ( " sp_linalg_lu " , None ) ,
PrimDef ::FunSpLinalgSchur = > fun ( " sp_linalg_schur " , None ) ,
PrimDef ::FunSpLinalgHessenberg = > fun ( " sp_linalg_hessenberg " , None ) ,
2024-07-26 11:32:12 +08:00
2024-07-26 11:38:23 +08:00
// Miscellaneous Python & NAC3 functions
2024-07-26 11:32:12 +08:00
PrimDef ::FunInt32 = > fun ( " int32 " , None ) ,
PrimDef ::FunInt64 = > fun ( " int64 " , None ) ,
PrimDef ::FunUInt32 = > fun ( " uint32 " , None ) ,
PrimDef ::FunUInt64 = > fun ( " uint64 " , None ) ,
PrimDef ::FunFloat = > fun ( " float " , None ) ,
PrimDef ::FunRound = > fun ( " round " , None ) ,
PrimDef ::FunRound64 = > fun ( " round64 " , None ) ,
PrimDef ::FunStr = > fun ( " str " , None ) ,
PrimDef ::FunBool = > fun ( " bool " , None ) ,
PrimDef ::FunFloor = > fun ( " floor " , None ) ,
PrimDef ::FunFloor64 = > fun ( " floor64 " , None ) ,
PrimDef ::FunCeil = > fun ( " ceil " , None ) ,
PrimDef ::FunCeil64 = > fun ( " ceil64 " , None ) ,
PrimDef ::FunLen = > fun ( " len " , None ) ,
PrimDef ::FunMin = > fun ( " min " , None ) ,
PrimDef ::FunMax = > fun ( " max " , None ) ,
PrimDef ::FunAbs = > fun ( " abs " , None ) ,
2024-06-12 15:09:20 +08:00
}
}
}
/// Asserts that a [`PrimDef`] is in an allowlist.
///
/// Like `debug_assert!`, this statements of this function are only
/// enabled if `cfg!(debug_assertions)` is true.
pub fn debug_assert_prim_is_allowed ( prim : PrimDef , allowlist : & [ PrimDef ] ) {
if cfg! ( debug_assertions ) {
let allowed = allowlist . iter ( ) . any ( | p | * p = = prim ) ;
assert! (
allowed ,
" Disallowed primitive definition. Got {prim:?}, but expects it to be in {allowlist:?} "
) ;
}
2024-02-26 15:11:00 +08:00
}
2024-06-17 10:47:49 +08:00
/// Construct the fields of class `Exception`
/// See [`TypeEnum::TObj::fields`] and [`TopLevelDef::Class::fields`]
#[ must_use ]
pub fn make_exception_fields ( int32 : Type , int64 : Type , str : Type ) -> Vec < ( StrRef , Type , bool ) > {
vec! [
( " __name__ " . into ( ) , int32 , true ) ,
( " __file__ " . into ( ) , str , true ) ,
( " __line__ " . into ( ) , int32 , true ) ,
( " __col__ " . into ( ) , int32 , true ) ,
( " __func__ " . into ( ) , str , true ) ,
( " __message__ " . into ( ) , str , true ) ,
( " __param0__ " . into ( ) , int64 , true ) ,
( " __param1__ " . into ( ) , int64 , true ) ,
( " __param2__ " . into ( ) , int64 , true ) ,
]
}
2021-09-07 17:30:15 +08:00
impl TopLevelDef {
2022-02-21 18:27:46 +08:00
pub fn to_string ( & self , unifier : & mut Unifier ) -> String {
2021-09-07 17:30:15 +08:00
match self {
2022-02-21 18:27:46 +08:00
TopLevelDef ::Class { name , ancestors , fields , methods , type_vars , .. } = > {
2021-09-07 17:30:15 +08:00
let fields_str = fields
. iter ( )
2022-02-21 18:27:46 +08:00
. map ( | ( n , ty , _ ) | ( n . to_string ( ) , unifier . stringify ( * ty ) ) )
2021-09-07 17:30:15 +08:00
. collect_vec ( ) ;
2021-09-08 02:27:12 +08:00
2021-09-07 17:30:15 +08:00
let methods_str = methods
. iter ( )
2022-02-21 18:27:46 +08:00
. map ( | ( n , ty , id ) | ( n . to_string ( ) , unifier . stringify ( * ty ) , * id ) )
2021-09-07 17:30:15 +08:00
. collect_vec ( ) ;
format! (
2021-11-05 20:28:21 +08:00
" Class {{ \n name: {:?}, \n ancestors: {:?}, \n fields: {:?}, \n methods: {:?}, \n type_vars: {:?} \n }} " ,
2021-09-07 17:30:15 +08:00
name ,
2021-11-05 20:28:21 +08:00
ancestors . iter ( ) . map ( | ancestor | ancestor . stringify ( unifier ) ) . collect_vec ( ) ,
fields_str . iter ( ) . map ( | ( a , _ ) | a ) . collect_vec ( ) ,
methods_str . iter ( ) . map ( | ( a , b , _ ) | ( a , b ) ) . collect_vec ( ) ,
2022-02-21 17:52:34 +08:00
type_vars . iter ( ) . map ( | id | unifier . stringify ( * id ) ) . collect_vec ( ) ,
2021-09-07 17:30:15 +08:00
)
}
2021-09-08 02:27:12 +08:00
TopLevelDef ::Function { name , signature , var_id , .. } = > format! (
" Function {{ \n name: {:?}, \n sig: {:?}, \n var_id: {:?} \n }} " ,
name ,
2022-02-21 17:52:34 +08:00
unifier . stringify ( * signature ) ,
2021-09-09 00:44:56 +08:00
{
2021-09-12 13:14:46 +08:00
// preserve the order for debug output and test
2021-09-09 00:44:56 +08:00
let mut r = var_id . clone ( ) ;
r . sort_unstable ( ) ;
r
}
2021-09-08 02:27:12 +08:00
) ,
2024-10-03 14:21:11 +08:00
TopLevelDef ::Variable { name , ty , .. } = > {
format! ( " Variable {{ name: {name:?} , ty: {:?} }} " , unifier . stringify ( * ty ) , )
}
2021-09-07 17:30:15 +08:00
}
}
}
2021-08-27 10:21:51 +08:00
impl TopLevelComposer {
2023-12-08 17:43:32 +08:00
#[ must_use ]
2023-12-15 14:02:30 +08:00
pub fn make_primitives ( size_t : u32 ) -> ( PrimitiveStore , Unifier ) {
2021-08-27 10:21:51 +08:00
let mut unifier = Unifier ::new ( ) ;
let int32 = unifier . add_ty ( TypeEnum ::TObj {
2024-06-12 15:01:01 +08:00
obj_id : PrimDef ::Int32 . id ( ) ,
2022-02-21 18:27:46 +08:00
fields : HashMap ::new ( ) ,
2024-03-04 23:38:52 +08:00
params : VarMap ::new ( ) ,
2021-08-27 10:21:51 +08:00
} ) ;
let int64 = unifier . add_ty ( TypeEnum ::TObj {
2024-06-12 15:01:01 +08:00
obj_id : PrimDef ::Int64 . id ( ) ,
2022-02-21 18:27:46 +08:00
fields : HashMap ::new ( ) ,
2024-03-04 23:38:52 +08:00
params : VarMap ::new ( ) ,
2021-08-27 10:21:51 +08:00
} ) ;
let float = unifier . add_ty ( TypeEnum ::TObj {
2024-06-12 15:01:01 +08:00
obj_id : PrimDef ::Float . id ( ) ,
2022-02-21 18:27:46 +08:00
fields : HashMap ::new ( ) ,
2024-03-04 23:38:52 +08:00
params : VarMap ::new ( ) ,
2021-08-27 10:21:51 +08:00
} ) ;
let bool = unifier . add_ty ( TypeEnum ::TObj {
2024-06-12 15:01:01 +08:00
obj_id : PrimDef ::Bool . id ( ) ,
2022-02-21 18:27:46 +08:00
fields : HashMap ::new ( ) ,
2024-03-04 23:38:52 +08:00
params : VarMap ::new ( ) ,
2021-08-27 10:21:51 +08:00
} ) ;
let none = unifier . add_ty ( TypeEnum ::TObj {
2024-06-12 15:01:01 +08:00
obj_id : PrimDef ::None . id ( ) ,
2022-02-21 18:27:46 +08:00
fields : HashMap ::new ( ) ,
2024-03-04 23:38:52 +08:00
params : VarMap ::new ( ) ,
2021-08-27 10:21:51 +08:00
} ) ;
2021-10-23 23:53:36 +08:00
let range = unifier . add_ty ( TypeEnum ::TObj {
2024-06-12 15:01:01 +08:00
obj_id : PrimDef ::Range . id ( ) ,
2024-07-09 13:31:45 +08:00
fields : [
( " start " . into ( ) , ( int32 , true ) ) ,
( " stop " . into ( ) , ( int32 , true ) ) ,
( " step " . into ( ) , ( int32 , true ) ) ,
]
. into_iter ( )
. collect ( ) ,
2024-03-04 23:38:52 +08:00
params : VarMap ::new ( ) ,
2021-10-23 23:53:36 +08:00
} ) ;
2021-11-02 23:22:37 +08:00
let str = unifier . add_ty ( TypeEnum ::TObj {
2024-06-12 15:01:01 +08:00
obj_id : PrimDef ::Str . id ( ) ,
2022-02-21 18:27:46 +08:00
fields : HashMap ::new ( ) ,
2024-03-04 23:38:52 +08:00
params : VarMap ::new ( ) ,
2021-11-02 23:22:37 +08:00
} ) ;
2022-02-12 21:09:23 +08:00
let exception = unifier . add_ty ( TypeEnum ::TObj {
2024-06-12 15:01:01 +08:00
obj_id : PrimDef ::Exception . id ( ) ,
2024-06-17 10:47:49 +08:00
fields : make_exception_fields ( int32 , int64 , str )
. into_iter ( )
. map ( | ( name , ty , mutable ) | ( name , ( ty , mutable ) ) )
. collect ( ) ,
2024-03-04 23:38:52 +08:00
params : VarMap ::new ( ) ,
2022-02-12 21:09:23 +08:00
} ) ;
2022-03-05 03:45:09 +08:00
let uint32 = unifier . add_ty ( TypeEnum ::TObj {
2024-06-12 15:01:01 +08:00
obj_id : PrimDef ::UInt32 . id ( ) ,
2022-03-05 03:45:09 +08:00
fields : HashMap ::new ( ) ,
2024-03-04 23:38:52 +08:00
params : VarMap ::new ( ) ,
2022-03-05 03:45:09 +08:00
} ) ;
let uint64 = unifier . add_ty ( TypeEnum ::TObj {
2024-06-12 15:01:01 +08:00
obj_id : PrimDef ::UInt64 . id ( ) ,
2022-03-05 03:45:09 +08:00
fields : HashMap ::new ( ) ,
2024-03-04 23:38:52 +08:00
params : VarMap ::new ( ) ,
2022-03-05 03:45:09 +08:00
} ) ;
2022-03-26 15:09:15 +08:00
let option_type_var = unifier . get_fresh_var ( Some ( " option_type_var " . into ( ) ) , None ) ;
let is_some_type_fun_ty = unifier . add_ty ( TypeEnum ::TFunc ( FunSignature {
args : vec ! [ ] ,
ret : bool ,
2024-06-13 16:03:32 +08:00
vars : into_var_map ( [ option_type_var ] ) ,
2022-03-26 15:09:15 +08:00
} ) ) ;
let unwrap_fun_ty = unifier . add_ty ( TypeEnum ::TFunc ( FunSignature {
args : vec ! [ ] ,
2024-06-13 13:28:39 +08:00
ret : option_type_var . ty ,
2024-06-13 16:03:32 +08:00
vars : into_var_map ( [ option_type_var ] ) ,
2022-03-26 15:09:15 +08:00
} ) ) ;
let option = unifier . add_ty ( TypeEnum ::TObj {
2024-06-12 15:01:01 +08:00
obj_id : PrimDef ::Option . id ( ) ,
2022-03-26 15:09:15 +08:00
fields : vec ! [
2024-07-26 10:59:50 +08:00
( PrimDef ::FunOptionIsSome . simple_name ( ) . into ( ) , ( is_some_type_fun_ty , true ) ) ,
( PrimDef ::FunOptionIsNone . simple_name ( ) . into ( ) , ( is_some_type_fun_ty , true ) ) ,
( PrimDef ::FunOptionUnwrap . simple_name ( ) . into ( ) , ( unwrap_fun_ty , true ) ) ,
2022-03-26 15:09:15 +08:00
]
. into_iter ( )
. collect ::< HashMap < _ , _ > > ( ) ,
2024-06-13 16:03:32 +08:00
params : into_var_map ( [ option_type_var ] ) ,
2022-03-26 15:09:15 +08:00
} ) ;
2024-02-27 13:39:05 +08:00
let size_t_ty = match size_t {
32 = > uint32 ,
64 = > uint64 ,
_ = > unreachable! ( ) ,
} ;
2024-07-02 11:05:05 +08:00
let list_elem_tvar = unifier . get_fresh_var ( Some ( " list_elem " . into ( ) ) , None ) ;
let list = unifier . add_ty ( TypeEnum ::TObj {
obj_id : PrimDef ::List . id ( ) ,
fields : Mapping ::new ( ) ,
params : into_var_map ( [ list_elem_tvar ] ) ,
} ) ;
2024-02-27 13:39:05 +08:00
let ndarray_dtype_tvar = unifier . get_fresh_var ( Some ( " ndarray_dtype " . into ( ) ) , None ) ;
2024-06-12 14:45:03 +08:00
let ndarray_ndims_tvar =
unifier . get_fresh_const_generic_var ( size_t_ty , Some ( " ndarray_ndims " . into ( ) ) , None ) ;
2024-03-07 13:02:13 +08:00
let ndarray_copy_fun_ret_ty = unifier . get_fresh_var ( None , None ) ;
let ndarray_copy_fun_ty = unifier . add_ty ( TypeEnum ::TFunc ( FunSignature {
args : vec ! [ ] ,
2024-06-13 13:28:39 +08:00
ret : ndarray_copy_fun_ret_ty . ty ,
2024-06-13 16:03:32 +08:00
vars : into_var_map ( [ ndarray_dtype_tvar , ndarray_ndims_tvar ] ) ,
2024-03-07 13:02:13 +08:00
} ) ) ;
2024-03-06 16:53:41 +08:00
let ndarray_fill_fun_ty = unifier . add_ty ( TypeEnum ::TFunc ( FunSignature {
2024-06-12 14:45:03 +08:00
args : vec ! [ FuncArg {
name : " value " . into ( ) ,
2024-06-13 13:28:39 +08:00
ty : ndarray_dtype_tvar . ty ,
2024-06-12 14:45:03 +08:00
default_value : None ,
2024-07-09 15:55:11 +08:00
is_vararg : false ,
2024-06-12 14:45:03 +08:00
} ] ,
2024-03-06 16:53:41 +08:00
ret : none ,
2024-06-13 16:03:32 +08:00
vars : into_var_map ( [ ndarray_dtype_tvar , ndarray_ndims_tvar ] ) ,
2024-03-06 16:53:41 +08:00
} ) ) ;
2024-02-27 13:39:05 +08:00
let ndarray = unifier . add_ty ( TypeEnum ::TObj {
2024-06-12 15:01:01 +08:00
obj_id : PrimDef ::NDArray . id ( ) ,
2024-03-06 16:53:41 +08:00
fields : Mapping ::from ( [
2024-07-26 10:59:50 +08:00
( PrimDef ::FunNDArrayCopy . simple_name ( ) . into ( ) , ( ndarray_copy_fun_ty , true ) ) ,
( PrimDef ::FunNDArrayFill . simple_name ( ) . into ( ) , ( ndarray_fill_fun_ty , true ) ) ,
2024-03-06 16:53:41 +08:00
] ) ,
2024-06-13 16:03:32 +08:00
params : into_var_map ( [ ndarray_dtype_tvar , ndarray_ndims_tvar ] ) ,
2024-02-27 13:39:05 +08:00
} ) ;
2024-06-13 13:28:39 +08:00
unifier . unify ( ndarray_copy_fun_ret_ty . ty , ndarray ) . unwrap ( ) ;
2024-03-07 13:02:13 +08:00
2022-03-26 15:09:15 +08:00
let primitives = PrimitiveStore {
int32 ,
int64 ,
2023-12-08 17:43:32 +08:00
uint32 ,
uint64 ,
2022-03-26 15:09:15 +08:00
float ,
bool ,
none ,
range ,
str ,
exception ,
option ,
2024-07-02 11:05:05 +08:00
list ,
2024-02-27 13:39:05 +08:00
ndarray ,
2023-12-15 14:02:30 +08:00
size_t ,
2022-03-26 15:09:15 +08:00
} ;
2023-12-13 18:23:32 +08:00
unifier . put_primitive_store ( & primitives ) ;
2021-08-27 10:21:51 +08:00
crate ::typecheck ::magic_methods ::set_primitives_magic_methods ( & primitives , & mut unifier ) ;
( primitives , unifier )
}
2023-12-08 17:43:32 +08:00
/// already include the `definition_id` of itself inside the ancestors vector
/// when first registering, the `type_vars`, fields, methods, ancestors are invalid
#[ must_use ]
2021-08-27 10:21:51 +08:00
pub fn make_top_level_class_def (
2024-02-26 15:11:00 +08:00
obj_id : DefinitionId ,
2021-10-16 18:08:13 +08:00
resolver : Option < Arc < dyn SymbolResolver + Send + Sync > > ,
2021-09-22 17:19:27 +08:00
name : StrRef ,
2021-09-20 14:24:16 +08:00
constructor : Option < Type > ,
2022-02-21 18:27:46 +08:00
loc : Option < Location > ,
2021-08-27 10:21:51 +08:00
) -> TopLevelDef {
TopLevelDef ::Class {
2021-09-22 17:19:27 +08:00
name ,
2024-02-26 15:11:00 +08:00
object_id : obj_id ,
2023-12-08 17:43:32 +08:00
type_vars : Vec ::default ( ) ,
fields : Vec ::default ( ) ,
2024-06-19 16:35:03 +08:00
attributes : Vec ::default ( ) ,
2023-12-08 17:43:32 +08:00
methods : Vec ::default ( ) ,
ancestors : Vec ::default ( ) ,
2021-09-19 22:54:06 +08:00
constructor ,
2021-08-27 10:21:51 +08:00
resolver ,
2022-02-21 17:52:34 +08:00
loc ,
2021-08-27 10:21:51 +08:00
}
}
/// when first registering, the type is a invalid value
2023-12-08 17:43:32 +08:00
#[ must_use ]
2021-08-27 10:21:51 +08:00
pub fn make_top_level_function_def (
name : String ,
2021-09-22 17:19:27 +08:00
simple_name : StrRef ,
2021-08-27 10:21:51 +08:00
ty : Type ,
2021-10-16 18:08:13 +08:00
resolver : Option < Arc < dyn SymbolResolver + Send + Sync > > ,
2022-02-21 18:27:46 +08:00
loc : Option < Location > ,
2021-08-27 10:21:51 +08:00
) -> TopLevelDef {
TopLevelDef ::Function {
name ,
2021-09-19 22:54:06 +08:00
simple_name ,
2021-08-27 10:21:51 +08:00
signature : ty ,
2023-12-08 17:43:32 +08:00
var_id : Vec ::default ( ) ,
instance_to_symbol : HashMap ::default ( ) ,
instance_to_stmt : HashMap ::default ( ) ,
2021-08-27 10:21:51 +08:00
resolver ,
2021-09-30 17:07:48 +08:00
codegen_callback : None ,
2022-02-21 17:52:34 +08:00
loc ,
2021-08-27 10:21:51 +08:00
}
}
2024-10-03 14:21:11 +08:00
#[ must_use ]
pub fn make_top_level_variable_def (
name : String ,
simple_name : StrRef ,
ty : Type ,
2024-10-07 16:52:39 +08:00
ty_decl : Option < Expr > ,
2024-10-03 14:21:11 +08:00
resolver : Option < Arc < dyn SymbolResolver + Send + Sync > > ,
loc : Option < Location > ,
) -> TopLevelDef {
TopLevelDef ::Variable { name , simple_name , ty , ty_decl , resolver , loc }
}
2023-12-08 17:43:32 +08:00
#[ must_use ]
2021-08-27 10:21:51 +08:00
pub fn make_class_method_name ( mut class_name : String , method_name : & str ) -> String {
2021-09-12 13:14:46 +08:00
class_name . push ( '.' ) ;
2021-08-27 10:21:51 +08:00
class_name . push_str ( method_name ) ;
class_name
}
pub fn get_class_method_def_info (
2021-09-22 17:19:27 +08:00
class_methods_def : & [ ( StrRef , Type , DefinitionId ) ] ,
method_name : StrRef ,
2023-11-15 17:30:26 +08:00
) -> Result < ( Type , DefinitionId ) , HashSet < String > > {
2021-08-27 10:21:51 +08:00
for ( name , ty , def_id ) in class_methods_def {
2021-09-22 17:19:27 +08:00
if name = = & method_name {
2021-08-27 10:21:51 +08:00
return Ok ( ( * ty , * def_id ) ) ;
}
}
2023-11-15 17:30:26 +08:00
Err ( HashSet ::from ( [ format! ( " no method {method_name} in the current class " ) ] ) )
2021-08-27 10:21:51 +08:00
}
2023-12-08 17:43:32 +08:00
/// get the `var_id` of a given `TVar` type
2024-06-13 13:28:39 +08:00
pub fn get_var_id ( var_ty : Type , unifier : & mut Unifier ) -> Result < TypeVarId , HashSet < String > > {
2021-08-31 09:57:07 +08:00
if let TypeEnum ::TVar { id , .. } = unifier . get_ty ( var_ty ) . as_ref ( ) {
Ok ( * id )
} else {
2024-06-12 14:45:03 +08:00
Err ( HashSet ::from ( [ " not type var " . to_string ( ) ] ) )
2021-08-31 09:57:07 +08:00
}
}
pub fn check_overload_function_type (
this : Type ,
other : Type ,
unifier : & mut Unifier ,
type_var_to_concrete_def : & HashMap < Type , TypeAnnotation > ,
) -> bool {
let this = unifier . get_ty ( this ) ;
let this = this . as_ref ( ) ;
2021-08-30 17:38:07 +08:00
let other = unifier . get_ty ( other ) ;
let other = other . as_ref ( ) ;
2023-12-12 13:38:27 +08:00
let (
2022-02-21 18:27:46 +08:00
TypeEnum ::TFunc ( FunSignature { args : this_args , ret : this_ret , .. } ) ,
TypeEnum ::TFunc ( FunSignature { args : other_args , ret : other_ret , .. } ) ,
2024-06-12 14:45:03 +08:00
) = ( this , other )
else {
2023-12-12 13:38:27 +08:00
unreachable! ( " this function must be called with function type " )
} ;
2021-08-30 17:38:07 +08:00
2023-12-12 13:38:27 +08:00
// check args
2024-06-12 14:45:03 +08:00
let args_ok =
this_args
. iter ( )
. map ( | FuncArg { name , ty , .. } | ( name , type_var_to_concrete_def . get ( ty ) . unwrap ( ) ) )
. zip ( other_args . iter ( ) . map ( | FuncArg { name , ty , .. } | {
( name , type_var_to_concrete_def . get ( ty ) . unwrap ( ) )
} ) )
. all ( | ( this , other ) | {
if this . 0 = = & " self " . into ( ) & & this . 0 = = other . 0 {
true
} else {
this . 0 = = other . 0
& & check_overload_type_annotation_compatible ( this . 1 , other . 1 , unifier )
}
} ) ;
2021-08-30 17:38:07 +08:00
2023-12-12 13:38:27 +08:00
// check rets
let ret_ok = check_overload_type_annotation_compatible (
type_var_to_concrete_def . get ( this_ret ) . unwrap ( ) ,
type_var_to_concrete_def . get ( other_ret ) . unwrap ( ) ,
unifier ,
) ;
// return
args_ok & & ret_ok
2021-08-30 17:38:07 +08:00
}
2021-08-30 22:46:50 +08:00
2021-08-31 09:57:07 +08:00
pub fn check_overload_field_type (
this : Type ,
other : Type ,
unifier : & mut Unifier ,
type_var_to_concrete_def : & HashMap < Type , TypeAnnotation > ,
) -> bool {
check_overload_type_annotation_compatible (
type_var_to_concrete_def . get ( & this ) . unwrap ( ) ,
type_var_to_concrete_def . get ( & other ) . unwrap ( ) ,
unifier ,
)
2021-08-30 22:46:50 +08:00
}
2021-09-21 02:48:42 +08:00
2024-08-16 17:42:09 +08:00
/// This function returns the fields that have been initialized in the `__init__` function of a class
/// The function takes as input:
/// * `class_id`: The `object_id` of the class whose function is being evaluated (check `TopLevelDef::Class`)
/// * `definition_ast_list`: A list of ast definitions and statements defined in `TopLevelComposer`
/// * `stmts`: The body of function being parsed. Each statment is analyzed to check varaible initialization statements
pub fn get_all_assigned_field (
class_id : usize ,
definition_ast_list : & Vec < DefAst > ,
stmts : & [ Stmt < ( ) > ] ,
) -> Result < HashSet < StrRef > , HashSet < String > > {
2021-09-22 17:19:27 +08:00
let mut result = HashSet ::new ( ) ;
2021-09-21 02:48:42 +08:00
for s in stmts {
match & s . node {
ast ::StmtKind ::AnnAssign { target , .. }
if {
if let ast ::ExprKind ::Attribute { value , .. } = & target . node {
if let ast ::ExprKind ::Name { id , .. } = & value . node {
2021-09-22 17:19:27 +08:00
id = = & " self " . into ( )
2021-09-21 02:48:42 +08:00
} else {
false
}
} else {
false
}
} = >
{
2024-06-12 14:45:03 +08:00
return Err ( HashSet ::from ( [ format! (
" redundant type annotation for class fields at {} " ,
s . location
) ] ) )
2021-09-21 02:48:42 +08:00
}
ast ::StmtKind ::Assign { targets , .. } = > {
for t in targets {
if let ast ::ExprKind ::Attribute { value , attr , .. } = & t . node {
if let ast ::ExprKind ::Name { id , .. } = & value . node {
2021-09-22 17:19:27 +08:00
if id = = & " self " . into ( ) {
2021-09-22 17:56:48 +08:00
result . insert ( * attr ) ;
2021-09-21 02:48:42 +08:00
}
}
}
}
}
// TODO: do not check for For and While?
ast ::StmtKind ::For { body , orelse , .. }
| ast ::StmtKind ::While { body , orelse , .. } = > {
2024-08-16 17:42:09 +08:00
result . extend ( Self ::get_all_assigned_field (
class_id ,
definition_ast_list ,
body . as_slice ( ) ,
) ? ) ;
result . extend ( Self ::get_all_assigned_field (
class_id ,
definition_ast_list ,
orelse . as_slice ( ) ,
) ? ) ;
2021-09-21 02:48:42 +08:00
}
ast ::StmtKind ::If { body , orelse , .. } = > {
2024-08-16 17:42:09 +08:00
let inited_for_sure = Self ::get_all_assigned_field (
class_id ,
definition_ast_list ,
body . as_slice ( ) ,
) ?
. intersection ( & Self ::get_all_assigned_field (
class_id ,
definition_ast_list ,
orelse . as_slice ( ) ,
) ? )
. copied ( )
. collect ::< HashSet < _ > > ( ) ;
2021-09-21 02:48:42 +08:00
result . extend ( inited_for_sure ) ;
}
ast ::StmtKind ::Try { body , orelse , finalbody , .. } = > {
2024-08-16 17:42:09 +08:00
let inited_for_sure = Self ::get_all_assigned_field (
class_id ,
definition_ast_list ,
body . as_slice ( ) ,
) ?
. intersection ( & Self ::get_all_assigned_field (
class_id ,
definition_ast_list ,
orelse . as_slice ( ) ,
) ? )
. copied ( )
. collect ::< HashSet < _ > > ( ) ;
2021-09-21 02:48:42 +08:00
result . extend ( inited_for_sure ) ;
2024-08-16 17:42:09 +08:00
result . extend ( Self ::get_all_assigned_field (
class_id ,
definition_ast_list ,
finalbody . as_slice ( ) ,
) ? ) ;
2021-09-21 02:48:42 +08:00
}
ast ::StmtKind ::With { body , .. } = > {
2024-08-16 17:42:09 +08:00
result . extend ( Self ::get_all_assigned_field (
class_id ,
definition_ast_list ,
body . as_slice ( ) ,
) ? ) ;
}
// Variables Initialized in function calls
ast ::StmtKind ::Expr { value , .. } = > {
let ExprKind ::Call { func , .. } = & value . node else {
continue ;
} ;
let ExprKind ::Attribute { value , attr , .. } = & func . node else {
continue ;
} ;
let ExprKind ::Name { id , .. } = & value . node else {
continue ;
} ;
// Need to consider the two cases:
// Case 1) Call to class function i.e. id = `self`
// Case 2) Call to class ancestor function i.e. id = ancestor_name
// We leave checking whether function in case 2 belonged to class ancestor or not to type checker
//
// According to current handling of `self`, function definition are fixed and do not change regardless
// of which object is passed as `self` i.e. virtual polymorphism is not supported
// Therefore, we change class id for case 2 to reflect behavior of our compiler
let class_name = if * id = = " self " . into ( ) {
let ast ::StmtKind ::ClassDef { name , .. } =
& definition_ast_list [ class_id ] . 1. as_ref ( ) . unwrap ( ) . node
else {
unreachable! ( )
} ;
name
} else {
id
} ;
let parent_method = definition_ast_list . iter ( ) . find_map ( | def | {
let (
class_def ,
Some ( ast ::Located {
node : ast ::StmtKind ::ClassDef { name , body , .. } ,
..
} ) ,
) = & def
else {
return None ;
} ;
let TopLevelDef ::Class { object_id : class_id , .. } = & * class_def . read ( )
else {
unreachable! ( )
} ;
if name = = class_name {
body . iter ( ) . find_map ( | m | {
let ast ::StmtKind ::FunctionDef { name , body , .. } = & m . node else {
return None ;
} ;
if * name = = * attr {
return Some ( ( body . clone ( ) , class_id . 0 ) ) ;
}
None
} )
} else {
None
}
} ) ;
// If method body is none then method does not exist
if let Some ( ( method_body , class_id ) ) = parent_method {
result . extend ( Self ::get_all_assigned_field (
class_id ,
definition_ast_list ,
method_body . as_slice ( ) ,
) ? ) ;
} else {
return Err ( HashSet ::from ( [ format! (
" {}.{} not found in class {class_name} at {} " ,
* id , * attr , value . location
) ] ) ) ;
}
2021-09-21 02:48:42 +08:00
}
2023-12-08 17:43:32 +08:00
ast ::StmtKind ::Pass { .. }
| ast ::StmtKind ::Assert { .. }
2024-08-16 17:42:09 +08:00
| ast ::StmtKind ::AnnAssign { .. } = > { }
2021-09-21 02:48:42 +08:00
_ = > {
unimplemented! ( )
}
}
}
Ok ( result )
}
2021-11-23 07:32:09 +08:00
2022-02-21 18:27:46 +08:00
pub fn parse_parameter_default_value (
default : & ast ::Expr ,
resolver : & ( dyn SymbolResolver + Send + Sync ) ,
2023-11-15 17:30:26 +08:00
) -> Result < SymbolValue , HashSet < String > > {
2021-11-23 07:32:09 +08:00
parse_parameter_default_value ( default , resolver )
}
2022-02-21 18:27:46 +08:00
pub fn check_default_param_type (
val : & SymbolValue ,
ty : & TypeAnnotation ,
primitive : & PrimitiveStore ,
unifier : & mut Unifier ,
) -> Result < ( ) , String > {
2022-03-30 03:50:34 +08:00
fn is_compatible (
found : & TypeAnnotation ,
expect : & TypeAnnotation ,
unifier : & mut Unifier ,
primitive : & PrimitiveStore ,
) -> bool {
match ( found , expect ) {
( TypeAnnotation ::Primitive ( f ) , TypeAnnotation ::Primitive ( e ) ) = > {
unifier . unioned ( * f , * e )
2022-03-30 03:14:21 +08:00
}
2022-03-30 03:50:34 +08:00
(
TypeAnnotation ::CustomClass { id : f_id , params : f_param } ,
TypeAnnotation ::CustomClass { id : e_id , params : e_param } ,
) = > {
* f_id = = * e_id
2024-03-27 10:36:02 +08:00
& & * f_id = = primitive . option . obj_id ( unifier ) . unwrap ( )
2022-03-30 03:50:34 +08:00
& & ( f_param . is_empty ( )
| | ( f_param . len ( ) = = 1
& & e_param . len ( ) = = 1
& & is_compatible ( & f_param [ 0 ] , & e_param [ 0 ] , unifier , primitive ) ) )
2022-03-30 03:14:21 +08:00
}
2022-03-30 03:50:34 +08:00
( TypeAnnotation ::Tuple ( f ) , TypeAnnotation ::Tuple ( e ) ) = > {
f . len ( ) = = e . len ( )
& & f . iter ( )
. zip ( e . iter ( ) )
. all ( | ( f , e ) | is_compatible ( f , e , unifier , primitive ) )
}
_ = > false ,
2022-03-30 03:14:21 +08:00
}
2022-03-30 03:50:34 +08:00
}
2023-12-01 15:56:18 +08:00
let found = val . get_type_annotation ( primitive , unifier ) ;
2023-12-08 17:43:32 +08:00
if is_compatible ( & found , ty , unifier , primitive ) {
Ok ( ( ) )
} else {
2021-11-23 07:32:09 +08:00
Err ( format! (
" incompatible default parameter type, expect {}, found {} " ,
ty . stringify ( unifier ) ,
2022-03-30 03:50:34 +08:00
found . stringify ( unifier ) ,
2021-11-23 07:32:09 +08:00
) )
}
}
2024-08-30 18:03:25 +08:00
/// Parses the class type variables and direct parents
/// we only allow single inheritance
pub fn analyze_class_bases (
class_def : & Arc < RwLock < TopLevelDef > > ,
class_ast : & Option < Stmt > ,
temp_def_list : & [ Arc < RwLock < TopLevelDef > > ] ,
unifier : & mut Unifier ,
primitives_store : & PrimitiveStore ,
) -> Result < ( ) , HashSet < String > > {
let mut class_def = class_def . write ( ) ;
let ( class_def_id , class_ancestors , class_bases_ast , class_type_vars , class_resolver ) = {
let TopLevelDef ::Class { object_id , ancestors , type_vars , resolver , .. } =
& mut * class_def
else {
unreachable! ( )
} ;
let Some ( ast ::Located { node : ast ::StmtKind ::ClassDef { bases , .. } , .. } ) = class_ast
else {
unreachable! ( )
} ;
( object_id , ancestors , bases , type_vars , resolver . as_ref ( ) . unwrap ( ) . as_ref ( ) )
} ;
let mut is_generic = false ;
let mut has_base = false ;
// Check class bases for typevars
for b in class_bases_ast {
match & b . node {
// analyze typevars bounded to the class,
// only support things like `class A(Generic[T, V])`,
// things like `class A(Generic[T, V, ImportedModule.T])` is not supported
// i.e. only simple names are allowed in the subscript
// should update the TopLevelDef::Class.typevars and the TypeEnum::TObj.params
ast ::ExprKind ::Subscript { value , slice , .. } if matches! ( & value . node , ast ::ExprKind ::Name { id , .. } if id = = & " Generic " . into ( ) ) = >
{
if is_generic {
return Err ( HashSet ::from ( [ format! (
" only single Generic[...] is allowed (at {}) " ,
b . location
) ] ) ) ;
}
is_generic = true ;
let type_var_list : Vec < & ast ::Expr < ( ) > > ;
// if `class A(Generic[T, V, G])`
if let ast ::ExprKind ::Tuple { elts , .. } = & slice . node {
type_var_list = elts . iter ( ) . collect_vec ( ) ;
// `class A(Generic[T])`
} else {
type_var_list = vec! [ & * * slice ] ;
}
let type_vars = type_var_list
. into_iter ( )
. map ( | e | {
class_resolver . parse_type_annotation (
temp_def_list ,
unifier ,
primitives_store ,
e ,
)
} )
. collect ::< Result < Vec < _ > , _ > > ( ) ? ;
class_type_vars . extend ( type_vars ) ;
}
ast ::ExprKind ::Name { .. } | ast ::ExprKind ::Subscript { .. } = > {
if has_base {
return Err ( HashSet ::from ( [ format! ( " a class definition can only have at most one base class declaration and one generic declaration (at {} ) " , b . location ) ] ) ) ;
}
has_base = true ;
// the function parse_ast_to make sure that no type var occurred in
// bast_ty if it is a CustomClassKind
let base_ty = parse_ast_to_type_annotation_kinds (
class_resolver ,
temp_def_list ,
unifier ,
primitives_store ,
b ,
vec! [ ( * class_def_id , class_type_vars . clone ( ) ) ]
. into_iter ( )
. collect ::< HashMap < _ , _ > > ( ) ,
) ? ;
if let TypeAnnotation ::CustomClass { .. } = & base_ty {
class_ancestors . push ( base_ty ) ;
} else {
return Err ( HashSet ::from ( [ format! (
" class base declaration can only be custom class (at {}) " ,
b . location
) ] ) ) ;
}
}
_ = > {
return Err ( HashSet ::from ( [ format! (
" unsupported statement in class defintion (at {}) " ,
b . location
) ] ) ) ;
}
}
}
Ok ( ( ) )
}
/// gets all ancestors of a class
pub fn analyze_class_ancestors (
class_def : & Arc < RwLock < TopLevelDef > > ,
temp_def_list : & [ Arc < RwLock < TopLevelDef > > ] ,
) {
// Check if class has a direct parent
let mut class_def = class_def . write ( ) ;
let TopLevelDef ::Class { ancestors , type_vars , object_id , .. } = & mut * class_def else {
unreachable! ( )
} ;
let mut anc_set = HashMap ::new ( ) ;
if let Some ( ancestor ) = ancestors . first ( ) {
let TypeAnnotation ::CustomClass { id , .. } = ancestor else { unreachable! ( ) } ;
let TopLevelDef ::Class { ancestors : parent_ancestors , .. } =
& * temp_def_list [ id . 0 ] . read ( )
else {
unreachable! ( )
} ;
for anc in parent_ancestors . iter ( ) . skip ( 1 ) {
let TypeAnnotation ::CustomClass { id , .. } = anc else { unreachable! ( ) } ;
anc_set . insert ( id , anc . clone ( ) ) ;
}
ancestors . extend ( anc_set . into_values ( ) ) ;
}
// push `self` as first ancestor of class
ancestors . insert ( 0 , make_self_type_annotation ( type_vars . as_slice ( ) , * object_id ) ) ;
}
2021-11-23 07:32:09 +08:00
}
2022-02-21 18:27:46 +08:00
pub fn parse_parameter_default_value (
default : & ast ::Expr ,
resolver : & ( dyn SymbolResolver + Send + Sync ) ,
2023-11-15 17:30:26 +08:00
) -> Result < SymbolValue , HashSet < String > > {
fn handle_constant ( val : & Constant , loc : & Location ) -> Result < SymbolValue , HashSet < String > > {
2021-11-23 07:32:09 +08:00
match val {
2022-03-08 02:30:04 +08:00
Constant ::Int ( v ) = > {
if let Ok ( v ) = ( * v ) . try_into ( ) {
Ok ( SymbolValue ::I32 ( v ) )
} else {
2023-11-15 17:30:26 +08:00
Err ( HashSet ::from ( [ format! ( " integer value out of range at {loc} " ) ] ) )
2021-11-23 07:32:09 +08:00
}
2022-03-08 02:30:04 +08:00
}
2021-11-23 07:32:09 +08:00
Constant ::Float ( v ) = > Ok ( SymbolValue ::Double ( * v ) ) ,
Constant ::Bool ( v ) = > Ok ( SymbolValue ::Bool ( * v ) ) ,
Constant ::Tuple ( tuple ) = > Ok ( SymbolValue ::Tuple (
2022-02-21 18:27:46 +08:00
tuple . iter ( ) . map ( | x | handle_constant ( x , loc ) ) . collect ::< Result < Vec < _ > , _ > > ( ) ? ,
2021-11-23 07:32:09 +08:00
) ) ,
2024-06-12 14:45:03 +08:00
Constant ::None = > Err ( HashSet ::from ( [ format! (
" `None` is not supported, use `none` for option type instead ({loc}) "
) ] ) ) ,
2021-11-23 07:32:09 +08:00
_ = > unimplemented! ( " this constant is not supported at {} " , loc ) ,
}
}
match & default . node {
ast ::ExprKind ::Constant { value , .. } = > handle_constant ( value , & default . location ) ,
2024-06-12 14:45:03 +08:00
ast ::ExprKind ::Call { func , args , .. } if args . len ( ) = = 1 = > match & func . node {
ast ::ExprKind ::Name { id , .. } if * id = = " int64 " . into ( ) = > match & args [ 0 ] . node {
ast ::ExprKind ::Constant { value : Constant ::Int ( v ) , .. } = > {
let v : Result < i64 , _ > = ( * v ) . try_into ( ) ;
match v {
Ok ( v ) = > Ok ( SymbolValue ::I64 ( v ) ) ,
_ = > Err ( HashSet ::from ( [ format! (
" default param value out of range at {} " ,
default . location
) ] ) ) ,
2022-03-08 02:30:04 +08:00
}
2021-11-23 07:32:09 +08:00
}
2024-06-12 14:45:03 +08:00
_ = > Err ( HashSet ::from ( [ format! (
" only allow constant integer here at {} " ,
default . location
) ] ) ) ,
} ,
ast ::ExprKind ::Name { id , .. } if * id = = " uint32 " . into ( ) = > match & args [ 0 ] . node {
ast ::ExprKind ::Constant { value : Constant ::Int ( v ) , .. } = > {
let v : Result < u32 , _ > = ( * v ) . try_into ( ) ;
match v {
Ok ( v ) = > Ok ( SymbolValue ::U32 ( v ) ) ,
_ = > Err ( HashSet ::from ( [ format! (
" default param value out of range at {} " ,
default . location
) ] ) ) ,
2022-03-05 03:45:09 +08:00
}
}
2024-06-12 14:45:03 +08:00
_ = > Err ( HashSet ::from ( [ format! (
" only allow constant integer here at {} " ,
default . location
) ] ) ) ,
} ,
ast ::ExprKind ::Name { id , .. } if * id = = " uint64 " . into ( ) = > match & args [ 0 ] . node {
ast ::ExprKind ::Constant { value : Constant ::Int ( v ) , .. } = > {
let v : Result < u64 , _ > = ( * v ) . try_into ( ) ;
match v {
Ok ( v ) = > Ok ( SymbolValue ::U64 ( v ) ) ,
_ = > Err ( HashSet ::from ( [ format! (
" default param value out of range at {} " ,
default . location
) ] ) ) ,
2022-03-05 03:45:09 +08:00
}
}
2024-06-12 14:45:03 +08:00
_ = > Err ( HashSet ::from ( [ format! (
" only allow constant integer here at {} " ,
default . location
) ] ) ) ,
} ,
ast ::ExprKind ::Name { id , .. } if * id = = " Some " . into ( ) = > Ok ( SymbolValue ::OptionSome (
Box ::new ( parse_parameter_default_value ( & args [ 0 ] , resolver ) ? ) ,
) ) ,
_ = > Err ( HashSet ::from ( [ format! (
" unsupported default parameter at {} " ,
default . location
) ] ) ) ,
} ,
ast ::ExprKind ::Tuple { elts , .. } = > Ok ( SymbolValue ::Tuple (
elts . iter ( )
. map ( | x | parse_parameter_default_value ( x , resolver ) )
. collect ::< Result < Vec < _ > , _ > > ( ) ? ,
2021-11-23 07:32:09 +08:00
) ) ,
2022-03-30 03:14:21 +08:00
ast ::ExprKind ::Name { id , .. } if id = = & " none " . into ( ) = > Ok ( SymbolValue ::OptionNone ) ,
2021-11-23 07:32:09 +08:00
ast ::ExprKind ::Name { id , .. } = > {
2024-06-12 14:45:03 +08:00
resolver . get_default_param_value ( default ) . ok_or_else ( | | {
HashSet ::from ( [ format! (
" `{}` cannot be used as a default parameter at {} \
2023-11-15 17:30:26 +08:00
( not primitive type , option or tuple / not defined ? ) " ,
2024-06-12 14:45:03 +08:00
id , default . location
) ] )
} )
2021-11-23 07:32:09 +08:00
}
2024-06-12 14:45:03 +08:00
_ = > Err ( HashSet ::from ( [ format! (
" unsupported default parameter (not primitive type, option or tuple) at {} " ,
default . location
) ] ) ) ,
2021-11-23 07:32:09 +08:00
}
2021-08-27 10:21:51 +08:00
}
2024-06-04 18:00:42 +08:00
/// Obtains the element type of an array-like type.
pub fn arraylike_flatten_element_type ( unifier : & mut Unifier , ty : Type ) -> Type {
match & * unifier . get_ty ( ty ) {
2024-06-12 15:01:01 +08:00
TypeEnum ::TObj { obj_id , .. } if * obj_id = = PrimDef ::NDArray . id ( ) = > {
2024-06-12 14:45:03 +08:00
unpack_ndarray_var_tys ( unifier , ty ) . 0
}
2024-06-04 18:00:42 +08:00
2024-07-02 11:05:05 +08:00
TypeEnum ::TObj { obj_id , params , .. } if * obj_id = = PrimDef ::List . id ( ) = > {
arraylike_flatten_element_type ( unifier , iter_type_vars ( params ) . next ( ) . unwrap ( ) . ty )
}
2024-06-12 14:45:03 +08:00
_ = > ty ,
2024-06-04 18:00:42 +08:00
}
}
/// Obtains the number of dimensions of an array-like type.
pub fn arraylike_get_ndims ( unifier : & mut Unifier , ty : Type ) -> u64 {
match & * unifier . get_ty ( ty ) {
2024-06-12 15:01:01 +08:00
TypeEnum ::TObj { obj_id , .. } if * obj_id = = PrimDef ::NDArray . id ( ) = > {
2024-06-04 18:00:42 +08:00
let ndims = unpack_ndarray_var_tys ( unifier , ty ) . 1 ;
let TypeEnum ::TLiteral { values , .. } = & * unifier . get_ty_immutable ( ndims ) else {
panic! ( " Expected TLiteral for ndarray.ndims, got {} " , unifier . stringify ( ndims ) )
} ;
if values . len ( ) > 1 {
todo! ( " Getting num of dimensions for ndarray with more than one ndim bound is unimplemented " )
}
u64 ::try_from ( values [ 0 ] . clone ( ) ) . unwrap ( )
}
2024-07-02 11:05:05 +08:00
TypeEnum ::TObj { obj_id , params , .. } if * obj_id = = PrimDef ::List . id ( ) = > {
arraylike_get_ndims ( unifier , iter_type_vars ( params ) . next ( ) . unwrap ( ) . ty ) + 1
}
2024-06-12 14:45:03 +08:00
_ = > 0 ,
2024-06-04 18:00:42 +08:00
}
}
2024-08-20 11:32:01 +08:00
/// Extract an ndarray's `ndims` [type][`Type`] in `u64`. Panic if not possible.
/// The `ndims` must only contain 1 value.
#[ must_use ]
pub fn extract_ndims ( unifier : & Unifier , ndims_ty : Type ) -> u64 {
let ndims_ty_enum = unifier . get_ty_immutable ( ndims_ty ) ;
let TypeEnum ::TLiteral { values , .. } = & * ndims_ty_enum else {
panic! ( " ndims_ty should be a TLiteral " ) ;
} ;
assert_eq! ( values . len ( ) , 1 , " ndims_ty TLiteral should only contain 1 value " ) ;
let ndims = values [ 0 ] . clone ( ) ;
u64 ::try_from ( ndims ) . unwrap ( )
}
/// Return an ndarray's `ndims` as a typechecker [`Type`] from its `u64` value.
pub fn create_ndims ( unifier : & mut Unifier , ndims : u64 ) -> Type {
unifier . get_fresh_literal ( vec! [ SymbolValue ::U64 ( ndims ) ] , None )
}