Compare commits

...

2 Commits

Author SHA1 Message Date
David Mak 12bdf6f77c core: WIP - round works now 2024-04-26 19:36:18 +08:00
David Mak fcb6234bbc core: WIP - float and bool works now 2024-04-26 19:36:16 +08:00
5 changed files with 215 additions and 57 deletions

View File

@ -1,5 +1,5 @@
use inkwell::{FloatPredicate, IntPredicate};
use inkwell::types::{BasicTypeEnum, IntType};
use inkwell::types::BasicTypeEnum;
use inkwell::values::{BasicValueEnum, FloatValue, IntValue};
use itertools::Itertools;
@ -373,15 +373,17 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>(
}
/// Invokes the `float` builtin function.
pub fn call_float<'ctx>(
pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
n: (Type, BasicValueEnum<'ctx>),
) -> FloatValue<'ctx> {
) -> Result<BasicValueEnum<'ctx>, String> {
let llvm_f64 = ctx.ctx.f64_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let (n_ty, n) = n;
match n.get_type() {
Ok(match n.get_type() {
BasicTypeEnum::IntType(int_ty) if matches!(int_ty.get_bit_width(), 1 | 8 | 32 | 64) => {
debug_assert!([
ctx.primitives.bool,
@ -398,42 +400,104 @@ pub fn call_float<'ctx>(
].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty)) {
ctx.builder
.build_signed_int_to_float(n.into_int_value(), llvm_f64, "sitofp")
.map(Into::into)
.unwrap()
} else {
ctx.builder
.build_unsigned_int_to_float(n.into_int_value(), llvm_f64, "uitofp")
.unwrap()
.map(Into::into)
.unwrap()
}
}
BasicTypeEnum::FloatType(_) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
n.into_float_value()
n
}
BasicTypeEnum::PointerType(_) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
ctx,
ctx.primitives.float,
None,
NDArrayValue::from_ptr_val(
n.into_pointer_value(),
llvm_usize,
None,
),
|generator, ctx, val| {
call_float(
generator,
ctx,
(elem_ty, val),
)
},
)?;
ndarray.as_ptr_value().into()
}
_ => unsupported_type(ctx, "float", &[n_ty])
}
})
}
/// Invokes the `round` builtin function.
pub fn call_round<'ctx>(
pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
n: (Type, FloatValue<'ctx>),
llvm_ret_ty: IntType<'ctx>,
) -> IntValue<'ctx> {
n: (Type, BasicValueEnum<'ctx>),
ret_ty: Type,
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "round";
let llvm_usize = generator.get_size_type(ctx.ctx);
let (n_ty, n) = n;
let llvm_ret_ty = ctx.get_llvm_abi_type(generator, ret_ty).into_int_type();
if !ctx.unifier.unioned(n_ty, ctx.primitives.float) {
unsupported_type(ctx, FN_NAME, &[n_ty])
}
Ok(match n.get_type() {
BasicTypeEnum::FloatType(_) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
let val = llvm_intrinsics::call_float_round(ctx, n, None);
ctx.builder
.build_float_to_signed_int(val, llvm_ret_ty, FN_NAME)
.unwrap()
let val = llvm_intrinsics::call_float_round(ctx, n.into_float_value(), None);
ctx.builder
.build_float_to_signed_int(val, llvm_ret_ty, FN_NAME)
.map(Into::into)
.unwrap()
}
BasicTypeEnum::PointerType(_) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
ctx,
ret_ty,
None,
NDArrayValue::from_ptr_val(
n.into_pointer_value(),
llvm_usize,
None,
),
|generator, ctx, val| {
call_round(
generator,
ctx,
(elem_ty, val),
ret_ty,
)
},
)?;
ndarray.as_ptr_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[n_ty])
})
}
/// Invokes the `np_round` builtin function.
@ -451,19 +515,22 @@ pub fn call_numpy_round<'ctx>(
}
/// Invokes the `bool` builtin function.
pub fn call_bool<'ctx>(
pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
n: (Type, BasicValueEnum<'ctx>),
) -> IntValue<'ctx> {
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "bool";
let llvm_usize = generator.get_size_type(ctx.ctx);
let (n_ty, n) = n;
match n.get_type() {
Ok(match n.get_type() {
BasicTypeEnum::IntType(int_ty) if matches!(int_ty.get_bit_width(), 1 | 8) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool));
n.into_int_value()
n
}
BasicTypeEnum::IntType(_) => {
@ -477,6 +544,7 @@ pub fn call_bool<'ctx>(
let val = n.into_int_value();
ctx.builder
.build_int_compare(IntPredicate::NE, val, val.get_type().const_zero(), FN_NAME)
.map(Into::into)
.unwrap()
}
@ -486,11 +554,39 @@ pub fn call_bool<'ctx>(
let val = n.into_float_value();
ctx.builder
.build_float_compare(FloatPredicate::UNE, val, val.get_type().const_zero(), FN_NAME)
.map(Into::into)
.unwrap()
}
BasicTypeEnum::PointerType(_) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
let ndarray = ndarray_elementwise_unaryop_impl(
generator,
ctx,
ctx.primitives.bool,
None,
NDArrayValue::from_ptr_val(
n.into_pointer_value(),
llvm_usize,
None,
),
|generator, ctx, val| {
let elem = call_bool(
generator,
ctx,
(elem_ty, val),
)?;
Ok(generator.bool_to_i8(ctx, elem.into_int_value()).into())
},
)?;
ndarray.as_ptr_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[n_ty])
}
})
}
/// Invokes the `floor` builtin function.

View File

