forked from M-Labs/nac3
1
0
Fork 0

Calling class methods statically implemented

This commit is contained in:
= 2024-07-28 23:53:58 +08:00
parent 44487b76ae
commit ab916df19e
10 changed files with 342 additions and 61 deletions

View File

@ -161,7 +161,9 @@
clippy clippy
pre-commit pre-commit
rustfmt rustfmt
rust-analyzer
]; ];
RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
}; };
devShells.x86_64-linux.msys2 = pkgs.mkShell { devShells.x86_64-linux.msys2 = pkgs.mkShell {
name = "nac3-dev-shell-msys2"; name = "nac3-dev-shell-msys2";

View File

@ -1460,6 +1460,7 @@ impl SymbolResolver for Resolver {
id: StrRef, id: StrRef,
_: &mut CodeGenContext<'ctx, '_>, _: &mut CodeGenContext<'ctx, '_>,
) -> Option<ValueEnum<'ctx>> { ) -> Option<ValueEnum<'ctx>> {
println!("dc");
let sym_value = { let sym_value = {
let id_to_val = self.0.id_to_pyval.read(); let id_to_val = self.0.id_to_pyval.read();
id_to_val.get(&id).cloned() id_to_val.get(&id).cloned()

View File

@ -1,3 +1,4 @@
use core::panic;
use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip}; use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
use crate::{ use crate::{
@ -51,8 +52,46 @@ pub fn get_subst_key(
) -> String { ) -> String {
let mut vars = obj let mut vars = obj
.map(|ty| { .map(|ty| {
let TypeEnum::TObj { params, .. } = &*unifier.get_ty(ty) else { unreachable!() }; // let (id, fun_id) = match &*ctx.unifier.get_ty(value.custom.unwrap()) {
params.clone() // TypeEnum::TObj { obj_id, .. } => {
// let fun_id = {
// let defs = ctx.top_level.definitions.read();
// let obj_def = defs.get(obj_id.0).unwrap().read();
// let TopLevelDef::Class { methods, .. } = &*obj_def else { unreachable!() };
// methods.iter().find(|method| method.0 == *attr).unwrap().2
// };
// (*obj_id, fun_id)
// }
// TypeEnum::TFunc(sign) => {
// let defs = ctx.top_level.definitions.read();
// let res = defs.iter().find_map(|def| {
// if let TopLevelDef::Class {object_id, methods, name, .. } = &*def.read() {
// if *name == ctx.unifier.stringify(sign.ret).into() {
// return Some((*object_id, methods.iter().find(|method| method.0 == *attr).unwrap().2))
// }
// }
// None
// }).unwrap();
// res
// // unreachable!()
// }
// _ => unreachable!()
// };
match &*unifier.get_ty(ty) {
TypeEnum::TObj { params, .. } => params.clone(),
TypeEnum::TFunc(sign) => {
let zelf = sign.args.iter().next().unwrap();
let TypeEnum::TObj { params, .. } = &*unifier.get_ty(zelf.ty) else {
unreachable!()
};
params.clone()
}
_ => unreachable!()
}
// let TypeEnum::TObj { params, .. } = &*unifier.get_ty(ty) else { unreachable!() };
// params.clone()
}) })
.unwrap_or_default(); .unwrap_or_default();
vars.extend(fun_vars); vars.extend(fun_vars);
@ -932,7 +971,7 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
} }
}) })
.collect_vec(); .collect_vec();
println!("FUnction Val: {:?}", fun_val);
Ok(ctx.build_call_or_invoke(fun_val, &param_vals, "call")) Ok(ctx.build_call_or_invoke(fun_val, &param_vals, "call"))
} }
@ -2456,10 +2495,12 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
} }
Some((_, Some(static_value), _)) => ValueEnum::Static(static_value.clone()), Some((_, Some(static_value), _)) => ValueEnum::Static(static_value.clone()),
None => { None => {
println!("{}", id);
let resolver = ctx.resolver.clone(); let resolver = ctx.resolver.clone();
if let Some(res) = resolver.get_symbol_value(*id, ctx) { if let Some(res) = resolver.get_symbol_value(*id, ctx) {
res res
} else { } else {
println!("Rnter Else Block");
// Allow "raise Exception" short form // Allow "raise Exception" short form
let def_id = resolver.get_identifier_def(*id).map_err(|e| { let def_id = resolver.get_identifier_def(*id).map_err(|e| {
format!("{} (at {})", e.iter().next().unwrap(), expr.location) format!("{} (at {})", e.iter().next().unwrap(), expr.location)
@ -2792,6 +2833,23 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
} }
} }
ExprKind::Call { func, args, keywords } => { ExprKind::Call { func, args, keywords } => {
let mut args = args.clone();
let zelf = {
if let Some(arg) = args.get(0) {
if let ExprKind::Name { id, .. } = &arg.node {
if *id == "self".into() {
Some(args.remove(0))
} else {
None
}
} else {
None
}
} else {
None
}
};
let mut params = args let mut params = args
.iter() .iter()
.map(|arg| generator.gen_expr(ctx, arg)) .map(|arg| generator.gen_expr(ctx, arg))
@ -2802,7 +2860,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
if params.len() < args.len() { if params.len() < args.len() {
return Ok(None); return Ok(None);
} }
println!("{}, {}", params.len(), args.len());
let kw_iter = keywords.iter().map(|kw| { let kw_iter = keywords.iter().map(|kw| {
Ok(( Ok((
Some(*kw.node.arg.as_ref().unwrap()), Some(*kw.node.arg.as_ref().unwrap()),
@ -2835,20 +2893,33 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
let Some(val) = generator.gen_expr(ctx, value)? else { return Ok(None) }; let Some(val) = generator.gen_expr(ctx, value)? else { return Ok(None) };
// Handle Class Method calls // Handle Class Method calls
let id = if let TypeEnum::TObj { obj_id, .. } = let (id, fun_id) = match &*ctx.unifier.get_ty(value.custom.unwrap()) {
&*ctx.unifier.get_ty(value.custom.unwrap()) TypeEnum::TObj { obj_id, .. } => {
{ let fun_id = {
*obj_id let defs = ctx.top_level.definitions.read();
} else { let obj_def = defs.get(obj_id.0).unwrap().read();
unreachable!() let TopLevelDef::Class { methods, .. } = &*obj_def else { unreachable!() };
};
let fun_id = {
let defs = ctx.top_level.definitions.read();
let obj_def = defs.get(id.0).unwrap().read();
let TopLevelDef::Class { methods, .. } = &*obj_def else { unreachable!() };
methods.iter().find(|method| method.0 == *attr).unwrap().2 methods.iter().find(|method| method.0 == *attr).unwrap().2
};
(*obj_id, fun_id)
}
TypeEnum::TFunc(sign) => {
let defs = ctx.top_level.definitions.read();
let res = defs.iter().find_map(|def| {
if let TopLevelDef::Class {object_id, methods, name, .. } = &*def.read() {
if *name == ctx.unifier.stringify(sign.ret).into() {
return Some((*object_id, methods.iter().find(|method| method.0 == *attr).unwrap().2))
}
}
None
}).unwrap();
res
// unreachable!()
}
_ => unreachable!()
}; };
// directly generate code for option.unwrap // directly generate code for option.unwrap
// since it needs to return static value to optimize for kernel invariant // since it needs to return static value to optimize for kernel invariant
if attr == &"unwrap".into() if attr == &"unwrap".into()
@ -2923,10 +2994,14 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
// Reset current_loc back to the location of the call // Reset current_loc back to the location of the call
ctx.current_loc = expr.location; ctx.current_loc = expr.location;
let obj_id = match zelf {
Some(arg) => arg.custom.unwrap(),
None => value.custom.unwrap()
};
return Ok(generator return Ok(generator
.gen_call( .gen_call(
ctx, ctx,
Some((value.custom.unwrap(), val)), Some((obj_id, val)),
(&signature, fun_id), (&signature, fun_id),
params, params,
)? )?

View File

@ -23,7 +23,7 @@ impl Default for ComposerConfig {
} }
} }
type DefAst = (Arc<RwLock<TopLevelDef>>, Option<Stmt<()>>); pub type DefAst = (Arc<RwLock<TopLevelDef>>, Option<Stmt<()>>);
pub struct TopLevelComposer { pub struct TopLevelComposer {
// list of top level definitions, same as top level context // list of top level definitions, same as top level context
pub definition_ast_list: Vec<DefAst>, pub definition_ast_list: Vec<DefAst>,
@ -1723,7 +1723,13 @@ impl TopLevelComposer {
if *name != init_str_id { if *name != init_str_id {
unreachable!("must be init function here") unreachable!("must be init function here")
} }
let all_inited = Self::get_all_assigned_field(body.as_slice())?; // let all_inited = Self::get_all_assigned_field(body.as_slice())?;
let all_inited = Self::get_all_assigned_field(
definition_ast_list,
def,
body.as_slice(),
)?;
for (f, _, _) in fields { for (f, _, _) in fields {
if !all_inited.contains(f) { if !all_inited.contains(f) {
return Err(HashSet::from([ return Err(HashSet::from([

View File

@ -3,6 +3,7 @@ use std::convert::TryInto;
use crate::symbol_resolver::SymbolValue; use crate::symbol_resolver::SymbolValue;
use crate::toplevel::numpy::unpack_ndarray_var_tys; use crate::toplevel::numpy::unpack_ndarray_var_tys;
use crate::typecheck::typedef::{into_var_map, iter_type_vars, Mapping, TypeVarId, VarMap}; use crate::typecheck::typedef::{into_var_map, iter_type_vars, Mapping, TypeVarId, VarMap};
use ast::ExprKind;
use nac3parser::ast::{Constant, Location}; use nac3parser::ast::{Constant, Location};
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
use strum_macros::EnumIter; use strum_macros::EnumIter;
@ -677,7 +678,11 @@ impl TopLevelComposer {
) )
} }
pub fn get_all_assigned_field(stmts: &[Stmt<()>]) -> Result<HashSet<StrRef>, HashSet<String>> { pub fn get_all_assigned_field(
definition_ast_list: &Vec<DefAst>,
def: &Arc<RwLock<TopLevelDef>>,
stmts: &[Stmt<()>],
) -> Result<HashSet<StrRef>, HashSet<String>> {
let mut result = HashSet::new(); let mut result = HashSet::new();
for s in stmts { for s in stmts {
match &s.node { match &s.node {
@ -713,32 +718,151 @@ impl TopLevelComposer {
// TODO: do not check for For and While? // TODO: do not check for For and While?
ast::StmtKind::For { body, orelse, .. } ast::StmtKind::For { body, orelse, .. }
| ast::StmtKind::While { body, orelse, .. } => { | ast::StmtKind::While { body, orelse, .. } => {
result.extend(Self::get_all_assigned_field(body.as_slice())?); result.extend(Self::get_all_assigned_field(
result.extend(Self::get_all_assigned_field(orelse.as_slice())?); definition_ast_list,
def,
body.as_slice(),
)?);
result.extend(Self::get_all_assigned_field(
definition_ast_list,
def,
orelse.as_slice(),
)?);
} }
ast::StmtKind::If { body, orelse, .. } => { ast::StmtKind::If { body, orelse, .. } => {
let inited_for_sure = Self::get_all_assigned_field(body.as_slice())? let inited_for_sure =
.intersection(&Self::get_all_assigned_field(orelse.as_slice())?) Self::get_all_assigned_field(definition_ast_list, def, body.as_slice())?
.copied() .intersection(&Self::get_all_assigned_field(
.collect::<HashSet<_>>(); definition_ast_list,
def,
orelse.as_slice(),
)?)
.copied()
.collect::<HashSet<_>>();
result.extend(inited_for_sure); result.extend(inited_for_sure);
} }
ast::StmtKind::Try { body, orelse, finalbody, .. } => { ast::StmtKind::Try { body, orelse, finalbody, .. } => {
let inited_for_sure = Self::get_all_assigned_field(body.as_slice())? let inited_for_sure =
.intersection(&Self::get_all_assigned_field(orelse.as_slice())?) Self::get_all_assigned_field(definition_ast_list, def, body.as_slice())?
.copied() .intersection(&Self::get_all_assigned_field(
.collect::<HashSet<_>>(); definition_ast_list,
def,
orelse.as_slice(),
)?)
.copied()
.collect::<HashSet<_>>();
result.extend(inited_for_sure); result.extend(inited_for_sure);
result.extend(Self::get_all_assigned_field(finalbody.as_slice())?); result.extend(Self::get_all_assigned_field(
definition_ast_list,
def,
finalbody.as_slice(),
)?);
} }
ast::StmtKind::With { body, .. } => { ast::StmtKind::With { body, .. } => {
result.extend(Self::get_all_assigned_field(body.as_slice())?); result.extend(Self::get_all_assigned_field(
definition_ast_list,
def,
body.as_slice(),
)?);
} }
ast::StmtKind::Pass { .. } // If its a call to __init__function of ancestor extend with ancestor fields
| ast::StmtKind::Assert { .. } ast::StmtKind::Expr { value, .. } => {
| ast::StmtKind::Expr { .. } => {} // Check if Expression is a function call to self
if let ExprKind::Call { func, args, .. } = &value.node {
if let ExprKind::Attribute { value, attr: fn_name, .. } = &func.node {
let class_def = def.read();
let (ancestors, methods) = {
let mut class_methods: HashMap<StrRef, DefinitionId> =
HashMap::new();
let mut class_ancestors: HashMap<
StrRef,
HashMap<StrRef, DefinitionId>,
> = HashMap::new();
if let TopLevelDef::Class { methods, ancestors, .. } = &*class_def {
for m in methods {
class_methods.insert(m.0, m.2);
}
ancestors.iter().skip(1).for_each(|a| {
if let TypeAnnotation::CustomClass { id, .. } = a {
let anc_def =
definition_ast_list.get(id.0).unwrap().0.read();
if let TopLevelDef::Class { name, methods, .. } =
&*anc_def
{
let mut temp: HashMap<StrRef, DefinitionId> =
HashMap::new();
for m in methods {
temp.insert(m.0, m.2);
}
// Remove module name suffix from name
let mut name_string = name.to_string();
let split_loc =
name_string.find(|c| c == '.').unwrap() + 1;
class_ancestors.insert(
name_string.split_off(split_loc).into(),
temp,
);
}
}
});
}
(class_ancestors, class_methods)
};
if let ExprKind::Name { id, .. } = value.node {
if id == "self".into() {
// Get Class methods and fields
let method_id = methods.get(fn_name);
if method_id.is_some() {
if let Some(fn_ast) = &definition_ast_list
.get(method_id.unwrap().0)
.unwrap()
.1
{
if let ast::StmtKind::FunctionDef { body, .. } =
&fn_ast.node
{
result.extend(Self::get_all_assigned_field(
definition_ast_list,
def,
body.as_slice(),
)?);
}
}
}
} else if let Some(ancestor_methods) = ancestors.get(&id) {
// First arg must be `self` when calling ancestor function
if let ExprKind::Name { id, .. } = args[0].node {
if id == "self".into() {
if let Some(method_id) = ancestor_methods.get(fn_name) {
if let Some(fn_ast) =
&definition_ast_list.get(method_id.0).unwrap().1
{
if let ast::StmtKind::FunctionDef {
body, ..
} = &fn_ast.node
{
result.extend(
Self::get_all_assigned_field(
definition_ast_list,
def,
body.as_slice(),
)?,
);
}
}
};
}
}
}
}
}
}
}
ast::StmtKind::Pass { .. } | ast::StmtKind::Assert { .. } => {}
_ => { _ => {
println!("{:?}", s.node);
unimplemented!() unimplemented!()
} }
} }

View File

@ -581,6 +581,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
ExprKind::List { elts, .. } => Some(self.infer_list(elts)?), ExprKind::List { elts, .. } => Some(self.infer_list(elts)?),
ExprKind::Tuple { elts, .. } => Some(self.infer_tuple(elts)?), ExprKind::Tuple { elts, .. } => Some(self.infer_tuple(elts)?),
ExprKind::Attribute { value, attr, ctx } => { ExprKind::Attribute { value, attr, ctx } => {
println!("Attr Called");
Some(self.infer_attribute(value, *attr, *ctx)?) Some(self.infer_attribute(value, *attr, *ctx)?)
} }
ExprKind::BoolOp { values, .. } => Some(self.infer_bool_ops(values)?), ExprKind::BoolOp { values, .. } => Some(self.infer_bool_ops(values)?),
@ -1513,14 +1514,54 @@ impl<'a> Inferencer<'a> {
mut args: Vec<ast::Expr<()>>, mut args: Vec<ast::Expr<()>>,
keywords: Vec<Located<ast::KeywordData>>, keywords: Vec<Located<ast::KeywordData>>,
) -> Result<ast::Expr<Option<Type>>, HashSet<String>> { ) -> Result<ast::Expr<Option<Type>>, HashSet<String>> {
println!("{:?}", func);
if let Some(spec_call_func) = if let Some(spec_call_func) =
self.try_fold_special_call(location, &func, &mut args, &keywords)? self.try_fold_special_call(location, &func, &mut args, &keywords)?
{ {
return Ok(spec_call_func); return Ok(spec_call_func);
} }
let func = Box::new(self.fold_expr(func)?); println!("Trying Args");
let args = args.into_iter().map(|v| self.fold_expr(v)).collect::<Result<Vec<_>, _>>()?; let mut args = args.into_iter().map(|v| self.fold_expr(v)).collect::<Result<Vec<_>, _>>()?;
let (func, arg_self) = {
if let Some(arg) = args.iter().next() {
if let ExprKind::Name { id, .. } = arg.node {
if id == "self".into() {
// args.remove(0);
let expr = match func.node {
ExprKind::Call { func, args, keywords } => {
return self.fold_call(func.location, *func, args, keywords);
}
_ => fold::fold_expr(self, func.clone())?,
};
let ExprKind::Attribute { value, attr, ctx } = &expr.node else {
return report_error("Unsupported Statement", location);
};
let ty = value.custom.unwrap();
let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(ty) else {
return report_error("Unsupported Statement", location);
};
// Check for ancestors methods
(Box::new(self.fold_expr(func)?), Some(args.remove(0)))
} else {
(Box::new(self.fold_expr(func)?), None)
}
} else {
(Box::new(self.fold_expr(func)?), None)
}
} else {
(Box::new(self.fold_expr(func)?), None)
}
};
// let func = Box::new(self.fold_expr(func)?);
println!("Failed");
let keywords = keywords let keywords = keywords
.into_iter() .into_iter()
.map(|v| fold::fold_keyword(self, v)) .map(|v| fold::fold_keyword(self, v))
@ -1539,9 +1580,14 @@ impl<'a> Inferencer<'a> {
loc: Some(location), loc: Some(location),
operator_info: None, operator_info: None,
}; };
println!("Try Unigu");
self.unifier.unify_call(&call, func.custom.unwrap(), sign).map_err(|e| { self.unifier.unify_call(&call, func.custom.unwrap(), sign).map_err(|e| {
HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()]) HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()])
})?; })?;
if let Some(arg) = arg_self {
args.insert(0, arg);
}
println!("Reutrnin");
return Ok(Located { return Ok(Located {
location, location,
custom: Some(sign.ret), custom: Some(sign.ret),
@ -1665,10 +1711,11 @@ impl<'a> Inferencer<'a> {
ctx: ExprContext, ctx: ExprContext,
) -> InferenceResult { ) -> InferenceResult {
let ty = value.custom.unwrap(); let ty = value.custom.unwrap();
println!("{:?}", value);
if let TypeEnum::TObj { obj_id, fields, .. } = &*self.unifier.get_ty(ty) { if let TypeEnum::TObj { obj_id, fields, .. } = &*self.unifier.get_ty(ty) {
// just a fast path // just a fast path
match (fields.get(&attr), ctx == ExprContext::Store) { match (fields.get(&attr), ctx == ExprContext::Store) {
(Some((ty, true)), _) | (Some((ty, false)), false) => Ok(*ty), (Some((ty, true)), _) | (Some((ty, false)), false) => {println!("Returning"); Ok(*ty)},
(Some((ty, false)), true) => report_type_error( (Some((ty, false)), true) => report_type_error(
TypeErrorKind::MutationError(RecordKey::Str(attr), *ty), TypeErrorKind::MutationError(RecordKey::Str(attr), *ty),
Some(value.location), Some(value.location),
@ -1705,12 +1752,15 @@ impl<'a> Inferencer<'a> {
} }
} else if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(ty) { } else if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(ty) {
// Access Class Attributes of classes with __init__ function using Class names e.g. Foo.ATTR1 // Access Class Attributes of classes with __init__ function using Class names e.g. Foo.ATTR1
// Remember to restore p[lz]
println!("jwicef\n");
let result = { let result = {
self.top_level.definitions.read().iter().find_map(|def| { self.top_level.definitions.read().iter().find_map(|def| {
if let Some(rear_guard) = def.try_read() { if let Some(rear_guard) = def.try_read() {
if let TopLevelDef::Class { name, attributes, .. } = &*rear_guard { if let TopLevelDef::Class { name, methods, .. } = &*rear_guard {
if name.to_string() == self.unifier.stringify(sign.ret) { if name.to_string() == self.unifier.stringify(sign.ret) {
return attributes.iter().find_map(|f| { return methods.iter().find_map(|f| {
if f.0 == attr { if f.0 == attr {
return Some(f.clone().1); return Some(f.clone().1);
} }
@ -1730,6 +1780,7 @@ impl<'a> Inferencer<'a> {
None => self.infer_general_attribute(value, attr, ctx), None => self.infer_general_attribute(value, attr, ctx),
} }
} else { } else {
println!("ncfe\n");
self.infer_general_attribute(value, attr, ctx) self.infer_general_attribute(value, attr, ctx)
} }
} }

View File

View File

@ -4,29 +4,51 @@ from __future__ import annotations
def output_int32(x: int32): def output_int32(x: int32):
... ...
class A: class C:
c: int32
a: int32 a: int32
def __init__(self, a: int32):
self.a = a
def f1(self):
self.f2()
def f2(self):
output_int32(self.a)
class B(A):
b: int32 b: int32
def __init__(self):
self.a = 42
self.b = 33
self.c = 12
def test2(self):
output_int32(999)
output_int32(self.a)
output_int32(self.b)
output_int32(self.c)
self.a = 23
class D(C):
def __init__(self):
# C.__init__(self)
self.test()
self.b = 1
self.c = 2
C.test2(self)
#self.a()
# self.test()
# C.test2(self)
# self.a = 2
# __main__.C.__init__(self)
def test(self):
self.a = 2
def __init__(self, b: int32):
self.a = b + 1
self.b = b
def run() -> int32: def run() -> int32:
aaa = A(5) x = D()
bbb = B(2) output_int32(x.a)
aaa.f1() output_int32(x.b)
bbb.f1() output_int32(x.c)
# aaa = A(5)
# bbb = B(2)
# aaa.f1()
# bbb.f1()
return 0 return 0

View File

@ -59,7 +59,7 @@ impl SymbolResolver for Resolver {
_: StrRef, _: StrRef,
_: &mut CodeGenContext<'ctx, '_>, _: &mut CodeGenContext<'ctx, '_>,
) -> Option<ValueEnum<'ctx>> { ) -> Option<ValueEnum<'ctx>> {
unimplemented!() None
} }
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> { fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> {

BIN
pyo3_output/nac3artiq.so Executable file

Binary file not shown.