diff --git a/.clang-format b/.clang-format new file mode 100644 index 00000000..278e1fa9 --- /dev/null +++ b/.clang-format @@ -0,0 +1,3 @@ +BasedOnStyle: Microsoft +IndentWidth: 4 +ReflowComments: false \ No newline at end of file diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index d5a53b16..e4a0d30f 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -1,12 +1,10 @@ use nac3core::{ codegen::{ - classes::{ - ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayType, - NDArrayValue, RangeValue, UntypedArrayLikeAccessor, - }, + classes::{ListValue, RangeValue, UntypedArrayLikeAccessor}, expr::{destructure_range, gen_call}, - irrt::call_ndarray_calc_size, - llvm_intrinsics::{call_int_smax, call_memcpy_generic, call_stackrestore, call_stacksave}, + llvm_intrinsics::{call_int_smax, call_stackrestore, call_stacksave}, + model::*, + object::{any::AnyObject, ndarray::NDArrayObject}, stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with}, CodeGenContext, CodeGenerator, }, @@ -20,7 +18,7 @@ use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef}; use inkwell::{ context::Context, module::Linkage, - types::{BasicType, IntType}, + types::IntType, values::{BasicValueEnum, PointerValue, StructValue}, AddressSpace, IntPredicate, }; @@ -456,58 +454,41 @@ fn format_rpc_arg<'ctx>( // NAC3: NDArray = { usize, usize*, T* } // libproto_artiq: NDArray = [data[..], dim_sz[..]] - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); + let ndarray = AnyObject { ty: arg_ty, value: arg }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); - let llvm_arg_ty = - NDArrayType::new(generator, ctx.ctx, ctx.get_llvm_type(generator, elem_ty)); - let llvm_arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), llvm_usize, None); + let dtype = ctx.get_llvm_type(generator, ndarray.dtype); + let ndims = ndarray.ndims_llvm(generator, ctx.ctx); - let llvm_usize_sizeof = ctx - .builder - .build_int_truncate_or_bit_cast(llvm_arg_ty.size_type().size_of(), llvm_usize, "") - .unwrap(); - let llvm_pdata_sizeof = ctx - .builder - .build_int_truncate_or_bit_cast( - llvm_arg_ty.element_type().ptr_type(AddressSpace::default()).size_of(), - llvm_usize, - "", - ) - .unwrap(); + // `ndarray.data` is possibly not contiguous, and we need it to be contiguous for + // the reader. + let carray = ndarray.make_contiguous_ndarray(generator, ctx, Any(dtype)); - let dims_buf_sz = - ctx.builder.build_int_mul(llvm_arg.load_ndims(ctx), llvm_usize_sizeof, "").unwrap(); + let sizeof_sizet = Int(SizeT).sizeof(generator, ctx.ctx); + let sizeof_sizet = Int(SizeT).truncate_or_bit_cast(generator, ctx, sizeof_sizet); - let buffer_size = - ctx.builder.build_int_add(dims_buf_sz, llvm_pdata_sizeof, "").unwrap(); + let sizeof_pdata = Ptr(Any(dtype)).sizeof(generator, ctx.ctx); + let sizeof_pdata = Int(SizeT).truncate_or_bit_cast(generator, ctx, sizeof_pdata); - let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.arg").unwrap(); - let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.arg")); + let sizeof_buf_shape = sizeof_sizet.mul(ctx, ndims); + let sizeof_buf = sizeof_buf_shape.add(ctx, sizeof_pdata); - let ppdata = generator.gen_var_alloc(ctx, llvm_arg_ty.element_type(), None).unwrap(); - ctx.builder.build_store(ppdata, llvm_arg.data().base_ptr(ctx, generator)).unwrap(); + // buf = { data: void*, shape: [size_t; ndims]; } + let buf = Int(Byte).array_alloca(generator, ctx, sizeof_buf.value); + let buf_data = buf; + let buf_shape = buf_data.offset(ctx, sizeof_pdata.value); - call_memcpy_generic( - ctx, - buffer.base_ptr(ctx, generator), - ppdata, - llvm_pdata_sizeof, - llvm_i1.const_zero(), - ); + // Write to `buf->data` + let carray_data = carray.get(generator, ctx, |f| f.data); // has type Ptr + let carray_data = carray_data.pointer_cast(generator, ctx, Int(Byte)); + buf_data.copy_from(generator, ctx, carray_data, sizeof_pdata.value); - let pbuffer_dims_begin = - unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) }; - call_memcpy_generic( - ctx, - pbuffer_dims_begin, - llvm_arg.dim_sizes().base_ptr(ctx, generator), - dims_buf_sz, - llvm_i1.const_zero(), - ); + // Write to `buf->shape` + let carray_shape = ndarray.instance.get(generator, ctx, |f| f.shape); + let carray_shape_i8 = carray_shape.pointer_cast(generator, ctx, Int(Byte)); + buf_shape.copy_from(generator, ctx, carray_shape_i8, sizeof_buf_shape.value); - buffer.base_ptr(ctx, generator) + buf.value } _ => { @@ -1091,56 +1072,46 @@ fn polymorphic_print<'ctx>( } TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); - fmt.push_str("array(["); flush(ctx, generator, &mut fmt, &mut args); - let val = NDArrayValue::from_ptr_val(value.into_pointer_value(), llvm_usize, None); - let len = call_ndarray_calc_size(generator, ctx, &val.dim_sizes(), (None, None)); - let last = - ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap(); + let ndarray = AnyObject { ty, value }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (len, false), - |generator, ctx, _, i| { - let elem = unsafe { val.data().get_unchecked(ctx, generator, &i, None) }; + let num_0 = Int(SizeT).const_0(generator, ctx.ctx); - polymorphic_print( - ctx, - generator, - &[(elem_ty, elem.into())], - "", - None, - true, - as_rtio, - )?; + // Print `ndarray` as a flat list delimited by interspersed with ", \0" + ndarray.foreach(generator, ctx, |generator, ctx, _, hdl| { + let i = hdl.get_index(generator, ctx); + let scalar = hdl.get_scalar(generator, ctx); - gen_if_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare(IntPredicate::ULT, i, last, "") - .unwrap()) - }, - |generator, ctx| { - printf(ctx, generator, ", \0".into(), Vec::default()); + // if (i != 0) { puts(", "); } + gen_if_callback( + generator, + ctx, + |_, ctx| { + let not_first = i.compare(ctx, IntPredicate::NE, num_0); + Ok(not_first.value) + }, + |generator, ctx| { + printf(ctx, generator, ", \0".into(), Vec::default()); + Ok(()) + }, + |_, _| Ok(()), + )?; - Ok(()) - }, - |_, _| Ok(()), - )?; - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; + // Print element + polymorphic_print( + ctx, + generator, + &[(scalar.ty, scalar.value.into())], + "", + None, + true, + as_rtio, + )?; + Ok(()) + })?; fmt.push_str(")]"); flush(ctx, generator, &mut fmt, &mut args); diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index be2853c7..6b310b3d 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -33,6 +33,7 @@ use inkwell::{ OptimizationLevel, }; use itertools::Itertools; +use nac3core::codegen::irrt::setup_irrt_exceptions; use nac3core::codegen::{gen_func_impl, CodeGenLLVMOptions, CodeGenTargetMachineOptions}; use nac3core::toplevel::builtins::get_exn_constructor; use nac3core::typecheck::typedef::{into_var_map, TypeEnum, Unifier, VarMap}; @@ -557,6 +558,11 @@ impl Nac3 { .register_top_level(synthesized.pop().unwrap(), Some(resolver.clone()), "", false) .unwrap(); + // Process IRRT + let context = inkwell::context::Context::create(); + let irrt = load_irrt(&context); + setup_irrt_exceptions(&context, &irrt, resolver.as_ref()); + let fun_signature = FunSignature { args: vec![], ret: self.primitive.none, vars: VarMap::new() }; let mut store = ConcreteTypeStore::new(); @@ -727,7 +733,7 @@ impl Nac3 { membuffer.lock().push(buffer); }); - let context = inkwell::context::Context::create(); + // Link all modules into `main`. let buffers = membuffers.lock(); let main = context .create_module_from_ir(MemoryBuffer::create_from_memory_range(&buffers[0], "main")) @@ -756,8 +762,7 @@ impl Nac3 { ) .unwrap(); - main.link_in_module(load_irrt(&context)) - .map_err(|err| CompileError::new_err(err.to_string()))?; + main.link_in_module(irrt).map_err(|err| CompileError::new_err(err.to_string()))?; let mut function_iter = main.get_first_function(); while let Some(func) = function_iter { diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 9470ee71..21b1ac10 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -1,14 +1,15 @@ use crate::PrimitivePythonId; use inkwell::{ module::Linkage, - types::{BasicType, BasicTypeEnum}, - values::BasicValueEnum, + types::BasicType, + values::{BasicValue, BasicValueEnum}, AddressSpace, }; use itertools::Itertools; use nac3core::{ codegen::{ - classes::{NDArrayType, ProxyType}, + model::*, + object::ndarray::{make_contiguous_strides, NDArray}, CodeGenContext, CodeGenerator, }, symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum}, @@ -26,7 +27,7 @@ use nac3parser::ast::{self, StrRef}; use parking_lot::{Mutex, RwLock}; use pyo3::{ types::{PyDict, PyTuple}, - PyAny, PyObject, PyResult, Python, + PyAny, PyErr, PyObject, PyResult, Python, }; use std::{ collections::{HashMap, HashSet}, @@ -1086,15 +1087,12 @@ impl InnerResolver { let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty); - let llvm_usize = generator.get_size_type(ctx.ctx); - let ndarray_dtype_llvm_ty = ctx.get_llvm_type(generator, ndarray_dtype); - let ndarray_llvm_ty = NDArrayType::new(generator, ctx.ctx, ndarray_dtype_llvm_ty); - + let dtype = Any(ctx.get_llvm_type(generator, ndarray_dtype)); { if self.global_value_ids.read().contains_key(&id) { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { ctx.module.add_global( - ndarray_llvm_ty.as_underlying_type(), + Struct(NDArray).get_type(generator, ctx.ctx), Some(AddressSpace::default()), &id_str, ) @@ -1114,100 +1112,138 @@ impl InnerResolver { } else { todo!("Unpacking literal of more than one element unimplemented") }; - let Ok(ndarray_ndims) = u64::try_from(ndarray_ndims) else { + let Ok(ndims) = u64::try_from(ndarray_ndims) else { unreachable!("Expected u64 value for ndarray_ndims") }; // Obtain the shape of the ndarray let shape_tuple: &PyTuple = obj.getattr("shape")?.downcast()?; - assert_eq!(shape_tuple.len(), ndarray_ndims as usize); - let shape_values: Result>, _> = shape_tuple + assert_eq!(shape_tuple.len(), ndims as usize); + + // The Rust type inferencer cannot figure this out + let shape_values: Result>>, PyErr> = shape_tuple .iter() .enumerate() .map(|(i, elem)| { - self.get_obj_value(py, elem, ctx, generator, ctx.primitives.usize()).map_err( - |e| super::CompileError::new_err(format!("Error getting element {i}: {e}")), - ) + let value = self + .get_obj_value(py, elem, ctx, generator, ctx.primitives.usize()) + .map_err(|e| { + super::CompileError::new_err(format!("Error getting element {i}: {e}")) + })? + .unwrap(); + let value = Int(SizeT).check_value(generator, ctx.ctx, value).unwrap(); + Ok(value) }) .collect(); - let shape_values = shape_values?.unwrap(); - let shape_values = llvm_usize.const_array( - &shape_values.into_iter().map(BasicValueEnum::into_int_value).collect_vec(), - ); + let shape_values = shape_values?; + + // Also use this opportunity to get the constant values of `shape_values` for calculating strides. + let shape_u64s = shape_values + .iter() + .map(|dim| { + assert!(dim.value.is_const()); + dim.value.get_zero_extended_constant().unwrap() + }) + .collect_vec(); + let shape_values = Int(SizeT).const_array(generator, ctx.ctx, &shape_values); // create a global for ndarray.shape and initialize it using the shape let shape_global = ctx.module.add_global( - llvm_usize.array_type(ndarray_ndims as u32), + Array { len: AnyLen(ndims as u32), item: Int(SizeT) }.get_type(generator, ctx.ctx), Some(AddressSpace::default()), &(id_str.clone() + ".shape"), ); - shape_global.set_initializer(&shape_values); + shape_global.set_initializer(&shape_values.value); // Obtain the (flattened) elements of the ndarray let sz: usize = obj.getattr("size")?.extract()?; - let data: Result>, _> = (0..sz) + let data_values: Vec> = (0..sz) .map(|i| { obj.getattr("flat")?.get_item(i).and_then(|elem| { - self.get_obj_value(py, elem, ctx, generator, ndarray_dtype).map_err(|e| { - super::CompileError::new_err(format!("Error getting element {i}: {e}")) - }) + let value = self + .get_obj_value(py, elem, ctx, generator, ndarray_dtype) + .map_err(|e| { + super::CompileError::new_err(format!( + "Error getting element {i}: {e}" + )) + })? + .unwrap(); + + let value = dtype.check_value(generator, ctx.ctx, value).unwrap(); + Ok(value) }) }) - .collect(); - let data = data?.unwrap().into_iter(); - let data = match ndarray_dtype_llvm_ty { - BasicTypeEnum::ArrayType(ty) => { - ty.const_array(&data.map(BasicValueEnum::into_array_value).collect_vec()) - } - - BasicTypeEnum::FloatType(ty) => { - ty.const_array(&data.map(BasicValueEnum::into_float_value).collect_vec()) - } - - BasicTypeEnum::IntType(ty) => { - ty.const_array(&data.map(BasicValueEnum::into_int_value).collect_vec()) - } - - BasicTypeEnum::PointerType(ty) => { - ty.const_array(&data.map(BasicValueEnum::into_pointer_value).collect_vec()) - } - - BasicTypeEnum::StructType(ty) => { - ty.const_array(&data.map(BasicValueEnum::into_struct_value).collect_vec()) - } - - BasicTypeEnum::VectorType(_) => unreachable!(), - }; + .try_collect()?; + let data = dtype.const_array(generator, ctx.ctx, &data_values); // create a global for ndarray.data and initialize it using the elements + // + // NOTE: NDArray's `data` is `u8*`. Here, `data_global` is an array of `dtype`. + // We will have to cast it to an `u8*` later. let data_global = ctx.module.add_global( - ndarray_dtype_llvm_ty.array_type(sz as u32), + Array { len: AnyLen(sz as u32), item: dtype }.get_type(generator, ctx.ctx), Some(AddressSpace::default()), &(id_str.clone() + ".data"), ); - data_global.set_initializer(&data); + data_global.set_initializer(&data.value); + + // Get the constant itemsize. + let itemsize = dtype.get_type(generator, ctx.ctx).size_of().unwrap(); + let itemsize = itemsize.get_zero_extended_constant().unwrap(); + + // Create the strides needed for ndarray.strides + let strides = make_contiguous_strides(itemsize, ndims, &shape_u64s); + let strides = strides + .into_iter() + .map(|stride| Int(SizeT).const_int(generator, ctx.ctx, stride)) + .collect_vec(); + let strides = Int(SizeT).const_array(generator, ctx.ctx, &strides); + + // create a global for ndarray.strides and initialize it + let strides_global = ctx.module.add_global( + Array { len: AnyLen(ndims as u32), item: Int(Byte) }.get_type(generator, ctx.ctx), + Some(AddressSpace::default()), + &(id_str.clone() + ".strides"), + ); + strides_global.set_initializer(&strides.value); // create a global for the ndarray object and initialize it - let value = ndarray_llvm_ty.as_underlying_type().const_named_struct(&[ - llvm_usize.const_int(ndarray_ndims, false).into(), - shape_global - .as_pointer_value() - .const_cast(llvm_usize.ptr_type(AddressSpace::default())) - .into(), - data_global - .as_pointer_value() - .const_cast(ndarray_dtype_llvm_ty.ptr_type(AddressSpace::default())) - .into(), - ]); + // We are also doing [`Model::check_value`] instead of [`Model::believe_value`] to catch bugs. - let ndarray = ctx.module.add_global( - ndarray_llvm_ty.as_underlying_type(), + // NOTE: data_global is an array of dtype, we want a `u8*`. + let ndarray_data = Ptr(dtype).check_value(generator, ctx.ctx, data_global).unwrap(); + let ndarray_data = Ptr(Int(Byte)).pointer_cast(generator, ctx, ndarray_data.value); + + let ndarray_itemsize = Int(SizeT).const_int(generator, ctx.ctx, itemsize); + + let ndarray_ndims = Int(SizeT).const_int(generator, ctx.ctx, ndims); + + let ndarray_shape = + Ptr(Int(SizeT)).check_value(generator, ctx.ctx, shape_global).unwrap(); + + let ndarray_strides = + Ptr(Int(SizeT)).check_value(generator, ctx.ctx, strides_global).unwrap(); + + let ndarray = Struct(NDArray).const_struct( + generator, + ctx.ctx, + &[ + ndarray_data.value.as_basic_value_enum(), + ndarray_itemsize.value.as_basic_value_enum(), + ndarray_ndims.value.as_basic_value_enum(), + ndarray_shape.value.as_basic_value_enum(), + ndarray_strides.value.as_basic_value_enum(), + ], + ); + + let ndarray_global = ctx.module.add_global( + Struct(NDArray).get_type(generator, ctx.ctx), Some(AddressSpace::default()), &id_str, ); - ndarray.set_initializer(&value); + ndarray_global.set_initializer(&ndarray.value); - Ok(Some(ndarray.as_pointer_value().into())) + Ok(Some(ndarray_global.as_pointer_value().into())) } else if ty_id == self.primitive_ids.tuple { let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty); let TypeEnum::TTuple { ty, is_vararg_ctx: false } = expected_ty_enum.as_ref() else { diff --git a/nac3core/build.rs b/nac3core/build.rs index 38e3382f..b56803dd 100644 --- a/nac3core/build.rs +++ b/nac3core/build.rs @@ -8,37 +8,49 @@ use std::{ }; fn main() { - const FILE: &str = "src/codegen/irrt/irrt.cpp"; + let out_dir = env::var("OUT_DIR").unwrap(); + let out_dir = Path::new(&out_dir); + let irrt_dir = Path::new("irrt"); + + let irrt_cpp_path = irrt_dir.join("irrt.cpp"); /* * HACK: Sadly, clang doesn't let us emit generic LLVM bitcode. * Compiling for WASM32 and filtering the output with regex is the closest we can get. */ - let flags: &[&str] = &[ + let mut flags: Vec<&str> = vec![ "--target=wasm32", - FILE, "-x", "c++", "-fno-discard-value-names", "-fno-exceptions", "-fno-rtti", - match env::var("PROFILE").as_deref() { - Ok("debug") => "-O0", - Ok("release") => "-O3", - flavor => panic!("Unknown or missing build flavor {flavor:?}"), - }, "-emit-llvm", "-S", "-Wall", "-Wextra", "-o", "-", + "-I", + irrt_dir.to_str().unwrap(), + irrt_cpp_path.to_str().unwrap(), ]; - println!("cargo:rerun-if-changed={FILE}"); - let out_dir = env::var("OUT_DIR").unwrap(); - let out_path = Path::new(&out_dir); + match env::var("PROFILE").as_deref() { + Ok("debug") => { + flags.push("-O0"); + flags.push("-DIRRT_DEBUG_ASSERT"); + } + Ok("release") => { + flags.push("-O3"); + } + flavor => panic!("Unknown or missing build flavor {flavor:?}"), + } + // Tell Cargo to rerun if any file under `irrt_dir` (recursive) changes + println!("cargo:rerun-if-changed={}", irrt_dir.to_str().unwrap()); + + // Compile IRRT and capture the LLVM IR output let output = Command::new("clang-irrt") .args(flags) .output() @@ -52,7 +64,17 @@ fn main() { let output = std::str::from_utf8(&output.stdout).unwrap().replace("\r\n", "\n"); let mut filtered_output = String::with_capacity(output.len()); - let regex_filter = Regex::new(r"(?ms:^define.*?\}$)|(?m:^declare.*?$)").unwrap(); + // Filter out irrelevant IR + // + // Regex: + // - `(?ms:^define.*?\}$)` captures LLVM `define` blocks + // - `(?m:^declare.*?$)` captures LLVM `declare` lines + // - `(?m:^%.+?=\s*type\s*\{.+?\}$)` captures LLVM `type` declarations + // - `(?m:^@.+?=.+$)` captures global constants + let regex_filter = Regex::new( + r"(?ms:^define.*?\}$)|(?m:^declare.*?$)|(?m:^%.+?=\s*type\s*\{.+?\}$)|(?m:^@.+?=.+$)", + ) + .unwrap(); for f in regex_filter.captures_iter(&output) { assert_eq!(f.len(), 1); filtered_output.push_str(&f[0]); @@ -63,18 +85,22 @@ fn main() { .unwrap() .replace_all(&filtered_output, ""); - println!("cargo:rerun-if-env-changed=DEBUG_DUMP_IRRT"); - if env::var("DEBUG_DUMP_IRRT").is_ok() { - let mut file = File::create(out_path.join("irrt.ll")).unwrap(); + // For debugging + // Doing `DEBUG_DUMP_IRRT=1 cargo build -p nac3core` dumps the LLVM IR generated + const DEBUG_DUMP_IRRT: &str = "DEBUG_DUMP_IRRT"; + println!("cargo:rerun-if-env-changed={DEBUG_DUMP_IRRT}"); + if env::var(DEBUG_DUMP_IRRT).is_ok() { + let mut file = File::create(out_dir.join("irrt.ll")).unwrap(); file.write_all(output.as_bytes()).unwrap(); - let mut file = File::create(out_path.join("irrt-filtered.ll")).unwrap(); + + let mut file = File::create(out_dir.join("irrt-filtered.ll")).unwrap(); file.write_all(filtered_output.as_bytes()).unwrap(); } let mut llvm_as = Command::new("llvm-as-irrt") .stdin(Stdio::piped()) .arg("-o") - .arg(out_path.join("irrt.bc")) + .arg(out_dir.join("irrt.bc")) .spawn() .unwrap(); llvm_as.stdin.as_mut().unwrap().write_all(filtered_output.as_bytes()).unwrap(); diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp new file mode 100644 index 00000000..cd434f85 --- /dev/null +++ b/nac3core/irrt/irrt.cpp @@ -0,0 +1,16 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include \ No newline at end of file diff --git a/nac3core/irrt/irrt/cslice.hpp b/nac3core/irrt/irrt/cslice.hpp new file mode 100644 index 00000000..06b3fc2f --- /dev/null +++ b/nac3core/irrt/irrt/cslice.hpp @@ -0,0 +1,9 @@ +#pragma once + +#include + +template struct CSlice +{ + uint8_t *base; + SizeT len; +}; \ No newline at end of file diff --git a/nac3core/irrt/irrt/cstr_util.hpp b/nac3core/irrt/irrt/cstr_util.hpp new file mode 100644 index 00000000..cf6ed34d --- /dev/null +++ b/nac3core/irrt/irrt/cstr_util.hpp @@ -0,0 +1,20 @@ +#pragma once + +#include + +namespace cstr +{ +/** + * @brief Implementation of `strlen()`. + */ +uint32_t length(const char *str) +{ + uint32_t length = 0; + while (*str != '\0') + { + length++; + str++; + } + return length; +} +} // namespace cstr \ No newline at end of file diff --git a/nac3core/irrt/irrt/debug.hpp b/nac3core/irrt/irrt/debug.hpp new file mode 100644 index 00000000..f5baea7a --- /dev/null +++ b/nac3core/irrt/irrt/debug.hpp @@ -0,0 +1,23 @@ +#pragma once + +// Set in nac3core/build.rs +#ifdef IRRT_DEBUG_ASSERT +#define IRRT_DEBUG_ASSERT_BOOL true +#else +#define IRRT_DEBUG_ASSERT_BOOL false +#endif + +#define raise_debug_assert(SizeT, msg, param1, param2, param3) \ + raise_exception(SizeT, EXN_ASSERTION_ERROR, "IRRT debug assert failed: " msg, param1, param2, param3); + +#define debug_assert_eq(SizeT, lhs, rhs) \ + if (IRRT_DEBUG_ASSERT_BOOL && (lhs) != (rhs)) \ + { \ + raise_debug_assert(SizeT, "LHS = {0}. RHS = {1}", lhs, rhs, NO_PARAM); \ + } + +#define debug_assert(SizeT, expr) \ + if (IRRT_DEBUG_ASSERT_BOOL && !(expr)) \ + { \ + raise_debug_assert(SizeT, "Got false.", NO_PARAM, NO_PARAM, NO_PARAM); \ + } \ No newline at end of file diff --git a/nac3core/irrt/irrt/exception.hpp b/nac3core/irrt/irrt/exception.hpp new file mode 100644 index 00000000..22eea6ae --- /dev/null +++ b/nac3core/irrt/irrt/exception.hpp @@ -0,0 +1,80 @@ +#pragma once + +#include +#include +#include + +/** + * @brief The int type of ARTIQ exception IDs. + */ +typedef int32_t ExceptionId; + +/* + * Set of exceptions C++ IRRT can use. + * Must be synchronized with `setup_irrt_exceptions` in `nac3core/src/codegen/irrt/mod.rs`. + */ +extern "C" +{ + ExceptionId EXN_INDEX_ERROR; + ExceptionId EXN_VALUE_ERROR; + ExceptionId EXN_ASSERTION_ERROR; + ExceptionId EXN_TYPE_ERROR; +} + +/** + * @brief Extern function to `__nac3_raise` + * + * The parameter `err` could be `Exception` or `Exception`. The caller + * must make sure to pass `Exception`s with the correct `SizeT` depending on the `size_t` of the runtime. + */ +extern "C" void __nac3_raise(void *err); + +namespace +{ +/** + * @brief NAC3's Exception struct + */ +template struct Exception +{ + ExceptionId id; + CSlice filename; + int32_t line; + int32_t column; + CSlice function; + CSlice msg; + int64_t params[3]; +}; + +const int64_t NO_PARAM = 0; + +template +void _raise_exception_helper(ExceptionId id, const char *filename, int32_t line, const char *function, const char *msg, + int64_t param0, int64_t param1, int64_t param2) +{ + Exception e = { + .id = id, + .filename = {.base = (uint8_t *)filename, .len = (int32_t)cstr::length(filename)}, + .line = line, + .column = 0, + .function = {.base = (uint8_t *)function, .len = (int32_t)cstr::length(function)}, + .msg = {.base = (uint8_t *)msg, .len = (int32_t)cstr::length(msg)}, + }; + e.params[0] = param0; + e.params[1] = param1; + e.params[2] = param2; + __nac3_raise((void *)&e); + __builtin_unreachable(); +} + +/** + * @brief Raise an exception with location details (location in the IRRT source files). + * @param SizeT The runtime `size_t` type. + * @param id The ID of the exception to raise. + * @param msg A global constant C-string of the error message. + * + * `param0` to `param2` are optional format arguments of `msg`. They should be set to + * `NO_PARAM` to indicate they are unused. + */ +#define raise_exception(SizeT, id, msg, param0, param1, param2) \ + _raise_exception_helper(id, __FILE__, __LINE__, __FUNCTION__, msg, param0, param1, param2) +} // namespace \ No newline at end of file diff --git a/nac3core/irrt/irrt/int_types.hpp b/nac3core/irrt/irrt/int_types.hpp new file mode 100644 index 00000000..2aa900d8 --- /dev/null +++ b/nac3core/irrt/irrt/int_types.hpp @@ -0,0 +1,8 @@ +#pragma once + +using int8_t = _BitInt(8); +using uint8_t = unsigned _BitInt(8); +using int32_t = _BitInt(32); +using uint32_t = unsigned _BitInt(32); +using int64_t = _BitInt(64); +using uint64_t = unsigned _BitInt(64); diff --git a/nac3core/irrt/irrt/list.hpp b/nac3core/irrt/irrt/list.hpp new file mode 100644 index 00000000..25011f11 --- /dev/null +++ b/nac3core/irrt/irrt/list.hpp @@ -0,0 +1,19 @@ +#pragma once + +#include +#include + +namespace +{ +/** + * @brief A list in NAC3. + * + * The `items` field is opaque. You must rely on external contexts to + * know how to interpret it. + */ +template struct List +{ + uint8_t *items; + SizeT len; +}; +} // namespace \ No newline at end of file diff --git a/nac3core/irrt/irrt/math_util.hpp b/nac3core/irrt/irrt/math_util.hpp new file mode 100644 index 00000000..d7ac779b --- /dev/null +++ b/nac3core/irrt/irrt/math_util.hpp @@ -0,0 +1,14 @@ +#pragma once + +namespace +{ +template const T &max(const T &a, const T &b) +{ + return a > b ? a : b; +} + +template const T &min(const T &a, const T &b) +{ + return a > b ? b : a; +} +} // namespace \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/array.hpp b/nac3core/irrt/irrt/ndarray/array.hpp new file mode 100644 index 00000000..84084388 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/array.hpp @@ -0,0 +1,157 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace +{ +namespace ndarray +{ +namespace array +{ +/** + * @brief In the context of `np.array()`, deduce the ndarray's shape produced by `` and raise + * an exception if there is anything wrong with `` (e.g., inconsistent dimensions `np.array([[1.0, 2.0], [3.0]])`) + * + * If this function finds no issues with ``, the deduced shape is written to `shape`. The caller has the responsibility to + * allocate `[SizeT; ndims]` for `shape`. The caller must also initialize `shape` with `-1`s because of implementation details. + */ +template +void set_and_validate_list_shape_helper(SizeT axis, List *list, SizeT ndims, SizeT *shape) +{ + if (shape[axis] == -1) + { + // Dimension is unspecified. Set it. + shape[axis] = list->len; + } + else + { + // Dimension is specified. Check. + if (shape[axis] != list->len) + { + // Mismatch, throw an error. + // NOTE: NumPy's error message is more complex and needs more PARAMS to display. + raise_exception(SizeT, EXN_VALUE_ERROR, + "The requested array has an inhomogenous shape " + "after {0} dimension(s).", + axis, shape[axis], list->len); + } + } + + if (axis + 1 == ndims) + { + // `list` has type `list[ItemType]` + // Do nothing + } + else + { + // `list` has type `list[list[...]]` + List **lists = (List **)(list->items); + for (SizeT i = 0; i < list->len; i++) + { + set_and_validate_list_shape_helper(axis + 1, lists[i], ndims, shape); + } + } +} + +/** + * @brief See `set_and_validate_list_shape_helper`. + */ +template void set_and_validate_list_shape(List *list, SizeT ndims, SizeT *shape) +{ + for (SizeT axis = 0; axis < ndims; axis++) + { + shape[axis] = -1; // Sentinel to say this dimension is unspecified. + } + set_and_validate_list_shape_helper(0, list, ndims, shape); +} + +/** + * @brief In the context of `np.array()`, copied the contents stored in `list` to `ndarray`. + * + * `list` is assumed to be "legal". (i.e., no inconsistent dimensions) + * + * # Notes on `ndarray` + * The caller is responsible for allocating space for `ndarray`. + * Here is what this function expects from `ndarray` when called: + * - `ndarray->data` has to be allocated, contiguous, and may contain uninitialized values. + * - `ndarray->itemsize` has to be initialized. + * - `ndarray->ndims` has to be initialized. + * - `ndarray->shape` has to be initialized. + * - `ndarray->strides` is ignored, but note that `ndarray->data` is contiguous. + * When this function call ends: + * - `ndarray->data` is written with contents from ``. + */ +template +void write_list_to_array_helper(SizeT axis, SizeT *index, List *list, NDArray *ndarray) +{ + debug_assert_eq(SizeT, list->len, ndarray->shape[axis]); + if (IRRT_DEBUG_ASSERT_BOOL) + { + if (!ndarray::basic::is_c_contiguous(ndarray)) + { + raise_debug_assert(SizeT, "ndarray is not C-contiguous", ndarray->strides[0], ndarray->strides[1], + NO_PARAM); + } + } + + if (axis + 1 == ndarray->ndims) + { + // `list` has type `list[scalar]` + // `ndarray` is contiguous, so we can do this, and this is fast. + uint8_t *dst = ndarray->data + (ndarray->itemsize * (*index)); + __builtin_memcpy(dst, list->items, ndarray->itemsize * list->len); + *index += list->len; + } + else + { + // `list` has type `list[list[...]]` + List **lists = (List **)(list->items); + + for (SizeT i = 0; i < list->len; i++) + { + write_list_to_array_helper(axis + 1, index, lists[i], ndarray); + } + } +} + +/** + * @brief See `write_list_to_array_helper`. + */ +template void write_list_to_array(List *list, NDArray *ndarray) +{ + SizeT index = 0; + write_list_to_array_helper((SizeT)0, &index, list, ndarray); +} +} // namespace array +} // namespace ndarray +} // namespace + +extern "C" +{ + using namespace ndarray::array; + + void __nac3_ndarray_array_set_and_validate_list_shape(List *list, int32_t ndims, int32_t *shape) + { + set_and_validate_list_shape(list, ndims, shape); + } + + void __nac3_ndarray_array_set_and_validate_list_shape64(List *list, int64_t ndims, int64_t *shape) + { + set_and_validate_list_shape(list, ndims, shape); + } + + void __nac3_ndarray_array_write_list_to_array(List *list, NDArray *ndarray) + { + write_list_to_array(list, ndarray); + } + + void __nac3_ndarray_array_write_list_to_array64(List *list, NDArray *ndarray) + { + write_list_to_array(list, ndarray); + } +} \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/basic.hpp b/nac3core/irrt/irrt/ndarray/basic.hpp new file mode 100644 index 00000000..604e4d76 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/basic.hpp @@ -0,0 +1,371 @@ +#pragma once + +#include +#include +#include +#include + +namespace +{ +namespace ndarray +{ +namespace basic +{ +/** + * @brief Assert that `shape` does not contain negative dimensions. + * + * @param ndims Number of dimensions in `shape` + * @param shape The shape to check on + */ +template void assert_shape_no_negative(SizeT ndims, const SizeT *shape) +{ + for (SizeT axis = 0; axis < ndims; axis++) + { + if (shape[axis] < 0) + { + raise_exception(SizeT, EXN_VALUE_ERROR, + "negative dimensions are not allowed; axis {0} " + "has dimension {1}", + axis, shape[axis], NO_PARAM); + } + } +} + +/** + * @brief Assert that two shapes are the same in the context of writing output to an ndarray. + */ +template +void assert_output_shape_same(SizeT ndarray_ndims, const SizeT *ndarray_shape, SizeT output_ndims, + const SizeT *output_shape) +{ + if (ndarray_ndims != output_ndims) + { + // There is no corresponding NumPy error message like this. + raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot write output of ndims {0} to an ndarray with ndims {1}", + output_ndims, ndarray_ndims, NO_PARAM); + } + + for (SizeT axis = 0; axis < ndarray_ndims; axis++) + { + if (ndarray_shape[axis] != output_shape[axis]) + { + // There is no corresponding NumPy error message like this. + raise_exception(SizeT, EXN_VALUE_ERROR, + "Mismatched dimensions on axis {0}, output has " + "dimension {1}, but destination ndarray has dimension {2}.", + axis, output_shape[axis], ndarray_shape[axis]); + } + } +} + +/** + * @brief Return the number of elements of an ndarray given its shape. + * + * @param ndims Number of dimensions in `shape` + * @param shape The shape of the ndarray + */ +template SizeT calc_size_from_shape(SizeT ndims, const SizeT *shape) +{ + SizeT size = 1; + for (SizeT axis = 0; axis < ndims; axis++) + size *= shape[axis]; + return size; +} + +/** + * @brief Compute the array indices of the `nth` (0-based) element of an ndarray given only its shape. + * + * @param ndims Number of elements in `shape` and `indices` + * @param shape The shape of the ndarray + * @param indices The returned indices indexing the ndarray with shape `shape`. + * @param nth The index of the element of interest. + */ +template void set_indices_by_nth(SizeT ndims, const SizeT *shape, SizeT *indices, SizeT nth) +{ + for (SizeT i = 0; i < ndims; i++) + { + SizeT axis = ndims - i - 1; + SizeT dim = shape[axis]; + + indices[axis] = nth % dim; + nth /= dim; + } +} + +/** + * @brief Return the number of elements of an `ndarray` + * + * This function corresponds to `.size` + */ +template SizeT size(const NDArray *ndarray) +{ + return calc_size_from_shape(ndarray->ndims, ndarray->shape); +} + +/** + * @brief Return of the number of its content of an `ndarray`. + * + * This function corresponds to `.nbytes`. + */ +template SizeT nbytes(const NDArray *ndarray) +{ + return size(ndarray) * ndarray->itemsize; +} + +/** + * @brief Get the `len()` of an ndarray, and asserts that `ndarray` is a sized object. + * + * This function corresponds to `.__len__`. + * + * @param dst_length The length. + */ +template SizeT len(const NDArray *ndarray) +{ + // numpy prohibits `__len__` on unsized objects + if (ndarray->ndims == 0) + { + raise_exception(SizeT, EXN_TYPE_ERROR, "len() of unsized object", NO_PARAM, NO_PARAM, NO_PARAM); + } + else + { + return ndarray->shape[0]; + } +} + +/** + * @brief Return a boolean indicating if `ndarray` is (C-)contiguous. + * + * You may want to see ndarray's rules for C-contiguity: https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45 + */ +template bool is_c_contiguous(const NDArray *ndarray) +{ + // References: + // - tinynumpy's implementation: https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L102 + // - ndarray's flags["C_CONTIGUOUS"]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags + // - ndarray's rules for C-contiguity: https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45 + + // From https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45: + // + // The traditional rule is that for an array to be flagged as C contiguous, + // the following must hold: + // + // strides[-1] == itemsize + // strides[i] == shape[i+1] * strides[i + 1] + // [...] + // According to these rules, a 0- or 1-dimensional array is either both + // C- and F-contiguous, or neither; and an array with 2+ dimensions + // can be C- or F- contiguous, or neither, but not both. Though there + // there are exceptions for arrays with zero or one item, in the first + // case the check is relaxed up to and including the first dimension + // with shape[i] == 0. In the second case `strides == itemsize` will + // can be true for all dimensions and both flags are set. + + if (ndarray->ndims == 0) + { + return true; + } + + if (ndarray->strides[ndarray->ndims - 1] != ndarray->itemsize) + { + return false; + } + + for (SizeT i = 1; i < ndarray->ndims; i++) + { + SizeT axis_i = ndarray->ndims - i - 1; + if (ndarray->strides[axis_i] != ndarray->shape[axis_i + 1] * ndarray->strides[axis_i + 1]) + { + return false; + } + } + + return true; +} + +/** + * @brief Return the pointer to the element indexed by `indices` along the ndarray's axes. + * + * This function does no bound check. + */ +template uint8_t *get_pelement_by_indices(const NDArray *ndarray, const SizeT *indices) +{ + uint8_t *element = ndarray->data; + for (SizeT dim_i = 0; dim_i < ndarray->ndims; dim_i++) + element += indices[dim_i] * ndarray->strides[dim_i]; + return element; +} + +/** + * @brief Return the pointer to the nth (0-based) element of `ndarray` in flattened view. + * + * This function does no bound check. + */ +template uint8_t *get_nth_pelement(const NDArray *ndarray, SizeT nth) +{ + uint8_t *element = ndarray->data; + for (SizeT i = 0; i < ndarray->ndims; i++) + { + SizeT axis = ndarray->ndims - i - 1; + SizeT dim = ndarray->shape[axis]; + element += ndarray->strides[axis] * (nth % dim); + nth /= dim; + } + return element; +} + +/** + * @brief Update the strides of an ndarray given an ndarray `shape` to be contiguous. + * + * You might want to read https://ajcr.net/stride-guide-part-1/. + */ +template void set_strides_by_shape(NDArray *ndarray) +{ + SizeT stride_product = 1; + for (SizeT i = 0; i < ndarray->ndims; i++) + { + SizeT axis = ndarray->ndims - i - 1; + ndarray->strides[axis] = stride_product * ndarray->itemsize; + stride_product *= ndarray->shape[axis]; + } +} + +/** + * @brief Set an element in `ndarray`. + * + * @param pelement Pointer to the element in `ndarray` to be set. + * @param pvalue Pointer to the value `pelement` will be set to. + */ +template void set_pelement_value(NDArray *ndarray, uint8_t *pelement, const uint8_t *pvalue) +{ + __builtin_memcpy(pelement, pvalue, ndarray->itemsize); +} + +/** + * @brief Copy data from one ndarray to another of the exact same size and itemsize. + * + * Both ndarrays will be viewed in their flatten views when copying the elements. + */ +template void copy_data(const NDArray *src_ndarray, NDArray *dst_ndarray) +{ + // TODO: Make this faster with memcpy when we see a contiguous segment. + // TODO: Handle overlapping. + + debug_assert_eq(SizeT, src_ndarray->itemsize, dst_ndarray->itemsize); + + for (SizeT i = 0; i < size(src_ndarray); i++) + { + auto src_element = ndarray::basic::get_nth_pelement(src_ndarray, i); + auto dst_element = ndarray::basic::get_nth_pelement(dst_ndarray, i); + ndarray::basic::set_pelement_value(dst_ndarray, dst_element, src_element); + } +} +} // namespace basic +} // namespace ndarray +} // namespace + +extern "C" +{ + using namespace ndarray::basic; + + void __nac3_ndarray_util_assert_shape_no_negative(int32_t ndims, int32_t *shape) + { + assert_shape_no_negative(ndims, shape); + } + + void __nac3_ndarray_util_assert_shape_no_negative64(int64_t ndims, int64_t *shape) + { + assert_shape_no_negative(ndims, shape); + } + + void __nac3_ndarray_util_assert_output_shape_same(int32_t ndarray_ndims, const int32_t *ndarray_shape, + int32_t output_ndims, const int32_t *output_shape) + { + assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape); + } + + void __nac3_ndarray_util_assert_output_shape_same64(int64_t ndarray_ndims, const int64_t *ndarray_shape, + int64_t output_ndims, const int64_t *output_shape) + { + assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape); + } + + uint32_t __nac3_ndarray_size(NDArray *ndarray) + { + return size(ndarray); + } + + uint64_t __nac3_ndarray_size64(NDArray *ndarray) + { + return size(ndarray); + } + + uint32_t __nac3_ndarray_nbytes(NDArray *ndarray) + { + return nbytes(ndarray); + } + + uint64_t __nac3_ndarray_nbytes64(NDArray *ndarray) + { + return nbytes(ndarray); + } + + int32_t __nac3_ndarray_len(NDArray *ndarray) + { + return len(ndarray); + } + + int64_t __nac3_ndarray_len64(NDArray *ndarray) + { + return len(ndarray); + } + + bool __nac3_ndarray_is_c_contiguous(NDArray *ndarray) + { + return is_c_contiguous(ndarray); + } + + bool __nac3_ndarray_is_c_contiguous64(NDArray *ndarray) + { + return is_c_contiguous(ndarray); + } + + uint8_t *__nac3_ndarray_get_nth_pelement(const NDArray *ndarray, int32_t nth) + { + return get_nth_pelement(ndarray, nth); + } + + uint8_t *__nac3_ndarray_get_nth_pelement64(const NDArray *ndarray, int64_t nth) + { + return get_nth_pelement(ndarray, nth); + } + + uint8_t *__nac3_ndarray_get_pelement_by_indices(const NDArray *ndarray, int32_t *indices) + { + return get_pelement_by_indices(ndarray, indices); + } + + uint8_t *__nac3_ndarray_get_pelement_by_indices64(const NDArray *ndarray, int64_t *indices) + { + return get_pelement_by_indices(ndarray, indices); + } + + void __nac3_ndarray_set_strides_by_shape(NDArray *ndarray) + { + set_strides_by_shape(ndarray); + } + + void __nac3_ndarray_set_strides_by_shape64(NDArray *ndarray) + { + set_strides_by_shape(ndarray); + } + + void __nac3_ndarray_copy_data(NDArray *src_ndarray, NDArray *dst_ndarray) + { + copy_data(src_ndarray, dst_ndarray); + } + + void __nac3_ndarray_copy_data64(NDArray *src_ndarray, NDArray *dst_ndarray) + { + copy_data(src_ndarray, dst_ndarray); + } +} \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/broadcast.hpp b/nac3core/irrt/irrt/ndarray/broadcast.hpp new file mode 100644 index 00000000..699bd8fa --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/broadcast.hpp @@ -0,0 +1,188 @@ +#pragma once + +#include +#include +#include + +namespace +{ +template struct ShapeEntry +{ + SizeT ndims; + SizeT *shape; +}; +} // namespace + +namespace +{ +namespace ndarray +{ +namespace broadcast +{ +/** + * @brief Return true if `src_shape` can broadcast to `dst_shape`. + * + * See https://numpy.org/doc/stable/user/basics.broadcasting.html + */ +template +bool can_broadcast_shape_to(SizeT target_ndims, const SizeT *target_shape, SizeT src_ndims, const SizeT *src_shape) +{ + if (src_ndims > target_ndims) + { + return false; + } + + for (SizeT i = 0; i < src_ndims; i++) + { + SizeT target_dim = target_shape[target_ndims - i - 1]; + SizeT src_dim = src_shape[src_ndims - i - 1]; + if (!(src_dim == 1 || target_dim == src_dim)) + { + return false; + } + } + return true; +} + +/** + * @brief Performs `np.broadcast_shapes()` + * + * @param num_shapes Number of entries in `shapes` + * @param shapes The list of shape to do `np.broadcast_shapes` on. + * @param dst_ndims The length of `dst_shape`. + * `dst_ndims` must be `max([shape.ndims for shape in shapes])`, but the caller has to calculate it/provide it. + * for this function since they should already know in order to allocate `dst_shape` in the first place. + * @param dst_shape The resulting shape. Must be pre-allocated by the caller. This function calculate the result + * of `np.broadcast_shapes` and write it here. + */ +template +void broadcast_shapes(SizeT num_shapes, const ShapeEntry *shapes, SizeT dst_ndims, SizeT *dst_shape) +{ + for (SizeT dst_axis = 0; dst_axis < dst_ndims; dst_axis++) + { + dst_shape[dst_axis] = 1; + } + +#ifdef IRRT_DEBUG_ASSERT + SizeT max_ndims_found = 0; +#endif + + for (SizeT i = 0; i < num_shapes; i++) + { + ShapeEntry entry = shapes[i]; + + // Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])` + debug_assert(SizeT, entry.ndims <= dst_ndims); + +#ifdef IRRT_DEBUG_ASSERT + max_ndims_found = max(max_ndims_found, entry.ndims); +#endif + + for (SizeT j = 0; j < entry.ndims; j++) + { + SizeT entry_axis = entry.ndims - j - 1; + SizeT dst_axis = dst_ndims - j - 1; + + SizeT entry_dim = entry.shape[entry_axis]; + SizeT dst_dim = dst_shape[dst_axis]; + + if (dst_dim == 1) + { + dst_shape[dst_axis] = entry_dim; + } + else if (entry_dim == 1 || entry_dim == dst_dim) + { + // Do nothing + } + else + { + raise_exception(SizeT, EXN_VALUE_ERROR, + "shape mismatch: objects cannot be broadcast " + "to a single shape.", + NO_PARAM, NO_PARAM, NO_PARAM); + } + } + } + + // Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])` + debug_assert_eq(SizeT, max_ndims_found, dst_ndims); +} + +/** + * @brief Perform `np.broadcast_to(, )` and appropriate assertions. + * + * This function attempts to broadcast `src_ndarray` to a new shape defined by `dst_ndarray.shape`, + * and return the result by modifying `dst_ndarray`. + * + * # Notes on `dst_ndarray` + * The caller is responsible for allocating space for the resulting ndarray. + * Here is what this function expects from `dst_ndarray` when called: + * - `dst_ndarray->data` does not have to be initialized. + * - `dst_ndarray->itemsize` does not have to be initialized. + * - `dst_ndarray->ndims` must be initialized, determining the length of `dst_ndarray->shape` + * - `dst_ndarray->shape` must be allocated, and must contain the desired target broadcast shape. + * - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values. + * When this function call ends: + * - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`) + * - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize` + * - `dst_ndarray->ndims` is unchanged. + * - `dst_ndarray->shape` is unchanged. + * - `dst_ndarray->strides` is updated accordingly by how ndarray broadcast_to works. + */ +template void broadcast_to(const NDArray *src_ndarray, NDArray *dst_ndarray) +{ + if (!ndarray::broadcast::can_broadcast_shape_to(dst_ndarray->ndims, dst_ndarray->shape, src_ndarray->ndims, + src_ndarray->shape)) + { + raise_exception(SizeT, EXN_VALUE_ERROR, "operands could not be broadcast together", NO_PARAM, NO_PARAM, + NO_PARAM); + } + + dst_ndarray->data = src_ndarray->data; + dst_ndarray->itemsize = src_ndarray->itemsize; + + for (SizeT i = 0; i < dst_ndarray->ndims; i++) + { + SizeT src_axis = src_ndarray->ndims - i - 1; + SizeT dst_axis = dst_ndarray->ndims - i - 1; + if (src_axis < 0 || (src_ndarray->shape[src_axis] == 1 && dst_ndarray->shape[dst_axis] != 1)) + { + // Freeze the steps in-place + dst_ndarray->strides[dst_axis] = 0; + } + else + { + dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis]; + } + } +} +} // namespace broadcast +} // namespace ndarray +} // namespace + +extern "C" +{ + using namespace ndarray::broadcast; + + void __nac3_ndarray_broadcast_to(NDArray *src_ndarray, NDArray *dst_ndarray) + { + broadcast_to(src_ndarray, dst_ndarray); + } + + void __nac3_ndarray_broadcast_to64(NDArray *src_ndarray, NDArray *dst_ndarray) + { + broadcast_to(src_ndarray, dst_ndarray); + } + + void __nac3_ndarray_broadcast_shapes(int32_t num_shapes, const ShapeEntry *shapes, int32_t dst_ndims, + int32_t *dst_shape) + { + broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape); + } + + void __nac3_ndarray_broadcast_shapes64(int64_t num_shapes, const ShapeEntry *shapes, int64_t dst_ndims, + int64_t *dst_shape) + { + broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape); + } +} \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/def.hpp b/nac3core/irrt/irrt/ndarray/def.hpp new file mode 100644 index 00000000..fab8cbe9 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/def.hpp @@ -0,0 +1,45 @@ +#pragma once + +#include + +namespace +{ +/** + * @brief The NDArray object + * + * Official numpy implementation: https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst + */ +template struct NDArray +{ + /** + * @brief The underlying data this `ndarray` is pointing to. + */ + uint8_t *data; + + /** + * @brief The number of bytes of a single element in `data`. + */ + SizeT itemsize; + + /** + * @brief The number of dimensions of this shape. + */ + SizeT ndims; + + /** + * @brief The NDArray shape, with length equal to `ndims`. + * + * Note that it may contain 0. + */ + SizeT *shape; + + /** + * @brief Array strides, with length equal to `ndims` + * + * The stride values are in units of bytes, not number of elements. + * + * Note that `strides` can have negative values or contain 0. + */ + SizeT *strides; +}; +} // namespace \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/indexing.hpp b/nac3core/irrt/irrt/ndarray/indexing.hpp new file mode 100644 index 00000000..bdef5130 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/indexing.hpp @@ -0,0 +1,249 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace +{ +typedef uint8_t NDIndexType; + +/** + * @brief A single element index + * + * `data` points to a `int32_t`. + */ +const NDIndexType ND_INDEX_TYPE_SINGLE_ELEMENT = 0; + +/** + * @brief A slice index + * + * `data` points to a `Slice`. + */ +const NDIndexType ND_INDEX_TYPE_SLICE = 1; + +/** + * @brief `np.newaxis` / `None` + * + * `data` is unused. + */ +const NDIndexType ND_INDEX_TYPE_NEWAXIS = 2; + +/** + * @brief `Ellipsis` / `...` + * + * `data` is unused. + */ +const NDIndexType ND_INDEX_TYPE_ELLIPSIS = 3; + +/** + * @brief An index used in ndarray indexing + * + * That is: + * ``` + * my_ndarray[::-1, 3, ..., np.newaxis] + * ^^^^ ^ ^^^ ^^^^^^^^^^ each of these is represented by an NDIndex. + * ``` + */ +struct NDIndex +{ + /** + * @brief Enum tag to specify the type of index. + * + * Please see the comment of each enum constant. + */ + NDIndexType type; + + /** + * @brief The accompanying data associated with `type`. + * + * Please see the comment of each enum constant. + */ + uint8_t *data; +}; +} // namespace + +namespace +{ +namespace ndarray +{ +namespace indexing +{ +/** + * @brief Perform ndarray "basic indexing" (https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing) + * + * This function is very similar to performing `dst_ndarray = src_ndarray[indices]` in Python. + * + * This function also does proper assertions on `indices` to check for out of bounds access and more. + * + * # Notes on `dst_ndarray` + * The caller is responsible for allocating space for the resulting ndarray. + * Here is what this function expects from `dst_ndarray` when called: + * - `dst_ndarray->data` does not have to be initialized. + * - `dst_ndarray->itemsize` does not have to be initialized. + * - `dst_ndarray->ndims` must be initialized, and it must be equal to the expected `ndims` of the `dst_ndarray` after + * indexing `src_ndarray` with `indices`. + * - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values. + * - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values. + * When this function call ends: + * - `dst_ndarray->data` is set to `src_ndarray->data`. + * - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`. + * - `dst_ndarray->ndims` is unchanged. + * - `dst_ndarray->shape` is updated according to how `src_ndarray` is indexed. + * - `dst_ndarray->strides` is updated accordingly by how ndarray indexing works. + * + * @param indices indices to index `src_ndarray`, ordered in the same way you would write them in Python. + * @param src_ndarray The NDArray to be indexed. + * @param dst_ndarray The resulting NDArray after indexing. Further details in the comments above, + */ +template +void index(SizeT num_indices, const NDIndex *indices, const NDArray *src_ndarray, NDArray *dst_ndarray) +{ + // Validate `indices`. + + // Expected value of `dst_ndarray->ndims`. + SizeT expected_dst_ndims = src_ndarray->ndims; + // To check for "too many indices for array: array is ?-dimensional, but ? were indexed" + SizeT num_indexed = 0; + // There may be ellipsis `...` in `indices`. There can only be 0 or 1 ellipsis. + SizeT num_ellipsis = 0; + + for (SizeT i = 0; i < num_indices; i++) + { + if (indices[i].type == ND_INDEX_TYPE_SINGLE_ELEMENT) + { + expected_dst_ndims--; + num_indexed++; + } + else if (indices[i].type == ND_INDEX_TYPE_SLICE) + { + num_indexed++; + } + else if (indices[i].type == ND_INDEX_TYPE_NEWAXIS) + { + expected_dst_ndims++; + } + else if (indices[i].type == ND_INDEX_TYPE_ELLIPSIS) + { + num_ellipsis++; + if (num_ellipsis > 1) + { + raise_exception(SizeT, EXN_INDEX_ERROR, "an index can only have a single ellipsis ('...')", NO_PARAM, + NO_PARAM, NO_PARAM); + } + } + else + { + __builtin_unreachable(); + } + } + + debug_assert_eq(SizeT, expected_dst_ndims, dst_ndarray->ndims); + + if (src_ndarray->ndims - num_indexed < 0) + { + raise_exception(SizeT, EXN_INDEX_ERROR, + "too many indices for array: array is {0}-dimensional, " + "but {1} were indexed", + src_ndarray->ndims, num_indices, NO_PARAM); + } + + dst_ndarray->data = src_ndarray->data; + dst_ndarray->itemsize = src_ndarray->itemsize; + + // Reference code: https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652 + SizeT src_axis = 0; + SizeT dst_axis = 0; + + for (int32_t i = 0; i < num_indices; i++) + { + const NDIndex *index = &indices[i]; + if (index->type == ND_INDEX_TYPE_SINGLE_ELEMENT) + { + SizeT input = (SizeT) * ((int32_t *)index->data); + + SizeT k = slice::resolve_index_in_length(src_ndarray->shape[src_axis], input); + if (k == -1) + { + raise_exception(SizeT, EXN_INDEX_ERROR, + "index {0} is out of bounds for axis {1} " + "with size {2}", + input, src_axis, src_ndarray->shape[src_axis]); + } + + dst_ndarray->data += k * src_ndarray->strides[src_axis]; + + src_axis++; + } + else if (index->type == ND_INDEX_TYPE_SLICE) + { + Slice *slice = (Slice *)index->data; + + Range range = slice->indices_checked(src_ndarray->shape[src_axis]); + + dst_ndarray->data += (SizeT)range.start * src_ndarray->strides[src_axis]; + dst_ndarray->strides[dst_axis] = ((SizeT)range.step) * src_ndarray->strides[src_axis]; + dst_ndarray->shape[dst_axis] = (SizeT)range.len(); + + dst_axis++; + src_axis++; + } + else if (index->type == ND_INDEX_TYPE_NEWAXIS) + { + dst_ndarray->strides[dst_axis] = 0; + dst_ndarray->shape[dst_axis] = 1; + + dst_axis++; + } + else if (index->type == ND_INDEX_TYPE_ELLIPSIS) + { + // The number of ':' entries this '...' implies. + SizeT ellipsis_size = src_ndarray->ndims - num_indexed; + + for (SizeT j = 0; j < ellipsis_size; j++) + { + dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis]; + dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis]; + + dst_axis++; + src_axis++; + } + } + else + { + __builtin_unreachable(); + } + } + + for (; dst_axis < dst_ndarray->ndims; dst_axis++, src_axis++) + { + dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis]; + dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis]; + } + + debug_assert_eq(SizeT, src_ndarray->ndims, src_axis); + debug_assert_eq(SizeT, dst_ndarray->ndims, dst_axis); +} +} // namespace indexing +} // namespace ndarray +} // namespace + +extern "C" +{ + using namespace ndarray::indexing; + + void __nac3_ndarray_index(int32_t num_indices, NDIndex *indices, NDArray *src_ndarray, + NDArray *dst_ndarray) + { + index(num_indices, indices, src_ndarray, dst_ndarray); + } + + void __nac3_ndarray_index64(int64_t num_indices, NDIndex *indices, NDArray *src_ndarray, + NDArray *dst_ndarray) + { + index(num_indices, indices, src_ndarray, dst_ndarray); + } +} \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/iter.hpp b/nac3core/irrt/irrt/ndarray/iter.hpp new file mode 100644 index 00000000..4d9f6606 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/iter.hpp @@ -0,0 +1,139 @@ +#pragma once + +#include +#include + +namespace +{ +/** + * @brief Helper struct to enumerate through an ndarray *efficiently*. + * + * Interesting cases: + * - If ndims == 0, there is one iteration. + * - If shape contains zeroes, there are no iterations. + */ +template struct NDIter +{ + // Information about the ndarray being iterated over. + SizeT ndims; + SizeT *shape; + SizeT *strides; + + /** + * @brief The current indices. + * + * Must be allocated by the caller. + */ + SizeT *indices; + + /** + * @brief The nth (0-based) index of the current indices. + * + * Initially this is all 0s. + */ + SizeT nth; + + /** + * @brief Pointer to the current element. + * + * Initially this points to first element of the ndarray. + */ + uint8_t *element; + + /** + * @brief Cache for the product of shape. + * + * Could be 0 if `shape` has 0s in it. + */ + SizeT size; + + // TODO:: Not implemented: There is something called backstrides to speedup iteration. + // See https://ajcr.net/stride-guide-part-1/, and https://docs.scipy.org/doc/numpy-1.13.0/reference/c-api.types-and-structures.html#c.PyArrayIterObject.PyArrayIterObject.backstrides. + + void initialize(SizeT ndims, SizeT *shape, SizeT *strides, uint8_t *element, SizeT *indices) + { + this->ndims = ndims; + this->shape = shape; + this->strides = strides; + + this->indices = indices; + this->element = element; + + // Compute size + this->size = 1; + for (SizeT i = 0; i < ndims; i++) + { + this->size *= shape[i]; + } + + for (SizeT axis = 0; axis < ndims; axis++) + indices[axis] = 0; + nth = 0; + } + + void initialize_by_ndarray(NDArray *ndarray, SizeT *indices) + { + this->initialize(ndarray->ndims, ndarray->shape, ndarray->strides, ndarray->data, indices); + } + + bool has_next() + { + return nth < size; + } + + void next() + { + for (SizeT i = 0; i < ndims; i++) + { + SizeT axis = ndims - i - 1; + indices[axis]++; + if (indices[axis] >= shape[axis]) + { + indices[axis] = 0; + + // TODO: Can be optimized with backstrides. + element -= strides[axis] * (shape[axis] - 1); + } + else + { + element += strides[axis]; + break; + } + } + nth++; + } +}; +} // namespace + +extern "C" +{ + void __nac3_nditer_initialize(NDIter *iter, NDArray *ndarray, int32_t *indices) + { + iter->initialize_by_ndarray(ndarray, indices); + } + + void __nac3_nditer_initialize64(NDIter *iter, NDArray *ndarray, int64_t *indices) + { + iter->initialize_by_ndarray(ndarray, indices); + } + + bool __nac3_nditer_has_next(NDIter *iter) + { + return iter->has_next(); + } + + bool __nac3_nditer_has_next64(NDIter *iter) + { + return iter->has_next(); + } + + void __nac3_nditer_next(NDIter *iter) + { + iter->next(); + } + + void __nac3_nditer_next64(NDIter *iter) + { + iter->next(); + } +} \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/matmul.hpp b/nac3core/irrt/irrt/ndarray/matmul.hpp new file mode 100644 index 00000000..99da3653 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/matmul.hpp @@ -0,0 +1,92 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +// NOTE: Everything would be much easier and elegant if einsum is implemented. + +namespace +{ +namespace ndarray +{ +namespace matmul +{ + +/** + * @brief Perform the broadcast in `np.einsum("...ij,...jk->...ik", a, b)`. + * + * Example: + * Suppose `a_shape == [1, 97, 4, 2]` + * and `b_shape == [99, 98, 1, 2, 5]`, + * + * ...then `new_a_shape == [99, 98, 97, 4, 2]`, + * `new_b_shape == [99, 98, 97, 2, 5]`, + * and `dst_shape == [99, 98, 97, 4, 5]`. + * ^^^^^^^^^^ ^^^^ + * (broadcasted) (4x2 @ 2x5 => 4x5) + * + * @param a_ndims Length of `a_shape`. + * @param a_shape Shape of `a`. + * @param b_ndims Length of `b_shape`. + * @param b_shape Shape of `b`. + * @param final_ndims Should be equal to `max(a_ndims, b_ndims)`. This is the length of `new_a_shape`, + * `new_b_shape`, and `dst_shape` - the number of dimensions after broadcasting. + */ +template +void calculate_shapes(SizeT a_ndims, SizeT *a_shape, SizeT b_ndims, SizeT *b_shape, SizeT final_ndims, + SizeT *new_a_shape, SizeT *new_b_shape, SizeT *dst_shape) +{ + debug_assert(SizeT, a_ndims >= 2); + debug_assert(SizeT, b_ndims >= 2); + debug_assert_eq(SizeT, max(a_ndims, b_ndims), final_ndims); + + // Check that a and b are compatible for matmul + if (a_shape[a_ndims - 1] != b_shape[b_ndims - 2]) + { + // This is a custom error message. Different from NumPy. + raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot multiply LHS (shape ?x{0}) with RHS (shape {1}x?})", + a_shape[a_ndims - 1], b_shape[b_ndims - 2], NO_PARAM); + } + + const SizeT num_entries = 2; + ShapeEntry entries[num_entries] = {{.ndims = a_ndims - 2, .shape = a_shape}, + {.ndims = b_ndims - 2, .shape = b_shape}}; + + // TODO: Optimize this + ndarray::broadcast::broadcast_shapes(num_entries, entries, final_ndims - 2, new_a_shape); + ndarray::broadcast::broadcast_shapes(num_entries, entries, final_ndims - 2, new_b_shape); + ndarray::broadcast::broadcast_shapes(num_entries, entries, final_ndims - 2, dst_shape); + + new_a_shape[final_ndims - 2] = a_shape[a_ndims - 2]; + new_a_shape[final_ndims - 1] = a_shape[a_ndims - 1]; + new_b_shape[final_ndims - 2] = b_shape[b_ndims - 2]; + new_b_shape[final_ndims - 1] = b_shape[b_ndims - 1]; + dst_shape[final_ndims - 2] = a_shape[a_ndims - 2]; + dst_shape[final_ndims - 1] = b_shape[b_ndims - 1]; +} +} // namespace matmul +} // namespace ndarray +} // namespace + +extern "C" +{ + using namespace ndarray::matmul; + + void __nac3_ndarray_matmul_calculate_shapes(int32_t a_ndims, int32_t *a_shape, int32_t b_ndims, int32_t *b_shape, + int32_t final_ndims, int32_t *new_a_shape, int32_t *new_b_shape, + int32_t *dst_shape) + { + calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape); + } + + void __nac3_ndarray_matmul_calculate_shapes64(int64_t a_ndims, int64_t *a_shape, int64_t b_ndims, int64_t *b_shape, + int64_t final_ndims, int64_t *new_a_shape, int64_t *new_b_shape, + int64_t *dst_shape) + { + calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape); + } +} \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/reshape.hpp b/nac3core/irrt/irrt/ndarray/reshape.hpp new file mode 100644 index 00000000..aab363e1 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/reshape.hpp @@ -0,0 +1,125 @@ +#pragma once + +#include +#include + +namespace +{ +namespace ndarray +{ +namespace reshape +{ +/** + * @brief Perform assertions on and resolve unknown dimensions in `new_shape` in `np.reshape(, new_shape)` + * + * If `new_shape` indeed contains unknown dimensions (specified with `-1`, just like numpy), `new_shape` will be + * modified to contain the resolved dimension. + * + * To perform assertions on and resolve unknown dimensions in `new_shape`, we don't need the actual + * `` object itself, but only the `.size` of the ``. + * + * @param size The `.size` of `` + * @param new_ndims Number of elements in `new_shape` + * @param new_shape Target shape to reshape to + */ +template void resolve_and_check_new_shape(SizeT size, SizeT new_ndims, SizeT *new_shape) +{ + // Is there a -1 in `new_shape`? + bool neg1_exists = false; + // Location of -1, only initialized if `neg1_exists` is true + SizeT neg1_axis_i; + // The computed ndarray size of `new_shape` + SizeT new_size = 1; + + for (SizeT axis_i = 0; axis_i < new_ndims; axis_i++) + { + SizeT dim = new_shape[axis_i]; + if (dim < 0) + { + if (dim == -1) + { + if (neg1_exists) + { + // Multiple `-1` found. Throw an error. + raise_exception(SizeT, EXN_VALUE_ERROR, "can only specify one unknown dimension", NO_PARAM, + NO_PARAM, NO_PARAM); + } + else + { + neg1_exists = true; + neg1_axis_i = axis_i; + } + } + else + { + // TODO: What? In `np.reshape` any negative dimensions is + // treated like its `-1`. + // + // Try running `np.zeros((3, 4)).reshape((-999, 2))` + // + // It is not documented by numpy. + // Throw an error for now... + + raise_exception(SizeT, EXN_VALUE_ERROR, "Found non -1 negative dimension {0} on axis {1}", dim, axis_i, + NO_PARAM); + } + } + else + { + new_size *= dim; + } + } + + bool can_reshape; + if (neg1_exists) + { + // Let `x` be the unknown dimension + // Solve `x * = ` + if (new_size == 0 && size == 0) + { + // `x` has infinitely many solutions + can_reshape = false; + } + else if (new_size == 0 && size != 0) + { + // `x` has no solutions + can_reshape = false; + } + else if (size % new_size != 0) + { + // `x` has no integer solutions + can_reshape = false; + } + else + { + can_reshape = true; + new_shape[neg1_axis_i] = size / new_size; // Resolve dimension + } + } + else + { + can_reshape = (new_size == size); + } + + if (!can_reshape) + { + raise_exception(SizeT, EXN_VALUE_ERROR, "cannot reshape array of size {0} into given shape", size, NO_PARAM, + NO_PARAM); + } +} +} // namespace reshape +} // namespace ndarray +} // namespace + +extern "C" +{ + void __nac3_ndarray_reshape_resolve_and_check_new_shape(int32_t size, int32_t new_ndims, int32_t *new_shape) + { + ndarray::reshape::resolve_and_check_new_shape(size, new_ndims, new_shape); + } + + void __nac3_ndarray_reshape_resolve_and_check_new_shape64(int64_t size, int64_t new_ndims, int64_t *new_shape) + { + ndarray::reshape::resolve_and_check_new_shape(size, new_ndims, new_shape); + } +} diff --git a/nac3core/irrt/irrt/ndarray/transpose.hpp b/nac3core/irrt/irrt/ndarray/transpose.hpp new file mode 100644 index 00000000..ab5fe009 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/transpose.hpp @@ -0,0 +1,155 @@ +#pragma once + +#include +#include +#include + +/* + * Notes on `np.transpose(, )` + * + * TODO: `axes`, if specified, can actually contain negative indices, + * but it is not documented in numpy. + * + * Supporting it for now. + */ + +namespace +{ +namespace ndarray +{ +namespace transpose +{ +/** + * @brief Do assertions on `` in `np.transpose(, )`. + * + * Note that `np.transpose`'s `` argument is optional. If the argument + * is specified but the user, use this function to do assertions on it. + * + * @param ndims The number of dimensions of `` + * @param num_axes Number of elements in `` as specified by the user. + * This should be equal to `ndims`. If not, a "ValueError: axes don't match array" is thrown. + * @param axes The user specified ``. + */ +template void assert_transpose_axes(SizeT ndims, SizeT num_axes, const SizeT *axes) +{ + if (ndims != num_axes) + { + raise_exception(SizeT, EXN_VALUE_ERROR, "axes don't match array", NO_PARAM, NO_PARAM, NO_PARAM); + } + + // TODO: Optimize this + bool *axe_specified = (bool *)__builtin_alloca(sizeof(bool) * ndims); + for (SizeT i = 0; i < ndims; i++) + axe_specified[i] = false; + + for (SizeT i = 0; i < ndims; i++) + { + SizeT axis = slice::resolve_index_in_length(ndims, axes[i]); + if (axis == -1) + { + // TODO: numpy actually throws a `numpy.exceptions.AxisError` + raise_exception(SizeT, EXN_VALUE_ERROR, "axis {0} is out of bounds for array of dimension {1}", axis, ndims, + NO_PARAM); + } + + if (axe_specified[axis]) + { + raise_exception(SizeT, EXN_VALUE_ERROR, "repeated axis in transpose", NO_PARAM, NO_PARAM, NO_PARAM); + } + + axe_specified[axis] = true; + } +} + +/** + * @brief Create a transpose view of `src_ndarray` and perform proper assertions. + * + * This function is very similar to doing `dst_ndarray = np.transpose(src_ndarray, )`. + * If `` is supposed to be `None`, caller can pass in a `nullptr` to ``. + * + * The transpose view created is returned by modifying `dst_ndarray`. + * + * The caller is responsible for setting up `dst_ndarray` before calling this function. + * Here is what this function expects from `dst_ndarray` when called: + * - `dst_ndarray->data` does not have to be initialized. + * - `dst_ndarray->itemsize` does not have to be initialized. + * - `dst_ndarray->ndims` must be initialized, must be equal to `src_ndarray->ndims`. + * - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values. + * - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values. + * When this function call ends: + * - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`) + * - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize` + * - `dst_ndarray->ndims` is unchanged + * - `dst_ndarray->shape` is updated according to how `np.transpose` works + * - `dst_ndarray->strides` is updated according to how `np.transpose` works + * + * @param src_ndarray The NDArray to build a transpose view on + * @param dst_ndarray The resulting NDArray after transpose. Further details in the comments above, + * @param num_axes Number of elements in axes. Unused if `axes` is nullptr. + * @param axes Axes permutation. Set it to `nullptr` if `` is `None`. + */ +template +void transpose(const NDArray *src_ndarray, NDArray *dst_ndarray, SizeT num_axes, const SizeT *axes) +{ + debug_assert_eq(SizeT, src_ndarray->ndims, dst_ndarray->ndims); + const auto ndims = src_ndarray->ndims; + + if (axes != nullptr) + assert_transpose_axes(ndims, num_axes, axes); + + dst_ndarray->data = src_ndarray->data; + dst_ndarray->itemsize = src_ndarray->itemsize; + + // Check out https://ajcr.net/stride-guide-part-2/ to see how `np.transpose` works behind the scenes. + if (axes == nullptr) + { + // `np.transpose(, axes=None)` + + /* + * Minor note: `np.transpose(, axes=None)` is equivalent to + * `np.transpose(, axes=[N-1, N-2, ..., 0])` - basically it + * is reversing the order of strides and shape. + * + * This is a fast implementation to handle this special (but very common) case. + */ + + for (SizeT axis = 0; axis < ndims; axis++) + { + dst_ndarray->shape[axis] = src_ndarray->shape[ndims - axis - 1]; + dst_ndarray->strides[axis] = src_ndarray->strides[ndims - axis - 1]; + } + } + else + { + // `np.transpose(, )` + + // Permute strides and shape according to `axes`, while resolving negative indices in `axes` + for (SizeT axis = 0; axis < ndims; axis++) + { + // `i` cannot be OUT_OF_BOUNDS because of assertions + SizeT i = slice::resolve_index_in_length(ndims, axes[axis]); + + dst_ndarray->shape[axis] = src_ndarray->shape[i]; + dst_ndarray->strides[axis] = src_ndarray->strides[i]; + } + } +} +} // namespace transpose +} // namespace ndarray +} // namespace + +extern "C" +{ + using namespace ndarray::transpose; + void __nac3_ndarray_transpose(const NDArray *src_ndarray, NDArray *dst_ndarray, int32_t num_axes, + const int32_t *axes) + { + transpose(src_ndarray, dst_ndarray, num_axes, axes); + } + + void __nac3_ndarray_transpose64(const NDArray *src_ndarray, NDArray *dst_ndarray, + int64_t num_axes, const int64_t *axes) + { + transpose(src_ndarray, dst_ndarray, num_axes, axes); + } +} \ No newline at end of file diff --git a/nac3core/irrt/irrt/original.hpp b/nac3core/irrt/irrt/original.hpp new file mode 100644 index 00000000..9eaf5518 --- /dev/null +++ b/nac3core/irrt/irrt/original.hpp @@ -0,0 +1,215 @@ +#pragma once + +#include +#include + +// The type of an index or a value describing the length of a range/slice is always `int32_t`. +using SliceIndex = int32_t; + +namespace +{ +// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c +// need to make sure `exp >= 0` before calling this function +template T __nac3_int_exp_impl(T base, T exp) +{ + T res = 1; + /* repeated squaring method */ + do + { + if (exp & 1) + { + res *= base; /* for n odd */ + } + exp >>= 1; + base *= base; + } while (exp); + return res; +} +} // namespace + +extern "C" +{ +#define DEF_nac3_int_exp_(T) \ + T __nac3_int_exp_##T(T base, T exp) \ + { \ + return __nac3_int_exp_impl(base, exp); \ + } + + DEF_nac3_int_exp_(int32_t) DEF_nac3_int_exp_(int64_t) DEF_nac3_int_exp_(uint32_t) DEF_nac3_int_exp_(uint64_t) + + SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) + { + if (i < 0) + { + i = len + i; + } + if (i < 0) + { + return 0; + } + else if (i > len) + { + return len; + } + return i; + } + + SliceIndex __nac3_range_slice_len(const SliceIndex start, const SliceIndex end, const SliceIndex step) + { + SliceIndex diff = end - start; + if (diff > 0 && step > 0) + { + return ((diff - 1) / step) + 1; + } + else if (diff < 0 && step < 0) + { + return ((diff + 1) / step) + 1; + } + else + { + return 0; + } + } + + // Handle list assignment and dropping part of the list when + // both dest_step and src_step are +1. + // - All the index must *not* be out-of-bound or negative, + // - The end index is *inclusive*, + // - The length of src and dest slice size should already + // be checked: if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest) + SliceIndex __nac3_list_slice_assign_var_size(SliceIndex dest_start, SliceIndex dest_end, SliceIndex dest_step, + uint8_t *dest_arr, SliceIndex dest_arr_len, SliceIndex src_start, + SliceIndex src_end, SliceIndex src_step, uint8_t *src_arr, + SliceIndex src_arr_len, const SliceIndex size) + { + /* if dest_arr_len == 0, do nothing since we do not support extending list */ + if (dest_arr_len == 0) + return dest_arr_len; + /* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */ + if (src_step == dest_step && dest_step == 1) + { + const SliceIndex src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0; + const SliceIndex dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0; + if (src_len > 0) + { + __builtin_memmove(dest_arr + dest_start * size, src_arr + src_start * size, src_len * size); + } + if (dest_len > 0) + { + /* dropping */ + __builtin_memmove(dest_arr + (dest_start + src_len) * size, dest_arr + (dest_end + 1) * size, + (dest_arr_len - dest_end - 1) * size); + } + /* shrink size */ + return dest_arr_len - (dest_len - src_len); + } + /* if two range overlaps, need alloca */ + uint8_t need_alloca = (dest_arr == src_arr) && !(max(dest_start, dest_end) < min(src_start, src_end) || + max(src_start, src_end) < min(dest_start, dest_end)); + if (need_alloca) + { + uint8_t *tmp = reinterpret_cast(__builtin_alloca(src_arr_len * size)); + __builtin_memcpy(tmp, src_arr, src_arr_len * size); + src_arr = tmp; + } + SliceIndex src_ind = src_start; + SliceIndex dest_ind = dest_start; + for (; (src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end); src_ind += src_step, dest_ind += dest_step) + { + /* for constant optimization */ + if (size == 1) + { + __builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1); + } + else if (size == 4) + { + __builtin_memcpy(dest_arr + dest_ind * 4, src_arr + src_ind * 4, 4); + } + else if (size == 8) + { + __builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8); + } + else + { + /* memcpy for var size, cannot overlap after previous alloca */ + __builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size); + } + } + /* only dest_step == 1 can we shrink the dest list. */ + /* size should be ensured prior to calling this function */ + if (dest_step == 1 && dest_end >= dest_start) + { + __builtin_memmove(dest_arr + dest_ind * size, dest_arr + (dest_end + 1) * size, + (dest_arr_len - dest_end - 1) * size); + return dest_arr_len - (dest_end - dest_ind) - 1; + } + return dest_arr_len; + } + + int32_t __nac3_isinf(double x) + { + return __builtin_isinf(x); + } + + int32_t __nac3_isnan(double x) + { + return __builtin_isnan(x); + } + + double tgamma(double arg); + + double __nac3_gamma(double z) + { + // Handling for denormals + // | x | Python gamma(x) | C tgamma(x) | + // --- | ----------------- | --------------- | ----------- | + // (1) | nan | nan | nan | + // (2) | -inf | -inf | inf | + // (3) | inf | inf | inf | + // (4) | 0.0 | inf | inf | + // (5) | {-1.0, -2.0, ...} | inf | nan | + + // (1)-(3) + if (__builtin_isinf(z) || __builtin_isnan(z)) + { + return z; + } + + double v = tgamma(z); + + // (4)-(5) + return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v; + } + + double lgamma(double arg); + + double __nac3_gammaln(double x) + { + // libm's handling of value overflows differs from scipy: + // - scipy: gammaln(-inf) -> -inf + // - libm : lgamma(-inf) -> inf + + if (__builtin_isinf(x)) + { + return x; + } + + return lgamma(x); + } + + double j0(double x); + + double __nac3_j0(double x) + { + // libm's handling of value overflows differs from scipy: + // - scipy: j0(inf) -> nan + // - libm : j0(inf) -> 0.0 + + if (__builtin_isinf(x)) + { + return __builtin_nan(""); + } + + return j0(x); + } +} // extern "C" \ No newline at end of file diff --git a/nac3core/irrt/irrt/range.hpp b/nac3core/irrt/irrt/range.hpp new file mode 100644 index 00000000..71cf0216 --- /dev/null +++ b/nac3core/irrt/irrt/range.hpp @@ -0,0 +1,41 @@ +#pragma once + +#include +#include + +namespace +{ +namespace range +{ +template T len(T start, T stop, T step) +{ + // Reference: + // https://github.com/python/cpython/blob/9dbd12375561a393eaec4b21ee4ac568a407cdb0/Objects/rangeobject.c#L933 + if (step > 0 && start < stop) + return 1 + (stop - 1 - start) / step; + else if (step < 0 && start > stop) + return 1 + (start - 1 - stop) / (-step); + else + return 0; +} +} // namespace range + +/** + * @brief A Python range. + */ +template struct Range +{ + T start; + T stop; + T step; + + /** + * @brief Calculate the `len()` of this range. + */ + template T len() + { + debug_assert(SizeT, step != 0); + return range::len(start, stop, step); + } +}; +} // namespace diff --git a/nac3core/irrt/irrt/slice.hpp b/nac3core/irrt/irrt/slice.hpp new file mode 100644 index 00000000..5b7f805b --- /dev/null +++ b/nac3core/irrt/irrt/slice.hpp @@ -0,0 +1,158 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace +{ +namespace slice +{ +/** + * @brief Resolve a possibly negative index in a list of a known length. + * + * Returns -1 if the resolved index is out of the list's bounds. + */ +template T resolve_index_in_length(T length, T index) +{ + T resolved = index < 0 ? length + index : index; + if (0 <= resolved && resolved < length) + { + return resolved; + } + else + { + return -1; + } +} + +/** + * @brief Resolve a slice as a range. + * + * This is equivalent to `range(*slice(start, stop, step).indices(length))` in Python. + */ +template +void indices(bool start_defined, T start, bool stop_defined, T stop, bool step_defined, T step, T length, + T *range_start, T *range_stop, T *range_step) +{ + // Reference: https://github.com/python/cpython/blob/main/Objects/sliceobject.c#L388 + *range_step = step_defined ? step : 1; + bool step_is_negative = *range_step < 0; + + T lower, upper; + if (step_is_negative) + { + lower = -1; + upper = length - 1; + } + else + { + lower = 0; + upper = length; + } + + if (start_defined) + { + *range_start = start < 0 ? max(lower, start + length) : min(upper, start); + } + else + { + *range_start = step_is_negative ? upper : lower; + } + + if (stop_defined) + { + *range_stop = stop < 0 ? max(lower, stop + length) : min(upper, stop); + } + else + { + *range_stop = step_is_negative ? lower : upper; + } +} +} // namespace slice + +/** + * @brief A Python-like slice with **unresolved** indices. + */ +template struct Slice +{ + bool start_defined; + T start; + + bool stop_defined; + T stop; + + bool step_defined; + T step; + + Slice() + { + this->reset(); + } + + void reset() + { + this->start_defined = false; + this->stop_defined = false; + this->step_defined = false; + } + + void set_start(T start) + { + this->start_defined = true; + this->start = start; + } + + void set_stop(T stop) + { + this->stop_defined = true; + this->stop = stop; + } + + void set_step(T step) + { + this->step_defined = true; + this->step = step; + } + + /** + * @brief Resolve this slice as a range. + * + * In Python, this would be `range(*slice(start, stop, step).indices(length))`. + */ + template Range indices(T length) + { + // Reference: + // https://github.com/python/cpython/blob/main/Objects/sliceobject.c#L388 + debug_assert(SizeT, length >= 0); + + Range result; + slice::indices(start_defined, start, stop_defined, stop, step_defined, step, length, &result.start, + &result.stop, &result.step); + return result; + } + + /** + * @brief Like `.indices()` but with assertions. + */ + template Range indices_checked(T length) + { + // TODO: Switch to `SizeT length` + + if (length < 0) + { + raise_exception(SizeT, EXN_VALUE_ERROR, "length should not be negative, got {0}", length, NO_PARAM, + NO_PARAM); + } + + if (this->step_defined && this->step == 0) + { + raise_exception(SizeT, EXN_VALUE_ERROR, "slice step cannot be zero", NO_PARAM, NO_PARAM, NO_PARAM); + } + + return this->indices(length); + } +}; +} // namespace diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 9914d81c..53dd93b0 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -1,20 +1,21 @@ use inkwell::types::BasicTypeEnum; -use inkwell::values::{BasicValue, BasicValueEnum, IntValue, PointerValue}; +use inkwell::values::{BasicValue, BasicValueEnum, IntValue}; use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel}; use itertools::Itertools; -use crate::codegen::classes::{ - ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor, - UntypedArrayLikeAccessor, UntypedArrayLikeMutator, -}; +use super::model::*; +use super::object::any::AnyObject; +use super::object::list::ListObject; +use super::object::ndarray::NDArrayObject; +use super::object::tuple::TupleObject; +use crate::codegen::classes::RangeValue; use crate::codegen::expr::destructure_range; use crate::codegen::irrt::calculate_len_for_slice_range; -use crate::codegen::numpy::ndarray_elementwise_unaryop_impl; -use crate::codegen::stmt::gen_for_callback_incrementing; -use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator}; +use crate::codegen::object::ndarray::{NDArrayOut, ScalarOrNDArray}; +use crate::codegen::{extern_fns, irrt, llvm_intrinsics, CodeGenContext, CodeGenerator}; use crate::toplevel::helper::PrimDef; -use crate::toplevel::numpy::unpack_ndarray_var_tys; -use crate::typecheck::typedef::{Type, TypeEnum}; +use crate::typecheck::typedef::Type; +use crate::typecheck::typedef::TypeEnum; /// Shorthand for [`unreachable!()`] when a type of argument is not supported. /// @@ -32,58 +33,33 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>( ctx: &mut CodeGenContext<'ctx, '_>, n: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - let llvm_i32 = ctx.ctx.i32_type(); - let range_ty = ctx.primitives.range; let (arg_ty, arg) = n; - - Ok(if ctx.unifier.unioned(arg_ty, range_ty) { + Ok(if ctx.unifier.unioned(arg_ty, ctx.primitives.range) { let arg = RangeValue::from_ptr_val(arg.into_pointer_value(), Some("range")); let (start, end, step) = destructure_range(ctx, arg); calculate_len_for_slice_range(generator, ctx, start, end, step) } else { - match &*ctx.unifier.get_ty_immutable(arg_ty) { - TypeEnum::TTuple { ty, .. } => llvm_i32.const_int(ty.len() as u64, false), - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::List.id() => { - let zero = llvm_i32.const_zero(); - let len = ctx - .build_gep_and_load( - arg.into_pointer_value(), - &[zero, llvm_i32.const_int(1, false)], - None, - ) - .into_int_value(); - ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap() + let arg = AnyObject { ty: arg_ty, value: arg }; + let len: Instance<'ctx, Int> = match &*ctx.unifier.get_ty(arg_ty) { + TypeEnum::TTuple { .. } => { + let tuple = TupleObject::from_object(ctx, arg); + tuple.len(generator, ctx).truncate_or_bit_cast(generator, ctx, Int32) } - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - let llvm_usize = generator.get_size_type(ctx.ctx); - - let arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), llvm_usize, None); - - let ndims = arg.dim_sizes().size(ctx, generator); - ctx.make_assert( - generator, - ctx.builder - .build_int_compare(IntPredicate::NE, ndims, llvm_usize.const_zero(), "") - .unwrap(), - "0:TypeError", - "len() of unsized object", - [None, None, None], - ctx.current_loc, - ); - - let len = unsafe { - arg.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_zero(), - None, - ) - }; - - ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap() + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => + { + let ndarray = NDArrayObject::from_object(generator, ctx, arg); + ndarray.len(generator, ctx).truncate_or_bit_cast(generator, ctx, Int32) } - _ => unreachable!(), - } + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + let list = ListObject::from_object(generator, ctx, arg); + list.len(generator, ctx).truncate_or_bit_cast(generator, ctx, Int32) + } + _ => unsupported_type(ctx, "len", &[arg_ty]), + }; + len.value }) } @@ -94,7 +70,6 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( n: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); let (n_ty, n) = n; Ok(match n { @@ -128,21 +103,20 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( ctx.builder.build_int_truncate(to_int64, llvm_i32, "conv").map(Into::into).unwrap() } - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + _ if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: n_ty, value: n }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.int32, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_int32(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.primitives.int32 }, + |generator, ctx, scalar| call_int32(generator, ctx, (ndarray.dtype, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, "int32", &[n_ty]), @@ -156,7 +130,6 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( n: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { let llvm_i64 = ctx.ctx.i64_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); let (n_ty, n) = n; @@ -190,21 +163,20 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + _ if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: n_ty, value: n }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.int64, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_int64(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.primitives.int64 }, + |generator, ctx, scalar| call_int64(generator, ctx, (ndarray.dtype, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, "int64", &[n_ty]), @@ -218,7 +190,6 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( n: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); let (n_ty, n) = n; @@ -268,21 +239,20 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + _ if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: n_ty, value: n }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.uint32, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_uint32(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.primitives.uint32 }, + |generator, ctx, scalar| call_uint32(generator, ctx, (ndarray.dtype, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, "uint32", &[n_ty]), @@ -296,7 +266,6 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( n: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { let llvm_i64 = ctx.ctx.i64_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); let (n_ty, n) = n; @@ -335,21 +304,20 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( ctx.builder.build_select(val_gez, to_uint64, to_int64, "conv").unwrap() } - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + _ if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: n_ty, value: n }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.uint64, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_uint64(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.primitives.uint64 }, + |generator, ctx, scalar| call_uint64(generator, ctx, (ndarray.dtype, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, "uint64", &[n_ty]), @@ -363,7 +331,6 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( n: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { let llvm_f64 = ctx.ctx.f64_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); let (n_ty, n) = n; @@ -401,21 +368,20 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( n.into() } - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + _ if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: n_ty, value: n }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.float, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_float(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.primitives.float }, + |generator, ctx, scalar| call_float(generator, ctx, (ndarray.dtype, scalar)), + ) + .unwrap(); - ndarray.as_base_value().into() + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, "float", &[n_ty]), @@ -431,8 +397,6 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "round"; - let llvm_usize = generator.get_size_type(ctx.ctx); - let (n_ty, n) = n; let llvm_ret_elem_ty = ctx.get_llvm_abi_type(generator, ret_elem_ty).into_int_type(); @@ -447,21 +411,22 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + _ if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: n_ty, value: n }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_round(generator, ctx, (elem_ty, val), ret_elem_ty), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ret_elem_ty }, + |generator, ctx, scalar| { + call_round(generator, ctx, (ndarray.dtype, scalar), ret_elem_ty) + }, + ) + .unwrap(); - ndarray.as_base_value().into() + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -476,8 +441,6 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_round"; - let llvm_usize = generator.get_size_type(ctx.ctx); - let (n_ty, n) = n; Ok(match n { @@ -487,21 +450,22 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_rint(ctx, n, None).into() } - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + _ if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: n_ty, value: n }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.float, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_numpy_round(generator, ctx, (elem_ty, val)), - )?; + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.primitives.float }, + |generator, ctx, scalar| { + call_numpy_round(generator, ctx, (ndarray.dtype, scalar)) + }, + ) + .unwrap(); - ndarray.as_base_value().into() + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -516,8 +480,6 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "bool"; - let llvm_usize = generator.get_size_type(ctx.ctx); - let (n_ty, n) = n; Ok(match n { @@ -552,25 +514,22 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + _ if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: n_ty, value: n }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ctx.primitives.bool, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| { - let elem = call_bool(generator, ctx, (elem_ty, val))?; - - Ok(generator.bool_to_i8(ctx, elem.into_int_value()).into()) - }, - )?; - - ndarray.as_base_value().into() + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.primitives.bool }, + |generator, ctx, scalar| { + let elem = call_bool(generator, ctx, (ndarray.dtype, scalar))?; + Ok(generator.bool_to_i8(ctx, elem.into_int_value()).into()) + }, + ) + .unwrap(); + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -586,8 +545,6 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "floor"; - let llvm_usize = generator.get_size_type(ctx.ctx); - let (n_ty, n) = n; let llvm_ret_elem_ty = ctx.get_llvm_abi_type(generator, ret_elem_ty); @@ -606,21 +563,21 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( } } - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + _ if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: n_ty, value: n }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty), - )?; - - ndarray.as_base_value().into() + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ret_elem_ty }, + |generator, ctx, scalar| { + call_floor(generator, ctx, (ndarray.dtype, scalar), ret_elem_ty) + }, + ) + .unwrap(); + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -636,8 +593,6 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "ceil"; - let llvm_usize = generator.get_size_type(ctx.ctx); - let (n_ty, n) = n; let llvm_ret_elem_ty = ctx.get_llvm_abi_type(generator, ret_elem_ty); @@ -656,21 +611,21 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( } } - BasicValueEnum::PointerValue(n) - if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); + _ if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: n_ty, value: n }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - NDArrayValue::from_ptr_val(n, llvm_usize, None), - |generator, ctx, val| call_ceil(generator, ctx, (elem_ty, val), ret_elem_ty), - )?; - - ndarray.as_base_value().into() + let result = ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ret_elem_ty }, + |generator, ctx, scalar| { + call_ceil(generator, ctx, (ndarray.dtype, scalar), ret_elem_ty) + }, + ) + .unwrap(); + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, FN_NAME, &[n_ty]), @@ -762,47 +717,32 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) } - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => + _ if [&x1_ty, &x2_ty] + .into_iter() + .any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())) => { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = ScalarOrNDArray::split_object(generator, ctx, x1).to_ndarray(generator, ctx); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let x2 = AnyObject { ty: x2_ty, value: x2 }; + let x2 = ScalarOrNDArray::split_object(generator, ctx, x2).to_ndarray(generator, ctx); - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + debug_assert!(ctx.unifier.unioned(x1.dtype, x2.dtype)); + let common_dtype = x1.dtype; - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - unreachable!() - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( + let result = NDArrayObject::broadcast_starmap( generator, ctx, - dtype, - None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_minimum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) + &[x1, x2], + NDArrayOut::NewNDArray { dtype: common_dtype }, + |_generator, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; + Ok(call_min(ctx, (x1.dtype, x1_scalar), (x2.dtype, x2_scalar))) }, - )? - .as_base_value() - .into() + ) + .unwrap(); + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), @@ -869,7 +809,6 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( debug_assert!(["np_argmin", "np_argmax", "np_max", "np_min"].iter().any(|f| *f == fn_name)); let llvm_int64 = ctx.ctx.i64_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); let (a_ty, a) = a; Ok(match a { @@ -891,23 +830,21 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( _ => unreachable!(), } } - BasicValueEnum::PointerValue(n) - if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); - let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); + _ if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { + let ndarray = AnyObject { ty: a_ty, value: a }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); + + let dtype_llvm = ctx.get_llvm_type(generator, ndarray.dtype); + + let zero = Int(SizeT).const_0(generator, ctx.ctx); - let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); - let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None)); if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { - let n_sz_eqz = ctx - .builder - .build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "") - .unwrap(); + let size_isnt_zero = + ndarray.size(generator, ctx).compare(ctx, IntPredicate::NE, zero); ctx.make_assert( generator, - n_sz_eqz, + size_isnt_zero.value, "0:ValueError", format!("zero-size array to reduction operation {fn_name}").as_str(), [None, None, None], @@ -915,44 +852,43 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( ); } - let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; - let res_idx = generator.gen_var_alloc(ctx, llvm_int64.into(), None)?; + let extremum = generator.gen_var_alloc(ctx, dtype_llvm, None)?; + let extremum_idx = Int(SizeT).var_alloca(generator, ctx, None)?; - unsafe { - let identity = - n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None); - ctx.builder.build_store(accumulator_addr, identity).unwrap(); - ctx.builder.build_store(res_idx, llvm_int64.const_zero()).unwrap(); - } + let first_value = ndarray.get_nth_scalar(generator, ctx, zero).value; + ctx.builder.build_store(extremum, first_value).unwrap(); + extremum_idx.store(ctx, zero); - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_int64.const_int(1, false), - (n_sz, false), - |generator, ctx, _, idx| { - let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) }; - let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); - let cur_idx = ctx.builder.build_load(res_idx, "").unwrap(); + // The first element is iterated, but this doesn't matter. + ndarray + .foreach(generator, ctx, |generator, ctx, _hooks, nditer| { + let old_extremum = ctx.builder.build_load(extremum, "").unwrap(); + let old_extremum_idx = extremum_idx.load(generator, ctx); - let result = match fn_name { - "np_argmin" | "np_min" => { - call_min(ctx, (elem_ty, accumulator), (elem_ty, elem)) - } - "np_argmax" | "np_max" => { - call_max(ctx, (elem_ty, accumulator), (elem_ty, elem)) - } + let curr_value = nditer.get_scalar(generator, ctx).value; + let curr_idx = nditer.get_index(generator, ctx); + + let new_extremum = match fn_name { + "np_argmin" | "np_min" => call_min( + ctx, + (ndarray.dtype, old_extremum), + (ndarray.dtype, curr_value), + ), + "np_argmax" | "np_max" => call_max( + ctx, + (ndarray.dtype, old_extremum), + (ndarray.dtype, curr_value), + ), _ => unreachable!(), }; - let updated_idx = match (accumulator, result) { + let new_extremum_idx = match (old_extremum, new_extremum) { (BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => ctx .builder .build_select( ctx.builder.build_int_compare(IntPredicate::NE, m, n, "").unwrap(), - idx.into(), - cur_idx, + curr_idx.value, + old_extremum_idx.value, "", ) .unwrap(), @@ -962,24 +898,31 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( ctx.builder .build_float_compare(FloatPredicate::ONE, m, n, "") .unwrap(), - idx.into(), - cur_idx, + curr_idx.value, + old_extremum_idx.value, "", ) .unwrap(), - _ => unsupported_type(ctx, fn_name, &[elem_ty, elem_ty]), + _ => unsupported_type(ctx, fn_name, &[ndarray.dtype, ndarray.dtype]), }; - ctx.builder.build_store(res_idx, updated_idx).unwrap(); - ctx.builder.build_store(accumulator_addr, result).unwrap(); + + ctx.builder.build_store(extremum, new_extremum).unwrap(); + + let new_extremum_idx = + Int(SizeT).believe_value(new_extremum_idx.into_int_value()); + extremum_idx.store(ctx, new_extremum_idx); Ok(()) - }, - llvm_int64.const_int(1, false), - )?; + }) + .unwrap(); match fn_name { - "np_argmin" | "np_argmax" => ctx.builder.build_load(res_idx, "").unwrap(), - "np_max" | "np_min" => ctx.builder.build_load(accumulator_addr, "").unwrap(), + "np_argmin" | "np_argmax" => extremum_idx + .load(generator, ctx) + .s_extend_or_bit_cast(generator, ctx, Int64) + .value + .as_basic_value_enum(), + "np_max" | "np_min" => ctx.builder.build_load(extremum, "").unwrap(), _ => unreachable!(), } } @@ -1024,47 +967,32 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>( call_max(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) } - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => + _ if [&x1_ty, &x2_ty] + .into_iter() + .any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())) => { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = ScalarOrNDArray::split_object(generator, ctx, x1).to_ndarray(generator, ctx); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); + let x2 = AnyObject { ty: x2_ty, value: x2 }; + let x2 = ScalarOrNDArray::split_object(generator, ctx, x2).to_ndarray(generator, ctx); - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + debug_assert!(ctx.unifier.unioned(x1.dtype, x2.dtype)); + let common_dtype = x1.dtype; - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - unreachable!() - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( + let result = NDArrayObject::broadcast_starmap( generator, ctx, - dtype, - None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_maximum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) + &[x1, x2], + NDArrayOut::NewNDArray { dtype: common_dtype }, + |_generator, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; + Ok(call_max(ctx, (x1.dtype, x1_scalar), (x2.dtype, x2_scalar))) }, - )? - .as_base_value() - .into() + ) + .unwrap(); + result.instance.value.as_basic_value_enum() } _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), @@ -1097,39 +1025,19 @@ where ) -> Option>, RetElemFn: Fn(&mut CodeGenContext<'ctx, '_>, Type) -> Type, { - let result = match arg_val { - BasicValueEnum::PointerValue(x) - if arg_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => - { - let llvm_usize = generator.get_size_type(ctx.ctx); - let (arg_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); - let ret_elem_ty = get_ret_elem_type(ctx, arg_elem_ty); + let arg = AnyObject { ty: arg_ty, value: arg_val }; + let arg = ScalarOrNDArray::split_object(generator, ctx, arg); - let ndarray = ndarray_elementwise_unaryop_impl( - generator, - ctx, - ret_elem_ty, - None, - NDArrayValue::from_ptr_val(x, llvm_usize, None), - |generator, ctx, elem_val| { - helper_call_numpy_unary_elementwise( - generator, - ctx, - (arg_elem_ty, elem_val), - fn_name, - get_ret_elem_type, - on_scalar, - ) - }, - )?; - ndarray.as_base_value().into() - } + let dtype = arg.get_dtype(); - _ => on_scalar(generator, ctx, arg_ty, arg_val) - .unwrap_or_else(|| unsupported_type(ctx, fn_name, &[arg_ty])), - }; - - Ok(result) + let ret_ty = get_ret_elem_type(ctx, dtype); + let result = arg.map(generator, ctx, ret_ty, |generator, ctx, scalar| { + let Some(result) = on_scalar(generator, ctx, dtype, scalar) else { + unsupported_type(ctx, fn_name, &[arg_ty]) + }; + Ok(result) + })?; + Ok(result.to_basic_value_enum()) } pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>( @@ -1448,391 +1356,220 @@ create_helper_call_numpy_unary_elementwise_float_to_float!( pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_arctan2"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = ScalarOrNDArray::split_object(generator, ctx, x1); + let x2 = AnyObject { ty: x2_ty, value: x2 }; + let x2 = ScalarOrNDArray::split_object(generator, ctx, x2); - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.primitives.float, + |_generator, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - extern_fns::call_atan2(ctx, x1, x2, None).into() - } + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(extern_fns::call_atan2(ctx, x1, x2, None).as_basic_value_enum()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - unreachable!() - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_arctan2(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_copysign` builtin function. pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_copysign"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = ScalarOrNDArray::split_object(generator, ctx, x1); + let x2 = AnyObject { ty: x2_ty, value: x2 }; + let x2 = ScalarOrNDArray::split_object(generator, ctx, x2); - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.primitives.float, + |_generator, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - llvm_intrinsics::call_float_copysign(ctx, x1, x2, None).into() - } + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(llvm_intrinsics::call_float_copysign(ctx, x1, x2, None) + .as_basic_value_enum()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - unreachable!() - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_copysign(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_fmax` builtin function. pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_fmax"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = ScalarOrNDArray::split_object(generator, ctx, x1); + let x2 = AnyObject { ty: x2_ty, value: x2 }; + let x2 = ScalarOrNDArray::split_object(generator, ctx, x2); - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.primitives.float, + |_generator, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - llvm_intrinsics::call_float_maxnum(ctx, x1, x2, None).into() - } + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(llvm_intrinsics::call_float_maxnum(ctx, x1, x2, None).as_basic_value_enum()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - unreachable!() - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_fmax(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_fmin` builtin function. pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_fmin"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = ScalarOrNDArray::split_object(generator, ctx, x1); + let x2 = AnyObject { ty: x2_ty, value: x2 }; + let x2 = ScalarOrNDArray::split_object(generator, ctx, x2); - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.primitives.float, + |_generator, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - llvm_intrinsics::call_float_minnum(ctx, x1, x2, None).into() - } + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(llvm_intrinsics::call_float_minnum(ctx, x1, x2, None).as_basic_value_enum()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - unreachable!() - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_fmin(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_ldexp` builtin function. pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_ldexp"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = ScalarOrNDArray::split_object(generator, ctx, x1); + let x2 = AnyObject { ty: x2_ty, value: x2 }; + let x2 = ScalarOrNDArray::split_object(generator, ctx, x2); - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::IntValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.int32)); + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.primitives.float, + |_generator, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - extern_fns::call_ldexp(ctx, x1, x2, None).into() - } + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1_scalar), BasicValueEnum::IntValue(x2_scalar)) => { + debug_assert!(ctx.unifier.unioned(x1.get_dtype(), ctx.primitives.float)); + debug_assert!(ctx.unifier.unioned(x2.get_dtype(), ctx.primitives.int32)); + Ok(extern_fns::call_ldexp(ctx, x1_scalar, x2_scalar, None) + .as_basic_value_enum()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - - let dtype = - if is_ndarray1 { unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 } else { x1_ty }; - - let x1_scalar_ty = dtype; - let x2_scalar_ty = - if is_ndarray2 { unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_ldexp(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_hypot` builtin function. pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_hypot"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = ScalarOrNDArray::split_object(generator, ctx, x1); + let x2 = AnyObject { ty: x2_ty, value: x2 }; + let x2 = ScalarOrNDArray::split_object(generator, ctx, x2); - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.primitives.float, + |_generator, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - extern_fns::call_hypot(ctx, x1, x2, None).into() - } + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(extern_fns::call_hypot(ctx, x1, x2, None).as_basic_value_enum()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - unreachable!() - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_hypot(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) + Ok(result.to_basic_value_enum()) } /// Invokes the `np_nextafter` builtin function. @@ -1847,555 +1584,359 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>( let (x1_ty, x1) = x1; let (x2_ty, x2) = x2; - Ok(match (x1, x2) { - (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { - debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); - debug_assert!(ctx.unifier.unioned(x2_ty, ctx.primitives.float)); + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = ScalarOrNDArray::split_object(generator, ctx, x1); + let x2 = AnyObject { ty: x2_ty, value: x2 }; + let x2 = ScalarOrNDArray::split_object(generator, ctx, x2); - extern_fns::call_nextafter(ctx, x1, x2, None).into() - } + let result = ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[x1, x2], + ctx.primitives.float, + |_generator, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; - (x1, x2) - if [&x1_ty, &x2_ty].into_iter().any(|ty| { - ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) - }) => - { - let is_ndarray1 = - x1_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - x2_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + match (x1_scalar, x2_scalar) { + (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { + Ok(extern_fns::call_nextafter(ctx, x1, x2, None).as_basic_value_enum()) + } + _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), + } + }, + ) + .unwrap(); - let dtype = if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty); - - debug_assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); - - ndarray_dtype1 - } else if is_ndarray1 { - unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 - } else if is_ndarray2 { - unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0 - } else { - unreachable!() - }; - - let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty }; - let x2_scalar_ty = if is_ndarray2 { dtype } else { x2_ty }; - - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - dtype, - None, - (x1, !is_ndarray1), - (x2, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - call_numpy_nextafter(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs)) - }, - )? - .as_base_value() - .into() - } - - _ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]), - }) -} - -/// Allocates a struct with the fields specified by `out_matrices` and returns a pointer to it -fn build_output_struct<'ctx>( - ctx: &mut CodeGenContext<'ctx, '_>, - out_matrices: Vec>, -) -> PointerValue<'ctx> { - let field_ty = - out_matrices.iter().map(BasicValueEnum::get_type).collect::>(); - let out_ty = ctx.ctx.struct_type(&field_ty, false); - let out_ptr = ctx.builder.build_alloca(out_ty, "").unwrap(); - - for (i, v) in out_matrices.into_iter().enumerate() { - unsafe { - let ptr = ctx - .builder - .build_in_bounds_gep( - out_ptr, - &[ - ctx.ctx.i32_type().const_zero(), - ctx.ctx.i32_type().const_int(i as u64, false), - ], - "", - ) - .unwrap(); - ctx.builder.build_store(ptr, v).unwrap(); - } - } - out_ptr + Ok(result.to_basic_value_enum()) } /// Invokes the `np_linalg_cholesky` linalg function pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - const FN_NAME: &str = "np_linalg_cholesky"; - let (x1_ty, x1) = x1; - let llvm_usize = generator.get_size_type(ctx.ctx); + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = NDArrayObject::from_object(generator, ctx, x1); - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let out = NDArrayObject::alloca(generator, ctx, ctx.primitives.float, 2); + out.copy_shape_from_ndarray(generator, ctx, x1); + out.create_data(generator, ctx); - let BasicTypeEnum::FloatType(_) = n1_elem_ty else { - unsupported_type(ctx, FN_NAME, &[x1_ty]); - }; + let x1_c = x1.make_contiguous_ndarray(generator, ctx, Float(Float64)); + let out_c = out.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + extern_fns::call_np_linalg_cholesky( + ctx, + x1_c.value.as_basic_value_enum(), + out_c.value.as_basic_value_enum(), + None, + ); - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); - let dim0 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let dim1 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - .into_int_value() - }; - - let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - - extern_fns::call_np_linalg_cholesky(ctx, x1, out, None); - Ok(out) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) - } + Ok(out.instance.value.as_basic_value_enum()) } /// Invokes the `np_linalg_qr` linalg function pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - const FN_NAME: &str = "np_linalg_qr"; - let (x1_ty, x1) = x1; - let llvm_usize = generator.get_size_type(ctx.ctx); + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = NDArrayObject::from_object(generator, ctx, x1); - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let x1_shape = x1.instance.get(generator, ctx, |f| f.shape); + let d0 = x1_shape.get_index_const(generator, ctx, 0); + let d1 = x1_shape.get_index_const(generator, ctx, 1); + let dk = + Int(SizeT).believe_value(llvm_intrinsics::call_int_smin(ctx, d0.value, d1.value, None)); - let BasicTypeEnum::FloatType(_) = n1_elem_ty else { - unimplemented!("{FN_NAME} operates on float type NdArrays only"); - }; + let q = NDArrayObject::alloca_dynamic_shape(generator, ctx, ctx.primitives.float, &[d0, dk]); + q.create_data(generator, ctx); - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); - let dim0 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let dim1 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - .into_int_value() - }; - let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None); + let r = NDArrayObject::alloca_dynamic_shape(generator, ctx, ctx.primitives.float, &[dk, d1]); + r.create_data(generator, ctx); - let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - let out_r = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); + let x1_c = x1.make_contiguous_ndarray(generator, ctx, Float(Float64)); + let q_c = q.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + let r_c = r.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + extern_fns::call_np_linalg_qr( + ctx, + x1_c.value.as_basic_value_enum(), + q_c.value.as_basic_value_enum(), + r_c.value.as_basic_value_enum(), + None, + ); - extern_fns::call_np_linalg_qr(ctx, x1, out_q, out_r, None); - - let out_ptr = build_output_struct(ctx, vec![out_q, out_r]); - - Ok(ctx.builder.build_load(out_ptr, "QR_Factorization_result").map(Into::into).unwrap()) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) - } + let q = q.to_any(ctx); + let r = r.to_any(ctx); + let tuple = TupleObject::from_objects(generator, ctx, [q, r]); + Ok(tuple.value.as_basic_value_enum()) } /// Invokes the `np_linalg_svd` linalg function pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - const FN_NAME: &str = "np_linalg_svd"; - let (x1_ty, x1) = x1; - let llvm_usize = generator.get_size_type(ctx.ctx); + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = NDArrayObject::from_object(generator, ctx, x1); - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let x1_shape = x1.instance.get(generator, ctx, |f| f.shape); + let d0 = x1_shape.get_index_const(generator, ctx, 0); + let d1 = x1_shape.get_index_const(generator, ctx, 1); + let dk = + Int(SizeT).believe_value(llvm_intrinsics::call_int_smin(ctx, d0.value, d1.value, None)); - let BasicTypeEnum::FloatType(_) = n1_elem_ty else { - unsupported_type(ctx, FN_NAME, &[x1_ty]); - }; + let u = NDArrayObject::alloca_dynamic_shape(generator, ctx, ctx.primitives.float, &[d0, d0]); + u.create_data(generator, ctx); - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + let s = NDArrayObject::alloca_dynamic_shape(generator, ctx, ctx.primitives.float, &[dk]); + s.create_data(generator, ctx); - let dim0 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let dim1 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - .into_int_value() - }; - let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None); + let vh = NDArrayObject::alloca_dynamic_shape(generator, ctx, ctx.primitives.float, &[d1, d1]); + vh.create_data(generator, ctx); - let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - let out_s = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - let out_vh = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim1]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); + let x1_c = x1.make_contiguous_ndarray(generator, ctx, Float(Float64)); + let u_c = u.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + let s_c = s.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + let vh_c = vh.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + extern_fns::call_np_linalg_svd( + ctx, + x1_c.value.as_basic_value_enum(), + u_c.value.as_basic_value_enum(), + s_c.value.as_basic_value_enum(), + vh_c.value.as_basic_value_enum(), + None, + ); - extern_fns::call_np_linalg_svd(ctx, x1, out_u, out_s, out_vh, None); - - let out_ptr = build_output_struct(ctx, vec![out_u, out_s, out_vh]); - - Ok(ctx.builder.build_load(out_ptr, "SVD_Factorization_result").map(Into::into).unwrap()) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) - } + let u = u.to_any(ctx); + let s = s.to_any(ctx); + let vh = vh.to_any(ctx); + let tuple = TupleObject::from_objects(generator, ctx, [u, s, vh]); + Ok(tuple.value.as_basic_value_enum()) } /// Invokes the `np_linalg_inv` linalg function pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - const FN_NAME: &str = "np_linalg_inv"; - let (x1_ty, x1) = x1; - let llvm_usize = generator.get_size_type(ctx.ctx); + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = NDArrayObject::from_object(generator, ctx, x1); - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let out = NDArrayObject::alloca(generator, ctx, x1.dtype, 2); + out.copy_shape_from_ndarray(generator, ctx, x1); + out.create_data(generator, ctx); - let BasicTypeEnum::FloatType(_) = n1_elem_ty else { - unsupported_type(ctx, FN_NAME, &[x1_ty]); - }; + let x1_c = x1.make_contiguous_ndarray(generator, ctx, Float(Float64)); + let out_c = out.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + extern_fns::call_np_linalg_inv( + ctx, + x1_c.value.as_basic_value_enum(), + out_c.value.as_basic_value_enum(), + None, + ); - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); - let dim0 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let dim1 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - .into_int_value() - }; - - let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - - extern_fns::call_np_linalg_inv(ctx, x1, out, None); - Ok(out) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) - } + Ok(out.instance.value.as_basic_value_enum()) } /// Invokes the `np_linalg_pinv` linalg function pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - const FN_NAME: &str = "np_linalg_pinv"; - let (x1_ty, x1) = x1; - let llvm_usize = generator.get_size_type(ctx.ctx); + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = NDArrayObject::from_object(generator, ctx, x1); - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let x1_shape = x1.instance.get(generator, ctx, |f| f.shape); + let d0 = x1_shape.get_index_const(generator, ctx, 0); + let d1 = x1_shape.get_index_const(generator, ctx, 1); - let BasicTypeEnum::FloatType(_) = n1_elem_ty else { - unsupported_type(ctx, FN_NAME, &[x1_ty]); - }; + let out = NDArrayObject::alloca_dynamic_shape(generator, ctx, x1.dtype, &[d1, d0]); + out.create_data(generator, ctx); - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + let x1_c = x1.make_contiguous_ndarray(generator, ctx, Float(Float64)); + let out_c = out.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + extern_fns::call_np_linalg_pinv( + ctx, + x1_c.value.as_basic_value_enum(), + out_c.value.as_basic_value_enum(), + None, + ); - let dim0 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let dim1 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - .into_int_value() - }; - - let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim0]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - - extern_fns::call_np_linalg_pinv(ctx, x1, out, None); - Ok(out) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) - } + Ok(out.instance.value.as_basic_value_enum()) } /// Invokes the `sp_linalg_lu` linalg function pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - const FN_NAME: &str = "sp_linalg_lu"; - let (x1_ty, x1) = x1; - let llvm_usize = generator.get_size_type(ctx.ctx); + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = NDArrayObject::from_object(generator, ctx, x1); - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let x1_shape = x1.instance.get(generator, ctx, |f| f.shape); + let d0 = x1_shape.get_index_const(generator, ctx, 0); + let d1 = x1_shape.get_index_const(generator, ctx, 1); + let dk = + Int(SizeT).believe_value(llvm_intrinsics::call_int_smin(ctx, d0.value, d1.value, None)); - let BasicTypeEnum::FloatType(_) = n1_elem_ty else { - unsupported_type(ctx, FN_NAME, &[x1_ty]); - }; + let l = NDArrayObject::alloca_dynamic_shape(generator, ctx, ctx.primitives.float, &[d0, dk]); + l.create_data(generator, ctx); - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + let u = NDArrayObject::alloca_dynamic_shape(generator, ctx, ctx.primitives.float, &[dk, d1]); + u.create_data(generator, ctx); - let dim0 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let dim1 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - .into_int_value() - }; - let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None); + let x1_c = x1.make_contiguous_ndarray(generator, ctx, Float(Float64)); + let l_c = l.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + let u_c = u.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + extern_fns::call_sp_linalg_lu( + ctx, + x1_c.value.as_basic_value_enum(), + l_c.value.as_basic_value_enum(), + u_c.value.as_basic_value_enum(), + None, + ); - let out_l = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - - extern_fns::call_sp_linalg_lu(ctx, x1, out_l, out_u, None); - - let out_ptr = build_output_struct(ctx, vec![out_l, out_u]); - Ok(ctx.builder.build_load(out_ptr, "LU_Factorization_result").map(Into::into).unwrap()) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) - } + let l = l.to_any(ctx); + let u = u.to_any(ctx); + let tuple = TupleObject::from_objects(generator, ctx, [l, u]); + Ok(tuple.value.as_basic_value_enum()) } /// Invokes the `np_linalg_matrix_power` linalg function pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - const FN_NAME: &str = "np_linalg_matrix_power"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = NDArrayObject::from_object(generator, ctx, x1); + + // x2 is a float, but we are promoting this to a 1D ndarray (.shape == [1]) for uniformity in function call. let x2 = call_float(generator, ctx, (x2_ty, x2)).unwrap(); + let x2 = AnyObject { ty: ctx.primitives.float, value: x2 }; + let x2 = NDArrayObject::make_unsized(generator, ctx, x2); // x2.shape == [] + let x2 = x2.atleast_nd(generator, ctx, 1); // x2.shape == [1] - let llvm_usize = generator.get_size_type(ctx.ctx); - if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::FloatValue(n2)) = (x1, x2) { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let out = NDArrayObject::alloca(generator, ctx, ctx.primitives.float, 2); + out.copy_shape_from_ndarray(generator, ctx, x1); + out.create_data(generator, ctx); - let BasicTypeEnum::FloatType(_) = n1_elem_ty else { - unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]); - }; + let x1_c = x1.make_contiguous_ndarray(generator, ctx, Float(Float64)); + let x2_c = x2.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + let out_c = out.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + extern_fns::call_np_linalg_matrix_power( + ctx, + x1_c.value.as_basic_value_enum(), + x2_c.value.as_basic_value_enum(), + out_c.value.as_basic_value_enum(), + None, + ); - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); - // Changing second parameter to a `NDArray` for uniformity in function call - let n2_array = numpy::create_ndarray_const_shape( - generator, - ctx, - elem_ty, - &[llvm_usize.const_int(1, false)], - ) - .unwrap(); - unsafe { - n2_array.data().set_unchecked( - ctx, - generator, - &llvm_usize.const_zero(), - n2.as_basic_value_enum(), - ); - }; - let n2_array = n2_array.as_base_value().as_basic_value_enum(); - - let outdim0 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let outdim1 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - .into_int_value() - }; - - let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - - extern_fns::call_np_linalg_matrix_power(ctx, x1, n2_array, out, None); - Ok(out) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) - } + Ok(out.instance.value.as_basic_value_enum()) } /// Invokes the `np_linalg_det` linalg function pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - const FN_NAME: &str = "np_linalg_matrix_power"; - let (x1_ty, x1) = x1; + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = NDArrayObject::from_object(generator, ctx, x1); - let llvm_usize = generator.get_size_type(ctx.ctx); - if let BasicValueEnum::PointerValue(_) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + // The output is a float64, but we are using an ndarray (shape == [1]) for uniformity in function call. + let det = NDArrayObject::alloca_constant_shape(generator, ctx, ctx.primitives.float, &[1]); + det.create_data(generator, ctx); - let BasicTypeEnum::FloatType(_) = n1_elem_ty else { - unsupported_type(ctx, FN_NAME, &[x1_ty]); - }; + let x1_c = x1.make_contiguous_ndarray(generator, ctx, Float(Float64)); + let out_c = det.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + extern_fns::call_np_linalg_det( + ctx, + x1_c.value.as_basic_value_enum(), + out_c.value.as_basic_value_enum(), + None, + ); - // Changing second parameter to a `NDArray` for uniformity in function call - let out = numpy::create_ndarray_const_shape( - generator, - ctx, - elem_ty, - &[llvm_usize.const_int(1, false)], - ) - .unwrap(); - extern_fns::call_np_linalg_det(ctx, x1, out.as_base_value().as_basic_value_enum(), None); - let res = - unsafe { out.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; - Ok(res) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) - } + // Get the determinant out of `out` + let zero = Int(SizeT).const_0(generator, ctx.ctx); + let det = det.get_nth_scalar(generator, ctx, zero); + Ok(det.value) } /// Invokes the `sp_linalg_schur` linalg function pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - const FN_NAME: &str = "sp_linalg_schur"; - let (x1_ty, x1) = x1; - let llvm_usize = generator.get_size_type(ctx.ctx); + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = NDArrayObject::from_object(generator, ctx, x1); + assert_eq!(x1.ndims, 2); - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let t = NDArrayObject::alloca(generator, ctx, ctx.primitives.float, 2); + t.copy_shape_from_ndarray(generator, ctx, x1); + t.create_data(generator, ctx); - let BasicTypeEnum::FloatType(_) = n1_elem_ty else { - unsupported_type(ctx, FN_NAME, &[x1_ty]); - }; + let z = NDArrayObject::alloca(generator, ctx, ctx.primitives.float, 2); + z.copy_shape_from_ndarray(generator, ctx, x1); + z.create_data(generator, ctx); - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + let x1_c = x1.make_contiguous_ndarray(generator, ctx, Float(Float64)); + let t_c = t.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + let z_c = z.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + extern_fns::call_sp_linalg_schur( + ctx, + x1_c.value.as_basic_value_enum(), + t_c.value.as_basic_value_enum(), + z_c.value.as_basic_value_enum(), + None, + ); - let dim0 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let out_t = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - let out_z = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - - extern_fns::call_sp_linalg_schur(ctx, x1, out_t, out_z, None); - - let out_ptr = build_output_struct(ctx, vec![out_t, out_z]); - Ok(ctx.builder.build_load(out_ptr, "Schur_Factorization_result").map(Into::into).unwrap()) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) - } + let t = t.to_any(ctx); + let z = z.to_any(ctx); + let tuple = TupleObject::from_objects(generator, ctx, [t, z]); + Ok(tuple.value.as_basic_value_enum()) } /// Invokes the `sp_linalg_hessenberg` linalg function pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - const FN_NAME: &str = "sp_linalg_hessenberg"; - let (x1_ty, x1) = x1; - let llvm_usize = generator.get_size_type(ctx.ctx); + let x1 = AnyObject { ty: x1_ty, value: x1 }; + let x1 = NDArrayObject::from_object(generator, ctx, x1); + assert_eq!(x1.ndims, 2); - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); + let h = NDArrayObject::alloca(generator, ctx, ctx.primitives.float, 2); + h.copy_shape_from_ndarray(generator, ctx, x1); + h.create_data(generator, ctx); - let BasicTypeEnum::FloatType(_) = n1_elem_ty else { - unsupported_type(ctx, FN_NAME, &[x1_ty]); - }; + let q = NDArrayObject::alloca(generator, ctx, ctx.primitives.float, 2); + q.copy_shape_from_ndarray(generator, ctx, x1); + q.create_data(generator, ctx); - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); + let x1_c = x1.make_contiguous_ndarray(generator, ctx, Float(Float64)); + let h_c = h.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + let q_c = q.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`. + extern_fns::call_sp_linalg_hessenberg( + ctx, + x1_c.value.as_basic_value_enum(), + h_c.value.as_basic_value_enum(), + q_c.value.as_basic_value_enum(), + None, + ); - let dim0 = unsafe { - n1.dim_sizes() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let out_h = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); - extern_fns::call_sp_linalg_hessenberg(ctx, x1, out_h, out_q, None); - - let out_ptr = build_output_struct(ctx, vec![out_h, out_q]); - Ok(ctx - .builder - .build_load(out_ptr, "Hessenberg_decomposition_result") - .map(Into::into) - .unwrap()) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) - } + let h = h.to_any(ctx); + let q = q.to_any(ctx); + let tuple = TupleObject::from_objects(generator, ctx, [h, q]); + Ok(tuple.value.as_basic_value_enum()) } diff --git a/nac3core/src/codegen/classes.rs b/nac3core/src/codegen/classes.rs index 52e9cca0..42676404 100644 --- a/nac3core/src/codegen/classes.rs +++ b/nac3core/src/codegen/classes.rs @@ -1,9 +1,4 @@ -use crate::codegen::{ - irrt::{call_ndarray_calc_size, call_ndarray_flatten_index}, - llvm_intrinsics::call_int_umin, - stmt::gen_for_callback_incrementing, - CodeGenContext, CodeGenerator, -}; +use crate::codegen::{CodeGenContext, CodeGenerator}; use inkwell::context::Context; use inkwell::types::{ArrayType, BasicType, StructType}; use inkwell::values::{ArrayValue, BasicValue, StructValue}; @@ -1141,624 +1136,3 @@ impl<'ctx> From> for PointerValue<'ctx> { value.as_base_value() } } - -/// Proxy type for a `ndarray` type in LLVM. -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub struct NDArrayType<'ctx> { - ty: PointerType<'ctx>, - llvm_usize: IntType<'ctx>, -} - -impl<'ctx> NDArrayType<'ctx> { - /// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not. - pub fn is_type(llvm_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> { - let llvm_ndarray_ty = llvm_ty.get_element_type(); - let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { - return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}")); - }; - if llvm_ndarray_ty.count_fields() != 3 { - return Err(format!( - "Expected 3 fields in `NDArray`, got {}", - llvm_ndarray_ty.count_fields() - )); - } - - let ndarray_ndims_ty = llvm_ndarray_ty.get_field_type_at_index(0).unwrap(); - let Ok(ndarray_ndims_ty) = IntType::try_from(ndarray_ndims_ty) else { - return Err(format!("Expected int type for `ndarray.0`, got {ndarray_ndims_ty}")); - }; - if ndarray_ndims_ty.get_bit_width() != llvm_usize.get_bit_width() { - return Err(format!( - "Expected {}-bit int type for `ndarray.0`, got {}-bit int", - llvm_usize.get_bit_width(), - ndarray_ndims_ty.get_bit_width() - )); - } - - let ndarray_dims_ty = llvm_ndarray_ty.get_field_type_at_index(1).unwrap(); - let Ok(ndarray_pdims) = PointerType::try_from(ndarray_dims_ty) else { - return Err(format!("Expected pointer type for `ndarray.1`, got {ndarray_dims_ty}")); - }; - let ndarray_dims = ndarray_pdims.get_element_type(); - let Ok(ndarray_dims) = IntType::try_from(ndarray_dims) else { - return Err(format!( - "Expected pointer-to-int type for `ndarray.1`, got pointer-to-{ndarray_dims}" - )); - }; - if ndarray_dims.get_bit_width() != llvm_usize.get_bit_width() { - return Err(format!( - "Expected pointer-to-{}-bit int type for `ndarray.1`, got pointer-to-{}-bit int", - llvm_usize.get_bit_width(), - ndarray_dims.get_bit_width() - )); - } - - let ndarray_data_ty = llvm_ndarray_ty.get_field_type_at_index(2).unwrap(); - let Ok(_) = PointerType::try_from(ndarray_data_ty) else { - return Err(format!("Expected pointer type for `ndarray.2`, got {ndarray_data_ty}")); - }; - - Ok(()) - } - - /// Creates an instance of [`ListType`]. - #[must_use] - pub fn new( - generator: &G, - ctx: &'ctx Context, - dtype: BasicTypeEnum<'ctx>, - ) -> Self { - let llvm_usize = generator.get_size_type(ctx); - - // struct NDArray { num_dims: size_t, dims: size_t*, data: T* } - // - // * num_dims: Number of dimensions in the array - // * dims: Pointer to an array containing the size of each dimension - // * data: Pointer to an array containing the array data - let llvm_ndarray = ctx - .struct_type( - &[ - llvm_usize.into(), - llvm_usize.ptr_type(AddressSpace::default()).into(), - dtype.ptr_type(AddressSpace::default()).into(), - ], - false, - ) - .ptr_type(AddressSpace::default()); - - NDArrayType::from_type(llvm_ndarray, llvm_usize) - } - - /// Creates an [`NDArrayType`] from a [`PointerType`]. - #[must_use] - pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { - debug_assert!(Self::is_type(ptr_ty, llvm_usize).is_ok()); - - NDArrayType { ty: ptr_ty, llvm_usize } - } - - /// Returns the type of the `size` field of this `ndarray` type. - #[must_use] - pub fn size_type(&self) -> IntType<'ctx> { - self.as_base_type() - .get_element_type() - .into_struct_type() - .get_field_type_at_index(0) - .map(BasicTypeEnum::into_int_type) - .unwrap() - } - - /// Returns the element type of this `ndarray` type. - #[must_use] - pub fn element_type(&self) -> BasicTypeEnum<'ctx> { - self.as_base_type() - .get_element_type() - .into_struct_type() - .get_field_type_at_index(2) - .unwrap() - } -} - -impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> { - type Base = PointerType<'ctx>; - type Underlying = StructType<'ctx>; - type Value = NDArrayValue<'ctx>; - - fn new_value( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> Self::Value { - self.create_value( - generator.gen_var_alloc(ctx, self.as_underlying_type().into(), name).unwrap(), - name, - ) - } - - fn create_value( - &self, - value: >::Base, - name: Option<&'ctx str>, - ) -> Self::Value { - debug_assert_eq!(value.get_type(), self.as_base_type()); - - NDArrayValue { value, llvm_usize: self.llvm_usize, name } - } - - fn as_base_type(&self) -> Self::Base { - self.ty - } - - fn as_underlying_type(&self) -> Self::Underlying { - self.as_base_type().get_element_type().into_struct_type() - } -} - -impl<'ctx> From> for PointerType<'ctx> { - fn from(value: NDArrayType<'ctx>) -> Self { - value.as_base_type() - } -} - -/// Proxy type for accessing an `NDArray` value in LLVM. -#[derive(Copy, Clone)] -pub struct NDArrayValue<'ctx> { - value: PointerValue<'ctx>, - llvm_usize: IntType<'ctx>, - name: Option<&'ctx str>, -} - -impl<'ctx> NDArrayValue<'ctx> { - /// Checks whether `value` is an instance of `NDArray`, returning [Err] if `value` is not an - /// instance. - pub fn is_instance(value: PointerValue<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> { - NDArrayType::is_type(value.get_type(), llvm_usize) - } - - /// Creates an [`NDArrayValue`] from a [`PointerValue`]. - #[must_use] - pub fn from_ptr_val( - ptr: PointerValue<'ctx>, - llvm_usize: IntType<'ctx>, - name: Option<&'ctx str>, - ) -> Self { - debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); - - >::Type::from_type(ptr.get_type(), llvm_usize) - .create_value(ptr, name) - } - - /// Returns the pointer to the field storing the number of dimensions of this `NDArray`. - fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let llvm_i32 = ctx.ctx.i32_type(); - let var_name = self.name.map(|v| format!("{v}.ndims.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - var_name.as_str(), - ) - .unwrap() - } - } - - /// Stores the number of dimensions `ndims` into this instance. - pub fn store_ndims( - &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, - ndims: IntValue<'ctx>, - ) { - debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx)); - - let pndims = self.ptr_to_ndims(ctx); - ctx.builder.build_store(pndims, ndims).unwrap(); - } - - /// Returns the number of dimensions of this `NDArray` as a value. - pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - let pndims = self.ptr_to_ndims(ctx); - ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap() - } - - /// Returns the double-indirection pointer to the `dims` array, as if by calling `getelementptr` - /// on the field. - fn ptr_to_dims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let llvm_i32 = ctx.ctx.i32_type(); - let var_name = self.name.map(|v| format!("{v}.dims.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - var_name.as_str(), - ) - .unwrap() - } - } - - /// Stores the array of dimension sizes `dims` into this instance. - fn store_dim_sizes(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) { - ctx.builder.build_store(self.ptr_to_dims(ctx), dims).unwrap(); - } - - /// Convenience method for creating a new array storing dimension sizes with the given `size`. - pub fn create_dim_sizes( - &self, - ctx: &CodeGenContext<'ctx, '_>, - llvm_usize: IntType<'ctx>, - size: IntValue<'ctx>, - ) { - self.store_dim_sizes(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap()); - } - - /// Returns a proxy object to the field storing the size of each dimension of this `NDArray`. - #[must_use] - pub fn dim_sizes(&self) -> NDArrayDimsProxy<'ctx, '_> { - NDArrayDimsProxy(self) - } - - /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` - /// on the field. - fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let llvm_i32 = ctx.ctx.i32_type(); - let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_int(2, true)], - var_name.as_str(), - ) - .unwrap() - } - } - - /// Stores the array of data elements `data` into this instance. - fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { - ctx.builder.build_store(self.ptr_to_data(ctx), data).unwrap(); - } - - /// Convenience method for creating a new array storing data elements with the given element - /// type `elem_ty` and `size`. - pub fn create_data( - &self, - ctx: &CodeGenContext<'ctx, '_>, - elem_ty: BasicTypeEnum<'ctx>, - size: IntValue<'ctx>, - ) { - self.store_data(ctx, ctx.builder.build_array_alloca(elem_ty, size, "").unwrap()); - } - - /// Returns a proxy object to the field storing the data of this `NDArray`. - #[must_use] - pub fn data(&self) -> NDArrayDataProxy<'ctx, '_> { - NDArrayDataProxy(self) - } -} - -impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> { - type Base = PointerValue<'ctx>; - type Underlying = StructValue<'ctx>; - type Type = NDArrayType<'ctx>; - - fn get_type(&self) -> Self::Type { - NDArrayType::from_type(self.as_base_value().get_type(), self.llvm_usize) - } - - fn as_base_value(&self) -> Self::Base { - self.value - } - - fn as_underlying_value( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> Self::Underlying { - ctx.builder - .build_load(self.as_base_value(), name.unwrap_or_default()) - .map(BasicValueEnum::into_struct_value) - .unwrap() - } -} - -impl<'ctx> From> for PointerValue<'ctx> { - fn from(value: NDArrayValue<'ctx>) -> Self { - value.as_base_value() - } -} - -/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM. -#[derive(Copy, Clone)] -pub struct NDArrayDimsProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>); - -impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDimsProxy<'ctx, '_> { - fn element_type( - &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, - ) -> AnyTypeEnum<'ctx> { - self.0.dim_sizes().base_ptr(ctx, generator).get_type().get_element_type() - } - - fn base_ptr( - &self, - ctx: &CodeGenContext<'ctx, '_>, - _: &G, - ) -> PointerValue<'ctx> { - let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default(); - - ctx.builder - .build_load(self.0.ptr_to_dims(ctx), var_name.as_str()) - .map(BasicValueEnum::into_pointer_value) - .unwrap() - } - - fn size( - &self, - ctx: &CodeGenContext<'ctx, '_>, - _: &G, - ) -> IntValue<'ctx> { - self.0.load_ndims(ctx) - } -} - -impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> { - unsafe fn ptr_offset_unchecked( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - idx: &IntValue<'ctx>, - name: Option<&str>, - ) -> PointerValue<'ctx> { - let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str()) - .unwrap() - } - } - - fn ptr_offset( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - idx: &IntValue<'ctx>, - name: Option<&str>, - ) -> PointerValue<'ctx> { - let size = self.size(ctx, generator); - let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap(); - ctx.make_assert( - generator, - in_range, - "0:IndexError", - "index {0} is out of bounds for axis 0 with size {1}", - [Some(*idx), Some(self.0.load_ndims(ctx)), None], - ctx.current_loc, - ); - - unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) } - } -} - -impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> {} -impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> {} - -impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> { - fn downcast_to_type( - &self, - _: &mut CodeGenContext<'ctx, '_>, - value: BasicValueEnum<'ctx>, - ) -> IntValue<'ctx> { - value.into_int_value() - } -} - -impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> { - fn upcast_from_type( - &self, - _: &mut CodeGenContext<'ctx, '_>, - value: IntValue<'ctx>, - ) -> BasicValueEnum<'ctx> { - value.into() - } -} - -/// Proxy type for accessing the `data` array of an `NDArray` instance in LLVM. -#[derive(Copy, Clone)] -pub struct NDArrayDataProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>); - -impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> { - fn element_type( - &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, - ) -> AnyTypeEnum<'ctx> { - self.0.data().base_ptr(ctx, generator).get_type().get_element_type() - } - - fn base_ptr( - &self, - ctx: &CodeGenContext<'ctx, '_>, - _: &G, - ) -> PointerValue<'ctx> { - let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default(); - - ctx.builder - .build_load(self.0.ptr_to_data(ctx), var_name.as_str()) - .map(BasicValueEnum::into_pointer_value) - .unwrap() - } - - fn size( - &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, - ) -> IntValue<'ctx> { - call_ndarray_calc_size(generator, ctx, &self.as_slice_value(ctx, generator), (None, None)) - } -} - -impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> { - unsafe fn ptr_offset_unchecked( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - idx: &IntValue<'ctx>, - name: Option<&str>, - ) -> PointerValue<'ctx> { - unsafe { - ctx.builder - .build_in_bounds_gep( - self.base_ptr(ctx, generator), - &[*idx], - name.unwrap_or_default(), - ) - .unwrap() - } - } - - fn ptr_offset( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - idx: &IntValue<'ctx>, - name: Option<&str>, - ) -> PointerValue<'ctx> { - let data_sz = self.size(ctx, generator); - let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, data_sz, "").unwrap(); - ctx.make_assert( - generator, - in_range, - "0:IndexError", - "index {0} is out of bounds with size {1}", - [Some(*idx), Some(self.0.load_ndims(ctx)), None], - ctx.current_loc, - ); - - unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) } - } -} - -impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDataProxy<'ctx, '_> {} -impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayDataProxy<'ctx, '_> {} - -impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> - for NDArrayDataProxy<'ctx, '_> -{ - unsafe fn ptr_offset_unchecked( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - indices: &Index, - name: Option<&str>, - ) -> PointerValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); - - let indices_elem_ty = indices - .ptr_offset(ctx, generator, &llvm_usize.const_zero(), None) - .get_type() - .get_element_type(); - let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else { - panic!("Expected list[int32] but got {indices_elem_ty}") - }; - assert_eq!( - indices_elem_ty.get_bit_width(), - 32, - "Expected list[int32] but got list[int{}]", - indices_elem_ty.get_bit_width() - ); - - let index = call_ndarray_flatten_index(generator, ctx, *self.0, indices); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.base_ptr(ctx, generator), - &[index], - name.unwrap_or_default(), - ) - .unwrap() - } - } - - fn ptr_offset( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - indices: &Index, - name: Option<&str>, - ) -> PointerValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); - - let indices_size = indices.size(ctx, generator); - let nidx_leq_ndims = ctx - .builder - .build_int_compare(IntPredicate::SLE, indices_size, self.0.load_ndims(ctx), "") - .unwrap(); - ctx.make_assert( - generator, - nidx_leq_ndims, - "0:IndexError", - "invalid index to scalar variable", - [None, None, None], - ctx.current_loc, - ); - - let indices_len = indices.size(ctx, generator); - let ndarray_len = self.0.load_ndims(ctx); - let len = call_int_umin(ctx, indices_len, ndarray_len, None); - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (len, false), - |generator, ctx, _, i| { - let (dim_idx, dim_sz) = unsafe { - ( - indices.get_unchecked(ctx, generator, &i, None).into_int_value(), - self.0.dim_sizes().get_typed_unchecked(ctx, generator, &i, None), - ) - }; - let dim_idx = ctx - .builder - .build_int_z_extend_or_bit_cast(dim_idx, dim_sz.get_type(), "") - .unwrap(); - - let dim_lt = - ctx.builder.build_int_compare(IntPredicate::SLT, dim_idx, dim_sz, "").unwrap(); - - ctx.make_assert( - generator, - dim_lt, - "0:IndexError", - "index {0} is out of bounds for axis 0 with size {1}", - [Some(dim_idx), Some(dim_sz), None], - ctx.current_loc, - ); - - Ok(()) - }, - llvm_usize.const_int(1, false), - ) - .unwrap(); - - unsafe { self.ptr_offset_unchecked(ctx, generator, indices, name) } - } -} - -impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeAccessor<'ctx, Index> - for NDArrayDataProxy<'ctx, '_> -{ -} -impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx, Index> - for NDArrayDataProxy<'ctx, '_> -{ -} diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 50d70eb7..2c8e3b52 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1,8 +1,8 @@ use crate::{ codegen::{ classes::{ - ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayValue, ProxyType, - ProxyValue, RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor, + ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, ProxyType, ProxyValue, + RangeValue, UntypedArrayLikeAccessor, }, concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore}, gen_in_range_check, get_llvm_abi_type, get_llvm_type, get_va_count_arg_name, @@ -11,7 +11,8 @@ use crate::{ call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax, call_int_umin, call_memcpy_generic, }, - need_sret, numpy, + need_sret, + object::ndarray::{NDArrayOut, ScalarOrNDArray}, stmt::{ gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise, gen_var, @@ -19,11 +20,7 @@ use crate::{ CodeGenContext, CodeGenTask, CodeGenerator, }, symbol_resolver::{SymbolValue, ValueEnum}, - toplevel::{ - helper::PrimDef, - numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, - DefinitionId, TopLevelDef, - }, + toplevel::{helper::PrimDef, DefinitionId, TopLevelDef}, typecheck::{ magic_methods::{Binop, BinopVariant, HasOpInfo}, typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap}, @@ -32,7 +29,10 @@ use crate::{ use inkwell::{ attributes::{Attribute, AttributeLoc}, types::{AnyType, BasicType, BasicTypeEnum}, - values::{BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue, StructValue}, + values::{ + BasicValue, BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue, + StructValue, + }, AddressSpace, IntPredicate, OptimizationLevel, }; use itertools::{chain, izip, Either, Itertools}; @@ -44,6 +44,14 @@ use std::cmp::min; use std::iter::{repeat, repeat_with}; use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip}; +use super::{ + model::*, + object::{ + any::AnyObject, + ndarray::{indexing::util::gen_ndarray_subscript_ndindices, NDArrayObject}, + }, +}; + pub fn get_subst_key( unifier: &mut Unifier, obj: Option, @@ -1540,99 +1548,75 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( } else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) || ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { - let llvm_usize = generator.get_size_type(ctx.ctx); + let left = + ScalarOrNDArray::split_object(generator, ctx, AnyObject { ty: ty1, value: left_val }); + let right = + ScalarOrNDArray::split_object(generator, ctx, AnyObject { ty: ty2, value: right_val }); - let is_ndarray1 = ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + // Inhomogeneous binary operations are not supported. + assert!(ctx.unifier.unioned(left.get_dtype(), right.get_dtype())); - if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty2); + let common_dtype = left.get_dtype(); - assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + let out = match op.variant { + BinopVariant::Normal => NDArrayOut::NewNDArray { dtype: common_dtype }, + BinopVariant::AugAssign => { + // If this is an augmented assignment. + // `left` has to be an ndarray. If it were a scalar then NAC3 simply doesn't support it. + if let ScalarOrNDArray::NDArray(out_ndarray) = left { + NDArrayOut::WriteToNDArray { ndarray: out_ndarray } + } else { + panic!("left must be an ndarray") + } + } + }; - let left_val = - NDArrayValue::from_ptr_val(left_val.into_pointer_value(), llvm_usize, None); - let right_val = - NDArrayValue::from_ptr_val(right_val.into_pointer_value(), llvm_usize, None); - - let res = if op.base == Operator::MatMult { - // MatMult is the only binop which is not an elementwise op - numpy::ndarray_matmul_2d( - generator, - ctx, - ndarray_dtype1, - match op.variant { - BinopVariant::Normal => None, - BinopVariant::AugAssign => Some(left_val), - }, - left_val, - right_val, - )? - } else { - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - ndarray_dtype1, - match op.variant { - BinopVariant::Normal => None, - BinopVariant::AugAssign => Some(left_val), - }, - (left_val.as_base_value().into(), false), - (right_val.as_base_value().into(), false), - |generator, ctx, (lhs, rhs)| { - gen_binop_expr_with_values( - generator, - ctx, - (&Some(ndarray_dtype1), lhs), - op, - (&Some(ndarray_dtype2), rhs), - ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum( - ctx, - generator, - ndarray_dtype1, - ) - }, - )? - }; - - Ok(Some(res.as_base_value().into())) + if op.base == Operator::MatMult { + // Handle matrix multiplication. + let left = left.to_ndarray(generator, ctx); + let right = right.to_ndarray(generator, ctx); + let result = NDArrayObject::matmul(generator, ctx, left, right, out) + .split_unsized(generator, ctx); + Ok(Some(ValueEnum::Dynamic(result.to_basic_value_enum()))) } else { - let (ndarray_dtype, _) = - unpack_ndarray_var_tys(&mut ctx.unifier, if is_ndarray1 { ty1 } else { ty2 }); - let ndarray_val = NDArrayValue::from_ptr_val( - if is_ndarray1 { left_val } else { right_val }.into_pointer_value(), - llvm_usize, - None, - ); - let res = numpy::ndarray_elementwise_binop_impl( + // For other operations, they are all elementwise operations. + + // There are only three cases: + // - LHS is a scalar, RHS is an ndarray. + // - LHS is an ndarray, RHS is a scalar. + // - LHS is an ndarray, RHS is an ndarray. + // + // For all cases, the scalar operand is promoted to an ndarray, + // the two are then broadcasted, and starmapped through. + + let left = left.to_ndarray(generator, ctx); + let right = right.to_ndarray(generator, ctx); + + let result = NDArrayObject::broadcast_starmap( generator, ctx, - ndarray_dtype, - match op.variant { - BinopVariant::Normal => None, - BinopVariant::AugAssign => Some(ndarray_val), - }, - (left_val, !is_ndarray1), - (right_val, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - gen_binop_expr_with_values( + &[left, right], + out, + |generator, ctx, scalars| { + let left_value = scalars[0]; + let right_value = scalars[1]; + + let result = gen_binop_expr_with_values( generator, ctx, - (&Some(ndarray_dtype), lhs), + (&Some(left.dtype), left_value), op, - (&Some(ndarray_dtype), rhs), + (&Some(right.dtype), right_value), ctx.current_loc, )? .unwrap() - .to_basic_value_enum(ctx, generator, ndarray_dtype) - }, - )?; + .to_basic_value_enum(ctx, generator, common_dtype)?; - Ok(Some(res.as_base_value().into())) + Ok(result) + }, + ) + .unwrap(); + Ok(Some(ValueEnum::Dynamic(result.instance.value.as_basic_value_enum()))) } } else { let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap()); @@ -1790,14 +1774,12 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( _ => val.into(), } } else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { - let llvm_usize = generator.get_size_type(ctx.ctx); - let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); - - let val = NDArrayValue::from_ptr_val(val.into_pointer_value(), llvm_usize, None); + let ndarray = AnyObject { value: val, ty }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); // ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before // passing it to the elementwise codegen function - let op = if ndarray_dtype.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::Bool.id()) { + let op = if ndarray.dtype.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::Bool.id()) { if op == ast::Unaryop::Invert { ast::Unaryop::Not } else { @@ -1810,20 +1792,18 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( op }; - let res = numpy::ndarray_elementwise_unaryop_impl( + let mapped_ndarray = ndarray.map( generator, ctx, - ndarray_dtype, - None, - val, - |generator, ctx, val| { - gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), val))? + NDArrayOut::NewNDArray { dtype: ndarray.dtype }, + |generator, ctx, scalar| { + gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray.dtype), scalar))? .unwrap() - .to_basic_value_enum(ctx, generator, ndarray_dtype) + .to_basic_value_enum(ctx, generator, ndarray.dtype) }, )?; - res.as_base_value().into() + ValueEnum::Dynamic(mapped_ndarray.instance.value.as_basic_value_enum()) } else { unimplemented!() })) @@ -1866,85 +1846,46 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) || right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { - let llvm_usize = generator.get_size_type(ctx.ctx); - - let (Some(left_ty), lhs) = left else { unreachable!() }; - let (Some(right_ty), rhs) = comparators[0] else { unreachable!() }; + let (Some(left_ty), left) = left else { unreachable!() }; + let (Some(right_ty), right) = comparators[0] else { unreachable!() }; let op = ops[0]; - let is_ndarray1 = - left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = - right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let left = AnyObject { value: left, ty: left_ty }; + let left = + ScalarOrNDArray::split_object(generator, ctx, left).to_ndarray(generator, ctx); - return if is_ndarray1 && is_ndarray2 { - let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty); + let right = AnyObject { value: right, ty: right_ty }; + let right = + ScalarOrNDArray::split_object(generator, ctx, right).to_ndarray(generator, ctx); - assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + let result_ndarray = NDArrayObject::broadcast_starmap( + generator, + ctx, + &[left, right], + NDArrayOut::NewNDArray { dtype: ctx.primitives.bool }, + |generator, ctx, scalars| { + let left_scalar = scalars[0]; + let right_scalar = scalars[1]; - let left_val = - NDArrayValue::from_ptr_val(lhs.into_pointer_value(), llvm_usize, None); - let res = numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - ctx.primitives.bool, - None, - (left_val.as_base_value().into(), false), - (rhs, false), - |generator, ctx, (lhs, rhs)| { - let val = gen_cmpop_expr_with_values( - generator, - ctx, - (Some(ndarray_dtype1), lhs), - &[op], - &[(Some(ndarray_dtype2), rhs)], - )? - .unwrap() - .to_basic_value_enum( - ctx, - generator, - ctx.primitives.bool, - )?; + let val = gen_cmpop_expr_with_values( + generator, + ctx, + (Some(left.dtype), left_scalar), + &[op], + &[(Some(right.dtype), right_scalar)], + )? + .unwrap() + .to_basic_value_enum( + ctx, + generator, + ctx.primitives.bool, + )?; - Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) - }, - )?; + Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) + }, + )?; - Ok(Some(res.as_base_value().into())) - } else { - let (ndarray_dtype, _) = unpack_ndarray_var_tys( - &mut ctx.unifier, - if is_ndarray1 { left_ty } else { right_ty }, - ); - let res = numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - ctx.primitives.bool, - None, - (lhs, !is_ndarray1), - (rhs, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - let val = gen_cmpop_expr_with_values( - generator, - ctx, - (Some(ndarray_dtype), lhs), - &[op], - &[(Some(ndarray_dtype), rhs)], - )? - .unwrap() - .to_basic_value_enum( - ctx, - generator, - ctx.primitives.bool, - )?; - - Ok(generator.bool_to_i8(ctx, val.into_int_value()).into()) - }, - )?; - - Ok(Some(res.as_base_value().into())) - }; + return Ok(Some(result_ndarray.instance.value.into())); } } @@ -2492,338 +2433,6 @@ pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>( ) } -/// Generates code for a subscript expression on an `ndarray`. -/// -/// * `ty` - The `Type` of the `NDArray` elements. -/// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`. -/// * `v` - The `NDArray` value. -/// * `slice` - The slice expression used to subscript into the `ndarray`. -fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - ty: Type, - ndims: Type, - v: NDArrayValue<'ctx>, - slice: &Expr>, -) -> Result>, String> { - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) else { - unreachable!() - }; - - let ndims = values - .iter() - .map(|ndim| u64::try_from(ndim.clone()).map_err(|()| ndim.clone())) - .collect::, _>>() - .map_err(|val| { - format!( - "Expected non-negative literal for ndarray.ndims, got {}", - i128::try_from(val).unwrap() - ) - })?; - - assert!(!ndims.is_empty()); - - // The number of dimensions subscripted by the index expression. - // Slicing a ndarray will yield the same number of dimensions, whereas indexing into a - // dimension will remove a dimension. - let subscripted_dims = match &slice.node { - ExprKind::Tuple { elts, .. } => elts.iter().fold(0, |acc, value_subexpr| { - if let ExprKind::Slice { .. } = &value_subexpr.node { - acc - } else { - acc + 1 - } - }), - - ExprKind::Slice { .. } => 0, - _ => 1, - }; - - let ndarray_ndims_ty = ctx.unifier.get_fresh_literal( - ndims.iter().map(|v| SymbolValue::U64(v - subscripted_dims)).collect(), - None, - ); - let ndarray_ty = - make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(ty), Some(ndarray_ndims_ty)); - let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type(); - let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type(); - let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum(); - let sizeof_elem = llvm_ndarray_data_t.size_of().unwrap(); - - // Check that len is non-zero - let len = v.load_ndims(ctx); - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::SGT, len, llvm_usize.const_zero(), "").unwrap(), - "0:IndexError", - "too many indices for array: array is {0}-dimensional but 1 were indexed", - [Some(len), None, None], - slice.location, - ); - - // Normalizes a possibly-negative index to its corresponding positive index - let normalize_index = |generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - index: IntValue<'ctx>, - dim: u64| { - gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare(IntPredicate::SGE, index, index.get_type().const_zero(), "") - .unwrap()) - }, - |_, _| Ok(Some(index)), - |generator, ctx| { - let llvm_i32 = ctx.ctx.i32_type(); - - let len = unsafe { - v.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(dim, true), - None, - ) - }; - - let index = ctx - .builder - .build_int_add( - len, - ctx.builder.build_int_s_extend(index, llvm_usize, "").unwrap(), - "", - ) - .unwrap(); - - Ok(Some(ctx.builder.build_int_truncate(index, llvm_i32, "").unwrap())) - }, - ) - .map(|v| v.map(BasicValueEnum::into_int_value)) - }; - - // Converts a slice expression into a slice-range tuple - let expr_to_slice = |generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - node: &ExprKind>, - dim: u64| { - match node { - ExprKind::Constant { value: Constant::Int(v), .. } => { - let Some(index) = - normalize_index(generator, ctx, llvm_i32.const_int(*v as u64, true), dim)? - else { - return Ok(None); - }; - - Ok(Some((index, index, llvm_i32.const_int(1, true)))) - } - - ExprKind::Slice { lower, upper, step } => { - let dim_sz = unsafe { - v.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(dim, false), - None, - ) - }; - - handle_slice_indices(lower, upper, step, ctx, generator, dim_sz) - } - - _ => { - let Some(index) = generator.gen_expr(ctx, slice)? else { return Ok(None) }; - let index = index - .to_basic_value_enum(ctx, generator, slice.custom.unwrap())? - .into_int_value(); - let Some(index) = normalize_index(generator, ctx, index, dim)? else { - return Ok(None); - }; - - Ok(Some((index, index, llvm_i32.const_int(1, true)))) - } - } - }; - - let make_indices_arr = |generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>| - -> Result<_, String> { - Ok(if let ExprKind::Tuple { elts, .. } = &slice.node { - let llvm_int_ty = ctx.get_llvm_type(generator, elts[0].custom.unwrap()); - let index_addr = generator.gen_array_var_alloc( - ctx, - llvm_int_ty, - llvm_usize.const_int(elts.len() as u64, false), - None, - )?; - - for (i, elt) in elts.iter().enumerate() { - let Some(index) = generator.gen_expr(ctx, elt)? else { - return Ok(None); - }; - - let index = index - .to_basic_value_enum(ctx, generator, elt.custom.unwrap())? - .into_int_value(); - let Some(index) = normalize_index(generator, ctx, index, 0)? else { - return Ok(None); - }; - - let store_ptr = unsafe { - index_addr.ptr_offset_unchecked( - ctx, - generator, - &llvm_usize.const_int(i as u64, false), - None, - ) - }; - ctx.builder.build_store(store_ptr, index).unwrap(); - } - - Some(index_addr) - } else if let Some(index) = generator.gen_expr(ctx, slice)? { - let llvm_int_ty = ctx.get_llvm_type(generator, slice.custom.unwrap()); - let index_addr = generator.gen_array_var_alloc( - ctx, - llvm_int_ty, - llvm_usize.const_int(1u64, false), - None, - )?; - - let index = - index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value(); - let Some(index) = normalize_index(generator, ctx, index, 0)? else { return Ok(None) }; - - let store_ptr = unsafe { - index_addr.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - ctx.builder.build_store(store_ptr, index).unwrap(); - - Some(index_addr) - } else { - None - }) - }; - - Ok(Some(if ndims.len() == 1 && ndims[0] - subscripted_dims == 0 { - let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) }; - - v.data().get(ctx, generator, &index_addr, None).into() - } else { - match &slice.node { - ExprKind::Tuple { elts, .. } => { - let slices = elts - .iter() - .enumerate() - .map(|(dim, elt)| expr_to_slice(generator, ctx, &elt.node, dim as u64)) - .take_while_inclusive(|slice| slice.as_ref().is_ok_and(Option::is_some)) - .collect::, _>>()?; - if slices.len() < elts.len() { - return Ok(None); - } - - let slices = slices.into_iter().map(Option::unwrap).collect_vec(); - - numpy::ndarray_sliced_copy(generator, ctx, ty, v, &slices)?.as_base_value().into() - } - - ExprKind::Slice { .. } => { - let Some(slice) = expr_to_slice(generator, ctx, &slice.node, 0)? else { - return Ok(None); - }; - - numpy::ndarray_sliced_copy(generator, ctx, ty, v, &[slice])?.as_base_value().into() - } - - _ => { - // Accessing an element from a multi-dimensional `ndarray` - - let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) }; - - // Create a new array, remove the top dimension from the dimension-size-list, and copy the - // elements over - let subscripted_ndarray = - generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?; - let ndarray = NDArrayValue::from_ptr_val(subscripted_ndarray, llvm_usize, None); - - let num_dims = v.load_ndims(ctx); - ndarray.store_ndims( - ctx, - generator, - ctx.builder - .build_int_sub(num_dims, llvm_usize.const_int(1, false), "") - .unwrap(), - ); - - let ndarray_num_dims = ndarray.load_ndims(ctx); - ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims); - - let ndarray_num_dims = ctx - .builder - .build_int_z_extend_or_bit_cast( - ndarray.load_ndims(ctx), - llvm_usize.size_of().get_type(), - "", - ) - .unwrap(); - let v_dims_src_ptr = unsafe { - v.dim_sizes().ptr_offset_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - call_memcpy_generic( - ctx, - ndarray.dim_sizes().base_ptr(ctx, generator), - v_dims_src_ptr, - ctx.builder - .build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "") - .map(Into::into) - .unwrap(), - llvm_i1.const_zero(), - ); - - let ndarray_num_elems = call_ndarray_calc_size( - generator, - ctx, - &ndarray.dim_sizes().as_slice_value(ctx, generator), - (None, None), - ); - let ndarray_num_elems = ctx - .builder - .build_int_z_extend_or_bit_cast(ndarray_num_elems, sizeof_elem.get_type(), "") - .unwrap(); - ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems); - - let v_data_src_ptr = v.data().ptr_offset(ctx, generator, &index_addr, None); - call_memcpy_generic( - ctx, - ndarray.data().base_ptr(ctx, generator), - v_data_src_ptr, - ctx.builder - .build_int_mul( - ndarray_num_elems, - llvm_ndarray_data_t.size_of().unwrap(), - "", - ) - .map(Into::into) - .unwrap(), - llvm_i1.const_zero(), - ); - - ndarray.as_base_value().into() - } - } - })) -} - /// See [`CodeGenerator::gen_expr`]. pub fn gen_expr<'ctx, G: CodeGenerator>( generator: &mut G, @@ -3463,18 +3072,26 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( v.data().get(ctx, generator, &index, None).into() } } - TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::NDArray.id() => { - let (ty, ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap(); - - let v = if let Some(v) = generator.gen_expr(ctx, value)? { - v.to_basic_value_enum(ctx, generator, value.custom.unwrap())? - .into_pointer_value() - } else { + TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { + let Some(ndarray) = generator.gen_expr(ctx, value)? else { return Ok(None); }; - let v = NDArrayValue::from_ptr_val(v, usize, None); - return gen_ndarray_subscript_expr(generator, ctx, *ty, *ndims, v, slice); + let ndarray_ty = value.custom.unwrap(); + let ndarray = ndarray.to_basic_value_enum(ctx, generator, ndarray_ty)?; + + let ndarray = NDArrayObject::from_object( + generator, + ctx, + AnyObject { ty: ndarray_ty, value: ndarray }, + ); + + let indices = gen_ndarray_subscript_ndindices(generator, ctx, slice)?; + let result = ndarray + .index(generator, ctx, &indices) + .split_unsized(generator, ctx) + .to_basic_value_enum(); + return Ok(Some(ValueEnum::Dynamic(result))); } TypeEnum::TTuple { .. } => { let index: u32 = @@ -3517,3 +3134,42 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( _ => unimplemented!(), })) } + +/// Generate LLVM IR for an [`ExprKind::Slice`] +#[allow(clippy::type_complexity)] +pub fn gen_slice<'ctx, G: CodeGenerator>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + lower: &Option>>>, + upper: &Option>>>, + step: &Option>>>, +) -> Result< + ( + Option>>, + Option>>, + Option>>, + ), + String, +> { + let mut help = |value_expr: &Option>>>| -> Result<_, String> { + Ok(match value_expr { + None => None, + Some(value_expr) => { + let value_expr = generator + .gen_expr(ctx, value_expr)? + .unwrap() + .to_basic_value_enum(ctx, generator, ctx.primitives.int32)?; + + let value_expr = Int(Int32).check_value(generator, ctx.ctx, value_expr).unwrap(); + + Some(value_expr) + } + }) + }; + + let lower = help(lower)?; + let upper = help(upper)?; + let step = help(step)?; + + Ok((lower, upper, step)) +} diff --git a/nac3core/src/codegen/irrt/irrt.cpp b/nac3core/src/codegen/irrt/irrt.cpp deleted file mode 100644 index 6032518d..00000000 --- a/nac3core/src/codegen/irrt/irrt.cpp +++ /dev/null @@ -1,414 +0,0 @@ -using int8_t = _BitInt(8); -using uint8_t = unsigned _BitInt(8); -using int32_t = _BitInt(32); -using uint32_t = unsigned _BitInt(32); -using int64_t = _BitInt(64); -using uint64_t = unsigned _BitInt(64); - -// NDArray indices are always `uint32_t`. -using NDIndex = uint32_t; -// The type of an index or a value describing the length of a range/slice is always `int32_t`. -using SliceIndex = int32_t; - -namespace { -template -const T& max(const T& a, const T& b) { - return a > b ? a : b; -} - -template -const T& min(const T& a, const T& b) { - return a > b ? b : a; -} - -// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c -// need to make sure `exp >= 0` before calling this function -template -T __nac3_int_exp_impl(T base, T exp) { - T res = 1; - /* repeated squaring method */ - do { - if (exp & 1) { - res *= base; /* for n odd */ - } - exp >>= 1; - base *= base; - } while (exp); - return res; -} - -template -SizeT __nac3_ndarray_calc_size_impl( - const SizeT* list_data, - SizeT list_len, - SizeT begin_idx, - SizeT end_idx -) { - __builtin_assume(end_idx <= list_len); - - SizeT num_elems = 1; - for (SizeT i = begin_idx; i < end_idx; ++i) { - SizeT val = list_data[i]; - __builtin_assume(val > 0); - num_elems *= val; - } - return num_elems; -} - -template -void __nac3_ndarray_calc_nd_indices_impl( - SizeT index, - const SizeT* dims, - SizeT num_dims, - NDIndex* idxs -) { - SizeT stride = 1; - for (SizeT dim = 0; dim < num_dims; dim++) { - SizeT i = num_dims - dim - 1; - __builtin_assume(dims[i] > 0); - idxs[i] = (index / stride) % dims[i]; - stride *= dims[i]; - } -} - -template -SizeT __nac3_ndarray_flatten_index_impl( - const SizeT* dims, - SizeT num_dims, - const NDIndex* indices, - SizeT num_indices -) { - SizeT idx = 0; - SizeT stride = 1; - for (SizeT i = 0; i < num_dims; ++i) { - SizeT ri = num_dims - i - 1; - if (ri < num_indices) { - idx += stride * indices[ri]; - } - - __builtin_assume(dims[i] > 0); - stride *= dims[ri]; - } - return idx; -} - -template -void __nac3_ndarray_calc_broadcast_impl( - const SizeT* lhs_dims, - SizeT lhs_ndims, - const SizeT* rhs_dims, - SizeT rhs_ndims, - SizeT* out_dims -) { - SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims; - - for (SizeT i = 0; i < max_ndims; ++i) { - const SizeT* lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr; - const SizeT* rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr; - SizeT* out_dim = &out_dims[max_ndims - i - 1]; - - if (lhs_dim_sz == nullptr) { - *out_dim = *rhs_dim_sz; - } else if (rhs_dim_sz == nullptr) { - *out_dim = *lhs_dim_sz; - } else if (*lhs_dim_sz == 1) { - *out_dim = *rhs_dim_sz; - } else if (*rhs_dim_sz == 1) { - *out_dim = *lhs_dim_sz; - } else if (*lhs_dim_sz == *rhs_dim_sz) { - *out_dim = *lhs_dim_sz; - } else { - __builtin_unreachable(); - } - } -} - -template -void __nac3_ndarray_calc_broadcast_idx_impl( - const SizeT* src_dims, - SizeT src_ndims, - const NDIndex* in_idx, - NDIndex* out_idx -) { - for (SizeT i = 0; i < src_ndims; ++i) { - SizeT src_i = src_ndims - i - 1; - out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i]; - } -} -} // namespace - -extern "C" { -#define DEF_nac3_int_exp_(T) \ - T __nac3_int_exp_##T(T base, T exp) {\ - return __nac3_int_exp_impl(base, exp);\ - } - -DEF_nac3_int_exp_(int32_t) -DEF_nac3_int_exp_(int64_t) -DEF_nac3_int_exp_(uint32_t) -DEF_nac3_int_exp_(uint64_t) - -SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) { - if (i < 0) { - i = len + i; - } - if (i < 0) { - return 0; - } else if (i > len) { - return len; - } - return i; -} - -SliceIndex __nac3_range_slice_len( - const SliceIndex start, - const SliceIndex end, - const SliceIndex step -) { - SliceIndex diff = end - start; - if (diff > 0 && step > 0) { - return ((diff - 1) / step) + 1; - } else if (diff < 0 && step < 0) { - return ((diff + 1) / step) + 1; - } else { - return 0; - } -} - -// Handle list assignment and dropping part of the list when -// both dest_step and src_step are +1. -// - All the index must *not* be out-of-bound or negative, -// - The end index is *inclusive*, -// - The length of src and dest slice size should already -// be checked: if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest) -SliceIndex __nac3_list_slice_assign_var_size( - SliceIndex dest_start, - SliceIndex dest_end, - SliceIndex dest_step, - uint8_t* dest_arr, - SliceIndex dest_arr_len, - SliceIndex src_start, - SliceIndex src_end, - SliceIndex src_step, - uint8_t* src_arr, - SliceIndex src_arr_len, - const SliceIndex size -) { - /* if dest_arr_len == 0, do nothing since we do not support extending list */ - if (dest_arr_len == 0) return dest_arr_len; - /* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */ - if (src_step == dest_step && dest_step == 1) { - const SliceIndex src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0; - const SliceIndex dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0; - if (src_len > 0) { - __builtin_memmove( - dest_arr + dest_start * size, - src_arr + src_start * size, - src_len * size - ); - } - if (dest_len > 0) { - /* dropping */ - __builtin_memmove( - dest_arr + (dest_start + src_len) * size, - dest_arr + (dest_end + 1) * size, - (dest_arr_len - dest_end - 1) * size - ); - } - /* shrink size */ - return dest_arr_len - (dest_len - src_len); - } - /* if two range overlaps, need alloca */ - uint8_t need_alloca = - (dest_arr == src_arr) - && !( - max(dest_start, dest_end) < min(src_start, src_end) - || max(src_start, src_end) < min(dest_start, dest_end) - ); - if (need_alloca) { - uint8_t* tmp = reinterpret_cast(__builtin_alloca(src_arr_len * size)); - __builtin_memcpy(tmp, src_arr, src_arr_len * size); - src_arr = tmp; - } - SliceIndex src_ind = src_start; - SliceIndex dest_ind = dest_start; - for (; - (src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end); - src_ind += src_step, dest_ind += dest_step - ) { - /* for constant optimization */ - if (size == 1) { - __builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1); - } else if (size == 4) { - __builtin_memcpy(dest_arr + dest_ind * 4, src_arr + src_ind * 4, 4); - } else if (size == 8) { - __builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8); - } else { - /* memcpy for var size, cannot overlap after previous alloca */ - __builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size); - } - } - /* only dest_step == 1 can we shrink the dest list. */ - /* size should be ensured prior to calling this function */ - if (dest_step == 1 && dest_end >= dest_start) { - __builtin_memmove( - dest_arr + dest_ind * size, - dest_arr + (dest_end + 1) * size, - (dest_arr_len - dest_end - 1) * size - ); - return dest_arr_len - (dest_end - dest_ind) - 1; - } - return dest_arr_len; -} - -int32_t __nac3_isinf(double x) { - return __builtin_isinf(x); -} - -int32_t __nac3_isnan(double x) { - return __builtin_isnan(x); -} - -double tgamma(double arg); - -double __nac3_gamma(double z) { - // Handling for denormals - // | x | Python gamma(x) | C tgamma(x) | - // --- | ----------------- | --------------- | ----------- | - // (1) | nan | nan | nan | - // (2) | -inf | -inf | inf | - // (3) | inf | inf | inf | - // (4) | 0.0 | inf | inf | - // (5) | {-1.0, -2.0, ...} | inf | nan | - - // (1)-(3) - if (__builtin_isinf(z) || __builtin_isnan(z)) { - return z; - } - - double v = tgamma(z); - - // (4)-(5) - return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v; -} - -double lgamma(double arg); - -double __nac3_gammaln(double x) { - // libm's handling of value overflows differs from scipy: - // - scipy: gammaln(-inf) -> -inf - // - libm : lgamma(-inf) -> inf - - if (__builtin_isinf(x)) { - return x; - } - - return lgamma(x); -} - -double j0(double x); - -double __nac3_j0(double x) { - // libm's handling of value overflows differs from scipy: - // - scipy: j0(inf) -> nan - // - libm : j0(inf) -> 0.0 - - if (__builtin_isinf(x)) { - return __builtin_nan(""); - } - - return j0(x); -} - -uint32_t __nac3_ndarray_calc_size( - const uint32_t* list_data, - uint32_t list_len, - uint32_t begin_idx, - uint32_t end_idx -) { - return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx); -} - -uint64_t __nac3_ndarray_calc_size64( - const uint64_t* list_data, - uint64_t list_len, - uint64_t begin_idx, - uint64_t end_idx -) { - return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx); -} - -void __nac3_ndarray_calc_nd_indices( - uint32_t index, - const uint32_t* dims, - uint32_t num_dims, - NDIndex* idxs -) { - __nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs); -} - -void __nac3_ndarray_calc_nd_indices64( - uint64_t index, - const uint64_t* dims, - uint64_t num_dims, - NDIndex* idxs -) { - __nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs); -} - -uint32_t __nac3_ndarray_flatten_index( - const uint32_t* dims, - uint32_t num_dims, - const NDIndex* indices, - uint32_t num_indices -) { - return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices); -} - -uint64_t __nac3_ndarray_flatten_index64( - const uint64_t* dims, - uint64_t num_dims, - const NDIndex* indices, - uint64_t num_indices -) { - return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices); -} - -void __nac3_ndarray_calc_broadcast( - const uint32_t* lhs_dims, - uint32_t lhs_ndims, - const uint32_t* rhs_dims, - uint32_t rhs_ndims, - uint32_t* out_dims -) { - return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims); -} - -void __nac3_ndarray_calc_broadcast64( - const uint64_t* lhs_dims, - uint64_t lhs_ndims, - const uint64_t* rhs_dims, - uint64_t rhs_ndims, - uint64_t* out_dims -) { - return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims); -} - -void __nac3_ndarray_calc_broadcast_idx( - const uint32_t* src_dims, - uint32_t src_ndims, - const NDIndex* in_idx, - NDIndex* out_idx -) { - __nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx); -} - -void __nac3_ndarray_calc_broadcast_idx64( - const uint64_t* src_dims, - uint64_t src_ndims, - const NDIndex* in_idx, - NDIndex* out_idx -) { - __nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx); -} -} // extern "C" \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 91e62e94..ba8d78db 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -1,21 +1,22 @@ -use crate::typecheck::typedef::Type; +use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Type}; use super::{ - classes::{ - ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, - TypedArrayLikeAdapter, UntypedArrayLikeAccessor, + classes::{ArrayLikeValue, ListValue}, + model::*, + object::{ + list::List, + ndarray::{broadcast::ShapeEntry, indexing::NDIndex, nditer::NDIter, NDArray}, }, - llvm_intrinsics, CodeGenContext, CodeGenerator, + CodeGenContext, CodeGenerator, }; -use crate::codegen::classes::TypedArrayLikeAccessor; -use crate::codegen::stmt::gen_for_callback_incrementing; +use function::CallFunction; use inkwell::{ attributes::{Attribute, AttributeLoc}, context::Context, memory_buffer::MemoryBuffer, module::Module, - types::{BasicTypeEnum, IntType}, - values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue}, + types::BasicTypeEnum, + values::{BasicValue, BasicValueEnum, CallSiteValue, FloatValue, IntValue}, AddressSpace, IntPredicate, }; use itertools::Either; @@ -563,369 +564,324 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo .unwrap() } -/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the -/// calculated total size. -/// -/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension. -/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for, -/// or [`None`] if starting from the first dimension and ending at the last dimension -/// respectively. -pub fn call_ndarray_calc_size<'ctx, G, Dims>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - dims: &Dims, - (begin, end): (Option>, Option>), -) -> IntValue<'ctx> -where - G: CodeGenerator + ?Sized, - Dims: ArrayLikeIndexer<'ctx>, -{ - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); +/// Initialize all global `EXN_*` exception IDs in IRRT with the [`SymbolResolver`]. +pub fn setup_irrt_exceptions<'ctx>( + ctx: &'ctx Context, + module: &Module<'ctx>, + symbol_resolver: &dyn SymbolResolver, +) { + let exn_id_type = ctx.i32_type(); - let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_size", - 64 => "__nac3_ndarray_calc_size64", - bw => unreachable!("Unsupported size type bit width: {}", bw), - }; - let ndarray_calc_size_fn_t = llvm_usize.fn_type( - &[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()], - false, - ); - let ndarray_calc_size_fn = - ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| { - ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None) + let errors = &[ + ("EXN_INDEX_ERROR", "0:IndexError"), + ("EXN_VALUE_ERROR", "0:ValueError"), + ("EXN_ASSERTION_ERROR", "0:AssertionError"), + ("EXN_TYPE_ERROR", "0:TypeError"), + ]; + + for (irrt_name, symbol_name) in errors { + let exn_id = symbol_resolver.get_string_id(symbol_name); + let exn_id = exn_id_type.const_int(exn_id as u64, false).as_basic_value_enum(); + + let global = module.get_global(irrt_name).unwrap_or_else(|| { + panic!("Exception symbol name '{irrt_name}' should exist in the IRRT LLVM module") }); - - let begin = begin.unwrap_or_else(|| llvm_usize.const_zero()); - let end = end.unwrap_or_else(|| dims.size(ctx, generator)); - ctx.builder - .build_call( - ndarray_calc_size_fn, - &[ - dims.base_ptr(ctx, generator).into(), - dims.size(ctx, generator).into(), - begin.into(), - end.into(), - ], - "", - ) - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap() + global.set_initializer(&exn_id); + } } -/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`] -/// containing `i32` indices of the flattened index. -/// -/// * `index` - The index to compute the multidimensional index for. -/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an -/// `NDArray`. -pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, - ctx: &mut CodeGenContext<'ctx, '_>, - index: IntValue<'ctx>, - ndarray: NDArrayValue<'ctx>, -) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { - let llvm_void = ctx.ctx.void_type(); - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_nd_indices", - 64 => "__nac3_ndarray_calc_nd_indices64", - bw => unreachable!("Unsupported size type bit width: {}", bw), - }; - let ndarray_calc_nd_indices_fn = - ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| { - let fn_type = llvm_void.fn_type( - &[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()], - false, - ); - - ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None) - }); - - let ndarray_num_dims = ndarray.load_ndims(ctx); - let ndarray_dims = ndarray.dim_sizes(); - - let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap(); - - ctx.builder - .build_call( - ndarray_calc_nd_indices_fn, - &[ - index.into(), - ndarray_dims.base_ptr(ctx, generator).into(), - ndarray_num_dims.into(), - indices.into(), - ], - "", - ) - .unwrap(); - - TypedArrayLikeAdapter::from( - ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None), - Box::new(|_, v| v.into_int_value()), - Box::new(|_, v| v.into()), - ) +// When [`TypeContext::size_type`] is 32-bits, the function name is "{fn_name}". +// When [`TypeContext::size_type`] is 64-bits, the function name is "{fn_name}64". +#[must_use] +pub fn get_sizet_dependent_function_name( + generator: &mut G, + ctx: &CodeGenContext<'_, '_>, + name: &str, +) -> String { + let mut name = name.to_owned(); + match generator.get_size_type(ctx.ctx).get_bit_width() { + 32 => {} + 64 => name.push_str("64"), + bit_width => { + panic!("Unsupported int type bit width {bit_width}, must be either 32-bits or 64-bits") + } + } + name } -fn call_ndarray_flatten_index_impl<'ctx, G, Indices>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ndarray: NDArrayValue<'ctx>, - indices: &Indices, -) -> IntValue<'ctx> -where - G: CodeGenerator + ?Sized, - Indices: ArrayLikeIndexer<'ctx>, -{ - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - debug_assert_eq!( - IntType::try_from(indices.element_type(ctx, generator)) - .map(IntType::get_bit_width) - .unwrap_or_default(), - llvm_i32.get_bit_width(), - "Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`" - ); - debug_assert_eq!( - indices.size(ctx, generator).get_type().get_bit_width(), - llvm_usize.get_bit_width(), - "Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`" - ); - - let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_flatten_index", - 64 => "__nac3_ndarray_flatten_index64", - bw => unreachable!("Unsupported size type bit width: {}", bw), - }; - let ndarray_flatten_index_fn = - ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| { - let fn_type = llvm_usize.fn_type( - &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()], - false, - ); - - ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None) - }); - - let ndarray_num_dims = ndarray.load_ndims(ctx); - let ndarray_dims = ndarray.dim_sizes(); - - let index = ctx - .builder - .build_call( - ndarray_flatten_index_fn, - &[ - ndarray_dims.base_ptr(ctx, generator).into(), - ndarray_num_dims.into(), - indices.base_ptr(ctx, generator).into(), - indices.size(ctx, generator).into(), - ], - "", - ) - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap(); - - index -} - -/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the -/// multidimensional index. -/// -/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an -/// `NDArray`. -/// * `indices` - The multidimensional index to compute the flattened index for. -pub fn call_ndarray_flatten_index<'ctx, G, Index>( +pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - ndarray: NDArrayValue<'ctx>, - indices: &Index, -) -> IntValue<'ctx> -where - G: CodeGenerator + ?Sized, - Index: ArrayLikeIndexer<'ctx>, -{ - call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices) -} - -/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of -/// dimension and size of each dimension of the resultant `ndarray`. -pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - lhs: NDArrayValue<'ctx>, - rhs: NDArrayValue<'ctx>, -) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_broadcast", - 64 => "__nac3_ndarray_calc_broadcast64", - bw => unreachable!("Unsupported size type bit width: {}", bw), - }; - let ndarray_calc_broadcast_fn = - ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { - let fn_type = llvm_usize.fn_type( - &[ - llvm_pusize.into(), - llvm_usize.into(), - llvm_pusize.into(), - llvm_usize.into(), - llvm_pusize.into(), - ], - false, - ); - - ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) - }); - - let lhs_ndims = lhs.load_ndims(ctx); - let rhs_ndims = rhs.load_ndims(ctx); - let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None); - - gen_for_callback_incrementing( + ndims: Instance<'ctx, Int>, + shape: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name( generator, ctx, - None, - llvm_usize.const_zero(), - (min_ndims, false), - |generator, ctx, _, idx| { - let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap(); - let (lhs_dim_sz, rhs_dim_sz) = unsafe { - ( - lhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None), - rhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None), - ) - }; - - let llvm_usize_const_one = llvm_usize.const_int(1, false); - let lhs_eqz = ctx - .builder - .build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "") - .unwrap(); - let rhs_eqz = ctx - .builder - .build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "") - .unwrap(); - let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap(); - - let lhs_eq_rhs = ctx - .builder - .build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "") - .unwrap(); - - let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap(); - - ctx.make_assert( - generator, - is_compatible, - "0:ValueError", - "operands could not be broadcast together", - [None, None, None], - ctx.current_loc, - ); - - Ok(()) - }, - llvm_usize.const_int(1, false), - ) - .unwrap(); - - let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None); - let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator); - let lhs_ndims = lhs.load_ndims(ctx); - let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator); - let rhs_ndims = rhs.load_ndims(ctx); - let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap(); - let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None); - - ctx.builder - .build_call( - ndarray_calc_broadcast_fn, - &[ - lhs_dims.into(), - lhs_ndims.into(), - rhs_dims.into(), - rhs_ndims.into(), - out_dims.base_ptr(ctx, generator).into(), - ], - "", - ) - .unwrap(); - - TypedArrayLikeAdapter::from( - out_dims, - Box::new(|_, v| v.into_int_value()), - Box::new(|_, v| v.into()), - ) + "__nac3_ndarray_util_assert_shape_no_negative", + ); + CallFunction::begin(generator, ctx, &name).arg(ndims).arg(shape).returning_void(); } -/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`] -/// containing the indices used for accessing `array` corresponding to the index of the broadcasted -/// array `broadcast_idx`. -pub fn call_ndarray_calc_broadcast_index< - 'ctx, - G: CodeGenerator + ?Sized, - BroadcastIdx: UntypedArrayLikeAccessor<'ctx>, ->( +pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - array: NDArrayValue<'ctx>, - broadcast_idx: &BroadcastIdx, -) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_broadcast_idx", - 64 => "__nac3_ndarray_calc_broadcast_idx64", - bw => unreachable!("Unsupported size type bit width: {}", bw), - }; - let ndarray_calc_broadcast_fn = - ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { - let fn_type = llvm_usize.fn_type( - &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()], - false, - ); - - ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) - }); - - let broadcast_size = broadcast_idx.size(ctx, generator); - let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap(); - - let array_dims = array.dim_sizes().base_ptr(ctx, generator); - let array_ndims = array.load_ndims(ctx); - let broadcast_idx_ptr = unsafe { - broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - - ctx.builder - .build_call( - ndarray_calc_broadcast_fn, - &[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()], - "", - ) - .unwrap(); - - TypedArrayLikeAdapter::from( - ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None), - Box::new(|_, v| v.into_int_value()), - Box::new(|_, v| v.into()), - ) + ndarray_ndims: Instance<'ctx, Int>, + ndarray_shape: Instance<'ctx, Ptr>>, + output_ndims: Instance<'ctx, Int>, + output_shape: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name( + generator, + ctx, + "__nac3_ndarray_util_assert_output_shape_same", + ); + CallFunction::begin(generator, ctx, &name) + .arg(ndarray_ndims) + .arg(ndarray_shape) + .arg(output_ndims) + .arg(output_shape) + .returning_void(); +} + +pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: Instance<'ctx, Ptr>>, +) -> Instance<'ctx, Int> { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_size"); + CallFunction::begin(generator, ctx, &name).arg(ndarray).returning_auto("size") +} + +pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: Instance<'ctx, Ptr>>, +) -> Instance<'ctx, Int> { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_nbytes"); + CallFunction::begin(generator, ctx, &name).arg(ndarray).returning_auto("nbytes") +} + +pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: Instance<'ctx, Ptr>>, +) -> Instance<'ctx, Int> { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_len"); + CallFunction::begin(generator, ctx, &name).arg(ndarray).returning_auto("len") +} + +pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: Instance<'ctx, Ptr>>, +) -> Instance<'ctx, Int> { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_is_c_contiguous"); + CallFunction::begin(generator, ctx, &name).arg(ndarray).returning_auto("is_c_contiguous") +} + +pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: Instance<'ctx, Ptr>>, + index: Instance<'ctx, Int>, +) -> Instance<'ctx, Ptr>> { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_get_nth_pelement"); + CallFunction::begin(generator, ctx, &name).arg(ndarray).arg(index).returning_auto("pelement") +} + +pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: Instance<'ctx, Ptr>>, + indices: Instance<'ctx, Ptr>>, +) -> Instance<'ctx, Ptr>> { + let name = + get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_get_pelement_by_indices"); + CallFunction::begin(generator, ctx, &name).arg(ndarray).arg(indices).returning_auto("pelement") +} + +pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: Instance<'ctx, Ptr>>, +) { + let name = + get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_set_strides_by_shape"); + CallFunction::begin(generator, ctx, &name).arg(ndarray).returning_void(); +} + +pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + src_ndarray: Instance<'ctx, Ptr>>, + dst_ndarray: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_copy_data"); + CallFunction::begin(generator, ctx, &name).arg(src_ndarray).arg(dst_ndarray).returning_void(); +} + +pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + iter: Instance<'ctx, Ptr>>, + ndarray: Instance<'ctx, Ptr>>, + indices: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_initialize"); + CallFunction::begin(generator, ctx, &name).arg(iter).arg(ndarray).arg(indices).returning_void(); +} + +pub fn call_nac3_nditer_has_next<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + iter: Instance<'ctx, Ptr>>, +) -> Instance<'ctx, Int> { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_has_next"); + CallFunction::begin(generator, ctx, &name).arg(iter).returning_auto("has_next") +} + +pub fn call_nac3_nditer_next<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + iter: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_nditer_next"); + CallFunction::begin(generator, ctx, &name).arg(iter).returning_void(); +} + +pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + num_indices: Instance<'ctx, Int>, + indices: Instance<'ctx, Ptr>>, + src_ndarray: Instance<'ctx, Ptr>>, + dst_ndarray: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_index"); + CallFunction::begin(generator, ctx, &name) + .arg(num_indices) + .arg(indices) + .arg(src_ndarray) + .arg(dst_ndarray) + .returning_void(); +} + +pub fn call_nac3_ndarray_array_set_and_validate_list_shape<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + list: Instance<'ctx, Ptr>>>>, + ndims: Instance<'ctx, Int>, + shape: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name( + generator, + ctx, + "__nac3_ndarray_array_set_and_validate_list_shape", + ); + CallFunction::begin(generator, ctx, &name).arg(list).arg(ndims).arg(shape).returning_void(); +} + +pub fn call_nac3_ndarray_array_write_list_to_array<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + list: Instance<'ctx, Ptr>>>>, + ndarray: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name( + generator, + ctx, + "__nac3_ndarray_array_write_list_to_array", + ); + CallFunction::begin(generator, ctx, &name).arg(list).arg(ndarray).returning_void(); +} + +pub fn call_nac3_ndarray_reshape_resolve_and_check_new_shape<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + size: Instance<'ctx, Int>, + new_ndims: Instance<'ctx, Int>, + new_shape: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name( + generator, + ctx, + "__nac3_ndarray_reshape_resolve_and_check_new_shape", + ); + CallFunction::begin(generator, ctx, &name) + .arg(size) + .arg(new_ndims) + .arg(new_shape) + .returning_void(); +} + +pub fn call_nac3_ndarray_broadcast_to<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + src_ndarray: Instance<'ctx, Ptr>>, + dst_ndarray: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_to"); + CallFunction::begin(generator, ctx, &name).arg(src_ndarray).arg(dst_ndarray).returning_void(); +} + +pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + num_shape_entries: Instance<'ctx, Int>, + shape_entries: Instance<'ctx, Ptr>>, + dst_ndims: Instance<'ctx, Int>, + dst_shape: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_shapes"); + CallFunction::begin(generator, ctx, &name) + .arg(num_shape_entries) + .arg(shape_entries) + .arg(dst_ndims) + .arg(dst_shape) + .returning_void(); +} + +pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + src_ndarray: Instance<'ctx, Ptr>>, + dst_ndarray: Instance<'ctx, Ptr>>, + num_axes: Instance<'ctx, Int>, + axes: Instance<'ctx, Ptr>>, +) { + let name = get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_transpose"); + CallFunction::begin(generator, ctx, &name) + .arg(src_ndarray) + .arg(dst_ndarray) + .arg(num_axes) + .arg(axes) + .returning_void(); +} + +#[allow(clippy::too_many_arguments)] +pub fn call_nac3_ndarray_matmul_calculate_shapes<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + a_ndims: Instance<'ctx, Int>, + a_shape: Instance<'ctx, Ptr>>, + b_ndims: Instance<'ctx, Int>, + b_shape: Instance<'ctx, Ptr>>, + final_ndims: Instance<'ctx, Int>, + new_a_shape: Instance<'ctx, Ptr>>, + new_b_shape: Instance<'ctx, Ptr>>, + dst_shape: Instance<'ctx, Ptr>>, +) { + let name = + get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_matmul_calculate_shapes"); + CallFunction::begin(generator, ctx, &name) + .arg(a_ndims) + .arg(a_shape) + .arg(b_ndims) + .arg(b_shape) + .arg(final_ndims) + .arg(new_a_shape) + .arg(new_b_shape) + .arg(dst_shape) + .returning_void(); } diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 71a2d52a..b90cb59b 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -1,7 +1,7 @@ use crate::{ - codegen::classes::{ListType, NDArrayType, ProxyType, RangeType}, + codegen::classes::{ListType, ProxyType, RangeType}, symbol_resolver::{StaticValue, SymbolResolver}, - toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef}, + toplevel::{helper::PrimDef, TopLevelContext, TopLevelDef}, typecheck::{ type_inferencer::{CodeLocation, PrimitiveStore}, typedef::{CallId, FuncArg, Type, TypeEnum, Unifier}, @@ -24,7 +24,9 @@ use inkwell::{ AddressSpace, IntPredicate, OptimizationLevel, }; use itertools::Itertools; +use model::*; use nac3parser::ast::{Location, Stmt, StrRef}; +use object::ndarray::NDArray; use parking_lot::{Condvar, Mutex}; use std::collections::{HashMap, HashSet}; use std::sync::{ @@ -41,7 +43,9 @@ pub mod extern_fns; mod generator; pub mod irrt; pub mod llvm_intrinsics; +pub mod model; pub mod numpy; +pub mod object; pub mod stmt; #[cfg(test)] @@ -489,12 +493,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( } TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - let (dtype, _) = unpack_ndarray_var_tys(unifier, ty); - let element_type = get_llvm_type( - ctx, module, generator, unifier, top_level, type_cache, dtype, - ); - - NDArrayType::new(generator, ctx, element_type).as_base_type().into() + Ptr(Struct(NDArray)).get_type(generator, ctx).as_basic_type_enum() } _ => unreachable!( diff --git a/nac3core/src/codegen/model/any.rs b/nac3core/src/codegen/model/any.rs new file mode 100644 index 00000000..9df863e8 --- /dev/null +++ b/nac3core/src/codegen/model/any.rs @@ -0,0 +1,42 @@ +use inkwell::{ + context::Context, + types::{BasicType, BasicTypeEnum}, + values::BasicValueEnum, +}; + +use crate::codegen::CodeGenerator; + +use super::*; + +/// A [`Model`] of any [`BasicTypeEnum`]. +/// +/// Use this when it is infeasible to use model abstractions. +#[derive(Debug, Clone, Copy)] +pub struct Any<'ctx>(pub BasicTypeEnum<'ctx>); + +impl<'ctx> Model<'ctx> for Any<'ctx> { + type Value = BasicValueEnum<'ctx>; + type Type = BasicTypeEnum<'ctx>; + + fn get_type( + &self, + _generator: &G, + _ctx: &'ctx Context, + ) -> Self::Type { + self.0 + } + + fn check_type, G: CodeGenerator + ?Sized>( + &self, + _generator: &mut G, + _ctx: &'ctx Context, + ty: T, + ) -> Result<(), ModelError> { + let ty = ty.as_basic_type_enum(); + if ty == self.0 { + Ok(()) + } else { + Err(ModelError(format!("Expecting {}, but got {}", self.0, ty))) + } + } +} diff --git a/nac3core/src/codegen/model/array.rs b/nac3core/src/codegen/model/array.rs new file mode 100644 index 00000000..be8dc0be --- /dev/null +++ b/nac3core/src/codegen/model/array.rs @@ -0,0 +1,143 @@ +use std::fmt; + +use inkwell::{ + context::Context, + types::{ArrayType, BasicType, BasicTypeEnum}, + values::{ArrayValue, IntValue}, +}; + +use crate::codegen::{CodeGenContext, CodeGenerator}; + +use super::*; + +/// Trait for Rust structs identifying length values for [`Array`]. +pub trait LenKind: fmt::Debug + Clone + Copy { + fn get_length(&self) -> u32; +} + +/// A statically known length. +#[derive(Debug, Clone, Copy, Default)] +pub struct Len; + +/// A dynamically known length. +#[derive(Debug, Clone, Copy)] +pub struct AnyLen(pub u32); + +impl LenKind for Len { + fn get_length(&self) -> u32 { + N + } +} + +impl LenKind for AnyLen { + fn get_length(&self) -> u32 { + self.0 + } +} + +/// A Model for an [`ArrayType`]. +/// +/// `Len` should be of a [`LenKind`] and `Item` should be a of [`Model`]. +#[derive(Debug, Clone, Copy, Default)] +pub struct Array { + /// Length of this array. + pub len: Len, + /// [`Model`] of the array items. + pub item: Item, +} + +impl<'ctx, Len: LenKind, Item: Model<'ctx>> Model<'ctx> for Array { + type Value = ArrayValue<'ctx>; + type Type = ArrayType<'ctx>; + + fn get_type(&self, generator: &G, ctx: &'ctx Context) -> Self::Type { + self.item.get_type(generator, ctx).array_type(self.len.get_length()) + } + + fn check_type, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &'ctx Context, + ty: T, + ) -> Result<(), ModelError> { + let ty = ty.as_basic_type_enum(); + let BasicTypeEnum::ArrayType(ty) = ty else { + return Err(ModelError(format!("Expecting ArrayType, but got {ty:?}"))); + }; + + if ty.len() != self.len.get_length() { + return Err(ModelError(format!( + "Expecting ArrayType with size {}, but got an ArrayType with size {}", + ty.len(), + self.len.get_length() + ))); + } + + self.item + .check_type(generator, ctx, ty.get_element_type()) + .map_err(|err| err.under_context("an ArrayType"))?; + + Ok(()) + } +} + +impl<'ctx, Len: LenKind, Item: Model<'ctx>> Instance<'ctx, Ptr>> { + /// Get the pointer to the `i`-th (0-based) array element. + pub fn gep( + &self, + ctx: &CodeGenContext<'ctx, '_>, + i: IntValue<'ctx>, + ) -> Instance<'ctx, Ptr> { + let zero = ctx.ctx.i32_type().const_zero(); + let ptr = unsafe { ctx.builder.build_in_bounds_gep(self.value, &[zero, i], "").unwrap() }; + + Ptr(self.model.0.item).believe_value(ptr) + } + + /// Like `gep` but `i` is a constant. + pub fn gep_const(&self, ctx: &CodeGenContext<'ctx, '_>, i: u64) -> Instance<'ctx, Ptr> { + assert!( + i < u64::from(self.model.0.len.get_length()), + "Index {i} is out of bounds. Array length = {}", + self.model.0.len.get_length() + ); + + let i = ctx.ctx.i32_type().const_int(i, false); + self.gep(ctx, i) + } + + /// Convenience function equivalent to `.gep(...).load(...)`. + pub fn get( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + i: IntValue<'ctx>, + ) -> Instance<'ctx, Item> { + self.gep(ctx, i).load(generator, ctx) + } + + /// Like `get` but `i` is a constant. + pub fn get_const( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + i: u64, + ) -> Instance<'ctx, Item> { + self.gep_const(ctx, i).load(generator, ctx) + } + + /// Convenience function equivalent to `.gep(...).store(...)`. + pub fn set( + &self, + ctx: &CodeGenContext<'ctx, '_>, + i: IntValue<'ctx>, + value: Instance<'ctx, Item>, + ) { + self.gep(ctx, i).store(ctx, value); + } + + /// Like `set` but `i` is a constant. + pub fn set_const(&self, ctx: &CodeGenContext<'ctx, '_>, i: u64, value: Instance<'ctx, Item>) { + self.gep_const(ctx, i).store(ctx, value); + } +} diff --git a/nac3core/src/codegen/model/core.rs b/nac3core/src/codegen/model/core.rs new file mode 100644 index 00000000..25faeea5 --- /dev/null +++ b/nac3core/src/codegen/model/core.rs @@ -0,0 +1,202 @@ +use std::fmt; + +use inkwell::{context::Context, types::*, values::*}; +use itertools::Itertools; + +use super::*; +use crate::codegen::{CodeGenContext, CodeGenerator}; + +/// A error type for reporting any [`Model`]-related error (e.g., a [`BasicType`] mismatch). +#[derive(Debug, Clone)] +pub struct ModelError(pub String); + +impl ModelError { + // Append a context message to the error. + pub(super) fn under_context(mut self, context: &str) -> Self { + self.0.push_str(" ... in "); + self.0.push_str(context); + self + } +} + +/// Trait for Rust structs identifying [`BasicType`]s in the context of a known [`CodeGenerator`] and [`CodeGenContext`]. +/// +/// For instance, +/// - [`Int`] identifies an [`IntType`] with 32-bits. +/// - [`Int`] identifies an [`IntType`] with bit-width [`CodeGenerator::get_size_type`]. +/// - [`Ptr>`] identifies a [`PointerType`] that points to an [`IntType`] with bit-width [`CodeGenerator::get_size_type`]. +/// - [`Int`] identifies an [`IntType`] with bit-width of whatever is set in the [`AnyInt`] object. +/// - [`Any`] identifies a [`BasicType`] set in the [`Any`] object itself. +/// +/// You can get the [`BasicType`] out of a model with [`Model::get_type`]. +/// +/// Furthermore, [`Instance<'ctx, M>`] is a simple structure that carries a [`BasicValue`] with [`BasicType`] identified by model `M`. +/// +/// The main purpose of this abstraction is to have a more Rust type-safe way to use Inkwell and give type-hints for programmers. +/// +/// ### Notes on `Default` trait +/// +/// For some models like [`Int`] or [`Int`], they have a [`Default`] trait since just by looking at their types, it is possible +/// to tell the [`BasicType`]s they are identifying. +/// +/// This can be used to create strongly-typed interfaces accepting only values of a specific [`BasicType`] without having to worry about +/// writing debug assertions to check, for example, if the programmer has passed in an [`IntValue`] with the wrong bit-width. +/// ```ignore +/// fn give_me_i32_and_get_a_size_t_back<'ctx>(i32: Instance<'ctx, Int>) -> Instance<'ctx, Int> { +/// // code... +/// } +/// ``` +/// +/// ### Notes on converting between Inkwell and model. +/// +/// Suppose you have an [`IntValue`], and you want to pass it into a function that takes a [`Instance<'ctx, Int>`]. You can do use +/// [`Model::check_value`] or [`Model::believe_value`]. +/// ```ignore +/// let my_value: IntValue<'ctx>; +/// +/// let my_value = Int(Int32).check_value(my_value).unwrap(); // Panics if `my_value` is not 32-bit with a descriptive error message. +/// +/// // or, if you are absolutely certain that `my_value` is 32-bit and doing extra checks is a waste of time: +/// let my_value = Int(Int32).believe_value(my_value); +/// ``` +pub trait Model<'ctx>: fmt::Debug + Clone + Copy { + /// The [`BasicType`] *variant* this model is identifying. + type Type: BasicType<'ctx>; + + /// The [`BasicValue`] type of the [`BasicType`] of this model. + type Value: BasicValue<'ctx> + TryFrom>; + + /// Return the [`BasicType`] of this model. + #[must_use] + fn get_type(&self, generator: &G, ctx: &'ctx Context) -> Self::Type; + + /// Get the number of bytes of the [`BasicType`] of this model. + fn sizeof( + &self, + generator: &mut G, + ctx: &'ctx Context, + ) -> IntValue<'ctx> { + self.get_type(generator, ctx).size_of().unwrap() + } + + /// Check if a [`BasicType`] matches the [`BasicType`] of this model. + fn check_type, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &'ctx Context, + ty: T, + ) -> Result<(), ModelError>; + + /// Create an instance from a value. + /// + /// Caller must make sure the type of `value` and the type of this `model` are equivalent. + #[must_use] + fn believe_value(&self, value: Self::Value) -> Instance<'ctx, Self> { + Instance { model: *self, value } + } + + /// Check if a [`BasicValue`]'s type is equivalent to the type of this model. + /// Wrap the [`BasicValue`] into an [`Instance`] if it is. + fn check_value, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &'ctx Context, + value: V, + ) -> Result, ModelError> { + let value = value.as_basic_value_enum(); + self.check_type(generator, ctx, value.get_type()) + .map_err(|err| err.under_context(format!("the value {value:?}").as_str()))?; + + let Ok(value) = Self::Value::try_from(value) else { + unreachable!("check_type() has bad implementation") + }; + Ok(self.believe_value(value)) + } + + // Allocate a value on the stack and return its pointer. + fn alloca( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + ) -> Instance<'ctx, Ptr> { + let p = ctx.builder.build_alloca(self.get_type(generator, ctx.ctx), "").unwrap(); + Ptr(*self).believe_value(p) + } + + // Allocate an array on the stack and return its pointer. + fn array_alloca( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + len: IntValue<'ctx>, + ) -> Instance<'ctx, Ptr> { + let p = ctx.builder.build_array_alloca(self.get_type(generator, ctx.ctx), len, "").unwrap(); + Ptr(*self).believe_value(p) + } + + fn var_alloca( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&str>, + ) -> Result>, String> { + let ty = self.get_type(generator, ctx.ctx).as_basic_type_enum(); + let p = generator.gen_var_alloc(ctx, ty, name)?; + Ok(Ptr(*self).believe_value(p)) + } + + fn array_var_alloca( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + len: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> Result>, String> { + // TODO: Remove ArraySliceValue + let ty = self.get_type(generator, ctx.ctx).as_basic_type_enum(); + let p = generator.gen_array_var_alloc(ctx, ty, len, name)?; + Ok(Ptr(*self).believe_value(PointerValue::from(p))) + } + + /// Allocate a constant array. + fn const_array( + &self, + generator: &mut G, + ctx: &'ctx Context, + values: &[Instance<'ctx, Self>], + ) -> Instance<'ctx, Array> { + macro_rules! make { + ($t:expr, $into_value:expr) => { + $t.const_array( + &values + .iter() + .map(|x| $into_value(x.value.as_basic_value_enum())) + .collect_vec(), + ) + }; + } + + let value = match self.get_type(generator, ctx).as_basic_type_enum() { + BasicTypeEnum::ArrayType(t) => make!(t, BasicValueEnum::into_array_value), + BasicTypeEnum::IntType(t) => make!(t, BasicValueEnum::into_int_value), + BasicTypeEnum::FloatType(t) => make!(t, BasicValueEnum::into_float_value), + BasicTypeEnum::PointerType(t) => make!(t, BasicValueEnum::into_pointer_value), + BasicTypeEnum::StructType(t) => make!(t, BasicValueEnum::into_struct_value), + BasicTypeEnum::VectorType(t) => make!(t, BasicValueEnum::into_vector_value), + }; + + Array { len: AnyLen(values.len() as u32), item: *self } + .check_value(generator, ctx, value) + .unwrap() + } +} + +#[derive(Debug, Clone, Copy)] +pub struct Instance<'ctx, M: Model<'ctx>> { + /// The model of this instance. + pub model: M, + /// The value of this instance. + /// + /// It is guaranteed the [`BasicType`] of `value` is consistent with that of `model`. + pub value: M::Value, +} diff --git a/nac3core/src/codegen/model/float.rs b/nac3core/src/codegen/model/float.rs new file mode 100644 index 00000000..88bff80b --- /dev/null +++ b/nac3core/src/codegen/model/float.rs @@ -0,0 +1,90 @@ +use std::fmt; + +use inkwell::{ + context::Context, + types::{BasicType, FloatType}, + values::FloatValue, +}; + +use crate::codegen::CodeGenerator; + +use super::*; + +pub trait FloatKind<'ctx>: fmt::Debug + Clone + Copy { + fn get_float_type( + &self, + generator: &G, + ctx: &'ctx Context, + ) -> FloatType<'ctx>; +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct Float32; +#[derive(Debug, Clone, Copy, Default)] +pub struct Float64; + +impl<'ctx> FloatKind<'ctx> for Float32 { + fn get_float_type( + &self, + _generator: &G, + ctx: &'ctx Context, + ) -> FloatType<'ctx> { + ctx.f32_type() + } +} + +impl<'ctx> FloatKind<'ctx> for Float64 { + fn get_float_type( + &self, + _generator: &G, + ctx: &'ctx Context, + ) -> FloatType<'ctx> { + ctx.f64_type() + } +} + +#[derive(Debug, Clone, Copy)] +pub struct AnyFloat<'ctx>(FloatType<'ctx>); + +impl<'ctx> FloatKind<'ctx> for AnyFloat<'ctx> { + fn get_float_type( + &self, + _generator: &G, + _ctx: &'ctx Context, + ) -> FloatType<'ctx> { + self.0 + } +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct Float(pub N); + +impl<'ctx, N: FloatKind<'ctx>> Model<'ctx> for Float { + type Value = FloatValue<'ctx>; + type Type = FloatType<'ctx>; + + fn get_type(&self, generator: &G, ctx: &'ctx Context) -> Self::Type { + self.0.get_float_type(generator, ctx) + } + + fn check_type, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &'ctx Context, + ty: T, + ) -> Result<(), ModelError> { + let ty = ty.as_basic_type_enum(); + let Ok(ty) = FloatType::try_from(ty) else { + return Err(ModelError(format!("Expecting FloatType, but got {ty:?}"))); + }; + + let exp_ty = self.0.get_float_type(generator, ctx); + + // TODO: Inkwell does not have get_bit_width for FloatType? + if ty != exp_ty { + return Err(ModelError(format!("Expecting {exp_ty:?}, but got {ty:?}"))); + } + + Ok(()) + } +} diff --git a/nac3core/src/codegen/model/function.rs b/nac3core/src/codegen/model/function.rs new file mode 100644 index 00000000..7ff2d746 --- /dev/null +++ b/nac3core/src/codegen/model/function.rs @@ -0,0 +1,122 @@ +use inkwell::{ + attributes::{Attribute, AttributeLoc}, + types::{BasicMetadataTypeEnum, BasicType, FunctionType}, + values::{AnyValue, BasicMetadataValueEnum, BasicValue, BasicValueEnum, CallSiteValue}, +}; +use itertools::Itertools; + +use crate::codegen::{CodeGenContext, CodeGenerator}; + +use super::*; + +#[derive(Debug, Clone, Copy)] +struct Arg<'ctx> { + ty: BasicMetadataTypeEnum<'ctx>, + val: BasicMetadataValueEnum<'ctx>, +} + +/// A convenience structure to construct & call an LLVM function. +/// +/// ### Usage +/// +/// The syntax is like this: +/// ```ignore +/// let result = CallFunction::begin("my_function_name") +/// .attrs(...) +/// .arg(arg1) +/// .arg(arg2) +/// .arg(arg3) +/// .returning("my_function_result", Int32); +/// ``` +/// +/// The function `my_function_name` is called when `.returning()` (or its variants) is called, returning +/// the result as an `Instance<'ctx, Int>`. +/// +/// If `my_function_name` has not been declared in `ctx.module`, once `.returning()` is called, a function +/// declaration of `my_function_name` is added to `ctx.module`, where the [`FunctionType`] is deduced from +/// the argument types and returning type. +pub struct CallFunction<'ctx, 'a, 'b, 'c, 'd, G: CodeGenerator + ?Sized> { + generator: &'d mut G, + ctx: &'b CodeGenContext<'ctx, 'a>, + /// Function name + name: &'c str, + /// Call arguments + args: Vec>, + /// LLVM function Attributes + attrs: Vec<&'static str>, +} + +impl<'ctx, 'a, 'b, 'c, 'd, G: CodeGenerator + ?Sized> CallFunction<'ctx, 'a, 'b, 'c, 'd, G> { + pub fn begin(generator: &'d mut G, ctx: &'b CodeGenContext<'ctx, 'a>, name: &'c str) -> Self { + CallFunction { generator, ctx, name, args: Vec::new(), attrs: Vec::new() } + } + + /// Push a list of LLVM function attributes to the function declaration. + #[must_use] + pub fn attrs(mut self, attrs: Vec<&'static str>) -> Self { + self.attrs = attrs; + self + } + + /// Push a call argument to the function call. + #[allow(clippy::needless_pass_by_value)] + #[must_use] + pub fn arg>(mut self, arg: Instance<'ctx, M>) -> Self { + let arg = Arg { + ty: arg.model.get_type(self.generator, self.ctx.ctx).as_basic_type_enum().into(), + val: arg.value.as_basic_value_enum().into(), + }; + self.args.push(arg); + self + } + + /// Call the function and expect the function to return a value of type of `return_model`. + #[must_use] + pub fn returning>(self, name: &str, return_model: M) -> Instance<'ctx, M> { + let ret_ty = return_model.get_type(self.generator, self.ctx.ctx); + + let ret = self.call(|tys| ret_ty.fn_type(tys, false), name); + let ret = BasicValueEnum::try_from(ret.as_any_value_enum()).unwrap(); // Must work + let ret = return_model.check_value(self.generator, self.ctx.ctx, ret).unwrap(); // Must work + ret + } + + /// Like [`CallFunction::returning_`] but `return_model` is automatically inferred. + #[must_use] + pub fn returning_auto + Default>(self, name: &str) -> Instance<'ctx, M> { + self.returning(name, M::default()) + } + + /// Call the function and expect the function to return a void-type. + pub fn returning_void(self) { + let ret_ty = self.ctx.ctx.void_type(); + + let _ = self.call(|tys| ret_ty.fn_type(tys, false), ""); + } + + fn call(&self, make_fn_type: F, return_value_name: &str) -> CallSiteValue<'ctx> + where + F: FnOnce(&[BasicMetadataTypeEnum<'ctx>]) -> FunctionType<'ctx>, + { + // Get the LLVM function. + let func = self.ctx.module.get_function(self.name).unwrap_or_else(|| { + // Declare the function if it doesn't exist. + let tys = self.args.iter().map(|arg| arg.ty).collect_vec(); + + let func_type = make_fn_type(&tys); + let func = self.ctx.module.add_function(self.name, func_type, None); + + for attr in &self.attrs { + func.add_attribute( + AttributeLoc::Function, + self.ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), + ); + } + + func + }); + + let vals = self.args.iter().map(|arg| arg.val).collect_vec(); + self.ctx.builder.build_call(func, &vals, return_value_name).unwrap() + } +} diff --git a/nac3core/src/codegen/model/int.rs b/nac3core/src/codegen/model/int.rs new file mode 100644 index 00000000..3a8a4fe3 --- /dev/null +++ b/nac3core/src/codegen/model/int.rs @@ -0,0 +1,417 @@ +use std::{cmp::Ordering, fmt}; + +use inkwell::{ + context::Context, + types::{BasicType, IntType}, + values::IntValue, + IntPredicate, +}; + +use crate::codegen::{CodeGenContext, CodeGenerator}; + +use super::*; + +pub trait IntKind<'ctx>: fmt::Debug + Clone + Copy { + fn get_int_type( + &self, + generator: &G, + ctx: &'ctx Context, + ) -> IntType<'ctx>; +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct Bool; +#[derive(Debug, Clone, Copy, Default)] +pub struct Byte; +#[derive(Debug, Clone, Copy, Default)] +pub struct Int32; +#[derive(Debug, Clone, Copy, Default)] +pub struct Int64; +#[derive(Debug, Clone, Copy, Default)] +pub struct SizeT; + +impl<'ctx> IntKind<'ctx> for Bool { + fn get_int_type( + &self, + _generator: &G, + ctx: &'ctx Context, + ) -> IntType<'ctx> { + ctx.bool_type() + } +} + +impl<'ctx> IntKind<'ctx> for Byte { + fn get_int_type( + &self, + _generator: &G, + ctx: &'ctx Context, + ) -> IntType<'ctx> { + ctx.i8_type() + } +} + +impl<'ctx> IntKind<'ctx> for Int32 { + fn get_int_type( + &self, + _generator: &G, + ctx: &'ctx Context, + ) -> IntType<'ctx> { + ctx.i32_type() + } +} + +impl<'ctx> IntKind<'ctx> for Int64 { + fn get_int_type( + &self, + _generator: &G, + ctx: &'ctx Context, + ) -> IntType<'ctx> { + ctx.i64_type() + } +} + +impl<'ctx> IntKind<'ctx> for SizeT { + fn get_int_type( + &self, + generator: &G, + ctx: &'ctx Context, + ) -> IntType<'ctx> { + generator.get_size_type(ctx) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct AnyInt<'ctx>(pub IntType<'ctx>); + +impl<'ctx> IntKind<'ctx> for AnyInt<'ctx> { + fn get_int_type( + &self, + _generator: &G, + _ctx: &'ctx Context, + ) -> IntType<'ctx> { + self.0 + } +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct Int(pub N); + +impl<'ctx, N: IntKind<'ctx>> Model<'ctx> for Int { + type Value = IntValue<'ctx>; + type Type = IntType<'ctx>; + + fn get_type(&self, generator: &G, ctx: &'ctx Context) -> Self::Type { + self.0.get_int_type(generator, ctx) + } + + fn check_type, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &'ctx Context, + ty: T, + ) -> Result<(), ModelError> { + let ty = ty.as_basic_type_enum(); + let Ok(ty) = IntType::try_from(ty) else { + return Err(ModelError(format!("Expecting IntType, but got {ty:?}"))); + }; + + let exp_ty = self.0.get_int_type(generator, ctx); + if ty.get_bit_width() != exp_ty.get_bit_width() { + return Err(ModelError(format!( + "Expecting IntType to have {} bit(s), but got {} bit(s)", + exp_ty.get_bit_width(), + ty.get_bit_width() + ))); + } + + Ok(()) + } +} + +impl<'ctx, N: IntKind<'ctx>> Int { + pub fn const_int( + &self, + generator: &mut G, + ctx: &'ctx Context, + value: u64, + ) -> Instance<'ctx, Self> { + let value = self.get_type(generator, ctx).const_int(value, false); + self.believe_value(value) + } + + pub fn const_0( + &self, + generator: &mut G, + ctx: &'ctx Context, + ) -> Instance<'ctx, Self> { + let value = self.get_type(generator, ctx).const_zero(); + self.believe_value(value) + } + + pub fn const_1( + &self, + generator: &mut G, + ctx: &'ctx Context, + ) -> Instance<'ctx, Self> { + self.const_int(generator, ctx, 1) + } + + pub fn const_all_ones( + &self, + generator: &mut G, + ctx: &'ctx Context, + ) -> Instance<'ctx, Self> { + let value = self.get_type(generator, ctx).const_all_ones(); + self.believe_value(value) + } + + pub fn s_extend_or_bit_cast( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + value: IntValue<'ctx>, + ) -> Instance<'ctx, Self> { + assert!( + value.get_type().get_bit_width() + <= self.0.get_int_type(generator, ctx.ctx).get_bit_width() + ); + let value = ctx + .builder + .build_int_s_extend_or_bit_cast(value, self.get_type(generator, ctx.ctx), "") + .unwrap(); + self.believe_value(value) + } + + pub fn s_extend( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + value: IntValue<'ctx>, + ) -> Instance<'ctx, Self> { + assert!( + value.get_type().get_bit_width() + < self.0.get_int_type(generator, ctx.ctx).get_bit_width() + ); + let value = + ctx.builder.build_int_s_extend(value, self.get_type(generator, ctx.ctx), "").unwrap(); + self.believe_value(value) + } + + pub fn z_extend_or_bit_cast( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + value: IntValue<'ctx>, + ) -> Instance<'ctx, Self> { + assert!( + value.get_type().get_bit_width() + <= self.0.get_int_type(generator, ctx.ctx).get_bit_width() + ); + let value = ctx + .builder + .build_int_z_extend_or_bit_cast(value, self.get_type(generator, ctx.ctx), "") + .unwrap(); + self.believe_value(value) + } + + pub fn z_extend( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + value: IntValue<'ctx>, + ) -> Instance<'ctx, Self> { + assert!( + value.get_type().get_bit_width() + < self.0.get_int_type(generator, ctx.ctx).get_bit_width() + ); + let value = + ctx.builder.build_int_z_extend(value, self.get_type(generator, ctx.ctx), "").unwrap(); + self.believe_value(value) + } + + pub fn truncate_or_bit_cast( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + value: IntValue<'ctx>, + ) -> Instance<'ctx, Self> { + assert!( + value.get_type().get_bit_width() + >= self.0.get_int_type(generator, ctx.ctx).get_bit_width() + ); + let value = ctx + .builder + .build_int_truncate_or_bit_cast(value, self.get_type(generator, ctx.ctx), "") + .unwrap(); + self.believe_value(value) + } + + pub fn truncate( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + value: IntValue<'ctx>, + ) -> Instance<'ctx, Self> { + assert!( + value.get_type().get_bit_width() + > self.0.get_int_type(generator, ctx.ctx).get_bit_width() + ); + let value = + ctx.builder.build_int_truncate(value, self.get_type(generator, ctx.ctx), "").unwrap(); + self.believe_value(value) + } + + /// `sext` or `trunc` an int to this model's int type. Does nothing if equal bit-widths. + pub fn s_extend_or_truncate( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + value: IntValue<'ctx>, + ) -> Instance<'ctx, Self> { + let their_width = value.get_type().get_bit_width(); + let our_width = self.0.get_int_type(generator, ctx.ctx).get_bit_width(); + match their_width.cmp(&our_width) { + Ordering::Less => self.s_extend(generator, ctx, value), + Ordering::Equal => self.believe_value(value), + Ordering::Greater => self.truncate(generator, ctx, value), + } + } + + /// `zext` or `trunc` an int to this model's int type. Does nothing if equal bit-widths. + pub fn z_extend_or_truncate( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + value: IntValue<'ctx>, + ) -> Instance<'ctx, Self> { + let their_width = value.get_type().get_bit_width(); + let our_width = self.0.get_int_type(generator, ctx.ctx).get_bit_width(); + match their_width.cmp(&our_width) { + Ordering::Less => self.z_extend(generator, ctx, value), + Ordering::Equal => self.believe_value(value), + Ordering::Greater => self.truncate(generator, ctx, value), + } + } +} + +impl Int { + #[must_use] + pub fn const_false<'ctx, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &'ctx Context, + ) -> Instance<'ctx, Self> { + self.const_int(generator, ctx, 0) + } + + #[must_use] + pub fn const_true<'ctx, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &'ctx Context, + ) -> Instance<'ctx, Self> { + self.const_int(generator, ctx, 1) + } +} + +impl<'ctx, N: IntKind<'ctx>> Instance<'ctx, Int> { + pub fn s_extend_or_bit_cast, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + to_int_kind: NewN, + ) -> Instance<'ctx, Int> { + Int(to_int_kind).s_extend_or_bit_cast(generator, ctx, self.value) + } + + pub fn s_extend, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + to_int_kind: NewN, + ) -> Instance<'ctx, Int> { + Int(to_int_kind).s_extend(generator, ctx, self.value) + } + + pub fn z_extend_or_bit_cast, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + to_int_kind: NewN, + ) -> Instance<'ctx, Int> { + Int(to_int_kind).z_extend_or_bit_cast(generator, ctx, self.value) + } + + pub fn z_extend, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + to_int_kind: NewN, + ) -> Instance<'ctx, Int> { + Int(to_int_kind).z_extend(generator, ctx, self.value) + } + + pub fn truncate_or_bit_cast, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + to_int_kind: NewN, + ) -> Instance<'ctx, Int> { + Int(to_int_kind).truncate_or_bit_cast(generator, ctx, self.value) + } + + pub fn truncate, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + to_int_kind: NewN, + ) -> Instance<'ctx, Int> { + Int(to_int_kind).truncate(generator, ctx, self.value) + } + + pub fn s_extend_or_truncate, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + to_int_kind: NewN, + ) -> Instance<'ctx, Int> { + Int(to_int_kind).s_extend_or_truncate(generator, ctx, self.value) + } + + pub fn z_extend_or_truncate, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + to_int_kind: NewN, + ) -> Instance<'ctx, Int> { + Int(to_int_kind).z_extend_or_truncate(generator, ctx, self.value) + } + + #[must_use] + pub fn add(&self, ctx: &CodeGenContext<'ctx, '_>, other: Self) -> Self { + let value = ctx.builder.build_int_add(self.value, other.value, "").unwrap(); + self.model.believe_value(value) + } + + #[must_use] + pub fn sub(&self, ctx: &CodeGenContext<'ctx, '_>, other: Self) -> Self { + let value = ctx.builder.build_int_sub(self.value, other.value, "").unwrap(); + self.model.believe_value(value) + } + + #[must_use] + pub fn mul(&self, ctx: &CodeGenContext<'ctx, '_>, other: Self) -> Self { + let value = ctx.builder.build_int_mul(self.value, other.value, "").unwrap(); + self.model.believe_value(value) + } + + pub fn compare( + &self, + ctx: &CodeGenContext<'ctx, '_>, + op: IntPredicate, + other: Self, + ) -> Instance<'ctx, Int> { + let value = ctx.builder.build_int_compare(op, self.value, other.value, "").unwrap(); + Int(Bool).believe_value(value) + } +} diff --git a/nac3core/src/codegen/model/mod.rs b/nac3core/src/codegen/model/mod.rs new file mode 100644 index 00000000..4256e84e --- /dev/null +++ b/nac3core/src/codegen/model/mod.rs @@ -0,0 +1,17 @@ +mod any; +mod array; +mod core; +mod float; +pub mod function; +mod int; +mod ptr; +mod structure; +pub mod util; + +pub use any::*; +pub use array::*; +pub use core::*; +pub use float::*; +pub use int::*; +pub use ptr::*; +pub use structure::*; diff --git a/nac3core/src/codegen/model/ptr.rs b/nac3core/src/codegen/model/ptr.rs new file mode 100644 index 00000000..a68ac106 --- /dev/null +++ b/nac3core/src/codegen/model/ptr.rs @@ -0,0 +1,219 @@ +use inkwell::{ + context::Context, + types::{BasicType, BasicTypeEnum, PointerType}, + values::{IntValue, PointerValue}, + AddressSpace, +}; + +use crate::codegen::{llvm_intrinsics::call_memcpy_generic, CodeGenContext, CodeGenerator}; + +use super::*; + +/// A model for [`PointerType`]. +/// +/// `Item` is the element type this pointer is pointing to, and should be of a [`Model`]. +/// +// TODO: LLVM 15: `Item` is a Rust type-hint for the LLVM type of value the `.store()/.load()` family +// of functions return. If a truly opaque pointer is needed, tell the programmer to use `OpaquePtr`. +#[derive(Debug, Clone, Copy, Default)] +pub struct Ptr(pub Item); + +/// An opaque pointer. Like [`Ptr`] but without any Rust type-hints about its element type. +/// +/// `.load()/.store()` is not available for [`Instance`]s of opaque pointers. +pub type OpaquePtr = Ptr<()>; + +// TODO: LLVM 15: `Item: Model<'ctx>` don't even need to be a model anymore. It will only be +// a type-hint for the `.load()/.store()` functions for the `pointee_ty`. +// +// See https://thedan64.github.io/inkwell/inkwell/builder/struct.Builder.html#method.build_load. +impl<'ctx, Item: Model<'ctx>> Model<'ctx> for Ptr { + type Value = PointerValue<'ctx>; + type Type = PointerType<'ctx>; + + fn get_type(&self, generator: &G, ctx: &'ctx Context) -> Self::Type { + // TODO: LLVM 15: ctx.ptr_type(AddressSpace::default()) + self.0.get_type(generator, ctx).ptr_type(AddressSpace::default()) + } + + fn check_type, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &'ctx Context, + ty: T, + ) -> Result<(), ModelError> { + let ty = ty.as_basic_type_enum(); + let Ok(ty) = PointerType::try_from(ty) else { + return Err(ModelError(format!("Expecting PointerType, but got {ty:?}"))); + }; + + let elem_ty = ty.get_element_type(); + let Ok(elem_ty) = BasicTypeEnum::try_from(elem_ty) else { + return Err(ModelError(format!( + "Expecting pointer element type to be a BasicTypeEnum, but got {elem_ty:?}" + ))); + }; + + // TODO: inkwell `get_element_type()` will be deprecated. + // Remove the check for `get_element_type()` when the time comes. + self.0 + .check_type(generator, ctx, elem_ty) + .map_err(|err| err.under_context("a PointerType"))?; + + Ok(()) + } +} + +impl<'ctx, Item: Model<'ctx>> Ptr { + /// Return a ***constant*** nullptr. + pub fn nullptr( + &self, + generator: &mut G, + ctx: &'ctx Context, + ) -> Instance<'ctx, Ptr> { + let ptr = self.get_type(generator, ctx).const_null(); + self.believe_value(ptr) + } + + /// Cast a pointer into this model with [`inkwell::builder::Builder::build_pointer_cast`] + pub fn pointer_cast( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + ptr: PointerValue<'ctx>, + ) -> Instance<'ctx, Ptr> { + // TODO: LLVM 15: Write in an impl where `Item` does not have to be `Model<'ctx>`. + // TODO: LLVM 15: This function will only have to be: + // ``` + // return self.believe_value(ptr); + // ``` + let t = self.get_type(generator, ctx.ctx); + let ptr = ctx.builder.build_pointer_cast(ptr, t, "").unwrap(); + self.believe_value(ptr) + } +} + +impl<'ctx, Item: Model<'ctx>> Instance<'ctx, Ptr> { + /// Offset the pointer by [`inkwell::builder::Builder::build_in_bounds_gep`]. + #[must_use] + pub fn offset( + &self, + ctx: &CodeGenContext<'ctx, '_>, + offset: IntValue<'ctx>, + ) -> Instance<'ctx, Ptr> { + let p = unsafe { ctx.builder.build_in_bounds_gep(self.value, &[offset], "").unwrap() }; + self.model.believe_value(p) + } + + /// Offset the pointer by [`inkwell::builder::Builder::build_in_bounds_gep`] by a constant offset. + #[must_use] + pub fn offset_const( + &self, + ctx: &CodeGenContext<'ctx, '_>, + offset: u64, + ) -> Instance<'ctx, Ptr> { + let offset = ctx.ctx.i32_type().const_int(offset, false); + self.offset(ctx, offset) + } + + pub fn set_index( + &self, + ctx: &CodeGenContext<'ctx, '_>, + index: IntValue<'ctx>, + value: Instance<'ctx, Item>, + ) { + self.offset(ctx, index).store(ctx, value); + } + + pub fn set_index_const( + &self, + ctx: &CodeGenContext<'ctx, '_>, + index: u64, + value: Instance<'ctx, Item>, + ) { + self.offset_const(ctx, index).store(ctx, value); + } + + pub fn get_index( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + index: IntValue<'ctx>, + ) -> Instance<'ctx, Item> { + self.offset(ctx, index).load(generator, ctx) + } + + pub fn get_index_const( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + index: u64, + ) -> Instance<'ctx, Item> { + self.offset_const(ctx, index).load(generator, ctx) + } + + /// Load the value with [`inkwell::builder::Builder::build_load`]. + pub fn load( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + ) -> Instance<'ctx, Item> { + let value = ctx.builder.build_load(self.value, "").unwrap(); + self.model.0.check_value(generator, ctx.ctx, value).unwrap() // If unwrap() panics, there is a logic error. + } + + /// Store a value with [`inkwell::builder::Builder::build_store`]. + pub fn store(&self, ctx: &CodeGenContext<'ctx, '_>, value: Instance<'ctx, Item>) { + ctx.builder.build_store(self.value, value.value).unwrap(); + } + + /// Return a casted pointer of element type `NewElement` with [`inkwell::builder::Builder::build_pointer_cast`]. + pub fn pointer_cast, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + new_item: NewItem, + ) -> Instance<'ctx, Ptr> { + // TODO: LLVM 15: Write in an impl where `Item` does not have to be `Model<'ctx>`. + Ptr(new_item).pointer_cast(generator, ctx, self.value) + } + + /// Cast this pointer to `uint8_t*` + pub fn cast_to_pi8( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + ) -> Instance<'ctx, Ptr>> { + Ptr(Int(Byte)).pointer_cast(generator, ctx, self.value) + } + + /// Check if the pointer is null with [`inkwell::builder::Builder::build_is_null`]. + pub fn is_null(&self, ctx: &CodeGenContext<'ctx, '_>) -> Instance<'ctx, Int> { + let value = ctx.builder.build_is_null(self.value, "").unwrap(); + Int(Bool).believe_value(value) + } + + /// Check if the pointer is not null with [`inkwell::builder::Builder::build_is_not_null`]. + pub fn is_not_null(&self, ctx: &CodeGenContext<'ctx, '_>) -> Instance<'ctx, Int> { + let value = ctx.builder.build_is_not_null(self.value, "").unwrap(); + Int(Bool).believe_value(value) + } + + /// `memcpy` from another pointer. + pub fn copy_from( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + source: Self, + num_items: IntValue<'ctx>, + ) { + // Force extend `num_items` and `itemsize` to `i64` so their types would match. + let itemsize = self.model.sizeof(generator, ctx.ctx); + let itemsize = Int(Int64).z_extend_or_truncate(generator, ctx, itemsize); + let num_items = Int(Int64).z_extend_or_truncate(generator, ctx, num_items); + let totalsize = itemsize.mul(ctx, num_items); + + let is_volatile = ctx.ctx.bool_type().const_zero(); // is_volatile = false + call_memcpy_generic(ctx, self.value, source.value, totalsize.value, is_volatile); + } +} diff --git a/nac3core/src/codegen/model/structure.rs b/nac3core/src/codegen/model/structure.rs new file mode 100644 index 00000000..a9899049 --- /dev/null +++ b/nac3core/src/codegen/model/structure.rs @@ -0,0 +1,359 @@ +use std::fmt; + +use inkwell::{ + context::Context, + types::{BasicType, BasicTypeEnum, StructType}, + values::{BasicValueEnum, StructValue}, +}; + +use crate::codegen::{CodeGenContext, CodeGenerator}; + +use super::*; + +/// A traveral that traverses a Rust `struct` that is used to declare an LLVM's struct's field types. +pub trait FieldTraversal<'ctx> { + /// Output type of [`FieldTraversal::add`]. + type Out; + + /// Traverse through the type of a declared field and do something with it. + /// + /// * `name` - The cosmetic name of the LLVM field. Used for debugging. + /// * `model` - The [`Model`] representing the LLVM type of this field. + fn add>(&mut self, name: &'static str, model: M) -> Self::Out; + + /// Like [`FieldTraversal::add`] but [`Model`] is automatically inferred from its [`Default`] trait. + fn add_auto + Default>(&mut self, name: &'static str) -> Self::Out { + self.add(name, M::default()) + } +} + +/// Descriptor of an LLVM struct field. +#[derive(Debug, Clone, Copy)] +pub struct GepField { + /// The GEP index of this field. This is the index to use with `build_gep`. + pub gep_index: u64, + /// The cosmetic name of this field. + pub name: &'static str, + /// The [`Model`] of this field's type. + pub model: M, +} + +/// A traversal to calculate the GEP index of fields. +pub struct GepFieldTraversal { + /// The current GEP index. + gep_index_counter: u64, +} + +impl<'ctx> FieldTraversal<'ctx> for GepFieldTraversal { + type Out = GepField; + + fn add>(&mut self, name: &'static str, model: M) -> Self::Out { + let gep_index = self.gep_index_counter; + self.gep_index_counter += 1; + Self::Out { gep_index, name, model } + } +} + +/// A traversal to collect the field types of a struct. +/// +/// This is used to collect field types and construct the LLVM struct type with [`Context::struct_type`]. +struct TypeFieldTraversal<'ctx, 'a, G: CodeGenerator + ?Sized> { + generator: &'a G, + ctx: &'ctx Context, + /// The collected field types so far in exact order. + field_types: Vec>, +} + +impl<'ctx, 'a, G: CodeGenerator + ?Sized> FieldTraversal<'ctx> for TypeFieldTraversal<'ctx, 'a, G> { + type Out = (); // Checking types return nothing. + + fn add>(&mut self, _name: &'static str, model: M) -> Self::Out { + let t = model.get_type(self.generator, self.ctx).as_basic_type_enum(); + self.field_types.push(t); + } +} + +/// A traversal to check the types of fields. +struct CheckTypeFieldTraversal<'ctx, 'a, G: CodeGenerator + ?Sized> { + generator: &'a mut G, + ctx: &'ctx Context, + /// The current GEP index, so we can tell the index of the field we are checking + /// and report the GEP index. + gep_index_counter: u32, + /// The [`StructType`] to check. + scrutinee: StructType<'ctx>, + /// The list of collected errors so far. + errors: Vec, +} + +impl<'ctx, 'a, G: CodeGenerator + ?Sized> FieldTraversal<'ctx> + for CheckTypeFieldTraversal<'ctx, 'a, G> +{ + type Out = (); // Checking types return nothing. + + fn add>(&mut self, name: &'static str, model: M) -> Self::Out { + let gep_index = self.gep_index_counter; + self.gep_index_counter += 1; + + if let Some(t) = self.scrutinee.get_field_type_at_index(gep_index) { + if let Err(err) = model.check_type(self.generator, self.ctx, t) { + self.errors + .push(err.under_context(format!("field #{gep_index} '{name}'").as_str())); + } + } // Otherwise, it will be caught by Struct's `check_type`. + } +} + +/// A trait for Rust structs identifying LLVM structures. +/// +/// ### Example +/// +/// Suppose you want to define this structure: +/// ```c +/// template +/// struct ContiguousNDArray { +/// size_t ndims; +/// size_t* shape; +/// T* data; +/// } +/// ``` +/// +/// This is how it should be done: +/// ```ignore +/// pub struct ContiguousNDArrayFields<'ctx, F: FieldTraversal<'ctx>, Item: Model<'ctx>> { +/// pub ndims: F::Out>, +/// pub shape: F::Out>>, +/// pub data: F::Out>, +/// } +/// +/// /// An ndarray without strides and non-opaque `data` field in NAC3. +/// #[derive(Debug, Clone, Copy)] +/// pub struct ContiguousNDArray { +/// /// [`Model`] of the items. +/// pub item: M, +/// } +/// +/// impl<'ctx, Item: Model<'ctx>> StructKind<'ctx> for ContiguousNDArray { +/// type Fields> = ContiguousNDArrayFields<'ctx, F, Item>; +/// +/// fn traverse_fields>(&self, traversal: &mut F) -> Self::Fields { +/// // The order of `traversal.add*` is important +/// Self::Fields { +/// ndims: traversal.add_auto("ndims"), +/// shape: traversal.add_auto("shape"), +/// data: traversal.add("data", Ptr(self.item)), +/// } +/// } +/// } +/// ``` +/// +/// The [`FieldTraversal`] here is a mechanism to allow the fields of `ContiguousNDArrayFields` to be +/// traversed to do useful work such as: +/// +/// - To create the [`StructType`] of `ContiguousNDArray` by collecting [`BasicType`]s of the fields. +/// - To enable the `.gep(ctx, |f| f.ndims).store(ctx, ...)` syntax. +/// +/// Suppose now that you have defined `ContiguousNDArray` and you want to allocate a `ContiguousNDArray` +/// with dtype `float64` in LLVM, this is how you do it: +/// ```ignore +/// type F64NDArray = Struct>>; // Type alias for leaner documentation +/// let model: F64NDArray = Struct(ContigousNDArray { item: Float(Float64) }); +/// let ndarray: Instance<'ctx, Ptr> = model.alloca(generator, ctx); +/// ``` +/// +/// ...and here is how you may manipulate/access `ndarray`: +/// +/// (NOTE: some arguments have been omitted) +/// +/// ```ignore +/// // Get `&ndarray->data` +/// ndarray.gep(|f| f.data); // type: Instance<'ctx, Ptr>> +/// +/// // Get `ndarray->ndims` +/// ndarray.get(|f| f.ndims); // type: Instance<'ctx, Int> +/// +/// // Get `&ndarray->ndims` +/// ndarray.gep(|f| f.ndims); // type: Instance<'ctx, Ptr>> +/// +/// // Get `ndarray->shape[0]` +/// ndarray.get(|f| f.shape).get_index_const(0); // Instance<'ctx, Int> +/// +/// // Get `&ndarray->shape[2]` +/// ndarray.get(|f| f.shape).offset_const(2); // Instance<'ctx, Ptr>> +/// +/// // Do `ndarray->ndims = 3;` +/// let num_3 = Int(SizeT).const_int(3); +/// ndarray.set(|f| f.ndims, num_3); +/// ``` +pub trait StructKind<'ctx>: fmt::Debug + Clone + Copy { + /// The associated fields of this struct. + type Fields>; + + /// Traverse through all fields of this [`StructKind`]. + /// + /// Only used internally in this module for implementing other components. + fn traverse_fields>(&self, traversal: &mut F) -> Self::Fields; + + /// Get a convenience structure to get a struct field's GEP index through its corresponding Rust field. + /// + /// Only used internally in this module for implementing other components. + fn fields(&self) -> Self::Fields { + self.traverse_fields(&mut GepFieldTraversal { gep_index_counter: 0 }) + } + + /// Get the LLVM [`StructType`] of this [`StructKind`]. + fn get_struct_type( + &self, + generator: &G, + ctx: &'ctx Context, + ) -> StructType<'ctx> { + let mut traversal = TypeFieldTraversal { generator, ctx, field_types: Vec::new() }; + self.traverse_fields(&mut traversal); + + ctx.struct_type(&traversal.field_types, false) + } +} + +/// A model for LLVM struct. +/// +/// `S` should be of a [`StructKind`]. +#[derive(Debug, Clone, Copy, Default)] +pub struct Struct(pub S); + +impl<'ctx, S: StructKind<'ctx>> Struct { + /// Create a constant struct value from its fields. + /// + /// This function also validates `fields` and panic when there is something wrong. + pub fn const_struct( + &self, + generator: &mut G, + ctx: &'ctx Context, + fields: &[BasicValueEnum<'ctx>], + ) -> Instance<'ctx, Self> { + // NOTE: There *could* have been a functor `F = Instance<'ctx, M>` for `S::Fields` + // to create a more user-friendly interface, but Rust's type system is not sophisticated enough + // and if you try doing that Rust would force you put lifetimes everywhere. + let val = ctx.const_struct(fields, false); + self.check_value(generator, ctx, val).unwrap() + } +} + +impl<'ctx, S: StructKind<'ctx>> Model<'ctx> for Struct { + type Value = StructValue<'ctx>; + type Type = StructType<'ctx>; + + fn get_type(&self, generator: &G, ctx: &'ctx Context) -> Self::Type { + self.0.get_struct_type(generator, ctx) + } + + fn check_type, G: CodeGenerator + ?Sized>( + &self, + generator: &mut G, + ctx: &'ctx Context, + ty: T, + ) -> Result<(), ModelError> { + let ty = ty.as_basic_type_enum(); + let Ok(ty) = StructType::try_from(ty) else { + return Err(ModelError(format!("Expecting StructType, but got {ty:?}"))); + }; + + // Check each field individually. + let mut traversal = CheckTypeFieldTraversal { + generator, + ctx, + gep_index_counter: 0, + errors: Vec::new(), + scrutinee: ty, + }; + self.0.traverse_fields(&mut traversal); + + // Check the number of fields. + let exp_num_fields = traversal.gep_index_counter; + let got_num_fields = u32::try_from(ty.get_field_types().len()).unwrap(); + if exp_num_fields != got_num_fields { + return Err(ModelError(format!( + "Expecting StructType with {exp_num_fields} field(s), but got {got_num_fields}" + ))); + } + + if !traversal.errors.is_empty() { + // Currently, only the first error is reported. + return Err(traversal.errors[0].clone()); + } + + Ok(()) + } +} + +impl<'ctx, S: StructKind<'ctx>> Instance<'ctx, Struct> { + /// Get a field with [`StructValue::get_field_at_index`]. + pub fn get_field( + &self, + generator: &mut G, + ctx: &'ctx Context, + get_field: GetField, + ) -> Instance<'ctx, M> + where + M: Model<'ctx>, + GetField: FnOnce(S::Fields) -> GepField, + { + let field = get_field(self.model.0.fields()); + let val = self.value.get_field_at_index(field.gep_index as u32).unwrap(); + field.model.check_value(generator, ctx, val).unwrap() + } +} + +impl<'ctx, S: StructKind<'ctx>> Instance<'ctx, Ptr>> { + /// Get a pointer to a field with [`Builder::build_in_bounds_gep`]. + pub fn gep( + &self, + ctx: &CodeGenContext<'ctx, '_>, + get_field: GetField, + ) -> Instance<'ctx, Ptr> + where + M: Model<'ctx>, + GetField: FnOnce(S::Fields) -> GepField, + { + let field = get_field(self.model.0 .0.fields()); + let llvm_i32 = ctx.ctx.i32_type(); + + let ptr = unsafe { + ctx.builder + .build_in_bounds_gep( + self.value, + &[llvm_i32.const_zero(), llvm_i32.const_int(field.gep_index, false)], + field.name, + ) + .unwrap() + }; + + Ptr(field.model).believe_value(ptr) + } + + /// Convenience function equivalent to `.gep(...).load(...)`. + pub fn get( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + get_field: GetField, + ) -> Instance<'ctx, M> + where + M: Model<'ctx>, + GetField: FnOnce(S::Fields) -> GepField, + { + self.gep(ctx, get_field).load(generator, ctx) + } + + /// Convenience function equivalent to `.gep(...).store(...)`. + pub fn set( + &self, + ctx: &CodeGenContext<'ctx, '_>, + get_field: GetField, + value: Instance<'ctx, M>, + ) where + M: Model<'ctx>, + GetField: FnOnce(S::Fields) -> GepField, + { + self.gep(ctx, get_field).store(ctx, value); + } +} diff --git a/nac3core/src/codegen/model/util.rs b/nac3core/src/codegen/model/util.rs new file mode 100644 index 00000000..41679740 --- /dev/null +++ b/nac3core/src/codegen/model/util.rs @@ -0,0 +1,42 @@ +use crate::codegen::{ + stmt::{gen_for_callback_incrementing, BreakContinueHooks}, + CodeGenContext, CodeGenerator, +}; + +use super::*; + +/// Like [`gen_for_callback_incrementing`] with [`Model`] abstractions. +/// +/// `stop` is not included. +pub fn gen_for_model<'ctx, 'a, G, F, N>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + start: Instance<'ctx, Int>, + stop: Instance<'ctx, Int>, + step: Instance<'ctx, Int>, + body: F, +) -> Result<(), String> +where + G: CodeGenerator + ?Sized, + F: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + BreakContinueHooks<'ctx>, + Instance<'ctx, Int>, + ) -> Result<(), String>, + N: IntKind<'ctx> + Default, +{ + let int_model = Int(N::default()); + gen_for_callback_incrementing( + generator, + ctx, + None, + start.value, + (stop.value, false), + |g, ctx, hooks, i| { + let i = int_model.believe_value(i); + body(g, ctx, hooks, i) + }, + step.value, + ) +} diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index d58b566b..fed1f0bd 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -1,1732 +1,22 @@ use crate::{ codegen::{ - classes::{ - ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayType, NDArrayValue, - ProxyType, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, - TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, + model::*, + object::{ + any::AnyObject, + ndarray::{nditer::NDIterHandle, shape_util::parse_numpy_int_sequence, NDArrayObject}, }, - expr::gen_binop_expr_with_values, - irrt::{ - calculate_len_for_slice_range, call_ndarray_calc_broadcast, - call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices, - call_ndarray_calc_size, - }, - llvm_intrinsics::{self, call_memcpy_generic}, - stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, + stmt::gen_for_callback, CodeGenContext, CodeGenerator, }, symbol_resolver::ValueEnum, - toplevel::{ - helper::PrimDef, - numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, - DefinitionId, - }, - typecheck::{ - magic_methods::Binop, - typedef::{FunSignature, Type, TypeEnum}, - }, + toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys, DefinitionId}, + typecheck::typedef::{FunSignature, Type}, }; use inkwell::{ - types::BasicType, - values::{BasicValueEnum, IntValue, PointerValue}, - AddressSpace, IntPredicate, OptimizationLevel, + values::{BasicValue, BasicValueEnum, PointerValue}, + IntPredicate, }; -use inkwell::{ - types::{AnyTypeEnum, BasicTypeEnum, PointerType}, - values::BasicValue, -}; -use nac3parser::ast::{Operator, StrRef}; - -/// Creates an uninitialized `NDArray` instance. -fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, -) -> Result, String> { - let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None); - - let llvm_usize = generator.get_size_type(ctx.ctx); - - let llvm_ndarray_t = ctx - .get_llvm_type(generator, ndarray_ty) - .into_pointer_type() - .get_element_type() - .into_struct_type(); - - let ndarray = generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?; - - Ok(NDArrayValue::from_ptr_val(ndarray, llvm_usize, None)) -} - -/// Creates an `NDArray` instance from a dynamic shape. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The shape of the `NDArray`. -/// * `shape_len_fn` - A function that retrieves the number of dimensions from `shape`. -/// * `shape_data_fn` - A function that retrieves the size of a dimension from `shape`. -fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - elem_ty: Type, - shape: &V, - shape_len_fn: LenFn, - shape_data_fn: DataFn, -) -> Result, String> -where - G: CodeGenerator + ?Sized, - LenFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, &V) -> Result, String>, - DataFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - &V, - IntValue<'ctx>, - ) -> Result, String>, -{ - let llvm_usize = generator.get_size_type(ctx.ctx); - - // Assert that all dimensions are non-negative - let shape_len = shape_len_fn(generator, ctx, shape)?; - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (shape_len, false), - |generator, ctx, _, i| { - let shape_dim = shape_data_fn(generator, ctx, shape, i)?; - debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width()); - - let shape_dim_gez = ctx - .builder - .build_int_compare( - IntPredicate::SGE, - shape_dim, - shape_dim.get_type().const_zero(), - "", - ) - .unwrap(); - - ctx.make_assert( - generator, - shape_dim_gez, - "0:ValueError", - "negative dimensions not supported", - [None, None, None], - ctx.current_loc, - ); - - // TODO: Disallow dim_sz > u32_MAX - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?; - - let num_dims = shape_len_fn(generator, ctx, shape)?; - ndarray.store_ndims(ctx, generator, num_dims); - - let ndarray_num_dims = ndarray.load_ndims(ctx); - ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims); - - // Copy the dimension sizes from shape to ndarray.dims - let shape_len = shape_len_fn(generator, ctx, shape)?; - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (shape_len, false), - |generator, ctx, _, i| { - let shape_dim = shape_data_fn(generator, ctx, shape, i)?; - debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width()); - let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); - - let ndarray_pdim = - unsafe { ndarray.dim_sizes().ptr_offset_unchecked(ctx, generator, &i, None) }; - - ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap(); - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray); - - Ok(ndarray) -} - -/// Creates an `NDArray` instance from a constant shape. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The shape of the `NDArray`, represented am array of [`IntValue`]s. -pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - shape: &[IntValue<'ctx>], -) -> Result, String> { - let llvm_usize = generator.get_size_type(ctx.ctx); - - for &shape_dim in shape { - let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); - let shape_dim_gez = ctx - .builder - .build_int_compare(IntPredicate::SGE, shape_dim, llvm_usize.const_zero(), "") - .unwrap(); - - ctx.make_assert( - generator, - shape_dim_gez, - "0:ValueError", - "negative dimensions not supported", - [None, None, None], - ctx.current_loc, - ); - - // TODO: Disallow dim_sz > u32_MAX - } - - let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?; - - let num_dims = llvm_usize.const_int(shape.len() as u64, false); - ndarray.store_ndims(ctx, generator, num_dims); - - let ndarray_num_dims = ndarray.load_ndims(ctx); - ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims); - - for (i, &shape_dim) in shape.iter().enumerate() { - let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); - let ndarray_dim = unsafe { - ndarray.dim_sizes().ptr_offset_unchecked( - ctx, - generator, - &llvm_usize.const_int(i as u64, true), - None, - ) - }; - - ctx.builder.build_store(ndarray_dim, shape_dim).unwrap(); - } - - let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray); - - Ok(ndarray) -} - -/// Initializes the `data` field of [`NDArrayValue`] based on the `ndims` and `dim_sz` fields. -fn ndarray_init_data<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - ndarray: NDArrayValue<'ctx>, -) -> NDArrayValue<'ctx> { - let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum(); - assert!(llvm_ndarray_data_t.is_sized()); - - let ndarray_num_elems = call_ndarray_calc_size( - generator, - ctx, - &ndarray.dim_sizes().as_slice_value(ctx, generator), - (None, None), - ); - ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems); - - ndarray -} - -fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, -) -> BasicValueEnum<'ctx> { - if [ctx.primitives.int32, ctx.primitives.uint32] - .iter() - .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) - { - ctx.ctx.i32_type().const_zero().into() - } else if [ctx.primitives.int64, ctx.primitives.uint64] - .iter() - .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) - { - ctx.ctx.i64_type().const_zero().into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) { - ctx.ctx.f64_type().const_zero().into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) { - ctx.ctx.bool_type().const_zero().into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) { - ctx.gen_string(generator, "").into() - } else { - unreachable!() - } -} - -fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, -) -> BasicValueEnum<'ctx> { - if [ctx.primitives.int32, ctx.primitives.uint32] - .iter() - .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) - { - let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int32); - ctx.ctx.i32_type().const_int(1, is_signed).into() - } else if [ctx.primitives.int64, ctx.primitives.uint64] - .iter() - .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) - { - let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int64); - ctx.ctx.i64_type().const_int(1, is_signed).into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) { - ctx.ctx.f64_type().const_float(1.0).into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) { - ctx.ctx.bool_type().const_int(1, false).into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) { - ctx.gen_string(generator, "1").into() - } else { - unreachable!() - } -} - -/// LLVM-typed implementation for generating the implementation for constructing an `NDArray`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The `shape` parameter used to construct the `NDArray`. -/// -/// ### Notes on `shape` -/// -/// Just like numpy, the `shape` argument can be: -/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])` -/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))` -/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` -/// -/// See also [`typecheck::type_inferencer::fold_numpy_function_call_shape_argument`] to -/// learn how `shape` gets from being a Python user expression to here. -fn call_ndarray_empty_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - shape: BasicValueEnum<'ctx>, -) -> Result, String> { - let llvm_usize = generator.get_size_type(ctx.ctx); - - match shape { - BasicValueEnum::PointerValue(shape_list_ptr) - if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() => - { - // 1. A list of ints; e.g., `np.empty([600, 800, 3])` - - let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None); - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &shape_list, - |_, ctx, shape_list| Ok(shape_list.load_size(ctx, None)), - |generator, ctx, shape_list, idx| { - Ok(shape_list.data().get(ctx, generator, &idx, None).into_int_value()) - }, - ) - } - BasicValueEnum::StructValue(shape_tuple) => { - // 2. A tuple of ints; e.g., `np.empty((600, 800, 3))` - // Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM. - - // Get the length/size of the tuple, which also happens to be the value of `ndims`. - let ndims = shape_tuple.get_type().count_fields(); - - let mut shape = Vec::with_capacity(ndims as usize); - for dim_i in 0..ndims { - let dim = ctx - .builder - .build_extract_value(shape_tuple, dim_i, format!("dim{dim_i}").as_str()) - .unwrap() - .into_int_value(); - - shape.push(dim); - } - create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice()) - } - BasicValueEnum::IntValue(shape_int) => { - // 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` - - create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int]) - } - _ => unreachable!(), - } -} - -/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as -/// its input. -fn ndarray_fill_flattened<'ctx, 'a, G, ValueFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - ndarray: NDArrayValue<'ctx>, - value_fn: ValueFn, -) -> Result<(), String> -where - G: CodeGenerator + ?Sized, - ValueFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - IntValue<'ctx>, - ) -> Result, String>, -{ - let llvm_usize = generator.get_size_type(ctx.ctx); - - let ndarray_num_elems = call_ndarray_calc_size( - generator, - ctx, - &ndarray.dim_sizes().as_slice_value(ctx, generator), - (None, None), - ); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (ndarray_num_elems, false), - |generator, ctx, _, i| { - let elem = unsafe { ndarray.data().ptr_offset_unchecked(ctx, generator, &i, None) }; - - let value = value_fn(generator, ctx, i)?; - ctx.builder.build_store(elem, value).unwrap(); - - Ok(()) - }, - llvm_usize.const_int(1, false), - ) -} - -/// Generates LLVM IR for populating the entire `NDArray` using a lambda with the dimension-indices -/// as its input. -fn ndarray_fill_indexed<'ctx, 'a, G, ValueFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - ndarray: NDArrayValue<'ctx>, - value_fn: ValueFn, -) -> Result<(), String> -where - G: CodeGenerator + ?Sized, - ValueFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - &TypedArrayLikeAdapter<'ctx, IntValue<'ctx>>, - ) -> Result, String>, -{ - ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, idx| { - let indices = call_ndarray_calc_nd_indices(generator, ctx, idx, ndarray); - - value_fn(generator, ctx, &indices) - }) -} - -fn ndarray_fill_mapping<'ctx, 'a, G, MapFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - src: NDArrayValue<'ctx>, - dest: NDArrayValue<'ctx>, - map_fn: MapFn, -) -> Result<(), String> -where - G: CodeGenerator + ?Sized, - MapFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - BasicValueEnum<'ctx>, - ) -> Result, String>, -{ - ndarray_fill_flattened(generator, ctx, dest, |generator, ctx, i| { - let elem = unsafe { src.data().get_unchecked(ctx, generator, &i, None) }; - - map_fn(generator, ctx, elem) - }) -} - -/// Generates the LLVM IR for checking whether the source `ndarray` can be broadcast to the shape of -/// the target `ndarray`. -fn ndarray_assert_is_broadcastable<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - target: NDArrayValue<'ctx>, - source: NDArrayValue<'ctx>, -) { - let array_ndims = source.load_ndims(ctx); - let broadcast_size = target.load_ndims(ctx); - - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::ULE, array_ndims, broadcast_size, "").unwrap(), - "0:ValueError", - "operands cannot be broadcast together", - [None, None, None], - ctx.current_loc, - ); -} - -/// Generates the LLVM IR for populating the entire `NDArray` from two `ndarray` or scalar value -/// with broadcast-compatible shapes. -fn ndarray_broadcast_fill<'ctx, 'a, G, ValueFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - res: NDArrayValue<'ctx>, - lhs: (BasicValueEnum<'ctx>, bool), - rhs: (BasicValueEnum<'ctx>, bool), - value_fn: ValueFn, -) -> Result, String> -where - G: CodeGenerator + ?Sized, - ValueFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>), - ) -> Result, String>, -{ - let llvm_usize = generator.get_size_type(ctx.ctx); - - let (lhs_val, lhs_scalar) = lhs; - let (rhs_val, rhs_scalar) = rhs; - - assert!( - !(lhs_scalar && rhs_scalar), - "One of the operands must be a ndarray instance: `{}`, `{}`", - lhs_val.get_type(), - rhs_val.get_type() - ); - - // Assert that all ndarray operands are broadcastable to the target size - if !lhs_scalar { - let lhs_val = NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None); - ndarray_assert_is_broadcastable(generator, ctx, res, lhs_val); - } - - if !rhs_scalar { - let rhs_val = NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None); - ndarray_assert_is_broadcastable(generator, ctx, res, rhs_val); - } - - ndarray_fill_indexed(generator, ctx, res, |generator, ctx, idx| { - let lhs_elem = if lhs_scalar { - lhs_val - } else { - let lhs = NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None); - let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, idx); - - unsafe { lhs.data().get_unchecked(ctx, generator, &lhs_idx, None) } - }; - - let rhs_elem = if rhs_scalar { - rhs_val - } else { - let rhs = NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None); - let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, idx); - - unsafe { rhs.data().get_unchecked(ctx, generator, &rhs_idx, None) } - }; - - value_fn(generator, ctx, (lhs_elem, rhs_elem)) - })?; - - Ok(res) -} - -/// LLVM-typed implementation for generating the implementation for `ndarray.zeros`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The `shape` parameter used to construct the `NDArray`. -fn call_ndarray_zeros_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - shape: BasicValueEnum<'ctx>, -) -> Result, String> { - let supported_types = [ - ctx.primitives.int32, - ctx.primitives.int64, - ctx.primitives.uint32, - ctx.primitives.uint64, - ctx.primitives.float, - ctx.primitives.bool, - ctx.primitives.str, - ]; - assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty))); - - let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; - ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| { - let value = ndarray_zero_value(generator, ctx, elem_ty); - - Ok(value) - })?; - - Ok(ndarray) -} - -/// LLVM-typed implementation for generating the implementation for `ndarray.ones`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The `shape` parameter used to construct the `NDArray`. -fn call_ndarray_ones_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - shape: BasicValueEnum<'ctx>, -) -> Result, String> { - let supported_types = [ - ctx.primitives.int32, - ctx.primitives.int64, - ctx.primitives.uint32, - ctx.primitives.uint64, - ctx.primitives.float, - ctx.primitives.bool, - ctx.primitives.str, - ]; - assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty))); - - let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; - ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| { - let value = ndarray_one_value(generator, ctx, elem_ty); - - Ok(value) - })?; - - Ok(ndarray) -} - -/// LLVM-typed implementation for generating the implementation for `ndarray.full`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The `shape` parameter used to construct the `NDArray`. -fn call_ndarray_full_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - shape: BasicValueEnum<'ctx>, - fill_value: BasicValueEnum<'ctx>, -) -> Result, String> { - let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; - ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| { - let value = if fill_value.is_pointer_value() { - let llvm_i1 = ctx.ctx.bool_type(); - - let copy = generator.gen_var_alloc(ctx, fill_value.get_type(), None)?; - - call_memcpy_generic( - ctx, - copy, - fill_value.into_pointer_value(), - fill_value.get_type().size_of().map(Into::into).unwrap(), - llvm_i1.const_zero(), - ); - - copy.into() - } else if fill_value.is_int_value() || fill_value.is_float_value() { - fill_value - } else { - unreachable!() - }; - - Ok(value) - })?; - - Ok(ndarray) -} - -/// Returns the number of dimensions for a multidimensional list as an [`IntValue`]. -fn llvm_ndlist_get_ndims<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ty: PointerType<'ctx>, -) -> IntValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); - - let list_ty = ListType::from_type(ty, llvm_usize); - let list_elem_ty = list_ty.element_type(); - - let ndims = llvm_usize.const_int(1, false); - match list_elem_ty { - AnyTypeEnum::PointerType(ptr_ty) if ListType::is_type(ptr_ty, llvm_usize).is_ok() => { - ndims.const_add(llvm_ndlist_get_ndims(generator, ctx, ptr_ty)) - } - - AnyTypeEnum::PointerType(ptr_ty) if NDArrayType::is_type(ptr_ty, llvm_usize).is_ok() => { - todo!("Getting ndims for list[ndarray] not supported") - } - - _ => ndims, - } -} - -/// Returns the number of dimensions for an array-like object as an [`IntValue`]. -fn llvm_arraylike_get_ndims<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - value: BasicValueEnum<'ctx>, -) -> IntValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); - - match value { - BasicValueEnum::PointerValue(v) if NDArrayValue::is_instance(v, llvm_usize).is_ok() => { - NDArrayValue::from_ptr_val(v, llvm_usize, None).load_ndims(ctx) - } - - BasicValueEnum::PointerValue(v) if ListValue::is_instance(v, llvm_usize).is_ok() => { - llvm_ndlist_get_ndims(generator, ctx, v.get_type()) - } - - _ => llvm_usize.const_zero(), - } -} - -/// Flattens and copies the values from a multidimensional list into an [`NDArrayValue`]. -fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - (dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>), - src_lst: ListValue<'ctx>, - dim: u64, -) -> Result<(), String> { - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let list_elem_ty = src_lst.get_type().element_type(); - - match list_elem_ty { - AnyTypeEnum::PointerType(ptr_ty) if ListType::is_type(ptr_ty, llvm_usize).is_ok() => { - // The stride of elements in this dimension, i.e. the number of elements between arr[i] - // and arr[i + 1] in this dimension - let stride = call_ndarray_calc_size( - generator, - ctx, - &dst_arr.dim_sizes(), - (Some(llvm_usize.const_int(dim + 1, false)), None), - ); - - gen_for_range_callback( - generator, - ctx, - None, - true, - |_, _| Ok(llvm_usize.const_zero()), - (|_, ctx| Ok(src_lst.load_size(ctx, None)), false), - |_, _| Ok(llvm_usize.const_int(1, false)), - |generator, ctx, _, i| { - let offset = ctx.builder.build_int_mul(stride, i, "").unwrap(); - - let dst_ptr = - unsafe { ctx.builder.build_gep(dst_slice_ptr, &[offset], "").unwrap() }; - - let nested_lst_elem = ListValue::from_ptr_val( - unsafe { src_lst.data().get_unchecked(ctx, generator, &i, None) } - .into_pointer_value(), - llvm_usize, - None, - ); - - ndarray_from_ndlist_impl( - generator, - ctx, - elem_ty, - (dst_arr, dst_ptr), - nested_lst_elem, - dim + 1, - )?; - - Ok(()) - }, - )?; - } - - AnyTypeEnum::PointerType(ptr_ty) if NDArrayType::is_type(ptr_ty, llvm_usize).is_ok() => { - todo!("Not implemented for list[ndarray]") - } - - _ => { - let lst_len = src_lst.load_size(ctx, None); - let sizeof_elem = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap(); - let sizeof_elem = ctx.builder.build_int_cast(sizeof_elem, llvm_usize, "").unwrap(); - - let cpy_len = ctx - .builder - .build_int_mul( - ctx.builder.build_int_z_extend_or_bit_cast(lst_len, llvm_usize, "").unwrap(), - sizeof_elem, - "", - ) - .unwrap(); - - call_memcpy_generic( - ctx, - dst_slice_ptr, - src_lst.data().base_ptr(ctx, generator), - cpy_len, - llvm_i1.const_zero(), - ); - } - } - - Ok(()) -} - -/// LLVM-typed implementation for `ndarray.array`. -fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - object: BasicValueEnum<'ctx>, - copy: IntValue<'ctx>, - ndmin: IntValue<'ctx>, -) -> Result, String> { - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let ndmin = ctx.builder.build_int_z_extend_or_bit_cast(ndmin, llvm_usize, "").unwrap(); - - // TODO(Derppening): Add assertions for sizes of different dimensions - - // object is not a pointer - 0-dim NDArray - if !object.is_pointer_value() { - let ndarray = create_ndarray_const_shape(generator, ctx, elem_ty, &[])?; - - unsafe { - ndarray.data().set_unchecked(ctx, generator, &llvm_usize.const_zero(), object); - } - - return Ok(ndarray); - } - - let object = object.into_pointer_value(); - - // object is an NDArray instance - copy object unless copy=0 && ndmin < object.ndims - if NDArrayValue::is_instance(object, llvm_usize).is_ok() { - let object = NDArrayValue::from_ptr_val(object, llvm_usize, None); - - let ndarray = gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - let copy_nez = ctx - .builder - .build_int_compare(IntPredicate::NE, copy, llvm_i1.const_zero(), "") - .unwrap(); - let ndmin_gt_ndims = ctx - .builder - .build_int_compare(IntPredicate::UGT, ndmin, object.load_ndims(ctx), "") - .unwrap(); - - Ok(ctx.builder.build_and(copy_nez, ndmin_gt_ndims, "").unwrap()) - }, - |generator, ctx| { - let ndarray = create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &object, - |_, ctx, object| { - let ndims = object.load_ndims(ctx); - let ndmin_gt_ndims = ctx - .builder - .build_int_compare(IntPredicate::UGT, ndmin, object.load_ndims(ctx), "") - .unwrap(); - - Ok(ctx - .builder - .build_select(ndmin_gt_ndims, ndmin, ndims, "") - .map(BasicValueEnum::into_int_value) - .unwrap()) - }, - |generator, ctx, object, idx| { - let ndims = object.load_ndims(ctx); - let ndmin = llvm_intrinsics::call_int_umax(ctx, ndims, ndmin, None); - // The number of dimensions to prepend 1's to - let offset = ctx.builder.build_int_sub(ndmin, ndims, "").unwrap(); - - Ok(gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare(IntPredicate::UGE, idx, offset, "") - .unwrap()) - }, - |_, _| Ok(Some(llvm_usize.const_int(1, false))), - |_, ctx| Ok(Some(ctx.builder.build_int_sub(idx, offset, "").unwrap())), - )? - .map(BasicValueEnum::into_int_value) - .unwrap()) - }, - )?; - - ndarray_sliced_copyto_impl( - generator, - ctx, - elem_ty, - (ndarray, ndarray.data().base_ptr(ctx, generator)), - (object, object.data().base_ptr(ctx, generator)), - 0, - &[], - )?; - - Ok(Some(ndarray.as_base_value())) - }, - |_, _| Ok(Some(object.as_base_value())), - )?; - - return Ok(NDArrayValue::from_ptr_val( - ndarray.map(BasicValueEnum::into_pointer_value).unwrap(), - llvm_usize, - None, - )); - } - - // Remaining case: TList - assert!(ListValue::is_instance(object, llvm_usize).is_ok()); - let object = ListValue::from_ptr_val(object, llvm_usize, None); - - // The number of dimensions to prepend 1's to - let ndims = llvm_ndlist_get_ndims(generator, ctx, object.as_base_value().get_type()); - let ndmin = llvm_intrinsics::call_int_umax(ctx, ndims, ndmin, None); - let offset = ctx.builder.build_int_sub(ndmin, ndims, "").unwrap(); - - let ndarray = create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &object, - |generator, ctx, object| { - let ndims = llvm_ndlist_get_ndims(generator, ctx, object.as_base_value().get_type()); - let ndmin_gt_ndims = - ctx.builder.build_int_compare(IntPredicate::UGT, ndmin, ndims, "").unwrap(); - - Ok(ctx - .builder - .build_select(ndmin_gt_ndims, ndmin, ndims, "") - .map(BasicValueEnum::into_int_value) - .unwrap()) - }, - |generator, ctx, object, idx| { - Ok(gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx.builder.build_int_compare(IntPredicate::ULT, idx, offset, "").unwrap()) - }, - |_, _| Ok(Some(llvm_usize.const_int(1, false))), - |generator, ctx| { - let make_llvm_list = |elem_ty: BasicTypeEnum<'ctx>| { - ctx.ctx.struct_type( - &[elem_ty.ptr_type(AddressSpace::default()).into(), llvm_usize.into()], - false, - ) - }; - - let llvm_i8 = ctx.ctx.i8_type(); - let llvm_list_i8 = make_llvm_list(llvm_i8.into()); - let llvm_plist_i8 = llvm_list_i8.ptr_type(AddressSpace::default()); - - // Cast list to { i8*, usize } since we only care about the size - let lst = generator - .gen_var_alloc( - ctx, - ListType::new(generator, ctx.ctx, llvm_i8.into()).as_base_type().into(), - None, - ) - .unwrap(); - ctx.builder - .build_store( - lst, - ctx.builder - .build_bitcast(object.as_base_value(), llvm_plist_i8, "") - .unwrap(), - ) - .unwrap(); - - let stop = ctx.builder.build_int_sub(idx, offset, "").unwrap(); - gen_for_range_callback( - generator, - ctx, - None, - true, - |_, _| Ok(llvm_usize.const_zero()), - (|_, _| Ok(stop), false), - |_, _| Ok(llvm_usize.const_int(1, false)), - |generator, ctx, _, _| { - let plist_plist_i8 = make_llvm_list(llvm_plist_i8.into()) - .ptr_type(AddressSpace::default()); - - let this_dim = ctx - .builder - .build_load(lst, "") - .map(BasicValueEnum::into_pointer_value) - .map(|v| ctx.builder.build_bitcast(v, plist_plist_i8, "").unwrap()) - .map(BasicValueEnum::into_pointer_value) - .unwrap(); - let this_dim = ListValue::from_ptr_val(this_dim, llvm_usize, None); - - // TODO: Assert this_dim.sz != 0 - - let next_dim = unsafe { - this_dim.data().get_unchecked( - ctx, - generator, - &llvm_usize.const_zero(), - None, - ) - } - .into_pointer_value(); - ctx.builder - .build_store( - lst, - ctx.builder.build_bitcast(next_dim, llvm_plist_i8, "").unwrap(), - ) - .unwrap(); - - Ok(()) - }, - )?; - - let lst = ListValue::from_ptr_val( - ctx.builder - .build_load(lst, "") - .map(BasicValueEnum::into_pointer_value) - .unwrap(), - llvm_usize, - None, - ); - - Ok(Some(lst.load_size(ctx, None))) - }, - )? - .map(BasicValueEnum::into_int_value) - .unwrap()) - }, - )?; - - ndarray_from_ndlist_impl( - generator, - ctx, - elem_ty, - (ndarray, ndarray.data().base_ptr(ctx, generator)), - object, - 0, - )?; - - Ok(ndarray) -} - -/// LLVM-typed implementation for generating the implementation for `ndarray.eye`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - nrows: IntValue<'ctx>, - ncols: IntValue<'ctx>, - offset: IntValue<'ctx>, -) -> Result, String> { - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let nrows = ctx.builder.build_int_z_extend_or_bit_cast(nrows, llvm_usize, "").unwrap(); - let ncols = ctx.builder.build_int_z_extend_or_bit_cast(ncols, llvm_usize, "").unwrap(); - - let ndarray = create_ndarray_const_shape(generator, ctx, elem_ty, &[nrows, ncols])?; - - ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, indices| { - let (row, col) = unsafe { - ( - indices.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None), - indices.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None), - ) - }; - - let col_with_offset = ctx - .builder - .build_int_add( - col, - ctx.builder.build_int_s_extend_or_bit_cast(offset, llvm_i32, "").unwrap(), - "", - ) - .unwrap(); - let is_on_diag = - ctx.builder.build_int_compare(IntPredicate::EQ, row, col_with_offset, "").unwrap(); - - let zero = ndarray_zero_value(generator, ctx, elem_ty); - let one = ndarray_one_value(generator, ctx, elem_ty); - - let value = ctx.builder.build_select(is_on_diag, one, zero, "").unwrap(); - - Ok(value) - })?; - - Ok(ndarray) -} - -/// Copies a slice of an [`NDArrayValue`] to another. -/// -/// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `dim_sz` -/// fields should be populated before calling this function. -/// - `dst_slice_ptr`: The [`PointerValue`] to the first element of the currently processing -/// dimensional slice in the destination array. -/// - `src_arr`: The [`NDArrayValue`] instance of the source array. -/// - `src_slice_ptr`: The [`PointerValue`] to the first element of the currently processing -/// dimensional slice in the source array. -/// - `dim`: The index of the currently processing dimension. -/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to -/// this dimension. The `start`/`stop` values of each slice must be non-negative indices. -fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - (dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>), - (src_arr, src_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>), - dim: u64, - slices: &[(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)], -) -> Result<(), String> { - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - // If there are no (remaining) slice expressions, memcpy the entire dimension - if slices.is_empty() { - let sizeof_elem = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap(); - - let stride = call_ndarray_calc_size( - generator, - ctx, - &src_arr.dim_sizes(), - (Some(llvm_usize.const_int(dim, false)), None), - ); - let stride = - ctx.builder.build_int_z_extend_or_bit_cast(stride, sizeof_elem.get_type(), "").unwrap(); - - let cpy_len = ctx.builder.build_int_mul(stride, sizeof_elem, "").unwrap(); - - call_memcpy_generic(ctx, dst_slice_ptr, src_slice_ptr, cpy_len, llvm_i1.const_zero()); - - return Ok(()); - } - - // The stride of elements in this dimension, i.e. the number of elements between arr[i] and - // arr[i + 1] in this dimension - let src_stride = call_ndarray_calc_size( - generator, - ctx, - &src_arr.dim_sizes(), - (Some(llvm_usize.const_int(dim + 1, false)), None), - ); - let dst_stride = call_ndarray_calc_size( - generator, - ctx, - &dst_arr.dim_sizes(), - (Some(llvm_usize.const_int(dim + 1, false)), None), - ); - - let (start, stop, step) = slices[0]; - let start = ctx.builder.build_int_s_extend_or_bit_cast(start, llvm_usize, "").unwrap(); - let stop = ctx.builder.build_int_s_extend_or_bit_cast(stop, llvm_usize, "").unwrap(); - let step = ctx.builder.build_int_s_extend_or_bit_cast(step, llvm_usize, "").unwrap(); - - let dst_i_addr = generator.gen_var_alloc(ctx, start.get_type().into(), None).unwrap(); - ctx.builder.build_store(dst_i_addr, start.get_type().const_zero()).unwrap(); - - gen_for_range_callback( - generator, - ctx, - None, - false, - |_, _| Ok(start), - (|_, _| Ok(stop), true), - |_, _| Ok(step), - |generator, ctx, _, src_i| { - // Calculate the offset of the active slice - let src_data_offset = ctx.builder.build_int_mul(src_stride, src_i, "").unwrap(); - let dst_i = - ctx.builder.build_load(dst_i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); - let dst_data_offset = ctx.builder.build_int_mul(dst_stride, dst_i, "").unwrap(); - - let (src_ptr, dst_ptr) = unsafe { - ( - ctx.builder.build_gep(src_slice_ptr, &[src_data_offset], "").unwrap(), - ctx.builder.build_gep(dst_slice_ptr, &[dst_data_offset], "").unwrap(), - ) - }; - - ndarray_sliced_copyto_impl( - generator, - ctx, - elem_ty, - (dst_arr, dst_ptr), - (src_arr, src_ptr), - dim + 1, - &slices[1..], - )?; - - let dst_i = - ctx.builder.build_load(dst_i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); - let dst_i_add1 = - ctx.builder.build_int_add(dst_i, llvm_usize.const_int(1, false), "").unwrap(); - ctx.builder.build_store(dst_i_addr, dst_i_add1).unwrap(); - - Ok(()) - }, - )?; - - Ok(()) -} - -/// Copies a [`NDArrayValue`] using slices. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to -/// this dimension. The `start`/`stop` values of each slice must be positive indices. -pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - this: NDArrayValue<'ctx>, - slices: &[(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)], -) -> Result, String> { - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let ndarray = if slices.is_empty() { - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &this, - |_, ctx, shape| Ok(shape.load_ndims(ctx)), - |generator, ctx, shape, idx| unsafe { - Ok(shape.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) - }, - )? - } else { - let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?; - ndarray.store_ndims(ctx, generator, this.load_ndims(ctx)); - - let ndims = this.load_ndims(ctx); - ndarray.create_dim_sizes(ctx, llvm_usize, ndims); - - // Populate the first slices.len() dimensions by computing the size of each dim slice - for (i, (start, stop, step)) in slices.iter().enumerate() { - // HACK: workaround calculate_len_for_slice_range requiring exclusive stop - let stop = ctx - .builder - .build_select( - ctx.builder - .build_int_compare( - IntPredicate::SLT, - *step, - llvm_i32.const_zero(), - "is_neg", - ) - .unwrap(), - ctx.builder - .build_int_sub(*stop, llvm_i32.const_int(1, true), "e_min_one") - .unwrap(), - ctx.builder - .build_int_add(*stop, llvm_i32.const_int(1, true), "e_add_one") - .unwrap(), - "final_e", - ) - .map(BasicValueEnum::into_int_value) - .unwrap(); - - let slice_len = calculate_len_for_slice_range(generator, ctx, *start, stop, *step); - let slice_len = - ctx.builder.build_int_z_extend_or_bit_cast(slice_len, llvm_usize, "").unwrap(); - - unsafe { - ndarray.dim_sizes().set_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(i as u64, false), - slice_len, - ); - } - } - - // Populate the rest by directly copying the dim size from the source array - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_int(slices.len() as u64, false), - (this.load_ndims(ctx), false), - |generator, ctx, _, idx| { - unsafe { - let dim_sz = this.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None); - ndarray.dim_sizes().set_typed_unchecked(ctx, generator, &idx, dim_sz); - } - - Ok(()) - }, - llvm_usize.const_int(1, false), - ) - .unwrap(); - - ndarray_init_data(generator, ctx, elem_ty, ndarray) - }; - - ndarray_sliced_copyto_impl( - generator, - ctx, - elem_ty, - (ndarray, ndarray.data().base_ptr(ctx, generator)), - (this, this.data().base_ptr(ctx, generator)), - 0, - slices, - )?; - - Ok(ndarray) -} - -/// LLVM-typed implementation for generating the implementation for `ndarray.copy`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - this: NDArrayValue<'ctx>, -) -> Result, String> { - ndarray_sliced_copy(generator, ctx, elem_ty, this, &[]) -} - -pub fn ndarray_elementwise_unaryop_impl<'ctx, 'a, G, MapFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - elem_ty: Type, - res: Option>, - operand: NDArrayValue<'ctx>, - map_fn: MapFn, -) -> Result, String> -where - G: CodeGenerator + ?Sized, - MapFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - BasicValueEnum<'ctx>, - ) -> Result, String>, -{ - let res = res.unwrap_or_else(|| { - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &operand, - |_, ctx, v| Ok(v.load_ndims(ctx)), - |generator, ctx, v, idx| unsafe { - Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) - }, - ) - .unwrap() - }); - - ndarray_fill_mapping(generator, ctx, operand, res, |generator, ctx, elem| { - map_fn(generator, ctx, elem) - })?; - - Ok(res) -} - -/// LLVM-typed implementation for computing elementwise binary operations on two input operands. -/// -/// If the operand is a `ndarray`, the broadcast index corresponding to each element in the output -/// is computed, the element accessed and used as an operand of the `value_fn` arguments tuple. -/// Otherwise, the operand is treated as a scalar value, and is used as an operand of the -/// `value_fn` arguments tuple for all output elements. -/// -/// The second element of the tuple indicates whether to treat the operand value as a `ndarray` -/// (which would be accessed by its broadcast index) or as a scalar value (which would be -/// broadcast to all elements). -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be -/// written to a new `ndarray`. -/// * `value_fn` - Function mapping the two input elements into the result. -/// -/// # Panic -/// -/// This function will panic if neither input operands (`lhs` or `rhs`) is a `ndarray`. -pub fn ndarray_elementwise_binop_impl<'ctx, 'a, G, ValueFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - elem_ty: Type, - res: Option>, - lhs: (BasicValueEnum<'ctx>, bool), - rhs: (BasicValueEnum<'ctx>, bool), - value_fn: ValueFn, -) -> Result, String> -where - G: CodeGenerator + ?Sized, - ValueFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>), - ) -> Result, String>, -{ - let llvm_usize = generator.get_size_type(ctx.ctx); - - let (lhs_val, lhs_scalar) = lhs; - let (rhs_val, rhs_scalar) = rhs; - - assert!( - !(lhs_scalar && rhs_scalar), - "One of the operands must be a ndarray instance: `{}`, `{}`", - lhs_val.get_type(), - rhs_val.get_type() - ); - - let ndarray = res.unwrap_or_else(|| { - if lhs_scalar && rhs_scalar { - let lhs_val = - NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None); - let rhs_val = - NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None); - - let ndarray_dims = call_ndarray_calc_broadcast(generator, ctx, lhs_val, rhs_val); - - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &ndarray_dims, - |generator, ctx, v| Ok(v.size(ctx, generator)), - |generator, ctx, v, idx| unsafe { - Ok(v.get_typed_unchecked(ctx, generator, &idx, None)) - }, - ) - .unwrap() - } else { - let ndarray = NDArrayValue::from_ptr_val( - if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(), - llvm_usize, - None, - ); - - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &ndarray, - |_, ctx, v| Ok(v.load_ndims(ctx)), - |generator, ctx, v, idx| unsafe { - Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) - }, - ) - .unwrap() - } - }); - - ndarray_broadcast_fill(generator, ctx, ndarray, lhs, rhs, |generator, ctx, elems| { - value_fn(generator, ctx, elems) - })?; - - Ok(ndarray) -} - -/// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be -/// written to a new `ndarray`. -pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - res: Option>, - lhs: NDArrayValue<'ctx>, - rhs: NDArrayValue<'ctx>, -) -> Result, String> { - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - if cfg!(debug_assertions) { - let lhs_ndims = lhs.load_ndims(ctx); - let rhs_ndims = rhs.load_ndims(ctx); - - // lhs.ndims == 2 - ctx.make_assert( - generator, - ctx.builder - .build_int_compare(IntPredicate::EQ, lhs_ndims, llvm_usize.const_int(2, false), "") - .unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - - // rhs.ndims == 2 - ctx.make_assert( - generator, - ctx.builder - .build_int_compare(IntPredicate::EQ, rhs_ndims, llvm_usize.const_int(2, false), "") - .unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - - if let Some(res) = res { - let res_ndims = res.load_ndims(ctx); - let res_dim0 = unsafe { - res.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - let res_dim1 = unsafe { - res.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - let lhs_dim0 = unsafe { - lhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - let rhs_dim1 = unsafe { - rhs.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - - // res.ndims == 2 - ctx.make_assert( - generator, - ctx.builder - .build_int_compare( - IntPredicate::EQ, - res_ndims, - llvm_usize.const_int(2, false), - "", - ) - .unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - - // res.dims[0] == lhs.dims[0] - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim0, res_dim0, "").unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - - // res.dims[1] == rhs.dims[0] - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, rhs_dim1, res_dim1, "").unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - } - } - - if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { - let lhs_dim1 = unsafe { - lhs.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - let rhs_dim0 = unsafe { - rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - - // lhs.dims[1] == rhs.dims[0] - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim1, rhs_dim0, "").unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - } - - let lhs = if res.is_some_and(|res| res.as_base_value() == lhs.as_base_value()) { - ndarray_copy_impl(generator, ctx, elem_ty, lhs)? - } else { - lhs - }; - - let ndarray = res.unwrap_or_else(|| { - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &(lhs, rhs), - |_, _, _| Ok(llvm_usize.const_int(2, false)), - |generator, ctx, (lhs, rhs), idx| { - gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare(IntPredicate::EQ, idx, llvm_usize.const_zero(), "") - .unwrap()) - }, - |generator, ctx| { - Ok(Some(unsafe { - lhs.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_zero(), - None, - ) - })) - }, - |generator, ctx| { - Ok(Some(unsafe { - rhs.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - })) - }, - ) - .map(|v| v.map(BasicValueEnum::into_int_value).unwrap()) - }, - ) - .unwrap() - }); - - let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); - - ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, idx| { - llvm_intrinsics::call_expect( - ctx, - idx.size(ctx, generator).get_type().const_int(2, false), - idx.size(ctx, generator), - None, - ); - - let common_dim = { - let lhs_idx1 = unsafe { - lhs.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - let rhs_idx0 = unsafe { - rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - - let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None); - - ctx.builder.build_int_truncate(idx, llvm_i32, "").unwrap() - }; - - let idx0 = unsafe { - let idx0 = idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None); - - ctx.builder.build_int_truncate(idx0, llvm_i32, "").unwrap() - }; - let idx1 = unsafe { - let idx1 = - idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None); - - ctx.builder.build_int_truncate(idx1, llvm_i32, "").unwrap() - }; - - let result_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; - let result_identity = ndarray_zero_value(generator, ctx, elem_ty); - ctx.builder.build_store(result_addr, result_identity).unwrap(); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_i32.const_zero(), - (common_dim, false), - |generator, ctx, _, i| { - let i = ctx.builder.build_int_truncate(i, llvm_i32, "").unwrap(); - - let ab_idx = generator.gen_array_var_alloc( - ctx, - llvm_i32.into(), - llvm_usize.const_int(2, false), - None, - )?; - - let a = unsafe { - ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), idx0.into()); - ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_int(1, false), i.into()); - - lhs.data().get_unchecked(ctx, generator, &ab_idx, None) - }; - let b = unsafe { - ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), i.into()); - ab_idx.set_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - idx1.into(), - ); - - rhs.data().get_unchecked(ctx, generator, &ab_idx, None) - }; - - let a_mul_b = gen_binop_expr_with_values( - generator, - ctx, - (&Some(elem_ty), a), - Binop::normal(Operator::Mult), - (&Some(elem_ty), b), - ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum(ctx, generator, elem_ty)?; - - let result = ctx.builder.build_load(result_addr, "").unwrap(); - let result = gen_binop_expr_with_values( - generator, - ctx, - (&Some(elem_ty), result), - Binop::normal(Operator::Add), - (&Some(elem_ty), a_mul_b), - ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum(ctx, generator, elem_ty)?; - ctx.builder.build_store(result_addr, result).unwrap(); - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - let result = ctx.builder.build_load(result_addr, "").unwrap(); - Ok(result) - })?; - - Ok(ndarray) -} +use nac3parser::ast::StrRef; /// Generates LLVM IR for `ndarray.empty`. pub fn gen_ndarray_empty<'ctx>( @@ -1742,8 +32,13 @@ pub fn gen_ndarray_empty<'ctx>( let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; - call_ndarray_empty_impl(generator, context, context.primitives.float, shape_arg) - .map(NDArrayValue::into) + let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let ndims = extract_ndims(&context.unifier, ndims); + + let shape = AnyObject { value: shape_arg, ty: shape_ty }; + let (_, shape) = parse_numpy_int_sequence(generator, context, shape); + let ndarray = NDArrayObject::make_np_empty(generator, context, dtype, ndims, shape); + Ok(ndarray.instance.value) } /// Generates LLVM IR for `ndarray.zeros`. @@ -1760,8 +55,13 @@ pub fn gen_ndarray_zeros<'ctx>( let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; - call_ndarray_zeros_impl(generator, context, context.primitives.float, shape_arg) - .map(NDArrayValue::into) + let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let ndims = extract_ndims(&context.unifier, ndims); + + let shape = AnyObject { value: shape_arg, ty: shape_ty }; + let (_, shape) = parse_numpy_int_sequence(generator, context, shape); + let ndarray = NDArrayObject::make_np_zeros(generator, context, dtype, ndims, shape); + Ok(ndarray.instance.value) } /// Generates LLVM IR for `ndarray.ones`. @@ -1778,8 +78,13 @@ pub fn gen_ndarray_ones<'ctx>( let shape_ty = fun.0.args[0].ty; let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; - call_ndarray_ones_impl(generator, context, context.primitives.float, shape_arg) - .map(NDArrayValue::into) + let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let ndims = extract_ndims(&context.unifier, ndims); + + let shape = AnyObject { value: shape_arg, ty: shape_ty }; + let (_, shape) = parse_numpy_int_sequence(generator, context, shape); + let ndarray = NDArrayObject::make_np_ones(generator, context, dtype, ndims, shape); + Ok(ndarray.instance.value) } /// Generates LLVM IR for `ndarray.full`. @@ -1799,8 +104,14 @@ pub fn gen_ndarray_full<'ctx>( let fill_value_arg = args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?; - call_ndarray_full_impl(generator, context, fill_value_ty, shape_arg, fill_value_arg) - .map(NDArrayValue::into) + let (dtype, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let ndims = extract_ndims(&context.unifier, ndims); + + let shape = AnyObject { value: shape_arg, ty: shape_ty }; + let (_, shape) = parse_numpy_int_sequence(generator, context, shape); + let ndarray = + NDArrayObject::make_np_full(generator, context, dtype, ndims, shape, fill_value_arg); + Ok(ndarray.instance.value) } pub fn gen_ndarray_array<'ctx>( @@ -1814,26 +125,6 @@ pub fn gen_ndarray_array<'ctx>( assert!(matches!(args.len(), 1..=3)); let obj_ty = fun.0.args[0].ty; - let obj_elem_ty = match &*context.unifier.get_ty(obj_ty) { - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - unpack_ndarray_var_tys(&mut context.unifier, obj_ty).0 - } - - TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => { - let mut ty = *params.iter().next().unwrap().1; - while let TypeEnum::TObj { obj_id, params, .. } = &*context.unifier.get_ty_immutable(ty) - { - if *obj_id != PrimDef::List.id() { - break; - } - - ty = *params.iter().next().unwrap().1; - } - ty - } - - _ => obj_ty, - }; let obj_arg = args[0].1.clone().to_basic_value_enum(context, generator, obj_ty)?; let copy_arg = if let Some(arg) = @@ -1849,28 +140,18 @@ pub fn gen_ndarray_array<'ctx>( ) }; - let ndmin_arg = if let Some(arg) = - args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[2].name)) - { - let ndmin_ty = fun.0.args[2].ty; - arg.1.clone().to_basic_value_enum(context, generator, ndmin_ty)? - } else { - context.gen_symbol_val( - generator, - fun.0.args[2].default_value.as_ref().unwrap(), - fun.0.args[2].ty, - ) - }; + // The ndmin argument is ignored. We can simply force the ndarray's number of dimensions to be + // the `ndims` of the function return type. + let (_, ndims) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + let ndims = extract_ndims(&context.unifier, ndims); - call_ndarray_array_impl( - generator, - context, - obj_elem_ty, - obj_arg, - copy_arg.into_int_value(), - ndmin_arg.into_int_value(), - ) - .map(NDArrayValue::into) + let object = AnyObject { value: obj_arg, ty: obj_ty }; + // NAC3 booleans are i8. + let copy = Int(Bool).truncate(generator, context, copy_arg.into_int_value()); + let ndarray = NDArrayObject::make_np_array(generator, context, object, copy) + .atleast_nd(generator, context, ndims); + + Ok(ndarray.instance.value) } /// Generates LLVM IR for `ndarray.eye`. @@ -1909,15 +190,23 @@ pub fn gen_ndarray_eye<'ctx>( )) }?; - call_ndarray_eye_impl( - generator, - context, - context.primitives.float, - nrows_arg.into_int_value(), - ncols_arg.into_int_value(), - offset_arg.into_int_value(), - ) - .map(NDArrayValue::into) + let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + + let nrows = Int(Int32) + .check_value(generator, context.ctx, nrows_arg) + .unwrap() + .s_extend_or_bit_cast(generator, context, SizeT); + let ncols = Int(Int32) + .check_value(generator, context.ctx, ncols_arg) + .unwrap() + .s_extend_or_bit_cast(generator, context, SizeT); + let offset = Int(Int32) + .check_value(generator, context.ctx, offset_arg) + .unwrap() + .s_extend_or_bit_cast(generator, context, SizeT); + + let ndarray = NDArrayObject::make_np_eye(generator, context, dtype, nrows, ncols, offset); + Ok(ndarray.instance.value) } /// Generates LLVM IR for `ndarray.identity`. @@ -1931,20 +220,15 @@ pub fn gen_ndarray_identity<'ctx>( assert!(obj.is_none()); assert_eq!(args.len(), 1); - let llvm_usize = generator.get_size_type(context.ctx); + let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); let n_ty = fun.0.args[0].ty; let n_arg = args[0].1.clone().to_basic_value_enum(context, generator, n_ty)?; - call_ndarray_eye_impl( - generator, - context, - context.primitives.float, - n_arg.into_int_value(), - n_arg.into_int_value(), - llvm_usize.const_zero(), - ) - .map(NDArrayValue::into) + let n = Int(Int32).check_value(generator, context.ctx, n_arg).unwrap(); + let n = n.s_extend_or_bit_cast(generator, context, SizeT); + let ndarray = NDArrayObject::make_np_identity(generator, context, dtype, n); + Ok(ndarray.instance.value) } /// Generates LLVM IR for `ndarray.copy`. @@ -1958,20 +242,14 @@ pub fn gen_ndarray_copy<'ctx>( assert!(obj.is_some()); assert!(args.is_empty()); - let llvm_usize = generator.get_size_type(context.ctx); - let this_ty = obj.as_ref().unwrap().0; - let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty); let this_arg = obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?; - ndarray_copy_impl( - generator, - context, - this_elem_ty, - NDArrayValue::from_ptr_val(this_arg.into_pointer_value(), llvm_usize, None), - ) - .map(NDArrayValue::into) + let this = AnyObject { value: this_arg, ty: this_ty }; + let this = NDArrayObject::from_object(generator, context, this); + let ndarray = this.make_copy(generator, context); + Ok(ndarray.instance.value) } /// Generates LLVM IR for `ndarray.fill`. @@ -1985,443 +263,18 @@ pub fn gen_ndarray_fill<'ctx>( assert!(obj.is_some()); assert_eq!(args.len(), 1); - let llvm_usize = generator.get_size_type(context.ctx); - let this_ty = obj.as_ref().unwrap().0; - let this_arg = obj - .as_ref() - .unwrap() - .1 - .clone() - .to_basic_value_enum(context, generator, this_ty)? - .into_pointer_value(); + let this_arg = + obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?; let value_ty = fun.0.args[0].ty; let value_arg = args[0].1.clone().to_basic_value_enum(context, generator, value_ty)?; - ndarray_fill_flattened( - generator, - context, - NDArrayValue::from_ptr_val(this_arg, llvm_usize, None), - |generator, ctx, _| { - let value = if value_arg.is_pointer_value() { - let llvm_i1 = ctx.ctx.bool_type(); - - let copy = generator.gen_var_alloc(ctx, value_arg.get_type(), None)?; - - call_memcpy_generic( - ctx, - copy, - value_arg.into_pointer_value(), - value_arg.get_type().size_of().map(Into::into).unwrap(), - llvm_i1.const_zero(), - ); - - copy.into() - } else if value_arg.is_int_value() || value_arg.is_float_value() { - value_arg - } else { - unreachable!() - }; - - Ok(value) - }, - )?; - + let this = AnyObject { value: this_arg, ty: this_ty }; + let this = NDArrayObject::from_object(generator, context, this); + this.fill(generator, context, value_arg); Ok(()) } -/// Generates LLVM IR for `ndarray.transpose`. -pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - const FN_NAME: &str = "ndarray_transpose"; - let (x1_ty, x1) = x1; - let llvm_usize = generator.get_size_type(ctx.ctx); - - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); - let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); - - // Dimensions are reversed in the transposed array - let out = create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &n1, - |_, ctx, n| Ok(n.load_ndims(ctx)), - |generator, ctx, n, idx| { - let new_idx = ctx.builder.build_int_sub(n.load_ndims(ctx), idx, "").unwrap(); - let new_idx = ctx - .builder - .build_int_sub(new_idx, new_idx.get_type().const_int(1, false), "") - .unwrap(); - unsafe { Ok(n.dim_sizes().get_typed_unchecked(ctx, generator, &new_idx, None)) } - }, - ) - .unwrap(); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (n_sz, false), - |generator, ctx, _, idx| { - let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) }; - - let new_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; - let rem_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; - ctx.builder.build_store(new_idx, llvm_usize.const_zero()).unwrap(); - ctx.builder.build_store(rem_idx, idx).unwrap(); - - // Incrementally calculate the new index in the transposed array - // For each index, we first decompose it into the n-dims and use those to reconstruct the new index - // The formula used for indexing is: - // idx = dim_n * ( ... (dim2 * (dim0 * dim1) + dim1) + dim2 ... ) + dim_n - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (n1.load_ndims(ctx), false), - |generator, ctx, _, ndim| { - let ndim_rev = - ctx.builder.build_int_sub(n1.load_ndims(ctx), ndim, "").unwrap(); - let ndim_rev = ctx - .builder - .build_int_sub(ndim_rev, llvm_usize.const_int(1, false), "") - .unwrap(); - let dim = unsafe { - n1.dim_sizes().get_typed_unchecked(ctx, generator, &ndim_rev, None) - }; - - let rem_idx_val = - ctx.builder.build_load(rem_idx, "").unwrap().into_int_value(); - let new_idx_val = - ctx.builder.build_load(new_idx, "").unwrap().into_int_value(); - - let add_component = - ctx.builder.build_int_unsigned_rem(rem_idx_val, dim, "").unwrap(); - let rem_idx_val = - ctx.builder.build_int_unsigned_div(rem_idx_val, dim, "").unwrap(); - - let new_idx_val = ctx.builder.build_int_mul(new_idx_val, dim, "").unwrap(); - let new_idx_val = - ctx.builder.build_int_add(new_idx_val, add_component, "").unwrap(); - - ctx.builder.build_store(rem_idx, rem_idx_val).unwrap(); - ctx.builder.build_store(new_idx, new_idx_val).unwrap(); - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - let new_idx_val = ctx.builder.build_load(new_idx, "").unwrap().into_int_value(); - unsafe { out.data().set_unchecked(ctx, generator, &new_idx_val, elem) }; - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - Ok(out.as_base_value().into()) - } else { - unreachable!( - "{FN_NAME}() not supported for '{}'", - format!("'{}'", ctx.unifier.stringify(x1_ty)) - ) - } -} - -/// LLVM-typed implementation for generating the implementation for `ndarray.reshape`. -/// -/// * `x1` - `NDArray` to reshape. -/// * `shape` - The `shape` parameter used to construct the new `NDArray`. -/// Just like numpy, the `shape` argument can be: -/// 1. A list of `int32`; e.g., `np.reshape(arr, [600, -1, 3])` -/// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))` -/// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)` -/// -/// Note that unlike other generating functions, one of the dimensions in the shape can be negative. -pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - shape: (Type, BasicValueEnum<'ctx>), -) -> Result, String> { - const FN_NAME: &str = "ndarray_reshape"; - let (x1_ty, x1) = x1; - let (_, shape) = shape; - - let llvm_usize = generator.get_size_type(ctx.ctx); - - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); - let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); - - let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; - let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; - ctx.builder.build_store(acc, llvm_usize.const_int(1, false)).unwrap(); - ctx.builder.build_store(num_neg, llvm_usize.const_zero()).unwrap(); - - let out = match shape { - BasicValueEnum::PointerValue(shape_list_ptr) - if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() => - { - // 1. A list of ints; e.g., `np.reshape(arr, [int64(600), int64(800, -1])` - - let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None); - // Check for -1 in dimensions - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (shape_list.load_size(ctx, None), false), - |generator, ctx, _, idx| { - let ele = - shape_list.data().get(ctx, generator, &idx, None).into_int_value(); - let ele = ctx.builder.build_int_s_extend(ele, llvm_usize, "").unwrap(); - - gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare( - IntPredicate::SLT, - ele, - llvm_usize.const_zero(), - "", - ) - .unwrap()) - }, - |_, ctx| -> Result, String> { - let num_neg_value = - ctx.builder.build_load(num_neg, "").unwrap().into_int_value(); - let num_neg_value = ctx - .builder - .build_int_add( - num_neg_value, - llvm_usize.const_int(1, false), - "", - ) - .unwrap(); - ctx.builder.build_store(num_neg, num_neg_value).unwrap(); - Ok(None) - }, - |_, ctx| { - let acc_value = - ctx.builder.build_load(acc, "").unwrap().into_int_value(); - let acc_value = - ctx.builder.build_int_mul(acc_value, ele, "").unwrap(); - ctx.builder.build_store(acc, acc_value).unwrap(); - Ok(None) - }, - )?; - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value(); - let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap(); - // Generate the output shape by filling -1 with `rem` - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &shape_list, - |_, ctx, _| Ok(shape_list.load_size(ctx, None)), - |generator, ctx, shape_list, idx| { - let dim = - shape_list.data().get(ctx, generator, &idx, None).into_int_value(); - let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap(); - - Ok(gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare( - IntPredicate::SLT, - dim, - llvm_usize.const_zero(), - "", - ) - .unwrap()) - }, - |_, _| Ok(Some(rem)), - |_, _| Ok(Some(dim)), - )? - .unwrap() - .into_int_value()) - }, - ) - } - BasicValueEnum::StructValue(shape_tuple) => { - // 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))` - - let ndims = shape_tuple.get_type().count_fields(); - // Check for -1 in dims - for dim_i in 0..ndims { - let dim = ctx - .builder - .build_extract_value(shape_tuple, dim_i, "") - .unwrap() - .into_int_value(); - let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap(); - - gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare( - IntPredicate::SLT, - dim, - llvm_usize.const_zero(), - "", - ) - .unwrap()) - }, - |_, ctx| -> Result, String> { - let num_negs = - ctx.builder.build_load(num_neg, "").unwrap().into_int_value(); - let num_negs = ctx - .builder - .build_int_add(num_negs, llvm_usize.const_int(1, false), "") - .unwrap(); - ctx.builder.build_store(num_neg, num_negs).unwrap(); - Ok(None) - }, - |_, ctx| { - let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value(); - let acc_val = ctx.builder.build_int_mul(acc_val, dim, "").unwrap(); - ctx.builder.build_store(acc, acc_val).unwrap(); - Ok(None) - }, - )?; - } - - let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value(); - let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap(); - let mut shape = Vec::with_capacity(ndims as usize); - - // Reconstruct shape filling negatives with rem - for dim_i in 0..ndims { - let dim = ctx - .builder - .build_extract_value(shape_tuple, dim_i, "") - .unwrap() - .into_int_value(); - let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap(); - - let dim = gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare( - IntPredicate::SLT, - dim, - llvm_usize.const_zero(), - "", - ) - .unwrap()) - }, - |_, _| Ok(Some(rem)), - |_, _| Ok(Some(dim)), - )? - .unwrap() - .into_int_value(); - shape.push(dim); - } - create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice()) - } - BasicValueEnum::IntValue(shape_int) => { - // 3. A scalar `int32`; e.g., `np.reshape(arr, 3)` - let shape_int = gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare( - IntPredicate::SLT, - shape_int, - llvm_usize.const_zero(), - "", - ) - .unwrap()) - }, - |_, _| Ok(Some(n_sz)), - |_, ctx| { - Ok(Some(ctx.builder.build_int_s_extend(shape_int, llvm_usize, "").unwrap())) - }, - )? - .unwrap() - .into_int_value(); - create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int]) - } - _ => unreachable!(), - } - .unwrap(); - - // Only allow one dimension to be negative - let num_negs = ctx.builder.build_load(num_neg, "").unwrap().into_int_value(); - ctx.make_assert( - generator, - ctx.builder - .build_int_compare(IntPredicate::ULT, num_negs, llvm_usize.const_int(2, false), "") - .unwrap(), - "0:ValueError", - "can only specify one unknown dimension", - [None, None, None], - ctx.current_loc, - ); - - // The new shape must be compatible with the old shape - let out_sz = call_ndarray_calc_size(generator, ctx, &out.dim_sizes(), (None, None)); - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(), - "0:ValueError", - "cannot reshape array of size {0} into provided shape of size {1}", - [Some(n_sz), Some(out_sz), None], - ctx.current_loc, - ); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (n_sz, false), - |generator, ctx, _, idx| { - let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) }; - unsafe { out.data().set_unchecked(ctx, generator, &idx, elem) }; - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - Ok(out.as_base_value().into()) - } else { - unreachable!( - "{FN_NAME}() not supported for '{}'", - format!("'{}'", ctx.unifier.stringify(x1_ty)) - ) - } -} - /// Generates LLVM IR for `ndarray.dot`. /// Calculate inner product of two vectors or literals /// For matrix multiplication use `np_matmul` @@ -2436,77 +289,88 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "ndarray_dot"; let (x1_ty, x1) = x1; - let (_, x2) = x2; - - let llvm_usize = generator.get_size_type(ctx.ctx); + let (x2_ty, x2) = x2; match (x1, x2) { - (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => { - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); - let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None); + (BasicValueEnum::PointerValue(_), BasicValueEnum::PointerValue(_)) => { + let a = AnyObject { ty: x1_ty, value: x1 }; + let b = AnyObject { ty: x2_ty, value: x2 }; - let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); - let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); + let a = NDArrayObject::from_object(generator, ctx, a); + let b = NDArrayObject::from_object(generator, ctx, b); + // TODO: General `np.dot()` https://numpy.org/doc/stable/reference/generated/numpy.dot.html. + assert_eq!(a.ndims, 1); + assert_eq!(b.ndims, 1); + let common_dtype = a.dtype; + + // Check shapes. + let a_size = a.size(generator, ctx); + let b_size = b.size(generator, ctx); + let same_shape = a_size.compare(ctx, IntPredicate::EQ, b_size); ctx.make_assert( generator, - ctx.builder.build_int_compare(IntPredicate::EQ, n1_sz, n2_sz, "").unwrap(), + same_shape.value, "0:ValueError", - "shapes ({0}), ({1}) not aligned", - [Some(n1_sz), Some(n2_sz), None], + "shapes ({0},) and ({1},) not aligned: {0} (dim 0) != {1} (dim 1)", + [Some(a_size.value), Some(b_size.value), None], ctx.current_loc, ); - let identity = - unsafe { n1.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; - let acc = ctx.builder.build_alloca(identity.get_type(), "").unwrap(); - ctx.builder.build_store(acc, identity.get_type().const_zero()).unwrap(); + let dtype_llvm = ctx.get_llvm_type(generator, common_dtype); - gen_for_callback_incrementing( + let result = ctx.builder.build_alloca(dtype_llvm, "np_dot_result").unwrap(); + ctx.builder.build_store(result, dtype_llvm.const_zero()).unwrap(); + + // Do dot product. + gen_for_callback( generator, ctx, - None, - llvm_usize.const_zero(), - (n1_sz, false), - |generator, ctx, _, idx| { - let elem1 = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) }; - let elem2 = unsafe { n2.data().get_unchecked(ctx, generator, &idx, None) }; + Some("np_dot"), + |generator, ctx| { + let a_iter = NDIterHandle::new(generator, ctx, a); + let b_iter = NDIterHandle::new(generator, ctx, b); + Ok((a_iter, b_iter)) + }, + |generator, ctx, (a_iter, _b_iter)| { + // Only a_iter drives the condition, b_iter should have the same status. + Ok(a_iter.has_next(generator, ctx).value) + }, + |generator, ctx, _hooks, (a_iter, b_iter)| { + let a_scalar = a_iter.get_scalar(generator, ctx).value; + let b_scalar = b_iter.get_scalar(generator, ctx).value; - let product = match elem1 { - BasicValueEnum::IntValue(e1) => ctx - .builder - .build_int_mul(e1, elem2.into_int_value(), "") - .unwrap() - .as_basic_value_enum(), - BasicValueEnum::FloatValue(e1) => ctx - .builder - .build_float_mul(e1, elem2.into_float_value(), "") - .unwrap() - .as_basic_value_enum(), - _ => unreachable!(), + let old_result = ctx.builder.build_load(result, "").unwrap(); + let new_result: BasicValueEnum<'ctx> = match old_result { + BasicValueEnum::IntValue(old_result) => { + let a_scalar = a_scalar.into_int_value(); + let b_scalar = b_scalar.into_int_value(); + let x = ctx.builder.build_int_mul(a_scalar, b_scalar, "").unwrap(); + ctx.builder.build_int_add(old_result, x, "").unwrap().into() + } + BasicValueEnum::FloatValue(old_result) => { + let a_scalar = a_scalar.into_float_value(); + let b_scalar = b_scalar.into_float_value(); + let x = ctx.builder.build_float_mul(a_scalar, b_scalar, "").unwrap(); + ctx.builder.build_float_add(old_result, x, "").unwrap().into() + } + _ => { + panic!("Unrecognized dtype: {}", ctx.unifier.stringify(common_dtype)); + } }; - let acc_val = ctx.builder.build_load(acc, "").unwrap(); - let acc_val = match acc_val { - BasicValueEnum::IntValue(e1) => ctx - .builder - .build_int_add(e1, product.into_int_value(), "") - .unwrap() - .as_basic_value_enum(), - BasicValueEnum::FloatValue(e1) => ctx - .builder - .build_float_add(e1, product.into_float_value(), "") - .unwrap() - .as_basic_value_enum(), - _ => unreachable!(), - }; - ctx.builder.build_store(acc, acc_val).unwrap(); + ctx.builder.build_store(result, new_result).unwrap(); Ok(()) }, - llvm_usize.const_int(1, false), - )?; - let acc_val = ctx.builder.build_load(acc, "").unwrap(); - Ok(acc_val) + |generator, ctx, (a_iter, b_iter)| { + a_iter.next(generator, ctx); + b_iter.next(generator, ctx); + Ok(()) + }, + ) + .unwrap(); + + Ok(ctx.builder.build_load(result, "").unwrap()) } (BasicValueEnum::IntValue(e1), BasicValueEnum::IntValue(e2)) => { Ok(ctx.builder.build_int_mul(e1, e2, "").unwrap().as_basic_value_enum()) diff --git a/nac3core/src/codegen/object/any.rs b/nac3core/src/codegen/object/any.rs new file mode 100644 index 00000000..c7a983e0 --- /dev/null +++ b/nac3core/src/codegen/object/any.rs @@ -0,0 +1,12 @@ +use inkwell::values::BasicValueEnum; + +use crate::typecheck::typedef::Type; + +/// A NAC3 LLVM Python object of any type. +#[derive(Debug, Clone, Copy)] +pub struct AnyObject<'ctx> { + /// Typechecker type of the object. + pub ty: Type, + /// LLVM value of the object. + pub value: BasicValueEnum<'ctx>, +} diff --git a/nac3core/src/codegen/object/list.rs b/nac3core/src/codegen/object/list.rs new file mode 100644 index 00000000..04c68043 --- /dev/null +++ b/nac3core/src/codegen/object/list.rs @@ -0,0 +1,87 @@ +use crate::{ + codegen::{model::*, CodeGenContext, CodeGenerator}, + typecheck::typedef::{iter_type_vars, Type, TypeEnum}, +}; + +use super::any::AnyObject; + +/// Fields of [`List`] +pub struct ListFields<'ctx, F: FieldTraversal<'ctx>, Item: Model<'ctx>> { + /// Array pointer to content + pub items: F::Out>, + /// Number of items in the array + pub len: F::Out>, +} + +/// A list in NAC3. +#[derive(Debug, Clone, Copy, Default)] +pub struct List { + /// Model of the list items + pub item: Item, +} + +impl<'ctx, Item: Model<'ctx>> StructKind<'ctx> for List { + type Fields> = ListFields<'ctx, F, Item>; + + fn traverse_fields>(&self, traversal: &mut F) -> Self::Fields { + Self::Fields { + items: traversal.add("items", Ptr(self.item)), + len: traversal.add_auto("len"), + } + } +} + +impl<'ctx, Item: Model<'ctx>> Instance<'ctx, Ptr>>> { + /// Cast the items pointer to `uint8_t*`. + pub fn with_pi8_items( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Instance<'ctx, Ptr>>>> { + self.pointer_cast(generator, ctx, Struct(List { item: Int(Byte) })) + } +} + +/// A NAC3 Python List object. +#[derive(Debug, Clone, Copy)] +pub struct ListObject<'ctx> { + /// Typechecker type of the list items + pub item_type: Type, + pub instance: Instance<'ctx, Ptr>>>>, +} + +impl<'ctx> ListObject<'ctx> { + /// Create a [`ListObject`] from an LLVM value and its typechecker [`Type`]. + pub fn from_object( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + object: AnyObject<'ctx>, + ) -> Self { + // Check typechecker type and extract `item_type` + let item_type = match &*ctx.unifier.get_ty(object.ty) { + TypeEnum::TObj { obj_id, params, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + iter_type_vars(params).next().unwrap().ty // Extract `item_type` + } + _ => { + panic!("Expecting type to be a list, but got {}", ctx.unifier.stringify(object.ty)) + } + }; + + let plist = Ptr(Struct(List { item: Any(ctx.get_llvm_type(generator, item_type)) })); + + // Create object + let value = plist.check_value(generator, ctx.ctx, object.value).unwrap(); + ListObject { item_type, instance: value } + } + + /// Get the `len()` of this list. + pub fn len( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Instance<'ctx, Int> { + self.instance.get(generator, ctx, |f| f.len) + } +} diff --git a/nac3core/src/codegen/object/mod.rs b/nac3core/src/codegen/object/mod.rs new file mode 100644 index 00000000..17b0b940 --- /dev/null +++ b/nac3core/src/codegen/object/mod.rs @@ -0,0 +1,5 @@ +pub mod any; +pub mod list; +pub mod ndarray; +pub mod tuple; +pub mod utils; diff --git a/nac3core/src/codegen/object/ndarray/array.rs b/nac3core/src/codegen/object/ndarray/array.rs new file mode 100644 index 00000000..23449071 --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/array.rs @@ -0,0 +1,184 @@ +use super::NDArrayObject; +use crate::{ + codegen::{ + irrt::{ + call_nac3_ndarray_array_set_and_validate_list_shape, + call_nac3_ndarray_array_write_list_to_array, + }, + model::*, + object::{any::AnyObject, list::ListObject}, + stmt::gen_if_else_expr_callback, + CodeGenContext, CodeGenerator, + }, + toplevel::helper::{arraylike_flatten_element_type, arraylike_get_ndims}, + typecheck::typedef::{Type, TypeEnum}, +}; + +/// Get the expected `dtype` and `ndims` of the ndarray returned by `np_array(list)`. +fn get_list_object_dtype_and_ndims<'ctx>( + ctx: &mut CodeGenContext<'ctx, '_>, + list: ListObject<'ctx>, +) -> (Type, u64) { + let dtype = arraylike_flatten_element_type(&mut ctx.unifier, list.item_type); + + let ndims = arraylike_get_ndims(&mut ctx.unifier, list.item_type); + let ndims = ndims + 1; // To count `list` itself. + + (dtype, ndims) +} + +impl<'ctx> NDArrayObject<'ctx> { + /// Implementation of `np_array(, copy=True)` + fn make_np_array_list_copy_true_impl( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + list: ListObject<'ctx>, + ) -> Self { + let (dtype, ndims_int) = get_list_object_dtype_and_ndims(ctx, list); + let list_value = list.instance.with_pi8_items(generator, ctx); + + // Validate `list` has a consistent shape. + // Raise an exception if `list` is something abnormal like `[[1, 2], [3]]`. + // If `list` has a consistent shape, deduce the shape and write it to `shape`. + let ndims = Int(SizeT).const_int(generator, ctx.ctx, ndims_int); + let shape = Int(SizeT).array_alloca(generator, ctx, ndims.value); + call_nac3_ndarray_array_set_and_validate_list_shape( + generator, ctx, list_value, ndims, shape, + ); + + let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims_int); + ndarray.copy_shape_from_array(generator, ctx, shape); + ndarray.create_data(generator, ctx); + + // Copy all contents from the list. + call_nac3_ndarray_array_write_list_to_array(generator, ctx, list_value, ndarray.instance); + + ndarray + } + + /// Implementation of `np_array(, copy=None)` + fn make_np_array_list_copy_none_impl( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + list: ListObject<'ctx>, + ) -> Self { + // np_array without copying is only possible `list` is not nested. + // + // If `list` is `list[T]`, we can create an ndarray with `data` set + // to the array pointer of `list`. + // + // If `list` is `list[list[T]]` or worse, copy. + + let (dtype, ndims) = get_list_object_dtype_and_ndims(ctx, list); + if ndims == 1 { + // `list` is not nested + let ndarray = NDArrayObject::alloca(generator, ctx, dtype, 1); + + // Set data + let data = list.instance.get(generator, ctx, |f| f.items).cast_to_pi8(generator, ctx); + ndarray.instance.set(ctx, |f| f.data, data); + + // ndarray->shape[0] = list->len; + let shape = ndarray.instance.get(generator, ctx, |f| f.shape); + let list_len = list.instance.get(generator, ctx, |f| f.len); + shape.set_index_const(ctx, 0, list_len); + + // Set strides, the `data` is contiguous + ndarray.set_strides_contiguous(generator, ctx); + + ndarray + } else { + // `list` is nested, copy + NDArrayObject::make_np_array_list_copy_true_impl(generator, ctx, list) + } + } + + /// Implementation of `np_array(, copy=copy)` + fn make_np_array_list_impl( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + list: ListObject<'ctx>, + copy: Instance<'ctx, Int>, + ) -> Self { + let (dtype, ndims) = get_list_object_dtype_and_ndims(ctx, list); + + let ndarray = gen_if_else_expr_callback( + generator, + ctx, + |_generator, _ctx| Ok(copy.value), + |generator, ctx| { + let ndarray = + NDArrayObject::make_np_array_list_copy_true_impl(generator, ctx, list); + Ok(Some(ndarray.instance.value)) + }, + |generator, ctx| { + let ndarray = + NDArrayObject::make_np_array_list_copy_none_impl(generator, ctx, list); + Ok(Some(ndarray.instance.value)) + }, + ) + .unwrap() + .unwrap(); + + NDArrayObject::from_value_and_unpacked_types(generator, ctx, ndarray, dtype, ndims) + } + + /// Implementation of `np_array(, copy=copy)`. + pub fn make_np_array_ndarray_impl( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: NDArrayObject<'ctx>, + copy: Instance<'ctx, Int>, + ) -> Self { + let ndarray_val = gen_if_else_expr_callback( + generator, + ctx, + |_generator, _ctx| Ok(copy.value), + |generator, ctx| { + let ndarray = ndarray.make_copy(generator, ctx); // Force copy + Ok(Some(ndarray.instance.value)) + }, + |_generator, _ctx| { + // No need to copy. Return `ndarray` itself. + Ok(Some(ndarray.instance.value)) + }, + ) + .unwrap() + .unwrap(); + + NDArrayObject::from_value_and_unpacked_types( + generator, + ctx, + ndarray_val, + ndarray.dtype, + ndarray.ndims, + ) + } + + /// Create a new ndarray like `np.array()`. + /// + /// NOTE: The `ndmin` argument is not here. You may want to + /// do [`NDArrayObject::atleast_nd`] to achieve that. + pub fn make_np_array( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + object: AnyObject<'ctx>, + copy: Instance<'ctx, Int>, + ) -> Self { + match &*ctx.unifier.get_ty(object.ty) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + let list = ListObject::from_object(generator, ctx, object); + NDArrayObject::make_np_array_list_impl(generator, ctx, list, copy) + } + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => + { + let ndarray = NDArrayObject::from_object(generator, ctx, object); + NDArrayObject::make_np_array_ndarray_impl(generator, ctx, ndarray, copy) + } + _ => panic!("Unrecognized object type: {}", ctx.unifier.stringify(object.ty)), // Typechecker ensures this + } + } +} diff --git a/nac3core/src/codegen/object/ndarray/broadcast.rs b/nac3core/src/codegen/object/ndarray/broadcast.rs new file mode 100644 index 00000000..79d61830 --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/broadcast.rs @@ -0,0 +1,135 @@ +use itertools::Itertools; + +use crate::codegen::{ + irrt::{call_nac3_ndarray_broadcast_shapes, call_nac3_ndarray_broadcast_to}, + model::*, + CodeGenContext, CodeGenerator, +}; + +use super::NDArrayObject; + +/// Fields of [`ShapeEntry`] +pub struct ShapeEntryFields<'ctx, F: FieldTraversal<'ctx>> { + pub ndims: F::Out>, + pub shape: F::Out>>, +} + +/// An IRRT structure used in broadcasting. +#[derive(Debug, Clone, Copy, Default)] +pub struct ShapeEntry; + +impl<'ctx> StructKind<'ctx> for ShapeEntry { + type Fields> = ShapeEntryFields<'ctx, F>; + + fn traverse_fields>(&self, traversal: &mut F) -> Self::Fields { + Self::Fields { ndims: traversal.add_auto("ndims"), shape: traversal.add_auto("shape") } + } +} + +impl<'ctx> NDArrayObject<'ctx> { + /// Create a broadcast view on this ndarray with a target shape. + /// + /// The input shape will be checked to make sure that it contains no negative values. + /// + /// * `target_ndims` - The ndims type after broadcasting to the given shape. + /// The caller has to figure this out for this function. + /// * `target_shape` - An array pointer pointing to the target shape. + #[must_use] + pub fn broadcast_to( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + target_ndims: u64, + target_shape: Instance<'ctx, Ptr>>, + ) -> Self { + let broadcast_ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, target_ndims); + broadcast_ndarray.copy_shape_from_array(generator, ctx, target_shape); + + call_nac3_ndarray_broadcast_to(generator, ctx, self.instance, broadcast_ndarray.instance); + broadcast_ndarray + } +} +/// A result produced by [`broadcast_all_ndarrays`] +#[derive(Debug, Clone)] +pub struct BroadcastAllResult<'ctx> { + /// The statically known `ndims` of the broadcast result. + pub ndims: u64, + /// The broadcasting shape. + pub shape: Instance<'ctx, Ptr>>, + /// Broadcasted views on the inputs. + /// + /// All of them will have `shape` [`BroadcastAllResult::shape`] and + /// `ndims` [`BroadcastAllResult::ndims`]. The length of the vector + /// is the same as the input. + pub ndarrays: Vec>, +} + +/// Helper function to call `call_nac3_ndarray_broadcast_shapes` +fn broadcast_shapes<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + in_shape_entries: &[(Instance<'ctx, Ptr>>, u64)], // (shape, shape's length/ndims) + broadcast_ndims: u64, + broadcast_shape: Instance<'ctx, Ptr>>, +) { + // Prepare input shape entries to be passed to `call_nac3_ndarray_broadcast_shapes`. + let num_shape_entries = + Int(SizeT).const_int(generator, ctx.ctx, u64::try_from(in_shape_entries.len()).unwrap()); + let shape_entries = Struct(ShapeEntry).array_alloca(generator, ctx, num_shape_entries.value); + for (i, (in_shape, in_ndims)) in in_shape_entries.iter().enumerate() { + let pshape_entry = shape_entries.offset_const(ctx, i as u64); + + let in_ndims = Int(SizeT).const_int(generator, ctx.ctx, *in_ndims); + pshape_entry.set(ctx, |f| f.ndims, in_ndims); + + pshape_entry.set(ctx, |f| f.shape, *in_shape); + } + + let broadcast_ndims = Int(SizeT).const_int(generator, ctx.ctx, broadcast_ndims); + call_nac3_ndarray_broadcast_shapes( + generator, + ctx, + num_shape_entries, + shape_entries, + broadcast_ndims, + broadcast_shape, + ); +} + +impl<'ctx> NDArrayObject<'ctx> { + /// Broadcast all ndarrays according to `np.broadcast()` and return a [`BroadcastAllResult`] + /// containing all the information of the result of the broadcast operation. + pub fn broadcast( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarrays: &[Self], + ) -> BroadcastAllResult<'ctx> { + assert!(!ndarrays.is_empty()); + + // Infer the broadcast output ndims. + let broadcast_ndims_int = ndarrays.iter().map(|ndarray| ndarray.ndims).max().unwrap(); + + let broadcast_ndims = Int(SizeT).const_int(generator, ctx.ctx, broadcast_ndims_int); + let broadcast_shape = Int(SizeT).array_alloca(generator, ctx, broadcast_ndims.value); + + let shape_entries = ndarrays + .iter() + .map(|ndarray| (ndarray.instance.get(generator, ctx, |f| f.shape), ndarray.ndims)) + .collect_vec(); + broadcast_shapes(generator, ctx, &shape_entries, broadcast_ndims_int, broadcast_shape); + + // Broadcast all the inputs to shape `dst_shape`. + let broadcast_ndarrays: Vec<_> = ndarrays + .iter() + .map(|ndarray| { + ndarray.broadcast_to(generator, ctx, broadcast_ndims_int, broadcast_shape) + }) + .collect_vec(); + + BroadcastAllResult { + ndims: broadcast_ndims_int, + shape: broadcast_shape, + ndarrays: broadcast_ndarrays, + } + } +} diff --git a/nac3core/src/codegen/object/ndarray/contiguous.rs b/nac3core/src/codegen/object/ndarray/contiguous.rs new file mode 100644 index 00000000..6f067679 --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/contiguous.rs @@ -0,0 +1,134 @@ +use crate::{ + codegen::{model::*, CodeGenContext, CodeGenerator}, + typecheck::typedef::Type, +}; + +use super::NDArrayObject; + +/// Fields of [`ContiguousNDArray`] +pub struct ContiguousNDArrayFields<'ctx, F: FieldTraversal<'ctx>, Item: Model<'ctx>> { + pub ndims: F::Out>, + pub shape: F::Out>>, + pub data: F::Out>, +} + +/// An ndarray without strides and non-opaque `data` field in NAC3. +#[derive(Debug, Clone, Copy)] +pub struct ContiguousNDArray { + /// [`Model`] of the items. + pub item: M, +} + +impl<'ctx, Item: Model<'ctx>> StructKind<'ctx> for ContiguousNDArray { + type Fields> = ContiguousNDArrayFields<'ctx, F, Item>; + + fn traverse_fields>(&self, traversal: &mut F) -> Self::Fields { + Self::Fields { + ndims: traversal.add_auto("ndims"), + shape: traversal.add_auto("shape"), + data: traversal.add("data", Ptr(self.item)), + } + } +} + +impl<'ctx> NDArrayObject<'ctx> { + /// Create a [`ContiguousNDArray`] from the contents of this ndarray. + /// + /// This function may or may not be expensive depending on if this ndarray has contiguous data. + /// + /// If this ndarray is not C-contiguous, this function will allocate memory on the stack for the `data` field of + /// the returned [`ContiguousNDArray`] and copy contents of this ndarray to there. + /// + /// If this ndarray is C-contiguous, contents of this ndarray will not be copied. The created [`ContiguousNDArray`] + /// will share memory with this ndarray. + /// + /// The `item_model` sets the [`Model`] of the returned [`ContiguousNDArray`]'s `Item` model for type-safety, and + /// should match the `ctx.get_llvm_type()` of this ndarray's `dtype`. Otherwise this function panics. Use model [`Any`] + /// if you don't care/cannot know the [`Model`] in advance. + pub fn make_contiguous_ndarray>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + item_model: Item, + ) -> Instance<'ctx, Ptr>>> { + // Sanity check on `self.dtype` and `item_model`. + let dtype_llvm = ctx.get_llvm_type(generator, self.dtype); + item_model.check_type(generator, ctx.ctx, dtype_llvm).unwrap(); + + let cdarray_model = Struct(ContiguousNDArray { item: item_model }); + + let current_bb = ctx.builder.get_insert_block().unwrap(); + let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "then_bb"); + let else_bb = ctx.ctx.insert_basic_block_after(then_bb, "else_bb"); + let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "end_bb"); + + // Allocate and setup the resulting [`ContiguousNDArray`]. + let result = cdarray_model.alloca(generator, ctx); + + // Set ndims and shape. + let ndims = self.ndims_llvm(generator, ctx.ctx); + result.set(ctx, |f| f.ndims, ndims); + + let shape = self.instance.get(generator, ctx, |f| f.shape); + result.set(ctx, |f| f.shape, shape); + + let is_contiguous = self.is_c_contiguous(generator, ctx); + ctx.builder.build_conditional_branch(is_contiguous.value, then_bb, else_bb).unwrap(); + + // Inserting into then_bb; This ndarray is contiguous. + ctx.builder.position_at_end(then_bb); + let data = self.instance.get(generator, ctx, |f| f.data); + let data = data.pointer_cast(generator, ctx, item_model); + result.set(ctx, |f| f.data, data); + ctx.builder.build_unconditional_branch(end_bb).unwrap(); + + // Inserting into else_bb; This ndarray is not contiguous. Do a full-copy on `data`. + // `make_copy` produces an ndarray with contiguous `data`. + ctx.builder.position_at_end(else_bb); + let copied_ndarray = self.make_copy(generator, ctx); + let data = copied_ndarray.instance.get(generator, ctx, |f| f.data); + let data = data.pointer_cast(generator, ctx, item_model); + result.set(ctx, |f| f.data, data); + ctx.builder.build_unconditional_branch(end_bb).unwrap(); + + // Reposition to end_bb for continuation + ctx.builder.position_at_end(end_bb); + + result + } + + /// Create an [`NDArrayObject`] from a [`ContiguousNDArray`]. + /// + /// The operation is super cheap. The newly created [`NDArrayObject`] will share the + /// same memory as the [`ContiguousNDArray`]. + /// + /// `ndims` has to be provided as [`NDArrayObject`] requires a statically known `ndims` value, despite + /// the fact that the information should be contained within the [`ContiguousNDArray`]. + pub fn from_contiguous_ndarray>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + carray: Instance<'ctx, Ptr>>>, + dtype: Type, + ndims: u64, + ) -> Self { + // Sanity check on `dtype` and `contiguous_array`'s `Item` model. + let dtype_llvm = ctx.get_llvm_type(generator, dtype); + carray.model.0 .0.item.check_type(generator, ctx.ctx, dtype_llvm).unwrap(); + + // TODO: Debug assert `ndims == carray.ndims` to catch bugs. + + // Allocate the resulting ndarray. + let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims); + + // Copy shape and update strides + let shape = carray.get(generator, ctx, |f| f.shape); + ndarray.copy_shape_from_array(generator, ctx, shape); + ndarray.set_strides_contiguous(generator, ctx); + + // Share data + let data = carray.get(generator, ctx, |f| f.data).pointer_cast(generator, ctx, Int(Byte)); + ndarray.instance.set(ctx, |f| f.data, data); + + ndarray + } +} diff --git a/nac3core/src/codegen/object/ndarray/factory.rs b/nac3core/src/codegen/object/ndarray/factory.rs new file mode 100644 index 00000000..04ee79ae --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/factory.rs @@ -0,0 +1,176 @@ +use inkwell::{values::BasicValueEnum, IntPredicate}; + +use crate::{ + codegen::{ + irrt::call_nac3_ndarray_util_assert_shape_no_negative, model::*, CodeGenContext, + CodeGenerator, + }, + typecheck::typedef::Type, +}; + +use super::NDArrayObject; + +/// Get the zero value in `np.zeros()` of a `dtype`. +fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, +) -> BasicValueEnum<'ctx> { + if [ctx.primitives.int32, ctx.primitives.uint32] + .iter() + .any(|ty| ctx.unifier.unioned(dtype, *ty)) + { + ctx.ctx.i32_type().const_zero().into() + } else if [ctx.primitives.int64, ctx.primitives.uint64] + .iter() + .any(|ty| ctx.unifier.unioned(dtype, *ty)) + { + ctx.ctx.i64_type().const_zero().into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.float) { + ctx.ctx.f64_type().const_zero().into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.bool) { + ctx.ctx.bool_type().const_zero().into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.str) { + ctx.gen_string(generator, "").into() + } else { + panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype)); + } +} + +/// Get the one value in `np.ones()` of a `dtype`. +fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, +) -> BasicValueEnum<'ctx> { + if [ctx.primitives.int32, ctx.primitives.uint32] + .iter() + .any(|ty| ctx.unifier.unioned(dtype, *ty)) + { + let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int32); + ctx.ctx.i32_type().const_int(1, is_signed).into() + } else if [ctx.primitives.int64, ctx.primitives.uint64] + .iter() + .any(|ty| ctx.unifier.unioned(dtype, *ty)) + { + let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int64); + ctx.ctx.i64_type().const_int(1, is_signed).into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.float) { + ctx.ctx.f64_type().const_float(1.0).into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.bool) { + ctx.ctx.bool_type().const_int(1, false).into() + } else if ctx.unifier.unioned(dtype, ctx.primitives.str) { + ctx.gen_string(generator, "1").into() + } else { + panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype)); + } +} + +impl<'ctx> NDArrayObject<'ctx> { + /// Create an ndarray like `np.empty`. + pub fn make_np_empty( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + ndims: u64, + shape: Instance<'ctx, Ptr>>, + ) -> Self { + // Validate `shape` + let ndims_llvm = Int(SizeT).const_int(generator, ctx.ctx, ndims); + call_nac3_ndarray_util_assert_shape_no_negative(generator, ctx, ndims_llvm, shape); + + let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims); + ndarray.copy_shape_from_array(generator, ctx, shape); + ndarray.create_data(generator, ctx); + + ndarray + } + + /// Create an ndarray like `np.full`. + pub fn make_np_full( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + ndims: u64, + shape: Instance<'ctx, Ptr>>, + fill_value: BasicValueEnum<'ctx>, + ) -> Self { + let ndarray = NDArrayObject::make_np_empty(generator, ctx, dtype, ndims, shape); + ndarray.fill(generator, ctx, fill_value); + ndarray + } + + /// Create an ndarray like `np.zero`. + pub fn make_np_zeros( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + ndims: u64, + shape: Instance<'ctx, Ptr>>, + ) -> Self { + let fill_value = ndarray_zero_value(generator, ctx, dtype); + NDArrayObject::make_np_full(generator, ctx, dtype, ndims, shape, fill_value) + } + + /// Create an ndarray like `np.ones`. + pub fn make_np_ones( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + ndims: u64, + shape: Instance<'ctx, Ptr>>, + ) -> Self { + let fill_value = ndarray_one_value(generator, ctx, dtype); + NDArrayObject::make_np_full(generator, ctx, dtype, ndims, shape, fill_value) + } + + /// Create an ndarray like `np.eye`. + pub fn make_np_eye( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + nrows: Instance<'ctx, Int>, + ncols: Instance<'ctx, Int>, + offset: Instance<'ctx, Int>, + ) -> Self { + let ndzero = ndarray_zero_value(generator, ctx, dtype); + let ndone = ndarray_one_value(generator, ctx, dtype); + + let ndarray = NDArrayObject::alloca_dynamic_shape(generator, ctx, dtype, &[nrows, ncols]); + + // Create data and make the matrix like look np.eye() + ndarray.create_data(generator, ctx); + ndarray + .foreach(generator, ctx, |generator, ctx, _hooks, nditer| { + // NOTE: rows and cols can never be zero here, since this ndarray's `np.size` would be zero + // and this loop would not execute. + + // Load up `row_i` and `col_i` from indices. + let row_i = nditer.get_indices().get_index_const(generator, ctx, 0); + let col_i = nditer.get_indices().get_index_const(generator, ctx, 1); + + let be_one = row_i.add(ctx, offset).compare(ctx, IntPredicate::EQ, col_i); + let value = ctx.builder.build_select(be_one.value, ndone, ndzero, "value").unwrap(); + + let p = nditer.get_pointer(generator, ctx); + ctx.builder.build_store(p, value).unwrap(); + + Ok(()) + }) + .unwrap(); + + ndarray + } + + /// Create an ndarray like `np.identity`. + pub fn make_np_identity( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + size: Instance<'ctx, Int>, + ) -> Self { + // Convenient implementation + let offset = Int(SizeT).const_0(generator, ctx.ctx); + NDArrayObject::make_np_eye(generator, ctx, dtype, size, size, offset) + } +} diff --git a/nac3core/src/codegen/object/ndarray/indexing.rs b/nac3core/src/codegen/object/ndarray/indexing.rs new file mode 100644 index 00000000..06d686e3 --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/indexing.rs @@ -0,0 +1,226 @@ +use crate::codegen::{ + irrt::call_nac3_ndarray_index, + model::*, + object::utils::slice::{RustSlice, Slice}, + CodeGenContext, CodeGenerator, +}; + +use super::NDArrayObject; + +pub type NDIndexType = Byte; + +/// Fields of [`NDIndex`] +#[derive(Debug, Clone, Copy)] +pub struct NDIndexFields<'ctx, F: FieldTraversal<'ctx>> { + pub type_: F::Out>, // Defined to be uint8_t in IRRT + pub data: F::Out>>, +} + +/// An IRRT representation of an ndarray subscript index. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub struct NDIndex; + +impl<'ctx> StructKind<'ctx> for NDIndex { + type Fields> = NDIndexFields<'ctx, F>; + + fn traverse_fields>(&self, traversal: &mut F) -> Self::Fields { + Self::Fields { type_: traversal.add_auto("type"), data: traversal.add_auto("data") } + } +} + +// A convenience enum representing a [`NDIndex`]. +#[derive(Debug, Clone)] +pub enum RustNDIndex<'ctx> { + SingleElement(Instance<'ctx, Int>), + Slice(RustSlice<'ctx, Int32>), + NewAxis, + Ellipsis, +} + +impl<'ctx> RustNDIndex<'ctx> { + /// Get the value to set `NDIndex::type` for this variant. + fn get_type_id(&self) -> u64 { + // Defined in IRRT, must be in sync + match self { + RustNDIndex::SingleElement(_) => 0, + RustNDIndex::Slice(_) => 1, + RustNDIndex::NewAxis => 2, + RustNDIndex::Ellipsis => 3, + } + } + + /// Write the contents to an LLVM [`NDIndex`]. + fn write_to_ndindex( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + dst_ndindex_ptr: Instance<'ctx, Ptr>>, + ) { + // Set `dst_ndindex_ptr->type` + dst_ndindex_ptr.gep(ctx, |f| f.type_).store( + ctx, + Int(NDIndexType::default()).const_int(generator, ctx.ctx, self.get_type_id()), + ); + + // Set `dst_ndindex_ptr->data` + match self { + RustNDIndex::SingleElement(in_index) => { + let index_ptr = Int(Int32).alloca(generator, ctx); + index_ptr.store(ctx, *in_index); + + dst_ndindex_ptr + .gep(ctx, |f| f.data) + .store(ctx, index_ptr.pointer_cast(generator, ctx, Int(Byte))); + } + RustNDIndex::Slice(in_rust_slice) => { + let user_slice_ptr = Struct(Slice(Int32)).alloca(generator, ctx); + in_rust_slice.write_to_slice(generator, ctx, user_slice_ptr); + + dst_ndindex_ptr + .gep(ctx, |f| f.data) + .store(ctx, user_slice_ptr.pointer_cast(generator, ctx, Int(Byte))); + } + RustNDIndex::NewAxis | RustNDIndex::Ellipsis => {} + } + } + + /// Allocate an array of `NDIndex`es on the stack and return the array pointer. + pub fn make_ndindices( + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + in_ndindices: &[RustNDIndex<'ctx>], + ) -> (Instance<'ctx, Int>, Instance<'ctx, Ptr>>) { + let ndindex_model = Struct(NDIndex); + + let num_ndindices = Int(SizeT).const_int(generator, ctx.ctx, in_ndindices.len() as u64); + let ndindices = ndindex_model.array_alloca(generator, ctx, num_ndindices.value); + for (i, in_ndindex) in in_ndindices.iter().enumerate() { + let pndindex = ndindices.offset_const(ctx, i as u64); + in_ndindex.write_to_ndindex(generator, ctx, pndindex); + } + + (num_ndindices, ndindices) + } +} + +impl<'ctx> NDArrayObject<'ctx> { + /// Get the expected `ndims` after indexing with `indices`. + #[must_use] + fn deduce_ndims_after_indexing_with(&self, indices: &[RustNDIndex<'ctx>]) -> u64 { + let mut ndims = self.ndims; + for index in indices { + match index { + RustNDIndex::SingleElement(_) => { + ndims -= 1; // Single elements decrements ndims + } + RustNDIndex::NewAxis => { + ndims += 1; // `np.newaxis` / `none` adds a new axis + } + RustNDIndex::Ellipsis | RustNDIndex::Slice(_) => {} + } + } + ndims + } + + /// Index into the ndarray, and return a newly-allocated view on this ndarray. + /// + /// This function behaves like NumPy's ndarray indexing, but if the indices index + /// into a single element, an unsized ndarray is returned. + #[must_use] + pub fn index( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + indices: &[RustNDIndex<'ctx>], + ) -> Self { + let dst_ndims = self.deduce_ndims_after_indexing_with(indices); + let dst_ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, dst_ndims); + + let (num_indices, indices) = RustNDIndex::make_ndindices(generator, ctx, indices); + call_nac3_ndarray_index( + generator, + ctx, + num_indices, + indices, + self.instance, + dst_ndarray.instance, + ); + + dst_ndarray + } +} + +pub mod util { + use itertools::Itertools; + use nac3parser::ast::{Expr, ExprKind}; + + use crate::{ + codegen::{ + expr::gen_slice, model::*, object::utils::slice::RustSlice, CodeGenContext, + CodeGenerator, + }, + typecheck::typedef::Type, + }; + + use super::RustNDIndex; + + /// Generate LLVM code to transform an ndarray subscript expression to + /// its list of [`RustNDIndex`] + /// + /// i.e., + /// ```python + /// my_ndarray[::3, 1, :2:] + /// ^^^^^^^^^^^ Then these into a three `RustNDIndex`es + /// ``` + pub fn gen_ndarray_subscript_ndindices<'ctx, G: CodeGenerator>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + subscript: &Expr>, + ) -> Result>, String> { + // TODO: Support https://numpy.org/doc/stable/user/basics.indexing.html#dimensional-indexing-tools + + // Annoying notes about `slice` + // - `my_array[5]` + // - slice is a `Constant` + // - `my_array[:5]` + // - slice is a `Slice` + // - `my_array[:]` + // - slice is a `Slice`, but lower upper step would all be `Option::None` + // - `my_array[:, :]` + // - slice is now a `Tuple` of two `Slice`-s + // + // In summary: + // - when there is a comma "," within [], `slice` will be a `Tuple` of the entries. + // - when there is not comma "," within [] (i.e., just a single entry), `slice` will be that entry itself. + // + // So we first "flatten" out the slice expression + let index_exprs = match &subscript.node { + ExprKind::Tuple { elts, .. } => elts.iter().collect_vec(), + _ => vec![subscript], + }; + + // Process all index expressions + let mut rust_ndindices: Vec = Vec::with_capacity(index_exprs.len()); // Not using iterators here because `?` is used here. + for index_expr in index_exprs { + // NOTE: Currently nac3core's slices do not have an object representation, + // so the code/implementation looks awkward - we have to do pattern matching on the expression + let ndindex = if let ExprKind::Slice { lower, upper, step } = &index_expr.node { + // Handle slices + let (lower, upper, step) = gen_slice(generator, ctx, lower, upper, step)?; + RustNDIndex::Slice(RustSlice { int_kind: Int32, start: lower, stop: upper, step }) + } else { + // Treat and handle everything else as a single element index. + let index = generator.gen_expr(ctx, index_expr)?.unwrap().to_basic_value_enum( + ctx, + generator, + ctx.primitives.int32, // Must be int32, this checks for illegal values + )?; + let index = Int(Int32).check_value(generator, ctx.ctx, index).unwrap(); + + RustNDIndex::SingleElement(index) + }; + rust_ndindices.push(ndindex); + } + Ok(rust_ndindices) + } +} diff --git a/nac3core/src/codegen/object/ndarray/map.rs b/nac3core/src/codegen/object/ndarray/map.rs new file mode 100644 index 00000000..4fcefe23 --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/map.rs @@ -0,0 +1,220 @@ +use inkwell::values::BasicValueEnum; +use itertools::Itertools; + +use crate::{ + codegen::{ + object::ndarray::{AnyObject, NDArrayObject}, + stmt::gen_for_callback, + CodeGenContext, CodeGenerator, + }, + typecheck::typedef::Type, +}; + +use super::{nditer::NDIterHandle, NDArrayOut, ScalarOrNDArray}; + +impl<'ctx> NDArrayObject<'ctx> { + /// Generate LLVM IR to broadcast `ndarray`s together, and starmap through them with `mapping` elementwise. + /// + /// `mapping` is an LLVM IR generator. The input of `mapping` is the list of elements when iterating through + /// the input `ndarrays` after broadcasting. The output of `mapping` is the result of the elementwise operation. + /// + /// `out` specifies whether the result should be a new ndarray or to be written an existing ndarray. + pub fn broadcast_starmap<'a, G, MappingFn>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + ndarrays: &[Self], + out: NDArrayOut<'ctx>, + mapping: MappingFn, + ) -> Result + where + G: CodeGenerator + ?Sized, + MappingFn: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + &[BasicValueEnum<'ctx>], + ) -> Result, String>, + { + // Broadcast inputs + let broadcast_result = NDArrayObject::broadcast(generator, ctx, ndarrays); + + let out_ndarray = match out { + NDArrayOut::NewNDArray { dtype } => { + // Create a new ndarray based on the broadcast shape. + let result_ndarray = + NDArrayObject::alloca(generator, ctx, dtype, broadcast_result.ndims); + result_ndarray.copy_shape_from_array(generator, ctx, broadcast_result.shape); + result_ndarray.create_data(generator, ctx); + result_ndarray + } + NDArrayOut::WriteToNDArray { ndarray: result_ndarray } => { + // Use an existing ndarray. + + // Check that its shape is compatible with the broadcast shape. + result_ndarray.assert_can_be_written_by_out( + generator, + ctx, + broadcast_result.ndims, + broadcast_result.shape, + ); + result_ndarray + } + }; + + // Map element-wise and store results into `mapped_ndarray`. + let nditer = NDIterHandle::new(generator, ctx, out_ndarray); + gen_for_callback( + generator, + ctx, + Some("broadcast_starmap"), + |generator, ctx| { + // Create NDIters for all broadcasted input ndarrays. + let other_nditers = broadcast_result + .ndarrays + .iter() + .map(|ndarray| NDIterHandle::new(generator, ctx, *ndarray)) + .collect_vec(); + Ok((nditer, other_nditers)) + }, + |generator, ctx, (out_nditer, _in_nditers)| { + // We can simply use `out_nditer`'s `has_next()`. + // `in_nditers`' `has_next()`s should return the same value. + Ok(out_nditer.has_next(generator, ctx).value) + }, + |generator, ctx, _hooks, (out_nditer, in_nditers)| { + // Get all the scalars from the broadcasted input ndarrays, pass them to `mapping`, + // and write to `out_ndarray`. + + let in_scalars = in_nditers + .iter() + .map(|nditer| nditer.get_scalar(generator, ctx).value) + .collect_vec(); + + let result = mapping(generator, ctx, &in_scalars)?; + + let p = out_nditer.get_pointer(generator, ctx); + ctx.builder.build_store(p, result).unwrap(); + + Ok(()) + }, + |generator, ctx, (out_nditer, in_nditers)| { + // Advance all iterators + out_nditer.next(generator, ctx); + in_nditers.iter().for_each(|nditer| nditer.next(generator, ctx)); + Ok(()) + }, + )?; + + Ok(out_ndarray) + } + + /// Map through this ndarray with an elementwise function. + pub fn map<'a, G, Mapping>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + out: NDArrayOut<'ctx>, + mapping: Mapping, + ) -> Result + where + G: CodeGenerator + ?Sized, + Mapping: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + BasicValueEnum<'ctx>, + ) -> Result, String>, + { + NDArrayObject::broadcast_starmap( + generator, + ctx, + &[*self], + out, + |generator, ctx, scalars| mapping(generator, ctx, scalars[0]), + ) + } +} + +impl<'ctx> ScalarOrNDArray<'ctx> { + /// Starmap through a list of inputs using `mapping`, where an input could be an ndarray, a scalar. + /// + /// This function is very helpful when implementing NumPy functions that takes on either scalars or ndarrays or a mix of them + /// as their inputs and produces either an ndarray with broadcast, or a scalar if all its inputs are all scalars. + /// + /// For example ,this function can be used to implement `np.add`, which has the following behaviors: + /// - `np.add(3, 4) = 7` # (scalar, scalar) -> scalar + /// - `np.add(3, np.array([4, 5, 6]))` # (scalar, ndarray) -> ndarray; the first `scalar` is converted into an ndarray and broadcasted. + /// - `np.add(np.array([[1], [2], [3]]), np.array([[4, 5, 6]]))` # (ndarray, ndarray) -> ndarray; there is broadcasting. + /// + /// ## Details: + /// + /// If `inputs` are all [`ScalarOrNDArray::Scalar`], the output will be a [`ScalarOrNDArray::Scalar`] with type `ret_dtype`. + /// + /// Otherwise (if there are any [`ScalarOrNDArray::NDArray`] in `inputs`), all inputs will be 'as-ndarray'-ed into ndarrays, + /// then all inputs (now all ndarrays) will be passed to [`NDArrayObject::broadcasting_starmap`] and **create** a new ndarray + /// with dtype `ret_dtype`. + pub fn broadcasting_starmap<'a, G, MappingFn>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + inputs: &[ScalarOrNDArray<'ctx>], + ret_dtype: Type, + mapping: MappingFn, + ) -> Result, String> + where + G: CodeGenerator + ?Sized, + MappingFn: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + &[BasicValueEnum<'ctx>], + ) -> Result, String>, + { + // Check if all inputs are Scalars + let all_scalars: Option> = inputs.iter().map(AnyObject::try_from).try_collect().ok(); + + if let Some(scalars) = all_scalars { + let scalars = scalars.iter().map(|scalar| scalar.value).collect_vec(); + let value = mapping(generator, ctx, &scalars)?; + + Ok(ScalarOrNDArray::Scalar(AnyObject { ty: ret_dtype, value })) + } else { + // Promote all input to ndarrays and map through them. + let inputs = inputs.iter().map(|input| input.to_ndarray(generator, ctx)).collect_vec(); + let ndarray = NDArrayObject::broadcast_starmap( + generator, + ctx, + &inputs, + NDArrayOut::NewNDArray { dtype: ret_dtype }, + mapping, + )?; + Ok(ScalarOrNDArray::NDArray(ndarray)) + } + } + + /// Map through this [`ScalarOrNDArray`] with an elementwise function. + /// + /// If this is a scalar, `mapping` will directly act on the scalar. This function will return a [`ScalarOrNDArray::Scalar`] of that result. + /// + /// If this is an ndarray, `mapping` will be applied to the elements of the ndarray. A new ndarray of the results will be created and + /// returned as a [`ScalarOrNDArray::NDArray`]. + pub fn map<'a, G, Mapping>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + ret_dtype: Type, + mapping: Mapping, + ) -> Result, String> + where + G: CodeGenerator + ?Sized, + Mapping: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + BasicValueEnum<'ctx>, + ) -> Result, String>, + { + ScalarOrNDArray::broadcasting_starmap( + generator, + ctx, + &[*self], + ret_dtype, + |generator, ctx, scalars| mapping(generator, ctx, scalars[0]), + ) + } +} diff --git a/nac3core/src/codegen/object/ndarray/matmul.rs b/nac3core/src/codegen/object/ndarray/matmul.rs new file mode 100644 index 00000000..e27d5fb7 --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/matmul.rs @@ -0,0 +1,218 @@ +use std::cmp::max; + +use nac3parser::ast::Operator; +use util::gen_for_model; + +use crate::{ + codegen::{ + expr::gen_binop_expr_with_values, irrt::call_nac3_ndarray_matmul_calculate_shapes, + model::*, object::ndarray::indexing::RustNDIndex, CodeGenContext, CodeGenerator, + }, + typecheck::{magic_methods::Binop, typedef::Type}, +}; + +use super::{NDArrayObject, NDArrayOut}; + +/// Perform `np.einsum("...ij,...jk->...ik", in_a, in_b)`. +/// +/// `dst_dtype` defines the dtype of the returned ndarray. +fn matmul_at_least_2d<'ctx, G: CodeGenerator>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dst_dtype: Type, + in_a: NDArrayObject<'ctx>, + in_b: NDArrayObject<'ctx>, +) -> NDArrayObject<'ctx> { + assert!(in_a.ndims >= 2); + assert!(in_b.ndims >= 2); + + // Deduce ndims of the result of matmul. + let ndims_int = max(in_a.ndims, in_b.ndims); + let ndims = Int(SizeT).const_int(generator, ctx.ctx, ndims_int); + + let num_0 = Int(SizeT).const_int(generator, ctx.ctx, 0); + let num_1 = Int(SizeT).const_int(generator, ctx.ctx, 1); + + // Broadcasts `in_a.shape[:-2]` and `in_b.shape[:-2]` together and allocate the + // destination ndarray to store the result of matmul. + let (lhs, rhs, dst) = { + let in_lhs_ndims = in_a.ndims_llvm(generator, ctx.ctx); + let in_lhs_shape = in_a.instance.get(generator, ctx, |f| f.shape); + let in_rhs_ndims = in_b.ndims_llvm(generator, ctx.ctx); + let in_rhs_shape = in_b.instance.get(generator, ctx, |f| f.shape); + let lhs_shape = Int(SizeT).array_alloca(generator, ctx, ndims.value); + let rhs_shape = Int(SizeT).array_alloca(generator, ctx, ndims.value); + let dst_shape = Int(SizeT).array_alloca(generator, ctx, ndims.value); + + // Matmul dimension compatibility is checked here. + call_nac3_ndarray_matmul_calculate_shapes( + generator, + ctx, + in_lhs_ndims, + in_lhs_shape, + in_rhs_ndims, + in_rhs_shape, + ndims, + lhs_shape, + rhs_shape, + dst_shape, + ); + + let lhs = in_a.broadcast_to(generator, ctx, ndims_int, lhs_shape); + let rhs = in_b.broadcast_to(generator, ctx, ndims_int, rhs_shape); + + let dst = NDArrayObject::alloca(generator, ctx, dst_dtype, ndims_int); + dst.copy_shape_from_array(generator, ctx, dst_shape); + dst.create_data(generator, ctx); + + (lhs, rhs, dst) + }; + + let len = lhs.instance.get(generator, ctx, |f| f.shape).get_index_const( + generator, + ctx, + ndims_int - 1, + ); + + let at_row = ndims_int - 2; + let at_col = ndims_int - 1; + + let dst_dtype_llvm = ctx.get_llvm_type(generator, dst_dtype); + let dst_zero = dst_dtype_llvm.const_zero(); + + dst.foreach(generator, ctx, |generator, ctx, _, hdl| { + let pdst_ij = hdl.get_pointer(generator, ctx); + + ctx.builder.build_store(pdst_ij, dst_zero).unwrap(); + + let indices = hdl.get_indices(); + let i = indices.get_index_const(generator, ctx, at_row); + let j = indices.get_index_const(generator, ctx, at_col); + + gen_for_model(generator, ctx, num_0, len, num_1, |generator, ctx, _, k| { + // `indices` is modified to index into `a` and `b`, and restored. + indices.set_index_const(ctx, at_row, i); + indices.set_index_const(ctx, at_col, k); + let a_ik = lhs.get_scalar_by_indices(generator, ctx, indices); + + indices.set_index_const(ctx, at_row, k); + indices.set_index_const(ctx, at_col, j); + let b_kj = rhs.get_scalar_by_indices(generator, ctx, indices); + + // Restore `indices`. + indices.set_index_const(ctx, at_row, i); + indices.set_index_const(ctx, at_col, j); + + // x = a_[...]ik * b_[...]kj + let x = gen_binop_expr_with_values( + generator, + ctx, + (&Some(lhs.dtype), a_ik.value), + Binop::normal(Operator::Mult), + (&Some(rhs.dtype), b_kj.value), + ctx.current_loc, + )? + .unwrap() + .to_basic_value_enum(ctx, generator, dst_dtype)?; + + // dst_[...]ij += x + let dst_ij = ctx.builder.build_load(pdst_ij, "").unwrap(); + let dst_ij = gen_binop_expr_with_values( + generator, + ctx, + (&Some(dst_dtype), dst_ij), + Binop::normal(Operator::Add), + (&Some(dst_dtype), x), + ctx.current_loc, + )? + .unwrap() + .to_basic_value_enum(ctx, generator, dst_dtype)?; + ctx.builder.build_store(pdst_ij, dst_ij).unwrap(); + + Ok(()) + }) + }) + .unwrap(); + + dst +} + +impl<'ctx> NDArrayObject<'ctx> { + /// Perform `np.matmul` according to the rules in + /// . + /// + /// This function always return an [`NDArrayObject`]. You may want to use [`NDArrayObject::split_unsized`] + /// to handle when the output could be a scalar. + /// + /// `dst_dtype` defines the dtype of the returned ndarray. + pub fn matmul( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + a: Self, + b: Self, + out: NDArrayOut<'ctx>, + ) -> Self { + // Sanity check, but type inference should prevent this. + assert!(a.ndims > 0 && b.ndims > 0, "np.matmul disallows scalar input"); + + /* + If both arguments are 2-D they are multiplied like conventional matrices. + If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the last two indices and broadcast accordingly. + If the first argument is 1-D, it is promoted to a matrix by prepending a 1 to its dimensions. After matrix multiplication the prepended 1 is removed. + If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its dimensions. After matrix multiplication the appended 1 is removed. + */ + + let new_a = if a.ndims == 1 { + // Prepend 1 to its dimensions + a.index(generator, ctx, &[RustNDIndex::NewAxis, RustNDIndex::Ellipsis]) + } else { + a + }; + + let new_b = if b.ndims == 1 { + // Append 1 to its dimensions + b.index(generator, ctx, &[RustNDIndex::Ellipsis, RustNDIndex::NewAxis]) + } else { + b + }; + + // NOTE: `result` will always be a newly allocated ndarray. + // Current implementation cannot do in-place matrix muliplication. + let mut result = matmul_at_least_2d(generator, ctx, out.get_dtype(), new_a, new_b); + + // Postprocessing on the result to remove prepended/appended axes. + let mut postindices = vec![]; + let zero = Int(Int32).const_0(generator, ctx.ctx); + + if a.ndims == 1 { + // Remove the prepended 1 + postindices.push(RustNDIndex::SingleElement(zero)); + } + + if b.ndims == 1 { + // Remove the appended 1 + postindices.push(RustNDIndex::Ellipsis); + postindices.push(RustNDIndex::SingleElement(zero)); + } + + if !postindices.is_empty() { + result = result.index(generator, ctx, &postindices); + } + + match out { + NDArrayOut::NewNDArray { .. } => result, + NDArrayOut::WriteToNDArray { ndarray: out_ndarray } => { + let result_shape = result.instance.get(generator, ctx, |f| f.shape); + out_ndarray.assert_can_be_written_by_out( + generator, + ctx, + result.ndims, + result_shape, + ); + + out_ndarray.copy_data_from(generator, ctx, result); + out_ndarray + } + } + } +} diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs new file mode 100644 index 00000000..91a0d7be --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -0,0 +1,668 @@ +pub mod array; +pub mod broadcast; +pub mod contiguous; +pub mod factory; +pub mod indexing; +pub mod map; +pub mod matmul; +pub mod nditer; +pub mod shape_util; +pub mod view; + +use inkwell::{ + context::Context, + types::BasicType, + values::{BasicValue, BasicValueEnum, PointerValue}, + AddressSpace, +}; + +use crate::{ + codegen::{ + irrt::{ + call_nac3_ndarray_copy_data, call_nac3_ndarray_get_nth_pelement, + call_nac3_ndarray_get_pelement_by_indices, call_nac3_ndarray_is_c_contiguous, + call_nac3_ndarray_len, call_nac3_ndarray_nbytes, + call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_size, + call_nac3_ndarray_util_assert_output_shape_same, + }, + model::*, + CodeGenContext, CodeGenerator, + }, + toplevel::{ + helper::{create_ndims, extract_ndims}, + numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, + }, + typecheck::typedef::{Type, TypeEnum}, +}; + +use super::{any::AnyObject, tuple::TupleObject}; + +/// Fields of [`NDArray`] +pub struct NDArrayFields<'ctx, F: FieldTraversal<'ctx>> { + pub data: F::Out>>, + pub itemsize: F::Out>, + pub ndims: F::Out>, + pub shape: F::Out>>, + pub strides: F::Out>>, +} + +/// A strided ndarray in NAC3. +/// +/// See IRRT implementation for details about its fields. +#[derive(Debug, Clone, Copy, Default)] +pub struct NDArray; + +impl<'ctx> StructKind<'ctx> for NDArray { + type Fields> = NDArrayFields<'ctx, F>; + + fn traverse_fields>(&self, traversal: &mut F) -> Self::Fields { + Self::Fields { + data: traversal.add_auto("data"), + itemsize: traversal.add_auto("itemsize"), + ndims: traversal.add_auto("ndims"), + shape: traversal.add_auto("shape"), + strides: traversal.add_auto("strides"), + } + } +} + +/// A NAC3 Python ndarray object. +#[derive(Debug, Clone, Copy)] +pub struct NDArrayObject<'ctx> { + pub dtype: Type, + pub ndims: u64, + pub instance: Instance<'ctx, Ptr>>, +} + +impl<'ctx> NDArrayObject<'ctx> { + /// Attempt to convert an [`AnyObject`] into an [`NDArrayObject`]. + pub fn from_object( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + object: AnyObject<'ctx>, + ) -> NDArrayObject<'ctx> { + let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, object.ty); + let ndims = extract_ndims(&ctx.unifier, ndims); + Self::from_value_and_unpacked_types(generator, ctx, object.value, dtype, ndims) + } + + /// Like [`NDArrayObject::from_object`] but you directly supply the ndarray's + /// `dtype` and `ndims`. + pub fn from_value_and_unpacked_types, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: V, + dtype: Type, + ndims: u64, + ) -> Self { + let value = Ptr(Struct(NDArray)).check_value(generator, ctx.ctx, value).unwrap(); + NDArrayObject { dtype, ndims, instance: value } + } + + /// Get this ndarray's `ndims` as an LLVM constant. + pub fn ndims_llvm( + &self, + generator: &mut G, + ctx: &'ctx Context, + ) -> Instance<'ctx, Int> { + Int(SizeT).const_int(generator, ctx, self.ndims) + } + + /// Get the typechecker ndarray type of this [`NDArrayObject`]. + pub fn get_type(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Type { + let ndims = create_ndims(&mut ctx.unifier, self.ndims); + make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(self.dtype), Some(ndims)) + } + + /// Forget that this is an ndarray and convert into an [`AnyObject`]. + pub fn to_any(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> AnyObject<'ctx> { + let ty = self.get_type(ctx); + AnyObject { value: self.instance.value.as_basic_value_enum(), ty } + } + + /// Allocate an ndarray on the stack given its `ndims` and `dtype`. + /// + /// `shape` and `strides` will be automatically allocated onto the stack. + /// + /// The returned ndarray's content will be: + /// - `data`: uninitialized. + /// - `itemsize`: set to the `sizeof()` of `dtype`. + /// - `ndims`: set to the value of `ndims`. + /// - `shape`: allocated with an array of length `ndims` with uninitialized values. + /// - `strides`: allocated with an array of length `ndims` with uninitialized values. + pub fn alloca( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + ndims: u64, + ) -> Self { + let ndarray = Struct(NDArray).alloca(generator, ctx); + + let itemsize = ctx.get_llvm_type(generator, dtype).size_of().unwrap(); + let itemsize = Int(SizeT).z_extend_or_truncate(generator, ctx, itemsize); + ndarray.set(ctx, |f| f.itemsize, itemsize); + + let ndims_val = Int(SizeT).const_int(generator, ctx.ctx, ndims); + ndarray.set(ctx, |f| f.ndims, ndims_val); + + let shape = Int(SizeT).array_alloca(generator, ctx, ndims_val.value); + ndarray.set(ctx, |f| f.shape, shape); + + let strides = Int(SizeT).array_alloca(generator, ctx, ndims_val.value); + ndarray.set(ctx, |f| f.strides, strides); + + NDArrayObject { dtype, ndims, instance: ndarray } + } + + /// Convenience function. Allocate an [`NDArrayObject`] with a statically known shape. + /// + /// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized. + pub fn alloca_constant_shape( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + shape: &[u64], + ) -> Self { + let ndarray = NDArrayObject::alloca(generator, ctx, dtype, shape.len() as u64); + + // Write shape + let dst_shape = ndarray.instance.get(generator, ctx, |f| f.shape); + for (i, dim) in shape.iter().enumerate() { + let dim = Int(SizeT).const_int(generator, ctx.ctx, *dim); + dst_shape.offset_const(ctx, i as u64).store(ctx, dim); + } + + ndarray + } + + /// Convenience function. Allocate an [`NDArrayObject`] with a dynamically known shape. + /// + /// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized. + pub fn alloca_dynamic_shape( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + shape: &[Instance<'ctx, Int>], + ) -> Self { + let ndarray = NDArrayObject::alloca(generator, ctx, dtype, shape.len() as u64); + + // Write shape + let dst_shape = ndarray.instance.get(generator, ctx, |f| f.shape); + for (i, dim) in shape.iter().enumerate() { + dst_shape.offset_const(ctx, i as u64).store(ctx, *dim); + } + + ndarray + } + + /// Initialize an ndarray's `data` by allocating a buffer on the stack. + /// The allocated data buffer is considered to be *owned* by the ndarray. + /// + /// `strides` of the ndarray will also be updated with `set_strides_by_shape`. + /// + /// `shape` and `itemsize` of the ndarray ***must*** be initialized first. + pub fn create_data( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) { + let nbytes = self.nbytes(generator, ctx); + + let data = Int(Byte).array_alloca(generator, ctx, nbytes.value); + self.instance.set(ctx, |f| f.data, data); + + self.set_strides_contiguous(generator, ctx); + } + + /// Copy shape dimensions from an array. + pub fn copy_shape_from_array( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + shape: Instance<'ctx, Ptr>>, + ) { + let num_items = self.ndims_llvm(generator, ctx.ctx).value; + self.instance.get(generator, ctx, |f| f.shape).copy_from(generator, ctx, shape, num_items); + } + + /// Copy shape dimensions from an ndarray. + /// Panics if `ndims` mismatches. + pub fn copy_shape_from_ndarray( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + src_ndarray: NDArrayObject<'ctx>, + ) { + assert_eq!(self.ndims, src_ndarray.ndims); + let src_shape = src_ndarray.instance.get(generator, ctx, |f| f.shape); + self.copy_shape_from_array(generator, ctx, src_shape); + } + + /// Copy strides dimensions from an array. + pub fn copy_strides_from_array( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + strides: Instance<'ctx, Ptr>>, + ) { + let num_items = self.ndims_llvm(generator, ctx.ctx).value; + self.instance + .get(generator, ctx, |f| f.strides) + .copy_from(generator, ctx, strides, num_items); + } + + /// Copy strides dimensions from an ndarray. + /// Panics if `ndims` mismatches. + pub fn copy_strides_from_ndarray( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + src_ndarray: NDArrayObject<'ctx>, + ) { + assert_eq!(self.ndims, src_ndarray.ndims); + let src_strides = src_ndarray.instance.get(generator, ctx, |f| f.strides); + self.copy_strides_from_array(generator, ctx, src_strides); + } + + /// Get the `np.size()` of this ndarray. + pub fn size( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Instance<'ctx, Int> { + call_nac3_ndarray_size(generator, ctx, self.instance) + } + + /// Get the `ndarray.nbytes` of this ndarray. + pub fn nbytes( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Instance<'ctx, Int> { + call_nac3_ndarray_nbytes(generator, ctx, self.instance) + } + + /// Get the `len()` of this ndarray. + pub fn len( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Instance<'ctx, Int> { + call_nac3_ndarray_len(generator, ctx, self.instance) + } + + /// Check if this ndarray is C-contiguous. + /// + /// See NumPy's `flags["C_CONTIGUOUS"]`: + pub fn is_c_contiguous( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Instance<'ctx, Int> { + call_nac3_ndarray_is_c_contiguous(generator, ctx, self.instance) + } + + /// Get the pointer to the n-th (0-based) element. + /// + /// The returned pointer has the element type of the LLVM type of this ndarray's `dtype`. + pub fn get_nth_pelement( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + nth: Instance<'ctx, Int>, + ) -> PointerValue<'ctx> { + let elem_ty = ctx.get_llvm_type(generator, self.dtype); + + let p = call_nac3_ndarray_get_nth_pelement(generator, ctx, self.instance, nth); + ctx.builder + .build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), "") + .unwrap() + } + + /// Get the n-th (0-based) scalar. + pub fn get_nth_scalar( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + nth: Instance<'ctx, Int>, + ) -> AnyObject<'ctx> { + let ptr = self.get_nth_pelement(generator, ctx, nth); + let value = ctx.builder.build_load(ptr, "").unwrap(); + AnyObject { ty: self.dtype, value } + } + + /// Get the pointer to the element indexed by `indices`. + /// + /// The returned pointer has the element type of the LLVM type of this ndarray's `dtype`. + pub fn get_pelement_by_indices( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + indices: Instance<'ctx, Ptr>>, + ) -> PointerValue<'ctx> { + let elem_ty = ctx.get_llvm_type(generator, self.dtype); + + let p = call_nac3_ndarray_get_pelement_by_indices(generator, ctx, self.instance, indices); + ctx.builder + .build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), "") + .unwrap() + } + + /// Get the scalar indexed by `indices`. + pub fn get_scalar_by_indices( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + indices: Instance<'ctx, Ptr>>, + ) -> AnyObject<'ctx> { + let ptr = self.get_pelement_by_indices(generator, ctx, indices); + let value = ctx.builder.build_load(ptr, "").unwrap(); + AnyObject { ty: self.dtype, value } + } + + /// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`. + /// + /// Update the ndarray's strides to make the ndarray contiguous. + pub fn set_strides_contiguous( + self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) { + call_nac3_ndarray_set_strides_by_shape(generator, ctx, self.instance); + } + + /// Clone/Copy this ndarray - Allocate a new ndarray with the same shape as this ndarray and copy the contents over. + /// + /// The new ndarray will own its data and will be C-contiguous. + #[must_use] + pub fn make_copy( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Self { + let clone = NDArrayObject::alloca(generator, ctx, self.dtype, self.ndims); + + let shape = self.instance.gep(ctx, |f| f.shape).load(generator, ctx); + clone.copy_shape_from_array(generator, ctx, shape); + clone.create_data(generator, ctx); + clone.copy_data_from(generator, ctx, *self); + clone + } + + /// Copy data from another ndarray. + /// + /// This ndarray and `src` is that their `np.size()` should be the same. Their shapes + /// do not matter. The copying order is determined by how their flattened views look. + /// + /// Panics if the `dtype`s of ndarrays are different. + pub fn copy_data_from( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + src: NDArrayObject<'ctx>, + ) { + assert!(ctx.unifier.unioned(self.dtype, src.dtype), "self and src dtype should match"); + call_nac3_ndarray_copy_data(generator, ctx, src.instance, self.instance); + } + + /// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar. + #[must_use] + pub fn is_unsized(&self) -> bool { + self.ndims == 0 + } + + /// If this ndarray is unsized, return its sole value as an [`AnyObject`]. + /// Otherwise, do nothing and return the ndarray itself. + pub fn split_unsized( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> ScalarOrNDArray<'ctx> { + if self.is_unsized() { + // NOTE: `np.size(self) == 0` here is never possible. + let zero = Int(SizeT).const_0(generator, ctx.ctx); + let value = self.get_nth_scalar(generator, ctx, zero).value; + + ScalarOrNDArray::Scalar(AnyObject { ty: self.dtype, value }) + } else { + ScalarOrNDArray::NDArray(*self) + } + } + + /// Fill the ndarray with a scalar. + /// + /// `fill_value` must have the same LLVM type as the `dtype` of this ndarray. + pub fn fill( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + value: BasicValueEnum<'ctx>, + ) { + self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| { + let p = nditer.get_pointer(generator, ctx); + ctx.builder.build_store(p, value).unwrap(); + Ok(()) + }) + .unwrap(); + } + + /// Create the shape tuple of this ndarray like `np.shape()`. + /// + /// The returned integers in the tuple are in int32. + pub fn make_shape_tuple( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> TupleObject<'ctx> { + // TODO: Return a tuple of SizeT + + let mut objects = Vec::with_capacity(self.ndims as usize); + + for i in 0..self.ndims { + let dim = self + .instance + .get(generator, ctx, |f| f.shape) + .get_index_const(generator, ctx, i) + .truncate_or_bit_cast(generator, ctx, Int32); + + objects.push(AnyObject { + ty: ctx.primitives.int32, + value: dim.value.as_basic_value_enum(), + }); + } + + TupleObject::from_objects(generator, ctx, objects) + } + + /// Create the strides tuple of this ndarray like `.strides`. + /// + /// The returned integers in the tuple are in int32. + pub fn make_strides_tuple( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> TupleObject<'ctx> { + // TODO: Return a tuple of SizeT. + + let mut objects = Vec::with_capacity(self.ndims as usize); + + for i in 0..self.ndims { + let dim = self + .instance + .get(generator, ctx, |f| f.strides) + .get_index_const(generator, ctx, i) + .truncate_or_bit_cast(generator, ctx, Int32); + + objects.push(AnyObject { + ty: ctx.primitives.int32, + value: dim.value.as_basic_value_enum(), + }); + } + + TupleObject::from_objects(generator, ctx, objects) + } + + /// Create an unsized ndarray to contain `object`. + pub fn make_unsized( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + object: AnyObject<'ctx>, + ) -> NDArrayObject<'ctx> { + // We have to put the value on the stack to get a data pointer. + let data = ctx.builder.build_alloca(object.value.get_type(), "make_unsized").unwrap(); + ctx.builder.build_store(data, object.value).unwrap(); + let data = Ptr(Int(Byte)).pointer_cast(generator, ctx, data); + + let ndarray = NDArrayObject::alloca(generator, ctx, object.ty, 0); + ndarray.instance.set(ctx, |f| f.data, data); + ndarray + } + /// Check if this `NDArray` can be used as an `out` ndarray for an operation. + /// + /// Raise an exception if the shapes do not match. + pub fn assert_can_be_written_by_out( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + out_ndims: u64, + out_shape: Instance<'ctx, Ptr>>, + ) { + let ndarray_ndims = self.ndims_llvm(generator, ctx.ctx); + let ndarray_shape = self.instance.get(generator, ctx, |f| f.shape); + + let output_ndims = Int(SizeT).const_int(generator, ctx.ctx, out_ndims); + let output_shape = out_shape; + + call_nac3_ndarray_util_assert_output_shape_same( + generator, + ctx, + ndarray_ndims, + ndarray_shape, + output_ndims, + output_shape, + ); + } +} + +/// A convenience enum for implementing functions that acts on scalars or ndarrays or both. +#[derive(Debug, Clone, Copy)] +pub enum ScalarOrNDArray<'ctx> { + Scalar(AnyObject<'ctx>), + NDArray(NDArrayObject<'ctx>), +} + +impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for AnyObject<'ctx> { + type Error = (); + + fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result { + match value { + ScalarOrNDArray::Scalar(scalar) => Ok(*scalar), + ScalarOrNDArray::NDArray(_ndarray) => Err(()), + } + } +} + +impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for NDArrayObject<'ctx> { + type Error = (); + + fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result { + match value { + ScalarOrNDArray::Scalar(_scalar) => Err(()), + ScalarOrNDArray::NDArray(ndarray) => Ok(*ndarray), + } + } +} + +impl<'ctx> ScalarOrNDArray<'ctx> { + /// Split on `object` either into a scalar or an ndarray. + /// + /// If `object` is an ndarray, [`ScalarOrNDArray::NDArray`]. + /// + /// For everything else, it is wrapped with [`ScalarOrNDArray::Scalar`]. + pub fn split_object( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + object: AnyObject<'ctx>, + ) -> ScalarOrNDArray<'ctx> { + match &*ctx.unifier.get_ty(object.ty) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => + { + let ndarray = NDArrayObject::from_object(generator, ctx, object); + ScalarOrNDArray::NDArray(ndarray) + } + _ => ScalarOrNDArray::Scalar(object), + } + } + + /// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`]. + #[must_use] + pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> { + match self { + ScalarOrNDArray::Scalar(scalar) => scalar.value, + ScalarOrNDArray::NDArray(ndarray) => ndarray.instance.value.as_basic_value_enum(), + } + } + + /// Convert this [`ScalarOrNDArray`] to an ndarray - behaves like `np.asarray`. + /// - If this is an ndarray, the ndarray is returned. + /// - If this is a scalar, this function returns new ndarray created with [`NDArrayObject::make_unsized`]. + pub fn to_ndarray( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> NDArrayObject<'ctx> { + match self { + ScalarOrNDArray::NDArray(ndarray) => *ndarray, + ScalarOrNDArray::Scalar(scalar) => NDArrayObject::make_unsized(generator, ctx, *scalar), + } + } + + /// Get the dtype of the ndarray created if this were called with [`ScalarOrNDArray::to_ndarray`]. + #[must_use] + pub fn get_dtype(&self) -> Type { + match self { + ScalarOrNDArray::NDArray(ndarray) => ndarray.dtype, + ScalarOrNDArray::Scalar(scalar) => scalar.ty, + } + } +} + +/// An helper enum specifying how a function should produce its output. +/// +/// Many functions in NumPy has an optional `out` parameter (e.g., `matmul`). If `out` is specified +/// with an ndarray, the result of a function will be written to `out`. If `out` is not specified, a function will +/// create a new ndarray and store the result in it. +#[derive(Debug, Clone, Copy)] +pub enum NDArrayOut<'ctx> { + /// Tell a function should create a new ndarray with the expected element type `dtype`. + NewNDArray { dtype: Type }, + /// Tell a function to write the result to `ndarray`. + WriteToNDArray { ndarray: NDArrayObject<'ctx> }, +} + +impl<'ctx> NDArrayOut<'ctx> { + /// Get the dtype of this output. + #[must_use] + pub fn get_dtype(&self) -> Type { + match self { + NDArrayOut::NewNDArray { dtype } => *dtype, + NDArrayOut::WriteToNDArray { ndarray } => ndarray.dtype, + } + } +} + +/// A version of [`call_nac3_ndarray_set_strides_by_shape`] in Rust. +/// +/// This function is used generating strides for globally defined contiguous ndarrays. +#[must_use] +pub fn make_contiguous_strides(itemsize: u64, ndims: u64, shape: &[u64]) -> Vec { + let mut strides = Vec::with_capacity(ndims as usize); + let mut stride_product = 1u64; + for i in 0..ndims { + let axis = ndims - i - 1; + strides[axis as usize] = stride_product * itemsize; + stride_product *= shape[axis as usize]; + } + strides +} diff --git a/nac3core/src/codegen/object/ndarray/nditer.rs b/nac3core/src/codegen/object/ndarray/nditer.rs new file mode 100644 index 00000000..b2a23c0f --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/nditer.rs @@ -0,0 +1,177 @@ +use inkwell::{types::BasicType, values::PointerValue, AddressSpace}; + +use crate::codegen::{ + irrt::{call_nac3_nditer_has_next, call_nac3_nditer_initialize, call_nac3_nditer_next}, + model::*, + object::any::AnyObject, + stmt::{gen_for_callback, BreakContinueHooks}, + CodeGenContext, CodeGenerator, +}; + +use super::NDArrayObject; + +/// Fields of [`NDIter`] +pub struct NDIterFields<'ctx, F: FieldTraversal<'ctx>> { + pub ndims: F::Out>, + pub shape: F::Out>>, + pub strides: F::Out>>, + + pub indices: F::Out>>, + pub nth: F::Out>, + pub element: F::Out>>, + + pub size: F::Out>, +} + +/// An IRRT helper structure used to iterate through an ndarray. +#[derive(Debug, Clone, Copy, Default)] +pub struct NDIter; + +impl<'ctx> StructKind<'ctx> for NDIter { + type Fields> = NDIterFields<'ctx, F>; + + fn traverse_fields>(&self, traversal: &mut F) -> Self::Fields { + Self::Fields { + ndims: traversal.add_auto("ndims"), + shape: traversal.add_auto("shape"), + strides: traversal.add_auto("strides"), + + indices: traversal.add_auto("indices"), + nth: traversal.add_auto("nth"), + element: traversal.add_auto("element"), + + size: traversal.add_auto("size"), + } + } +} + +/// A helper structure with a convenient interface to interact with [`NDIter`]. +#[derive(Debug, Clone)] +pub struct NDIterHandle<'ctx> { + instance: Instance<'ctx, Ptr>>, + /// The ndarray this [`NDIter`] to iterating over. + ndarray: NDArrayObject<'ctx>, + /// The current indices of [`NDIter`]. + indices: Instance<'ctx, Ptr>>, +} + +impl<'ctx> NDIterHandle<'ctx> { + /// Allocate an [`NDIter`] that iterates through an ndarray. + pub fn new( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndarray: NDArrayObject<'ctx>, + ) -> Self { + let nditer = Struct(NDIter).alloca(generator, ctx); + let ndims = ndarray.ndims_llvm(generator, ctx.ctx); + + // The caller has the responsibility to allocate 'indices' for `NDIter`. + let indices = Int(SizeT).array_alloca(generator, ctx, ndims.value); + call_nac3_nditer_initialize(generator, ctx, nditer, ndarray.instance, indices); + + NDIterHandle { ndarray, instance: nditer, indices } + } + + /// Is there a next element? + /// + /// If `ndarray` is unsized, this returns true only for the first iteration. + /// If `ndarray` is 0-sized, this always returns false. + #[must_use] + pub fn has_next( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Instance<'ctx, Int> { + call_nac3_nditer_has_next(generator, ctx, self.instance) + } + + /// Go to the next element. If `has_next()` is false, then this has undefined behavior. + /// + /// If `ndarray` is unsized, this can only be called once. + /// If `ndarray` is 0-sized, this can never be called. + pub fn next( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) { + call_nac3_nditer_next(generator, ctx, self.instance); + } + + /// Get pointer to the current element. + #[must_use] + pub fn get_pointer( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> PointerValue<'ctx> { + let elem_ty = ctx.get_llvm_type(generator, self.ndarray.dtype); + + let p = self.instance.get(generator, ctx, |f| f.element); + ctx.builder + .build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), "element") + .unwrap() + } + + /// Get the value of the current element. + #[must_use] + pub fn get_scalar( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> AnyObject<'ctx> { + let p = self.get_pointer(generator, ctx); + let value = ctx.builder.build_load(p, "value").unwrap(); + AnyObject { ty: self.ndarray.dtype, value } + } + + /// Get the index of the current element if this ndarray were a flat ndarray. + #[must_use] + pub fn get_index( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Instance<'ctx, Int> { + self.instance.get(generator, ctx, |f| f.nth) + } + + /// Get the indices of the current element. + #[must_use] + pub fn get_indices(&self) -> Instance<'ctx, Ptr>> { + self.indices + } +} + +impl<'ctx> NDArrayObject<'ctx> { + /// Iterate through every element in the ndarray. + /// + /// `body` has access to [`BreakContinueHooks`] to short-circuit and [`NDIterHandle`] to + /// get properties of the current iteration (e.g., the current element, indices, etc.) + pub fn foreach<'a, G, F>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + body: F, + ) -> Result<(), String> + where + G: CodeGenerator + ?Sized, + F: FnOnce( + &mut G, + &mut CodeGenContext<'ctx, 'a>, + BreakContinueHooks<'ctx>, + NDIterHandle<'ctx>, + ) -> Result<(), String>, + { + gen_for_callback( + generator, + ctx, + Some("ndarray_foreach"), + |generator, ctx| Ok(NDIterHandle::new(generator, ctx, *self)), + |generator, ctx, nditer| Ok(nditer.has_next(generator, ctx).value), + |generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer), + |generator, ctx, nditer| { + nditer.next(generator, ctx); + Ok(()) + }, + ) + } +} diff --git a/nac3core/src/codegen/object/ndarray/shape_util.rs b/nac3core/src/codegen/object/ndarray/shape_util.rs new file mode 100644 index 00000000..aa6c3f80 --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/shape_util.rs @@ -0,0 +1,105 @@ +use util::gen_for_model; + +use crate::{ + codegen::{ + model::*, + object::{any::AnyObject, list::ListObject, tuple::TupleObject}, + CodeGenContext, CodeGenerator, + }, + typecheck::typedef::TypeEnum, +}; + +/// Parse a NumPy-like "int sequence" input and return the int sequence as an array and its length. +/// +/// * `sequence` - The `sequence` parameter. +/// * `sequence_ty` - The typechecker type of `sequence` +/// +/// The `sequence` argument type may only be one of the following: +/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])` +/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))` +/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` +/// +/// All `int32` values will be sign-extended to `SizeT`. +pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + input_sequence: AnyObject<'ctx>, +) -> (Instance<'ctx, Int>, Instance<'ctx, Ptr>>) { + let zero = Int(SizeT).const_0(generator, ctx.ctx); + let one = Int(SizeT).const_1(generator, ctx.ctx); + + // The result `list` to return. + match &*ctx.unifier.get_ty(input_sequence.ty) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + // 1. A list of `int32`; e.g., `np.empty([600, 800, 3])` + + // Check `input_sequence` + let input_sequence = ListObject::from_object(generator, ctx, input_sequence); + + let len = input_sequence.instance.get(generator, ctx, |f| f.len); + let result = Int(SizeT).array_alloca(generator, ctx, len.value); + + // Load all the `int32`s from the input_sequence, cast them to `SizeT`, and store them into `result` + gen_for_model(generator, ctx, zero, len, one, |generator, ctx, _hooks, i| { + // Load the i-th int32 in the input sequence + let int = input_sequence + .instance + .get(generator, ctx, |f| f.items) + .get_index(generator, ctx, i.value) + .value + .into_int_value(); + + // Cast to SizeT + let int = Int(SizeT).s_extend_or_bit_cast(generator, ctx, int); + + // Store + result.set_index(ctx, i.value, int); + + Ok(()) + }) + .unwrap(); + + (len, result) + } + TypeEnum::TTuple { .. } => { + // 2. A tuple of ints; e.g., `np.empty((600, 800, 3))` + + let input_sequence = TupleObject::from_object(ctx, input_sequence); + + let len = input_sequence.len(generator, ctx); + + let result = Int(SizeT).array_alloca(generator, ctx, len.value); + + for i in 0..input_sequence.num_elements() { + // Get the i-th element off of the tuple and load it into `result`. + let int = input_sequence.index(ctx, i).value.into_int_value(); + let int = Int(SizeT).s_extend_or_bit_cast(generator, ctx, int); + + result.set_index_const(ctx, i as u64, int); + } + + (len, result) + } + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() => + { + // 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` + let input_int = input_sequence.value.into_int_value(); + + let len = Int(SizeT).const_1(generator, ctx.ctx); + let result = Int(SizeT).array_alloca(generator, ctx, len.value); + let int = Int(SizeT).s_extend_or_bit_cast(generator, ctx, input_int); + + // Storing into result[0] + result.store(ctx, int); + + (len, result) + } + _ => panic!( + "encountered unknown sequence type: {}", + ctx.unifier.stringify(input_sequence.ty) + ), + } +} diff --git a/nac3core/src/codegen/object/ndarray/view.rs b/nac3core/src/codegen/object/ndarray/view.rs new file mode 100644 index 00000000..bae0bbdd --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/view.rs @@ -0,0 +1,119 @@ +use crate::codegen::{ + irrt::{call_nac3_ndarray_reshape_resolve_and_check_new_shape, call_nac3_ndarray_transpose}, + model::*, + CodeGenContext, CodeGenerator, +}; + +use super::{indexing::RustNDIndex, NDArrayObject}; + +impl<'ctx> NDArrayObject<'ctx> { + /// Make sure the ndarray is at least `ndmin`-dimensional. + /// + /// If this ndarray's `ndims` is less than `ndmin`, a view is created on this with 1s prepended to the shape. + /// If this ndarray's `ndims` is not less than `ndmin`, this function does nothing and return this ndarray. + #[must_use] + pub fn atleast_nd( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ndmin: u64, + ) -> Self { + if self.ndims < ndmin { + // Extend the dimensions with np.newaxis. + let mut indices = vec![]; + for _ in self.ndims..ndmin { + indices.push(RustNDIndex::NewAxis); + } + indices.push(RustNDIndex::Ellipsis); + self.index(generator, ctx, &indices) + } else { + *self + } + } + + /// Create a reshaped view on this ndarray like `np.reshape()`. + /// + /// If there is a `-1` in `new_shape`, it will be resolved; `new_shape` would **NOT** be modified as a result. + /// + /// If reshape without copying is impossible, this function will allocate a new ndarray and copy contents. + /// + /// * `new_ndims` - The number of dimensions of `new_shape` as a [`Type`]. + /// * `new_shape` - The target shape to do `np.reshape()`. + #[must_use] + pub fn reshape_or_copy( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + new_ndims: u64, + new_shape: Instance<'ctx, Ptr>>, + ) -> Self { + // TODO: The current criterion for whether to do a full copy or not is by checking `is_c_contiguous`, + // but this is not optimal - there are cases when the ndarray is not contiguous but could be reshaped + // without copying data. Look into how numpy does it. + + let current_bb = ctx.builder.get_insert_block().unwrap(); + let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "then_bb"); + let else_bb = ctx.ctx.insert_basic_block_after(then_bb, "else_bb"); + let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "end_bb"); + + let dst_ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, new_ndims); + dst_ndarray.copy_shape_from_array(generator, ctx, new_shape); + + // Reolsve negative indices + let size = self.size(generator, ctx); + let dst_ndims = dst_ndarray.ndims_llvm(generator, ctx.ctx); + let dst_shape = dst_ndarray.instance.get(generator, ctx, |f| f.shape); + call_nac3_ndarray_reshape_resolve_and_check_new_shape( + generator, ctx, size, dst_ndims, dst_shape, + ); + + let is_c_contiguous = self.is_c_contiguous(generator, ctx); + ctx.builder.build_conditional_branch(is_c_contiguous.value, then_bb, else_bb).unwrap(); + + // Inserting into then_bb: reshape is possible without copying + ctx.builder.position_at_end(then_bb); + dst_ndarray.set_strides_contiguous(generator, ctx); + dst_ndarray.instance.set(ctx, |f| f.data, self.instance.get(generator, ctx, |f| f.data)); + ctx.builder.build_unconditional_branch(end_bb).unwrap(); + + // Inserting into else_bb: reshape is impossible without copying + ctx.builder.position_at_end(else_bb); + dst_ndarray.create_data(generator, ctx); + dst_ndarray.copy_data_from(generator, ctx, *self); + ctx.builder.build_unconditional_branch(end_bb).unwrap(); + + // Reposition for continuation + ctx.builder.position_at_end(end_bb); + + dst_ndarray + } + + /// Create a transposed view on this ndarray like `np.transpose(, = None)`. + /// * `axes` - If specified, should be an array of the permutation (negative indices are **allowed**). + #[must_use] + pub fn transpose( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + axes: Option>>>, + ) -> Self { + // Define models + let transposed_ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, self.ndims); + + let num_axes = self.ndims_llvm(generator, ctx.ctx); + + // `axes = nullptr` if `axes` is unspecified. + let axes = axes.unwrap_or_else(|| Ptr(Int(SizeT)).nullptr(generator, ctx.ctx)); + + call_nac3_ndarray_transpose( + generator, + ctx, + self.instance, + transposed_ndarray.instance, + num_axes, + axes, + ); + + transposed_ndarray + } +} diff --git a/nac3core/src/codegen/object/tuple.rs b/nac3core/src/codegen/object/tuple.rs new file mode 100644 index 00000000..00e52a7e --- /dev/null +++ b/nac3core/src/codegen/object/tuple.rs @@ -0,0 +1,99 @@ +use inkwell::values::StructValue; +use itertools::Itertools; + +use crate::{ + codegen::{model::*, CodeGenContext, CodeGenerator}, + typecheck::typedef::{Type, TypeEnum}, +}; + +use super::any::AnyObject; + +/// A NAC3 tuple object. +/// +/// NOTE: This struct has no copy trait. +#[derive(Debug, Clone)] +pub struct TupleObject<'ctx> { + /// The type of the tuple. + pub tys: Vec, + /// The underlying LLVM struct value of this tuple. + pub value: StructValue<'ctx>, +} + +impl<'ctx> TupleObject<'ctx> { + pub fn from_object(ctx: &mut CodeGenContext<'ctx, '_>, object: AnyObject<'ctx>) -> Self { + // TODO: Keep `is_vararg_ctx` from TTuple? + + // Sanity check on object type. + let TypeEnum::TTuple { ty: tys, .. } = &*ctx.unifier.get_ty(object.ty) else { + panic!( + "Expected type to be a TypeEnum::TTuple, got {}", + ctx.unifier.stringify(object.ty) + ); + }; + + // Check number of fields + let value = object.value.into_struct_value(); + let value_num_fields = value.get_type().count_fields() as usize; + assert!( + value_num_fields == tys.len(), + "Tuple type has {} item(s), but the LLVM struct value has {} field(s)", + tys.len(), + value_num_fields + ); + + TupleObject { tys: tys.clone(), value } + } + + /// Convenience function. Create a [`TupleObject`] from an iterator of objects. + pub fn from_objects( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + objects: I, + ) -> Self + where + I: IntoIterator>, + { + let (values, tys): (Vec<_>, Vec<_>) = + objects.into_iter().map(|object| (object.value, object.ty)).unzip(); + + let llvm_tys = tys.iter().map(|ty| ctx.get_llvm_type(generator, *ty)).collect_vec(); + let llvm_tuple_ty = ctx.ctx.struct_type(&llvm_tys, false); + + let pllvm_tuple = ctx.builder.build_alloca(llvm_tuple_ty, "tuple").unwrap(); + for (i, val) in values.into_iter().enumerate() { + let pval = ctx.builder.build_struct_gep(pllvm_tuple, i as u32, "value").unwrap(); + ctx.builder.build_store(pval, val).unwrap(); + } + + let value = ctx.builder.build_load(pllvm_tuple, "").unwrap().into_struct_value(); + TupleObject { tys, value } + } + + #[must_use] + pub fn num_elements(&self) -> usize { + self.tys.len() + } + + /// Get the `len()` of this tuple. + #[must_use] + pub fn len( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Instance<'ctx, Int> { + Int(SizeT).const_int(generator, ctx.ctx, self.num_elements() as u64) + } + + /// Get the `i`-th (0-based) object in this tuple. + pub fn index(&self, ctx: &mut CodeGenContext<'ctx, '_>, i: usize) -> AnyObject<'ctx> { + assert!( + i < self.num_elements(), + "Tuple object with length {} have index {i}", + self.num_elements() + ); + + let value = ctx.builder.build_extract_value(self.value, i as u32, "tuple[{i}]").unwrap(); + let ty = self.tys[i]; + AnyObject { ty, value } + } +} diff --git a/nac3core/src/codegen/object/utils/mod.rs b/nac3core/src/codegen/object/utils/mod.rs new file mode 100644 index 00000000..913812d4 --- /dev/null +++ b/nac3core/src/codegen/object/utils/mod.rs @@ -0,0 +1 @@ +pub mod slice; diff --git a/nac3core/src/codegen/object/utils/slice.rs b/nac3core/src/codegen/object/utils/slice.rs new file mode 100644 index 00000000..6ea145b1 --- /dev/null +++ b/nac3core/src/codegen/object/utils/slice.rs @@ -0,0 +1,125 @@ +use crate::codegen::{model::*, CodeGenContext, CodeGenerator}; + +/// Fields of [`Slice`] +#[derive(Debug, Clone)] +pub struct SliceFields<'ctx, F: FieldTraversal<'ctx>, N: IntKind<'ctx>> { + pub start_defined: F::Out>, + pub start: F::Out>, + pub stop_defined: F::Out>, + pub stop: F::Out>, + pub step_defined: F::Out>, + pub step: F::Out>, +} + +/// An IRRT representation of an (unresolved) slice. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub struct Slice(pub N); + +impl<'ctx, N: IntKind<'ctx>> StructKind<'ctx> for Slice { + type Fields> = SliceFields<'ctx, F, N>; + + fn traverse_fields>(&self, traversal: &mut F) -> Self::Fields { + Self::Fields { + start_defined: traversal.add_auto("start_defined"), + start: traversal.add("start", Int(self.0)), + stop_defined: traversal.add_auto("stop_defined"), + stop: traversal.add("stop", Int(self.0)), + step_defined: traversal.add_auto("step_defined"), + step: traversal.add("step", Int(self.0)), + } + } +} + +/// A Rust structure that has [`Slice`] utilities and looks like a [`Slice`] but +/// `start`, `stop` and `step` are held by LLVM registers only and possibly +/// [`Option::None`] if unspecified. +#[derive(Debug, Clone)] +pub struct RustSlice<'ctx, N: IntKind<'ctx>> { + // It is possible that `start`, `stop`, and `step` are all `None`. + // We need to know the `int_kind` even when that is the case. + pub int_kind: N, + pub start: Option>>, + pub stop: Option>>, + pub step: Option>>, +} + +impl<'ctx, N: IntKind<'ctx>> RustSlice<'ctx, N> { + /// Write the contents to an LLVM [`Slice`]. + pub fn write_to_slice( + &self, + generator: &mut G, + ctx: &CodeGenContext<'ctx, '_>, + dst_slice_ptr: Instance<'ctx, Ptr>>>, + ) { + let false_ = Int(Bool).const_false(generator, ctx.ctx); + let true_ = Int(Bool).const_true(generator, ctx.ctx); + + match self.start { + Some(start) => { + dst_slice_ptr.gep(ctx, |f| f.start_defined).store(ctx, true_); + dst_slice_ptr.gep(ctx, |f| f.start).store(ctx, start); + } + None => dst_slice_ptr.gep(ctx, |f| f.start_defined).store(ctx, false_), + } + + match self.stop { + Some(stop) => { + dst_slice_ptr.gep(ctx, |f| f.stop_defined).store(ctx, true_); + dst_slice_ptr.gep(ctx, |f| f.stop).store(ctx, stop); + } + None => dst_slice_ptr.gep(ctx, |f| f.stop_defined).store(ctx, false_), + } + + match self.step { + Some(step) => { + dst_slice_ptr.gep(ctx, |f| f.step_defined).store(ctx, true_); + dst_slice_ptr.gep(ctx, |f| f.step).store(ctx, step); + } + None => dst_slice_ptr.gep(ctx, |f| f.step_defined).store(ctx, false_), + } + } +} + +pub mod util { + use nac3parser::ast::Expr; + + use crate::{ + codegen::{model::*, CodeGenContext, CodeGenerator}, + typecheck::typedef::Type, + }; + + use super::RustSlice; + + /// Generate LLVM IR for an [`ExprKind::Slice`] and convert it into a [`RustSlice`]. + #[allow(clippy::type_complexity)] + pub fn gen_slice<'ctx, G: CodeGenerator>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + lower: &Option>>>, + upper: &Option>>>, + step: &Option>>>, + ) -> Result, String> { + let mut help = |value_expr: &Option>>>| -> Result<_, String> { + Ok(match value_expr { + None => None, + Some(value_expr) => { + let value_expr = generator + .gen_expr(ctx, value_expr)? + .unwrap() + .to_basic_value_enum(ctx, generator, ctx.primitives.int32)?; + + let value_expr = + Int(Int32).check_value(generator, ctx.ctx, value_expr).unwrap(); + + Some(value_expr) + } + }) + }; + + let start = help(lower)?; + let stop = help(upper)?; + let step = help(step)?; + + Ok(RustSlice { int_kind: Int32, start, stop, step }) + } +} diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 081a5cef..7e053c06 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -2,6 +2,12 @@ use super::{ super::symbol_resolver::ValueEnum, expr::destructure_range, irrt::{handle_slice_indices, list_slice_assignment}, + object::{ + any::AnyObject, + ndarray::{ + indexing::util::gen_ndarray_subscript_ndindices, NDArrayObject, ScalarOrNDArray, + }, + }, CodeGenContext, CodeGenerator, }; use crate::{ @@ -401,7 +407,47 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>( if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => { // Handle NDArray item assignment - todo!("ndarray subscript assignment is not yet implemented"); + // Process target + let target = generator + .gen_expr(ctx, target)? + .unwrap() + .to_basic_value_enum(ctx, generator, target_ty)?; + let target = AnyObject { value: target, ty: target_ty }; + + // Process key + let key = gen_ndarray_subscript_ndindices(generator, ctx, key)?; + + // Process value + let value = value.to_basic_value_enum(ctx, generator, value_ty)?; + let value = AnyObject { value, ty: value_ty }; + + /* + Reference code: + ```python + target = target[key] + value = np.asarray(value) + + shape = np.broadcast_shape((target, value)) + + target = np.broadcast_to(target, shape) + value = np.broadcast_to(value, shape) + + ...and finally copy 1-1 from value to target. + ``` + */ + + let target = NDArrayObject::from_object(generator, ctx, target); + let target = target.index(generator, ctx, &key); + + let value = + ScalarOrNDArray::split_object(generator, ctx, value).to_ndarray(generator, ctx); + + let broadcast_result = NDArrayObject::broadcast(generator, ctx, &[target, value]); + + let target = broadcast_result.ndarrays[0]; + let value = broadcast_result.ndarrays[1]; + + target.copy_data_from(generator, ctx, value); } _ => { panic!("encountered unknown target type: {}", ctx.unifier.stringify(target_ty)); diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 9ed495e0..09c99553 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -1,6 +1,6 @@ use crate::{ codegen::{ - classes::{ListType, NDArrayType, ProxyType, RangeType}, + classes::{ListType, ProxyType, RangeType}, concrete_type::ConcreteTypeStore, CodeGenContext, CodeGenLLVMOptions, CodeGenTargetMachineOptions, CodeGenTask, CodeGenerator, DefaultCodeGenerator, WithCall, WorkerRegistry, @@ -456,15 +456,3 @@ fn test_classes_range_type_new() { let llvm_range = RangeType::new(&ctx); assert!(RangeType::is_type(llvm_range.as_base_type()).is_ok()); } - -#[test] -fn test_classes_ndarray_type_new() { - let ctx = inkwell::context::Context::create(); - let generator = DefaultCodeGenerator::new(String::new(), 64); - - let llvm_i32 = ctx.i32_type(); - let llvm_usize = generator.get_size_type(&ctx); - - let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into()); - assert!(NDArrayType::is_type(llvm_ndarray.as_base_type(), llvm_usize).is_ok()); -} diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index e2325ebb..1a810f28 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1,6 +1,6 @@ use std::iter::once; -use helper::{debug_assert_prim_is_allowed, make_exception_fields, PrimDefDetails}; +use helper::{debug_assert_prim_is_allowed, extract_ndims, make_exception_fields, PrimDefDetails}; use indexmap::IndexMap; use inkwell::{ attributes::{Attribute, AttributeLoc}, @@ -9,13 +9,19 @@ use inkwell::{ IntPredicate, }; use itertools::Either; +use numpy::unpack_ndarray_var_tys; use strum::IntoEnumIterator; use crate::{ codegen::{ builtin_fns, classes::{ProxyValue, RangeValue}, + model::*, numpy::*, + object::{ + any::AnyObject, + ndarray::{shape_util::parse_numpy_int_sequence, NDArrayObject}, + }, stmt::exn_constructor, }, symbol_resolver::SymbolValue, @@ -511,6 +517,14 @@ impl<'a> BuiltinBuilder<'a> { | PrimDef::FunNpEye | PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim), + PrimDef::FunNpSize | PrimDef::FunNpShape | PrimDef::FunNpStrides => { + self.build_ndarray_property_getter_function(prim) + } + + PrimDef::FunNpBroadcastTo | PrimDef::FunNpTranspose | PrimDef::FunNpReshape => { + self.build_ndarray_view_function(prim) + } + PrimDef::FunStr => self.build_str_function(), PrimDef::FunFloor | PrimDef::FunFloor64 | PrimDef::FunCeil | PrimDef::FunCeil64 => { @@ -576,10 +590,6 @@ impl<'a> BuiltinBuilder<'a> { | PrimDef::FunNpHypot | PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim), - PrimDef::FunNpTranspose | PrimDef::FunNpReshape => { - self.build_np_sp_ndarray_function(prim) - } - PrimDef::FunNpDot | PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgQr @@ -1385,6 +1395,171 @@ impl<'a> BuiltinBuilder<'a> { } } + fn build_ndarray_property_getter_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed( + prim, + &[PrimDef::FunNpSize, PrimDef::FunNpShape, PrimDef::FunNpStrides], + ); + + let in_ndarray_ty = self.unifier.get_fresh_var_with_range( + &[self.primitives.ndarray], + Some("T".into()), + None, + ); + + match prim { + PrimDef::FunNpSize => create_fn_by_codegen( + self.unifier, + &into_var_map([in_ndarray_ty]), + prim.name(), + self.primitives.int32, + &[(in_ndarray_ty.ty, "a")], + Box::new(|ctx, obj, fun, args, generator| { + assert!(obj.is_none()); + assert_eq!(args.len(), 1); + + let ndarray_ty = fun.0.args[0].ty; + let ndarray = + args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?; + let ndarray = AnyObject { ty: ndarray_ty, value: ndarray }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); + + let size = + ndarray.size(generator, ctx).truncate_or_bit_cast(generator, ctx, Int32); + Ok(Some(size.value.as_basic_value_enum())) + }), + ), + PrimDef::FunNpShape | PrimDef::FunNpStrides => { + // The function signatures of `np_shape` an `np_size` are the same. + // Mixed together for convenience. + + // The return type is a tuple of variable length depending on the ndims of the input ndarray. + let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special folding + + create_fn_by_codegen( + self.unifier, + &into_var_map([in_ndarray_ty]), + prim.name(), + ret_ty, + &[(in_ndarray_ty.ty, "a")], + Box::new(move |ctx, obj, fun, args, generator| { + assert!(obj.is_none()); + assert_eq!(args.len(), 1); + + let ndarray_ty = fun.0.args[0].ty; + let ndarray = + args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?; + + let ndarray = AnyObject { ty: ndarray_ty, value: ndarray }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); + + let result_tuple = match prim { + PrimDef::FunNpShape => ndarray.make_shape_tuple(generator, ctx), + PrimDef::FunNpStrides => ndarray.make_strides_tuple(generator, ctx), + _ => unreachable!(), + }; + + Ok(Some(result_tuple.value.as_basic_value_enum())) + }), + ) + } + _ => unreachable!(), + } + } + + /// Build np/sp functions that take as input `NDArray` only + fn build_ndarray_view_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed( + prim, + &[PrimDef::FunNpBroadcastTo, PrimDef::FunNpTranspose, PrimDef::FunNpReshape], + ); + + let in_ndarray_ty = self.unifier.get_fresh_var_with_range( + &[self.primitives.ndarray], + Some("T".into()), + None, + ); + + match prim { + PrimDef::FunNpTranspose => { + create_fn_by_codegen( + self.unifier, + &into_var_map([in_ndarray_ty]), + prim.name(), + in_ndarray_ty.ty, + &[(in_ndarray_ty.ty, "x")], + Box::new(move |ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg_val = + args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; + + let arg = AnyObject { ty: arg_ty, value: arg_val }; + let ndarray = NDArrayObject::from_object(generator, ctx, arg); + + let ndarray = ndarray.transpose(generator, ctx, None); // TODO: Add axes argument + Ok(Some(ndarray.instance.value.as_basic_value_enum())) + }), + ) + } + + // NOTE: on `ndarray_factory_fn_shape_arg_tvar` and + // the `param_ty` for `create_fn_by_codegen`. + // + // Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking + // to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`], + // and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`. + PrimDef::FunNpBroadcastTo | PrimDef::FunNpReshape => { + // These two functions have the same function signature. + // Mixed together for convenience. + + let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special holding + + create_fn_by_codegen( + self.unifier, + &VarMap::new(), + prim.name(), + ret_ty, + &[ + (in_ndarray_ty.ty, "x"), + (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape"), // Handled by special folding + ], + Box::new(move |ctx, _, fun, args, generator| { + let ndarray_ty = fun.0.args[0].ty; + let ndarray_val = + args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?; + + let shape_ty = fun.0.args[1].ty; + let shape_val = + args[1].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?; + + let ndarray = AnyObject { value: ndarray_val, ty: ndarray_ty }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); + + let shape = AnyObject { value: shape_val, ty: shape_ty }; + let (_, shape) = parse_numpy_int_sequence(generator, ctx, shape); + + // The ndims after reshaping is gotten from the return type of the call. + let (_, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret); + let ndims = extract_ndims(&ctx.unifier, ndims); + + let new_ndarray = match prim { + PrimDef::FunNpBroadcastTo => { + ndarray.broadcast_to(generator, ctx, ndims, shape) + } + PrimDef::FunNpReshape => { + ndarray.reshape_or_copy(generator, ctx, ndims, shape) + } + _ => unreachable!(), + }; + Ok(Some(new_ndarray.instance.value.as_basic_value_enum())) + }), + ) + } + + _ => unreachable!(), + } + } + /// Build the `str()` function. fn build_str_function(&mut self) -> TopLevelDef { let prim = PrimDef::FunStr; @@ -1872,57 +2047,6 @@ impl<'a> BuiltinBuilder<'a> { } } - /// Build np/sp functions that take as input `NDArray` only - fn build_np_sp_ndarray_function(&mut self, prim: PrimDef) -> TopLevelDef { - debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]); - - match prim { - PrimDef::FunNpTranspose => { - let ndarray_ty = self.unifier.get_fresh_var_with_range( - &[self.ndarray_num_ty], - Some("T".into()), - None, - ); - create_fn_by_codegen( - self.unifier, - &into_var_map([ndarray_ty]), - prim.name(), - ndarray_ty.ty, - &[(ndarray_ty.ty, "x")], - Box::new(move |ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg_val = - args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?)) - }), - ) - } - - // NOTE: on `ndarray_factory_fn_shape_arg_tvar` and - // the `param_ty` for `create_fn_by_codegen`. - // - // Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking - // to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`], - // and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`. - PrimDef::FunNpReshape => create_fn_by_codegen( - self.unifier, - &VarMap::new(), - prim.name(), - self.ndarray_num_ty, - &[(self.ndarray_num_ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")], - Box::new(move |ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_ty = fun.0.args[1].ty; - let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) - }), - ), - - _ => unreachable!(), - } - } - /// Build `np_linalg` and `sp_linalg` functions /// /// The input to these functions must be floating point `NDArray` @@ -1954,10 +2078,12 @@ impl<'a> BuiltinBuilder<'a> { Box::new(move |ctx, _, fun, args, generator| { let x1_ty = fun.0.args[0].ty; let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - Ok(Some(ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) + let result = ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?; + Ok(Some(result)) }), ), diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 21aeb9db..2533489a 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -52,6 +52,16 @@ pub enum PrimDef { FunNpEye, FunNpIdentity, + // NumPy ndarray property getters + FunNpSize, + FunNpShape, + FunNpStrides, + + // NumPy ndarray view functions + FunNpBroadcastTo, + FunNpTranspose, + FunNpReshape, + // Miscellaneous NumPy & SciPy functions FunNpRound, FunNpFloor, @@ -99,8 +109,6 @@ pub enum PrimDef { FunNpLdExp, FunNpHypot, FunNpNextAfter, - FunNpTranspose, - FunNpReshape, // Linalg functions FunNpDot, @@ -238,6 +246,16 @@ impl PrimDef { PrimDef::FunNpEye => fun("np_eye", None), PrimDef::FunNpIdentity => fun("np_identity", None), + // NumPy NDArray property getters, + PrimDef::FunNpSize => fun("np_size", None), + PrimDef::FunNpShape => fun("np_shape", None), + PrimDef::FunNpStrides => fun("np_strides", None), + + // NumPy NDArray view functions + PrimDef::FunNpBroadcastTo => fun("np_broadcast_to", None), + PrimDef::FunNpTranspose => fun("np_transpose", None), + PrimDef::FunNpReshape => fun("np_reshape", None), + // Miscellaneous NumPy & SciPy functions PrimDef::FunNpRound => fun("np_round", None), PrimDef::FunNpFloor => fun("np_floor", None), @@ -285,8 +303,6 @@ impl PrimDef { PrimDef::FunNpLdExp => fun("np_ldexp", None), PrimDef::FunNpHypot => fun("np_hypot", None), PrimDef::FunNpNextAfter => fun("np_nextafter", None), - PrimDef::FunNpTranspose => fun("np_transpose", None), - PrimDef::FunNpReshape => fun("np_reshape", None), // Linalg functions PrimDef::FunNpDot => fun("np_dot", None), @@ -1000,3 +1016,23 @@ pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 { _ => 0, } } + +/// Extract an ndarray's `ndims` [type][`Type`] in `u64`. Panic if not possible. +/// The `ndims` must only contain 1 value. +#[must_use] +pub fn extract_ndims(unifier: &Unifier, ndims_ty: Type) -> u64 { + let ndims_ty_enum = unifier.get_ty_immutable(ndims_ty); + let TypeEnum::TLiteral { values, .. } = &*ndims_ty_enum else { + panic!("ndims_ty should be a TLiteral"); + }; + + assert_eq!(values.len(), 1, "ndims_ty TLiteral should only contain 1 value"); + + let ndims = values[0].clone(); + u64::try_from(ndims).unwrap() +} + +/// Return an ndarray's `ndims` as a typechecker [`Type`] from its `u64` value. +pub fn create_ndims(unifier: &mut Unifier, ndims: u64) -> Type { + unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None) +} diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index 53ff774f..a997444a 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -5,7 +5,7 @@ expression: res_vec [ "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(241)]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(257)]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index 2621337c..df4517b4 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar230]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar230\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar246]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar246\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index d0769305..b1d9bf99 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(243)]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(248)]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(259)]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(264)]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index 5ebdf86c..83746579 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar229, typevar230]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar229\", \"typevar230\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar245, typevar246]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar245\", \"typevar246\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index 502abbd6..aea0aee4 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(249)]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(265)]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(257)]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(273)]\n}\n", ] diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index 325f837a..cc36754c 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -1,5 +1,5 @@ use crate::symbol_resolver::SymbolValue; -use crate::toplevel::helper::PrimDef; +use crate::toplevel::helper::{extract_ndims, PrimDef}; use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys}; use crate::typecheck::{ type_inferencer::*, @@ -13,6 +13,8 @@ use std::collections::HashMap; use std::rc::Rc; use strum::IntoEnumIterator; +use super::typedef::into_var_map; + /// The variant of a binary operator. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum BinopVariant { @@ -171,19 +173,8 @@ pub fn impl_binop( ops: &[Operator], ) { with_fields(unifier, ty, |unifier, fields| { - let (other_ty, other_var_id) = if other_ty.len() == 1 { - (other_ty[0], None) - } else { - let tvar = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None); - (tvar.ty, Some(tvar.id)) - }; - - let function_vars = if let Some(var_id) = other_var_id { - vec![(var_id, other_ty)].into_iter().collect::() - } else { - VarMap::new() - }; - + let other_tvar = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None); + let function_vars = into_var_map([other_tvar]); let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty); for (base_op, variant) in iproduct!(ops, [BinopVariant::Normal, BinopVariant::AugAssign]) { @@ -194,7 +185,7 @@ pub fn impl_binop( ret: ret_ty, vars: function_vars.clone(), args: vec![FuncArg { - ty: other_ty, + ty: other_tvar.ty, default_value: None, name: "other".into(), is_vararg: false, @@ -520,36 +511,41 @@ pub fn typeof_binop( } Operator::MatMult => { - let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs); - let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) { - TypeEnum::TLiteral { values, .. } => { - assert_eq!(values.len(), 1); - u64::try_from(values[0].clone()).unwrap() + let (lhs_dtype, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs); + let lhs_ndims = extract_ndims(unifier, lhs_ndims); + + let (rhs_dtype, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs); + let rhs_ndims = extract_ndims(unifier, rhs_ndims); + + if !(unifier.unioned(lhs_dtype, primitives.float) + && unifier.unioned(rhs_dtype, primitives.float)) + { + return Err(format!( + "ndarray.__matmul__ only supports float64 operations, but LHS has type {} and RHS has type {}", + unifier.stringify(lhs), + unifier.stringify(rhs) + )); + } + + let result_ndims = match (lhs_ndims, rhs_ndims) { + (0, _) | (_, 0) => { + return Err( + "ndarray.__matmul__ does not allow unsized ndarray input".to_string() + ) } - _ => unreachable!(), - }; - let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs); - let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) { - TypeEnum::TLiteral { values, .. } => { - assert_eq!(values.len(), 1); - u64::try_from(values[0].clone()).unwrap() - } - _ => unreachable!(), + (1, 1) => 0, + (1, _) => rhs_ndims - 1, + (_, 1) => lhs_ndims - 1, + (m, n) => max(m, n), }; - match (lhs_ndims, rhs_ndims) { - (2, 2) => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?, - (lhs, rhs) if lhs == 0 || rhs == 0 => { - return Err(format!( - "Input operand {} does not have enough dimensions (has {lhs}, requires {rhs})", - u8::from(rhs == 0) - )) - } - (lhs, rhs) => { - return Err(format!( - "ndarray.__matmul__ on {lhs}D and {rhs}D operands not supported" - )) - } + if result_ndims == 0 { + // If the result is unsized, NumPy returns a scalar. + primitives.float + } else { + let result_ndims_ty = + unifier.get_fresh_literal(vec![SymbolValue::U64(result_ndims)], None); + make_ndarray_ty(unifier, primitives, Some(primitives.float), Some(result_ndims_ty)) } } @@ -752,7 +748,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None); impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); - impl_matmul(unifier, store, ndarray_t, &[ndarray_t], Some(ndarray_t)); + impl_matmul(unifier, store, ndarray_t, &[ndarray_unsized_t], None); impl_sign(unifier, store, ndarray_t, Some(ndarray_t)); impl_invert(unifier, store, ndarray_t, Some(ndarray_t)); impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 0408cf1c..a2a2710f 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -1,7 +1,7 @@ use std::cmp::max; use std::collections::{HashMap, HashSet}; use std::convert::{From, TryInto}; -use std::iter::once; +use std::iter::{self, once}; use std::{cell::RefCell, sync::Arc}; use super::{ @@ -1181,6 +1181,45 @@ impl<'a> Inferencer<'a> { })); } + if ["np_shape".into(), "np_strides".into()].contains(id) && args.len() == 1 { + let ndarray = self.fold_expr(args.remove(0))?; + + let ndims = arraylike_get_ndims(self.unifier, ndarray.custom.unwrap()); + + // Make a tuple of size `ndims` full of int32 (TODO: Make it usize) + let ret_ty = TypeEnum::TTuple { + ty: iter::repeat(self.primitives.int32).take(ndims as usize).collect_vec(), + is_vararg_ctx: false, + }; + let ret_ty = self.unifier.add_ty(ret_ty); + + let func_ty = TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { + name: "a".into(), + default_value: None, + ty: ndarray.custom.unwrap(), + is_vararg: false, + }], + ret: ret_ty, + vars: VarMap::new(), + }); + let func_ty = self.unifier.add_ty(func_ty); + + return Ok(Some(Located { + location, + custom: Some(ret_ty), + node: ExprKind::Call { + func: Box::new(Located { + custom: Some(func_ty), + location: func.location, + node: ExprKind::Name { id: *id, ctx: *ctx }, + }), + args: vec![ndarray], + keywords: vec![], + }, + })); + } + if id == &"np_dot".into() { let arg0 = self.fold_expr(args.remove(0))?; let arg1 = self.fold_expr(args.remove(0))?; @@ -1502,7 +1541,7 @@ impl<'a> Inferencer<'a> { })); } // 2-argument ndarray n-dimensional factory functions - if id == &"np_reshape".into() && args.len() == 2 { + if ["np_reshape".into(), "np_broadcast_to".into()].contains(id) && args.len() == 2 { let arg0 = self.fold_expr(args.remove(0))?; let shape_expr = args.remove(0); diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 4f19db95..8784ce53 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -179,6 +179,16 @@ def patch(module): module.np_identity = np.identity module.np_array = np.array + # NumPy NDArray view functions + module.np_broadcast_to = np.broadcast_to + module.np_transpose = np.transpose + module.np_reshape = np.reshape + + # NumPy NDArray property getters + module.np_size = np.size + module.np_shape = np.shape + module.np_strides = lambda ndarray: ndarray.strides + # NumPy Math functions module.np_isnan = np.isnan module.np_isinf = np.isinf @@ -218,8 +228,6 @@ def patch(module): module.np_ldexp = np.ldexp module.np_hypot = np.hypot module.np_nextafter = np.nextafter - module.np_transpose = np.transpose - module.np_reshape = np.reshape # SciPy Math functions module.sp_spec_erf = special.erf diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 9664b3f0..3d4cae67 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -68,6 +68,19 @@ def output_ndarray_float_2(n: ndarray[float, Literal[2]]): for c in range(len(n[r])): output_float64(n[r][c]) +def output_ndarray_float_3(n: ndarray[float, Literal[3]]): + for d in range(len(n)): + for r in range(len(n[d])): + for c in range(len(n[d][r])): + output_float64(n[d][r][c]) + +def output_ndarray_float_4(n: ndarray[float, Literal[4]]): + for x in range(len(n)): + for y in range(len(n[x])): + for z in range(len(n[x][y])): + for w in range(len(n[x][y][z])): + output_float64(n[x][y][z][w]) + def consume_ndarray_1(n: ndarray[float, Literal[1]]): pass @@ -186,6 +199,104 @@ def test_ndarray_nd_idx(): output_float64(x[1, 0]) output_float64(x[1, 1]) +def test_ndarray_transpose(): + x: ndarray[float, 2] = np_array([[1., 2., 3.], [4., 5., 6.]]) + y = np_transpose(x) + z = np_transpose(y) + + output_int32(np_shape(x)[0]) + output_int32(np_shape(x)[1]) + output_ndarray_float_2(x) + + output_int32(np_shape(y)[0]) + output_int32(np_shape(y)[1]) + output_ndarray_float_2(y) + + output_int32(np_shape(z)[0]) + output_int32(np_shape(z)[1]) + output_ndarray_float_2(z) + +def test_ndarray_reshape(): + w: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) + x = np_reshape(w, (1, 2, 1, -1)) + y = np_reshape(x, [2, -1]) + z = np_reshape(y, 10) + + output_int32(np_shape(w)[0]) + output_ndarray_float_1(w) + + output_int32(np_shape(x)[0]) + output_int32(np_shape(x)[1]) + output_int32(np_shape(x)[2]) + output_int32(np_shape(x)[3]) + output_ndarray_float_4(x) + + output_int32(np_shape(y)[0]) + output_int32(np_shape(y)[1]) + output_ndarray_float_2(y) + + output_int32(np_shape(z)[0]) + output_ndarray_float_1(z) + + x1: ndarray[int32, 1] = np_array([1, 2, 3, 4]) + x2: ndarray[int32, 2] = np_reshape(x1, (2, 2)) + + output_int32(np_shape(x1)[0]) + output_ndarray_int32_1(x1) + + output_int32(np_shape(x2)[0]) + output_int32(np_shape(x2)[1]) + output_ndarray_int32_2(x2) + +def test_ndarray_broadcast_to(): + xs = np_array([1.0, 2.0, 3.0]) + ys = np_broadcast_to(xs, (1, 3)) + zs = np_broadcast_to(ys, (2, 4, 3)) + + output_int32(np_shape(xs)[0]) + output_ndarray_float_1(xs) + + output_int32(np_shape(ys)[0]) + output_int32(np_shape(ys)[1]) + output_ndarray_float_2(ys) + + output_int32(np_shape(zs)[0]) + output_int32(np_shape(zs)[1]) + output_int32(np_shape(zs)[2]) + output_ndarray_float_3(zs) + +def test_ndarray_subscript_assignment(): + xs = np_array([[11.0, 22.0, 33.0, 44.0], [55.0, 66.0, 77.0, 88.0]]) + + xs[0, 0] = 99.0 + output_ndarray_float_2(xs) + + xs[0] = 100.0 + output_ndarray_float_2(xs) + + xs[:, ::2] = 101.0 + output_ndarray_float_2(xs) + + xs[1:, 0] = 102.0 + output_ndarray_float_2(xs) + + xs[0] = np_array([-1.0, -2.0, -3.0, -4.0]) + output_ndarray_float_2(xs) + + xs[:] = np_array([-5.0, -6.0, -7.0, -8.0]) + output_ndarray_float_2(xs) + + # Test assignment with memory sharing + ys1 = np_reshape(xs, (2, 4)) + ys2 = np_transpose(ys1) + ys3 = ys2[::-1, 0] + ys3[0] = -999.0 + + output_ndarray_float_2(xs) + output_ndarray_float_2(ys1) + output_ndarray_float_2(ys2) + output_ndarray_float_1(ys3) + def test_ndarray_add(): x = np_identity(2) y = x + np_ones([2, 2]) @@ -530,11 +641,59 @@ def test_ndarray_ipow_broadcast_scalar(): output_ndarray_float_2(x) def test_ndarray_matmul(): - x = np_identity(2) - y = x @ np_ones([2, 2]) + # 2D @ 2D -> 2D + a1 = np_array([[2.0, 3.0], [5.0, 7.0]]) + b1 = np_array([[11.0, 13.0], [17.0, 23.0]]) + c1 = a1 @ b1 + output_int32(np_shape(c1)[0]) + output_int32(np_shape(c1)[1]) + output_ndarray_float_2(c1) - output_ndarray_float_2(x) - output_ndarray_float_2(y) + # 1D @ 1D -> Scalar + a2 = np_array([2.0, 3.0, 5.0]) + b2 = np_array([7.0, 11.0, 13.0]) + c2 = a2 @ b2 + output_float64(c2) + + # 2D @ 1D -> 1D + a3 = np_array([[1.0, 2.0, 3.0], [7.0, 8.0, 9.0]]) + b3 = np_array([4.0, 5.0, 6.0]) + c3 = a3 @ b3 + output_int32(np_shape(c3)[0]) + output_ndarray_float_1(c3) + + # 1D @ 2D -> 1D + a4 = np_array([1.0, 2.0, 3.0]) + b4 = np_array([[4.0, 5.0], [6.0, 7.0], [8.0, 9.0]]) + c4 = a4 @ b4 + output_int32(np_shape(c4)[0]) + output_ndarray_float_1(c4) + + # Broadcasting + a5 = np_array([ + [[ 0.0, 1.0, 2.0, 3.0], + [ 4.0, 5.0, 6.0, 7.0]], + [[ 8.0, 9.0, 10.0, 11.0], + [12.0, 13.0, 14.0, 15.0]], + [[16.0, 17.0, 18.0, 19.0], + [20.0, 21.0, 22.0, 23.0]] + ]) + b5 = np_array([ + [[[ 0.0, 1.0, 2.0], + [ 3.0, 4.0, 5.0], + [ 6.0, 7.0, 8.0], + [ 9.0, 10.0, 11.0]]], + [[[12.0, 13.0, 14.0], + [15.0, 16.0, 17.0], + [18.0, 19.0, 20.0], + [21.0, 22.0, 23.0]]] + ]) + c5 = a5 @ b5 + output_int32(np_shape(c5)[0]) + output_int32(np_shape(c5)[1]) + output_int32(np_shape(c5)[2]) + output_int32(np_shape(c5)[3]) + output_ndarray_float_4(c5) def test_ndarray_imatmul(): x = np_identity(2) @@ -1429,27 +1588,6 @@ def test_ndarray_nextafter_broadcast_rhs_scalar(): output_ndarray_float_2(nextafter_x_zeros) output_ndarray_float_2(nextafter_x_ones) -def test_ndarray_transpose(): - x: ndarray[float, 2] = np_array([[1., 2., 3.], [4., 5., 6.]]) - y = np_transpose(x) - z = np_transpose(y) - - output_ndarray_float_2(x) - output_ndarray_float_2(y) - -def test_ndarray_reshape(): - w: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) - x = np_reshape(w, (1, 2, 1, -1)) - y = np_reshape(x, [2, -1]) - z = np_reshape(y, 10) - - x1: ndarray[int32, 1] = np_array([1, 2, 3, 4]) - x2: ndarray[int32, 2] = np_reshape(x1, (2, 2)) - - output_ndarray_float_1(w) - output_ndarray_float_2(y) - output_ndarray_float_1(z) - def test_ndarray_dot(): x1: ndarray[float, 1] = np_array([5.0, 1.0, 4.0, 2.0]) y1: ndarray[float, 1] = np_array([5.0, 1.0, 6.0, 6.0]) @@ -1581,6 +1719,11 @@ def run() -> int32: test_ndarray_slices() test_ndarray_nd_idx() + test_ndarray_transpose() + test_ndarray_reshape() + test_ndarray_broadcast_to() + test_ndarray_subscript_assignment() + test_ndarray_add() test_ndarray_add_broadcast() test_ndarray_add_broadcast_lhs_scalar() @@ -1744,8 +1887,6 @@ def run() -> int32: test_ndarray_nextafter_broadcast() test_ndarray_nextafter_broadcast_lhs_scalar() test_ndarray_nextafter_broadcast_rhs_scalar() - test_ndarray_transpose() - test_ndarray_reshape() test_ndarray_dot() test_ndarray_cholesky() diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index cc4811c1..da9b0e6f 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -14,6 +14,7 @@ use inkwell::{ memory_buffer::MemoryBuffer, passes::PassBuilderOptions, support::is_multithreaded, targets::*, OptimizationLevel, }; +use nac3core::codegen::irrt::setup_irrt_exceptions; use nac3core::{ codegen::{ concrete_type::ConcreteTypeStore, irrt::load_irrt, CodeGenLLVMOptions, @@ -314,6 +315,16 @@ fn main() { let resolver = Arc::new(Resolver(internal_resolver.clone())) as Arc; + let context = inkwell::context::Context::create(); + + // Process IRRT + let irrt = load_irrt(&context); + setup_irrt_exceptions(&context, &irrt, resolver.as_ref()); + if emit_llvm { + irrt.write_bitcode_to_path(Path::new("irrt.bc")); + } + + // Process the Python script let parser_result = parser::parse_program(&program, file_name.into()).unwrap(); for stmt in parser_result { @@ -418,8 +429,8 @@ fn main() { registry.add_task(task); registry.wait_tasks_complete(handles); + // Link all modules together into `main` let buffers = membuffers.lock(); - let context = inkwell::context::Context::create(); let main = context .create_module_from_ir(MemoryBuffer::create_from_memory_range(&buffers[0], "main")) .unwrap(); @@ -439,12 +450,9 @@ fn main() { main.link_in_module(other).unwrap(); } - let irrt = load_irrt(&context); - if emit_llvm { - irrt.write_bitcode_to_path(Path::new("irrt.bc")); - } main.link_in_module(irrt).unwrap(); + // Private all functions except "run" let mut function_iter = main.get_first_function(); while let Some(func) = function_iter { if func.count_basic_blocks() > 0 && func.get_name().to_str().unwrap() != "run" { @@ -453,6 +461,7 @@ fn main() { function_iter = func.get_next_function(); } + // Optimize `main` let target_machine = llvm_options .target .create_target_machine(llvm_options.opt_level) @@ -466,6 +475,7 @@ fn main() { panic!("Failed to run optimization for module `main`: {}", err.to_string()); } + // Write output target_machine .write_to_file(&main, FileType::Object, Path::new("module.o")) .expect("couldn't write module to file");