diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index e420bd8..1652509 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -2,7 +2,7 @@ use inkwell::{types::BasicType, values::BasicValueEnum, AddressSpace}; use nac3core::{ codegen::CodeGenContext, location::Location, - symbol_resolver::SymbolResolver, + symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::{DefinitionId, TopLevelDef}, typecheck::{ type_inferencer::PrimitiveStore, @@ -14,7 +14,7 @@ use pyo3::{ types::{PyList, PyModule, PyTuple}, PyAny, PyObject, PyResult, Python, }; -use nac3parser::ast::StrRef; +use nac3parser::ast::{self, StrRef}; use std::{ cell::RefCell, collections::{HashMap, HashSet}, @@ -399,9 +399,86 @@ impl Resolver { } } } + + fn get_default_param_obj_value(&self, obj: &PyAny, helper: &PythonHelper) -> PyResult> { + let ty_id: u64 = helper + .id_fn + .call1((helper.type_fn.call1((obj,))?,))? + .extract()?; + Ok( + if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 { + let val: i32 = obj.extract()?; + Ok(SymbolValue::I32(val)) + } else if ty_id == self.primitive_ids.int64 { + let val: i64 = obj.extract()?; + Ok(SymbolValue::I64(val)) + } else if ty_id == self.primitive_ids.bool { + let val: bool = obj.extract()?; + Ok(SymbolValue::Bool(val)) + } else if ty_id == self.primitive_ids.float { + let val: f64 = obj.extract()?; + Ok(SymbolValue::Double(val)) + } else if ty_id == self.primitive_ids.tuple { + let elements: &PyTuple = obj.cast_as()?; + let elements: Result, String>, _> = elements + .iter() + .map(|elem| { + self.get_default_param_obj_value( + elem, + helper + ) + }) + .collect(); + let elements = match elements? { + Ok(el) => el, + Err(err) => return Ok(Err(err)) + }; + Ok(SymbolValue::Tuple(elements)) + } else { + Err("only primitives values and tuple can be default parameter value".into()) + } + ) + } } impl SymbolResolver for Resolver { + fn get_default_param_value( + &self, + expr: &ast::Expr + ) -> Option { + match &expr.node { + ast::ExprKind::Name { id, .. } => { + Python::with_gil( + |py| -> PyResult> { + let obj: &PyAny = self.module.extract(py)?; + let members: &PyList = PyModule::import(py, "inspect")? + .getattr("getmembers")? + .call1((obj,))? + .cast_as()?; + let mut sym_value = None; + for member in members.iter() { + let key: &str = member.get_item(0)?.extract()?; + let val = member.get_item(1)?; + if key == id.to_string() { + let builtins = PyModule::import(py, "builtins")?; + let helper = PythonHelper { + id_fn: builtins.getattr("id").unwrap(), + len_fn: builtins.getattr("len").unwrap(), + type_fn: builtins.getattr("type").unwrap(), + }; + sym_value = Some(self.get_default_param_obj_value(val, &helper).unwrap().unwrap()); + break; + } + } + Ok(sym_value) + } + ) + .unwrap() + } + _ => unimplemented!("other type of expr not supported at {}", expr.location) + } + } + fn get_symbol_type( &self, unifier: &mut Unifier, diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index ce958ff..f9e2810 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -37,6 +37,10 @@ impl Resolver { } impl SymbolResolver for Resolver { + fn get_default_param_value(&self, _: &nac3parser::ast::Expr) -> Option { + unimplemented!() + } + fn get_symbol_type( &self, _: &mut Unifier, diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 161d78a..e8d594d 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -44,6 +44,7 @@ pub trait SymbolResolver { ctx: &mut CodeGenContext<'ctx, 'a>, ) -> Option>; fn get_symbol_location(&self, str: StrRef) -> Option; + fn get_default_param_value(&self, expr: &nac3parser::ast::Expr) -> Option; // handle function call etc. } diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 1cb1680..232e0c4 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -1136,7 +1136,7 @@ impl TopLevelComposer { default_value: match default { None => None, Some(default) => - Some(Self::parse_parameter_default_value(default)?) + Some(Self::parse_parameter_default_value(default, resolver)?) } }) }) @@ -1353,7 +1353,7 @@ impl TopLevelComposer { if name == "self".into() { return Err(format!("`self` parameter cannot take default value at {}", x.location)); } - Some(Self::parse_parameter_default_value(default)?) + Some(Self::parse_parameter_default_value(default, class_resolver)?) } } }; diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 397a10d..6d46d9d 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -347,51 +347,64 @@ impl TopLevelComposer { Ok(result) } - pub fn parse_parameter_default_value(default: &ast::Expr) -> Result { - fn handle_constant(val: &Constant, loc: &Location) -> Result { - match val { - Constant::Int(v) => { - if let Ok(v) = v.try_into() { - Ok(SymbolValue::I32(v)) - } else { - Err(format!( - "int64 default parameter should be specified explicitly by `int64()` at {}", - loc - )) - } - } - Constant::Float(v) => Ok(SymbolValue::Double(*v)), - Constant::Bool(v) => Ok(SymbolValue::Bool(*v)), - Constant::Tuple(tuple) => Ok(SymbolValue::Tuple( - tuple.iter().map(|x| handle_constant(x, loc)).collect::, _>>()? - )), - _ => unimplemented!("this constant is not supported at {}", loc), - } - } - match &default.node { - ast::ExprKind::Constant { value, .. } => handle_constant(value, &default.location), - ast::ExprKind::Call { func, args, .. } if { - match &func.node { - ast::ExprKind::Name { id, .. } => *id == "int64".into(), - _ => false, - } - } => { - if args.len() == 1 { - match &args[0].node { - ast::ExprKind::Constant { value: Constant::Int(v), .. } => - Ok(SymbolValue::I64(v.try_into().unwrap())), - _ => Err(format!("only allow constant integer here at {}", default.location)) - } - } else { - Err(format!("only allow constant integer here at {}", default.location)) - } - } - ast::ExprKind::Tuple { elts, .. } => Ok(SymbolValue::Tuple(elts - .iter() - .map(|x| Self::parse_parameter_default_value(x)) - .collect::, _>>()? - )), - _ => unimplemented!("only constant default is supported now at {}", default.location), - } + pub fn parse_parameter_default_value(default: &ast::Expr, resolver: &(dyn SymbolResolver + Send + Sync)) -> Result { + parse_parameter_default_value(default, resolver) + } +} + +pub fn parse_parameter_default_value(default: &ast::Expr, resolver: &(dyn SymbolResolver + Send + Sync)) -> Result { + fn handle_constant(val: &Constant, loc: &Location) -> Result { + match val { + Constant::Int(v) => { + if let Ok(v) = v.try_into() { + Ok(SymbolValue::I32(v)) + } else { + Err(format!( + "int64 default parameter should be specified explicitly by `int64()` at {}", + loc + )) + } + } + Constant::Float(v) => Ok(SymbolValue::Double(*v)), + Constant::Bool(v) => Ok(SymbolValue::Bool(*v)), + Constant::Tuple(tuple) => Ok(SymbolValue::Tuple( + tuple.iter().map(|x| handle_constant(x, loc)).collect::, _>>()? + )), + _ => unimplemented!("this constant is not supported at {}", loc), + } + } + match &default.node { + ast::ExprKind::Constant { value, .. } => handle_constant(value, &default.location), + ast::ExprKind::Call { func, args, .. } if { + match &func.node { + ast::ExprKind::Name { id, .. } => *id == "int64".into(), + _ => false, + } + } => { + if args.len() == 1 { + match &args[0].node { + ast::ExprKind::Constant { value: Constant::Int(v), .. } => + Ok(SymbolValue::I64(v.try_into().unwrap())), + _ => Err(format!("only allow constant integer here at {}", default.location)) + } + } else { + Err(format!("only allow constant integer here at {}", default.location)) + } + } + ast::ExprKind::Tuple { elts, .. } => Ok(SymbolValue::Tuple(elts + .iter() + .map(|x| parse_parameter_default_value(x, resolver)) + .collect::, _>>()? + )), + ast::ExprKind::Name { id, .. } => { + resolver.get_default_param_value(default).ok_or_else( + || format!( + "this module global `{}` cannot be used as a default parameter at {} (should be primitive type or tuple)", + id, + default.location + ) + ) + } + _ => unimplemented!("only constant default is supported now at {}", default.location), } } diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index c96a7ca..f6fc15b 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -23,7 +23,7 @@ use inkwell::values::BasicValueEnum; pub struct DefinitionId(pub usize); pub mod composer; -mod helper; +pub mod helper; mod type_annotation; use composer::*; use type_annotation::*; diff --git a/nac3core/src/toplevel/test.rs b/nac3core/src/toplevel/test.rs index e279547..0f4fd83 100644 --- a/nac3core/src/toplevel/test.rs +++ b/nac3core/src/toplevel/test.rs @@ -35,6 +35,10 @@ impl ResolverInternal { struct Resolver(Arc); impl SymbolResolver for Resolver { + fn get_default_param_value(&self, _: &nac3parser::ast::Expr) -> Option { + unimplemented!() + } + fn get_symbol_type( &self, _: &mut Unifier, diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index d3aa09f..da51aa9 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -19,6 +19,10 @@ struct Resolver { } impl SymbolResolver for Resolver { + fn get_default_param_value(&self, _: &nac3parser::ast::Expr) -> Option { + unimplemented!() + } + fn get_symbol_type( &self, _: &mut Unifier, diff --git a/nac3standalone/src/basic_symbol_resolver.rs b/nac3standalone/src/basic_symbol_resolver.rs index 82c1ac3..31b3713 100644 --- a/nac3standalone/src/basic_symbol_resolver.rs +++ b/nac3standalone/src/basic_symbol_resolver.rs @@ -2,21 +2,22 @@ use inkwell::values::BasicValueEnum; use nac3core::{ codegen::CodeGenContext, location::Location, - symbol_resolver::SymbolResolver, - toplevel::{DefinitionId, TopLevelDef}, + symbol_resolver::{SymbolResolver, SymbolValue}, + toplevel::{DefinitionId, TopLevelDef, helper::parse_parameter_default_value}, typecheck::{ type_inferencer::PrimitiveStore, typedef::{Type, Unifier}, }, }; use parking_lot::{Mutex, RwLock}; -use nac3parser::ast::StrRef; +use nac3parser::ast::{self, StrRef}; use std::{collections::HashMap, sync::Arc}; pub struct ResolverInternal { pub id_to_type: Mutex>, pub id_to_def: Mutex>, pub class_names: Mutex>, + pub module_globals: Mutex>, } impl ResolverInternal { @@ -27,11 +28,27 @@ impl ResolverInternal { pub fn add_id_type(&self, id: StrRef, ty: Type) { self.id_to_type.lock().insert(id, ty); } + + pub fn add_module_global(&self, id: StrRef, expr: &ast::Expr) { + self.module_globals.lock().insert(id, expr.clone()); + } } pub struct Resolver(pub Arc); impl SymbolResolver for Resolver { + fn get_default_param_value(&self, expr: &ast::Expr) -> Option { + match &expr.node { + ast::ExprKind::Name { id, .. } => { + let expr = self.0.module_globals.lock().get(id).cloned(); + expr.map(|x| { + parse_parameter_default_value(&x, self).unwrap() + }) + } + _ => unimplemented!("other type of expr not supported at {}", expr.location) + } + } + fn get_symbol_type( &self, _: &mut Unifier, diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index 5b4e31d..33fc0e6 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -4,7 +4,7 @@ use inkwell::{ OptimizationLevel, }; use nac3core::typecheck::type_inferencer::PrimitiveStore; -use nac3parser::parser; +use nac3parser::{ast::{ExprKind, StmtKind}, parser}; use std::env; use std::fs; use std::{collections::HashMap, path::Path, sync::Arc, time::SystemTime}; @@ -48,6 +48,7 @@ fn main() { id_to_type: builtins_ty.into(), id_to_def: builtins_def.into(), class_names: Default::default(), + module_globals: Default::default(), } .into(); let resolver = @@ -66,6 +67,18 @@ fn main() { ); for stmt in parser_result.into_iter() { + // handle module globals + if let StmtKind::Assign { targets, value, .. } = &stmt.node { + if targets.len() == 1 { + if let ExprKind::Name { id, .. } = &targets[0].node { + internal_resolver.add_module_global(*id, value); + } + } else { + unimplemented!("only single assign supported now") + } + continue; + } + let (name, def_id, ty) = composer .register_top_level(stmt, Some(resolver.clone()), "__main__".into()) .unwrap();