diff --git a/flake.nix b/flake.nix index 4febca24..7bd28c70 100644 --- a/flake.nix +++ b/flake.nix @@ -161,7 +161,9 @@ clippy pre-commit rustfmt + rust-analyzer ]; + RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}"; }; devShells.x86_64-linux.msys2 = pkgs.mkShell { name = "nac3-dev-shell-msys2"; diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 2e998ff3..376ead60 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -1460,6 +1460,7 @@ impl SymbolResolver for Resolver { id: StrRef, _: &mut CodeGenContext<'ctx, '_>, ) -> Option> { + println!("dc"); let sym_value = { let id_to_val = self.0.id_to_pyval.read(); id_to_val.get(&id).cloned() diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 38ac9a63..ad8581c9 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1,3 +1,4 @@ +use core::panic; use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip}; use crate::{ @@ -51,8 +52,46 @@ pub fn get_subst_key( ) -> String { let mut vars = obj .map(|ty| { - let TypeEnum::TObj { params, .. } = &*unifier.get_ty(ty) else { unreachable!() }; - params.clone() + // let (id, fun_id) = match &*ctx.unifier.get_ty(value.custom.unwrap()) { + // TypeEnum::TObj { obj_id, .. } => { + // let fun_id = { + // let defs = ctx.top_level.definitions.read(); + // let obj_def = defs.get(obj_id.0).unwrap().read(); + // let TopLevelDef::Class { methods, .. } = &*obj_def else { unreachable!() }; + + // methods.iter().find(|method| method.0 == *attr).unwrap().2 + // }; + // (*obj_id, fun_id) + // } + // TypeEnum::TFunc(sign) => { + // let defs = ctx.top_level.definitions.read(); + // let res = defs.iter().find_map(|def| { + // if let TopLevelDef::Class {object_id, methods, name, .. } = &*def.read() { + // if *name == ctx.unifier.stringify(sign.ret).into() { + // return Some((*object_id, methods.iter().find(|method| method.0 == *attr).unwrap().2)) + // } + // } + // None + // }).unwrap(); + // res + // // unreachable!() + // } + // _ => unreachable!() + // }; + match &*unifier.get_ty(ty) { + TypeEnum::TObj { params, .. } => params.clone(), + TypeEnum::TFunc(sign) => { + let zelf = sign.args.iter().next().unwrap(); + let TypeEnum::TObj { params, .. } = &*unifier.get_ty(zelf.ty) else { + unreachable!() + }; + params.clone() + } + _ => unreachable!() + } + + // let TypeEnum::TObj { params, .. } = &*unifier.get_ty(ty) else { unreachable!() }; + // params.clone() }) .unwrap_or_default(); vars.extend(fun_vars); @@ -932,7 +971,7 @@ pub fn gen_call<'ctx, G: CodeGenerator>( } }) .collect_vec(); - + println!("FUnction Val: {:?}", fun_val); Ok(ctx.build_call_or_invoke(fun_val, ¶m_vals, "call")) } @@ -2456,10 +2495,12 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } Some((_, Some(static_value), _)) => ValueEnum::Static(static_value.clone()), None => { + println!("{}", id); let resolver = ctx.resolver.clone(); if let Some(res) = resolver.get_symbol_value(*id, ctx) { res } else { + println!("Rnter Else Block"); // Allow "raise Exception" short form let def_id = resolver.get_identifier_def(*id).map_err(|e| { format!("{} (at {})", e.iter().next().unwrap(), expr.location) @@ -2792,6 +2833,23 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } } ExprKind::Call { func, args, keywords } => { + let mut args = args.clone(); + let zelf = { + if let Some(arg) = args.get(0) { + if let ExprKind::Name { id, .. } = &arg.node { + if *id == "self".into() { + Some(args.remove(0)) + } else { + None + } + } else { + None + } + } else { + None + } + }; + let mut params = args .iter() .map(|arg| generator.gen_expr(ctx, arg)) @@ -2802,7 +2860,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( if params.len() < args.len() { return Ok(None); } - + println!("{}, {}", params.len(), args.len()); let kw_iter = keywords.iter().map(|kw| { Ok(( Some(*kw.node.arg.as_ref().unwrap()), @@ -2835,20 +2893,33 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( let Some(val) = generator.gen_expr(ctx, value)? else { return Ok(None) }; // Handle Class Method calls - let id = if let TypeEnum::TObj { obj_id, .. } = - &*ctx.unifier.get_ty(value.custom.unwrap()) - { - *obj_id - } else { - unreachable!() - }; - let fun_id = { - let defs = ctx.top_level.definitions.read(); - let obj_def = defs.get(id.0).unwrap().read(); - let TopLevelDef::Class { methods, .. } = &*obj_def else { unreachable!() }; - - methods.iter().find(|method| method.0 == *attr).unwrap().2 + let (id, fun_id) = match &*ctx.unifier.get_ty(value.custom.unwrap()) { + TypeEnum::TObj { obj_id, .. } => { + let fun_id = { + let defs = ctx.top_level.definitions.read(); + let obj_def = defs.get(obj_id.0).unwrap().read(); + let TopLevelDef::Class { methods, .. } = &*obj_def else { unreachable!() }; + + methods.iter().find(|method| method.0 == *attr).unwrap().2 + }; + (*obj_id, fun_id) + } + TypeEnum::TFunc(sign) => { + let defs = ctx.top_level.definitions.read(); + let res = defs.iter().find_map(|def| { + if let TopLevelDef::Class {object_id, methods, name, .. } = &*def.read() { + if *name == ctx.unifier.stringify(sign.ret).into() { + return Some((*object_id, methods.iter().find(|method| method.0 == *attr).unwrap().2)) + } + } + None + }).unwrap(); + res + // unreachable!() + } + _ => unreachable!() }; + // directly generate code for option.unwrap // since it needs to return static value to optimize for kernel invariant if attr == &"unwrap".into() @@ -2923,10 +2994,14 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( // Reset current_loc back to the location of the call ctx.current_loc = expr.location; + let obj_id = match zelf { + Some(arg) => arg.custom.unwrap(), + None => value.custom.unwrap() + }; return Ok(generator .gen_call( ctx, - Some((value.custom.unwrap(), val)), + Some((obj_id, val)), (&signature, fun_id), params, )? diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 58ae94fd..b4617cf0 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -23,7 +23,7 @@ impl Default for ComposerConfig { } } -type DefAst = (Arc>, Option>); +pub type DefAst = (Arc>, Option>); pub struct TopLevelComposer { // list of top level definitions, same as top level context pub definition_ast_list: Vec, @@ -1723,7 +1723,13 @@ impl TopLevelComposer { if *name != init_str_id { unreachable!("must be init function here") } - let all_inited = Self::get_all_assigned_field(body.as_slice())?; + // let all_inited = Self::get_all_assigned_field(body.as_slice())?; + let all_inited = Self::get_all_assigned_field( + definition_ast_list, + def, + body.as_slice(), + )?; + for (f, _, _) in fields { if !all_inited.contains(f) { return Err(HashSet::from([ diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 538e653e..550b19ce 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -3,6 +3,7 @@ use std::convert::TryInto; use crate::symbol_resolver::SymbolValue; use crate::toplevel::numpy::unpack_ndarray_var_tys; use crate::typecheck::typedef::{into_var_map, iter_type_vars, Mapping, TypeVarId, VarMap}; +use ast::ExprKind; use nac3parser::ast::{Constant, Location}; use strum::IntoEnumIterator; use strum_macros::EnumIter; @@ -677,7 +678,11 @@ impl TopLevelComposer { ) } - pub fn get_all_assigned_field(stmts: &[Stmt<()>]) -> Result, HashSet> { + pub fn get_all_assigned_field( + definition_ast_list: &Vec, + def: &Arc>, + stmts: &[Stmt<()>], + ) -> Result, HashSet> { let mut result = HashSet::new(); for s in stmts { match &s.node { @@ -713,32 +718,151 @@ impl TopLevelComposer { // TODO: do not check for For and While? ast::StmtKind::For { body, orelse, .. } | ast::StmtKind::While { body, orelse, .. } => { - result.extend(Self::get_all_assigned_field(body.as_slice())?); - result.extend(Self::get_all_assigned_field(orelse.as_slice())?); + result.extend(Self::get_all_assigned_field( + definition_ast_list, + def, + body.as_slice(), + )?); + result.extend(Self::get_all_assigned_field( + definition_ast_list, + def, + orelse.as_slice(), + )?); } ast::StmtKind::If { body, orelse, .. } => { - let inited_for_sure = Self::get_all_assigned_field(body.as_slice())? - .intersection(&Self::get_all_assigned_field(orelse.as_slice())?) - .copied() - .collect::>(); + let inited_for_sure = + Self::get_all_assigned_field(definition_ast_list, def, body.as_slice())? + .intersection(&Self::get_all_assigned_field( + definition_ast_list, + def, + orelse.as_slice(), + )?) + .copied() + .collect::>(); result.extend(inited_for_sure); } ast::StmtKind::Try { body, orelse, finalbody, .. } => { - let inited_for_sure = Self::get_all_assigned_field(body.as_slice())? - .intersection(&Self::get_all_assigned_field(orelse.as_slice())?) - .copied() - .collect::>(); + let inited_for_sure = + Self::get_all_assigned_field(definition_ast_list, def, body.as_slice())? + .intersection(&Self::get_all_assigned_field( + definition_ast_list, + def, + orelse.as_slice(), + )?) + .copied() + .collect::>(); result.extend(inited_for_sure); - result.extend(Self::get_all_assigned_field(finalbody.as_slice())?); + result.extend(Self::get_all_assigned_field( + definition_ast_list, + def, + finalbody.as_slice(), + )?); } ast::StmtKind::With { body, .. } => { - result.extend(Self::get_all_assigned_field(body.as_slice())?); + result.extend(Self::get_all_assigned_field( + definition_ast_list, + def, + body.as_slice(), + )?); } - ast::StmtKind::Pass { .. } - | ast::StmtKind::Assert { .. } - | ast::StmtKind::Expr { .. } => {} + // If its a call to __init__function of ancestor extend with ancestor fields + ast::StmtKind::Expr { value, .. } => { + // Check if Expression is a function call to self + if let ExprKind::Call { func, args, .. } = &value.node { + if let ExprKind::Attribute { value, attr: fn_name, .. } = &func.node { + let class_def = def.read(); + let (ancestors, methods) = { + let mut class_methods: HashMap = + HashMap::new(); + let mut class_ancestors: HashMap< + StrRef, + HashMap, + > = HashMap::new(); + + if let TopLevelDef::Class { methods, ancestors, .. } = &*class_def { + for m in methods { + class_methods.insert(m.0, m.2); + } + ancestors.iter().skip(1).for_each(|a| { + if let TypeAnnotation::CustomClass { id, .. } = a { + let anc_def = + definition_ast_list.get(id.0).unwrap().0.read(); + if let TopLevelDef::Class { name, methods, .. } = + &*anc_def + { + let mut temp: HashMap = + HashMap::new(); + for m in methods { + temp.insert(m.0, m.2); + } + // Remove module name suffix from name + let mut name_string = name.to_string(); + let split_loc = + name_string.find(|c| c == '.').unwrap() + 1; + class_ancestors.insert( + name_string.split_off(split_loc).into(), + temp, + ); + } + } + }); + } + (class_ancestors, class_methods) + }; + if let ExprKind::Name { id, .. } = value.node { + if id == "self".into() { + // Get Class methods and fields + let method_id = methods.get(fn_name); + if method_id.is_some() { + if let Some(fn_ast) = &definition_ast_list + .get(method_id.unwrap().0) + .unwrap() + .1 + { + if let ast::StmtKind::FunctionDef { body, .. } = + &fn_ast.node + { + result.extend(Self::get_all_assigned_field( + definition_ast_list, + def, + body.as_slice(), + )?); + } + } + } + } else if let Some(ancestor_methods) = ancestors.get(&id) { + // First arg must be `self` when calling ancestor function + if let ExprKind::Name { id, .. } = args[0].node { + if id == "self".into() { + if let Some(method_id) = ancestor_methods.get(fn_name) { + if let Some(fn_ast) = + &definition_ast_list.get(method_id.0).unwrap().1 + { + if let ast::StmtKind::FunctionDef { + body, .. + } = &fn_ast.node + { + result.extend( + Self::get_all_assigned_field( + definition_ast_list, + def, + body.as_slice(), + )?, + ); + } + } + }; + } + } + } + } + } + } + } + ast::StmtKind::Pass { .. } | ast::StmtKind::Assert { .. } => {} _ => { + println!("{:?}", s.node); unimplemented!() } } diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index d9380ab1..22c4a64a 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -581,6 +581,7 @@ impl<'a> Fold<()> for Inferencer<'a> { ExprKind::List { elts, .. } => Some(self.infer_list(elts)?), ExprKind::Tuple { elts, .. } => Some(self.infer_tuple(elts)?), ExprKind::Attribute { value, attr, ctx } => { + println!("Attr Called"); Some(self.infer_attribute(value, *attr, *ctx)?) } ExprKind::BoolOp { values, .. } => Some(self.infer_bool_ops(values)?), @@ -1513,14 +1514,54 @@ impl<'a> Inferencer<'a> { mut args: Vec>, keywords: Vec>, ) -> Result>, HashSet> { + println!("{:?}", func); if let Some(spec_call_func) = self.try_fold_special_call(location, &func, &mut args, &keywords)? { return Ok(spec_call_func); } - let func = Box::new(self.fold_expr(func)?); - let args = args.into_iter().map(|v| self.fold_expr(v)).collect::, _>>()?; + println!("Trying Args"); + let mut args = args.into_iter().map(|v| self.fold_expr(v)).collect::, _>>()?; + let (func, arg_self) = { + if let Some(arg) = args.iter().next() { + if let ExprKind::Name { id, .. } = arg.node { + if id == "self".into() { + // args.remove(0); + + + let expr = match func.node { + ExprKind::Call { func, args, keywords } => { + return self.fold_call(func.location, *func, args, keywords); + } + + _ => fold::fold_expr(self, func.clone())?, + }; + + + let ExprKind::Attribute { value, attr, ctx } = &expr.node else { + return report_error("Unsupported Statement", location); + }; + let ty = value.custom.unwrap(); + let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(ty) else { + return report_error("Unsupported Statement", location); + }; + + // Check for ancestors methods + (Box::new(self.fold_expr(func)?), Some(args.remove(0))) + + } else { + (Box::new(self.fold_expr(func)?), None) + } + } else { + (Box::new(self.fold_expr(func)?), None) + } + } else { + (Box::new(self.fold_expr(func)?), None) + } + }; + // let func = Box::new(self.fold_expr(func)?); + println!("Failed"); let keywords = keywords .into_iter() .map(|v| fold::fold_keyword(self, v)) @@ -1539,9 +1580,14 @@ impl<'a> Inferencer<'a> { loc: Some(location), operator_info: None, }; + println!("Try Unigu"); self.unifier.unify_call(&call, func.custom.unwrap(), sign).map_err(|e| { HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()]) })?; + if let Some(arg) = arg_self { + args.insert(0, arg); + } + println!("Reutrnin"); return Ok(Located { location, custom: Some(sign.ret), @@ -1665,10 +1711,11 @@ impl<'a> Inferencer<'a> { ctx: ExprContext, ) -> InferenceResult { let ty = value.custom.unwrap(); + println!("{:?}", value); if let TypeEnum::TObj { obj_id, fields, .. } = &*self.unifier.get_ty(ty) { // just a fast path match (fields.get(&attr), ctx == ExprContext::Store) { - (Some((ty, true)), _) | (Some((ty, false)), false) => Ok(*ty), + (Some((ty, true)), _) | (Some((ty, false)), false) => {println!("Returning"); Ok(*ty)}, (Some((ty, false)), true) => report_type_error( TypeErrorKind::MutationError(RecordKey::Str(attr), *ty), Some(value.location), @@ -1705,12 +1752,15 @@ impl<'a> Inferencer<'a> { } } else if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(ty) { // Access Class Attributes of classes with __init__ function using Class names e.g. Foo.ATTR1 + + // Remember to restore p[lz] + println!("jwicef\n"); let result = { self.top_level.definitions.read().iter().find_map(|def| { if let Some(rear_guard) = def.try_read() { - if let TopLevelDef::Class { name, attributes, .. } = &*rear_guard { + if let TopLevelDef::Class { name, methods, .. } = &*rear_guard { if name.to_string() == self.unifier.stringify(sign.ret) { - return attributes.iter().find_map(|f| { + return methods.iter().find_map(|f| { if f.0 == attr { return Some(f.clone().1); } @@ -1730,6 +1780,7 @@ impl<'a> Inferencer<'a> { None => self.infer_general_attribute(value, attr, ctx), } } else { + println!("ncfe\n"); self.infer_general_attribute(value, attr, ctx) } } diff --git a/nac3standalone/demo/interpreted.log b/nac3standalone/demo/interpreted.log new file mode 100644 index 00000000..e69de29b diff --git a/nac3standalone/demo/src/inheritance.py b/nac3standalone/demo/src/inheritance.py index d280e3a5..5ca937c8 100644 --- a/nac3standalone/demo/src/inheritance.py +++ b/nac3standalone/demo/src/inheritance.py @@ -4,29 +4,51 @@ from __future__ import annotations def output_int32(x: int32): ... -class A: +class C: + c: int32 a: int32 - - def __init__(self, a: int32): - self.a = a - - def f1(self): - self.f2() - - def f2(self): - output_int32(self.a) - -class B(A): b: int32 + def __init__(self): + self.a = 42 + self.b = 33 + self.c = 12 - def __init__(self, b: int32): - self.a = b + 1 - self.b = b + def test2(self): + output_int32(999) + output_int32(self.a) + output_int32(self.b) + output_int32(self.c) + + self.a = 23 + +class D(C): + def __init__(self): + # C.__init__(self) + self.test() + self.b = 1 + self.c = 2 + C.test2(self) + #self.a() + # self.test() + # C.test2(self) + # self.a = 2 + # __main__.C.__init__(self) + + def test(self): + self.a = 2 + def run() -> int32: - aaa = A(5) - bbb = B(2) - aaa.f1() - bbb.f1() + x = D() + output_int32(x.a) + output_int32(x.b) + output_int32(x.c) + + + + # aaa = A(5) + # bbb = B(2) + # aaa.f1() + # bbb.f1() return 0 diff --git a/nac3standalone/src/basic_symbol_resolver.rs b/nac3standalone/src/basic_symbol_resolver.rs index 5fe0d4f5..48f6cb0e 100644 --- a/nac3standalone/src/basic_symbol_resolver.rs +++ b/nac3standalone/src/basic_symbol_resolver.rs @@ -59,7 +59,7 @@ impl SymbolResolver for Resolver { _: StrRef, _: &mut CodeGenContext<'ctx, '_>, ) -> Option> { - unimplemented!() + None } fn get_identifier_def(&self, id: StrRef) -> Result> { diff --git a/pyo3_output/nac3artiq.so b/pyo3_output/nac3artiq.so new file mode 100755 index 00000000..beb4f236 Binary files /dev/null and b/pyo3_output/nac3artiq.so differ