forked from M-Labs/nac3
added primitive codegen test
This commit is contained in:
parent
3a93e2b048
commit
77943a8117
|
@ -14,7 +14,7 @@ use inkwell::{
|
|||
use itertools::{chain, izip, zip, Itertools};
|
||||
use rustpython_parser::ast::{self, Boolop, Constant, Expr, ExprKind, Operator};
|
||||
|
||||
impl<'ctx> CodeGenContext<'ctx> {
|
||||
impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||
fn get_subst_key(&mut self, obj: Option<Type>, fun: &FunSignature) -> String {
|
||||
let mut vars = obj
|
||||
.map(|ty| {
|
||||
|
|
|
@ -16,21 +16,23 @@ use inkwell::{
|
|||
AddressSpace,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use rayon::current_thread_index;
|
||||
use rustpython_parser::ast::{Stmt, StmtKind};
|
||||
use rustpython_parser::ast::Stmt;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
mod expr;
|
||||
mod stmt;
|
||||
|
||||
pub struct CodeGenContext<'ctx> {
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
||||
pub struct CodeGenContext<'ctx, 'a> {
|
||||
pub ctx: &'ctx Context,
|
||||
pub builder: Builder<'ctx>,
|
||||
pub module: Module<'ctx>,
|
||||
pub top_level: &'ctx TopLevelContext,
|
||||
pub top_level: &'a TopLevelContext,
|
||||
pub unifier: Unifier,
|
||||
pub resolver: Box<dyn SymbolResolver>,
|
||||
pub resolver: Arc<dyn SymbolResolver>,
|
||||
pub var_assignment: HashMap<String, PointerValue<'ctx>>,
|
||||
pub type_cache: HashMap<Type, BasicTypeEnum<'ctx>>,
|
||||
pub primitives: PrimitiveStore,
|
||||
|
@ -45,9 +47,9 @@ pub struct CodeGenTask {
|
|||
pub subst: Vec<(Type, Type)>,
|
||||
pub symbol_name: String,
|
||||
pub signature: FunSignature,
|
||||
pub body: Stmt<Option<Type>>,
|
||||
pub body: Vec<Stmt<Option<Type>>>,
|
||||
pub unifier_index: usize,
|
||||
pub resolver: Box<dyn SymbolResolver>,
|
||||
pub resolver: Arc<dyn SymbolResolver>,
|
||||
}
|
||||
|
||||
fn get_llvm_type<'ctx>(
|
||||
|
@ -60,7 +62,7 @@ fn get_llvm_type<'ctx>(
|
|||
use TypeEnum::*;
|
||||
// we assume the type cache should already contain primitive types,
|
||||
// and they should be passed by value instead of passing as pointer.
|
||||
type_cache.get(&ty).cloned().unwrap_or_else(|| match &*unifier.get_ty(ty) {
|
||||
type_cache.get(&unifier.get_representative(ty)).cloned().unwrap_or_else(|| match &*unifier.get_ty(ty) {
|
||||
TObj { obj_id, fields, .. } => {
|
||||
// a struct with fields in the order of declaration
|
||||
let defs = top_level.definitions.read();
|
||||
|
@ -97,16 +99,13 @@ fn get_llvm_type<'ctx>(
|
|||
})
|
||||
}
|
||||
|
||||
pub fn gen_func(task: CodeGenTask, top_level_ctx: Arc<TopLevelContext>) {
|
||||
pub fn gen_func<'ctx>(context: &'ctx Context, builder: Builder<'ctx>, module: Module<'ctx>, task: CodeGenTask, top_level_ctx: Arc<TopLevelContext>) -> Module<'ctx> {
|
||||
// unwrap_or(0) is for unit tests without using rayon
|
||||
let thread_id = current_thread_index().unwrap_or(0);
|
||||
let (mut unifier, primitives) = {
|
||||
let unifiers = top_level_ctx.unifiers.read();
|
||||
let (unifier, primitives) = &unifiers[task.unifier_index];
|
||||
(Unifier::from_shared_unifier(unifier), *primitives)
|
||||
};
|
||||
let contexts = top_level_ctx.conetexts.read();
|
||||
let context = contexts[thread_id].lock();
|
||||
|
||||
for (a, b) in task.subst.iter() {
|
||||
// this should be unification between variables and concrete types
|
||||
|
@ -124,10 +123,10 @@ pub fn gen_func(task: CodeGenTask, top_level_ctx: Arc<TopLevelContext>) {
|
|||
};
|
||||
|
||||
let mut type_cache: HashMap<_, _> = [
|
||||
(primitives.int32, context.i32_type().into()),
|
||||
(primitives.int64, context.i64_type().into()),
|
||||
(primitives.float, context.f64_type().into()),
|
||||
(primitives.bool, context.bool_type().into()),
|
||||
(unifier.get_representative(primitives.int32), context.i32_type().into()),
|
||||
(unifier.get_representative(primitives.int64), context.i64_type().into()),
|
||||
(unifier.get_representative(primitives.float), context.f64_type().into()),
|
||||
(unifier.get_representative(primitives.bool), context.bool_type().into()),
|
||||
]
|
||||
.iter()
|
||||
.cloned()
|
||||
|
@ -155,8 +154,6 @@ pub fn gen_func(task: CodeGenTask, top_level_ctx: Arc<TopLevelContext>) {
|
|||
.fn_type(¶ms, false)
|
||||
};
|
||||
|
||||
let builder = context.create_builder();
|
||||
let module = context.create_module(&task.symbol_name);
|
||||
let fn_val = module.add_function(&task.symbol_name, fn_type, None);
|
||||
let init_bb = context.append_basic_block(fn_val, "init");
|
||||
builder.position_at_end(init_bb);
|
||||
|
@ -189,9 +186,9 @@ pub fn gen_func(task: CodeGenTask, top_level_ctx: Arc<TopLevelContext>) {
|
|||
unifier,
|
||||
};
|
||||
|
||||
if let StmtKind::FunctionDef { body, .. } = &task.body.node {
|
||||
for stmt in body.iter() {
|
||||
for stmt in task.body.iter() {
|
||||
code_gen_context.gen_stmt(stmt);
|
||||
}
|
||||
}
|
||||
|
||||
code_gen_context.module
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ use crate::typecheck::typedef::Type;
|
|||
use inkwell::values::{BasicValue, BasicValueEnum, PointerValue};
|
||||
use rustpython_parser::ast::{Expr, ExprKind, Stmt, StmtKind};
|
||||
|
||||
impl<'ctx> CodeGenContext<'ctx> {
|
||||
impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||
fn gen_var(&mut self, ty: Type) -> PointerValue<'ctx> {
|
||||
// put the alloca in init block
|
||||
let current = self.builder.get_insert_block().unwrap();
|
||||
|
|
|
@ -0,0 +1,246 @@
|
|||
use super::{gen_func, CodeGenTask};
|
||||
use crate::{
|
||||
location::Location,
|
||||
symbol_resolver::{SymbolResolver, SymbolValue},
|
||||
top_level::{DefinitionId, TopLevelContext},
|
||||
typecheck::{
|
||||
magic_methods::set_primitives_magic_methods,
|
||||
type_inferencer::{CodeLocation, FunctionData, Inferencer, PrimitiveStore},
|
||||
typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier},
|
||||
},
|
||||
};
|
||||
use indoc::indoc;
|
||||
use inkwell::context::Context;
|
||||
use parking_lot::RwLock;
|
||||
use rustpython_parser::{ast::fold::Fold, parser::parse_program};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Resolver {
|
||||
id_to_type: HashMap<String, Type>,
|
||||
id_to_def: HashMap<String, DefinitionId>,
|
||||
class_names: HashMap<String, Type>,
|
||||
}
|
||||
|
||||
impl SymbolResolver for Resolver {
|
||||
fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option<Type> {
|
||||
self.id_to_type.get(str).cloned()
|
||||
}
|
||||
|
||||
fn get_symbol_value(&self, _: &str) -> Option<SymbolValue> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn get_symbol_location(&self, _: &str) -> Option<Location> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn get_identifier_def(&self, id: &str) -> Option<DefinitionId> {
|
||||
self.id_to_def.get(id).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
struct TestEnvironment {
|
||||
pub unifier: Unifier,
|
||||
pub function_data: FunctionData,
|
||||
pub primitives: PrimitiveStore,
|
||||
pub id_to_name: HashMap<usize, String>,
|
||||
pub identifier_mapping: HashMap<String, Type>,
|
||||
pub virtual_checks: Vec<(Type, Type)>,
|
||||
pub calls: HashMap<CodeLocation, Arc<Call>>,
|
||||
pub top_level: TopLevelContext,
|
||||
}
|
||||
|
||||
impl TestEnvironment {
|
||||
pub fn basic_test_env() -> TestEnvironment {
|
||||
let mut unifier = Unifier::new();
|
||||
|
||||
let int32 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(0),
|
||||
fields: HashMap::new().into(),
|
||||
params: HashMap::new(),
|
||||
});
|
||||
let int64 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(1),
|
||||
fields: HashMap::new().into(),
|
||||
params: HashMap::new(),
|
||||
});
|
||||
let float = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(2),
|
||||
fields: HashMap::new().into(),
|
||||
params: HashMap::new(),
|
||||
});
|
||||
let bool = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(3),
|
||||
fields: HashMap::new().into(),
|
||||
params: HashMap::new(),
|
||||
});
|
||||
let none = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(4),
|
||||
fields: HashMap::new().into(),
|
||||
params: HashMap::new(),
|
||||
});
|
||||
let primitives = PrimitiveStore { int32, int64, float, bool, none };
|
||||
set_primitives_magic_methods(&primitives, &mut unifier);
|
||||
|
||||
let id_to_name = [
|
||||
(0, "int32".to_string()),
|
||||
(1, "int64".to_string()),
|
||||
(2, "float".to_string()),
|
||||
(3, "bool".to_string()),
|
||||
(4, "none".to_string()),
|
||||
]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
let mut identifier_mapping = HashMap::new();
|
||||
identifier_mapping.insert("None".into(), none);
|
||||
|
||||
let resolver = Arc::new(Resolver {
|
||||
id_to_type: identifier_mapping.clone(),
|
||||
id_to_def: Default::default(),
|
||||
class_names: Default::default(),
|
||||
}) as Arc<dyn SymbolResolver>;
|
||||
|
||||
TestEnvironment {
|
||||
unifier,
|
||||
top_level: TopLevelContext {
|
||||
definitions: Default::default(),
|
||||
unifiers: Default::default(),
|
||||
conetexts: Default::default(),
|
||||
},
|
||||
function_data: FunctionData {
|
||||
resolver,
|
||||
bound_variables: Vec::new(),
|
||||
return_type: Some(primitives.int32),
|
||||
},
|
||||
primitives,
|
||||
id_to_name,
|
||||
identifier_mapping,
|
||||
virtual_checks: Vec::new(),
|
||||
calls: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_inferencer(&mut self) -> Inferencer {
|
||||
Inferencer {
|
||||
top_level: &self.top_level,
|
||||
function_data: &mut self.function_data,
|
||||
unifier: &mut self.unifier,
|
||||
variable_mapping: Default::default(),
|
||||
primitives: &mut self.primitives,
|
||||
virtual_checks: &mut self.virtual_checks,
|
||||
calls: &mut self.calls,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_primitives() {
|
||||
let mut env = TestEnvironment::basic_test_env();
|
||||
let context = Context::create();
|
||||
let module = context.create_module("test");
|
||||
let builder = context.create_builder();
|
||||
|
||||
let signature = FunSignature {
|
||||
args: vec![
|
||||
FuncArg { name: "a".to_string(), ty: env.primitives.int32, default_value: None },
|
||||
FuncArg { name: "b".to_string(), ty: env.primitives.int32, default_value: None },
|
||||
],
|
||||
ret: env.primitives.int32,
|
||||
vars: HashMap::new(),
|
||||
};
|
||||
|
||||
let mut inferencer = env.get_inferencer();
|
||||
let source = indoc! { "
|
||||
c = a + b
|
||||
d = a if c == 1 else 0
|
||||
return d
|
||||
"};
|
||||
let statements = parse_program(source).unwrap();
|
||||
|
||||
let statements = statements
|
||||
.into_iter()
|
||||
.map(|v| inferencer.fold_stmt(v))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.unwrap();
|
||||
|
||||
let top_level = Arc::new(TopLevelContext {
|
||||
definitions: Default::default(),
|
||||
unifiers: Arc::new(RwLock::new(vec![(env.unifier.get_shared_unifier(), env.primitives)])),
|
||||
conetexts: Default::default(),
|
||||
});
|
||||
|
||||
let task = CodeGenTask {
|
||||
subst: Default::default(),
|
||||
symbol_name: "testing".to_string(),
|
||||
body: statements,
|
||||
unifier_index: 0,
|
||||
resolver: env.function_data.resolver.clone(),
|
||||
signature,
|
||||
};
|
||||
|
||||
let module = gen_func(&context, builder, module, task, top_level);
|
||||
// the following IR is equivalent to
|
||||
// ```
|
||||
// ; ModuleID = 'test.ll'
|
||||
// source_filename = "test"
|
||||
//
|
||||
// ; Function Attrs: norecurse nounwind readnone
|
||||
// define i32 @testing(i32 %0, i32 %1) local_unnamed_addr #0 {
|
||||
// init:
|
||||
// %add = add i32 %1, %0
|
||||
// %cmp = icmp eq i32 %add, 1
|
||||
// %ifexpr = select i1 %cmp, i32 %0, i32 0
|
||||
// ret i32 %ifexpr
|
||||
// }
|
||||
//
|
||||
// attributes #0 = { norecurse nounwind readnone }
|
||||
// ```
|
||||
// after O2 optimization
|
||||
|
||||
let expected = indoc! {"
|
||||
; ModuleID = 'test'
|
||||
source_filename = \"test\"
|
||||
|
||||
define i32 @testing(i32 %0, i32 %1) {
|
||||
init:
|
||||
%a = alloca i32
|
||||
store i32 %0, i32* %a
|
||||
%b = alloca i32
|
||||
store i32 %1, i32* %b
|
||||
%tmp = alloca i32
|
||||
%tmp4 = alloca i32
|
||||
br label %body
|
||||
|
||||
body: ; preds = %init
|
||||
%load = load i32, i32* %a
|
||||
%load1 = load i32, i32* %b
|
||||
%add = add i32 %load, %load1
|
||||
store i32 %add, i32* %tmp
|
||||
%load2 = load i32, i32* %tmp
|
||||
%cmp = icmp eq i32 %load2, 1
|
||||
br i1 %cmp, label %then, label %else
|
||||
|
||||
then: ; preds = %body
|
||||
%load3 = load i32, i32* %a
|
||||
br label %cont
|
||||
|
||||
else: ; preds = %body
|
||||
br label %cont
|
||||
|
||||
cont: ; preds = %else, %then
|
||||
%ifexpr = phi i32 [ %load3, %then ], [ 0, %else ]
|
||||
store i32 %ifexpr, i32* %tmp4
|
||||
%load5 = load i32, i32* %tmp4
|
||||
ret i32 %load5
|
||||
}
|
||||
"}
|
||||
.trim();
|
||||
let ir = module.print_to_string().to_string();
|
||||
println!("src:\n{}", source);
|
||||
println!("IR:\n{}", ir);
|
||||
assert_eq!(expected, ir.trim());
|
||||
}
|
Loading…
Reference in New Issue