1
0
forked from M-Labs/nac3

Constant Default Parameter Support (#98)

Add support for constant default parameter

Reviewed-on: M-Labs/nac3#98
Co-authored-by: ychenfo <yc@m-labs.hk>
Co-committed-by: ychenfo <yc@m-labs.hk>
This commit is contained in:
ychenfo 2021-11-23 07:32:09 +08:00 committed by sb10q
parent 49476d06e1
commit 4587088835
10 changed files with 347 additions and 15 deletions

View File

@ -2,7 +2,7 @@ use inkwell::{types::BasicType, values::BasicValueEnum, AddressSpace};
use nac3core::{ use nac3core::{
codegen::CodeGenContext, codegen::CodeGenContext,
location::Location, location::Location,
symbol_resolver::SymbolResolver, symbol_resolver::{SymbolResolver, SymbolValue},
toplevel::{DefinitionId, TopLevelDef}, toplevel::{DefinitionId, TopLevelDef},
typecheck::{ typecheck::{
type_inferencer::PrimitiveStore, type_inferencer::PrimitiveStore,
@ -14,7 +14,7 @@ use pyo3::{
types::{PyList, PyModule, PyTuple}, types::{PyList, PyModule, PyTuple},
PyAny, PyObject, PyResult, Python, PyAny, PyObject, PyResult, Python,
}; };
use nac3parser::ast::StrRef; use nac3parser::ast::{self, StrRef};
use std::{ use std::{
cell::RefCell, cell::RefCell,
collections::{HashMap, HashSet}, collections::{HashMap, HashSet},
@ -399,9 +399,86 @@ impl Resolver {
} }
} }
} }
fn get_default_param_obj_value(&self, obj: &PyAny, helper: &PythonHelper) -> PyResult<Result<SymbolValue, String>> {
let ty_id: u64 = helper
.id_fn
.call1((helper.type_fn.call1((obj,))?,))?
.extract()?;
Ok(
if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 {
let val: i32 = obj.extract()?;
Ok(SymbolValue::I32(val))
} else if ty_id == self.primitive_ids.int64 {
let val: i64 = obj.extract()?;
Ok(SymbolValue::I64(val))
} else if ty_id == self.primitive_ids.bool {
let val: bool = obj.extract()?;
Ok(SymbolValue::Bool(val))
} else if ty_id == self.primitive_ids.float {
let val: f64 = obj.extract()?;
Ok(SymbolValue::Double(val))
} else if ty_id == self.primitive_ids.tuple {
let elements: &PyTuple = obj.cast_as()?;
let elements: Result<Result<Vec<_>, String>, _> = elements
.iter()
.map(|elem| {
self.get_default_param_obj_value(
elem,
helper
)
})
.collect();
let elements = match elements? {
Ok(el) => el,
Err(err) => return Ok(Err(err))
};
Ok(SymbolValue::Tuple(elements))
} else {
Err("only primitives values and tuple can be default parameter value".into())
}
)
}
} }
impl SymbolResolver for Resolver { impl SymbolResolver for Resolver {
fn get_default_param_value(
&self,
expr: &ast::Expr
) -> Option<SymbolValue> {
match &expr.node {
ast::ExprKind::Name { id, .. } => {
Python::with_gil(
|py| -> PyResult<Option<SymbolValue>> {
let obj: &PyAny = self.module.extract(py)?;
let members: &PyList = PyModule::import(py, "inspect")?
.getattr("getmembers")?
.call1((obj,))?
.cast_as()?;
let mut sym_value = None;
for member in members.iter() {
let key: &str = member.get_item(0)?.extract()?;
let val = member.get_item(1)?;
if key == id.to_string() {
let builtins = PyModule::import(py, "builtins")?;
let helper = PythonHelper {
id_fn: builtins.getattr("id").unwrap(),
len_fn: builtins.getattr("len").unwrap(),
type_fn: builtins.getattr("type").unwrap(),
};
sym_value = Some(self.get_default_param_obj_value(val, &helper).unwrap().unwrap());
break;
}
}
Ok(sym_value)
}
)
.unwrap()
}
_ => unimplemented!("other type of expr not supported at {}", expr.location)
}
}
fn get_symbol_type( fn get_symbol_type(
&self, &self,
unifier: &mut Unifier, unifier: &mut Unifier,

View File

@ -37,6 +37,10 @@ impl Resolver {
} }
impl SymbolResolver for Resolver { impl SymbolResolver for Resolver {
fn get_default_param_value(&self, _: &nac3parser::ast::Expr) -> Option<crate::symbol_resolver::SymbolValue> {
unimplemented!()
}
fn get_symbol_type( fn get_symbol_type(
&self, &self,
_: &mut Unifier, _: &mut Unifier,

View File

@ -44,6 +44,7 @@ pub trait SymbolResolver {
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
) -> Option<BasicValueEnum<'ctx>>; ) -> Option<BasicValueEnum<'ctx>>;
fn get_symbol_location(&self, str: StrRef) -> Option<Location>; fn get_symbol_location(&self, str: StrRef) -> Option<Location>;
fn get_default_param_value(&self, expr: &nac3parser::ast::Expr) -> Option<SymbolValue>;
// handle function call etc. // handle function call etc.
} }

View File

@ -1066,10 +1066,23 @@ impl TopLevelComposer {
and names thould not be the same as the keywords" and names thould not be the same as the keywords"
.into()); .into());
} }
args.args let arg_with_default: Vec<(&ast::Located<ast::ArgData<()>>, Option<&ast::Expr>)> = args
.args
.iter() .iter()
.map(|x| -> Result<FuncArg, String> { .rev()
.zip(args
.defaults
.iter()
.rev()
.map(|x| -> Option<&ast::Expr> { Some(x) })
.chain(std::iter::repeat(None))
).collect_vec();
arg_with_default
.iter()
.rev()
.map(|(x, default)| -> Result<FuncArg, String> {
let annotation = x let annotation = x
.node .node
.annotation .annotation
@ -1121,7 +1134,19 @@ impl TopLevelComposer {
Ok(FuncArg { Ok(FuncArg {
name: x.node.arg, name: x.node.arg,
ty, ty,
default_value: Default::default(), default_value: match default {
None => None,
Some(default) => Some({
let v = Self::parse_parameter_default_value(default, resolver)?;
Self::check_default_param_type(
&v,
&type_annotation,
primitives_store,
unifier
).map_err(|err| format!("{} at {}", err, x.location))?;
v
})
}
}) })
}) })
.collect::<Result<Vec<_>, _>>()? .collect::<Result<Vec<_>, _>>()?
@ -1282,7 +1307,20 @@ impl TopLevelComposer {
} }
let mut result = Vec::new(); let mut result = Vec::new();
for x in &args.args {
let arg_with_default: Vec<(&ast::Located<ast::ArgData<()>>, Option<&ast::Expr>)> = args
.args
.iter()
.rev()
.zip(args
.defaults
.iter()
.rev()
.map(|x| -> Option<&ast::Expr> { Some(x) })
.chain(std::iter::repeat(None))
).collect_vec();
for (x, default) in arg_with_default.into_iter().rev() {
let name = x.node.arg; let name = x.node.arg;
if name != zelf { if name != zelf {
let type_ann = { let type_ann = {
@ -1327,8 +1365,20 @@ impl TopLevelComposer {
let dummy_func_arg = FuncArg { let dummy_func_arg = FuncArg {
name, name,
ty: unifier.get_fresh_var().0, ty: unifier.get_fresh_var().0,
// TODO: default value? default_value: match default {
default_value: None, None => None,
Some(default) => {
if name == "self".into() {
return Err(format!("`self` parameter cannot take default value at {}", x.location));
}
Some({
let v = Self::parse_parameter_default_value(default, class_resolver)?;
Self::check_default_param_type(&v, &type_ann, primitives, unifier)
.map_err(|err| format!("{} at {}", err, x.location))?;
v
})
}
}
}; };
// push the dummy type and the type annotation // push the dummy type and the type annotation
// into the list for later unification // into the list for later unification

View File

@ -1,3 +1,8 @@
use std::convert::TryInto;
use nac3parser::ast::{Constant, Location};
use crate::symbol_resolver::SymbolValue;
use super::*; use super::*;
impl TopLevelDef { impl TopLevelDef {
@ -341,4 +346,121 @@ impl TopLevelComposer {
} }
Ok(result) Ok(result)
} }
pub fn parse_parameter_default_value(default: &ast::Expr, resolver: &(dyn SymbolResolver + Send + Sync)) -> Result<SymbolValue, String> {
parse_parameter_default_value(default, resolver)
}
pub fn check_default_param_type(val: &SymbolValue, ty: &TypeAnnotation, primitive: &PrimitiveStore, unifier: &mut Unifier) -> Result<(), String> {
let res = match val {
SymbolValue::Bool(..) => {
if matches!(ty, TypeAnnotation::Primitive(t) if *t == primitive.bool) {
None
} else {
Some("bool".to_string())
}
}
SymbolValue::Double(..) => {
if matches!(ty, TypeAnnotation::Primitive(t) if *t == primitive.float) {
None
} else {
Some("float".to_string())
}
}
SymbolValue::I32(..) => {
if matches!(ty, TypeAnnotation::Primitive(t) if *t == primitive.int32) {
None
} else {
Some("int32".to_string())
}
}
SymbolValue::I64(..) => {
if matches!(ty, TypeAnnotation::Primitive(t) if *t == primitive.int64) {
None
} else {
Some("int64".to_string())
}
}
SymbolValue::Tuple(elts) => {
if let TypeAnnotation::Tuple(elts_ty) = ty {
for (e, t) in elts.iter().zip(elts_ty.iter()) {
Self::check_default_param_type(e, t, primitive, unifier)?
}
if elts.len() != elts_ty.len() {
Some(format!("tuple of length {}", elts.len()))
} else {
None
}
} else {
Some("tuple".to_string())
}
}
};
if let Some(found) = res {
Err(format!(
"incompatible default parameter type, expect {}, found {}",
ty.stringify(unifier),
found
))
} else {
Ok(())
}
}
}
pub fn parse_parameter_default_value(default: &ast::Expr, resolver: &(dyn SymbolResolver + Send + Sync)) -> Result<SymbolValue, String> {
fn handle_constant(val: &Constant, loc: &Location) -> Result<SymbolValue, String> {
match val {
Constant::Int(v) => {
if let Ok(v) = v.try_into() {
Ok(SymbolValue::I32(v))
} else {
Err(format!(
"integer value out of range at {}",
loc
))
}
}
Constant::Float(v) => Ok(SymbolValue::Double(*v)),
Constant::Bool(v) => Ok(SymbolValue::Bool(*v)),
Constant::Tuple(tuple) => Ok(SymbolValue::Tuple(
tuple.iter().map(|x| handle_constant(x, loc)).collect::<Result<Vec<_>, _>>()?
)),
_ => unimplemented!("this constant is not supported at {}", loc),
}
}
match &default.node {
ast::ExprKind::Constant { value, .. } => handle_constant(value, &default.location),
ast::ExprKind::Call { func, args, .. } if {
match &func.node {
ast::ExprKind::Name { id, .. } => *id == "int64".into(),
_ => false,
}
} => {
if args.len() == 1 {
match &args[0].node {
ast::ExprKind::Constant { value: Constant::Int(v), .. } =>
Ok(SymbolValue::I64(v.try_into().unwrap())),
_ => Err(format!("only allow constant integer here at {}", default.location))
}
} else {
Err(format!("only allow constant integer here at {}", default.location))
}
}
ast::ExprKind::Tuple { elts, .. } => Ok(SymbolValue::Tuple(elts
.iter()
.map(|x| parse_parameter_default_value(x, resolver))
.collect::<Result<Vec<_>, _>>()?
)),
ast::ExprKind::Name { id, .. } => {
resolver.get_default_param_value(default).ok_or_else(
|| format!(
"`{}` cannot be used as a default parameter at {} (not primitive type or tuple / not defined?)",
id,
default.location
)
)
}
_ => unimplemented!("only constant default is supported now at {}", default.location),
}
} }

View File

@ -23,7 +23,7 @@ use inkwell::values::BasicValueEnum;
pub struct DefinitionId(pub usize); pub struct DefinitionId(pub usize);
pub mod composer; pub mod composer;
mod helper; pub mod helper;
mod type_annotation; mod type_annotation;
use composer::*; use composer::*;
use type_annotation::*; use type_annotation::*;

View File

@ -35,6 +35,10 @@ impl ResolverInternal {
struct Resolver(Arc<ResolverInternal>); struct Resolver(Arc<ResolverInternal>);
impl SymbolResolver for Resolver { impl SymbolResolver for Resolver {
fn get_default_param_value(&self, _: &nac3parser::ast::Expr) -> Option<crate::symbol_resolver::SymbolValue> {
unimplemented!()
}
fn get_symbol_type( fn get_symbol_type(
&self, &self,
_: &mut Unifier, _: &mut Unifier,

View File

@ -19,6 +19,10 @@ struct Resolver {
} }
impl SymbolResolver for Resolver { impl SymbolResolver for Resolver {
fn get_default_param_value(&self, _: &nac3parser::ast::Expr) -> Option<crate::symbol_resolver::SymbolValue> {
unimplemented!()
}
fn get_symbol_type( fn get_symbol_type(
&self, &self,
_: &mut Unifier, _: &mut Unifier,

View File

@ -2,7 +2,7 @@ use inkwell::values::BasicValueEnum;
use nac3core::{ use nac3core::{
codegen::CodeGenContext, codegen::CodeGenContext,
location::Location, location::Location,
symbol_resolver::SymbolResolver, symbol_resolver::{SymbolResolver, SymbolValue},
toplevel::{DefinitionId, TopLevelDef}, toplevel::{DefinitionId, TopLevelDef},
typecheck::{ typecheck::{
type_inferencer::PrimitiveStore, type_inferencer::PrimitiveStore,
@ -10,13 +10,14 @@ use nac3core::{
}, },
}; };
use parking_lot::{Mutex, RwLock}; use parking_lot::{Mutex, RwLock};
use nac3parser::ast::StrRef; use nac3parser::ast::{self, StrRef};
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
pub struct ResolverInternal { pub struct ResolverInternal {
pub id_to_type: Mutex<HashMap<StrRef, Type>>, pub id_to_type: Mutex<HashMap<StrRef, Type>>,
pub id_to_def: Mutex<HashMap<StrRef, DefinitionId>>, pub id_to_def: Mutex<HashMap<StrRef, DefinitionId>>,
pub class_names: Mutex<HashMap<StrRef, Type>>, pub class_names: Mutex<HashMap<StrRef, Type>>,
pub module_globals: Mutex<HashMap<StrRef, SymbolValue>>,
} }
impl ResolverInternal { impl ResolverInternal {
@ -27,11 +28,24 @@ impl ResolverInternal {
pub fn add_id_type(&self, id: StrRef, ty: Type) { pub fn add_id_type(&self, id: StrRef, ty: Type) {
self.id_to_type.lock().insert(id, ty); self.id_to_type.lock().insert(id, ty);
} }
pub fn add_module_global(&self, id: StrRef, val: SymbolValue) {
self.module_globals.lock().insert(id, val);
}
} }
pub struct Resolver(pub Arc<ResolverInternal>); pub struct Resolver(pub Arc<ResolverInternal>);
impl SymbolResolver for Resolver { impl SymbolResolver for Resolver {
fn get_default_param_value(&self, expr: &ast::Expr) -> Option<SymbolValue> {
match &expr.node {
ast::ExprKind::Name { id, .. } => {
self.0.module_globals.lock().get(id).cloned()
}
_ => unimplemented!("other type of expr not supported at {}", expr.location)
}
}
fn get_symbol_type( fn get_symbol_type(
&self, &self,
_: &mut Unifier, _: &mut Unifier,

View File

@ -4,8 +4,8 @@ use inkwell::{
OptimizationLevel, OptimizationLevel,
}; };
use nac3core::typecheck::type_inferencer::PrimitiveStore; use nac3core::typecheck::type_inferencer::PrimitiveStore;
use nac3parser::parser; use nac3parser::{ast::{Expr, ExprKind, StmtKind}, parser};
use std::env; use std::{borrow::Borrow, env};
use std::fs; use std::fs;
use std::{collections::HashMap, path::Path, sync::Arc, time::SystemTime}; use std::{collections::HashMap, path::Path, sync::Arc, time::SystemTime};
@ -15,7 +15,7 @@ use nac3core::{
WorkerRegistry, WorkerRegistry,
}, },
symbol_resolver::SymbolResolver, symbol_resolver::SymbolResolver,
toplevel::{composer::TopLevelComposer, TopLevelDef}, toplevel::{composer::TopLevelComposer, TopLevelDef, helper::parse_parameter_default_value},
typecheck::typedef::FunSignature, typecheck::typedef::FunSignature,
}; };
@ -48,6 +48,7 @@ fn main() {
id_to_type: builtins_ty.into(), id_to_type: builtins_ty.into(),
id_to_def: builtins_def.into(), id_to_def: builtins_def.into(),
class_names: Default::default(), class_names: Default::default(),
module_globals: Default::default(),
} }
.into(); .into();
let resolver = let resolver =
@ -66,6 +67,61 @@ fn main() {
); );
for stmt in parser_result.into_iter() { for stmt in parser_result.into_iter() {
if let StmtKind::Assign { targets, value, .. } = &stmt.node {
fn handle_assignment_pattern(
targets: &[Expr],
value: &Expr,
resolver: &(dyn SymbolResolver + Send + Sync),
internal_resolver: &ResolverInternal,
) -> Result<(), String> {
if targets.len() == 1 {
match &targets[0].node {
ExprKind::Name { id, .. } => {
let val = parse_parameter_default_value(value.borrow(), resolver)?;
internal_resolver.add_module_global(*id, val);
Ok(())
}
ExprKind::List { elts, .. }
| ExprKind::Tuple { elts, .. } => {
handle_assignment_pattern(elts, value, resolver, internal_resolver)?;
Ok(())
}
_ => unreachable!("cannot be assigned")
}
} else {
match &value.node {
ExprKind::List { elts, .. }
| ExprKind::Tuple { elts, .. } => {
if elts.len() != targets.len() {
Err(format!(
"number of elements to unpack does not match (expect {}, found {}) at {}",
targets.len(),
elts.len(),
value.location
))
} else {
for (tar, val) in targets.iter().zip(elts) {
handle_assignment_pattern(
std::slice::from_ref(tar),
val,
resolver,
internal_resolver
)?;
}
Ok(())
}
},
_ => Err(format!("unpack of this expression is not supported at {}", value.location))
}
}
}
if let Err(err) = handle_assignment_pattern(targets, value, resolver.as_ref(), internal_resolver.as_ref()) {
eprintln!("{}", err);
return;
}
continue;
}
let (name, def_id, ty) = composer let (name, def_id, ty) = composer
.register_top_level(stmt, Some(resolver.clone()), "__main__".into()) .register_top_level(stmt, Some(resolver.clone()), "__main__".into())
.unwrap(); .unwrap();