@ -655,8 +655,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
name: "float".into(),
simple_name: "float".into(),
signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }],
ret: float,
args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }],
ret: num_or_ndarray_ty.0,
vars: var_map.clone(),
})),
var_id: Vec::default(),
@ -668,7 +668,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
Ok(Some(builtin_fns::call_float(ctx, (arg_ty, arg)).into()))
Ok(Some(builtin_fns::call_float(generator, ctx, (arg_ty, arg))?))
},
)))),
loc: None,
@ -782,23 +782,35 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
.map(|val| Some(val.as_basic_value_enum()))
}),
),
create_fn_by_codegen(
primitives,
&var_map,
"round",
int32,
&[(float, "n")],
Box::new(|ctx, _, fun, args, generator| {
let llvm_i32 = ctx.ctx.i32_type();
{
let common_ndim = primitives.1.get_fresh_const_generic_var(
primitives.0.usize(),
Some("N".into()),
None,
);
let ndarray_int32 = make_ndarray_ty(&mut primitives.1, &primitives.0, Some(int32), Some(common_ndim.0));
let ndarray_float = make_ndarray_ty(&mut primitives.1, &primitives.0, Some(float), Some(common_ndim.0));
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?
.into_float_value();
let p0_ty = primitives.1
.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None);
let ret_ty = primitives.1
.get_fresh_var_with_range(&[int32, ndarray_int32], Some("R".into()), None);
Ok(Some(builtin_fns::call_round(ctx, (arg_ty, arg), llvm_i32).into()))
}),
),
create_fn_by_codegen(
primitives,
&var_map,
"round",
ret_ty.0,
&[(p0_ty.0, "n")],
Box::new(|ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
Ok(Some(builtin_fns::call_round(generator, ctx, (arg_ty, arg), ctx.primitives.int32)?))
}),
)
},
create_fn_by_codegen(
primitives,
&var_map,
@ -806,14 +818,11 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
int64,
&[(float, "n")],
Box::new(|ctx, _, fun, args, generator| {
let llvm_i64 = ctx.ctx.i64_type();
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?
.into_float_value();
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
Ok(Some(builtin_fns::call_round(ctx, (arg_ty, arg), llvm_i64).into()))
Ok(Some(builtin_fns::call_round(generator, ctx, (arg_ty, arg), ctx.primitives.int64)?))
}),
),
create_fn_by_codegen(
@ -960,8 +969,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
name: "bool".into(),
simple_name: "bool".into(),
signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }],
ret: primitives.0.bool,
args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }],
ret: num_or_ndarray_ty.0,
vars: var_map.clone(),
})),
var_id: Vec::default(),
@ -973,7 +982,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
Ok(Some(builtin_fns::call_bool(ctx, (arg_ty, arg)).into()))
Ok(Some(builtin_fns::call_bool(generator, ctx, (arg_ty, arg))?))
},
)))),
loc: None,

View File

@ -851,8 +851,19 @@ impl<'a> Inferencer<'a> {
}))
}
if id == &"int32".into() && args.len() == 1 {
let int32_ty = self.primitives.int32;
if [
"int32",
"float",
"bool",
"round",
].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 {
let target_ty = if id == &"int32".into() || id == &"round".into() {
self.primitives.int32
} else if id == &"float".into() {
self.primitives.float
} else if id == &"bool".into() {
self.primitives.bool
} else { unreachable!() };
let arg0 = self.fold_expr(args.remove(0))?;
let arg0_ty = arg0.custom.unwrap();
@ -860,9 +871,9 @@ impl<'a> Inferencer<'a> {
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
make_ndarray_ty(self.unifier, self.primitives, Some(int32_ty), Some(ndarray_ndims))
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims))
} else {
int32_ty
target_ty
};
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {

View File

@ -58,11 +58,26 @@ class _NDArrayDummy(Generic[T, N]):
# https://stackoverflow.com/questions/67803260/how-to-create-a-type-alias-with-a-throw-away-generic
NDArray = Union[npt.NDArray[T], _NDArrayDummy[T, N]]
def round_away_zero(x):
if x >= 0.0:
return math.floor(x + 0.5)
def _bool(x):
if isinstance(x, np.ndarray):
return np.bool_(x)
else:
return math.ceil(x - 0.5)
return bool(x)
def _float(x):
if isinstance(x, np.ndarray):
return np.float_(x)
else:
return float(x)
def round_away_zero(x):
if isinstance(x, np.ndarray):
return np.vectorize(round_away_zero)(x)
else:
if x >= 0.0:
return math.floor(x + 0.5)
else:
return math.ceil(x - 0.5)
def patch(module):
def dbl_nan():
@ -112,6 +127,8 @@ def patch(module):
module.int64 = int64
module.uint32 = uint32
module.uint64 = uint64
module.bool = _bool
module.float = _float
module.TypeVar = TypeVar
module.ConstGeneric = ConstGeneric
module.Generic = Generic

View File

@ -704,6 +704,27 @@ def test_ndarray_uint64():
output_ndarray_float_2(x)
output_ndarray_uint64_2(y)
def test_ndarray_float():
x = np_full([2, 2], 1)
y = float(x)
output_ndarray_int32_2(x)
output_ndarray_float_2(y)
def test_ndarray_bool():
x = np_identity(2)
y = bool(x)
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_round():
x = np_identity(2)
y = round(x)
output_ndarray_float_2(x)
output_ndarray_int32_2(y)
def run() -> int32:
test_ndarray_ctor()
test_ndarray_empty()
@ -798,5 +819,9 @@ def run() -> int32:
test_ndarray_int64()
test_ndarray_uint32()
test_ndarray_uint64()
test_ndarray_float()
test_ndarray_bool()
test_ndarray_round()
return 0