forked from M-Labs/nac3
function parameter handling
This commit is contained in:
parent
711482d09c
commit
86ca02796b
|
@ -1,12 +1,9 @@
|
||||||
use std::{convert::TryInto, iter::once};
|
use std::{collections::HashMap, convert::TryInto, iter::once};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
top_level::DefinitionId,
|
symbol_resolver::SymbolValue,
|
||||||
typecheck::typedef::{Type, TypeEnum},
|
top_level::{CodeGenContext, DefinitionId, TopLevelDef},
|
||||||
};
|
typecheck::typedef::{FunSignature, Type, TypeEnum},
|
||||||
use crate::{
|
|
||||||
top_level::{CodeGenContext, TopLevelDef},
|
|
||||||
typecheck::typedef::FunSignature,
|
|
||||||
};
|
};
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
types::{BasicType, BasicTypeEnum},
|
types::{BasicType, BasicTypeEnum},
|
||||||
|
@ -65,10 +62,7 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||||
let fields = fields.borrow();
|
let fields = fields.borrow();
|
||||||
let fields =
|
let fields =
|
||||||
fields_list.iter().map(|f| self.get_llvm_type(fields[&f.0])).collect_vec();
|
fields_list.iter().map(|f| self.get_llvm_type(fields[&f.0])).collect_vec();
|
||||||
self.ctx
|
self.ctx.struct_type(&fields, false).ptr_type(AddressSpace::Generic).into()
|
||||||
.struct_type(&fields, false)
|
|
||||||
.ptr_type(AddressSpace::Generic)
|
|
||||||
.into()
|
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
|
@ -93,16 +87,20 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn gen_symbol_val(&mut self, val: &SymbolValue) -> BasicValueEnum<'ctx> {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
fn gen_call(
|
fn gen_call(
|
||||||
&mut self,
|
&mut self,
|
||||||
obj: Option<(Type, BasicValueEnum<'ctx>)>,
|
obj: Option<(Type, BasicValueEnum<'ctx>)>,
|
||||||
fun: (&FunSignature, DefinitionId),
|
fun: (&FunSignature, DefinitionId),
|
||||||
params: &[BasicValueEnum<'ctx>],
|
params: Vec<(Option<String>, BasicValueEnum<'ctx>)>,
|
||||||
ret: Type,
|
ret: Type,
|
||||||
) -> Option<BasicValueEnum<'ctx>> {
|
) -> Option<BasicValueEnum<'ctx>> {
|
||||||
let key = self.get_subst_key(obj.map(|(a, _)| a), fun.0);
|
let key = self.get_subst_key(obj.map(|(a, _)| a), fun.0);
|
||||||
let defs = self.top_level.definitions.read();
|
let defs = self.top_level.definitions.read();
|
||||||
let definition = defs.get(fun.1.0).unwrap();
|
let definition = defs.get(fun.1 .0).unwrap();
|
||||||
let val = if let TopLevelDef::Function { instance_to_symbol, .. } = &*definition.read() {
|
let val = if let TopLevelDef::Function { instance_to_symbol, .. } = &*definition.read() {
|
||||||
let symbol = instance_to_symbol.get(&key).unwrap_or_else(|| {
|
let symbol = instance_to_symbol.get(&key).unwrap_or_else(|| {
|
||||||
// TODO: codegen for function that are not yet generated
|
// TODO: codegen for function that are not yet generated
|
||||||
|
@ -117,8 +115,19 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||||
};
|
};
|
||||||
self.module.add_function(symbol, fun_ty, None)
|
self.module.add_function(symbol, fun_ty, None)
|
||||||
});
|
});
|
||||||
// TODO: deal with default parameters and reordering based on keys
|
let mut keys = fun.0.args.clone();
|
||||||
self.builder.build_call(fun_val, params, "call").try_as_basic_value().left()
|
let mut mapping = HashMap::new();
|
||||||
|
for (key, value) in params.into_iter() {
|
||||||
|
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value);
|
||||||
|
}
|
||||||
|
// default value handling
|
||||||
|
for k in keys.into_iter() {
|
||||||
|
mapping.insert(k.name, self.gen_symbol_val(&k.default_value.unwrap()));
|
||||||
|
}
|
||||||
|
// reorder the parameters
|
||||||
|
let params =
|
||||||
|
fun.0.args.iter().map(|arg| mapping.remove(&arg.name).unwrap()).collect_vec();
|
||||||
|
self.builder.build_call(fun_val, ¶ms, "call").try_as_basic_value().left()
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
|
@ -158,7 +167,7 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||||
let ty = self.ctx.struct_type(&types, false);
|
let ty = self.ctx.struct_type(&types, false);
|
||||||
ty.const_named_struct(&values).into()
|
ty.const_named_struct(&values).into()
|
||||||
}
|
}
|
||||||
_ => unreachable!()
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -261,10 +270,7 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||||
"tmparr",
|
"tmparr",
|
||||||
);
|
);
|
||||||
let arr_ty = self.ctx.struct_type(
|
let arr_ty = self.ctx.struct_type(
|
||||||
&[
|
&[self.ctx.i32_type().into(), ty.ptr_type(AddressSpace::Generic).into()],
|
||||||
self.ctx.i32_type().into(),
|
|
||||||
ty.ptr_type(AddressSpace::Generic).into(),
|
|
||||||
],
|
|
||||||
false,
|
false,
|
||||||
);
|
);
|
||||||
let arr_str_ptr = self.builder.build_alloca(arr_ty, "tmparrstr");
|
let arr_str_ptr = self.builder.build_alloca(arr_ty, "tmparrstr");
|
||||||
|
@ -445,50 +451,52 @@ impl<'ctx> CodeGenContext<'ctx> {
|
||||||
)
|
)
|
||||||
.fold(None, |prev, (lhs, rhs, op)| {
|
.fold(None, |prev, (lhs, rhs, op)| {
|
||||||
let ty = lhs.custom.unwrap();
|
let ty = lhs.custom.unwrap();
|
||||||
let current = if [self.primitives.int32, self.primitives.int64, self.primitives.bool]
|
let current =
|
||||||
.contains(&ty)
|
if [self.primitives.int32, self.primitives.int64, self.primitives.bool]
|
||||||
{
|
.contains(&ty)
|
||||||
let (lhs, rhs) =
|
{
|
||||||
if let (BasicValueEnum::IntValue(lhs), BasicValueEnum::IntValue(rhs)) =
|
let (lhs, rhs) = if let (
|
||||||
(self.gen_expr(lhs), self.gen_expr(rhs))
|
BasicValueEnum::IntValue(lhs),
|
||||||
|
BasicValueEnum::IntValue(rhs),
|
||||||
|
) = (self.gen_expr(lhs), self.gen_expr(rhs))
|
||||||
{
|
{
|
||||||
(lhs, rhs)
|
(lhs, rhs)
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
let op = match op {
|
let op = match op {
|
||||||
ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::IntPredicate::EQ,
|
ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::IntPredicate::EQ,
|
||||||
ast::Cmpop::NotEq => inkwell::IntPredicate::NE,
|
ast::Cmpop::NotEq => inkwell::IntPredicate::NE,
|
||||||
ast::Cmpop::Lt => inkwell::IntPredicate::SLT,
|
ast::Cmpop::Lt => inkwell::IntPredicate::SLT,
|
||||||
ast::Cmpop::LtE => inkwell::IntPredicate::SLE,
|
ast::Cmpop::LtE => inkwell::IntPredicate::SLE,
|
||||||
ast::Cmpop::Gt => inkwell::IntPredicate::SGT,
|
ast::Cmpop::Gt => inkwell::IntPredicate::SGT,
|
||||||
ast::Cmpop::GtE => inkwell::IntPredicate::SGE,
|
ast::Cmpop::GtE => inkwell::IntPredicate::SGE,
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
self.builder.build_int_compare(op, lhs, rhs, "cmp")
|
self.builder.build_int_compare(op, lhs, rhs, "cmp")
|
||||||
} else if ty == self.primitives.float {
|
} else if ty == self.primitives.float {
|
||||||
let (lhs, rhs) = if let (
|
let (lhs, rhs) = if let (
|
||||||
BasicValueEnum::FloatValue(lhs),
|
BasicValueEnum::FloatValue(lhs),
|
||||||
BasicValueEnum::FloatValue(rhs),
|
BasicValueEnum::FloatValue(rhs),
|
||||||
) = (self.gen_expr(lhs), self.gen_expr(rhs))
|
) = (self.gen_expr(lhs), self.gen_expr(rhs))
|
||||||
{
|
{
|
||||||
(lhs, rhs)
|
(lhs, rhs)
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
let op = match op {
|
||||||
|
ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::FloatPredicate::OEQ,
|
||||||
|
ast::Cmpop::NotEq => inkwell::FloatPredicate::ONE,
|
||||||
|
ast::Cmpop::Lt => inkwell::FloatPredicate::OLT,
|
||||||
|
ast::Cmpop::LtE => inkwell::FloatPredicate::OLE,
|
||||||
|
ast::Cmpop::Gt => inkwell::FloatPredicate::OGT,
|
||||||
|
ast::Cmpop::GtE => inkwell::FloatPredicate::OGE,
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
self.builder.build_float_compare(op, lhs, rhs, "cmp")
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unimplemented!()
|
||||||
};
|
};
|
||||||
let op = match op {
|
|
||||||
ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::FloatPredicate::OEQ,
|
|
||||||
ast::Cmpop::NotEq => inkwell::FloatPredicate::ONE,
|
|
||||||
ast::Cmpop::Lt => inkwell::FloatPredicate::OLT,
|
|
||||||
ast::Cmpop::LtE => inkwell::FloatPredicate::OLE,
|
|
||||||
ast::Cmpop::Gt => inkwell::FloatPredicate::OGT,
|
|
||||||
ast::Cmpop::GtE => inkwell::FloatPredicate::OGE,
|
|
||||||
_ => unreachable!(),
|
|
||||||
};
|
|
||||||
self.builder.build_float_compare(op, lhs, rhs, "cmp")
|
|
||||||
} else {
|
|
||||||
unimplemented!()
|
|
||||||
};
|
|
||||||
prev.map(|v| self.builder.build_and(v, current, "cmp")).or(Some(current))
|
prev.map(|v| self.builder.build_and(v, current, "cmp")).or(Some(current))
|
||||||
})
|
})
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
|
|
@ -3,13 +3,15 @@ use crate::typecheck::typedef::Type;
|
||||||
use crate::top_level::DefinitionId;
|
use crate::top_level::DefinitionId;
|
||||||
use rustpython_parser::ast::Expr;
|
use rustpython_parser::ast::Expr;
|
||||||
|
|
||||||
pub enum SymbolValue<'a> {
|
#[derive(Clone, PartialEq)]
|
||||||
|
pub enum SymbolValue {
|
||||||
I32(i32),
|
I32(i32),
|
||||||
I64(i64),
|
I64(i64),
|
||||||
Double(f64),
|
Double(f64),
|
||||||
Bool(bool),
|
Bool(bool),
|
||||||
Tuple(&'a [SymbolValue<'a>]),
|
Tuple(Vec<SymbolValue>),
|
||||||
Bytes(&'a [u8]),
|
// we should think about how to implement bytes later...
|
||||||
|
// Bytes(&'a [u8]),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait SymbolResolver {
|
pub trait SymbolResolver {
|
||||||
|
|
|
@ -77,7 +77,7 @@ pub fn impl_binop(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, othe
|
||||||
vars: HashMap::new(),
|
vars: HashMap::new(),
|
||||||
args: vec![FuncArg {
|
args: vec![FuncArg {
|
||||||
ty: other,
|
ty: other,
|
||||||
is_optional: false,
|
default_value: None,
|
||||||
name: "other".into()
|
name: "other".into()
|
||||||
}]
|
}]
|
||||||
}))
|
}))
|
||||||
|
@ -97,7 +97,7 @@ pub fn impl_binop(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, othe
|
||||||
vars: HashMap::new(),
|
vars: HashMap::new(),
|
||||||
args: vec![FuncArg {
|
args: vec![FuncArg {
|
||||||
ty: other,
|
ty: other,
|
||||||
is_optional: false,
|
default_value: None,
|
||||||
name: "other".into()
|
name: "other".into()
|
||||||
}]
|
}]
|
||||||
}))
|
}))
|
||||||
|
@ -132,7 +132,7 @@ pub fn impl_cmpop(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other
|
||||||
vars: HashMap::new(),
|
vars: HashMap::new(),
|
||||||
args: vec![FuncArg {
|
args: vec![FuncArg {
|
||||||
ty: other_ty,
|
ty: other_ty,
|
||||||
is_optional: false,
|
default_value: None,
|
||||||
name: "other".into()
|
name: "other".into()
|
||||||
}]
|
}]
|
||||||
}))
|
}))
|
||||||
|
|
|
@ -240,7 +240,7 @@ impl<'a> Inferencer<'a> {
|
||||||
let fun = FunSignature {
|
let fun = FunSignature {
|
||||||
args: fn_args
|
args: fn_args
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(k, ty)| FuncArg { name: k.clone(), ty: *ty, is_optional: false })
|
.map(|(k, ty)| FuncArg { name: k.clone(), ty: *ty, default_value: None })
|
||||||
.collect(),
|
.collect(),
|
||||||
ret,
|
ret,
|
||||||
vars: Default::default(),
|
vars: Default::default(),
|
||||||
|
@ -513,7 +513,13 @@ impl<'a> Inferencer<'a> {
|
||||||
for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) {
|
for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) {
|
||||||
let method =
|
let method =
|
||||||
comparison_name(c).ok_or_else(|| "unsupported comparator".to_string())?.to_string();
|
comparison_name(c).ok_or_else(|| "unsupported comparator".to_string())?.to_string();
|
||||||
self.build_method_call(a.location, method, a.custom.unwrap(), vec![b.custom.unwrap()], boolean)?;
|
self.build_method_call(
|
||||||
|
a.location,
|
||||||
|
method,
|
||||||
|
a.custom.unwrap(),
|
||||||
|
vec![b.custom.unwrap()],
|
||||||
|
boolean,
|
||||||
|
)?;
|
||||||
}
|
}
|
||||||
Ok(boolean)
|
Ok(boolean)
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,8 +6,9 @@ use std::iter::once;
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
use crate::top_level::DefinitionId;
|
|
||||||
use super::unification_table::{UnificationKey, UnificationTable};
|
use super::unification_table::{UnificationKey, UnificationTable};
|
||||||
|
use crate::symbol_resolver::SymbolValue;
|
||||||
|
use crate::top_level::DefinitionId;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test;
|
mod test;
|
||||||
|
@ -30,9 +31,7 @@ pub struct Call {
|
||||||
pub struct FuncArg {
|
pub struct FuncArg {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
pub ty: Type,
|
pub ty: Type,
|
||||||
// TODO: change this to an optional value
|
pub default_value: Option<SymbolValue>,
|
||||||
// for primitive types
|
|
||||||
pub is_optional: bool,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
|
@ -457,7 +456,7 @@ impl Unifier {
|
||||||
let required: Vec<String> = signature
|
let required: Vec<String> = signature
|
||||||
.args
|
.args
|
||||||
.iter()
|
.iter()
|
||||||
.filter(|v| !v.is_optional)
|
.filter(|v| v.default_value.is_none())
|
||||||
.map(|v| v.name.clone())
|
.map(|v| v.name.clone())
|
||||||
.rev()
|
.rev()
|
||||||
.collect();
|
.collect();
|
||||||
|
@ -516,7 +515,7 @@ impl Unifier {
|
||||||
if x.name != y.name {
|
if x.name != y.name {
|
||||||
return Err("Functions differ in parameter names.".to_string());
|
return Err("Functions differ in parameter names.".to_string());
|
||||||
}
|
}
|
||||||
if x.is_optional != y.is_optional {
|
if x.default_value != y.default_value {
|
||||||
return Err("Functions differ in optional parameters.".to_string());
|
return Err("Functions differ in optional parameters.".to_string());
|
||||||
}
|
}
|
||||||
self.unify(x.ty, y.ty)?;
|
self.unify(x.ty, y.ty)?;
|
||||||
|
|
Loading…
Reference in New Issue