forked from M-Labs/nac3
1
0
Fork 0

WIP: core/ndstrides: checkpoint 13

This commit is contained in:
lyken 2024-08-15 14:38:05 +08:00
parent 4b765cfb27
commit a69a441bdd
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
6 changed files with 101 additions and 98 deletions

View File

@ -1555,9 +1555,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
if op.base == Operator::MatMult {
// Handle `left @ right`
let result = NDArrayObject::matmul(generator, ctx, left, right, out)
.split_unsized(generator, ctx)
.to_basic_value_enum();
Ok(Some(ValueEnum::Dynamic(result)))
.split_unsized(generator, ctx);
Ok(Some(ValueEnum::Dynamic(result.to_basic_value_enum())))
} else {
// For other operators like +, -, etc...; do them element-wise-ly
let result = NDArrayObject::broadcasting_starmap(
@ -1566,18 +1565,21 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
&[left, right],
out,
|generator, ctx, scalars| {
let left = scalars[0];
let right = scalars[1];
gen_binop_expr_with_values(
let left_scalar = scalars[0];
let right_scalar = scalars[1];
let result = gen_binop_expr_with_values(
generator,
ctx,
(&Some(left.dtype), left.value),
(&Some(left_scalar.ty), left_scalar.value),
op,
(&Some(right.dtype), right.value),
(&Some(right_scalar.ty), right_scalar.value),
ctx.current_loc,
)?
.unwrap()
.to_basic_value_enum(ctx, generator, common_dtype)
.to_basic_value_enum(ctx, generator, common_dtype)?;
Ok(AnyObject { value: result, ty: common_dtype })
},
)?;
Ok(Some(ValueEnum::Dynamic(result.instance.value.as_basic_value_enum())))
@ -1671,6 +1673,7 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
/// Generates LLVM IR for a unary operator expression using the [`Type`] and
/// [LLVM value][`BasicValueEnum`] of the operands.
#[allow(clippy::only_used_in_recursion)]
pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,

View File

@ -63,7 +63,7 @@ impl<'ctx, Element: Model<'ctx>> Ptr<'ctx, ArrayModel<Element>> {
assert!(i < self.model.0.len);
let zero = ctx.ctx.i32_type().const_zero();
let i = ctx.ctx.i32_type().const_int(i as u64, false);
let i = ctx.ctx.i32_type().const_int(u64::from(i), false);
let ptr = unsafe { ctx.builder.build_in_bounds_gep(self.value, &[zero, i], name).unwrap() };
PtrModel(self.model.0.element).check_value(generator, ctx.ctx, ptr).unwrap()
@ -114,7 +114,7 @@ impl<'ctx, const LEN: u32, Element: Model<'ctx>> Ptr<'ctx, NArrayModel<LEN, Elem
assert!(i < LEN);
let zero = ctx.ctx.i32_type().const_zero();
let i = ctx.ctx.i32_type().const_int(i as u64, false);
let i = ctx.ctx.i32_type().const_int(u64::from(i), false);
let ptr = unsafe { ctx.builder.build_in_bounds_gep(self.value, &[zero, i], name).unwrap() };
PtrModel(self.model.0 .0).check_value(generator, ctx.ctx, ptr).unwrap()

View File

@ -86,10 +86,17 @@ impl<'ctx> AnyObject<'ctx> {
matches!(&*ctx.unifier.get_ty(self.ty), TypeEnum::TTuple { .. })
}
pub fn into_tuple() {}
pub fn is_int32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned(self.ty, ctx.primitives.int32)
}
pub fn into_int32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Int<'ctx, Int32> {
assert!(self.is_int32(ctx));
IntModel(Int32).believe_value(self.value.into_int_value())
}
pub fn is_uint32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned(self.ty, ctx.primitives.uint32)
}
@ -106,44 +113,33 @@ impl<'ctx> AnyObject<'ctx> {
ctx.unifier.unioned(self.ty, ctx.primitives.bool)
}
/// Returns true if the object type is `bool`, `int32`, `int64`, `uint32`, or `uint64`.
pub fn is_int_like(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned_any(self.ty, int_like(ctx))
}
/// Returns true if the object type is `int32`, `int64`.
pub fn is_signed_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned_any(self.ty, signed_ints(ctx))
}
/// Returns true if the object type is `uint32`, `uint64`.
pub fn is_unsigned_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned_any(self.ty, unsigned_ints(ctx))
}
/// Convenience function. If object has type `int32`, `int64`, `uint32`, `uint64`, or `bool`,
/// get its underlying LLVM value.
///
/// Panic if the type is wrong.
pub fn into_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
if self.is_int_like(ctx) {
self.value.into_int_value()
} else {
panic!("not an int32 type")
}
assert!(self.is_int_like(ctx));
self.value.into_int_value()
}
pub fn is_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
self.is_obj(ctx, PrimDef::Float)
}
/// Convenience function. If object has type `float`, get its underlying LLVM value.
///
/// Panic if the type is wrong.
pub fn into_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Float<'ctx, Float64> {
if self.is_float(ctx) {
// self.value must be a FloatValue
FloatModel(Float64).believe_value(self.value.into_float_value())
} else {
panic!("not a float type")
}
assert!(self.is_float(ctx));
FloatModel(Float64).believe_value(self.value.into_float_value())
}
pub fn is_ndarray(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
@ -158,6 +154,15 @@ impl<'ctx> AnyObject<'ctx> {
NDArrayObject::from_object(generator, ctx, *self)
}
/// Create an object from a boolean from an i1.
///
/// NOTE: In NAC3, booleans are i8. This function does converts the input i1 to an i8.
pub fn from_bool(ctx: &mut CodeGenContext<'ctx, '_>, n: Int<'ctx, Bool>) -> AnyObject<'ctx> {
let llvm_i8 = ctx.ctx.i8_type();
let value = ctx.builder.build_int_z_extend(n.value, llvm_i8, "bool").unwrap();
AnyObject { value: value.as_basic_value_enum(), ty: ctx.primitives.bool }
}
/// Helper function to compare two scalars.
///
/// Only int-to-int and float-to-float comparisons are allowed.
@ -172,9 +177,7 @@ impl<'ctx> AnyObject<'ctx> {
float_predicate: FloatPredicate,
name: &str,
) -> Int<'ctx, Bool> {
if !ctx.unifier.unioned(lhs.ty, rhs.ty) {
panic!("lhs and rhs type are not the same.")
}
assert!(ctx.unifier.unioned(lhs.ty, rhs.ty), "lhs and rhs type should be the same");
let bool_model = IntModel(Bool);
@ -320,6 +323,7 @@ impl<'ctx> AnyObject<'ctx> {
}
// Get the `len()` of this object.
#[must_use]
pub fn call_len<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
@ -378,20 +382,14 @@ impl<'ctx> AnyObject<'ctx> {
}
/// Call `bool()` on this object.
pub fn call_bool<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> AnyObject<'ctx> {
#[must_use]
pub fn call_bool(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> AnyObject<'ctx> {
let n = self.bool(ctx);
// NAC3 booleans are i8
let llvm_i8 = ctx.ctx.i8_type();
let n = ctx.builder.build_int_z_extend(n.value, llvm_i8, "bool").unwrap();
AnyObject { ty: ctx.primitives.bool, value: n.as_basic_value_enum() }
AnyObject::from_bool(ctx, n)
}
/// Call `float()` on this object.
#[must_use]
pub fn call_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> AnyObject<'ctx> {
let f64_model = FloatModel(Float64);
let llvm_f64 = ctx.ctx.f64_type();
@ -449,6 +447,7 @@ impl<'ctx> AnyObject<'ctx> {
// Call `round()` on this object.
//
// It is possible to specify which kind of int type to return.
#[must_use]
pub fn call_round<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
@ -497,6 +496,7 @@ impl<'ctx> AnyObject<'ctx> {
}
/// Call `min()` or `max()` on two objects.
#[must_use]
pub fn call_min_or_max(
ctx: &mut CodeGenContext<'ctx, '_>,
kind: MinOrMax,

View File

@ -34,7 +34,7 @@ use indexing::RustNDIndex;
use inkwell::{
context::Context,
types::BasicType,
values::{BasicValue, BasicValueEnum, PointerValue},
values::{BasicValue, PointerValue},
AddressSpace, IntPredicate,
};
use nditer::NDIterHandle;
@ -570,7 +570,10 @@ impl<'ctx> NDArrayObject<'ctx> {
|generator, ctx| Ok(NDIterHandle::new(generator, ctx, *self)),
|generator, ctx, nditer| Ok(nditer.has_next(generator, ctx).value),
|generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer),
|generator, ctx, nditer| Ok(nditer.next(generator, ctx)),
|generator, ctx, nditer| {
nditer.next(generator, ctx);
Ok(())
},
)
}

View File

@ -30,6 +30,7 @@ impl<'ctx> NDIterHandle<'ctx> {
NDIterHandle { ndarray, instance: nditer, indices }
}
#[must_use]
pub fn has_next<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
@ -43,9 +44,10 @@ impl<'ctx> NDIterHandle<'ctx> {
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) {
call_nac3_nditer_next(generator, ctx, self.instance)
call_nac3_nditer_next(generator, ctx, self.instance);
}
#[must_use]
pub fn get_pointer<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
@ -59,6 +61,7 @@ impl<'ctx> NDIterHandle<'ctx> {
.unwrap()
}
#[must_use]
pub fn get_scalar<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
@ -69,6 +72,7 @@ impl<'ctx> NDIterHandle<'ctx> {
AnyObject { ty: self.ndarray.dtype, value }
}
#[must_use]
pub fn get_index<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
@ -77,6 +81,7 @@ impl<'ctx> NDIterHandle<'ctx> {
self.instance.get(generator, ctx, |f| f.nth, "index")
}
#[must_use]
pub fn get_indices(&self) -> Ptr<'ctx, IntModel<SizeT>> {
self.indices
}

View File

@ -1,4 +1,4 @@
use std::{any::Any, iter::once};
use std::iter::once;
use helper::{
create_ndims, debug_assert_prim_is_allowed, extract_ndims, make_exception_fields,
@ -26,14 +26,13 @@ use crate::{
numpy_new::{self},
object::{
ndarray::{
functions::{FloorOrCeil, MinOrMax},
nalgebra::perform_nalgebra_call,
scalar::{split_scalar_or_ndarray, ScalarObject, ScalarOrNDArray},
scalar::{split_scalar_or_ndarray, ScalarOrNDArray},
shape_util::parse_numpy_int_sequence,
NDArrayObject,
},
tuple::TupleObject,
AnyObject,
AnyObject, FloorOrCeil, MinOrMax,
},
stmt::exn_constructor,
},
@ -1077,7 +1076,7 @@ impl<'a> BuiltinBuilder<'a> {
let this = AnyObject { value: this_arg, ty: this_ty };
let this = NDArrayObject::from_object(generator, ctx, this);
let value = ScalarObject { value: value_arg, dtype: value_ty };
let value = AnyObject { value: value_arg, ty: value_ty };
this.fill(generator, ctx, value);
Ok(None)
@ -1142,7 +1141,7 @@ impl<'a> BuiltinBuilder<'a> {
ctx,
ret_dtype,
|generator, ctx, scalar| {
let result = match prim {
Ok(match prim {
PrimDef::FunInt32 => scalar.call_int32(generator, ctx),
PrimDef::FunInt64 => scalar.call_int64(generator, ctx),
PrimDef::FunUInt32 => scalar.call_uint32(generator, ctx),
@ -1150,8 +1149,7 @@ impl<'a> BuiltinBuilder<'a> {
PrimDef::FunFloat => scalar.call_float(ctx),
PrimDef::FunBool => scalar.call_bool(ctx),
_ => unreachable!(),
};
Ok(result.value)
})
},
)?;
Ok(Some(result.to_basic_value_enum()))
@ -1206,13 +1204,13 @@ impl<'a> BuiltinBuilder<'a> {
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
let arg = AnyObject { ty: arg_ty, value: arg };
let ret_int_dtype = size_variant.of_int(&ctx.primitives);
let ret_int_ty = size_variant.of_int(&ctx.primitives);
let result = split_scalar_or_ndarray(generator, ctx, arg).map(
generator,
ctx,
ret_int_dtype,
|generator, ctx, scalar| Ok(scalar.round(generator, ctx, ret_int_dtype).value),
ret_int_ty,
|generator, ctx, scalar| Ok(scalar.call_round(generator, ctx, ret_int_ty)),
)?;
Ok(Some(result.to_basic_value_enum()))
}),
@ -1277,7 +1275,7 @@ impl<'a> BuiltinBuilder<'a> {
ctx,
int_sized,
|generator, ctx, scalar| {
Ok(scalar.call_floor_or_ceil(generator, ctx, kind, int_sized).value)
Ok(scalar.call_floor_or_ceil(generator, ctx, kind, int_sized))
},
)?;
Ok(Some(result.to_basic_value_enum()))
@ -1477,7 +1475,7 @@ impl<'a> BuiltinBuilder<'a> {
let fill_value_ty = fun.0.args[1].ty;
let fill_value =
args[1].1.clone().to_basic_value_enum(ctx, generator, fill_value_ty)?;
let fill_value = ScalarObject { dtype: fill_value_ty, value: fill_value };
let fill_value = AnyObject { ty: fill_value_ty, value: fill_value };
// Implementation
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
@ -1858,10 +1856,7 @@ impl<'a> BuiltinBuilder<'a> {
generator,
ctx,
ctx.primitives.float,
move |_generator, ctx, scalar| {
let result = scalar.np_floor_or_ceil(ctx, kind);
Ok(result.value)
},
|generator, ctx, scalar| Ok(scalar.call_np_floor_or_ceil(generator, ctx, kind)),
)?;
Ok(Some(result.to_basic_value_enum()))
}),
@ -1887,10 +1882,7 @@ impl<'a> BuiltinBuilder<'a> {
generator,
ctx,
ctx.primitives.float,
|_generator, ctx, scalar| {
let result = scalar.np_round(ctx);
Ok(result.value)
},
|generator, ctx, scalar| Ok(scalar.call_np_round(generator, ctx)),
)?;
Ok(Some(result.to_basic_value_enum()))
}),
@ -1977,9 +1969,9 @@ impl<'a> BuiltinBuilder<'a> {
_ => unreachable!(),
};
let m = ScalarObject { dtype: m_ty, value: m_val };
let n = ScalarObject { dtype: n_ty, value: n_val };
let result = ScalarObject::min_or_max(ctx, kind, m, n);
let m = AnyObject { ty: m_ty, value: m_val };
let n = AnyObject { ty: n_ty, value: n_val };
let result = AnyObject::call_min_or_max(ctx, kind, m, n);
Ok(Some(result.value))
},
)))),
@ -2104,9 +2096,7 @@ impl<'a> BuiltinBuilder<'a> {
|_generator, ctx, scalars| {
let x1 = scalars[0];
let x2 = scalars[1];
let result = ScalarObject::min_or_max(ctx, kind, x1, x2);
Ok(result.value)
Ok(AnyObject::call_min_or_max(ctx, kind, x1, x2))
},
)?;
Ok(Some(result.to_basic_value_enum()))
@ -2148,7 +2138,7 @@ impl<'a> BuiltinBuilder<'a> {
generator,
ctx,
num_ty.ty,
|_generator, ctx, scalar| Ok(scalar.abs(ctx).value),
|generator, ctx, scalar| Ok(scalar.call_abs(generator, ctx)),
)?;
Ok(Some(result.to_basic_value_enum()))
},
@ -2185,9 +2175,10 @@ impl<'a> BuiltinBuilder<'a> {
ctx,
ctx.primitives.bool,
|generator, ctx, scalar| {
let n = scalar.into_float64(ctx);
let n = scalar.into_float(ctx).value;
let n = function(generator, ctx, n);
Ok(n.as_basic_value_enum())
let n = IntModel(Bool).believe_value(n);
Ok(AnyObject::from_bool(ctx, n))
},
)?;
Ok(Some(result.to_basic_value_enum()))
@ -2254,8 +2245,9 @@ impl<'a> BuiltinBuilder<'a> {
ctx,
ctx.primitives.float,
|_generator, ctx, scalar| {
let n = scalar.into_float64(ctx);
let n = match prim {
let n = scalar.into_float(ctx).value;
let result = match prim {
PrimDef::FunNpSin => llvm_intrinsics::call_float_sin(ctx, n, None),
PrimDef::FunNpCos => llvm_intrinsics::call_float_cos(ctx, n, None),
PrimDef::FunNpTan => extern_fns::call_tan(ctx, n, None),
@ -2297,7 +2289,11 @@ impl<'a> BuiltinBuilder<'a> {
_ => unreachable!(),
};
Ok(n.as_basic_value_enum())
Ok(AnyObject {
ty: ctx.primitives.float,
value: result.as_basic_value_enum(),
})
},
)?;
Ok(Some(result.to_basic_value_enum()))
@ -2385,48 +2381,48 @@ impl<'a> BuiltinBuilder<'a> {
// TODO: This looks ugly
let result = match prim {
PrimDef::FunNpArctan2 => {
let x1 = x1.into_float64(ctx);
let x2 = x2.into_float64(ctx);
let x1 = x1.into_float(ctx).value;
let x2 = x2.into_float(ctx).value;
extern_fns::call_atan2(ctx, x1, x2, None).as_basic_value_enum()
}
PrimDef::FunNpCopysign => {
let x1 = x1.into_float64(ctx);
let x2 = x2.into_float64(ctx);
let x1 = x1.into_float(ctx).value;
let x2 = x2.into_float(ctx).value;
llvm_intrinsics::call_float_copysign(ctx, x1, x2, None)
.as_basic_value_enum()
}
PrimDef::FunNpFmax => {
let x1 = x1.into_float64(ctx);
let x2 = x2.into_float64(ctx);
let x1 = x1.into_float(ctx).value;
let x2 = x2.into_float(ctx).value;
llvm_intrinsics::call_float_maxnum(ctx, x1, x2, None)
.as_basic_value_enum()
}
PrimDef::FunNpFmin => {
let x1 = x1.into_float64(ctx);
let x2 = x2.into_float64(ctx);
let x1 = x1.into_float(ctx).value;
let x2 = x2.into_float(ctx).value;
llvm_intrinsics::call_float_minnum(ctx, x1, x2, None)
.as_basic_value_enum()
}
PrimDef::FunNpHypot => {
let x1 = x1.into_float64(ctx);
let x2 = x2.into_float64(ctx);
let x1 = x1.into_float(ctx).value;
let x2 = x2.into_float(ctx).value;
extern_fns::call_hypot(ctx, x1, x2, None).as_basic_value_enum()
}
PrimDef::FunNpNextAfter => {
let x1 = x1.into_float64(ctx);
let x2 = x2.into_float64(ctx);
let x1 = x1.into_float(ctx).value;
let x2 = x2.into_float(ctx).value;
extern_fns::call_nextafter(ctx, x1, x2, None)
.as_basic_value_enum()
}
PrimDef::FunNpLdExp => {
let x1 = x1.into_float64(ctx);
let x2 = x2.into_int32(ctx);
let x1 = x1.into_float(ctx).value;
let x2 = x2.into_int32(ctx).value;
extern_fns::call_ldexp(ctx, x1, x2, None).as_basic_value_enum()
}
_ => unreachable!(),
};
Ok(result)
Ok(AnyObject { ty: ret_dtype, value: result })
},
)?;
@ -2628,13 +2624,9 @@ impl<'a> BuiltinBuilder<'a> {
generator,
ctx,
zero,
ScalarObject { dtype: x2_ty, value: x2_val },
AnyObject { ty: x2_ty, value: x2_val },
);
// alloca_constant_shape
// let x2 = x2.as_ndarray(generator, ctx);
let [out] = perform_nalgebra_call(
generator,
ctx,