1
0
forked from M-Labs/nac3

function parameter handling

This commit is contained in:
pca006132 2021-08-07 17:25:14 +08:00
parent 711482d09c
commit 86ca02796b
5 changed files with 99 additions and 84 deletions

View File

@ -1,12 +1,9 @@
use std::{convert::TryInto, iter::once}; use std::{collections::HashMap, convert::TryInto, iter::once};
use crate::{ use crate::{
top_level::DefinitionId, symbol_resolver::SymbolValue,
typecheck::typedef::{Type, TypeEnum}, top_level::{CodeGenContext, DefinitionId, TopLevelDef},
}; typecheck::typedef::{FunSignature, Type, TypeEnum},
use crate::{
top_level::{CodeGenContext, TopLevelDef},
typecheck::typedef::FunSignature,
}; };
use inkwell::{ use inkwell::{
types::{BasicType, BasicTypeEnum}, types::{BasicType, BasicTypeEnum},
@ -65,10 +62,7 @@ impl<'ctx> CodeGenContext<'ctx> {
let fields = fields.borrow(); let fields = fields.borrow();
let fields = let fields =
fields_list.iter().map(|f| self.get_llvm_type(fields[&f.0])).collect_vec(); fields_list.iter().map(|f| self.get_llvm_type(fields[&f.0])).collect_vec();
self.ctx self.ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into()
.struct_type(&fields, false)
.ptr_type(AddressSpace::Generic)
.into()
} else { } else {
unreachable!() unreachable!()
}; };
@ -93,16 +87,20 @@ impl<'ctx> CodeGenContext<'ctx> {
}) })
} }
fn gen_symbol_val(&mut self, val: &SymbolValue) -> BasicValueEnum<'ctx> {
unimplemented!()
}
fn gen_call( fn gen_call(
&mut self, &mut self,
obj: Option<(Type, BasicValueEnum<'ctx>)>, obj: Option<(Type, BasicValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId), fun: (&FunSignature, DefinitionId),
params: &[BasicValueEnum<'ctx>], params: Vec<(Option<String>, BasicValueEnum<'ctx>)>,
ret: Type, ret: Type,
) -> Option<BasicValueEnum<'ctx>> { ) -> Option<BasicValueEnum<'ctx>> {
let key = self.get_subst_key(obj.map(|(a, _)| a), fun.0); let key = self.get_subst_key(obj.map(|(a, _)| a), fun.0);
let defs = self.top_level.definitions.read(); let defs = self.top_level.definitions.read();
let definition = defs.get(fun.1.0).unwrap(); let definition = defs.get(fun.1 .0).unwrap();
let val = if let TopLevelDef::Function { instance_to_symbol, .. } = &*definition.read() { let val = if let TopLevelDef::Function { instance_to_symbol, .. } = &*definition.read() {
let symbol = instance_to_symbol.get(&key).unwrap_or_else(|| { let symbol = instance_to_symbol.get(&key).unwrap_or_else(|| {
// TODO: codegen for function that are not yet generated // TODO: codegen for function that are not yet generated
@ -117,8 +115,19 @@ impl<'ctx> CodeGenContext<'ctx> {
}; };
self.module.add_function(symbol, fun_ty, None) self.module.add_function(symbol, fun_ty, None)
}); });
// TODO: deal with default parameters and reordering based on keys let mut keys = fun.0.args.clone();
self.builder.build_call(fun_val, params, "call").try_as_basic_value().left() let mut mapping = HashMap::new();
for (key, value) in params.into_iter() {
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value);
}
// default value handling
for k in keys.into_iter() {
mapping.insert(k.name, self.gen_symbol_val(&k.default_value.unwrap()));
}
// reorder the parameters
let params =
fun.0.args.iter().map(|arg| mapping.remove(&arg.name).unwrap()).collect_vec();
self.builder.build_call(fun_val, &params, "call").try_as_basic_value().left()
} else { } else {
unreachable!() unreachable!()
}; };
@ -158,7 +167,7 @@ impl<'ctx> CodeGenContext<'ctx> {
let ty = self.ctx.struct_type(&types, false); let ty = self.ctx.struct_type(&types, false);
ty.const_named_struct(&values).into() ty.const_named_struct(&values).into()
} }
_ => unreachable!() _ => unreachable!(),
} }
} }
@ -261,10 +270,7 @@ impl<'ctx> CodeGenContext<'ctx> {
"tmparr", "tmparr",
); );
let arr_ty = self.ctx.struct_type( let arr_ty = self.ctx.struct_type(
&[ &[self.ctx.i32_type().into(), ty.ptr_type(AddressSpace::Generic).into()],
self.ctx.i32_type().into(),
ty.ptr_type(AddressSpace::Generic).into(),
],
false, false,
); );
let arr_str_ptr = self.builder.build_alloca(arr_ty, "tmparrstr"); let arr_str_ptr = self.builder.build_alloca(arr_ty, "tmparrstr");
@ -445,50 +451,52 @@ impl<'ctx> CodeGenContext<'ctx> {
) )
.fold(None, |prev, (lhs, rhs, op)| { .fold(None, |prev, (lhs, rhs, op)| {
let ty = lhs.custom.unwrap(); let ty = lhs.custom.unwrap();
let current = if [self.primitives.int32, self.primitives.int64, self.primitives.bool] let current =
.contains(&ty) if [self.primitives.int32, self.primitives.int64, self.primitives.bool]
{ .contains(&ty)
let (lhs, rhs) = {
if let (BasicValueEnum::IntValue(lhs), BasicValueEnum::IntValue(rhs)) = let (lhs, rhs) = if let (
(self.gen_expr(lhs), self.gen_expr(rhs)) BasicValueEnum::IntValue(lhs),
BasicValueEnum::IntValue(rhs),
) = (self.gen_expr(lhs), self.gen_expr(rhs))
{ {
(lhs, rhs) (lhs, rhs)
} else { } else {
unreachable!() unreachable!()
}; };
let op = match op { let op = match op {
ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::IntPredicate::EQ, ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::IntPredicate::EQ,
ast::Cmpop::NotEq => inkwell::IntPredicate::NE, ast::Cmpop::NotEq => inkwell::IntPredicate::NE,
ast::Cmpop::Lt => inkwell::IntPredicate::SLT, ast::Cmpop::Lt => inkwell::IntPredicate::SLT,
ast::Cmpop::LtE => inkwell::IntPredicate::SLE, ast::Cmpop::LtE => inkwell::IntPredicate::SLE,
ast::Cmpop::Gt => inkwell::IntPredicate::SGT, ast::Cmpop::Gt => inkwell::IntPredicate::SGT,
ast::Cmpop::GtE => inkwell::IntPredicate::SGE, ast::Cmpop::GtE => inkwell::IntPredicate::SGE,
_ => unreachable!(), _ => unreachable!(),
}; };
self.builder.build_int_compare(op, lhs, rhs, "cmp") self.builder.build_int_compare(op, lhs, rhs, "cmp")
} else if ty == self.primitives.float { } else if ty == self.primitives.float {
let (lhs, rhs) = if let ( let (lhs, rhs) = if let (
BasicValueEnum::FloatValue(lhs), BasicValueEnum::FloatValue(lhs),
BasicValueEnum::FloatValue(rhs), BasicValueEnum::FloatValue(rhs),
) = (self.gen_expr(lhs), self.gen_expr(rhs)) ) = (self.gen_expr(lhs), self.gen_expr(rhs))
{ {
(lhs, rhs) (lhs, rhs)
} else {
unreachable!()
};
let op = match op {
ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::FloatPredicate::OEQ,
ast::Cmpop::NotEq => inkwell::FloatPredicate::ONE,
ast::Cmpop::Lt => inkwell::FloatPredicate::OLT,
ast::Cmpop::LtE => inkwell::FloatPredicate::OLE,
ast::Cmpop::Gt => inkwell::FloatPredicate::OGT,
ast::Cmpop::GtE => inkwell::FloatPredicate::OGE,
_ => unreachable!(),
};
self.builder.build_float_compare(op, lhs, rhs, "cmp")
} else { } else {
unreachable!() unimplemented!()
}; };
let op = match op {
ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::FloatPredicate::OEQ,
ast::Cmpop::NotEq => inkwell::FloatPredicate::ONE,
ast::Cmpop::Lt => inkwell::FloatPredicate::OLT,
ast::Cmpop::LtE => inkwell::FloatPredicate::OLE,
ast::Cmpop::Gt => inkwell::FloatPredicate::OGT,
ast::Cmpop::GtE => inkwell::FloatPredicate::OGE,
_ => unreachable!(),
};
self.builder.build_float_compare(op, lhs, rhs, "cmp")
} else {
unimplemented!()
};
prev.map(|v| self.builder.build_and(v, current, "cmp")).or(Some(current)) prev.map(|v| self.builder.build_and(v, current, "cmp")).or(Some(current))
}) })
.unwrap() .unwrap()

