forked from M-Labs/nac3
Handle polymorphism as special calls
This commit is contained in:
parent
7e3d87f841
commit
4069852503
@ -23,7 +23,7 @@ impl Default for ComposerConfig {
|
||||
}
|
||||
}
|
||||
|
||||
type DefAst = (Arc<RwLock<TopLevelDef>>, Option<Stmt<()>>);
|
||||
pub type DefAst = (Arc<RwLock<TopLevelDef>>, Option<Stmt<()>>);
|
||||
pub struct TopLevelComposer {
|
||||
// list of top level definitions, same as top level context
|
||||
pub definition_ast_list: Vec<DefAst>,
|
||||
@ -1801,7 +1801,12 @@ impl TopLevelComposer {
|
||||
if *name != init_str_id {
|
||||
unreachable!("must be init function here")
|
||||
}
|
||||
let all_inited = Self::get_all_assigned_field(body.as_slice())?;
|
||||
// let all_inited = Self::get_all_assigned_field(body.as_slice())?;
|
||||
let all_inited = Self::get_all_assigned_field(
|
||||
definition_ast_list,
|
||||
def,
|
||||
body.as_slice(),
|
||||
)?;
|
||||
for (f, _, _) in fields {
|
||||
if !all_inited.contains(f) {
|
||||
return Err(HashSet::from([
|
||||
|
@ -3,6 +3,7 @@ use std::convert::TryInto;
|
||||
use crate::symbol_resolver::SymbolValue;
|
||||
use crate::toplevel::numpy::unpack_ndarray_var_tys;
|
||||
use crate::typecheck::typedef::{into_var_map, iter_type_vars, Mapping, TypeVarId, VarMap};
|
||||
use ast::ExprKind;
|
||||
use nac3parser::ast::{Constant, Location};
|
||||
use strum::IntoEnumIterator;
|
||||
use strum_macros::EnumIter;
|
||||
@ -732,8 +733,11 @@ impl TopLevelComposer {
|
||||
unifier,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn get_all_assigned_field(stmts: &[Stmt<()>]) -> Result<HashSet<StrRef>, HashSet<String>> {
|
||||
pub fn get_all_assigned_field(
|
||||
definition_ast_list: &Vec<DefAst>,
|
||||
def: &Arc<RwLock<TopLevelDef>>,
|
||||
stmts: &[Stmt<()>],
|
||||
) -> Result<HashSet<StrRef>, HashSet<String>> {
|
||||
let mut result = HashSet::new();
|
||||
for s in stmts {
|
||||
match &s.node {
|
||||
@ -769,32 +773,151 @@ impl TopLevelComposer {
|
||||
// TODO: do not check for For and While?
|
||||
ast::StmtKind::For { body, orelse, .. }
|
||||
| ast::StmtKind::While { body, orelse, .. } => {
|
||||
result.extend(Self::get_all_assigned_field(body.as_slice())?);
|
||||
result.extend(Self::get_all_assigned_field(orelse.as_slice())?);
|
||||
result.extend(Self::get_all_assigned_field(
|
||||
definition_ast_list,
|
||||
def,
|
||||
body.as_slice(),
|
||||
)?);
|
||||
result.extend(Self::get_all_assigned_field(
|
||||
definition_ast_list,
|
||||
def,
|
||||
orelse.as_slice(),
|
||||
)?);
|
||||
}
|
||||
ast::StmtKind::If { body, orelse, .. } => {
|
||||
let inited_for_sure = Self::get_all_assigned_field(body.as_slice())?
|
||||
.intersection(&Self::get_all_assigned_field(orelse.as_slice())?)
|
||||
.copied()
|
||||
.collect::<HashSet<_>>();
|
||||
let inited_for_sure =
|
||||
Self::get_all_assigned_field(definition_ast_list, def, body.as_slice())?
|
||||
.intersection(&Self::get_all_assigned_field(
|
||||
definition_ast_list,
|
||||
def,
|
||||
orelse.as_slice(),
|
||||
)?)
|
||||
.copied()
|
||||
.collect::<HashSet<_>>();
|
||||
result.extend(inited_for_sure);
|
||||
}
|
||||
ast::StmtKind::Try { body, orelse, finalbody, .. } => {
|
||||
let inited_for_sure = Self::get_all_assigned_field(body.as_slice())?
|
||||
.intersection(&Self::get_all_assigned_field(orelse.as_slice())?)
|
||||
.copied()
|
||||
.collect::<HashSet<_>>();
|
||||
let inited_for_sure =
|
||||
Self::get_all_assigned_field(definition_ast_list, def, body.as_slice())?
|
||||
.intersection(&Self::get_all_assigned_field(
|
||||
definition_ast_list,
|
||||
def,
|
||||
orelse.as_slice(),
|
||||
)?)
|
||||
.copied()
|
||||
.collect::<HashSet<_>>();
|
||||
result.extend(inited_for_sure);
|
||||
result.extend(Self::get_all_assigned_field(finalbody.as_slice())?);
|
||||
result.extend(Self::get_all_assigned_field(
|
||||
definition_ast_list,
|
||||
def,
|
||||
finalbody.as_slice(),
|
||||
)?);
|
||||
}
|
||||
ast::StmtKind::With { body, .. } => {
|
||||
result.extend(Self::get_all_assigned_field(body.as_slice())?);
|
||||
result.extend(Self::get_all_assigned_field(
|
||||
definition_ast_list,
|
||||
def,
|
||||
body.as_slice(),
|
||||
)?);
|
||||
}
|
||||
ast::StmtKind::Pass { .. }
|
||||
| ast::StmtKind::Assert { .. }
|
||||
| ast::StmtKind::Expr { .. } => {}
|
||||
// If its a call to __init__function of ancestor extend with ancestor fields
|
||||
ast::StmtKind::Expr { value, .. } => {
|
||||
// Check if Expression is a function call to self
|
||||
if let ExprKind::Call { func, args, .. } = &value.node {
|
||||
if let ExprKind::Attribute { value, attr: fn_name, .. } = &func.node {
|
||||
let class_def = def.read();
|
||||
let (ancestors, methods) = {
|
||||
let mut class_methods: HashMap<StrRef, DefinitionId> =
|
||||
HashMap::new();
|
||||
let mut class_ancestors: HashMap<
|
||||
StrRef,
|
||||
HashMap<StrRef, DefinitionId>,
|
||||
> = HashMap::new();
|
||||
|
||||
if let TopLevelDef::Class { methods, ancestors, .. } = &*class_def {
|
||||
for m in methods {
|
||||
class_methods.insert(m.0, m.2);
|
||||
}
|
||||
ancestors.iter().skip(1).for_each(|a| {
|
||||
if let TypeAnnotation::CustomClass { id, .. } = a {
|
||||
let anc_def =
|
||||
definition_ast_list.get(id.0).unwrap().0.read();
|
||||
if let TopLevelDef::Class { name, methods, .. } =
|
||||
&*anc_def
|
||||
{
|
||||
let mut temp: HashMap<StrRef, DefinitionId> =
|
||||
HashMap::new();
|
||||
for m in methods {
|
||||
temp.insert(m.0, m.2);
|
||||
}
|
||||
// Remove module name suffix from name
|
||||
let mut name_string = name.to_string();
|
||||
let split_loc =
|
||||
name_string.find(|c| c == '.').unwrap() + 1;
|
||||
class_ancestors.insert(
|
||||
name_string.split_off(split_loc).into(),
|
||||
temp,
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
(class_ancestors, class_methods)
|
||||
};
|
||||
if let ExprKind::Name { id, .. } = value.node {
|
||||
if id == "self".into() {
|
||||
// Get Class methods and fields
|
||||
let method_id = methods.get(fn_name);
|
||||
if method_id.is_some() {
|
||||
if let Some(fn_ast) = &definition_ast_list
|
||||
.get(method_id.unwrap().0)
|
||||
.unwrap()
|
||||
.1
|
||||
{
|
||||
if let ast::StmtKind::FunctionDef { body, .. } =
|
||||
&fn_ast.node
|
||||
{
|
||||
result.extend(Self::get_all_assigned_field(
|
||||
definition_ast_list,
|
||||
def,
|
||||
body.as_slice(),
|
||||
)?);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if let Some(ancestor_methods) = ancestors.get(&id) {
|
||||
// First arg must be `self` when calling ancestor function
|
||||
if let ExprKind::Name { id, .. } = args[0].node {
|
||||
if id == "self".into() {
|
||||
if let Some(method_id) = ancestor_methods.get(fn_name) {
|
||||
if let Some(fn_ast) =
|
||||
&definition_ast_list.get(method_id.0).unwrap().1
|
||||
{
|
||||
if let ast::StmtKind::FunctionDef {
|
||||
body, ..
|
||||
} = &fn_ast.node
|
||||
{
|
||||
result.extend(
|
||||
Self::get_all_assigned_field(
|
||||
definition_ast_list,
|
||||
def,
|
||||
body.as_slice(),
|
||||
)?,
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ast::StmtKind::Pass { .. } | ast::StmtKind::Assert { .. } => {}
|
||||
|
||||
_ => {
|
||||
println!("{:?}", s.node);
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
@ -11,6 +11,7 @@ use super::{
|
||||
RecordField, RecordKey, Type, TypeEnum, TypeVar, Unifier, VarMap,
|
||||
},
|
||||
};
|
||||
use crate::toplevel::type_annotation::TypeAnnotation;
|
||||
use crate::{
|
||||
symbol_resolver::{SymbolResolver, SymbolValue},
|
||||
toplevel::{
|
||||
@ -1029,7 +1030,97 @@ impl<'a> Inferencer<'a> {
|
||||
keywords: &[Located<ast::KeywordData>],
|
||||
) -> Result<Option<ast::Expr<Option<Type>>>, InferenceError> {
|
||||
let Located { location: func_location, node: ExprKind::Name { id, ctx }, .. } = func else {
|
||||
return Ok(None);
|
||||
// Must have self as input
|
||||
if args.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let Located { node: ExprKind::Attribute { value, attr: method_name, ctx }, .. } = func
|
||||
else {
|
||||
return Ok(None);
|
||||
};
|
||||
let ExprKind::Name { id: class_name, .. } = &value.node else { return Ok(None) };
|
||||
|
||||
// Check whether first param is self
|
||||
let first_arg = args.remove(0);
|
||||
let Located { node: ExprKind::Name { id: param_name, .. }, .. } = first_arg else {
|
||||
return Ok(None);
|
||||
};
|
||||
if param_name != "self".into() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Get Method from ancestors
|
||||
let zelf = &self.fold_expr(first_arg)?;
|
||||
let def_id = self.unifier.get_ty(zelf.custom.unwrap());
|
||||
let TypeEnum::TObj { obj_id, .. } = def_id.as_ref() else { unreachable!() };
|
||||
let defs = self.top_level.definitions.read();
|
||||
let result = {
|
||||
if let TopLevelDef::Class { ancestors, .. } = &*defs[obj_id.0].read() {
|
||||
ancestors.iter().find_map(|f| {
|
||||
println!("{}", f.stringify(self.unifier));
|
||||
let TypeAnnotation::CustomClass { id, .. } = f else { unreachable!() };
|
||||
let TopLevelDef::Class { name, methods, .. } = &*defs[id.0].read() else {
|
||||
unreachable!()
|
||||
};
|
||||
let name = name.to_string();
|
||||
let (_, name) = name.split_once('.').unwrap();
|
||||
println!("Comparing against => {name}, {class_name}");
|
||||
if name == class_name.to_string() {
|
||||
return methods.iter().find_map(|f| {
|
||||
if f.0 == *method_name {
|
||||
return Some(f.1);
|
||||
}
|
||||
None
|
||||
});
|
||||
}
|
||||
None
|
||||
})
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
.unwrap();
|
||||
|
||||
let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(result) else { return Ok(None) };
|
||||
|
||||
let args = args
|
||||
.iter_mut()
|
||||
.map(|v| self.fold_expr(v.clone()))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
// let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
// args: vec![FuncArg {
|
||||
// name: "n".into(),
|
||||
// ty: arg0.custom.unwrap(),
|
||||
// default_value: None,
|
||||
// is_vararg: false,
|
||||
// }],
|
||||
// ret,
|
||||
// vars: VarMap::new(),
|
||||
// }));
|
||||
|
||||
return Ok(Some(Located {
|
||||
location,
|
||||
custom: Some(sign.ret),
|
||||
node: ExprKind::Call {
|
||||
func: Box::new(Located {
|
||||
custom: Some(result),
|
||||
location: func.location,
|
||||
node: ExprKind::Attribute {
|
||||
value: Box::new(Located {
|
||||
location: func.location,
|
||||
custom: zelf.custom,
|
||||
node: ExprKind::Name { id: *class_name, ctx: *ctx },
|
||||
}),
|
||||
attr: *method_name,
|
||||
ctx: *ctx,
|
||||
},
|
||||
}),
|
||||
args,
|
||||
keywords: vec![],
|
||||
},
|
||||
}));
|
||||
};
|
||||
|
||||
// handle special functions that cannot be typed in the usual way...
|
||||
@ -1631,13 +1722,85 @@ impl<'a> Inferencer<'a> {
|
||||
return Ok(spec_call_func);
|
||||
}
|
||||
|
||||
let func = Box::new(self.fold_expr(func)?);
|
||||
let args = args.into_iter().map(|v| self.fold_expr(v)).collect::<Result<Vec<_>, _>>()?;
|
||||
let keywords = keywords
|
||||
.into_iter()
|
||||
.map(|v| fold::fold_keyword(self, v))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
println!("===============================");
|
||||
println!("=======Printing Func details=======");
|
||||
println!("Fun Location => {}", func.location);
|
||||
println!("Fun Node => {}", func.node.name());
|
||||
println!("Fun Args => {}", args.len());
|
||||
if !args.is_empty() {
|
||||
println!("First ArgNode => {}", args[0].node.name());
|
||||
}
|
||||
|
||||
if let ExprKind::Attribute { value, attr, .. } = &func.node {
|
||||
println!("Function Attributes");
|
||||
println!("Attr Name => {}", attr);
|
||||
println!("Value node => {}", value.node.name());
|
||||
if let ExprKind::Name { id: class_id, .. } = value.node {
|
||||
println!("Value Node ID => {class_id}");
|
||||
|
||||
// This ID is the parent class name
|
||||
// Resolve definition of class from self and get the ancestor list
|
||||
|
||||
let zelf = &self.fold_expr(args[0].clone()).unwrap();
|
||||
println!("Unification Key => {}", self.unifier.stringify(zelf.custom.unwrap()));
|
||||
let def_id = self.unifier.get_ty(zelf.custom.unwrap());
|
||||
let TypeEnum::TObj { obj_id, .. } = def_id.as_ref() else { unreachable!() };
|
||||
let defs = self.top_level.definitions.read();
|
||||
let result = {
|
||||
if let TopLevelDef::Class { ancestors, .. } = &*defs[obj_id.0].read() {
|
||||
ancestors.iter().find_map(|f| {
|
||||
let TypeAnnotation::CustomClass { id, .. } = f else { unreachable!() };
|
||||
let TopLevelDef::Class { name, methods, .. } = &*defs[id.0].read()
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
let name = name.to_string();
|
||||
let (_, name) = name.split_once('.').unwrap();
|
||||
println!("Comparing against => {name}, {class_id}");
|
||||
if name == class_id.to_string() {
|
||||
return methods.iter().find_map(|f| {
|
||||
if f.0 == *attr {
|
||||
return Some(f.1);
|
||||
}
|
||||
None
|
||||
});
|
||||
}
|
||||
None
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
.unwrap();
|
||||
|
||||
println!("Function in Selected Parent Class");
|
||||
// Construct new call add type checking later if it works
|
||||
let args = args
|
||||
.iter()
|
||||
.map(|v| self.fold_expr(v.clone()))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
// let func = Box::new(self.fold_expr(func.clone()).unwrap());
|
||||
// let ty = self.unifier.get_ty(result);
|
||||
println!("Function Type => {}", self.unifier.stringify(result));
|
||||
|
||||
// Now I have the unification key of the call
|
||||
// and vars for the call
|
||||
// Need to make call
|
||||
// Use special case for ref
|
||||
|
||||
// let expr = ExprKind::Attribute { value: (), attr: (), ctx: () }
|
||||
println!("======================");
|
||||
}
|
||||
}
|
||||
println!("=======Ending Func details=======");
|
||||
|
||||
let args = args.into_iter().map(|v| self.fold_expr(v)).collect::<Result<Vec<_>, _>>()?;
|
||||
let func = Box::new(self.fold_expr(func)?);
|
||||
if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(func.custom.unwrap()) {
|
||||
if sign.vars.is_empty() {
|
||||
let call = Call {
|
||||
|
0
nac3standalone/demo/interpreted.log
Normal file
0
nac3standalone/demo/interpreted.log
Normal file
@ -6,27 +6,31 @@ def output_int32(x: int32):
|
||||
|
||||
class A:
|
||||
a: int32
|
||||
|
||||
def __init__(self, a: int32):
|
||||
self.a = a
|
||||
def __init__(self, val: int32):
|
||||
self.a = val
|
||||
# self.f1()
|
||||
|
||||
def f1(self):
|
||||
self.f2()
|
||||
|
||||
def f2(self):
|
||||
output_int32(self.a)
|
||||
|
||||
class B(A):
|
||||
b: int32
|
||||
|
||||
def __init__(self, b: int32):
|
||||
self.a = b + 1
|
||||
self.b = b
|
||||
|
||||
def __init__(self, val1: int32, val2: int32):
|
||||
A.__init__(self, val1)
|
||||
self.b = val2
|
||||
|
||||
def f2(self):
|
||||
# A.f1(self)
|
||||
output_int32(self.b)
|
||||
|
||||
def run() -> int32:
|
||||
aaa = A(5)
|
||||
bbb = B(2)
|
||||
aaa.f1()
|
||||
bbb.f1()
|
||||
c1 = B(2, 4)
|
||||
# c1.f2()
|
||||
|
||||
|
||||
|
||||
# aaa = A(5)
|
||||
# bbb = B(2)
|
||||
# aaa.f1()
|
||||
# bbb.f1()
|
||||
return 0
|
||||
|
@ -59,7 +59,7 @@ impl SymbolResolver for Resolver {
|
||||
_: StrRef,
|
||||
_: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> Option<ValueEnum<'ctx>> {
|
||||
unimplemented!()
|
||||
None
|
||||
}
|
||||
|
||||
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> {
|
||||
|
Loading…
Reference in New Issue
Block a user