1
0
forked from M-Labs/nac3

[core] codegen: Refactor ProxyType and ProxyValue

Accepts generator+context object for generic type checking. Also
implements more default trait impl for easier delegation.
This commit is contained in:
David Mak 2024-11-01 15:17:00 +08:00
parent 9d9ead211e
commit a4f53b6e6b
15 changed files with 235 additions and 109 deletions

View File

@ -461,7 +461,8 @@ fn format_rpc_arg<'ctx>(
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let llvm_arg_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty);
let llvm_arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), llvm_usize, None);
let llvm_arg =
NDArrayValue::from_pointer_value(arg.into_pointer_value(), llvm_usize, None);
let llvm_usize_sizeof = ctx
.builder
@ -1315,7 +1316,8 @@ fn polymorphic_print<'ctx>(
fmt.push('[');
flush(ctx, generator, &mut fmt, &mut args);
let val = ListValue::from_ptr_val(value.into_pointer_value(), llvm_usize, None);
let val =
ListValue::from_pointer_value(value.into_pointer_value(), llvm_usize, None);
let len = val.load_size(ctx, None);
let last =
ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap();
@ -1371,7 +1373,8 @@ fn polymorphic_print<'ctx>(
fmt.push_str("array([");
flush(ctx, generator, &mut fmt, &mut args);
let val = NDArrayValue::from_ptr_val(value.into_pointer_value(), llvm_usize, None);
let val =
NDArrayValue::from_pointer_value(value.into_pointer_value(), llvm_usize, None);
let len = call_ndarray_calc_size(generator, ctx, &val.dim_sizes(), (None, None));
let last =
ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap();
@ -1425,7 +1428,7 @@ fn polymorphic_print<'ctx>(
fmt.push_str("range(");
flush(ctx, generator, &mut fmt, &mut args);
let val = RangeValue::from_ptr_val(value.into_pointer_value(), None);
let val = RangeValue::from_pointer_value(value.into_pointer_value(), None);
let (start, stop, step) = destructure_range(ctx, val);

View File

