core: move top level def type vars into PrimitiveStore
#418
|
@ -4,7 +4,7 @@ use nac3core::{
|
||||||
symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum},
|
symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum},
|
||||||
toplevel::{
|
toplevel::{
|
||||||
helper::PrimDef,
|
helper::PrimDef,
|
||||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
prim_types::{make_ndarray_ty, unpack_ndarray_params},
|
||||||
DefinitionId, TopLevelDef,
|
DefinitionId, TopLevelDef,
|
||||||
},
|
},
|
||||||
typecheck::{
|
typecheck::{
|
||||||
|
@ -665,11 +665,11 @@ impl InnerResolver {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
(TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PrimDef::NDArray.id() => {
|
(TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PrimDef::NDArray.id() => {
|
||||||
let (ty, ndims) = unpack_ndarray_var_tys(unifier, extracted_ty);
|
let params = unpack_ndarray_params(unifier, primitives, extracted_ty);
|
||||||
let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?;
|
let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?;
|
||||||
if len == 0 {
|
if len == 0 {
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
&*unifier.get_ty(ty),
|
&*unifier.get_ty(params.dtype),
|
||||||
TypeEnum::TVar { fields: None, range, .. }
|
TypeEnum::TVar { fields: None, range, .. }
|
||||||
if range.is_empty()
|
if range.is_empty()
|
||||||
));
|
));
|
||||||
|
@ -678,10 +678,14 @@ impl InnerResolver {
|
||||||
let actual_ty =
|
let actual_ty =
|
||||||
self.get_list_elem_type(py, obj, len, unifier, defs, primitives)?;
|
self.get_list_elem_type(py, obj, len, unifier, defs, primitives)?;
|
||||||
match actual_ty {
|
match actual_ty {
|
||||||
Ok(t) => match unifier.unify(ty, t) {
|
Ok(t) => match unifier.unify(params.dtype, t) {
|
||||||
Ok(()) => {
|
Ok(()) => {
|
||||||
let ndarray_ty =
|
let ndarray_ty = make_ndarray_ty(
|
||||||
make_ndarray_ty(unifier, primitives, Some(ty), Some(ndims));
|
unifier,
|
||||||
|
primitives,
|
||||||
|
Some(params.dtype),
|
||||||
|
Some(params.ndims),
|
||||||
|
);
|
||||||
|
|
||||||
Ok(Ok(ndarray_ty))
|
Ok(Ok(ndarray_ty))
|
||||||
}
|
}
|
||||||
|
@ -984,7 +988,7 @@ impl InnerResolver {
|
||||||
TypeEnum::TObj { obj_id, params, .. }
|
TypeEnum::TObj { obj_id, params, .. }
|
||||||
if *obj_id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap() =>
|
if *obj_id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap() =>
|
||||||
{
|
{
|
||||||
*params.iter().next().unwrap().1
|
*params.get(&ctx.primitives.option_type_tvar.id).unwrap()
|
||||||
}
|
}
|
||||||
_ => unreachable!("must be option type"),
|
_ => unreachable!("must be option type"),
|
||||||
};
|
};
|
||||||
|
|
|
@ -8,7 +8,7 @@ use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
|
||||||
use crate::codegen::stmt::gen_for_callback_incrementing;
|
use crate::codegen::stmt::gen_for_callback_incrementing;
|
||||||
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
|
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
|
||||||
use crate::toplevel::helper::PrimDef;
|
use crate::toplevel::helper::PrimDef;
|
||||||
use crate::toplevel::numpy::unpack_ndarray_var_tys;
|
use crate::toplevel::prim_types::unpack_ndarray_params;
|
||||||
use crate::typecheck::typedef::Type;
|
use crate::typecheck::typedef::Type;
|
||||||
|
|
||||||
/// Shorthand for [`unreachable!()`] when a type of argument is not supported.
|
/// Shorthand for [`unreachable!()`] when a type of argument is not supported.
|
||||||
|
@ -66,7 +66,7 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(n)
|
BasicValueEnum::PointerValue(n)
|
||||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -128,7 +128,7 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(n)
|
BasicValueEnum::PointerValue(n)
|
||||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -206,7 +206,7 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(n)
|
BasicValueEnum::PointerValue(n)
|
||||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -273,7 +273,7 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(n)
|
BasicValueEnum::PointerValue(n)
|
||||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -339,7 +339,7 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(n)
|
BasicValueEnum::PointerValue(n)
|
||||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -385,7 +385,7 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(n)
|
BasicValueEnum::PointerValue(n)
|
||||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -425,7 +425,7 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(n)
|
BasicValueEnum::PointerValue(n)
|
||||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -490,7 +490,7 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(n)
|
BasicValueEnum::PointerValue(n)
|
||||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -544,7 +544,7 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(n)
|
BasicValueEnum::PointerValue(n)
|
||||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -594,7 +594,7 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(n)
|
BasicValueEnum::PointerValue(n)
|
||||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -692,7 +692,7 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(n)
|
BasicValueEnum::PointerValue(n)
|
||||||
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, a_ty).dtype;
|
||||||
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
||||||
|
@ -792,16 +792,17 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||||
|
|
||||||
let dtype = if is_ndarray1 && is_ndarray2 {
|
let dtype = if is_ndarray1 && is_ndarray2 {
|
||||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
let ndarray_dtype1 =
|
||||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype;
|
||||||
|
let ndarray_dtype2 =
|
||||||
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype;
|
||||||
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||||
|
|
||||||
ndarray_dtype1
|
ndarray_dtype1
|
||||||
} else if is_ndarray1 {
|
} else if is_ndarray1 {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype
|
||||||
} else if is_ndarray2 {
|
} else if is_ndarray2 {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
|
@ -908,7 +909,7 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(n)
|
BasicValueEnum::PointerValue(n)
|
||||||
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, a_ty).dtype;
|
||||||
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
||||||
|
@ -1008,16 +1009,18 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||||
|
|
||||||
let dtype = if is_ndarray1 && is_ndarray2 {
|
let dtype = if is_ndarray1 && is_ndarray2 {
|
||||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
let ndarray_dtype1 =
|
||||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype;
|
||||||
|
let ndarray_dtype2 =
|
||||||
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype;
|
||||||
|
|
||||||
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||||
|
|
||||||
ndarray_dtype1
|
ndarray_dtype1
|
||||||
} else if is_ndarray1 {
|
} else if is_ndarray1 {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype
|
||||||
} else if is_ndarray2 {
|
} else if is_ndarray2 {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
|
@ -1088,7 +1091,7 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(n)
|
BasicValueEnum::PointerValue(n)
|
||||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, n_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -1128,7 +1131,7 @@ pub fn call_numpy_isnan<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(x)
|
BasicValueEnum::PointerValue(x)
|
||||||
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -1172,7 +1175,7 @@ pub fn call_numpy_isinf<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(x)
|
BasicValueEnum::PointerValue(x)
|
||||||
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -1216,7 +1219,7 @@ pub fn call_numpy_sin<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(x)
|
BasicValueEnum::PointerValue(x)
|
||||||
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -1256,7 +1259,7 @@ pub fn call_numpy_cos<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(x)
|
BasicValueEnum::PointerValue(x)
|
||||||
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -1296,7 +1299,7 @@ pub fn call_numpy_exp<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(x)
|
BasicValueEnum::PointerValue(x)
|
||||||
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -1336,7 +1339,7 @@ pub fn call_numpy_exp2<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(x)
|
BasicValueEnum::PointerValue(x)
|
||||||
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -1376,7 +1379,7 @@ pub fn call_numpy_log<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(x)
|
BasicValueEnum::PointerValue(x)
|
||||||
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -1416,7 +1419,7 @@ pub fn call_numpy_log10<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(x)
|
BasicValueEnum::PointerValue(x)
|
||||||
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -1456,7 +1459,7 @@ pub fn call_numpy_log2<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(x)
|
BasicValueEnum::PointerValue(x)
|
||||||
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -1496,7 +1499,7 @@ pub fn call_numpy_fabs<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(x)
|
BasicValueEnum::PointerValue(x)
|
||||||
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -1536,7 +1539,7 @@ pub fn call_numpy_sqrt<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(x)
|
BasicValueEnum::PointerValue(x)
|
||||||
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -1576,7 +1579,7 @@ pub fn call_numpy_rint<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(x)
|
BasicValueEnum::PointerValue(x)
|
||||||
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -1616,7 +1619,7 @@ pub fn call_numpy_tan<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(x)
|
BasicValueEnum::PointerValue(x)
|
||||||
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -1656,7 +1659,7 @@ pub fn call_numpy_arcsin<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(x)
|
BasicValueEnum::PointerValue(x)
|
||||||
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -1696,7 +1699,7 @@ pub fn call_numpy_arccos<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(x)
|
BasicValueEnum::PointerValue(x)
|
||||||
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -1736,7 +1739,7 @@ pub fn call_numpy_arctan<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(x)
|
BasicValueEnum::PointerValue(x)
|
||||||
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -1776,7 +1779,7 @@ pub fn call_numpy_sinh<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(x)
|
BasicValueEnum::PointerValue(x)
|
||||||
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -1816,7 +1819,7 @@ pub fn call_numpy_cosh<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(x)
|
BasicValueEnum::PointerValue(x)
|
||||||
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -1856,7 +1859,7 @@ pub fn call_numpy_tanh<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(x)
|
BasicValueEnum::PointerValue(x)
|
||||||
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -1896,7 +1899,7 @@ pub fn call_numpy_arcsinh<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(x)
|
BasicValueEnum::PointerValue(x)
|
||||||
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -1936,7 +1939,7 @@ pub fn call_numpy_arccosh<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(x)
|
BasicValueEnum::PointerValue(x)
|
||||||
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -1976,7 +1979,7 @@ pub fn call_numpy_arctanh<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(x)
|
BasicValueEnum::PointerValue(x)
|
||||||
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -2016,7 +2019,7 @@ pub fn call_numpy_expm1<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(x)
|
BasicValueEnum::PointerValue(x)
|
||||||
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -2056,7 +2059,7 @@ pub fn call_numpy_cbrt<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(x)
|
BasicValueEnum::PointerValue(x)
|
||||||
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -2096,7 +2099,7 @@ pub fn call_scipy_special_erf<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(z)
|
BasicValueEnum::PointerValue(z)
|
||||||
if z_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if z_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, z_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, z_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -2136,7 +2139,7 @@ pub fn call_scipy_special_erfc<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(x)
|
BasicValueEnum::PointerValue(x)
|
||||||
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -2176,7 +2179,7 @@ pub fn call_scipy_special_gamma<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(z)
|
BasicValueEnum::PointerValue(z)
|
||||||
if z_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if z_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, z_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, z_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -2216,7 +2219,7 @@ pub fn call_scipy_special_gammaln<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(x)
|
BasicValueEnum::PointerValue(x)
|
||||||
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -2256,7 +2259,7 @@ pub fn call_scipy_special_j0<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(x)
|
BasicValueEnum::PointerValue(x)
|
||||||
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -2296,7 +2299,7 @@ pub fn call_scipy_special_j1<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
BasicValueEnum::PointerValue(x)
|
BasicValueEnum::PointerValue(x)
|
||||||
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
|
let elem_ty = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x_ty).dtype;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -2345,16 +2348,18 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||||
|
|
||||||
let dtype = if is_ndarray1 && is_ndarray2 {
|
let dtype = if is_ndarray1 && is_ndarray2 {
|
||||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
let ndarray_dtype1 =
|
||||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype;
|
||||||
|
let ndarray_dtype2 =
|
||||||
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype;
|
||||||
|
|
||||||
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||||
|
|
||||||
ndarray_dtype1
|
ndarray_dtype1
|
||||||
} else if is_ndarray1 {
|
} else if is_ndarray1 {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype
|
||||||
} else if is_ndarray2 {
|
} else if is_ndarray2 {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
|
@ -2412,16 +2417,18 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||||
|
|
||||||
let dtype = if is_ndarray1 && is_ndarray2 {
|
let dtype = if is_ndarray1 && is_ndarray2 {
|
||||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
let ndarray_dtype1 =
|
||||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype;
|
||||||
|
let ndarray_dtype2 =
|
||||||
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype;
|
||||||
|
|
||||||
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||||
|
|
||||||
ndarray_dtype1
|
ndarray_dtype1
|
||||||
} else if is_ndarray1 {
|
} else if is_ndarray1 {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype
|
||||||
} else if is_ndarray2 {
|
} else if is_ndarray2 {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
|
@ -2479,16 +2486,18 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||||
|
|
||||||
let dtype = if is_ndarray1 && is_ndarray2 {
|
let dtype = if is_ndarray1 && is_ndarray2 {
|
||||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
let ndarray_dtype1 =
|
||||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype;
|
||||||
|
let ndarray_dtype2 =
|
||||||
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype;
|
||||||
|
|
||||||
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||||
|
|
||||||
ndarray_dtype1
|
ndarray_dtype1
|
||||||
} else if is_ndarray1 {
|
} else if is_ndarray1 {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype
|
||||||
} else if is_ndarray2 {
|
} else if is_ndarray2 {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
|
@ -2546,16 +2555,18 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||||
|
|
||||||
let dtype = if is_ndarray1 && is_ndarray2 {
|
let dtype = if is_ndarray1 && is_ndarray2 {
|
||||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
let ndarray_dtype1 =
|
||||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype;
|
||||||
|
let ndarray_dtype2 =
|
||||||
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype;
|
||||||
|
|
||||||
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||||
|
|
||||||
ndarray_dtype1
|
ndarray_dtype1
|
||||||
} else if is_ndarray1 {
|
} else if is_ndarray1 {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype
|
||||||
} else if is_ndarray2 {
|
} else if is_ndarray2 {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
|
@ -2612,12 +2623,18 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
let is_ndarray2 =
|
let is_ndarray2 =
|
||||||
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||||
|
|
||||||
let dtype =
|
let dtype = if is_ndarray1 {
|
||||||
if is_ndarray1 { unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 } else { x1_ty };
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype
|
||||||
|
} else {
|
||||||
|
x1_ty
|
||||||
|
};
|
||||||
|
|
||||||
let x1_scalar_ty = dtype;
|
let x1_scalar_ty = dtype;
|
||||||
let x2_scalar_ty =
|
let x2_scalar_ty = if is_ndarray2 {
|
||||||
if is_ndarray2 { unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 } else { x2_ty };
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype
|
||||||
|
} else {
|
||||||
|
x2_ty
|
||||||
|
};
|
||||||
|
|
||||||
numpy::ndarray_elementwise_binop_impl(
|
numpy::ndarray_elementwise_binop_impl(
|
||||||
generator,
|
generator,
|
||||||
|
@ -2669,16 +2686,18 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||||
|
|
||||||
let dtype = if is_ndarray1 && is_ndarray2 {
|
let dtype = if is_ndarray1 && is_ndarray2 {
|
||||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
let ndarray_dtype1 =
|
||||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype;
|
||||||
|
let ndarray_dtype2 =
|
||||||
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype;
|
||||||
|
|
||||||
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||||
|
|
||||||
ndarray_dtype1
|
ndarray_dtype1
|
||||||
} else if is_ndarray1 {
|
} else if is_ndarray1 {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype
|
||||||
} else if is_ndarray2 {
|
} else if is_ndarray2 {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
|
@ -2736,16 +2755,18 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||||
|
|
||||||
let dtype = if is_ndarray1 && is_ndarray2 {
|
let dtype = if is_ndarray1 && is_ndarray2 {
|
||||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
let ndarray_dtype1 =
|
||||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype;
|
||||||
|
let ndarray_dtype2 =
|
||||||
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype;
|
||||||
|
|
||||||
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||||
|
|
||||||
ndarray_dtype1
|
ndarray_dtype1
|
||||||
} else if is_ndarray1 {
|
} else if is_ndarray1 {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x1_ty).dtype
|
||||||
} else if is_ndarray2 {
|
} else if is_ndarray2 {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, x2_ty).dtype
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
|
|
|
@ -17,7 +17,7 @@ use crate::{
|
||||||
symbol_resolver::{SymbolValue, ValueEnum},
|
symbol_resolver::{SymbolValue, ValueEnum},
|
||||||
toplevel::{
|
toplevel::{
|
||||||
helper::PrimDef,
|
helper::PrimDef,
|
||||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
prim_types::{make_ndarray_ty, unpack_ndarray_params},
|
||||||
DefinitionId, TopLevelDef,
|
DefinitionId, TopLevelDef,
|
||||||
},
|
},
|
||||||
typecheck::{
|
typecheck::{
|
||||||
|
@ -150,7 +150,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||||
TypeEnum::TObj { obj_id, params, .. }
|
TypeEnum::TObj { obj_id, params, .. }
|
||||||
if *obj_id == self.primitives.option.obj_id(&self.unifier).unwrap() =>
|
if *obj_id == self.primitives.option.obj_id(&self.unifier).unwrap() =>
|
||||||
{
|
{
|
||||||
*params.iter().next().unwrap().1
|
*params.get(&self.primitives.option_type_tvar.id).unwrap()
|
||||||
}
|
}
|
||||||
_ => unreachable!("must be option type"),
|
_ => unreachable!("must be option type"),
|
||||||
};
|
};
|
||||||
|
@ -166,7 +166,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||||
TypeEnum::TObj { obj_id, params, .. }
|
TypeEnum::TObj { obj_id, params, .. }
|
||||||
if *obj_id == self.primitives.option.obj_id(&self.unifier).unwrap() =>
|
if *obj_id == self.primitives.option.obj_id(&self.unifier).unwrap() =>
|
||||||
{
|
{
|
||||||
*params.iter().next().unwrap().1
|
*params.get(&self.primitives.option_type_tvar.id).unwrap()
|
||||||
}
|
}
|
||||||
_ => unreachable!("must be option type"),
|
_ => unreachable!("must be option type"),
|
||||||
};
|
};
|
||||||
|
@ -188,6 +188,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||||
&self.module,
|
&self.module,
|
||||||
generator,
|
generator,
|
||||||
&mut self.unifier,
|
&mut self.unifier,
|
||||||
|
&self.primitives,
|
||||||
self.top_level,
|
self.top_level,
|
||||||
&mut self.type_cache,
|
&mut self.type_cache,
|
||||||
ty,
|
ty,
|
||||||
|
@ -205,6 +206,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||||
&self.module,
|
&self.module,
|
||||||
generator,
|
generator,
|
||||||
&mut self.unifier,
|
&mut self.unifier,
|
||||||
|
&self.primitives,
|
||||||
self.top_level,
|
self.top_level,
|
||||||
&mut self.type_cache,
|
&mut self.type_cache,
|
||||||
&self.primitives,
|
&self.primitives,
|
||||||
|
@ -1190,8 +1192,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
let is_ndarray2 = ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
let is_ndarray2 = ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||||
|
|
||||||
if is_ndarray1 && is_ndarray2 {
|
if is_ndarray1 && is_ndarray2 {
|
||||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1);
|
let ndarray_dtype1 = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, ty1).dtype;
|
||||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty2);
|
let ndarray_dtype2 = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, ty2).dtype;
|
||||||
|
|
||||||
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||||
|
|
||||||
|
@ -1240,8 +1242,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
|
|
||||||
Ok(Some(res.as_base_value().into()))
|
Ok(Some(res.as_base_value().into()))
|
||||||
} else {
|
} else {
|
||||||
let (ndarray_dtype, _) =
|
let ndarray_dtype = unpack_ndarray_params(
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, if is_ndarray1 { ty1 } else { ty2 });
|
&ctx.unifier,
|
||||||
|
&ctx.primitives,
|
||||||
|
if is_ndarray1 { ty1 } else { ty2 },
|
||||||
|
)
|
||||||
|
.dtype;
|
||||||
let ndarray_val = NDArrayValue::from_ptr_val(
|
let ndarray_val = NDArrayValue::from_ptr_val(
|
||||||
if is_ndarray1 { left_val } else { right_val }.into_pointer_value(),
|
if is_ndarray1 { left_val } else { right_val }.into_pointer_value(),
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
|
@ -1427,7 +1433,7 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
}
|
}
|
||||||
} else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
} else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
let ndarray_dtype = unpack_ndarray_params(&ctx.unifier, &ctx.primitives, ty).dtype;
|
||||||
|
|
||||||
let val = NDArrayValue::from_ptr_val(val.into_pointer_value(), llvm_usize, None);
|
let val = NDArrayValue::from_ptr_val(val.into_pointer_value(), llvm_usize, None);
|
||||||
|
|
||||||
|
@ -1511,8 +1517,10 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||||
|
|
||||||
return if is_ndarray1 && is_ndarray2 {
|
return if is_ndarray1 && is_ndarray2 {
|
||||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty);
|
let ndarray_dtype1 =
|
||||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty);
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, left_ty).dtype;
|
||||||
|
let ndarray_dtype2 =
|
||||||
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, right_ty).dtype;
|
||||||
|
|
||||||
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||||
|
|
||||||
|
@ -1546,10 +1554,12 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
|
|
||||||
Ok(Some(res.as_base_value().into()))
|
Ok(Some(res.as_base_value().into()))
|
||||||
} else {
|
} else {
|
||||||
let (ndarray_dtype, _) = unpack_ndarray_var_tys(
|
let ndarray_dtype = unpack_ndarray_params(
|
||||||
&mut ctx.unifier,
|
&ctx.unifier,
|
||||||
|
&ctx.primitives,
|
||||||
if is_ndarray1 { left_ty } else { right_ty },
|
if is_ndarray1 { left_ty } else { right_ty },
|
||||||
);
|
)
|
||||||
|
.dtype;
|
||||||
let res = numpy::ndarray_elementwise_binop_impl(
|
let res = numpy::ndarray_elementwise_binop_impl(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
|
@ -2014,7 +2024,10 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||||
(TypeEnum::TObj { obj_id, params, .. }, TypeEnum::TObj { obj_id: opt_id, .. })
|
(TypeEnum::TObj { obj_id, params, .. }, TypeEnum::TObj { obj_id: opt_id, .. })
|
||||||
if *obj_id == *opt_id =>
|
if *obj_id == *opt_id =>
|
||||||
{
|
{
|
||||||
ctx.get_llvm_type(generator, *params.iter().next().unwrap().1)
|
ctx.get_llvm_type(
|
||||||
|
generator,
|
||||||
|
*params.get(&ctx.primitives.option_type_tvar.id).unwrap(),
|
||||||
|
)
|
||||||
.ptr_type(AddressSpace::default())
|
.ptr_type(AddressSpace::default())
|
||||||
.const_null()
|
.const_null()
|
||||||
.into()
|
.into()
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::classes::{ListType, NDArrayType, ProxyType, RangeType},
|
codegen::classes::{ListType, NDArrayType, ProxyType, RangeType},
|
||||||
symbol_resolver::{StaticValue, SymbolResolver},
|
symbol_resolver::{StaticValue, SymbolResolver},
|
||||||
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef},
|
toplevel::{helper::PrimDef, prim_types::unpack_ndarray_params, TopLevelContext, TopLevelDef},
|
||||||
typecheck::{
|
typecheck::{
|
||||||
type_inferencer::{CodeLocation, PrimitiveStore},
|
type_inferencer::{CodeLocation, PrimitiveStore},
|
||||||
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
|
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
|
||||||
|
@ -423,6 +423,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
module: &Module<'ctx>,
|
module: &Module<'ctx>,
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
unifier: &mut Unifier,
|
unifier: &mut Unifier,
|
||||||
|
store: &PrimitiveStore,
|
||||||
top_level: &TopLevelContext,
|
top_level: &TopLevelContext,
|
||||||
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
|
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
|
||||||
ty: Type,
|
ty: Type,
|
||||||
|
@ -443,18 +444,20 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
module,
|
module,
|
||||||
generator,
|
generator,
|
||||||
unifier,
|
unifier,
|
||||||
|
store,
|
||||||
top_level,
|
top_level,
|
||||||
type_cache,
|
type_cache,
|
||||||
*params.iter().next().unwrap().1,
|
*params.get(&store.option_type_tvar.id).unwrap(),
|
||||||
)
|
)
|
||||||
.ptr_type(AddressSpace::default())
|
.ptr_type(AddressSpace::default())
|
||||||
.into()
|
.into()
|
||||||
}
|
}
|
||||||
|
|
||||||
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||||
let (dtype, _) = unpack_ndarray_var_tys(unifier, ty);
|
let dtype = unpack_ndarray_params(unifier, store, ty).dtype;
|
||||||
let element_type = get_llvm_type(
|
let element_type = get_llvm_type(
|
||||||
ctx, module, generator, unifier, top_level, type_cache, dtype,
|
ctx, module, generator, unifier, store, top_level, type_cache,
|
||||||
|
dtype,
|
||||||
);
|
);
|
||||||
|
|
||||||
NDArrayType::new(generator, ctx, element_type).as_base_type().into()
|
NDArrayType::new(generator, ctx, element_type).as_base_type().into()
|
||||||
|
@ -490,6 +493,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
module,
|
module,
|
||||||
generator,
|
generator,
|
||||||
unifier,
|
unifier,
|
||||||
|
store,
|
||||||
top_level,
|
top_level,
|
||||||
type_cache,
|
type_cache,
|
||||||
fields[&f.0].0,
|
fields[&f.0].0,
|
||||||
|
@ -506,14 +510,17 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
let fields = ty
|
let fields = ty
|
||||||
.iter()
|
.iter()
|
||||||
.map(|ty| {
|
.map(|ty| {
|
||||||
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty)
|
get_llvm_type(
|
||||||
|
ctx, module, generator, unifier, store, top_level, type_cache, *ty,
|
||||||
|
)
|
||||||
})
|
})
|
||||||
.collect_vec();
|
.collect_vec();
|
||||||
ctx.struct_type(&fields, false).into()
|
ctx.struct_type(&fields, false).into()
|
||||||
}
|
}
|
||||||
TList { ty } => {
|
TList { ty } => {
|
||||||
let element_type =
|
let element_type = get_llvm_type(
|
||||||
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty);
|
ctx, module, generator, unifier, store, top_level, type_cache, *ty,
|
||||||
|
);
|
||||||
|
|
||||||
ListType::new(generator, ctx, element_type).as_base_type().into()
|
ListType::new(generator, ctx, element_type).as_base_type().into()
|
||||||
}
|
}
|
||||||
|
@ -540,6 +547,7 @@ fn get_llvm_abi_type<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
module: &Module<'ctx>,
|
module: &Module<'ctx>,
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
unifier: &mut Unifier,
|
unifier: &mut Unifier,
|
||||||
|
store: &PrimitiveStore,
|
||||||
top_level: &TopLevelContext,
|
top_level: &TopLevelContext,
|
||||||
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
|
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
|
||||||
primitives: &PrimitiveStore,
|
primitives: &PrimitiveStore,
|
||||||
|
@ -550,7 +558,7 @@ fn get_llvm_abi_type<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
return if unifier.unioned(ty, primitives.bool) {
|
return if unifier.unioned(ty, primitives.bool) {
|
||||||
ctx.bool_type().into()
|
ctx.bool_type().into()
|
||||||
} else {
|
} else {
|
||||||
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, ty)
|
get_llvm_type(ctx, module, generator, unifier, store, top_level, type_cache, ty)
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -699,6 +707,7 @@ pub fn gen_func_impl<
|
||||||
&module,
|
&module,
|
||||||
generator,
|
generator,
|
||||||
&mut unifier,
|
&mut unifier,
|
||||||
|
&primitives,
|
||||||
top_level_ctx.as_ref(),
|
top_level_ctx.as_ref(),
|
||||||
&mut type_cache,
|
&mut type_cache,
|
||||||
&primitives,
|
&primitives,
|
||||||
|
@ -715,6 +724,7 @@ pub fn gen_func_impl<
|
||||||
&module,
|
&module,
|
||||||
generator,
|
generator,
|
||||||
&mut unifier,
|
&mut unifier,
|
||||||
|
&primitives,
|
||||||
top_level_ctx.as_ref(),
|
top_level_ctx.as_ref(),
|
||||||
&mut type_cache,
|
&mut type_cache,
|
||||||
&primitives,
|
&primitives,
|
||||||
|
@ -767,6 +777,7 @@ pub fn gen_func_impl<
|
||||||
&module,
|
&module,
|
||||||
generator,
|
generator,
|
||||||
&mut unifier,
|
&mut unifier,
|
||||||
|
&primitives,
|
||||||
top_level_ctx.as_ref(),
|
top_level_ctx.as_ref(),
|
||||||
&mut type_cache,
|
&mut type_cache,
|
||||||
arg.ty,
|
arg.ty,
|
||||||
|
|
|
@ -19,7 +19,7 @@ use crate::{
|
||||||
symbol_resolver::ValueEnum,
|
symbol_resolver::ValueEnum,
|
||||||
toplevel::{
|
toplevel::{
|
||||||
helper::PrimDef,
|
helper::PrimDef,
|
||||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
prim_types::{make_ndarray_ty, unpack_ndarray_params},
|
||||||
DefinitionId,
|
DefinitionId,
|
||||||
},
|
},
|
||||||
typecheck::typedef::{FunSignature, Type, TypeEnum},
|
typecheck::typedef::{FunSignature, Type, TypeEnum},
|
||||||
|
@ -1776,7 +1776,7 @@ pub fn gen_ndarray_array<'ctx>(
|
||||||
let obj_ty = fun.0.args[0].ty;
|
let obj_ty = fun.0.args[0].ty;
|
||||||
let obj_elem_ty = match &*context.unifier.get_ty(obj_ty) {
|
let obj_elem_ty = match &*context.unifier.get_ty(obj_ty) {
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||||
unpack_ndarray_var_tys(&mut context.unifier, obj_ty).0
|
unpack_ndarray_params(&context.unifier, &context.primitives, obj_ty).dtype
|
||||||
}
|
}
|
||||||
|
|
||||||
TypeEnum::TList { ty } => {
|
TypeEnum::TList { ty } => {
|
||||||
|
@ -1916,7 +1916,7 @@ pub fn gen_ndarray_copy<'ctx>(
|
||||||
let llvm_usize = generator.get_size_type(context.ctx);
|
let llvm_usize = generator.get_size_type(context.ctx);
|
||||||
|
|
||||||
let this_ty = obj.as_ref().unwrap().0;
|
let this_ty = obj.as_ref().unwrap().0;
|
||||||
let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty);
|
let this_elem_ty = unpack_ndarray_params(&context.unifier, &context.primitives, this_ty).dtype;
|
||||||
let this_arg =
|
let this_arg =
|
||||||
obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?;
|
obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?;
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ use crate::{
|
||||||
expr::gen_binop_expr,
|
expr::gen_binop_expr,
|
||||||
gen_in_range_check,
|
gen_in_range_check,
|
||||||
},
|
},
|
||||||
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef},
|
toplevel::{helper::PrimDef, prim_types::unpack_ndarray_params, DefinitionId, TopLevelDef},
|
||||||
typecheck::typedef::{FunSignature, Type, TypeEnum},
|
typecheck::typedef::{FunSignature, Type, TypeEnum},
|
||||||
};
|
};
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
|
@ -245,7 +245,8 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
|
||||||
let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) {
|
let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) {
|
||||||
TypeEnum::TList { ty } => *ty,
|
TypeEnum::TList { ty } => *ty,
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, target.custom.unwrap()).0
|
unpack_ndarray_params(&ctx.unifier, &ctx.primitives, target.custom.unwrap())
|
||||||
|
.dtype
|
||||||
}
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
|
|
|
@ -24,8 +24,8 @@ use crate::{
|
||||||
stmt::exn_constructor,
|
stmt::exn_constructor,
|
||||||
},
|
},
|
||||||
symbol_resolver::SymbolValue,
|
symbol_resolver::SymbolValue,
|
||||||
toplevel::{helper::PrimDef, numpy::make_ndarray_ty},
|
toplevel::{helper::PrimDef, prim_types::make_ndarray_ty},
|
||||||
typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap},
|
typecheck::typedef::{into_var_map, TypeVar, VarMap},
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
@ -301,10 +301,7 @@ struct BuiltinBuilder<'a> {
|
||||||
|
|
||||||
is_some_ty: (Type, bool),
|
is_some_ty: (Type, bool),
|
||||||
unwrap_ty: (Type, bool),
|
unwrap_ty: (Type, bool),
|
||||||
option_tvar: TypeVar,
|
|
||||||
|
|
||||||
ndarray_dtype_tvar: TypeVar,
|
|
||||||
ndarray_ndims_tvar: TypeVar,
|
|
||||||
ndarray_copy_ty: (Type, bool),
|
ndarray_copy_ty: (Type, bool),
|
||||||
ndarray_fill_ty: (Type, bool),
|
ndarray_fill_ty: (Type, bool),
|
||||||
|
|
||||||
|
@ -339,24 +336,19 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
} = *primitives;
|
} = *primitives;
|
||||||
|
|
||||||
// Option-related
|
// Option-related
|
||||||
let (is_some_ty, unwrap_ty, option_tvar) =
|
let (is_some_ty, unwrap_ty) =
|
||||||
if let TypeEnum::TObj { fields, params, .. } = unifier.get_ty(option).as_ref() {
|
if let TypeEnum::TObj { fields, .. } = unifier.get_ty(option).as_ref() {
|
||||||
(
|
(
|
||||||
*fields.get(&PrimDef::OptionIsSome.simple_name().into()).unwrap(),
|
*fields.get(&PrimDef::OptionIsSome.simple_name().into()).unwrap(),
|
||||||
*fields.get(&PrimDef::OptionUnwrap.simple_name().into()).unwrap(),
|
*fields.get(&PrimDef::OptionUnwrap.simple_name().into()).unwrap(),
|
||||||
iter_type_vars(params).next().unwrap(),
|
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
|
|
||||||
let TypeEnum::TObj { fields: ndarray_fields, params: ndarray_params, .. } =
|
let TypeEnum::TObj { fields: ndarray_fields, .. } = &*unifier.get_ty(ndarray) else {
|
||||||
&*unifier.get_ty(ndarray)
|
|
||||||
else {
|
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
let ndarray_dtype_tvar = iter_type_vars(ndarray_params).next().unwrap();
|
|
||||||
let ndarray_ndims_tvar = iter_type_vars(ndarray_params).nth(1).unwrap();
|
|
||||||
let ndarray_copy_ty =
|
let ndarray_copy_ty =
|
||||||
*ndarray_fields.get(&PrimDef::NDArrayCopy.simple_name().into()).unwrap();
|
*ndarray_fields.get(&PrimDef::NDArrayCopy.simple_name().into()).unwrap();
|
||||||
let ndarray_fill_ty =
|
let ndarray_fill_ty =
|
||||||
|
@ -398,10 +390,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
|
|
||||||
is_some_ty,
|
is_some_ty,
|
||||||
unwrap_ty,
|
unwrap_ty,
|
||||||
option_tvar,
|
|
||||||
|
|
||||||
ndarray_dtype_tvar,
|
|
||||||
ndarray_ndims_tvar,
|
|
||||||
ndarray_copy_ty,
|
ndarray_copy_ty,
|
||||||
ndarray_fill_ty,
|
ndarray_fill_ty,
|
||||||
|
|
||||||
|
@ -622,7 +611,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
PrimDef::Option => TopLevelDef::Class {
|
PrimDef::Option => TopLevelDef::Class {
|
||||||
name: prim.name().into(),
|
name: prim.name().into(),
|
||||||
object_id: prim.id(),
|
object_id: prim.id(),
|
||||||
type_vars: vec![self.option_tvar.ty],
|
type_vars: vec![self.primitives.option_type_tvar.ty],
|
||||||
fields: vec![],
|
fields: vec![],
|
||||||
methods: vec![
|
methods: vec![
|
||||||
Self::create_method(PrimDef::OptionIsSome, self.is_some_ty.0),
|
Self::create_method(PrimDef::OptionIsSome, self.is_some_ty.0),
|
||||||
|
@ -642,7 +631,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
name: prim.name().into(),
|
name: prim.name().into(),
|
||||||
simple_name: prim.simple_name().into(),
|
simple_name: prim.simple_name().into(),
|
||||||
signature: self.unwrap_ty.0,
|
signature: self.unwrap_ty.0,
|
||||||
var_id: vec![self.option_tvar.id],
|
var_id: vec![self.primitives.option_type_tvar.id],
|
||||||
instance_to_symbol: HashMap::default(),
|
instance_to_symbol: HashMap::default(),
|
||||||
instance_to_stmt: HashMap::default(),
|
instance_to_stmt: HashMap::default(),
|
||||||
resolver: None,
|
resolver: None,
|
||||||
|
@ -656,7 +645,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
name: prim.name().to_string(),
|
name: prim.name().to_string(),
|
||||||
simple_name: prim.simple_name().into(),
|
simple_name: prim.simple_name().into(),
|
||||||
signature: self.is_some_ty.0,
|
signature: self.is_some_ty.0,
|
||||||
var_id: vec![self.option_tvar.id],
|
var_id: vec![self.primitives.option_type_tvar.id],
|
||||||
instance_to_symbol: HashMap::default(),
|
instance_to_symbol: HashMap::default(),
|
||||||
instance_to_stmt: HashMap::default(),
|
instance_to_stmt: HashMap::default(),
|
||||||
resolver: None,
|
resolver: None,
|
||||||
|
@ -693,13 +682,13 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
args: vec![FuncArg {
|
args: vec![FuncArg {
|
||||||
name: "n".into(),
|
name: "n".into(),
|
||||||
ty: self.option_tvar.ty,
|
ty: self.primitives.option_type_tvar.ty,
|
||||||
default_value: None,
|
default_value: None,
|
||||||
}],
|
}],
|
||||||
ret: self.primitives.option,
|
ret: self.primitives.option,
|
||||||
vars: into_var_map([self.option_tvar]),
|
vars: into_var_map([self.primitives.option_type_tvar]),
|
||||||
})),
|
})),
|
||||||
var_id: vec![self.option_tvar.id],
|
var_id: vec![self.primitives.option_type_tvar.id],
|
||||||
instance_to_symbol: HashMap::default(),
|
instance_to_symbol: HashMap::default(),
|
||||||
instance_to_stmt: HashMap::default(),
|
instance_to_stmt: HashMap::default(),
|
||||||
resolver: None,
|
resolver: None,
|
||||||
|
@ -735,7 +724,10 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
PrimDef::NDArray => TopLevelDef::Class {
|
PrimDef::NDArray => TopLevelDef::Class {
|
||||||
name: prim.name().into(),
|
name: prim.name().into(),
|
||||||
object_id: prim.id(),
|
object_id: prim.id(),
|
||||||
type_vars: vec![self.ndarray_dtype_tvar.ty, self.ndarray_ndims_tvar.ty],
|
type_vars: vec![
|
||||||
|
self.primitives.ndarray_dtype_tvar.ty,
|
||||||
|
self.primitives.ndarray_ndims_tvar.ty,
|
||||||
|
],
|
||||||
fields: Vec::default(),
|
fields: Vec::default(),
|
||||||
methods: vec![
|
methods: vec![
|
||||||
Self::create_method(PrimDef::NDArrayCopy, self.ndarray_copy_ty.0),
|
Self::create_method(PrimDef::NDArrayCopy, self.ndarray_copy_ty.0),
|
||||||
|
@ -751,7 +743,10 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
name: prim.name().into(),
|
name: prim.name().into(),
|
||||||
simple_name: prim.simple_name().into(),
|
simple_name: prim.simple_name().into(),
|
||||||
signature: self.ndarray_copy_ty.0,
|
signature: self.ndarray_copy_ty.0,
|
||||||
var_id: vec![self.ndarray_dtype_tvar.id, self.ndarray_ndims_tvar.id],
|
var_id: vec![
|
||||||
|
self.primitives.ndarray_dtype_tvar.id,
|
||||||
|
self.primitives.ndarray_ndims_tvar.id,
|
||||||
|
],
|
||||||
instance_to_symbol: HashMap::default(),
|
instance_to_symbol: HashMap::default(),
|
||||||
instance_to_stmt: HashMap::default(),
|
instance_to_stmt: HashMap::default(),
|
||||||
resolver: None,
|
resolver: None,
|
||||||
|
@ -768,7 +763,10 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
name: prim.name().into(),
|
name: prim.name().into(),
|
||||||
simple_name: prim.simple_name().into(),
|
simple_name: prim.simple_name().into(),
|
||||||
signature: self.ndarray_fill_ty.0,
|
signature: self.ndarray_fill_ty.0,
|
||||||
var_id: vec![self.ndarray_dtype_tvar.id, self.ndarray_ndims_tvar.id],
|
var_id: vec![
|
||||||
|
self.primitives.ndarray_dtype_tvar.id,
|
||||||
|
self.primitives.ndarray_ndims_tvar.id,
|
||||||
|
],
|
||||||
instance_to_symbol: HashMap::default(),
|
instance_to_symbol: HashMap::default(),
|
||||||
instance_to_stmt: HashMap::default(),
|
instance_to_stmt: HashMap::default(),
|
||||||
resolver: None,
|
resolver: None,
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
use std::convert::TryInto;
|
use std::convert::TryInto;
|
||||||
|
|
||||||
use crate::symbol_resolver::SymbolValue;
|
use crate::symbol_resolver::SymbolValue;
|
||||||
use crate::toplevel::numpy::unpack_ndarray_var_tys;
|
use crate::toplevel::prim_types::unpack_ndarray_params;
|
||||||
use crate::typecheck::typedef::{into_var_map, Mapping, TypeVarId, VarMap};
|
use crate::typecheck::typedef::{into_var_map, Mapping, TypeVar, TypeVarId, VarMap};
|
||||||
use nac3parser::ast::{Constant, Location};
|
use nac3parser::ast::{Constant, Location};
|
||||||
use strum::IntoEnumIterator;
|
use strum::IntoEnumIterator;
|
||||||
use strum_macros::EnumIter;
|
use strum_macros::EnumIter;
|
||||||
|
@ -286,6 +286,18 @@ pub fn make_exception_fields(int32: Type, int64: Type, str: Type) -> Vec<(StrRef
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn make_option_type_tvar(unifier: &mut Unifier) -> TypeVar {
|
||||||
|
unifier.get_fresh_var(Some("option_type_var".into()), None)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn make_ndarray_dtype_tvar(unifier: &mut Unifier) -> TypeVar {
|
||||||
|
unifier.get_fresh_var(Some("ndarray_dtype".into()), None)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn make_ndarray_ndims_tvar(unifier: &mut Unifier, size_ty: Type) -> TypeVar {
|
||||||
|
unifier.get_fresh_const_generic_var(size_ty, Some("ndarray_ndims".into()), None)
|
||||||
|
}
|
||||||
|
|
||||||
impl TopLevelDef {
|
impl TopLevelDef {
|
||||||
pub fn to_string(&self, unifier: &mut Unifier) -> String {
|
pub fn to_string(&self, unifier: &mut Unifier) -> String {
|
||||||
match self {
|
match self {
|
||||||
|
@ -381,16 +393,16 @@ impl TopLevelComposer {
|
||||||
params: VarMap::new(),
|
params: VarMap::new(),
|
||||||
});
|
});
|
||||||
|
|
||||||
let option_type_var = unifier.get_fresh_var(Some("option_type_var".into()), None);
|
let option_type_tvar = make_option_type_tvar(&mut unifier);
|
||||||
let is_some_type_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
let is_some_type_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
args: vec![],
|
args: vec![],
|
||||||
ret: bool,
|
ret: bool,
|
||||||
vars: into_var_map([option_type_var]),
|
vars: into_var_map([option_type_tvar]),
|
||||||
}));
|
}));
|
||||||
let unwrap_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
let unwrap_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
args: vec![],
|
args: vec![],
|
||||||
ret: option_type_var.ty,
|
ret: option_type_tvar.ty,
|
||||||
vars: into_var_map([option_type_var]),
|
vars: into_var_map([option_type_tvar]),
|
||||||
}));
|
}));
|
||||||
let option = unifier.add_ty(TypeEnum::TObj {
|
let option = unifier.add_ty(TypeEnum::TObj {
|
||||||
obj_id: PrimDef::Option.id(),
|
obj_id: PrimDef::Option.id(),
|
||||||
|
@ -401,7 +413,7 @@ impl TopLevelComposer {
|
||||||
]
|
]
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.collect::<HashMap<_, _>>(),
|
.collect::<HashMap<_, _>>(),
|
||||||
params: into_var_map([option_type_var]),
|
params: into_var_map([option_type_tvar]),
|
||||||
});
|
});
|
||||||
|
|
||||||
let size_t_ty = match size_t {
|
let size_t_ty = match size_t {
|
||||||
|
@ -410,9 +422,8 @@ impl TopLevelComposer {
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
|
let ndarray_dtype_tvar = make_ndarray_dtype_tvar(&mut unifier);
|
||||||
let ndarray_ndims_tvar =
|
let ndarray_ndims_tvar = make_ndarray_ndims_tvar(&mut unifier, size_t_ty);
|
||||||
unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None);
|
|
||||||
let ndarray_copy_fun_ret_ty = unifier.get_fresh_var(None, None);
|
let ndarray_copy_fun_ret_ty = unifier.get_fresh_var(None, None);
|
||||||
let ndarray_copy_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
let ndarray_copy_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
args: vec![],
|
args: vec![],
|
||||||
|
@ -451,7 +462,10 @@ impl TopLevelComposer {
|
||||||
str,
|
str,
|
||||||
exception,
|
exception,
|
||||||
option,
|
option,
|
||||||
|
option_type_tvar,
|
||||||
ndarray,
|
ndarray,
|
||||||
|
ndarray_dtype_tvar,
|
||||||
|
ndarray_ndims_tvar,
|
||||||
size_t,
|
size_t,
|
||||||
};
|
};
|
||||||
unifier.put_primitive_store(&primitives);
|
unifier.put_primitive_store(&primitives);
|
||||||
|
@ -881,22 +895,26 @@ pub fn parse_parameter_default_value(
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Obtains the element type of an array-like type.
|
/// Obtains the element type of an array-like type.
|
||||||
pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type {
|
pub fn arraylike_flatten_element_type(
|
||||||
|
unifier: &mut Unifier,
|
||||||
|
store: &PrimitiveStore,
|
||||||
|
ty: Type,
|
||||||
|
) -> Type {
|
||||||
match &*unifier.get_ty(ty) {
|
match &*unifier.get_ty(ty) {
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||||
unpack_ndarray_var_tys(unifier, ty).0
|
unpack_ndarray_params(unifier, store, ty).dtype
|
||||||
}
|
}
|
||||||
|
|
||||||
TypeEnum::TList { ty } => arraylike_flatten_element_type(unifier, *ty),
|
TypeEnum::TList { ty } => arraylike_flatten_element_type(unifier, store, *ty),
|
||||||
_ => ty,
|
_ => ty,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Obtains the number of dimensions of an array-like type.
|
/// Obtains the number of dimensions of an array-like type.
|
||||||
pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 {
|
pub fn arraylike_get_ndims(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) -> u64 {
|
||||||
match &*unifier.get_ty(ty) {
|
match &*unifier.get_ty(ty) {
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||||
let ndims = unpack_ndarray_var_tys(unifier, ty).1;
|
let ndims = unpack_ndarray_params(unifier, store, ty).ndims;
|
||||||
let TypeEnum::TLiteral { values, .. } = &*unifier.get_ty_immutable(ndims) else {
|
let TypeEnum::TLiteral { values, .. } = &*unifier.get_ty_immutable(ndims) else {
|
||||||
panic!("Expected TLiteral for ndarray.ndims, got {}", unifier.stringify(ndims))
|
panic!("Expected TLiteral for ndarray.ndims, got {}", unifier.stringify(ndims))
|
||||||
};
|
};
|
||||||
|
@ -908,7 +926,7 @@ pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 {
|
||||||
u64::try_from(values[0].clone()).unwrap()
|
u64::try_from(values[0].clone()).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
TypeEnum::TList { ty } => arraylike_get_ndims(unifier, *ty) + 1,
|
TypeEnum::TList { ty } => arraylike_get_ndims(unifier, store, *ty) + 1,
|
||||||
_ => 0,
|
_ => 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,7 +30,7 @@ pub struct DefinitionId(pub usize);
|
||||||
pub mod builtins;
|
pub mod builtins;
|
||||||
pub mod composer;
|
pub mod composer;
|
||||||
pub mod helper;
|
pub mod helper;
|
||||||
pub mod numpy;
|
pub mod prim_types;
|
||||||
pub mod type_annotation;
|
pub mod type_annotation;
|
||||||
use composer::*;
|
use composer::*;
|
||||||
use type_annotation::*;
|
use type_annotation::*;
|
||||||
|
|
|
@ -2,7 +2,7 @@ use crate::{
|
||||||
toplevel::helper::PrimDef,
|
toplevel::helper::PrimDef,
|
||||||
typecheck::{
|
typecheck::{
|
||||||
type_inferencer::PrimitiveStore,
|
type_inferencer::PrimitiveStore,
|
||||||
typedef::{Type, TypeEnum, TypeVarId, Unifier, VarMap},
|
typedef::{Type, TypeEnum, Unifier, VarMap},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
@ -57,29 +57,25 @@ pub fn subst_ndarray_tvars(
|
||||||
unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray)
|
unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn unpack_ndarray_tvars(unifier: &mut Unifier, ndarray: Type) -> Vec<(TypeVarId, Type)> {
|
#[derive(Clone, Copy, Debug)]
|
||||||
|
pub struct NDArrayParams {
|
||||||
|
pub dtype: Type,
|
||||||
|
pub ndims: Type,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract the [`Type`]s of `ndarray`.
|
||||||
|
#[must_use]
|
||||||
|
pub fn unpack_ndarray_params(
|
||||||
|
unifier: &Unifier,
|
||||||
|
store: &PrimitiveStore,
|
||||||
|
ndarray: Type,
|
||||||
|
) -> NDArrayParams {
|
||||||
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
|
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
|
||||||
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
|
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
|
||||||
};
|
};
|
||||||
debug_assert_eq!(*obj_id, PrimDef::NDArray.id());
|
debug_assert_eq!(*obj_id, PrimDef::NDArray.id());
|
||||||
debug_assert_eq!(params.len(), 2);
|
debug_assert_eq!(params.len(), 2);
|
||||||
|
let dtype = *params.get(&store.ndarray_dtype_tvar.id).unwrap();
|
||||||
params
|
let ndims = *params.get(&store.ndarray_ndims_tvar.id).unwrap();
|
||||||
.iter()
|
NDArrayParams { dtype, ndims }
|
||||||
.sorted_by_key(|(obj_id, _)| *obj_id)
|
|
||||||
.map(|(var_id, ty)| (*var_id, *ty))
|
|
||||||
.collect_vec()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds
|
|
||||||
/// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray`
|
|
||||||
/// respectively.
|
|
||||||
pub fn unpack_ndarray_var_ids(unifier: &mut Unifier, ndarray: Type) -> (TypeVarId, TypeVarId) {
|
|
||||||
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.0).collect_tuple().unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to
|
|
||||||
/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively.
|
|
||||||
pub fn unpack_ndarray_var_tys(unifier: &mut Unifier, ndarray: Type) -> (Type, Type) {
|
|
||||||
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.1).collect_tuple().unwrap()
|
|
||||||
}
|
}
|
|
@ -1,6 +1,6 @@
|
||||||
use crate::symbol_resolver::SymbolValue;
|
use crate::symbol_resolver::SymbolValue;
|
||||||
use crate::toplevel::helper::PrimDef;
|
use crate::toplevel::helper::PrimDef;
|
||||||
use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys};
|
use crate::toplevel::prim_types::{make_ndarray_ty, unpack_ndarray_params};
|
||||||
use crate::typecheck::{
|
use crate::typecheck::{
|
||||||
type_inferencer::*,
|
type_inferencer::*,
|
||||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
|
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
|
||||||
|
@ -369,16 +369,16 @@ pub fn typeof_ndarray_broadcast(
|
||||||
if is_left_ndarray && is_right_ndarray {
|
if is_left_ndarray && is_right_ndarray {
|
||||||
// Perform broadcasting on two ndarray operands.
|
// Perform broadcasting on two ndarray operands.
|
||||||
|
|
||||||
let (left_ty_dtype, left_ty_ndims) = unpack_ndarray_var_tys(unifier, left);
|
let left_params = unpack_ndarray_params(unifier, primitives, left);
|
||||||
let (right_ty_dtype, right_ty_ndims) = unpack_ndarray_var_tys(unifier, right);
|
let right_params = unpack_ndarray_params(unifier, primitives, right);
|
||||||
|
|
||||||
assert!(unifier.unioned(left_ty_dtype, right_ty_dtype));
|
assert!(unifier.unioned(left_params.dtype, right_params.dtype));
|
||||||
|
|
||||||
let left_ty_ndims = match &*unifier.get_ty_immutable(left_ty_ndims) {
|
let left_ty_ndims = match &*unifier.get_ty_immutable(left_params.ndims) {
|
||||||
TypeEnum::TLiteral { values, .. } => values.clone(),
|
TypeEnum::TLiteral { values, .. } => values.clone(),
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
let right_ty_ndims = match &*unifier.get_ty_immutable(right_ty_ndims) {
|
let right_ty_ndims = match &*unifier.get_ty_immutable(right_params.ndims) {
|
||||||
TypeEnum::TLiteral { values, .. } => values.clone(),
|
TypeEnum::TLiteral { values, .. } => values.clone(),
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
|
@ -397,11 +397,11 @@ pub fn typeof_ndarray_broadcast(
|
||||||
.collect_vec();
|
.collect_vec();
|
||||||
let res_ndims = unifier.get_fresh_literal(res_ndims, None);
|
let res_ndims = unifier.get_fresh_literal(res_ndims, None);
|
||||||
|
|
||||||
Ok(make_ndarray_ty(unifier, primitives, Some(left_ty_dtype), Some(res_ndims)))
|
Ok(make_ndarray_ty(unifier, primitives, Some(left_params.dtype), Some(res_ndims)))
|
||||||
} else {
|
} else {
|
||||||
let (ndarray_ty, scalar_ty) = if is_left_ndarray { (left, right) } else { (right, left) };
|
let (ndarray_ty, scalar_ty) = if is_left_ndarray { (left, right) } else { (right, left) };
|
||||||
|
|
||||||
let (ndarray_ty_dtype, _) = unpack_ndarray_var_tys(unifier, ndarray_ty);
|
let ndarray_ty_dtype = unpack_ndarray_params(unifier, primitives, ndarray_ty).dtype;
|
||||||
|
|
||||||
if unifier.unioned(ndarray_ty_dtype, scalar_ty) {
|
if unifier.unioned(ndarray_ty_dtype, scalar_ty) {
|
||||||
Ok(ndarray_ty)
|
Ok(ndarray_ty)
|
||||||
|
@ -444,7 +444,7 @@ pub fn typeof_binop(
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator::MatMult => {
|
Operator::MatMult => {
|
||||||
let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs);
|
let lhs_ndims = unpack_ndarray_params(unifier, primitives, lhs).ndims;
|
||||||
let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) {
|
let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) {
|
||||||
TypeEnum::TLiteral { values, .. } => {
|
TypeEnum::TLiteral { values, .. } => {
|
||||||
assert_eq!(values.len(), 1);
|
assert_eq!(values.len(), 1);
|
||||||
|
@ -452,7 +452,7 @@ pub fn typeof_binop(
|
||||||
}
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs);
|
let rhs_ndims = unpack_ndarray_params(unifier, primitives, rhs).ndims;
|
||||||
let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) {
|
let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) {
|
||||||
TypeEnum::TLiteral { values, .. } => {
|
TypeEnum::TLiteral { values, .. } => {
|
||||||
assert_eq!(values.len(), 1);
|
assert_eq!(values.len(), 1);
|
||||||
|
@ -552,7 +552,7 @@ pub fn typeof_unaryop(
|
||||||
|
|
||||||
Unaryop::UAdd | Unaryop::USub => {
|
Unaryop::UAdd | Unaryop::USub => {
|
||||||
if operand_obj_id.is_some_and(|id| id == PrimDef::NDArray.id()) {
|
if operand_obj_id.is_some_and(|id| id == PrimDef::NDArray.id()) {
|
||||||
let (dtype, _) = unpack_ndarray_var_tys(unifier, operand);
|
let dtype = unpack_ndarray_params(unifier, primitives, operand).dtype;
|
||||||
if dtype.obj_id(unifier).is_some_and(|id| id == PrimDef::Bool.id()) {
|
if dtype.obj_id(unifier).is_some_and(|id| id == PrimDef::Bool.id()) {
|
||||||
return Err(if op == Unaryop::UAdd {
|
return Err(if op == Unaryop::UAdd {
|
||||||
"The ufunc 'positive' cannot be applied to ndarray[bool, N]".to_string()
|
"The ufunc 'positive' cannot be applied to ndarray[bool, N]".to_string()
|
||||||
|
@ -586,7 +586,7 @@ pub fn typeof_cmpop(
|
||||||
|
|
||||||
Ok(Some(if is_left_ndarray || is_right_ndarray {
|
Ok(Some(if is_left_ndarray || is_right_ndarray {
|
||||||
let brd = typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?;
|
let brd = typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?;
|
||||||
let (_, ndims) = unpack_ndarray_var_tys(unifier, brd);
|
let ndims = unpack_ndarray_params(unifier, primitives, brd).ndims;
|
||||||
|
|
||||||
make_ndarray_ty(unifier, primitives, Some(primitives.bool), Some(ndims))
|
make_ndarray_ty(unifier, primitives, Some(primitives.bool), Some(ndims))
|
||||||
} else if unifier.unioned(lhs, rhs) {
|
} else if unifier.unioned(lhs, rhs) {
|
||||||
|
@ -653,8 +653,8 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
||||||
unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None);
|
unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None);
|
||||||
let ndarray_unsized_t =
|
let ndarray_unsized_t =
|
||||||
make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.ty));
|
make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.ty));
|
||||||
let (ndarray_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_t);
|
let ndarray_dtype_t = unpack_ndarray_params(unifier, store, ndarray_t).dtype;
|
||||||
let (ndarray_unsized_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_unsized_t);
|
let ndarray_unsized_dtype_t = unpack_ndarray_params(unifier, store, ndarray_unsized_t).dtype;
|
||||||
impl_basic_arithmetic(
|
impl_basic_arithmetic(
|
||||||
unifier,
|
unifier,
|
||||||
store,
|
store,
|
||||||
|
|
|
@ -4,13 +4,15 @@ use std::iter::once;
|
||||||
use std::ops::Not;
|
use std::ops::Not;
|
||||||
use std::{cell::RefCell, sync::Arc};
|
use std::{cell::RefCell, sync::Arc};
|
||||||
|
|
||||||
use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap};
|
use super::typedef::{
|
||||||
|
Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, TypeVar, Unifier, VarMap,
|
||||||
|
};
|
||||||
use super::{magic_methods::*, type_error::TypeError, typedef::CallId};
|
use super::{magic_methods::*, type_error::TypeError, typedef::CallId};
|
||||||
use crate::{
|
use crate::{
|
||||||
symbol_resolver::{SymbolResolver, SymbolValue},
|
symbol_resolver::{SymbolResolver, SymbolValue},
|
||||||
toplevel::{
|
toplevel::{
|
||||||
helper::{arraylike_flatten_element_type, arraylike_get_ndims, PrimDef},
|
helper::{arraylike_flatten_element_type, arraylike_get_ndims, PrimDef},
|
||||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
prim_types::{make_ndarray_ty, unpack_ndarray_params},
|
||||||
TopLevelContext,
|
TopLevelContext,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
@ -49,7 +51,11 @@ pub struct PrimitiveStore {
|
||||||
pub str: Type,
|
pub str: Type,
|
||||||
pub exception: Type,
|
pub exception: Type,
|
||||||
pub option: Type,
|
pub option: Type,
|
||||||
|
/// The contained type of an `Option`
|
||||||
|
pub option_type_tvar: TypeVar,
|
||||||
pub ndarray: Type,
|
pub ndarray: Type,
|
||||||
|
pub ndarray_dtype_tvar: TypeVar,
|
||||||
|
pub ndarray_ndims_tvar: TypeVar,
|
||||||
pub size_t: u32,
|
pub size_t: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -896,7 +902,8 @@ impl<'a> Inferencer<'a> {
|
||||||
|
|
||||||
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
{
|
{
|
||||||
let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
|
let ndarray_ndims =
|
||||||
|
unpack_ndarray_params(self.unifier, self.primitives, arg0_ty).ndims;
|
||||||
|
|
||||||
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims))
|
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims))
|
||||||
} else {
|
} else {
|
||||||
|
@ -934,9 +941,7 @@ impl<'a> Inferencer<'a> {
|
||||||
|
|
||||||
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
{
|
{
|
||||||
let (ndarray_dtype, _) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
|
unpack_ndarray_params(self.unifier, self.primitives, arg0_ty).dtype
|
||||||
|
|
||||||
ndarray_dtype
|
|
||||||
} else {
|
} else {
|
||||||
arg0_ty
|
arg0_ty
|
||||||
};
|
};
|
||||||
|
@ -988,14 +993,14 @@ impl<'a> Inferencer<'a> {
|
||||||
|
|
||||||
let arg0_dtype =
|
let arg0_dtype =
|
||||||
if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
||||||
unpack_ndarray_var_tys(self.unifier, arg0_ty).0
|
unpack_ndarray_params(self.unifier, self.primitives, arg0_ty).dtype
|
||||||
} else {
|
} else {
|
||||||
arg0_ty
|
arg0_ty
|
||||||
};
|
};
|
||||||
|
|
||||||
let arg1_dtype =
|
let arg1_dtype =
|
||||||
if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
||||||
unpack_ndarray_var_tys(self.unifier, arg1_ty).0
|
unpack_ndarray_params(self.unifier, self.primitives, arg1_ty).dtype
|
||||||
} else {
|
} else {
|
||||||
arg1_ty
|
arg1_ty
|
||||||
};
|
};
|
||||||
|
@ -1026,7 +1031,8 @@ impl<'a> Inferencer<'a> {
|
||||||
// (float, int32), so convert it to align with the dtype of the first arg
|
// (float, int32), so convert it to align with the dtype of the first arg
|
||||||
let arg1_ty = if id == &"np_ldexp".into() {
|
let arg1_ty = if id == &"np_ldexp".into() {
|
||||||
if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
||||||
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, arg1_ty);
|
let ndims =
|
||||||
|
unpack_ndarray_params(self.unifier, self.primitives, arg1_ty).ndims;
|
||||||
|
|
||||||
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndims))
|
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndims))
|
||||||
} else {
|
} else {
|
||||||
|
@ -1115,7 +1121,8 @@ impl<'a> Inferencer<'a> {
|
||||||
|
|
||||||
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
{
|
{
|
||||||
let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
|
let ndarray_ndims =
|
||||||
|
unpack_ndarray_params(self.unifier, self.primitives, arg0_ty).ndims;
|
||||||
|
|
||||||
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims))
|
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims))
|
||||||
} else {
|
} else {
|
||||||
|
@ -1258,7 +1265,8 @@ impl<'a> Inferencer<'a> {
|
||||||
let ndmin_kw =
|
let ndmin_kw =
|
||||||
keywords.iter().find(|kwarg| kwarg.node.arg.is_some_and(|id| id == "ndmin".into()));
|
keywords.iter().find(|kwarg| kwarg.node.arg.is_some_and(|id| id == "ndmin".into()));
|
||||||
|
|
||||||
let ty = arraylike_flatten_element_type(self.unifier, arg0.custom.unwrap());
|
let ty =
|
||||||
|
arraylike_flatten_element_type(self.unifier, self.primitives, arg0.custom.unwrap());
|
||||||
let ndims = if let Some(ndmin_kw) = ndmin_kw {
|
let ndims = if let Some(ndmin_kw) = ndmin_kw {
|
||||||
match &ndmin_kw.node.value.node {
|
match &ndmin_kw.node.value.node {
|
||||||
ExprKind::Constant { value, .. } => match value {
|
ExprKind::Constant { value, .. } => match value {
|
||||||
|
@ -1266,10 +1274,10 @@ impl<'a> Inferencer<'a> {
|
||||||
_ => return Err(HashSet::from(["Expected uint64 for ndims".to_string()])),
|
_ => return Err(HashSet::from(["Expected uint64 for ndims".to_string()])),
|
||||||
},
|
},
|
||||||
|
|
||||||
_ => arraylike_get_ndims(self.unifier, arg0.custom.unwrap()),
|
_ => arraylike_get_ndims(self.unifier, self.primitives, arg0.custom.unwrap()),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
arraylike_get_ndims(self.unifier, arg0.custom.unwrap())
|
arraylike_get_ndims(self.unifier, self.primitives, arg0.custom.unwrap())
|
||||||
};
|
};
|
||||||
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
|
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
|
||||||
let ret = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims));
|
let ret = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims));
|
||||||
|
@ -1666,8 +1674,12 @@ impl<'a> Inferencer<'a> {
|
||||||
let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) {
|
let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) {
|
||||||
TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }),
|
TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }),
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||||
let (_, ndims) =
|
let ndims = unpack_ndarray_params(
|
||||||
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
|
self.unifier,
|
||||||
|
self.primitives,
|
||||||
|
value.custom.unwrap(),
|
||||||
|
)
|
||||||
|
.ndims;
|
||||||
|
|
||||||
make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims))
|
make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims))
|
||||||
}
|
}
|
||||||
|
@ -1680,8 +1692,13 @@ impl<'a> Inferencer<'a> {
|
||||||
ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
|
ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
|
||||||
match &*self.unifier.get_ty(value.custom.unwrap()) {
|
match &*self.unifier.get_ty(value.custom.unwrap()) {
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||||
let (_, ndims) =
|
let ndims = unpack_ndarray_params(
|
||||||
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
|
self.unifier,
|
||||||
|
self.primitives,
|
||||||
|
value.custom.unwrap(),
|
||||||
|
)
|
||||||
|
.ndims;
|
||||||
|
|
||||||
self.infer_subscript_ndarray(value, ty, ndims)
|
self.infer_subscript_ndarray(value, ty, ndims)
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
|
@ -1724,7 +1741,9 @@ impl<'a> Inferencer<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
|
let ndims =
|
||||||
|
unpack_ndarray_params(self.unifier, self.primitives, value.custom.unwrap())
|
||||||
|
.ndims;
|
||||||
let ndarray_ty =
|
let ndarray_ty =
|
||||||
make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims));
|
make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims));
|
||||||
self.constrain(value.custom.unwrap(), ndarray_ty, &value.location)?;
|
self.constrain(value.custom.unwrap(), ndarray_ty, &value.location)?;
|
||||||
|
@ -1751,8 +1770,12 @@ impl<'a> Inferencer<'a> {
|
||||||
Ok(ty)
|
Ok(ty)
|
||||||
}
|
}
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||||
let (_, ndims) =
|
let ndims = unpack_ndarray_params(
|
||||||
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
|
self.unifier,
|
||||||
|
self.primitives,
|
||||||
|
value.custom.unwrap(),
|
||||||
|
)
|
||||||
|
.ndims;
|
||||||
|
|
||||||
let valid_index_tys = [self.primitives.int32, self.primitives.isize()]
|
let valid_index_tys = [self.primitives.int32, self.primitives.isize()]
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
|
|
@ -1,5 +1,8 @@
|
||||||
use super::super::{magic_methods::with_fields, typedef::*};
|
use super::super::{magic_methods::with_fields, typedef::*};
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::toplevel::helper::{
|
||||||
|
make_ndarray_dtype_tvar, make_ndarray_ndims_tvar, make_option_type_tvar,
|
||||||
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::CodeGenContext,
|
codegen::CodeGenContext,
|
||||||
symbol_resolver::ValueEnum,
|
symbol_resolver::ValueEnum,
|
||||||
|
@ -132,14 +135,14 @@ impl TestEnvironment {
|
||||||
fields: HashMap::new(),
|
fields: HashMap::new(),
|
||||||
params: VarMap::new(),
|
params: VarMap::new(),
|
||||||
});
|
});
|
||||||
|
let option_type_tvar = make_option_type_tvar(&mut unifier);
|
||||||
let option = unifier.add_ty(TypeEnum::TObj {
|
let option = unifier.add_ty(TypeEnum::TObj {
|
||||||
obj_id: PrimDef::Option.id(),
|
obj_id: PrimDef::Option.id(),
|
||||||
fields: HashMap::new(),
|
fields: HashMap::new(),
|
||||||
params: VarMap::new(),
|
params: VarMap::new(),
|
||||||
});
|
});
|
||||||
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
|
let ndarray_dtype_tvar = make_ndarray_dtype_tvar(&mut unifier);
|
||||||
let ndarray_ndims_tvar =
|
let ndarray_ndims_tvar = make_ndarray_ndims_tvar(&mut unifier, uint64);
|
||||||
unifier.get_fresh_const_generic_var(uint64, Some("ndarray_ndims".into()), None);
|
|
||||||
let ndarray = unifier.add_ty(TypeEnum::TObj {
|
let ndarray = unifier.add_ty(TypeEnum::TObj {
|
||||||
obj_id: PrimDef::NDArray.id(),
|
obj_id: PrimDef::NDArray.id(),
|
||||||
fields: HashMap::new(),
|
fields: HashMap::new(),
|
||||||
|
@ -157,7 +160,10 @@ impl TestEnvironment {
|
||||||
uint32,
|
uint32,
|
||||||
uint64,
|
uint64,
|
||||||
option,
|
option,
|
||||||
|
option_type_tvar,
|
||||||
ndarray,
|
ndarray,
|
||||||
|
ndarray_dtype_tvar,
|
||||||
|
ndarray_ndims_tvar,
|
||||||
size_t: 64,
|
size_t: 64,
|
||||||
};
|
};
|
||||||
unifier.put_primitive_store(&primitives);
|
unifier.put_primitive_store(&primitives);
|
||||||
|
@ -268,16 +274,22 @@ impl TestEnvironment {
|
||||||
fields: HashMap::new(),
|
fields: HashMap::new(),
|
||||||
params: VarMap::new(),
|
params: VarMap::new(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
let option_type_tvar = make_option_type_tvar(&mut unifier);
|
||||||
let option = unifier.add_ty(TypeEnum::TObj {
|
let option = unifier.add_ty(TypeEnum::TObj {
|
||||||
obj_id: PrimDef::Option.id(),
|
obj_id: PrimDef::Option.id(),
|
||||||
fields: HashMap::new(),
|
fields: HashMap::new(),
|
||||||
params: VarMap::new(),
|
params: VarMap::new(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
let ndarray_dtype_tvar = make_ndarray_dtype_tvar(&mut unifier);
|
||||||
|
let ndarray_ndims_tvar = make_ndarray_ndims_tvar(&mut unifier, uint64);
|
||||||
let ndarray = unifier.add_ty(TypeEnum::TObj {
|
let ndarray = unifier.add_ty(TypeEnum::TObj {
|
||||||
obj_id: PrimDef::NDArray.id(),
|
obj_id: PrimDef::NDArray.id(),
|
||||||
fields: HashMap::new(),
|
fields: HashMap::new(),
|
||||||
params: VarMap::new(),
|
params: VarMap::new(),
|
||||||
});
|
});
|
||||||
|
|
||||||
identifier_mapping.insert("None".into(), none);
|
identifier_mapping.insert("None".into(), none);
|
||||||
for (i, name) in ["int32", "int64", "float", "bool", "none", "range", "str", "Exception"]
|
for (i, name) in ["int32", "int64", "float", "bool", "none", "range", "str", "Exception"]
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -312,7 +324,10 @@ impl TestEnvironment {
|
||||||
uint32,
|
uint32,
|
||||||
uint64,
|
uint64,
|
||||||
option,
|
option,
|
||||||
|
option_type_tvar,
|
||||||
ndarray,
|
ndarray,
|
||||||
|
ndarray_dtype_tvar,
|
||||||
|
ndarray_ndims_tvar,
|
||||||
size_t: 64,
|
size_t: 64,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue