forked from M-Labs/nac3
1
0
Fork 0

WIP: core/ndstrides: rename .value to .instance in *Object

This commit is contained in:
lyken 2024-08-14 11:34:42 +08:00
parent 18dcbf5bbc
commit febe78b6a4
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
14 changed files with 136 additions and 124 deletions

View File

@ -1570,16 +1570,16 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
gen_binop_expr_with_values( gen_binop_expr_with_values(
generator, generator,
ctx, ctx,
(&Some(left.dtype), left.value), (&Some(left.dtype), left.instance),
op, op,
(&Some(right.dtype), right.value), (&Some(right.dtype), right.instance),
ctx.current_loc, ctx.current_loc,
)? )?
.unwrap() .unwrap()
.to_basic_value_enum(ctx, generator, common_dtype) .to_basic_value_enum(ctx, generator, common_dtype)
}, },
)?; )?;
Ok(Some(ValueEnum::Dynamic(result.value.value.as_basic_value_enum()))) Ok(Some(ValueEnum::Dynamic(result.instance.value.as_basic_value_enum())))
} }
} else { } else {
let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap()); let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap());

View File

@ -125,7 +125,7 @@ pub fn gen_ndarray_empty<'ctx>(
let ndarray_ty = fun.0.ret; let ndarray_ty = fun.0.ret;
let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape, shape_ty); let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape, shape_ty);
Ok(ndarray.value.value.as_basic_value_enum()) Ok(ndarray.instance.value.as_basic_value_enum())
} }
/// Generates LLVM IR for `np.zero`. /// Generates LLVM IR for `np.zero`.
@ -150,7 +150,7 @@ pub fn gen_ndarray_zeros<'ctx>(
let fill_value = ndarray_zero_value(generator, ctx, ndarray.dtype); let fill_value = ndarray_zero_value(generator, ctx, ndarray.dtype);
ndarray.fill(generator, ctx, fill_value); ndarray.fill(generator, ctx, fill_value);
Ok(ndarray.value.value.as_basic_value_enum()) Ok(ndarray.instance.value.as_basic_value_enum())
} }
/// Generates LLVM IR for `np.ones`. /// Generates LLVM IR for `np.ones`.
@ -175,7 +175,7 @@ pub fn gen_ndarray_ones<'ctx>(
let fill_value = ndarray_one_value(generator, ctx, ndarray.dtype); let fill_value = ndarray_one_value(generator, ctx, ndarray.dtype);
ndarray.fill(generator, ctx, fill_value); ndarray.fill(generator, ctx, fill_value);
Ok(ndarray.value.value.as_basic_value_enum()) Ok(ndarray.instance.value.as_basic_value_enum())
} }
/// Generates LLVM IR for `np.full`. /// Generates LLVM IR for `np.full`.
@ -203,7 +203,7 @@ pub fn gen_ndarray_full<'ctx>(
ndarray.fill(generator, ctx, fill_value); ndarray.fill(generator, ctx, fill_value);
Ok(ndarray.value.value.as_basic_value_enum()) Ok(ndarray.instance.value.as_basic_value_enum())
} }
/// Generates LLVM IR for `np.broadcast_to`. /// Generates LLVM IR for `np.broadcast_to`.
@ -252,7 +252,7 @@ pub fn gen_ndarray_broadcast_to<'ctx>(
let broadcast_ndarray = let broadcast_ndarray =
in_ndarray.broadcast_to(generator, ctx, broadcast_ndims, broadcast_shape); in_ndarray.broadcast_to(generator, ctx, broadcast_ndims, broadcast_shape);
Ok(broadcast_ndarray.value.value.as_basic_value_enum()) Ok(broadcast_ndarray.instance.value.as_basic_value_enum())
} }
/// Generates LLVM IR for `np.reshape`. /// Generates LLVM IR for `np.reshape`.
@ -288,7 +288,7 @@ pub fn gen_ndarray_reshape<'ctx>(
let (_, new_shape) = parse_numpy_int_sequence(generator, ctx, shape, shape_ty); let (_, new_shape) = parse_numpy_int_sequence(generator, ctx, shape, shape_ty);
let reshaped_ndarray = in_ndarray.reshape_or_copy(generator, ctx, reshaped_ndims, new_shape); let reshaped_ndarray = in_ndarray.reshape_or_copy(generator, ctx, reshaped_ndims, new_shape);
Ok(reshaped_ndarray.value.value.as_basic_value_enum()) Ok(reshaped_ndarray.instance.value.as_basic_value_enum())
} }
/// Generates LLVM IR for `np.arange`. /// Generates LLVM IR for `np.arange`.
@ -324,7 +324,7 @@ pub fn gen_ndarray_arange<'ctx>(
// `ndarray.shape[0] = input` // `ndarray.shape[0] = input`
let zero = sizet_model.const_0(generator, ctx.ctx); let zero = sizet_model.const_0(generator, ctx.ctx);
ndarray ndarray
.value .instance
.get(generator, ctx, |f| f.shape, "shape") .get(generator, ctx, |f| f.shape, "shape")
.offset(generator, ctx, zero.value, "dim") .offset(generator, ctx, zero.value, "dim")
.store(ctx, input); .store(ctx, input);
@ -338,7 +338,7 @@ pub fn gen_ndarray_arange<'ctx>(
Ok(()) Ok(())
})?; })?;
Ok(ndarray.value.value.as_basic_value_enum()) Ok(ndarray.instance.value.as_basic_value_enum())
} }
/// Generates LLVM IR for `np.size`. /// Generates LLVM IR for `np.size`.
@ -386,8 +386,10 @@ pub fn gen_ndarray_shape<'ctx>(
for i in 0..ndarray.ndims { for i in 0..ndarray.ndims {
let i = sizet_model.constant(generator, ctx.ctx, i); let i = sizet_model.constant(generator, ctx.ctx, i);
let dim = let dim = ndarray
ndarray.value.get(generator, ctx, |f| f.shape, "").ix(generator, ctx, i.value, "dim"); .instance
.get(generator, ctx, |f| f.shape, "")
.ix(generator, ctx, i.value, "dim");
let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT
items.push((dim.value.as_basic_value_enum(), ctx.primitives.int32)); items.push((dim.value.as_basic_value_enum(), ctx.primitives.int32));
@ -424,8 +426,10 @@ pub fn gen_ndarray_strides<'ctx>(
for i in 0..ndarray.ndims { for i in 0..ndarray.ndims {
let i = sizet_model.constant(generator, ctx.ctx, i); let i = sizet_model.constant(generator, ctx.ctx, i);
let dim = let dim = ndarray
ndarray.value.get(generator, ctx, |f| f.strides, "").ix(generator, ctx, i.value, "dim"); .instance
.get(generator, ctx, |f| f.strides, "")
.ix(generator, ctx, i.value, "dim");
let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT
items.push((dim.value.as_basic_value_enum(), ctx.primitives.int32)); items.push((dim.value.as_basic_value_enum(), ctx.primitives.int32));
@ -471,7 +475,7 @@ pub fn gen_ndarray_transpose<'ctx>(
ndarray.transpose(generator, ctx, None) ndarray.transpose(generator, ctx, None)
}; };
Ok(transposed_ndarray.value.value.as_basic_value_enum()) Ok(transposed_ndarray.instance.value.as_basic_value_enum())
} }
pub fn gen_ndarray_array<'ctx>( pub fn gen_ndarray_array<'ctx>(
@ -515,5 +519,5 @@ pub fn gen_ndarray_array<'ctx>(
let ndarray = ndarray.atleast_nd(generator, ctx, output_ndims); let ndarray = ndarray.atleast_nd(generator, ctx, output_ndims);
debug_assert!(ctx.unifier.unioned(ndarray.dtype, dtype)); // Sanity check on `dtype` debug_assert!(ctx.unifier.unioned(ndarray.dtype, dtype)); // Sanity check on `dtype`
Ok(ndarray.value.value.as_basic_value_enum()) Ok(ndarray.instance.value.as_basic_value_enum())
} }

View File

@ -10,7 +10,7 @@ use crate::{
pub struct ListObject<'ctx> { pub struct ListObject<'ctx> {
/// Typechecker type of the list items /// Typechecker type of the list items
pub item_type: Type, pub item_type: Type,
pub value: Ptr<'ctx, StructModel<List<AnyModel<'ctx>>>>, pub instance: Ptr<'ctx, StructModel<List<AnyModel<'ctx>>>>,
} }
impl<'ctx> ListObject<'ctx> { impl<'ctx> ListObject<'ctx> {
@ -38,7 +38,7 @@ impl<'ctx> ListObject<'ctx> {
// Create object // Create object
let value = plist_model.check_value(generator, ctx.ctx, list_val).unwrap(); let value = plist_model.check_value(generator, ctx.ctx, list_val).unwrap();
ListObject { item_type, value } ListObject { item_type, instance: value }
} }
/// Get the `items` field as an opaque pointer. /// Get the `items` field as an opaque pointer.
@ -47,7 +47,7 @@ impl<'ctx> ListObject<'ctx> {
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
) -> Ptr<'ctx, IntModel<Byte>> { ) -> Ptr<'ctx, IntModel<Byte>> {
self.value.get(generator, ctx, |f| f.items, "items").pointer_cast( self.instance.get(generator, ctx, |f| f.items, "items").pointer_cast(
generator, generator,
ctx, ctx,
IntModel(Byte), IntModel(Byte),
@ -71,7 +71,7 @@ impl<'ctx> ListObject<'ctx> {
opaque_list_ptr.set(ctx, |f| f.items, items); opaque_list_ptr.set(ctx, |f| f.items, items);
// Copy len // Copy len
let len = self.value.get(generator, ctx, |f| f.len, "len"); let len = self.instance.get(generator, ctx, |f| f.len, "len");
opaque_list_ptr.set(ctx, |f| f.len, len); opaque_list_ptr.set(ctx, |f| f.len, len);
opaque_list_ptr opaque_list_ptr

View File

@ -54,7 +54,7 @@ impl<'ctx> NDArrayObject<'ctx> {
ndarray.create_data(generator, ctx); ndarray.create_data(generator, ctx);
// Copy all contents from the list. // Copy all contents from the list.
call_nac3_array_write_list_to_array(generator, ctx, list_value, ndarray.value); call_nac3_array_write_list_to_array(generator, ctx, list_value, ndarray.instance);
ndarray ndarray
} }
@ -85,13 +85,13 @@ impl<'ctx> NDArrayObject<'ctx> {
// Set data // Set data
let data = list.get_opaque_items_ptr(generator, ctx); let data = list.get_opaque_items_ptr(generator, ctx);
ndarray.value.set(ctx, |f| f.data, data); ndarray.instance.set(ctx, |f| f.data, data);
// Set shape // Set shape
// dim = list->len; // dim = list->len;
// shape[0] = dim; // shape[0] = dim;
let shape = ndarray.value.get(generator, ctx, |f| f.shape, "shape"); let shape = ndarray.instance.get(generator, ctx, |f| f.shape, "shape");
let dim = list.value.get(generator, ctx, |f| f.len, "dim"); let dim = list.instance.get(generator, ctx, |f| f.len, "dim");
shape.offset(generator, ctx, zero.value, "pdim").store(ctx, dim); shape.offset(generator, ctx, zero.value, "pdim").store(ctx, dim);
// Set strides, the `data` is contiguous // Set strides, the `data` is contiguous
@ -119,11 +119,11 @@ impl<'ctx> NDArrayObject<'ctx> {
|_generator, _ctx| Ok(copy.value), |_generator, _ctx| Ok(copy.value),
|generator, ctx| { |generator, ctx| {
let ndarray = NDArrayObject::from_np_array_list_copy(generator, ctx, list); let ndarray = NDArrayObject::from_np_array_list_copy(generator, ctx, list);
Ok(Some(ndarray.value.value)) Ok(Some(ndarray.instance.value))
}, },
|generator, ctx| { |generator, ctx| {
let ndarray = NDArrayObject::from_np_array_list_try_no_copy(generator, ctx, list); let ndarray = NDArrayObject::from_np_array_list_try_no_copy(generator, ctx, list);
Ok(Some(ndarray.value.value)) Ok(Some(ndarray.instance.value))
}, },
) )
.unwrap() .unwrap()
@ -144,11 +144,11 @@ impl<'ctx> NDArrayObject<'ctx> {
|_generator, _ctx| Ok(copy.value), |_generator, _ctx| Ok(copy.value),
|generator, ctx| { |generator, ctx| {
let ndarray = ndarray.make_clone(generator, ctx, "np_array_copied_ndarray"); // Force copy let ndarray = ndarray.make_clone(generator, ctx, "np_array_copied_ndarray"); // Force copy
Ok(Some(ndarray.value.value)) Ok(Some(ndarray.instance.value))
}, },
|_generator, _ctx| { |_generator, _ctx| {
// No need to copy. Return `ndarray` itself. // No need to copy. Return `ndarray` itself.
Ok(Some(ndarray.value.value)) Ok(Some(ndarray.instance.value))
}, },
) )
.unwrap() .unwrap()

View File

@ -49,7 +49,7 @@ impl<'ctx> NDArrayObject<'ctx> {
); );
broadcast_ndarray.copy_shape_from_array(generator, ctx, target_shape); broadcast_ndarray.copy_shape_from_array(generator, ctx, target_shape);
call_nac3_ndarray_broadcast_to(generator, ctx, self.value, broadcast_ndarray.value); call_nac3_ndarray_broadcast_to(generator, ctx, self.instance, broadcast_ndarray.instance);
broadcast_ndarray broadcast_ndarray
} }
} }
@ -125,7 +125,9 @@ impl<'ctx> NDArrayObject<'ctx> {
let shape_entries = ndarrays let shape_entries = ndarrays
.iter() .iter()
.map(|ndarray| (ndarray.value.get(generator, ctx, |f| f.shape, "shape"), ndarray.ndims)) .map(|ndarray| {
(ndarray.instance.get(generator, ctx, |f| f.shape, "shape"), ndarray.ndims)
})
.collect_vec(); .collect_vec();
broadcast_shapes(generator, ctx, &shape_entries, broadcast_ndims_int, broadcast_shape); broadcast_shapes(generator, ctx, &shape_entries, broadcast_ndims_int, broadcast_shape);

View File

@ -80,10 +80,10 @@ where
let result = if ctx.unifier.unioned(scalar.dtype, ctx.primitives.float) { let result = if ctx.unifier.unioned(scalar.dtype, ctx.primitives.float) {
// Special handling for floats // Special handling for floats
let n = scalar.value.into_float_value(); let n = scalar.instance.into_float_value();
handle_float(generator, ctx, n) handle_float(generator, ctx, n)
} else if ctx.unifier.unioned_any(scalar.dtype, int_like(ctx)) { } else if ctx.unifier.unioned_any(scalar.dtype, int_like(ctx)) {
let n = scalar.value.into_int_value(); let n = scalar.instance.into_int_value();
if n.get_type().get_bit_width() <= ret_int_dtype_llvm.get_bit_width() { if n.get_type().get_bit_width() <= ret_int_dtype_llvm.get_bit_width() {
ctx.builder.build_int_z_extend(n, ret_int_dtype_llvm, "zext").unwrap() ctx.builder.build_int_z_extend(n, ret_int_dtype_llvm, "zext").unwrap()
@ -95,7 +95,7 @@ where
}; };
assert_eq!(ret_int_dtype_llvm.get_bit_width(), result.get_type().get_bit_width()); // Sanity check assert_eq!(ret_int_dtype_llvm.get_bit_width(), result.get_type().get_bit_width()); // Sanity check
ScalarObject { value: result.into(), dtype: ret_int_dtype } ScalarObject { instance: result.into(), dtype: ret_int_dtype }
} }
impl<'ctx> ScalarObject<'ctx> { impl<'ctx> ScalarObject<'ctx> {
@ -104,7 +104,7 @@ impl<'ctx> ScalarObject<'ctx> {
/// Panic if the type is wrong. /// Panic if the type is wrong.
pub fn into_float64(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> FloatValue<'ctx> { pub fn into_float64(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> FloatValue<'ctx> {
if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
self.value.into_float_value() // self.value must be a FloatValue self.instance.into_float_value() // self.value must be a FloatValue
} else { } else {
panic!("not a float type") panic!("not a float type")
} }
@ -115,7 +115,7 @@ impl<'ctx> ScalarObject<'ctx> {
/// Panic if the type is wrong. /// Panic if the type is wrong.
pub fn into_int32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { pub fn into_int32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
if ctx.unifier.unioned(self.dtype, ctx.primitives.int32) { if ctx.unifier.unioned(self.dtype, ctx.primitives.int32) {
let value = self.value.into_int_value(); let value = self.instance.into_int_value();
debug_assert_eq!(value.get_type().get_bit_width(), 32); // Sanity check debug_assert_eq!(value.get_type().get_bit_width(), 32); // Sanity check
value value
} else { } else {
@ -142,12 +142,12 @@ impl<'ctx> ScalarObject<'ctx> {
let common_ty = lhs.dtype; let common_ty = lhs.dtype;
let result = if ctx.unifier.unioned(common_ty, ctx.primitives.float) { let result = if ctx.unifier.unioned(common_ty, ctx.primitives.float) {
let lhs = lhs.value.into_float_value(); let lhs = lhs.instance.into_float_value();
let rhs = rhs.value.into_float_value(); let rhs = rhs.instance.into_float_value();
ctx.builder.build_float_compare(float_predicate, lhs, rhs, name).unwrap() ctx.builder.build_float_compare(float_predicate, lhs, rhs, name).unwrap()
} else if ctx.unifier.unioned_any(common_ty, int_like(ctx)) { } else if ctx.unifier.unioned_any(common_ty, int_like(ctx)) {
let lhs = lhs.value.into_int_value(); let lhs = lhs.instance.into_int_value();
let rhs = rhs.value.into_int_value(); let rhs = rhs.instance.into_int_value();
ctx.builder.build_int_compare(int_predicate, lhs, rhs, name).unwrap() ctx.builder.build_int_compare(int_predicate, lhs, rhs, name).unwrap()
} else { } else {
unsupported_type(ctx, [lhs.dtype, rhs.dtype]); unsupported_type(ctx, [lhs.dtype, rhs.dtype]);
@ -266,14 +266,14 @@ impl<'ctx> ScalarObject<'ctx> {
pub fn cast_to_bool(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self { pub fn cast_to_bool(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
// TODO: Why is the original code being so lax about i1 and i8 for the returned int type? // TODO: Why is the original code being so lax about i1 and i8 for the returned int type?
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.bool) { let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.bool) {
self.value.into_int_value() self.instance.into_int_value()
} else if ctx.unifier.unioned_any(self.dtype, ints(ctx)) { } else if ctx.unifier.unioned_any(self.dtype, ints(ctx)) {
let n = self.value.into_int_value(); let n = self.instance.into_int_value();
ctx.builder ctx.builder
.build_int_compare(inkwell::IntPredicate::NE, n, n.get_type().const_zero(), "bool") .build_int_compare(inkwell::IntPredicate::NE, n, n.get_type().const_zero(), "bool")
.unwrap() .unwrap()
} else if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { } else if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
let n = self.value.into_float_value(); let n = self.instance.into_float_value();
ctx.builder ctx.builder
.build_float_compare(FloatPredicate::UNE, n, n.get_type().const_zero(), "bool") .build_float_compare(FloatPredicate::UNE, n, n.get_type().const_zero(), "bool")
.unwrap() .unwrap()
@ -281,7 +281,7 @@ impl<'ctx> ScalarObject<'ctx> {
unsupported_type(ctx, [self.dtype]) unsupported_type(ctx, [self.dtype])
}; };
ScalarObject { dtype: ctx.primitives.bool, value: result.as_basic_value_enum() } ScalarObject { dtype: ctx.primitives.bool, instance: result.as_basic_value_enum() }
} }
/// Invoke NAC3's builtin `float()`. /// Invoke NAC3's builtin `float()`.
@ -290,21 +290,21 @@ impl<'ctx> ScalarObject<'ctx> {
let llvm_f64 = ctx.ctx.f64_type(); let llvm_f64 = ctx.ctx.f64_type();
let result: FloatValue<'_> = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { let result: FloatValue<'_> = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
self.value.into_float_value() self.instance.into_float_value()
} else if ctx } else if ctx
.unifier .unifier
.unioned_any(self.dtype, [signed_ints(ctx).as_slice(), &[ctx.primitives.bool]].concat()) .unioned_any(self.dtype, [signed_ints(ctx).as_slice(), &[ctx.primitives.bool]].concat())
{ {
let n = self.value.into_int_value(); let n = self.instance.into_int_value();
ctx.builder.build_signed_int_to_float(n, llvm_f64, "sitofp").unwrap() ctx.builder.build_signed_int_to_float(n, llvm_f64, "sitofp").unwrap()
} else if ctx.unifier.unioned_any(self.dtype, unsigned_ints(ctx)) { } else if ctx.unifier.unioned_any(self.dtype, unsigned_ints(ctx)) {
let n = self.value.into_int_value(); let n = self.instance.into_int_value();
ctx.builder.build_unsigned_int_to_float(n, llvm_f64, "uitofp").unwrap() ctx.builder.build_unsigned_int_to_float(n, llvm_f64, "uitofp").unwrap()
} else { } else {
unsupported_type(ctx, [self.dtype]); unsupported_type(ctx, [self.dtype]);
}; };
ScalarObject { value: result.as_basic_value_enum(), dtype: ctx.primitives.float } ScalarObject { instance: result.as_basic_value_enum(), dtype: ctx.primitives.float }
} }
/// Invoke NAC3's builtin `round()`. /// Invoke NAC3's builtin `round()`.
@ -318,13 +318,13 @@ impl<'ctx> ScalarObject<'ctx> {
let ret_int_dtype_llvm = ctx.get_llvm_type(generator, ret_int_dtype).into_int_type(); let ret_int_dtype_llvm = ctx.get_llvm_type(generator, ret_int_dtype).into_int_type();
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
let n = self.value.into_float_value(); let n = self.instance.into_float_value();
let n = llvm_intrinsics::call_float_round(ctx, n, None); let n = llvm_intrinsics::call_float_round(ctx, n, None);
ctx.builder.build_float_to_signed_int(n, ret_int_dtype_llvm, "round").unwrap() ctx.builder.build_float_to_signed_int(n, ret_int_dtype_llvm, "round").unwrap()
} else { } else {
unsupported_type(ctx, [self.dtype, ret_int_dtype]) unsupported_type(ctx, [self.dtype, ret_int_dtype])
}; };
ScalarObject { dtype: ret_int_dtype, value: result.as_basic_value_enum() } ScalarObject { dtype: ret_int_dtype, instance: result.as_basic_value_enum() }
} }
/// Invoke NAC3's builtin `np_round()`. /// Invoke NAC3's builtin `np_round()`.
@ -333,12 +333,12 @@ impl<'ctx> ScalarObject<'ctx> {
#[must_use] #[must_use]
pub fn np_round(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self { pub fn np_round(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
let n = self.value.into_float_value(); let n = self.instance.into_float_value();
llvm_intrinsics::call_float_rint(ctx, n, None) llvm_intrinsics::call_float_rint(ctx, n, None)
} else { } else {
unsupported_type(ctx, [self.dtype]) unsupported_type(ctx, [self.dtype])
}; };
ScalarObject { dtype: ctx.primitives.float, value: result.as_basic_value_enum() } ScalarObject { dtype: ctx.primitives.float, instance: result.as_basic_value_enum() }
} }
/// Invoke NAC3's builtin `min()` or `max()`. /// Invoke NAC3's builtin `min()` or `max()`.
@ -360,8 +360,8 @@ impl<'ctx> ScalarObject<'ctx> {
MinOrMax::Max => llvm_intrinsics::call_float_maxnum, MinOrMax::Max => llvm_intrinsics::call_float_maxnum,
}; };
let result = let result =
function(ctx, a.value.into_float_value(), b.value.into_float_value(), None); function(ctx, a.instance.into_float_value(), b.instance.into_float_value(), None);
ScalarObject { value: result.as_basic_value_enum(), dtype: ctx.primitives.float } ScalarObject { instance: result.as_basic_value_enum(), dtype: ctx.primitives.float }
} else if ctx.unifier.unioned_any( } else if ctx.unifier.unioned_any(
common_dtype, common_dtype,
[unsigned_ints(ctx).as_slice(), &[ctx.primitives.bool]].concat(), [unsigned_ints(ctx).as_slice(), &[ctx.primitives.bool]].concat(),
@ -371,8 +371,9 @@ impl<'ctx> ScalarObject<'ctx> {
MinOrMax::Min => llvm_intrinsics::call_int_umin, MinOrMax::Min => llvm_intrinsics::call_int_umin,
MinOrMax::Max => llvm_intrinsics::call_int_umax, MinOrMax::Max => llvm_intrinsics::call_int_umax,
}; };
let result = function(ctx, a.value.into_int_value(), b.value.into_int_value(), None); let result =
ScalarObject { value: result.as_basic_value_enum(), dtype: common_dtype } function(ctx, a.instance.into_int_value(), b.instance.into_int_value(), None);
ScalarObject { instance: result.as_basic_value_enum(), dtype: common_dtype }
} else { } else {
unsupported_type(ctx, [common_dtype]) unsupported_type(ctx, [common_dtype])
} }
@ -398,11 +399,11 @@ impl<'ctx> ScalarObject<'ctx> {
FloorOrCeil::Floor => llvm_intrinsics::call_float_floor, FloorOrCeil::Floor => llvm_intrinsics::call_float_floor,
FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil, FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil,
}; };
let n = self.value.into_float_value(); let n = self.instance.into_float_value();
let n = function(ctx, n, None); let n = function(ctx, n, None);
let n = ctx.builder.build_float_to_signed_int(n, ret_int_dtype_llvm, "").unwrap(); let n = ctx.builder.build_float_to_signed_int(n, ret_int_dtype_llvm, "").unwrap();
ScalarObject { dtype: ret_int_dtype, value: n.as_basic_value_enum() } ScalarObject { dtype: ret_int_dtype, instance: n.as_basic_value_enum() }
} else { } else {
unsupported_type(ctx, [self.dtype]) unsupported_type(ctx, [self.dtype])
} }
@ -418,9 +419,9 @@ impl<'ctx> ScalarObject<'ctx> {
FloorOrCeil::Floor => llvm_intrinsics::call_float_floor, FloorOrCeil::Floor => llvm_intrinsics::call_float_floor,
FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil, FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil,
}; };
let n = self.value.into_float_value(); let n = self.instance.into_float_value();
let n = function(ctx, n, None); let n = function(ctx, n, None);
ScalarObject { dtype: ctx.primitives.float, value: n.as_basic_value_enum() } ScalarObject { dtype: ctx.primitives.float, instance: n.as_basic_value_enum() }
} else { } else {
unsupported_type(ctx, [self.dtype]) unsupported_type(ctx, [self.dtype])
} }
@ -430,16 +431,16 @@ impl<'ctx> ScalarObject<'ctx> {
#[must_use] #[must_use]
pub fn abs(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self { pub fn abs(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
let n = self.value.into_float_value(); let n = self.instance.into_float_value();
let n = llvm_intrinsics::call_float_fabs(ctx, n, Some("abs")); let n = llvm_intrinsics::call_float_fabs(ctx, n, Some("abs"));
ScalarObject { value: n.into(), dtype: ctx.primitives.float } ScalarObject { instance: n.into(), dtype: ctx.primitives.float }
} else if ctx.unifier.unioned_any(self.dtype, ints(ctx)) { } else if ctx.unifier.unioned_any(self.dtype, ints(ctx)) {
let n = self.value.into_int_value(); let n = self.instance.into_int_value();
let is_poisoned = ctx.ctx.bool_type().const_zero(); // is_poisoned = false let is_poisoned = ctx.ctx.bool_type().const_zero(); // is_poisoned = false
let n = llvm_intrinsics::call_int_abs(ctx, n, is_poisoned, Some("abs")); let n = llvm_intrinsics::call_int_abs(ctx, n, is_poisoned, Some("abs"));
ScalarObject { value: n.into(), dtype: self.dtype } ScalarObject { instance: n.into(), dtype: self.dtype }
} else { } else {
unsupported_type(ctx, [self.dtype]) unsupported_type(ctx, [self.dtype])
} }
@ -481,7 +482,7 @@ impl<'ctx> NDArrayObject<'ctx> {
pextremum_index.store(ctx, zero); pextremum_index.store(ctx, zero);
let first_scalar = self.get_nth(generator, ctx, zero); let first_scalar = self.get_nth(generator, ctx, zero);
ctx.builder.build_store(pextremum, first_scalar.value).unwrap(); ctx.builder.build_store(pextremum, first_scalar.instance).unwrap();
// Find extremum // Find extremum
let start = sizet_model.const_1(generator, ctx.ctx); // Start on 1 let start = sizet_model.const_1(generator, ctx.ctx); // Start on 1
@ -494,7 +495,7 @@ impl<'ctx> NDArrayObject<'ctx> {
let scalar = self.get_nth(generator, ctx, i); let scalar = self.get_nth(generator, ctx, i);
let old_extremum = ctx.builder.build_load(pextremum, "current_extremum").unwrap(); let old_extremum = ctx.builder.build_load(pextremum, "current_extremum").unwrap();
let old_extremum = ScalarObject { dtype: self.dtype, value: old_extremum }; let old_extremum = ScalarObject { dtype: self.dtype, instance: old_extremum };
let new_extremum = ScalarObject::min_or_max(ctx, kind, old_extremum, scalar); let new_extremum = ScalarObject::min_or_max(ctx, kind, old_extremum, scalar);
@ -522,7 +523,7 @@ impl<'ctx> NDArrayObject<'ctx> {
let extremum_index = pextremum_index.load(generator, ctx, "extremum_index"); let extremum_index = pextremum_index.load(generator, ctx, "extremum_index");
let extremum = ctx.builder.build_load(pextremum, "extremum_value").unwrap(); let extremum = ctx.builder.build_load(pextremum, "extremum_value").unwrap();
let extremum = ScalarObject { dtype: self.dtype, value: extremum }; let extremum = ScalarObject { dtype: self.dtype, instance: extremum };
(extremum, extremum_index) (extremum, extremum_index)
} }

View File

@ -224,8 +224,8 @@ impl<'ctx> NDArrayObject<'ctx> {
ctx, ctx,
num_indexes, num_indexes,
indexes, indexes,
self.value, self.instance,
dst_ndarray.value, dst_ndarray.instance,
); );
dst_ndarray dst_ndarray

View File

@ -137,7 +137,7 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
if let Some(scalars) = all_scalars { if let Some(scalars) = all_scalars {
let i = sizet_model.const_0(generator, ctx.ctx); // Pass 0 as the index let i = sizet_model.const_0(generator, ctx.ctx); // Pass 0 as the index
let scalar = let scalar =
ScalarObject { value: mapping(generator, ctx, i, &scalars)?, dtype: ret_dtype }; ScalarObject { instance: mapping(generator, ctx, i, &scalars)?, dtype: ret_dtype };
Ok(ScalarOrNDArray::Scalar(scalar)) Ok(ScalarOrNDArray::Scalar(scalar))
} else { } else {
// Promote all input to ndarrays and map through them. // Promote all input to ndarrays and map through them.

View File

@ -40,7 +40,7 @@ use util::{call_memcpy_model, gen_for_model_auto};
pub struct NDArrayObject<'ctx> { pub struct NDArrayObject<'ctx> {
pub dtype: Type, pub dtype: Type,
pub ndims: u64, pub ndims: u64,
pub value: Ptr<'ctx, StructModel<NDArray>>, pub instance: Ptr<'ctx, StructModel<NDArray>>,
} }
impl<'ctx> NDArrayObject<'ctx> { impl<'ctx> NDArrayObject<'ctx> {
@ -67,7 +67,7 @@ impl<'ctx> NDArrayObject<'ctx> {
) -> Self { ) -> Self {
let pndarray_model = PtrModel(StructModel(NDArray)); let pndarray_model = PtrModel(StructModel(NDArray));
let value = pndarray_model.check_value(generator, ctx.ctx, value).unwrap(); let value = pndarray_model.check_value(generator, ctx.ctx, value).unwrap();
NDArrayObject { dtype, ndims, value } NDArrayObject { dtype, ndims, instance: value }
} }
/// Create a [`SimpleNDArray`] from the contents of this ndarray. /// Create a [`SimpleNDArray`] from the contents of this ndarray.
@ -106,7 +106,7 @@ impl<'ctx> NDArrayObject<'ctx> {
let ndims = self.get_ndims(generator, ctx.ctx); let ndims = self.get_ndims(generator, ctx.ctx);
result.set(ctx, |f| f.ndims, ndims); result.set(ctx, |f| f.ndims, ndims);
let shape = self.value.get(generator, ctx, |f| f.shape, "shape"); let shape = self.instance.get(generator, ctx, |f| f.shape, "shape");
result.set(ctx, |f| f.shape, shape); result.set(ctx, |f| f.shape, shape);
// Set data, but we do things differently if this ndarray is contiguous. // Set data, but we do things differently if this ndarray is contiguous.
@ -114,7 +114,7 @@ impl<'ctx> NDArrayObject<'ctx> {
ctx.builder.build_conditional_branch(is_contiguous.value, then_bb, else_bb).unwrap(); ctx.builder.build_conditional_branch(is_contiguous.value, then_bb, else_bb).unwrap();
// Inserting into then_bb; This ndarray is contiguous. // Inserting into then_bb; This ndarray is contiguous.
let data = self.value.get(generator, ctx, |f| f.data, ""); let data = self.instance.get(generator, ctx, |f| f.data, "");
let data = data.pointer_cast(generator, ctx, item_model, ""); let data = data.pointer_cast(generator, ctx, item_model, "");
result.set(ctx, |f| f.data, data); result.set(ctx, |f| f.data, data);
ctx.builder.build_unconditional_branch(end_bb).unwrap(); ctx.builder.build_unconditional_branch(end_bb).unwrap();
@ -123,7 +123,7 @@ impl<'ctx> NDArrayObject<'ctx> {
// TODO: Reimplement this? This method does give us the contiguous `data`, but // TODO: Reimplement this? This method does give us the contiguous `data`, but
// this creates a few extra bytes of useless information because an entire NDArray // this creates a few extra bytes of useless information because an entire NDArray
// is allocated. Though this is super convenient. // is allocated. Though this is super convenient.
let data = self.make_clone(generator, ctx, "").value.get(generator, ctx, |f| f.data, ""); let data = self.make_clone(generator, ctx, "").instance.get(generator, ctx, |f| f.data, "");
let data = data.pointer_cast(generator, ctx, item_model, ""); let data = data.pointer_cast(generator, ctx, item_model, "");
result.set(ctx, |f| f.data, data); result.set(ctx, |f| f.data, data);
ctx.builder.build_unconditional_branch(end_bb).unwrap(); ctx.builder.build_unconditional_branch(end_bb).unwrap();
@ -166,10 +166,10 @@ impl<'ctx> NDArrayObject<'ctx> {
let data = simple_ndarray let data = simple_ndarray
.get(generator, ctx, |f| f.data, "") .get(generator, ctx, |f| f.data, "")
.pointer_cast(generator, ctx, byte_model, "data"); .pointer_cast(generator, ctx, byte_model, "data");
ndarray.value.set(ctx, |f| f.data, data); ndarray.instance.set(ctx, |f| f.data, data);
let shape = simple_ndarray.get(generator, ctx, |f| f.shape, "shape"); let shape = simple_ndarray.get(generator, ctx, |f| f.shape, "shape");
ndarray.value.set(ctx, |f| f.shape, shape); ndarray.instance.set(ctx, |f| f.shape, shape);
// Set strides. We know `data` is contiguous. // Set strides. We know `data` is contiguous.
ndarray.update_strides_by_shape(generator, ctx); ndarray.update_strides_by_shape(generator, ctx);
@ -183,7 +183,7 @@ impl<'ctx> NDArrayObject<'ctx> {
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, SizeT> { ) -> Int<'ctx, SizeT> {
call_nac3_ndarray_size(generator, ctx, self.value) call_nac3_ndarray_size(generator, ctx, self.instance)
} }
/// Get the `ndarray.nbytes` of this ndarray. /// Get the `ndarray.nbytes` of this ndarray.
@ -192,7 +192,7 @@ impl<'ctx> NDArrayObject<'ctx> {
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, SizeT> { ) -> Int<'ctx, SizeT> {
call_nac3_ndarray_nbytes(generator, ctx, self.value) call_nac3_ndarray_nbytes(generator, ctx, self.instance)
} }
/// Get the `len()` of this ndarray. /// Get the `len()` of this ndarray.
@ -201,7 +201,7 @@ impl<'ctx> NDArrayObject<'ctx> {
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, SizeT> { ) -> Int<'ctx, SizeT> {
call_nac3_ndarray_len(generator, ctx, self.value) call_nac3_ndarray_len(generator, ctx, self.instance)
} }
/// Check if this ndarray is C-contiguous. /// Check if this ndarray is C-contiguous.
@ -212,7 +212,7 @@ impl<'ctx> NDArrayObject<'ctx> {
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, Bool> { ) -> Int<'ctx, Bool> {
call_nac3_ndarray_is_c_contiguous(generator, ctx, self.value) call_nac3_ndarray_is_c_contiguous(generator, ctx, self.instance)
} }
/// Get the pointer to the n-th (0-based) element. /// Get the pointer to the n-th (0-based) element.
@ -227,7 +227,7 @@ impl<'ctx> NDArrayObject<'ctx> {
) -> PointerValue<'ctx> { ) -> PointerValue<'ctx> {
let elem_ty = ctx.get_llvm_type(generator, self.dtype); let elem_ty = ctx.get_llvm_type(generator, self.dtype);
let p = call_nac3_ndarray_get_nth_pelement(generator, ctx, self.value, nth); let p = call_nac3_ndarray_get_nth_pelement(generator, ctx, self.instance, nth);
ctx.builder ctx.builder
.build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), name) .build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), name)
.unwrap() .unwrap()
@ -242,7 +242,7 @@ impl<'ctx> NDArrayObject<'ctx> {
) -> ScalarObject<'ctx> { ) -> ScalarObject<'ctx> {
let p = self.get_nth_pointer(generator, ctx, nth, "value"); let p = self.get_nth_pointer(generator, ctx, nth, "value");
let value = ctx.builder.build_load(p, "value").unwrap(); let value = ctx.builder.build_load(p, "value").unwrap();
ScalarObject { dtype: self.dtype, value } ScalarObject { dtype: self.dtype, instance: value }
} }
/// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`. /// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`.
@ -253,7 +253,7 @@ impl<'ctx> NDArrayObject<'ctx> {
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
) { ) {
call_nac3_ndarray_set_strides_by_shape(generator, ctx, self.value); call_nac3_ndarray_set_strides_by_shape(generator, ctx, self.instance);
} }
/// Copy data from another ndarray. /// Copy data from another ndarray.
@ -269,7 +269,7 @@ impl<'ctx> NDArrayObject<'ctx> {
src: NDArrayObject<'ctx>, src: NDArrayObject<'ctx>,
) { ) {
assert!(ctx.unifier.unioned(self.dtype, src.dtype), "self and src dtype should match"); assert!(ctx.unifier.unioned(self.dtype, src.dtype), "self and src dtype should match");
call_nac3_ndarray_copy_data(generator, ctx, src.value, self.value); call_nac3_ndarray_copy_data(generator, ctx, src.instance, self.instance);
} }
/// Allocate an ndarray on the stack given its `ndims` and `dtype`. /// Allocate an ndarray on the stack given its `ndims` and `dtype`.
@ -312,7 +312,7 @@ impl<'ctx> NDArrayObject<'ctx> {
let strides = sizet_model.array_alloca(generator, ctx, ndims_val.value, "alloca_strides"); let strides = sizet_model.array_alloca(generator, ctx, ndims_val.value, "alloca_strides");
pndarray.set(ctx, |f| f.strides, strides); pndarray.set(ctx, |f| f.strides, strides);
NDArrayObject { dtype, ndims, value: pndarray } NDArrayObject { dtype, ndims, instance: pndarray }
} }
/// Convenience function. /// Convenience function.
@ -342,7 +342,7 @@ impl<'ctx> NDArrayObject<'ctx> {
let clone = let clone =
NDArrayObject::alloca_uninitialized(generator, ctx, self.dtype, self.ndims, name); NDArrayObject::alloca_uninitialized(generator, ctx, self.dtype, self.ndims, name);
let shape = self.value.gep(ctx, |f| f.shape).load(generator, ctx, "shape"); let shape = self.instance.gep(ctx, |f| f.shape).load(generator, ctx, "shape");
clone.copy_shape_from_array(generator, ctx, shape); clone.copy_shape_from_array(generator, ctx, shape);
clone.create_data(generator, ctx); clone.create_data(generator, ctx);
clone.copy_data_from(generator, ctx, *self); clone.copy_data_from(generator, ctx, *self);
@ -412,7 +412,7 @@ impl<'ctx> NDArrayObject<'ctx> {
let nbytes = self.nbytes(generator, ctx); let nbytes = self.nbytes(generator, ctx);
let data = byte_model.array_alloca(generator, ctx, nbytes.value, "data"); let data = byte_model.array_alloca(generator, ctx, nbytes.value, "data");
self.value.set(ctx, |f| f.data, data); self.instance.set(ctx, |f| f.data, data);
self.update_strides_by_shape(generator, ctx); self.update_strides_by_shape(generator, ctx);
} }
@ -424,7 +424,7 @@ impl<'ctx> NDArrayObject<'ctx> {
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
src_shape: Ptr<'ctx, IntModel<SizeT>>, src_shape: Ptr<'ctx, IntModel<SizeT>>,
) { ) {
let dst_shape = self.value.get(generator, ctx, |f| f.shape, "dst_shape"); let dst_shape = self.instance.get(generator, ctx, |f| f.shape, "dst_shape");
let num_items = self.get_ndims(generator, ctx.ctx).value; let num_items = self.get_ndims(generator, ctx.ctx).value;
call_memcpy_model(generator, ctx, dst_shape, src_shape, num_items); call_memcpy_model(generator, ctx, dst_shape, src_shape, num_items);
} }
@ -438,7 +438,7 @@ impl<'ctx> NDArrayObject<'ctx> {
src_ndarray: NDArrayObject<'ctx>, src_ndarray: NDArrayObject<'ctx>,
) { ) {
assert_eq!(self.ndims, src_ndarray.ndims); assert_eq!(self.ndims, src_ndarray.ndims);
let src_shape = src_ndarray.value.get(generator, ctx, |f| f.shape, "src_shape"); let src_shape = src_ndarray.instance.get(generator, ctx, |f| f.shape, "src_shape");
self.copy_shape_from_array(generator, ctx, src_shape); self.copy_shape_from_array(generator, ctx, src_shape);
} }
@ -449,7 +449,7 @@ impl<'ctx> NDArrayObject<'ctx> {
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
src_strides: Ptr<'ctx, IntModel<SizeT>>, src_strides: Ptr<'ctx, IntModel<SizeT>>,
) { ) {
let dst_strides = self.value.get(generator, ctx, |f| f.strides, "dst_strides"); let dst_strides = self.instance.get(generator, ctx, |f| f.strides, "dst_strides");
let num_items = self.get_ndims(generator, ctx.ctx).value; let num_items = self.get_ndims(generator, ctx.ctx).value;
call_memcpy_model(generator, ctx, dst_strides, src_strides, num_items); call_memcpy_model(generator, ctx, dst_strides, src_strides, num_items);
} }
@ -463,7 +463,7 @@ impl<'ctx> NDArrayObject<'ctx> {
src_ndarray: NDArrayObject<'ctx>, src_ndarray: NDArrayObject<'ctx>,
) { ) {
assert_eq!(self.ndims, src_ndarray.ndims); assert_eq!(self.ndims, src_ndarray.ndims);
let src_strides = src_ndarray.value.get(generator, ctx, |f| f.strides, "src_strides"); let src_strides = src_ndarray.instance.get(generator, ctx, |f| f.strides, "src_strides");
self.copy_strides_from_array(generator, ctx, src_strides); self.copy_strides_from_array(generator, ctx, src_strides);
} }
@ -518,7 +518,7 @@ impl<'ctx> NDArrayObject<'ctx> {
{ {
self.foreach_pointer(generator, ctx, |generator, ctx, hooks, i, p| { self.foreach_pointer(generator, ctx, |generator, ctx, hooks, i, p| {
let value = ctx.builder.build_load(p, "value").unwrap(); let value = ctx.builder.build_load(p, "value").unwrap();
let scalar = ScalarObject { dtype: self.dtype, value }; let scalar = ScalarObject { dtype: self.dtype, instance: value };
body(generator, ctx, hooks, i, scalar) body(generator, ctx, hooks, i, scalar)
}) })
} }
@ -606,7 +606,11 @@ impl<'ctx> NDArrayObject<'ctx> {
// Inserting into then_bb: reshape is possible without copying // Inserting into then_bb: reshape is possible without copying
ctx.builder.position_at_end(then_bb); ctx.builder.position_at_end(then_bb);
dst_ndarray.update_strides_by_shape(generator, ctx); dst_ndarray.update_strides_by_shape(generator, ctx);
dst_ndarray.value.set(ctx, |f| f.data, self.value.get(generator, ctx, |f| f.data, "data")); dst_ndarray.instance.set(
ctx,
|f| f.data,
self.instance.get(generator, ctx, |f| f.data, "data"),
);
ctx.builder.build_unconditional_branch(end_bb).unwrap(); ctx.builder.build_unconditional_branch(end_bb).unwrap();
// Inserting into else_bb: reshape is impossible without copying // Inserting into else_bb: reshape is impossible without copying
@ -672,8 +676,8 @@ impl<'ctx> NDArrayObject<'ctx> {
call_nac3_ndarray_transpose( call_nac3_ndarray_transpose(
generator, generator,
ctx, ctx,
self.value, self.instance,
transposed_ndarray.value, transposed_ndarray.instance,
num_axes, num_axes,
axes, axes,
); );
@ -694,7 +698,7 @@ impl<'ctx> NDArrayObject<'ctx> {
let sizet_model = IntModel(SizeT); let sizet_model = IntModel(SizeT);
let ndarray_ndims = self.get_ndims(generator, ctx.ctx); let ndarray_ndims = self.get_ndims(generator, ctx.ctx);
let ndarray_shape = self.value.get(generator, ctx, |f| f.shape, "shape"); let ndarray_shape = self.instance.get(generator, ctx, |f| f.shape, "shape");
let output_ndims = sizet_model.constant(generator, ctx.ctx, out_ndims); let output_ndims = sizet_model.constant(generator, ctx.ctx, out_ndims);
let output_shape = out_shape; let output_shape = out_shape;

View File

@ -30,9 +30,9 @@ impl<'ctx> NDArrayObject<'ctx> {
let final_ndims_int = max(a.ndims, b.ndims); let final_ndims_int = max(a.ndims, b.ndims);
let a_ndims = a.get_ndims(generator, ctx.ctx); let a_ndims = a.get_ndims(generator, ctx.ctx);
let a_shape = a.value.get(generator, ctx, |f| f.shape, "a_shape"); let a_shape = a.instance.get(generator, ctx, |f| f.shape, "a_shape");
let b_ndims = b.get_ndims(generator, ctx.ctx); let b_ndims = b.get_ndims(generator, ctx.ctx);
let b_shape = b.value.get(generator, ctx, |f| f.shape, "b_shape"); let b_shape = b.instance.get(generator, ctx, |f| f.shape, "b_shape");
let final_ndims = sizet_model.constant(generator, ctx.ctx, final_ndims_int); let final_ndims = sizet_model.constant(generator, ctx.ctx, final_ndims_int);
let new_a_shape = let new_a_shape =
sizet_model.array_alloca(generator, ctx, final_ndims.value, "new_a_shape"); sizet_model.array_alloca(generator, ctx, final_ndims.value, "new_a_shape");
@ -68,9 +68,9 @@ impl<'ctx> NDArrayObject<'ctx> {
call_nac3_ndarray_float64_matmul_at_least_2d( call_nac3_ndarray_float64_matmul_at_least_2d(
generator, generator,
ctx, ctx,
new_a.value, new_a.instance,
new_b.value, new_b.instance,
dst.value, dst.instance,
); );
dst dst
@ -147,7 +147,7 @@ impl<'ctx> NDArrayObject<'ctx> {
} }
NDArrayOut::WriteToNDArray { ndarray: out_ndarray } => { NDArrayOut::WriteToNDArray { ndarray: out_ndarray } => {
// TODO: It is possible to check the shapes before computing the matmul to save resources. // TODO: It is possible to check the shapes before computing the matmul to save resources.
let result_shape = result.value.get(generator, ctx, |f| f.shape, "result_shape"); let result_shape = result.instance.get(generator, ctx, |f| f.shape, "result_shape");
out_ndarray.check_can_be_written_by_out(generator, ctx, result.ndims, result_shape); out_ndarray.check_can_be_written_by_out(generator, ctx, result.ndims, result_shape);
// TODO: We can just set `out_ndarray.data` to `result.data`. Should we? // TODO: We can just set `out_ndarray.data` to `result.data`. Should we?

View File

@ -15,7 +15,7 @@ use super::NDArrayObject;
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub struct ScalarObject<'ctx> { pub struct ScalarObject<'ctx> {
pub dtype: Type, pub dtype: Type,
pub value: BasicValueEnum<'ctx>, pub instance: BasicValueEnum<'ctx>,
} }
impl<'ctx> ScalarObject<'ctx> { impl<'ctx> ScalarObject<'ctx> {
@ -31,13 +31,13 @@ impl<'ctx> ScalarObject<'ctx> {
let pbyte_model = PtrModel(IntModel(Byte)); let pbyte_model = PtrModel(IntModel(Byte));
// We have to put the value on the stack to get a data pointer. // We have to put the value on the stack to get a data pointer.
let data = ctx.builder.build_alloca(self.value.get_type(), "as_ndarray_scalar").unwrap(); let data = ctx.builder.build_alloca(self.instance.get_type(), "as_ndarray_scalar").unwrap();
ctx.builder.build_store(data, self.value).unwrap(); ctx.builder.build_store(data, self.instance).unwrap();
let data = pbyte_model.pointer_cast(generator, ctx, data, "data"); let data = pbyte_model.pointer_cast(generator, ctx, data, "data");
let ndarray = let ndarray =
NDArrayObject::alloca_uninitialized(generator, ctx, self.dtype, 0, "scalar_ndarray"); NDArrayObject::alloca_uninitialized(generator, ctx, self.dtype, 0, "scalar_ndarray");
ndarray.value.set(ctx, |f| f.data, data); ndarray.instance.set(ctx, |f| f.data, data);
ndarray ndarray
} }
} }
@ -54,8 +54,8 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
#[must_use] #[must_use]
pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> { pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> {
match self { match self {
ScalarOrNDArray::Scalar(scalar) => scalar.value, ScalarOrNDArray::Scalar(scalar) => scalar.instance,
ScalarOrNDArray::NDArray(ndarray) => ndarray.value.value.as_basic_value_enum(), ScalarOrNDArray::NDArray(ndarray) => ndarray.instance.value.as_basic_value_enum(),
} }
} }
@ -136,7 +136,7 @@ pub fn split_scalar_or_ndarray<'ctx, G: CodeGenerator + ?Sized>(
ScalarOrNDArray::NDArray(ndarray) ScalarOrNDArray::NDArray(ndarray)
} }
_ => { _ => {
let scalar = ScalarObject { dtype: input_ty, value: input }; let scalar = ScalarObject { dtype: input_ty, instance: input };
ScalarOrNDArray::Scalar(scalar) ScalarOrNDArray::Scalar(scalar)
} }
} }

View File

@ -39,14 +39,14 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
let input_sequence = let input_sequence =
ListObject::from_value_and_type(generator, ctx, input_sequence, input_sequence_ty); ListObject::from_value_and_type(generator, ctx, input_sequence, input_sequence_ty);
let len = input_sequence.value.gep(ctx, |f| f.len).load(generator, ctx, "len"); let len = input_sequence.instance.gep(ctx, |f| f.len).load(generator, ctx, "len");
let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence"); let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence");
// Load all the `int32`s from the input_sequence, cast them to `SizeT`, and store them into `result` // Load all the `int32`s from the input_sequence, cast them to `SizeT`, and store them into `result`
gen_for_model_auto(generator, ctx, zero, len, one, |generator, ctx, _hooks, i| { gen_for_model_auto(generator, ctx, zero, len, one, |generator, ctx, _hooks, i| {
// Load the i-th int32 in the input sequence // Load the i-th int32 in the input sequence
let int = input_sequence let int = input_sequence
.value .instance
.get(generator, ctx, |f| f.items, "int") .get(generator, ctx, |f| f.items, "int")
.ix(generator, ctx, i.value, "int") .ix(generator, ctx, i.value, "int")
.value .value

View File

@ -1105,7 +1105,7 @@ impl<'a> BuiltinBuilder<'a> {
PrimDef::FunBool => scalar.cast_to_bool(ctx), PrimDef::FunBool => scalar.cast_to_bool(ctx),
_ => unreachable!(), _ => unreachable!(),
}; };
Ok(result.value) Ok(result.instance)
}, },
)?; )?;
Ok(Some(result.to_basic_value_enum())) Ok(Some(result.to_basic_value_enum()))
@ -1166,7 +1166,7 @@ impl<'a> BuiltinBuilder<'a> {
ctx, ctx,
ret_int_dtype, ret_int_dtype,
|generator, ctx, _i, scalar| { |generator, ctx, _i, scalar| {
Ok(scalar.round(generator, ctx, ret_int_dtype).value) Ok(scalar.round(generator, ctx, ret_int_dtype).instance)
}, },
)?; )?;
Ok(Some(result.to_basic_value_enum())) Ok(Some(result.to_basic_value_enum()))
@ -1231,7 +1231,7 @@ impl<'a> BuiltinBuilder<'a> {
ctx, ctx,
int_sized, int_sized,
|generator, ctx, _i, scalar| { |generator, ctx, _i, scalar| {
Ok(scalar.floor_or_ceil(generator, ctx, kind, int_sized).value) Ok(scalar.floor_or_ceil(generator, ctx, kind, int_sized).instance)
}, },
)?; )?;
Ok(Some(result.to_basic_value_enum())) Ok(Some(result.to_basic_value_enum()))
@ -1631,7 +1631,7 @@ impl<'a> BuiltinBuilder<'a> {
ctx.primitives.float, ctx.primitives.float,
move |_generator, ctx, _i, scalar| { move |_generator, ctx, _i, scalar| {
let result = scalar.np_floor_or_ceil(ctx, kind); let result = scalar.np_floor_or_ceil(ctx, kind);
Ok(result.value) Ok(result.instance)
}, },
)?; )?;
Ok(Some(result.to_basic_value_enum())) Ok(Some(result.to_basic_value_enum()))
@ -1659,7 +1659,7 @@ impl<'a> BuiltinBuilder<'a> {
ctx.primitives.float, ctx.primitives.float,
|_generator, ctx, _i, scalar| { |_generator, ctx, _i, scalar| {
let result = scalar.np_round(ctx); let result = scalar.np_round(ctx);
Ok(result.value) Ok(result.instance)
}, },
)?; )?;
Ok(Some(result.to_basic_value_enum())) Ok(Some(result.to_basic_value_enum()))
@ -1746,10 +1746,10 @@ impl<'a> BuiltinBuilder<'a> {
_ => unreachable!(), _ => unreachable!(),
}; };
let m = ScalarObject { dtype: m_ty, value: m_val }; let m = ScalarObject { dtype: m_ty, instance: m_val };
let n = ScalarObject { dtype: n_ty, value: n_val }; let n = ScalarObject { dtype: n_ty, instance: n_val };
let result = ScalarObject::min_or_max(ctx, kind, m, n); let result = ScalarObject::min_or_max(ctx, kind, m, n);
Ok(Some(result.value)) Ok(Some(result.instance))
}, },
)))), )))),
loc: None, loc: None,
@ -1802,10 +1802,10 @@ impl<'a> BuiltinBuilder<'a> {
.value .value
.as_basic_value_enum(), .as_basic_value_enum(),
PrimDef::FunNpMin => { PrimDef::FunNpMin => {
a.min_or_max(generator, ctx, MinOrMax::Min).value.as_basic_value_enum() a.min_or_max(generator, ctx, MinOrMax::Min).instance.as_basic_value_enum()
} }
PrimDef::FunNpMax => { PrimDef::FunNpMax => {
a.min_or_max(generator, ctx, MinOrMax::Max).value.as_basic_value_enum() a.min_or_max(generator, ctx, MinOrMax::Max).instance.as_basic_value_enum()
} }
_ => unreachable!(), _ => unreachable!(),
}; };
@ -1871,7 +1871,7 @@ impl<'a> BuiltinBuilder<'a> {
let x2 = scalars[1]; let x2 = scalars[1];
let result = ScalarObject::min_or_max(ctx, kind, x1, x2); let result = ScalarObject::min_or_max(ctx, kind, x1, x2);
Ok(result.value) Ok(result.instance)
}, },
)?; )?;
Ok(Some(result.to_basic_value_enum())) Ok(Some(result.to_basic_value_enum()))
@ -1912,7 +1912,7 @@ impl<'a> BuiltinBuilder<'a> {
generator, generator,
ctx, ctx,
num_ty.ty, num_ty.ty,
|_generator, ctx, _i, scalar| Ok(scalar.abs(ctx).value), |_generator, ctx, _i, scalar| Ok(scalar.abs(ctx).instance),
)?; )?;
Ok(Some(result.to_basic_value_enum())) Ok(Some(result.to_basic_value_enum()))
}, },