forked from M-Labs/nac3
function parameter handling
This commit is contained in:
parent
711482d09c
commit
86ca02796b
@ -1,12 +1,9 @@
|
||||
use std::{convert::TryInto, iter::once};
|
||||
use std::{collections::HashMap, convert::TryInto, iter::once};
|
||||
|
||||
use crate::{
|
||||
top_level::DefinitionId,
|
||||
typecheck::typedef::{Type, TypeEnum},
|
||||
};
|
||||
use crate::{
|
||||
top_level::{CodeGenContext, TopLevelDef},
|
||||
typecheck::typedef::FunSignature,
|
||||
symbol_resolver::SymbolValue,
|
||||
top_level::{CodeGenContext, DefinitionId, TopLevelDef},
|
||||
typecheck::typedef::{FunSignature, Type, TypeEnum},
|
||||
};
|
||||
use inkwell::{
|
||||
types::{BasicType, BasicTypeEnum},
|
||||
@ -65,10 +62,7 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||
let fields = fields.borrow();
|
||||
let fields =
|
||||
fields_list.iter().map(|f| self.get_llvm_type(fields[&f.0])).collect_vec();
|
||||
self.ctx
|
||||
.struct_type(&fields, false)
|
||||
.ptr_type(AddressSpace::Generic)
|
||||
.into()
|
||||
self.ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into()
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
@ -93,16 +87,20 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||
})
|
||||
}
|
||||
|
||||
fn gen_symbol_val(&mut self, val: &SymbolValue) -> BasicValueEnum<'ctx> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn gen_call(
|
||||
&mut self,
|
||||
obj: Option<(Type, BasicValueEnum<'ctx>)>,
|
||||
fun: (&FunSignature, DefinitionId),
|
||||
params: &[BasicValueEnum<'ctx>],
|
||||
params: Vec<(Option<String>, BasicValueEnum<'ctx>)>,
|
||||
ret: Type,
|
||||
) -> Option<BasicValueEnum<'ctx>> {
|
||||
let key = self.get_subst_key(obj.map(|(a, _)| a), fun.0);
|
||||
let defs = self.top_level.definitions.read();
|
||||
let definition = defs.get(fun.1.0).unwrap();
|
||||
let definition = defs.get(fun.1 .0).unwrap();
|
||||
let val = if let TopLevelDef::Function { instance_to_symbol, .. } = &*definition.read() {
|
||||
let symbol = instance_to_symbol.get(&key).unwrap_or_else(|| {
|
||||
// TODO: codegen for function that are not yet generated
|
||||
@ -117,8 +115,19 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||
};
|
||||
self.module.add_function(symbol, fun_ty, None)
|
||||
});
|
||||
// TODO: deal with default parameters and reordering based on keys
|
||||
self.builder.build_call(fun_val, params, "call").try_as_basic_value().left()
|
||||
let mut keys = fun.0.args.clone();
|
||||
let mut mapping = HashMap::new();
|
||||
for (key, value) in params.into_iter() {
|
||||
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value);
|
||||
}
|
||||
// default value handling
|
||||
for k in keys.into_iter() {
|
||||
mapping.insert(k.name, self.gen_symbol_val(&k.default_value.unwrap()));
|
||||
}
|
||||
// reorder the parameters
|
||||
let params =
|
||||
fun.0.args.iter().map(|arg| mapping.remove(&arg.name).unwrap()).collect_vec();
|
||||
self.builder.build_call(fun_val, ¶ms, "call").try_as_basic_value().left()
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
@ -158,7 +167,7 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||
let ty = self.ctx.struct_type(&types, false);
|
||||
ty.const_named_struct(&values).into()
|
||||
}
|
||||
_ => unreachable!()
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -261,10 +270,7 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||
"tmparr",
|
||||
);
|
||||
let arr_ty = self.ctx.struct_type(
|
||||
&[
|
||||
self.ctx.i32_type().into(),
|
||||
ty.ptr_type(AddressSpace::Generic).into(),
|
||||
],
|
||||
&[self.ctx.i32_type().into(), ty.ptr_type(AddressSpace::Generic).into()],
|
||||
false,
|
||||
);
|
||||
let arr_str_ptr = self.builder.build_alloca(arr_ty, "tmparrstr");
|
||||
@ -445,50 +451,52 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||
)
|
||||
.fold(None, |prev, (lhs, rhs, op)| {
|
||||
let ty = lhs.custom.unwrap();
|
||||
let current = if [self.primitives.int32, self.primitives.int64, self.primitives.bool]
|
||||
.contains(&ty)
|
||||
{
|
||||
let (lhs, rhs) =
|
||||
if let (BasicValueEnum::IntValue(lhs), BasicValueEnum::IntValue(rhs)) =
|
||||
(self.gen_expr(lhs), self.gen_expr(rhs))
|
||||
let current =
|
||||
if [self.primitives.int32, self.primitives.int64, self.primitives.bool]
|
||||
.contains(&ty)
|
||||
{
|
||||
let (lhs, rhs) = if let (
|
||||
BasicValueEnum::IntValue(lhs),
|
||||
BasicValueEnum::IntValue(rhs),
|
||||
) = (self.gen_expr(lhs), self.gen_expr(rhs))
|
||||
{
|
||||
(lhs, rhs)
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
let op = match op {
|
||||
ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::IntPredicate::EQ,
|
||||
ast::Cmpop::NotEq => inkwell::IntPredicate::NE,
|
||||
ast::Cmpop::Lt => inkwell::IntPredicate::SLT,
|
||||
ast::Cmpop::LtE => inkwell::IntPredicate::SLE,
|
||||
ast::Cmpop::Gt => inkwell::IntPredicate::SGT,
|
||||
ast::Cmpop::GtE => inkwell::IntPredicate::SGE,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
self.builder.build_int_compare(op, lhs, rhs, "cmp")
|
||||
} else if ty == self.primitives.float {
|
||||
let (lhs, rhs) = if let (
|
||||
BasicValueEnum::FloatValue(lhs),
|
||||
BasicValueEnum::FloatValue(rhs),
|
||||
) = (self.gen_expr(lhs), self.gen_expr(rhs))
|
||||
{
|
||||
(lhs, rhs)
|
||||
let op = match op {
|
||||
ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::IntPredicate::EQ,
|
||||
ast::Cmpop::NotEq => inkwell::IntPredicate::NE,
|
||||
ast::Cmpop::Lt => inkwell::IntPredicate::SLT,
|
||||
ast::Cmpop::LtE => inkwell::IntPredicate::SLE,
|
||||
ast::Cmpop::Gt => inkwell::IntPredicate::SGT,
|
||||
ast::Cmpop::GtE => inkwell::IntPredicate::SGE,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
self.builder.build_int_compare(op, lhs, rhs, "cmp")
|
||||
} else if ty == self.primitives.float {
|
||||
let (lhs, rhs) = if let (
|
||||
BasicValueEnum::FloatValue(lhs),
|
||||
BasicValueEnum::FloatValue(rhs),
|
||||
) = (self.gen_expr(lhs), self.gen_expr(rhs))
|
||||
{
|
||||
(lhs, rhs)
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
let op = match op {
|
||||
ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::FloatPredicate::OEQ,
|
||||
ast::Cmpop::NotEq => inkwell::FloatPredicate::ONE,
|
||||
ast::Cmpop::Lt => inkwell::FloatPredicate::OLT,
|
||||
ast::Cmpop::LtE => inkwell::FloatPredicate::OLE,
|
||||
ast::Cmpop::Gt => inkwell::FloatPredicate::OGT,
|
||||
ast::Cmpop::GtE => inkwell::FloatPredicate::OGE,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
self.builder.build_float_compare(op, lhs, rhs, "cmp")
|
||||
} else {
|
||||
unreachable!()
|
||||
unimplemented!()
|
||||
};
|
||||
let op = match op {
|
||||
ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::FloatPredicate::OEQ,
|
||||
ast::Cmpop::NotEq => inkwell::FloatPredicate::ONE,
|
||||
ast::Cmpop::Lt => inkwell::FloatPredicate::OLT,
|
||||
ast::Cmpop::LtE => inkwell::FloatPredicate::OLE,
|
||||
ast::Cmpop::Gt => inkwell::FloatPredicate::OGT,
|
||||
ast::Cmpop::GtE => inkwell::FloatPredicate::OGE,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
self.builder.build_float_compare(op, lhs, rhs, "cmp")
|
||||
} else {
|
||||
unimplemented!()
|
||||
};
|
||||
prev.map(|v| self.builder.build_and(v, current, "cmp")).or(Some(current))
|
||||
})
|
||||
.unwrap()
|
||||
|
@ -3,13 +3,15 @@ use crate::typecheck::typedef::Type;
|
||||
use crate::top_level::DefinitionId;
|
||||
use rustpython_parser::ast::Expr;
|
||||
|
||||
pub enum SymbolValue<'a> {
|
||||
#[derive(Clone, PartialEq)]
|
||||
pub enum SymbolValue {
|
||||
I32(i32),
|
||||
I64(i64),
|
||||
Double(f64),
|
||||
Bool(bool),
|
||||
Tuple(&'a [SymbolValue<'a>]),
|
||||
Bytes(&'a [u8]),
|
||||
Tuple(Vec<SymbolValue>),
|
||||
// we should think about how to implement bytes later...
|
||||
// Bytes(&'a [u8]),
|
||||
}
|
||||
|
||||
pub trait SymbolResolver {
|
||||
|
@ -77,7 +77,7 @@ pub fn impl_binop(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, othe
|
||||
vars: HashMap::new(),
|
||||
args: vec![FuncArg {
|
||||
ty: other,
|
||||
is_optional: false,
|
||||
default_value: None,
|
||||
name: "other".into()
|
||||
}]
|
||||
}))
|
||||
@ -97,7 +97,7 @@ pub fn impl_binop(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, othe
|
||||
vars: HashMap::new(),
|
||||
args: vec![FuncArg {
|
||||
ty: other,
|
||||
is_optional: false,
|
||||
default_value: None,
|
||||
name: "other".into()
|
||||
}]
|
||||
}))
|
||||
@ -132,7 +132,7 @@ pub fn impl_cmpop(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other
|
||||
vars: HashMap::new(),
|
||||
args: vec![FuncArg {
|
||||
ty: other_ty,
|
||||
is_optional: false,
|
||||
default_value: None,
|
||||
name: "other".into()
|
||||
}]
|
||||
}))
|
||||
@ -144,15 +144,15 @@ pub fn impl_cmpop(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other
|
||||
/// Add, Sub, Mult, Pow
|
||||
pub fn impl_basic_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type], ret_ty: Type) {
|
||||
impl_binop(unifier, store, ty, other_ty, ret_ty, &[
|
||||
ast::Operator::Add,
|
||||
ast::Operator::Sub,
|
||||
ast::Operator::Add,
|
||||
ast::Operator::Sub,
|
||||
ast::Operator::Mult,
|
||||
])
|
||||
}
|
||||
|
||||
pub fn impl_pow(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type], ret_ty: Type) {
|
||||
impl_binop(unifier, store, ty, other_ty, ret_ty, &[
|
||||
ast::Operator::Pow,
|
||||
ast::Operator::Pow,
|
||||
])
|
||||
}
|
||||
|
||||
@ -236,9 +236,9 @@ pub fn impl_eq(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
|
||||
|
||||
pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) {
|
||||
let PrimitiveStore {
|
||||
int32: int32_t,
|
||||
int64: int64_t,
|
||||
float: float_t,
|
||||
int32: int32_t,
|
||||
int64: int64_t,
|
||||
float: float_t,
|
||||
bool: bool_t,
|
||||
..
|
||||
} = *store;
|
||||
@ -255,8 +255,8 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
||||
impl_not(unifier, store, int32_t);
|
||||
impl_comparison(unifier, store, int32_t, int32_t);
|
||||
impl_eq(unifier, store, int32_t);
|
||||
|
||||
/* int64 ======== */
|
||||
|
||||
/* int64 ======== */
|
||||
impl_basic_arithmetic(unifier, store, int64_t, &[int64_t], int64_t);
|
||||
impl_pow(unifier, store, int64_t, &[int64_t], int64_t);
|
||||
impl_bitwise_arithmetic(unifier, store, int64_t);
|
||||
@ -269,8 +269,8 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
||||
impl_not(unifier, store, int64_t);
|
||||
impl_comparison(unifier, store, int64_t, int64_t);
|
||||
impl_eq(unifier, store, int64_t);
|
||||
|
||||
/* float ======== */
|
||||
|
||||
/* float ======== */
|
||||
impl_basic_arithmetic(unifier, store, float_t, &[float_t], float_t);
|
||||
impl_pow(unifier, store, float_t, &[int32_t, float_t], float_t);
|
||||
impl_div(unifier, store, float_t, &[float_t]);
|
||||
@ -280,8 +280,8 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
||||
impl_not(unifier, store, float_t);
|
||||
impl_comparison(unifier, store, float_t, float_t);
|
||||
impl_eq(unifier, store, float_t);
|
||||
|
||||
/* bool ======== */
|
||||
|
||||
/* bool ======== */
|
||||
impl_not(unifier, store, bool_t);
|
||||
impl_eq(unifier, store, bool_t);
|
||||
}
|
||||
}
|
||||
|
@ -240,7 +240,7 @@ impl<'a> Inferencer<'a> {
|
||||
let fun = FunSignature {
|
||||
args: fn_args
|
||||
.iter()
|
||||
.map(|(k, ty)| FuncArg { name: k.clone(), ty: *ty, is_optional: false })
|
||||
.map(|(k, ty)| FuncArg { name: k.clone(), ty: *ty, default_value: None })
|
||||
.collect(),
|
||||
ret,
|
||||
vars: Default::default(),
|
||||
@ -513,7 +513,13 @@ impl<'a> Inferencer<'a> {
|
||||
for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) {
|
||||
let method =
|
||||
comparison_name(c).ok_or_else(|| "unsupported comparator".to_string())?.to_string();
|
||||
self.build_method_call(a.location, method, a.custom.unwrap(), vec![b.custom.unwrap()], boolean)?;
|
||||
self.build_method_call(
|
||||
a.location,
|
||||
method,
|
||||
a.custom.unwrap(),
|
||||
vec![b.custom.unwrap()],
|
||||
boolean,
|
||||
)?;
|
||||
}
|
||||
Ok(boolean)
|
||||
}
|
||||
|
@ -6,8 +6,9 @@ use std::iter::once;
|
||||
use std::rc::Rc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::top_level::DefinitionId;
|
||||
use super::unification_table::{UnificationKey, UnificationTable};
|
||||
use crate::symbol_resolver::SymbolValue;
|
||||
use crate::top_level::DefinitionId;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
@ -30,9 +31,7 @@ pub struct Call {
|
||||
pub struct FuncArg {
|
||||
pub name: String,
|
||||
pub ty: Type,
|
||||
// TODO: change this to an optional value
|
||||
// for primitive types
|
||||
pub is_optional: bool,
|
||||
pub default_value: Option<SymbolValue>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -457,7 +456,7 @@ impl Unifier {
|
||||
let required: Vec<String> = signature
|
||||
.args
|
||||
.iter()
|
||||
.filter(|v| !v.is_optional)
|
||||
.filter(|v| v.default_value.is_none())
|
||||
.map(|v| v.name.clone())
|
||||
.rev()
|
||||
.collect();
|
||||
@ -516,7 +515,7 @@ impl Unifier {
|
||||
if x.name != y.name {
|
||||
return Err("Functions differ in parameter names.".to_string());
|
||||
}
|
||||
if x.is_optional != y.is_optional {
|
||||
if x.default_value != y.default_value {
|
||||
return Err("Functions differ in optional parameters.".to_string());
|
||||
}
|
||||
self.unify(x.ty, y.ty)?;
|
||||
|
Loading…
Reference in New Issue
Block a user