forked from M-Labs/nac3
core/typedef: WIP - Add OptionType and NDArrayType
This commit is contained in:
parent
da4dec08a5
commit
6892a4848e
@ -7,7 +7,7 @@ use nac3core::{
|
|||||||
},
|
},
|
||||||
symbol_resolver::ValueEnum,
|
symbol_resolver::ValueEnum,
|
||||||
toplevel::{helper::PrimDef, DefinitionId, GenCall},
|
toplevel::{helper::PrimDef, DefinitionId, GenCall},
|
||||||
typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, VarMap},
|
typecheck::typedef::{FunSignature, FuncArg, GenericObjectType, Type, TypeEnum, VarMap},
|
||||||
};
|
};
|
||||||
|
|
||||||
use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef};
|
use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef};
|
||||||
@ -23,7 +23,7 @@ use pyo3::{
|
|||||||
|
|
||||||
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
|
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
|
||||||
|
|
||||||
use nac3core::toplevel::numpy::unpack_ndarray_var_tys;
|
use nac3core::toplevel::primitive_type;
|
||||||
use std::{
|
use std::{
|
||||||
collections::hash_map::DefaultHasher,
|
collections::hash_map::DefaultHasher,
|
||||||
collections::HashMap,
|
collections::HashMap,
|
||||||
@ -399,7 +399,9 @@ fn gen_rpc_tag(
|
|||||||
gen_rpc_tag(ctx, *ty, buffer)?;
|
gen_rpc_tag(ctx, *ty, buffer)?;
|
||||||
}
|
}
|
||||||
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||||
let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
let ndarray_ty = primitive_type::NDArrayType::create(ty, &mut ctx.unifier);
|
||||||
|
let ndarray_dtype = ndarray_ty.dtype_tvar(&mut ctx.unifier).ty;
|
||||||
|
let ndarray_ndims = ndarray_ty.ndims_tvar(&mut ctx.unifier).ty;
|
||||||
let ndarray_ndims = if let TLiteral { values, .. } =
|
let ndarray_ndims = if let TLiteral { values, .. } =
|
||||||
&*ctx.unifier.get_ty_immutable(ndarray_ndims)
|
&*ctx.unifier.get_ty_immutable(ndarray_ndims)
|
||||||
{
|
{
|
||||||
@ -645,7 +647,7 @@ pub fn attributes_writeback(
|
|||||||
let ty = ty.unwrap();
|
let ty = ty.unwrap();
|
||||||
match &*ctx.unifier.get_ty(ty) {
|
match &*ctx.unifier.get_ty(ty) {
|
||||||
TypeEnum::TObj { fields, obj_id, .. }
|
TypeEnum::TObj { fields, obj_id, .. }
|
||||||
if *obj_id != ctx.primitives.option.obj_id(&ctx.unifier).unwrap() =>
|
if *obj_id != ctx.primitives.option.obj_id(&ctx.unifier) =>
|
||||||
{
|
{
|
||||||
// we only care about primitive attributes
|
// we only care about primitive attributes
|
||||||
// for non-primitive attributes, they should be in another global
|
// for non-primitive attributes, they should be in another global
|
||||||
|
@ -11,11 +11,7 @@ use nac3core::{
|
|||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
},
|
},
|
||||||
symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum},
|
symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum},
|
||||||
toplevel::{
|
toplevel::{helper::PrimDef, primitive_type, DefinitionId, TopLevelDef},
|
||||||
helper::PrimDef,
|
|
||||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
|
||||||
DefinitionId, TopLevelDef,
|
|
||||||
},
|
|
||||||
typecheck::{
|
typecheck::{
|
||||||
type_inferencer::PrimitiveStore,
|
type_inferencer::PrimitiveStore,
|
||||||
typedef::{Type, TypeEnum, TypeVar, Unifier, VarMap},
|
typedef::{Type, TypeEnum, TypeVar, Unifier, VarMap},
|
||||||
@ -337,13 +333,18 @@ impl InnerResolver {
|
|||||||
// do not handle type var param and concrete check here
|
// do not handle type var param and concrete check here
|
||||||
let var = unifier.get_dummy_var().ty;
|
let var = unifier.get_dummy_var().ty;
|
||||||
let ndims = unifier.get_fresh_const_generic_var(primitives.usize(), None, None).ty;
|
let ndims = unifier.get_fresh_const_generic_var(primitives.usize(), None, None).ty;
|
||||||
let ndarray = make_ndarray_ty(unifier, primitives, Some(var), Some(ndims));
|
let ndarray = primitive_type::NDArrayType::from_primitive(
|
||||||
Ok(Ok((ndarray, false)))
|
unifier,
|
||||||
|
primitives,
|
||||||
|
Some(var),
|
||||||
|
Some(ndims),
|
||||||
|
);
|
||||||
|
Ok(Ok((ndarray.into(), false)))
|
||||||
} else if ty_id == self.primitive_ids.tuple {
|
} else if ty_id == self.primitive_ids.tuple {
|
||||||
// do not handle type var param and concrete check here
|
// do not handle type var param and concrete check here
|
||||||
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false)))
|
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false)))
|
||||||
} else if ty_id == self.primitive_ids.option {
|
} else if ty_id == self.primitive_ids.option {
|
||||||
Ok(Ok((primitives.option, false)))
|
Ok(Ok((primitives.option.into(), false)))
|
||||||
} else if ty_id == self.primitive_ids.none {
|
} else if ty_id == self.primitive_ids.none {
|
||||||
unreachable!("none cannot be typeid")
|
unreachable!("none cannot be typeid")
|
||||||
} else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).copied() {
|
} else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).copied() {
|
||||||
@ -510,7 +511,16 @@ impl InnerResolver {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Ok((make_ndarray_ty(unifier, primitives, Some(ty.0), None), true)))
|
Ok(Ok((
|
||||||
|
primitive_type::NDArrayType::from_primitive(
|
||||||
|
unifier,
|
||||||
|
primitives,
|
||||||
|
Some(ty.0),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.into(),
|
||||||
|
true,
|
||||||
|
)))
|
||||||
}
|
}
|
||||||
TypeEnum::TTuple { .. } => {
|
TypeEnum::TTuple { .. } => {
|
||||||
let args = match args
|
let args = match args
|
||||||
@ -719,7 +729,9 @@ impl InnerResolver {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
(TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PrimDef::NDArray.id() => {
|
(TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PrimDef::NDArray.id() => {
|
||||||
let (ty, ndims) = unpack_ndarray_var_tys(unifier, extracted_ty);
|
let ndarray = primitive_type::NDArrayType::create(extracted_ty, unifier);
|
||||||
|
let ty = ndarray.dtype_tvar(unifier).ty;
|
||||||
|
let ndims = ndarray.ndims_tvar(unifier).ty;
|
||||||
let len: usize = obj.getattr("ndim")?.extract()?;
|
let len: usize = obj.getattr("ndim")?.extract()?;
|
||||||
if len == 0 {
|
if len == 0 {
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
@ -734,10 +746,14 @@ impl InnerResolver {
|
|||||||
match dtype_ty {
|
match dtype_ty {
|
||||||
Ok((t, _)) => match unifier.unify(ty, t) {
|
Ok((t, _)) => match unifier.unify(ty, t) {
|
||||||
Ok(()) => {
|
Ok(()) => {
|
||||||
let ndarray_ty =
|
let ndarray_ty = primitive_type::NDArrayType::from_primitive(
|
||||||
make_ndarray_ty(unifier, primitives, Some(ty), Some(ndims));
|
unifier,
|
||||||
|
primitives,
|
||||||
|
Some(ty),
|
||||||
|
Some(ndims),
|
||||||
|
);
|
||||||
|
|
||||||
Ok(Ok(ndarray_ty))
|
Ok(Ok(ndarray_ty.into()))
|
||||||
}
|
}
|
||||||
Err(e) => Ok(Err(format!(
|
Err(e) => Ok(Err(format!(
|
||||||
"type error ({}) for the ndarray",
|
"type error ({}) for the ndarray",
|
||||||
@ -760,7 +776,7 @@ impl InnerResolver {
|
|||||||
// special handling for option type since its class member layout in python side
|
// special handling for option type since its class member layout in python side
|
||||||
// is special and cannot be mapped directly to a nac3 type as below
|
// is special and cannot be mapped directly to a nac3 type as below
|
||||||
(TypeEnum::TObj { obj_id, params, .. }, false)
|
(TypeEnum::TObj { obj_id, params, .. }, false)
|
||||||
if *obj_id == primitives.option.obj_id(unifier).unwrap() =>
|
if *obj_id == primitives.option.obj_id(unifier) =>
|
||||||
{
|
{
|
||||||
let Ok(field_data) = obj.getattr("_nac3_option") else {
|
let Ok(field_data) = obj.getattr("_nac3_option") else {
|
||||||
unreachable!("cannot be None")
|
unreachable!("cannot be None")
|
||||||
@ -785,7 +801,7 @@ impl InnerResolver {
|
|||||||
.map(TypeVar::into)
|
.map(TypeVar::into)
|
||||||
.collect::<VarMap>()
|
.collect::<VarMap>()
|
||||||
});
|
});
|
||||||
return Ok(Ok(unifier.subst(primitives.option, &var_map).unwrap()));
|
return Ok(Ok(unifier.subst(primitives.option.into(), &var_map).unwrap()));
|
||||||
}
|
}
|
||||||
|
|
||||||
let ty = match self.get_obj_type(py, field_data, unifier, defs, primitives)? {
|
let ty = match self.get_obj_type(py, field_data, unifier, defs, primitives)? {
|
||||||
@ -1038,8 +1054,9 @@ impl InnerResolver {
|
|||||||
} else {
|
} else {
|
||||||
unreachable!("must be ndarray")
|
unreachable!("must be ndarray")
|
||||||
};
|
};
|
||||||
let (ndarray_dtype, ndarray_ndims) =
|
let ndarray_ty = primitive_type::NDArrayType::create(ndarray_ty, &mut ctx.unifier);
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty);
|
let ndarray_dtype = ndarray_ty.dtype_tvar(&mut ctx.unifier).ty;
|
||||||
|
let ndarray_ndims = ndarray_ty.ndims_tvar(&mut ctx.unifier).ty;
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
let ndarray_dtype_llvm_ty = ctx.get_llvm_type(generator, ndarray_dtype);
|
let ndarray_dtype_llvm_ty = ctx.get_llvm_type(generator, ndarray_dtype);
|
||||||
@ -1186,7 +1203,7 @@ impl InnerResolver {
|
|||||||
} else if ty_id == self.primitive_ids.option {
|
} else if ty_id == self.primitive_ids.option {
|
||||||
let option_val_ty = match ctx.unifier.get_ty_immutable(expected_ty).as_ref() {
|
let option_val_ty = match ctx.unifier.get_ty_immutable(expected_ty).as_ref() {
|
||||||
TypeEnum::TObj { obj_id, params, .. }
|
TypeEnum::TObj { obj_id, params, .. }
|
||||||
if *obj_id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap() =>
|
if *obj_id == ctx.primitives.option.obj_id(&ctx.unifier) =>
|
||||||
{
|
{
|
||||||
*params.iter().next().unwrap().1
|
*params.iter().next().unwrap().1
|
||||||
}
|
}
|
||||||
|
@ -8,8 +8,8 @@ use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
|
|||||||
use crate::codegen::stmt::gen_for_callback_incrementing;
|
use crate::codegen::stmt::gen_for_callback_incrementing;
|
||||||
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
|
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
|
||||||
use crate::toplevel::helper::PrimDef;
|
use crate::toplevel::helper::PrimDef;
|
||||||
use crate::toplevel::numpy::unpack_ndarray_var_tys;
|
use crate::toplevel::primitive_type;
|
||||||
use crate::typecheck::typedef::Type;
|
use crate::typecheck::typedef::{GenericObjectType, Type};
|
||||||
|
|
||||||
/// Shorthand for [`unreachable!()`] when a type of argument is not supported.
|
/// Shorthand for [`unreachable!()`] when a type of argument is not supported.
|
||||||
///
|
///
|
||||||
@ -66,7 +66,9 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
BasicValueEnum::PointerValue(n)
|
BasicValueEnum::PointerValue(n)
|
||||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
@ -128,7 +130,9 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
BasicValueEnum::PointerValue(n)
|
BasicValueEnum::PointerValue(n)
|
||||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
@ -206,7 +210,9 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
BasicValueEnum::PointerValue(n)
|
BasicValueEnum::PointerValue(n)
|
||||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
@ -273,7 +279,9 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
BasicValueEnum::PointerValue(n)
|
BasicValueEnum::PointerValue(n)
|
||||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
@ -328,7 +336,9 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
BasicValueEnum::PointerValue(n)
|
BasicValueEnum::PointerValue(n)
|
||||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
@ -374,7 +384,9 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
BasicValueEnum::PointerValue(n)
|
BasicValueEnum::PointerValue(n)
|
||||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
@ -414,7 +426,9 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
BasicValueEnum::PointerValue(n)
|
BasicValueEnum::PointerValue(n)
|
||||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
@ -475,7 +489,9 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
BasicValueEnum::PointerValue(n)
|
BasicValueEnum::PointerValue(n)
|
||||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
@ -529,7 +545,9 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
BasicValueEnum::PointerValue(n)
|
BasicValueEnum::PointerValue(n)
|
||||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
@ -579,7 +597,9 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
BasicValueEnum::PointerValue(n)
|
BasicValueEnum::PointerValue(n)
|
||||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
generator,
|
generator,
|
||||||
@ -660,7 +680,9 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
BasicValueEnum::PointerValue(n)
|
BasicValueEnum::PointerValue(n)
|
||||||
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
|
let elem_ty = primitive_type::NDArrayType::create(a_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
||||||
@ -751,16 +773,24 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||||
|
|
||||||
let dtype = if is_ndarray1 && is_ndarray2 {
|
let dtype = if is_ndarray1 && is_ndarray2 {
|
||||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
|
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
|
|
||||||
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||||
|
|
||||||
ndarray_dtype1
|
ndarray_dtype1
|
||||||
} else if is_ndarray1 {
|
} else if is_ndarray1 {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
|
primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty
|
||||||
} else if is_ndarray2 {
|
} else if is_ndarray2 {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
@ -850,7 +880,9 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
BasicValueEnum::PointerValue(n)
|
BasicValueEnum::PointerValue(n)
|
||||||
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
|
let elem_ty = primitive_type::NDArrayType::create(a_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
|
||||||
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
||||||
@ -941,16 +973,24 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||||
|
|
||||||
let dtype = if is_ndarray1 && is_ndarray2 {
|
let dtype = if is_ndarray1 && is_ndarray2 {
|
||||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
|
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
|
|
||||||
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||||
|
|
||||||
ndarray_dtype1
|
ndarray_dtype1
|
||||||
} else if is_ndarray1 {
|
} else if is_ndarray1 {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
|
primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty
|
||||||
} else if is_ndarray2 {
|
} else if is_ndarray2 {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
@ -1008,7 +1048,9 @@ where
|
|||||||
if arg_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
if arg_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||||
{
|
{
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
let (arg_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty);
|
let arg_elem_ty = primitive_type::NDArrayType::create(arg_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
let ret_elem_ty = get_ret_elem_type(ctx, arg_elem_ty);
|
let ret_elem_ty = get_ret_elem_type(ctx, arg_elem_ty);
|
||||||
|
|
||||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
@ -1370,16 +1412,24 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||||
|
|
||||||
let dtype = if is_ndarray1 && is_ndarray2 {
|
let dtype = if is_ndarray1 && is_ndarray2 {
|
||||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
|
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
|
|
||||||
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||||
|
|
||||||
ndarray_dtype1
|
ndarray_dtype1
|
||||||
} else if is_ndarray1 {
|
} else if is_ndarray1 {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
|
primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty
|
||||||
} else if is_ndarray2 {
|
} else if is_ndarray2 {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
@ -1437,16 +1487,24 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||||
|
|
||||||
let dtype = if is_ndarray1 && is_ndarray2 {
|
let dtype = if is_ndarray1 && is_ndarray2 {
|
||||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
|
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
|
|
||||||
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||||
|
|
||||||
ndarray_dtype1
|
ndarray_dtype1
|
||||||
} else if is_ndarray1 {
|
} else if is_ndarray1 {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
|
primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty
|
||||||
} else if is_ndarray2 {
|
} else if is_ndarray2 {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
@ -1504,16 +1562,24 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||||
|
|
||||||
let dtype = if is_ndarray1 && is_ndarray2 {
|
let dtype = if is_ndarray1 && is_ndarray2 {
|
||||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
|
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
|
|
||||||
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||||
|
|
||||||
ndarray_dtype1
|
ndarray_dtype1
|
||||||
} else if is_ndarray1 {
|
} else if is_ndarray1 {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
|
primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty
|
||||||
} else if is_ndarray2 {
|
} else if is_ndarray2 {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
@ -1571,16 +1637,24 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||||
|
|
||||||
let dtype = if is_ndarray1 && is_ndarray2 {
|
let dtype = if is_ndarray1 && is_ndarray2 {
|
||||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
|
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
|
|
||||||
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||||
|
|
||||||
ndarray_dtype1
|
ndarray_dtype1
|
||||||
} else if is_ndarray1 {
|
} else if is_ndarray1 {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
|
primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty
|
||||||
} else if is_ndarray2 {
|
} else if is_ndarray2 {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
@ -1637,12 +1711,22 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
let is_ndarray2 =
|
let is_ndarray2 =
|
||||||
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||||
|
|
||||||
let dtype =
|
let dtype = if is_ndarray1 {
|
||||||
if is_ndarray1 { unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 } else { x1_ty };
|
primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty
|
||||||
|
} else {
|
||||||
|
x1_ty
|
||||||
|
};
|
||||||
|
|
||||||
let x1_scalar_ty = dtype;
|
let x1_scalar_ty = dtype;
|
||||||
let x2_scalar_ty =
|
let x2_scalar_ty = if is_ndarray2 {
|
||||||
if is_ndarray2 { unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 } else { x2_ty };
|
primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty
|
||||||
|
} else {
|
||||||
|
x2_ty
|
||||||
|
};
|
||||||
|
|
||||||
numpy::ndarray_elementwise_binop_impl(
|
numpy::ndarray_elementwise_binop_impl(
|
||||||
generator,
|
generator,
|
||||||
@ -1694,16 +1778,24 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||||
|
|
||||||
let dtype = if is_ndarray1 && is_ndarray2 {
|
let dtype = if is_ndarray1 && is_ndarray2 {
|
||||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
|
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
|
|
||||||
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||||
|
|
||||||
ndarray_dtype1
|
ndarray_dtype1
|
||||||
} else if is_ndarray1 {
|
} else if is_ndarray1 {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
|
primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty
|
||||||
} else if is_ndarray2 {
|
} else if is_ndarray2 {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
@ -1761,16 +1853,24 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||||
|
|
||||||
let dtype = if is_ndarray1 && is_ndarray2 {
|
let dtype = if is_ndarray1 && is_ndarray2 {
|
||||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
|
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
|
|
||||||
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||||
|
|
||||||
ndarray_dtype1
|
ndarray_dtype1
|
||||||
} else if is_ndarray1 {
|
} else if is_ndarray1 {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
|
primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty
|
||||||
} else if is_ndarray2 {
|
} else if is_ndarray2 {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
|
@ -1,5 +1,9 @@
|
|||||||
use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
|
use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
|
||||||
|
|
||||||
|
use super::{llvm_intrinsics::call_memcpy_generic, need_sret, CodeGenerator};
|
||||||
|
use crate::toplevel::primitive_type;
|
||||||
|
use crate::toplevel::primitive_type::OptionType;
|
||||||
|
use crate::typecheck::typedef::GenericObjectType;
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::{
|
codegen::{
|
||||||
classes::{
|
classes::{
|
||||||
@ -15,11 +19,7 @@ use crate::{
|
|||||||
CodeGenContext, CodeGenTask,
|
CodeGenContext, CodeGenTask,
|
||||||
},
|
},
|
||||||
symbol_resolver::{SymbolValue, ValueEnum},
|
symbol_resolver::{SymbolValue, ValueEnum},
|
||||||
toplevel::{
|
toplevel::{helper::PrimDef, DefinitionId, TopLevelDef},
|
||||||
helper::PrimDef,
|
|
||||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
|
||||||
DefinitionId, TopLevelDef,
|
|
||||||
},
|
|
||||||
typecheck::{
|
typecheck::{
|
||||||
magic_methods::{binop_assign_name, binop_name, unaryop_name},
|
magic_methods::{binop_assign_name, binop_name, unaryop_name},
|
||||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap},
|
typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap},
|
||||||
@ -36,8 +36,6 @@ use nac3parser::ast::{
|
|||||||
self, Boolop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef,
|
self, Boolop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{llvm_intrinsics::call_memcpy_generic, need_sret, CodeGenerator};
|
|
||||||
|
|
||||||
pub fn get_subst_key(
|
pub fn get_subst_key(
|
||||||
unifier: &mut Unifier,
|
unifier: &mut Unifier,
|
||||||
obj: Option<Type>,
|
obj: Option<Type>,
|
||||||
@ -162,14 +160,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
|||||||
self.builder.build_load(ptr, "tup_val").unwrap()
|
self.builder.build_load(ptr, "tup_val").unwrap()
|
||||||
}
|
}
|
||||||
SymbolValue::OptionSome(v) => {
|
SymbolValue::OptionSome(v) => {
|
||||||
let ty = match self.unifier.get_ty_immutable(ty).as_ref() {
|
let ty = OptionType::create(ty, &mut self.unifier).type_tvar(&mut self.unifier).ty;
|
||||||
TypeEnum::TObj { obj_id, params, .. }
|
|
||||||
if *obj_id == self.primitives.option.obj_id(&self.unifier).unwrap() =>
|
|
||||||
{
|
|
||||||
*params.iter().next().unwrap().1
|
|
||||||
}
|
|
||||||
_ => unreachable!("must be option type"),
|
|
||||||
};
|
|
||||||
let val = self.gen_symbol_val(generator, v, ty);
|
let val = self.gen_symbol_val(generator, v, ty);
|
||||||
let ptr = generator
|
let ptr = generator
|
||||||
.gen_var_alloc(self, val.get_type(), Some("default_opt_some"))
|
.gen_var_alloc(self, val.get_type(), Some("default_opt_some"))
|
||||||
@ -178,14 +169,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
|||||||
ptr.into()
|
ptr.into()
|
||||||
}
|
}
|
||||||
SymbolValue::OptionNone => {
|
SymbolValue::OptionNone => {
|
||||||
let ty = match self.unifier.get_ty_immutable(ty).as_ref() {
|
let ty = OptionType::create(ty, &mut self.unifier).type_tvar(&mut self.unifier).ty;
|
||||||
TypeEnum::TObj { obj_id, params, .. }
|
|
||||||
if *obj_id == self.primitives.option.obj_id(&self.unifier).unwrap() =>
|
|
||||||
{
|
|
||||||
*params.iter().next().unwrap().1
|
|
||||||
}
|
|
||||||
_ => unreachable!("must be option type"),
|
|
||||||
};
|
|
||||||
let actual_ptr_type =
|
let actual_ptr_type =
|
||||||
self.get_llvm_type(generator, ty).ptr_type(AddressSpace::default());
|
self.get_llvm_type(generator, ty).ptr_type(AddressSpace::default());
|
||||||
actual_ptr_type.const_null().into()
|
actual_ptr_type.const_null().into()
|
||||||
@ -1206,8 +1190,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
let is_ndarray2 = ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
let is_ndarray2 = ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||||
|
|
||||||
if is_ndarray1 && is_ndarray2 {
|
if is_ndarray1 && is_ndarray2 {
|
||||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1);
|
let ndarray_dtype1 = primitive_type::NDArrayType::create(ty1, &mut ctx.unifier)
|
||||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty2);
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
|
let ndarray_dtype2 = primitive_type::NDArrayType::create(ty2, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
|
|
||||||
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||||
|
|
||||||
@ -1256,8 +1244,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
|
|
||||||
Ok(Some(res.as_base_value().into()))
|
Ok(Some(res.as_base_value().into()))
|
||||||
} else {
|
} else {
|
||||||
let (ndarray_dtype, _) =
|
let ndarray_dtype = primitive_type::NDArrayType::create(
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, if is_ndarray1 { ty1 } else { ty2 });
|
if is_ndarray1 { ty1 } else { ty2 },
|
||||||
|
&mut ctx.unifier,
|
||||||
|
)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
let ndarray_val = NDArrayValue::from_ptr_val(
|
let ndarray_val = NDArrayValue::from_ptr_val(
|
||||||
if is_ndarray1 { left_val } else { right_val }.into_pointer_value(),
|
if is_ndarray1 { left_val } else { right_val }.into_pointer_value(),
|
||||||
llvm_usize,
|
llvm_usize,
|
||||||
@ -1443,7 +1435,9 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
}
|
}
|
||||||
} else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
} else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
let ndarray_dtype = primitive_type::NDArrayType::create(ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
|
|
||||||
let val = NDArrayValue::from_ptr_val(val.into_pointer_value(), llvm_usize, None);
|
let val = NDArrayValue::from_ptr_val(val.into_pointer_value(), llvm_usize, None);
|
||||||
|
|
||||||
@ -1527,8 +1521,13 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||||
|
|
||||||
return if is_ndarray1 && is_ndarray2 {
|
return if is_ndarray1 && is_ndarray2 {
|
||||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty);
|
let ndarray_dtype1 = primitive_type::NDArrayType::create(left_ty, &mut ctx.unifier)
|
||||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty);
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
|
let ndarray_dtype2 =
|
||||||
|
primitive_type::NDArrayType::create(right_ty, &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
|
|
||||||
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||||
|
|
||||||
@ -1562,10 +1561,12 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
|
|
||||||
Ok(Some(res.as_base_value().into()))
|
Ok(Some(res.as_base_value().into()))
|
||||||
} else {
|
} else {
|
||||||
let (ndarray_dtype, _) = unpack_ndarray_var_tys(
|
let ndarray_dtype = primitive_type::NDArrayType::create(
|
||||||
&mut ctx.unifier,
|
|
||||||
if is_ndarray1 { left_ty } else { right_ty },
|
if is_ndarray1 { left_ty } else { right_ty },
|
||||||
);
|
&mut ctx.unifier,
|
||||||
|
)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty;
|
||||||
let res = numpy::ndarray_elementwise_binop_impl(
|
let res = numpy::ndarray_elementwise_binop_impl(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
@ -1788,9 +1789,13 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
|||||||
ndims.iter().map(|v| SymbolValue::U64(v - subscripted_dims)).collect(),
|
ndims.iter().map(|v| SymbolValue::U64(v - subscripted_dims)).collect(),
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
let ndarray_ty =
|
let ndarray_ty = primitive_type::NDArrayType::from_primitive(
|
||||||
make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(ty), Some(ndarray_ndims_ty));
|
&mut ctx.unifier,
|
||||||
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
|
&ctx.primitives,
|
||||||
|
Some(ty),
|
||||||
|
Some(ndarray_ndims_ty),
|
||||||
|
);
|
||||||
|
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty.into()).into_pointer_type();
|
||||||
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
|
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
|
||||||
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum();
|
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum();
|
||||||
|
|
||||||
@ -2082,7 +2087,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
|||||||
ExprKind::Name { id, .. } if id == &"none".into() => {
|
ExprKind::Name { id, .. } if id == &"none".into() => {
|
||||||
match (
|
match (
|
||||||
ctx.unifier.get_ty(expr.custom.unwrap()).as_ref(),
|
ctx.unifier.get_ty(expr.custom.unwrap()).as_ref(),
|
||||||
ctx.unifier.get_ty(ctx.primitives.option).as_ref(),
|
ctx.unifier.get_ty(ctx.primitives.option.into()).as_ref(),
|
||||||
) {
|
) {
|
||||||
(TypeEnum::TObj { obj_id, params, .. }, TypeEnum::TObj { obj_id: opt_id, .. })
|
(TypeEnum::TObj { obj_id, params, .. }, TypeEnum::TObj { obj_id: opt_id, .. })
|
||||||
if *obj_id == *opt_id =>
|
if *obj_id == *opt_id =>
|
||||||
@ -2464,8 +2469,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
|||||||
};
|
};
|
||||||
// directly generate code for option.unwrap
|
// directly generate code for option.unwrap
|
||||||
// since it needs to return static value to optimize for kernel invariant
|
// since it needs to return static value to optimize for kernel invariant
|
||||||
if attr == &"unwrap".into()
|
if attr == &"unwrap".into() && id == ctx.primitives.option.obj_id(&ctx.unifier)
|
||||||
&& id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap()
|
|
||||||
{
|
{
|
||||||
match val {
|
match val {
|
||||||
ValueEnum::Static(v) => {
|
ValueEnum::Static(v) => {
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
codegen::classes::{ListType, NDArrayType, ProxyType, RangeType},
|
codegen::classes::{ListType, NDArrayType, ProxyType, RangeType},
|
||||||
symbol_resolver::{StaticValue, SymbolResolver},
|
symbol_resolver::{StaticValue, SymbolResolver},
|
||||||
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef},
|
toplevel::{helper::PrimDef, TopLevelContext, TopLevelDef},
|
||||||
typecheck::{
|
typecheck::{
|
||||||
type_inferencer::{CodeLocation, PrimitiveStore},
|
type_inferencer::{CodeLocation, PrimitiveStore},
|
||||||
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
|
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
|
||||||
@ -47,6 +47,9 @@ pub mod stmt;
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test;
|
mod test;
|
||||||
|
|
||||||
|
use crate::toplevel::primitive_type;
|
||||||
|
use crate::toplevel::primitive_type::OptionType;
|
||||||
|
use crate::typecheck::typedef::GenericObjectType;
|
||||||
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore};
|
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore};
|
||||||
pub use generator::{CodeGenerator, DefaultCodeGenerator};
|
pub use generator::{CodeGenerator, DefaultCodeGenerator};
|
||||||
|
|
||||||
@ -457,7 +460,9 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||||
let (dtype, _) = unpack_ndarray_var_tys(unifier, ty);
|
let dtype = primitive_type::NDArrayType::create(ty, unifier)
|
||||||
|
.dtype_tvar(unifier)
|
||||||
|
.ty;
|
||||||
let element_type = get_llvm_type(
|
let element_type = get_llvm_type(
|
||||||
ctx, module, generator, unifier, top_level, type_cache, dtype,
|
ctx, module, generator, unifier, top_level, type_cache, dtype,
|
||||||
);
|
);
|
||||||
@ -634,7 +639,10 @@ pub fn gen_func_impl<
|
|||||||
range: unifier.get_representative(primitives.range),
|
range: unifier.get_representative(primitives.range),
|
||||||
str: unifier.get_representative(primitives.str),
|
str: unifier.get_representative(primitives.str),
|
||||||
exception: unifier.get_representative(primitives.exception),
|
exception: unifier.get_representative(primitives.exception),
|
||||||
option: unifier.get_representative(primitives.option),
|
option: OptionType::create(
|
||||||
|
unifier.get_representative(primitives.option.into()),
|
||||||
|
&mut unifier,
|
||||||
|
),
|
||||||
..primitives
|
..primitives
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -17,12 +17,8 @@ use crate::{
|
|||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
},
|
},
|
||||||
symbol_resolver::ValueEnum,
|
symbol_resolver::ValueEnum,
|
||||||
toplevel::{
|
toplevel::{helper::PrimDef, primitive_type, DefinitionId},
|
||||||
helper::PrimDef,
|
typecheck::typedef::{FunSignature, GenericObjectType, Type, TypeEnum},
|
||||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
|
||||||
DefinitionId,
|
|
||||||
},
|
|
||||||
typecheck::typedef::{FunSignature, Type, TypeEnum},
|
|
||||||
};
|
};
|
||||||
use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType};
|
use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType};
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
@ -38,12 +34,17 @@ fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
elem_ty: Type,
|
elem_ty: Type,
|
||||||
) -> Result<NDArrayValue<'ctx>, String> {
|
) -> Result<NDArrayValue<'ctx>, String> {
|
||||||
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
|
let ndarray_ty = primitive_type::NDArrayType::from_primitive(
|
||||||
|
&mut ctx.unifier,
|
||||||
|
&ctx.primitives,
|
||||||
|
Some(elem_ty),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
let llvm_ndarray_t = ctx
|
let llvm_ndarray_t = ctx
|
||||||
.get_llvm_type(generator, ndarray_ty)
|
.get_llvm_type(generator, ndarray_ty.into())
|
||||||
.into_pointer_type()
|
.into_pointer_type()
|
||||||
.get_element_type()
|
.get_element_type()
|
||||||
.into_struct_type();
|
.into_struct_type();
|
||||||
@ -1799,7 +1800,9 @@ pub fn gen_ndarray_array<'ctx>(
|
|||||||
let obj_ty = fun.0.args[0].ty;
|
let obj_ty = fun.0.args[0].ty;
|
||||||
let obj_elem_ty = match &*context.unifier.get_ty(obj_ty) {
|
let obj_elem_ty = match &*context.unifier.get_ty(obj_ty) {
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||||
unpack_ndarray_var_tys(&mut context.unifier, obj_ty).0
|
primitive_type::NDArrayType::create(obj_ty, &mut context.unifier)
|
||||||
|
.dtype_tvar(&mut context.unifier)
|
||||||
|
.ty
|
||||||
}
|
}
|
||||||
|
|
||||||
TypeEnum::TList { ty } => {
|
TypeEnum::TList { ty } => {
|
||||||
@ -1939,7 +1942,9 @@ pub fn gen_ndarray_copy<'ctx>(
|
|||||||
let llvm_usize = generator.get_size_type(context.ctx);
|
let llvm_usize = generator.get_size_type(context.ctx);
|
||||||
|
|
||||||
let this_ty = obj.as_ref().unwrap().0;
|
let this_ty = obj.as_ref().unwrap().0;
|
||||||
let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty);
|
let this_elem_ty = primitive_type::NDArrayType::create(this_ty, &mut context.unifier)
|
||||||
|
.dtype_tvar(&mut context.unifier)
|
||||||
|
.ty;
|
||||||
let this_arg =
|
let this_arg =
|
||||||
obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?;
|
obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?;
|
||||||
|
|
||||||
|
@ -4,13 +4,15 @@ use super::{
|
|||||||
irrt::{handle_slice_indices, list_slice_assignment},
|
irrt::{handle_slice_indices, list_slice_assignment},
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
};
|
};
|
||||||
|
use crate::toplevel::primitive_type;
|
||||||
|
use crate::typecheck::typedef::GenericObjectType;
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::{
|
codegen::{
|
||||||
classes::{ArrayLikeIndexer, ArraySliceValue, ListValue, RangeValue},
|
classes::{ArrayLikeIndexer, ArraySliceValue, ListValue, RangeValue},
|
||||||
expr::gen_binop_expr,
|
expr::gen_binop_expr,
|
||||||
gen_in_range_check,
|
gen_in_range_check,
|
||||||
},
|
},
|
||||||
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef},
|
toplevel::{helper::PrimDef, DefinitionId, TopLevelDef},
|
||||||
typecheck::typedef::{FunSignature, Type, TypeEnum},
|
typecheck::typedef::{FunSignature, Type, TypeEnum},
|
||||||
};
|
};
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
@ -245,7 +247,9 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
|
|||||||
let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) {
|
let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) {
|
||||||
TypeEnum::TList { ty } => *ty,
|
TypeEnum::TList { ty } => *ty,
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, target.custom.unwrap()).0
|
primitive_type::NDArrayType::create(target.custom.unwrap(), &mut ctx.unifier)
|
||||||
|
.dtype_tvar(&mut ctx.unifier)
|
||||||
|
.ty
|
||||||
}
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
|
@ -3,6 +3,7 @@ use std::rc::Rc;
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::{collections::HashMap, collections::HashSet, fmt::Display};
|
use std::{collections::HashMap, collections::HashSet, fmt::Display};
|
||||||
|
|
||||||
|
use crate::typecheck::typedef::GenericObjectType;
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::{CodeGenContext, CodeGenerator},
|
codegen::{CodeGenContext, CodeGenerator},
|
||||||
toplevel::{type_annotation::TypeAnnotation, DefinitionId, TopLevelDef},
|
toplevel::{type_annotation::TypeAnnotation, DefinitionId, TopLevelDef},
|
||||||
@ -43,7 +44,7 @@ impl SymbolValue {
|
|||||||
) -> Result<Self, String> {
|
) -> Result<Self, String> {
|
||||||
match constant {
|
match constant {
|
||||||
Constant::None => {
|
Constant::None => {
|
||||||
if unifier.unioned(expected_ty, primitives.option) {
|
if unifier.unioned(expected_ty, primitives.option.into()) {
|
||||||
Ok(SymbolValue::OptionNone)
|
Ok(SymbolValue::OptionNone)
|
||||||
} else {
|
} else {
|
||||||
Err(format!("Expected {expected_ty:?}, but got Option"))
|
Err(format!("Expected {expected_ty:?}, but got Option"))
|
||||||
@ -157,7 +158,7 @@ impl SymbolValue {
|
|||||||
let vs_tys = vs.iter().map(|v| v.get_type(primitives, unifier)).collect::<Vec<_>>();
|
let vs_tys = vs.iter().map(|v| v.get_type(primitives, unifier)).collect::<Vec<_>>();
|
||||||
unifier.add_ty(TypeEnum::TTuple { ty: vs_tys })
|
unifier.add_ty(TypeEnum::TTuple { ty: vs_tys })
|
||||||
}
|
}
|
||||||
SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option,
|
SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option.into(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -183,13 +184,13 @@ impl SymbolValue {
|
|||||||
TypeAnnotation::Tuple(vs_tys)
|
TypeAnnotation::Tuple(vs_tys)
|
||||||
}
|
}
|
||||||
SymbolValue::OptionNone => TypeAnnotation::CustomClass {
|
SymbolValue::OptionNone => TypeAnnotation::CustomClass {
|
||||||
id: primitives.option.obj_id(unifier).unwrap(),
|
id: primitives.option.obj_id(unifier),
|
||||||
params: Vec::default(),
|
params: Vec::default(),
|
||||||
},
|
},
|
||||||
SymbolValue::OptionSome(v) => {
|
SymbolValue::OptionSome(v) => {
|
||||||
let ty = v.get_type_annotation(primitives, unifier);
|
let ty = v.get_type_annotation(primitives, unifier);
|
||||||
TypeAnnotation::CustomClass {
|
TypeAnnotation::CustomClass {
|
||||||
id: primitives.option.obj_id(unifier).unwrap(),
|
id: primitives.option.obj_id(unifier),
|
||||||
params: vec![ty],
|
params: vec![ty],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -11,7 +11,6 @@ use inkwell::{
|
|||||||
use itertools::Either;
|
use itertools::Either;
|
||||||
use strum::IntoEnumIterator;
|
use strum::IntoEnumIterator;
|
||||||
|
|
||||||
use crate::typecheck::typedef::{GenericObjectType, GenericTypeAdapter};
|
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::{
|
codegen::{
|
||||||
builtin_fns,
|
builtin_fns,
|
||||||
@ -25,7 +24,7 @@ use crate::{
|
|||||||
stmt::exn_constructor,
|
stmt::exn_constructor,
|
||||||
},
|
},
|
||||||
symbol_resolver::SymbolValue,
|
symbol_resolver::SymbolValue,
|
||||||
toplevel::{helper::PrimDef, numpy::make_ndarray_ty},
|
toplevel::helper::PrimDef,
|
||||||
typecheck::typedef::{into_var_map, TypeVar, VarMap},
|
typecheck::typedef::{into_var_map, TypeVar, VarMap},
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -304,10 +303,7 @@ struct BuiltinBuilder<'a> {
|
|||||||
|
|
||||||
is_some_ty: (Type, bool),
|
is_some_ty: (Type, bool),
|
||||||
unwrap_ty: (Type, bool),
|
unwrap_ty: (Type, bool),
|
||||||
option_tvar: TypeVar,
|
|
||||||
|
|
||||||
ndarray_dtype_tvar: TypeVar,
|
|
||||||
ndarray_ndims_tvar: TypeVar,
|
|
||||||
ndarray_copy_ty: (Type, bool),
|
ndarray_copy_ty: (Type, bool),
|
||||||
ndarray_fill_ty: (Type, bool),
|
ndarray_fill_ty: (Type, bool),
|
||||||
|
|
||||||
@ -316,9 +312,9 @@ struct BuiltinBuilder<'a> {
|
|||||||
num_ty: TypeVar,
|
num_ty: TypeVar,
|
||||||
num_var_map: VarMap,
|
num_var_map: VarMap,
|
||||||
|
|
||||||
ndarray_float: Type,
|
ndarray_float: primitive_type::NDArrayType,
|
||||||
ndarray_float_2d: Type,
|
ndarray_float_2d: primitive_type::NDArrayType,
|
||||||
ndarray_num_ty: Type,
|
ndarray_num_ty: primitive_type::NDArrayType,
|
||||||
|
|
||||||
float_or_ndarray_ty: TypeVar,
|
float_or_ndarray_ty: TypeVar,
|
||||||
float_or_ndarray_var_map: VarMap,
|
float_or_ndarray_var_map: VarMap,
|
||||||
@ -345,24 +341,19 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
} = *primitives;
|
} = *primitives;
|
||||||
|
|
||||||
// Option-related
|
// Option-related
|
||||||
let (is_some_ty, unwrap_ty, option_tvar) =
|
let (is_some_ty, unwrap_ty) =
|
||||||
if let TypeEnum::TObj { fields, .. } = &*unifier.get_ty(option) {
|
if let TypeEnum::TObj { fields, .. } = &*unifier.get_ty(option.into()) {
|
||||||
let option = GenericTypeAdapter::create(option, unifier);
|
|
||||||
(
|
(
|
||||||
*fields.get(&PrimDef::OptionIsSome.simple_name().into()).unwrap(),
|
*fields.get(&PrimDef::OptionIsSome.simple_name().into()).unwrap(),
|
||||||
*fields.get(&PrimDef::OptionUnwrap.simple_name().into()).unwrap(),
|
*fields.get(&PrimDef::OptionUnwrap.simple_name().into()).unwrap(),
|
||||||
option.get_var_at(unifier, 0).unwrap(),
|
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
|
|
||||||
let TypeEnum::TObj { fields: ndarray_fields, .. } = &*unifier.get_ty(ndarray) else {
|
let TypeEnum::TObj { fields: ndarray_fields, .. } = &*unifier.get_ty(ndarray.into()) else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
let ndarray = GenericTypeAdapter::create(ndarray, unifier);
|
|
||||||
let ndarray_dtype_tvar = ndarray.get_var_at(unifier, 0).unwrap();
|
|
||||||
let ndarray_ndims_tvar = ndarray.get_var_at(unifier, 1).unwrap();
|
|
||||||
let ndarray_copy_ty =
|
let ndarray_copy_ty =
|
||||||
*ndarray_fields.get(&PrimDef::NDArrayCopy.simple_name().into()).unwrap();
|
*ndarray_fields.get(&PrimDef::NDArrayCopy.simple_name().into()).unwrap();
|
||||||
let ndarray_fill_ty =
|
let ndarray_fill_ty =
|
||||||
@ -375,7 +366,8 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
);
|
);
|
||||||
let num_var_map = into_var_map([num_ty]);
|
let num_var_map = into_var_map([num_ty]);
|
||||||
|
|
||||||
let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), None);
|
let ndarray_float =
|
||||||
|
primitive_type::NDArrayType::from_primitive(unifier, primitives, Some(float), None);
|
||||||
let ndarray_float_2d = {
|
let ndarray_float_2d = {
|
||||||
let value = match primitives.size_t {
|
let value = match primitives.size_t {
|
||||||
64 => SymbolValue::U64(2u64),
|
64 => SymbolValue::U64(2u64),
|
||||||
@ -384,16 +376,28 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
};
|
};
|
||||||
let ndims = unifier.add_ty(TypeEnum::TLiteral { values: vec![value], loc: None });
|
let ndims = unifier.add_ty(TypeEnum::TLiteral { values: vec![value], loc: None });
|
||||||
|
|
||||||
make_ndarray_ty(unifier, primitives, Some(float), Some(ndims))
|
primitive_type::NDArrayType::from_primitive(
|
||||||
|
unifier,
|
||||||
|
primitives,
|
||||||
|
Some(float),
|
||||||
|
Some(ndims),
|
||||||
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
let ndarray_num_ty = make_ndarray_ty(unifier, primitives, Some(num_ty.ty), None);
|
let ndarray_num_ty =
|
||||||
let float_or_ndarray_ty =
|
primitive_type::NDArrayType::from_primitive(unifier, primitives, Some(num_ty.ty), None);
|
||||||
unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None);
|
let float_or_ndarray_ty = unifier.get_fresh_var_with_range(
|
||||||
|
&[float, ndarray_float.into()],
|
||||||
|
Some("T".into()),
|
||||||
|
None,
|
||||||
|
);
|
||||||
let float_or_ndarray_var_map = into_var_map([float_or_ndarray_ty]);
|
let float_or_ndarray_var_map = into_var_map([float_or_ndarray_ty]);
|
||||||
|
|
||||||
let num_or_ndarray_ty =
|
let num_or_ndarray_ty = unifier.get_fresh_var_with_range(
|
||||||
unifier.get_fresh_var_with_range(&[num_ty.ty, ndarray_num_ty], Some("T".into()), None);
|
&[num_ty.ty, ndarray_num_ty.into()],
|
||||||
|
Some("T".into()),
|
||||||
|
None,
|
||||||
|
);
|
||||||
let num_or_ndarray_var_map = into_var_map([num_ty, num_or_ndarray_ty]);
|
let num_or_ndarray_var_map = into_var_map([num_ty, num_or_ndarray_ty]);
|
||||||
|
|
||||||
let list_int32 = unifier.add_ty(TypeEnum::TList { ty: int32 });
|
let list_int32 = unifier.add_ty(TypeEnum::TList { ty: int32 });
|
||||||
@ -406,10 +410,7 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
|
|
||||||
is_some_ty,
|
is_some_ty,
|
||||||
unwrap_ty,
|
unwrap_ty,
|
||||||
option_tvar,
|
|
||||||
|
|
||||||
ndarray_dtype_tvar,
|
|
||||||
ndarray_ndims_tvar,
|
|
||||||
ndarray_copy_ty,
|
ndarray_copy_ty,
|
||||||
ndarray_fill_ty,
|
ndarray_fill_ty,
|
||||||
|
|
||||||
@ -633,7 +634,7 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
PrimDef::Option => TopLevelDef::Class {
|
PrimDef::Option => TopLevelDef::Class {
|
||||||
name: prim.name().into(),
|
name: prim.name().into(),
|
||||||
object_id: prim.id(),
|
object_id: prim.id(),
|
||||||
type_vars: vec![self.option_tvar.ty],
|
type_vars: vec![self.primitives.option.type_tvar(self.unifier).ty],
|
||||||
fields: Vec::default(),
|
fields: Vec::default(),
|
||||||
attributes: Vec::default(),
|
attributes: Vec::default(),
|
||||||
methods: vec![
|
methods: vec![
|
||||||
@ -654,7 +655,7 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
name: prim.name().into(),
|
name: prim.name().into(),
|
||||||
simple_name: prim.simple_name().into(),
|
simple_name: prim.simple_name().into(),
|
||||||
signature: self.unwrap_ty.0,
|
signature: self.unwrap_ty.0,
|
||||||
var_id: vec![self.option_tvar.id],
|
var_id: vec![self.primitives.option.type_tvar(self.unifier).id],
|
||||||
instance_to_symbol: HashMap::default(),
|
instance_to_symbol: HashMap::default(),
|
||||||
instance_to_stmt: HashMap::default(),
|
instance_to_stmt: HashMap::default(),
|
||||||
resolver: None,
|
resolver: None,
|
||||||
@ -668,7 +669,7 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
name: prim.name().to_string(),
|
name: prim.name().to_string(),
|
||||||
simple_name: prim.simple_name().into(),
|
simple_name: prim.simple_name().into(),
|
||||||
signature: self.is_some_ty.0,
|
signature: self.is_some_ty.0,
|
||||||
var_id: vec![self.option_tvar.id],
|
var_id: vec![self.primitives.option.type_tvar(self.unifier).id],
|
||||||
instance_to_symbol: HashMap::default(),
|
instance_to_symbol: HashMap::default(),
|
||||||
instance_to_stmt: HashMap::default(),
|
instance_to_stmt: HashMap::default(),
|
||||||
resolver: None,
|
resolver: None,
|
||||||
@ -699,19 +700,22 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
loc: None,
|
loc: None,
|
||||||
},
|
},
|
||||||
|
|
||||||
PrimDef::FunSome => TopLevelDef::Function {
|
PrimDef::FunSome => {
|
||||||
|
let option_tvar = self.primitives.option.type_tvar(self.unifier);
|
||||||
|
|
||||||
|
TopLevelDef::Function {
|
||||||
name: prim.name().into(),
|
name: prim.name().into(),
|
||||||
simple_name: prim.simple_name().into(),
|
simple_name: prim.simple_name().into(),
|
||||||
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
args: vec![FuncArg {
|
args: vec![FuncArg {
|
||||||
name: "n".into(),
|
name: "n".into(),
|
||||||
ty: self.option_tvar.ty,
|
ty: option_tvar.ty,
|
||||||
default_value: None,
|
default_value: None,
|
||||||
}],
|
}],
|
||||||
ret: self.primitives.option,
|
ret: self.primitives.option.into(),
|
||||||
vars: into_var_map([self.option_tvar]),
|
vars: into_var_map([option_tvar]),
|
||||||
})),
|
})),
|
||||||
var_id: vec![self.option_tvar.id],
|
var_id: vec![self.primitives.option.type_tvar(self.unifier).id],
|
||||||
instance_to_symbol: HashMap::default(),
|
instance_to_symbol: HashMap::default(),
|
||||||
instance_to_stmt: HashMap::default(),
|
instance_to_stmt: HashMap::default(),
|
||||||
resolver: None,
|
resolver: None,
|
||||||
@ -728,7 +732,8 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
},
|
},
|
||||||
)))),
|
)))),
|
||||||
loc: None,
|
loc: None,
|
||||||
},
|
}
|
||||||
|
}
|
||||||
|
|
||||||
_ => {
|
_ => {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
@ -737,7 +742,7 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Build the class `ndarray` and its associated methods.
|
/// Build the class `ndarray` and its associated methods.
|
||||||
fn build_ndarray_class_related(&self, prim: PrimDef) -> TopLevelDef {
|
fn build_ndarray_class_related(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||||
debug_assert_prim_is_allowed(
|
debug_assert_prim_is_allowed(
|
||||||
prim,
|
prim,
|
||||||
&[PrimDef::NDArray, PrimDef::NDArrayCopy, PrimDef::NDArrayFill],
|
&[PrimDef::NDArray, PrimDef::NDArrayCopy, PrimDef::NDArrayFill],
|
||||||
@ -747,7 +752,10 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
PrimDef::NDArray => TopLevelDef::Class {
|
PrimDef::NDArray => TopLevelDef::Class {
|
||||||
name: prim.name().into(),
|
name: prim.name().into(),
|
||||||
object_id: prim.id(),
|
object_id: prim.id(),
|
||||||
type_vars: vec![self.ndarray_dtype_tvar.ty, self.ndarray_ndims_tvar.ty],
|
type_vars: vec![
|
||||||
|
self.primitives.ndarray.dtype_tvar(self.unifier).ty,
|
||||||
|
self.primitives.ndarray.ndims_tvar(self.unifier).ty,
|
||||||
|
],
|
||||||
fields: Vec::default(),
|
fields: Vec::default(),
|
||||||
attributes: Vec::default(),
|
attributes: Vec::default(),
|
||||||
methods: vec![
|
methods: vec![
|
||||||
@ -764,7 +772,10 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
name: prim.name().into(),
|
name: prim.name().into(),
|
||||||
simple_name: prim.simple_name().into(),
|
simple_name: prim.simple_name().into(),
|
||||||
signature: self.ndarray_copy_ty.0,
|
signature: self.ndarray_copy_ty.0,
|
||||||
var_id: vec![self.ndarray_dtype_tvar.id, self.ndarray_ndims_tvar.id],
|
var_id: vec![
|
||||||
|
self.primitives.ndarray.dtype_tvar(self.unifier).id,
|
||||||
|
self.primitives.ndarray.ndims_tvar(self.unifier).id,
|
||||||
|
],
|
||||||
instance_to_symbol: HashMap::default(),
|
instance_to_symbol: HashMap::default(),
|
||||||
instance_to_stmt: HashMap::default(),
|
instance_to_stmt: HashMap::default(),
|
||||||
resolver: None,
|
resolver: None,
|
||||||
@ -781,7 +792,10 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
name: prim.name().into(),
|
name: prim.name().into(),
|
||||||
simple_name: prim.simple_name().into(),
|
simple_name: prim.simple_name().into(),
|
||||||
signature: self.ndarray_fill_ty.0,
|
signature: self.ndarray_fill_ty.0,
|
||||||
var_id: vec![self.ndarray_dtype_tvar.id, self.ndarray_ndims_tvar.id],
|
var_id: vec![
|
||||||
|
self.primitives.ndarray.dtype_tvar(self.unifier).id,
|
||||||
|
self.primitives.ndarray.ndims_tvar(self.unifier).id,
|
||||||
|
],
|
||||||
instance_to_symbol: HashMap::default(),
|
instance_to_symbol: HashMap::default(),
|
||||||
instance_to_stmt: HashMap::default(),
|
instance_to_stmt: HashMap::default(),
|
||||||
resolver: None,
|
resolver: None,
|
||||||
@ -870,15 +884,26 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
// The size variant of the function determines the size of the returned int.
|
// The size variant of the function determines the size of the returned int.
|
||||||
let int_sized = size_variant.of_int(self.primitives);
|
let int_sized = size_variant.of_int(self.primitives);
|
||||||
|
|
||||||
let ndarray_int_sized =
|
let ndarray_int_sized = primitive_type::NDArrayType::from_primitive(
|
||||||
make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.ty));
|
self.unifier,
|
||||||
let ndarray_float =
|
self.primitives,
|
||||||
make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.ty));
|
Some(int_sized),
|
||||||
|
Some(common_ndim.ty),
|
||||||
|
);
|
||||||
|
let ndarray_float = primitive_type::NDArrayType::from_primitive(
|
||||||
|
self.unifier,
|
||||||
|
self.primitives,
|
||||||
|
Some(float),
|
||||||
|
Some(common_ndim.ty),
|
||||||
|
);
|
||||||
|
|
||||||
let p0_ty =
|
let p0_ty = self.unifier.get_fresh_var_with_range(
|
||||||
self.unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None);
|
&[float, ndarray_float.into()],
|
||||||
|
Some("T".into()),
|
||||||
|
None,
|
||||||
|
);
|
||||||
let ret_ty = self.unifier.get_fresh_var_with_range(
|
let ret_ty = self.unifier.get_fresh_var_with_range(
|
||||||
&[int_sized, ndarray_int_sized],
|
&[int_sized, ndarray_int_sized.into()],
|
||||||
Some("R".into()),
|
Some("R".into()),
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
@ -930,19 +955,30 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
|
||||||
let ndarray_float =
|
let ndarray_float = primitive_type::NDArrayType::from_primitive(
|
||||||
make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.ty));
|
self.unifier,
|
||||||
|
self.primitives,
|
||||||
|
Some(float),
|
||||||
|
Some(common_ndim.ty),
|
||||||
|
);
|
||||||
|
|
||||||
// The size variant of the function determines the type of int returned
|
// The size variant of the function determines the type of int returned
|
||||||
let int_sized = size_variant.of_int(self.primitives);
|
let int_sized = size_variant.of_int(self.primitives);
|
||||||
let ndarray_int_sized =
|
let ndarray_int_sized = primitive_type::NDArrayType::from_primitive(
|
||||||
make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.ty));
|
self.unifier,
|
||||||
|
self.primitives,
|
||||||
|
Some(int_sized),
|
||||||
|
Some(common_ndim.ty),
|
||||||
|
);
|
||||||
|
|
||||||
let p0_ty =
|
let p0_ty = self.unifier.get_fresh_var_with_range(
|
||||||
self.unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None);
|
&[float, ndarray_float.into()],
|
||||||
|
Some("T".into()),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
let ret_ty = self.unifier.get_fresh_var_with_range(
|
let ret_ty = self.unifier.get_fresh_var_with_range(
|
||||||
&[int_sized, ndarray_int_sized],
|
&[int_sized, ndarray_int_sized.into()],
|
||||||
Some("R".into()),
|
Some("R".into()),
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
@ -1005,7 +1041,7 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
self.unifier,
|
self.unifier,
|
||||||
&VarMap::new(),
|
&VarMap::new(),
|
||||||
prim.name(),
|
prim.name(),
|
||||||
self.ndarray_float,
|
self.ndarray_float.into(),
|
||||||
&[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
|
&[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
|
||||||
Box::new(move |ctx, obj, fun, args, generator| {
|
Box::new(move |ctx, obj, fun, args, generator| {
|
||||||
let func = match prim {
|
let func = match prim {
|
||||||
@ -1051,7 +1087,7 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
default_value: Some(SymbolValue::U32(0)),
|
default_value: Some(SymbolValue::U32(0)),
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
ret: ndarray,
|
ret: ndarray.into(),
|
||||||
vars: into_var_map([tv]),
|
vars: into_var_map([tv]),
|
||||||
})),
|
})),
|
||||||
var_id: vec![tv.id],
|
var_id: vec![tv.id],
|
||||||
@ -1074,7 +1110,7 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
self.unifier,
|
self.unifier,
|
||||||
&into_var_map([tv]),
|
&into_var_map([tv]),
|
||||||
prim.name(),
|
prim.name(),
|
||||||
self.primitives.ndarray,
|
self.primitives.ndarray.into(),
|
||||||
// We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a
|
// We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a
|
||||||
// type variable
|
// type variable
|
||||||
&[(self.list_int32, "shape"), (tv.ty, "fill_value")],
|
&[(self.list_int32, "shape"), (tv.ty, "fill_value")],
|
||||||
@ -1103,7 +1139,7 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
default_value: Some(SymbolValue::I32(0)),
|
default_value: Some(SymbolValue::I32(0)),
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
ret: self.ndarray_float_2d,
|
ret: self.ndarray_float_2d.into(),
|
||||||
vars: VarMap::default(),
|
vars: VarMap::default(),
|
||||||
})),
|
})),
|
||||||
var_id: Vec::default(),
|
var_id: Vec::default(),
|
||||||
@ -1123,7 +1159,7 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
self.unifier,
|
self.unifier,
|
||||||
&VarMap::new(),
|
&VarMap::new(),
|
||||||
prim.name(),
|
prim.name(),
|
||||||
self.ndarray_float_2d,
|
self.ndarray_float_2d.into(),
|
||||||
&[(int32, "n")],
|
&[(int32, "n")],
|
||||||
Box::new(|ctx, obj, fun, args, generator| {
|
Box::new(|ctx, obj, fun, args, generator| {
|
||||||
gen_ndarray_identity(ctx, &obj, fun, &args, generator)
|
gen_ndarray_identity(ctx, &obj, fun, &args, generator)
|
||||||
@ -1338,10 +1374,15 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
let tvar = self.unifier.get_fresh_var(Some("L".into()), None);
|
let tvar = self.unifier.get_fresh_var(Some("L".into()), None);
|
||||||
let list = self.unifier.add_ty(TypeEnum::TList { ty: tvar.ty });
|
let list = self.unifier.add_ty(TypeEnum::TList { ty: tvar.ty });
|
||||||
let ndims = self.unifier.get_fresh_const_generic_var(uint64, Some("N".into()), None);
|
let ndims = self.unifier.get_fresh_const_generic_var(uint64, Some("N".into()), None);
|
||||||
let ndarray = make_ndarray_ty(self.unifier, self.primitives, Some(tvar.ty), Some(ndims.ty));
|
let ndarray = primitive_type::NDArrayType::from_primitive(
|
||||||
|
self.unifier,
|
||||||
|
self.primitives,
|
||||||
|
Some(tvar.ty),
|
||||||
|
Some(ndims.ty),
|
||||||
|
);
|
||||||
|
|
||||||
let arg_ty = self.unifier.get_fresh_var_with_range(
|
let arg_ty = self.unifier.get_fresh_var_with_range(
|
||||||
&[list, ndarray, self.primitives.range],
|
&[list, ndarray.into(), self.primitives.range],
|
||||||
Some("I".into()),
|
Some("I".into()),
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
@ -1799,8 +1840,13 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn new_type_or_ndarray_ty(&mut self, scalar_ty: Type) -> TypeVar {
|
fn new_type_or_ndarray_ty(&mut self, scalar_ty: Type) -> TypeVar {
|
||||||
let ndarray = make_ndarray_ty(self.unifier, self.primitives, Some(scalar_ty), None);
|
let ndarray = primitive_type::NDArrayType::from_primitive(
|
||||||
|
self.unifier,
|
||||||
|
self.primitives,
|
||||||
|
Some(scalar_ty),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
self.unifier.get_fresh_var_with_range(&[scalar_ty, ndarray], Some("T".into()), None)
|
self.unifier.get_fresh_var_with_range(&[scalar_ty, ndarray.into()], Some("T".into()), None)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,14 +1,13 @@
|
|||||||
use std::convert::TryInto;
|
use std::convert::TryInto;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
use crate::symbol_resolver::SymbolValue;
|
use crate::symbol_resolver::SymbolValue;
|
||||||
use crate::toplevel::numpy::unpack_ndarray_var_tys;
|
use crate::toplevel::primitive_type::{NDArrayType, OptionType};
|
||||||
use crate::typecheck::typedef::{into_var_map, Mapping, TypeVarId, VarMap};
|
use crate::typecheck::typedef::{into_var_map, GenericObjectType, Mapping, TypeVarId, VarMap};
|
||||||
use nac3parser::ast::{Constant, Location};
|
use nac3parser::ast::{Constant, Location};
|
||||||
use strum::IntoEnumIterator;
|
use strum::IntoEnumIterator;
|
||||||
use strum_macros::EnumIter;
|
use strum_macros::EnumIter;
|
||||||
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
/// All primitive types and functions in nac3core.
|
/// All primitive types and functions in nac3core.
|
||||||
#[derive(Clone, Copy, Debug, EnumIter, PartialEq, Eq)]
|
#[derive(Clone, Copy, Debug, EnumIter, PartialEq, Eq)]
|
||||||
pub enum PrimDef {
|
pub enum PrimDef {
|
||||||
@ -403,6 +402,7 @@ impl TopLevelComposer {
|
|||||||
.collect::<HashMap<_, _>>(),
|
.collect::<HashMap<_, _>>(),
|
||||||
params: into_var_map([option_type_var]),
|
params: into_var_map([option_type_var]),
|
||||||
});
|
});
|
||||||
|
let option = OptionType::create(option, &mut unifier);
|
||||||
|
|
||||||
let size_t_ty = match size_t {
|
let size_t_ty = match size_t {
|
||||||
32 => uint32,
|
32 => uint32,
|
||||||
@ -436,8 +436,9 @@ impl TopLevelComposer {
|
|||||||
]),
|
]),
|
||||||
params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
|
params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
|
||||||
});
|
});
|
||||||
|
let ndarray = NDArrayType::create(ndarray, &mut unifier);
|
||||||
|
|
||||||
unifier.unify(ndarray_copy_fun_ret_ty.ty, ndarray).unwrap();
|
unifier.unify(ndarray_copy_fun_ret_ty.ty, ndarray.into()).unwrap();
|
||||||
|
|
||||||
let primitives = PrimitiveStore {
|
let primitives = PrimitiveStore {
|
||||||
int32,
|
int32,
|
||||||
@ -747,7 +748,7 @@ impl TopLevelComposer {
|
|||||||
TypeAnnotation::CustomClass { id: e_id, params: e_param },
|
TypeAnnotation::CustomClass { id: e_id, params: e_param },
|
||||||
) => {
|
) => {
|
||||||
*f_id == *e_id
|
*f_id == *e_id
|
||||||
&& *f_id == primitive.option.obj_id(unifier).unwrap()
|
&& *f_id == primitive.option.obj_id(unifier)
|
||||||
&& (f_param.is_empty()
|
&& (f_param.is_empty()
|
||||||
|| (f_param.len() == 1
|
|| (f_param.len() == 1
|
||||||
&& e_param.len() == 1
|
&& e_param.len() == 1
|
||||||
@ -885,7 +886,7 @@ pub fn parse_parameter_default_value(
|
|||||||
pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type {
|
pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type {
|
||||||
match &*unifier.get_ty(ty) {
|
match &*unifier.get_ty(ty) {
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||||
unpack_ndarray_var_tys(unifier, ty).0
|
NDArrayType::create(ty, unifier).dtype_tvar(unifier).ty
|
||||||
}
|
}
|
||||||
|
|
||||||
TypeEnum::TList { ty } => arraylike_flatten_element_type(unifier, *ty),
|
TypeEnum::TList { ty } => arraylike_flatten_element_type(unifier, *ty),
|
||||||
@ -897,7 +898,7 @@ pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type {
|
|||||||
pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 {
|
pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 {
|
||||||
match &*unifier.get_ty(ty) {
|
match &*unifier.get_ty(ty) {
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||||
let ndims = unpack_ndarray_var_tys(unifier, ty).1;
|
let ndims = NDArrayType::create(ty, unifier).ndims_tvar(unifier).ty;
|
||||||
let TypeEnum::TLiteral { values, .. } = &*unifier.get_ty_immutable(ndims) else {
|
let TypeEnum::TLiteral { values, .. } = &*unifier.get_ty_immutable(ndims) else {
|
||||||
panic!("Expected TLiteral for ndarray.ndims, got {}", unifier.stringify(ndims))
|
panic!("Expected TLiteral for ndarray.ndims, got {}", unifier.stringify(ndims))
|
||||||
};
|
};
|
||||||
|
@ -30,7 +30,7 @@ pub struct DefinitionId(pub usize);
|
|||||||
pub mod builtins;
|
pub mod builtins;
|
||||||
pub mod composer;
|
pub mod composer;
|
||||||
pub mod helper;
|
pub mod helper;
|
||||||
pub mod numpy;
|
pub mod primitive_type;
|
||||||
pub mod type_annotation;
|
pub mod type_annotation;
|
||||||
use composer::*;
|
use composer::*;
|
||||||
use type_annotation::*;
|
use type_annotation::*;
|
||||||
|
@ -1,85 +0,0 @@
|
|||||||
use crate::{
|
|
||||||
toplevel::helper::PrimDef,
|
|
||||||
typecheck::{
|
|
||||||
type_inferencer::PrimitiveStore,
|
|
||||||
typedef::{Type, TypeEnum, TypeVarId, Unifier, VarMap},
|
|
||||||
},
|
|
||||||
};
|
|
||||||
use itertools::Itertools;
|
|
||||||
|
|
||||||
/// Creates a `ndarray` [`Type`] with the given type arguments.
|
|
||||||
///
|
|
||||||
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
|
|
||||||
/// specialized.
|
|
||||||
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
|
|
||||||
/// specialized.
|
|
||||||
pub fn make_ndarray_ty(
|
|
||||||
unifier: &mut Unifier,
|
|
||||||
primitives: &PrimitiveStore,
|
|
||||||
dtype: Option<Type>,
|
|
||||||
ndims: Option<Type>,
|
|
||||||
) -> Type {
|
|
||||||
subst_ndarray_tvars(unifier, primitives.ndarray, dtype, ndims)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Substitutes type variables in `ndarray`.
|
|
||||||
///
|
|
||||||
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
|
|
||||||
/// specialized.
|
|
||||||
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
|
|
||||||
/// specialized.
|
|
||||||
pub fn subst_ndarray_tvars(
|
|
||||||
unifier: &mut Unifier,
|
|
||||||
ndarray: Type,
|
|
||||||
dtype: Option<Type>,
|
|
||||||
ndims: Option<Type>,
|
|
||||||
) -> Type {
|
|
||||||
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
|
|
||||||
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
|
|
||||||
};
|
|
||||||
debug_assert_eq!(*obj_id, PrimDef::NDArray.id());
|
|
||||||
|
|
||||||
if dtype.is_none() && ndims.is_none() {
|
|
||||||
return ndarray;
|
|
||||||
}
|
|
||||||
|
|
||||||
let tvar_ids = params.iter().map(|(obj_id, _)| *obj_id).collect_vec();
|
|
||||||
debug_assert_eq!(tvar_ids.len(), 2);
|
|
||||||
|
|
||||||
let mut tvar_subst = VarMap::new();
|
|
||||||
if let Some(dtype) = dtype {
|
|
||||||
tvar_subst.insert(tvar_ids[0], dtype);
|
|
||||||
}
|
|
||||||
if let Some(ndims) = ndims {
|
|
||||||
tvar_subst.insert(tvar_ids[1], ndims);
|
|
||||||
}
|
|
||||||
|
|
||||||
unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn unpack_ndarray_tvars(unifier: &mut Unifier, ndarray: Type) -> Vec<(TypeVarId, Type)> {
|
|
||||||
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
|
|
||||||
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
|
|
||||||
};
|
|
||||||
debug_assert_eq!(*obj_id, PrimDef::NDArray.id());
|
|
||||||
debug_assert_eq!(params.len(), 2);
|
|
||||||
|
|
||||||
params
|
|
||||||
.iter()
|
|
||||||
.sorted_by_key(|(obj_id, _)| *obj_id)
|
|
||||||
.map(|(var_id, ty)| (*var_id, *ty))
|
|
||||||
.collect_vec()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds
|
|
||||||
/// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray`
|
|
||||||
/// respectively.
|
|
||||||
pub fn unpack_ndarray_var_ids(unifier: &mut Unifier, ndarray: Type) -> (TypeVarId, TypeVarId) {
|
|
||||||
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.0).collect_tuple().unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to
|
|
||||||
/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively.
|
|
||||||
pub fn unpack_ndarray_var_tys(unifier: &mut Unifier, ndarray: Type) -> (Type, Type) {
|
|
||||||
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.1).collect_tuple().unwrap()
|
|
||||||
}
|
|
98
nac3core/src/toplevel/primitive_type.rs
Normal file
98
nac3core/src/toplevel/primitive_type.rs
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
use crate::toplevel::helper::PrimDef;
|
||||||
|
use crate::typecheck::type_inferencer::PrimitiveStore;
|
||||||
|
use crate::typecheck::typedef::{GenericObjectType, Type, TypeVar, Unifier, VarMap};
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
pub struct OptionType(Type);
|
||||||
|
|
||||||
|
impl OptionType {
|
||||||
|
pub fn from_primitive(
|
||||||
|
unifier: &mut Unifier,
|
||||||
|
primitives: &PrimitiveStore,
|
||||||
|
type_ty: Option<Type>,
|
||||||
|
) -> Self {
|
||||||
|
primitives.option.subst(unifier, type_ty)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn type_tvar(&self, unifier: &mut Unifier) -> TypeVar {
|
||||||
|
self.get_var_at(unifier, 0).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn subst(&self, unifier: &mut Unifier, type_ty: Option<Type>) -> Self {
|
||||||
|
let new_vars = [(self.type_tvar(unifier).id, type_ty)]
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|(id, ty)| ty.map(|ty| (id, ty)))
|
||||||
|
.collect::<VarMap>();
|
||||||
|
|
||||||
|
let new_ty = unifier.subst(self.get_type(), &new_vars).unwrap_or(self.get_type());
|
||||||
|
OptionType(new_ty)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GenericObjectType for OptionType {
|
||||||
|
fn try_create(ty: Type, unifier: &mut Unifier) -> Option<Self> {
|
||||||
|
if ty.obj_id(unifier).is_some_and(|id| id == PrimDef::Option.id()) {
|
||||||
|
Some(OptionType(ty))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_type(&self) -> Type {
|
||||||
|
self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
pub struct NDArrayType(Type);
|
||||||
|
|
||||||
|
impl NDArrayType {
|
||||||
|
pub fn from_primitive(
|
||||||
|
unifier: &mut Unifier,
|
||||||
|
primitives: &PrimitiveStore,
|
||||||
|
dtype: Option<Type>,
|
||||||
|
ndims: Option<Type>,
|
||||||
|
) -> Self {
|
||||||
|
primitives.ndarray.subst(unifier, dtype, ndims)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dtype_tvar(&self, unifier: &mut Unifier) -> TypeVar {
|
||||||
|
self.get_var_at(unifier, 0).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ndims_tvar(&self, unifier: &mut Unifier) -> TypeVar {
|
||||||
|
self.get_var_at(unifier, 1).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn subst(
|
||||||
|
&self,
|
||||||
|
unifier: &mut Unifier,
|
||||||
|
dtype_ty: Option<Type>,
|
||||||
|
ndims_ty: Option<Type>,
|
||||||
|
) -> Self {
|
||||||
|
let new_vars =
|
||||||
|
[(self.dtype_tvar(unifier).id, dtype_ty), (self.ndims_tvar(unifier).id, ndims_ty)]
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|(id, ty)| ty.map(|ty| (id, ty)))
|
||||||
|
.collect::<VarMap>();
|
||||||
|
|
||||||
|
let new_ty = unifier.subst(self.get_type(), &new_vars).unwrap_or(self.get_type());
|
||||||
|
NDArrayType(new_ty)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GenericObjectType for NDArrayType {
|
||||||
|
fn try_create(ty: Type, unifier: &mut Unifier) -> Option<Self> {
|
||||||
|
if ty.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
||||||
|
Some(NDArrayType(ty))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_type(&self) -> Type {
|
||||||
|
self.0
|
||||||
|
}
|
||||||
|
}
|
@ -1,7 +1,7 @@
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::symbol_resolver::SymbolValue;
|
use crate::symbol_resolver::SymbolValue;
|
||||||
use crate::toplevel::helper::PrimDef;
|
use crate::toplevel::helper::PrimDef;
|
||||||
use crate::typecheck::typedef::VarMap;
|
use crate::typecheck::typedef::{GenericObjectType, VarMap};
|
||||||
use nac3parser::ast::Constant;
|
use nac3parser::ast::Constant;
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
@ -267,12 +267,7 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
|
|||||||
slice.as_ref(),
|
slice.as_ref(),
|
||||||
locked,
|
locked,
|
||||||
)?;
|
)?;
|
||||||
let id =
|
let id = primitives.option.obj_id(unifier);
|
||||||
if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(primitives.option).as_ref() {
|
|
||||||
*obj_id
|
|
||||||
} else {
|
|
||||||
unreachable!()
|
|
||||||
};
|
|
||||||
Ok(TypeAnnotation::CustomClass { id, params: vec![def_ann] })
|
Ok(TypeAnnotation::CustomClass { id, params: vec![def_ann] })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
use crate::symbol_resolver::SymbolValue;
|
use crate::symbol_resolver::SymbolValue;
|
||||||
use crate::toplevel::helper::PrimDef;
|
use crate::toplevel::helper::PrimDef;
|
||||||
use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys};
|
use crate::toplevel::primitive_type;
|
||||||
use crate::typecheck::{
|
use crate::typecheck::{
|
||||||
type_inferencer::*,
|
type_inferencer::*,
|
||||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
|
typedef::{FunSignature, FuncArg, GenericObjectType, Type, TypeEnum, Unifier, VarMap},
|
||||||
};
|
};
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use nac3parser::ast::StrRef;
|
use nac3parser::ast::StrRef;
|
||||||
@ -369,8 +369,12 @@ pub fn typeof_ndarray_broadcast(
|
|||||||
if is_left_ndarray && is_right_ndarray {
|
if is_left_ndarray && is_right_ndarray {
|
||||||
// Perform broadcasting on two ndarray operands.
|
// Perform broadcasting on two ndarray operands.
|
||||||
|
|
||||||
let (left_ty_dtype, left_ty_ndims) = unpack_ndarray_var_tys(unifier, left);
|
let left_ty = primitive_type::NDArrayType::create(left, unifier);
|
||||||
let (right_ty_dtype, right_ty_ndims) = unpack_ndarray_var_tys(unifier, right);
|
let left_ty_dtype = left_ty.dtype_tvar(unifier).ty;
|
||||||
|
let left_ty_ndims = left_ty.ndims_tvar(unifier).ty;
|
||||||
|
let right_ty = primitive_type::NDArrayType::create(right, unifier);
|
||||||
|
let right_ty_dtype = right_ty.dtype_tvar(unifier).ty;
|
||||||
|
let right_ty_ndims = right_ty.ndims_tvar(unifier).ty;
|
||||||
|
|
||||||
assert!(unifier.unioned(left_ty_dtype, right_ty_dtype));
|
assert!(unifier.unioned(left_ty_dtype, right_ty_dtype));
|
||||||
|
|
||||||
@ -397,11 +401,18 @@ pub fn typeof_ndarray_broadcast(
|
|||||||
.collect_vec();
|
.collect_vec();
|
||||||
let res_ndims = unifier.get_fresh_literal(res_ndims, None);
|
let res_ndims = unifier.get_fresh_literal(res_ndims, None);
|
||||||
|
|
||||||
Ok(make_ndarray_ty(unifier, primitives, Some(left_ty_dtype), Some(res_ndims)))
|
Ok(primitive_type::NDArrayType::from_primitive(
|
||||||
|
unifier,
|
||||||
|
primitives,
|
||||||
|
Some(left_ty_dtype),
|
||||||
|
Some(res_ndims),
|
||||||
|
)
|
||||||
|
.into())
|
||||||
} else {
|
} else {
|
||||||
let (ndarray_ty, scalar_ty) = if is_left_ndarray { (left, right) } else { (right, left) };
|
let (ndarray_ty, scalar_ty) = if is_left_ndarray { (left, right) } else { (right, left) };
|
||||||
|
|
||||||
let (ndarray_ty_dtype, _) = unpack_ndarray_var_tys(unifier, ndarray_ty);
|
let ndarray_ty_dtype =
|
||||||
|
primitive_type::NDArrayType::create(ndarray_ty, unifier).ndims_tvar(unifier).ty;
|
||||||
|
|
||||||
if unifier.unioned(ndarray_ty_dtype, scalar_ty) {
|
if unifier.unioned(ndarray_ty_dtype, scalar_ty) {
|
||||||
Ok(ndarray_ty)
|
Ok(ndarray_ty)
|
||||||
@ -444,7 +455,8 @@ pub fn typeof_binop(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Operator::MatMult => {
|
Operator::MatMult => {
|
||||||
let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs);
|
let lhs_ndims =
|
||||||
|
primitive_type::NDArrayType::create(lhs, unifier).ndims_tvar(unifier).ty;
|
||||||
let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) {
|
let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) {
|
||||||
TypeEnum::TLiteral { values, .. } => {
|
TypeEnum::TLiteral { values, .. } => {
|
||||||
assert_eq!(values.len(), 1);
|
assert_eq!(values.len(), 1);
|
||||||
@ -452,7 +464,8 @@ pub fn typeof_binop(
|
|||||||
}
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs);
|
let rhs_ndims =
|
||||||
|
primitive_type::NDArrayType::create(rhs, unifier).ndims_tvar(unifier).ty;
|
||||||
let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) {
|
let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) {
|
||||||
TypeEnum::TLiteral { values, .. } => {
|
TypeEnum::TLiteral { values, .. } => {
|
||||||
assert_eq!(values.len(), 1);
|
assert_eq!(values.len(), 1);
|
||||||
@ -526,7 +539,7 @@ pub fn typeof_unaryop(
|
|||||||
let operand_obj_id = operand.obj_id(unifier);
|
let operand_obj_id = operand.obj_id(unifier);
|
||||||
|
|
||||||
if op == Unaryop::Not
|
if op == Unaryop::Not
|
||||||
&& operand_obj_id.is_some_and(|id| id == primitives.ndarray.obj_id(unifier).unwrap())
|
&& operand_obj_id.is_some_and(|id| id == primitives.ndarray.obj_id(unifier))
|
||||||
{
|
{
|
||||||
return Err(
|
return Err(
|
||||||
"The truth value of an array with more than one element is ambiguous".to_string()
|
"The truth value of an array with more than one element is ambiguous".to_string()
|
||||||
@ -552,7 +565,8 @@ pub fn typeof_unaryop(
|
|||||||
|
|
||||||
Unaryop::UAdd | Unaryop::USub => {
|
Unaryop::UAdd | Unaryop::USub => {
|
||||||
if operand_obj_id.is_some_and(|id| id == PrimDef::NDArray.id()) {
|
if operand_obj_id.is_some_and(|id| id == PrimDef::NDArray.id()) {
|
||||||
let (dtype, _) = unpack_ndarray_var_tys(unifier, operand);
|
let dtype =
|
||||||
|
primitive_type::NDArrayType::create(operand, unifier).dtype_tvar(unifier).ty;
|
||||||
if dtype.obj_id(unifier).is_some_and(|id| id == PrimDef::Bool.id()) {
|
if dtype.obj_id(unifier).is_some_and(|id| id == PrimDef::Bool.id()) {
|
||||||
return Err(if op == Unaryop::UAdd {
|
return Err(if op == Unaryop::UAdd {
|
||||||
"The ufunc 'positive' cannot be applied to ndarray[bool, N]".to_string()
|
"The ufunc 'positive' cannot be applied to ndarray[bool, N]".to_string()
|
||||||
@ -586,9 +600,15 @@ pub fn typeof_cmpop(
|
|||||||
|
|
||||||
Ok(Some(if is_left_ndarray || is_right_ndarray {
|
Ok(Some(if is_left_ndarray || is_right_ndarray {
|
||||||
let brd = typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?;
|
let brd = typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?;
|
||||||
let (_, ndims) = unpack_ndarray_var_tys(unifier, brd);
|
let ndims = primitive_type::NDArrayType::create(brd, unifier).ndims_tvar(unifier).ty;
|
||||||
|
|
||||||
make_ndarray_ty(unifier, primitives, Some(primitives.bool), Some(ndims))
|
primitive_type::NDArrayType::from_primitive(
|
||||||
|
unifier,
|
||||||
|
primitives,
|
||||||
|
Some(primitives.bool),
|
||||||
|
Some(ndims),
|
||||||
|
)
|
||||||
|
.into()
|
||||||
} else if unifier.unioned(lhs, rhs) {
|
} else if unifier.unioned(lhs, rhs) {
|
||||||
primitives.bool
|
primitives.bool
|
||||||
} else {
|
} else {
|
||||||
@ -611,64 +631,108 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
|||||||
|
|
||||||
/* int ======== */
|
/* int ======== */
|
||||||
for t in [int32_t, int64_t, uint32_t, uint64_t] {
|
for t in [int32_t, int64_t, uint32_t, uint64_t] {
|
||||||
let ndarray_int_t = make_ndarray_ty(unifier, store, Some(t), None);
|
let ndarray_int_t =
|
||||||
impl_basic_arithmetic(unifier, store, t, &[t, ndarray_int_t], None);
|
primitive_type::NDArrayType::from_primitive(unifier, store, Some(t), None);
|
||||||
impl_pow(unifier, store, t, &[t, ndarray_int_t], None);
|
impl_basic_arithmetic(unifier, store, t, &[t, ndarray_int_t.into()], None);
|
||||||
|
impl_pow(unifier, store, t, &[t, ndarray_int_t.into()], None);
|
||||||
impl_bitwise_arithmetic(unifier, store, t);
|
impl_bitwise_arithmetic(unifier, store, t);
|
||||||
impl_bitwise_shift(unifier, store, t);
|
impl_bitwise_shift(unifier, store, t);
|
||||||
impl_div(unifier, store, t, &[t, ndarray_int_t], None);
|
impl_div(unifier, store, t, &[t, ndarray_int_t.into()], None);
|
||||||
impl_floordiv(unifier, store, t, &[t, ndarray_int_t], None);
|
impl_floordiv(unifier, store, t, &[t, ndarray_int_t.into()], None);
|
||||||
impl_mod(unifier, store, t, &[t, ndarray_int_t], None);
|
impl_mod(unifier, store, t, &[t, ndarray_int_t.into()], None);
|
||||||
impl_invert(unifier, store, t, Some(t));
|
impl_invert(unifier, store, t, Some(t));
|
||||||
impl_not(unifier, store, t, Some(bool_t));
|
impl_not(unifier, store, t, Some(bool_t));
|
||||||
impl_comparison(unifier, store, t, &[t, ndarray_int_t], None);
|
impl_comparison(unifier, store, t, &[t, ndarray_int_t.into()], None);
|
||||||
impl_eq(unifier, store, t, &[t, ndarray_int_t], None);
|
impl_eq(unifier, store, t, &[t, ndarray_int_t.into()], None);
|
||||||
}
|
}
|
||||||
for t in [int32_t, int64_t] {
|
for t in [int32_t, int64_t] {
|
||||||
impl_sign(unifier, store, t, Some(t));
|
impl_sign(unifier, store, t, Some(t));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* float ======== */
|
/* float ======== */
|
||||||
let ndarray_float_t = make_ndarray_ty(unifier, store, Some(float_t), None);
|
let ndarray_float_t =
|
||||||
let ndarray_int32_t = make_ndarray_ty(unifier, store, Some(int32_t), None);
|
primitive_type::NDArrayType::from_primitive(unifier, store, Some(float_t), None);
|
||||||
impl_basic_arithmetic(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
let ndarray_int32_t =
|
||||||
impl_pow(unifier, store, float_t, &[int32_t, float_t, ndarray_int32_t, ndarray_float_t], None);
|
primitive_type::NDArrayType::from_primitive(unifier, store, Some(int32_t), None);
|
||||||
impl_div(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
impl_basic_arithmetic(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None);
|
||||||
impl_floordiv(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
impl_pow(
|
||||||
impl_mod(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
unifier,
|
||||||
|
store,
|
||||||
|
float_t,
|
||||||
|
&[int32_t, float_t, ndarray_int32_t.into(), ndarray_float_t.into()],
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
impl_div(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None);
|
||||||
|
impl_floordiv(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None);
|
||||||
|
impl_mod(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None);
|
||||||
impl_sign(unifier, store, float_t, Some(float_t));
|
impl_sign(unifier, store, float_t, Some(float_t));
|
||||||
impl_not(unifier, store, float_t, Some(bool_t));
|
impl_not(unifier, store, float_t, Some(bool_t));
|
||||||
impl_comparison(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
impl_comparison(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None);
|
||||||
impl_eq(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
impl_eq(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None);
|
||||||
|
|
||||||
/* bool ======== */
|
/* bool ======== */
|
||||||
let ndarray_bool_t = make_ndarray_ty(unifier, store, Some(bool_t), None);
|
let ndarray_bool_t =
|
||||||
|
primitive_type::NDArrayType::from_primitive(unifier, store, Some(bool_t), None);
|
||||||
impl_invert(unifier, store, bool_t, Some(int32_t));
|
impl_invert(unifier, store, bool_t, Some(int32_t));
|
||||||
impl_not(unifier, store, bool_t, Some(bool_t));
|
impl_not(unifier, store, bool_t, Some(bool_t));
|
||||||
impl_sign(unifier, store, bool_t, Some(int32_t));
|
impl_sign(unifier, store, bool_t, Some(int32_t));
|
||||||
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None);
|
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t.into()], None);
|
||||||
|
|
||||||
/* ndarray ===== */
|
/* ndarray ===== */
|
||||||
let ndarray_usized_ndims_tvar =
|
let ndarray_usized_ndims_tvar =
|
||||||
unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None);
|
unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None);
|
||||||
let ndarray_unsized_t =
|
let ndarray_unsized_t = primitive_type::NDArrayType::from_primitive(
|
||||||
make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.ty));
|
unifier,
|
||||||
let (ndarray_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_t);
|
store,
|
||||||
let (ndarray_unsized_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_unsized_t);
|
None,
|
||||||
|
Some(ndarray_usized_ndims_tvar.ty),
|
||||||
|
);
|
||||||
|
let ndarray_dtype_t = ndarray_t.dtype_tvar(unifier).ty;
|
||||||
|
let ndarray_unsized_dtype_t = ndarray_unsized_t.dtype_tvar(unifier).ty;
|
||||||
impl_basic_arithmetic(
|
impl_basic_arithmetic(
|
||||||
unifier,
|
unifier,
|
||||||
store,
|
store,
|
||||||
ndarray_t,
|
ndarray_t.into(),
|
||||||
&[ndarray_unsized_t, ndarray_unsized_dtype_t],
|
&[ndarray_unsized_t.into(), ndarray_unsized_dtype_t],
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
impl_pow(
|
||||||
|
unifier,
|
||||||
|
store,
|
||||||
|
ndarray_t.into(),
|
||||||
|
&[ndarray_unsized_t.into(), ndarray_unsized_dtype_t],
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
impl_div(unifier, store, ndarray_t.into(), &[ndarray_t.into(), ndarray_dtype_t], None);
|
||||||
|
impl_floordiv(
|
||||||
|
unifier,
|
||||||
|
store,
|
||||||
|
ndarray_t.into(),
|
||||||
|
&[ndarray_unsized_t.into(), ndarray_unsized_dtype_t],
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
impl_mod(
|
||||||
|
unifier,
|
||||||
|
store,
|
||||||
|
ndarray_t.into(),
|
||||||
|
&[ndarray_unsized_t.into(), ndarray_unsized_dtype_t],
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
impl_matmul(unifier, store, ndarray_t.into(), &[ndarray_t.into()], Some(ndarray_t.into()));
|
||||||
|
impl_sign(unifier, store, ndarray_t.into(), Some(ndarray_t.into()));
|
||||||
|
impl_invert(unifier, store, ndarray_t.into(), Some(ndarray_t.into()));
|
||||||
|
impl_eq(
|
||||||
|
unifier,
|
||||||
|
store,
|
||||||
|
ndarray_t.into(),
|
||||||
|
&[ndarray_unsized_t.into(), ndarray_unsized_dtype_t],
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
impl_comparison(
|
||||||
|
unifier,
|
||||||
|
store,
|
||||||
|
ndarray_t.into(),
|
||||||
|
&[ndarray_unsized_t.into(), ndarray_unsized_dtype_t],
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
impl_pow(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
|
||||||
impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None);
|
|
||||||
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
|
||||||
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
|
||||||
impl_matmul(unifier, store, ndarray_t, &[ndarray_t], Some(ndarray_t));
|
|
||||||
impl_sign(unifier, store, ndarray_t, Some(ndarray_t));
|
|
||||||
impl_invert(unifier, store, ndarray_t, Some(ndarray_t));
|
|
||||||
impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
|
||||||
impl_comparison(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
|
||||||
}
|
}
|
||||||
|
@ -4,14 +4,16 @@ use std::iter::once;
|
|||||||
use std::ops::Not;
|
use std::ops::Not;
|
||||||
use std::{cell::RefCell, sync::Arc};
|
use std::{cell::RefCell, sync::Arc};
|
||||||
|
|
||||||
use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap};
|
use super::typedef::{
|
||||||
|
Call, FunSignature, FuncArg, GenericObjectType, RecordField, Type, TypeEnum, Unifier, VarMap,
|
||||||
|
};
|
||||||
use super::{magic_methods::*, type_error::TypeError, typedef::CallId};
|
use super::{magic_methods::*, type_error::TypeError, typedef::CallId};
|
||||||
|
use crate::toplevel::primitive_type::{NDArrayType, OptionType};
|
||||||
use crate::toplevel::TopLevelDef;
|
use crate::toplevel::TopLevelDef;
|
||||||
use crate::{
|
use crate::{
|
||||||
symbol_resolver::{SymbolResolver, SymbolValue},
|
symbol_resolver::{SymbolResolver, SymbolValue},
|
||||||
toplevel::{
|
toplevel::{
|
||||||
helper::{arraylike_flatten_element_type, arraylike_get_ndims, PrimDef},
|
helper::{arraylike_flatten_element_type, arraylike_get_ndims, PrimDef},
|
||||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
|
||||||
TopLevelContext,
|
TopLevelContext,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
@ -49,8 +51,8 @@ pub struct PrimitiveStore {
|
|||||||
pub range: Type,
|
pub range: Type,
|
||||||
pub str: Type,
|
pub str: Type,
|
||||||
pub exception: Type,
|
pub exception: Type,
|
||||||
pub option: Type,
|
pub option: OptionType,
|
||||||
pub ndarray: Type,
|
pub ndarray: NDArrayType,
|
||||||
pub size_t: u32,
|
pub size_t: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -97,8 +99,8 @@ impl IntoIterator for &PrimitiveStore {
|
|||||||
self.range,
|
self.range,
|
||||||
self.str,
|
self.str,
|
||||||
self.exception,
|
self.exception,
|
||||||
self.option,
|
self.option.into(),
|
||||||
self.ndarray,
|
self.ndarray.into(),
|
||||||
]
|
]
|
||||||
.into_iter()
|
.into_iter()
|
||||||
}
|
}
|
||||||
@ -528,7 +530,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
|
|||||||
// the name `none` is special since it may have different types
|
// the name `none` is special since it may have different types
|
||||||
if id == &"none".into() {
|
if id == &"none".into() {
|
||||||
if let TypeEnum::TObj { params, .. } =
|
if let TypeEnum::TObj { params, .. } =
|
||||||
self.unifier.get_ty_immutable(self.primitives.option).as_ref()
|
&*self.unifier.get_ty_immutable(self.primitives.option.into())
|
||||||
{
|
{
|
||||||
let var_map = params
|
let var_map = params
|
||||||
.iter()
|
.iter()
|
||||||
@ -543,7 +545,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
|
|||||||
(*id, self.unifier.get_fresh_var_with_range(range, *name, *loc).ty)
|
(*id, self.unifier.get_fresh_var_with_range(range, *name, *loc).ty)
|
||||||
})
|
})
|
||||||
.collect::<VarMap>();
|
.collect::<VarMap>();
|
||||||
Some(self.unifier.subst(self.primitives.option, &var_map).unwrap())
|
Some(self.unifier.subst(self.primitives.option.into(), &var_map).unwrap())
|
||||||
} else {
|
} else {
|
||||||
unreachable!("must be tobj")
|
unreachable!("must be tobj")
|
||||||
}
|
}
|
||||||
@ -1062,9 +1064,16 @@ impl<'a> Inferencer<'a> {
|
|||||||
|
|
||||||
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
{
|
{
|
||||||
let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
|
let ndarray_ndims =
|
||||||
|
NDArrayType::create(arg0_ty, self.unifier).ndims_tvar(self.unifier).ty;
|
||||||
|
|
||||||
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims))
|
NDArrayType::from_primitive(
|
||||||
|
self.unifier,
|
||||||
|
self.primitives,
|
||||||
|
Some(target_ty),
|
||||||
|
Some(ndarray_ndims),
|
||||||
|
)
|
||||||
|
.into()
|
||||||
} else {
|
} else {
|
||||||
target_ty
|
target_ty
|
||||||
};
|
};
|
||||||
@ -1100,9 +1109,7 @@ impl<'a> Inferencer<'a> {
|
|||||||
|
|
||||||
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
{
|
{
|
||||||
let (ndarray_dtype, _) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
|
NDArrayType::create(arg0_ty, self.unifier).dtype_tvar(self.unifier).ty
|
||||||
|
|
||||||
ndarray_dtype
|
|
||||||
} else {
|
} else {
|
||||||
arg0_ty
|
arg0_ty
|
||||||
};
|
};
|
||||||
@ -1154,14 +1161,14 @@ impl<'a> Inferencer<'a> {
|
|||||||
|
|
||||||
let arg0_dtype =
|
let arg0_dtype =
|
||||||
if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
||||||
unpack_ndarray_var_tys(self.unifier, arg0_ty).0
|
NDArrayType::create(arg0_ty, self.unifier).dtype_tvar(self.unifier).ty
|
||||||
} else {
|
} else {
|
||||||
arg0_ty
|
arg0_ty
|
||||||
};
|
};
|
||||||
|
|
||||||
let arg1_dtype =
|
let arg1_dtype =
|
||||||
if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
||||||
unpack_ndarray_var_tys(self.unifier, arg1_ty).0
|
NDArrayType::create(arg1_ty, self.unifier).dtype_tvar(self.unifier).ty
|
||||||
} else {
|
} else {
|
||||||
arg1_ty
|
arg1_ty
|
||||||
};
|
};
|
||||||
@ -1192,9 +1199,17 @@ impl<'a> Inferencer<'a> {
|
|||||||
// (float, int32), so convert it to align with the dtype of the first arg
|
// (float, int32), so convert it to align with the dtype of the first arg
|
||||||
let arg1_ty = if id == &"np_ldexp".into() {
|
let arg1_ty = if id == &"np_ldexp".into() {
|
||||||
if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
||||||
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, arg1_ty);
|
// let (_, ndims) = unpack_ndarray_var_tys(self.unifier, arg1_ty);
|
||||||
|
let ndims =
|
||||||
|
NDArrayType::create(arg1_ty, self.unifier).ndims_tvar(self.unifier).ty;
|
||||||
|
|
||||||
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndims))
|
NDArrayType::from_primitive(
|
||||||
|
self.unifier,
|
||||||
|
self.primitives,
|
||||||
|
Some(target_ty),
|
||||||
|
Some(ndims),
|
||||||
|
)
|
||||||
|
.into()
|
||||||
} else {
|
} else {
|
||||||
target_ty
|
target_ty
|
||||||
}
|
}
|
||||||
@ -1281,9 +1296,16 @@ impl<'a> Inferencer<'a> {
|
|||||||
|
|
||||||
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
{
|
{
|
||||||
let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
|
let ndarray_ndims =
|
||||||
|
NDArrayType::create(arg0_ty, self.unifier).ndims_tvar(self.unifier).ty;
|
||||||
|
|
||||||
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims))
|
NDArrayType::from_primitive(
|
||||||
|
self.unifier,
|
||||||
|
self.primitives,
|
||||||
|
Some(target_ty),
|
||||||
|
Some(ndarray_ndims),
|
||||||
|
)
|
||||||
|
.into()
|
||||||
} else {
|
} else {
|
||||||
target_ty
|
target_ty
|
||||||
};
|
};
|
||||||
@ -1323,7 +1345,7 @@ impl<'a> Inferencer<'a> {
|
|||||||
self.fold_numpy_function_call_shape_argument(*id, 0, shape_expr)?; // Special handling for `shape`
|
self.fold_numpy_function_call_shape_argument(*id, 0, shape_expr)?; // Special handling for `shape`
|
||||||
|
|
||||||
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
|
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
|
||||||
let ret = make_ndarray_ty(
|
let ret = NDArrayType::from_primitive(
|
||||||
self.unifier,
|
self.unifier,
|
||||||
self.primitives,
|
self.primitives,
|
||||||
Some(self.primitives.float),
|
Some(self.primitives.float),
|
||||||
@ -1335,13 +1357,13 @@ impl<'a> Inferencer<'a> {
|
|||||||
ty: shape.custom.unwrap(),
|
ty: shape.custom.unwrap(),
|
||||||
default_value: None,
|
default_value: None,
|
||||||
}],
|
}],
|
||||||
ret,
|
ret: ret.into(),
|
||||||
vars: VarMap::new(),
|
vars: VarMap::new(),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
return Ok(Some(Located {
|
return Ok(Some(Located {
|
||||||
location,
|
location,
|
||||||
custom: Some(ret),
|
custom: Some(ret.into()),
|
||||||
node: ExprKind::Call {
|
node: ExprKind::Call {
|
||||||
func: Box::new(Located {
|
func: Box::new(Located {
|
||||||
custom: Some(custom),
|
custom: Some(custom),
|
||||||
@ -1374,7 +1396,8 @@ impl<'a> Inferencer<'a> {
|
|||||||
|
|
||||||
let ty = arg1.custom.unwrap();
|
let ty = arg1.custom.unwrap();
|
||||||
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
|
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
|
||||||
let ret = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims));
|
let ret =
|
||||||
|
NDArrayType::from_primitive(self.unifier, self.primitives, Some(ty), Some(ndims));
|
||||||
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
args: vec![
|
args: vec![
|
||||||
FuncArg { name: "shape".into(), ty: arg0.custom.unwrap(), default_value: None },
|
FuncArg { name: "shape".into(), ty: arg0.custom.unwrap(), default_value: None },
|
||||||
@ -1384,13 +1407,13 @@ impl<'a> Inferencer<'a> {
|
|||||||
default_value: None,
|
default_value: None,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
ret,
|
ret: ret.into(),
|
||||||
vars: VarMap::new(),
|
vars: VarMap::new(),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
return Ok(Some(Located {
|
return Ok(Some(Located {
|
||||||
location,
|
location,
|
||||||
custom: Some(ret),
|
custom: Some(ret.into()),
|
||||||
node: ExprKind::Call {
|
node: ExprKind::Call {
|
||||||
func: Box::new(Located {
|
func: Box::new(Located {
|
||||||
custom: Some(custom),
|
custom: Some(custom),
|
||||||
@ -1428,7 +1451,8 @@ impl<'a> Inferencer<'a> {
|
|||||||
arraylike_get_ndims(self.unifier, arg0.custom.unwrap())
|
arraylike_get_ndims(self.unifier, arg0.custom.unwrap())
|
||||||
};
|
};
|
||||||
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
|
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
|
||||||
let ret = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims));
|
let ret =
|
||||||
|
NDArrayType::from_primitive(self.unifier, self.primitives, Some(ty), Some(ndims));
|
||||||
|
|
||||||
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
args: vec![
|
args: vec![
|
||||||
@ -1448,13 +1472,13 @@ impl<'a> Inferencer<'a> {
|
|||||||
default_value: Some(SymbolValue::U32(0)),
|
default_value: Some(SymbolValue::U32(0)),
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
ret,
|
ret: ret.into(),
|
||||||
vars: VarMap::new(),
|
vars: VarMap::new(),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
return Ok(Some(Located {
|
return Ok(Some(Located {
|
||||||
location,
|
location,
|
||||||
custom: Some(ret),
|
custom: Some(ret.into()),
|
||||||
node: ExprKind::Call {
|
node: ExprKind::Call {
|
||||||
func: Box::new(Located {
|
func: Box::new(Located {
|
||||||
custom: Some(custom),
|
custom: Some(custom),
|
||||||
@ -1803,9 +1827,13 @@ impl<'a> Inferencer<'a> {
|
|||||||
TypeEnum::TVar { is_const_generic: false, .. }
|
TypeEnum::TVar { is_const_generic: false, .. }
|
||||||
));
|
));
|
||||||
|
|
||||||
let constrained_ty =
|
let constrained_ty = NDArrayType::from_primitive(
|
||||||
make_ndarray_ty(self.unifier, self.primitives, Some(dummy_tvar), Some(ndims));
|
self.unifier,
|
||||||
self.constrain(value.custom.unwrap(), constrained_ty, &value.location)?;
|
self.primitives,
|
||||||
|
Some(dummy_tvar),
|
||||||
|
Some(ndims),
|
||||||
|
);
|
||||||
|
self.constrain(value.custom.unwrap(), constrained_ty.into(), &value.location)?;
|
||||||
|
|
||||||
let TypeEnum::TLiteral { values, .. } = &*self.unifier.get_ty_immutable(ndims) else {
|
let TypeEnum::TLiteral { values, .. } = &*self.unifier.get_ty_immutable(ndims) else {
|
||||||
panic!("Expected TLiteral for ndarray.ndims, got {}", self.unifier.stringify(ndims))
|
panic!("Expected TLiteral for ndarray.ndims, got {}", self.unifier.stringify(ndims))
|
||||||
@ -1871,10 +1899,14 @@ impl<'a> Inferencer<'a> {
|
|||||||
let ndims_ty = self
|
let ndims_ty = self
|
||||||
.unifier
|
.unifier
|
||||||
.get_fresh_literal(new_ndims.into_iter().map(SymbolValue::U64).collect(), None);
|
.get_fresh_literal(new_ndims.into_iter().map(SymbolValue::U64).collect(), None);
|
||||||
let subscripted_ty =
|
let subscripted_ty = NDArrayType::from_primitive(
|
||||||
make_ndarray_ty(self.unifier, self.primitives, Some(dummy_tvar), Some(ndims_ty));
|
self.unifier,
|
||||||
|
self.primitives,
|
||||||
|
Some(dummy_tvar),
|
||||||
|
Some(ndims_ty),
|
||||||
|
);
|
||||||
|
|
||||||
Ok(subscripted_ty)
|
Ok(subscripted_ty.into())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1893,10 +1925,17 @@ impl<'a> Inferencer<'a> {
|
|||||||
let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) {
|
let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) {
|
||||||
TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }),
|
TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }),
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||||
let (_, ndims) =
|
let ndims = NDArrayType::create(value.custom.unwrap(), self.unifier)
|
||||||
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
|
.ndims_tvar(self.unifier)
|
||||||
|
.ty;
|
||||||
|
|
||||||
make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims))
|
NDArrayType::from_primitive(
|
||||||
|
self.unifier,
|
||||||
|
self.primitives,
|
||||||
|
Some(ty),
|
||||||
|
Some(ndims),
|
||||||
|
)
|
||||||
|
.into()
|
||||||
}
|
}
|
||||||
|
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
@ -1907,8 +1946,10 @@ impl<'a> Inferencer<'a> {
|
|||||||
ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
|
ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
|
||||||
match &*self.unifier.get_ty(value.custom.unwrap()) {
|
match &*self.unifier.get_ty(value.custom.unwrap()) {
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||||
let (_, ndims) =
|
let ndims = NDArrayType::create(value.custom.unwrap(), self.unifier)
|
||||||
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
|
.ndims_tvar(self.unifier)
|
||||||
|
.ty;
|
||||||
|
|
||||||
self.infer_subscript_ndarray(value, slice, ty, ndims)
|
self.infer_subscript_ndarray(value, slice, ty, ndims)
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
@ -1951,7 +1992,10 @@ impl<'a> Inferencer<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
|
let ndims = NDArrayType::create(value.custom.unwrap(), self.unifier)
|
||||||
|
.ndims_tvar(self.unifier)
|
||||||
|
.ty;
|
||||||
|
|
||||||
self.infer_subscript_ndarray(value, slice, ty, ndims)
|
self.infer_subscript_ndarray(value, slice, ty, ndims)
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
@ -1975,8 +2019,9 @@ impl<'a> Inferencer<'a> {
|
|||||||
Ok(ty)
|
Ok(ty)
|
||||||
}
|
}
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||||
let (_, ndims) =
|
let ndims = NDArrayType::create(value.custom.unwrap(), self.unifier)
|
||||||
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
|
.ndims_tvar(self.unifier)
|
||||||
|
.ty;
|
||||||
|
|
||||||
let valid_index_tys = [self.primitives.int32, self.primitives.isize()]
|
let valid_index_tys = [self.primitives.int32, self.primitives.isize()]
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
@ -139,6 +139,7 @@ impl TestEnvironment {
|
|||||||
fields: HashMap::new(),
|
fields: HashMap::new(),
|
||||||
params: VarMap::new(),
|
params: VarMap::new(),
|
||||||
});
|
});
|
||||||
|
let option = OptionType::create(option, &mut unifier);
|
||||||
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
|
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
|
||||||
let ndarray_ndims_tvar =
|
let ndarray_ndims_tvar =
|
||||||
unifier.get_fresh_const_generic_var(uint64, Some("ndarray_ndims".into()), None);
|
unifier.get_fresh_const_generic_var(uint64, Some("ndarray_ndims".into()), None);
|
||||||
@ -147,6 +148,7 @@ impl TestEnvironment {
|
|||||||
fields: HashMap::new(),
|
fields: HashMap::new(),
|
||||||
params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
|
params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
|
||||||
});
|
});
|
||||||
|
let ndarray = NDArrayType::create(ndarray, &mut unifier);
|
||||||
let primitives = PrimitiveStore {
|
let primitives = PrimitiveStore {
|
||||||
int32,
|
int32,
|
||||||
int64,
|
int64,
|
||||||
@ -273,11 +275,13 @@ impl TestEnvironment {
|
|||||||
fields: HashMap::new(),
|
fields: HashMap::new(),
|
||||||
params: VarMap::new(),
|
params: VarMap::new(),
|
||||||
});
|
});
|
||||||
|
let option = OptionType::create(option, &mut unifier);
|
||||||
let ndarray = unifier.add_ty(TypeEnum::TObj {
|
let ndarray = unifier.add_ty(TypeEnum::TObj {
|
||||||
obj_id: PrimDef::NDArray.id(),
|
obj_id: PrimDef::NDArray.id(),
|
||||||
fields: HashMap::new(),
|
fields: HashMap::new(),
|
||||||
params: VarMap::new(),
|
params: VarMap::new(),
|
||||||
});
|
});
|
||||||
|
let ndarray = NDArrayType::create(ndarray, &mut unifier);
|
||||||
identifier_mapping.insert("None".into(), none);
|
identifier_mapping.insert("None".into(), none);
|
||||||
for (i, name) in ["int32", "int64", "float", "bool", "none", "range", "str", "Exception"]
|
for (i, name) in ["int32", "int64", "float", "bool", "none", "range", "str", "Exception"]
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -137,7 +137,8 @@ where
|
|||||||
#[must_use]
|
#[must_use]
|
||||||
fn get_type(&self) -> Type;
|
fn get_type(&self) -> Type;
|
||||||
|
|
||||||
/// See [`Type::obj_id`].
|
/// Similar to [`Type::obj_id`], except that the [`DefinitionId`] is not wrapped within an
|
||||||
|
/// [`Option`].
|
||||||
#[must_use]
|
#[must_use]
|
||||||
fn obj_id(&self, unifier: &Unifier) -> DefinitionId {
|
fn obj_id(&self, unifier: &Unifier) -> DefinitionId {
|
||||||
self.get_type().obj_id(unifier).unwrap()
|
self.get_type().obj_id(unifier).unwrap()
|
||||||
|
Loading…
Reference in New Issue
Block a user