use std::convert::TryInto; use crate::symbol_resolver::SymbolValue; use crate::toplevel::numpy::unpack_ndarray_var_tys; use crate::typecheck::typedef::{Mapping, VarMap}; use nac3parser::ast::{Constant, Location}; use strum::IntoEnumIterator; use strum_macros::EnumIter; use super::*; /// All primitive types and functions in nac3core. #[derive(Clone, Copy, Debug, EnumIter, PartialEq, Eq)] pub enum PrimDef { Int32, Int64, Float, Bool, None, Range, Str, Exception, UInt32, UInt64, Option, OptionIsSome, OptionIsNone, OptionUnwrap, NDArray, NDArrayCopy, NDArrayFill, FunInt32, FunInt64, FunUInt32, FunUInt64, FunFloat, FunNpNDArray, FunNpEmpty, FunNpZeros, FunNpOnes, FunNpFull, FunNpArray, FunNpEye, FunNpIdentity, FunRound, FunRound64, FunNpRound, FunRange, FunStr, FunBool, FunFloor, FunFloor64, FunNpFloor, FunCeil, FunCeil64, FunNpCeil, FunLen, FunMin, FunNpMin, FunNpMinimum, FunMax, FunNpMax, FunNpMaximum, FunAbs, FunNpIsNan, FunNpIsInf, FunNpSin, FunNpCos, FunNpExp, FunNpExp2, FunNpLog, FunNpLog10, FunNpLog2, FunNpFabs, FunNpSqrt, FunNpRint, FunNpTan, FunNpArcsin, FunNpArccos, FunNpArctan, FunNpSinh, FunNpCosh, FunNpTanh, FunNpArcsinh, FunNpArccosh, FunNpArctanh, FunNpExpm1, FunNpCbrt, FunSpSpecErf, FunSpSpecErfc, FunSpSpecGamma, FunSpSpecGammaln, FunSpSpecJ0, FunSpSpecJ1, FunNpArctan2, FunNpCopysign, FunNpFmax, FunNpFmin, FunNpLdExp, FunNpHypot, FunNpNextAfter, FunSome, } /// Associated details of a [`PrimDef`] pub enum PrimDefDetails { PrimFunction { name: &'static str, simple_name: &'static str }, PrimClass { name: &'static str }, } impl PrimDef { /// Get the assigned [`DefinitionId`] of this [`PrimDef`]. /// /// The assigned definition ID is defined by the position this [`PrimDef`] enum unit variant is defined at, /// with the first `PrimDef`'s definition id being `0`. #[must_use] pub fn id(&self) -> DefinitionId { DefinitionId(*self as usize) } /// Check if a definition ID is that of a [`PrimDef`]. #[must_use] pub fn contains_id(id: DefinitionId) -> bool { Self::iter().any(|prim| prim.id() == id) } /// Get the definition "simple name" of this [`PrimDef`]. /// /// If the [`PrimDef`] is a function, this corresponds to [`TopLevelDef::Function::simple_name`]. /// /// If the [`PrimDef`] is a class, this returns [`None`]. #[must_use] pub fn simple_name(&self) -> &'static str { match self.details() { PrimDefDetails::PrimFunction { simple_name, .. } => simple_name, PrimDefDetails::PrimClass { .. } => { panic!("PrimDef {self:?} has no simple_name as it is not a function.") } } } /// Get the definition "name" of this [`PrimDef`]. /// /// If the [`PrimDef`] is a function, this corresponds to [`TopLevelDef::Function::name`]. /// /// If the [`PrimDef`] is a class, this corresponds to [`TopLevelDef::Class::name`]. #[must_use] pub fn name(&self) -> &'static str { match self.details() { PrimDefDetails::PrimFunction { name, .. } | PrimDefDetails::PrimClass { name } => name, } } /// Get the associated details of this [`PrimDef`] #[must_use] pub fn details(self) -> PrimDefDetails { fn class(name: &'static str) -> PrimDefDetails { PrimDefDetails::PrimClass { name } } fn fun(name: &'static str, simple_name: Option<&'static str>) -> PrimDefDetails { PrimDefDetails::PrimFunction { simple_name: simple_name.unwrap_or(name), name } } match self { PrimDef::Int32 => class("int32"), PrimDef::Int64 => class("int64"), PrimDef::Float => class("float"), PrimDef::Bool => class("bool"), PrimDef::None => class("none"), PrimDef::Range => class("range"), PrimDef::Str => class("str"), PrimDef::Exception => class("Exception"), PrimDef::UInt32 => class("uint32"), PrimDef::UInt64 => class("uint64"), PrimDef::Option => class("Option"), PrimDef::OptionIsSome => fun("Option.is_some", Some("is_some")), PrimDef::OptionIsNone => fun("Option.is_none", Some("is_none")), PrimDef::OptionUnwrap => fun("Option.unwrap", Some("unwrap")), PrimDef::NDArray => class("ndarray"), PrimDef::NDArrayCopy => fun("ndarray.copy", Some("copy")), PrimDef::NDArrayFill => fun("ndarray.fill", Some("fill")), PrimDef::FunInt32 => fun("int32", None), PrimDef::FunInt64 => fun("int64", None), PrimDef::FunUInt32 => fun("uint32", None), PrimDef::FunUInt64 => fun("uint64", None), PrimDef::FunFloat => fun("float", None), PrimDef::FunNpNDArray => fun("np_ndarray", None), PrimDef::FunNpEmpty => fun("np_empty", None), PrimDef::FunNpZeros => fun("np_zeros", None), PrimDef::FunNpOnes => fun("np_ones", None), PrimDef::FunNpFull => fun("np_full", None), PrimDef::FunNpArray => fun("np_array", None), PrimDef::FunNpEye => fun("np_eye", None), PrimDef::FunNpIdentity => fun("np_identity", None), PrimDef::FunRound => fun("round", None), PrimDef::FunRound64 => fun("round64", None), PrimDef::FunNpRound => fun("np_round", None), PrimDef::FunRange => fun("range", None), PrimDef::FunStr => fun("str", None), PrimDef::FunBool => fun("bool", None), PrimDef::FunFloor => fun("floor", None), PrimDef::FunFloor64 => fun("floor64", None), PrimDef::FunNpFloor => fun("np_floor", None), PrimDef::FunCeil => fun("ceil", None), PrimDef::FunCeil64 => fun("ceil64", None), PrimDef::FunNpCeil => fun("np_ceil", None), PrimDef::FunLen => fun("len", None), PrimDef::FunMin => fun("min", None), PrimDef::FunNpMin => fun("np_min", None), PrimDef::FunNpMinimum => fun("np_minimum", None), PrimDef::FunMax => fun("max", None), PrimDef::FunNpMax => fun("np_max", None), PrimDef::FunNpMaximum => fun("np_maximum", None), PrimDef::FunAbs => fun("abs", None), PrimDef::FunNpIsNan => fun("np_isnan", None), PrimDef::FunNpIsInf => fun("np_isinf", None), PrimDef::FunNpSin => fun("np_sin", None), PrimDef::FunNpCos => fun("np_cos", None), PrimDef::FunNpExp => fun("np_exp", None), PrimDef::FunNpExp2 => fun("np_exp2", None), PrimDef::FunNpLog => fun("np_log", None), PrimDef::FunNpLog10 => fun("np_log10", None), PrimDef::FunNpLog2 => fun("np_log2", None), PrimDef::FunNpFabs => fun("np_fabs", None), PrimDef::FunNpSqrt => fun("np_sqrt", None), PrimDef::FunNpRint => fun("np_rint", None), PrimDef::FunNpTan => fun("np_tan", None), PrimDef::FunNpArcsin => fun("np_arcsin", None), PrimDef::FunNpArccos => fun("np_arccos", None), PrimDef::FunNpArctan => fun("np_arctan", None), PrimDef::FunNpSinh => fun("np_sinh", None), PrimDef::FunNpCosh => fun("np_cosh", None), PrimDef::FunNpTanh => fun("np_tanh", None), PrimDef::FunNpArcsinh => fun("np_arcsinh", None), PrimDef::FunNpArccosh => fun("np_arccosh", None), PrimDef::FunNpArctanh => fun("np_arctanh", None), PrimDef::FunNpExpm1 => fun("np_expm1", None), PrimDef::FunNpCbrt => fun("np_cbrt", None), PrimDef::FunSpSpecErf => fun("sp_spec_erf", None), PrimDef::FunSpSpecErfc => fun("sp_spec_erfc", None), PrimDef::FunSpSpecGamma => fun("sp_spec_gamma", None), PrimDef::FunSpSpecGammaln => fun("sp_spec_gammaln", None), PrimDef::FunSpSpecJ0 => fun("sp_spec_j0", None), PrimDef::FunSpSpecJ1 => fun("sp_spec_j1", None), PrimDef::FunNpArctan2 => fun("np_arctan2", None), PrimDef::FunNpCopysign => fun("np_copysign", None), PrimDef::FunNpFmax => fun("np_fmax", None), PrimDef::FunNpFmin => fun("np_fmin", None), PrimDef::FunNpLdExp => fun("np_ldexp", None), PrimDef::FunNpHypot => fun("np_hypot", None), PrimDef::FunNpNextAfter => fun("np_nextafter", None), PrimDef::FunSome => fun("Some", None), } } } /// Asserts that a [`PrimDef`] is in an allowlist. /// /// Like `debug_assert!`, this statements of this function are only /// enabled if `cfg!(debug_assertions)` is true. pub fn debug_assert_prim_is_allowed(prim: PrimDef, allowlist: &[PrimDef]) { if cfg!(debug_assertions) { let allowed = allowlist.iter().any(|p| *p == prim); assert!( allowed, "Disallowed primitive definition. Got {prim:?}, but expects it to be in {allowlist:?}" ); } } impl TopLevelDef { pub fn to_string(&self, unifier: &mut Unifier) -> String { match self { TopLevelDef::Class { name, ancestors, fields, methods, type_vars, .. } => { let fields_str = fields .iter() .map(|(n, ty, _)| (n.to_string(), unifier.stringify(*ty))) .collect_vec(); let methods_str = methods .iter() .map(|(n, ty, id)| (n.to_string(), unifier.stringify(*ty), *id)) .collect_vec(); format!( "Class {{\nname: {:?},\nancestors: {:?},\nfields: {:?},\nmethods: {:?},\ntype_vars: {:?}\n}}", name, ancestors.iter().map(|ancestor| ancestor.stringify(unifier)).collect_vec(), fields_str.iter().map(|(a, _)| a).collect_vec(), methods_str.iter().map(|(a, b, _)| (a, b)).collect_vec(), type_vars.iter().map(|id| unifier.stringify(*id)).collect_vec(), ) } TopLevelDef::Function { name, signature, var_id, .. } => format!( "Function {{\nname: {:?},\nsig: {:?},\nvar_id: {:?}\n}}", name, unifier.stringify(*signature), { // preserve the order for debug output and test let mut r = var_id.clone(); r.sort_unstable(); r } ), } } } impl TopLevelComposer { #[must_use] pub fn make_primitives(size_t: u32) -> (PrimitiveStore, Unifier) { let mut unifier = Unifier::new(); let int32 = unifier.add_ty(TypeEnum::TObj { obj_id: PrimDef::Int32.id(), fields: HashMap::new(), params: VarMap::new(), }); let int64 = unifier.add_ty(TypeEnum::TObj { obj_id: PrimDef::Int64.id(), fields: HashMap::new(), params: VarMap::new(), }); let float = unifier.add_ty(TypeEnum::TObj { obj_id: PrimDef::Float.id(), fields: HashMap::new(), params: VarMap::new(), }); let bool = unifier.add_ty(TypeEnum::TObj { obj_id: PrimDef::Bool.id(), fields: HashMap::new(), params: VarMap::new(), }); let none = unifier.add_ty(TypeEnum::TObj { obj_id: PrimDef::None.id(), fields: HashMap::new(), params: VarMap::new(), }); let range = unifier.add_ty(TypeEnum::TObj { obj_id: PrimDef::Range.id(), fields: HashMap::new(), params: VarMap::new(), }); let str = unifier.add_ty(TypeEnum::TObj { obj_id: PrimDef::Str.id(), fields: HashMap::new(), params: VarMap::new(), }); let exception = unifier.add_ty(TypeEnum::TObj { obj_id: PrimDef::Exception.id(), fields: vec![ ("__name__".into(), (int32, true)), ("__file__".into(), (str, true)), ("__line__".into(), (int32, true)), ("__col__".into(), (int32, true)), ("__func__".into(), (str, true)), ("__message__".into(), (str, true)), ("__param0__".into(), (int64, true)), ("__param1__".into(), (int64, true)), ("__param2__".into(), (int64, true)), ] .into_iter() .collect::>(), params: VarMap::new(), }); let uint32 = unifier.add_ty(TypeEnum::TObj { obj_id: PrimDef::UInt32.id(), fields: HashMap::new(), params: VarMap::new(), }); let uint64 = unifier.add_ty(TypeEnum::TObj { obj_id: PrimDef::UInt64.id(), fields: HashMap::new(), params: VarMap::new(), }); let option_type_var = unifier.get_fresh_var(Some("option_type_var".into()), None); let is_some_type_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![], ret: bool, vars: VarMap::from([(option_type_var.1, option_type_var.0)]), })); let unwrap_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![], ret: option_type_var.0, vars: VarMap::from([(option_type_var.1, option_type_var.0)]), })); let option = unifier.add_ty(TypeEnum::TObj { obj_id: PrimDef::Option.id(), fields: vec![ (PrimDef::OptionIsSome.simple_name().into(), (is_some_type_fun_ty, true)), (PrimDef::OptionIsNone.simple_name().into(), (is_some_type_fun_ty, true)), (PrimDef::OptionUnwrap.simple_name().into(), (unwrap_fun_ty, true)), ] .into_iter() .collect::>(), params: VarMap::from([(option_type_var.1, option_type_var.0)]), }); let size_t_ty = match size_t { 32 => uint32, 64 => uint64, _ => unreachable!(), }; let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None); let ndarray_ndims_tvar = unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None); let ndarray_copy_fun_ret_ty = unifier.get_fresh_var(None, None); let ndarray_copy_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![], ret: ndarray_copy_fun_ret_ty.0, vars: VarMap::from([ (ndarray_dtype_tvar.1, ndarray_dtype_tvar.0), (ndarray_ndims_tvar.1, ndarray_ndims_tvar.0), ]), })); let ndarray_fill_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![FuncArg { name: "value".into(), ty: ndarray_dtype_tvar.0, default_value: None, }], ret: none, vars: VarMap::from([ (ndarray_dtype_tvar.1, ndarray_dtype_tvar.0), (ndarray_ndims_tvar.1, ndarray_ndims_tvar.0), ]), })); let ndarray = unifier.add_ty(TypeEnum::TObj { obj_id: PrimDef::NDArray.id(), fields: Mapping::from([ (PrimDef::NDArrayCopy.simple_name().into(), (ndarray_copy_fun_ty, true)), (PrimDef::NDArrayFill.simple_name().into(), (ndarray_fill_fun_ty, true)), ]), params: VarMap::from([ (ndarray_dtype_tvar.1, ndarray_dtype_tvar.0), (ndarray_ndims_tvar.1, ndarray_ndims_tvar.0), ]), }); unifier.unify(ndarray_copy_fun_ret_ty.0, ndarray).unwrap(); let primitives = PrimitiveStore { int32, int64, uint32, uint64, float, bool, none, range, str, exception, option, ndarray, size_t, }; unifier.put_primitive_store(&primitives); crate::typecheck::magic_methods::set_primitives_magic_methods(&primitives, &mut unifier); (primitives, unifier) } /// already include the `definition_id` of itself inside the ancestors vector /// when first registering, the `type_vars`, fields, methods, ancestors are invalid #[must_use] pub fn make_top_level_class_def( obj_id: DefinitionId, resolver: Option>, name: StrRef, constructor: Option, loc: Option, ) -> TopLevelDef { TopLevelDef::Class { name, object_id: obj_id, type_vars: Vec::default(), fields: Vec::default(), methods: Vec::default(), ancestors: Vec::default(), constructor, resolver, loc, } } /// when first registering, the type is a invalid value #[must_use] pub fn make_top_level_function_def( name: String, simple_name: StrRef, ty: Type, resolver: Option>, loc: Option, ) -> TopLevelDef { TopLevelDef::Function { name, simple_name, signature: ty, var_id: Vec::default(), instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver, codegen_callback: None, loc, } } #[must_use] pub fn make_class_method_name(mut class_name: String, method_name: &str) -> String { class_name.push('.'); class_name.push_str(method_name); class_name } pub fn get_class_method_def_info( class_methods_def: &[(StrRef, Type, DefinitionId)], method_name: StrRef, ) -> Result<(Type, DefinitionId), HashSet> { for (name, ty, def_id) in class_methods_def { if name == &method_name { return Ok((*ty, *def_id)); } } Err(HashSet::from([format!("no method {method_name} in the current class")])) } /// get all base class def id of a class, excluding itself. \ /// this function should called only after the direct parent is set /// and before all the ancestors are set /// and when we allow single inheritance \ /// the order of the returned list is from the child to the deepest ancestor pub fn get_all_ancestors_helper( child: &TypeAnnotation, temp_def_list: &[Arc>], ) -> Result, HashSet> { let mut result: Vec = Vec::new(); let mut parent = Self::get_parent(child, temp_def_list); while let Some(p) = parent { parent = Self::get_parent(&p, temp_def_list); let p_id = if let TypeAnnotation::CustomClass { id, .. } = &p { *id } else { unreachable!("must be class kind annotation") }; // check cycle let no_cycle = result.iter().all(|x| { let TypeAnnotation::CustomClass { id, .. } = x else { unreachable!("must be class kind annotation") }; id.0 != p_id.0 }); if no_cycle { result.push(p); } else { return Err(HashSet::from(["cyclic inheritance detected".into()])); } } Ok(result) } /// should only be called when finding all ancestors, so panic when wrong fn get_parent( child: &TypeAnnotation, temp_def_list: &[Arc>], ) -> Option { let child_id = if let TypeAnnotation::CustomClass { id, .. } = child { *id } else { unreachable!("should be class type annotation") }; let child_def = temp_def_list.get(child_id.0).unwrap(); let child_def = child_def.read(); let TopLevelDef::Class { ancestors, .. } = &*child_def else { unreachable!("child must be top level class def") }; if ancestors.is_empty() { None } else { Some(ancestors[0].clone()) } } /// get the `var_id` of a given `TVar` type pub fn get_var_id(var_ty: Type, unifier: &mut Unifier) -> Result> { if let TypeEnum::TVar { id, .. } = unifier.get_ty(var_ty).as_ref() { Ok(*id) } else { Err(HashSet::from(["not type var".to_string()])) } } pub fn check_overload_function_type( this: Type, other: Type, unifier: &mut Unifier, type_var_to_concrete_def: &HashMap, ) -> bool { let this = unifier.get_ty(this); let this = this.as_ref(); let other = unifier.get_ty(other); let other = other.as_ref(); let ( TypeEnum::TFunc(FunSignature { args: this_args, ret: this_ret, .. }), TypeEnum::TFunc(FunSignature { args: other_args, ret: other_ret, .. }), ) = (this, other) else { unreachable!("this function must be called with function type") }; // check args let args_ok = this_args .iter() .map(|FuncArg { name, ty, .. }| (name, type_var_to_concrete_def.get(ty).unwrap())) .zip(other_args.iter().map(|FuncArg { name, ty, .. }| { (name, type_var_to_concrete_def.get(ty).unwrap()) })) .all(|(this, other)| { if this.0 == &"self".into() && this.0 == other.0 { true } else { this.0 == other.0 && check_overload_type_annotation_compatible(this.1, other.1, unifier) } }); // check rets let ret_ok = check_overload_type_annotation_compatible( type_var_to_concrete_def.get(this_ret).unwrap(), type_var_to_concrete_def.get(other_ret).unwrap(), unifier, ); // return args_ok && ret_ok } pub fn check_overload_field_type( this: Type, other: Type, unifier: &mut Unifier, type_var_to_concrete_def: &HashMap, ) -> bool { check_overload_type_annotation_compatible( type_var_to_concrete_def.get(&this).unwrap(), type_var_to_concrete_def.get(&other).unwrap(), unifier, ) } pub fn get_all_assigned_field(stmts: &[Stmt<()>]) -> Result, HashSet> { let mut result = HashSet::new(); for s in stmts { match &s.node { ast::StmtKind::AnnAssign { target, .. } if { if let ast::ExprKind::Attribute { value, .. } = &target.node { if let ast::ExprKind::Name { id, .. } = &value.node { id == &"self".into() } else { false } } else { false } } => { return Err(HashSet::from([format!( "redundant type annotation for class fields at {}", s.location )])) } ast::StmtKind::Assign { targets, .. } => { for t in targets { if let ast::ExprKind::Attribute { value, attr, .. } = &t.node { if let ast::ExprKind::Name { id, .. } = &value.node { if id == &"self".into() { result.insert(*attr); } } } } } // TODO: do not check for For and While? ast::StmtKind::For { body, orelse, .. } | ast::StmtKind::While { body, orelse, .. } => { result.extend(Self::get_all_assigned_field(body.as_slice())?); result.extend(Self::get_all_assigned_field(orelse.as_slice())?); } ast::StmtKind::If { body, orelse, .. } => { let inited_for_sure = Self::get_all_assigned_field(body.as_slice())? .intersection(&Self::get_all_assigned_field(orelse.as_slice())?) .copied() .collect::>(); result.extend(inited_for_sure); } ast::StmtKind::Try { body, orelse, finalbody, .. } => { let inited_for_sure = Self::get_all_assigned_field(body.as_slice())? .intersection(&Self::get_all_assigned_field(orelse.as_slice())?) .copied() .collect::>(); result.extend(inited_for_sure); result.extend(Self::get_all_assigned_field(finalbody.as_slice())?); } ast::StmtKind::With { body, .. } => { result.extend(Self::get_all_assigned_field(body.as_slice())?); } ast::StmtKind::Pass { .. } | ast::StmtKind::Assert { .. } | ast::StmtKind::Expr { .. } => {} _ => { unimplemented!() } } } Ok(result) } pub fn parse_parameter_default_value( default: &ast::Expr, resolver: &(dyn SymbolResolver + Send + Sync), ) -> Result> { parse_parameter_default_value(default, resolver) } pub fn check_default_param_type( val: &SymbolValue, ty: &TypeAnnotation, primitive: &PrimitiveStore, unifier: &mut Unifier, ) -> Result<(), String> { fn is_compatible( found: &TypeAnnotation, expect: &TypeAnnotation, unifier: &mut Unifier, primitive: &PrimitiveStore, ) -> bool { match (found, expect) { (TypeAnnotation::Primitive(f), TypeAnnotation::Primitive(e)) => { unifier.unioned(*f, *e) } ( TypeAnnotation::CustomClass { id: f_id, params: f_param }, TypeAnnotation::CustomClass { id: e_id, params: e_param }, ) => { *f_id == *e_id && *f_id == primitive.option.obj_id(unifier).unwrap() && (f_param.is_empty() || (f_param.len() == 1 && e_param.len() == 1 && is_compatible(&f_param[0], &e_param[0], unifier, primitive))) } (TypeAnnotation::Tuple(f), TypeAnnotation::Tuple(e)) => { f.len() == e.len() && f.iter() .zip(e.iter()) .all(|(f, e)| is_compatible(f, e, unifier, primitive)) } _ => false, } } let found = val.get_type_annotation(primitive, unifier); if is_compatible(&found, ty, unifier, primitive) { Ok(()) } else { Err(format!( "incompatible default parameter type, expect {}, found {}", ty.stringify(unifier), found.stringify(unifier), )) } } } pub fn parse_parameter_default_value( default: &ast::Expr, resolver: &(dyn SymbolResolver + Send + Sync), ) -> Result> { fn handle_constant(val: &Constant, loc: &Location) -> Result> { match val { Constant::Int(v) => { if let Ok(v) = (*v).try_into() { Ok(SymbolValue::I32(v)) } else { Err(HashSet::from([format!("integer value out of range at {loc}")])) } } Constant::Float(v) => Ok(SymbolValue::Double(*v)), Constant::Bool(v) => Ok(SymbolValue::Bool(*v)), Constant::Tuple(tuple) => Ok(SymbolValue::Tuple( tuple.iter().map(|x| handle_constant(x, loc)).collect::, _>>()?, )), Constant::None => Err(HashSet::from([format!( "`None` is not supported, use `none` for option type instead ({loc})" )])), _ => unimplemented!("this constant is not supported at {}", loc), } } match &default.node { ast::ExprKind::Constant { value, .. } => handle_constant(value, &default.location), ast::ExprKind::Call { func, args, .. } if args.len() == 1 => match &func.node { ast::ExprKind::Name { id, .. } if *id == "int64".into() => match &args[0].node { ast::ExprKind::Constant { value: Constant::Int(v), .. } => { let v: Result = (*v).try_into(); match v { Ok(v) => Ok(SymbolValue::I64(v)), _ => Err(HashSet::from([format!( "default param value out of range at {}", default.location )])), } } _ => Err(HashSet::from([format!( "only allow constant integer here at {}", default.location )])), }, ast::ExprKind::Name { id, .. } if *id == "uint32".into() => match &args[0].node { ast::ExprKind::Constant { value: Constant::Int(v), .. } => { let v: Result = (*v).try_into(); match v { Ok(v) => Ok(SymbolValue::U32(v)), _ => Err(HashSet::from([format!( "default param value out of range at {}", default.location )])), } } _ => Err(HashSet::from([format!( "only allow constant integer here at {}", default.location )])), }, ast::ExprKind::Name { id, .. } if *id == "uint64".into() => match &args[0].node { ast::ExprKind::Constant { value: Constant::Int(v), .. } => { let v: Result = (*v).try_into(); match v { Ok(v) => Ok(SymbolValue::U64(v)), _ => Err(HashSet::from([format!( "default param value out of range at {}", default.location )])), } } _ => Err(HashSet::from([format!( "only allow constant integer here at {}", default.location )])), }, ast::ExprKind::Name { id, .. } if *id == "Some".into() => Ok(SymbolValue::OptionSome( Box::new(parse_parameter_default_value(&args[0], resolver)?), )), _ => Err(HashSet::from([format!( "unsupported default parameter at {}", default.location )])), }, ast::ExprKind::Tuple { elts, .. } => Ok(SymbolValue::Tuple( elts.iter() .map(|x| parse_parameter_default_value(x, resolver)) .collect::, _>>()?, )), ast::ExprKind::Name { id, .. } if id == &"none".into() => Ok(SymbolValue::OptionNone), ast::ExprKind::Name { id, .. } => { resolver.get_default_param_value(default).ok_or_else(|| { HashSet::from([format!( "`{}` cannot be used as a default parameter at {} \ (not primitive type, option or tuple / not defined?)", id, default.location )]) }) } _ => Err(HashSet::from([format!( "unsupported default parameter (not primitive type, option or tuple) at {}", default.location )])), } } /// Obtains the element type of an array-like type. pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type { match &*unifier.get_ty(ty) { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { unpack_ndarray_var_tys(unifier, ty).0 } TypeEnum::TList { ty } => arraylike_flatten_element_type(unifier, *ty), _ => ty, } } /// Obtains the number of dimensions of an array-like type. pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 { match &*unifier.get_ty(ty) { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { let ndims = unpack_ndarray_var_tys(unifier, ty).1; let TypeEnum::TLiteral { values, .. } = &*unifier.get_ty_immutable(ndims) else { panic!("Expected TLiteral for ndarray.ndims, got {}", unifier.stringify(ndims)) }; if values.len() > 1 { todo!("Getting num of dimensions for ndarray with more than one ndim bound is unimplemented") } u64::try_from(values[0].clone()).unwrap() } TypeEnum::TList { ty } => arraylike_get_ndims(unifier, *ty) + 1, _ => 0, } }