diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 3ff5e0e4c..53b852f51 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -16,7 +16,7 @@ use crate::{ }, }; use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue}; -use itertools::{chain, izip}; +use itertools::{chain, Itertools, izip}; use nac3parser::ast::{Constant, Expr, Location, StrRef}; use parking_lot::RwLock; @@ -354,7 +354,7 @@ pub trait SymbolResolver { } thread_local! { - static IDENTIFIER_ID: [StrRef; 12] = [ + static IDENTIFIER_ID: [StrRef; 13] = [ "int32".into(), "int64".into(), "float".into(), @@ -367,6 +367,7 @@ thread_local! { "Exception".into(), "uint32".into(), "uint64".into(), + "Literal".into(), ]; } @@ -392,6 +393,7 @@ pub fn parse_type_annotation( let exn_id = ids[9]; let uint32_id = ids[10]; let uint64_id = ids[11]; + let literal_id = ids[12]; let name_handling = |id: &StrRef, loc: Location, unifier: &mut Unifier| { if *id == int32_id { @@ -491,6 +493,27 @@ pub fn parse_type_annotation( "Expected multiple elements for tuple".into() ])) } + } else if *id == literal_id { + let mut parse_literal = |elt: &Expr| { + let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, elt)?; + let ty_enum = &*unifier.get_ty_immutable(ty); + match ty_enum { + TypeEnum::TLiteral { values, .. } => Ok(values.clone()), + _ => Err(HashSet::from([ + format!("Expected literal in type argument for Literal at {}", elt.location), + ])) + } + }; + + let values = if let Tuple { elts, .. } = &slice.node { + elts.iter() + .map(|elt| parse_literal(elt)) + .collect::, _>>()? + } else { + vec![parse_literal(slice)?] + }.into_iter().flatten().collect_vec(); + + Ok(unifier.get_fresh_literal(values, Some(slice.location))) } else { let types = if let Tuple { elts, .. } = &slice.node { elts.iter() @@ -554,6 +577,9 @@ pub fn parse_type_annotation( ])) } } + Constant { value, .. } => SymbolValue::from_constant_inferred(value, unifier) + .map(|v| unifier.get_fresh_literal(vec![v], Some(expr.location.clone()))) + .map_err(|err| HashSet::from([err])), _ => Err(HashSet::from([ format!("unsupported type expression at {}", expr.location), ])), diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 98794a7b8..ecaf81bdc 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -8,33 +8,33 @@ def consume_ndarray_2(n: ndarray[float, Literal[2]]): pass def test_ndarray_ctor(): - n = np_ndarray([1]) + n: ndarray[float, Literal[1]] = np_ndarray([1]) consume_ndarray_1(n) def test_ndarray_empty(): - n = np_empty([1]) + n: ndarray[float, 1] = np_empty([1]) consume_ndarray_1(n) def test_ndarray_zeros(): - n = np_zeros([1]) + n: ndarray[float, 1] = np_zeros([1]) consume_ndarray_1(n) def test_ndarray_ones(): - n = np_ones([1]) + n: ndarray[float, 1] = np_ones([1]) consume_ndarray_1(n) def test_ndarray_full(): - n_float = np_full([1], 2.0) + n_float: ndarray[float, 1] = np_full([1], 2.0) consume_ndarray_1(n_float) - n_i32 = np_full([1], 2) + n_i32: ndarray[int32, 1] = np_full([1], 2) consume_ndarray_i32_1(n_i32) def test_ndarray_eye(): - n = np_eye(2) + n: ndarray[float, 2] = np_eye(2) consume_ndarray_2(n) def test_ndarray_identity(): - n = np_identity(2) + n: ndarray[float, 2] = np_identity(2) consume_ndarray_2(n) def run() -> int32: