From 86ca02796b5752fc800a5aa56c57df0cb1df57dd Mon Sep 17 00:00:00 2001 From: pca006132 Date: Sat, 7 Aug 2021 17:25:14 +0800 Subject: [PATCH] function parameter handling --- nac3core/src/codegen/expr.rs | 122 ++++++++++-------- nac3core/src/symbol_resolver.rs | 8 +- nac3core/src/typecheck/magic_methods.rs | 32 ++--- nac3core/src/typecheck/type_inferencer/mod.rs | 10 +- nac3core/src/typecheck/typedef/mod.rs | 11 +- 5 files changed, 99 insertions(+), 84 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 09a250aa7..f93975ab8 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1,12 +1,9 @@ -use std::{convert::TryInto, iter::once}; +use std::{collections::HashMap, convert::TryInto, iter::once}; use crate::{ - top_level::DefinitionId, - typecheck::typedef::{Type, TypeEnum}, -}; -use crate::{ - top_level::{CodeGenContext, TopLevelDef}, - typecheck::typedef::FunSignature, + symbol_resolver::SymbolValue, + top_level::{CodeGenContext, DefinitionId, TopLevelDef}, + typecheck::typedef::{FunSignature, Type, TypeEnum}, }; use inkwell::{ types::{BasicType, BasicTypeEnum}, @@ -65,10 +62,7 @@ impl<'ctx> CodeGenContext<'ctx> { let fields = fields.borrow(); let fields = fields_list.iter().map(|f| self.get_llvm_type(fields[&f.0])).collect_vec(); - self.ctx - .struct_type(&fields, false) - .ptr_type(AddressSpace::Generic) - .into() + self.ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into() } else { unreachable!() }; @@ -93,16 +87,20 @@ impl<'ctx> CodeGenContext<'ctx> { }) } + fn gen_symbol_val(&mut self, val: &SymbolValue) -> BasicValueEnum<'ctx> { + unimplemented!() + } + fn gen_call( &mut self, obj: Option<(Type, BasicValueEnum<'ctx>)>, fun: (&FunSignature, DefinitionId), - params: &[BasicValueEnum<'ctx>], + params: Vec<(Option, BasicValueEnum<'ctx>)>, ret: Type, ) -> Option> { let key = self.get_subst_key(obj.map(|(a, _)| a), fun.0); let defs = self.top_level.definitions.read(); - let definition = defs.get(fun.1.0).unwrap(); + let definition = defs.get(fun.1 .0).unwrap(); let val = if let TopLevelDef::Function { instance_to_symbol, .. } = &*definition.read() { let symbol = instance_to_symbol.get(&key).unwrap_or_else(|| { // TODO: codegen for function that are not yet generated @@ -117,8 +115,19 @@ impl<'ctx> CodeGenContext<'ctx> { }; self.module.add_function(symbol, fun_ty, None) }); - // TODO: deal with default parameters and reordering based on keys - self.builder.build_call(fun_val, params, "call").try_as_basic_value().left() + let mut keys = fun.0.args.clone(); + let mut mapping = HashMap::new(); + for (key, value) in params.into_iter() { + mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value); + } + // default value handling + for k in keys.into_iter() { + mapping.insert(k.name, self.gen_symbol_val(&k.default_value.unwrap())); + } + // reorder the parameters + let params = + fun.0.args.iter().map(|arg| mapping.remove(&arg.name).unwrap()).collect_vec(); + self.builder.build_call(fun_val, ¶ms, "call").try_as_basic_value().left() } else { unreachable!() }; @@ -158,7 +167,7 @@ impl<'ctx> CodeGenContext<'ctx> { let ty = self.ctx.struct_type(&types, false); ty.const_named_struct(&values).into() } - _ => unreachable!() + _ => unreachable!(), } } @@ -261,10 +270,7 @@ impl<'ctx> CodeGenContext<'ctx> { "tmparr", ); let arr_ty = self.ctx.struct_type( - &[ - self.ctx.i32_type().into(), - ty.ptr_type(AddressSpace::Generic).into(), - ], + &[self.ctx.i32_type().into(), ty.ptr_type(AddressSpace::Generic).into()], false, ); let arr_str_ptr = self.builder.build_alloca(arr_ty, "tmparrstr"); @@ -445,50 +451,52 @@ impl<'ctx> CodeGenContext<'ctx> { ) .fold(None, |prev, (lhs, rhs, op)| { let ty = lhs.custom.unwrap(); - let current = if [self.primitives.int32, self.primitives.int64, self.primitives.bool] - .contains(&ty) - { - let (lhs, rhs) = - if let (BasicValueEnum::IntValue(lhs), BasicValueEnum::IntValue(rhs)) = - (self.gen_expr(lhs), self.gen_expr(rhs)) + let current = + if [self.primitives.int32, self.primitives.int64, self.primitives.bool] + .contains(&ty) + { + let (lhs, rhs) = if let ( + BasicValueEnum::IntValue(lhs), + BasicValueEnum::IntValue(rhs), + ) = (self.gen_expr(lhs), self.gen_expr(rhs)) { (lhs, rhs) } else { unreachable!() }; - let op = match op { - ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::IntPredicate::EQ, - ast::Cmpop::NotEq => inkwell::IntPredicate::NE, - ast::Cmpop::Lt => inkwell::IntPredicate::SLT, - ast::Cmpop::LtE => inkwell::IntPredicate::SLE, - ast::Cmpop::Gt => inkwell::IntPredicate::SGT, - ast::Cmpop::GtE => inkwell::IntPredicate::SGE, - _ => unreachable!(), - }; - self.builder.build_int_compare(op, lhs, rhs, "cmp") - } else if ty == self.primitives.float { - let (lhs, rhs) = if let ( - BasicValueEnum::FloatValue(lhs), - BasicValueEnum::FloatValue(rhs), - ) = (self.gen_expr(lhs), self.gen_expr(rhs)) - { - (lhs, rhs) + let op = match op { + ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::IntPredicate::EQ, + ast::Cmpop::NotEq => inkwell::IntPredicate::NE, + ast::Cmpop::Lt => inkwell::IntPredicate::SLT, + ast::Cmpop::LtE => inkwell::IntPredicate::SLE, + ast::Cmpop::Gt => inkwell::IntPredicate::SGT, + ast::Cmpop::GtE => inkwell::IntPredicate::SGE, + _ => unreachable!(), + }; + self.builder.build_int_compare(op, lhs, rhs, "cmp") + } else if ty == self.primitives.float { + let (lhs, rhs) = if let ( + BasicValueEnum::FloatValue(lhs), + BasicValueEnum::FloatValue(rhs), + ) = (self.gen_expr(lhs), self.gen_expr(rhs)) + { + (lhs, rhs) + } else { + unreachable!() + }; + let op = match op { + ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::FloatPredicate::OEQ, + ast::Cmpop::NotEq => inkwell::FloatPredicate::ONE, + ast::Cmpop::Lt => inkwell::FloatPredicate::OLT, + ast::Cmpop::LtE => inkwell::FloatPredicate::OLE, + ast::Cmpop::Gt => inkwell::FloatPredicate::OGT, + ast::Cmpop::GtE => inkwell::FloatPredicate::OGE, + _ => unreachable!(), + }; + self.builder.build_float_compare(op, lhs, rhs, "cmp") } else { - unreachable!() + unimplemented!() }; - let op = match op { - ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::FloatPredicate::OEQ, - ast::Cmpop::NotEq => inkwell::FloatPredicate::ONE, - ast::Cmpop::Lt => inkwell::FloatPredicate::OLT, - ast::Cmpop::LtE => inkwell::FloatPredicate::OLE, - ast::Cmpop::Gt => inkwell::FloatPredicate::OGT, - ast::Cmpop::GtE => inkwell::FloatPredicate::OGE, - _ => unreachable!(), - }; - self.builder.build_float_compare(op, lhs, rhs, "cmp") - } else { - unimplemented!() - }; prev.map(|v| self.builder.build_and(v, current, "cmp")).or(Some(current)) }) .unwrap() diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 1c9f80c7b..86a43ad76 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -3,13 +3,15 @@ use crate::typecheck::typedef::Type; use crate::top_level::DefinitionId; use rustpython_parser::ast::Expr; -pub enum SymbolValue<'a> { +#[derive(Clone, PartialEq)] +pub enum SymbolValue { I32(i32), I64(i64), Double(f64), Bool(bool), - Tuple(&'a [SymbolValue<'a>]), - Bytes(&'a [u8]), + Tuple(Vec), + // we should think about how to implement bytes later... + // Bytes(&'a [u8]), } pub trait SymbolResolver { diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index ff28da5ed..9d3743949 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -77,7 +77,7 @@ pub fn impl_binop(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, othe vars: HashMap::new(), args: vec![FuncArg { ty: other, - is_optional: false, + default_value: None, name: "other".into() }] })) @@ -97,7 +97,7 @@ pub fn impl_binop(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, othe vars: HashMap::new(), args: vec![FuncArg { ty: other, - is_optional: false, + default_value: None, name: "other".into() }] })) @@ -132,7 +132,7 @@ pub fn impl_cmpop(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other vars: HashMap::new(), args: vec![FuncArg { ty: other_ty, - is_optional: false, + default_value: None, name: "other".into() }] })) @@ -144,15 +144,15 @@ pub fn impl_cmpop(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other /// Add, Sub, Mult, Pow pub fn impl_basic_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type], ret_ty: Type) { impl_binop(unifier, store, ty, other_ty, ret_ty, &[ - ast::Operator::Add, - ast::Operator::Sub, + ast::Operator::Add, + ast::Operator::Sub, ast::Operator::Mult, ]) } pub fn impl_pow(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type], ret_ty: Type) { impl_binop(unifier, store, ty, other_ty, ret_ty, &[ - ast::Operator::Pow, + ast::Operator::Pow, ]) } @@ -236,9 +236,9 @@ pub fn impl_eq(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) { let PrimitiveStore { - int32: int32_t, - int64: int64_t, - float: float_t, + int32: int32_t, + int64: int64_t, + float: float_t, bool: bool_t, .. } = *store; @@ -255,8 +255,8 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie impl_not(unifier, store, int32_t); impl_comparison(unifier, store, int32_t, int32_t); impl_eq(unifier, store, int32_t); - - /* int64 ======== */ + + /* int64 ======== */ impl_basic_arithmetic(unifier, store, int64_t, &[int64_t], int64_t); impl_pow(unifier, store, int64_t, &[int64_t], int64_t); impl_bitwise_arithmetic(unifier, store, int64_t); @@ -269,8 +269,8 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie impl_not(unifier, store, int64_t); impl_comparison(unifier, store, int64_t, int64_t); impl_eq(unifier, store, int64_t); - - /* float ======== */ + + /* float ======== */ impl_basic_arithmetic(unifier, store, float_t, &[float_t], float_t); impl_pow(unifier, store, float_t, &[int32_t, float_t], float_t); impl_div(unifier, store, float_t, &[float_t]); @@ -280,8 +280,8 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie impl_not(unifier, store, float_t); impl_comparison(unifier, store, float_t, float_t); impl_eq(unifier, store, float_t); - - /* bool ======== */ + + /* bool ======== */ impl_not(unifier, store, bool_t); impl_eq(unifier, store, bool_t); -} \ No newline at end of file +} diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index e2fdc1727..7f5bcbc87 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -240,7 +240,7 @@ impl<'a> Inferencer<'a> { let fun = FunSignature { args: fn_args .iter() - .map(|(k, ty)| FuncArg { name: k.clone(), ty: *ty, is_optional: false }) + .map(|(k, ty)| FuncArg { name: k.clone(), ty: *ty, default_value: None }) .collect(), ret, vars: Default::default(), @@ -513,7 +513,13 @@ impl<'a> Inferencer<'a> { for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) { let method = comparison_name(c).ok_or_else(|| "unsupported comparator".to_string())?.to_string(); - self.build_method_call(a.location, method, a.custom.unwrap(), vec![b.custom.unwrap()], boolean)?; + self.build_method_call( + a.location, + method, + a.custom.unwrap(), + vec![b.custom.unwrap()], + boolean, + )?; } Ok(boolean) } diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index c880b6c2d..38e2a9ff0 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -6,8 +6,9 @@ use std::iter::once; use std::rc::Rc; use std::sync::{Arc, Mutex}; -use crate::top_level::DefinitionId; use super::unification_table::{UnificationKey, UnificationTable}; +use crate::symbol_resolver::SymbolValue; +use crate::top_level::DefinitionId; #[cfg(test)] mod test; @@ -30,9 +31,7 @@ pub struct Call { pub struct FuncArg { pub name: String, pub ty: Type, - // TODO: change this to an optional value - // for primitive types - pub is_optional: bool, + pub default_value: Option, } #[derive(Clone)] @@ -457,7 +456,7 @@ impl Unifier { let required: Vec = signature .args .iter() - .filter(|v| !v.is_optional) + .filter(|v| v.default_value.is_none()) .map(|v| v.name.clone()) .rev() .collect(); @@ -516,7 +515,7 @@ impl Unifier { if x.name != y.name { return Err("Functions differ in parameter names.".to_string()); } - if x.is_optional != y.is_optional { + if x.default_value != y.default_value { return Err("Functions differ in optional parameters.".to_string()); } self.unify(x.ty, y.ty)?;