From 03cdca606de86419825dd8499793788f6a1bd7c3 Mon Sep 17 00:00:00 2001 From: ychenfo Date: Tue, 23 Nov 2021 01:02:42 +0800 Subject: [PATCH] nac3core: check type for default parameter --- nac3core/src/toplevel/composer.rs | 19 +++++++++-- nac3core/src/toplevel/helper.rs | 56 +++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 3 deletions(-) diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 232e0c4..677542b 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -1135,8 +1135,16 @@ impl TopLevelComposer { ty, default_value: match default { None => None, - Some(default) => - Some(Self::parse_parameter_default_value(default, resolver)?) + Some(default) => Some({ + let v = Self::parse_parameter_default_value(default, resolver)?; + Self::check_default_param_type( + &v, + &type_annotation, + primitives_store, + unifier + ).map_err(|err| format!("{} at {}", err, x.location))?; + v + }) } }) }) @@ -1353,7 +1361,12 @@ impl TopLevelComposer { if name == "self".into() { return Err(format!("`self` parameter cannot take default value at {}", x.location)); } - Some(Self::parse_parameter_default_value(default, class_resolver)?) + Some({ + let v = Self::parse_parameter_default_value(default, class_resolver)?; + Self::check_default_param_type(&v, &type_ann, primitives, unifier) + .map_err(|err| format!("{} at {}", err, x.location))?; + v + }) } } }; diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 2dcb359..5c8b85e 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -350,6 +350,62 @@ impl TopLevelComposer { pub fn parse_parameter_default_value(default: &ast::Expr, resolver: &(dyn SymbolResolver + Send + Sync)) -> Result { parse_parameter_default_value(default, resolver) } + + pub fn check_default_param_type(val: &SymbolValue, ty: &TypeAnnotation, primitive: &PrimitiveStore, unifier: &mut Unifier) -> Result<(), String> { + let res = match val { + SymbolValue::Bool(..) => { + if matches!(ty, TypeAnnotation::Primitive(t) if *t == primitive.bool) { + None + } else { + Some("bool".to_string()) + } + } + SymbolValue::Double(..) => { + if matches!(ty, TypeAnnotation::Primitive(t) if *t == primitive.float) { + None + } else { + Some("float".to_string()) + } + } + SymbolValue::I32(..) => { + if matches!(ty, TypeAnnotation::Primitive(t) if *t == primitive.int32) { + None + } else { + Some("int32".to_string()) + } + } + SymbolValue::I64(..) => { + if matches!(ty, TypeAnnotation::Primitive(t) if *t == primitive.int64) { + None + } else { + Some("int64".to_string()) + } + } + SymbolValue::Tuple(elts) => { + if let TypeAnnotation::Tuple(elts_ty) = ty { + for (e, t) in elts.iter().zip(elts_ty.iter()) { + Self::check_default_param_type(e, t, primitive, unifier)? + } + if elts.len() != elts_ty.len() { + Some(format!("tuple of length {}", elts.len())) + } else { + None + } + } else { + Some("tuple".to_string()) + } + } + }; + if let Some(found) = res { + Err(format!( + "incompatible default parameter type, expect {}, found {}", + ty.stringify(unifier), + found + )) + } else { + Ok(()) + } + } } pub fn parse_parameter_default_value(default: &ast::Expr, resolver: &(dyn SymbolResolver + Send + Sync)) -> Result {