forked from M-Labs/nac3
1
0
Fork 0
This commit is contained in:
lyken 2024-07-10 17:27:10 +08:00
parent 9aae290727
commit 87d2a4ed59
7 changed files with 447 additions and 38 deletions

View File

@ -163,7 +163,10 @@
clippy
pre-commit
rustfmt
rust-analyzer
];
# https://nixos.wiki/wiki/Rust#Shell.nix_example
RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
};
devShells.x86_64-linux.msys2 = pkgs.mkShell {
name = "nac3-dev-shell-msys2";

View File

@ -403,6 +403,43 @@ namespace {
}
}
}
// Simulates `this_ndarray[:] = src_ndarray`, with automatic broadcasting.
// Caution on https://github.com/numpy/numpy/issues/21744
// Also see `NDArray::broadcast_to`
void assign_with(NDArray<SizeT>* src_ndarray) {
irrt_assert(
ndarray_util::can_broadcast_shape_to(
this->ndims,
this->shape,
src_ndarray->ndims,
src_ndarray->shape
)
);
// Broadcast the `src_ndarray` to make the reading process *much* easier
SizeT* broadcasted_src_ndarray_strides = __builtin_alloca(sizeof(SizeT) * this->ndims); // Remember to allocate strides beforehand
NDArray<SizeT> broadcasted_src_ndarray = {
.ndims = this->ndims,
.shape = this->shape,
.strides = broadcasted_src_ndarray_strides
};
src_ndarray->broadcast_to(&broadcasted_src_ndarray);
// Using iter instead of `get_nth_pelement` because it is slightly faster
SizeT* indices = __builtin_alloca(sizeof(SizeT) * this->ndims);
auto iter = NDArrayIndicesIter<SizeT> {
.ndims = this->ndims,
.shape = this->shape,
.indices = indices
};
const SizeT this_size = this->size();
for (SizeT i = 0; i < this_size; i++, iter.next()) {
uint8_t* src_pelement = broadcasted_src_ndarray_strides->get_pelement(indices);
uint8_t* this_pelement = this->get_pelement(indices);
this->set_value_at_pelement(src_pelement, src_pelement);
}
}
};
}

View File

