Compare commits
5 Commits
12bdf6f77c
...
b804d2c995
Author | SHA1 | Date |
---|---|---|
David Mak | b804d2c995 | |
David Mak | eb829b9396 | |
David Mak | 039080303a | |
David Mak | edea07fb76 | |
David Mak | 1fcbd5b7ae |
|
@ -1,9 +1,13 @@
|
|||
use inkwell::{FloatPredicate, IntPredicate};
|
||||
use inkwell::types::{BasicTypeEnum, IntType};
|
||||
use inkwell::types::BasicTypeEnum;
|
||||
use inkwell::values::{BasicValueEnum, FloatValue, IntValue};
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::codegen::{CodeGenContext, CodeGenerator, extern_fns, irrt, llvm_intrinsics};
|
||||
use crate::codegen::classes::NDArrayValue;
|
||||
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
|
||||
use crate::toplevel::helper::PRIMITIVE_DEF_IDS;
|
||||
use crate::toplevel::numpy::unpack_ndarray_var_tys;
|
||||
use crate::typecheck::typedef::Type;
|
||||
|
||||
/// Shorthand for [`unreachable!()`] when a type of argument is not supported.
|
||||
|
@ -21,20 +25,23 @@ fn unsupported_type(
|
|||
}
|
||||
|
||||
/// Invokes the `int32` builtin function.
|
||||
pub fn call_int32<'ctx>(
|
||||
pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
n: (Type, BasicValueEnum<'ctx>),
|
||||
) -> IntValue<'ctx> {
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (n_ty, n) = n;
|
||||
|
||||
match n.get_type() {
|
||||
Ok(match n.get_type() {
|
||||
BasicTypeEnum::IntType(int_ty) if matches!(int_ty.get_bit_width(), 1 | 8) => {
|
||||
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool));
|
||||
|
||||
ctx.builder
|
||||
.build_int_z_extend(n.into_int_value(), llvm_i32, "zext")
|
||||
.map(Into::into)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
|
@ -44,7 +51,7 @@ pub fn call_int32<'ctx>(
|
|||
ctx.primitives.uint32,
|
||||
].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty)));
|
||||
|
||||
n.into_int_value()
|
||||
n
|
||||
}
|
||||
|
||||
BasicTypeEnum::IntType(int_ty) if int_ty.get_bit_width() == 64 => {
|
||||
|
@ -55,6 +62,7 @@ pub fn call_int32<'ctx>(
|
|||
|
||||
ctx.builder
|
||||
.build_int_truncate(n.into_int_value(), llvm_i32, "trunc")
|
||||
.map(Into::into)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
|
@ -66,23 +74,43 @@ pub fn call_int32<'ctx>(
|
|||
.unwrap();
|
||||
ctx.builder
|
||||
.build_int_truncate(to_int64, llvm_i32, "conv")
|
||||
.map(Into::into)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
BasicTypeEnum::PointerType(_) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
||||
|
||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.primitives.int32,
|
||||
None,
|
||||
NDArrayValue::from_ptr_val(n.into_pointer_value(), llvm_usize, None),
|
||||
|generator, ctx, val| {
|
||||
call_int32(generator, ctx, (elem_ty, val))
|
||||
},
|
||||
)?;
|
||||
|
||||
ndarray.as_ptr_value().into()
|
||||
}
|
||||
|
||||
_ => unsupported_type(ctx, "int32", &[n_ty])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Invokes the `int64` builtin function.
|
||||
pub fn call_int64<'ctx>(
|
||||
pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
n: (Type, BasicValueEnum<'ctx>),
|
||||
) -> IntValue<'ctx> {
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
let llvm_i64 = ctx.ctx.i64_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (n_ty, n) = n;
|
||||
|
||||
match n.get_type() {
|
||||
Ok(match n.get_type() {
|
||||
BasicTypeEnum::IntType(int_ty) if matches!(int_ty.get_bit_width(), 1 | 8 | 32) => {
|
||||
debug_assert!([
|
||||
ctx.primitives.bool,
|
||||
|
@ -93,10 +121,12 @@ pub fn call_int64<'ctx>(
|
|||
if ctx.unifier.unioned(n_ty, ctx.primitives.int32) {
|
||||
ctx.builder
|
||||
.build_int_s_extend(n.into_int_value(), llvm_i64, "sext")
|
||||
.map(Into::into)
|
||||
.unwrap()
|
||||
} else {
|
||||
ctx.builder
|
||||
.build_int_z_extend(n.into_int_value(), llvm_i64, "zext")
|
||||
.map(Into::into)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
@ -107,7 +137,7 @@ pub fn call_int64<'ctx>(
|
|||
ctx.primitives.uint64,
|
||||
].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty)));
|
||||
|
||||
n.into_int_value()
|
||||
n
|
||||
}
|
||||
|
||||
BasicTypeEnum::FloatType(_) => {
|
||||
|
@ -115,28 +145,49 @@ pub fn call_int64<'ctx>(
|
|||
|
||||
ctx.builder
|
||||
.build_float_to_signed_int(n.into_float_value(), ctx.ctx.i64_type(), "fptosi")
|
||||
.map(Into::into)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
BasicTypeEnum::PointerType(_) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
||||
|
||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.primitives.int64,
|
||||
None,
|
||||
NDArrayValue::from_ptr_val(n.into_pointer_value(), llvm_usize, None),
|
||||
|generator, ctx, val| {
|
||||
call_int64(generator, ctx, (elem_ty, val))
|
||||
},
|
||||
)?;
|
||||
|
||||
ndarray.as_ptr_value().into()
|
||||
}
|
||||
|
||||
_ => unsupported_type(ctx, "int64", &[n_ty])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Invokes the `uint32` builtin function.
|
||||
pub fn call_uint32<'ctx>(
|
||||
pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
n: (Type, BasicValueEnum<'ctx>),
|
||||
) -> IntValue<'ctx> {
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (n_ty, n) = n;
|
||||
|
||||
match n.get_type() {
|
||||
Ok(match n.get_type() {
|
||||
BasicTypeEnum::IntType(int_ty) if matches!(int_ty.get_bit_width(), 1 | 8) => {
|
||||
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool));
|
||||
|
||||
ctx.builder
|
||||
.build_int_z_extend(n.into_int_value(), llvm_i32, "zext")
|
||||
.map(Into::into)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
|
@ -146,7 +197,7 @@ pub fn call_uint32<'ctx>(
|
|||
ctx.primitives.uint32,
|
||||
].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty)));
|
||||
|
||||
n.into_int_value()
|
||||
n
|
||||
}
|
||||
|
||||
BasicTypeEnum::IntType(int_ty) if int_ty.get_bit_width() == 64 => {
|
||||
|
@ -157,6 +208,7 @@ pub fn call_uint32<'ctx>(
|
|||
|
||||
ctx.builder
|
||||
.build_int_truncate(n.into_int_value(), llvm_i32, "trunc")
|
||||
.map(Into::into)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
|
@ -182,24 +234,42 @@ pub fn call_uint32<'ctx>(
|
|||
to_int32,
|
||||
"conv",
|
||||
)
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
BasicTypeEnum::PointerType(_) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
||||
|
||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.primitives.uint32,
|
||||
None,
|
||||
NDArrayValue::from_ptr_val(n.into_pointer_value(), llvm_usize, None),
|
||||
|generator, ctx, val| {
|
||||
call_uint32(generator, ctx, (elem_ty, val))
|
||||
},
|
||||
)?;
|
||||
|
||||
ndarray.as_ptr_value().into()
|
||||
}
|
||||
|
||||
_ => unsupported_type(ctx, "uint32", &[n_ty])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Invokes the `uint64` builtin function.
|
||||
pub fn call_uint64<'ctx>(
|
||||
pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
n: (Type, BasicValueEnum<'ctx>),
|
||||
) -> IntValue<'ctx> {
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
let llvm_i64 = ctx.ctx.i64_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (n_ty, n) = n;
|
||||
|
||||
match n.get_type() {
|
||||
Ok(match n.get_type() {
|
||||
BasicTypeEnum::IntType(int_ty) if matches!(int_ty.get_bit_width(), 1 | 8 | 32) => {
|
||||
debug_assert!([
|
||||
ctx.primitives.bool,
|
||||
|
@ -210,10 +280,12 @@ pub fn call_uint64<'ctx>(
|
|||
if ctx.unifier.unioned(n_ty, ctx.primitives.int32) {
|
||||
ctx.builder
|
||||
.build_int_s_extend(n.into_int_value(), llvm_i64, "sext")
|
||||
.map(Into::into)
|
||||
.unwrap()
|
||||
} else {
|
||||
ctx.builder
|
||||
.build_int_z_extend(n.into_int_value(), llvm_i64, "zext")
|
||||
.map(Into::into)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
@ -224,7 +296,7 @@ pub fn call_uint64<'ctx>(
|
|||
ctx.primitives.uint64,
|
||||
].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty)));
|
||||
|
||||
n.into_int_value()
|
||||
n
|
||||
}
|
||||
|
||||
BasicTypeEnum::FloatType(_) => {
|
||||
|
@ -244,24 +316,42 @@ pub fn call_uint64<'ctx>(
|
|||
|
||||
ctx.builder
|
||||
.build_select(val_gez, to_uint64, to_int64, "conv")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
BasicTypeEnum::PointerType(_) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
||||
|
||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.primitives.uint64,
|
||||
None,
|
||||
NDArrayValue::from_ptr_val(n.into_pointer_value(), llvm_usize, None),
|
||||
|generator, ctx, val| {
|
||||
call_uint64(generator, ctx, (elem_ty, val))
|
||||
},
|
||||
)?;
|
||||
|
||||
ndarray.as_ptr_value().into()
|
||||
}
|
||||
|
||||
_ => unsupported_type(ctx, "uint64", &[n_ty])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Invokes the `float` builtin function.
|
||||
pub fn call_float<'ctx>(
|
||||
pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
n: (Type, BasicValueEnum<'ctx>),
|
||||
) -> FloatValue<'ctx> {
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (n_ty, n) = n;
|
||||
|
||||
match n.get_type() {
|
||||
Ok(match n.get_type() {
|
||||
BasicTypeEnum::IntType(int_ty) if matches!(int_ty.get_bit_width(), 1 | 8 | 32 | 64) => {
|
||||
debug_assert!([
|
||||
ctx.primitives.bool,
|
||||
|
@ -278,72 +368,146 @@ pub fn call_float<'ctx>(
|
|||
].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty)) {
|
||||
ctx.builder
|
||||
.build_signed_int_to_float(n.into_int_value(), llvm_f64, "sitofp")
|
||||
.map(Into::into)
|
||||
.unwrap()
|
||||
} else {
|
||||
ctx.builder
|
||||
.build_unsigned_int_to_float(n.into_int_value(), llvm_f64, "uitofp")
|
||||
.unwrap()
|
||||
.map(Into::into)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
BasicTypeEnum::FloatType(_) => {
|
||||
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
|
||||
|
||||
n.into_float_value()
|
||||
n
|
||||
}
|
||||
|
||||
BasicTypeEnum::PointerType(_) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
||||
|
||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.primitives.float,
|
||||
None,
|
||||
NDArrayValue::from_ptr_val(n.into_pointer_value(), llvm_usize, None),
|
||||
|generator, ctx, val| {
|
||||
call_float(generator, ctx, (elem_ty, val))
|
||||
},
|
||||
)?;
|
||||
|
||||
ndarray.as_ptr_value().into()
|
||||
}
|
||||
|
||||
_ => unsupported_type(ctx, "float", &[n_ty])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Invokes the `round` builtin function.
|
||||
pub fn call_round<'ctx>(
|
||||
pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
n: (Type, FloatValue<'ctx>),
|
||||
llvm_ret_ty: IntType<'ctx>,
|
||||
) -> IntValue<'ctx> {
|
||||
n: (Type, BasicValueEnum<'ctx>),
|
||||
ret_ty: Type,
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "round";
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (n_ty, n) = n;
|
||||
let llvm_ret_ty = ctx.get_llvm_abi_type(generator, ret_ty).into_int_type();
|
||||
|
||||
if !ctx.unifier.unioned(n_ty, ctx.primitives.float) {
|
||||
unsupported_type(ctx, FN_NAME, &[n_ty])
|
||||
}
|
||||
Ok(match n.get_type() {
|
||||
BasicTypeEnum::FloatType(_) => {
|
||||
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
|
||||
|
||||
let val = llvm_intrinsics::call_float_round(ctx, n, None);
|
||||
ctx.builder
|
||||
.build_float_to_signed_int(val, llvm_ret_ty, FN_NAME)
|
||||
.unwrap()
|
||||
let val = llvm_intrinsics::call_float_round(ctx, n.into_float_value(), None);
|
||||
ctx.builder
|
||||
.build_float_to_signed_int(val, llvm_ret_ty, FN_NAME)
|
||||
.map(Into::into)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
BasicTypeEnum::PointerType(_) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
||||
|
||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||
generator,
|
||||
ctx,
|
||||
ret_ty,
|
||||
None,
|
||||
NDArrayValue::from_ptr_val(n.into_pointer_value(), llvm_usize, None),
|
||||
|generator, ctx, val| {
|
||||
call_round(generator, ctx, (elem_ty, val), ret_ty)
|
||||
},
|
||||
)?;
|
||||
|
||||
ndarray.as_ptr_value().into()
|
||||
}
|
||||
|
||||
_ => unsupported_type(ctx, FN_NAME, &[n_ty])
|
||||
})
|
||||
}
|
||||
|
||||
/// Invokes the `np_round` builtin function.
|
||||
pub fn call_numpy_round<'ctx>(
|
||||
pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
n: (Type, FloatValue<'ctx>),
|
||||
) -> FloatValue<'ctx> {
|
||||
n: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_round";
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (n_ty, n) = n;
|
||||
|
||||
if !ctx.unifier.unioned(n_ty, ctx.primitives.float) {
|
||||
unsupported_type(ctx, "np_round", &[n_ty])
|
||||
}
|
||||
Ok(match n.get_type() {
|
||||
BasicTypeEnum::FloatType(_) => {
|
||||
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
|
||||
|
||||
llvm_intrinsics::call_float_roundeven(ctx, n, None)
|
||||
llvm_intrinsics::call_float_roundeven(ctx, n.into_float_value(), None).into()
|
||||
}
|
||||
|
||||
BasicTypeEnum::PointerType(_) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
||||
|
||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.primitives.float,
|
||||
None,
|
||||
NDArrayValue::from_ptr_val(n.into_pointer_value(), llvm_usize, None),
|
||||
|generator, ctx, val| {
|
||||
call_numpy_round(generator, ctx, (elem_ty, val))
|
||||
},
|
||||
)?;
|
||||
|
||||
ndarray.as_ptr_value().into()
|
||||
}
|
||||
|
||||
_ => unsupported_type(ctx, FN_NAME, &[n_ty])
|
||||
})
|
||||
}
|
||||
|
||||
/// Invokes the `bool` builtin function.
|
||||
pub fn call_bool<'ctx>(
|
||||
pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
n: (Type, BasicValueEnum<'ctx>),
|
||||
) -> IntValue<'ctx> {
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "bool";
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (n_ty, n) = n;
|
||||
|
||||
match n.get_type() {
|
||||
Ok(match n.get_type() {
|
||||
BasicTypeEnum::IntType(int_ty) if matches!(int_ty.get_bit_width(), 1 | 8) => {
|
||||
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool));
|
||||
|
||||
n.into_int_value()
|
||||
n
|
||||
}
|
||||
|
||||
BasicTypeEnum::IntType(_) => {
|
||||
|
@ -357,6 +521,7 @@ pub fn call_bool<'ctx>(
|
|||
let val = n.into_int_value();
|
||||
ctx.builder
|
||||
.build_int_compare(IntPredicate::NE, val, val.get_type().const_zero(), FN_NAME)
|
||||
.map(Into::into)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
|
@ -366,11 +531,39 @@ pub fn call_bool<'ctx>(
|
|||
let val = n.into_float_value();
|
||||
ctx.builder
|
||||
.build_float_compare(FloatPredicate::UNE, val, val.get_type().const_zero(), FN_NAME)
|
||||
.map(Into::into)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
BasicTypeEnum::PointerType(_) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
||||
|
||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.primitives.bool,
|
||||
None,
|
||||
NDArrayValue::from_ptr_val(
|
||||
n.into_pointer_value(),
|
||||
llvm_usize,
|
||||
None,
|
||||
),
|
||||
|generator, ctx, val| {
|
||||
let elem = call_bool(
|
||||
generator,
|
||||
ctx,
|
||||
(elem_ty, val),
|
||||
)?;
|
||||
|
||||
Ok(generator.bool_to_i8(ctx, elem.into_int_value()).into())
|
||||
},
|
||||
)?;
|
||||
|
||||
ndarray.as_ptr_value().into()
|
||||
}
|
||||
|
||||
_ => unsupported_type(ctx, FN_NAME, &[n_ty])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Invokes the `floor` builtin function.
|
||||
|
|
|
@ -693,7 +693,7 @@ pub fn ndarray_elementwise_unaryop_impl<'ctx, G, MapFn>(
|
|||
map_fn: MapFn,
|
||||
) -> Result<NDArrayValue<'ctx>, String>
|
||||
where
|
||||
G: CodeGenerator,
|
||||
G: CodeGenerator + ?Sized,
|
||||
MapFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, BasicValueEnum<'ctx>) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
{
|
||||
let res = res.unwrap_or_else(|| {
|
||||
|
@ -755,7 +755,7 @@ pub fn ndarray_elementwise_binop_impl<'ctx, G, ValueFn>(
|
|||
value_fn: ValueFn,
|
||||
) -> Result<NDArrayValue<'ctx>, String>
|
||||
where
|
||||
G: CodeGenerator,
|
||||
G: CodeGenerator + ?Sized,
|
||||
ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>)) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
{
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
|
|
@ -298,6 +298,11 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
Some("N".into()),
|
||||
None,
|
||||
);
|
||||
let ndarray_num_ty = make_ndarray_ty(&mut primitives.1, &primitives.0, Some(num_ty.0), None);
|
||||
let float_or_ndarray_ty = primitives.1
|
||||
.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None);
|
||||
let num_or_ndarray_ty = primitives.1
|
||||
.get_fresh_var_with_range(&[num_ty.0, ndarray_num_ty], Some("T".into()), None);
|
||||
|
||||
let var_map: VarMap = vec![(num_ty.1, num_ty.0)].into_iter().collect();
|
||||
let exception_fields = vec![
|
||||
|
@ -564,8 +569,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
name: "int32".into(),
|
||||
simple_name: "int32".into(),
|
||||
signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }],
|
||||
ret: int32,
|
||||
args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }],
|
||||
ret: num_or_ndarray_ty.0,
|
||||
vars: var_map.clone(),
|
||||
})),
|
||||
var_id: Vec::default(),
|
||||
|
@ -577,7 +582,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||
|
||||
Ok(Some(builtin_fns::call_int32(ctx, (arg_ty, arg)).into()))
|
||||
Ok(Some(builtin_fns::call_int32(generator, ctx, (arg_ty, arg))?))
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
|
@ -586,8 +591,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
name: "int64".into(),
|
||||
simple_name: "int64".into(),
|
||||
signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }],
|
||||
ret: int64,
|
||||
args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }],
|
||||
ret: num_or_ndarray_ty.0,
|
||||
vars: var_map.clone(),
|
||||
})),
|
||||
var_id: Vec::default(),
|
||||
|
@ -599,7 +604,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||
|
||||
Ok(Some(builtin_fns::call_int64(ctx, (arg_ty, arg)).into()))
|
||||
Ok(Some(builtin_fns::call_int64(generator, ctx, (arg_ty, arg))?))
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
|
@ -608,8 +613,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
name: "uint32".into(),
|
||||
simple_name: "uint32".into(),
|
||||
signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }],
|
||||
ret: uint32,
|
||||
args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }],
|
||||
ret: num_or_ndarray_ty.0,
|
||||
vars: var_map.clone(),
|
||||
})),
|
||||
var_id: Vec::default(),
|
||||
|
@ -621,7 +626,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||
|
||||
Ok(Some(builtin_fns::call_uint32(ctx, (arg_ty, arg)).into()))
|
||||
Ok(Some(builtin_fns::call_uint32(generator, ctx, (arg_ty, arg))?))
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
|
@ -630,8 +635,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
name: "uint64".into(),
|
||||
simple_name: "uint64".into(),
|
||||
signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }],
|
||||
ret: uint64,
|
||||
args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }],
|
||||
ret: num_or_ndarray_ty.0,
|
||||
vars: var_map.clone(),
|
||||
})),
|
||||
var_id: Vec::default(),
|
||||
|
@ -643,7 +648,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||
|
||||
Ok(Some(builtin_fns::call_uint64(ctx, (arg_ty, arg)).into()))
|
||||
Ok(Some(builtin_fns::call_uint64(generator, ctx, (arg_ty, arg))?))
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
|
@ -652,8 +657,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
name: "float".into(),
|
||||
simple_name: "float".into(),
|
||||
signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }],
|
||||
ret: float,
|
||||
args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }],
|
||||
ret: num_or_ndarray_ty.0,
|
||||
vars: var_map.clone(),
|
||||
})),
|
||||
var_id: Vec::default(),
|
||||
|
@ -665,7 +670,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||
|
||||
Ok(Some(builtin_fns::call_float(ctx, (arg_ty, arg)).into()))
|
||||
Ok(Some(builtin_fns::call_float(generator, ctx, (arg_ty, arg))?))
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
|
@ -779,53 +784,76 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
.map(|val| Some(val.as_basic_value_enum()))
|
||||
}),
|
||||
),
|
||||
create_fn_by_codegen(
|
||||
primitives,
|
||||
&var_map,
|
||||
"round",
|
||||
int32,
|
||||
&[(float, "n")],
|
||||
Box::new(|ctx, _, fun, args, generator| {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
{
|
||||
let common_ndim = primitives.1.get_fresh_const_generic_var(
|
||||
primitives.0.usize(),
|
||||
Some("N".into()),
|
||||
None,
|
||||
);
|
||||
let ndarray_int32 = make_ndarray_ty(&mut primitives.1, &primitives.0, Some(int32), Some(common_ndim.0));
|
||||
let ndarray_float = make_ndarray_ty(&mut primitives.1, &primitives.0, Some(float), Some(common_ndim.0));
|
||||
|
||||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg = args[0].1.clone()
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?
|
||||
.into_float_value();
|
||||
let p0_ty = primitives.1
|
||||
.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None);
|
||||
let ret_ty = primitives.1
|
||||
.get_fresh_var_with_range(&[int32, ndarray_int32], Some("R".into()), None);
|
||||
|
||||
Ok(Some(builtin_fns::call_round(ctx, (arg_ty, arg), llvm_i32).into()))
|
||||
}),
|
||||
),
|
||||
create_fn_by_codegen(
|
||||
primitives,
|
||||
&var_map,
|
||||
"round64",
|
||||
int64,
|
||||
&[(float, "n")],
|
||||
Box::new(|ctx, _, fun, args, generator| {
|
||||
let llvm_i64 = ctx.ctx.i64_type();
|
||||
create_fn_by_codegen(
|
||||
primitives,
|
||||
&var_map,
|
||||
"round",
|
||||
ret_ty.0,
|
||||
&[(p0_ty.0, "n")],
|
||||
Box::new(|ctx, _, fun, args, generator| {
|
||||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg = args[0].1.clone()
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
|
||||
|
||||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg = args[0].1.clone()
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?
|
||||
.into_float_value();
|
||||
Ok(Some(builtin_fns::call_round(generator, ctx, (arg_ty, arg), ctx.primitives.int32)?))
|
||||
}),
|
||||
)
|
||||
},
|
||||
{
|
||||
let common_ndim = primitives.1.get_fresh_const_generic_var(
|
||||
primitives.0.usize(),
|
||||
Some("N".into()),
|
||||
None,
|
||||
);
|
||||
let ndarray_int64 = make_ndarray_ty(&mut primitives.1, &primitives.0, Some(int64), Some(common_ndim.0));
|
||||
let ndarray_float = make_ndarray_ty(&mut primitives.1, &primitives.0, Some(float), Some(common_ndim.0));
|
||||
|
||||
Ok(Some(builtin_fns::call_round(ctx, (arg_ty, arg), llvm_i64).into()))
|
||||
}),
|
||||
),
|
||||
let p0_ty = primitives.1
|
||||
.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None);
|
||||
let ret_ty = primitives.1
|
||||
.get_fresh_var_with_range(&[int64, ndarray_int64], Some("R".into()), None);
|
||||
|
||||
create_fn_by_codegen(
|
||||
primitives,
|
||||
&var_map,
|
||||
"round64",
|
||||
ret_ty.0,
|
||||
&[(p0_ty.0, "n")],
|
||||
Box::new(|ctx, _, fun, args, generator| {
|
||||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg = args[0].1.clone()
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
|
||||
|
||||
Ok(Some(builtin_fns::call_round(generator, ctx, (arg_ty, arg), ctx.primitives.int64)?))
|
||||
}),
|
||||
)
|
||||
},
|
||||
create_fn_by_codegen(
|
||||
primitives,
|
||||
&var_map,
|
||||
"np_round",
|
||||
float,
|
||||
&[(float, "n")],
|
||||
float_or_ndarray_ty.0,
|
||||
&[(float_or_ndarray_ty.0, "n")],
|
||||
Box::new(|ctx, _, fun, args, generator| {
|
||||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg = args[0].1.clone()
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?
|
||||
.into_float_value();
|
||||
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
|
||||
|
||||
Ok(Some(builtin_fns::call_numpy_round(ctx, (arg_ty, arg)).into()))
|
||||
Ok(Some(builtin_fns::call_numpy_round(generator, ctx, (arg_ty, arg))?))
|
||||
}),
|
||||
),
|
||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||
|
@ -957,8 +985,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
name: "bool".into(),
|
||||
simple_name: "bool".into(),
|
||||
signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }],
|
||||
ret: primitives.0.bool,
|
||||
args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }],
|
||||
ret: num_or_ndarray_ty.0,
|
||||
vars: var_map.clone(),
|
||||
})),
|
||||
var_id: Vec::default(),
|
||||
|
@ -970,7 +998,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||
|
||||
Ok(Some(builtin_fns::call_bool(ctx, (arg_ty, arg)).into()))
|
||||
Ok(Some(builtin_fns::call_bool(generator, ctx, (arg_ty, arg))?))
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
|
|
|
@ -14,17 +14,7 @@ use crate::{
|
|||
},
|
||||
};
|
||||
use itertools::{Itertools, izip};
|
||||
use nac3parser::ast::{
|
||||
self,
|
||||
fold::{self, Fold},
|
||||
Arguments,
|
||||
Comprehension,
|
||||
ExprContext,
|
||||
ExprKind,
|
||||
Located,
|
||||
Location,
|
||||
StrRef
|
||||
};
|
||||
use nac3parser::ast::{self, fold::{self, Fold}, Arguments, Comprehension, ExprContext, ExprKind, Located, Location, StrRef};
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
@ -860,66 +850,136 @@ impl<'a> Inferencer<'a> {
|
|||
},
|
||||
}))
|
||||
}
|
||||
// int64 is special because its argument can be a constant larger than int32
|
||||
if id == &"int64".into() && args.len() == 1 {
|
||||
if let ExprKind::Constant { value: ast::Constant::Int(val), kind } =
|
||||
&args[0].node
|
||||
{
|
||||
let custom = Some(self.primitives.int64);
|
||||
let v: Result<i64, _> = (*val).try_into();
|
||||
return if v.is_ok() {
|
||||
Ok(Some(Located {
|
||||
location: args[0].location,
|
||||
custom,
|
||||
node: ExprKind::Constant {
|
||||
value: ast::Constant::Int(*val),
|
||||
kind: kind.clone(),
|
||||
},
|
||||
}))
|
||||
} else {
|
||||
report_error("Integer out of bound", args[0].location)
|
||||
}
|
||||
}
|
||||
|
||||
if [
|
||||
"int32",
|
||||
"float",
|
||||
"bool",
|
||||
"round",
|
||||
].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 {
|
||||
let target_ty = if id == &"int32".into() || id == &"round".into() {
|
||||
self.primitives.int32
|
||||
} else if id == &"float".into() {
|
||||
self.primitives.float
|
||||
} else if id == &"bool".into() {
|
||||
self.primitives.bool
|
||||
} else { unreachable!() };
|
||||
|
||||
let arg0 = self.fold_expr(args.remove(0))?;
|
||||
let arg0_ty = arg0.custom.unwrap();
|
||||
|
||||
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
|
||||
let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
|
||||
|
||||
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims))
|
||||
} else {
|
||||
target_ty
|
||||
};
|
||||
|
||||
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![
|
||||
FuncArg {
|
||||
name: "n".into(),
|
||||
ty: arg0.custom.unwrap(),
|
||||
default_value: None,
|
||||
},
|
||||
],
|
||||
ret,
|
||||
vars: VarMap::new(),
|
||||
}));
|
||||
|
||||
return Ok(Some(Located {
|
||||
location,
|
||||
custom: Some(ret),
|
||||
node: ExprKind::Call {
|
||||
func: Box::new(Located {
|
||||
custom: Some(custom),
|
||||
location: func.location,
|
||||
node: ExprKind::Name { id: *id, ctx: ctx.clone() },
|
||||
}),
|
||||
args: vec![arg0],
|
||||
keywords: vec![],
|
||||
},
|
||||
}))
|
||||
}
|
||||
if id == &"uint32".into() && args.len() == 1 {
|
||||
if let ExprKind::Constant { value: ast::Constant::Int(val), kind } =
|
||||
&args[0].node
|
||||
{
|
||||
let custom = Some(self.primitives.uint32);
|
||||
let v: Result<u32, _> = (*val).try_into();
|
||||
return if v.is_ok() {
|
||||
Ok(Some(Located {
|
||||
location: args[0].location,
|
||||
custom,
|
||||
node: ExprKind::Constant {
|
||||
value: ast::Constant::Int(*val),
|
||||
kind: kind.clone(),
|
||||
},
|
||||
}))
|
||||
} else {
|
||||
report_error("Integer out of bound", args[0].location)
|
||||
|
||||
// int64, uint32 and uint64 are special because their argument can be a constant outside the
|
||||
// range of int32s
|
||||
if [
|
||||
"int64",
|
||||
"uint32",
|
||||
"uint64",
|
||||
].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 {
|
||||
let target_ty = if id == &"int64".into() {
|
||||
self.primitives.int64
|
||||
} else if id == &"uint32".into() {
|
||||
self.primitives.uint32
|
||||
} else if id == &"uint64".into() {
|
||||
self.primitives.uint64
|
||||
} else { unreachable!() };
|
||||
|
||||
let arg0 = self.fold_expr(args.remove(0))?;
|
||||
let arg0_ty = arg0.custom.unwrap();
|
||||
|
||||
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
|
||||
let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
|
||||
|
||||
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims))
|
||||
} else {
|
||||
if let ExprKind::Constant {
|
||||
value: ast::Constant::Int(val),
|
||||
kind
|
||||
} = &arg0.node {
|
||||
let conv_is_ok = if self.unifier.unioned(target_ty, self.primitives.int64) {
|
||||
i64::try_from(*val).is_ok()
|
||||
} else if self.unifier.unioned(target_ty, self.primitives.uint32) {
|
||||
u32::try_from(*val).is_ok()
|
||||
} else if self.unifier.unioned(target_ty, self.primitives.uint64) {
|
||||
u64::try_from(*val).is_ok()
|
||||
} else { unreachable!() };
|
||||
|
||||
return if conv_is_ok {
|
||||
Ok(Some(Located {
|
||||
location: arg0.location,
|
||||
custom: Some(target_ty),
|
||||
node: ExprKind::Constant {
|
||||
value: ast::Constant::Int(*val),
|
||||
kind: kind.clone(),
|
||||
},
|
||||
}))
|
||||
} else {
|
||||
report_error("Integer out of bound", arg0.location)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if id == &"uint64".into() && args.len() == 1 {
|
||||
if let ExprKind::Constant { value: ast::Constant::Int(val), kind } =
|
||||
&args[0].node
|
||||
{
|
||||
let custom = Some(self.primitives.uint64);
|
||||
let v: Result<u64, _> = (*val).try_into();
|
||||
return if v.is_ok() {
|
||||
Ok(Some(Located {
|
||||
location: args[0].location,
|
||||
custom,
|
||||
node: ExprKind::Constant {
|
||||
value: ast::Constant::Int(*val),
|
||||
kind: kind.clone(),
|
||||
},
|
||||
}))
|
||||
} else {
|
||||
report_error("Integer out of bound", args[0].location)
|
||||
}
|
||||
}
|
||||
|
||||
target_ty
|
||||
};
|
||||
|
||||
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![
|
||||
FuncArg {
|
||||
name: "n".into(),
|
||||
ty: arg0.custom.unwrap(),
|
||||
default_value: None,
|
||||
},
|
||||
],
|
||||
ret,
|
||||
vars: VarMap::new(),
|
||||
}));
|
||||
|
||||
return Ok(Some(Located {
|
||||
location,
|
||||
custom: Some(ret),
|
||||
node: ExprKind::Call {
|
||||
func: Box::new(Located {
|
||||
custom: Some(custom),
|
||||
location: func.location,
|
||||
node: ExprKind::Name { id: *id, ctx: ctx.clone() },
|
||||
}),
|
||||
args: vec![arg0],
|
||||
keywords: vec![],
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
// 1-argument ndarray n-dimensional creation functions
|
||||
|
|
|
@ -58,11 +58,26 @@ class _NDArrayDummy(Generic[T, N]):
|
|||
# https://stackoverflow.com/questions/67803260/how-to-create-a-type-alias-with-a-throw-away-generic
|
||||
NDArray = Union[npt.NDArray[T], _NDArrayDummy[T, N]]
|
||||
|
||||
def round_away_zero(x):
|
||||
if x >= 0.0:
|
||||
return math.floor(x + 0.5)
|
||||
def _bool(x):
|
||||
if isinstance(x, np.ndarray):
|
||||
return np.bool_(x)
|
||||
else:
|
||||
return math.ceil(x - 0.5)
|
||||
return bool(x)
|
||||
|
||||
def _float(x):
|
||||
if isinstance(x, np.ndarray):
|
||||
return np.float_(x)
|
||||
else:
|
||||
return float(x)
|
||||
|
||||
def round_away_zero(x):
|
||||
if isinstance(x, np.ndarray):
|
||||
return np.vectorize(round_away_zero)(x)
|
||||
else:
|
||||
if x >= 0.0:
|
||||
return math.floor(x + 0.5)
|
||||
else:
|
||||
return math.ceil(x - 0.5)
|
||||
|
||||
def patch(module):
|
||||
def dbl_nan():
|
||||
|
@ -112,6 +127,8 @@ def patch(module):
|
|||
module.int64 = int64
|
||||
module.uint32 = uint32
|
||||
module.uint64 = uint64
|
||||
module.bool = _bool
|
||||
module.float = _float
|
||||
module.TypeVar = TypeVar
|
||||
module.ConstGeneric = ConstGeneric
|
||||
module.Generic = Generic
|
||||
|
|
|
@ -6,6 +6,18 @@ def output_bool(x: bool):
|
|||
def output_int32(x: int32):
|
||||
...
|
||||
|
||||
@extern
|
||||
def output_int64(x: int64):
|
||||
...
|
||||
|
||||
@extern
|
||||
def output_uint32(x: uint32):
|
||||
...
|
||||
|
||||
@extern
|
||||
def output_uint64(x: uint64):
|
||||
...
|
||||
|
||||
@extern
|
||||
def output_float64(x: float):
|
||||
...
|
||||
|
@ -24,6 +36,21 @@ def output_ndarray_int32_2(n: ndarray[int32, Literal[2]]):
|
|||
for c in range(len(n[r])):
|
||||
output_int32(n[r][c])
|
||||
|
||||
def output_ndarray_int64_2(n: ndarray[int64, Literal[2]]):
|
||||
for r in range(len(n)):
|
||||
for c in range(len(n[r])):
|
||||
output_int64(n[r][c])
|
||||
|
||||
def output_ndarray_uint32_2(n: ndarray[uint32, Literal[2]]):
|
||||
for r in range(len(n)):
|
||||
for c in range(len(n[r])):
|
||||
output_uint32(n[r][c])
|
||||
|
||||
def output_ndarray_uint64_2(n: ndarray[uint64, Literal[2]]):
|
||||
for r in range(len(n)):
|
||||
for c in range(len(n[r])):
|
||||
output_uint64(n[r][c])
|
||||
|
||||
def output_ndarray_float_1(n: ndarray[float, Literal[1]]):
|
||||
for i in range(len(n)):
|
||||
output_float64(n[i])
|
||||
|
@ -649,6 +676,64 @@ def test_ndarray_ge_broadcast_rhs_scalar():
|
|||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_int32():
|
||||
x = np_identity(2)
|
||||
y = int32(x)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_int32_2(y)
|
||||
|
||||
def test_ndarray_int64():
|
||||
x = np_identity(2)
|
||||
y = int64(x)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_int64_2(y)
|
||||
|
||||
def test_ndarray_uint32():
|
||||
x = np_identity(2)
|
||||
y = uint32(x)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_uint32_2(y)
|
||||
|
||||
def test_ndarray_uint64():
|
||||
x = np_identity(2)
|
||||
y = uint64(x)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_uint64_2(y)
|
||||
|
||||
def test_ndarray_float():
|
||||
x = np_full([2, 2], 1)
|
||||
y = float(x)
|
||||
|
||||
output_ndarray_int32_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_bool():
|
||||
x = np_identity(2)
|
||||
y = bool(x)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_round():
|
||||
x = np_identity(2)
|
||||
y = round(x)
|
||||
z = round64(x)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_int32_2(y)
|
||||
output_ndarray_int64_2(z)
|
||||
|
||||
def test_ndarray_numpy_round():
|
||||
x = np_identity(2)
|
||||
y = np_round(x)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def run() -> int32:
|
||||
test_ndarray_ctor()
|
||||
test_ndarray_empty()
|
||||
|
@ -739,4 +824,14 @@ def run() -> int32:
|
|||
test_ndarray_ge_broadcast_lhs_scalar()
|
||||
test_ndarray_ge_broadcast_rhs_scalar()
|
||||
|
||||
test_ndarray_int32()
|
||||
test_ndarray_int64()
|
||||
test_ndarray_uint32()
|
||||
test_ndarray_uint64()
|
||||
test_ndarray_float()
|
||||
test_ndarray_bool()
|
||||
|
||||
test_ndarray_round()
|
||||
test_ndarray_numpy_round()
|
||||
|
||||
return 0
|
||||
|
|
Loading…
Reference in New Issue