From 8e614d83de169a5dbd4a36219a3ab37a09c236e7 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 14 Jan 2025 18:20:09 +0800 Subject: [PATCH] [core] codegen: Add ProxyType::new overloads and refactor to use them --- nac3artiq/src/codegen.rs | 6 +- nac3core/src/codegen/builtin_fns.rs | 130 ++++++++++-------- nac3core/src/codegen/expr.rs | 42 ++---- nac3core/src/codegen/mod.rs | 6 +- nac3core/src/codegen/numpy.rs | 16 +-- nac3core/src/codegen/stmt.rs | 3 +- nac3core/src/codegen/test.rs | 4 +- nac3core/src/codegen/types/list.rs | 45 ++++-- nac3core/src/codegen/types/ndarray/array.rs | 11 +- .../src/codegen/types/ndarray/broadcast.rs | 20 ++- .../src/codegen/types/ndarray/contiguous.rs | 22 ++- .../src/codegen/types/ndarray/indexing.rs | 17 ++- nac3core/src/codegen/types/ndarray/map.rs | 14 +- nac3core/src/codegen/types/ndarray/mod.rs | 80 +++++++---- nac3core/src/codegen/types/ndarray/nditer.rs | 20 ++- nac3core/src/codegen/types/tuple.rs | 27 +++- nac3core/src/codegen/types/utils/slice.rs | 24 +++- nac3core/src/codegen/values/list.rs | 8 +- .../src/codegen/values/ndarray/broadcast.rs | 4 +- .../src/codegen/values/ndarray/contiguous.rs | 11 +- .../src/codegen/values/ndarray/indexing.rs | 8 +- nac3core/src/codegen/values/ndarray/matmul.rs | 2 +- nac3core/src/codegen/values/ndarray/mod.rs | 22 +-- nac3core/src/codegen/values/ndarray/nditer.rs | 4 +- nac3core/src/codegen/values/ndarray/view.rs | 2 +- 25 files changed, 320 insertions(+), 228 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index cb75606d..c968198b 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -476,8 +476,8 @@ fn format_rpc_arg<'ctx>( let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); let ndims = extract_ndims(&ctx.unifier, ndims); let dtype = ctx.get_llvm_type(generator, elem_ty); - let ndarray = NDArrayType::new(generator, ctx.ctx, dtype, ndims) - .map_value(arg.into_pointer_value(), None); + let ndarray = + NDArrayType::new(ctx, dtype, ndims).map_value(arg.into_pointer_value(), None); let ndims = llvm_usize.const_int(ndims, false); @@ -609,7 +609,7 @@ fn format_rpc_ret<'ctx>( let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty); let dtype_llvm = ctx.get_llvm_type(generator, dtype); let ndims = extract_ndims(&ctx.unifier, ndims); - let ndarray = NDArrayType::new(generator, ctx.ctx, dtype_llvm, ndims) + let ndarray = NDArrayType::new(ctx, dtype_llvm, ndims) .construct_uninitialized(generator, ctx, None); // NOTE: Current content of `ndarray`: diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 96f8c700..911e3dc1 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -752,24 +752,20 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( debug_assert!(ctx.unifier.unioned(x1_dtype, x2_dtype)); let llvm_common_dtype = x1.get_type().element_type(); - let result = NDArrayType::new_broadcast( - generator, - ctx.ctx, - llvm_common_dtype, - &[x1.get_type(), x2.get_type()], - ) - .broadcast_starmap( - generator, - ctx, - &[x1, x2], - NDArrayOut::NewNDArray { dtype: llvm_common_dtype }, - |_, ctx, scalars| { - let x1_scalar = scalars[0]; - let x2_scalar = scalars[1]; - Ok(call_min(ctx, (x1_dtype, x1_scalar), (x2_dtype, x2_scalar))) - }, - ) - .unwrap(); + let result = + NDArrayType::new_broadcast(ctx, llvm_common_dtype, &[x1.get_type(), x2.get_type()]) + .broadcast_starmap( + generator, + ctx, + &[x1, x2], + NDArrayOut::NewNDArray { dtype: llvm_common_dtype }, + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; + Ok(call_min(ctx, (x1_dtype, x1_scalar), (x2_dtype, x2_scalar))) + }, + ) + .unwrap(); result.as_base_value().into() } @@ -1015,24 +1011,20 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>( debug_assert!(ctx.unifier.unioned(x1_dtype, x2_dtype)); let llvm_common_dtype = x1.get_type().element_type(); - let result = NDArrayType::new_broadcast( - generator, - ctx.ctx, - llvm_common_dtype, - &[x1.get_type(), x2.get_type()], - ) - .broadcast_starmap( - generator, - ctx, - &[x1, x2], - NDArrayOut::NewNDArray { dtype: llvm_common_dtype }, - |_, ctx, scalars| { - let x1_scalar = scalars[0]; - let x2_scalar = scalars[1]; - Ok(call_max(ctx, (x1_dtype, x1_scalar), (x2_dtype, x2_scalar))) - }, - ) - .unwrap(); + let result = + NDArrayType::new_broadcast(ctx, llvm_common_dtype, &[x1.get_type(), x2.get_type()]) + .broadcast_starmap( + generator, + ctx, + &[x1, x2], + NDArrayOut::NewNDArray { dtype: llvm_common_dtype }, + |_, ctx, scalars| { + let x1_scalar = scalars[0]; + let x2_scalar = scalars[1]; + Ok(call_max(ctx, (x1_dtype, x1_scalar), (x2_dtype, x2_scalar))) + }, + ) + .unwrap(); result.as_base_value().into() } @@ -1652,7 +1644,7 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); } - let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2) + let out = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2) .construct_uninitialized(generator, ctx, None); out.copy_shape_from_ndarray(generator, ctx, x1); unsafe { out.create_data(generator, ctx) }; @@ -1694,7 +1686,7 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( }; let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None); - let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2); + let out_ndarray_ty = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2); let q = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[d0, dk], None); unsafe { q.create_data(generator, ctx) }; @@ -1715,8 +1707,11 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( let q = q.as_base_value().as_basic_value_enum(); let r = r.as_base_value().as_basic_value_enum(); - let tuple = TupleType::new(generator, ctx.ctx, &[q.get_type(), r.get_type()]) - .construct_from_objects(ctx, [q, r], None); + let tuple = TupleType::new(ctx, &[q.get_type(), r.get_type()]).construct_from_objects( + ctx, + [q, r], + None, + ); Ok(tuple.as_base_value().into()) } @@ -1746,8 +1741,8 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( }; let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None); - let out_ndarray1_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 1); - let out_ndarray2_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2); + let out_ndarray1_ty = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 1); + let out_ndarray2_ty = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2); let u = out_ndarray2_ty.construct_dyn_shape(generator, ctx, &[d0, d0], None); unsafe { u.create_data(generator, ctx) }; @@ -1775,7 +1770,7 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( let u = u.as_base_value().as_basic_value_enum(); let s = s.as_base_value().as_basic_value_enum(); let vh = vh.as_base_value().as_basic_value_enum(); - let tuple = TupleType::new(generator, ctx.ctx, &[u.get_type(), s.get_type(), vh.get_type()]) + let tuple = TupleType::new(ctx, &[u.get_type(), s.get_type(), vh.get_type()]) .construct_from_objects(ctx, [u, s, vh], None); Ok(tuple.as_base_value().into()) } @@ -1796,7 +1791,7 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); } - let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2) + let out = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2) .construct_uninitialized(generator, ctx, None); out.copy_shape_from_ndarray(generator, ctx, x1); unsafe { out.create_data(generator, ctx) }; @@ -1838,8 +1833,12 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) }; - let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2) - .construct_dyn_shape(generator, ctx, &[d0, d1], None); + let out = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2).construct_dyn_shape( + generator, + ctx, + &[d0, d1], + None, + ); unsafe { out.create_data(generator, ctx) }; let x1_c = x1.make_contiguous_ndarray(generator, ctx); @@ -1880,7 +1879,7 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( }; let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None); - let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2); + let out_ndarray_ty = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2); let l = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[d0, dk], None); unsafe { l.create_data(generator, ctx) }; @@ -1901,8 +1900,11 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( let l = l.as_base_value().as_basic_value_enum(); let u = u.as_base_value().as_basic_value_enum(); - let tuple = TupleType::new(generator, ctx.ctx, &[l.get_type(), u.get_type()]) - .construct_from_objects(ctx, [l, u], None); + let tuple = TupleType::new(ctx, &[l.get_type(), u.get_type()]).construct_from_objects( + ctx, + [l, u], + None, + ); Ok(tuple.as_base_value().into()) } @@ -1936,11 +1938,11 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) }; - let x2 = NDArrayType::new_unsized(generator, ctx.ctx, ctx.ctx.f64_type().into()) + let x2 = NDArrayType::new_unsized(ctx, ctx.ctx.f64_type().into()) .construct_unsized(generator, ctx, &x2, None); // x2.shape == [] let x2 = x2.atleast_nd(generator, ctx, 1); // x2.shape == [1] - let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2) + let out = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2) .construct_uninitialized(generator, ctx, None); out.copy_shape_from_ndarray(generator, ctx, x1); unsafe { out.create_data(generator, ctx) }; @@ -1979,8 +1981,12 @@ pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>( } // The output is a float64, but we are using an ndarray (shape == [1]) for uniformity in function call. - let det = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 1) - .construct_const_shape(generator, ctx, &[1], None); + let det = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 1).construct_const_shape( + generator, + ctx, + &[1], + None, + ); unsafe { det.create_data(generator, ctx) }; let x1_c = x1.make_contiguous_ndarray(generator, ctx); @@ -2014,7 +2020,7 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); } - let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2); + let out_ndarray_ty = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2); let t = out_ndarray_ty.construct_uninitialized(generator, ctx, None); t.copy_shape_from_ndarray(generator, ctx, x1); @@ -2037,8 +2043,11 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( let t = t.as_base_value().as_basic_value_enum(); let z = z.as_base_value().as_basic_value_enum(); - let tuple = TupleType::new(generator, ctx.ctx, &[t.get_type(), z.get_type()]) - .construct_from_objects(ctx, [t, z], None); + let tuple = TupleType::new(ctx, &[t.get_type(), z.get_type()]).construct_from_objects( + ctx, + [t, z], + None, + ); Ok(tuple.as_base_value().into()) } @@ -2059,7 +2068,7 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]); } - let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2); + let out_ndarray_ty = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2); let h = out_ndarray_ty.construct_uninitialized(generator, ctx, None); h.copy_shape_from_ndarray(generator, ctx, x1); @@ -2082,7 +2091,10 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( let h = h.as_base_value().as_basic_value_enum(); let q = q.as_base_value().as_basic_value_enum(); - let tuple = TupleType::new(generator, ctx.ctx, &[h.get_type(), q.get_type()]) - .construct_from_objects(ctx, [h, q], None); + let tuple = TupleType::new(ctx, &[h.get_type(), q.get_type()]).construct_from_objects( + ctx, + [h, q], + None, + ); Ok(tuple.as_base_value().into()) } diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 00290d34..8f52e929 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1167,7 +1167,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( "listcomp.alloc_size", ) .unwrap(); - list = ListType::new(generator, ctx.ctx, elem_ty).construct( + list = ListType::new(ctx, &elem_ty).construct( generator, ctx, list_alloc_size.into_int_value(), @@ -1218,12 +1218,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( Some("length"), ) .into_int_value(); - list = ListType::new(generator, ctx.ctx, elem_ty).construct( - generator, - ctx, - length, - Some("listcomp"), - ); + list = ListType::new(ctx, &elem_ty).construct(generator, ctx, length, Some("listcomp")); let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("counter.addr"))?; // counter = -1 @@ -1386,8 +1381,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( .build_int_add(lhs.load_size(ctx, None), rhs.load_size(ctx, None), "") .unwrap(); - let new_list = ListType::new(generator, ctx.ctx, llvm_elem_ty) - .construct(generator, ctx, size, None); + let new_list = + ListType::new(ctx, &llvm_elem_ty).construct(generator, ctx, size, None); let lhs_size = ctx .builder @@ -1474,7 +1469,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( let elem_llvm_ty = ctx.get_llvm_type(generator, elem_ty); let sizeof_elem = elem_llvm_ty.size_of().unwrap(); - let new_list = ListType::new(generator, ctx.ctx, elem_llvm_ty).construct( + let new_list = ListType::new(ctx, &elem_llvm_ty).construct( generator, ctx, ctx.builder.build_int_mul(list_val.load_size(ctx, None), int_val, "").unwrap(), @@ -1576,8 +1571,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( let right = right.to_ndarray(generator, ctx); let result = NDArrayType::new_broadcast( - generator, - ctx.ctx, + ctx, llvm_common_dtype, &[left.get_type(), right.get_type()], ) @@ -1850,8 +1844,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( .to_ndarray(generator, ctx); let result_ndarray = NDArrayType::new_broadcast( - generator, - ctx.ctx, + ctx, ctx.ctx.i8_type().into(), &[left.get_type(), right.get_type()], ) @@ -2480,18 +2473,9 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( }; let length = ctx.get_size_type().const_int(elements.len() as u64, false); let arr_str_ptr = if let Some(ty) = ty { - ListType::new(generator, ctx.ctx, ty).construct( - generator, - ctx, - length, - Some("list"), - ) + ListType::new(ctx, &ty).construct(generator, ctx, length, Some("list")) } else { - ListType::new_untyped(generator, ctx.ctx).construct_empty( - generator, - ctx, - Some("list"), - ) + ListType::new_untyped(ctx).construct_empty(generator, ctx, Some("list")) }; let arr_ptr = arr_str_ptr.data(); for (i, v) in elements.iter().enumerate() { @@ -2970,12 +2954,8 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( .unwrap(), step, ); - let res_array_ret = ListType::new(generator, ctx.ctx, ty).construct( - generator, - ctx, - length, - Some("ret"), - ); + let res_array_ret = + ListType::new(ctx, &ty).construct(generator, ctx, length, Some("ret")); let Some(res_ind) = handle_slice_indices( &None, &None, diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index f7483ef8..dcfa2b8c 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -530,7 +530,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( *params.iter().next().unwrap().1, ); - ListType::new(generator, ctx, element_type).as_base_type().into() + ListType::new_with_generator(generator, ctx, element_type).as_base_type().into() } TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { @@ -540,7 +540,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( ctx, module, generator, unifier, top_level, type_cache, dtype, ); - NDArrayType::new(generator, ctx, element_type, ndims).as_base_type().into() + NDArrayType::new_with_generator(generator, ctx, element_type, ndims).as_base_type().into() } _ => unreachable!( @@ -594,7 +594,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty) }) .collect_vec(); - TupleType::new(generator, ctx, &fields).as_base_type().into() + TupleType::new_with_generator(generator, ctx, &fields).as_base_type().into() } TVirtual { .. } => unimplemented!(), _ => unreachable!("{}", ty_enum.get_type_name()), diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 6700af45..3cdd1ef3 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -42,7 +42,7 @@ pub fn gen_ndarray_empty<'ctx>( let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, ndims) + let ndarray = NDArrayType::new(context, llvm_dtype, ndims) .construct_numpy_empty(generator, context, &shape, None); Ok(ndarray.as_base_value()) } @@ -67,7 +67,7 @@ pub fn gen_ndarray_zeros<'ctx>( let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, ndims) + let ndarray = NDArrayType::new(context, llvm_dtype, ndims) .construct_numpy_zeros(generator, context, dtype, &shape, None); Ok(ndarray.as_base_value()) } @@ -92,7 +92,7 @@ pub fn gen_ndarray_ones<'ctx>( let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, ndims) + let ndarray = NDArrayType::new(context, llvm_dtype, ndims) .construct_numpy_ones(generator, context, dtype, &shape, None); Ok(ndarray.as_base_value()) } @@ -120,7 +120,7 @@ pub fn gen_ndarray_full<'ctx>( let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, ndims).construct_numpy_full( + let ndarray = NDArrayType::new(context, llvm_dtype, ndims).construct_numpy_full( generator, context, &shape, @@ -223,7 +223,7 @@ pub fn gen_ndarray_eye<'ctx>( .build_int_s_extend_or_bit_cast(offset_arg.into_int_value(), llvm_usize, "") .unwrap(); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, 2) + let ndarray = NDArrayType::new(context, llvm_dtype, 2) .construct_numpy_eye(generator, context, dtype, nrows, ncols, offset, None); Ok(ndarray.as_base_value()) } @@ -251,7 +251,7 @@ pub fn gen_ndarray_identity<'ctx>( .builder .build_int_s_extend_or_bit_cast(n_arg.into_int_value(), llvm_usize, "") .unwrap(); - let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, 2) + let ndarray = NDArrayType::new(context, llvm_dtype, 2) .construct_numpy_identity(generator, context, dtype, n, None); Ok(ndarray.as_base_value()) } @@ -349,8 +349,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( ctx, Some("np_dot"), |generator, ctx| { - let a_iter = NDIterType::new(generator, ctx.ctx).construct(generator, ctx, a); - let b_iter = NDIterType::new(generator, ctx.ctx).construct(generator, ctx, b); + let a_iter = NDIterType::new(ctx).construct(generator, ctx, a); + let b_iter = NDIterType::new(ctx).construct(generator, ctx, b); Ok((a_iter, b_iter)) }, |_, ctx, (a_iter, _b_iter)| { diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index c3274057..85a894ac 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -448,8 +448,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>( let broadcast_ndims = [target.get_type().ndims(), value.get_type().ndims()].into_iter().max().unwrap(); let broadcast_result = NDArrayType::new( - generator, - ctx.ctx, + ctx, value.get_type().element_type(), broadcast_ndims, ) diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 6518d858..a58a9847 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -446,7 +446,7 @@ fn test_classes_list_type_new() { let llvm_i32 = ctx.i32_type(); let llvm_usize = generator.get_size_type(&ctx); - let llvm_list = ListType::new(&generator, &ctx, llvm_i32.into()); + let llvm_list = ListType::new_with_generator(&generator, &ctx, llvm_i32.into()); assert!(ListType::is_representable(llvm_list.as_base_type(), llvm_usize).is_ok()); } @@ -466,6 +466,6 @@ fn test_classes_ndarray_type_new() { let llvm_i32 = ctx.i32_type(); let llvm_usize = generator.get_size_type(&ctx); - let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into(), 2); + let llvm_ndarray = NDArrayType::new_with_generator(&generator, &ctx, llvm_i32.into(), 2); assert!(NDArrayType::is_representable(llvm_ndarray.as_base_type(), llvm_usize).is_ok()); } diff --git a/nac3core/src/codegen/types/list.rs b/nac3core/src/codegen/types/list.rs index 9ea4acaa..637cced3 100644 --- a/nac3core/src/codegen/types/list.rs +++ b/nac3core/src/codegen/types/list.rs @@ -104,7 +104,7 @@ impl<'ctx> ListType<'ctx> { element_type: Option>, llvm_usize: IntType<'ctx>, ) -> PointerType<'ctx> { - let element_type = element_type.unwrap_or(llvm_usize.into()); + let element_type = element_type.map_or(llvm_usize.into(), |ty| ty.as_basic_type_enum()); let field_tys = Self::fields(element_type, llvm_usize).into_iter().map(|field| field.1).collect_vec(); @@ -112,26 +112,45 @@ impl<'ctx> ListType<'ctx> { ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } + fn new_impl( + ctx: &'ctx Context, + element_type: Option>, + llvm_usize: IntType<'ctx>, + ) -> Self { + let llvm_list = Self::llvm_type(ctx, element_type, llvm_usize); + + Self { ty: llvm_list, item: element_type, llvm_usize } + } + /// Creates an instance of [`ListType`]. #[must_use] - pub fn new( + pub fn new(ctx: &CodeGenContext<'ctx, '_>, element_type: &impl BasicType<'ctx>) -> Self { + Self::new_impl(ctx.ctx, Some(element_type.as_basic_type_enum()), ctx.get_size_type()) + } + + /// Creates an instance of [`ListType`]. + #[must_use] + pub fn new_with_generator( generator: &G, ctx: &'ctx Context, element_type: BasicTypeEnum<'ctx>, ) -> Self { - let llvm_usize = generator.get_size_type(ctx); - let llvm_list = Self::llvm_type(ctx, Some(element_type), llvm_usize); - - Self { ty: llvm_list, item: Some(element_type), llvm_usize } + Self::new_impl(ctx, Some(element_type.as_basic_type_enum()), generator.get_size_type(ctx)) } /// Creates an instance of [`ListType`] with an unknown element type. #[must_use] - pub fn new_untyped(generator: &G, ctx: &'ctx Context) -> Self { - let llvm_usize = generator.get_size_type(ctx); - let llvm_list = Self::llvm_type(ctx, None, llvm_usize); + pub fn new_untyped(ctx: &CodeGenContext<'ctx, '_>) -> Self { + Self::new_impl(ctx.ctx, None, ctx.get_size_type()) + } - Self { ty: llvm_list, item: None, llvm_usize } + /// Creates an instance of [`ListType`] with an unknown element type. + #[must_use] + pub fn new_untyped_with_generator( + generator: &G, + ctx: &'ctx Context, + ) -> Self { + Self::new_impl(ctx, None, generator.get_size_type(ctx)) } /// Creates an [`ListType`] from a [unifier type][Type]. @@ -159,11 +178,7 @@ impl<'ctx> ListType<'ctx> { Some(ctx.get_llvm_type(generator, elem_type)) }; - Self { - ty: Self::llvm_type(ctx.ctx, llvm_elem_type, llvm_usize), - item: llvm_elem_type, - llvm_usize, - } + Self::new_impl(ctx.ctx, llvm_elem_type, llvm_usize) } /// Creates an [`ListType`] from a [`PointerType`]. diff --git a/nac3core/src/codegen/types/ndarray/array.rs b/nac3core/src/codegen/types/ndarray/array.rs index b0c9d637..70611127 100644 --- a/nac3core/src/codegen/types/ndarray/array.rs +++ b/nac3core/src/codegen/types/ndarray/array.rs @@ -44,7 +44,7 @@ impl<'ctx> NDArrayType<'ctx> { assert!(self.ndims >= ndims_int); assert_eq!(dtype, self.dtype); - let list_value = list.as_i8_list(generator, ctx); + let list_value = list.as_i8_list(ctx); // Validate `list` has a consistent shape. // Raise an exception if `list` is something abnormal like `[[1, 2], [3]]`. @@ -61,8 +61,8 @@ impl<'ctx> NDArrayType<'ctx> { generator, ctx, list_value, ndims, &shape, ); - let ndarray = Self::new(generator, ctx.ctx, dtype, ndims_int) - .construct_uninitialized(generator, ctx, name); + let ndarray = + Self::new(ctx, dtype, ndims_int).construct_uninitialized(generator, ctx, name); ndarray.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator)); unsafe { ndarray.create_data(generator, ctx) }; @@ -96,8 +96,7 @@ impl<'ctx> NDArrayType<'ctx> { let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); - let ndarray = Self::new(generator, ctx.ctx, dtype, 1) - .construct_uninitialized(generator, ctx, name); + let ndarray = Self::new(ctx, dtype, 1).construct_uninitialized(generator, ctx, name); // Set data let data = ctx @@ -168,7 +167,7 @@ impl<'ctx> NDArrayType<'ctx> { .map(BasicValueEnum::into_pointer_value) .unwrap(); - NDArrayType::new(generator, ctx.ctx, dtype, ndims).map_value(ndarray, None) + NDArrayType::new(ctx, dtype, ndims).map_value(ndarray, None) } /// Implementation of `np_array(, copy=copy)`. diff --git a/nac3core/src/codegen/types/ndarray/broadcast.rs b/nac3core/src/codegen/types/ndarray/broadcast.rs index 5ee28454..3a1fd8da 100644 --- a/nac3core/src/codegen/types/ndarray/broadcast.rs +++ b/nac3core/src/codegen/types/ndarray/broadcast.rs @@ -79,15 +79,27 @@ impl<'ctx> ShapeEntryType<'ctx> { ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } - /// Creates an instance of [`ShapeEntryType`]. - #[must_use] - pub fn new(generator: &G, ctx: &'ctx Context) -> Self { - let llvm_usize = generator.get_size_type(ctx); + fn new_impl(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self { let llvm_ty = Self::llvm_type(ctx, llvm_usize); Self { ty: llvm_ty, llvm_usize } } + /// Creates an instance of [`ShapeEntryType`]. + #[must_use] + pub fn new(ctx: &CodeGenContext<'ctx, '_>) -> Self { + Self::new_impl(ctx.ctx, ctx.get_size_type()) + } + + /// Creates an instance of [`ShapeEntryType`]. + #[must_use] + pub fn new_with_generator( + generator: &G, + ctx: &'ctx Context, + ) -> Self { + Self::new_impl(ctx, generator.get_size_type(ctx)) + } + /// Creates a [`ShapeEntryType`] from a [`PointerType`] representing an `ShapeEntry`. #[must_use] pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { diff --git a/nac3core/src/codegen/types/ndarray/contiguous.rs b/nac3core/src/codegen/types/ndarray/contiguous.rs index f4a8b73d..c751d573 100644 --- a/nac3core/src/codegen/types/ndarray/contiguous.rs +++ b/nac3core/src/codegen/types/ndarray/contiguous.rs @@ -117,17 +117,26 @@ impl<'ctx> ContiguousNDArrayType<'ctx> { ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } + fn new_impl(ctx: &'ctx Context, item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + let llvm_cndarray = Self::llvm_type(ctx, item, llvm_usize); + + Self { ty: llvm_cndarray, item, llvm_usize } + } + /// Creates an instance of [`ContiguousNDArrayType`]. #[must_use] - pub fn new( + pub fn new(ctx: &CodeGenContext<'ctx, '_>, item: &impl BasicType<'ctx>) -> Self { + Self::new_impl(ctx.ctx, item.as_basic_type_enum(), ctx.get_size_type()) + } + + /// Creates an instance of [`ContiguousNDArrayType`]. + #[must_use] + pub fn new_with_generator( generator: &G, ctx: &'ctx Context, item: BasicTypeEnum<'ctx>, ) -> Self { - let llvm_usize = generator.get_size_type(ctx); - let llvm_cndarray = Self::llvm_type(ctx, item, llvm_usize); - - Self { ty: llvm_cndarray, item, llvm_usize } + Self::new_impl(ctx, item, generator.get_size_type(ctx)) } /// Creates an [`ContiguousNDArrayType`] from a [unifier type][Type]. @@ -140,9 +149,8 @@ impl<'ctx> ContiguousNDArrayType<'ctx> { let (dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); let llvm_dtype = ctx.get_llvm_type(generator, dtype); - let llvm_usize = ctx.get_size_type(); - Self { ty: Self::llvm_type(ctx.ctx, llvm_dtype, llvm_usize), item: llvm_dtype, llvm_usize } + Self::new_impl(ctx.ctx, llvm_dtype, ctx.get_size_type()) } /// Creates an [`ContiguousNDArrayType`] from a [`PointerType`] representing an `NDArray`. diff --git a/nac3core/src/codegen/types/ndarray/indexing.rs b/nac3core/src/codegen/types/ndarray/indexing.rs index 644e173c..3e4e1362 100644 --- a/nac3core/src/codegen/types/ndarray/indexing.rs +++ b/nac3core/src/codegen/types/ndarray/indexing.rs @@ -75,14 +75,25 @@ impl<'ctx> NDIndexType<'ctx> { ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } - #[must_use] - pub fn new(generator: &G, ctx: &'ctx Context) -> Self { - let llvm_usize = generator.get_size_type(ctx); + fn new_impl(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self { let llvm_ndindex = Self::llvm_type(ctx, llvm_usize); Self { ty: llvm_ndindex, llvm_usize } } + #[must_use] + pub fn new(ctx: &CodeGenContext<'ctx, '_>) -> Self { + Self::new_impl(ctx.ctx, ctx.get_size_type()) + } + + #[must_use] + pub fn new_with_generator( + generator: &G, + ctx: &'ctx Context, + ) -> Self { + Self::new_impl(ctx, generator.get_size_type(ctx)) + } + #[must_use] pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); diff --git a/nac3core/src/codegen/types/ndarray/map.rs b/nac3core/src/codegen/types/ndarray/map.rs index 6fdd9e12..ae24458c 100644 --- a/nac3core/src/codegen/types/ndarray/map.rs +++ b/nac3core/src/codegen/types/ndarray/map.rs @@ -46,9 +46,8 @@ impl<'ctx> NDArrayType<'ctx> { let out_ndarray = match out { NDArrayOut::NewNDArray { dtype } => { // Create a new ndarray based on the broadcast shape. - let result_ndarray = - NDArrayType::new(generator, ctx.ctx, dtype, broadcast_result.ndims) - .construct_uninitialized(generator, ctx, None); + let result_ndarray = NDArrayType::new(ctx, dtype, broadcast_result.ndims) + .construct_uninitialized(generator, ctx, None); result_ndarray.copy_shape_from_array( generator, ctx, @@ -70,7 +69,7 @@ impl<'ctx> NDArrayType<'ctx> { }; // Map element-wise and store results into `mapped_ndarray`. - let nditer = NDIterType::new(generator, ctx.ctx).construct(generator, ctx, out_ndarray); + let nditer = NDIterType::new(ctx).construct(generator, ctx, out_ndarray); gen_for_callback( generator, ctx, @@ -80,9 +79,7 @@ impl<'ctx> NDArrayType<'ctx> { let other_nditers = broadcast_result .ndarrays .iter() - .map(|ndarray| { - NDIterType::new(generator, ctx.ctx).construct(generator, ctx, *ndarray) - }) + .map(|ndarray| NDIterType::new(ctx).construct(generator, ctx, *ndarray)) .collect_vec(); Ok((nditer, other_nditers)) }, @@ -169,8 +166,7 @@ impl<'ctx> ScalarOrNDArray<'ctx> { // Promote all input to ndarrays and map through them. let inputs = inputs.iter().map(|input| input.to_ndarray(generator, ctx)).collect_vec(); let ndarray = NDArrayType::new_broadcast( - generator, - ctx.ctx, + ctx, ret_dtype, &inputs.iter().map(NDArrayValue::get_type).collect_vec(), ) diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index a7bcb7ef..fe73307d 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -107,24 +107,56 @@ impl<'ctx> NDArrayType<'ctx> { ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } - /// Creates an instance of [`NDArrayType`]. - #[must_use] - pub fn new( - generator: &G, + fn new_impl( ctx: &'ctx Context, dtype: BasicTypeEnum<'ctx>, ndims: u64, + llvm_usize: IntType<'ctx>, ) -> Self { - let llvm_usize = generator.get_size_type(ctx); let llvm_ndarray = Self::llvm_type(ctx, llvm_usize); NDArrayType { ty: llvm_ndarray, dtype, ndims, llvm_usize } } + /// Creates an instance of [`NDArrayType`]. + #[must_use] + pub fn new(ctx: &CodeGenContext<'ctx, '_>, dtype: BasicTypeEnum<'ctx>, ndims: u64) -> Self { + Self::new_impl(ctx.ctx, dtype, ndims, ctx.get_size_type()) + } + + /// Creates an instance of [`NDArrayType`]. + #[must_use] + pub fn new_with_generator( + generator: &G, + ctx: &'ctx Context, + dtype: BasicTypeEnum<'ctx>, + ndims: u64, + ) -> Self { + Self::new_impl(ctx, dtype, ndims, generator.get_size_type(ctx)) + } + /// Creates an instance of [`NDArrayType`] as a result of a broadcast operation over one or more /// `ndarray` operands. #[must_use] - pub fn new_broadcast( + pub fn new_broadcast( + ctx: &CodeGenContext<'ctx, '_>, + dtype: BasicTypeEnum<'ctx>, + inputs: &[NDArrayType<'ctx>], + ) -> Self { + assert!(!inputs.is_empty()); + + Self::new_impl( + ctx.ctx, + dtype, + inputs.iter().map(NDArrayType::ndims).max().unwrap(), + ctx.get_size_type(), + ) + } + + /// Creates an instance of [`NDArrayType`] as a result of a broadcast operation over one or more + /// `ndarray` operands. + #[must_use] + pub fn new_broadcast_with_generator( generator: &G, ctx: &'ctx Context, dtype: BasicTypeEnum<'ctx>, @@ -132,20 +164,28 @@ impl<'ctx> NDArrayType<'ctx> { ) -> Self { assert!(!inputs.is_empty()); - Self::new(generator, ctx, dtype, inputs.iter().map(NDArrayType::ndims).max().unwrap()) + Self::new_impl( + ctx, + dtype, + inputs.iter().map(NDArrayType::ndims).max().unwrap(), + generator.get_size_type(ctx), + ) } /// Creates an instance of [`NDArrayType`] with `ndims` of 0. #[must_use] - pub fn new_unsized( + pub fn new_unsized(ctx: &CodeGenContext<'ctx, '_>, dtype: BasicTypeEnum<'ctx>) -> Self { + Self::new_impl(ctx.ctx, dtype, 0, ctx.get_size_type()) + } + + /// Creates an instance of [`NDArrayType`] with `ndims` of 0. + #[must_use] + pub fn new_unsized_with_generator( generator: &G, ctx: &'ctx Context, dtype: BasicTypeEnum<'ctx>, ) -> Self { - let llvm_usize = generator.get_size_type(ctx); - let llvm_ndarray = Self::llvm_type(ctx, llvm_usize); - - NDArrayType { ty: llvm_ndarray, dtype, ndims: 0, llvm_usize } + Self::new_impl(ctx, dtype, 0, generator.get_size_type(ctx)) } /// Creates an [`NDArrayType`] from a [unifier type][Type]. @@ -158,15 +198,9 @@ impl<'ctx> NDArrayType<'ctx> { let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); let llvm_dtype = ctx.get_llvm_type(generator, dtype); - let llvm_usize = ctx.get_size_type(); let ndims = extract_ndims(&ctx.unifier, ndims); - NDArrayType { - ty: Self::llvm_type(ctx.ctx, llvm_usize), - dtype: llvm_dtype, - ndims, - llvm_usize, - } + Self::new_impl(ctx.ctx, llvm_dtype, ndims, ctx.get_size_type()) } /// Creates an [`NDArrayType`] from a [`PointerType`] representing an `NDArray`. @@ -304,7 +338,7 @@ impl<'ctx> NDArrayType<'ctx> { ) -> >::Value { assert_eq!(shape.len() as u64, self.ndims); - let ndarray = Self::new(generator, ctx.ctx, self.dtype, shape.len() as u64) + let ndarray = Self::new(ctx, self.dtype, shape.len() as u64) .construct_uninitialized(generator, ctx, name); let llvm_usize = ctx.get_size_type(); @@ -339,7 +373,7 @@ impl<'ctx> NDArrayType<'ctx> { ) -> >::Value { assert_eq!(shape.len() as u64, self.ndims); - let ndarray = Self::new(generator, ctx.ctx, self.dtype, shape.len() as u64) + let ndarray = Self::new(ctx, self.dtype, shape.len() as u64) .construct_uninitialized(generator, ctx, name); let llvm_usize = ctx.get_size_type(); @@ -389,8 +423,8 @@ impl<'ctx> NDArrayType<'ctx> { .build_pointer_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "") .unwrap(); - let ndarray = Self::new_unsized(generator, ctx.ctx, value.get_type()) - .construct_uninitialized(generator, ctx, name); + let ndarray = + Self::new_unsized(ctx, value.get_type()).construct_uninitialized(generator, ctx, name); ctx.builder.build_store(ndarray.ptr_to_data(ctx), data).unwrap(); ndarray } diff --git a/nac3core/src/codegen/types/ndarray/nditer.rs b/nac3core/src/codegen/types/ndarray/nditer.rs index c77e4571..1d83742f 100644 --- a/nac3core/src/codegen/types/ndarray/nditer.rs +++ b/nac3core/src/codegen/types/ndarray/nditer.rs @@ -86,15 +86,27 @@ impl<'ctx> NDIterType<'ctx> { ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } - /// Creates an instance of [`NDIter`]. - #[must_use] - pub fn new(generator: &G, ctx: &'ctx Context) -> Self { - let llvm_usize = generator.get_size_type(ctx); + fn new_impl(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self { let llvm_nditer = Self::llvm_type(ctx, llvm_usize); Self { ty: llvm_nditer, llvm_usize } } + /// Creates an instance of [`NDIter`]. + #[must_use] + pub fn new(ctx: &CodeGenContext<'ctx, '_>) -> Self { + Self::new_impl(ctx.ctx, ctx.get_size_type()) + } + + /// Creates an instance of [`NDIter`]. + #[must_use] + pub fn new_with_generator( + generator: &G, + ctx: &'ctx Context, + ) -> Self { + Self::new_impl(ctx, generator.get_size_type(ctx)) + } + /// Creates an [`NDIterType`] from a [`PointerType`] representing an `NDIter`. #[must_use] pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { diff --git a/nac3core/src/codegen/types/tuple.rs b/nac3core/src/codegen/types/tuple.rs index 947f95ad..5c736528 100644 --- a/nac3core/src/codegen/types/tuple.rs +++ b/nac3core/src/codegen/types/tuple.rs @@ -32,17 +32,34 @@ impl<'ctx> TupleType<'ctx> { ctx.struct_type(tys, false) } + fn new_impl( + ctx: &'ctx Context, + tys: &[BasicTypeEnum<'ctx>], + llvm_usize: IntType<'ctx>, + ) -> Self { + let llvm_tuple = Self::llvm_type(ctx, tys); + + Self { ty: llvm_tuple, llvm_usize } + } + /// Creates an instance of [`TupleType`]. #[must_use] - pub fn new( + pub fn new(ctx: &CodeGenContext<'ctx, '_>, tys: &[impl BasicType<'ctx>]) -> Self { + Self::new_impl( + ctx.ctx, + &tys.iter().map(BasicType::as_basic_type_enum).collect_vec(), + ctx.get_size_type(), + ) + } + + /// Creates an instance of [`TupleType`]. + #[must_use] + pub fn new_with_generator( generator: &G, ctx: &'ctx Context, tys: &[BasicTypeEnum<'ctx>], ) -> Self { - let llvm_usize = generator.get_size_type(ctx); - let llvm_tuple = Self::llvm_type(ctx, tys); - - Self { ty: llvm_tuple, llvm_usize } + Self::new_impl(ctx, tys, generator.get_size_type(ctx)) } /// Creates an [`TupleType`] from a [unifier type][Type]. diff --git a/nac3core/src/codegen/types/utils/slice.rs b/nac3core/src/codegen/types/utils/slice.rs index fa5a3474..0ef4d1b0 100644 --- a/nac3core/src/codegen/types/utils/slice.rs +++ b/nac3core/src/codegen/types/utils/slice.rs @@ -122,19 +122,31 @@ impl<'ctx> SliceType<'ctx> { ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) } - /// Creates an instance of [`SliceType`] with `int_ty` as its backing integer type. - #[must_use] - pub fn new(ctx: &'ctx Context, int_ty: IntType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + fn new_impl(ctx: &'ctx Context, int_ty: IntType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { let llvm_ty = Self::llvm_type(ctx, int_ty); Self { ty: llvm_ty, int_ty, llvm_usize } } + /// Creates an instance of [`SliceType`] with `int_ty` as its backing integer type. + #[must_use] + pub fn new(ctx: &CodeGenContext<'ctx, '_>, int_ty: IntType<'ctx>) -> Self { + Self::new_impl(ctx.ctx, int_ty, ctx.get_size_type()) + } + /// Creates an instance of [`SliceType`] with `usize` as its backing integer type. #[must_use] - pub fn new_usize(generator: &G, ctx: &'ctx Context) -> Self { - let llvm_usize = generator.get_size_type(ctx); - Self::new(ctx, llvm_usize, llvm_usize) + pub fn new_usize(ctx: &CodeGenContext<'ctx, '_>) -> Self { + Self::new_impl(ctx.ctx, ctx.get_size_type(), ctx.get_size_type()) + } + + /// Creates an instance of [`SliceType`] with `usize` as its backing integer type. + #[must_use] + pub fn new_usize_with_generator( + generator: &G, + ctx: &'ctx Context, + ) -> Self { + Self::new_impl(ctx, generator.get_size_type(ctx), generator.get_size_type(ctx)) } /// Creates an [`SliceType`] from a [`PointerType`] representing a `slice`. diff --git a/nac3core/src/codegen/values/list.rs b/nac3core/src/codegen/values/list.rs index 08d2b6b5..4ba5b6af 100644 --- a/nac3core/src/codegen/values/list.rs +++ b/nac3core/src/codegen/values/list.rs @@ -114,13 +114,9 @@ impl<'ctx> ListValue<'ctx> { /// Returns an instance of [`ListValue`] with the `items` pointer cast to `i8*`. #[must_use] - pub fn as_i8_list( - &self, - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ) -> ListValue<'ctx> { + pub fn as_i8_list(&self, ctx: &CodeGenContext<'ctx, '_>) -> ListValue<'ctx> { let llvm_i8 = ctx.ctx.i8_type(); - let llvm_list_i8 = ::Type::new(generator, ctx.ctx, llvm_i8.into()); + let llvm_list_i8 = ::Type::new(ctx, &llvm_i8); Self::from_pointer_value( ctx.builder.build_pointer_cast(self.value, llvm_list_i8.as_base_type(), "").unwrap(), diff --git a/nac3core/src/codegen/values/ndarray/broadcast.rs b/nac3core/src/codegen/values/ndarray/broadcast.rs index b145746e..b5182a2b 100644 --- a/nac3core/src/codegen/values/ndarray/broadcast.rs +++ b/nac3core/src/codegen/values/ndarray/broadcast.rs @@ -104,7 +104,7 @@ impl<'ctx> NDArrayValue<'ctx> { assert!(self.ndims <= target_ndims); assert_eq!(target_shape.element_type(ctx, generator), self.llvm_usize.into()); - let broadcast_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, target_ndims) + let broadcast_ndarray = NDArrayType::new(ctx, self.dtype, target_ndims) .construct_uninitialized(generator, ctx, None); broadcast_ndarray.copy_shape_from_array( generator, @@ -147,7 +147,7 @@ fn broadcast_shapes<'ctx, G, Shape>( + TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>>, { let llvm_usize = ctx.get_size_type(); - let llvm_shape_ty = ShapeEntryType::new(generator, ctx.ctx); + let llvm_shape_ty = ShapeEntryType::new(ctx); assert!(in_shape_entries .iter() diff --git a/nac3core/src/codegen/values/ndarray/contiguous.rs b/nac3core/src/codegen/values/ndarray/contiguous.rs index 52082df6..0fbb85f0 100644 --- a/nac3core/src/codegen/values/ndarray/contiguous.rs +++ b/nac3core/src/codegen/values/ndarray/contiguous.rs @@ -117,8 +117,8 @@ impl<'ctx> NDArrayValue<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, ) -> ContiguousNDArrayValue<'ctx> { - let result = ContiguousNDArrayType::new(generator, ctx.ctx, self.dtype) - .alloca_var(generator, ctx, self.name); + let result = + ContiguousNDArrayType::new(ctx, &self.dtype).alloca_var(generator, ctx, self.name); // Set ndims and shape. let ndims = self.llvm_usize.const_int(self.ndims, false); @@ -178,8 +178,11 @@ impl<'ctx> NDArrayValue<'ctx> { // TODO: Debug assert `ndims == carray.ndims` to catch bugs. // Allocate the resulting ndarray. - let ndarray = NDArrayType::new(generator, ctx.ctx, carray.item, ndims) - .construct_uninitialized(generator, ctx, carray.name); + let ndarray = NDArrayType::new(ctx, carray.item, ndims).construct_uninitialized( + generator, + ctx, + carray.name, + ); // Copy shape and update strides let shape = carray.load_shape(ctx); diff --git a/nac3core/src/codegen/values/ndarray/indexing.rs b/nac3core/src/codegen/values/ndarray/indexing.rs index 1a96522b..60c9c3b7 100644 --- a/nac3core/src/codegen/values/ndarray/indexing.rs +++ b/nac3core/src/codegen/values/ndarray/indexing.rs @@ -128,11 +128,10 @@ impl<'ctx> NDArrayValue<'ctx> { indices: &[RustNDIndex<'ctx>], ) -> Self { let dst_ndims = self.deduce_ndims_after_indexing_with(indices); - let dst_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, dst_ndims) + let dst_ndarray = NDArrayType::new(ctx, self.dtype, dst_ndims) .construct_uninitialized(generator, ctx, None); - let indices = - NDIndexType::new(generator, ctx.ctx).construct_ndindices(generator, ctx, indices); + let indices = NDIndexType::new(ctx).construct_ndindices(generator, ctx, indices); irrt::ndarray::call_nac3_ndarray_index(generator, ctx, indices, *self, dst_ndarray); dst_ndarray @@ -245,8 +244,7 @@ impl<'ctx> RustNDIndex<'ctx> { } RustNDIndex::Slice(in_rust_slice) => { let user_slice_ptr = - SliceType::new(ctx.ctx, ctx.ctx.i32_type(), ctx.get_size_type()) - .alloca_var(generator, ctx, None); + SliceType::new(ctx, ctx.ctx.i32_type()).alloca_var(generator, ctx, None); in_rust_slice.write_to_slice(ctx, user_slice_ptr); dst_ndindex.store_data( diff --git a/nac3core/src/codegen/values/ndarray/matmul.rs b/nac3core/src/codegen/values/ndarray/matmul.rs index a24316b4..f12d36c1 100644 --- a/nac3core/src/codegen/values/ndarray/matmul.rs +++ b/nac3core/src/codegen/values/ndarray/matmul.rs @@ -108,7 +108,7 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>( let lhs = in_a.broadcast_to(generator, ctx, ndims_int, &lhs_shape); let rhs = in_b.broadcast_to(generator, ctx, ndims_int, &rhs_shape); - let dst = NDArrayType::new(generator, ctx.ctx, llvm_dst_dtype, ndims_int) + let dst = NDArrayType::new(ctx, llvm_dst_dtype, ndims_int) .construct_uninitialized(generator, ctx, None); dst.copy_shape_from_array(generator, ctx, dst_shape.base_ptr(ctx, generator)); unsafe { diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 595345e8..705412e0 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -377,12 +377,8 @@ impl<'ctx> NDArrayValue<'ctx> { .map(|obj| obj.as_basic_value_enum()) .collect_vec(); - TupleType::new( - generator, - ctx.ctx, - &repeat_n(llvm_i32.into(), self.ndims as usize).collect_vec(), - ) - .construct_from_objects(ctx, objects, None) + TupleType::new(ctx, &repeat_n(llvm_i32, self.ndims as usize).collect_vec()) + .construct_from_objects(ctx, objects, None) } /// Create the strides tuple of this ndarray like @@ -411,12 +407,8 @@ impl<'ctx> NDArrayValue<'ctx> { .map(|obj| obj.as_basic_value_enum()) .collect_vec(); - TupleType::new( - generator, - ctx.ctx, - &repeat_n(llvm_i32.into(), self.ndims as usize).collect_vec(), - ) - .construct_from_objects(ctx, objects, None) + TupleType::new(ctx, &repeat_n(llvm_i32, self.ndims as usize).collect_vec()) + .construct_from_objects(ctx, objects, None) } /// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar. @@ -998,10 +990,8 @@ impl<'ctx> ScalarOrNDArray<'ctx> { ) -> NDArrayValue<'ctx> { match self { ScalarOrNDArray::NDArray(ndarray) => *ndarray, - ScalarOrNDArray::Scalar(scalar) => { - NDArrayType::new_unsized(generator, ctx.ctx, scalar.get_type()) - .construct_unsized(generator, ctx, scalar, None) - } + ScalarOrNDArray::Scalar(scalar) => NDArrayType::new_unsized(ctx, scalar.get_type()) + .construct_unsized(generator, ctx, scalar, None), } } diff --git a/nac3core/src/codegen/values/ndarray/nditer.rs b/nac3core/src/codegen/values/ndarray/nditer.rs index 3784193d..dd900d64 100644 --- a/nac3core/src/codegen/values/ndarray/nditer.rs +++ b/nac3core/src/codegen/values/ndarray/nditer.rs @@ -160,9 +160,7 @@ impl<'ctx> NDArrayValue<'ctx> { generator, ctx, Some("ndarray_foreach"), - |generator, ctx| { - Ok(NDIterType::new(generator, ctx.ctx).construct(generator, ctx, *self)) - }, + |generator, ctx| Ok(NDIterType::new(ctx).construct(generator, ctx, *self)), |_, ctx, nditer| Ok(nditer.has_element(ctx)), |generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer), |_, ctx, nditer| { diff --git a/nac3core/src/codegen/values/ndarray/view.rs b/nac3core/src/codegen/values/ndarray/view.rs index f68931f7..9ab3d306 100644 --- a/nac3core/src/codegen/values/ndarray/view.rs +++ b/nac3core/src/codegen/values/ndarray/view.rs @@ -65,7 +65,7 @@ impl<'ctx> NDArrayValue<'ctx> { // not contiguous but could be reshaped without copying data. Look into how numpy does // it. - let dst_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, new_ndims) + let dst_ndarray = NDArrayType::new(ctx, self.dtype, new_ndims) .construct_uninitialized(generator, ctx, None); dst_ndarray.copy_shape_from_array(generator, ctx, new_shape.base_ptr(ctx, generator));