@ -165,7 +165,7 @@ void test_ndarray_indices_iter_normal() {
int32_t shape[3] = { 1, 2, 3 };
int32_t indices[3] = { 0, 0, 0 };
auto iter = NDArrayIndicesIter<int32_t> {
.ndims = 3u,
.ndims = 3,
.shape = shape,
.indices = indices
};
@ -629,6 +629,15 @@ void test_ndarray_broadcast_1() {
assert_values_match("dst_ndarray[1, 2, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {1, 2, 3})));
}
void test_assign_with() {
/*
```
xs = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=np.float64)
ys = xs.shape
```
*/
}
int main() {
test_calc_size_from_shape_normal();
test_calc_size_from_shape_has_zero();
@ -644,5 +653,6 @@ int main() {
test_ndslice_2();
test_can_broadcast_shape();
test_ndarray_broadcast_1();
test_assign_with();
return 0;
}

View File

@ -1,8 +1,6 @@
use crate::codegen::{
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
llvm_intrinsics::call_int_umin,
stmt::gen_for_callback_incrementing,
CodeGenContext, CodeGenerator,
llvm_intrinsics::call_int_umin, stmt::gen_for_callback_incrementing, CodeGenContext,
CodeGenerator,
};
use inkwell::context::Context;
use inkwell::types::{ArrayType, BasicType, StructType};
@ -12,6 +10,7 @@ use inkwell::{
values::{BasicValueEnum, IntValue, PointerValue},
AddressSpace, IntPredicate,
};
use itertools::Itertools;
/// A LLVM type that is used to represent a non-primitive type in NAC3.
pub trait ProxyType<'ctx>: Into<Self::Base> {
@ -1601,7 +1600,8 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> {
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> IntValue<'ctx> {
call_ndarray_calc_size(generator, ctx, &self.as_slice_value(ctx, generator), (None, None))
todo!()
// call_ndarray_calc_size(generator, ctx, &self.as_slice_value(ctx, generator), (None, None))
}
}
@ -1675,17 +1675,19 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
indices_elem_ty.get_bit_width()
);
let index = call_ndarray_flatten_index(generator, ctx, *self.0, indices);
todo!()
unsafe {
ctx.builder
.build_in_bounds_gep(
self.base_ptr(ctx, generator),
&[index],
name.unwrap_or_default(),
)
.unwrap()
}
// let index = call_ndarray_flatten_index(generator, ctx, *self.0, indices);
// unsafe {
// ctx.builder
// .build_in_bounds_gep(
// self.base_ptr(ctx, generator),
// &[index],
// name.unwrap_or_default(),
// )
// .unwrap()
// }
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
@ -1761,3 +1763,307 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx,
for NDArrayDataProxy<'ctx, '_>
{
}
#[derive(Debug, Clone, Copy)]
pub struct StructField<'ctx> {
/// The GEP index of this struct field.
pub gep_index: u32,
/// Name of this struct field.
///
/// Used for generating names.
pub name: &'static str,
/// The type of this struct field.
pub ty: BasicTypeEnum<'ctx>,
}
pub struct StructFields<'ctx> {
/// Name of the struct.
///
/// Used for generating names.
pub name: &'static str,
/// All the [`StructField`]s of this struct.
///
/// **NOTE:** The index position of a [`StructField`]
/// matches the element's [`StructField::index`].
pub fields: Vec<StructField<'ctx>>,
}
struct StructFieldsBuilder<'ctx> {
gep_index_counter: u32,
/// Name of the struct to be built.
name: &'static str,
fields: Vec<StructField<'ctx>>,
}
impl<'ctx> StructField<'ctx> {
pub fn gep(
&self,
ctx: &CodeGenContext<'ctx, '_>,
ptr: PointerValue<'ctx>,
) -> PointerValue<'ctx> {
ctx.builder.build_struct_gep(ptr, self.gep_index, self.name).unwrap()
}
pub fn load(
&self,
ctx: &CodeGenContext<'ctx, '_>,
ptr: PointerValue<'ctx>,
) -> BasicValueEnum<'ctx> {
ctx.builder.build_load(self.gep(ctx, ptr), self.name).unwrap()
}
pub fn store<V>(&self, ctx: &CodeGenContext<'ctx, '_>, ptr: PointerValue<'ctx>, value: V)
where
V: BasicValue<'ctx>,
{
ctx.builder.build_store(ptr, value).unwrap();
}
}
type IsInstanceError = String;
type IsInstanceResult = Result<(), IsInstanceError>;
pub fn check_basic_types_match<'ctx, A, B>(expected: A, got: B) -> IsInstanceResult
where
A: BasicType<'ctx>,
B: BasicType<'ctx>,
{
let expected = expected.as_basic_type_enum();
let got = got.as_basic_type_enum();
// Put those logic into here,
// otherwise there is always a fallback reporting on any kind of mismatch
match (expected, got) {
(BasicTypeEnum::IntType(expected), BasicTypeEnum::IntType(got)) => {
if expected.get_bit_width() != got.get_bit_width() {
return Err(format!(
"Expected IntType ({expected}-bit(s)), got IntType ({got}-bit(s))"
));
}
}
(expected, got) => {
if expected != got {
return Err(format!("Expected {expected}, got {got}"));
}
}
}
Ok(())
}
impl<'ctx> StructFields<'ctx> {
pub fn num_fields(&self) -> u32 {
self.fields.len() as u32
}
pub fn as_struct_type(&self, ctx: &'ctx Context) -> StructType<'ctx> {
let llvm_fields = self.fields.iter().map(|field| field.ty).collect_vec();
ctx.struct_type(llvm_fields.as_slice(), false)
}
pub fn is_type(&self, scrutinee: StructType<'ctx>) -> IsInstanceResult {
// Check scrutinee's number of struct fields
if scrutinee.count_fields() != self.num_fields() {
return Err(format!(
"Expected {expected_count} field(s) in `{struct_name}` type, got {got_count}",
struct_name = self.name,
expected_count = self.num_fields(),
got_count = scrutinee.count_fields(),
));
}
// Check the scrutinee's field types
for field in self.fields.iter() {
let expected_field_ty = field.ty;
let got_field_ty = scrutinee.get_field_type_at_index(field.gep_index).unwrap();
if let Err(field_err) = check_basic_types_match(expected_field_ty, got_field_ty) {
return Err(format!(
"Field GEP index {gep_index} does not match the expected type of ({struct_name}::{field_name}): {field_err}",
gep_index = field.gep_index,
struct_name = self.name,
field_name = field.name,
));
}
}
// Done
Ok(())
}
}
impl<'ctx> StructFieldsBuilder<'ctx> {
fn start(name: &'static str) -> Self {
StructFieldsBuilder { gep_index_counter: 0, name, fields: Vec::new() }
}
fn add_field(&mut self, name: &'static str, ty: BasicTypeEnum<'ctx>) -> StructField<'ctx> {
let index = self.gep_index_counter;
self.gep_index_counter += 1;
StructField { gep_index: index, name, ty }
}
fn end(self) -> StructFields<'ctx> {
StructFields { name: self.name, fields: self.fields }
}
}
#[derive(Debug, Clone, Copy)]
pub struct NpArrayType<'ctx> {
pub size_type: IntType<'ctx>,
pub elem_type: BasicTypeEnum<'ctx>,
}
pub struct NpArrayStructFields<'ctx> {
pub whole_struct: StructFields<'ctx>,
pub data: StructField<'ctx>,
pub itemsize: StructField<'ctx>,
pub ndims: StructField<'ctx>,
pub shape: StructField<'ctx>,
pub strides: StructField<'ctx>,
}
impl<'ctx> NpArrayType<'ctx> {
pub fn new_opaque_elem(
ctx: &CodeGenContext<'ctx, '_>,
size_type: IntType<'ctx>,
) -> NpArrayType<'ctx> {
NpArrayType { size_type, elem_type: ctx.ctx.i8_type().as_basic_type_enum() }
}
pub fn struct_type(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructType<'ctx> {
self.fields().whole_struct.as_struct_type(ctx.ctx)
}
pub fn fields(&self) -> NpArrayStructFields<'ctx> {
let mut builder = StructFieldsBuilder::start("NpArray");
let addrspace = AddressSpace::default();
let byte_type = self.size_type.get_context().i8_type();
// Make sure the struct matches PERFECTLY with that defined in `nac3core/irrt`.
let data = builder.add_field("data", byte_type.ptr_type(addrspace).into());
let itemsize = builder.add_field("itemsize", self.size_type.into());
let ndims = builder.add_field("ndims", self.size_type.into());
let shape = builder.add_field("shape", self.size_type.ptr_type(addrspace).into());
let strides = builder.add_field("strides", self.size_type.ptr_type(addrspace).into());
NpArrayStructFields { whole_struct: builder.end(), data, itemsize, ndims, shape, strides }
}
/// Allocate an `ndarray` on stack, with the following notes:
///
/// - `ndarray.ndims` will be initialized to `in_ndims`.
/// - `ndarray.itemsize` will be initialized to the size of `self.elem_type.size_of()`.
/// - `ndarray.shape` and `ndarray.strides` will be allocated on the stack with number of elements being `in_ndims`,
/// all with empty/uninitialized values.
pub fn alloca(
&self,
ctx: &CodeGenContext<'ctx, '_>,
in_ndims: IntValue<'ctx>,
name: &str,
) -> NpArrayValue<'ctx> {
let fields = self.fields();
let ptr =
ctx.builder.build_alloca(fields.whole_struct.as_struct_type(ctx.ctx), name).unwrap();
// Allocate `in_dims` number of `size_type` on the stack for `shape` and `strides`
let allocated_shape =
ctx.builder.build_array_alloca(fields.shape.ty, in_ndims, "allocated_shape").unwrap();
let allocated_strides = ctx
.builder
.build_array_alloca(fields.strides.ty, in_ndims, "allocated_strides")
.unwrap();
let value = NpArrayValue { ty: *self, ptr };
value.store_ndims(ctx, in_ndims);
value.store_itemsize(ctx, self.elem_type.size_of().unwrap());
value.store_shape(ctx, allocated_shape);
value.store_strides(ctx, allocated_strides);
return value;
}
}
#[derive(Debug, Clone, Copy)]
pub struct NpArrayValue<'ctx> {
pub ty: NpArrayType<'ctx>,
pub ptr: PointerValue<'ctx>,
}
impl<'ctx> NpArrayValue<'ctx> {
pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
let field = self.ty.fields().ndims;
field.load(ctx, self.ptr).into_int_value()
}
pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) {
let field = self.ty.fields().ndims;
field.store(ctx, self.ptr, value);
}
pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
let field = self.ty.fields().itemsize;
field.load(ctx, self.ptr).into_int_value()
}
pub fn store_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) {
let field = self.ty.fields().itemsize;
field.store(ctx, self.ptr, value);
}
pub fn load_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let field = self.ty.fields().shape;
field.load(ctx, self.ptr).into_pointer_value()
}
pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) {
let field = self.ty.fields().shape;
field.store(ctx, self.ptr, value);
}
pub fn load_strides(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let field = self.ty.fields().strides;
field.load(ctx, self.ptr).into_pointer_value()
}
pub fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) {
let field = self.ty.fields().strides;
field.store(ctx, self.ptr, value);
}
/// TODO: DOCUMENT ME -- NDIMS WOULD NEVER CHANGE!!!!!
pub fn shape_slice(
&self,
ctx: &CodeGenContext<'ctx, '_>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let field = self.ty.fields().shape;
field.gep(ctx, self.ptr);
let ndims = self.load_ndims(ctx);
TypedArrayLikeAdapter {
adapted: ArraySliceValue(self.ptr, ndims, Some(field.name)),
downcast_fn: Box::new(|_ctx, x| x.into_int_value()),
upcast_fn: Box::new(|_ctx, x| x.as_basic_value_enum()),
}
}
/// TODO: DOCUMENT ME -- NDIMS WOULD NEVER CHANGE!!!!!
pub fn strides_slice(
&self,
ctx: &CodeGenContext<'ctx, '_>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let field = self.ty.fields().strides;
field.gep(ctx, self.ptr);
let ndims = self.load_ndims(ctx);
TypedArrayLikeAdapter {
adapted: ArraySliceValue(self.ptr, ndims, Some(field.name)),
downcast_fn: Box::new(|_ctx, x| x.into_int_value()),
upcast_fn: Box::new(|_ctx, x| x.as_basic_value_enum()),
}
}
}

