forked from M-Labs/nac3
1
0
Fork 0

core: Fix Literal use in variable type annotation

This commit is contained in:
David Mak 2024-02-06 12:29:21 +08:00 committed by sb10q
parent 1963c30744
commit 5cecb2bb74
2 changed files with 36 additions and 10 deletions

View File

@ -16,7 +16,7 @@ use crate::{
}, },
}; };
use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue}; use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue};
use itertools::{chain, izip}; use itertools::{chain, Itertools, izip};
use nac3parser::ast::{Constant, Expr, Location, StrRef}; use nac3parser::ast::{Constant, Expr, Location, StrRef};
use parking_lot::RwLock; use parking_lot::RwLock;
@ -354,7 +354,7 @@ pub trait SymbolResolver {
} }
thread_local! { thread_local! {
static IDENTIFIER_ID: [StrRef; 12] = [ static IDENTIFIER_ID: [StrRef; 13] = [
"int32".into(), "int32".into(),
"int64".into(), "int64".into(),
"float".into(), "float".into(),
@ -367,6 +367,7 @@ thread_local! {
"Exception".into(), "Exception".into(),
"uint32".into(), "uint32".into(),
"uint64".into(), "uint64".into(),
"Literal".into(),
]; ];
} }
@ -392,6 +393,7 @@ pub fn parse_type_annotation<T>(
let exn_id = ids[9]; let exn_id = ids[9];
let uint32_id = ids[10]; let uint32_id = ids[10];
let uint64_id = ids[11]; let uint64_id = ids[11];
let literal_id = ids[12];
let name_handling = |id: &StrRef, loc: Location, unifier: &mut Unifier| { let name_handling = |id: &StrRef, loc: Location, unifier: &mut Unifier| {
if *id == int32_id { if *id == int32_id {
@ -491,6 +493,27 @@ pub fn parse_type_annotation<T>(
"Expected multiple elements for tuple".into() "Expected multiple elements for tuple".into()
])) ]))
} }
} else if *id == literal_id {
let mut parse_literal = |elt: &Expr<T>| {
let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, elt)?;
let ty_enum = &*unifier.get_ty_immutable(ty);
match ty_enum {
TypeEnum::TLiteral { values, .. } => Ok(values.clone()),
_ => Err(HashSet::from([
format!("Expected literal in type argument for Literal at {}", elt.location),
]))
}
};
let values = if let Tuple { elts, .. } = &slice.node {
elts.iter()
.map(|elt| parse_literal(elt))
.collect::<Result<Vec<_>, _>>()?
} else {
vec![parse_literal(slice)?]
}.into_iter().flatten().collect_vec();
Ok(unifier.get_fresh_literal(values, Some(slice.location)))
} else { } else {
let types = if let Tuple { elts, .. } = &slice.node { let types = if let Tuple { elts, .. } = &slice.node {
elts.iter() elts.iter()
@ -554,6 +577,9 @@ pub fn parse_type_annotation<T>(
])) ]))
} }
} }
Constant { value, .. } => SymbolValue::from_constant_inferred(value, unifier)
.map(|v| unifier.get_fresh_literal(vec![v], Some(expr.location.clone())))
.map_err(|err| HashSet::from([err])),
_ => Err(HashSet::from([ _ => Err(HashSet::from([
format!("unsupported type expression at {}", expr.location), format!("unsupported type expression at {}", expr.location),
])), ])),

View File

@ -8,33 +8,33 @@ def consume_ndarray_2(n: ndarray[float, Literal[2]]):
pass pass
def test_ndarray_ctor(): def test_ndarray_ctor():
n = np_ndarray([1]) n: ndarray[float, Literal[1]] = np_ndarray([1])
consume_ndarray_1(n) consume_ndarray_1(n)
def test_ndarray_empty(): def test_ndarray_empty():
n = np_empty([1]) n: ndarray[float, 1] = np_empty([1])
consume_ndarray_1(n) consume_ndarray_1(n)
def test_ndarray_zeros(): def test_ndarray_zeros():
n = np_zeros([1]) n: ndarray[float, 1] = np_zeros([1])
consume_ndarray_1(n) consume_ndarray_1(n)
def test_ndarray_ones(): def test_ndarray_ones():
n = np_ones([1]) n: ndarray[float, 1] = np_ones([1])
consume_ndarray_1(n) consume_ndarray_1(n)
def test_ndarray_full(): def test_ndarray_full():
n_float = np_full([1], 2.0) n_float: ndarray[float, 1] = np_full([1], 2.0)
consume_ndarray_1(n_float) consume_ndarray_1(n_float)
n_i32 = np_full([1], 2) n_i32: ndarray[int32, 1] = np_full([1], 2)
consume_ndarray_i32_1(n_i32) consume_ndarray_i32_1(n_i32)
def test_ndarray_eye(): def test_ndarray_eye():
n = np_eye(2) n: ndarray[float, 2] = np_eye(2)
consume_ndarray_2(n) consume_ndarray_2(n)
def test_ndarray_identity(): def test_ndarray_identity():
n = np_identity(2) n: ndarray[float, 2] = np_identity(2)
consume_ndarray_2(n) consume_ndarray_2(n)
def run() -> int32: def run() -> int32: