core: move top level def type vars into PrimitiveStore #418

Closed
lyken wants to merge 2 commits from refactor-primstore into master
13 changed files with 318 additions and 218 deletions

View File

@ -4,7 +4,7 @@ use nac3core::{
symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum}, symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum},
toplevel::{ toplevel::{
helper::PrimDef, helper::PrimDef,
numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, prim_types::{make_ndarray_ty, unpack_ndarray_params},
DefinitionId, TopLevelDef, DefinitionId, TopLevelDef,
}, },
typecheck::{ typecheck::{
@ -665,11 +665,11 @@ impl InnerResolver {
} }
} }
(TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PrimDef::NDArray.id() => { (TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PrimDef::NDArray.id() => {
let (ty, ndims) = unpack_ndarray_var_tys(unifier, extracted_ty); let params = unpack_ndarray_params(unifier, primitives, extracted_ty);
let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?; let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?;
if len == 0 { if len == 0 {
assert!(matches!( assert!(matches!(
&*unifier.get_ty(ty), &*unifier.get_ty(params.dtype),
TypeEnum::TVar { fields: None, range, .. } TypeEnum::TVar { fields: None, range, .. }
if range.is_empty() if range.is_empty()
)); ));
@ -678,10 +678,14 @@ impl InnerResolver {
let actual_ty = let actual_ty =
self.get_list_elem_type(py, obj, len, unifier, defs, primitives)?; self.get_list_elem_type(py, obj, len, unifier, defs, primitives)?;
match actual_ty { match actual_ty {
Ok(t) => match unifier.unify(ty, t) { Ok(t) => match unifier.unify(params.dtype, t) {
Ok(()) => { Ok(()) => {
let ndarray_ty = let ndarray_ty = make_ndarray_ty(
make_ndarray_ty(unifier, primitives, Some(ty), Some(ndims)); unifier,
primitives,
Some(params.dtype),
Some(params.ndims),
);
Ok(Ok(ndarray_ty)) Ok(Ok(ndarray_ty))
} }
@ -984,7 +988,7 @@ impl InnerResolver {
TypeEnum::TObj { obj_id, params, .. } TypeEnum::TObj { obj_id, params, .. }
if *obj_id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap() => if *obj_id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap() =>
{ {
*params.iter().next().unwrap().1 *params.get(&ctx.primitives.option_type_tvar.id).unwrap()
} }
_ => unreachable!("must be option type"), _ => unreachable!("must be option type"),
}; };

View File

@ -8,7 +8,7 @@ use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
use crate::codegen::stmt::gen_for_callback_incrementing; use crate::codegen::stmt::gen_for_callback_incrementing;
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator}; use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
use crate::toplevel::helper::PrimDef; use crate::toplevel::helper::PrimDef;
use crate::toplevel::numpy::unpack_ndarray_var_tys; use crate::toplevel::prim_types::unpack_ndarray_params;
use crate::typecheck::typedef::Type; use crate::typecheck::typedef::Type;
/// Shorthand for [`unreachable!()`] when a type of argument is not supported. /// Shorthand for [`unreachable!()`] when a type of argument is not supported.
@ -66,7 +66,7 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -128,7 +128,7 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -206,7 +206,7 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -273,7 +273,7 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -339,7 +339,7 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -385,7 +385,7 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -425,7 +425,7 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -490,7 +490,7 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -544,7 +544,7 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -594,7 +594,7 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -692,7 +692,7 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, a_ty).dtype;
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
@ -792,16 +792,17 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 { let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let ndarray_dtype1 =
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype;
let ndarray_dtype2 =
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1 ndarray_dtype1
} else if is_ndarray1 { } else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype
} else { } else {
unreachable!() unreachable!()
}; };
@ -908,7 +909,7 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, a_ty).dtype;
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
@ -1008,16 +1009,18 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 { let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let ndarray_dtype1 =
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype;
let ndarray_dtype2 =
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1 ndarray_dtype1
} else if is_ndarray1 { } else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype
} else { } else {
unreachable!() unreachable!()
}; };
@ -1088,7 +1091,7 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -1128,7 +1131,7 @@ pub fn call_numpy_isnan<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x) BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -1172,7 +1175,7 @@ pub fn call_numpy_isinf<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x) BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -1216,7 +1219,7 @@ pub fn call_numpy_sin<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x) BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -1256,7 +1259,7 @@ pub fn call_numpy_cos<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x) BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -1296,7 +1299,7 @@ pub fn call_numpy_exp<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x) BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -1336,7 +1339,7 @@ pub fn call_numpy_exp2<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x) BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -1376,7 +1379,7 @@ pub fn call_numpy_log<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x) BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -1416,7 +1419,7 @@ pub fn call_numpy_log10<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x) BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -1456,7 +1459,7 @@ pub fn call_numpy_log2<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x) BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -1496,7 +1499,7 @@ pub fn call_numpy_fabs<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x) BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -1536,7 +1539,7 @@ pub fn call_numpy_sqrt<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x) BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -1576,7 +1579,7 @@ pub fn call_numpy_rint<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x) BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -1616,7 +1619,7 @@ pub fn call_numpy_tan<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x) BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -1656,7 +1659,7 @@ pub fn call_numpy_arcsin<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x) BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -1696,7 +1699,7 @@ pub fn call_numpy_arccos<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x) BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -1736,7 +1739,7 @@ pub fn call_numpy_arctan<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x) BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -1776,7 +1779,7 @@ pub fn call_numpy_sinh<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x) BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -1816,7 +1819,7 @@ pub fn call_numpy_cosh<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x) BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -1856,7 +1859,7 @@ pub fn call_numpy_tanh<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x) BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -1896,7 +1899,7 @@ pub fn call_numpy_arcsinh<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x) BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -1936,7 +1939,7 @@ pub fn call_numpy_arccosh<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x) BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -1976,7 +1979,7 @@ pub fn call_numpy_arctanh<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x) BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -2016,7 +2019,7 @@ pub fn call_numpy_expm1<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x) BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -2056,7 +2059,7 @@ pub fn call_numpy_cbrt<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x) BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -2096,7 +2099,7 @@ pub fn call_scipy_special_erf<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(z) BasicValueEnum::PointerValue(z)
if z_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if z_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, z_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, z_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -2136,7 +2139,7 @@ pub fn call_scipy_special_erfc<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x) BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -2176,7 +2179,7 @@ pub fn call_scipy_special_gamma<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(z) BasicValueEnum::PointerValue(z)
if z_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if z_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, z_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, z_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -2216,7 +2219,7 @@ pub fn call_scipy_special_gammaln<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x) BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -2256,7 +2259,7 @@ pub fn call_scipy_special_j0<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x) BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -2296,7 +2299,7 @@ pub fn call_scipy_special_j1<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(x) BasicValueEnum::PointerValue(x)
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -2345,16 +2348,18 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 { let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let ndarray_dtype1 =
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype;
let ndarray_dtype2 =
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1 ndarray_dtype1
} else if is_ndarray1 { } else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype
} else { } else {
unreachable!() unreachable!()
}; };
@ -2412,16 +2417,18 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 { let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let ndarray_dtype1 =
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype;
let ndarray_dtype2 =
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1 ndarray_dtype1
} else if is_ndarray1 { } else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype
} else { } else {
unreachable!() unreachable!()
}; };
@ -2479,16 +2486,18 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 { let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let ndarray_dtype1 =
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype;
let ndarray_dtype2 =
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1 ndarray_dtype1
} else if is_ndarray1 { } else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype
} else { } else {
unreachable!() unreachable!()
}; };
@ -2546,16 +2555,18 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 { let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let ndarray_dtype1 =
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype;
let ndarray_dtype2 =
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1 ndarray_dtype1
} else if is_ndarray1 { } else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype
} else { } else {
unreachable!() unreachable!()
}; };
@ -2612,12 +2623,18 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>(
let is_ndarray2 = let is_ndarray2 =
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = let dtype = if is_ndarray1 {
if is_ndarray1 { unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 } else { x1_ty }; unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype
} else {
x1_ty
};
let x1_scalar_ty = dtype; let x1_scalar_ty = dtype;
let x2_scalar_ty = let x2_scalar_ty = if is_ndarray2 {
if is_ndarray2 { unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 } else { x2_ty }; unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype
} else {
x2_ty
};
numpy::ndarray_elementwise_binop_impl( numpy::ndarray_elementwise_binop_impl(
generator, generator,
@ -2669,16 +2686,18 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 { let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let ndarray_dtype1 =
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype;
let ndarray_dtype2 =
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1 ndarray_dtype1
} else if is_ndarray1 { } else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype
} else { } else {
unreachable!() unreachable!()
}; };
@ -2736,16 +2755,18 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 { let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let ndarray_dtype1 =
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype;
let ndarray_dtype2 =
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1 ndarray_dtype1
} else if is_ndarray1 { } else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype
} else { } else {
unreachable!() unreachable!()
}; };

View File

@ -17,7 +17,7 @@ use crate::{
symbol_resolver::{SymbolValue, ValueEnum}, symbol_resolver::{SymbolValue, ValueEnum},
toplevel::{ toplevel::{
helper::PrimDef, helper::PrimDef,
numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, prim_types::{make_ndarray_ty, unpack_ndarray_params},
DefinitionId, TopLevelDef, DefinitionId, TopLevelDef,
}, },
typecheck::{ typecheck::{
@ -150,7 +150,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
TypeEnum::TObj { obj_id, params, .. } TypeEnum::TObj { obj_id, params, .. }
if *obj_id == self.primitives.option.obj_id(&self.unifier).unwrap() => if *obj_id == self.primitives.option.obj_id(&self.unifier).unwrap() =>
{ {
*params.iter().next().unwrap().1 *params.get(&self.primitives.option_type_tvar.id).unwrap()
} }
_ => unreachable!("must be option type"), _ => unreachable!("must be option type"),
}; };
@ -166,7 +166,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
TypeEnum::TObj { obj_id, params, .. } TypeEnum::TObj { obj_id, params, .. }
if *obj_id == self.primitives.option.obj_id(&self.unifier).unwrap() => if *obj_id == self.primitives.option.obj_id(&self.unifier).unwrap() =>
{ {
*params.iter().next().unwrap().1 *params.get(&self.primitives.option_type_tvar.id).unwrap()
} }
_ => unreachable!("must be option type"), _ => unreachable!("must be option type"),
}; };
@ -188,6 +188,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
&self.module, &self.module,
generator, generator,
&mut self.unifier, &mut self.unifier,
&self.primitives,
self.top_level, self.top_level,
&mut self.type_cache, &mut self.type_cache,
ty, ty,
@ -205,6 +206,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
&self.module, &self.module,
generator, generator,
&mut self.unifier, &mut self.unifier,
&self.primitives,
self.top_level, self.top_level,
&mut self.type_cache, &mut self.type_cache,
&self.primitives, &self.primitives,
@ -1190,8 +1192,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
let is_ndarray2 = ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let is_ndarray2 = ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
if is_ndarray1 && is_ndarray2 { if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1); let ndarray_dtype1 = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, ty1).dtype;
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty2); let ndarray_dtype2 = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, ty2).dtype;
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
@ -1240,8 +1242,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
Ok(Some(res.as_base_value().into())) Ok(Some(res.as_base_value().into()))
} else { } else {
let (ndarray_dtype, _) = let ndarray_dtype = unpack_ndarray_params(
unpack_ndarray_var_tys(&mut ctx.unifier, if is_ndarray1 { ty1 } else { ty2 }); &ctx.unifier,
&ctx.primitives,
if is_ndarray1 { ty1 } else { ty2 },
)
.dtype;
let ndarray_val = NDArrayValue::from_ptr_val( let ndarray_val = NDArrayValue::from_ptr_val(
if is_ndarray1 { left_val } else { right_val }.into_pointer_value(), if is_ndarray1 { left_val } else { right_val }.into_pointer_value(),
llvm_usize, llvm_usize,
@ -1427,7 +1433,7 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
} }
} else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { } else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); let ndarray_dtype = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, ty).dtype;
let val = NDArrayValue::from_ptr_val(val.into_pointer_value(), llvm_usize, None); let val = NDArrayValue::from_ptr_val(val.into_pointer_value(), llvm_usize, None);
@ -1511,8 +1517,10 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
return if is_ndarray1 && is_ndarray2 { return if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty); let ndarray_dtype1 =
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty); unpack_ndarray_params(&ctx.unifier, &ctx.primitives, left_ty).dtype;
let ndarray_dtype2 =
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, right_ty).dtype;
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
@ -1546,10 +1554,12 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
Ok(Some(res.as_base_value().into())) Ok(Some(res.as_base_value().into()))
} else { } else {
let (ndarray_dtype, _) = unpack_ndarray_var_tys( let ndarray_dtype = unpack_ndarray_params(
&mut ctx.unifier, &ctx.unifier,
&ctx.primitives,
if is_ndarray1 { left_ty } else { right_ty }, if is_ndarray1 { left_ty } else { right_ty },
); )
.dtype;
let res = numpy::ndarray_elementwise_binop_impl( let res = numpy::ndarray_elementwise_binop_impl(
generator, generator,
ctx, ctx,
@ -2014,7 +2024,10 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
(TypeEnum::TObj { obj_id, params, .. }, TypeEnum::TObj { obj_id: opt_id, .. }) (TypeEnum::TObj { obj_id, params, .. }, TypeEnum::TObj { obj_id: opt_id, .. })
if *obj_id == *opt_id => if *obj_id == *opt_id =>
{ {
ctx.get_llvm_type(generator, *params.iter().next().unwrap().1) ctx.get_llvm_type(
generator,
*params.get(&ctx.primitives.option_type_tvar.id).unwrap(),
)
.ptr_type(AddressSpace::default()) .ptr_type(AddressSpace::default())
.const_null() .const_null()
.into() .into()

View File

@ -1,7 +1,7 @@
use crate::{ use crate::{
codegen::classes::{ListType, NDArrayType, ProxyType, RangeType}, codegen::classes::{ListType, NDArrayType, ProxyType, RangeType},
symbol_resolver::{StaticValue, SymbolResolver}, symbol_resolver::{StaticValue, SymbolResolver},
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef}, toplevel::{helper::PrimDef, prim_types::unpack_ndarray_params, TopLevelContext, TopLevelDef},
typecheck::{ typecheck::{
type_inferencer::{CodeLocation, PrimitiveStore}, type_inferencer::{CodeLocation, PrimitiveStore},
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier}, typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
@ -423,6 +423,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
module: &Module<'ctx>, module: &Module<'ctx>,
generator: &mut G, generator: &mut G,
unifier: &mut Unifier, unifier: &mut Unifier,
store: &PrimitiveStore,
top_level: &TopLevelContext, top_level: &TopLevelContext,
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>, type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
ty: Type, ty: Type,
@ -443,18 +444,20 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
module, module,
generator, generator,
unifier, unifier,
store,
top_level, top_level,
type_cache, type_cache,
*params.iter().next().unwrap().1, *params.get(&store.option_type_tvar.id).unwrap(),
) )
.ptr_type(AddressSpace::default()) .ptr_type(AddressSpace::default())
.into() .into()
} }
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (dtype, _) = unpack_ndarray_var_tys(unifier, ty); let dtype = unpack_ndarray_params(unifier, store, ty).dtype;
let element_type = get_llvm_type( let element_type = get_llvm_type(
ctx, module, generator, unifier, top_level, type_cache, dtype, ctx, module, generator, unifier, store, top_level, type_cache,
dtype,
); );
NDArrayType::new(generator, ctx, element_type).as_base_type().into() NDArrayType::new(generator, ctx, element_type).as_base_type().into()
@ -490,6 +493,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
module, module,
generator, generator,
unifier, unifier,
store,
top_level, top_level,
type_cache, type_cache,
fields[&f.0].0, fields[&f.0].0,
@ -506,14 +510,17 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
let fields = ty let fields = ty
.iter() .iter()
.map(|ty| { .map(|ty| {
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty) get_llvm_type(
ctx, module, generator, unifier, store, top_level, type_cache, *ty,
)
}) })
.collect_vec(); .collect_vec();
ctx.struct_type(&fields, false).into() ctx.struct_type(&fields, false).into()
} }
TList { ty } => { TList { ty } => {
let element_type = let element_type = get_llvm_type(
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty); ctx, module, generator, unifier, store, top_level, type_cache, *ty,
);
ListType::new(generator, ctx, element_type).as_base_type().into() ListType::new(generator, ctx, element_type).as_base_type().into()
} }
@ -540,6 +547,7 @@ fn get_llvm_abi_type<'ctx, G: CodeGenerator + ?Sized>(
module: &Module<'ctx>, module: &Module<'ctx>,
generator: &mut G, generator: &mut G,
unifier: &mut Unifier, unifier: &mut Unifier,
store: &PrimitiveStore,
top_level: &TopLevelContext, top_level: &TopLevelContext,
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>, type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
primitives: &PrimitiveStore, primitives: &PrimitiveStore,
@ -550,7 +558,7 @@ fn get_llvm_abi_type<'ctx, G: CodeGenerator + ?Sized>(
return if unifier.unioned(ty, primitives.bool) { return if unifier.unioned(ty, primitives.bool) {
ctx.bool_type().into() ctx.bool_type().into()
} else { } else {
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, ty) get_llvm_type(ctx, module, generator, unifier, store, top_level, type_cache, ty)
}; };
} }
@ -699,6 +707,7 @@ pub fn gen_func_impl<
&module, &module,
generator, generator,
&mut unifier, &mut unifier,
&primitives,
top_level_ctx.as_ref(), top_level_ctx.as_ref(),
&mut type_cache, &mut type_cache,
&primitives, &primitives,
@ -715,6 +724,7 @@ pub fn gen_func_impl<
&module, &module,
generator, generator,
&mut unifier, &mut unifier,
&primitives,
top_level_ctx.as_ref(), top_level_ctx.as_ref(),
&mut type_cache, &mut type_cache,
&primitives, &primitives,
@ -767,6 +777,7 @@ pub fn gen_func_impl<
&module, &module,
generator, generator,
&mut unifier, &mut unifier,
&primitives,
top_level_ctx.as_ref(), top_level_ctx.as_ref(),
&mut type_cache, &mut type_cache,
arg.ty, arg.ty,

View File

@ -19,7 +19,7 @@ use crate::{
symbol_resolver::ValueEnum, symbol_resolver::ValueEnum,
toplevel::{ toplevel::{
helper::PrimDef, helper::PrimDef,
numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, prim_types::{make_ndarray_ty, unpack_ndarray_params},
DefinitionId, DefinitionId,
}, },
typecheck::typedef::{FunSignature, Type, TypeEnum}, typecheck::typedef::{FunSignature, Type, TypeEnum},
@ -1776,7 +1776,7 @@ pub fn gen_ndarray_array<'ctx>(
let obj_ty = fun.0.args[0].ty; let obj_ty = fun.0.args[0].ty;
let obj_elem_ty = match &*context.unifier.get_ty(obj_ty) { let obj_elem_ty = match &*context.unifier.get_ty(obj_ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
unpack_ndarray_var_tys(&mut context.unifier, obj_ty).0 unpack_ndarray_params(&context.unifier, &context.primitives, obj_ty).dtype
} }
TypeEnum::TList { ty } => { TypeEnum::TList { ty } => {
@ -1916,7 +1916,7 @@ pub fn gen_ndarray_copy<'ctx>(
let llvm_usize = generator.get_size_type(context.ctx); let llvm_usize = generator.get_size_type(context.ctx);
let this_ty = obj.as_ref().unwrap().0; let this_ty = obj.as_ref().unwrap().0;
let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty); let this_elem_ty = unpack_ndarray_params(&context.unifier, &context.primitives, this_ty).dtype;
let this_arg = let this_arg =
obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?; obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?;

View File

@ -10,7 +10,7 @@ use crate::{
expr::gen_binop_expr, expr::gen_binop_expr,
gen_in_range_check, gen_in_range_check,
}, },
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef}, toplevel::{helper::PrimDef, prim_types::unpack_ndarray_params, DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type, TypeEnum}, typecheck::typedef::{FunSignature, Type, TypeEnum},
}; };
use inkwell::{ use inkwell::{
@ -245,7 +245,8 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) { let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) {
TypeEnum::TList { ty } => *ty, TypeEnum::TList { ty } => *ty,
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
unpack_ndarray_var_tys(&mut ctx.unifier, target.custom.unwrap()).0 unpack_ndarray_params(&ctx.unifier, &ctx.primitives, target.custom.unwrap())
.dtype
} }
_ => unreachable!(), _ => unreachable!(),
}; };

View File

@ -24,8 +24,8 @@ use crate::{
stmt::exn_constructor, stmt::exn_constructor,
}, },
symbol_resolver::SymbolValue, symbol_resolver::SymbolValue,
toplevel::{helper::PrimDef, numpy::make_ndarray_ty}, toplevel::{helper::PrimDef, prim_types::make_ndarray_ty},
typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap}, typecheck::typedef::{into_var_map, TypeVar, VarMap},
}; };
use super::*; use super::*;
@ -301,10 +301,7 @@ struct BuiltinBuilder<'a> {
is_some_ty: (Type, bool), is_some_ty: (Type, bool),
unwrap_ty: (Type, bool), unwrap_ty: (Type, bool),
option_tvar: TypeVar,
ndarray_dtype_tvar: TypeVar,
ndarray_ndims_tvar: TypeVar,
ndarray_copy_ty: (Type, bool), ndarray_copy_ty: (Type, bool),
ndarray_fill_ty: (Type, bool), ndarray_fill_ty: (Type, bool),
@ -339,24 +336,19 @@ impl<'a> BuiltinBuilder<'a> {
} = *primitives; } = *primitives;
// Option-related // Option-related
let (is_some_ty, unwrap_ty, option_tvar) = let (is_some_ty, unwrap_ty) =
if let TypeEnum::TObj { fields, params, .. } = unifier.get_ty(option).as_ref() { if let TypeEnum::TObj { fields, .. } = unifier.get_ty(option).as_ref() {
( (
*fields.get(&PrimDef::OptionIsSome.simple_name().into()).unwrap(), *fields.get(&PrimDef::OptionIsSome.simple_name().into()).unwrap(),
*fields.get(&PrimDef::OptionUnwrap.simple_name().into()).unwrap(), *fields.get(&PrimDef::OptionUnwrap.simple_name().into()).unwrap(),
iter_type_vars(params).next().unwrap(),
) )
} else { } else {
unreachable!() unreachable!()
}; };
let TypeEnum::TObj { fields: ndarray_fields, params: ndarray_params, .. } = let TypeEnum::TObj { fields: ndarray_fields, .. } = &*unifier.get_ty(ndarray) else {
&*unifier.get_ty(ndarray)
else {
unreachable!() unreachable!()
}; };
let ndarray_dtype_tvar = iter_type_vars(ndarray_params).next().unwrap();
let ndarray_ndims_tvar = iter_type_vars(ndarray_params).nth(1).unwrap();
let ndarray_copy_ty = let ndarray_copy_ty =
*ndarray_fields.get(&PrimDef::NDArrayCopy.simple_name().into()).unwrap(); *ndarray_fields.get(&PrimDef::NDArrayCopy.simple_name().into()).unwrap();
let ndarray_fill_ty = let ndarray_fill_ty =
@ -398,10 +390,7 @@ impl<'a> BuiltinBuilder<'a> {
is_some_ty, is_some_ty,
unwrap_ty, unwrap_ty,
option_tvar,
ndarray_dtype_tvar,
ndarray_ndims_tvar,
ndarray_copy_ty, ndarray_copy_ty,
ndarray_fill_ty, ndarray_fill_ty,
@ -622,7 +611,7 @@ impl<'a> BuiltinBuilder<'a> {
PrimDef::Option => TopLevelDef::Class { PrimDef::Option => TopLevelDef::Class {
name: prim.name().into(), name: prim.name().into(),
object_id: prim.id(), object_id: prim.id(),
type_vars: vec![self.option_tvar.ty], type_vars: vec![self.primitives.option_type_tvar.ty],
fields: vec![], fields: vec![],
methods: vec![ methods: vec![
Self::create_method(PrimDef::OptionIsSome, self.is_some_ty.0), Self::create_method(PrimDef::OptionIsSome, self.is_some_ty.0),
@ -642,7 +631,7 @@ impl<'a> BuiltinBuilder<'a> {
name: prim.name().into(), name: prim.name().into(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.unwrap_ty.0, signature: self.unwrap_ty.0,
var_id: vec![self.option_tvar.id], var_id: vec![self.primitives.option_type_tvar.id],
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
@ -656,7 +645,7 @@ impl<'a> BuiltinBuilder<'a> {
name: prim.name().to_string(), name: prim.name().to_string(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.is_some_ty.0, signature: self.is_some_ty.0,
var_id: vec![self.option_tvar.id], var_id: vec![self.primitives.option_type_tvar.id],
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
@ -693,13 +682,13 @@ impl<'a> BuiltinBuilder<'a> {
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { args: vec![FuncArg {
name: "n".into(), name: "n".into(),
ty: self.option_tvar.ty, ty: self.primitives.option_type_tvar.ty,
default_value: None, default_value: None,
}], }],
ret: self.primitives.option, ret: self.primitives.option,
vars: into_var_map([self.option_tvar]), vars: into_var_map([self.primitives.option_type_tvar]),
})), })),
var_id: vec![self.option_tvar.id], var_id: vec![self.primitives.option_type_tvar.id],
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
@ -735,7 +724,10 @@ impl<'a> BuiltinBuilder<'a> {
PrimDef::NDArray => TopLevelDef::Class { PrimDef::NDArray => TopLevelDef::Class {
name: prim.name().into(), name: prim.name().into(),
object_id: prim.id(), object_id: prim.id(),
type_vars: vec![self.ndarray_dtype_tvar.ty, self.ndarray_ndims_tvar.ty], type_vars: vec![
self.primitives.ndarray_dtype_tvar.ty,
self.primitives.ndarray_ndims_tvar.ty,
],
fields: Vec::default(), fields: Vec::default(),
methods: vec![ methods: vec![
Self::create_method(PrimDef::NDArrayCopy, self.ndarray_copy_ty.0), Self::create_method(PrimDef::NDArrayCopy, self.ndarray_copy_ty.0),
@ -751,7 +743,10 @@ impl<'a> BuiltinBuilder<'a> {
name: prim.name().into(), name: prim.name().into(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.ndarray_copy_ty.0, signature: self.ndarray_copy_ty.0,
var_id: vec![self.ndarray_dtype_tvar.id, self.ndarray_ndims_tvar.id], var_id: vec![
self.primitives.ndarray_dtype_tvar.id,
self.primitives.ndarray_ndims_tvar.id,
],
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
@ -768,7 +763,10 @@ impl<'a> BuiltinBuilder<'a> {
name: prim.name().into(), name: prim.name().into(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.ndarray_fill_ty.0, signature: self.ndarray_fill_ty.0,
var_id: vec![self.ndarray_dtype_tvar.id, self.ndarray_ndims_tvar.id], var_id: vec![
self.primitives.ndarray_dtype_tvar.id,
self.primitives.ndarray_ndims_tvar.id,
],
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,

View File

@ -1,8 +1,8 @@
use std::convert::TryInto; use std::convert::TryInto;
use crate::symbol_resolver::SymbolValue; use crate::symbol_resolver::SymbolValue;
use crate::toplevel::numpy::unpack_ndarray_var_tys; use crate::toplevel::prim_types::unpack_ndarray_params;
use crate::typecheck::typedef::{into_var_map, Mapping, TypeVarId, VarMap}; use crate::typecheck::typedef::{into_var_map, Mapping, TypeVar, TypeVarId, VarMap};
use nac3parser::ast::{Constant, Location}; use nac3parser::ast::{Constant, Location};
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
use strum_macros::EnumIter; use strum_macros::EnumIter;
@ -286,6 +286,18 @@ pub fn make_exception_fields(int32: Type, int64: Type, str: Type) -> Vec<(StrRef
] ]
} }
pub fn make_option_type_tvar(unifier: &mut Unifier) -> TypeVar {
unifier.get_fresh_var(Some("option_type_var".into()), None)
}
pub fn make_ndarray_dtype_tvar(unifier: &mut Unifier) -> TypeVar {
unifier.get_fresh_var(Some("ndarray_dtype".into()), None)
}
pub fn make_ndarray_ndims_tvar(unifier: &mut Unifier, size_ty: Type) -> TypeVar {
unifier.get_fresh_const_generic_var(size_ty, Some("ndarray_ndims".into()), None)
}
impl TopLevelDef { impl TopLevelDef {
pub fn to_string(&self, unifier: &mut Unifier) -> String { pub fn to_string(&self, unifier: &mut Unifier) -> String {
match self { match self {
@ -381,16 +393,16 @@ impl TopLevelComposer {
params: VarMap::new(), params: VarMap::new(),
}); });
let option_type_var = unifier.get_fresh_var(Some("option_type_var".into()), None); let option_type_tvar = make_option_type_tvar(&mut unifier);
let is_some_type_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { let is_some_type_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![], args: vec![],
ret: bool, ret: bool,
vars: into_var_map([option_type_var]), vars: into_var_map([option_type_tvar]),
})); }));
let unwrap_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { let unwrap_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![], args: vec![],
ret: option_type_var.ty, ret: option_type_tvar.ty,
vars: into_var_map([option_type_var]), vars: into_var_map([option_type_tvar]),
})); }));
let option = unifier.add_ty(TypeEnum::TObj { let option = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Option.id(), obj_id: PrimDef::Option.id(),
@ -401,7 +413,7 @@ impl TopLevelComposer {
] ]
.into_iter() .into_iter()
.collect::<HashMap<_, _>>(), .collect::<HashMap<_, _>>(),
params: into_var_map([option_type_var]), params: into_var_map([option_type_tvar]),
}); });
let size_t_ty = match size_t { let size_t_ty = match size_t {
@ -410,9 +422,8 @@ impl TopLevelComposer {
_ => unreachable!(), _ => unreachable!(),
}; };
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None); let ndarray_dtype_tvar = make_ndarray_dtype_tvar(&mut unifier);
let ndarray_ndims_tvar = let ndarray_ndims_tvar = make_ndarray_ndims_tvar(&mut unifier, size_t_ty);
unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None);
let ndarray_copy_fun_ret_ty = unifier.get_fresh_var(None, None); let ndarray_copy_fun_ret_ty = unifier.get_fresh_var(None, None);
let ndarray_copy_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { let ndarray_copy_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![], args: vec![],
@ -451,7 +462,10 @@ impl TopLevelComposer {
str, str,
exception, exception,
option, option,
option_type_tvar,
ndarray, ndarray,
ndarray_dtype_tvar,
ndarray_ndims_tvar,
size_t, size_t,
}; };
unifier.put_primitive_store(&primitives); unifier.put_primitive_store(&primitives);
@ -881,22 +895,26 @@ pub fn parse_parameter_default_value(
} }
/// Obtains the element type of an array-like type. /// Obtains the element type of an array-like type.
pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type { pub fn arraylike_flatten_element_type(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
) -> Type {
match &*unifier.get_ty(ty) { match &*unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
unpack_ndarray_var_tys(unifier, ty).0 unpack_ndarray_params(unifier, store, ty).dtype
} }
TypeEnum::TList { ty } => arraylike_flatten_element_type(unifier, *ty), TypeEnum::TList { ty } => arraylike_flatten_element_type(unifier, store, *ty),
_ => ty, _ => ty,
} }
} }
/// Obtains the number of dimensions of an array-like type. /// Obtains the number of dimensions of an array-like type.
pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 { pub fn arraylike_get_ndims(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) -> u64 {
match &*unifier.get_ty(ty) { match &*unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let ndims = unpack_ndarray_var_tys(unifier, ty).1; let ndims = unpack_ndarray_params(unifier, store, ty).ndims;
let TypeEnum::TLiteral { values, .. } = &*unifier.get_ty_immutable(ndims) else { let TypeEnum::TLiteral { values, .. } = &*unifier.get_ty_immutable(ndims) else {
panic!("Expected TLiteral for ndarray.ndims, got {}", unifier.stringify(ndims)) panic!("Expected TLiteral for ndarray.ndims, got {}", unifier.stringify(ndims))
}; };
@ -908,7 +926,7 @@ pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 {
u64::try_from(values[0].clone()).unwrap() u64::try_from(values[0].clone()).unwrap()
} }
TypeEnum::TList { ty } => arraylike_get_ndims(unifier, *ty) + 1, TypeEnum::TList { ty } => arraylike_get_ndims(unifier, store, *ty) + 1,
_ => 0, _ => 0,
} }
} }

