core/typedef: WIP - Add OptionType and NDArrayType

This commit is contained in:
David Mak 2024-06-27 16:33:27 +08:00
parent da4dec08a5
commit 6892a4848e
18 changed files with 717 additions and 407 deletions

View File

@ -7,7 +7,7 @@ use nac3core::{
}, },
symbol_resolver::ValueEnum, symbol_resolver::ValueEnum,
toplevel::{helper::PrimDef, DefinitionId, GenCall}, toplevel::{helper::PrimDef, DefinitionId, GenCall},
typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, VarMap}, typecheck::typedef::{FunSignature, FuncArg, GenericObjectType, Type, TypeEnum, VarMap},
}; };
use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef}; use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef};
@ -23,7 +23,7 @@ use pyo3::{
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns}; use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
use nac3core::toplevel::numpy::unpack_ndarray_var_tys; use nac3core::toplevel::primitive_type;
use std::{ use std::{
collections::hash_map::DefaultHasher, collections::hash_map::DefaultHasher,
collections::HashMap, collections::HashMap,
@ -399,7 +399,9 @@ fn gen_rpc_tag(
gen_rpc_tag(ctx, *ty, buffer)?; gen_rpc_tag(ctx, *ty, buffer)?;
} }
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); let ndarray_ty = primitive_type::NDArrayType::create(ty, &mut ctx.unifier);
let ndarray_dtype = ndarray_ty.dtype_tvar(&mut ctx.unifier).ty;
let ndarray_ndims = ndarray_ty.ndims_tvar(&mut ctx.unifier).ty;
let ndarray_ndims = if let TLiteral { values, .. } = let ndarray_ndims = if let TLiteral { values, .. } =
&*ctx.unifier.get_ty_immutable(ndarray_ndims) &*ctx.unifier.get_ty_immutable(ndarray_ndims)
{ {
@ -645,7 +647,7 @@ pub fn attributes_writeback(
let ty = ty.unwrap(); let ty = ty.unwrap();
match &*ctx.unifier.get_ty(ty) { match &*ctx.unifier.get_ty(ty) {
TypeEnum::TObj { fields, obj_id, .. } TypeEnum::TObj { fields, obj_id, .. }
if *obj_id != ctx.primitives.option.obj_id(&ctx.unifier).unwrap() => if *obj_id != ctx.primitives.option.obj_id(&ctx.unifier) =>
{ {
// we only care about primitive attributes // we only care about primitive attributes
// for non-primitive attributes, they should be in another global // for non-primitive attributes, they should be in another global

View File

@ -11,11 +11,7 @@ use nac3core::{
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}, },
symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum}, symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum},
toplevel::{ toplevel::{helper::PrimDef, primitive_type, DefinitionId, TopLevelDef},
helper::PrimDef,
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
DefinitionId, TopLevelDef,
},
typecheck::{ typecheck::{
type_inferencer::PrimitiveStore, type_inferencer::PrimitiveStore,
typedef::{Type, TypeEnum, TypeVar, Unifier, VarMap}, typedef::{Type, TypeEnum, TypeVar, Unifier, VarMap},
@ -337,13 +333,18 @@ impl InnerResolver {
// do not handle type var param and concrete check here // do not handle type var param and concrete check here
let var = unifier.get_dummy_var().ty; let var = unifier.get_dummy_var().ty;
let ndims = unifier.get_fresh_const_generic_var(primitives.usize(), None, None).ty; let ndims = unifier.get_fresh_const_generic_var(primitives.usize(), None, None).ty;
let ndarray = make_ndarray_ty(unifier, primitives, Some(var), Some(ndims)); let ndarray = primitive_type::NDArrayType::from_primitive(
Ok(Ok((ndarray, false))) unifier,
primitives,
Some(var),
Some(ndims),
);
Ok(Ok((ndarray.into(), false)))
} else if ty_id == self.primitive_ids.tuple { } else if ty_id == self.primitive_ids.tuple {
// do not handle type var param and concrete check here // do not handle type var param and concrete check here
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false))) Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false)))
} else if ty_id == self.primitive_ids.option { } else if ty_id == self.primitive_ids.option {
Ok(Ok((primitives.option, false))) Ok(Ok((primitives.option.into(), false)))
} else if ty_id == self.primitive_ids.none { } else if ty_id == self.primitive_ids.none {
unreachable!("none cannot be typeid") unreachable!("none cannot be typeid")
} else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).copied() { } else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).copied() {
@ -510,7 +511,16 @@ impl InnerResolver {
)); ));
} }
Ok(Ok((make_ndarray_ty(unifier, primitives, Some(ty.0), None), true))) Ok(Ok((
primitive_type::NDArrayType::from_primitive(
unifier,
primitives,
Some(ty.0),
None,
)
.into(),
true,
)))
} }
TypeEnum::TTuple { .. } => { TypeEnum::TTuple { .. } => {
let args = match args let args = match args
@ -719,7 +729,9 @@ impl InnerResolver {
} }
} }
(TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PrimDef::NDArray.id() => { (TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PrimDef::NDArray.id() => {
let (ty, ndims) = unpack_ndarray_var_tys(unifier, extracted_ty); let ndarray = primitive_type::NDArrayType::create(extracted_ty, unifier);
let ty = ndarray.dtype_tvar(unifier).ty;
let ndims = ndarray.ndims_tvar(unifier).ty;
let len: usize = obj.getattr("ndim")?.extract()?; let len: usize = obj.getattr("ndim")?.extract()?;
if len == 0 { if len == 0 {
assert!(matches!( assert!(matches!(
@ -734,10 +746,14 @@ impl InnerResolver {
match dtype_ty { match dtype_ty {
Ok((t, _)) => match unifier.unify(ty, t) { Ok((t, _)) => match unifier.unify(ty, t) {
Ok(()) => { Ok(()) => {
let ndarray_ty = let ndarray_ty = primitive_type::NDArrayType::from_primitive(
make_ndarray_ty(unifier, primitives, Some(ty), Some(ndims)); unifier,
primitives,
Some(ty),
Some(ndims),
);
Ok(Ok(ndarray_ty)) Ok(Ok(ndarray_ty.into()))
} }
Err(e) => Ok(Err(format!( Err(e) => Ok(Err(format!(
"type error ({}) for the ndarray", "type error ({}) for the ndarray",
@ -760,7 +776,7 @@ impl InnerResolver {
// special handling for option type since its class member layout in python side // special handling for option type since its class member layout in python side
// is special and cannot be mapped directly to a nac3 type as below // is special and cannot be mapped directly to a nac3 type as below
(TypeEnum::TObj { obj_id, params, .. }, false) (TypeEnum::TObj { obj_id, params, .. }, false)
if *obj_id == primitives.option.obj_id(unifier).unwrap() => if *obj_id == primitives.option.obj_id(unifier) =>
{ {
let Ok(field_data) = obj.getattr("_nac3_option") else { let Ok(field_data) = obj.getattr("_nac3_option") else {
unreachable!("cannot be None") unreachable!("cannot be None")
@ -785,7 +801,7 @@ impl InnerResolver {
.map(TypeVar::into) .map(TypeVar::into)
.collect::<VarMap>() .collect::<VarMap>()
}); });
return Ok(Ok(unifier.subst(primitives.option, &var_map).unwrap())); return Ok(Ok(unifier.subst(primitives.option.into(), &var_map).unwrap()));
} }
let ty = match self.get_obj_type(py, field_data, unifier, defs, primitives)? { let ty = match self.get_obj_type(py, field_data, unifier, defs, primitives)? {
@ -1038,8 +1054,9 @@ impl InnerResolver {
} else { } else {
unreachable!("must be ndarray") unreachable!("must be ndarray")
}; };
let (ndarray_dtype, ndarray_ndims) = let ndarray_ty = primitive_type::NDArrayType::create(ndarray_ty, &mut ctx.unifier);
unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty); let ndarray_dtype = ndarray_ty.dtype_tvar(&mut ctx.unifier).ty;
let ndarray_ndims = ndarray_ty.ndims_tvar(&mut ctx.unifier).ty;
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let ndarray_dtype_llvm_ty = ctx.get_llvm_type(generator, ndarray_dtype); let ndarray_dtype_llvm_ty = ctx.get_llvm_type(generator, ndarray_dtype);
@ -1186,7 +1203,7 @@ impl InnerResolver {
} else if ty_id == self.primitive_ids.option { } else if ty_id == self.primitive_ids.option {
let option_val_ty = match ctx.unifier.get_ty_immutable(expected_ty).as_ref() { let option_val_ty = match ctx.unifier.get_ty_immutable(expected_ty).as_ref() {
TypeEnum::TObj { obj_id, params, .. } TypeEnum::TObj { obj_id, params, .. }
if *obj_id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap() => if *obj_id == ctx.primitives.option.obj_id(&ctx.unifier) =>
{ {
*params.iter().next().unwrap().1 *params.iter().next().unwrap().1
} }

View File

@ -8,8 +8,8 @@ use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
use crate::codegen::stmt::gen_for_callback_incrementing; use crate::codegen::stmt::gen_for_callback_incrementing;
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator}; use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
use crate::toplevel::helper::PrimDef; use crate::toplevel::helper::PrimDef;
use crate::toplevel::numpy::unpack_ndarray_var_tys; use crate::toplevel::primitive_type;
use crate::typecheck::typedef::Type; use crate::typecheck::typedef::{GenericObjectType, Type};
/// Shorthand for [`unreachable!()`] when a type of argument is not supported. /// Shorthand for [`unreachable!()`] when a type of argument is not supported.
/// ///
@ -66,7 +66,9 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -128,7 +130,9 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -206,7 +210,9 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -273,7 +279,9 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -328,7 +336,9 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -374,7 +384,9 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -414,7 +426,9 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -475,7 +489,9 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -529,7 +545,9 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -579,7 +597,9 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -660,7 +680,9 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); let elem_ty = primitive_type::NDArrayType::create(a_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
@ -751,16 +773,24 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 { let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); .dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1 ndarray_dtype1
} else if is_ndarray1 { } else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else { } else {
unreachable!() unreachable!()
}; };
@ -850,7 +880,9 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); let elem_ty = primitive_type::NDArrayType::create(a_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
@ -941,16 +973,24 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 { let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); .dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1 ndarray_dtype1
} else if is_ndarray1 { } else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else { } else {
unreachable!() unreachable!()
}; };
@ -1008,7 +1048,9 @@ where
if arg_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if arg_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let (arg_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); let arg_elem_ty = primitive_type::NDArrayType::create(arg_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ret_elem_ty = get_ret_elem_type(ctx, arg_elem_ty); let ret_elem_ty = get_ret_elem_type(ctx, arg_elem_ty);
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
@ -1370,16 +1412,24 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 { let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); .dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1 ndarray_dtype1
} else if is_ndarray1 { } else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else { } else {
unreachable!() unreachable!()
}; };
@ -1437,16 +1487,24 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 { let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); .dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1 ndarray_dtype1
} else if is_ndarray1 { } else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else { } else {
unreachable!() unreachable!()
}; };
@ -1504,16 +1562,24 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 { let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); .dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1 ndarray_dtype1
} else if is_ndarray1 { } else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else { } else {
unreachable!() unreachable!()
}; };
@ -1571,16 +1637,24 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 { let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); .dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1 ndarray_dtype1
} else if is_ndarray1 { } else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else { } else {
unreachable!() unreachable!()
}; };
@ -1637,12 +1711,22 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>(
let is_ndarray2 = let is_ndarray2 =
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = let dtype = if is_ndarray1 {
if is_ndarray1 { unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 } else { x1_ty }; primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else {
x1_ty
};
let x1_scalar_ty = dtype; let x1_scalar_ty = dtype;
let x2_scalar_ty = let x2_scalar_ty = if is_ndarray2 {
if is_ndarray2 { unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 } else { x2_ty }; primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else {
x2_ty
};
numpy::ndarray_elementwise_binop_impl( numpy::ndarray_elementwise_binop_impl(
generator, generator,
@ -1694,16 +1778,24 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 { let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); .dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1 ndarray_dtype1
} else if is_ndarray1 { } else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else { } else {
unreachable!() unreachable!()
}; };
@ -1761,16 +1853,24 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 { let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); .dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1 ndarray_dtype1
} else if is_ndarray1 { } else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else { } else {
unreachable!() unreachable!()
}; };

View File

@ -1,5 +1,9 @@
use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip}; use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
use super::{llvm_intrinsics::call_memcpy_generic, need_sret, CodeGenerator};
use crate::toplevel::primitive_type;
use crate::toplevel::primitive_type::OptionType;
use crate::typecheck::typedef::GenericObjectType;
use crate::{ use crate::{
codegen::{ codegen::{
classes::{ classes::{
@ -15,11 +19,7 @@ use crate::{
CodeGenContext, CodeGenTask, CodeGenContext, CodeGenTask,
}, },
symbol_resolver::{SymbolValue, ValueEnum}, symbol_resolver::{SymbolValue, ValueEnum},
toplevel::{ toplevel::{helper::PrimDef, DefinitionId, TopLevelDef},
helper::PrimDef,
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
DefinitionId, TopLevelDef,
},
typecheck::{ typecheck::{
magic_methods::{binop_assign_name, binop_name, unaryop_name}, magic_methods::{binop_assign_name, binop_name, unaryop_name},
typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap}, typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap},
@ -36,8 +36,6 @@ use nac3parser::ast::{
self, Boolop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef, self, Boolop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef,
}; };
use super::{llvm_intrinsics::call_memcpy_generic, need_sret, CodeGenerator};
pub fn get_subst_key( pub fn get_subst_key(
unifier: &mut Unifier, unifier: &mut Unifier,
obj: Option<Type>, obj: Option<Type>,
@ -162,14 +160,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
self.builder.build_load(ptr, "tup_val").unwrap() self.builder.build_load(ptr, "tup_val").unwrap()
} }
SymbolValue::OptionSome(v) => { SymbolValue::OptionSome(v) => {
let ty = match self.unifier.get_ty_immutable(ty).as_ref() { let ty = OptionType::create(ty, &mut self.unifier).type_tvar(&mut self.unifier).ty;
TypeEnum::TObj { obj_id, params, .. }
if *obj_id == self.primitives.option.obj_id(&self.unifier).unwrap() =>
{
*params.iter().next().unwrap().1
}
_ => unreachable!("must be option type"),
};
let val = self.gen_symbol_val(generator, v, ty); let val = self.gen_symbol_val(generator, v, ty);
let ptr = generator let ptr = generator
.gen_var_alloc(self, val.get_type(), Some("default_opt_some")) .gen_var_alloc(self, val.get_type(), Some("default_opt_some"))
@ -178,14 +169,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
ptr.into() ptr.into()
} }
SymbolValue::OptionNone => { SymbolValue::OptionNone => {
let ty = match self.unifier.get_ty_immutable(ty).as_ref() { let ty = OptionType::create(ty, &mut self.unifier).type_tvar(&mut self.unifier).ty;
TypeEnum::TObj { obj_id, params, .. }
if *obj_id == self.primitives.option.obj_id(&self.unifier).unwrap() =>
{
*params.iter().next().unwrap().1
}
_ => unreachable!("must be option type"),
};
let actual_ptr_type = let actual_ptr_type =
self.get_llvm_type(generator, ty).ptr_type(AddressSpace::default()); self.get_llvm_type(generator, ty).ptr_type(AddressSpace::default());
actual_ptr_type.const_null().into() actual_ptr_type.const_null().into()
@ -1206,8 +1190,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
let is_ndarray2 = ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let is_ndarray2 = ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
if is_ndarray1 && is_ndarray2 { if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1); let ndarray_dtype1 = primitive_type::NDArrayType::create(ty1, &mut ctx.unifier)
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty2); .dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_dtype2 = primitive_type::NDArrayType::create(ty2, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
@ -1256,8 +1244,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
Ok(Some(res.as_base_value().into())) Ok(Some(res.as_base_value().into()))
} else { } else {
let (ndarray_dtype, _) = let ndarray_dtype = primitive_type::NDArrayType::create(
unpack_ndarray_var_tys(&mut ctx.unifier, if is_ndarray1 { ty1 } else { ty2 }); if is_ndarray1 { ty1 } else { ty2 },
&mut ctx.unifier,
)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_val = NDArrayValue::from_ptr_val( let ndarray_val = NDArrayValue::from_ptr_val(
if is_ndarray1 { left_val } else { right_val }.into_pointer_value(), if is_ndarray1 { left_val } else { right_val }.into_pointer_value(),
llvm_usize, llvm_usize,
@ -1443,7 +1435,9 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
} }
} else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { } else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); let ndarray_dtype = primitive_type::NDArrayType::create(ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let val = NDArrayValue::from_ptr_val(val.into_pointer_value(), llvm_usize, None); let val = NDArrayValue::from_ptr_val(val.into_pointer_value(), llvm_usize, None);
@ -1527,8 +1521,13 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
return if is_ndarray1 && is_ndarray2 { return if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty); let ndarray_dtype1 = primitive_type::NDArrayType::create(left_ty, &mut ctx.unifier)
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty); .dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_dtype2 =
primitive_type::NDArrayType::create(right_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
@ -1562,10 +1561,12 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
Ok(Some(res.as_base_value().into())) Ok(Some(res.as_base_value().into()))
} else { } else {
let (ndarray_dtype, _) = unpack_ndarray_var_tys( let ndarray_dtype = primitive_type::NDArrayType::create(
&mut ctx.unifier,
if is_ndarray1 { left_ty } else { right_ty }, if is_ndarray1 { left_ty } else { right_ty },
); &mut ctx.unifier,
)
.dtype_tvar(&mut ctx.unifier)
.ty;
let res = numpy::ndarray_elementwise_binop_impl( let res = numpy::ndarray_elementwise_binop_impl(
generator, generator,
ctx, ctx,
@ -1788,9 +1789,13 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
ndims.iter().map(|v| SymbolValue::U64(v - subscripted_dims)).collect(), ndims.iter().map(|v| SymbolValue::U64(v - subscripted_dims)).collect(),
None, None,
); );
let ndarray_ty = let ndarray_ty = primitive_type::NDArrayType::from_primitive(
make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(ty), Some(ndarray_ndims_ty)); &mut ctx.unifier,
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type(); &ctx.primitives,
Some(ty),
Some(ndarray_ndims_ty),
);
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty.into()).into_pointer_type();
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type(); let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum(); let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum();
@ -2082,7 +2087,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
ExprKind::Name { id, .. } if id == &"none".into() => { ExprKind::Name { id, .. } if id == &"none".into() => {
match ( match (
ctx.unifier.get_ty(expr.custom.unwrap()).as_ref(), ctx.unifier.get_ty(expr.custom.unwrap()).as_ref(),
ctx.unifier.get_ty(ctx.primitives.option).as_ref(), ctx.unifier.get_ty(ctx.primitives.option.into()).as_ref(),
) { ) {
(TypeEnum::TObj { obj_id, params, .. }, TypeEnum::TObj { obj_id: opt_id, .. }) (TypeEnum::TObj { obj_id, params, .. }, TypeEnum::TObj { obj_id: opt_id, .. })
if *obj_id == *opt_id => if *obj_id == *opt_id =>
@ -2464,8 +2469,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
}; };
// directly generate code for option.unwrap // directly generate code for option.unwrap
// since it needs to return static value to optimize for kernel invariant // since it needs to return static value to optimize for kernel invariant
if attr == &"unwrap".into() if attr == &"unwrap".into() && id == ctx.primitives.option.obj_id(&ctx.unifier)
&& id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap()
{ {
match val { match val {
ValueEnum::Static(v) => { ValueEnum::Static(v) => {

View File

@ -1,7 +1,7 @@
use crate::{ use crate::{
codegen::classes::{ListType, NDArrayType, ProxyType, RangeType}, codegen::classes::{ListType, NDArrayType, ProxyType, RangeType},
symbol_resolver::{StaticValue, SymbolResolver}, symbol_resolver::{StaticValue, SymbolResolver},
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef}, toplevel::{helper::PrimDef, TopLevelContext, TopLevelDef},
typecheck::{ typecheck::{
type_inferencer::{CodeLocation, PrimitiveStore}, type_inferencer::{CodeLocation, PrimitiveStore},
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier}, typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
@ -47,6 +47,9 @@ pub mod stmt;
#[cfg(test)] #[cfg(test)]
mod test; mod test;
use crate::toplevel::primitive_type;
use crate::toplevel::primitive_type::OptionType;
use crate::typecheck::typedef::GenericObjectType;
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore}; use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore};
pub use generator::{CodeGenerator, DefaultCodeGenerator}; pub use generator::{CodeGenerator, DefaultCodeGenerator};
@ -457,7 +460,9 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
} }
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (dtype, _) = unpack_ndarray_var_tys(unifier, ty); let dtype = primitive_type::NDArrayType::create(ty, unifier)
.dtype_tvar(unifier)
.ty;
let element_type = get_llvm_type( let element_type = get_llvm_type(
ctx, module, generator, unifier, top_level, type_cache, dtype, ctx, module, generator, unifier, top_level, type_cache, dtype,
); );
@ -634,7 +639,10 @@ pub fn gen_func_impl<
range: unifier.get_representative(primitives.range), range: unifier.get_representative(primitives.range),
str: unifier.get_representative(primitives.str), str: unifier.get_representative(primitives.str),
exception: unifier.get_representative(primitives.exception), exception: unifier.get_representative(primitives.exception),
option: unifier.get_representative(primitives.option), option: OptionType::create(
unifier.get_representative(primitives.option.into()),
&mut unifier,
),
..primitives ..primitives
}; };

View File

@ -17,12 +17,8 @@ use crate::{
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}, },
symbol_resolver::ValueEnum, symbol_resolver::ValueEnum,
toplevel::{ toplevel::{helper::PrimDef, primitive_type, DefinitionId},
helper::PrimDef, typecheck::typedef::{FunSignature, GenericObjectType, Type, TypeEnum},
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
DefinitionId,
},
typecheck::typedef::{FunSignature, Type, TypeEnum},
}; };
use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType}; use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType};
use inkwell::{ use inkwell::{
@ -38,12 +34,17 @@ fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type, elem_ty: Type,
) -> Result<NDArrayValue<'ctx>, String> { ) -> Result<NDArrayValue<'ctx>, String> {
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None); let ndarray_ty = primitive_type::NDArrayType::from_primitive(
&mut ctx.unifier,
&ctx.primitives,
Some(elem_ty),
None,
);
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_ndarray_t = ctx let llvm_ndarray_t = ctx
.get_llvm_type(generator, ndarray_ty) .get_llvm_type(generator, ndarray_ty.into())
.into_pointer_type() .into_pointer_type()
.get_element_type() .get_element_type()
.into_struct_type(); .into_struct_type();
@ -1799,7 +1800,9 @@ pub fn gen_ndarray_array<'ctx>(
let obj_ty = fun.0.args[0].ty; let obj_ty = fun.0.args[0].ty;
let obj_elem_ty = match &*context.unifier.get_ty(obj_ty) { let obj_elem_ty = match &*context.unifier.get_ty(obj_ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
unpack_ndarray_var_tys(&mut context.unifier, obj_ty).0 primitive_type::NDArrayType::create(obj_ty, &mut context.unifier)
.dtype_tvar(&mut context.unifier)
.ty
} }
TypeEnum::TList { ty } => { TypeEnum::TList { ty } => {
@ -1939,7 +1942,9 @@ pub fn gen_ndarray_copy<'ctx>(
let llvm_usize = generator.get_size_type(context.ctx); let llvm_usize = generator.get_size_type(context.ctx);
let this_ty = obj.as_ref().unwrap().0; let this_ty = obj.as_ref().unwrap().0;
let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty); let this_elem_ty = primitive_type::NDArrayType::create(this_ty, &mut context.unifier)
.dtype_tvar(&mut context.unifier)
.ty;
let this_arg = let this_arg =
obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?; obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?;

View File

@ -4,13 +4,15 @@ use super::{
irrt::{handle_slice_indices, list_slice_assignment}, irrt::{handle_slice_indices, list_slice_assignment},
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}; };
use crate::toplevel::primitive_type;
use crate::typecheck::typedef::GenericObjectType;
use crate::{ use crate::{
codegen::{ codegen::{
classes::{ArrayLikeIndexer, ArraySliceValue, ListValue, RangeValue}, classes::{ArrayLikeIndexer, ArraySliceValue, ListValue, RangeValue},
expr::gen_binop_expr, expr::gen_binop_expr,
gen_in_range_check, gen_in_range_check,
}, },
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef}, toplevel::{helper::PrimDef, DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type, TypeEnum}, typecheck::typedef::{FunSignature, Type, TypeEnum},
}; };
use inkwell::{ use inkwell::{
@ -245,7 +247,9 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) { let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) {
TypeEnum::TList { ty } => *ty, TypeEnum::TList { ty } => *ty,
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
unpack_ndarray_var_tys(&mut ctx.unifier, target.custom.unwrap()).0 primitive_type::NDArrayType::create(target.custom.unwrap(), &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} }
_ => unreachable!(), _ => unreachable!(),
}; };

View File

@ -3,6 +3,7 @@ use std::rc::Rc;
use std::sync::Arc; use std::sync::Arc;
use std::{collections::HashMap, collections::HashSet, fmt::Display}; use std::{collections::HashMap, collections::HashSet, fmt::Display};
use crate::typecheck::typedef::GenericObjectType;
use crate::{ use crate::{
codegen::{CodeGenContext, CodeGenerator}, codegen::{CodeGenContext, CodeGenerator},
toplevel::{type_annotation::TypeAnnotation, DefinitionId, TopLevelDef}, toplevel::{type_annotation::TypeAnnotation, DefinitionId, TopLevelDef},
@ -43,7 +44,7 @@ impl SymbolValue {
) -> Result<Self, String> { ) -> Result<Self, String> {
match constant { match constant {
Constant::None => { Constant::None => {
if unifier.unioned(expected_ty, primitives.option) { if unifier.unioned(expected_ty, primitives.option.into()) {
Ok(SymbolValue::OptionNone) Ok(SymbolValue::OptionNone)
} else { } else {
Err(format!("Expected {expected_ty:?}, but got Option")) Err(format!("Expected {expected_ty:?}, but got Option"))
@ -157,7 +158,7 @@ impl SymbolValue {
let vs_tys = vs.iter().map(|v| v.get_type(primitives, unifier)).collect::<Vec<_>>(); let vs_tys = vs.iter().map(|v| v.get_type(primitives, unifier)).collect::<Vec<_>>();
unifier.add_ty(TypeEnum::TTuple { ty: vs_tys }) unifier.add_ty(TypeEnum::TTuple { ty: vs_tys })
} }
SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option, SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option.into(),
} }
} }
@ -183,13 +184,13 @@ impl SymbolValue {
TypeAnnotation::Tuple(vs_tys) TypeAnnotation::Tuple(vs_tys)
} }
SymbolValue::OptionNone => TypeAnnotation::CustomClass { SymbolValue::OptionNone => TypeAnnotation::CustomClass {
id: primitives.option.obj_id(unifier).unwrap(), id: primitives.option.obj_id(unifier),
params: Vec::default(), params: Vec::default(),
}, },
SymbolValue::OptionSome(v) => { SymbolValue::OptionSome(v) => {
let ty = v.get_type_annotation(primitives, unifier); let ty = v.get_type_annotation(primitives, unifier);
TypeAnnotation::CustomClass { TypeAnnotation::CustomClass {
id: primitives.option.obj_id(unifier).unwrap(), id: primitives.option.obj_id(unifier),
params: vec![ty], params: vec![ty],
} }
} }

View File

@ -11,7 +11,6 @@ use inkwell::{
use itertools::Either; use itertools::Either;
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
use crate::typecheck::typedef::{GenericObjectType, GenericTypeAdapter};
use crate::{ use crate::{
codegen::{ codegen::{
builtin_fns, builtin_fns,
@ -25,7 +24,7 @@ use crate::{
stmt::exn_constructor, stmt::exn_constructor,
}, },
symbol_resolver::SymbolValue, symbol_resolver::SymbolValue,
toplevel::{helper::PrimDef, numpy::make_ndarray_ty}, toplevel::helper::PrimDef,
typecheck::typedef::{into_var_map, TypeVar, VarMap}, typecheck::typedef::{into_var_map, TypeVar, VarMap},
}; };
@ -304,10 +303,7 @@ struct BuiltinBuilder<'a> {
is_some_ty: (Type, bool), is_some_ty: (Type, bool),
unwrap_ty: (Type, bool), unwrap_ty: (Type, bool),
option_tvar: TypeVar,
ndarray_dtype_tvar: TypeVar,
ndarray_ndims_tvar: TypeVar,
ndarray_copy_ty: (Type, bool), ndarray_copy_ty: (Type, bool),
ndarray_fill_ty: (Type, bool), ndarray_fill_ty: (Type, bool),
@ -316,9 +312,9 @@ struct BuiltinBuilder<'a> {
num_ty: TypeVar, num_ty: TypeVar,
num_var_map: VarMap, num_var_map: VarMap,
ndarray_float: Type, ndarray_float: primitive_type::NDArrayType,
ndarray_float_2d: Type, ndarray_float_2d: primitive_type::NDArrayType,
ndarray_num_ty: Type, ndarray_num_ty: primitive_type::NDArrayType,
float_or_ndarray_ty: TypeVar, float_or_ndarray_ty: TypeVar,
float_or_ndarray_var_map: VarMap, float_or_ndarray_var_map: VarMap,
@ -345,24 +341,19 @@ impl<'a> BuiltinBuilder<'a> {
} = *primitives; } = *primitives;
// Option-related // Option-related
let (is_some_ty, unwrap_ty, option_tvar) = let (is_some_ty, unwrap_ty) =
if let TypeEnum::TObj { fields, .. } = &*unifier.get_ty(option) { if let TypeEnum::TObj { fields, .. } = &*unifier.get_ty(option.into()) {
let option = GenericTypeAdapter::create(option, unifier);
( (
*fields.get(&PrimDef::OptionIsSome.simple_name().into()).unwrap(), *fields.get(&PrimDef::OptionIsSome.simple_name().into()).unwrap(),
*fields.get(&PrimDef::OptionUnwrap.simple_name().into()).unwrap(), *fields.get(&PrimDef::OptionUnwrap.simple_name().into()).unwrap(),
option.get_var_at(unifier, 0).unwrap(),
) )
} else { } else {
unreachable!() unreachable!()
}; };
let TypeEnum::TObj { fields: ndarray_fields, .. } = &*unifier.get_ty(ndarray) else { let TypeEnum::TObj { fields: ndarray_fields, .. } = &*unifier.get_ty(ndarray.into()) else {
unreachable!() unreachable!()
}; };
let ndarray = GenericTypeAdapter::create(ndarray, unifier);
let ndarray_dtype_tvar = ndarray.get_var_at(unifier, 0).unwrap();
let ndarray_ndims_tvar = ndarray.get_var_at(unifier, 1).unwrap();
let ndarray_copy_ty = let ndarray_copy_ty =
*ndarray_fields.get(&PrimDef::NDArrayCopy.simple_name().into()).unwrap(); *ndarray_fields.get(&PrimDef::NDArrayCopy.simple_name().into()).unwrap();
let ndarray_fill_ty = let ndarray_fill_ty =
@ -375,7 +366,8 @@ impl<'a> BuiltinBuilder<'a> {
); );
let num_var_map = into_var_map([num_ty]); let num_var_map = into_var_map([num_ty]);
let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), None); let ndarray_float =
primitive_type::NDArrayType::from_primitive(unifier, primitives, Some(float), None);
let ndarray_float_2d = { let ndarray_float_2d = {
let value = match primitives.size_t { let value = match primitives.size_t {
64 => SymbolValue::U64(2u64), 64 => SymbolValue::U64(2u64),
@ -384,16 +376,28 @@ impl<'a> BuiltinBuilder<'a> {
}; };
let ndims = unifier.add_ty(TypeEnum::TLiteral { values: vec![value], loc: None }); let ndims = unifier.add_ty(TypeEnum::TLiteral { values: vec![value], loc: None });
make_ndarray_ty(unifier, primitives, Some(float), Some(ndims)) primitive_type::NDArrayType::from_primitive(
unifier,
primitives,
Some(float),
Some(ndims),
)
}; };
let ndarray_num_ty = make_ndarray_ty(unifier, primitives, Some(num_ty.ty), None); let ndarray_num_ty =
let float_or_ndarray_ty = primitive_type::NDArrayType::from_primitive(unifier, primitives, Some(num_ty.ty), None);
unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); let float_or_ndarray_ty = unifier.get_fresh_var_with_range(
&[float, ndarray_float.into()],
Some("T".into()),
None,
);
let float_or_ndarray_var_map = into_var_map([float_or_ndarray_ty]); let float_or_ndarray_var_map = into_var_map([float_or_ndarray_ty]);
let num_or_ndarray_ty = let num_or_ndarray_ty = unifier.get_fresh_var_with_range(
unifier.get_fresh_var_with_range(&[num_ty.ty, ndarray_num_ty], Some("T".into()), None); &[num_ty.ty, ndarray_num_ty.into()],
Some("T".into()),
None,
);
let num_or_ndarray_var_map = into_var_map([num_ty, num_or_ndarray_ty]); let num_or_ndarray_var_map = into_var_map([num_ty, num_or_ndarray_ty]);
let list_int32 = unifier.add_ty(TypeEnum::TList { ty: int32 }); let list_int32 = unifier.add_ty(TypeEnum::TList { ty: int32 });
@ -406,10 +410,7 @@ impl<'a> BuiltinBuilder<'a> {
is_some_ty, is_some_ty,
unwrap_ty, unwrap_ty,
option_tvar,
ndarray_dtype_tvar,
ndarray_ndims_tvar,
ndarray_copy_ty, ndarray_copy_ty,
ndarray_fill_ty, ndarray_fill_ty,
@ -633,7 +634,7 @@ impl<'a> BuiltinBuilder<'a> {
PrimDef::Option => TopLevelDef::Class { PrimDef::Option => TopLevelDef::Class {
name: prim.name().into(), name: prim.name().into(),
object_id: prim.id(), object_id: prim.id(),
type_vars: vec![self.option_tvar.ty], type_vars: vec![self.primitives.option.type_tvar(self.unifier).ty],
fields: Vec::default(), fields: Vec::default(),
attributes: Vec::default(), attributes: Vec::default(),
methods: vec![ methods: vec![
@ -654,7 +655,7 @@ impl<'a> BuiltinBuilder<'a> {
name: prim.name().into(), name: prim.name().into(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.unwrap_ty.0, signature: self.unwrap_ty.0,
var_id: vec![self.option_tvar.id], var_id: vec![self.primitives.option.type_tvar(self.unifier).id],
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
@ -668,7 +669,7 @@ impl<'a> BuiltinBuilder<'a> {
name: prim.name().to_string(), name: prim.name().to_string(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.is_some_ty.0, signature: self.is_some_ty.0,
var_id: vec![self.option_tvar.id], var_id: vec![self.primitives.option.type_tvar(self.unifier).id],
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
@ -699,19 +700,22 @@ impl<'a> BuiltinBuilder<'a> {
loc: None, loc: None,
}, },
PrimDef::FunSome => TopLevelDef::Function { PrimDef::FunSome => {
let option_tvar = self.primitives.option.type_tvar(self.unifier);
TopLevelDef::Function {
name: prim.name().into(), name: prim.name().into(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { args: vec![FuncArg {
name: "n".into(), name: "n".into(),
ty: self.option_tvar.ty, ty: option_tvar.ty,
default_value: None, default_value: None,
}], }],
ret: self.primitives.option, ret: self.primitives.option.into(),
vars: into_var_map([self.option_tvar]), vars: into_var_map([option_tvar]),
})), })),
var_id: vec![self.option_tvar.id], var_id: vec![self.primitives.option.type_tvar(self.unifier).id],
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
@ -728,7 +732,8 @@ impl<'a> BuiltinBuilder<'a> {
}, },
)))), )))),
loc: None, loc: None,
}, }
}
_ => { _ => {
unreachable!() unreachable!()
@ -737,7 +742,7 @@ impl<'a> BuiltinBuilder<'a> {
} }
/// Build the class `ndarray` and its associated methods. /// Build the class `ndarray` and its associated methods.
fn build_ndarray_class_related(&self, prim: PrimDef) -> TopLevelDef { fn build_ndarray_class_related(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed( debug_assert_prim_is_allowed(
prim, prim,
&[PrimDef::NDArray, PrimDef::NDArrayCopy, PrimDef::NDArrayFill], &[PrimDef::NDArray, PrimDef::NDArrayCopy, PrimDef::NDArrayFill],
@ -747,7 +752,10 @@ impl<'a> BuiltinBuilder<'a> {
PrimDef::NDArray => TopLevelDef::Class { PrimDef::NDArray => TopLevelDef::Class {
name: prim.name().into(), name: prim.name().into(),
object_id: prim.id(), object_id: prim.id(),
type_vars: vec![self.ndarray_dtype_tvar.ty, self.ndarray_ndims_tvar.ty], type_vars: vec![
self.primitives.ndarray.dtype_tvar(self.unifier).ty,
self.primitives.ndarray.ndims_tvar(self.unifier).ty,
],
fields: Vec::default(), fields: Vec::default(),
attributes: Vec::default(), attributes: Vec::default(),
methods: vec![ methods: vec![
@ -764,7 +772,10 @@ impl<'a> BuiltinBuilder<'a> {
name: prim.name().into(), name: prim.name().into(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.ndarray_copy_ty.0, signature: self.ndarray_copy_ty.0,
var_id: vec![self.ndarray_dtype_tvar.id, self.ndarray_ndims_tvar.id], var_id: vec![
self.primitives.ndarray.dtype_tvar(self.unifier).id,
self.primitives.ndarray.ndims_tvar(self.unifier).id,
],
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
@ -781,7 +792,10 @@ impl<'a> BuiltinBuilder<'a> {
name: prim.name().into(), name: prim.name().into(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.ndarray_fill_ty.0, signature: self.ndarray_fill_ty.0,
var_id: vec![self.ndarray_dtype_tvar.id, self.ndarray_ndims_tvar.id], var_id: vec![
self.primitives.ndarray.dtype_tvar(self.unifier).id,
self.primitives.ndarray.ndims_tvar(self.unifier).id,
],
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
@ -870,15 +884,26 @@ impl<'a> BuiltinBuilder<'a> {
// The size variant of the function determines the size of the returned int. // The size variant of the function determines the size of the returned int.
let int_sized = size_variant.of_int(self.primitives); let int_sized = size_variant.of_int(self.primitives);
let ndarray_int_sized = let ndarray_int_sized = primitive_type::NDArrayType::from_primitive(
make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.ty)); self.unifier,
let ndarray_float = self.primitives,
make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.ty)); Some(int_sized),
Some(common_ndim.ty),
);
let ndarray_float = primitive_type::NDArrayType::from_primitive(
self.unifier,
self.primitives,
Some(float),
Some(common_ndim.ty),
);
let p0_ty = let p0_ty = self.unifier.get_fresh_var_with_range(
self.unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); &[float, ndarray_float.into()],
Some("T".into()),
None,
);
let ret_ty = self.unifier.get_fresh_var_with_range( let ret_ty = self.unifier.get_fresh_var_with_range(
&[int_sized, ndarray_int_sized], &[int_sized, ndarray_int_sized.into()],
Some("R".into()), Some("R".into()),
None, None,
); );
@ -930,19 +955,30 @@ impl<'a> BuiltinBuilder<'a> {
None, None,
); );
let ndarray_float = let ndarray_float = primitive_type::NDArrayType::from_primitive(
make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.ty)); self.unifier,
self.primitives,
Some(float),
Some(common_ndim.ty),
);
// The size variant of the function determines the type of int returned // The size variant of the function determines the type of int returned
let int_sized = size_variant.of_int(self.primitives); let int_sized = size_variant.of_int(self.primitives);
let ndarray_int_sized = let ndarray_int_sized = primitive_type::NDArrayType::from_primitive(
make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.ty)); self.unifier,
self.primitives,
Some(int_sized),
Some(common_ndim.ty),
);
let p0_ty = let p0_ty = self.unifier.get_fresh_var_with_range(
self.unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); &[float, ndarray_float.into()],
Some("T".into()),
None,
);
let ret_ty = self.unifier.get_fresh_var_with_range( let ret_ty = self.unifier.get_fresh_var_with_range(
&[int_sized, ndarray_int_sized], &[int_sized, ndarray_int_sized.into()],
Some("R".into()), Some("R".into()),
None, None,
); );
@ -1005,7 +1041,7 @@ impl<'a> BuiltinBuilder<'a> {
self.unifier, self.unifier,
&VarMap::new(), &VarMap::new(),
prim.name(), prim.name(),
self.ndarray_float, self.ndarray_float.into(),
&[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")], &[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
Box::new(move |ctx, obj, fun, args, generator| { Box::new(move |ctx, obj, fun, args, generator| {
let func = match prim { let func = match prim {
@ -1051,7 +1087,7 @@ impl<'a> BuiltinBuilder<'a> {
default_value: Some(SymbolValue::U32(0)), default_value: Some(SymbolValue::U32(0)),
}, },
], ],
ret: ndarray, ret: ndarray.into(),
vars: into_var_map([tv]), vars: into_var_map([tv]),
})), })),
var_id: vec![tv.id], var_id: vec![tv.id],
@ -1074,7 +1110,7 @@ impl<'a> BuiltinBuilder<'a> {
self.unifier, self.unifier,
&into_var_map([tv]), &into_var_map([tv]),
prim.name(), prim.name(),
self.primitives.ndarray, self.primitives.ndarray.into(),
// We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a
// type variable // type variable
&[(self.list_int32, "shape"), (tv.ty, "fill_value")], &[(self.list_int32, "shape"), (tv.ty, "fill_value")],
@ -1103,7 +1139,7 @@ impl<'a> BuiltinBuilder<'a> {
default_value: Some(SymbolValue::I32(0)), default_value: Some(SymbolValue::I32(0)),
}, },
], ],
ret: self.ndarray_float_2d, ret: self.ndarray_float_2d.into(),
vars: VarMap::default(), vars: VarMap::default(),
})), })),
var_id: Vec::default(), var_id: Vec::default(),
@ -1123,7 +1159,7 @@ impl<'a> BuiltinBuilder<'a> {
self.unifier, self.unifier,
&VarMap::new(), &VarMap::new(),
prim.name(), prim.name(),
self.ndarray_float_2d, self.ndarray_float_2d.into(),
&[(int32, "n")], &[(int32, "n")],
Box::new(|ctx, obj, fun, args, generator| { Box::new(|ctx, obj, fun, args, generator| {
gen_ndarray_identity(ctx, &obj, fun, &args, generator) gen_ndarray_identity(ctx, &obj, fun, &args, generator)
@ -1338,10 +1374,15 @@ impl<'a> BuiltinBuilder<'a> {
let tvar = self.unifier.get_fresh_var(Some("L".into()), None); let tvar = self.unifier.get_fresh_var(Some("L".into()), None);
let list = self.unifier.add_ty(TypeEnum::TList { ty: tvar.ty }); let list = self.unifier.add_ty(TypeEnum::TList { ty: tvar.ty });
let ndims = self.unifier.get_fresh_const_generic_var(uint64, Some("N".into()), None); let ndims = self.unifier.get_fresh_const_generic_var(uint64, Some("N".into()), None);
let ndarray = make_ndarray_ty(self.unifier, self.primitives, Some(tvar.ty), Some(ndims.ty)); let ndarray = primitive_type::NDArrayType::from_primitive(
self.unifier,
self.primitives,
Some(tvar.ty),
Some(ndims.ty),
);
let arg_ty = self.unifier.get_fresh_var_with_range( let arg_ty = self.unifier.get_fresh_var_with_range(
&[list, ndarray, self.primitives.range], &[list, ndarray.into(), self.primitives.range],
Some("I".into()), Some("I".into()),
None, None,
); );
@ -1799,8 +1840,13 @@ impl<'a> BuiltinBuilder<'a> {
} }
fn new_type_or_ndarray_ty(&mut self, scalar_ty: Type) -> TypeVar { fn new_type_or_ndarray_ty(&mut self, scalar_ty: Type) -> TypeVar {
let ndarray = make_ndarray_ty(self.unifier, self.primitives, Some(scalar_ty), None); let ndarray = primitive_type::NDArrayType::from_primitive(
self.unifier,
self.primitives,
Some(scalar_ty),
None,
);
self.unifier.get_fresh_var_with_range(&[scalar_ty, ndarray], Some("T".into()), None) self.unifier.get_fresh_var_with_range(&[scalar_ty, ndarray.into()], Some("T".into()), None)
} }
} }

View File

@ -1,14 +1,13 @@
use std::convert::TryInto; use std::convert::TryInto;
use super::*;
use crate::symbol_resolver::SymbolValue; use crate::symbol_resolver::SymbolValue;
use crate::toplevel::numpy::unpack_ndarray_var_tys; use crate::toplevel::primitive_type::{NDArrayType, OptionType};
use crate::typecheck::typedef::{into_var_map, Mapping, TypeVarId, VarMap}; use crate::typecheck::typedef::{into_var_map, GenericObjectType, Mapping, TypeVarId, VarMap};
use nac3parser::ast::{Constant, Location}; use nac3parser::ast::{Constant, Location};
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
use strum_macros::EnumIter; use strum_macros::EnumIter;
use super::*;
/// All primitive types and functions in nac3core. /// All primitive types and functions in nac3core.
#[derive(Clone, Copy, Debug, EnumIter, PartialEq, Eq)] #[derive(Clone, Copy, Debug, EnumIter, PartialEq, Eq)]
pub enum PrimDef { pub enum PrimDef {
@ -403,6 +402,7 @@ impl TopLevelComposer {
.collect::<HashMap<_, _>>(), .collect::<HashMap<_, _>>(),
params: into_var_map([option_type_var]), params: into_var_map([option_type_var]),
}); });
let option = OptionType::create(option, &mut unifier);
let size_t_ty = match size_t { let size_t_ty = match size_t {
32 => uint32, 32 => uint32,
@ -436,8 +436,9 @@ impl TopLevelComposer {
]), ]),
params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]), params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
}); });
let ndarray = NDArrayType::create(ndarray, &mut unifier);
unifier.unify(ndarray_copy_fun_ret_ty.ty, ndarray).unwrap(); unifier.unify(ndarray_copy_fun_ret_ty.ty, ndarray.into()).unwrap();
let primitives = PrimitiveStore { let primitives = PrimitiveStore {
int32, int32,
@ -747,7 +748,7 @@ impl TopLevelComposer {
TypeAnnotation::CustomClass { id: e_id, params: e_param }, TypeAnnotation::CustomClass { id: e_id, params: e_param },
) => { ) => {
*f_id == *e_id *f_id == *e_id
&& *f_id == primitive.option.obj_id(unifier).unwrap() && *f_id == primitive.option.obj_id(unifier)
&& (f_param.is_empty() && (f_param.is_empty()
|| (f_param.len() == 1 || (f_param.len() == 1
&& e_param.len() == 1 && e_param.len() == 1
@ -885,7 +886,7 @@ pub fn parse_parameter_default_value(
pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type { pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type {
match &*unifier.get_ty(ty) { match &*unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
unpack_ndarray_var_tys(unifier, ty).0 NDArrayType::create(ty, unifier).dtype_tvar(unifier).ty
} }
TypeEnum::TList { ty } => arraylike_flatten_element_type(unifier, *ty), TypeEnum::TList { ty } => arraylike_flatten_element_type(unifier, *ty),
@ -897,7 +898,7 @@ pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type {
pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 { pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 {
match &*unifier.get_ty(ty) { match &*unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let ndims = unpack_ndarray_var_tys(unifier, ty).1; let ndims = NDArrayType::create(ty, unifier).ndims_tvar(unifier).ty;
let TypeEnum::TLiteral { values, .. } = &*unifier.get_ty_immutable(ndims) else { let TypeEnum::TLiteral { values, .. } = &*unifier.get_ty_immutable(ndims) else {
panic!("Expected TLiteral for ndarray.ndims, got {}", unifier.stringify(ndims)) panic!("Expected TLiteral for ndarray.ndims, got {}", unifier.stringify(ndims))
}; };

View File

@ -30,7 +30,7 @@ pub struct DefinitionId(pub usize);
pub mod builtins; pub mod builtins;
pub mod composer; pub mod composer;
pub mod helper; pub mod helper;
pub mod numpy; pub mod primitive_type;
pub mod type_annotation; pub mod type_annotation;
use composer::*; use composer::*;
use type_annotation::*; use type_annotation::*;

View File

@ -1,85 +0,0 @@
use crate::{
toplevel::helper::PrimDef,
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{Type, TypeEnum, TypeVarId, Unifier, VarMap},
},
};
use itertools::Itertools;
/// Creates a `ndarray` [`Type`] with the given type arguments.
///
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
pub fn make_ndarray_ty(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
dtype: Option<Type>,
ndims: Option<Type>,
) -> Type {
subst_ndarray_tvars(unifier, primitives.ndarray, dtype, ndims)
}
/// Substitutes type variables in `ndarray`.
///
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
pub fn subst_ndarray_tvars(
unifier: &mut Unifier,
ndarray: Type,
dtype: Option<Type>,
ndims: Option<Type>,
) -> Type {
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
};
debug_assert_eq!(*obj_id, PrimDef::NDArray.id());
if dtype.is_none() && ndims.is_none() {
return ndarray;
}
let tvar_ids = params.iter().map(|(obj_id, _)| *obj_id).collect_vec();
debug_assert_eq!(tvar_ids.len(), 2);
let mut tvar_subst = VarMap::new();
if let Some(dtype) = dtype {
tvar_subst.insert(tvar_ids[0], dtype);
}
if let Some(ndims) = ndims {
tvar_subst.insert(tvar_ids[1], ndims);
}
unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray)
}
fn unpack_ndarray_tvars(unifier: &mut Unifier, ndarray: Type) -> Vec<(TypeVarId, Type)> {
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
};
debug_assert_eq!(*obj_id, PrimDef::NDArray.id());
debug_assert_eq!(params.len(), 2);
params
.iter()
.sorted_by_key(|(obj_id, _)| *obj_id)
.map(|(var_id, ty)| (*var_id, *ty))
.collect_vec()
}
/// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds
/// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray`
/// respectively.
pub fn unpack_ndarray_var_ids(unifier: &mut Unifier, ndarray: Type) -> (TypeVarId, TypeVarId) {
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.0).collect_tuple().unwrap()
}
/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to
/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively.
pub fn unpack_ndarray_var_tys(unifier: &mut Unifier, ndarray: Type) -> (Type, Type) {
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.1).collect_tuple().unwrap()
}

View File

@ -0,0 +1,98 @@
use crate::toplevel::helper::PrimDef;
use crate::typecheck::type_inferencer::PrimitiveStore;
use crate::typecheck::typedef::{GenericObjectType, Type, TypeVar, Unifier, VarMap};
#[derive(Clone, Copy)]
pub struct OptionType(Type);
impl OptionType {
pub fn from_primitive(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
type_ty: Option<Type>,
) -> Self {
primitives.option.subst(unifier, type_ty)
}
pub fn type_tvar(&self, unifier: &mut Unifier) -> TypeVar {
self.get_var_at(unifier, 0).unwrap()
}
#[must_use]
pub fn subst(&self, unifier: &mut Unifier, type_ty: Option<Type>) -> Self {
let new_vars = [(self.type_tvar(unifier).id, type_ty)]
.into_iter()
.filter_map(|(id, ty)| ty.map(|ty| (id, ty)))
.collect::<VarMap>();
let new_ty = unifier.subst(self.get_type(), &new_vars).unwrap_or(self.get_type());
OptionType(new_ty)
}
}
impl GenericObjectType for OptionType {
fn try_create(ty: Type, unifier: &mut Unifier) -> Option<Self> {
if ty.obj_id(unifier).is_some_and(|id| id == PrimDef::Option.id()) {
Some(OptionType(ty))
} else {
None
}
}
fn get_type(&self) -> Type {
self.0
}
}
#[derive(Clone, Copy)]
pub struct NDArrayType(Type);
impl NDArrayType {
pub fn from_primitive(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
dtype: Option<Type>,
ndims: Option<Type>,
) -> Self {
primitives.ndarray.subst(unifier, dtype, ndims)
}
pub fn dtype_tvar(&self, unifier: &mut Unifier) -> TypeVar {
self.get_var_at(unifier, 0).unwrap()
}
pub fn ndims_tvar(&self, unifier: &mut Unifier) -> TypeVar {
self.get_var_at(unifier, 1).unwrap()
}
#[must_use]
pub fn subst(
&self,
unifier: &mut Unifier,
dtype_ty: Option<Type>,
ndims_ty: Option<Type>,
) -> Self {
let new_vars =
[(self.dtype_tvar(unifier).id, dtype_ty), (self.ndims_tvar(unifier).id, ndims_ty)]
.into_iter()
.filter_map(|(id, ty)| ty.map(|ty| (id, ty)))
.collect::<VarMap>();
let new_ty = unifier.subst(self.get_type(), &new_vars).unwrap_or(self.get_type());
NDArrayType(new_ty)
}
}
impl GenericObjectType for NDArrayType {
fn try_create(ty: Type, unifier: &mut Unifier) -> Option<Self> {
if ty.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
Some(NDArrayType(ty))
} else {
None
}
}
fn get_type(&self) -> Type {
self.0
}
}

View File

@ -1,7 +1,7 @@
use super::*; use super::*;
use crate::symbol_resolver::SymbolValue; use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PrimDef; use crate::toplevel::helper::PrimDef;
use crate::typecheck::typedef::VarMap; use crate::typecheck::typedef::{GenericObjectType, VarMap};
use nac3parser::ast::Constant; use nac3parser::ast::Constant;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -267,12 +267,7 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
slice.as_ref(), slice.as_ref(),
locked, locked,
)?; )?;
let id = let id = primitives.option.obj_id(unifier);
if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(primitives.option).as_ref() {
*obj_id
} else {
unreachable!()
};
Ok(TypeAnnotation::CustomClass { id, params: vec![def_ann] }) Ok(TypeAnnotation::CustomClass { id, params: vec![def_ann] })
} }

View File

@ -1,9 +1,9 @@
use crate::symbol_resolver::SymbolValue; use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PrimDef; use crate::toplevel::helper::PrimDef;
use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys}; use crate::toplevel::primitive_type;
use crate::typecheck::{ use crate::typecheck::{
type_inferencer::*, type_inferencer::*,
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, typedef::{FunSignature, FuncArg, GenericObjectType, Type, TypeEnum, Unifier, VarMap},
}; };
use itertools::Itertools; use itertools::Itertools;
use nac3parser::ast::StrRef; use nac3parser::ast::StrRef;
@ -369,8 +369,12 @@ pub fn typeof_ndarray_broadcast(
if is_left_ndarray && is_right_ndarray { if is_left_ndarray && is_right_ndarray {
// Perform broadcasting on two ndarray operands. // Perform broadcasting on two ndarray operands.
let (left_ty_dtype, left_ty_ndims) = unpack_ndarray_var_tys(unifier, left); let left_ty = primitive_type::NDArrayType::create(left, unifier);
let (right_ty_dtype, right_ty_ndims) = unpack_ndarray_var_tys(unifier, right); let left_ty_dtype = left_ty.dtype_tvar(unifier).ty;
let left_ty_ndims = left_ty.ndims_tvar(unifier).ty;
let right_ty = primitive_type::NDArrayType::create(right, unifier);
let right_ty_dtype = right_ty.dtype_tvar(unifier).ty;
let right_ty_ndims = right_ty.ndims_tvar(unifier).ty;
assert!(unifier.unioned(left_ty_dtype, right_ty_dtype)); assert!(unifier.unioned(left_ty_dtype, right_ty_dtype));
@ -397,11 +401,18 @@ pub fn typeof_ndarray_broadcast(
.collect_vec(); .collect_vec();
let res_ndims = unifier.get_fresh_literal(res_ndims, None); let res_ndims = unifier.get_fresh_literal(res_ndims, None);
Ok(make_ndarray_ty(unifier, primitives, Some(left_ty_dtype), Some(res_ndims))) Ok(primitive_type::NDArrayType::from_primitive(
unifier,
primitives,
Some(left_ty_dtype),
Some(res_ndims),
)
.into())
} else { } else {
let (ndarray_ty, scalar_ty) = if is_left_ndarray { (left, right) } else { (right, left) }; let (ndarray_ty, scalar_ty) = if is_left_ndarray { (left, right) } else { (right, left) };
let (ndarray_ty_dtype, _) = unpack_ndarray_var_tys(unifier, ndarray_ty); let ndarray_ty_dtype =
primitive_type::NDArrayType::create(ndarray_ty, unifier).ndims_tvar(unifier).ty;
if unifier.unioned(ndarray_ty_dtype, scalar_ty) { if unifier.unioned(ndarray_ty_dtype, scalar_ty) {
Ok(ndarray_ty) Ok(ndarray_ty)
@ -444,7 +455,8 @@ pub fn typeof_binop(
} }
Operator::MatMult => { Operator::MatMult => {
let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs); let lhs_ndims =
primitive_type::NDArrayType::create(lhs, unifier).ndims_tvar(unifier).ty;
let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) { let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) {
TypeEnum::TLiteral { values, .. } => { TypeEnum::TLiteral { values, .. } => {
assert_eq!(values.len(), 1); assert_eq!(values.len(), 1);
@ -452,7 +464,8 @@ pub fn typeof_binop(
} }
_ => unreachable!(), _ => unreachable!(),
}; };
let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs); let rhs_ndims =
primitive_type::NDArrayType::create(rhs, unifier).ndims_tvar(unifier).ty;
let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) { let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) {
TypeEnum::TLiteral { values, .. } => { TypeEnum::TLiteral { values, .. } => {
assert_eq!(values.len(), 1); assert_eq!(values.len(), 1);
@ -526,7 +539,7 @@ pub fn typeof_unaryop(
let operand_obj_id = operand.obj_id(unifier); let operand_obj_id = operand.obj_id(unifier);
if op == Unaryop::Not if op == Unaryop::Not
&& operand_obj_id.is_some_and(|id| id == primitives.ndarray.obj_id(unifier).unwrap()) && operand_obj_id.is_some_and(|id| id == primitives.ndarray.obj_id(unifier))
{ {
return Err( return Err(
"The truth value of an array with more than one element is ambiguous".to_string() "The truth value of an array with more than one element is ambiguous".to_string()
@ -552,7 +565,8 @@ pub fn typeof_unaryop(
Unaryop::UAdd | Unaryop::USub => { Unaryop::UAdd | Unaryop::USub => {
if operand_obj_id.is_some_and(|id| id == PrimDef::NDArray.id()) { if operand_obj_id.is_some_and(|id| id == PrimDef::NDArray.id()) {
let (dtype, _) = unpack_ndarray_var_tys(unifier, operand); let dtype =
primitive_type::NDArrayType::create(operand, unifier).dtype_tvar(unifier).ty;
if dtype.obj_id(unifier).is_some_and(|id| id == PrimDef::Bool.id()) { if dtype.obj_id(unifier).is_some_and(|id| id == PrimDef::Bool.id()) {
return Err(if op == Unaryop::UAdd { return Err(if op == Unaryop::UAdd {
"The ufunc 'positive' cannot be applied to ndarray[bool, N]".to_string() "The ufunc 'positive' cannot be applied to ndarray[bool, N]".to_string()
@ -586,9 +600,15 @@ pub fn typeof_cmpop(
Ok(Some(if is_left_ndarray || is_right_ndarray { Ok(Some(if is_left_ndarray || is_right_ndarray {
let brd = typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?; let brd = typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?;
let (_, ndims) = unpack_ndarray_var_tys(unifier, brd); let ndims = primitive_type::NDArrayType::create(brd, unifier).ndims_tvar(unifier).ty;
make_ndarray_ty(unifier, primitives, Some(primitives.bool), Some(ndims)) primitive_type::NDArrayType::from_primitive(
unifier,
primitives,
Some(primitives.bool),
Some(ndims),
)
.into()
} else if unifier.unioned(lhs, rhs) { } else if unifier.unioned(lhs, rhs) {
primitives.bool primitives.bool
} else { } else {
@ -611,64 +631,108 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
/* int ======== */ /* int ======== */
for t in [int32_t, int64_t, uint32_t, uint64_t] { for t in [int32_t, int64_t, uint32_t, uint64_t] {
let ndarray_int_t = make_ndarray_ty(unifier, store, Some(t), None); let ndarray_int_t =
impl_basic_arithmetic(unifier, store, t, &[t, ndarray_int_t], None); primitive_type::NDArrayType::from_primitive(unifier, store, Some(t), None);
impl_pow(unifier, store, t, &[t, ndarray_int_t], None); impl_basic_arithmetic(unifier, store, t, &[t, ndarray_int_t.into()], None);
impl_pow(unifier, store, t, &[t, ndarray_int_t.into()], None);
impl_bitwise_arithmetic(unifier, store, t); impl_bitwise_arithmetic(unifier, store, t);
impl_bitwise_shift(unifier, store, t); impl_bitwise_shift(unifier, store, t);
impl_div(unifier, store, t, &[t, ndarray_int_t], None); impl_div(unifier, store, t, &[t, ndarray_int_t.into()], None);
impl_floordiv(unifier, store, t, &[t, ndarray_int_t], None); impl_floordiv(unifier, store, t, &[t, ndarray_int_t.into()], None);
impl_mod(unifier, store, t, &[t, ndarray_int_t], None); impl_mod(unifier, store, t, &[t, ndarray_int_t.into()], None);
impl_invert(unifier, store, t, Some(t)); impl_invert(unifier, store, t, Some(t));
impl_not(unifier, store, t, Some(bool_t)); impl_not(unifier, store, t, Some(bool_t));
impl_comparison(unifier, store, t, &[t, ndarray_int_t], None); impl_comparison(unifier, store, t, &[t, ndarray_int_t.into()], None);
impl_eq(unifier, store, t, &[t, ndarray_int_t], None); impl_eq(unifier, store, t, &[t, ndarray_int_t.into()], None);
} }
for t in [int32_t, int64_t] { for t in [int32_t, int64_t] {
impl_sign(unifier, store, t, Some(t)); impl_sign(unifier, store, t, Some(t));
} }
/* float ======== */ /* float ======== */
let ndarray_float_t = make_ndarray_ty(unifier, store, Some(float_t), None); let ndarray_float_t =
let ndarray_int32_t = make_ndarray_ty(unifier, store, Some(int32_t), None); primitive_type::NDArrayType::from_primitive(unifier, store, Some(float_t), None);
impl_basic_arithmetic(unifier, store, float_t, &[float_t, ndarray_float_t], None); let ndarray_int32_t =
impl_pow(unifier, store, float_t, &[int32_t, float_t, ndarray_int32_t, ndarray_float_t], None); primitive_type::NDArrayType::from_primitive(unifier, store, Some(int32_t), None);
impl_div(unifier, store, float_t, &[float_t, ndarray_float_t], None); impl_basic_arithmetic(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None);
impl_floordiv(unifier, store, float_t, &[float_t, ndarray_float_t], None); impl_pow(
impl_mod(unifier, store, float_t, &[float_t, ndarray_float_t], None); unifier,
store,
float_t,
&[int32_t, float_t, ndarray_int32_t.into(), ndarray_float_t.into()],
None,
);
impl_div(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None);
impl_floordiv(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None);
impl_mod(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None);
impl_sign(unifier, store, float_t, Some(float_t)); impl_sign(unifier, store, float_t, Some(float_t));
impl_not(unifier, store, float_t, Some(bool_t)); impl_not(unifier, store, float_t, Some(bool_t));
impl_comparison(unifier, store, float_t, &[float_t, ndarray_float_t], None); impl_comparison(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None);
impl_eq(unifier, store, float_t, &[float_t, ndarray_float_t], None); impl_eq(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None);
/* bool ======== */ /* bool ======== */
let ndarray_bool_t = make_ndarray_ty(unifier, store, Some(bool_t), None); let ndarray_bool_t =
primitive_type::NDArrayType::from_primitive(unifier, store, Some(bool_t), None);
impl_invert(unifier, store, bool_t, Some(int32_t)); impl_invert(unifier, store, bool_t, Some(int32_t));
impl_not(unifier, store, bool_t, Some(bool_t)); impl_not(unifier, store, bool_t, Some(bool_t));
impl_sign(unifier, store, bool_t, Some(int32_t)); impl_sign(unifier, store, bool_t, Some(int32_t));
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None); impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t.into()], None);
/* ndarray ===== */ /* ndarray ===== */
let ndarray_usized_ndims_tvar = let ndarray_usized_ndims_tvar =
unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None); unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None);
let ndarray_unsized_t = let ndarray_unsized_t = primitive_type::NDArrayType::from_primitive(
make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.ty)); unifier,
let (ndarray_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_t); store,
let (ndarray_unsized_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_unsized_t); None,
Some(ndarray_usized_ndims_tvar.ty),
);
let ndarray_dtype_t = ndarray_t.dtype_tvar(unifier).ty;
let ndarray_unsized_dtype_t = ndarray_unsized_t.dtype_tvar(unifier).ty;
impl_basic_arithmetic( impl_basic_arithmetic(
unifier, unifier,
store, store,
ndarray_t, ndarray_t.into(),
&[ndarray_unsized_t, ndarray_unsized_dtype_t], &[ndarray_unsized_t.into(), ndarray_unsized_dtype_t],
None,
);
impl_pow(
unifier,
store,
ndarray_t.into(),
&[ndarray_unsized_t.into(), ndarray_unsized_dtype_t],
None,
);
impl_div(unifier, store, ndarray_t.into(), &[ndarray_t.into(), ndarray_dtype_t], None);
impl_floordiv(
unifier,
store,
ndarray_t.into(),
&[ndarray_unsized_t.into(), ndarray_unsized_dtype_t],
None,
);
impl_mod(
unifier,
store,
ndarray_t.into(),
&[ndarray_unsized_t.into(), ndarray_unsized_dtype_t],
None,
);
impl_matmul(unifier, store, ndarray_t.into(), &[ndarray_t.into()], Some(ndarray_t.into()));
impl_sign(unifier, store, ndarray_t.into(), Some(ndarray_t.into()));
impl_invert(unifier, store, ndarray_t.into(), Some(ndarray_t.into()));
impl_eq(
unifier,
store,
ndarray_t.into(),
&[ndarray_unsized_t.into(), ndarray_unsized_dtype_t],
None,
);
impl_comparison(
unifier,
store,
ndarray_t.into(),
&[ndarray_unsized_t.into(), ndarray_unsized_dtype_t],
None, None,
); );
impl_pow(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None);
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_matmul(unifier, store, ndarray_t, &[ndarray_t], Some(ndarray_t));
impl_sign(unifier, store, ndarray_t, Some(ndarray_t));
impl_invert(unifier, store, ndarray_t, Some(ndarray_t));
impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_comparison(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
} }

View File

@ -4,14 +4,16 @@ use std::iter::once;
use std::ops::Not; use std::ops::Not;
use std::{cell::RefCell, sync::Arc}; use std::{cell::RefCell, sync::Arc};
use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap}; use super::typedef::{
Call, FunSignature, FuncArg, GenericObjectType, RecordField, Type, TypeEnum, Unifier, VarMap,
};
use super::{magic_methods::*, type_error::TypeError, typedef::CallId}; use super::{magic_methods::*, type_error::TypeError, typedef::CallId};
use crate::toplevel::primitive_type::{NDArrayType, OptionType};
use crate::toplevel::TopLevelDef; use crate::toplevel::TopLevelDef;
use crate::{ use crate::{
symbol_resolver::{SymbolResolver, SymbolValue}, symbol_resolver::{SymbolResolver, SymbolValue},
toplevel::{ toplevel::{
helper::{arraylike_flatten_element_type, arraylike_get_ndims, PrimDef}, helper::{arraylike_flatten_element_type, arraylike_get_ndims, PrimDef},
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
TopLevelContext, TopLevelContext,
}, },
}; };
@ -49,8 +51,8 @@ pub struct PrimitiveStore {
pub range: Type, pub range: Type,
pub str: Type, pub str: Type,
pub exception: Type, pub exception: Type,
pub option: Type, pub option: OptionType,
pub ndarray: Type, pub ndarray: NDArrayType,
pub size_t: u32, pub size_t: u32,
} }
@ -97,8 +99,8 @@ impl IntoIterator for &PrimitiveStore {
self.range, self.range,
self.str, self.str,
self.exception, self.exception,
self.option, self.option.into(),
self.ndarray, self.ndarray.into(),
] ]
.into_iter() .into_iter()
} }
@ -528,7 +530,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
// the name `none` is special since it may have different types // the name `none` is special since it may have different types
if id == &"none".into() { if id == &"none".into() {
if let TypeEnum::TObj { params, .. } = if let TypeEnum::TObj { params, .. } =
self.unifier.get_ty_immutable(self.primitives.option).as_ref() &*self.unifier.get_ty_immutable(self.primitives.option.into())
{ {
let var_map = params let var_map = params
.iter() .iter()
@ -543,7 +545,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
(*id, self.unifier.get_fresh_var_with_range(range, *name, *loc).ty) (*id, self.unifier.get_fresh_var_with_range(range, *name, *loc).ty)
}) })
.collect::<VarMap>(); .collect::<VarMap>();
Some(self.unifier.subst(self.primitives.option, &var_map).unwrap()) Some(self.unifier.subst(self.primitives.option.into(), &var_map).unwrap())
} else { } else {
unreachable!("must be tobj") unreachable!("must be tobj")
} }
@ -1062,9 +1064,16 @@ impl<'a> Inferencer<'a> {
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
{ {
let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty); let ndarray_ndims =
NDArrayType::create(arg0_ty, self.unifier).ndims_tvar(self.unifier).ty;
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims)) NDArrayType::from_primitive(
self.unifier,
self.primitives,
Some(target_ty),
Some(ndarray_ndims),
)
.into()
} else { } else {
target_ty target_ty
}; };
@ -1100,9 +1109,7 @@ impl<'a> Inferencer<'a> {
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
{ {
let (ndarray_dtype, _) = unpack_ndarray_var_tys(self.unifier, arg0_ty); NDArrayType::create(arg0_ty, self.unifier).dtype_tvar(self.unifier).ty
ndarray_dtype
} else { } else {
arg0_ty arg0_ty
}; };
@ -1154,14 +1161,14 @@ impl<'a> Inferencer<'a> {
let arg0_dtype = let arg0_dtype =
if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
unpack_ndarray_var_tys(self.unifier, arg0_ty).0 NDArrayType::create(arg0_ty, self.unifier).dtype_tvar(self.unifier).ty
} else { } else {
arg0_ty arg0_ty
}; };
let arg1_dtype = let arg1_dtype =
if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
unpack_ndarray_var_tys(self.unifier, arg1_ty).0 NDArrayType::create(arg1_ty, self.unifier).dtype_tvar(self.unifier).ty
} else { } else {
arg1_ty arg1_ty
}; };
@ -1192,9 +1199,17 @@ impl<'a> Inferencer<'a> {
// (float, int32), so convert it to align with the dtype of the first arg // (float, int32), so convert it to align with the dtype of the first arg
let arg1_ty = if id == &"np_ldexp".into() { let arg1_ty = if id == &"np_ldexp".into() {
if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, arg1_ty); // let (_, ndims) = unpack_ndarray_var_tys(self.unifier, arg1_ty);
let ndims =
NDArrayType::create(arg1_ty, self.unifier).ndims_tvar(self.unifier).ty;
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndims)) NDArrayType::from_primitive(
self.unifier,
self.primitives,
Some(target_ty),
Some(ndims),
)
.into()
} else { } else {
target_ty target_ty
} }
@ -1281,9 +1296,16 @@ impl<'a> Inferencer<'a> {
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
{ {
let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty); let ndarray_ndims =
NDArrayType::create(arg0_ty, self.unifier).ndims_tvar(self.unifier).ty;
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims)) NDArrayType::from_primitive(
self.unifier,
self.primitives,
Some(target_ty),
Some(ndarray_ndims),
)
.into()
} else { } else {
target_ty target_ty
}; };
@ -1323,7 +1345,7 @@ impl<'a> Inferencer<'a> {
self.fold_numpy_function_call_shape_argument(*id, 0, shape_expr)?; // Special handling for `shape` self.fold_numpy_function_call_shape_argument(*id, 0, shape_expr)?; // Special handling for `shape`
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None); let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
let ret = make_ndarray_ty( let ret = NDArrayType::from_primitive(
self.unifier, self.unifier,
self.primitives, self.primitives,
Some(self.primitives.float), Some(self.primitives.float),
@ -1335,13 +1357,13 @@ impl<'a> Inferencer<'a> {
ty: shape.custom.unwrap(), ty: shape.custom.unwrap(),
default_value: None, default_value: None,
}], }],
ret, ret: ret.into(),
vars: VarMap::new(), vars: VarMap::new(),
})); }));
return Ok(Some(Located { return Ok(Some(Located {
location, location,
custom: Some(ret), custom: Some(ret.into()),
node: ExprKind::Call { node: ExprKind::Call {
func: Box::new(Located { func: Box::new(Located {
custom: Some(custom), custom: Some(custom),
@ -1374,7 +1396,8 @@ impl<'a> Inferencer<'a> {
let ty = arg1.custom.unwrap(); let ty = arg1.custom.unwrap();
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None); let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
let ret = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims)); let ret =
NDArrayType::from_primitive(self.unifier, self.primitives, Some(ty), Some(ndims));
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![ args: vec![
FuncArg { name: "shape".into(), ty: arg0.custom.unwrap(), default_value: None }, FuncArg { name: "shape".into(), ty: arg0.custom.unwrap(), default_value: None },
@ -1384,13 +1407,13 @@ impl<'a> Inferencer<'a> {
default_value: None, default_value: None,
}, },
], ],
ret, ret: ret.into(),
vars: VarMap::new(), vars: VarMap::new(),
})); }));
return Ok(Some(Located { return Ok(Some(Located {
location, location,
custom: Some(ret), custom: Some(ret.into()),
node: ExprKind::Call { node: ExprKind::Call {
func: Box::new(Located { func: Box::new(Located {
custom: Some(custom), custom: Some(custom),
@ -1428,7 +1451,8 @@ impl<'a> Inferencer<'a> {
arraylike_get_ndims(self.unifier, arg0.custom.unwrap()) arraylike_get_ndims(self.unifier, arg0.custom.unwrap())
}; };
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None); let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
let ret = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims)); let ret =
NDArrayType::from_primitive(self.unifier, self.primitives, Some(ty), Some(ndims));
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![ args: vec![
@ -1448,13 +1472,13 @@ impl<'a> Inferencer<'a> {
default_value: Some(SymbolValue::U32(0)), default_value: Some(SymbolValue::U32(0)),
}, },
], ],
ret, ret: ret.into(),
vars: VarMap::new(), vars: VarMap::new(),
})); }));
return Ok(Some(Located { return Ok(Some(Located {
location, location,
custom: Some(ret), custom: Some(ret.into()),
node: ExprKind::Call { node: ExprKind::Call {
func: Box::new(Located { func: Box::new(Located {
custom: Some(custom), custom: Some(custom),
@ -1803,9 +1827,13 @@ impl<'a> Inferencer<'a> {
TypeEnum::TVar { is_const_generic: false, .. } TypeEnum::TVar { is_const_generic: false, .. }
)); ));
let constrained_ty = let constrained_ty = NDArrayType::from_primitive(
make_ndarray_ty(self.unifier, self.primitives, Some(dummy_tvar), Some(ndims)); self.unifier,
self.constrain(value.custom.unwrap(), constrained_ty, &value.location)?; self.primitives,
Some(dummy_tvar),
Some(ndims),
);
self.constrain(value.custom.unwrap(), constrained_ty.into(), &value.location)?;
let TypeEnum::TLiteral { values, .. } = &*self.unifier.get_ty_immutable(ndims) else { let TypeEnum::TLiteral { values, .. } = &*self.unifier.get_ty_immutable(ndims) else {
panic!("Expected TLiteral for ndarray.ndims, got {}", self.unifier.stringify(ndims)) panic!("Expected TLiteral for ndarray.ndims, got {}", self.unifier.stringify(ndims))
@ -1871,10 +1899,14 @@ impl<'a> Inferencer<'a> {
let ndims_ty = self let ndims_ty = self
.unifier .unifier
.get_fresh_literal(new_ndims.into_iter().map(SymbolValue::U64).collect(), None); .get_fresh_literal(new_ndims.into_iter().map(SymbolValue::U64).collect(), None);
let subscripted_ty = let subscripted_ty = NDArrayType::from_primitive(
make_ndarray_ty(self.unifier, self.primitives, Some(dummy_tvar), Some(ndims_ty)); self.unifier,
self.primitives,
Some(dummy_tvar),
Some(ndims_ty),
);
Ok(subscripted_ty) Ok(subscripted_ty.into())
} }
} }
@ -1893,10 +1925,17 @@ impl<'a> Inferencer<'a> {
let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) { let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) {
TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }), TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }),
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (_, ndims) = let ndims = NDArrayType::create(value.custom.unwrap(), self.unifier)
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); .ndims_tvar(self.unifier)
.ty;
make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims)) NDArrayType::from_primitive(
self.unifier,
self.primitives,
Some(ty),
Some(ndims),
)
.into()
} }
_ => unreachable!(), _ => unreachable!(),
@ -1907,8 +1946,10 @@ impl<'a> Inferencer<'a> {
ExprKind::Constant { value: ast::Constant::Int(val), .. } => { ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
match &*self.unifier.get_ty(value.custom.unwrap()) { match &*self.unifier.get_ty(value.custom.unwrap()) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (_, ndims) = let ndims = NDArrayType::create(value.custom.unwrap(), self.unifier)
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); .ndims_tvar(self.unifier)
.ty;
self.infer_subscript_ndarray(value, slice, ty, ndims) self.infer_subscript_ndarray(value, slice, ty, ndims)
} }
_ => { _ => {
@ -1951,7 +1992,10 @@ impl<'a> Inferencer<'a> {
} }
} }
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); let ndims = NDArrayType::create(value.custom.unwrap(), self.unifier)
.ndims_tvar(self.unifier)
.ty;
self.infer_subscript_ndarray(value, slice, ty, ndims) self.infer_subscript_ndarray(value, slice, ty, ndims)
} }
_ => { _ => {
@ -1975,8 +2019,9 @@ impl<'a> Inferencer<'a> {
Ok(ty) Ok(ty)
} }
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (_, ndims) = let ndims = NDArrayType::create(value.custom.unwrap(), self.unifier)
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); .ndims_tvar(self.unifier)
.ty;
let valid_index_tys = [self.primitives.int32, self.primitives.isize()] let valid_index_tys = [self.primitives.int32, self.primitives.isize()]
.into_iter() .into_iter()

