Compare commits

..

4 Commits

72 changed files with 1111 additions and 1667 deletions

111
Cargo.lock generated
View File

@ -65,12 +65,11 @@ dependencies = [
[[package]] [[package]]
name = "anstyle-wincon" name = "anstyle-wincon"
version = "3.0.7" version = "3.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca3534e77181a9cc07539ad51f2141fe32f6c3ffd4df76db8ad92346b003ae4e" checksum = "2109dbce0e72be3ec00bed26e6a7479ca384ad226efdd66db8fa2e3a38c83125"
dependencies = [ dependencies = [
"anstyle", "anstyle",
"once_cell",
"windows-sys 0.59.0", "windows-sys 0.59.0",
] ]
@ -106,9 +105,9 @@ checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7"
[[package]] [[package]]
name = "bitflags" name = "bitflags"
version = "2.8.0" version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de"
[[package]] [[package]]
name = "block-buffer" name = "block-buffer"
@ -127,9 +126,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.2.9" version = "1.2.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8293772165d9345bdaaa39b45b2109591e63fe5e6fbc23c6ff930a048aa310b" checksum = "a012a0df96dd6d06ba9a1b29d6402d1a5d77c6befd2566afdc26e10603dc93d7"
dependencies = [ dependencies = [
"shlex", "shlex",
] ]
@ -142,9 +141,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.5.26" version = "4.5.23"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8eb5e908ef3a6efbe1ed62520fb7287959888c88485abe072543190ecc66783" checksum = "3135e7ec2ef7b10c6ed8950f0f792ed96ee093fa088608f1c76e569722700c84"
dependencies = [ dependencies = [
"clap_builder", "clap_builder",
"clap_derive", "clap_derive",
@ -152,9 +151,9 @@ dependencies = [
[[package]] [[package]]
name = "clap_builder" name = "clap_builder"
version = "4.5.26" version = "4.5.23"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96b01801b5fc6a0a232407abc821660c9c6d25a1cafc0d4f85f29fb8d9afc121" checksum = "30582fc632330df2bd26877bde0c1f4470d57c582bbc070376afcd04d8cb4838"
dependencies = [ dependencies = [
"anstream", "anstream",
"anstyle", "anstyle",
@ -164,14 +163,14 @@ dependencies = [
[[package]] [[package]]
name = "clap_derive" name = "clap_derive"
version = "4.5.24" version = "4.5.18"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54b755194d6389280185988721fffba69495eed5ee9feeee9a599b53db80318c" checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab"
dependencies = [ dependencies = [
"heck 0.5.0", "heck 0.5.0",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.96", "syn 2.0.94",
] ]
[[package]] [[package]]
@ -473,7 +472,7 @@ checksum = "9dd28cfd4cfba665d47d31c08a6ba637eed16770abca2eccbbc3ca831fef1e44"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.96", "syn 2.0.94",
] ]
[[package]] [[package]]
@ -582,9 +581,9 @@ checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f"
[[package]] [[package]]
name = "linux-raw-sys" name = "linux-raw-sys"
version = "0.4.15" version = "0.4.14"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89"
[[package]] [[package]]
name = "llvm-sys" name = "llvm-sys"
@ -611,9 +610,9 @@ dependencies = [
[[package]] [[package]]
name = "log" name = "log"
version = "0.4.25" version = "0.4.22"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24"
[[package]] [[package]]
name = "memchr" name = "memchr"
@ -679,7 +678,7 @@ dependencies = [
"proc-macro-error", "proc-macro-error",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.96", "syn 2.0.94",
"trybuild", "trybuild",
] ]
@ -762,45 +761,45 @@ dependencies = [
[[package]] [[package]]
name = "phf" name = "phf"
version = "0.11.3" version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078" checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc"
dependencies = [ dependencies = [
"phf_macros", "phf_macros",
"phf_shared 0.11.3", "phf_shared 0.11.2",
] ]
[[package]] [[package]]
name = "phf_codegen" name = "phf_codegen"
version = "0.11.3" version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aef8048c789fa5e851558d709946d6d79a8ff88c0440c587967f8e94bfb1216a" checksum = "e8d39688d359e6b34654d328e262234662d16cc0f60ec8dcbe5e718709342a5a"
dependencies = [ dependencies = [
"phf_generator", "phf_generator",
"phf_shared 0.11.3", "phf_shared 0.11.2",
] ]
[[package]] [[package]]
name = "phf_generator" name = "phf_generator"
version = "0.11.3" version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0"
dependencies = [ dependencies = [
"phf_shared 0.11.3", "phf_shared 0.11.2",
"rand", "rand",
] ]
[[package]] [[package]]
name = "phf_macros" name = "phf_macros"
version = "0.11.3" version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f84ac04429c13a7ff43785d75ad27569f2951ce0ffd30a3321230db2fc727216" checksum = "3444646e286606587e49f3bcf1679b8cef1dc2c5ecc29ddacaffc305180d464b"
dependencies = [ dependencies = [
"phf_generator", "phf_generator",
"phf_shared 0.11.3", "phf_shared 0.11.2",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.96", "syn 2.0.94",
] ]
[[package]] [[package]]
@ -809,16 +808,16 @@ version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6796ad771acdc0123d2a88dc428b5e38ef24456743ddb1744ed628f9815c096" checksum = "b6796ad771acdc0123d2a88dc428b5e38ef24456743ddb1744ed628f9815c096"
dependencies = [ dependencies = [
"siphasher 0.3.11", "siphasher",
] ]
[[package]] [[package]]
name = "phf_shared" name = "phf_shared"
version = "0.11.3" version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" checksum = "90fcb95eef784c2ac79119d1dd819e162b5da872ce6f3c3abe1e8ca1c082f72b"
dependencies = [ dependencies = [
"siphasher 1.0.1", "siphasher",
] ]
[[package]] [[package]]
@ -874,9 +873,9 @@ dependencies = [
[[package]] [[package]]
name = "proc-macro2" name = "proc-macro2"
version = "1.0.93" version = "1.0.92"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60946a68e5f9d28b0dc1c21bb8a97ee7d018a8b322fa57838ba31cc878e22d99" checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0"
dependencies = [ dependencies = [
"unicode-ident", "unicode-ident",
] ]
@ -928,7 +927,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"pyo3-macros-backend", "pyo3-macros-backend",
"quote", "quote",
"syn 2.0.96", "syn 2.0.94",
] ]
[[package]] [[package]]
@ -941,7 +940,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"pyo3-build-config", "pyo3-build-config",
"quote", "quote",
"syn 2.0.96", "syn 2.0.94",
] ]
[[package]] [[package]]
@ -1050,9 +1049,9 @@ dependencies = [
[[package]] [[package]]
name = "rustix" name = "rustix"
version = "0.38.43" version = "0.38.42"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a78891ee6bf2340288408954ac787aa063d8e8817e9f53abb37c695c6d834ef6" checksum = "f93dc38ecbab2eb790ff964bb77fa94faf256fd3e73285fd7ba0903b76bedb85"
dependencies = [ dependencies = [
"bitflags", "bitflags",
"errno", "errno",
@ -1111,14 +1110,14 @@ checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.96", "syn 2.0.94",
] ]
[[package]] [[package]]
name = "serde_json" name = "serde_json"
version = "1.0.135" version = "1.0.134"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b0d7ba2887406110130a978386c4e1befb98c674b4fba677954e4db976630d9" checksum = "d00f4175c42ee48b15416f6193a959ba3a0d67fc699a0db9ad12df9f83991c7d"
dependencies = [ dependencies = [
"itoa", "itoa",
"memchr", "memchr",
@ -1175,12 +1174,6 @@ version = "0.3.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d"
[[package]]
name = "siphasher"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d"
[[package]] [[package]]
name = "smallvec" name = "smallvec"
version = "1.13.2" version = "1.13.2"
@ -1233,7 +1226,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"rustversion", "rustversion",
"syn 2.0.96", "syn 2.0.94",
] ]
[[package]] [[package]]
@ -1249,9 +1242,9 @@ dependencies = [
[[package]] [[package]]
name = "syn" name = "syn"
version = "2.0.96" version = "2.0.94"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d5d0adab1ae378d7f53bdebc67a39f1f151407ef230f0ce2883572f5d8985c80" checksum = "987bc0be1cdea8b10216bd06e2ca407d40b9543468fafd3ddfb02f36e77f71f3"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -1333,7 +1326,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.96", "syn 2.0.94",
] ]
[[package]] [[package]]
@ -1611,9 +1604,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]] [[package]]
name = "winnow" name = "winnow"
version = "0.6.24" version = "0.6.22"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8d71a593cc5c42ad7876e2c1fda56f314f3754c084128833e64f1345ff8a03a" checksum = "39281189af81c07ec09db316b302a3e67bf9bd7cbf6c820b50e35fee9c2fa980"
dependencies = [ dependencies = [
"memchr", "memchr",
] ]
@ -1645,5 +1638,5 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.96", "syn 2.0.94",
] ]

6
flake.lock generated
View File

@ -2,11 +2,11 @@
"nodes": { "nodes": {
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1736798957, "lastModified": 1735834308,
"narHash": "sha256-qwpCtZhSsSNQtK4xYGzMiyEDhkNzOCz/Vfu4oL2ETsQ=", "narHash": "sha256-dklw3AXr3OGO4/XT1Tu3Xz9n/we8GctZZ75ZWVqAVhk=",
"owner": "NixOS", "owner": "NixOS",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "9abb87b552b7f55ac8916b6fc9e5cb486656a2f3", "rev": "6df24922a1400241dae323af55f30e4318a6ca65",
"type": "github" "type": "github"
}, },
"original": { "original": {

View File

@ -180,9 +180,12 @@
clippy clippy
pre-commit pre-commit
rustfmt rustfmt
rust-analyzer
]; ];
RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
shellHook = shellHook =
'' ''
export PYTHONPATH=/home/abdul/nac3/pyo3:$PYTHONPATH
export DEMO_LINALG_STUB=${packages.x86_64-linux.demo-linalg-stub}/lib/liblinalg.a export DEMO_LINALG_STUB=${packages.x86_64-linux.demo-linalg-stub}/lib/liblinalg.a
export DEMO_LINALG_STUB32=${packages.x86_64-linux.demo-linalg-stub32}/lib/liblinalg.a export DEMO_LINALG_STUB32=${packages.x86_64-linux.demo-linalg-stub32}/lib/liblinalg.a
''; '';

View File

@ -1,5 +1,6 @@
from min_artiq import * from min_artiq import *
attr1: Kernel[str] = "ss"
@nac3 @nac3
class Demo: class Demo:
@ -14,6 +15,8 @@ class Demo:
@kernel @kernel
def run(self): def run(self):
global attr1
# attr1 = "2"
self.core.reset() self.core.reset()
while True: while True:
with parallel: with parallel:

View File

@ -0,0 +1,39 @@
class EmbeddingMap:
def __init__(self):
self.object_inverse_map = {}
self.object_map = {}
self.string_map = {}
self.string_reverse_map = {}
self.function_map = {}
self.attributes_writeback = []
def store_function(self, key, fun):
self.function_map[key] = fun
return key
def store_object(self, obj):
obj_id = id(obj)
if obj_id in self.object_inverse_map:
return self.object_inverse_map[obj_id]
key = len(self.object_map) + 1
self.object_map[key] = obj
self.object_inverse_map[obj_id] = key
return key
def store_str(self, s):
if s in self.string_reverse_map:
return self.string_reverse_map[s]
key = len(self.string_map)
self.string_map[key] = s
self.string_reverse_map[s] = key
return key
def retrieve_function(self, key):
return self.function_map[key]
def retrieve_object(self, key):
return self.object_map[key]
def retrieve_str(self, key):
return self.string_map[key]

View File

@ -6,6 +6,7 @@ from typing import Generic, TypeVar
from math import floor, ceil from math import floor, ceil
import nac3artiq import nac3artiq
from embedding_map import EmbeddingMap
__all__ = [ __all__ = [
@ -192,46 +193,6 @@ def print_int64(x: int64):
raise NotImplementedError("syscall not simulated") raise NotImplementedError("syscall not simulated")
class EmbeddingMap:
def __init__(self):
self.object_inverse_map = {}
self.object_map = {}
self.string_map = {}
self.string_reverse_map = {}
self.function_map = {}
self.attributes_writeback = []
def store_function(self, key, fun):
self.function_map[key] = fun
return key
def store_object(self, obj):
obj_id = id(obj)
if obj_id in self.object_inverse_map:
return self.object_inverse_map[obj_id]
key = len(self.object_map) + 1
self.object_map[key] = obj
self.object_inverse_map[obj_id] = key
return key
def store_str(self, s):
if s in self.string_reverse_map:
return self.string_reverse_map[s]
key = len(self.string_map)
self.string_map[key] = s
self.string_reverse_map[s] = key
return key
def retrieve_function(self, key):
return self.function_map[key]
def retrieve_object(self, key):
return self.object_map[key]
def retrieve_str(self, key):
return self.string_map[key]
@nac3 @nac3
class Core: class Core:
ref_period: KernelInvariant[float] ref_period: KernelInvariant[float]

View File

@ -1,26 +0,0 @@
from min_artiq import *
from numpy import int32
# Global Variable Definition
X: Kernel[int32] = 1
# TopLevelFunction Defintion
@kernel
def display_X():
print_int32(X)
# TopLevel Class Definition
@nac3
class A:
@kernel
def __init__(self):
self.set_x(1)
@kernel
def set_x(self, new_val: int32):
global X
X = new_val
@kernel
def get_X(self) -> int32:
return X

View File

@ -1,26 +0,0 @@
from min_artiq import *
import module as module_definition
@nac3
class TestModuleSupport:
core: KernelInvariant[Core]
def __init__(self):
self.core = Core()
@kernel
def run(self):
# Accessing classes
obj = module_definition.A()
obj.get_X()
obj.set_x(2)
# Calling functions
module_definition.display_X()
# Updating global variables
module_definition.X = 9
module_definition.display_X()
if __name__ == "__main__":
TestModuleSupport().run()

View File

@ -1 +0,0 @@
../../target/release/libnac3artiq.so

View File

@ -0,0 +1,24 @@
from min_artiq import *
from numpy import int32
@nac3
class Demo:
core: KernelInvariant[Core]
attr1: KernelInvariant[str]
attr2: KernelInvariant[int32]
def __init__(self):
self.core = Core()
self.attr2 = 32
self.attr1 = "SAMPLE"
@kernel
def run(self):
print_int32(self.attr2)
self.attr1
if __name__ == "__main__":
Demo().run()

View File

@ -0,0 +1,40 @@
from min_artiq import *
from numpy import int32
@nac3
class Demo:
attr1: KernelInvariant[int32] = 2
attr2: int32 = 4
attr3: Kernel[int32]
@kernel
def __init__(self):
self.attr3 = 8
@nac3
class NAC3Devices:
core: KernelInvariant[Core]
attr4: KernelInvariant[int32] = 16
def __init__(self):
self.core = Core()
@kernel
def run(self):
Demo.attr1 # Supported
# Demo.attr2 # Field not accessible on Kernel
# Demo.attr3 # Only attributes can be accessed in this way
# Demo.attr1 = 2 # Attributes are immutable
self.attr4 # Attributes can be accessed within class
obj = Demo()
obj.attr1 # Attributes can be accessed by class objects
NAC3Devices.attr4 # Attributes accessible for classes without __init__
if __name__ == "__main__":
NAC3Devices().run()

View File

@ -8,8 +8,7 @@ use std::{
use itertools::Itertools; use itertools::Itertools;
use pyo3::{ use pyo3::{
types::{PyDict, PyList}, types::{PyDict, PyList}, PyObject, PyResult, Python
PyObject, PyResult, Python,
}; };
use super::{symbol_resolver::InnerResolver, timeline::TimeFns}; use super::{symbol_resolver::InnerResolver, timeline::TimeFns};
@ -29,7 +28,6 @@ use nac3core::{
inkwell::{ inkwell::{
context::Context, context::Context,
module::Linkage, module::Linkage,
targets::TargetMachine,
types::{BasicType, IntType}, types::{BasicType, IntType},
values::{BasicValueEnum, IntValue, PointerValue, StructValue}, values::{BasicValueEnum, IntValue, PointerValue, StructValue},
AddressSpace, IntPredicate, OptimizationLevel, AddressSpace, IntPredicate, OptimizationLevel,
@ -39,7 +37,7 @@ use nac3core::{
toplevel::{ toplevel::{
helper::{extract_ndims, PrimDef}, helper::{extract_ndims, PrimDef},
numpy::unpack_ndarray_var_tys, numpy::unpack_ndarray_var_tys,
DefinitionId, GenCall, DefinitionId, GenCall, TopLevelDef,
}, },
typecheck::typedef::{iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, VarMap}, typecheck::typedef::{iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, VarMap},
}; };
@ -88,13 +86,13 @@ pub struct ArtiqCodeGenerator<'a> {
impl<'a> ArtiqCodeGenerator<'a> { impl<'a> ArtiqCodeGenerator<'a> {
pub fn new( pub fn new(
name: String, name: String,
size_t: IntType<'_>, size_t: u32,
timeline: &'a (dyn TimeFns + Sync), timeline: &'a (dyn TimeFns + Sync),
) -> ArtiqCodeGenerator<'a> { ) -> ArtiqCodeGenerator<'a> {
assert!(matches!(size_t.get_bit_width(), 32 | 64)); assert!(size_t == 32 || size_t == 64);
ArtiqCodeGenerator { ArtiqCodeGenerator {
name, name,
size_t: size_t.get_bit_width(), size_t,
name_counter: 0, name_counter: 0,
start: None, start: None,
end: None, end: None,
@ -103,17 +101,6 @@ impl<'a> ArtiqCodeGenerator<'a> {
} }
} }
#[must_use]
pub fn with_target_machine(
name: String,
ctx: &Context,
target_machine: &TargetMachine,
timeline: &'a (dyn TimeFns + Sync),
) -> ArtiqCodeGenerator<'a> {
let llvm_usize = ctx.ptr_sized_int_type(&target_machine.get_target_data(), None);
Self::new(name, llvm_usize, timeline)
}
/// If the generator is currently in a direct-`parallel` block context, emits IR that resets the /// If the generator is currently in a direct-`parallel` block context, emits IR that resets the
/// position of the timeline to the initial timeline position before entering the `parallel` /// position of the timeline to the initial timeline position before entering the `parallel`
/// block. /// block.
@ -471,13 +458,13 @@ fn format_rpc_arg<'ctx>(
// libproto_artiq: NDArray = [data[..], dim_sz[..]] // libproto_artiq: NDArray = [data[..], dim_sz[..]]
let llvm_i1 = ctx.ctx.bool_type(); let llvm_i1 = ctx.ctx.bool_type();
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty);
let ndims = extract_ndims(&ctx.unifier, ndims); let ndims = extract_ndims(&ctx.unifier, ndims);
let dtype = ctx.get_llvm_type(generator, elem_ty); let dtype = ctx.get_llvm_type(generator, elem_ty);
let ndarray = let ndarray = NDArrayType::new(generator, ctx.ctx, dtype, ndims)
NDArrayType::new(ctx, dtype, ndims).map_value(arg.into_pointer_value(), None); .map_value(arg.into_pointer_value(), None);
let ndims = llvm_usize.const_int(ndims, false); let ndims = llvm_usize.const_int(ndims, false);
@ -556,7 +543,7 @@ fn format_rpc_ret<'ctx>(
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let llvm_i8_8 = ctx.ctx.struct_type(&[llvm_i8.array_type(8).into()], false); let llvm_i8_8 = ctx.ctx.struct_type(&[llvm_i8.array_type(8).into()], false);
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| { let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| {
@ -609,7 +596,7 @@ fn format_rpc_ret<'ctx>(
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty); let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty);
let dtype_llvm = ctx.get_llvm_type(generator, dtype); let dtype_llvm = ctx.get_llvm_type(generator, dtype);
let ndims = extract_ndims(&ctx.unifier, ndims); let ndims = extract_ndims(&ctx.unifier, ndims);
let ndarray = NDArrayType::new(ctx, dtype_llvm, ndims) let ndarray = NDArrayType::new(generator, ctx.ctx, dtype_llvm, ndims)
.construct_uninitialized(generator, ctx, None); .construct_uninitialized(generator, ctx, None);
// NOTE: Current content of `ndarray`: // NOTE: Current content of `ndarray`:
@ -697,7 +684,7 @@ fn format_rpc_ret<'ctx>(
// debug_assert(nelems * sizeof(T) >= ndarray_nbytes) // debug_assert(nelems * sizeof(T) >= ndarray_nbytes)
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let num_elements = ndarray.size(ctx); let num_elements = ndarray.size(generator, ctx);
let expected_ndarray_nbytes = let expected_ndarray_nbytes =
ctx.builder.build_int_mul(num_elements, itemsize, "").unwrap(); ctx.builder.build_int_mul(num_elements, itemsize, "").unwrap();
@ -809,7 +796,7 @@ fn rpc_codegen_callback_fn<'ctx>(
) -> Result<Option<BasicValueEnum<'ctx>>, String> { ) -> Result<Option<BasicValueEnum<'ctx>>, String> {
let int8 = ctx.ctx.i8_type(); let int8 = ctx.ctx.i8_type();
let int32 = ctx.ctx.i32_type(); let int32 = ctx.ctx.i32_type();
let size_type = ctx.get_size_type(); let size_type = generator.get_size_type(ctx.ctx);
let ptr_type = int8.ptr_type(AddressSpace::default()); let ptr_type = int8.ptr_type(AddressSpace::default());
let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false); let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false);
@ -994,7 +981,8 @@ pub fn attributes_writeback<'ctx>(
values.push((ty, obj.to_basic_value_enum(ctx, generator, ty).unwrap())); values.push((ty, obj.to_basic_value_enum(ctx, generator, ty).unwrap()));
} }
for val in (*globals).values() { // For now the global variables are just values so completely useless (14.0 stored as a float object in globals)
for (gloabl_id, val) in &*globals {
let val = val.as_ref(py); let val = val.as_ref(py);
let ty = inner_resolver.get_obj_type( let ty = inner_resolver.get_obj_type(
py, py,
@ -1007,6 +995,33 @@ pub fn attributes_writeback<'ctx>(
return Ok(Err(ty)); return Ok(Err(ty));
} }
let ty = ty.unwrap(); let ty = ty.unwrap();
if let Some(def_id) = inner_resolver.pyid_to_def.read().get(gloabl_id) {
if let TopLevelDef::Variable { name, simple_name, .. } = &*top_levels[def_id.0].read() {
// println!("[+] Varaible with type: {:?}\n{:?}\n", ctx.unifier.stringify(*ty), ctx.unifier.get_ty(*ty));
println!("Sending Value of {:?}", val.to_string());
if gen_rpc_tag(ctx, ty, &mut scratch_buffer).is_ok() {
// let Some(val) = ctx.module.get_global(simple_name.to_string().as_str()) else {continue;};
// let val = val.as_pointer_value();
let pydict = PyDict::new(py);
pydict.set_item("global", val)?;
pydict.set_item("name", name)?;
host_attributes.append(pydict)?;
values.push((
ty,
// ctx.build_gep_and_load(
// val,
// &[zero, int32.const_int(0, false)],
// None,
// ),
inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap(),
));
}
continue;
}
}
match &*ctx.unifier.get_ty(ty) { match &*ctx.unifier.get_ty(ty) {
TypeEnum::TObj { fields, obj_id, .. } TypeEnum::TObj { fields, obj_id, .. }
if *obj_id != ctx.primitives.option.obj_id(&ctx.unifier).unwrap() => if *obj_id != ctx.primitives.option.obj_id(&ctx.unifier).unwrap() =>
@ -1015,6 +1030,15 @@ pub fn attributes_writeback<'ctx>(
// for non-primitive attributes, they should be in another global // for non-primitive attributes, they should be in another global
let mut attributes = Vec::new(); let mut attributes = Vec::new();
let obj = inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap(); let obj = inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap();
if !obj.is_pointer_value() && gen_rpc_tag(ctx, ty, &mut scratch_buffer).is_ok() {
println!("[-] Other function skipped");
// values.push((ty, obj));
// let pydict = PyDict::new(py);
// pydict.set_item("global", val)?;
// host_attributes.append(pydict)?;
// continue;
}
for (name, (field_ty, is_mutable)) in fields { for (name, (field_ty, is_mutable)) in fields {
if !is_mutable { if !is_mutable {
continue; continue;
@ -1052,34 +1076,6 @@ pub fn attributes_writeback<'ctx>(
)); ));
} }
} }
TypeEnum::TModule { attributes, .. } => {
let mut fields = Vec::new();
let obj = inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap();
for (name, (field_ty, is_method)) in attributes {
if *is_method {
continue;
}
if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() {
fields.push(name.to_string());
let (index, _) = ctx.get_attr_index(ty, *name);
values.push((
*field_ty,
ctx.build_gep_and_load(
obj.into_pointer_value(),
&[zero, int32.const_int(index as u64, false)],
None,
),
));
}
}
if !fields.is_empty() {
let pydict = PyDict::new(py);
pydict.set_item("obj", val)?;
pydict.set_item("fields", fields)?;
host_attributes.append(pydict)?;
}
}
_ => {} _ => {}
} }
} }
@ -1195,7 +1191,7 @@ fn polymorphic_print<'ctx>(
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let llvm_i64 = ctx.ctx.i64_type(); let llvm_i64 = ctx.ctx.i64_type();
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
let suffix = suffix.unwrap_or_default(); let suffix = suffix.unwrap_or_default();

View File

