forked from M-Labs/nac3
1
0
Fork 0

function parameter handling

This commit is contained in:
pca006132 2021-08-07 17:25:14 +08:00
parent 711482d09c
commit 86ca02796b
5 changed files with 99 additions and 84 deletions

View File

@ -1,12 +1,9 @@
use std::{convert::TryInto, iter::once};
use std::{collections::HashMap, convert::TryInto, iter::once};
use crate::{
top_level::DefinitionId,
typecheck::typedef::{Type, TypeEnum},
};
use crate::{
top_level::{CodeGenContext, TopLevelDef},
typecheck::typedef::FunSignature,
symbol_resolver::SymbolValue,
top_level::{CodeGenContext, DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type, TypeEnum},
};
use inkwell::{
types::{BasicType, BasicTypeEnum},
@ -65,10 +62,7 @@ impl<'ctx> CodeGenContext<'ctx> {
let fields = fields.borrow();
let fields =
fields_list.iter().map(|f| self.get_llvm_type(fields[&f.0])).collect_vec();
self.ctx
.struct_type(&fields, false)
.ptr_type(AddressSpace::Generic)
.into()
self.ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into()
} else {
unreachable!()
};
@ -93,16 +87,20 @@ impl<'ctx> CodeGenContext<'ctx> {
})
}
fn gen_symbol_val(&mut self, val: &SymbolValue) -> BasicValueEnum<'ctx> {
unimplemented!()
}
fn gen_call(
&mut self,
obj: Option<(Type, BasicValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
params: &[BasicValueEnum<'ctx>],
params: Vec<(Option<String>, BasicValueEnum<'ctx>)>,
ret: Type,
) -> Option<BasicValueEnum<'ctx>> {
let key = self.get_subst_key(obj.map(|(a, _)| a), fun.0);
let defs = self.top_level.definitions.read();
let definition = defs.get(fun.1.0).unwrap();
let definition = defs.get(fun.1 .0).unwrap();
let val = if let TopLevelDef::Function { instance_to_symbol, .. } = &*definition.read() {
let symbol = instance_to_symbol.get(&key).unwrap_or_else(|| {
// TODO: codegen for function that are not yet generated
@ -117,8 +115,19 @@ impl<'ctx> CodeGenContext<'ctx> {
};
self.module.add_function(symbol, fun_ty, None)
});
// TODO: deal with default parameters and reordering based on keys
self.builder.build_call(fun_val, params, "call").try_as_basic_value().left()
let mut keys = fun.0.args.clone();
let mut mapping = HashMap::new();
for (key, value) in params.into_iter() {
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value);
}
// default value handling
for k in keys.into_iter() {
mapping.insert(k.name, self.gen_symbol_val(&k.default_value.unwrap()));
}
// reorder the parameters
let params =
fun.0.args.iter().map(|arg| mapping.remove(&arg.name).unwrap()).collect_vec();
self.builder.build_call(fun_val, &params, "call").try_as_basic_value().left()
} else {
unreachable!()
};
@ -158,7 +167,7 @@ impl<'ctx> CodeGenContext<'ctx> {
let ty = self.ctx.struct_type(&types, false);
ty.const_named_struct(&values).into()
}
_ => unreachable!()
_ => unreachable!(),
}
}
@ -261,10 +270,7 @@ impl<'ctx> CodeGenContext<'ctx> {
"tmparr",
);
let arr_ty = self.ctx.struct_type(
&[
self.ctx.i32_type().into(),
ty.ptr_type(AddressSpace::Generic).into(),
],
&[self.ctx.i32_type().into(), ty.ptr_type(AddressSpace::Generic).into()],
false,
);
let arr_str_ptr = self.builder.build_alloca(arr_ty, "tmparrstr");
@ -445,50 +451,52 @@ impl<'ctx> CodeGenContext<'ctx> {
)
.fold(None, |prev, (lhs, rhs, op)| {
let ty = lhs.custom.unwrap();
let current = if [self.primitives.int32, self.primitives.int64, self.primitives.bool]
.contains(&ty)
{
let (lhs, rhs) =
if let (BasicValueEnum::IntValue(lhs), BasicValueEnum::IntValue(rhs)) =
(self.gen_expr(lhs), self.gen_expr(rhs))
let current =
if [self.primitives.int32, self.primitives.int64, self.primitives.bool]
.contains(&ty)
{
let (lhs, rhs) = if let (
BasicValueEnum::IntValue(lhs),
BasicValueEnum::IntValue(rhs),
) = (self.gen_expr(lhs), self.gen_expr(rhs))
{
(lhs, rhs)
} else {
unreachable!()
};
let op = match op {
ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::IntPredicate::EQ,
ast::Cmpop::NotEq => inkwell::IntPredicate::NE,
ast::Cmpop::Lt => inkwell::IntPredicate::SLT,
ast::Cmpop::LtE => inkwell::IntPredicate::SLE,
ast::Cmpop::Gt => inkwell::IntPredicate::SGT,
ast::Cmpop::GtE => inkwell::IntPredicate::SGE,
_ => unreachable!(),
};
self.builder.build_int_compare(op, lhs, rhs, "cmp")
} else if ty == self.primitives.float {
let (lhs, rhs) = if let (
BasicValueEnum::FloatValue(lhs),
BasicValueEnum::FloatValue(rhs),
) = (self.gen_expr(lhs), self.gen_expr(rhs))
{
(lhs, rhs)
let op = match op {
ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::IntPredicate::EQ,
ast::Cmpop::NotEq => inkwell::IntPredicate::NE,
ast::Cmpop::Lt => inkwell::IntPredicate::SLT,
ast::Cmpop::LtE => inkwell::IntPredicate::SLE,
ast::Cmpop::Gt => inkwell::IntPredicate::SGT,
ast::Cmpop::GtE => inkwell::IntPredicate::SGE,
_ => unreachable!(),
};
self.builder.build_int_compare(op, lhs, rhs, "cmp")
} else if ty == self.primitives.float {
let (lhs, rhs) = if let (
BasicValueEnum::FloatValue(lhs),
BasicValueEnum::FloatValue(rhs),
) = (self.gen_expr(lhs), self.gen_expr(rhs))
{
(lhs, rhs)
} else {
unreachable!()
};
let op = match op {
ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::FloatPredicate::OEQ,
ast::Cmpop::NotEq => inkwell::FloatPredicate::ONE,
ast::Cmpop::Lt => inkwell::FloatPredicate::OLT,
ast::Cmpop::LtE => inkwell::FloatPredicate::OLE,
ast::Cmpop::Gt => inkwell::FloatPredicate::OGT,
ast::Cmpop::GtE => inkwell::FloatPredicate::OGE,
_ => unreachable!(),
};
self.builder.build_float_compare(op, lhs, rhs, "cmp")
} else {
unreachable!()
unimplemented!()
};
let op = match op {
ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::FloatPredicate::OEQ,
ast::Cmpop::NotEq => inkwell::FloatPredicate::ONE,
ast::Cmpop::Lt => inkwell::FloatPredicate::OLT,
ast::Cmpop::LtE => inkwell::FloatPredicate::OLE,
ast::Cmpop::Gt => inkwell::FloatPredicate::OGT,
ast::Cmpop::GtE => inkwell::FloatPredicate::OGE,
_ => unreachable!(),
};
self.builder.build_float_compare(op, lhs, rhs, "cmp")
} else {
unimplemented!()
};
prev.map(|v| self.builder.build_and(v, current, "cmp")).or(Some(current))
})
.unwrap()

View File

@ -3,13 +3,15 @@ use crate::typecheck::typedef::Type;
use crate::top_level::DefinitionId;
use rustpython_parser::ast::Expr;
pub enum SymbolValue<'a> {
#[derive(Clone, PartialEq)]
pub enum SymbolValue {
I32(i32),
I64(i64),
Double(f64),
Bool(bool),
Tuple(&'a [SymbolValue<'a>]),
Bytes(&'a [u8]),
Tuple(Vec<SymbolValue>),
// we should think about how to implement bytes later...
// Bytes(&'a [u8]),
}
pub trait SymbolResolver {

View File

@ -77,7 +77,7 @@ pub fn impl_binop(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, othe
vars: HashMap::new(),
args: vec![FuncArg {
ty: other,
is_optional: false,
default_value: None,
name: "other".into()
}]
}))
@ -97,7 +97,7 @@ pub fn impl_binop(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, othe
vars: HashMap::new(),
args: vec![FuncArg {
ty: other,
is_optional: false,
default_value: None,
name: "other".into()
}]
}))
@ -132,7 +132,7 @@ pub fn impl_cmpop(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other
vars: HashMap::new(),
args: vec![FuncArg {
ty: other_ty,
is_optional: false,
default_value: None,
name: "other".into()
}]
}))
@ -144,15 +144,15 @@ pub fn impl_cmpop(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other
/// Add, Sub, Mult, Pow
pub fn impl_basic_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type], ret_ty: Type) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[
ast::Operator::Add,
ast::Operator::Sub,
ast::Operator::Add,
ast::Operator::Sub,
ast::Operator::Mult,
])
}
pub fn impl_pow(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type], ret_ty: Type) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[
ast::Operator::Pow,
ast::Operator::Pow,
])
}
@ -236,9 +236,9 @@ pub fn impl_eq(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) {
let PrimitiveStore {
int32: int32_t,
int64: int64_t,
float: float_t,
int32: int32_t,
int64: int64_t,
float: float_t,
bool: bool_t,
..
} = *store;
@ -255,8 +255,8 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
impl_not(unifier, store, int32_t);
impl_comparison(unifier, store, int32_t, int32_t);
impl_eq(unifier, store, int32_t);
/* int64 ======== */
/* int64 ======== */
impl_basic_arithmetic(unifier, store, int64_t, &[int64_t], int64_t);
impl_pow(unifier, store, int64_t, &[int64_t], int64_t);
impl_bitwise_arithmetic(unifier, store, int64_t);
@ -269,8 +269,8 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
impl_not(unifier, store, int64_t);
impl_comparison(unifier, store, int64_t, int64_t);
impl_eq(unifier, store, int64_t);
/* float ======== */
/* float ======== */
impl_basic_arithmetic(unifier, store, float_t, &[float_t], float_t);
impl_pow(unifier, store, float_t, &[int32_t, float_t], float_t);
impl_div(unifier, store, float_t, &[float_t]);
@ -280,8 +280,8 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
impl_not(unifier, store, float_t);
impl_comparison(unifier, store, float_t, float_t);
impl_eq(unifier, store, float_t);
/* bool ======== */
/* bool ======== */
impl_not(unifier, store, bool_t);
impl_eq(unifier, store, bool_t);
}
}