View File

@ -30,7 +30,7 @@ pub struct DefinitionId(pub usize);
pub mod builtins; pub mod builtins;
pub mod composer; pub mod composer;
pub mod helper; pub mod helper;
pub mod numpy; pub mod prim_types;
pub mod type_annotation; pub mod type_annotation;
use composer::*; use composer::*;
use type_annotation::*; use type_annotation::*;

View File

@ -2,7 +2,7 @@ use crate::{
toplevel::helper::PrimDef, toplevel::helper::PrimDef,
typecheck::{ typecheck::{
type_inferencer::PrimitiveStore, type_inferencer::PrimitiveStore,
typedef::{Type, TypeEnum, TypeVarId, Unifier, VarMap}, typedef::{Type, TypeEnum, Unifier, VarMap},
}, },
}; };
use itertools::Itertools; use itertools::Itertools;
@ -57,29 +57,25 @@ pub fn subst_ndarray_tvars(
unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray) unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray)
} }
fn unpack_ndarray_tvars(unifier: &mut Unifier, ndarray: Type) -> Vec<(TypeVarId, Type)> { #[derive(Clone, Copy, Debug)]
pub struct NDArrayParams {
pub dtype: Type,
pub ndims: Type,
}
/// Extract the [`Type`]s of `ndarray`.
#[must_use]
pub fn unpack_ndarray_params(
unifier: &Unifier,
store: &PrimitiveStore,
ndarray: Type,
) -> NDArrayParams {
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else { let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray)) panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
}; };
debug_assert_eq!(*obj_id, PrimDef::NDArray.id()); debug_assert_eq!(*obj_id, PrimDef::NDArray.id());
debug_assert_eq!(params.len(), 2); debug_assert_eq!(params.len(), 2);
let dtype = *params.get(&store.ndarray_dtype_tvar.id).unwrap();
params let ndims = *params.get(&store.ndarray_ndims_tvar.id).unwrap();
.iter() NDArrayParams { dtype, ndims }
.sorted_by_key(|(obj_id, _)| *obj_id)
.map(|(var_id, ty)| (*var_id, *ty))
.collect_vec()
}
/// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds
/// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray`
/// respectively.
pub fn unpack_ndarray_var_ids(unifier: &mut Unifier, ndarray: Type) -> (TypeVarId, TypeVarId) {
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.0).collect_tuple().unwrap()
}
/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to
/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively.
pub fn unpack_ndarray_var_tys(unifier: &mut Unifier, ndarray: Type) -> (Type, Type) {
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.1).collect_tuple().unwrap()
} }

