Compare commits

...

9 Commits

Author SHA1 Message Date
David Mak 23b2fee4e7 standalone: Add test case for ndarray slicing 2024-06-03 16:40:05 +08:00
David Mak ed79d5bb9e core/expr: Add support for multi-dim slicing of NDArrays 2024-06-03 16:40:05 +08:00
David Mak c35ad06949 core/expr: Add support for 1D slicing of NDArrays 2024-06-03 16:40:05 +08:00
David Mak 135ef557f9 core/numpy: Implement ndarray_sliced_{copy,copyto_impl}
Performing copying with optional support for slicing. Also made
copy_impl delegate to sliced_copy, as sliced_copy now performs a
superset of operations that copy_impl can already do.
2024-06-03 16:40:05 +08:00
David Mak a176c3eb70 core/irrt: Change handle_slice_indices to instead take length of object
So that all other array-like datatypes (e.g. ndarray) can also take
advantage of it.
2024-06-03 16:40:05 +08:00
David Mak 2cf79510c2 core/numpy: Add more helper functions 2024-06-03 16:40:05 +08:00
David Mak b6ff75dcaf core/irrt: Add support for calculating partial size of NDArray 2024-06-03 16:40:05 +08:00
David Mak 588c15f80d core/stmt: Add gen_for_range_callback
For generating for loops over range objects or array slices.
2024-06-03 16:40:05 +08:00
David Mak 82cc693b11 meta: Update dependencies 2024-06-03 16:40:02 +08:00
13 changed files with 780 additions and 249 deletions

112
Cargo.lock generated
View File