View File

@ -240,7 +240,7 @@ impl<'a> Inferencer<'a> {
let fun = FunSignature {
args: fn_args
.iter()
.map(|(k, ty)| FuncArg { name: k.clone(), ty: *ty, is_optional: false })
.map(|(k, ty)| FuncArg { name: k.clone(), ty: *ty, default_value: None })
.collect(),
ret,
vars: Default::default(),
@ -513,7 +513,13 @@ impl<'a> Inferencer<'a> {
for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) {
let method =
comparison_name(c).ok_or_else(|| "unsupported comparator".to_string())?.to_string();
self.build_method_call(a.location, method, a.custom.unwrap(), vec![b.custom.unwrap()], boolean)?;
self.build_method_call(
a.location,
method,
a.custom.unwrap(),
vec![b.custom.unwrap()],
boolean,
)?;
}
Ok(boolean)
}

View File

@ -6,8 +6,9 @@ use std::iter::once;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use crate::top_level::DefinitionId;
use super::unification_table::{UnificationKey, UnificationTable};
use crate::symbol_resolver::SymbolValue;
use crate::top_level::DefinitionId;
#[cfg(test)]
mod test;
@ -30,9 +31,7 @@ pub struct Call {
pub struct FuncArg {
pub name: String,
pub ty: Type,
// TODO: change this to an optional value
// for primitive types
pub is_optional: bool,
pub default_value: Option<SymbolValue>,
}
#[derive(Clone)]
@ -457,7 +456,7 @@ impl Unifier {
let required: Vec<String> = signature
.args
.iter()
.filter(|v| !v.is_optional)
.filter(|v| v.default_value.is_none())
.map(|v| v.name.clone())
.rev()
.collect();
@ -516,7 +515,7 @@ impl Unifier {
if x.name != y.name {
return Err("Functions differ in parameter names.".to_string());
}
if x.is_optional != y.is_optional {
if x.default_value != y.default_value {
return Err("Functions differ in optional parameters.".to_string());
}
self.unify(x.ty, y.ty)?;