View File

@ -1,6 +1,6 @@
use crate::symbol_resolver::SymbolValue; use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PrimDef; use crate::toplevel::helper::PrimDef;
use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys}; use crate::toplevel::prim_types::{make_ndarray_ty, unpack_ndarray_params};
use crate::typecheck::{ use crate::typecheck::{
type_inferencer::*, type_inferencer::*,
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
@ -369,16 +369,16 @@ pub fn typeof_ndarray_broadcast(
if is_left_ndarray && is_right_ndarray { if is_left_ndarray && is_right_ndarray {
// Perform broadcasting on two ndarray operands. // Perform broadcasting on two ndarray operands.
let (left_ty_dtype, left_ty_ndims) = unpack_ndarray_var_tys(unifier, left); let left_params = unpack_ndarray_params(unifier, primitives, left);
let (right_ty_dtype, right_ty_ndims) = unpack_ndarray_var_tys(unifier, right); let right_params = unpack_ndarray_params(unifier, primitives, right);
assert!(unifier.unioned(left_ty_dtype, right_ty_dtype)); assert!(unifier.unioned(left_params.dtype, right_params.dtype));
let left_ty_ndims = match &*unifier.get_ty_immutable(left_ty_ndims) { let left_ty_ndims = match &*unifier.get_ty_immutable(left_params.ndims) {
TypeEnum::TLiteral { values, .. } => values.clone(), TypeEnum::TLiteral { values, .. } => values.clone(),
_ => unreachable!(), _ => unreachable!(),
}; };
let right_ty_ndims = match &*unifier.get_ty_immutable(right_ty_ndims) { let right_ty_ndims = match &*unifier.get_ty_immutable(right_params.ndims) {
TypeEnum::TLiteral { values, .. } => values.clone(), TypeEnum::TLiteral { values, .. } => values.clone(),
_ => unreachable!(), _ => unreachable!(),
}; };
@ -397,11 +397,11 @@ pub fn typeof_ndarray_broadcast(
.collect_vec(); .collect_vec();
let res_ndims = unifier.get_fresh_literal(res_ndims, None); let res_ndims = unifier.get_fresh_literal(res_ndims, None);
Ok(make_ndarray_ty(unifier, primitives, Some(left_ty_dtype), Some(res_ndims))) Ok(make_ndarray_ty(unifier, primitives, Some(left_params.dtype), Some(res_ndims)))
} else { } else {
let (ndarray_ty, scalar_ty) = if is_left_ndarray { (left, right) } else { (right, left) }; let (ndarray_ty, scalar_ty) = if is_left_ndarray { (left, right) } else { (right, left) };
let (ndarray_ty_dtype, _) = unpack_ndarray_var_tys(unifier, ndarray_ty); let ndarray_ty_dtype = unpack_ndarray_params(unifier, primitives, ndarray_ty).dtype;
if unifier.unioned(ndarray_ty_dtype, scalar_ty) { if unifier.unioned(ndarray_ty_dtype, scalar_ty) {
Ok(ndarray_ty) Ok(ndarray_ty)
@ -444,7 +444,7 @@ pub fn typeof_binop(
} }
Operator::MatMult => { Operator::MatMult => {
let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs); let lhs_ndims = unpack_ndarray_params(unifier, primitives, lhs).ndims;
let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) { let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) {
TypeEnum::TLiteral { values, .. } => { TypeEnum::TLiteral { values, .. } => {
assert_eq!(values.len(), 1); assert_eq!(values.len(), 1);
@ -452,7 +452,7 @@ pub fn typeof_binop(
} }
_ => unreachable!(), _ => unreachable!(),
}; };
let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs); let rhs_ndims = unpack_ndarray_params(unifier, primitives, rhs).ndims;
let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) { let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) {
TypeEnum::TLiteral { values, .. } => { TypeEnum::TLiteral { values, .. } => {
assert_eq!(values.len(), 1); assert_eq!(values.len(), 1);
@ -552,7 +552,7 @@ pub fn typeof_unaryop(
Unaryop::UAdd | Unaryop::USub => { Unaryop::UAdd | Unaryop::USub => {
if operand_obj_id.is_some_and(|id| id == PrimDef::NDArray.id()) { if operand_obj_id.is_some_and(|id| id == PrimDef::NDArray.id()) {
let (dtype, _) = unpack_ndarray_var_tys(unifier, operand); let dtype = unpack_ndarray_params(unifier, primitives, operand).dtype;
if dtype.obj_id(unifier).is_some_and(|id| id == PrimDef::Bool.id()) { if dtype.obj_id(unifier).is_some_and(|id| id == PrimDef::Bool.id()) {
return Err(if op == Unaryop::UAdd { return Err(if op == Unaryop::UAdd {
"The ufunc 'positive' cannot be applied to ndarray[bool, N]".to_string() "The ufunc 'positive' cannot be applied to ndarray[bool, N]".to_string()
@ -586,7 +586,7 @@ pub fn typeof_cmpop(
Ok(Some(if is_left_ndarray || is_right_ndarray { Ok(Some(if is_left_ndarray || is_right_ndarray {
let brd = typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?; let brd = typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?;
let (_, ndims) = unpack_ndarray_var_tys(unifier, brd); let ndims = unpack_ndarray_params(unifier, primitives, brd).ndims;
make_ndarray_ty(unifier, primitives, Some(primitives.bool), Some(ndims)) make_ndarray_ty(unifier, primitives, Some(primitives.bool), Some(ndims))
} else if unifier.unioned(lhs, rhs) { } else if unifier.unioned(lhs, rhs) {
@ -653,8 +653,8 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None); unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None);
let ndarray_unsized_t = let ndarray_unsized_t =
make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.ty)); make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.ty));
let (ndarray_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_t); let ndarray_dtype_t = unpack_ndarray_params(unifier, store, ndarray_t).dtype;
let (ndarray_unsized_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_unsized_t); let ndarray_unsized_dtype_t = unpack_ndarray_params(unifier, store, ndarray_unsized_t).dtype;
impl_basic_arithmetic( impl_basic_arithmetic(
unifier, unifier,
store, store,

View File

@ -4,13 +4,15 @@ use std::iter::once;
use std::ops::Not; use std::ops::Not;
use std::{cell::RefCell, sync::Arc}; use std::{cell::RefCell, sync::Arc};
use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap}; use super::typedef::{
Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, TypeVar, Unifier, VarMap,
};
use super::{magic_methods::*, type_error::TypeError, typedef::CallId}; use super::{magic_methods::*, type_error::TypeError, typedef::CallId};
use crate::{ use crate::{
symbol_resolver::{SymbolResolver, SymbolValue}, symbol_resolver::{SymbolResolver, SymbolValue},
toplevel::{ toplevel::{
helper::{arraylike_flatten_element_type, arraylike_get_ndims, PrimDef}, helper::{arraylike_flatten_element_type, arraylike_get_ndims, PrimDef},
numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, prim_types::{make_ndarray_ty, unpack_ndarray_params},
TopLevelContext, TopLevelContext,
}, },
}; };
@ -49,7 +51,11 @@ pub struct PrimitiveStore {
pub str: Type, pub str: Type,
pub exception: Type, pub exception: Type,
pub option: Type, pub option: Type,
/// The contained type of an `Option`
pub option_type_tvar: TypeVar,
pub ndarray: Type, pub ndarray: Type,
pub ndarray_dtype_tvar: TypeVar,
pub ndarray_ndims_tvar: TypeVar,
pub size_t: u32, pub size_t: u32,
} }
@ -896,7 +902,8 @@ impl<'a> Inferencer<'a> {
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
{ {
let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty); let ndarray_ndims =
unpack_ndarray_params(self.unifier, self.primitives, arg0_ty).ndims;
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims)) make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims))
} else { } else {
@ -934,9 +941,7 @@ impl<'a> Inferencer<'a> {
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
{ {
let (ndarray_dtype, _) = unpack_ndarray_var_tys(self.unifier, arg0_ty); unpack_ndarray_params(self.unifier, self.primitives, arg0_ty).dtype
ndarray_dtype
} else { } else {
arg0_ty arg0_ty
}; };
@ -988,14 +993,14 @@ impl<'a> Inferencer<'a> {
let arg0_dtype = let arg0_dtype =
if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
unpack_ndarray_var_tys(self.unifier, arg0_ty).0 unpack_ndarray_params(self.unifier, self.primitives, arg0_ty).dtype
} else { } else {
arg0_ty arg0_ty
}; };
let arg1_dtype = let arg1_dtype =
if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
unpack_ndarray_var_tys(self.unifier, arg1_ty).0 unpack_ndarray_params(self.unifier, self.primitives, arg1_ty).dtype
} else { } else {
arg1_ty arg1_ty
}; };
@ -1026,7 +1031,8 @@ impl<'a> Inferencer<'a> {
// (float, int32), so convert it to align with the dtype of the first arg // (float, int32), so convert it to align with the dtype of the first arg
let arg1_ty = if id == &"np_ldexp".into() { let arg1_ty = if id == &"np_ldexp".into() {
if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, arg1_ty); let ndims =
unpack_ndarray_params(self.unifier, self.primitives, arg1_ty).ndims;
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndims)) make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndims))
} else { } else {
@ -1115,7 +1121,8 @@ impl<'a> Inferencer<'a> {
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
{ {
let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty); let ndarray_ndims =
unpack_ndarray_params(self.unifier, self.primitives, arg0_ty).ndims;
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims)) make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims))
} else { } else {
@ -1258,7 +1265,8 @@ impl<'a> Inferencer<'a> {
let ndmin_kw = let ndmin_kw =
keywords.iter().find(|kwarg| kwarg.node.arg.is_some_and(|id| id == "ndmin".into())); keywords.iter().find(|kwarg| kwarg.node.arg.is_some_and(|id| id == "ndmin".into()));
let ty = arraylike_flatten_element_type(self.unifier, arg0.custom.unwrap()); let ty =
arraylike_flatten_element_type(self.unifier, self.primitives, arg0.custom.unwrap());
let ndims = if let Some(ndmin_kw) = ndmin_kw { let ndims = if let Some(ndmin_kw) = ndmin_kw {
match &ndmin_kw.node.value.node { match &ndmin_kw.node.value.node {
ExprKind::Constant { value, .. } => match value { ExprKind::Constant { value, .. } => match value {
@ -1266,10 +1274,10 @@ impl<'a> Inferencer<'a> {
_ => return Err(HashSet::from(["Expected uint64 for ndims".to_string()])), _ => return Err(HashSet::from(["Expected uint64 for ndims".to_string()])),
}, },
_ => arraylike_get_ndims(self.unifier, arg0.custom.unwrap()), _ => arraylike_get_ndims(self.unifier, self.primitives, arg0.custom.unwrap()),
} }
} else { } else {
arraylike_get_ndims(self.unifier, arg0.custom.unwrap()) arraylike_get_ndims(self.unifier, self.primitives, arg0.custom.unwrap())
}; };
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None); let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
let ret = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims)); let ret = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims));
@ -1666,8 +1674,12 @@ impl<'a> Inferencer<'a> {
let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) { let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) {
TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }), TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }),
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (_, ndims) = let ndims = unpack_ndarray_params(
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); self.unifier,
self.primitives,
value.custom.unwrap(),
)
.ndims;
make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims)) make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims))
} }
@ -1680,8 +1692,13 @@ impl<'a> Inferencer<'a> {
ExprKind::Constant { value: ast::Constant::Int(val), .. } => { ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
match &*self.unifier.get_ty(value.custom.unwrap()) { match &*self.unifier.get_ty(value.custom.unwrap()) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (_, ndims) = let ndims = unpack_ndarray_params(
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); self.unifier,
self.primitives,
value.custom.unwrap(),
)
.ndims;
self.infer_subscript_ndarray(value, ty, ndims) self.infer_subscript_ndarray(value, ty, ndims)
} }
_ => { _ => {
@ -1724,7 +1741,9 @@ impl<'a> Inferencer<'a> {
} }
} }
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); let ndims =
unpack_ndarray_params(self.unifier, self.primitives, value.custom.unwrap())
.ndims;
let ndarray_ty = let ndarray_ty =
make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims)); make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims));
self.constrain(value.custom.unwrap(), ndarray_ty, &value.location)?; self.constrain(value.custom.unwrap(), ndarray_ty, &value.location)?;
@ -1751,8 +1770,12 @@ impl<'a> Inferencer<'a> {
Ok(ty) Ok(ty)
} }
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (_, ndims) = let ndims = unpack_ndarray_params(
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); self.unifier,
self.primitives,
value.custom.unwrap(),
)
.ndims;
let valid_index_tys = [self.primitives.int32, self.primitives.isize()] let valid_index_tys = [self.primitives.int32, self.primitives.isize()]
.into_iter() .into_iter()

