core: move top level def type vars into `PrimitiveStore`

This commit is contained in:
lyken 2024-06-17 13:34:57 +08:00
parent 2abe75d1f4
commit 4a81ca08d2
12 changed files with 316 additions and 216 deletions

View File

@ -4,7 +4,7 @@ use nac3core::{
symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum},
toplevel::{
helper::PrimDef,
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
numpy::{make_ndarray_ty, unpack_ndarray_params},
DefinitionId, TopLevelDef,
},
typecheck::{
@ -665,11 +665,11 @@ impl InnerResolver {
}
}
(TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PrimDef::NDArray.id() => {
let (ty, ndims) = unpack_ndarray_var_tys(unifier, extracted_ty);
let params = unpack_ndarray_params(unifier, primitives, extracted_ty);
let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?;
if len == 0 {
assert!(matches!(
&*unifier.get_ty(ty),
&*unifier.get_ty(params.dtype),
TypeEnum::TVar { fields: None, range, .. }
if range.is_empty()
));
@ -678,10 +678,14 @@ impl InnerResolver {
let actual_ty =
self.get_list_elem_type(py, obj, len, unifier, defs, primitives)?;
match actual_ty {
Ok(t) => match unifier.unify(ty, t) {
Ok(t) => match unifier.unify(params.dtype, t) {
Ok(()) => {
let ndarray_ty =
make_ndarray_ty(unifier, primitives, Some(ty), Some(ndims));
let ndarray_ty = make_ndarray_ty(
unifier,
primitives,
Some(params.dtype),
Some(params.ndims),
);
Ok(Ok(ndarray_ty))
}
@ -984,7 +988,7 @@ impl InnerResolver {
TypeEnum::TObj { obj_id, params, .. }
if *obj_id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap() =>
{
*params.iter().next().unwrap().1
*params.get(&ctx.primitives.option_type_tvar.id).unwrap()
}
_ => unreachable!("must be option type"),
};

View File

@ -8,7 +8,7 @@ use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
use crate::codegen::stmt::gen_for_callback_incrementing;
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
use crate::toplevel::helper::PrimDef;
use crate::toplevel::numpy::unpack_ndarray_var_tys;
use crate::toplevel::numpy::unpack_ndarray_params;
use crate::typecheck::typedef::Type;
/// Shorthand for [`unreachable!()`] when a type of argument is not supported.
@ -66,7 +66,7 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -128,7 +128,7 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -206,7 +206,7 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -273,7 +273,7 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -339,7 +339,7 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -385,7 +385,7 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -425,7 +425,7 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -490,7 +490,7 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -544,7 +544,7 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -594,7 +594,7 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -692,7 +692,7 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n)
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, a_ty).dtype;
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
@ -792,16 +792,17 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
let ndarray_dtype1 =
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype;
let ndarray_dtype2 =
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1
} else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype
} else {
unreachable!()
};
@ -908,7 +909,7 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n)
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, a_ty).dtype;
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
@ -1008,16 +1009,18 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
let ndarray_dtype1 =
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype;
let ndarray_dtype2 =
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1
} else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype
} else {
unreachable!()
};
@ -1088,7 +1091,7 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -1128,7 +1131,7 @@ pub fn call_numpy_isnan<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -1172,7 +1175,7 @@ pub fn call_numpy_isinf<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -1216,7 +1219,7 @@ pub fn call_numpy_sin<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -1256,7 +1259,7 @@ pub fn call_numpy_cos<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -1296,7 +1299,7 @@ pub fn call_numpy_exp<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -1336,7 +1339,7 @@ pub fn call_numpy_exp2<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -1376,7 +1379,7 @@ pub fn call_numpy_log<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -1416,7 +1419,7 @@ pub fn call_numpy_log10<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -1456,7 +1459,7 @@ pub fn call_numpy_log2<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -1496,7 +1499,7 @@ pub fn call_numpy_fabs<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -1536,7 +1539,7 @@ pub fn call_numpy_sqrt<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -1576,7 +1579,7 @@ pub fn call_numpy_rint<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -1616,7 +1619,7 @@ pub fn call_numpy_tan<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -1656,7 +1659,7 @@ pub fn call_numpy_arcsin<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -1696,7 +1699,7 @@ pub fn call_numpy_arccos<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -1736,7 +1739,7 @@ pub fn call_numpy_arctan<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -1776,7 +1779,7 @@ pub fn call_numpy_sinh<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -1816,7 +1819,7 @@ pub fn call_numpy_cosh<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -1856,7 +1859,7 @@ pub fn call_numpy_tanh<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -1896,7 +1899,7 @@ pub fn call_numpy_arcsinh<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -1936,7 +1939,7 @@ pub fn call_numpy_arccosh<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -1976,7 +1979,7 @@ pub fn call_numpy_arctanh<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -2016,7 +2019,7 @@ pub fn call_numpy_expm1<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -2056,7 +2059,7 @@ pub fn call_numpy_cbrt<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -2096,7 +2099,7 @@ pub fn call_scipy_special_erf<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(z)
if z_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, z_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, z_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -2136,7 +2139,7 @@ pub fn call_scipy_special_erfc<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -2176,7 +2179,7 @@ pub fn call_scipy_special_gamma<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(z)
if z_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, z_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, z_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -2216,7 +2219,7 @@ pub fn call_scipy_special_gammaln<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -2256,7 +2259,7 @@ pub fn call_scipy_special_j0<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -2296,7 +2299,7 @@ pub fn call_scipy_special_j1<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -2345,16 +2348,18 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
let ndarray_dtype1 =
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype;
let ndarray_dtype2 =
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1
} else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype
} else {
unreachable!()
};
@ -2412,16 +2417,18 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
let ndarray_dtype1 =
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype;
let ndarray_dtype2 =
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1
} else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype
} else {
unreachable!()
};
@ -2479,16 +2486,18 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
let ndarray_dtype1 =
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype;
let ndarray_dtype2 =
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1
} else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype
} else {
unreachable!()
};
@ -2546,16 +2555,18 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
let ndarray_dtype1 =
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype;
let ndarray_dtype2 =
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1
} else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype
} else {
unreachable!()
};
@ -2612,12 +2623,18 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>(
let is_ndarray2 =
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype =
if is_ndarray1 { unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 } else { x1_ty };
let dtype = if is_ndarray1 {
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype
} else {
x1_ty
};
let x1_scalar_ty = dtype;
let x2_scalar_ty =
if is_ndarray2 { unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 } else { x2_ty };
let x2_scalar_ty = if is_ndarray2 {
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype
} else {
x2_ty
};
numpy::ndarray_elementwise_binop_impl(
generator,
@ -2669,16 +2686,18 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
let ndarray_dtype1 =
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype;
let ndarray_dtype2 =
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1
} else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype
} else {
unreachable!()
};
@ -2736,16 +2755,18 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
let ndarray_dtype1 =
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype;
let ndarray_dtype2 =
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1
} else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype
} else {
unreachable!()
};