@ -47,7 +47,7 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
let (arg_ty, arg) = n;
Ok(if ctx.unifier.unioned(arg_ty, range_ty) {
let arg = RangeValue::from_ptr_val(arg.into_pointer_value(), Some("range"));
let arg = RangeValue::from_pointer_value(arg.into_pointer_value(), Some("range"));
let (start, end, step) = destructure_range(ctx, arg);
calculate_len_for_slice_range(generator, ctx, start, end, step)
} else {
@ -67,7 +67,8 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let llvm_usize = generator.get_size_type(ctx.ctx);
let arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), llvm_usize, None);
let arg =
NDArrayValue::from_pointer_value(arg.into_pointer_value(), llvm_usize, None);
let ndims = arg.dim_sizes().size(ctx, generator);
ctx.make_assert(
@ -148,7 +149,7 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
ctx,
ctx.primitives.int32,
None,
NDArrayValue::from_ptr_val(n, llvm_usize, None),
NDArrayValue::from_pointer_value(n, llvm_usize, None),
|generator, ctx, val| call_int32(generator, ctx, (elem_ty, val)),
)?;
@ -210,7 +211,7 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>(
ctx,
ctx.primitives.int64,
None,
NDArrayValue::from_ptr_val(n, llvm_usize, None),
NDArrayValue::from_pointer_value(n, llvm_usize, None),
|generator, ctx, val| call_int64(generator, ctx, (elem_ty, val)),
)?;
@ -288,7 +289,7 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>(
ctx,
ctx.primitives.uint32,
None,
NDArrayValue::from_ptr_val(n, llvm_usize, None),
NDArrayValue::from_pointer_value(n, llvm_usize, None),
|generator, ctx, val| call_uint32(generator, ctx, (elem_ty, val)),
)?;
@ -355,7 +356,7 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>(
ctx,
ctx.primitives.uint64,
None,
NDArrayValue::from_ptr_val(n, llvm_usize, None),
NDArrayValue::from_pointer_value(n, llvm_usize, None),
|generator, ctx, val| call_uint64(generator, ctx, (elem_ty, val)),
)?;
@ -421,7 +422,7 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
ctx,
ctx.primitives.float,
None,
NDArrayValue::from_ptr_val(n, llvm_usize, None),
NDArrayValue::from_pointer_value(n, llvm_usize, None),
|generator, ctx, val| call_float(generator, ctx, (elem_ty, val)),
)?;
@ -467,7 +468,7 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
ctx,
ret_elem_ty,
None,
NDArrayValue::from_ptr_val(n, llvm_usize, None),
NDArrayValue::from_pointer_value(n, llvm_usize, None),
|generator, ctx, val| call_round(generator, ctx, (elem_ty, val), ret_elem_ty),
)?;
@ -507,7 +508,7 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>(
ctx,
ctx.primitives.float,
None,
NDArrayValue::from_ptr_val(n, llvm_usize, None),
NDArrayValue::from_pointer_value(n, llvm_usize, None),
|generator, ctx, val| call_numpy_round(generator, ctx, (elem_ty, val)),
)?;
@ -572,7 +573,7 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
ctx,
ctx.primitives.bool,
None,
NDArrayValue::from_ptr_val(n, llvm_usize, None),
NDArrayValue::from_pointer_value(n, llvm_usize, None),
|generator, ctx, val| {
let elem = call_bool(generator, ctx, (elem_ty, val))?;
@ -626,7 +627,7 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>(
ctx,
ret_elem_ty,
None,
NDArrayValue::from_ptr_val(n, llvm_usize, None),
NDArrayValue::from_pointer_value(n, llvm_usize, None),
|generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty),
)?;
@ -676,7 +677,7 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>(
ctx,
ret_elem_ty,
None,
NDArrayValue::from_ptr_val(n, llvm_usize, None),
NDArrayValue::from_pointer_value(n, llvm_usize, None),
|generator, ctx, val| call_ceil(generator, ctx, (elem_ty, val), ret_elem_ty),
)?;
@ -907,7 +908,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
let n = NDArrayValue::from_pointer_value(n, llvm_usize, None);
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None));
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let n_sz_eqz = ctx
@ -1120,7 +1121,7 @@ where
ctx,
ret_elem_ty,
None,
NDArrayValue::from_ptr_val(x, llvm_usize, None),
NDArrayValue::from_pointer_value(x, llvm_usize, None),
|generator, ctx, elem_val| {
helper_call_numpy_unary_elementwise(
generator,
@ -1959,7 +1960,7 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
@ -2001,7 +2002,7 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
@ -2051,7 +2052,7 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
@ -2106,7 +2107,7 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
@ -2148,7 +2149,7 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
@ -2191,7 +2192,7 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
@ -2244,7 +2245,7 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None);
// Changing second parameter to a `NDArray` for uniformity in function call
let n2_array = numpy::create_ndarray_const_shape(
generator,
@ -2339,7 +2340,7 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
@ -2382,7 +2383,7 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()

View File

@ -1168,7 +1168,8 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() =>
{
let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range"));
let iter_val =
RangeValue::from_pointer_value(iter_val.into_pointer_value(), Some("range"));
let (start, stop, step) = destructure_range(ctx, iter_val);
let diff = ctx.builder.build_int_sub(stop, start, "diff").unwrap();
// add 1 to the length as the value is rounded to zero
@ -1399,8 +1400,10 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty1);
let sizeof_elem = llvm_elem_ty.size_of().unwrap();
let lhs = ListValue::from_ptr_val(left_val.into_pointer_value(), llvm_usize, None);
let rhs = ListValue::from_ptr_val(right_val.into_pointer_value(), llvm_usize, None);
let lhs =
ListValue::from_pointer_value(left_val.into_pointer_value(), llvm_usize, None);
let rhs =
ListValue::from_pointer_value(right_val.into_pointer_value(), llvm_usize, None);
let size = ctx
.builder
@ -1483,7 +1486,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
codegen_unreachable!(ctx)
};
let list_val =
ListValue::from_ptr_val(list_val.into_pointer_value(), llvm_usize, None);
ListValue::from_pointer_value(list_val.into_pointer_value(), llvm_usize, None);
let int_val = ctx
.builder
.build_int_s_extend(int_val.into_int_value(), llvm_usize, "")
@ -1562,9 +1565,9 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
let left_val =
NDArrayValue::from_ptr_val(left_val.into_pointer_value(), llvm_usize, None);
NDArrayValue::from_pointer_value(left_val.into_pointer_value(), llvm_usize, None);
let right_val =
NDArrayValue::from_ptr_val(right_val.into_pointer_value(), llvm_usize, None);
NDArrayValue::from_pointer_value(right_val.into_pointer_value(), llvm_usize, None);
let res = if op.base == Operator::MatMult {
// MatMult is the only binop which is not an elementwise op
@ -1613,7 +1616,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
} else {
let (ndarray_dtype, _) =
unpack_ndarray_var_tys(&mut ctx.unifier, if is_ndarray1 { ty1 } else { ty2 });
let ndarray_val = NDArrayValue::from_ptr_val(
let ndarray_val = NDArrayValue::from_pointer_value(
if is_ndarray1 { left_val } else { right_val }.into_pointer_value(),
llvm_usize,
None,
@ -1808,7 +1811,7 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
let llvm_usize = generator.get_size_type(ctx.ctx);
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
let val = NDArrayValue::from_ptr_val(val.into_pointer_value(), llvm_usize, None);
let val = NDArrayValue::from_pointer_value(val.into_pointer_value(), llvm_usize, None);
// ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before
// passing it to the elementwise codegen function
@ -1900,7 +1903,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
let left_val =
NDArrayValue::from_ptr_val(lhs.into_pointer_value(), llvm_usize, None);
NDArrayValue::from_pointer_value(lhs.into_pointer_value(), llvm_usize, None);
let res = numpy::ndarray_elementwise_binop_impl(
generator,
ctx,
@ -2202,9 +2205,9 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
}
let left_val =
ListValue::from_ptr_val(lhs.into_pointer_value(), llvm_usize, None);
ListValue::from_pointer_value(lhs.into_pointer_value(), llvm_usize, None);
let right_val =
ListValue::from_ptr_val(rhs.into_pointer_value(), llvm_usize, None);
ListValue::from_pointer_value(rhs.into_pointer_value(), llvm_usize, None);
Ok(gen_if_else_expr_callback(
generator,
@ -2768,7 +2771,8 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
// elements over
let subscripted_ndarray =
generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?;
let ndarray = NDArrayValue::from_ptr_val(subscripted_ndarray, llvm_usize, None);
let ndarray =
NDArrayValue::from_pointer_value(subscripted_ndarray, llvm_usize, None);
let num_dims = v.load_ndims(ctx);
ndarray.store_ndims(
@ -3403,7 +3407,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
} else {
return Ok(None);
};
let v = ListValue::from_ptr_val(v, usize, Some("arr"));
let v = ListValue::from_pointer_value(v, usize, Some("arr"));
let ty = ctx.get_llvm_type(generator, *ty);
if let ExprKind::Slice { lower, upper, step } = &slice.node {
let one = int32.const_int(1, false);
@ -3513,7 +3517,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
} else {
return Ok(None);
};
let v = NDArrayValue::from_ptr_val(v, usize, None);
let v = NDArrayValue::from_pointer_value(v, usize, None);
return gen_ndarray_subscript_expr(generator, ctx, *ty, *ndims, v, slice);
}

View File

@ -54,7 +54,7 @@ fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
let ndarray = generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?;
Ok(NDArrayValue::from_ptr_val(ndarray, llvm_usize, None))
Ok(NDArrayValue::from_pointer_value(ndarray, llvm_usize, None))
}
/// Creates an `NDArray` instance from a dynamic shape.
@ -314,11 +314,11 @@ fn call_ndarray_empty_impl<'ctx, G: CodeGenerator + ?Sized>(
match shape {
BasicValueEnum::PointerValue(shape_list_ptr)
if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() =>
if ListValue::is_representable(shape_list_ptr, llvm_usize).is_ok() =>
{
// 1. A list of ints; e.g., `np.empty([600, 800, 3])`
let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None);
let shape_list = ListValue::from_pointer_value(shape_list_ptr, llvm_usize, None);
create_ndarray_dyn_shape(
generator,
ctx,
@ -499,12 +499,14 @@ where
// Assert that all ndarray operands are broadcastable to the target size
if !lhs_scalar {
let lhs_val = NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None);
let lhs_val =
NDArrayValue::from_pointer_value(lhs_val.into_pointer_value(), llvm_usize, None);
ndarray_assert_is_broadcastable(generator, ctx, res, lhs_val);
}
if !rhs_scalar {
let rhs_val = NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None);
let rhs_val =
NDArrayValue::from_pointer_value(rhs_val.into_pointer_value(), llvm_usize, None);
ndarray_assert_is_broadcastable(generator, ctx, res, rhs_val);
}
@ -512,7 +514,8 @@ where
let lhs_elem = if lhs_scalar {
lhs_val
} else {
let lhs = NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None);
let lhs =
NDArrayValue::from_pointer_value(lhs_val.into_pointer_value(), llvm_usize, None);
let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, idx);
unsafe { lhs.data().get_unchecked(ctx, generator, &lhs_idx, None) }
@ -521,7 +524,8 @@ where
let rhs_elem = if rhs_scalar {
rhs_val
} else {
let rhs = NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None);
let rhs =
NDArrayValue::from_pointer_value(rhs_val.into_pointer_value(), llvm_usize, None);
let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, idx);
unsafe { rhs.data().get_unchecked(ctx, generator, &rhs_idx, None) }
@ -647,11 +651,15 @@ fn llvm_ndlist_get_ndims<'ctx, G: CodeGenerator + ?Sized>(
let ndims = llvm_usize.const_int(1, false);
match list_elem_ty {
AnyTypeEnum::PointerType(ptr_ty) if ListType::is_type(ptr_ty, llvm_usize).is_ok() => {
AnyTypeEnum::PointerType(ptr_ty)
if ListType::is_representable(ptr_ty, llvm_usize).is_ok() =>
{
ndims.const_add(llvm_ndlist_get_ndims(generator, ctx, ptr_ty))
}
AnyTypeEnum::PointerType(ptr_ty) if NDArrayType::is_type(ptr_ty, llvm_usize).is_ok() => {
AnyTypeEnum::PointerType(ptr_ty)
if NDArrayType::is_representable(ptr_ty, llvm_usize).is_ok() =>
{
todo!("Getting ndims for list[ndarray] not supported")
}
@ -668,11 +676,13 @@ fn llvm_arraylike_get_ndims<'ctx, G: CodeGenerator + ?Sized>(
let llvm_usize = generator.get_size_type(ctx.ctx);
match value {
BasicValueEnum::PointerValue(v) if NDArrayValue::is_instance(v, llvm_usize).is_ok() => {
NDArrayValue::from_ptr_val(v, llvm_usize, None).load_ndims(ctx)
BasicValueEnum::PointerValue(v)
if NDArrayValue::is_representable(v, llvm_usize).is_ok() =>
{
NDArrayValue::from_pointer_value(v, llvm_usize, None).load_ndims(ctx)
}
BasicValueEnum::PointerValue(v) if ListValue::is_instance(v, llvm_usize).is_ok() => {
BasicValueEnum::PointerValue(v) if ListValue::is_representable(v, llvm_usize).is_ok() => {
llvm_ndlist_get_ndims(generator, ctx, v.get_type())
}
@ -695,7 +705,9 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>(
let list_elem_ty = src_lst.get_type().element_type();
match list_elem_ty {
AnyTypeEnum::PointerType(ptr_ty) if ListType::is_type(ptr_ty, llvm_usize).is_ok() => {
AnyTypeEnum::PointerType(ptr_ty)
if ListType::is_representable(ptr_ty, llvm_usize).is_ok() =>
{
// The stride of elements in this dimension, i.e. the number of elements between arr[i]
// and arr[i + 1] in this dimension
let stride = call_ndarray_calc_size(
@ -719,7 +731,7 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>(
let dst_ptr =
unsafe { ctx.builder.build_gep(dst_slice_ptr, &[offset], "").unwrap() };
let nested_lst_elem = ListValue::from_ptr_val(
let nested_lst_elem = ListValue::from_pointer_value(
unsafe { src_lst.data().get_unchecked(ctx, generator, &i, None) }
.into_pointer_value(),
llvm_usize,
@ -740,7 +752,9 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>(
)?;
}
AnyTypeEnum::PointerType(ptr_ty) if NDArrayType::is_type(ptr_ty, llvm_usize).is_ok() => {
AnyTypeEnum::PointerType(ptr_ty)
if NDArrayType::is_representable(ptr_ty, llvm_usize).is_ok() =>
{
todo!("Not implemented for list[ndarray]")
}
@ -801,8 +815,8 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
let object = object.into_pointer_value();
// object is an NDArray instance - copy object unless copy=0 && ndmin < object.ndims
if NDArrayValue::is_instance(object, llvm_usize).is_ok() {
let object = NDArrayValue::from_ptr_val(object, llvm_usize, None);
if NDArrayValue::is_representable(object, llvm_usize).is_ok() {
let object = NDArrayValue::from_pointer_value(object, llvm_usize, None);
let ndarray = gen_if_else_expr_callback(
generator,
@ -876,7 +890,7 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
|_, _| Ok(Some(object.as_base_value())),
)?;
return Ok(NDArrayValue::from_ptr_val(
return Ok(NDArrayValue::from_pointer_value(
ndarray.map(BasicValueEnum::into_pointer_value).unwrap(),
llvm_usize,
None,
@ -884,8 +898,8 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
}
// Remaining case: TList
assert!(ListValue::is_instance(object, llvm_usize).is_ok());
let object = ListValue::from_ptr_val(object, llvm_usize, None);
assert!(ListValue::is_representable(object, llvm_usize).is_ok());
let object = ListValue::from_pointer_value(object, llvm_usize, None);
// The number of dimensions to prepend 1's to
let ndims = llvm_ndlist_get_ndims(generator, ctx, object.as_base_value().get_type());
@ -965,7 +979,8 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
.map(|v| ctx.builder.build_bit_cast(v, plist_plist_i8, "").unwrap())
.map(BasicValueEnum::into_pointer_value)
.unwrap();
let this_dim = ListValue::from_ptr_val(this_dim, llvm_usize, None);
let this_dim =
ListValue::from_pointer_value(this_dim, llvm_usize, None);
// TODO: Assert this_dim.sz != 0
@ -991,7 +1006,7 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
let lst = ListValue::from_ptr_val(
let lst = ListValue::from_pointer_value(
ctx.builder
.build_load(lst, "")
.map(BasicValueEnum::into_pointer_value)
@ -1388,9 +1403,9 @@ where
let ndarray = res.unwrap_or_else(|| {
if lhs_scalar && rhs_scalar {
let lhs_val =
NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None);
NDArrayValue::from_pointer_value(lhs_val.into_pointer_value(), llvm_usize, None);
let rhs_val =
NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None);
NDArrayValue::from_pointer_value(rhs_val.into_pointer_value(), llvm_usize, None);
let ndarray_dims = call_ndarray_calc_broadcast(generator, ctx, lhs_val, rhs_val);
@ -1406,7 +1421,7 @@ where
)
.unwrap()
} else {
let ndarray = NDArrayValue::from_ptr_val(
let ndarray = NDArrayValue::from_pointer_value(
if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(),
llvm_usize,
None,
@ -1970,7 +1985,7 @@ pub fn gen_ndarray_copy<'ctx>(
generator,
context,
this_elem_ty,
NDArrayValue::from_ptr_val(this_arg.into_pointer_value(), llvm_usize, None),
NDArrayValue::from_pointer_value(this_arg.into_pointer_value(), llvm_usize, None),
)
.map(NDArrayValue::into)
}
@ -2002,7 +2017,7 @@ pub fn gen_ndarray_fill<'ctx>(
ndarray_fill_flattened(
generator,
context,
NDArrayValue::from_ptr_val(this_arg, llvm_usize, None),
NDArrayValue::from_pointer_value(this_arg, llvm_usize, None),
|generator, ctx, _| {
let value = if value_arg.is_pointer_value() {
let llvm_i1 = ctx.ctx.bool_type();
@ -2043,7 +2058,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
// Dimensions are reversed in the transposed array
@ -2162,7 +2177,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
@ -2172,11 +2187,11 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
let out = match shape {
BasicValueEnum::PointerValue(shape_list_ptr)
if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() =>
if ListValue::is_representable(shape_list_ptr, llvm_usize).is_ok() =>
{
// 1. A list of ints; e.g., `np.reshape(arr, [int64(600), int64(800, -1])`
let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None);
let shape_list = ListValue::from_pointer_value(shape_list_ptr, llvm_usize, None);
// Check for -1 in dimensions
gen_for_callback_incrementing(
generator,
@ -2445,8 +2460,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
match (x1, x2) {
(BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => {
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None);
let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None);
let n2 = NDArrayValue::from_pointer_value(n2, llvm_usize, None);
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));

View File

@ -310,7 +310,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
.unwrap()
.to_basic_value_enum(ctx, generator, target_ty)?
.into_pointer_value();
let target = ListValue::from_ptr_val(target, llvm_usize, None);
let target = ListValue::from_pointer_value(target, llvm_usize, None);
if let ExprKind::Slice { .. } = &key.node {
// Handle assigning to a slice
@ -331,7 +331,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
let value =
value.to_basic_value_enum(ctx, generator, value_ty)?.into_pointer_value();
let value = ListValue::from_ptr_val(value, llvm_usize, None);
let value = ListValue::from_pointer_value(value, llvm_usize, None);
let target_item_ty = ctx.get_llvm_type(generator, target_item_ty);
let Some(src_ind) = handle_slice_indices(
@ -463,7 +463,8 @@ pub fn gen_for<G: CodeGenerator>(
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() =>
{
let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range"));
let iter_val =
RangeValue::from_pointer_value(iter_val.into_pointer_value(), Some("range"));
// Internal variable for loop; Cannot be assigned
let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?;
// Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed

View File

@ -452,7 +452,7 @@ fn test_classes_list_type_new() {
let llvm_usize = generator.get_size_type(&ctx);
let llvm_list = ListType::new(&generator, &ctx, llvm_i32.into());
assert!(ListType::is_type(llvm_list.as_base_type(), llvm_usize).is_ok());
assert!(ListType::is_representable(llvm_list.as_base_type(), llvm_usize).is_ok());
}
#[test]
@ -460,7 +460,7 @@ fn test_classes_range_type_new() {
let ctx = inkwell::context::Context::create();
let llvm_range = RangeType::new(&ctx);
assert!(RangeType::is_type(llvm_range.as_base_type()).is_ok());
assert!(RangeType::is_representable(llvm_range.as_base_type()).is_ok());
}
#[test]
@ -472,5 +472,5 @@ fn test_classes_ndarray_type_new() {
let llvm_usize = generator.get_size_type(&ctx);
let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into());
assert!(NDArrayType::is_type(llvm_ndarray.as_base_type(), llvm_usize).is_ok());
assert!(NDArrayType::is_representable(llvm_ndarray.as_base_type(), llvm_usize).is_ok());
}

View File

@ -20,7 +20,10 @@ pub struct ListType<'ctx> {
impl<'ctx> ListType<'ctx> {
/// Checks whether `llvm_ty` represents a `list` type, returning [Err] if it does not.
pub fn is_type(llvm_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> {
pub fn is_representable(
llvm_ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
let llvm_list_ty = llvm_ty.get_element_type();
let AnyTypeEnum::StructType(llvm_list_ty) = llvm_list_ty else {
return Err(format!("Expected struct type for `list` type, got {llvm_list_ty}"));
@ -73,7 +76,7 @@ impl<'ctx> ListType<'ctx> {
/// Creates an [`ListType`] from a [`PointerType`].
#[must_use]
pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
debug_assert!(Self::is_type(ptr_ty, llvm_usize).is_ok());
debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok());
ListType { ty: ptr_ty, llvm_usize }
}
@ -106,6 +109,26 @@ impl<'ctx> ProxyType<'ctx> for ListType<'ctx> {
type Base = PointerType<'ctx>;
type Value = ListValue<'ctx>;
fn is_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: impl BasicType<'ctx>,
) -> Result<(), String> {
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
<Self as ProxyType<'ctx>>::is_representable(generator, ctx, ty)
} else {
Err(format!("Expected pointer type, got {llvm_ty:?}"))
}
}
fn is_representable<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: Self::Base,
) -> Result<(), String> {
Self::is_representable(llvm_ty, generator.get_size_type(ctx))
}
fn new_value<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
@ -146,9 +169,7 @@ impl<'ctx> ProxyType<'ctx> for ListType<'ctx> {
value: <Self::Value as ProxyValue<'ctx>>::Base,
name: Option<&'ctx str>,
) -> Self::Value {
debug_assert_eq!(value.get_type(), self.as_base_type());
ListValue::from_ptr_val(value, self.llvm_usize, name)
Self::Value::from_pointer_value(value, self.llvm_usize, name)
}
fn as_base_type(&self) -> Self::Base {

View File

@ -1,4 +1,4 @@
use inkwell::{types::BasicType, values::IntValue};
use inkwell::{context::Context, types::BasicType, values::IntValue};
use super::{
values::{ArraySliceValue, ProxyValue},
@ -19,7 +19,20 @@ pub trait ProxyType<'ctx>: Into<Self::Base> {
type Base: BasicType<'ctx>;
/// The type of values represented by this type.
type Value: ProxyValue<'ctx>;
type Value: ProxyValue<'ctx, Type = Self>;
fn is_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: impl BasicType<'ctx>,
) -> Result<(), String>;
/// Checks whether `llvm_ty` can be represented by this [`ProxyType`].
fn is_representable<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: Self::Base,
) -> Result<(), String>;
/// Creates a new value of this type.
fn new_value<G: CodeGenerator + ?Sized>(

View File

@ -20,7 +20,10 @@ pub struct NDArrayType<'ctx> {
impl<'ctx> NDArrayType<'ctx> {
/// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not.
pub fn is_type(llvm_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> {
pub fn is_representable(
llvm_ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
let llvm_ndarray_ty = llvm_ty.get_element_type();
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}"));
@ -101,7 +104,7 @@ impl<'ctx> NDArrayType<'ctx> {
/// Creates an [`NDArrayType`] from a [`PointerType`].
#[must_use]
pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
debug_assert!(Self::is_type(ptr_ty, llvm_usize).is_ok());
debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok());
NDArrayType { ty: ptr_ty, llvm_usize }
}
@ -134,6 +137,26 @@ impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> {
type Base = PointerType<'ctx>;
type Value = NDArrayValue<'ctx>;
fn is_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: impl BasicType<'ctx>,
) -> Result<(), String> {
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
<Self as ProxyType<'ctx>>::is_representable(generator, ctx, ty)
} else {
Err(format!("Expected pointer type, got {llvm_ty:?}"))
}
}
fn is_representable<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: Self::Base,
) -> Result<(), String> {
Self::is_representable(llvm_ty, generator.get_size_type(ctx))
}
fn new_value<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
@ -176,7 +199,7 @@ impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> {
) -> Self::Value {
debug_assert_eq!(value.get_type(), self.as_base_type());
NDArrayValue::from_ptr_val(value, self.llvm_usize, name)
NDArrayValue::from_pointer_value(value, self.llvm_usize, name)
}
fn as_base_type(&self) -> Self::Base {

View File

@ -1,6 +1,6 @@
use inkwell::{
context::Context,
types::{AnyTypeEnum, IntType, PointerType},
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
values::IntValue,
AddressSpace,
};
@ -19,7 +19,7 @@ pub struct RangeType<'ctx> {
impl<'ctx> RangeType<'ctx> {
/// Checks whether `llvm_ty` represents a `range` type, returning [Err] if it does not.
pub fn is_type(llvm_ty: PointerType<'ctx>) -> Result<(), String> {
pub fn is_representable(llvm_ty: PointerType<'ctx>) -> Result<(), String> {
let llvm_range_ty = llvm_ty.get_element_type();
let AnyTypeEnum::ArrayType(llvm_range_ty) = llvm_range_ty else {
return Err(format!("Expected array type for `range` type, got {llvm_range_ty}"));
@ -59,7 +59,7 @@ impl<'ctx> RangeType<'ctx> {
/// Creates an [`RangeType`] from a [`PointerType`].
#[must_use]
pub fn from_type(ptr_ty: PointerType<'ctx>) -> Self {
debug_assert!(Self::is_type(ptr_ty).is_ok());
debug_assert!(Self::is_representable(ptr_ty).is_ok());
RangeType { ty: ptr_ty }
}
@ -75,6 +75,26 @@ impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> {
type Base = PointerType<'ctx>;
type Value = RangeValue<'ctx>;
fn is_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: impl BasicType<'ctx>,
) -> Result<(), String> {
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
<Self as ProxyType<'ctx>>::is_representable(generator, ctx, ty)
} else {
Err(format!("Expected pointer type, got {llvm_ty:?}"))
}
}
fn is_representable<G: CodeGenerator + ?Sized>(
_: &G,
_: &'ctx Context,
llvm_ty: Self::Base,
) -> Result<(), String> {
Self::is_representable(llvm_ty)
}
fn new_value<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
@ -117,7 +137,7 @@ impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> {
) -> Self::Value {
debug_assert_eq!(value.get_type(), self.as_base_type());
RangeValue::from_ptr_val(value, name)
RangeValue::from_pointer_value(value, name)
}
fn as_base_type(&self) -> Self::Base {

View File

@ -23,18 +23,21 @@ pub struct ListValue<'ctx> {
impl<'ctx> ListValue<'ctx> {
/// Checks whether `value` is an instance of `list`, returning [Err] if `value` is not an
/// instance.
pub fn is_instance(value: PointerValue<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> {
ListType::is_type(value.get_type(), llvm_usize)
pub fn is_representable(
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
ListType::is_representable(value.get_type(), llvm_usize)
}
/// Creates an [`ListValue`] from a [`PointerValue`].
#[must_use]
pub fn from_ptr_val(
pub fn from_pointer_value(
ptr: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
) -> Self {
debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok());
debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok());
ListValue { value: ptr, llvm_usize, name }
}

View File

@ -1,6 +1,7 @@
use inkwell::values::BasicValue;
use inkwell::{context::Context, values::BasicValue};
use super::types::ProxyType;
use crate::codegen::CodeGenerator;
pub use array::*;
pub use list::*;
pub use ndarray::*;
@ -18,7 +19,25 @@ pub trait ProxyValue<'ctx>: Into<Self::Base> {
type Base: BasicValue<'ctx>;
/// The type of this value.
type Type: ProxyType<'ctx>;
type Type: ProxyType<'ctx, Value = Self>;
/// Checks whether `value` can be represented by this [`ProxyValue`].
fn is_instance<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
value: impl BasicValue<'ctx>,
) -> Result<(), String> {
Self::Type::is_type(generator, ctx, value.as_basic_value_enum().get_type())
}
/// Checks whether `value` can be represented by this [`ProxyValue`].
fn is_representable<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
value: Self::Base,
) -> Result<(), String> {
Self::is_instance(generator, ctx, value.as_basic_value_enum())
}
/// Returns the [type][ProxyType] of this value.
fn get_type(&self) -> Self::Type;

View File

@ -27,18 +27,21 @@ pub struct NDArrayValue<'ctx> {
impl<'ctx> NDArrayValue<'ctx> {
/// Checks whether `value` is an instance of `NDArray`, returning [Err] if `value` is not an
/// instance.
pub fn is_instance(value: PointerValue<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> {
NDArrayType::is_type(value.get_type(), llvm_usize)
pub fn is_representable(
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
NDArrayType::is_representable(value.get_type(), llvm_usize)
}
/// Creates an [`NDArrayValue`] from a [`PointerValue`].
#[must_use]
pub fn from_ptr_val(
pub fn from_pointer_value(
ptr: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
) -> Self {
debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok());
debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok());
NDArrayValue { value: ptr, llvm_usize, name }
}

View File

@ -12,14 +12,14 @@ pub struct RangeValue<'ctx> {
impl<'ctx> RangeValue<'ctx> {
/// Checks whether `value` is an instance of `range`, returning [Err] if `value` is not an instance.
pub fn is_instance(value: PointerValue<'ctx>) -> Result<(), String> {
RangeType::is_type(value.get_type())
pub fn is_representable(value: PointerValue<'ctx>) -> Result<(), String> {
RangeType::is_representable(value.get_type())
}
/// Creates an [`RangeValue`] from a [`PointerValue`].
#[must_use]
pub fn from_ptr_val(ptr: PointerValue<'ctx>, name: Option<&'ctx str>) -> Self {
debug_assert!(Self::is_instance(ptr).is_ok());
pub fn from_pointer_value(ptr: PointerValue<'ctx>, name: Option<&'ctx str>) -> Self {
debug_assert!(Self::is_representable(ptr).is_ok());
RangeValue { value: ptr, name }
}

View File

@ -710,7 +710,7 @@ impl<'a> BuiltinBuilder<'a> {
let (zelf_ty, zelf) = obj.unwrap();
let zelf =
zelf.to_basic_value_enum(ctx, generator, zelf_ty)?.into_pointer_value();
let zelf = RangeValue::from_ptr_val(zelf, Some("range"));
let zelf = RangeValue::from_pointer_value(zelf, Some("range"));
let mut start = None;
let mut stop = None;