From 94e2414df0c9905b515aa0800ccc23b8ddee003c Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 11 Nov 2024 15:00:24 +0800 Subject: [PATCH 01/16] [meta] Update cargo dependencies --- Cargo.lock | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7e09602..fa19e0d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -547,9 +547,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.162" +version = "0.2.164" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18d287de67fe55fd7e1581fe933d965a5a9477b38e949cfa9f8574ef01506398" +checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f" [[package]] name = "libloading" @@ -1000,9 +1000,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.40" +version = "0.38.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99e4ea3e1cdc4b559b8e5650f9c8e5998e3e5c1343b4eaf034565f32318d63c0" +checksum = "d7f649912bc1495e167a6edee79151c84b1bad49748cb4f1f1167f459f6224f6" dependencies = [ "bitflags", "errno", @@ -1066,9 +1066,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.132" +version = "1.0.133" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" +checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" dependencies = [ "itoa", "memchr", -- 2.44.2 From fe67ed076cf1308787cfbcd81e3d3e14d0ecad4a Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 8 Nov 2024 15:49:35 +0800 Subject: [PATCH 02/16] [meta] Update pre-commit configuration --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1d37807..4022a79 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ # See https://pre-commit.com for more information # See https://pre-commit.com/hooks.html for more hooks -default_stages: [commit] +default_stages: [pre-commit] repos: - repo: local -- 2.44.2 From 2822074b2d3acd83bccba67fc56384e12ad6deff Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 19 Nov 2024 13:43:57 +0800 Subject: [PATCH 03/16] [meta] Cleanup from upgrading Rust version - Remove rust_2024_edition warnings, since it wouldn't be released for another 3 months - Fix new clippy warnings --- nac3artiq/src/lib.rs | 12 +++------ nac3ast/src/lib.rs | 8 +----- nac3core/build.rs | 3 +-- nac3core/src/lib.rs | 8 +----- nac3core/src/typecheck/type_inferencer/mod.rs | 5 ++-- nac3ld/src/lib.rs | 8 +----- nac3parser/src/lib.rs | 8 +----- nac3standalone/src/main.rs | 8 +----- runkernel/src/main.rs | 26 +++++++------------ 9 files changed, 21 insertions(+), 65 deletions(-) diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 553264d..b49a88e 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -1,10 +1,4 @@ -#![deny( - future_incompatible, - let_underscore, - nonstandard_style, - clippy::all -)] -#![warn(rust_2024_compatibility)] +#![deny(future_incompatible, let_underscore, nonstandard_style, clippy::all)] #![warn(clippy::pedantic)] #![allow( unsafe_op_in_unsafe_fn, @@ -741,7 +735,7 @@ impl Nac3 { }; let return_obj = - generator.gen_expr(ctx, &expr)?.map(|value| (expr.custom.unwrap(), value)); + generator.gen_expr(ctx, expr)?.map(|value| (expr.custom.unwrap(), value)); has_return = return_obj.is_some(); registry.wait_tasks_complete(handles); attributes_writeback( @@ -765,7 +759,7 @@ impl Nac3 { let buffers = membuffers.lock(); let main = context .create_module_from_ir(MemoryBuffer::create_from_memory_range( - &buffers.last().unwrap(), + buffers.last().unwrap(), "main", )) .unwrap(); diff --git a/nac3ast/src/lib.rs b/nac3ast/src/lib.rs index e313946..7691915 100644 --- a/nac3ast/src/lib.rs +++ b/nac3ast/src/lib.rs @@ -1,10 +1,4 @@ -#![deny( - future_incompatible, - let_underscore, - nonstandard_style, - clippy::all -)] -#![warn(rust_2024_compatibility)] +#![deny(future_incompatible, let_underscore, nonstandard_style, clippy::all)] #![warn(clippy::pedantic)] #![allow( clippy::missing_errors_doc, diff --git a/nac3core/build.rs b/nac3core/build.rs index 8a5c56f..c2059a2 100644 --- a/nac3core/build.rs +++ b/nac3core/build.rs @@ -56,9 +56,8 @@ fn main() { let output = Command::new("clang-irrt") .args(flags) .output() - .map(|o| { + .inspect(|o| { assert!(o.status.success(), "{}", std::str::from_utf8(&o.stderr).unwrap()); - o }) .unwrap(); diff --git a/nac3core/src/lib.rs b/nac3core/src/lib.rs index c2faade..91ae05a 100644 --- a/nac3core/src/lib.rs +++ b/nac3core/src/lib.rs @@ -1,10 +1,4 @@ -#![deny( - future_incompatible, - let_underscore, - nonstandard_style, - clippy::all -)] -#![warn(rust_2024_compatibility)] +#![deny(future_incompatible, let_underscore, nonstandard_style, clippy::all)] #![warn(clippy::pedantic)] #![allow( dead_code, diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index e5a6cf8..6068f63 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -536,9 +536,8 @@ impl<'a> Fold<()> for Inferencer<'a> { } ast::StmtKind::Assert { test, msg, .. } => { self.unify(test.custom.unwrap(), self.primitives.bool, &test.location)?; - match msg { - Some(m) => self.unify(m.custom.unwrap(), self.primitives.str, &m.location)?, - None => (), + if let Some(m) = msg { + self.unify(m.custom.unwrap(), self.primitives.str, &m.location)?; } } _ => return report_error("Unsupported statement type", stmt.location), diff --git a/nac3ld/src/lib.rs b/nac3ld/src/lib.rs index 4d5a51f..73e065d 100644 --- a/nac3ld/src/lib.rs +++ b/nac3ld/src/lib.rs @@ -1,10 +1,4 @@ -#![deny( - future_incompatible, - let_underscore, - nonstandard_style, - clippy::all -)] -#![warn(rust_2024_compatibility)] +#![deny(future_incompatible, let_underscore, nonstandard_style, clippy::all)] #![warn(clippy::pedantic)] #![allow( clippy::cast_possible_truncation, diff --git a/nac3parser/src/lib.rs b/nac3parser/src/lib.rs index 6f3fcaa..7aa6e82 100644 --- a/nac3parser/src/lib.rs +++ b/nac3parser/src/lib.rs @@ -15,13 +15,7 @@ //! //! ``` -#![deny( - future_incompatible, - let_underscore, - nonstandard_style, - clippy::all -)] -#![warn(rust_2024_compatibility)] +#![deny(future_incompatible, let_underscore, nonstandard_style, clippy::all)] #![warn(clippy::pedantic)] #![allow( clippy::enum_glob_use, diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index 698ff9c..2fce5d1 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -1,10 +1,4 @@ -#![deny( - future_incompatible, - let_underscore, - nonstandard_style, - clippy::all -)] -#![warn(rust_2024_compatibility)] +#![deny(future_incompatible, let_underscore, nonstandard_style, clippy::all)] #![warn(clippy::pedantic)] #![allow(clippy::too_many_lines, clippy::wildcard_imports)] diff --git a/runkernel/src/main.rs b/runkernel/src/main.rs index fa99f69..971f3a7 100644 --- a/runkernel/src/main.rs +++ b/runkernel/src/main.rs @@ -1,10 +1,4 @@ -#![deny( - future_incompatible, - let_underscore, - nonstandard_style, - clippy::all -)] -#![warn(rust_2024_compatibility)] +#![deny(future_incompatible, let_underscore, nonstandard_style, clippy::all)] #![warn(clippy::pedantic)] #![allow(clippy::semicolon_if_nothing_returned, clippy::uninlined_format_args)] @@ -12,47 +6,47 @@ use std::env; static mut NOW: i64 = 0; -#[no_mangle] +#[unsafe(no_mangle)] pub extern "C" fn now_mu() -> i64 { unsafe { NOW } } -#[no_mangle] +#[unsafe(no_mangle)] pub extern "C" fn at_mu(t: i64) { unsafe { NOW = t } } -#[no_mangle] +#[unsafe(no_mangle)] pub extern "C" fn delay_mu(dt: i64) { unsafe { NOW += dt } } -#[no_mangle] +#[unsafe(no_mangle)] pub extern "C" fn rtio_init() { println!("rtio_init"); } -#[no_mangle] +#[unsafe(no_mangle)] pub extern "C" fn rtio_get_counter() -> i64 { 0 } -#[no_mangle] +#[unsafe(no_mangle)] pub extern "C" fn rtio_output(target: i32, data: i32) { println!("rtio_output @{} target={target:04x} data={data}", unsafe { NOW }); } -#[no_mangle] +#[unsafe(no_mangle)] pub extern "C" fn print_int32(x: i32) { println!("print_int32: {x}"); } -#[no_mangle] +#[unsafe(no_mangle)] pub extern "C" fn print_int64(x: i64) { println!("print_int64: {x}"); } -#[no_mangle] +#[unsafe(no_mangle)] pub extern "C" fn __nac3_personality(_state: u32, _exception_object: u32, _context: u32) -> u32 { unimplemented!(); } -- 2.44.2 From 26a1b85206855073f55788e1fdec63a5a1391c1f Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 29 Oct 2024 13:41:11 +0800 Subject: [PATCH 04/16] [core] codegen/classes: Remove Underlying type This is confusing and we want a better abstraction than this. --- nac3artiq/src/symbol_resolver.rs | 11 +-- nac3core/src/codegen/classes.rs | 161 ++++++++++++++++--------------- 2 files changed, 86 insertions(+), 86 deletions(-) diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index fd8ed0d..11662b4 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -14,10 +14,7 @@ use pyo3::{ }; use nac3core::{ - codegen::{ - classes::{NDArrayType, ProxyType}, - CodeGenContext, CodeGenerator, - }, + codegen::{classes::NDArrayType, CodeGenContext, CodeGenerator}, inkwell::{ module::Linkage, types::{BasicType, BasicTypeEnum}, @@ -1096,7 +1093,7 @@ impl InnerResolver { if self.global_value_ids.read().contains_key(&id) { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { ctx.module.add_global( - ndarray_llvm_ty.as_underlying_type(), + ndarray_llvm_ty.element_type().into_struct_type(), Some(AddressSpace::default()), &id_str, ) @@ -1190,7 +1187,7 @@ impl InnerResolver { data_global.set_initializer(&data); // create a global for the ndarray object and initialize it - let value = ndarray_llvm_ty.as_underlying_type().const_named_struct(&[ + let value = ndarray_llvm_ty.element_type().into_struct_type().const_named_struct(&[ llvm_usize.const_int(ndarray_ndims, false).into(), shape_global .as_pointer_value() @@ -1203,7 +1200,7 @@ impl InnerResolver { ]); let ndarray = ctx.module.add_global( - ndarray_llvm_ty.as_underlying_type(), + ndarray_llvm_ty.element_type().into_struct_type(), Some(AddressSpace::default()), &id_str, ); diff --git a/nac3core/src/codegen/classes.rs b/nac3core/src/codegen/classes.rs index 8628aaa..6cabcc6 100644 --- a/nac3core/src/codegen/classes.rs +++ b/nac3core/src/codegen/classes.rs @@ -1,7 +1,7 @@ use inkwell::{ context::Context, - types::{AnyTypeEnum, ArrayType, BasicType, BasicTypeEnum, IntType, PointerType, StructType}, - values::{ArrayValue, BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue}, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, + values::{BasicValue, BasicValueEnum, IntValue, PointerValue}, AddressSpace, IntPredicate, }; @@ -18,10 +18,6 @@ pub trait ProxyType<'ctx>: Into { /// [LLVM pointer type][PointerType]. type Base: BasicType<'ctx>; - /// The underlying LLVM type used to represent values. This is usually the element type of - /// [`Base`] if it is a pointer, otherwise this is the same type as `Base`. - type Underlying: BasicType<'ctx>; - /// The type of values represented by this type. type Value: ProxyValue<'ctx>; @@ -40,11 +36,7 @@ pub trait ProxyType<'ctx>: Into { ctx: &mut CodeGenContext<'ctx, '_>, size: IntValue<'ctx>, name: Option<&'ctx str>, - ) -> ArraySliceValue<'ctx> { - generator - .gen_array_var_alloc(ctx, self.as_underlying_type().as_basic_type_enum(), size, name) - .unwrap() - } + ) -> ArraySliceValue<'ctx>; /// Creates a [`value`][ProxyValue] with this as its type. fn create_value( @@ -55,9 +47,6 @@ pub trait ProxyType<'ctx>: Into { /// Returns the [base type][Self::Base] of this proxy. fn as_base_type(&self) -> Self::Base; - - /// Returns the [underlying type][Self::Underlying] of this proxy. - fn as_underlying_type(&self) -> Self::Underlying; } /// A LLVM type that is used to represent a non-primitive value in NAC3. @@ -66,10 +55,6 @@ pub trait ProxyValue<'ctx>: Into { /// [LLVM pointer type][PointerValue]. type Base: BasicValue<'ctx>; - /// The underlying type of LLVM values represented by this instance. This is usually the element - /// type of [`Base`] if it is a pointer, otherwise this is the same type as `Base`. - type Underlying: BasicValue<'ctx>; - /// The type of this value. type Type: ProxyType<'ctx>; @@ -79,13 +64,13 @@ pub trait ProxyValue<'ctx>: Into { /// Returns the [base value][Self::Base] of this proxy. fn as_base_value(&self) -> Self::Base; - /// Loads this value into its [underlying representation][Self::Underlying]. Usually involves a - /// `getelementptr` if [`Self::Base`] is a [pointer value][PointerValue]. - fn as_underlying_value( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> Self::Underlying; + // /// Loads this value into its [underlying representation][Self::Underlying]. Usually involves a + // /// `getelementptr` if [`Self::Base`] is a [pointer value][PointerValue]. + // fn as_underlying_value( + // &self, + // ctx: &mut CodeGenContext<'ctx, '_>, + // name: Option<&'ctx str>, + // ) -> Self::Underlying; } /// An LLVM value that is array-like, i.e. it contains a contiguous, sequenced collection of @@ -600,7 +585,6 @@ impl<'ctx> ListType<'ctx> { impl<'ctx> ProxyType<'ctx> for ListType<'ctx> { type Base = PointerType<'ctx>; - type Underlying = StructType<'ctx>; type Value = ListValue<'ctx>; fn new_value( @@ -610,11 +594,34 @@ impl<'ctx> ProxyType<'ctx> for ListType<'ctx> { name: Option<&'ctx str>, ) -> Self::Value { self.create_value( - generator.gen_var_alloc(ctx, self.as_underlying_type().into(), name).unwrap(), + generator + .gen_var_alloc( + ctx, + self.as_base_type().get_element_type().into_struct_type().into(), + name, + ) + .unwrap(), name, ) } + fn new_array_value( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + size: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> ArraySliceValue<'ctx> { + generator + .gen_array_var_alloc( + ctx, + self.as_base_type().get_element_type().into_struct_type().into(), + size, + name, + ) + .unwrap() + } + fn create_value( &self, value: >::Base, @@ -628,10 +635,6 @@ impl<'ctx> ProxyType<'ctx> for ListType<'ctx> { fn as_base_type(&self) -> Self::Base { self.ty } - - fn as_underlying_type(&self) -> Self::Underlying { - self.as_base_type().get_element_type().into_struct_type() - } } impl<'ctx> From> for PointerType<'ctx> { @@ -770,7 +773,6 @@ impl<'ctx> ListValue<'ctx> { impl<'ctx> ProxyValue<'ctx> for ListValue<'ctx> { type Base = PointerValue<'ctx>; - type Underlying = StructValue<'ctx>; type Type = ListType<'ctx>; fn get_type(&self) -> Self::Type { @@ -780,17 +782,6 @@ impl<'ctx> ProxyValue<'ctx> for ListValue<'ctx> { fn as_base_value(&self) -> Self::Base { self.value } - - fn as_underlying_value( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> Self::Underlying { - ctx.builder - .build_load(self.as_base_value(), name.unwrap_or_default()) - .map(BasicValueEnum::into_struct_value) - .unwrap() - } } impl<'ctx> From> for PointerValue<'ctx> { @@ -940,7 +931,6 @@ impl<'ctx> RangeType<'ctx> { impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> { type Base = PointerType<'ctx>; - type Underlying = ArrayType<'ctx>; type Value = RangeValue<'ctx>; fn new_value( @@ -950,7 +940,13 @@ impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> { name: Option<&'ctx str>, ) -> Self::Value { self.create_value( - generator.gen_var_alloc(ctx, self.as_underlying_type().into(), name).unwrap(), + generator + .gen_var_alloc( + ctx, + self.as_base_type().get_element_type().into_struct_type().into(), + name, + ) + .unwrap(), name, ) } @@ -965,12 +961,25 @@ impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> { RangeValue { value, name } } - fn as_base_type(&self) -> Self::Base { - self.ty + fn new_array_value( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + size: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> ArraySliceValue<'ctx> { + generator + .gen_array_var_alloc( + ctx, + self.as_base_type().get_element_type().into_struct_type().into(), + size, + name, + ) + .unwrap() } - fn as_underlying_type(&self) -> Self::Underlying { - self.as_base_type().get_element_type().into_array_type() + fn as_base_type(&self) -> Self::Base { + self.ty } } @@ -1112,7 +1121,6 @@ impl<'ctx> RangeValue<'ctx> { impl<'ctx> ProxyValue<'ctx> for RangeValue<'ctx> { type Base = PointerValue<'ctx>; - type Underlying = ArrayValue<'ctx>; type Type = RangeType<'ctx>; fn get_type(&self) -> Self::Type { @@ -1122,17 +1130,6 @@ impl<'ctx> ProxyValue<'ctx> for RangeValue<'ctx> { fn as_base_value(&self) -> Self::Base { self.value } - - fn as_underlying_value( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> Self::Underlying { - ctx.builder - .build_load(self.as_base_value(), name.unwrap_or_default()) - .map(BasicValueEnum::into_array_value) - .unwrap() - } } impl<'ctx> From> for PointerValue<'ctx> { @@ -1262,7 +1259,6 @@ impl<'ctx> NDArrayType<'ctx> { impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> { type Base = PointerType<'ctx>; - type Underlying = StructType<'ctx>; type Value = NDArrayValue<'ctx>; fn new_value( @@ -1272,11 +1268,34 @@ impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> { name: Option<&'ctx str>, ) -> Self::Value { self.create_value( - generator.gen_var_alloc(ctx, self.as_underlying_type().into(), name).unwrap(), + generator + .gen_var_alloc( + ctx, + self.as_base_type().get_element_type().into_struct_type().into(), + name, + ) + .unwrap(), name, ) } + fn new_array_value( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + size: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> ArraySliceValue<'ctx> { + generator + .gen_array_var_alloc( + ctx, + self.as_base_type().get_element_type().into_struct_type().into(), + size, + name, + ) + .unwrap() + } + fn create_value( &self, value: >::Base, @@ -1290,10 +1309,6 @@ impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> { fn as_base_type(&self) -> Self::Base { self.ty } - - fn as_underlying_type(&self) -> Self::Underlying { - self.as_base_type().get_element_type().into_struct_type() - } } impl<'ctx> From> for PointerType<'ctx> { @@ -1445,7 +1460,6 @@ impl<'ctx> NDArrayValue<'ctx> { impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> { type Base = PointerValue<'ctx>; - type Underlying = StructValue<'ctx>; type Type = NDArrayType<'ctx>; fn get_type(&self) -> Self::Type { @@ -1455,17 +1469,6 @@ impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> { fn as_base_value(&self) -> Self::Base { self.value } - - fn as_underlying_value( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> Self::Underlying { - ctx.builder - .build_load(self.as_base_value(), name.unwrap_or_default()) - .map(BasicValueEnum::into_struct_value) - .unwrap() - } } impl<'ctx> From> for PointerValue<'ctx> { -- 2.44.2 From 9d9ead211eecb411a26a7ff3be17a1d410073be4 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 29 Oct 2024 13:57:28 +0800 Subject: [PATCH 05/16] [core] Move Proxies to their own modules --- nac3artiq/src/codegen.rs | 9 +- nac3artiq/src/symbol_resolver.rs | 2 +- nac3core/src/codegen/builtin_fns.rs | 8 +- nac3core/src/codegen/classes.rs | 1768 ------------------------ nac3core/src/codegen/expr.rs | 9 +- nac3core/src/codegen/generator.rs | 2 +- nac3core/src/codegen/irrt/mod.rs | 8 +- nac3core/src/codegen/mod.rs | 5 +- nac3core/src/codegen/numpy.rs | 11 +- nac3core/src/codegen/stmt.rs | 2 +- nac3core/src/codegen/test.rs | 2 +- nac3core/src/codegen/types/list.rs | 163 +++ nac3core/src/codegen/types/mod.rs | 50 + nac3core/src/codegen/types/ndarray.rs | 191 +++ nac3core/src/codegen/types/range.rs | 132 ++ nac3core/src/codegen/values/array.rs | 426 ++++++ nac3core/src/codegen/values/list.rs | 238 ++++ nac3core/src/codegen/values/mod.rs | 28 + nac3core/src/codegen/values/ndarray.rs | 466 +++++++ nac3core/src/codegen/values/range.rs | 153 ++ nac3core/src/toplevel/builtins.rs | 2 +- 21 files changed, 1879 insertions(+), 1796 deletions(-) delete mode 100644 nac3core/src/codegen/classes.rs create mode 100644 nac3core/src/codegen/types/list.rs create mode 100644 nac3core/src/codegen/types/mod.rs create mode 100644 nac3core/src/codegen/types/ndarray.rs create mode 100644 nac3core/src/codegen/types/range.rs create mode 100644 nac3core/src/codegen/values/array.rs create mode 100644 nac3core/src/codegen/values/list.rs create mode 100644 nac3core/src/codegen/values/mod.rs create mode 100644 nac3core/src/codegen/values/ndarray.rs create mode 100644 nac3core/src/codegen/values/range.rs diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 8bec23b..85f9b1c 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -14,14 +14,15 @@ use pyo3::{ use nac3core::{ codegen::{ - classes::{ - ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayType, - NDArrayValue, ProxyType, ProxyValue, RangeValue, UntypedArrayLikeAccessor, - }, expr::{destructure_range, gen_call}, irrt::call_ndarray_calc_size, llvm_intrinsics::{call_int_smax, call_memcpy_generic, call_stackrestore, call_stacksave}, stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with}, + types::{NDArrayType, ProxyType}, + values::{ + ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, ProxyValue, + RangeValue, UntypedArrayLikeAccessor, + }, CodeGenContext, CodeGenerator, }, inkwell::{ diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 11662b4..4aeaaf1 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -14,7 +14,7 @@ use pyo3::{ }; use nac3core::{ - codegen::{classes::NDArrayType, CodeGenContext, CodeGenerator}, + codegen::{types::NDArrayType, CodeGenContext, CodeGenerator}, inkwell::{ module::Linkage, types::{BasicType, BasicTypeEnum}, diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 4260ebe..a6749d5 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -6,10 +6,6 @@ use inkwell::{ use itertools::Itertools; use super::{ - classes::{ - ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor, - UntypedArrayLikeAccessor, UntypedArrayLikeMutator, - }, expr::destructure_range, extern_fns, irrt, irrt::calculate_len_for_slice_range, @@ -18,6 +14,10 @@ use super::{ numpy, numpy::ndarray_elementwise_unaryop_impl, stmt::gen_for_callback_incrementing, + values::{ + ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor, + UntypedArrayLikeAccessor, UntypedArrayLikeMutator, + }, CodeGenContext, CodeGenerator, }; use crate::{ diff --git a/nac3core/src/codegen/classes.rs b/nac3core/src/codegen/classes.rs deleted file mode 100644 index 6cabcc6..0000000 --- a/nac3core/src/codegen/classes.rs +++ /dev/null @@ -1,1768 +0,0 @@ -use inkwell::{ - context::Context, - types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::{BasicValue, BasicValueEnum, IntValue, PointerValue}, - AddressSpace, IntPredicate, -}; - -use super::{ - irrt::{call_ndarray_calc_size, call_ndarray_flatten_index}, - llvm_intrinsics::call_int_umin, - stmt::gen_for_callback_incrementing, - CodeGenContext, CodeGenerator, -}; - -/// A LLVM type that is used to represent a non-primitive type in NAC3. -pub trait ProxyType<'ctx>: Into { - /// The LLVM type of which values of this type possess. This is usually a - /// [LLVM pointer type][PointerType]. - type Base: BasicType<'ctx>; - - /// The type of values represented by this type. - type Value: ProxyValue<'ctx>; - - /// Creates a new value of this type. - fn new_value( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> Self::Value; - - /// Creates a new array value of this type. - fn new_array_value( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - size: IntValue<'ctx>, - name: Option<&'ctx str>, - ) -> ArraySliceValue<'ctx>; - - /// Creates a [`value`][ProxyValue] with this as its type. - fn create_value( - &self, - value: >::Base, - name: Option<&'ctx str>, - ) -> Self::Value; - - /// Returns the [base type][Self::Base] of this proxy. - fn as_base_type(&self) -> Self::Base; -} - -/// A LLVM type that is used to represent a non-primitive value in NAC3. -pub trait ProxyValue<'ctx>: Into { - /// The type of LLVM values represented by this instance. This is usually the - /// [LLVM pointer type][PointerValue]. - type Base: BasicValue<'ctx>; - - /// The type of this value. - type Type: ProxyType<'ctx>; - - /// Returns the [type][ProxyType] of this value. - fn get_type(&self) -> Self::Type; - - /// Returns the [base value][Self::Base] of this proxy. - fn as_base_value(&self) -> Self::Base; - - // /// Loads this value into its [underlying representation][Self::Underlying]. Usually involves a - // /// `getelementptr` if [`Self::Base`] is a [pointer value][PointerValue]. - // fn as_underlying_value( - // &self, - // ctx: &mut CodeGenContext<'ctx, '_>, - // name: Option<&'ctx str>, - // ) -> Self::Underlying; -} - -/// An LLVM value that is array-like, i.e. it contains a contiguous, sequenced collection of -/// elements. -pub trait ArrayLikeValue<'ctx> { - /// Returns the element type of this array-like value. - fn element_type( - &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, - ) -> AnyTypeEnum<'ctx>; - - /// Returns the base pointer to the array. - fn base_ptr( - &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, - ) -> PointerValue<'ctx>; - - /// Returns the size of this array-like value. - fn size( - &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, - ) -> IntValue<'ctx>; - - /// Returns a [`ArraySliceValue`] representing this value. - fn as_slice_value( - &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, - ) -> ArraySliceValue<'ctx> { - ArraySliceValue::from_ptr_val( - self.base_ptr(ctx, generator), - self.size(ctx, generator), - None, - ) - } -} - -/// An array-like value that can be indexed by memory offset. -pub trait ArrayLikeIndexer<'ctx, Index = IntValue<'ctx>>: ArrayLikeValue<'ctx> { - /// # Safety - /// - /// This function should be called with a valid index. - unsafe fn ptr_offset_unchecked( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - idx: &Index, - name: Option<&str>, - ) -> PointerValue<'ctx>; - - /// Returns the pointer to the data at the `idx`-th index. - fn ptr_offset( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - idx: &Index, - name: Option<&str>, - ) -> PointerValue<'ctx>; -} - -/// An array-like value that can have its array elements accessed as a [`BasicValueEnum`]. -pub trait UntypedArrayLikeAccessor<'ctx, Index = IntValue<'ctx>>: - ArrayLikeIndexer<'ctx, Index> -{ - /// # Safety - /// - /// This function should be called with a valid index. - unsafe fn get_unchecked( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - idx: &Index, - name: Option<&str>, - ) -> BasicValueEnum<'ctx> { - let ptr = unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }; - ctx.builder.build_load(ptr, name.unwrap_or_default()).unwrap() - } - - /// Returns the data at the `idx`-th index. - fn get( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - idx: &Index, - name: Option<&str>, - ) -> BasicValueEnum<'ctx> { - let ptr = self.ptr_offset(ctx, generator, idx, name); - ctx.builder.build_load(ptr, name.unwrap_or_default()).unwrap() - } -} - -/// An array-like value that can have its array elements mutated as a [`BasicValueEnum`]. -pub trait UntypedArrayLikeMutator<'ctx, Index = IntValue<'ctx>>: - ArrayLikeIndexer<'ctx, Index> -{ - /// # Safety - /// - /// This function should be called with a valid index. - unsafe fn set_unchecked( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - idx: &Index, - value: BasicValueEnum<'ctx>, - ) { - let ptr = unsafe { self.ptr_offset_unchecked(ctx, generator, idx, None) }; - ctx.builder.build_store(ptr, value).unwrap(); - } - - /// Sets the data at the `idx`-th index. - fn set( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - idx: &Index, - value: BasicValueEnum<'ctx>, - ) { - let ptr = self.ptr_offset(ctx, generator, idx, None); - ctx.builder.build_store(ptr, value).unwrap(); - } -} - -/// An array-like value that can have its array elements accessed as an arbitrary type `T`. -pub trait TypedArrayLikeAccessor<'ctx, T, Index = IntValue<'ctx>>: - UntypedArrayLikeAccessor<'ctx, Index> -{ - /// Casts an element from [`BasicValueEnum`] into `T`. - fn downcast_to_type( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - value: BasicValueEnum<'ctx>, - ) -> T; - - /// # Safety - /// - /// This function should be called with a valid index. - unsafe fn get_typed_unchecked( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - idx: &Index, - name: Option<&str>, - ) -> T { - let value = unsafe { self.get_unchecked(ctx, generator, idx, name) }; - self.downcast_to_type(ctx, value) - } - - /// Returns the data at the `idx`-th index. - fn get_typed( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - idx: &Index, - name: Option<&str>, - ) -> T { - let value = self.get(ctx, generator, idx, name); - self.downcast_to_type(ctx, value) - } -} - -/// An array-like value that can have its array elements mutated as an arbitrary type `T`. -pub trait TypedArrayLikeMutator<'ctx, T, Index = IntValue<'ctx>>: - UntypedArrayLikeMutator<'ctx, Index> -{ - /// Casts an element from T into [`BasicValueEnum`]. - fn upcast_from_type( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - value: T, - ) -> BasicValueEnum<'ctx>; - - /// # Safety - /// - /// This function should be called with a valid index. - unsafe fn set_typed_unchecked( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - idx: &Index, - value: T, - ) { - let value = self.upcast_from_type(ctx, value); - unsafe { self.set_unchecked(ctx, generator, idx, value) } - } - - /// Sets the data at the `idx`-th index. - fn set_typed( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - idx: &Index, - value: T, - ) { - let value = self.upcast_from_type(ctx, value); - self.set(ctx, generator, idx, value); - } -} - -/// Type alias for a function that casts a [`BasicValueEnum`] into a `T`. -type ValueDowncastFn<'ctx, T> = - Box, BasicValueEnum<'ctx>) -> T>; -/// Type alias for a function that casts a `T` into a [`BasicValueEnum`]. -type ValueUpcastFn<'ctx, T> = Box, T) -> BasicValueEnum<'ctx>>; - -/// An adapter for constraining untyped array values as typed values. -pub struct TypedArrayLikeAdapter<'ctx, T, Adapted: ArrayLikeValue<'ctx> = ArraySliceValue<'ctx>> { - adapted: Adapted, - downcast_fn: ValueDowncastFn<'ctx, T>, - upcast_fn: ValueUpcastFn<'ctx, T>, -} - -impl<'ctx, T, Adapted> TypedArrayLikeAdapter<'ctx, T, Adapted> -where - Adapted: ArrayLikeValue<'ctx>, -{ - /// Creates a [`TypedArrayLikeAdapter`]. - /// - /// * `adapted` - The value to be adapted. - /// * `downcast_fn` - The function converting a [`BasicValueEnum`] into a `T`. - /// * `upcast_fn` - The function converting a T into a [`BasicValueEnum`]. - pub fn from( - adapted: Adapted, - downcast_fn: ValueDowncastFn<'ctx, T>, - upcast_fn: ValueUpcastFn<'ctx, T>, - ) -> Self { - TypedArrayLikeAdapter { adapted, downcast_fn, upcast_fn } - } -} - -impl<'ctx, T, Adapted> ArrayLikeValue<'ctx> for TypedArrayLikeAdapter<'ctx, T, Adapted> -where - Adapted: ArrayLikeValue<'ctx>, -{ - fn element_type( - &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, - ) -> AnyTypeEnum<'ctx> { - self.adapted.element_type(ctx, generator) - } - - fn base_ptr( - &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, - ) -> PointerValue<'ctx> { - self.adapted.base_ptr(ctx, generator) - } - - fn size( - &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, - ) -> IntValue<'ctx> { - self.adapted.size(ctx, generator) - } -} - -impl<'ctx, T, Index, Adapted> ArrayLikeIndexer<'ctx, Index> - for TypedArrayLikeAdapter<'ctx, T, Adapted> -where - Adapted: ArrayLikeIndexer<'ctx, Index>, -{ - unsafe fn ptr_offset_unchecked( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - idx: &Index, - name: Option<&str>, - ) -> PointerValue<'ctx> { - unsafe { self.adapted.ptr_offset_unchecked(ctx, generator, idx, name) } - } - - fn ptr_offset( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - idx: &Index, - name: Option<&str>, - ) -> PointerValue<'ctx> { - self.adapted.ptr_offset(ctx, generator, idx, name) - } -} - -impl<'ctx, T, Index, Adapted> UntypedArrayLikeAccessor<'ctx, Index> - for TypedArrayLikeAdapter<'ctx, T, Adapted> -where - Adapted: UntypedArrayLikeAccessor<'ctx, Index>, -{ -} -impl<'ctx, T, Index, Adapted> UntypedArrayLikeMutator<'ctx, Index> - for TypedArrayLikeAdapter<'ctx, T, Adapted> -where - Adapted: UntypedArrayLikeMutator<'ctx, Index>, -{ -} - -impl<'ctx, T, Index, Adapted> TypedArrayLikeAccessor<'ctx, T, Index> - for TypedArrayLikeAdapter<'ctx, T, Adapted> -where - Adapted: UntypedArrayLikeAccessor<'ctx, Index>, -{ - fn downcast_to_type( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - value: BasicValueEnum<'ctx>, - ) -> T { - (self.downcast_fn)(ctx, value) - } -} - -impl<'ctx, T, Index, Adapted> TypedArrayLikeMutator<'ctx, T, Index> - for TypedArrayLikeAdapter<'ctx, T, Adapted> -where - Adapted: UntypedArrayLikeMutator<'ctx, Index>, -{ - fn upcast_from_type( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - value: T, - ) -> BasicValueEnum<'ctx> { - (self.upcast_fn)(ctx, value) - } -} - -/// An LLVM value representing an array slice, consisting of a pointer to the data and the size of -/// the slice. -#[derive(Copy, Clone)] -pub struct ArraySliceValue<'ctx>(PointerValue<'ctx>, IntValue<'ctx>, Option<&'ctx str>); - -impl<'ctx> ArraySliceValue<'ctx> { - /// Creates an [`ArraySliceValue`] from a [`PointerValue`] and its size. - #[must_use] - pub fn from_ptr_val( - ptr: PointerValue<'ctx>, - size: IntValue<'ctx>, - name: Option<&'ctx str>, - ) -> Self { - ArraySliceValue(ptr, size, name) - } -} - -impl<'ctx> From> for PointerValue<'ctx> { - fn from(value: ArraySliceValue<'ctx>) -> Self { - value.0 - } -} - -impl<'ctx> ArrayLikeValue<'ctx> for ArraySliceValue<'ctx> { - fn element_type( - &self, - _: &CodeGenContext<'ctx, '_>, - _: &G, - ) -> AnyTypeEnum<'ctx> { - self.0.get_type().get_element_type() - } - - fn base_ptr( - &self, - _: &CodeGenContext<'ctx, '_>, - _: &G, - ) -> PointerValue<'ctx> { - self.0 - } - - fn size( - &self, - _: &CodeGenContext<'ctx, '_>, - _: &G, - ) -> IntValue<'ctx> { - self.1 - } -} - -impl<'ctx> ArrayLikeIndexer<'ctx> for ArraySliceValue<'ctx> { - unsafe fn ptr_offset_unchecked( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - idx: &IntValue<'ctx>, - name: Option<&str>, - ) -> PointerValue<'ctx> { - let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str()) - .unwrap() - } - } - - fn ptr_offset( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - idx: &IntValue<'ctx>, - name: Option<&str>, - ) -> PointerValue<'ctx> { - debug_assert_eq!(idx.get_type(), generator.get_size_type(ctx.ctx)); - - let size = self.size(ctx, generator); - let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap(); - ctx.make_assert( - generator, - in_range, - "0:IndexError", - "list index out of range", - [None, None, None], - ctx.current_loc, - ); - - unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) } - } -} - -impl<'ctx> UntypedArrayLikeAccessor<'ctx> for ArraySliceValue<'ctx> {} -impl<'ctx> UntypedArrayLikeMutator<'ctx> for ArraySliceValue<'ctx> {} - -/// Proxy type for a `list` type in LLVM. -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub struct ListType<'ctx> { - ty: PointerType<'ctx>, - llvm_usize: IntType<'ctx>, -} - -impl<'ctx> ListType<'ctx> { - /// Checks whether `llvm_ty` represents a `list` type, returning [Err] if it does not. - pub fn is_type(llvm_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> { - let llvm_list_ty = llvm_ty.get_element_type(); - let AnyTypeEnum::StructType(llvm_list_ty) = llvm_list_ty else { - return Err(format!("Expected struct type for `list` type, got {llvm_list_ty}")); - }; - if llvm_list_ty.count_fields() != 2 { - return Err(format!( - "Expected 2 fields in `list`, got {}", - llvm_list_ty.count_fields() - )); - } - - let list_size_ty = llvm_list_ty.get_field_type_at_index(0).unwrap(); - let Ok(_) = PointerType::try_from(list_size_ty) else { - return Err(format!("Expected pointer type for `list.0`, got {list_size_ty}")); - }; - - let list_data_ty = llvm_list_ty.get_field_type_at_index(1).unwrap(); - let Ok(list_data_ty) = IntType::try_from(list_data_ty) else { - return Err(format!("Expected int type for `list.1`, got {list_data_ty}")); - }; - if list_data_ty.get_bit_width() != llvm_usize.get_bit_width() { - return Err(format!( - "Expected {}-bit int type for `list.1`, got {}-bit int", - llvm_usize.get_bit_width(), - list_data_ty.get_bit_width() - )); - } - - Ok(()) - } - - /// Creates an instance of [`ListType`]. - #[must_use] - pub fn new( - generator: &G, - ctx: &'ctx Context, - element_type: BasicTypeEnum<'ctx>, - ) -> Self { - let llvm_usize = generator.get_size_type(ctx); - let llvm_list = ctx - .struct_type( - &[element_type.ptr_type(AddressSpace::default()).into(), llvm_usize.into()], - false, - ) - .ptr_type(AddressSpace::default()); - - ListType::from_type(llvm_list, llvm_usize) - } - - /// Creates an [`ListType`] from a [`PointerType`]. - #[must_use] - pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { - debug_assert!(Self::is_type(ptr_ty, llvm_usize).is_ok()); - - ListType { ty: ptr_ty, llvm_usize } - } - - /// Returns the type of the `size` field of this `list` type. - #[must_use] - pub fn size_type(&self) -> IntType<'ctx> { - self.as_base_type() - .get_element_type() - .into_struct_type() - .get_field_type_at_index(1) - .map(BasicTypeEnum::into_int_type) - .unwrap() - } - - /// Returns the element type of this `list` type. - #[must_use] - pub fn element_type(&self) -> AnyTypeEnum<'ctx> { - self.as_base_type() - .get_element_type() - .into_struct_type() - .get_field_type_at_index(0) - .map(BasicTypeEnum::into_pointer_type) - .map(PointerType::get_element_type) - .unwrap() - } -} - -impl<'ctx> ProxyType<'ctx> for ListType<'ctx> { - type Base = PointerType<'ctx>; - type Value = ListValue<'ctx>; - - fn new_value( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> Self::Value { - self.create_value( - generator - .gen_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - name, - ) - .unwrap(), - name, - ) - } - - fn new_array_value( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - size: IntValue<'ctx>, - name: Option<&'ctx str>, - ) -> ArraySliceValue<'ctx> { - generator - .gen_array_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - size, - name, - ) - .unwrap() - } - - fn create_value( - &self, - value: >::Base, - name: Option<&'ctx str>, - ) -> Self::Value { - debug_assert_eq!(value.get_type(), self.as_base_type()); - - ListValue { value, llvm_usize: self.llvm_usize, name } - } - - fn as_base_type(&self) -> Self::Base { - self.ty - } -} - -impl<'ctx> From> for PointerType<'ctx> { - fn from(value: ListType<'ctx>) -> Self { - value.as_base_type() - } -} - -/// Proxy type for accessing a `list` value in LLVM. -#[derive(Copy, Clone)] -pub struct ListValue<'ctx> { - value: PointerValue<'ctx>, - llvm_usize: IntType<'ctx>, - name: Option<&'ctx str>, -} - -impl<'ctx> ListValue<'ctx> { - /// Checks whether `value` is an instance of `list`, returning [Err] if `value` is not an - /// instance. - pub fn is_instance(value: PointerValue<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> { - ListType::is_type(value.get_type(), llvm_usize) - } - - /// Creates an [`ListValue`] from a [`PointerValue`]. - #[must_use] - pub fn from_ptr_val( - ptr: PointerValue<'ctx>, - llvm_usize: IntType<'ctx>, - name: Option<&'ctx str>, - ) -> Self { - debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); - - >::Type::from_type(ptr.get_type(), llvm_usize) - .create_value(ptr, name) - } - - /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` - /// on the field. - fn pptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let llvm_i32 = ctx.ctx.i32_type(); - let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - var_name.as_str(), - ) - .unwrap() - } - } - - /// Returns the pointer to the field storing the size of this `list`. - fn ptr_to_size(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let llvm_i32 = ctx.ctx.i32_type(); - let var_name = self.name.map(|v| format!("{v}.size.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - var_name.as_str(), - ) - .unwrap() - } - } - - /// Stores the array of data elements `data` into this instance. - fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { - ctx.builder.build_store(self.pptr_to_data(ctx), data).unwrap(); - } - - /// Convenience method for creating a new array storing data elements with the given element - /// type `elem_ty` and `size`. - /// - /// If `size` is [None], the size stored in the field of this instance is used instead. - pub fn create_data( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: BasicTypeEnum<'ctx>, - size: Option>, - ) { - let size = size.unwrap_or_else(|| self.load_size(ctx, None)); - - let data = ctx - .builder - .build_select( - ctx.builder - .build_int_compare(IntPredicate::NE, size, self.llvm_usize.const_zero(), "") - .unwrap(), - ctx.builder.build_array_alloca(elem_ty, size, "").unwrap(), - elem_ty.ptr_type(AddressSpace::default()).const_zero(), - "", - ) - .map(BasicValueEnum::into_pointer_value) - .unwrap(); - self.store_data(ctx, data); - } - - /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` - /// on the field. - #[must_use] - pub fn data(&self) -> ListDataProxy<'ctx, '_> { - ListDataProxy(self) - } - - /// Stores the `size` of this `list` into this instance. - pub fn store_size( - &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, - size: IntValue<'ctx>, - ) { - debug_assert_eq!(size.get_type(), generator.get_size_type(ctx.ctx)); - - let psize = self.ptr_to_size(ctx); - ctx.builder.build_store(psize, size).unwrap(); - } - - /// Returns the size of this `list` as a value. - pub fn load_size(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> { - let psize = self.ptr_to_size(ctx); - let var_name = name - .map(ToString::to_string) - .or_else(|| self.name.map(|v| format!("{v}.size"))) - .unwrap_or_default(); - - ctx.builder - .build_load(psize, var_name.as_str()) - .map(BasicValueEnum::into_int_value) - .unwrap() - } -} - -impl<'ctx> ProxyValue<'ctx> for ListValue<'ctx> { - type Base = PointerValue<'ctx>; - type Type = ListType<'ctx>; - - fn get_type(&self) -> Self::Type { - ListType::from_type(self.as_base_value().get_type(), self.llvm_usize) - } - - fn as_base_value(&self) -> Self::Base { - self.value - } -} - -impl<'ctx> From> for PointerValue<'ctx> { - fn from(value: ListValue<'ctx>) -> Self { - value.as_base_value() - } -} - -/// Proxy type for accessing the `data` array of an `list` instance in LLVM. -#[derive(Copy, Clone)] -pub struct ListDataProxy<'ctx, 'a>(&'a ListValue<'ctx>); - -impl<'ctx> ArrayLikeValue<'ctx> for ListDataProxy<'ctx, '_> { - fn element_type( - &self, - _: &CodeGenContext<'ctx, '_>, - _: &G, - ) -> AnyTypeEnum<'ctx> { - self.0.value.get_type().get_element_type() - } - - fn base_ptr( - &self, - ctx: &CodeGenContext<'ctx, '_>, - _: &G, - ) -> PointerValue<'ctx> { - let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default(); - - ctx.builder - .build_load(self.0.pptr_to_data(ctx), var_name.as_str()) - .map(BasicValueEnum::into_pointer_value) - .unwrap() - } - - fn size( - &self, - ctx: &CodeGenContext<'ctx, '_>, - _: &G, - ) -> IntValue<'ctx> { - self.0.load_size(ctx, None) - } -} - -impl<'ctx> ArrayLikeIndexer<'ctx> for ListDataProxy<'ctx, '_> { - unsafe fn ptr_offset_unchecked( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - idx: &IntValue<'ctx>, - name: Option<&str>, - ) -> PointerValue<'ctx> { - let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str()) - .unwrap() - } - } - - fn ptr_offset( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - idx: &IntValue<'ctx>, - name: Option<&str>, - ) -> PointerValue<'ctx> { - debug_assert_eq!(idx.get_type(), generator.get_size_type(ctx.ctx)); - - let size = self.size(ctx, generator); - let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap(); - ctx.make_assert( - generator, - in_range, - "0:IndexError", - "list index out of range", - [None, None, None], - ctx.current_loc, - ); - - unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) } - } -} - -impl<'ctx> UntypedArrayLikeAccessor<'ctx> for ListDataProxy<'ctx, '_> {} -impl<'ctx> UntypedArrayLikeMutator<'ctx> for ListDataProxy<'ctx, '_> {} - -/// Proxy type for a `range` type in LLVM. -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub struct RangeType<'ctx> { - ty: PointerType<'ctx>, -} - -impl<'ctx> RangeType<'ctx> { - /// Checks whether `llvm_ty` represents a `range` type, returning [Err] if it does not. - pub fn is_type(llvm_ty: PointerType<'ctx>) -> Result<(), String> { - let llvm_range_ty = llvm_ty.get_element_type(); - let AnyTypeEnum::ArrayType(llvm_range_ty) = llvm_range_ty else { - return Err(format!("Expected array type for `range` type, got {llvm_range_ty}")); - }; - if llvm_range_ty.len() != 3 { - return Err(format!( - "Expected 3 elements for `range` type, got {}", - llvm_range_ty.len() - )); - } - - let llvm_range_elem_ty = llvm_range_ty.get_element_type(); - let Ok(llvm_range_elem_ty) = IntType::try_from(llvm_range_elem_ty) else { - return Err(format!( - "Expected int type for `range` element type, got {llvm_range_elem_ty}" - )); - }; - if llvm_range_elem_ty.get_bit_width() != 32 { - return Err(format!( - "Expected 32-bit int type for `range` element type, got {}", - llvm_range_elem_ty.get_bit_width() - )); - } - - Ok(()) - } - - /// Creates an instance of [`RangeType`]. - #[must_use] - pub fn new(ctx: &'ctx Context) -> Self { - let llvm_i32 = ctx.i32_type(); - let llvm_range = llvm_i32.array_type(3).ptr_type(AddressSpace::default()); - - RangeType::from_type(llvm_range) - } - - /// Creates an [`RangeType`] from a [`PointerType`]. - #[must_use] - pub fn from_type(ptr_ty: PointerType<'ctx>) -> Self { - debug_assert!(Self::is_type(ptr_ty).is_ok()); - - RangeType { ty: ptr_ty } - } - - /// Returns the type of all fields of this `range` type. - #[must_use] - pub fn value_type(&self) -> IntType<'ctx> { - self.as_base_type().get_element_type().into_array_type().get_element_type().into_int_type() - } -} - -impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> { - type Base = PointerType<'ctx>; - type Value = RangeValue<'ctx>; - - fn new_value( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> Self::Value { - self.create_value( - generator - .gen_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - name, - ) - .unwrap(), - name, - ) - } - - fn create_value( - &self, - value: >::Base, - name: Option<&'ctx str>, - ) -> Self::Value { - debug_assert_eq!(value.get_type(), self.as_base_type()); - - RangeValue { value, name } - } - - fn new_array_value( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - size: IntValue<'ctx>, - name: Option<&'ctx str>, - ) -> ArraySliceValue<'ctx> { - generator - .gen_array_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - size, - name, - ) - .unwrap() - } - - fn as_base_type(&self) -> Self::Base { - self.ty - } -} - -impl<'ctx> From> for PointerType<'ctx> { - fn from(value: RangeType<'ctx>) -> Self { - value.as_base_type() - } -} - -/// Proxy type for accessing a `range` value in LLVM. -#[derive(Copy, Clone)] -pub struct RangeValue<'ctx> { - value: PointerValue<'ctx>, - name: Option<&'ctx str>, -} - -impl<'ctx> RangeValue<'ctx> { - /// Checks whether `value` is an instance of `range`, returning [Err] if `value` is not an instance. - pub fn is_instance(value: PointerValue<'ctx>) -> Result<(), String> { - RangeType::is_type(value.get_type()) - } - - /// Creates an [`RangeValue`] from a [`PointerValue`]. - #[must_use] - pub fn from_ptr_val(ptr: PointerValue<'ctx>, name: Option<&'ctx str>) -> Self { - debug_assert!(Self::is_instance(ptr).is_ok()); - - >::Type::from_type(ptr.get_type()).create_value(ptr, name) - } - - fn ptr_to_start(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let llvm_i32 = ctx.ctx.i32_type(); - let var_name = self.name.map(|v| format!("{v}.start.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_int(0, false)], - var_name.as_str(), - ) - .unwrap() - } - } - - fn ptr_to_end(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let llvm_i32 = ctx.ctx.i32_type(); - let var_name = self.name.map(|v| format!("{v}.end.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_int(1, false)], - var_name.as_str(), - ) - .unwrap() - } - } - - fn ptr_to_step(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let llvm_i32 = ctx.ctx.i32_type(); - let var_name = self.name.map(|v| format!("{v}.step.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_int(2, false)], - var_name.as_str(), - ) - .unwrap() - } - } - - /// Stores the `start` value into this instance. - pub fn store_start(&self, ctx: &CodeGenContext<'ctx, '_>, start: IntValue<'ctx>) { - debug_assert_eq!(start.get_type().get_bit_width(), 32); - - let pstart = self.ptr_to_start(ctx); - ctx.builder.build_store(pstart, start).unwrap(); - } - - /// Returns the `start` value of this `range`. - pub fn load_start(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> { - let pstart = self.ptr_to_start(ctx); - let var_name = name - .map(ToString::to_string) - .or_else(|| self.name.map(|v| format!("{v}.start"))) - .unwrap_or_default(); - - ctx.builder - .build_load(pstart, var_name.as_str()) - .map(BasicValueEnum::into_int_value) - .unwrap() - } - - /// Stores the `end` value into this instance. - pub fn store_end(&self, ctx: &CodeGenContext<'ctx, '_>, end: IntValue<'ctx>) { - debug_assert_eq!(end.get_type().get_bit_width(), 32); - - let pend = self.ptr_to_end(ctx); - ctx.builder.build_store(pend, end).unwrap(); - } - - /// Returns the `end` value of this `range`. - pub fn load_end(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> { - let pend = self.ptr_to_end(ctx); - let var_name = name - .map(ToString::to_string) - .or_else(|| self.name.map(|v| format!("{v}.end"))) - .unwrap_or_default(); - - ctx.builder.build_load(pend, var_name.as_str()).map(BasicValueEnum::into_int_value).unwrap() - } - - /// Stores the `step` value into this instance. - pub fn store_step(&self, ctx: &CodeGenContext<'ctx, '_>, step: IntValue<'ctx>) { - debug_assert_eq!(step.get_type().get_bit_width(), 32); - - let pstep = self.ptr_to_step(ctx); - ctx.builder.build_store(pstep, step).unwrap(); - } - - /// Returns the `step` value of this `range`. - pub fn load_step(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> { - let pstep = self.ptr_to_step(ctx); - let var_name = name - .map(ToString::to_string) - .or_else(|| self.name.map(|v| format!("{v}.step"))) - .unwrap_or_default(); - - ctx.builder - .build_load(pstep, var_name.as_str()) - .map(BasicValueEnum::into_int_value) - .unwrap() - } -} - -impl<'ctx> ProxyValue<'ctx> for RangeValue<'ctx> { - type Base = PointerValue<'ctx>; - type Type = RangeType<'ctx>; - - fn get_type(&self) -> Self::Type { - RangeType::from_type(self.value.get_type()) - } - - fn as_base_value(&self) -> Self::Base { - self.value - } -} - -impl<'ctx> From> for PointerValue<'ctx> { - fn from(value: RangeValue<'ctx>) -> Self { - value.as_base_value() - } -} - -/// Proxy type for a `ndarray` type in LLVM. -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub struct NDArrayType<'ctx> { - ty: PointerType<'ctx>, - llvm_usize: IntType<'ctx>, -} - -impl<'ctx> NDArrayType<'ctx> { - /// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not. - pub fn is_type(llvm_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> { - let llvm_ndarray_ty = llvm_ty.get_element_type(); - let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { - return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}")); - }; - if llvm_ndarray_ty.count_fields() != 3 { - return Err(format!( - "Expected 3 fields in `NDArray`, got {}", - llvm_ndarray_ty.count_fields() - )); - } - - let ndarray_ndims_ty = llvm_ndarray_ty.get_field_type_at_index(0).unwrap(); - let Ok(ndarray_ndims_ty) = IntType::try_from(ndarray_ndims_ty) else { - return Err(format!("Expected int type for `ndarray.0`, got {ndarray_ndims_ty}")); - }; - if ndarray_ndims_ty.get_bit_width() != llvm_usize.get_bit_width() { - return Err(format!( - "Expected {}-bit int type for `ndarray.0`, got {}-bit int", - llvm_usize.get_bit_width(), - ndarray_ndims_ty.get_bit_width() - )); - } - - let ndarray_dims_ty = llvm_ndarray_ty.get_field_type_at_index(1).unwrap(); - let Ok(ndarray_pdims) = PointerType::try_from(ndarray_dims_ty) else { - return Err(format!("Expected pointer type for `ndarray.1`, got {ndarray_dims_ty}")); - }; - let ndarray_dims = ndarray_pdims.get_element_type(); - let Ok(ndarray_dims) = IntType::try_from(ndarray_dims) else { - return Err(format!( - "Expected pointer-to-int type for `ndarray.1`, got pointer-to-{ndarray_dims}" - )); - }; - if ndarray_dims.get_bit_width() != llvm_usize.get_bit_width() { - return Err(format!( - "Expected pointer-to-{}-bit int type for `ndarray.1`, got pointer-to-{}-bit int", - llvm_usize.get_bit_width(), - ndarray_dims.get_bit_width() - )); - } - - let ndarray_data_ty = llvm_ndarray_ty.get_field_type_at_index(2).unwrap(); - let Ok(_) = PointerType::try_from(ndarray_data_ty) else { - return Err(format!("Expected pointer type for `ndarray.2`, got {ndarray_data_ty}")); - }; - - Ok(()) - } - - /// Creates an instance of [`ListType`]. - #[must_use] - pub fn new( - generator: &G, - ctx: &'ctx Context, - dtype: BasicTypeEnum<'ctx>, - ) -> Self { - let llvm_usize = generator.get_size_type(ctx); - - // struct NDArray { num_dims: size_t, dims: size_t*, data: T* } - // - // * num_dims: Number of dimensions in the array - // * dims: Pointer to an array containing the size of each dimension - // * data: Pointer to an array containing the array data - let llvm_ndarray = ctx - .struct_type( - &[ - llvm_usize.into(), - llvm_usize.ptr_type(AddressSpace::default()).into(), - dtype.ptr_type(AddressSpace::default()).into(), - ], - false, - ) - .ptr_type(AddressSpace::default()); - - NDArrayType::from_type(llvm_ndarray, llvm_usize) - } - - /// Creates an [`NDArrayType`] from a [`PointerType`]. - #[must_use] - pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { - debug_assert!(Self::is_type(ptr_ty, llvm_usize).is_ok()); - - NDArrayType { ty: ptr_ty, llvm_usize } - } - - /// Returns the type of the `size` field of this `ndarray` type. - #[must_use] - pub fn size_type(&self) -> IntType<'ctx> { - self.as_base_type() - .get_element_type() - .into_struct_type() - .get_field_type_at_index(0) - .map(BasicTypeEnum::into_int_type) - .unwrap() - } - - /// Returns the element type of this `ndarray` type. - #[must_use] - pub fn element_type(&self) -> AnyTypeEnum<'ctx> { - self.as_base_type() - .get_element_type() - .into_struct_type() - .get_field_type_at_index(2) - .map(BasicTypeEnum::into_pointer_type) - .map(PointerType::get_element_type) - .unwrap() - } -} - -impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> { - type Base = PointerType<'ctx>; - type Value = NDArrayValue<'ctx>; - - fn new_value( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> Self::Value { - self.create_value( - generator - .gen_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - name, - ) - .unwrap(), - name, - ) - } - - fn new_array_value( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - size: IntValue<'ctx>, - name: Option<&'ctx str>, - ) -> ArraySliceValue<'ctx> { - generator - .gen_array_var_alloc( - ctx, - self.as_base_type().get_element_type().into_struct_type().into(), - size, - name, - ) - .unwrap() - } - - fn create_value( - &self, - value: >::Base, - name: Option<&'ctx str>, - ) -> Self::Value { - debug_assert_eq!(value.get_type(), self.as_base_type()); - - NDArrayValue { value, llvm_usize: self.llvm_usize, name } - } - - fn as_base_type(&self) -> Self::Base { - self.ty - } -} - -impl<'ctx> From> for PointerType<'ctx> { - fn from(value: NDArrayType<'ctx>) -> Self { - value.as_base_type() - } -} - -/// Proxy type for accessing an `NDArray` value in LLVM. -#[derive(Copy, Clone)] -pub struct NDArrayValue<'ctx> { - value: PointerValue<'ctx>, - llvm_usize: IntType<'ctx>, - name: Option<&'ctx str>, -} - -impl<'ctx> NDArrayValue<'ctx> { - /// Checks whether `value` is an instance of `NDArray`, returning [Err] if `value` is not an - /// instance. - pub fn is_instance(value: PointerValue<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> { - NDArrayType::is_type(value.get_type(), llvm_usize) - } - - /// Creates an [`NDArrayValue`] from a [`PointerValue`]. - #[must_use] - pub fn from_ptr_val( - ptr: PointerValue<'ctx>, - llvm_usize: IntType<'ctx>, - name: Option<&'ctx str>, - ) -> Self { - debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); - - >::Type::from_type(ptr.get_type(), llvm_usize) - .create_value(ptr, name) - } - - /// Returns the pointer to the field storing the number of dimensions of this `NDArray`. - fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let llvm_i32 = ctx.ctx.i32_type(); - let var_name = self.name.map(|v| format!("{v}.ndims.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - var_name.as_str(), - ) - .unwrap() - } - } - - /// Stores the number of dimensions `ndims` into this instance. - pub fn store_ndims( - &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, - ndims: IntValue<'ctx>, - ) { - debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx)); - - let pndims = self.ptr_to_ndims(ctx); - ctx.builder.build_store(pndims, ndims).unwrap(); - } - - /// Returns the number of dimensions of this `NDArray` as a value. - pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - let pndims = self.ptr_to_ndims(ctx); - ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap() - } - - /// Returns the double-indirection pointer to the `dims` array, as if by calling `getelementptr` - /// on the field. - fn ptr_to_dims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let llvm_i32 = ctx.ctx.i32_type(); - let var_name = self.name.map(|v| format!("{v}.dims.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - var_name.as_str(), - ) - .unwrap() - } - } - - /// Stores the array of dimension sizes `dims` into this instance. - fn store_dim_sizes(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) { - ctx.builder.build_store(self.ptr_to_dims(ctx), dims).unwrap(); - } - - /// Convenience method for creating a new array storing dimension sizes with the given `size`. - pub fn create_dim_sizes( - &self, - ctx: &CodeGenContext<'ctx, '_>, - llvm_usize: IntType<'ctx>, - size: IntValue<'ctx>, - ) { - self.store_dim_sizes(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap()); - } - - /// Returns a proxy object to the field storing the size of each dimension of this `NDArray`. - #[must_use] - pub fn dim_sizes(&self) -> NDArrayDimsProxy<'ctx, '_> { - NDArrayDimsProxy(self) - } - - /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` - /// on the field. - pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let llvm_i32 = ctx.ctx.i32_type(); - let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_int(2, true)], - var_name.as_str(), - ) - .unwrap() - } - } - - /// Stores the array of data elements `data` into this instance. - fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { - ctx.builder.build_store(self.ptr_to_data(ctx), data).unwrap(); - } - - /// Convenience method for creating a new array storing data elements with the given element - /// type `elem_ty` and `size`. - pub fn create_data( - &self, - ctx: &CodeGenContext<'ctx, '_>, - elem_ty: BasicTypeEnum<'ctx>, - size: IntValue<'ctx>, - ) { - self.store_data(ctx, ctx.builder.build_array_alloca(elem_ty, size, "").unwrap()); - } - - /// Returns a proxy object to the field storing the data of this `NDArray`. - #[must_use] - pub fn data(&self) -> NDArrayDataProxy<'ctx, '_> { - NDArrayDataProxy(self) - } -} - -impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> { - type Base = PointerValue<'ctx>; - type Type = NDArrayType<'ctx>; - - fn get_type(&self) -> Self::Type { - NDArrayType::from_type(self.as_base_value().get_type(), self.llvm_usize) - } - - fn as_base_value(&self) -> Self::Base { - self.value - } -} - -impl<'ctx> From> for PointerValue<'ctx> { - fn from(value: NDArrayValue<'ctx>) -> Self { - value.as_base_value() - } -} - -/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM. -#[derive(Copy, Clone)] -pub struct NDArrayDimsProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>); - -impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDimsProxy<'ctx, '_> { - fn element_type( - &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, - ) -> AnyTypeEnum<'ctx> { - self.0.dim_sizes().base_ptr(ctx, generator).get_type().get_element_type() - } - - fn base_ptr( - &self, - ctx: &CodeGenContext<'ctx, '_>, - _: &G, - ) -> PointerValue<'ctx> { - let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default(); - - ctx.builder - .build_load(self.0.ptr_to_dims(ctx), var_name.as_str()) - .map(BasicValueEnum::into_pointer_value) - .unwrap() - } - - fn size( - &self, - ctx: &CodeGenContext<'ctx, '_>, - _: &G, - ) -> IntValue<'ctx> { - self.0.load_ndims(ctx) - } -} - -impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> { - unsafe fn ptr_offset_unchecked( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - idx: &IntValue<'ctx>, - name: Option<&str>, - ) -> PointerValue<'ctx> { - let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str()) - .unwrap() - } - } - - fn ptr_offset( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - idx: &IntValue<'ctx>, - name: Option<&str>, - ) -> PointerValue<'ctx> { - let size = self.size(ctx, generator); - let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap(); - ctx.make_assert( - generator, - in_range, - "0:IndexError", - "index {0} is out of bounds for axis 0 with size {1}", - [Some(*idx), Some(self.0.load_ndims(ctx)), None], - ctx.current_loc, - ); - - unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) } - } -} - -impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> {} -impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> {} - -impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> { - fn downcast_to_type( - &self, - _: &mut CodeGenContext<'ctx, '_>, - value: BasicValueEnum<'ctx>, - ) -> IntValue<'ctx> { - value.into_int_value() - } -} - -impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> { - fn upcast_from_type( - &self, - _: &mut CodeGenContext<'ctx, '_>, - value: IntValue<'ctx>, - ) -> BasicValueEnum<'ctx> { - value.into() - } -} - -/// Proxy type for accessing the `data` array of an `NDArray` instance in LLVM. -#[derive(Copy, Clone)] -pub struct NDArrayDataProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>); - -impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> { - fn element_type( - &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, - ) -> AnyTypeEnum<'ctx> { - self.0.data().base_ptr(ctx, generator).get_type().get_element_type() - } - - fn base_ptr( - &self, - ctx: &CodeGenContext<'ctx, '_>, - _: &G, - ) -> PointerValue<'ctx> { - let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default(); - - ctx.builder - .build_load(self.0.ptr_to_data(ctx), var_name.as_str()) - .map(BasicValueEnum::into_pointer_value) - .unwrap() - } - - fn size( - &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, - ) -> IntValue<'ctx> { - call_ndarray_calc_size(generator, ctx, &self.as_slice_value(ctx, generator), (None, None)) - } -} - -impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> { - unsafe fn ptr_offset_unchecked( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - idx: &IntValue<'ctx>, - name: Option<&str>, - ) -> PointerValue<'ctx> { - unsafe { - ctx.builder - .build_in_bounds_gep( - self.base_ptr(ctx, generator), - &[*idx], - name.unwrap_or_default(), - ) - .unwrap() - } - } - - fn ptr_offset( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - idx: &IntValue<'ctx>, - name: Option<&str>, - ) -> PointerValue<'ctx> { - let data_sz = self.size(ctx, generator); - let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, data_sz, "").unwrap(); - ctx.make_assert( - generator, - in_range, - "0:IndexError", - "index {0} is out of bounds with size {1}", - [Some(*idx), Some(self.0.load_ndims(ctx)), None], - ctx.current_loc, - ); - - unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) } - } -} - -impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDataProxy<'ctx, '_> {} -impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayDataProxy<'ctx, '_> {} - -impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> - for NDArrayDataProxy<'ctx, '_> -{ - unsafe fn ptr_offset_unchecked( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - indices: &Index, - name: Option<&str>, - ) -> PointerValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); - - let indices_elem_ty = indices - .ptr_offset(ctx, generator, &llvm_usize.const_zero(), None) - .get_type() - .get_element_type(); - let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else { - panic!("Expected list[int32] but got {indices_elem_ty}") - }; - assert_eq!( - indices_elem_ty.get_bit_width(), - 32, - "Expected list[int32] but got list[int{}]", - indices_elem_ty.get_bit_width() - ); - - let index = call_ndarray_flatten_index(generator, ctx, *self.0, indices); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.base_ptr(ctx, generator), - &[index], - name.unwrap_or_default(), - ) - .unwrap() - } - } - - fn ptr_offset( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - indices: &Index, - name: Option<&str>, - ) -> PointerValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); - - let indices_size = indices.size(ctx, generator); - let nidx_leq_ndims = ctx - .builder - .build_int_compare(IntPredicate::SLE, indices_size, self.0.load_ndims(ctx), "") - .unwrap(); - ctx.make_assert( - generator, - nidx_leq_ndims, - "0:IndexError", - "invalid index to scalar variable", - [None, None, None], - ctx.current_loc, - ); - - let indices_len = indices.size(ctx, generator); - let ndarray_len = self.0.load_ndims(ctx); - let len = call_int_umin(ctx, indices_len, ndarray_len, None); - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (len, false), - |generator, ctx, _, i| { - let (dim_idx, dim_sz) = unsafe { - ( - indices.get_unchecked(ctx, generator, &i, None).into_int_value(), - self.0.dim_sizes().get_typed_unchecked(ctx, generator, &i, None), - ) - }; - let dim_idx = ctx - .builder - .build_int_z_extend_or_bit_cast(dim_idx, dim_sz.get_type(), "") - .unwrap(); - - let dim_lt = - ctx.builder.build_int_compare(IntPredicate::SLT, dim_idx, dim_sz, "").unwrap(); - - ctx.make_assert( - generator, - dim_lt, - "0:IndexError", - "index {0} is out of bounds for axis 0 with size {1}", - [Some(dim_idx), Some(dim_sz), None], - ctx.current_loc, - ); - - Ok(()) - }, - llvm_usize.const_int(1, false), - ) - .unwrap(); - - unsafe { self.ptr_offset_unchecked(ctx, generator, indices, name) } - } -} - -impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeAccessor<'ctx, Index> - for NDArrayDataProxy<'ctx, '_> -{ -} -impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx, Index> - for NDArrayDataProxy<'ctx, '_> -{ -} diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index cae650b..7584826 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -19,10 +19,6 @@ use nac3parser::ast::{ }; use super::{ - classes::{ - ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayValue, ProxyType, ProxyValue, - RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor, - }, concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore}, gen_in_range_check, get_llvm_abi_type, get_llvm_type, get_va_count_arg_name, irrt::*, @@ -36,6 +32,11 @@ use super::{ gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise, gen_var, }, + types::{ListType, ProxyType}, + values::{ + ArrayLikeIndexer, ArrayLikeValue, ListValue, NDArrayValue, ProxyValue, RangeValue, + TypedArrayLikeAccessor, UntypedArrayLikeAccessor, + }, CodeGenContext, CodeGenTask, CodeGenerator, }; use crate::{ diff --git a/nac3core/src/codegen/generator.rs b/nac3core/src/codegen/generator.rs index 5441cbb..f277ec9 100644 --- a/nac3core/src/codegen/generator.rs +++ b/nac3core/src/codegen/generator.rs @@ -6,7 +6,7 @@ use inkwell::{ use nac3parser::ast::{Expr, Stmt, StrRef}; -use super::{bool_to_i1, bool_to_i8, classes::ArraySliceValue, expr::*, stmt::*, CodeGenContext}; +use super::{bool_to_i1, bool_to_i8, expr::*, stmt::*, values::ArraySliceValue, CodeGenContext}; use crate::{ symbol_resolver::ValueEnum, toplevel::{DefinitionId, TopLevelDef}, diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 15af3cd..7e70a36 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -12,13 +12,13 @@ use itertools::Either; use nac3parser::ast::Expr; use super::{ - classes::{ - ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, - TypedArrayLikeAccessor, TypedArrayLikeAdapter, UntypedArrayLikeAccessor, - }, llvm_intrinsics, macros::codegen_unreachable, stmt::gen_for_callback_incrementing, + values::{ + ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, + TypedArrayLikeAccessor, TypedArrayLikeAdapter, UntypedArrayLikeAccessor, + }, CodeGenContext, CodeGenerator, }; use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Type}; diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 7e99c58..b2bb5ad 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -36,12 +36,11 @@ use crate::{ typedef::{CallId, FuncArg, Type, TypeEnum, Unifier}, }, }; -use classes::{ListType, NDArrayType, ProxyType, RangeType}; use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore}; pub use generator::{CodeGenerator, DefaultCodeGenerator}; +use types::{ListType, NDArrayType, ProxyType, RangeType}; pub mod builtin_fns; -pub mod classes; pub mod concrete_type; pub mod expr; pub mod extern_fns; @@ -50,6 +49,8 @@ pub mod irrt; pub mod llvm_intrinsics; pub mod numpy; pub mod stmt; +pub mod types; +pub mod values; #[cfg(test)] mod test; diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index ffe0d83..538809c 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -7,11 +7,6 @@ use inkwell::{ use nac3parser::ast::{Operator, StrRef}; use super::{ - classes::{ - ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayType, NDArrayValue, - ProxyType, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, - TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, - }, expr::gen_binop_expr_with_values, irrt::{ calculate_len_for_slice_range, call_ndarray_calc_broadcast, @@ -20,6 +15,12 @@ use super::{ llvm_intrinsics::{self, call_memcpy_generic}, macros::codegen_unreachable, stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, + types::{ListType, NDArrayType, ProxyType}, + values::{ + ArrayLikeIndexer, ArrayLikeValue, ListValue, NDArrayValue, ProxyValue, + TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, + UntypedArrayLikeAccessor, UntypedArrayLikeMutator, + }, CodeGenContext, CodeGenerator, }; use crate::{ diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index cfc188c..d008f7b 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -12,11 +12,11 @@ use nac3parser::ast::{ }; use super::{ - classes::{ArrayLikeIndexer, ArraySliceValue, ListValue, RangeValue}, expr::{destructure_range, gen_binop_expr}, gen_in_range_check, irrt::{handle_slice_indices, list_slice_assignment}, macros::codegen_unreachable, + values::{ArrayLikeIndexer, ArraySliceValue, ListValue, RangeValue}, CodeGenContext, CodeGenerator, }; use crate::{ diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 2bd02a7..5dd6070 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -16,8 +16,8 @@ use nac3parser::{ use parking_lot::RwLock; use super::{ - classes::{ListType, NDArrayType, ProxyType, RangeType}, concrete_type::ConcreteTypeStore, + types::{ListType, NDArrayType, ProxyType, RangeType}, CodeGenContext, CodeGenLLVMOptions, CodeGenTargetMachineOptions, CodeGenTask, CodeGenerator, DefaultCodeGenerator, WithCall, WorkerRegistry, }; diff --git a/nac3core/src/codegen/types/list.rs b/nac3core/src/codegen/types/list.rs new file mode 100644 index 0000000..c68e4ab --- /dev/null +++ b/nac3core/src/codegen/types/list.rs @@ -0,0 +1,163 @@ +use inkwell::{ + context::Context, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, + values::IntValue, + AddressSpace, +}; + +use super::ProxyType; +use crate::codegen::{ + values::{ArraySliceValue, ListValue, ProxyValue}, + CodeGenContext, CodeGenerator, +}; + +/// Proxy type for a `list` type in LLVM. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct ListType<'ctx> { + ty: PointerType<'ctx>, + llvm_usize: IntType<'ctx>, +} + +impl<'ctx> ListType<'ctx> { + /// Checks whether `llvm_ty` represents a `list` type, returning [Err] if it does not. + pub fn is_type(llvm_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> { + let llvm_list_ty = llvm_ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_list_ty) = llvm_list_ty else { + return Err(format!("Expected struct type for `list` type, got {llvm_list_ty}")); + }; + if llvm_list_ty.count_fields() != 2 { + return Err(format!( + "Expected 2 fields in `list`, got {}", + llvm_list_ty.count_fields() + )); + } + + let list_size_ty = llvm_list_ty.get_field_type_at_index(0).unwrap(); + let Ok(_) = PointerType::try_from(list_size_ty) else { + return Err(format!("Expected pointer type for `list.0`, got {list_size_ty}")); + }; + + let list_data_ty = llvm_list_ty.get_field_type_at_index(1).unwrap(); + let Ok(list_data_ty) = IntType::try_from(list_data_ty) else { + return Err(format!("Expected int type for `list.1`, got {list_data_ty}")); + }; + if list_data_ty.get_bit_width() != llvm_usize.get_bit_width() { + return Err(format!( + "Expected {}-bit int type for `list.1`, got {}-bit int", + llvm_usize.get_bit_width(), + list_data_ty.get_bit_width() + )); + } + + Ok(()) + } + + /// Creates an instance of [`ListType`]. + #[must_use] + pub fn new( + generator: &G, + ctx: &'ctx Context, + element_type: BasicTypeEnum<'ctx>, + ) -> Self { + let llvm_usize = generator.get_size_type(ctx); + let llvm_list = ctx + .struct_type( + &[element_type.ptr_type(AddressSpace::default()).into(), llvm_usize.into()], + false, + ) + .ptr_type(AddressSpace::default()); + + ListType::from_type(llvm_list, llvm_usize) + } + + /// Creates an [`ListType`] from a [`PointerType`]. + #[must_use] + pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + debug_assert!(Self::is_type(ptr_ty, llvm_usize).is_ok()); + + ListType { ty: ptr_ty, llvm_usize } + } + + /// Returns the type of the `size` field of this `list` type. + #[must_use] + pub fn size_type(&self) -> IntType<'ctx> { + self.as_base_type() + .get_element_type() + .into_struct_type() + .get_field_type_at_index(1) + .map(BasicTypeEnum::into_int_type) + .unwrap() + } + + /// Returns the element type of this `list` type. + #[must_use] + pub fn element_type(&self) -> AnyTypeEnum<'ctx> { + self.as_base_type() + .get_element_type() + .into_struct_type() + .get_field_type_at_index(0) + .map(BasicTypeEnum::into_pointer_type) + .map(PointerType::get_element_type) + .unwrap() + } +} + +impl<'ctx> ProxyType<'ctx> for ListType<'ctx> { + type Base = PointerType<'ctx>; + type Value = ListValue<'ctx>; + + fn new_value( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> Self::Value { + self.map_value( + generator + .gen_var_alloc( + ctx, + self.as_base_type().get_element_type().into_struct_type().into(), + name, + ) + .unwrap(), + name, + ) + } + + fn new_array_value( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + size: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> ArraySliceValue<'ctx> { + generator + .gen_array_var_alloc( + ctx, + self.as_base_type().get_element_type().into_struct_type().into(), + size, + name, + ) + .unwrap() + } + + fn map_value( + &self, + value: >::Base, + name: Option<&'ctx str>, + ) -> Self::Value { + debug_assert_eq!(value.get_type(), self.as_base_type()); + + ListValue::from_ptr_val(value, self.llvm_usize, name) + } + + fn as_base_type(&self) -> Self::Base { + self.ty + } +} + +impl<'ctx> From> for PointerType<'ctx> { + fn from(value: ListType<'ctx>) -> Self { + value.as_base_type() + } +} diff --git a/nac3core/src/codegen/types/mod.rs b/nac3core/src/codegen/types/mod.rs new file mode 100644 index 0000000..6032936 --- /dev/null +++ b/nac3core/src/codegen/types/mod.rs @@ -0,0 +1,50 @@ +use inkwell::{types::BasicType, values::IntValue}; + +use super::{ + values::{ArraySliceValue, ProxyValue}, + {CodeGenContext, CodeGenerator}, +}; +pub use list::*; +pub use ndarray::*; +pub use range::*; + +mod list; +mod ndarray; +mod range; + +/// A LLVM type that is used to represent a corresponding type in NAC3. +pub trait ProxyType<'ctx>: Into { + /// The LLVM type of which values of this type possess. This is usually a + /// [LLVM pointer type][PointerType] for any non-primitive types. + type Base: BasicType<'ctx>; + + /// The type of values represented by this type. + type Value: ProxyValue<'ctx>; + + /// Creates a new value of this type. + fn new_value( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> Self::Value; + + /// Creates a new array value of this type. + fn new_array_value( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + size: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> ArraySliceValue<'ctx>; + + /// Converts an existing value into a [`ProxyValue`] of this type. + fn map_value( + &self, + value: >::Base, + name: Option<&'ctx str>, + ) -> Self::Value; + + /// Returns the [base type][Self::Base] of this proxy. + fn as_base_type(&self) -> Self::Base; +} diff --git a/nac3core/src/codegen/types/ndarray.rs b/nac3core/src/codegen/types/ndarray.rs new file mode 100644 index 0000000..780a7a6 --- /dev/null +++ b/nac3core/src/codegen/types/ndarray.rs @@ -0,0 +1,191 @@ +use inkwell::{ + context::Context, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, + values::IntValue, + AddressSpace, +}; + +use super::ProxyType; +use crate::codegen::{ + values::{ArraySliceValue, NDArrayValue, ProxyValue}, + {CodeGenContext, CodeGenerator}, +}; + +/// Proxy type for a `ndarray` type in LLVM. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct NDArrayType<'ctx> { + ty: PointerType<'ctx>, + llvm_usize: IntType<'ctx>, +} + +impl<'ctx> NDArrayType<'ctx> { + /// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not. + pub fn is_type(llvm_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> { + let llvm_ndarray_ty = llvm_ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { + return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}")); + }; + if llvm_ndarray_ty.count_fields() != 3 { + return Err(format!( + "Expected 3 fields in `NDArray`, got {}", + llvm_ndarray_ty.count_fields() + )); + } + + let ndarray_ndims_ty = llvm_ndarray_ty.get_field_type_at_index(0).unwrap(); + let Ok(ndarray_ndims_ty) = IntType::try_from(ndarray_ndims_ty) else { + return Err(format!("Expected int type for `ndarray.0`, got {ndarray_ndims_ty}")); + }; + if ndarray_ndims_ty.get_bit_width() != llvm_usize.get_bit_width() { + return Err(format!( + "Expected {}-bit int type for `ndarray.0`, got {}-bit int", + llvm_usize.get_bit_width(), + ndarray_ndims_ty.get_bit_width() + )); + } + + let ndarray_dims_ty = llvm_ndarray_ty.get_field_type_at_index(1).unwrap(); + let Ok(ndarray_pdims) = PointerType::try_from(ndarray_dims_ty) else { + return Err(format!("Expected pointer type for `ndarray.1`, got {ndarray_dims_ty}")); + }; + let ndarray_dims = ndarray_pdims.get_element_type(); + let Ok(ndarray_dims) = IntType::try_from(ndarray_dims) else { + return Err(format!( + "Expected pointer-to-int type for `ndarray.1`, got pointer-to-{ndarray_dims}" + )); + }; + if ndarray_dims.get_bit_width() != llvm_usize.get_bit_width() { + return Err(format!( + "Expected pointer-to-{}-bit int type for `ndarray.1`, got pointer-to-{}-bit int", + llvm_usize.get_bit_width(), + ndarray_dims.get_bit_width() + )); + } + + let ndarray_data_ty = llvm_ndarray_ty.get_field_type_at_index(2).unwrap(); + let Ok(_) = PointerType::try_from(ndarray_data_ty) else { + return Err(format!("Expected pointer type for `ndarray.2`, got {ndarray_data_ty}")); + }; + + Ok(()) + } + + /// Creates an instance of [`ListType`]. + #[must_use] + pub fn new( + generator: &G, + ctx: &'ctx Context, + dtype: BasicTypeEnum<'ctx>, + ) -> Self { + let llvm_usize = generator.get_size_type(ctx); + + // struct NDArray { num_dims: size_t, dims: size_t*, data: T* } + // + // * num_dims: Number of dimensions in the array + // * dims: Pointer to an array containing the size of each dimension + // * data: Pointer to an array containing the array data + let llvm_ndarray = ctx + .struct_type( + &[ + llvm_usize.into(), + llvm_usize.ptr_type(AddressSpace::default()).into(), + dtype.ptr_type(AddressSpace::default()).into(), + ], + false, + ) + .ptr_type(AddressSpace::default()); + + NDArrayType::from_type(llvm_ndarray, llvm_usize) + } + + /// Creates an [`NDArrayType`] from a [`PointerType`]. + #[must_use] + pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + debug_assert!(Self::is_type(ptr_ty, llvm_usize).is_ok()); + + NDArrayType { ty: ptr_ty, llvm_usize } + } + + /// Returns the type of the `size` field of this `ndarray` type. + #[must_use] + pub fn size_type(&self) -> IntType<'ctx> { + self.as_base_type() + .get_element_type() + .into_struct_type() + .get_field_type_at_index(0) + .map(BasicTypeEnum::into_int_type) + .unwrap() + } + + /// Returns the element type of this `ndarray` type. + #[must_use] + pub fn element_type(&self) -> AnyTypeEnum<'ctx> { + self.as_base_type() + .get_element_type() + .into_struct_type() + .get_field_type_at_index(2) + .map(BasicTypeEnum::into_pointer_type) + .map(PointerType::get_element_type) + .unwrap() + } +} + +impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> { + type Base = PointerType<'ctx>; + type Value = NDArrayValue<'ctx>; + + fn new_value( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> Self::Value { + self.map_value( + generator + .gen_var_alloc( + ctx, + self.as_base_type().get_element_type().into_struct_type().into(), + name, + ) + .unwrap(), + name, + ) + } + + fn new_array_value( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + size: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> ArraySliceValue<'ctx> { + generator + .gen_array_var_alloc( + ctx, + self.as_base_type().get_element_type().into_struct_type().into(), + size, + name, + ) + .unwrap() + } + + fn map_value( + &self, + value: >::Base, + name: Option<&'ctx str>, + ) -> Self::Value { + debug_assert_eq!(value.get_type(), self.as_base_type()); + + NDArrayValue::from_ptr_val(value, self.llvm_usize, name) + } + + fn as_base_type(&self) -> Self::Base { + self.ty + } +} + +impl<'ctx> From> for PointerType<'ctx> { + fn from(value: NDArrayType<'ctx>) -> Self { + value.as_base_type() + } +} diff --git a/nac3core/src/codegen/types/range.rs b/nac3core/src/codegen/types/range.rs new file mode 100644 index 0000000..3b22916 --- /dev/null +++ b/nac3core/src/codegen/types/range.rs @@ -0,0 +1,132 @@ +use inkwell::{ + context::Context, + types::{AnyTypeEnum, IntType, PointerType}, + values::IntValue, + AddressSpace, +}; + +use super::ProxyType; +use crate::codegen::{ + values::{ArraySliceValue, ProxyValue, RangeValue}, + {CodeGenContext, CodeGenerator}, +}; + +/// Proxy type for a `range` type in LLVM. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct RangeType<'ctx> { + ty: PointerType<'ctx>, +} + +impl<'ctx> RangeType<'ctx> { + /// Checks whether `llvm_ty` represents a `range` type, returning [Err] if it does not. + pub fn is_type(llvm_ty: PointerType<'ctx>) -> Result<(), String> { + let llvm_range_ty = llvm_ty.get_element_type(); + let AnyTypeEnum::ArrayType(llvm_range_ty) = llvm_range_ty else { + return Err(format!("Expected array type for `range` type, got {llvm_range_ty}")); + }; + if llvm_range_ty.len() != 3 { + return Err(format!( + "Expected 3 elements for `range` type, got {}", + llvm_range_ty.len() + )); + } + + let llvm_range_elem_ty = llvm_range_ty.get_element_type(); + let Ok(llvm_range_elem_ty) = IntType::try_from(llvm_range_elem_ty) else { + return Err(format!( + "Expected int type for `range` element type, got {llvm_range_elem_ty}" + )); + }; + if llvm_range_elem_ty.get_bit_width() != 32 { + return Err(format!( + "Expected 32-bit int type for `range` element type, got {}", + llvm_range_elem_ty.get_bit_width() + )); + } + + Ok(()) + } + + /// Creates an instance of [`RangeType`]. + #[must_use] + pub fn new(ctx: &'ctx Context) -> Self { + let llvm_i32 = ctx.i32_type(); + let llvm_range = llvm_i32.array_type(3).ptr_type(AddressSpace::default()); + + RangeType::from_type(llvm_range) + } + + /// Creates an [`RangeType`] from a [`PointerType`]. + #[must_use] + pub fn from_type(ptr_ty: PointerType<'ctx>) -> Self { + debug_assert!(Self::is_type(ptr_ty).is_ok()); + + RangeType { ty: ptr_ty } + } + + /// Returns the type of all fields of this `range` type. + #[must_use] + pub fn value_type(&self) -> IntType<'ctx> { + self.as_base_type().get_element_type().into_array_type().get_element_type().into_int_type() + } +} + +impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> { + type Base = PointerType<'ctx>; + type Value = RangeValue<'ctx>; + + fn new_value( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> Self::Value { + self.map_value( + generator + .gen_var_alloc( + ctx, + self.as_base_type().get_element_type().into_struct_type().into(), + name, + ) + .unwrap(), + name, + ) + } + + fn new_array_value( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + size: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> ArraySliceValue<'ctx> { + generator + .gen_array_var_alloc( + ctx, + self.as_base_type().get_element_type().into_struct_type().into(), + size, + name, + ) + .unwrap() + } + + fn map_value( + &self, + value: >::Base, + name: Option<&'ctx str>, + ) -> Self::Value { + debug_assert_eq!(value.get_type(), self.as_base_type()); + + RangeValue::from_ptr_val(value, name) + } + + fn as_base_type(&self) -> Self::Base { + self.ty + } +} + +impl<'ctx> From> for PointerType<'ctx> { + fn from(value: RangeType<'ctx>) -> Self { + value.as_base_type() + } +} diff --git a/nac3core/src/codegen/values/array.rs b/nac3core/src/codegen/values/array.rs new file mode 100644 index 0000000..8d14fe8 --- /dev/null +++ b/nac3core/src/codegen/values/array.rs @@ -0,0 +1,426 @@ +use inkwell::{ + types::AnyTypeEnum, + values::{BasicValueEnum, IntValue, PointerValue}, + IntPredicate, +}; + +use crate::codegen::{CodeGenContext, CodeGenerator}; + +/// An LLVM value that is array-like, i.e. it contains a contiguous, sequenced collection of +/// elements. +pub trait ArrayLikeValue<'ctx> { + /// Returns the element type of this array-like value. + fn element_type( + &self, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, + ) -> AnyTypeEnum<'ctx>; + + /// Returns the base pointer to the array. + fn base_ptr( + &self, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, + ) -> PointerValue<'ctx>; + + /// Returns the size of this array-like value. + fn size( + &self, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, + ) -> IntValue<'ctx>; + + /// Returns a [`ArraySliceValue`] representing this value. + fn as_slice_value( + &self, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, + ) -> ArraySliceValue<'ctx> { + ArraySliceValue::from_ptr_val( + self.base_ptr(ctx, generator), + self.size(ctx, generator), + None, + ) + } +} + +/// An array-like value that can be indexed by memory offset. +pub trait ArrayLikeIndexer<'ctx, Index = IntValue<'ctx>>: ArrayLikeValue<'ctx> { + /// # Safety + /// + /// This function should be called with a valid index. + unsafe fn ptr_offset_unchecked( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut G, + idx: &Index, + name: Option<&str>, + ) -> PointerValue<'ctx>; + + /// Returns the pointer to the data at the `idx`-th index. + fn ptr_offset( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut G, + idx: &Index, + name: Option<&str>, + ) -> PointerValue<'ctx>; +} + +/// An array-like value that can have its array elements accessed as a [`BasicValueEnum`]. +pub trait UntypedArrayLikeAccessor<'ctx, Index = IntValue<'ctx>>: + ArrayLikeIndexer<'ctx, Index> +{ + /// # Safety + /// + /// This function should be called with a valid index. + unsafe fn get_unchecked( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut G, + idx: &Index, + name: Option<&str>, + ) -> BasicValueEnum<'ctx> { + let ptr = unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }; + ctx.builder.build_load(ptr, name.unwrap_or_default()).unwrap() + } + + /// Returns the data at the `idx`-th index. + fn get( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut G, + idx: &Index, + name: Option<&str>, + ) -> BasicValueEnum<'ctx> { + let ptr = self.ptr_offset(ctx, generator, idx, name); + ctx.builder.build_load(ptr, name.unwrap_or_default()).unwrap() + } +} + +/// An array-like value that can have its array elements mutated as a [`BasicValueEnum`]. +pub trait UntypedArrayLikeMutator<'ctx, Index = IntValue<'ctx>>: + ArrayLikeIndexer<'ctx, Index> +{ + /// # Safety + /// + /// This function should be called with a valid index. + unsafe fn set_unchecked( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut G, + idx: &Index, + value: BasicValueEnum<'ctx>, + ) { + let ptr = unsafe { self.ptr_offset_unchecked(ctx, generator, idx, None) }; + ctx.builder.build_store(ptr, value).unwrap(); + } + + /// Sets the data at the `idx`-th index. + fn set( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut G, + idx: &Index, + value: BasicValueEnum<'ctx>, + ) { + let ptr = self.ptr_offset(ctx, generator, idx, None); + ctx.builder.build_store(ptr, value).unwrap(); + } +} + +/// An array-like value that can have its array elements accessed as an arbitrary type `T`. +pub trait TypedArrayLikeAccessor<'ctx, T, Index = IntValue<'ctx>>: + UntypedArrayLikeAccessor<'ctx, Index> +{ + /// Casts an element from [`BasicValueEnum`] into `T`. + fn downcast_to_type( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + value: BasicValueEnum<'ctx>, + ) -> T; + + /// # Safety + /// + /// This function should be called with a valid index. + unsafe fn get_typed_unchecked( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut G, + idx: &Index, + name: Option<&str>, + ) -> T { + let value = unsafe { self.get_unchecked(ctx, generator, idx, name) }; + self.downcast_to_type(ctx, value) + } + + /// Returns the data at the `idx`-th index. + fn get_typed( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut G, + idx: &Index, + name: Option<&str>, + ) -> T { + let value = self.get(ctx, generator, idx, name); + self.downcast_to_type(ctx, value) + } +} + +/// An array-like value that can have its array elements mutated as an arbitrary type `T`. +pub trait TypedArrayLikeMutator<'ctx, T, Index = IntValue<'ctx>>: + UntypedArrayLikeMutator<'ctx, Index> +{ + /// Casts an element from T into [`BasicValueEnum`]. + fn upcast_from_type( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + value: T, + ) -> BasicValueEnum<'ctx>; + + /// # Safety + /// + /// This function should be called with a valid index. + unsafe fn set_typed_unchecked( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut G, + idx: &Index, + value: T, + ) { + let value = self.upcast_from_type(ctx, value); + unsafe { self.set_unchecked(ctx, generator, idx, value) } + } + + /// Sets the data at the `idx`-th index. + fn set_typed( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut G, + idx: &Index, + value: T, + ) { + let value = self.upcast_from_type(ctx, value); + self.set(ctx, generator, idx, value); + } +} + +/// Type alias for a function that casts a [`BasicValueEnum`] into a `T`. +type ValueDowncastFn<'ctx, T> = + Box, BasicValueEnum<'ctx>) -> T>; +/// Type alias for a function that casts a `T` into a [`BasicValueEnum`]. +type ValueUpcastFn<'ctx, T> = Box, T) -> BasicValueEnum<'ctx>>; + +/// An adapter for constraining untyped array values as typed values. +pub struct TypedArrayLikeAdapter<'ctx, T, Adapted: ArrayLikeValue<'ctx> = ArraySliceValue<'ctx>> { + adapted: Adapted, + downcast_fn: ValueDowncastFn<'ctx, T>, + upcast_fn: ValueUpcastFn<'ctx, T>, +} + +impl<'ctx, T, Adapted> TypedArrayLikeAdapter<'ctx, T, Adapted> +where + Adapted: ArrayLikeValue<'ctx>, +{ + /// Creates a [`TypedArrayLikeAdapter`]. + /// + /// * `adapted` - The value to be adapted. + /// * `downcast_fn` - The function converting a [`BasicValueEnum`] into a `T`. + /// * `upcast_fn` - The function converting a T into a [`BasicValueEnum`]. + pub fn from( + adapted: Adapted, + downcast_fn: ValueDowncastFn<'ctx, T>, + upcast_fn: ValueUpcastFn<'ctx, T>, + ) -> Self { + TypedArrayLikeAdapter { adapted, downcast_fn, upcast_fn } + } +} + +impl<'ctx, T, Adapted> ArrayLikeValue<'ctx> for TypedArrayLikeAdapter<'ctx, T, Adapted> +where + Adapted: ArrayLikeValue<'ctx>, +{ + fn element_type( + &self, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, + ) -> AnyTypeEnum<'ctx> { + self.adapted.element_type(ctx, generator) + } + + fn base_ptr( + &self, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, + ) -> PointerValue<'ctx> { + self.adapted.base_ptr(ctx, generator) + } + + fn size( + &self, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, + ) -> IntValue<'ctx> { + self.adapted.size(ctx, generator) + } +} + +impl<'ctx, T, Index, Adapted> ArrayLikeIndexer<'ctx, Index> + for TypedArrayLikeAdapter<'ctx, T, Adapted> +where + Adapted: ArrayLikeIndexer<'ctx, Index>, +{ + unsafe fn ptr_offset_unchecked( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut G, + idx: &Index, + name: Option<&str>, + ) -> PointerValue<'ctx> { + unsafe { self.adapted.ptr_offset_unchecked(ctx, generator, idx, name) } + } + + fn ptr_offset( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut G, + idx: &Index, + name: Option<&str>, + ) -> PointerValue<'ctx> { + self.adapted.ptr_offset(ctx, generator, idx, name) + } +} + +impl<'ctx, T, Index, Adapted> UntypedArrayLikeAccessor<'ctx, Index> + for TypedArrayLikeAdapter<'ctx, T, Adapted> +where + Adapted: UntypedArrayLikeAccessor<'ctx, Index>, +{ +} +impl<'ctx, T, Index, Adapted> UntypedArrayLikeMutator<'ctx, Index> + for TypedArrayLikeAdapter<'ctx, T, Adapted> +where + Adapted: UntypedArrayLikeMutator<'ctx, Index>, +{ +} + +impl<'ctx, T, Index, Adapted> TypedArrayLikeAccessor<'ctx, T, Index> + for TypedArrayLikeAdapter<'ctx, T, Adapted> +where + Adapted: UntypedArrayLikeAccessor<'ctx, Index>, +{ + fn downcast_to_type( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + value: BasicValueEnum<'ctx>, + ) -> T { + (self.downcast_fn)(ctx, value) + } +} + +impl<'ctx, T, Index, Adapted> TypedArrayLikeMutator<'ctx, T, Index> + for TypedArrayLikeAdapter<'ctx, T, Adapted> +where + Adapted: UntypedArrayLikeMutator<'ctx, Index>, +{ + fn upcast_from_type( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + value: T, + ) -> BasicValueEnum<'ctx> { + (self.upcast_fn)(ctx, value) + } +} + +/// An LLVM value representing an array slice, consisting of a pointer to the data and the size of +/// the slice. +#[derive(Copy, Clone)] +pub struct ArraySliceValue<'ctx>(PointerValue<'ctx>, IntValue<'ctx>, Option<&'ctx str>); + +impl<'ctx> ArraySliceValue<'ctx> { + /// Creates an [`ArraySliceValue`] from a [`PointerValue`] and its size. + #[must_use] + pub fn from_ptr_val( + ptr: PointerValue<'ctx>, + size: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + ArraySliceValue(ptr, size, name) + } +} + +impl<'ctx> From> for PointerValue<'ctx> { + fn from(value: ArraySliceValue<'ctx>) -> Self { + value.0 + } +} + +impl<'ctx> ArrayLikeValue<'ctx> for ArraySliceValue<'ctx> { + fn element_type( + &self, + _: &CodeGenContext<'ctx, '_>, + _: &G, + ) -> AnyTypeEnum<'ctx> { + self.0.get_type().get_element_type() + } + + fn base_ptr( + &self, + _: &CodeGenContext<'ctx, '_>, + _: &G, + ) -> PointerValue<'ctx> { + self.0 + } + + fn size( + &self, + _: &CodeGenContext<'ctx, '_>, + _: &G, + ) -> IntValue<'ctx> { + self.1 + } +} + +impl<'ctx> ArrayLikeIndexer<'ctx> for ArraySliceValue<'ctx> { + unsafe fn ptr_offset_unchecked( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut G, + idx: &IntValue<'ctx>, + name: Option<&str>, + ) -> PointerValue<'ctx> { + let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default(); + + unsafe { + ctx.builder + .build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str()) + .unwrap() + } + } + + fn ptr_offset( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut G, + idx: &IntValue<'ctx>, + name: Option<&str>, + ) -> PointerValue<'ctx> { + debug_assert_eq!(idx.get_type(), generator.get_size_type(ctx.ctx)); + + let size = self.size(ctx, generator); + let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap(); + ctx.make_assert( + generator, + in_range, + "0:IndexError", + "list index out of range", + [None, None, None], + ctx.current_loc, + ); + + unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) } + } +} + +impl<'ctx> UntypedArrayLikeAccessor<'ctx> for ArraySliceValue<'ctx> {} +impl<'ctx> UntypedArrayLikeMutator<'ctx> for ArraySliceValue<'ctx> {} diff --git a/nac3core/src/codegen/values/list.rs b/nac3core/src/codegen/values/list.rs new file mode 100644 index 0000000..e81993d --- /dev/null +++ b/nac3core/src/codegen/values/list.rs @@ -0,0 +1,238 @@ +use inkwell::{ + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType}, + values::{BasicValueEnum, IntValue, PointerValue}, + AddressSpace, IntPredicate, +}; + +use super::{ + ArrayLikeIndexer, ArrayLikeValue, ProxyValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, +}; +use crate::codegen::{ + types::ListType, + {CodeGenContext, CodeGenerator}, +}; + +/// Proxy type for accessing a `list` value in LLVM. +#[derive(Copy, Clone)] +pub struct ListValue<'ctx> { + value: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, +} + +impl<'ctx> ListValue<'ctx> { + /// Checks whether `value` is an instance of `list`, returning [Err] if `value` is not an + /// instance. + pub fn is_instance(value: PointerValue<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> { + ListType::is_type(value.get_type(), llvm_usize) + } + + /// Creates an [`ListValue`] from a [`PointerValue`]. + #[must_use] + pub fn from_ptr_val( + ptr: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); + + ListValue { value: ptr, llvm_usize, name } + } + + /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` + /// on the field. + fn pptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default(); + + unsafe { + ctx.builder + .build_in_bounds_gep( + self.as_base_value(), + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + var_name.as_str(), + ) + .unwrap() + } + } + + /// Returns the pointer to the field storing the size of this `list`. + fn ptr_to_size(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let var_name = self.name.map(|v| format!("{v}.size.addr")).unwrap_or_default(); + + unsafe { + ctx.builder + .build_in_bounds_gep( + self.as_base_value(), + &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], + var_name.as_str(), + ) + .unwrap() + } + } + + /// Stores the array of data elements `data` into this instance. + fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { + ctx.builder.build_store(self.pptr_to_data(ctx), data).unwrap(); + } + + /// Convenience method for creating a new array storing data elements with the given element + /// type `elem_ty` and `size`. + /// + /// If `size` is [None], the size stored in the field of this instance is used instead. + pub fn create_data( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + elem_ty: BasicTypeEnum<'ctx>, + size: Option>, + ) { + let size = size.unwrap_or_else(|| self.load_size(ctx, None)); + + let data = ctx + .builder + .build_select( + ctx.builder + .build_int_compare(IntPredicate::NE, size, self.llvm_usize.const_zero(), "") + .unwrap(), + ctx.builder.build_array_alloca(elem_ty, size, "").unwrap(), + elem_ty.ptr_type(AddressSpace::default()).const_zero(), + "", + ) + .map(BasicValueEnum::into_pointer_value) + .unwrap(); + self.store_data(ctx, data); + } + + /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` + /// on the field. + #[must_use] + pub fn data(&self) -> ListDataProxy<'ctx, '_> { + ListDataProxy(self) + } + + /// Stores the `size` of this `list` into this instance. + pub fn store_size( + &self, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, + size: IntValue<'ctx>, + ) { + debug_assert_eq!(size.get_type(), generator.get_size_type(ctx.ctx)); + + let psize = self.ptr_to_size(ctx); + ctx.builder.build_store(psize, size).unwrap(); + } + + /// Returns the size of this `list` as a value. + pub fn load_size(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> { + let psize = self.ptr_to_size(ctx); + let var_name = name + .map(ToString::to_string) + .or_else(|| self.name.map(|v| format!("{v}.size"))) + .unwrap_or_default(); + + ctx.builder + .build_load(psize, var_name.as_str()) + .map(BasicValueEnum::into_int_value) + .unwrap() + } +} + +impl<'ctx> ProxyValue<'ctx> for ListValue<'ctx> { + type Base = PointerValue<'ctx>; + type Type = ListType<'ctx>; + + fn get_type(&self) -> Self::Type { + ListType::from_type(self.as_base_value().get_type(), self.llvm_usize) + } + + fn as_base_value(&self) -> Self::Base { + self.value + } +} + +impl<'ctx> From> for PointerValue<'ctx> { + fn from(value: ListValue<'ctx>) -> Self { + value.as_base_value() + } +} + +/// Proxy type for accessing the `data` array of an `list` instance in LLVM. +#[derive(Copy, Clone)] +pub struct ListDataProxy<'ctx, 'a>(&'a ListValue<'ctx>); + +impl<'ctx> ArrayLikeValue<'ctx> for ListDataProxy<'ctx, '_> { + fn element_type( + &self, + _: &CodeGenContext<'ctx, '_>, + _: &G, + ) -> AnyTypeEnum<'ctx> { + self.0.value.get_type().get_element_type() + } + + fn base_ptr( + &self, + ctx: &CodeGenContext<'ctx, '_>, + _: &G, + ) -> PointerValue<'ctx> { + let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default(); + + ctx.builder + .build_load(self.0.pptr_to_data(ctx), var_name.as_str()) + .map(BasicValueEnum::into_pointer_value) + .unwrap() + } + + fn size( + &self, + ctx: &CodeGenContext<'ctx, '_>, + _: &G, + ) -> IntValue<'ctx> { + self.0.load_size(ctx, None) + } +} + +impl<'ctx> ArrayLikeIndexer<'ctx> for ListDataProxy<'ctx, '_> { + unsafe fn ptr_offset_unchecked( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut G, + idx: &IntValue<'ctx>, + name: Option<&str>, + ) -> PointerValue<'ctx> { + let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default(); + + unsafe { + ctx.builder + .build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str()) + .unwrap() + } + } + + fn ptr_offset( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut G, + idx: &IntValue<'ctx>, + name: Option<&str>, + ) -> PointerValue<'ctx> { + debug_assert_eq!(idx.get_type(), generator.get_size_type(ctx.ctx)); + + let size = self.size(ctx, generator); + let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap(); + ctx.make_assert( + generator, + in_range, + "0:IndexError", + "list index out of range", + [None, None, None], + ctx.current_loc, + ); + + unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) } + } +} + +impl<'ctx> UntypedArrayLikeAccessor<'ctx> for ListDataProxy<'ctx, '_> {} +impl<'ctx> UntypedArrayLikeMutator<'ctx> for ListDataProxy<'ctx, '_> {} diff --git a/nac3core/src/codegen/values/mod.rs b/nac3core/src/codegen/values/mod.rs new file mode 100644 index 0000000..db1c7bf --- /dev/null +++ b/nac3core/src/codegen/values/mod.rs @@ -0,0 +1,28 @@ +use inkwell::values::BasicValue; + +use super::types::ProxyType; +pub use array::*; +pub use list::*; +pub use ndarray::*; +pub use range::*; + +mod array; +mod list; +mod ndarray; +mod range; + +/// A LLVM type that is used to represent a non-primitive value in NAC3. +pub trait ProxyValue<'ctx>: Into { + /// The type of LLVM values represented by this instance. This is usually the + /// [LLVM pointer type][PointerValue]. + type Base: BasicValue<'ctx>; + + /// The type of this value. + type Type: ProxyType<'ctx>; + + /// Returns the [type][ProxyType] of this value. + fn get_type(&self) -> Self::Type; + + /// Returns the [base value][Self::Base] of this proxy. + fn as_base_value(&self) -> Self::Base; +} diff --git a/nac3core/src/codegen/values/ndarray.rs b/nac3core/src/codegen/values/ndarray.rs new file mode 100644 index 0000000..23e8836 --- /dev/null +++ b/nac3core/src/codegen/values/ndarray.rs @@ -0,0 +1,466 @@ +use inkwell::{ + types::{AnyTypeEnum, BasicTypeEnum, IntType}, + values::{BasicValueEnum, IntValue, PointerValue}, + IntPredicate, +}; + +use super::{ + ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeMutator, + UntypedArrayLikeAccessor, UntypedArrayLikeMutator, +}; +use crate::codegen::{ + irrt::{call_ndarray_calc_size, call_ndarray_flatten_index}, + llvm_intrinsics::call_int_umin, + stmt::gen_for_callback_incrementing, + types::NDArrayType, + CodeGenContext, CodeGenerator, +}; + +/// Proxy type for accessing an `NDArray` value in LLVM. +#[derive(Copy, Clone)] +pub struct NDArrayValue<'ctx> { + value: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, +} + +impl<'ctx> NDArrayValue<'ctx> { + /// Checks whether `value` is an instance of `NDArray`, returning [Err] if `value` is not an + /// instance. + pub fn is_instance(value: PointerValue<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> { + NDArrayType::is_type(value.get_type(), llvm_usize) + } + + /// Creates an [`NDArrayValue`] from a [`PointerValue`]. + #[must_use] + pub fn from_ptr_val( + ptr: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); + + NDArrayValue { value: ptr, llvm_usize, name } + } + + /// Returns the pointer to the field storing the number of dimensions of this `NDArray`. + fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let var_name = self.name.map(|v| format!("{v}.ndims.addr")).unwrap_or_default(); + + unsafe { + ctx.builder + .build_in_bounds_gep( + self.as_base_value(), + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + var_name.as_str(), + ) + .unwrap() + } + } + + /// Stores the number of dimensions `ndims` into this instance. + pub fn store_ndims( + &self, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, + ndims: IntValue<'ctx>, + ) { + debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx)); + + let pndims = self.ptr_to_ndims(ctx); + ctx.builder.build_store(pndims, ndims).unwrap(); + } + + /// Returns the number of dimensions of this `NDArray` as a value. + pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + let pndims = self.ptr_to_ndims(ctx); + ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap() + } + + /// Returns the double-indirection pointer to the `dims` array, as if by calling `getelementptr` + /// on the field. + fn ptr_to_dims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let var_name = self.name.map(|v| format!("{v}.dims.addr")).unwrap_or_default(); + + unsafe { + ctx.builder + .build_in_bounds_gep( + self.as_base_value(), + &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], + var_name.as_str(), + ) + .unwrap() + } + } + + /// Stores the array of dimension sizes `dims` into this instance. + fn store_dim_sizes(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) { + ctx.builder.build_store(self.ptr_to_dims(ctx), dims).unwrap(); + } + + /// Convenience method for creating a new array storing dimension sizes with the given `size`. + pub fn create_dim_sizes( + &self, + ctx: &CodeGenContext<'ctx, '_>, + llvm_usize: IntType<'ctx>, + size: IntValue<'ctx>, + ) { + self.store_dim_sizes(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap()); + } + + /// Returns a proxy object to the field storing the size of each dimension of this `NDArray`. + #[must_use] + pub fn dim_sizes(&self) -> NDArrayDimsProxy<'ctx, '_> { + NDArrayDimsProxy(self) + } + + /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` + /// on the field. + pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default(); + + unsafe { + ctx.builder + .build_in_bounds_gep( + self.as_base_value(), + &[llvm_i32.const_zero(), llvm_i32.const_int(2, true)], + var_name.as_str(), + ) + .unwrap() + } + } + + /// Stores the array of data elements `data` into this instance. + fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { + ctx.builder.build_store(self.ptr_to_data(ctx), data).unwrap(); + } + + /// Convenience method for creating a new array storing data elements with the given element + /// type `elem_ty` and `size`. + pub fn create_data( + &self, + ctx: &CodeGenContext<'ctx, '_>, + elem_ty: BasicTypeEnum<'ctx>, + size: IntValue<'ctx>, + ) { + self.store_data(ctx, ctx.builder.build_array_alloca(elem_ty, size, "").unwrap()); + } + + /// Returns a proxy object to the field storing the data of this `NDArray`. + #[must_use] + pub fn data(&self) -> NDArrayDataProxy<'ctx, '_> { + NDArrayDataProxy(self) + } +} + +impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> { + type Base = PointerValue<'ctx>; + type Type = NDArrayType<'ctx>; + + fn get_type(&self) -> Self::Type { + NDArrayType::from_type(self.as_base_value().get_type(), self.llvm_usize) + } + + fn as_base_value(&self) -> Self::Base { + self.value + } +} + +impl<'ctx> From> for PointerValue<'ctx> { + fn from(value: NDArrayValue<'ctx>) -> Self { + value.as_base_value() + } +} + +/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM. +#[derive(Copy, Clone)] +pub struct NDArrayDimsProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>); + +impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDimsProxy<'ctx, '_> { + fn element_type( + &self, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, + ) -> AnyTypeEnum<'ctx> { + self.0.dim_sizes().base_ptr(ctx, generator).get_type().get_element_type() + } + + fn base_ptr( + &self, + ctx: &CodeGenContext<'ctx, '_>, + _: &G, + ) -> PointerValue<'ctx> { + let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default(); + + ctx.builder + .build_load(self.0.ptr_to_dims(ctx), var_name.as_str()) + .map(BasicValueEnum::into_pointer_value) + .unwrap() + } + + fn size( + &self, + ctx: &CodeGenContext<'ctx, '_>, + _: &G, + ) -> IntValue<'ctx> { + self.0.load_ndims(ctx) + } +} + +impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> { + unsafe fn ptr_offset_unchecked( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut G, + idx: &IntValue<'ctx>, + name: Option<&str>, + ) -> PointerValue<'ctx> { + let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default(); + + unsafe { + ctx.builder + .build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str()) + .unwrap() + } + } + + fn ptr_offset( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut G, + idx: &IntValue<'ctx>, + name: Option<&str>, + ) -> PointerValue<'ctx> { + let size = self.size(ctx, generator); + let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap(); + ctx.make_assert( + generator, + in_range, + "0:IndexError", + "index {0} is out of bounds for axis 0 with size {1}", + [Some(*idx), Some(self.0.load_ndims(ctx)), None], + ctx.current_loc, + ); + + unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) } + } +} + +impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> {} +impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> {} + +impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> { + fn downcast_to_type( + &self, + _: &mut CodeGenContext<'ctx, '_>, + value: BasicValueEnum<'ctx>, + ) -> IntValue<'ctx> { + value.into_int_value() + } +} + +impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> { + fn upcast_from_type( + &self, + _: &mut CodeGenContext<'ctx, '_>, + value: IntValue<'ctx>, + ) -> BasicValueEnum<'ctx> { + value.into() + } +} + +/// Proxy type for accessing the `data` array of an `NDArray` instance in LLVM. +#[derive(Copy, Clone)] +pub struct NDArrayDataProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>); + +impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> { + fn element_type( + &self, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, + ) -> AnyTypeEnum<'ctx> { + self.0.data().base_ptr(ctx, generator).get_type().get_element_type() + } + + fn base_ptr( + &self, + ctx: &CodeGenContext<'ctx, '_>, + _: &G, + ) -> PointerValue<'ctx> { + let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default(); + + ctx.builder + .build_load(self.0.ptr_to_data(ctx), var_name.as_str()) + .map(BasicValueEnum::into_pointer_value) + .unwrap() + } + + fn size( + &self, + ctx: &CodeGenContext<'ctx, '_>, + generator: &G, + ) -> IntValue<'ctx> { + call_ndarray_calc_size(generator, ctx, &self.as_slice_value(ctx, generator), (None, None)) + } +} + +impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> { + unsafe fn ptr_offset_unchecked( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut G, + idx: &IntValue<'ctx>, + name: Option<&str>, + ) -> PointerValue<'ctx> { + unsafe { + ctx.builder + .build_in_bounds_gep( + self.base_ptr(ctx, generator), + &[*idx], + name.unwrap_or_default(), + ) + .unwrap() + } + } + + fn ptr_offset( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut G, + idx: &IntValue<'ctx>, + name: Option<&str>, + ) -> PointerValue<'ctx> { + let data_sz = self.size(ctx, generator); + let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, data_sz, "").unwrap(); + ctx.make_assert( + generator, + in_range, + "0:IndexError", + "index {0} is out of bounds with size {1}", + [Some(*idx), Some(self.0.load_ndims(ctx)), None], + ctx.current_loc, + ); + + unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) } + } +} + +impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDataProxy<'ctx, '_> {} +impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayDataProxy<'ctx, '_> {} + +impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> + for NDArrayDataProxy<'ctx, '_> +{ + unsafe fn ptr_offset_unchecked( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut G, + indices: &Index, + name: Option<&str>, + ) -> PointerValue<'ctx> { + let llvm_usize = generator.get_size_type(ctx.ctx); + + let indices_elem_ty = indices + .ptr_offset(ctx, generator, &llvm_usize.const_zero(), None) + .get_type() + .get_element_type(); + let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else { + panic!("Expected list[int32] but got {indices_elem_ty}") + }; + assert_eq!( + indices_elem_ty.get_bit_width(), + 32, + "Expected list[int32] but got list[int{}]", + indices_elem_ty.get_bit_width() + ); + + let index = call_ndarray_flatten_index(generator, ctx, *self.0, indices); + + unsafe { + ctx.builder + .build_in_bounds_gep( + self.base_ptr(ctx, generator), + &[index], + name.unwrap_or_default(), + ) + .unwrap() + } + } + + fn ptr_offset( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut G, + indices: &Index, + name: Option<&str>, + ) -> PointerValue<'ctx> { + let llvm_usize = generator.get_size_type(ctx.ctx); + + let indices_size = indices.size(ctx, generator); + let nidx_leq_ndims = ctx + .builder + .build_int_compare(IntPredicate::SLE, indices_size, self.0.load_ndims(ctx), "") + .unwrap(); + ctx.make_assert( + generator, + nidx_leq_ndims, + "0:IndexError", + "invalid index to scalar variable", + [None, None, None], + ctx.current_loc, + ); + + let indices_len = indices.size(ctx, generator); + let ndarray_len = self.0.load_ndims(ctx); + let len = call_int_umin(ctx, indices_len, ndarray_len, None); + gen_for_callback_incrementing( + generator, + ctx, + None, + llvm_usize.const_zero(), + (len, false), + |generator, ctx, _, i| { + let (dim_idx, dim_sz) = unsafe { + ( + indices.get_unchecked(ctx, generator, &i, None).into_int_value(), + self.0.dim_sizes().get_typed_unchecked(ctx, generator, &i, None), + ) + }; + let dim_idx = ctx + .builder + .build_int_z_extend_or_bit_cast(dim_idx, dim_sz.get_type(), "") + .unwrap(); + + let dim_lt = + ctx.builder.build_int_compare(IntPredicate::SLT, dim_idx, dim_sz, "").unwrap(); + + ctx.make_assert( + generator, + dim_lt, + "0:IndexError", + "index {0} is out of bounds for axis 0 with size {1}", + [Some(dim_idx), Some(dim_sz), None], + ctx.current_loc, + ); + + Ok(()) + }, + llvm_usize.const_int(1, false), + ) + .unwrap(); + + unsafe { self.ptr_offset_unchecked(ctx, generator, indices, name) } + } +} + +impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeAccessor<'ctx, Index> + for NDArrayDataProxy<'ctx, '_> +{ +} +impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx, Index> + for NDArrayDataProxy<'ctx, '_> +{ +} diff --git a/nac3core/src/codegen/values/range.rs b/nac3core/src/codegen/values/range.rs new file mode 100644 index 0000000..40deb35 --- /dev/null +++ b/nac3core/src/codegen/values/range.rs @@ -0,0 +1,153 @@ +use inkwell::values::{BasicValueEnum, IntValue, PointerValue}; + +use super::ProxyValue; +use crate::codegen::{types::RangeType, CodeGenContext}; + +/// Proxy type for accessing a `range` value in LLVM. +#[derive(Copy, Clone)] +pub struct RangeValue<'ctx> { + value: PointerValue<'ctx>, + name: Option<&'ctx str>, +} + +impl<'ctx> RangeValue<'ctx> { + /// Checks whether `value` is an instance of `range`, returning [Err] if `value` is not an instance. + pub fn is_instance(value: PointerValue<'ctx>) -> Result<(), String> { + RangeType::is_type(value.get_type()) + } + + /// Creates an [`RangeValue`] from a [`PointerValue`]. + #[must_use] + pub fn from_ptr_val(ptr: PointerValue<'ctx>, name: Option<&'ctx str>) -> Self { + debug_assert!(Self::is_instance(ptr).is_ok()); + + RangeValue { value: ptr, name } + } + + fn ptr_to_start(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let var_name = self.name.map(|v| format!("{v}.start.addr")).unwrap_or_default(); + + unsafe { + ctx.builder + .build_in_bounds_gep( + self.as_base_value(), + &[llvm_i32.const_zero(), llvm_i32.const_int(0, false)], + var_name.as_str(), + ) + .unwrap() + } + } + + fn ptr_to_end(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let var_name = self.name.map(|v| format!("{v}.end.addr")).unwrap_or_default(); + + unsafe { + ctx.builder + .build_in_bounds_gep( + self.as_base_value(), + &[llvm_i32.const_zero(), llvm_i32.const_int(1, false)], + var_name.as_str(), + ) + .unwrap() + } + } + + fn ptr_to_step(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let llvm_i32 = ctx.ctx.i32_type(); + let var_name = self.name.map(|v| format!("{v}.step.addr")).unwrap_or_default(); + + unsafe { + ctx.builder + .build_in_bounds_gep( + self.as_base_value(), + &[llvm_i32.const_zero(), llvm_i32.const_int(2, false)], + var_name.as_str(), + ) + .unwrap() + } + } + + /// Stores the `start` value into this instance. + pub fn store_start(&self, ctx: &CodeGenContext<'ctx, '_>, start: IntValue<'ctx>) { + debug_assert_eq!(start.get_type().get_bit_width(), 32); + + let pstart = self.ptr_to_start(ctx); + ctx.builder.build_store(pstart, start).unwrap(); + } + + /// Returns the `start` value of this `range`. + pub fn load_start(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> { + let pstart = self.ptr_to_start(ctx); + let var_name = name + .map(ToString::to_string) + .or_else(|| self.name.map(|v| format!("{v}.start"))) + .unwrap_or_default(); + + ctx.builder + .build_load(pstart, var_name.as_str()) + .map(BasicValueEnum::into_int_value) + .unwrap() + } + + /// Stores the `end` value into this instance. + pub fn store_end(&self, ctx: &CodeGenContext<'ctx, '_>, end: IntValue<'ctx>) { + debug_assert_eq!(end.get_type().get_bit_width(), 32); + + let pend = self.ptr_to_end(ctx); + ctx.builder.build_store(pend, end).unwrap(); + } + + /// Returns the `end` value of this `range`. + pub fn load_end(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> { + let pend = self.ptr_to_end(ctx); + let var_name = name + .map(ToString::to_string) + .or_else(|| self.name.map(|v| format!("{v}.end"))) + .unwrap_or_default(); + + ctx.builder.build_load(pend, var_name.as_str()).map(BasicValueEnum::into_int_value).unwrap() + } + + /// Stores the `step` value into this instance. + pub fn store_step(&self, ctx: &CodeGenContext<'ctx, '_>, step: IntValue<'ctx>) { + debug_assert_eq!(step.get_type().get_bit_width(), 32); + + let pstep = self.ptr_to_step(ctx); + ctx.builder.build_store(pstep, step).unwrap(); + } + + /// Returns the `step` value of this `range`. + pub fn load_step(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> { + let pstep = self.ptr_to_step(ctx); + let var_name = name + .map(ToString::to_string) + .or_else(|| self.name.map(|v| format!("{v}.step"))) + .unwrap_or_default(); + + ctx.builder + .build_load(pstep, var_name.as_str()) + .map(BasicValueEnum::into_int_value) + .unwrap() + } +} + +impl<'ctx> ProxyValue<'ctx> for RangeValue<'ctx> { + type Base = PointerValue<'ctx>; + type Type = RangeType<'ctx>; + + fn get_type(&self) -> Self::Type { + RangeType::from_type(self.value.get_type()) + } + + fn as_base_value(&self) -> Self::Base { + self.value + } +} + +impl<'ctx> From> for PointerValue<'ctx> { + fn from(value: RangeValue<'ctx>) -> Self { + value.as_base_value() + } +} diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 4d88b6e..d4f9664 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -18,9 +18,9 @@ use super::{ use crate::{ codegen::{ builtin_fns, - classes::{ProxyValue, RangeValue}, numpy::*, stmt::exn_constructor, + values::{ProxyValue, RangeValue}, }, symbol_resolver::SymbolValue, typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap}, -- 2.44.2 From a4f53b6e6bf910d695379a8cfd8a5f10d33f48f8 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 1 Nov 2024 15:17:00 +0800 Subject: [PATCH 06/16] [core] codegen: Refactor ProxyType and ProxyValue Accepts generator+context object for generic type checking. Also implements more default trait impl for easier delegation. --- nac3artiq/src/codegen.rs | 11 ++-- nac3core/src/codegen/builtin_fns.rs | 47 +++++++-------- nac3core/src/codegen/expr.rs | 32 +++++----- nac3core/src/codegen/numpy.rs | 81 +++++++++++++++----------- nac3core/src/codegen/stmt.rs | 7 ++- nac3core/src/codegen/test.rs | 6 +- nac3core/src/codegen/types/list.rs | 31 ++++++++-- nac3core/src/codegen/types/mod.rs | 17 +++++- nac3core/src/codegen/types/ndarray.rs | 29 ++++++++- nac3core/src/codegen/types/range.rs | 28 +++++++-- nac3core/src/codegen/values/list.rs | 11 ++-- nac3core/src/codegen/values/mod.rs | 23 +++++++- nac3core/src/codegen/values/ndarray.rs | 11 ++-- nac3core/src/codegen/values/range.rs | 8 +-- nac3core/src/toplevel/builtins.rs | 2 +- 15 files changed, 235 insertions(+), 109 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 85f9b1c..1fcfd4b 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -461,7 +461,8 @@ fn format_rpc_arg<'ctx>( let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_arg_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty); - let llvm_arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), llvm_usize, None); + let llvm_arg = + NDArrayValue::from_pointer_value(arg.into_pointer_value(), llvm_usize, None); let llvm_usize_sizeof = ctx .builder @@ -1315,7 +1316,8 @@ fn polymorphic_print<'ctx>( fmt.push('['); flush(ctx, generator, &mut fmt, &mut args); - let val = ListValue::from_ptr_val(value.into_pointer_value(), llvm_usize, None); + let val = + ListValue::from_pointer_value(value.into_pointer_value(), llvm_usize, None); let len = val.load_size(ctx, None); let last = ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap(); @@ -1371,7 +1373,8 @@ fn polymorphic_print<'ctx>( fmt.push_str("array(["); flush(ctx, generator, &mut fmt, &mut args); - let val = NDArrayValue::from_ptr_val(value.into_pointer_value(), llvm_usize, None); + let val = + NDArrayValue::from_pointer_value(value.into_pointer_value(), llvm_usize, None); let len = call_ndarray_calc_size(generator, ctx, &val.dim_sizes(), (None, None)); let last = ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap(); @@ -1425,7 +1428,7 @@ fn polymorphic_print<'ctx>( fmt.push_str("range("); flush(ctx, generator, &mut fmt, &mut args); - let val = RangeValue::from_ptr_val(value.into_pointer_value(), None); + let val = RangeValue::from_pointer_value(value.into_pointer_value(), None); let (start, stop, step) = destructure_range(ctx, val); diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index a6749d5..7765753 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -47,7 +47,7 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>( let (arg_ty, arg) = n; Ok(if ctx.unifier.unioned(arg_ty, range_ty) { - let arg = RangeValue::from_ptr_val(arg.into_pointer_value(), Some("range")); + let arg = RangeValue::from_pointer_value(arg.into_pointer_value(), Some("range")); let (start, end, step) = destructure_range(ctx, arg); calculate_len_for_slice_range(generator, ctx, start, end, step) } else { @@ -67,7 +67,8 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>( TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { let llvm_usize = generator.get_size_type(ctx.ctx); - let arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), llvm_usize, None); + let arg = + NDArrayValue::from_pointer_value(arg.into_pointer_value(), llvm_usize, None); let ndims = arg.dim_sizes().size(ctx, generator); ctx.make_assert( @@ -148,7 +149,7 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( ctx, ctx.primitives.int32, None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_usize, None), |generator, ctx, val| call_int32(generator, ctx, (elem_ty, val)), )?; @@ -210,7 +211,7 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( ctx, ctx.primitives.int64, None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_usize, None), |generator, ctx, val| call_int64(generator, ctx, (elem_ty, val)), )?; @@ -288,7 +289,7 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( ctx, ctx.primitives.uint32, None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_usize, None), |generator, ctx, val| call_uint32(generator, ctx, (elem_ty, val)), )?; @@ -355,7 +356,7 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( ctx, ctx.primitives.uint64, None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_usize, None), |generator, ctx, val| call_uint64(generator, ctx, (elem_ty, val)), )?; @@ -421,7 +422,7 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( ctx, ctx.primitives.float, None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_usize, None), |generator, ctx, val| call_float(generator, ctx, (elem_ty, val)), )?; @@ -467,7 +468,7 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( ctx, ret_elem_ty, None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_usize, None), |generator, ctx, val| call_round(generator, ctx, (elem_ty, val), ret_elem_ty), )?; @@ -507,7 +508,7 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( ctx, ctx.primitives.float, None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_usize, None), |generator, ctx, val| call_numpy_round(generator, ctx, (elem_ty, val)), )?; @@ -572,7 +573,7 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( ctx, ctx.primitives.bool, None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_usize, None), |generator, ctx, val| { let elem = call_bool(generator, ctx, (elem_ty, val))?; @@ -626,7 +627,7 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( ctx, ret_elem_ty, None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_usize, None), |generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty), )?; @@ -676,7 +677,7 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( ctx, ret_elem_ty, None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_usize, None), |generator, ctx, val| call_ceil(generator, ctx, (elem_ty, val), ret_elem_ty), )?; @@ -907,7 +908,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); - let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); + let n = NDArrayValue::from_pointer_value(n, llvm_usize, None); let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None)); if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { let n_sz_eqz = ctx @@ -1120,7 +1121,7 @@ where ctx, ret_elem_ty, None, - NDArrayValue::from_ptr_val(x, llvm_usize, None), + NDArrayValue::from_pointer_value(x, llvm_usize, None), |generator, ctx, elem_val| { helper_call_numpy_unary_elementwise( generator, @@ -1959,7 +1960,7 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); }; - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); let dim0 = unsafe { n1.dim_sizes() .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) @@ -2001,7 +2002,7 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( unimplemented!("{FN_NAME} operates on float type NdArrays only"); }; - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); let dim0 = unsafe { n1.dim_sizes() .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) @@ -2051,7 +2052,7 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); }; - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); let dim0 = unsafe { n1.dim_sizes() @@ -2106,7 +2107,7 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); }; - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); let dim0 = unsafe { n1.dim_sizes() .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) @@ -2148,7 +2149,7 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); }; - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); let dim0 = unsafe { n1.dim_sizes() @@ -2191,7 +2192,7 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); }; - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); let dim0 = unsafe { n1.dim_sizes() @@ -2244,7 +2245,7 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]); }; - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); // Changing second parameter to a `NDArray` for uniformity in function call let n2_array = numpy::create_ndarray_const_shape( generator, @@ -2339,7 +2340,7 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); }; - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); let dim0 = unsafe { n1.dim_sizes() @@ -2382,7 +2383,7 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); }; - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); let dim0 = unsafe { n1.dim_sizes() diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 7584826..93720f9 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1168,7 +1168,8 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( TypeEnum::TObj { obj_id, .. } if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() => { - let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range")); + let iter_val = + RangeValue::from_pointer_value(iter_val.into_pointer_value(), Some("range")); let (start, stop, step) = destructure_range(ctx, iter_val); let diff = ctx.builder.build_int_sub(stop, start, "diff").unwrap(); // add 1 to the length as the value is rounded to zero @@ -1399,8 +1400,10 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty1); let sizeof_elem = llvm_elem_ty.size_of().unwrap(); - let lhs = ListValue::from_ptr_val(left_val.into_pointer_value(), llvm_usize, None); - let rhs = ListValue::from_ptr_val(right_val.into_pointer_value(), llvm_usize, None); + let lhs = + ListValue::from_pointer_value(left_val.into_pointer_value(), llvm_usize, None); + let rhs = + ListValue::from_pointer_value(right_val.into_pointer_value(), llvm_usize, None); let size = ctx .builder @@ -1483,7 +1486,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( codegen_unreachable!(ctx) }; let list_val = - ListValue::from_ptr_val(list_val.into_pointer_value(), llvm_usize, None); + ListValue::from_pointer_value(list_val.into_pointer_value(), llvm_usize, None); let int_val = ctx .builder .build_int_s_extend(int_val.into_int_value(), llvm_usize, "") @@ -1562,9 +1565,9 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); let left_val = - NDArrayValue::from_ptr_val(left_val.into_pointer_value(), llvm_usize, None); + NDArrayValue::from_pointer_value(left_val.into_pointer_value(), llvm_usize, None); let right_val = - NDArrayValue::from_ptr_val(right_val.into_pointer_value(), llvm_usize, None); + NDArrayValue::from_pointer_value(right_val.into_pointer_value(), llvm_usize, None); let res = if op.base == Operator::MatMult { // MatMult is the only binop which is not an elementwise op @@ -1613,7 +1616,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( } else { let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, if is_ndarray1 { ty1 } else { ty2 }); - let ndarray_val = NDArrayValue::from_ptr_val( + let ndarray_val = NDArrayValue::from_pointer_value( if is_ndarray1 { left_val } else { right_val }.into_pointer_value(), llvm_usize, None, @@ -1808,7 +1811,7 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( let llvm_usize = generator.get_size_type(ctx.ctx); let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); - let val = NDArrayValue::from_ptr_val(val.into_pointer_value(), llvm_usize, None); + let val = NDArrayValue::from_pointer_value(val.into_pointer_value(), llvm_usize, None); // ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before // passing it to the elementwise codegen function @@ -1900,7 +1903,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); let left_val = - NDArrayValue::from_ptr_val(lhs.into_pointer_value(), llvm_usize, None); + NDArrayValue::from_pointer_value(lhs.into_pointer_value(), llvm_usize, None); let res = numpy::ndarray_elementwise_binop_impl( generator, ctx, @@ -2202,9 +2205,9 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( } let left_val = - ListValue::from_ptr_val(lhs.into_pointer_value(), llvm_usize, None); + ListValue::from_pointer_value(lhs.into_pointer_value(), llvm_usize, None); let right_val = - ListValue::from_ptr_val(rhs.into_pointer_value(), llvm_usize, None); + ListValue::from_pointer_value(rhs.into_pointer_value(), llvm_usize, None); Ok(gen_if_else_expr_callback( generator, @@ -2768,7 +2771,8 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( // elements over let subscripted_ndarray = generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?; - let ndarray = NDArrayValue::from_ptr_val(subscripted_ndarray, llvm_usize, None); + let ndarray = + NDArrayValue::from_pointer_value(subscripted_ndarray, llvm_usize, None); let num_dims = v.load_ndims(ctx); ndarray.store_ndims( @@ -3403,7 +3407,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } else { return Ok(None); }; - let v = ListValue::from_ptr_val(v, usize, Some("arr")); + let v = ListValue::from_pointer_value(v, usize, Some("arr")); let ty = ctx.get_llvm_type(generator, *ty); if let ExprKind::Slice { lower, upper, step } = &slice.node { let one = int32.const_int(1, false); @@ -3513,7 +3517,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } else { return Ok(None); }; - let v = NDArrayValue::from_ptr_val(v, usize, None); + let v = NDArrayValue::from_pointer_value(v, usize, None); return gen_ndarray_subscript_expr(generator, ctx, *ty, *ndims, v, slice); } diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 538809c..4589ba4 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -54,7 +54,7 @@ fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>( let ndarray = generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?; - Ok(NDArrayValue::from_ptr_val(ndarray, llvm_usize, None)) + Ok(NDArrayValue::from_pointer_value(ndarray, llvm_usize, None)) } /// Creates an `NDArray` instance from a dynamic shape. @@ -314,11 +314,11 @@ fn call_ndarray_empty_impl<'ctx, G: CodeGenerator + ?Sized>( match shape { BasicValueEnum::PointerValue(shape_list_ptr) - if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() => + if ListValue::is_representable(shape_list_ptr, llvm_usize).is_ok() => { // 1. A list of ints; e.g., `np.empty([600, 800, 3])` - let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None); + let shape_list = ListValue::from_pointer_value(shape_list_ptr, llvm_usize, None); create_ndarray_dyn_shape( generator, ctx, @@ -499,12 +499,14 @@ where // Assert that all ndarray operands are broadcastable to the target size if !lhs_scalar { - let lhs_val = NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None); + let lhs_val = + NDArrayValue::from_pointer_value(lhs_val.into_pointer_value(), llvm_usize, None); ndarray_assert_is_broadcastable(generator, ctx, res, lhs_val); } if !rhs_scalar { - let rhs_val = NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None); + let rhs_val = + NDArrayValue::from_pointer_value(rhs_val.into_pointer_value(), llvm_usize, None); ndarray_assert_is_broadcastable(generator, ctx, res, rhs_val); } @@ -512,7 +514,8 @@ where let lhs_elem = if lhs_scalar { lhs_val } else { - let lhs = NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None); + let lhs = + NDArrayValue::from_pointer_value(lhs_val.into_pointer_value(), llvm_usize, None); let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, idx); unsafe { lhs.data().get_unchecked(ctx, generator, &lhs_idx, None) } @@ -521,7 +524,8 @@ where let rhs_elem = if rhs_scalar { rhs_val } else { - let rhs = NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None); + let rhs = + NDArrayValue::from_pointer_value(rhs_val.into_pointer_value(), llvm_usize, None); let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, idx); unsafe { rhs.data().get_unchecked(ctx, generator, &rhs_idx, None) } @@ -647,11 +651,15 @@ fn llvm_ndlist_get_ndims<'ctx, G: CodeGenerator + ?Sized>( let ndims = llvm_usize.const_int(1, false); match list_elem_ty { - AnyTypeEnum::PointerType(ptr_ty) if ListType::is_type(ptr_ty, llvm_usize).is_ok() => { + AnyTypeEnum::PointerType(ptr_ty) + if ListType::is_representable(ptr_ty, llvm_usize).is_ok() => + { ndims.const_add(llvm_ndlist_get_ndims(generator, ctx, ptr_ty)) } - AnyTypeEnum::PointerType(ptr_ty) if NDArrayType::is_type(ptr_ty, llvm_usize).is_ok() => { + AnyTypeEnum::PointerType(ptr_ty) + if NDArrayType::is_representable(ptr_ty, llvm_usize).is_ok() => + { todo!("Getting ndims for list[ndarray] not supported") } @@ -668,11 +676,13 @@ fn llvm_arraylike_get_ndims<'ctx, G: CodeGenerator + ?Sized>( let llvm_usize = generator.get_size_type(ctx.ctx); match value { - BasicValueEnum::PointerValue(v) if NDArrayValue::is_instance(v, llvm_usize).is_ok() => { - NDArrayValue::from_ptr_val(v, llvm_usize, None).load_ndims(ctx) + BasicValueEnum::PointerValue(v) + if NDArrayValue::is_representable(v, llvm_usize).is_ok() => + { + NDArrayValue::from_pointer_value(v, llvm_usize, None).load_ndims(ctx) } - BasicValueEnum::PointerValue(v) if ListValue::is_instance(v, llvm_usize).is_ok() => { + BasicValueEnum::PointerValue(v) if ListValue::is_representable(v, llvm_usize).is_ok() => { llvm_ndlist_get_ndims(generator, ctx, v.get_type()) } @@ -695,7 +705,9 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( let list_elem_ty = src_lst.get_type().element_type(); match list_elem_ty { - AnyTypeEnum::PointerType(ptr_ty) if ListType::is_type(ptr_ty, llvm_usize).is_ok() => { + AnyTypeEnum::PointerType(ptr_ty) + if ListType::is_representable(ptr_ty, llvm_usize).is_ok() => + { // The stride of elements in this dimension, i.e. the number of elements between arr[i] // and arr[i + 1] in this dimension let stride = call_ndarray_calc_size( @@ -719,7 +731,7 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( let dst_ptr = unsafe { ctx.builder.build_gep(dst_slice_ptr, &[offset], "").unwrap() }; - let nested_lst_elem = ListValue::from_ptr_val( + let nested_lst_elem = ListValue::from_pointer_value( unsafe { src_lst.data().get_unchecked(ctx, generator, &i, None) } .into_pointer_value(), llvm_usize, @@ -740,7 +752,9 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( )?; } - AnyTypeEnum::PointerType(ptr_ty) if NDArrayType::is_type(ptr_ty, llvm_usize).is_ok() => { + AnyTypeEnum::PointerType(ptr_ty) + if NDArrayType::is_representable(ptr_ty, llvm_usize).is_ok() => + { todo!("Not implemented for list[ndarray]") } @@ -801,8 +815,8 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( let object = object.into_pointer_value(); // object is an NDArray instance - copy object unless copy=0 && ndmin < object.ndims - if NDArrayValue::is_instance(object, llvm_usize).is_ok() { - let object = NDArrayValue::from_ptr_val(object, llvm_usize, None); + if NDArrayValue::is_representable(object, llvm_usize).is_ok() { + let object = NDArrayValue::from_pointer_value(object, llvm_usize, None); let ndarray = gen_if_else_expr_callback( generator, @@ -876,7 +890,7 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( |_, _| Ok(Some(object.as_base_value())), )?; - return Ok(NDArrayValue::from_ptr_val( + return Ok(NDArrayValue::from_pointer_value( ndarray.map(BasicValueEnum::into_pointer_value).unwrap(), llvm_usize, None, @@ -884,8 +898,8 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( } // Remaining case: TList - assert!(ListValue::is_instance(object, llvm_usize).is_ok()); - let object = ListValue::from_ptr_val(object, llvm_usize, None); + assert!(ListValue::is_representable(object, llvm_usize).is_ok()); + let object = ListValue::from_pointer_value(object, llvm_usize, None); // The number of dimensions to prepend 1's to let ndims = llvm_ndlist_get_ndims(generator, ctx, object.as_base_value().get_type()); @@ -965,7 +979,8 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( .map(|v| ctx.builder.build_bit_cast(v, plist_plist_i8, "").unwrap()) .map(BasicValueEnum::into_pointer_value) .unwrap(); - let this_dim = ListValue::from_ptr_val(this_dim, llvm_usize, None); + let this_dim = + ListValue::from_pointer_value(this_dim, llvm_usize, None); // TODO: Assert this_dim.sz != 0 @@ -991,7 +1006,7 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( }, )?; - let lst = ListValue::from_ptr_val( + let lst = ListValue::from_pointer_value( ctx.builder .build_load(lst, "") .map(BasicValueEnum::into_pointer_value) @@ -1388,9 +1403,9 @@ where let ndarray = res.unwrap_or_else(|| { if lhs_scalar && rhs_scalar { let lhs_val = - NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None); + NDArrayValue::from_pointer_value(lhs_val.into_pointer_value(), llvm_usize, None); let rhs_val = - NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None); + NDArrayValue::from_pointer_value(rhs_val.into_pointer_value(), llvm_usize, None); let ndarray_dims = call_ndarray_calc_broadcast(generator, ctx, lhs_val, rhs_val); @@ -1406,7 +1421,7 @@ where ) .unwrap() } else { - let ndarray = NDArrayValue::from_ptr_val( + let ndarray = NDArrayValue::from_pointer_value( if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(), llvm_usize, None, @@ -1970,7 +1985,7 @@ pub fn gen_ndarray_copy<'ctx>( generator, context, this_elem_ty, - NDArrayValue::from_ptr_val(this_arg.into_pointer_value(), llvm_usize, None), + NDArrayValue::from_pointer_value(this_arg.into_pointer_value(), llvm_usize, None), ) .map(NDArrayValue::into) } @@ -2002,7 +2017,7 @@ pub fn gen_ndarray_fill<'ctx>( ndarray_fill_flattened( generator, context, - NDArrayValue::from_ptr_val(this_arg, llvm_usize, None), + NDArrayValue::from_pointer_value(this_arg, llvm_usize, None), |generator, ctx, _| { let value = if value_arg.is_pointer_value() { let llvm_i1 = ctx.ctx.bool_type(); @@ -2043,7 +2058,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( if let BasicValueEnum::PointerValue(n1) = x1 { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); // Dimensions are reversed in the transposed array @@ -2162,7 +2177,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>( if let BasicValueEnum::PointerValue(n1) = x1 { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; @@ -2172,11 +2187,11 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>( let out = match shape { BasicValueEnum::PointerValue(shape_list_ptr) - if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() => + if ListValue::is_representable(shape_list_ptr, llvm_usize).is_ok() => { // 1. A list of ints; e.g., `np.reshape(arr, [int64(600), int64(800, -1])` - let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None); + let shape_list = ListValue::from_pointer_value(shape_list_ptr, llvm_usize, None); // Check for -1 in dimensions gen_for_callback_incrementing( generator, @@ -2445,8 +2460,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( match (x1, x2) { (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => { - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); - let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); + let n2 = NDArrayValue::from_pointer_value(n2, llvm_usize, None); let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index d008f7b..3595528 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -310,7 +310,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>( .unwrap() .to_basic_value_enum(ctx, generator, target_ty)? .into_pointer_value(); - let target = ListValue::from_ptr_val(target, llvm_usize, None); + let target = ListValue::from_pointer_value(target, llvm_usize, None); if let ExprKind::Slice { .. } = &key.node { // Handle assigning to a slice @@ -331,7 +331,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>( let value = value.to_basic_value_enum(ctx, generator, value_ty)?.into_pointer_value(); - let value = ListValue::from_ptr_val(value, llvm_usize, None); + let value = ListValue::from_pointer_value(value, llvm_usize, None); let target_item_ty = ctx.get_llvm_type(generator, target_item_ty); let Some(src_ind) = handle_slice_indices( @@ -463,7 +463,8 @@ pub fn gen_for( TypeEnum::TObj { obj_id, .. } if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() => { - let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range")); + let iter_val = + RangeValue::from_pointer_value(iter_val.into_pointer_value(), Some("range")); // Internal variable for loop; Cannot be assigned let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?; // Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 5dd6070..a1c391a 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -452,7 +452,7 @@ fn test_classes_list_type_new() { let llvm_usize = generator.get_size_type(&ctx); let llvm_list = ListType::new(&generator, &ctx, llvm_i32.into()); - assert!(ListType::is_type(llvm_list.as_base_type(), llvm_usize).is_ok()); + assert!(ListType::is_representable(llvm_list.as_base_type(), llvm_usize).is_ok()); } #[test] @@ -460,7 +460,7 @@ fn test_classes_range_type_new() { let ctx = inkwell::context::Context::create(); let llvm_range = RangeType::new(&ctx); - assert!(RangeType::is_type(llvm_range.as_base_type()).is_ok()); + assert!(RangeType::is_representable(llvm_range.as_base_type()).is_ok()); } #[test] @@ -472,5 +472,5 @@ fn test_classes_ndarray_type_new() { let llvm_usize = generator.get_size_type(&ctx); let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into()); - assert!(NDArrayType::is_type(llvm_ndarray.as_base_type(), llvm_usize).is_ok()); + assert!(NDArrayType::is_representable(llvm_ndarray.as_base_type(), llvm_usize).is_ok()); } diff --git a/nac3core/src/codegen/types/list.rs b/nac3core/src/codegen/types/list.rs index c68e4ab..4561b48 100644 --- a/nac3core/src/codegen/types/list.rs +++ b/nac3core/src/codegen/types/list.rs @@ -20,7 +20,10 @@ pub struct ListType<'ctx> { impl<'ctx> ListType<'ctx> { /// Checks whether `llvm_ty` represents a `list` type, returning [Err] if it does not. - pub fn is_type(llvm_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> { + pub fn is_representable( + llvm_ty: PointerType<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { let llvm_list_ty = llvm_ty.get_element_type(); let AnyTypeEnum::StructType(llvm_list_ty) = llvm_list_ty else { return Err(format!("Expected struct type for `list` type, got {llvm_list_ty}")); @@ -73,7 +76,7 @@ impl<'ctx> ListType<'ctx> { /// Creates an [`ListType`] from a [`PointerType`]. #[must_use] pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { - debug_assert!(Self::is_type(ptr_ty, llvm_usize).is_ok()); + debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); ListType { ty: ptr_ty, llvm_usize } } @@ -106,6 +109,26 @@ impl<'ctx> ProxyType<'ctx> for ListType<'ctx> { type Base = PointerType<'ctx>; type Value = ListValue<'ctx>; + fn is_type( + generator: &G, + ctx: &'ctx Context, + llvm_ty: impl BasicType<'ctx>, + ) -> Result<(), String> { + if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { + >::is_representable(generator, ctx, ty) + } else { + Err(format!("Expected pointer type, got {llvm_ty:?}")) + } + } + + fn is_representable( + generator: &G, + ctx: &'ctx Context, + llvm_ty: Self::Base, + ) -> Result<(), String> { + Self::is_representable(llvm_ty, generator.get_size_type(ctx)) + } + fn new_value( &self, generator: &mut G, @@ -146,9 +169,7 @@ impl<'ctx> ProxyType<'ctx> for ListType<'ctx> { value: >::Base, name: Option<&'ctx str>, ) -> Self::Value { - debug_assert_eq!(value.get_type(), self.as_base_type()); - - ListValue::from_ptr_val(value, self.llvm_usize, name) + Self::Value::from_pointer_value(value, self.llvm_usize, name) } fn as_base_type(&self) -> Self::Base { diff --git a/nac3core/src/codegen/types/mod.rs b/nac3core/src/codegen/types/mod.rs index 6032936..ab3d46b 100644 --- a/nac3core/src/codegen/types/mod.rs +++ b/nac3core/src/codegen/types/mod.rs @@ -1,4 +1,4 @@ -use inkwell::{types::BasicType, values::IntValue}; +use inkwell::{context::Context, types::BasicType, values::IntValue}; use super::{ values::{ArraySliceValue, ProxyValue}, @@ -19,7 +19,20 @@ pub trait ProxyType<'ctx>: Into { type Base: BasicType<'ctx>; /// The type of values represented by this type. - type Value: ProxyValue<'ctx>; + type Value: ProxyValue<'ctx, Type = Self>; + + fn is_type( + generator: &G, + ctx: &'ctx Context, + llvm_ty: impl BasicType<'ctx>, + ) -> Result<(), String>; + + /// Checks whether `llvm_ty` can be represented by this [`ProxyType`]. + fn is_representable( + generator: &G, + ctx: &'ctx Context, + llvm_ty: Self::Base, + ) -> Result<(), String>; /// Creates a new value of this type. fn new_value( diff --git a/nac3core/src/codegen/types/ndarray.rs b/nac3core/src/codegen/types/ndarray.rs index 780a7a6..ca463b6 100644 --- a/nac3core/src/codegen/types/ndarray.rs +++ b/nac3core/src/codegen/types/ndarray.rs @@ -20,7 +20,10 @@ pub struct NDArrayType<'ctx> { impl<'ctx> NDArrayType<'ctx> { /// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not. - pub fn is_type(llvm_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> { + pub fn is_representable( + llvm_ty: PointerType<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { let llvm_ndarray_ty = llvm_ty.get_element_type(); let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}")); @@ -101,7 +104,7 @@ impl<'ctx> NDArrayType<'ctx> { /// Creates an [`NDArrayType`] from a [`PointerType`]. #[must_use] pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { - debug_assert!(Self::is_type(ptr_ty, llvm_usize).is_ok()); + debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); NDArrayType { ty: ptr_ty, llvm_usize } } @@ -134,6 +137,26 @@ impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> { type Base = PointerType<'ctx>; type Value = NDArrayValue<'ctx>; + fn is_type( + generator: &G, + ctx: &'ctx Context, + llvm_ty: impl BasicType<'ctx>, + ) -> Result<(), String> { + if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { + >::is_representable(generator, ctx, ty) + } else { + Err(format!("Expected pointer type, got {llvm_ty:?}")) + } + } + + fn is_representable( + generator: &G, + ctx: &'ctx Context, + llvm_ty: Self::Base, + ) -> Result<(), String> { + Self::is_representable(llvm_ty, generator.get_size_type(ctx)) + } + fn new_value( &self, generator: &mut G, @@ -176,7 +199,7 @@ impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> { ) -> Self::Value { debug_assert_eq!(value.get_type(), self.as_base_type()); - NDArrayValue::from_ptr_val(value, self.llvm_usize, name) + NDArrayValue::from_pointer_value(value, self.llvm_usize, name) } fn as_base_type(&self) -> Self::Base { diff --git a/nac3core/src/codegen/types/range.rs b/nac3core/src/codegen/types/range.rs index 3b22916..89a1b72 100644 --- a/nac3core/src/codegen/types/range.rs +++ b/nac3core/src/codegen/types/range.rs @@ -1,6 +1,6 @@ use inkwell::{ context::Context, - types::{AnyTypeEnum, IntType, PointerType}, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, values::IntValue, AddressSpace, }; @@ -19,7 +19,7 @@ pub struct RangeType<'ctx> { impl<'ctx> RangeType<'ctx> { /// Checks whether `llvm_ty` represents a `range` type, returning [Err] if it does not. - pub fn is_type(llvm_ty: PointerType<'ctx>) -> Result<(), String> { + pub fn is_representable(llvm_ty: PointerType<'ctx>) -> Result<(), String> { let llvm_range_ty = llvm_ty.get_element_type(); let AnyTypeEnum::ArrayType(llvm_range_ty) = llvm_range_ty else { return Err(format!("Expected array type for `range` type, got {llvm_range_ty}")); @@ -59,7 +59,7 @@ impl<'ctx> RangeType<'ctx> { /// Creates an [`RangeType`] from a [`PointerType`]. #[must_use] pub fn from_type(ptr_ty: PointerType<'ctx>) -> Self { - debug_assert!(Self::is_type(ptr_ty).is_ok()); + debug_assert!(Self::is_representable(ptr_ty).is_ok()); RangeType { ty: ptr_ty } } @@ -75,6 +75,26 @@ impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> { type Base = PointerType<'ctx>; type Value = RangeValue<'ctx>; + fn is_type( + generator: &G, + ctx: &'ctx Context, + llvm_ty: impl BasicType<'ctx>, + ) -> Result<(), String> { + if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { + >::is_representable(generator, ctx, ty) + } else { + Err(format!("Expected pointer type, got {llvm_ty:?}")) + } + } + + fn is_representable( + _: &G, + _: &'ctx Context, + llvm_ty: Self::Base, + ) -> Result<(), String> { + Self::is_representable(llvm_ty) + } + fn new_value( &self, generator: &mut G, @@ -117,7 +137,7 @@ impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> { ) -> Self::Value { debug_assert_eq!(value.get_type(), self.as_base_type()); - RangeValue::from_ptr_val(value, name) + RangeValue::from_pointer_value(value, name) } fn as_base_type(&self) -> Self::Base { diff --git a/nac3core/src/codegen/values/list.rs b/nac3core/src/codegen/values/list.rs index e81993d..7b1975f 100644 --- a/nac3core/src/codegen/values/list.rs +++ b/nac3core/src/codegen/values/list.rs @@ -23,18 +23,21 @@ pub struct ListValue<'ctx> { impl<'ctx> ListValue<'ctx> { /// Checks whether `value` is an instance of `list`, returning [Err] if `value` is not an /// instance. - pub fn is_instance(value: PointerValue<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> { - ListType::is_type(value.get_type(), llvm_usize) + pub fn is_representable( + value: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + ListType::is_representable(value.get_type(), llvm_usize) } /// Creates an [`ListValue`] from a [`PointerValue`]. #[must_use] - pub fn from_ptr_val( + pub fn from_pointer_value( ptr: PointerValue<'ctx>, llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, ) -> Self { - debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); + debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); ListValue { value: ptr, llvm_usize, name } } diff --git a/nac3core/src/codegen/values/mod.rs b/nac3core/src/codegen/values/mod.rs index db1c7bf..29f534e 100644 --- a/nac3core/src/codegen/values/mod.rs +++ b/nac3core/src/codegen/values/mod.rs @@ -1,6 +1,7 @@ -use inkwell::values::BasicValue; +use inkwell::{context::Context, values::BasicValue}; use super::types::ProxyType; +use crate::codegen::CodeGenerator; pub use array::*; pub use list::*; pub use ndarray::*; @@ -18,7 +19,25 @@ pub trait ProxyValue<'ctx>: Into { type Base: BasicValue<'ctx>; /// The type of this value. - type Type: ProxyType<'ctx>; + type Type: ProxyType<'ctx, Value = Self>; + + /// Checks whether `value` can be represented by this [`ProxyValue`]. + fn is_instance( + generator: &G, + ctx: &'ctx Context, + value: impl BasicValue<'ctx>, + ) -> Result<(), String> { + Self::Type::is_type(generator, ctx, value.as_basic_value_enum().get_type()) + } + + /// Checks whether `value` can be represented by this [`ProxyValue`]. + fn is_representable( + generator: &G, + ctx: &'ctx Context, + value: Self::Base, + ) -> Result<(), String> { + Self::is_instance(generator, ctx, value.as_basic_value_enum()) + } /// Returns the [type][ProxyType] of this value. fn get_type(&self) -> Self::Type; diff --git a/nac3core/src/codegen/values/ndarray.rs b/nac3core/src/codegen/values/ndarray.rs index 23e8836..908ad2f 100644 --- a/nac3core/src/codegen/values/ndarray.rs +++ b/nac3core/src/codegen/values/ndarray.rs @@ -27,18 +27,21 @@ pub struct NDArrayValue<'ctx> { impl<'ctx> NDArrayValue<'ctx> { /// Checks whether `value` is an instance of `NDArray`, returning [Err] if `value` is not an /// instance. - pub fn is_instance(value: PointerValue<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> { - NDArrayType::is_type(value.get_type(), llvm_usize) + pub fn is_representable( + value: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + NDArrayType::is_representable(value.get_type(), llvm_usize) } /// Creates an [`NDArrayValue`] from a [`PointerValue`]. #[must_use] - pub fn from_ptr_val( + pub fn from_pointer_value( ptr: PointerValue<'ctx>, llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, ) -> Self { - debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); + debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); NDArrayValue { value: ptr, llvm_usize, name } } diff --git a/nac3core/src/codegen/values/range.rs b/nac3core/src/codegen/values/range.rs index 40deb35..7e9976a 100644 --- a/nac3core/src/codegen/values/range.rs +++ b/nac3core/src/codegen/values/range.rs @@ -12,14 +12,14 @@ pub struct RangeValue<'ctx> { impl<'ctx> RangeValue<'ctx> { /// Checks whether `value` is an instance of `range`, returning [Err] if `value` is not an instance. - pub fn is_instance(value: PointerValue<'ctx>) -> Result<(), String> { - RangeType::is_type(value.get_type()) + pub fn is_representable(value: PointerValue<'ctx>) -> Result<(), String> { + RangeType::is_representable(value.get_type()) } /// Creates an [`RangeValue`] from a [`PointerValue`]. #[must_use] - pub fn from_ptr_val(ptr: PointerValue<'ctx>, name: Option<&'ctx str>) -> Self { - debug_assert!(Self::is_instance(ptr).is_ok()); + pub fn from_pointer_value(ptr: PointerValue<'ctx>, name: Option<&'ctx str>) -> Self { + debug_assert!(Self::is_representable(ptr).is_ok()); RangeValue { value: ptr, name } } diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index d4f9664..e382222 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -710,7 +710,7 @@ impl<'a> BuiltinBuilder<'a> { let (zelf_ty, zelf) = obj.unwrap(); let zelf = zelf.to_basic_value_enum(ctx, generator, zelf_ty)?.into_pointer_value(); - let zelf = RangeValue::from_ptr_val(zelf, Some("range")); + let zelf = RangeValue::from_pointer_value(zelf, Some("range")); let mut start = None; let mut stop = None; -- 2.44.2 From d7633c42bc49876bc031c60a61097290c5c5593c Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 13 Nov 2024 15:40:47 +0800 Subject: [PATCH 07/16] [core] codegen/types: Implement StructField{,s} Loosely based on FieldTraversal by lyken. --- nac3core/src/codegen/types/mod.rs | 1 + nac3core/src/codegen/types/structure.rs | 203 ++++++++++++++++++++++++ 2 files changed, 204 insertions(+) create mode 100644 nac3core/src/codegen/types/structure.rs diff --git a/nac3core/src/codegen/types/mod.rs b/nac3core/src/codegen/types/mod.rs index ab3d46b..fc64c9d 100644 --- a/nac3core/src/codegen/types/mod.rs +++ b/nac3core/src/codegen/types/mod.rs @@ -11,6 +11,7 @@ pub use range::*; mod list; mod ndarray; mod range; +pub mod structure; /// A LLVM type that is used to represent a corresponding type in NAC3. pub trait ProxyType<'ctx>: Into { diff --git a/nac3core/src/codegen/types/structure.rs b/nac3core/src/codegen/types/structure.rs new file mode 100644 index 0000000..444fa2c --- /dev/null +++ b/nac3core/src/codegen/types/structure.rs @@ -0,0 +1,203 @@ +use std::marker::PhantomData; + +use inkwell::{ + context::AsContextRef, + types::{BasicTypeEnum, IntType}, + values::{BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue}, +}; + +use crate::codegen::CodeGenContext; + +/// Trait indicating that the structure is a field-wise representation of an LLVM structure. +/// +/// # Usage +/// +/// For example, for a simple C-slice LLVM structure: +/// +/// ```ignore +/// struct CSliceFields<'ctx> { +/// ptr: StructField<'ctx, PointerValue<'ctx>>, +/// len: StructField<'ctx, IntValue<'ctx>> +/// } +/// ``` +pub trait StructFields<'ctx>: Eq + Copy { + /// Creates an instance of [`StructFields`] using the given `ctx` and `size_t` types. + fn new(ctx: impl AsContextRef<'ctx>, llvm_usize: IntType<'ctx>) -> Self; + + /// Returns a [`Vec`] that contains the fields of the structure in the order as they appear in + /// the type definition. + #[must_use] + fn to_vec(&self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)>; + + /// Returns a [`Iterator`] that contains the fields of the structure in the order as they appear + /// in the type definition. + #[must_use] + fn iter(&self) -> impl Iterator)> { + self.to_vec().into_iter() + } + + /// Returns a [`Vec`] that contains the fields of the structure in the order as they appear in + /// the type definition. + #[must_use] + fn into_vec(self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)> + where + Self: Sized, + { + self.to_vec() + } + + /// Returns a [`Iterator`] that contains the fields of the structure in the order as they appear + /// in the type definition. + #[must_use] + fn into_iter(self) -> impl Iterator)> + where + Self: Sized, + { + self.into_vec().into_iter() + } +} + +/// A single field of an LLVM structure. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct StructField<'ctx, Value> +where + Value: BasicValue<'ctx> + TryFrom, Error = ()>, +{ + /// The index of this field within the structure. + index: u32, + + /// The name of this field. + name: &'static str, + + /// The type of this field. + ty: BasicTypeEnum<'ctx>, + + /// Instance of [`PhantomData`] containing [`Value`], used to implement automatic downcasts. + _value_ty: PhantomData, +} + +impl<'ctx, Value> StructField<'ctx, Value> +where + Value: BasicValue<'ctx> + TryFrom, Error = ()>, +{ + /// Creates an instance of [`StructField`]. + /// + /// * `idx_counter` - The instance of [`FieldIndexCounter`] used to track the current field + /// index. + /// * `name` - Name of the field. + /// * `ty` - The type of this field. + pub fn create( + idx_counter: &mut FieldIndexCounter, + name: &'static str, + ty: impl Into>, + ) -> Self { + StructField { index: idx_counter.increment(), name, ty: ty.into(), _value_ty: PhantomData } + } + + /// Creates an instance of [`StructField`] with a given index. + /// + /// * `index` - The index of this field within its enclosing structure. + /// * `name` - Name of the field. + /// * `ty` - The type of this field. + pub fn create_at(index: u32, name: &'static str, ty: impl Into>) -> Self { + StructField { index, name, ty: ty.into(), _value_ty: PhantomData } + } + + /// Creates a pointer to this field in an arbitrary structure by performing a `getelementptr i32 + /// {idx...}, i32 {self.index}`. + pub fn ptr_by_array_gep( + &self, + ctx: &CodeGenContext<'ctx, '_>, + pobj: PointerValue<'ctx>, + idx: &[IntValue<'ctx>], + ) -> PointerValue<'ctx> { + unsafe { + ctx.builder.build_in_bounds_gep( + pobj, + &[idx, &[ctx.ctx.i32_type().const_int(u64::from(self.index), false)]].concat(), + "", + ) + } + .unwrap() + } + + /// Creates a pointer to this field in an arbitrary structure by performing the equivalent of + /// `getelementptr i32 0, i32 {self.index}`. + pub fn ptr_by_gep( + &self, + ctx: &CodeGenContext<'ctx, '_>, + pobj: PointerValue<'ctx>, + obj_name: Option<&'ctx str>, + ) -> PointerValue<'ctx> { + ctx.builder + .build_struct_gep( + pobj, + self.index, + &obj_name.map(|name| format!("{name}.{}.addr", self.name)).unwrap_or_default(), + ) + .unwrap() + } + + /// Gets the value of this field for a given `obj`. + #[must_use] + pub fn get_from_value(&self, obj: StructValue<'ctx>) -> Value { + obj.get_field_at_index(self.index).and_then(|value| Value::try_from(value).ok()).unwrap() + } + + /// Sets the value of this field for a given `obj`. + pub fn set_from_value(&self, obj: StructValue<'ctx>, value: Value) { + obj.set_field_at_index(self.index, value); + } + + /// Gets the value of this field for a pointer-to-structure. + pub fn get( + &self, + ctx: &CodeGenContext<'ctx, '_>, + pobj: PointerValue<'ctx>, + obj_name: Option<&'ctx str>, + ) -> Value { + ctx.builder + .build_load( + self.ptr_by_gep(ctx, pobj, obj_name), + &obj_name.map(|name| format!("{name}.{}", self.name)).unwrap_or_default(), + ) + .map_err(|_| ()) + .and_then(|value| Value::try_from(value)) + .unwrap() + } + + /// Sets the value of this field for a pointer-to-structure. + pub fn set( + &self, + ctx: &CodeGenContext<'ctx, '_>, + pobj: PointerValue<'ctx>, + value: Value, + obj_name: Option<&'ctx str>, + ) { + ctx.builder.build_store(self.ptr_by_gep(ctx, pobj, obj_name), value).unwrap(); + } +} + +impl<'ctx, Value> From> for (&'static str, BasicTypeEnum<'ctx>) +where + Value: BasicValue<'ctx> + TryFrom, Error = ()>, +{ + fn from(value: StructField<'ctx, Value>) -> Self { + (value.name, value.ty) + } +} + +/// A counter that tracks the next index of a field using a monotonically increasing counter. +#[derive(Default, Debug, PartialEq, Eq, Clone, Copy)] +pub struct FieldIndexCounter(u32); + +impl FieldIndexCounter { + /// Increments the number stored by this counter, returning the previous value. + /// + /// Functionally equivalent to `i++` in C-based languages. + pub fn increment(&mut self) -> u32 { + let v = self.0; + self.0 += 1; + v + } +} -- 2.44.2 From 88e57f71206723424941ae7d2a22f767c5cd36d3 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 19 Nov 2024 13:14:51 +0800 Subject: [PATCH 08/16] [core_derive] Initial implementation --- Cargo.lock | 132 ++++++++ Cargo.toml | 1 + nac3core/Cargo.toml | 3 + nac3core/nac3core_derive/Cargo.toml | 21 ++ nac3core/nac3core_derive/src/lib.rs | 320 ++++++++++++++++++ .../tests/structfields_empty.rs | 9 + .../tests/structfields_ndarray.rs | 20 ++ .../tests/structfields_slice.rs | 18 + .../tests/structfields_slice_context.rs | 18 + .../tests/structfields_slice_ctx.rs | 18 + .../tests/structfields_slice_sizet.rs | 18 + .../tests/structfields_test.rs | 10 + nac3core/src/lib.rs | 2 + 13 files changed, 590 insertions(+) create mode 100644 nac3core/nac3core_derive/Cargo.toml create mode 100644 nac3core/nac3core_derive/src/lib.rs create mode 100644 nac3core/nac3core_derive/tests/structfields_empty.rs create mode 100644 nac3core/nac3core_derive/tests/structfields_ndarray.rs create mode 100644 nac3core/nac3core_derive/tests/structfields_slice.rs create mode 100644 nac3core/nac3core_derive/tests/structfields_slice_context.rs create mode 100644 nac3core/nac3core_derive/tests/structfields_slice_ctx.rs create mode 100644 nac3core/nac3core_derive/tests/structfields_slice_sizet.rs create mode 100644 nac3core/nac3core_derive/tests/structfields_test.rs diff --git a/Cargo.lock b/Cargo.lock index fa19e0d..bf68184 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -282,6 +282,12 @@ dependencies = [ "crypto-common", ] +[[package]] +name = "dissimilar" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59f8e79d1fbf76bdfbde321e902714bf6c49df88a7dda6fc682fc2979226962d" + [[package]] name = "either" version = "1.13.0" @@ -370,6 +376,12 @@ dependencies = [ "wasi", ] +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + [[package]] name = "hashbrown" version = "0.12.3" @@ -648,6 +660,7 @@ dependencies = [ "inkwell", "insta", "itertools", + "nac3core_derive", "nac3parser", "parking_lot", "rayon", @@ -657,6 +670,18 @@ dependencies = [ "test-case", ] +[[package]] +name = "nac3core_derive" +version = "0.1.0" +dependencies = [ + "nac3core", + "proc-macro-error", + "proc-macro2", + "quote", + "syn 2.0.87", + "trybuild", +] + [[package]] name = "nac3ld" version = "0.1.0" @@ -822,6 +847,30 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" +[[package]] +name = "proc-macro-error" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +dependencies = [ + "proc-macro-error-attr", + "proc-macro2", + "quote", + "syn 1.0.109", + "version_check", +] + +[[package]] +name = "proc-macro-error-attr" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +dependencies = [ + "proc-macro2", + "quote", + "version_check", +] + [[package]] name = "proc-macro2" version = "1.0.89" @@ -1076,6 +1125,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_spanned" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87607cb1398ed59d48732e575a4c28a7a8ebf2454b964fe3f224f2afc07909e1" +dependencies = [ + "serde", +] + [[package]] name = "serde_yaml" version = "0.8.26" @@ -1199,6 +1257,12 @@ version = "0.12.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" +[[package]] +name = "target-triple" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42a4d50cdb458045afc8131fd91b64904da29548bcb63c7236e0844936c13078" + [[package]] name = "tempfile" version = "3.14.0" @@ -1222,6 +1286,15 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "termcolor" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" +dependencies = [ + "winapi-util", +] + [[package]] name = "test-case" version = "1.2.3" @@ -1255,6 +1328,56 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "toml" +version = "0.8.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1ed1f98e3fdc28d6d910e6737ae6ab1a93bf1985935a1193e68f93eeb68d24e" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit", +] + +[[package]] +name = "toml_datetime" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" +dependencies = [ + "serde", +] + +[[package]] +name = "toml_edit" +version = "0.22.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" +dependencies = [ + "indexmap 2.6.0", + "serde", + "serde_spanned", + "toml_datetime", + "winnow", +] + +[[package]] +name = "trybuild" +version = "1.0.101" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8dcd332a5496c026f1e14b7f3d2b7bd98e509660c04239c58b0ba38a12daded4" +dependencies = [ + "dissimilar", + "glob", + "serde", + "serde_derive", + "serde_json", + "target-triple", + "termcolor", + "toml", +] + [[package]] name = "typenum" version = "1.17.0" @@ -1478,6 +1601,15 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "winnow" +version = "0.6.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b" +dependencies = [ + "memchr", +] + [[package]] name = "yaml-rust" version = "0.4.5" diff --git a/Cargo.toml b/Cargo.toml index 765ab39..7a28185 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ members = [ "nac3ast", "nac3parser", "nac3core", + "nac3core/nac3core_derive", "nac3standalone", "nac3artiq", "runkernel", diff --git a/nac3core/Cargo.toml b/nac3core/Cargo.toml index 5c0c36a..6521a33 100644 --- a/nac3core/Cargo.toml +++ b/nac3core/Cargo.toml @@ -5,6 +5,8 @@ authors = ["M-Labs"] edition = "2021" [features] +default = ["derive"] +derive = ["dep:nac3core_derive"] no-escape-analysis = [] [dependencies] @@ -13,6 +15,7 @@ crossbeam = "0.8" indexmap = "2.6" parking_lot = "0.12" rayon = "1.10" +nac3core_derive = { path = "nac3core_derive", optional = true } nac3parser = { path = "../nac3parser" } strum = "0.26" strum_macros = "0.26" diff --git a/nac3core/nac3core_derive/Cargo.toml b/nac3core/nac3core_derive/Cargo.toml new file mode 100644 index 0000000..adf4ad8 --- /dev/null +++ b/nac3core/nac3core_derive/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "nac3core_derive" +version = "0.1.0" +edition = "2021" + +[lib] +proc-macro = true + +[[test]] +name = "structfields_tests" +path = "tests/structfields_test.rs" + +[dev-dependencies] +nac3core = { path = ".." } +trybuild = { version = "1.0", features = ["diff"] } + +[dependencies] +proc-macro2 = "1.0" +proc-macro-error = "1.0" +syn = "2.0" +quote = "1.0" diff --git a/nac3core/nac3core_derive/src/lib.rs b/nac3core/nac3core_derive/src/lib.rs new file mode 100644 index 0000000..44d6aeb --- /dev/null +++ b/nac3core/nac3core_derive/src/lib.rs @@ -0,0 +1,320 @@ +use proc_macro::TokenStream; +use proc_macro_error::{abort, proc_macro_error}; +use quote::quote; +use syn::{ + parse_macro_input, spanned::Spanned, Data, DataStruct, Expr, ExprField, ExprMethodCall, + ExprPath, GenericArgument, Ident, LitStr, Path, PathArguments, Type, TypePath, +}; + +/// Extracts all generic arguments of a [`Type`] into a [`Vec`]. +/// +/// Returns [`Some`] of a possibly-empty [`Vec`] if the path of `ty` matches with +/// `expected_ty_name`, otherwise returns [`None`]. +fn extract_generic_args(expected_ty_name: &'static str, ty: &Type) -> Option> { + let Type::Path(TypePath { qself: None, path, .. }) = ty else { + return None; + }; + + let segments = &path.segments; + if segments.len() != 1 { + return None; + }; + + let segment = segments.iter().next().unwrap(); + if segment.ident != expected_ty_name { + return None; + } + + let PathArguments::AngleBracketed(path_args) = &segment.arguments else { + return Some(Vec::new()); + }; + let args = &path_args.args; + + Some(args.iter().cloned().collect::>()) +} + +/// Maps a `path` matching one of the `target_idents` into the `replacement` [`Ident`]. +fn map_path_to_ident(path: &Path, target_idents: &[&str], replacement: &str) -> Option { + path.require_ident() + .ok() + .filter(|ident| target_idents.iter().any(|target| ident == target)) + .map(|ident| Ident::new(replacement, ident.span())) +} + +/// Extracts the left-hand side of a dot-expression. +fn extract_dot_operand(expr: &Expr) -> Option<&Expr> { + match expr { + Expr::MethodCall(ExprMethodCall { receiver: operand, .. }) + | Expr::Field(ExprField { base: operand, .. }) => Some(operand), + _ => None, + } +} + +/// Replaces the top-level receiver of a dot-expression with an [`Ident`], returning `Some(&mut expr)` if the +/// replacement is performed. +/// +/// The top-level receiver is the left-most receiver expression, e.g. the top-level receiver of `a.b.c.foo()` is `a`. +fn replace_top_level_receiver(expr: &mut Expr, ident: Ident) -> Option<&mut Expr> { + if let Expr::MethodCall(ExprMethodCall { receiver: operand, .. }) + | Expr::Field(ExprField { base: operand, .. }) = expr + { + return if extract_dot_operand(operand).is_some() { + if replace_top_level_receiver(operand, ident).is_some() { + Some(expr) + } else { + None + } + } else { + *operand = Box::new(Expr::Path(ExprPath { + attrs: Vec::default(), + qself: None, + path: ident.into(), + })); + + Some(expr) + }; + } + + None +} + +/// Iterates all operands to the left-hand side of the `.` of an [expression][`Expr`], i.e. the container operand of all +/// [`Expr::Field`] and the receiver operand of all [`Expr::MethodCall`]. +/// +/// The iterator will return the operand expressions in reverse order of appearance. For example, `a.b.c.func()` will +/// return `vec![c, b, a]`. +fn iter_dot_operands(expr: &Expr) -> impl Iterator { + let mut o = extract_dot_operand(expr); + + std::iter::from_fn(move || { + let this = o; + o = o.as_ref().and_then(|o| extract_dot_operand(o)); + + this + }) +} + +/// Normalizes a value expression for use when creating an instance of this structure, returning a +/// [`proc_macro2::TokenStream`] of tokens representing the normalized expression. +fn normalize_value_expr(expr: &Expr) -> proc_macro2::TokenStream { + match &expr { + Expr::Path(ExprPath { qself: None, path, .. }) => { + if let Some(ident) = map_path_to_ident(path, &["usize", "size_t"], "llvm_usize") { + quote! { #ident } + } else { + abort!( + path, + format!( + "Expected one of `size_t`, `usize`, or an implicit call expression in #[value_type(...)], found {}", + quote!(#expr).to_string(), + ) + ) + } + } + + Expr::Call(_) => { + quote! { ctx.#expr } + } + + Expr::MethodCall(_) => { + let base_receiver = iter_dot_operands(expr).last(); + + match base_receiver { + // `usize.{...}`, `size_t.{...}` -> Rewrite the identifiers to `llvm_usize` + Some(Expr::Path(ExprPath { qself: None, path, .. })) + if map_path_to_ident(path, &["usize", "size_t"], "llvm_usize").is_some() => + { + let ident = + map_path_to_ident(path, &["usize", "size_t"], "llvm_usize").unwrap(); + + let mut expr = expr.clone(); + let expr = replace_top_level_receiver(&mut expr, ident).unwrap(); + + quote!(#expr) + } + + // `ctx.{...}`, `context.{...}` -> Rewrite the identifiers to `ctx` + Some(Expr::Path(ExprPath { qself: None, path, .. })) + if map_path_to_ident(path, &["ctx", "context"], "ctx").is_some() => + { + let ident = map_path_to_ident(path, &["ctx", "context"], "ctx").unwrap(); + + let mut expr = expr.clone(); + let expr = replace_top_level_receiver(&mut expr, ident).unwrap(); + + quote!(#expr) + } + + // No reserved identifier prefix -> Prepend `ctx.` to the entire expression + _ => quote! { ctx.#expr }, + } + } + + _ => { + abort!( + expr, + format!( + "Expected one of `size_t`, `usize`, or an implicit call expression in #[value_type(...)], found {}", + quote!(#expr).to_string(), + ) + ) + } + } +} + +/// Derives an implementation of `codegen::types::structure::StructFields`. +/// +/// The benefit of using `#[derive(StructFields)]` is that all index- or order-dependent logic required by +/// `impl StructFields` is automatically generated by this implementation, including the field index as required by +/// `StructField::new` and the fields as returned by `StructFields::to_vec`. +/// +/// # Prerequisites +/// +/// In order to derive from [`StructFields`], you must implement (or derive) [`Eq`] and [`Copy`] as required by +/// `StructFields`. +/// +/// Moreover, `#[derive(StructFields)]` can only be used for `struct`s with named fields, and may only contain fields +/// with either `StructField` or [`PhantomData`] types. +/// +/// # Attributes for [`StructFields`] +/// +/// Each `StructField` field must be declared with the `#[value_type(...)]` attribute. The argument of `value_type` +/// accepts one of the following: +/// +/// - An expression returning an instance of `inkwell::types::BasicType` (with or without the receiver `ctx`/`context`). +/// For example, `context.i8_type()`, `ctx.i8_type()`, and `i8_type()` all refer to `i8`. +/// - The reserved identifiers `usize` and `size_t` referring to an `inkwell::types::IntType` of the platform-dependent +/// integer size. `usize` and `size_t` can also be used as the receiver to other method calls, e.g. +/// `usize.array_type(3)`. +/// +/// # Example +/// +/// The following is an example of an LLVM slice implemented using `#[derive(StructFields)]`. +/// +/// ``` +/// use nac3core::{ +/// codegen::types::structure::StructField, +/// inkwell::{ +/// values::{IntValue, PointerValue}, +/// AddressSpace, +/// }, +/// }; +/// use nac3core_derive::StructFields; +/// +/// // All classes that implement StructFields must also implement Eq and Copy +/// #[derive(PartialEq, Eq, Clone, Copy, StructFields)] +/// pub struct SliceValue<'ctx> { +/// // Declares ptr have a value type of i8* +/// // +/// // Can also be written as `ctx.i8_type().ptr_type(...)` or `context.i8_type().ptr_type(...)` +/// #[value_type(i8_type().ptr_type(AddressSpace::default()))] +/// ptr: StructField<'ctx, PointerValue<'ctx>>, +/// +/// // Declares len have a value type of usize, depending on the target compilation platform +/// #[value_type(usize)] +/// len: StructField<'ctx, IntValue<'ctx>>, +/// } +/// ``` +#[proc_macro_derive(StructFields, attributes(value_type))] +#[proc_macro_error] +pub fn derive(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as syn::DeriveInput); + let ident = &input.ident; + + let Data::Struct(DataStruct { fields, .. }) = &input.data else { + abort!(input, "Only structs with named fields are supported"); + }; + if let Err(err_span) = + fields + .iter() + .try_for_each(|field| if field.ident.is_some() { Ok(()) } else { Err(field.span()) }) + { + abort!(err_span, "Only structs with named fields are supported"); + }; + + // Check if struct<'ctx> + if input.generics.params.len() != 1 { + abort!(input.generics, "Expected exactly 1 generic parameter") + } + + let phantom_info = fields + .iter() + .filter(|field| extract_generic_args("PhantomData", &field.ty).is_some()) + .map(|field| field.ident.as_ref().unwrap()) + .cloned() + .collect::>(); + + let field_info = fields + .iter() + .filter(|field| extract_generic_args("PhantomData", &field.ty).is_none()) + .map(|field| { + let ident = field.ident.as_ref().unwrap(); + let ty = &field.ty; + + let Some(_) = extract_generic_args("StructField", ty) else { + abort!(field, "Only StructField and PhantomData are allowed") + }; + + let attrs = &field.attrs; + let Some(value_type_attr) = + attrs.iter().find(|attr| attr.path().is_ident("value_type")) + else { + abort!(field, "Expected #[value_type(...)] attribute for field"); + }; + + let Ok(value_type_expr) = value_type_attr.parse_args::() else { + abort!(value_type_attr, "Expected expression in #[value_type(...)]"); + }; + + let value_expr_toks = normalize_value_expr(&value_type_expr); + + (ident.clone(), value_expr_toks) + }) + .collect::>(); + + // `<*>::new` impl of `StructField` and `PhantomData` for `StructFields::new` + let phantoms_create = phantom_info + .iter() + .map(|id| quote! { #id: ::std::marker::PhantomData }) + .collect::>(); + let fields_create = field_info + .iter() + .map(|(id, ty)| { + let id_lit = LitStr::new(&id.to_string(), id.span()); + quote! { + #id: ::nac3core::codegen::types::structure::StructField::create( + &mut counter, + #id_lit, + #ty, + ) + } + }) + .collect::>(); + + // `.into()` impl of `StructField` for `StructFields::to_vec` + let fields_into = + field_info.iter().map(|(id, _)| quote! { self.#id.into() }).collect::>(); + + let impl_block = quote! { + impl<'ctx> ::nac3core::codegen::types::structure::StructFields<'ctx> for #ident<'ctx> { + fn new(ctx: impl ::nac3core::inkwell::context::AsContextRef<'ctx>, llvm_usize: ::nac3core::inkwell::types::IntType<'ctx>) -> Self { + let ctx = unsafe { ::nac3core::inkwell::context::ContextRef::new(ctx.as_ctx_ref()) }; + + let mut counter = ::nac3core::codegen::types::structure::FieldIndexCounter::default(); + + #ident { + #(#fields_create),* + #(#phantoms_create),* + } + } + + fn to_vec(&self) -> ::std::vec::Vec<(&'static str, ::nac3core::inkwell::types::BasicTypeEnum<'ctx>)> { + vec![ + #(#fields_into),* + ] + } + } + }; + + impl_block.into() +} diff --git a/nac3core/nac3core_derive/tests/structfields_empty.rs b/nac3core/nac3core_derive/tests/structfields_empty.rs new file mode 100644 index 0000000..0a3b19b --- /dev/null +++ b/nac3core/nac3core_derive/tests/structfields_empty.rs @@ -0,0 +1,9 @@ +use nac3core_derive::StructFields; +use std::marker::PhantomData; + +#[derive(PartialEq, Eq, Clone, Copy, StructFields)] +pub struct EmptyValue<'ctx> { + _phantom: PhantomData<&'ctx ()>, +} + +fn main() {} diff --git a/nac3core/nac3core_derive/tests/structfields_ndarray.rs b/nac3core/nac3core_derive/tests/structfields_ndarray.rs new file mode 100644 index 0000000..b556c80 --- /dev/null +++ b/nac3core/nac3core_derive/tests/structfields_ndarray.rs @@ -0,0 +1,20 @@ +use nac3core::{ + codegen::types::structure::StructField, + inkwell::{ + values::{IntValue, PointerValue}, + AddressSpace, + }, +}; +use nac3core_derive::StructFields; + +#[derive(PartialEq, Eq, Clone, Copy, StructFields)] +pub struct NDArrayValue<'ctx> { + #[value_type(usize)] + ndims: StructField<'ctx, IntValue<'ctx>>, + #[value_type(usize.ptr_type(AddressSpace::default()))] + shape: StructField<'ctx, PointerValue<'ctx>>, + #[value_type(i8_type().ptr_type(AddressSpace::default()))] + data: StructField<'ctx, PointerValue<'ctx>>, +} + +fn main() {} diff --git a/nac3core/nac3core_derive/tests/structfields_slice.rs b/nac3core/nac3core_derive/tests/structfields_slice.rs new file mode 100644 index 0000000..a191459 --- /dev/null +++ b/nac3core/nac3core_derive/tests/structfields_slice.rs @@ -0,0 +1,18 @@ +use nac3core::{ + codegen::types::structure::StructField, + inkwell::{ + values::{IntValue, PointerValue}, + AddressSpace, + }, +}; +use nac3core_derive::StructFields; + +#[derive(PartialEq, Eq, Clone, Copy, StructFields)] +pub struct SliceValue<'ctx> { + #[value_type(i8_type().ptr_type(AddressSpace::default()))] + ptr: StructField<'ctx, PointerValue<'ctx>>, + #[value_type(usize)] + len: StructField<'ctx, IntValue<'ctx>>, +} + +fn main() {} diff --git a/nac3core/nac3core_derive/tests/structfields_slice_context.rs b/nac3core/nac3core_derive/tests/structfields_slice_context.rs new file mode 100644 index 0000000..7a9f2e4 --- /dev/null +++ b/nac3core/nac3core_derive/tests/structfields_slice_context.rs @@ -0,0 +1,18 @@ +use nac3core::{ + codegen::types::structure::StructField, + inkwell::{ + values::{IntValue, PointerValue}, + AddressSpace, + }, +}; +use nac3core_derive::StructFields; + +#[derive(PartialEq, Eq, Clone, Copy, StructFields)] +pub struct SliceValue<'ctx> { + #[value_type(context.i8_type().ptr_type(AddressSpace::default()))] + ptr: StructField<'ctx, PointerValue<'ctx>>, + #[value_type(usize)] + len: StructField<'ctx, IntValue<'ctx>>, +} + +fn main() {} diff --git a/nac3core/nac3core_derive/tests/structfields_slice_ctx.rs b/nac3core/nac3core_derive/tests/structfields_slice_ctx.rs new file mode 100644 index 0000000..deee780 --- /dev/null +++ b/nac3core/nac3core_derive/tests/structfields_slice_ctx.rs @@ -0,0 +1,18 @@ +use nac3core::{ + codegen::types::structure::StructField, + inkwell::{ + values::{IntValue, PointerValue}, + AddressSpace, + }, +}; +use nac3core_derive::StructFields; + +#[derive(PartialEq, Eq, Clone, Copy, StructFields)] +pub struct SliceValue<'ctx> { + #[value_type(ctx.i8_type().ptr_type(AddressSpace::default()))] + ptr: StructField<'ctx, PointerValue<'ctx>>, + #[value_type(usize)] + len: StructField<'ctx, IntValue<'ctx>>, +} + +fn main() {} diff --git a/nac3core/nac3core_derive/tests/structfields_slice_sizet.rs b/nac3core/nac3core_derive/tests/structfields_slice_sizet.rs new file mode 100644 index 0000000..29efa4b --- /dev/null +++ b/nac3core/nac3core_derive/tests/structfields_slice_sizet.rs @@ -0,0 +1,18 @@ +use nac3core::{ + codegen::types::structure::StructField, + inkwell::{ + values::{IntValue, PointerValue}, + AddressSpace, + }, +}; +use nac3core_derive::StructFields; + +#[derive(PartialEq, Eq, Clone, Copy, StructFields)] +pub struct SliceValue<'ctx> { + #[value_type(i8_type().ptr_type(AddressSpace::default()))] + ptr: StructField<'ctx, PointerValue<'ctx>>, + #[value_type(size_t)] + len: StructField<'ctx, IntValue<'ctx>>, +} + +fn main() {} diff --git a/nac3core/nac3core_derive/tests/structfields_test.rs b/nac3core/nac3core_derive/tests/structfields_test.rs new file mode 100644 index 0000000..fb4af4f --- /dev/null +++ b/nac3core/nac3core_derive/tests/structfields_test.rs @@ -0,0 +1,10 @@ +#[test] +fn test_parse_empty() { + let t = trybuild::TestCases::new(); + t.pass("tests/structfields_empty.rs"); + t.pass("tests/structfields_slice.rs"); + t.pass("tests/structfields_slice_ctx.rs"); + t.pass("tests/structfields_slice_context.rs"); + t.pass("tests/structfields_slice_sizet.rs"); + t.pass("tests/structfields_ndarray.rs"); +} diff --git a/nac3core/src/lib.rs b/nac3core/src/lib.rs index 91ae05a..0c12732 100644 --- a/nac3core/src/lib.rs +++ b/nac3core/src/lib.rs @@ -21,3 +21,5 @@ pub mod codegen; pub mod symbol_resolver; pub mod toplevel; pub mod typecheck; + +extern crate self as nac3core; -- 2.44.2 From 48e2148c0fe2bb76ec3e9d7899a22f8ef08c8930 Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 20 Aug 2024 11:32:01 +0800 Subject: [PATCH 09/16] core/toplevel/helper: add {extract,create}_ndims --- nac3core/src/toplevel/helper.rs | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 9661310..71ee35b 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -1134,3 +1134,23 @@ pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 { _ => 0, } } + +/// Extract an ndarray's `ndims` [type][`Type`] in `u64`. Panic if not possible. +/// The `ndims` must only contain 1 value. +#[must_use] +pub fn extract_ndims(unifier: &Unifier, ndims_ty: Type) -> u64 { + let ndims_ty_enum = unifier.get_ty_immutable(ndims_ty); + let TypeEnum::TLiteral { values, .. } = &*ndims_ty_enum else { + panic!("ndims_ty should be a TLiteral"); + }; + + assert_eq!(values.len(), 1, "ndims_ty TLiteral should only contain 1 value"); + + let ndims = values[0].clone(); + u64::try_from(ndims).unwrap() +} + +/// Return an ndarray's `ndims` as a typechecker [`Type`] from its `u64` value. +pub fn create_ndims(unifier: &mut Unifier, ndims: u64) -> Type { + unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None) +} -- 2.44.2 From f95f979ad3de4860fb35ee76bd7c3fda3704aee8 Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 20 Aug 2024 11:38:05 +0800 Subject: [PATCH 10/16] core/irrt: fix exception.hpp C++ castings --- nac3core/irrt/irrt/exception.hpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/nac3core/irrt/irrt/exception.hpp b/nac3core/irrt/irrt/exception.hpp index 5b1ec59..4a4b093 100644 --- a/nac3core/irrt/irrt/exception.hpp +++ b/nac3core/irrt/irrt/exception.hpp @@ -55,11 +55,14 @@ void _raise_exception_helper(ExceptionId id, int64_t param2) { Exception e = { .id = id, - .filename = {.base = reinterpret_cast(filename), .len = __builtin_strlen(filename)}, + .filename = {.base = reinterpret_cast(const_cast(filename)), + .len = static_cast(__builtin_strlen(filename))}, .line = line, .column = 0, - .function = {.base = reinterpret_cast(function), .len = __builtin_strlen(function)}, - .msg = {.base = reinterpret_cast(msg), .len = __builtin_strlen(msg)}, + .function = {.base = reinterpret_cast(const_cast(function)), + .len = static_cast(__builtin_strlen(function))}, + .msg = {.base = reinterpret_cast(const_cast(msg)), + .len = static_cast(__builtin_strlen(msg))}, }; e.params[0] = param0; e.params[1] = param1; -- 2.44.2 From 1ba2e287a6dad556a2fa63ed18979a9c224cd888 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 8 Nov 2024 15:35:45 +0800 Subject: [PATCH 11/16] [core] codegen: Add Self::llvm_type to all type abstractions --- nac3core/src/codegen/types/list.rs | 20 +++++++++----- nac3core/src/codegen/types/ndarray.rs | 38 ++++++++++++++++----------- nac3core/src/codegen/types/range.rs | 11 ++++++-- 3 files changed, 45 insertions(+), 24 deletions(-) diff --git a/nac3core/src/codegen/types/list.rs b/nac3core/src/codegen/types/list.rs index 4561b48..7b08236 100644 --- a/nac3core/src/codegen/types/list.rs +++ b/nac3core/src/codegen/types/list.rs @@ -55,6 +55,19 @@ impl<'ctx> ListType<'ctx> { Ok(()) } + /// Creates an LLVM type corresponding to the expected structure of a `List`. + #[must_use] + fn llvm_type( + ctx: &'ctx Context, + element_type: BasicTypeEnum<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> PointerType<'ctx> { + // struct List { data: T*, size: size_t } + let field_tys = [element_type.ptr_type(AddressSpace::default()).into(), llvm_usize.into()]; + + ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) + } + /// Creates an instance of [`ListType`]. #[must_use] pub fn new( @@ -63,12 +76,7 @@ impl<'ctx> ListType<'ctx> { element_type: BasicTypeEnum<'ctx>, ) -> Self { let llvm_usize = generator.get_size_type(ctx); - let llvm_list = ctx - .struct_type( - &[element_type.ptr_type(AddressSpace::default()).into(), llvm_usize.into()], - false, - ) - .ptr_type(AddressSpace::default()); + let llvm_list = Self::llvm_type(ctx, element_type, llvm_usize); ListType::from_type(llvm_list, llvm_usize) } diff --git a/nac3core/src/codegen/types/ndarray.rs b/nac3core/src/codegen/types/ndarray.rs index ca463b6..98bcdb6 100644 --- a/nac3core/src/codegen/types/ndarray.rs +++ b/nac3core/src/codegen/types/ndarray.rs @@ -73,6 +73,27 @@ impl<'ctx> NDArrayType<'ctx> { Ok(()) } + /// Creates an LLVM type corresponding to the expected structure of an `NDArray`. + #[must_use] + fn llvm_type( + ctx: &'ctx Context, + dtype: BasicTypeEnum<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> PointerType<'ctx> { + // struct NDArray { num_dims: size_t, dims: size_t*, data: T* } + // + // * num_dims: Number of dimensions in the array + // * dims: Pointer to an array containing the size of each dimension + // * data: Pointer to an array containing the array data + let field_tys = [ + llvm_usize.into(), + llvm_usize.ptr_type(AddressSpace::default()).into(), + dtype.ptr_type(AddressSpace::default()).into(), + ]; + + ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) + } + /// Creates an instance of [`ListType`]. #[must_use] pub fn new( @@ -81,22 +102,7 @@ impl<'ctx> NDArrayType<'ctx> { dtype: BasicTypeEnum<'ctx>, ) -> Self { let llvm_usize = generator.get_size_type(ctx); - - // struct NDArray { num_dims: size_t, dims: size_t*, data: T* } - // - // * num_dims: Number of dimensions in the array - // * dims: Pointer to an array containing the size of each dimension - // * data: Pointer to an array containing the array data - let llvm_ndarray = ctx - .struct_type( - &[ - llvm_usize.into(), - llvm_usize.ptr_type(AddressSpace::default()).into(), - dtype.ptr_type(AddressSpace::default()).into(), - ], - false, - ) - .ptr_type(AddressSpace::default()); + let llvm_ndarray = Self::llvm_type(ctx, dtype, llvm_usize); NDArrayType::from_type(llvm_ndarray, llvm_usize) } diff --git a/nac3core/src/codegen/types/range.rs b/nac3core/src/codegen/types/range.rs index 89a1b72..49c4388 100644 --- a/nac3core/src/codegen/types/range.rs +++ b/nac3core/src/codegen/types/range.rs @@ -47,11 +47,18 @@ impl<'ctx> RangeType<'ctx> { Ok(()) } + /// Creates an LLVM type corresponding to the expected structure of a `Range`. + #[must_use] + fn llvm_type(ctx: &'ctx Context) -> PointerType<'ctx> { + // typedef int32_t Range[3]; + let llvm_i32 = ctx.i32_type(); + llvm_i32.array_type(3).ptr_type(AddressSpace::default()) + } + /// Creates an instance of [`RangeType`]. #[must_use] pub fn new(ctx: &'ctx Context) -> Self { - let llvm_i32 = ctx.i32_type(); - let llvm_range = llvm_i32.array_type(3).ptr_type(AddressSpace::default()); + let llvm_range = Self::llvm_type(ctx); RangeType::from_type(llvm_range) } -- 2.44.2 From 1a535db558adeb6221a008d3914d3321622676e2 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 8 Nov 2024 15:49:01 +0800 Subject: [PATCH 12/16] [core] codegen: Add dtype to NDArrayType We won't have this once NDArray is refactored to strided impl. --- nac3core/src/codegen/types/ndarray.rs | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/nac3core/src/codegen/types/ndarray.rs b/nac3core/src/codegen/types/ndarray.rs index 98bcdb6..3f25f82 100644 --- a/nac3core/src/codegen/types/ndarray.rs +++ b/nac3core/src/codegen/types/ndarray.rs @@ -15,6 +15,7 @@ use crate::codegen::{ #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub struct NDArrayType<'ctx> { ty: PointerType<'ctx>, + dtype: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>, } @@ -112,7 +113,14 @@ impl<'ctx> NDArrayType<'ctx> { pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); - NDArrayType { ty: ptr_ty, llvm_usize } + NDArrayType { + ty: ptr_ty, + dtype: ptr_ty + .get_element_type() + .try_into() + .expect("Expected BasicTypeEnum for dtype of NDArray"), + llvm_usize, + } } /// Returns the type of the `size` field of this `ndarray` type. @@ -128,14 +136,8 @@ impl<'ctx> NDArrayType<'ctx> { /// Returns the element type of this `ndarray` type. #[must_use] - pub fn element_type(&self) -> AnyTypeEnum<'ctx> { - self.as_base_type() - .get_element_type() - .into_struct_type() - .get_field_type_at_index(2) - .map(BasicTypeEnum::into_pointer_type) - .map(PointerType::get_element_type) - .unwrap() + pub fn element_type(&self) -> BasicTypeEnum<'ctx> { + self.dtype } } -- 2.44.2 From b58c99369ea3335dfc05f890c951813d663f96c8 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 11 Nov 2024 15:38:24 +0800 Subject: [PATCH 13/16] [core] irrt: Update some IRRT implementation - Change CSlice to use `void*` for better pointer compatibility - Only include impl *.hpp files in irrt.cpp - Refactor typedef to using declaration - Add missing ``// namespace` --- nac3core/irrt/irrt.cpp | 1 - nac3core/irrt/irrt/cslice.hpp | 2 +- nac3core/irrt/irrt/exception.hpp | 16 ++++++++-------- nac3core/irrt/irrt/int_types.hpp | 5 +++++ nac3core/irrt/irrt/list.hpp | 26 ++++++++++++++++---------- nac3core/irrt/irrt/math.hpp | 2 +- nac3core/irrt/irrt/ndarray.hpp | 2 +- nac3core/irrt/irrt/slice.hpp | 2 +- 8 files changed, 33 insertions(+), 23 deletions(-) diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 404681c..7966322 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -1,5 +1,4 @@ #include "irrt/exception.hpp" -#include "irrt/int_types.hpp" #include "irrt/list.hpp" #include "irrt/math.hpp" #include "irrt/ndarray.hpp" diff --git a/nac3core/irrt/irrt/cslice.hpp b/nac3core/irrt/irrt/cslice.hpp index 6f1d0a2..41441d3 100644 --- a/nac3core/irrt/irrt/cslice.hpp +++ b/nac3core/irrt/irrt/cslice.hpp @@ -4,6 +4,6 @@ template struct CSlice { - uint8_t* base; + void* base; SizeT len; }; \ No newline at end of file diff --git a/nac3core/irrt/irrt/exception.hpp b/nac3core/irrt/irrt/exception.hpp index 4a4b093..78accbc 100644 --- a/nac3core/irrt/irrt/exception.hpp +++ b/nac3core/irrt/irrt/exception.hpp @@ -6,7 +6,7 @@ /** * @brief The int type of ARTIQ exception IDs. */ -typedef int32_t ExceptionId; +using ExceptionId = int32_t; /* * Set of exceptions C++ IRRT can use. @@ -55,14 +55,14 @@ void _raise_exception_helper(ExceptionId id, int64_t param2) { Exception e = { .id = id, - .filename = {.base = reinterpret_cast(const_cast(filename)), - .len = static_cast(__builtin_strlen(filename))}, + .filename = {.base = reinterpret_cast(const_cast(filename)), + .len = static_cast(__builtin_strlen(filename))}, .line = line, .column = 0, - .function = {.base = reinterpret_cast(const_cast(function)), - .len = static_cast(__builtin_strlen(function))}, - .msg = {.base = reinterpret_cast(const_cast(msg)), - .len = static_cast(__builtin_strlen(msg))}, + .function = {.base = reinterpret_cast(const_cast(function)), + .len = static_cast(__builtin_strlen(function))}, + .msg = {.base = reinterpret_cast(const_cast(msg)), + .len = static_cast(__builtin_strlen(msg))}, }; e.params[0] = param0; e.params[1] = param1; @@ -70,6 +70,7 @@ void _raise_exception_helper(ExceptionId id, __nac3_raise(reinterpret_cast(&e)); __builtin_unreachable(); } +} // namespace /** * @brief Raise an exception with location details (location in the IRRT source files). @@ -82,4 +83,3 @@ void _raise_exception_helper(ExceptionId id, */ #define raise_exception(SizeT, id, msg, param0, param1, param2) \ _raise_exception_helper(id, __FILE__, __LINE__, __FUNCTION__, msg, param0, param1, param2) -} // namespace \ No newline at end of file diff --git a/nac3core/irrt/irrt/int_types.hpp b/nac3core/irrt/irrt/int_types.hpp index 656a060..6f2fcef 100644 --- a/nac3core/irrt/irrt/int_types.hpp +++ b/nac3core/irrt/irrt/int_types.hpp @@ -8,12 +8,17 @@ using uint32_t = unsigned _BitInt(32); using int64_t = _BitInt(64); using uint64_t = unsigned _BitInt(64); #else + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wdeprecated-type" using int8_t = _ExtInt(8); using uint8_t = unsigned _ExtInt(8); using int32_t = _ExtInt(32); using uint32_t = unsigned _ExtInt(32); using int64_t = _ExtInt(64); using uint64_t = unsigned _ExtInt(64); +#pragma clang diagnostic pop + #endif // NDArray indices are always `uint32_t`. diff --git a/nac3core/irrt/irrt/list.hpp b/nac3core/irrt/irrt/list.hpp index b389197..2854394 100644 --- a/nac3core/irrt/irrt/list.hpp +++ b/nac3core/irrt/irrt/list.hpp @@ -13,12 +13,12 @@ extern "C" { SliceIndex __nac3_list_slice_assign_var_size(SliceIndex dest_start, SliceIndex dest_end, SliceIndex dest_step, - uint8_t* dest_arr, + void* dest_arr, SliceIndex dest_arr_len, SliceIndex src_start, SliceIndex src_end, SliceIndex src_step, - uint8_t* src_arr, + void* src_arr, SliceIndex src_arr_len, const SliceIndex size) { /* if dest_arr_len == 0, do nothing since we do not support extending list */ @@ -29,11 +29,13 @@ SliceIndex __nac3_list_slice_assign_var_size(SliceIndex dest_start, const SliceIndex src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0; const SliceIndex dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0; if (src_len > 0) { - __builtin_memmove(dest_arr + dest_start * size, src_arr + src_start * size, src_len * size); + __builtin_memmove(static_cast(dest_arr) + dest_start * size, + static_cast(src_arr) + src_start * size, src_len * size); } if (dest_len > 0) { /* dropping */ - __builtin_memmove(dest_arr + (dest_start + src_len) * size, dest_arr + (dest_end + 1) * size, + __builtin_memmove(static_cast(dest_arr) + (dest_start + src_len) * size, + static_cast(dest_arr) + (dest_end + 1) * size, (dest_arr_len - dest_end - 1) * size); } /* shrink size */ @@ -44,7 +46,7 @@ SliceIndex __nac3_list_slice_assign_var_size(SliceIndex dest_start, && !(max(dest_start, dest_end) < min(src_start, src_end) || max(src_start, src_end) < min(dest_start, dest_end)); if (need_alloca) { - uint8_t* tmp = reinterpret_cast(__builtin_alloca(src_arr_len * size)); + void* tmp = __builtin_alloca(src_arr_len * size); __builtin_memcpy(tmp, src_arr, src_arr_len * size); src_arr = tmp; } @@ -53,20 +55,24 @@ SliceIndex __nac3_list_slice_assign_var_size(SliceIndex dest_start, for (; (src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end); src_ind += src_step, dest_ind += dest_step) { /* for constant optimization */ if (size == 1) { - __builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1); + __builtin_memcpy(static_cast(dest_arr) + dest_ind, static_cast(src_arr) + src_ind, 1); } else if (size == 4) { - __builtin_memcpy(dest_arr + dest_ind * 4, src_arr + src_ind * 4, 4); + __builtin_memcpy(static_cast(dest_arr) + dest_ind * 4, + static_cast(src_arr) + src_ind * 4, 4); } else if (size == 8) { - __builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8); + __builtin_memcpy(static_cast(dest_arr) + dest_ind * 8, + static_cast(src_arr) + src_ind * 8, 8); } else { /* memcpy for var size, cannot overlap after previous alloca */ - __builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size); + __builtin_memcpy(static_cast(dest_arr) + dest_ind * size, + static_cast(src_arr) + src_ind * size, size); } } /* only dest_step == 1 can we shrink the dest list. */ /* size should be ensured prior to calling this function */ if (dest_step == 1 && dest_end >= dest_start) { - __builtin_memmove(dest_arr + dest_ind * size, dest_arr + (dest_end + 1) * size, + __builtin_memmove(static_cast(dest_arr) + dest_ind * size, + static_cast(dest_arr) + (dest_end + 1) * size, (dest_arr_len - dest_end - 1) * size); return dest_arr_len - (dest_end - dest_ind) - 1; } diff --git a/nac3core/irrt/irrt/math.hpp b/nac3core/irrt/irrt/math.hpp index ff10f3f..1872f56 100644 --- a/nac3core/irrt/irrt/math.hpp +++ b/nac3core/irrt/irrt/math.hpp @@ -90,4 +90,4 @@ double __nac3_j0(double x) { return j0(x); } -} \ No newline at end of file +} // namespace \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray.hpp b/nac3core/irrt/irrt/ndarray.hpp index 5dda173..b239152 100644 --- a/nac3core/irrt/irrt/ndarray.hpp +++ b/nac3core/irrt/irrt/ndarray.hpp @@ -141,4 +141,4 @@ void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims, NDIndex* out_idx) { __nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx); } -} \ No newline at end of file +} // namespace \ No newline at end of file diff --git a/nac3core/irrt/irrt/slice.hpp b/nac3core/irrt/irrt/slice.hpp index a1523dd..3f6d83a 100644 --- a/nac3core/irrt/irrt/slice.hpp +++ b/nac3core/irrt/irrt/slice.hpp @@ -25,4 +25,4 @@ SliceIndex __nac3_range_slice_len(const SliceIndex start, const SliceIndex end, return 0; } } -} \ No newline at end of file +} // namespace \ No newline at end of file -- 2.44.2 From f7e296da530c9940d247e128a3229ddd7389d374 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 11 Nov 2024 16:16:23 +0800 Subject: [PATCH 14/16] [core] irrt: Break IRRT into several impl files Each IRRT file is now mapped to one Rust file. --- nac3core/src/codegen/irrt/list.rs | 162 ++++++ nac3core/src/codegen/irrt/math.rs | 152 ++++++ nac3core/src/codegen/irrt/mod.rs | 749 +-------------------------- nac3core/src/codegen/irrt/ndarray.rs | 384 ++++++++++++++ nac3core/src/codegen/irrt/slice.rs | 76 +++ 5 files changed, 786 insertions(+), 737 deletions(-) create mode 100644 nac3core/src/codegen/irrt/list.rs create mode 100644 nac3core/src/codegen/irrt/math.rs create mode 100644 nac3core/src/codegen/irrt/ndarray.rs create mode 100644 nac3core/src/codegen/irrt/slice.rs diff --git a/nac3core/src/codegen/irrt/list.rs b/nac3core/src/codegen/irrt/list.rs new file mode 100644 index 0000000..a7fec59 --- /dev/null +++ b/nac3core/src/codegen/irrt/list.rs @@ -0,0 +1,162 @@ +use inkwell::{ + types::BasicTypeEnum, + values::{BasicValueEnum, CallSiteValue, IntValue}, + AddressSpace, IntPredicate, +}; +use itertools::Either; + +use super::calculate_len_for_slice_range; +use crate::codegen::{ + macros::codegen_unreachable, + values::{ArrayLikeValue, ListValue}, + CodeGenContext, CodeGenerator, +}; + +/// This function handles 'end' **inclusively**. +/// Order of tuples `assign_idx` and `value_idx` is ('start', 'end', 'step'). +/// Negative index should be handled before entering this function +pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ty: BasicTypeEnum<'ctx>, + dest_arr: ListValue<'ctx>, + dest_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), + src_arr: ListValue<'ctx>, + src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), +) { + let size_ty = generator.get_size_type(ctx.ctx); + let int8_ptr = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); + let int32 = ctx.ctx.i32_type(); + let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", int8_ptr); + let slice_assign_fun = { + let ty_vec = vec![ + int32.into(), // dest start idx + int32.into(), // dest end idx + int32.into(), // dest step + elem_ptr_type.into(), // dest arr ptr + int32.into(), // dest arr len + int32.into(), // src start idx + int32.into(), // src end idx + int32.into(), // src step + elem_ptr_type.into(), // src arr ptr + int32.into(), // src arr len + int32.into(), // size + ]; + ctx.module.get_function(fun_symbol).unwrap_or_else(|| { + let fn_t = int32.fn_type(ty_vec.as_slice(), false); + ctx.module.add_function(fun_symbol, fn_t, None) + }) + }; + + let zero = int32.const_zero(); + let one = int32.const_int(1, false); + let dest_arr_ptr = dest_arr.data().base_ptr(ctx, generator); + let dest_arr_ptr = + ctx.builder.build_pointer_cast(dest_arr_ptr, elem_ptr_type, "dest_arr_ptr_cast").unwrap(); + let dest_len = dest_arr.load_size(ctx, Some("dest.len")); + let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32").unwrap(); + let src_arr_ptr = src_arr.data().base_ptr(ctx, generator); + let src_arr_ptr = + ctx.builder.build_pointer_cast(src_arr_ptr, elem_ptr_type, "src_arr_ptr_cast").unwrap(); + let src_len = src_arr.load_size(ctx, Some("src.len")); + let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32").unwrap(); + + // index in bound and positive should be done + // assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and + // throw exception if not satisfied + let src_end = ctx + .builder + .build_select( + ctx.builder.build_int_compare(IntPredicate::SLT, src_idx.2, zero, "is_neg").unwrap(), + ctx.builder.build_int_sub(src_idx.1, one, "e_min_one").unwrap(), + ctx.builder.build_int_add(src_idx.1, one, "e_add_one").unwrap(), + "final_e", + ) + .map(BasicValueEnum::into_int_value) + .unwrap(); + let dest_end = ctx + .builder + .build_select( + ctx.builder.build_int_compare(IntPredicate::SLT, dest_idx.2, zero, "is_neg").unwrap(), + ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one").unwrap(), + ctx.builder.build_int_add(dest_idx.1, one, "e_add_one").unwrap(), + "final_e", + ) + .map(BasicValueEnum::into_int_value) + .unwrap(); + let src_slice_len = + calculate_len_for_slice_range(generator, ctx, src_idx.0, src_end, src_idx.2); + let dest_slice_len = + calculate_len_for_slice_range(generator, ctx, dest_idx.0, dest_end, dest_idx.2); + let src_eq_dest = ctx + .builder + .build_int_compare(IntPredicate::EQ, src_slice_len, dest_slice_len, "slice_src_eq_dest") + .unwrap(); + let src_slt_dest = ctx + .builder + .build_int_compare(IntPredicate::SLT, src_slice_len, dest_slice_len, "slice_src_slt_dest") + .unwrap(); + let dest_step_eq_one = ctx + .builder + .build_int_compare( + IntPredicate::EQ, + dest_idx.2, + dest_idx.2.get_type().const_int(1, false), + "slice_dest_step_eq_one", + ) + .unwrap(); + let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1").unwrap(); + let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond").unwrap(); + ctx.make_assert( + generator, + cond, + "0:ValueError", + "attempt to assign sequence of size {0} to slice of size {1} with step size {2}", + [Some(src_slice_len), Some(dest_slice_len), Some(dest_idx.2)], + ctx.current_loc, + ); + + let new_len = { + let args = vec![ + dest_idx.0.into(), // dest start idx + dest_idx.1.into(), // dest end idx + dest_idx.2.into(), // dest step + dest_arr_ptr.into(), // dest arr ptr + dest_len.into(), // dest arr len + src_idx.0.into(), // src start idx + src_idx.1.into(), // src end idx + src_idx.2.into(), // src step + src_arr_ptr.into(), // src arr ptr + src_len.into(), // src arr len + { + let s = match ty { + BasicTypeEnum::FloatType(t) => t.size_of(), + BasicTypeEnum::IntType(t) => t.size_of(), + BasicTypeEnum::PointerType(t) => t.size_of(), + BasicTypeEnum::StructType(t) => t.size_of().unwrap(), + _ => codegen_unreachable!(ctx), + }; + ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size").unwrap() + } + .into(), + ]; + ctx.builder + .build_call(slice_assign_fun, args.as_slice(), "slice_assign") + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_int_value)) + .map(Either::unwrap_left) + .unwrap() + }; + // update length + let need_update = + ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update").unwrap(); + let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); + let update_bb = ctx.ctx.append_basic_block(current, "update"); + let cont_bb = ctx.ctx.append_basic_block(current, "cont"); + ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap(); + ctx.builder.position_at_end(update_bb); + let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len").unwrap(); + dest_arr.store_size(ctx, generator, new_len); + ctx.builder.build_unconditional_branch(cont_bb).unwrap(); + ctx.builder.position_at_end(cont_bb); +} diff --git a/nac3core/src/codegen/irrt/math.rs b/nac3core/src/codegen/irrt/math.rs new file mode 100644 index 0000000..4bc9591 --- /dev/null +++ b/nac3core/src/codegen/irrt/math.rs @@ -0,0 +1,152 @@ +use inkwell::{ + values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue}, + IntPredicate, +}; +use itertools::Either; + +use crate::codegen::{ + macros::codegen_unreachable, + {CodeGenContext, CodeGenerator}, +}; + +// repeated squaring method adapted from GNU Scientific Library: +// https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c +pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + base: IntValue<'ctx>, + exp: IntValue<'ctx>, + signed: bool, +) -> IntValue<'ctx> { + let symbol = match (base.get_type().get_bit_width(), exp.get_type().get_bit_width(), signed) { + (32, 32, true) => "__nac3_int_exp_int32_t", + (64, 64, true) => "__nac3_int_exp_int64_t", + (32, 32, false) => "__nac3_int_exp_uint32_t", + (64, 64, false) => "__nac3_int_exp_uint64_t", + _ => codegen_unreachable!(ctx), + }; + let base_type = base.get_type(); + let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| { + let fn_type = base_type.fn_type(&[base_type.into(), base_type.into()], false); + ctx.module.add_function(symbol, fn_type, None) + }); + // throw exception when exp < 0 + let ge_zero = ctx + .builder + .build_int_compare( + IntPredicate::SGE, + exp, + exp.get_type().const_zero(), + "assert_int_pow_ge_0", + ) + .unwrap(); + ctx.make_assert( + generator, + ge_zero, + "0:ValueError", + "integer power must be positive or zero", + [None, None, None], + ctx.current_loc, + ); + ctx.builder + .build_call(pow_fun, &[base.into(), exp.into()], "call_int_pow") + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_int_value)) + .map(Either::unwrap_left) + .unwrap() +} + +/// Generates a call to `isinf` in IR. Returns an `i1` representing the result. +pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + v: FloatValue<'ctx>, +) -> IntValue<'ctx> { + let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| { + let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false); + ctx.module.add_function("__nac3_isinf", fn_type, None) + }); + + let ret = ctx + .builder + .build_call(intrinsic_fn, &[v.into()], "isinf") + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_int_value)) + .map(Either::unwrap_left) + .unwrap(); + + generator.bool_to_i1(ctx, ret) +} + +/// Generates a call to `isnan` in IR. Returns an `i1` representing the result. +pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + v: FloatValue<'ctx>, +) -> IntValue<'ctx> { + let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| { + let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false); + ctx.module.add_function("__nac3_isnan", fn_type, None) + }); + + let ret = ctx + .builder + .build_call(intrinsic_fn, &[v.into()], "isnan") + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_int_value)) + .map(Either::unwrap_left) + .unwrap(); + + generator.bool_to_i1(ctx, ret) +} + +/// Generates a call to `gamma` in IR. Returns an `f64` representing the result. +pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> { + let llvm_f64 = ctx.ctx.f64_type(); + + let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| { + let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); + ctx.module.add_function("__nac3_gamma", fn_type, None) + }); + + ctx.builder + .build_call(intrinsic_fn, &[v.into()], "gamma") + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_float_value)) + .map(Either::unwrap_left) + .unwrap() +} + +/// Generates a call to `gammaln` in IR. Returns an `f64` representing the result. +pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> { + let llvm_f64 = ctx.ctx.f64_type(); + + let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| { + let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); + ctx.module.add_function("__nac3_gammaln", fn_type, None) + }); + + ctx.builder + .build_call(intrinsic_fn, &[v.into()], "gammaln") + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_float_value)) + .map(Either::unwrap_left) + .unwrap() +} + +/// Generates a call to `j0` in IR. Returns an `f64` representing the result. +pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> { + let llvm_f64 = ctx.ctx.f64_type(); + + let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| { + let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); + ctx.module.add_function("__nac3_j0", fn_type, None) + }); + + ctx.builder + .build_call(intrinsic_fn, &[v.into()], "j0") + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_float_value)) + .map(Either::unwrap_left) + .unwrap() +} diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 7e70a36..f6c4a1e 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -3,25 +3,23 @@ use inkwell::{ context::Context, memory_buffer::MemoryBuffer, module::Module, - types::{BasicTypeEnum, IntType}, - values::{BasicValue, BasicValueEnum, CallSiteValue, FloatValue, IntValue}, - AddressSpace, IntPredicate, + values::{BasicValue, BasicValueEnum, IntValue}, + IntPredicate, }; -use itertools::Either; use nac3parser::ast::Expr; -use super::{ - llvm_intrinsics, - macros::codegen_unreachable, - stmt::gen_for_callback_incrementing, - values::{ - ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, - TypedArrayLikeAccessor, TypedArrayLikeAdapter, UntypedArrayLikeAccessor, - }, - CodeGenContext, CodeGenerator, -}; +use super::{CodeGenContext, CodeGenerator}; use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Type}; +pub use list::*; +pub use math::*; +pub use ndarray::*; +pub use slice::*; + +mod list; +mod math; +mod ndarray; +mod slice; #[must_use] pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> { @@ -62,88 +60,6 @@ pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) irrt_mod } -// repeated squaring method adapted from GNU Scientific Library: -// https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c -pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - base: IntValue<'ctx>, - exp: IntValue<'ctx>, - signed: bool, -) -> IntValue<'ctx> { - let symbol = match (base.get_type().get_bit_width(), exp.get_type().get_bit_width(), signed) { - (32, 32, true) => "__nac3_int_exp_int32_t", - (64, 64, true) => "__nac3_int_exp_int64_t", - (32, 32, false) => "__nac3_int_exp_uint32_t", - (64, 64, false) => "__nac3_int_exp_uint64_t", - _ => codegen_unreachable!(ctx), - }; - let base_type = base.get_type(); - let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| { - let fn_type = base_type.fn_type(&[base_type.into(), base_type.into()], false); - ctx.module.add_function(symbol, fn_type, None) - }); - // throw exception when exp < 0 - let ge_zero = ctx - .builder - .build_int_compare( - IntPredicate::SGE, - exp, - exp.get_type().const_zero(), - "assert_int_pow_ge_0", - ) - .unwrap(); - ctx.make_assert( - generator, - ge_zero, - "0:ValueError", - "integer power must be positive or zero", - [None, None, None], - ctx.current_loc, - ); - ctx.builder - .build_call(pow_fun, &[base.into(), exp.into()], "call_int_pow") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap() -} - -pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - start: IntValue<'ctx>, - end: IntValue<'ctx>, - step: IntValue<'ctx>, -) -> IntValue<'ctx> { - const SYMBOL: &str = "__nac3_range_slice_len"; - let len_func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| { - let i32_t = ctx.ctx.i32_type(); - let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into(), i32_t.into()], false); - ctx.module.add_function(SYMBOL, fn_t, None) - }); - - // assert step != 0, throw exception if not - let not_zero = ctx - .builder - .build_int_compare(IntPredicate::NE, step, step.get_type().const_zero(), "range_step_ne") - .unwrap(); - ctx.make_assert( - generator, - not_zero, - "0:ValueError", - "step must not be zero", - [None, None, None], - ctx.current_loc, - ); - ctx.builder - .build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap() -} - /// NOTE: the output value of the end index of this function should be compared ***inclusively***, /// because python allows `a[2::-1]`, whose semantic is `[a[2], a[1], a[0]]`, which is equivalent to /// NO numeric slice in python. @@ -309,644 +225,3 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>( } })) } - -/// this function allows index out of range, since python -/// allows index out of range in slice (`a = [1,2,3]; a[1:10] == [2,3]`). -pub fn handle_slice_index_bound<'ctx, G: CodeGenerator>( - i: &Expr>, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - length: IntValue<'ctx>, -) -> Result>, String> { - const SYMBOL: &str = "__nac3_slice_index_bound"; - let func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| { - let i32_t = ctx.ctx.i32_type(); - let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into()], false); - ctx.module.add_function(SYMBOL, fn_t, None) - }); - - let i = if let Some(v) = generator.gen_expr(ctx, i)? { - v.to_basic_value_enum(ctx, generator, i.custom.unwrap())? - } else { - return Ok(None); - }; - Ok(Some( - ctx.builder - .build_call(func, &[i.into(), length.into()], "bounded_ind") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap(), - )) -} - -/// This function handles 'end' **inclusively**. -/// Order of tuples `assign_idx` and `value_idx` is ('start', 'end', 'step'). -/// Negative index should be handled before entering this function -pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - ty: BasicTypeEnum<'ctx>, - dest_arr: ListValue<'ctx>, - dest_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), - src_arr: ListValue<'ctx>, - src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), -) { - let size_ty = generator.get_size_type(ctx.ctx); - let int8_ptr = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); - let int32 = ctx.ctx.i32_type(); - let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", int8_ptr); - let slice_assign_fun = { - let ty_vec = vec![ - int32.into(), // dest start idx - int32.into(), // dest end idx - int32.into(), // dest step - elem_ptr_type.into(), // dest arr ptr - int32.into(), // dest arr len - int32.into(), // src start idx - int32.into(), // src end idx - int32.into(), // src step - elem_ptr_type.into(), // src arr ptr - int32.into(), // src arr len - int32.into(), // size - ]; - ctx.module.get_function(fun_symbol).unwrap_or_else(|| { - let fn_t = int32.fn_type(ty_vec.as_slice(), false); - ctx.module.add_function(fun_symbol, fn_t, None) - }) - }; - - let zero = int32.const_zero(); - let one = int32.const_int(1, false); - let dest_arr_ptr = dest_arr.data().base_ptr(ctx, generator); - let dest_arr_ptr = - ctx.builder.build_pointer_cast(dest_arr_ptr, elem_ptr_type, "dest_arr_ptr_cast").unwrap(); - let dest_len = dest_arr.load_size(ctx, Some("dest.len")); - let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32").unwrap(); - let src_arr_ptr = src_arr.data().base_ptr(ctx, generator); - let src_arr_ptr = - ctx.builder.build_pointer_cast(src_arr_ptr, elem_ptr_type, "src_arr_ptr_cast").unwrap(); - let src_len = src_arr.load_size(ctx, Some("src.len")); - let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32").unwrap(); - - // index in bound and positive should be done - // assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and - // throw exception if not satisfied - let src_end = ctx - .builder - .build_select( - ctx.builder.build_int_compare(IntPredicate::SLT, src_idx.2, zero, "is_neg").unwrap(), - ctx.builder.build_int_sub(src_idx.1, one, "e_min_one").unwrap(), - ctx.builder.build_int_add(src_idx.1, one, "e_add_one").unwrap(), - "final_e", - ) - .map(BasicValueEnum::into_int_value) - .unwrap(); - let dest_end = ctx - .builder - .build_select( - ctx.builder.build_int_compare(IntPredicate::SLT, dest_idx.2, zero, "is_neg").unwrap(), - ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one").unwrap(), - ctx.builder.build_int_add(dest_idx.1, one, "e_add_one").unwrap(), - "final_e", - ) - .map(BasicValueEnum::into_int_value) - .unwrap(); - let src_slice_len = - calculate_len_for_slice_range(generator, ctx, src_idx.0, src_end, src_idx.2); - let dest_slice_len = - calculate_len_for_slice_range(generator, ctx, dest_idx.0, dest_end, dest_idx.2); - let src_eq_dest = ctx - .builder - .build_int_compare(IntPredicate::EQ, src_slice_len, dest_slice_len, "slice_src_eq_dest") - .unwrap(); - let src_slt_dest = ctx - .builder - .build_int_compare(IntPredicate::SLT, src_slice_len, dest_slice_len, "slice_src_slt_dest") - .unwrap(); - let dest_step_eq_one = ctx - .builder - .build_int_compare( - IntPredicate::EQ, - dest_idx.2, - dest_idx.2.get_type().const_int(1, false), - "slice_dest_step_eq_one", - ) - .unwrap(); - let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1").unwrap(); - let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond").unwrap(); - ctx.make_assert( - generator, - cond, - "0:ValueError", - "attempt to assign sequence of size {0} to slice of size {1} with step size {2}", - [Some(src_slice_len), Some(dest_slice_len), Some(dest_idx.2)], - ctx.current_loc, - ); - - let new_len = { - let args = vec![ - dest_idx.0.into(), // dest start idx - dest_idx.1.into(), // dest end idx - dest_idx.2.into(), // dest step - dest_arr_ptr.into(), // dest arr ptr - dest_len.into(), // dest arr len - src_idx.0.into(), // src start idx - src_idx.1.into(), // src end idx - src_idx.2.into(), // src step - src_arr_ptr.into(), // src arr ptr - src_len.into(), // src arr len - { - let s = match ty { - BasicTypeEnum::FloatType(t) => t.size_of(), - BasicTypeEnum::IntType(t) => t.size_of(), - BasicTypeEnum::PointerType(t) => t.size_of(), - BasicTypeEnum::StructType(t) => t.size_of().unwrap(), - _ => codegen_unreachable!(ctx), - }; - ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size").unwrap() - } - .into(), - ]; - ctx.builder - .build_call(slice_assign_fun, args.as_slice(), "slice_assign") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap() - }; - // update length - let need_update = - ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update").unwrap(); - let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); - let update_bb = ctx.ctx.append_basic_block(current, "update"); - let cont_bb = ctx.ctx.append_basic_block(current, "cont"); - ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap(); - ctx.builder.position_at_end(update_bb); - let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len").unwrap(); - dest_arr.store_size(ctx, generator, new_len); - ctx.builder.build_unconditional_branch(cont_bb).unwrap(); - ctx.builder.position_at_end(cont_bb); -} - -/// Generates a call to `isinf` in IR. Returns an `i1` representing the result. -pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &CodeGenContext<'ctx, '_>, - v: FloatValue<'ctx>, -) -> IntValue<'ctx> { - let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| { - let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false); - ctx.module.add_function("__nac3_isinf", fn_type, None) - }); - - let ret = ctx - .builder - .build_call(intrinsic_fn, &[v.into()], "isinf") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap(); - - generator.bool_to_i1(ctx, ret) -} - -/// Generates a call to `isnan` in IR. Returns an `i1` representing the result. -pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &CodeGenContext<'ctx, '_>, - v: FloatValue<'ctx>, -) -> IntValue<'ctx> { - let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| { - let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false); - ctx.module.add_function("__nac3_isnan", fn_type, None) - }); - - let ret = ctx - .builder - .build_call(intrinsic_fn, &[v.into()], "isnan") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap(); - - generator.bool_to_i1(ctx, ret) -} - -/// Generates a call to `gamma` in IR. Returns an `f64` representing the result. -pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> { - let llvm_f64 = ctx.ctx.f64_type(); - - let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| { - let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); - ctx.module.add_function("__nac3_gamma", fn_type, None) - }); - - ctx.builder - .build_call(intrinsic_fn, &[v.into()], "gamma") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_float_value)) - .map(Either::unwrap_left) - .unwrap() -} - -/// Generates a call to `gammaln` in IR. Returns an `f64` representing the result. -pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> { - let llvm_f64 = ctx.ctx.f64_type(); - - let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| { - let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); - ctx.module.add_function("__nac3_gammaln", fn_type, None) - }); - - ctx.builder - .build_call(intrinsic_fn, &[v.into()], "gammaln") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_float_value)) - .map(Either::unwrap_left) - .unwrap() -} - -/// Generates a call to `j0` in IR. Returns an `f64` representing the result. -pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> { - let llvm_f64 = ctx.ctx.f64_type(); - - let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| { - let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false); - ctx.module.add_function("__nac3_j0", fn_type, None) - }); - - ctx.builder - .build_call(intrinsic_fn, &[v.into()], "j0") - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_float_value)) - .map(Either::unwrap_left) - .unwrap() -} - -/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the -/// calculated total size. -/// -/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension. -/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for, -/// or [`None`] if starting from the first dimension and ending at the last dimension -/// respectively. -pub fn call_ndarray_calc_size<'ctx, G, Dims>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - dims: &Dims, - (begin, end): (Option>, Option>), -) -> IntValue<'ctx> -where - G: CodeGenerator + ?Sized, - Dims: ArrayLikeIndexer<'ctx>, -{ - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_size", - 64 => "__nac3_ndarray_calc_size64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; - let ndarray_calc_size_fn_t = llvm_usize.fn_type( - &[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()], - false, - ); - let ndarray_calc_size_fn = - ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| { - ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None) - }); - - let begin = begin.unwrap_or_else(|| llvm_usize.const_zero()); - let end = end.unwrap_or_else(|| dims.size(ctx, generator)); - ctx.builder - .build_call( - ndarray_calc_size_fn, - &[ - dims.base_ptr(ctx, generator).into(), - dims.size(ctx, generator).into(), - begin.into(), - end.into(), - ], - "", - ) - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap() -} - -/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`] -/// containing `i32` indices of the flattened index. -/// -/// * `index` - The index to compute the multidimensional index for. -/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an -/// `NDArray`. -pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, - ctx: &mut CodeGenContext<'ctx, '_>, - index: IntValue<'ctx>, - ndarray: NDArrayValue<'ctx>, -) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { - let llvm_void = ctx.ctx.void_type(); - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_nd_indices", - 64 => "__nac3_ndarray_calc_nd_indices64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; - let ndarray_calc_nd_indices_fn = - ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| { - let fn_type = llvm_void.fn_type( - &[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()], - false, - ); - - ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None) - }); - - let ndarray_num_dims = ndarray.load_ndims(ctx); - let ndarray_dims = ndarray.dim_sizes(); - - let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap(); - - ctx.builder - .build_call( - ndarray_calc_nd_indices_fn, - &[ - index.into(), - ndarray_dims.base_ptr(ctx, generator).into(), - ndarray_num_dims.into(), - indices.into(), - ], - "", - ) - .unwrap(); - - TypedArrayLikeAdapter::from( - ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None), - Box::new(|_, v| v.into_int_value()), - Box::new(|_, v| v.into()), - ) -} - -fn call_ndarray_flatten_index_impl<'ctx, G, Indices>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ndarray: NDArrayValue<'ctx>, - indices: &Indices, -) -> IntValue<'ctx> -where - G: CodeGenerator + ?Sized, - Indices: ArrayLikeIndexer<'ctx>, -{ - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - debug_assert_eq!( - IntType::try_from(indices.element_type(ctx, generator)) - .map(IntType::get_bit_width) - .unwrap_or_default(), - llvm_i32.get_bit_width(), - "Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`" - ); - debug_assert_eq!( - indices.size(ctx, generator).get_type().get_bit_width(), - llvm_usize.get_bit_width(), - "Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`" - ); - - let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_flatten_index", - 64 => "__nac3_ndarray_flatten_index64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; - let ndarray_flatten_index_fn = - ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| { - let fn_type = llvm_usize.fn_type( - &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()], - false, - ); - - ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None) - }); - - let ndarray_num_dims = ndarray.load_ndims(ctx); - let ndarray_dims = ndarray.dim_sizes(); - - let index = ctx - .builder - .build_call( - ndarray_flatten_index_fn, - &[ - ndarray_dims.base_ptr(ctx, generator).into(), - ndarray_num_dims.into(), - indices.base_ptr(ctx, generator).into(), - indices.size(ctx, generator).into(), - ], - "", - ) - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap(); - - index -} - -/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the -/// multidimensional index. -/// -/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an -/// `NDArray`. -/// * `indices` - The multidimensional index to compute the flattened index for. -pub fn call_ndarray_flatten_index<'ctx, G, Index>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - ndarray: NDArrayValue<'ctx>, - indices: &Index, -) -> IntValue<'ctx> -where - G: CodeGenerator + ?Sized, - Index: ArrayLikeIndexer<'ctx>, -{ - call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices) -} - -/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of -/// dimension and size of each dimension of the resultant `ndarray`. -pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - lhs: NDArrayValue<'ctx>, - rhs: NDArrayValue<'ctx>, -) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_broadcast", - 64 => "__nac3_ndarray_calc_broadcast64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; - let ndarray_calc_broadcast_fn = - ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { - let fn_type = llvm_usize.fn_type( - &[ - llvm_pusize.into(), - llvm_usize.into(), - llvm_pusize.into(), - llvm_usize.into(), - llvm_pusize.into(), - ], - false, - ); - - ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) - }); - - let lhs_ndims = lhs.load_ndims(ctx); - let rhs_ndims = rhs.load_ndims(ctx); - let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (min_ndims, false), - |generator, ctx, _, idx| { - let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap(); - let (lhs_dim_sz, rhs_dim_sz) = unsafe { - ( - lhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None), - rhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None), - ) - }; - - let llvm_usize_const_one = llvm_usize.const_int(1, false); - let lhs_eqz = ctx - .builder - .build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "") - .unwrap(); - let rhs_eqz = ctx - .builder - .build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "") - .unwrap(); - let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap(); - - let lhs_eq_rhs = ctx - .builder - .build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "") - .unwrap(); - - let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap(); - - ctx.make_assert( - generator, - is_compatible, - "0:ValueError", - "operands could not be broadcast together", - [None, None, None], - ctx.current_loc, - ); - - Ok(()) - }, - llvm_usize.const_int(1, false), - ) - .unwrap(); - - let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None); - let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator); - let lhs_ndims = lhs.load_ndims(ctx); - let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator); - let rhs_ndims = rhs.load_ndims(ctx); - let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap(); - let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None); - - ctx.builder - .build_call( - ndarray_calc_broadcast_fn, - &[ - lhs_dims.into(), - lhs_ndims.into(), - rhs_dims.into(), - rhs_ndims.into(), - out_dims.base_ptr(ctx, generator).into(), - ], - "", - ) - .unwrap(); - - TypedArrayLikeAdapter::from( - out_dims, - Box::new(|_, v| v.into_int_value()), - Box::new(|_, v| v.into()), - ) -} - -/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`] -/// containing the indices used for accessing `array` corresponding to the index of the broadcasted -/// array `broadcast_idx`. -pub fn call_ndarray_calc_broadcast_index< - 'ctx, - G: CodeGenerator + ?Sized, - BroadcastIdx: UntypedArrayLikeAccessor<'ctx>, ->( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - array: NDArrayValue<'ctx>, - broadcast_idx: &BroadcastIdx, -) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_broadcast_idx", - 64 => "__nac3_ndarray_calc_broadcast_idx64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; - let ndarray_calc_broadcast_fn = - ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { - let fn_type = llvm_usize.fn_type( - &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()], - false, - ); - - ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) - }); - - let broadcast_size = broadcast_idx.size(ctx, generator); - let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap(); - - let array_dims = array.dim_sizes().base_ptr(ctx, generator); - let array_ndims = array.load_ndims(ctx); - let broadcast_idx_ptr = unsafe { - broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - - ctx.builder - .build_call( - ndarray_calc_broadcast_fn, - &[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()], - "", - ) - .unwrap(); - - TypedArrayLikeAdapter::from( - ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None), - Box::new(|_, v| v.into_int_value()), - Box::new(|_, v| v.into()), - ) -} diff --git a/nac3core/src/codegen/irrt/ndarray.rs b/nac3core/src/codegen/irrt/ndarray.rs new file mode 100644 index 0000000..bfec1d5 --- /dev/null +++ b/nac3core/src/codegen/irrt/ndarray.rs @@ -0,0 +1,384 @@ +use inkwell::{ + types::IntType, + values::{BasicValueEnum, CallSiteValue, IntValue}, + AddressSpace, IntPredicate, +}; +use itertools::Either; + +use crate::codegen::{ + llvm_intrinsics, + macros::codegen_unreachable, + stmt::gen_for_callback_incrementing, + values::{ + ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, NDArrayValue, TypedArrayLikeAccessor, + TypedArrayLikeAdapter, UntypedArrayLikeAccessor, + }, + CodeGenContext, CodeGenerator, +}; + +/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the +/// calculated total size. +/// +/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension. +/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for, +/// or [`None`] if starting from the first dimension and ending at the last dimension +/// respectively. +pub fn call_ndarray_calc_size<'ctx, G, Dims>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + dims: &Dims, + (begin, end): (Option>, Option>), +) -> IntValue<'ctx> +where + G: CodeGenerator + ?Sized, + Dims: ArrayLikeIndexer<'ctx>, +{ + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + + let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() { + 32 => "__nac3_ndarray_calc_size", + 64 => "__nac3_ndarray_calc_size64", + bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), + }; + let ndarray_calc_size_fn_t = llvm_usize.fn_type( + &[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()], + false, + ); + let ndarray_calc_size_fn = + ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| { + ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None) + }); + + let begin = begin.unwrap_or_else(|| llvm_usize.const_zero()); + let end = end.unwrap_or_else(|| dims.size(ctx, generator)); + ctx.builder + .build_call( + ndarray_calc_size_fn, + &[ + dims.base_ptr(ctx, generator).into(), + dims.size(ctx, generator).into(), + begin.into(), + end.into(), + ], + "", + ) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_int_value)) + .map(Either::unwrap_left) + .unwrap() +} + +/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`] +/// containing `i32` indices of the flattened index. +/// +/// * `index` - The index to compute the multidimensional index for. +/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an +/// `NDArray`. +pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( + generator: &G, + ctx: &mut CodeGenContext<'ctx, '_>, + index: IntValue<'ctx>, + ndarray: NDArrayValue<'ctx>, +) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { + let llvm_void = ctx.ctx.void_type(); + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); + let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + + let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() { + 32 => "__nac3_ndarray_calc_nd_indices", + 64 => "__nac3_ndarray_calc_nd_indices64", + bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), + }; + let ndarray_calc_nd_indices_fn = + ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| { + let fn_type = llvm_void.fn_type( + &[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()], + false, + ); + + ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None) + }); + + let ndarray_num_dims = ndarray.load_ndims(ctx); + let ndarray_dims = ndarray.dim_sizes(); + + let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap(); + + ctx.builder + .build_call( + ndarray_calc_nd_indices_fn, + &[ + index.into(), + ndarray_dims.base_ptr(ctx, generator).into(), + ndarray_num_dims.into(), + indices.into(), + ], + "", + ) + .unwrap(); + + TypedArrayLikeAdapter::from( + ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None), + Box::new(|_, v| v.into_int_value()), + Box::new(|_, v| v.into()), + ) +} + +fn call_ndarray_flatten_index_impl<'ctx, G, Indices>( + generator: &G, + ctx: &CodeGenContext<'ctx, '_>, + ndarray: NDArrayValue<'ctx>, + indices: &Indices, +) -> IntValue<'ctx> +where + G: CodeGenerator + ?Sized, + Indices: ArrayLikeIndexer<'ctx>, +{ + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + + let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); + let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + + debug_assert_eq!( + IntType::try_from(indices.element_type(ctx, generator)) + .map(IntType::get_bit_width) + .unwrap_or_default(), + llvm_i32.get_bit_width(), + "Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`" + ); + debug_assert_eq!( + indices.size(ctx, generator).get_type().get_bit_width(), + llvm_usize.get_bit_width(), + "Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`" + ); + + let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() { + 32 => "__nac3_ndarray_flatten_index", + 64 => "__nac3_ndarray_flatten_index64", + bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), + }; + let ndarray_flatten_index_fn = + ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| { + let fn_type = llvm_usize.fn_type( + &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()], + false, + ); + + ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None) + }); + + let ndarray_num_dims = ndarray.load_ndims(ctx); + let ndarray_dims = ndarray.dim_sizes(); + + let index = ctx + .builder + .build_call( + ndarray_flatten_index_fn, + &[ + ndarray_dims.base_ptr(ctx, generator).into(), + ndarray_num_dims.into(), + indices.base_ptr(ctx, generator).into(), + indices.size(ctx, generator).into(), + ], + "", + ) + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_int_value)) + .map(Either::unwrap_left) + .unwrap(); + + index +} + +/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the +/// multidimensional index. +/// +/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an +/// `NDArray`. +/// * `indices` - The multidimensional index to compute the flattened index for. +pub fn call_ndarray_flatten_index<'ctx, G, Index>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: NDArrayValue<'ctx>, + indices: &Index, +) -> IntValue<'ctx> +where + G: CodeGenerator + ?Sized, + Index: ArrayLikeIndexer<'ctx>, +{ + call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices) +} + +/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of +/// dimension and size of each dimension of the resultant `ndarray`. +pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + lhs: NDArrayValue<'ctx>, + rhs: NDArrayValue<'ctx>, +) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + + let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { + 32 => "__nac3_ndarray_calc_broadcast", + 64 => "__nac3_ndarray_calc_broadcast64", + bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), + }; + let ndarray_calc_broadcast_fn = + ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { + let fn_type = llvm_usize.fn_type( + &[ + llvm_pusize.into(), + llvm_usize.into(), + llvm_pusize.into(), + llvm_usize.into(), + llvm_pusize.into(), + ], + false, + ); + + ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) + }); + + let lhs_ndims = lhs.load_ndims(ctx); + let rhs_ndims = rhs.load_ndims(ctx); + let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None); + + gen_for_callback_incrementing( + generator, + ctx, + None, + llvm_usize.const_zero(), + (min_ndims, false), + |generator, ctx, _, idx| { + let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap(); + let (lhs_dim_sz, rhs_dim_sz) = unsafe { + ( + lhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None), + rhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None), + ) + }; + + let llvm_usize_const_one = llvm_usize.const_int(1, false); + let lhs_eqz = ctx + .builder + .build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "") + .unwrap(); + let rhs_eqz = ctx + .builder + .build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "") + .unwrap(); + let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap(); + + let lhs_eq_rhs = ctx + .builder + .build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "") + .unwrap(); + + let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap(); + + ctx.make_assert( + generator, + is_compatible, + "0:ValueError", + "operands could not be broadcast together", + [None, None, None], + ctx.current_loc, + ); + + Ok(()) + }, + llvm_usize.const_int(1, false), + ) + .unwrap(); + + let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None); + let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator); + let lhs_ndims = lhs.load_ndims(ctx); + let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator); + let rhs_ndims = rhs.load_ndims(ctx); + let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap(); + let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None); + + ctx.builder + .build_call( + ndarray_calc_broadcast_fn, + &[ + lhs_dims.into(), + lhs_ndims.into(), + rhs_dims.into(), + rhs_ndims.into(), + out_dims.base_ptr(ctx, generator).into(), + ], + "", + ) + .unwrap(); + + TypedArrayLikeAdapter::from( + out_dims, + Box::new(|_, v| v.into_int_value()), + Box::new(|_, v| v.into()), + ) +} + +/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`] +/// containing the indices used for accessing `array` corresponding to the index of the broadcasted +/// array `broadcast_idx`. +pub fn call_ndarray_calc_broadcast_index< + 'ctx, + G: CodeGenerator + ?Sized, + BroadcastIdx: UntypedArrayLikeAccessor<'ctx>, +>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + array: NDArrayValue<'ctx>, + broadcast_idx: &BroadcastIdx, +) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); + let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + + let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { + 32 => "__nac3_ndarray_calc_broadcast_idx", + 64 => "__nac3_ndarray_calc_broadcast_idx64", + bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), + }; + let ndarray_calc_broadcast_fn = + ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { + let fn_type = llvm_usize.fn_type( + &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()], + false, + ); + + ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) + }); + + let broadcast_size = broadcast_idx.size(ctx, generator); + let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap(); + + let array_dims = array.dim_sizes().base_ptr(ctx, generator); + let array_ndims = array.load_ndims(ctx); + let broadcast_idx_ptr = unsafe { + broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + }; + + ctx.builder + .build_call( + ndarray_calc_broadcast_fn, + &[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()], + "", + ) + .unwrap(); + + TypedArrayLikeAdapter::from( + ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None), + Box::new(|_, v| v.into_int_value()), + Box::new(|_, v| v.into()), + ) +} diff --git a/nac3core/src/codegen/irrt/slice.rs b/nac3core/src/codegen/irrt/slice.rs new file mode 100644 index 0000000..eb7037a --- /dev/null +++ b/nac3core/src/codegen/irrt/slice.rs @@ -0,0 +1,76 @@ +use inkwell::{ + values::{BasicValueEnum, CallSiteValue, IntValue}, + IntPredicate, +}; +use itertools::Either; +use nac3parser::ast::Expr; + +use crate::{ + codegen::{CodeGenContext, CodeGenerator}, + typecheck::typedef::Type, +}; + +/// this function allows index out of range, since python +/// allows index out of range in slice (`a = [1,2,3]; a[1:10] == [2,3]`). +pub fn handle_slice_index_bound<'ctx, G: CodeGenerator>( + i: &Expr>, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut G, + length: IntValue<'ctx>, +) -> Result>, String> { + const SYMBOL: &str = "__nac3_slice_index_bound"; + let func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| { + let i32_t = ctx.ctx.i32_type(); + let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into()], false); + ctx.module.add_function(SYMBOL, fn_t, None) + }); + + let i = if let Some(v) = generator.gen_expr(ctx, i)? { + v.to_basic_value_enum(ctx, generator, i.custom.unwrap())? + } else { + return Ok(None); + }; + Ok(Some( + ctx.builder + .build_call(func, &[i.into(), length.into()], "bounded_ind") + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_int_value)) + .map(Either::unwrap_left) + .unwrap(), + )) +} + +pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + start: IntValue<'ctx>, + end: IntValue<'ctx>, + step: IntValue<'ctx>, +) -> IntValue<'ctx> { + const SYMBOL: &str = "__nac3_range_slice_len"; + let len_func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| { + let i32_t = ctx.ctx.i32_type(); + let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into(), i32_t.into()], false); + ctx.module.add_function(SYMBOL, fn_t, None) + }); + + // assert step != 0, throw exception if not + let not_zero = ctx + .builder + .build_int_compare(IntPredicate::NE, step, step.get_type().const_zero(), "range_step_ne") + .unwrap(); + ctx.make_assert( + generator, + not_zero, + "0:ValueError", + "step must not be zero", + [None, None, None], + ctx.current_loc, + ); + ctx.builder + .build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len") + .map(CallSiteValue::try_as_basic_value) + .map(|v| v.map_left(BasicValueEnum::into_int_value)) + .map(Either::unwrap_left) + .unwrap() +} -- 2.44.2 From c58ce9c3a9b0e23b352c3ad2d147d7d6511697f7 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 28 Aug 2024 16:33:03 +0800 Subject: [PATCH 15/16] [core] codegen/types: Implement NDArray in terms of i8* Better aligns with the future implementation of ndstrides. --- nac3artiq/src/codegen.rs | 12 +- nac3core/src/codegen/builtin_fns.rs | 107 +++++++++------ nac3core/src/codegen/expr.rs | 66 ++++++--- nac3core/src/codegen/numpy.rs | 181 +++++++++++++++++++------ nac3core/src/codegen/types/ndarray.rs | 49 ++++--- nac3core/src/codegen/values/ndarray.rs | 108 +++++++++++++-- nac3standalone/demo/src/ndarray.py | 3 +- 7 files changed, 378 insertions(+), 148 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 1fcfd4b..aece926 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -461,8 +461,7 @@ fn format_rpc_arg<'ctx>( let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let llvm_arg_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty); - let llvm_arg = - NDArrayValue::from_pointer_value(arg.into_pointer_value(), llvm_usize, None); + let llvm_arg = llvm_arg_ty.map_value(arg.into_pointer_value(), None); let llvm_usize_sizeof = ctx .builder @@ -1369,12 +1368,17 @@ fn polymorphic_print<'ctx>( TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); fmt.push_str("array(["); flush(ctx, generator, &mut fmt, &mut args); - let val = - NDArrayValue::from_pointer_value(value.into_pointer_value(), llvm_usize, None); + let val = NDArrayValue::from_pointer_value( + value.into_pointer_value(), + llvm_elem_ty, + llvm_usize, + None, + ); let len = call_ndarray_calc_size(generator, ctx, &val.dim_sizes(), (None, None)); let last = ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap(); diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 7765753..e693faf 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -21,7 +21,10 @@ use super::{ CodeGenContext, CodeGenerator, }; use crate::{ - toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys}, + toplevel::{ + helper::{arraylike_flatten_element_type, PrimDef}, + numpy::unpack_ndarray_var_tys, + }, typecheck::typedef::{Type, TypeEnum}, }; @@ -65,10 +68,15 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>( ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap() } TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { + let elem_ty = arraylike_flatten_element_type(&mut ctx.unifier, arg_ty); let llvm_usize = generator.get_size_type(ctx.ctx); - let arg = - NDArrayValue::from_pointer_value(arg.into_pointer_value(), llvm_usize, None); + let arg = NDArrayValue::from_pointer_value( + arg.into_pointer_value(), + ctx.get_llvm_type(generator, elem_ty), + llvm_usize, + None, + ); let ndims = arg.dim_sizes().size(ctx, generator); ctx.make_assert( @@ -143,13 +151,14 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let ndarray = ndarray_elementwise_unaryop_impl( generator, ctx, ctx.primitives.int32, None, - NDArrayValue::from_pointer_value(n, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), |generator, ctx, val| call_int32(generator, ctx, (elem_ty, val)), )?; @@ -205,13 +214,14 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let ndarray = ndarray_elementwise_unaryop_impl( generator, ctx, ctx.primitives.int64, None, - NDArrayValue::from_pointer_value(n, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), |generator, ctx, val| call_int64(generator, ctx, (elem_ty, val)), )?; @@ -283,13 +293,14 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let ndarray = ndarray_elementwise_unaryop_impl( generator, ctx, ctx.primitives.uint32, None, - NDArrayValue::from_pointer_value(n, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), |generator, ctx, val| call_uint32(generator, ctx, (elem_ty, val)), )?; @@ -350,13 +361,14 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let ndarray = ndarray_elementwise_unaryop_impl( generator, ctx, ctx.primitives.uint64, None, - NDArrayValue::from_pointer_value(n, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), |generator, ctx, val| call_uint64(generator, ctx, (elem_ty, val)), )?; @@ -416,13 +428,14 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let ndarray = ndarray_elementwise_unaryop_impl( generator, ctx, ctx.primitives.float, None, - NDArrayValue::from_pointer_value(n, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), |generator, ctx, val| call_float(generator, ctx, (elem_ty, val)), )?; @@ -462,13 +475,14 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let ndarray = ndarray_elementwise_unaryop_impl( generator, ctx, ret_elem_ty, None, - NDArrayValue::from_pointer_value(n, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), |generator, ctx, val| call_round(generator, ctx, (elem_ty, val), ret_elem_ty), )?; @@ -502,13 +516,14 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let ndarray = ndarray_elementwise_unaryop_impl( generator, ctx, ctx.primitives.float, None, - NDArrayValue::from_pointer_value(n, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), |generator, ctx, val| call_numpy_round(generator, ctx, (elem_ty, val)), )?; @@ -567,13 +582,14 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let ndarray = ndarray_elementwise_unaryop_impl( generator, ctx, ctx.primitives.bool, None, - NDArrayValue::from_pointer_value(n, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), |generator, ctx, val| { let elem = call_bool(generator, ctx, (elem_ty, val))?; @@ -621,13 +637,14 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let ndarray = ndarray_elementwise_unaryop_impl( generator, ctx, ret_elem_ty, None, - NDArrayValue::from_pointer_value(n, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), |generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty), )?; @@ -671,13 +688,14 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let ndarray = ndarray_elementwise_unaryop_impl( generator, ctx, ret_elem_ty, None, - NDArrayValue::from_pointer_value(n, llvm_usize, None), + NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None), |generator, ctx, val| call_ceil(generator, ctx, (elem_ty, val), ret_elem_ty), )?; @@ -806,8 +824,8 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( ctx, dtype, None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), + (x1_ty, x1, !is_ndarray1), + (x2_ty, x2, !is_ndarray2), |generator, ctx, (lhs, rhs)| { call_numpy_minimum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) }, @@ -906,9 +924,9 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); - let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); - let n = NDArrayValue::from_pointer_value(n, llvm_usize, None); + let n = NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None); let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None)); if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { let n_sz_eqz = ctx @@ -926,7 +944,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( ); } - let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; + let accumulator_addr = generator.gen_var_alloc(ctx, llvm_elem_ty, None)?; let res_idx = generator.gen_var_alloc(ctx, llvm_int64.into(), None)?; unsafe { @@ -1068,8 +1086,8 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>( ctx, dtype, None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), + (x1_ty, x1, !is_ndarray1), + (x2_ty, x2, !is_ndarray2), |generator, ctx, (lhs, rhs)| { call_numpy_maximum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) }, @@ -1114,6 +1132,7 @@ where { let llvm_usize = generator.get_size_type(ctx.ctx); let (arg_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); + let llvm_arg_elem_ty = ctx.get_llvm_type(generator, arg_elem_ty); let ret_elem_ty = get_ret_elem_type(ctx, arg_elem_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1121,7 +1140,7 @@ where ctx, ret_elem_ty, None, - NDArrayValue::from_pointer_value(x, llvm_usize, None), + NDArrayValue::from_pointer_value(x, llvm_arg_elem_ty, llvm_usize, None), |generator, ctx, elem_val| { helper_call_numpy_unary_elementwise( generator, @@ -1508,8 +1527,8 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>( ctx, dtype, None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), + (x1_ty, x1, !is_ndarray1), + (x2_ty, x2, !is_ndarray2), |generator, ctx, (lhs, rhs)| { call_numpy_arctan2(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) }, @@ -1575,8 +1594,8 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>( ctx, dtype, None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), + (x1_ty, x1, !is_ndarray1), + (x2_ty, x2, !is_ndarray2), |generator, ctx, (lhs, rhs)| { call_numpy_copysign(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) }, @@ -1642,8 +1661,8 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>( ctx, dtype, None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), + (x1_ty, x1, !is_ndarray1), + (x2_ty, x2, !is_ndarray2), |generator, ctx, (lhs, rhs)| { call_numpy_fmax(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) }, @@ -1709,8 +1728,8 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>( ctx, dtype, None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), + (x1_ty, x1, !is_ndarray1), + (x2_ty, x2, !is_ndarray2), |generator, ctx, (lhs, rhs)| { call_numpy_fmin(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) }, @@ -1765,8 +1784,8 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>( ctx, dtype, None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), + (x1_ty, x1, !is_ndarray1), + (x2_ty, x2, !is_ndarray2), |generator, ctx, (lhs, rhs)| { call_numpy_ldexp(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) }, @@ -1832,8 +1851,8 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>( ctx, dtype, None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), + (x1_ty, x1, !is_ndarray1), + (x2_ty, x2, !is_ndarray2), |generator, ctx, (lhs, rhs)| { call_numpy_hypot(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) }, @@ -1899,8 +1918,8 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>( ctx, dtype, None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), + (x1_ty, x1, !is_ndarray1), + (x2_ty, x2, !is_ndarray2), |generator, ctx, (lhs, rhs)| { call_numpy_nextafter(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) }, @@ -1960,7 +1979,7 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); }; - let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let dim0 = unsafe { n1.dim_sizes() .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) @@ -2002,7 +2021,7 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( unimplemented!("{FN_NAME} operates on float type NdArrays only"); }; - let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let dim0 = unsafe { n1.dim_sizes() .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) @@ -2052,7 +2071,7 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); }; - let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let dim0 = unsafe { n1.dim_sizes() @@ -2107,7 +2126,7 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); }; - let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let dim0 = unsafe { n1.dim_sizes() .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) @@ -2149,7 +2168,7 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); }; - let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let dim0 = unsafe { n1.dim_sizes() @@ -2192,7 +2211,7 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); }; - let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let dim0 = unsafe { n1.dim_sizes() @@ -2245,7 +2264,7 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]); }; - let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); // Changing second parameter to a `NDArray` for uniformity in function call let n2_array = numpy::create_ndarray_const_shape( generator, @@ -2340,7 +2359,7 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); }; - let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let dim0 = unsafe { n1.dim_sizes() @@ -2383,7 +2402,7 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); }; - let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); + let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let dim0 = unsafe { n1.dim_sizes() diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 93720f9..01047b3 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1564,10 +1564,21 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - let left_val = - NDArrayValue::from_pointer_value(left_val.into_pointer_value(), llvm_usize, None); - let right_val = - NDArrayValue::from_pointer_value(right_val.into_pointer_value(), llvm_usize, None); + let llvm_ndarray_dtype1 = ctx.get_llvm_type(generator, ndarray_dtype1); + let llvm_ndarray_dtype2 = ctx.get_llvm_type(generator, ndarray_dtype2); + + let left_val = NDArrayValue::from_pointer_value( + left_val.into_pointer_value(), + llvm_ndarray_dtype1, + llvm_usize, + None, + ); + let right_val = NDArrayValue::from_pointer_value( + right_val.into_pointer_value(), + llvm_ndarray_dtype2, + llvm_usize, + None, + ); let res = if op.base == Operator::MatMult { // MatMult is the only binop which is not an elementwise op @@ -1591,8 +1602,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( BinopVariant::Normal => None, BinopVariant::AugAssign => Some(left_val), }, - (left_val.as_base_value().into(), false), - (right_val.as_base_value().into(), false), + (ty1, left_val.as_base_value().into(), false), + (ty2, right_val.as_base_value().into(), false), |generator, ctx, (lhs, rhs)| { gen_binop_expr_with_values( generator, @@ -1616,8 +1627,10 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( } else { let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, if is_ndarray1 { ty1 } else { ty2 }); + let llvm_ndarray_dtype = ctx.get_llvm_type(generator, ndarray_dtype); let ndarray_val = NDArrayValue::from_pointer_value( if is_ndarray1 { left_val } else { right_val }.into_pointer_value(), + llvm_ndarray_dtype, llvm_usize, None, ); @@ -1629,8 +1642,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( BinopVariant::Normal => None, BinopVariant::AugAssign => Some(ndarray_val), }, - (left_val, !is_ndarray1), - (right_val, !is_ndarray2), + (ty1, left_val, !is_ndarray1), + (ty2, right_val, !is_ndarray2), |generator, ctx, (lhs, rhs)| { gen_binop_expr_with_values( generator, @@ -1810,8 +1823,14 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( } else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { let llvm_usize = generator.get_size_type(ctx.ctx); let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); + let llvm_ndarray_dtype = ctx.get_llvm_type(generator, ndarray_dtype); - let val = NDArrayValue::from_pointer_value(val.into_pointer_value(), llvm_usize, None); + let val = NDArrayValue::from_pointer_value( + val.into_pointer_value(), + llvm_ndarray_dtype, + llvm_usize, + None, + ); // ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before // passing it to the elementwise codegen function @@ -1902,15 +1921,21 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - let left_val = - NDArrayValue::from_pointer_value(lhs.into_pointer_value(), llvm_usize, None); + let llvm_ndarray_dtype1 = ctx.get_llvm_type(generator, ndarray_dtype1); + + let left_val = NDArrayValue::from_pointer_value( + lhs.into_pointer_value(), + llvm_ndarray_dtype1, + llvm_usize, + None, + ); let res = numpy::ndarray_elementwise_binop_impl( generator, ctx, ctx.primitives.bool, None, - (left_val.as_base_value().into(), false), - (rhs, false), + (left_ty, left_val.as_base_value().into(), false), + (right_ty, rhs, false), |generator, ctx, (lhs, rhs)| { let val = gen_cmpop_expr_with_values( generator, @@ -1941,8 +1966,8 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( ctx, ctx.primitives.bool, None, - (lhs, !is_ndarray1), - (rhs, !is_ndarray2), + (left_ty, lhs, !is_ndarray1), + (right_ty, rhs, !is_ndarray2), |generator, ctx, (lhs, rhs)| { let val = gen_cmpop_expr_with_values( generator, @@ -2771,8 +2796,12 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( // elements over let subscripted_ndarray = generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?; - let ndarray = - NDArrayValue::from_pointer_value(subscripted_ndarray, llvm_usize, None); + let ndarray = NDArrayValue::from_pointer_value( + subscripted_ndarray, + llvm_ndarray_data_t, + llvm_usize, + None, + ); let num_dims = v.load_ndims(ctx); ndarray.store_ndims( @@ -3510,6 +3539,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::NDArray.id() => { let (ty, ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap(); + let llvm_ty = ctx.get_llvm_type(generator, *ty); let v = if let Some(v) = generator.gen_expr(ctx, value)? { v.to_basic_value_enum(ctx, generator, value.custom.unwrap())? @@ -3517,7 +3547,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } else { return Ok(None); }; - let v = NDArrayValue::from_pointer_value(v, usize, None); + let v = NDArrayValue::from_pointer_value(v, llvm_ty, usize, None); return gen_ndarray_subscript_expr(generator, ctx, *ty, *ndims, v, slice); } diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 4589ba4..5db4ac2 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -26,7 +26,7 @@ use super::{ use crate::{ symbol_resolver::ValueEnum, toplevel::{ - helper::PrimDef, + helper::{arraylike_flatten_element_type, PrimDef}, numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, DefinitionId, }, @@ -42,6 +42,7 @@ fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>( ctx: &mut CodeGenContext<'ctx, '_>, elem_ty: Type, ) -> Result, String> { + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None); let llvm_usize = generator.get_size_type(ctx.ctx); @@ -54,7 +55,7 @@ fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>( let ndarray = generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?; - Ok(NDArrayValue::from_pointer_value(ndarray, llvm_usize, None)) + Ok(NDArrayValue::from_pointer_value(ndarray, llvm_elem_ty, llvm_usize, None)) } /// Creates an `NDArray` instance from a dynamic shape. @@ -473,8 +474,8 @@ fn ndarray_broadcast_fill<'ctx, 'a, G, ValueFn>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, 'a>, res: NDArrayValue<'ctx>, - lhs: (BasicValueEnum<'ctx>, bool), - rhs: (BasicValueEnum<'ctx>, bool), + lhs: (Type, BasicValueEnum<'ctx>, bool), + rhs: (Type, BasicValueEnum<'ctx>, bool), value_fn: ValueFn, ) -> Result, String> where @@ -487,8 +488,8 @@ where { let llvm_usize = generator.get_size_type(ctx.ctx); - let (lhs_val, lhs_scalar) = lhs; - let (rhs_val, rhs_scalar) = rhs; + let (lhs_ty, lhs_val, lhs_scalar) = lhs; + let (rhs_ty, rhs_val, rhs_scalar) = rhs; assert!( !(lhs_scalar && rhs_scalar), @@ -499,14 +500,26 @@ where // Assert that all ndarray operands are broadcastable to the target size if !lhs_scalar { - let lhs_val = - NDArrayValue::from_pointer_value(lhs_val.into_pointer_value(), llvm_usize, None); + let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, lhs_ty); + let llvm_lhs_elem_ty = ctx.get_llvm_type(generator, lhs_dtype); + let lhs_val = NDArrayValue::from_pointer_value( + lhs_val.into_pointer_value(), + llvm_lhs_elem_ty, + llvm_usize, + None, + ); ndarray_assert_is_broadcastable(generator, ctx, res, lhs_val); } if !rhs_scalar { - let rhs_val = - NDArrayValue::from_pointer_value(rhs_val.into_pointer_value(), llvm_usize, None); + let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, rhs_ty); + let llvm_rhs_elem_ty = ctx.get_llvm_type(generator, rhs_dtype); + let rhs_val = NDArrayValue::from_pointer_value( + rhs_val.into_pointer_value(), + llvm_rhs_elem_ty, + llvm_usize, + None, + ); ndarray_assert_is_broadcastable(generator, ctx, res, rhs_val); } @@ -514,8 +527,14 @@ where let lhs_elem = if lhs_scalar { lhs_val } else { - let lhs = - NDArrayValue::from_pointer_value(lhs_val.into_pointer_value(), llvm_usize, None); + let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, lhs_ty); + let llvm_lhs_elem_ty = ctx.get_llvm_type(generator, lhs_dtype); + let lhs = NDArrayValue::from_pointer_value( + lhs_val.into_pointer_value(), + llvm_lhs_elem_ty, + llvm_usize, + None, + ); let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, idx); unsafe { lhs.data().get_unchecked(ctx, generator, &lhs_idx, None) } @@ -524,8 +543,14 @@ where let rhs_elem = if rhs_scalar { rhs_val } else { - let rhs = - NDArrayValue::from_pointer_value(rhs_val.into_pointer_value(), llvm_usize, None); + let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, rhs_ty); + let llvm_rhs_elem_ty = ctx.get_llvm_type(generator, rhs_dtype); + let rhs = NDArrayValue::from_pointer_value( + rhs_val.into_pointer_value(), + llvm_rhs_elem_ty, + llvm_usize, + None, + ); let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, idx); unsafe { rhs.data().get_unchecked(ctx, generator, &rhs_idx, None) } @@ -671,7 +696,7 @@ fn llvm_ndlist_get_ndims<'ctx, G: CodeGenerator + ?Sized>( fn llvm_arraylike_get_ndims<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - value: BasicValueEnum<'ctx>, + (ty, value): (Type, BasicValueEnum<'ctx>), ) -> IntValue<'ctx> { let llvm_usize = generator.get_size_type(ctx.ctx); @@ -679,7 +704,9 @@ fn llvm_arraylike_get_ndims<'ctx, G: CodeGenerator + ?Sized>( BasicValueEnum::PointerValue(v) if NDArrayValue::is_representable(v, llvm_usize).is_ok() => { - NDArrayValue::from_pointer_value(v, llvm_usize, None).load_ndims(ctx) + let dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty); + let llvm_elem_ty = ctx.get_llvm_type(generator, dtype); + NDArrayValue::from_pointer_value(v, llvm_elem_ty, llvm_usize, None).load_ndims(ctx) } BasicValueEnum::PointerValue(v) if ListValue::is_representable(v, llvm_usize).is_ok() => { @@ -694,7 +721,6 @@ fn llvm_arraylike_get_ndims<'ctx, G: CodeGenerator + ?Sized>( fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, (dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>), src_lst: ListValue<'ctx>, dim: u64, @@ -727,6 +753,20 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( |_, _| Ok(llvm_usize.const_int(1, false)), |generator, ctx, _, i| { let offset = ctx.builder.build_int_mul(stride, i, "").unwrap(); + let offset = ctx + .builder + .build_int_mul( + offset, + ctx.builder + .build_int_truncate_or_bit_cast( + dst_arr.get_type().element_type().size_of().unwrap(), + offset.get_type(), + "", + ) + .unwrap(), + "", + ) + .unwrap(); let dst_ptr = unsafe { ctx.builder.build_gep(dst_slice_ptr, &[offset], "").unwrap() }; @@ -741,7 +781,6 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( ndarray_from_ndlist_impl( generator, ctx, - elem_ty, (dst_arr, dst_ptr), nested_lst_elem, dim + 1, @@ -760,7 +799,7 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( _ => { let lst_len = src_lst.load_size(ctx, None); - let sizeof_elem = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap(); + let sizeof_elem = dst_arr.get_type().element_type().size_of().unwrap(); let sizeof_elem = ctx.builder.build_int_cast(sizeof_elem, llvm_usize, "").unwrap(); let cpy_len = ctx @@ -816,7 +855,8 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( // object is an NDArray instance - copy object unless copy=0 && ndmin < object.ndims if NDArrayValue::is_representable(object, llvm_usize).is_ok() { - let object = NDArrayValue::from_pointer_value(object, llvm_usize, None); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let object = NDArrayValue::from_pointer_value(object, llvm_elem_ty, llvm_usize, None); let ndarray = gen_if_else_expr_callback( generator, @@ -878,7 +918,6 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( ndarray_sliced_copyto_impl( generator, ctx, - elem_ty, (ndarray, ndarray.data().base_ptr(ctx, generator)), (object, object.data().base_ptr(ctx, generator)), 0, @@ -892,6 +931,7 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( return Ok(NDArrayValue::from_pointer_value( ndarray.map(BasicValueEnum::into_pointer_value).unwrap(), + llvm_elem_ty, llvm_usize, None, )); @@ -1026,7 +1066,6 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( ndarray_from_ndlist_impl( generator, ctx, - elem_ty, (ndarray, ndarray.data().base_ptr(ctx, generator)), object, 0, @@ -1099,7 +1138,6 @@ fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>( fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, (dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>), (src_arr, src_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>), dim: u64, @@ -1108,10 +1146,12 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>( let llvm_i1 = ctx.ctx.bool_type(); let llvm_usize = generator.get_size_type(ctx.ctx); + assert_eq!(dst_arr.get_type().element_type(), src_arr.get_type().element_type()); + + let sizeof_elem = dst_arr.get_type().element_type().size_of().unwrap(); + // If there are no (remaining) slice expressions, memcpy the entire dimension if slices.is_empty() { - let sizeof_elem = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap(); - let stride = call_ndarray_calc_size( generator, ctx, @@ -1162,9 +1202,29 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>( |generator, ctx, _, src_i| { // Calculate the offset of the active slice let src_data_offset = ctx.builder.build_int_mul(src_stride, src_i, "").unwrap(); + let src_data_offset = ctx + .builder + .build_int_mul( + src_data_offset, + ctx.builder + .build_int_cast(sizeof_elem, src_data_offset.get_type(), "") + .unwrap(), + "", + ) + .unwrap(); let dst_i = ctx.builder.build_load(dst_i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); let dst_data_offset = ctx.builder.build_int_mul(dst_stride, dst_i, "").unwrap(); + let dst_data_offset = ctx + .builder + .build_int_mul( + dst_data_offset, + ctx.builder + .build_int_cast(sizeof_elem, dst_data_offset.get_type(), "") + .unwrap(), + "", + ) + .unwrap(); let (src_ptr, dst_ptr) = unsafe { ( @@ -1176,7 +1236,6 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>( ndarray_sliced_copyto_impl( generator, ctx, - elem_ty, (dst_arr, dst_ptr), (src_arr, src_ptr), dim + 1, @@ -1293,7 +1352,6 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>( ndarray_sliced_copyto_impl( generator, ctx, - elem_ty, (ndarray, ndarray.data().base_ptr(ctx, generator)), (this, this.data().base_ptr(ctx, generator)), 0, @@ -1376,8 +1434,8 @@ pub fn ndarray_elementwise_binop_impl<'ctx, 'a, G, ValueFn>( ctx: &mut CodeGenContext<'ctx, 'a>, elem_ty: Type, res: Option>, - lhs: (BasicValueEnum<'ctx>, bool), - rhs: (BasicValueEnum<'ctx>, bool), + lhs: (Type, BasicValueEnum<'ctx>, bool), + rhs: (Type, BasicValueEnum<'ctx>, bool), value_fn: ValueFn, ) -> Result, String> where @@ -1390,8 +1448,8 @@ where { let llvm_usize = generator.get_size_type(ctx.ctx); - let (lhs_val, lhs_scalar) = lhs; - let (rhs_val, rhs_scalar) = rhs; + let (lhs_ty, lhs_val, lhs_scalar) = lhs; + let (rhs_ty, rhs_val, rhs_scalar) = rhs; assert!( !(lhs_scalar && rhs_scalar), @@ -1402,10 +1460,22 @@ where let ndarray = res.unwrap_or_else(|| { if lhs_scalar && rhs_scalar { - let lhs_val = - NDArrayValue::from_pointer_value(lhs_val.into_pointer_value(), llvm_usize, None); - let rhs_val = - NDArrayValue::from_pointer_value(rhs_val.into_pointer_value(), llvm_usize, None); + let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, lhs_ty); + let llvm_lhs_elem_ty = ctx.get_llvm_type(generator, lhs_dtype); + let lhs_val = NDArrayValue::from_pointer_value( + lhs_val.into_pointer_value(), + llvm_lhs_elem_ty, + llvm_usize, + None, + ); + let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, rhs_ty); + let llvm_rhs_elem_ty = ctx.get_llvm_type(generator, rhs_dtype); + let rhs_val = NDArrayValue::from_pointer_value( + rhs_val.into_pointer_value(), + llvm_rhs_elem_ty, + llvm_usize, + None, + ); let ndarray_dims = call_ndarray_calc_broadcast(generator, ctx, lhs_val, rhs_val); @@ -1421,8 +1491,14 @@ where ) .unwrap() } else { + let dtype = arraylike_flatten_element_type( + &mut ctx.unifier, + if lhs_scalar { rhs_ty } else { lhs_ty }, + ); + let llvm_elem_ty = ctx.get_llvm_type(generator, dtype); let ndarray = NDArrayValue::from_pointer_value( if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(), + llvm_elem_ty, llvm_usize, None, ); @@ -1981,11 +2057,18 @@ pub fn gen_ndarray_copy<'ctx>( let this_arg = obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?; + let llvm_elem_ty = context.get_llvm_type(generator, this_elem_ty); + ndarray_copy_impl( generator, context, this_elem_ty, - NDArrayValue::from_pointer_value(this_arg.into_pointer_value(), llvm_usize, None), + NDArrayValue::from_pointer_value( + this_arg.into_pointer_value(), + llvm_elem_ty, + llvm_usize, + None, + ), ) .map(NDArrayValue::into) } @@ -2004,6 +2087,7 @@ pub fn gen_ndarray_fill<'ctx>( let llvm_usize = generator.get_size_type(context.ctx); let this_ty = obj.as_ref().unwrap().0; + let this_elem_ty = arraylike_flatten_element_type(&mut context.unifier, this_ty); let this_arg = obj .as_ref() .unwrap() @@ -2014,10 +2098,12 @@ pub fn gen_ndarray_fill<'ctx>( let value_ty = fun.0.args[0].ty; let value_arg = args[0].1.clone().to_basic_value_enum(context, generator, value_ty)?; + let llvm_elem_ty = context.get_llvm_type(generator, this_elem_ty); + ndarray_fill_flattened( generator, context, - NDArrayValue::from_pointer_value(this_arg, llvm_usize, None), + NDArrayValue::from_pointer_value(this_arg, llvm_elem_ty, llvm_usize, None), |generator, ctx, _| { let value = if value_arg.is_pointer_value() { let llvm_i1 = ctx.ctx.bool_type(); @@ -2058,7 +2144,8 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( if let BasicValueEnum::PointerValue(n1) = x1 { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None); let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); // Dimensions are reversed in the transposed array @@ -2177,7 +2264,8 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>( if let BasicValueEnum::PointerValue(n1) = x1 { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); + let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None); let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; @@ -2454,14 +2542,19 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "ndarray_dot"; let (x1_ty, x1) = x1; - let (_, x2) = x2; + let (x2_ty, x2) = x2; let llvm_usize = generator.get_size_type(ctx.ctx); match (x1, x2) { (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => { - let n1 = NDArrayValue::from_pointer_value(n1, llvm_usize, None); - let n2 = NDArrayValue::from_pointer_value(n2, llvm_usize, None); + let n1_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty); + let n2_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x2_ty); + let llvm_n1_data_ty = ctx.get_llvm_type(generator, n1_dtype); + let llvm_n2_data_ty = ctx.get_llvm_type(generator, n2_dtype); + + let n1 = NDArrayValue::from_pointer_value(n1, llvm_n1_data_ty, llvm_usize, None); + let n2 = NDArrayValue::from_pointer_value(n2, llvm_n2_data_ty, llvm_usize, None); let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); @@ -2501,7 +2594,7 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( .build_float_mul(e1, elem2.into_float_value(), "") .unwrap() .as_basic_value_enum(), - _ => codegen_unreachable!(ctx), + _ => codegen_unreachable!(ctx, "product: {}", elem1.get_type()), }; let acc_val = ctx.builder.build_load(acc, "").unwrap(); let acc_val = match acc_val { @@ -2515,7 +2608,7 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( .build_float_add(e1, product.into_float_value(), "") .unwrap() .as_basic_value_enum(), - _ => codegen_unreachable!(ctx), + _ => codegen_unreachable!(ctx, "acc_val: {}", acc_val.get_type()), }; ctx.builder.build_store(acc, acc_val).unwrap(); diff --git a/nac3core/src/codegen/types/ndarray.rs b/nac3core/src/codegen/types/ndarray.rs index 3f25f82..d688732 100644 --- a/nac3core/src/codegen/types/ndarray.rs +++ b/nac3core/src/codegen/types/ndarray.rs @@ -67,21 +67,29 @@ impl<'ctx> NDArrayType<'ctx> { } let ndarray_data_ty = llvm_ndarray_ty.get_field_type_at_index(2).unwrap(); - let Ok(_) = PointerType::try_from(ndarray_data_ty) else { + let Ok(ndarray_pdata) = PointerType::try_from(ndarray_data_ty) else { return Err(format!("Expected pointer type for `ndarray.2`, got {ndarray_data_ty}")); }; + let ndarray_data = ndarray_pdata.get_element_type(); + let Ok(ndarray_data) = IntType::try_from(ndarray_data) else { + return Err(format!( + "Expected pointer-to-int type for `ndarray.2`, got pointer-to-{ndarray_data}" + )); + }; + if ndarray_data.get_bit_width() != 8 { + return Err(format!( + "Expected pointer-to-8-bit int type for `ndarray.1`, got pointer-to-{}-bit int", + ndarray_data.get_bit_width() + )); + } Ok(()) } /// Creates an LLVM type corresponding to the expected structure of an `NDArray`. #[must_use] - fn llvm_type( - ctx: &'ctx Context, - dtype: BasicTypeEnum<'ctx>, - llvm_usize: IntType<'ctx>, - ) -> PointerType<'ctx> { - // struct NDArray { num_dims: size_t, dims: size_t*, data: T* } + fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> { + // struct NDArray { num_dims: size_t, dims: size_t*, data: i8* } // // * num_dims: Number of dimensions in the array // * dims: Pointer to an array containing the size of each dimension @@ -89,13 +97,13 @@ impl<'ctx> NDArrayType<'ctx> { let field_tys = [ llvm_usize.into(), llvm_usize.ptr_type(AddressSpace::default()).into(), - dtype.ptr_type(AddressSpace::default()).into(), + ctx.i8_type().ptr_type(AddressSpace::default()).into(), ]; ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } - /// Creates an instance of [`ListType`]. + /// Creates an instance of [`NDArrayType`]. #[must_use] pub fn new( generator: &G, @@ -103,24 +111,21 @@ impl<'ctx> NDArrayType<'ctx> { dtype: BasicTypeEnum<'ctx>, ) -> Self { let llvm_usize = generator.get_size_type(ctx); - let llvm_ndarray = Self::llvm_type(ctx, dtype, llvm_usize); + let llvm_ndarray = Self::llvm_type(ctx, llvm_usize); - NDArrayType::from_type(llvm_ndarray, llvm_usize) + NDArrayType { ty: llvm_ndarray, dtype, llvm_usize } } - /// Creates an [`NDArrayType`] from a [`PointerType`]. + /// Creates an [`NDArrayType`] from a [`PointerType`] representing an `NDArray`. #[must_use] - pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + pub fn from_type( + ptr_ty: PointerType<'ctx>, + dtype: BasicTypeEnum<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Self { debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); - NDArrayType { - ty: ptr_ty, - dtype: ptr_ty - .get_element_type() - .try_into() - .expect("Expected BasicTypeEnum for dtype of NDArray"), - llvm_usize, - } + NDArrayType { ty: ptr_ty, dtype, llvm_usize } } /// Returns the type of the `size` field of this `ndarray` type. @@ -207,7 +212,7 @@ impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> { ) -> Self::Value { debug_assert_eq!(value.get_type(), self.as_base_type()); - NDArrayValue::from_pointer_value(value, self.llvm_usize, name) + NDArrayValue::from_pointer_value(value, self.dtype, self.llvm_usize, name) } fn as_base_type(&self) -> Self::Base { diff --git a/nac3core/src/codegen/values/ndarray.rs b/nac3core/src/codegen/values/ndarray.rs index 908ad2f..42c42f0 100644 --- a/nac3core/src/codegen/values/ndarray.rs +++ b/nac3core/src/codegen/values/ndarray.rs @@ -1,7 +1,7 @@ use inkwell::{ - types::{AnyTypeEnum, BasicTypeEnum, IntType}, + types::{AnyType, AnyTypeEnum, BasicType, BasicTypeEnum, IntType}, values::{BasicValueEnum, IntValue, PointerValue}, - IntPredicate, + AddressSpace, IntPredicate, }; use super::{ @@ -20,6 +20,7 @@ use crate::codegen::{ #[derive(Copy, Clone)] pub struct NDArrayValue<'ctx> { value: PointerValue<'ctx>, + dtype: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, } @@ -38,12 +39,13 @@ impl<'ctx> NDArrayValue<'ctx> { #[must_use] pub fn from_pointer_value( ptr: PointerValue<'ctx>, + dtype: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>, name: Option<&'ctx str>, ) -> Self { debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); - NDArrayValue { value: ptr, llvm_usize, name } + NDArrayValue { value: ptr, dtype, llvm_usize, name } } /// Returns the pointer to the field storing the number of dimensions of this `NDArray`. @@ -138,6 +140,10 @@ impl<'ctx> NDArrayValue<'ctx> { /// Stores the array of data elements `data` into this instance. fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { + let data = ctx + .builder + .build_bit_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "") + .unwrap(); ctx.builder.build_store(self.ptr_to_data(ctx), data).unwrap(); } @@ -149,7 +155,15 @@ impl<'ctx> NDArrayValue<'ctx> { elem_ty: BasicTypeEnum<'ctx>, size: IntValue<'ctx>, ) { - self.store_data(ctx, ctx.builder.build_array_alloca(elem_ty, size, "").unwrap()); + let itemsize = + ctx.builder.build_int_cast(elem_ty.size_of().unwrap(), size.get_type(), "").unwrap(); + let nbytes = ctx.builder.build_int_mul(size, itemsize, "").unwrap(); + + // TODO: What about alignment? + self.store_data( + ctx, + ctx.builder.build_array_alloca(ctx.ctx.i8_type(), nbytes, "").unwrap(), + ); } /// Returns a proxy object to the field storing the data of this `NDArray`. @@ -164,7 +178,7 @@ impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> { type Type = NDArrayType<'ctx>; fn get_type(&self) -> Self::Type { - NDArrayType::from_type(self.as_base_value().get_type(), self.llvm_usize) + NDArrayType::from_type(self.as_base_value().get_type(), self.dtype, self.llvm_usize) } fn as_base_value(&self) -> Self::Base { @@ -282,10 +296,10 @@ pub struct NDArrayDataProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>); impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> { fn element_type( &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, + _: &CodeGenContext<'ctx, '_>, + _: &G, ) -> AnyTypeEnum<'ctx> { - self.0.data().base_ptr(ctx, generator).get_type().get_element_type() + self.0.dtype.as_any_type_enum() } fn base_ptr( @@ -318,15 +332,37 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> { idx: &IntValue<'ctx>, name: Option<&str>, ) -> PointerValue<'ctx> { - unsafe { + let sizeof_elem = ctx + .builder + .build_int_truncate_or_bit_cast( + self.element_type(ctx, generator).size_of().unwrap(), + idx.get_type(), + "", + ) + .unwrap(); + let idx = ctx.builder.build_int_mul(*idx, sizeof_elem, "").unwrap(); + let ptr = unsafe { ctx.builder .build_in_bounds_gep( self.base_ptr(ctx, generator), - &[*idx], + &[idx], name.unwrap_or_default(), ) .unwrap() - } + }; + + // Current implementation is transparent - The returned pointer type is + // already cast into the expected type, allowing for immediately + // load/store. + ctx.builder + .build_pointer_cast( + ptr, + BasicTypeEnum::try_from(self.element_type(ctx, generator)) + .unwrap() + .ptr_type(AddressSpace::default()), + "", + ) + .unwrap() } fn ptr_offset( @@ -347,7 +383,20 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> { ctx.current_loc, ); - unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) } + let ptr = unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }; + + // Current implementation is transparent - The returned pointer type is + // already cast into the expected type, allowing for immediately + // load/store. + ctx.builder + .build_pointer_cast( + ptr, + BasicTypeEnum::try_from(self.element_type(ctx, generator)) + .unwrap() + .ptr_type(AddressSpace::default()), + "", + ) + .unwrap() } } @@ -381,8 +430,17 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> ); let index = call_ndarray_flatten_index(generator, ctx, *self.0, indices); + let sizeof_elem = ctx + .builder + .build_int_truncate_or_bit_cast( + self.element_type(ctx, generator).size_of().unwrap(), + index.get_type(), + "", + ) + .unwrap(); + let index = ctx.builder.build_int_mul(index, sizeof_elem, "").unwrap(); - unsafe { + let ptr = unsafe { ctx.builder .build_in_bounds_gep( self.base_ptr(ctx, generator), @@ -390,7 +448,17 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> name.unwrap_or_default(), ) .unwrap() - } + }; + // TODO: Current implementation is transparent + ctx.builder + .build_pointer_cast( + ptr, + BasicTypeEnum::try_from(self.element_type(ctx, generator)) + .unwrap() + .ptr_type(AddressSpace::default()), + "", + ) + .unwrap() } fn ptr_offset( @@ -455,7 +523,17 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> ) .unwrap(); - unsafe { self.ptr_offset_unchecked(ctx, generator, indices, name) } + let ptr = unsafe { self.ptr_offset_unchecked(ctx, generator, indices, name) }; + // TODO: Current implementation is transparent + ctx.builder + .build_pointer_cast( + ptr, + BasicTypeEnum::try_from(self.element_type(ctx, generator)) + .unwrap() + .ptr_type(AddressSpace::default()), + "", + ) + .unwrap() } } diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 577ad9c..d42f3b9 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -144,6 +144,7 @@ def test_ndarray_array(): # Copy n2_cpy: ndarray[float, 2] = np_array(n2, copy=False) + output_ndarray_float_2(n2_cpy) n2_cpy.fill(0.0) output_ndarray_float_2(n2_cpy) @@ -1756,7 +1757,7 @@ def run() -> int32: test_ndarray_nextafter_broadcast_rhs_scalar() test_ndarray_transpose() test_ndarray_reshape() - + test_ndarray_dot() test_ndarray_cholesky() test_ndarray_qr() -- 2.44.2 From 144f0922dbee2c533f55ac9e0a296982d51cc3a2 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 13 Nov 2024 15:53:29 +0800 Subject: [PATCH 16/16] [core] coregen/types: Implement StructFields for NDArray Also rename some fields to better align with their naming in numpy. --- nac3artiq/src/codegen.rs | 10 +-- nac3core/src/codegen/builtin_fns.rs | 43 ++++++------- nac3core/src/codegen/expr.rs | 12 ++-- nac3core/src/codegen/irrt/ndarray.rs | 14 ++--- nac3core/src/codegen/numpy.rs | 73 ++++++++++------------ nac3core/src/codegen/types/ndarray.rs | 51 ++++++++++++--- nac3core/src/codegen/values/ndarray.rs | 86 ++++++++++---------------- 7 files changed, 143 insertions(+), 146 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index aece926..9999879 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -498,7 +498,7 @@ fn format_rpc_arg<'ctx>( call_memcpy_generic( ctx, pbuffer_dims_begin, - llvm_arg.dim_sizes().base_ptr(ctx, generator), + llvm_arg.shape().base_ptr(ctx, generator), dims_buf_sz, llvm_i1.const_zero(), ); @@ -612,7 +612,7 @@ fn format_rpc_ret<'ctx>( // Set `ndarray.ndims` ndarray.store_ndims(ctx, generator, llvm_usize.const_int(ndims, false)); // Allocate `ndarray.shape` [size_t; ndims] - ndarray.create_dim_sizes(ctx, llvm_usize, ndarray.load_ndims(ctx)); + ndarray.create_shape(ctx, llvm_usize, ndarray.load_ndims(ctx)); /* ndarray now: @@ -702,7 +702,7 @@ fn format_rpc_ret<'ctx>( call_memcpy_generic( ctx, - ndarray.dim_sizes().base_ptr(ctx, generator), + ndarray.shape().base_ptr(ctx, generator), pbuffer_dims, sizeof_dims, llvm_i1.const_zero(), @@ -714,7 +714,7 @@ fn format_rpc_ret<'ctx>( // `ndarray.shape` must be initialized beforehand in this implementation // (for ndarray.create_data() to know how many elements to allocate) let num_elements = - call_ndarray_calc_size(generator, ctx, &ndarray.dim_sizes(), (None, None)); + call_ndarray_calc_size(generator, ctx, &ndarray.shape(), (None, None)); // debug_assert(nelems * sizeof(T) >= ndarray_nbytes) if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { @@ -1379,7 +1379,7 @@ fn polymorphic_print<'ctx>( llvm_usize, None, ); - let len = call_ndarray_calc_size(generator, ctx, &val.dim_sizes(), (None, None)); + let len = call_ndarray_calc_size(generator, ctx, &val.shape(), (None, None)); let last = ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap(); diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index e693faf..20d3500 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -78,7 +78,7 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>( None, ); - let ndims = arg.dim_sizes().size(ctx, generator); + let ndims = arg.shape().size(ctx, generator); ctx.make_assert( generator, ctx.builder @@ -91,12 +91,7 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>( ); let len = unsafe { - arg.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_zero(), - None, - ) + arg.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap() @@ -927,7 +922,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let n = NDArrayValue::from_pointer_value(n, llvm_elem_ty, llvm_usize, None); - let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None)); + let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.shape(), (None, None)); if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { let n_sz_eqz = ctx .builder @@ -1981,12 +1976,12 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let dim0 = unsafe { - n1.dim_sizes() + n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .into_int_value() }; let dim1 = unsafe { - n1.dim_sizes() + n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) .into_int_value() }; @@ -2023,12 +2018,12 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let dim0 = unsafe { - n1.dim_sizes() + n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .into_int_value() }; let dim1 = unsafe { - n1.dim_sizes() + n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) .into_int_value() }; @@ -2074,12 +2069,12 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let dim0 = unsafe { - n1.dim_sizes() + n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .into_int_value() }; let dim1 = unsafe { - n1.dim_sizes() + n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) .into_int_value() }; @@ -2128,12 +2123,12 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>( let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let dim0 = unsafe { - n1.dim_sizes() + n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .into_int_value() }; let dim1 = unsafe { - n1.dim_sizes() + n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) .into_int_value() }; @@ -2171,12 +2166,12 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let dim0 = unsafe { - n1.dim_sizes() + n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .into_int_value() }; let dim1 = unsafe { - n1.dim_sizes() + n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) .into_int_value() }; @@ -2214,12 +2209,12 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let dim0 = unsafe { - n1.dim_sizes() + n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .into_int_value() }; let dim1 = unsafe { - n1.dim_sizes() + n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) .into_int_value() }; @@ -2284,12 +2279,12 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( let n2_array = n2_array.as_base_value().as_basic_value_enum(); let outdim0 = unsafe { - n1.dim_sizes() + n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .into_int_value() }; let outdim1 = unsafe { - n1.dim_sizes() + n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) .into_int_value() }; @@ -2362,7 +2357,7 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let dim0 = unsafe { - n1.dim_sizes() + n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .into_int_value() }; @@ -2405,7 +2400,7 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, llvm_usize, None); let dim0 = unsafe { - n1.dim_sizes() + n1.shape() .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) .into_int_value() }; diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 01047b3..f8462ab 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -2631,7 +2631,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( let llvm_i32 = ctx.ctx.i32_type(); let len = unsafe { - v.dim_sizes().get_typed_unchecked( + v.shape().get_typed_unchecked( ctx, generator, &llvm_usize.const_int(dim, true), @@ -2672,7 +2672,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( ExprKind::Slice { lower, upper, step } => { let dim_sz = unsafe { - v.dim_sizes().get_typed_unchecked( + v.shape().get_typed_unchecked( ctx, generator, &llvm_usize.const_int(dim, false), @@ -2813,7 +2813,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( ); let ndarray_num_dims = ndarray.load_ndims(ctx); - ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims); + ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims); let ndarray_num_dims = ctx .builder @@ -2824,7 +2824,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( ) .unwrap(); let v_dims_src_ptr = unsafe { - v.dim_sizes().ptr_offset_unchecked( + v.shape().ptr_offset_unchecked( ctx, generator, &llvm_usize.const_int(1, false), @@ -2833,7 +2833,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( }; call_memcpy_generic( ctx, - ndarray.dim_sizes().base_ptr(ctx, generator), + ndarray.shape().base_ptr(ctx, generator), v_dims_src_ptr, ctx.builder .build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "") @@ -2845,7 +2845,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( let ndarray_num_elems = call_ndarray_calc_size( generator, ctx, - &ndarray.dim_sizes().as_slice_value(ctx, generator), + &ndarray.shape().as_slice_value(ctx, generator), (None, None), ); let ndarray_num_elems = ctx diff --git a/nac3core/src/codegen/irrt/ndarray.rs b/nac3core/src/codegen/irrt/ndarray.rs index bfec1d5..541c117 100644 --- a/nac3core/src/codegen/irrt/ndarray.rs +++ b/nac3core/src/codegen/irrt/ndarray.rs @@ -103,7 +103,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( }); let ndarray_num_dims = ndarray.load_ndims(ctx); - let ndarray_dims = ndarray.dim_sizes(); + let ndarray_dims = ndarray.shape(); let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap(); @@ -172,7 +172,7 @@ where }); let ndarray_num_dims = ndarray.load_ndims(ctx); - let ndarray_dims = ndarray.dim_sizes(); + let ndarray_dims = ndarray.shape(); let index = ctx .builder @@ -259,8 +259,8 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap(); let (lhs_dim_sz, rhs_dim_sz) = unsafe { ( - lhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None), - rhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None), + lhs.shape().get_typed_unchecked(ctx, generator, &idx, None), + rhs.shape().get_typed_unchecked(ctx, generator, &idx, None), ) }; @@ -298,9 +298,9 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( .unwrap(); let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None); - let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator); + let lhs_dims = lhs.shape().base_ptr(ctx, generator); let lhs_ndims = lhs.load_ndims(ctx); - let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator); + let rhs_dims = rhs.shape().base_ptr(ctx, generator); let rhs_ndims = rhs.load_ndims(ctx); let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap(); let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None); @@ -362,7 +362,7 @@ pub fn call_ndarray_calc_broadcast_index< let broadcast_size = broadcast_idx.size(ctx, generator); let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap(); - let array_dims = array.dim_sizes().base_ptr(ctx, generator); + let array_dims = array.shape().base_ptr(ctx, generator); let array_ndims = array.load_ndims(ctx); let broadcast_idx_ptr = unsafe { broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None) diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 5db4ac2..58869f2 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -128,7 +128,7 @@ where ndarray.store_ndims(ctx, generator, num_dims); let ndarray_num_dims = ndarray.load_ndims(ctx); - ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims); + ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims); // Copy the dimension sizes from shape to ndarray.dims let shape_len = shape_len_fn(generator, ctx, shape)?; @@ -144,7 +144,7 @@ where let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); let ndarray_pdim = - unsafe { ndarray.dim_sizes().ptr_offset_unchecked(ctx, generator, &i, None) }; + unsafe { ndarray.shape().ptr_offset_unchecked(ctx, generator, &i, None) }; ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap(); @@ -195,12 +195,12 @@ pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( ndarray.store_ndims(ctx, generator, num_dims); let ndarray_num_dims = ndarray.load_ndims(ctx); - ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims); + ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims); for (i, &shape_dim) in shape.iter().enumerate() { let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); let ndarray_dim = unsafe { - ndarray.dim_sizes().ptr_offset_unchecked( + ndarray.shape().ptr_offset_unchecked( ctx, generator, &llvm_usize.const_int(i as u64, true), @@ -229,7 +229,7 @@ fn ndarray_init_data<'ctx, G: CodeGenerator + ?Sized>( let ndarray_num_elems = call_ndarray_calc_size( generator, ctx, - &ndarray.dim_sizes().as_slice_value(ctx, generator), + &ndarray.shape().as_slice_value(ctx, generator), (None, None), ); ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems); @@ -380,7 +380,7 @@ where let ndarray_num_elems = call_ndarray_calc_size( generator, ctx, - &ndarray.dim_sizes().as_slice_value(ctx, generator), + &ndarray.shape().as_slice_value(ctx, generator), (None, None), ); @@ -739,7 +739,7 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( let stride = call_ndarray_calc_size( generator, ctx, - &dst_arr.dim_sizes(), + &dst_arr.shape(), (Some(llvm_usize.const_int(dim + 1, false)), None), ); @@ -1155,7 +1155,7 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>( let stride = call_ndarray_calc_size( generator, ctx, - &src_arr.dim_sizes(), + &src_arr.shape(), (Some(llvm_usize.const_int(dim, false)), None), ); let stride = @@ -1173,13 +1173,13 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>( let src_stride = call_ndarray_calc_size( generator, ctx, - &src_arr.dim_sizes(), + &src_arr.shape(), (Some(llvm_usize.const_int(dim + 1, false)), None), ); let dst_stride = call_ndarray_calc_size( generator, ctx, - &dst_arr.dim_sizes(), + &dst_arr.shape(), (Some(llvm_usize.const_int(dim + 1, false)), None), ); @@ -1278,7 +1278,7 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>( &this, |_, ctx, shape| Ok(shape.load_ndims(ctx)), |generator, ctx, shape, idx| unsafe { - Ok(shape.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) + Ok(shape.shape().get_typed_unchecked(ctx, generator, &idx, None)) }, )? } else { @@ -1286,7 +1286,7 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>( ndarray.store_ndims(ctx, generator, this.load_ndims(ctx)); let ndims = this.load_ndims(ctx); - ndarray.create_dim_sizes(ctx, llvm_usize, ndims); + ndarray.create_shape(ctx, llvm_usize, ndims); // Populate the first slices.len() dimensions by computing the size of each dim slice for (i, (start, stop, step)) in slices.iter().enumerate() { @@ -1318,7 +1318,7 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>( ctx.builder.build_int_z_extend_or_bit_cast(slice_len, llvm_usize, "").unwrap(); unsafe { - ndarray.dim_sizes().set_typed_unchecked( + ndarray.shape().set_typed_unchecked( ctx, generator, &llvm_usize.const_int(i as u64, false), @@ -1336,8 +1336,8 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>( (this.load_ndims(ctx), false), |generator, ctx, _, idx| { unsafe { - let dim_sz = this.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None); - ndarray.dim_sizes().set_typed_unchecked(ctx, generator, &idx, dim_sz); + let dim_sz = this.shape().get_typed_unchecked(ctx, generator, &idx, None); + ndarray.shape().set_typed_unchecked(ctx, generator, &idx, dim_sz); } Ok(()) @@ -1397,7 +1397,7 @@ where &operand, |_, ctx, v| Ok(v.load_ndims(ctx)), |generator, ctx, v, idx| unsafe { - Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) + Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None)) }, ) .unwrap() @@ -1510,7 +1510,7 @@ where &ndarray, |_, ctx, v| Ok(v.load_ndims(ctx)), |generator, ctx, v, idx| unsafe { - Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) + Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None)) }, ) .unwrap() @@ -1571,10 +1571,10 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( if let Some(res) = res { let res_ndims = res.load_ndims(ctx); let res_dim0 = unsafe { - res.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + res.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; let res_dim1 = unsafe { - res.dim_sizes().get_typed_unchecked( + res.shape().get_typed_unchecked( ctx, generator, &llvm_usize.const_int(1, false), @@ -1582,10 +1582,10 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( ) }; let lhs_dim0 = unsafe { - lhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + lhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; let rhs_dim1 = unsafe { - rhs.dim_sizes().get_typed_unchecked( + rhs.shape().get_typed_unchecked( ctx, generator, &llvm_usize.const_int(1, false), @@ -1634,15 +1634,10 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { let lhs_dim1 = unsafe { - lhs.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) + lhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) }; let rhs_dim0 = unsafe { - rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + rhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; // lhs.dims[1] == rhs.dims[0] @@ -1681,7 +1676,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( }, |generator, ctx| { Ok(Some(unsafe { - lhs.dim_sizes().get_typed_unchecked( + lhs.shape().get_typed_unchecked( ctx, generator, &llvm_usize.const_zero(), @@ -1691,7 +1686,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( }, |generator, ctx| { Ok(Some(unsafe { - rhs.dim_sizes().get_typed_unchecked( + rhs.shape().get_typed_unchecked( ctx, generator, &llvm_usize.const_int(1, false), @@ -1718,7 +1713,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( let common_dim = { let lhs_idx1 = unsafe { - lhs.dim_sizes().get_typed_unchecked( + lhs.shape().get_typed_unchecked( ctx, generator, &llvm_usize.const_int(1, false), @@ -1726,7 +1721,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( ) }; let rhs_idx0 = unsafe { - rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + rhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None); @@ -2146,7 +2141,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None); - let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); + let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); // Dimensions are reversed in the transposed array let out = create_ndarray_dyn_shape( @@ -2161,7 +2156,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( .builder .build_int_sub(new_idx, new_idx.get_type().const_int(1, false), "") .unwrap(); - unsafe { Ok(n.dim_sizes().get_typed_unchecked(ctx, generator, &new_idx, None)) } + unsafe { Ok(n.shape().get_typed_unchecked(ctx, generator, &new_idx, None)) } }, ) .unwrap(); @@ -2198,7 +2193,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( .build_int_sub(ndim_rev, llvm_usize.const_int(1, false), "") .unwrap(); let dim = unsafe { - n1.dim_sizes().get_typed_unchecked(ctx, generator, &ndim_rev, None) + n1.shape().get_typed_unchecked(ctx, generator, &ndim_rev, None) }; let rem_idx_val = @@ -2266,7 +2261,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>( let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None); - let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); + let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; @@ -2494,7 +2489,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>( ); // The new shape must be compatible with the old shape - let out_sz = call_ndarray_calc_size(generator, ctx, &out.dim_sizes(), (None, None)); + let out_sz = call_ndarray_calc_size(generator, ctx, &out.shape(), (None, None)); ctx.make_assert( generator, ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(), @@ -2556,8 +2551,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( let n1 = NDArrayValue::from_pointer_value(n1, llvm_n1_data_ty, llvm_usize, None); let n2 = NDArrayValue::from_pointer_value(n2, llvm_n2_data_ty, llvm_usize, None); - let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); - let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); + let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); + let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None)); ctx.make_assert( generator, diff --git a/nac3core/src/codegen/types/ndarray.rs b/nac3core/src/codegen/types/ndarray.rs index d688732..d149b91 100644 --- a/nac3core/src/codegen/types/ndarray.rs +++ b/nac3core/src/codegen/types/ndarray.rs @@ -1,11 +1,17 @@ use inkwell::{ context::Context, types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, - values::IntValue, + values::{IntValue, PointerValue}, AddressSpace, }; +use itertools::Itertools; -use super::ProxyType; +use nac3core_derive::StructFields; + +use super::{ + structure::{StructField, StructFields}, + ProxyType, +}; use crate::codegen::{ values::{ArraySliceValue, NDArrayValue, ProxyValue}, {CodeGenContext, CodeGenerator}, @@ -19,6 +25,16 @@ pub struct NDArrayType<'ctx> { llvm_usize: IntType<'ctx>, } +#[derive(PartialEq, Eq, Clone, Copy, StructFields)] +pub struct NDArrayStructFields<'ctx> { + #[value_type(usize)] + pub ndims: StructField<'ctx, IntValue<'ctx>>, + #[value_type(usize.ptr_type(AddressSpace::default()))] + pub shape: StructField<'ctx, PointerValue<'ctx>>, + #[value_type(i8_type().ptr_type(AddressSpace::default()))] + pub data: StructField<'ctx, PointerValue<'ctx>>, +} + impl<'ctx> NDArrayType<'ctx> { /// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not. pub fn is_representable( @@ -86,19 +102,34 @@ impl<'ctx> NDArrayType<'ctx> { Ok(()) } + // TODO: Move this into e.g. StructProxyType + #[must_use] + fn fields(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> NDArrayStructFields<'ctx> { + NDArrayStructFields::new(ctx, llvm_usize) + } + + // TODO: Move this into e.g. StructProxyType + #[must_use] + pub fn get_fields( + &self, + ctx: &'ctx Context, + llvm_usize: IntType<'ctx>, + ) -> NDArrayStructFields<'ctx> { + Self::fields(ctx, llvm_usize) + } + /// Creates an LLVM type corresponding to the expected structure of an `NDArray`. #[must_use] fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> { // struct NDArray { num_dims: size_t, dims: size_t*, data: i8* } // - // * num_dims: Number of dimensions in the array - // * dims: Pointer to an array containing the size of each dimension - // * data: Pointer to an array containing the array data - let field_tys = [ - llvm_usize.into(), - llvm_usize.ptr_type(AddressSpace::default()).into(), - ctx.i8_type().ptr_type(AddressSpace::default()).into(), - ]; + // * data : Pointer to an array containing the array data + // * itemsize: The size of each NDArray elements in bytes + // * ndims : Number of dimensions in the array + // * shape : Pointer to an array containing the shape of the NDArray + // * strides : Pointer to an array indicating the number of bytes between each element at a dimension + let field_tys = + Self::fields(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec(); ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } diff --git a/nac3core/src/codegen/values/ndarray.rs b/nac3core/src/codegen/values/ndarray.rs index 42c42f0..38c6a98 100644 --- a/nac3core/src/codegen/values/ndarray.rs +++ b/nac3core/src/codegen/values/ndarray.rs @@ -50,18 +50,10 @@ impl<'ctx> NDArrayValue<'ctx> { /// Returns the pointer to the field storing the number of dimensions of this `NDArray`. fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let llvm_i32 = ctx.ctx.i32_type(); - let var_name = self.name.map(|v| format!("{v}.ndims.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - var_name.as_str(), - ) - .unwrap() - } + self.get_type() + .get_fields(ctx.ctx, self.llvm_usize) + .ndims + .ptr_by_gep(ctx, self.value, self.name) } /// Stores the number of dimensions `ndims` into this instance. @@ -83,59 +75,43 @@ impl<'ctx> NDArrayValue<'ctx> { ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap() } - /// Returns the double-indirection pointer to the `dims` array, as if by calling `getelementptr` - /// on the field. - fn ptr_to_dims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let llvm_i32 = ctx.ctx.i32_type(); - let var_name = self.name.map(|v| format!("{v}.dims.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - var_name.as_str(), - ) - .unwrap() - } + /// Returns the double-indirection pointer to the `shape` array, as if by calling + /// `getelementptr` on the field. + fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + self.get_type() + .get_fields(ctx.ctx, self.llvm_usize) + .shape + .ptr_by_gep(ctx, self.value, self.name) } /// Stores the array of dimension sizes `dims` into this instance. - fn store_dim_sizes(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) { - ctx.builder.build_store(self.ptr_to_dims(ctx), dims).unwrap(); + fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) { + ctx.builder.build_store(self.ptr_to_shape(ctx), dims).unwrap(); } /// Convenience method for creating a new array storing dimension sizes with the given `size`. - pub fn create_dim_sizes( + pub fn create_shape( &self, ctx: &CodeGenContext<'ctx, '_>, llvm_usize: IntType<'ctx>, size: IntValue<'ctx>, ) { - self.store_dim_sizes(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap()); + self.store_shape(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap()); } /// Returns a proxy object to the field storing the size of each dimension of this `NDArray`. #[must_use] - pub fn dim_sizes(&self) -> NDArrayDimsProxy<'ctx, '_> { - NDArrayDimsProxy(self) + pub fn shape(&self) -> NDArrayShapeProxy<'ctx, '_> { + NDArrayShapeProxy(self) } /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` /// on the field. pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let llvm_i32 = ctx.ctx.i32_type(); - let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_int(2, true)], - var_name.as_str(), - ) - .unwrap() - } + self.get_type() + .get_fields(ctx.ctx, self.llvm_usize) + .data + .ptr_by_gep(ctx, self.value, self.name) } /// Stores the array of data elements `data` into this instance. @@ -194,15 +170,15 @@ impl<'ctx> From> for PointerValue<'ctx> { /// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM. #[derive(Copy, Clone)] -pub struct NDArrayDimsProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>); +pub struct NDArrayShapeProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>); -impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDimsProxy<'ctx, '_> { +impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> { fn element_type( &self, ctx: &CodeGenContext<'ctx, '_>, generator: &G, ) -> AnyTypeEnum<'ctx> { - self.0.dim_sizes().base_ptr(ctx, generator).get_type().get_element_type() + self.0.shape().base_ptr(ctx, generator).get_type().get_element_type() } fn base_ptr( @@ -213,7 +189,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDimsProxy<'ctx, '_> { let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default(); ctx.builder - .build_load(self.0.ptr_to_dims(ctx), var_name.as_str()) + .build_load(self.0.ptr_to_shape(ctx), var_name.as_str()) .map(BasicValueEnum::into_pointer_value) .unwrap() } @@ -227,7 +203,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDimsProxy<'ctx, '_> { } } -impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> { +impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> { unsafe fn ptr_offset_unchecked( &self, ctx: &mut CodeGenContext<'ctx, '_>, @@ -266,10 +242,10 @@ impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> } } -impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> {} -impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> {} +impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {} +impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {} -impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> { +impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> { fn downcast_to_type( &self, _: &mut CodeGenContext<'ctx, '_>, @@ -279,7 +255,7 @@ impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ct } } -impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> { +impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> { fn upcast_from_type( &self, _: &mut CodeGenContext<'ctx, '_>, @@ -497,7 +473,7 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> let (dim_idx, dim_sz) = unsafe { ( indices.get_unchecked(ctx, generator, &i, None).into_int_value(), - self.0.dim_sizes().get_typed_unchecked(ctx, generator, &i, None), + self.0.shape().get_typed_unchecked(ctx, generator, &i, None), ) }; let dim_idx = ctx -- 2.44.2