forked from M-Labs/nac3
core: Fix Literal use in variable type annotation
This commit is contained in:
parent
1963c30744
commit
5cecb2bb74
|
@ -16,7 +16,7 @@ use crate::{
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue};
|
use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue};
|
||||||
use itertools::{chain, izip};
|
use itertools::{chain, Itertools, izip};
|
||||||
use nac3parser::ast::{Constant, Expr, Location, StrRef};
|
use nac3parser::ast::{Constant, Expr, Location, StrRef};
|
||||||
use parking_lot::RwLock;
|
use parking_lot::RwLock;
|
||||||
|
|
||||||
|
@ -354,7 +354,7 @@ pub trait SymbolResolver {
|
||||||
}
|
}
|
||||||
|
|
||||||
thread_local! {
|
thread_local! {
|
||||||
static IDENTIFIER_ID: [StrRef; 12] = [
|
static IDENTIFIER_ID: [StrRef; 13] = [
|
||||||
"int32".into(),
|
"int32".into(),
|
||||||
"int64".into(),
|
"int64".into(),
|
||||||
"float".into(),
|
"float".into(),
|
||||||
|
@ -367,6 +367,7 @@ thread_local! {
|
||||||
"Exception".into(),
|
"Exception".into(),
|
||||||
"uint32".into(),
|
"uint32".into(),
|
||||||
"uint64".into(),
|
"uint64".into(),
|
||||||
|
"Literal".into(),
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -392,6 +393,7 @@ pub fn parse_type_annotation<T>(
|
||||||
let exn_id = ids[9];
|
let exn_id = ids[9];
|
||||||
let uint32_id = ids[10];
|
let uint32_id = ids[10];
|
||||||
let uint64_id = ids[11];
|
let uint64_id = ids[11];
|
||||||
|
let literal_id = ids[12];
|
||||||
|
|
||||||
let name_handling = |id: &StrRef, loc: Location, unifier: &mut Unifier| {
|
let name_handling = |id: &StrRef, loc: Location, unifier: &mut Unifier| {
|
||||||
if *id == int32_id {
|
if *id == int32_id {
|
||||||
|
@ -491,6 +493,27 @@ pub fn parse_type_annotation<T>(
|
||||||
"Expected multiple elements for tuple".into()
|
"Expected multiple elements for tuple".into()
|
||||||
]))
|
]))
|
||||||
}
|
}
|
||||||
|
} else if *id == literal_id {
|
||||||
|
let mut parse_literal = |elt: &Expr<T>| {
|
||||||
|
let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, elt)?;
|
||||||
|
let ty_enum = &*unifier.get_ty_immutable(ty);
|
||||||
|
match ty_enum {
|
||||||
|
TypeEnum::TLiteral { values, .. } => Ok(values.clone()),
|
||||||
|
_ => Err(HashSet::from([
|
||||||
|
format!("Expected literal in type argument for Literal at {}", elt.location),
|
||||||
|
]))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let values = if let Tuple { elts, .. } = &slice.node {
|
||||||
|
elts.iter()
|
||||||
|
.map(|elt| parse_literal(elt))
|
||||||
|
.collect::<Result<Vec<_>, _>>()?
|
||||||
|
} else {
|
||||||
|
vec![parse_literal(slice)?]
|
||||||
|
}.into_iter().flatten().collect_vec();
|
||||||
|
|
||||||
|
Ok(unifier.get_fresh_literal(values, Some(slice.location)))
|
||||||
} else {
|
} else {
|
||||||
let types = if let Tuple { elts, .. } = &slice.node {
|
let types = if let Tuple { elts, .. } = &slice.node {
|
||||||
elts.iter()
|
elts.iter()
|
||||||
|
@ -554,6 +577,9 @@ pub fn parse_type_annotation<T>(
|
||||||
]))
|
]))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Constant { value, .. } => SymbolValue::from_constant_inferred(value, unifier)
|
||||||
|
.map(|v| unifier.get_fresh_literal(vec![v], Some(expr.location.clone())))
|
||||||
|
.map_err(|err| HashSet::from([err])),
|
||||||
_ => Err(HashSet::from([
|
_ => Err(HashSet::from([
|
||||||
format!("unsupported type expression at {}", expr.location),
|
format!("unsupported type expression at {}", expr.location),
|
||||||
])),
|
])),
|
||||||
|
|
|
@ -8,33 +8,33 @@ def consume_ndarray_2(n: ndarray[float, Literal[2]]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_ndarray_ctor():
|
def test_ndarray_ctor():
|
||||||
n = np_ndarray([1])
|
n: ndarray[float, Literal[1]] = np_ndarray([1])
|
||||||
consume_ndarray_1(n)
|
consume_ndarray_1(n)
|
||||||
|
|
||||||
def test_ndarray_empty():
|
def test_ndarray_empty():
|
||||||
n = np_empty([1])
|
n: ndarray[float, 1] = np_empty([1])
|
||||||
consume_ndarray_1(n)
|
consume_ndarray_1(n)
|
||||||
|
|
||||||
def test_ndarray_zeros():
|
def test_ndarray_zeros():
|
||||||
n = np_zeros([1])
|
n: ndarray[float, 1] = np_zeros([1])
|
||||||
consume_ndarray_1(n)
|
consume_ndarray_1(n)
|
||||||
|
|
||||||
def test_ndarray_ones():
|
def test_ndarray_ones():
|
||||||
n = np_ones([1])
|
n: ndarray[float, 1] = np_ones([1])
|
||||||
consume_ndarray_1(n)
|
consume_ndarray_1(n)
|
||||||
|
|
||||||
def test_ndarray_full():
|
def test_ndarray_full():
|
||||||
n_float = np_full([1], 2.0)
|
n_float: ndarray[float, 1] = np_full([1], 2.0)
|
||||||
consume_ndarray_1(n_float)
|
consume_ndarray_1(n_float)
|
||||||
n_i32 = np_full([1], 2)
|
n_i32: ndarray[int32, 1] = np_full([1], 2)
|
||||||
consume_ndarray_i32_1(n_i32)
|
consume_ndarray_i32_1(n_i32)
|
||||||
|
|
||||||
def test_ndarray_eye():
|
def test_ndarray_eye():
|
||||||
n = np_eye(2)
|
n: ndarray[float, 2] = np_eye(2)
|
||||||
consume_ndarray_2(n)
|
consume_ndarray_2(n)
|
||||||
|
|
||||||
def test_ndarray_identity():
|
def test_ndarray_identity():
|
||||||
n = np_identity(2)
|
n: ndarray[float, 2] = np_identity(2)
|
||||||
consume_ndarray_2(n)
|
consume_ndarray_2(n)
|
||||||
|
|
||||||
def run() -> int32:
|
def run() -> int32:
|
||||||
|
|
Loading…
Reference in New Issue