From 4939ff4dbdadcb2fa7860b379afa1605193f0a62 Mon Sep 17 00:00:00 2001 From: pca006132 Date: Sun, 19 Sep 2021 22:54:06 +0800 Subject: [PATCH] simple implementation of classes --- nac3core/src/codegen/expr.rs | 114 +++++++++++++++++++++++------- nac3core/src/codegen/mod.rs | 10 ++- nac3core/src/toplevel/composer.rs | 93 ++++++++++++------------ nac3core/src/toplevel/helper.rs | 5 +- nac3core/src/toplevel/mod.rs | 7 +- nac3standalone/mandelbrot.py | 70 ++++-------------- nac3standalone/src/main.rs | 13 ++-- 7 files changed, 172 insertions(+), 140 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 013785e5f..4d290e27c 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -114,10 +114,28 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { let symbol = { // make sure this lock guard is dropped at the end of this scope... let def = definition.read(); - if let TopLevelDef::Function { instance_to_symbol, .. } = &*def { - instance_to_symbol.get(&key).cloned() - } else { - unreachable!() + match &*def { + TopLevelDef::Function { instance_to_symbol, .. } => { + instance_to_symbol.get(&key).cloned() + } + TopLevelDef::Class { methods, .. } => { + // TODO: what about other fields that require alloca? + let mut fun_id = None; + for (name, _, id) in methods.iter() { + if name == "__init__" { + fun_id = Some(*id); + } + } + let fun_id = fun_id.unwrap(); + + let ty = self.get_llvm_type(fun.0.ret).into_pointer_type(); + let zelf_ty: BasicTypeEnum = ty.get_element_type().try_into().unwrap(); + let zelf = self.builder.build_alloca(zelf_ty, "alloca").into(); + let mut sign = fun.0.clone(); + sign.ret = self.primitives.none; + self.gen_call(Some((fun.0.ret, zelf)), (&sign, fun_id), params); + return Some(zelf); + } } } .unwrap_or_else(|| { @@ -164,7 +182,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { }) .collect(); - let signature = FunSignature { + let mut signature = FunSignature { args: fun .0 .args @@ -186,6 +204,13 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { .collect(), }; + if let Some(obj) = &obj { + signature.args.insert( + 0, + FuncArg { name: "self".into(), ty: obj.0, default_value: None }, + ); + } + let unifier = (unifier.get_shared_unifier(), *primitives); task = Some(CodeGenTask { @@ -209,7 +234,11 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { } let fun_val = self.module.get_function(&symbol).unwrap_or_else(|| { - let params = fun.0.args.iter().map(|arg| self.get_llvm_type(arg.ty)).collect_vec(); + let mut args = fun.0.args.clone(); + if let Some(obj) = &obj { + args.insert(0, FuncArg { name: "self".into(), ty: obj.0, default_value: None }); + } + let params = args.iter().map(|arg| self.get_llvm_type(arg.ty)).collect_vec(); let fun_ty = if self.unifier.unioned(fun.0.ret, self.primitives.none) { self.ctx.void_type().fn_type(¶ms, false) } else { @@ -227,7 +256,11 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { mapping.insert(k.name, self.gen_symbol_val(&k.default_value.unwrap())); } // reorder the parameters - let params = fun.0.args.iter().map(|arg| mapping.remove(&arg.name).unwrap()).collect_vec(); + let mut params = + fun.0.args.iter().map(|arg| mapping.remove(&arg.name).unwrap()).collect_vec(); + if let Some(obj) = obj { + params.insert(0, obj.1); + } self.builder.build_call(fun_val, ¶ms, "call").try_as_basic_value().left() } @@ -607,26 +640,53 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { phi.as_basic_value() } ExprKind::Call { func, args, keywords } => { - if let ExprKind::Name { id, .. } = &func.as_ref().node { - // TODO: handle primitive casts and function pointers - let fun = self.resolver.get_identifier_def(&id).expect("Unknown identifier"); - let mut params = - args.iter().map(|arg| (None, self.gen_expr(arg).unwrap())).collect_vec(); - let kw_iter = keywords.iter().map(|kw| { - ( - Some(kw.node.arg.as_ref().unwrap().clone()), - self.gen_expr(&kw.node.value).unwrap(), - ) - }); - params.extend(kw_iter); - let signature = self - .unifier - .get_call_signature(*self.calls.get(&expr.location.into()).unwrap()) - .unwrap(); - return self.gen_call(None, (&signature, fun), params); - } else { - // TODO: method - unimplemented!() + let mut params = + args.iter().map(|arg| (None, self.gen_expr(arg).unwrap())).collect_vec(); + let kw_iter = keywords.iter().map(|kw| { + ( + Some(kw.node.arg.as_ref().unwrap().clone()), + self.gen_expr(&kw.node.value).unwrap(), + ) + }); + params.extend(kw_iter); + let signature = self + .unifier + .get_call_signature(*self.calls.get(&expr.location.into()).unwrap()) + .unwrap(); + match &func.as_ref().node { + ExprKind::Name { id, .. } => { + // TODO: handle primitive casts and function pointers + let fun = + self.resolver.get_identifier_def(&id).expect("Unknown identifier"); + return self.gen_call(None, (&signature, fun), params); + } + ExprKind::Attribute { value, attr, .. } => { + let val = self.gen_expr(value).unwrap(); + let id = if let TypeEnum::TObj { obj_id, .. } = + &*self.unifier.get_ty(value.custom.unwrap()) + { + *obj_id + } else { + unreachable!() + }; + let fun_id = { + let defs = self.top_level.definitions.read(); + let obj_def = defs.get(id.0).unwrap().read(); + if let TopLevelDef::Class { methods, .. } = &*obj_def { + let mut fun_id = None; + for (name, _, id) in methods.iter() { + if name == attr { + fun_id = Some(*id); + } + } + fun_id.unwrap() + } else { + unreachable!() + } + }; + return self.gen_call(Some((value.custom.unwrap(), val)), (&signature, fun_id), params); + } + _ => unimplemented!(), } } ExprKind::Subscript { value, slice, .. } => { diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 1c3e4e16d..413572906 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -341,8 +341,16 @@ pub fn gen_func<'ctx>( unifier, }; + let mut returned = false; for stmt in task.body.iter() { - code_gen_context.gen_stmt(stmt); + returned = code_gen_context.gen_stmt(stmt); + if returned { + break; + } + } + // after static analysis, only void functions can have no return at the end. + if !returned { + code_gen_context.builder.build_return(None); } let CodeGenContext { builder, module, .. } = code_gen_context; diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 15523d1d9..d1099a642 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -34,16 +34,18 @@ impl Default for TopLevelComposer { impl TopLevelComposer { /// return a composer and things to make a "primitive" symbol resolver, so that the symbol /// resolver can later figure out primitive type definitions when passed a primitive type name - pub fn new(builtins: Vec<(String, FunSignature)>) -> (Self, HashMap, HashMap) { + pub fn new( + builtins: Vec<(String, FunSignature)>, + ) -> (Self, HashMap, HashMap) { let primitives = Self::make_primitives(); let mut definition_ast_list = { let top_level_def_list = vec![ - Arc::new(RwLock::new(Self::make_top_level_class_def(0, None, "int32"))), - Arc::new(RwLock::new(Self::make_top_level_class_def(1, None, "int64"))), - Arc::new(RwLock::new(Self::make_top_level_class_def(2, None, "float"))), - Arc::new(RwLock::new(Self::make_top_level_class_def(3, None, "bool"))), - Arc::new(RwLock::new(Self::make_top_level_class_def(4, None, "none"))), + Arc::new(RwLock::new(Self::make_top_level_class_def(0, None, "int32", None))), + Arc::new(RwLock::new(Self::make_top_level_class_def(1, None, "int64", None))), + Arc::new(RwLock::new(Self::make_top_level_class_def(2, None, "float", None))), + Arc::new(RwLock::new(Self::make_top_level_class_def(3, None, "bool", None))), + Arc::new(RwLock::new(Self::make_top_level_class_def(4, None, "none", None))), ]; let ast_list: Vec>> = vec![None, None, None, None, None]; izip!(top_level_def_list, ast_list).collect_vec() @@ -69,10 +71,10 @@ impl TopLevelComposer { let mut defined_class_name: HashSet = Default::default(); let mut defined_function_name: HashSet = Default::default(); let method_class: HashMap = Default::default(); - + let mut built_in_id: HashMap = Default::default(); let mut built_in_ty: HashMap = Default::default(); - + for (name, sig) in builtins { let fun_sig = unifier.add_ty(TypeEnum::TFunc(RefCell::new(sig))); built_in_ty.insert(name.clone(), fun_sig); @@ -80,22 +82,20 @@ impl TopLevelComposer { definition_ast_list.push(( Arc::new(RwLock::new(TopLevelDef::Function { name: name.clone(), + simple_name: name.clone(), signature: fun_sig, instance_to_stmt: HashMap::new(), - instance_to_symbol: [("".to_string(), name.clone())] - .iter() - .cloned() - .collect(), + instance_to_symbol: [("".to_string(), name.clone())].iter().cloned().collect(), var_id: Default::default(), resolver: None, })), - None + None, )); defined_class_method_name.insert(name.clone()); defined_class_name.insert(name.clone()); defined_function_name.insert(name); } - + ( TopLevelComposer { built_in_num: definition_ast_list.len(), @@ -160,11 +160,13 @@ impl TopLevelComposer { // since later when registering class method, ast will still be used, // here push None temporarly, later will move the ast inside + let constructor_ty = self.unifier.get_fresh_var().0; let mut class_def_ast = ( Arc::new(RwLock::new(Self::make_top_level_class_def( class_def_id, resolver.clone(), name, + Some(constructor_ty) ))), None, ); @@ -215,6 +217,7 @@ impl TopLevelComposer { method_name.clone(), RwLock::new(Self::make_top_level_function_def( global_class_method_name, + method_name.clone(), // later unify with parsed type dummy_method_type.0, resolver.clone(), @@ -251,14 +254,7 @@ impl TopLevelComposer { self.definition_ast_list.push((def, Some(ast))); } - // put the constructor into the def_list - self.definition_ast_list.push(( - RwLock::new(TopLevelDef::Initializer { class_id: DefinitionId(class_def_id) }) - .into(), - None, - )); - - Ok((class_name, DefinitionId(class_def_id), None)) + Ok((class_name, DefinitionId(class_def_id), Some(constructor_ty))) } ast::StmtKind::FunctionDef { name, .. } => { @@ -278,6 +274,8 @@ impl TopLevelComposer { // add to the definition list self.definition_ast_list.push(( RwLock::new(Self::make_top_level_function_def( + // TODO: is this fun_name or the above name with mod_path? + name.into(), name.into(), // dummy here, unify with correct type later ty_to_be_unified, @@ -801,7 +799,7 @@ impl TopLevelComposer { resolver, type_vars, .. - } = class_def.deref_mut() + } = &mut *class_def { if let ast::StmtKind::ClassDef { name, bases, body, .. } = &class_ast { ( @@ -1161,32 +1159,36 @@ impl TopLevelComposer { /// step 5, analyze and call type inferecer to fill the `instance_to_stmt` of topleveldef::function fn analyze_function_instance(&mut self) -> Result<(), String> { - for (id, (def, ast)) in self.definition_ast_list.iter().enumerate().skip(self.built_in_num) { + for (id, (def, ast)) in self.definition_ast_list.iter().enumerate().skip(self.built_in_num) + { let mut function_def = def.write(); - if let TopLevelDef::Function { - instance_to_stmt, - name, - signature, - var_id, - resolver, - .. - } = &mut *function_def + if let TopLevelDef::Function { instance_to_stmt, name, simple_name, signature, resolver, .. } = + &mut *function_def { - if let TypeEnum::TFunc(func_sig) = self.unifier.get_ty(*signature).as_ref() { + if let TypeEnum::TFunc(func_sig) = self.unifier.get_ty(*signature).as_ref() { let FunSignature { args, ret, vars } = &*func_sig.borrow(); // None if is not class method let self_type = { if let Some(class_id) = self.method_class.get(&DefinitionId(id)) { let class_def = self.definition_ast_list.get(class_id.0).unwrap(); let class_def = class_def.0.read(); - if let TopLevelDef::Class { type_vars, .. } = &*class_def { + if let TopLevelDef::Class { type_vars, constructor, .. } = &*class_def { let ty_ann = make_self_type_annotation(type_vars, *class_id); - Some(get_type_from_type_annotation_kinds( + let self_ty = get_type_from_type_annotation_kinds( self.extract_def_list().as_slice(), &mut self.unifier, &self.primitives_ty, &ty_ann, - )?) + )?; + if simple_name == "__init__" { + let fn_type = self.unifier.add_ty(TypeEnum::TFunc(RefCell::new(FunSignature { + args: args.clone(), + ret: self_ty, + vars: vars.clone() + }))); + self.unifier.unify(fn_type, constructor.unwrap())?; + } + Some(self_ty) } else { unreachable!("must be class def") } @@ -1227,8 +1229,7 @@ impl TopLevelComposer { let inst_ret = self.unifier.subst(*ret, &subst).unwrap_or(*ret); let inst_args = { let unifier = &mut self.unifier; - args - .iter() + args.iter() .map(|a| FuncArg { name: a.name.clone(), ty: unifier.subst(a.ty, &subst).unwrap_or(a.ty), @@ -1319,15 +1320,15 @@ impl TopLevelComposer { .sorted() .map(|id| { let ty = subst.get(id).unwrap(); - unifier.stringify(*ty, &mut |id| id.to_string(), &mut |id| id.to_string()) - }).join(", ") - }, - FunInstance { - body: fun_body, - unifier_id: 0, - calls, - subst, + unifier.stringify( + *ty, + &mut |id| id.to_string(), + &mut |id| id.to_string(), + ) + }) + .join(", ") }, + FunInstance { body: fun_body, unifier_id: 0, calls, subst }, ); } } else { diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index dd70fa908..f128b2757 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -50,7 +50,6 @@ impl TopLevelDef { r } ), - TopLevelDef::Initializer { class_id } => format!("Initializer {{ {:?} }}", class_id), } } } @@ -94,6 +93,7 @@ impl TopLevelComposer { index: usize, resolver: Option>>, name: &str, + constructor: Option ) -> TopLevelDef { TopLevelDef::Class { name: name.to_string(), @@ -102,6 +102,7 @@ impl TopLevelComposer { fields: Default::default(), methods: Default::default(), ancestors: Default::default(), + constructor, resolver, } } @@ -109,11 +110,13 @@ impl TopLevelComposer { /// when first registering, the type is a invalid value pub fn make_top_level_function_def( name: String, + simple_name: String, ty: Type, resolver: Option>>, ) -> TopLevelDef { TopLevelDef::Function { name, + simple_name, signature: ty, var_id: Default::default(), instance_to_symbol: Default::default(), diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index 206aefe9c..4076931fe 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -53,10 +53,14 @@ pub enum TopLevelDef { ancestors: Vec, // symbol resolver of the module defined the class, none if it is built-in type resolver: Option>>, + // constructor type + constructor: Option, }, Function { // prefix for symbol, should be unique globally, and not ending with numbers name: String, + // simple name, the same as in method/function definition + simple_name: String, // function signature. signature: Type, // instantiated type variable IDs @@ -75,9 +79,6 @@ pub enum TopLevelDef { // symbol resolver of the module defined the class resolver: Option>>, }, - Initializer { - class_id: DefinitionId, - }, } pub struct TopLevelContext { diff --git a/nac3standalone/mandelbrot.py b/nac3standalone/mandelbrot.py index 7afd28a7c..07d0bcbf6 100644 --- a/nac3standalone/mandelbrot.py +++ b/nac3standalone/mandelbrot.py @@ -1,64 +1,20 @@ -def y_scale(maxX: float, minX: float, height: float, width: float, aspectRatio: float) -> float: - return (maxX-minX)*(height/width)*aspectRatio +class A: + a: int32 + def __init__(self, a: int32): + self.a = a -def check_smaller_than_sixteen(i: int32) -> bool: - return i < 16 + def get_a(self) -> int32: + return self.a -def rec(x: int32): - if x > 1: - output(x) - rec(x - 1) - return - else: - output(-1) - return - -def fib(n: int32) -> int32: - if n <= 2: - return 1 - else: - return fib(n - 1) + fib(n - 2) - -def draw(): - minX = -2.0 - maxX = 1.0 - width = 78.0 - height = 36.0 - aspectRatio = 2.0 - - # test = 1.0 + 1 - - yScale = y_scale(maxX, minX, height, width, aspectRatio) - - y = 0.0 - while y < height: - x = 0.0 - while x < width: - c_r = minX+x*(maxX-minX)/width - c_i = y*yScale/height-yScale/2.0 - z_r = c_r - z_i = c_i - i = 0 - while check_smaller_than_sixteen(i): - if z_r*z_r + z_i*z_i > 4.0: - break - new_z_r = (z_r*z_r)-(z_i*z_i) + c_r - z_i = 2.0*z_r*z_i + c_i - z_r = new_z_r - i = i + 1 - output(i) - x = x + 1.0 - output(-1) - y = y + 1.0 - - return + def get_self(self) -> A: + return self def run() -> int32: - rec(5) + a = A(10) + output(a.a) - output(fib(10)) - output(-1) - - draw() + a = A(20) + output(a.a) + output(a.get_a()) return 0 diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index 07b994157..acbfb00d8 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -1,14 +1,11 @@ -use std::time::SystemTime; -use std::{collections::HashSet, fs}; - +use std::fs; use inkwell::{ passes::{PassManager, PassManagerBuilder}, targets::*, OptimizationLevel, }; use nac3core::typecheck::type_inferencer::PrimitiveStore; -use parking_lot::RwLock; -use rustpython_parser::parser; +use rustpython_parser::{parser, ast::StmtKind}; use std::{collections::HashMap, path::Path, sync::Arc}; use nac3core::{ @@ -55,11 +52,17 @@ fn main() { ); for stmt in parser::parse_program(&program).unwrap().into_iter() { + let is_class = matches!(stmt.node, StmtKind::ClassDef{ .. }); let (name, def_id, ty) = composer.register_top_level( stmt, Some(resolver.clone()), "__main__".into(), ).unwrap(); + + if is_class { + internal_resolver.add_id_type(name.clone(), ty.unwrap()); + } + internal_resolver.add_id_def(name.clone(), def_id); if let Some(ty) = ty { internal_resolver.add_id_type(name, ty);