View File

@ -17,7 +17,7 @@ use crate::{
symbol_resolver::{SymbolValue, ValueEnum},
toplevel::{
helper::PrimDef,
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
numpy::{make_ndarray_ty, unpack_ndarray_params},
DefinitionId, TopLevelDef,
},
typecheck::{
@ -150,7 +150,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
TypeEnum::TObj { obj_id, params, .. }
if *obj_id == self.primitives.option.obj_id(&self.unifier).unwrap() =>
{
*params.iter().next().unwrap().1
*params.get(&self.primitives.option_type_tvar.id).unwrap()
}
_ => unreachable!("must be option type"),
};
@ -166,7 +166,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
TypeEnum::TObj { obj_id, params, .. }
if *obj_id == self.primitives.option.obj_id(&self.unifier).unwrap() =>
{
*params.iter().next().unwrap().1
*params.get(&self.primitives.option_type_tvar.id).unwrap()
}
_ => unreachable!("must be option type"),
};
@ -188,6 +188,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
&self.module,
generator,
&mut self.unifier,
&self.primitives,
self.top_level,
&mut self.type_cache,
ty,
@ -205,6 +206,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
&self.module,
generator,
&mut self.unifier,
&self.primitives,
self.top_level,
&mut self.type_cache,
&self.primitives,
@ -1190,8 +1192,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
let is_ndarray2 = ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty2);
let ndarray_dtype1 = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, ty1).dtype;
let ndarray_dtype2 = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, ty2).dtype;
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
@ -1240,8 +1242,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
Ok(Some(res.as_base_value().into()))
} else {
let (ndarray_dtype, _) =
unpack_ndarray_var_tys(&mut ctx.unifier, if is_ndarray1 { ty1 } else { ty2 });
let ndarray_dtype = unpack_ndarray_params(
&ctx.unifier,
&ctx.primitives,
if is_ndarray1 { ty1 } else { ty2 },
)
.dtype;
let ndarray_val = NDArrayValue::from_ptr_val(
if is_ndarray1 { left_val } else { right_val }.into_pointer_value(),
llvm_usize,
@ -1427,7 +1433,7 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
}
} else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
let llvm_usize = generator.get_size_type(ctx.ctx);
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
let ndarray_dtype = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, ty).dtype;
let val = NDArrayValue::from_ptr_val(val.into_pointer_value(), llvm_usize, None);
@ -1511,8 +1517,10 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
return if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty);
let ndarray_dtype1 =
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, left_ty).dtype;
let ndarray_dtype2 =
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, right_ty).dtype;
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
@ -1546,10 +1554,12 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
Ok(Some(res.as_base_value().into()))
} else {
let (ndarray_dtype, _) = unpack_ndarray_var_tys(
&mut ctx.unifier,
let ndarray_dtype = unpack_ndarray_params(
&ctx.unifier,
&ctx.primitives,
if is_ndarray1 { left_ty } else { right_ty },
);
)
.dtype;
let res = numpy::ndarray_elementwise_binop_impl(
generator,
ctx,
@ -2014,10 +2024,13 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
(TypeEnum::TObj { obj_id, params, .. }, TypeEnum::TObj { obj_id: opt_id, .. })
if *obj_id == *opt_id =>
{
ctx.get_llvm_type(generator, *params.iter().next().unwrap().1)
.ptr_type(AddressSpace::default())
.const_null()
.into()
ctx.get_llvm_type(
generator,
*params.get(&ctx.primitives.option_type_tvar.id).unwrap(),
)
.ptr_type(AddressSpace::default())
.const_null()
.into()
}
_ => unreachable!("must be option type"),
}

View File

@ -1,7 +1,7 @@
use crate::{
codegen::classes::{ListType, NDArrayType, ProxyType, RangeType},
symbol_resolver::{StaticValue, SymbolResolver},
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef},
toplevel::{helper::PrimDef, numpy::unpack_ndarray_params, TopLevelContext, TopLevelDef},
typecheck::{
type_inferencer::{CodeLocation, PrimitiveStore},
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
@ -423,6 +423,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
module: &Module<'ctx>,
generator: &mut G,
unifier: &mut Unifier,
store: &PrimitiveStore,
top_level: &TopLevelContext,
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
ty: Type,
@ -443,18 +444,20 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
module,
generator,
unifier,
store,
top_level,
type_cache,
*params.iter().next().unwrap().1,
*params.get(&store.option_type_tvar.id).unwrap(),
)
.ptr_type(AddressSpace::default())
.into()
}
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (dtype, _) = unpack_ndarray_var_tys(unifier, ty);
let dtype = unpack_ndarray_params(unifier, store, ty).dtype;
let element_type = get_llvm_type(
ctx, module, generator, unifier, top_level, type_cache, dtype,
ctx, module, generator, unifier, store, top_level, type_cache,
dtype,
);
NDArrayType::new(generator, ctx, element_type).as_base_type().into()
@ -490,6 +493,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
module,
generator,
unifier,
store,
top_level,
type_cache,
fields[&f.0].0,
@ -506,14 +510,17 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
let fields = ty
.iter()
.map(|ty| {
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty)
get_llvm_type(
ctx, module, generator, unifier, store, top_level, type_cache, *ty,
)
})
.collect_vec();
ctx.struct_type(&fields, false).into()
}
TList { ty } => {
let element_type =
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty);
let element_type = get_llvm_type(
ctx, module, generator, unifier, store, top_level, type_cache, *ty,
);
ListType::new(generator, ctx, element_type).as_base_type().into()
}
@ -540,6 +547,7 @@ fn get_llvm_abi_type<'ctx, G: CodeGenerator + ?Sized>(
module: &Module<'ctx>,
generator: &mut G,
unifier: &mut Unifier,
store: &PrimitiveStore,
top_level: &TopLevelContext,
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
primitives: &PrimitiveStore,
@ -550,7 +558,7 @@ fn get_llvm_abi_type<'ctx, G: CodeGenerator + ?Sized>(
return if unifier.unioned(ty, primitives.bool) {
ctx.bool_type().into()
} else {
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, ty)
get_llvm_type(ctx, module, generator, unifier, store, top_level, type_cache, ty)
};
}
@ -699,6 +707,7 @@ pub fn gen_func_impl<
&module,
generator,
&mut unifier,
&primitives,
top_level_ctx.as_ref(),
&mut type_cache,
&primitives,
@ -715,6 +724,7 @@ pub fn gen_func_impl<
&module,
generator,
&mut unifier,
&primitives,
top_level_ctx.as_ref(),
&mut type_cache,
&primitives,
@ -767,6 +777,7 @@ pub fn gen_func_impl<
&module,
generator,
&mut unifier,
&primitives,
top_level_ctx.as_ref(),
&mut type_cache,
arg.ty,

