forked from M-Labs/nac3
1
0
Fork 0

core/typedef: WIP - Add OptionType and NDArrayType

This commit is contained in:
David Mak 2024-06-27 16:33:27 +08:00
parent da4dec08a5
commit 6892a4848e
18 changed files with 717 additions and 407 deletions

View File

@ -7,7 +7,7 @@ use nac3core::{
},
symbol_resolver::ValueEnum,
toplevel::{helper::PrimDef, DefinitionId, GenCall},
typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, VarMap},
typecheck::typedef::{FunSignature, FuncArg, GenericObjectType, Type, TypeEnum, VarMap},
};
use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef};
@ -23,7 +23,7 @@ use pyo3::{
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
use nac3core::toplevel::numpy::unpack_ndarray_var_tys;
use nac3core::toplevel::primitive_type;
use std::{
collections::hash_map::DefaultHasher,
collections::HashMap,
@ -399,7 +399,9 @@ fn gen_rpc_tag(
gen_rpc_tag(ctx, *ty, buffer)?;
}
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
let ndarray_ty = primitive_type::NDArrayType::create(ty, &mut ctx.unifier);
let ndarray_dtype = ndarray_ty.dtype_tvar(&mut ctx.unifier).ty;
let ndarray_ndims = ndarray_ty.ndims_tvar(&mut ctx.unifier).ty;
let ndarray_ndims = if let TLiteral { values, .. } =
&*ctx.unifier.get_ty_immutable(ndarray_ndims)
{
@ -645,7 +647,7 @@ pub fn attributes_writeback(
let ty = ty.unwrap();
match &*ctx.unifier.get_ty(ty) {
TypeEnum::TObj { fields, obj_id, .. }
if *obj_id != ctx.primitives.option.obj_id(&ctx.unifier).unwrap() =>
if *obj_id != ctx.primitives.option.obj_id(&ctx.unifier) =>
{
// we only care about primitive attributes
// for non-primitive attributes, they should be in another global

View File

@ -11,11 +11,7 @@ use nac3core::{
CodeGenContext, CodeGenerator,
},
symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum},
toplevel::{
helper::PrimDef,
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
DefinitionId, TopLevelDef,
},
toplevel::{helper::PrimDef, primitive_type, DefinitionId, TopLevelDef},
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{Type, TypeEnum, TypeVar, Unifier, VarMap},
@ -337,13 +333,18 @@ impl InnerResolver {
// do not handle type var param and concrete check here
let var = unifier.get_dummy_var().ty;
let ndims = unifier.get_fresh_const_generic_var(primitives.usize(), None, None).ty;
let ndarray = make_ndarray_ty(unifier, primitives, Some(var), Some(ndims));
Ok(Ok((ndarray, false)))
let ndarray = primitive_type::NDArrayType::from_primitive(
unifier,
primitives,
Some(var),
Some(ndims),
);
Ok(Ok((ndarray.into(), false)))
} else if ty_id == self.primitive_ids.tuple {
// do not handle type var param and concrete check here
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false)))
} else if ty_id == self.primitive_ids.option {
Ok(Ok((primitives.option, false)))
Ok(Ok((primitives.option.into(), false)))
} else if ty_id == self.primitive_ids.none {
unreachable!("none cannot be typeid")
} else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).copied() {
@ -510,7 +511,16 @@ impl InnerResolver {
));
}
Ok(Ok((make_ndarray_ty(unifier, primitives, Some(ty.0), None), true)))
Ok(Ok((
primitive_type::NDArrayType::from_primitive(
unifier,
primitives,
Some(ty.0),
None,
)
.into(),
true,
)))
}
TypeEnum::TTuple { .. } => {
let args = match args
@ -719,7 +729,9 @@ impl InnerResolver {
}
}
(TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PrimDef::NDArray.id() => {
let (ty, ndims) = unpack_ndarray_var_tys(unifier, extracted_ty);
let ndarray = primitive_type::NDArrayType::create(extracted_ty, unifier);
let ty = ndarray.dtype_tvar(unifier).ty;
let ndims = ndarray.ndims_tvar(unifier).ty;
let len: usize = obj.getattr("ndim")?.extract()?;
if len == 0 {
assert!(matches!(
@ -734,10 +746,14 @@ impl InnerResolver {
match dtype_ty {
Ok((t, _)) => match unifier.unify(ty, t) {
Ok(()) => {
let ndarray_ty =
make_ndarray_ty(unifier, primitives, Some(ty), Some(ndims));
let ndarray_ty = primitive_type::NDArrayType::from_primitive(
unifier,
primitives,
Some(ty),
Some(ndims),
);
Ok(Ok(ndarray_ty))
Ok(Ok(ndarray_ty.into()))
}
Err(e) => Ok(Err(format!(
"type error ({}) for the ndarray",
@ -760,7 +776,7 @@ impl InnerResolver {
// special handling for option type since its class member layout in python side
// is special and cannot be mapped directly to a nac3 type as below
(TypeEnum::TObj { obj_id, params, .. }, false)
if *obj_id == primitives.option.obj_id(unifier).unwrap() =>
if *obj_id == primitives.option.obj_id(unifier) =>
{
let Ok(field_data) = obj.getattr("_nac3_option") else {
unreachable!("cannot be None")
@ -785,7 +801,7 @@ impl InnerResolver {
.map(TypeVar::into)
.collect::<VarMap>()
});
return Ok(Ok(unifier.subst(primitives.option, &var_map).unwrap()));
return Ok(Ok(unifier.subst(primitives.option.into(), &var_map).unwrap()));
}
let ty = match self.get_obj_type(py, field_data, unifier, defs, primitives)? {
@ -1038,8 +1054,9 @@ impl InnerResolver {
} else {
unreachable!("must be ndarray")
};
let (ndarray_dtype, ndarray_ndims) =
unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty);
let ndarray_ty = primitive_type::NDArrayType::create(ndarray_ty, &mut ctx.unifier);
let ndarray_dtype = ndarray_ty.dtype_tvar(&mut ctx.unifier).ty;
let ndarray_ndims = ndarray_ty.ndims_tvar(&mut ctx.unifier).ty;
let llvm_usize = generator.get_size_type(ctx.ctx);
let ndarray_dtype_llvm_ty = ctx.get_llvm_type(generator, ndarray_dtype);
@ -1186,7 +1203,7 @@ impl InnerResolver {
} else if ty_id == self.primitive_ids.option {
let option_val_ty = match ctx.unifier.get_ty_immutable(expected_ty).as_ref() {
TypeEnum::TObj { obj_id, params, .. }
if *obj_id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap() =>
if *obj_id == ctx.primitives.option.obj_id(&ctx.unifier) =>
{
*params.iter().next().unwrap().1
}

View File

@ -8,8 +8,8 @@ use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
use crate::codegen::stmt::gen_for_callback_incrementing;
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
use crate::toplevel::helper::PrimDef;
use crate::toplevel::numpy::unpack_ndarray_var_tys;
use crate::typecheck::typedef::Type;
use crate::toplevel::primitive_type;
use crate::typecheck::typedef::{GenericObjectType, Type};
/// Shorthand for [`unreachable!()`] when a type of argument is not supported.
///
@ -66,7 +66,9 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -128,7 +130,9 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -206,7 +210,9 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -273,7 +279,9 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -328,7 +336,9 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -374,7 +384,9 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -414,7 +426,9 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -475,7 +489,9 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -529,7 +545,9 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -579,7 +597,9 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
@ -660,7 +680,9 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n)
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
let elem_ty = primitive_type::NDArrayType::create(a_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
@ -751,16 +773,24 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1
} else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else {
unreachable!()
};
@ -850,7 +880,9 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n)
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
let elem_ty = primitive_type::NDArrayType::create(a_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
@ -941,16 +973,24 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1
} else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else {
unreachable!()
};
@ -1008,7 +1048,9 @@ where
if arg_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let llvm_usize = generator.get_size_type(ctx.ctx);
let (arg_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty);
let arg_elem_ty = primitive_type::NDArrayType::create(arg_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ret_elem_ty = get_ret_elem_type(ctx, arg_elem_ty);
let ndarray = ndarray_elementwise_unaryop_impl(
@ -1370,16 +1412,24 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1
} else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else {
unreachable!()
};
@ -1437,16 +1487,24 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1
} else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else {
unreachable!()
};
@ -1504,16 +1562,24 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1
} else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else {
unreachable!()
};
@ -1571,16 +1637,24 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1
} else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else {
unreachable!()
};
@ -1637,12 +1711,22 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>(
let is_ndarray2 =
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype =
if is_ndarray1 { unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 } else { x1_ty };
let dtype = if is_ndarray1 {
primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else {
x1_ty
};
let x1_scalar_ty = dtype;
let x2_scalar_ty =
if is_ndarray2 { unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 } else { x2_ty };
let x2_scalar_ty = if is_ndarray2 {
primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else {
x2_ty
};
numpy::ndarray_elementwise_binop_impl(
generator,
@ -1694,16 +1778,24 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1
} else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else {
unreachable!()
};
@ -1761,16 +1853,24 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1
} else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else {
unreachable!()
};

View File

@ -1,5 +1,9 @@
use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
use super::{llvm_intrinsics::call_memcpy_generic, need_sret, CodeGenerator};
use crate::toplevel::primitive_type;
use crate::toplevel::primitive_type::OptionType;
use crate::typecheck::typedef::GenericObjectType;
use crate::{
codegen::{
classes::{
@ -15,11 +19,7 @@ use crate::{
CodeGenContext, CodeGenTask,
},
symbol_resolver::{SymbolValue, ValueEnum},
toplevel::{
helper::PrimDef,
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
DefinitionId, TopLevelDef,
},
toplevel::{helper::PrimDef, DefinitionId, TopLevelDef},
typecheck::{
magic_methods::{binop_assign_name, binop_name, unaryop_name},
typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap},
@ -36,8 +36,6 @@ use nac3parser::ast::{
self, Boolop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef,
};
use super::{llvm_intrinsics::call_memcpy_generic, need_sret, CodeGenerator};
pub fn get_subst_key(
unifier: &mut Unifier,
obj: Option<Type>,
@ -162,14 +160,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
self.builder.build_load(ptr, "tup_val").unwrap()
}
SymbolValue::OptionSome(v) => {
let ty = match self.unifier.get_ty_immutable(ty).as_ref() {
TypeEnum::TObj { obj_id, params, .. }
if *obj_id == self.primitives.option.obj_id(&self.unifier).unwrap() =>
{
*params.iter().next().unwrap().1
}
_ => unreachable!("must be option type"),
};
let ty = OptionType::create(ty, &mut self.unifier).type_tvar(&mut self.unifier).ty;
let val = self.gen_symbol_val(generator, v, ty);
let ptr = generator
.gen_var_alloc(self, val.get_type(), Some("default_opt_some"))
@ -178,14 +169,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
ptr.into()
}
SymbolValue::OptionNone => {
let ty = match self.unifier.get_ty_immutable(ty).as_ref() {
TypeEnum::TObj { obj_id, params, .. }
if *obj_id == self.primitives.option.obj_id(&self.unifier).unwrap() =>
{
*params.iter().next().unwrap().1
}
_ => unreachable!("must be option type"),
};
let ty = OptionType::create(ty, &mut self.unifier).type_tvar(&mut self.unifier).ty;
let actual_ptr_type =
self.get_llvm_type(generator, ty).ptr_type(AddressSpace::default());
actual_ptr_type.const_null().into()
@ -1206,8 +1190,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
let is_ndarray2 = ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty2);
let ndarray_dtype1 = primitive_type::NDArrayType::create(ty1, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_dtype2 = primitive_type::NDArrayType::create(ty2, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
@ -1256,8 +1244,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
Ok(Some(res.as_base_value().into()))
} else {
let (ndarray_dtype, _) =
unpack_ndarray_var_tys(&mut ctx.unifier, if is_ndarray1 { ty1 } else { ty2 });
let ndarray_dtype = primitive_type::NDArrayType::create(
if is_ndarray1 { ty1 } else { ty2 },
&mut ctx.unifier,
)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_val = NDArrayValue::from_ptr_val(
if is_ndarray1 { left_val } else { right_val }.into_pointer_value(),
llvm_usize,
@ -1443,7 +1435,9 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
}
} else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
let llvm_usize = generator.get_size_type(ctx.ctx);
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
let ndarray_dtype = primitive_type::NDArrayType::create(ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let val = NDArrayValue::from_ptr_val(val.into_pointer_value(), llvm_usize, None);
@ -1527,8 +1521,13 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
return if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty);
let ndarray_dtype1 = primitive_type::NDArrayType::create(left_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_dtype2 =
primitive_type::NDArrayType::create(right_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
@ -1562,10 +1561,12 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
Ok(Some(res.as_base_value().into()))
} else {
let (ndarray_dtype, _) = unpack_ndarray_var_tys(
&mut ctx.unifier,
let ndarray_dtype = primitive_type::NDArrayType::create(
if is_ndarray1 { left_ty } else { right_ty },
);
&mut ctx.unifier,
)
.dtype_tvar(&mut ctx.unifier)
.ty;
let res = numpy::ndarray_elementwise_binop_impl(
generator,
ctx,
@ -1788,9 +1789,13 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
ndims.iter().map(|v| SymbolValue::U64(v - subscripted_dims)).collect(),
None,
);
let ndarray_ty =
make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(ty), Some(ndarray_ndims_ty));
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
let ndarray_ty = primitive_type::NDArrayType::from_primitive(
&mut ctx.unifier,
&ctx.primitives,
Some(ty),
Some(ndarray_ndims_ty),
);
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty.into()).into_pointer_type();
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum();
@ -2082,7 +2087,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
ExprKind::Name { id, .. } if id == &"none".into() => {
match (
ctx.unifier.get_ty(expr.custom.unwrap()).as_ref(),
ctx.unifier.get_ty(ctx.primitives.option).as_ref(),
ctx.unifier.get_ty(ctx.primitives.option.into()).as_ref(),
) {
(TypeEnum::TObj { obj_id, params, .. }, TypeEnum::TObj { obj_id: opt_id, .. })
if *obj_id == *opt_id =>
@ -2464,8 +2469,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
};
// directly generate code for option.unwrap
// since it needs to return static value to optimize for kernel invariant
if attr == &"unwrap".into()
&& id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap()
if attr == &"unwrap".into() && id == ctx.primitives.option.obj_id(&ctx.unifier)
{
match val {
ValueEnum::Static(v) => {

View File

@ -1,7 +1,7 @@
use crate::{
codegen::classes::{ListType, NDArrayType, ProxyType, RangeType},
symbol_resolver::{StaticValue, SymbolResolver},
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef},
toplevel::{helper::PrimDef, TopLevelContext, TopLevelDef},
typecheck::{
type_inferencer::{CodeLocation, PrimitiveStore},
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
@ -47,6 +47,9 @@ pub mod stmt;
#[cfg(test)]
mod test;
use crate::toplevel::primitive_type;
use crate::toplevel::primitive_type::OptionType;
use crate::typecheck::typedef::GenericObjectType;
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore};
pub use generator::{CodeGenerator, DefaultCodeGenerator};
@ -457,7 +460,9 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
}
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (dtype, _) = unpack_ndarray_var_tys(unifier, ty);
let dtype = primitive_type::NDArrayType::create(ty, unifier)
.dtype_tvar(unifier)
.ty;
let element_type = get_llvm_type(
ctx, module, generator, unifier, top_level, type_cache, dtype,
);
@ -634,7 +639,10 @@ pub fn gen_func_impl<
range: unifier.get_representative(primitives.range),
str: unifier.get_representative(primitives.str),
exception: unifier.get_representative(primitives.exception),
option: unifier.get_representative(primitives.option),
option: OptionType::create(
unifier.get_representative(primitives.option.into()),
&mut unifier,
),
..primitives
};

View File

@ -17,12 +17,8 @@ use crate::{
CodeGenContext, CodeGenerator,
},
symbol_resolver::ValueEnum,
toplevel::{
helper::PrimDef,
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
DefinitionId,
},
typecheck::typedef::{FunSignature, Type, TypeEnum},
toplevel::{helper::PrimDef, primitive_type, DefinitionId},
typecheck::typedef::{FunSignature, GenericObjectType, Type, TypeEnum},
};
use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType};
use inkwell::{
@ -38,12 +34,17 @@ fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
) -> Result<NDArrayValue<'ctx>, String> {
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
let ndarray_ty = primitive_type::NDArrayType::from_primitive(
&mut ctx.unifier,
&ctx.primitives,
Some(elem_ty),
None,
);
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_ndarray_t = ctx
.get_llvm_type(generator, ndarray_ty)
.get_llvm_type(generator, ndarray_ty.into())
.into_pointer_type()
.get_element_type()
.into_struct_type();
@ -1799,7 +1800,9 @@ pub fn gen_ndarray_array<'ctx>(
let obj_ty = fun.0.args[0].ty;
let obj_elem_ty = match &*context.unifier.get_ty(obj_ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
unpack_ndarray_var_tys(&mut context.unifier, obj_ty).0
primitive_type::NDArrayType::create(obj_ty, &mut context.unifier)
.dtype_tvar(&mut context.unifier)
.ty
}
TypeEnum::TList { ty } => {
@ -1939,7 +1942,9 @@ pub fn gen_ndarray_copy<'ctx>(
let llvm_usize = generator.get_size_type(context.ctx);
let this_ty = obj.as_ref().unwrap().0;
let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty);
let this_elem_ty = primitive_type::NDArrayType::create(this_ty, &mut context.unifier)
.dtype_tvar(&mut context.unifier)
.ty;
let this_arg =
obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?;

View File

@ -4,13 +4,15 @@ use super::{
irrt::{handle_slice_indices, list_slice_assignment},
CodeGenContext, CodeGenerator,
};
use crate::toplevel::primitive_type;
use crate::typecheck::typedef::GenericObjectType;
use crate::{
codegen::{
classes::{ArrayLikeIndexer, ArraySliceValue, ListValue, RangeValue},
expr::gen_binop_expr,
gen_in_range_check,
},
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef},
toplevel::{helper::PrimDef, DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type, TypeEnum},
};
use inkwell::{
@ -245,7 +247,9 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) {
TypeEnum::TList { ty } => *ty,
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
unpack_ndarray_var_tys(&mut ctx.unifier, target.custom.unwrap()).0
primitive_type::NDArrayType::create(target.custom.unwrap(), &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
}
_ => unreachable!(),
};

View File

@ -3,6 +3,7 @@ use std::rc::Rc;
use std::sync::Arc;
use std::{collections::HashMap, collections::HashSet, fmt::Display};
use crate::typecheck::typedef::GenericObjectType;
use crate::{
codegen::{CodeGenContext, CodeGenerator},
toplevel::{type_annotation::TypeAnnotation, DefinitionId, TopLevelDef},
@ -43,7 +44,7 @@ impl SymbolValue {
) -> Result<Self, String> {
match constant {
Constant::None => {
if unifier.unioned(expected_ty, primitives.option) {
if unifier.unioned(expected_ty, primitives.option.into()) {
Ok(SymbolValue::OptionNone)
} else {
Err(format!("Expected {expected_ty:?}, but got Option"))
@ -157,7 +158,7 @@ impl SymbolValue {
let vs_tys = vs.iter().map(|v| v.get_type(primitives, unifier)).collect::<Vec<_>>();
unifier.add_ty(TypeEnum::TTuple { ty: vs_tys })
}
SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option,
SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option.into(),
}
}
@ -183,13 +184,13 @@ impl SymbolValue {
TypeAnnotation::Tuple(vs_tys)
}
SymbolValue::OptionNone => TypeAnnotation::CustomClass {
id: primitives.option.obj_id(unifier).unwrap(),
id: primitives.option.obj_id(unifier),
params: Vec::default(),
},
SymbolValue::OptionSome(v) => {
let ty = v.get_type_annotation(primitives, unifier);
TypeAnnotation::CustomClass {
id: primitives.option.obj_id(unifier).unwrap(),
id: primitives.option.obj_id(unifier),
params: vec![ty],
}
}

View File

@ -11,7 +11,6 @@ use inkwell::{
use itertools::Either;
use strum::IntoEnumIterator;
use crate::typecheck::typedef::{GenericObjectType, GenericTypeAdapter};
use crate::{
codegen::{
builtin_fns,
@ -25,7 +24,7 @@ use crate::{
stmt::exn_constructor,
},
symbol_resolver::SymbolValue,
toplevel::{helper::PrimDef, numpy::make_ndarray_ty},
toplevel::helper::PrimDef,
typecheck::typedef::{into_var_map, TypeVar, VarMap},
};
@ -304,10 +303,7 @@ struct BuiltinBuilder<'a> {
is_some_ty: (Type, bool),
unwrap_ty: (Type, bool),
option_tvar: TypeVar,
ndarray_dtype_tvar: TypeVar,
ndarray_ndims_tvar: TypeVar,
ndarray_copy_ty: (Type, bool),
ndarray_fill_ty: (Type, bool),
@ -316,9 +312,9 @@ struct BuiltinBuilder<'a> {
num_ty: TypeVar,
num_var_map: VarMap,
ndarray_float: Type,
ndarray_float_2d: Type,
ndarray_num_ty: Type,
ndarray_float: primitive_type::NDArrayType,
ndarray_float_2d: primitive_type::NDArrayType,
ndarray_num_ty: primitive_type::NDArrayType,
float_or_ndarray_ty: TypeVar,
float_or_ndarray_var_map: VarMap,
@ -345,24 +341,19 @@ impl<'a> BuiltinBuilder<'a> {
} = *primitives;
// Option-related
let (is_some_ty, unwrap_ty, option_tvar) =
if let TypeEnum::TObj { fields, .. } = &*unifier.get_ty(option) {
let option = GenericTypeAdapter::create(option, unifier);
let (is_some_ty, unwrap_ty) =
if let TypeEnum::TObj { fields, .. } = &*unifier.get_ty(option.into()) {
(
*fields.get(&PrimDef::OptionIsSome.simple_name().into()).unwrap(),
*fields.get(&PrimDef::OptionUnwrap.simple_name().into()).unwrap(),
option.get_var_at(unifier, 0).unwrap(),
)
} else {
unreachable!()
};
let TypeEnum::TObj { fields: ndarray_fields, .. } = &*unifier.get_ty(ndarray) else {
let TypeEnum::TObj { fields: ndarray_fields, .. } = &*unifier.get_ty(ndarray.into()) else {
unreachable!()
};
let ndarray = GenericTypeAdapter::create(ndarray, unifier);
let ndarray_dtype_tvar = ndarray.get_var_at(unifier, 0).unwrap();
let ndarray_ndims_tvar = ndarray.get_var_at(unifier, 1).unwrap();
let ndarray_copy_ty =
*ndarray_fields.get(&PrimDef::NDArrayCopy.simple_name().into()).unwrap();
let ndarray_fill_ty =
@ -375,7 +366,8 @@ impl<'a> BuiltinBuilder<'a> {
);
let num_var_map = into_var_map([num_ty]);
let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), None);
let ndarray_float =
primitive_type::NDArrayType::from_primitive(unifier, primitives, Some(float), None);
let ndarray_float_2d = {
let value = match primitives.size_t {
64 => SymbolValue::U64(2u64),
@ -384,16 +376,28 @@ impl<'a> BuiltinBuilder<'a> {
};
let ndims = unifier.add_ty(TypeEnum::TLiteral { values: vec![value], loc: None });
make_ndarray_ty(unifier, primitives, Some(float), Some(ndims))
primitive_type::NDArrayType::from_primitive(
unifier,
primitives,
Some(float),
Some(ndims),
)
};
let ndarray_num_ty = make_ndarray_ty(unifier, primitives, Some(num_ty.ty), None);
let float_or_ndarray_ty =
unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None);
let ndarray_num_ty =
primitive_type::NDArrayType::from_primitive(unifier, primitives, Some(num_ty.ty), None);
let float_or_ndarray_ty = unifier.get_fresh_var_with_range(
&[float, ndarray_float.into()],
Some("T".into()),
None,
);
let float_or_ndarray_var_map = into_var_map([float_or_ndarray_ty]);
let num_or_ndarray_ty =
unifier.get_fresh_var_with_range(&[num_ty.ty, ndarray_num_ty], Some("T".into()), None);
let num_or_ndarray_ty = unifier.get_fresh_var_with_range(
&[num_ty.ty, ndarray_num_ty.into()],
Some("T".into()),
None,
);
let num_or_ndarray_var_map = into_var_map([num_ty, num_or_ndarray_ty]);
let list_int32 = unifier.add_ty(TypeEnum::TList { ty: int32 });
@ -406,10 +410,7 @@ impl<'a> BuiltinBuilder<'a> {
is_some_ty,
unwrap_ty,
option_tvar,
ndarray_dtype_tvar,
ndarray_ndims_tvar,
ndarray_copy_ty,
ndarray_fill_ty,
@ -633,7 +634,7 @@ impl<'a> BuiltinBuilder<'a> {
PrimDef::Option => TopLevelDef::Class {
name: prim.name().into(),
object_id: prim.id(),
type_vars: vec![self.option_tvar.ty],
type_vars: vec![self.primitives.option.type_tvar(self.unifier).ty],
fields: Vec::default(),
attributes: Vec::default(),
methods: vec![
@ -654,7 +655,7 @@ impl<'a> BuiltinBuilder<'a> {
name: prim.name().into(),
simple_name: prim.simple_name().into(),
signature: self.unwrap_ty.0,
var_id: vec![self.option_tvar.id],
var_id: vec![self.primitives.option.type_tvar(self.unifier).id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
@ -668,7 +669,7 @@ impl<'a> BuiltinBuilder<'a> {
name: prim.name().to_string(),
simple_name: prim.simple_name().into(),
signature: self.is_some_ty.0,
var_id: vec![self.option_tvar.id],
var_id: vec![self.primitives.option.type_tvar(self.unifier).id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
@ -699,19 +700,22 @@ impl<'a> BuiltinBuilder<'a> {
loc: None,
},
PrimDef::FunSome => TopLevelDef::Function {
PrimDef::FunSome => {
let option_tvar = self.primitives.option.type_tvar(self.unifier);
TopLevelDef::Function {
name: prim.name().into(),
simple_name: prim.simple_name().into(),
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg {
name: "n".into(),
ty: self.option_tvar.ty,
ty: option_tvar.ty,
default_value: None,
}],
ret: self.primitives.option,
vars: into_var_map([self.option_tvar]),
ret: self.primitives.option.into(),
vars: into_var_map([option_tvar]),
})),
var_id: vec![self.option_tvar.id],
var_id: vec![self.primitives.option.type_tvar(self.unifier).id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
@ -728,7 +732,8 @@ impl<'a> BuiltinBuilder<'a> {
},
)))),
loc: None,
},
}
}
_ => {
unreachable!()
@ -737,7 +742,7 @@ impl<'a> BuiltinBuilder<'a> {
}
/// Build the class `ndarray` and its associated methods.
fn build_ndarray_class_related(&self, prim: PrimDef) -> TopLevelDef {
fn build_ndarray_class_related(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(
prim,
&[PrimDef::NDArray, PrimDef::NDArrayCopy, PrimDef::NDArrayFill],
@ -747,7 +752,10 @@ impl<'a> BuiltinBuilder<'a> {
PrimDef::NDArray => TopLevelDef::Class {
name: prim.name().into(),
object_id: prim.id(),
type_vars: vec![self.ndarray_dtype_tvar.ty, self.ndarray_ndims_tvar.ty],
type_vars: vec![
self.primitives.ndarray.dtype_tvar(self.unifier).ty,
self.primitives.ndarray.ndims_tvar(self.unifier).ty,
],
fields: Vec::default(),
attributes: Vec::default(),
methods: vec![
@ -764,7 +772,10 @@ impl<'a> BuiltinBuilder<'a> {
name: prim.name().into(),
simple_name: prim.simple_name().into(),
signature: self.ndarray_copy_ty.0,
var_id: vec![self.ndarray_dtype_tvar.id, self.ndarray_ndims_tvar.id],
var_id: vec![
self.primitives.ndarray.dtype_tvar(self.unifier).id,
self.primitives.ndarray.ndims_tvar(self.unifier).id,
],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
@ -781,7 +792,10 @@ impl<'a> BuiltinBuilder<'a> {
name: prim.name().into(),
simple_name: prim.simple_name().into(),
signature: self.ndarray_fill_ty.0,
var_id: vec![self.ndarray_dtype_tvar.id, self.ndarray_ndims_tvar.id],
var_id: vec![
self.primitives.ndarray.dtype_tvar(self.unifier).id,
self.primitives.ndarray.ndims_tvar(self.unifier).id,
],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
@ -870,15 +884,26 @@ impl<'a> BuiltinBuilder<'a> {
// The size variant of the function determines the size of the returned int.
let int_sized = size_variant.of_int(self.primitives);
let ndarray_int_sized =
make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.ty));
let ndarray_float =
make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.ty));
let ndarray_int_sized = primitive_type::NDArrayType::from_primitive(
self.unifier,
self.primitives,
Some(int_sized),
Some(common_ndim.ty),
);
let ndarray_float = primitive_type::NDArrayType::from_primitive(
self.unifier,
self.primitives,
Some(float),
Some(common_ndim.ty),
);
let p0_ty =
self.unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None);
let p0_ty = self.unifier.get_fresh_var_with_range(
&[float, ndarray_float.into()],
Some("T".into()),
None,
);
let ret_ty = self.unifier.get_fresh_var_with_range(
&[int_sized, ndarray_int_sized],
&[int_sized, ndarray_int_sized.into()],
Some("R".into()),
None,
);
@ -930,19 +955,30 @@ impl<'a> BuiltinBuilder<'a> {
None,
);
let ndarray_float =
make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.ty));
let ndarray_float = primitive_type::NDArrayType::from_primitive(
self.unifier,
self.primitives,
Some(float),
Some(common_ndim.ty),
);
// The size variant of the function determines the type of int returned
let int_sized = size_variant.of_int(self.primitives);
let ndarray_int_sized =
make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.ty));
let ndarray_int_sized = primitive_type::NDArrayType::from_primitive(
self.unifier,
self.primitives,
Some(int_sized),
Some(common_ndim.ty),
);
let p0_ty =
self.unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None);
let p0_ty = self.unifier.get_fresh_var_with_range(
&[float, ndarray_float.into()],
Some("T".into()),
None,
);
let ret_ty = self.unifier.get_fresh_var_with_range(
&[int_sized, ndarray_int_sized],
&[int_sized, ndarray_int_sized.into()],
Some("R".into()),
None,
);
@ -1005,7 +1041,7 @@ impl<'a> BuiltinBuilder<'a> {
self.unifier,
&VarMap::new(),
prim.name(),
self.ndarray_float,
self.ndarray_float.into(),
&[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
Box::new(move |ctx, obj, fun, args, generator| {
let func = match prim {
@ -1051,7 +1087,7 @@ impl<'a> BuiltinBuilder<'a> {
default_value: Some(SymbolValue::U32(0)),
},
],
ret: ndarray,
ret: ndarray.into(),
vars: into_var_map([tv]),
})),
var_id: vec![tv.id],
@ -1074,7 +1110,7 @@ impl<'a> BuiltinBuilder<'a> {
self.unifier,
&into_var_map([tv]),
prim.name(),
self.primitives.ndarray,
self.primitives.ndarray.into(),
// We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a
// type variable
&[(self.list_int32, "shape"), (tv.ty, "fill_value")],
@ -1103,7 +1139,7 @@ impl<'a> BuiltinBuilder<'a> {
default_value: Some(SymbolValue::I32(0)),
},
],
ret: self.ndarray_float_2d,
ret: self.ndarray_float_2d.into(),
vars: VarMap::default(),
})),
var_id: Vec::default(),
@ -1123,7 +1159,7 @@ impl<'a> BuiltinBuilder<'a> {
self.unifier,
&VarMap::new(),
prim.name(),
self.ndarray_float_2d,
self.ndarray_float_2d.into(),
&[(int32, "n")],
Box::new(|ctx, obj, fun, args, generator| {
gen_ndarray_identity(ctx, &obj, fun, &args, generator)
@ -1338,10 +1374,15 @@ impl<'a> BuiltinBuilder<'a> {
let tvar = self.unifier.get_fresh_var(Some("L".into()), None);
let list = self.unifier.add_ty(TypeEnum::TList { ty: tvar.ty });
let ndims = self.unifier.get_fresh_const_generic_var(uint64, Some("N".into()), None);
let ndarray = make_ndarray_ty(self.unifier, self.primitives, Some(tvar.ty), Some(ndims.ty));
let ndarray = primitive_type::NDArrayType::from_primitive(
self.unifier,
self.primitives,
Some(tvar.ty),
Some(ndims.ty),
);
let arg_ty = self.unifier.get_fresh_var_with_range(
&[list, ndarray, self.primitives.range],
&[list, ndarray.into(), self.primitives.range],
Some("I".into()),
None,
);
@ -1799,8 +1840,13 @@ impl<'a> BuiltinBuilder<'a> {
}
fn new_type_or_ndarray_ty(&mut self, scalar_ty: Type) -> TypeVar {
let ndarray = make_ndarray_ty(self.unifier, self.primitives, Some(scalar_ty), None);
let ndarray = primitive_type::NDArrayType::from_primitive(
self.unifier,
self.primitives,
Some(scalar_ty),
None,
);
self.unifier.get_fresh_var_with_range(&[scalar_ty, ndarray], Some("T".into()), None)
self.unifier.get_fresh_var_with_range(&[scalar_ty, ndarray.into()], Some("T".into()), None)
}
}

View File

@ -1,14 +1,13 @@
use std::convert::TryInto;
use super::*;
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::numpy::unpack_ndarray_var_tys;
use crate::typecheck::typedef::{into_var_map, Mapping, TypeVarId, VarMap};
use crate::toplevel::primitive_type::{NDArrayType, OptionType};
use crate::typecheck::typedef::{into_var_map, GenericObjectType, Mapping, TypeVarId, VarMap};
use nac3parser::ast::{Constant, Location};
use strum::IntoEnumIterator;
use strum_macros::EnumIter;
use super::*;
/// All primitive types and functions in nac3core.
#[derive(Clone, Copy, Debug, EnumIter, PartialEq, Eq)]
pub enum PrimDef {
@ -403,6 +402,7 @@ impl TopLevelComposer {
.collect::<HashMap<_, _>>(),
params: into_var_map([option_type_var]),
});
let option = OptionType::create(option, &mut unifier);
let size_t_ty = match size_t {
32 => uint32,
@ -436,8 +436,9 @@ impl TopLevelComposer {
]),
params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
});
let ndarray = NDArrayType::create(ndarray, &mut unifier);
unifier.unify(ndarray_copy_fun_ret_ty.ty, ndarray).unwrap();
unifier.unify(ndarray_copy_fun_ret_ty.ty, ndarray.into()).unwrap();
let primitives = PrimitiveStore {
int32,
@ -747,7 +748,7 @@ impl TopLevelComposer {
TypeAnnotation::CustomClass { id: e_id, params: e_param },
) => {
*f_id == *e_id
&& *f_id == primitive.option.obj_id(unifier).unwrap()
&& *f_id == primitive.option.obj_id(unifier)
&& (f_param.is_empty()
|| (f_param.len() == 1
&& e_param.len() == 1
@ -885,7 +886,7 @@ pub fn parse_parameter_default_value(
pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type {
match &*unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
unpack_ndarray_var_tys(unifier, ty).0
NDArrayType::create(ty, unifier).dtype_tvar(unifier).ty
}
TypeEnum::TList { ty } => arraylike_flatten_element_type(unifier, *ty),
@ -897,7 +898,7 @@ pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type {
pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 {
match &*unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let ndims = unpack_ndarray_var_tys(unifier, ty).1;
let ndims = NDArrayType::create(ty, unifier).ndims_tvar(unifier).ty;
let TypeEnum::TLiteral { values, .. } = &*unifier.get_ty_immutable(ndims) else {
panic!("Expected TLiteral for ndarray.ndims, got {}", unifier.stringify(ndims))
};

View File

@ -30,7 +30,7 @@ pub struct DefinitionId(pub usize);
pub mod builtins;
pub mod composer;
pub mod helper;
pub mod numpy;
pub mod primitive_type;
pub mod type_annotation;
use composer::*;
use type_annotation::*;

View File

@ -1,85 +0,0 @@
use crate::{
toplevel::helper::PrimDef,
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{Type, TypeEnum, TypeVarId, Unifier, VarMap},
},
};
use itertools::Itertools;
/// Creates a `ndarray` [`Type`] with the given type arguments.
///
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
pub fn make_ndarray_ty(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
dtype: Option<Type>,
ndims: Option<Type>,
) -> Type {
subst_ndarray_tvars(unifier, primitives.ndarray, dtype, ndims)
}
/// Substitutes type variables in `ndarray`.
///
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
pub fn subst_ndarray_tvars(
unifier: &mut Unifier,
ndarray: Type,
dtype: Option<Type>,
ndims: Option<Type>,
) -> Type {
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
};
debug_assert_eq!(*obj_id, PrimDef::NDArray.id());
if dtype.is_none() && ndims.is_none() {
return ndarray;
}
let tvar_ids = params.iter().map(|(obj_id, _)| *obj_id).collect_vec();
debug_assert_eq!(tvar_ids.len(), 2);
let mut tvar_subst = VarMap::new();
if let Some(dtype) = dtype {
tvar_subst.insert(tvar_ids[0], dtype);
}
if let Some(ndims) = ndims {
tvar_subst.insert(tvar_ids[1], ndims);
}
unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray)
}
fn unpack_ndarray_tvars(unifier: &mut Unifier, ndarray: Type) -> Vec<(TypeVarId, Type)> {
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
};
debug_assert_eq!(*obj_id, PrimDef::NDArray.id());
debug_assert_eq!(params.len(), 2);
params
.iter()
.sorted_by_key(|(obj_id, _)| *obj_id)
.map(|(var_id, ty)| (*var_id, *ty))
.collect_vec()
}
/// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds
/// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray`
/// respectively.
pub fn unpack_ndarray_var_ids(unifier: &mut Unifier, ndarray: Type) -> (TypeVarId, TypeVarId) {
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.0).collect_tuple().unwrap()
}
/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to
/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively.
pub fn unpack_ndarray_var_tys(unifier: &mut Unifier, ndarray: Type) -> (Type, Type) {
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.1).collect_tuple().unwrap()
}

View File

@ -0,0 +1,98 @@
use crate::toplevel::helper::PrimDef;
use crate::typecheck::type_inferencer::PrimitiveStore;
use crate::typecheck::typedef::{GenericObjectType, Type, TypeVar, Unifier, VarMap};
#[derive(Clone, Copy)]
pub struct OptionType(Type);
impl OptionType {
pub fn from_primitive(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
type_ty: Option<Type>,
) -> Self {
primitives.option.subst(unifier, type_ty)
}
pub fn type_tvar(&self, unifier: &mut Unifier) -> TypeVar {
self.get_var_at(unifier, 0).unwrap()
}
#[must_use]
pub fn subst(&self, unifier: &mut Unifier, type_ty: Option<Type>) -> Self {
let new_vars = [(self.type_tvar(unifier).id, type_ty)]
.into_iter()
.filter_map(|(id, ty)| ty.map(|ty| (id, ty)))
.collect::<VarMap>();
let new_ty = unifier.subst(self.get_type(), &new_vars).unwrap_or(self.get_type());
OptionType(new_ty)
}
}
impl GenericObjectType for OptionType {
fn try_create(ty: Type, unifier: &mut Unifier) -> Option<Self> {
if ty.obj_id(unifier).is_some_and(|id| id == PrimDef::Option.id()) {
Some(OptionType(ty))
} else {
None
}
}
fn get_type(&self) -> Type {
self.0
}
}
#[derive(Clone, Copy)]
pub struct NDArrayType(Type);
impl NDArrayType {
pub fn from_primitive(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
dtype: Option<Type>,
ndims: Option<Type>,
) -> Self {
primitives.ndarray.subst(unifier, dtype, ndims)
}
pub fn dtype_tvar(&self, unifier: &mut Unifier) -> TypeVar {
self.get_var_at(unifier, 0).unwrap()
}
pub fn ndims_tvar(&self, unifier: &mut Unifier) -> TypeVar {
self.get_var_at(unifier, 1).unwrap()
}
#[must_use]
pub fn subst(
&self,
unifier: &mut Unifier,
dtype_ty: Option<Type>,
ndims_ty: Option<Type>,
) -> Self {
let new_vars =
[(self.dtype_tvar(unifier).id, dtype_ty), (self.ndims_tvar(unifier).id, ndims_ty)]
.into_iter()
.filter_map(|(id, ty)| ty.map(|ty| (id, ty)))
.collect::<VarMap>();
let new_ty = unifier.subst(self.get_type(), &new_vars).unwrap_or(self.get_type());
NDArrayType(new_ty)
}
}
impl GenericObjectType for NDArrayType {
fn try_create(ty: Type, unifier: &mut Unifier) -> Option<Self> {
if ty.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
Some(NDArrayType(ty))
} else {
None
}
}
fn get_type(&self) -> Type {
self.0
}
}

View File

@ -1,7 +1,7 @@
use super::*;
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PrimDef;
use crate::typecheck::typedef::VarMap;
use crate::typecheck::typedef::{GenericObjectType, VarMap};
use nac3parser::ast::Constant;
#[derive(Clone, Debug)]
@ -267,12 +267,7 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
slice.as_ref(),
locked,
)?;
let id =
if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(primitives.option).as_ref() {
*obj_id
} else {
unreachable!()
};
let id = primitives.option.obj_id(unifier);
Ok(TypeAnnotation::CustomClass { id, params: vec![def_ann] })
}

View File

@ -1,9 +1,9 @@
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PrimDef;
use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys};
use crate::toplevel::primitive_type;
use crate::typecheck::{
type_inferencer::*,
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
typedef::{FunSignature, FuncArg, GenericObjectType, Type, TypeEnum, Unifier, VarMap},
};
use itertools::Itertools;
use nac3parser::ast::StrRef;
@ -369,8 +369,12 @@ pub fn typeof_ndarray_broadcast(
if is_left_ndarray && is_right_ndarray {
// Perform broadcasting on two ndarray operands.
let (left_ty_dtype, left_ty_ndims) = unpack_ndarray_var_tys(unifier, left);
let (right_ty_dtype, right_ty_ndims) = unpack_ndarray_var_tys(unifier, right);
let left_ty = primitive_type::NDArrayType::create(left, unifier);
let left_ty_dtype = left_ty.dtype_tvar(unifier).ty;
let left_ty_ndims = left_ty.ndims_tvar(unifier).ty;
let right_ty = primitive_type::NDArrayType::create(right, unifier);
let right_ty_dtype = right_ty.dtype_tvar(unifier).ty;
let right_ty_ndims = right_ty.ndims_tvar(unifier).ty;
assert!(unifier.unioned(left_ty_dtype, right_ty_dtype));
@ -397,11 +401,18 @@ pub fn typeof_ndarray_broadcast(
.collect_vec();
let res_ndims = unifier.get_fresh_literal(res_ndims, None);
Ok(make_ndarray_ty(unifier, primitives, Some(left_ty_dtype), Some(res_ndims)))
Ok(primitive_type::NDArrayType::from_primitive(
unifier,
primitives,
Some(left_ty_dtype),
Some(res_ndims),
)
.into())
} else {
let (ndarray_ty, scalar_ty) = if is_left_ndarray { (left, right) } else { (right, left) };
let (ndarray_ty_dtype, _) = unpack_ndarray_var_tys(unifier, ndarray_ty);
let ndarray_ty_dtype =
primitive_type::NDArrayType::create(ndarray_ty, unifier).ndims_tvar(unifier).ty;
if unifier.unioned(ndarray_ty_dtype, scalar_ty) {
Ok(ndarray_ty)
@ -444,7 +455,8 @@ pub fn typeof_binop(
}
Operator::MatMult => {
let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs);
let lhs_ndims =
primitive_type::NDArrayType::create(lhs, unifier).ndims_tvar(unifier).ty;
let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) {
TypeEnum::TLiteral { values, .. } => {
assert_eq!(values.len(), 1);
@ -452,7 +464,8 @@ pub fn typeof_binop(
}
_ => unreachable!(),
};
let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs);
let rhs_ndims =
primitive_type::NDArrayType::create(rhs, unifier).ndims_tvar(unifier).ty;
let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) {
TypeEnum::TLiteral { values, .. } => {
assert_eq!(values.len(), 1);
@ -526,7 +539,7 @@ pub fn typeof_unaryop(
let operand_obj_id = operand.obj_id(unifier);
if op == Unaryop::Not
&& operand_obj_id.is_some_and(|id| id == primitives.ndarray.obj_id(unifier).unwrap())
&& operand_obj_id.is_some_and(|id| id == primitives.ndarray.obj_id(unifier))
{
return Err(
"The truth value of an array with more than one element is ambiguous".to_string()
@ -552,7 +565,8 @@ pub fn typeof_unaryop(
Unaryop::UAdd | Unaryop::USub => {
if operand_obj_id.is_some_and(|id| id == PrimDef::NDArray.id()) {
let (dtype, _) = unpack_ndarray_var_tys(unifier, operand);
let dtype =
primitive_type::NDArrayType::create(operand, unifier).dtype_tvar(unifier).ty;
if dtype.obj_id(unifier).is_some_and(|id| id == PrimDef::Bool.id()) {
return Err(if op == Unaryop::UAdd {
"The ufunc 'positive' cannot be applied to ndarray[bool, N]".to_string()
@ -586,9 +600,15 @@ pub fn typeof_cmpop(
Ok(Some(if is_left_ndarray || is_right_ndarray {
let brd = typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?;
let (_, ndims) = unpack_ndarray_var_tys(unifier, brd);
let ndims = primitive_type::NDArrayType::create(brd, unifier).ndims_tvar(unifier).ty;
make_ndarray_ty(unifier, primitives, Some(primitives.bool), Some(ndims))
primitive_type::NDArrayType::from_primitive(
unifier,
primitives,
Some(primitives.bool),
Some(ndims),
)
.into()
} else if unifier.unioned(lhs, rhs) {
primitives.bool
} else {
@ -611,64 +631,108 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
/* int ======== */
for t in [int32_t, int64_t, uint32_t, uint64_t] {
let ndarray_int_t = make_ndarray_ty(unifier, store, Some(t), None);
impl_basic_arithmetic(unifier, store, t, &[t, ndarray_int_t], None);
impl_pow(unifier, store, t, &[t, ndarray_int_t], None);
let ndarray_int_t =
primitive_type::NDArrayType::from_primitive(unifier, store, Some(t), None);
impl_basic_arithmetic(unifier, store, t, &[t, ndarray_int_t.into()], None);
impl_pow(unifier, store, t, &[t, ndarray_int_t.into()], None);
impl_bitwise_arithmetic(unifier, store, t);
impl_bitwise_shift(unifier, store, t);
impl_div(unifier, store, t, &[t, ndarray_int_t], None);
impl_floordiv(unifier, store, t, &[t, ndarray_int_t], None);
impl_mod(unifier, store, t, &[t, ndarray_int_t], None);
impl_div(unifier, store, t, &[t, ndarray_int_t.into()], None);
impl_floordiv(unifier, store, t, &[t, ndarray_int_t.into()], None);
impl_mod(unifier, store, t, &[t, ndarray_int_t.into()], None);
impl_invert(unifier, store, t, Some(t));
impl_not(unifier, store, t, Some(bool_t));
impl_comparison(unifier, store, t, &[t, ndarray_int_t], None);
impl_eq(unifier, store, t, &[t, ndarray_int_t], None);
impl_comparison(unifier, store, t, &[t, ndarray_int_t.into()], None);
impl_eq(unifier, store, t, &[t, ndarray_int_t.into()], None);
}
for t in [int32_t, int64_t] {
impl_sign(unifier, store, t, Some(t));
}
/* float ======== */
let ndarray_float_t = make_ndarray_ty(unifier, store, Some(float_t), None);
let ndarray_int32_t = make_ndarray_ty(unifier, store, Some(int32_t), None);
impl_basic_arithmetic(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_pow(unifier, store, float_t, &[int32_t, float_t, ndarray_int32_t, ndarray_float_t], None);
impl_div(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_floordiv(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_mod(unifier, store, float_t, &[float_t, ndarray_float_t], None);
let ndarray_float_t =
primitive_type::NDArrayType::from_primitive(unifier, store, Some(float_t), None);
let ndarray_int32_t =
primitive_type::NDArrayType::from_primitive(unifier, store, Some(int32_t), None);
impl_basic_arithmetic(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None);
impl_pow(
unifier,
store,
float_t,
&[int32_t, float_t, ndarray_int32_t.into(), ndarray_float_t.into()],
None,
);
impl_div(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None);
impl_floordiv(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None);
impl_mod(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None);
impl_sign(unifier, store, float_t, Some(float_t));
impl_not(unifier, store, float_t, Some(bool_t));
impl_comparison(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_eq(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_comparison(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None);
impl_eq(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None);
/* bool ======== */
let ndarray_bool_t = make_ndarray_ty(unifier, store, Some(bool_t), None);
let ndarray_bool_t =
primitive_type::NDArrayType::from_primitive(unifier, store, Some(bool_t), None);
impl_invert(unifier, store, bool_t, Some(int32_t));
impl_not(unifier, store, bool_t, Some(bool_t));
impl_sign(unifier, store, bool_t, Some(int32_t));
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None);
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t.into()], None);
/* ndarray ===== */
let ndarray_usized_ndims_tvar =
unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None);
let ndarray_unsized_t =
make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.ty));
let (ndarray_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_t);
let (ndarray_unsized_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_unsized_t);
let ndarray_unsized_t = primitive_type::NDArrayType::from_primitive(
unifier,
store,
None,
Some(ndarray_usized_ndims_tvar.ty),
);
let ndarray_dtype_t = ndarray_t.dtype_tvar(unifier).ty;
let ndarray_unsized_dtype_t = ndarray_unsized_t.dtype_tvar(unifier).ty;
impl_basic_arithmetic(
unifier,
store,
ndarray_t,
&[ndarray_unsized_t, ndarray_unsized_dtype_t],
ndarray_t.into(),
&[ndarray_unsized_t.into(), ndarray_unsized_dtype_t],
None,
);
impl_pow(
unifier,
store,
ndarray_t.into(),
&[ndarray_unsized_t.into(), ndarray_unsized_dtype_t],
None,
);
impl_div(unifier, store, ndarray_t.into(), &[ndarray_t.into(), ndarray_dtype_t], None);
impl_floordiv(
unifier,
store,
ndarray_t.into(),
&[ndarray_unsized_t.into(), ndarray_unsized_dtype_t],
None,
);
impl_mod(
unifier,
store,
ndarray_t.into(),
&[ndarray_unsized_t.into(), ndarray_unsized_dtype_t],
None,
);
impl_matmul(unifier, store, ndarray_t.into(), &[ndarray_t.into()], Some(ndarray_t.into()));
impl_sign(unifier, store, ndarray_t.into(), Some(ndarray_t.into()));
impl_invert(unifier, store, ndarray_t.into(), Some(ndarray_t.into()));
impl_eq(
unifier,
store,
ndarray_t.into(),
&[ndarray_unsized_t.into(), ndarray_unsized_dtype_t],
None,
);
impl_comparison(
unifier,
store,
ndarray_t.into(),
&[ndarray_unsized_t.into(), ndarray_unsized_dtype_t],
None,
);
impl_pow(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None);
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_matmul(unifier, store, ndarray_t, &[ndarray_t], Some(ndarray_t));
impl_sign(unifier, store, ndarray_t, Some(ndarray_t));
impl_invert(unifier, store, ndarray_t, Some(ndarray_t));
impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_comparison(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
}

View File

@ -4,14 +4,16 @@ use std::iter::once;
use std::ops::Not;
use std::{cell::RefCell, sync::Arc};
use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap};
use super::typedef::{
Call, FunSignature, FuncArg, GenericObjectType, RecordField, Type, TypeEnum, Unifier, VarMap,
};
use super::{magic_methods::*, type_error::TypeError, typedef::CallId};
use crate::toplevel::primitive_type::{NDArrayType, OptionType};
use crate::toplevel::TopLevelDef;
use crate::{
symbol_resolver::{SymbolResolver, SymbolValue},
toplevel::{
helper::{arraylike_flatten_element_type, arraylike_get_ndims, PrimDef},
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
TopLevelContext,
},
};
@ -49,8 +51,8 @@ pub struct PrimitiveStore {
pub range: Type,
pub str: Type,
pub exception: Type,
pub option: Type,
pub ndarray: Type,
pub option: OptionType,
pub ndarray: NDArrayType,
pub size_t: u32,
}
@ -97,8 +99,8 @@ impl IntoIterator for &PrimitiveStore {
self.range,
self.str,
self.exception,
self.option,
self.ndarray,
self.option.into(),
self.ndarray.into(),
]
.into_iter()
}
@ -528,7 +530,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
// the name `none` is special since it may have different types
if id == &"none".into() {
if let TypeEnum::TObj { params, .. } =
self.unifier.get_ty_immutable(self.primitives.option).as_ref()
&*self.unifier.get_ty_immutable(self.primitives.option.into())
{
let var_map = params
.iter()
@ -543,7 +545,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
(*id, self.unifier.get_fresh_var_with_range(range, *name, *loc).ty)
})
.collect::<VarMap>();
Some(self.unifier.subst(self.primitives.option, &var_map).unwrap())
Some(self.unifier.subst(self.primitives.option.into(), &var_map).unwrap())
} else {
unreachable!("must be tobj")
}
@ -1062,9 +1064,16 @@ impl<'a> Inferencer<'a> {
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
{
let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
let ndarray_ndims =
NDArrayType::create(arg0_ty, self.unifier).ndims_tvar(self.unifier).ty;
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims))
NDArrayType::from_primitive(
self.unifier,
self.primitives,
Some(target_ty),
Some(ndarray_ndims),
)
.into()
} else {
target_ty
};
@ -1100,9 +1109,7 @@ impl<'a> Inferencer<'a> {
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
{
let (ndarray_dtype, _) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
ndarray_dtype
NDArrayType::create(arg0_ty, self.unifier).dtype_tvar(self.unifier).ty
} else {
arg0_ty
};
@ -1154,14 +1161,14 @@ impl<'a> Inferencer<'a> {
let arg0_dtype =
if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
unpack_ndarray_var_tys(self.unifier, arg0_ty).0
NDArrayType::create(arg0_ty, self.unifier).dtype_tvar(self.unifier).ty
} else {
arg0_ty
};
let arg1_dtype =
if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
unpack_ndarray_var_tys(self.unifier, arg1_ty).0
NDArrayType::create(arg1_ty, self.unifier).dtype_tvar(self.unifier).ty
} else {
arg1_ty
};
@ -1192,9 +1199,17 @@ impl<'a> Inferencer<'a> {
// (float, int32), so convert it to align with the dtype of the first arg
let arg1_ty = if id == &"np_ldexp".into() {
if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, arg1_ty);
// let (_, ndims) = unpack_ndarray_var_tys(self.unifier, arg1_ty);
let ndims =
NDArrayType::create(arg1_ty, self.unifier).ndims_tvar(self.unifier).ty;
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndims))
NDArrayType::from_primitive(
self.unifier,
self.primitives,
Some(target_ty),
Some(ndims),
)
.into()
} else {
target_ty
}
@ -1281,9 +1296,16 @@ impl<'a> Inferencer<'a> {
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
{
let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
let ndarray_ndims =
NDArrayType::create(arg0_ty, self.unifier).ndims_tvar(self.unifier).ty;
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims))
NDArrayType::from_primitive(
self.unifier,
self.primitives,
Some(target_ty),
Some(ndarray_ndims),
)
.into()
} else {
target_ty
};
@ -1323,7 +1345,7 @@ impl<'a> Inferencer<'a> {
self.fold_numpy_function_call_shape_argument(*id, 0, shape_expr)?; // Special handling for `shape`
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
let ret = make_ndarray_ty(
let ret = NDArrayType::from_primitive(
self.unifier,
self.primitives,
Some(self.primitives.float),
@ -1335,13 +1357,13 @@ impl<'a> Inferencer<'a> {
ty: shape.custom.unwrap(),
default_value: None,
}],
ret,
ret: ret.into(),
vars: VarMap::new(),
}));
return Ok(Some(Located {
location,
custom: Some(ret),
custom: Some(ret.into()),
node: ExprKind::Call {
func: Box::new(Located {
custom: Some(custom),
@ -1374,7 +1396,8 @@ impl<'a> Inferencer<'a> {
let ty = arg1.custom.unwrap();
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
let ret = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims));
let ret =
NDArrayType::from_primitive(self.unifier, self.primitives, Some(ty), Some(ndims));
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg { name: "shape".into(), ty: arg0.custom.unwrap(), default_value: None },
@ -1384,13 +1407,13 @@ impl<'a> Inferencer<'a> {
default_value: None,
},
],
ret,
ret: ret.into(),
vars: VarMap::new(),
}));
return Ok(Some(Located {
location,
custom: Some(ret),
custom: Some(ret.into()),
node: ExprKind::Call {
func: Box::new(Located {
custom: Some(custom),
@ -1428,7 +1451,8 @@ impl<'a> Inferencer<'a> {
arraylike_get_ndims(self.unifier, arg0.custom.unwrap())
};
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
let ret = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims));
let ret =
NDArrayType::from_primitive(self.unifier, self.primitives, Some(ty), Some(ndims));
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
@ -1448,13 +1472,13 @@ impl<'a> Inferencer<'a> {
default_value: Some(SymbolValue::U32(0)),
},
],
ret,
ret: ret.into(),
vars: VarMap::new(),
}));
return Ok(Some(Located {
location,
custom: Some(ret),
custom: Some(ret.into()),
node: ExprKind::Call {
func: Box::new(Located {
custom: Some(custom),
@ -1803,9 +1827,13 @@ impl<'a> Inferencer<'a> {
TypeEnum::TVar { is_const_generic: false, .. }
));
let constrained_ty =
make_ndarray_ty(self.unifier, self.primitives, Some(dummy_tvar), Some(ndims));
self.constrain(value.custom.unwrap(), constrained_ty, &value.location)?;
let constrained_ty = NDArrayType::from_primitive(
self.unifier,
self.primitives,
Some(dummy_tvar),
Some(ndims),
);
self.constrain(value.custom.unwrap(), constrained_ty.into(), &value.location)?;
let TypeEnum::TLiteral { values, .. } = &*self.unifier.get_ty_immutable(ndims) else {
panic!("Expected TLiteral for ndarray.ndims, got {}", self.unifier.stringify(ndims))
@ -1871,10 +1899,14 @@ impl<'a> Inferencer<'a> {
let ndims_ty = self
.unifier
.get_fresh_literal(new_ndims.into_iter().map(SymbolValue::U64).collect(), None);
let subscripted_ty =
make_ndarray_ty(self.unifier, self.primitives, Some(dummy_tvar), Some(ndims_ty));
let subscripted_ty = NDArrayType::from_primitive(
self.unifier,
self.primitives,
Some(dummy_tvar),
Some(ndims_ty),
);
Ok(subscripted_ty)
Ok(subscripted_ty.into())
}
}
@ -1893,10 +1925,17 @@ impl<'a> Inferencer<'a> {
let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) {
TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }),
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (_, ndims) =
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
let ndims = NDArrayType::create(value.custom.unwrap(), self.unifier)
.ndims_tvar(self.unifier)
.ty;
make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims))
NDArrayType::from_primitive(
self.unifier,
self.primitives,
Some(ty),
Some(ndims),
)
.into()
}
_ => unreachable!(),
@ -1907,8 +1946,10 @@ impl<'a> Inferencer<'a> {
ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
match &*self.unifier.get_ty(value.custom.unwrap()) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (_, ndims) =
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
let ndims = NDArrayType::create(value.custom.unwrap(), self.unifier)
.ndims_tvar(self.unifier)
.ty;
self.infer_subscript_ndarray(value, slice, ty, ndims)
}
_ => {
@ -1951,7 +1992,10 @@ impl<'a> Inferencer<'a> {
}
}
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
let ndims = NDArrayType::create(value.custom.unwrap(), self.unifier)
.ndims_tvar(self.unifier)
.ty;
self.infer_subscript_ndarray(value, slice, ty, ndims)
}
_ => {
@ -1975,8 +2019,9 @@ impl<'a> Inferencer<'a> {
Ok(ty)
}
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (_, ndims) =
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
let ndims = NDArrayType::create(value.custom.unwrap(), self.unifier)
.ndims_tvar(self.unifier)
.ty;
let valid_index_tys = [self.primitives.int32, self.primitives.isize()]
.into_iter()

View File

@ -139,6 +139,7 @@ impl TestEnvironment {
fields: HashMap::new(),
params: VarMap::new(),
});
let option = OptionType::create(option, &mut unifier);
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
let ndarray_ndims_tvar =
unifier.get_fresh_const_generic_var(uint64, Some("ndarray_ndims".into()), None);
@ -147,6 +148,7 @@ impl TestEnvironment {
fields: HashMap::new(),
params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
});
let ndarray = NDArrayType::create(ndarray, &mut unifier);
let primitives = PrimitiveStore {
int32,
int64,
@ -273,11 +275,13 @@ impl TestEnvironment {
fields: HashMap::new(),
params: VarMap::new(),
});
let option = OptionType::create(option, &mut unifier);
let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::NDArray.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
let ndarray = NDArrayType::create(ndarray, &mut unifier);
identifier_mapping.insert("None".into(), none);
for (i, name) in ["int32", "int64", "float", "bool", "none", "range", "str", "Exception"]
.iter()

View File

@ -137,7 +137,8 @@ where
#[must_use]
fn get_type(&self) -> Type;
/// See [`Type::obj_id`].
/// Similar to [`Type::obj_id`], except that the [`DefinitionId`] is not wrapped within an
/// [`Option`].
#[must_use]
fn obj_id(&self, unifier: &Unifier) -> DefinitionId {
self.get_type().obj_id(unifier).unwrap()