View File

@ -3,13 +3,15 @@ use crate::typecheck::typedef::Type;
use crate::top_level::DefinitionId; use crate::top_level::DefinitionId;
use rustpython_parser::ast::Expr; use rustpython_parser::ast::Expr;
pub enum SymbolValue<'a> { #[derive(Clone, PartialEq)]
pub enum SymbolValue {
I32(i32), I32(i32),
I64(i64), I64(i64),
Double(f64), Double(f64),
Bool(bool), Bool(bool),
Tuple(&'a [SymbolValue<'a>]), Tuple(Vec<SymbolValue>),
Bytes(&'a [u8]), // we should think about how to implement bytes later...
// Bytes(&'a [u8]),
} }
pub trait SymbolResolver { pub trait SymbolResolver {

View File

@ -77,7 +77,7 @@ pub fn impl_binop(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, othe
vars: HashMap::new(), vars: HashMap::new(),
args: vec![FuncArg { args: vec![FuncArg {
ty: other, ty: other,
is_optional: false, default_value: None,
name: "other".into() name: "other".into()
}] }]
})) }))
@ -97,7 +97,7 @@ pub fn impl_binop(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, othe
vars: HashMap::new(), vars: HashMap::new(),
args: vec![FuncArg { args: vec![FuncArg {
ty: other, ty: other,
is_optional: false, default_value: None,
name: "other".into() name: "other".into()
}] }]
})) }))
@ -132,7 +132,7 @@ pub fn impl_cmpop(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other
vars: HashMap::new(), vars: HashMap::new(),
args: vec![FuncArg { args: vec![FuncArg {
ty: other_ty, ty: other_ty,
is_optional: false, default_value: None,
name: "other".into() name: "other".into()
}] }]
})) }))
@ -144,15 +144,15 @@ pub fn impl_cmpop(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other
/// Add, Sub, Mult, Pow /// Add, Sub, Mult, Pow
pub fn impl_basic_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type], ret_ty: Type) { pub fn impl_basic_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type], ret_ty: Type) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[ impl_binop(unifier, store, ty, other_ty, ret_ty, &[
ast::Operator::Add, ast::Operator::Add,
ast::Operator::Sub, ast::Operator::Sub,
ast::Operator::Mult, ast::Operator::Mult,
]) ])
} }
pub fn impl_pow(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type], ret_ty: Type) { pub fn impl_pow(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type], ret_ty: Type) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[ impl_binop(unifier, store, ty, other_ty, ret_ty, &[
ast::Operator::Pow, ast::Operator::Pow,
]) ])
} }
@ -236,9 +236,9 @@ pub fn impl_eq(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) { pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) {
let PrimitiveStore { let PrimitiveStore {
int32: int32_t, int32: int32_t,
int64: int64_t, int64: int64_t,
float: float_t, float: float_t,
bool: bool_t, bool: bool_t,
.. ..
} = *store; } = *store;
@ -255,8 +255,8 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
impl_not(unifier, store, int32_t); impl_not(unifier, store, int32_t);
impl_comparison(unifier, store, int32_t, int32_t); impl_comparison(unifier, store, int32_t, int32_t);
impl_eq(unifier, store, int32_t); impl_eq(unifier, store, int32_t);
/* int64 ======== */ /* int64 ======== */
impl_basic_arithmetic(unifier, store, int64_t, &[int64_t], int64_t); impl_basic_arithmetic(unifier, store, int64_t, &[int64_t], int64_t);
impl_pow(unifier, store, int64_t, &[int64_t], int64_t); impl_pow(unifier, store, int64_t, &[int64_t], int64_t);
impl_bitwise_arithmetic(unifier, store, int64_t); impl_bitwise_arithmetic(unifier, store, int64_t);
@ -269,8 +269,8 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
impl_not(unifier, store, int64_t); impl_not(unifier, store, int64_t);
impl_comparison(unifier, store, int64_t, int64_t); impl_comparison(unifier, store, int64_t, int64_t);
impl_eq(unifier, store, int64_t); impl_eq(unifier, store, int64_t);
/* float ======== */ /* float ======== */
impl_basic_arithmetic(unifier, store, float_t, &[float_t], float_t); impl_basic_arithmetic(unifier, store, float_t, &[float_t], float_t);
impl_pow(unifier, store, float_t, &[int32_t, float_t], float_t); impl_pow(unifier, store, float_t, &[int32_t, float_t], float_t);
impl_div(unifier, store, float_t, &[float_t]); impl_div(unifier, store, float_t, &[float_t]);
@ -280,8 +280,8 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
impl_not(unifier, store, float_t); impl_not(unifier, store, float_t);
impl_comparison(unifier, store, float_t, float_t); impl_comparison(unifier, store, float_t, float_t);
impl_eq(unifier, store, float_t); impl_eq(unifier, store, float_t);
/* bool ======== */ /* bool ======== */
impl_not(unifier, store, bool_t); impl_not(unifier, store, bool_t);
impl_eq(unifier, store, bool_t); impl_eq(unifier, store, bool_t);
} }

