forked from M-Labs/nac3
1
0
Fork 0

core: Refactor class abstractions

- Introduce new Type abstractions
- Rearrange some functions
This commit is contained in:
David Mak 2024-06-06 12:16:09 +08:00
parent 08129cc635
commit f0ab1b858a
4 changed files with 405 additions and 175 deletions

View File

@ -4,7 +4,7 @@ use inkwell::values::BasicValueEnum;
use itertools::Itertools; use itertools::Itertools;
use crate::codegen::{CodeGenContext, CodeGenerator, extern_fns, irrt, llvm_intrinsics, numpy}; use crate::codegen::{CodeGenContext, CodeGenerator, extern_fns, irrt, llvm_intrinsics, numpy};
use crate::codegen::classes::{NDArrayValue, UntypedArrayLikeAccessor}; use crate::codegen::classes::{NDArrayValue, ProxyValue, UntypedArrayLikeAccessor};
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl; use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
use crate::codegen::stmt::gen_for_callback_incrementing; use crate::codegen::stmt::gen_for_callback_incrementing;
use crate::toplevel::helper::PRIMITIVE_DEF_IDS; use crate::toplevel::helper::PRIMITIVE_DEF_IDS;
@ -93,7 +93,7 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, "int32", &[n_ty]) _ => unsupported_type(ctx, "int32", &[n_ty])
@ -123,7 +123,7 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>(
ctx.builder ctx.builder
.build_int_s_extend(n, llvm_i64, "sext") .build_int_s_extend(n, llvm_i64, "sext")
.map(Into::into) .map(Into::into)
.unwrap() .unwrap()
} else { } else {
ctx.builder ctx.builder
.build_int_z_extend(n, llvm_i64, "zext") .build_int_z_extend(n, llvm_i64, "zext")
@ -164,7 +164,7 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, "int64", &[n_ty]) _ => unsupported_type(ctx, "int64", &[n_ty])
@ -251,7 +251,7 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, "uint32", &[n_ty]) _ => unsupported_type(ctx, "uint32", &[n_ty])
@ -332,7 +332,7 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, "uint64", &[n_ty]) _ => unsupported_type(ctx, "uint64", &[n_ty])
@ -397,7 +397,7 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, "float", &[n_ty]) _ => unsupported_type(ctx, "float", &[n_ty])
@ -426,7 +426,7 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
ctx.builder ctx.builder
.build_float_to_signed_int(val, llvm_ret_elem_ty, FN_NAME) .build_float_to_signed_int(val, llvm_ret_elem_ty, FN_NAME)
.map(Into::into) .map(Into::into)
.unwrap() .unwrap()
} }
BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
@ -443,7 +443,7 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[n_ty]) _ => unsupported_type(ctx, FN_NAME, &[n_ty])
@ -483,7 +483,7 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[n_ty]) _ => unsupported_type(ctx, FN_NAME, &[n_ty])
@ -552,7 +552,7 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[n_ty]) _ => unsupported_type(ctx, FN_NAME, &[n_ty])
@ -602,7 +602,7 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[n_ty]) _ => unsupported_type(ctx, FN_NAME, &[n_ty])
@ -652,7 +652,7 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[n_ty]) _ => unsupported_type(ctx, FN_NAME, &[n_ty])
@ -772,7 +772,7 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>(
(n_sz, false), (n_sz, false),
|generator, ctx, idx| { |generator, ctx, idx| {
let elem = unsafe { let elem = unsafe {
n.data().get_unchecked(ctx, generator, &idx, None) n.data().get_unchecked(ctx, generator, &idx, None)
}; };
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
@ -870,7 +870,7 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
|generator, ctx, (lhs, rhs)| { |generator, ctx, (lhs, rhs)| {
call_numpy_minimum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) call_numpy_minimum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
}, },
)?.as_ptr_value().into() )?.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
@ -1088,7 +1088,7 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
|generator, ctx, (lhs, rhs)| { |generator, ctx, (lhs, rhs)| {
call_numpy_maximum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) call_numpy_maximum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
}, },
)?.as_ptr_value().into() )?.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
@ -1153,7 +1153,7 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[n_ty]) _ => unsupported_type(ctx, FN_NAME, &[n_ty])
@ -1195,7 +1195,7 @@ pub fn call_numpy_isnan<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x_ty]) _ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1237,7 +1237,7 @@ pub fn call_numpy_isinf<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x_ty]) _ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1277,7 +1277,7 @@ pub fn call_numpy_sin<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x_ty]) _ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1317,7 +1317,7 @@ pub fn call_numpy_cos<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x_ty]) _ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1357,7 +1357,7 @@ pub fn call_numpy_exp<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x_ty]) _ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1397,7 +1397,7 @@ pub fn call_numpy_exp2<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x_ty]) _ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1437,7 +1437,7 @@ pub fn call_numpy_log<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x_ty]) _ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1477,7 +1477,7 @@ pub fn call_numpy_log10<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x_ty]) _ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1517,7 +1517,7 @@ pub fn call_numpy_log2<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x_ty]) _ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1557,7 +1557,7 @@ pub fn call_numpy_fabs<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x_ty]) _ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1597,7 +1597,7 @@ pub fn call_numpy_sqrt<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x_ty]) _ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1637,7 +1637,7 @@ pub fn call_numpy_rint<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x_ty]) _ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1677,7 +1677,7 @@ pub fn call_numpy_tan<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x_ty]) _ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1717,7 +1717,7 @@ pub fn call_numpy_arcsin<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x_ty]) _ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1739,7 +1739,7 @@ pub fn call_numpy_arccos<'ctx, G: CodeGenerator + ?Sized>(
Ok(match x { Ok(match x {
BasicValueEnum::FloatValue(x) => { BasicValueEnum::FloatValue(x) => {
debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float)); debug_assert!(ctx.unifier.unioned(x_ty, ctx.primitives.float));
extern_fns::call_acos(ctx, x, None).into() extern_fns::call_acos(ctx, x, None).into()
} }
@ -1757,7 +1757,7 @@ pub fn call_numpy_arccos<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x_ty]) _ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1797,7 +1797,7 @@ pub fn call_numpy_arctan<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x_ty]) _ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1837,7 +1837,7 @@ pub fn call_numpy_sinh<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x_ty]) _ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1877,7 +1877,7 @@ pub fn call_numpy_cosh<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x_ty]) _ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1917,7 +1917,7 @@ pub fn call_numpy_tanh<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x_ty]) _ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1957,7 +1957,7 @@ pub fn call_numpy_arcsinh<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x_ty]) _ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1997,7 +1997,7 @@ pub fn call_numpy_arccosh<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x_ty]) _ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -2037,7 +2037,7 @@ pub fn call_numpy_arctanh<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x_ty]) _ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -2077,7 +2077,7 @@ pub fn call_numpy_expm1<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x_ty]) _ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -2117,7 +2117,7 @@ pub fn call_numpy_cbrt<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x_ty]) _ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -2157,7 +2157,7 @@ pub fn call_scipy_special_erf<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[z_ty]) _ => unsupported_type(ctx, FN_NAME, &[z_ty])
@ -2197,7 +2197,7 @@ pub fn call_scipy_special_erfc<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x_ty]) _ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -2237,7 +2237,7 @@ pub fn call_scipy_special_gamma<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[z_ty]) _ => unsupported_type(ctx, FN_NAME, &[z_ty])
@ -2277,7 +2277,7 @@ pub fn call_scipy_special_gammaln<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x_ty]) _ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -2317,7 +2317,7 @@ pub fn call_scipy_special_j0<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x_ty]) _ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -2357,7 +2357,7 @@ pub fn call_scipy_special_j1<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x_ty]) _ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -2392,13 +2392,13 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>(
let dtype = if is_ndarray1 && is_ndarray2 { let dtype = if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
ndarray_dtype1 ndarray_dtype1
} else if is_ndarray1 { } else if is_ndarray1 {
unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0
} else if is_ndarray2 { } else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
} else { unreachable!() }; } else { unreachable!() };
@ -2424,7 +2424,7 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>(
|generator, ctx, (lhs, rhs)| { |generator, ctx, (lhs, rhs)| {
call_numpy_arctan2(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) call_numpy_arctan2(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
}, },
)?.as_ptr_value().into() )?.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
@ -2491,7 +2491,7 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>(
|generator, ctx, (lhs, rhs)| { |generator, ctx, (lhs, rhs)| {
call_numpy_copysign(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) call_numpy_copysign(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
}, },
)?.as_ptr_value().into() )?.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
@ -2558,7 +2558,7 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>(
|generator, ctx, (lhs, rhs)| { |generator, ctx, (lhs, rhs)| {
call_numpy_fmax(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) call_numpy_fmax(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
}, },
)?.as_ptr_value().into() )?.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
@ -2625,7 +2625,7 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>(
|generator, ctx, (lhs, rhs)| { |generator, ctx, (lhs, rhs)| {
call_numpy_fmin(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) call_numpy_fmin(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
}, },
)?.as_ptr_value().into() )?.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
@ -2681,7 +2681,7 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>(
|generator, ctx, (lhs, rhs)| { |generator, ctx, (lhs, rhs)| {
call_numpy_ldexp(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) call_numpy_ldexp(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
}, },
)?.as_ptr_value().into() )?.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
@ -2748,7 +2748,7 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>(
|generator, ctx, (lhs, rhs)| { |generator, ctx, (lhs, rhs)| {
call_numpy_hypot(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) call_numpy_hypot(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
}, },
)?.as_ptr_value().into() )?.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
@ -2815,7 +2815,7 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
|generator, ctx, (lhs, rhs)| { |generator, ctx, (lhs, rhs)| {
call_numpy_nextafter(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) call_numpy_nextafter(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
}, },
)?.as_ptr_value().into() )?.as_base_value().into()
} }
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])

