From b6dfcfcc389d055bbf8528d3d8e797a9d186f482 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 1 Dec 2023 15:56:18 +0800 Subject: [PATCH] core: Move some SymbolValue functions to symbol_resolver.rs --- nac3core/src/symbol_resolver.rs | 146 +++++++++++++++++++++++++++++++- nac3core/src/toplevel/helper.rs | 36 +------- 2 files changed, 145 insertions(+), 37 deletions(-) diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index cce05c4a..8e61e5ea 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -1,11 +1,12 @@ use std::fmt::Debug; use std::sync::Arc; use std::{collections::HashMap, fmt::Display}; +use std::rc::Rc; use crate::typecheck::typedef::TypeEnum; use crate::{ codegen::CodeGenContext, - toplevel::{DefinitionId, TopLevelDef}, + toplevel::{DefinitionId, TopLevelDef, type_annotation::TypeAnnotation}, }; use crate::{ codegen::CodeGenerator, @@ -16,7 +17,7 @@ use crate::{ }; use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue}; use itertools::{chain, izip}; -use nac3parser::ast::{Expr, Location, StrRef}; +use nac3parser::ast::{Constant, Expr, Location, StrRef}; use parking_lot::RwLock; #[derive(Clone, PartialEq, Debug)] @@ -33,6 +34,147 @@ pub enum SymbolValue { OptionNone, } +impl SymbolValue { + /// Creates a [SymbolValue] from a [Constant]. + /// + /// * `constant` - The constant to create the value from. + /// * `expected_ty` - The expected type of the [SymbolValue]. + pub fn from_constant( + constant: &Constant, + expected_ty: Type, + primitives: &PrimitiveStore, + unifier: &mut Unifier + ) -> Result { + match constant { + Constant::None => { + if unifier.unioned(expected_ty, primitives.option) { + Ok(SymbolValue::OptionNone) + } else { + Err(format!("Expected {:?}, but got Option", expected_ty)) + } + } + Constant::Bool(b) => { + if unifier.unioned(expected_ty, primitives.bool) { + Ok(SymbolValue::Bool(*b)) + } else { + Err(format!("Expected {:?}, but got bool", expected_ty)) + } + } + Constant::Str(s) => { + if unifier.unioned(expected_ty, primitives.str) { + Ok(SymbolValue::Str(s.to_string())) + } else { + Err(format!("Expected {:?}, but got str", expected_ty)) + } + }, + Constant::Int(i) => { + if unifier.unioned(expected_ty, primitives.int32) { + i32::try_from(*i) + .map(|val| SymbolValue::I32(val)) + .map_err(|e| e.to_string()) + } else if unifier.unioned(expected_ty, primitives.int64) { + i64::try_from(*i) + .map(|val| SymbolValue::I64(val)) + .map_err(|e| e.to_string()) + } else if unifier.unioned(expected_ty, primitives.uint32) { + u32::try_from(*i) + .map(|val| SymbolValue::U32(val)) + .map_err(|e| e.to_string()) + } else if unifier.unioned(expected_ty, primitives.uint64) { + u64::try_from(*i) + .map(|val| SymbolValue::U64(val)) + .map_err(|e| e.to_string()) + } else { + Err(format!("Expected {:?}, but got int", expected_ty)) + } + } + Constant::Tuple(t) => { + let expected_ty = unifier.get_ty(expected_ty); + let TypeEnum::TTuple { ty } = expected_ty.as_ref() else { + return Err(format!("Expected {:?}, but got Tuple", expected_ty.get_type_name())) + }; + + assert_eq!(ty.len(), t.len()); + + let elems = t.into_iter() + .zip(ty) + .map(|(constant, ty)| Self::from_constant(constant, *ty, primitives, unifier)) + .collect::, _>>()?; + Ok(SymbolValue::Tuple(elems)) + } + Constant::Float(f) => { + if unifier.unioned(expected_ty, primitives.float) { + Ok(SymbolValue::Double(*f)) + } else { + Err(format!("Expected {:?}, but got float", expected_ty)) + } + }, + _ => Err(format!("Unsupported value type {:?}", constant)), + } + } + + /// Returns the [Type] representing the data type of this value. + pub fn get_type(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> Type { + match self { + SymbolValue::I32(_) => primitives.int32, + SymbolValue::I64(_) => primitives.int64, + SymbolValue::U32(_) => primitives.uint32, + SymbolValue::U64(_) => primitives.uint64, + SymbolValue::Str(_) => primitives.str, + SymbolValue::Double(_) => primitives.float, + SymbolValue::Bool(_) => primitives.bool, + SymbolValue::Tuple(vs) => { + let vs_tys = vs + .iter() + .map(|v| v.get_type(primitives, unifier)) + .collect::>(); + unifier.add_ty(TypeEnum::TTuple { + ty: vs_tys, + }) + } + SymbolValue::OptionSome(_) => primitives.option, + SymbolValue::OptionNone => primitives.option, + } + } + + /// Returns the [TypeAnnotation] representing the data type of this value. + pub fn get_type_annotation(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> TypeAnnotation { + match self { + SymbolValue::Bool(..) => TypeAnnotation::Primitive(primitives.bool), + SymbolValue::Double(..) => TypeAnnotation::Primitive(primitives.float), + SymbolValue::I32(..) => TypeAnnotation::Primitive(primitives.int32), + SymbolValue::I64(..) => TypeAnnotation::Primitive(primitives.int64), + SymbolValue::U32(..) => TypeAnnotation::Primitive(primitives.uint32), + SymbolValue::U64(..) => TypeAnnotation::Primitive(primitives.uint64), + SymbolValue::Str(..) => TypeAnnotation::Primitive(primitives.str), + SymbolValue::Tuple(vs) => { + let vs_tys = vs + .iter() + .map(|v| v.get_type_annotation(primitives, unifier)) + .collect::>(); + TypeAnnotation::Tuple(vs_tys) + } + SymbolValue::OptionNone => TypeAnnotation::CustomClass { + id: primitives.option.get_obj_id(unifier), + params: Default::default(), + }, + SymbolValue::OptionSome(v) => { + let ty = v.get_type_annotation(primitives, unifier); + TypeAnnotation::CustomClass { + id: primitives.option.get_obj_id(unifier), + params: vec![ty], + } + } + } + } + + /// Returns the [TypeEnum] representing the data type of this value. + pub fn get_type_enum(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> Rc { + let ty = self.get_type(primitives, unifier); + unifier.get_ty(ty) + } +} + impl Display for SymbolValue { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 73eebab5..014996a6 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -416,40 +416,6 @@ impl TopLevelComposer { primitive: &PrimitiveStore, unifier: &mut Unifier, ) -> Result<(), String> { - fn type_default_param( - val: &SymbolValue, - primitive: &PrimitiveStore, - unifier: &mut Unifier, - ) -> TypeAnnotation { - match val { - SymbolValue::Bool(..) => TypeAnnotation::Primitive(primitive.bool), - SymbolValue::Double(..) => TypeAnnotation::Primitive(primitive.float), - SymbolValue::I32(..) => TypeAnnotation::Primitive(primitive.int32), - SymbolValue::I64(..) => TypeAnnotation::Primitive(primitive.int64), - SymbolValue::U32(..) => TypeAnnotation::Primitive(primitive.uint32), - SymbolValue::U64(..) => TypeAnnotation::Primitive(primitive.uint64), - SymbolValue::Str(..) => TypeAnnotation::Primitive(primitive.str), - SymbolValue::Tuple(vs) => { - let vs_tys = vs - .iter() - .map(|v| type_default_param(v, primitive, unifier)) - .collect::>(); - TypeAnnotation::Tuple(vs_tys) - } - SymbolValue::OptionNone => TypeAnnotation::CustomClass { - id: primitive.option.get_obj_id(unifier), - params: Default::default(), - }, - SymbolValue::OptionSome(v) => { - let ty = type_default_param(v, primitive, unifier); - TypeAnnotation::CustomClass { - id: primitive.option.get_obj_id(unifier), - params: vec![ty], - } - } - } - } - fn is_compatible( found: &TypeAnnotation, expect: &TypeAnnotation, @@ -481,7 +447,7 @@ impl TopLevelComposer { } } - let found = type_default_param(val, primitive, unifier); + let found = val.get_type_annotation(primitive, unifier); if !is_compatible(&found, ty, unifier, primitive) { Err(format!( "incompatible default parameter type, expect {}, found {}",