View File

@ -240,7 +240,7 @@ impl<'a> Inferencer<'a> {
let fun = FunSignature { let fun = FunSignature {
args: fn_args args: fn_args
.iter() .iter()
.map(|(k, ty)| FuncArg { name: k.clone(), ty: *ty, is_optional: false }) .map(|(k, ty)| FuncArg { name: k.clone(), ty: *ty, default_value: None })
.collect(), .collect(),
ret, ret,
vars: Default::default(), vars: Default::default(),
@ -513,7 +513,13 @@ impl<'a> Inferencer<'a> {
for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) { for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) {
let method = let method =
comparison_name(c).ok_or_else(|| "unsupported comparator".to_string())?.to_string(); comparison_name(c).ok_or_else(|| "unsupported comparator".to_string())?.to_string();
self.build_method_call(a.location, method, a.custom.unwrap(), vec![b.custom.unwrap()], boolean)?; self.build_method_call(
a.location,
method,
a.custom.unwrap(),
vec![b.custom.unwrap()],
boolean,
)?;
} }
Ok(boolean) Ok(boolean)
} }

View File

@ -6,8 +6,9 @@ use std::iter::once;
use std::rc::Rc; use std::rc::Rc;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use crate::top_level::DefinitionId;
use super::unification_table::{UnificationKey, UnificationTable}; use super::unification_table::{UnificationKey, UnificationTable};
use crate::symbol_resolver::SymbolValue;
use crate::top_level::DefinitionId;
#[cfg(test)] #[cfg(test)]
mod test; mod test;
@ -30,9 +31,7 @@ pub struct Call {
pub struct FuncArg { pub struct FuncArg {
pub name: String, pub name: String,
pub ty: Type, pub ty: Type,
// TODO: change this to an optional value pub default_value: Option<SymbolValue>,
// for primitive types
pub is_optional: bool,
} }
#[derive(Clone)] #[derive(Clone)]
@ -457,7 +456,7 @@ impl Unifier {
let required: Vec<String> = signature let required: Vec<String> = signature
.args .args
.iter() .iter()
.filter(|v| !v.is_optional) .filter(|v| v.default_value.is_none())
.map(|v| v.name.clone()) .map(|v| v.name.clone())
.rev() .rev()
.collect(); .collect();
@ -516,7 +515,7 @@ impl Unifier {
if x.name != y.name { if x.name != y.name {
return Err("Functions differ in parameter names.".to_string()); return Err("Functions differ in parameter names.".to_string());
} }
if x.is_optional != y.is_optional { if x.default_value != y.default_value {
return Err("Functions differ in optional parameters.".to_string()); return Err("Functions differ in optional parameters.".to_string());
} }
self.unify(x.ty, y.ty)?; self.unify(x.ty, y.ty)?;