View File

@ -19,7 +19,7 @@ use crate::{
symbol_resolver::ValueEnum,
toplevel::{
helper::PrimDef,
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
numpy::{make_ndarray_ty, unpack_ndarray_params},
DefinitionId,
},
typecheck::typedef::{FunSignature, Type, TypeEnum},
@ -1776,7 +1776,7 @@ pub fn gen_ndarray_array<'ctx>(
let obj_ty = fun.0.args[0].ty;
let obj_elem_ty = match &*context.unifier.get_ty(obj_ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
unpack_ndarray_var_tys(&mut context.unifier, obj_ty).0
unpack_ndarray_params(&context.unifier, &context.primitives, obj_ty).dtype
}
TypeEnum::TList { ty } => {
@ -1916,7 +1916,7 @@ pub fn gen_ndarray_copy<'ctx>(
let llvm_usize = generator.get_size_type(context.ctx);
let this_ty = obj.as_ref().unwrap().0;
let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty);
let this_elem_ty = unpack_ndarray_params(&context.unifier, &context.primitives, this_ty).dtype;
let this_arg =
obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?;

View File

@ -10,7 +10,7 @@ use crate::{
expr::gen_binop_expr,
gen_in_range_check,
},
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef},
toplevel::{helper::PrimDef, numpy::unpack_ndarray_params, DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type, TypeEnum},
};
use inkwell::{
@ -245,7 +245,8 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) {
TypeEnum::TList { ty } => *ty,
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
unpack_ndarray_var_tys(&mut ctx.unifier, target.custom.unwrap()).0
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, target.custom.unwrap())
.dtype
}
_ => unreachable!(),
};

