forked from M-Labs/nac3
Compare commits
5 Commits
ndarray-st
...
refactor-p
Author | SHA1 | Date | |
---|---|---|---|
6892a4848e | |||
da4dec08a5 | |||
10a88e1799 | |||
c78accce70 | |||
91e3824517 |
50
Cargo.lock
generated
50
Cargo.lock
generated
@ -105,9 +105,9 @@ checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb"
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "2.5.0"
|
||||
version = "2.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1"
|
||||
checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de"
|
||||
|
||||
[[package]]
|
||||
name = "byteorder"
|
||||
@ -117,9 +117,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.0.99"
|
||||
version = "1.0.100"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "96c51067fd44124faa7f870b4b1c969379ad32b2ba805aa959430ceaa384f695"
|
||||
checksum = "c891175c3fb232128f48de6590095e59198bbeb8620c310be349bfc3afd12c7b"
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
@ -158,7 +158,7 @@ dependencies = [
|
||||
"heck 0.5.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -421,7 +421,7 @@ checksum = "4fa4d8d74483041a882adaa9a29f633253a66dde85055f0495c121620ac484b2"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -501,9 +501,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.4.0"
|
||||
version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
||||
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
@ -513,9 +513,9 @@ checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c"
|
||||
|
||||
[[package]]
|
||||
name = "libloading"
|
||||
version = "0.8.3"
|
||||
version = "0.8.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19"
|
||||
checksum = "e310b3a6b5907f99202fcdb4960ff45b93735d7c7d96b760fcff8db2dc0e103d"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"windows-targets",
|
||||
@ -749,7 +749,7 @@ dependencies = [
|
||||
"phf_shared 0.11.2",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -796,9 +796,9 @@ checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c"
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.85"
|
||||
version = "1.0.86"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "22244ce15aa966053a896d1accb3a6e68469b97c7f33f284b99f0d576879fc23"
|
||||
checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77"
|
||||
dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
@ -850,7 +850,7 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
"pyo3-macros-backend",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -863,7 +863,7 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
"pyo3-build-config",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1044,14 +1044,14 @@ checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.117"
|
||||
version = "1.0.118"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "455182ea6142b14f93f4bc5320a2b31c1f266b66a4a5c858b013302a5d8cbfc3"
|
||||
checksum = "d947f6b3163d8857ea16c4fa0dd4840d52f3041039a85decd46867eb1abef2e4"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"ryu",
|
||||
@ -1120,9 +1120,9 @@ checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
|
||||
|
||||
[[package]]
|
||||
name = "strum"
|
||||
version = "0.26.2"
|
||||
version = "0.26.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5d8cec3501a5194c432b2b7976db6b7d10ec95c253208b45f83f7136aa985e29"
|
||||
checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06"
|
||||
|
||||
[[package]]
|
||||
name = "strum_macros"
|
||||
@ -1134,7 +1134,7 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"rustversion",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1150,9 +1150,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.66"
|
||||
version = "2.0.68"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c42f3f41a2de00b01c0aaad383c5a45241efc8b2d1eda5661812fda5f3cdcff5"
|
||||
checksum = "901fa70d88b9d6c98022e23b4136f9f3e54e4662c3bc1bd1d84a42a9a0f0c1e9"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@ -1218,7 +1218,7 @@ checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1486,5 +1486,5 @@ checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
@ -7,7 +7,7 @@ use nac3core::{
|
||||
},
|
||||
symbol_resolver::ValueEnum,
|
||||
toplevel::{helper::PrimDef, DefinitionId, GenCall},
|
||||
typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, VarMap},
|
||||
typecheck::typedef::{FunSignature, FuncArg, GenericObjectType, Type, TypeEnum, VarMap},
|
||||
};
|
||||
|
||||
use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef};
|
||||
@ -23,7 +23,7 @@ use pyo3::{
|
||||
|
||||
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
|
||||
|
||||
use nac3core::toplevel::numpy::unpack_ndarray_var_tys;
|
||||
use nac3core::toplevel::primitive_type;
|
||||
use std::{
|
||||
collections::hash_map::DefaultHasher,
|
||||
collections::HashMap,
|
||||
@ -399,7 +399,9 @@ fn gen_rpc_tag(
|
||||
gen_rpc_tag(ctx, *ty, buffer)?;
|
||||
}
|
||||
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||
let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
||||
let ndarray_ty = primitive_type::NDArrayType::create(ty, &mut ctx.unifier);
|
||||
let ndarray_dtype = ndarray_ty.dtype_tvar(&mut ctx.unifier).ty;
|
||||
let ndarray_ndims = ndarray_ty.ndims_tvar(&mut ctx.unifier).ty;
|
||||
let ndarray_ndims = if let TLiteral { values, .. } =
|
||||
&*ctx.unifier.get_ty_immutable(ndarray_ndims)
|
||||
{
|
||||
@ -645,7 +647,7 @@ pub fn attributes_writeback(
|
||||
let ty = ty.unwrap();
|
||||
match &*ctx.unifier.get_ty(ty) {
|
||||
TypeEnum::TObj { fields, obj_id, .. }
|
||||
if *obj_id != ctx.primitives.option.obj_id(&ctx.unifier).unwrap() =>
|
||||
if *obj_id != ctx.primitives.option.obj_id(&ctx.unifier) =>
|
||||
{
|
||||
// we only care about primitive attributes
|
||||
// for non-primitive attributes, they should be in another global
|
||||
|
@ -4,20 +4,17 @@ use inkwell::{
|
||||
AddressSpace,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use nac3core::typecheck::typedef::{GenericObjectType, GenericTypeAdapter};
|
||||
use nac3core::{
|
||||
codegen::{
|
||||
classes::{NDArrayType, ProxyType},
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum},
|
||||
toplevel::{
|
||||
helper::PrimDef,
|
||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||
DefinitionId, TopLevelDef,
|
||||
},
|
||||
toplevel::{helper::PrimDef, primitive_type, DefinitionId, TopLevelDef},
|
||||
typecheck::{
|
||||
type_inferencer::PrimitiveStore,
|
||||
typedef::{into_var_map, iter_type_vars, Type, TypeEnum, TypeVar, Unifier, VarMap},
|
||||
typedef::{Type, TypeEnum, TypeVar, Unifier, VarMap},
|
||||
},
|
||||
};
|
||||
use nac3parser::ast::{self, StrRef};
|
||||
@ -336,13 +333,18 @@ impl InnerResolver {
|
||||
// do not handle type var param and concrete check here
|
||||
let var = unifier.get_dummy_var().ty;
|
||||
let ndims = unifier.get_fresh_const_generic_var(primitives.usize(), None, None).ty;
|
||||
let ndarray = make_ndarray_ty(unifier, primitives, Some(var), Some(ndims));
|
||||
Ok(Ok((ndarray, false)))
|
||||
let ndarray = primitive_type::NDArrayType::from_primitive(
|
||||
unifier,
|
||||
primitives,
|
||||
Some(var),
|
||||
Some(ndims),
|
||||
);
|
||||
Ok(Ok((ndarray.into(), false)))
|
||||
} else if ty_id == self.primitive_ids.tuple {
|
||||
// do not handle type var param and concrete check here
|
||||
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false)))
|
||||
} else if ty_id == self.primitive_ids.option {
|
||||
Ok(Ok((primitives.option, false)))
|
||||
Ok(Ok((primitives.option.into(), false)))
|
||||
} else if ty_id == self.primitive_ids.none {
|
||||
unreachable!("none cannot be typeid")
|
||||
} else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).copied() {
|
||||
@ -509,7 +511,16 @@ impl InnerResolver {
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Ok((make_ndarray_ty(unifier, primitives, Some(ty.0), None), true)))
|
||||
Ok(Ok((
|
||||
primitive_type::NDArrayType::from_primitive(
|
||||
unifier,
|
||||
primitives,
|
||||
Some(ty.0),
|
||||
None,
|
||||
)
|
||||
.into(),
|
||||
true,
|
||||
)))
|
||||
}
|
||||
TypeEnum::TTuple { .. } => {
|
||||
let args = match args
|
||||
@ -718,7 +729,9 @@ impl InnerResolver {
|
||||
}
|
||||
}
|
||||
(TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PrimDef::NDArray.id() => {
|
||||
let (ty, ndims) = unpack_ndarray_var_tys(unifier, extracted_ty);
|
||||
let ndarray = primitive_type::NDArrayType::create(extracted_ty, unifier);
|
||||
let ty = ndarray.dtype_tvar(unifier).ty;
|
||||
let ndims = ndarray.ndims_tvar(unifier).ty;
|
||||
let len: usize = obj.getattr("ndim")?.extract()?;
|
||||
if len == 0 {
|
||||
assert!(matches!(
|
||||
@ -733,10 +746,14 @@ impl InnerResolver {
|
||||
match dtype_ty {
|
||||
Ok((t, _)) => match unifier.unify(ty, t) {
|
||||
Ok(()) => {
|
||||
let ndarray_ty =
|
||||
make_ndarray_ty(unifier, primitives, Some(ty), Some(ndims));
|
||||
let ndarray_ty = primitive_type::NDArrayType::from_primitive(
|
||||
unifier,
|
||||
primitives,
|
||||
Some(ty),
|
||||
Some(ndims),
|
||||
);
|
||||
|
||||
Ok(Ok(ndarray_ty))
|
||||
Ok(Ok(ndarray_ty.into()))
|
||||
}
|
||||
Err(e) => Ok(Err(format!(
|
||||
"type error ({}) for the ndarray",
|
||||
@ -759,7 +776,7 @@ impl InnerResolver {
|
||||
// special handling for option type since its class member layout in python side
|
||||
// is special and cannot be mapped directly to a nac3 type as below
|
||||
(TypeEnum::TObj { obj_id, params, .. }, false)
|
||||
if *obj_id == primitives.option.obj_id(unifier).unwrap() =>
|
||||
if *obj_id == primitives.option.obj_id(unifier) =>
|
||||
{
|
||||
let Ok(field_data) = obj.getattr("_nac3_option") else {
|
||||
unreachable!("cannot be None")
|
||||
@ -767,22 +784,24 @@ impl InnerResolver {
|
||||
// if is `none`
|
||||
let zelf_id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?;
|
||||
if zelf_id == self.primitive_ids.none {
|
||||
let ty_enum = unifier.get_ty_immutable(primitives.option);
|
||||
let TypeEnum::TObj { params, .. } = ty_enum.as_ref() else {
|
||||
unreachable!("must be tobj")
|
||||
};
|
||||
let extracted_ty = GenericTypeAdapter::create(extracted_ty, unifier);
|
||||
let var_map = extracted_ty.iter_var_map(unifier, |tvar_iter, unifier| {
|
||||
tvar_iter
|
||||
.map(|tvar| {
|
||||
let TypeEnum::TVar { id, range, name, loc, .. } =
|
||||
&*unifier.get_ty(tvar.ty)
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let var_map = into_var_map(iter_type_vars(params).map(|tvar| {
|
||||
let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(tvar.ty)
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
assert_eq!(*id, tvar.id);
|
||||
let ty = unifier.get_fresh_var_with_range(range, *name, *loc).ty;
|
||||
TypeVar { id: *id, ty }
|
||||
}));
|
||||
return Ok(Ok(unifier.subst(primitives.option, &var_map).unwrap()));
|
||||
assert_eq!(*id, tvar.id);
|
||||
let ty = unifier.get_fresh_var_with_range(range, *name, *loc).ty;
|
||||
TypeVar { id: *id, ty }
|
||||
})
|
||||
.map(TypeVar::into)
|
||||
.collect::<VarMap>()
|
||||
});
|
||||
return Ok(Ok(unifier.subst(primitives.option.into(), &var_map).unwrap()));
|
||||
}
|
||||
|
||||
let ty = match self.get_obj_type(py, field_data, unifier, defs, primitives)? {
|
||||
@ -797,19 +816,26 @@ impl InnerResolver {
|
||||
let res = unifier.subst(extracted_ty, &new_var_map).unwrap_or(extracted_ty);
|
||||
Ok(Ok(res))
|
||||
}
|
||||
(TypeEnum::TObj { params, fields, .. }, false) => {
|
||||
(TypeEnum::TObj { fields, .. }, false) => {
|
||||
self.pyid_to_type.write().insert(py_obj_id, extracted_ty);
|
||||
let var_map = into_var_map(iter_type_vars(params).map(|tvar| {
|
||||
let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(tvar.ty)
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
let extracted_ty = GenericTypeAdapter::create(extracted_ty, unifier);
|
||||
let var_map = extracted_ty.iter_var_map(unifier, |tvar_iter, unifier| {
|
||||
tvar_iter
|
||||
.map(|tvar| {
|
||||
let TypeEnum::TVar { id, range, name, loc, .. } =
|
||||
&*unifier.get_ty(tvar.ty)
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
assert_eq!(*id, tvar.id);
|
||||
let ty = unifier.get_fresh_var_with_range(range, *name, *loc).ty;
|
||||
TypeVar { id: *id, ty }
|
||||
}));
|
||||
let mut instantiate_obj = || {
|
||||
assert_eq!(*id, tvar.id);
|
||||
let ty = unifier.get_fresh_var_with_range(range, *name, *loc).ty;
|
||||
TypeVar { id: *id, ty }
|
||||
})
|
||||
.map(TypeVar::into)
|
||||
.collect::<VarMap>()
|
||||
});
|
||||
let instantiate_obj = || {
|
||||
// loop through non-function fields of the class to get the instantiated value
|
||||
for field in fields {
|
||||
let name: String = (*field.0).into();
|
||||
@ -844,6 +870,7 @@ impl InnerResolver {
|
||||
return Ok(Err("object is not of concrete type".into()));
|
||||
}
|
||||
}
|
||||
let extracted_ty = extracted_ty.into();
|
||||
let extracted_ty =
|
||||
unifier.subst(extracted_ty, &var_map).unwrap_or(extracted_ty);
|
||||
Ok(Ok(extracted_ty))
|
||||
@ -1027,8 +1054,9 @@ impl InnerResolver {
|
||||
} else {
|
||||
unreachable!("must be ndarray")
|
||||
};
|
||||
let (ndarray_dtype, ndarray_ndims) =
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty);
|
||||
let ndarray_ty = primitive_type::NDArrayType::create(ndarray_ty, &mut ctx.unifier);
|
||||
let ndarray_dtype = ndarray_ty.dtype_tvar(&mut ctx.unifier).ty;
|
||||
let ndarray_ndims = ndarray_ty.ndims_tvar(&mut ctx.unifier).ty;
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let ndarray_dtype_llvm_ty = ctx.get_llvm_type(generator, ndarray_dtype);
|
||||
@ -1175,7 +1203,7 @@ impl InnerResolver {
|
||||
} else if ty_id == self.primitive_ids.option {
|
||||
let option_val_ty = match ctx.unifier.get_ty_immutable(expected_ty).as_ref() {
|
||||
TypeEnum::TObj { obj_id, params, .. }
|
||||
if *obj_id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap() =>
|
||||
if *obj_id == ctx.primitives.option.obj_id(&ctx.unifier) =>
|
||||
{
|
||||
*params.iter().next().unwrap().1
|
||||
}
|
||||
|
@ -10,7 +10,7 @@ constant-optimization = ["fold"]
|
||||
fold = []
|
||||
|
||||
[dependencies]
|
||||
lazy_static = "1.4"
|
||||
lazy_static = "1.5"
|
||||
parking_lot = "0.12"
|
||||
string-interner = "0.17"
|
||||
fxhash = "0.2"
|
||||
|
@ -8,8 +8,8 @@ use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
|
||||
use crate::codegen::stmt::gen_for_callback_incrementing;
|
||||
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
|
||||
use crate::toplevel::helper::PrimDef;
|
||||
use crate::toplevel::numpy::unpack_ndarray_var_tys;
|
||||
use crate::typecheck::typedef::Type;
|
||||
use crate::toplevel::primitive_type;
|
||||
use crate::typecheck::typedef::{GenericObjectType, Type};
|
||||
|
||||
/// Shorthand for [`unreachable!()`] when a type of argument is not supported.
|
||||
///
|
||||
@ -66,7 +66,9 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
|
||||
BasicValueEnum::PointerValue(n)
|
||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||
{
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
||||
let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
|
||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||
generator,
|
||||
@ -101,7 +103,7 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>(
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(n_ty, *ty)));
|
||||
|
||||
if ctx.unifier.unioned(n_ty, ctx.primitives.int32) {
|
||||
if n_ty.is_signed(&mut ctx.unifier, &ctx.primitives) {
|
||||
ctx.builder.build_int_s_extend(n, llvm_i64, "sext").map(Into::into).unwrap()
|
||||
} else {
|
||||
ctx.builder.build_int_z_extend(n, llvm_i64, "zext").map(Into::into).unwrap()
|
||||
@ -128,7 +130,9 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>(
|
||||
BasicValueEnum::PointerValue(n)
|
||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||
{
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
||||
let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
|
||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||
generator,
|
||||
@ -206,7 +210,9 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>(
|
||||
BasicValueEnum::PointerValue(n)
|
||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||
{
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
||||
let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
|
||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||
generator,
|
||||
@ -241,7 +247,7 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>(
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(n_ty, *ty)));
|
||||
|
||||
if ctx.unifier.unioned(n_ty, ctx.primitives.int32) {
|
||||
if n_ty.is_signed(&mut ctx.unifier, &ctx.primitives) {
|
||||
ctx.builder.build_int_s_extend(n, llvm_i64, "sext").map(Into::into).unwrap()
|
||||
} else {
|
||||
ctx.builder.build_int_z_extend(n, llvm_i64, "zext").map(Into::into).unwrap()
|
||||
@ -273,7 +279,9 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>(
|
||||
BasicValueEnum::PointerValue(n)
|
||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||
{
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
||||
let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
|
||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||
generator,
|
||||
@ -304,20 +312,9 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|
||||
Ok(match n {
|
||||
BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32 | 64) => {
|
||||
debug_assert!([
|
||||
ctx.primitives.bool,
|
||||
ctx.primitives.int32,
|
||||
ctx.primitives.uint32,
|
||||
ctx.primitives.int64,
|
||||
ctx.primitives.uint64,
|
||||
]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(n_ty, *ty)));
|
||||
debug_assert!(n_ty.is_integral(&mut ctx.unifier, &ctx.primitives));
|
||||
|
||||
if [ctx.primitives.bool, ctx.primitives.int32, ctx.primitives.int64]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(n_ty, *ty))
|
||||
{
|
||||
if n_ty.is_signed(&mut ctx.unifier, &ctx.primitives) {
|
||||
ctx.builder
|
||||
.build_signed_int_to_float(n, llvm_f64, "sitofp")
|
||||
.map(Into::into)
|
||||
@ -331,7 +328,7 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
|
||||
}
|
||||
|
||||
BasicValueEnum::FloatValue(n) => {
|
||||
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
|
||||
debug_assert!(n_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
|
||||
|
||||
n.into()
|
||||
}
|
||||
@ -339,7 +336,9 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
|
||||
BasicValueEnum::PointerValue(n)
|
||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||
{
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
||||
let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
|
||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||
generator,
|
||||
@ -373,7 +372,7 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|
||||
Ok(match n {
|
||||
BasicValueEnum::FloatValue(n) => {
|
||||
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
|
||||
debug_assert!(n_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
|
||||
|
||||
let val = llvm_intrinsics::call_float_round(ctx, n, None);
|
||||
ctx.builder
|
||||
@ -385,7 +384,9 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
|
||||
BasicValueEnum::PointerValue(n)
|
||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||
{
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
||||
let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
|
||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||
generator,
|
||||
@ -417,7 +418,7 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|
||||
Ok(match n {
|
||||
BasicValueEnum::FloatValue(n) => {
|
||||
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
|
||||
debug_assert!(n_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
|
||||
|
||||
llvm_intrinsics::call_float_roundeven(ctx, n, None).into()
|
||||
}
|
||||
@ -425,7 +426,9 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>(
|
||||
BasicValueEnum::PointerValue(n)
|
||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||
{
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
||||
let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
|
||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||
generator,
|
||||
@ -463,14 +466,10 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
|
||||
}
|
||||
|
||||
BasicValueEnum::IntValue(n) => {
|
||||
debug_assert!([
|
||||
ctx.primitives.int32,
|
||||
ctx.primitives.uint32,
|
||||
ctx.primitives.int64,
|
||||
ctx.primitives.uint64,
|
||||
]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(n_ty, *ty)));
|
||||
debug_assert!(
|
||||
n_ty.is_integral(&mut ctx.unifier, &ctx.primitives)
|
||||
&& n_ty.is_arithmetic(&mut ctx.unifier, &ctx.primitives)
|
||||
);
|
||||
|
||||
ctx.builder
|
||||
.build_int_compare(IntPredicate::NE, n, n.get_type().const_zero(), FN_NAME)
|
||||
@ -479,7 +478,7 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
|
||||
}
|
||||
|
||||
BasicValueEnum::FloatValue(n) => {
|
||||
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
|
||||
debug_assert!(n_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
|
||||
|
||||
ctx.builder
|
||||
.build_float_compare(FloatPredicate::UNE, n, n.get_type().const_zero(), FN_NAME)
|
||||
@ -490,7 +489,9 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
|
||||
BasicValueEnum::PointerValue(n)
|
||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||
{
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
||||
let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
|
||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||
generator,
|
||||
@ -528,7 +529,7 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|
||||
Ok(match n {
|
||||
BasicValueEnum::FloatValue(n) => {
|
||||
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
|
||||
debug_assert!(n_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
|
||||
|
||||
let val = llvm_intrinsics::call_float_floor(ctx, n, None);
|
||||
if let BasicTypeEnum::IntType(llvm_ret_elem_ty) = llvm_ret_elem_ty {
|
||||
@ -544,7 +545,9 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>(
|
||||
BasicValueEnum::PointerValue(n)
|
||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||
{
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
||||
let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
|
||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||
generator,
|
||||
@ -578,7 +581,7 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|
||||
Ok(match n {
|
||||
BasicValueEnum::FloatValue(n) => {
|
||||
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
|
||||
debug_assert!(n_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
|
||||
|
||||
let val = llvm_intrinsics::call_float_ceil(ctx, n, None);
|
||||
if let BasicTypeEnum::IntType(llvm_ret_elem_ty) = llvm_ret_elem_ty {
|
||||
@ -594,7 +597,9 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>(
|
||||
BasicValueEnum::PointerValue(n)
|
||||
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||
{
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
||||
let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
|
||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||
generator,
|
||||
@ -631,20 +636,9 @@ pub fn call_min<'ctx>(
|
||||
|
||||
match (m, n) {
|
||||
(BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => {
|
||||
debug_assert!([
|
||||
ctx.primitives.bool,
|
||||
ctx.primitives.int32,
|
||||
ctx.primitives.uint32,
|
||||
ctx.primitives.int64,
|
||||
ctx.primitives.uint64,
|
||||
]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(common_ty, *ty)));
|
||||
debug_assert!(common_ty.is_integral(&mut ctx.unifier, &ctx.primitives));
|
||||
|
||||
if [ctx.primitives.int32, ctx.primitives.int64]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(common_ty, *ty))
|
||||
{
|
||||
if common_ty.is_signed(&mut ctx.unifier, &ctx.primitives) {
|
||||
llvm_intrinsics::call_int_smin(ctx, m, n, Some(FN_NAME)).into()
|
||||
} else {
|
||||
llvm_intrinsics::call_int_umin(ctx, m, n, Some(FN_NAME)).into()
|
||||
@ -652,7 +646,7 @@ pub fn call_min<'ctx>(
|
||||
}
|
||||
|
||||
(BasicValueEnum::FloatValue(m), BasicValueEnum::FloatValue(n)) => {
|
||||
debug_assert!(ctx.unifier.unioned(common_ty, ctx.primitives.float));
|
||||
debug_assert!(common_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
|
||||
|
||||
llvm_intrinsics::call_float_minnum(ctx, m, n, Some(FN_NAME)).into()
|
||||
}
|
||||
@ -675,16 +669,10 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|
||||
Ok(match a {
|
||||
BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => {
|
||||
debug_assert!([
|
||||
ctx.primitives.bool,
|
||||
ctx.primitives.int32,
|
||||
ctx.primitives.uint32,
|
||||
ctx.primitives.int64,
|
||||
ctx.primitives.uint64,
|
||||
ctx.primitives.float,
|
||||
]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(a_ty, *ty)));
|
||||
debug_assert!(
|
||||
a_ty.is_integral(&mut ctx.unifier, &ctx.primitives)
|
||||
|| a_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives)
|
||||
);
|
||||
|
||||
a
|
||||
}
|
||||
@ -692,7 +680,9 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>(
|
||||
BasicValueEnum::PointerValue(n)
|
||||
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||
{
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
|
||||
let elem_ty = primitive_type::NDArrayType::create(a_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
|
||||
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
||||
@ -761,22 +751,13 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|
||||
Ok(match (x1, x2) {
|
||||
(BasicValueEnum::IntValue(x1), BasicValueEnum::IntValue(x2)) => {
|
||||
debug_assert!([
|
||||
ctx.primitives.bool,
|
||||
ctx.primitives.int32,
|
||||
ctx.primitives.uint32,
|
||||
ctx.primitives.int64,
|
||||
ctx.primitives.uint64,
|
||||
ctx.primitives.float,
|
||||
]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(common_ty.unwrap(), *ty)));
|
||||
debug_assert!(common_ty.unwrap().is_integral(&mut ctx.unifier, &ctx.primitives));
|
||||
|
||||
call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into()))
|
||||
}
|
||||
|
||||
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
|
||||
debug_assert!(ctx.unifier.unioned(common_ty.unwrap(), ctx.primitives.float));
|
||||
debug_assert!(common_ty.unwrap().is_floating_point(&mut ctx.unifier, &ctx.primitives));
|
||||
|
||||
call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into()))
|
||||
}
|
||||
@ -792,16 +773,24 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
|
||||
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||
|
||||
let dtype = if is_ndarray1 && is_ndarray2 {
|
||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
||||
let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
|
||||
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||
|
||||
ndarray_dtype1
|
||||
} else if is_ndarray1 {
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
|
||||
primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty
|
||||
} else if is_ndarray2 {
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
||||
primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
@ -847,20 +836,9 @@ pub fn call_max<'ctx>(
|
||||
|
||||
match (m, n) {
|
||||
(BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => {
|
||||
debug_assert!([
|
||||
ctx.primitives.bool,
|
||||
ctx.primitives.int32,
|
||||
ctx.primitives.uint32,
|
||||
ctx.primitives.int64,
|
||||
ctx.primitives.uint64,
|
||||
]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(common_ty, *ty)));
|
||||
debug_assert!(common_ty.is_integral(&mut ctx.unifier, &ctx.primitives));
|
||||
|
||||
if [ctx.primitives.int32, ctx.primitives.int64]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(common_ty, *ty))
|
||||
{
|
||||
if common_ty.is_signed(&mut ctx.unifier, &ctx.primitives) {
|
||||
llvm_intrinsics::call_int_smax(ctx, m, n, Some(FN_NAME)).into()
|
||||
} else {
|
||||
llvm_intrinsics::call_int_umax(ctx, m, n, Some(FN_NAME)).into()
|
||||
@ -868,7 +846,7 @@ pub fn call_max<'ctx>(
|
||||
}
|
||||
|
||||
(BasicValueEnum::FloatValue(m), BasicValueEnum::FloatValue(n)) => {
|
||||
debug_assert!(ctx.unifier.unioned(common_ty, ctx.primitives.float));
|
||||
debug_assert!(common_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
|
||||
|
||||
llvm_intrinsics::call_float_maxnum(ctx, m, n, Some(FN_NAME)).into()
|
||||
}
|
||||
@ -891,16 +869,10 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|
||||
Ok(match a {
|
||||
BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => {
|
||||
debug_assert!([
|
||||
ctx.primitives.bool,
|
||||
ctx.primitives.int32,
|
||||
ctx.primitives.uint32,
|
||||
ctx.primitives.int64,
|
||||
ctx.primitives.uint64,
|
||||
ctx.primitives.float,
|
||||
]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(a_ty, *ty)));
|
||||
debug_assert!(
|
||||
a_ty.is_integral(&mut ctx.unifier, &ctx.primitives)
|
||||
|| a_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives)
|
||||
);
|
||||
|
||||
a
|
||||
}
|
||||
@ -908,7 +880,9 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
|
||||
BasicValueEnum::PointerValue(n)
|
||||
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||
{
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
|
||||
let elem_ty = primitive_type::NDArrayType::create(a_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
|
||||
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
|
||||
@ -977,22 +951,13 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|
||||
Ok(match (x1, x2) {
|
||||
(BasicValueEnum::IntValue(x1), BasicValueEnum::IntValue(x2)) => {
|
||||
debug_assert!([
|
||||
ctx.primitives.bool,
|
||||
ctx.primitives.int32,
|
||||
ctx.primitives.uint32,
|
||||
ctx.primitives.int64,
|
||||
ctx.primitives.uint64,
|
||||
ctx.primitives.float,
|
||||
]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(common_ty.unwrap(), *ty)));
|
||||
debug_assert!(common_ty.unwrap().is_integral(&mut ctx.unifier, &ctx.primitives));
|
||||
|
||||
call_max(ctx, (x1_ty, x1.into()), (x2_ty, x2.into()))
|
||||
}
|
||||
|
||||
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
|
||||
debug_assert!(ctx.unifier.unioned(common_ty.unwrap(), ctx.primitives.float));
|
||||
debug_assert!(common_ty.unwrap().is_floating_point(&mut ctx.unifier, &ctx.primitives));
|
||||
|
||||
call_max(ctx, (x1_ty, x1.into()), (x2_ty, x2.into()))
|
||||
}
|
||||
@ -1008,16 +973,24 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
|
||||
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||
|
||||
let dtype = if is_ndarray1 && is_ndarray2 {
|
||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
||||
let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
|
||||
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||
|
||||
ndarray_dtype1
|
||||
} else if is_ndarray1 {
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
|
||||
primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty
|
||||
} else if is_ndarray2 {
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
||||
primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
@ -1075,7 +1048,9 @@ where
|
||||
if arg_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
|
||||
{
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let (arg_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty);
|
||||
let arg_elem_ty = primitive_type::NDArrayType::create(arg_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
let ret_elem_ty = get_ret_elem_type(ctx, arg_elem_ty);
|
||||
|
||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||
@ -1117,22 +1092,11 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
|
||||
n,
|
||||
FN_NAME,
|
||||
&|_ctx, elem_ty| elem_ty,
|
||||
&|_generator, ctx, val_ty, val| match val {
|
||||
&|_, ctx, val_ty, val| match val {
|
||||
BasicValueEnum::IntValue(n) => Some({
|
||||
debug_assert!([
|
||||
ctx.primitives.bool,
|
||||
ctx.primitives.int32,
|
||||
ctx.primitives.uint32,
|
||||
ctx.primitives.int64,
|
||||
ctx.primitives.uint64,
|
||||
]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(val_ty, *ty)));
|
||||
debug_assert!(val_ty.is_integral(&mut ctx.unifier, &ctx.primitives));
|
||||
|
||||
if [ctx.primitives.int32, ctx.primitives.int64]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(val_ty, *ty))
|
||||
{
|
||||
if val_ty.is_signed(&mut ctx.unifier, &ctx.primitives) {
|
||||
llvm_intrinsics::call_int_abs(
|
||||
ctx,
|
||||
n,
|
||||
@ -1146,7 +1110,7 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
|
||||
}),
|
||||
|
||||
BasicValueEnum::FloatValue(n) => Some({
|
||||
debug_assert!(ctx.unifier.unioned(val_ty, ctx.primitives.float));
|
||||
debug_assert!(val_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
|
||||
|
||||
llvm_intrinsics::call_float_fabs(ctx, n, Some(FN_NAME)).into()
|
||||
}),
|
||||
@ -1431,8 +1395,8 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|
||||
Ok(match (x1, x2) {
|
||||
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
|
||||
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
|
||||
debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float));
|
||||
debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
|
||||
debug_assert!(x2_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
|
||||
|
||||
extern_fns::call_atan2(ctx, x1, x2, None).into()
|
||||
}
|
||||
@ -1448,16 +1412,24 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>(
|
||||
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||
|
||||
let dtype = if is_ndarray1 && is_ndarray2 {
|
||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
||||
let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
|
||||
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||
|
||||
ndarray_dtype1
|
||||
} else if is_ndarray1 {
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
|
||||
primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty
|
||||
} else if is_ndarray2 {
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
||||
primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
@ -1498,8 +1470,8 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|
||||
Ok(match (x1, x2) {
|
||||
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
|
||||
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
|
||||
debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float));
|
||||
debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
|
||||
debug_assert!(x2_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
|
||||
|
||||
llvm_intrinsics::call_float_copysign(ctx, x1, x2, None).into()
|
||||
}
|
||||
@ -1515,16 +1487,24 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>(
|
||||
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||
|
||||
let dtype = if is_ndarray1 && is_ndarray2 {
|
||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
||||
let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
|
||||
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||
|
||||
ndarray_dtype1
|
||||
} else if is_ndarray1 {
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
|
||||
primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty
|
||||
} else if is_ndarray2 {
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
||||
primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
@ -1565,8 +1545,8 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|
||||
Ok(match (x1, x2) {
|
||||
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
|
||||
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
|
||||
debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float));
|
||||
debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
|
||||
debug_assert!(x2_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
|
||||
|
||||
llvm_intrinsics::call_float_maxnum(ctx, x1, x2, None).into()
|
||||
}
|
||||
@ -1582,16 +1562,24 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>(
|
||||
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||
|
||||
let dtype = if is_ndarray1 && is_ndarray2 {
|
||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
||||
let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
|
||||
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||
|
||||
ndarray_dtype1
|
||||
} else if is_ndarray1 {
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
|
||||
primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty
|
||||
} else if is_ndarray2 {
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
||||
primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
@ -1632,8 +1620,8 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|
||||
Ok(match (x1, x2) {
|
||||
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
|
||||
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
|
||||
debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float));
|
||||
debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
|
||||
debug_assert!(x2_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
|
||||
|
||||
llvm_intrinsics::call_float_minnum(ctx, x1, x2, None).into()
|
||||
}
|
||||
@ -1649,16 +1637,24 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>(
|
||||
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||
|
||||
let dtype = if is_ndarray1 && is_ndarray2 {
|
||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
||||
let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
|
||||
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||
|
||||
ndarray_dtype1
|
||||
} else if is_ndarray1 {
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
|
||||
primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty
|
||||
} else if is_ndarray2 {
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
||||
primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
@ -1699,7 +1695,7 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|
||||
Ok(match (x1, x2) {
|
||||
(BasicValueEnum::FloatValue(x1), BasicValueEnum::IntValue(x2)) => {
|
||||
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
|
||||
debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
|
||||
debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.int32));
|
||||
|
||||
extern_fns::call_ldexp(ctx, x1, x2, None).into()
|
||||
@ -1715,12 +1711,22 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>(
|
||||
let is_ndarray2 =
|
||||
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||
|
||||
let dtype =
|
||||
if is_ndarray1 { unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 } else { x1_ty };
|
||||
let dtype = if is_ndarray1 {
|
||||
primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty
|
||||
} else {
|
||||
x1_ty
|
||||
};
|
||||
|
||||
let x1_scalar_ty = dtype;
|
||||
let x2_scalar_ty =
|
||||
if is_ndarray2 { unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 } else { x2_ty };
|
||||
let x2_scalar_ty = if is_ndarray2 {
|
||||
primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty
|
||||
} else {
|
||||
x2_ty
|
||||
};
|
||||
|
||||
numpy::ndarray_elementwise_binop_impl(
|
||||
generator,
|
||||
@ -1755,8 +1761,8 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|
||||
Ok(match (x1, x2) {
|
||||
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
|
||||
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
|
||||
debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float));
|
||||
debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
|
||||
debug_assert!(x2_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
|
||||
|
||||
extern_fns::call_hypot(ctx, x1, x2, None).into()
|
||||
}
|
||||
@ -1772,16 +1778,24 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>(
|
||||
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||
|
||||
let dtype = if is_ndarray1 && is_ndarray2 {
|
||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
||||
let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
|
||||
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||
|
||||
ndarray_dtype1
|
||||
} else if is_ndarray1 {
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
|
||||
primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty
|
||||
} else if is_ndarray2 {
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
||||
primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
@ -1822,8 +1836,8 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|
||||
Ok(match (x1, x2) {
|
||||
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
|
||||
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
|
||||
debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float));
|
||||
debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
|
||||
debug_assert!(x2_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
|
||||
|
||||
extern_fns::call_nextafter(ctx, x1, x2, None).into()
|
||||
}
|
||||
@ -1839,16 +1853,24 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
|
||||
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||
|
||||
let dtype = if is_ndarray1 && is_ndarray2 {
|
||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
||||
let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
|
||||
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||
|
||||
ndarray_dtype1
|
||||
} else if is_ndarray1 {
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
|
||||
primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty
|
||||
} else if is_ndarray2 {
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
||||
primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
|
@ -1,5 +1,9 @@
|
||||
use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
|
||||
|
||||
use super::{llvm_intrinsics::call_memcpy_generic, need_sret, CodeGenerator};
|
||||
use crate::toplevel::primitive_type;
|
||||
use crate::toplevel::primitive_type::OptionType;
|
||||
use crate::typecheck::typedef::GenericObjectType;
|
||||
use crate::{
|
||||
codegen::{
|
||||
classes::{
|
||||
@ -15,11 +19,7 @@ use crate::{
|
||||
CodeGenContext, CodeGenTask,
|
||||
},
|
||||
symbol_resolver::{SymbolValue, ValueEnum},
|
||||
toplevel::{
|
||||
helper::PrimDef,
|
||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||
DefinitionId, TopLevelDef,
|
||||
},
|
||||
toplevel::{helper::PrimDef, DefinitionId, TopLevelDef},
|
||||
typecheck::{
|
||||
magic_methods::{binop_assign_name, binop_name, unaryop_name},
|
||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap},
|
||||
@ -36,8 +36,6 @@ use nac3parser::ast::{
|
||||
self, Boolop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef,
|
||||
};
|
||||
|
||||
use super::{llvm_intrinsics::call_memcpy_generic, need_sret, CodeGenerator};
|
||||
|
||||
pub fn get_subst_key(
|
||||
unifier: &mut Unifier,
|
||||
obj: Option<Type>,
|
||||
@ -162,14 +160,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||
self.builder.build_load(ptr, "tup_val").unwrap()
|
||||
}
|
||||
SymbolValue::OptionSome(v) => {
|
||||
let ty = match self.unifier.get_ty_immutable(ty).as_ref() {
|
||||
TypeEnum::TObj { obj_id, params, .. }
|
||||
if *obj_id == self.primitives.option.obj_id(&self.unifier).unwrap() =>
|
||||
{
|
||||
*params.iter().next().unwrap().1
|
||||
}
|
||||
_ => unreachable!("must be option type"),
|
||||
};
|
||||
let ty = OptionType::create(ty, &mut self.unifier).type_tvar(&mut self.unifier).ty;
|
||||
let val = self.gen_symbol_val(generator, v, ty);
|
||||
let ptr = generator
|
||||
.gen_var_alloc(self, val.get_type(), Some("default_opt_some"))
|
||||
@ -178,14 +169,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||
ptr.into()
|
||||
}
|
||||
SymbolValue::OptionNone => {
|
||||
let ty = match self.unifier.get_ty_immutable(ty).as_ref() {
|
||||
TypeEnum::TObj { obj_id, params, .. }
|
||||
if *obj_id == self.primitives.option.obj_id(&self.unifier).unwrap() =>
|
||||
{
|
||||
*params.iter().next().unwrap().1
|
||||
}
|
||||
_ => unreachable!("must be option type"),
|
||||
};
|
||||
let ty = OptionType::create(ty, &mut self.unifier).type_tvar(&mut self.unifier).ty;
|
||||
let actual_ptr_type =
|
||||
self.get_llvm_type(generator, ty).ptr_type(AddressSpace::default());
|
||||
actual_ptr_type.const_null().into()
|
||||
@ -1206,8 +1190,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
let is_ndarray2 = ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||
|
||||
if is_ndarray1 && is_ndarray2 {
|
||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1);
|
||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty2);
|
||||
let ndarray_dtype1 = primitive_type::NDArrayType::create(ty1, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
let ndarray_dtype2 = primitive_type::NDArrayType::create(ty2, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
|
||||
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||
|
||||
@ -1256,8 +1244,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
|
||||
Ok(Some(res.as_base_value().into()))
|
||||
} else {
|
||||
let (ndarray_dtype, _) =
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, if is_ndarray1 { ty1 } else { ty2 });
|
||||
let ndarray_dtype = primitive_type::NDArrayType::create(
|
||||
if is_ndarray1 { ty1 } else { ty2 },
|
||||
&mut ctx.unifier,
|
||||
)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
let ndarray_val = NDArrayValue::from_ptr_val(
|
||||
if is_ndarray1 { left_val } else { right_val }.into_pointer_value(),
|
||||
llvm_usize,
|
||||
@ -1443,7 +1435,9 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
}
|
||||
} else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
||||
let ndarray_dtype = primitive_type::NDArrayType::create(ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
|
||||
let val = NDArrayValue::from_ptr_val(val.into_pointer_value(), llvm_usize, None);
|
||||
|
||||
@ -1527,8 +1521,13 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||
|
||||
return if is_ndarray1 && is_ndarray2 {
|
||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty);
|
||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty);
|
||||
let ndarray_dtype1 = primitive_type::NDArrayType::create(left_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
let ndarray_dtype2 =
|
||||
primitive_type::NDArrayType::create(right_ty, &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
|
||||
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||
|
||||
@ -1562,10 +1561,12 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
|
||||
Ok(Some(res.as_base_value().into()))
|
||||
} else {
|
||||
let (ndarray_dtype, _) = unpack_ndarray_var_tys(
|
||||
&mut ctx.unifier,
|
||||
let ndarray_dtype = primitive_type::NDArrayType::create(
|
||||
if is_ndarray1 { left_ty } else { right_ty },
|
||||
);
|
||||
&mut ctx.unifier,
|
||||
)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty;
|
||||
let res = numpy::ndarray_elementwise_binop_impl(
|
||||
generator,
|
||||
ctx,
|
||||
@ -1788,9 +1789,13 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||
ndims.iter().map(|v| SymbolValue::U64(v - subscripted_dims)).collect(),
|
||||
None,
|
||||
);
|
||||
let ndarray_ty =
|
||||
make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(ty), Some(ndarray_ndims_ty));
|
||||
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
|
||||
let ndarray_ty = primitive_type::NDArrayType::from_primitive(
|
||||
&mut ctx.unifier,
|
||||
&ctx.primitives,
|
||||
Some(ty),
|
||||
Some(ndarray_ndims_ty),
|
||||
);
|
||||
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty.into()).into_pointer_type();
|
||||
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
|
||||
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum();
|
||||
|
||||
@ -2082,7 +2087,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||
ExprKind::Name { id, .. } if id == &"none".into() => {
|
||||
match (
|
||||
ctx.unifier.get_ty(expr.custom.unwrap()).as_ref(),
|
||||
ctx.unifier.get_ty(ctx.primitives.option).as_ref(),
|
||||
ctx.unifier.get_ty(ctx.primitives.option.into()).as_ref(),
|
||||
) {
|
||||
(TypeEnum::TObj { obj_id, params, .. }, TypeEnum::TObj { obj_id: opt_id, .. })
|
||||
if *obj_id == *opt_id =>
|
||||
@ -2464,8 +2469,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||
};
|
||||
// directly generate code for option.unwrap
|
||||
// since it needs to return static value to optimize for kernel invariant
|
||||
if attr == &"unwrap".into()
|
||||
&& id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap()
|
||||
if attr == &"unwrap".into() && id == ctx.primitives.option.obj_id(&ctx.unifier)
|
||||
{
|
||||
match val {
|
||||
ValueEnum::Static(v) => {
|
||||
|
@ -1,7 +1,7 @@
|
||||
use crate::{
|
||||
codegen::classes::{ListType, NDArrayType, ProxyType, RangeType},
|
||||
symbol_resolver::{StaticValue, SymbolResolver},
|
||||
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef},
|
||||
toplevel::{helper::PrimDef, TopLevelContext, TopLevelDef},
|
||||
typecheck::{
|
||||
type_inferencer::{CodeLocation, PrimitiveStore},
|
||||
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
|
||||
@ -47,6 +47,9 @@ pub mod stmt;
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
||||
use crate::toplevel::primitive_type;
|
||||
use crate::toplevel::primitive_type::OptionType;
|
||||
use crate::typecheck::typedef::GenericObjectType;
|
||||
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore};
|
||||
pub use generator::{CodeGenerator, DefaultCodeGenerator};
|
||||
|
||||
@ -457,7 +460,9 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
|
||||
}
|
||||
|
||||
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||
let (dtype, _) = unpack_ndarray_var_tys(unifier, ty);
|
||||
let dtype = primitive_type::NDArrayType::create(ty, unifier)
|
||||
.dtype_tvar(unifier)
|
||||
.ty;
|
||||
let element_type = get_llvm_type(
|
||||
ctx, module, generator, unifier, top_level, type_cache, dtype,
|
||||
);
|
||||
@ -634,7 +639,10 @@ pub fn gen_func_impl<
|
||||
range: unifier.get_representative(primitives.range),
|
||||
str: unifier.get_representative(primitives.str),
|
||||
exception: unifier.get_representative(primitives.exception),
|
||||
option: unifier.get_representative(primitives.option),
|
||||
option: OptionType::create(
|
||||
unifier.get_representative(primitives.option.into()),
|
||||
&mut unifier,
|
||||
),
|
||||
..primitives
|
||||
};
|
||||
|
||||
|
@ -17,12 +17,8 @@ use crate::{
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
symbol_resolver::ValueEnum,
|
||||
toplevel::{
|
||||
helper::PrimDef,
|
||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||
DefinitionId,
|
||||
},
|
||||
typecheck::typedef::{FunSignature, Type, TypeEnum},
|
||||
toplevel::{helper::PrimDef, primitive_type, DefinitionId},
|
||||
typecheck::typedef::{FunSignature, GenericObjectType, Type, TypeEnum},
|
||||
};
|
||||
use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType};
|
||||
use inkwell::{
|
||||
@ -38,12 +34,17 @@ fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
elem_ty: Type,
|
||||
) -> Result<NDArrayValue<'ctx>, String> {
|
||||
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
|
||||
let ndarray_ty = primitive_type::NDArrayType::from_primitive(
|
||||
&mut ctx.unifier,
|
||||
&ctx.primitives,
|
||||
Some(elem_ty),
|
||||
None,
|
||||
);
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let llvm_ndarray_t = ctx
|
||||
.get_llvm_type(generator, ndarray_ty)
|
||||
.get_llvm_type(generator, ndarray_ty.into())
|
||||
.into_pointer_type()
|
||||
.get_element_type()
|
||||
.into_struct_type();
|
||||
@ -1799,7 +1800,9 @@ pub fn gen_ndarray_array<'ctx>(
|
||||
let obj_ty = fun.0.args[0].ty;
|
||||
let obj_elem_ty = match &*context.unifier.get_ty(obj_ty) {
|
||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||
unpack_ndarray_var_tys(&mut context.unifier, obj_ty).0
|
||||
primitive_type::NDArrayType::create(obj_ty, &mut context.unifier)
|
||||
.dtype_tvar(&mut context.unifier)
|
||||
.ty
|
||||
}
|
||||
|
||||
TypeEnum::TList { ty } => {
|
||||
@ -1939,7 +1942,9 @@ pub fn gen_ndarray_copy<'ctx>(
|
||||
let llvm_usize = generator.get_size_type(context.ctx);
|
||||
|
||||
let this_ty = obj.as_ref().unwrap().0;
|
||||
let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty);
|
||||
let this_elem_ty = primitive_type::NDArrayType::create(this_ty, &mut context.unifier)
|
||||
.dtype_tvar(&mut context.unifier)
|
||||
.ty;
|
||||
let this_arg =
|
||||
obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?;
|
||||
|
||||
|
@ -4,13 +4,15 @@ use super::{
|
||||
irrt::{handle_slice_indices, list_slice_assignment},
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
use crate::toplevel::primitive_type;
|
||||
use crate::typecheck::typedef::GenericObjectType;
|
||||
use crate::{
|
||||
codegen::{
|
||||
classes::{ArrayLikeIndexer, ArraySliceValue, ListValue, RangeValue},
|
||||
expr::gen_binop_expr,
|
||||
gen_in_range_check,
|
||||
},
|
||||
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef},
|
||||
toplevel::{helper::PrimDef, DefinitionId, TopLevelDef},
|
||||
typecheck::typedef::{FunSignature, Type, TypeEnum},
|
||||
};
|
||||
use inkwell::{
|
||||
@ -245,7 +247,9 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
|
||||
let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) {
|
||||
TypeEnum::TList { ty } => *ty,
|
||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, target.custom.unwrap()).0
|
||||
primitive_type::NDArrayType::create(target.custom.unwrap(), &mut ctx.unifier)
|
||||
.dtype_tvar(&mut ctx.unifier)
|
||||
.ty
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
@ -3,6 +3,7 @@ use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
use std::{collections::HashMap, collections::HashSet, fmt::Display};
|
||||
|
||||
use crate::typecheck::typedef::GenericObjectType;
|
||||
use crate::{
|
||||
codegen::{CodeGenContext, CodeGenerator},
|
||||
toplevel::{type_annotation::TypeAnnotation, DefinitionId, TopLevelDef},
|
||||
@ -43,7 +44,7 @@ impl SymbolValue {
|
||||
) -> Result<Self, String> {
|
||||
match constant {
|
||||
Constant::None => {
|
||||
if unifier.unioned(expected_ty, primitives.option) {
|
||||
if unifier.unioned(expected_ty, primitives.option.into()) {
|
||||
Ok(SymbolValue::OptionNone)
|
||||
} else {
|
||||
Err(format!("Expected {expected_ty:?}, but got Option"))
|
||||
@ -157,7 +158,7 @@ impl SymbolValue {
|
||||
let vs_tys = vs.iter().map(|v| v.get_type(primitives, unifier)).collect::<Vec<_>>();
|
||||
unifier.add_ty(TypeEnum::TTuple { ty: vs_tys })
|
||||
}
|
||||
SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option,
|
||||
SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option.into(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -183,13 +184,13 @@ impl SymbolValue {
|
||||
TypeAnnotation::Tuple(vs_tys)
|
||||
}
|
||||
SymbolValue::OptionNone => TypeAnnotation::CustomClass {
|
||||
id: primitives.option.obj_id(unifier).unwrap(),
|
||||
id: primitives.option.obj_id(unifier),
|
||||
params: Vec::default(),
|
||||
},
|
||||
SymbolValue::OptionSome(v) => {
|
||||
let ty = v.get_type_annotation(primitives, unifier);
|
||||
TypeAnnotation::CustomClass {
|
||||
id: primitives.option.obj_id(unifier).unwrap(),
|
||||
id: primitives.option.obj_id(unifier),
|
||||
params: vec![ty],
|
||||
}
|
||||
}
|
||||
|
@ -24,8 +24,8 @@ use crate::{
|
||||
stmt::exn_constructor,
|
||||
},
|
||||
symbol_resolver::SymbolValue,
|
||||
toplevel::{helper::PrimDef, numpy::make_ndarray_ty},
|
||||
typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap},
|
||||
toplevel::helper::PrimDef,
|
||||
typecheck::typedef::{into_var_map, TypeVar, VarMap},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
@ -303,10 +303,7 @@ struct BuiltinBuilder<'a> {
|
||||
|
||||
is_some_ty: (Type, bool),
|
||||
unwrap_ty: (Type, bool),
|
||||
option_tvar: TypeVar,
|
||||
|
||||
ndarray_dtype_tvar: TypeVar,
|
||||
ndarray_ndims_tvar: TypeVar,
|
||||
ndarray_copy_ty: (Type, bool),
|
||||
ndarray_fill_ty: (Type, bool),
|
||||
|
||||
@ -315,9 +312,9 @@ struct BuiltinBuilder<'a> {
|
||||
num_ty: TypeVar,
|
||||
num_var_map: VarMap,
|
||||
|
||||
ndarray_float: Type,
|
||||
ndarray_float_2d: Type,
|
||||
ndarray_num_ty: Type,
|
||||
ndarray_float: primitive_type::NDArrayType,
|
||||
ndarray_float_2d: primitive_type::NDArrayType,
|
||||
ndarray_num_ty: primitive_type::NDArrayType,
|
||||
|
||||
float_or_ndarray_ty: TypeVar,
|
||||
float_or_ndarray_var_map: VarMap,
|
||||
@ -344,24 +341,19 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
} = *primitives;
|
||||
|
||||
// Option-related
|
||||
let (is_some_ty, unwrap_ty, option_tvar) =
|
||||
if let TypeEnum::TObj { fields, params, .. } = unifier.get_ty(option).as_ref() {
|
||||
let (is_some_ty, unwrap_ty) =
|
||||
if let TypeEnum::TObj { fields, .. } = &*unifier.get_ty(option.into()) {
|
||||
(
|
||||
*fields.get(&PrimDef::OptionIsSome.simple_name().into()).unwrap(),
|
||||
*fields.get(&PrimDef::OptionUnwrap.simple_name().into()).unwrap(),
|
||||
iter_type_vars(params).next().unwrap(),
|
||||
)
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let TypeEnum::TObj { fields: ndarray_fields, params: ndarray_params, .. } =
|
||||
&*unifier.get_ty(ndarray)
|
||||
else {
|
||||
let TypeEnum::TObj { fields: ndarray_fields, .. } = &*unifier.get_ty(ndarray.into()) else {
|
||||
unreachable!()
|
||||
};
|
||||
let ndarray_dtype_tvar = iter_type_vars(ndarray_params).next().unwrap();
|
||||
let ndarray_ndims_tvar = iter_type_vars(ndarray_params).nth(1).unwrap();
|
||||
let ndarray_copy_ty =
|
||||
*ndarray_fields.get(&PrimDef::NDArrayCopy.simple_name().into()).unwrap();
|
||||
let ndarray_fill_ty =
|
||||
@ -374,7 +366,8 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
);
|
||||
let num_var_map = into_var_map([num_ty]);
|
||||
|
||||
let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), None);
|
||||
let ndarray_float =
|
||||
primitive_type::NDArrayType::from_primitive(unifier, primitives, Some(float), None);
|
||||
let ndarray_float_2d = {
|
||||
let value = match primitives.size_t {
|
||||
64 => SymbolValue::U64(2u64),
|
||||
@ -383,16 +376,28 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
};
|
||||
let ndims = unifier.add_ty(TypeEnum::TLiteral { values: vec![value], loc: None });
|
||||
|
||||
make_ndarray_ty(unifier, primitives, Some(float), Some(ndims))
|
||||
primitive_type::NDArrayType::from_primitive(
|
||||
unifier,
|
||||
primitives,
|
||||
Some(float),
|
||||
Some(ndims),
|
||||
)
|
||||
};
|
||||
|
||||
let ndarray_num_ty = make_ndarray_ty(unifier, primitives, Some(num_ty.ty), None);
|
||||
let float_or_ndarray_ty =
|
||||
unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None);
|
||||
let ndarray_num_ty =
|
||||
primitive_type::NDArrayType::from_primitive(unifier, primitives, Some(num_ty.ty), None);
|
||||
let float_or_ndarray_ty = unifier.get_fresh_var_with_range(
|
||||
&[float, ndarray_float.into()],
|
||||
Some("T".into()),
|
||||
None,
|
||||
);
|
||||
let float_or_ndarray_var_map = into_var_map([float_or_ndarray_ty]);
|
||||
|
||||
let num_or_ndarray_ty =
|
||||
unifier.get_fresh_var_with_range(&[num_ty.ty, ndarray_num_ty], Some("T".into()), None);
|
||||
let num_or_ndarray_ty = unifier.get_fresh_var_with_range(
|
||||
&[num_ty.ty, ndarray_num_ty.into()],
|
||||
Some("T".into()),
|
||||
None,
|
||||
);
|
||||
let num_or_ndarray_var_map = into_var_map([num_ty, num_or_ndarray_ty]);
|
||||
|
||||
let list_int32 = unifier.add_ty(TypeEnum::TList { ty: int32 });
|
||||
@ -405,10 +410,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
|
||||
is_some_ty,
|
||||
unwrap_ty,
|
||||
option_tvar,
|
||||
|
||||
ndarray_dtype_tvar,
|
||||
ndarray_ndims_tvar,
|
||||
ndarray_copy_ty,
|
||||
ndarray_fill_ty,
|
||||
|
||||
@ -632,7 +634,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
PrimDef::Option => TopLevelDef::Class {
|
||||
name: prim.name().into(),
|
||||
object_id: prim.id(),
|
||||
type_vars: vec![self.option_tvar.ty],
|
||||
type_vars: vec![self.primitives.option.type_tvar(self.unifier).ty],
|
||||
fields: Vec::default(),
|
||||
attributes: Vec::default(),
|
||||
methods: vec![
|
||||
@ -653,7 +655,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
name: prim.name().into(),
|
||||
simple_name: prim.simple_name().into(),
|
||||
signature: self.unwrap_ty.0,
|
||||
var_id: vec![self.option_tvar.id],
|
||||
var_id: vec![self.primitives.option.type_tvar(self.unifier).id],
|
||||
instance_to_symbol: HashMap::default(),
|
||||
instance_to_stmt: HashMap::default(),
|
||||
resolver: None,
|
||||
@ -667,7 +669,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
name: prim.name().to_string(),
|
||||
simple_name: prim.simple_name().into(),
|
||||
signature: self.is_some_ty.0,
|
||||
var_id: vec![self.option_tvar.id],
|
||||
var_id: vec![self.primitives.option.type_tvar(self.unifier).id],
|
||||
instance_to_symbol: HashMap::default(),
|
||||
instance_to_stmt: HashMap::default(),
|
||||
resolver: None,
|
||||
@ -698,36 +700,40 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
loc: None,
|
||||
},
|
||||
|
||||
PrimDef::FunSome => TopLevelDef::Function {
|
||||
name: prim.name().into(),
|
||||
simple_name: prim.simple_name().into(),
|
||||
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![FuncArg {
|
||||
name: "n".into(),
|
||||
ty: self.option_tvar.ty,
|
||||
default_value: None,
|
||||
}],
|
||||
ret: self.primitives.option,
|
||||
vars: into_var_map([self.option_tvar]),
|
||||
})),
|
||||
var_id: vec![self.option_tvar.id],
|
||||
instance_to_symbol: HashMap::default(),
|
||||
instance_to_stmt: HashMap::default(),
|
||||
resolver: None,
|
||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||
|ctx, _, fun, args, generator| {
|
||||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg_val =
|
||||
args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||
let alloca = generator
|
||||
.gen_var_alloc(ctx, arg_val.get_type(), Some("alloca_some"))
|
||||
.unwrap();
|
||||
ctx.builder.build_store(alloca, arg_val).unwrap();
|
||||
Ok(Some(alloca.into()))
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
},
|
||||
PrimDef::FunSome => {
|
||||
let option_tvar = self.primitives.option.type_tvar(self.unifier);
|
||||
|
||||
TopLevelDef::Function {
|
||||
name: prim.name().into(),
|
||||
simple_name: prim.simple_name().into(),
|
||||
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![FuncArg {
|
||||
name: "n".into(),
|
||||
ty: option_tvar.ty,
|
||||
default_value: None,
|
||||
}],
|
||||
ret: self.primitives.option.into(),
|
||||
vars: into_var_map([option_tvar]),
|
||||
})),
|
||||
var_id: vec![self.primitives.option.type_tvar(self.unifier).id],
|
||||
instance_to_symbol: HashMap::default(),
|
||||
instance_to_stmt: HashMap::default(),
|
||||
resolver: None,
|
||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||
|ctx, _, fun, args, generator| {
|
||||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg_val =
|
||||
args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||
let alloca = generator
|
||||
.gen_var_alloc(ctx, arg_val.get_type(), Some("alloca_some"))
|
||||
.unwrap();
|
||||
ctx.builder.build_store(alloca, arg_val).unwrap();
|
||||
Ok(Some(alloca.into()))
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
}
|
||||
}
|
||||
|
||||
_ => {
|
||||
unreachable!()
|
||||
@ -736,7 +742,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
}
|
||||
|
||||
/// Build the class `ndarray` and its associated methods.
|
||||
fn build_ndarray_class_related(&self, prim: PrimDef) -> TopLevelDef {
|
||||
fn build_ndarray_class_related(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||
debug_assert_prim_is_allowed(
|
||||
prim,
|
||||
&[PrimDef::NDArray, PrimDef::NDArrayCopy, PrimDef::NDArrayFill],
|
||||
@ -746,7 +752,10 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
PrimDef::NDArray => TopLevelDef::Class {
|
||||
name: prim.name().into(),
|
||||
object_id: prim.id(),
|
||||
type_vars: vec![self.ndarray_dtype_tvar.ty, self.ndarray_ndims_tvar.ty],
|
||||
type_vars: vec![
|
||||
self.primitives.ndarray.dtype_tvar(self.unifier).ty,
|
||||
self.primitives.ndarray.ndims_tvar(self.unifier).ty,
|
||||
],
|
||||
fields: Vec::default(),
|
||||
attributes: Vec::default(),
|
||||
methods: vec![
|
||||
@ -763,7 +772,10 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
name: prim.name().into(),
|
||||
simple_name: prim.simple_name().into(),
|
||||
signature: self.ndarray_copy_ty.0,
|
||||
var_id: vec![self.ndarray_dtype_tvar.id, self.ndarray_ndims_tvar.id],
|
||||
var_id: vec![
|
||||
self.primitives.ndarray.dtype_tvar(self.unifier).id,
|
||||
self.primitives.ndarray.ndims_tvar(self.unifier).id,
|
||||
],
|
||||
instance_to_symbol: HashMap::default(),
|
||||
instance_to_stmt: HashMap::default(),
|
||||
resolver: None,
|
||||
@ -780,7 +792,10 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
name: prim.name().into(),
|
||||
simple_name: prim.simple_name().into(),
|
||||
signature: self.ndarray_fill_ty.0,
|
||||
var_id: vec![self.ndarray_dtype_tvar.id, self.ndarray_ndims_tvar.id],
|
||||
var_id: vec![
|
||||
self.primitives.ndarray.dtype_tvar(self.unifier).id,
|
||||
self.primitives.ndarray.ndims_tvar(self.unifier).id,
|
||||
],
|
||||
instance_to_symbol: HashMap::default(),
|
||||
instance_to_stmt: HashMap::default(),
|
||||
resolver: None,
|
||||
@ -869,15 +884,26 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
// The size variant of the function determines the size of the returned int.
|
||||
let int_sized = size_variant.of_int(self.primitives);
|
||||
|
||||
let ndarray_int_sized =
|
||||
make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.ty));
|
||||
let ndarray_float =
|
||||
make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.ty));
|
||||
let ndarray_int_sized = primitive_type::NDArrayType::from_primitive(
|
||||
self.unifier,
|
||||
self.primitives,
|
||||
Some(int_sized),
|
||||
Some(common_ndim.ty),
|
||||
);
|
||||
let ndarray_float = primitive_type::NDArrayType::from_primitive(
|
||||
self.unifier,
|
||||
self.primitives,
|
||||
Some(float),
|
||||
Some(common_ndim.ty),
|
||||
);
|
||||
|
||||
let p0_ty =
|
||||
self.unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None);
|
||||
let p0_ty = self.unifier.get_fresh_var_with_range(
|
||||
&[float, ndarray_float.into()],
|
||||
Some("T".into()),
|
||||
None,
|
||||
);
|
||||
let ret_ty = self.unifier.get_fresh_var_with_range(
|
||||
&[int_sized, ndarray_int_sized],
|
||||
&[int_sized, ndarray_int_sized.into()],
|
||||
Some("R".into()),
|
||||
None,
|
||||
);
|
||||
@ -929,19 +955,30 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
None,
|
||||
);
|
||||
|
||||
let ndarray_float =
|
||||
make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.ty));
|
||||
let ndarray_float = primitive_type::NDArrayType::from_primitive(
|
||||
self.unifier,
|
||||
self.primitives,
|
||||
Some(float),
|
||||
Some(common_ndim.ty),
|
||||
);
|
||||
|
||||
// The size variant of the function determines the type of int returned
|
||||
let int_sized = size_variant.of_int(self.primitives);
|
||||
let ndarray_int_sized =
|
||||
make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.ty));
|
||||
let ndarray_int_sized = primitive_type::NDArrayType::from_primitive(
|
||||
self.unifier,
|
||||
self.primitives,
|
||||
Some(int_sized),
|
||||
Some(common_ndim.ty),
|
||||
);
|
||||
|
||||
let p0_ty =
|
||||
self.unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None);
|
||||
let p0_ty = self.unifier.get_fresh_var_with_range(
|
||||
&[float, ndarray_float.into()],
|
||||
Some("T".into()),
|
||||
None,
|
||||
);
|
||||
|
||||
let ret_ty = self.unifier.get_fresh_var_with_range(
|
||||
&[int_sized, ndarray_int_sized],
|
||||
&[int_sized, ndarray_int_sized.into()],
|
||||
Some("R".into()),
|
||||
None,
|
||||
);
|
||||
@ -1004,7 +1041,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
self.unifier,
|
||||
&VarMap::new(),
|
||||
prim.name(),
|
||||
self.ndarray_float,
|
||||
self.ndarray_float.into(),
|
||||
&[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
|
||||
Box::new(move |ctx, obj, fun, args, generator| {
|
||||
let func = match prim {
|
||||
@ -1050,7 +1087,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
default_value: Some(SymbolValue::U32(0)),
|
||||
},
|
||||
],
|
||||
ret: ndarray,
|
||||
ret: ndarray.into(),
|
||||
vars: into_var_map([tv]),
|
||||
})),
|
||||
var_id: vec![tv.id],
|
||||
@ -1073,7 +1110,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
self.unifier,
|
||||
&into_var_map([tv]),
|
||||
prim.name(),
|
||||
self.primitives.ndarray,
|
||||
self.primitives.ndarray.into(),
|
||||
// We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a
|
||||
// type variable
|
||||
&[(self.list_int32, "shape"), (tv.ty, "fill_value")],
|
||||
@ -1102,7 +1139,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
default_value: Some(SymbolValue::I32(0)),
|
||||
},
|
||||
],
|
||||
ret: self.ndarray_float_2d,
|
||||
ret: self.ndarray_float_2d.into(),
|
||||
vars: VarMap::default(),
|
||||
})),
|
||||
var_id: Vec::default(),
|
||||
@ -1122,7 +1159,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
self.unifier,
|
||||
&VarMap::new(),
|
||||
prim.name(),
|
||||
self.ndarray_float_2d,
|
||||
self.ndarray_float_2d.into(),
|
||||
&[(int32, "n")],
|
||||
Box::new(|ctx, obj, fun, args, generator| {
|
||||
gen_ndarray_identity(ctx, &obj, fun, &args, generator)
|
||||
@ -1337,10 +1374,15 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
let tvar = self.unifier.get_fresh_var(Some("L".into()), None);
|
||||
let list = self.unifier.add_ty(TypeEnum::TList { ty: tvar.ty });
|
||||
let ndims = self.unifier.get_fresh_const_generic_var(uint64, Some("N".into()), None);
|
||||
let ndarray = make_ndarray_ty(self.unifier, self.primitives, Some(tvar.ty), Some(ndims.ty));
|
||||
let ndarray = primitive_type::NDArrayType::from_primitive(
|
||||
self.unifier,
|
||||
self.primitives,
|
||||
Some(tvar.ty),
|
||||
Some(ndims.ty),
|
||||
);
|
||||
|
||||
let arg_ty = self.unifier.get_fresh_var_with_range(
|
||||
&[list, ndarray, self.primitives.range],
|
||||
&[list, ndarray.into(), self.primitives.range],
|
||||
Some("I".into()),
|
||||
None,
|
||||
);
|
||||
@ -1798,8 +1840,13 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
}
|
||||
|
||||
fn new_type_or_ndarray_ty(&mut self, scalar_ty: Type) -> TypeVar {
|
||||
let ndarray = make_ndarray_ty(self.unifier, self.primitives, Some(scalar_ty), None);
|
||||
let ndarray = primitive_type::NDArrayType::from_primitive(
|
||||
self.unifier,
|
||||
self.primitives,
|
||||
Some(scalar_ty),
|
||||
None,
|
||||
);
|
||||
|
||||
self.unifier.get_fresh_var_with_range(&[scalar_ty, ndarray], Some("T".into()), None)
|
||||
self.unifier.get_fresh_var_with_range(&[scalar_ty, ndarray.into()], Some("T".into()), None)
|
||||
}
|
||||
}
|
||||
|
@ -1,14 +1,13 @@
|
||||
use std::convert::TryInto;
|
||||
|
||||
use super::*;
|
||||
use crate::symbol_resolver::SymbolValue;
|
||||
use crate::toplevel::numpy::unpack_ndarray_var_tys;
|
||||
use crate::typecheck::typedef::{into_var_map, Mapping, TypeVarId, VarMap};
|
||||
use crate::toplevel::primitive_type::{NDArrayType, OptionType};
|
||||
use crate::typecheck::typedef::{into_var_map, GenericObjectType, Mapping, TypeVarId, VarMap};
|
||||
use nac3parser::ast::{Constant, Location};
|
||||
use strum::IntoEnumIterator;
|
||||
use strum_macros::EnumIter;
|
||||
|
||||
use super::*;
|
||||
|
||||
/// All primitive types and functions in nac3core.
|
||||
#[derive(Clone, Copy, Debug, EnumIter, PartialEq, Eq)]
|
||||
pub enum PrimDef {
|
||||
@ -403,6 +402,7 @@ impl TopLevelComposer {
|
||||
.collect::<HashMap<_, _>>(),
|
||||
params: into_var_map([option_type_var]),
|
||||
});
|
||||
let option = OptionType::create(option, &mut unifier);
|
||||
|
||||
let size_t_ty = match size_t {
|
||||
32 => uint32,
|
||||
@ -436,8 +436,9 @@ impl TopLevelComposer {
|
||||
]),
|
||||
params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
|
||||
});
|
||||
let ndarray = NDArrayType::create(ndarray, &mut unifier);
|
||||
|
||||
unifier.unify(ndarray_copy_fun_ret_ty.ty, ndarray).unwrap();
|
||||
unifier.unify(ndarray_copy_fun_ret_ty.ty, ndarray.into()).unwrap();
|
||||
|
||||
let primitives = PrimitiveStore {
|
||||
int32,
|
||||
@ -747,7 +748,7 @@ impl TopLevelComposer {
|
||||
TypeAnnotation::CustomClass { id: e_id, params: e_param },
|
||||
) => {
|
||||
*f_id == *e_id
|
||||
&& *f_id == primitive.option.obj_id(unifier).unwrap()
|
||||
&& *f_id == primitive.option.obj_id(unifier)
|
||||
&& (f_param.is_empty()
|
||||
|| (f_param.len() == 1
|
||||
&& e_param.len() == 1
|
||||
@ -885,7 +886,7 @@ pub fn parse_parameter_default_value(
|
||||
pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type {
|
||||
match &*unifier.get_ty(ty) {
|
||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||
unpack_ndarray_var_tys(unifier, ty).0
|
||||
NDArrayType::create(ty, unifier).dtype_tvar(unifier).ty
|
||||
}
|
||||
|
||||
TypeEnum::TList { ty } => arraylike_flatten_element_type(unifier, *ty),
|
||||
@ -897,7 +898,7 @@ pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type {
|
||||
pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 {
|
||||
match &*unifier.get_ty(ty) {
|
||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||
let ndims = unpack_ndarray_var_tys(unifier, ty).1;
|
||||
let ndims = NDArrayType::create(ty, unifier).ndims_tvar(unifier).ty;
|
||||
let TypeEnum::TLiteral { values, .. } = &*unifier.get_ty_immutable(ndims) else {
|
||||
panic!("Expected TLiteral for ndarray.ndims, got {}", unifier.stringify(ndims))
|
||||
};
|
||||
|
@ -30,7 +30,7 @@ pub struct DefinitionId(pub usize);
|
||||
pub mod builtins;
|
||||
pub mod composer;
|
||||
pub mod helper;
|
||||
pub mod numpy;
|
||||
pub mod primitive_type;
|
||||
pub mod type_annotation;
|
||||
use composer::*;
|
||||
use type_annotation::*;
|
||||
|
@ -1,85 +0,0 @@
|
||||
use crate::{
|
||||
toplevel::helper::PrimDef,
|
||||
typecheck::{
|
||||
type_inferencer::PrimitiveStore,
|
||||
typedef::{Type, TypeEnum, TypeVarId, Unifier, VarMap},
|
||||
},
|
||||
};
|
||||
use itertools::Itertools;
|
||||
|
||||
/// Creates a `ndarray` [`Type`] with the given type arguments.
|
||||
///
|
||||
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
|
||||
/// specialized.
|
||||
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
|
||||
/// specialized.
|
||||
pub fn make_ndarray_ty(
|
||||
unifier: &mut Unifier,
|
||||
primitives: &PrimitiveStore,
|
||||
dtype: Option<Type>,
|
||||
ndims: Option<Type>,
|
||||
) -> Type {
|
||||
subst_ndarray_tvars(unifier, primitives.ndarray, dtype, ndims)
|
||||
}
|
||||
|
||||
/// Substitutes type variables in `ndarray`.
|
||||
///
|
||||
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
|
||||
/// specialized.
|
||||
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
|
||||
/// specialized.
|
||||
pub fn subst_ndarray_tvars(
|
||||
unifier: &mut Unifier,
|
||||
ndarray: Type,
|
||||
dtype: Option<Type>,
|
||||
ndims: Option<Type>,
|
||||
) -> Type {
|
||||
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
|
||||
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
|
||||
};
|
||||
debug_assert_eq!(*obj_id, PrimDef::NDArray.id());
|
||||
|
||||
if dtype.is_none() && ndims.is_none() {
|
||||
return ndarray;
|
||||
}
|
||||
|
||||
let tvar_ids = params.iter().map(|(obj_id, _)| *obj_id).collect_vec();
|
||||
debug_assert_eq!(tvar_ids.len(), 2);
|
||||
|
||||
let mut tvar_subst = VarMap::new();
|
||||
if let Some(dtype) = dtype {
|
||||
tvar_subst.insert(tvar_ids[0], dtype);
|
||||
}
|
||||
if let Some(ndims) = ndims {
|
||||
tvar_subst.insert(tvar_ids[1], ndims);
|
||||
}
|
||||
|
||||
unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray)
|
||||
}
|
||||
|
||||
fn unpack_ndarray_tvars(unifier: &mut Unifier, ndarray: Type) -> Vec<(TypeVarId, Type)> {
|
||||
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
|
||||
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
|
||||
};
|
||||
debug_assert_eq!(*obj_id, PrimDef::NDArray.id());
|
||||
debug_assert_eq!(params.len(), 2);
|
||||
|
||||
params
|
||||
.iter()
|
||||
.sorted_by_key(|(obj_id, _)| *obj_id)
|
||||
.map(|(var_id, ty)| (*var_id, *ty))
|
||||
.collect_vec()
|
||||
}
|
||||
|
||||
/// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds
|
||||
/// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray`
|
||||
/// respectively.
|
||||
pub fn unpack_ndarray_var_ids(unifier: &mut Unifier, ndarray: Type) -> (TypeVarId, TypeVarId) {
|
||||
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.0).collect_tuple().unwrap()
|
||||
}
|
||||
|
||||
/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to
|
||||
/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively.
|
||||
pub fn unpack_ndarray_var_tys(unifier: &mut Unifier, ndarray: Type) -> (Type, Type) {
|
||||
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.1).collect_tuple().unwrap()
|
||||
}
|
98
nac3core/src/toplevel/primitive_type.rs
Normal file
98
nac3core/src/toplevel/primitive_type.rs
Normal file
@ -0,0 +1,98 @@
|
||||
use crate::toplevel::helper::PrimDef;
|
||||
use crate::typecheck::type_inferencer::PrimitiveStore;
|
||||
use crate::typecheck::typedef::{GenericObjectType, Type, TypeVar, Unifier, VarMap};
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct OptionType(Type);
|
||||
|
||||
impl OptionType {
|
||||
pub fn from_primitive(
|
||||
unifier: &mut Unifier,
|
||||
primitives: &PrimitiveStore,
|
||||
type_ty: Option<Type>,
|
||||
) -> Self {
|
||||
primitives.option.subst(unifier, type_ty)
|
||||
}
|
||||
|
||||
pub fn type_tvar(&self, unifier: &mut Unifier) -> TypeVar {
|
||||
self.get_var_at(unifier, 0).unwrap()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn subst(&self, unifier: &mut Unifier, type_ty: Option<Type>) -> Self {
|
||||
let new_vars = [(self.type_tvar(unifier).id, type_ty)]
|
||||
.into_iter()
|
||||
.filter_map(|(id, ty)| ty.map(|ty| (id, ty)))
|
||||
.collect::<VarMap>();
|
||||
|
||||
let new_ty = unifier.subst(self.get_type(), &new_vars).unwrap_or(self.get_type());
|
||||
OptionType(new_ty)
|
||||
}
|
||||
}
|
||||
|
||||
impl GenericObjectType for OptionType {
|
||||
fn try_create(ty: Type, unifier: &mut Unifier) -> Option<Self> {
|
||||
if ty.obj_id(unifier).is_some_and(|id| id == PrimDef::Option.id()) {
|
||||
Some(OptionType(ty))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn get_type(&self) -> Type {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct NDArrayType(Type);
|
||||
|
||||
impl NDArrayType {
|
||||
pub fn from_primitive(
|
||||
unifier: &mut Unifier,
|
||||
primitives: &PrimitiveStore,
|
||||
dtype: Option<Type>,
|
||||
ndims: Option<Type>,
|
||||
) -> Self {
|
||||
primitives.ndarray.subst(unifier, dtype, ndims)
|
||||
}
|
||||
|
||||
pub fn dtype_tvar(&self, unifier: &mut Unifier) -> TypeVar {
|
||||
self.get_var_at(unifier, 0).unwrap()
|
||||
}
|
||||
|
||||
pub fn ndims_tvar(&self, unifier: &mut Unifier) -> TypeVar {
|
||||
self.get_var_at(unifier, 1).unwrap()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn subst(
|
||||
&self,
|
||||
unifier: &mut Unifier,
|
||||
dtype_ty: Option<Type>,
|
||||
ndims_ty: Option<Type>,
|
||||
) -> Self {
|
||||
let new_vars =
|
||||
[(self.dtype_tvar(unifier).id, dtype_ty), (self.ndims_tvar(unifier).id, ndims_ty)]
|
||||
.into_iter()
|
||||
.filter_map(|(id, ty)| ty.map(|ty| (id, ty)))
|
||||
.collect::<VarMap>();
|
||||
|
||||
let new_ty = unifier.subst(self.get_type(), &new_vars).unwrap_or(self.get_type());
|
||||
NDArrayType(new_ty)
|
||||
}
|
||||
}
|
||||
|
||||
impl GenericObjectType for NDArrayType {
|
||||
fn try_create(ty: Type, unifier: &mut Unifier) -> Option<Self> {
|
||||
if ty.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
||||
Some(NDArrayType(ty))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn get_type(&self) -> Type {
|
||||
self.0
|
||||
}
|
||||
}
|
@ -1,7 +1,7 @@
|
||||
use super::*;
|
||||
use crate::symbol_resolver::SymbolValue;
|
||||
use crate::toplevel::helper::PrimDef;
|
||||
use crate::typecheck::typedef::VarMap;
|
||||
use crate::typecheck::typedef::{GenericObjectType, VarMap};
|
||||
use nac3parser::ast::Constant;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
@ -267,12 +267,7 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
|
||||
slice.as_ref(),
|
||||
locked,
|
||||
)?;
|
||||
let id =
|
||||
if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(primitives.option).as_ref() {
|
||||
*obj_id
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
let id = primitives.option.obj_id(unifier);
|
||||
Ok(TypeAnnotation::CustomClass { id, params: vec![def_ann] })
|
||||
}
|
||||
|
||||
|
@ -1,9 +1,9 @@
|
||||
use crate::symbol_resolver::SymbolValue;
|
||||
use crate::toplevel::helper::PrimDef;
|
||||
use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys};
|
||||
use crate::toplevel::primitive_type;
|
||||
use crate::typecheck::{
|
||||
type_inferencer::*,
|
||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
|
||||
typedef::{FunSignature, FuncArg, GenericObjectType, Type, TypeEnum, Unifier, VarMap},
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use nac3parser::ast::StrRef;
|
||||
@ -369,8 +369,12 @@ pub fn typeof_ndarray_broadcast(
|
||||
if is_left_ndarray && is_right_ndarray {
|
||||
// Perform broadcasting on two ndarray operands.
|
||||
|
||||
let (left_ty_dtype, left_ty_ndims) = unpack_ndarray_var_tys(unifier, left);
|
||||
let (right_ty_dtype, right_ty_ndims) = unpack_ndarray_var_tys(unifier, right);
|
||||
let left_ty = primitive_type::NDArrayType::create(left, unifier);
|
||||
let left_ty_dtype = left_ty.dtype_tvar(unifier).ty;
|
||||
let left_ty_ndims = left_ty.ndims_tvar(unifier).ty;
|
||||
let right_ty = primitive_type::NDArrayType::create(right, unifier);
|
||||
let right_ty_dtype = right_ty.dtype_tvar(unifier).ty;
|
||||
let right_ty_ndims = right_ty.ndims_tvar(unifier).ty;
|
||||
|
||||
assert!(unifier.unioned(left_ty_dtype, right_ty_dtype));
|
||||
|
||||
@ -397,11 +401,18 @@ pub fn typeof_ndarray_broadcast(
|
||||
.collect_vec();
|
||||
let res_ndims = unifier.get_fresh_literal(res_ndims, None);
|
||||
|
||||
Ok(make_ndarray_ty(unifier, primitives, Some(left_ty_dtype), Some(res_ndims)))
|
||||
Ok(primitive_type::NDArrayType::from_primitive(
|
||||
unifier,
|
||||
primitives,
|
||||
Some(left_ty_dtype),
|
||||
Some(res_ndims),
|
||||
)
|
||||
.into())
|
||||
} else {
|
||||
let (ndarray_ty, scalar_ty) = if is_left_ndarray { (left, right) } else { (right, left) };
|
||||
|
||||
let (ndarray_ty_dtype, _) = unpack_ndarray_var_tys(unifier, ndarray_ty);
|
||||
let ndarray_ty_dtype =
|
||||
primitive_type::NDArrayType::create(ndarray_ty, unifier).ndims_tvar(unifier).ty;
|
||||
|
||||
if unifier.unioned(ndarray_ty_dtype, scalar_ty) {
|
||||
Ok(ndarray_ty)
|
||||
@ -444,7 +455,8 @@ pub fn typeof_binop(
|
||||
}
|
||||
|
||||
Operator::MatMult => {
|
||||
let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs);
|
||||
let lhs_ndims =
|
||||
primitive_type::NDArrayType::create(lhs, unifier).ndims_tvar(unifier).ty;
|
||||
let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) {
|
||||
TypeEnum::TLiteral { values, .. } => {
|
||||
assert_eq!(values.len(), 1);
|
||||
@ -452,7 +464,8 @@ pub fn typeof_binop(
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs);
|
||||
let rhs_ndims =
|
||||
primitive_type::NDArrayType::create(rhs, unifier).ndims_tvar(unifier).ty;
|
||||
let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) {
|
||||
TypeEnum::TLiteral { values, .. } => {
|
||||
assert_eq!(values.len(), 1);
|
||||
@ -526,7 +539,7 @@ pub fn typeof_unaryop(
|
||||
let operand_obj_id = operand.obj_id(unifier);
|
||||
|
||||
if op == Unaryop::Not
|
||||
&& operand_obj_id.is_some_and(|id| id == primitives.ndarray.obj_id(unifier).unwrap())
|
||||
&& operand_obj_id.is_some_and(|id| id == primitives.ndarray.obj_id(unifier))
|
||||
{
|
||||
return Err(
|
||||
"The truth value of an array with more than one element is ambiguous".to_string()
|
||||
@ -552,7 +565,8 @@ pub fn typeof_unaryop(
|
||||
|
||||
Unaryop::UAdd | Unaryop::USub => {
|
||||
if operand_obj_id.is_some_and(|id| id == PrimDef::NDArray.id()) {
|
||||
let (dtype, _) = unpack_ndarray_var_tys(unifier, operand);
|
||||
let dtype =
|
||||
primitive_type::NDArrayType::create(operand, unifier).dtype_tvar(unifier).ty;
|
||||
if dtype.obj_id(unifier).is_some_and(|id| id == PrimDef::Bool.id()) {
|
||||
return Err(if op == Unaryop::UAdd {
|
||||
"The ufunc 'positive' cannot be applied to ndarray[bool, N]".to_string()
|
||||
@ -586,9 +600,15 @@ pub fn typeof_cmpop(
|
||||
|
||||
Ok(Some(if is_left_ndarray || is_right_ndarray {
|
||||
let brd = typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?;
|
||||
let (_, ndims) = unpack_ndarray_var_tys(unifier, brd);
|
||||
let ndims = primitive_type::NDArrayType::create(brd, unifier).ndims_tvar(unifier).ty;
|
||||
|
||||
make_ndarray_ty(unifier, primitives, Some(primitives.bool), Some(ndims))
|
||||
primitive_type::NDArrayType::from_primitive(
|
||||
unifier,
|
||||
primitives,
|
||||
Some(primitives.bool),
|
||||
Some(ndims),
|
||||
)
|
||||
.into()
|
||||
} else if unifier.unioned(lhs, rhs) {
|
||||
primitives.bool
|
||||
} else {
|
||||
@ -611,64 +631,108 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
||||
|
||||
/* int ======== */
|
||||
for t in [int32_t, int64_t, uint32_t, uint64_t] {
|
||||
let ndarray_int_t = make_ndarray_ty(unifier, store, Some(t), None);
|
||||
impl_basic_arithmetic(unifier, store, t, &[t, ndarray_int_t], None);
|
||||
impl_pow(unifier, store, t, &[t, ndarray_int_t], None);
|
||||
let ndarray_int_t =
|
||||
primitive_type::NDArrayType::from_primitive(unifier, store, Some(t), None);
|
||||
impl_basic_arithmetic(unifier, store, t, &[t, ndarray_int_t.into()], None);
|
||||
impl_pow(unifier, store, t, &[t, ndarray_int_t.into()], None);
|
||||
impl_bitwise_arithmetic(unifier, store, t);
|
||||
impl_bitwise_shift(unifier, store, t);
|
||||
impl_div(unifier, store, t, &[t, ndarray_int_t], None);
|
||||
impl_floordiv(unifier, store, t, &[t, ndarray_int_t], None);
|
||||
impl_mod(unifier, store, t, &[t, ndarray_int_t], None);
|
||||
impl_div(unifier, store, t, &[t, ndarray_int_t.into()], None);
|
||||
impl_floordiv(unifier, store, t, &[t, ndarray_int_t.into()], None);
|
||||
impl_mod(unifier, store, t, &[t, ndarray_int_t.into()], None);
|
||||
impl_invert(unifier, store, t, Some(t));
|
||||
impl_not(unifier, store, t, Some(bool_t));
|
||||
impl_comparison(unifier, store, t, &[t, ndarray_int_t], None);
|
||||
impl_eq(unifier, store, t, &[t, ndarray_int_t], None);
|
||||
impl_comparison(unifier, store, t, &[t, ndarray_int_t.into()], None);
|
||||
impl_eq(unifier, store, t, &[t, ndarray_int_t.into()], None);
|
||||
}
|
||||
for t in [int32_t, int64_t] {
|
||||
impl_sign(unifier, store, t, Some(t));
|
||||
}
|
||||
|
||||
/* float ======== */
|
||||
let ndarray_float_t = make_ndarray_ty(unifier, store, Some(float_t), None);
|
||||
let ndarray_int32_t = make_ndarray_ty(unifier, store, Some(int32_t), None);
|
||||
impl_basic_arithmetic(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
||||
impl_pow(unifier, store, float_t, &[int32_t, float_t, ndarray_int32_t, ndarray_float_t], None);
|
||||
impl_div(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
||||
impl_floordiv(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
||||
impl_mod(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
||||
let ndarray_float_t =
|
||||
primitive_type::NDArrayType::from_primitive(unifier, store, Some(float_t), None);
|
||||
let ndarray_int32_t =
|
||||
primitive_type::NDArrayType::from_primitive(unifier, store, Some(int32_t), None);
|
||||
impl_basic_arithmetic(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None);
|
||||
impl_pow(
|
||||
unifier,
|
||||
store,
|
||||
float_t,
|
||||
&[int32_t, float_t, ndarray_int32_t.into(), ndarray_float_t.into()],
|
||||
None,
|
||||
);
|
||||
impl_div(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None);
|
||||
impl_floordiv(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None);
|
||||
impl_mod(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None);
|
||||
impl_sign(unifier, store, float_t, Some(float_t));
|
||||
impl_not(unifier, store, float_t, Some(bool_t));
|
||||
impl_comparison(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
||||
impl_eq(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
||||
impl_comparison(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None);
|
||||
impl_eq(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None);
|
||||
|
||||
/* bool ======== */
|
||||
let ndarray_bool_t = make_ndarray_ty(unifier, store, Some(bool_t), None);
|
||||
let ndarray_bool_t =
|
||||
primitive_type::NDArrayType::from_primitive(unifier, store, Some(bool_t), None);
|
||||
impl_invert(unifier, store, bool_t, Some(int32_t));
|
||||
impl_not(unifier, store, bool_t, Some(bool_t));
|
||||
impl_sign(unifier, store, bool_t, Some(int32_t));
|
||||
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None);
|
||||
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t.into()], None);
|
||||
|
||||
/* ndarray ===== */
|
||||
let ndarray_usized_ndims_tvar =
|
||||
unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None);
|
||||
let ndarray_unsized_t =
|
||||
make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.ty));
|
||||
let (ndarray_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_t);
|
||||
let (ndarray_unsized_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_unsized_t);
|
||||
let ndarray_unsized_t = primitive_type::NDArrayType::from_primitive(
|
||||
unifier,
|
||||
store,
|
||||
None,
|
||||
Some(ndarray_usized_ndims_tvar.ty),
|
||||
);
|
||||
let ndarray_dtype_t = ndarray_t.dtype_tvar(unifier).ty;
|
||||
let ndarray_unsized_dtype_t = ndarray_unsized_t.dtype_tvar(unifier).ty;
|
||||
impl_basic_arithmetic(
|
||||
unifier,
|
||||
store,
|
||||
ndarray_t,
|
||||
&[ndarray_unsized_t, ndarray_unsized_dtype_t],
|
||||
ndarray_t.into(),
|
||||
&[ndarray_unsized_t.into(), ndarray_unsized_dtype_t],
|
||||
None,
|
||||
);
|
||||
impl_pow(
|
||||
unifier,
|
||||
store,
|
||||
ndarray_t.into(),
|
||||
&[ndarray_unsized_t.into(), ndarray_unsized_dtype_t],
|
||||
None,
|
||||
);
|
||||
impl_div(unifier, store, ndarray_t.into(), &[ndarray_t.into(), ndarray_dtype_t], None);
|
||||
impl_floordiv(
|
||||
unifier,
|
||||
store,
|
||||
ndarray_t.into(),
|
||||
&[ndarray_unsized_t.into(), ndarray_unsized_dtype_t],
|
||||
None,
|
||||
);
|
||||
impl_mod(
|
||||
unifier,
|
||||
store,
|
||||
ndarray_t.into(),
|
||||
&[ndarray_unsized_t.into(), ndarray_unsized_dtype_t],
|
||||
None,
|
||||
);
|
||||
impl_matmul(unifier, store, ndarray_t.into(), &[ndarray_t.into()], Some(ndarray_t.into()));
|
||||
impl_sign(unifier, store, ndarray_t.into(), Some(ndarray_t.into()));
|
||||
impl_invert(unifier, store, ndarray_t.into(), Some(ndarray_t.into()));
|
||||
impl_eq(
|
||||
unifier,
|
||||
store,
|
||||
ndarray_t.into(),
|
||||
&[ndarray_unsized_t.into(), ndarray_unsized_dtype_t],
|
||||
None,
|
||||
);
|
||||
impl_comparison(
|
||||
unifier,
|
||||
store,
|
||||
ndarray_t.into(),
|
||||
&[ndarray_unsized_t.into(), ndarray_unsized_dtype_t],
|
||||
None,
|
||||
);
|
||||
impl_pow(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||
impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None);
|
||||
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||
impl_matmul(unifier, store, ndarray_t, &[ndarray_t], Some(ndarray_t));
|
||||
impl_sign(unifier, store, ndarray_t, Some(ndarray_t));
|
||||
impl_invert(unifier, store, ndarray_t, Some(ndarray_t));
|
||||
impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||
impl_comparison(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||
}
|
||||
|
@ -4,14 +4,16 @@ use std::iter::once;
|
||||
use std::ops::Not;
|
||||
use std::{cell::RefCell, sync::Arc};
|
||||
|
||||
use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap};
|
||||
use super::typedef::{
|
||||
Call, FunSignature, FuncArg, GenericObjectType, RecordField, Type, TypeEnum, Unifier, VarMap,
|
||||
};
|
||||
use super::{magic_methods::*, type_error::TypeError, typedef::CallId};
|
||||
use crate::toplevel::primitive_type::{NDArrayType, OptionType};
|
||||
use crate::toplevel::TopLevelDef;
|
||||
use crate::{
|
||||
symbol_resolver::{SymbolResolver, SymbolValue},
|
||||
toplevel::{
|
||||
helper::{arraylike_flatten_element_type, arraylike_get_ndims, PrimDef},
|
||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||
TopLevelContext,
|
||||
},
|
||||
};
|
||||
@ -49,8 +51,8 @@ pub struct PrimitiveStore {
|
||||
pub range: Type,
|
||||
pub str: Type,
|
||||
pub exception: Type,
|
||||
pub option: Type,
|
||||
pub ndarray: Type,
|
||||
pub option: OptionType,
|
||||
pub ndarray: NDArrayType,
|
||||
pub size_t: u32,
|
||||
}
|
||||
|
||||
@ -74,6 +76,34 @@ impl PrimitiveStore {
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an iterator over all primitive types in this store.
|
||||
fn iter(&self) -> impl Iterator<Item = Type> {
|
||||
self.into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoIterator for &PrimitiveStore {
|
||||
type Item = Type;
|
||||
type IntoIter = <Vec<Type> as IntoIterator>::IntoIter;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
vec![
|
||||
self.int32,
|
||||
self.int64,
|
||||
self.uint32,
|
||||
self.uint64,
|
||||
self.float,
|
||||
self.bool,
|
||||
self.none,
|
||||
self.range,
|
||||
self.str,
|
||||
self.exception,
|
||||
self.option.into(),
|
||||
self.ndarray.into(),
|
||||
]
|
||||
.into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FunctionData {
|
||||
@ -500,7 +530,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
|
||||
// the name `none` is special since it may have different types
|
||||
if id == &"none".into() {
|
||||
if let TypeEnum::TObj { params, .. } =
|
||||
self.unifier.get_ty_immutable(self.primitives.option).as_ref()
|
||||
&*self.unifier.get_ty_immutable(self.primitives.option.into())
|
||||
{
|
||||
let var_map = params
|
||||
.iter()
|
||||
@ -515,7 +545,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
|
||||
(*id, self.unifier.get_fresh_var_with_range(range, *name, *loc).ty)
|
||||
})
|
||||
.collect::<VarMap>();
|
||||
Some(self.unifier.subst(self.primitives.option, &var_map).unwrap())
|
||||
Some(self.unifier.subst(self.primitives.option.into(), &var_map).unwrap())
|
||||
} else {
|
||||
unreachable!("must be tobj")
|
||||
}
|
||||
@ -1034,9 +1064,16 @@ impl<'a> Inferencer<'a> {
|
||||
|
||||
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||
{
|
||||
let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
|
||||
let ndarray_ndims =
|
||||
NDArrayType::create(arg0_ty, self.unifier).ndims_tvar(self.unifier).ty;
|
||||
|
||||
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims))
|
||||
NDArrayType::from_primitive(
|
||||
self.unifier,
|
||||
self.primitives,
|
||||
Some(target_ty),
|
||||
Some(ndarray_ndims),
|
||||
)
|
||||
.into()
|
||||
} else {
|
||||
target_ty
|
||||
};
|
||||
@ -1072,9 +1109,7 @@ impl<'a> Inferencer<'a> {
|
||||
|
||||
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||
{
|
||||
let (ndarray_dtype, _) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
|
||||
|
||||
ndarray_dtype
|
||||
NDArrayType::create(arg0_ty, self.unifier).dtype_tvar(self.unifier).ty
|
||||
} else {
|
||||
arg0_ty
|
||||
};
|
||||
@ -1126,14 +1161,14 @@ impl<'a> Inferencer<'a> {
|
||||
|
||||
let arg0_dtype =
|
||||
if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
||||
unpack_ndarray_var_tys(self.unifier, arg0_ty).0
|
||||
NDArrayType::create(arg0_ty, self.unifier).dtype_tvar(self.unifier).ty
|
||||
} else {
|
||||
arg0_ty
|
||||
};
|
||||
|
||||
let arg1_dtype =
|
||||
if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
||||
unpack_ndarray_var_tys(self.unifier, arg1_ty).0
|
||||
NDArrayType::create(arg1_ty, self.unifier).dtype_tvar(self.unifier).ty
|
||||
} else {
|
||||
arg1_ty
|
||||
};
|
||||
@ -1164,9 +1199,17 @@ impl<'a> Inferencer<'a> {
|
||||
// (float, int32), so convert it to align with the dtype of the first arg
|
||||
let arg1_ty = if id == &"np_ldexp".into() {
|
||||
if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
||||
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, arg1_ty);
|
||||
// let (_, ndims) = unpack_ndarray_var_tys(self.unifier, arg1_ty);
|
||||
let ndims =
|
||||
NDArrayType::create(arg1_ty, self.unifier).ndims_tvar(self.unifier).ty;
|
||||
|
||||
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndims))
|
||||
NDArrayType::from_primitive(
|
||||
self.unifier,
|
||||
self.primitives,
|
||||
Some(target_ty),
|
||||
Some(ndims),
|
||||
)
|
||||
.into()
|
||||
} else {
|
||||
target_ty
|
||||
}
|
||||
@ -1253,9 +1296,16 @@ impl<'a> Inferencer<'a> {
|
||||
|
||||
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||
{
|
||||
let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
|
||||
let ndarray_ndims =
|
||||
NDArrayType::create(arg0_ty, self.unifier).ndims_tvar(self.unifier).ty;
|
||||
|
||||
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims))
|
||||
NDArrayType::from_primitive(
|
||||
self.unifier,
|
||||
self.primitives,
|
||||
Some(target_ty),
|
||||
Some(ndarray_ndims),
|
||||
)
|
||||
.into()
|
||||
} else {
|
||||
target_ty
|
||||
};
|
||||
@ -1295,7 +1345,7 @@ impl<'a> Inferencer<'a> {
|
||||
self.fold_numpy_function_call_shape_argument(*id, 0, shape_expr)?; // Special handling for `shape`
|
||||
|
||||
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
|
||||
let ret = make_ndarray_ty(
|
||||
let ret = NDArrayType::from_primitive(
|
||||
self.unifier,
|
||||
self.primitives,
|
||||
Some(self.primitives.float),
|
||||
@ -1307,13 +1357,13 @@ impl<'a> Inferencer<'a> {
|
||||
ty: shape.custom.unwrap(),
|
||||
default_value: None,
|
||||
}],
|
||||
ret,
|
||||
ret: ret.into(),
|
||||
vars: VarMap::new(),
|
||||
}));
|
||||
|
||||
return Ok(Some(Located {
|
||||
location,
|
||||
custom: Some(ret),
|
||||
custom: Some(ret.into()),
|
||||
node: ExprKind::Call {
|
||||
func: Box::new(Located {
|
||||
custom: Some(custom),
|
||||
@ -1346,7 +1396,8 @@ impl<'a> Inferencer<'a> {
|
||||
|
||||
let ty = arg1.custom.unwrap();
|
||||
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
|
||||
let ret = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims));
|
||||
let ret =
|
||||
NDArrayType::from_primitive(self.unifier, self.primitives, Some(ty), Some(ndims));
|
||||
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![
|
||||
FuncArg { name: "shape".into(), ty: arg0.custom.unwrap(), default_value: None },
|
||||
@ -1356,13 +1407,13 @@ impl<'a> Inferencer<'a> {
|
||||
default_value: None,
|
||||
},
|
||||
],
|
||||
ret,
|
||||
ret: ret.into(),
|
||||
vars: VarMap::new(),
|
||||
}));
|
||||
|
||||
return Ok(Some(Located {
|
||||
location,
|
||||
custom: Some(ret),
|
||||
custom: Some(ret.into()),
|
||||
node: ExprKind::Call {
|
||||
func: Box::new(Located {
|
||||
custom: Some(custom),
|
||||
@ -1400,7 +1451,8 @@ impl<'a> Inferencer<'a> {
|
||||
arraylike_get_ndims(self.unifier, arg0.custom.unwrap())
|
||||
};
|
||||
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
|
||||
let ret = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims));
|
||||
let ret =
|
||||
NDArrayType::from_primitive(self.unifier, self.primitives, Some(ty), Some(ndims));
|
||||
|
||||
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![
|
||||
@ -1420,13 +1472,13 @@ impl<'a> Inferencer<'a> {
|
||||
default_value: Some(SymbolValue::U32(0)),
|
||||
},
|
||||
],
|
||||
ret,
|
||||
ret: ret.into(),
|
||||
vars: VarMap::new(),
|
||||
}));
|
||||
|
||||
return Ok(Some(Located {
|
||||
location,
|
||||
custom: Some(ret),
|
||||
custom: Some(ret.into()),
|
||||
node: ExprKind::Call {
|
||||
func: Box::new(Located {
|
||||
custom: Some(custom),
|
||||
@ -1775,9 +1827,13 @@ impl<'a> Inferencer<'a> {
|
||||
TypeEnum::TVar { is_const_generic: false, .. }
|
||||
));
|
||||
|
||||
let constrained_ty =
|
||||
make_ndarray_ty(self.unifier, self.primitives, Some(dummy_tvar), Some(ndims));
|
||||
self.constrain(value.custom.unwrap(), constrained_ty, &value.location)?;
|
||||
let constrained_ty = NDArrayType::from_primitive(
|
||||
self.unifier,
|
||||
self.primitives,
|
||||
Some(dummy_tvar),
|
||||
Some(ndims),
|
||||
);
|
||||
self.constrain(value.custom.unwrap(), constrained_ty.into(), &value.location)?;
|
||||
|
||||
let TypeEnum::TLiteral { values, .. } = &*self.unifier.get_ty_immutable(ndims) else {
|
||||
panic!("Expected TLiteral for ndarray.ndims, got {}", self.unifier.stringify(ndims))
|
||||
@ -1843,10 +1899,14 @@ impl<'a> Inferencer<'a> {
|
||||
let ndims_ty = self
|
||||
.unifier
|
||||
.get_fresh_literal(new_ndims.into_iter().map(SymbolValue::U64).collect(), None);
|
||||
let subscripted_ty =
|
||||
make_ndarray_ty(self.unifier, self.primitives, Some(dummy_tvar), Some(ndims_ty));
|
||||
let subscripted_ty = NDArrayType::from_primitive(
|
||||
self.unifier,
|
||||
self.primitives,
|
||||
Some(dummy_tvar),
|
||||
Some(ndims_ty),
|
||||
);
|
||||
|
||||
Ok(subscripted_ty)
|
||||
Ok(subscripted_ty.into())
|
||||
}
|
||||
}
|
||||
|
||||
@ -1865,10 +1925,17 @@ impl<'a> Inferencer<'a> {
|
||||
let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) {
|
||||
TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }),
|
||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||
let (_, ndims) =
|
||||
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
|
||||
let ndims = NDArrayType::create(value.custom.unwrap(), self.unifier)
|
||||
.ndims_tvar(self.unifier)
|
||||
.ty;
|
||||
|
||||
make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims))
|
||||
NDArrayType::from_primitive(
|
||||
self.unifier,
|
||||
self.primitives,
|
||||
Some(ty),
|
||||
Some(ndims),
|
||||
)
|
||||
.into()
|
||||
}
|
||||
|
||||
_ => unreachable!(),
|
||||
@ -1879,8 +1946,10 @@ impl<'a> Inferencer<'a> {
|
||||
ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
|
||||
match &*self.unifier.get_ty(value.custom.unwrap()) {
|
||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||
let (_, ndims) =
|
||||
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
|
||||
let ndims = NDArrayType::create(value.custom.unwrap(), self.unifier)
|
||||
.ndims_tvar(self.unifier)
|
||||
.ty;
|
||||
|
||||
self.infer_subscript_ndarray(value, slice, ty, ndims)
|
||||
}
|
||||
_ => {
|
||||
@ -1923,7 +1992,10 @@ impl<'a> Inferencer<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
|
||||
let ndims = NDArrayType::create(value.custom.unwrap(), self.unifier)
|
||||
.ndims_tvar(self.unifier)
|
||||
.ty;
|
||||
|
||||
self.infer_subscript_ndarray(value, slice, ty, ndims)
|
||||
}
|
||||
_ => {
|
||||
@ -1947,8 +2019,9 @@ impl<'a> Inferencer<'a> {
|
||||
Ok(ty)
|
||||
}
|
||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||
let (_, ndims) =
|
||||
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
|
||||
let ndims = NDArrayType::create(value.custom.unwrap(), self.unifier)
|
||||
.ndims_tvar(self.unifier)
|
||||
.ty;
|
||||
|
||||
let valid_index_tys = [self.primitives.int32, self.primitives.isize()]
|
||||
.into_iter()
|
||||
|
@ -139,6 +139,7 @@ impl TestEnvironment {
|
||||
fields: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let option = OptionType::create(option, &mut unifier);
|
||||
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
|
||||
let ndarray_ndims_tvar =
|
||||
unifier.get_fresh_const_generic_var(uint64, Some("ndarray_ndims".into()), None);
|
||||
@ -147,6 +148,7 @@ impl TestEnvironment {
|
||||
fields: HashMap::new(),
|
||||
params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
|
||||
});
|
||||
let ndarray = NDArrayType::create(ndarray, &mut unifier);
|
||||
let primitives = PrimitiveStore {
|
||||
int32,
|
||||
int64,
|
||||
@ -273,11 +275,13 @@ impl TestEnvironment {
|
||||
fields: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let option = OptionType::create(option, &mut unifier);
|
||||
let ndarray = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: PrimDef::NDArray.id(),
|
||||
fields: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let ndarray = NDArrayType::create(ndarray, &mut unifier);
|
||||
identifier_mapping.insert("None".into(), none);
|
||||
for (i, name) in ["int32", "int64", "float", "bool", "none", "range", "str", "Exception"]
|
||||
.iter()
|
||||
|
@ -22,6 +22,40 @@ mod test;
|
||||
/// Handle for a type, implemented as a key in the unification table.
|
||||
pub type Type = UnificationKey;
|
||||
|
||||
/// Macro for generating functions related to type traits, e.g. whether the type is integral.
|
||||
macro_rules! primitive_type_trait_fn {
|
||||
($id:ident, $( $matches:ident ),*) => {
|
||||
#[must_use]
|
||||
pub fn $id(self, unifier: &mut Unifier, store: &PrimitiveStore) -> bool {
|
||||
[$(store.$matches,)*].into_iter().any(|ty| unifier.unioned(self, ty))
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl Type {
|
||||
/// Wrapper function for cleaner code so that we don't need to write this long pattern matching
|
||||
/// just to get the field `obj_id`.
|
||||
#[must_use]
|
||||
pub fn obj_id(self, unifier: &Unifier) -> Option<DefinitionId> {
|
||||
if let TypeEnum::TObj { obj_id, .. } = &*unifier.get_ty_immutable(self) {
|
||||
Some(*obj_id)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn is_primitive(self, unifier: &mut Unifier, store: &PrimitiveStore) -> bool {
|
||||
store.into_iter().any(|ty| unifier.unioned(self, ty))
|
||||
}
|
||||
|
||||
primitive_type_trait_fn!(is_integral, bool, int32, int64, uint32, uint64);
|
||||
primitive_type_trait_fn!(is_floating_point, float);
|
||||
primitive_type_trait_fn!(is_arithmetic, int32, int64, uint32, uint64, float);
|
||||
primitive_type_trait_fn!(is_signed, int32, uint32, float);
|
||||
primitive_type_trait_fn!(is_unsigned, uint32, uint64);
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
|
||||
pub struct CallId(pub(super) usize);
|
||||
|
||||
@ -55,6 +89,24 @@ pub struct TypeVar {
|
||||
pub ty: Type,
|
||||
}
|
||||
|
||||
impl From<(TypeVarId, Type)> for TypeVar {
|
||||
fn from((id, ty): (TypeVarId, Type)) -> Self {
|
||||
TypeVar { id, ty }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(&TypeVarId, &Type)> for TypeVar {
|
||||
fn from((id, ty): (&TypeVarId, &Type)) -> Self {
|
||||
TypeVar { id: *id, ty: *ty }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TypeVar> for (TypeVarId, Type) {
|
||||
fn from(value: TypeVar) -> Self {
|
||||
(value.id, value.ty)
|
||||
}
|
||||
}
|
||||
|
||||
/// The mapping between [`TypeVarId`] and [unifier type][`Type`].
|
||||
pub type VarMap = IndexMapping<TypeVarId>;
|
||||
|
||||
@ -68,9 +120,84 @@ where
|
||||
vars.into_iter().map(|var| (var.id, var.ty)).collect()
|
||||
}
|
||||
|
||||
/// Get an iterator of [`TypeVar`]s from a [`VarMap`]
|
||||
pub fn iter_type_vars(var_map: &VarMap) -> impl Iterator<Item = TypeVar> + '_ {
|
||||
var_map.iter().map(|(&id, &ty)| TypeVar { id, ty })
|
||||
/// A trait representing a possibly generic object type.
|
||||
pub trait GenericObjectType
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
fn try_create(ty: Type, unifier: &mut Unifier) -> Option<Self>;
|
||||
|
||||
/// Creates an instance from a [`Type`].
|
||||
#[must_use]
|
||||
fn create(ty: Type, unifier: &mut Unifier) -> Self {
|
||||
Self::try_create(ty, unifier).unwrap()
|
||||
}
|
||||
|
||||
/// Returns the [`Type`] underlying this instance.
|
||||
#[must_use]
|
||||
fn get_type(&self) -> Type;
|
||||
|
||||
/// Similar to [`Type::obj_id`], except that the [`DefinitionId`] is not wrapped within an
|
||||
/// [`Option`].
|
||||
#[must_use]
|
||||
fn obj_id(&self, unifier: &Unifier) -> DefinitionId {
|
||||
self.get_type().obj_id(unifier).unwrap()
|
||||
}
|
||||
|
||||
/// Returns a copy of the [`VarMap`] of this object type.
|
||||
#[must_use]
|
||||
fn var_map(&self, unifier: &mut Unifier) -> VarMap {
|
||||
let TypeEnum::TObj { params, .. } = &*unifier.get_ty(self.get_type()) else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
params.clone()
|
||||
}
|
||||
|
||||
/// Creates an iterator over the [`VarMap`] of this object type, applying `iter_fn` on the
|
||||
/// created [`Iterator`].
|
||||
#[must_use]
|
||||
fn iter_var_map<R, IterFn: FnOnce(&mut dyn Iterator<Item = TypeVar>, &mut Unifier) -> R>(
|
||||
&self,
|
||||
unifier: &mut Unifier,
|
||||
iter_fn: IterFn,
|
||||
) -> R {
|
||||
let TypeEnum::TObj { params, .. } = &*unifier.get_ty(self.get_type()) else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let res = iter_fn(&mut params.iter().map(TypeVar::from), unifier);
|
||||
res
|
||||
}
|
||||
|
||||
/// Returns the [`TypeVar`] instance at the given index.
|
||||
#[must_use]
|
||||
fn get_var_at(&self, unifier: &mut Unifier, i: usize) -> Option<TypeVar> {
|
||||
self.iter_var_map(unifier, |iter, _| iter.nth(i))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: GenericObjectType> From<T> for Type {
|
||||
fn from(value: T) -> Self {
|
||||
value.get_type()
|
||||
}
|
||||
}
|
||||
|
||||
/// An adapter that converts [`Type`] into
|
||||
pub struct GenericTypeAdapter(Type);
|
||||
|
||||
impl GenericObjectType for GenericTypeAdapter {
|
||||
fn try_create(ty: Type, unifier: &mut Unifier) -> Option<Self> {
|
||||
if let TypeEnum::TObj { .. } = &*unifier.get_ty_immutable(ty) {
|
||||
Some(GenericTypeAdapter(ty))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn get_type(&self) -> Type {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -109,19 +236,6 @@ pub enum RecordKey {
|
||||
Int(i32),
|
||||
}
|
||||
|
||||
impl Type {
|
||||
/// Wrapper function for cleaner code so that we don't need to write this long pattern matching
|
||||
/// just to get the field `obj_id`.
|
||||
#[must_use]
|
||||
pub fn obj_id(self, unifier: &Unifier) -> Option<DefinitionId> {
|
||||
if let TypeEnum::TObj { obj_id, .. } = &*unifier.get_ty_immutable(self) {
|
||||
Some(*obj_id)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RecordKey> for StrRef {
|
||||
fn from(r: &RecordKey) -> Self {
|
||||
match r {
|
||||
|
Loading…
Reference in New Issue
Block a user