1
0
forked from M-Labs/nac3

Compare commits

...

5 Commits

20 changed files with 984 additions and 599 deletions

50
Cargo.lock generated
View File

@ -105,9 +105,9 @@ checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb"
[[package]] [[package]]
name = "bitflags" name = "bitflags"
version = "2.5.0" version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de"
[[package]] [[package]]
name = "byteorder" name = "byteorder"
@ -117,9 +117,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.0.99" version = "1.0.100"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96c51067fd44124faa7f870b4b1c969379ad32b2ba805aa959430ceaa384f695" checksum = "c891175c3fb232128f48de6590095e59198bbeb8620c310be349bfc3afd12c7b"
[[package]] [[package]]
name = "cfg-if" name = "cfg-if"
@ -158,7 +158,7 @@ dependencies = [
"heck 0.5.0", "heck 0.5.0",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.66", "syn 2.0.68",
] ]
[[package]] [[package]]
@ -421,7 +421,7 @@ checksum = "4fa4d8d74483041a882adaa9a29f633253a66dde85055f0495c121620ac484b2"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.66", "syn 2.0.68",
] ]
[[package]] [[package]]
@ -501,9 +501,9 @@ dependencies = [
[[package]] [[package]]
name = "lazy_static" name = "lazy_static"
version = "1.4.0" version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]] [[package]]
name = "libc" name = "libc"
@ -513,9 +513,9 @@ checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c"
[[package]] [[package]]
name = "libloading" name = "libloading"
version = "0.8.3" version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19" checksum = "e310b3a6b5907f99202fcdb4960ff45b93735d7c7d96b760fcff8db2dc0e103d"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"windows-targets", "windows-targets",
@ -749,7 +749,7 @@ dependencies = [
"phf_shared 0.11.2", "phf_shared 0.11.2",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.66", "syn 2.0.68",
] ]
[[package]] [[package]]
@ -796,9 +796,9 @@ checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c"
[[package]] [[package]]
name = "proc-macro2" name = "proc-macro2"
version = "1.0.85" version = "1.0.86"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22244ce15aa966053a896d1accb3a6e68469b97c7f33f284b99f0d576879fc23" checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77"
dependencies = [ dependencies = [
"unicode-ident", "unicode-ident",
] ]
@ -850,7 +850,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"pyo3-macros-backend", "pyo3-macros-backend",
"quote", "quote",
"syn 2.0.66", "syn 2.0.68",
] ]
[[package]] [[package]]
@ -863,7 +863,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"pyo3-build-config", "pyo3-build-config",
"quote", "quote",
"syn 2.0.66", "syn 2.0.68",
] ]
[[package]] [[package]]
@ -1044,14 +1044,14 @@ checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.66", "syn 2.0.68",
] ]
[[package]] [[package]]
name = "serde_json" name = "serde_json"
version = "1.0.117" version = "1.0.118"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "455182ea6142b14f93f4bc5320a2b31c1f266b66a4a5c858b013302a5d8cbfc3" checksum = "d947f6b3163d8857ea16c4fa0dd4840d52f3041039a85decd46867eb1abef2e4"
dependencies = [ dependencies = [
"itoa", "itoa",
"ryu", "ryu",
@ -1120,9 +1120,9 @@ checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]] [[package]]
name = "strum" name = "strum"
version = "0.26.2" version = "0.26.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d8cec3501a5194c432b2b7976db6b7d10ec95c253208b45f83f7136aa985e29" checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06"
[[package]] [[package]]
name = "strum_macros" name = "strum_macros"
@ -1134,7 +1134,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"rustversion", "rustversion",
"syn 2.0.66", "syn 2.0.68",
] ]
[[package]] [[package]]
@ -1150,9 +1150,9 @@ dependencies = [
[[package]] [[package]]
name = "syn" name = "syn"
version = "2.0.66" version = "2.0.68"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c42f3f41a2de00b01c0aaad383c5a45241efc8b2d1eda5661812fda5f3cdcff5" checksum = "901fa70d88b9d6c98022e23b4136f9f3e54e4662c3bc1bd1d84a42a9a0f0c1e9"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -1218,7 +1218,7 @@ checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.66", "syn 2.0.68",
] ]
[[package]] [[package]]
@ -1486,5 +1486,5 @@ checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.66", "syn 2.0.68",
] ]

View File

@ -7,7 +7,7 @@ use nac3core::{
}, },
symbol_resolver::ValueEnum, symbol_resolver::ValueEnum,
toplevel::{helper::PrimDef, DefinitionId, GenCall}, toplevel::{helper::PrimDef, DefinitionId, GenCall},
typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, VarMap}, typecheck::typedef::{FunSignature, FuncArg, GenericObjectType, Type, TypeEnum, VarMap},
}; };
use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef}; use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef};
@ -23,7 +23,7 @@ use pyo3::{
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns}; use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
use nac3core::toplevel::numpy::unpack_ndarray_var_tys; use nac3core::toplevel::primitive_type;
use std::{ use std::{
collections::hash_map::DefaultHasher, collections::hash_map::DefaultHasher,
collections::HashMap, collections::HashMap,
@ -399,7 +399,9 @@ fn gen_rpc_tag(
gen_rpc_tag(ctx, *ty, buffer)?; gen_rpc_tag(ctx, *ty, buffer)?;
} }
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); let ndarray_ty = primitive_type::NDArrayType::create(ty, &mut ctx.unifier);
let ndarray_dtype = ndarray_ty.dtype_tvar(&mut ctx.unifier).ty;
let ndarray_ndims = ndarray_ty.ndims_tvar(&mut ctx.unifier).ty;
let ndarray_ndims = if let TLiteral { values, .. } = let ndarray_ndims = if let TLiteral { values, .. } =
&*ctx.unifier.get_ty_immutable(ndarray_ndims) &*ctx.unifier.get_ty_immutable(ndarray_ndims)
{ {
@ -645,7 +647,7 @@ pub fn attributes_writeback(
let ty = ty.unwrap(); let ty = ty.unwrap();
match &*ctx.unifier.get_ty(ty) { match &*ctx.unifier.get_ty(ty) {
TypeEnum::TObj { fields, obj_id, .. } TypeEnum::TObj { fields, obj_id, .. }
if *obj_id != ctx.primitives.option.obj_id(&ctx.unifier).unwrap() => if *obj_id != ctx.primitives.option.obj_id(&ctx.unifier) =>
{ {
// we only care about primitive attributes // we only care about primitive attributes
// for non-primitive attributes, they should be in another global // for non-primitive attributes, they should be in another global

View File

@ -4,20 +4,17 @@ use inkwell::{
AddressSpace, AddressSpace,
}; };
use itertools::Itertools; use itertools::Itertools;
use nac3core::typecheck::typedef::{GenericObjectType, GenericTypeAdapter};
use nac3core::{ use nac3core::{
codegen::{ codegen::{
classes::{NDArrayType, ProxyType}, classes::{NDArrayType, ProxyType},
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}, },
symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum}, symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum},
toplevel::{ toplevel::{helper::PrimDef, primitive_type, DefinitionId, TopLevelDef},
helper::PrimDef,
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
DefinitionId, TopLevelDef,
},
typecheck::{ typecheck::{
type_inferencer::PrimitiveStore, type_inferencer::PrimitiveStore,
typedef::{into_var_map, iter_type_vars, Type, TypeEnum, TypeVar, Unifier, VarMap}, typedef::{Type, TypeEnum, TypeVar, Unifier, VarMap},
}, },
}; };
use nac3parser::ast::{self, StrRef}; use nac3parser::ast::{self, StrRef};
@ -336,13 +333,18 @@ impl InnerResolver {
// do not handle type var param and concrete check here // do not handle type var param and concrete check here
let var = unifier.get_dummy_var().ty; let var = unifier.get_dummy_var().ty;
let ndims = unifier.get_fresh_const_generic_var(primitives.usize(), None, None).ty; let ndims = unifier.get_fresh_const_generic_var(primitives.usize(), None, None).ty;
let ndarray = make_ndarray_ty(unifier, primitives, Some(var), Some(ndims)); let ndarray = primitive_type::NDArrayType::from_primitive(
Ok(Ok((ndarray, false))) unifier,
primitives,
Some(var),
Some(ndims),
);
Ok(Ok((ndarray.into(), false)))
} else if ty_id == self.primitive_ids.tuple { } else if ty_id == self.primitive_ids.tuple {
// do not handle type var param and concrete check here // do not handle type var param and concrete check here
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false))) Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false)))
} else if ty_id == self.primitive_ids.option { } else if ty_id == self.primitive_ids.option {
Ok(Ok((primitives.option, false))) Ok(Ok((primitives.option.into(), false)))
} else if ty_id == self.primitive_ids.none { } else if ty_id == self.primitive_ids.none {
unreachable!("none cannot be typeid") unreachable!("none cannot be typeid")
} else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).copied() { } else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).copied() {
@ -509,7 +511,16 @@ impl InnerResolver {
)); ));
} }
Ok(Ok((make_ndarray_ty(unifier, primitives, Some(ty.0), None), true))) Ok(Ok((
primitive_type::NDArrayType::from_primitive(
unifier,
primitives,
Some(ty.0),
None,
)
.into(),
true,
)))
} }
TypeEnum::TTuple { .. } => { TypeEnum::TTuple { .. } => {
let args = match args let args = match args
@ -718,7 +729,9 @@ impl InnerResolver {
} }
} }
(TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PrimDef::NDArray.id() => { (TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PrimDef::NDArray.id() => {
let (ty, ndims) = unpack_ndarray_var_tys(unifier, extracted_ty); let ndarray = primitive_type::NDArrayType::create(extracted_ty, unifier);
let ty = ndarray.dtype_tvar(unifier).ty;
let ndims = ndarray.ndims_tvar(unifier).ty;
let len: usize = obj.getattr("ndim")?.extract()?; let len: usize = obj.getattr("ndim")?.extract()?;
if len == 0 { if len == 0 {
assert!(matches!( assert!(matches!(
@ -733,10 +746,14 @@ impl InnerResolver {
match dtype_ty { match dtype_ty {
Ok((t, _)) => match unifier.unify(ty, t) { Ok((t, _)) => match unifier.unify(ty, t) {
Ok(()) => { Ok(()) => {
let ndarray_ty = let ndarray_ty = primitive_type::NDArrayType::from_primitive(
make_ndarray_ty(unifier, primitives, Some(ty), Some(ndims)); unifier,
primitives,
Some(ty),
Some(ndims),
);
Ok(Ok(ndarray_ty)) Ok(Ok(ndarray_ty.into()))
} }
Err(e) => Ok(Err(format!( Err(e) => Ok(Err(format!(
"type error ({}) for the ndarray", "type error ({}) for the ndarray",
@ -759,7 +776,7 @@ impl InnerResolver {
// special handling for option type since its class member layout in python side // special handling for option type since its class member layout in python side
// is special and cannot be mapped directly to a nac3 type as below // is special and cannot be mapped directly to a nac3 type as below
(TypeEnum::TObj { obj_id, params, .. }, false) (TypeEnum::TObj { obj_id, params, .. }, false)
if *obj_id == primitives.option.obj_id(unifier).unwrap() => if *obj_id == primitives.option.obj_id(unifier) =>
{ {
let Ok(field_data) = obj.getattr("_nac3_option") else { let Ok(field_data) = obj.getattr("_nac3_option") else {
unreachable!("cannot be None") unreachable!("cannot be None")
@ -767,13 +784,12 @@ impl InnerResolver {
// if is `none` // if is `none`
let zelf_id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?; let zelf_id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?;
if zelf_id == self.primitive_ids.none { if zelf_id == self.primitive_ids.none {
let ty_enum = unifier.get_ty_immutable(primitives.option); let extracted_ty = GenericTypeAdapter::create(extracted_ty, unifier);
let TypeEnum::TObj { params, .. } = ty_enum.as_ref() else { let var_map = extracted_ty.iter_var_map(unifier, |tvar_iter, unifier| {
unreachable!("must be tobj") tvar_iter
}; .map(|tvar| {
let TypeEnum::TVar { id, range, name, loc, .. } =
let var_map = into_var_map(iter_type_vars(params).map(|tvar| { &*unifier.get_ty(tvar.ty)
let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(tvar.ty)
else { else {
unreachable!() unreachable!()
}; };
@ -781,8 +797,11 @@ impl InnerResolver {
assert_eq!(*id, tvar.id); assert_eq!(*id, tvar.id);
let ty = unifier.get_fresh_var_with_range(range, *name, *loc).ty; let ty = unifier.get_fresh_var_with_range(range, *name, *loc).ty;
TypeVar { id: *id, ty } TypeVar { id: *id, ty }
})); })
return Ok(Ok(unifier.subst(primitives.option, &var_map).unwrap())); .map(TypeVar::into)
.collect::<VarMap>()
});
return Ok(Ok(unifier.subst(primitives.option.into(), &var_map).unwrap()));
} }
let ty = match self.get_obj_type(py, field_data, unifier, defs, primitives)? { let ty = match self.get_obj_type(py, field_data, unifier, defs, primitives)? {
@ -797,10 +816,14 @@ impl InnerResolver {
let res = unifier.subst(extracted_ty, &new_var_map).unwrap_or(extracted_ty); let res = unifier.subst(extracted_ty, &new_var_map).unwrap_or(extracted_ty);
Ok(Ok(res)) Ok(Ok(res))
} }
(TypeEnum::TObj { params, fields, .. }, false) => { (TypeEnum::TObj { fields, .. }, false) => {
self.pyid_to_type.write().insert(py_obj_id, extracted_ty); self.pyid_to_type.write().insert(py_obj_id, extracted_ty);
let var_map = into_var_map(iter_type_vars(params).map(|tvar| { let extracted_ty = GenericTypeAdapter::create(extracted_ty, unifier);
let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(tvar.ty) let var_map = extracted_ty.iter_var_map(unifier, |tvar_iter, unifier| {
tvar_iter
.map(|tvar| {
let TypeEnum::TVar { id, range, name, loc, .. } =
&*unifier.get_ty(tvar.ty)
else { else {
unreachable!() unreachable!()
}; };
@ -808,8 +831,11 @@ impl InnerResolver {
assert_eq!(*id, tvar.id); assert_eq!(*id, tvar.id);
let ty = unifier.get_fresh_var_with_range(range, *name, *loc).ty; let ty = unifier.get_fresh_var_with_range(range, *name, *loc).ty;
TypeVar { id: *id, ty } TypeVar { id: *id, ty }
})); })
let mut instantiate_obj = || { .map(TypeVar::into)
.collect::<VarMap>()
});
let instantiate_obj = || {
// loop through non-function fields of the class to get the instantiated value // loop through non-function fields of the class to get the instantiated value
for field in fields { for field in fields {
let name: String = (*field.0).into(); let name: String = (*field.0).into();
@ -844,6 +870,7 @@ impl InnerResolver {
return Ok(Err("object is not of concrete type".into())); return Ok(Err("object is not of concrete type".into()));
} }
} }
let extracted_ty = extracted_ty.into();
let extracted_ty = let extracted_ty =
unifier.subst(extracted_ty, &var_map).unwrap_or(extracted_ty); unifier.subst(extracted_ty, &var_map).unwrap_or(extracted_ty);
Ok(Ok(extracted_ty)) Ok(Ok(extracted_ty))
@ -1027,8 +1054,9 @@ impl InnerResolver {
} else { } else {
unreachable!("must be ndarray") unreachable!("must be ndarray")
}; };
let (ndarray_dtype, ndarray_ndims) = let ndarray_ty = primitive_type::NDArrayType::create(ndarray_ty, &mut ctx.unifier);
unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty); let ndarray_dtype = ndarray_ty.dtype_tvar(&mut ctx.unifier).ty;
let ndarray_ndims = ndarray_ty.ndims_tvar(&mut ctx.unifier).ty;
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let ndarray_dtype_llvm_ty = ctx.get_llvm_type(generator, ndarray_dtype); let ndarray_dtype_llvm_ty = ctx.get_llvm_type(generator, ndarray_dtype);
@ -1175,7 +1203,7 @@ impl InnerResolver {
} else if ty_id == self.primitive_ids.option { } else if ty_id == self.primitive_ids.option {
let option_val_ty = match ctx.unifier.get_ty_immutable(expected_ty).as_ref() { let option_val_ty = match ctx.unifier.get_ty_immutable(expected_ty).as_ref() {
TypeEnum::TObj { obj_id, params, .. } TypeEnum::TObj { obj_id, params, .. }
if *obj_id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap() => if *obj_id == ctx.primitives.option.obj_id(&ctx.unifier) =>
{ {
*params.iter().next().unwrap().1 *params.iter().next().unwrap().1
} }

View File

@ -10,7 +10,7 @@ constant-optimization = ["fold"]
fold = [] fold = []
[dependencies] [dependencies]
lazy_static = "1.4" lazy_static = "1.5"
parking_lot = "0.12" parking_lot = "0.12"
string-interner = "0.17" string-interner = "0.17"
fxhash = "0.2" fxhash = "0.2"

View File

@ -8,8 +8,8 @@ use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
use crate::codegen::stmt::gen_for_callback_incrementing; use crate::codegen::stmt::gen_for_callback_incrementing;
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator}; use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
use crate::toplevel::helper::PrimDef; use crate::toplevel::helper::PrimDef;
use crate::toplevel::numpy::unpack_ndarray_var_tys; use crate::toplevel::primitive_type;
use crate::typecheck::typedef::Type; use crate::typecheck::typedef::{GenericObjectType, Type};
/// Shorthand for [`unreachable!()`] when a type of argument is not supported. /// Shorthand for [`unreachable!()`] when a type of argument is not supported.
/// ///
@ -66,7 +66,9 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -101,7 +103,7 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>(
.iter() .iter()
.any(|ty| ctx.unifier.unioned(n_ty, *ty))); .any(|ty| ctx.unifier.unioned(n_ty, *ty)));
if ctx.unifier.unioned(n_ty, ctx.primitives.int32) { if n_ty.is_signed(&mut ctx.unifier, &ctx.primitives) {
ctx.builder.build_int_s_extend(n, llvm_i64, "sext").map(Into::into).unwrap() ctx.builder.build_int_s_extend(n, llvm_i64, "sext").map(Into::into).unwrap()
} else { } else {
ctx.builder.build_int_z_extend(n, llvm_i64, "zext").map(Into::into).unwrap() ctx.builder.build_int_z_extend(n, llvm_i64, "zext").map(Into::into).unwrap()
@ -128,7 +130,9 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -206,7 +210,9 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -241,7 +247,7 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>(
.iter() .iter()
.any(|ty| ctx.unifier.unioned(n_ty, *ty))); .any(|ty| ctx.unifier.unioned(n_ty, *ty)));
if ctx.unifier.unioned(n_ty, ctx.primitives.int32) { if n_ty.is_signed(&mut ctx.unifier, &ctx.primitives) {
ctx.builder.build_int_s_extend(n, llvm_i64, "sext").map(Into::into).unwrap() ctx.builder.build_int_s_extend(n, llvm_i64, "sext").map(Into::into).unwrap()
} else { } else {
ctx.builder.build_int_z_extend(n, llvm_i64, "zext").map(Into::into).unwrap() ctx.builder.build_int_z_extend(n, llvm_i64, "zext").map(Into::into).unwrap()
@ -273,7 +279,9 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -304,20 +312,9 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
Ok(match n { Ok(match n {
BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32 | 64) => { BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32 | 64) => {
debug_assert!([ debug_assert!(n_ty.is_integral(&mut ctx.unifier, &ctx.primitives));
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
]
.iter()
.any(|ty| ctx.unifier.unioned(n_ty, *ty)));
if [ctx.primitives.bool, ctx.primitives.int32, ctx.primitives.int64] if n_ty.is_signed(&mut ctx.unifier, &ctx.primitives) {
.iter()
.any(|ty| ctx.unifier.unioned(n_ty, *ty))
{
ctx.builder ctx.builder
.build_signed_int_to_float(n, llvm_f64, "sitofp") .build_signed_int_to_float(n, llvm_f64, "sitofp")
.map(Into::into) .map(Into::into)
@ -331,7 +328,7 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
} }
BasicValueEnum::FloatValue(n) => { BasicValueEnum::FloatValue(n) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); debug_assert!(n_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
n.into() n.into()
} }
@ -339,7 +336,9 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -373,7 +372,7 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
Ok(match n { Ok(match n {
BasicValueEnum::FloatValue(n) => { BasicValueEnum::FloatValue(n) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); debug_assert!(n_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
let val = llvm_intrinsics::call_float_round(ctx, n, None); let val = llvm_intrinsics::call_float_round(ctx, n, None);
ctx.builder ctx.builder
@ -385,7 +384,9 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -417,7 +418,7 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>(
Ok(match n { Ok(match n {
BasicValueEnum::FloatValue(n) => { BasicValueEnum::FloatValue(n) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); debug_assert!(n_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
llvm_intrinsics::call_float_roundeven(ctx, n, None).into() llvm_intrinsics::call_float_roundeven(ctx, n, None).into()
} }
@ -425,7 +426,9 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -463,14 +466,10 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
} }
BasicValueEnum::IntValue(n) => { BasicValueEnum::IntValue(n) => {
debug_assert!([ debug_assert!(
ctx.primitives.int32, n_ty.is_integral(&mut ctx.unifier, &ctx.primitives)
ctx.primitives.uint32, && n_ty.is_arithmetic(&mut ctx.unifier, &ctx.primitives)
ctx.primitives.int64, );
ctx.primitives.uint64,
]
.iter()
.any(|ty| ctx.unifier.unioned(n_ty, *ty)));
ctx.builder ctx.builder
.build_int_compare(IntPredicate::NE, n, n.get_type().const_zero(), FN_NAME) .build_int_compare(IntPredicate::NE, n, n.get_type().const_zero(), FN_NAME)
@ -479,7 +478,7 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
} }
BasicValueEnum::FloatValue(n) => { BasicValueEnum::FloatValue(n) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); debug_assert!(n_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
ctx.builder ctx.builder
.build_float_compare(FloatPredicate::UNE, n, n.get_type().const_zero(), FN_NAME) .build_float_compare(FloatPredicate::UNE, n, n.get_type().const_zero(), FN_NAME)
@ -490,7 +489,9 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -528,7 +529,7 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>(
Ok(match n { Ok(match n {
BasicValueEnum::FloatValue(n) => { BasicValueEnum::FloatValue(n) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); debug_assert!(n_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
let val = llvm_intrinsics::call_float_floor(ctx, n, None); let val = llvm_intrinsics::call_float_floor(ctx, n, None);
if let BasicTypeEnum::IntType(llvm_ret_elem_ty) = llvm_ret_elem_ty { if let BasicTypeEnum::IntType(llvm_ret_elem_ty) = llvm_ret_elem_ty {
@ -544,7 +545,9 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -578,7 +581,7 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>(
Ok(match n { Ok(match n {
BasicValueEnum::FloatValue(n) => { BasicValueEnum::FloatValue(n) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); debug_assert!(n_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
let val = llvm_intrinsics::call_float_ceil(ctx, n, None); let val = llvm_intrinsics::call_float_ceil(ctx, n, None);
if let BasicTypeEnum::IntType(llvm_ret_elem_ty) = llvm_ret_elem_ty { if let BasicTypeEnum::IntType(llvm_ret_elem_ty) = llvm_ret_elem_ty {
@ -594,7 +597,9 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let elem_ty = primitive_type::NDArrayType::create(n_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
generator, generator,
@ -631,20 +636,9 @@ pub fn call_min<'ctx>(
match (m, n) { match (m, n) {
(BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => { (BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => {
debug_assert!([ debug_assert!(common_ty.is_integral(&mut ctx.unifier, &ctx.primitives));
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
]
.iter()
.any(|ty| ctx.unifier.unioned(common_ty, *ty)));
if [ctx.primitives.int32, ctx.primitives.int64] if common_ty.is_signed(&mut ctx.unifier, &ctx.primitives) {
.iter()
.any(|ty| ctx.unifier.unioned(common_ty, *ty))
{
llvm_intrinsics::call_int_smin(ctx, m, n, Some(FN_NAME)).into() llvm_intrinsics::call_int_smin(ctx, m, n, Some(FN_NAME)).into()
} else { } else {
llvm_intrinsics::call_int_umin(ctx, m, n, Some(FN_NAME)).into() llvm_intrinsics::call_int_umin(ctx, m, n, Some(FN_NAME)).into()
@ -652,7 +646,7 @@ pub fn call_min<'ctx>(
} }
(BasicValueEnum::FloatValue(m), BasicValueEnum::FloatValue(n)) => { (BasicValueEnum::FloatValue(m), BasicValueEnum::FloatValue(n)) => {
debug_assert!(ctx.unifier.unioned(common_ty, ctx.primitives.float)); debug_assert!(common_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
llvm_intrinsics::call_float_minnum(ctx, m, n, Some(FN_NAME)).into() llvm_intrinsics::call_float_minnum(ctx, m, n, Some(FN_NAME)).into()
} }
@ -675,16 +669,10 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>(
Ok(match a { Ok(match a {
BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => { BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => {
debug_assert!([ debug_assert!(
ctx.primitives.bool, a_ty.is_integral(&mut ctx.unifier, &ctx.primitives)
ctx.primitives.int32, || a_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives)
ctx.primitives.uint32, );
ctx.primitives.int64,
ctx.primitives.uint64,
ctx.primitives.float,
]
.iter()
.any(|ty| ctx.unifier.unioned(a_ty, *ty)));
a a
} }
@ -692,7 +680,9 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); let elem_ty = primitive_type::NDArrayType::create(a_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
@ -761,22 +751,13 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
Ok(match (x1, x2) { Ok(match (x1, x2) {
(BasicValueEnum::IntValue(x1), BasicValueEnum::IntValue(x2)) => { (BasicValueEnum::IntValue(x1), BasicValueEnum::IntValue(x2)) => {
debug_assert!([ debug_assert!(common_ty.unwrap().is_integral(&mut ctx.unifier, &ctx.primitives));
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
ctx.primitives.float,
]
.iter()
.any(|ty| ctx.unifier.unioned(common_ty.unwrap(), *ty)));
call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into()))
} }
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(common_ty.unwrap(), ctx.primitives.float)); debug_assert!(common_ty.unwrap().is_floating_point(&mut ctx.unifier, &ctx.primitives));
call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into()))
} }
@ -792,16 +773,24 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 { let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); .dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1 ndarray_dtype1
} else if is_ndarray1 { } else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else { } else {
unreachable!() unreachable!()
}; };
@ -847,20 +836,9 @@ pub fn call_max<'ctx>(
match (m, n) { match (m, n) {
(BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => { (BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => {
debug_assert!([ debug_assert!(common_ty.is_integral(&mut ctx.unifier, &ctx.primitives));
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
]
.iter()
.any(|ty| ctx.unifier.unioned(common_ty, *ty)));
if [ctx.primitives.int32, ctx.primitives.int64] if common_ty.is_signed(&mut ctx.unifier, &ctx.primitives) {
.iter()
.any(|ty| ctx.unifier.unioned(common_ty, *ty))
{
llvm_intrinsics::call_int_smax(ctx, m, n, Some(FN_NAME)).into() llvm_intrinsics::call_int_smax(ctx, m, n, Some(FN_NAME)).into()
} else { } else {
llvm_intrinsics::call_int_umax(ctx, m, n, Some(FN_NAME)).into() llvm_intrinsics::call_int_umax(ctx, m, n, Some(FN_NAME)).into()
@ -868,7 +846,7 @@ pub fn call_max<'ctx>(
} }
(BasicValueEnum::FloatValue(m), BasicValueEnum::FloatValue(n)) => { (BasicValueEnum::FloatValue(m), BasicValueEnum::FloatValue(n)) => {
debug_assert!(ctx.unifier.unioned(common_ty, ctx.primitives.float)); debug_assert!(common_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
llvm_intrinsics::call_float_maxnum(ctx, m, n, Some(FN_NAME)).into() llvm_intrinsics::call_float_maxnum(ctx, m, n, Some(FN_NAME)).into()
} }
@ -891,16 +869,10 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
Ok(match a { Ok(match a {
BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => { BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => {
debug_assert!([ debug_assert!(
ctx.primitives.bool, a_ty.is_integral(&mut ctx.unifier, &ctx.primitives)
ctx.primitives.int32, || a_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives)
ctx.primitives.uint32, );
ctx.primitives.int64,
ctx.primitives.uint64,
ctx.primitives.float,
]
.iter()
.any(|ty| ctx.unifier.unioned(a_ty, *ty)));
a a
} }
@ -908,7 +880,9 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); let elem_ty = primitive_type::NDArrayType::create(a_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
@ -977,22 +951,13 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
Ok(match (x1, x2) { Ok(match (x1, x2) {
(BasicValueEnum::IntValue(x1), BasicValueEnum::IntValue(x2)) => { (BasicValueEnum::IntValue(x1), BasicValueEnum::IntValue(x2)) => {
debug_assert!([ debug_assert!(common_ty.unwrap().is_integral(&mut ctx.unifier, &ctx.primitives));
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
ctx.primitives.float,
]
.iter()
.any(|ty| ctx.unifier.unioned(common_ty.unwrap(), *ty)));
call_max(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) call_max(ctx, (x1_ty, x1.into()), (x2_ty, x2.into()))
} }
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(common_ty.unwrap(), ctx.primitives.float)); debug_assert!(common_ty.unwrap().is_floating_point(&mut ctx.unifier, &ctx.primitives));
call_max(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) call_max(ctx, (x1_ty, x1.into()), (x2_ty, x2.into()))
} }
@ -1008,16 +973,24 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 { let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); .dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1 ndarray_dtype1
} else if is_ndarray1 { } else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else { } else {
unreachable!() unreachable!()
}; };
@ -1075,7 +1048,9 @@ where
if arg_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if arg_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let (arg_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); let arg_elem_ty = primitive_type::NDArrayType::create(arg_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ret_elem_ty = get_ret_elem_type(ctx, arg_elem_ty); let ret_elem_ty = get_ret_elem_type(ctx, arg_elem_ty);
let ndarray = ndarray_elementwise_unaryop_impl( let ndarray = ndarray_elementwise_unaryop_impl(
@ -1117,22 +1092,11 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
n, n,
FN_NAME, FN_NAME,
&|_ctx, elem_ty| elem_ty, &|_ctx, elem_ty| elem_ty,
&|_generator, ctx, val_ty, val| match val { &|_, ctx, val_ty, val| match val {
BasicValueEnum::IntValue(n) => Some({ BasicValueEnum::IntValue(n) => Some({
debug_assert!([ debug_assert!(val_ty.is_integral(&mut ctx.unifier, &ctx.primitives));
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
]
.iter()
.any(|ty| ctx.unifier.unioned(val_ty, *ty)));
if [ctx.primitives.int32, ctx.primitives.int64] if val_ty.is_signed(&mut ctx.unifier, &ctx.primitives) {
.iter()
.any(|ty| ctx.unifier.unioned(val_ty, *ty))
{
llvm_intrinsics::call_int_abs( llvm_intrinsics::call_int_abs(
ctx, ctx,
n, n,
@ -1146,7 +1110,7 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
}), }),
BasicValueEnum::FloatValue(n) => Some({ BasicValueEnum::FloatValue(n) => Some({
debug_assert!(ctx.unifier.unioned(val_ty, ctx.primitives.float)); debug_assert!(val_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
llvm_intrinsics::call_float_fabs(ctx, n, Some(FN_NAME)).into() llvm_intrinsics::call_float_fabs(ctx, n, Some(FN_NAME)).into()
}), }),
@ -1431,8 +1395,8 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>(
Ok(match (x1, x2) { Ok(match (x1, x2) {
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); debug_assert!(x2_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
extern_fns::call_atan2(ctx, x1, x2, None).into() extern_fns::call_atan2(ctx, x1, x2, None).into()
} }
@ -1448,16 +1412,24 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 { let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); .dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1 ndarray_dtype1
} else if is_ndarray1 { } else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else { } else {
unreachable!() unreachable!()
}; };
@ -1498,8 +1470,8 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>(
Ok(match (x1, x2) { Ok(match (x1, x2) {
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); debug_assert!(x2_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
llvm_intrinsics::call_float_copysign(ctx, x1, x2, None).into() llvm_intrinsics::call_float_copysign(ctx, x1, x2, None).into()
} }
@ -1515,16 +1487,24 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 { let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); .dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1 ndarray_dtype1
} else if is_ndarray1 { } else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else { } else {
unreachable!() unreachable!()
}; };
@ -1565,8 +1545,8 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>(
Ok(match (x1, x2) { Ok(match (x1, x2) {
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); debug_assert!(x2_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
llvm_intrinsics::call_float_maxnum(ctx, x1, x2, None).into() llvm_intrinsics::call_float_maxnum(ctx, x1, x2, None).into()
} }
@ -1582,16 +1562,24 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 { let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); .dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1 ndarray_dtype1
} else if is_ndarray1 { } else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else { } else {
unreachable!() unreachable!()
}; };
@ -1632,8 +1620,8 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>(
Ok(match (x1, x2) { Ok(match (x1, x2) {
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); debug_assert!(x2_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
llvm_intrinsics::call_float_minnum(ctx, x1, x2, None).into() llvm_intrinsics::call_float_minnum(ctx, x1, x2, None).into()
} }
@ -1649,16 +1637,24 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 { let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); .dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1 ndarray_dtype1
} else if is_ndarray1 { } else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else { } else {
unreachable!() unreachable!()
}; };
@ -1699,7 +1695,7 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>(
Ok(match (x1, x2) { Ok(match (x1, x2) {
(BasicValueEnum::FloatValue(x1), BasicValueEnum::IntValue(x2)) => { (BasicValueEnum::FloatValue(x1), BasicValueEnum::IntValue(x2)) => {
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.int32)); debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.int32));
extern_fns::call_ldexp(ctx, x1, x2, None).into() extern_fns::call_ldexp(ctx, x1, x2, None).into()
@ -1715,12 +1711,22 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>(
let is_ndarray2 = let is_ndarray2 =
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = let dtype = if is_ndarray1 {
if is_ndarray1 { unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 } else { x1_ty }; primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else {
x1_ty
};
let x1_scalar_ty = dtype; let x1_scalar_ty = dtype;
let x2_scalar_ty = let x2_scalar_ty = if is_ndarray2 {
if is_ndarray2 { unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 } else { x2_ty }; primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else {
x2_ty
};
numpy::ndarray_elementwise_binop_impl( numpy::ndarray_elementwise_binop_impl(
generator, generator,
@ -1755,8 +1761,8 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>(
Ok(match (x1, x2) { Ok(match (x1, x2) {
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); debug_assert!(x2_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
extern_fns::call_hypot(ctx, x1, x2, None).into() extern_fns::call_hypot(ctx, x1, x2, None).into()
} }
@ -1772,16 +1778,24 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 { let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); .dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1 ndarray_dtype1
} else if is_ndarray1 { } else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else { } else {
unreachable!() unreachable!()
}; };
@ -1822,8 +1836,8 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
Ok(match (x1, x2) { Ok(match (x1, x2) {
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); debug_assert!(x1_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); debug_assert!(x2_ty.is_floating_point(&mut ctx.unifier, &ctx.primitives));
extern_fns::call_nextafter(ctx, x1, x2, None).into() extern_fns::call_nextafter(ctx, x1, x2, None).into()
} }
@ -1839,16 +1853,24 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let dtype = if is_ndarray1 && is_ndarray2 { let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let ndarray_dtype1 = primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); .dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_dtype2 = primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1 ndarray_dtype1
} else if is_ndarray1 { } else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 primitive_type::NDArrayType::create(x1_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 primitive_type::NDArrayType::create(x2_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} else { } else {
unreachable!() unreachable!()
}; };

View File

@ -1,5 +1,9 @@
use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip}; use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
use super::{llvm_intrinsics::call_memcpy_generic, need_sret, CodeGenerator};
use crate::toplevel::primitive_type;
use crate::toplevel::primitive_type::OptionType;
use crate::typecheck::typedef::GenericObjectType;
use crate::{ use crate::{
codegen::{ codegen::{
classes::{ classes::{
@ -15,11 +19,7 @@ use crate::{
CodeGenContext, CodeGenTask, CodeGenContext, CodeGenTask,
}, },
symbol_resolver::{SymbolValue, ValueEnum}, symbol_resolver::{SymbolValue, ValueEnum},
toplevel::{ toplevel::{helper::PrimDef, DefinitionId, TopLevelDef},
helper::PrimDef,
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
DefinitionId, TopLevelDef,
},
typecheck::{ typecheck::{
magic_methods::{binop_assign_name, binop_name, unaryop_name}, magic_methods::{binop_assign_name, binop_name, unaryop_name},
typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap}, typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap},
@ -36,8 +36,6 @@ use nac3parser::ast::{
self, Boolop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef, self, Boolop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef,
}; };
use super::{llvm_intrinsics::call_memcpy_generic, need_sret, CodeGenerator};
pub fn get_subst_key( pub fn get_subst_key(
unifier: &mut Unifier, unifier: &mut Unifier,
obj: Option<Type>, obj: Option<Type>,
@ -162,14 +160,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
self.builder.build_load(ptr, "tup_val").unwrap() self.builder.build_load(ptr, "tup_val").unwrap()
} }
SymbolValue::OptionSome(v) => { SymbolValue::OptionSome(v) => {
let ty = match self.unifier.get_ty_immutable(ty).as_ref() { let ty = OptionType::create(ty, &mut self.unifier).type_tvar(&mut self.unifier).ty;
TypeEnum::TObj { obj_id, params, .. }
if *obj_id == self.primitives.option.obj_id(&self.unifier).unwrap() =>
{
*params.iter().next().unwrap().1
}
_ => unreachable!("must be option type"),
};
let val = self.gen_symbol_val(generator, v, ty); let val = self.gen_symbol_val(generator, v, ty);
let ptr = generator let ptr = generator
.gen_var_alloc(self, val.get_type(), Some("default_opt_some")) .gen_var_alloc(self, val.get_type(), Some("default_opt_some"))
@ -178,14 +169,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
ptr.into() ptr.into()
} }
SymbolValue::OptionNone => { SymbolValue::OptionNone => {
let ty = match self.unifier.get_ty_immutable(ty).as_ref() { let ty = OptionType::create(ty, &mut self.unifier).type_tvar(&mut self.unifier).ty;
TypeEnum::TObj { obj_id, params, .. }
if *obj_id == self.primitives.option.obj_id(&self.unifier).unwrap() =>
{
*params.iter().next().unwrap().1
}
_ => unreachable!("must be option type"),
};
let actual_ptr_type = let actual_ptr_type =
self.get_llvm_type(generator, ty).ptr_type(AddressSpace::default()); self.get_llvm_type(generator, ty).ptr_type(AddressSpace::default());
actual_ptr_type.const_null().into() actual_ptr_type.const_null().into()
@ -1206,8 +1190,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
let is_ndarray2 = ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let is_ndarray2 = ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
if is_ndarray1 && is_ndarray2 { if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1); let ndarray_dtype1 = primitive_type::NDArrayType::create(ty1, &mut ctx.unifier)
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty2); .dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_dtype2 = primitive_type::NDArrayType::create(ty2, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
@ -1256,8 +1244,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
Ok(Some(res.as_base_value().into())) Ok(Some(res.as_base_value().into()))
} else { } else {
let (ndarray_dtype, _) = let ndarray_dtype = primitive_type::NDArrayType::create(
unpack_ndarray_var_tys(&mut ctx.unifier, if is_ndarray1 { ty1 } else { ty2 }); if is_ndarray1 { ty1 } else { ty2 },
&mut ctx.unifier,
)
.dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_val = NDArrayValue::from_ptr_val( let ndarray_val = NDArrayValue::from_ptr_val(
if is_ndarray1 { left_val } else { right_val }.into_pointer_value(), if is_ndarray1 { left_val } else { right_val }.into_pointer_value(),
llvm_usize, llvm_usize,
@ -1443,7 +1435,9 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
} }
} else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { } else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); let ndarray_dtype = primitive_type::NDArrayType::create(ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
let val = NDArrayValue::from_ptr_val(val.into_pointer_value(), llvm_usize, None); let val = NDArrayValue::from_ptr_val(val.into_pointer_value(), llvm_usize, None);
@ -1527,8 +1521,13 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
return if is_ndarray1 && is_ndarray2 { return if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty); let ndarray_dtype1 = primitive_type::NDArrayType::create(left_ty, &mut ctx.unifier)
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty); .dtype_tvar(&mut ctx.unifier)
.ty;
let ndarray_dtype2 =
primitive_type::NDArrayType::create(right_ty, &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty;
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
@ -1562,10 +1561,12 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
Ok(Some(res.as_base_value().into())) Ok(Some(res.as_base_value().into()))
} else { } else {
let (ndarray_dtype, _) = unpack_ndarray_var_tys( let ndarray_dtype = primitive_type::NDArrayType::create(
&mut ctx.unifier,
if is_ndarray1 { left_ty } else { right_ty }, if is_ndarray1 { left_ty } else { right_ty },
); &mut ctx.unifier,
)
.dtype_tvar(&mut ctx.unifier)
.ty;
let res = numpy::ndarray_elementwise_binop_impl( let res = numpy::ndarray_elementwise_binop_impl(
generator, generator,
ctx, ctx,
@ -1788,9 +1789,13 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
ndims.iter().map(|v| SymbolValue::U64(v - subscripted_dims)).collect(), ndims.iter().map(|v| SymbolValue::U64(v - subscripted_dims)).collect(),
None, None,
); );
let ndarray_ty = let ndarray_ty = primitive_type::NDArrayType::from_primitive(
make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(ty), Some(ndarray_ndims_ty)); &mut ctx.unifier,
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type(); &ctx.primitives,
Some(ty),
Some(ndarray_ndims_ty),
);
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty.into()).into_pointer_type();
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type(); let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum(); let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum();
@ -2082,7 +2087,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
ExprKind::Name { id, .. } if id == &"none".into() => { ExprKind::Name { id, .. } if id == &"none".into() => {
match ( match (
ctx.unifier.get_ty(expr.custom.unwrap()).as_ref(), ctx.unifier.get_ty(expr.custom.unwrap()).as_ref(),
ctx.unifier.get_ty(ctx.primitives.option).as_ref(), ctx.unifier.get_ty(ctx.primitives.option.into()).as_ref(),
) { ) {
(TypeEnum::TObj { obj_id, params, .. }, TypeEnum::TObj { obj_id: opt_id, .. }) (TypeEnum::TObj { obj_id, params, .. }, TypeEnum::TObj { obj_id: opt_id, .. })
if *obj_id == *opt_id => if *obj_id == *opt_id =>
@ -2464,8 +2469,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
}; };
// directly generate code for option.unwrap // directly generate code for option.unwrap
// since it needs to return static value to optimize for kernel invariant // since it needs to return static value to optimize for kernel invariant
if attr == &"unwrap".into() if attr == &"unwrap".into() && id == ctx.primitives.option.obj_id(&ctx.unifier)
&& id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap()
{ {
match val { match val {
ValueEnum::Static(v) => { ValueEnum::Static(v) => {

View File

@ -1,7 +1,7 @@
use crate::{ use crate::{
codegen::classes::{ListType, NDArrayType, ProxyType, RangeType}, codegen::classes::{ListType, NDArrayType, ProxyType, RangeType},
symbol_resolver::{StaticValue, SymbolResolver}, symbol_resolver::{StaticValue, SymbolResolver},
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef}, toplevel::{helper::PrimDef, TopLevelContext, TopLevelDef},
typecheck::{ typecheck::{
type_inferencer::{CodeLocation, PrimitiveStore}, type_inferencer::{CodeLocation, PrimitiveStore},
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier}, typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
@ -47,6 +47,9 @@ pub mod stmt;
#[cfg(test)] #[cfg(test)]
mod test; mod test;
use crate::toplevel::primitive_type;
use crate::toplevel::primitive_type::OptionType;
use crate::typecheck::typedef::GenericObjectType;
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore}; use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore};
pub use generator::{CodeGenerator, DefaultCodeGenerator}; pub use generator::{CodeGenerator, DefaultCodeGenerator};
@ -457,7 +460,9 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
} }
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (dtype, _) = unpack_ndarray_var_tys(unifier, ty); let dtype = primitive_type::NDArrayType::create(ty, unifier)
.dtype_tvar(unifier)
.ty;
let element_type = get_llvm_type( let element_type = get_llvm_type(
ctx, module, generator, unifier, top_level, type_cache, dtype, ctx, module, generator, unifier, top_level, type_cache, dtype,
); );
@ -634,7 +639,10 @@ pub fn gen_func_impl<
range: unifier.get_representative(primitives.range), range: unifier.get_representative(primitives.range),
str: unifier.get_representative(primitives.str), str: unifier.get_representative(primitives.str),
exception: unifier.get_representative(primitives.exception), exception: unifier.get_representative(primitives.exception),
option: unifier.get_representative(primitives.option), option: OptionType::create(
unifier.get_representative(primitives.option.into()),
&mut unifier,
),
..primitives ..primitives
}; };

View File

@ -17,12 +17,8 @@ use crate::{
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}, },
symbol_resolver::ValueEnum, symbol_resolver::ValueEnum,
toplevel::{ toplevel::{helper::PrimDef, primitive_type, DefinitionId},
helper::PrimDef, typecheck::typedef::{FunSignature, GenericObjectType, Type, TypeEnum},
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
DefinitionId,
},
typecheck::typedef::{FunSignature, Type, TypeEnum},
}; };
use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType}; use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType};
use inkwell::{ use inkwell::{
@ -38,12 +34,17 @@ fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type, elem_ty: Type,
) -> Result<NDArrayValue<'ctx>, String> { ) -> Result<NDArrayValue<'ctx>, String> {
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None); let ndarray_ty = primitive_type::NDArrayType::from_primitive(
&mut ctx.unifier,
&ctx.primitives,
Some(elem_ty),
None,
);
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_ndarray_t = ctx let llvm_ndarray_t = ctx
.get_llvm_type(generator, ndarray_ty) .get_llvm_type(generator, ndarray_ty.into())
.into_pointer_type() .into_pointer_type()
.get_element_type() .get_element_type()
.into_struct_type(); .into_struct_type();
@ -1799,7 +1800,9 @@ pub fn gen_ndarray_array<'ctx>(
let obj_ty = fun.0.args[0].ty; let obj_ty = fun.0.args[0].ty;
let obj_elem_ty = match &*context.unifier.get_ty(obj_ty) { let obj_elem_ty = match &*context.unifier.get_ty(obj_ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
unpack_ndarray_var_tys(&mut context.unifier, obj_ty).0 primitive_type::NDArrayType::create(obj_ty, &mut context.unifier)
.dtype_tvar(&mut context.unifier)
.ty
} }
TypeEnum::TList { ty } => { TypeEnum::TList { ty } => {
@ -1939,7 +1942,9 @@ pub fn gen_ndarray_copy<'ctx>(
let llvm_usize = generator.get_size_type(context.ctx); let llvm_usize = generator.get_size_type(context.ctx);
let this_ty = obj.as_ref().unwrap().0; let this_ty = obj.as_ref().unwrap().0;
let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty); let this_elem_ty = primitive_type::NDArrayType::create(this_ty, &mut context.unifier)
.dtype_tvar(&mut context.unifier)
.ty;
let this_arg = let this_arg =
obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?; obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?;

View File

@ -4,13 +4,15 @@ use super::{
irrt::{handle_slice_indices, list_slice_assignment}, irrt::{handle_slice_indices, list_slice_assignment},
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}; };
use crate::toplevel::primitive_type;
use crate::typecheck::typedef::GenericObjectType;
use crate::{ use crate::{
codegen::{ codegen::{
classes::{ArrayLikeIndexer, ArraySliceValue, ListValue, RangeValue}, classes::{ArrayLikeIndexer, ArraySliceValue, ListValue, RangeValue},
expr::gen_binop_expr, expr::gen_binop_expr,
gen_in_range_check, gen_in_range_check,
}, },
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef}, toplevel::{helper::PrimDef, DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type, TypeEnum}, typecheck::typedef::{FunSignature, Type, TypeEnum},
}; };
use inkwell::{ use inkwell::{
@ -245,7 +247,9 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) { let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) {
TypeEnum::TList { ty } => *ty, TypeEnum::TList { ty } => *ty,
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
unpack_ndarray_var_tys(&mut ctx.unifier, target.custom.unwrap()).0 primitive_type::NDArrayType::create(target.custom.unwrap(), &mut ctx.unifier)
.dtype_tvar(&mut ctx.unifier)
.ty
} }
_ => unreachable!(), _ => unreachable!(),
}; };

View File

@ -3,6 +3,7 @@ use std::rc::Rc;
use std::sync::Arc; use std::sync::Arc;
use std::{collections::HashMap, collections::HashSet, fmt::Display}; use std::{collections::HashMap, collections::HashSet, fmt::Display};
use crate::typecheck::typedef::GenericObjectType;
use crate::{ use crate::{
codegen::{CodeGenContext, CodeGenerator}, codegen::{CodeGenContext, CodeGenerator},
toplevel::{type_annotation::TypeAnnotation, DefinitionId, TopLevelDef}, toplevel::{type_annotation::TypeAnnotation, DefinitionId, TopLevelDef},
@ -43,7 +44,7 @@ impl SymbolValue {
) -> Result<Self, String> { ) -> Result<Self, String> {
match constant { match constant {
Constant::None => { Constant::None => {
if unifier.unioned(expected_ty, primitives.option) { if unifier.unioned(expected_ty, primitives.option.into()) {
Ok(SymbolValue::OptionNone) Ok(SymbolValue::OptionNone)
} else { } else {
Err(format!("Expected {expected_ty:?}, but got Option")) Err(format!("Expected {expected_ty:?}, but got Option"))
@ -157,7 +158,7 @@ impl SymbolValue {
let vs_tys = vs.iter().map(|v| v.get_type(primitives, unifier)).collect::<Vec<_>>(); let vs_tys = vs.iter().map(|v| v.get_type(primitives, unifier)).collect::<Vec<_>>();
unifier.add_ty(TypeEnum::TTuple { ty: vs_tys }) unifier.add_ty(TypeEnum::TTuple { ty: vs_tys })
} }
SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option, SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option.into(),
} }
} }
@ -183,13 +184,13 @@ impl SymbolValue {
TypeAnnotation::Tuple(vs_tys) TypeAnnotation::Tuple(vs_tys)
} }
SymbolValue::OptionNone => TypeAnnotation::CustomClass { SymbolValue::OptionNone => TypeAnnotation::CustomClass {
id: primitives.option.obj_id(unifier).unwrap(), id: primitives.option.obj_id(unifier),
params: Vec::default(), params: Vec::default(),
}, },
SymbolValue::OptionSome(v) => { SymbolValue::OptionSome(v) => {
let ty = v.get_type_annotation(primitives, unifier); let ty = v.get_type_annotation(primitives, unifier);
TypeAnnotation::CustomClass { TypeAnnotation::CustomClass {
id: primitives.option.obj_id(unifier).unwrap(), id: primitives.option.obj_id(unifier),
params: vec![ty], params: vec![ty],
} }
} }

View File

@ -24,8 +24,8 @@ use crate::{
stmt::exn_constructor, stmt::exn_constructor,
}, },
symbol_resolver::SymbolValue, symbol_resolver::SymbolValue,
toplevel::{helper::PrimDef, numpy::make_ndarray_ty}, toplevel::helper::PrimDef,
typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap}, typecheck::typedef::{into_var_map, TypeVar, VarMap},
}; };
use super::*; use super::*;
@ -303,10 +303,7 @@ struct BuiltinBuilder<'a> {
is_some_ty: (Type, bool), is_some_ty: (Type, bool),
unwrap_ty: (Type, bool), unwrap_ty: (Type, bool),
option_tvar: TypeVar,
ndarray_dtype_tvar: TypeVar,
ndarray_ndims_tvar: TypeVar,
ndarray_copy_ty: (Type, bool), ndarray_copy_ty: (Type, bool),
ndarray_fill_ty: (Type, bool), ndarray_fill_ty: (Type, bool),
@ -315,9 +312,9 @@ struct BuiltinBuilder<'a> {
num_ty: TypeVar, num_ty: TypeVar,
num_var_map: VarMap, num_var_map: VarMap,
ndarray_float: Type, ndarray_float: primitive_type::NDArrayType,
ndarray_float_2d: Type, ndarray_float_2d: primitive_type::NDArrayType,
ndarray_num_ty: Type, ndarray_num_ty: primitive_type::NDArrayType,
float_or_ndarray_ty: TypeVar, float_or_ndarray_ty: TypeVar,
float_or_ndarray_var_map: VarMap, float_or_ndarray_var_map: VarMap,
@ -344,24 +341,19 @@ impl<'a> BuiltinBuilder<'a> {
} = *primitives; } = *primitives;
// Option-related // Option-related
let (is_some_ty, unwrap_ty, option_tvar) = let (is_some_ty, unwrap_ty) =
if let TypeEnum::TObj { fields, params, .. } = unifier.get_ty(option).as_ref() { if let TypeEnum::TObj { fields, .. } = &*unifier.get_ty(option.into()) {
( (
*fields.get(&PrimDef::OptionIsSome.simple_name().into()).unwrap(), *fields.get(&PrimDef::OptionIsSome.simple_name().into()).unwrap(),
*fields.get(&PrimDef::OptionUnwrap.simple_name().into()).unwrap(), *fields.get(&PrimDef::OptionUnwrap.simple_name().into()).unwrap(),
iter_type_vars(params).next().unwrap(),
) )
} else { } else {
unreachable!() unreachable!()
}; };
let TypeEnum::TObj { fields: ndarray_fields, params: ndarray_params, .. } = let TypeEnum::TObj { fields: ndarray_fields, .. } = &*unifier.get_ty(ndarray.into()) else {
&*unifier.get_ty(ndarray)
else {
unreachable!() unreachable!()
}; };
let ndarray_dtype_tvar = iter_type_vars(ndarray_params).next().unwrap();
let ndarray_ndims_tvar = iter_type_vars(ndarray_params).nth(1).unwrap();
let ndarray_copy_ty = let ndarray_copy_ty =
*ndarray_fields.get(&PrimDef::NDArrayCopy.simple_name().into()).unwrap(); *ndarray_fields.get(&PrimDef::NDArrayCopy.simple_name().into()).unwrap();
let ndarray_fill_ty = let ndarray_fill_ty =
@ -374,7 +366,8 @@ impl<'a> BuiltinBuilder<'a> {
); );
let num_var_map = into_var_map([num_ty]); let num_var_map = into_var_map([num_ty]);
let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), None); let ndarray_float =
primitive_type::NDArrayType::from_primitive(unifier, primitives, Some(float), None);
let ndarray_float_2d = { let ndarray_float_2d = {
let value = match primitives.size_t { let value = match primitives.size_t {
64 => SymbolValue::U64(2u64), 64 => SymbolValue::U64(2u64),
@ -383,16 +376,28 @@ impl<'a> BuiltinBuilder<'a> {
}; };
let ndims = unifier.add_ty(TypeEnum::TLiteral { values: vec![value], loc: None }); let ndims = unifier.add_ty(TypeEnum::TLiteral { values: vec![value], loc: None });
make_ndarray_ty(unifier, primitives, Some(float), Some(ndims)) primitive_type::NDArrayType::from_primitive(
unifier,
primitives,
Some(float),
Some(ndims),
)
}; };
let ndarray_num_ty = make_ndarray_ty(unifier, primitives, Some(num_ty.ty), None); let ndarray_num_ty =
let float_or_ndarray_ty = primitive_type::NDArrayType::from_primitive(unifier, primitives, Some(num_ty.ty), None);
unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); let float_or_ndarray_ty = unifier.get_fresh_var_with_range(
&[float, ndarray_float.into()],
Some("T".into()),
None,
);
let float_or_ndarray_var_map = into_var_map([float_or_ndarray_ty]); let float_or_ndarray_var_map = into_var_map([float_or_ndarray_ty]);
let num_or_ndarray_ty = let num_or_ndarray_ty = unifier.get_fresh_var_with_range(
unifier.get_fresh_var_with_range(&[num_ty.ty, ndarray_num_ty], Some("T".into()), None); &[num_ty.ty, ndarray_num_ty.into()],
Some("T".into()),
None,
);
let num_or_ndarray_var_map = into_var_map([num_ty, num_or_ndarray_ty]); let num_or_ndarray_var_map = into_var_map([num_ty, num_or_ndarray_ty]);
let list_int32 = unifier.add_ty(TypeEnum::TList { ty: int32 }); let list_int32 = unifier.add_ty(TypeEnum::TList { ty: int32 });
@ -405,10 +410,7 @@ impl<'a> BuiltinBuilder<'a> {
is_some_ty, is_some_ty,
unwrap_ty, unwrap_ty,
option_tvar,
ndarray_dtype_tvar,
ndarray_ndims_tvar,
ndarray_copy_ty, ndarray_copy_ty,
ndarray_fill_ty, ndarray_fill_ty,
@ -632,7 +634,7 @@ impl<'a> BuiltinBuilder<'a> {
PrimDef::Option => TopLevelDef::Class { PrimDef::Option => TopLevelDef::Class {
name: prim.name().into(), name: prim.name().into(),
object_id: prim.id(), object_id: prim.id(),
type_vars: vec![self.option_tvar.ty], type_vars: vec![self.primitives.option.type_tvar(self.unifier).ty],
fields: Vec::default(), fields: Vec::default(),
attributes: Vec::default(), attributes: Vec::default(),
methods: vec![ methods: vec![
@ -653,7 +655,7 @@ impl<'a> BuiltinBuilder<'a> {
name: prim.name().into(), name: prim.name().into(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.unwrap_ty.0, signature: self.unwrap_ty.0,
var_id: vec![self.option_tvar.id], var_id: vec![self.primitives.option.type_tvar(self.unifier).id],
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
@ -667,7 +669,7 @@ impl<'a> BuiltinBuilder<'a> {
name: prim.name().to_string(), name: prim.name().to_string(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.is_some_ty.0, signature: self.is_some_ty.0,
var_id: vec![self.option_tvar.id], var_id: vec![self.primitives.option.type_tvar(self.unifier).id],
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
@ -698,19 +700,22 @@ impl<'a> BuiltinBuilder<'a> {
loc: None, loc: None,
}, },
PrimDef::FunSome => TopLevelDef::Function { PrimDef::FunSome => {
let option_tvar = self.primitives.option.type_tvar(self.unifier);
TopLevelDef::Function {
name: prim.name().into(), name: prim.name().into(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { args: vec![FuncArg {
name: "n".into(), name: "n".into(),
ty: self.option_tvar.ty, ty: option_tvar.ty,
default_value: None, default_value: None,
}], }],
ret: self.primitives.option, ret: self.primitives.option.into(),
vars: into_var_map([self.option_tvar]), vars: into_var_map([option_tvar]),
})), })),
var_id: vec![self.option_tvar.id], var_id: vec![self.primitives.option.type_tvar(self.unifier).id],
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
@ -727,7 +732,8 @@ impl<'a> BuiltinBuilder<'a> {
}, },
)))), )))),
loc: None, loc: None,
}, }
}
_ => { _ => {
unreachable!() unreachable!()
@ -736,7 +742,7 @@ impl<'a> BuiltinBuilder<'a> {
} }
/// Build the class `ndarray` and its associated methods. /// Build the class `ndarray` and its associated methods.
fn build_ndarray_class_related(&self, prim: PrimDef) -> TopLevelDef { fn build_ndarray_class_related(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed( debug_assert_prim_is_allowed(
prim, prim,
&[PrimDef::NDArray, PrimDef::NDArrayCopy, PrimDef::NDArrayFill], &[PrimDef::NDArray, PrimDef::NDArrayCopy, PrimDef::NDArrayFill],
@ -746,7 +752,10 @@ impl<'a> BuiltinBuilder<'a> {
PrimDef::NDArray => TopLevelDef::Class { PrimDef::NDArray => TopLevelDef::Class {
name: prim.name().into(), name: prim.name().into(),
object_id: prim.id(), object_id: prim.id(),
type_vars: vec![self.ndarray_dtype_tvar.ty, self.ndarray_ndims_tvar.ty], type_vars: vec![
self.primitives.ndarray.dtype_tvar(self.unifier).ty,
self.primitives.ndarray.ndims_tvar(self.unifier).ty,
],
fields: Vec::default(), fields: Vec::default(),
attributes: Vec::default(), attributes: Vec::default(),
methods: vec![ methods: vec![
@ -763,7 +772,10 @@ impl<'a> BuiltinBuilder<'a> {
name: prim.name().into(), name: prim.name().into(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.ndarray_copy_ty.0, signature: self.ndarray_copy_ty.0,
var_id: vec![self.ndarray_dtype_tvar.id, self.ndarray_ndims_tvar.id], var_id: vec![
self.primitives.ndarray.dtype_tvar(self.unifier).id,
self.primitives.ndarray.ndims_tvar(self.unifier).id,
],
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
@ -780,7 +792,10 @@ impl<'a> BuiltinBuilder<'a> {
name: prim.name().into(), name: prim.name().into(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.ndarray_fill_ty.0, signature: self.ndarray_fill_ty.0,
var_id: vec![self.ndarray_dtype_tvar.id, self.ndarray_ndims_tvar.id], var_id: vec![
self.primitives.ndarray.dtype_tvar(self.unifier).id,
self.primitives.ndarray.ndims_tvar(self.unifier).id,
],
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,
@ -869,15 +884,26 @@ impl<'a> BuiltinBuilder<'a> {
// The size variant of the function determines the size of the returned int. // The size variant of the function determines the size of the returned int.
let int_sized = size_variant.of_int(self.primitives); let int_sized = size_variant.of_int(self.primitives);
let ndarray_int_sized = let ndarray_int_sized = primitive_type::NDArrayType::from_primitive(
make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.ty)); self.unifier,
let ndarray_float = self.primitives,
make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.ty)); Some(int_sized),
Some(common_ndim.ty),
);
let ndarray_float = primitive_type::NDArrayType::from_primitive(
self.unifier,
self.primitives,
Some(float),
Some(common_ndim.ty),
);
let p0_ty = let p0_ty = self.unifier.get_fresh_var_with_range(
self.unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); &[float, ndarray_float.into()],
Some("T".into()),
None,
);
let ret_ty = self.unifier.get_fresh_var_with_range( let ret_ty = self.unifier.get_fresh_var_with_range(
&[int_sized, ndarray_int_sized], &[int_sized, ndarray_int_sized.into()],
Some("R".into()), Some("R".into()),
None, None,
); );
@ -929,19 +955,30 @@ impl<'a> BuiltinBuilder<'a> {
None, None,
); );
let ndarray_float = let ndarray_float = primitive_type::NDArrayType::from_primitive(
make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.ty)); self.unifier,
self.primitives,
Some(float),
Some(common_ndim.ty),
);
// The size variant of the function determines the type of int returned // The size variant of the function determines the type of int returned
let int_sized = size_variant.of_int(self.primitives); let int_sized = size_variant.of_int(self.primitives);
let ndarray_int_sized = let ndarray_int_sized = primitive_type::NDArrayType::from_primitive(
make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.ty)); self.unifier,
self.primitives,
Some(int_sized),
Some(common_ndim.ty),
);
let p0_ty = let p0_ty = self.unifier.get_fresh_var_with_range(
self.unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); &[float, ndarray_float.into()],
Some("T".into()),
None,
);
let ret_ty = self.unifier.get_fresh_var_with_range( let ret_ty = self.unifier.get_fresh_var_with_range(
&[int_sized, ndarray_int_sized], &[int_sized, ndarray_int_sized.into()],
Some("R".into()), Some("R".into()),
None, None,
); );
@ -1004,7 +1041,7 @@ impl<'a> BuiltinBuilder<'a> {
self.unifier, self.unifier,
&VarMap::new(), &VarMap::new(),
prim.name(), prim.name(),
self.ndarray_float, self.ndarray_float.into(),
&[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")], &[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
Box::new(move |ctx, obj, fun, args, generator| { Box::new(move |ctx, obj, fun, args, generator| {
let func = match prim { let func = match prim {
@ -1050,7 +1087,7 @@ impl<'a> BuiltinBuilder<'a> {
default_value: Some(SymbolValue::U32(0)), default_value: Some(SymbolValue::U32(0)),
}, },
], ],
ret: ndarray, ret: ndarray.into(),
vars: into_var_map([tv]), vars: into_var_map([tv]),
})), })),
var_id: vec![tv.id], var_id: vec![tv.id],
@ -1073,7 +1110,7 @@ impl<'a> BuiltinBuilder<'a> {
self.unifier, self.unifier,
&into_var_map([tv]), &into_var_map([tv]),
prim.name(), prim.name(),
self.primitives.ndarray, self.primitives.ndarray.into(),
// We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a
// type variable // type variable
&[(self.list_int32, "shape"), (tv.ty, "fill_value")], &[(self.list_int32, "shape"), (tv.ty, "fill_value")],
@ -1102,7 +1139,7 @@ impl<'a> BuiltinBuilder<'a> {
default_value: Some(SymbolValue::I32(0)), default_value: Some(SymbolValue::I32(0)),
}, },
], ],
ret: self.ndarray_float_2d, ret: self.ndarray_float_2d.into(),
vars: VarMap::default(), vars: VarMap::default(),
})), })),
var_id: Vec::default(), var_id: Vec::default(),
@ -1122,7 +1159,7 @@ impl<'a> BuiltinBuilder<'a> {
self.unifier, self.unifier,
&VarMap::new(), &VarMap::new(),
prim.name(), prim.name(),
self.ndarray_float_2d, self.ndarray_float_2d.into(),
&[(int32, "n")], &[(int32, "n")],
Box::new(|ctx, obj, fun, args, generator| { Box::new(|ctx, obj, fun, args, generator| {
gen_ndarray_identity(ctx, &obj, fun, &args, generator) gen_ndarray_identity(ctx, &obj, fun, &args, generator)
@ -1337,10 +1374,15 @@ impl<'a> BuiltinBuilder<'a> {
let tvar = self.unifier.get_fresh_var(Some("L".into()), None); let tvar = self.unifier.get_fresh_var(Some("L".into()), None);
let list = self.unifier.add_ty(TypeEnum::TList { ty: tvar.ty }); let list = self.unifier.add_ty(TypeEnum::TList { ty: tvar.ty });
let ndims = self.unifier.get_fresh_const_generic_var(uint64, Some("N".into()), None); let ndims = self.unifier.get_fresh_const_generic_var(uint64, Some("N".into()), None);
let ndarray = make_ndarray_ty(self.unifier, self.primitives, Some(tvar.ty), Some(ndims.ty)); let ndarray = primitive_type::NDArrayType::from_primitive(
self.unifier,
self.primitives,
Some(tvar.ty),
Some(ndims.ty),
);
let arg_ty = self.unifier.get_fresh_var_with_range( let arg_ty = self.unifier.get_fresh_var_with_range(
&[list, ndarray, self.primitives.range], &[list, ndarray.into(), self.primitives.range],
Some("I".into()), Some("I".into()),
None, None,
); );
@ -1798,8 +1840,13 @@ impl<'a> BuiltinBuilder<'a> {
} }
fn new_type_or_ndarray_ty(&mut self, scalar_ty: Type) -> TypeVar { fn new_type_or_ndarray_ty(&mut self, scalar_ty: Type) -> TypeVar {
let ndarray = make_ndarray_ty(self.unifier, self.primitives, Some(scalar_ty), None); let ndarray = primitive_type::NDArrayType::from_primitive(
self.unifier,
self.primitives,
Some(scalar_ty),
None,
);
self.unifier.get_fresh_var_with_range(&[scalar_ty, ndarray], Some("T".into()), None) self.unifier.get_fresh_var_with_range(&[scalar_ty, ndarray.into()], Some("T".into()), None)
} }
} }

View File

@ -1,14 +1,13 @@
use std::convert::TryInto; use std::convert::TryInto;
use super::*;
use crate::symbol_resolver::SymbolValue; use crate::symbol_resolver::SymbolValue;
use crate::toplevel::numpy::unpack_ndarray_var_tys; use crate::toplevel::primitive_type::{NDArrayType, OptionType};
use crate::typecheck::typedef::{into_var_map, Mapping, TypeVarId, VarMap}; use crate::typecheck::typedef::{into_var_map, GenericObjectType, Mapping, TypeVarId, VarMap};
use nac3parser::ast::{Constant, Location}; use nac3parser::ast::{Constant, Location};
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
use strum_macros::EnumIter; use strum_macros::EnumIter;
use super::*;
/// All primitive types and functions in nac3core. /// All primitive types and functions in nac3core.
#[derive(Clone, Copy, Debug, EnumIter, PartialEq, Eq)] #[derive(Clone, Copy, Debug, EnumIter, PartialEq, Eq)]
pub enum PrimDef { pub enum PrimDef {
@ -403,6 +402,7 @@ impl TopLevelComposer {
.collect::<HashMap<_, _>>(), .collect::<HashMap<_, _>>(),
params: into_var_map([option_type_var]), params: into_var_map([option_type_var]),
}); });
let option = OptionType::create(option, &mut unifier);
let size_t_ty = match size_t { let size_t_ty = match size_t {
32 => uint32, 32 => uint32,
@ -436,8 +436,9 @@ impl TopLevelComposer {
]), ]),
params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]), params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
}); });
let ndarray = NDArrayType::create(ndarray, &mut unifier);
unifier.unify(ndarray_copy_fun_ret_ty.ty, ndarray).unwrap(); unifier.unify(ndarray_copy_fun_ret_ty.ty, ndarray.into()).unwrap();
let primitives = PrimitiveStore { let primitives = PrimitiveStore {
int32, int32,
@ -747,7 +748,7 @@ impl TopLevelComposer {
TypeAnnotation::CustomClass { id: e_id, params: e_param }, TypeAnnotation::CustomClass { id: e_id, params: e_param },
) => { ) => {
*f_id == *e_id *f_id == *e_id
&& *f_id == primitive.option.obj_id(unifier).unwrap() && *f_id == primitive.option.obj_id(unifier)
&& (f_param.is_empty() && (f_param.is_empty()
|| (f_param.len() == 1 || (f_param.len() == 1
&& e_param.len() == 1 && e_param.len() == 1
@ -885,7 +886,7 @@ pub fn parse_parameter_default_value(
pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type { pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type {
match &*unifier.get_ty(ty) { match &*unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
unpack_ndarray_var_tys(unifier, ty).0 NDArrayType::create(ty, unifier).dtype_tvar(unifier).ty
} }
TypeEnum::TList { ty } => arraylike_flatten_element_type(unifier, *ty), TypeEnum::TList { ty } => arraylike_flatten_element_type(unifier, *ty),
@ -897,7 +898,7 @@ pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type {
pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 { pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 {
match &*unifier.get_ty(ty) { match &*unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let ndims = unpack_ndarray_var_tys(unifier, ty).1; let ndims = NDArrayType::create(ty, unifier).ndims_tvar(unifier).ty;
let TypeEnum::TLiteral { values, .. } = &*unifier.get_ty_immutable(ndims) else { let TypeEnum::TLiteral { values, .. } = &*unifier.get_ty_immutable(ndims) else {
panic!("Expected TLiteral for ndarray.ndims, got {}", unifier.stringify(ndims)) panic!("Expected TLiteral for ndarray.ndims, got {}", unifier.stringify(ndims))
}; };

View File

@ -30,7 +30,7 @@ pub struct DefinitionId(pub usize);
pub mod builtins; pub mod builtins;
pub mod composer; pub mod composer;
pub mod helper; pub mod helper;
pub mod numpy; pub mod primitive_type;
pub mod type_annotation; pub mod type_annotation;
use composer::*; use composer::*;
use type_annotation::*; use type_annotation::*;

View File

@ -1,85 +0,0 @@
use crate::{
toplevel::helper::PrimDef,
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{Type, TypeEnum, TypeVarId, Unifier, VarMap},
},
};
use itertools::Itertools;
/// Creates a `ndarray` [`Type`] with the given type arguments.
///
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
pub fn make_ndarray_ty(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
dtype: Option<Type>,
ndims: Option<Type>,
) -> Type {
subst_ndarray_tvars(unifier, primitives.ndarray, dtype, ndims)
}
/// Substitutes type variables in `ndarray`.
///
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
pub fn subst_ndarray_tvars(
unifier: &mut Unifier,
ndarray: Type,
dtype: Option<Type>,
ndims: Option<Type>,
) -> Type {
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
};
debug_assert_eq!(*obj_id, PrimDef::NDArray.id());
if dtype.is_none() && ndims.is_none() {
return ndarray;
}
let tvar_ids = params.iter().map(|(obj_id, _)| *obj_id).collect_vec();
debug_assert_eq!(tvar_ids.len(), 2);
let mut tvar_subst = VarMap::new();
if let Some(dtype) = dtype {
tvar_subst.insert(tvar_ids[0], dtype);
}
if let Some(ndims) = ndims {
tvar_subst.insert(tvar_ids[1], ndims);
}
unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray)
}
fn unpack_ndarray_tvars(unifier: &mut Unifier, ndarray: Type) -> Vec<(TypeVarId, Type)> {
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
};
debug_assert_eq!(*obj_id, PrimDef::NDArray.id());
debug_assert_eq!(params.len(), 2);
params
.iter()
.sorted_by_key(|(obj_id, _)| *obj_id)
.map(|(var_id, ty)| (*var_id, *ty))
.collect_vec()
}
/// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds
/// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray`
/// respectively.
pub fn unpack_ndarray_var_ids(unifier: &mut Unifier, ndarray: Type) -> (TypeVarId, TypeVarId) {
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.0).collect_tuple().unwrap()
}
/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to
/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively.
pub fn unpack_ndarray_var_tys(unifier: &mut Unifier, ndarray: Type) -> (Type, Type) {
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.1).collect_tuple().unwrap()
}

View File

@ -0,0 +1,98 @@
use crate::toplevel::helper::PrimDef;
use crate::typecheck::type_inferencer::PrimitiveStore;
use crate::typecheck::typedef::{GenericObjectType, Type, TypeVar, Unifier, VarMap};
#[derive(Clone, Copy)]
pub struct OptionType(Type);
impl OptionType {
pub fn from_primitive(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
type_ty: Option<Type>,
) -> Self {
primitives.option.subst(unifier, type_ty)
}
pub fn type_tvar(&self, unifier: &mut Unifier) -> TypeVar {
self.get_var_at(unifier, 0).unwrap()
}
#[must_use]
pub fn subst(&self, unifier: &mut Unifier, type_ty: Option<Type>) -> Self {
let new_vars = [(self.type_tvar(unifier).id, type_ty)]
.into_iter()
.filter_map(|(id, ty)| ty.map(|ty| (id, ty)))
.collect::<VarMap>();
let new_ty = unifier.subst(self.get_type(), &new_vars).unwrap_or(self.get_type());
OptionType(new_ty)
}
}
impl GenericObjectType for OptionType {
fn try_create(ty: Type, unifier: &mut Unifier) -> Option<Self> {
if ty.obj_id(unifier).is_some_and(|id| id == PrimDef::Option.id()) {
Some(OptionType(ty))
} else {
None
}
}
fn get_type(&self) -> Type {
self.0
}
}
#[derive(Clone, Copy)]
pub struct NDArrayType(Type);
impl NDArrayType {
pub fn from_primitive(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
dtype: Option<Type>,
ndims: Option<Type>,
) -> Self {
primitives.ndarray.subst(unifier, dtype, ndims)
}
pub fn dtype_tvar(&self, unifier: &mut Unifier) -> TypeVar {
self.get_var_at(unifier, 0).unwrap()
}
pub fn ndims_tvar(&self, unifier: &mut Unifier) -> TypeVar {
self.get_var_at(unifier, 1).unwrap()
}
#[must_use]
pub fn subst(
&self,
unifier: &mut Unifier,
dtype_ty: Option<Type>,
ndims_ty: Option<Type>,
) -> Self {
let new_vars =
[(self.dtype_tvar(unifier).id, dtype_ty), (self.ndims_tvar(unifier).id, ndims_ty)]
.into_iter()
.filter_map(|(id, ty)| ty.map(|ty| (id, ty)))
.collect::<VarMap>();
let new_ty = unifier.subst(self.get_type(), &new_vars).unwrap_or(self.get_type());
NDArrayType(new_ty)
}
}
impl GenericObjectType for NDArrayType {
fn try_create(ty: Type, unifier: &mut Unifier) -> Option<Self> {
if ty.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
Some(NDArrayType(ty))
} else {
None
}
}
fn get_type(&self) -> Type {
self.0
}
}

View File

@ -1,7 +1,7 @@
use super::*; use super::*;
use crate::symbol_resolver::SymbolValue; use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PrimDef; use crate::toplevel::helper::PrimDef;
use crate::typecheck::typedef::VarMap; use crate::typecheck::typedef::{GenericObjectType, VarMap};
use nac3parser::ast::Constant; use nac3parser::ast::Constant;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -267,12 +267,7 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
slice.as_ref(), slice.as_ref(),
locked, locked,
)?; )?;
let id = let id = primitives.option.obj_id(unifier);
if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(primitives.option).as_ref() {
*obj_id
} else {
unreachable!()
};
Ok(TypeAnnotation::CustomClass { id, params: vec![def_ann] }) Ok(TypeAnnotation::CustomClass { id, params: vec![def_ann] })
} }

View File

@ -1,9 +1,9 @@
use crate::symbol_resolver::SymbolValue; use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PrimDef; use crate::toplevel::helper::PrimDef;
use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys}; use crate::toplevel::primitive_type;
use crate::typecheck::{ use crate::typecheck::{
type_inferencer::*, type_inferencer::*,
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, typedef::{FunSignature, FuncArg, GenericObjectType, Type, TypeEnum, Unifier, VarMap},
}; };
use itertools::Itertools; use itertools::Itertools;
use nac3parser::ast::StrRef; use nac3parser::ast::StrRef;
@ -369,8 +369,12 @@ pub fn typeof_ndarray_broadcast(
if is_left_ndarray && is_right_ndarray { if is_left_ndarray && is_right_ndarray {
// Perform broadcasting on two ndarray operands. // Perform broadcasting on two ndarray operands.
let (left_ty_dtype, left_ty_ndims) = unpack_ndarray_var_tys(unifier, left); let left_ty = primitive_type::NDArrayType::create(left, unifier);
let (right_ty_dtype, right_ty_ndims) = unpack_ndarray_var_tys(unifier, right); let left_ty_dtype = left_ty.dtype_tvar(unifier).ty;
let left_ty_ndims = left_ty.ndims_tvar(unifier).ty;
let right_ty = primitive_type::NDArrayType::create(right, unifier);
let right_ty_dtype = right_ty.dtype_tvar(unifier).ty;
let right_ty_ndims = right_ty.ndims_tvar(unifier).ty;
assert!(unifier.unioned(left_ty_dtype, right_ty_dtype)); assert!(unifier.unioned(left_ty_dtype, right_ty_dtype));
@ -397,11 +401,18 @@ pub fn typeof_ndarray_broadcast(
.collect_vec(); .collect_vec();
let res_ndims = unifier.get_fresh_literal(res_ndims, None); let res_ndims = unifier.get_fresh_literal(res_ndims, None);
Ok(make_ndarray_ty(unifier, primitives, Some(left_ty_dtype), Some(res_ndims))) Ok(primitive_type::NDArrayType::from_primitive(
unifier,
primitives,
Some(left_ty_dtype),
Some(res_ndims),
)
.into())
} else { } else {
let (ndarray_ty, scalar_ty) = if is_left_ndarray { (left, right) } else { (right, left) }; let (ndarray_ty, scalar_ty) = if is_left_ndarray { (left, right) } else { (right, left) };
let (ndarray_ty_dtype, _) = unpack_ndarray_var_tys(unifier, ndarray_ty); let ndarray_ty_dtype =
primitive_type::NDArrayType::create(ndarray_ty, unifier).ndims_tvar(unifier).ty;
if unifier.unioned(ndarray_ty_dtype, scalar_ty) { if unifier.unioned(ndarray_ty_dtype, scalar_ty) {
Ok(ndarray_ty) Ok(ndarray_ty)
@ -444,7 +455,8 @@ pub fn typeof_binop(
} }
Operator::MatMult => { Operator::MatMult => {
let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs); let lhs_ndims =
primitive_type::NDArrayType::create(lhs, unifier).ndims_tvar(unifier).ty;
let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) { let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) {
TypeEnum::TLiteral { values, .. } => { TypeEnum::TLiteral { values, .. } => {
assert_eq!(values.len(), 1); assert_eq!(values.len(), 1);
@ -452,7 +464,8 @@ pub fn typeof_binop(
} }
_ => unreachable!(), _ => unreachable!(),
}; };
let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs); let rhs_ndims =
primitive_type::NDArrayType::create(rhs, unifier).ndims_tvar(unifier).ty;
let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) { let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) {
TypeEnum::TLiteral { values, .. } => { TypeEnum::TLiteral { values, .. } => {
assert_eq!(values.len(), 1); assert_eq!(values.len(), 1);
@ -526,7 +539,7 @@ pub fn typeof_unaryop(
let operand_obj_id = operand.obj_id(unifier); let operand_obj_id = operand.obj_id(unifier);
if op == Unaryop::Not if op == Unaryop::Not
&& operand_obj_id.is_some_and(|id| id == primitives.ndarray.obj_id(unifier).unwrap()) && operand_obj_id.is_some_and(|id| id == primitives.ndarray.obj_id(unifier))
{ {
return Err( return Err(
"The truth value of an array with more than one element is ambiguous".to_string() "The truth value of an array with more than one element is ambiguous".to_string()
@ -552,7 +565,8 @@ pub fn typeof_unaryop(
Unaryop::UAdd | Unaryop::USub => { Unaryop::UAdd | Unaryop::USub => {
if operand_obj_id.is_some_and(|id| id == PrimDef::NDArray.id()) { if operand_obj_id.is_some_and(|id| id == PrimDef::NDArray.id()) {
let (dtype, _) = unpack_ndarray_var_tys(unifier, operand); let dtype =
primitive_type::NDArrayType::create(operand, unifier).dtype_tvar(unifier).ty;
if dtype.obj_id(unifier).is_some_and(|id| id == PrimDef::Bool.id()) { if dtype.obj_id(unifier).is_some_and(|id| id == PrimDef::Bool.id()) {
return Err(if op == Unaryop::UAdd { return Err(if op == Unaryop::UAdd {
"The ufunc 'positive' cannot be applied to ndarray[bool, N]".to_string() "The ufunc 'positive' cannot be applied to ndarray[bool, N]".to_string()
@ -586,9 +600,15 @@ pub fn typeof_cmpop(
Ok(Some(if is_left_ndarray || is_right_ndarray { Ok(Some(if is_left_ndarray || is_right_ndarray {
let brd = typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?; let brd = typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?;
let (_, ndims) = unpack_ndarray_var_tys(unifier, brd); let ndims = primitive_type::NDArrayType::create(brd, unifier).ndims_tvar(unifier).ty;
make_ndarray_ty(unifier, primitives, Some(primitives.bool), Some(ndims)) primitive_type::NDArrayType::from_primitive(
unifier,
primitives,
Some(primitives.bool),
Some(ndims),
)
.into()
} else if unifier.unioned(lhs, rhs) { } else if unifier.unioned(lhs, rhs) {
primitives.bool primitives.bool
} else { } else {
@ -611,64 +631,108 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
/* int ======== */ /* int ======== */
for t in [int32_t, int64_t, uint32_t, uint64_t] { for t in [int32_t, int64_t, uint32_t, uint64_t] {
let ndarray_int_t = make_ndarray_ty(unifier, store, Some(t), None); let ndarray_int_t =
impl_basic_arithmetic(unifier, store, t, &[t, ndarray_int_t], None); primitive_type::NDArrayType::from_primitive(unifier, store, Some(t), None);
impl_pow(unifier, store, t, &[t, ndarray_int_t], None); impl_basic_arithmetic(unifier, store, t, &[t, ndarray_int_t.into()], None);
impl_pow(unifier, store, t, &[t, ndarray_int_t.into()], None);
impl_bitwise_arithmetic(unifier, store, t); impl_bitwise_arithmetic(unifier, store, t);
impl_bitwise_shift(unifier, store, t); impl_bitwise_shift(unifier, store, t);
impl_div(unifier, store, t, &[t, ndarray_int_t], None); impl_div(unifier, store, t, &[t, ndarray_int_t.into()], None);
impl_floordiv(unifier, store, t, &[t, ndarray_int_t], None); impl_floordiv(unifier, store, t, &[t, ndarray_int_t.into()], None);
impl_mod(unifier, store, t, &[t, ndarray_int_t], None); impl_mod(unifier, store, t, &[t, ndarray_int_t.into()], None);
impl_invert(unifier, store, t, Some(t)); impl_invert(unifier, store, t, Some(t));
impl_not(unifier, store, t, Some(bool_t)); impl_not(unifier, store, t, Some(bool_t));
impl_comparison(unifier, store, t, &[t, ndarray_int_t], None); impl_comparison(unifier, store, t, &[t, ndarray_int_t.into()], None);
impl_eq(unifier, store, t, &[t, ndarray_int_t], None); impl_eq(unifier, store, t, &[t, ndarray_int_t.into()], None);
} }
for t in [int32_t, int64_t] { for t in [int32_t, int64_t] {
impl_sign(unifier, store, t, Some(t)); impl_sign(unifier, store, t, Some(t));
} }
/* float ======== */ /* float ======== */
let ndarray_float_t = make_ndarray_ty(unifier, store, Some(float_t), None); let ndarray_float_t =
let ndarray_int32_t = make_ndarray_ty(unifier, store, Some(int32_t), None); primitive_type::NDArrayType::from_primitive(unifier, store, Some(float_t), None);
impl_basic_arithmetic(unifier, store, float_t, &[float_t, ndarray_float_t], None); let ndarray_int32_t =
impl_pow(unifier, store, float_t, &[int32_t, float_t, ndarray_int32_t, ndarray_float_t], None); primitive_type::NDArrayType::from_primitive(unifier, store, Some(int32_t), None);
impl_div(unifier, store, float_t, &[float_t, ndarray_float_t], None); impl_basic_arithmetic(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None);
impl_floordiv(unifier, store, float_t, &[float_t, ndarray_float_t], None); impl_pow(
impl_mod(unifier, store, float_t, &[float_t, ndarray_float_t], None); unifier,
store,
float_t,
&[int32_t, float_t, ndarray_int32_t.into(), ndarray_float_t.into()],
None,
);
impl_div(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None);
impl_floordiv(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None);
impl_mod(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None);
impl_sign(unifier, store, float_t, Some(float_t)); impl_sign(unifier, store, float_t, Some(float_t));
impl_not(unifier, store, float_t, Some(bool_t)); impl_not(unifier, store, float_t, Some(bool_t));
impl_comparison(unifier, store, float_t, &[float_t, ndarray_float_t], None); impl_comparison(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None);
impl_eq(unifier, store, float_t, &[float_t, ndarray_float_t], None); impl_eq(unifier, store, float_t, &[float_t, ndarray_float_t.into()], None);
/* bool ======== */ /* bool ======== */
let ndarray_bool_t = make_ndarray_ty(unifier, store, Some(bool_t), None); let ndarray_bool_t =
primitive_type::NDArrayType::from_primitive(unifier, store, Some(bool_t), None);
impl_invert(unifier, store, bool_t, Some(int32_t)); impl_invert(unifier, store, bool_t, Some(int32_t));
impl_not(unifier, store, bool_t, Some(bool_t)); impl_not(unifier, store, bool_t, Some(bool_t));
impl_sign(unifier, store, bool_t, Some(int32_t)); impl_sign(unifier, store, bool_t, Some(int32_t));
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None); impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t.into()], None);
/* ndarray ===== */ /* ndarray ===== */
let ndarray_usized_ndims_tvar = let ndarray_usized_ndims_tvar =
unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None); unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None);
let ndarray_unsized_t = let ndarray_unsized_t = primitive_type::NDArrayType::from_primitive(
make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.ty)); unifier,
let (ndarray_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_t); store,
let (ndarray_unsized_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_unsized_t); None,
Some(ndarray_usized_ndims_tvar.ty),
);
let ndarray_dtype_t = ndarray_t.dtype_tvar(unifier).ty;
let ndarray_unsized_dtype_t = ndarray_unsized_t.dtype_tvar(unifier).ty;
impl_basic_arithmetic( impl_basic_arithmetic(
unifier, unifier,
store, store,
ndarray_t, ndarray_t.into(),
&[ndarray_unsized_t, ndarray_unsized_dtype_t], &[ndarray_unsized_t.into(), ndarray_unsized_dtype_t],
None,
);
impl_pow(
unifier,
store,
ndarray_t.into(),
&[ndarray_unsized_t.into(), ndarray_unsized_dtype_t],
None,
);
impl_div(unifier, store, ndarray_t.into(), &[ndarray_t.into(), ndarray_dtype_t], None);
impl_floordiv(
unifier,
store,
ndarray_t.into(),
&[ndarray_unsized_t.into(), ndarray_unsized_dtype_t],
None,
);
impl_mod(
unifier,
store,
ndarray_t.into(),
&[ndarray_unsized_t.into(), ndarray_unsized_dtype_t],
None,
);
impl_matmul(unifier, store, ndarray_t.into(), &[ndarray_t.into()], Some(ndarray_t.into()));
impl_sign(unifier, store, ndarray_t.into(), Some(ndarray_t.into()));
impl_invert(unifier, store, ndarray_t.into(), Some(ndarray_t.into()));
impl_eq(
unifier,
store,
ndarray_t.into(),
&[ndarray_unsized_t.into(), ndarray_unsized_dtype_t],
None,
);
impl_comparison(
unifier,
store,
ndarray_t.into(),
&[ndarray_unsized_t.into(), ndarray_unsized_dtype_t],
None, None,
); );
impl_pow(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None);
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_matmul(unifier, store, ndarray_t, &[ndarray_t], Some(ndarray_t));
impl_sign(unifier, store, ndarray_t, Some(ndarray_t));
impl_invert(unifier, store, ndarray_t, Some(ndarray_t));
impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_comparison(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
} }

View File

@ -4,14 +4,16 @@ use std::iter::once;
use std::ops::Not; use std::ops::Not;
use std::{cell::RefCell, sync::Arc}; use std::{cell::RefCell, sync::Arc};
use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap}; use super::typedef::{
Call, FunSignature, FuncArg, GenericObjectType, RecordField, Type, TypeEnum, Unifier, VarMap,
};
use super::{magic_methods::*, type_error::TypeError, typedef::CallId}; use super::{magic_methods::*, type_error::TypeError, typedef::CallId};
use crate::toplevel::primitive_type::{NDArrayType, OptionType};
use crate::toplevel::TopLevelDef; use crate::toplevel::TopLevelDef;
use crate::{ use crate::{
symbol_resolver::{SymbolResolver, SymbolValue}, symbol_resolver::{SymbolResolver, SymbolValue},
toplevel::{ toplevel::{
helper::{arraylike_flatten_element_type, arraylike_get_ndims, PrimDef}, helper::{arraylike_flatten_element_type, arraylike_get_ndims, PrimDef},
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
TopLevelContext, TopLevelContext,
}, },
}; };
@ -49,8 +51,8 @@ pub struct PrimitiveStore {
pub range: Type, pub range: Type,
pub str: Type, pub str: Type,
pub exception: Type, pub exception: Type,
pub option: Type, pub option: OptionType,
pub ndarray: Type, pub ndarray: NDArrayType,
pub size_t: u32, pub size_t: u32,
} }
@ -74,6 +76,34 @@ impl PrimitiveStore {
_ => unreachable!(), _ => unreachable!(),
} }
} }
/// Returns an iterator over all primitive types in this store.
fn iter(&self) -> impl Iterator<Item = Type> {
self.into_iter()
}
}
impl IntoIterator for &PrimitiveStore {
type Item = Type;
type IntoIter = <Vec<Type> as IntoIterator>::IntoIter;
fn into_iter(self) -> Self::IntoIter {
vec![
self.int32,
self.int64,
self.uint32,
self.uint64,
self.float,
self.bool,
self.none,
self.range,
self.str,
self.exception,
self.option.into(),
self.ndarray.into(),
]
.into_iter()
}
} }
pub struct FunctionData { pub struct FunctionData {
@ -500,7 +530,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
// the name `none` is special since it may have different types // the name `none` is special since it may have different types
if id == &"none".into() { if id == &"none".into() {
if let TypeEnum::TObj { params, .. } = if let TypeEnum::TObj { params, .. } =
self.unifier.get_ty_immutable(self.primitives.option).as_ref() &*self.unifier.get_ty_immutable(self.primitives.option.into())
{ {
let var_map = params let var_map = params
.iter() .iter()
@ -515,7 +545,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
(*id, self.unifier.get_fresh_var_with_range(range, *name, *loc).ty) (*id, self.unifier.get_fresh_var_with_range(range, *name, *loc).ty)
}) })
.collect::<VarMap>(); .collect::<VarMap>();
Some(self.unifier.subst(self.primitives.option, &var_map).unwrap()) Some(self.unifier.subst(self.primitives.option.into(), &var_map).unwrap())
} else { } else {
unreachable!("must be tobj") unreachable!("must be tobj")
} }
@ -1034,9 +1064,16 @@ impl<'a> Inferencer<'a> {
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
{ {
let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty); let ndarray_ndims =
NDArrayType::create(arg0_ty, self.unifier).ndims_tvar(self.unifier).ty;
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims)) NDArrayType::from_primitive(
self.unifier,
self.primitives,
Some(target_ty),
Some(ndarray_ndims),
)
.into()
} else { } else {
target_ty target_ty
}; };
@ -1072,9 +1109,7 @@ impl<'a> Inferencer<'a> {
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
{ {
let (ndarray_dtype, _) = unpack_ndarray_var_tys(self.unifier, arg0_ty); NDArrayType::create(arg0_ty, self.unifier).dtype_tvar(self.unifier).ty
ndarray_dtype
} else { } else {
arg0_ty arg0_ty
}; };
@ -1126,14 +1161,14 @@ impl<'a> Inferencer<'a> {
let arg0_dtype = let arg0_dtype =
if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
unpack_ndarray_var_tys(self.unifier, arg0_ty).0 NDArrayType::create(arg0_ty, self.unifier).dtype_tvar(self.unifier).ty
} else { } else {
arg0_ty arg0_ty
}; };
let arg1_dtype = let arg1_dtype =
if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
unpack_ndarray_var_tys(self.unifier, arg1_ty).0 NDArrayType::create(arg1_ty, self.unifier).dtype_tvar(self.unifier).ty
} else { } else {
arg1_ty arg1_ty
}; };
@ -1164,9 +1199,17 @@ impl<'a> Inferencer<'a> {
// (float, int32), so convert it to align with the dtype of the first arg // (float, int32), so convert it to align with the dtype of the first arg
let arg1_ty = if id == &"np_ldexp".into() { let arg1_ty = if id == &"np_ldexp".into() {
if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, arg1_ty); // let (_, ndims) = unpack_ndarray_var_tys(self.unifier, arg1_ty);
let ndims =
NDArrayType::create(arg1_ty, self.unifier).ndims_tvar(self.unifier).ty;
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndims)) NDArrayType::from_primitive(
self.unifier,
self.primitives,
Some(target_ty),
Some(ndims),
)
.into()
} else { } else {
target_ty target_ty
} }
@ -1253,9 +1296,16 @@ impl<'a> Inferencer<'a> {
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
{ {
let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty); let ndarray_ndims =
NDArrayType::create(arg0_ty, self.unifier).ndims_tvar(self.unifier).ty;
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims)) NDArrayType::from_primitive(
self.unifier,
self.primitives,
Some(target_ty),
Some(ndarray_ndims),
)
.into()
} else { } else {
target_ty target_ty
}; };
@ -1295,7 +1345,7 @@ impl<'a> Inferencer<'a> {
self.fold_numpy_function_call_shape_argument(*id, 0, shape_expr)?; // Special handling for `shape` self.fold_numpy_function_call_shape_argument(*id, 0, shape_expr)?; // Special handling for `shape`
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None); let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
let ret = make_ndarray_ty( let ret = NDArrayType::from_primitive(
self.unifier, self.unifier,
self.primitives, self.primitives,
Some(self.primitives.float), Some(self.primitives.float),
@ -1307,13 +1357,13 @@ impl<'a> Inferencer<'a> {
ty: shape.custom.unwrap(), ty: shape.custom.unwrap(),
default_value: None, default_value: None,
}], }],
ret, ret: ret.into(),
vars: VarMap::new(), vars: VarMap::new(),
})); }));
return Ok(Some(Located { return Ok(Some(Located {
location, location,
custom: Some(ret), custom: Some(ret.into()),
node: ExprKind::Call { node: ExprKind::Call {
func: Box::new(Located { func: Box::new(Located {
custom: Some(custom), custom: Some(custom),
@ -1346,7 +1396,8 @@ impl<'a> Inferencer<'a> {
let ty = arg1.custom.unwrap(); let ty = arg1.custom.unwrap();
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None); let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
let ret = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims)); let ret =
NDArrayType::from_primitive(self.unifier, self.primitives, Some(ty), Some(ndims));
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![ args: vec![
FuncArg { name: "shape".into(), ty: arg0.custom.unwrap(), default_value: None }, FuncArg { name: "shape".into(), ty: arg0.custom.unwrap(), default_value: None },
@ -1356,13 +1407,13 @@ impl<'a> Inferencer<'a> {
default_value: None, default_value: None,
}, },
], ],
ret, ret: ret.into(),
vars: VarMap::new(), vars: VarMap::new(),
})); }));
return Ok(Some(Located { return Ok(Some(Located {
location, location,
custom: Some(ret), custom: Some(ret.into()),
node: ExprKind::Call { node: ExprKind::Call {
func: Box::new(Located { func: Box::new(Located {
custom: Some(custom), custom: Some(custom),
@ -1400,7 +1451,8 @@ impl<'a> Inferencer<'a> {
arraylike_get_ndims(self.unifier, arg0.custom.unwrap()) arraylike_get_ndims(self.unifier, arg0.custom.unwrap())
}; };
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None); let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
let ret = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims)); let ret =
NDArrayType::from_primitive(self.unifier, self.primitives, Some(ty), Some(ndims));
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![ args: vec![
@ -1420,13 +1472,13 @@ impl<'a> Inferencer<'a> {
default_value: Some(SymbolValue::U32(0)), default_value: Some(SymbolValue::U32(0)),
}, },
], ],
ret, ret: ret.into(),
vars: VarMap::new(), vars: VarMap::new(),
})); }));
return Ok(Some(Located { return Ok(Some(Located {
location, location,
custom: Some(ret), custom: Some(ret.into()),
node: ExprKind::Call { node: ExprKind::Call {
func: Box::new(Located { func: Box::new(Located {
custom: Some(custom), custom: Some(custom),
@ -1775,9 +1827,13 @@ impl<'a> Inferencer<'a> {
TypeEnum::TVar { is_const_generic: false, .. } TypeEnum::TVar { is_const_generic: false, .. }
)); ));
let constrained_ty = let constrained_ty = NDArrayType::from_primitive(
make_ndarray_ty(self.unifier, self.primitives, Some(dummy_tvar), Some(ndims)); self.unifier,
self.constrain(value.custom.unwrap(), constrained_ty, &value.location)?; self.primitives,
Some(dummy_tvar),
Some(ndims),
);
self.constrain(value.custom.unwrap(), constrained_ty.into(), &value.location)?;
let TypeEnum::TLiteral { values, .. } = &*self.unifier.get_ty_immutable(ndims) else { let TypeEnum::TLiteral { values, .. } = &*self.unifier.get_ty_immutable(ndims) else {
panic!("Expected TLiteral for ndarray.ndims, got {}", self.unifier.stringify(ndims)) panic!("Expected TLiteral for ndarray.ndims, got {}", self.unifier.stringify(ndims))
@ -1843,10 +1899,14 @@ impl<'a> Inferencer<'a> {
let ndims_ty = self let ndims_ty = self
.unifier .unifier
.get_fresh_literal(new_ndims.into_iter().map(SymbolValue::U64).collect(), None); .get_fresh_literal(new_ndims.into_iter().map(SymbolValue::U64).collect(), None);
let subscripted_ty = let subscripted_ty = NDArrayType::from_primitive(
make_ndarray_ty(self.unifier, self.primitives, Some(dummy_tvar), Some(ndims_ty)); self.unifier,
self.primitives,
Some(dummy_tvar),
Some(ndims_ty),
);
Ok(subscripted_ty) Ok(subscripted_ty.into())
} }
} }
@ -1865,10 +1925,17 @@ impl<'a> Inferencer<'a> {
let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) { let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) {
TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }), TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }),
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (_, ndims) = let ndims = NDArrayType::create(value.custom.unwrap(), self.unifier)
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); .ndims_tvar(self.unifier)
.ty;
make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims)) NDArrayType::from_primitive(
self.unifier,
self.primitives,
Some(ty),
Some(ndims),
)
.into()
} }
_ => unreachable!(), _ => unreachable!(),
@ -1879,8 +1946,10 @@ impl<'a> Inferencer<'a> {
ExprKind::Constant { value: ast::Constant::Int(val), .. } => { ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
match &*self.unifier.get_ty(value.custom.unwrap()) { match &*self.unifier.get_ty(value.custom.unwrap()) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (_, ndims) = let ndims = NDArrayType::create(value.custom.unwrap(), self.unifier)
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); .ndims_tvar(self.unifier)
.ty;
self.infer_subscript_ndarray(value, slice, ty, ndims) self.infer_subscript_ndarray(value, slice, ty, ndims)
} }
_ => { _ => {
@ -1923,7 +1992,10 @@ impl<'a> Inferencer<'a> {
} }
} }
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); let ndims = NDArrayType::create(value.custom.unwrap(), self.unifier)
.ndims_tvar(self.unifier)
.ty;
self.infer_subscript_ndarray(value, slice, ty, ndims) self.infer_subscript_ndarray(value, slice, ty, ndims)
} }
_ => { _ => {
@ -1947,8 +2019,9 @@ impl<'a> Inferencer<'a> {
Ok(ty) Ok(ty)
} }
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (_, ndims) = let ndims = NDArrayType::create(value.custom.unwrap(), self.unifier)
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); .ndims_tvar(self.unifier)
.ty;
let valid_index_tys = [self.primitives.int32, self.primitives.isize()] let valid_index_tys = [self.primitives.int32, self.primitives.isize()]
.into_iter() .into_iter()

View File

@ -139,6 +139,7 @@ impl TestEnvironment {
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let option = OptionType::create(option, &mut unifier);
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None); let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
let ndarray_ndims_tvar = let ndarray_ndims_tvar =
unifier.get_fresh_const_generic_var(uint64, Some("ndarray_ndims".into()), None); unifier.get_fresh_const_generic_var(uint64, Some("ndarray_ndims".into()), None);
@ -147,6 +148,7 @@ impl TestEnvironment {
fields: HashMap::new(), fields: HashMap::new(),
params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]), params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
}); });
let ndarray = NDArrayType::create(ndarray, &mut unifier);
let primitives = PrimitiveStore { let primitives = PrimitiveStore {
int32, int32,
int64, int64,
@ -273,11 +275,13 @@ impl TestEnvironment {
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let option = OptionType::create(option, &mut unifier);
let ndarray = unifier.add_ty(TypeEnum::TObj { let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::NDArray.id(), obj_id: PrimDef::NDArray.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let ndarray = NDArrayType::create(ndarray, &mut unifier);
identifier_mapping.insert("None".into(), none); identifier_mapping.insert("None".into(), none);
for (i, name) in ["int32", "int64", "float", "bool", "none", "range", "str", "Exception"] for (i, name) in ["int32", "int64", "float", "bool", "none", "range", "str", "Exception"]
.iter() .iter()

View File

@ -22,6 +22,40 @@ mod test;
/// Handle for a type, implemented as a key in the unification table. /// Handle for a type, implemented as a key in the unification table.
pub type Type = UnificationKey; pub type Type = UnificationKey;
/// Macro for generating functions related to type traits, e.g. whether the type is integral.
macro_rules! primitive_type_trait_fn {
($id:ident, $( $matches:ident ),*) => {
#[must_use]
pub fn $id(self, unifier: &mut Unifier, store: &PrimitiveStore) -> bool {
[$(store.$matches,)*].into_iter().any(|ty| unifier.unioned(self, ty))
}
};
}
impl Type {
/// Wrapper function for cleaner code so that we don't need to write this long pattern matching
/// just to get the field `obj_id`.
#[must_use]
pub fn obj_id(self, unifier: &Unifier) -> Option<DefinitionId> {
if let TypeEnum::TObj { obj_id, .. } = &*unifier.get_ty_immutable(self) {
Some(*obj_id)
} else {
None
}
}
#[must_use]
pub fn is_primitive(self, unifier: &mut Unifier, store: &PrimitiveStore) -> bool {
store.into_iter().any(|ty| unifier.unioned(self, ty))
}
primitive_type_trait_fn!(is_integral, bool, int32, int64, uint32, uint64);
primitive_type_trait_fn!(is_floating_point, float);
primitive_type_trait_fn!(is_arithmetic, int32, int64, uint32, uint64, float);
primitive_type_trait_fn!(is_signed, int32, uint32, float);
primitive_type_trait_fn!(is_unsigned, uint32, uint64);
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)] #[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct CallId(pub(super) usize); pub struct CallId(pub(super) usize);
@ -55,6 +89,24 @@ pub struct TypeVar {
pub ty: Type, pub ty: Type,
} }
impl From<(TypeVarId, Type)> for TypeVar {
fn from((id, ty): (TypeVarId, Type)) -> Self {
TypeVar { id, ty }
}
}
impl From<(&TypeVarId, &Type)> for TypeVar {
fn from((id, ty): (&TypeVarId, &Type)) -> Self {
TypeVar { id: *id, ty: *ty }
}
}
impl From<TypeVar> for (TypeVarId, Type) {
fn from(value: TypeVar) -> Self {
(value.id, value.ty)
}
}
/// The mapping between [`TypeVarId`] and [unifier type][`Type`]. /// The mapping between [`TypeVarId`] and [unifier type][`Type`].
pub type VarMap = IndexMapping<TypeVarId>; pub type VarMap = IndexMapping<TypeVarId>;
@ -68,9 +120,84 @@ where
vars.into_iter().map(|var| (var.id, var.ty)).collect() vars.into_iter().map(|var| (var.id, var.ty)).collect()
} }
/// Get an iterator of [`TypeVar`]s from a [`VarMap`] /// A trait representing a possibly generic object type.
pub fn iter_type_vars(var_map: &VarMap) -> impl Iterator<Item = TypeVar> + '_ { pub trait GenericObjectType
var_map.iter().map(|(&id, &ty)| TypeVar { id, ty }) where
Self: Sized,
{
fn try_create(ty: Type, unifier: &mut Unifier) -> Option<Self>;
/// Creates an instance from a [`Type`].
#[must_use]
fn create(ty: Type, unifier: &mut Unifier) -> Self {
Self::try_create(ty, unifier).unwrap()
}
/// Returns the [`Type`] underlying this instance.
#[must_use]
fn get_type(&self) -> Type;
/// Similar to [`Type::obj_id`], except that the [`DefinitionId`] is not wrapped within an
/// [`Option`].
#[must_use]
fn obj_id(&self, unifier: &Unifier) -> DefinitionId {
self.get_type().obj_id(unifier).unwrap()
}
/// Returns a copy of the [`VarMap`] of this object type.
#[must_use]
fn var_map(&self, unifier: &mut Unifier) -> VarMap {
let TypeEnum::TObj { params, .. } = &*unifier.get_ty(self.get_type()) else {
unreachable!()
};
params.clone()
}
/// Creates an iterator over the [`VarMap`] of this object type, applying `iter_fn` on the
/// created [`Iterator`].
#[must_use]
fn iter_var_map<R, IterFn: FnOnce(&mut dyn Iterator<Item = TypeVar>, &mut Unifier) -> R>(
&self,
unifier: &mut Unifier,
iter_fn: IterFn,
) -> R {
let TypeEnum::TObj { params, .. } = &*unifier.get_ty(self.get_type()) else {
unreachable!()
};
let res = iter_fn(&mut params.iter().map(TypeVar::from), unifier);
res
}
/// Returns the [`TypeVar`] instance at the given index.
#[must_use]
fn get_var_at(&self, unifier: &mut Unifier, i: usize) -> Option<TypeVar> {
self.iter_var_map(unifier, |iter, _| iter.nth(i))
}
}
impl<T: GenericObjectType> From<T> for Type {
fn from(value: T) -> Self {
value.get_type()
}
}
/// An adapter that converts [`Type`] into
pub struct GenericTypeAdapter(Type);
impl GenericObjectType for GenericTypeAdapter {
fn try_create(ty: Type, unifier: &mut Unifier) -> Option<Self> {
if let TypeEnum::TObj { .. } = &*unifier.get_ty_immutable(ty) {
Some(GenericTypeAdapter(ty))
} else {
None
}
}
fn get_type(&self) -> Type {
self.0
}
} }
#[derive(Clone)] #[derive(Clone)]
@ -109,19 +236,6 @@ pub enum RecordKey {
Int(i32), Int(i32),
} }
impl Type {
/// Wrapper function for cleaner code so that we don't need to write this long pattern matching
/// just to get the field `obj_id`.
#[must_use]
pub fn obj_id(self, unifier: &Unifier) -> Option<DefinitionId> {
if let TypeEnum::TObj { obj_id, .. } = &*unifier.get_ty_immutable(self) {
Some(*obj_id)
} else {
None
}
}
}
impl From<&RecordKey> for StrRef { impl From<&RecordKey> for StrRef {
fn from(r: &RecordKey) -> Self { fn from(r: &RecordKey) -> Self {
match r { match r {