@ -43,7 +43,7 @@ use nac3core::{
OptimizationLevel, OptimizationLevel,
}, },
nac3parser::{ nac3parser::{
ast::{self, Constant, ExprKind, Located, Stmt, StmtKind, StrRef}, ast::{Constant, ExprKind, Located, Stmt, StmtKind, StrRef},
parser::parse_program, parser::parse_program,
}, },
symbol_resolver::SymbolResolver, symbol_resolver::SymbolResolver,
@ -78,62 +78,14 @@ enum Isa {
} }
impl Isa { impl Isa {
/// Returns the [`TargetTriple`] used for compiling to this ISA. /// Returns the number of bits in `size_t` for the [`Isa`].
pub fn get_llvm_target_triple(self) -> TargetTriple { fn get_size_type(self) -> u32 {
match self { if self == Isa::Host {
Isa::Host => TargetMachine::get_default_triple(), 64u32
Isa::RiscV32G | Isa::RiscV32IMA => TargetTriple::create("riscv32-unknown-linux"), } else {
Isa::CortexA9 => TargetTriple::create("armv7-unknown-linux-gnueabihf"), 32u32
} }
} }
/// Returns the [`String`] representing the target CPU used for compiling to this ISA.
pub fn get_llvm_target_cpu(self) -> String {
match self {
Isa::Host => TargetMachine::get_host_cpu_name().to_string(),
Isa::RiscV32G | Isa::RiscV32IMA => "generic-rv32".to_string(),
Isa::CortexA9 => "cortex-a9".to_string(),
}
}
/// Returns the [`String`] representing the target features used for compiling to this ISA.
pub fn get_llvm_target_features(self) -> String {
match self {
Isa::Host => TargetMachine::get_host_cpu_features().to_string(),
Isa::RiscV32G => "+a,+m,+f,+d".to_string(),
Isa::RiscV32IMA => "+a,+m".to_string(),
Isa::CortexA9 => "+dsp,+fp16,+neon,+vfp3,+long-calls".to_string(),
}
}
/// Returns an instance of [`CodeGenTargetMachineOptions`] representing the target machine
/// options used for compiling to this ISA.
pub fn get_llvm_target_options(self) -> CodeGenTargetMachineOptions {
CodeGenTargetMachineOptions {
triple: self.get_llvm_target_triple().as_str().to_string_lossy().into_owned(),
cpu: self.get_llvm_target_cpu(),
features: self.get_llvm_target_features(),
reloc_mode: RelocMode::PIC,
..CodeGenTargetMachineOptions::from_host()
}
}
/// Returns an instance of [`TargetMachine`] used in compiling and linking of a program of this
/// ISA.
pub fn create_llvm_target_machine(self, opt_level: OptimizationLevel) -> TargetMachine {
self.get_llvm_target_options()
.create_target_machine(opt_level)
.expect("couldn't create target machine")
}
/// Returns the number of bits in `size_t` for this ISA.
fn get_size_type(self, ctx: &Context) -> u32 {
ctx.ptr_sized_int_type(
&self.create_llvm_target_machine(OptimizationLevel::Default).get_target_data(),
None,
)
.get_bit_width()
}
} }
#[derive(Clone)] #[derive(Clone)]
@ -159,7 +111,6 @@ pub struct PrimitivePythonId {
generic_alias: (u64, u64), generic_alias: (u64, u64),
virtual_id: u64, virtual_id: u64,
option: u64, option: u64,
module: u64,
} }
type TopLevelComponent = (Stmt, String, PyObject); type TopLevelComponent = (Stmt, String, PyObject);
@ -277,10 +228,17 @@ impl Nac3 {
} }
}) })
} }
// Allow global variable declaration with `Kernel` type annotation // Allow global declaration with Kernel[X] to pass across
// These are kept in sync between kernel and host
// KernelInvariants are constants, and hence not included here
StmtKind::AnnAssign { ref annotation, .. } => { StmtKind::AnnAssign { ref annotation, .. } => {
matches!(&annotation.node, ExprKind::Subscript { value, .. } if matches!(&value.node, ExprKind::Name {id, ..} if id == &"Kernel".into())) match &annotation.node {
ExprKind::Subscript { value, .. }
if matches!(&value.node, ExprKind::Name { id, .. }
if id.to_string().as_str() == "Kernel") => true,
_ => false
} }
},
_ => false, _ => false,
}; };
@ -431,7 +389,7 @@ impl Nac3 {
py: Python, py: Python,
link_fn: &dyn Fn(&Module) -> PyResult<T>, link_fn: &dyn Fn(&Module) -> PyResult<T>,
) -> PyResult<T> { ) -> PyResult<T> {
let size_t = self.isa.get_size_type(&Context::create()); let size_t = self.isa.get_size_type();
let (mut composer, mut builtins_def, mut builtins_ty) = TopLevelComposer::new( let (mut composer, mut builtins_def, mut builtins_ty) = TopLevelComposer::new(
self.builtins.clone(), self.builtins.clone(),
Self::get_lateinit_builtins(), Self::get_lateinit_builtins(),
@ -474,14 +432,12 @@ impl Nac3 {
]; ];
add_exceptions(&mut composer, &mut builtins_def, &mut builtins_ty, &exception_names); add_exceptions(&mut composer, &mut builtins_def, &mut builtins_ty, &exception_names);
// Stores a mapping from module id to attributes
let mut module_to_resolver_cache: HashMap<u64, _> = HashMap::new(); let mut module_to_resolver_cache: HashMap<u64, _> = HashMap::new();
let mut rpc_ids = vec![]; let mut rpc_ids = vec![];
for (stmt, path, module) in &self.top_levels { for (stmt, path, module) in &self.top_levels {
let py_module: &PyAny = module.extract(py)?; let py_module: &PyAny = module.extract(py)?;
let module_id: u64 = id_fn.call1((py_module,))?.extract()?; let module_id: u64 = id_fn.call1((py_module,))?.extract()?;
let module_name: String = py_module.getattr("__name__")?.extract()?;
let helper = helper.clone(); let helper = helper.clone();
let class_obj; let class_obj;
if let StmtKind::ClassDef { name, .. } = &stmt.node { if let StmtKind::ClassDef { name, .. } = &stmt.node {
@ -496,7 +452,7 @@ impl Nac3 {
} else { } else {
class_obj = None; class_obj = None;
} }
let (name_to_pyid, resolver, _, _) = let (name_to_pyid, resolver) =
module_to_resolver_cache.get(&module_id).cloned().unwrap_or_else(|| { module_to_resolver_cache.get(&module_id).cloned().unwrap_or_else(|| {
let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new(); let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new();
let members: &PyDict = let members: &PyDict =
@ -525,17 +481,9 @@ impl Nac3 {
}))) })))
as Arc<dyn SymbolResolver + Send + Sync>; as Arc<dyn SymbolResolver + Send + Sync>;
let name_to_pyid = Rc::new(name_to_pyid); let name_to_pyid = Rc::new(name_to_pyid);
let module_location = ast::Location::new(1, 1, stmt.location.file); module_to_resolver_cache
module_to_resolver_cache.insert( .insert(module_id, (name_to_pyid.clone(), resolver.clone()));
module_id, (name_to_pyid, resolver)
(
name_to_pyid.clone(),
resolver.clone(),
module_name.clone(),
Some(module_location),
),
);
(name_to_pyid, resolver, module_name, Some(module_location))
}); });
let (name, def_id, ty) = composer let (name, def_id, ty) = composer
@ -607,24 +555,15 @@ impl Nac3 {
pyid_to_ty.insert(id, ty); pyid_to_ty.insert(id, ty);
} }
} }
if let StmtKind::AnnAssign { target, .. } = &stmt.node {
let ExprKind::Name { id: name, .. } = target.node else { unreachable!() };
global_value_ids
.write()
// .insert(id, py_module.getattr(name.to_string().as_str()).unwrap().into());
.insert(id, module.as_ref(py).into());
} }
// Adding top level module definitions
for (module_id, (module_name_to_pyid, module_resolver, module_name, module_location)) in
module_to_resolver_cache
{
let def_id = composer
.register_top_level_module(
&module_name,
&module_name_to_pyid,
module_resolver,
module_location,
)
.map_err(|e| {
CompileError::new_err(format!("compilation failed\n----------\n{e}"))
})?;
self.pyid_to_def.write().insert(module_id, def_id);
} }
let id_fun = PyModule::import(py, "builtins")?.getattr("id")?; let id_fun = PyModule::import(py, "builtins")?.getattr("id")?;
@ -746,9 +685,6 @@ impl Nac3 {
"Unsupported @rpc annotation on global variable", "Unsupported @rpc annotation on global variable",
))) )))
} }
TopLevelDef::Module { .. } => {
unreachable!("Type module cannot be decorated with @rpc")
}
} }
} }
} }
@ -787,18 +723,14 @@ impl Nac3 {
let buffer = buffer.as_slice().into(); let buffer = buffer.as_slice().into();
membuffer.lock().push(buffer); membuffer.lock().push(buffer);
}))); })));
let size_t = context
.ptr_sized_int_type(&self.get_llvm_target_machine().get_target_data(), None)
.get_bit_width();
let num_threads = if is_multithreaded() { 4 } else { 1 }; let num_threads = if is_multithreaded() { 4 } else { 1 };
let thread_names: Vec<String> = (0..num_threads).map(|_| "main".to_string()).collect(); let thread_names: Vec<String> = (0..num_threads).map(|_| "main".to_string()).collect();
let threads: Vec<_> = thread_names let threads: Vec<_> = thread_names
.iter() .iter()
.map(|s| { .map(|s| Box::new(ArtiqCodeGenerator::new(s.to_string(), size_t, self.time_fns)))
Box::new(ArtiqCodeGenerator::with_target_machine(
s.to_string(),
&context,
&self.get_llvm_target_machine(),
self.time_fns,
))
})
.collect(); .collect();
let membuffer = membuffers.clone(); let membuffer = membuffers.clone();
@ -807,13 +739,8 @@ impl Nac3 {
let (registry, handles) = let (registry, handles) =
WorkerRegistry::create_workers(threads, top_level.clone(), &self.llvm_options, &f); WorkerRegistry::create_workers(threads, top_level.clone(), &self.llvm_options, &f);
let mut generator = ArtiqCodeGenerator::new("main".to_string(), size_t, self.time_fns);
let context = Context::create(); let context = Context::create();
let mut generator = ArtiqCodeGenerator::with_target_machine(
"main".to_string(),
&context,
&self.get_llvm_target_machine(),
self.time_fns,
);
let module = context.create_module("main"); let module = context.create_module("main");
let target_machine = self.llvm_options.create_target_machine().unwrap(); let target_machine = self.llvm_options.create_target_machine().unwrap();
module.set_data_layout(&target_machine.get_target_data().get_data_layout()); module.set_data_layout(&target_machine.get_target_data().get_data_layout());
@ -932,10 +859,52 @@ impl Nac3 {
link_fn(&main) link_fn(&main)
} }
/// Returns the [`TargetTriple`] used for compiling to [isa].
fn get_llvm_target_triple(isa: Isa) -> TargetTriple {
match isa {
Isa::Host => TargetMachine::get_default_triple(),
Isa::RiscV32G | Isa::RiscV32IMA => TargetTriple::create("riscv32-unknown-linux"),
Isa::CortexA9 => TargetTriple::create("armv7-unknown-linux-gnueabihf"),
}
}
/// Returns the [`String`] representing the target CPU used for compiling to [isa].
fn get_llvm_target_cpu(isa: Isa) -> String {
match isa {
Isa::Host => TargetMachine::get_host_cpu_name().to_string(),
Isa::RiscV32G | Isa::RiscV32IMA => "generic-rv32".to_string(),
Isa::CortexA9 => "cortex-a9".to_string(),
}
}
/// Returns the [`String`] representing the target features used for compiling to [isa].
fn get_llvm_target_features(isa: Isa) -> String {
match isa {
Isa::Host => TargetMachine::get_host_cpu_features().to_string(),
Isa::RiscV32G => "+a,+m,+f,+d".to_string(),
Isa::RiscV32IMA => "+a,+m".to_string(),
Isa::CortexA9 => "+dsp,+fp16,+neon,+vfp3,+long-calls".to_string(),
}
}
/// Returns an instance of [`CodeGenTargetMachineOptions`] representing the target machine
/// options used for compiling to [isa].
fn get_llvm_target_options(isa: Isa) -> CodeGenTargetMachineOptions {
CodeGenTargetMachineOptions {
triple: Nac3::get_llvm_target_triple(isa).as_str().to_string_lossy().into_owned(),
cpu: Nac3::get_llvm_target_cpu(isa),
features: Nac3::get_llvm_target_features(isa),
reloc_mode: RelocMode::PIC,
..CodeGenTargetMachineOptions::from_host()
}
}
/// Returns an instance of [`TargetMachine`] used in compiling and linking of a program to the /// Returns an instance of [`TargetMachine`] used in compiling and linking of a program to the
/// target [ISA][isa]. /// target [isa].
fn get_llvm_target_machine(&self) -> TargetMachine { fn get_llvm_target_machine(&self) -> TargetMachine {
self.isa.create_llvm_target_machine(self.llvm_options.opt_level) Nac3::get_llvm_target_options(self.isa)
.create_target_machine(self.llvm_options.opt_level)
.expect("couldn't create target machine")
} }
} }
@ -1043,8 +1012,7 @@ impl Nac3 {
Isa::RiscV32IMA => &timeline::NOW_PINNING_TIME_FNS, Isa::RiscV32IMA => &timeline::NOW_PINNING_TIME_FNS,
Isa::CortexA9 | Isa::Host => &timeline::EXTERN_TIME_FNS, Isa::CortexA9 | Isa::Host => &timeline::EXTERN_TIME_FNS,
}; };
let (primitive, _) = let (primitive, _) = TopLevelComposer::make_primitives(isa.get_size_type());
TopLevelComposer::make_primitives(isa.get_size_type(&Context::create()));
let builtins = vec![ let builtins = vec![
( (
"now_mu".into(), "now_mu".into(),
@ -1132,7 +1100,6 @@ impl Nac3 {
tuple: get_attr_id(builtins_mod, "tuple"), tuple: get_attr_id(builtins_mod, "tuple"),
exception: get_attr_id(builtins_mod, "Exception"), exception: get_attr_id(builtins_mod, "Exception"),
option: get_id(artiq_builtins.get_item("Option").ok().flatten().unwrap()), option: get_id(artiq_builtins.get_item("Option").ok().flatten().unwrap()),
module: get_attr_id(types_mod, "ModuleType"),
}; };
let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap(); let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap();
@ -1194,7 +1161,7 @@ impl Nac3 {
deferred_eval_store: DeferredEvaluationStore::new(), deferred_eval_store: DeferredEvaluationStore::new(),
llvm_options: CodeGenLLVMOptions { llvm_options: CodeGenLLVMOptions {
opt_level: OptimizationLevel::Default, opt_level: OptimizationLevel::Default,
target: isa.get_llvm_target_options(), target: Nac3::get_llvm_target_options(isa),
}, },
}) })
} }

View File

@ -674,48 +674,6 @@ impl InnerResolver {
}) })
}); });
// check if obj is module
if self.helper.id_fn.call1(py, (ty.clone(),))?.extract::<u64>(py)?
== self.primitive_ids.module
&& self.pyid_to_def.read().contains_key(&py_obj_id)
{
let def_id = self.pyid_to_def.read()[&py_obj_id];
let def = defs[def_id.0].read();
let TopLevelDef::Module { name: module_name, module_id, attributes, methods, .. } =
&*def
else {
unreachable!("must be a module here");
};
// Construct the module return type
let mut module_attributes = HashMap::new();
for (name, _) in attributes {
let attribute_obj = obj.getattr(name.to_string().as_str())?;
let attribute_ty =
self.get_obj_type(py, attribute_obj, unifier, defs, primitives)?;
if let Ok(attribute_ty) = attribute_ty {
module_attributes.insert(*name, (attribute_ty, false));
} else {
return Ok(Err(format!("Unable to resolve {module_name}.{name}")));
}
}
for name in methods.keys() {
let method_obj = obj.getattr(name.to_string().as_str())?;
let method_ty = self.get_obj_type(py, method_obj, unifier, defs, primitives)?;
if let Ok(method_ty) = method_ty {
module_attributes.insert(*name, (method_ty, true));
} else {
return Ok(Err(format!("Unable to resolve {module_name}.{name}")));
}
}
let module_ty =
TypeEnum::TModule { module_id: *module_id, attributes: module_attributes };
let ty = unifier.add_ty(module_ty);
return Ok(Ok(ty));
}
if let Some(ty) = constructor_ty { if let Some(ty) = constructor_ty {
self.pyid_to_type.write().insert(py_obj_id, ty); self.pyid_to_type.write().insert(py_obj_id, ty);
return Ok(Ok(ty)); return Ok(Ok(ty));
@ -1049,7 +1007,7 @@ impl InnerResolver {
} }
_ => unreachable!("must be list"), _ => unreachable!("must be list"),
}; };
let size_t = ctx.get_size_type(); let size_t = generator.get_size_type(ctx.ctx);
let ty = if len == 0 let ty = if len == 0
&& matches!(&*ctx.unifier.get_ty_immutable(elem_ty), TypeEnum::TVar { .. }) && matches!(&*ctx.unifier.get_ty_immutable(elem_ty), TypeEnum::TVar { .. })
{ {
@ -1138,7 +1096,7 @@ impl InnerResolver {
let llvm_i8 = ctx.ctx.i8_type(); let llvm_i8 = ctx.ctx.i8_type();
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty); let llvm_ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty);
let dtype = llvm_ndarray.element_type(); let dtype = llvm_ndarray.element_type();
@ -1415,77 +1373,6 @@ impl InnerResolver {
None => Ok(None), None => Ok(None),
} }
} }
} else if ty_id == self.primitive_ids.module {
let id_str = id.to_string();
if let Some(global) = ctx.module.get_global(&id_str) {
return Ok(Some(global.as_pointer_value().into()));
}
let top_level_defs = ctx.top_level.definitions.read();
let ty = self
.get_obj_type(py, obj, &mut ctx.unifier, &top_level_defs, &ctx.primitives)?
.unwrap();
let ty = ctx
.get_llvm_type(generator, ty)
.into_pointer_type()
.get_element_type()
.into_struct_type();
{
if self.global_value_ids.read().contains_key(&id) {
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
ctx.module.add_global(ty, Some(AddressSpace::default()), &id_str)
});
return Ok(Some(global.as_pointer_value().into()));
}
self.global_value_ids.write().insert(id, obj.into());
}
let fields = {
let definition =
top_level_defs.get(self.pyid_to_def.read().get(&id).unwrap().0).unwrap().read();
let TopLevelDef::Module { attributes, .. } = &*definition else { unreachable!() };
attributes
.iter()
.filter_map(|f| {
let definition = top_level_defs.get(f.1 .0).unwrap().read();
if let TopLevelDef::Variable { ty, .. } = &*definition {
Some((f.0, *ty))
} else {
None
}
})
.collect_vec()
};
let values: Result<Option<Vec<_>>, _> = fields
.iter()
.map(|(name, ty)| {
self.get_obj_value(
py,
obj.getattr(name.to_string().as_str())?,
ctx,
generator,
*ty,
)
.map_err(|e| {
super::CompileError::new_err(format!("Error getting field {name}: {e}"))
})
})
.collect();
let values = values?;
if let Some(values) = values {
let val = ty.const_named_struct(&values);
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
ctx.module.add_global(ty, Some(AddressSpace::default()), &id_str)
});
global.set_initializer(&val);
Ok(Some(global.as_pointer_value().into()))
} else {
Ok(None)
}
} else { } else {
let id_str = id.to_string(); let id_str = id.to_string();
@ -1671,45 +1558,35 @@ impl SymbolResolver for Resolver {
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
) -> Option<ValueEnum<'ctx>> { ) -> Option<ValueEnum<'ctx>> {
if let Some(def_id) = self.0.id_to_def.read().get(&id) { if let Some(global_value) = self
let top_levels = ctx.top_level.definitions.read(); .0
if matches!(&*top_levels[def_id.0].read(), TopLevelDef::Variable { .. }) { .name_to_pyid
let module_val = &self.0.module; .get(&id)
let ret = Python::with_gil(|py| -> PyResult<Result<BasicValueEnum, String>> { .and_then(|pyid| self.0.global_value_ids.read().get(pyid).cloned())
let module_val = module_val.as_ref(py); {
let val = ctx.module.get_global(id.to_string().as_str()).unwrap_or_else(|| {
let ty = self.0.get_obj_type( let v = Python::with_gil(|py| -> PyResult<SymbolValue> {
py, Ok(self
module_val, .0
&mut ctx.unifier, .get_default_param_obj_value(py, global_value.as_ref(py))
&top_levels, .unwrap()
&ctx.primitives, .unwrap())
)?;
if let Err(ty) = ty {
return Ok(Err(ty));
}
let ty = ty.unwrap();
let obj = self.0.get_obj_value(py, module_val, ctx, generator, ty)?.unwrap();
let (idx, _) = ctx.get_attr_index(ty, id);
let ret = unsafe {
ctx.builder.build_gep(
obj.into_pointer_value(),
&[
ctx.ctx.i32_type().const_zero(),
ctx.ctx.i32_type().const_int(idx as u64, false),
],
id.to_string().as_str(),
)
}
.unwrap();
Ok(Ok(ret.as_basic_value_enum()))
}) })
.unwrap(); .unwrap();
if ret.is_err() {
return None; let ty = v.get_type(&ctx.primitives, &mut ctx.unifier);
}
return Some(ret.unwrap().into()); let init_val = ctx.gen_symbol_val(generator, &v, ty);
} let llvm_ty = init_val.get_type();
println!("Adding {id}");
let global = ctx.module.add_global(llvm_ty, None, &id.to_string());
global.set_linkage(Linkage::LinkOnceAny);
global.set_initializer(&init_val);
global
});
return Some(val.as_basic_value_enum().into());
} }
let sym_value = { let sym_value = {

View File

@ -64,7 +64,7 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
let ndarray = NDArrayType::from_unifier_type(generator, ctx, arg_ty) let ndarray = NDArrayType::from_unifier_type(generator, ctx, arg_ty)
.map_value(arg.into_pointer_value(), None); .map_value(arg.into_pointer_value(), None);
ctx.builder ctx.builder
.build_int_truncate_or_bit_cast(ndarray.len(ctx), llvm_i32, "len") .build_int_truncate_or_bit_cast(ndarray.len(generator, ctx), llvm_i32, "len")
.unwrap() .unwrap()
} }
@ -752,8 +752,12 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
debug_assert!(ctx.unifier.unioned(x1_dtype, x2_dtype)); debug_assert!(ctx.unifier.unioned(x1_dtype, x2_dtype));
let llvm_common_dtype = x1.get_type().element_type(); let llvm_common_dtype = x1.get_type().element_type();
let result = let result = NDArrayType::new_broadcast(
NDArrayType::new_broadcast(ctx, llvm_common_dtype, &[x1.get_type(), x2.get_type()]) generator,
ctx.ctx,
llvm_common_dtype,
&[x1.get_type(), x2.get_type()],
)
.broadcast_starmap( .broadcast_starmap(
generator, generator,
ctx, ctx,
@ -831,7 +835,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
debug_assert!(["np_argmin", "np_argmax", "np_max", "np_min"].iter().any(|f| *f == fn_name)); debug_assert!(["np_argmin", "np_argmax", "np_max", "np_min"].iter().any(|f| *f == fn_name));
let llvm_int64 = ctx.ctx.i64_type(); let llvm_int64 = ctx.ctx.i64_type();
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
Ok(match a { Ok(match a {
BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => { BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => {
@ -866,7 +870,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let size_nez = ctx let size_nez = ctx
.builder .builder
.build_int_compare(IntPredicate::NE, ndarray.size(ctx), zero, "") .build_int_compare(IntPredicate::NE, ndarray.size(generator, ctx), zero, "")
.unwrap(); .unwrap();
ctx.make_assert( ctx.make_assert(
@ -1011,8 +1015,12 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
debug_assert!(ctx.unifier.unioned(x1_dtype, x2_dtype)); debug_assert!(ctx.unifier.unioned(x1_dtype, x2_dtype));
let llvm_common_dtype = x1.get_type().element_type(); let llvm_common_dtype = x1.get_type().element_type();
let result = let result = NDArrayType::new_broadcast(
NDArrayType::new_broadcast(ctx, llvm_common_dtype, &[x1.get_type(), x2.get_type()]) generator,
ctx.ctx,
llvm_common_dtype,
&[x1.get_type(), x2.get_type()],
)
.broadcast_starmap( .broadcast_starmap(
generator, generator,
ctx, ctx,
@ -1644,7 +1652,7 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
unsupported_type(ctx, FN_NAME, &[x1_ty]); unsupported_type(ctx, FN_NAME, &[x1_ty]);
} }
let out = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2) let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2)
.construct_uninitialized(generator, ctx, None); .construct_uninitialized(generator, ctx, None);
out.copy_shape_from_ndarray(generator, ctx, x1); out.copy_shape_from_ndarray(generator, ctx, x1);
unsafe { out.create_data(generator, ctx) }; unsafe { out.create_data(generator, ctx) };
@ -1668,7 +1676,7 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_qr"; const FN_NAME: &str = "np_linalg_qr";
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
@ -1686,7 +1694,7 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
}; };
let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None); let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None);
let out_ndarray_ty = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2); let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2);
let q = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[d0, dk], None); let q = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[d0, dk], None);
unsafe { q.create_data(generator, ctx) }; unsafe { q.create_data(generator, ctx) };
@ -1707,11 +1715,8 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
let q = q.as_base_value().as_basic_value_enum(); let q = q.as_base_value().as_basic_value_enum();
let r = r.as_base_value().as_basic_value_enum(); let r = r.as_base_value().as_basic_value_enum();
let tuple = TupleType::new(ctx, &[q.get_type(), r.get_type()]).construct_from_objects( let tuple = TupleType::new(generator, ctx.ctx, &[q.get_type(), r.get_type()])
ctx, .construct_from_objects(ctx, [q, r], None);
[q, r],
None,
);
Ok(tuple.as_base_value().into()) Ok(tuple.as_base_value().into())
} }
@ -1723,7 +1728,7 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_svd"; const FN_NAME: &str = "np_linalg_svd";
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
@ -1741,8 +1746,8 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
}; };
let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None); let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None);
let out_ndarray1_ty = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 1); let out_ndarray1_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 1);
let out_ndarray2_ty = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2); let out_ndarray2_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2);
let u = out_ndarray2_ty.construct_dyn_shape(generator, ctx, &[d0, d0], None); let u = out_ndarray2_ty.construct_dyn_shape(generator, ctx, &[d0, d0], None);
unsafe { u.create_data(generator, ctx) }; unsafe { u.create_data(generator, ctx) };
@ -1770,7 +1775,7 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
let u = u.as_base_value().as_basic_value_enum(); let u = u.as_base_value().as_basic_value_enum();
let s = s.as_base_value().as_basic_value_enum(); let s = s.as_base_value().as_basic_value_enum();
let vh = vh.as_base_value().as_basic_value_enum(); let vh = vh.as_base_value().as_basic_value_enum();
let tuple = TupleType::new(ctx, &[u.get_type(), s.get_type(), vh.get_type()]) let tuple = TupleType::new(generator, ctx.ctx, &[u.get_type(), s.get_type(), vh.get_type()])
.construct_from_objects(ctx, [u, s, vh], None); .construct_from_objects(ctx, [u, s, vh], None);
Ok(tuple.as_base_value().into()) Ok(tuple.as_base_value().into())
} }
@ -1791,7 +1796,7 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
unsupported_type(ctx, FN_NAME, &[x1_ty]); unsupported_type(ctx, FN_NAME, &[x1_ty]);
} }
let out = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2) let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2)
.construct_uninitialized(generator, ctx, None); .construct_uninitialized(generator, ctx, None);
out.copy_shape_from_ndarray(generator, ctx, x1); out.copy_shape_from_ndarray(generator, ctx, x1);
unsafe { out.create_data(generator, ctx) }; unsafe { out.create_data(generator, ctx) };
@ -1816,7 +1821,7 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_pinv"; const FN_NAME: &str = "np_linalg_pinv";
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
@ -1833,12 +1838,8 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
}; };
let out = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2).construct_dyn_shape( let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2)
generator, .construct_dyn_shape(generator, ctx, &[d0, d1], None);
ctx,
&[d0, d1],
None,
);
unsafe { out.create_data(generator, ctx) }; unsafe { out.create_data(generator, ctx) };
let x1_c = x1.make_contiguous_ndarray(generator, ctx); let x1_c = x1.make_contiguous_ndarray(generator, ctx);
@ -1861,7 +1862,7 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "sp_linalg_lu"; const FN_NAME: &str = "sp_linalg_lu";
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
@ -1879,7 +1880,7 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
}; };
let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None); let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None);
let out_ndarray_ty = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2); let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2);
let l = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[d0, dk], None); let l = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[d0, dk], None);
unsafe { l.create_data(generator, ctx) }; unsafe { l.create_data(generator, ctx) };
@ -1900,11 +1901,8 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
let l = l.as_base_value().as_basic_value_enum(); let l = l.as_base_value().as_basic_value_enum();
let u = u.as_base_value().as_basic_value_enum(); let u = u.as_base_value().as_basic_value_enum();
let tuple = TupleType::new(ctx, &[l.get_type(), u.get_type()]).construct_from_objects( let tuple = TupleType::new(generator, ctx.ctx, &[l.get_type(), u.get_type()])
ctx, .construct_from_objects(ctx, [l, u], None);
[l, u],
None,
);
Ok(tuple.as_base_value().into()) Ok(tuple.as_base_value().into())
} }
@ -1917,7 +1915,7 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_matrix_power"; const FN_NAME: &str = "np_linalg_matrix_power";
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
let BasicValueEnum::PointerValue(x1) = x1 else { let BasicValueEnum::PointerValue(x1) = x1 else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
@ -1938,11 +1936,11 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
}; };
let x2 = NDArrayType::new_unsized(ctx, ctx.ctx.f64_type().into()) let x2 = NDArrayType::new_unsized(generator, ctx.ctx, ctx.ctx.f64_type().into())
.construct_unsized(generator, ctx, &x2, None); // x2.shape == [] .construct_unsized(generator, ctx, &x2, None); // x2.shape == []
let x2 = x2.atleast_nd(generator, ctx, 1); // x2.shape == [1] let x2 = x2.atleast_nd(generator, ctx, 1); // x2.shape == [1]
let out = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2) let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2)
.construct_uninitialized(generator, ctx, None); .construct_uninitialized(generator, ctx, None);
out.copy_shape_from_ndarray(generator, ctx, x1); out.copy_shape_from_ndarray(generator, ctx, x1);
unsafe { out.create_data(generator, ctx) }; unsafe { out.create_data(generator, ctx) };
@ -1970,7 +1968,7 @@ pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>(
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_matrix_power"; const FN_NAME: &str = "np_linalg_matrix_power";
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
@ -1981,12 +1979,8 @@ pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>(
} }
// The output is a float64, but we are using an ndarray (shape == [1]) for uniformity in function call. // The output is a float64, but we are using an ndarray (shape == [1]) for uniformity in function call.
let det = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 1).construct_const_shape( let det = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 1)
generator, .construct_const_shape(generator, ctx, &[1], None);
ctx,
&[1],
None,
);
unsafe { det.create_data(generator, ctx) }; unsafe { det.create_data(generator, ctx) };
let x1_c = x1.make_contiguous_ndarray(generator, ctx); let x1_c = x1.make_contiguous_ndarray(generator, ctx);
@ -2020,7 +2014,7 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
unsupported_type(ctx, FN_NAME, &[x1_ty]); unsupported_type(ctx, FN_NAME, &[x1_ty]);
} }
let out_ndarray_ty = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2); let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2);
let t = out_ndarray_ty.construct_uninitialized(generator, ctx, None); let t = out_ndarray_ty.construct_uninitialized(generator, ctx, None);
t.copy_shape_from_ndarray(generator, ctx, x1); t.copy_shape_from_ndarray(generator, ctx, x1);
@ -2043,11 +2037,8 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
let t = t.as_base_value().as_basic_value_enum(); let t = t.as_base_value().as_basic_value_enum();
let z = z.as_base_value().as_basic_value_enum(); let z = z.as_base_value().as_basic_value_enum();
let tuple = TupleType::new(ctx, &[t.get_type(), z.get_type()]).construct_from_objects( let tuple = TupleType::new(generator, ctx.ctx, &[t.get_type(), z.get_type()])
ctx, .construct_from_objects(ctx, [t, z], None);
[t, z],
None,
);
Ok(tuple.as_base_value().into()) Ok(tuple.as_base_value().into())
} }
@ -2068,7 +2059,7 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
unsupported_type(ctx, FN_NAME, &[x1_ty]); unsupported_type(ctx, FN_NAME, &[x1_ty]);
} }
let out_ndarray_ty = NDArrayType::new(ctx, ctx.ctx.f64_type().into(), 2); let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), 2);
let h = out_ndarray_ty.construct_uninitialized(generator, ctx, None); let h = out_ndarray_ty.construct_uninitialized(generator, ctx, None);
h.copy_shape_from_ndarray(generator, ctx, x1); h.copy_shape_from_ndarray(generator, ctx, x1);
@ -2091,10 +2082,7 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
let h = h.as_base_value().as_basic_value_enum(); let h = h.as_base_value().as_basic_value_enum();
let q = q.as_base_value().as_basic_value_enum(); let q = q.as_base_value().as_basic_value_enum();
let tuple = TupleType::new(ctx, &[h.get_type(), q.get_type()]).construct_from_objects( let tuple = TupleType::new(generator, ctx.ctx, &[h.get_type(), q.get_type()])
ctx, .construct_from_objects(ctx, [h, q], None);
[h, q],
None,
);
Ok(tuple.as_base_value().into()) Ok(tuple.as_base_value().into())
} }

