From b267a656a8e5e09ee6ecd473217712658d47e8af Mon Sep 17 00:00:00 2001 From: pca006132 Date: Sat, 12 Feb 2022 21:09:23 +0800 Subject: [PATCH] nac3core: added exception type and fixed primitive representation - Added `Exception` primitive type and some builtin exception types. Note that all exception types share the same layout, and should inherit from the base `Exception` type. There are some hacks in the toplevel module for handling exception types, we should revisit and fix them later. - Added new primitive types to concrete type module, otherwise there would be some weird type errors. - Changed the representation of strings to CSlice, instead of CString. --- nac3core/src/codegen/concrete_type.rs | 15 +++ nac3core/src/codegen/expr.rs | 33 ++++- nac3core/src/codegen/mod.rs | 45 +++++-- nac3core/src/codegen/stmt.rs | 65 ++++++++++ nac3core/src/symbol_resolver.rs | 67 ++++------ nac3core/src/toplevel/builtins.rs | 101 ++++++++++++++- nac3core/src/toplevel/composer.rs | 122 +++++++++++++----- nac3core/src/toplevel/helper.rs | 24 +++- nac3core/src/toplevel/type_annotation.rs | 2 + nac3core/src/typecheck/type_inferencer/mod.rs | 1 + 10 files changed, 384 insertions(+), 91 deletions(-) diff --git a/nac3core/src/codegen/concrete_type.rs b/nac3core/src/codegen/concrete_type.rs index 140142ef..fd13afac 100644 --- a/nac3core/src/codegen/concrete_type.rs +++ b/nac3core/src/codegen/concrete_type.rs @@ -31,6 +31,9 @@ pub enum Primitive { Float, Bool, None, + Range, + Str, + Exception } #[derive(Debug)] @@ -66,6 +69,9 @@ impl ConcreteTypeStore { ConcreteTypeEnum::TPrimitive(Primitive::Float), ConcreteTypeEnum::TPrimitive(Primitive::Bool), ConcreteTypeEnum::TPrimitive(Primitive::None), + ConcreteTypeEnum::TPrimitive(Primitive::Range), + ConcreteTypeEnum::TPrimitive(Primitive::Str), + ConcreteTypeEnum::TPrimitive(Primitive::Exception), ], } } @@ -118,6 +124,12 @@ impl ConcreteTypeStore { ConcreteType(3) } else if unifier.unioned(ty, primitives.none) { ConcreteType(4) + } else if unifier.unioned(ty, primitives.range) { + ConcreteType(5) + } else if unifier.unioned(ty, primitives.str) { + ConcreteType(6) + } else if unifier.unioned(ty, primitives.exception) { + ConcreteType(7) } else if let Some(cty) = cache.get(&ty) { if let Some(cty) = cty { *cty @@ -211,6 +223,9 @@ impl ConcreteTypeStore { Primitive::Float => primitives.float, Primitive::Bool => primitives.bool, Primitive::None => primitives.none, + Primitive::Range => primitives.range, + Primitive::Str => primitives.str, + Primitive::Exception => primitives.exception, }; *cache.get_mut(&cty).unwrap() = Some(ty); return ty; diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index c20be479..b1368ba2 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -76,14 +76,20 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { index } - fn gen_symbol_val(&mut self, val: &SymbolValue) -> BasicValueEnum<'ctx> { + pub fn gen_symbol_val(&mut self, generator: &mut dyn CodeGenerator, val: &SymbolValue) -> BasicValueEnum<'ctx> { match val { SymbolValue::I32(v) => self.ctx.i32_type().const_int(*v as u64, true).into(), SymbolValue::I64(v) => self.ctx.i64_type().const_int(*v as u64, true).into(), SymbolValue::Bool(v) => self.ctx.bool_type().const_int(*v as u64, true).into(), SymbolValue::Double(v) => self.ctx.f64_type().const_float(*v).into(), + SymbolValue::Str(v) => { + let str_ptr = self.builder.build_global_string_ptr(v, "const").as_pointer_value().into(); + let size = generator.get_size_type(self.ctx).const_int(v.len() as u64, false); + let ty = self.get_llvm_type(generator, self.primitives.str).into_struct_type(); + ty.const_named_struct(&[str_ptr, size.into()]).into() + } SymbolValue::Tuple(ls) => { - let vals = ls.iter().map(|v| self.gen_symbol_val(v)).collect_vec(); + let vals = ls.iter().map(|v| self.gen_symbol_val(generator, v)).collect_vec(); let fields = vals.iter().map(|v| v.get_type()).collect_vec(); let ty = self.ctx.struct_type(&fields, false); let ptr = self.builder.build_alloca(ty, "tuple"); @@ -118,7 +124,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { ) } - fn gen_const(&mut self, value: &Constant, ty: Type) -> BasicValueEnum<'ctx> { + pub fn gen_const(&mut self, generator: &mut dyn CodeGenerator, value: &Constant, ty: Type) -> BasicValueEnum<'ctx> { match value { Constant::Bool(v) => { assert!(self.unifier.unioned(ty, self.primitives.bool)); @@ -145,7 +151,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { let types = if let TypeEnum::TTuple { ty } = &*ty { ty.clone() } else { unreachable!() }; let values = zip(types.into_iter(), v.iter()) - .map(|(ty, v)| self.gen_const(v, ty)) + .map(|(ty, v)| self.gen_const(generator, v, ty)) .collect_vec(); let types = values.iter().map(BasicValueEnum::get_type).collect_vec(); let ty = self.ctx.struct_type(&types, false); @@ -153,7 +159,16 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { } Constant::Str(v) => { assert!(self.unifier.unioned(ty, self.primitives.str)); - self.builder.build_global_string_ptr(v, "const").as_pointer_value().into() + if let Some(v) = self.const_strings.get(v) { + *v + } else { + let str_ptr = self.builder.build_global_string_ptr(v, "const").as_pointer_value().into(); + let size = generator.get_size_type(self.ctx).const_int(v.len() as u64, false); + let ty = self.get_llvm_type(generator, self.primitives.str); + let val = ty.into_struct_type().const_named_struct(&[str_ptr, size.into()]).into(); + self.const_strings.insert(v.to_string(), val); + val + } } _ => unreachable!(), } @@ -241,6 +256,14 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { _ => unimplemented!(), } } + pub fn gen_string>( + &mut self, + generator: &mut G, + s: S + ) -> BasicValueEnum<'ctx> { + self.gen_const(generator, &nac3parser::ast::Constant::Str(s.into()), self.primitives.str) + } + } pub fn gen_constructor<'ctx, 'a, G: CodeGenerator>( diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index ead7b0c2..1601a3e7 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -314,25 +314,50 @@ pub fn gen_func<'ctx, G: CodeGenerator>( none: unifier.get_representative(primitives.none), range: unifier.get_representative(primitives.range), str: unifier.get_representative(primitives.str), + exception: unifier.get_representative(primitives.exception), }; let mut type_cache: HashMap<_, _> = [ - (unifier.get_representative(primitives.int32), context.i32_type().into()), - (unifier.get_representative(primitives.int64), context.i64_type().into()), - (unifier.get_representative(primitives.float), context.f64_type().into()), - (unifier.get_representative(primitives.bool), context.bool_type().into()), + (primitives.int32, context.i32_type().into()), + (primitives.int64, context.i64_type().into()), + (primitives.float, context.f64_type().into()), + (primitives.bool, context.bool_type().into()), + (primitives.str, { + let str_type = context.opaque_struct_type("str"); + let fields = [ + context.i8_type().ptr_type(AddressSpace::Generic).into(), + generator.get_size_type(context).into(), + ]; + str_type.set_body(&fields, false); + str_type.into() + }), ( - unifier.get_representative(primitives.str), - context.i8_type().ptr_type(AddressSpace::Generic).into(), - ), - ( - unifier.get_representative(primitives.range), - context.i32_type().array_type(3).ptr_type(AddressSpace::Generic).into() + primitives.range, + context.i32_type().array_type(3).ptr_type(AddressSpace::Generic).into(), ), ] .iter() .cloned() .collect(); + type_cache.insert(primitives.exception, { + let exception = context.opaque_struct_type("Exception"); + let int32 = context.i32_type().into(); + let int64 = context.i64_type().into(); + let str_ty = *type_cache.get(&primitives.str).unwrap(); + let fields = [ + int32, + str_ty, + int32, + int32, + str_ty, + str_ty, + int64, + int64, + int64 + ]; + exception.set_body(&fields, false); + exception.ptr_type(AddressSpace::Generic).into() + }); let (args, ret) = if let ConcreteTypeEnum::TFunc { args, ret, .. } = task.store.get(task.signature) diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index d8559f28..f8f42d43 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -402,6 +402,71 @@ pub fn gen_if<'ctx, 'a, G: CodeGenerator>( } } then_exited && else_exited +pub fn exn_constructor<'ctx, 'a>( + ctx: &mut CodeGenContext<'ctx, 'a>, + obj: Option<(Type, ValueEnum<'ctx>)>, + _fun: (&FunSignature, DefinitionId), + mut args: Vec<(Option, ValueEnum<'ctx>)>, + generator: &mut dyn CodeGenerator +) -> Option> { + let (zelf_ty, zelf) = obj.unwrap(); + let zelf = zelf.to_basic_value_enum(ctx, generator).into_pointer_value(); + let int32 = ctx.ctx.i32_type(); + let zero = int32.const_zero(); + let zelf_id = { + if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(zelf_ty) { + obj_id.0 + } else { + unreachable!() + } + }; + let defs = ctx.top_level.definitions.read(); + let def = defs[zelf_id].read(); + let zelf_name = if let TopLevelDef::Class { name, .. } = &*def { + *name + } else { + unreachable!() + }; + let exception_name = format!("0:{}", zelf_name); + unsafe { + let id_ptr = ctx.builder.build_in_bounds_gep(zelf, &[zero, zero], "exn.id"); + let id = ctx.resolver.get_string_id(&exception_name); + ctx.builder.build_store(id_ptr, int32.const_int(id as u64, false)); + let empty_string = ctx.gen_const(generator, &Constant::Str("".into()), ctx.primitives.str); + let ptr = ctx.builder.build_in_bounds_gep( + zelf, &[zero, int32.const_int(5, false)], "exn.msg"); + let msg = if !args.is_empty() { + args.remove(0).1.to_basic_value_enum(ctx, generator) + } else { + empty_string + }; + ctx.builder.build_store(ptr, msg); + for i in [6, 7, 8].iter() { + let value = if !args.is_empty() { + args.remove(0).1.to_basic_value_enum(ctx, generator) + } else { + ctx.ctx.i64_type().const_zero().into() + }; + let ptr = ctx.builder.build_in_bounds_gep( + zelf, &[zero, int32.const_int(*i, false)], "exn.param"); + ctx.builder.build_store(ptr, value); + } + // set file, func to empty string + for i in [1, 4].iter() { + let ptr = ctx.builder.build_in_bounds_gep( + zelf, &[zero, int32.const_int(*i, false)], "exn.str"); + ctx.builder.build_store(ptr, empty_string); + } + // set ints to zero + for i in [2, 3].iter() { + let ptr = ctx.builder.build_in_bounds_gep( + zelf, &[zero, int32.const_int(*i, false)], "exn.ints"); + ctx.builder.build_store(ptr, zero); + } + } + Some(zelf.into()) +} + } else { unreachable!() } diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 88b00dd3..dcd3c378 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -2,14 +2,17 @@ use std::collections::HashMap; use std::fmt::Debug; use std::{cell::RefCell, sync::Arc}; -use crate::{codegen::CodeGenerator, typecheck::{ - type_inferencer::PrimitiveStore, - typedef::{Type, Unifier}, -}}; use crate::{ codegen::CodeGenContext, toplevel::{DefinitionId, TopLevelDef}, }; +use crate::{ + codegen::CodeGenerator, + typecheck::{ + type_inferencer::PrimitiveStore, + typedef::{Type, Unifier}, + }, +}; use crate::{location::Location, typecheck::typedef::TypeEnum}; use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue}; use itertools::{chain, izip}; @@ -20,6 +23,7 @@ use parking_lot::RwLock; pub enum SymbolValue { I32(i32), I64(i64), + Str(String), Double(f64), Bool(bool), Tuple(Vec), @@ -109,7 +113,7 @@ pub trait SymbolResolver { } thread_local! { - static IDENTIFIER_ID: [StrRef; 8] = [ + static IDENTIFIER_ID: [StrRef; 10] = [ "int32".into(), "int64".into(), "float".into(), @@ -117,7 +121,9 @@ thread_local! { "None".into(), "virtual".into(), "list".into(), - "tuple".into() + "tuple".into(), + "str".into(), + "Exception".into(), ]; } @@ -139,6 +145,8 @@ pub fn parse_type_annotation( let virtual_id = ids[5]; let list_id = ids[6]; let tuple_id = ids[7]; + let str_id = ids[8]; + let exn_id = ids[9]; let name_handling = |id: &StrRef, unifier: &mut Unifier| { if *id == int32_id { @@ -151,6 +159,10 @@ pub fn parse_type_annotation( Ok(primitives.bool) } else if *id == none_id { Ok(primitives.none) + } else if *id == str_id { + Ok(primitives.str) + } else if *id == exn_id { + Ok(primitives.exception) } else { let obj_id = resolver.get_identifier_def(*id); if let Some(obj_id) = obj_id { @@ -179,8 +191,7 @@ pub fn parse_type_annotation( } } else { // it could be a type variable - let ty = resolver - .get_symbol_type(unifier, top_level_defs, primitives, *id)?; + let ty = resolver.get_symbol_type(unifier, top_level_defs, primitives, *id)?; if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) { Ok(ty) } else { @@ -192,35 +203,17 @@ pub fn parse_type_annotation( let subscript_name_handle = |id: &StrRef, slice: &Expr, unifier: &mut Unifier| { if *id == virtual_id { - let ty = parse_type_annotation( - resolver, - top_level_defs, - unifier, - primitives, - slice, - )?; + let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?; Ok(unifier.add_ty(TypeEnum::TVirtual { ty })) } else if *id == list_id { - let ty = parse_type_annotation( - resolver, - top_level_defs, - unifier, - primitives, - slice, - )?; + let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?; Ok(unifier.add_ty(TypeEnum::TList { ty })) } else if *id == tuple_id { if let Tuple { elts, .. } = &slice.node { let ty = elts .iter() .map(|elt| { - parse_type_annotation( - resolver, - top_level_defs, - unifier, - primitives, - elt, - ) + parse_type_annotation(resolver, top_level_defs, unifier, primitives, elt) }) .collect::, _>>()?; Ok(unifier.add_ty(TypeEnum::TTuple { ty })) @@ -231,23 +224,11 @@ pub fn parse_type_annotation( let types = if let Tuple { elts, .. } = &slice.node { elts.iter() .map(|v| { - parse_type_annotation( - resolver, - top_level_defs, - unifier, - primitives, - v, - ) + parse_type_annotation(resolver, top_level_defs, unifier, primitives, v) }) .collect::, _>>()? } else { - vec![parse_type_annotation( - resolver, - top_level_defs, - unifier, - primitives, - slice, - )?] + vec![parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?] }; let obj_id = resolver diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 7a331fd4..e8c84d71 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1,6 +1,6 @@ use super::*; use crate::{ - codegen::{expr::destructure_range, irrt::calculate_len_for_slice_range}, + codegen::{expr::destructure_range, irrt::calculate_len_for_slice_range, stmt::exn_constructor}, symbol_resolver::SymbolValue, }; use inkwell::{FloatPredicate, IntPredicate}; @@ -21,6 +21,47 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { let num_ty = primitives.1.get_fresh_var_with_range(&[int32, int64, float, boolean]); let var_map: HashMap<_, _> = vec![(num_ty.1, num_ty.0)].into_iter().collect(); + let exception_fields = vec![ + ("__name__".into(), int32, true), + ("__file__".into(), string, true), + ("__line__".into(), int32, true), + ("__col__".into(), int32, true), + ("__func__".into(), string, true), + ("__message__".into(), string, true), + ("__param0__".into(), int64, true), + ("__param1__".into(), int64, true), + ("__param2__".into(), int64, true), + ]; + let div_by_zero = primitives.1.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(10), + fields: RefCell::new(exception_fields.iter().map(|(a, b, c)| (*a, (*b, *c))).collect()), + params: Default::default() + }); + let index_error = primitives.1.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(11), + fields: RefCell::new(exception_fields.iter().map(|(a, b, c)| (*a, (*b, *c))).collect()), + params: Default::default() + }); + let exn_cons_args = vec![ + FuncArg { name: "msg".into(), ty: string, + default_value: Some(SymbolValue::Str("".into()))}, + FuncArg { name: "param0".into(), ty: int64, + default_value: Some(SymbolValue::I64(0))}, + FuncArg { name: "param1".into(), ty: int64, + default_value: Some(SymbolValue::I64(0))}, + FuncArg { name: "param2".into(), ty: int64, + default_value: Some(SymbolValue::I64(0))}, + ]; + let div_by_zero_signature = primitives.1.add_ty(TypeEnum::TFunc(RefCell::new(FunSignature { + args: exn_cons_args.clone(), + ret: div_by_zero, + vars: Default::default() + }))); + let index_error_signature = primitives.1.add_ty(TypeEnum::TFunc(RefCell::new(FunSignature { + args: exn_cons_args, + ret: index_error, + vars: Default::default() + }))); let top_level_def_list = vec![ Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( 0, @@ -49,6 +90,62 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { None, ))), Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def(6, None, "str".into(), None))), + Arc::new(RwLock::new(TopLevelDef::Class { + name: "Exception".into(), + object_id: DefinitionId(7), + type_vars: Default::default(), + fields: exception_fields.clone(), + methods: Default::default(), + ancestors: vec![], + constructor: None, + resolver: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ZeroDivisionError.__init__".into(), + simple_name: "__init__".into(), + signature: div_by_zero_signature, + var_id: Default::default(), + instance_to_symbol: Default::default(), + instance_to_stmt: Default::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new(exn_constructor)))) + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "IndexError.__init__".into(), + simple_name: "__init__".into(), + signature: index_error_signature, + var_id: Default::default(), + instance_to_symbol: Default::default(), + instance_to_stmt: Default::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new(exn_constructor)))) + })), + Arc::new(RwLock::new(TopLevelDef::Class { + name: "ZeroDivisionError".into(), + object_id: DefinitionId(10), + type_vars: Default::default(), + fields: exception_fields.clone(), + methods: vec![("__init__".into(), div_by_zero_signature, DefinitionId(8))], + ancestors: vec![ + TypeAnnotation::CustomClass { id: DefinitionId(10), params: Default::default() }, + TypeAnnotation::CustomClass { id: DefinitionId(7), params: Default::default() } + ], + constructor: Some(div_by_zero_signature), + resolver: None, + })), + Arc::new(RwLock::new(TopLevelDef::Class { + name: "IndexError".into(), + object_id: DefinitionId(11), + type_vars: Default::default(), + fields: exception_fields, + methods: vec![("__init__".into(), index_error_signature, DefinitionId(9))], + ancestors: vec![ + TypeAnnotation::CustomClass { id: DefinitionId(11), params: Default::default() }, + TypeAnnotation::CustomClass { id: DefinitionId(7), params: Default::default() } + ], + constructor: Some(index_error_signature), + resolver: None, + })), Arc::new(RwLock::new(TopLevelDef::Function { name: "int32".into(), simple_name: "int32".into(), @@ -609,6 +706,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { ( izip!(top_level_def_list, ast_list).collect_vec(), &[ + "ZeroDivisionError", + "IndexError", "int32", "int64", "float", diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index fb748fe2..762aaed0 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -4,7 +4,8 @@ use nac3parser::ast::fold::Fold; use crate::{ typecheck::type_inferencer::{FunctionData, Inferencer}, - codegen::expr::get_subst_key, + codegen::{expr::get_subst_key, stmt::exn_constructor}, + symbol_resolver::SymbolValue, }; use super::*; @@ -90,8 +91,13 @@ impl TopLevelComposer { assert!(name == *simple_name); builtin_ty.insert(name, *signature); builtin_id.insert(name, DefinitionId(id)); - } else { - unreachable!() + } else if let TopLevelDef::Class { name, constructor, object_id, type_vars, .. } = &*def { + assert!(id == object_id.0); + assert!(type_vars.is_empty()); + if let Some(constructor) = constructor { + builtin_ty.insert(*name, *constructor); + } + builtin_id.insert(*name, DefinitionId(id)); } } @@ -471,7 +477,6 @@ impl TopLevelComposer { let unifier = self.unifier.borrow_mut(); // first, only push direct parent into the list - // skip 5 to skip analyzing the primitives for (class_def, class_ast) in self.definition_ast_list.iter_mut().skip(self.builtin_num) { let mut class_def = class_def.write(); let (class_def_id, class_bases, class_ancestors, class_resolver, class_type_vars) = { @@ -540,7 +545,6 @@ impl TopLevelComposer { // second, get all ancestors let mut ancestors_store: HashMap> = Default::default(); - // skip 5 to skip analyzing the primitives for (class_def, _) in self.definition_ast_list.iter().skip(self.builtin_num) { let class_def = class_def.read(); let (class_ancestors, class_id) = { @@ -562,8 +566,7 @@ impl TopLevelComposer { } // insert the ancestors to the def list - // skip 5 to skip analyzing the primitives - for (class_def, _) in self.definition_ast_list.iter_mut().skip(self.builtin_num) { + for (class_def, class_ast) in self.definition_ast_list.iter_mut().skip(self.builtin_num) { let mut class_def = class_def.write(); let (class_ancestors, class_id, class_type_vars) = { if let TopLevelDef::Class { ancestors, object_id, type_vars, .. } = @@ -581,6 +584,26 @@ impl TopLevelComposer { // insert self type annotation to the front of the vector to maintain the order class_ancestors .insert(0, make_self_type_annotation(class_type_vars.as_slice(), class_id)); + + // special case classes that inherit from Exception + if class_ancestors.iter().any(|ann| matches!(ann, TypeAnnotation::CustomClass { id, .. } if id.0 == 7)) { + // if inherited from Exception, the body should be a pass + if let ast::StmtKind::ClassDef { body, .. } = &class_ast.as_ref().unwrap().node { + if body.len() != 1 || !matches!(body[0].node, ast::StmtKind::Pass { .. }) { + return Err("Classes inherited from exception should have `pass` as body".into()); + } + } else { + unreachable!() + } + } + } + + // deal with ancestor of Exception object + if let TopLevelDef::Class { name, ancestors, object_id, .. } = &mut *self.definition_ast_list[7].0.write() { + assert_eq!(*name, "Exception".into()); + ancestors.push(make_self_type_annotation(&[], *object_id)); + } else { + unreachable!(); } Ok(()) @@ -595,7 +618,6 @@ impl TopLevelComposer { let mut type_var_to_concrete_def: HashMap = HashMap::new(); - // skip 5 to skip analyzing the primitives for (class_def, class_ast) in def_ast_list.iter().skip(self.builtin_num) { if matches!(&*class_def.read(), TopLevelDef::Class { .. }) { Self::analyze_single_class_methods_fields( @@ -610,8 +632,6 @@ impl TopLevelComposer { } } - // println!("type_var_to_concrete_def1: {:?}", type_var_to_concrete_def); - // handle the inheritanced methods and fields let mut current_ancestor_depth: usize = 2; loop { @@ -646,19 +666,8 @@ impl TopLevelComposer { } } - // println!("type_var_to_concrete_def3: {:?}\n", type_var_to_concrete_def); - // unification of previously assigned typevar for (ty, def) in type_var_to_concrete_def { - // println!( - // "{:?}_{} -> {:?}\n", - // ty, - // unifier.stringify(ty, - // &mut |id| format!("class{}", id), - // &mut |id| format!("tvar{}", id) - // ), - // def - // ); let target_ty = get_type_from_type_annotation_kinds(&temp_def_list, unifier, primitives, &def)?; unifier.unify(ty, target_ty)?; @@ -946,7 +955,7 @@ impl TopLevelComposer { )) } } - + if name == &"__init__".into() && !defined_paramter_name.contains(&zelf) { return Err(format!("__init__ method must have a `self` parameter (at {})", b.location)); } @@ -1301,10 +1310,14 @@ impl TopLevelComposer { fn analyze_function_instance(&mut self) -> Result<(), String> { // first get the class contructor type correct for the following type check in function body // also do class field instantiation check - for (def, ast) in self.definition_ast_list.iter().skip(self.builtin_num) { + let init_str_id = "__init__".into(); + let mut definition_extension = Vec::new(); + let mut constructors = Vec::new(); + for (i, (def, ast)) in self.definition_ast_list.iter().enumerate().skip(self.builtin_num) { let class_def = def.read(); if let TopLevelDef::Class { constructor, + ancestors, methods, fields, type_vars, @@ -1314,13 +1327,54 @@ impl TopLevelComposer { .. } = &*class_def { + let self_type = get_type_from_type_annotation_kinds( + self.extract_def_list().as_slice(), + &mut self.unifier, + &self.primitives_ty, + &make_self_type_annotation(type_vars, *object_id), + )?; + if ancestors.iter().any(|ann| matches!(ann, TypeAnnotation::CustomClass { id, .. } if id.0 == 7)) { + // create constructor for these classes + let string = self.primitives_ty.str; + let int64 = self.primitives_ty.int64; + let signature = self.unifier.add_ty(TypeEnum::TFunc(RefCell::new(FunSignature { + args: vec![ + FuncArg { name: "msg".into(), ty: string, + default_value: Some(SymbolValue::Str("".into()))}, + FuncArg { name: "param0".into(), ty: int64, + default_value: Some(SymbolValue::I64(0))}, + FuncArg { name: "param1".into(), ty: int64, + default_value: Some(SymbolValue::I64(0))}, + FuncArg { name: "param2".into(), ty: int64, + default_value: Some(SymbolValue::I64(0))}, + ], + ret: self_type, + vars: Default::default() + }))); + let cons_fun = TopLevelDef::Function { + name: format!("{}.{}", class_name, "__init__"), + simple_name: init_str_id, + signature, + var_id: Default::default(), + instance_to_symbol: Default::default(), + instance_to_stmt: Default::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new(exn_constructor)))) + }; + constructors.push((i, signature, definition_extension.len())); + definition_extension.push((Arc::new(RwLock::new(cons_fun)), None)); + self.unifier + .unify(constructor.unwrap(), signature) + .map_err(|old| format!("{} (at {})", old, ast.as_ref().unwrap().location))?; + continue; + } let mut init_id: Option = None; // get the class contructor type correct let (contor_args, contor_type_vars) = { let mut constructor_args: Vec = Vec::new(); let mut type_vars: HashMap = HashMap::new(); for (name, func_sig, id) in methods { - if name == &"__init__".into() { + if *name == init_str_id { init_id = Some(*id); if let TypeEnum::TFunc(sig) = self.unifier.get_ty(*func_sig).as_ref() { let FunSignature { args, vars, .. } = &*sig.borrow(); @@ -1333,12 +1387,6 @@ impl TopLevelComposer { } (constructor_args, type_vars) }; - let self_type = get_type_from_type_annotation_kinds( - self.extract_def_list().as_slice(), - &mut self.unifier, - &self.primitives_ty, - &make_self_type_annotation(type_vars, *object_id), - )?; let contor_type = self.unifier.add_ty(TypeEnum::TFunc( FunSignature { args: contor_args, ret: self_type, vars: contor_type_vars } .into(), @@ -1352,7 +1400,7 @@ impl TopLevelComposer { let init_ast = self.definition_ast_list.get(init_id.0).unwrap().1.as_ref().unwrap(); if let ast::StmtKind::FunctionDef { name, body, .. } = &init_ast.node { - if name != &"__init__".into() { + if *name != init_str_id { unreachable!("must be init function here") } let all_inited = Self::get_all_assigned_field(body.as_slice())?; @@ -1370,11 +1418,23 @@ impl TopLevelComposer { } } } + for (i, signature, id) in constructors.into_iter() { + if let TopLevelDef::Class { methods, .. } = &mut *self.definition_ast_list[i].0.write() { + methods.push((init_str_id, signature, + DefinitionId(self.definition_ast_list.len() + id))); + } else { + unreachable!() + } + } + self.definition_ast_list.extend_from_slice(&definition_extension); let ctx = Arc::new(self.make_top_level_context()); // type inference inside function body for (id, (def, ast)) in self.definition_ast_list.iter().enumerate().skip(self.builtin_num) { + if ast.is_none() { + continue; + } let mut function_def = def.write(); if let TopLevelDef::Function { instance_to_stmt, diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index dcc74dd1..6d33a339 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -90,7 +90,22 @@ impl TopLevelComposer { fields: HashMap::new().into(), params: HashMap::new().into(), }); - let primitives = PrimitiveStore { int32, int64, float, bool, none, range, str }; + let exception = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(7), + fields: vec![ + ("__name__".into(), (int32, true)), + ("__file__".into(), (int32, true)), + ("__line__".into(), (int32, true)), + ("__col__".into(), (int32, true)), + ("__func__".into(), (str, true)), + ("__message__".into(), (str, true)), + ("__param0__".into(), (int64, true)), + ("__param1__".into(), (int64, true)), + ("__param2__".into(), (int64, true)), + ].into_iter().collect::>().into(), + params: HashMap::new().into(), + }); + let primitives = PrimitiveStore { int32, int64, float, bool, none, range, str, exception }; crate::typecheck::magic_methods::set_primitives_magic_methods(&primitives, &mut unifier); (primitives, unifier) } @@ -381,6 +396,13 @@ impl TopLevelComposer { Some("int64".to_string()) } } + SymbolValue::Str(..) => { + if matches!(ty, TypeAnnotation::Primitive(t) if *t == primitive.str) { + None + } else { + Some("str".to_string()) + } + } SymbolValue::Tuple(elts) => { if let TypeAnnotation::Tuple(elts_ty) = ty { for (e, t) in elts.iter().zip(elts_ty.iter()) { diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 8fb6c7f9..72bc59c6 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -63,6 +63,8 @@ pub fn parse_ast_to_type_annotation_kinds( Ok(TypeAnnotation::Primitive(primitives.none)) } else if id == &"str".into() { Ok(TypeAnnotation::Primitive(primitives.str)) + } else if id == &"Exception".into() { + Ok(TypeAnnotation::CustomClass { id: DefinitionId(7), params: Default::default() }) } else if let Some(obj_id) = resolver.get_identifier_def(*id) { let type_vars = { let def_read = top_level_defs[obj_id.0].try_read(); diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 9a303749..87bf8235 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -37,6 +37,7 @@ pub struct PrimitiveStore { pub none: Type, pub range: Type, pub str: Type, + pub exception: Type, } pub struct FunctionData {