@ -117,9 +117,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]]
name = "cc"
version = "1.0.97"
version = "1.0.98"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "099a5357d84c4c61eb35fc8eafa9a79a902c2f76911e5747ced4e032edd8d9b4"
checksum = "41c270e7540d725e65ac7f1b212ac8ce349719624d7bcff99f8e2e488e8cf03f"
[[package]]
name = "cfg-if"
@ -158,7 +158,7 @@ dependencies = [
"heck 0.5.0",
"proc-macro2",
"quote",
"syn 2.0.61",
"syn 2.0.66",
]
[[package]]
@ -200,9 +200,9 @@ dependencies = [
[[package]]
name = "crossbeam-channel"
version = "0.5.12"
version = "0.5.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab3db02a9c5b5121e1e42fbdb1aeb65f5e02624cc58c43f2884c6ccac0b82f95"
checksum = "33480d6946193aa8033910124896ca395333cae7e2d1113d1fef6c3272217df2"
dependencies = [
"crossbeam-utils",
]
@ -237,9 +237,9 @@ dependencies = [
[[package]]
name = "crossbeam-utils"
version = "0.8.19"
version = "0.8.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345"
checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80"
[[package]]
name = "crunchy"
@ -270,9 +270,9 @@ dependencies = [
[[package]]
name = "either"
version = "1.11.0"
version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2"
checksum = "3dca9240753cf90908d7e4aac30f630662b02aebaa1b58a3cadabdb23385b58b"
[[package]]
name = "ena"
@ -297,9 +297,9 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
[[package]]
name = "errno"
version = "0.3.8"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245"
checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba"
dependencies = [
"libc",
"windows-sys",
@ -421,7 +421,7 @@ checksum = "4fa4d8d74483041a882adaa9a29f633253a66dde85055f0495c121620ac484b2"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.61",
"syn 2.0.66",
]
[[package]]
@ -455,9 +455,9 @@ dependencies = [
[[package]]
name = "itertools"
version = "0.12.1"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569"
checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186"
dependencies = [
"either",
]
@ -507,9 +507,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
[[package]]
name = "libc"
version = "0.2.154"
version = "0.2.155"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae743338b92ff9146ce83992f766a31066a91a8c84a45e0e9f21e7cf6de6d346"
checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c"
[[package]]
name = "libloading"
@ -539,9 +539,9 @@ checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f"
[[package]]
name = "linux-raw-sys"
version = "0.4.13"
version = "0.4.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c"
checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89"
[[package]]
name = "llvm-sys"
@ -592,7 +592,7 @@ name = "nac3artiq"
version = "0.1.0"
dependencies = [
"inkwell",
"itertools 0.12.1",
"itertools 0.13.0",
"nac3core",
"nac3ld",
"nac3parser",
@ -620,7 +620,7 @@ dependencies = [
"indoc",
"inkwell",
"insta",
"itertools 0.12.1",
"itertools 0.13.0",
"nac3parser",
"parking_lot",
"rayon",
@ -676,9 +676,9 @@ checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
[[package]]
name = "parking_lot"
version = "0.12.2"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7e4af0ca4f6caed20e900d564c242b8e5d4903fdacf31d3daf527b66fe6f42fb"
checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27"
dependencies = [
"lock_api",
"parking_lot_core",
@ -699,9 +699,9 @@ dependencies = [
[[package]]
name = "petgraph"
version = "0.6.4"
version = "0.6.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1d3afd2628e69da2be385eb6f2fd57c8ac7977ceeff6dc166ff1657b0e386a9"
checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db"
dependencies = [
"fixedbitset",
"indexmap 2.2.6",
@ -747,7 +747,7 @@ dependencies = [
"phf_shared 0.11.2",
"proc-macro2",
"quote",
"syn 2.0.61",
"syn 2.0.66",
]
[[package]]
@ -794,18 +794,18 @@ checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c"
[[package]]
name = "proc-macro2"
version = "1.0.82"
version = "1.0.85"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ad3d49ab951a01fbaafe34f2ec74122942fe18a3f9814c3268f1bb72042131b"
checksum = "22244ce15aa966053a896d1accb3a6e68469b97c7f33f284b99f0d576879fc23"
dependencies = [
"unicode-ident",
]
[[package]]
name = "pyo3"
version = "0.20.3"
version = "0.21.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53bdbb96d49157e65d45cc287af5f32ffadd5f4761438b527b055fb0d4bb8233"
checksum = "a5e00b96a521718e08e03b1a622f01c8a8deb50719335de3f60b3b3950f069d8"
dependencies = [
"cfg-if",
"indoc",
@ -821,9 +821,9 @@ dependencies = [
[[package]]
name = "pyo3-build-config"
version = "0.20.3"
version = "0.21.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "deaa5745de3f5231ce10517a1f5dd97d53e5a2fd77aa6b5842292085831d48d7"
checksum = "7883df5835fafdad87c0d888b266c8ec0f4c9ca48a5bed6bbb592e8dedee1b50"
dependencies = [
"once_cell",
"target-lexicon",
@ -831,9 +831,9 @@ dependencies = [
[[package]]
name = "pyo3-ffi"
version = "0.20.3"
version = "0.21.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62b42531d03e08d4ef1f6e85a2ed422eb678b8cd62b762e53891c05faf0d4afa"
checksum = "01be5843dc60b916ab4dad1dca6d20b9b4e6ddc8e15f50c47fe6d85f1fb97403"
dependencies = [
"libc",
"pyo3-build-config",
@ -841,27 +841,27 @@ dependencies = [
[[package]]
name = "pyo3-macros"
version = "0.20.3"
version = "0.21.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7305c720fa01b8055ec95e484a6eca7a83c841267f0dd5280f0c8b8551d2c158"
checksum = "77b34069fc0682e11b31dbd10321cbf94808394c56fd996796ce45217dfac53c"
dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn 2.0.61",
"syn 2.0.66",
]
[[package]]
name = "pyo3-macros-backend"
version = "0.20.3"
version = "0.21.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c7e9b68bb9c3149c5b0cade5d07f953d6d125eb4337723c4ccdb665f1f96185"
checksum = "08260721f32db5e1a5beae69a55553f56b99bd0e1c3e6e0a5e8851a9d0f5a85c"
dependencies = [
"heck 0.4.1",
"proc-macro2",
"pyo3-build-config",
"quote",
"syn 2.0.61",
"syn 2.0.66",
]
[[package]]
@ -994,9 +994,9 @@ dependencies = [
[[package]]
name = "rustversion"
version = "1.0.16"
version = "1.0.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "092474d1a01ea8278f69e6a358998405fae5b8b963ddaeb2b0b04a128bf1dfb0"
checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6"
[[package]]
name = "ryu"
@ -1027,22 +1027,22 @@ checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b"
[[package]]
name = "serde"
version = "1.0.201"
version = "1.0.203"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "780f1cebed1629e4753a1a38a3c72d30b97ec044f0aef68cb26650a3c5cf363c"
checksum = "7253ab4de971e72fb7be983802300c30b5a7f0c2e56fab8abfc6a214307c0094"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.201"
version = "1.0.203"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c5e405930b9796f1c00bee880d03fc7e0bb4b9a11afc776885ffe84320da2865"
checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.61",
"syn 2.0.66",
]
[[package]]
@ -1088,9 +1088,9 @@ checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
[[package]]
name = "string-interner"
version = "0.15.0"
version = "0.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07f9fdfdd31a0ff38b59deb401be81b73913d76c9cc5b1aed4e1330a223420b9"
checksum = "1c6a0d765f5807e98a091107bae0a56ea3799f66a5de47b2c84c94a39c09974e"
dependencies = [
"cfg-if",
"hashbrown 0.14.5",
@ -1129,9 +1129,9 @@ dependencies = [
[[package]]
name = "syn"
version = "2.0.61"
version = "2.0.66"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c993ed8ccba56ae856363b1845da7266a7cb78e1d146c8a32d54b45a8b831fc9"
checksum = "c42f3f41a2de00b01c0aaad383c5a45241efc8b2d1eda5661812fda5f3cdcff5"
dependencies = [
"proc-macro2",
"quote",
@ -1182,22 +1182,22 @@ dependencies = [
[[package]]
name = "thiserror"
version = "1.0.60"
version = "1.0.61"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "579e9083ca58dd9dcf91a9923bb9054071b9ebbd800b342194c9feb0ee89fc18"
checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.60"
version = "1.0.61"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2470041c06ec3ac1ab38d0356a6119054dedaea53e12fbefc0de730a1c08524"
checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.61",
"syn 2.0.66",
]
[[package]]
@ -1465,5 +1465,5 @@ checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.61",
"syn 2.0.66",
]

View File

@ -9,8 +9,8 @@ name = "nac3artiq"
crate-type = ["cdylib"]
[dependencies]
itertools = "0.12"
pyo3 = { version = "0.20", features = ["extension-module"] }
itertools = "0.13"
pyo3 = { version = "0.21", features = ["extension-module", "gil-refs"] }
parking_lot = "0.12"
tempfile = "3.10"
nac3parser = { path = "../nac3parser" }

View File

@ -12,5 +12,5 @@ fold = []
[dependencies]
lazy_static = "1.4"
parking_lot = "0.12"
string-interner = "0.15"
string-interner = "0.17"
fxhash = "0.2"

View File

@ -5,7 +5,7 @@ authors = ["M-Labs"]
edition = "2021"
[dependencies]
itertools = "0.12"
itertools = "0.13"
crossbeam = "0.8"
indexmap = "2.2"
parking_lot = "0.12"

View File

@ -737,7 +737,7 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>(
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes());
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None));
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let n_sz_eqz = ctx.builder
.build_int_compare(
@ -955,7 +955,7 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes());
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None));
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let n_sz_eqz = ctx.builder
.build_int_compare(

View File

@ -658,6 +658,17 @@ impl<'ctx> RangeValue<'ctx> {
RangeValue(ptr, name)
}
/// Returns the element type of this `range` object.
#[must_use]
pub fn element_type(&self) -> IntType<'ctx> {
self.as_ptr_value()
.get_type()
.get_element_type()
.into_array_type()
.get_element_type()
.into_int_type()
}
/// Returns the underlying [`PointerValue`] pointing to the `range` instance.
#[must_use]
pub fn as_ptr_value(&self) -> PointerValue<'ctx> {
@ -1111,7 +1122,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> {
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> IntValue<'ctx> {
call_ndarray_calc_size(generator, ctx, &self.as_slice_value(ctx, generator))
call_ndarray_calc_size(generator, ctx, &self.as_slice_value(ctx, generator), (None, None))
}
}

View File

@ -1667,6 +1667,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
slice: &Expr<Option<Type>>,
) -> Result<Option<ValueEnum<'ctx>>, String> {
let llvm_i1 = ctx.ctx.bool_type();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) else {
@ -1712,13 +1713,11 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
slice.location,
);
if let ExprKind::Slice { .. } = &slice.node {
return Err(String::from("subscript operator for ndarray not implemented"))
}
let index = if let Some(index) = generator.gen_expr(ctx, slice)? {
let index = index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value();
// Normalizes a possibly-negative index to its corresponding positive index
let normalize_index = |generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
index: IntValue<'ctx>,
dim: u64| {
gen_if_else_expr_callback(
generator,
ctx,
@ -1738,7 +1737,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
v.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_zero(),
&llvm_usize.const_int(dim, true),
None,
)
};
@ -1751,96 +1750,194 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
Ok(Some(ctx.builder.build_int_truncate(index, llvm_i32, "").unwrap()))
},
)?.map(BasicValueEnum::into_int_value).unwrap()
} else {
return Ok(None)
).map(|v| v.map(BasicValueEnum::into_int_value))
};
let index_addr = generator.gen_var_alloc(ctx, index.get_type().into(), None)?;
ctx.builder.build_store(index_addr, index).unwrap();
if ndims.len() == 1 && ndims[0] == 1 {
// Accessing an element from a 1-dimensional `ndarray`
// Converts a slice expression into a slice-range tuple
let expr_to_slice = |generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
node: &ExprKind<Option<Type>>,
dim: u64| {
match node {
ExprKind::Constant { value: Constant::Int(v), .. } => {
let Some(index) = normalize_index(
generator, ctx, llvm_i32.const_int(*v as u64, true), dim,
)? else {
return Ok(None)
};
Ok(Some(v.data()
.get(
Ok(Some((index, index, llvm_i32.const_int(1, true))))
}
ExprKind::Slice { lower, upper, step } => {
let dim_sz = unsafe {
v.dim_sizes()
.get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(dim, false),
None,
)
};
handle_slice_indices(lower, upper, step, ctx, generator, dim_sz)
}
_ => {
let Some(index) = generator.gen_expr(ctx, slice)? else {
return Ok(None)
};
let index = index
.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?
.into_int_value();
let Some(index) = normalize_index(generator, ctx, index, dim)? else {
return Ok(None)
};
Ok(Some((index, index, llvm_i32.const_int(1, true))))
}
}
};
Ok(Some(match &slice.node {
ExprKind::Tuple { elts, .. } => {
let slices = elts.iter().enumerate()
.map(|(dim, elt)| expr_to_slice(generator, ctx, &elt.node, dim as u64))
.take_while_inclusive(|slice| slice.as_ref().is_ok_and(Option::is_some))
.collect::<Result<Vec<_>, _>>()?;
if slices.len() < elts.len() {
return Ok(None)
}
let slices = slices.into_iter()
.map(Option::unwrap)
.collect_vec();
numpy::ndarray_sliced_copy(
generator,
ctx,
ty,
v,
&slices,
)?.as_ptr_value().into()
}
ExprKind::Slice { .. } => {
let Some(slice) = expr_to_slice(generator, ctx, &slice.node, 0)? else {
return Ok(None)
};
numpy::ndarray_sliced_copy(
generator,
ctx,
ty,
v,
&[slice],
)?.as_ptr_value().into()
}
_ => {
let index = if let Some(index) = generator.gen_expr(ctx, slice)? {
index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value()
} else {
return Ok(None)
};
let Some(index) = normalize_index(generator, ctx, index, 0)? else {
return Ok(None)
};
let index_addr = generator.gen_var_alloc(ctx, index.get_type().into(), None)?;
ctx.builder.build_store(index_addr, index).unwrap();
if ndims.len() == 1 && ndims[0] == 1 {
// Accessing an element from a 1-dimensional `ndarray`
return Ok(Some(v.data()
.get(
ctx,
generator,
&ArraySliceValue::from_ptr_val(
index_addr,
llvm_usize.const_int(1, false),
None,
),
None,
)
.into()))
}
// Accessing an element from a multi-dimensional `ndarray`
// Create a new array, remove the top dimension from the dimension-size-list, and copy the
// elements over
let subscripted_ndarray = generator.gen_var_alloc(
ctx,
llvm_ndarray_t.into(),
None
)?;
let ndarray = NDArrayValue::from_ptr_val(
subscripted_ndarray,
llvm_usize,
None
);
let num_dims = v.load_ndims(ctx);
ndarray.store_ndims(
ctx,
generator,
ctx.builder.build_int_sub(num_dims, llvm_usize.const_int(1, false), "").unwrap(),
);
let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
let ndarray_num_dims = ndarray.load_ndims(ctx);
let v_dims_src_ptr = unsafe {
v.dim_sizes().ptr_offset_unchecked(
ctx,
generator,
&llvm_usize.const_int(1, false),
None,
)
};
call_memcpy_generic(
ctx,
ndarray.dim_sizes().base_ptr(ctx, generator),
v_dims_src_ptr,
ctx.builder
.build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "")
.map(Into::into)
.unwrap(),
llvm_i1.const_zero(),
);
let ndarray_num_elems = call_ndarray_calc_size(
generator,
ctx,
&ndarray.dim_sizes().as_slice_value(ctx, generator),
(None, None),
);
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
let v_data_src_ptr = v.data().ptr_offset(
ctx,
generator,
&ArraySliceValue::from_ptr_val(index_addr, llvm_usize.const_int(1, false), None),
None,
)
.into()))
} else {
// Accessing an element from a multi-dimensional `ndarray`
// Create a new array, remove the top dimension from the dimension-size-list, and copy the
// elements over
let subscripted_ndarray = generator.gen_var_alloc(
ctx,
llvm_ndarray_t.into(),
None
)?;
let ndarray = NDArrayValue::from_ptr_val(
subscripted_ndarray,
llvm_usize,
None
);
let num_dims = v.load_ndims(ctx);
ndarray.store_ndims(
ctx,
generator,
ctx.builder.build_int_sub(num_dims, llvm_usize.const_int(1, false), "").unwrap(),
);
let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
let ndarray_num_dims = ndarray.load_ndims(ctx);
let v_dims_src_ptr = unsafe {
v.dim_sizes().ptr_offset_unchecked(
None
);
call_memcpy_generic(
ctx,
generator,
&llvm_usize.const_int(1, false),
None,
)
};
call_memcpy_generic(
ctx,
ndarray.dim_sizes().base_ptr(ctx, generator),
v_dims_src_ptr,
ctx.builder
.build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "")
.map(Into::into)
.unwrap(),
llvm_i1.const_zero(),
);
ndarray.data().base_ptr(ctx, generator),
v_data_src_ptr,
ctx.builder
.build_int_mul(ndarray_num_elems, llvm_ndarray_data_t.size_of().unwrap(), "")
.map(Into::into)
.unwrap(),
llvm_i1.const_zero(),
);
let ndarray_num_elems = call_ndarray_calc_size(
generator,
ctx,
&ndarray.dim_sizes().as_slice_value(ctx, generator),
);
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
let v_data_src_ptr = v.data().ptr_offset(
ctx,
generator,
&ArraySliceValue::from_ptr_val(index_addr, llvm_usize.const_int(1, false), None),
None
);
call_memcpy_generic(
ctx,
ndarray.data().base_ptr(ctx, generator),
v_data_src_ptr,
ctx.builder
.build_int_mul(ndarray_num_elems, llvm_ndarray_data_t.size_of().unwrap(), "")
.map(Into::into)
.unwrap(),
llvm_i1.const_zero(),
);
Ok(Some(ndarray.as_ptr_value().into()))
}
ndarray.as_ptr_value().into()
}
}))
}
/// See [`CodeGenerator::gen_expr`].
@ -2263,10 +2360,14 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
let ty = ctx.get_llvm_type(generator, *ty);
if let ExprKind::Slice { lower, upper, step } = &slice.node {
let one = int32.const_int(1, false);
let Some((start, end, step)) =
handle_slice_indices(lower, upper, step, ctx, generator, v)? else {
return Ok(None)
};
let Some((start, end, step)) = handle_slice_indices(
lower,
upper,
step,
ctx,
generator,
v.load_size(ctx, None),
)? else { return Ok(None) };
let length = calculate_len_for_slice_range(
generator,
ctx,
@ -2288,10 +2389,14 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
step,
);
let res_array_ret = allocate_list(generator, ctx, ty, length, Some("ret"));
let Some(res_ind) =
handle_slice_indices(&None, &None, &None, ctx, generator, res_array_ret)? else {
return Ok(None)
};
let Some(res_ind) = handle_slice_indices(
&None,
&None,
&None,
ctx,
generator,
res_array_ret.load_size(ctx, None),
)? else { return Ok(None) };
list_slice_assignment(
generator,
ctx,

View File

@ -202,10 +202,14 @@ double __nac3_j0(double x) {
uint32_t __nac3_ndarray_calc_size(
const uint64_t *list_data,
uint32_t list_len
uint32_t list_len,
uint32_t begin_idx,
uint32_t end_idx
) {
__builtin_assume(end_idx <= list_len);
uint32_t num_elems = 1;
for (uint32_t i = 0; i < list_len; ++i) {
for (uint32_t i = begin_idx; i < end_idx; ++i) {
uint64_t val = list_data[i];
__builtin_assume(val > 0);
num_elems *= val;
@ -215,10 +219,14 @@ uint32_t __nac3_ndarray_calc_size(
uint64_t __nac3_ndarray_calc_size64(
const uint64_t *list_data,
uint64_t list_len
uint64_t list_len,
uint64_t begin_idx,
uint64_t end_idx
) {
__builtin_assume(end_idx <= list_len);
uint64_t num_elems = 1;
for (uint64_t i = 0; i < list_len; ++i) {
for (uint64_t i = begin_idx; i < end_idx; ++i) {
uint64_t val = list_data[i];
__builtin_assume(val > 0);
num_elems *= val;

View File

@ -175,12 +175,11 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
step: &Option<Box<Expr<Option<Type>>>>,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
list: ListValue<'ctx>,
length: IntValue<'ctx>,
) -> Result<Option<(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)>, String> {
let int32 = ctx.ctx.i32_type();
let zero = int32.const_zero();
let one = int32.const_int(1, false);
let length = list.load_size(ctx, Some("length"));
let length = ctx.builder.build_int_truncate_or_bit_cast(length, int32, "leni32").unwrap();
Ok(Some(match (start, end, step) {
(s, e, None) => (
@ -583,12 +582,14 @@ pub fn call_j0<'ctx>(
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
/// calculated total size.
///
/// * `num_dims` - An [`IntValue`] containing the number of dimensions.
/// * `dims` - A [`PointerValue`] to an array containing the size of each dimension.
/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension.
/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for,
/// or [`None`] if starting from the first dimension and ending at the last dimension respectively.
pub fn call_ndarray_calc_size<'ctx, G, Dims>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
dims: &Dims,
(begin, end): (Option<IntValue<'ctx>>, Option<IntValue<'ctx>>),
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
@ -607,6 +608,8 @@ pub fn call_ndarray_calc_size<'ctx, G, Dims>(
&[
llvm_pi64.into(),
llvm_usize.into(),
llvm_usize.into(),
llvm_usize.into(),
],
false,
);
@ -615,12 +618,16 @@ pub fn call_ndarray_calc_size<'ctx, G, Dims>(
ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
});
let begin = begin.unwrap_or_else(|| llvm_usize.const_zero());
let end = end.unwrap_or_else(|| dims.size(ctx, generator));
ctx.builder
.build_call(
ndarray_calc_size_fn,
&[
dims.base_ptr(ctx, generator).into(),
dims.size(ctx, generator).into(),
begin.into(),
end.into(),
],
"",
)

View File

@ -9,6 +9,7 @@ use crate::{
NDArrayValue,
TypedArrayLikeAccessor,
TypedArrayLikeAdapter,
TypedArrayLikeMutator,
UntypedArrayLikeAccessor,
UntypedArrayLikeMutator,
},
@ -16,6 +17,7 @@ use crate::{
CodeGenerator,
expr::gen_binop_expr_with_values,
irrt::{
calculate_len_for_slice_range,
call_ndarray_calc_broadcast,
call_ndarray_calc_broadcast_index,
call_ndarray_calc_nd_indices,
@ -23,7 +25,7 @@ use crate::{
},
llvm_intrinsics,
llvm_intrinsics::{call_memcpy_generic},
stmt::{gen_for_callback_incrementing, gen_if_else_expr_callback},
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
},
symbol_resolver::ValueEnum,
toplevel::{
@ -33,6 +35,30 @@ use crate::{
typecheck::typedef::{FunSignature, Type},
};
/// Creates an uninitialized `NDArray` instance.
fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
) -> Result<NDArrayValue<'ctx>, String> {
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_ndarray_t = ctx.get_llvm_type(generator, ndarray_ty)
.into_pointer_type()
.get_element_type()
.into_struct_type();
let ndarray = generator.gen_var_alloc(
ctx,
llvm_ndarray_t.into(),
None,
)?;
Ok(NDArrayValue::from_ptr_val(ndarray, llvm_usize, None))
}
/// Creates an `NDArray` instance from a dynamic shape.
///
/// * `elem_ty` - The element type of the `NDArray`.
@ -52,15 +78,8 @@ fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
LenFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, &V) -> Result<IntValue<'ctx>, String>,
DataFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, &V, IntValue<'ctx>) -> Result<IntValue<'ctx>, String>,
{
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum();
assert!(llvm_ndarray_data_t.is_sized());
// Assert that all dimensions are non-negative
let shape_len = shape_len_fn(generator, ctx, shape)?;
gen_for_callback_incrementing(
@ -92,12 +111,7 @@ fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
llvm_usize.const_int(1, false),
)?;
let ndarray = generator.gen_var_alloc(
ctx,
llvm_ndarray_t.into(),
None,
)?;
let ndarray = NDArrayValue::from_ptr_val(ndarray, llvm_usize, None);
let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?;
let num_dims = shape_len_fn(generator, ctx, shape)?;
ndarray.store_ndims(ctx, generator, num_dims);
@ -130,12 +144,7 @@ fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>(
llvm_usize.const_int(1, false),
)?;
let ndarray_num_elems = call_ndarray_calc_size(
generator,
ctx,
&ndarray.dim_sizes().as_slice_value(ctx, generator),
);
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray);
Ok(ndarray)
}
@ -150,15 +159,8 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
elem_ty: Type,
shape: &[IntValue<'ctx>],
) -> Result<NDArrayValue<'ctx>, String> {
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum();
assert!(llvm_ndarray_data_t.is_sized());
for shape_dim in shape {
let shape_dim_gez = ctx.builder
.build_int_compare(IntPredicate::SGE, *shape_dim, llvm_usize.const_zero(), "")
@ -176,12 +178,7 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
// TODO: Disallow dim_sz > u32_MAX
}
let ndarray = generator.gen_var_alloc(
ctx,
llvm_ndarray_t.into(),
None,
)?;
let ndarray = NDArrayValue::from_ptr_val(ndarray, llvm_usize, None);
let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?;
let num_dims = llvm_usize.const_int(shape.len() as u64, false);
ndarray.store_ndims(ctx, generator, num_dims);
@ -199,14 +196,30 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
ctx.builder.build_store(ndarray_dim, *shape_dim).unwrap();
}
let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray);
Ok(ndarray)
}
/// Initializes the `data` field of [`NDArrayValue`] based on the `ndims` and `dim_sz` fields.
fn ndarray_init_data<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
ndarray: NDArrayValue<'ctx>,
) -> NDArrayValue<'ctx> {
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum();
assert!(llvm_ndarray_data_t.is_sized());
let ndarray_num_elems = call_ndarray_calc_size(
generator,
ctx,
&ndarray.dim_sizes().as_slice_value(ctx, generator),
(None, None),
);
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
Ok(ndarray)
ndarray
}
fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
@ -293,6 +306,7 @@ fn ndarray_fill_flattened<'ctx, 'a, G, ValueFn>(
generator,
ctx,
&ndarray.dim_sizes().as_slice_value(ctx, generator),
(None, None),
);
gen_for_callback_incrementing(
@ -633,6 +647,240 @@ fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>(
Ok(ndarray)
}
/// Copies a slice of an [`NDArrayValue`] to another.
///
/// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `dim_sz`
/// fields should be populated before calling this function.
/// - `dst_slice_ptr`: The [`PointerValue`] to the first element of the currently processing
/// dimensional slice in the destination array.
/// - `src_arr`: The [`NDArrayValue`] instance of the source array.
/// - `src_slice_ptr`: The [`PointerValue`] to the first element of the currently processing
/// dimensional slice in the source array.
/// - `dim`: The index of the currently processing dimension.
/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to
/// this dimension. The `start`/`stop` values of each slice must be non-negative indices.
fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
(dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>),
(src_arr, src_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>),
dim: u64,
slices: &[(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)],
) -> Result<(), String> {
let llvm_i1 = ctx.ctx.bool_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
// If there are no (remaining) slice expressions, memcpy the entire dimension
if slices.is_empty() {
let stride = call_ndarray_calc_size(
generator,
ctx,
&src_arr.dim_sizes(),
(Some(llvm_usize.const_int(dim, false)), None),
);
let sizeof_elem = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap();
let cpy_len = ctx.builder.build_int_mul(
stride,
sizeof_elem,
""
).unwrap();
call_memcpy_generic(
ctx,
dst_slice_ptr,
src_slice_ptr,
cpy_len,
llvm_i1.const_zero(),
);
return Ok(())
}
// The stride of elements in this dimension, i.e. the number of elements between arr[i] and
// arr[i + 1] in this dimension
let src_stride = call_ndarray_calc_size(
generator,
ctx,
&src_arr.dim_sizes(),
(Some(llvm_usize.const_int(dim + 1, false)), None),
);
let dst_stride = call_ndarray_calc_size(
generator,
ctx,
&dst_arr.dim_sizes(),
(Some(llvm_usize.const_int(dim + 1, false)), None),
);
let (start, stop, step) = slices[0];
let start = ctx.builder.build_int_s_extend_or_bit_cast(start, llvm_usize, "").unwrap();
let stop = ctx.builder.build_int_s_extend_or_bit_cast(stop, llvm_usize, "").unwrap();
let step = ctx.builder.build_int_s_extend_or_bit_cast(step, llvm_usize, "").unwrap();
let dst_i_addr = generator.gen_var_alloc(ctx, start.get_type().into(), None).unwrap();
ctx.builder.build_store(dst_i_addr, start.get_type().const_zero()).unwrap();
gen_for_range_callback(
generator,
ctx,
false,
|_, _| Ok(start),
(|_, _| Ok(stop), true),
|_, _| Ok(step),
|generator, ctx, src_i| {
// Calculate the offset of the active slice
let src_data_offset = ctx.builder.build_int_mul(
src_stride,
src_i,
"",
).unwrap();
let dst_i = ctx.builder
.build_load(dst_i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let dst_data_offset = ctx.builder.build_int_mul(
dst_stride,
dst_i,
"",
).unwrap();
let (src_ptr, dst_ptr) = unsafe {
(
ctx.builder.build_gep(src_slice_ptr, &[src_data_offset], "").unwrap(),
ctx.builder.build_gep(dst_slice_ptr, &[dst_data_offset], "").unwrap(),
)
};
ndarray_sliced_copyto_impl(
generator,
ctx,
elem_ty,
(dst_arr, dst_ptr),
(src_arr, src_ptr),
dim + 1,
&slices[1..],
)?;
let dst_i = ctx.builder
.build_load(dst_i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let dst_i_add1 = ctx.builder
.build_int_add(dst_i, llvm_usize.const_int(1, false), "")
.unwrap();
ctx.builder.build_store(dst_i_addr, dst_i_add1).unwrap();
Ok(())
},
)?;
Ok(())
}
/// Copies a [`NDArrayValue`] using slices.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to
/// this dimension. The `start`/`stop` values of each slice must be positive indices.
pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
this: NDArrayValue<'ctx>,
slices: &[(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)],
) -> Result<NDArrayValue<'ctx>, String> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let ndarray = if slices.is_empty() {
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&this,
|_, ctx, shape| {
Ok(shape.load_ndims(ctx))
},
|generator, ctx, shape, idx| {
unsafe { Ok(shape.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) }
},
)?
} else {
let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?;
ndarray.store_ndims(ctx, generator, this.load_ndims(ctx));
let ndims = this.load_ndims(ctx);
ndarray.create_dim_sizes(ctx, llvm_usize, ndims);
// Populate the first slices.len() dimensions by computing the size of each dim slice
for (i, (start, stop, step)) in slices.iter().enumerate() {
// HACK: workaround calculate_len_for_slice_range requiring exclusive stop
let stop = ctx.builder
.build_select(
ctx.builder.build_int_compare(
IntPredicate::SLT,
*step,
llvm_i32.const_zero(),
"is_neg",
).unwrap(),
ctx.builder.build_int_sub(*stop, llvm_i32.const_int(1, true), "e_min_one").unwrap(),
ctx.builder.build_int_add(*stop, llvm_i32.const_int(1, true), "e_add_one").unwrap(),
"final_e",
)
.map(BasicValueEnum::into_int_value)
.unwrap();
let slice_len = calculate_len_for_slice_range(generator, ctx, *start, stop, *step);
let slice_len = ctx.builder.build_int_z_extend_or_bit_cast(
slice_len,
llvm_usize,
""
).unwrap();
unsafe {
ndarray.dim_sizes()
.set_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(i as u64, false),
slice_len,
);
}
}
// Populate the rest by directly copying the dim size from the source array
gen_for_callback_incrementing(
generator,
ctx,
llvm_usize.const_int(slices.len() as u64, false),
(this.load_ndims(ctx), false),
|generator, ctx, idx| {
unsafe {
let dim_sz = this.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None);
ndarray.dim_sizes().set_typed_unchecked(ctx, generator, &idx, dim_sz);
}
Ok(())
},
llvm_usize.const_int(1, false),
).unwrap();
ndarray_init_data(generator, ctx, elem_ty, ndarray)
};
ndarray_sliced_copyto_impl(
generator,
ctx,
elem_ty,
(ndarray, ndarray.data().base_ptr(ctx, generator)),
(this, this.data().base_ptr(ctx, generator)),
0,
slices,
)?;
Ok(ndarray)
}
/// LLVM-typed implementation for generating the implementation for `ndarray.copy`.
///
/// * `elem_ty` - The element type of the `NDArray`.
@ -642,44 +890,7 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>(
elem_ty: Type,
this: NDArrayValue<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
let llvm_i1 = ctx.ctx.bool_type();
let ndarray = create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&this,
|_, ctx, shape| {
Ok(shape.load_ndims(ctx))
},
|generator, ctx, shape, idx| {
unsafe { Ok(shape.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) }
},
)?;
let len = call_ndarray_calc_size(
generator,
ctx,
&ndarray.dim_sizes().as_slice_value(ctx, generator),
);
let sizeof_ty = ctx.get_llvm_type(generator, elem_ty);
let len_bytes = ctx.builder
.build_int_mul(
len,
sizeof_ty.size_of().unwrap(),
"",
)
.unwrap();
call_memcpy_generic(
ctx,
ndarray.data().base_ptr(ctx, generator),
this.data().base_ptr(ctx, generator),
len_bytes,
llvm_i1.const_zero(),
);
Ok(ndarray)
ndarray_sliced_copy(generator, ctx, elem_ty, this, &[])
}
pub fn ndarray_elementwise_unaryop_impl<'ctx, 'a, G, MapFn>(

View File

@ -240,10 +240,14 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
.to_basic_value_enum(ctx, generator, ls.custom.unwrap())?
.into_pointer_value();
let ls = ListValue::from_ptr_val(ls, llvm_usize, None);
let Some((start, end, step)) =
handle_slice_indices(lower, upper, step, ctx, generator, ls)? else {
return Ok(())
};
let Some((start, end, step)) = handle_slice_indices(
lower,
upper,
step,
ctx,
generator,
ls.load_size(ctx, None),
)? else { return Ok(()) };
let value = value
.to_basic_value_enum(ctx, generator, target.custom.unwrap())?
.into_pointer_value();
@ -257,9 +261,14 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
};
let ty = ctx.get_llvm_type(generator, ty);
let Some(src_ind) = handle_slice_indices(&None, &None, &None, ctx, generator, value)? else {
return Ok(())
};
let Some(src_ind) = handle_slice_indices(
&None,
&None,
&None,
ctx,
generator,
value.load_size(ctx, None),
)? else { return Ok(()) };
list_slice_assignment(generator, ctx, ty, ls, (start, end, step), value, src_ind);
}
_ => {
@ -621,6 +630,139 @@ pub fn gen_for_callback_incrementing<'ctx, 'a, G, BodyFn>(
)
}
/// Generates a `for` construct over a `range`-like iterable using lambdas, similar to the following
/// C code:
///
/// ```c
/// bool incr = start_fn() <= end_fn();
/// for (int i = start_fn(); i /* < or > */ end_fn(); i += step_fn()) {
/// body_fn(i);
/// }
/// ```
///
/// - `is_unsigned`: Whether to treat the values of the `range` as unsigned.
/// - `start_fn`: A lambda of IR statements that retrieves the `start` value of the `range`-like
/// iterable.
/// - `stop_fn`: A lambda of IR statements that retrieves the `stop` value of the `range`-like
/// iterable. This value will be extended to the size of `start`.
/// - `stop_inclusive`: Whether the stop value should be treated as inclusive.
/// - `step_fn`: A lambda of IR statements that retrieves the `step` value of the `range`-like
/// iterable. This value will be extended to the size of `start`.
/// - `body_fn`: A lambda of IR statements within the loop body.
pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
is_unsigned: bool,
start_fn: StartFn,
(stop_fn, stop_inclusive): (StopFn, bool),
step_fn: StepFn,
body_fn: BodyFn,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
StartFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
StopFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
StepFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<(), String>,
{
let init_val_t = start_fn(generator, ctx)
.map(IntValue::get_type)
.unwrap();
gen_for_callback(
generator,
ctx,
|generator, ctx| {
let i_addr = generator.gen_var_alloc(ctx, init_val_t.into(), None)?;
let start = start_fn(generator, ctx)?;
ctx.builder.build_store(i_addr, start).unwrap();
let start = start_fn(generator, ctx)?;
let stop = stop_fn(generator, ctx)?;
let stop = if stop.get_type().get_bit_width() == start.get_type().get_bit_width() {
stop
} else if is_unsigned {
ctx.builder.build_int_z_extend(stop, start.get_type(), "").unwrap()
} else {
ctx.builder.build_int_s_extend(stop, start.get_type(), "").unwrap()
};
let incr = ctx.builder.build_int_compare(
if is_unsigned { IntPredicate::ULE } else { IntPredicate::SLE },
start,
stop,
"",
).unwrap();
Ok((i_addr, incr))
},
|generator, ctx, (i_addr, incr)| {
let (lt_cmp_op, gt_cmp_op) = match (is_unsigned, stop_inclusive) {
(true, true) => (IntPredicate::ULE, IntPredicate::UGE),
(true, false) => (IntPredicate::ULT, IntPredicate::UGT),
(false, true) => (IntPredicate::SLE, IntPredicate::SGE),
(false, false) => (IntPredicate::SLT, IntPredicate::SGT),
};
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let stop = stop_fn(generator, ctx)?;
let stop = if stop.get_type().get_bit_width() == i.get_type().get_bit_width() {
stop
} else if is_unsigned {
ctx.builder.build_int_z_extend(stop, i.get_type(), "").unwrap()
} else {
ctx.builder.build_int_s_extend(stop, i.get_type(), "").unwrap()
};
let i_lt_end = ctx.builder
.build_int_compare(lt_cmp_op, i, stop, "")
.unwrap();
let i_gt_end = ctx.builder
.build_int_compare(gt_cmp_op, i, stop, "")
.unwrap();
let cond = ctx.builder
.build_select(incr, i_lt_end, i_gt_end, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
Ok(cond)
},
|generator, ctx, (i_addr, _)| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
body_fn(generator, ctx, i)
},
|generator, ctx, (i_addr, _)| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let incr_val = step_fn(generator, ctx)?;
let incr_val = if incr_val.get_type().get_bit_width() == i.get_type().get_bit_width() {
incr_val
} else if is_unsigned {
ctx.builder.build_int_z_extend(incr_val, i.get_type(), "").unwrap()
} else {
ctx.builder.build_int_s_extend(incr_val, i.get_type(), "").unwrap()
};
let i = ctx.builder.build_int_add(i, incr_val, "").unwrap();
ctx.builder.build_store(i_addr, i).unwrap();
Ok(())
},
)
}
/// See [`CodeGenerator::gen_while`].
pub fn gen_while<G: CodeGenerator>(
generator: &mut G,

View File

@ -2,6 +2,7 @@ use std::collections::{HashMap, HashSet};
use std::convert::{From, TryInto};
use std::iter::once;
use std::{cell::RefCell, sync::Arc};
use std::ops::Not;
use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap};
use super::{magic_methods::*, type_error::TypeError, typedef::CallId};
@ -554,7 +555,10 @@ impl<'a> Fold<()> for Inferencer<'a> {
ExprKind::ListComp { .. }
| ExprKind::Lambda { .. }
| ExprKind::Call { .. } => expr.custom, // already computed
ExprKind::Slice { .. } => None, // we don't need it for slice
ExprKind::Slice { .. } => {
// slices aren't exactly ranges, but for our purposes this should suffice
Some(self.primitives.range)
}
_ => return report_error("not supported", expr.location),
};
Ok(ast::Expr { custom, location: expr.location, node: expr.node })
@ -1642,6 +1646,30 @@ impl<'a> Inferencer<'a> {
}
}
}
ExprKind::Tuple { elts, .. } => {
if value.custom
.unwrap()
.obj_id(self.unifier)
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)
.not() {
return report_error("Tuple slices are only supported for ndarrays", slice.location)
}
for elt in elts {
if let ExprKind::Slice { lower, upper, step } = &elt.node {
for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() {
self.constrain(v.custom.unwrap(), self.primitives.int32, &v.location)?;
}
} else {
self.constrain(elt.custom.unwrap(), self.primitives.int32, &elt.location)?;
}
}
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
let ndarray_ty = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims));
self.constrain(value.custom.unwrap(), ndarray_ty, &value.location)?;
Ok(ndarray_ty)
}
_ => {
if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) {
return report_error("Tuple index must be a constant (KernelInvariant is also not supported)", slice.location)

View File

@ -121,6 +121,22 @@ def test_ndarray_neg_idx():
for j in range(-1, -3, -1):
output_float64(x[i][j])
def test_ndarray_slices():
x = np_identity(3)
output_ndarray_float_2(x)
x_identity = x[::]
output_ndarray_float_2(x_identity)
x02 = x[0::2]
output_ndarray_float_2(x02)
x_mirror = x[::-1]
output_ndarray_float_2(x_mirror)
x2 = x[0::2, 0::2]
output_ndarray_float_2(x2)
def test_ndarray_add():
x = np_identity(2)
y = x + np_ones([2, 2])
@ -1360,7 +1376,10 @@ def run() -> int32:
test_ndarray_identity()
test_ndarray_fill()
test_ndarray_copy()
test_ndarray_neg_idx()
test_ndarray_slices()
test_ndarray_add()
test_ndarray_add_broadcast()
test_ndarray_add_broadcast_lhs_scalar()