From 45870888356b870fa6f4f91aaa463b29d41d94b2 Mon Sep 17 00:00:00 2001 From: ychenfo Date: Tue, 23 Nov 2021 07:32:09 +0800 Subject: [PATCH] Constant Default Parameter Support (#98) Add support for constant default parameter Reviewed-on: https://git.m-labs.hk/M-Labs/nac3/pulls/98 Co-authored-by: ychenfo Co-committed-by: ychenfo --- nac3artiq/src/symbol_resolver.rs | 81 +++++++++++- nac3core/src/codegen/test.rs | 4 + nac3core/src/symbol_resolver.rs | 1 + nac3core/src/toplevel/composer.rs | 64 ++++++++- nac3core/src/toplevel/helper.rs | 122 ++++++++++++++++++ nac3core/src/toplevel/mod.rs | 2 +- nac3core/src/toplevel/test.rs | 4 + .../src/typecheck/type_inferencer/test.rs | 4 + nac3standalone/src/basic_symbol_resolver.rs | 18 ++- nac3standalone/src/main.rs | 62 ++++++++- 10 files changed, 347 insertions(+), 15 deletions(-) diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index e420bd80..16525096 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -2,7 +2,7 @@ use inkwell::{types::BasicType, values::BasicValueEnum, AddressSpace}; use nac3core::{ codegen::CodeGenContext, location::Location, - symbol_resolver::SymbolResolver, + symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::{DefinitionId, TopLevelDef}, typecheck::{ type_inferencer::PrimitiveStore, @@ -14,7 +14,7 @@ use pyo3::{ types::{PyList, PyModule, PyTuple}, PyAny, PyObject, PyResult, Python, }; -use nac3parser::ast::StrRef; +use nac3parser::ast::{self, StrRef}; use std::{ cell::RefCell, collections::{HashMap, HashSet}, @@ -399,9 +399,86 @@ impl Resolver { } } } + + fn get_default_param_obj_value(&self, obj: &PyAny, helper: &PythonHelper) -> PyResult> { + let ty_id: u64 = helper + .id_fn + .call1((helper.type_fn.call1((obj,))?,))? + .extract()?; + Ok( + if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 { + let val: i32 = obj.extract()?; + Ok(SymbolValue::I32(val)) + } else if ty_id == self.primitive_ids.int64 { + let val: i64 = obj.extract()?; + Ok(SymbolValue::I64(val)) + } else if ty_id == self.primitive_ids.bool { + let val: bool = obj.extract()?; + Ok(SymbolValue::Bool(val)) + } else if ty_id == self.primitive_ids.float { + let val: f64 = obj.extract()?; + Ok(SymbolValue::Double(val)) + } else if ty_id == self.primitive_ids.tuple { + let elements: &PyTuple = obj.cast_as()?; + let elements: Result, String>, _> = elements + .iter() + .map(|elem| { + self.get_default_param_obj_value( + elem, + helper + ) + }) + .collect(); + let elements = match elements? { + Ok(el) => el, + Err(err) => return Ok(Err(err)) + }; + Ok(SymbolValue::Tuple(elements)) + } else { + Err("only primitives values and tuple can be default parameter value".into()) + } + ) + } } impl SymbolResolver for Resolver { + fn get_default_param_value( + &self, + expr: &ast::Expr + ) -> Option { + match &expr.node { + ast::ExprKind::Name { id, .. } => { + Python::with_gil( + |py| -> PyResult> { + let obj: &PyAny = self.module.extract(py)?; + let members: &PyList = PyModule::import(py, "inspect")? + .getattr("getmembers")? + .call1((obj,))? + .cast_as()?; + let mut sym_value = None; + for member in members.iter() { + let key: &str = member.get_item(0)?.extract()?; + let val = member.get_item(1)?; + if key == id.to_string() { + let builtins = PyModule::import(py, "builtins")?; + let helper = PythonHelper { + id_fn: builtins.getattr("id").unwrap(), + len_fn: builtins.getattr("len").unwrap(), + type_fn: builtins.getattr("type").unwrap(), + }; + sym_value = Some(self.get_default_param_obj_value(val, &helper).unwrap().unwrap()); + break; + } + } + Ok(sym_value) + } + ) + .unwrap() + } + _ => unimplemented!("other type of expr not supported at {}", expr.location) + } + } + fn get_symbol_type( &self, unifier: &mut Unifier, diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index ce958ff3..f9e2810f 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -37,6 +37,10 @@ impl Resolver { } impl SymbolResolver for Resolver { + fn get_default_param_value(&self, _: &nac3parser::ast::Expr) -> Option { + unimplemented!() + } + fn get_symbol_type( &self, _: &mut Unifier, diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 161d78a9..e8d594da 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -44,6 +44,7 @@ pub trait SymbolResolver { ctx: &mut CodeGenContext<'ctx, 'a>, ) -> Option>; fn get_symbol_location(&self, str: StrRef) -> Option; + fn get_default_param_value(&self, expr: &nac3parser::ast::Expr) -> Option; // handle function call etc. } diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index e7c17cd0..b0afaf26 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -1066,10 +1066,23 @@ impl TopLevelComposer { and names thould not be the same as the keywords" .into()); } - - args.args + + let arg_with_default: Vec<(&ast::Located>, Option<&ast::Expr>)> = args + .args .iter() - .map(|x| -> Result { + .rev() + .zip(args + .defaults + .iter() + .rev() + .map(|x| -> Option<&ast::Expr> { Some(x) }) + .chain(std::iter::repeat(None)) + ).collect_vec(); + + arg_with_default + .iter() + .rev() + .map(|(x, default)| -> Result { let annotation = x .node .annotation @@ -1121,7 +1134,19 @@ impl TopLevelComposer { Ok(FuncArg { name: x.node.arg, ty, - default_value: Default::default(), + default_value: match default { + None => None, + Some(default) => Some({ + let v = Self::parse_parameter_default_value(default, resolver)?; + Self::check_default_param_type( + &v, + &type_annotation, + primitives_store, + unifier + ).map_err(|err| format!("{} at {}", err, x.location))?; + v + }) + } }) }) .collect::, _>>()? @@ -1282,7 +1307,20 @@ impl TopLevelComposer { } let mut result = Vec::new(); - for x in &args.args { + + let arg_with_default: Vec<(&ast::Located>, Option<&ast::Expr>)> = args + .args + .iter() + .rev() + .zip(args + .defaults + .iter() + .rev() + .map(|x| -> Option<&ast::Expr> { Some(x) }) + .chain(std::iter::repeat(None)) + ).collect_vec(); + + for (x, default) in arg_with_default.into_iter().rev() { let name = x.node.arg; if name != zelf { let type_ann = { @@ -1327,8 +1365,20 @@ impl TopLevelComposer { let dummy_func_arg = FuncArg { name, ty: unifier.get_fresh_var().0, - // TODO: default value? - default_value: None, + default_value: match default { + None => None, + Some(default) => { + if name == "self".into() { + return Err(format!("`self` parameter cannot take default value at {}", x.location)); + } + Some({ + let v = Self::parse_parameter_default_value(default, class_resolver)?; + Self::check_default_param_type(&v, &type_ann, primitives, unifier) + .map_err(|err| format!("{} at {}", err, x.location))?; + v + }) + } + } }; // push the dummy type and the type annotation // into the list for later unification diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 35da97c9..5c8b85e7 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -1,3 +1,8 @@ +use std::convert::TryInto; + +use nac3parser::ast::{Constant, Location}; +use crate::symbol_resolver::SymbolValue; + use super::*; impl TopLevelDef { @@ -341,4 +346,121 @@ impl TopLevelComposer { } Ok(result) } + + pub fn parse_parameter_default_value(default: &ast::Expr, resolver: &(dyn SymbolResolver + Send + Sync)) -> Result { + parse_parameter_default_value(default, resolver) + } + + pub fn check_default_param_type(val: &SymbolValue, ty: &TypeAnnotation, primitive: &PrimitiveStore, unifier: &mut Unifier) -> Result<(), String> { + let res = match val { + SymbolValue::Bool(..) => { + if matches!(ty, TypeAnnotation::Primitive(t) if *t == primitive.bool) { + None + } else { + Some("bool".to_string()) + } + } + SymbolValue::Double(..) => { + if matches!(ty, TypeAnnotation::Primitive(t) if *t == primitive.float) { + None + } else { + Some("float".to_string()) + } + } + SymbolValue::I32(..) => { + if matches!(ty, TypeAnnotation::Primitive(t) if *t == primitive.int32) { + None + } else { + Some("int32".to_string()) + } + } + SymbolValue::I64(..) => { + if matches!(ty, TypeAnnotation::Primitive(t) if *t == primitive.int64) { + None + } else { + Some("int64".to_string()) + } + } + SymbolValue::Tuple(elts) => { + if let TypeAnnotation::Tuple(elts_ty) = ty { + for (e, t) in elts.iter().zip(elts_ty.iter()) { + Self::check_default_param_type(e, t, primitive, unifier)? + } + if elts.len() != elts_ty.len() { + Some(format!("tuple of length {}", elts.len())) + } else { + None + } + } else { + Some("tuple".to_string()) + } + } + }; + if let Some(found) = res { + Err(format!( + "incompatible default parameter type, expect {}, found {}", + ty.stringify(unifier), + found + )) + } else { + Ok(()) + } + } +} + +pub fn parse_parameter_default_value(default: &ast::Expr, resolver: &(dyn SymbolResolver + Send + Sync)) -> Result { + fn handle_constant(val: &Constant, loc: &Location) -> Result { + match val { + Constant::Int(v) => { + if let Ok(v) = v.try_into() { + Ok(SymbolValue::I32(v)) + } else { + Err(format!( + "integer value out of range at {}", + loc + )) + } + } + Constant::Float(v) => Ok(SymbolValue::Double(*v)), + Constant::Bool(v) => Ok(SymbolValue::Bool(*v)), + Constant::Tuple(tuple) => Ok(SymbolValue::Tuple( + tuple.iter().map(|x| handle_constant(x, loc)).collect::, _>>()? + )), + _ => unimplemented!("this constant is not supported at {}", loc), + } + } + match &default.node { + ast::ExprKind::Constant { value, .. } => handle_constant(value, &default.location), + ast::ExprKind::Call { func, args, .. } if { + match &func.node { + ast::ExprKind::Name { id, .. } => *id == "int64".into(), + _ => false, + } + } => { + if args.len() == 1 { + match &args[0].node { + ast::ExprKind::Constant { value: Constant::Int(v), .. } => + Ok(SymbolValue::I64(v.try_into().unwrap())), + _ => Err(format!("only allow constant integer here at {}", default.location)) + } + } else { + Err(format!("only allow constant integer here at {}", default.location)) + } + } + ast::ExprKind::Tuple { elts, .. } => Ok(SymbolValue::Tuple(elts + .iter() + .map(|x| parse_parameter_default_value(x, resolver)) + .collect::, _>>()? + )), + ast::ExprKind::Name { id, .. } => { + resolver.get_default_param_value(default).ok_or_else( + || format!( + "`{}` cannot be used as a default parameter at {} (not primitive type or tuple / not defined?)", + id, + default.location + ) + ) + } + _ => unimplemented!("only constant default is supported now at {}", default.location), + } } diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index c96a7cac..f6fc15b0 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -23,7 +23,7 @@ use inkwell::values::BasicValueEnum; pub struct DefinitionId(pub usize); pub mod composer; -mod helper; +pub mod helper; mod type_annotation; use composer::*; use type_annotation::*; diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index e2795474..0f4fd83f 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -35,6 +35,10 @@ impl ResolverInternal { struct Resolver(Arc); impl SymbolResolver for Resolver { + fn get_default_param_value(&self, _: &nac3parser::ast::Expr) -> Option { + unimplemented!() + } + fn get_symbol_type( &self, _: &mut Unifier, diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index d3aa09fa..da51aa9d 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -19,6 +19,10 @@ struct Resolver { } impl SymbolResolver for Resolver { + fn get_default_param_value(&self, _: &nac3parser::ast::Expr) -> Option { + unimplemented!() + } + fn get_symbol_type( &self, _: &mut Unifier, diff --git a/nac3standalone/src/basic_symbol_resolver.rs b/nac3standalone/src/basic_symbol_resolver.rs index 82c1ac31..9983d507 100644 --- a/nac3standalone/src/basic_symbol_resolver.rs +++ b/nac3standalone/src/basic_symbol_resolver.rs @@ -2,7 +2,7 @@ use inkwell::values::BasicValueEnum; use nac3core::{ codegen::CodeGenContext, location::Location, - symbol_resolver::SymbolResolver, + symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::{DefinitionId, TopLevelDef}, typecheck::{ type_inferencer::PrimitiveStore, @@ -10,13 +10,14 @@ use nac3core::{ }, }; use parking_lot::{Mutex, RwLock}; -use nac3parser::ast::StrRef; +use nac3parser::ast::{self, StrRef}; use std::{collections::HashMap, sync::Arc}; pub struct ResolverInternal { pub id_to_type: Mutex>, pub id_to_def: Mutex>, pub class_names: Mutex>, + pub module_globals: Mutex>, } impl ResolverInternal { @@ -27,11 +28,24 @@ impl ResolverInternal { pub fn add_id_type(&self, id: StrRef, ty: Type) { self.id_to_type.lock().insert(id, ty); } + + pub fn add_module_global(&self, id: StrRef, val: SymbolValue) { + self.module_globals.lock().insert(id, val); + } } pub struct Resolver(pub Arc); impl SymbolResolver for Resolver { + fn get_default_param_value(&self, expr: &ast::Expr) -> Option { + match &expr.node { + ast::ExprKind::Name { id, .. } => { + self.0.module_globals.lock().get(id).cloned() + } + _ => unimplemented!("other type of expr not supported at {}", expr.location) + } + } + fn get_symbol_type( &self, _: &mut Unifier, diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index 17b61b2b..f6c4be68 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -4,8 +4,8 @@ use inkwell::{ OptimizationLevel, }; use nac3core::typecheck::type_inferencer::PrimitiveStore; -use nac3parser::parser; -use std::env; +use nac3parser::{ast::{Expr, ExprKind, StmtKind}, parser}; +use std::{borrow::Borrow, env}; use std::fs; use std::{collections::HashMap, path::Path, sync::Arc, time::SystemTime}; @@ -15,7 +15,7 @@ use nac3core::{ WorkerRegistry, }, symbol_resolver::SymbolResolver, - toplevel::{composer::TopLevelComposer, TopLevelDef}, + toplevel::{composer::TopLevelComposer, TopLevelDef, helper::parse_parameter_default_value}, typecheck::typedef::FunSignature, }; @@ -48,6 +48,7 @@ fn main() { id_to_type: builtins_ty.into(), id_to_def: builtins_def.into(), class_names: Default::default(), + module_globals: Default::default(), } .into(); let resolver = @@ -66,6 +67,61 @@ fn main() { ); for stmt in parser_result.into_iter() { + if let StmtKind::Assign { targets, value, .. } = &stmt.node { + fn handle_assignment_pattern( + targets: &[Expr], + value: &Expr, + resolver: &(dyn SymbolResolver + Send + Sync), + internal_resolver: &ResolverInternal, + ) -> Result<(), String> { + if targets.len() == 1 { + match &targets[0].node { + ExprKind::Name { id, .. } => { + let val = parse_parameter_default_value(value.borrow(), resolver)?; + internal_resolver.add_module_global(*id, val); + Ok(()) + } + ExprKind::List { elts, .. } + | ExprKind::Tuple { elts, .. } => { + handle_assignment_pattern(elts, value, resolver, internal_resolver)?; + Ok(()) + } + _ => unreachable!("cannot be assigned") + } + } else { + match &value.node { + ExprKind::List { elts, .. } + | ExprKind::Tuple { elts, .. } => { + if elts.len() != targets.len() { + Err(format!( + "number of elements to unpack does not match (expect {}, found {}) at {}", + targets.len(), + elts.len(), + value.location + )) + } else { + for (tar, val) in targets.iter().zip(elts) { + handle_assignment_pattern( + std::slice::from_ref(tar), + val, + resolver, + internal_resolver + )?; + } + Ok(()) + } + }, + _ => Err(format!("unpack of this expression is not supported at {}", value.location)) + } + } + } + if let Err(err) = handle_assignment_pattern(targets, value, resolver.as_ref(), internal_resolver.as_ref()) { + eprintln!("{}", err); + return; + } + continue; + } + let (name, def_id, ty) = composer .register_top_level(stmt, Some(resolver.clone()), "__main__".into()) .unwrap();