simple implementation of classes

This commit is contained in:
pca006132 2021-09-19 22:54:06 +08:00
parent bf1769cef6
commit 4939ff4dbd
7 changed files with 172 additions and 140 deletions

View File

@ -114,10 +114,28 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
let symbol = { let symbol = {
// make sure this lock guard is dropped at the end of this scope... // make sure this lock guard is dropped at the end of this scope...
let def = definition.read(); let def = definition.read();
if let TopLevelDef::Function { instance_to_symbol, .. } = &*def { match &*def {
TopLevelDef::Function { instance_to_symbol, .. } => {
instance_to_symbol.get(&key).cloned() instance_to_symbol.get(&key).cloned()
} else { }
unreachable!() TopLevelDef::Class { methods, .. } => {
// TODO: what about other fields that require alloca?
let mut fun_id = None;
for (name, _, id) in methods.iter() {
if name == "__init__" {
fun_id = Some(*id);
}
}
let fun_id = fun_id.unwrap();
let ty = self.get_llvm_type(fun.0.ret).into_pointer_type();
let zelf_ty: BasicTypeEnum = ty.get_element_type().try_into().unwrap();
let zelf = self.builder.build_alloca(zelf_ty, "alloca").into();
let mut sign = fun.0.clone();
sign.ret = self.primitives.none;
self.gen_call(Some((fun.0.ret, zelf)), (&sign, fun_id), params);
return Some(zelf);
}
} }
} }
.unwrap_or_else(|| { .unwrap_or_else(|| {
@ -164,7 +182,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
}) })
.collect(); .collect();
let signature = FunSignature { let mut signature = FunSignature {
args: fun args: fun
.0 .0
.args .args
@ -186,6 +204,13 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
.collect(), .collect(),
}; };
if let Some(obj) = &obj {
signature.args.insert(
0,
FuncArg { name: "self".into(), ty: obj.0, default_value: None },
);
}
let unifier = (unifier.get_shared_unifier(), *primitives); let unifier = (unifier.get_shared_unifier(), *primitives);
task = Some(CodeGenTask { task = Some(CodeGenTask {
@ -209,7 +234,11 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
} }
let fun_val = self.module.get_function(&symbol).unwrap_or_else(|| { let fun_val = self.module.get_function(&symbol).unwrap_or_else(|| {
let params = fun.0.args.iter().map(|arg| self.get_llvm_type(arg.ty)).collect_vec(); let mut args = fun.0.args.clone();
if let Some(obj) = &obj {
args.insert(0, FuncArg { name: "self".into(), ty: obj.0, default_value: None });
}
let params = args.iter().map(|arg| self.get_llvm_type(arg.ty)).collect_vec();
let fun_ty = if self.unifier.unioned(fun.0.ret, self.primitives.none) { let fun_ty = if self.unifier.unioned(fun.0.ret, self.primitives.none) {
self.ctx.void_type().fn_type(&params, false) self.ctx.void_type().fn_type(&params, false)
} else { } else {
@ -227,7 +256,11 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
mapping.insert(k.name, self.gen_symbol_val(&k.default_value.unwrap())); mapping.insert(k.name, self.gen_symbol_val(&k.default_value.unwrap()));
} }
// reorder the parameters // reorder the parameters
let params = fun.0.args.iter().map(|arg| mapping.remove(&arg.name).unwrap()).collect_vec(); let mut params =
fun.0.args.iter().map(|arg| mapping.remove(&arg.name).unwrap()).collect_vec();
if let Some(obj) = obj {
params.insert(0, obj.1);
}
self.builder.build_call(fun_val, &params, "call").try_as_basic_value().left() self.builder.build_call(fun_val, &params, "call").try_as_basic_value().left()
} }
@ -607,9 +640,6 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
phi.as_basic_value() phi.as_basic_value()
} }
ExprKind::Call { func, args, keywords } => { ExprKind::Call { func, args, keywords } => {
if let ExprKind::Name { id, .. } = &func.as_ref().node {
// TODO: handle primitive casts and function pointers
let fun = self.resolver.get_identifier_def(&id).expect("Unknown identifier");
let mut params = let mut params =
args.iter().map(|arg| (None, self.gen_expr(arg).unwrap())).collect_vec(); args.iter().map(|arg| (None, self.gen_expr(arg).unwrap())).collect_vec();
let kw_iter = keywords.iter().map(|kw| { let kw_iter = keywords.iter().map(|kw| {
@ -623,10 +653,40 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
.unifier .unifier
.get_call_signature(*self.calls.get(&expr.location.into()).unwrap()) .get_call_signature(*self.calls.get(&expr.location.into()).unwrap())
.unwrap(); .unwrap();
match &func.as_ref().node {
ExprKind::Name { id, .. } => {
// TODO: handle primitive casts and function pointers
let fun =
self.resolver.get_identifier_def(&id).expect("Unknown identifier");
return self.gen_call(None, (&signature, fun), params); return self.gen_call(None, (&signature, fun), params);
}
ExprKind::Attribute { value, attr, .. } => {
let val = self.gen_expr(value).unwrap();
let id = if let TypeEnum::TObj { obj_id, .. } =
&*self.unifier.get_ty(value.custom.unwrap())
{
*obj_id
} else { } else {
// TODO: method unreachable!()
unimplemented!() };
let fun_id = {
let defs = self.top_level.definitions.read();
let obj_def = defs.get(id.0).unwrap().read();
if let TopLevelDef::Class { methods, .. } = &*obj_def {
let mut fun_id = None;
for (name, _, id) in methods.iter() {
if name == attr {
fun_id = Some(*id);
}
}
fun_id.unwrap()
} else {
unreachable!()
}
};
return self.gen_call(Some((value.custom.unwrap(), val)), (&signature, fun_id), params);
}
_ => unimplemented!(),
} }
} }
ExprKind::Subscript { value, slice, .. } => { ExprKind::Subscript { value, slice, .. } => {

View File

@ -341,8 +341,16 @@ pub fn gen_func<'ctx>(
unifier, unifier,
}; };
let mut returned = false;
for stmt in task.body.iter() { for stmt in task.body.iter() {
code_gen_context.gen_stmt(stmt); returned = code_gen_context.gen_stmt(stmt);
if returned {
break;
}
}
// after static analysis, only void functions can have no return at the end.
if !returned {
code_gen_context.builder.build_return(None);
} }
let CodeGenContext { builder, module, .. } = code_gen_context; let CodeGenContext { builder, module, .. } = code_gen_context;

View File

@ -34,16 +34,18 @@ impl Default for TopLevelComposer {
impl TopLevelComposer { impl TopLevelComposer {
/// return a composer and things to make a "primitive" symbol resolver, so that the symbol /// return a composer and things to make a "primitive" symbol resolver, so that the symbol
/// resolver can later figure out primitive type definitions when passed a primitive type name /// resolver can later figure out primitive type definitions when passed a primitive type name
pub fn new(builtins: Vec<(String, FunSignature)>) -> (Self, HashMap<String, DefinitionId>, HashMap<String, Type>) { pub fn new(
builtins: Vec<(String, FunSignature)>,
) -> (Self, HashMap<String, DefinitionId>, HashMap<String, Type>) {
let primitives = Self::make_primitives(); let primitives = Self::make_primitives();
let mut definition_ast_list = { let mut definition_ast_list = {
let top_level_def_list = vec![ let top_level_def_list = vec![
Arc::new(RwLock::new(Self::make_top_level_class_def(0, None, "int32"))), Arc::new(RwLock::new(Self::make_top_level_class_def(0, None, "int32", None))),
Arc::new(RwLock::new(Self::make_top_level_class_def(1, None, "int64"))), Arc::new(RwLock::new(Self::make_top_level_class_def(1, None, "int64", None))),
Arc::new(RwLock::new(Self::make_top_level_class_def(2, None, "float"))), Arc::new(RwLock::new(Self::make_top_level_class_def(2, None, "float", None))),
Arc::new(RwLock::new(Self::make_top_level_class_def(3, None, "bool"))), Arc::new(RwLock::new(Self::make_top_level_class_def(3, None, "bool", None))),
Arc::new(RwLock::new(Self::make_top_level_class_def(4, None, "none"))), Arc::new(RwLock::new(Self::make_top_level_class_def(4, None, "none", None))),
]; ];
let ast_list: Vec<Option<ast::Stmt<()>>> = vec![None, None, None, None, None]; let ast_list: Vec<Option<ast::Stmt<()>>> = vec![None, None, None, None, None];
izip!(top_level_def_list, ast_list).collect_vec() izip!(top_level_def_list, ast_list).collect_vec()
@ -80,16 +82,14 @@ impl TopLevelComposer {
definition_ast_list.push(( definition_ast_list.push((
Arc::new(RwLock::new(TopLevelDef::Function { Arc::new(RwLock::new(TopLevelDef::Function {
name: name.clone(), name: name.clone(),
simple_name: name.clone(),
signature: fun_sig, signature: fun_sig,
instance_to_stmt: HashMap::new(), instance_to_stmt: HashMap::new(),
instance_to_symbol: [("".to_string(), name.clone())] instance_to_symbol: [("".to_string(), name.clone())].iter().cloned().collect(),
.iter()
.cloned()
.collect(),
var_id: Default::default(), var_id: Default::default(),
resolver: None, resolver: None,
})), })),
None None,
)); ));
defined_class_method_name.insert(name.clone()); defined_class_method_name.insert(name.clone());
defined_class_name.insert(name.clone()); defined_class_name.insert(name.clone());
@ -160,11 +160,13 @@ impl TopLevelComposer {
// since later when registering class method, ast will still be used, // since later when registering class method, ast will still be used,
// here push None temporarly, later will move the ast inside // here push None temporarly, later will move the ast inside
let constructor_ty = self.unifier.get_fresh_var().0;
let mut class_def_ast = ( let mut class_def_ast = (
Arc::new(RwLock::new(Self::make_top_level_class_def( Arc::new(RwLock::new(Self::make_top_level_class_def(
class_def_id, class_def_id,
resolver.clone(), resolver.clone(),
name, name,
Some(constructor_ty)
))), ))),
None, None,
); );
@ -215,6 +217,7 @@ impl TopLevelComposer {
method_name.clone(), method_name.clone(),
RwLock::new(Self::make_top_level_function_def( RwLock::new(Self::make_top_level_function_def(
global_class_method_name, global_class_method_name,
method_name.clone(),
// later unify with parsed type // later unify with parsed type
dummy_method_type.0, dummy_method_type.0,
resolver.clone(), resolver.clone(),
@ -251,14 +254,7 @@ impl TopLevelComposer {
self.definition_ast_list.push((def, Some(ast))); self.definition_ast_list.push((def, Some(ast)));
} }
// put the constructor into the def_list Ok((class_name, DefinitionId(class_def_id), Some(constructor_ty)))
self.definition_ast_list.push((
RwLock::new(TopLevelDef::Initializer { class_id: DefinitionId(class_def_id) })
.into(),
None,
));
Ok((class_name, DefinitionId(class_def_id), None))
} }
ast::StmtKind::FunctionDef { name, .. } => { ast::StmtKind::FunctionDef { name, .. } => {
@ -278,6 +274,8 @@ impl TopLevelComposer {
// add to the definition list // add to the definition list
self.definition_ast_list.push(( self.definition_ast_list.push((
RwLock::new(Self::make_top_level_function_def( RwLock::new(Self::make_top_level_function_def(
// TODO: is this fun_name or the above name with mod_path?
name.into(),
name.into(), name.into(),
// dummy here, unify with correct type later // dummy here, unify with correct type later
ty_to_be_unified, ty_to_be_unified,
@ -801,7 +799,7 @@ impl TopLevelComposer {
resolver, resolver,
type_vars, type_vars,
.. ..
} = class_def.deref_mut() } = &mut *class_def
{ {
if let ast::StmtKind::ClassDef { name, bases, body, .. } = &class_ast { if let ast::StmtKind::ClassDef { name, bases, body, .. } = &class_ast {
( (
@ -1161,16 +1159,11 @@ impl TopLevelComposer {
/// step 5, analyze and call type inferecer to fill the `instance_to_stmt` of topleveldef::function /// step 5, analyze and call type inferecer to fill the `instance_to_stmt` of topleveldef::function
fn analyze_function_instance(&mut self) -> Result<(), String> { fn analyze_function_instance(&mut self) -> Result<(), String> {
for (id, (def, ast)) in self.definition_ast_list.iter().enumerate().skip(self.built_in_num) { for (id, (def, ast)) in self.definition_ast_list.iter().enumerate().skip(self.built_in_num)
{
let mut function_def = def.write(); let mut function_def = def.write();
if let TopLevelDef::Function { if let TopLevelDef::Function { instance_to_stmt, name, simple_name, signature, resolver, .. } =
instance_to_stmt, &mut *function_def
name,
signature,
var_id,
resolver,
..
} = &mut *function_def
{ {
if let TypeEnum::TFunc(func_sig) = self.unifier.get_ty(*signature).as_ref() { if let TypeEnum::TFunc(func_sig) = self.unifier.get_ty(*signature).as_ref() {
let FunSignature { args, ret, vars } = &*func_sig.borrow(); let FunSignature { args, ret, vars } = &*func_sig.borrow();
@ -1179,14 +1172,23 @@ impl TopLevelComposer {
if let Some(class_id) = self.method_class.get(&DefinitionId(id)) { if let Some(class_id) = self.method_class.get(&DefinitionId(id)) {
let class_def = self.definition_ast_list.get(class_id.0).unwrap(); let class_def = self.definition_ast_list.get(class_id.0).unwrap();
let class_def = class_def.0.read(); let class_def = class_def.0.read();
if let TopLevelDef::Class { type_vars, .. } = &*class_def { if let TopLevelDef::Class { type_vars, constructor, .. } = &*class_def {
let ty_ann = make_self_type_annotation(type_vars, *class_id); let ty_ann = make_self_type_annotation(type_vars, *class_id);
Some(get_type_from_type_annotation_kinds( let self_ty = get_type_from_type_annotation_kinds(
self.extract_def_list().as_slice(), self.extract_def_list().as_slice(),
&mut self.unifier, &mut self.unifier,
&self.primitives_ty, &self.primitives_ty,
&ty_ann, &ty_ann,
)?) )?;
if simple_name == "__init__" {
let fn_type = self.unifier.add_ty(TypeEnum::TFunc(RefCell::new(FunSignature {
args: args.clone(),
ret: self_ty,
vars: vars.clone()
})));
self.unifier.unify(fn_type, constructor.unwrap())?;
}
Some(self_ty)
} else { } else {
unreachable!("must be class def") unreachable!("must be class def")
} }
@ -1227,8 +1229,7 @@ impl TopLevelComposer {
let inst_ret = self.unifier.subst(*ret, &subst).unwrap_or(*ret); let inst_ret = self.unifier.subst(*ret, &subst).unwrap_or(*ret);
let inst_args = { let inst_args = {
let unifier = &mut self.unifier; let unifier = &mut self.unifier;
args args.iter()
.iter()
.map(|a| FuncArg { .map(|a| FuncArg {
name: a.name.clone(), name: a.name.clone(),
ty: unifier.subst(a.ty, &subst).unwrap_or(a.ty), ty: unifier.subst(a.ty, &subst).unwrap_or(a.ty),
@ -1319,15 +1320,15 @@ impl TopLevelComposer {
.sorted() .sorted()
.map(|id| { .map(|id| {
let ty = subst.get(id).unwrap(); let ty = subst.get(id).unwrap();
unifier.stringify(*ty, &mut |id| id.to_string(), &mut |id| id.to_string()) unifier.stringify(
}).join(", ") *ty,
}, &mut |id| id.to_string(),
FunInstance { &mut |id| id.to_string(),
body: fun_body, )
unifier_id: 0, })
calls, .join(", ")
subst,
}, },
FunInstance { body: fun_body, unifier_id: 0, calls, subst },
); );
} }
} else { } else {

View File

@ -50,7 +50,6 @@ impl TopLevelDef {
r r
} }
), ),
TopLevelDef::Initializer { class_id } => format!("Initializer {{ {:?} }}", class_id),
} }
} }
} }
@ -94,6 +93,7 @@ impl TopLevelComposer {
index: usize, index: usize,
resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>, resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
name: &str, name: &str,
constructor: Option<Type>
) -> TopLevelDef { ) -> TopLevelDef {
TopLevelDef::Class { TopLevelDef::Class {
name: name.to_string(), name: name.to_string(),
@ -102,6 +102,7 @@ impl TopLevelComposer {
fields: Default::default(), fields: Default::default(),
methods: Default::default(), methods: Default::default(),
ancestors: Default::default(), ancestors: Default::default(),
constructor,
resolver, resolver,
} }
} }
@ -109,11 +110,13 @@ impl TopLevelComposer {
/// when first registering, the type is a invalid value /// when first registering, the type is a invalid value
pub fn make_top_level_function_def( pub fn make_top_level_function_def(
name: String, name: String,
simple_name: String,
ty: Type, ty: Type,
resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>, resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
) -> TopLevelDef { ) -> TopLevelDef {
TopLevelDef::Function { TopLevelDef::Function {
name, name,
simple_name,
signature: ty, signature: ty,
var_id: Default::default(), var_id: Default::default(),
instance_to_symbol: Default::default(), instance_to_symbol: Default::default(),

View File

@ -53,10 +53,14 @@ pub enum TopLevelDef {
ancestors: Vec<TypeAnnotation>, ancestors: Vec<TypeAnnotation>,
// symbol resolver of the module defined the class, none if it is built-in type // symbol resolver of the module defined the class, none if it is built-in type
resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>, resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
// constructor type
constructor: Option<Type>,
}, },
Function { Function {
// prefix for symbol, should be unique globally, and not ending with numbers // prefix for symbol, should be unique globally, and not ending with numbers
name: String, name: String,
// simple name, the same as in method/function definition
simple_name: String,
// function signature. // function signature.
signature: Type, signature: Type,
// instantiated type variable IDs // instantiated type variable IDs
@ -75,9 +79,6 @@ pub enum TopLevelDef {
// symbol resolver of the module defined the class // symbol resolver of the module defined the class
resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>, resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
}, },
Initializer {
class_id: DefinitionId,
},
} }
pub struct TopLevelContext { pub struct TopLevelContext {

View File

@ -1,64 +1,20 @@
def y_scale(maxX: float, minX: float, height: float, width: float, aspectRatio: float) -> float: class A:
return (maxX-minX)*(height/width)*aspectRatio a: int32
def __init__(self, a: int32):
self.a = a
def check_smaller_than_sixteen(i: int32) -> bool: def get_a(self) -> int32:
return i < 16 return self.a
def rec(x: int32): def get_self(self) -> A:
if x > 1: return self
output(x)
rec(x - 1)
return
else:
output(-1)
return
def fib(n: int32) -> int32:
if n <= 2:
return 1
else:
return fib(n - 1) + fib(n - 2)
def draw():
minX = -2.0
maxX = 1.0
width = 78.0
height = 36.0
aspectRatio = 2.0
# test = 1.0 + 1
yScale = y_scale(maxX, minX, height, width, aspectRatio)
y = 0.0
while y < height:
x = 0.0
while x < width:
c_r = minX+x*(maxX-minX)/width
c_i = y*yScale/height-yScale/2.0
z_r = c_r
z_i = c_i
i = 0
while check_smaller_than_sixteen(i):
if z_r*z_r + z_i*z_i > 4.0:
break
new_z_r = (z_r*z_r)-(z_i*z_i) + c_r
z_i = 2.0*z_r*z_i + c_i
z_r = new_z_r
i = i + 1
output(i)
x = x + 1.0
output(-1)
y = y + 1.0
return
def run() -> int32: def run() -> int32:
rec(5) a = A(10)
output(a.a)
output(fib(10)) a = A(20)
output(-1) output(a.a)
output(a.get_a())
draw()
return 0 return 0

View File

@ -1,14 +1,11 @@
use std::time::SystemTime; use std::fs;
use std::{collections::HashSet, fs};
use inkwell::{ use inkwell::{
passes::{PassManager, PassManagerBuilder}, passes::{PassManager, PassManagerBuilder},
targets::*, targets::*,
OptimizationLevel, OptimizationLevel,
}; };
use nac3core::typecheck::type_inferencer::PrimitiveStore; use nac3core::typecheck::type_inferencer::PrimitiveStore;
use parking_lot::RwLock; use rustpython_parser::{parser, ast::StmtKind};
use rustpython_parser::parser;
use std::{collections::HashMap, path::Path, sync::Arc}; use std::{collections::HashMap, path::Path, sync::Arc};
use nac3core::{ use nac3core::{
@ -55,11 +52,17 @@ fn main() {
); );
for stmt in parser::parse_program(&program).unwrap().into_iter() { for stmt in parser::parse_program(&program).unwrap().into_iter() {
let is_class = matches!(stmt.node, StmtKind::ClassDef{ .. });
let (name, def_id, ty) = composer.register_top_level( let (name, def_id, ty) = composer.register_top_level(
stmt, stmt,
Some(resolver.clone()), Some(resolver.clone()),
"__main__".into(), "__main__".into(),
).unwrap(); ).unwrap();
if is_class {
internal_resolver.add_id_type(name.clone(), ty.unwrap());
}
internal_resolver.add_id_def(name.clone(), def_id); internal_resolver.add_id_def(name.clone(), def_id);
if let Some(ty) = ty { if let Some(ty) = ty {
internal_resolver.add_id_type(name, ty); internal_resolver.add_id_type(name, ty);