View File

@ -3,6 +3,8 @@ use inkwell::{
types::{AnyTypeEnum, BasicTypeEnum, IntType, PointerType}, types::{AnyTypeEnum, BasicTypeEnum, IntType, PointerType},
values::{BasicValueEnum, IntValue, PointerValue}, values::{BasicValueEnum, IntValue, PointerValue},
}; };
use inkwell::types::BasicType;
use inkwell::values::BasicValue;
use crate::codegen::{ use crate::codegen::{
CodeGenContext, CodeGenContext,
CodeGenerator, CodeGenerator,
@ -11,6 +13,40 @@ use crate::codegen::{
stmt::gen_for_callback_incrementing, stmt::gen_for_callback_incrementing,
}; };
/// A LLVM type that is used to represent a non-primitive type in NAC3.
pub trait ProxyType<'ctx> {
/// The underlying type as represented by an LLVM type.
type Base: BasicType<'ctx>;
/// The type of values represented by this type.
type Value: ProxyValue<'ctx>;
/// Creates a [`value`][ProxyValue] with this as its type.
fn create_value(
&self,
value: <Self::Value as ProxyValue<'ctx>>::Base,
name: Option<&'ctx str>,
) -> Self::Value;
/// Returns the base type of this proxy.
fn as_base_type(&self) -> Self::Base;
}
/// A LLVM type that is used to represent a non-primitive value in NAC3.
pub trait ProxyValue<'ctx> {
/// The underlying type as represented by an LLVM value.
type Base: BasicValue<'ctx>;
/// The type of this value.
type Type: ProxyType<'ctx>;
/// Returns the [type][ProxyType] of this value.
fn get_type(&self) -> Self::Type;
/// Returns the base value of this proxy.
fn as_base_value(&self) -> Self::Base;
}
/// An LLVM value that is array-like, i.e. it contains a contiguous, sequenced collection of /// An LLVM value that is array-like, i.e. it contains a contiguous, sequenced collection of
/// elements. /// elements.
pub trait ArrayLikeValue<'ctx> { pub trait ArrayLikeValue<'ctx> {
@ -388,26 +424,20 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ArraySliceValue<'ctx> {
impl<'ctx> UntypedArrayLikeAccessor<'ctx> for ArraySliceValue<'ctx> {} impl<'ctx> UntypedArrayLikeAccessor<'ctx> for ArraySliceValue<'ctx> {}
impl<'ctx> UntypedArrayLikeMutator<'ctx> for ArraySliceValue<'ctx> {} impl<'ctx> UntypedArrayLikeMutator<'ctx> for ArraySliceValue<'ctx> {}
#[cfg(not(debug_assertions))] /// Proxy type for a `list` type in LLVM.
pub fn assert_is_list<'ctx>(_value: PointerValue<'ctx>, _llvm_usize: IntType<'ctx>) {} #[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct ListType<'ctx> {
#[cfg(debug_assertions)] ty: PointerType<'ctx>,
pub fn assert_is_list<'ctx>(value: PointerValue<'ctx>, llvm_usize: IntType<'ctx>) { llvm_usize: IntType<'ctx>,
ListValue::is_instance(value, llvm_usize).unwrap();
} }
/// Proxy type for accessing a `list` value in LLVM. impl<'ctx> ListType<'ctx> {
#[derive(Copy, Clone)] /// Checks whether `llvm_ty` represents a `list` type, returning [Err] if it does not.
pub struct ListValue<'ctx>(PointerValue<'ctx>, Option<&'ctx str>); pub fn is_type(
llvm_ty: PointerType<'ctx>,
impl<'ctx> ListValue<'ctx> {
/// Checks whether `value` is an instance of `list`, returning [Err] if `value` is not an
/// instance.
pub fn is_instance(
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>, llvm_usize: IntType<'ctx>,
) -> Result<(), String> { ) -> Result<(), String> {
let llvm_list_ty = value.get_type().get_element_type(); let llvm_list_ty = llvm_ty.get_element_type();
let AnyTypeEnum::StructType(llvm_list_ty) = llvm_list_ty else { let AnyTypeEnum::StructType(llvm_list_ty) = llvm_list_ty else {
return Err(format!("Expected struct type for `list` type, got {llvm_list_ty}")) return Err(format!("Expected struct type for `list` type, got {llvm_list_ty}"))
}; };
@ -433,28 +463,97 @@ impl<'ctx> ListValue<'ctx> {
Ok(()) Ok(())
} }
/// Creates an [`ListType`] from a [`PointerType`].
#[must_use]
pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
debug_assert!(Self::is_type(ptr_ty, llvm_usize).is_ok());
ListType { ty: ptr_ty, llvm_usize }
}
/// Returns the type of the `size` field of this `list` type.
#[must_use]
pub fn size_type(&self) -> IntType<'ctx> {
self.as_base_type()
.get_element_type()
.into_struct_type()
.get_field_type_at_index(0)
.map(BasicTypeEnum::into_int_type)
.unwrap()
}
/// Returns the element type of this `list` type.
#[must_use]
pub fn element_type(&self) -> BasicTypeEnum<'ctx> {
self.as_base_type()
.get_element_type()
.into_struct_type()
.get_field_type_at_index(1)
.unwrap()
}
}
impl<'ctx> ProxyType<'ctx> for ListType<'ctx> {
type Base = PointerType<'ctx>;
type Value = ListValue<'ctx>;
fn create_value(
&self,
value: <Self::Value as ProxyValue<'ctx>>::Base,
name: Option<&'ctx str>,
) -> Self::Value {
debug_assert_eq!(value.get_type(), self.as_base_type());
ListValue { value, llvm_usize: self.llvm_usize, name }
}
fn as_base_type(&self) -> Self::Base {
self.ty
}
}
impl<'ctx> From<ListType<'ctx>> for PointerType<'ctx> {
fn from(value: ListType<'ctx>) -> Self {
value.as_base_type()
}
}
/// Proxy type for accessing a `list` value in LLVM.
#[derive(Copy, Clone)]
pub struct ListValue<'ctx> {
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
}
impl<'ctx> ListValue<'ctx> {
/// Checks whether `value` is an instance of `list`, returning [Err] if `value` is not an
/// instance.
pub fn is_instance(
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
ListType::is_type(value.get_type(), llvm_usize)
}
/// Creates an [`ListValue`] from a [`PointerValue`]. /// Creates an [`ListValue`] from a [`PointerValue`].
#[must_use] #[must_use]
pub fn from_ptr_val(ptr: PointerValue<'ctx>, llvm_usize: IntType<'ctx>, name: Option<&'ctx str>) -> Self { pub fn from_ptr_val(ptr: PointerValue<'ctx>, llvm_usize: IntType<'ctx>, name: Option<&'ctx str>) -> Self {
assert_is_list(ptr, llvm_usize); debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok());
ListValue(ptr, name)
}
/// Returns the underlying [`PointerValue`] pointing to the `list` instance. <Self as ProxyValue<'ctx>>::Type::from_type(ptr.get_type(), llvm_usize)
#[must_use] .create_value(ptr, name)
pub fn as_ptr_value(&self) -> PointerValue<'ctx> {
self.0
} }
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
/// on the field. /// on the field.
fn pptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { fn pptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.1.map(|v| format!("{v}.data.addr")).unwrap_or_default(); let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default();
unsafe { unsafe {
ctx.builder.build_in_bounds_gep( ctx.builder.build_in_bounds_gep(
self.as_ptr_value(), self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_zero()], &[llvm_i32.const_zero(), llvm_i32.const_zero()],
var_name.as_str(), var_name.as_str(),
).unwrap() ).unwrap()
@ -464,11 +563,11 @@ impl<'ctx> ListValue<'ctx> {
/// Returns the pointer to the field storing the size of this `list`. /// Returns the pointer to the field storing the size of this `list`.
fn ptr_to_size(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { fn ptr_to_size(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.1.map(|v| format!("{v}.size.addr")).unwrap_or_default(); let var_name = self.name.map(|v| format!("{v}.size.addr")).unwrap_or_default();
unsafe { unsafe {
ctx.builder.build_in_bounds_gep( ctx.builder.build_in_bounds_gep(
self.0, self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
var_name.as_str(), var_name.as_str(),
).unwrap() ).unwrap()
@ -519,7 +618,7 @@ impl<'ctx> ListValue<'ctx> {
let psize = self.ptr_to_size(ctx); let psize = self.ptr_to_size(ctx);
let var_name = name let var_name = name
.map(ToString::to_string) .map(ToString::to_string)
.or_else(|| self.1.map(|v| format!("{v}.size"))) .or_else(|| self.name.map(|v| format!("{v}.size")))
.unwrap_or_default(); .unwrap_or_default();
ctx.builder.build_load(psize, var_name.as_str()) ctx.builder.build_load(psize, var_name.as_str())
@ -528,9 +627,22 @@ impl<'ctx> ListValue<'ctx> {
} }
} }
impl<'ctx> ProxyValue<'ctx> for ListValue<'ctx> {
type Base = PointerValue<'ctx>;
type Type = ListType<'ctx>;
fn get_type(&self) -> Self::Type {
ListType::from_type(self.as_base_value().get_type(), self.llvm_usize)
}
fn as_base_value(&self) -> Self::Base {
self.value
}
}
impl<'ctx> From<ListValue<'ctx>> for PointerValue<'ctx> { impl<'ctx> From<ListValue<'ctx>> for PointerValue<'ctx> {
fn from(value: ListValue<'ctx>) -> Self { fn from(value: ListValue<'ctx>) -> Self {
value.as_ptr_value() value.as_base_value()
} }
} }
@ -544,7 +656,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for ListDataProxy<'ctx, '_> {
_: &CodeGenContext<'ctx, '_>, _: &CodeGenContext<'ctx, '_>,
_: &G, _: &G,
) -> AnyTypeEnum<'ctx> { ) -> AnyTypeEnum<'ctx> {
self.0.0.get_type().get_element_type() self.0.value.get_type().get_element_type()
} }
fn base_ptr<G: CodeGenerator + ?Sized>( fn base_ptr<G: CodeGenerator + ?Sized>(
@ -552,7 +664,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for ListDataProxy<'ctx, '_> {
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
_: &G, _: &G,
) -> PointerValue<'ctx> { ) -> PointerValue<'ctx> {
let var_name = self.0.1.map(|v| format!("{v}.data")).unwrap_or_default(); let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
ctx.builder.build_load(self.0.pptr_to_data(ctx), var_name.as_str()) ctx.builder.build_load(self.0.pptr_to_data(ctx), var_name.as_str())
.map(BasicValueEnum::into_pointer_value) .map(BasicValueEnum::into_pointer_value)
@ -616,22 +728,16 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ListDataProxy<'ctx, '_> {
impl<'ctx> UntypedArrayLikeAccessor<'ctx> for ListDataProxy<'ctx, '_> {} impl<'ctx> UntypedArrayLikeAccessor<'ctx> for ListDataProxy<'ctx, '_> {}
impl<'ctx> UntypedArrayLikeMutator<'ctx> for ListDataProxy<'ctx, '_> {} impl<'ctx> UntypedArrayLikeMutator<'ctx> for ListDataProxy<'ctx, '_> {}
#[cfg(not(debug_assertions))] /// Proxy type for a `range` type in LLVM.
pub fn assert_is_range(_value: PointerValue) {} #[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct RangeType<'ctx> {
#[cfg(debug_assertions)] ty: PointerType<'ctx>,
pub fn assert_is_range(value: PointerValue) {
RangeValue::is_instance(value).unwrap();
} }
/// Proxy type for accessing a `range` value in LLVM. impl<'ctx> RangeType<'ctx> {
#[derive(Copy, Clone)] /// Checks whether `llvm_ty` represents a `range` type, returning [Err] if it does not.
pub struct RangeValue<'ctx>(PointerValue<'ctx>, Option<&'ctx str>); pub fn is_type(llvm_ty: PointerType<'ctx>) -> Result<(), String> {
let llvm_range_ty = llvm_ty.get_element_type();
impl<'ctx> RangeValue<'ctx> {
/// Checks whether `value` is an instance of `range`, returning [Err] if `value` is not an instance.
pub fn is_instance(value: PointerValue<'ctx>) -> Result<(), String> {
let llvm_range_ty = value.get_type().get_element_type();
let AnyTypeEnum::ArrayType(llvm_range_ty) = llvm_range_ty else { let AnyTypeEnum::ArrayType(llvm_range_ty) = llvm_range_ty else {
return Err(format!("Expected array type for `range` type, got {llvm_range_ty}")) return Err(format!("Expected array type for `range` type, got {llvm_range_ty}"))
}; };
@ -651,37 +757,74 @@ impl<'ctx> RangeValue<'ctx> {
Ok(()) Ok(())
} }
/// Creates an [`RangeValue`] from a [`PointerValue`]. /// Creates an [`RangeType`] from a [`PointerType`].
#[must_use] #[must_use]
pub fn from_ptr_val(ptr: PointerValue<'ctx>, name: Option<&'ctx str>) -> Self { pub fn from_type(ptr_ty: PointerType<'ctx>) -> Self {
assert_is_range(ptr); debug_assert!(Self::is_type(ptr_ty).is_ok());
RangeValue(ptr, name)
RangeType { ty: ptr_ty }
} }
/// Returns the element type of this `range` object. /// Returns the type of all fields of this `range` type.
#[must_use] #[must_use]
pub fn element_type(&self) -> IntType<'ctx> { pub fn value_type(&self) -> IntType<'ctx> {
self.as_ptr_value() self.as_base_type()
.get_type()
.get_element_type() .get_element_type()
.into_array_type() .into_array_type()
.get_element_type() .get_element_type()
.into_int_type() .into_int_type()
} }
}
/// Returns the underlying [`PointerValue`] pointing to the `range` instance. impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> {
type Base = PointerType<'ctx>;
type Value = RangeValue<'ctx>;
fn create_value(&self, value: <Self::Value as ProxyValue<'ctx>>::Base, name: Option<&'ctx str>) -> Self::Value {
debug_assert_eq!(value.get_type(), self.as_base_type());
RangeValue { value, name }
}
fn as_base_type(&self) -> Self::Base {
self.ty
}
}
impl<'ctx> From<RangeType<'ctx>> for PointerType<'ctx> {
fn from(value: RangeType<'ctx>) -> Self {
value.as_base_type()
}
}
/// Proxy type for accessing a `range` value in LLVM.
#[derive(Copy, Clone)]
pub struct RangeValue<'ctx> {
value: PointerValue<'ctx>,
name: Option<&'ctx str>,
}
impl<'ctx> RangeValue<'ctx> {
/// Checks whether `value` is an instance of `range`, returning [Err] if `value` is not an instance.
pub fn is_instance(value: PointerValue<'ctx>) -> Result<(), String> {
RangeType::is_type(value.get_type())
}
/// Creates an [`RangeValue`] from a [`PointerValue`].
#[must_use] #[must_use]
pub fn as_ptr_value(&self) -> PointerValue<'ctx> { pub fn from_ptr_val(ptr: PointerValue<'ctx>, name: Option<&'ctx str>) -> Self {
self.0 debug_assert!(Self::is_instance(ptr).is_ok());
<Self as ProxyValue<'ctx>>::Type::from_type(ptr.get_type()).create_value(ptr, name)
} }
fn ptr_to_start(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { fn ptr_to_start(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.1.map(|v| format!("{v}.start.addr")).unwrap_or_default(); let var_name = self.name.map(|v| format!("{v}.start.addr")).unwrap_or_default();
unsafe { unsafe {
ctx.builder.build_in_bounds_gep( ctx.builder.build_in_bounds_gep(
self.0, self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(0, false)], &[llvm_i32.const_zero(), llvm_i32.const_int(0, false)],
var_name.as_str(), var_name.as_str(),
).unwrap() ).unwrap()
@ -690,11 +833,11 @@ impl<'ctx> RangeValue<'ctx> {
fn ptr_to_end(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { fn ptr_to_end(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.1.map(|v| format!("{v}.end.addr")).unwrap_or_default(); let var_name = self.name.map(|v| format!("{v}.end.addr")).unwrap_or_default();
unsafe { unsafe {
ctx.builder.build_in_bounds_gep( ctx.builder.build_in_bounds_gep(
self.0, self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(1, false)], &[llvm_i32.const_zero(), llvm_i32.const_int(1, false)],
var_name.as_str(), var_name.as_str(),
).unwrap() ).unwrap()
@ -703,11 +846,11 @@ impl<'ctx> RangeValue<'ctx> {
fn ptr_to_step(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { fn ptr_to_step(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.1.map(|v| format!("{v}.step.addr")).unwrap_or_default(); let var_name = self.name.map(|v| format!("{v}.step.addr")).unwrap_or_default();
unsafe { unsafe {
ctx.builder.build_in_bounds_gep( ctx.builder.build_in_bounds_gep(
self.0, self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(2, false)], &[llvm_i32.const_zero(), llvm_i32.const_int(2, false)],
var_name.as_str(), var_name.as_str(),
).unwrap() ).unwrap()
@ -731,7 +874,7 @@ impl<'ctx> RangeValue<'ctx> {
let pstart = self.ptr_to_start(ctx); let pstart = self.ptr_to_start(ctx);
let var_name = name let var_name = name
.map(ToString::to_string) .map(ToString::to_string)
.or_else(|| self.1.map(|v| format!("{v}.start"))) .or_else(|| self.name.map(|v| format!("{v}.start")))
.unwrap_or_default(); .unwrap_or_default();
ctx.builder.build_load(pstart, var_name.as_str()) ctx.builder.build_load(pstart, var_name.as_str())
@ -756,7 +899,7 @@ impl<'ctx> RangeValue<'ctx> {
let pend = self.ptr_to_end(ctx); let pend = self.ptr_to_end(ctx);
let var_name = name let var_name = name
.map(ToString::to_string) .map(ToString::to_string)
.or_else(|| self.1.map(|v| format!("{v}.end"))) .or_else(|| self.name.map(|v| format!("{v}.end")))
.unwrap_or_default(); .unwrap_or_default();
ctx.builder.build_load(pend, var_name.as_str()) ctx.builder.build_load(pend, var_name.as_str())
@ -781,7 +924,7 @@ impl<'ctx> RangeValue<'ctx> {
let pstep = self.ptr_to_step(ctx); let pstep = self.ptr_to_step(ctx);
let var_name = name let var_name = name
.map(ToString::to_string) .map(ToString::to_string)
.or_else(|| self.1.map(|v| format!("{v}.step"))) .or_else(|| self.name.map(|v| format!("{v}.step")))
.unwrap_or_default(); .unwrap_or_default();
ctx.builder.build_load(pstep, var_name.as_str()) ctx.builder.build_load(pstep, var_name.as_str())
@ -790,32 +933,39 @@ impl<'ctx> RangeValue<'ctx> {
} }
} }
impl<'ctx> From<RangeValue<'ctx>> for PointerValue<'ctx> { impl<'ctx> ProxyValue<'ctx> for RangeValue<'ctx> {
fn from(value: RangeValue<'ctx>) -> Self { type Base = PointerValue<'ctx>;
value.as_ptr_value() type Type = RangeType<'ctx>;
fn get_type(&self) -> Self::Type {
RangeType::from_type(self.value.get_type())
}
fn as_base_value(&self) -> Self::Base {
self.value
} }
} }
#[cfg(not(debug_assertions))] impl<'ctx> From<RangeValue<'ctx>> for PointerValue<'ctx> {
pub fn assert_is_ndarray<'ctx>(_value: PointerValue<'ctx>, _llvm_usize: IntType<'ctx>) {} fn from(value: RangeValue<'ctx>) -> Self {
value.as_base_value()
#[cfg(debug_assertions)] }
pub fn assert_is_ndarray<'ctx>(value: PointerValue<'ctx>, llvm_usize: IntType<'ctx>) {
NDArrayValue::is_instance(value, llvm_usize).unwrap();
} }
/// Proxy type for accessing an `NDArray` value in LLVM. /// Proxy type for a `ndarray` type in LLVM.
#[derive(Copy, Clone)] #[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct NDArrayValue<'ctx>(PointerValue<'ctx>, Option<&'ctx str>); pub struct NDArrayType<'ctx> {
ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>,
}
impl<'ctx> NDArrayValue<'ctx> { impl<'ctx> NDArrayType<'ctx> {
/// Checks whether `value` is an instance of `NDArray`, returning [Err] if `value` is not an /// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not.
/// instance. pub fn is_type(
pub fn is_instance( llvm_ty: PointerType<'ctx>,
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>, llvm_usize: IntType<'ctx>,
) -> Result<(), String> { ) -> Result<(), String> {
let llvm_ndarray_ty = value.get_type().get_element_type(); let llvm_ndarray_ty = llvm_ty.get_element_type();
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}")) return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}"))
}; };
@ -855,31 +1005,96 @@ impl<'ctx> NDArrayValue<'ctx> {
Ok(()) Ok(())
} }
/// Creates an [`NDArrayValue`] from a [`PointerValue`]. /// Creates an [`NDArrayType`] from a [`PointerType`].
#[must_use] #[must_use]
pub fn from_ptr_val( pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
ptr: PointerValue<'ctx>, debug_assert!(Self::is_type(ptr_ty, llvm_usize).is_ok());
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>, NDArrayType { ty: ptr_ty, llvm_usize }
) -> Self {
assert_is_ndarray(ptr, llvm_usize);
NDArrayValue(ptr, name)
} }
/// Returns the underlying [`PointerValue`] pointing to the `NDArray` instance. /// Returns the type of the `size` field of this `ndarray` type.
#[must_use] #[must_use]
pub fn as_ptr_value(&self) -> PointerValue<'ctx> { pub fn size_type(&self) -> IntType<'ctx> {
self.0 self.as_base_type()
.get_element_type()
.into_struct_type()
.get_field_type_at_index(0)
.map(BasicTypeEnum::into_int_type)
.unwrap()
}
/// Returns the element type of this `ndarray` type.
#[must_use]
pub fn element_type(&self) -> BasicTypeEnum<'ctx> {
self.as_base_type()
.get_element_type()
.into_struct_type()
.get_field_type_at_index(2)
.unwrap()
}
}
impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> {
type Base = PointerType<'ctx>;
type Value = NDArrayValue<'ctx>;
fn create_value(
&self,
value: <Self::Value as ProxyValue<'ctx>>::Base,
name: Option<&'ctx str>,
) -> Self::Value {
debug_assert_eq!(value.get_type(), self.as_base_type());
NDArrayValue { value, llvm_usize: self.llvm_usize, name }
}
fn as_base_type(&self) -> Self::Base {
self.ty
}
}
impl<'ctx> From<NDArrayType<'ctx>> for PointerType<'ctx> {
fn from(value: NDArrayType<'ctx>) -> Self {
value.as_base_type()
}
}
/// Proxy type for accessing an `NDArray` value in LLVM.
#[derive(Copy, Clone)]
pub struct NDArrayValue<'ctx> {
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
}
impl<'ctx> NDArrayValue<'ctx> {
/// Checks whether `value` is an instance of `NDArray`, returning [Err] if `value` is not an
/// instance.
pub fn is_instance(
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
NDArrayType::is_type(value.get_type(), llvm_usize)
}
/// Creates an [`NDArrayValue`] from a [`PointerValue`].
#[must_use]
pub fn from_ptr_val(ptr: PointerValue<'ctx>, llvm_usize: IntType<'ctx>, name: Option<&'ctx str>) -> Self {
debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok());
<Self as ProxyValue<'ctx>>::Type::from_type(ptr.get_type(), llvm_usize)
.create_value(ptr, name)
} }
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`. /// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.1.map(|v| format!("{v}.ndims.addr")).unwrap_or_default(); let var_name = self.name.map(|v| format!("{v}.ndims.addr")).unwrap_or_default();
unsafe { unsafe {
ctx.builder.build_in_bounds_gep( ctx.builder.build_in_bounds_gep(
self.0, self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_zero()], &[llvm_i32.const_zero(), llvm_i32.const_zero()],
var_name.as_str(), var_name.as_str(),
).unwrap() ).unwrap()
@ -911,11 +1126,11 @@ impl<'ctx> NDArrayValue<'ctx> {
/// on the field. /// on the field.
fn ptr_to_dims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { fn ptr_to_dims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.1.map(|v| format!("{v}.dims.addr")).unwrap_or_default(); let var_name = self.name.map(|v| format!("{v}.dims.addr")).unwrap_or_default();
unsafe { unsafe {
ctx.builder.build_in_bounds_gep( ctx.builder.build_in_bounds_gep(
self.as_ptr_value(), self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
var_name.as_str(), var_name.as_str(),
).unwrap() ).unwrap()
@ -947,11 +1162,11 @@ impl<'ctx> NDArrayValue<'ctx> {
/// on the field. /// on the field.
fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.1.map(|v| format!("{v}.data.addr")).unwrap_or_default(); let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default();
unsafe { unsafe {
ctx.builder.build_in_bounds_gep( ctx.builder.build_in_bounds_gep(
self.as_ptr_value(), self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(2, true)], &[llvm_i32.const_zero(), llvm_i32.const_int(2, true)],
var_name.as_str(), var_name.as_str(),
).unwrap() ).unwrap()
@ -981,9 +1196,22 @@ impl<'ctx> NDArrayValue<'ctx> {
} }
} }
impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> {
type Base = PointerValue<'ctx>;
type Type = NDArrayType<'ctx>;
fn get_type(&self) -> Self::Type {
NDArrayType::from_type(self.as_base_value().get_type(), self.llvm_usize)
}
fn as_base_value(&self) -> Self::Base {
self.value
}
}
impl<'ctx> From<NDArrayValue<'ctx>> for PointerValue<'ctx> { impl<'ctx> From<NDArrayValue<'ctx>> for PointerValue<'ctx> {
fn from(value: NDArrayValue<'ctx>) -> Self { fn from(value: NDArrayValue<'ctx>) -> Self {
value.as_ptr_value() value.as_base_value()
} }
} }
@ -1005,7 +1233,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDimsProxy<'ctx, '_> {
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
_: &G, _: &G,
) -> PointerValue<'ctx> { ) -> PointerValue<'ctx> {
let var_name = self.0.1.map(|v| format!("{v}.data")).unwrap_or_default(); let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
ctx.builder.build_load(self.0.ptr_to_dims(ctx), var_name.as_str()) ctx.builder.build_load(self.0.ptr_to_dims(ctx), var_name.as_str())
.map(BasicValueEnum::into_pointer_value) .map(BasicValueEnum::into_pointer_value)
@ -1110,7 +1338,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> {
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
_: &G, _: &G,
) -> PointerValue<'ctx> { ) -> PointerValue<'ctx> {
let var_name = self.0.1.map(|v| format!("{v}.data")).unwrap_or_default(); let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
ctx.builder.build_load(self.0.ptr_to_data(ctx), var_name.as_str()) ctx.builder.build_load(self.0.ptr_to_data(ctx), var_name.as_str())
.map(BasicValueEnum::into_pointer_value) .map(BasicValueEnum::into_pointer_value)

View File

@ -8,6 +8,7 @@ use crate::{
ArraySliceValue, ArraySliceValue,
ListValue, ListValue,
NDArrayValue, NDArrayValue,
ProxyValue,
RangeValue, RangeValue,
TypedArrayLikeAccessor, TypedArrayLikeAccessor,
UntypedArrayLikeAccessor, UntypedArrayLikeAccessor,
@ -1090,7 +1091,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
emit_cont_bb(ctx, generator, list); emit_cont_bb(ctx, generator, list);
Ok(Some(list.as_ptr_value().into())) Ok(Some(list.as_base_value().into()))
} }
/// Generates LLVM IR for a binary operator expression using the [`Type`] and /// Generates LLVM IR for a binary operator expression using the [`Type`] and
@ -1173,8 +1174,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
ctx, ctx,
ndarray_dtype1, ndarray_dtype1,
if is_aug_assign { Some(left_val) } else { None }, if is_aug_assign { Some(left_val) } else { None },
(left_val.as_ptr_value().into(), false), (left_val.as_base_value().into(), false),
(right_val.as_ptr_value().into(), false), (right_val.as_base_value().into(), false),
|generator, ctx, (lhs, rhs)| { |generator, ctx, (lhs, rhs)| {
gen_binop_expr_with_values( gen_binop_expr_with_values(
generator, generator,
@ -1189,7 +1190,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
)? )?
}; };
Ok(Some(res.as_ptr_value().into())) Ok(Some(res.as_base_value().into()))
} else { } else {
let (ndarray_dtype, _) = unpack_ndarray_var_tys( let (ndarray_dtype, _) = unpack_ndarray_var_tys(
&mut ctx.unifier, &mut ctx.unifier,
@ -1220,7 +1221,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
}, },
)?; )?;
Ok(Some(res.as_ptr_value().into())) Ok(Some(res.as_base_value().into()))
} }
} else { } else {
let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap()); let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap());
@ -1410,7 +1411,7 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
}, },
)?; )?;
res.as_ptr_value().into() res.as_base_value().into()
} else { } else {
unimplemented!() unimplemented!()
})) }))
@ -1478,7 +1479,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
ctx, ctx,
ctx.primitives.bool, ctx.primitives.bool,
None, None,
(left_val.as_ptr_value().into(), false), (left_val.as_base_value().into(), false),
(rhs, false), (rhs, false),
|generator, ctx, (lhs, rhs)| { |generator, ctx, (lhs, rhs)| {
let val = gen_cmpop_expr_with_values( let val = gen_cmpop_expr_with_values(
@ -1493,7 +1494,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
}, },
)?; )?;
Ok(Some(res.as_ptr_value().into())) Ok(Some(res.as_base_value().into()))
} else { } else {
let (ndarray_dtype, _) = unpack_ndarray_var_tys( let (ndarray_dtype, _) = unpack_ndarray_var_tys(
&mut ctx.unifier, &mut ctx.unifier,
@ -1519,7 +1520,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
}, },
)?; )?;
Ok(Some(res.as_ptr_value().into())) Ok(Some(res.as_base_value().into()))
} }
} }
} }
@ -1819,7 +1820,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
ty, ty,
v, v,
&slices, &slices,
)?.as_ptr_value().into() )?.as_base_value().into()
} }
ExprKind::Slice { .. } => { ExprKind::Slice { .. } => {
@ -1833,7 +1834,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
ty, ty,
v, v,
&[slice], &[slice],
)?.as_ptr_value().into() )?.as_base_value().into()
} }
_ => { _ => {
@ -1935,7 +1936,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
llvm_i1.const_zero(), llvm_i1.const_zero(),
); );
ndarray.as_ptr_value().into() ndarray.as_base_value().into()
} }
})) }))
} }
@ -2025,7 +2026,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
.ptr_offset(ctx, generator, &usize.const_int(i as u64, false), Some("elem_ptr")); .ptr_offset(ctx, generator, &usize.const_int(i as u64, false), Some("elem_ptr"));
ctx.builder.build_store(elem_ptr, *v).unwrap(); ctx.builder.build_store(elem_ptr, *v).unwrap();
} }
arr_str_ptr.as_ptr_value().into() arr_str_ptr.as_base_value().into()
} }
ExprKind::Tuple { elts, .. } => { ExprKind::Tuple { elts, .. } => {
let elements_val = elts let elements_val = elts
@ -2406,7 +2407,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
v, v,
(start, end, step), (start, end, step),
); );
res_array_ret.as_ptr_value().into() res_array_ret.as_base_value().into()
} else { } else {
let len = v.load_size(ctx, Some("len")); let len = v.load_size(ctx, Some("len"));
let raw_index = if let Some(v) = generator.gen_expr(ctx, slice)? { let raw_index = if let Some(v) = generator.gen_expr(ctx, slice)? {

View File

@ -7,6 +7,7 @@ use crate::{
ArrayLikeValue, ArrayLikeValue,
ListValue, ListValue,
NDArrayValue, NDArrayValue,
ProxyValue,
TypedArrayLikeAccessor, TypedArrayLikeAccessor,
TypedArrayLikeAdapter, TypedArrayLikeAdapter,
TypedArrayLikeMutator, TypedArrayLikeMutator,
@ -1172,7 +1173,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
); );
} }
let lhs = if res.is_some_and(|res| res.as_ptr_value() == lhs.as_ptr_value()) { let lhs = if res.is_some_and(|res| res.as_base_value() == lhs.as_base_value()) {
ndarray_copy_impl(generator, ctx, elem_ty, lhs)? ndarray_copy_impl(generator, ctx, elem_ty, lhs)?
} else { } else {
lhs lhs