View File

@ -1,11 +1,11 @@
use crate::typecheck::typedef::Type;
use crate::{typecheck::typedef::Type, util::SizeVariant};
mod test;
use super::{
classes::{
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue,
TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, NpArrayType,
NpArrayValue, TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
},
llvm_intrinsics, CodeGenContext, CodeGenerator,
};
@ -16,8 +16,8 @@ use inkwell::{
context::Context,
memory_buffer::MemoryBuffer,
module::Module,
types::{BasicTypeEnum, IntType},
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue},
types::{BasicType, BasicTypeEnum, FunctionType, IntType, PointerType},
values::{BasicValueEnum, CallSiteValue, FloatValue, FunctionValue, IntValue},
AddressSpace, IntPredicate,
};
use itertools::Either;
@ -929,3 +929,63 @@ pub fn call_ndarray_calc_broadcast_index<
Box::new(|_, v| v.into()),
)
}
fn get_size_variant<'ctx>(ty: IntType<'ctx>) -> SizeVariant {
match ty.get_bit_width() {
32 => SizeVariant::Bits32,
64 => SizeVariant::Bits64,
_ => unreachable!("Unsupported int type bit width {}", ty.get_bit_width()),
}
}
fn get_size_type_dependent_function<'ctx, BuildFuncTypeFn>(
ctx: &CodeGenContext<'ctx, '_>,
size_type: IntType<'ctx>,
base_name: &str,
build_func_type: BuildFuncTypeFn,
) -> FunctionValue<'ctx>
where
BuildFuncTypeFn: Fn() -> FunctionType<'ctx>,
{
let mut fn_name = base_name.to_owned();
match get_size_variant(size_type) {
SizeVariant::Bits32 => {
// The original fn_name is the correct function name
}
SizeVariant::Bits64 => {
// Append "64" at the end, this is the naming convention for 64-bit
fn_name.push_str("64");
}
}
// Get (or declare then get if does not exist) the corresponding function
ctx.module.get_function(&fn_name).unwrap_or_else(|| {
let fn_type = build_func_type();
ctx.module.add_function(&fn_name, fn_type, None)
})
}
fn get_ndarray_struct_ptr<'ctx>(ctx: &'ctx Context, size_type: IntType<'ctx>) -> PointerType<'ctx> {
let i8_type = ctx.i8_type();
let ndarray_ty = NpArrayType { size_type, elem_type: i8_type.as_basic_type_enum() };
let struct_ty = ndarray_ty.fields().whole_struct.as_struct_type(ctx);
struct_ty.ptr_type(AddressSpace::default())
}
pub fn call_nac3_ndarray_size<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NpArrayValue<'ctx>,
) -> IntValue<'ctx> {
let size_type = ndarray.ty.size_type;
let function = get_size_type_dependent_function(ctx, size_type, "__nac3_ndarray_size", || {
size_type.fn_type(&[get_ndarray_struct_ptr(ctx.ctx, size_type).into()], false)
});
ctx.builder
.build_call(function, &[ndarray.ptr.into()], "size")
.unwrap()
.try_as_basic_value()
.unwrap_left()
.into_int_value()
}