View File

@ -56,10 +56,6 @@ pub enum ConcreteTypeEnum {
fields: HashMap<StrRef, (ConcreteType, bool)>, fields: HashMap<StrRef, (ConcreteType, bool)>,
params: IndexMap<TypeVarId, ConcreteType>, params: IndexMap<TypeVarId, ConcreteType>,
}, },
TModule {
module_id: DefinitionId,
methods: HashMap<StrRef, (ConcreteType, bool)>,
},
TVirtual { TVirtual {
ty: ConcreteType, ty: ConcreteType,
}, },
@ -209,19 +205,6 @@ impl ConcreteTypeStore {
}) })
.collect(), .collect(),
}, },
TypeEnum::TModule { module_id, attributes } => ConcreteTypeEnum::TModule {
module_id: *module_id,
methods: attributes
.iter()
.filter_map(|(name, ty)| match &*unifier.get_ty(ty.0) {
TypeEnum::TFunc(..) | TypeEnum::TObj { .. } => None,
_ => Some((
*name,
(self.from_unifier_type(unifier, primitives, ty.0, cache), ty.1),
)),
})
.collect(),
},
TypeEnum::TVirtual { ty } => ConcreteTypeEnum::TVirtual { TypeEnum::TVirtual { ty } => ConcreteTypeEnum::TVirtual {
ty: self.from_unifier_type(unifier, primitives, *ty, cache), ty: self.from_unifier_type(unifier, primitives, *ty, cache),
}, },
@ -301,15 +284,6 @@ impl ConcreteTypeStore {
TypeVar { id, ty } TypeVar { id, ty }
})), })),
}, },
ConcreteTypeEnum::TModule { module_id, methods } => TypeEnum::TModule {
module_id: *module_id,
attributes: methods
.iter()
.map(|(name, cty)| {
(*name, (self.to_unifier_type(unifier, primitives, cty.0, cache), cty.1))
})
.collect::<HashMap<_, _>>(),
},
ConcreteTypeEnum::TFunc { args, ret, vars } => TypeEnum::TFunc(FunSignature { ConcreteTypeEnum::TFunc { args, ret, vars } => TypeEnum::TFunc(FunSignature {
args: args args: args
.iter() .iter()

View File

@ -61,13 +61,8 @@ pub fn get_subst_key(
) -> String { ) -> String {
let mut vars = obj let mut vars = obj
.map(|ty| { .map(|ty| {
if let TypeEnum::TObj { params, .. } = &*unifier.get_ty(ty) { let TypeEnum::TObj { params, .. } = &*unifier.get_ty(ty) else { unreachable!() };
params.clone() params.clone()
} else if let TypeEnum::TModule { .. } = &*unifier.get_ty(ty) {
indexmap::IndexMap::new()
} else {
unreachable!()
}
}) })
.unwrap_or_default(); .unwrap_or_default();
vars.extend(fun_vars); vars.extend(fun_vars);
@ -125,7 +120,6 @@ impl<'ctx> CodeGenContext<'ctx, '_> {
pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> (usize, Option<Constant>) { pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> (usize, Option<Constant>) {
let obj_id = match &*self.unifier.get_ty(ty) { let obj_id = match &*self.unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } => *obj_id, TypeEnum::TObj { obj_id, .. } => *obj_id,
TypeEnum::TModule { module_id, .. } => *module_id,
// we cannot have other types, virtual type should be handled by function calls // we cannot have other types, virtual type should be handled by function calls
_ => codegen_unreachable!(self), _ => codegen_unreachable!(self),
}; };
@ -137,8 +131,6 @@ impl<'ctx> CodeGenContext<'ctx, '_> {
let attribute_index = attributes.iter().find_position(|x| x.0 == attr).unwrap(); let attribute_index = attributes.iter().find_position(|x| x.0 == attr).unwrap();
(attribute_index.0, Some(attribute_index.1 .2.clone())) (attribute_index.0, Some(attribute_index.1 .2.clone()))
} }
} else if let TopLevelDef::Module { attributes, .. } = &*def.read() {
(attributes.iter().find_position(|x| x.0 == attr).unwrap().0, None)
} else { } else {
codegen_unreachable!(self) codegen_unreachable!(self)
}; };
@ -173,7 +165,7 @@ impl<'ctx> CodeGenContext<'ctx, '_> {
.build_global_string_ptr(v, "const") .build_global_string_ptr(v, "const")
.map(|v| v.as_pointer_value().into()) .map(|v| v.as_pointer_value().into())
.unwrap(); .unwrap();
let size = self.get_size_type().const_int(v.len() as u64, false); let size = generator.get_size_type(self.ctx).const_int(v.len() as u64, false);
let ty = self.get_llvm_type(generator, self.primitives.str).into_struct_type(); let ty = self.get_llvm_type(generator, self.primitives.str).into_struct_type();
ty.const_named_struct(&[str_ptr, size.into()]).into() ty.const_named_struct(&[str_ptr, size.into()]).into()
} }
@ -326,7 +318,7 @@ impl<'ctx> CodeGenContext<'ctx, '_> {
.build_global_string_ptr(v, "const") .build_global_string_ptr(v, "const")
.map(|v| v.as_pointer_value().into()) .map(|v| v.as_pointer_value().into())
.unwrap(); .unwrap();
let size = self.get_size_type().const_int(v.len() as u64, false); let size = generator.get_size_type(self.ctx).const_int(v.len() as u64, false);
let ty = self.get_llvm_type(generator, self.primitives.str); let ty = self.get_llvm_type(generator, self.primitives.str);
let val = let val =
ty.into_struct_type().const_named_struct(&[str_ptr, size.into()]).into(); ty.into_struct_type().const_named_struct(&[str_ptr, size.into()]).into();
@ -828,7 +820,7 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
fun: (&FunSignature, DefinitionId), fun: (&FunSignature, DefinitionId),
params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>, params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
) -> Result<Option<BasicValueEnum<'ctx>>, String> { ) -> Result<Option<BasicValueEnum<'ctx>>, String> {
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
let definition = ctx.top_level.definitions.read().get(fun.1 .0).cloned().unwrap(); let definition = ctx.top_level.definitions.read().get(fun.1 .0).cloned().unwrap();
let id; let id;
@ -987,7 +979,7 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
TopLevelDef::Class { .. } => { TopLevelDef::Class { .. } => {
return Ok(Some(generator.gen_constructor(ctx, fun.0, &def, params)?)) return Ok(Some(generator.gen_constructor(ctx, fun.0, &def, params)?))
} }
TopLevelDef::Variable { .. } | TopLevelDef::Module { .. } => unreachable!(), TopLevelDef::Variable { .. } => unreachable!(),
} }
} }
.or_else(|_: String| { .or_else(|_: String| {
@ -1028,7 +1020,7 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
} }
let is_vararg = args.iter().any(|arg| arg.is_vararg); let is_vararg = args.iter().any(|arg| arg.is_vararg);
if is_vararg { if is_vararg {
params.push(ctx.get_size_type().into()); params.push(generator.get_size_type(ctx.ctx).into());
} }
let fun_ty = match ret_type { let fun_ty = match ret_type {
Some(ret_type) if !has_sret => ret_type.fn_type(&params, is_vararg), Some(ret_type) if !has_sret => ret_type.fn_type(&params, is_vararg),
@ -1136,7 +1128,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
return Ok(None); return Ok(None);
}; };
let int32 = ctx.ctx.i32_type(); let int32 = ctx.ctx.i32_type();
let size_t = ctx.get_size_type(); let size_t = generator.get_size_type(ctx.ctx);
let zero_size_t = size_t.const_zero(); let zero_size_t = size_t.const_zero();
let zero_32 = int32.const_zero(); let zero_32 = int32.const_zero();
@ -1175,7 +1167,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
"listcomp.alloc_size", "listcomp.alloc_size",
) )
.unwrap(); .unwrap();
list = ListType::new(ctx, &elem_ty).construct( list = ListType::new(generator, ctx.ctx, elem_ty).construct(
generator, generator,
ctx, ctx,
list_alloc_size.into_int_value(), list_alloc_size.into_int_value(),
@ -1226,7 +1218,12 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
Some("length"), Some("length"),
) )
.into_int_value(); .into_int_value();
list = ListType::new(ctx, &elem_ty).construct(generator, ctx, length, Some("listcomp")); list = ListType::new(generator, ctx.ctx, elem_ty).construct(
generator,
ctx,
length,
Some("listcomp"),
);
let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("counter.addr"))?; let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("counter.addr"))?;
// counter = -1 // counter = -1
@ -1261,10 +1258,12 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
} }
// Emits the content of `cont_bb` // Emits the content of `cont_bb`
let emit_cont_bb = |ctx: &CodeGenContext<'ctx, '_>, list: ListValue<'ctx>| { let emit_cont_bb =
|ctx: &CodeGenContext<'ctx, '_>, generator: &dyn CodeGenerator, list: ListValue<'ctx>| {
ctx.builder.position_at_end(cont_bb); ctx.builder.position_at_end(cont_bb);
list.store_size( list.store_size(
ctx, ctx,
generator,
ctx.builder.build_load(index, "index").map(BasicValueEnum::into_int_value).unwrap(), ctx.builder.build_load(index, "index").map(BasicValueEnum::into_int_value).unwrap(),
); );
}; };
@ -1275,7 +1274,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
} else { } else {
// Bail if the predicate is an ellipsis - Emit cont_bb contents in case the // Bail if the predicate is an ellipsis - Emit cont_bb contents in case the
// no element matches the predicate // no element matches the predicate
emit_cont_bb(ctx, list); emit_cont_bb(ctx, generator, list);
return Ok(None); return Ok(None);
}; };
@ -1288,7 +1287,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
let Some(elem) = generator.gen_expr(ctx, elt)? else { let Some(elem) = generator.gen_expr(ctx, elt)? else {
// Similarly, bail if the generator expression is an ellipsis, but keep cont_bb contents // Similarly, bail if the generator expression is an ellipsis, but keep cont_bb contents
emit_cont_bb(ctx, list); emit_cont_bb(ctx, generator, list);
return Ok(None); return Ok(None);
}; };
@ -1305,7 +1304,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
.unwrap(); .unwrap();
ctx.builder.build_unconditional_branch(test_bb).unwrap(); ctx.builder.build_unconditional_branch(test_bb).unwrap();
emit_cont_bb(ctx, list); emit_cont_bb(ctx, generator, list);
Ok(Some(list.as_base_value().into())) Ok(Some(list.as_base_value().into()))
} }
@ -1351,7 +1350,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
} else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) } else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id())
|| ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) || ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id())
{ {
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
if op.variant == BinopVariant::AugAssign { if op.variant == BinopVariant::AugAssign {
todo!("Augmented assignment operators not implemented for lists") todo!("Augmented assignment operators not implemented for lists")
@ -1389,8 +1388,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
.build_int_add(lhs.load_size(ctx, None), rhs.load_size(ctx, None), "") .build_int_add(lhs.load_size(ctx, None), rhs.load_size(ctx, None), "")
.unwrap(); .unwrap();
let new_list = let new_list = ListType::new(generator, ctx.ctx, llvm_elem_ty)
ListType::new(ctx, &llvm_elem_ty).construct(generator, ctx, size, None); .construct(generator, ctx, size, None);
let lhs_size = ctx let lhs_size = ctx
.builder .builder
@ -1477,7 +1476,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
let elem_llvm_ty = ctx.get_llvm_type(generator, elem_ty); let elem_llvm_ty = ctx.get_llvm_type(generator, elem_ty);
let sizeof_elem = elem_llvm_ty.size_of().unwrap(); let sizeof_elem = elem_llvm_ty.size_of().unwrap();
let new_list = ListType::new(ctx, &elem_llvm_ty).construct( let new_list = ListType::new(generator, ctx.ctx, elem_llvm_ty).construct(
generator, generator,
ctx, ctx,
ctx.builder.build_int_mul(list_val.load_size(ctx, None), int_val, "").unwrap(), ctx.builder.build_int_mul(list_val.load_size(ctx, None), int_val, "").unwrap(),
@ -1579,7 +1578,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
let right = right.to_ndarray(generator, ctx); let right = right.to_ndarray(generator, ctx);
let result = NDArrayType::new_broadcast( let result = NDArrayType::new_broadcast(
ctx, generator,
ctx.ctx,
llvm_common_dtype, llvm_common_dtype,
&[left.get_type(), right.get_type()], &[left.get_type(), right.get_type()],
) )
@ -1852,7 +1852,8 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
.to_ndarray(generator, ctx); .to_ndarray(generator, ctx);
let result_ndarray = NDArrayType::new_broadcast( let result_ndarray = NDArrayType::new_broadcast(
ctx, generator,
ctx.ctx,
ctx.ctx.i8_type().into(), ctx.ctx.i8_type().into(),
&[left.get_type(), right.get_type()], &[left.get_type(), right.get_type()],
) )
@ -1971,7 +1972,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
let rhs = rhs.into_struct_value(); let rhs = rhs.into_struct_value();
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
let plhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap(); let plhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap();
ctx.builder.build_store(plhs, lhs).unwrap(); ctx.builder.build_store(plhs, lhs).unwrap();
@ -1999,7 +2000,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
&[llvm_usize.const_zero(), llvm_i32.const_int(1, false)], &[llvm_usize.const_zero(), llvm_i32.const_int(1, false)],
None, None,
).into_int_value(); ).into_int_value();
let result = call_string_eq(ctx, lhs_ptr, lhs_len, rhs_ptr, rhs_len); let result = call_string_eq(generator, ctx, lhs_ptr, lhs_len, rhs_ptr, rhs_len);
if *op == Cmpop::NotEq { if *op == Cmpop::NotEq {
ctx.builder.build_not(result, "").unwrap() ctx.builder.build_not(result, "").unwrap()
} else { } else {
@ -2009,7 +2010,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
.iter() .iter()
.any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id())) .any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()))
{ {
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
let gen_list_cmpop = |generator: &mut G, let gen_list_cmpop = |generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>| ctx: &mut CodeGenContext<'ctx, '_>|
@ -2130,7 +2131,9 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
ctx.ctx.bool_type().const_zero(), ctx.ctx.bool_type().const_zero(),
) )
.unwrap(); .unwrap();
hooks.build_break_branch(&ctx.builder); ctx.builder
.build_unconditional_branch(hooks.exit_bb)
.unwrap();
Ok(()) Ok(())
}, },
@ -2372,7 +2375,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
) -> Result<Option<ValueEnum<'ctx>>, String> { ) -> Result<Option<ValueEnum<'ctx>>, String> {
ctx.current_loc = expr.location; ctx.current_loc = expr.location;
let int32 = ctx.ctx.i32_type(); let int32 = ctx.ctx.i32_type();
let usize = ctx.get_size_type(); let usize = generator.get_size_type(ctx.ctx);
let zero = int32.const_int(0, false); let zero = int32.const_int(0, false);
let loc = ctx.debug_info.0.create_debug_location( let loc = ctx.debug_info.0.create_debug_location(
@ -2477,11 +2480,20 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
} else { } else {
Some(elements[0].get_type()) Some(elements[0].get_type())
}; };
let length = ctx.get_size_type().const_int(elements.len() as u64, false); let length = generator.get_size_type(ctx.ctx).const_int(elements.len() as u64, false);
let arr_str_ptr = if let Some(ty) = ty { let arr_str_ptr = if let Some(ty) = ty {
ListType::new(ctx, &ty).construct(generator, ctx, length, Some("list")) ListType::new(generator, ctx.ctx, ty).construct(
generator,
ctx,
length,
Some("list"),
)
} else { } else {
ListType::new_untyped(ctx).construct_empty(generator, ctx, Some("list")) ListType::new_untyped(generator, ctx.ctx).construct_empty(
generator,
ctx,
Some("list"),
)
}; };
let arr_ptr = arr_str_ptr.data(); let arr_ptr = arr_str_ptr.data();
for (i, v) in elements.iter().enumerate() { for (i, v) in elements.iter().enumerate() {
@ -2813,10 +2825,6 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
&*ctx.unifier.get_ty(value.custom.unwrap()) &*ctx.unifier.get_ty(value.custom.unwrap())
{ {
*obj_id *obj_id
} else if let TypeEnum::TModule { module_id, .. } =
&*ctx.unifier.get_ty(value.custom.unwrap())
{
*module_id
} else { } else {
codegen_unreachable!(ctx) codegen_unreachable!(ctx)
}; };
@ -2827,13 +2835,11 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
} else { } else {
let defs = ctx.top_level.definitions.read(); let defs = ctx.top_level.definitions.read();
let obj_def = defs.get(id.0).unwrap().read(); let obj_def = defs.get(id.0).unwrap().read();
if let TopLevelDef::Class { methods, .. } = &*obj_def { let TopLevelDef::Class { methods, .. } = &*obj_def else {
methods.iter().find(|method| method.0 == *attr).unwrap().2
} else if let TopLevelDef::Module { methods, .. } = &*obj_def {
*methods.iter().find(|method| method.0 == attr).unwrap().1
} else {
codegen_unreachable!(ctx) codegen_unreachable!(ctx)
} };
methods.iter().find(|method| method.0 == *attr).unwrap().2
}; };
// directly generate code for option.unwrap // directly generate code for option.unwrap
// since it needs to return static value to optimize for kernel invariant // since it needs to return static value to optimize for kernel invariant
@ -2966,8 +2972,12 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
.unwrap(), .unwrap(),
step, step,
); );
let res_array_ret = let res_array_ret = ListType::new(generator, ctx.ctx, ty).construct(
ListType::new(ctx, &ty).construct(generator, ctx, length, Some("ret")); generator,
ctx,
length,
Some("ret"),
);
let Some(res_ind) = handle_slice_indices( let Some(res_ind) = handle_slice_indices(
&None, &None,
&None, &None,
@ -2999,7 +3009,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
}; };
let raw_index = ctx let raw_index = ctx
.builder .builder
.build_int_s_extend(raw_index, ctx.get_size_type(), "sext") .build_int_s_extend(raw_index, generator.get_size_type(ctx.ctx), "sext")
.unwrap(); .unwrap();
// handle negative index // handle negative index
let is_negative = ctx let is_negative = ctx
@ -3007,7 +3017,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
.build_int_compare( .build_int_compare(
IntPredicate::SLT, IntPredicate::SLT,
raw_index, raw_index,
ctx.get_size_type().const_zero(), generator.get_size_type(ctx.ctx).const_zero(),
"is_neg", "is_neg",
) )
.unwrap(); .unwrap();

View File

@ -1,6 +1,5 @@
use inkwell::{ use inkwell::{
context::Context, context::Context,
targets::TargetMachine,
types::{BasicTypeEnum, IntType}, types::{BasicTypeEnum, IntType},
values::{BasicValueEnum, IntValue, PointerValue}, values::{BasicValueEnum, IntValue, PointerValue},
}; };
@ -19,9 +18,6 @@ pub trait CodeGenerator {
fn get_name(&self) -> &str; fn get_name(&self) -> &str;
/// Return an instance of [`IntType`] corresponding to the type of `size_t` for this instance. /// Return an instance of [`IntType`] corresponding to the type of `size_t` for this instance.
///
/// Prefer using [`CodeGenContext::get_size_type`] if [`CodeGenContext`] is available, as it is
/// equivalent to this function in a more concise syntax.
fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx>; fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx>;
/// Generate function call and returns the function return value. /// Generate function call and returns the function return value.
@ -274,27 +270,19 @@ pub struct DefaultCodeGenerator {
impl DefaultCodeGenerator { impl DefaultCodeGenerator {
#[must_use] #[must_use]
pub fn new(name: String, size_t: IntType<'_>) -> DefaultCodeGenerator { pub fn new(name: String, size_t: u32) -> DefaultCodeGenerator {
assert!(matches!(size_t.get_bit_width(), 32 | 64)); assert!(matches!(size_t, 32 | 64));
DefaultCodeGenerator { name, size_t: size_t.get_bit_width() } DefaultCodeGenerator { name, size_t }
}
#[must_use]
pub fn with_target_machine(
name: String,
ctx: &Context,
target_machine: &TargetMachine,
) -> DefaultCodeGenerator {
let llvm_usize = ctx.ptr_sized_int_type(&target_machine.get_target_data(), None);
Self::new(name, llvm_usize)
} }
} }
impl CodeGenerator for DefaultCodeGenerator { impl CodeGenerator for DefaultCodeGenerator {
/// Returns the name for this [`CodeGenerator`].
fn get_name(&self) -> &str { fn get_name(&self) -> &str {
&self.name &self.name
} }
/// Returns an LLVM integer type representing `size_t`.
fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx> { fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx> {
// it should be unsigned, but we don't really need unsigned and this could save us from // it should be unsigned, but we don't really need unsigned and this could save us from
// having to do a bit cast... // having to do a bit cast...

View File

@ -24,7 +24,7 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
src_arr: ListValue<'ctx>, src_arr: ListValue<'ctx>,
src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
) { ) {
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
@ -168,7 +168,7 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
ctx.builder.position_at_end(update_bb); ctx.builder.position_at_end(update_bb);
let new_len = let new_len =
ctx.builder.build_int_z_extend_or_bit_cast(new_len, llvm_usize, "new_len").unwrap(); ctx.builder.build_int_z_extend_or_bit_cast(new_len, llvm_usize, "new_len").unwrap();
dest_arr.store_size(ctx, new_len); dest_arr.store_size(ctx, generator, new_len);
ctx.builder.build_unconditional_branch(cont_bb).unwrap(); ctx.builder.build_unconditional_branch(cont_bb).unwrap();
ctx.builder.position_at_end(cont_bb); ctx.builder.position_at_end(cont_bb);
} }

View File

@ -68,9 +68,13 @@ pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver)
/// - When [`TypeContext::size_type`] is 32-bits, the function name is `fn_name}`. /// - When [`TypeContext::size_type`] is 32-bits, the function name is `fn_name}`.
/// - When [`TypeContext::size_type`] is 64-bits, the function name is `{fn_name}64`. /// - When [`TypeContext::size_type`] is 64-bits, the function name is `{fn_name}64`.
#[must_use] #[must_use]
pub fn get_usize_dependent_function_name(ctx: &CodeGenContext<'_, '_>, name: &str) -> String { pub fn get_usize_dependent_function_name<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'_, '_>,
name: &str,
) -> String {
let mut name = name.to_owned(); let mut name = name.to_owned();
match ctx.get_size_type().get_bit_width() { match generator.get_size_type(ctx.ctx).get_bit_width() {
32 => {} 32 => {}
64 => name.push_str("64"), 64 => name.push_str("64"),
bit_width => { bit_width => {

View File

@ -21,7 +21,7 @@ pub fn call_nac3_ndarray_array_set_and_validate_list_shape<'ctx, G: CodeGenerato
ndims: IntValue<'ctx>, ndims: IntValue<'ctx>,
shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
) { ) {
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
assert_eq!(list.get_type().element_type().unwrap(), ctx.ctx.i8_type().into()); assert_eq!(list.get_type().element_type().unwrap(), ctx.ctx.i8_type().into());
assert_eq!(ndims.get_type(), llvm_usize); assert_eq!(ndims.get_type(), llvm_usize);
assert_eq!( assert_eq!(
@ -29,8 +29,11 @@ pub fn call_nac3_ndarray_array_set_and_validate_list_shape<'ctx, G: CodeGenerato
llvm_usize.into() llvm_usize.into()
); );
let name = let name = get_usize_dependent_function_name(
get_usize_dependent_function_name(ctx, "__nac3_ndarray_array_set_and_validate_list_shape"); generator,
ctx,
"__nac3_ndarray_array_set_and_validate_list_shape",
);
infer_and_call_function( infer_and_call_function(
ctx, ctx,
@ -52,14 +55,19 @@ pub fn call_nac3_ndarray_array_set_and_validate_list_shape<'ctx, G: CodeGenerato
/// - `ndarray.ndims`: Must be initialized. /// - `ndarray.ndims`: Must be initialized.
/// - `ndarray.shape`: Must be initialized. /// - `ndarray.shape`: Must be initialized.
/// - `ndarray.data`: Must be allocated and contiguous. /// - `ndarray.data`: Must be allocated and contiguous.
pub fn call_nac3_ndarray_array_write_list_to_array<'ctx>( pub fn call_nac3_ndarray_array_write_list_to_array<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
list: ListValue<'ctx>, list: ListValue<'ctx>,
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
) { ) {
assert_eq!(list.get_type().element_type().unwrap(), ctx.ctx.i8_type().into()); assert_eq!(list.get_type().element_type().unwrap(), ctx.ctx.i8_type().into());
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_array_write_list_to_array"); let name = get_usize_dependent_function_name(
generator,
ctx,
"__nac3_ndarray_array_write_list_to_array",
);
infer_and_call_function( infer_and_call_function(
ctx, ctx,

View File

@ -20,7 +20,7 @@ pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator +
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
) { ) {
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
assert_eq!( assert_eq!(
@ -28,8 +28,11 @@ pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator +
llvm_usize.into() llvm_usize.into()
); );
let name = let name = get_usize_dependent_function_name(
get_usize_dependent_function_name(ctx, "__nac3_ndarray_util_assert_shape_no_negative"); generator,
ctx,
"__nac3_ndarray_util_assert_shape_no_negative",
);
create_and_call_function( create_and_call_function(
ctx, ctx,
@ -54,7 +57,7 @@ pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator +
ndarray_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, ndarray_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
output_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, output_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
) { ) {
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
assert_eq!( assert_eq!(
@ -66,8 +69,11 @@ pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator +
llvm_usize.into() llvm_usize.into()
); );
let name = let name = get_usize_dependent_function_name(
get_usize_dependent_function_name(ctx, "__nac3_ndarray_util_assert_output_shape_same"); generator,
ctx,
"__nac3_ndarray_util_assert_output_shape_same",
);
create_and_call_function( create_and_call_function(
ctx, ctx,
@ -88,14 +94,15 @@ pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator +
/// ///
/// Returns a [`usize`][CodeGenerator::get_size_type] value of the number of elements of an /// Returns a [`usize`][CodeGenerator::get_size_type] value of the number of elements of an
/// `ndarray`, corresponding to the value of `ndarray.size`. /// `ndarray`, corresponding to the value of `ndarray.size`.
pub fn call_nac3_ndarray_size<'ctx>( pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_ndarray = ndarray.get_type().as_base_type(); let llvm_ndarray = ndarray.get_type().as_base_type();
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_size"); let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_size");
create_and_call_function( create_and_call_function(
ctx, ctx,
@ -113,14 +120,15 @@ pub fn call_nac3_ndarray_size<'ctx>(
/// ///
/// Returns a [`usize`][CodeGenerator::get_size_type] value of the number of bytes consumed by the /// Returns a [`usize`][CodeGenerator::get_size_type] value of the number of bytes consumed by the
/// data of the `ndarray`, corresponding to the value of `ndarray.nbytes`. /// data of the `ndarray`, corresponding to the value of `ndarray.nbytes`.
pub fn call_nac3_ndarray_nbytes<'ctx>( pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_ndarray = ndarray.get_type().as_base_type(); let llvm_ndarray = ndarray.get_type().as_base_type();
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_nbytes"); let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_nbytes");
create_and_call_function( create_and_call_function(
ctx, ctx,
@ -138,14 +146,15 @@ pub fn call_nac3_ndarray_nbytes<'ctx>(
/// ///
/// Returns a [`usize`][CodeGenerator::get_size_type] value of the size of the topmost dimension of /// Returns a [`usize`][CodeGenerator::get_size_type] value of the size of the topmost dimension of
/// the `ndarray`, corresponding to the value of `ndarray.__len__`. /// the `ndarray`, corresponding to the value of `ndarray.__len__`.
pub fn call_nac3_ndarray_len<'ctx>( pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_ndarray = ndarray.get_type().as_base_type(); let llvm_ndarray = ndarray.get_type().as_base_type();
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_len"); let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_len");
create_and_call_function( create_and_call_function(
ctx, ctx,
@ -162,14 +171,15 @@ pub fn call_nac3_ndarray_len<'ctx>(
/// Generates a call to `__nac3_ndarray_is_c_contiguous`. /// Generates a call to `__nac3_ndarray_is_c_contiguous`.
/// ///
/// Returns an `i1` value indicating whether the `ndarray` is C-contiguous. /// Returns an `i1` value indicating whether the `ndarray` is C-contiguous.
pub fn call_nac3_ndarray_is_c_contiguous<'ctx>( pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let llvm_i1 = ctx.ctx.bool_type(); let llvm_i1 = ctx.ctx.bool_type();
let llvm_ndarray = ndarray.get_type().as_base_type(); let llvm_ndarray = ndarray.get_type().as_base_type();
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_is_c_contiguous"); let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_is_c_contiguous");
create_and_call_function( create_and_call_function(
ctx, ctx,
@ -186,19 +196,20 @@ pub fn call_nac3_ndarray_is_c_contiguous<'ctx>(
/// Generates a call to `__nac3_ndarray_get_nth_pelement`. /// Generates a call to `__nac3_ndarray_get_nth_pelement`.
/// ///
/// Returns a [`PointerValue`] to the `index`-th flattened element of the `ndarray`. /// Returns a [`PointerValue`] to the `index`-th flattened element of the `ndarray`.
pub fn call_nac3_ndarray_get_nth_pelement<'ctx>( pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
index: IntValue<'ctx>, index: IntValue<'ctx>,
) -> PointerValue<'ctx> { ) -> PointerValue<'ctx> {
let llvm_i8 = ctx.ctx.i8_type(); let llvm_i8 = ctx.ctx.i8_type();
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_ndarray = ndarray.get_type().as_base_type(); let llvm_ndarray = ndarray.get_type().as_base_type();
assert_eq!(index.get_type(), llvm_usize); assert_eq!(index.get_type(), llvm_usize);
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_get_nth_pelement"); let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_nth_pelement");
create_and_call_function( create_and_call_function(
ctx, ctx,
@ -225,7 +236,7 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized
) -> PointerValue<'ctx> { ) -> PointerValue<'ctx> {
let llvm_i8 = ctx.ctx.i8_type(); let llvm_i8 = ctx.ctx.i8_type();
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let llvm_ndarray = ndarray.get_type().as_base_type(); let llvm_ndarray = ndarray.get_type().as_base_type();
@ -234,7 +245,8 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized
llvm_usize.into() llvm_usize.into()
); );
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_get_pelement_by_indices"); let name =
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_pelement_by_indices");
create_and_call_function( create_and_call_function(
ctx, ctx,
@ -254,13 +266,15 @@ pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized
/// Generates a call to `__nac3_ndarray_set_strides_by_shape`. /// Generates a call to `__nac3_ndarray_set_strides_by_shape`.
/// ///
/// Sets `ndarray.strides` assuming that `ndarray.shape` is C-contiguous. /// Sets `ndarray.strides` assuming that `ndarray.shape` is C-contiguous.
pub fn call_nac3_ndarray_set_strides_by_shape<'ctx>( pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
) { ) {
let llvm_ndarray = ndarray.get_type().as_base_type(); let llvm_ndarray = ndarray.get_type().as_base_type();
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_set_strides_by_shape"); let name =
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_set_strides_by_shape");
create_and_call_function( create_and_call_function(
ctx, ctx,
@ -277,12 +291,13 @@ pub fn call_nac3_ndarray_set_strides_by_shape<'ctx>(
/// Copies all elements from `src_ndarray` to `dst_ndarray` using their flattened views. The number /// Copies all elements from `src_ndarray` to `dst_ndarray` using their flattened views. The number
/// of elements in `src_ndarray` must be greater than or equal to the number of elements in /// of elements in `src_ndarray` must be greater than or equal to the number of elements in
/// `dst_ndarray`. /// `dst_ndarray`.
pub fn call_nac3_ndarray_copy_data<'ctx>( pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
src_ndarray: NDArrayValue<'ctx>, src_ndarray: NDArrayValue<'ctx>,
dst_ndarray: NDArrayValue<'ctx>, dst_ndarray: NDArrayValue<'ctx>,
) { ) {
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_copy_data"); let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_copy_data");
infer_and_call_function( infer_and_call_function(
ctx, ctx,

View File

@ -20,12 +20,13 @@ use crate::codegen::{
/// - `dst_ndarray.ndims` must be initialized and matching the length of `dst_ndarray.shape`. /// - `dst_ndarray.ndims` must be initialized and matching the length of `dst_ndarray.shape`.
/// - `dst_ndarray.shape` must be initialized and contains the target broadcast shape. /// - `dst_ndarray.shape` must be initialized and contains the target broadcast shape.
/// - `dst_ndarray.strides` must be allocated and may contain uninitialized values. /// - `dst_ndarray.strides` must be allocated and may contain uninitialized values.
pub fn call_nac3_ndarray_broadcast_to<'ctx>( pub fn call_nac3_ndarray_broadcast_to<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
src_ndarray: NDArrayValue<'ctx>, src_ndarray: NDArrayValue<'ctx>,
dst_ndarray: NDArrayValue<'ctx>, dst_ndarray: NDArrayValue<'ctx>,
) { ) {
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_broadcast_to"); let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_to");
infer_and_call_function( infer_and_call_function(
ctx, ctx,
&name, &name,
@ -52,7 +53,7 @@ pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G, Shape>(
Shape: TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> Shape: TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>
+ TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>>, + TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>>,
{ {
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
assert_eq!(num_shape_entries.get_type(), llvm_usize); assert_eq!(num_shape_entries.get_type(), llvm_usize);
assert!(ShapeEntryType::is_type( assert!(ShapeEntryType::is_type(
@ -64,7 +65,7 @@ pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G, Shape>(
assert_eq!(dst_ndims.get_type(), llvm_usize); assert_eq!(dst_ndims.get_type(), llvm_usize);
assert_eq!(dst_shape.element_type(ctx, generator), llvm_usize.into()); assert_eq!(dst_shape.element_type(ctx, generator), llvm_usize.into());
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_broadcast_shapes"); let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_broadcast_shapes");
infer_and_call_function( infer_and_call_function(
ctx, ctx,
&name, &name,

View File

@ -17,7 +17,7 @@ pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>(
src_ndarray: NDArrayValue<'ctx>, src_ndarray: NDArrayValue<'ctx>,
dst_ndarray: NDArrayValue<'ctx>, dst_ndarray: NDArrayValue<'ctx>,
) { ) {
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_index"); let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_index");
infer_and_call_function( infer_and_call_function(
ctx, ctx,
&name, &name,

View File

@ -25,7 +25,7 @@ pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>(
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
indices: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, indices: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
) { ) {
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
assert_eq!( assert_eq!(
@ -33,7 +33,7 @@ pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>(
llvm_usize.into() llvm_usize.into()
); );
let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_initialize"); let name = get_usize_dependent_function_name(generator, ctx, "__nac3_nditer_initialize");
create_and_call_function( create_and_call_function(
ctx, ctx,
@ -53,11 +53,12 @@ pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>(
/// ///
/// Returns an `i1` value indicating whether there are elements left to traverse for the `iter` /// Returns an `i1` value indicating whether there are elements left to traverse for the `iter`
/// object. /// object.
pub fn call_nac3_nditer_has_element<'ctx>( pub fn call_nac3_nditer_has_element<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
iter: NDIterValue<'ctx>, iter: NDIterValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_has_element"); let name = get_usize_dependent_function_name(generator, ctx, "__nac3_nditer_has_element");
infer_and_call_function( infer_and_call_function(
ctx, ctx,
@ -74,8 +75,12 @@ pub fn call_nac3_nditer_has_element<'ctx>(
/// Generates a call to `__nac3_nditer_next`. /// Generates a call to `__nac3_nditer_next`.
/// ///
/// Moves `iter` to point to the next element. /// Moves `iter` to point to the next element.
pub fn call_nac3_nditer_next<'ctx>(ctx: &CodeGenContext<'ctx, '_>, iter: NDIterValue<'ctx>) { pub fn call_nac3_nditer_next<'ctx, G: CodeGenerator + ?Sized>(
let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_next"); generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
iter: NDIterValue<'ctx>,
) {
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_nditer_next");
infer_and_call_function(ctx, &name, None, &[iter.as_base_value().into()], None, None); infer_and_call_function(ctx, &name, None, &[iter.as_base_value().into()], None, None);
} }

View File

@ -20,7 +20,7 @@ pub fn call_nac3_ndarray_matmul_calculate_shapes<'ctx, G: CodeGenerator + ?Sized
new_b_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, new_b_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
dst_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>, dst_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
) { ) {
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
assert_eq!( assert_eq!(
BasicTypeEnum::try_from(a_shape.element_type(ctx, generator)).unwrap(), BasicTypeEnum::try_from(a_shape.element_type(ctx, generator)).unwrap(),
@ -43,7 +43,8 @@ pub fn call_nac3_ndarray_matmul_calculate_shapes<'ctx, G: CodeGenerator + ?Sized
llvm_usize.into() llvm_usize.into()
); );
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_matmul_calculate_shapes"); let name =
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_matmul_calculate_shapes");
infer_and_call_function( infer_and_call_function(
ctx, ctx,

View File

@ -18,13 +18,14 @@ pub fn call_nac3_ndarray_reshape_resolve_and_check_new_shape<'ctx, G: CodeGenera
new_ndims: IntValue<'ctx>, new_ndims: IntValue<'ctx>,
new_shape: ArraySliceValue<'ctx>, new_shape: ArraySliceValue<'ctx>,
) { ) {
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
assert_eq!(size.get_type(), llvm_usize); assert_eq!(size.get_type(), llvm_usize);
assert_eq!(new_ndims.get_type(), llvm_usize); assert_eq!(new_ndims.get_type(), llvm_usize);
assert_eq!(new_shape.element_type(ctx, generator), llvm_usize.into()); assert_eq!(new_shape.element_type(ctx, generator), llvm_usize.into());
let name = get_usize_dependent_function_name( let name = get_usize_dependent_function_name(
generator,
ctx, ctx,
"__nac3_ndarray_reshape_resolve_and_check_new_shape", "__nac3_ndarray_reshape_resolve_and_check_new_shape",
); );

View File

@ -23,12 +23,12 @@ pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
dst_ndarray: NDArrayValue<'ctx>, dst_ndarray: NDArrayValue<'ctx>,
axes: Option<&impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>>, axes: Option<&impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>>,
) { ) {
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
assert!(axes.is_none_or(|axes| axes.size(ctx, generator).get_type() == llvm_usize)); assert!(axes.is_none_or(|axes| axes.size(ctx, generator).get_type() == llvm_usize));
assert!(axes.is_none_or(|axes| axes.element_type(ctx, generator) == llvm_usize.into())); assert!(axes.is_none_or(|axes| axes.element_type(ctx, generator) == llvm_usize.into()));
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_transpose"); let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_transpose");
infer_and_call_function( infer_and_call_function(
ctx, ctx,
&name, &name,

View File

@ -2,10 +2,11 @@ use inkwell::values::{BasicValueEnum, CallSiteValue, IntValue, PointerValue};
use itertools::Either; use itertools::Either;
use super::get_usize_dependent_function_name; use super::get_usize_dependent_function_name;
use crate::codegen::CodeGenContext; use crate::codegen::{CodeGenContext, CodeGenerator};
/// Generates a call to string equality comparison. Returns an `i1` representing whether the strings are equal. /// Generates a call to string equality comparison. Returns an `i1` representing whether the strings are equal.
pub fn call_string_eq<'ctx>( pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
str1_ptr: PointerValue<'ctx>, str1_ptr: PointerValue<'ctx>,
str1_len: IntValue<'ctx>, str1_len: IntValue<'ctx>,
@ -14,7 +15,7 @@ pub fn call_string_eq<'ctx>(
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let llvm_i1 = ctx.ctx.bool_type(); let llvm_i1 = ctx.ctx.bool_type();
let func_name = get_usize_dependent_function_name(ctx, "nac3_str_eq"); let func_name = get_usize_dependent_function_name(generator, ctx, "nac3_str_eq");
let func = ctx.module.get_function(&func_name).unwrap_or_else(|| { let func = ctx.module.get_function(&func_name).unwrap_or_else(|| {
ctx.module.add_function( ctx.module.add_function(

View File

@ -1,5 +1,4 @@
use std::{ use std::{
cell::OnceCell,
collections::{HashMap, HashSet}, collections::{HashMap, HashSet},
sync::{ sync::{
atomic::{AtomicBool, Ordering}, atomic::{AtomicBool, Ordering},
@ -20,7 +19,7 @@ use inkwell::{
module::Module, module::Module,
passes::PassBuilderOptions, passes::PassBuilderOptions,
targets::{CodeModel, RelocMode, Target, TargetMachine, TargetTriple}, targets::{CodeModel, RelocMode, Target, TargetMachine, TargetTriple},
types::{AnyType, BasicType, BasicTypeEnum, IntType}, types::{AnyType, BasicType, BasicTypeEnum},
values::{BasicValueEnum, FunctionValue, IntValue, PhiValue, PointerValue}, values::{BasicValueEnum, FunctionValue, IntValue, PhiValue, PointerValue},
AddressSpace, IntPredicate, OptimizationLevel, AddressSpace, IntPredicate, OptimizationLevel,
}; };
@ -227,33 +226,14 @@ pub struct CodeGenContext<'ctx, 'a> {
/// The current source location. /// The current source location.
pub current_loc: Location, pub current_loc: Location,
/// The cached type of `size_t`.
llvm_usize: OnceCell<IntType<'ctx>>,
} }
impl<'ctx> CodeGenContext<'ctx, '_> { impl CodeGenContext<'_, '_> {
/// Whether the [current basic block][Builder::get_insert_block] referenced by `builder` /// Whether the [current basic block][Builder::get_insert_block] referenced by `builder`
/// contains a [terminator statement][BasicBlock::get_terminator]. /// contains a [terminator statement][BasicBlock::get_terminator].
pub fn is_terminated(&self) -> bool { pub fn is_terminated(&self) -> bool {
self.builder.get_insert_block().and_then(BasicBlock::get_terminator).is_some() self.builder.get_insert_block().and_then(BasicBlock::get_terminator).is_some()
} }
/// Returns a [`IntType`] representing `size_t` for the compilation target as specified by
/// [`self.registry`][WorkerRegistry].
pub fn get_size_type(&self) -> IntType<'ctx> {
*self.llvm_usize.get_or_init(|| {
self.ctx.ptr_sized_int_type(
&self
.registry
.llvm_options
.create_target_machine()
.map(|tm| tm.get_target_data())
.unwrap(),
None,
)
})
}
} }
type Fp = Box<dyn Fn(&Module) + Send + Sync>; type Fp = Box<dyn Fn(&Module) + Send + Sync>;
@ -501,38 +481,6 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
type_cache.get(&unifier.get_representative(ty)).copied().unwrap_or_else(|| { type_cache.get(&unifier.get_representative(ty)).copied().unwrap_or_else(|| {
let ty_enum = unifier.get_ty(ty); let ty_enum = unifier.get_ty(ty);
let result = match &*ty_enum { let result = match &*ty_enum {
TModule {module_id, attributes} => {
let top_level_defs = top_level.definitions.read();
let definition = top_level_defs.get(module_id.0).unwrap();
let TopLevelDef::Module { name, attributes: attribute_fields, .. } = &*definition.read() else {
unreachable!()
};
let ty: BasicTypeEnum<'_> = if let Some(t) = module.get_struct_type(&name.to_string()) {
t.ptr_type(AddressSpace::default()).into()
} else {
let struct_type = ctx.opaque_struct_type(&name.to_string());
type_cache.insert(
unifier.get_representative(ty),
struct_type.ptr_type(AddressSpace::default()).into(),
);
let module_fields: Vec<BasicTypeEnum<'_>> = attribute_fields.iter()
.map(|f| {
get_llvm_type(
ctx,
module,
generator,
unifier,
top_level,
type_cache,
attributes[&f.0].0,
)
})
.collect_vec();
struct_type.set_body(&module_fields, false);
struct_type.ptr_type(AddressSpace::default()).into()
};
return ty;
},
TObj { obj_id, fields, .. } => { TObj { obj_id, fields, .. } => {
// check to avoid treating non-class primitives as classes // check to avoid treating non-class primitives as classes
if PrimDef::contains_id(*obj_id) { if PrimDef::contains_id(*obj_id) {
@ -562,7 +510,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
*params.iter().next().unwrap().1, *params.iter().next().unwrap().1,
); );
ListType::new_with_generator(generator, ctx, element_type).as_base_type().into() ListType::new(generator, ctx, element_type).as_base_type().into()
} }
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
@ -572,7 +520,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
ctx, module, generator, unifier, top_level, type_cache, dtype, ctx, module, generator, unifier, top_level, type_cache, dtype,
); );
NDArrayType::new_with_generator(generator, ctx, element_type, ndims).as_base_type().into() NDArrayType::new(generator, ctx, element_type, ndims).as_base_type().into()
} }
_ => unreachable!( _ => unreachable!(
@ -626,7 +574,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty) get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty)
}) })
.collect_vec(); .collect_vec();
TupleType::new_with_generator(generator, ctx, &fields).as_base_type().into() TupleType::new(generator, ctx, &fields).as_base_type().into()
} }
TVirtual { .. } => unimplemented!(), TVirtual { .. } => unimplemented!(),
_ => unreachable!("{}", ty_enum.get_type_name()), _ => unreachable!("{}", ty_enum.get_type_name()),
@ -1039,20 +987,8 @@ pub fn gen_func_impl<
need_sret: has_sret, need_sret: has_sret,
current_loc: Location::default(), current_loc: Location::default(),
debug_info: (dibuilder, compile_unit, func_scope.as_debug_info_scope()), debug_info: (dibuilder, compile_unit, func_scope.as_debug_info_scope()),
llvm_usize: OnceCell::default(),
}; };
let target_llvm_usize = context.ptr_sized_int_type(
&registry.llvm_options.create_target_machine().map(|tm| tm.get_target_data()).unwrap(),
None,
);
let generator_llvm_usize = generator.get_size_type(context);
assert_eq!(
generator_llvm_usize,
target_llvm_usize,
"CodeGenerator (size_t = {generator_llvm_usize}) is not compatible with CodeGen Target (size_t = {target_llvm_usize})",
);
let loc = code_gen_context.debug_info.0.create_debug_location( let loc = code_gen_context.debug_info.0.create_debug_location(
context, context,
row as u32, row as u32,
@ -1244,7 +1180,7 @@ pub fn type_aligned_alloca<'ctx, G: CodeGenerator + ?Sized>(
let llvm_i8 = ctx.ctx.i8_type(); let llvm_i8 = ctx.ctx.i8_type();
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
let align_ty = align_ty.into(); let align_ty = align_ty.into();
let size = ctx.builder.build_int_truncate_or_bit_cast(size, llvm_usize, "").unwrap(); let size = ctx.builder.build_int_truncate_or_bit_cast(size, llvm_usize, "").unwrap();

View File

@ -42,7 +42,7 @@ pub fn gen_ndarray_empty<'ctx>(
let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg));
let ndarray = NDArrayType::new(context, llvm_dtype, ndims) let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, ndims)
.construct_numpy_empty(generator, context, &shape, None); .construct_numpy_empty(generator, context, &shape, None);
Ok(ndarray.as_base_value()) Ok(ndarray.as_base_value())
} }
@ -67,7 +67,7 @@ pub fn gen_ndarray_zeros<'ctx>(
let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg));
let ndarray = NDArrayType::new(context, llvm_dtype, ndims) let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, ndims)
.construct_numpy_zeros(generator, context, dtype, &shape, None); .construct_numpy_zeros(generator, context, dtype, &shape, None);
Ok(ndarray.as_base_value()) Ok(ndarray.as_base_value())
} }
@ -92,7 +92,7 @@ pub fn gen_ndarray_ones<'ctx>(
let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg));
let ndarray = NDArrayType::new(context, llvm_dtype, ndims) let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, ndims)
.construct_numpy_ones(generator, context, dtype, &shape, None); .construct_numpy_ones(generator, context, dtype, &shape, None);
Ok(ndarray.as_base_value()) Ok(ndarray.as_base_value())
} }
@ -120,7 +120,7 @@ pub fn gen_ndarray_full<'ctx>(
let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg)); let shape = parse_numpy_int_sequence(generator, context, (shape_ty, shape_arg));
let ndarray = NDArrayType::new(context, llvm_dtype, ndims).construct_numpy_full( let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, ndims).construct_numpy_full(
generator, generator,
context, context,
&shape, &shape,
@ -207,7 +207,7 @@ pub fn gen_ndarray_eye<'ctx>(
let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
let llvm_usize = context.get_size_type(); let llvm_usize = generator.get_size_type(context.ctx);
let llvm_dtype = context.get_llvm_type(generator, dtype); let llvm_dtype = context.get_llvm_type(generator, dtype);
let nrows = context let nrows = context
@ -223,7 +223,7 @@ pub fn gen_ndarray_eye<'ctx>(
.build_int_s_extend_or_bit_cast(offset_arg.into_int_value(), llvm_usize, "") .build_int_s_extend_or_bit_cast(offset_arg.into_int_value(), llvm_usize, "")
.unwrap(); .unwrap();
let ndarray = NDArrayType::new(context, llvm_dtype, 2) let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, 2)
.construct_numpy_eye(generator, context, dtype, nrows, ncols, offset, None); .construct_numpy_eye(generator, context, dtype, nrows, ncols, offset, None);
Ok(ndarray.as_base_value()) Ok(ndarray.as_base_value())
} }
@ -244,14 +244,14 @@ pub fn gen_ndarray_identity<'ctx>(
let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret);
let llvm_usize = context.get_size_type(); let llvm_usize = generator.get_size_type(context.ctx);
let llvm_dtype = context.get_llvm_type(generator, dtype); let llvm_dtype = context.get_llvm_type(generator, dtype);
let n = context let n = context
.builder .builder
.build_int_s_extend_or_bit_cast(n_arg.into_int_value(), llvm_usize, "") .build_int_s_extend_or_bit_cast(n_arg.into_int_value(), llvm_usize, "")
.unwrap(); .unwrap();
let ndarray = NDArrayType::new(context, llvm_dtype, 2) let ndarray = NDArrayType::new(generator, context.ctx, llvm_dtype, 2)
.construct_numpy_identity(generator, context, dtype, n, None); .construct_numpy_identity(generator, context, dtype, n, None);
Ok(ndarray.as_base_value()) Ok(ndarray.as_base_value())
} }
@ -325,8 +325,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
let common_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty); let common_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty);
// Check shapes. // Check shapes.
let a_size = a.size(ctx); let a_size = a.size(generator, ctx);
let b_size = b.size(ctx); let b_size = b.size(generator, ctx);
let same_shape = let same_shape =
ctx.builder.build_int_compare(IntPredicate::EQ, a_size, b_size, "").unwrap(); ctx.builder.build_int_compare(IntPredicate::EQ, a_size, b_size, "").unwrap();
ctx.make_assert( ctx.make_assert(
@ -349,13 +349,13 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
Some("np_dot"), Some("np_dot"),
|generator, ctx| { |generator, ctx| {
let a_iter = NDIterType::new(ctx).construct(generator, ctx, a); let a_iter = NDIterType::new(generator, ctx.ctx).construct(generator, ctx, a);
let b_iter = NDIterType::new(ctx).construct(generator, ctx, b); let b_iter = NDIterType::new(generator, ctx.ctx).construct(generator, ctx, b);
Ok((a_iter, b_iter)) Ok((a_iter, b_iter))
}, },
|_, ctx, (a_iter, _b_iter)| { |generator, ctx, (a_iter, _b_iter)| {
// Only a_iter drives the condition, b_iter should have the same status. // Only a_iter drives the condition, b_iter should have the same status.
Ok(a_iter.has_element(ctx)) Ok(a_iter.has_element(generator, ctx))
}, },
|_, ctx, _hooks, (a_iter, b_iter)| { |_, ctx, _hooks, (a_iter, b_iter)| {
let a_scalar = a_iter.get_scalar(ctx); let a_scalar = a_iter.get_scalar(ctx);
@ -385,9 +385,9 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
ctx.builder.build_store(result, new_result).unwrap(); ctx.builder.build_store(result, new_result).unwrap();
Ok(()) Ok(())
}, },
|_, ctx, (a_iter, b_iter)| { |generator, ctx, (a_iter, b_iter)| {
a_iter.next(ctx); a_iter.next(generator, ctx);
b_iter.next(ctx); b_iter.next(generator, ctx);
Ok(()) Ok(())
}, },
) )

View File

@ -1,7 +1,6 @@
use inkwell::{ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
basic_block::BasicBlock, basic_block::BasicBlock,
builder::Builder,
types::{BasicType, BasicTypeEnum}, types::{BasicType, BasicTypeEnum},
values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue}, values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue},
IntPredicate, IntPredicate,
@ -307,7 +306,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{ {
// Handle list item assignment // Handle list item assignment
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
let target_item_ty = iter_type_vars(list_params).next().unwrap().ty; let target_item_ty = iter_type_vars(list_params).next().unwrap().ty;
let target = generator let target = generator
@ -368,8 +367,10 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
.unwrap() .unwrap()
.to_basic_value_enum(ctx, generator, key_ty)? .to_basic_value_enum(ctx, generator, key_ty)?
.into_int_value(); .into_int_value();
let index = let index = ctx
ctx.builder.build_int_s_extend(index, ctx.get_size_type(), "sext").unwrap(); .builder
.build_int_s_extend(index, generator.get_size_type(ctx.ctx), "sext")
.unwrap();
// handle negative index // handle negative index
let is_negative = ctx let is_negative = ctx
@ -377,7 +378,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
.build_int_compare( .build_int_compare(
IntPredicate::SLT, IntPredicate::SLT,
index, index,
ctx.get_size_type().const_zero(), generator.get_size_type(ctx.ctx).const_zero(),
"is_neg", "is_neg",
) )
.unwrap(); .unwrap();
@ -449,7 +450,8 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
let broadcast_ndims = let broadcast_ndims =
[target.get_type().ndims(), value.get_type().ndims()].into_iter().max().unwrap(); [target.get_type().ndims(), value.get_type().ndims()].into_iter().max().unwrap();
let broadcast_result = NDArrayType::new( let broadcast_result = NDArrayType::new(
ctx, generator,
ctx.ctx,
value.get_type().element_type(), value.get_type().element_type(),
broadcast_ndims, broadcast_ndims,
) )
@ -458,7 +460,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
let target = broadcast_result.ndarrays[0]; let target = broadcast_result.ndarrays[0];
let value = broadcast_result.ndarrays[1]; let value = broadcast_result.ndarrays[1];
target.copy_data_from(ctx, value); target.copy_data_from(generator, ctx, value);
} }
_ => { _ => {
panic!("encountered unknown target type: {}", ctx.unifier.stringify(target_ty)); panic!("encountered unknown target type: {}", ctx.unifier.stringify(target_ty));
@ -482,7 +484,7 @@ pub fn gen_for<G: CodeGenerator>(
let var_assignment = ctx.var_assignment.clone(); let var_assignment = ctx.var_assignment.clone();
let int32 = ctx.ctx.i32_type(); let int32 = ctx.ctx.i32_type();
let size_t = ctx.get_size_type(); let size_t = generator.get_size_type(ctx.ctx);
let zero = int32.const_zero(); let zero = int32.const_zero();
let current = ctx.builder.get_insert_block().and_then(BasicBlock::get_parent).unwrap(); let current = ctx.builder.get_insert_block().and_then(BasicBlock::get_parent).unwrap();
let body_bb = ctx.ctx.append_basic_block(current, "for.body"); let body_bb = ctx.ctx.append_basic_block(current, "for.body");
@ -663,25 +665,11 @@ pub fn gen_for<G: CodeGenerator>(
#[derive(PartialEq, Eq, Debug, Clone, Copy, Hash)] #[derive(PartialEq, Eq, Debug, Clone, Copy, Hash)]
pub struct BreakContinueHooks<'ctx> { pub struct BreakContinueHooks<'ctx> {
/// The [exit block][`BasicBlock`] to branch to when `break`-ing out of a loop. /// The [exit block][`BasicBlock`] to branch to when `break`-ing out of a loop.
exit_bb: BasicBlock<'ctx>, pub exit_bb: BasicBlock<'ctx>,
/// The [latch basic block][`BasicBlock`] to branch to for `continue`-ing to the next iteration /// The [latch basic block][`BasicBlock`] to branch to for `continue`-ing to the next iteration
/// of the loop. /// of the loop.
latch_bb: BasicBlock<'ctx>, pub latch_bb: BasicBlock<'ctx>,
}
impl<'ctx> BreakContinueHooks<'ctx> {
/// Creates a [`br` instruction][Builder::build_unconditional_branch] to the exit
/// [`BasicBlock`], as if by calling `break`.
pub fn build_break_branch(&self, builder: &Builder<'ctx>) {
builder.build_unconditional_branch(self.exit_bb).unwrap();
}
/// Creates a [`br` instruction][Builder::build_unconditional_branch] to the latch
/// [`BasicBlock`], as if by calling `continue`.
pub fn build_continue_branch(&self, builder: &Builder<'ctx>) {
builder.build_unconditional_branch(self.latch_bb).unwrap();
}
} }
/// Generates a C-style `for` construct using lambdas, similar to the following C code: /// Generates a C-style `for` construct using lambdas, similar to the following C code:

View File

@ -97,8 +97,7 @@ fn test_primitives() {
"}; "};
let statements = parse_program(source, FileName::default()).unwrap(); let statements = parse_program(source, FileName::default()).unwrap();
let context = inkwell::context::Context::create(); let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 32).0;
let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
let mut unifier = composer.unifier.clone(); let mut unifier = composer.unifier.clone();
let primitives = composer.primitives_ty; let primitives = composer.primitives_ty;
let top_level = Arc::new(composer.make_top_level_context()); let top_level = Arc::new(composer.make_top_level_context());
@ -108,7 +107,7 @@ fn test_primitives() {
Arc::new(Resolver { id_to_type: HashMap::new(), id_to_def: RwLock::new(HashMap::new()) }) Arc::new(Resolver { id_to_type: HashMap::new(), id_to_def: RwLock::new(HashMap::new()) })
as Arc<dyn SymbolResolver + Send + Sync>; as Arc<dyn SymbolResolver + Send + Sync>;
let threads = vec![DefaultCodeGenerator::new("test".into(), context.i64_type()).into()]; let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()];
let signature = FunSignature { let signature = FunSignature {
args: vec![ args: vec![
FuncArg { FuncArg {
@ -261,8 +260,7 @@ fn test_simple_call() {
"}; "};
let statements_2 = parse_program(source_2, FileName::default()).unwrap(); let statements_2 = parse_program(source_2, FileName::default()).unwrap();
let context = inkwell::context::Context::create(); let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 32).0;
let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
let mut unifier = composer.unifier.clone(); let mut unifier = composer.unifier.clone();
let primitives = composer.primitives_ty; let primitives = composer.primitives_ty;
let top_level = Arc::new(composer.make_top_level_context()); let top_level = Arc::new(composer.make_top_level_context());
@ -309,7 +307,7 @@ fn test_simple_call() {
unreachable!() unreachable!()
} }
let threads = vec![DefaultCodeGenerator::new("test".into(), context.i64_type()).into()]; let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()];
let mut function_data = FunctionData { let mut function_data = FunctionData {
resolver: resolver.clone(), resolver: resolver.clone(),
bound_variables: Vec::new(), bound_variables: Vec::new(),
@ -441,12 +439,12 @@ fn test_simple_call() {
#[test] #[test]
fn test_classes_list_type_new() { fn test_classes_list_type_new() {
let ctx = inkwell::context::Context::create(); let ctx = inkwell::context::Context::create();
let generator = DefaultCodeGenerator::new(String::new(), ctx.i64_type()); let generator = DefaultCodeGenerator::new(String::new(), 64);
let llvm_i32 = ctx.i32_type(); let llvm_i32 = ctx.i32_type();
let llvm_usize = generator.get_size_type(&ctx); let llvm_usize = generator.get_size_type(&ctx);
let llvm_list = ListType::new_with_generator(&generator, &ctx, llvm_i32.into()); let llvm_list = ListType::new(&generator, &ctx, llvm_i32.into());
assert!(ListType::is_representable(llvm_list.as_base_type(), llvm_usize).is_ok()); assert!(ListType::is_representable(llvm_list.as_base_type(), llvm_usize).is_ok());
} }
@ -461,11 +459,11 @@ fn test_classes_range_type_new() {
#[test] #[test]
fn test_classes_ndarray_type_new() { fn test_classes_ndarray_type_new() {
let ctx = inkwell::context::Context::create(); let ctx = inkwell::context::Context::create();
let generator = DefaultCodeGenerator::new(String::new(), ctx.i64_type()); let generator = DefaultCodeGenerator::new(String::new(), 64);
let llvm_i32 = ctx.i32_type(); let llvm_i32 = ctx.i32_type();
let llvm_usize = generator.get_size_type(&ctx); let llvm_usize = generator.get_size_type(&ctx);
let llvm_ndarray = NDArrayType::new_with_generator(&generator, &ctx, llvm_i32.into(), 2); let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into(), 2);
assert!(NDArrayType::is_representable(llvm_ndarray.as_base_type(), llvm_usize).is_ok()); assert!(NDArrayType::is_representable(llvm_ndarray.as_base_type(), llvm_usize).is_ok());
} }

View File

@ -104,7 +104,7 @@ impl<'ctx> ListType<'ctx> {
element_type: Option<BasicTypeEnum<'ctx>>, element_type: Option<BasicTypeEnum<'ctx>>,
llvm_usize: IntType<'ctx>, llvm_usize: IntType<'ctx>,
) -> PointerType<'ctx> { ) -> PointerType<'ctx> {
let element_type = element_type.map_or(llvm_usize.into(), |ty| ty.as_basic_type_enum()); let element_type = element_type.unwrap_or(llvm_usize.into());
let field_tys = let field_tys =
Self::fields(element_type, llvm_usize).into_iter().map(|field| field.1).collect_vec(); Self::fields(element_type, llvm_usize).into_iter().map(|field| field.1).collect_vec();
@ -112,45 +112,26 @@ impl<'ctx> ListType<'ctx> {
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
} }
fn new_impl(
ctx: &'ctx Context,
element_type: Option<BasicTypeEnum<'ctx>>,
llvm_usize: IntType<'ctx>,
) -> Self {
let llvm_list = Self::llvm_type(ctx, element_type, llvm_usize);
Self { ty: llvm_list, item: element_type, llvm_usize }
}
/// Creates an instance of [`ListType`]. /// Creates an instance of [`ListType`].
#[must_use] #[must_use]
pub fn new(ctx: &CodeGenContext<'ctx, '_>, element_type: &impl BasicType<'ctx>) -> Self { pub fn new<G: CodeGenerator + ?Sized>(
Self::new_impl(ctx.ctx, Some(element_type.as_basic_type_enum()), ctx.get_size_type())
}
/// Creates an instance of [`ListType`].
#[must_use]
pub fn new_with_generator<G: CodeGenerator + ?Sized>(
generator: &G, generator: &G,
ctx: &'ctx Context, ctx: &'ctx Context,
element_type: BasicTypeEnum<'ctx>, element_type: BasicTypeEnum<'ctx>,
) -> Self { ) -> Self {
Self::new_impl(ctx, Some(element_type.as_basic_type_enum()), generator.get_size_type(ctx)) let llvm_usize = generator.get_size_type(ctx);
let llvm_list = Self::llvm_type(ctx, Some(element_type), llvm_usize);
Self { ty: llvm_list, item: Some(element_type), llvm_usize }
} }
/// Creates an instance of [`ListType`] with an unknown element type. /// Creates an instance of [`ListType`] with an unknown element type.
#[must_use] #[must_use]
pub fn new_untyped(ctx: &CodeGenContext<'ctx, '_>) -> Self { pub fn new_untyped<G: CodeGenerator + ?Sized>(generator: &G, ctx: &'ctx Context) -> Self {
Self::new_impl(ctx.ctx, None, ctx.get_size_type()) let llvm_usize = generator.get_size_type(ctx);
} let llvm_list = Self::llvm_type(ctx, None, llvm_usize);
/// Creates an instance of [`ListType`] with an unknown element type. Self { ty: llvm_list, item: None, llvm_usize }
#[must_use]
pub fn new_untyped_with_generator<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
) -> Self {
Self::new_impl(ctx, None, generator.get_size_type(ctx))
} }
/// Creates an [`ListType`] from a [unifier type][Type]. /// Creates an [`ListType`] from a [unifier type][Type].
@ -171,14 +152,18 @@ impl<'ctx> ListType<'ctx> {
_ => panic!("Expected `list` type, but got {}", ctx.unifier.stringify(ty)), _ => panic!("Expected `list` type, but got {}", ctx.unifier.stringify(ty)),
}; };
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_elem_type = if let TypeEnum::TVar { .. } = &*ctx.unifier.get_ty_immutable(ty) { let llvm_elem_type = if let TypeEnum::TVar { .. } = &*ctx.unifier.get_ty_immutable(ty) {
None None
} else { } else {
Some(ctx.get_llvm_type(generator, elem_type)) Some(ctx.get_llvm_type(generator, elem_type))
}; };
Self::new_impl(ctx.ctx, llvm_elem_type, llvm_usize) Self {
ty: Self::llvm_type(ctx.ctx, llvm_elem_type, llvm_usize),
item: llvm_elem_type,
llvm_usize,
}
} }
/// Creates an [`ListType`] from a [`PointerType`]. /// Creates an [`ListType`] from a [`PointerType`].
@ -288,7 +273,7 @@ impl<'ctx> ListType<'ctx> {
} }
let plist = self.alloca_var(generator, ctx, name); let plist = self.alloca_var(generator, ctx, name);
plist.store_size(ctx, len); plist.store_size(ctx, generator, len);
let item = self.item.unwrap_or(self.llvm_usize.into()); let item = self.item.unwrap_or(self.llvm_usize.into());
plist.create_data(ctx, item, None); plist.create_data(ctx, item, None);
@ -315,7 +300,7 @@ impl<'ctx> ListType<'ctx> {
) -> <Self as ProxyType<'ctx>>::Value { ) -> <Self as ProxyType<'ctx>>::Value {
let plist = self.alloca_var(generator, ctx, name); let plist = self.alloca_var(generator, ctx, name);
plist.store_size(ctx, self.llvm_usize.const_zero()); plist.store_size(ctx, generator, self.llvm_usize.const_zero());
plist.create_data(ctx, self.item.unwrap_or(self.llvm_usize.into()), None); plist.create_data(ctx, self.item.unwrap_or(self.llvm_usize.into()), None);
plist plist

View File

@ -44,7 +44,7 @@ impl<'ctx> NDArrayType<'ctx> {
assert!(self.ndims >= ndims_int); assert!(self.ndims >= ndims_int);
assert_eq!(dtype, self.dtype); assert_eq!(dtype, self.dtype);
let list_value = list.as_i8_list(ctx); let list_value = list.as_i8_list(generator, ctx);
// Validate `list` has a consistent shape. // Validate `list` has a consistent shape.
// Raise an exception if `list` is something abnormal like `[[1, 2], [3]]`. // Raise an exception if `list` is something abnormal like `[[1, 2], [3]]`.
@ -61,13 +61,15 @@ impl<'ctx> NDArrayType<'ctx> {
generator, ctx, list_value, ndims, &shape, generator, ctx, list_value, ndims, &shape,
); );
let ndarray = let ndarray = Self::new(generator, ctx.ctx, dtype, ndims_int)
Self::new(ctx, dtype, ndims_int).construct_uninitialized(generator, ctx, name); .construct_uninitialized(generator, ctx, name);
ndarray.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator)); ndarray.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator));
unsafe { ndarray.create_data(generator, ctx) }; unsafe { ndarray.create_data(generator, ctx) };
// Copy all contents from the list. // Copy all contents from the list.
irrt::ndarray::call_nac3_ndarray_array_write_list_to_array(ctx, list_value, ndarray); irrt::ndarray::call_nac3_ndarray_array_write_list_to_array(
generator, ctx, list_value, ndarray,
);
ndarray ndarray
} }
@ -96,7 +98,8 @@ impl<'ctx> NDArrayType<'ctx> {
let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default()); let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
let ndarray = Self::new(ctx, dtype, 1).construct_uninitialized(generator, ctx, name); let ndarray = Self::new(generator, ctx.ctx, dtype, 1)
.construct_uninitialized(generator, ctx, name);
// Set data // Set data
let data = ctx let data = ctx
@ -113,7 +116,7 @@ impl<'ctx> NDArrayType<'ctx> {
} }
// Set strides, the `data` is contiguous // Set strides, the `data` is contiguous
ndarray.set_strides_contiguous(ctx); ndarray.set_strides_contiguous(generator, ctx);
ndarray ndarray
} else { } else {
@ -167,7 +170,7 @@ impl<'ctx> NDArrayType<'ctx> {
.map(BasicValueEnum::into_pointer_value) .map(BasicValueEnum::into_pointer_value)
.unwrap(); .unwrap();
NDArrayType::new(ctx, dtype, ndims).map_value(ndarray, None) NDArrayType::new(generator, ctx.ctx, dtype, ndims).map_value(ndarray, None)
} }
/// Implementation of `np_array(<ndarray>, copy=copy)`. /// Implementation of `np_array(<ndarray>, copy=copy)`.

View File

@ -79,27 +79,15 @@ impl<'ctx> ShapeEntryType<'ctx> {
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
} }
fn new_impl(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self { /// Creates an instance of [`ShapeEntryType`].
#[must_use]
pub fn new<G: CodeGenerator + ?Sized>(generator: &G, ctx: &'ctx Context) -> Self {
let llvm_usize = generator.get_size_type(ctx);
let llvm_ty = Self::llvm_type(ctx, llvm_usize); let llvm_ty = Self::llvm_type(ctx, llvm_usize);
Self { ty: llvm_ty, llvm_usize } Self { ty: llvm_ty, llvm_usize }
} }
/// Creates an instance of [`ShapeEntryType`].
#[must_use]
pub fn new(ctx: &CodeGenContext<'ctx, '_>) -> Self {
Self::new_impl(ctx.ctx, ctx.get_size_type())
}
/// Creates an instance of [`ShapeEntryType`].
#[must_use]
pub fn new_with_generator<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
) -> Self {
Self::new_impl(ctx, generator.get_size_type(ctx))
}
/// Creates a [`ShapeEntryType`] from a [`PointerType`] representing an `ShapeEntry`. /// Creates a [`ShapeEntryType`] from a [`PointerType`] representing an `ShapeEntry`.
#[must_use] #[must_use]
pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {

View File

@ -117,26 +117,17 @@ impl<'ctx> ContiguousNDArrayType<'ctx> {
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
} }
fn new_impl(ctx: &'ctx Context, item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
let llvm_cndarray = Self::llvm_type(ctx, item, llvm_usize);
Self { ty: llvm_cndarray, item, llvm_usize }
}
/// Creates an instance of [`ContiguousNDArrayType`]. /// Creates an instance of [`ContiguousNDArrayType`].
#[must_use] #[must_use]
pub fn new(ctx: &CodeGenContext<'ctx, '_>, item: &impl BasicType<'ctx>) -> Self { pub fn new<G: CodeGenerator + ?Sized>(
Self::new_impl(ctx.ctx, item.as_basic_type_enum(), ctx.get_size_type())
}
/// Creates an instance of [`ContiguousNDArrayType`].
#[must_use]
pub fn new_with_generator<G: CodeGenerator + ?Sized>(
generator: &G, generator: &G,
ctx: &'ctx Context, ctx: &'ctx Context,
item: BasicTypeEnum<'ctx>, item: BasicTypeEnum<'ctx>,
) -> Self { ) -> Self {
Self::new_impl(ctx, item, generator.get_size_type(ctx)) let llvm_usize = generator.get_size_type(ctx);
let llvm_cndarray = Self::llvm_type(ctx, item, llvm_usize);
Self { ty: llvm_cndarray, item, llvm_usize }
} }
/// Creates an [`ContiguousNDArrayType`] from a [unifier type][Type]. /// Creates an [`ContiguousNDArrayType`] from a [unifier type][Type].
@ -149,8 +140,9 @@ impl<'ctx> ContiguousNDArrayType<'ctx> {
let (dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); let (dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
let llvm_dtype = ctx.get_llvm_type(generator, dtype); let llvm_dtype = ctx.get_llvm_type(generator, dtype);
let llvm_usize = generator.get_size_type(ctx.ctx);
Self::new_impl(ctx.ctx, llvm_dtype, ctx.get_size_type()) Self { ty: Self::llvm_type(ctx.ctx, llvm_dtype, llvm_usize), item: llvm_dtype, llvm_usize }
} }
/// Creates an [`ContiguousNDArrayType`] from a [`PointerType`] representing an `NDArray`. /// Creates an [`ContiguousNDArrayType`] from a [`PointerType`] representing an `NDArray`.

View File

@ -75,25 +75,14 @@ impl<'ctx> NDIndexType<'ctx> {
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
} }
fn new_impl(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self { #[must_use]
pub fn new<G: CodeGenerator + ?Sized>(generator: &G, ctx: &'ctx Context) -> Self {
let llvm_usize = generator.get_size_type(ctx);
let llvm_ndindex = Self::llvm_type(ctx, llvm_usize); let llvm_ndindex = Self::llvm_type(ctx, llvm_usize);
Self { ty: llvm_ndindex, llvm_usize } Self { ty: llvm_ndindex, llvm_usize }
} }
#[must_use]
pub fn new(ctx: &CodeGenContext<'ctx, '_>) -> Self {
Self::new_impl(ctx.ctx, ctx.get_size_type())
}
#[must_use]
pub fn new_with_generator<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
) -> Self {
Self::new_impl(ctx, generator.get_size_type(ctx))
}
#[must_use] #[must_use]
pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok());

View File

@ -46,7 +46,8 @@ impl<'ctx> NDArrayType<'ctx> {
let out_ndarray = match out { let out_ndarray = match out {
NDArrayOut::NewNDArray { dtype } => { NDArrayOut::NewNDArray { dtype } => {
// Create a new ndarray based on the broadcast shape. // Create a new ndarray based on the broadcast shape.
let result_ndarray = NDArrayType::new(ctx, dtype, broadcast_result.ndims) let result_ndarray =
NDArrayType::new(generator, ctx.ctx, dtype, broadcast_result.ndims)
.construct_uninitialized(generator, ctx, None); .construct_uninitialized(generator, ctx, None);
result_ndarray.copy_shape_from_array( result_ndarray.copy_shape_from_array(
generator, generator,
@ -69,7 +70,7 @@ impl<'ctx> NDArrayType<'ctx> {
}; };
// Map element-wise and store results into `mapped_ndarray`. // Map element-wise and store results into `mapped_ndarray`.
let nditer = NDIterType::new(ctx).construct(generator, ctx, out_ndarray); let nditer = NDIterType::new(generator, ctx.ctx).construct(generator, ctx, out_ndarray);
gen_for_callback( gen_for_callback(
generator, generator,
ctx, ctx,
@ -79,14 +80,16 @@ impl<'ctx> NDArrayType<'ctx> {
let other_nditers = broadcast_result let other_nditers = broadcast_result
.ndarrays .ndarrays
.iter() .iter()
.map(|ndarray| NDIterType::new(ctx).construct(generator, ctx, *ndarray)) .map(|ndarray| {
NDIterType::new(generator, ctx.ctx).construct(generator, ctx, *ndarray)
})
.collect_vec(); .collect_vec();
Ok((nditer, other_nditers)) Ok((nditer, other_nditers))
}, },
|_, ctx, (out_nditer, _in_nditers)| { |generator, ctx, (out_nditer, _in_nditers)| {
// We can simply use `out_nditer`'s `has_element()`. // We can simply use `out_nditer`'s `has_element()`.
// `in_nditers`' `has_element()`s should return the same value. // `in_nditers`' `has_element()`s should return the same value.
Ok(out_nditer.has_element(ctx)) Ok(out_nditer.has_element(generator, ctx))
}, },
|generator, ctx, _hooks, (out_nditer, in_nditers)| { |generator, ctx, _hooks, (out_nditer, in_nditers)| {
// Get all the scalars from the broadcasted input ndarrays, pass them to `mapping`, // Get all the scalars from the broadcasted input ndarrays, pass them to `mapping`,
@ -101,10 +104,10 @@ impl<'ctx> NDArrayType<'ctx> {
Ok(()) Ok(())
}, },
|_, ctx, (out_nditer, in_nditers)| { |generator, ctx, (out_nditer, in_nditers)| {
// Advance all iterators // Advance all iterators
out_nditer.next(ctx); out_nditer.next(generator, ctx);
in_nditers.iter().for_each(|nditer| nditer.next(ctx)); in_nditers.iter().for_each(|nditer| nditer.next(generator, ctx));
Ok(()) Ok(())
}, },
)?; )?;
@ -166,7 +169,8 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
// Promote all input to ndarrays and map through them. // Promote all input to ndarrays and map through them.
let inputs = inputs.iter().map(|input| input.to_ndarray(generator, ctx)).collect_vec(); let inputs = inputs.iter().map(|input| input.to_ndarray(generator, ctx)).collect_vec();
let ndarray = NDArrayType::new_broadcast( let ndarray = NDArrayType::new_broadcast(
ctx, generator,
ctx.ctx,
ret_dtype, ret_dtype,
&inputs.iter().map(NDArrayValue::get_type).collect_vec(), &inputs.iter().map(NDArrayValue::get_type).collect_vec(),
) )

View File

@ -107,56 +107,24 @@ impl<'ctx> NDArrayType<'ctx> {
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
} }
fn new_impl( /// Creates an instance of [`NDArrayType`].
#[must_use]
pub fn new<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context, ctx: &'ctx Context,
dtype: BasicTypeEnum<'ctx>, dtype: BasicTypeEnum<'ctx>,
ndims: u64, ndims: u64,
llvm_usize: IntType<'ctx>,
) -> Self { ) -> Self {
let llvm_usize = generator.get_size_type(ctx);
let llvm_ndarray = Self::llvm_type(ctx, llvm_usize); let llvm_ndarray = Self::llvm_type(ctx, llvm_usize);
NDArrayType { ty: llvm_ndarray, dtype, ndims, llvm_usize } NDArrayType { ty: llvm_ndarray, dtype, ndims, llvm_usize }
} }
/// Creates an instance of [`NDArrayType`].
#[must_use]
pub fn new(ctx: &CodeGenContext<'ctx, '_>, dtype: BasicTypeEnum<'ctx>, ndims: u64) -> Self {
Self::new_impl(ctx.ctx, dtype, ndims, ctx.get_size_type())
}
/// Creates an instance of [`NDArrayType`].
#[must_use]
pub fn new_with_generator<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
dtype: BasicTypeEnum<'ctx>,
ndims: u64,
) -> Self {
Self::new_impl(ctx, dtype, ndims, generator.get_size_type(ctx))
}
/// Creates an instance of [`NDArrayType`] as a result of a broadcast operation over one or more /// Creates an instance of [`NDArrayType`] as a result of a broadcast operation over one or more
/// `ndarray` operands. /// `ndarray` operands.
#[must_use] #[must_use]
pub fn new_broadcast( pub fn new_broadcast<G: CodeGenerator + ?Sized>(
ctx: &CodeGenContext<'ctx, '_>,
dtype: BasicTypeEnum<'ctx>,
inputs: &[NDArrayType<'ctx>],
) -> Self {
assert!(!inputs.is_empty());
Self::new_impl(
ctx.ctx,
dtype,
inputs.iter().map(NDArrayType::ndims).max().unwrap(),
ctx.get_size_type(),
)
}
/// Creates an instance of [`NDArrayType`] as a result of a broadcast operation over one or more
/// `ndarray` operands.
#[must_use]
pub fn new_broadcast_with_generator<G: CodeGenerator + ?Sized>(
generator: &G, generator: &G,
ctx: &'ctx Context, ctx: &'ctx Context,
dtype: BasicTypeEnum<'ctx>, dtype: BasicTypeEnum<'ctx>,
@ -164,28 +132,20 @@ impl<'ctx> NDArrayType<'ctx> {
) -> Self { ) -> Self {
assert!(!inputs.is_empty()); assert!(!inputs.is_empty());
Self::new_impl( Self::new(generator, ctx, dtype, inputs.iter().map(NDArrayType::ndims).max().unwrap())
ctx,
dtype,
inputs.iter().map(NDArrayType::ndims).max().unwrap(),
generator.get_size_type(ctx),
)
} }
/// Creates an instance of [`NDArrayType`] with `ndims` of 0. /// Creates an instance of [`NDArrayType`] with `ndims` of 0.
#[must_use] #[must_use]
pub fn new_unsized(ctx: &CodeGenContext<'ctx, '_>, dtype: BasicTypeEnum<'ctx>) -> Self { pub fn new_unsized<G: CodeGenerator + ?Sized>(
Self::new_impl(ctx.ctx, dtype, 0, ctx.get_size_type())
}
/// Creates an instance of [`NDArrayType`] with `ndims` of 0.
#[must_use]
pub fn new_unsized_with_generator<G: CodeGenerator + ?Sized>(
generator: &G, generator: &G,
ctx: &'ctx Context, ctx: &'ctx Context,
dtype: BasicTypeEnum<'ctx>, dtype: BasicTypeEnum<'ctx>,
) -> Self { ) -> Self {
Self::new_impl(ctx, dtype, 0, generator.get_size_type(ctx)) let llvm_usize = generator.get_size_type(ctx);
let llvm_ndarray = Self::llvm_type(ctx, llvm_usize);
NDArrayType { ty: llvm_ndarray, dtype, ndims: 0, llvm_usize }
} }
/// Creates an [`NDArrayType`] from a [unifier type][Type]. /// Creates an [`NDArrayType`] from a [unifier type][Type].
@ -198,9 +158,15 @@ impl<'ctx> NDArrayType<'ctx> {
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
let llvm_dtype = ctx.get_llvm_type(generator, dtype); let llvm_dtype = ctx.get_llvm_type(generator, dtype);
let llvm_usize = generator.get_size_type(ctx.ctx);
let ndims = extract_ndims(&ctx.unifier, ndims); let ndims = extract_ndims(&ctx.unifier, ndims);
Self::new_impl(ctx.ctx, llvm_dtype, ndims, ctx.get_size_type()) NDArrayType {
ty: Self::llvm_type(ctx.ctx, llvm_usize),
dtype: llvm_dtype,
ndims,
llvm_usize,
}
} }
/// Creates an [`NDArrayType`] from a [`PointerType`] representing an `NDArray`. /// Creates an [`NDArrayType`] from a [`PointerType`] representing an `NDArray`.
@ -293,9 +259,9 @@ impl<'ctx> NDArrayType<'ctx> {
.builder .builder
.build_int_truncate_or_bit_cast(self.dtype.size_of().unwrap(), self.llvm_usize, "") .build_int_truncate_or_bit_cast(self.dtype.size_of().unwrap(), self.llvm_usize, "")
.unwrap(); .unwrap();
ndarray.store_itemsize(ctx, itemsize); ndarray.store_itemsize(ctx, generator, itemsize);
ndarray.store_ndims(ctx, ndims); ndarray.store_ndims(ctx, generator, ndims);
ndarray.create_shape(ctx, self.llvm_usize, ndims); ndarray.create_shape(ctx, self.llvm_usize, ndims);
ndarray.create_strides(ctx, self.llvm_usize, ndims); ndarray.create_strides(ctx, self.llvm_usize, ndims);
@ -338,10 +304,10 @@ impl<'ctx> NDArrayType<'ctx> {
) -> <Self as ProxyType<'ctx>>::Value { ) -> <Self as ProxyType<'ctx>>::Value {
assert_eq!(shape.len() as u64, self.ndims); assert_eq!(shape.len() as u64, self.ndims);
let ndarray = Self::new(ctx, self.dtype, shape.len() as u64) let ndarray = Self::new(generator, ctx.ctx, self.dtype, shape.len() as u64)
.construct_uninitialized(generator, ctx, name); .construct_uninitialized(generator, ctx, name);
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
// Write shape // Write shape
let ndarray_shape = ndarray.shape(); let ndarray_shape = ndarray.shape();
@ -373,10 +339,10 @@ impl<'ctx> NDArrayType<'ctx> {
) -> <Self as ProxyType<'ctx>>::Value { ) -> <Self as ProxyType<'ctx>>::Value {
assert_eq!(shape.len() as u64, self.ndims); assert_eq!(shape.len() as u64, self.ndims);
let ndarray = Self::new(ctx, self.dtype, shape.len() as u64) let ndarray = Self::new(generator, ctx.ctx, self.dtype, shape.len() as u64)
.construct_uninitialized(generator, ctx, name); .construct_uninitialized(generator, ctx, name);
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
// Write shape // Write shape
let ndarray_shape = ndarray.shape(); let ndarray_shape = ndarray.shape();
@ -423,8 +389,8 @@ impl<'ctx> NDArrayType<'ctx> {
.build_pointer_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "") .build_pointer_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "")
.unwrap(); .unwrap();
let ndarray = let ndarray = Self::new_unsized(generator, ctx.ctx, value.get_type())
Self::new_unsized(ctx, value.get_type()).construct_uninitialized(generator, ctx, name); .construct_uninitialized(generator, ctx, name);
ctx.builder.build_store(ndarray.ptr_to_data(ctx), data).unwrap(); ctx.builder.build_store(ndarray.ptr_to_data(ctx), data).unwrap();
ndarray ndarray
} }

View File

@ -86,27 +86,15 @@ impl<'ctx> NDIterType<'ctx> {
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
} }
fn new_impl(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self { /// Creates an instance of [`NDIter`].
#[must_use]
pub fn new<G: CodeGenerator + ?Sized>(generator: &G, ctx: &'ctx Context) -> Self {
let llvm_usize = generator.get_size_type(ctx);
let llvm_nditer = Self::llvm_type(ctx, llvm_usize); let llvm_nditer = Self::llvm_type(ctx, llvm_usize);
Self { ty: llvm_nditer, llvm_usize } Self { ty: llvm_nditer, llvm_usize }
} }
/// Creates an instance of [`NDIter`].
#[must_use]
pub fn new(ctx: &CodeGenContext<'ctx, '_>) -> Self {
Self::new_impl(ctx.ctx, ctx.get_size_type())
}
/// Creates an instance of [`NDIter`].
#[must_use]
pub fn new_with_generator<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
) -> Self {
Self::new_impl(ctx, generator.get_size_type(ctx))
}
/// Creates an [`NDIterType`] from a [`PointerType`] representing an `NDIter`. /// Creates an [`NDIterType`] from a [`PointerType`] representing an `NDIter`.
#[must_use] #[must_use]
pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
@ -163,6 +151,11 @@ impl<'ctx> NDIterType<'ctx> {
} }
/// Allocate an [`NDIter`] that iterates through the given `ndarray`. /// Allocate an [`NDIter`] that iterates through the given `ndarray`.
///
/// Note: This function allocates an array on the stack at the current builder location, which
/// may lead to stack explosion if called in a hot loop. Therefore, callers are recommended to
/// call `llvm.stacksave` before calling this function and call `llvm.stackrestore` after the
/// [`NDIter`] is no longer needed.
#[must_use] #[must_use]
pub fn construct<G: CodeGenerator + ?Sized>( pub fn construct<G: CodeGenerator + ?Sized>(
&self, &self,

View File

@ -32,34 +32,17 @@ impl<'ctx> TupleType<'ctx> {
ctx.struct_type(tys, false) ctx.struct_type(tys, false)
} }
fn new_impl(
ctx: &'ctx Context,
tys: &[BasicTypeEnum<'ctx>],
llvm_usize: IntType<'ctx>,
) -> Self {
let llvm_tuple = Self::llvm_type(ctx, tys);
Self { ty: llvm_tuple, llvm_usize }
}
/// Creates an instance of [`TupleType`]. /// Creates an instance of [`TupleType`].
#[must_use] #[must_use]
pub fn new(ctx: &CodeGenContext<'ctx, '_>, tys: &[impl BasicType<'ctx>]) -> Self { pub fn new<G: CodeGenerator + ?Sized>(
Self::new_impl(
ctx.ctx,
&tys.iter().map(BasicType::as_basic_type_enum).collect_vec(),
ctx.get_size_type(),
)
}
/// Creates an instance of [`TupleType`].
#[must_use]
pub fn new_with_generator<G: CodeGenerator + ?Sized>(
generator: &G, generator: &G,
ctx: &'ctx Context, ctx: &'ctx Context,
tys: &[BasicTypeEnum<'ctx>], tys: &[BasicTypeEnum<'ctx>],
) -> Self { ) -> Self {
Self::new_impl(ctx, tys, generator.get_size_type(ctx)) let llvm_usize = generator.get_size_type(ctx);
let llvm_tuple = Self::llvm_type(ctx, tys);
Self { ty: llvm_tuple, llvm_usize }
} }
/// Creates an [`TupleType`] from a [unifier type][Type]. /// Creates an [`TupleType`] from a [unifier type][Type].
@ -69,7 +52,7 @@ impl<'ctx> TupleType<'ctx> {
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
ty: Type, ty: Type,
) -> Self { ) -> Self {
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
// Sanity check on object type. // Sanity check on object type.
let TypeEnum::TTuple { ty: tys, .. } = &*ctx.unifier.get_ty_immutable(ty) else { let TypeEnum::TTuple { ty: tys, .. } = &*ctx.unifier.get_ty_immutable(ty) else {

View File

@ -122,31 +122,19 @@ impl<'ctx> SliceType<'ctx> {
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
} }
fn new_impl(ctx: &'ctx Context, int_ty: IntType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { /// Creates an instance of [`SliceType`] with `int_ty` as its backing integer type.
#[must_use]
pub fn new(ctx: &'ctx Context, int_ty: IntType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
let llvm_ty = Self::llvm_type(ctx, int_ty); let llvm_ty = Self::llvm_type(ctx, int_ty);
Self { ty: llvm_ty, int_ty, llvm_usize } Self { ty: llvm_ty, int_ty, llvm_usize }
} }
/// Creates an instance of [`SliceType`] with `int_ty` as its backing integer type.
#[must_use]
pub fn new(ctx: &CodeGenContext<'ctx, '_>, int_ty: IntType<'ctx>) -> Self {
Self::new_impl(ctx.ctx, int_ty, ctx.get_size_type())
}
/// Creates an instance of [`SliceType`] with `usize` as its backing integer type. /// Creates an instance of [`SliceType`] with `usize` as its backing integer type.
#[must_use] #[must_use]
pub fn new_usize(ctx: &CodeGenContext<'ctx, '_>) -> Self { pub fn new_usize<G: CodeGenerator + ?Sized>(generator: &G, ctx: &'ctx Context) -> Self {
Self::new_impl(ctx.ctx, ctx.get_size_type(), ctx.get_size_type()) let llvm_usize = generator.get_size_type(ctx);
} Self::new(ctx, llvm_usize, llvm_usize)
/// Creates an instance of [`SliceType`] with `usize` as its backing integer type.
#[must_use]
pub fn new_usize_with_generator<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
) -> Self {
Self::new_impl(ctx, generator.get_size_type(ctx), generator.get_size_type(ctx))
} }
/// Creates an [`SliceType`] from a [`PointerType`] representing a `slice`. /// Creates an [`SliceType`] from a [`PointerType`] representing a `slice`.

View File

@ -418,7 +418,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ArraySliceValue<'ctx> {
idx: &IntValue<'ctx>, idx: &IntValue<'ctx>,
name: Option<&str>, name: Option<&str>,
) -> PointerValue<'ctx> { ) -> PointerValue<'ctx> {
debug_assert_eq!(idx.get_type(), ctx.get_size_type()); debug_assert_eq!(idx.get_type(), generator.get_size_type(ctx.ctx));
let size = self.size(ctx, generator); let size = self.size(ctx, generator);
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap(); let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();

View File

@ -97,8 +97,13 @@ impl<'ctx> ListValue<'ctx> {
} }
/// Stores the `size` of this `list` into this instance. /// Stores the `size` of this `list` into this instance.
pub fn store_size(&self, ctx: &CodeGenContext<'ctx, '_>, size: IntValue<'ctx>) { pub fn store_size<G: CodeGenerator + ?Sized>(
debug_assert_eq!(size.get_type(), ctx.get_size_type()); &self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
size: IntValue<'ctx>,
) {
debug_assert_eq!(size.get_type(), generator.get_size_type(ctx.ctx));
self.len_field(ctx).set(ctx, self.value, size, self.name); self.len_field(ctx).set(ctx, self.value, size, self.name);
} }
@ -114,9 +119,13 @@ impl<'ctx> ListValue<'ctx> {
/// Returns an instance of [`ListValue`] with the `items` pointer cast to `i8*`. /// Returns an instance of [`ListValue`] with the `items` pointer cast to `i8*`.
#[must_use] #[must_use]
pub fn as_i8_list(&self, ctx: &CodeGenContext<'ctx, '_>) -> ListValue<'ctx> { pub fn as_i8_list<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
) -> ListValue<'ctx> {
let llvm_i8 = ctx.ctx.i8_type(); let llvm_i8 = ctx.ctx.i8_type();
let llvm_list_i8 = <Self as ProxyValue>::Type::new(ctx, &llvm_i8); let llvm_list_i8 = <Self as ProxyValue>::Type::new(generator, ctx.ctx, llvm_i8.into());
Self::from_pointer_value( Self::from_pointer_value(
ctx.builder.build_pointer_cast(self.value, llvm_list_i8.as_base_type(), "").unwrap(), ctx.builder.build_pointer_cast(self.value, llvm_list_i8.as_base_type(), "").unwrap(),
@ -204,7 +213,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ListDataProxy<'ctx, '_> {
idx: &IntValue<'ctx>, idx: &IntValue<'ctx>,
name: Option<&str>, name: Option<&str>,
) -> PointerValue<'ctx> { ) -> PointerValue<'ctx> {
debug_assert_eq!(idx.get_type(), ctx.get_size_type()); debug_assert_eq!(idx.get_type(), generator.get_size_type(ctx.ctx));
let size = self.size(ctx, generator); let size = self.size(ctx, generator);
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap(); let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();

View File

@ -104,7 +104,7 @@ impl<'ctx> NDArrayValue<'ctx> {
assert!(self.ndims <= target_ndims); assert!(self.ndims <= target_ndims);
assert_eq!(target_shape.element_type(ctx, generator), self.llvm_usize.into()); assert_eq!(target_shape.element_type(ctx, generator), self.llvm_usize.into());
let broadcast_ndarray = NDArrayType::new(ctx, self.dtype, target_ndims) let broadcast_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, target_ndims)
.construct_uninitialized(generator, ctx, None); .construct_uninitialized(generator, ctx, None);
broadcast_ndarray.copy_shape_from_array( broadcast_ndarray.copy_shape_from_array(
generator, generator,
@ -112,7 +112,7 @@ impl<'ctx> NDArrayValue<'ctx> {
target_shape.base_ptr(ctx, generator), target_shape.base_ptr(ctx, generator),
); );
irrt::ndarray::call_nac3_ndarray_broadcast_to(ctx, *self, broadcast_ndarray); irrt::ndarray::call_nac3_ndarray_broadcast_to(generator, ctx, *self, broadcast_ndarray);
broadcast_ndarray broadcast_ndarray
} }
} }
@ -146,8 +146,8 @@ fn broadcast_shapes<'ctx, G, Shape>(
Shape: TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> Shape: TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>
+ TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>>, + TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>>,
{ {
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_shape_ty = ShapeEntryType::new(ctx); let llvm_shape_ty = ShapeEntryType::new(generator, ctx.ctx);
assert!(in_shape_entries assert!(in_shape_entries
.iter() .iter()
@ -199,7 +199,7 @@ impl<'ctx> NDArrayType<'ctx> {
) -> BroadcastAllResult<'ctx, G> { ) -> BroadcastAllResult<'ctx, G> {
assert!(!ndarrays.is_empty()); assert!(!ndarrays.is_empty());
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
// Infer the broadcast output ndims. // Infer the broadcast output ndims.
let broadcast_ndims_int = let broadcast_ndims_int =

View File

@ -117,8 +117,8 @@ impl<'ctx> NDArrayValue<'ctx> {
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
) -> ContiguousNDArrayValue<'ctx> { ) -> ContiguousNDArrayValue<'ctx> {
let result = let result = ContiguousNDArrayType::new(generator, ctx.ctx, self.dtype)
ContiguousNDArrayType::new(ctx, &self.dtype).alloca_var(generator, ctx, self.name); .alloca_var(generator, ctx, self.name);
// Set ndims and shape. // Set ndims and shape.
let ndims = self.llvm_usize.const_int(self.ndims, false); let ndims = self.llvm_usize.const_int(self.ndims, false);
@ -130,7 +130,7 @@ impl<'ctx> NDArrayValue<'ctx> {
gen_if_callback( gen_if_callback(
generator, generator,
ctx, ctx,
|_, ctx| Ok(self.is_c_contiguous(ctx)), |generator, ctx| Ok(self.is_c_contiguous(generator, ctx)),
|_, ctx| { |_, ctx| {
// This ndarray is contiguous. // This ndarray is contiguous.
let data = self.data_field(ctx).get(ctx, self.as_base_value(), self.name); let data = self.data_field(ctx).get(ctx, self.as_base_value(), self.name);
@ -178,16 +178,13 @@ impl<'ctx> NDArrayValue<'ctx> {
// TODO: Debug assert `ndims == carray.ndims` to catch bugs. // TODO: Debug assert `ndims == carray.ndims` to catch bugs.
// Allocate the resulting ndarray. // Allocate the resulting ndarray.
let ndarray = NDArrayType::new(ctx, carray.item, ndims).construct_uninitialized( let ndarray = NDArrayType::new(generator, ctx.ctx, carray.item, ndims)
generator, .construct_uninitialized(generator, ctx, carray.name);
ctx,
carray.name,
);
// Copy shape and update strides // Copy shape and update strides
let shape = carray.load_shape(ctx); let shape = carray.load_shape(ctx);
ndarray.copy_shape_from_array(generator, ctx, shape); ndarray.copy_shape_from_array(generator, ctx, shape);
ndarray.set_strides_contiguous(ctx); ndarray.set_strides_contiguous(generator, ctx);
// Share data // Share data
let data = carray.load_data(ctx); let data = carray.load_data(ctx);

View File

@ -1,101 +0,0 @@
use inkwell::values::{BasicValue, BasicValueEnum};
use super::{NDArrayValue, NDIterValue, ScalarOrNDArray};
use crate::codegen::{
stmt::{gen_for_callback, BreakContinueHooks},
types::ndarray::NDIterType,
CodeGenContext, CodeGenerator,
};
impl<'ctx> NDArrayValue<'ctx> {
/// Folds the elements of this ndarray into an accumulator value by applying `f`, returning the
/// final value.
///
/// `f` has access to [`BreakContinueHooks`] to short-circuit the `fold` operation, an instance
/// of `V` representing the current accumulated value, and an [`NDIterValue`] to get the
/// properties of the current iterated element.
pub fn fold<'a, G, V, F>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
init: V,
f: F,
) -> Result<V, String>
where
G: CodeGenerator + ?Sized,
V: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>>,
<V as TryFrom<BasicValueEnum<'ctx>>>::Error: std::fmt::Debug,
F: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks<'ctx>,
V,
NDIterValue<'ctx>,
) -> Result<V, String>,
{
let acc_ptr =
generator.gen_var_alloc(ctx, init.as_basic_value_enum().get_type(), None).unwrap();
ctx.builder.build_store(acc_ptr, init).unwrap();
gen_for_callback(
generator,
ctx,
Some("ndarray_fold"),
|generator, ctx| Ok(NDIterType::new(ctx).construct(generator, ctx, *self)),
|_, ctx, nditer| Ok(nditer.has_element(ctx)),
|generator, ctx, hooks, nditer| {
let acc = V::try_from(ctx.builder.build_load(acc_ptr, "").unwrap()).unwrap();
let acc = f(generator, ctx, hooks, acc, nditer)?;
ctx.builder.build_store(acc_ptr, acc).unwrap();
Ok(())
},
|_, ctx, nditer| {
nditer.next(ctx);
Ok(())
},
)?;
let acc = ctx.builder.build_load(acc_ptr, "").unwrap();
Ok(V::try_from(acc).unwrap())
}
}
impl<'ctx> ScalarOrNDArray<'ctx> {
/// See [`NDArrayValue::fold`].
///
/// The primary differences between this function and `NDArrayValue::fold` are:
///
/// - The 3rd parameter of `f` is an `Option` of hooks, since `break`/`continue` hooks are not
/// available if this instance represents a scalar value.
/// - The 5th parameter of `f` is a [`BasicValueEnum`], since no [iterator][`NDIterValue`] will
/// be created if this instance represents a scalar value.
pub fn fold<'a, G, V, F>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
init: V,
f: F,
) -> Result<V, String>
where
G: CodeGenerator + ?Sized,
V: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>>,
<V as TryFrom<BasicValueEnum<'ctx>>>::Error: std::fmt::Debug,
F: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
Option<&BreakContinueHooks<'ctx>>,
V,
BasicValueEnum<'ctx>,
) -> Result<V, String>,
{
match self {
ScalarOrNDArray::Scalar(v) => f(generator, ctx, None, init, *v),
ScalarOrNDArray::NDArray(v) => {
v.fold(generator, ctx, init, |generator, ctx, hooks, acc, nditer| {
let elem = nditer.get_scalar(ctx);
f(generator, ctx, Some(&hooks), acc, elem)
})
}
}
}
}

View File

@ -128,10 +128,11 @@ impl<'ctx> NDArrayValue<'ctx> {
indices: &[RustNDIndex<'ctx>], indices: &[RustNDIndex<'ctx>],
) -> Self { ) -> Self {
let dst_ndims = self.deduce_ndims_after_indexing_with(indices); let dst_ndims = self.deduce_ndims_after_indexing_with(indices);
let dst_ndarray = NDArrayType::new(ctx, self.dtype, dst_ndims) let dst_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, dst_ndims)
.construct_uninitialized(generator, ctx, None); .construct_uninitialized(generator, ctx, None);
let indices = NDIndexType::new(ctx).construct_ndindices(generator, ctx, indices); let indices =
NDIndexType::new(generator, ctx.ctx).construct_ndindices(generator, ctx, indices);
irrt::ndarray::call_nac3_ndarray_index(generator, ctx, indices, *self, dst_ndarray); irrt::ndarray::call_nac3_ndarray_index(generator, ctx, indices, *self, dst_ndarray);
dst_ndarray dst_ndarray
@ -244,7 +245,8 @@ impl<'ctx> RustNDIndex<'ctx> {
} }
RustNDIndex::Slice(in_rust_slice) => { RustNDIndex::Slice(in_rust_slice) => {
let user_slice_ptr = let user_slice_ptr =
SliceType::new(ctx, ctx.ctx.i32_type()).alloca_var(generator, ctx, None); SliceType::new(ctx.ctx, ctx.ctx.i32_type(), generator.get_size_type(ctx.ctx))
.alloca_var(generator, ctx, None);
in_rust_slice.write_to_slice(ctx, user_slice_ptr); in_rust_slice.write_to_slice(ctx, user_slice_ptr);
dst_ndindex.store_data( dst_ndindex.store_data(

View File

@ -35,7 +35,7 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>(
let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_a_ty); let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_a_ty);
let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_b_ty); let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_b_ty);
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_dst_dtype = ctx.get_llvm_type(generator, dst_dtype); let llvm_dst_dtype = ctx.get_llvm_type(generator, dst_dtype);
// Deduce ndims of the result of matmul. // Deduce ndims of the result of matmul.
@ -108,7 +108,7 @@ fn matmul_at_least_2d<'ctx, G: CodeGenerator>(
let lhs = in_a.broadcast_to(generator, ctx, ndims_int, &lhs_shape); let lhs = in_a.broadcast_to(generator, ctx, ndims_int, &lhs_shape);
let rhs = in_b.broadcast_to(generator, ctx, ndims_int, &rhs_shape); let rhs = in_b.broadcast_to(generator, ctx, ndims_int, &rhs_shape);
let dst = NDArrayType::new(ctx, llvm_dst_dtype, ndims_int) let dst = NDArrayType::new(generator, ctx.ctx, llvm_dst_dtype, ndims_int)
.construct_uninitialized(generator, ctx, None); .construct_uninitialized(generator, ctx, None);
dst.copy_shape_from_array(generator, ctx, dst_shape.base_ptr(ctx, generator)); dst.copy_shape_from_array(generator, ctx, dst_shape.base_ptr(ctx, generator));
unsafe { unsafe {
@ -315,7 +315,7 @@ impl<'ctx> NDArrayValue<'ctx> {
let result_shape = result.shape(); let result_shape = result.shape();
out_ndarray.assert_can_be_written_by_out(generator, ctx, result_shape); out_ndarray.assert_can_be_written_by_out(generator, ctx, result_shape);
out_ndarray.copy_data_from(ctx, result); out_ndarray.copy_data_from(generator, ctx, result);
out_ndarray out_ndarray
} }
} }

View File

@ -30,7 +30,6 @@ pub use nditer::*;
mod broadcast; mod broadcast;
mod contiguous; mod contiguous;
mod fold;
mod indexing; mod indexing;
mod map; mod map;
mod matmul; mod matmul;
@ -82,8 +81,13 @@ impl<'ctx> NDArrayValue<'ctx> {
} }
/// Stores the number of dimensions `ndims` into this instance. /// Stores the number of dimensions `ndims` into this instance.
pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, ndims: IntValue<'ctx>) { pub fn store_ndims<G: CodeGenerator + ?Sized>(
debug_assert_eq!(ndims.get_type(), ctx.get_size_type()); &self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
ndims: IntValue<'ctx>,
) {
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
let pndims = self.ptr_to_ndims(ctx); let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_store(pndims, ndims).unwrap(); ctx.builder.build_store(pndims, ndims).unwrap();
@ -100,8 +104,13 @@ impl<'ctx> NDArrayValue<'ctx> {
} }
/// Stores the size of each element `itemsize` into this instance. /// Stores the size of each element `itemsize` into this instance.
pub fn store_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>, itemsize: IntValue<'ctx>) { pub fn store_itemsize<G: CodeGenerator + ?Sized>(
debug_assert_eq!(itemsize.get_type(), ctx.get_size_type()); &self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
itemsize: IntValue<'ctx>,
) {
debug_assert_eq!(itemsize.get_type(), generator.get_size_type(ctx.ctx));
self.itemsize_field(ctx).set(ctx, self.value, itemsize, self.name); self.itemsize_field(ctx).set(ctx, self.value, itemsize, self.name);
} }
@ -196,12 +205,12 @@ impl<'ctx> NDArrayValue<'ctx> {
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
) { ) {
let nbytes = self.nbytes(ctx); let nbytes = self.nbytes(generator, ctx);
let data = type_aligned_alloca(generator, ctx, self.dtype, nbytes, None); let data = type_aligned_alloca(generator, ctx, self.dtype, nbytes, None);
self.store_data(ctx, data); self.store_data(ctx, data);
self.set_strides_contiguous(ctx); self.set_strides_contiguous(generator, ctx);
} }
/// Returns a proxy object to the field storing the data of this `NDArray`. /// Returns a proxy object to the field storing the data of this `NDArray`.
@ -275,32 +284,52 @@ impl<'ctx> NDArrayValue<'ctx> {
} }
/// Get the `np.size()` of this ndarray. /// Get the `np.size()` of this ndarray.
pub fn size(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { pub fn size<G: CodeGenerator + ?Sized>(
irrt::ndarray::call_nac3_ndarray_size(ctx, *self) &self,
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
) -> IntValue<'ctx> {
irrt::ndarray::call_nac3_ndarray_size(generator, ctx, *self)
} }
/// Get the `ndarray.nbytes` of this ndarray. /// Get the `ndarray.nbytes` of this ndarray.
pub fn nbytes(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { pub fn nbytes<G: CodeGenerator + ?Sized>(
irrt::ndarray::call_nac3_ndarray_nbytes(ctx, *self) &self,
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
) -> IntValue<'ctx> {
irrt::ndarray::call_nac3_ndarray_nbytes(generator, ctx, *self)
} }
/// Get the `len()` of this ndarray. /// Get the `len()` of this ndarray.
pub fn len(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { pub fn len<G: CodeGenerator + ?Sized>(
irrt::ndarray::call_nac3_ndarray_len(ctx, *self) &self,
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
) -> IntValue<'ctx> {
irrt::ndarray::call_nac3_ndarray_len(generator, ctx, *self)
} }
/// Check if this ndarray is C-contiguous. /// Check if this ndarray is C-contiguous.
/// ///
/// See NumPy's `flags["C_CONTIGUOUS"]`: <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags> /// See NumPy's `flags["C_CONTIGUOUS"]`: <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags>
pub fn is_c_contiguous(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { pub fn is_c_contiguous<G: CodeGenerator + ?Sized>(
irrt::ndarray::call_nac3_ndarray_is_c_contiguous(ctx, *self) &self,
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
) -> IntValue<'ctx> {
irrt::ndarray::call_nac3_ndarray_is_c_contiguous(generator, ctx, *self)
} }
/// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`. /// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`.
/// ///
/// Update the ndarray's strides to make the ndarray contiguous. /// Update the ndarray's strides to make the ndarray contiguous.
pub fn set_strides_contiguous(&self, ctx: &CodeGenContext<'ctx, '_>) { pub fn set_strides_contiguous<G: CodeGenerator + ?Sized>(
irrt::ndarray::call_nac3_ndarray_set_strides_by_shape(ctx, *self); &self,
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
) {
irrt::ndarray::call_nac3_ndarray_set_strides_by_shape(generator, ctx, *self);
} }
/// Clone/Copy this ndarray - Allocate a new ndarray with the same shape as this ndarray and /// Clone/Copy this ndarray - Allocate a new ndarray with the same shape as this ndarray and
@ -318,7 +347,7 @@ impl<'ctx> NDArrayValue<'ctx> {
let shape = self.shape(); let shape = self.shape();
clone.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator)); clone.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator));
unsafe { clone.create_data(generator, ctx) }; unsafe { clone.create_data(generator, ctx) };
clone.copy_data_from(ctx, *self); clone.copy_data_from(generator, ctx, *self);
clone clone
} }
@ -328,9 +357,14 @@ impl<'ctx> NDArrayValue<'ctx> {
/// do not matter. The copying order is determined by how their flattened views look. /// do not matter. The copying order is determined by how their flattened views look.
/// ///
/// Panics if the `dtype`s of ndarrays are different. /// Panics if the `dtype`s of ndarrays are different.
pub fn copy_data_from(&self, ctx: &CodeGenContext<'ctx, '_>, src: NDArrayValue<'ctx>) { pub fn copy_data_from<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
src: NDArrayValue<'ctx>,
) {
assert_eq!(self.dtype, src.dtype, "self and src dtype should match"); assert_eq!(self.dtype, src.dtype, "self and src dtype should match");
irrt::ndarray::call_nac3_ndarray_copy_data(ctx, src, *self); irrt::ndarray::call_nac3_ndarray_copy_data(generator, ctx, src, *self);
} }
/// Fill the ndarray with a scalar. /// Fill the ndarray with a scalar.
@ -378,7 +412,11 @@ impl<'ctx> NDArrayValue<'ctx> {
.map(|obj| obj.as_basic_value_enum()) .map(|obj| obj.as_basic_value_enum())
.collect_vec(); .collect_vec();
TupleType::new(ctx, &repeat_n(llvm_i32, self.ndims as usize).collect_vec()) TupleType::new(
generator,
ctx.ctx,
&repeat_n(llvm_i32.into(), self.ndims as usize).collect_vec(),
)
.construct_from_objects(ctx, objects, None) .construct_from_objects(ctx, objects, None)
} }
@ -408,7 +446,11 @@ impl<'ctx> NDArrayValue<'ctx> {
.map(|obj| obj.as_basic_value_enum()) .map(|obj| obj.as_basic_value_enum())
.collect_vec(); .collect_vec();
TupleType::new(ctx, &repeat_n(llvm_i32, self.ndims as usize).collect_vec()) TupleType::new(
generator,
ctx.ctx,
&repeat_n(llvm_i32.into(), self.ndims as usize).collect_vec(),
)
.construct_from_objects(ctx, objects, None) .construct_from_objects(ctx, objects, None)
} }
@ -426,7 +468,7 @@ impl<'ctx> NDArrayValue<'ctx> {
) -> Option<BasicValueEnum<'ctx>> { ) -> Option<BasicValueEnum<'ctx>> {
if self.is_unsized() { if self.is_unsized() {
// NOTE: `np.size(self) == 0` here is never possible. // NOTE: `np.size(self) == 0` here is never possible.
let zero = ctx.get_size_type().const_zero(); let zero = generator.get_size_type(ctx.ctx).const_zero();
let value = unsafe { self.data().get_unchecked(ctx, generator, &zero, None) }; let value = unsafe { self.data().get_unchecked(ctx, generator, &zero, None) };
Some(value) Some(value)
@ -714,9 +756,9 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> {
fn size<G: CodeGenerator + ?Sized>( fn size<G: CodeGenerator + ?Sized>(
&self, &self,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
_: &G, generator: &G,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
irrt::ndarray::call_nac3_ndarray_len(ctx, *self.0) irrt::ndarray::call_nac3_ndarray_len(generator, ctx, *self.0)
} }
} }
@ -728,7 +770,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> {
idx: &IntValue<'ctx>, idx: &IntValue<'ctx>,
name: Option<&str>, name: Option<&str>,
) -> PointerValue<'ctx> { ) -> PointerValue<'ctx> {
let ptr = irrt::ndarray::call_nac3_ndarray_get_nth_pelement(ctx, *self.0, *idx); let ptr = irrt::ndarray::call_nac3_ndarray_get_nth_pelement(generator, ctx, *self.0, *idx);
// Current implementation is transparent - The returned pointer type is // Current implementation is transparent - The returned pointer type is
// already cast into the expected type, allowing for immediately // already cast into the expected type, allowing for immediately
@ -792,7 +834,7 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
indices: &Index, indices: &Index,
name: Option<&str>, name: Option<&str>,
) -> PointerValue<'ctx> { ) -> PointerValue<'ctx> {
assert_eq!(indices.element_type(ctx, generator), ctx.get_size_type().into()); assert_eq!(indices.element_type(ctx, generator), generator.get_size_type(ctx.ctx).into());
let indices = TypedArrayLikeAdapter::from( let indices = TypedArrayLikeAdapter::from(
indices.as_slice_value(ctx, generator), indices.as_slice_value(ctx, generator),
@ -825,7 +867,7 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
indices: &Index, indices: &Index,
name: Option<&str>, name: Option<&str>,
) -> PointerValue<'ctx> { ) -> PointerValue<'ctx> {
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
let indices_size = indices.size(ctx, generator); let indices_size = indices.size(ctx, generator);
let nidx_leq_ndims = ctx let nidx_leq_ndims = ctx
@ -991,8 +1033,10 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
) -> NDArrayValue<'ctx> { ) -> NDArrayValue<'ctx> {
match self { match self {
ScalarOrNDArray::NDArray(ndarray) => *ndarray, ScalarOrNDArray::NDArray(ndarray) => *ndarray,
ScalarOrNDArray::Scalar(scalar) => NDArrayType::new_unsized(ctx, scalar.get_type()) ScalarOrNDArray::Scalar(scalar) => {
.construct_unsized(generator, ctx, scalar, None), NDArrayType::new_unsized(generator, ctx.ctx, scalar.get_type())
.construct_unsized(generator, ctx, scalar, None)
}
} }
} }

View File

@ -53,16 +53,20 @@ impl<'ctx> NDIterValue<'ctx> {
/// If `ndarray` is unsized, this returns true only for the first iteration. /// If `ndarray` is unsized, this returns true only for the first iteration.
/// If `ndarray` is 0-sized, this always returns false. /// If `ndarray` is 0-sized, this always returns false.
#[must_use] #[must_use]
pub fn has_element(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { pub fn has_element<G: CodeGenerator + ?Sized>(
irrt::ndarray::call_nac3_nditer_has_element(ctx, *self) &self,
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
) -> IntValue<'ctx> {
irrt::ndarray::call_nac3_nditer_has_element(generator, ctx, *self)
} }
/// Go to the next element. If `has_element()` is false, then this has undefined behavior. /// Go to the next element. If `has_element()` is false, then this has undefined behavior.
/// ///
/// If `ndarray` is unsized, this can only be called once. /// If `ndarray` is unsized, this can only be called once.
/// If `ndarray` is 0-sized, this can never be called. /// If `ndarray` is 0-sized, this can never be called.
pub fn next(&self, ctx: &CodeGenContext<'ctx, '_>) { pub fn next<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &CodeGenContext<'ctx, '_>) {
irrt::ndarray::call_nac3_nditer_next(ctx, *self); irrt::ndarray::call_nac3_nditer_next(generator, ctx, *self);
} }
fn element_field( fn element_field(
@ -137,6 +141,10 @@ impl<'ctx> NDArrayValue<'ctx> {
/// ///
/// `body` has access to [`BreakContinueHooks`] to short-circuit and [`NDIterValue`] to /// `body` has access to [`BreakContinueHooks`] to short-circuit and [`NDIterValue`] to
/// get properties of the current iteration (e.g., the current element, indices, etc.) /// get properties of the current iteration (e.g., the current element, indices, etc.)
///
/// Note: The caller is recommended to call `llvm.stacksave` and `llvm.stackrestore` before and
/// after invoking this function respectively. See [`NDIterType::construct`] for an explanation
/// on why this is suggested.
pub fn foreach<'a, G, F>( pub fn foreach<'a, G, F>(
&self, &self,
generator: &mut G, generator: &mut G,
@ -156,11 +164,13 @@ impl<'ctx> NDArrayValue<'ctx> {
generator, generator,
ctx, ctx,
Some("ndarray_foreach"), Some("ndarray_foreach"),
|generator, ctx| Ok(NDIterType::new(ctx).construct(generator, ctx, *self)), |generator, ctx| {
|_, ctx, nditer| Ok(nditer.has_element(ctx)), Ok(NDIterType::new(generator, ctx.ctx).construct(generator, ctx, *self))
},
|generator, ctx, nditer| Ok(nditer.has_element(generator, ctx)),
|generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer), |generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer),
|_, ctx, nditer| { |generator, ctx, nditer| {
nditer.next(ctx); nditer.next(generator, ctx);
Ok(()) Ok(())
}, },
) )

View File

@ -30,7 +30,7 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
(input_seq_ty, input_seq): (Type, BasicValueEnum<'ctx>), (input_seq_ty, input_seq): (Type, BasicValueEnum<'ctx>),
) -> impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> { ) -> impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>> {
let llvm_usize = ctx.get_size_type(); let llvm_usize = generator.get_size_type(ctx.ctx);
let zero = llvm_usize.const_zero(); let zero = llvm_usize.const_zero();
let one = llvm_usize.const_int(1, false); let one = llvm_usize.const_int(1, false);

View File

@ -65,12 +65,12 @@ impl<'ctx> NDArrayValue<'ctx> {
// not contiguous but could be reshaped without copying data. Look into how numpy does // not contiguous but could be reshaped without copying data. Look into how numpy does
// it. // it.
let dst_ndarray = NDArrayType::new(ctx, self.dtype, new_ndims) let dst_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, new_ndims)
.construct_uninitialized(generator, ctx, None); .construct_uninitialized(generator, ctx, None);
dst_ndarray.copy_shape_from_array(generator, ctx, new_shape.base_ptr(ctx, generator)); dst_ndarray.copy_shape_from_array(generator, ctx, new_shape.base_ptr(ctx, generator));
// Resolve negative indices // Resolve negative indices
let size = self.size(ctx); let size = self.size(generator, ctx);
let dst_ndims = self.llvm_usize.const_int(dst_ndarray.get_type().ndims(), false); let dst_ndims = self.llvm_usize.const_int(dst_ndarray.get_type().ndims(), false);
let dst_shape = dst_ndarray.shape(); let dst_shape = dst_ndarray.shape();
irrt::ndarray::call_nac3_ndarray_reshape_resolve_and_check_new_shape( irrt::ndarray::call_nac3_ndarray_reshape_resolve_and_check_new_shape(
@ -84,10 +84,10 @@ impl<'ctx> NDArrayValue<'ctx> {
gen_if_callback( gen_if_callback(
generator, generator,
ctx, ctx,
|_, ctx| Ok(self.is_c_contiguous(ctx)), |generator, ctx| Ok(self.is_c_contiguous(generator, ctx)),
|generator, ctx| { |generator, ctx| {
// Reshape is possible without copying // Reshape is possible without copying
dst_ndarray.set_strides_contiguous(ctx); dst_ndarray.set_strides_contiguous(generator, ctx);
dst_ndarray.store_data(ctx, self.data().base_ptr(ctx, generator)); dst_ndarray.store_data(ctx, self.data().base_ptr(ctx, generator));
Ok(()) Ok(())
@ -97,7 +97,7 @@ impl<'ctx> NDArrayValue<'ctx> {
unsafe { unsafe {
dst_ndarray.create_data(generator, ctx); dst_ndarray.create_data(generator, ctx);
} }
dst_ndarray.copy_data_from(ctx, *self); dst_ndarray.copy_data_from(generator, ctx, *self);
Ok(()) Ok(())
}, },

View File

@ -598,12 +598,10 @@ impl dyn SymbolResolver + Send + Sync {
unifier.internal_stringify( unifier.internal_stringify(
ty, ty,
&mut |id| { &mut |id| {
let top_level_def = &*top_level_defs[id].read(); let TopLevelDef::Class { name, .. } = &*top_level_defs[id].read() else {
let (TopLevelDef::Class { name, .. } | TopLevelDef::Module { name, .. }) = unreachable!("expected class definition")
top_level_def
else {
unreachable!("expected class/module definition")
}; };
name.to_string() name.to_string()
}, },
&mut |id| format!("typevar{id}"), &mut |id| format!("typevar{id}"),

View File

@ -6,8 +6,7 @@ use strum::IntoEnumIterator;
use super::{ use super::{
helper::{ helper::{
arraylike_flatten_element_type, debug_assert_prim_is_allowed, extract_ndims, debug_assert_prim_is_allowed, extract_ndims, make_exception_fields, PrimDef, PrimDefDetails,
make_exception_fields, PrimDef, PrimDefDetails,
}, },
numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
*, *,
@ -16,12 +15,9 @@ use crate::{
codegen::{ codegen::{
builtin_fns, builtin_fns,
numpy::*, numpy::*,
stmt::{exn_constructor, gen_if_callback}, stmt::exn_constructor,
types::ndarray::NDArrayType, types::ndarray::NDArrayType,
values::{ values::{ndarray::shape::parse_numpy_int_sequence, ProxyValue, RangeValue},
ndarray::{shape::parse_numpy_int_sequence, ScalarOrNDArray},
ProxyValue, RangeValue,
},
}, },
symbol_resolver::SymbolValue, symbol_resolver::SymbolValue,
typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap}, typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap},
@ -409,8 +405,6 @@ impl<'a> BuiltinBuilder<'a> {
PrimDef::FunNpIsNan | PrimDef::FunNpIsInf => self.build_np_float_to_bool_function(prim), PrimDef::FunNpIsNan | PrimDef::FunNpIsInf => self.build_np_float_to_bool_function(prim),
PrimDef::FunNpAny | PrimDef::FunNpAll => self.build_np_any_all_function(prim),
PrimDef::FunNpSin PrimDef::FunNpSin
| PrimDef::FunNpCos | PrimDef::FunNpCos
| PrimDef::FunNpTan | PrimDef::FunNpTan
@ -1284,7 +1278,11 @@ impl<'a> BuiltinBuilder<'a> {
let size = ctx let size = ctx
.builder .builder
.build_int_truncate_or_bit_cast(ndarray.size(ctx), ctx.ctx.i32_type(), "") .build_int_truncate_or_bit_cast(
ndarray.size(generator, ctx),
ctx.ctx.i32_type(),
"",
)
.unwrap(); .unwrap();
Ok(Some(size.into())) Ok(Some(size.into()))
}), }),
@ -1726,64 +1724,6 @@ impl<'a> BuiltinBuilder<'a> {
) )
} }
fn build_np_any_all_function(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpAny, PrimDef::FunNpAll]);
let param_ty = &[(self.num_or_ndarray_ty.ty, "a")];
let ret_ty = self.primitives.bool;
let var_map = &self.num_or_ndarray_var_map;
let codegen_callback: Box<GenCallCallback> =
Box::new(move |ctx, _, fun, args, generator| {
let llvm_i1 = ctx.ctx.bool_type();
let llvm_i1_k0 = llvm_i1.const_zero();
let llvm_i1_k1 = llvm_i1.const_all_ones();
let a_ty = fun.0.args[0].ty;
let a_val = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?;
let a = ScalarOrNDArray::from_value(generator, ctx, (a_ty, a_val));
let a_elem_ty = arraylike_flatten_element_type(&mut ctx.unifier, a_ty);
let (init, sc_val) = match prim {
PrimDef::FunNpAny => (llvm_i1_k0, llvm_i1_k1),
PrimDef::FunNpAll => (llvm_i1_k1, llvm_i1_k0),
_ => unreachable!(),
};
let acc = a.fold(generator, ctx, init, |generator, ctx, hooks, acc, elem| {
gen_if_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(IntPredicate::EQ, acc, sc_val, "")
.unwrap())
},
|_, ctx| {
if let Some(hooks) = hooks {
hooks.build_break_branch(&ctx.builder);
}
Ok(())
},
|_, _| Ok(()),
)?;
let is_truthy =
builtin_fns::call_bool(generator, ctx, (a_elem_ty, elem))?.into_int_value();
Ok(match prim {
PrimDef::FunNpAny => ctx.builder.build_or(acc, is_truthy, "").unwrap(),
PrimDef::FunNpAll => ctx.builder.build_and(acc, is_truthy, "").unwrap(),
_ => unreachable!(),
})
})?;
Ok(Some(acc.as_basic_value_enum()))
});
create_fn_by_codegen(self.unifier, var_map, prim.name(), ret_ty, param_ty, codegen_callback)
}
/// Build 1-ary numpy/scipy functions that take in a float or an ndarray and return a value of the same type as the input. /// Build 1-ary numpy/scipy functions that take in a float or an ndarray and return a value of the same type as the input.
fn build_np_sp_float_or_ndarray_1ary_function(&mut self, prim: PrimDef) -> TopLevelDef { fn build_np_sp_float_or_ndarray_1ary_function(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed( debug_assert_prim_is_allowed(

View File

@ -101,9 +101,7 @@ impl TopLevelComposer {
let builtin_name_list = definition_ast_list let builtin_name_list = definition_ast_list
.iter() .iter()
.map(|def_ast| match *def_ast.0.read() { .map(|def_ast| match *def_ast.0.read() {
TopLevelDef::Class { name, .. } | TopLevelDef::Module { name, .. } => { TopLevelDef::Class { name, .. } => name.to_string(),
name.to_string()
}
TopLevelDef::Function { simple_name, .. } TopLevelDef::Function { simple_name, .. }
| TopLevelDef::Variable { simple_name, .. } => simple_name.to_string(), | TopLevelDef::Variable { simple_name, .. } => simple_name.to_string(),
}) })
@ -203,43 +201,6 @@ impl TopLevelComposer {
self.definition_ast_list.iter().map(|(def, ..)| def.clone()).collect_vec() self.definition_ast_list.iter().map(|(def, ..)| def.clone()).collect_vec()
} }
/// register top level modules
pub fn register_top_level_module(
&mut self,
module_name: &str,
name_to_pyid: &Rc<HashMap<StrRef, u64>>,
resolver: Arc<dyn SymbolResolver + Send + Sync>,
location: Option<Location>,
) -> Result<DefinitionId, String> {
let mut methods: HashMap<StrRef, DefinitionId> = HashMap::new();
let mut attributes: Vec<(StrRef, DefinitionId)> = Vec::new();
for (name, _) in name_to_pyid.iter() {
if let Ok(def_id) = resolver.get_identifier_def(*name) {
// Avoid repeated attribute instances resulting from multiple imports of same module
if self.defined_names.contains(&format!("{module_name}.{name}")) {
match &*self.definition_ast_list[def_id.0].0.read() {
TopLevelDef::Class { .. } | TopLevelDef::Function { .. } => {
methods.insert(*name, def_id);
}
_ => attributes.push((*name, def_id)),
}
}
};
}
let module_def = TopLevelDef::Module {
name: module_name.to_string().into(),
module_id: DefinitionId(self.definition_ast_list.len()),
methods,
attributes,
resolver: Some(resolver),
loc: location,
};
self.definition_ast_list.push((Arc::new(RwLock::new(module_def)), None));
Ok(DefinitionId(self.definition_ast_list.len() - 1))
}
/// register, just remember the names of top level classes/function /// register, just remember the names of top level classes/function
/// and check duplicate class/method/function definition /// and check duplicate class/method/function definition
pub fn register_top_level( pub fn register_top_level(
@ -473,7 +434,7 @@ impl TopLevelComposer {
location: Location, location: Location,
) -> Result<(StrRef, DefinitionId, Option<Type>), String> { ) -> Result<(StrRef, DefinitionId, Option<Type>), String> {
if self.keyword_list.contains(&name) { if self.keyword_list.contains(&name) {
return Err(format!("cannot use keyword `{name}` as a class name (at {location})")); return Err(format!("cannot use keyword `{name}` as a variable name (at {location})"));
} }
let global_var_name = let global_var_name =
@ -1987,6 +1948,103 @@ impl TopLevelComposer {
let unifier = &mut self.unifier; let unifier = &mut self.unifier;
let primitives_store = &self.primitives_ty; let primitives_store = &self.primitives_ty;
// let dummy_field_type = unifier.get_dummy_var().ty;
// let annotation = match value {
// None => {
// // handle Kernel[T], KernelInvariant[T]
// let (annotation, mutable) = match &annotation.node {
// ExprKind::Subscript { value, slice, .. }
// if matches!(
// &value.node,
// ast::ExprKind::Name { id, .. } if id == &core_config.kernel_invariant_ann.into()
// ) =>
// {
// (slice, false)
// }
// ExprKind::Subscript { value, slice, .. }
// if matches!(
// &value.node,
// ast::ExprKind::Name { id, .. } if core_config.kernel_ann.map_or(false, |c| id == &c.into())
// ) =>
// {
// (slice, true)
// }
// _ if core_config.kernel_ann.is_none() => (annotation, true),
// _ => continue, // ignore fields annotated otherwise
// };
// class_fields_def.push((*attr, dummy_field_type, mutable));
// annotation
// }
// // Supporting Class Attributes
// Some(boxed_expr) => {
// // Class attributes are set as immutable regardless
// let (annotation, _) = match &annotation.node {
// ExprKind::Subscript { slice, .. } => (slice, false),
// _ if core_config.kernel_ann.is_none() => (annotation, false),
// _ => continue,
// };
// match &**boxed_expr {
// ast::Located {location: _, custom: (), node: ExprKind::Constant { value: v, kind: _ }} => {
// // Restricting the types allowed to be defined as class attributes
// match v {
// ast::Constant::Bool(_) | ast::Constant::Str(_) | ast::Constant::Int(_) | ast::Constant::Float(_) => {}
// _ => {
// return Err(HashSet::from([
// format!(
// "unsupported statement in class definition body (at {})",
// b.location
// ),
// ]))
// }
// }
// class_attributes_def.push((*attr, dummy_field_type, v.clone()));
// }
// _ => {
// return Err(HashSet::from([
// format!(
// "unsupported statement in class definition body (at {})",
// b.location
// ),
// ]))
// }
// }
// annotation
// }
// };
// let parsed_annotation = parse_ast_to_type_annotation_kinds(
// class_resolver,
// temp_def_list,
// unifier,
// primitives,
// annotation.as_ref(),
// vec![(class_id, class_type_vars_def.clone())]
// .into_iter()
// .collect::<HashMap<_, _>>(),
// )?;
// // find type vars within this return type annotation
// let type_vars_within =
// get_type_var_contained_in_type_annotation(&parsed_annotation);
// // handle the class type var and the method type var
// for type_var_within in type_vars_within {
// let TypeAnnotation::TypeVar(t) = type_var_within else {
// unreachable!("must be type var annotation")
// };
// if !class_type_vars_def.contains(&t){
// return Err(HashSet::from([
// format!(
// "class fields can only use type \
// vars over which the class is generic (at {})",
// annotation.location
// ),
// ]))
// }
// }
// type_var_to_concrete_def.insert(dummy_field_type, parsed_annotation);
let mut analyze = |variable_def: &Arc<RwLock<TopLevelDef>>| -> Result<_, HashSet<String>> { let mut analyze = |variable_def: &Arc<RwLock<TopLevelDef>>| -> Result<_, HashSet<String>> {
let TopLevelDef::Variable { ty: dummy_ty, ty_decl, resolver, loc, .. } = let TopLevelDef::Variable { ty: dummy_ty, ty_decl, resolver, loc, .. } =
&*variable_def.read() &*variable_def.read()
@ -2008,7 +2066,7 @@ impl TopLevelComposer {
slice slice
} }
_ if self.core_config.kernel_ann.is_none() => ty_decl, _ if self.core_config.kernel_ann.is_none() => ty_decl,
_ => unreachable!("Global variables should be annotated with Kernel[]"), // ignore fields annotated otherwise _ => unreachable!("Global variables should be annotated with Kernel[]") // ignore fields annotated otherwise
}; };
let ty_annotation = parse_ast_to_type_annotation_kinds( let ty_annotation = parse_ast_to_type_annotation_kinds(

View File

@ -111,8 +111,6 @@ pub enum PrimDef {
FunNpLdExp, FunNpLdExp,
FunNpHypot, FunNpHypot,
FunNpNextAfter, FunNpNextAfter,
FunNpAny,
FunNpAll,
// Linalg functions // Linalg functions
FunNpDot, FunNpDot,
@ -307,8 +305,6 @@ impl PrimDef {
PrimDef::FunNpLdExp => fun("np_ldexp", None), PrimDef::FunNpLdExp => fun("np_ldexp", None),
PrimDef::FunNpHypot => fun("np_hypot", None), PrimDef::FunNpHypot => fun("np_hypot", None),
PrimDef::FunNpNextAfter => fun("np_nextafter", None), PrimDef::FunNpNextAfter => fun("np_nextafter", None),
PrimDef::FunNpAny => fun("np_any", None),
PrimDef::FunNpAll => fun("np_all", None),
// Linalg functions // Linalg functions
PrimDef::FunNpDot => fun("np_dot", None), PrimDef::FunNpDot => fun("np_dot", None),
@ -379,37 +375,21 @@ pub fn make_exception_fields(int32: Type, int64: Type, str: Type) -> Vec<(StrRef
impl TopLevelDef { impl TopLevelDef {
pub fn to_string(&self, unifier: &mut Unifier) -> String { pub fn to_string(&self, unifier: &mut Unifier) -> String {
match self { match self {
TopLevelDef::Module { name, attributes, methods, .. } => { TopLevelDef::Class { name, ancestors, fields, methods, type_vars, .. } => {
format!(
"Module {{\nname: {:?},\nattributes: {:?}\nmethods: {:?}\n}}",
name,
attributes.iter().map(|(n, _)| n.to_string()).collect_vec(),
methods.iter().map(|(n, _)| n.to_string()).collect_vec()
)
}
TopLevelDef::Class {
name, ancestors, fields, methods, attributes, type_vars, ..
} => {
let fields_str = fields let fields_str = fields
.iter() .iter()
.map(|(n, ty, _)| (n.to_string(), unifier.stringify(*ty))) .map(|(n, ty, _)| (n.to_string(), unifier.stringify(*ty)))
.collect_vec(); .collect_vec();
let attributes_str = attributes
.iter()
.map(|(n, ty, _)| (n.to_string(), unifier.stringify(*ty)))
.collect_vec();
let methods_str = methods let methods_str = methods
.iter() .iter()
.map(|(n, ty, id)| (n.to_string(), unifier.stringify(*ty), *id)) .map(|(n, ty, id)| (n.to_string(), unifier.stringify(*ty), *id))
.collect_vec(); .collect_vec();
format!( format!(
"Class {{\nname: {:?},\nancestors: {:?},\nfields: {:?},\nattributes: {:?},\nmethods: {:?},\ntype_vars: {:?}\n}}", "Class {{\nname: {:?},\nancestors: {:?},\nfields: {:?},\nmethods: {:?},\ntype_vars: {:?}\n}}",
name, name,
ancestors.iter().map(|ancestor| ancestor.stringify(unifier)).collect_vec(), ancestors.iter().map(|ancestor| ancestor.stringify(unifier)).collect_vec(),
fields_str.iter().map(|(a, _)| a).collect_vec(), fields_str.iter().map(|(a, _)| a).collect_vec(),
attributes_str.iter().map(|(a, _)| a).collect_vec(),
methods_str.iter().map(|(a, b, _)| (a, b)).collect_vec(), methods_str.iter().map(|(a, b, _)| (a, b)).collect_vec(),
type_vars.iter().map(|id| unifier.stringify(*id)).collect_vec(), type_vars.iter().map(|id| unifier.stringify(*id)).collect_vec(),
) )

View File

@ -92,20 +92,6 @@ pub struct FunInstance {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum TopLevelDef { pub enum TopLevelDef {
Module {
/// Name of the module
name: StrRef,
/// Module ID used for [`TypeEnum`]
module_id: DefinitionId,
/// `DefinitionId` of `TopLevelDef::{Class, Function}` within the module
methods: HashMap<StrRef, DefinitionId>,
/// `DefinitionId` of `TopLevelDef::{Variable}` within the module
attributes: Vec<(StrRef, DefinitionId)>,
/// Symbol resolver of the module defined the class.
resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,
/// Definition location.
loc: Option<Location>,
},
Class { Class {
/// Name for error messages and symbols. /// Name for error messages and symbols.
name: StrRef, name: StrRef,

View File

@ -3,10 +3,10 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec expression: res_vec
--- ---
[ [
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(261)]\n}\n", "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(261)]\n}\n",
] ]

View File

@ -3,13 +3,13 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec expression: res_vec
--- ---
[ [
"Class {\nname: \"A\",\nancestors: [\"A[T]\"],\nfields: [\"a\", \"b\", \"c\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[t:T], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"T\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T]\"],\nfields: [\"a\", \"b\", \"c\"],\nmethods: [(\"__init__\", \"fn[[t:T], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"T\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar245]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar245\"]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B[typevar245]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar245\"]\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
] ]

View File

@ -4,10 +4,10 @@ expression: res_vec
--- ---
[ [
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(258)]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(258)]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(263)]\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(263)]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
] ]

View File

@ -3,10 +3,10 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec expression: res_vec
--- ---
[ [
"Class {\nname: \"A\",\nancestors: [\"A[typevar244, typevar245]\"],\nfields: [\"a\", \"b\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar244\", \"typevar245\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[typevar244, typevar245]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar244\", \"typevar245\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:B], B]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:B], B]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.bar\",\nsig: \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\",\nvar_id: []\n}\n", "Function {\nname: \"B.bar\",\nsig: \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\",\nvar_id: []\n}\n",

View File

@ -3,14 +3,14 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec expression: res_vec
--- ---
[ [
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(264)]\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(264)]\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(272)]\n}\n", "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(272)]\n}\n",

View File

@ -1,7 +1,9 @@
--- ---
source: nac3core/src/toplevel/test.rs source: nac3core/src/toplevel/test.rs
assertion_line: 549
expression: res_vec expression: res_vec
--- ---
[ [
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [],\nattributes: [],\nmethods: [],\ntype_vars: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [],\nmethods: [],\ntype_vars: []\n}\n",
] ]

View File

@ -2008,8 +2008,7 @@ impl Inferencer<'_> {
ctx: ExprContext, ctx: ExprContext,
) -> InferenceResult { ) -> InferenceResult {
let ty = value.custom.unwrap(); let ty = value.custom.unwrap();
match &*self.unifier.get_ty(ty) { if let TypeEnum::TObj { obj_id, fields, .. } = &*self.unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, fields, .. } => {
// just a fast path // just a fast path
match (fields.get(&attr), ctx == ExprContext::Store) { match (fields.get(&attr), ctx == ExprContext::Store) {
(Some((ty, true)), _) | (Some((ty, false)), false) => Ok(*ty), (Some((ty, true)), _) | (Some((ty, false)), false) => Ok(*ty),
@ -2047,8 +2046,7 @@ impl Inferencer<'_> {
} }
} }
} }
} } else if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(ty) {
TypeEnum::TFunc(sign) => {
// Access Class Attributes of classes with __init__ function using Class names e.g. Foo.ATTR1 // Access Class Attributes of classes with __init__ function using Class names e.g. Foo.ATTR1
let result = { let result = {
self.top_level.definitions.read().iter().find_map(|def| { self.top_level.definitions.read().iter().find_map(|def| {
@ -2069,29 +2067,13 @@ impl Inferencer<'_> {
}; };
match result { match result {
Some(f) if ctx != ExprContext::Store => Ok(f), Some(f) if ctx != ExprContext::Store => Ok(f),
Some(_) => report_error( Some(_) => {
&format!("Class Attribute `{attr}` is immutable"), report_error(&format!("Class Attribute `{attr}` is immutable"), value.location)
value.location, }
),
None => self.infer_general_attribute(value, attr, ctx), None => self.infer_general_attribute(value, attr, ctx),
} }
} } else {
TypeEnum::TModule { attributes, .. } => { self.infer_general_attribute(value, attr, ctx)
match (attributes.get(&attr), ctx == ExprContext::Load) {
(Some((ty, _)), true) | (Some((ty, false)), false) => Ok(*ty),
(Some((ty, true)), false) => report_type_error(
TypeErrorKind::MutationError(RecordKey::Str(attr), *ty),
Some(value.location),
self.unifier,
),
(None, _) => report_type_error(
TypeErrorKind::NoSuchField(RecordKey::Str(attr), ty),
Some(value.location),
self.unifier,
),
}
}
_ => self.infer_general_attribute(value, attr, ctx),
} }
} }
@ -2752,7 +2734,7 @@ impl Inferencer<'_> {
.read() .read()
.iter() .iter()
.map(|def| match *def.read() { .map(|def| match *def.read() {
TopLevelDef::Class { name, .. } | TopLevelDef::Module { name, .. } => (name, false), TopLevelDef::Class { name, .. } => (name, false),
TopLevelDef::Function { simple_name, .. } => (simple_name, false), TopLevelDef::Function { simple_name, .. } => (simple_name, false),
TopLevelDef::Variable { simple_name, .. } => (simple_name, true), TopLevelDef::Variable { simple_name, .. } => (simple_name, true),
}) })

View File

@ -270,19 +270,6 @@ pub enum TypeEnum {
/// A function type. /// A function type.
TFunc(FunSignature), TFunc(FunSignature),
/// Module Type
TModule {
/// The [`DefinitionId`] of this object type.
module_id: DefinitionId,
/// The attributes present in this object type.
///
/// The key of the [Mapping] is the identifier of the field, while the value is a tuple
/// containing the [Type] of the field, and a `bool` indicating whether the field is a
/// variable (as opposed to a function).
attributes: Mapping<StrRef, (Type, bool)>,
},
} }
impl TypeEnum { impl TypeEnum {
@ -297,7 +284,6 @@ impl TypeEnum {
TypeEnum::TVirtual { .. } => "TVirtual", TypeEnum::TVirtual { .. } => "TVirtual",
TypeEnum::TCall { .. } => "TCall", TypeEnum::TCall { .. } => "TCall",
TypeEnum::TFunc { .. } => "TFunc", TypeEnum::TFunc { .. } => "TFunc",
TypeEnum::TModule { .. } => "TModule",
} }
} }
} }
@ -607,8 +593,7 @@ impl Unifier {
| TLiteral { .. } | TLiteral { .. }
// functions are instantiated for each call sites, so the function type can contain // functions are instantiated for each call sites, so the function type can contain
// type variables. // type variables.
| TFunc { .. } | TFunc { .. } => true,
| TModule { .. } => true,
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)), TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
TCall { .. } => false, TCall { .. } => false,
@ -1330,12 +1315,10 @@ impl Unifier {
|| format!("{id}"), || format!("{id}"),
|top_level| { |top_level| {
let top_level_def = &top_level.definitions.read()[id]; let top_level_def = &top_level.definitions.read()[id];
let top_level_def = top_level_def.read(); let TopLevelDef::Class { name, .. } = &*top_level_def.read() else {
let (TopLevelDef::Class { name, .. } | TopLevelDef::Module { name, .. }) = unreachable!("expected class definition")
&*top_level_def
else {
unreachable!("expected module/class definition")
}; };
name.to_string() name.to_string()
}, },
) )
@ -1463,10 +1446,6 @@ impl Unifier {
let ret = self.internal_stringify(signature.ret, obj_to_name, var_to_name, notes); let ret = self.internal_stringify(signature.ret, obj_to_name, var_to_name, notes);
format!("fn[[{params}], {ret}]") format!("fn[[{params}], {ret}]")
} }
TypeEnum::TModule { module_id, .. } => {
let name = obj_to_name(module_id.0);
name.to_string()
}
} }
} }
@ -1542,9 +1521,7 @@ impl Unifier {
// variables, i.e. things like TRecord, TCall should not occur, and we // variables, i.e. things like TRecord, TCall should not occur, and we
// should be safe to not implement the substitution for those variants. // should be safe to not implement the substitution for those variants.
match &*ty { match &*ty {
TypeEnum::TRigidVar { .. } | TypeEnum::TLiteral { .. } | TypeEnum::TModule { .. } => { TypeEnum::TRigidVar { .. } | TypeEnum::TLiteral { .. } => None,
None
}
TypeEnum::TVar { id, .. } => mapping.get(id).copied(), TypeEnum::TVar { id, .. } => mapping.get(id).copied(),
TypeEnum::TTuple { ty, is_vararg_ctx } => { TypeEnum::TTuple { ty, is_vararg_ctx } => {
let mut new_ty = Cow::from(ty); let mut new_ty = Cow::from(ty);

View File

@ -232,8 +232,6 @@ def patch(module):
module.np_ldexp = np.ldexp module.np_ldexp = np.ldexp
module.np_hypot = np.hypot module.np_hypot = np.hypot
module.np_nextafter = np.nextafter module.np_nextafter = np.nextafter
module.np_any = np.any
module.np_all = np.all
# SciPy Math functions # SciPy Math functions
module.sp_spec_erf = special.erf module.sp_spec_erf = special.erf

View File

@ -1,35 +0,0 @@
@extern
def output_int32(x: int32):
...
@extern
def output_strln(x: str):
...
class A:
a: int32 = 1
b: int32
c: str = "test"
d: str
def __init__(self):
self.b = 2
self.d = "test"
output_int32(self.a) # Attributes can be accessed within class
def run() -> int32:
output_int32(A.a) # Attributes can be directly accessed with class name
# A.b # Only attributes can be accessed in this way
# A.a = 2 # Attributes are immutable
obj = A()
output_int32(obj.a) # Attributes can be accessed by class objects
output_strln(obj.c)
output_strln(obj.d)
return 0

View File

@ -1551,59 +1551,6 @@ def test_ndarray_nextafter_broadcast_rhs_scalar():
output_ndarray_float_2(nextafter_x_zeros) output_ndarray_float_2(nextafter_x_zeros)
output_ndarray_float_2(nextafter_x_ones) output_ndarray_float_2(nextafter_x_ones)
def test_ndarray_any():
s0 = 0
output_bool(np_any(s0))
s1 = 1
output_bool(np_any(s1))
x1 = np_identity(5)
y1 = np_any(x1)
output_ndarray_float_2(x1)
output_bool(y1)
x2 = np_identity(1)
y2 = np_any(x2)
output_ndarray_float_2(x2)
output_bool(y2)
x3 = np_array([[1.0, 2.0], [3.0, 4.0]])
y3 = np_any(x3)
output_ndarray_float_2(x3)
output_bool(y3)
x4 = np_zeros([3, 5])
y4 = np_any(x4)
output_ndarray_float_2(x4)
output_bool(y4)
def test_ndarray_all():
s0 = 0
output_bool(np_all(s0))
s1 = 1
output_bool(np_all(s1))
x1 = np_identity(5)
y1 = np_all(x1)
output_ndarray_float_2(x1)
output_bool(y1)
x2 = np_identity(1)
y2 = np_all(x2)
output_ndarray_float_2(x2)
output_bool(y2)
x3 = np_array([[1.0, 2.0], [3.0, 4.0]])
y3 = np_all(x3)
output_ndarray_float_2(x3)
output_bool(y3)
x4 = np_zeros([3, 5])
y4 = np_all(x4)
output_ndarray_float_2(x4)
output_bool(y4)
def test_ndarray_dot(): def test_ndarray_dot():
x1: ndarray[float, 1] = np_array([5.0, 1.0, 4.0, 2.0]) x1: ndarray[float, 1] = np_array([5.0, 1.0, 4.0, 2.0])
y1: ndarray[float, 1] = np_array([5.0, 1.0, 6.0, 6.0]) y1: ndarray[float, 1] = np_array([5.0, 1.0, 6.0, 6.0])
@ -1904,9 +1851,6 @@ def run() -> int32:
test_ndarray_nextafter_broadcast_lhs_scalar() test_ndarray_nextafter_broadcast_lhs_scalar()
test_ndarray_nextafter_broadcast_rhs_scalar() test_ndarray_nextafter_broadcast_rhs_scalar()
test_ndarray_any()
test_ndarray_all()
test_ndarray_dot() test_ndarray_dot()
test_ndarray_cholesky() test_ndarray_cholesky()
test_ndarray_qr() test_ndarray_qr()

View File

@ -456,13 +456,7 @@ fn main() {
membuffer.lock().push(buffer); membuffer.lock().push(buffer);
}))); })));
let threads = (0..threads) let threads = (0..threads)
.map(|i| { .map(|i| Box::new(DefaultCodeGenerator::new(format!("module{i}"), size_t)))
Box::new(DefaultCodeGenerator::with_target_machine(
format!("module{i}"),
&context,
&target_machine,
))
})
.collect(); .collect();
let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f); let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f);
registry.add_task(task); registry.add_task(task);

View File

@ -1,15 +1,15 @@
{ pkgs } : [ { pkgs } : [
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libunwind-19.1.6-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libunwind-19.1.4-1-any.pkg.tar.zst";
sha256 = "1gv6hbqvfgjzirpljql1shlchldmf5ww3rfsspg90pq1frnwavjl"; sha256 = "0frb5k16bbxdf8g379d16vl3qrh7n9pydn83gpfxpvwf3qlvnzyl";
name = "mingw-w64-clang-x86_64-libunwind-19.1.6-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-libunwind-19.1.4-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libc++-19.1.6-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libc++-19.1.4-1-any.pkg.tar.zst";
sha256 = "1wbkvrx14ahc04cgkydvlxwmsl8jfnqwhy9sy4kn4wkdzmlcp1ax"; sha256 = "0wh5km0v8j50pqz9bxb4f0w7r8zhsvssrjvc94np53iq8wjagk86";
name = "mingw-w64-clang-x86_64-libc++-19.1.6-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-libc++-19.1.4-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
@ -19,15 +19,15 @@
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libiconv-1.18-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libiconv-1.17-4-any.pkg.tar.zst";
sha256 = "0vn5xgx9jjg66f8r9ylm9220qdbjdkffykfl6nwj14zv9y7xh4nj"; sha256 = "1g2bkhgf60dywccxw911ydyigf3m25yqfh81m5099swr7mjsmzyf";
name = "mingw-w64-clang-x86_64-libiconv-1.18-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-libiconv-1.17-4-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-gettext-runtime-0.23.1-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-gettext-runtime-0.22.5-2-any.pkg.tar.zst";
sha256 = "0wbp5pmrr0rk4mx7d1frvqlk4a061zw31zscs57srmvl0wv3pi2a"; sha256 = "0ll6ci6d3mc7g04q0xixjc209bh8r874dqbczgns69jsad3wg6mi";
name = "mingw-w64-clang-x86_64-gettext-runtime-0.23.1-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-gettext-runtime-0.22.5-2-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
@ -55,69 +55,69 @@
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-llvm-libs-19.1.6-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-llvm-libs-19.1.4-1-any.pkg.tar.zst";
sha256 = "0fpsnfyf0bg39a4ygzga06sr4wv4jp1jnc8lk6sr3z0nim0nlhjn"; sha256 = "1clrbm8dk893byj8s15pgcgqqijm2zkd10zgyakamd8m354kj9q4";
name = "mingw-w64-clang-x86_64-llvm-libs-19.1.6-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-llvm-libs-19.1.4-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-llvm-19.1.6-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-llvm-19.1.4-1-any.pkg.tar.zst";
sha256 = "0whqs9nvfmgxj3c83px6dipcdw9zi858kgd8130201fy1mbnafp1"; sha256 = "1iz2c9475h8p20ydpp0znbhyb62rlrk7wr7xl7cmwbam7wkwr8rn";
name = "mingw-w64-clang-x86_64-llvm-19.1.6-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-llvm-19.1.4-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-clang-libs-19.1.6-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-clang-libs-19.1.4-1-any.pkg.tar.zst";
sha256 = "0rmzri7h043i73jy3c2jcrg3hy40dr5s9n96kmxgaghfhvlpilps"; sha256 = "1hidciwlakxrp4kyb0j2v6g4lv76nn834g6b88w1j94fk3qc765d";
name = "mingw-w64-clang-x86_64-clang-libs-19.1.6-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-clang-libs-19.1.4-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-compiler-rt-19.1.6-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-compiler-rt-19.1.4-1-any.pkg.tar.zst";
sha256 = "04cqlh35asvlh06nmhwnx9h0yrqk8zxd9lpzxmm1xh64kvm9maxn"; sha256 = "1m1yhjkgzlbk10sv966qk4yji009ga0lr25gpgj2w7mcd2wixcr3";
name = "mingw-w64-clang-x86_64-compiler-rt-19.1.6-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-compiler-rt-19.1.4-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-headers-git-12.0.0.r473.gce0d0bfb7-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-headers-git-12.0.0.r423.g8bcd5fc1a-1-any.pkg.tar.zst";
sha256 = "05zsqgq8zwdcfacyqdxdjcf80447bgnrz71xv5cds0y135yziy7l"; sha256 = "08gxc7h2achckknn6fz3p6yi7gxxvbaday8fpm4j56c4sa04n0df";
name = "mingw-w64-clang-x86_64-headers-git-12.0.0.r473.gce0d0bfb7-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-headers-git-12.0.0.r423.g8bcd5fc1a-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-crt-git-12.0.0.r473.gce0d0bfb7-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-crt-git-12.0.0.r423.g8bcd5fc1a-1-any.pkg.tar.zst";
sha256 = "12fkxpk7rwy36snvvc7sdivx81pd4ckzh5ilyh7gl6ly4qayppp6"; sha256 = "0fxd1pb197ki0gzw6z8gmd6wgpd9d28js6cp5d31d55kw7d1vz13";
name = "mingw-w64-clang-x86_64-crt-git-12.0.0.r473.gce0d0bfb7-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-crt-git-12.0.0.r423.g8bcd5fc1a-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-lld-19.1.6-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-lld-19.1.4-1-any.pkg.tar.zst";
sha256 = "102bbv5acq1fvrfn8bp1x3503cb8hvcxmlpr86qsba4vm11l0wrw"; sha256 = "1a8pjyhrzpc2z3784xxwix4i7yrz03ygnsk1wv9k0yq8m8wi9nbw";
name = "mingw-w64-clang-x86_64-lld-19.1.6-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-lld-19.1.4-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libwinpthread-git-12.0.0.r473.gce0d0bfb7-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libwinpthread-git-12.0.0.r423.g8bcd5fc1a-1-any.pkg.tar.zst";
sha256 = "1sris0qczxk5px9xy85976hbmqrpg49ns7yyzd9p455ckf740cid"; sha256 = "140m312jx1sywqjkvfij69d268m4jpdmilq5bb8khkf0ayb16036";
name = "mingw-w64-clang-x86_64-libwinpthread-git-12.0.0.r473.gce0d0bfb7-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-libwinpthread-git-12.0.0.r423.g8bcd5fc1a-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-winpthreads-git-12.0.0.r473.gce0d0bfb7-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-winpthreads-git-12.0.0.r423.g8bcd5fc1a-1-any.pkg.tar.zst";
sha256 = "1r0m5xpsxdl00a2daj4p0wgl6037700pvw6p6zl91h1dr092r6pa"; sha256 = "017j4h511wg37bacym73f8g6s0jcfgzbzabzxpc6anr3gy4kkpbg";
name = "mingw-w64-clang-x86_64-winpthreads-git-12.0.0.r473.gce0d0bfb7-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-winpthreads-git-12.0.0.r423.g8bcd5fc1a-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-clang-19.1.6-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-clang-19.1.4-1-any.pkg.tar.zst";
sha256 = "0j4a642fpnvqs79chhinc8r5q53q1wllmc1bzb01a4y7w9rqg4hw"; sha256 = "11f4i4ai2bzvq6f06vxk1ymv7056c9707vdw489f1i2bdrf0c0ii";
name = "mingw-w64-clang-x86_64-clang-19.1.6-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-clang-19.1.4-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-rust-1.84.0-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-rust-1.83.0-3-any.pkg.tar.zst";
sha256 = "0nrz9788grl50nkbhxswry143rrwpdnc6pk6f0k30kcp19qq6y2d"; sha256 = "0nxs571vb4f1i5vp91134p5blns9ml2r25nx6kdlg0zhd5x85kvm";
name = "mingw-w64-clang-x86_64-rust-1.84.0-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-rust-1.83.0-3-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
@ -127,9 +127,9 @@
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-c-ares-1.34.4-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-c-ares-1.34.3-1-any.pkg.tar.zst";
sha256 = "1dppwwx3wrn0lzrlk2q7bpsainbidrpw1ndp1aasyv42xhxl1sn1"; sha256 = "1mpn397qsdz3l2fav6ymwjlj96ialn9m8sldii3ymbcyhranl3xx";
name = "mingw-w64-clang-x86_64-c-ares-1.34.4-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-c-ares-1.34.3-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
@ -139,9 +139,9 @@
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libunistring-1.3-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libunistring-1.2-1-any.pkg.tar.zst";
sha256 = "1zg58qbfybyqzcj0dalb13l48f9jsras318h02rka65r7wi0pdcg"; sha256 = "13nz49li39z1zgfx1q9jg4vrmyrmqb6qdq0nqshidaqc6zr16k3g";
name = "mingw-w64-clang-x86_64-libunistring-1.3-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-libunistring-1.2-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
@ -169,9 +169,9 @@
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-ca-certificates-20241223-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-ca-certificates-20240203-1-any.pkg.tar.zst";
sha256 = "0c36lg63imzw8i6j1ard42v5wgzpc83phzk8lvifvm0djndq2bbj"; sha256 = "1q5nxhsk04gidz66ai5wgd4dr04lfyakkfja9p0r5hrgg4ppqqjg";
name = "mingw-w64-clang-x86_64-ca-certificates-20241223-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-ca-certificates-20240203-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
@ -193,9 +193,9 @@
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-nghttp3-1.7.0-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-nghttp3-1.6.0-1-any.pkg.tar.zst";
sha256 = "0kd2f7yh90815kyldxvdy8c6jyxyw0wv4f7k3shwp98w874m0mxd"; sha256 = "1p7q47fin12vzyf126v1azbbpgpa0y6ighfh6mbfdb6zcyq74kbd";
name = "mingw-w64-clang-x86_64-nghttp3-1.7.0-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-nghttp3-1.6.0-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
@ -271,15 +271,15 @@
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-rhash-1.4.5-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-rhash-1.4.4-3-any.pkg.tar.zst";
sha256 = "0gdn1351knjwgsqgyaa3l55qs135k7dn6mlf04vzjxlc1895wx5z"; sha256 = "1ysbxirpfr0yf7pvyps75lnwc897w2a2kcid3nb4j6ilw6n64jmc";
name = "mingw-w64-clang-x86_64-rhash-1.4.5-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-rhash-1.4.4-3-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-cmake-3.31.4-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-cmake-3.31.2-3-any.pkg.tar.zst";
sha256 = "1xjjwgkqf2j97pcx0yd6j0lgmzgbgqjjf0s7j29mc03g89fhdhw0"; sha256 = "139f91r392c68hsajm0c81690pmzkywb0p4x8ms8ms53ncxnz6gz";
name = "mingw-w64-clang-x86_64-cmake-3.31.4-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-cmake-3.31.2-3-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
@ -289,9 +289,9 @@
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-ncurses-6.5.20241228-3-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-ncurses-6.5.20240831-1-any.pkg.tar.zst";
sha256 = "0f98pzrwsxil90n55hz2ym2x2rzrrjrmnj8i2203n189qbxbg2c9"; sha256 = "1hlfj9g4s767s502sawwbcv4a0xd3ym3ip4jswmhq48wh5050iyb";
name = "mingw-w64-clang-x86_64-ncurses-6.5.20241228-3-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-ncurses-6.5.20240831-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
@ -331,32 +331,32 @@
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-3.12.8-2-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-3.12.7-3-any.pkg.tar.zst";
sha256 = "0lksgrmylvpr7yyjcc1szm30pnag7ixrj7vhdql1ryi4k9309v8s"; sha256 = "1v15j2pzy9wj4n1rjngdi2hf8h0l9z4lri3xb86yvdv1xl2msj6h";
name = "mingw-w64-clang-x86_64-python-3.12.8-2-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-python-3.12.7-3-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-llvm-openmp-19.1.6-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-llvm-openmp-19.1.4-2-any.pkg.tar.zst";
sha256 = "0d3mm26hnw716n0ppzqhydxcgm4im081hiiy6l4zp267ad3kfg93"; sha256 = "1pn1fbj74rx837s9z8gqs4b0cr7kqi5m1m2mi9ibjpw64m1aqwxv";
name = "mingw-w64-clang-x86_64-llvm-openmp-19.1.6-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-llvm-openmp-19.1.4-2-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-openblas-0.3.29-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-openblas-0.3.28-2-any.pkg.tar.zst";
sha256 = "006f2s12jmk35rppkp20rlm7k4kknsnh5h4krqs2ry2rd6qqkk9h"; sha256 = "18p1zhf7h3k3phf3bl483jg3k7y9zq375z6ww75g62158ic9lfyc";
name = "mingw-w64-clang-x86_64-openblas-0.3.29-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-openblas-0.3.28-2-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-numpy-2.2.1-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-numpy-2.1.1-2-any.pkg.tar.zst";
sha256 = "0sgkhax9cwmkkrfrir45l91h6pgg339gaw6147gsayf8h8ag4brg"; sha256 = "1kiy7ail04ias47xbbhl9vpsz02g0g3f29ncgx5gcks9vgqldp6m";
name = "mingw-w64-clang-x86_64-python-numpy-2.2.1-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-python-numpy-2.1.1-2-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-setuptools-75.8.0-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-setuptools-75.6.0-1-any.pkg.tar.zst";
sha256 = "12ivpaj967y4bi8396q3fpii4fy5aakidxpv16rkyg1b831k0h93"; sha256 = "03l04kjmy5p9whaw0h619gdg7yw1gxbz8phifq4pzh3c1wlw7yfd";
name = "mingw-w64-clang-x86_64-python-setuptools-75.8.0-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-python-setuptools-75.6.0-1-any.pkg.tar.zst";
}) })
] ]