simple implementation of classes

This commit is contained in:
pca006132 2021-09-19 22:54:06 +08:00
parent bf1769cef6
commit 4939ff4dbd
7 changed files with 172 additions and 140 deletions

View File

@ -114,10 +114,28 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
let symbol = {
// make sure this lock guard is dropped at the end of this scope...
let def = definition.read();
if let TopLevelDef::Function { instance_to_symbol, .. } = &*def {
instance_to_symbol.get(&key).cloned()
} else {
unreachable!()
match &*def {
TopLevelDef::Function { instance_to_symbol, .. } => {
instance_to_symbol.get(&key).cloned()
}
TopLevelDef::Class { methods, .. } => {
// TODO: what about other fields that require alloca?
let mut fun_id = None;
for (name, _, id) in methods.iter() {
if name == "__init__" {
fun_id = Some(*id);
}
}
let fun_id = fun_id.unwrap();
let ty = self.get_llvm_type(fun.0.ret).into_pointer_type();
let zelf_ty: BasicTypeEnum = ty.get_element_type().try_into().unwrap();
let zelf = self.builder.build_alloca(zelf_ty, "alloca").into();
let mut sign = fun.0.clone();
sign.ret = self.primitives.none;
self.gen_call(Some((fun.0.ret, zelf)), (&sign, fun_id), params);
return Some(zelf);
}
}
}
.unwrap_or_else(|| {
@ -164,7 +182,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
})
.collect();
let signature = FunSignature {
let mut signature = FunSignature {
args: fun
.0
.args
@ -186,6 +204,13 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
.collect(),
};
if let Some(obj) = &obj {
signature.args.insert(
0,
FuncArg { name: "self".into(), ty: obj.0, default_value: None },
);
}
let unifier = (unifier.get_shared_unifier(), *primitives);
task = Some(CodeGenTask {
@ -209,7 +234,11 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
}
let fun_val = self.module.get_function(&symbol).unwrap_or_else(|| {
let params = fun.0.args.iter().map(|arg| self.get_llvm_type(arg.ty)).collect_vec();
let mut args = fun.0.args.clone();
if let Some(obj) = &obj {
args.insert(0, FuncArg { name: "self".into(), ty: obj.0, default_value: None });
}
let params = args.iter().map(|arg| self.get_llvm_type(arg.ty)).collect_vec();
let fun_ty = if self.unifier.unioned(fun.0.ret, self.primitives.none) {
self.ctx.void_type().fn_type(&params, false)
} else {
@ -227,7 +256,11 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
mapping.insert(k.name, self.gen_symbol_val(&k.default_value.unwrap()));
}
// reorder the parameters
let params = fun.0.args.iter().map(|arg| mapping.remove(&arg.name).unwrap()).collect_vec();
let mut params =
fun.0.args.iter().map(|arg| mapping.remove(&arg.name).unwrap()).collect_vec();
if let Some(obj) = obj {
params.insert(0, obj.1);
}
self.builder.build_call(fun_val, &params, "call").try_as_basic_value().left()
}
@ -607,26 +640,53 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
phi.as_basic_value()
}
ExprKind::Call { func, args, keywords } => {
if let ExprKind::Name { id, .. } = &func.as_ref().node {
// TODO: handle primitive casts and function pointers
let fun = self.resolver.get_identifier_def(&id).expect("Unknown identifier");
let mut params =
args.iter().map(|arg| (None, self.gen_expr(arg).unwrap())).collect_vec();
let kw_iter = keywords.iter().map(|kw| {
(
Some(kw.node.arg.as_ref().unwrap().clone()),
self.gen_expr(&kw.node.value).unwrap(),
)
});
params.extend(kw_iter);
let signature = self
.unifier
.get_call_signature(*self.calls.get(&expr.location.into()).unwrap())
.unwrap();
return self.gen_call(None, (&signature, fun), params);
} else {
// TODO: method
unimplemented!()
let mut params =
args.iter().map(|arg| (None, self.gen_expr(arg).unwrap())).collect_vec();
let kw_iter = keywords.iter().map(|kw| {
(
Some(kw.node.arg.as_ref().unwrap().clone()),
self.gen_expr(&kw.node.value).unwrap(),
)
});
params.extend(kw_iter);
let signature = self
.unifier
.get_call_signature(*self.calls.get(&expr.location.into()).unwrap())
.unwrap();
match &func.as_ref().node {
ExprKind::Name { id, .. } => {
// TODO: handle primitive casts and function pointers
let fun =
self.resolver.get_identifier_def(&id).expect("Unknown identifier");
return self.gen_call(None, (&signature, fun), params);
}
ExprKind::Attribute { value, attr, .. } => {
let val = self.gen_expr(value).unwrap();
let id = if let TypeEnum::TObj { obj_id, .. } =
&*self.unifier.get_ty(value.custom.unwrap())
{
*obj_id
} else {
unreachable!()
};
let fun_id = {
let defs = self.top_level.definitions.read();
let obj_def = defs.get(id.0).unwrap().read();
if let TopLevelDef::Class { methods, .. } = &*obj_def {
let mut fun_id = None;
for (name, _, id) in methods.iter() {
if name == attr {
fun_id = Some(*id);
}
}
fun_id.unwrap()
} else {
unreachable!()
}
};
return self.gen_call(Some((value.custom.unwrap(), val)), (&signature, fun_id), params);
}
_ => unimplemented!(),
}
}
ExprKind::Subscript { value, slice, .. } => {

View File

@ -341,8 +341,16 @@ pub fn gen_func<'ctx>(
unifier,
};
let mut returned = false;
for stmt in task.body.iter() {
code_gen_context.gen_stmt(stmt);
returned = code_gen_context.gen_stmt(stmt);
if returned {
break;
}
}
// after static analysis, only void functions can have no return at the end.
if !returned {
code_gen_context.builder.build_return(None);
}
let CodeGenContext { builder, module, .. } = code_gen_context;

View File

@ -34,16 +34,18 @@ impl Default for TopLevelComposer {
impl TopLevelComposer {
/// return a composer and things to make a "primitive" symbol resolver, so that the symbol
/// resolver can later figure out primitive type definitions when passed a primitive type name
pub fn new(builtins: Vec<(String, FunSignature)>) -> (Self, HashMap<String, DefinitionId>, HashMap<String, Type>) {
pub fn new(
builtins: Vec<(String, FunSignature)>,
) -> (Self, HashMap<String, DefinitionId>, HashMap<String, Type>) {
let primitives = Self::make_primitives();
let mut definition_ast_list = {
let top_level_def_list = vec![
Arc::new(RwLock::new(Self::make_top_level_class_def(0, None, "int32"))),
Arc::new(RwLock::new(Self::make_top_level_class_def(1, None, "int64"))),
Arc::new(RwLock::new(Self::make_top_level_class_def(2, None, "float"))),
Arc::new(RwLock::new(Self::make_top_level_class_def(3, None, "bool"))),
Arc::new(RwLock::new(Self::make_top_level_class_def(4, None, "none"))),
Arc::new(RwLock::new(Self::make_top_level_class_def(0, None, "int32", None))),
Arc::new(RwLock::new(Self::make_top_level_class_def(1, None, "int64", None))),
Arc::new(RwLock::new(Self::make_top_level_class_def(2, None, "float", None))),
Arc::new(RwLock::new(Self::make_top_level_class_def(3, None, "bool", None))),
Arc::new(RwLock::new(Self::make_top_level_class_def(4, None, "none", None))),
];
let ast_list: Vec<Option<ast::Stmt<()>>> = vec![None, None, None, None, None];
izip!(top_level_def_list, ast_list).collect_vec()
@ -80,16 +82,14 @@ impl TopLevelComposer {
definition_ast_list.push((
Arc::new(RwLock::new(TopLevelDef::Function {
name: name.clone(),
simple_name: name.clone(),
signature: fun_sig,
instance_to_stmt: HashMap::new(),
instance_to_symbol: [("".to_string(), name.clone())]
.iter()
.cloned()
.collect(),
instance_to_symbol: [("".to_string(), name.clone())].iter().cloned().collect(),
var_id: Default::default(),
resolver: None,
})),
None
None,
));
defined_class_method_name.insert(name.clone());
defined_class_name.insert(name.clone());
@ -160,11 +160,13 @@ impl TopLevelComposer {
// since later when registering class method, ast will still be used,
// here push None temporarly, later will move the ast inside
let constructor_ty = self.unifier.get_fresh_var().0;
let mut class_def_ast = (
Arc::new(RwLock::new(Self::make_top_level_class_def(
class_def_id,
resolver.clone(),
name,
Some(constructor_ty)
))),
None,
);
@ -215,6 +217,7 @@ impl TopLevelComposer {
method_name.clone(),
RwLock::new(Self::make_top_level_function_def(
global_class_method_name,
method_name.clone(),
// later unify with parsed type
dummy_method_type.0,
resolver.clone(),
@ -251,14 +254,7 @@ impl TopLevelComposer {
self.definition_ast_list.push((def, Some(ast)));
}
// put the constructor into the def_list
self.definition_ast_list.push((
RwLock::new(TopLevelDef::Initializer { class_id: DefinitionId(class_def_id) })
.into(),
None,
));
Ok((class_name, DefinitionId(class_def_id), None))
Ok((class_name, DefinitionId(class_def_id), Some(constructor_ty)))
}
ast::StmtKind::FunctionDef { name, .. } => {
@ -278,6 +274,8 @@ impl TopLevelComposer {
// add to the definition list
self.definition_ast_list.push((
RwLock::new(Self::make_top_level_function_def(
// TODO: is this fun_name or the above name with mod_path?
name.into(),
name.into(),
// dummy here, unify with correct type later
ty_to_be_unified,
@ -801,7 +799,7 @@ impl TopLevelComposer {
resolver,
type_vars,
..
} = class_def.deref_mut()
} = &mut *class_def
{
if let ast::StmtKind::ClassDef { name, bases, body, .. } = &class_ast {
(
@ -1161,32 +1159,36 @@ impl TopLevelComposer {
/// step 5, analyze and call type inferecer to fill the `instance_to_stmt` of topleveldef::function
fn analyze_function_instance(&mut self) -> Result<(), String> {
for (id, (def, ast)) in self.definition_ast_list.iter().enumerate().skip(self.built_in_num) {
for (id, (def, ast)) in self.definition_ast_list.iter().enumerate().skip(self.built_in_num)
{
let mut function_def = def.write();
if let TopLevelDef::Function {
instance_to_stmt,
name,
signature,
var_id,
resolver,
..
} = &mut *function_def
if let TopLevelDef::Function { instance_to_stmt, name, simple_name, signature, resolver, .. } =
&mut *function_def
{
if let TypeEnum::TFunc(func_sig) = self.unifier.get_ty(*signature).as_ref() {
if let TypeEnum::TFunc(func_sig) = self.unifier.get_ty(*signature).as_ref() {
let FunSignature { args, ret, vars } = &*func_sig.borrow();
// None if is not class method
let self_type = {
if let Some(class_id) = self.method_class.get(&DefinitionId(id)) {
let class_def = self.definition_ast_list.get(class_id.0).unwrap();
let class_def = class_def.0.read();
if let TopLevelDef::Class { type_vars, .. } = &*class_def {
if let TopLevelDef::Class { type_vars, constructor, .. } = &*class_def {
let ty_ann = make_self_type_annotation(type_vars, *class_id);
Some(get_type_from_type_annotation_kinds(
let self_ty = get_type_from_type_annotation_kinds(
self.extract_def_list().as_slice(),
&mut self.unifier,
&self.primitives_ty,
&ty_ann,
)?)
)?;
if simple_name == "__init__" {
let fn_type = self.unifier.add_ty(TypeEnum::TFunc(RefCell::new(FunSignature {
args: args.clone(),
ret: self_ty,
vars: vars.clone()
})));
self.unifier.unify(fn_type, constructor.unwrap())?;
}
Some(self_ty)
} else {
unreachable!("must be class def")
}
@ -1227,8 +1229,7 @@ impl TopLevelComposer {
let inst_ret = self.unifier.subst(*ret, &subst).unwrap_or(*ret);
let inst_args = {
let unifier = &mut self.unifier;
args
.iter()
args.iter()
.map(|a| FuncArg {
name: a.name.clone(),
ty: unifier.subst(a.ty, &subst).unwrap_or(a.ty),
@ -1319,15 +1320,15 @@ impl TopLevelComposer {
.sorted()
.map(|id| {
let ty = subst.get(id).unwrap();
unifier.stringify(*ty, &mut |id| id.to_string(), &mut |id| id.to_string())
}).join(", ")
},
FunInstance {
body: fun_body,
unifier_id: 0,
calls,
subst,
unifier.stringify(
*ty,
&mut |id| id.to_string(),
&mut |id| id.to_string(),
)
})
.join(", ")
},
FunInstance { body: fun_body, unifier_id: 0, calls, subst },
);
}
} else {

View File

@ -50,7 +50,6 @@ impl TopLevelDef {
r
}
),
TopLevelDef::Initializer { class_id } => format!("Initializer {{ {:?} }}", class_id),
}
}
}
@ -94,6 +93,7 @@ impl TopLevelComposer {
index: usize,
resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
name: &str,
constructor: Option<Type>
) -> TopLevelDef {
TopLevelDef::Class {
name: name.to_string(),
@ -102,6 +102,7 @@ impl TopLevelComposer {
fields: Default::default(),
methods: Default::default(),
ancestors: Default::default(),
constructor,
resolver,
}
}
@ -109,11 +110,13 @@ impl TopLevelComposer {
/// when first registering, the type is a invalid value
pub fn make_top_level_function_def(
name: String,
simple_name: String,
ty: Type,
resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
) -> TopLevelDef {
TopLevelDef::Function {
name,
simple_name,
signature: ty,
var_id: Default::default(),
instance_to_symbol: Default::default(),

View File

@ -53,10 +53,14 @@ pub enum TopLevelDef {
ancestors: Vec<TypeAnnotation>,
// symbol resolver of the module defined the class, none if it is built-in type
resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
// constructor type
constructor: Option<Type>,
},
Function {
// prefix for symbol, should be unique globally, and not ending with numbers
name: String,
// simple name, the same as in method/function definition
simple_name: String,
// function signature.
signature: Type,
// instantiated type variable IDs
@ -75,9 +79,6 @@ pub enum TopLevelDef {
// symbol resolver of the module defined the class
resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
},
Initializer {
class_id: DefinitionId,
},
}
pub struct TopLevelContext {

View File

@ -1,64 +1,20 @@
def y_scale(maxX: float, minX: float, height: float, width: float, aspectRatio: float) -> float:
return (maxX-minX)*(height/width)*aspectRatio
class A:
a: int32
def __init__(self, a: int32):
self.a = a
def check_smaller_than_sixteen(i: int32) -> bool:
return i < 16
def get_a(self) -> int32:
return self.a
def rec(x: int32):
if x > 1:
output(x)
rec(x - 1)
return
else:
output(-1)
return
def fib(n: int32) -> int32:
if n <= 2:
return 1
else:
return fib(n - 1) + fib(n - 2)
def draw():
minX = -2.0
maxX = 1.0
width = 78.0
height = 36.0
aspectRatio = 2.0
# test = 1.0 + 1
yScale = y_scale(maxX, minX, height, width, aspectRatio)
y = 0.0
while y < height:
x = 0.0
while x < width:
c_r = minX+x*(maxX-minX)/width
c_i = y*yScale/height-yScale/2.0
z_r = c_r
z_i = c_i
i = 0
while check_smaller_than_sixteen(i):
if z_r*z_r + z_i*z_i > 4.0:
break
new_z_r = (z_r*z_r)-(z_i*z_i) + c_r
z_i = 2.0*z_r*z_i + c_i
z_r = new_z_r
i = i + 1
output(i)
x = x + 1.0
output(-1)
y = y + 1.0
return
def get_self(self) -> A:
return self
def run() -> int32:
rec(5)
a = A(10)
output(a.a)
output(fib(10))
output(-1)
draw()
a = A(20)
output(a.a)
output(a.get_a())
return 0

View File

@ -1,14 +1,11 @@
use std::time::SystemTime;
use std::{collections::HashSet, fs};
use std::fs;
use inkwell::{
passes::{PassManager, PassManagerBuilder},
targets::*,
OptimizationLevel,
};
use nac3core::typecheck::type_inferencer::PrimitiveStore;
use parking_lot::RwLock;
use rustpython_parser::parser;
use rustpython_parser::{parser, ast::StmtKind};
use std::{collections::HashMap, path::Path, sync::Arc};
use nac3core::{
@ -55,11 +52,17 @@ fn main() {
);
for stmt in parser::parse_program(&program).unwrap().into_iter() {
let is_class = matches!(stmt.node, StmtKind::ClassDef{ .. });
let (name, def_id, ty) = composer.register_top_level(
stmt,
Some(resolver.clone()),
"__main__".into(),
).unwrap();
if is_class {
internal_resolver.add_id_type(name.clone(), ty.unwrap());
}
internal_resolver.add_id_def(name.clone(), def_id);
if let Some(ty) = ty {
internal_resolver.add_id_type(name, ty);