View File

@ -23,3 +23,4 @@ pub mod codegen;
pub mod symbol_resolver;
pub mod toplevel;
pub mod typecheck;
pub mod util;

View File

@ -1,5 +1,6 @@
use std::iter::once;
use crate::util::SizeVariant;
use helper::{debug_assert_prim_is_allowed, make_exception_fields, PrimDefDetails};
use indexmap::IndexMap;
use inkwell::{
@ -278,19 +279,10 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built
.collect()
}
/// A helper enum used by [`BuiltinBuilder`]
#[derive(Clone, Copy)]
enum SizeVariant {
Bits32,
Bits64,
}
impl SizeVariant {
fn of_int(self, primitives: &PrimitiveStore) -> Type {
match self {
SizeVariant::Bits32 => primitives.int32,
SizeVariant::Bits64 => primitives.int64,
}
fn size_variant_to_int_type(variant: SizeVariant, primitives: &PrimitiveStore) -> Type {
match variant {
SizeVariant::Bits32 => primitives.int32,
SizeVariant::Bits64 => primitives.int64,
}
}
@ -1061,7 +1053,7 @@ impl<'a> BuiltinBuilder<'a> {
);
// The size variant of the function determines the size of the returned int.
let int_sized = size_variant.of_int(self.primitives);
let int_sized = size_variant_to_int_type(size_variant, self.primitives);
let ndarray_int_sized =
make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.ty));
@ -1086,7 +1078,7 @@ impl<'a> BuiltinBuilder<'a> {
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
let ret_elem_ty = size_variant.of_int(&ctx.primitives);
let ret_elem_ty = size_variant_to_int_type(size_variant, &ctx.primitives);
Ok(Some(builtin_fns::call_round(generator, ctx, (arg_ty, arg), ret_elem_ty)?))
}),
)
@ -1127,7 +1119,7 @@ impl<'a> BuiltinBuilder<'a> {
make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.ty));
// The size variant of the function determines the type of int returned
let int_sized = size_variant.of_int(self.primitives);
let int_sized = size_variant_to_int_type(size_variant, self.primitives);
let ndarray_int_sized =
make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.ty));
@ -1150,7 +1142,7 @@ impl<'a> BuiltinBuilder<'a> {
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
let ret_elem_ty = size_variant.of_int(&ctx.primitives);
let ret_elem_ty = size_variant_to_int_type(size_variant, &ctx.primitives);
let func = match kind {
Kind::Ceil => builtin_fns::call_ceil,
Kind::Floor => builtin_fns::call_floor,