core: Fix Literal use in variable type annotation

This commit is contained in:
David Mak 2024-02-06 12:29:21 +08:00
parent 27011f385b
commit ac18fb312c
2 changed files with 36 additions and 10 deletions

View File

@ -16,7 +16,7 @@ use crate::{
},
};
use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue};
use itertools::{chain, izip};
use itertools::{chain, Itertools, izip};
use nac3parser::ast::{Constant, Expr, Location, StrRef};
use parking_lot::RwLock;
@ -354,7 +354,7 @@ pub trait SymbolResolver {
}
thread_local! {
static IDENTIFIER_ID: [StrRef; 12] = [
static IDENTIFIER_ID: [StrRef; 13] = [
"int32".into(),
"int64".into(),
"float".into(),
@ -367,6 +367,7 @@ thread_local! {
"Exception".into(),
"uint32".into(),
"uint64".into(),
"Literal".into(),
];
}
@ -392,6 +393,7 @@ pub fn parse_type_annotation<T>(
let exn_id = ids[9];
let uint32_id = ids[10];
let uint64_id = ids[11];
let literal_id = ids[12];
let name_handling = |id: &StrRef, loc: Location, unifier: &mut Unifier| {
if *id == int32_id {
@ -491,6 +493,27 @@ pub fn parse_type_annotation<T>(
"Expected multiple elements for tuple".into()
]))
}
} else if *id == literal_id {
let mut parse_literal = |elt: &Expr<T>| {
let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, elt)?;
let ty_enum = &*unifier.get_ty_immutable(ty);
match ty_enum {
TypeEnum::TLiteral { values, .. } => Ok(values.clone()),
_ => Err(HashSet::from([
format!("Expected literal in type argument for Literal at {}", elt.location),
]))
}
};
let values = if let Tuple { elts, .. } = &slice.node {
elts.iter()
.map(|elt| parse_literal(elt))
.collect::<Result<Vec<_>, _>>()?
} else {
vec![parse_literal(slice)?]
}.into_iter().flatten().collect_vec();
Ok(unifier.get_fresh_literal(values, Some(slice.location)))
} else {
let types = if let Tuple { elts, .. } = &slice.node {
elts.iter()
@ -554,6 +577,9 @@ pub fn parse_type_annotation<T>(
]))
}
}
Constant { value, .. } => SymbolValue::from_constant_inferred(value, unifier)
.map(|v| unifier.get_fresh_literal(vec![v], Some(expr.location.clone())))
.map_err(|err| HashSet::from([err])),
_ => Err(HashSet::from([
format!("unsupported type expression at {}", expr.location),
])),

View File

@ -8,33 +8,33 @@ def consume_ndarray_2(n: ndarray[float, Literal[2]]):
pass
def test_ndarray_ctor():
n = np_ndarray([1])
n: ndarray[float, Literal[1]] = np_ndarray([1])
consume_ndarray_1(n)
def test_ndarray_empty():
n = np_empty([1])
n: ndarray[float, 1] = np_empty([1])
consume_ndarray_1(n)
def test_ndarray_zeros():
n = np_zeros([1])
n: ndarray[float, 1] = np_zeros([1])
consume_ndarray_1(n)
def test_ndarray_ones():
n = np_ones([1])
n: ndarray[float, 1] = np_ones([1])
consume_ndarray_1(n)
def test_ndarray_full():
n_float = np_full([1], 2.0)
n_float: ndarray[float, 1] = np_full([1], 2.0)
consume_ndarray_1(n_float)
n_i32 = np_full([1], 2)
n_i32: ndarray[int32, 1] = np_full([1], 2)
consume_ndarray_i32_1(n_i32)
def test_ndarray_eye():
n = np_eye(2)
n: ndarray[float, 2] = np_eye(2)
consume_ndarray_2(n)
def test_ndarray_identity():
n = np_identity(2)
n: ndarray[float, 2] = np_identity(2)
consume_ndarray_2(n)
def run() -> int32: