forked from M-Labs/nac3
1
0
Fork 0

WIP: core/ndstrides: rename .value to .instance in *Object

This commit is contained in:
lyken 2024-08-14 11:34:42 +08:00
parent 18dcbf5bbc
commit febe78b6a4
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
14 changed files with 136 additions and 124 deletions

View File

@ -1570,16 +1570,16 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
gen_binop_expr_with_values(
generator,
ctx,
(&Some(left.dtype), left.value),
(&Some(left.dtype), left.instance),
op,
(&Some(right.dtype), right.value),
(&Some(right.dtype), right.instance),
ctx.current_loc,
)?
.unwrap()
.to_basic_value_enum(ctx, generator, common_dtype)
},
)?;
Ok(Some(ValueEnum::Dynamic(result.value.value.as_basic_value_enum())))
Ok(Some(ValueEnum::Dynamic(result.instance.value.as_basic_value_enum())))
}
} else {
let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap());

View File

@ -125,7 +125,7 @@ pub fn gen_ndarray_empty<'ctx>(
let ndarray_ty = fun.0.ret;
let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape, shape_ty);
Ok(ndarray.value.value.as_basic_value_enum())
Ok(ndarray.instance.value.as_basic_value_enum())
}
/// Generates LLVM IR for `np.zero`.
@ -150,7 +150,7 @@ pub fn gen_ndarray_zeros<'ctx>(
let fill_value = ndarray_zero_value(generator, ctx, ndarray.dtype);
ndarray.fill(generator, ctx, fill_value);
Ok(ndarray.value.value.as_basic_value_enum())
Ok(ndarray.instance.value.as_basic_value_enum())
}
/// Generates LLVM IR for `np.ones`.
@ -175,7 +175,7 @@ pub fn gen_ndarray_ones<'ctx>(
let fill_value = ndarray_one_value(generator, ctx, ndarray.dtype);
ndarray.fill(generator, ctx, fill_value);
Ok(ndarray.value.value.as_basic_value_enum())
Ok(ndarray.instance.value.as_basic_value_enum())
}
/// Generates LLVM IR for `np.full`.
@ -203,7 +203,7 @@ pub fn gen_ndarray_full<'ctx>(
ndarray.fill(generator, ctx, fill_value);
Ok(ndarray.value.value.as_basic_value_enum())
Ok(ndarray.instance.value.as_basic_value_enum())
}
/// Generates LLVM IR for `np.broadcast_to`.
@ -252,7 +252,7 @@ pub fn gen_ndarray_broadcast_to<'ctx>(
let broadcast_ndarray =
in_ndarray.broadcast_to(generator, ctx, broadcast_ndims, broadcast_shape);
Ok(broadcast_ndarray.value.value.as_basic_value_enum())
Ok(broadcast_ndarray.instance.value.as_basic_value_enum())
}
/// Generates LLVM IR for `np.reshape`.
@ -288,7 +288,7 @@ pub fn gen_ndarray_reshape<'ctx>(
let (_, new_shape) = parse_numpy_int_sequence(generator, ctx, shape, shape_ty);
let reshaped_ndarray = in_ndarray.reshape_or_copy(generator, ctx, reshaped_ndims, new_shape);
Ok(reshaped_ndarray.value.value.as_basic_value_enum())
Ok(reshaped_ndarray.instance.value.as_basic_value_enum())
}
/// Generates LLVM IR for `np.arange`.
@ -324,7 +324,7 @@ pub fn gen_ndarray_arange<'ctx>(
// `ndarray.shape[0] = input`
let zero = sizet_model.const_0(generator, ctx.ctx);
ndarray
.value
.instance
.get(generator, ctx, |f| f.shape, "shape")
.offset(generator, ctx, zero.value, "dim")
.store(ctx, input);
@ -338,7 +338,7 @@ pub fn gen_ndarray_arange<'ctx>(
Ok(())
})?;
Ok(ndarray.value.value.as_basic_value_enum())
Ok(ndarray.instance.value.as_basic_value_enum())
}
/// Generates LLVM IR for `np.size`.
@ -386,8 +386,10 @@ pub fn gen_ndarray_shape<'ctx>(
for i in 0..ndarray.ndims {
let i = sizet_model.constant(generator, ctx.ctx, i);
let dim =
ndarray.value.get(generator, ctx, |f| f.shape, "").ix(generator, ctx, i.value, "dim");
let dim = ndarray
.instance
.get(generator, ctx, |f| f.shape, "")
.ix(generator, ctx, i.value, "dim");
let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT
items.push((dim.value.as_basic_value_enum(), ctx.primitives.int32));
@ -424,8 +426,10 @@ pub fn gen_ndarray_strides<'ctx>(
for i in 0..ndarray.ndims {
let i = sizet_model.constant(generator, ctx.ctx, i);
let dim =
ndarray.value.get(generator, ctx, |f| f.strides, "").ix(generator, ctx, i.value, "dim");
let dim = ndarray
.instance
.get(generator, ctx, |f| f.strides, "")
.ix(generator, ctx, i.value, "dim");
let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT
items.push((dim.value.as_basic_value_enum(), ctx.primitives.int32));
@ -471,7 +475,7 @@ pub fn gen_ndarray_transpose<'ctx>(
ndarray.transpose(generator, ctx, None)
};
Ok(transposed_ndarray.value.value.as_basic_value_enum())
Ok(transposed_ndarray.instance.value.as_basic_value_enum())
}
pub fn gen_ndarray_array<'ctx>(
@ -515,5 +519,5 @@ pub fn gen_ndarray_array<'ctx>(
let ndarray = ndarray.atleast_nd(generator, ctx, output_ndims);
debug_assert!(ctx.unifier.unioned(ndarray.dtype, dtype)); // Sanity check on `dtype`
Ok(ndarray.value.value.as_basic_value_enum())
Ok(ndarray.instance.value.as_basic_value_enum())
}

View File

@ -10,7 +10,7 @@ use crate::{
pub struct ListObject<'ctx> {
/// Typechecker type of the list items
pub item_type: Type,
pub value: Ptr<'ctx, StructModel<List<AnyModel<'ctx>>>>,
pub instance: Ptr<'ctx, StructModel<List<AnyModel<'ctx>>>>,
}
impl<'ctx> ListObject<'ctx> {
@ -38,7 +38,7 @@ impl<'ctx> ListObject<'ctx> {
// Create object
let value = plist_model.check_value(generator, ctx.ctx, list_val).unwrap();
ListObject { item_type, value }
ListObject { item_type, instance: value }
}
/// Get the `items` field as an opaque pointer.
@ -47,7 +47,7 @@ impl<'ctx> ListObject<'ctx> {
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Ptr<'ctx, IntModel<Byte>> {
self.value.get(generator, ctx, |f| f.items, "items").pointer_cast(
self.instance.get(generator, ctx, |f| f.items, "items").pointer_cast(
generator,
ctx,
IntModel(Byte),
@ -71,7 +71,7 @@ impl<'ctx> ListObject<'ctx> {
opaque_list_ptr.set(ctx, |f| f.items, items);
// Copy len
let len = self.value.get(generator, ctx, |f| f.len, "len");
let len = self.instance.get(generator, ctx, |f| f.len, "len");
opaque_list_ptr.set(ctx, |f| f.len, len);
opaque_list_ptr

View File

@ -54,7 +54,7 @@ impl<'ctx> NDArrayObject<'ctx> {
ndarray.create_data(generator, ctx);
// Copy all contents from the list.
call_nac3_array_write_list_to_array(generator, ctx, list_value, ndarray.value);
call_nac3_array_write_list_to_array(generator, ctx, list_value, ndarray.instance);
ndarray
}
@ -85,13 +85,13 @@ impl<'ctx> NDArrayObject<'ctx> {
// Set data
let data = list.get_opaque_items_ptr(generator, ctx);
ndarray.value.set(ctx, |f| f.data, data);
ndarray.instance.set(ctx, |f| f.data, data);
// Set shape
// dim = list->len;
// shape[0] = dim;
let shape = ndarray.value.get(generator, ctx, |f| f.shape, "shape");
let dim = list.value.get(generator, ctx, |f| f.len, "dim");
let shape = ndarray.instance.get(generator, ctx, |f| f.shape, "shape");
let dim = list.instance.get(generator, ctx, |f| f.len, "dim");
shape.offset(generator, ctx, zero.value, "pdim").store(ctx, dim);
// Set strides, the `data` is contiguous
@ -119,11 +119,11 @@ impl<'ctx> NDArrayObject<'ctx> {
|_generator, _ctx| Ok(copy.value),
|generator, ctx| {
let ndarray = NDArrayObject::from_np_array_list_copy(generator, ctx, list);
Ok(Some(ndarray.value.value))
Ok(Some(ndarray.instance.value))
},
|generator, ctx| {
let ndarray = NDArrayObject::from_np_array_list_try_no_copy(generator, ctx, list);
Ok(Some(ndarray.value.value))
Ok(Some(ndarray.instance.value))
},
)
.unwrap()
@ -144,11 +144,11 @@ impl<'ctx> NDArrayObject<'ctx> {
|_generator, _ctx| Ok(copy.value),
|generator, ctx| {
let ndarray = ndarray.make_clone(generator, ctx, "np_array_copied_ndarray"); // Force copy
Ok(Some(ndarray.value.value))
Ok(Some(ndarray.instance.value))
},
|_generator, _ctx| {
// No need to copy. Return `ndarray` itself.
Ok(Some(ndarray.value.value))
Ok(Some(ndarray.instance.value))
},
)
.unwrap()

View File

@ -49,7 +49,7 @@ impl<'ctx> NDArrayObject<'ctx> {
);
broadcast_ndarray.copy_shape_from_array(generator, ctx, target_shape);
call_nac3_ndarray_broadcast_to(generator, ctx, self.value, broadcast_ndarray.value);
call_nac3_ndarray_broadcast_to(generator, ctx, self.instance, broadcast_ndarray.instance);
broadcast_ndarray
}
}
@ -125,7 +125,9 @@ impl<'ctx> NDArrayObject<'ctx> {
let shape_entries = ndarrays
.iter()
.map(|ndarray| (ndarray.value.get(generator, ctx, |f| f.shape, "shape"), ndarray.ndims))
.map(|ndarray| {
(ndarray.instance.get(generator, ctx, |f| f.shape, "shape"), ndarray.ndims)
})
.collect_vec();
broadcast_shapes(generator, ctx, &shape_entries, broadcast_ndims_int, broadcast_shape);

View File

@ -80,10 +80,10 @@ where
let result = if ctx.unifier.unioned(scalar.dtype, ctx.primitives.float) {
// Special handling for floats
let n = scalar.value.into_float_value();
let n = scalar.instance.into_float_value();
handle_float(generator, ctx, n)
} else if ctx.unifier.unioned_any(scalar.dtype, int_like(ctx)) {
let n = scalar.value.into_int_value();
let n = scalar.instance.into_int_value();
if n.get_type().get_bit_width() <= ret_int_dtype_llvm.get_bit_width() {
ctx.builder.build_int_z_extend(n, ret_int_dtype_llvm, "zext").unwrap()
@ -95,7 +95,7 @@ where
};
assert_eq!(ret_int_dtype_llvm.get_bit_width(), result.get_type().get_bit_width()); // Sanity check
ScalarObject { value: result.into(), dtype: ret_int_dtype }
ScalarObject { instance: result.into(), dtype: ret_int_dtype }
}
impl<'ctx> ScalarObject<'ctx> {
@ -104,7 +104,7 @@ impl<'ctx> ScalarObject<'ctx> {
/// Panic if the type is wrong.
pub fn into_float64(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> FloatValue<'ctx> {
if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
self.value.into_float_value() // self.value must be a FloatValue
self.instance.into_float_value() // self.value must be a FloatValue
} else {
panic!("not a float type")
}
@ -115,7 +115,7 @@ impl<'ctx> ScalarObject<'ctx> {
/// Panic if the type is wrong.
pub fn into_int32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
if ctx.unifier.unioned(self.dtype, ctx.primitives.int32) {
let value = self.value.into_int_value();
let value = self.instance.into_int_value();
debug_assert_eq!(value.get_type().get_bit_width(), 32); // Sanity check
value
} else {
@ -142,12 +142,12 @@ impl<'ctx> ScalarObject<'ctx> {
let common_ty = lhs.dtype;
let result = if ctx.unifier.unioned(common_ty, ctx.primitives.float) {
let lhs = lhs.value.into_float_value();
let rhs = rhs.value.into_float_value();
let lhs = lhs.instance.into_float_value();
let rhs = rhs.instance.into_float_value();
ctx.builder.build_float_compare(float_predicate, lhs, rhs, name).unwrap()
} else if ctx.unifier.unioned_any(common_ty, int_like(ctx)) {
let lhs = lhs.value.into_int_value();
let rhs = rhs.value.into_int_value();
let lhs = lhs.instance.into_int_value();
let rhs = rhs.instance.into_int_value();
ctx.builder.build_int_compare(int_predicate, lhs, rhs, name).unwrap()
} else {
unsupported_type(ctx, [lhs.dtype, rhs.dtype]);
@ -266,14 +266,14 @@ impl<'ctx> ScalarObject<'ctx> {
pub fn cast_to_bool(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
// TODO: Why is the original code being so lax about i1 and i8 for the returned int type?
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.bool) {
self.value.into_int_value()
self.instance.into_int_value()
} else if ctx.unifier.unioned_any(self.dtype, ints(ctx)) {
let n = self.value.into_int_value();
let n = self.instance.into_int_value();
ctx.builder
.build_int_compare(inkwell::IntPredicate::NE, n, n.get_type().const_zero(), "bool")
.unwrap()
} else if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
let n = self.value.into_float_value();
let n = self.instance.into_float_value();
ctx.builder
.build_float_compare(FloatPredicate::UNE, n, n.get_type().const_zero(), "bool")
.unwrap()
@ -281,7 +281,7 @@ impl<'ctx> ScalarObject<'ctx> {
unsupported_type(ctx, [self.dtype])
};
ScalarObject { dtype: ctx.primitives.bool, value: result.as_basic_value_enum() }
ScalarObject { dtype: ctx.primitives.bool, instance: result.as_basic_value_enum() }
}
/// Invoke NAC3's builtin `float()`.
@ -290,21 +290,21 @@ impl<'ctx> ScalarObject<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
let result: FloatValue<'_> = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
self.value.into_float_value()
self.instance.into_float_value()
} else if ctx
.unifier
.unioned_any(self.dtype, [signed_ints(ctx).as_slice(), &[ctx.primitives.bool]].concat())
{
let n = self.value.into_int_value();
let n = self.instance.into_int_value();
ctx.builder.build_signed_int_to_float(n, llvm_f64, "sitofp").unwrap()
} else if ctx.unifier.unioned_any(self.dtype, unsigned_ints(ctx)) {
let n = self.value.into_int_value();
let n = self.instance.into_int_value();
ctx.builder.build_unsigned_int_to_float(n, llvm_f64, "uitofp").unwrap()
} else {
unsupported_type(ctx, [self.dtype]);
};
ScalarObject { value: result.as_basic_value_enum(), dtype: ctx.primitives.float }
ScalarObject { instance: result.as_basic_value_enum(), dtype: ctx.primitives.float }
}
/// Invoke NAC3's builtin `round()`.
@ -318,13 +318,13 @@ impl<'ctx> ScalarObject<'ctx> {
let ret_int_dtype_llvm = ctx.get_llvm_type(generator, ret_int_dtype).into_int_type();
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
let n = self.value.into_float_value();
let n = self.instance.into_float_value();
let n = llvm_intrinsics::call_float_round(ctx, n, None);
ctx.builder.build_float_to_signed_int(n, ret_int_dtype_llvm, "round").unwrap()
} else {
unsupported_type(ctx, [self.dtype, ret_int_dtype])
};
ScalarObject { dtype: ret_int_dtype, value: result.as_basic_value_enum() }
ScalarObject { dtype: ret_int_dtype, instance: result.as_basic_value_enum() }
}
/// Invoke NAC3's builtin `np_round()`.
@ -333,12 +333,12 @@ impl<'ctx> ScalarObject<'ctx> {
#[must_use]
pub fn np_round(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
let n = self.value.into_float_value();
let n = self.instance.into_float_value();
llvm_intrinsics::call_float_rint(ctx, n, None)
} else {
unsupported_type(ctx, [self.dtype])
};
ScalarObject { dtype: ctx.primitives.float, value: result.as_basic_value_enum() }
ScalarObject { dtype: ctx.primitives.float, instance: result.as_basic_value_enum() }
}
/// Invoke NAC3's builtin `min()` or `max()`.
@ -360,8 +360,8 @@ impl<'ctx> ScalarObject<'ctx> {
MinOrMax::Max => llvm_intrinsics::call_float_maxnum,
};
let result =
function(ctx, a.value.into_float_value(), b.value.into_float_value(), None);
ScalarObject { value: result.as_basic_value_enum(), dtype: ctx.primitives.float }
function(ctx, a.instance.into_float_value(), b.instance.into_float_value(), None);
ScalarObject { instance: result.as_basic_value_enum(), dtype: ctx.primitives.float }
} else if ctx.unifier.unioned_any(
common_dtype,
[unsigned_ints(ctx).as_slice(), &[ctx.primitives.bool]].concat(),
@ -371,8 +371,9 @@ impl<'ctx> ScalarObject<'ctx> {
MinOrMax::Min => llvm_intrinsics::call_int_umin,
MinOrMax::Max => llvm_intrinsics::call_int_umax,
};
let result = function(ctx, a.value.into_int_value(), b.value.into_int_value(), None);
ScalarObject { value: result.as_basic_value_enum(), dtype: common_dtype }
let result =
function(ctx, a.instance.into_int_value(), b.instance.into_int_value(), None);
ScalarObject { instance: result.as_basic_value_enum(), dtype: common_dtype }
} else {
unsupported_type(ctx, [common_dtype])
}
@ -398,11 +399,11 @@ impl<'ctx> ScalarObject<'ctx> {
FloorOrCeil::Floor => llvm_intrinsics::call_float_floor,
FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil,
};
let n = self.value.into_float_value();
let n = self.instance.into_float_value();
let n = function(ctx, n, None);
let n = ctx.builder.build_float_to_signed_int(n, ret_int_dtype_llvm, "").unwrap();
ScalarObject { dtype: ret_int_dtype, value: n.as_basic_value_enum() }
ScalarObject { dtype: ret_int_dtype, instance: n.as_basic_value_enum() }
} else {
unsupported_type(ctx, [self.dtype])
}
@ -418,9 +419,9 @@ impl<'ctx> ScalarObject<'ctx> {
FloorOrCeil::Floor => llvm_intrinsics::call_float_floor,
FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil,
};
let n = self.value.into_float_value();
let n = self.instance.into_float_value();
let n = function(ctx, n, None);
ScalarObject { dtype: ctx.primitives.float, value: n.as_basic_value_enum() }
ScalarObject { dtype: ctx.primitives.float, instance: n.as_basic_value_enum() }
} else {
unsupported_type(ctx, [self.dtype])
}
@ -430,16 +431,16 @@ impl<'ctx> ScalarObject<'ctx> {
#[must_use]
pub fn abs(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
let n = self.value.into_float_value();
let n = self.instance.into_float_value();
let n = llvm_intrinsics::call_float_fabs(ctx, n, Some("abs"));
ScalarObject { value: n.into(), dtype: ctx.primitives.float }
ScalarObject { instance: n.into(), dtype: ctx.primitives.float }
} else if ctx.unifier.unioned_any(self.dtype, ints(ctx)) {
let n = self.value.into_int_value();
let n = self.instance.into_int_value();
let is_poisoned = ctx.ctx.bool_type().const_zero(); // is_poisoned = false
let n = llvm_intrinsics::call_int_abs(ctx, n, is_poisoned, Some("abs"));
ScalarObject { value: n.into(), dtype: self.dtype }
ScalarObject { instance: n.into(), dtype: self.dtype }
} else {
unsupported_type(ctx, [self.dtype])
}
@ -481,7 +482,7 @@ impl<'ctx> NDArrayObject<'ctx> {
pextremum_index.store(ctx, zero);
let first_scalar = self.get_nth(generator, ctx, zero);
ctx.builder.build_store(pextremum, first_scalar.value).unwrap();
ctx.builder.build_store(pextremum, first_scalar.instance).unwrap();
// Find extremum
let start = sizet_model.const_1(generator, ctx.ctx); // Start on 1
@ -494,7 +495,7 @@ impl<'ctx> NDArrayObject<'ctx> {
let scalar = self.get_nth(generator, ctx, i);
let old_extremum = ctx.builder.build_load(pextremum, "current_extremum").unwrap();
let old_extremum = ScalarObject { dtype: self.dtype, value: old_extremum };
let old_extremum = ScalarObject { dtype: self.dtype, instance: old_extremum };
let new_extremum = ScalarObject::min_or_max(ctx, kind, old_extremum, scalar);
@ -522,7 +523,7 @@ impl<'ctx> NDArrayObject<'ctx> {
let extremum_index = pextremum_index.load(generator, ctx, "extremum_index");
let extremum = ctx.builder.build_load(pextremum, "extremum_value").unwrap();
let extremum = ScalarObject { dtype: self.dtype, value: extremum };
let extremum = ScalarObject { dtype: self.dtype, instance: extremum };
(extremum, extremum_index)
}

View File

@ -224,8 +224,8 @@ impl<'ctx> NDArrayObject<'ctx> {
ctx,
num_indexes,
indexes,
self.value,
dst_ndarray.value,
self.instance,
dst_ndarray.instance,
);
dst_ndarray

View File

@ -137,7 +137,7 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
if let Some(scalars) = all_scalars {
let i = sizet_model.const_0(generator, ctx.ctx); // Pass 0 as the index
let scalar =
ScalarObject { value: mapping(generator, ctx, i, &scalars)?, dtype: ret_dtype };
ScalarObject { instance: mapping(generator, ctx, i, &scalars)?, dtype: ret_dtype };
Ok(ScalarOrNDArray::Scalar(scalar))
} else {
// Promote all input to ndarrays and map through them.

View File

@ -40,7 +40,7 @@ use util::{call_memcpy_model, gen_for_model_auto};
pub struct NDArrayObject<'ctx> {
pub dtype: Type,
pub ndims: u64,
pub value: Ptr<'ctx, StructModel<NDArray>>,
pub instance: Ptr<'ctx, StructModel<NDArray>>,
}
impl<'ctx> NDArrayObject<'ctx> {
@ -67,7 +67,7 @@ impl<'ctx> NDArrayObject<'ctx> {
) -> Self {
let pndarray_model = PtrModel(StructModel(NDArray));
let value = pndarray_model.check_value(generator, ctx.ctx, value).unwrap();
NDArrayObject { dtype, ndims, value }
NDArrayObject { dtype, ndims, instance: value }
}
/// Create a [`SimpleNDArray`] from the contents of this ndarray.
@ -106,7 +106,7 @@ impl<'ctx> NDArrayObject<'ctx> {
let ndims = self.get_ndims(generator, ctx.ctx);
result.set(ctx, |f| f.ndims, ndims);
let shape = self.value.get(generator, ctx, |f| f.shape, "shape");
let shape = self.instance.get(generator, ctx, |f| f.shape, "shape");
result.set(ctx, |f| f.shape, shape);
// Set data, but we do things differently if this ndarray is contiguous.
@ -114,7 +114,7 @@ impl<'ctx> NDArrayObject<'ctx> {
ctx.builder.build_conditional_branch(is_contiguous.value, then_bb, else_bb).unwrap();
// Inserting into then_bb; This ndarray is contiguous.
let data = self.value.get(generator, ctx, |f| f.data, "");
let data = self.instance.get(generator, ctx, |f| f.data, "");
let data = data.pointer_cast(generator, ctx, item_model, "");
result.set(ctx, |f| f.data, data);
ctx.builder.build_unconditional_branch(end_bb).unwrap();
@ -123,7 +123,7 @@ impl<'ctx> NDArrayObject<'ctx> {
// TODO: Reimplement this? This method does give us the contiguous `data`, but
// this creates a few extra bytes of useless information because an entire NDArray
// is allocated. Though this is super convenient.
let data = self.make_clone(generator, ctx, "").value.get(generator, ctx, |f| f.data, "");
let data = self.make_clone(generator, ctx, "").instance.get(generator, ctx, |f| f.data, "");
let data = data.pointer_cast(generator, ctx, item_model, "");
result.set(ctx, |f| f.data, data);
ctx.builder.build_unconditional_branch(end_bb).unwrap();
@ -166,10 +166,10 @@ impl<'ctx> NDArrayObject<'ctx> {
let data = simple_ndarray
.get(generator, ctx, |f| f.data, "")
.pointer_cast(generator, ctx, byte_model, "data");
ndarray.value.set(ctx, |f| f.data, data);
ndarray.instance.set(ctx, |f| f.data, data);
let shape = simple_ndarray.get(generator, ctx, |f| f.shape, "shape");
ndarray.value.set(ctx, |f| f.shape, shape);
ndarray.instance.set(ctx, |f| f.shape, shape);
// Set strides. We know `data` is contiguous.
ndarray.update_strides_by_shape(generator, ctx);
@ -183,7 +183,7 @@ impl<'ctx> NDArrayObject<'ctx> {
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, SizeT> {
call_nac3_ndarray_size(generator, ctx, self.value)
call_nac3_ndarray_size(generator, ctx, self.instance)
}
/// Get the `ndarray.nbytes` of this ndarray.
@ -192,7 +192,7 @@ impl<'ctx> NDArrayObject<'ctx> {
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, SizeT> {
call_nac3_ndarray_nbytes(generator, ctx, self.value)
call_nac3_ndarray_nbytes(generator, ctx, self.instance)
}
/// Get the `len()` of this ndarray.
@ -201,7 +201,7 @@ impl<'ctx> NDArrayObject<'ctx> {
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, SizeT> {
call_nac3_ndarray_len(generator, ctx, self.value)
call_nac3_ndarray_len(generator, ctx, self.instance)
}
/// Check if this ndarray is C-contiguous.
@ -212,7 +212,7 @@ impl<'ctx> NDArrayObject<'ctx> {
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, Bool> {
call_nac3_ndarray_is_c_contiguous(generator, ctx, self.value)
call_nac3_ndarray_is_c_contiguous(generator, ctx, self.instance)
}
/// Get the pointer to the n-th (0-based) element.
@ -227,7 +227,7 @@ impl<'ctx> NDArrayObject<'ctx> {
) -> PointerValue<'ctx> {
let elem_ty = ctx.get_llvm_type(generator, self.dtype);
let p = call_nac3_ndarray_get_nth_pelement(generator, ctx, self.value, nth);
let p = call_nac3_ndarray_get_nth_pelement(generator, ctx, self.instance, nth);
ctx.builder
.build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), name)
.unwrap()
@ -242,7 +242,7 @@ impl<'ctx> NDArrayObject<'ctx> {
) -> ScalarObject<'ctx> {
let p = self.get_nth_pointer(generator, ctx, nth, "value");
let value = ctx.builder.build_load(p, "value").unwrap();
ScalarObject { dtype: self.dtype, value }
ScalarObject { dtype: self.dtype, instance: value }
}
/// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`.
@ -253,7 +253,7 @@ impl<'ctx> NDArrayObject<'ctx> {
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) {
call_nac3_ndarray_set_strides_by_shape(generator, ctx, self.value);
call_nac3_ndarray_set_strides_by_shape(generator, ctx, self.instance);
}
/// Copy data from another ndarray.
@ -269,7 +269,7 @@ impl<'ctx> NDArrayObject<'ctx> {
src: NDArrayObject<'ctx>,
) {
assert!(ctx.unifier.unioned(self.dtype, src.dtype), "self and src dtype should match");
call_nac3_ndarray_copy_data(generator, ctx, src.value, self.value);
call_nac3_ndarray_copy_data(generator, ctx, src.instance, self.instance);
}
/// Allocate an ndarray on the stack given its `ndims` and `dtype`.
@ -312,7 +312,7 @@ impl<'ctx> NDArrayObject<'ctx> {
let strides = sizet_model.array_alloca(generator, ctx, ndims_val.value, "alloca_strides");
pndarray.set(ctx, |f| f.strides, strides);
NDArrayObject { dtype, ndims, value: pndarray }
NDArrayObject { dtype, ndims, instance: pndarray }
}
/// Convenience function.
@ -342,7 +342,7 @@ impl<'ctx> NDArrayObject<'ctx> {
let clone =
NDArrayObject::alloca_uninitialized(generator, ctx, self.dtype, self.ndims, name);
let shape = self.value.gep(ctx, |f| f.shape).load(generator, ctx, "shape");
let shape = self.instance.gep(ctx, |f| f.shape).load(generator, ctx, "shape");
clone.copy_shape_from_array(generator, ctx, shape);
clone.create_data(generator, ctx);
clone.copy_data_from(generator, ctx, *self);
@ -412,7 +412,7 @@ impl<'ctx> NDArrayObject<'ctx> {
let nbytes = self.nbytes(generator, ctx);
let data = byte_model.array_alloca(generator, ctx, nbytes.value, "data");
self.value.set(ctx, |f| f.data, data);
self.instance.set(ctx, |f| f.data, data);
self.update_strides_by_shape(generator, ctx);
}
@ -424,7 +424,7 @@ impl<'ctx> NDArrayObject<'ctx> {
ctx: &mut CodeGenContext<'ctx, '_>,
src_shape: Ptr<'ctx, IntModel<SizeT>>,
) {
let dst_shape = self.value.get(generator, ctx, |f| f.shape, "dst_shape");
let dst_shape = self.instance.get(generator, ctx, |f| f.shape, "dst_shape");
let num_items = self.get_ndims(generator, ctx.ctx).value;
call_memcpy_model(generator, ctx, dst_shape, src_shape, num_items);
}
@ -438,7 +438,7 @@ impl<'ctx> NDArrayObject<'ctx> {
src_ndarray: NDArrayObject<'ctx>,
) {
assert_eq!(self.ndims, src_ndarray.ndims);
let src_shape = src_ndarray.value.get(generator, ctx, |f| f.shape, "src_shape");
let src_shape = src_ndarray.instance.get(generator, ctx, |f| f.shape, "src_shape");
self.copy_shape_from_array(generator, ctx, src_shape);
}
@ -449,7 +449,7 @@ impl<'ctx> NDArrayObject<'ctx> {
ctx: &mut CodeGenContext<'ctx, '_>,
src_strides: Ptr<'ctx, IntModel<SizeT>>,
) {
let dst_strides = self.value.get(generator, ctx, |f| f.strides, "dst_strides");
let dst_strides = self.instance.get(generator, ctx, |f| f.strides, "dst_strides");
let num_items = self.get_ndims(generator, ctx.ctx).value;
call_memcpy_model(generator, ctx, dst_strides, src_strides, num_items);
}
@ -463,7 +463,7 @@ impl<'ctx> NDArrayObject<'ctx> {
src_ndarray: NDArrayObject<'ctx>,
) {
assert_eq!(self.ndims, src_ndarray.ndims);
let src_strides = src_ndarray.value.get(generator, ctx, |f| f.strides, "src_strides");
let src_strides = src_ndarray.instance.get(generator, ctx, |f| f.strides, "src_strides");
self.copy_strides_from_array(generator, ctx, src_strides);
}
@ -518,7 +518,7 @@ impl<'ctx> NDArrayObject<'ctx> {
{
self.foreach_pointer(generator, ctx, |generator, ctx, hooks, i, p| {
let value = ctx.builder.build_load(p, "value").unwrap();
let scalar = ScalarObject { dtype: self.dtype, value };
let scalar = ScalarObject { dtype: self.dtype, instance: value };
body(generator, ctx, hooks, i, scalar)
})
}
@ -606,7 +606,11 @@ impl<'ctx> NDArrayObject<'ctx> {
// Inserting into then_bb: reshape is possible without copying
ctx.builder.position_at_end(then_bb);
dst_ndarray.update_strides_by_shape(generator, ctx);
dst_ndarray.value.set(ctx, |f| f.data, self.value.get(generator, ctx, |f| f.data, "data"));
dst_ndarray.instance.set(
ctx,
|f| f.data,
self.instance.get(generator, ctx, |f| f.data, "data"),
);
ctx.builder.build_unconditional_branch(end_bb).unwrap();
// Inserting into else_bb: reshape is impossible without copying
@ -672,8 +676,8 @@ impl<'ctx> NDArrayObject<'ctx> {
call_nac3_ndarray_transpose(
generator,
ctx,
self.value,
transposed_ndarray.value,
self.instance,
transposed_ndarray.instance,
num_axes,
axes,
);
@ -694,7 +698,7 @@ impl<'ctx> NDArrayObject<'ctx> {
let sizet_model = IntModel(SizeT);
let ndarray_ndims = self.get_ndims(generator, ctx.ctx);
let ndarray_shape = self.value.get(generator, ctx, |f| f.shape, "shape");
let ndarray_shape = self.instance.get(generator, ctx, |f| f.shape, "shape");
let output_ndims = sizet_model.constant(generator, ctx.ctx, out_ndims);
let output_shape = out_shape;

View File

@ -30,9 +30,9 @@ impl<'ctx> NDArrayObject<'ctx> {
let final_ndims_int = max(a.ndims, b.ndims);
let a_ndims = a.get_ndims(generator, ctx.ctx);
let a_shape = a.value.get(generator, ctx, |f| f.shape, "a_shape");
let a_shape = a.instance.get(generator, ctx, |f| f.shape, "a_shape");
let b_ndims = b.get_ndims(generator, ctx.ctx);
let b_shape = b.value.get(generator, ctx, |f| f.shape, "b_shape");
let b_shape = b.instance.get(generator, ctx, |f| f.shape, "b_shape");
let final_ndims = sizet_model.constant(generator, ctx.ctx, final_ndims_int);
let new_a_shape =
sizet_model.array_alloca(generator, ctx, final_ndims.value, "new_a_shape");
@ -68,9 +68,9 @@ impl<'ctx> NDArrayObject<'ctx> {
call_nac3_ndarray_float64_matmul_at_least_2d(
generator,
ctx,
new_a.value,
new_b.value,
dst.value,
new_a.instance,
new_b.instance,
dst.instance,
);
dst
@ -147,7 +147,7 @@ impl<'ctx> NDArrayObject<'ctx> {
}
NDArrayOut::WriteToNDArray { ndarray: out_ndarray } => {
// TODO: It is possible to check the shapes before computing the matmul to save resources.
let result_shape = result.value.get(generator, ctx, |f| f.shape, "result_shape");
let result_shape = result.instance.get(generator, ctx, |f| f.shape, "result_shape");
out_ndarray.check_can_be_written_by_out(generator, ctx, result.ndims, result_shape);
// TODO: We can just set `out_ndarray.data` to `result.data`. Should we?

View File

@ -15,7 +15,7 @@ use super::NDArrayObject;
#[derive(Debug, Clone, Copy)]
pub struct ScalarObject<'ctx> {
pub dtype: Type,
pub value: BasicValueEnum<'ctx>,
pub instance: BasicValueEnum<'ctx>,
}
impl<'ctx> ScalarObject<'ctx> {
@ -31,13 +31,13 @@ impl<'ctx> ScalarObject<'ctx> {
let pbyte_model = PtrModel(IntModel(Byte));
// We have to put the value on the stack to get a data pointer.
let data = ctx.builder.build_alloca(self.value.get_type(), "as_ndarray_scalar").unwrap();
ctx.builder.build_store(data, self.value).unwrap();
let data = ctx.builder.build_alloca(self.instance.get_type(), "as_ndarray_scalar").unwrap();
ctx.builder.build_store(data, self.instance).unwrap();
let data = pbyte_model.pointer_cast(generator, ctx, data, "data");
let ndarray =
NDArrayObject::alloca_uninitialized(generator, ctx, self.dtype, 0, "scalar_ndarray");
ndarray.value.set(ctx, |f| f.data, data);
ndarray.instance.set(ctx, |f| f.data, data);
ndarray
}
}
@ -54,8 +54,8 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
#[must_use]
pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> {
match self {
ScalarOrNDArray::Scalar(scalar) => scalar.value,
ScalarOrNDArray::NDArray(ndarray) => ndarray.value.value.as_basic_value_enum(),
ScalarOrNDArray::Scalar(scalar) => scalar.instance,
ScalarOrNDArray::NDArray(ndarray) => ndarray.instance.value.as_basic_value_enum(),
}
}
@ -136,7 +136,7 @@ pub fn split_scalar_or_ndarray<'ctx, G: CodeGenerator + ?Sized>(
ScalarOrNDArray::NDArray(ndarray)
}
_ => {
let scalar = ScalarObject { dtype: input_ty, value: input };
let scalar = ScalarObject { dtype: input_ty, instance: input };
ScalarOrNDArray::Scalar(scalar)
}
}

View File

@ -39,14 +39,14 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
let input_sequence =
ListObject::from_value_and_type(generator, ctx, input_sequence, input_sequence_ty);
let len = input_sequence.value.gep(ctx, |f| f.len).load(generator, ctx, "len");
let len = input_sequence.instance.gep(ctx, |f| f.len).load(generator, ctx, "len");
let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence");
// Load all the `int32`s from the input_sequence, cast them to `SizeT`, and store them into `result`
gen_for_model_auto(generator, ctx, zero, len, one, |generator, ctx, _hooks, i| {
// Load the i-th int32 in the input sequence
let int = input_sequence
.value
.instance
.get(generator, ctx, |f| f.items, "int")
.ix(generator, ctx, i.value, "int")
.value

View File

@ -1105,7 +1105,7 @@ impl<'a> BuiltinBuilder<'a> {
PrimDef::FunBool => scalar.cast_to_bool(ctx),
_ => unreachable!(),
};
Ok(result.value)
Ok(result.instance)
},
)?;
Ok(Some(result.to_basic_value_enum()))
@ -1166,7 +1166,7 @@ impl<'a> BuiltinBuilder<'a> {
ctx,
ret_int_dtype,
|generator, ctx, _i, scalar| {
Ok(scalar.round(generator, ctx, ret_int_dtype).value)
Ok(scalar.round(generator, ctx, ret_int_dtype).instance)
},
)?;
Ok(Some(result.to_basic_value_enum()))
@ -1231,7 +1231,7 @@ impl<'a> BuiltinBuilder<'a> {
ctx,
int_sized,
|generator, ctx, _i, scalar| {
Ok(scalar.floor_or_ceil(generator, ctx, kind, int_sized).value)
Ok(scalar.floor_or_ceil(generator, ctx, kind, int_sized).instance)
},
)?;
Ok(Some(result.to_basic_value_enum()))
@ -1631,7 +1631,7 @@ impl<'a> BuiltinBuilder<'a> {
ctx.primitives.float,
move |_generator, ctx, _i, scalar| {
let result = scalar.np_floor_or_ceil(ctx, kind);
Ok(result.value)
Ok(result.instance)
},
)?;
Ok(Some(result.to_basic_value_enum()))
@ -1659,7 +1659,7 @@ impl<'a> BuiltinBuilder<'a> {
ctx.primitives.float,
|_generator, ctx, _i, scalar| {
let result = scalar.np_round(ctx);
Ok(result.value)
Ok(result.instance)
},
)?;
Ok(Some(result.to_basic_value_enum()))
@ -1746,10 +1746,10 @@ impl<'a> BuiltinBuilder<'a> {
_ => unreachable!(),
};
let m = ScalarObject { dtype: m_ty, value: m_val };
let n = ScalarObject { dtype: n_ty, value: n_val };
let m = ScalarObject { dtype: m_ty, instance: m_val };
let n = ScalarObject { dtype: n_ty, instance: n_val };
let result = ScalarObject::min_or_max(ctx, kind, m, n);
Ok(Some(result.value))
Ok(Some(result.instance))
},
)))),
loc: None,
@ -1802,10 +1802,10 @@ impl<'a> BuiltinBuilder<'a> {
.value
.as_basic_value_enum(),
PrimDef::FunNpMin => {
a.min_or_max(generator, ctx, MinOrMax::Min).value.as_basic_value_enum()
a.min_or_max(generator, ctx, MinOrMax::Min).instance.as_basic_value_enum()
}
PrimDef::FunNpMax => {
a.min_or_max(generator, ctx, MinOrMax::Max).value.as_basic_value_enum()
a.min_or_max(generator, ctx, MinOrMax::Max).instance.as_basic_value_enum()
}
_ => unreachable!(),
};
@ -1871,7 +1871,7 @@ impl<'a> BuiltinBuilder<'a> {
let x2 = scalars[1];
let result = ScalarObject::min_or_max(ctx, kind, x1, x2);
Ok(result.value)
Ok(result.instance)
},
)?;
Ok(Some(result.to_basic_value_enum()))
@ -1912,7 +1912,7 @@ impl<'a> BuiltinBuilder<'a> {
generator,
ctx,
num_ty.ty,
|_generator, ctx, _i, scalar| Ok(scalar.abs(ctx).value),
|_generator, ctx, _i, scalar| Ok(scalar.abs(ctx).instance),
)?;
Ok(Some(result.to_basic_value_enum()))
},