View File

@ -25,7 +25,7 @@ use crate::{
},
symbol_resolver::SymbolValue,
toplevel::{helper::PrimDef, numpy::make_ndarray_ty},
typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap},
typecheck::typedef::{into_var_map, TypeVar, VarMap},
};
use super::*;
@ -301,10 +301,7 @@ struct BuiltinBuilder<'a> {
is_some_ty: (Type, bool),
unwrap_ty: (Type, bool),
option_tvar: TypeVar,
ndarray_dtype_tvar: TypeVar,
ndarray_ndims_tvar: TypeVar,
ndarray_copy_ty: (Type, bool),
ndarray_fill_ty: (Type, bool),
@ -339,24 +336,19 @@ impl<'a> BuiltinBuilder<'a> {
} = *primitives;
// Option-related
let (is_some_ty, unwrap_ty, option_tvar) =
if let TypeEnum::TObj { fields, params, .. } = unifier.get_ty(option).as_ref() {
let (is_some_ty, unwrap_ty) =
if let TypeEnum::TObj { fields, .. } = unifier.get_ty(option).as_ref() {
(
*fields.get(&PrimDef::OptionIsSome.simple_name().into()).unwrap(),
*fields.get(&PrimDef::OptionUnwrap.simple_name().into()).unwrap(),
iter_type_vars(params).next().unwrap(),
)
} else {
unreachable!()
};
let TypeEnum::TObj { fields: ndarray_fields, params: ndarray_params, .. } =
&*unifier.get_ty(ndarray)
else {
let TypeEnum::TObj { fields: ndarray_fields, .. } = &*unifier.get_ty(ndarray) else {
unreachable!()
};
let ndarray_dtype_tvar = iter_type_vars(ndarray_params).next().unwrap();
let ndarray_ndims_tvar = iter_type_vars(ndarray_params).nth(1).unwrap();
let ndarray_copy_ty =
*ndarray_fields.get(&PrimDef::NDArrayCopy.simple_name().into()).unwrap();
let ndarray_fill_ty =
@ -398,10 +390,7 @@ impl<'a> BuiltinBuilder<'a> {
is_some_ty,
unwrap_ty,
option_tvar,
ndarray_dtype_tvar,
ndarray_ndims_tvar,
ndarray_copy_ty,
ndarray_fill_ty,
@ -622,7 +611,7 @@ impl<'a> BuiltinBuilder<'a> {
PrimDef::Option => TopLevelDef::Class {
name: prim.name().into(),
object_id: prim.id(),
type_vars: vec![self.option_tvar.ty],
type_vars: vec![self.primitives.option_type_tvar.ty],
fields: vec![],
methods: vec![
Self::create_method(PrimDef::OptionIsSome, self.is_some_ty.0),
@ -642,7 +631,7 @@ impl<'a> BuiltinBuilder<'a> {
name: prim.name().into(),
simple_name: prim.simple_name().into(),
signature: self.unwrap_ty.0,
var_id: vec![self.option_tvar.id],
var_id: vec![self.primitives.option_type_tvar.id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
@ -656,7 +645,7 @@ impl<'a> BuiltinBuilder<'a> {
name: prim.name().to_string(),
simple_name: prim.simple_name().into(),
signature: self.is_some_ty.0,
var_id: vec![self.option_tvar.id],
var_id: vec![self.primitives.option_type_tvar.id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
@ -693,13 +682,13 @@ impl<'a> BuiltinBuilder<'a> {
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg {
name: "n".into(),
ty: self.option_tvar.ty,
ty: self.primitives.option_type_tvar.ty,
default_value: None,
}],
ret: self.primitives.option,
vars: into_var_map([self.option_tvar]),
vars: into_var_map([self.primitives.option_type_tvar]),
})),
var_id: vec![self.option_tvar.id],
var_id: vec![self.primitives.option_type_tvar.id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
@ -735,7 +724,10 @@ impl<'a> BuiltinBuilder<'a> {
PrimDef::NDArray => TopLevelDef::Class {
name: prim.name().into(),
object_id: prim.id(),
type_vars: vec![self.ndarray_dtype_tvar.ty, self.ndarray_ndims_tvar.ty],
type_vars: vec![
self.primitives.ndarray_dtype_tvar.ty,
self.primitives.ndarray_ndims_tvar.ty,
],
fields: Vec::default(),
methods: vec![
Self::create_method(PrimDef::NDArrayCopy, self.ndarray_copy_ty.0),
@ -751,7 +743,10 @@ impl<'a> BuiltinBuilder<'a> {
name: prim.name().into(),
simple_name: prim.simple_name().into(),
signature: self.ndarray_copy_ty.0,
var_id: vec![self.ndarray_dtype_tvar.id, self.ndarray_ndims_tvar.id],
var_id: vec![
self.primitives.ndarray_dtype_tvar.id,
self.primitives.ndarray_ndims_tvar.id,
],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
@ -768,7 +763,10 @@ impl<'a> BuiltinBuilder<'a> {
name: prim.name().into(),
simple_name: prim.simple_name().into(),
signature: self.ndarray_fill_ty.0,
var_id: vec![self.ndarray_dtype_tvar.id, self.ndarray_ndims_tvar.id],
var_id: vec![
self.primitives.ndarray_dtype_tvar.id,
self.primitives.ndarray_ndims_tvar.id,
],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,

View File

@ -1,8 +1,8 @@
use std::convert::TryInto;
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::numpy::unpack_ndarray_var_tys;
use crate::typecheck::typedef::{into_var_map, Mapping, TypeVarId, VarMap};
use crate::toplevel::numpy::unpack_ndarray_params;
use crate::typecheck::typedef::{into_var_map, Mapping, TypeVar, TypeVarId, VarMap};
use nac3parser::ast::{Constant, Location};
use strum::IntoEnumIterator;
use strum_macros::EnumIter;
@ -286,6 +286,18 @@ pub fn make_exception_fields(int32: Type, int64: Type, str: Type) -> Vec<(StrRef
]
}
pub fn make_option_type_tvar(unifier: &mut Unifier) -> TypeVar {
unifier.get_fresh_var(Some("option_type_var".into()), None)
}
pub fn make_ndarray_dtype_tvar(unifier: &mut Unifier) -> TypeVar {
unifier.get_fresh_var(Some("ndarray_dtype".into()), None)
}
pub fn make_ndarray_ndims_tvar(unifier: &mut Unifier, size_ty: Type) -> TypeVar {
unifier.get_fresh_const_generic_var(size_ty, Some("ndarray_ndims".into()), None)
}
impl TopLevelDef {
pub fn to_string(&self, unifier: &mut Unifier) -> String {
match self {
@ -381,16 +393,16 @@ impl TopLevelComposer {
params: VarMap::new(),
});
let option_type_var = unifier.get_fresh_var(Some("option_type_var".into()), None);
let option_type_tvar = make_option_type_tvar(&mut unifier);
let is_some_type_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![],
ret: bool,
vars: into_var_map([option_type_var]),
vars: into_var_map([option_type_tvar]),
}));
let unwrap_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![],
ret: option_type_var.ty,
vars: into_var_map([option_type_var]),
ret: option_type_tvar.ty,
vars: into_var_map([option_type_tvar]),
}));
let option = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Option.id(),
@ -401,7 +413,7 @@ impl TopLevelComposer {
]
.into_iter()
.collect::<HashMap<_, _>>(),
params: into_var_map([option_type_var]),
params: into_var_map([option_type_tvar]),
});
let size_t_ty = match size_t {
@ -410,9 +422,8 @@ impl TopLevelComposer {
_ => unreachable!(),
};
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
let ndarray_ndims_tvar =
unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None);
let ndarray_dtype_tvar = make_ndarray_dtype_tvar(&mut unifier);
let ndarray_ndims_tvar = make_ndarray_ndims_tvar(&mut unifier, size_t_ty);
let ndarray_copy_fun_ret_ty = unifier.get_fresh_var(None, None);
let ndarray_copy_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![],
@ -451,7 +462,10 @@ impl TopLevelComposer {
str,
exception,
option,
option_type_tvar,
ndarray,
ndarray_dtype_tvar,
ndarray_ndims_tvar,
size_t,
};
unifier.put_primitive_store(&primitives);
@ -881,22 +895,26 @@ pub fn parse_parameter_default_value(
}
/// Obtains the element type of an array-like type.
pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type {
pub fn arraylike_flatten_element_type(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
) -> Type {
match &*unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
unpack_ndarray_var_tys(unifier, ty).0
unpack_ndarray_params(unifier, store, ty).dtype
}
TypeEnum::TList { ty } => arraylike_flatten_element_type(unifier, *ty),
TypeEnum::TList { ty } => arraylike_flatten_element_type(unifier, store, *ty),
_ => ty,
}
}
/// Obtains the number of dimensions of an array-like type.
pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 {
pub fn arraylike_get_ndims(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) -> u64 {
match &*unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let ndims = unpack_ndarray_var_tys(unifier, ty).1;
let ndims = unpack_ndarray_params(unifier, store, ty).ndims;
let TypeEnum::TLiteral { values, .. } = &*unifier.get_ty_immutable(ndims) else {
panic!("Expected TLiteral for ndarray.ndims, got {}", unifier.stringify(ndims))
};
@ -908,7 +926,7 @@ pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 {
u64::try_from(values[0].clone()).unwrap()
}
TypeEnum::TList { ty } => arraylike_get_ndims(unifier, *ty) + 1,
TypeEnum::TList { ty } => arraylike_get_ndims(unifier, store, *ty) + 1,
_ => 0,
}
}

View File

@ -2,7 +2,7 @@ use crate::{
toplevel::helper::PrimDef,
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{Type, TypeEnum, TypeVarId, Unifier, VarMap},
typedef::{Type, TypeEnum, Unifier, VarMap},
},
};
use itertools::Itertools;
@ -57,29 +57,25 @@ pub fn subst_ndarray_tvars(
unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray)
}
fn unpack_ndarray_tvars(unifier: &mut Unifier, ndarray: Type) -> Vec<(TypeVarId, Type)> {
#[derive(Clone, Copy, Debug)]
pub struct NDArrayParams {
pub dtype: Type,
pub ndims: Type,
}
/// Extract the [`Type`]s of `ndarray`.
#[must_use]
pub fn unpack_ndarray_params(
unifier: &Unifier,
store: &PrimitiveStore,
ndarray: Type,
) -> NDArrayParams {
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
};
debug_assert_eq!(*obj_id, PrimDef::NDArray.id());
debug_assert_eq!(params.len(), 2);
params
.iter()
.sorted_by_key(|(obj_id, _)| *obj_id)
.map(|(var_id, ty)| (*var_id, *ty))
.collect_vec()
}
/// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds
/// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray`
/// respectively.
pub fn unpack_ndarray_var_ids(unifier: &mut Unifier, ndarray: Type) -> (TypeVarId, TypeVarId) {
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.0).collect_tuple().unwrap()
}
/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to
/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively.
pub fn unpack_ndarray_var_tys(unifier: &mut Unifier, ndarray: Type) -> (Type, Type) {
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.1).collect_tuple().unwrap()
let dtype = *params.get(&store.ndarray_dtype_tvar.id).unwrap();
let ndims = *params.get(&store.ndarray_ndims_tvar.id).unwrap();
NDArrayParams { dtype, ndims }
}

View File

@ -1,6 +1,6 @@
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PrimDef;
use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys};
use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_params};
use crate::typecheck::{
type_inferencer::*,
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
@ -369,16 +369,16 @@ pub fn typeof_ndarray_broadcast(
if is_left_ndarray && is_right_ndarray {
// Perform broadcasting on two ndarray operands.
let (left_ty_dtype, left_ty_ndims) = unpack_ndarray_var_tys(unifier, left);
let (right_ty_dtype, right_ty_ndims) = unpack_ndarray_var_tys(unifier, right);
let left_params = unpack_ndarray_params(unifier, primitives, left);
let right_params = unpack_ndarray_params(unifier, primitives, right);
assert!(unifier.unioned(left_ty_dtype, right_ty_dtype));
assert!(unifier.unioned(left_params.dtype, right_params.dtype));
let left_ty_ndims = match &*unifier.get_ty_immutable(left_ty_ndims) {
let left_ty_ndims = match &*unifier.get_ty_immutable(left_params.ndims) {
TypeEnum::TLiteral { values, .. } => values.clone(),
_ => unreachable!(),
};
let right_ty_ndims = match &*unifier.get_ty_immutable(right_ty_ndims) {
let right_ty_ndims = match &*unifier.get_ty_immutable(right_params.ndims) {
TypeEnum::TLiteral { values, .. } => values.clone(),
_ => unreachable!(),
};
@ -397,11 +397,11 @@ pub fn typeof_ndarray_broadcast(
.collect_vec();
let res_ndims = unifier.get_fresh_literal(res_ndims, None);
Ok(make_ndarray_ty(unifier, primitives, Some(left_ty_dtype), Some(res_ndims)))
Ok(make_ndarray_ty(unifier, primitives, Some(left_params.dtype), Some(res_ndims)))
} else {
let (ndarray_ty, scalar_ty) = if is_left_ndarray { (left, right) } else { (right, left) };
let (ndarray_ty_dtype, _) = unpack_ndarray_var_tys(unifier, ndarray_ty);
let ndarray_ty_dtype = unpack_ndarray_params(unifier, primitives, ndarray_ty).dtype;
if unifier.unioned(ndarray_ty_dtype, scalar_ty) {
Ok(ndarray_ty)
@ -444,7 +444,7 @@ pub fn typeof_binop(
}
Operator::MatMult => {
let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs);
let lhs_ndims = unpack_ndarray_params(unifier, primitives, lhs).ndims;
let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) {
TypeEnum::TLiteral { values, .. } => {
assert_eq!(values.len(), 1);
@ -452,7 +452,7 @@ pub fn typeof_binop(
}
_ => unreachable!(),
};
let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs);
let rhs_ndims = unpack_ndarray_params(unifier, primitives, rhs).ndims;
let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) {
TypeEnum::TLiteral { values, .. } => {
assert_eq!(values.len(), 1);
@ -552,7 +552,7 @@ pub fn typeof_unaryop(
Unaryop::UAdd | Unaryop::USub => {
if operand_obj_id.is_some_and(|id| id == PrimDef::NDArray.id()) {
let (dtype, _) = unpack_ndarray_var_tys(unifier, operand);
let dtype = unpack_ndarray_params(unifier, primitives, operand).dtype;
if dtype.obj_id(unifier).is_some_and(|id| id == PrimDef::Bool.id()) {
return Err(if op == Unaryop::UAdd {
"The ufunc 'positive' cannot be applied to ndarray[bool, N]".to_string()
@ -586,7 +586,7 @@ pub fn typeof_cmpop(
Ok(Some(if is_left_ndarray || is_right_ndarray {
let brd = typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?;
let (_, ndims) = unpack_ndarray_var_tys(unifier, brd);
let ndims = unpack_ndarray_params(unifier, primitives, brd).ndims;
make_ndarray_ty(unifier, primitives, Some(primitives.bool), Some(ndims))
} else if unifier.unioned(lhs, rhs) {
@ -653,8 +653,8 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None);
let ndarray_unsized_t =
make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.ty));
let (ndarray_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_t);
let (ndarray_unsized_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_unsized_t);
let ndarray_dtype_t = unpack_ndarray_params(unifier, store, ndarray_t).dtype;
let ndarray_unsized_dtype_t = unpack_ndarray_params(unifier, store, ndarray_unsized_t).dtype;
impl_basic_arithmetic(
unifier,
store,

View File

@ -4,13 +4,15 @@ use std::iter::once;
use std::ops::Not;
use std::{cell::RefCell, sync::Arc};
use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap};
use super::typedef::{
Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, TypeVar, Unifier, VarMap,
};
use super::{magic_methods::*, type_error::TypeError, typedef::CallId};
use crate::{
symbol_resolver::{SymbolResolver, SymbolValue},
toplevel::{
helper::{arraylike_flatten_element_type, arraylike_get_ndims, PrimDef},
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
numpy::{make_ndarray_ty, unpack_ndarray_params},
TopLevelContext,
},
};
@ -49,7 +51,11 @@ pub struct PrimitiveStore {
pub str: Type,
pub exception: Type,
pub option: Type,
/// The contained type of an `Option`
pub option_type_tvar: TypeVar,
pub ndarray: Type,
pub ndarray_dtype_tvar: TypeVar,
pub ndarray_ndims_tvar: TypeVar,
pub size_t: u32,
}
@ -896,7 +902,8 @@ impl<'a> Inferencer<'a> {
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
{
let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
let ndarray_ndims =
unpack_ndarray_params(self.unifier, self.primitives, arg0_ty).ndims;
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims))
} else {
@ -934,9 +941,7 @@ impl<'a> Inferencer<'a> {
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
{
let (ndarray_dtype, _) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
ndarray_dtype
unpack_ndarray_params(self.unifier, self.primitives, arg0_ty).dtype
} else {
arg0_ty
};
@ -988,14 +993,14 @@ impl<'a> Inferencer<'a> {
let arg0_dtype =
if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
unpack_ndarray_var_tys(self.unifier, arg0_ty).0
unpack_ndarray_params(self.unifier, self.primitives, arg0_ty).dtype
} else {
arg0_ty
};
let arg1_dtype =
if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
unpack_ndarray_var_tys(self.unifier, arg1_ty).0
unpack_ndarray_params(self.unifier, self.primitives, arg1_ty).dtype
} else {
arg1_ty
};
@ -1026,7 +1031,8 @@ impl<'a> Inferencer<'a> {
// (float, int32), so convert it to align with the dtype of the first arg
let arg1_ty = if id == &"np_ldexp".into() {
if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, arg1_ty);
let ndims =
unpack_ndarray_params(self.unifier, self.primitives, arg1_ty).ndims;
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndims))
} else {
@ -1115,7 +1121,8 @@ impl<'a> Inferencer<'a> {
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
{
let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
let ndarray_ndims =
unpack_ndarray_params(self.unifier, self.primitives, arg0_ty).ndims;
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims))
} else {
@ -1258,7 +1265,8 @@ impl<'a> Inferencer<'a> {
let ndmin_kw =
keywords.iter().find(|kwarg| kwarg.node.arg.is_some_and(|id| id == "ndmin".into()));
let ty = arraylike_flatten_element_type(self.unifier, arg0.custom.unwrap());
let ty =
arraylike_flatten_element_type(self.unifier, self.primitives, arg0.custom.unwrap());
let ndims = if let Some(ndmin_kw) = ndmin_kw {
match &ndmin_kw.node.value.node {
ExprKind::Constant { value, .. } => match value {
@ -1266,10 +1274,10 @@ impl<'a> Inferencer<'a> {
_ => return Err(HashSet::from(["Expected uint64 for ndims".to_string()])),
},
_ => arraylike_get_ndims(self.unifier, arg0.custom.unwrap()),
_ => arraylike_get_ndims(self.unifier, self.primitives, arg0.custom.unwrap()),
}
} else {
arraylike_get_ndims(self.unifier, arg0.custom.unwrap())
arraylike_get_ndims(self.unifier, self.primitives, arg0.custom.unwrap())
};
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
let ret = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims));
@ -1666,8 +1674,12 @@ impl<'a> Inferencer<'a> {
let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) {
TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }),
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (_, ndims) =
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
let ndims = unpack_ndarray_params(
self.unifier,
self.primitives,
value.custom.unwrap(),
)
.ndims;
make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims))
}
@ -1680,8 +1692,13 @@ impl<'a> Inferencer<'a> {
ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
match &*self.unifier.get_ty(value.custom.unwrap()) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (_, ndims) =
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
let ndims = unpack_ndarray_params(
self.unifier,
self.primitives,
value.custom.unwrap(),
)
.ndims;
self.infer_subscript_ndarray(value, ty, ndims)
}
_ => {
@ -1724,7 +1741,9 @@ impl<'a> Inferencer<'a> {
}
}
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
let ndims =
unpack_ndarray_params(self.unifier, self.primitives, value.custom.unwrap())
.ndims;
let ndarray_ty =
make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims));
self.constrain(value.custom.unwrap(), ndarray_ty, &value.location)?;
@ -1751,8 +1770,12 @@ impl<'a> Inferencer<'a> {
Ok(ty)
}
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (_, ndims) =
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
let ndims = unpack_ndarray_params(
self.unifier,
self.primitives,
value.custom.unwrap(),
)
.ndims;
let valid_index_tys = [self.primitives.int32, self.primitives.isize()]
.into_iter()

View File

@ -1,5 +1,8 @@
use super::super::{magic_methods::with_fields, typedef::*};
use super::*;
use crate::toplevel::helper::{
make_ndarray_dtype_tvar, make_ndarray_ndims_tvar, make_option_type_tvar,
};
use crate::{
codegen::CodeGenContext,
symbol_resolver::ValueEnum,
@ -132,14 +135,14 @@ impl TestEnvironment {
fields: HashMap::new(),
params: VarMap::new(),
});
let option_type_tvar = make_option_type_tvar(&mut unifier);
let option = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Option.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
let ndarray_ndims_tvar =
unifier.get_fresh_const_generic_var(uint64, Some("ndarray_ndims".into()), None);
let ndarray_dtype_tvar = make_ndarray_dtype_tvar(&mut unifier);
let ndarray_ndims_tvar = make_ndarray_ndims_tvar(&mut unifier, uint64);
let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::NDArray.id(),
fields: HashMap::new(),
@ -157,7 +160,10 @@ impl TestEnvironment {
uint32,
uint64,
option,
option_type_tvar,
ndarray,
ndarray_dtype_tvar,
ndarray_ndims_tvar,
size_t: 64,
};
unifier.put_primitive_store(&primitives);
@ -268,16 +274,22 @@ impl TestEnvironment {
fields: HashMap::new(),
params: VarMap::new(),
});
let option_type_tvar = make_option_type_tvar(&mut unifier);
let option = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Option.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
let ndarray_dtype_tvar = make_ndarray_dtype_tvar(&mut unifier);
let ndarray_ndims_tvar = make_ndarray_ndims_tvar(&mut unifier, uint64);
let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::NDArray.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
identifier_mapping.insert("None".into(), none);
for (i, name) in ["int32", "int64", "float", "bool", "none", "range", "str", "Exception"]
.iter()
@ -312,7 +324,10 @@ impl TestEnvironment {
uint32,
uint64,
option,
option_type_tvar,
ndarray,
ndarray_dtype_tvar,
ndarray_ndims_tvar,
size_t: 64,
};