1
0
forked from M-Labs/nac3

nac3core: change the place to unify constructor type for function body type check

add really basic field initialize check
This commit is contained in:
ychenfo 2021-09-20 23:26:04 +08:00
parent dd1be541b8
commit e66693282c
4 changed files with 74 additions and 20 deletions

View File

@ -1160,13 +1160,69 @@ impl TopLevelComposer {
/// step 5, analyze and call type inferecer to fill the `instance_to_stmt` of topleveldef::function /// step 5, analyze and call type inferecer to fill the `instance_to_stmt` of topleveldef::function
fn analyze_function_instance(&mut self) -> Result<(), String> { fn analyze_function_instance(&mut self) -> Result<(), String> {
// first get the class contructor type correct for the following type check in function body
// also do class field instantiation check
for (def, _) in self.definition_ast_list.iter().skip(self.built_in_num) {
let class_def = def.read();
if let TopLevelDef::Class {
constructor,
methods,
fields,
type_vars,
name,
object_id,
resolver: _,
..
} = &*class_def
{
let mut has_init = false;
// get the class contructor type correct
let (contor_args, contor_type_vars) = {
let mut constructor_args: Vec<FuncArg> = Vec::new();
let mut type_vars: HashMap<u32, Type> = HashMap::new();
for (name, func_sig, ..) in methods {
if name == "__init__" {
has_init = true;
if let TypeEnum::TFunc(sig) = self.unifier.get_ty(*func_sig).as_ref() {
let FunSignature { args, vars, .. } = &*sig.borrow();
constructor_args.extend_from_slice(args);
type_vars.extend(vars);
} else {
unreachable!("must be typeenum::tfunc")
}
}
}
(constructor_args, type_vars)
};
let self_type = get_type_from_type_annotation_kinds(
self.extract_def_list().as_slice(),
&mut self.unifier,
&self.primitives_ty,
&make_self_type_annotation(type_vars, *object_id),
)?;
let contor_type = self.unifier.add_ty(TypeEnum::TFunc(
FunSignature { args: contor_args, ret: self_type, vars: contor_type_vars }
.into(),
));
self.unifier.unify(constructor.unwrap(), contor_type)?;
// class field instantiation check
// TODO: this is a really simple one, more check later
if !has_init && !fields.is_empty() {
return Err(format!("fields of class {} not fully initialized", name))
}
}
}
// type inference inside function body
for (id, (def, ast)) in self.definition_ast_list.iter().enumerate().skip(self.built_in_num) for (id, (def, ast)) in self.definition_ast_list.iter().enumerate().skip(self.built_in_num)
{ {
let mut function_def = def.write(); let mut function_def = def.write();
if let TopLevelDef::Function { if let TopLevelDef::Function {
instance_to_stmt, instance_to_stmt,
name, name,
simple_name, simple_name: _,
signature, signature,
resolver, resolver,
.. ..
@ -1179,7 +1235,7 @@ impl TopLevelComposer {
if let Some(class_id) = self.method_class.get(&DefinitionId(id)) { if let Some(class_id) = self.method_class.get(&DefinitionId(id)) {
let class_def = self.definition_ast_list.get(class_id.0).unwrap(); let class_def = self.definition_ast_list.get(class_id.0).unwrap();
let class_def = class_def.0.read(); let class_def = class_def.0.read();
if let TopLevelDef::Class { type_vars, constructor, .. } = &*class_def { if let TopLevelDef::Class { type_vars, .. } = &*class_def {
let ty_ann = make_self_type_annotation(type_vars, *class_id); let ty_ann = make_self_type_annotation(type_vars, *class_id);
let self_ty = get_type_from_type_annotation_kinds( let self_ty = get_type_from_type_annotation_kinds(
self.extract_def_list().as_slice(), self.extract_def_list().as_slice(),
@ -1187,16 +1243,6 @@ impl TopLevelComposer {
&self.primitives_ty, &self.primitives_ty,
&ty_ann, &ty_ann,
)?; )?;
if simple_name == "__init__" {
let fn_type = self.unifier.add_ty(TypeEnum::TFunc(
RefCell::new(FunSignature {
args: args.clone(),
ret: self_ty,
vars: vars.clone(),
}),
));
self.unifier.unify(fn_type, constructor.unwrap())?;
}
Some(self_ty) Some(self_ty)
} else { } else {
unreachable!("must be class def") unreachable!("must be class def")

View File

@ -9,10 +9,11 @@ int output(int x) {
putchar('\n'); putchar('\n');
} else { } else {
if(x < strlen(chars)) { if(x < strlen(chars)) {
putchar(chars[x]); // putchar(chars[x]);
printf("%d\n", x);
} else { } else {
// printf("ERROR\n"); // printf("ERROR\n");
printf("%d", x); printf("%d\n", x);
} }
} }
return 0; return 0;

View File

@ -1,7 +1,9 @@
class A: class A:
a: int32 a: int32
b: B
def __init__(self, a: int32): def __init__(self, a: int32):
self.a = a self.a = a
self.b = B(a + 1)
def get_a(self) -> int32: def get_a(self) -> int32:
return self.a return self.a
@ -9,6 +11,15 @@ class A:
def get_self(self) -> A: def get_self(self) -> A:
return self return self
def get_b(self) -> B:
return self.b
class B:
b: int32
def __init__(self, a: int32):
self.b = a
def run() -> int32: def run() -> int32:
a = A(10) a = A(10)
output(a.a) output(a.a)
@ -16,5 +27,6 @@ def run() -> int32:
a = A(20) a = A(20)
output(a.a) output(a.a)
output(a.get_a()) output(a.get_a())
output(a.get_b().b)
return 0 return 0

View File

@ -5,7 +5,7 @@ use inkwell::{
OptimizationLevel, OptimizationLevel,
}; };
use nac3core::typecheck::type_inferencer::PrimitiveStore; use nac3core::typecheck::type_inferencer::PrimitiveStore;
use rustpython_parser::{parser, ast::StmtKind}; use rustpython_parser::parser;
use std::{collections::HashMap, path::Path, sync::Arc}; use std::{collections::HashMap, path::Path, sync::Arc};
use nac3core::{ use nac3core::{
@ -52,17 +52,12 @@ fn main() {
); );
for stmt in parser::parse_program(&program).unwrap().into_iter() { for stmt in parser::parse_program(&program).unwrap().into_iter() {
let is_class = matches!(stmt.node, StmtKind::ClassDef{ .. });
let (name, def_id, ty) = composer.register_top_level( let (name, def_id, ty) = composer.register_top_level(
stmt, stmt,
Some(resolver.clone()), Some(resolver.clone()),
"__main__".into(), "__main__".into(),
).unwrap(); ).unwrap();
if is_class {
internal_resolver.add_id_type(name.clone(), ty.unwrap());
}
internal_resolver.add_id_def(name.clone(), def_id); internal_resolver.add_id_def(name.clone(), def_id);
if let Some(ty) = ty { if let Some(ty) = ty {
internal_resolver.add_id_type(name, ty); internal_resolver.add_id_type(name, ty);