From d0ed1435797ed9027425a05e8d8d8cb937579d1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Bourdeauducq?= Date: Thu, 12 Sep 2024 22:52:50 +0800 Subject: [PATCH] copy top level assignment logic from nac3standalone --- src/main.rs | 225 ++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 211 insertions(+), 14 deletions(-) diff --git a/src/main.rs b/src/main.rs index 169a917..900a443 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,11 +1,11 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::num::NonZeroUsize; use std::path::Path; use std::process::Command; use std::sync::Arc; -use parking_lot::Mutex; use eframe::egui; +use parking_lot::{Mutex, RwLock}; use nac3core::codegen; use nac3core::inkwell; @@ -17,6 +17,188 @@ use nac3core::typecheck::{type_inferencer, typedef}; mod basic_symbol_resolver; use basic_symbol_resolver::{Resolver, ResolverInternal}; +fn handle_typevar_definition( + var: &nac3parser::ast::Expr, + resolver: &(dyn nac3core::symbol_resolver::SymbolResolver + Send + Sync), + def_list: &[Arc>], + unifier: &mut nac3core::typecheck::typedef::Unifier, + primitives: &type_inferencer::PrimitiveStore, +) -> Result> { + let nac3parser::ast::ExprKind::Call { func, args, .. } = &var.node else { + return Err(HashSet::from([format!( + "expression {var:?} cannot be handled as a generic parameter in global scope" + )])); + }; + + match &func.node { + nac3parser::ast::ExprKind::Name { id, .. } if id == &"TypeVar".into() => { + let nac3parser::ast::ExprKind::Constant { + value: nac3parser::ast::Constant::Str(ty_name), + .. + } = &args[0].node + else { + return Err(HashSet::from([format!( + "Expected string constant for first parameter of `TypeVar`, got {:?}", + &args[0].node + )])); + }; + let generic_name: nac3parser::ast::StrRef = ty_name.to_string().into(); + + let constraints = args + .iter() + .skip(1) + .map(|x| -> Result> { + let ty = toplevel::type_annotation::parse_ast_to_type_annotation_kinds( + resolver, + def_list, + unifier, + primitives, + x, + HashMap::new(), + )?; + toplevel::type_annotation::get_type_from_type_annotation_kinds( + def_list, unifier, primitives, &ty, &mut None, + ) + }) + .collect::, _>>()?; + let loc = func.location; + + if constraints.len() == 1 { + return Err(HashSet::from([format!( + "A single constraint is not allowed (at {loc})" + )])); + } + + Ok(unifier + .get_fresh_var_with_range(&constraints, Some(generic_name), Some(loc)) + .ty) + } + + nac3parser::ast::ExprKind::Name { id, .. } if id == &"ConstGeneric".into() => { + if args.len() != 2 { + return Err(HashSet::from([format!( + "Expected 2 arguments for `ConstGeneric`, got {}", + args.len() + )])); + } + + let nac3parser::ast::ExprKind::Constant { + value: nac3parser::ast::Constant::Str(ty_name), + .. + } = &args[0].node + else { + return Err(HashSet::from([format!( + "Expected string constant for first parameter of `ConstGeneric`, got {:?}", + &args[0].node + )])); + }; + let generic_name: nac3parser::ast::StrRef = ty_name.to_string().into(); + + let ty = toplevel::type_annotation::parse_ast_to_type_annotation_kinds( + resolver, + def_list, + unifier, + primitives, + &args[1], + HashMap::new(), + )?; + let constraint = toplevel::type_annotation::get_type_from_type_annotation_kinds( + def_list, unifier, primitives, &ty, &mut None, + )?; + let loc = func.location; + + Ok(unifier + .get_fresh_const_generic_var(constraint, Some(generic_name), Some(loc)) + .ty) + } + + _ => Err(HashSet::from([format!( + "expression {var:?} cannot be handled as a generic parameter in global scope" + )])), + } +} + +fn handle_assignment_pattern( + targets: &[nac3parser::ast::Expr], + value: &nac3parser::ast::Expr, + resolver: &(dyn nac3core::symbol_resolver::SymbolResolver + Send + Sync), + internal_resolver: &ResolverInternal, + def_list: &[Arc>], + unifier: &mut nac3core::typecheck::typedef::Unifier, + primitives: &type_inferencer::PrimitiveStore, +) -> Result<(), String> { + if targets.len() == 1 { + match &targets[0].node { + nac3parser::ast::ExprKind::Name { id, .. } => { + if let Ok(var) = + handle_typevar_definition(value, resolver, def_list, unifier, primitives) + { + internal_resolver.add_id_type(*id, var); + Ok(()) + } else if let Ok(val) = + toplevel::helper::parse_parameter_default_value(value, resolver) + { + internal_resolver.add_module_global(*id, val); + Ok(()) + } else { + Err(format!("fails to evaluate this expression `{:?}` as a constant or generic parameter at {}", + targets[0].node, + targets[0].location, + )) + } + } + nac3parser::ast::ExprKind::List { elts, .. } + | nac3parser::ast::ExprKind::Tuple { elts, .. } => { + handle_assignment_pattern( + elts, + value, + resolver, + internal_resolver, + def_list, + unifier, + primitives, + )?; + Ok(()) + } + _ => Err(format!( + "assignment to {:?} is not supported at {}", + targets[0], targets[0].location + )), + } + } else { + match &value.node { + nac3parser::ast::ExprKind::List { elts, .. } + | nac3parser::ast::ExprKind::Tuple { elts, .. } => { + if elts.len() == targets.len() { + for (tar, val) in targets.iter().zip(elts) { + handle_assignment_pattern( + std::slice::from_ref(tar), + val, + resolver, + internal_resolver, + def_list, + unifier, + primitives, + )?; + } + Ok(()) + } else { + Err(format!( + "number of elements to unpack does not match (expect {}, found {}) at {}", + targets.len(), + elts.len(), + value.location + )) + } + } + _ => Err(format!( + "unpack of this expression is not supported at {}", + value.location + )), + } + } +} + fn compile(code: &String, run_symbol: &String, output_filename: &Path) -> Result<(), String> { let mut target_machine_options = codegen::CodeGenTargetMachineOptions::from_host(); target_machine_options.reloc_mode = inkwell::targets::RelocMode::PIC; @@ -59,20 +241,35 @@ fn compile(code: &String, run_symbol: &String, output_filename: &Path) -> Result } }; for mut stmt in parser_result { - if let nac3parser::ast::StmtKind::FunctionDef { name, .. } = &mut stmt.node { - if name.to_string() == "run" { - *name = run_symbol.as_str().into(); - } - } - match composer.register_top_level(stmt, Some(resolver.clone()), "__main__", true) { - Ok((name, def_id, ty)) => { - internal_resolver.add_id_def(name, def_id); - if let Some(ty) = ty { - internal_resolver.add_id_type(name, ty); + if let nac3parser::ast::StmtKind::Assign { targets, value, .. } = stmt.node { + let def_list = composer.extract_def_list(); + let unifier = &mut composer.unifier; + let primitives = &composer.primitives_ty; + handle_assignment_pattern( + &targets, + &value, + resolver.as_ref(), + internal_resolver.as_ref(), + &def_list, + unifier, + primitives, + )?; + } else { + if let nac3parser::ast::StmtKind::FunctionDef { name, .. } = &mut stmt.node { + if name.to_string() == "run" { + *name = run_symbol.as_str().into(); } } - Err(err) => { - return Err(format!("composer error: {}", err)); + match composer.register_top_level(stmt, Some(resolver.clone()), "__main__", true) { + Ok((name, def_id, ty)) => { + internal_resolver.add_id_def(name, def_id); + if let Some(ty) = ty { + internal_resolver.add_id_type(name, ty); + } + } + Err(err) => { + return Err(format!("composer error: {}", err)); + } } } }