[core] Refactor/Remove redundant and unused constructs
- Use ProxyValue.name where necessary - Remove NDArrayValue::ptr_to_{shape,strides} - Remove functions made obsolete by ndstrides - Remove use statement for ndarray::views as it only contain an impl block. - Remove class_names field in Resolvers of test sources
This commit is contained in:
parent
08b717d640
commit
f62babbace
@ -34,19 +34,14 @@ use super::{
|
||||
},
|
||||
types::{ndarray::NDArrayType, ListType},
|
||||
values::{
|
||||
ndarray::{NDArrayValue, RustNDIndex},
|
||||
ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue,
|
||||
TypedArrayLikeAccessor, UntypedArrayLikeAccessor,
|
||||
ndarray::RustNDIndex, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue,
|
||||
UntypedArrayLikeAccessor,
|
||||
},
|
||||
CodeGenContext, CodeGenTask, CodeGenerator,
|
||||
};
|
||||
use crate::{
|
||||
symbol_resolver::{SymbolValue, ValueEnum},
|
||||
toplevel::{
|
||||
helper::{extract_ndims, PrimDef},
|
||||
numpy::unpack_ndarray_var_tys,
|
||||
DefinitionId, TopLevelDef,
|
||||
},
|
||||
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef},
|
||||
typecheck::{
|
||||
magic_methods::{Binop, BinopVariant, HasOpInfo},
|
||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap},
|
||||
@ -2512,319 +2507,6 @@ pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>(
|
||||
)
|
||||
}
|
||||
|
||||
/// Generates code for a subscript expression on an `ndarray`.
|
||||
///
|
||||
/// * `ty` - The `Type` of the `NDArray` elements.
|
||||
/// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`.
|
||||
/// * `v` - The `NDArray` value.
|
||||
/// * `slice` - The slice expression used to subscript into the `ndarray`.
|
||||
fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ty: Type,
|
||||
ndims_ty: Type,
|
||||
v: NDArrayValue<'ctx>,
|
||||
slice: &Expr<Option<Type>>,
|
||||
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
||||
let llvm_i1 = ctx.ctx.bool_type();
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims_ty) else {
|
||||
codegen_unreachable!(ctx)
|
||||
};
|
||||
|
||||
let ndims = values
|
||||
.iter()
|
||||
.map(|ndim| u64::try_from(ndim.clone()).map_err(|()| ndim.clone()))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|val| {
|
||||
format!(
|
||||
"Expected non-negative literal for ndarray.ndims, got {}",
|
||||
i128::try_from(val).unwrap()
|
||||
)
|
||||
})?;
|
||||
|
||||
assert!(!ndims.is_empty());
|
||||
|
||||
// The number of dimensions subscripted by the index expression.
|
||||
// Slicing a ndarray will yield the same number of dimensions, whereas indexing into a
|
||||
// dimension will remove a dimension.
|
||||
let subscripted_dims = match &slice.node {
|
||||
ExprKind::Tuple { elts, .. } => elts.iter().fold(0, |acc, value_subexpr| {
|
||||
if let ExprKind::Slice { .. } = &value_subexpr.node {
|
||||
acc
|
||||
} else {
|
||||
acc + 1
|
||||
}
|
||||
}),
|
||||
|
||||
ExprKind::Slice { .. } => 0,
|
||||
_ => 1,
|
||||
};
|
||||
|
||||
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum();
|
||||
let sizeof_elem = llvm_ndarray_data_t.size_of().unwrap();
|
||||
|
||||
// Check that len is non-zero
|
||||
let len = v.load_ndims(ctx);
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder.build_int_compare(IntPredicate::SGT, len, llvm_usize.const_zero(), "").unwrap(),
|
||||
"0:IndexError",
|
||||
"too many indices for array: array is {0}-dimensional but 1 were indexed",
|
||||
[Some(len), None, None],
|
||||
slice.location,
|
||||
);
|
||||
|
||||
// Normalizes a possibly-negative index to its corresponding positive index
|
||||
let normalize_index = |generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
index: IntValue<'ctx>,
|
||||
dim: u64| {
|
||||
gen_if_else_expr_callback(
|
||||
generator,
|
||||
ctx,
|
||||
|_, ctx| {
|
||||
Ok(ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::SGE, index, index.get_type().const_zero(), "")
|
||||
.unwrap())
|
||||
},
|
||||
|_, _| Ok(Some(index)),
|
||||
|generator, ctx| {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
|
||||
let len = unsafe {
|
||||
v.shape().get_typed_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(dim, true),
|
||||
None,
|
||||
)
|
||||
};
|
||||
|
||||
let index = ctx
|
||||
.builder
|
||||
.build_int_add(
|
||||
len,
|
||||
ctx.builder.build_int_s_extend(index, llvm_usize, "").unwrap(),
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
Ok(Some(ctx.builder.build_int_truncate(index, llvm_i32, "").unwrap()))
|
||||
},
|
||||
)
|
||||
.map(|v| v.map(BasicValueEnum::into_int_value))
|
||||
};
|
||||
|
||||
// Converts a slice expression into a slice-range tuple
|
||||
let expr_to_slice = |generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
node: &ExprKind<Option<Type>>,
|
||||
dim: u64| {
|
||||
match node {
|
||||
ExprKind::Constant { value: Constant::Int(v), .. } => {
|
||||
let Some(index) =
|
||||
normalize_index(generator, ctx, llvm_i32.const_int(*v as u64, true), dim)?
|
||||
else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
Ok(Some((index, index, llvm_i32.const_int(1, true))))
|
||||
}
|
||||
|
||||
ExprKind::Slice { lower, upper, step } => {
|
||||
let dim_sz = unsafe {
|
||||
v.shape().get_typed_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(dim, false),
|
||||
None,
|
||||
)
|
||||
};
|
||||
|
||||
handle_slice_indices(lower, upper, step, ctx, generator, dim_sz)
|
||||
}
|
||||
|
||||
_ => {
|
||||
let Some(index) = generator.gen_expr(ctx, slice)? else { return Ok(None) };
|
||||
let index = index
|
||||
.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?
|
||||
.into_int_value();
|
||||
let Some(index) = normalize_index(generator, ctx, index, dim)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
Ok(Some((index, index, llvm_i32.const_int(1, true))))
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let make_indices_arr = |generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>|
|
||||
-> Result<_, String> {
|
||||
Ok(if let ExprKind::Tuple { elts, .. } = &slice.node {
|
||||
let llvm_int_ty = ctx.get_llvm_type(generator, elts[0].custom.unwrap());
|
||||
let index_addr = generator.gen_array_var_alloc(
|
||||
ctx,
|
||||
llvm_int_ty,
|
||||
llvm_usize.const_int(elts.len() as u64, false),
|
||||
None,
|
||||
)?;
|
||||
|
||||
for (i, elt) in elts.iter().enumerate() {
|
||||
let Some(index) = generator.gen_expr(ctx, elt)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let index = index
|
||||
.to_basic_value_enum(ctx, generator, elt.custom.unwrap())?
|
||||
.into_int_value();
|
||||
let Some(index) = normalize_index(generator, ctx, index, 0)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let store_ptr = unsafe {
|
||||
index_addr.ptr_offset_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(i as u64, false),
|
||||
None,
|
||||
)
|
||||
};
|
||||
ctx.builder.build_store(store_ptr, index).unwrap();
|
||||
}
|
||||
|
||||
Some(index_addr)
|
||||
} else if let Some(index) = generator.gen_expr(ctx, slice)? {
|
||||
let llvm_int_ty = ctx.get_llvm_type(generator, slice.custom.unwrap());
|
||||
let index_addr = generator.gen_array_var_alloc(
|
||||
ctx,
|
||||
llvm_int_ty,
|
||||
llvm_usize.const_int(1u64, false),
|
||||
None,
|
||||
)?;
|
||||
|
||||
let index =
|
||||
index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value();
|
||||
let Some(index) = normalize_index(generator, ctx, index, 0)? else { return Ok(None) };
|
||||
|
||||
let store_ptr = unsafe {
|
||||
index_addr.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
};
|
||||
ctx.builder.build_store(store_ptr, index).unwrap();
|
||||
|
||||
Some(index_addr)
|
||||
} else {
|
||||
None
|
||||
})
|
||||
};
|
||||
|
||||
Ok(Some(if ndims.len() == 1 && ndims[0] - subscripted_dims == 0 {
|
||||
let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) };
|
||||
|
||||
v.data().get(ctx, generator, &index_addr, None).into()
|
||||
} else {
|
||||
match &slice.node {
|
||||
ExprKind::Tuple { elts, .. } => {
|
||||
let slices = elts
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(dim, elt)| expr_to_slice(generator, ctx, &elt.node, dim as u64))
|
||||
.take_while_inclusive(|slice| slice.as_ref().is_ok_and(Option::is_some))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
if slices.len() < elts.len() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let slices = slices.into_iter().map(Option::unwrap).collect_vec();
|
||||
|
||||
numpy::ndarray_sliced_copy(generator, ctx, ty, v, &slices)?.as_base_value().into()
|
||||
}
|
||||
|
||||
ExprKind::Slice { .. } => {
|
||||
let Some(slice) = expr_to_slice(generator, ctx, &slice.node, 0)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
numpy::ndarray_sliced_copy(generator, ctx, ty, v, &[slice])?.as_base_value().into()
|
||||
}
|
||||
|
||||
_ => {
|
||||
// Accessing an element from a multi-dimensional `ndarray`
|
||||
let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) };
|
||||
|
||||
let num_dims = extract_ndims(&ctx.unifier, ndims_ty) - 1;
|
||||
|
||||
// Create a new array, remove the top dimension from the dimension-size-list, and copy the
|
||||
// elements over
|
||||
let ndarray =
|
||||
NDArrayType::new(generator, ctx.ctx, llvm_ndarray_data_t, Some(num_dims))
|
||||
.construct_uninitialized(generator, ctx, None);
|
||||
|
||||
let ndarray_num_dims = ctx
|
||||
.builder
|
||||
.build_int_z_extend_or_bit_cast(
|
||||
ndarray.load_ndims(ctx),
|
||||
llvm_usize.size_of().get_type(),
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
let v_dims_src_ptr = unsafe {
|
||||
v.shape().ptr_offset_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(1, false),
|
||||
None,
|
||||
)
|
||||
};
|
||||
call_memcpy_generic(
|
||||
ctx,
|
||||
ndarray.shape().base_ptr(ctx, generator),
|
||||
v_dims_src_ptr,
|
||||
ctx.builder
|
||||
.build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "")
|
||||
.map(Into::into)
|
||||
.unwrap(),
|
||||
llvm_i1.const_zero(),
|
||||
);
|
||||
|
||||
let ndarray_num_elems = ndarray::call_ndarray_calc_size(
|
||||
generator,
|
||||
ctx,
|
||||
&ndarray.shape().as_slice_value(ctx, generator),
|
||||
(None, None),
|
||||
);
|
||||
let ndarray_num_elems = ctx
|
||||
.builder
|
||||
.build_int_z_extend_or_bit_cast(ndarray_num_elems, sizeof_elem.get_type(), "")
|
||||
.unwrap();
|
||||
unsafe { ndarray.create_data(generator, ctx) };
|
||||
|
||||
let v_data_src_ptr = v.data().ptr_offset(ctx, generator, &index_addr, None);
|
||||
call_memcpy_generic(
|
||||
ctx,
|
||||
ndarray.data().base_ptr(ctx, generator),
|
||||
v_data_src_ptr,
|
||||
ctx.builder
|
||||
.build_int_mul(
|
||||
ndarray_num_elems,
|
||||
llvm_ndarray_data_t.size_of().unwrap(),
|
||||
"",
|
||||
)
|
||||
.map(Into::into)
|
||||
.unwrap(),
|
||||
llvm_i1.const_zero(),
|
||||
);
|
||||
|
||||
ndarray.as_base_value().into()
|
||||
}
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
/// See [`CodeGenerator::gen_expr`].
|
||||
pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||
generator: &mut G,
|
||||
|
@ -1,7 +1,6 @@
|
||||
use inkwell::{
|
||||
context::Context,
|
||||
intrinsics::Intrinsic,
|
||||
types::{AnyTypeEnum::IntType, FloatType},
|
||||
types::AnyTypeEnum::IntType,
|
||||
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue},
|
||||
AddressSpace,
|
||||
};
|
||||
@ -9,34 +8,6 @@ use itertools::Either;
|
||||
|
||||
use super::CodeGenContext;
|
||||
|
||||
/// Returns the string representation for the floating-point type `ft` when used in intrinsic
|
||||
/// functions.
|
||||
fn get_float_intrinsic_repr(ctx: &Context, ft: FloatType) -> &'static str {
|
||||
// Standard LLVM floating-point types
|
||||
if ft == ctx.f16_type() {
|
||||
return "f16";
|
||||
}
|
||||
if ft == ctx.f32_type() {
|
||||
return "f32";
|
||||
}
|
||||
if ft == ctx.f64_type() {
|
||||
return "f64";
|
||||
}
|
||||
if ft == ctx.f128_type() {
|
||||
return "f128";
|
||||
}
|
||||
|
||||
// Non-standard floating-point types
|
||||
if ft == ctx.x86_f80_type() {
|
||||
return "f80";
|
||||
}
|
||||
if ft == ctx.ppc_f128_type() {
|
||||
return "ppcf128";
|
||||
}
|
||||
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
/// Invokes the [`llvm.va_start`](https://llvm.org/docs/LangRef.html#llvm-va-start-intrinsic)
|
||||
/// intrinsic.
|
||||
pub fn call_va_start<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue<'ctx>) {
|
||||
@ -54,7 +25,7 @@ pub fn call_va_start<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue
|
||||
ctx.builder.build_call(intrinsic_fn, &[arglist.into()], "").unwrap();
|
||||
}
|
||||
|
||||
/// Invokes the [`llvm.va_start`](https://llvm.org/docs/LangRef.html#llvm-va-start-intrinsic)
|
||||
/// Invokes the [`llvm.va_end`](https://llvm.org/docs/LangRef.html#llvm-va-end-intrinsic)
|
||||
/// intrinsic.
|
||||
pub fn call_va_end<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue<'ctx>) {
|
||||
const FN_NAME: &str = "llvm.va_end";
|
||||
|
@ -604,29 +604,6 @@ fn llvm_ndlist_get_ndims<'ctx, G: CodeGenerator + ?Sized>(
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the number of dimensions for an array-like object as an [`IntValue`].
|
||||
fn llvm_arraylike_get_ndims<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
(ty, value): (Type, BasicValueEnum<'ctx>),
|
||||
) -> IntValue<'ctx> {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
match value {
|
||||
BasicValueEnum::PointerValue(v)
|
||||
if NDArrayValue::is_representable(v, llvm_usize).is_ok() =>
|
||||
{
|
||||
NDArrayType::from_unifier_type(generator, ctx, ty).map_value(v, None).load_ndims(ctx)
|
||||
}
|
||||
|
||||
BasicValueEnum::PointerValue(v) if ListValue::is_representable(v, llvm_usize).is_ok() => {
|
||||
llvm_ndlist_get_ndims(generator, ctx, v.get_type())
|
||||
}
|
||||
|
||||
_ => llvm_usize.const_zero(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Flattens and copies the values from a multidimensional list into an [`NDArrayValue`].
|
||||
fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
|
@ -36,7 +36,6 @@ use crate::{
|
||||
struct Resolver {
|
||||
id_to_type: HashMap<StrRef, Type>,
|
||||
id_to_def: RwLock<HashMap<StrRef, DefinitionId>>,
|
||||
class_names: HashMap<StrRef, Type>,
|
||||
}
|
||||
|
||||
impl Resolver {
|
||||
@ -104,11 +103,9 @@ fn test_primitives() {
|
||||
let top_level = Arc::new(composer.make_top_level_context());
|
||||
unifier.top_level = Some(top_level.clone());
|
||||
|
||||
let resolver = Arc::new(Resolver {
|
||||
id_to_type: HashMap::new(),
|
||||
id_to_def: RwLock::new(HashMap::new()),
|
||||
class_names: HashMap::default(),
|
||||
}) as Arc<dyn SymbolResolver + Send + Sync>;
|
||||
let resolver =
|
||||
Arc::new(Resolver { id_to_type: HashMap::new(), id_to_def: RwLock::new(HashMap::new()) })
|
||||
as Arc<dyn SymbolResolver + Send + Sync>;
|
||||
|
||||
let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()];
|
||||
let signature = FunSignature {
|
||||
@ -298,11 +295,7 @@ fn test_simple_call() {
|
||||
loc: None,
|
||||
})));
|
||||
|
||||
let resolver = Resolver {
|
||||
id_to_type: HashMap::new(),
|
||||
id_to_def: RwLock::new(HashMap::new()),
|
||||
class_names: HashMap::default(),
|
||||
};
|
||||
let resolver = Resolver { id_to_type: HashMap::new(), id_to_def: RwLock::new(HashMap::new()) };
|
||||
resolver.add_id_def("foo".into(), DefinitionId(foo_id));
|
||||
let resolver = Arc::new(resolver) as Arc<dyn SymbolResolver + Send + Sync>;
|
||||
|
||||
|
@ -389,7 +389,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ArraySliceValue<'ctx> {
|
||||
idx: &IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
|
||||
let var_name = name.or(self.2).map(|v| format!("{v}.addr")).unwrap_or_default();
|
||||
|
||||
unsafe {
|
||||
ctx.builder
|
||||
|
@ -19,7 +19,6 @@ use crate::codegen::{
|
||||
pub use contiguous::*;
|
||||
pub use indexing::*;
|
||||
pub use nditer::*;
|
||||
pub use view::*;
|
||||
|
||||
mod contiguous;
|
||||
mod indexing;
|
||||
@ -113,12 +112,6 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||
self.get_type().get_fields(ctx.ctx).shape
|
||||
}
|
||||
|
||||
/// Returns the double-indirection pointer to the `shape` array, as if by calling
|
||||
/// `getelementptr` on the field.
|
||||
fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||
self.shape_field(ctx).ptr_by_gep(ctx, self.value, self.name)
|
||||
}
|
||||
|
||||
/// Stores the array of dimension sizes `dims` into this instance.
|
||||
fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
|
||||
self.shape_field(ctx).set(ctx, self.as_base_value(), dims, self.name);
|
||||
@ -147,12 +140,6 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||
self.get_type().get_fields(ctx.ctx).strides
|
||||
}
|
||||
|
||||
/// Returns the double-indirection pointer to the `strides` array, as if by calling
|
||||
/// `getelementptr` on the field.
|
||||
fn ptr_to_strides(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||
self.strides_field(ctx).ptr_by_gep(ctx, self.value, self.name)
|
||||
}
|
||||
|
||||
/// Stores the array of stride sizes `strides` into this instance.
|
||||
fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, strides: PointerValue<'ctx>) {
|
||||
self.strides_field(ctx).set(ctx, self.as_base_value(), strides, self.name);
|
||||
|
@ -78,7 +78,7 @@ impl<'ctx> NDIterValue<'ctx> {
|
||||
pub fn get_pointer(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||
let elem_ty = self.parent.dtype;
|
||||
|
||||
let p = self.element(ctx).get(ctx, self.as_base_value(), None);
|
||||
let p = self.element(ctx).get(ctx, self.as_base_value(), self.name);
|
||||
ctx.builder
|
||||
.build_pointer_cast(p, elem_ty.ptr_type(AddressSpace::default()), "element")
|
||||
.unwrap()
|
||||
@ -98,7 +98,7 @@ impl<'ctx> NDIterValue<'ctx> {
|
||||
/// Get the index of the current element if this ndarray were a flat ndarray.
|
||||
#[must_use]
|
||||
pub fn get_index(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
|
||||
self.nth(ctx).get(ctx, self.as_base_value(), None)
|
||||
self.nth(ctx).get(ctx, self.as_base_value(), self.name)
|
||||
}
|
||||
|
||||
/// Get the indices of the current element.
|
||||
|
@ -1,13 +1,7 @@
|
||||
use std::iter::once;
|
||||
|
||||
use indexmap::IndexMap;
|
||||
use inkwell::{
|
||||
attributes::{Attribute, AttributeLoc},
|
||||
types::{BasicMetadataTypeEnum, BasicType},
|
||||
values::{BasicMetadataValueEnum, BasicValue, CallSiteValue},
|
||||
IntPredicate,
|
||||
};
|
||||
use itertools::Either;
|
||||
use inkwell::{values::BasicValue, IntPredicate};
|
||||
use strum::IntoEnumIterator;
|
||||
|
||||
use super::{
|
||||
@ -148,144 +142,6 @@ fn create_fn_by_codegen(
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a NumPy [`TopLevelDef`] function using an LLVM intrinsic.
|
||||
///
|
||||
/// * `name`: The name of the implemented NumPy function.
|
||||
/// * `ret_ty`: The return type of this function.
|
||||
/// * `param_ty`: The parameters accepted by this function, represented by a tuple of the
|
||||
/// [parameter type][Type] and the parameter symbol name.
|
||||
/// * `intrinsic_fn`: The fully-qualified name of the LLVM intrinsic function.
|
||||
fn create_fn_by_intrinsic(
|
||||
unifier: &mut Unifier,
|
||||
var_map: &VarMap,
|
||||
name: &'static str,
|
||||
ret_ty: Type,
|
||||
params: &[(Type, &'static str)],
|
||||
intrinsic_fn: &'static str,
|
||||
) -> TopLevelDef {
|
||||
let param_tys = params.iter().map(|p| p.0).collect_vec();
|
||||
|
||||
create_fn_by_codegen(
|
||||
unifier,
|
||||
var_map,
|
||||
name,
|
||||
ret_ty,
|
||||
params,
|
||||
Box::new(move |ctx, _, fun, args, generator| {
|
||||
let args_ty = fun.0.args.iter().map(|a| a.ty).collect_vec();
|
||||
|
||||
assert!(param_tys
|
||||
.iter()
|
||||
.zip(&args_ty)
|
||||
.all(|(expected, actual)| ctx.unifier.unioned(*expected, *actual)));
|
||||
|
||||
let args_val = args_ty
|
||||
.iter()
|
||||
.zip_eq(args.iter())
|
||||
.map(|(ty, arg)| arg.1.clone().to_basic_value_enum(ctx, generator, *ty).unwrap())
|
||||
.map_into::<BasicMetadataValueEnum>()
|
||||
.collect_vec();
|
||||
|
||||
let intrinsic_fn = ctx.module.get_function(intrinsic_fn).unwrap_or_else(|| {
|
||||
let ret_llvm_ty = ctx.get_llvm_abi_type(generator, ret_ty);
|
||||
let param_llvm_ty = param_tys
|
||||
.iter()
|
||||
.map(|p| ctx.get_llvm_abi_type(generator, *p))
|
||||
.map_into::<BasicMetadataTypeEnum>()
|
||||
.collect_vec();
|
||||
let fn_type = ret_llvm_ty.fn_type(param_llvm_ty.as_slice(), false);
|
||||
|
||||
ctx.module.add_function(intrinsic_fn, fn_type, None)
|
||||
});
|
||||
|
||||
let val = ctx
|
||||
.builder
|
||||
.build_call(intrinsic_fn, args_val.as_slice(), name)
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap();
|
||||
Ok(val.into())
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Creates a unary NumPy [`TopLevelDef`] function using an extern function (e.g. from `libc` or
|
||||
/// `libm`).
|
||||
///
|
||||
/// * `name`: The name of the implemented NumPy function.
|
||||
/// * `ret_ty`: The return type of this function.
|
||||
/// * `param_ty`: The parameters accepted by this function, represented by a tuple of the
|
||||
/// [parameter type][Type] and the parameter symbol name.
|
||||
/// * `extern_fn`: The fully-qualified name of the extern function used as the implementation.
|
||||
/// * `attrs`: The list of attributes to apply to this function declaration. Note that `nounwind` is
|
||||
/// already implied by the C ABI.
|
||||
fn create_fn_by_extern(
|
||||
unifier: &mut Unifier,
|
||||
var_map: &VarMap,
|
||||
name: &'static str,
|
||||
ret_ty: Type,
|
||||
params: &[(Type, &'static str)],
|
||||
extern_fn: &'static str,
|
||||
attrs: &'static [&str],
|
||||
) -> TopLevelDef {
|
||||
let param_tys = params.iter().map(|p| p.0).collect_vec();
|
||||
|
||||
create_fn_by_codegen(
|
||||
unifier,
|
||||
var_map,
|
||||
name,
|
||||
ret_ty,
|
||||
params,
|
||||
Box::new(move |ctx, _, fun, args, generator| {
|
||||
let args_ty = fun.0.args.iter().map(|a| a.ty).collect_vec();
|
||||
|
||||
assert!(param_tys
|
||||
.iter()
|
||||
.zip(&args_ty)
|
||||
.all(|(expected, actual)| ctx.unifier.unioned(*expected, *actual)));
|
||||
|
||||
let args_val = args_ty
|
||||
.iter()
|
||||
.zip_eq(args.iter())
|
||||
.map(|(ty, arg)| arg.1.clone().to_basic_value_enum(ctx, generator, *ty).unwrap())
|
||||
.map_into::<BasicMetadataValueEnum>()
|
||||
.collect_vec();
|
||||
|
||||
let intrinsic_fn = ctx.module.get_function(extern_fn).unwrap_or_else(|| {
|
||||
let ret_llvm_ty = ctx.get_llvm_abi_type(generator, ret_ty);
|
||||
let param_llvm_ty = param_tys
|
||||
.iter()
|
||||
.map(|p| ctx.get_llvm_abi_type(generator, *p))
|
||||
.map_into::<BasicMetadataTypeEnum>()
|
||||
.collect_vec();
|
||||
let fn_type = ret_llvm_ty.fn_type(param_llvm_ty.as_slice(), false);
|
||||
let func = ctx.module.add_function(extern_fn, fn_type, None);
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0),
|
||||
);
|
||||
|
||||
for attr in attrs {
|
||||
func.add_attribute(
|
||||
AttributeLoc::Function,
|
||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
||||
);
|
||||
}
|
||||
|
||||
func
|
||||
});
|
||||
|
||||
let val = ctx
|
||||
.builder
|
||||
.build_call(intrinsic_fn, &args_val, name)
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap();
|
||||
Ok(val.into())
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> BuiltinInfo {
|
||||
BuiltinBuilder::new(unifier, primitives)
|
||||
.build_all_builtins()
|
||||
|
@ -8,5 +8,5 @@ expression: res_vec
|
||||
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
||||
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(246)]\n}\n",
|
||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(245)]\n}\n",
|
||||
]
|
||||
|
@ -7,7 +7,7 @@ expression: res_vec
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar230]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar230\"]\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar229]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar229\"]\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
||||
|
@ -5,8 +5,8 @@ expression: res_vec
|
||||
[
|
||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(243)]\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(248)]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(242)]\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(247)]\n}\n",
|
||||
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
|
@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
|
||||
expression: res_vec
|
||||
---
|
||||
[
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar229, typevar230]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar229\", \"typevar230\"]\n}\n",
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar228, typevar229]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar228\", \"typevar229\"]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
||||
|
@ -6,12 +6,12 @@ expression: res_vec
|
||||
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(249)]\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(248)]\n}\n",
|
||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(257)]\n}\n",
|
||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(256)]\n}\n",
|
||||
]
|
||||
|
@ -15,14 +15,13 @@ use crate::{
|
||||
symbol_resolver::{SymbolResolver, ValueEnum},
|
||||
typecheck::{
|
||||
type_inferencer::PrimitiveStore,
|
||||
typedef::{into_var_map, Type, Unifier},
|
||||
typedef::{Type, Unifier},
|
||||
},
|
||||
};
|
||||
|
||||
struct ResolverInternal {
|
||||
id_to_type: Mutex<HashMap<StrRef, Type>>,
|
||||
id_to_def: Mutex<HashMap<StrRef, DefinitionId>>,
|
||||
class_names: Mutex<HashMap<StrRef, Type>>,
|
||||
}
|
||||
|
||||
impl ResolverInternal {
|
||||
@ -179,11 +178,8 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
|
||||
let mut composer =
|
||||
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
|
||||
|
||||
let internal_resolver = Arc::new(ResolverInternal {
|
||||
id_to_def: Mutex::default(),
|
||||
id_to_type: Mutex::default(),
|
||||
class_names: Mutex::default(),
|
||||
});
|
||||
let internal_resolver =
|
||||
Arc::new(ResolverInternal { id_to_def: Mutex::default(), id_to_type: Mutex::default() });
|
||||
let resolver =
|
||||
Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
|
||||
|
||||
@ -784,13 +780,6 @@ fn make_internal_resolver_with_tvar(
|
||||
unifier: &mut Unifier,
|
||||
print: bool,
|
||||
) -> Arc<ResolverInternal> {
|
||||
let list_elem_tvar = unifier.get_fresh_var(Some("list_elem".into()), None);
|
||||
let list = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: PrimDef::List.id(),
|
||||
fields: HashMap::new(),
|
||||
params: into_var_map([list_elem_tvar]),
|
||||
});
|
||||
|
||||
let res: Arc<ResolverInternal> = ResolverInternal {
|
||||
id_to_def: Mutex::new(HashMap::from([("list".into(), PrimDef::List.id())])),
|
||||
id_to_type: tvars
|
||||
@ -806,7 +795,6 @@ fn make_internal_resolver_with_tvar(
|
||||
})
|
||||
.collect::<HashMap<_, _>>()
|
||||
.into(),
|
||||
class_names: Mutex::new(HashMap::from([("list".into(), list)])),
|
||||
}
|
||||
.into();
|
||||
if print {
|
||||
|
@ -18,7 +18,6 @@ use crate::{
|
||||
struct Resolver {
|
||||
id_to_type: HashMap<StrRef, Type>,
|
||||
id_to_def: HashMap<StrRef, DefinitionId>,
|
||||
class_names: HashMap<StrRef, Type>,
|
||||
}
|
||||
|
||||
impl SymbolResolver for Resolver {
|
||||
@ -198,7 +197,6 @@ impl TestEnvironment {
|
||||
let resolver = Arc::new(Resolver {
|
||||
id_to_type: identifier_mapping.clone(),
|
||||
id_to_def: HashMap::default(),
|
||||
class_names: HashMap::default(),
|
||||
}) as Arc<dyn SymbolResolver + Send + Sync>;
|
||||
|
||||
TestEnvironment {
|
||||
@ -454,7 +452,6 @@ impl TestEnvironment {
|
||||
vars: IndexMap::default(),
|
||||
})),
|
||||
);
|
||||
let class_names: HashMap<_, _> = [("Bar".into(), bar), ("Bar2".into(), bar2)].into();
|
||||
|
||||
let id_to_name = [
|
||||
"int32".into(),
|
||||
@ -492,7 +489,6 @@ impl TestEnvironment {
|
||||
("Bar2".into(), DefinitionId(defs + 3)),
|
||||
]
|
||||
.into(),
|
||||
class_names,
|
||||
}) as Arc<dyn SymbolResolver + Send + Sync>;
|
||||
|
||||
TestEnvironment {
|
||||
|
Loading…
Reference in New Issue
Block a user