forked from M-Labs/nac3
1
0
Fork 0

added primitive codegen test

This commit is contained in:
pca006132 2021-08-12 13:55:15 +08:00
parent 3a93e2b048
commit 77943a8117
4 changed files with 267 additions and 24 deletions

View File

@ -14,7 +14,7 @@ use inkwell::{
use itertools::{chain, izip, zip, Itertools}; use itertools::{chain, izip, zip, Itertools};
use rustpython_parser::ast::{self, Boolop, Constant, Expr, ExprKind, Operator}; use rustpython_parser::ast::{self, Boolop, Constant, Expr, ExprKind, Operator};
impl<'ctx> CodeGenContext<'ctx> { impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
fn get_subst_key(&mut self, obj: Option<Type>, fun: &FunSignature) -> String { fn get_subst_key(&mut self, obj: Option<Type>, fun: &FunSignature) -> String {
let mut vars = obj let mut vars = obj
.map(|ty| { .map(|ty| {

View File

@ -16,21 +16,23 @@ use inkwell::{
AddressSpace, AddressSpace,
}; };
use itertools::Itertools; use itertools::Itertools;
use rayon::current_thread_index; use rustpython_parser::ast::Stmt;
use rustpython_parser::ast::{Stmt, StmtKind};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
mod expr; mod expr;
mod stmt; mod stmt;
pub struct CodeGenContext<'ctx> { #[cfg(test)]
mod test;
pub struct CodeGenContext<'ctx, 'a> {
pub ctx: &'ctx Context, pub ctx: &'ctx Context,
pub builder: Builder<'ctx>, pub builder: Builder<'ctx>,
pub module: Module<'ctx>, pub module: Module<'ctx>,
pub top_level: &'ctx TopLevelContext, pub top_level: &'a TopLevelContext,
pub unifier: Unifier, pub unifier: Unifier,
pub resolver: Box<dyn SymbolResolver>, pub resolver: Arc<dyn SymbolResolver>,
pub var_assignment: HashMap<String, PointerValue<'ctx>>, pub var_assignment: HashMap<String, PointerValue<'ctx>>,
pub type_cache: HashMap<Type, BasicTypeEnum<'ctx>>, pub type_cache: HashMap<Type, BasicTypeEnum<'ctx>>,
pub primitives: PrimitiveStore, pub primitives: PrimitiveStore,
@ -45,9 +47,9 @@ pub struct CodeGenTask {
pub subst: Vec<(Type, Type)>, pub subst: Vec<(Type, Type)>,
pub symbol_name: String, pub symbol_name: String,
pub signature: FunSignature, pub signature: FunSignature,
pub body: Stmt<Option<Type>>, pub body: Vec<Stmt<Option<Type>>>,
pub unifier_index: usize, pub unifier_index: usize,
pub resolver: Box<dyn SymbolResolver>, pub resolver: Arc<dyn SymbolResolver>,
} }
fn get_llvm_type<'ctx>( fn get_llvm_type<'ctx>(
@ -60,7 +62,7 @@ fn get_llvm_type<'ctx>(
use TypeEnum::*; use TypeEnum::*;
// we assume the type cache should already contain primitive types, // we assume the type cache should already contain primitive types,
// and they should be passed by value instead of passing as pointer. // and they should be passed by value instead of passing as pointer.
type_cache.get(&ty).cloned().unwrap_or_else(|| match &*unifier.get_ty(ty) { type_cache.get(&unifier.get_representative(ty)).cloned().unwrap_or_else(|| match &*unifier.get_ty(ty) {
TObj { obj_id, fields, .. } => { TObj { obj_id, fields, .. } => {
// a struct with fields in the order of declaration // a struct with fields in the order of declaration
let defs = top_level.definitions.read(); let defs = top_level.definitions.read();
@ -97,16 +99,13 @@ fn get_llvm_type<'ctx>(
}) })
} }
pub fn gen_func(task: CodeGenTask, top_level_ctx: Arc<TopLevelContext>) { pub fn gen_func<'ctx>(context: &'ctx Context, builder: Builder<'ctx>, module: Module<'ctx>, task: CodeGenTask, top_level_ctx: Arc<TopLevelContext>) -> Module<'ctx> {
// unwrap_or(0) is for unit tests without using rayon // unwrap_or(0) is for unit tests without using rayon
let thread_id = current_thread_index().unwrap_or(0);
let (mut unifier, primitives) = { let (mut unifier, primitives) = {
let unifiers = top_level_ctx.unifiers.read(); let unifiers = top_level_ctx.unifiers.read();
let (unifier, primitives) = &unifiers[task.unifier_index]; let (unifier, primitives) = &unifiers[task.unifier_index];
(Unifier::from_shared_unifier(unifier), *primitives) (Unifier::from_shared_unifier(unifier), *primitives)
}; };
let contexts = top_level_ctx.conetexts.read();
let context = contexts[thread_id].lock();
for (a, b) in task.subst.iter() { for (a, b) in task.subst.iter() {
// this should be unification between variables and concrete types // this should be unification between variables and concrete types
@ -124,10 +123,10 @@ pub fn gen_func(task: CodeGenTask, top_level_ctx: Arc<TopLevelContext>) {
}; };
let mut type_cache: HashMap<_, _> = [ let mut type_cache: HashMap<_, _> = [
(primitives.int32, context.i32_type().into()), (unifier.get_representative(primitives.int32), context.i32_type().into()),
(primitives.int64, context.i64_type().into()), (unifier.get_representative(primitives.int64), context.i64_type().into()),
(primitives.float, context.f64_type().into()), (unifier.get_representative(primitives.float), context.f64_type().into()),
(primitives.bool, context.bool_type().into()), (unifier.get_representative(primitives.bool), context.bool_type().into()),
] ]
.iter() .iter()
.cloned() .cloned()
@ -155,8 +154,6 @@ pub fn gen_func(task: CodeGenTask, top_level_ctx: Arc<TopLevelContext>) {
.fn_type(&params, false) .fn_type(&params, false)
}; };
let builder = context.create_builder();
let module = context.create_module(&task.symbol_name);
let fn_val = module.add_function(&task.symbol_name, fn_type, None); let fn_val = module.add_function(&task.symbol_name, fn_type, None);
let init_bb = context.append_basic_block(fn_val, "init"); let init_bb = context.append_basic_block(fn_val, "init");
builder.position_at_end(init_bb); builder.position_at_end(init_bb);
@ -189,9 +186,9 @@ pub fn gen_func(task: CodeGenTask, top_level_ctx: Arc<TopLevelContext>) {
unifier, unifier,
}; };
if let StmtKind::FunctionDef { body, .. } = &task.body.node { for stmt in task.body.iter() {
for stmt in body.iter() {
code_gen_context.gen_stmt(stmt); code_gen_context.gen_stmt(stmt);
} }
}
code_gen_context.module
} }

View File

@ -3,7 +3,7 @@ use crate::typecheck::typedef::Type;
use inkwell::values::{BasicValue, BasicValueEnum, PointerValue}; use inkwell::values::{BasicValue, BasicValueEnum, PointerValue};
use rustpython_parser::ast::{Expr, ExprKind, Stmt, StmtKind}; use rustpython_parser::ast::{Expr, ExprKind, Stmt, StmtKind};
impl<'ctx> CodeGenContext<'ctx> { impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
fn gen_var(&mut self, ty: Type) -> PointerValue<'ctx> { fn gen_var(&mut self, ty: Type) -> PointerValue<'ctx> {
// put the alloca in init block // put the alloca in init block
let current = self.builder.get_insert_block().unwrap(); let current = self.builder.get_insert_block().unwrap();

View File

@ -0,0 +1,246 @@
use super::{gen_func, CodeGenTask};
use crate::{
location::Location,
symbol_resolver::{SymbolResolver, SymbolValue},
top_level::{DefinitionId, TopLevelContext},
typecheck::{
magic_methods::set_primitives_magic_methods,
type_inferencer::{CodeLocation, FunctionData, Inferencer, PrimitiveStore},
typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier},
},
};
use indoc::indoc;
use inkwell::context::Context;
use parking_lot::RwLock;
use rustpython_parser::{ast::fold::Fold, parser::parse_program};
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Clone)]
struct Resolver {
id_to_type: HashMap<String, Type>,
id_to_def: HashMap<String, DefinitionId>,
class_names: HashMap<String, Type>,
}
impl SymbolResolver for Resolver {
fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option<Type> {
self.id_to_type.get(str).cloned()
}
fn get_symbol_value(&self, _: &str) -> Option<SymbolValue> {
unimplemented!()
}
fn get_symbol_location(&self, _: &str) -> Option<Location> {
unimplemented!()
}
fn get_identifier_def(&self, id: &str) -> Option<DefinitionId> {
self.id_to_def.get(id).cloned()
}
}
struct TestEnvironment {
pub unifier: Unifier,
pub function_data: FunctionData,
pub primitives: PrimitiveStore,
pub id_to_name: HashMap<usize, String>,
pub identifier_mapping: HashMap<String, Type>,
pub virtual_checks: Vec<(Type, Type)>,
pub calls: HashMap<CodeLocation, Arc<Call>>,
pub top_level: TopLevelContext,
}
impl TestEnvironment {
pub fn basic_test_env() -> TestEnvironment {
let mut unifier = Unifier::new();
let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(0),
fields: HashMap::new().into(),
params: HashMap::new(),
});
let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(1),
fields: HashMap::new().into(),
params: HashMap::new(),
});
let float = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(2),
fields: HashMap::new().into(),
params: HashMap::new(),
});
let bool = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(3),
fields: HashMap::new().into(),
params: HashMap::new(),
});
let none = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(4),
fields: HashMap::new().into(),
params: HashMap::new(),
});
let primitives = PrimitiveStore { int32, int64, float, bool, none };
set_primitives_magic_methods(&primitives, &mut unifier);
let id_to_name = [
(0, "int32".to_string()),
(1, "int64".to_string()),
(2, "float".to_string()),
(3, "bool".to_string()),
(4, "none".to_string()),
]
.iter()
.cloned()
.collect();
let mut identifier_mapping = HashMap::new();
identifier_mapping.insert("None".into(), none);
let resolver = Arc::new(Resolver {
id_to_type: identifier_mapping.clone(),
id_to_def: Default::default(),
class_names: Default::default(),
}) as Arc<dyn SymbolResolver>;
TestEnvironment {
unifier,
top_level: TopLevelContext {
definitions: Default::default(),
unifiers: Default::default(),
conetexts: Default::default(),
},
function_data: FunctionData {
resolver,
bound_variables: Vec::new(),
return_type: Some(primitives.int32),
},
primitives,
id_to_name,
identifier_mapping,
virtual_checks: Vec::new(),
calls: HashMap::new(),
}
}
fn get_inferencer(&mut self) -> Inferencer {
Inferencer {
top_level: &self.top_level,
function_data: &mut self.function_data,
unifier: &mut self.unifier,
variable_mapping: Default::default(),
primitives: &mut self.primitives,
virtual_checks: &mut self.virtual_checks,
calls: &mut self.calls,
}
}
}
#[test]
fn test_primitives() {
let mut env = TestEnvironment::basic_test_env();
let context = Context::create();
let module = context.create_module("test");
let builder = context.create_builder();
let signature = FunSignature {
args: vec![
FuncArg { name: "a".to_string(), ty: env.primitives.int32, default_value: None },
FuncArg { name: "b".to_string(), ty: env.primitives.int32, default_value: None },
],
ret: env.primitives.int32,
vars: HashMap::new(),
};
let mut inferencer = env.get_inferencer();
let source = indoc! { "
c = a + b
d = a if c == 1 else 0
return d
"};
let statements = parse_program(source).unwrap();
let statements = statements
.into_iter()
.map(|v| inferencer.fold_stmt(v))
.collect::<Result<Vec<_>, _>>()
.unwrap();
let top_level = Arc::new(TopLevelContext {
definitions: Default::default(),
unifiers: Arc::new(RwLock::new(vec![(env.unifier.get_shared_unifier(), env.primitives)])),
conetexts: Default::default(),
});
let task = CodeGenTask {
subst: Default::default(),
symbol_name: "testing".to_string(),
body: statements,
unifier_index: 0,
resolver: env.function_data.resolver.clone(),
signature,
};
let module = gen_func(&context, builder, module, task, top_level);
// the following IR is equivalent to
// ```
// ; ModuleID = 'test.ll'
// source_filename = "test"
//
// ; Function Attrs: norecurse nounwind readnone
// define i32 @testing(i32 %0, i32 %1) local_unnamed_addr #0 {
// init:
// %add = add i32 %1, %0
// %cmp = icmp eq i32 %add, 1
// %ifexpr = select i1 %cmp, i32 %0, i32 0
// ret i32 %ifexpr
// }
//
// attributes #0 = { norecurse nounwind readnone }
// ```
// after O2 optimization
let expected = indoc! {"
; ModuleID = 'test'
source_filename = \"test\"
define i32 @testing(i32 %0, i32 %1) {
init:
%a = alloca i32
store i32 %0, i32* %a
%b = alloca i32
store i32 %1, i32* %b
%tmp = alloca i32
%tmp4 = alloca i32
br label %body
body: ; preds = %init
%load = load i32, i32* %a
%load1 = load i32, i32* %b
%add = add i32 %load, %load1
store i32 %add, i32* %tmp
%load2 = load i32, i32* %tmp
%cmp = icmp eq i32 %load2, 1
br i1 %cmp, label %then, label %else
then: ; preds = %body
%load3 = load i32, i32* %a
br label %cont
else: ; preds = %body
br label %cont
cont: ; preds = %else, %then
%ifexpr = phi i32 [ %load3, %then ], [ 0, %else ]
store i32 %ifexpr, i32* %tmp4
%load5 = load i32, i32* %tmp4
ret i32 %load5
}
"}
.trim();
let ir = module.print_to_string().to_string();
println!("src:\n{}", source);
println!("IR:\n{}", ir);
assert_eq!(expected, ir.trim());
}