View File

@ -139,6 +139,7 @@ impl TestEnvironment {
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let option = OptionType::create(option, &mut unifier);
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None); let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
let ndarray_ndims_tvar = let ndarray_ndims_tvar =
unifier.get_fresh_const_generic_var(uint64, Some("ndarray_ndims".into()), None); unifier.get_fresh_const_generic_var(uint64, Some("ndarray_ndims".into()), None);
@ -147,6 +148,7 @@ impl TestEnvironment {
fields: HashMap::new(), fields: HashMap::new(),
params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]), params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
}); });
let ndarray = NDArrayType::create(ndarray, &mut unifier);
let primitives = PrimitiveStore { let primitives = PrimitiveStore {
int32, int32,
int64, int64,
@ -273,11 +275,13 @@ impl TestEnvironment {
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let option = OptionType::create(option, &mut unifier);
let ndarray = unifier.add_ty(TypeEnum::TObj { let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::NDArray.id(), obj_id: PrimDef::NDArray.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let ndarray = NDArrayType::create(ndarray, &mut unifier);
identifier_mapping.insert("None".into(), none); identifier_mapping.insert("None".into(), none);
for (i, name) in ["int32", "int64", "float", "bool", "none", "range", "str", "Exception"] for (i, name) in ["int32", "int64", "float", "bool", "none", "range", "str", "Exception"]
.iter() .iter()

View File

@ -137,7 +137,8 @@ where
#[must_use] #[must_use]
fn get_type(&self) -> Type; fn get_type(&self) -> Type;
/// See [`Type::obj_id`]. /// Similar to [`Type::obj_id`], except that the [`DefinitionId`] is not wrapped within an
/// [`Option`].
#[must_use] #[must_use]
fn obj_id(&self, unifier: &Unifier) -> DefinitionId { fn obj_id(&self, unifier: &Unifier) -> DefinitionId {
self.get_type().obj_id(unifier).unwrap() self.get_type().obj_id(unifier).unwrap()