core: WIP - fabs, sqrt and rint works now
This commit is contained in:
parent
ffcae46bd8
commit
11a0bcc443
|
@ -1217,46 +1217,124 @@ pub fn call_numpy_log2<'ctx, G: CodeGenerator + ?Sized>(
|
|||
})
|
||||
}
|
||||
|
||||
/// Invokes the `np_sqrt` builtin function.
|
||||
pub fn call_numpy_fabs<'ctx>(
|
||||
/// Invokes the `np_fabs` builtin function.
|
||||
pub fn call_numpy_fabs<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x: (Type, FloatValue<'ctx>),
|
||||
) -> FloatValue<'ctx> {
|
||||
x: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_fabs";
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (x_ty, x) = x;
|
||||
|
||||
if !ctx.unifier.unioned(x_ty, ctx.primitives.float) {
|
||||
unsupported_type(ctx, "np_fabs", &[x_ty])
|
||||
Ok(match x.get_type() {
|
||||
BasicTypeEnum::FloatType(_) => {
|
||||
debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float));
|
||||
|
||||
llvm_intrinsics::call_float_fabs(ctx, x.into_float_value(), None).into()
|
||||
}
|
||||
|
||||
llvm_intrinsics::call_float_fabs(ctx, x, None)
|
||||
BasicTypeEnum::PointerType(_) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
|
||||
|
||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||
generator,
|
||||
ctx,
|
||||
elem_ty,
|
||||
None,
|
||||
NDArrayValue::from_ptr_val(x.into_pointer_value(), llvm_usize, None),
|
||||
|generator, ctx, val| {
|
||||
call_numpy_fabs(generator, ctx, (elem_ty, val))
|
||||
},
|
||||
)?;
|
||||
|
||||
ndarray.as_ptr_value().into()
|
||||
}
|
||||
|
||||
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
|
||||
})
|
||||
}
|
||||
|
||||
/// Invokes the `np_sqrt` builtin function.
|
||||
pub fn call_numpy_sqrt<'ctx>(
|
||||
pub fn call_numpy_sqrt<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x: (Type, FloatValue<'ctx>),
|
||||
) -> FloatValue<'ctx> {
|
||||
x: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_sqrt";
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (x_ty, x) = x;
|
||||
|
||||
if !ctx.unifier.unioned(x_ty, ctx.primitives.float) {
|
||||
unsupported_type(ctx, "np_sqrt", &[x_ty])
|
||||
Ok(match x.get_type() {
|
||||
BasicTypeEnum::FloatType(_) => {
|
||||
debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float));
|
||||
|
||||
llvm_intrinsics::call_float_sqrt(ctx, x.into_float_value(), None).into()
|
||||
}
|
||||
|
||||
llvm_intrinsics::call_float_sqrt(ctx, x, None)
|
||||
BasicTypeEnum::PointerType(_) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
|
||||
|
||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||
generator,
|
||||
ctx,
|
||||
elem_ty,
|
||||
None,
|
||||
NDArrayValue::from_ptr_val(x.into_pointer_value(), llvm_usize, None),
|
||||
|generator, ctx, val| {
|
||||
call_numpy_sqrt(generator, ctx, (elem_ty, val))
|
||||
},
|
||||
)?;
|
||||
|
||||
ndarray.as_ptr_value().into()
|
||||
}
|
||||
|
||||
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
|
||||
})
|
||||
}
|
||||
|
||||
/// Invokes the `np_rint` builtin function.
|
||||
pub fn call_numpy_rint<'ctx>(
|
||||
pub fn call_numpy_rint<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x: (Type, FloatValue<'ctx>),
|
||||
) -> FloatValue<'ctx> {
|
||||
x: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_rint";
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (x_ty, x) = x;
|
||||
|
||||
if !ctx.unifier.unioned(x_ty, ctx.primitives.float) {
|
||||
unsupported_type(ctx, "np_rint", &[x_ty])
|
||||
Ok(match x.get_type() {
|
||||
BasicTypeEnum::FloatType(_) => {
|
||||
debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float));
|
||||
|
||||
llvm_intrinsics::call_float_roundeven(ctx, x.into_float_value(), None).into()
|
||||
}
|
||||
|
||||
llvm_intrinsics::call_float_roundeven(ctx, x, None)
|
||||
BasicTypeEnum::PointerType(_) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty);
|
||||
|
||||
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||
generator,
|
||||
ctx,
|
||||
elem_ty,
|
||||
None,
|
||||
NDArrayValue::from_ptr_val(x.into_pointer_value(), llvm_usize, None),
|
||||
|generator, ctx, val| {
|
||||
call_numpy_rint(generator, ctx, (elem_ty, val))
|
||||
},
|
||||
)?;
|
||||
|
||||
ndarray.as_ptr_value().into()
|
||||
}
|
||||
|
||||
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
|
||||
})
|
||||
}
|
||||
|
||||
/// Invokes the `np_tan` builtin function.
|
||||
|
|
|
@ -1514,45 +1514,42 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built
|
|||
unifier,
|
||||
&var_map,
|
||||
"np_fabs",
|
||||
float,
|
||||
&[(float, "x")],
|
||||
float_or_ndarray_ty.0,
|
||||
&[(float_or_ndarray_ty.0, "x")],
|
||||
Box::new(|ctx, _, fun, args, generator| {
|
||||
let x_ty = fun.0.args[0].ty;
|
||||
let x_val = args[0].1.clone()
|
||||
.to_basic_value_enum(ctx, generator, x_ty)?
|
||||
.into_float_value();
|
||||
.to_basic_value_enum(ctx, generator, x_ty)?;
|
||||
|
||||
Ok(Some(builtin_fns::call_numpy_fabs(ctx, (x_ty, x_val)).into()))
|
||||
Ok(Some(builtin_fns::call_numpy_fabs(generator, ctx, (x_ty, x_val))?))
|
||||
}),
|
||||
),
|
||||
create_fn_by_codegen(
|
||||
unifier,
|
||||
&var_map,
|
||||
"np_sqrt",
|
||||
float,
|
||||
&[(float, "x")],
|
||||
float_or_ndarray_ty.0,
|
||||
&[(float_or_ndarray_ty.0, "x")],
|
||||
Box::new(|ctx, _, fun, args, generator| {
|
||||
let x_ty = fun.0.args[0].ty;
|
||||
let x_val = args[0].1.clone()
|
||||
.to_basic_value_enum(ctx, generator, x_ty)?
|
||||
.into_float_value();
|
||||
.to_basic_value_enum(ctx, generator, x_ty)?;
|
||||
|
||||
Ok(Some(builtin_fns::call_numpy_sqrt(ctx, (x_ty, x_val)).into()))
|
||||
Ok(Some(builtin_fns::call_numpy_sqrt(generator, ctx, (x_ty, x_val))?))
|
||||
}),
|
||||
),
|
||||
create_fn_by_codegen(
|
||||
unifier,
|
||||
&var_map,
|
||||
"np_rint",
|
||||
float,
|
||||
&[(float, "x")],
|
||||
float_or_ndarray_ty.0,
|
||||
&[(float_or_ndarray_ty.0, "x")],
|
||||
Box::new(|ctx, _, fun, args, generator| {
|
||||
let x_ty = fun.0.args[0].ty;
|
||||
let x_val = args[0].1.clone()
|
||||
.to_basic_value_enum(ctx, generator, x_ty)?
|
||||
.into_float_value();
|
||||
.to_basic_value_enum(ctx, generator, x_ty)?;
|
||||
|
||||
Ok(Some(builtin_fns::call_numpy_rint(ctx, (x_ty, x_val)).into()))
|
||||
Ok(Some(builtin_fns::call_numpy_rint(generator, ctx, (x_ty, x_val))?))
|
||||
}),
|
||||
),
|
||||
create_fn_by_codegen(
|
||||
|
|
|
@ -837,6 +837,27 @@ def test_ndarray_log2():
|
|||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_fabs():
|
||||
x = -np_identity(2)
|
||||
y = np_fabs(x)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_sqrt():
|
||||
x = np_identity(2)
|
||||
y = np_sqrt(x)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_rint():
|
||||
x = np_identity(2)
|
||||
y = np_rint(x)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def run() -> int32:
|
||||
test_ndarray_ctor()
|
||||
test_ndarray_empty()
|
||||
|
@ -947,5 +968,8 @@ def run() -> int32:
|
|||
test_ndarray_log()
|
||||
test_ndarray_log10()
|
||||
test_ndarray_log2()
|
||||
test_ndarray_fabs()
|
||||
test_ndarray_sqrt()
|
||||
test_ndarray_rint()
|
||||
|
||||
return 0
|
||||
|
|
Loading…
Reference in New Issue