View File

@ -1,5 +1,8 @@
use super::super::{magic_methods::with_fields, typedef::*}; use super::super::{magic_methods::with_fields, typedef::*};
use super::*; use super::*;
use crate::toplevel::helper::{
make_ndarray_dtype_tvar, make_ndarray_ndims_tvar, make_option_type_tvar,
};
use crate::{ use crate::{
codegen::CodeGenContext, codegen::CodeGenContext,
symbol_resolver::ValueEnum, symbol_resolver::ValueEnum,
@ -132,14 +135,14 @@ impl TestEnvironment {
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let option_type_tvar = make_option_type_tvar(&mut unifier);
let option = unifier.add_ty(TypeEnum::TObj { let option = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Option.id(), obj_id: PrimDef::Option.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None); let ndarray_dtype_tvar = make_ndarray_dtype_tvar(&mut unifier);
let ndarray_ndims_tvar = let ndarray_ndims_tvar = make_ndarray_ndims_tvar(&mut unifier, uint64);
unifier.get_fresh_const_generic_var(uint64, Some("ndarray_ndims".into()), None);
let ndarray = unifier.add_ty(TypeEnum::TObj { let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::NDArray.id(), obj_id: PrimDef::NDArray.id(),
fields: HashMap::new(), fields: HashMap::new(),
@ -157,7 +160,10 @@ impl TestEnvironment {
uint32, uint32,
uint64, uint64,
option, option,
option_type_tvar,
ndarray, ndarray,
ndarray_dtype_tvar,
ndarray_ndims_tvar,
size_t: 64, size_t: 64,
}; };
unifier.put_primitive_store(&primitives); unifier.put_primitive_store(&primitives);
@ -268,16 +274,22 @@ impl TestEnvironment {
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let option_type_tvar = make_option_type_tvar(&mut unifier);
let option = unifier.add_ty(TypeEnum::TObj { let option = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Option.id(), obj_id: PrimDef::Option.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let ndarray_dtype_tvar = make_ndarray_dtype_tvar(&mut unifier);
let ndarray_ndims_tvar = make_ndarray_ndims_tvar(&mut unifier, uint64);
let ndarray = unifier.add_ty(TypeEnum::TObj { let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::NDArray.id(), obj_id: PrimDef::NDArray.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
identifier_mapping.insert("None".into(), none); identifier_mapping.insert("None".into(), none);
for (i, name) in ["int32", "int64", "float", "bool", "none", "range", "str", "Exception"] for (i, name) in ["int32", "int64", "float", "bool", "none", "range", "str", "Exception"]
.iter() .iter()
@ -312,7 +324,10 @@ impl TestEnvironment {
uint32, uint32,
uint64, uint64,
option, option,
option_type_tvar,
ndarray, ndarray,
ndarray_dtype_tvar,
ndarray_ndims_tvar,
size_t: 64, size_t: 64,
}; };