copy top level assignment logic from nac3standalone

This commit is contained in:
Sébastien Bourdeauducq 2024-09-12 22:52:50 +08:00
parent c5fe70ec9a
commit d0ed143579

View File

@ -1,11 +1,11 @@
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::num::NonZeroUsize;
use std::path::Path;
use std::process::Command;
use std::sync::Arc;
use parking_lot::Mutex;
use eframe::egui;
use parking_lot::{Mutex, RwLock};
use nac3core::codegen;
use nac3core::inkwell;
@ -17,6 +17,188 @@ use nac3core::typecheck::{type_inferencer, typedef};
mod basic_symbol_resolver;
use basic_symbol_resolver::{Resolver, ResolverInternal};
fn handle_typevar_definition(
var: &nac3parser::ast::Expr,
resolver: &(dyn nac3core::symbol_resolver::SymbolResolver + Send + Sync),
def_list: &[Arc<RwLock<toplevel::TopLevelDef>>],
unifier: &mut nac3core::typecheck::typedef::Unifier,
primitives: &type_inferencer::PrimitiveStore,
) -> Result<typedef::Type, HashSet<String>> {
let nac3parser::ast::ExprKind::Call { func, args, .. } = &var.node else {
return Err(HashSet::from([format!(
"expression {var:?} cannot be handled as a generic parameter in global scope"
)]));
};
match &func.node {
nac3parser::ast::ExprKind::Name { id, .. } if id == &"TypeVar".into() => {
let nac3parser::ast::ExprKind::Constant {
value: nac3parser::ast::Constant::Str(ty_name),
..
} = &args[0].node
else {
return Err(HashSet::from([format!(
"Expected string constant for first parameter of `TypeVar`, got {:?}",
&args[0].node
)]));
};
let generic_name: nac3parser::ast::StrRef = ty_name.to_string().into();
let constraints = args
.iter()
.skip(1)
.map(|x| -> Result<typedef::Type, HashSet<String>> {
let ty = toplevel::type_annotation::parse_ast_to_type_annotation_kinds(
resolver,
def_list,
unifier,
primitives,
x,
HashMap::new(),
)?;
toplevel::type_annotation::get_type_from_type_annotation_kinds(
def_list, unifier, primitives, &ty, &mut None,
)
})
.collect::<Result<Vec<_>, _>>()?;
let loc = func.location;
if constraints.len() == 1 {
return Err(HashSet::from([format!(
"A single constraint is not allowed (at {loc})"
)]));
}
Ok(unifier
.get_fresh_var_with_range(&constraints, Some(generic_name), Some(loc))
.ty)
}
nac3parser::ast::ExprKind::Name { id, .. } if id == &"ConstGeneric".into() => {
if args.len() != 2 {
return Err(HashSet::from([format!(
"Expected 2 arguments for `ConstGeneric`, got {}",
args.len()
)]));
}
let nac3parser::ast::ExprKind::Constant {
value: nac3parser::ast::Constant::Str(ty_name),
..
} = &args[0].node
else {
return Err(HashSet::from([format!(
"Expected string constant for first parameter of `ConstGeneric`, got {:?}",
&args[0].node
)]));
};
let generic_name: nac3parser::ast::StrRef = ty_name.to_string().into();
let ty = toplevel::type_annotation::parse_ast_to_type_annotation_kinds(
resolver,
def_list,
unifier,
primitives,
&args[1],
HashMap::new(),
)?;
let constraint = toplevel::type_annotation::get_type_from_type_annotation_kinds(
def_list, unifier, primitives, &ty, &mut None,
)?;
let loc = func.location;
Ok(unifier
.get_fresh_const_generic_var(constraint, Some(generic_name), Some(loc))
.ty)
}
_ => Err(HashSet::from([format!(
"expression {var:?} cannot be handled as a generic parameter in global scope"
)])),
}
}
fn handle_assignment_pattern(
targets: &[nac3parser::ast::Expr],
value: &nac3parser::ast::Expr,
resolver: &(dyn nac3core::symbol_resolver::SymbolResolver + Send + Sync),
internal_resolver: &ResolverInternal,
def_list: &[Arc<RwLock<toplevel::TopLevelDef>>],
unifier: &mut nac3core::typecheck::typedef::Unifier,
primitives: &type_inferencer::PrimitiveStore,
) -> Result<(), String> {
if targets.len() == 1 {
match &targets[0].node {
nac3parser::ast::ExprKind::Name { id, .. } => {
if let Ok(var) =
handle_typevar_definition(value, resolver, def_list, unifier, primitives)
{
internal_resolver.add_id_type(*id, var);
Ok(())
} else if let Ok(val) =
toplevel::helper::parse_parameter_default_value(value, resolver)
{
internal_resolver.add_module_global(*id, val);
Ok(())
} else {
Err(format!("fails to evaluate this expression `{:?}` as a constant or generic parameter at {}",
targets[0].node,
targets[0].location,
))
}
}
nac3parser::ast::ExprKind::List { elts, .. }
| nac3parser::ast::ExprKind::Tuple { elts, .. } => {
handle_assignment_pattern(
elts,
value,
resolver,
internal_resolver,
def_list,
unifier,
primitives,
)?;
Ok(())
}
_ => Err(format!(
"assignment to {:?} is not supported at {}",
targets[0], targets[0].location
)),
}
} else {
match &value.node {
nac3parser::ast::ExprKind::List { elts, .. }
| nac3parser::ast::ExprKind::Tuple { elts, .. } => {
if elts.len() == targets.len() {
for (tar, val) in targets.iter().zip(elts) {
handle_assignment_pattern(
std::slice::from_ref(tar),
val,
resolver,
internal_resolver,
def_list,
unifier,
primitives,
)?;
}
Ok(())
} else {
Err(format!(
"number of elements to unpack does not match (expect {}, found {}) at {}",
targets.len(),
elts.len(),
value.location
))
}
}
_ => Err(format!(
"unpack of this expression is not supported at {}",
value.location
)),
}
}
}
fn compile(code: &String, run_symbol: &String, output_filename: &Path) -> Result<(), String> {
let mut target_machine_options = codegen::CodeGenTargetMachineOptions::from_host();
target_machine_options.reloc_mode = inkwell::targets::RelocMode::PIC;
@ -59,20 +241,35 @@ fn compile(code: &String, run_symbol: &String, output_filename: &Path) -> Result
}
};
for mut stmt in parser_result {
if let nac3parser::ast::StmtKind::FunctionDef { name, .. } = &mut stmt.node {
if name.to_string() == "run" {
*name = run_symbol.as_str().into();
}
}
match composer.register_top_level(stmt, Some(resolver.clone()), "__main__", true) {
Ok((name, def_id, ty)) => {
internal_resolver.add_id_def(name, def_id);
if let Some(ty) = ty {
internal_resolver.add_id_type(name, ty);
if let nac3parser::ast::StmtKind::Assign { targets, value, .. } = stmt.node {
let def_list = composer.extract_def_list();
let unifier = &mut composer.unifier;
let primitives = &composer.primitives_ty;
handle_assignment_pattern(
&targets,
&value,
resolver.as_ref(),
internal_resolver.as_ref(),
&def_list,
unifier,
primitives,
)?;
} else {
if let nac3parser::ast::StmtKind::FunctionDef { name, .. } = &mut stmt.node {
if name.to_string() == "run" {
*name = run_symbol.as_str().into();
}
}
Err(err) => {
return Err(format!("composer error: {}", err));
match composer.register_top_level(stmt, Some(resolver.clone()), "__main__", true) {
Ok((name, def_id, ty)) => {
internal_resolver.add_id_def(name, def_id);
if let Some(ty) = ty {
internal_resolver.add_id_type(name, ty);
}
}
Err(err) => {
return Err(format!("composer error: {}", err));
}
}
}
}