Compare commits

..

2 Commits

Author SHA1 Message Date
David Mak affa19e88d flake: Remove standalone execution of test cases
This is now executed as part of cargo test.
2023-12-12 19:56:05 +08:00
David Mak ea121673d3 standalone: Add cargo test cases for demos 2023-12-12 19:56:05 +08:00
139 changed files with 9249 additions and 29283 deletions

View File

@ -1,32 +0,0 @@
BasedOnStyle: LLVM
Language: Cpp
Standard: Cpp11
AccessModifierOffset: -1
AlignEscapedNewlines: Left
AlwaysBreakAfterReturnType: None
AlwaysBreakTemplateDeclarations: Yes
AllowAllParametersOfDeclarationOnNextLine: false
AllowShortFunctionsOnASingleLine: Inline
BinPackParameters: false
BreakBeforeBinaryOperators: NonAssignment
BreakBeforeTernaryOperators: true
BreakConstructorInitializers: AfterColon
BreakInheritanceList: AfterColon
ColumnLimit: 120
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ContinuationIndentWidth: 4
DerivePointerAlignment: false
IndentCaseLabels: true
IndentPPDirectives: None
IndentWidth: 4
MaxEmptyLinesToKeep: 1
PointerAlignment: Left
ReflowComments: true
SortIncludes: false
SortUsingDeclarations: true
SpaceAfterTemplateKeyword: false
SpacesBeforeTrailingComments: 2
TabWidth: 4
UseTab: Never

View File

@ -1 +0,0 @@
doc-valid-idents = ["CPython", "NumPy", ".."]

1
.gitignore vendored
View File

@ -1,4 +1,3 @@
__pycache__
/target
/nac3standalone/demo/linalg/target
nix/windows/msys2

View File

@ -1,24 +0,0 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
default_stages: [pre-commit]
repos:
- repo: local
hooks:
- id: nac3-cargo-fmt
name: nac3 cargo format
entry: nix
language: system
types: [file, rust]
pass_filenames: false
description: Runs cargo fmt on the codebase.
args: [develop, -c, cargo, fmt, --all]
- id: nac3-cargo-clippy
name: nac3 cargo clippy
entry: nix
language: system
types: [file, rust]
pass_filenames: false
description: Runs cargo clippy on the codebase.
args: [develop, -c, cargo, clippy, --tests]

1033
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -4,7 +4,6 @@ members = [
"nac3ast",
"nac3parser",
"nac3core",
"nac3core/nac3core_derive",
"nac3standalone",
"nac3artiq",
"runkernel",

View File

@ -51,12 +51,3 @@ Use ``nix develop`` in this repository to enter a development shell.
If you are using a different shell than bash you can use e.g. ``nix develop --command fish``.
Build NAC3 with ``cargo build --release``. See the demonstrations in ``nac3artiq`` and ``nac3standalone``.
### Pre-Commit Hooks
You are strongly recommended to use the provided pre-commit hooks to automatically reformat files and check for non-optimal Rust practices using Clippy. Run `pre-commit install` to install the hook and `pre-commit` will automatically run `cargo fmt` and `cargo clippy` for you.
Several things to note:
- If `cargo fmt` or `cargo clippy` returns an error, the pre-commit hook will fail. You should fix all errors before trying to commit again.
- If `cargo fmt` reformats some files, the pre-commit hook will also fail. You should review the changes and, if satisfied, try to commit again.

View File

@ -2,16 +2,16 @@
"nodes": {
"nixpkgs": {
"locked": {
"lastModified": 1731319897,
"narHash": "sha256-PbABj4tnbWFMfBp6OcUK5iGy1QY+/Z96ZcLpooIbuEI=",
"lastModified": 1701389149,
"narHash": "sha256-rU1suTIEd5DGCaAXKW6yHoCfR1mnYjOXQFOaH7M23js=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "dc460ec76cbff0e66e269457d7b728432263166c",
"rev": "5de0b32be6e85dc1a9404c75131316e4ffbc634c",
"type": "github"
},
"original": {
"owner": "NixOS",
"ref": "nixos-unstable",
"ref": "nixos-23.11",
"repo": "nixpkgs",
"type": "github"
}

View File

@ -1,12 +1,11 @@
{
description = "The third-generation ARTIQ compiler";
inputs.nixpkgs.url = github:NixOS/nixpkgs/nixos-unstable;
inputs.nixpkgs.url = github:NixOS/nixpkgs/nixos-23.11;
outputs = { self, nixpkgs }:
let
pkgs = import nixpkgs { system = "x86_64-linux"; };
pkgs32 = import nixpkgs { system = "i686-linux"; };
in rec {
packages.x86_64-linux = rec {
llvm-nac3 = pkgs.callPackage ./nix/llvm {};
@ -16,22 +15,6 @@
ln -s ${pkgs.llvmPackages_14.clang-unwrapped}/bin/clang $out/bin/clang-irrt
ln -s ${pkgs.llvmPackages_14.llvm.out}/bin/llvm-as $out/bin/llvm-as-irrt
'';
demo-linalg-stub = pkgs.rustPlatform.buildRustPackage {
name = "demo-linalg-stub";
src = ./nac3standalone/demo/linalg;
cargoLock = {
lockFile = ./nac3standalone/demo/linalg/Cargo.lock;
};
doCheck = false;
};
demo-linalg-stub32 = pkgs32.rustPlatform.buildRustPackage {
name = "demo-linalg-stub32";
src = ./nac3standalone/demo/linalg;
cargoLock = {
lockFile = ./nac3standalone/demo/linalg/Cargo.lock;
};
doCheck = false;
};
nac3artiq = pkgs.python3Packages.toPythonModule (
pkgs.rustPlatform.buildRustPackage rec {
name = "nac3artiq";
@ -41,19 +24,15 @@
lockFile = ./Cargo.lock;
};
passthru.cargoLock = cargoLock;
nativeBuildInputs = [ pkgs.python3 (pkgs.wrapClangMulti pkgs.llvmPackages_14.clang) llvm-tools-irrt pkgs.llvmPackages_14.llvm.out llvm-nac3 ];
nativeBuildInputs = [ pkgs.python3 pkgs.llvmPackages_14.clang llvm-tools-irrt pkgs.llvmPackages_14.llvm.out llvm-nac3 ];
buildInputs = [ pkgs.python3 llvm-nac3 ];
checkInputs = [ (pkgs.python3.withPackages(ps: [ ps.numpy ps.scipy ])) ];
checkPhase =
''
echo "Checking nac3standalone demos..."
echo "Running Cargo tests..."
pushd nac3standalone/demo
patchShebangs .
export DEMO_LINALG_STUB=${demo-linalg-stub}/lib/liblinalg.a
export DEMO_LINALG_STUB32=${demo-linalg-stub32}/lib/liblinalg.a
./check_demos.sh -i686
popd
echo "Running Cargo tests..."
cargoCheckHook
'';
installPhase =
@ -107,18 +86,18 @@
(pkgs.fetchFromGitHub {
owner = "m-labs";
repo = "sipyco";
rev = "094a6cd63ffa980ef63698920170e50dc9ba77fd";
sha256 = "sha256-PPnAyDedUQ7Og/Cby9x5OT9wMkNGTP8GS53V6N/dk4w=";
rev = "939f84f9b5eef7efbf7423c735d1834783b6140e";
sha256 = "sha256-15Nun4EY35j+6SPZkjzZtyH/ncxLS60KuGJjFh5kSTc=";
})
(pkgs.fetchFromGitHub {
owner = "m-labs";
repo = "artiq";
rev = "28c9de3e251daa89a8c9fd79d5ab64a3ec03bac6";
sha256 = "sha256-vAvpbHc5B+1wtG8zqN7j9dQE1ON+i22v+uqA+tw6Gak=";
rev = "8b4572f9cad34ac0c2b6f6bba9382e7b59b2f93b";
sha256 = "sha256-O/0sUSxxXU1AL9cmT9qdzCkzdOKREBNftz22/8ouQcc=";
})
];
buildInputs = [
(python3-mimalloc.withPackages(ps: [ ps.numpy ps.scipy ps.jsonschema ps.lmdb ps.platformdirs nac3artiq-instrumented ]))
(python3-mimalloc.withPackages(ps: [ ps.numpy ps.scipy ps.jsonschema ps.lmdb nac3artiq-instrumented ]))
pkgs.llvmPackages_14.llvm.out
];
phases = [ "buildPhase" "installPhase" ];
@ -168,7 +147,7 @@
buildInputs = with pkgs; [
# build dependencies
packages.x86_64-linux.llvm-nac3
(pkgs.wrapClangMulti llvmPackages_14.clang) llvmPackages_14.llvm.out # for running nac3standalone demos
llvmPackages_14.clang # demo
packages.x86_64-linux.llvm-tools-irrt
cargo
rustc
@ -178,14 +157,8 @@
# development tools
cargo-insta
clippy
pre-commit
rustfmt
];
shellHook =
''
export DEMO_LINALG_STUB=${packages.x86_64-linux.demo-linalg-stub}/lib/liblinalg.a
export DEMO_LINALG_STUB32=${packages.x86_64-linux.demo-linalg-stub32}/lib/liblinalg.a
'';
};
devShells.x86_64-linux.msys2 = pkgs.mkShell {
name = "nac3-dev-shell-msys2";

View File

@ -9,13 +9,18 @@ name = "nac3artiq"
crate-type = ["cdylib"]
[dependencies]
itertools = "0.13"
pyo3 = { version = "0.21", features = ["extension-module", "gil-refs"] }
itertools = "0.12"
pyo3 = { version = "0.20", features = ["extension-module"] }
parking_lot = "0.12"
tempfile = "3.13"
tempfile = "3.8"
nac3parser = { path = "../nac3parser" }
nac3core = { path = "../nac3core" }
nac3ld = { path = "../nac3ld" }
[dependencies.inkwell]
version = "0.2"
default-features = false
features = ["llvm14-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]
[features]
init-llvm-profile = []
no-escape-analysis = ["nac3core/no-escape-analysis"]

View File

@ -1,24 +0,0 @@
from min_artiq import *
from numpy import int32
@nac3
class EmptyList:
core: KernelInvariant[Core]
def __init__(self):
self.core = Core()
@rpc
def get_empty(self) -> list[int32]:
return []
@kernel
def run(self):
a: list[int32] = self.get_empty()
if a != []:
raise ValueError
if __name__ == "__main__":
EmptyList().run()

View File

@ -112,15 +112,10 @@ def extern(function):
register_function(function)
return function
def rpc(arg=None, flags={}):
"""Decorates a function or method to be executed on the host interpreter."""
if arg is None:
def inner_decorator(function):
return rpc(function, flags)
return inner_decorator
register_function(arg)
return arg
def rpc(function):
"""Decorates a function declaration defined by the core device runtime."""
register_function(function)
return function
def kernel(function_or_method):
"""Decorates a function or method to be executed on the core device."""
@ -206,7 +201,7 @@ class Core:
embedding = EmbeddingMap()
if allow_registration:
compiler.analyze(registered_functions, registered_classes, set())
compiler.analyze(registered_functions, registered_classes)
allow_registration = False
if hasattr(method, "__self__"):

View File

@ -1,26 +0,0 @@
from min_artiq import *
from numpy import ndarray, zeros as np_zeros
@nac3
class StrFail:
core: KernelInvariant[Core]
def __init__(self):
self.core = Core()
@kernel
def hello(self, arg: str):
pass
@kernel
def consume_ndarray(self, arg: ndarray[str, 1]):
pass
def run(self):
self.hello("world")
self.consume_ndarray(np_zeros([10], dtype=str))
if __name__ == "__main__":
StrFail().run()

View File

@ -1,24 +0,0 @@
from min_artiq import *
from numpy import int32
@nac3
class Demo:
core: KernelInvariant[Core]
attr1: KernelInvariant[str]
attr2: KernelInvariant[int32]
def __init__(self):
self.core = Core()
self.attr2 = 32
self.attr1 = "SAMPLE"
@kernel
def run(self):
print_int32(self.attr2)
self.attr1
if __name__ == "__main__":
Demo().run()

View File

@ -1,40 +0,0 @@
from min_artiq import *
from numpy import int32
@nac3
class Demo:
attr1: KernelInvariant[int32] = 2
attr2: int32 = 4
attr3: Kernel[int32]
@kernel
def __init__(self):
self.attr3 = 8
@nac3
class NAC3Devices:
core: KernelInvariant[Core]
attr4: KernelInvariant[int32] = 16
def __init__(self):
self.core = Core()
@kernel
def run(self):
Demo.attr1 # Supported
# Demo.attr2 # Field not accessible on Kernel
# Demo.attr3 # Only attributes can be accessed in this way
# Demo.attr1 = 2 # Attributes are immutable
self.attr4 # Attributes can be accessed within class
obj = Demo()
obj.attr1 # Attributes can be accessed by class objects
NAC3Devices.attr4 # Attributes accessible for classes without __init__
if __name__ == "__main__":
NAC3Devices().run()

File diff suppressed because it is too large Load Diff

View File

@ -1,74 +1,60 @@
#![deny(future_incompatible, let_underscore, nonstandard_style, clippy::all)]
#![warn(clippy::pedantic)]
#![allow(
unsafe_op_in_unsafe_fn,
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::enum_glob_use,
clippy::similar_names,
clippy::too_many_lines,
clippy::wildcard_imports
)]
use std::collections::{HashMap, HashSet};
use std::fs;
use std::io::Write;
use std::process::Command;
use std::rc::Rc;
use std::sync::Arc;
use std::{
collections::{HashMap, HashSet},
fs,
io::Write,
process::Command,
rc::Rc,
sync::Arc,
};
use itertools::Itertools;
use parking_lot::{Mutex, RwLock};
use pyo3::{
create_exception, exceptions,
prelude::*,
types::{PyBytes, PyDict, PyNone, PySet},
};
use tempfile::{self, TempDir};
use nac3core::{
codegen::{
concrete_type::ConcreteTypeStore, gen_func_impl, irrt::load_irrt, CodeGenLLVMOptions,
CodeGenTargetMachineOptions, CodeGenTask, CodeGenerator, WithCall, WorkerRegistry,
},
inkwell::{
context::Context,
use inkwell::{
memory_buffer::MemoryBuffer,
module::{FlagBehavior, Linkage, Module},
module::{Linkage, Module},
passes::PassBuilderOptions,
support::is_multithreaded,
targets::*,
OptimizationLevel,
},
nac3parser::{
ast::{Constant, ExprKind, Located, Stmt, StmtKind, StrRef},
};
use itertools::Itertools;
use nac3core::codegen::{CodeGenLLVMOptions, CodeGenTargetMachineOptions, gen_func_impl};
use nac3core::toplevel::builtins::get_exn_constructor;
use nac3core::typecheck::typedef::{TypeEnum, Unifier};
use nac3parser::{
ast::{ExprKind, Stmt, StmtKind, StrRef},
parser::parse_program,
},
};
use pyo3::prelude::*;
use pyo3::{exceptions, types::PyBytes, types::PyDict, types::PySet};
use pyo3::create_exception;
use parking_lot::{Mutex, RwLock};
use nac3core::{
codegen::irrt::load_irrt,
codegen::{concrete_type::ConcreteTypeStore, CodeGenTask, WithCall, WorkerRegistry},
symbol_resolver::SymbolResolver,
toplevel::{
builtins::get_exn_constructor,
composer::{BuiltinFuncCreator, BuiltinFuncSpec, ComposerConfig, TopLevelComposer},
composer::{ComposerConfig, TopLevelComposer},
DefinitionId, GenCall, TopLevelDef,
},
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{into_var_map, FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
},
typecheck::typedef::{FunSignature, FuncArg},
typecheck::{type_inferencer::PrimitiveStore, typedef::Type},
};
use nac3ld::Linker;
use codegen::{
attributes_writeback, gen_core_log, gen_rtio_log, rpc_codegen_callback, ArtiqCodeGenerator,
use tempfile::{self, TempDir};
use crate::codegen::attributes_writeback;
use crate::{
codegen::{rpc_codegen_callback, ArtiqCodeGenerator},
symbol_resolver::{InnerResolver, PythonHelper, Resolver, DeferredEvaluationStore},
};
use symbol_resolver::{DeferredEvaluationStore, InnerResolver, PythonHelper, Resolver};
use timeline::TimeFns;
mod codegen;
mod symbol_resolver;
mod timeline;
use timeline::TimeFns;
#[derive(PartialEq, Clone, Copy)]
enum Isa {
Host,
@ -77,17 +63,6 @@ enum Isa {
CortexA9,
}
impl Isa {
/// Returns the number of bits in `size_t` for the [`Isa`].
fn get_size_type(self) -> u32 {
if self == Isa::Host {
64u32
} else {
32u32
}
}
}
#[derive(Clone)]
pub struct PrimitivePythonId {
int: u64,
@ -98,11 +73,7 @@ pub struct PrimitivePythonId {
float: u64,
float64: u64,
bool: u64,
np_bool_: u64,
string: u64,
np_str_: u64,
list: u64,
ndarray: u64,
tuple: u64,
typevar: u64,
const_generic_marker: u64,
@ -122,7 +93,7 @@ struct Nac3 {
isa: Isa,
time_fns: &'static (dyn TimeFns + Sync),
primitive: PrimitiveStore,
builtins: Vec<BuiltinFuncSpec>,
builtins: Vec<(StrRef, FunSignature, Arc<GenCall>)>,
pyid_to_def: Arc<RwLock<HashMap<u64, DefinitionId>>>,
primitive_ids: PrimitivePythonId,
working_directory: TempDir,
@ -142,38 +113,22 @@ impl Nac3 {
module: &PyObject,
registered_class_ids: &HashSet<u64>,
) -> PyResult<()> {
let (module_name, source_file, source) =
Python::with_gil(|py| -> PyResult<(String, String, String)> {
let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> {
let module: &PyAny = module.extract(py)?;
let source_file = module.getattr("__file__");
let (source_file, source) = if let Ok(source_file) = source_file {
let source_file = source_file.extract()?;
(
source_file,
fs::read_to_string(source_file).map_err(|e| {
exceptions::PyIOError::new_err(format!(
"failed to read input file: {e}"
))
})?,
)
} else {
// kernels submitted by content have no file
// but still can provide source by StringLoader
let get_src_fn = module
.getattr("__loader__")?
.extract::<PyObject>()?
.getattr(py, "get_source")?;
("<expcontent>", get_src_fn.call1(py, (PyNone::get(py),))?.extract(py)?)
};
Ok((module.getattr("__name__")?.extract()?, source_file.to_string(), source))
Ok((module.getattr("__name__")?.extract()?, module.getattr("__file__")?.extract()?))
})?;
let source = fs::read_to_string(&source_file).map_err(|e| {
exceptions::PyIOError::new_err(format!("failed to read input file: {e}"))
})?;
let parser_result = parse_program(&source, source_file.into())
.map_err(|e| exceptions::PySyntaxError::new_err(format!("parse error: {e}")))?;
for mut stmt in parser_result {
let include = match stmt.node {
StmtKind::ClassDef { ref decorator_list, ref mut body, ref mut bases, .. } => {
StmtKind::ClassDef {
ref decorator_list, ref mut body, ref mut bases, ..
} => {
let nac3_class = decorator_list.iter().any(|decorator| {
if let ExprKind::Name { id, .. } = decorator.node {
id.to_string() == "nac3"
@ -193,8 +148,7 @@ impl Nac3 {
if *id == "Exception".into() {
Ok(true)
} else {
let base_obj =
module.getattr(py, id.to_string().as_str())?;
let base_obj = module.getattr(py, id.to_string().as_str())?;
let base_id = id_fn.call1((base_obj,))?.extract()?;
Ok(registered_class_ids.contains(&base_id))
}
@ -207,8 +161,10 @@ impl Nac3 {
body.retain(|stmt| {
if let StmtKind::FunctionDef { ref decorator_list, .. } = stmt.node {
decorator_list.iter().any(|decorator| {
if let Some(id) = decorator_id_string(decorator) {
id == "kernel" || id == "portable" || id == "rpc"
if let ExprKind::Name { id, .. } = decorator.node {
id.to_string() == "kernel"
|| id.to_string() == "portable"
|| id.to_string() == "rpc"
} else {
false
}
@ -221,8 +177,9 @@ impl Nac3 {
}
StmtKind::FunctionDef { ref decorator_list, .. } => {
decorator_list.iter().any(|decorator| {
if let Some(id) = decorator_id_string(decorator) {
id == "extern" || id == "kernel" || id == "portable" || id == "rpc"
if let ExprKind::Name { id, .. } = decorator.node {
let id = id.to_string();
id == "extern" || id == "portable" || id == "kernel" || id == "rpc"
} else {
false
}
@ -275,7 +232,7 @@ impl Nac3 {
arg_names.len(),
));
}
for (i, FuncArg { ty, default_value, name, .. }) in args.iter().enumerate() {
for (i, FuncArg { ty, default_value, name }) in args.iter().enumerate() {
let in_name = match arg_names.get(i) {
Some(n) => n,
None if default_value.is_none() => {
@ -311,64 +268,6 @@ impl Nac3 {
None
}
/// Returns a [`Vec`] of builtins that needs to be initialized during method compilation time.
fn get_lateinit_builtins() -> Vec<Box<BuiltinFuncCreator>> {
vec![
Box::new(|primitives, unifier| {
let arg_ty = unifier.get_fresh_var(Some("T".into()), None);
(
"core_log".into(),
FunSignature {
args: vec![FuncArg {
name: "arg".into(),
ty: arg_ty.ty,
default_value: None,
is_vararg: false,
}],
ret: primitives.none,
vars: into_var_map([arg_ty]),
},
Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| {
gen_core_log(ctx, &obj, fun, &args, generator)?;
Ok(None)
}))),
)
}),
Box::new(|primitives, unifier| {
let arg_ty = unifier.get_fresh_var(Some("T".into()), None);
(
"rtio_log".into(),
FunSignature {
args: vec![
FuncArg {
name: "channel".into(),
ty: primitives.str,
default_value: None,
is_vararg: false,
},
FuncArg {
name: "arg".into(),
ty: arg_ty.ty,
default_value: None,
is_vararg: false,
},
],
ret: primitives.none,
vars: into_var_map([arg_ty]),
},
Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| {
gen_rtio_log(ctx, &obj, fun, &args, generator)?;
Ok(None)
}))),
)
}),
]
}
fn compile_method<T>(
&self,
obj: &PyAny,
@ -378,12 +277,9 @@ impl Nac3 {
py: Python,
link_fn: &dyn Fn(&Module) -> PyResult<T>,
) -> PyResult<T> {
let size_t = self.isa.get_size_type();
let (mut composer, mut builtins_def, mut builtins_ty) = TopLevelComposer::new(
self.builtins.clone(),
Self::get_lateinit_builtins(),
ComposerConfig { kernel_ann: Some("Kernel"), kernel_invariant_ann: "KernelInvariant" },
size_t,
);
let builtins = PyModule::import(py, "builtins")?;
@ -431,9 +327,8 @@ impl Nac3 {
let class_obj;
if let StmtKind::ClassDef { name, .. } = &stmt.node {
let class = py_module.getattr(name.to_string().as_str()).unwrap();
if issubclass.call1((class, exn_class)).unwrap().extract().unwrap()
&& class.getattr("artiq_builtin").is_err()
{
if issubclass.call1((class, exn_class)).unwrap().extract().unwrap() &&
class.getattr("artiq_builtin").is_err() {
class_obj = Some(class);
} else {
class_obj = None;
@ -458,6 +353,7 @@ impl Nac3 {
pyid_to_type: pyid_to_type.clone(),
primitive_ids: self.primitive_ids.clone(),
global_value_ids: global_value_ids.clone(),
class_names: Mutex::default(),
name_to_pyid: name_to_pyid.clone(),
module: module.clone(),
id_to_pyval: RwLock::default(),
@ -478,35 +374,19 @@ impl Nac3 {
let (name, def_id, ty) = composer
.register_top_level(stmt.clone(), Some(resolver.clone()), path, false)
.map_err(|e| {
CompileError::new_err(format!("compilation failed\n----------\n{e}"))
CompileError::new_err(format!(
"compilation failed\n----------\n{e}"
))
})?;
if let Some(class_obj) = class_obj {
self.exception_ids
.write()
.insert(def_id.0, store_obj.call1(py, (class_obj,))?.extract(py)?);
self.exception_ids.write().insert(def_id.0, store_obj.call1(py, (class_obj, ))?.extract(py)?);
}
match &stmt.node {
StmtKind::FunctionDef { decorator_list, .. } => {
if decorator_list
.iter()
.any(|decorator| decorator_id_string(decorator) == Some("rpc".to_string()))
{
store_fun
.call1(
py,
(
def_id.0.into_py(py),
module.getattr(py, name.to_string().as_str()).unwrap(),
),
)
.unwrap();
let is_async = decorator_list.iter().any(|decorator| {
decorator_get_flags(decorator)
.iter()
.any(|constant| *constant == Constant::Str("async".into()))
});
rpc_ids.push((None, def_id, is_async));
if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "rpc".into())) {
store_fun.call1(py, (def_id.0.into_py(py), module.getattr(py, name.to_string().as_str()).unwrap())).unwrap();
rpc_ids.push((None, def_id));
}
}
StmtKind::ClassDef { name, body, .. } => {
@ -514,26 +394,19 @@ impl Nac3 {
let class_obj = module.getattr(py, class_name.as_str()).unwrap();
for stmt in body {
if let StmtKind::FunctionDef { name, decorator_list, .. } = &stmt.node {
if decorator_list.iter().any(|decorator| {
decorator_id_string(decorator) == Some("rpc".to_string())
}) {
let is_async = decorator_list.iter().any(|decorator| {
decorator_get_flags(decorator)
.iter()
.any(|constant| *constant == Constant::Str("async".into()))
});
if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "rpc".into())) {
if name == &"__init__".into() {
return Err(CompileError::new_err(format!(
"compilation failed\n----------\nThe constructor of class {} should not be decorated with rpc decorator (at {})",
class_name, stmt.location
)));
}
rpc_ids.push((Some((class_obj.clone(), *name)), def_id, is_async));
rpc_ids.push((Some((class_obj.clone(), *name)), def_id));
}
}
}
}
_ => (),
_ => ()
}
let id = *name_to_pyid.get(&name).unwrap();
@ -572,6 +445,7 @@ impl Nac3 {
pyid_to_type: pyid_to_type.clone(),
primitive_ids: self.primitive_ids.clone(),
global_value_ids: global_value_ids.clone(),
class_names: Mutex::default(),
id_to_pyval: RwLock::default(),
id_to_primitive: RwLock::default(),
field_to_val: RwLock::default(),
@ -582,26 +456,17 @@ impl Nac3 {
exception_ids: self.exception_ids.clone(),
deferred_eval_store: self.deferred_eval_store.clone(),
});
let resolver =
Arc::new(Resolver(inner_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
let resolver = Arc::new(Resolver(inner_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
let (_, def_id, _) = composer
.register_top_level(synthesized.pop().unwrap(), Some(resolver.clone()), "", false)
.unwrap();
// Process IRRT
let context = Context::create();
let irrt = load_irrt(&context, resolver.as_ref());
let fun_signature =
FunSignature { args: vec![], ret: self.primitive.none, vars: VarMap::new() };
FunSignature { args: vec![], ret: self.primitive.none, vars: HashMap::new() };
let mut store = ConcreteTypeStore::new();
let mut cache = HashMap::new();
let signature = store.from_signature(
&mut composer.unifier,
&self.primitive,
&fun_signature,
&mut cache,
);
let signature =
store.from_signature(&mut composer.unifier, &self.primitive, &fun_signature, &mut cache);
let signature = store.add_cty(signature);
if let Err(e) = composer.start_analysis(true) {
@ -620,21 +485,24 @@ impl Nac3 {
msg.unwrap_or(e.iter().sorted().join("\n----------\n"))
)))
} else {
Err(CompileError::new_err(format!(
Err(CompileError::new_err(
format!(
"compilation failed\n----------\n{}",
e.iter().sorted().join("\n----------\n"),
)))
};
),
))
}
}
let top_level = Arc::new(composer.make_top_level_context());
{
let rpc_codegen = rpc_codegen_callback();
let defs = top_level.definitions.read();
for (class_data, id, is_async) in &rpc_ids {
for (class_data, id) in &rpc_ids {
let mut def = defs[id.0].write();
match &mut *def {
TopLevelDef::Function { codegen_callback, .. } => {
*codegen_callback = Some(rpc_codegen_callback(*is_async));
*codegen_callback = Some(rpc_codegen.clone());
}
TopLevelDef::Class { methods, .. } => {
let (class_def, method_name) = class_data.as_ref().unwrap();
@ -645,26 +513,19 @@ impl Nac3 {
if let TopLevelDef::Function { codegen_callback, .. } =
&mut *defs[id.0].write()
{
*codegen_callback = Some(rpc_codegen_callback(*is_async));
*codegen_callback = Some(rpc_codegen.clone());
store_fun
.call1(
py,
(
id.0.into_py(py),
class_def
.getattr(py, name.to_string().as_str())
.unwrap(),
class_def.getattr(py, name.to_string().as_str()).unwrap(),
),
)
.unwrap();
}
}
}
TopLevelDef::Variable { .. } => {
return Err(CompileError::new_err(String::from(
"Unsupported @rpc annotation on global variable",
)))
}
}
}
}
@ -673,8 +534,7 @@ impl Nac3 {
let defs = top_level.definitions.read();
let mut definition = defs[def_id.0].write();
let TopLevelDef::Function { instance_to_stmt, instance_to_symbol, .. } =
&mut *definition
else {
&mut *definition else {
unreachable!()
};
@ -685,12 +545,29 @@ impl Nac3 {
let task = CodeGenTask {
subst: Vec::default(),
symbol_name: "__modinit__".to_string(),
body: instance.body,
signature,
resolver: resolver.clone(),
store,
unifier_index: instance.unifier_id,
calls: instance.calls,
id: 0,
};
let mut store = ConcreteTypeStore::new();
let mut cache = HashMap::new();
let signature =
store.from_signature(&mut composer.unifier, &self.primitive, &fun_signature, &mut cache);
let signature = store.add_cty(signature);
let attributes_writeback_task = CodeGenTask {
subst: Vec::default(),
symbol_name: "attributes_writeback".to_string(),
body: Arc::new(Vec::default()),
signature,
resolver,
store,
unifier_index: instance.unifier_id,
calls: instance.calls,
calls: Arc::new(HashMap::default()),
id: 0,
};
@ -703,9 +580,7 @@ impl Nac3 {
let buffer = buffer.as_slice().into();
membuffer.lock().push(buffer);
})));
let size_t = context
.ptr_sized_int_type(&self.get_llvm_target_machine().get_target_data(), None)
.get_bit_width();
let size_t = if self.isa == Isa::Host { 64 } else { 32 };
let num_threads = if is_multithreaded() { 4 } else { 1 };
let thread_names: Vec<String> = (0..num_threads).map(|_| "main".to_string()).collect();
let threads: Vec<_> = thread_names
@ -714,81 +589,49 @@ impl Nac3 {
.collect();
let membuffer = membuffers.clone();
let mut has_return = false;
py.allow_threads(|| {
let (registry, handles) =
WorkerRegistry::create_workers(threads, top_level.clone(), &self.llvm_options, &f);
let mut generator = ArtiqCodeGenerator::new("main".to_string(), size_t, self.time_fns);
let context = Context::create();
let module = context.create_module("main");
let target_machine = self.llvm_options.create_target_machine().unwrap();
module.set_data_layout(&target_machine.get_target_data().get_data_layout());
module.set_triple(&target_machine.get_triple());
module.add_basic_value_flag(
"Debug Info Version",
FlagBehavior::Warning,
context.i32_type().const_int(3, false),
let (registry, handles) = WorkerRegistry::create_workers(
threads,
top_level.clone(),
&self.llvm_options,
&f
);
module.add_basic_value_flag(
"Dwarf Version",
FlagBehavior::Warning,
context.i32_type().const_int(4, false),
);
let builder = context.create_builder();
let (_, module, _) = gen_func_impl(
&context,
&mut generator,
&registry,
builder,
module,
task,
|generator, ctx| {
assert_eq!(instance.body.len(), 1, "toplevel module should have 1 statement");
let StmtKind::Expr { value: ref expr, .. } = instance.body[0].node else {
unreachable!("toplevel statement must be an expression")
};
let ExprKind::Call { .. } = expr.node else {
unreachable!("toplevel expression must be a function call")
};
let return_obj =
generator.gen_expr(ctx, expr)?.map(|value| (expr.custom.unwrap(), value));
has_return = return_obj.is_some();
registry.add_task(task);
registry.wait_tasks_complete(handles);
attributes_writeback(
ctx,
generator,
inner_resolver.as_ref(),
&host_attributes,
return_obj,
)
},
)
.unwrap();
let mut generator = ArtiqCodeGenerator::new("attributes_writeback".to_string(), size_t, self.time_fns);
let context = inkwell::context::Context::create();
let module = context.create_module("attributes_writeback");
let builder = context.create_builder();
let (_, module, _) = gen_func_impl(&context, &mut generator, &registry, builder, module,
attributes_writeback_task, |generator, ctx| {
attributes_writeback(ctx, generator, inner_resolver.as_ref(), &host_attributes)
}).unwrap();
let buffer = module.write_bitcode_to_memory();
let buffer = buffer.as_slice().into();
membuffer.lock().push(buffer);
});
embedding_map.setattr("expects_return", has_return).unwrap();
// Link all modules into `main`.
let context = inkwell::context::Context::create();
let buffers = membuffers.lock();
let main = context
.create_module_from_ir(MemoryBuffer::create_from_memory_range(
buffers.last().unwrap(),
"main",
))
.create_module_from_ir(MemoryBuffer::create_from_memory_range(&buffers[0], "main"))
.unwrap();
for buffer in buffers.iter().rev().skip(1) {
for buffer in buffers.iter().skip(1) {
let other = context
.create_module_from_ir(MemoryBuffer::create_from_memory_range(buffer, "main"))
.unwrap();
main.link_in_module(other).map_err(|err| CompileError::new_err(err.to_string()))?;
main.link_in_module(other)
.map_err(|err| CompileError::new_err(err.to_string()))?;
}
main.link_in_module(irrt).map_err(|err| CompileError::new_err(err.to_string()))?;
let builder = context.create_builder();
let modinit_return = main.get_function("__modinit__").unwrap().get_last_basic_block().unwrap().get_terminator().unwrap();
builder.position_before(&modinit_return);
builder.build_call(main.get_function("attributes_writeback").unwrap(), &[], "attributes_writeback");
main.link_in_module(load_irrt(&context))
.map_err(|err| CompileError::new_err(err.to_string()))?;
let mut function_iter = main.get_first_function();
while let Some(func) = function_iter {
@ -799,7 +642,10 @@ impl Nac3 {
}
// Demote all global variables that will not be referenced in the kernel to private
let preserved_symbols: Vec<&'static [u8]> = vec![b"typeinfo", b"now"];
let preserved_symbols: Vec<&'static [u8]> = vec![
b"typeinfo",
b"now",
];
let mut global_option = main.get_first_global();
while let Some(global) = global_option {
if !preserved_symbols.contains(&(global.get_name().to_bytes())) {
@ -808,9 +654,7 @@ impl Nac3 {
global_option = global.get_next_global();
}
let target_machine = self
.llvm_options
.target
let target_machine = self.llvm_options.target
.create_target_machine(self.llvm_options.opt_level)
.expect("couldn't create target machine");
@ -874,42 +718,10 @@ impl Nac3 {
}
}
/// Retrieves the Name.id from a decorator, supports decorators with arguments.
fn decorator_id_string(decorator: &Located<ExprKind>) -> Option<String> {
if let ExprKind::Name { id, .. } = decorator.node {
// Bare decorator
return Some(id.to_string());
} else if let ExprKind::Call { func, .. } = &decorator.node {
// Decorators that are calls (e.g. "@rpc()") have Call for the node,
// need to extract the id from within.
if let ExprKind::Name { id, .. } = func.node {
return Some(id.to_string());
}
}
None
}
/// Retrieves flags from a decorator, if any.
fn decorator_get_flags(decorator: &Located<ExprKind>) -> Vec<Constant> {
let mut flags = vec![];
if let ExprKind::Call { keywords, .. } = &decorator.node {
for keyword in keywords {
if keyword.node.arg != Some("flags".into()) {
continue;
}
if let ExprKind::Set { elts } = &keyword.node.value.node {
for elt in elts {
if let ExprKind::Constant { value, .. } = &elt.node {
flags.push(value.clone());
}
}
}
}
}
flags
}
fn link_with_lld(elf_filename: String, obj_filename: String) -> PyResult<()> {
fn link_with_lld(
elf_filename: String,
obj_filename: String,
) -> PyResult<()>{
let linker_args = vec![
"-shared".to_string(),
"--eh-frame-hdr".to_string(),
@ -928,7 +740,9 @@ fn link_with_lld(elf_filename: String, obj_filename: String) -> PyResult<()> {
return Err(CompileError::new_err("failed to start linker"));
}
} else {
return Err(CompileError::new_err("linker returned non-zero status code"));
return Err(CompileError::new_err(
"linker returned non-zero status code",
));
}
Ok(())
@ -938,7 +752,7 @@ fn add_exceptions(
composer: &mut TopLevelComposer,
builtin_def: &mut HashMap<StrRef, DefinitionId>,
builtin_ty: &mut HashMap<StrRef, Type>,
error_names: &[&str],
error_names: &[&str]
) -> Vec<Type> {
let mut types = Vec::new();
// note: this is only for builtin exceptions, i.e. the exception name is "0:{exn}"
@ -951,7 +765,7 @@ fn add_exceptions(
// constructor id
def_id + 1,
&mut composer.unifier,
&composer.primitives_ty,
&composer.primitives_ty
);
composer.definition_ast_list.push((Arc::new(RwLock::new(exception_class)), None));
composer.definition_ast_list.push((Arc::new(RwLock::new(exception_fn)), None));
@ -978,11 +792,11 @@ impl Nac3 {
Isa::RiscV32IMA => &timeline::NOW_PINNING_TIME_FNS,
Isa::CortexA9 | Isa::Host => &timeline::EXTERN_TIME_FNS,
};
let (primitive, _) = TopLevelComposer::make_primitives(isa.get_size_type());
let primitive: PrimitiveStore = TopLevelComposer::make_primitives().0;
let builtins = vec![
(
"now_mu".into(),
FunSignature { args: vec![], ret: primitive.int64, vars: VarMap::new() },
FunSignature { args: vec![], ret: primitive.int64, vars: HashMap::new() },
Arc::new(GenCall::new(Box::new(move |ctx, _, _, _, _| {
Ok(Some(time_fns.emit_now_mu(ctx)))
}))),
@ -994,15 +808,13 @@ impl Nac3 {
name: "t".into(),
ty: primitive.int64,
default_value: None,
is_vararg: false,
}],
ret: primitive.none,
vars: VarMap::new(),
vars: HashMap::new(),
},
Arc::new(GenCall::new(Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty;
let arg =
args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty).unwrap();
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty).unwrap();
time_fns.emit_at_mu(ctx, arg);
Ok(None)
}))),
@ -1014,15 +826,13 @@ impl Nac3 {
name: "dt".into(),
ty: primitive.int64,
default_value: None,
is_vararg: false,
}],
ret: primitive.none,
vars: VarMap::new(),
vars: HashMap::new(),
},
Arc::new(GenCall::new(Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty;
let arg =
args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty).unwrap();
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty).unwrap();
time_fns.emit_delay_mu(ctx, arg);
Ok(None)
}))),
@ -1036,9 +846,8 @@ impl Nac3 {
let types_mod = PyModule::import(py, "types").unwrap();
let get_id = |x: &PyAny| id_fn.call1((x,)).and_then(PyAny::extract).unwrap();
let get_attr_id = |obj: &PyModule, attr| {
id_fn.call1((obj.getattr(attr).unwrap(),)).unwrap().extract().unwrap()
};
let get_attr_id = |obj: &PyModule, attr| id_fn.call1((obj.getattr(attr).unwrap(),))
.unwrap().extract().unwrap();
let primitive_ids = PrimitivePythonId {
virtual_id: get_id(artiq_builtins.get_item("virtual").ok().flatten().unwrap()),
generic_alias: (
@ -1047,22 +856,16 @@ impl Nac3 {
),
none: get_id(artiq_builtins.get_item("none").ok().flatten().unwrap()),
typevar: get_attr_id(typing_mod, "TypeVar"),
const_generic_marker: get_id(
artiq_builtins.get_item("_ConstGenericMarker").ok().flatten().unwrap(),
),
const_generic_marker: get_id(artiq_builtins.get_item("_ConstGenericMarker").ok().flatten().unwrap()),
int: get_attr_id(builtins_mod, "int"),
int32: get_attr_id(numpy_mod, "int32"),
int64: get_attr_id(numpy_mod, "int64"),
uint32: get_attr_id(numpy_mod, "uint32"),
uint64: get_attr_id(numpy_mod, "uint64"),
bool: get_attr_id(builtins_mod, "bool"),
np_bool_: get_attr_id(numpy_mod, "bool_"),
string: get_attr_id(builtins_mod, "str"),
np_str_: get_attr_id(numpy_mod, "str_"),
float: get_attr_id(builtins_mod, "float"),
float64: get_attr_id(numpy_mod, "float64"),
list: get_attr_id(builtins_mod, "list"),
ndarray: get_attr_id(numpy_mod, "ndarray"),
tuple: get_attr_id(builtins_mod, "tuple"),
exception: get_attr_id(builtins_mod, "Exception"),
option: get_id(artiq_builtins.get_item("Option").ok().flatten().unwrap()),
@ -1086,16 +889,11 @@ impl Nac3 {
llvm_options: CodeGenLLVMOptions {
opt_level: OptimizationLevel::Default,
target: Nac3::get_llvm_target_options(isa),
},
}
})
}
fn analyze(
&mut self,
functions: &PySet,
classes: &PySet,
content_modules: &PySet,
) -> PyResult<()> {
fn analyze(&mut self, functions: &PySet, classes: &PySet) -> PyResult<()> {
let (modules, class_ids) =
Python::with_gil(|py| -> PyResult<(HashMap<u64, PyObject>, HashSet<u64>)> {
let mut modules: HashMap<u64, PyObject> = HashMap::new();
@ -1105,22 +903,14 @@ impl Nac3 {
let getmodule_fn = PyModule::import(py, "inspect")?.getattr("getmodule")?;
for function in functions {
let module: PyObject = getmodule_fn.call1((function,))?.extract()?;
if !module.is_none(py) {
let module = getmodule_fn.call1((function,))?.extract()?;
modules.insert(id_fn.call1((&module,))?.extract()?, module);
}
}
for class in classes {
let module: PyObject = getmodule_fn.call1((class,))?.extract()?;
if !module.is_none(py) {
let module = getmodule_fn.call1((class,))?.extract()?;
modules.insert(id_fn.call1((&module,))?.extract()?, module);
}
class_ids.insert(id_fn.call1((class,))?.extract()?);
}
for module in content_modules {
let module: PyObject = module.extract()?;
modules.insert(id_fn.call1((&module,))?.extract()?, module);
}
Ok((modules, class_ids))
})?;
@ -1149,7 +939,7 @@ impl Nac3 {
.expect("couldn't write module to file");
link_with_lld(
filename.to_string(),
working_directory.join("module.o").to_string_lossy().to_string(),
working_directory.join("module.o").to_string_lossy().to_string()
)?;
Ok(())
};
@ -1197,7 +987,7 @@ impl Nac3 {
let filename = filename_path.to_str().unwrap();
link_with_lld(
filename.to_string(),
working_directory.join("module.o").to_string_lossy().to_string(),
working_directory.join("module.o").to_string_lossy().to_string()
)?;
Ok(PyBytes::new(py, &fs::read(filename).unwrap()).into())

File diff suppressed because it is too large Load Diff

View File

@ -1,15 +1,9 @@
use itertools::Either;
use nac3core::{
codegen::CodeGenContext,
inkwell::{
values::{BasicValueEnum, CallSiteValue},
AddressSpace, AtomicOrdering,
},
};
use inkwell::{values::BasicValueEnum, AddressSpace, AtomicOrdering};
use nac3core::codegen::CodeGenContext;
/// Functions for manipulating the timeline.
pub trait TimeFns {
/// Emits LLVM IR for `now_mu`.
fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx>;
@ -32,33 +26,32 @@ impl TimeFns for NowPinningTimeFns64 {
.module
.get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx
.builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
let now_hiptr =
ctx.builder.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr");
let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr else {
unreachable!()
};
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
}
.unwrap();
};
let now_hi = ctx
.builder
.build_load(now_hiptr, "now.hi")
.map(BasicValueEnum::into_int_value)
.unwrap();
let now_lo = ctx
.builder
.build_load(now_loptr, "now.lo")
.map(BasicValueEnum::into_int_value)
.unwrap();
let (BasicValueEnum::IntValue(now_hi), BasicValueEnum::IntValue(now_lo)) = (
ctx.builder.build_load(now_hiptr, "now.hi"),
ctx.builder.build_load(now_loptr, "now.lo"),
) else {
unreachable!()
};
let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "").unwrap();
let shifted_hi =
ctx.builder.build_left_shift(zext_hi, i64_type.const_int(32, false), "").unwrap();
let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "").unwrap();
ctx.builder.build_or(shifted_hi, zext_lo, "now_mu").map(Into::into).unwrap()
let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "");
let shifted_hi = ctx.builder.build_left_shift(
zext_hi,
i64_type.const_int(32, false),
"",
);
let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "");
ctx.builder.build_or(shifted_hi, zext_lo, "now_mu").into()
}
fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) {
@ -66,100 +59,105 @@ impl TimeFns for NowPinningTimeFns64 {
let i64_type = ctx.ctx.i64_type();
let i64_32 = i64_type.const_int(32, false);
let time = t.into_int_value();
let BasicValueEnum::IntValue(time) = t else {
unreachable!()
};
let time_hi = ctx
.builder
.build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, "time.hi").unwrap(),
let time_hi = ctx.builder.build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, "time.hi"),
i32_type,
"",
)
.unwrap();
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
);
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo");
let now = ctx
.module
.get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx
.builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
let now_hiptr = ctx.builder.build_bitcast(
now,
i32_type.ptr_type(AddressSpace::default()),
"now.hi.addr",
);
let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr else {
unreachable!()
};
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
}
.unwrap();
};
ctx.builder
.build_store(now_hiptr, time_hi)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
ctx.builder
.build_store(now_loptr, time_lo)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
}
fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) {
fn emit_delay_mu<'ctx>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
dt: BasicValueEnum<'ctx>,
) {
let i64_type = ctx.ctx.i64_type();
let i32_type = ctx.ctx.i32_type();
let now = ctx
.module
.get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx
.builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
let now_hiptr =
ctx.builder.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr");
let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr else {
unreachable!()
};
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
}
.unwrap();
};
let now_hi = ctx
.builder
.build_load(now_hiptr, "now.hi")
.map(BasicValueEnum::into_int_value)
.unwrap();
let now_lo = ctx
.builder
.build_load(now_loptr, "now.lo")
.map(BasicValueEnum::into_int_value)
.unwrap();
let dt = dt.into_int_value();
let (
BasicValueEnum::IntValue(now_hi),
BasicValueEnum::IntValue(now_lo),
BasicValueEnum::IntValue(dt),
) = (
ctx.builder.build_load(now_hiptr, "now.hi"),
ctx.builder.build_load(now_loptr, "now.lo"),
dt,
) else {
unreachable!()
};
let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "").unwrap();
let shifted_hi =
ctx.builder.build_left_shift(zext_hi, i64_type.const_int(32, false), "").unwrap();
let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "").unwrap();
let now_val = ctx.builder.build_or(shifted_hi, zext_lo, "now").unwrap();
let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "");
let shifted_hi = ctx.builder.build_left_shift(
zext_hi,
i64_type.const_int(32, false),
"",
);
let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "");
let now_val = ctx.builder.build_or(shifted_hi, zext_lo, "now");
let time = ctx.builder.build_int_add(now_val, dt, "time").unwrap();
let time_hi = ctx
.builder
.build_int_truncate(
ctx.builder
.build_right_shift(time, i64_type.const_int(32, false), false, "")
.unwrap(),
let time = ctx.builder.build_int_add(now_val, dt, "time");
let time_hi = ctx.builder.build_int_truncate(
ctx.builder.build_right_shift(
time,
i64_type.const_int(32, false),
false,
"",
),
i32_type,
"time.hi",
)
.unwrap();
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
);
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo");
ctx.builder
.build_store(now_hiptr, time_hi)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
ctx.builder
.build_store(now_loptr, time_lo)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
}
@ -176,16 +174,16 @@ impl TimeFns for NowPinningTimeFns {
.module
.get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_raw = ctx
.builder
.build_load(now.as_pointer_value(), "now")
.map(BasicValueEnum::into_int_value)
.unwrap();
let now_raw = ctx.builder.build_load(now.as_pointer_value(), "now");
let BasicValueEnum::IntValue(now_raw) = now_raw else {
unreachable!()
};
let i64_32 = i64_type.const_int(32, false);
let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo").unwrap();
let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi").unwrap();
ctx.builder.build_or(now_lo, now_hi, "now_mu").map(Into::into).unwrap()
let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo");
let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi");
ctx.builder.build_or(now_lo, now_hi, "now_mu").into()
}
fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) {
@ -193,44 +191,48 @@ impl TimeFns for NowPinningTimeFns {
let i64_type = ctx.ctx.i64_type();
let i64_32 = i64_type.const_int(32, false);
let time = t.into_int_value();
let BasicValueEnum::IntValue(time) = t else {
unreachable!()
};
let time_hi = ctx
.builder
.build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, "").unwrap(),
let time_hi = ctx.builder.build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, ""),
i32_type,
"time.hi",
)
.unwrap();
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "now_trunc").unwrap();
);
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "now_trunc");
let now = ctx
.module
.get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx
.builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
let now_hiptr = ctx.builder.build_bitcast(
now,
i32_type.ptr_type(AddressSpace::default()),
"now.hi.addr",
);
let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr else {
unreachable!()
};
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr")
}
.unwrap();
};
ctx.builder
.build_store(now_hiptr, time_hi)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
ctx.builder
.build_store(now_loptr, time_lo)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
}
fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) {
fn emit_delay_mu<'ctx>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
dt: BasicValueEnum<'ctx>,
) {
let i32_type = ctx.ctx.i32_type();
let i64_type = ctx.ctx.i64_type();
let i64_32 = i64_type.const_int(32, false);
@ -238,45 +240,41 @@ impl TimeFns for NowPinningTimeFns {
.module
.get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_raw = ctx
.builder
.build_load(now.as_pointer_value(), "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let now_raw = ctx.builder.build_load(now.as_pointer_value(), "");
let dt = dt.into_int_value();
let (BasicValueEnum::IntValue(now_raw), BasicValueEnum::IntValue(dt)) = (now_raw, dt) else {
unreachable!()
};
let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo").unwrap();
let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi").unwrap();
let now_val = ctx.builder.build_or(now_lo, now_hi, "now_val").unwrap();
let time = ctx.builder.build_int_add(now_val, dt, "time").unwrap();
let time_hi = ctx
.builder
.build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, "time.hi").unwrap(),
let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo");
let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi");
let now_val = ctx.builder.build_or(now_lo, now_hi, "now_val");
let time = ctx.builder.build_int_add(now_val, dt, "time");
let time_hi = ctx.builder.build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, "time.hi"),
i32_type,
"now_trunc",
)
.unwrap();
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
let now_hiptr = ctx
.builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
);
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo");
let now_hiptr = ctx.builder.build_bitcast(
now,
i32_type.ptr_type(AddressSpace::default()),
"now.hi.addr",
);
let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr else {
unreachable!()
};
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr")
}
.unwrap();
};
ctx.builder
.build_store(now_hiptr, time_hi)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
ctx.builder
.build_store(now_loptr, time_lo)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
}
@ -291,11 +289,7 @@ impl TimeFns for ExternTimeFns {
let now_mu = ctx.module.get_function("now_mu").unwrap_or_else(|| {
ctx.module.add_function("now_mu", ctx.ctx.i64_type().fn_type(&[], false), None)
});
ctx.builder
.build_call(now_mu, &[], "now_mu")
.map(CallSiteValue::try_as_basic_value)
.map(Either::unwrap_left)
.unwrap()
ctx.builder.build_call(now_mu, &[], "now_mu").try_as_basic_value().left().unwrap()
}
fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) {
@ -306,10 +300,14 @@ impl TimeFns for ExternTimeFns {
None,
)
});
ctx.builder.build_call(at_mu, &[t.into()], "at_mu").unwrap();
ctx.builder.build_call(at_mu, &[t.into()], "at_mu");
}
fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) {
fn emit_delay_mu<'ctx>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
dt: BasicValueEnum<'ctx>,
) {
let delay_mu = ctx.module.get_function("delay_mu").unwrap_or_else(|| {
ctx.module.add_function(
"delay_mu",
@ -317,7 +315,7 @@ impl TimeFns for ExternTimeFns {
None,
)
});
ctx.builder.build_call(delay_mu, &[dt.into()], "delay_mu").unwrap();
ctx.builder.build_call(delay_mu, &[dt.into()], "delay_mu");
}
}

View File

@ -10,6 +10,7 @@ constant-optimization = ["fold"]
fold = []
[dependencies]
lazy_static = "1.4"
parking_lot = "0.12"
string-interner = "0.17"
string-interner = "0.14"
fxhash = "0.2"

File diff suppressed because it is too large Load Diff

View File

@ -28,12 +28,12 @@ impl From<bool> for Constant {
}
impl From<i32> for Constant {
fn from(i: i32) -> Constant {
Self::Int(i128::from(i))
Self::Int(i as i128)
}
}
impl From<i64> for Constant {
fn from(i: i64) -> Constant {
Self::Int(i128::from(i))
Self::Int(i as i128)
}
}
@ -50,7 +50,6 @@ pub enum ConversionFlag {
}
impl ConversionFlag {
#[must_use]
pub fn try_from_byte(b: u8) -> Option<Self> {
match b {
b's' => Some(Self::Str),
@ -70,7 +69,6 @@ pub struct ConstantOptimizer {
#[cfg(feature = "constant-optimization")]
impl ConstantOptimizer {
#[inline]
#[must_use]
pub fn new() -> Self {
Self { _priv: () }
}
@ -87,10 +85,14 @@ impl<U> crate::fold::Fold<U> for ConstantOptimizer {
fn fold_expr(&mut self, node: crate::Expr<U>) -> Result<crate::Expr<U>, Self::Error> {
match node.node {
crate::ExprKind::Tuple { elts, ctx } => {
let elts =
elts.into_iter().map(|x| self.fold_expr(x)).collect::<Result<Vec<_>, _>>()?;
let expr =
if elts.iter().all(|e| matches!(e.node, crate::ExprKind::Constant { .. })) {
let elts = elts
.into_iter()
.map(|x| self.fold_expr(x))
.collect::<Result<Vec<_>, _>>()?;
let expr = if elts
.iter()
.all(|e| matches!(e.node, crate::ExprKind::Constant { .. }))
{
let tuple = elts
.into_iter()
.map(|e| match e.node {
@ -98,11 +100,18 @@ impl<U> crate::fold::Fold<U> for ConstantOptimizer {
_ => unreachable!(),
})
.collect();
crate::ExprKind::Constant { value: Constant::Tuple(tuple), kind: None }
crate::ExprKind::Constant {
value: Constant::Tuple(tuple),
kind: None,
}
} else {
crate::ExprKind::Tuple { elts, ctx }
};
Ok(crate::Expr { node: expr, custom: node.custom, location: node.location })
Ok(crate::Expr {
node: expr,
custom: node.custom,
location: node.location,
})
}
_ => crate::fold::fold_expr(self, node),
}
@ -118,7 +127,7 @@ mod tests {
use crate::fold::Fold;
use crate::*;
let location = Location::new(0, 0, FileName::default());
let location = Location::new(0, 0, Default::default());
let custom = ();
let ast = Located {
location,
@ -129,12 +138,18 @@ mod tests {
Located {
location,
custom,
node: ExprKind::Constant { value: 1.into(), kind: None },
node: ExprKind::Constant {
value: 1.into(),
kind: None,
},
},
Located {
location,
custom,
node: ExprKind::Constant { value: 2.into(), kind: None },
node: ExprKind::Constant {
value: 2.into(),
kind: None,
},
},
Located {
location,
@ -145,17 +160,26 @@ mod tests {
Located {
location,
custom,
node: ExprKind::Constant { value: 3.into(), kind: None },
node: ExprKind::Constant {
value: 3.into(),
kind: None,
},
},
Located {
location,
custom,
node: ExprKind::Constant { value: 4.into(), kind: None },
node: ExprKind::Constant {
value: 4.into(),
kind: None,
},
},
Located {
location,
custom,
node: ExprKind::Constant { value: 5.into(), kind: None },
node: ExprKind::Constant {
value: 5.into(),
kind: None,
},
},
],
},
@ -163,7 +187,9 @@ mod tests {
],
},
};
let new_ast = ConstantOptimizer::new().fold_expr(ast).unwrap_or_else(|e| match e {});
let new_ast = ConstantOptimizer::new()
.fold_expr(ast)
.unwrap_or_else(|e| match e {});
assert_eq!(
new_ast,
Located {
@ -173,7 +199,11 @@ mod tests {
value: Constant::Tuple(vec![
1.into(),
2.into(),
Constant::Tuple(vec![3.into(), 4.into(), 5.into(),])
Constant::Tuple(vec![
3.into(),
4.into(),
5.into(),
])
]),
kind: None
},

View File

@ -64,4 +64,11 @@ macro_rules! simple_fold {
};
}
simple_fold!(usize, String, bool, StrRef, constant::Constant, constant::ConversionFlag);
simple_fold!(
usize,
String,
bool,
StrRef,
constant::Constant,
constant::ConversionFlag
);

View File

@ -2,7 +2,6 @@ use crate::{Constant, ExprKind};
impl<U> ExprKind<U> {
/// Returns a short name for the node suitable for use in error messages.
#[must_use]
pub fn name(&self) -> &'static str {
match self {
ExprKind::BoolOp { .. } | ExprKind::BinOp { .. } | ExprKind::UnaryOp { .. } => {
@ -35,7 +34,10 @@ impl<U> ExprKind<U> {
ExprKind::Starred { .. } => "starred",
ExprKind::Slice { .. } => "slice",
ExprKind::JoinedStr { values } => {
if values.iter().any(|e| matches!(e.node, ExprKind::JoinedStr { .. })) {
if values
.iter()
.any(|e| matches!(e.node, ExprKind::JoinedStr { .. }))
{
"f-string expression"
} else {
"literal"

View File

@ -1,12 +1,5 @@
#![deny(future_incompatible, let_underscore, nonstandard_style, clippy::all)]
#![warn(clippy::pedantic)]
#![allow(
clippy::missing_errors_doc,
clippy::missing_panics_doc,
clippy::module_name_repetitions,
clippy::too_many_lines,
clippy::wildcard_imports
)]
#[macro_use]
extern crate lazy_static;
mod ast_gen;
mod constant;
@ -16,6 +9,6 @@ mod impls;
mod location;
pub use ast_gen::*;
pub use location::{FileName, Location};
pub use location::{Location, FileName};
pub type Suite<U = ()> = Vec<Stmt<U>>;

View File

@ -1,6 +1,6 @@
//! Datatypes to support source location information.
use crate::ast_gen::StrRef;
use std::cmp::Ordering;
use crate::ast_gen::StrRef;
use std::fmt;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
@ -22,7 +22,7 @@ impl From<String> for FileName {
pub struct Location {
pub row: usize,
pub column: usize,
pub file: FileName,
pub file: FileName
}
impl fmt::Display for Location {
@ -35,12 +35,12 @@ impl Ord for Location {
fn cmp(&self, other: &Self) -> Ordering {
let file_cmp = self.file.0.to_string().cmp(&other.file.0.to_string());
if file_cmp != Ordering::Equal {
return file_cmp;
return file_cmp
}
let row_cmp = self.row.cmp(&other.row);
if row_cmp != Ordering::Equal {
return row_cmp;
return row_cmp
}
self.column.cmp(&other.column)
@ -76,22 +76,23 @@ impl Location {
)
}
}
Visualize { loc: *self, line, desc }
Visualize {
loc: *self,
line,
desc,
}
}
}
impl Location {
#[must_use]
pub fn new(row: usize, column: usize, file: FileName) -> Self {
Location { row, column, file }
}
#[must_use]
pub fn row(&self) -> usize {
self.row
}
#[must_use]
pub fn column(&self) -> usize {
self.column
}

View File

@ -4,26 +4,17 @@ version = "0.1.0"
authors = ["M-Labs"]
edition = "2021"
[features]
default = ["derive"]
derive = ["dep:nac3core_derive"]
no-escape-analysis = []
[dependencies]
itertools = "0.13"
itertools = "0.12"
crossbeam = "0.8"
indexmap = "2.6"
parking_lot = "0.12"
rayon = "1.10"
nac3core_derive = { path = "nac3core_derive", optional = true }
rayon = "1.5"
nac3parser = { path = "../nac3parser" }
strum = "0.26"
strum_macros = "0.26"
[dependencies.inkwell]
version = "0.5"
version = "0.2"
default-features = false
features = ["llvm14-0-prefer-dynamic", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]
features = ["llvm14-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]
[dev-dependencies]
test-case = "1.2.0"

View File

@ -1,3 +1,4 @@
use regex::Regex;
use std::{
env,
fs::File,
@ -6,58 +7,35 @@ use std::{
process::{Command, Stdio},
};
use regex::Regex;
fn main() {
let out_dir = env::var("OUT_DIR").unwrap();
let out_dir = Path::new(&out_dir);
let irrt_dir = Path::new("irrt");
let irrt_cpp_path = irrt_dir.join("irrt.cpp");
const FILE: &str = "src/codegen/irrt/irrt.c";
/*
* HACK: Sadly, clang doesn't let us emit generic LLVM bitcode.
* Compiling for WASM32 and filtering the output with regex is the closest we can get.
*/
let mut flags: Vec<&str> = vec![
const FLAG: &[&str] = &[
"--target=wasm32",
"-x",
"c++",
"-std=c++20",
"-fno-discard-value-names",
"-fno-exceptions",
"-fno-rtti",
FILE,
"-O3",
"-emit-llvm",
"-S",
"-Wall",
"-Wextra",
"-o",
"-",
"-I",
irrt_dir.to_str().unwrap(),
irrt_cpp_path.to_str().unwrap(),
];
match env::var("PROFILE").as_deref() {
Ok("debug") => {
flags.push("-O0");
flags.push("-DIRRT_DEBUG_ASSERT");
}
Ok("release") => {
flags.push("-O3");
}
flavor => panic!("Unknown or missing build flavor {flavor:?}"),
}
println!("cargo:rerun-if-changed={FILE}");
let out_dir = env::var("OUT_DIR").unwrap();
let out_path = Path::new(&out_dir);
// Tell Cargo to rerun if any file under `irrt_dir` (recursive) changes
println!("cargo:rerun-if-changed={}", irrt_dir.to_str().unwrap());
// Compile IRRT and capture the LLVM IR output
let output = Command::new("clang-irrt")
.args(flags)
.args(FLAG)
.output()
.inspect(|o| {
.map(|o| {
assert!(o.status.success(), "{}", std::str::from_utf8(&o.stderr).unwrap());
o
})
.unwrap();
@ -65,17 +43,7 @@ fn main() {
let output = std::str::from_utf8(&output.stdout).unwrap().replace("\r\n", "\n");
let mut filtered_output = String::with_capacity(output.len());
// Filter out irrelevant IR
//
// Regex:
// - `(?ms:^define.*?\}$)` captures LLVM `define` blocks
// - `(?m:^declare.*?$)` captures LLVM `declare` lines
// - `(?m:^%.+?=\s*type\s*\{.+?\}$)` captures LLVM `type` declarations
// - `(?m:^@.+?=.+$)` captures global constants
let regex_filter = Regex::new(
r"(?ms:^define.*?\}$)|(?m:^declare.*?$)|(?m:^%.+?=\s*type\s*\{.+?\}$)|(?m:^@.+?=.+$)",
)
.unwrap();
let regex_filter = Regex::new(r"(?ms:^define.*?\}$)|(?m:^declare.*?$)").unwrap();
for f in regex_filter.captures_iter(&output) {
assert_eq!(f.len(), 1);
filtered_output.push_str(&f[0]);
@ -86,22 +54,18 @@ fn main() {
.unwrap()
.replace_all(&filtered_output, "");
// For debugging
// Doing `DEBUG_DUMP_IRRT=1 cargo build -p nac3core` dumps the LLVM IR generated
const DEBUG_DUMP_IRRT: &str = "DEBUG_DUMP_IRRT";
println!("cargo:rerun-if-env-changed={DEBUG_DUMP_IRRT}");
if env::var(DEBUG_DUMP_IRRT).is_ok() {
let mut file = File::create(out_dir.join("irrt.ll")).unwrap();
println!("cargo:rerun-if-env-changed=DEBUG_DUMP_IRRT");
if env::var("DEBUG_DUMP_IRRT").is_ok() {
let mut file = File::create(out_path.join("irrt.ll")).unwrap();
file.write_all(output.as_bytes()).unwrap();
let mut file = File::create(out_dir.join("irrt-filtered.ll")).unwrap();
let mut file = File::create(out_path.join("irrt-filtered.ll")).unwrap();
file.write_all(filtered_output.as_bytes()).unwrap();
}
let mut llvm_as = Command::new("llvm-as-irrt")
.stdin(Stdio::piped())
.arg("-o")
.arg(out_dir.join("irrt.bc"))
.arg(out_path.join("irrt.bc"))
.spawn()
.unwrap();
llvm_as.stdin.as_mut().unwrap().write_all(filtered_output.as_bytes()).unwrap();

View File

@ -1,5 +0,0 @@
#include "irrt/exception.hpp"
#include "irrt/list.hpp"
#include "irrt/math.hpp"
#include "irrt/ndarray.hpp"
#include "irrt/slice.hpp"

View File

@ -1,9 +0,0 @@
#pragma once
#include "irrt/int_types.hpp"
template<typename SizeT>
struct CSlice {
void* base;
SizeT len;
};

View File

@ -1,25 +0,0 @@
#pragma once
// Set in nac3core/build.rs
#ifdef IRRT_DEBUG_ASSERT
#define IRRT_DEBUG_ASSERT_BOOL true
#else
#define IRRT_DEBUG_ASSERT_BOOL false
#endif
#define raise_debug_assert(SizeT, msg, param1, param2, param3) \
raise_exception(SizeT, EXN_ASSERTION_ERROR, "IRRT debug assert failed: " msg, param1, param2, param3)
#define debug_assert_eq(SizeT, lhs, rhs) \
if constexpr (IRRT_DEBUG_ASSERT_BOOL) { \
if ((lhs) != (rhs)) { \
raise_debug_assert(SizeT, "LHS = {0}. RHS = {1}", lhs, rhs, NO_PARAM); \
} \
}
#define debug_assert(SizeT, expr) \
if constexpr (IRRT_DEBUG_ASSERT_BOOL) { \
if (!(expr)) { \
raise_debug_assert(SizeT, "Got false.", NO_PARAM, NO_PARAM, NO_PARAM); \
} \
}

View File

@ -1,85 +0,0 @@
#pragma once
#include "irrt/cslice.hpp"
#include "irrt/int_types.hpp"
/**
* @brief The int type of ARTIQ exception IDs.
*/
using ExceptionId = int32_t;
/*
* Set of exceptions C++ IRRT can use.
* Must be synchronized with `setup_irrt_exceptions` in `nac3core/src/codegen/irrt/mod.rs`.
*/
extern "C" {
ExceptionId EXN_INDEX_ERROR;
ExceptionId EXN_VALUE_ERROR;
ExceptionId EXN_ASSERTION_ERROR;
ExceptionId EXN_TYPE_ERROR;
}
/**
* @brief Extern function to `__nac3_raise`
*
* The parameter `err` could be `Exception<int32_t>` or `Exception<int64_t>`. The caller
* must make sure to pass `Exception`s with the correct `SizeT` depending on the `size_t` of the runtime.
*/
extern "C" void __nac3_raise(void* err);
namespace {
/**
* @brief NAC3's Exception struct
*/
template<typename SizeT>
struct Exception {
ExceptionId id;
CSlice<SizeT> filename;
int32_t line;
int32_t column;
CSlice<SizeT> function;
CSlice<SizeT> msg;
int64_t params[3];
};
constexpr int64_t NO_PARAM = 0;
template<typename SizeT>
void _raise_exception_helper(ExceptionId id,
const char* filename,
int32_t line,
const char* function,
const char* msg,
int64_t param0,
int64_t param1,
int64_t param2) {
Exception<SizeT> e = {
.id = id,
.filename = {.base = reinterpret_cast<void*>(const_cast<char*>(filename)),
.len = static_cast<SizeT>(__builtin_strlen(filename))},
.line = line,
.column = 0,
.function = {.base = reinterpret_cast<void*>(const_cast<char*>(function)),
.len = static_cast<SizeT>(__builtin_strlen(function))},
.msg = {.base = reinterpret_cast<void*>(const_cast<char*>(msg)),
.len = static_cast<SizeT>(__builtin_strlen(msg))},
};
e.params[0] = param0;
e.params[1] = param1;
e.params[2] = param2;
__nac3_raise(reinterpret_cast<void*>(&e));
__builtin_unreachable();
}
} // namespace
/**
* @brief Raise an exception with location details (location in the IRRT source files).
* @param SizeT The runtime `size_t` type.
* @param id The ID of the exception to raise.
* @param msg A global constant C-string of the error message.
*
* `param0` to `param2` are optional format arguments of `msg`. They should be set to
* `NO_PARAM` to indicate they are unused.
*/
#define raise_exception(SizeT, id, msg, param0, param1, param2) \
_raise_exception_helper<SizeT>(id, __FILE__, __LINE__, __FUNCTION__, msg, param0, param1, param2)

View File

@ -1,27 +0,0 @@
#pragma once
#if __STDC_VERSION__ >= 202000
using int8_t = _BitInt(8);
using uint8_t = unsigned _BitInt(8);
using int32_t = _BitInt(32);
using uint32_t = unsigned _BitInt(32);
using int64_t = _BitInt(64);
using uint64_t = unsigned _BitInt(64);
#else
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wdeprecated-type"
using int8_t = _ExtInt(8);
using uint8_t = unsigned _ExtInt(8);
using int32_t = _ExtInt(32);
using uint32_t = unsigned _ExtInt(32);
using int64_t = _ExtInt(64);
using uint64_t = unsigned _ExtInt(64);
#pragma clang diagnostic pop
#endif
// NDArray indices are always `uint32_t`.
using NDIndex = uint32_t;
// The type of an index or a value describing the length of a range/slice is always `int32_t`.
using SliceIndex = int32_t;

View File

@ -1,81 +0,0 @@
#pragma once
#include "irrt/int_types.hpp"
#include "irrt/math_util.hpp"
extern "C" {
// Handle list assignment and dropping part of the list when
// both dest_step and src_step are +1.
// - All the index must *not* be out-of-bound or negative,
// - The end index is *inclusive*,
// - The length of src and dest slice size should already
// be checked: if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest)
SliceIndex __nac3_list_slice_assign_var_size(SliceIndex dest_start,
SliceIndex dest_end,
SliceIndex dest_step,
void* dest_arr,
SliceIndex dest_arr_len,
SliceIndex src_start,
SliceIndex src_end,
SliceIndex src_step,
void* src_arr,
SliceIndex src_arr_len,
const SliceIndex size) {
/* if dest_arr_len == 0, do nothing since we do not support extending list */
if (dest_arr_len == 0)
return dest_arr_len;
/* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */
if (src_step == dest_step && dest_step == 1) {
const SliceIndex src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0;
const SliceIndex dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
if (src_len > 0) {
__builtin_memmove(static_cast<uint8_t*>(dest_arr) + dest_start * size,
static_cast<uint8_t*>(src_arr) + src_start * size, src_len * size);
}
if (dest_len > 0) {
/* dropping */
__builtin_memmove(static_cast<uint8_t*>(dest_arr) + (dest_start + src_len) * size,
static_cast<uint8_t*>(dest_arr) + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size);
}
/* shrink size */
return dest_arr_len - (dest_len - src_len);
}
/* if two range overlaps, need alloca */
uint8_t need_alloca = (dest_arr == src_arr)
&& !(max(dest_start, dest_end) < min(src_start, src_end)
|| max(src_start, src_end) < min(dest_start, dest_end));
if (need_alloca) {
void* tmp = __builtin_alloca(src_arr_len * size);
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
src_arr = tmp;
}
SliceIndex src_ind = src_start;
SliceIndex dest_ind = dest_start;
for (; (src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end); src_ind += src_step, dest_ind += dest_step) {
/* for constant optimization */
if (size == 1) {
__builtin_memcpy(static_cast<uint8_t*>(dest_arr) + dest_ind, static_cast<uint8_t*>(src_arr) + src_ind, 1);
} else if (size == 4) {
__builtin_memcpy(static_cast<uint8_t*>(dest_arr) + dest_ind * 4,
static_cast<uint8_t*>(src_arr) + src_ind * 4, 4);
} else if (size == 8) {
__builtin_memcpy(static_cast<uint8_t*>(dest_arr) + dest_ind * 8,
static_cast<uint8_t*>(src_arr) + src_ind * 8, 8);
} else {
/* memcpy for var size, cannot overlap after previous alloca */
__builtin_memcpy(static_cast<uint8_t*>(dest_arr) + dest_ind * size,
static_cast<uint8_t*>(src_arr) + src_ind * size, size);
}
}
/* only dest_step == 1 can we shrink the dest list. */
/* size should be ensured prior to calling this function */
if (dest_step == 1 && dest_end >= dest_start) {
__builtin_memmove(static_cast<uint8_t*>(dest_arr) + dest_ind * size,
static_cast<uint8_t*>(dest_arr) + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size);
return dest_arr_len - (dest_end - dest_ind) - 1;
}
return dest_arr_len;
}
} // extern "C"

View File

@ -1,93 +0,0 @@
#pragma once
namespace {
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
// need to make sure `exp >= 0` before calling this function
template<typename T>
T __nac3_int_exp_impl(T base, T exp) {
T res = 1;
/* repeated squaring method */
do {
if (exp & 1) {
res *= base; /* for n odd */
}
exp >>= 1;
base *= base;
} while (exp);
return res;
}
} // namespace
#define DEF_nac3_int_exp_(T) \
T __nac3_int_exp_##T(T base, T exp) { \
return __nac3_int_exp_impl(base, exp); \
}
extern "C" {
// Putting semicolons here to make clang-format not reformat this into
// a stair shape.
DEF_nac3_int_exp_(int32_t);
DEF_nac3_int_exp_(int64_t);
DEF_nac3_int_exp_(uint32_t);
DEF_nac3_int_exp_(uint64_t);
int32_t __nac3_isinf(double x) {
return __builtin_isinf(x);
}
int32_t __nac3_isnan(double x) {
return __builtin_isnan(x);
}
double tgamma(double arg);
double __nac3_gamma(double z) {
// Handling for denormals
// | x | Python gamma(x) | C tgamma(x) |
// --- | ----------------- | --------------- | ----------- |
// (1) | nan | nan | nan |
// (2) | -inf | -inf | inf |
// (3) | inf | inf | inf |
// (4) | 0.0 | inf | inf |
// (5) | {-1.0, -2.0, ...} | inf | nan |
// (1)-(3)
if (__builtin_isinf(z) || __builtin_isnan(z)) {
return z;
}
double v = tgamma(z);
// (4)-(5)
return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v;
}
double lgamma(double arg);
double __nac3_gammaln(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: gammaln(-inf) -> -inf
// - libm : lgamma(-inf) -> inf
if (__builtin_isinf(x)) {
return x;
}
return lgamma(x);
}
double j0(double x);
double __nac3_j0(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: j0(inf) -> nan
// - libm : j0(inf) -> 0.0
if (__builtin_isinf(x)) {
return __builtin_nan("");
}
return j0(x);
}
} // namespace

View File

@ -1,13 +0,0 @@
#pragma once
namespace {
template<typename T>
const T& max(const T& a, const T& b) {
return a > b ? a : b;
}
template<typename T>
const T& min(const T& a, const T& b) {
return a > b ? b : a;
}
} // namespace

View File

@ -1,144 +0,0 @@
#pragma once
#include "irrt/int_types.hpp"
namespace {
template<typename SizeT>
SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) {
__builtin_assume(end_idx <= list_len);
SizeT num_elems = 1;
for (SizeT i = begin_idx; i < end_idx; ++i) {
SizeT val = list_data[i];
__builtin_assume(val > 0);
num_elems *= val;
}
return num_elems;
}
template<typename SizeT>
void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT num_dims, NDIndex* idxs) {
SizeT stride = 1;
for (SizeT dim = 0; dim < num_dims; dim++) {
SizeT i = num_dims - dim - 1;
__builtin_assume(dims[i] > 0);
idxs[i] = (index / stride) % dims[i];
stride *= dims[i];
}
}
template<typename SizeT>
SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims, SizeT num_dims, const NDIndex* indices, SizeT num_indices) {
SizeT idx = 0;
SizeT stride = 1;
for (SizeT i = 0; i < num_dims; ++i) {
SizeT ri = num_dims - i - 1;
if (ri < num_indices) {
idx += stride * indices[ri];
}
__builtin_assume(dims[i] > 0);
stride *= dims[ri];
}
return idx;
}
template<typename SizeT>
void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims,
SizeT lhs_ndims,
const SizeT* rhs_dims,
SizeT rhs_ndims,
SizeT* out_dims) {
SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
for (SizeT i = 0; i < max_ndims; ++i) {
const SizeT* lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr;
const SizeT* rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr;
SizeT* out_dim = &out_dims[max_ndims - i - 1];
if (lhs_dim_sz == nullptr) {
*out_dim = *rhs_dim_sz;
} else if (rhs_dim_sz == nullptr) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == 1) {
*out_dim = *rhs_dim_sz;
} else if (*rhs_dim_sz == 1) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == *rhs_dim_sz) {
*out_dim = *lhs_dim_sz;
} else {
__builtin_unreachable();
}
}
}
template<typename SizeT>
void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT* src_dims,
SizeT src_ndims,
const NDIndex* in_idx,
NDIndex* out_idx) {
for (SizeT i = 0; i < src_ndims; ++i) {
SizeT src_i = src_ndims - i - 1;
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
}
}
} // namespace
extern "C" {
uint32_t __nac3_ndarray_calc_size(const uint32_t* list_data, uint32_t list_len, uint32_t begin_idx, uint32_t end_idx) {
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
}
uint64_t
__nac3_ndarray_calc_size64(const uint64_t* list_data, uint64_t list_len, uint64_t begin_idx, uint64_t end_idx) {
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
}
void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32_t num_dims, NDIndex* idxs) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
}
void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndex* idxs) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
}
uint32_t
__nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims, const NDIndex* indices, uint32_t num_indices) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
}
uint64_t
__nac3_ndarray_flatten_index64(const uint64_t* dims, uint64_t num_dims, const NDIndex* indices, uint64_t num_indices) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
}
void __nac3_ndarray_calc_broadcast(const uint32_t* lhs_dims,
uint32_t lhs_ndims,
const uint32_t* rhs_dims,
uint32_t rhs_ndims,
uint32_t* out_dims) {
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
}
void __nac3_ndarray_calc_broadcast64(const uint64_t* lhs_dims,
uint64_t lhs_ndims,
const uint64_t* rhs_dims,
uint64_t rhs_ndims,
uint64_t* out_dims) {
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
}
void __nac3_ndarray_calc_broadcast_idx(const uint32_t* src_dims,
uint32_t src_ndims,
const NDIndex* in_idx,
NDIndex* out_idx) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
}
void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims,
uint64_t src_ndims,
const NDIndex* in_idx,
NDIndex* out_idx) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
}
} // namespace

View File

@ -1,28 +0,0 @@
#pragma once
#include "irrt/int_types.hpp"
extern "C" {
SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
if (i < 0) {
i = len + i;
}
if (i < 0) {
return 0;
} else if (i > len) {
return len;
}
return i;
}
SliceIndex __nac3_range_slice_len(const SliceIndex start, const SliceIndex end, const SliceIndex step) {
SliceIndex diff = end - start;
if (diff > 0 && step > 0) {
return ((diff - 1) / step) + 1;
} else if (diff < 0 && step < 0) {
return ((diff + 1) / step) + 1;
} else {
return 0;
}
}
} // namespace

View File

@ -1,21 +0,0 @@
[package]
name = "nac3core_derive"
version = "0.1.0"
edition = "2021"
[lib]
proc-macro = true
[[test]]
name = "structfields_tests"
path = "tests/structfields_test.rs"
[dev-dependencies]
nac3core = { path = ".." }
trybuild = { version = "1.0", features = ["diff"] }
[dependencies]
proc-macro2 = "1.0"
proc-macro-error = "1.0"
syn = "2.0"
quote = "1.0"

View File

@ -1,320 +0,0 @@
use proc_macro::TokenStream;
use proc_macro_error::{abort, proc_macro_error};
use quote::quote;
use syn::{
parse_macro_input, spanned::Spanned, Data, DataStruct, Expr, ExprField, ExprMethodCall,
ExprPath, GenericArgument, Ident, LitStr, Path, PathArguments, Type, TypePath,
};
/// Extracts all generic arguments of a [`Type`] into a [`Vec`].
///
/// Returns [`Some`] of a possibly-empty [`Vec`] if the path of `ty` matches with
/// `expected_ty_name`, otherwise returns [`None`].
fn extract_generic_args(expected_ty_name: &'static str, ty: &Type) -> Option<Vec<GenericArgument>> {
let Type::Path(TypePath { qself: None, path, .. }) = ty else {
return None;
};
let segments = &path.segments;
if segments.len() != 1 {
return None;
};
let segment = segments.iter().next().unwrap();
if segment.ident != expected_ty_name {
return None;
}
let PathArguments::AngleBracketed(path_args) = &segment.arguments else {
return Some(Vec::new());
};
let args = &path_args.args;
Some(args.iter().cloned().collect::<Vec<_>>())
}
/// Maps a `path` matching one of the `target_idents` into the `replacement` [`Ident`].
fn map_path_to_ident(path: &Path, target_idents: &[&str], replacement: &str) -> Option<Ident> {
path.require_ident()
.ok()
.filter(|ident| target_idents.iter().any(|target| ident == target))
.map(|ident| Ident::new(replacement, ident.span()))
}
/// Extracts the left-hand side of a dot-expression.
fn extract_dot_operand(expr: &Expr) -> Option<&Expr> {
match expr {
Expr::MethodCall(ExprMethodCall { receiver: operand, .. })
| Expr::Field(ExprField { base: operand, .. }) => Some(operand),
_ => None,
}
}
/// Replaces the top-level receiver of a dot-expression with an [`Ident`], returning `Some(&mut expr)` if the
/// replacement is performed.
///
/// The top-level receiver is the left-most receiver expression, e.g. the top-level receiver of `a.b.c.foo()` is `a`.
fn replace_top_level_receiver(expr: &mut Expr, ident: Ident) -> Option<&mut Expr> {
if let Expr::MethodCall(ExprMethodCall { receiver: operand, .. })
| Expr::Field(ExprField { base: operand, .. }) = expr
{
return if extract_dot_operand(operand).is_some() {
if replace_top_level_receiver(operand, ident).is_some() {
Some(expr)
} else {
None
}
} else {
*operand = Box::new(Expr::Path(ExprPath {
attrs: Vec::default(),
qself: None,
path: ident.into(),
}));
Some(expr)
};
}
None
}
/// Iterates all operands to the left-hand side of the `.` of an [expression][`Expr`], i.e. the container operand of all
/// [`Expr::Field`] and the receiver operand of all [`Expr::MethodCall`].
///
/// The iterator will return the operand expressions in reverse order of appearance. For example, `a.b.c.func()` will
/// return `vec![c, b, a]`.
fn iter_dot_operands(expr: &Expr) -> impl Iterator<Item = &Expr> {
let mut o = extract_dot_operand(expr);
std::iter::from_fn(move || {
let this = o;
o = o.as_ref().and_then(|o| extract_dot_operand(o));
this
})
}
/// Normalizes a value expression for use when creating an instance of this structure, returning a
/// [`proc_macro2::TokenStream`] of tokens representing the normalized expression.
fn normalize_value_expr(expr: &Expr) -> proc_macro2::TokenStream {
match &expr {
Expr::Path(ExprPath { qself: None, path, .. }) => {
if let Some(ident) = map_path_to_ident(path, &["usize", "size_t"], "llvm_usize") {
quote! { #ident }
} else {
abort!(
path,
format!(
"Expected one of `size_t`, `usize`, or an implicit call expression in #[value_type(...)], found {}",
quote!(#expr).to_string(),
)
)
}
}
Expr::Call(_) => {
quote! { ctx.#expr }
}
Expr::MethodCall(_) => {
let base_receiver = iter_dot_operands(expr).last();
match base_receiver {
// `usize.{...}`, `size_t.{...}` -> Rewrite the identifiers to `llvm_usize`
Some(Expr::Path(ExprPath { qself: None, path, .. }))
if map_path_to_ident(path, &["usize", "size_t"], "llvm_usize").is_some() =>
{
let ident =
map_path_to_ident(path, &["usize", "size_t"], "llvm_usize").unwrap();
let mut expr = expr.clone();
let expr = replace_top_level_receiver(&mut expr, ident).unwrap();
quote!(#expr)
}
// `ctx.{...}`, `context.{...}` -> Rewrite the identifiers to `ctx`
Some(Expr::Path(ExprPath { qself: None, path, .. }))
if map_path_to_ident(path, &["ctx", "context"], "ctx").is_some() =>
{
let ident = map_path_to_ident(path, &["ctx", "context"], "ctx").unwrap();
let mut expr = expr.clone();
let expr = replace_top_level_receiver(&mut expr, ident).unwrap();
quote!(#expr)
}
// No reserved identifier prefix -> Prepend `ctx.` to the entire expression
_ => quote! { ctx.#expr },
}
}
_ => {
abort!(
expr,
format!(
"Expected one of `size_t`, `usize`, or an implicit call expression in #[value_type(...)], found {}",
quote!(#expr).to_string(),
)
)
}
}
}
/// Derives an implementation of `codegen::types::structure::StructFields`.
///
/// The benefit of using `#[derive(StructFields)]` is that all index- or order-dependent logic required by
/// `impl StructFields` is automatically generated by this implementation, including the field index as required by
/// `StructField::new` and the fields as returned by `StructFields::to_vec`.
///
/// # Prerequisites
///
/// In order to derive from [`StructFields`], you must implement (or derive) [`Eq`] and [`Copy`] as required by
/// `StructFields`.
///
/// Moreover, `#[derive(StructFields)]` can only be used for `struct`s with named fields, and may only contain fields
/// with either `StructField` or [`PhantomData`] types.
///
/// # Attributes for [`StructFields`]
///
/// Each `StructField` field must be declared with the `#[value_type(...)]` attribute. The argument of `value_type`
/// accepts one of the following:
///
/// - An expression returning an instance of `inkwell::types::BasicType` (with or without the receiver `ctx`/`context`).
/// For example, `context.i8_type()`, `ctx.i8_type()`, and `i8_type()` all refer to `i8`.
/// - The reserved identifiers `usize` and `size_t` referring to an `inkwell::types::IntType` of the platform-dependent
/// integer size. `usize` and `size_t` can also be used as the receiver to other method calls, e.g.
/// `usize.array_type(3)`.
///
/// # Example
///
/// The following is an example of an LLVM slice implemented using `#[derive(StructFields)]`.
///
/// ```rust,ignore
/// use nac3core::{
/// codegen::types::structure::StructField,
/// inkwell::{
/// values::{IntValue, PointerValue},
/// AddressSpace,
/// },
/// };
/// use nac3core_derive::StructFields;
///
/// // All classes that implement StructFields must also implement Eq and Copy
/// #[derive(PartialEq, Eq, Clone, Copy, StructFields)]
/// pub struct SliceValue<'ctx> {
/// // Declares ptr have a value type of i8*
/// //
/// // Can also be written as `ctx.i8_type().ptr_type(...)` or `context.i8_type().ptr_type(...)`
/// #[value_type(i8_type().ptr_type(AddressSpace::default()))]
/// ptr: StructField<'ctx, PointerValue<'ctx>>,
///
/// // Declares len have a value type of usize, depending on the target compilation platform
/// #[value_type(usize)]
/// len: StructField<'ctx, IntValue<'ctx>>,
/// }
/// ```
#[proc_macro_derive(StructFields, attributes(value_type))]
#[proc_macro_error]
pub fn derive(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as syn::DeriveInput);
let ident = &input.ident;
let Data::Struct(DataStruct { fields, .. }) = &input.data else {
abort!(input, "Only structs with named fields are supported");
};
if let Err(err_span) =
fields
.iter()
.try_for_each(|field| if field.ident.is_some() { Ok(()) } else { Err(field.span()) })
{
abort!(err_span, "Only structs with named fields are supported");
};
// Check if struct<'ctx>
if input.generics.params.len() != 1 {
abort!(input.generics, "Expected exactly 1 generic parameter")
}
let phantom_info = fields
.iter()
.filter(|field| extract_generic_args("PhantomData", &field.ty).is_some())
.map(|field| field.ident.as_ref().unwrap())
.cloned()
.collect::<Vec<_>>();
let field_info = fields
.iter()
.filter(|field| extract_generic_args("PhantomData", &field.ty).is_none())
.map(|field| {
let ident = field.ident.as_ref().unwrap();
let ty = &field.ty;
let Some(_) = extract_generic_args("StructField", ty) else {
abort!(field, "Only StructField and PhantomData are allowed")
};
let attrs = &field.attrs;
let Some(value_type_attr) =
attrs.iter().find(|attr| attr.path().is_ident("value_type"))
else {
abort!(field, "Expected #[value_type(...)] attribute for field");
};
let Ok(value_type_expr) = value_type_attr.parse_args::<Expr>() else {
abort!(value_type_attr, "Expected expression in #[value_type(...)]");
};
let value_expr_toks = normalize_value_expr(&value_type_expr);
(ident.clone(), value_expr_toks)
})
.collect::<Vec<_>>();
// `<*>::new` impl of `StructField` and `PhantomData` for `StructFields::new`
let phantoms_create = phantom_info
.iter()
.map(|id| quote! { #id: ::std::marker::PhantomData })
.collect::<Vec<_>>();
let fields_create = field_info
.iter()
.map(|(id, ty)| {
let id_lit = LitStr::new(&id.to_string(), id.span());
quote! {
#id: ::nac3core::codegen::types::structure::StructField::create(
&mut counter,
#id_lit,
#ty,
)
}
})
.collect::<Vec<_>>();
// `.into()` impl of `StructField` for `StructFields::to_vec`
let fields_into =
field_info.iter().map(|(id, _)| quote! { self.#id.into() }).collect::<Vec<_>>();
let impl_block = quote! {
impl<'ctx> ::nac3core::codegen::types::structure::StructFields<'ctx> for #ident<'ctx> {
fn new(ctx: impl ::nac3core::inkwell::context::AsContextRef<'ctx>, llvm_usize: ::nac3core::inkwell::types::IntType<'ctx>) -> Self {
let ctx = unsafe { ::nac3core::inkwell::context::ContextRef::new(ctx.as_ctx_ref()) };
let mut counter = ::nac3core::codegen::types::structure::FieldIndexCounter::default();
#ident {
#(#fields_create),*
#(#phantoms_create),*
}
}
fn to_vec(&self) -> ::std::vec::Vec<(&'static str, ::nac3core::inkwell::types::BasicTypeEnum<'ctx>)> {
vec![
#(#fields_into),*
]
}
}
};
impl_block.into()
}

View File

@ -1,9 +0,0 @@
use nac3core_derive::StructFields;
use std::marker::PhantomData;
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct EmptyValue<'ctx> {
_phantom: PhantomData<&'ctx ()>,
}
fn main() {}

View File

@ -1,20 +0,0 @@
use nac3core::{
codegen::types::structure::StructField,
inkwell::{
values::{IntValue, PointerValue},
AddressSpace,
},
};
use nac3core_derive::StructFields;
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct NDArrayValue<'ctx> {
#[value_type(usize)]
ndims: StructField<'ctx, IntValue<'ctx>>,
#[value_type(usize.ptr_type(AddressSpace::default()))]
shape: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
data: StructField<'ctx, PointerValue<'ctx>>,
}
fn main() {}

View File

@ -1,18 +0,0 @@
use nac3core::{
codegen::types::structure::StructField,
inkwell::{
values::{IntValue, PointerValue},
AddressSpace,
},
};
use nac3core_derive::StructFields;
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct SliceValue<'ctx> {
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
ptr: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(usize)]
len: StructField<'ctx, IntValue<'ctx>>,
}
fn main() {}

View File

@ -1,18 +0,0 @@
use nac3core::{
codegen::types::structure::StructField,
inkwell::{
values::{IntValue, PointerValue},
AddressSpace,
},
};
use nac3core_derive::StructFields;
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct SliceValue<'ctx> {
#[value_type(context.i8_type().ptr_type(AddressSpace::default()))]
ptr: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(usize)]
len: StructField<'ctx, IntValue<'ctx>>,
}
fn main() {}

View File

@ -1,18 +0,0 @@
use nac3core::{
codegen::types::structure::StructField,
inkwell::{
values::{IntValue, PointerValue},
AddressSpace,
},
};
use nac3core_derive::StructFields;
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct SliceValue<'ctx> {
#[value_type(ctx.i8_type().ptr_type(AddressSpace::default()))]
ptr: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(usize)]
len: StructField<'ctx, IntValue<'ctx>>,
}
fn main() {}

View File

@ -1,18 +0,0 @@
use nac3core::{
codegen::types::structure::StructField,
inkwell::{
values::{IntValue, PointerValue},
AddressSpace,
},
};
use nac3core_derive::StructFields;
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct SliceValue<'ctx> {
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
ptr: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(size_t)]
len: StructField<'ctx, IntValue<'ctx>>,
}
fn main() {}

View File

@ -1,10 +0,0 @@
#[test]
fn test_parse_empty() {
let t = trybuild::TestCases::new();
t.pass("tests/structfields_empty.rs");
t.pass("tests/structfields_slice.rs");
t.pass("tests/structfields_slice_ctx.rs");
t.pass("tests/structfields_slice_context.rs");
t.pass("tests/structfields_slice_sizet.rs");
t.pass("tests/structfields_ndarray.rs");
}

File diff suppressed because it is too large Load Diff

View File

@ -1,20 +1,15 @@
use std::collections::HashMap;
use indexmap::IndexMap;
use nac3parser::ast::StrRef;
use crate::{
symbol_resolver::SymbolValue,
toplevel::DefinitionId,
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{
into_var_map, FunSignature, FuncArg, Type, TypeEnum, TypeVar, TypeVarId, Unifier,
},
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier},
},
};
use nac3parser::ast::StrRef;
use std::collections::HashMap;
pub struct ConcreteTypeStore {
store: Vec<ConcreteTypeEnum>,
}
@ -27,7 +22,6 @@ pub struct ConcreteFuncArg {
pub name: StrRef,
pub ty: ConcreteType,
pub default_value: Option<SymbolValue>,
pub is_vararg: bool,
}
#[derive(Clone, Debug)]
@ -49,12 +43,14 @@ pub enum ConcreteTypeEnum {
TPrimitive(Primitive),
TTuple {
ty: Vec<ConcreteType>,
is_vararg_ctx: bool,
},
TList {
ty: ConcreteType,
},
TObj {
obj_id: DefinitionId,
fields: HashMap<StrRef, (ConcreteType, bool)>,
params: IndexMap<TypeVarId, ConcreteType>,
params: HashMap<u32, ConcreteType>,
},
TVirtual {
ty: ConcreteType,
@ -62,10 +58,11 @@ pub enum ConcreteTypeEnum {
TFunc {
args: Vec<ConcreteFuncArg>,
ret: ConcreteType,
vars: HashMap<TypeVarId, ConcreteType>,
vars: HashMap<u32, ConcreteType>,
},
TLiteral {
values: Vec<SymbolValue>,
TConstant {
value: SymbolValue,
ty: ConcreteType,
},
}
@ -106,16 +103,8 @@ impl ConcreteTypeStore {
.iter()
.map(|arg| ConcreteFuncArg {
name: arg.name,
ty: if arg.is_vararg {
let tuple_ty = unifier
.add_ty(TypeEnum::TTuple { ty: vec![arg.ty], is_vararg_ctx: true });
self.from_unifier_type(unifier, primitives, tuple_ty, cache)
} else {
self.from_unifier_type(unifier, primitives, arg.ty, cache)
},
ty: self.from_unifier_type(unifier, primitives, arg.ty, cache),
default_value: arg.default_value.clone(),
is_vararg: arg.is_vararg,
})
.collect(),
ret: self.from_unifier_type(unifier, primitives, signature.ret, cache),
@ -170,12 +159,14 @@ impl ConcreteTypeStore {
cache.insert(ty, None);
let ty_enum = unifier.get_ty(ty);
let result = match &*ty_enum {
TypeEnum::TTuple { ty, is_vararg_ctx } => ConcreteTypeEnum::TTuple {
TypeEnum::TTuple { ty } => ConcreteTypeEnum::TTuple {
ty: ty
.iter()
.map(|t| self.from_unifier_type(unifier, primitives, *t, cache))
.collect(),
is_vararg_ctx: *is_vararg_ctx,
},
TypeEnum::TList { ty } => ConcreteTypeEnum::TList {
ty: self.from_unifier_type(unifier, primitives, *ty, cache),
},
TypeEnum::TObj { obj_id, fields, params } => ConcreteTypeEnum::TObj {
obj_id: *obj_id,
@ -211,9 +202,10 @@ impl ConcreteTypeStore {
TypeEnum::TFunc(signature) => {
self.from_signature(unifier, primitives, signature, cache)
}
TypeEnum::TLiteral { values, .. } => {
ConcreteTypeEnum::TLiteral { values: values.clone() }
}
TypeEnum::TConstant { value, ty, .. } => ConcreteTypeEnum::TConstant {
value: value.clone(),
ty: self.from_unifier_type(unifier, primitives, *ty, cache),
},
_ => unreachable!("{:?}", ty_enum.get_type_name()),
};
let index = if let Some(ConcreteType(index)) = cache.get(&ty).unwrap() {
@ -239,7 +231,7 @@ impl ConcreteTypeStore {
return if let Some(ty) = ty {
*ty
} else {
*ty = Some(unifier.get_dummy_var().ty);
*ty = Some(unifier.get_dummy_var().0);
ty.unwrap()
};
}
@ -261,13 +253,15 @@ impl ConcreteTypeStore {
*cache.get_mut(&cty).unwrap() = Some(ty);
return ty;
}
ConcreteTypeEnum::TTuple { ty, is_vararg_ctx } => TypeEnum::TTuple {
ConcreteTypeEnum::TTuple { ty } => TypeEnum::TTuple {
ty: ty
.iter()
.map(|cty| self.to_unifier_type(unifier, primitives, *cty, cache))
.collect(),
is_vararg_ctx: *is_vararg_ctx,
},
ConcreteTypeEnum::TList { ty } => {
TypeEnum::TList { ty: self.to_unifier_type(unifier, primitives, *ty, cache) }
}
ConcreteTypeEnum::TVirtual { ty } => {
TypeEnum::TVirtual { ty: self.to_unifier_type(unifier, primitives, *ty, cache) }
}
@ -279,10 +273,10 @@ impl ConcreteTypeStore {
(*name, (self.to_unifier_type(unifier, primitives, cty.0, cache), cty.1))
})
.collect::<HashMap<_, _>>(),
params: into_var_map(params.iter().map(|(&id, cty)| {
let ty = self.to_unifier_type(unifier, primitives, *cty, cache);
TypeVar { id, ty }
})),
params: params
.iter()
.map(|(id, cty)| (*id, self.to_unifier_type(unifier, primitives, *cty, cache)))
.collect::<HashMap<_, _>>(),
},
ConcreteTypeEnum::TFunc { args, ret, vars } => TypeEnum::TFunc(FunSignature {
args: args
@ -291,17 +285,18 @@ impl ConcreteTypeStore {
name: arg.name,
ty: self.to_unifier_type(unifier, primitives, arg.ty, cache),
default_value: arg.default_value.clone(),
is_vararg: false,
})
.collect(),
ret: self.to_unifier_type(unifier, primitives, *ret, cache),
vars: into_var_map(vars.iter().map(|(&id, cty)| {
let ty = self.to_unifier_type(unifier, primitives, *cty, cache);
TypeVar { id, ty }
})),
vars: vars
.iter()
.map(|(id, cty)| (*id, self.to_unifier_type(unifier, primitives, *cty, cache)))
.collect::<HashMap<_, _>>(),
}),
ConcreteTypeEnum::TLiteral { values, .. } => {
TypeEnum::TLiteral { values: values.clone(), loc: None }
ConcreteTypeEnum::TConstant { value, ty } => TypeEnum::TConstant {
value: value.clone(),
ty: self.to_unifier_type(unifier, primitives, *ty, cache),
loc: None,
}
};
let result = unifier.add_ty(result);

File diff suppressed because it is too large Load Diff

View File

@ -1,193 +0,0 @@
use inkwell::{
attributes::{Attribute, AttributeLoc},
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue},
};
use itertools::Either;
use super::CodeGenContext;
/// Macro to generate extern function
/// Both function return type and function parameter type are `FloatValue`
///
/// Arguments:
/// * `unary/binary`: Whether the extern function requires one (unary) or two (binary) operands
/// * `$fn_name:ident`: The identifier of the rust function to be generated
/// * `$extern_fn:literal`: Name of underlying extern function
///
/// Optional Arguments:
/// * `$(,$attributes:literal)*)`: Attributes linked with the extern function.
/// The default attributes are "mustprogress", "nofree", "nounwind", "willreturn", and "writeonly".
/// These will be used unless other attributes are specified
/// * `$(,$args:ident)*`: Operands of the extern function
/// The data type of these operands will be set to `FloatValue`
///
macro_rules! generate_extern_fn {
("unary", $fn_name:ident, $extern_fn:literal) => {
generate_extern_fn!($fn_name, $extern_fn, arg, "mustprogress", "nofree", "nounwind", "willreturn", "writeonly");
};
("unary", $fn_name:ident, $extern_fn:literal $(,$attributes:literal)*) => {
generate_extern_fn!($fn_name, $extern_fn, arg $(,$attributes)*);
};
("binary", $fn_name:ident, $extern_fn:literal) => {
generate_extern_fn!($fn_name, $extern_fn, arg1, arg2, "mustprogress", "nofree", "nounwind", "willreturn", "writeonly");
};
("binary", $fn_name:ident, $extern_fn:literal $(,$attributes:literal)*) => {
generate_extern_fn!($fn_name, $extern_fn, arg1, arg2 $(,$attributes)*);
};
($fn_name:ident, $extern_fn:literal $(,$args:ident)* $(,$attributes:literal)*) => {
#[doc = concat!("Invokes the [`", stringify!($extern_fn), "`](https://en.cppreference.com/w/c/numeric/math/", stringify!($llvm_name), ") function." )]
pub fn $fn_name<'ctx>(
ctx: &CodeGenContext<'ctx, '_>
$(,$args: FloatValue<'ctx>)*,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = $extern_fn;
let llvm_f64 = ctx.ctx.f64_type();
$(debug_assert_eq!($args.get_type(), llvm_f64);)*
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[$($args.get_type().into()),*], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in [$($attributes),*] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[$($args.into()),*], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
};
}
generate_extern_fn!("unary", call_tan, "tan");
generate_extern_fn!("unary", call_asin, "asin");
generate_extern_fn!("unary", call_acos, "acos");
generate_extern_fn!("unary", call_atan, "atan");
generate_extern_fn!("unary", call_sinh, "sinh");
generate_extern_fn!("unary", call_cosh, "cosh");
generate_extern_fn!("unary", call_tanh, "tanh");
generate_extern_fn!("unary", call_asinh, "asinh");
generate_extern_fn!("unary", call_acosh, "acosh");
generate_extern_fn!("unary", call_atanh, "atanh");
generate_extern_fn!("unary", call_expm1, "expm1");
generate_extern_fn!(
"unary",
call_cbrt,
"cbrt",
"mustprogress",
"nofree",
"nosync",
"nounwind",
"readonly",
"willreturn"
);
generate_extern_fn!("unary", call_erf, "erf", "nounwind");
generate_extern_fn!("unary", call_erfc, "erfc", "nounwind");
generate_extern_fn!("unary", call_j1, "j1", "nounwind");
generate_extern_fn!("binary", call_atan2, "atan2");
generate_extern_fn!("binary", call_hypot, "hypot", "nounwind");
generate_extern_fn!("binary", call_nextafter, "nextafter", "nounwind");
/// Invokes the [`ldexp`](https://en.cppreference.com/w/c/numeric/math/ldexp) function.
pub fn call_ldexp<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
exp: IntValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "ldexp";
let llvm_f64 = ctx.ctx.f64_type();
let llvm_i32 = ctx.ctx.i32_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
debug_assert_eq!(exp.get_type(), llvm_i32);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into(), llvm_i32.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[arg.into(), exp.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Macro to generate `np_linalg` and `sp_linalg` functions
/// The function takes as input `NDArray` and returns ()
///
/// Arguments:
/// * `$fn_name:ident`: The identifier of the rust function to be generated
/// * `$extern_fn:literal`: Name of underlying extern function
/// * (2/3/4): Number of `NDArray` that function takes as input
///
/// Note:
/// The operands and resulting `NDArray` are both passed as input to the funcion
/// It is the responsibility of caller to ensure that output `NDArray` is properly allocated on stack
/// The function changes the content of the output `NDArray` in-place
macro_rules! generate_linalg_extern_fn {
($fn_name:ident, $extern_fn:literal, 2) => {
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2);
};
($fn_name:ident, $extern_fn:literal, 3) => {
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2, mat3);
};
($fn_name:ident, $extern_fn:literal, 4) => {
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2, mat3, mat4);
};
($fn_name:ident, $extern_fn:literal $(,$input_matrix:ident)*) => {
#[doc = concat!("Invokes the linalg `", stringify!($extern_fn), " function." )]
pub fn $fn_name<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>
$(,$input_matrix: BasicValueEnum<'ctx>)*,
name: Option<&str>,
){
const FN_NAME: &str = $extern_fn;
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = ctx.ctx.void_type().fn_type(&[$($input_matrix.get_type().into()),*], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder.build_call(extern_fn, &[$($input_matrix.into(),)*], name.unwrap_or_default()).unwrap();
}
};
}
generate_linalg_extern_fn!(call_np_linalg_cholesky, "np_linalg_cholesky", 2);
generate_linalg_extern_fn!(call_np_linalg_qr, "np_linalg_qr", 3);
generate_linalg_extern_fn!(call_np_linalg_svd, "np_linalg_svd", 4);
generate_linalg_extern_fn!(call_np_linalg_inv, "np_linalg_inv", 2);
generate_linalg_extern_fn!(call_np_linalg_pinv, "np_linalg_pinv", 2);
generate_linalg_extern_fn!(call_np_linalg_matrix_power, "np_linalg_matrix_power", 3);
generate_linalg_extern_fn!(call_np_linalg_det, "np_linalg_det", 2);
generate_linalg_extern_fn!(call_sp_linalg_lu, "sp_linalg_lu", 3);
generate_linalg_extern_fn!(call_sp_linalg_schur, "sp_linalg_schur", 3);
generate_linalg_extern_fn!(call_sp_linalg_hessenberg, "sp_linalg_hessenberg", 3);

View File

@ -1,18 +1,16 @@
use crate::{
codegen::{expr::*, stmt::*, bool_to_i1, bool_to_i8, CodeGenContext},
symbol_resolver::ValueEnum,
toplevel::{DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type},
};
use inkwell::{
context::Context,
types::{BasicTypeEnum, IntType},
values::{BasicValueEnum, IntValue, PointerValue},
};
use nac3parser::ast::{Expr, Stmt, StrRef};
use super::{bool_to_i1, bool_to_i8, expr::*, stmt::*, values::ArraySliceValue, CodeGenContext};
use crate::{
symbol_resolver::ValueEnum,
toplevel::{DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type},
};
pub trait CodeGenerator {
/// Return the module name for the code generator.
fn get_name(&self) -> &str;
@ -59,7 +57,6 @@ pub trait CodeGenerator {
/// - fun: Function signature, definition ID and the substitution key.
/// - params: Function parameters. Note that this does not include the object even if the
/// function is a class method.
///
/// Note that this function should check if the function is generated in another thread (due to
/// possible race condition), see the default implementation for an example.
fn gen_func_instance<'ctx>(
@ -95,18 +92,6 @@ pub trait CodeGenerator {
gen_var(ctx, ty, name)
}
/// Allocate memory for a variable and return a pointer pointing to it.
/// The default implementation places the allocations at the start of the function.
fn gen_array_var_alloc<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
ty: BasicTypeEnum<'ctx>,
size: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> Result<ArraySliceValue<'ctx>, String> {
gen_array_var(ctx, ty, size, name)
}
/// Return a pointer pointing to the target of the expression.
fn gen_store_target<'ctx>(
&mut self,
@ -126,45 +111,11 @@ pub trait CodeGenerator {
ctx: &mut CodeGenContext<'ctx, '_>,
target: &Expr<Option<Type>>,
value: ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String>
where
Self: Sized,
{
gen_assign(self, ctx, target, value, value_ty)
}
/// Generate code for an assignment expression where LHS is a `"target_list"`.
///
/// See <https://docs.python.org/3/reference/simple_stmts.html#assignment-statements>.
fn gen_assign_target_list<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
targets: &Vec<Expr<Option<Type>>>,
value: ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String>
where
Self: Sized,
{
gen_assign_target_list(self, ctx, targets, value, value_ty)
}
/// Generate code for an item assignment.
///
/// i.e., `target[key] = value`
fn gen_setitem<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
target: &Expr<Option<Type>>,
key: &Expr<Option<Type>>,
value: ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String>
where
Self: Sized,
{
gen_setitem(self, ctx, target, key, value, value_ty)
gen_assign(self, ctx, target, value)
}
/// Generate code for a while expression.
@ -180,8 +131,8 @@ pub trait CodeGenerator {
gen_while(self, ctx, stmt)
}
/// Generate code for a for expression.
/// Return true if the for loop must early return
/// Generate code for a while expression.
/// Return true if the while loop must early return
fn gen_for(
&mut self,
ctx: &mut CodeGenContext<'_, '_>,
@ -247,7 +198,7 @@ pub trait CodeGenerator {
fn bool_to_i1<'ctx>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
bool_value: IntValue<'ctx>,
bool_value: IntValue<'ctx>
) -> IntValue<'ctx> {
bool_to_i1(&ctx.builder, bool_value)
}
@ -256,7 +207,7 @@ pub trait CodeGenerator {
fn bool_to_i8<'ctx>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
bool_value: IntValue<'ctx>,
bool_value: IntValue<'ctx>
) -> IntValue<'ctx> {
bool_to_i8(&ctx.builder, ctx.ctx, bool_value)
}
@ -276,6 +227,7 @@ impl DefaultCodeGenerator {
}
impl CodeGenerator for DefaultCodeGenerator {
/// Returns the name for this [`CodeGenerator`].
fn get_name(&self) -> &str {
&self.name

View File

@ -0,0 +1,199 @@
typedef _BitInt(8) int8_t;
typedef unsigned _BitInt(8) uint8_t;
typedef _BitInt(32) int32_t;
typedef unsigned _BitInt(32) uint32_t;
typedef _BitInt(64) int64_t;
typedef unsigned _BitInt(64) uint64_t;
# define MAX(a, b) (a > b ? a : b)
# define MIN(a, b) (a > b ? b : a)
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
// need to make sure `exp >= 0` before calling this function
#define DEF_INT_EXP(T) T __nac3_int_exp_##T( \
T base, \
T exp \
) { \
T res = (T)1; \
/* repeated squaring method */ \
do { \
if (exp & 1) res *= base; /* for n odd */ \
exp >>= 1; \
base *= base; \
} while (exp); \
return res; \
} \
DEF_INT_EXP(int32_t)
DEF_INT_EXP(int64_t)
DEF_INT_EXP(uint32_t)
DEF_INT_EXP(uint64_t)
int32_t __nac3_slice_index_bound(int32_t i, const int32_t len) {
if (i < 0) {
i = len + i;
}
if (i < 0) {
return 0;
} else if (i > len) {
return len;
}
return i;
}
int32_t __nac3_range_slice_len(const int32_t start, const int32_t end, const int32_t step) {
int32_t diff = end - start;
if (diff > 0 && step > 0) {
return ((diff - 1) / step) + 1;
} else if (diff < 0 && step < 0) {
return ((diff + 1) / step) + 1;
} else {
return 0;
}
}
// Handle list assignment and dropping part of the list when
// both dest_step and src_step are +1.
// - All the index must *not* be out-of-bound or negative,
// - The end index is *inclusive*,
// - The length of src and dest slice size should already
// be checked: if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest)
int32_t __nac3_list_slice_assign_var_size(
int32_t dest_start,
int32_t dest_end,
int32_t dest_step,
uint8_t *dest_arr,
int32_t dest_arr_len,
int32_t src_start,
int32_t src_end,
int32_t src_step,
uint8_t *src_arr,
int32_t src_arr_len,
const int32_t size
) {
/* if dest_arr_len == 0, do nothing since we do not support extending list */
if (dest_arr_len == 0) return dest_arr_len;
/* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */
if (src_step == dest_step && dest_step == 1) {
const int32_t src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0;
const int32_t dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
if (src_len > 0) {
__builtin_memmove(
dest_arr + dest_start * size,
src_arr + src_start * size,
src_len * size
);
}
if (dest_len > 0) {
/* dropping */
__builtin_memmove(
dest_arr + (dest_start + src_len) * size,
dest_arr + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size
);
}
/* shrink size */
return dest_arr_len - (dest_len - src_len);
}
/* if two range overlaps, need alloca */
uint8_t need_alloca =
(dest_arr == src_arr)
&& !(
MAX(dest_start, dest_end) < MIN(src_start, src_end)
|| MAX(src_start, src_end) < MIN(dest_start, dest_end)
);
if (need_alloca) {
uint8_t *tmp = __builtin_alloca(src_arr_len * size);
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
src_arr = tmp;
}
int32_t src_ind = src_start;
int32_t dest_ind = dest_start;
for (;
(src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end);
src_ind += src_step, dest_ind += dest_step
) {
/* for constant optimization */
if (size == 1) {
__builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1);
} else if (size == 4) {
__builtin_memcpy(dest_arr + dest_ind * 4, src_arr + src_ind * 4, 4);
} else if (size == 8) {
__builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8);
} else {
/* memcpy for var size, cannot overlap after previous alloca */
__builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size);
}
}
/* only dest_step == 1 can we shrink the dest list. */
/* size should be ensured prior to calling this function */
if (dest_step == 1 && dest_end >= dest_start) {
__builtin_memmove(
dest_arr + dest_ind * size,
dest_arr + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size
);
return dest_arr_len - (dest_end - dest_ind) - 1;
}
return dest_arr_len;
}
int32_t __nac3_isinf(double x) {
return __builtin_isinf(x);
}
int32_t __nac3_isnan(double x) {
return __builtin_isnan(x);
}
double tgamma(double arg);
double __nac3_gamma(double z) {
// Handling for denormals
// | x | Python gamma(x) | C tgamma(x) |
// --- | ----------------- | --------------- | ----------- |
// (1) | nan | nan | nan |
// (2) | -inf | -inf | inf |
// (3) | inf | inf | inf |
// (4) | 0.0 | inf | inf |
// (5) | {-1.0, -2.0, ...} | inf | nan |
// (1)-(3)
if (__builtin_isinf(z) || __builtin_isnan(z)) {
return z;
}
double v = tgamma(z);
// (4)-(5)
return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v;
}
double lgamma(double arg);
double __nac3_gammaln(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: gammaln(-inf) -> -inf
// - libm : lgamma(-inf) -> inf
if (__builtin_isinf(x)) {
return x;
}
return lgamma(x);
}
double j0(double x);
double __nac3_j0(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: j0(inf) -> nan
// - libm : j0(inf) -> 0.0
if (__builtin_isinf(x)) {
return __builtin_nan("");
}
return j0(x);
}

View File

@ -1,162 +0,0 @@
use inkwell::{
types::BasicTypeEnum,
values::{BasicValueEnum, CallSiteValue, IntValue},
AddressSpace, IntPredicate,
};
use itertools::Either;
use super::calculate_len_for_slice_range;
use crate::codegen::{
macros::codegen_unreachable,
values::{ArrayLikeValue, ListValue},
CodeGenContext, CodeGenerator,
};
/// This function handles 'end' **inclusively**.
/// Order of tuples `assign_idx` and `value_idx` is ('start', 'end', 'step').
/// Negative index should be handled before entering this function
pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ty: BasicTypeEnum<'ctx>,
dest_arr: ListValue<'ctx>,
dest_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
src_arr: ListValue<'ctx>,
src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
) {
let size_ty = generator.get_size_type(ctx.ctx);
let int8_ptr = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
let int32 = ctx.ctx.i32_type();
let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", int8_ptr);
let slice_assign_fun = {
let ty_vec = vec![
int32.into(), // dest start idx
int32.into(), // dest end idx
int32.into(), // dest step
elem_ptr_type.into(), // dest arr ptr
int32.into(), // dest arr len
int32.into(), // src start idx
int32.into(), // src end idx
int32.into(), // src step
elem_ptr_type.into(), // src arr ptr
int32.into(), // src arr len
int32.into(), // size
];
ctx.module.get_function(fun_symbol).unwrap_or_else(|| {
let fn_t = int32.fn_type(ty_vec.as_slice(), false);
ctx.module.add_function(fun_symbol, fn_t, None)
})
};
let zero = int32.const_zero();
let one = int32.const_int(1, false);
let dest_arr_ptr = dest_arr.data().base_ptr(ctx, generator);
let dest_arr_ptr =
ctx.builder.build_pointer_cast(dest_arr_ptr, elem_ptr_type, "dest_arr_ptr_cast").unwrap();
let dest_len = dest_arr.load_size(ctx, Some("dest.len"));
let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32").unwrap();
let src_arr_ptr = src_arr.data().base_ptr(ctx, generator);
let src_arr_ptr =
ctx.builder.build_pointer_cast(src_arr_ptr, elem_ptr_type, "src_arr_ptr_cast").unwrap();
let src_len = src_arr.load_size(ctx, Some("src.len"));
let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32").unwrap();
// index in bound and positive should be done
// assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and
// throw exception if not satisfied
let src_end = ctx
.builder
.build_select(
ctx.builder.build_int_compare(IntPredicate::SLT, src_idx.2, zero, "is_neg").unwrap(),
ctx.builder.build_int_sub(src_idx.1, one, "e_min_one").unwrap(),
ctx.builder.build_int_add(src_idx.1, one, "e_add_one").unwrap(),
"final_e",
)
.map(BasicValueEnum::into_int_value)
.unwrap();
let dest_end = ctx
.builder
.build_select(
ctx.builder.build_int_compare(IntPredicate::SLT, dest_idx.2, zero, "is_neg").unwrap(),
ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one").unwrap(),
ctx.builder.build_int_add(dest_idx.1, one, "e_add_one").unwrap(),
"final_e",
)
.map(BasicValueEnum::into_int_value)
.unwrap();
let src_slice_len =
calculate_len_for_slice_range(generator, ctx, src_idx.0, src_end, src_idx.2);
let dest_slice_len =
calculate_len_for_slice_range(generator, ctx, dest_idx.0, dest_end, dest_idx.2);
let src_eq_dest = ctx
.builder
.build_int_compare(IntPredicate::EQ, src_slice_len, dest_slice_len, "slice_src_eq_dest")
.unwrap();
let src_slt_dest = ctx
.builder
.build_int_compare(IntPredicate::SLT, src_slice_len, dest_slice_len, "slice_src_slt_dest")
.unwrap();
let dest_step_eq_one = ctx
.builder
.build_int_compare(
IntPredicate::EQ,
dest_idx.2,
dest_idx.2.get_type().const_int(1, false),
"slice_dest_step_eq_one",
)
.unwrap();
let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1").unwrap();
let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond").unwrap();
ctx.make_assert(
generator,
cond,
"0:ValueError",
"attempt to assign sequence of size {0} to slice of size {1} with step size {2}",
[Some(src_slice_len), Some(dest_slice_len), Some(dest_idx.2)],
ctx.current_loc,
);
let new_len = {
let args = vec![
dest_idx.0.into(), // dest start idx
dest_idx.1.into(), // dest end idx
dest_idx.2.into(), // dest step
dest_arr_ptr.into(), // dest arr ptr
dest_len.into(), // dest arr len
src_idx.0.into(), // src start idx
src_idx.1.into(), // src end idx
src_idx.2.into(), // src step
src_arr_ptr.into(), // src arr ptr
src_len.into(), // src arr len
{
let s = match ty {
BasicTypeEnum::FloatType(t) => t.size_of(),
BasicTypeEnum::IntType(t) => t.size_of(),
BasicTypeEnum::PointerType(t) => t.size_of(),
BasicTypeEnum::StructType(t) => t.size_of().unwrap(),
_ => codegen_unreachable!(ctx),
};
ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size").unwrap()
}
.into(),
];
ctx.builder
.build_call(slice_assign_fun, args.as_slice(), "slice_assign")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
};
// update length
let need_update =
ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update").unwrap();
let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
let update_bb = ctx.ctx.append_basic_block(current, "update");
let cont_bb = ctx.ctx.append_basic_block(current, "cont");
ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap();
ctx.builder.position_at_end(update_bb);
let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len").unwrap();
dest_arr.store_size(ctx, generator, new_len);
ctx.builder.build_unconditional_branch(cont_bb).unwrap();
ctx.builder.position_at_end(cont_bb);
}

View File

@ -1,152 +0,0 @@
use inkwell::{
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue},
IntPredicate,
};
use itertools::Either;
use crate::codegen::{
macros::codegen_unreachable,
{CodeGenContext, CodeGenerator},
};
// repeated squaring method adapted from GNU Scientific Library:
// https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
base: IntValue<'ctx>,
exp: IntValue<'ctx>,
signed: bool,
) -> IntValue<'ctx> {
let symbol = match (base.get_type().get_bit_width(), exp.get_type().get_bit_width(), signed) {
(32, 32, true) => "__nac3_int_exp_int32_t",
(64, 64, true) => "__nac3_int_exp_int64_t",
(32, 32, false) => "__nac3_int_exp_uint32_t",
(64, 64, false) => "__nac3_int_exp_uint64_t",
_ => codegen_unreachable!(ctx),
};
let base_type = base.get_type();
let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| {
let fn_type = base_type.fn_type(&[base_type.into(), base_type.into()], false);
ctx.module.add_function(symbol, fn_type, None)
});
// throw exception when exp < 0
let ge_zero = ctx
.builder
.build_int_compare(
IntPredicate::SGE,
exp,
exp.get_type().const_zero(),
"assert_int_pow_ge_0",
)
.unwrap();
ctx.make_assert(
generator,
ge_zero,
"0:ValueError",
"integer power must be positive or zero",
[None, None, None],
ctx.current_loc,
);
ctx.builder
.build_call(pow_fun, &[base.into(), exp.into()], "call_int_pow")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Generates a call to `isinf` in IR. Returns an `i1` representing the result.
pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> IntValue<'ctx> {
let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| {
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
ctx.module.add_function("__nac3_isinf", fn_type, None)
});
let ret = ctx
.builder
.build_call(intrinsic_fn, &[v.into()], "isinf")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap();
generator.bool_to_i1(ctx, ret)
}
/// Generates a call to `isnan` in IR. Returns an `i1` representing the result.
pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> IntValue<'ctx> {
let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| {
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
ctx.module.add_function("__nac3_isnan", fn_type, None)
});
let ret = ctx
.builder
.build_call(intrinsic_fn, &[v.into()], "isnan")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap();
generator.bool_to_i1(ctx, ret)
}
/// Generates a call to `gamma` in IR. Returns an `f64` representing the result.
pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("__nac3_gamma", fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[v.into()], "gamma")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Generates a call to `gammaln` in IR. Returns an `f64` representing the result.
pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("__nac3_gammaln", fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[v.into()], "gammaln")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Generates a call to `j0` in IR. Returns an `f64` representing the result.
pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("__nac3_j0", fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[v.into()], "j0")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}

View File

@ -1,28 +1,19 @@
use crate::typecheck::typedef::Type;
use super::{CodeGenContext, CodeGenerator};
use inkwell::{
attributes::{Attribute, AttributeLoc},
context::Context,
memory_buffer::MemoryBuffer,
module::Module,
values::{BasicValue, BasicValueEnum, IntValue},
IntPredicate,
types::BasicTypeEnum,
values::{FloatValue, IntValue, PointerValue},
AddressSpace, IntPredicate,
};
use nac3parser::ast::Expr;
use super::{CodeGenContext, CodeGenerator};
use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Type};
pub use list::*;
pub use math::*;
pub use ndarray::*;
pub use slice::*;
mod list;
mod math;
mod ndarray;
mod slice;
#[must_use]
pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> {
pub fn load_irrt(ctx: &Context) -> Module {
let bitcode_buf = MemoryBuffer::create_from_memory_range(
include_bytes!(concat!(env!("OUT_DIR"), "/irrt.bc")),
"irrt_bitcode_buffer",
@ -38,26 +29,87 @@ pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver)
let function = irrt_mod.get_function(symbol).unwrap();
function.add_attribute(AttributeLoc::Function, ctx.create_enum_attribute(inline_attr, 0));
}
// Initialize all global `EXN_*` exception IDs in IRRT with the [`SymbolResolver`].
let exn_id_type = ctx.i32_type();
let errors = &[
("EXN_INDEX_ERROR", "0:IndexError"),
("EXN_VALUE_ERROR", "0:ValueError"),
("EXN_ASSERTION_ERROR", "0:AssertionError"),
("EXN_TYPE_ERROR", "0:TypeError"),
];
for (irrt_name, symbol_name) in errors {
let exn_id = symbol_resolver.get_string_id(symbol_name);
let exn_id = exn_id_type.const_int(exn_id as u64, false).as_basic_value_enum();
let global = irrt_mod.get_global(irrt_name).unwrap_or_else(|| {
panic!("Exception symbol name '{irrt_name}' should exist in the IRRT LLVM module")
});
global.set_initializer(&exn_id);
irrt_mod
}
irrt_mod
// repeated squaring method adapted from GNU Scientific Library:
// https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
pub fn integer_power<'ctx>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>,
base: IntValue<'ctx>,
exp: IntValue<'ctx>,
signed: bool,
) -> IntValue<'ctx> {
let symbol = match (base.get_type().get_bit_width(), exp.get_type().get_bit_width(), signed) {
(32, 32, true) => "__nac3_int_exp_int32_t",
(64, 64, true) => "__nac3_int_exp_int64_t",
(32, 32, false) => "__nac3_int_exp_uint32_t",
(64, 64, false) => "__nac3_int_exp_uint64_t",
_ => unreachable!(),
};
let base_type = base.get_type();
let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| {
let fn_type = base_type.fn_type(&[base_type.into(), base_type.into()], false);
ctx.module.add_function(symbol, fn_type, None)
});
// throw exception when exp < 0
let ge_zero = ctx.builder.build_int_compare(
IntPredicate::SGE,
exp,
exp.get_type().const_zero(),
"assert_int_pow_ge_0",
);
ctx.make_assert(
generator,
ge_zero,
"0:ValueError",
"integer power must be positive or zero",
[None, None, None],
ctx.current_loc,
);
ctx.builder
.build_call(pow_fun, &[base.into(), exp.into()], "call_int_pow")
.try_as_basic_value()
.unwrap_left()
.into_int_value()
}
pub fn calculate_len_for_slice_range<'ctx>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>,
start: IntValue<'ctx>,
end: IntValue<'ctx>,
step: IntValue<'ctx>,
) -> IntValue<'ctx> {
const SYMBOL: &str = "__nac3_range_slice_len";
let len_func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
let i32_t = ctx.ctx.i32_type();
let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into(), i32_t.into()], false);
ctx.module.add_function(SYMBOL, fn_t, None)
});
// assert step != 0, throw exception if not
let not_zero = ctx.builder.build_int_compare(
IntPredicate::NE,
step,
step.get_type().const_zero(),
"range_step_ne",
);
ctx.make_assert(
generator,
not_zero,
"0:ValueError",
"step must not be zero",
[None, None, None],
ctx.current_loc,
);
ctx.builder
.build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len")
.try_as_basic_value()
.left()
.unwrap()
.into_int_value()
}
/// NOTE: the output value of the end index of this function should be compared ***inclusively***,
@ -106,12 +158,13 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
step: &Option<Box<Expr<Option<Type>>>>,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
length: IntValue<'ctx>,
list: PointerValue<'ctx>,
) -> Result<Option<(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)>, String> {
let int32 = ctx.ctx.i32_type();
let zero = int32.const_zero();
let one = int32.const_int(1, false);
let length = ctx.builder.build_int_truncate_or_bit_cast(length, int32, "leni32").unwrap();
let length = ctx.build_gep_and_load(list, &[zero, one], Some("length")).into_int_value();
let length = ctx.builder.build_int_truncate_or_bit_cast(length, int32, "leni32");
Ok(Some(match (start, end, step) {
(s, e, None) => (
if let Some(s) = s.as_ref() {
@ -131,7 +184,7 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
} else {
length
};
ctx.builder.build_int_sub(e, one, "final_end").unwrap()
ctx.builder.build_int_sub(e, one, "final_end")
},
one,
),
@ -139,18 +192,15 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
let step = if let Some(v) = generator.gen_expr(ctx, step)? {
v.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?.into_int_value()
} else {
return Ok(None);
return Ok(None)
};
// assert step != 0, throw exception if not
let not_zero = ctx
.builder
.build_int_compare(
let not_zero = ctx.builder.build_int_compare(
IntPredicate::NE,
step,
step.get_type().const_zero(),
"range_step_ne",
)
.unwrap();
);
ctx.make_assert(
generator,
not_zero,
@ -159,69 +209,340 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
[None, None, None],
ctx.current_loc,
);
let len_id = ctx.builder.build_int_sub(length, one, "lenmin1").unwrap();
let neg = ctx
.builder
.build_int_compare(IntPredicate::SLT, step, zero, "step_is_neg")
.unwrap();
let len_id = ctx.builder.build_int_sub(length, one, "lenmin1");
let neg = ctx.builder.build_int_compare(IntPredicate::SLT, step, zero, "step_is_neg");
(
match s {
Some(s) => {
let Some(s) = handle_slice_index_bound(s, ctx, generator, length)? else {
return Ok(None);
return Ok(None)
};
ctx.builder
.build_select(
ctx.builder
.build_and(
ctx.builder
.build_int_compare(
ctx.builder.build_and(
ctx.builder.build_int_compare(
IntPredicate::EQ,
s,
length,
"s_eq_len",
)
.unwrap(),
),
neg,
"should_minus_one",
)
.unwrap(),
ctx.builder.build_int_sub(s, one, "s_min").unwrap(),
),
ctx.builder.build_int_sub(s, one, "s_min"),
s,
"final_start",
)
.map(BasicValueEnum::into_int_value)
.unwrap()
.into_int_value()
}
None => ctx
.builder
.build_select(neg, len_id, zero, "stt")
.map(BasicValueEnum::into_int_value)
.unwrap(),
None => ctx.builder.build_select(neg, len_id, zero, "stt").into_int_value(),
},
match e {
Some(e) => {
let Some(e) = handle_slice_index_bound(e, ctx, generator, length)? else {
return Ok(None);
return Ok(None)
};
ctx.builder
.build_select(
neg,
ctx.builder.build_int_add(e, one, "end_add_one").unwrap(),
ctx.builder.build_int_sub(e, one, "end_sub_one").unwrap(),
ctx.builder.build_int_add(e, one, "end_add_one"),
ctx.builder.build_int_sub(e, one, "end_sub_one"),
"final_end",
)
.map(BasicValueEnum::into_int_value)
.unwrap()
.into_int_value()
}
None => ctx
.builder
.build_select(neg, zero, len_id, "end")
.map(BasicValueEnum::into_int_value)
.unwrap(),
None => ctx.builder.build_select(neg, zero, len_id, "end").into_int_value(),
},
step,
)
}
}))
}
/// this function allows index out of range, since python
/// allows index out of range in slice (`a = [1,2,3]; a[1:10] == [2,3]`).
pub fn handle_slice_index_bound<'ctx, G: CodeGenerator>(
i: &Expr<Option<Type>>,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
length: IntValue<'ctx>,
) -> Result<Option<IntValue<'ctx>>, String> {
const SYMBOL: &str = "__nac3_slice_index_bound";
let func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
let i32_t = ctx.ctx.i32_type();
let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into()], false);
ctx.module.add_function(SYMBOL, fn_t, None)
});
let i = if let Some(v) = generator.gen_expr(ctx, i)? {
v.to_basic_value_enum(ctx, generator, i.custom.unwrap())?
} else {
return Ok(None)
};
Ok(Some(ctx
.builder
.build_call(func, &[i.into(), length.into()], "bounded_ind")
.try_as_basic_value()
.left()
.unwrap()
.into_int_value()))
}
/// This function handles 'end' **inclusively**.
/// Order of tuples `assign_idx` and `value_idx` is ('start', 'end', 'step').
/// Negative index should be handled before entering this function
pub fn list_slice_assignment<'ctx>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>,
ty: BasicTypeEnum<'ctx>,
dest_arr: PointerValue<'ctx>,
dest_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
src_arr: PointerValue<'ctx>,
src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
) {
let size_ty = generator.get_size_type(ctx.ctx);
let int8_ptr = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
let int32 = ctx.ctx.i32_type();
let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", int8_ptr);
let slice_assign_fun = {
let ty_vec = vec![
int32.into(), // dest start idx
int32.into(), // dest end idx
int32.into(), // dest step
elem_ptr_type.into(), // dest arr ptr
int32.into(), // dest arr len
int32.into(), // src start idx
int32.into(), // src end idx
int32.into(), // src step
elem_ptr_type.into(), // src arr ptr
int32.into(), // src arr len
int32.into(), // size
];
ctx.module.get_function(fun_symbol).unwrap_or_else(|| {
let fn_t = int32.fn_type(ty_vec.as_slice(), false);
ctx.module.add_function(fun_symbol, fn_t, None)
})
};
let zero = int32.const_zero();
let one = int32.const_int(1, false);
let dest_arr_ptr = ctx.build_gep_and_load(dest_arr, &[zero, zero], Some("dest.addr"));
let dest_arr_ptr = ctx.builder.build_pointer_cast(
dest_arr_ptr.into_pointer_value(),
elem_ptr_type,
"dest_arr_ptr_cast",
);
let dest_len = ctx.build_gep_and_load(dest_arr, &[zero, one], Some("dest.len")).into_int_value();
let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32");
let src_arr_ptr = ctx.build_gep_and_load(src_arr, &[zero, zero], Some("src.addr"));
let src_arr_ptr = ctx.builder.build_pointer_cast(
src_arr_ptr.into_pointer_value(),
elem_ptr_type,
"src_arr_ptr_cast",
);
let src_len = ctx.build_gep_and_load(src_arr, &[zero, one], Some("src.len")).into_int_value();
let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32");
// index in bound and positive should be done
// assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and
// throw exception if not satisfied
let src_end = ctx.builder
.build_select(
ctx.builder.build_int_compare(
IntPredicate::SLT,
src_idx.2,
zero,
"is_neg",
),
ctx.builder.build_int_sub(src_idx.1, one, "e_min_one"),
ctx.builder.build_int_add(src_idx.1, one, "e_add_one"),
"final_e",
)
.into_int_value();
let dest_end = ctx.builder
.build_select(
ctx.builder.build_int_compare(
IntPredicate::SLT,
dest_idx.2,
zero,
"is_neg",
),
ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one"),
ctx.builder.build_int_add(dest_idx.1, one, "e_add_one"),
"final_e",
)
.into_int_value();
let src_slice_len =
calculate_len_for_slice_range(generator, ctx, src_idx.0, src_end, src_idx.2);
let dest_slice_len =
calculate_len_for_slice_range(generator, ctx, dest_idx.0, dest_end, dest_idx.2);
let src_eq_dest = ctx.builder.build_int_compare(
IntPredicate::EQ,
src_slice_len,
dest_slice_len,
"slice_src_eq_dest",
);
let src_slt_dest = ctx.builder.build_int_compare(
IntPredicate::SLT,
src_slice_len,
dest_slice_len,
"slice_src_slt_dest",
);
let dest_step_eq_one = ctx.builder.build_int_compare(
IntPredicate::EQ,
dest_idx.2,
dest_idx.2.get_type().const_int(1, false),
"slice_dest_step_eq_one",
);
let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1");
let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond");
ctx.make_assert(
generator,
cond,
"0:ValueError",
"attempt to assign sequence of size {0} to slice of size {1} with step size {2}",
[Some(src_slice_len), Some(dest_slice_len), Some(dest_idx.2)],
ctx.current_loc,
);
let new_len = {
let args = vec![
dest_idx.0.into(), // dest start idx
dest_idx.1.into(), // dest end idx
dest_idx.2.into(), // dest step
dest_arr_ptr.into(), // dest arr ptr
dest_len.into(), // dest arr len
src_idx.0.into(), // src start idx
src_idx.1.into(), // src end idx
src_idx.2.into(), // src step
src_arr_ptr.into(), // src arr ptr
src_len.into(), // src arr len
{
let s = match ty {
BasicTypeEnum::FloatType(t) => t.size_of(),
BasicTypeEnum::IntType(t) => t.size_of(),
BasicTypeEnum::PointerType(t) => t.size_of(),
BasicTypeEnum::StructType(t) => t.size_of().unwrap(),
_ => unreachable!(),
};
ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size")
}
.into(),
];
ctx.builder
.build_call(slice_assign_fun, args.as_slice(), "slice_assign")
.try_as_basic_value()
.unwrap_left()
.into_int_value()
};
// update length
let need_update =
ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update");
let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
let update_bb = ctx.ctx.append_basic_block(current, "update");
let cont_bb = ctx.ctx.append_basic_block(current, "cont");
ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb);
ctx.builder.position_at_end(update_bb);
let dest_len_ptr = unsafe { ctx.builder.build_gep(dest_arr, &[zero, one], "dest_len_ptr") };
let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len");
ctx.builder.build_store(dest_len_ptr, new_len);
ctx.builder.build_unconditional_branch(cont_bb);
ctx.builder.position_at_end(cont_bb);
}
/// Generates a call to `isinf` in IR. Returns an `i1` representing the result.
pub fn call_isinf<'ctx>(
generator: &mut dyn CodeGenerator,
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> IntValue<'ctx> {
let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| {
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
ctx.module.add_function("__nac3_isinf", fn_type, None)
});
let ret = ctx.builder
.build_call(intrinsic_fn, &[v.into()], "isinf")
.try_as_basic_value()
.unwrap_left()
.into_int_value();
generator.bool_to_i1(ctx, ret)
}
/// Generates a call to `isnan` in IR. Returns an `i1` representing the result.
pub fn call_isnan<'ctx>(
generator: &mut dyn CodeGenerator,
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> IntValue<'ctx> {
let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| {
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
ctx.module.add_function("__nac3_isnan", fn_type, None)
});
let ret = ctx.builder
.build_call(intrinsic_fn, &[v.into()], "isnan")
.try_as_basic_value()
.unwrap_left()
.into_int_value();
generator.bool_to_i1(ctx, ret)
}
/// Generates a call to `gamma` in IR. Returns an `f64` representing the result.
pub fn call_gamma<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("__nac3_gamma", fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[v.into()], "gamma")
.try_as_basic_value()
.unwrap_left()
.into_float_value()
}
/// Generates a call to `gammaln` in IR. Returns an `f64` representing the result.
pub fn call_gammaln<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("__nac3_gammaln", fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[v.into()], "gammaln")
.try_as_basic_value()
.unwrap_left()
.into_float_value()
}
/// Generates a call to `j0` in IR. Returns an `f64` representing the result.
pub fn call_j0<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("__nac3_j0", fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[v.into()], "j0")
.try_as_basic_value()
.unwrap_left()
.into_float_value()
}

View File

@ -1,384 +0,0 @@
use inkwell::{
types::IntType,
values::{BasicValueEnum, CallSiteValue, IntValue},
AddressSpace, IntPredicate,
};
use itertools::Either;
use crate::codegen::{
llvm_intrinsics,
macros::codegen_unreachable,
stmt::gen_for_callback_incrementing,
values::{
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, NDArrayValue, TypedArrayLikeAccessor,
TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
},
CodeGenContext, CodeGenerator,
};
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
/// calculated total size.
///
/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension.
/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for,
/// or [`None`] if starting from the first dimension and ending at the last dimension
/// respectively.
pub fn call_ndarray_calc_size<'ctx, G, Dims>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
dims: &Dims,
(begin, end): (Option<IntValue<'ctx>>, Option<IntValue<'ctx>>),
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Dims: ArrayLikeIndexer<'ctx>,
{
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_size",
64 => "__nac3_ndarray_calc_size64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_size_fn_t = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()],
false,
);
let ndarray_calc_size_fn =
ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| {
ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
});
let begin = begin.unwrap_or_else(|| llvm_usize.const_zero());
let end = end.unwrap_or_else(|| dims.size(ctx, generator));
ctx.builder
.build_call(
ndarray_calc_size_fn,
&[
dims.base_ptr(ctx, generator).into(),
dims.size(ctx, generator).into(),
begin.into(),
end.into(),
],
"",
)
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`]
/// containing `i32` indices of the flattened index.
///
/// * `index` - The index to compute the multidimensional index for.
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &mut CodeGenContext<'ctx, '_>,
index: IntValue<'ctx>,
ndarray: NDArrayValue<'ctx>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_void = ctx.ctx.void_type();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_nd_indices",
64 => "__nac3_ndarray_calc_nd_indices64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_nd_indices_fn =
ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| {
let fn_type = llvm_void.fn_type(
&[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()],
false,
);
ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.shape();
let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap();
ctx.builder
.build_call(
ndarray_calc_nd_indices_fn,
&[
index.into(),
ndarray_dims.base_ptr(ctx, generator).into(),
ndarray_num_dims.into(),
indices.into(),
],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None),
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}
fn call_ndarray_flatten_index_impl<'ctx, G, Indices>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: &Indices,
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Indices: ArrayLikeIndexer<'ctx>,
{
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
debug_assert_eq!(
IntType::try_from(indices.element_type(ctx, generator))
.map(IntType::get_bit_width)
.unwrap_or_default(),
llvm_i32.get_bit_width(),
"Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`"
);
debug_assert_eq!(
indices.size(ctx, generator).get_type().get_bit_width(),
llvm_usize.get_bit_width(),
"Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`"
);
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_flatten_index",
64 => "__nac3_ndarray_flatten_index64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_flatten_index_fn =
ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()],
false,
);
ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.shape();
let index = ctx
.builder
.build_call(
ndarray_flatten_index_fn,
&[
ndarray_dims.base_ptr(ctx, generator).into(),
ndarray_num_dims.into(),
indices.base_ptr(ctx, generator).into(),
indices.size(ctx, generator).into(),
],
"",
)
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap();
index
}
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
/// multidimensional index.
///
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
/// * `indices` - The multidimensional index to compute the flattened index for.
pub fn call_ndarray_flatten_index<'ctx, G, Index>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: &Index,
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Index: ArrayLikeIndexer<'ctx>,
{
call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices)
}
/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of
/// dimension and size of each dimension of the resultant `ndarray`.
pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
lhs: NDArrayValue<'ctx>,
rhs: NDArrayValue<'ctx>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast",
64 => "__nac3_ndarray_calc_broadcast64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_broadcast_fn =
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
],
false,
);
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_ndims = rhs.load_ndims(ctx);
let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None);
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(min_ndims, false),
|generator, ctx, _, idx| {
let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap();
let (lhs_dim_sz, rhs_dim_sz) = unsafe {
(
lhs.shape().get_typed_unchecked(ctx, generator, &idx, None),
rhs.shape().get_typed_unchecked(ctx, generator, &idx, None),
)
};
let llvm_usize_const_one = llvm_usize.const_int(1, false);
let lhs_eqz = ctx
.builder
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "")
.unwrap();
let rhs_eqz = ctx
.builder
.build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "")
.unwrap();
let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap();
let lhs_eq_rhs = ctx
.builder
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "")
.unwrap();
let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap();
ctx.make_assert(
generator,
is_compatible,
"0:ValueError",
"operands could not be broadcast together",
[None, None, None],
ctx.current_loc,
);
Ok(())
},
llvm_usize.const_int(1, false),
)
.unwrap();
let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
let lhs_dims = lhs.shape().base_ptr(ctx, generator);
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_dims = rhs.shape().base_ptr(ctx, generator);
let rhs_ndims = rhs.load_ndims(ctx);
let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap();
let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None);
ctx.builder
.build_call(
ndarray_calc_broadcast_fn,
&[
lhs_dims.into(),
lhs_ndims.into(),
rhs_dims.into(),
rhs_ndims.into(),
out_dims.base_ptr(ctx, generator).into(),
],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
out_dims,
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}
/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`]
/// containing the indices used for accessing `array` corresponding to the index of the broadcasted
/// array `broadcast_idx`.
pub fn call_ndarray_calc_broadcast_index<
'ctx,
G: CodeGenerator + ?Sized,
BroadcastIdx: UntypedArrayLikeAccessor<'ctx>,
>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
array: NDArrayValue<'ctx>,
broadcast_idx: &BroadcastIdx,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast_idx",
64 => "__nac3_ndarray_calc_broadcast_idx64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_broadcast_fn =
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()],
false,
);
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
let broadcast_size = broadcast_idx.size(ctx, generator);
let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();
let array_dims = array.shape().base_ptr(ctx, generator);
let array_ndims = array.load_ndims(ctx);
let broadcast_idx_ptr = unsafe {
broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
ctx.builder
.build_call(
ndarray_calc_broadcast_fn,
&[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None),
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}

View File

@ -1,76 +0,0 @@
use inkwell::{
values::{BasicValueEnum, CallSiteValue, IntValue},
IntPredicate,
};
use itertools::Either;
use nac3parser::ast::Expr;
use crate::{
codegen::{CodeGenContext, CodeGenerator},
typecheck::typedef::Type,
};
/// this function allows index out of range, since python
/// allows index out of range in slice (`a = [1,2,3]; a[1:10] == [2,3]`).
pub fn handle_slice_index_bound<'ctx, G: CodeGenerator>(
i: &Expr<Option<Type>>,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
length: IntValue<'ctx>,
) -> Result<Option<IntValue<'ctx>>, String> {
const SYMBOL: &str = "__nac3_slice_index_bound";
let func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
let i32_t = ctx.ctx.i32_type();
let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into()], false);
ctx.module.add_function(SYMBOL, fn_t, None)
});
let i = if let Some(v) = generator.gen_expr(ctx, i)? {
v.to_basic_value_enum(ctx, generator, i.custom.unwrap())?
} else {
return Ok(None);
};
Ok(Some(
ctx.builder
.build_call(func, &[i.into(), length.into()], "bounded_ind")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap(),
))
}
pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
start: IntValue<'ctx>,
end: IntValue<'ctx>,
step: IntValue<'ctx>,
) -> IntValue<'ctx> {
const SYMBOL: &str = "__nac3_range_slice_len";
let len_func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
let i32_t = ctx.ctx.i32_type();
let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into(), i32_t.into()], false);
ctx.module.add_function(SYMBOL, fn_t, None)
});
// assert step != 0, throw exception if not
let not_zero = ctx
.builder
.build_int_compare(IntPredicate::NE, step, step.get_type().const_zero(), "range_step_ne")
.unwrap();
ctx.make_assert(
generator,
not_zero,
"0:ValueError",
"step must not be zero",
[None, None, None],
ctx.current_loc,
);
ctx.builder
.build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}

View File

@ -1,345 +0,0 @@
use inkwell::{
context::Context,
intrinsics::Intrinsic,
types::{AnyTypeEnum::IntType, FloatType},
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue},
AddressSpace,
};
use itertools::Either;
use super::CodeGenContext;
/// Returns the string representation for the floating-point type `ft` when used in intrinsic
/// functions.
fn get_float_intrinsic_repr(ctx: &Context, ft: FloatType) -> &'static str {
// Standard LLVM floating-point types
if ft == ctx.f16_type() {
return "f16";
}
if ft == ctx.f32_type() {
return "f32";
}
if ft == ctx.f64_type() {
return "f64";
}
if ft == ctx.f128_type() {
return "f128";
}
// Non-standard floating-point types
if ft == ctx.x86_f80_type() {
return "f80";
}
if ft == ctx.ppc_f128_type() {
return "ppcf128";
}
unreachable!()
}
/// Invokes the [`llvm.va_start`](https://llvm.org/docs/LangRef.html#llvm-va-start-intrinsic)
/// intrinsic.
pub fn call_va_start<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue<'ctx>) {
const FN_NAME: &str = "llvm.va_start";
let intrinsic_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let llvm_void = ctx.ctx.void_type();
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let fn_type = llvm_void.fn_type(&[llvm_p0i8.into()], false);
ctx.module.add_function(FN_NAME, fn_type, None)
});
ctx.builder.build_call(intrinsic_fn, &[arglist.into()], "").unwrap();
}
/// Invokes the [`llvm.va_start`](https://llvm.org/docs/LangRef.html#llvm-va-start-intrinsic)
/// intrinsic.
pub fn call_va_end<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue<'ctx>) {
const FN_NAME: &str = "llvm.va_end";
let intrinsic_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let llvm_void = ctx.ctx.void_type();
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let fn_type = llvm_void.fn_type(&[llvm_p0i8.into()], false);
ctx.module.add_function(FN_NAME, fn_type, None)
});
ctx.builder.build_call(intrinsic_fn, &[arglist.into()], "").unwrap();
}
/// Invokes the [`llvm.stacksave`](https://llvm.org/docs/LangRef.html#llvm-stacksave-intrinsic)
/// intrinsic.
pub fn call_stacksave<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
name: Option<&str>,
) -> PointerValue<'ctx> {
const FN_NAME: &str = "llvm.stacksave";
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[]))
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_pointer_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the
/// [`llvm.stackrestore`](https://llvm.org/docs/LangRef.html#llvm-stackrestore-intrinsic) intrinsic.
///
/// - `ptr`: The pointer storing the address to restore the stack to.
pub fn call_stackrestore<'ctx>(ctx: &CodeGenContext<'ctx, '_>, ptr: PointerValue<'ctx>) {
const FN_NAME: &str = "llvm.stackrestore";
/*
SEE https://github.com/TheDan64/inkwell/issues/496
We want `llvm.stackrestore`, but the following would generate `llvm.stackrestore.p0i8`.
```ignore
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_p0i8.into()]))
.unwrap();
```
Temp workaround by manually declaring the intrinsic with the correct function name instead.
*/
let intrinsic_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let llvm_void = ctx.ctx.void_type();
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let fn_type = llvm_void.fn_type(&[llvm_p0i8.into()], false);
ctx.module.add_function(FN_NAME, fn_type, None)
});
ctx.builder.build_call(intrinsic_fn, &[ptr.into()], "").unwrap();
}
/// Invokes the [`llvm.memcpy`](https://llvm.org/docs/LangRef.html#llvm-memcpy-intrinsic) intrinsic.
///
/// * `dest` - The pointer to the destination. Must be a pointer to an integer type.
/// * `src` - The pointer to the source. Must be a pointer to an integer type.
/// * `len` - The number of bytes to copy.
/// * `is_volatile` - Whether the `memcpy` operation should be `volatile`.
pub fn call_memcpy<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
dest: PointerValue<'ctx>,
src: PointerValue<'ctx>,
len: IntValue<'ctx>,
is_volatile: IntValue<'ctx>,
) {
const FN_NAME: &str = "llvm.memcpy";
debug_assert!(dest.get_type().get_element_type().is_int_type());
debug_assert!(src.get_type().get_element_type().is_int_type());
debug_assert_eq!(
dest.get_type().get_element_type().into_int_type().get_bit_width(),
src.get_type().get_element_type().into_int_type().get_bit_width(),
);
debug_assert!(matches!(len.get_type().get_bit_width(), 32 | 64));
debug_assert_eq!(is_volatile.get_type().get_bit_width(), 1);
let llvm_dest_t = dest.get_type();
let llvm_src_t = src.get_type();
let llvm_len_t = len.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| {
intrinsic.get_declaration(
&ctx.module,
&[llvm_dest_t.into(), llvm_src_t.into(), llvm_len_t.into()],
)
})
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[dest.into(), src.into(), len.into(), is_volatile.into()], "")
.unwrap();
}
/// Invokes the `llvm.memcpy` intrinsic.
///
/// Unlike [`call_memcpy`], this function accepts any type of pointer value. If `dest` or `src` is
/// not a pointer to an integer, the pointer(s) will be cast to `i8*` before invoking `memcpy`.
pub fn call_memcpy_generic<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
dest: PointerValue<'ctx>,
src: PointerValue<'ctx>,
len: IntValue<'ctx>,
is_volatile: IntValue<'ctx>,
) {
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let dest_elem_t = dest.get_type().get_element_type();
let src_elem_t = src.get_type().get_element_type();
let dest = if matches!(dest_elem_t, IntType(t) if t.get_bit_width() == 8) {
dest
} else {
ctx.builder
.build_bit_cast(dest, llvm_p0i8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
};
let src = if matches!(src_elem_t, IntType(t) if t.get_bit_width() == 8) {
src
} else {
ctx.builder
.build_bit_cast(src, llvm_p0i8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
};
call_memcpy(ctx, dest, src, len, is_volatile);
}
/// Macro to find and generate build call for llvm intrinsic (body of llvm intrinsic function)
///
/// Arguments:
/// * `$ctx:ident`: Reference to the current Code Generation Context
/// * `$name:ident`: Optional name to be assigned to the llvm build call (Option<&str>)
/// * `$llvm_name:literal`: Name of underlying llvm intrinsic function
/// * `$map_fn:ident`: Mapping function to be applied on `BasicValue` (`BasicValue` -> Function Return Type).
/// Use `BasicValueEnum::into_int_value` for Integer return type and
/// `BasicValueEnum::into_float_value` for Float return type
/// * `$llvm_ty:ident`: Type of first operand
/// * `,($val:ident)*`: Comma separated list of operands
macro_rules! generate_llvm_intrinsic_fn_body {
($ctx:ident, $name:ident, $llvm_name:literal, $map_fn:expr, $llvm_ty:ident $(,$val:ident)*) => {{
const FN_NAME: &str = concat!("llvm.", $llvm_name);
let intrinsic_fn = Intrinsic::find(FN_NAME).and_then(|intrinsic| intrinsic.get_declaration(&$ctx.module, &[$llvm_ty.into()])).unwrap();
$ctx.builder.build_call(intrinsic_fn, &[$($val.into()),*], $name.unwrap_or_default()).map(CallSiteValue::try_as_basic_value).map(|v| v.map_left($map_fn)).map(Either::unwrap_left).unwrap()
}};
}
/// Macro to generate the llvm intrinsic function using [`generate_llvm_intrinsic_fn_body`].
///
/// Arguments:
/// * `float/int`: Indicates the return and argument type of the function
/// * `$fn_name:ident`: The identifier of the rust function to be generated
/// * `$llvm_name:literal`: Name of underlying llvm intrinsic function.
/// Omit "llvm." prefix from the function name i.e. use "ceil" instead of "llvm.ceil"
/// * `$val:ident`: The operand for unary operations
/// * `$val1:ident`, `$val2:ident`: The operands for binary operations
macro_rules! generate_llvm_intrinsic_fn {
("float", $fn_name:ident, $llvm_name:literal, $val:ident) => {
#[doc = concat!("Invokes the [`", stringify!($llvm_name), "`](https://llvm.org/docs/LangRef.html#llvm-", stringify!($llvm_name), "-intrinsic) intrinsic." )]
pub fn $fn_name<'ctx> (
ctx: &CodeGenContext<'ctx, '_>,
$val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
let llvm_ty = $val.get_type();
generate_llvm_intrinsic_fn_body!(ctx, name, $llvm_name, BasicValueEnum::into_float_value, llvm_ty, $val)
}
};
("float", $fn_name:ident, $llvm_name:literal, $val1:ident, $val2:ident) => {
#[doc = concat!("Invokes the [`", stringify!($llvm_name), "`](https://llvm.org/docs/LangRef.html#llvm-", stringify!($llvm_name), "-intrinsic) intrinsic." )]
pub fn $fn_name<'ctx> (
ctx: &CodeGenContext<'ctx, '_>,
$val1: FloatValue<'ctx>,
$val2: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
debug_assert_eq!($val1.get_type(), $val2.get_type());
let llvm_ty = $val1.get_type();
generate_llvm_intrinsic_fn_body!(ctx, name, $llvm_name, BasicValueEnum::into_float_value, llvm_ty, $val1, $val2)
}
};
("int", $fn_name:ident, $llvm_name:literal, $val1:ident, $val2:ident) => {
#[doc = concat!("Invokes the [`", stringify!($llvm_name), "`](https://llvm.org/docs/LangRef.html#llvm-", stringify!($llvm_name), "-intrinsic) intrinsic." )]
pub fn $fn_name<'ctx> (
ctx: &CodeGenContext<'ctx, '_>,
$val1: IntValue<'ctx>,
$val2: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
debug_assert_eq!($val1.get_type().get_bit_width(), $val2.get_type().get_bit_width());
let llvm_ty = $val1.get_type();
generate_llvm_intrinsic_fn_body!(ctx, name, $llvm_name, BasicValueEnum::into_int_value, llvm_ty, $val1, $val2)
}
};
}
/// Invokes the [`llvm.abs`](https://llvm.org/docs/LangRef.html#llvm-abs-intrinsic) intrinsic.
///
/// * `src` - The value for which the absolute value is to be returned.
/// * `is_int_min_poison` - Whether `poison` is to be returned if `src` is `INT_MIN`.
pub fn call_int_abs<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
src: IntValue<'ctx>,
is_int_min_poison: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
debug_assert_eq!(is_int_min_poison.get_type().get_bit_width(), 1);
debug_assert!(is_int_min_poison.is_const());
let src_type = src.get_type();
generate_llvm_intrinsic_fn_body!(
ctx,
name,
"abs",
BasicValueEnum::into_int_value,
src_type,
src,
is_int_min_poison
)
}
generate_llvm_intrinsic_fn!("int", call_int_smax, "smax", a, b);
generate_llvm_intrinsic_fn!("int", call_int_smin, "smin", a, b);
generate_llvm_intrinsic_fn!("int", call_int_umax, "umax", a, b);
generate_llvm_intrinsic_fn!("int", call_int_umin, "umin", a, b);
generate_llvm_intrinsic_fn!("int", call_expect, "expect", val, expected_val);
generate_llvm_intrinsic_fn!("float", call_float_sqrt, "sqrt", val);
generate_llvm_intrinsic_fn!("float", call_float_sin, "sin", val);
generate_llvm_intrinsic_fn!("float", call_float_cos, "cos", val);
generate_llvm_intrinsic_fn!("float", call_float_pow, "pow", val, power);
generate_llvm_intrinsic_fn!("float", call_float_exp, "exp", val);
generate_llvm_intrinsic_fn!("float", call_float_exp2, "exp2", val);
generate_llvm_intrinsic_fn!("float", call_float_log, "log", val);
generate_llvm_intrinsic_fn!("float", call_float_log10, "log10", val);
generate_llvm_intrinsic_fn!("float", call_float_log2, "log2", val);
generate_llvm_intrinsic_fn!("float", call_float_fabs, "fabs", src);
generate_llvm_intrinsic_fn!("float", call_float_minnum, "minnum", val, power);
generate_llvm_intrinsic_fn!("float", call_float_maxnum, "maxnum", val, power);
generate_llvm_intrinsic_fn!("float", call_float_copysign, "copysign", mag, sgn);
generate_llvm_intrinsic_fn!("float", call_float_floor, "floor", val);
generate_llvm_intrinsic_fn!("float", call_float_ceil, "ceil", val);
generate_llvm_intrinsic_fn!("float", call_float_round, "round", val);
generate_llvm_intrinsic_fn!("float", call_float_rint, "rint", val);
/// Invokes the [`llvm.powi`](https://llvm.org/docs/LangRef.html#llvm-powi-intrinsic) intrinsic.
pub fn call_float_powi<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
power: IntValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "llvm.powi";
let llvm_val_t = val.get_type();
let llvm_power_t = power.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| {
intrinsic.get_declaration(&ctx.module, &[llvm_val_t.into(), llvm_power_t.into()])
})
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[val.into(), power.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}

View File

@ -1,75 +1,50 @@
use std::{
collections::{HashMap, HashSet},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
thread,
};
use crossbeam::channel::{unbounded, Receiver, Sender};
use inkwell::{
attributes::{Attribute, AttributeLoc},
basic_block::BasicBlock,
builder::Builder,
context::Context,
debug_info::{
AsDIScope, DICompileUnit, DIFlagsConstants, DIScope, DISubprogram, DebugInfoBuilder,
},
module::Module,
passes::PassBuilderOptions,
targets::{CodeModel, RelocMode, Target, TargetMachine, TargetTriple},
types::{AnyType, BasicType, BasicTypeEnum},
values::{BasicValueEnum, FunctionValue, IntValue, PhiValue, PointerValue},
AddressSpace, IntPredicate, OptimizationLevel,
};
use itertools::Itertools;
use parking_lot::{Condvar, Mutex};
use nac3parser::ast::{Location, Stmt, StrRef};
use crate::{
symbol_resolver::{StaticValue, SymbolResolver},
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef},
toplevel::{TopLevelContext, TopLevelDef},
typecheck::{
type_inferencer::{CodeLocation, PrimitiveStore},
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
},
};
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore};
pub use generator::{CodeGenerator, DefaultCodeGenerator};
use types::{ListType, NDArrayType, ProxyType, RangeType};
use crossbeam::channel::{unbounded, Receiver, Sender};
use inkwell::{
AddressSpace,
IntPredicate,
OptimizationLevel,
attributes::{Attribute, AttributeLoc},
basic_block::BasicBlock,
builder::Builder,
context::Context,
module::Module,
passes::PassBuilderOptions,
targets::{CodeModel, RelocMode, Target, TargetMachine, TargetTriple},
types::{AnyType, BasicType, BasicTypeEnum},
values::{BasicValueEnum, FunctionValue, IntValue, PhiValue, PointerValue},
debug_info::{
DebugInfoBuilder, DICompileUnit, DISubprogram, AsDIScope, DIFlagsConstants, DIScope
},
};
use itertools::Itertools;
use nac3parser::ast::{Stmt, StrRef, Location};
use parking_lot::{Condvar, Mutex};
use std::collections::{HashMap, HashSet};
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use std::thread;
pub mod builtin_fns;
pub mod concrete_type;
pub mod expr;
pub mod extern_fns;
mod generator;
pub mod irrt;
pub mod llvm_intrinsics;
pub mod numpy;
pub mod stmt;
pub mod types;
pub mod values;
#[cfg(test)]
mod test;
mod macros {
/// Codegen-variant of [`std::unreachable`] which accepts an instance of [`CodeGenContext`] as
/// its first argument to provide Python source information to indicate the codegen location
/// causing the assertion.
macro_rules! codegen_unreachable {
($ctx:expr $(,)?) => {
std::unreachable!("unreachable code while processing {}", &$ctx.current_loc)
};
($ctx:expr, $($arg:tt)*) => {
std::unreachable!("unreachable code while processing {}: {}", &$ctx.current_loc, std::format!("{}", std::format_args!($($arg)+)))
};
}
pub(crate) use codegen_unreachable;
}
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore};
pub use generator::{CodeGenerator, DefaultCodeGenerator};
#[derive(Default)]
pub struct StaticValueStore {
@ -89,16 +64,6 @@ pub struct CodeGenLLVMOptions {
pub target: CodeGenTargetMachineOptions,
}
impl CodeGenLLVMOptions {
/// Creates a [`TargetMachine`] using the target options specified by this struct.
///
/// See [`Target::create_target_machine`].
#[must_use]
pub fn create_target_machine(&self) -> Option<TargetMachine> {
self.target.create_target_machine(self.opt_level)
}
}
/// Additional options for code generation for the target machine.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct CodeGenTargetMachineOptions {
@ -115,6 +80,7 @@ pub struct CodeGenTargetMachineOptions {
}
impl CodeGenTargetMachineOptions {
/// Creates an instance of [`CodeGenTargetMachineOptions`] using the triple of the host machine.
/// Other options are set to defaults.
#[must_use]
@ -143,11 +109,13 @@ impl CodeGenTargetMachineOptions {
///
/// See [`Target::create_target_machine`].
#[must_use]
pub fn create_target_machine(&self, level: OptimizationLevel) -> Option<TargetMachine> {
pub fn create_target_machine(
&self,
level: OptimizationLevel,
) -> Option<TargetMachine> {
let triple = TargetTriple::create(self.triple.as_str());
let target = Target::from_triple(&triple).unwrap_or_else(|_| {
panic!("could not create target from target triple {}", self.triple)
});
let target = Target::from_triple(&triple)
.unwrap_or_else(|_| panic!("could not create target from target triple {}", self.triple));
target.create_target_machine(
&triple,
@ -155,7 +123,7 @@ impl CodeGenTargetMachineOptions {
self.features.as_str(),
level,
self.reloc_mode,
self.code_model,
self.code_model
)
}
}
@ -166,23 +134,24 @@ pub struct CodeGenContext<'ctx, 'a> {
/// The [Builder] instance for creating LLVM IR statements.
pub builder: Builder<'ctx>,
/// The [`DebugInfoBuilder`], [compilation unit information][DICompileUnit], and
/// The [DebugInfoBuilder], [compilation unit information][DICompileUnit], and
/// [scope information][DIScope] of this context.
pub debug_info: (DebugInfoBuilder<'ctx>, DICompileUnit<'ctx>, DIScope<'ctx>),
/// The module for which [this context][CodeGenContext] is generating into.
pub module: Module<'ctx>,
/// The [`TopLevelContext`] associated with [this context][CodeGenContext].
/// The [TopLevelContext] associated with [this context][CodeGenContext].
pub top_level: &'a TopLevelContext,
pub unifier: Unifier,
pub resolver: Arc<dyn SymbolResolver + Send + Sync>,
pub static_value_store: Arc<Mutex<StaticValueStore>>,
/// A [`HashMap`] containing the mapping between the names of variables currently in-scope and
/// A [HashMap] containing the mapping between the names of variables currently in-scope and
/// its value information.
pub var_assignment: HashMap<StrRef, VarValue<'ctx>>,
///
pub type_cache: HashMap<Type, BasicTypeEnum<'ctx>>,
pub primitives: PrimitiveStore,
pub calls: Arc<HashMap<CodeLocation, CallId>>,
@ -191,24 +160,24 @@ pub struct CodeGenContext<'ctx, 'a> {
/// Cache for constant strings.
pub const_strings: HashMap<String, BasicValueEnum<'ctx>>,
/// [`BasicBlock`] containing all `alloca` statements for the current function.
/// [BasicBlock] containing all `alloca` statements for the current function.
pub init_bb: BasicBlock<'ctx>,
pub exception_val: Option<PointerValue<'ctx>>,
/// The header and exit basic blocks of a loop in this context. See
/// <https://llvm.org/docs/LoopTerminology.html> for explanation of these terminology.
/// https://llvm.org/docs/LoopTerminology.html for explanation of these terminology.
pub loop_target: Option<(BasicBlock<'ctx>, BasicBlock<'ctx>)>,
/// The target [`BasicBlock`] to jump to when performing stack unwind.
/// The target [BasicBlock] to jump to when performing stack unwind.
pub unwind_target: Option<BasicBlock<'ctx>>,
/// The target [`BasicBlock`] to jump to before returning from the function.
/// The target [BasicBlock] to jump to before returning from the function.
///
/// If this field is [None] when generating a return from a function, `ret` with no argument can
/// be emitted.
pub return_target: Option<BasicBlock<'ctx>>,
/// The [`PointerValue`] containing the return value of the function.
/// The [PointerValue] containing the return value of the function.
pub return_buffer: Option<PointerValue<'ctx>>,
// outer catch clauses
@ -217,7 +186,7 @@ pub struct CodeGenContext<'ctx, 'a> {
/// Whether `sret` is needed for the first parameter of the function.
///
/// See [`need_sret`].
/// See [need_sret].
pub need_sret: bool,
/// The current source location.
@ -225,6 +194,7 @@ pub struct CodeGenContext<'ctx, 'a> {
}
impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
/// Whether the [current basic block][Builder::get_insert_block] referenced by `builder`
/// contains a [terminator statement][BasicBlock::get_terminator].
pub fn is_terminated(&self) -> bool {
@ -266,10 +236,11 @@ pub struct WorkerRegistry {
static_value_store: Arc<Mutex<StaticValueStore>>,
/// LLVM-related options for code generation.
pub llvm_options: CodeGenLLVMOptions,
llvm_options: CodeGenLLVMOptions,
}
impl WorkerRegistry {
/// Creates workers for this registry.
#[must_use]
pub fn create_workers<G: CodeGenerator + Send + 'static>(
@ -304,15 +275,9 @@ impl WorkerRegistry {
let registry = registry.clone();
let registry2 = registry.clone();
let f = f.clone();
let worker_thread_name =
format!("codegen-worker-{worker_id}", worker_id = generator.get_name());
let handle = thread::Builder::new()
.name(worker_thread_name)
.spawn(move || {
let handle = thread::spawn(move || {
registry.worker_thread(generator.as_mut(), &f);
})
.unwrap();
});
let handle = thread::spawn(move || {
if let Err(e) = handle.join() {
if let Some(e) = e.downcast_ref::<&'static str>() {
@ -369,10 +334,6 @@ impl WorkerRegistry {
let mut builder = context.create_builder();
let mut module = context.create_module(generator.get_name());
let target_machine = self.llvm_options.create_target_machine().unwrap();
module.set_data_layout(&target_machine.get_target_data().get_data_layout());
module.set_triple(&target_machine.get_triple());
module.add_basic_value_flag(
"Debug Info Version",
inkwell::module::FlagBehavior::Warning,
@ -396,20 +357,12 @@ impl WorkerRegistry {
errors.insert(e);
// create a new empty module just to continue codegen and collect errors
module = context.create_module(&format!("{}_recover", generator.get_name()));
let target_machine = self.llvm_options.create_target_machine().unwrap();
module.set_data_layout(&target_machine.get_target_data().get_data_layout());
module.set_triple(&target_machine.get_triple());
}
}
*self.task_count.lock() -= 1;
self.wait_condvar.notify_all();
}
assert!(
errors.is_empty(),
"Codegen error: {}",
errors.into_iter().sorted().join("\n----------\n")
);
assert!(errors.is_empty(), "Codegen error: {}", errors.into_iter().sorted().join("\n----------\n"));
let result = module.verify();
if let Err(err) = result {
@ -422,20 +375,13 @@ impl WorkerRegistry {
.llvm_options
.target
.create_target_machine(self.llvm_options.opt_level)
.unwrap_or_else(|| {
panic!(
"could not create target machine from properties {:?}",
self.llvm_options.target
)
});
.unwrap_or_else(|| panic!("could not create target machine from properties {:?}", self.llvm_options.target));
let passes = format!("default<O{}>", self.llvm_options.opt_level as u32);
let result = module.run_passes(passes.as_str(), &target_machine, pass_options);
if let Err(err) = result {
panic!(
"Failed to run optimization for module `{}`: {}",
panic!("Failed to run optimization for module `{}`: {}",
module.get_name().to_str().unwrap(),
err.to_string()
);
err.to_string());
}
f.run(&module);
@ -461,14 +407,14 @@ pub struct CodeGenTask {
///
/// This function is used to obtain the in-memory representation of `ty`, e.g. a `bool` variable
/// would be represented by an `i8`.
#[allow(clippy::too_many_arguments)]
fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
fn get_llvm_type<'ctx>(
ctx: &'ctx Context,
module: &Module<'ctx>,
generator: &G,
generator: &mut dyn CodeGenerator,
unifier: &mut Unifier,
top_level: &TopLevelContext,
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
primitives: &PrimitiveStore,
ty: Type,
) -> BasicTypeEnum<'ctx> {
use TypeEnum::*;
@ -478,51 +424,29 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
let ty_enum = unifier.get_ty(ty);
let result = match &*ty_enum {
TObj { obj_id, fields, .. } => {
// check to avoid treating non-class primitives as classes
if PrimDef::contains_id(*obj_id) {
return match &*unifier.get_ty_immutable(ty) {
TObj { obj_id, params, .. } if *obj_id == PrimDef::Option.id() => {
get_llvm_type(
// check to avoid treating primitives other than Option as classes
if obj_id.0 <= 10 {
match (unifier.get_ty(ty).as_ref(), unifier.get_ty(primitives.option).as_ref())
{
(
TObj { obj_id, params, .. },
TObj { obj_id: opt_id, .. },
) if *obj_id == *opt_id => {
return get_llvm_type(
ctx,
module,
generator,
unifier,
top_level,
type_cache,
primitives,
*params.iter().next().unwrap().1,
)
.ptr_type(AddressSpace::default())
.into()
.into();
}
TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
let element_type = get_llvm_type(
ctx,
module,
generator,
unifier,
top_level,
type_cache,
*params.iter().next().unwrap().1,
);
ListType::new(generator, ctx, element_type).as_base_type().into()
_ => unreachable!("must be option type"),
}
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (dtype, _) = unpack_ndarray_var_tys(unifier, ty);
let element_type = get_llvm_type(
ctx, module, generator, unifier, top_level, type_cache, dtype,
);
NDArrayType::new(generator, ctx, element_type).as_base_type().into()
}
_ => unreachable!(
"LLVM type for primitive {} is missing",
unifier.stringify(ty)
),
};
}
// a struct with fields in the order of declaration
let top_level_defs = top_level.definitions.read();
@ -538,7 +462,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
let struct_type = ctx.opaque_struct_type(&name);
type_cache.insert(
unifier.get_representative(ty),
struct_type.ptr_type(AddressSpace::default()).into(),
struct_type.ptr_type(AddressSpace::default()).into()
);
let fields = fields_list
.iter()
@ -550,6 +474,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
unifier,
top_level,
type_cache,
primitives,
fields[&f.0].0,
)
})
@ -557,20 +482,31 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
struct_type.set_body(&fields, false);
struct_type.ptr_type(AddressSpace::default()).into()
};
return ty;
return ty
}
TTuple { ty, is_vararg_ctx } => {
TTuple { ty } => {
// a struct with fields in the order present in the tuple
assert!(!is_vararg_ctx, "Tuples in vararg context must be instantiated with the correct number of arguments before calling get_llvm_type");
let fields = ty
.iter()
.map(|ty| {
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty)
get_llvm_type(
ctx, module, generator, unifier, top_level, type_cache, primitives, *ty,
)
})
.collect_vec();
ctx.struct_type(&fields, false).into()
}
TList { ty } => {
// a struct with an integer and a pointer to an array
let element_type = get_llvm_type(
ctx, module, generator, unifier, top_level, type_cache, primitives, *ty,
);
let fields = [
element_type.ptr_type(AddressSpace::default()).into(),
generator.get_size_type(ctx).into(),
];
ctx.struct_type(&fields, false).ptr_type(AddressSpace::default()).into()
}
TVirtual { .. } => unimplemented!(),
_ => unreachable!("{}", ty_enum.get_type_name()),
};
@ -588,11 +524,10 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
/// ABI representation is that the in-memory representation must be at least byte-sized and must
/// be byte-aligned for the variable to be addressable in memory, whereas there is no such
/// restriction for ABI representations.
#[allow(clippy::too_many_arguments)]
fn get_llvm_abi_type<'ctx, G: CodeGenerator + ?Sized>(
fn get_llvm_abi_type<'ctx>(
ctx: &'ctx Context,
module: &Module<'ctx>,
generator: &G,
generator: &mut dyn CodeGenerator,
unifier: &mut Unifier,
top_level: &TopLevelContext,
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
@ -601,10 +536,10 @@ fn get_llvm_abi_type<'ctx, G: CodeGenerator + ?Sized>(
) -> BasicTypeEnum<'ctx> {
// If the type is used in the definition of a function, return `i1` instead of `i8` for ABI
// consistency.
if unifier.unioned(ty, primitives.bool) {
return if unifier.unioned(ty, primitives.bool) {
ctx.bool_type().into()
} else {
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, ty)
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, primitives, ty)
}
}
@ -621,62 +556,23 @@ fn need_sret(ty: BasicTypeEnum) -> bool {
match ty {
BasicTypeEnum::IntType(_) | BasicTypeEnum::PointerType(_) => false,
BasicTypeEnum::FloatType(_) if maybe_large => false,
BasicTypeEnum::StructType(ty) if maybe_large && ty.count_fields() <= 2 => {
ty.get_field_types().iter().any(|ty| need_sret_impl(*ty, false))
}
BasicTypeEnum::StructType(ty) if maybe_large && ty.count_fields() <= 2 =>
ty.get_field_types().iter().any(|ty| need_sret_impl(*ty, false)),
_ => true,
}
}
need_sret_impl(ty, true)
}
/// Returns the [`BasicTypeEnum`] representing a `va_list` struct for variadic arguments.
fn get_llvm_valist_type<'ctx>(ctx: &'ctx Context, triple: &TargetTriple) -> BasicTypeEnum<'ctx> {
let triple = TargetMachine::normalize_triple(triple);
let triple = triple.as_str().to_str().unwrap();
let arch = triple.split('-').next().unwrap();
let llvm_pi8 = ctx.i8_type().ptr_type(AddressSpace::default());
// Referenced from parseArch() in llvm/lib/Support/Triple.cpp
match arch {
"i386" | "i486" | "i586" | "i686" | "riscv32" => {
ctx.i8_type().ptr_type(AddressSpace::default()).into()
}
"amd64" | "x86_64" | "x86_64h" => {
let llvm_i32 = ctx.i32_type();
let va_list_tag = ctx.opaque_struct_type("struct.__va_list_tag");
va_list_tag.set_body(
&[llvm_i32.into(), llvm_i32.into(), llvm_pi8.into(), llvm_pi8.into()],
false,
);
va_list_tag.into()
}
"armv7" => {
let va_list = ctx.opaque_struct_type("struct.__va_list");
va_list.set_body(&[llvm_pi8.into()], false);
va_list.into()
}
triple => {
todo!("Unsupported platform for varargs: {triple}")
}
}
}
/// Implementation for generating LLVM IR for a function.
pub fn gen_func_impl<
'ctx,
G: CodeGenerator,
F: FnOnce(&mut G, &mut CodeGenContext) -> Result<(), String>,
>(
pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenContext) -> Result<(), String>> (
context: &'ctx Context,
generator: &mut G,
registry: &WorkerRegistry,
builder: Builder<'ctx>,
module: Module<'ctx>,
task: CodeGenTask,
codegen_function: F,
codegen_function: F
) -> Result<(Builder<'ctx>, Module<'ctx>, FunctionValue<'ctx>), (Builder<'ctx>, String)> {
let top_level_ctx = registry.top_level_ctx.clone();
let static_value_store = registry.static_value_store.clone();
@ -684,7 +580,6 @@ pub fn gen_func_impl<
let (unifier, primitives) = &top_level_ctx.unifiers.read()[task.unifier_index];
(Unifier::from_shared_unifier(unifier), *primitives)
};
unifier.put_primitive_store(&primitives);
unifier.top_level = Some(top_level_ctx.clone());
let mut cache = HashMap::new();
@ -718,7 +613,6 @@ pub fn gen_func_impl<
str: unifier.get_representative(primitives.str),
exception: unifier.get_representative(primitives.exception),
option: unifier.get_representative(primitives.option),
..primitives
};
let mut type_cache: HashMap<_, _> = [
@ -740,10 +634,10 @@ pub fn gen_func_impl<
str_type.set_body(&fields, false);
str_type.into()
}
Some(t) => t.as_basic_type_enum(),
Some(t) => t.as_basic_type_enum()
}
}),
(primitives.range, RangeType::new(context).as_base_type().into()),
(primitives.range, context.i32_type().array_type(3).ptr_type(AddressSpace::default()).into()),
(primitives.exception, {
let name = "Exception";
if let Some(t) = module.get_struct_type(name) {
@ -757,7 +651,7 @@ pub fn gen_func_impl<
exception.set_body(&fields, false);
exception.ptr_type(AddressSpace::default()).as_basic_type_enum()
}
}),
})
]
.iter()
.copied()
@ -765,7 +659,8 @@ pub fn gen_func_impl<
// NOTE: special handling of option cannot use this type cache since it contains type var,
// handled inside get_llvm_type instead
let ConcreteTypeEnum::TFunc { args, ret, .. } = task.store.get(task.signature) else {
let ConcreteTypeEnum::TFunc { args, ret, .. } =
task.store.get(task.signature) else {
unreachable!()
};
@ -775,7 +670,6 @@ pub fn gen_func_impl<
name: arg.name,
ty: task.store.to_unifier_type(&mut unifier, &primitives, arg.ty, &mut cache),
default_value: arg.default_value.clone(),
is_vararg: arg.is_vararg,
})
.collect_vec(),
task.store.to_unifier_type(&mut unifier, &primitives, *ret, &mut cache),
@ -783,25 +677,13 @@ pub fn gen_func_impl<
let ret_type = if unifier.unioned(ret, primitives.none) {
None
} else {
Some(get_llvm_abi_type(
context,
&module,
generator,
&mut unifier,
top_level_ctx.as_ref(),
&mut type_cache,
&primitives,
ret,
))
Some(get_llvm_abi_type(context, &module, generator, &mut unifier, top_level_ctx.as_ref(), &mut type_cache, &primitives, ret))
};
let has_sret = ret_type.map_or(false, |ty| need_sret(ty));
let mut params = args
.iter()
.filter(|arg| !arg.is_vararg)
.map(|arg| {
debug_assert!(!arg.is_vararg);
get_llvm_abi_type(
context,
&module,
@ -820,12 +702,9 @@ pub fn gen_func_impl<
params.insert(0, ret_type.unwrap().ptr_type(AddressSpace::default()).into());
}
debug_assert!(matches!(args.iter().filter(|arg| arg.is_vararg).count(), 0..=1));
let vararg_arg = args.iter().find(|arg| arg.is_vararg);
let fn_type = match ret_type {
Some(ret_type) if !has_sret => ret_type.fn_type(&params, vararg_arg.is_some()),
_ => context.void_type().fn_type(&params, vararg_arg.is_some()),
Some(ret_type) if !has_sret => ret_type.fn_type(&params, false),
_ => context.void_type().fn_type(&params, false)
};
let symbol = &task.symbol_name;
@ -840,23 +719,18 @@ pub fn gen_func_impl<
fn_val.set_personality_function(personality);
}
if has_sret {
fn_val.add_attribute(
AttributeLoc::Param(0),
context.create_type_attribute(
Attribute::get_named_enum_kind_id("sret"),
ret_type.unwrap().as_any_type_enum(),
),
);
fn_val.add_attribute(AttributeLoc::Param(0),
context.create_type_attribute(Attribute::get_named_enum_kind_id("sret"),
ret_type.unwrap().as_any_type_enum()));
}
let init_bb = context.append_basic_block(fn_val, "init");
builder.position_at_end(init_bb);
let body_bb = context.append_basic_block(fn_val, "body");
// Store non-vararg argument values into local variables
let mut var_assignment = HashMap::new();
let offset = u32::from(has_sret);
for (n, arg) in args.iter().enumerate().filter(|(_, arg)| !arg.is_vararg) {
for (n, arg) in args.iter().enumerate() {
let param = fn_val.get_nth_param((n as u32) + offset).unwrap();
let local_type = get_llvm_type(
context,
@ -865,10 +739,13 @@ pub fn gen_func_impl<
&mut unifier,
top_level_ctx.as_ref(),
&mut type_cache,
&primitives,
arg.ty,
);
let alloca =
builder.build_alloca(local_type, &format!("{}.addr", &arg.name.to_string())).unwrap();
let alloca = builder.build_alloca(
local_type,
&format!("{}.addr", &arg.name.to_string()),
);
// Remap boolean parameters into i8
let param = if local_type.is_int_type() && param.is_int_value() {
@ -879,22 +756,19 @@ pub fn gen_func_impl<
bool_to_i8(&builder, context, param_val)
} else {
param_val
}
.into()
}.into()
} else {
param
};
builder.build_store(alloca, param).unwrap();
builder.build_store(alloca, param);
var_assignment.insert(arg.name, (alloca, None, 0));
}
// TODO: Save vararg parameters as list
let return_buffer = if has_sret {
Some(fn_val.get_nth_param(0).unwrap().into_pointer_value())
} else {
fn_type.get_return_type().map(|v| builder.build_alloca(v, "$ret").unwrap())
fn_type.get_return_type().map(|v| builder.build_alloca(v, "$ret"))
};
let static_values = {
@ -906,7 +780,7 @@ pub fn gen_func_impl<
*static_val = Some(v);
}
builder.build_unconditional_branch(body_bb).unwrap();
builder.build_unconditional_branch(body_bb);
builder.position_at_end(body_bb);
let (dibuilder, compile_unit) = module.create_debug_info_builder(
@ -915,8 +789,11 @@ pub fn gen_func_impl<
/* filename */
&task
.body
.first()
.map_or_else(|| "<nac3_internal>".to_string(), |f| f.location.file.0.to_string()),
.get(0)
.map_or_else(
|| "<nac3_internal>".to_string(),
|f| f.location.file.0.to_string(),
),
/* directory */ "",
/* producer */ "NAC3",
/* is_optimized */ registry.llvm_options.opt_level != OptimizationLevel::None,
@ -942,7 +819,7 @@ pub fn gen_func_impl<
inkwell::debug_info::DIFlags::PUBLIC,
);
let (row, col) =
task.body.first().map_or_else(|| (0, 0), |b| (b.location.row, b.location.column));
task.body.get(0).map_or_else(|| (0, 0), |b| (b.location.row, b.location.column));
let func_scope: DISubprogram<'_> = dibuilder.create_function(
/* scope */ compile_unit.as_debug_info_scope(),
/* func name */ symbol,
@ -989,7 +866,7 @@ pub fn gen_func_impl<
row as u32,
col as u32,
func_scope.as_debug_info_scope(),
None,
None
);
code_gen_context.builder.set_current_debug_location(loc);
@ -997,7 +874,7 @@ pub fn gen_func_impl<
// after static analysis, only void functions can have no return at the end.
if !code_gen_context.is_terminated() {
code_gen_context.builder.build_return(None).unwrap();
code_gen_context.builder.build_return(None);
}
code_gen_context.builder.unset_current_debug_location();
@ -1039,14 +916,12 @@ fn bool_to_i1<'ctx>(builder: &Builder<'ctx>, bool_value: IntValue<'ctx>) -> IntV
if bool_value.get_type().get_bit_width() == 1 {
bool_value
} else {
builder
.build_int_compare(
builder.build_int_compare(
IntPredicate::NE,
bool_value,
bool_value.get_type().const_zero(),
"tobool",
"tobool"
)
.unwrap()
}
}
@ -1054,23 +929,21 @@ fn bool_to_i1<'ctx>(builder: &Builder<'ctx>, bool_value: IntValue<'ctx>) -> IntV
fn bool_to_i8<'ctx>(
builder: &Builder<'ctx>,
ctx: &'ctx Context,
bool_value: IntValue<'ctx>,
bool_value: IntValue<'ctx>
) -> IntValue<'ctx> {
let value_bits = bool_value.get_type().get_bit_width();
match value_bits {
8 => bool_value,
1 => builder.build_int_z_extend(bool_value, ctx.i8_type(), "frombool").unwrap(),
1 => builder.build_int_z_extend(bool_value, ctx.i8_type(), "frombool"),
_ => bool_to_i8(
builder,
ctx,
builder
.build_int_compare(
builder.build_int_compare(
IntPredicate::NE,
bool_value,
bool_value.get_type().const_zero(),
"",
""
)
.unwrap(),
),
}
}
@ -1096,26 +969,9 @@ fn gen_in_range_check<'ctx>(
stop: IntValue<'ctx>,
step: IntValue<'ctx>,
) -> IntValue<'ctx> {
let sign = ctx
.builder
.build_int_compare(IntPredicate::SGT, step, ctx.ctx.i32_type().const_zero(), "")
.unwrap();
let lo = ctx
.builder
.build_select(sign, value, stop, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let hi = ctx
.builder
.build_select(sign, stop, value, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let sign = ctx.builder.build_int_compare(IntPredicate::SGT, step, ctx.ctx.i32_type().const_zero(), "");
let lo = ctx.builder.build_select(sign, value, stop, "").into_int_value();
let hi = ctx.builder.build_select(sign, stop, value, "").into_int_value();
ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp").unwrap()
}
/// Returns the internal name for the `va_count` argument, used to indicate the number of arguments
/// passed to the variadic function.
fn get_va_count_arg_name(arg_name: StrRef) -> StrRef {
format!("__{}_va_count", &arg_name).into()
ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp")
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,37 +1,29 @@
use std::{
collections::{HashMap, HashSet},
sync::Arc,
use crate::{
codegen::{
concrete_type::ConcreteTypeStore, CodeGenContext, CodeGenLLVMOptions,
CodeGenTargetMachineOptions, CodeGenTask, DefaultCodeGenerator, WithCall, WorkerRegistry,
},
symbol_resolver::{SymbolResolver, ValueEnum},
toplevel::{
composer::TopLevelComposer, DefinitionId, FunInstance, TopLevelContext, TopLevelDef,
},
typecheck::{
type_inferencer::{FunctionData, Inferencer, PrimitiveStore},
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier},
},
};
use indexmap::IndexMap;
use indoc::indoc;
use inkwell::{
targets::{InitializationConfig, Target},
OptimizationLevel,
OptimizationLevel
};
use nac3parser::{
ast::{fold::Fold, FileName, StrRef},
ast::{fold::Fold, StrRef},
parser::parse_program,
};
use parking_lot::RwLock;
use super::{
concrete_type::ConcreteTypeStore,
types::{ListType, NDArrayType, ProxyType, RangeType},
CodeGenContext, CodeGenLLVMOptions, CodeGenTargetMachineOptions, CodeGenTask, CodeGenerator,
DefaultCodeGenerator, WithCall, WorkerRegistry,
};
use crate::{
symbol_resolver::{SymbolResolver, ValueEnum},
toplevel::{
composer::{ComposerConfig, TopLevelComposer},
DefinitionId, FunInstance, TopLevelContext, TopLevelDef,
},
typecheck::{
type_inferencer::{FunctionData, IdentifierInfo, Inferencer, PrimitiveStore},
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
},
};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
struct Resolver {
id_to_type: HashMap<StrRef, Type>,
@ -60,14 +52,13 @@ impl SymbolResolver for Resolver {
_: &PrimitiveStore,
str: StrRef,
) -> Result<Type, String> {
self.id_to_type.get(&str).copied().ok_or_else(|| format!("cannot find symbol `{str}`"))
self.id_to_type.get(&str).cloned().ok_or_else(|| format!("cannot find symbol `{}`", str))
}
fn get_symbol_value<'ctx>(
fn get_symbol_value<'ctx, 'a>(
&self,
_: StrRef,
_: &mut CodeGenContext<'ctx, '_>,
_: &mut dyn CodeGenerator,
_: &mut CodeGenContext<'ctx, 'a>,
) -> Option<ValueEnum<'ctx>> {
unimplemented!()
}
@ -76,8 +67,10 @@ impl SymbolResolver for Resolver {
self.id_to_def
.read()
.get(&id)
.copied()
.ok_or_else(|| HashSet::from([format!("cannot find symbol `{id}`")]))
.cloned()
.ok_or_else(|| HashSet::from([
format!("cannot find symbol `{}`", id),
]))
}
fn get_string_id(&self, _: &str) -> i32 {
@ -96,9 +89,9 @@ fn test_primitives() {
d = a if c == 1 else 0
return d
"};
let statements = parse_program(source, FileName::default()).unwrap();
let statements = parse_program(source, Default::default()).unwrap();
let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 32).0;
let composer: TopLevelComposer = Default::default();
let mut unifier = composer.unifier.clone();
let primitives = composer.primitives_ty;
let top_level = Arc::new(composer.make_top_level_context());
@ -107,27 +100,17 @@ fn test_primitives() {
let resolver = Arc::new(Resolver {
id_to_type: HashMap::new(),
id_to_def: RwLock::new(HashMap::new()),
class_names: HashMap::default(),
class_names: Default::default(),
}) as Arc<dyn SymbolResolver + Send + Sync>;
let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()];
let signature = FunSignature {
args: vec![
FuncArg {
name: "a".into(),
ty: primitives.int32,
default_value: None,
is_vararg: false,
},
FuncArg {
name: "b".into(),
ty: primitives.int32,
default_value: None,
is_vararg: false,
},
FuncArg { name: "a".into(), ty: primitives.int32, default_value: None },
FuncArg { name: "b".into(), ty: primitives.int32, default_value: None },
],
ret: primitives.int32,
vars: VarMap::new(),
vars: HashMap::new(),
};
let mut store = ConcreteTypeStore::new();
@ -142,13 +125,12 @@ fn test_primitives() {
};
let mut virtual_checks = Vec::new();
let mut calls = HashMap::new();
let mut identifiers: HashMap<_, _> =
["a".into(), "b".into()].map(|id| (id, IdentifierInfo::default())).into();
let mut identifiers: HashSet<_> = ["a".into(), "b".into()].iter().cloned().collect();
let mut inferencer = Inferencer {
top_level: &top_level,
function_data: &mut function_data,
unifier: &mut unifier,
variable_mapping: HashMap::default(),
variable_mapping: Default::default(),
primitives: &primitives,
virtual_checks: &mut virtual_checks,
calls: &mut calls,
@ -172,7 +154,7 @@ fn test_primitives() {
});
let task = CodeGenTask {
subst: Vec::default(),
subst: Default::default(),
symbol_name: "testing".into(),
body: Arc::new(statements),
unifier_index: 0,
@ -204,8 +186,6 @@ fn test_primitives() {
let expected = indoc! {"
; ModuleID = 'test'
source_filename = \"test\"
target datalayout = \"e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128\"
target triple = \"x86_64-unknown-linux-gnu\"
; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn
define i32 @testing(i32 %0, i32 %1) local_unnamed_addr #0 !dbg !4 {
@ -245,7 +225,12 @@ fn test_primitives() {
opt_level: OptimizationLevel::Default,
target: CodeGenTargetMachineOptions::from_host_triple(),
};
let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f);
let (registry, handles) = WorkerRegistry::create_workers(
threads,
top_level,
&llvm_options,
&f
);
registry.add_task(task);
registry.wait_tasks_complete(handles);
}
@ -256,28 +241,23 @@ fn test_simple_call() {
a = foo(a)
return a * 2
"};
let statements_1 = parse_program(source_1, FileName::default()).unwrap();
let statements_1 = parse_program(source_1, Default::default()).unwrap();
let source_2 = indoc! { "
return a + 1
"};
let statements_2 = parse_program(source_2, FileName::default()).unwrap();
let statements_2 = parse_program(source_2, Default::default()).unwrap();
let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 32).0;
let composer: TopLevelComposer = Default::default();
let mut unifier = composer.unifier.clone();
let primitives = composer.primitives_ty;
let top_level = Arc::new(composer.make_top_level_context());
unifier.top_level = Some(top_level.clone());
let signature = FunSignature {
args: vec![FuncArg {
name: "a".into(),
ty: primitives.int32,
default_value: None,
is_vararg: false,
}],
args: vec![FuncArg { name: "a".into(), ty: primitives.int32, default_value: None }],
ret: primitives.int32,
vars: VarMap::new(),
vars: HashMap::new(),
};
let fun_ty = unifier.add_ty(TypeEnum::TFunc(signature.clone()));
let mut store = ConcreteTypeStore::new();
@ -301,7 +281,7 @@ fn test_simple_call() {
let resolver = Resolver {
id_to_type: HashMap::new(),
id_to_def: RwLock::new(HashMap::new()),
class_names: HashMap::default(),
class_names: Default::default(),
};
resolver.add_id_def("foo".into(), DefinitionId(foo_id));
let resolver = Arc::new(resolver) as Arc<dyn SymbolResolver + Send + Sync>;
@ -322,13 +302,12 @@ fn test_simple_call() {
};
let mut virtual_checks = Vec::new();
let mut calls = HashMap::new();
let mut identifiers: HashMap<_, _> =
["a".into(), "foo".into()].map(|id| (id, IdentifierInfo::default())).into();
let mut identifiers: HashSet<_> = ["a".into(), "foo".into()].iter().cloned().collect();
let mut inferencer = Inferencer {
top_level: &top_level,
function_data: &mut function_data,
unifier: &mut unifier,
variable_mapping: HashMap::default(),
variable_mapping: Default::default(),
primitives: &primitives,
virtual_checks: &mut virtual_checks,
calls: &mut calls,
@ -357,11 +336,11 @@ fn test_simple_call() {
&mut *top_level.definitions.read()[foo_id].write()
{
instance_to_stmt.insert(
String::new(),
"".to_string(),
FunInstance {
body: Arc::new(statements_2),
calls: Arc::new(inferencer.calls.clone()),
subst: IndexMap::default(),
subst: Default::default(),
unifier_id: 0,
},
);
@ -377,7 +356,7 @@ fn test_simple_call() {
});
let task = CodeGenTask {
subst: Vec::default(),
subst: Default::default(),
symbol_name: "testing".to_string(),
body: Arc::new(statements_1),
calls: Arc::new(calls1),
@ -391,8 +370,6 @@ fn test_simple_call() {
let expected = indoc! {"
; ModuleID = 'test'
source_filename = \"test\"
target datalayout = \"e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128\"
target triple = \"x86_64-unknown-linux-gnu\"
; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn
define i32 @testing(i32 %0) local_unnamed_addr #0 !dbg !5 {
@ -438,39 +415,12 @@ fn test_simple_call() {
opt_level: OptimizationLevel::Default,
target: CodeGenTargetMachineOptions::from_host_triple(),
};
let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f);
let (registry, handles) = WorkerRegistry::create_workers(
threads,
top_level,
&llvm_options,
&f
);
registry.add_task(task);
registry.wait_tasks_complete(handles);
}
#[test]
fn test_classes_list_type_new() {
let ctx = inkwell::context::Context::create();
let generator = DefaultCodeGenerator::new(String::new(), 64);
let llvm_i32 = ctx.i32_type();
let llvm_usize = generator.get_size_type(&ctx);
let llvm_list = ListType::new(&generator, &ctx, llvm_i32.into());
assert!(ListType::is_representable(llvm_list.as_base_type(), llvm_usize).is_ok());
}
#[test]
fn test_classes_range_type_new() {
let ctx = inkwell::context::Context::create();
let llvm_range = RangeType::new(&ctx);
assert!(RangeType::is_representable(llvm_range.as_base_type()).is_ok());
}
#[test]
fn test_classes_ndarray_type_new() {
let ctx = inkwell::context::Context::create();
let generator = DefaultCodeGenerator::new(String::new(), 64);
let llvm_i32 = ctx.i32_type();
let llvm_usize = generator.get_size_type(&ctx);
let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into());
assert!(NDArrayType::is_representable(llvm_ndarray.as_base_type(), llvm_usize).is_ok());
}

View File

@ -1,192 +0,0 @@
use inkwell::{
context::Context,
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
values::IntValue,
AddressSpace,
};
use super::ProxyType;
use crate::codegen::{
values::{ArraySliceValue, ListValue, ProxyValue},
CodeGenContext, CodeGenerator,
};
/// Proxy type for a `list` type in LLVM.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct ListType<'ctx> {
ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>,
}
impl<'ctx> ListType<'ctx> {
/// Checks whether `llvm_ty` represents a `list` type, returning [Err] if it does not.
pub fn is_representable(
llvm_ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
let llvm_list_ty = llvm_ty.get_element_type();
let AnyTypeEnum::StructType(llvm_list_ty) = llvm_list_ty else {
return Err(format!("Expected struct type for `list` type, got {llvm_list_ty}"));
};
if llvm_list_ty.count_fields() != 2 {
return Err(format!(
"Expected 2 fields in `list`, got {}",
llvm_list_ty.count_fields()
));
}
let list_size_ty = llvm_list_ty.get_field_type_at_index(0).unwrap();
let Ok(_) = PointerType::try_from(list_size_ty) else {
return Err(format!("Expected pointer type for `list.0`, got {list_size_ty}"));
};
let list_data_ty = llvm_list_ty.get_field_type_at_index(1).unwrap();
let Ok(list_data_ty) = IntType::try_from(list_data_ty) else {
return Err(format!("Expected int type for `list.1`, got {list_data_ty}"));
};
if list_data_ty.get_bit_width() != llvm_usize.get_bit_width() {
return Err(format!(
"Expected {}-bit int type for `list.1`, got {}-bit int",
llvm_usize.get_bit_width(),
list_data_ty.get_bit_width()
));
}
Ok(())
}
/// Creates an LLVM type corresponding to the expected structure of a `List`.
#[must_use]
fn llvm_type(
ctx: &'ctx Context,
element_type: BasicTypeEnum<'ctx>,
llvm_usize: IntType<'ctx>,
) -> PointerType<'ctx> {
// struct List { data: T*, size: size_t }
let field_tys = [element_type.ptr_type(AddressSpace::default()).into(), llvm_usize.into()];
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
}
/// Creates an instance of [`ListType`].
#[must_use]
pub fn new<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
element_type: BasicTypeEnum<'ctx>,
) -> Self {
let llvm_usize = generator.get_size_type(ctx);
let llvm_list = Self::llvm_type(ctx, element_type, llvm_usize);
ListType::from_type(llvm_list, llvm_usize)
}
/// Creates an [`ListType`] from a [`PointerType`].
#[must_use]
pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok());
ListType { ty: ptr_ty, llvm_usize }
}
/// Returns the type of the `size` field of this `list` type.
#[must_use]
pub fn size_type(&self) -> IntType<'ctx> {
self.as_base_type()
.get_element_type()
.into_struct_type()
.get_field_type_at_index(1)
.map(BasicTypeEnum::into_int_type)
.unwrap()
}
/// Returns the element type of this `list` type.
#[must_use]
pub fn element_type(&self) -> AnyTypeEnum<'ctx> {
self.as_base_type()
.get_element_type()
.into_struct_type()
.get_field_type_at_index(0)
.map(BasicTypeEnum::into_pointer_type)
.map(PointerType::get_element_type)
.unwrap()
}
}
impl<'ctx> ProxyType<'ctx> for ListType<'ctx> {
type Base = PointerType<'ctx>;
type Value = ListValue<'ctx>;
fn is_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: impl BasicType<'ctx>,
) -> Result<(), String> {
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
<Self as ProxyType<'ctx>>::is_representable(generator, ctx, ty)
} else {
Err(format!("Expected pointer type, got {llvm_ty:?}"))
}
}
fn is_representable<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: Self::Base,
) -> Result<(), String> {
Self::is_representable(llvm_ty, generator.get_size_type(ctx))
}
fn new_value<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> Self::Value {
self.map_value(
generator
.gen_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
name,
)
.unwrap(),
name,
)
}
fn new_array_value<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
size: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> ArraySliceValue<'ctx> {
generator
.gen_array_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
size,
name,
)
.unwrap()
}
fn map_value(
&self,
value: <Self::Value as ProxyValue<'ctx>>::Base,
name: Option<&'ctx str>,
) -> Self::Value {
Self::Value::from_pointer_value(value, self.llvm_usize, name)
}
fn as_base_type(&self) -> Self::Base {
self.ty
}
}
impl<'ctx> From<ListType<'ctx>> for PointerType<'ctx> {
fn from(value: ListType<'ctx>) -> Self {
value.as_base_type()
}
}

View File

@ -1,64 +0,0 @@
use inkwell::{context::Context, types::BasicType, values::IntValue};
use super::{
values::{ArraySliceValue, ProxyValue},
{CodeGenContext, CodeGenerator},
};
pub use list::*;
pub use ndarray::*;
pub use range::*;
mod list;
mod ndarray;
mod range;
pub mod structure;
/// A LLVM type that is used to represent a corresponding type in NAC3.
pub trait ProxyType<'ctx>: Into<Self::Base> {
/// The LLVM type of which values of this type possess. This is usually a
/// [LLVM pointer type][PointerType] for any non-primitive types.
type Base: BasicType<'ctx>;
/// The type of values represented by this type.
type Value: ProxyValue<'ctx, Type = Self>;
fn is_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: impl BasicType<'ctx>,
) -> Result<(), String>;
/// Checks whether `llvm_ty` can be represented by this [`ProxyType`].
fn is_representable<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: Self::Base,
) -> Result<(), String>;
/// Creates a new value of this type.
fn new_value<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> Self::Value;
/// Creates a new array value of this type.
fn new_array_value<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
size: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> ArraySliceValue<'ctx>;
/// Converts an existing value into a [`ProxyValue`] of this type.
fn map_value(
&self,
value: <Self::Value as ProxyValue<'ctx>>::Base,
name: Option<&'ctx str>,
) -> Self::Value;
/// Returns the [base type][Self::Base] of this proxy.
fn as_base_type(&self) -> Self::Base;
}

View File

@ -1,258 +0,0 @@
use inkwell::{
context::Context,
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
values::{IntValue, PointerValue},
AddressSpace,
};
use itertools::Itertools;
use nac3core_derive::StructFields;
use super::{
structure::{StructField, StructFields},
ProxyType,
};
use crate::codegen::{
values::{ArraySliceValue, NDArrayValue, ProxyValue},
{CodeGenContext, CodeGenerator},
};
/// Proxy type for a `ndarray` type in LLVM.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct NDArrayType<'ctx> {
ty: PointerType<'ctx>,
dtype: BasicTypeEnum<'ctx>,
llvm_usize: IntType<'ctx>,
}
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct NDArrayStructFields<'ctx> {
#[value_type(usize)]
pub ndims: StructField<'ctx, IntValue<'ctx>>,
#[value_type(usize.ptr_type(AddressSpace::default()))]
pub shape: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
pub data: StructField<'ctx, PointerValue<'ctx>>,
}
impl<'ctx> NDArrayType<'ctx> {
/// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not.
pub fn is_representable(
llvm_ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
let llvm_ndarray_ty = llvm_ty.get_element_type();
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}"));
};
if llvm_ndarray_ty.count_fields() != 3 {
return Err(format!(
"Expected 3 fields in `NDArray`, got {}",
llvm_ndarray_ty.count_fields()
));
}
let ndarray_ndims_ty = llvm_ndarray_ty.get_field_type_at_index(0).unwrap();
let Ok(ndarray_ndims_ty) = IntType::try_from(ndarray_ndims_ty) else {
return Err(format!("Expected int type for `ndarray.0`, got {ndarray_ndims_ty}"));
};
if ndarray_ndims_ty.get_bit_width() != llvm_usize.get_bit_width() {
return Err(format!(
"Expected {}-bit int type for `ndarray.0`, got {}-bit int",
llvm_usize.get_bit_width(),
ndarray_ndims_ty.get_bit_width()
));
}
let ndarray_dims_ty = llvm_ndarray_ty.get_field_type_at_index(1).unwrap();
let Ok(ndarray_pdims) = PointerType::try_from(ndarray_dims_ty) else {
return Err(format!("Expected pointer type for `ndarray.1`, got {ndarray_dims_ty}"));
};
let ndarray_dims = ndarray_pdims.get_element_type();
let Ok(ndarray_dims) = IntType::try_from(ndarray_dims) else {
return Err(format!(
"Expected pointer-to-int type for `ndarray.1`, got pointer-to-{ndarray_dims}"
));
};
if ndarray_dims.get_bit_width() != llvm_usize.get_bit_width() {
return Err(format!(
"Expected pointer-to-{}-bit int type for `ndarray.1`, got pointer-to-{}-bit int",
llvm_usize.get_bit_width(),
ndarray_dims.get_bit_width()
));
}
let ndarray_data_ty = llvm_ndarray_ty.get_field_type_at_index(2).unwrap();
let Ok(ndarray_pdata) = PointerType::try_from(ndarray_data_ty) else {
return Err(format!("Expected pointer type for `ndarray.2`, got {ndarray_data_ty}"));
};
let ndarray_data = ndarray_pdata.get_element_type();
let Ok(ndarray_data) = IntType::try_from(ndarray_data) else {
return Err(format!(
"Expected pointer-to-int type for `ndarray.2`, got pointer-to-{ndarray_data}"
));
};
if ndarray_data.get_bit_width() != 8 {
return Err(format!(
"Expected pointer-to-8-bit int type for `ndarray.1`, got pointer-to-{}-bit int",
ndarray_data.get_bit_width()
));
}
Ok(())
}
// TODO: Move this into e.g. StructProxyType
#[must_use]
fn fields(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> NDArrayStructFields<'ctx> {
NDArrayStructFields::new(ctx, llvm_usize)
}
// TODO: Move this into e.g. StructProxyType
#[must_use]
pub fn get_fields(
&self,
ctx: &'ctx Context,
llvm_usize: IntType<'ctx>,
) -> NDArrayStructFields<'ctx> {
Self::fields(ctx, llvm_usize)
}
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
#[must_use]
fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> {
// struct NDArray { num_dims: size_t, dims: size_t*, data: i8* }
//
// * data : Pointer to an array containing the array data
// * itemsize: The size of each NDArray elements in bytes
// * ndims : Number of dimensions in the array
// * shape : Pointer to an array containing the shape of the NDArray
// * strides : Pointer to an array indicating the number of bytes between each element at a dimension
let field_tys =
Self::fields(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec();
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
}
/// Creates an instance of [`NDArrayType`].
#[must_use]
pub fn new<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
dtype: BasicTypeEnum<'ctx>,
) -> Self {
let llvm_usize = generator.get_size_type(ctx);
let llvm_ndarray = Self::llvm_type(ctx, llvm_usize);
NDArrayType { ty: llvm_ndarray, dtype, llvm_usize }
}
/// Creates an [`NDArrayType`] from a [`PointerType`] representing an `NDArray`.
#[must_use]
pub fn from_type(
ptr_ty: PointerType<'ctx>,
dtype: BasicTypeEnum<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Self {
debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok());
NDArrayType { ty: ptr_ty, dtype, llvm_usize }
}
/// Returns the type of the `size` field of this `ndarray` type.
#[must_use]
pub fn size_type(&self) -> IntType<'ctx> {
self.as_base_type()
.get_element_type()
.into_struct_type()
.get_field_type_at_index(0)
.map(BasicTypeEnum::into_int_type)
.unwrap()
}
/// Returns the element type of this `ndarray` type.
#[must_use]
pub fn element_type(&self) -> BasicTypeEnum<'ctx> {
self.dtype
}
}
impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> {
type Base = PointerType<'ctx>;
type Value = NDArrayValue<'ctx>;
fn is_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: impl BasicType<'ctx>,
) -> Result<(), String> {
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
<Self as ProxyType<'ctx>>::is_representable(generator, ctx, ty)
} else {
Err(format!("Expected pointer type, got {llvm_ty:?}"))
}
}
fn is_representable<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: Self::Base,
) -> Result<(), String> {
Self::is_representable(llvm_ty, generator.get_size_type(ctx))
}
fn new_value<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> Self::Value {
self.map_value(
generator
.gen_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
name,
)
.unwrap(),
name,
)
}
fn new_array_value<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
size: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> ArraySliceValue<'ctx> {
generator
.gen_array_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
size,
name,
)
.unwrap()
}
fn map_value(
&self,
value: <Self::Value as ProxyValue<'ctx>>::Base,
name: Option<&'ctx str>,
) -> Self::Value {
debug_assert_eq!(value.get_type(), self.as_base_type());
NDArrayValue::from_pointer_value(value, self.dtype, self.llvm_usize, name)
}
fn as_base_type(&self) -> Self::Base {
self.ty
}
}
impl<'ctx> From<NDArrayType<'ctx>> for PointerType<'ctx> {
fn from(value: NDArrayType<'ctx>) -> Self {
value.as_base_type()
}
}

View File

@ -1,159 +0,0 @@
use inkwell::{
context::Context,
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
values::IntValue,
AddressSpace,
};
use super::ProxyType;
use crate::codegen::{
values::{ArraySliceValue, ProxyValue, RangeValue},
{CodeGenContext, CodeGenerator},
};
/// Proxy type for a `range` type in LLVM.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct RangeType<'ctx> {
ty: PointerType<'ctx>,
}
impl<'ctx> RangeType<'ctx> {
/// Checks whether `llvm_ty` represents a `range` type, returning [Err] if it does not.
pub fn is_representable(llvm_ty: PointerType<'ctx>) -> Result<(), String> {
let llvm_range_ty = llvm_ty.get_element_type();
let AnyTypeEnum::ArrayType(llvm_range_ty) = llvm_range_ty else {
return Err(format!("Expected array type for `range` type, got {llvm_range_ty}"));
};
if llvm_range_ty.len() != 3 {
return Err(format!(
"Expected 3 elements for `range` type, got {}",
llvm_range_ty.len()
));
}
let llvm_range_elem_ty = llvm_range_ty.get_element_type();
let Ok(llvm_range_elem_ty) = IntType::try_from(llvm_range_elem_ty) else {
return Err(format!(
"Expected int type for `range` element type, got {llvm_range_elem_ty}"
));
};
if llvm_range_elem_ty.get_bit_width() != 32 {
return Err(format!(
"Expected 32-bit int type for `range` element type, got {}",
llvm_range_elem_ty.get_bit_width()
));
}
Ok(())
}
/// Creates an LLVM type corresponding to the expected structure of a `Range`.
#[must_use]
fn llvm_type(ctx: &'ctx Context) -> PointerType<'ctx> {
// typedef int32_t Range[3];
let llvm_i32 = ctx.i32_type();
llvm_i32.array_type(3).ptr_type(AddressSpace::default())
}
/// Creates an instance of [`RangeType`].
#[must_use]
pub fn new(ctx: &'ctx Context) -> Self {
let llvm_range = Self::llvm_type(ctx);
RangeType::from_type(llvm_range)
}
/// Creates an [`RangeType`] from a [`PointerType`].
#[must_use]
pub fn from_type(ptr_ty: PointerType<'ctx>) -> Self {
debug_assert!(Self::is_representable(ptr_ty).is_ok());
RangeType { ty: ptr_ty }
}
/// Returns the type of all fields of this `range` type.
#[must_use]
pub fn value_type(&self) -> IntType<'ctx> {
self.as_base_type().get_element_type().into_array_type().get_element_type().into_int_type()
}
}
impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> {
type Base = PointerType<'ctx>;
type Value = RangeValue<'ctx>;
fn is_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: impl BasicType<'ctx>,
) -> Result<(), String> {
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
<Self as ProxyType<'ctx>>::is_representable(generator, ctx, ty)
} else {
Err(format!("Expected pointer type, got {llvm_ty:?}"))
}
}
fn is_representable<G: CodeGenerator + ?Sized>(
_: &G,
_: &'ctx Context,
llvm_ty: Self::Base,
) -> Result<(), String> {
Self::is_representable(llvm_ty)
}
fn new_value<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> Self::Value {
self.map_value(
generator
.gen_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
name,
)
.unwrap(),
name,
)
}
fn new_array_value<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
size: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> ArraySliceValue<'ctx> {
generator
.gen_array_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
size,
name,
)
.unwrap()
}
fn map_value(
&self,
value: <Self::Value as ProxyValue<'ctx>>::Base,
name: Option<&'ctx str>,
) -> Self::Value {
debug_assert_eq!(value.get_type(), self.as_base_type());
RangeValue::from_pointer_value(value, name)
}
fn as_base_type(&self) -> Self::Base {
self.ty
}
}
impl<'ctx> From<RangeType<'ctx>> for PointerType<'ctx> {
fn from(value: RangeType<'ctx>) -> Self {
value.as_base_type()
}
}

View File

@ -1,203 +0,0 @@
use std::marker::PhantomData;
use inkwell::{
context::AsContextRef,
types::{BasicTypeEnum, IntType},
values::{BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue},
};
use crate::codegen::CodeGenContext;
/// Trait indicating that the structure is a field-wise representation of an LLVM structure.
///
/// # Usage
///
/// For example, for a simple C-slice LLVM structure:
///
/// ```ignore
/// struct CSliceFields<'ctx> {
/// ptr: StructField<'ctx, PointerValue<'ctx>>,
/// len: StructField<'ctx, IntValue<'ctx>>
/// }
/// ```
pub trait StructFields<'ctx>: Eq + Copy {
/// Creates an instance of [`StructFields`] using the given `ctx` and `size_t` types.
fn new(ctx: impl AsContextRef<'ctx>, llvm_usize: IntType<'ctx>) -> Self;
/// Returns a [`Vec`] that contains the fields of the structure in the order as they appear in
/// the type definition.
#[must_use]
fn to_vec(&self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)>;
/// Returns a [`Iterator`] that contains the fields of the structure in the order as they appear
/// in the type definition.
#[must_use]
fn iter(&self) -> impl Iterator<Item = (&'static str, BasicTypeEnum<'ctx>)> {
self.to_vec().into_iter()
}
/// Returns a [`Vec`] that contains the fields of the structure in the order as they appear in
/// the type definition.
#[must_use]
fn into_vec(self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)>
where
Self: Sized,
{
self.to_vec()
}
/// Returns a [`Iterator`] that contains the fields of the structure in the order as they appear
/// in the type definition.
#[must_use]
fn into_iter(self) -> impl Iterator<Item = (&'static str, BasicTypeEnum<'ctx>)>
where
Self: Sized,
{
self.into_vec().into_iter()
}
}
/// A single field of an LLVM structure.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct StructField<'ctx, Value>
where
Value: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>, Error = ()>,
{
/// The index of this field within the structure.
index: u32,
/// The name of this field.
name: &'static str,
/// The type of this field.
ty: BasicTypeEnum<'ctx>,
/// Instance of [`PhantomData`] containing [`Value`], used to implement automatic downcasts.
_value_ty: PhantomData<Value>,
}
impl<'ctx, Value> StructField<'ctx, Value>
where
Value: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>, Error = ()>,
{
/// Creates an instance of [`StructField`].
///
/// * `idx_counter` - The instance of [`FieldIndexCounter`] used to track the current field
/// index.
/// * `name` - Name of the field.
/// * `ty` - The type of this field.
pub fn create(
idx_counter: &mut FieldIndexCounter,
name: &'static str,
ty: impl Into<BasicTypeEnum<'ctx>>,
) -> Self {
StructField { index: idx_counter.increment(), name, ty: ty.into(), _value_ty: PhantomData }
}
/// Creates an instance of [`StructField`] with a given index.
///
/// * `index` - The index of this field within its enclosing structure.
/// * `name` - Name of the field.
/// * `ty` - The type of this field.
pub fn create_at(index: u32, name: &'static str, ty: impl Into<BasicTypeEnum<'ctx>>) -> Self {
StructField { index, name, ty: ty.into(), _value_ty: PhantomData }
}
/// Creates a pointer to this field in an arbitrary structure by performing a `getelementptr i32
/// {idx...}, i32 {self.index}`.
pub fn ptr_by_array_gep(
&self,
ctx: &CodeGenContext<'ctx, '_>,
pobj: PointerValue<'ctx>,
idx: &[IntValue<'ctx>],
) -> PointerValue<'ctx> {
unsafe {
ctx.builder.build_in_bounds_gep(
pobj,
&[idx, &[ctx.ctx.i32_type().const_int(u64::from(self.index), false)]].concat(),
"",
)
}
.unwrap()
}
/// Creates a pointer to this field in an arbitrary structure by performing the equivalent of
/// `getelementptr i32 0, i32 {self.index}`.
pub fn ptr_by_gep(
&self,
ctx: &CodeGenContext<'ctx, '_>,
pobj: PointerValue<'ctx>,
obj_name: Option<&'ctx str>,
) -> PointerValue<'ctx> {
ctx.builder
.build_struct_gep(
pobj,
self.index,
&obj_name.map(|name| format!("{name}.{}.addr", self.name)).unwrap_or_default(),
)
.unwrap()
}
/// Gets the value of this field for a given `obj`.
#[must_use]
pub fn get_from_value(&self, obj: StructValue<'ctx>) -> Value {
obj.get_field_at_index(self.index).and_then(|value| Value::try_from(value).ok()).unwrap()
}
/// Sets the value of this field for a given `obj`.
pub fn set_from_value(&self, obj: StructValue<'ctx>, value: Value) {
obj.set_field_at_index(self.index, value);
}
/// Gets the value of this field for a pointer-to-structure.
pub fn get(
&self,
ctx: &CodeGenContext<'ctx, '_>,
pobj: PointerValue<'ctx>,
obj_name: Option<&'ctx str>,
) -> Value {
ctx.builder
.build_load(
self.ptr_by_gep(ctx, pobj, obj_name),
&obj_name.map(|name| format!("{name}.{}", self.name)).unwrap_or_default(),
)
.map_err(|_| ())
.and_then(|value| Value::try_from(value))
.unwrap()
}
/// Sets the value of this field for a pointer-to-structure.
pub fn set(
&self,
ctx: &CodeGenContext<'ctx, '_>,
pobj: PointerValue<'ctx>,
value: Value,
obj_name: Option<&'ctx str>,
) {
ctx.builder.build_store(self.ptr_by_gep(ctx, pobj, obj_name), value).unwrap();
}
}
impl<'ctx, Value> From<StructField<'ctx, Value>> for (&'static str, BasicTypeEnum<'ctx>)
where
Value: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>, Error = ()>,
{
fn from(value: StructField<'ctx, Value>) -> Self {
(value.name, value.ty)
}
}
/// A counter that tracks the next index of a field using a monotonically increasing counter.
#[derive(Default, Debug, PartialEq, Eq, Clone, Copy)]
pub struct FieldIndexCounter(u32);
impl FieldIndexCounter {
/// Increments the number stored by this counter, returning the previous value.
///
/// Functionally equivalent to `i++` in C-based languages.
pub fn increment(&mut self) -> u32 {
let v = self.0;
self.0 += 1;
v
}
}

View File

@ -1,426 +0,0 @@
use inkwell::{
types::AnyTypeEnum,
values::{BasicValueEnum, IntValue, PointerValue},
IntPredicate,
};
use crate::codegen::{CodeGenContext, CodeGenerator};
/// An LLVM value that is array-like, i.e. it contains a contiguous, sequenced collection of
/// elements.
pub trait ArrayLikeValue<'ctx> {
/// Returns the element type of this array-like value.
fn element_type<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> AnyTypeEnum<'ctx>;
/// Returns the base pointer to the array.
fn base_ptr<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> PointerValue<'ctx>;
/// Returns the size of this array-like value.
fn size<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> IntValue<'ctx>;
/// Returns a [`ArraySliceValue`] representing this value.
fn as_slice_value<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> ArraySliceValue<'ctx> {
ArraySliceValue::from_ptr_val(
self.base_ptr(ctx, generator),
self.size(ctx, generator),
None,
)
}
}
/// An array-like value that can be indexed by memory offset.
pub trait ArrayLikeIndexer<'ctx, Index = IntValue<'ctx>>: ArrayLikeValue<'ctx> {
/// # Safety
///
/// This function should be called with a valid index.
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
name: Option<&str>,
) -> PointerValue<'ctx>;
/// Returns the pointer to the data at the `idx`-th index.
fn ptr_offset<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
name: Option<&str>,
) -> PointerValue<'ctx>;
}
/// An array-like value that can have its array elements accessed as a [`BasicValueEnum`].
pub trait UntypedArrayLikeAccessor<'ctx, Index = IntValue<'ctx>>:
ArrayLikeIndexer<'ctx, Index>
{
/// # Safety
///
/// This function should be called with a valid index.
unsafe fn get_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
name: Option<&str>,
) -> BasicValueEnum<'ctx> {
let ptr = unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) };
ctx.builder.build_load(ptr, name.unwrap_or_default()).unwrap()
}
/// Returns the data at the `idx`-th index.
fn get<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
name: Option<&str>,
) -> BasicValueEnum<'ctx> {
let ptr = self.ptr_offset(ctx, generator, idx, name);
ctx.builder.build_load(ptr, name.unwrap_or_default()).unwrap()
}
}
/// An array-like value that can have its array elements mutated as a [`BasicValueEnum`].
pub trait UntypedArrayLikeMutator<'ctx, Index = IntValue<'ctx>>:
ArrayLikeIndexer<'ctx, Index>
{
/// # Safety
///
/// This function should be called with a valid index.
unsafe fn set_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
value: BasicValueEnum<'ctx>,
) {
let ptr = unsafe { self.ptr_offset_unchecked(ctx, generator, idx, None) };
ctx.builder.build_store(ptr, value).unwrap();
}
/// Sets the data at the `idx`-th index.
fn set<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
value: BasicValueEnum<'ctx>,
) {
let ptr = self.ptr_offset(ctx, generator, idx, None);
ctx.builder.build_store(ptr, value).unwrap();
}
}
/// An array-like value that can have its array elements accessed as an arbitrary type `T`.
pub trait TypedArrayLikeAccessor<'ctx, T, Index = IntValue<'ctx>>:
UntypedArrayLikeAccessor<'ctx, Index>
{
/// Casts an element from [`BasicValueEnum`] into `T`.
fn downcast_to_type(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
value: BasicValueEnum<'ctx>,
) -> T;
/// # Safety
///
/// This function should be called with a valid index.
unsafe fn get_typed_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
name: Option<&str>,
) -> T {
let value = unsafe { self.get_unchecked(ctx, generator, idx, name) };
self.downcast_to_type(ctx, value)
}
/// Returns the data at the `idx`-th index.
fn get_typed<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
name: Option<&str>,
) -> T {
let value = self.get(ctx, generator, idx, name);
self.downcast_to_type(ctx, value)
}
}
/// An array-like value that can have its array elements mutated as an arbitrary type `T`.
pub trait TypedArrayLikeMutator<'ctx, T, Index = IntValue<'ctx>>:
UntypedArrayLikeMutator<'ctx, Index>
{
/// Casts an element from T into [`BasicValueEnum`].
fn upcast_from_type(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
value: T,
) -> BasicValueEnum<'ctx>;
/// # Safety
///
/// This function should be called with a valid index.
unsafe fn set_typed_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
value: T,
) {
let value = self.upcast_from_type(ctx, value);
unsafe { self.set_unchecked(ctx, generator, idx, value) }
}
/// Sets the data at the `idx`-th index.
fn set_typed<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
value: T,
) {
let value = self.upcast_from_type(ctx, value);
self.set(ctx, generator, idx, value);
}
}
/// Type alias for a function that casts a [`BasicValueEnum`] into a `T`.
type ValueDowncastFn<'ctx, T> =
Box<dyn Fn(&mut CodeGenContext<'ctx, '_>, BasicValueEnum<'ctx>) -> T>;
/// Type alias for a function that casts a `T` into a [`BasicValueEnum`].
type ValueUpcastFn<'ctx, T> = Box<dyn Fn(&mut CodeGenContext<'ctx, '_>, T) -> BasicValueEnum<'ctx>>;
/// An adapter for constraining untyped array values as typed values.
pub struct TypedArrayLikeAdapter<'ctx, T, Adapted: ArrayLikeValue<'ctx> = ArraySliceValue<'ctx>> {
adapted: Adapted,
downcast_fn: ValueDowncastFn<'ctx, T>,
upcast_fn: ValueUpcastFn<'ctx, T>,
}
impl<'ctx, T, Adapted> TypedArrayLikeAdapter<'ctx, T, Adapted>
where
Adapted: ArrayLikeValue<'ctx>,
{
/// Creates a [`TypedArrayLikeAdapter`].
///
/// * `adapted` - The value to be adapted.
/// * `downcast_fn` - The function converting a [`BasicValueEnum`] into a `T`.
/// * `upcast_fn` - The function converting a T into a [`BasicValueEnum`].
pub fn from(
adapted: Adapted,
downcast_fn: ValueDowncastFn<'ctx, T>,
upcast_fn: ValueUpcastFn<'ctx, T>,
) -> Self {
TypedArrayLikeAdapter { adapted, downcast_fn, upcast_fn }
}
}
impl<'ctx, T, Adapted> ArrayLikeValue<'ctx> for TypedArrayLikeAdapter<'ctx, T, Adapted>
where
Adapted: ArrayLikeValue<'ctx>,
{
fn element_type<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> AnyTypeEnum<'ctx> {
self.adapted.element_type(ctx, generator)
}
fn base_ptr<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> PointerValue<'ctx> {
self.adapted.base_ptr(ctx, generator)
}
fn size<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> IntValue<'ctx> {
self.adapted.size(ctx, generator)
}
}
impl<'ctx, T, Index, Adapted> ArrayLikeIndexer<'ctx, Index>
for TypedArrayLikeAdapter<'ctx, T, Adapted>
where
Adapted: ArrayLikeIndexer<'ctx, Index>,
{
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
name: Option<&str>,
) -> PointerValue<'ctx> {
unsafe { self.adapted.ptr_offset_unchecked(ctx, generator, idx, name) }
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
name: Option<&str>,
) -> PointerValue<'ctx> {
self.adapted.ptr_offset(ctx, generator, idx, name)
}
}
impl<'ctx, T, Index, Adapted> UntypedArrayLikeAccessor<'ctx, Index>
for TypedArrayLikeAdapter<'ctx, T, Adapted>
where
Adapted: UntypedArrayLikeAccessor<'ctx, Index>,
{
}
impl<'ctx, T, Index, Adapted> UntypedArrayLikeMutator<'ctx, Index>
for TypedArrayLikeAdapter<'ctx, T, Adapted>
where
Adapted: UntypedArrayLikeMutator<'ctx, Index>,
{
}
impl<'ctx, T, Index, Adapted> TypedArrayLikeAccessor<'ctx, T, Index>
for TypedArrayLikeAdapter<'ctx, T, Adapted>
where
Adapted: UntypedArrayLikeAccessor<'ctx, Index>,
{
fn downcast_to_type(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
value: BasicValueEnum<'ctx>,
) -> T {
(self.downcast_fn)(ctx, value)
}
}
impl<'ctx, T, Index, Adapted> TypedArrayLikeMutator<'ctx, T, Index>
for TypedArrayLikeAdapter<'ctx, T, Adapted>
where
Adapted: UntypedArrayLikeMutator<'ctx, Index>,
{
fn upcast_from_type(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
value: T,
) -> BasicValueEnum<'ctx> {
(self.upcast_fn)(ctx, value)
}
}
/// An LLVM value representing an array slice, consisting of a pointer to the data and the size of
/// the slice.
#[derive(Copy, Clone)]
pub struct ArraySliceValue<'ctx>(PointerValue<'ctx>, IntValue<'ctx>, Option<&'ctx str>);
impl<'ctx> ArraySliceValue<'ctx> {
/// Creates an [`ArraySliceValue`] from a [`PointerValue`] and its size.
#[must_use]
pub fn from_ptr_val(
ptr: PointerValue<'ctx>,
size: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> Self {
ArraySliceValue(ptr, size, name)
}
}
impl<'ctx> From<ArraySliceValue<'ctx>> for PointerValue<'ctx> {
fn from(value: ArraySliceValue<'ctx>) -> Self {
value.0
}
}
impl<'ctx> ArrayLikeValue<'ctx> for ArraySliceValue<'ctx> {
fn element_type<G: CodeGenerator + ?Sized>(
&self,
_: &CodeGenContext<'ctx, '_>,
_: &G,
) -> AnyTypeEnum<'ctx> {
self.0.get_type().get_element_type()
}
fn base_ptr<G: CodeGenerator + ?Sized>(
&self,
_: &CodeGenContext<'ctx, '_>,
_: &G,
) -> PointerValue<'ctx> {
self.0
}
fn size<G: CodeGenerator + ?Sized>(
&self,
_: &CodeGenContext<'ctx, '_>,
_: &G,
) -> IntValue<'ctx> {
self.1
}
}
impl<'ctx> ArrayLikeIndexer<'ctx> for ArraySliceValue<'ctx> {
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
.unwrap()
}
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
debug_assert_eq!(idx.get_type(), generator.get_size_type(ctx.ctx));
let size = self.size(ctx, generator);
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
ctx.make_assert(
generator,
in_range,
"0:IndexError",
"list index out of range",
[None, None, None],
ctx.current_loc,
);
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
}
}
impl<'ctx> UntypedArrayLikeAccessor<'ctx> for ArraySliceValue<'ctx> {}
impl<'ctx> UntypedArrayLikeMutator<'ctx> for ArraySliceValue<'ctx> {}

View File

@ -1,241 +0,0 @@
use inkwell::{
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType},
values::{BasicValueEnum, IntValue, PointerValue},
AddressSpace, IntPredicate,
};
use super::{
ArrayLikeIndexer, ArrayLikeValue, ProxyValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
};
use crate::codegen::{
types::ListType,
{CodeGenContext, CodeGenerator},
};
/// Proxy type for accessing a `list` value in LLVM.
#[derive(Copy, Clone)]
pub struct ListValue<'ctx> {
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
}
impl<'ctx> ListValue<'ctx> {
/// Checks whether `value` is an instance of `list`, returning [Err] if `value` is not an
/// instance.
pub fn is_representable(
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
ListType::is_representable(value.get_type(), llvm_usize)
}
/// Creates an [`ListValue`] from a [`PointerValue`].
#[must_use]
pub fn from_pointer_value(
ptr: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
) -> Self {
debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok());
ListValue { value: ptr, llvm_usize, name }
}
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
/// on the field.
fn pptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
var_name.as_str(),
)
.unwrap()
}
}
/// Returns the pointer to the field storing the size of this `list`.
fn ptr_to_size(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.size.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
var_name.as_str(),
)
.unwrap()
}
}
/// Stores the array of data elements `data` into this instance.
fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) {
ctx.builder.build_store(self.pptr_to_data(ctx), data).unwrap();
}
/// Convenience method for creating a new array storing data elements with the given element
/// type `elem_ty` and `size`.
///
/// If `size` is [None], the size stored in the field of this instance is used instead.
pub fn create_data(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: BasicTypeEnum<'ctx>,
size: Option<IntValue<'ctx>>,
) {
let size = size.unwrap_or_else(|| self.load_size(ctx, None));
let data = ctx
.builder
.build_select(
ctx.builder
.build_int_compare(IntPredicate::NE, size, self.llvm_usize.const_zero(), "")
.unwrap(),
ctx.builder.build_array_alloca(elem_ty, size, "").unwrap(),
elem_ty.ptr_type(AddressSpace::default()).const_zero(),
"",
)
.map(BasicValueEnum::into_pointer_value)
.unwrap();
self.store_data(ctx, data);
}
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
/// on the field.
#[must_use]
pub fn data(&self) -> ListDataProxy<'ctx, '_> {
ListDataProxy(self)
}
/// Stores the `size` of this `list` into this instance.
pub fn store_size<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
size: IntValue<'ctx>,
) {
debug_assert_eq!(size.get_type(), generator.get_size_type(ctx.ctx));
let psize = self.ptr_to_size(ctx);
ctx.builder.build_store(psize, size).unwrap();
}
/// Returns the size of this `list` as a value.
pub fn load_size(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> {
let psize = self.ptr_to_size(ctx);
let var_name = name
.map(ToString::to_string)
.or_else(|| self.name.map(|v| format!("{v}.size")))
.unwrap_or_default();
ctx.builder
.build_load(psize, var_name.as_str())
.map(BasicValueEnum::into_int_value)
.unwrap()
}
}
impl<'ctx> ProxyValue<'ctx> for ListValue<'ctx> {
type Base = PointerValue<'ctx>;
type Type = ListType<'ctx>;
fn get_type(&self) -> Self::Type {
ListType::from_type(self.as_base_value().get_type(), self.llvm_usize)
}
fn as_base_value(&self) -> Self::Base {
self.value
}
}
impl<'ctx> From<ListValue<'ctx>> for PointerValue<'ctx> {
fn from(value: ListValue<'ctx>) -> Self {
value.as_base_value()
}
}
/// Proxy type for accessing the `data` array of an `list` instance in LLVM.
#[derive(Copy, Clone)]
pub struct ListDataProxy<'ctx, 'a>(&'a ListValue<'ctx>);
impl<'ctx> ArrayLikeValue<'ctx> for ListDataProxy<'ctx, '_> {
fn element_type<G: CodeGenerator + ?Sized>(
&self,
_: &CodeGenContext<'ctx, '_>,
_: &G,
) -> AnyTypeEnum<'ctx> {
self.0.value.get_type().get_element_type()
}
fn base_ptr<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> PointerValue<'ctx> {
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
ctx.builder
.build_load(self.0.pptr_to_data(ctx), var_name.as_str())
.map(BasicValueEnum::into_pointer_value)
.unwrap()
}
fn size<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> IntValue<'ctx> {
self.0.load_size(ctx, None)
}
}
impl<'ctx> ArrayLikeIndexer<'ctx> for ListDataProxy<'ctx, '_> {
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
.unwrap()
}
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
debug_assert_eq!(idx.get_type(), generator.get_size_type(ctx.ctx));
let size = self.size(ctx, generator);
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
ctx.make_assert(
generator,
in_range,
"0:IndexError",
"list index out of range",
[None, None, None],
ctx.current_loc,
);
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
}
}
impl<'ctx> UntypedArrayLikeAccessor<'ctx> for ListDataProxy<'ctx, '_> {}
impl<'ctx> UntypedArrayLikeMutator<'ctx> for ListDataProxy<'ctx, '_> {}

View File

@ -1,47 +0,0 @@
use inkwell::{context::Context, values::BasicValue};
use super::types::ProxyType;
use crate::codegen::CodeGenerator;
pub use array::*;
pub use list::*;
pub use ndarray::*;
pub use range::*;
mod array;
mod list;
mod ndarray;
mod range;
/// A LLVM type that is used to represent a non-primitive value in NAC3.
pub trait ProxyValue<'ctx>: Into<Self::Base> {
/// The type of LLVM values represented by this instance. This is usually the
/// [LLVM pointer type][PointerValue].
type Base: BasicValue<'ctx>;
/// The type of this value.
type Type: ProxyType<'ctx, Value = Self>;
/// Checks whether `value` can be represented by this [`ProxyValue`].
fn is_instance<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
value: impl BasicValue<'ctx>,
) -> Result<(), String> {
Self::Type::is_type(generator, ctx, value.as_basic_value_enum().get_type())
}
/// Checks whether `value` can be represented by this [`ProxyValue`].
fn is_representable<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
value: Self::Base,
) -> Result<(), String> {
Self::is_instance(generator, ctx, value.as_basic_value_enum())
}
/// Returns the [type][ProxyType] of this value.
fn get_type(&self) -> Self::Type;
/// Returns the [base value][Self::Base] of this proxy.
fn as_base_value(&self) -> Self::Base;
}

View File

@ -1,523 +0,0 @@
use inkwell::{
types::{AnyType, AnyTypeEnum, BasicType, BasicTypeEnum, IntType},
values::{BasicValueEnum, IntValue, PointerValue},
AddressSpace, IntPredicate,
};
use super::{
ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeMutator,
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
};
use crate::codegen::{
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
llvm_intrinsics::call_int_umin,
stmt::gen_for_callback_incrementing,
types::NDArrayType,
CodeGenContext, CodeGenerator,
};
/// Proxy type for accessing an `NDArray` value in LLVM.
#[derive(Copy, Clone)]
pub struct NDArrayValue<'ctx> {
value: PointerValue<'ctx>,
dtype: BasicTypeEnum<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
}
impl<'ctx> NDArrayValue<'ctx> {
/// Checks whether `value` is an instance of `NDArray`, returning [Err] if `value` is not an
/// instance.
pub fn is_representable(
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
NDArrayType::is_representable(value.get_type(), llvm_usize)
}
/// Creates an [`NDArrayValue`] from a [`PointerValue`].
#[must_use]
pub fn from_pointer_value(
ptr: PointerValue<'ctx>,
dtype: BasicTypeEnum<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
) -> Self {
debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok());
NDArrayValue { value: ptr, dtype, llvm_usize, name }
}
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.ndims
.ptr_by_gep(ctx, self.value, self.name)
}
/// Stores the number of dimensions `ndims` into this instance.
pub fn store_ndims<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
ndims: IntValue<'ctx>,
) {
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_store(pndims, ndims).unwrap();
}
/// Returns the number of dimensions of this `NDArray` as a value.
pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
}
/// Returns the double-indirection pointer to the `shape` array, as if by calling
/// `getelementptr` on the field.
fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.shape
.ptr_by_gep(ctx, self.value, self.name)
}
/// Stores the array of dimension sizes `dims` into this instance.
fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
ctx.builder.build_store(self.ptr_to_shape(ctx), dims).unwrap();
}
/// Convenience method for creating a new array storing dimension sizes with the given `size`.
pub fn create_shape(
&self,
ctx: &CodeGenContext<'ctx, '_>,
llvm_usize: IntType<'ctx>,
size: IntValue<'ctx>,
) {
self.store_shape(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
}
/// Returns a proxy object to the field storing the size of each dimension of this `NDArray`.
#[must_use]
pub fn shape(&self) -> NDArrayShapeProxy<'ctx, '_> {
NDArrayShapeProxy(self)
}
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
/// on the field.
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.data
.ptr_by_gep(ctx, self.value, self.name)
}
/// Stores the array of data elements `data` into this instance.
fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) {
let data = ctx
.builder
.build_bit_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "")
.unwrap();
ctx.builder.build_store(self.ptr_to_data(ctx), data).unwrap();
}
/// Convenience method for creating a new array storing data elements with the given element
/// type `elem_ty` and `size`.
pub fn create_data(
&self,
ctx: &CodeGenContext<'ctx, '_>,
elem_ty: BasicTypeEnum<'ctx>,
size: IntValue<'ctx>,
) {
let itemsize =
ctx.builder.build_int_cast(elem_ty.size_of().unwrap(), size.get_type(), "").unwrap();
let nbytes = ctx.builder.build_int_mul(size, itemsize, "").unwrap();
// TODO: What about alignment?
self.store_data(
ctx,
ctx.builder.build_array_alloca(ctx.ctx.i8_type(), nbytes, "").unwrap(),
);
}
/// Returns a proxy object to the field storing the data of this `NDArray`.
#[must_use]
pub fn data(&self) -> NDArrayDataProxy<'ctx, '_> {
NDArrayDataProxy(self)
}
}
impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> {
type Base = PointerValue<'ctx>;
type Type = NDArrayType<'ctx>;
fn get_type(&self) -> Self::Type {
NDArrayType::from_type(self.as_base_value().get_type(), self.dtype, self.llvm_usize)
}
fn as_base_value(&self) -> Self::Base {
self.value
}
}
impl<'ctx> From<NDArrayValue<'ctx>> for PointerValue<'ctx> {
fn from(value: NDArrayValue<'ctx>) -> Self {
value.as_base_value()
}
}
/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM.
#[derive(Copy, Clone)]
pub struct NDArrayShapeProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> {
fn element_type<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> AnyTypeEnum<'ctx> {
self.0.shape().base_ptr(ctx, generator).get_type().get_element_type()
}
fn base_ptr<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> PointerValue<'ctx> {
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
ctx.builder
.build_load(self.0.ptr_to_shape(ctx), var_name.as_str())
.map(BasicValueEnum::into_pointer_value)
.unwrap()
}
fn size<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> IntValue<'ctx> {
self.0.load_ndims(ctx)
}
}
impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
.unwrap()
}
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let size = self.size(ctx, generator);
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
ctx.make_assert(
generator,
in_range,
"0:IndexError",
"index {0} is out of bounds for axis 0 with size {1}",
[Some(*idx), Some(self.0.load_ndims(ctx)), None],
ctx.current_loc,
);
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
}
}
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
fn downcast_to_type(
&self,
_: &mut CodeGenContext<'ctx, '_>,
value: BasicValueEnum<'ctx>,
) -> IntValue<'ctx> {
value.into_int_value()
}
}
impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
fn upcast_from_type(
&self,
_: &mut CodeGenContext<'ctx, '_>,
value: IntValue<'ctx>,
) -> BasicValueEnum<'ctx> {
value.into()
}
}
/// Proxy type for accessing the `data` array of an `NDArray` instance in LLVM.
#[derive(Copy, Clone)]
pub struct NDArrayDataProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> {
fn element_type<G: CodeGenerator + ?Sized>(
&self,
_: &CodeGenContext<'ctx, '_>,
_: &G,
) -> AnyTypeEnum<'ctx> {
self.0.dtype.as_any_type_enum()
}
fn base_ptr<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> PointerValue<'ctx> {
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
ctx.builder
.build_load(self.0.ptr_to_data(ctx), var_name.as_str())
.map(BasicValueEnum::into_pointer_value)
.unwrap()
}
fn size<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> IntValue<'ctx> {
call_ndarray_calc_size(generator, ctx, &self.as_slice_value(ctx, generator), (None, None))
}
}
impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> {
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let sizeof_elem = ctx
.builder
.build_int_truncate_or_bit_cast(
self.element_type(ctx, generator).size_of().unwrap(),
idx.get_type(),
"",
)
.unwrap();
let idx = ctx.builder.build_int_mul(*idx, sizeof_elem, "").unwrap();
let ptr = unsafe {
ctx.builder
.build_in_bounds_gep(
self.base_ptr(ctx, generator),
&[idx],
name.unwrap_or_default(),
)
.unwrap()
};
// Current implementation is transparent - The returned pointer type is
// already cast into the expected type, allowing for immediately
// load/store.
ctx.builder
.build_pointer_cast(
ptr,
BasicTypeEnum::try_from(self.element_type(ctx, generator))
.unwrap()
.ptr_type(AddressSpace::default()),
"",
)
.unwrap()
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let data_sz = self.size(ctx, generator);
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, data_sz, "").unwrap();
ctx.make_assert(
generator,
in_range,
"0:IndexError",
"index {0} is out of bounds with size {1}",
[Some(*idx), Some(self.0.load_ndims(ctx)), None],
ctx.current_loc,
);
let ptr = unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) };
// Current implementation is transparent - The returned pointer type is
// already cast into the expected type, allowing for immediately
// load/store.
ctx.builder
.build_pointer_cast(
ptr,
BasicTypeEnum::try_from(self.element_type(ctx, generator))
.unwrap()
.ptr_type(AddressSpace::default()),
"",
)
.unwrap()
}
}
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDataProxy<'ctx, '_> {}
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayDataProxy<'ctx, '_> {}
impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
for NDArrayDataProxy<'ctx, '_>
{
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
indices: &Index,
name: Option<&str>,
) -> PointerValue<'ctx> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let indices_elem_ty = indices
.ptr_offset(ctx, generator, &llvm_usize.const_zero(), None)
.get_type()
.get_element_type();
let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else {
panic!("Expected list[int32] but got {indices_elem_ty}")
};
assert_eq!(
indices_elem_ty.get_bit_width(),
32,
"Expected list[int32] but got list[int{}]",
indices_elem_ty.get_bit_width()
);
let index = call_ndarray_flatten_index(generator, ctx, *self.0, indices);
let sizeof_elem = ctx
.builder
.build_int_truncate_or_bit_cast(
self.element_type(ctx, generator).size_of().unwrap(),
index.get_type(),
"",
)
.unwrap();
let index = ctx.builder.build_int_mul(index, sizeof_elem, "").unwrap();
let ptr = unsafe {
ctx.builder
.build_in_bounds_gep(
self.base_ptr(ctx, generator),
&[index],
name.unwrap_or_default(),
)
.unwrap()
};
// TODO: Current implementation is transparent
ctx.builder
.build_pointer_cast(
ptr,
BasicTypeEnum::try_from(self.element_type(ctx, generator))
.unwrap()
.ptr_type(AddressSpace::default()),
"",
)
.unwrap()
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
indices: &Index,
name: Option<&str>,
) -> PointerValue<'ctx> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let indices_size = indices.size(ctx, generator);
let nidx_leq_ndims = ctx
.builder
.build_int_compare(IntPredicate::SLE, indices_size, self.0.load_ndims(ctx), "")
.unwrap();
ctx.make_assert(
generator,
nidx_leq_ndims,
"0:IndexError",
"invalid index to scalar variable",
[None, None, None],
ctx.current_loc,
);
let indices_len = indices.size(ctx, generator);
let ndarray_len = self.0.load_ndims(ctx);
let len = call_int_umin(ctx, indices_len, ndarray_len, None);
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(len, false),
|generator, ctx, _, i| {
let (dim_idx, dim_sz) = unsafe {
(
indices.get_unchecked(ctx, generator, &i, None).into_int_value(),
self.0.shape().get_typed_unchecked(ctx, generator, &i, None),
)
};
let dim_idx = ctx
.builder
.build_int_z_extend_or_bit_cast(dim_idx, dim_sz.get_type(), "")
.unwrap();
let dim_lt =
ctx.builder.build_int_compare(IntPredicate::SLT, dim_idx, dim_sz, "").unwrap();
ctx.make_assert(
generator,
dim_lt,
"0:IndexError",
"index {0} is out of bounds for axis 0 with size {1}",
[Some(dim_idx), Some(dim_sz), None],
ctx.current_loc,
);
Ok(())
},
llvm_usize.const_int(1, false),
)
.unwrap();
let ptr = unsafe { self.ptr_offset_unchecked(ctx, generator, indices, name) };
// TODO: Current implementation is transparent
ctx.builder
.build_pointer_cast(
ptr,
BasicTypeEnum::try_from(self.element_type(ctx, generator))
.unwrap()
.ptr_type(AddressSpace::default()),
"",
)
.unwrap()
}
}
impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeAccessor<'ctx, Index>
for NDArrayDataProxy<'ctx, '_>
{
}
impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx, Index>
for NDArrayDataProxy<'ctx, '_>
{
}

View File

@ -1,153 +0,0 @@
use inkwell::values::{BasicValueEnum, IntValue, PointerValue};
use super::ProxyValue;
use crate::codegen::{types::RangeType, CodeGenContext};
/// Proxy type for accessing a `range` value in LLVM.
#[derive(Copy, Clone)]
pub struct RangeValue<'ctx> {
value: PointerValue<'ctx>,
name: Option<&'ctx str>,
}
impl<'ctx> RangeValue<'ctx> {
/// Checks whether `value` is an instance of `range`, returning [Err] if `value` is not an instance.
pub fn is_representable(value: PointerValue<'ctx>) -> Result<(), String> {
RangeType::is_representable(value.get_type())
}
/// Creates an [`RangeValue`] from a [`PointerValue`].
#[must_use]
pub fn from_pointer_value(ptr: PointerValue<'ctx>, name: Option<&'ctx str>) -> Self {
debug_assert!(Self::is_representable(ptr).is_ok());
RangeValue { value: ptr, name }
}
fn ptr_to_start(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.start.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(0, false)],
var_name.as_str(),
)
.unwrap()
}
}
fn ptr_to_end(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.end.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(1, false)],
var_name.as_str(),
)
.unwrap()
}
}
fn ptr_to_step(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.step.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(2, false)],
var_name.as_str(),
)
.unwrap()
}
}
/// Stores the `start` value into this instance.
pub fn store_start(&self, ctx: &CodeGenContext<'ctx, '_>, start: IntValue<'ctx>) {
debug_assert_eq!(start.get_type().get_bit_width(), 32);
let pstart = self.ptr_to_start(ctx);
ctx.builder.build_store(pstart, start).unwrap();
}
/// Returns the `start` value of this `range`.
pub fn load_start(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> {
let pstart = self.ptr_to_start(ctx);
let var_name = name
.map(ToString::to_string)
.or_else(|| self.name.map(|v| format!("{v}.start")))
.unwrap_or_default();
ctx.builder
.build_load(pstart, var_name.as_str())
.map(BasicValueEnum::into_int_value)
.unwrap()
}
/// Stores the `end` value into this instance.
pub fn store_end(&self, ctx: &CodeGenContext<'ctx, '_>, end: IntValue<'ctx>) {
debug_assert_eq!(end.get_type().get_bit_width(), 32);
let pend = self.ptr_to_end(ctx);
ctx.builder.build_store(pend, end).unwrap();
}
/// Returns the `end` value of this `range`.
pub fn load_end(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> {
let pend = self.ptr_to_end(ctx);
let var_name = name
.map(ToString::to_string)
.or_else(|| self.name.map(|v| format!("{v}.end")))
.unwrap_or_default();
ctx.builder.build_load(pend, var_name.as_str()).map(BasicValueEnum::into_int_value).unwrap()
}
/// Stores the `step` value into this instance.
pub fn store_step(&self, ctx: &CodeGenContext<'ctx, '_>, step: IntValue<'ctx>) {
debug_assert_eq!(step.get_type().get_bit_width(), 32);
let pstep = self.ptr_to_step(ctx);
ctx.builder.build_store(pstep, step).unwrap();
}
/// Returns the `step` value of this `range`.
pub fn load_step(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> {
let pstep = self.ptr_to_step(ctx);
let var_name = name
.map(ToString::to_string)
.or_else(|| self.name.map(|v| format!("{v}.step")))
.unwrap_or_default();
ctx.builder
.build_load(pstep, var_name.as_str())
.map(BasicValueEnum::into_int_value)
.unwrap()
}
}
impl<'ctx> ProxyValue<'ctx> for RangeValue<'ctx> {
type Base = PointerValue<'ctx>;
type Type = RangeType<'ctx>;
fn get_type(&self) -> Self::Type {
RangeType::from_type(self.value.get_type())
}
fn as_base_value(&self) -> Self::Base {
self.value
}
}
impl<'ctx> From<RangeValue<'ctx>> for PointerValue<'ctx> {
fn from(value: RangeValue<'ctx>) -> Self {
value.as_base_value()
}
}

View File

@ -1,25 +1,7 @@
#![deny(future_incompatible, let_underscore, nonstandard_style, clippy::all)]
#![warn(clippy::pedantic)]
#![allow(
dead_code,
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::enum_glob_use,
clippy::missing_errors_doc,
clippy::missing_panics_doc,
clippy::module_name_repetitions,
clippy::similar_names,
clippy::too_many_lines,
clippy::wildcard_imports
)]
// users of nac3core need to use the same version of these dependencies, so expose them as nac3core::*
pub use inkwell;
pub use nac3parser;
#![warn(clippy::all)]
#![allow(dead_code)]
pub mod codegen;
pub mod symbol_resolver;
pub mod toplevel;
pub mod typecheck;
extern crate self as nac3core;

View File

@ -1,24 +1,24 @@
use std::{
collections::{HashMap, HashSet},
fmt::{Debug, Display},
rc::Rc,
sync::Arc,
};
use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue};
use itertools::{chain, izip, Itertools};
use parking_lot::RwLock;
use nac3parser::ast::{Constant, Expr, Location, StrRef};
use std::fmt::Debug;
use std::sync::Arc;
use std::{collections::HashMap, collections::HashSet, fmt::Display};
use std::rc::Rc;
use crate::typecheck::typedef::TypeEnum;
use crate::{
codegen::{CodeGenContext, CodeGenerator},
toplevel::{type_annotation::TypeAnnotation, DefinitionId, TopLevelDef},
codegen::CodeGenContext,
toplevel::{DefinitionId, TopLevelDef, type_annotation::TypeAnnotation},
};
use crate::{
codegen::CodeGenerator,
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{Type, TypeEnum, Unifier, VarMap},
typedef::{Type, Unifier},
},
};
use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue};
use itertools::{chain, izip};
use nac3parser::ast::{Constant, Expr, Location, StrRef};
use parking_lot::RwLock;
#[derive(Clone, PartialEq, Debug)]
pub enum SymbolValue {
@ -43,7 +43,7 @@ impl SymbolValue {
constant: &Constant,
expected_ty: Type,
primitives: &PrimitiveStore,
unifier: &mut Unifier,
unifier: &mut Unifier
) -> Result<Self, String> {
match constant {
Constant::None => {
@ -66,30 +66,35 @@ impl SymbolValue {
} else {
Err(format!("Expected {expected_ty:?}, but got str"))
}
}
},
Constant::Int(i) => {
if unifier.unioned(expected_ty, primitives.int32) {
i32::try_from(*i).map(SymbolValue::I32).map_err(|e| e.to_string())
i32::try_from(*i)
.map(SymbolValue::I32)
.map_err(|e| e.to_string())
} else if unifier.unioned(expected_ty, primitives.int64) {
i64::try_from(*i).map(SymbolValue::I64).map_err(|e| e.to_string())
i64::try_from(*i)
.map(SymbolValue::I64)
.map_err(|e| e.to_string())
} else if unifier.unioned(expected_ty, primitives.uint32) {
u32::try_from(*i).map(SymbolValue::U32).map_err(|e| e.to_string())
u32::try_from(*i)
.map(SymbolValue::U32)
.map_err(|e| e.to_string())
} else if unifier.unioned(expected_ty, primitives.uint64) {
u64::try_from(*i).map(SymbolValue::U64).map_err(|e| e.to_string())
u64::try_from(*i)
.map(SymbolValue::U64)
.map_err(|e| e.to_string())
} else {
Err(format!("Expected {}, but got int", unifier.stringify(expected_ty)))
}
}
Constant::Tuple(t) => {
let expected_ty = unifier.get_ty(expected_ty);
let TypeEnum::TTuple { ty, is_vararg_ctx } = expected_ty.as_ref() else {
return Err(format!(
"Expected {:?}, but got Tuple",
expected_ty.get_type_name()
));
let TypeEnum::TTuple { ty } = expected_ty.as_ref() else {
return Err(format!("Expected {:?}, but got Tuple", expected_ty.get_type_name()))
};
assert!(*is_vararg_ctx || ty.len() == t.len());
assert_eq!(ty.len(), t.len());
let elems = t
.iter()
@ -104,45 +109,7 @@ impl SymbolValue {
} else {
Err(format!("Expected {expected_ty:?}, but got float"))
}
}
_ => Err(format!("Unsupported value type {constant:?}")),
}
}
/// Creates a [`SymbolValue`] from a [`Constant`], with its type being inferred from the constant value.
///
/// * `constant` - The constant to create the value from.
pub fn from_constant_inferred(constant: &Constant) -> Result<Self, String> {
match constant {
Constant::None => Ok(SymbolValue::OptionNone),
Constant::Bool(b) => Ok(SymbolValue::Bool(*b)),
Constant::Str(s) => Ok(SymbolValue::Str(s.to_string())),
Constant::Int(i) => {
let i = *i;
if i >= 0 {
i32::try_from(i)
.map(SymbolValue::I32)
.or_else(|_| i64::try_from(i).map(SymbolValue::I64))
.map_err(|_| {
format!("Literal cannot be expressed as any integral type: {i}")
})
} else {
u32::try_from(i)
.map(SymbolValue::U32)
.or_else(|_| u64::try_from(i).map(SymbolValue::U64))
.map_err(|_| {
format!("Literal cannot be expressed as any integral type: {i}")
})
}
}
Constant::Tuple(t) => {
let elems = t
.iter()
.map(Self::from_constant_inferred)
.collect::<Result<Vec<SymbolValue>, _>>()?;
Ok(SymbolValue::Tuple(elems))
}
Constant::Float(f) => Ok(SymbolValue::Double(*f)),
},
_ => Err(format!("Unsupported value type {constant:?}")),
}
}
@ -158,27 +125,28 @@ impl SymbolValue {
SymbolValue::Double(_) => primitives.float,
SymbolValue::Bool(_) => primitives.bool,
SymbolValue::Tuple(vs) => {
let vs_tys = vs.iter().map(|v| v.get_type(primitives, unifier)).collect::<Vec<_>>();
unifier.add_ty(TypeEnum::TTuple { ty: vs_tys, is_vararg_ctx: false })
let vs_tys = vs
.iter()
.map(|v| v.get_type(primitives, unifier))
.collect::<Vec<_>>();
unifier.add_ty(TypeEnum::TTuple {
ty: vs_tys,
})
}
SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option,
}
}
/// Returns the [`TypeAnnotation`] representing the data type of this value.
pub fn get_type_annotation(
&self,
primitives: &PrimitiveStore,
unifier: &mut Unifier,
) -> TypeAnnotation {
pub fn get_type_annotation(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> TypeAnnotation {
match self {
SymbolValue::Bool(..)
| SymbolValue::Double(..)
| SymbolValue::I32(..)
| SymbolValue::I64(..)
| SymbolValue::U32(..)
| SymbolValue::U64(..)
| SymbolValue::Str(..) => TypeAnnotation::Primitive(self.get_type(primitives, unifier)),
SymbolValue::Bool(..) => TypeAnnotation::Primitive(primitives.bool),
SymbolValue::Double(..) => TypeAnnotation::Primitive(primitives.float),
SymbolValue::I32(..) => TypeAnnotation::Primitive(primitives.int32),
SymbolValue::I64(..) => TypeAnnotation::Primitive(primitives.int64),
SymbolValue::U32(..) => TypeAnnotation::Primitive(primitives.uint32),
SymbolValue::U64(..) => TypeAnnotation::Primitive(primitives.uint64),
SymbolValue::Str(..) => TypeAnnotation::Primitive(primitives.str),
SymbolValue::Tuple(vs) => {
let vs_tys = vs
.iter()
@ -187,13 +155,13 @@ impl SymbolValue {
TypeAnnotation::Tuple(vs_tys)
}
SymbolValue::OptionNone => TypeAnnotation::CustomClass {
id: primitives.option.obj_id(unifier).unwrap(),
id: primitives.option.get_obj_id(unifier),
params: Vec::default(),
},
SymbolValue::OptionSome(v) => {
let ty = v.get_type_annotation(primitives, unifier);
TypeAnnotation::CustomClass {
id: primitives.option.obj_id(unifier).unwrap(),
id: primitives.option.get_obj_id(unifier),
params: vec![ty],
}
}
@ -201,11 +169,7 @@ impl SymbolValue {
}
/// Returns the [`TypeEnum`] representing the data type of this value.
pub fn get_type_enum(
&self,
primitives: &PrimitiveStore,
unifier: &mut Unifier,
) -> Rc<TypeEnum> {
pub fn get_type_enum(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> Rc<TypeEnum> {
let ty = self.get_type(primitives, unifier);
unifier.get_ty(ty)
}
@ -236,38 +200,6 @@ impl Display for SymbolValue {
}
}
impl TryFrom<SymbolValue> for u64 {
type Error = ();
/// Tries to convert a [`SymbolValue`] into a [`u64`], returning [`Err`] if the value is not
/// numeric or if the value cannot be converted into a `u64` without overflow.
fn try_from(value: SymbolValue) -> Result<Self, Self::Error> {
match value {
SymbolValue::I32(v) => u64::try_from(v).map_err(|_| ()),
SymbolValue::I64(v) => u64::try_from(v).map_err(|_| ()),
SymbolValue::U32(v) => Ok(u64::from(v)),
SymbolValue::U64(v) => Ok(v),
_ => Err(()),
}
}
}
impl TryFrom<SymbolValue> for i128 {
type Error = ();
/// Tries to convert a [`SymbolValue`] into a [`i128`], returning [`Err`] if the value is not
/// numeric.
fn try_from(value: SymbolValue) -> Result<Self, Self::Error> {
match value {
SymbolValue::I32(v) => Ok(i128::from(v)),
SymbolValue::I64(v) => Ok(i128::from(v)),
SymbolValue::U32(v) => Ok(i128::from(v)),
SymbolValue::U64(v) => Ok(i128::from(v)),
_ => Err(()),
}
}
}
pub trait StaticValue {
/// Returns a unique identifier for this value.
fn get_unique_identifier(&self) -> u64;
@ -300,10 +232,10 @@ pub trait StaticValue {
#[derive(Clone)]
pub enum ValueEnum<'ctx> {
/// [`ValueEnum`] representing a static value.
/// [ValueEnum] representing a static value.
Static(Arc<dyn StaticValue + Send + Sync>),
/// [`ValueEnum`] representing a dynamic value.
/// [ValueEnum] representing a dynamic value.
Dynamic(BasicValueEnum<'ctx>),
}
@ -338,6 +270,7 @@ impl<'ctx> From<StructValue<'ctx>> for ValueEnum<'ctx> {
}
impl<'ctx> ValueEnum<'ctx> {
/// Converts this [`ValueEnum`] to a [`BasicValueEnum`].
pub fn to_basic_value_enum<'a>(
self,
@ -369,7 +302,6 @@ pub trait SymbolResolver {
&self,
str: StrRef,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
) -> Option<ValueEnum<'ctx>>;
fn get_default_param_value(&self, expr: &Expr) -> Option<SymbolValue>;
@ -380,7 +312,7 @@ pub trait SymbolResolver {
&self,
_unifier: &mut Unifier,
_top_level_defs: &[Arc<RwLock<TopLevelDef>>],
_primitives: &PrimitiveStore,
_primitives: &PrimitiveStore
) -> Result<(), String> {
Ok(())
}
@ -393,12 +325,12 @@ thread_local! {
"float".into(),
"bool".into(),
"virtual".into(),
"list".into(),
"tuple".into(),
"str".into(),
"Exception".into(),
"uint32".into(),
"uint64".into(),
"Literal".into(),
];
}
@ -417,12 +349,12 @@ pub fn parse_type_annotation<T>(
let float_id = ids[2];
let bool_id = ids[3];
let virtual_id = ids[4];
let tuple_id = ids[5];
let str_id = ids[6];
let exn_id = ids[7];
let uint32_id = ids[8];
let uint64_id = ids[9];
let literal_id = ids[10];
let list_id = ids[5];
let tuple_id = ids[6];
let str_id = ids[7];
let exn_id = ids[8];
let uint32_id = ids[9];
let uint64_id = ids[10];
let name_handling = |id: &StrRef, loc: Location, unifier: &mut Unifier| {
if *id == int32_id {
@ -447,29 +379,40 @@ pub fn parse_type_annotation<T>(
let def = top_level_defs[obj_id.0].read();
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
if !type_vars.is_empty() {
return Err(HashSet::from([format!(
return Err(HashSet::from([
format!(
"Unexpected number of type parameters: expected {} but got 0",
type_vars.len()
)]));
),
]))
}
let fields = chain(
fields.iter().map(|(k, v, m)| (*k, (*v, *m))),
methods.iter().map(|(k, v, _)| (*k, (*v, false))),
)
.collect();
Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, params: VarMap::default() }))
Ok(unifier.add_ty(TypeEnum::TObj {
obj_id,
fields,
params: HashMap::default(),
}))
} else {
Err(HashSet::from([format!("Cannot use function name as type at {loc}")]))
Err(HashSet::from([
format!("Cannot use function name as type at {loc}"),
]))
}
} else {
let ty =
resolver.get_symbol_type(unifier, top_level_defs, primitives, *id).map_err(
|e| HashSet::from([format!("Unknown type annotation at {loc}: {e}")]),
)?;
let ty = resolver
.get_symbol_type(unifier, top_level_defs, primitives, *id)
.map_err(|e| HashSet::from([
format!("Unknown type annotation at {loc}: {e}"),
]))?;
if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) {
Ok(ty)
} else {
Err(HashSet::from([format!("Unknown type annotation {id} at {loc}")]))
Err(HashSet::from([
format!("Unknown type annotation {id} at {loc}"),
]))
}
}
}
@ -479,6 +422,9 @@ pub fn parse_type_annotation<T>(
if *id == virtual_id {
let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?;
Ok(unifier.add_ty(TypeEnum::TVirtual { ty }))
} else if *id == list_id {
let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?;
Ok(unifier.add_ty(TypeEnum::TList { ty }))
} else if *id == tuple_id {
if let Tuple { elts, .. } = &slice.node {
let ty = elts
@ -487,33 +433,12 @@ pub fn parse_type_annotation<T>(
parse_type_annotation(resolver, top_level_defs, unifier, primitives, elt)
})
.collect::<Result<Vec<_>, _>>()?;
Ok(unifier.add_ty(TypeEnum::TTuple { ty, is_vararg_ctx: false }))
Ok(unifier.add_ty(TypeEnum::TTuple { ty }))
} else {
Err(HashSet::from(["Expected multiple elements for tuple".into()]))
Err(HashSet::from([
"Expected multiple elements for tuple".into()
]))
}
} else if *id == literal_id {
let mut parse_literal = |elt: &Expr<T>| {
let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, elt)?;
let ty_enum = &*unifier.get_ty_immutable(ty);
match ty_enum {
TypeEnum::TLiteral { values, .. } => Ok(values.clone()),
_ => Err(HashSet::from([format!(
"Expected literal in type argument for Literal at {}",
elt.location
)])),
}
};
let values = if let Tuple { elts, .. } = &slice.node {
elts.iter().map(&mut parse_literal).collect::<Result<Vec<_>, _>>()?
} else {
vec![parse_literal(slice)?]
}
.into_iter()
.flatten()
.collect_vec();
Ok(unifier.get_fresh_literal(values, Some(slice.location)))
} else {
let types = if let Tuple { elts, .. } = &slice.node {
elts.iter()
@ -529,13 +454,15 @@ pub fn parse_type_annotation<T>(
let def = top_level_defs[obj_id.0].read();
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
if types.len() != type_vars.len() {
return Err(HashSet::from([format!(
return Err(HashSet::from([
format!(
"Unexpected number of type parameters: expected {} but got {}",
type_vars.len(),
types.len()
)]));
),
]))
}
let mut subst = VarMap::new();
let mut subst = HashMap::new();
for (var, ty) in izip!(type_vars.iter(), types.iter()) {
let id = if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) {
*id
@ -557,7 +484,9 @@ pub fn parse_type_annotation<T>(
}));
Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, params: subst }))
} else {
Err(HashSet::from(["Cannot use function name as type".into()]))
Err(HashSet::from([
"Cannot use function name as type".into(),
]))
}
}
};
@ -568,13 +497,14 @@ pub fn parse_type_annotation<T>(
if let Name { id, .. } = &value.node {
subscript_name_handle(id, slice, unifier)
} else {
Err(HashSet::from([format!("unsupported type expression at {}", expr.location)]))
Err(HashSet::from([
format!("unsupported type expression at {}", expr.location),
]))
}
}
Constant { value, .. } => SymbolValue::from_constant_inferred(value)
.map(|v| unifier.get_fresh_literal(vec![v], Some(expr.location)))
.map_err(|err| HashSet::from([err])),
_ => Err(HashSet::from([format!("unsupported type expression at {}", expr.location)])),
_ => Err(HashSet::from([
format!("unsupported type expression at {}", expr.location),
])),
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -6,38 +6,33 @@ use std::{
sync::Arc,
};
use inkwell::values::BasicValueEnum;
use itertools::Itertools;
use parking_lot::RwLock;
use nac3parser::ast::{self, Expr, Location, Stmt, StrRef};
use super::codegen::CodeGenContext;
use super::typecheck::type_inferencer::PrimitiveStore;
use super::typecheck::typedef::{FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, Unifier};
use crate::{
codegen::{CodeGenContext, CodeGenerator},
codegen::CodeGenerator,
symbol_resolver::{SymbolResolver, ValueEnum},
typecheck::{
type_inferencer::{CodeLocation, PrimitiveStore},
typedef::{
CallId, FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, TypeVarId, Unifier,
VarMap,
},
},
typecheck::{type_inferencer::CodeLocation, typedef::CallId},
};
use composer::*;
use type_annotation::*;
pub mod builtins;
pub mod composer;
pub mod helper;
pub mod numpy;
#[cfg(test)]
mod test;
pub mod type_annotation;
use inkwell::values::BasicValueEnum;
use itertools::{izip, Itertools};
use nac3parser::ast::{self, Location, Stmt, StrRef};
use parking_lot::RwLock;
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash, Debug)]
pub struct DefinitionId(pub usize);
type GenCallCallback = dyn for<'ctx, 'a> Fn(
pub mod builtins;
pub mod composer;
pub mod helper;
pub mod type_annotation;
use composer::*;
use type_annotation::*;
#[cfg(test)]
mod test;
type GenCallCallback = Box<
dyn for<'ctx, 'a> Fn(
&mut CodeGenContext<'ctx, 'a>,
Option<(Type, ValueEnum<'ctx>)>,
(&FunSignature, DefinitionId),
@ -45,25 +40,19 @@ type GenCallCallback = dyn for<'ctx, 'a> Fn(
&mut dyn CodeGenerator,
) -> Result<Option<BasicValueEnum<'ctx>>, String>
+ Send
+ Sync;
+ Sync,
>;
pub struct GenCall {
fp: Box<GenCallCallback>,
fp: GenCallCallback,
}
impl GenCall {
#[must_use]
pub fn new(fp: Box<GenCallCallback>) -> GenCall {
pub fn new(fp: GenCallCallback) -> GenCall {
GenCall { fp }
}
/// Creates a dummy instance of [`GenCall`], which invokes [`unreachable!()`] with the given
/// `reason`.
#[must_use]
pub fn create_dummy(reason: String) -> GenCall {
Self::new(Box::new(move |_, _, _, _, _| unreachable!("{reason}")))
}
pub fn run<'ctx>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
@ -86,7 +75,7 @@ impl Debug for GenCall {
pub struct FunInstance {
pub body: Arc<Vec<Stmt<Option<Type>>>>,
pub calls: Arc<HashMap<CodeLocation, CallId>>,
pub subst: VarMap,
pub subst: HashMap<u32, Type>,
pub unifier_id: usize,
}
@ -95,7 +84,7 @@ pub enum TopLevelDef {
Class {
/// Name for error messages and symbols.
name: StrRef,
/// Object ID used for [`TypeEnum`].
/// Object ID used for [TypeEnum].
object_id: DefinitionId,
/// type variables bounded to the class.
type_vars: Vec<Type>,
@ -103,10 +92,6 @@ pub enum TopLevelDef {
///
/// Name and type is mutable.
fields: Vec<(StrRef, Type, bool)>,
/// Class Attributes.
///
/// Name, type, value.
attributes: Vec<(StrRef, Type, ast::Constant)>,
/// Class methods, pointing to the corresponding function definition.
methods: Vec<(StrRef, Type, DefinitionId)>,
/// Ancestor classes, including itself.
@ -126,7 +111,7 @@ pub enum TopLevelDef {
/// Function signature.
signature: Type,
/// Instantiated type variable IDs.
var_id: Vec<TypeVarId>,
var_id: Vec<u32>,
/// Function instance to symbol mapping
///
/// * Key: String representation of type variable values, sorted by variable ID in ascending
@ -148,25 +133,6 @@ pub enum TopLevelDef {
/// Definition location.
loc: Option<Location>,
},
Variable {
/// Qualified name of the global variable, should be unique globally.
name: String,
/// Simple name, the same as in method/function definition.
simple_name: StrRef,
/// Type of the global variable.
ty: Type,
/// The declared type of the global variable, or [`None`] if no type annotation is provided.
ty_decl: Option<Expr>,
/// Symbol resolver of the module defined the class.
resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,
/// Definition location.
loc: Option<Location>,
},
}
pub struct TopLevelContext {

View File

@ -1,84 +0,0 @@
use itertools::Itertools;
use super::helper::PrimDef;
use crate::typecheck::{
type_inferencer::PrimitiveStore,
typedef::{Type, TypeEnum, TypeVarId, Unifier, VarMap},
};
/// Creates a `ndarray` [`Type`] with the given type arguments.
///
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
pub fn make_ndarray_ty(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
dtype: Option<Type>,
ndims: Option<Type>,
) -> Type {
subst_ndarray_tvars(unifier, primitives.ndarray, dtype, ndims)
}
/// Substitutes type variables in `ndarray`.
///
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
pub fn subst_ndarray_tvars(
unifier: &mut Unifier,
ndarray: Type,
dtype: Option<Type>,
ndims: Option<Type>,
) -> Type {
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
};
debug_assert_eq!(*obj_id, PrimDef::NDArray.id());
if dtype.is_none() && ndims.is_none() {
return ndarray;
}
let tvar_ids = params.iter().map(|(obj_id, _)| *obj_id).collect_vec();
debug_assert_eq!(tvar_ids.len(), 2);
let mut tvar_subst = VarMap::new();
if let Some(dtype) = dtype {
tvar_subst.insert(tvar_ids[0], dtype);
}
if let Some(ndims) = ndims {
tvar_subst.insert(tvar_ids[1], ndims);
}
unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray)
}
fn unpack_ndarray_tvars(unifier: &mut Unifier, ndarray: Type) -> Vec<(TypeVarId, Type)> {
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
};
debug_assert_eq!(*obj_id, PrimDef::NDArray.id());
debug_assert_eq!(params.len(), 2);
params
.iter()
.sorted_by_key(|(obj_id, _)| *obj_id)
.map(|(var_id, ty)| (*var_id, *ty))
.collect_vec()
}
/// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds
/// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray`
/// respectively.
pub fn unpack_ndarray_var_ids(unifier: &mut Unifier, ndarray: Type) -> (TypeVarId, TypeVarId) {
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.0).collect_tuple().unwrap()
}
/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to
/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively.
pub fn unpack_ndarray_var_tys(unifier: &mut Unifier, ndarray: Type) -> (Type, Type) {
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.1).collect_tuple().unwrap()
}

View File

@ -5,7 +5,7 @@ expression: res_vec
[
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(241)]\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [22]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",

View File

@ -7,7 +7,7 @@ expression: res_vec
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar230]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar230\"]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar11]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar11\"]\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",

View File

@ -5,9 +5,9 @@ expression: res_vec
[
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(243)]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(248)]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [24]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [29]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[int32, list[float]]], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
]

View File

@ -3,11 +3,11 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec
---
[
"Class {\nname: \"A\",\nancestors: [\"A[typevar229, typevar230]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar229\", \"typevar230\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[typevar10, typevar11]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[bool, float], b:B], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\")],\ntype_vars: [\"typevar10\", \"typevar11\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[bool, float], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[bool, float]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[int32, list[B]]], tuple[A[bool, virtual[A[B, int32]]], B]]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:B], B]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.bar\",\nsig: \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.bar\",\nsig: \"fn[[a:A[int32, list[B]]], tuple[A[bool, virtual[A[B, int32]]], B]]\",\nvar_id: []\n}\n",
]

View File

@ -6,12 +6,12 @@ expression: res_vec
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(249)]\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [30]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(257)]\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [38]\n}\n",
]

View File

@ -1,23 +1,19 @@
use std::{collections::HashMap, sync::Arc};
use indoc::indoc;
use parking_lot::Mutex;
use test_case::test_case;
use nac3parser::{
ast::{fold::Fold, FileName},
parser::parse_program,
};
use super::{helper::PrimDef, DefinitionId, *};
use crate::{
codegen::CodeGenContext,
symbol_resolver::{SymbolResolver, ValueEnum},
toplevel::DefinitionId,
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{into_var_map, Type, Unifier},
typedef::{Type, Unifier},
},
};
use indoc::indoc;
use nac3parser::{ast::fold::Fold, parser::parse_program};
use parking_lot::Mutex;
use std::{collections::HashMap, sync::Arc};
use test_case::test_case;
use super::*;
struct ResolverInternal {
id_to_type: Mutex<HashMap<StrRef, Type>>,
@ -56,25 +52,20 @@ impl SymbolResolver for Resolver {
.id_to_type
.lock()
.get(&str)
.copied()
.ok_or_else(|| format!("cannot find symbol `{str}`"))
.cloned()
.ok_or_else(|| format!("cannot find symbol `{}`", str))
}
fn get_symbol_value<'ctx>(
fn get_symbol_value<'ctx, 'a>(
&self,
_: StrRef,
_: &mut CodeGenContext<'ctx, '_>,
_: &mut dyn CodeGenerator,
_: &mut CodeGenContext<'ctx, 'a>,
) -> Option<ValueEnum<'ctx>> {
unimplemented!()
}
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> {
self.0
.id_to_def
.lock()
.get(&id)
.copied()
self.0.id_to_def.lock().get(&id).cloned()
.ok_or_else(|| HashSet::from(["Unknown identifier".to_string()]))
}
@ -120,14 +111,13 @@ impl SymbolResolver for Resolver {
"register"
)]
fn test_simple_register(source: Vec<&str>) {
let mut composer =
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
let mut composer: TopLevelComposer = Default::default();
for s in source {
let ast = parse_program(s, FileName::default()).unwrap();
let ast = parse_program(s, Default::default()).unwrap();
let ast = ast[0].clone();
composer.register_top_level(ast, None, "", false).unwrap();
composer.register_top_level(ast, None, "".into(), false).unwrap();
}
}
@ -141,15 +131,14 @@ fn test_simple_register(source: Vec<&str>) {
"register"
)]
fn test_simple_register_without_constructor(source: &str) {
let mut composer =
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
let ast = parse_program(source, FileName::default()).unwrap();
let mut composer: TopLevelComposer = Default::default();
let ast = parse_program(source, Default::default()).unwrap();
let ast = ast[0].clone();
composer.register_top_level(ast, None, "", true).unwrap();
composer.register_top_level(ast, None, "".into(), true).unwrap();
}
#[test_case(
&[
vec![
indoc! {"
def fun(a: int32) -> int32:
return a
@ -163,36 +152,35 @@ fn test_simple_register_without_constructor(source: &str) {
return 3
"},
],
&[
vec![
"fn[[a:0], 0]",
"fn[[a:2], 4]",
"fn[[b:1], 0]",
],
&[
vec![
"fun",
"foo",
"f"
];
"function compose"
)]
fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
let mut composer =
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&str>) {
let mut composer: TopLevelComposer = Default::default();
let internal_resolver = Arc::new(ResolverInternal {
id_to_def: Mutex::default(),
id_to_type: Mutex::default(),
class_names: Mutex::default(),
id_to_def: Default::default(),
id_to_type: Default::default(),
class_names: Default::default(),
});
let resolver =
Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
for s in source {
let ast = parse_program(s, FileName::default()).unwrap();
let ast = parse_program(s, Default::default()).unwrap();
let ast = ast[0].clone();
let (id, def_id, ty) =
composer.register_top_level(ast, Some(resolver.clone()), "", false).unwrap();
composer.register_top_level(ast, Some(resolver.clone()), "".into(), false).unwrap();
internal_resolver.add_id_def(id, def_id);
if let Some(ty) = ty {
internal_resolver.add_id_type(id, ty);
@ -218,7 +206,7 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
}
#[test_case(
&[
vec![
indoc! {"
class A():
a: int32
@ -251,11 +239,11 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
pass
"}
],
&[];
vec![];
"simple class compose"
)]
#[test_case(
&[
vec![
indoc! {"
class Generic_A(Generic[V], B):
a: int64
@ -273,11 +261,11 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
pass
"}
],
&[];
vec![];
"generic class"
)]
#[test_case(
&[
vec![
indoc! {"
def foo(a: list[int32], b: tuple[T, float]) -> A[B, bool]:
pass
@ -302,11 +290,11 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
pass
"}
],
&[];
vec![];
"list tuple generic"
)]
#[test_case(
&[
vec![
indoc! {"
class A(Generic[T, V]):
a: A[float, bool]
@ -327,11 +315,11 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
pass
"}
],
&[];
vec![];
"self1"
)]
#[test_case(
&[
vec![
indoc! {"
class A(Generic[T]):
a: int32
@ -361,11 +349,11 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
pass
"}
],
&[];
vec![];
"inheritance_override"
)]
#[test_case(
&[
vec![
indoc! {"
class A(Generic[T]):
def __init__(self):
@ -374,11 +362,11 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
pass
"}
],
&["application of type vars to generic class is not currently supported (at unknown:4:24)"];
vec!["application of type vars to generic class is not currently supported (at unknown:4:24)"];
"err no type var in generic app"
)]
#[test_case(
&[
vec![
indoc! {"
class A(B):
def __init__(self):
@ -390,11 +378,11 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
pass
"}
],
&["cyclic inheritance detected"];
vec!["cyclic inheritance detected"];
"cyclic1"
)]
#[test_case(
&[
vec![
indoc! {"
class A(B[bool, int64]):
def __init__(self):
@ -411,30 +399,30 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
pass
"},
],
&["cyclic inheritance detected"];
vec!["cyclic inheritance detected"];
"cyclic2"
)]
#[test_case(
&[
vec![
indoc! {"
class A:
pass
"}
],
&["5: Class {\nname: \"A\",\ndef_id: DefinitionId(5),\nancestors: [CustomClassKind { id: DefinitionId(5), params: [] }],\nfields: [],\nmethods: [],\ntype_vars: []\n}"];
vec!["5: Class {\nname: \"A\",\ndef_id: DefinitionId(5),\nancestors: [CustomClassKind { id: DefinitionId(5), params: [] }],\nfields: [],\nmethods: [],\ntype_vars: []\n}"];
"simple pass in class"
)]
#[test_case(
&[indoc! {"
vec![indoc! {"
class A:
def __init__():
pass
"}],
&["__init__ method must have a `self` parameter (at unknown:2:5)"];
vec!["__init__ method must have a `self` parameter (at unknown:2:5)"];
"err no self_1"
)]
#[test_case(
&[
vec![
indoc! {"
class A(B, Generic[T], C):
def __init__(self):
@ -452,11 +440,11 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
"}
],
&["a class definition can only have at most one base class declaration and one generic declaration (at unknown:1:24)"];
vec!["a class definition can only have at most one base class declaration and one generic declaration (at unknown:1:24)"];
"err multiple inheritance"
)]
#[test_case(
&[
vec![
indoc! {"
class A(Generic[T]):
a: int32
@ -477,11 +465,11 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
pass
"}
],
&["method fun has same name as ancestors' method, but incompatible type"];
vec!["method fun has same name as ancestors' method, but incompatible type"];
"err_incompatible_inheritance_method"
)]
#[test_case(
&[
vec![
indoc! {"
class A(Generic[T]):
a: int32
@ -503,11 +491,11 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
pass
"}
],
&["field `a` has already declared in the ancestor classes"];
vec!["field `a` has already declared in the ancestor classes"];
"err_incompatible_inheritance_field"
)]
#[test_case(
&[
vec![
indoc! {"
class A:
def __init__(self):
@ -520,13 +508,12 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
pass
"}
],
&["duplicate definition of class `A` (at unknown:1:1)"];
vec!["duplicate definition of class `A` (at unknown:1:1)"];
"class same name"
)]
fn test_analyze(source: &[&str], res: &[&str]) {
fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
let print = false;
let mut composer =
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
let mut composer: TopLevelComposer = Default::default();
let internal_resolver = make_internal_resolver_with_tvar(
vec![
@ -541,15 +528,15 @@ fn test_analyze(source: &[&str], res: &[&str]) {
Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
for s in source {
let ast = parse_program(s, FileName::default()).unwrap();
let ast = parse_program(s, Default::default()).unwrap();
let ast = ast[0].clone();
let (id, def_id, ty) = {
match composer.register_top_level(ast, Some(resolver.clone()), "", false) {
match composer.register_top_level(ast, Some(resolver.clone()), "".into(), false) {
Ok(x) => x,
Err(msg) => {
if print {
println!("{msg}");
println!("{}", msg);
} else {
assert_eq!(res[0], msg);
}
@ -595,7 +582,7 @@ fn test_analyze(source: &[&str], res: &[&str]) {
return fib(n - 1)
"}
],
&[];
vec![];
"simple function"
)]
#[test_case(
@ -628,7 +615,7 @@ fn test_analyze(source: &[&str], res: &[&str]) {
return a.fun() + 2
"}
],
&[];
vec![];
"simple class body"
)]
#[test_case(
@ -653,7 +640,7 @@ fn test_analyze(source: &[&str], res: &[&str]) {
return [a, b]
"}
],
&[];
vec![];
"type var fun"
)]
#[test_case(
@ -674,7 +661,7 @@ fn test_analyze(source: &[&str], res: &[&str]) {
return ret if self.b else self.fun(self.a)
"}
],
&[];
vec![];
"type var class"
)]
#[test_case(
@ -698,13 +685,12 @@ fn test_analyze(source: &[&str], res: &[&str]) {
self.b = True
"}
],
&[];
vec![];
"no_init_inst_check"
)]
fn test_inference(source: Vec<&str>, res: &[&str]) {
fn test_inference(source: Vec<&str>, res: Vec<&str>) {
let print = true;
let mut composer =
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
let mut composer: TopLevelComposer = Default::default();
let internal_resolver = make_internal_resolver_with_tvar(
vec![
@ -726,15 +712,15 @@ fn test_inference(source: Vec<&str>, res: &[&str]) {
Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
for s in source {
let ast = parse_program(s, FileName::default()).unwrap();
let ast = parse_program(s, Default::default()).unwrap();
let ast = ast[0].clone();
let (id, def_id, ty) = {
match composer.register_top_level(ast, Some(resolver.clone()), "", false) {
match composer.register_top_level(ast, Some(resolver.clone()), "".into(), false) {
Ok(x) => x,
Err(msg) => {
if print {
println!("{msg}");
println!("{}", msg);
} else {
assert_eq!(res[0], msg);
}
@ -757,7 +743,9 @@ fn test_inference(source: Vec<&str>, res: &[&str]) {
} else {
// skip 5 to skip primitives
let mut stringify_folder = TypeToStringFolder { unifier: &mut composer.unifier };
for (def, _) in composer.definition_ast_list.iter().skip(composer.builtin_num) {
for (_i, (def, _)) in
composer.definition_ast_list.iter().skip(composer.builtin_num).enumerate()
{
let def = &*def.read();
if let TopLevelDef::Function { instance_to_stmt, name, .. } = def {
@ -766,7 +754,7 @@ fn test_inference(source: Vec<&str>, res: &[&str]) {
name,
instance_to_stmt.len()
);
for inst in instance_to_stmt {
for inst in instance_to_stmt.iter() {
let ast = &inst.1.body;
for b in ast.iter() {
println!("{:?}", stringify_folder.fold_stmt(b.clone()).unwrap());
@ -784,29 +772,22 @@ fn make_internal_resolver_with_tvar(
unifier: &mut Unifier,
print: bool,
) -> Arc<ResolverInternal> {
let list_elem_tvar = unifier.get_fresh_var(Some("list_elem".into()), None);
let list = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::List.id(),
fields: HashMap::new(),
params: into_var_map([list_elem_tvar]),
});
let res: Arc<ResolverInternal> = ResolverInternal {
id_to_def: Mutex::new(HashMap::from([("list".into(), PrimDef::List.id())])),
id_to_def: Default::default(),
id_to_type: tvars
.into_iter()
.map(|(name, range)| {
(name, {
let tvar = unifier.get_fresh_var_with_range(range.as_slice(), None, None);
let (ty, id) = unifier.get_fresh_var_with_range(range.as_slice(), None, None);
if print {
println!("{}: {:?}, typevar{}", name, tvar.ty, tvar.id);
println!("{}: {:?}, typevar{}", name, ty, id);
}
tvar.ty
ty
})
})
.collect::<HashMap<_, _>>()
.into(),
class_names: Mutex::new(HashMap::from([("list".into(), list)])),
class_names: Default::default(),
}
.into();
if print {
@ -826,8 +807,8 @@ impl<'a> Fold<Option<Type>> for TypeToStringFolder<'a> {
Ok(if let Some(ty) = user {
self.unifier.internal_stringify(
ty,
&mut |id| format!("class{id}"),
&mut |id| format!("typevar{id}"),
&mut |id| format!("class{}", id.to_string()),
&mut |id| format!("typevar{}", id.to_string()),
&mut None,
)
} else {

View File

@ -1,12 +1,5 @@
use strum::IntoEnumIterator;
use nac3parser::ast::Constant;
use super::{
helper::{PrimDef, PrimDefDetails},
*,
};
use crate::{symbol_resolver::SymbolValue, typecheck::typedef::VarMap};
use crate::symbol_resolver::SymbolValue;
use super::*;
#[derive(Clone, Debug)]
pub enum TypeAnnotation {
@ -20,8 +13,17 @@ pub enum TypeAnnotation {
// can only be CustomClassKind
Virtual(Box<TypeAnnotation>),
TypeVar(Type),
/// A `Literal` allowing a subset of literals.
Literal(Vec<Constant>),
/// A constant used in the context of a const-generic variable.
Constant {
/// The non-type variable associated with this constant.
///
/// Invoking [Unifier::get_ty] on this type will return a [TypeEnum::TVar] representing the
/// const generic variable of which this constant is associated with.
ty: Type,
/// The constant value of this constant.
value: SymbolValue
},
List(Box<TypeAnnotation>),
Tuple(Vec<TypeAnnotation>),
}
@ -32,7 +34,9 @@ impl TypeAnnotation {
Primitive(ty) | TypeVar(ty) => unifier.stringify(*ty),
CustomClass { id, params } => {
let class_name = if let Some(ref top) = unifier.top_level {
if let TopLevelDef::Class { name, .. } = &*top.definitions.read()[id.0].read() {
if let TopLevelDef::Class { name, .. } =
&*top.definitions.read()[id.0].read()
{
(*name).into()
} else {
unreachable!()
@ -40,26 +44,25 @@ impl TypeAnnotation {
} else {
format!("class_def_{}", id.0)
};
format!("{}{}", class_name, {
let param_list =
params.iter().map(|p| p.stringify(unifier)).collect_vec().join(", ");
format!(
"{}{}",
class_name,
{
let param_list = params.iter().map(|p| p.stringify(unifier)).collect_vec().join(", ");
if param_list.is_empty() {
String::new()
} else {
format!("[{param_list}]")
}
})
}
Literal(values) => {
format!("Literal({})", values.iter().map(|v| format!("{v:?}")).join(", "))
}
Virtual(ty) => format!("virtual[{}]", ty.stringify(unifier)),
Tuple(types) => {
format!(
"tuple[{}]",
types.iter().map(|p| p.stringify(unifier)).collect_vec().join(", ")
)
}
Constant { value, .. } => format!("Const({value})"),
Virtual(ty) => format!("virtual[{}]", ty.stringify(unifier)),
List(ty) => format!("list[{}]", ty.stringify(unifier)),
Tuple(types) => {
format!("tuple[{}]", types.iter().map(|p| p.stringify(unifier)).collect_vec().join(", "))
}
}
}
}
@ -70,18 +73,19 @@ impl TypeAnnotation {
/// generic variables associated with the definition.
/// * `type_var` - The type variable associated with the type argument currently being parsed. Pass
/// [`None`] when this function is invoked externally.
pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
pub fn parse_ast_to_type_annotation_kinds<T>(
resolver: &(dyn SymbolResolver + Send + Sync),
top_level_defs: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier,
primitives: &PrimitiveStore,
expr: &ast::Expr<T>,
// the key stores the type_var of this topleveldef::class, we only need this field here
locked: HashMap<DefinitionId, Vec<Type>, S>,
locked: HashMap<DefinitionId, Vec<Type>>,
type_var: Option<Type>,
) -> Result<TypeAnnotation, HashSet<String>> {
let name_handle = |id: &StrRef,
unifier: &mut Unifier,
locked: HashMap<DefinitionId, Vec<Type>, S>| {
locked: HashMap<DefinitionId, Vec<Type>>| {
if id == &"int32".into() {
Ok(TypeAnnotation::Primitive(primitives.int32))
} else if id == &"int64".into() {
@ -97,7 +101,7 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
} else if id == &"str".into() {
Ok(TypeAnnotation::Primitive(primitives.str))
} else if id == &"Exception".into() {
Ok(TypeAnnotation::CustomClass { id: PrimDef::Exception.id(), params: Vec::default() })
Ok(TypeAnnotation::CustomClass { id: DefinitionId(7), params: Vec::default() })
} else if let Ok(obj_id) = resolver.get_identifier_def(*id) {
let type_vars = {
let def_read = top_level_defs[obj_id.0].try_read();
@ -105,10 +109,12 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
if let TopLevelDef::Class { type_vars, .. } = &*def_read {
type_vars.clone()
} else {
return Err(HashSet::from([format!(
return Err(HashSet::from([
format!(
"function cannot be used as a type (at {})",
expr.location
)]));
),
]))
}
} else {
locked.get(&obj_id).unwrap().clone()
@ -116,29 +122,29 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
};
// check param number here
if !type_vars.is_empty() {
return Err(HashSet::from([format!(
return Err(HashSet::from([
format!(
"expect {} type variable parameter but got 0 (at {})",
type_vars.len(),
expr.location,
)]));
),
]))
}
Ok(TypeAnnotation::CustomClass { id: obj_id, params: vec![] })
} else if let Ok(ty) = resolver.get_symbol_type(unifier, top_level_defs, primitives, *id) {
if let TypeEnum::TVar { .. } = unifier.get_ty(ty).as_ref() {
let var = unifier.get_fresh_var(Some(*id), Some(expr.location)).ty;
let var = unifier.get_fresh_var(Some(*id), Some(expr.location)).0;
unifier.unify(var, ty).unwrap();
Ok(TypeAnnotation::TypeVar(ty))
} else {
Err(HashSet::from([format!(
"`{}` is not a valid type annotation (at {})",
id, expr.location
)]))
Err(HashSet::from([
format!("`{}` is not a valid type annotation (at {})", id, expr.location),
]))
}
} else {
Err(HashSet::from([format!(
"`{}` is not a valid type annotation (at {})",
id, expr.location
)]))
Err(HashSet::from([
format!("`{}` is not a valid type annotation (at {})", id, expr.location),
]))
}
};
@ -146,12 +152,12 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
|id: &StrRef,
slice: &ast::Expr<T>,
unifier: &mut Unifier,
mut locked: HashMap<DefinitionId, Vec<Type>, S>| {
if ["virtual".into(), "Generic".into(), "tuple".into(), "Option".into()].contains(id) {
return Err(HashSet::from([format!(
"keywords cannot be class name (at {})",
expr.location
)]));
mut locked: HashMap<DefinitionId, Vec<Type>>| {
if ["virtual".into(), "Generic".into(), "list".into(), "tuple".into(), "Option".into()].contains(id)
{
return Err(HashSet::from([
format!("keywords cannot be class name (at {})", expr.location),
]))
}
let obj_id = resolver.get_identifier_def(*id)?;
let type_vars = {
@ -174,16 +180,19 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
vec![slice]
};
if type_vars.len() != params_ast.len() {
return Err(HashSet::from([format!(
return Err(HashSet::from([
format!(
"expect {} type parameters but got {} (at {})",
type_vars.len(),
params_ast.len(),
params_ast[0].location,
)]));
),
]))
}
let result = params_ast
.iter()
.map(|x| {
.enumerate()
.map(|(idx, x)| {
parse_ast_to_type_annotation_kinds(
resolver,
top_level_defs,
@ -194,6 +203,7 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
locked.insert(obj_id, type_vars.clone());
locked.clone()
},
Some(type_vars[idx]),
)
})
.collect::<Result<Vec<_>, _>>()?;
@ -208,7 +218,7 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
"application of type vars to generic class is not currently supported (at {})",
params_ast[0].location
),
]));
]))
}
};
Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos })
@ -229,6 +239,7 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
primitives,
slice.as_ref(),
locked,
None,
)?;
if !matches!(def, TypeAnnotation::CustomClass { .. }) {
unreachable!("must be concretized custom class kind in the virtual")
@ -236,6 +247,24 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
Ok(TypeAnnotation::Virtual(def.into()))
}
// list
ast::ExprKind::Subscript { value, slice, .. }
if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"list".into())
} =>
{
let def_ann = parse_ast_to_type_annotation_kinds(
resolver,
top_level_defs,
unifier,
primitives,
slice.as_ref(),
locked,
None,
)?;
Ok(TypeAnnotation::List(def_ann.into()))
}
// option
ast::ExprKind::Subscript { value, slice, .. }
if {
@ -249,6 +278,7 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
primitives,
slice.as_ref(),
locked,
None,
)?;
let id =
if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(primitives.option).as_ref() {
@ -282,76 +312,54 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
primitives,
e,
locked.clone(),
None,
)
})
.collect::<Result<Vec<_>, _>>()?;
Ok(TypeAnnotation::Tuple(type_annotations))
}
// Literal
ast::ExprKind::Subscript { value, slice, .. }
if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"Literal".into())
} =>
{
let tup_elts = {
if let ast::ExprKind::Tuple { elts, .. } = &slice.node {
elts.as_slice()
} else {
std::slice::from_ref(slice.as_ref())
}
};
let type_annotations = tup_elts
.iter()
.map(|e| match &e.node {
ast::ExprKind::Constant { value, .. } => {
Ok(TypeAnnotation::Literal(vec![value.clone()]))
}
_ => parse_ast_to_type_annotation_kinds(
resolver,
top_level_defs,
unifier,
primitives,
e,
locked.clone(),
),
})
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flat_map(|type_ann| match type_ann {
TypeAnnotation::Literal(values) => values,
_ => unreachable!(),
})
.collect_vec();
if type_annotations.len() == 1 {
Ok(TypeAnnotation::Literal(type_annotations))
} else {
Err(HashSet::from([format!(
"multiple literal bounds are currently unsupported (at {})",
value.location
)]))
}
}
// custom class
ast::ExprKind::Subscript { value, slice, .. } => {
if let ast::ExprKind::Name { id, .. } = &value.node {
class_name_handle(id, slice, unifier, locked)
} else {
Err(HashSet::from([format!(
"unsupported expression type for class name (at {})",
value.location
)]))
Err(HashSet::from([
format!("unsupported expression type for class name (at {})", value.location)
]))
}
}
ast::ExprKind::Constant { value, .. } => Ok(TypeAnnotation::Literal(vec![value.clone()])),
ast::ExprKind::Constant { value, .. } => {
let type_var = type_var.expect("Expect type variable to be present");
_ => Err(HashSet::from([format!(
"unsupported expression for type annotation (at {})",
let ntv_ty_enum = unifier.get_ty_immutable(type_var);
let TypeEnum::TVar { range: underlying_ty, .. } = ntv_ty_enum.as_ref() else {
unreachable!()
};
let underlying_ty = underlying_ty[0];
let value = SymbolValue::from_constant(value, underlying_ty, primitives, unifier)
.map_err(|err| HashSet::from([err]))?;
if matches!(value, SymbolValue::Str(_) | SymbolValue::Tuple(_) | SymbolValue::OptionSome(_)) {
return Err(HashSet::from([
format!(
"expression {value} is not allowed for constant type annotation (at {})",
expr.location
)])),
),
]))
}
Ok(TypeAnnotation::Constant {
ty: type_var,
value,
})
}
_ => Err(HashSet::from([
format!("unsupported expression for type annotation (at {})", expr.location),
])),
}
}
@ -361,9 +369,8 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
pub fn get_type_from_type_annotation_kinds(
top_level_defs: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier,
primitives: &PrimitiveStore,
ann: &TypeAnnotation,
subst_list: &mut Option<Vec<Type>>,
subst_list: &mut Option<Vec<Type>>
) -> Result<Type, HashSet<String>> {
match ann {
TypeAnnotation::CustomClass { id: obj_id, params } => {
@ -374,11 +381,13 @@ pub fn get_type_from_type_annotation_kinds(
};
if type_vars.len() != params.len() {
return Err(HashSet::from([format!(
return Err(HashSet::from([
format!(
"unexpected number of type parameters: expected {} but got {}",
type_vars.len(),
params.len()
)]));
),
]))
}
let param_ty = params
@ -387,54 +396,19 @@ pub fn get_type_from_type_annotation_kinds(
get_type_from_type_annotation_kinds(
top_level_defs,
unifier,
primitives,
x,
subst_list,
subst_list
)
})
.collect::<Result<Vec<_>, _>>()?;
let ty = if let Some(prim_def) = PrimDef::iter().find(|prim| prim.id() == *obj_id) {
// Primitive TopLevelDefs do not contain all fields that are present in their Type
// counterparts, so directly perform subst on the Type instead.
let PrimDefDetails::PrimClass { get_ty_fn, .. } = prim_def.details() else {
unreachable!()
};
let base_ty = get_ty_fn(primitives);
let params =
if let TypeEnum::TObj { params, .. } = &*unifier.get_ty_immutable(base_ty) {
params.clone()
} else {
unreachable!()
};
unifier
.subst(
get_ty_fn(primitives),
&params
.iter()
.zip(param_ty)
.map(|(obj_tv, param)| (*obj_tv.0, param))
.collect(),
)
.unwrap_or(base_ty)
} else {
let subst = {
// check for compatible range
// TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check
let mut result = VarMap::new();
let mut result: HashMap<u32, Type> = HashMap::new();
for (tvar, p) in type_vars.iter().zip(param_ty) {
match unifier.get_ty(*tvar).as_ref() {
TypeEnum::TVar {
id,
range,
fields: None,
name,
loc,
is_const_generic: false,
} => {
TypeEnum::TVar { id, range, fields: None, name, loc, is_const_generic: false } => {
let ok: bool = {
// create a temp type var and unify to check compatibility
p == *tvar || {
@ -443,13 +417,14 @@ pub fn get_type_from_type_annotation_kinds(
*name,
*loc,
);
unifier.unify(temp.ty, p).is_ok()
unifier.unify(temp.0, p).is_ok()
}
};
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([format!(
return Err(HashSet::from([
format!(
"cannot apply type {} to type variable with id {:?}",
unifier.internal_stringify(
p,
@ -458,30 +433,34 @@ pub fn get_type_from_type_annotation_kinds(
&mut None
),
*id
)]));
)
]))
}
}
TypeEnum::TVar {
id, range, name, loc, is_const_generic: true, ..
} => {
TypeEnum::TVar { id, range, name, loc, is_const_generic: true, .. } => {
let ty = range[0];
let ok: bool = {
// create a temp type var and unify to check compatibility
p == *tvar || {
let temp =
unifier.get_fresh_const_generic_var(ty, *name, *loc);
unifier.unify(temp.ty, p).is_ok()
let temp = unifier.get_fresh_const_generic_var(
ty,
*name,
*loc,
);
unifier.unify(temp.0, p).is_ok()
}
};
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([format!(
return Err(HashSet::from([
format!(
"cannot apply type {} to type variable {}",
unifier.stringify(p),
name.unwrap_or_else(|| format!("typevar{id}").into()),
)]));
),
]))
}
}
@ -490,7 +469,6 @@ pub fn get_type_from_type_annotation_kinds(
}
result
};
// Class Attributes keep a copy with Class Definition and are not added to objects
let mut tobj_fields = methods
.iter()
.map(|(name, ty, _)| {
@ -509,53 +487,50 @@ pub fn get_type_from_type_annotation_kinds(
fields: tobj_fields,
params: subst,
});
if need_subst {
if let Some(wl) = subst_list.as_mut() {
wl.push(ty);
}
}
ty
};
Ok(ty)
}
TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty),
TypeAnnotation::Literal(values) => {
let values = values
.iter()
.map(SymbolValue::from_constant_inferred)
.collect::<Result<Vec<_>, _>>()
.map_err(|err| HashSet::from([err]))?;
TypeAnnotation::Constant { ty, value, .. } => {
let ty_enum = unifier.get_ty(*ty);
let TypeEnum::TVar { range: ntv_underlying_ty, loc, is_const_generic: true, .. } = &*ty_enum else {
unreachable!("{} ({})", unifier.stringify(*ty), ty_enum.get_type_name());
};
let var = unifier.get_fresh_literal(values, None);
let ty = ntv_underlying_ty[0];
let var = unifier.get_fresh_constant(value.clone(), ty, *loc);
Ok(var)
}
TypeAnnotation::Virtual(ty) => {
let ty = get_type_from_type_annotation_kinds(
top_level_defs,
unifier,
primitives,
ty.as_ref(),
subst_list,
subst_list
)?;
Ok(unifier.add_ty(TypeEnum::TVirtual { ty }))
}
TypeAnnotation::List(ty) => {
let ty = get_type_from_type_annotation_kinds(
top_level_defs,
unifier,
ty.as_ref(),
subst_list
)?;
Ok(unifier.add_ty(TypeEnum::TList { ty }))
}
TypeAnnotation::Tuple(tys) => {
let tys = tys
.iter()
.map(|x| {
get_type_from_type_annotation_kinds(
top_level_defs,
unifier,
primitives,
x,
subst_list,
)
get_type_from_type_annotation_kinds(top_level_defs, unifier, x, subst_list)
})
.collect::<Result<Vec<_>, _>>()?;
Ok(unifier.add_ty(TypeEnum::TTuple { ty: tys, is_vararg_ctx: false }))
Ok(unifier.add_ty(TypeEnum::TTuple { ty: tys }))
}
}
}
@ -588,7 +563,7 @@ pub fn get_type_var_contained_in_type_annotation(ann: &TypeAnnotation) -> Vec<Ty
let mut result: Vec<TypeAnnotation> = Vec::new();
match ann {
TypeAnnotation::TypeVar(..) => result.push(ann.clone()),
TypeAnnotation::Virtual(ann) => {
TypeAnnotation::Virtual(ann) | TypeAnnotation::List(ann) => {
result.extend(get_type_var_contained_in_type_annotation(ann.as_ref()));
}
TypeAnnotation::CustomClass { params, .. } => {
@ -601,7 +576,7 @@ pub fn get_type_var_contained_in_type_annotation(ann: &TypeAnnotation) -> Vec<Ty
result.extend(get_type_var_contained_in_type_annotation(a));
}
}
TypeAnnotation::Primitive(..) | TypeAnnotation::Literal { .. } => {}
TypeAnnotation::Primitive(..) | TypeAnnotation::Constant { .. } => {}
}
result
}
@ -622,14 +597,14 @@ pub fn check_overload_type_annotation_compatible(
let (
TypeEnum::TVar { id: a, fields: None, .. },
TypeEnum::TVar { id: b, fields: None, .. },
) = (a, b)
else {
) = (a, b) else {
unreachable!("must be type var")
};
a == b
}
(TypeAnnotation::Virtual(a), TypeAnnotation::Virtual(b)) => {
(TypeAnnotation::Virtual(a), TypeAnnotation::Virtual(b))
| (TypeAnnotation::List(a), TypeAnnotation::List(b)) => {
check_overload_type_annotation_compatible(a.as_ref(), b.as_ref(), unifier)
}

View File

@ -1,24 +1,16 @@
use std::{
collections::{HashMap, HashSet},
iter::once,
};
use crate::typecheck::typedef::TypeEnum;
use nac3parser::ast::{
self, Constant, Expr, ExprKind,
Operator::{LShift, RShift},
Stmt, StmtKind, StrRef,
};
use super::{
type_inferencer::{DeclarationSource, IdentifierInfo, Inferencer},
typedef::{Type, TypeEnum},
};
use crate::toplevel::helper::PrimDef;
use super::type_inferencer::Inferencer;
use super::typedef::Type;
use nac3parser::ast::{self, Constant, Expr, ExprKind, Operator::{LShift, RShift}, Stmt, StmtKind, StrRef};
use std::{collections::HashSet, iter::once};
impl<'a> Inferencer<'a> {
fn should_have_value(&mut self, expr: &Expr<Option<Type>>) -> Result<(), HashSet<String>> {
if matches!(expr.custom, Some(ty) if self.unifier.unioned(ty, self.primitives.none)) {
Err(HashSet::from([format!("Error at {}: cannot have value none", expr.location)]))
Err(HashSet::from([
format!("Error at {}: cannot have value none", expr.location),
]))
} else {
Ok(())
}
@ -27,61 +19,45 @@ impl<'a> Inferencer<'a> {
fn check_pattern(
&mut self,
pattern: &Expr<Option<Type>>,
defined_identifiers: &mut HashMap<StrRef, IdentifierInfo>,
defined_identifiers: &mut HashSet<StrRef>,
) -> Result<(), HashSet<String>> {
match &pattern.node {
ExprKind::Name { id, .. } if id == &"none".into() => {
Err(HashSet::from([format!("cannot assign to a `none` (at {})", pattern.location)]))
}
ExprKind::Name { id, .. } if id == &"none".into() => Err(HashSet::from([
format!("cannot assign to a `none` (at {})", pattern.location),
])),
ExprKind::Name { id, .. } => {
// If `id` refers to a declared symbol, reject this assignment if it is used in the
// context of an (implicit) global variable
if let Some(id_info) = defined_identifiers.get(id) {
if matches!(
id_info.source,
DeclarationSource::Global { is_explicit: Some(false) }
) {
return Err(HashSet::from([format!(
"cannot access local variable '{id}' before it is declared (at {})",
pattern.location
)]));
}
}
if !defined_identifiers.contains_key(id) {
defined_identifiers.insert(*id, IdentifierInfo::default());
if !defined_identifiers.contains(id) {
defined_identifiers.insert(*id);
}
self.should_have_value(pattern)?;
Ok(())
}
ExprKind::List { elts, .. } | ExprKind::Tuple { elts, .. } => {
ExprKind::Tuple { elts, .. } => {
for elt in elts {
self.check_pattern(elt, defined_identifiers)?;
self.should_have_value(elt)?;
}
Ok(())
}
ExprKind::Starred { value, .. } => {
self.check_pattern(value, defined_identifiers)?;
self.should_have_value(value)?;
Ok(())
}
ExprKind::Subscript { value, slice, .. } => {
self.check_expr(value, defined_identifiers)?;
self.should_have_value(value)?;
self.check_expr(slice, defined_identifiers)?;
if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) {
return Err(HashSet::from([format!(
return Err(HashSet::from([
format!(
"Error at {}: cannot assign to tuple element",
value.location
)]));
),
]))
}
Ok(())
}
ExprKind::Constant { .. } => Err(HashSet::from([format!(
"cannot assign to a constant (at {})",
pattern.location
)])),
ExprKind::Constant { .. } => {
Err(HashSet::from([
format!("cannot assign to a constant (at {})", pattern.location),
]))
}
_ => self.check_expr(pattern, defined_identifiers),
}
}
@ -89,19 +65,18 @@ impl<'a> Inferencer<'a> {
fn check_expr(
&mut self,
expr: &Expr<Option<Type>>,
defined_identifiers: &mut HashMap<StrRef, IdentifierInfo>,
defined_identifiers: &mut HashSet<StrRef>,
) -> Result<(), HashSet<String>> {
// there are some cases where the custom field is None
if let Some(ty) = &expr.custom {
if !matches!(&expr.node, ExprKind::Constant { value: Constant::Ellipsis, .. })
&& !ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::List.id())
&& !self.unifier.is_concrete(*ty, &self.function_data.bound_variables)
{
return Err(HashSet::from([format!(
if !matches!(&expr.node, ExprKind::Constant { value: Constant::Ellipsis, .. }) && !self.unifier.is_concrete(*ty, &self.function_data.bound_variables) {
return Err(HashSet::from([
format!(
"expected concrete type at {} but got {}",
expr.location,
self.unifier.get_ty(*ty).get_type_name()
)]));
)
]))
}
}
match &expr.node {
@ -110,7 +85,7 @@ impl<'a> Inferencer<'a> {
return Ok(());
}
self.should_have_value(expr)?;
if !defined_identifiers.contains_key(id) {
if !defined_identifiers.contains(id) {
match self.function_data.resolver.get_symbol_type(
self.unifier,
&self.top_level.definitions.read(),
@ -118,28 +93,15 @@ impl<'a> Inferencer<'a> {
*id,
) {
Ok(_) => {
let is_global = self.is_id_global(*id);
defined_identifiers.insert(
*id,
IdentifierInfo {
source: match is_global {
Some(true) => {
DeclarationSource::Global { is_explicit: Some(false) }
}
Some(false) => {
DeclarationSource::Global { is_explicit: None }
}
None => DeclarationSource::Local,
},
},
);
self.defined_identifiers.insert(*id);
}
Err(e) => {
return Err(HashSet::from([format!(
return Err(HashSet::from([
format!(
"type error at identifier `{}` ({}) at {}",
id, e, expr.location
)]))
)
]))
}
}
}
@ -165,13 +127,17 @@ impl<'a> Inferencer<'a> {
// Check whether a bitwise shift has a negative RHS constant value
if *op == LShift || *op == RShift {
if let ExprKind::Constant { value, .. } = &right.node {
let Constant::Int(rhs_val) = value else { unreachable!() };
let Constant::Int(rhs_val) = value else {
unreachable!()
};
if *rhs_val < 0 {
return Err(HashSet::from([format!(
return Err(HashSet::from([
format!(
"shift count is negative at {}",
right.location
)]));
),
]))
}
}
}
@ -206,7 +172,9 @@ impl<'a> Inferencer<'a> {
let mut defined_identifiers = defined_identifiers.clone();
for arg in &args.args {
// TODO: should we check the types here?
defined_identifiers.entry(arg.node.arg).or_default();
if !defined_identifiers.contains(&arg.node.arg) {
defined_identifiers.insert(arg.node.arg);
}
}
self.check_expr(body, &mut defined_identifiers)?;
}
@ -240,36 +208,11 @@ impl<'a> Inferencer<'a> {
Ok(())
}
/// Check that the return value is a non-`alloca` type, effectively only allowing primitive types.
///
/// This is a workaround preventing the caller from using a variable `alloca`-ed in the body, which
/// is freed when the function returns.
fn check_return_value_ty(&mut self, ret_ty: Type) -> bool {
if cfg!(feature = "no-escape-analysis") {
true
} else {
match &*self.unifier.get_ty_immutable(ret_ty) {
TypeEnum::TObj { .. } => [
self.primitives.int32,
self.primitives.int64,
self.primitives.uint32,
self.primitives.uint64,
self.primitives.float,
self.primitives.bool,
]
.iter()
.any(|allowed_ty| self.unifier.unioned(ret_ty, *allowed_ty)),
TypeEnum::TTuple { ty, .. } => ty.iter().all(|t| self.check_return_value_ty(*t)),
_ => false,
}
}
}
// check statements for proper identifier def-use and return on all paths
fn check_stmt(
&mut self,
stmt: &Stmt<Option<Type>>,
defined_identifiers: &mut HashMap<StrRef, IdentifierInfo>,
defined_identifiers: &mut HashSet<StrRef>,
) -> Result<bool, HashSet<String>> {
match &stmt.node {
StmtKind::For { target, iter, body, orelse, .. } => {
@ -295,11 +238,9 @@ impl<'a> Inferencer<'a> {
let body_returned = self.check_block(body, &mut body_identifiers)?;
let orelse_returned = self.check_block(orelse, &mut orelse_identifiers)?;
for ident in body_identifiers.keys() {
if !defined_identifiers.contains_key(ident)
&& orelse_identifiers.contains_key(ident)
{
defined_identifiers.insert(*ident, IdentifierInfo::default());
for ident in &body_identifiers {
if !defined_identifiers.contains(ident) && orelse_identifiers.contains(ident) {
defined_identifiers.insert(*ident);
}
}
Ok(body_returned && orelse_returned)
@ -330,7 +271,7 @@ impl<'a> Inferencer<'a> {
let mut defined_identifiers = defined_identifiers.clone();
let ast::ExcepthandlerKind::ExceptHandler { name, body, .. } = &handler.node;
if let Some(name) = name {
defined_identifiers.insert(*name, IdentifierInfo::default());
defined_identifiers.insert(*name);
}
self.check_block(body, &mut defined_identifiers)?;
}
@ -361,30 +302,6 @@ impl<'a> Inferencer<'a> {
if let Some(value) = value {
self.check_expr(value, defined_identifiers)?;
self.should_have_value(value)?;
// Check that the return value is a non-`alloca` type, effectively only allowing primitive types.
// This is a workaround preventing the caller from using a variable `alloca`-ed in the body, which
// is freed when the function returns.
if let Some(ret_ty) = value.custom {
// Explicitly allow ellipsis as a return value, as the type of the ellipsis is contextually
// inferred and just generates an unconditional assertion
if matches!(
value.node,
ExprKind::Constant { value: Constant::Ellipsis, .. }
) {
return Ok(true);
}
if !self.check_return_value_ty(ret_ty) {
return Err(HashSet::from([
format!(
"return value of type {} must be a primitive or a tuple of primitives at {}",
self.unifier.stringify(ret_ty),
value.location,
),
]));
}
}
}
Ok(true)
}
@ -394,44 +311,6 @@ impl<'a> Inferencer<'a> {
}
Ok(true)
}
StmtKind::Global { names, .. } => {
for id in names {
if let Some(id_info) = defined_identifiers.get(id) {
if id_info.source == DeclarationSource::Local {
return Err(HashSet::from([format!(
"name '{id}' is referenced prior to global declaration at {}",
stmt.location,
)]));
}
continue;
}
match self.function_data.resolver.get_symbol_type(
self.unifier,
&self.top_level.definitions.read(),
self.primitives,
*id,
) {
Ok(_) => {
defined_identifiers.insert(
*id,
IdentifierInfo {
source: DeclarationSource::Global { is_explicit: Some(true) },
},
);
}
Err(e) => {
return Err(HashSet::from([format!(
"type error at identifier `{}` ({}) at {}",
id, e, stmt.location
)]))
}
}
}
Ok(false)
}
// break, raise, etc.
_ => Ok(false),
}
@ -440,12 +319,12 @@ impl<'a> Inferencer<'a> {
pub fn check_block(
&mut self,
block: &[Stmt<Option<Type>>],
defined_identifiers: &mut HashMap<StrRef, IdentifierInfo>,
defined_identifiers: &mut HashSet<StrRef>,
) -> Result<bool, HashSet<String>> {
let mut ret = false;
for stmt in block {
if ret {
eprintln!("warning: dead code at {}\n", stmt.location);
eprintln!("warning: dead code at {:?}\n", stmt.location);
}
if self.check_stmt(stmt, defined_identifiers)? {
ret = true;

View File

@ -1,151 +1,70 @@
use std::{cmp::max, collections::HashMap, rc::Rc};
use itertools::{iproduct, Itertools};
use strum::IntoEnumIterator;
use nac3parser::ast::{Cmpop, Operator, StrRef, Unaryop};
use super::{
use crate::typecheck::{
type_inferencer::*,
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
};
use crate::{
symbol_resolver::SymbolValue,
toplevel::{
helper::PrimDef,
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
},
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier},
};
use nac3parser::ast::StrRef;
use nac3parser::ast::{Cmpop, Operator, Unaryop};
use std::collections::HashMap;
use std::rc::Rc;
/// The variant of a binary operator.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BinopVariant {
/// The normal variant.
/// For addition, it would be `+`.
Normal,
/// The "Augmented Assigning Operator" variant.
/// For addition, it would be `+=`.
AugAssign,
}
/// A binary operator with its variant.
#[derive(Debug, Clone, Copy)]
pub struct Binop {
/// The base [`Operator`] of this binary operator.
pub base: Operator,
/// The variant of this binary operator.
pub variant: BinopVariant,
}
impl Binop {
/// Make a [`Binop`] of the normal variant from an [`Operator`].
#[must_use]
pub fn normal(base: Operator) -> Self {
Binop { base, variant: BinopVariant::Normal }
}
/// Make a [`Binop`] of the aug assign variant from an [`Operator`].
#[must_use]
pub fn aug_assign(base: Operator) -> Self {
Binop { base, variant: BinopVariant::AugAssign }
}
}
/// Details about an operator (unary, binary, etc...) in Python
#[derive(Debug, Clone, Copy)]
pub struct OpInfo {
/// The method name of the binary operator.
/// For addition, this would be `__add__`, and `__iadd__` if
/// it is the augmented assigning variant.
pub method_name: &'static str,
/// The symbol of the binary operator.
/// For addition, this would be `+`, and `+=` if
/// it is the augmented assigning variant.
pub symbol: &'static str,
}
/// Helper macro to conveniently build an [`OpInfo`].
///
/// Example usage: `make_info("add", "+")` generates `OpInfo { name: "__add__", symbol: "+" }`
macro_rules! make_op_info {
($name:expr, $symbol:expr) => {
OpInfo { method_name: concat!("__", $name, "__"), symbol: $symbol }
};
}
pub trait HasOpInfo {
fn op_info(&self) -> OpInfo;
}
fn try_get_cmpop_info(op: Cmpop) -> Option<OpInfo> {
pub fn binop_name(op: &Operator) -> &'static str {
match op {
Cmpop::Lt => Some(make_op_info!("lt", "<")),
Cmpop::LtE => Some(make_op_info!("le", "<=")),
Cmpop::Gt => Some(make_op_info!("gt", ">")),
Cmpop::GtE => Some(make_op_info!("ge", ">=")),
Cmpop::Eq => Some(make_op_info!("eq", "==")),
Cmpop::NotEq => Some(make_op_info!("ne", "!=")),
_ => None,
Operator::Add => "__add__",
Operator::Sub => "__sub__",
Operator::Div => "__truediv__",
Operator::Mod => "__mod__",
Operator::Mult => "__mul__",
Operator::Pow => "__pow__",
Operator::BitOr => "__or__",
Operator::BitXor => "__xor__",
Operator::BitAnd => "__and__",
Operator::LShift => "__lshift__",
Operator::RShift => "__rshift__",
Operator::FloorDiv => "__floordiv__",
Operator::MatMult => "__matmul__",
}
}
impl OpInfo {
#[must_use]
pub fn supports_cmpop(op: Cmpop) -> bool {
try_get_cmpop_info(op).is_some()
pub fn binop_assign_name(op: &Operator) -> &'static str {
match op {
Operator::Add => "__iadd__",
Operator::Sub => "__isub__",
Operator::Div => "__itruediv__",
Operator::Mod => "__imod__",
Operator::Mult => "__imul__",
Operator::Pow => "__ipow__",
Operator::BitOr => "__ior__",
Operator::BitXor => "__ixor__",
Operator::BitAnd => "__iand__",
Operator::LShift => "__ilshift__",
Operator::RShift => "__irshift__",
Operator::FloorDiv => "__ifloordiv__",
Operator::MatMult => "__imatmul__",
}
}
impl HasOpInfo for Cmpop {
fn op_info(&self) -> OpInfo {
try_get_cmpop_info(*self).expect("{self:?} is not supported")
#[must_use]
pub fn unaryop_name(op: &Unaryop) -> &'static str {
match op {
Unaryop::UAdd => "__pos__",
Unaryop::USub => "__neg__",
Unaryop::Not => "__not__",
Unaryop::Invert => "__inv__",
}
}
impl HasOpInfo for Binop {
fn op_info(&self) -> OpInfo {
// Helper macro to generate both the normal variant [`OpInfo`] and the
// augmented assigning variant [`OpInfo`] for a binary operator conveniently.
macro_rules! info {
($name:literal, $symbol:literal) => {
(
make_op_info!($name, $symbol),
make_op_info!(concat!("i", $name), concat!($symbol, "=")),
)
};
}
let (normal_variant, aug_assign_variant) = match self.base {
Operator::Add => info!("add", "+"),
Operator::Sub => info!("sub", "-"),
Operator::Div => info!("truediv", "/"),
Operator::Mod => info!("mod", "%"),
Operator::Mult => info!("mul", "*"),
Operator::Pow => info!("pow", "**"),
Operator::BitOr => info!("or", "|"),
Operator::BitXor => info!("xor", "^"),
Operator::BitAnd => info!("and", "&"),
Operator::LShift => info!("lshift", "<<"),
Operator::RShift => info!("rshift", ">>"),
Operator::FloorDiv => info!("floordiv", "//"),
Operator::MatMult => info!("matmul", "@"),
};
match self.variant {
BinopVariant::Normal => normal_variant,
BinopVariant::AugAssign => aug_assign_variant,
}
}
}
impl HasOpInfo for Unaryop {
fn op_info(&self) -> OpInfo {
match self {
Unaryop::UAdd => make_op_info!("pos", "+"),
Unaryop::USub => make_op_info!("neg", "-"),
Unaryop::Not => make_op_info!("not", "not"), // i.e., `not False`, so the symbol is just `not`.
Unaryop::Invert => make_op_info!("inv", "~"),
}
#[must_use]
pub fn comparison_name(op: &Cmpop) -> Option<&'static str> {
match op {
Cmpop::Lt => Some("__lt__"),
Cmpop::LtE => Some("__le__"),
Cmpop::Gt => Some("__gt__"),
Cmpop::GtE => Some("__ge__"),
Cmpop::Eq => Some("__eq__"),
Cmpop::NotEq => Some("__ne__"),
_ => None,
}
}
@ -171,28 +90,40 @@ pub fn impl_binop(
_store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
ret_ty: Type,
ops: &[Operator],
) {
with_fields(unifier, ty, |unifier, fields| {
let (other_ty, other_var_id) = if other_ty.len() == 1 {
(other_ty[0], None)
} else {
let tvar = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None);
(tvar.ty, Some(tvar.id))
let (ty, var_id) = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None);
(ty, Some(var_id))
};
let function_vars = if let Some(var_id) = other_var_id {
vec![(var_id, other_ty)].into_iter().collect::<VarMap>()
vec![(var_id, other_ty)].into_iter().collect::<HashMap<_, _>>()
} else {
VarMap::new()
HashMap::new()
};
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty);
for (base_op, variant) in iproduct!(ops, [BinopVariant::Normal, BinopVariant::AugAssign]) {
let op = Binop { base: *base_op, variant };
fields.insert(op.op_info().method_name.into(), {
for op in ops {
fields.insert(binop_name(op).into(), {
(
unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty,
vars: function_vars.clone(),
args: vec![FuncArg {
ty: other_ty,
default_value: None,
name: "other".into(),
}],
})),
false,
)
});
fields.insert(binop_assign_name(op).into(), {
(
unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty,
@ -201,7 +132,6 @@ pub fn impl_binop(
ty: other_ty,
default_value: None,
name: "other".into(),
is_vararg: false,
}],
})),
false,
@ -211,17 +141,15 @@ pub fn impl_binop(
});
}
pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Option<Type>, ops: &[Unaryop]) {
pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Type, ops: &[Unaryop]) {
with_fields(unifier, ty, |unifier, fields| {
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty);
for op in ops {
fields.insert(
op.op_info().method_name.into(),
unaryop_name(op).into(),
(
unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty,
vars: VarMap::new(),
vars: HashMap::new(),
args: vec![],
})),
false,
@ -233,40 +161,23 @@ pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Option<Type>, ops:
pub fn impl_cmpop(
unifier: &mut Unifier,
_store: &PrimitiveStore,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
other_ty: Type,
ops: &[Cmpop],
ret_ty: Option<Type>,
) {
with_fields(unifier, ty, |unifier, fields| {
let (other_ty, other_var_id) = if other_ty.len() == 1 {
(other_ty[0], None)
} else {
let tvar = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None);
(tvar.ty, Some(tvar.id))
};
let function_vars = if let Some(var_id) = other_var_id {
vec![(var_id, other_ty)].into_iter().collect::<VarMap>()
} else {
VarMap::new()
};
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty);
for op in ops {
fields.insert(
op.op_info().method_name.into(),
comparison_name(op).unwrap().into(),
(
unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty,
vars: function_vars.clone(),
ret: store.bool,
vars: HashMap::new(),
args: vec![FuncArg {
ty: other_ty,
default_value: None,
name: "other".into(),
is_vararg: false,
}],
})),
false,
@ -282,7 +193,7 @@ pub fn impl_basic_arithmetic(
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
ret_ty: Type,
) {
impl_binop(
unifier,
@ -300,7 +211,7 @@ pub fn impl_pow(
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
ret_ty: Type,
) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Pow]);
}
@ -312,32 +223,19 @@ pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty
store,
ty,
&[ty],
Some(ty),
ty,
&[Operator::BitAnd, Operator::BitOr, Operator::BitXor],
);
}
/// `LShift`, `RShift`
pub fn impl_bitwise_shift(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_binop(
unifier,
store,
ty,
&[store.int32, store.uint32],
Some(ty),
&[Operator::LShift, Operator::RShift],
);
impl_binop(unifier, store, ty, &[store.int32, store.uint32], ty, &[Operator::LShift, Operator::RShift]);
}
/// `Div`
pub fn impl_div(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Div]);
pub fn impl_div(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type]) {
impl_binop(unifier, store, ty, other_ty, store.float, &[Operator::Div]);
}
/// `FloorDiv`
@ -346,7 +244,7 @@ pub fn impl_floordiv(
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
ret_ty: Type,
) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::FloorDiv]);
}
@ -357,340 +255,40 @@ pub fn impl_mod(
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
ret_ty: Type,
) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Mod]);
}
/// [`Operator::MatMult`]
pub fn impl_matmul(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::MatMult]);
}
/// `UAdd`, `USub`
pub fn impl_sign(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret_ty: Option<Type>) {
impl_unaryop(unifier, ty, ret_ty, &[Unaryop::UAdd, Unaryop::USub]);
pub fn impl_sign(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) {
impl_unaryop(unifier, ty, ty, &[Unaryop::UAdd, Unaryop::USub]);
}
/// `Invert`
pub fn impl_invert(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret_ty: Option<Type>) {
impl_unaryop(unifier, ty, ret_ty, &[Unaryop::Invert]);
pub fn impl_invert(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) {
impl_unaryop(unifier, ty, ty, &[Unaryop::Invert]);
}
/// `Not`
pub fn impl_not(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret_ty: Option<Type>) {
impl_unaryop(unifier, ty, ret_ty, &[Unaryop::Not]);
pub fn impl_not(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_unaryop(unifier, ty, store.bool, &[Unaryop::Not]);
}
/// `Lt`, `LtE`, `Gt`, `GtE`
pub fn impl_comparison(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
) {
pub fn impl_comparison(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type) {
impl_cmpop(
unifier,
store,
ty,
other_ty,
&[Cmpop::Lt, Cmpop::Gt, Cmpop::LtE, Cmpop::GtE],
ret_ty,
);
}
/// `Eq`, `NotEq`
pub fn impl_eq(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
) {
impl_cmpop(unifier, store, ty, other_ty, &[Cmpop::Eq, Cmpop::NotEq], ret_ty);
}
/// Returns the expected return type of binary operations with at least one `ndarray` operand.
pub fn typeof_ndarray_broadcast(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
left: Type,
right: Type,
) -> Result<Type, String> {
let is_left_ndarray = left.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let is_right_ndarray = right.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
assert!(is_left_ndarray || is_right_ndarray);
if is_left_ndarray && is_right_ndarray {
// Perform broadcasting on two ndarray operands.
let (left_ty_dtype, left_ty_ndims) = unpack_ndarray_var_tys(unifier, left);
let (right_ty_dtype, right_ty_ndims) = unpack_ndarray_var_tys(unifier, right);
assert!(unifier.unioned(left_ty_dtype, right_ty_dtype));
let left_ty_ndims = match &*unifier.get_ty_immutable(left_ty_ndims) {
TypeEnum::TLiteral { values, .. } => values.clone(),
_ => unreachable!(),
};
let right_ty_ndims = match &*unifier.get_ty_immutable(right_ty_ndims) {
TypeEnum::TLiteral { values, .. } => values.clone(),
_ => unreachable!(),
};
let res_ndims = left_ty_ndims
.into_iter()
.cartesian_product(right_ty_ndims)
.map(|(left, right)| {
let left_val = u64::try_from(left).unwrap();
let right_val = u64::try_from(right).unwrap();
max(left_val, right_val)
})
.unique()
.map(SymbolValue::U64)
.collect_vec();
let res_ndims = unifier.get_fresh_literal(res_ndims, None);
Ok(make_ndarray_ty(unifier, primitives, Some(left_ty_dtype), Some(res_ndims)))
} else {
let (ndarray_ty, scalar_ty) = if is_left_ndarray { (left, right) } else { (right, left) };
let (ndarray_ty_dtype, _) = unpack_ndarray_var_tys(unifier, ndarray_ty);
if unifier.unioned(ndarray_ty_dtype, scalar_ty) {
Ok(ndarray_ty)
} else {
let (expected_ty, actual_ty) = if is_left_ndarray {
(ndarray_ty_dtype, scalar_ty)
} else {
(scalar_ty, ndarray_ty_dtype)
};
Err(format!(
"Expected right-hand side operand to be {}, got {}",
unifier.stringify(expected_ty),
unifier.stringify(actual_ty),
))
}
}
}
/// Returns the return type given a binary operator and its primitive operands.
pub fn typeof_binop(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
op: Operator,
lhs: Type,
rhs: Type,
) -> Result<Option<Type>, String> {
let op = Binop { base: op, variant: BinopVariant::Normal };
let is_left_list = lhs.obj_id(unifier).is_some_and(|id| id == PrimDef::List.id());
let is_right_list = rhs.obj_id(unifier).is_some_and(|id| id == PrimDef::List.id());
let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
Ok(Some(match op.base {
Operator::Add | Operator::Sub | Operator::Mult | Operator::Mod | Operator::FloorDiv => {
if is_left_list || is_right_list {
if ![Operator::Add, Operator::Mult].contains(&op.base) {
return Err(format!(
"Binary operator {} not supported for list",
op.op_info().symbol
));
}
if is_left_list {
lhs
} else {
rhs
}
} else if is_left_ndarray || is_right_ndarray {
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
} else if unifier.unioned(lhs, rhs) {
lhs
} else {
return Ok(None);
}
}
Operator::MatMult => {
// NOTE: NumPy matmul's LHS and RHS must both be ndarrays. Scalars are not allowed.
match (&*unifier.get_ty(lhs), &*unifier.get_ty(rhs)) {
(
TypeEnum::TObj { obj_id: lhs_obj_id, .. },
TypeEnum::TObj { obj_id: rhs_obj_id, .. },
) if *lhs_obj_id == primitives.ndarray.obj_id(unifier).unwrap()
&& *rhs_obj_id == primitives.ndarray.obj_id(unifier).unwrap() =>
{
// LHS and RHS have valid types
}
_ => {
let lhs_str = unifier.stringify(lhs);
let rhs_str = unifier.stringify(rhs);
return Err(format!("ndarray.__matmul__ only accepts ndarray operands, but left operand has type {lhs_str}, and right operand has type {rhs_str}"));
}
}
let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs);
let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) {
TypeEnum::TLiteral { values, .. } => {
assert_eq!(values.len(), 1);
u64::try_from(values[0].clone()).unwrap()
}
_ => unreachable!(),
};
let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs);
let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) {
TypeEnum::TLiteral { values, .. } => {
assert_eq!(values.len(), 1);
u64::try_from(values[0].clone()).unwrap()
}
_ => unreachable!(),
};
match (lhs_ndims, rhs_ndims) {
(2, 2) => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?,
(lhs, rhs) if lhs == 0 || rhs == 0 => {
return Err(format!(
"Input operand {} does not have enough dimensions (has {lhs}, requires {rhs})",
u8::from(rhs == 0)
))
}
(lhs, rhs) => {
return Err(format!(
"ndarray.__matmul__ on {lhs}D and {rhs}D operands not supported"
))
}
}
}
Operator::Div => {
if is_left_ndarray || is_right_ndarray {
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
} else if unifier.unioned(lhs, rhs) {
primitives.float
} else {
return Ok(None);
}
}
Operator::Pow => {
if is_left_ndarray || is_right_ndarray {
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
} else if [
primitives.int32,
primitives.int64,
primitives.uint32,
primitives.uint64,
primitives.float,
]
.into_iter()
.any(|ty| unifier.unioned(lhs, ty))
{
lhs
} else {
return Ok(None);
}
}
Operator::LShift | Operator::RShift => lhs,
Operator::BitOr | Operator::BitXor | Operator::BitAnd => {
if unifier.unioned(lhs, rhs) {
lhs
} else {
return Ok(None);
}
}
}))
}
pub fn typeof_unaryop(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
op: Unaryop,
operand: Type,
) -> Result<Option<Type>, String> {
let operand_obj_id = operand.obj_id(unifier);
if op == Unaryop::Not
&& operand_obj_id.is_some_and(|id| id == primitives.ndarray.obj_id(unifier).unwrap())
{
return Err(
"The truth value of an array with more than one element is ambiguous".to_string()
);
}
Ok(match op {
Unaryop::Not => match operand_obj_id {
Some(v) if v == PrimDef::NDArray.id() => Some(operand),
Some(_) => Some(primitives.bool),
_ => None,
},
Unaryop::Invert => {
if operand_obj_id.is_some_and(|id| id == PrimDef::Bool.id()) {
Some(primitives.int32)
} else if operand_obj_id.is_some_and(|id| PrimDef::iter().any(|prim| id == prim.id())) {
Some(operand)
} else {
None
}
}
Unaryop::UAdd | Unaryop::USub => {
if operand_obj_id.is_some_and(|id| id == PrimDef::NDArray.id()) {
let (dtype, _) = unpack_ndarray_var_tys(unifier, operand);
if dtype.obj_id(unifier).is_some_and(|id| id == PrimDef::Bool.id()) {
return Err(if op == Unaryop::UAdd {
"The ufunc 'positive' cannot be applied to ndarray[bool, N]".to_string()
} else {
"The numpy boolean negative, the `-` operator, is not supported, use the `~` operator function instead.".to_string()
});
}
Some(operand)
} else if operand_obj_id.is_some_and(|id| id == PrimDef::Bool.id()) {
Some(primitives.int32)
} else if operand_obj_id.is_some_and(|id| PrimDef::iter().any(|prim| id == prim.id())) {
Some(operand)
} else {
None
}
}
})
}
/// Returns the return type given a comparison operator and its primitive operands.
pub fn typeof_cmpop(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
_op: Cmpop,
lhs: Type,
rhs: Type,
) -> Result<Option<Type>, String> {
let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
Ok(Some(if is_left_ndarray || is_right_ndarray {
let brd = typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?;
let (_, ndims) = unpack_ndarray_var_tys(unifier, brd);
make_ndarray_ty(unifier, primitives, Some(primitives.bool), Some(ndims))
} else if unifier.unioned(lhs, rhs) {
primitives.bool
} else {
return Ok(None);
}))
pub fn impl_eq(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_cmpop(unifier, store, ty, ty, &[Cmpop::Eq, Cmpop::NotEq]);
}
pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) {
@ -701,81 +299,39 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
bool: bool_t,
uint32: uint32_t,
uint64: uint64_t,
str: str_t,
list: list_t,
ndarray: ndarray_t,
..
} = *store;
let size_t = store.usize();
/* int ======== */
for t in [int32_t, int64_t, uint32_t, uint64_t] {
let ndarray_int_t = make_ndarray_ty(unifier, store, Some(t), None);
impl_basic_arithmetic(unifier, store, t, &[t, ndarray_int_t], None);
impl_pow(unifier, store, t, &[t, ndarray_int_t], None);
impl_basic_arithmetic(unifier, store, t, &[t], t);
impl_pow(unifier, store, t, &[t], t);
impl_bitwise_arithmetic(unifier, store, t);
impl_bitwise_shift(unifier, store, t);
impl_div(unifier, store, t, &[t, ndarray_int_t], None);
impl_floordiv(unifier, store, t, &[t, ndarray_int_t], None);
impl_mod(unifier, store, t, &[t, ndarray_int_t], None);
impl_invert(unifier, store, t, Some(t));
impl_not(unifier, store, t, Some(bool_t));
impl_comparison(unifier, store, t, &[t, ndarray_int_t], None);
impl_eq(unifier, store, t, &[t, ndarray_int_t], None);
impl_div(unifier, store, t, &[t]);
impl_floordiv(unifier, store, t, &[t], t);
impl_mod(unifier, store, t, &[t], t);
impl_invert(unifier, store, t);
impl_not(unifier, store, t);
impl_comparison(unifier, store, t, t);
impl_eq(unifier, store, t);
}
for t in [int32_t, int64_t] {
impl_sign(unifier, store, t, Some(t));
impl_sign(unifier, store, t);
}
/* float ======== */
let ndarray_float_t = make_ndarray_ty(unifier, store, Some(float_t), None);
let ndarray_int32_t = make_ndarray_ty(unifier, store, Some(int32_t), None);
impl_basic_arithmetic(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_pow(unifier, store, float_t, &[int32_t, float_t, ndarray_int32_t, ndarray_float_t], None);
impl_div(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_floordiv(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_mod(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_sign(unifier, store, float_t, Some(float_t));
impl_not(unifier, store, float_t, Some(bool_t));
impl_comparison(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_eq(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_basic_arithmetic(unifier, store, float_t, &[float_t], float_t);
impl_pow(unifier, store, float_t, &[int32_t, float_t], float_t);
impl_div(unifier, store, float_t, &[float_t]);
impl_floordiv(unifier, store, float_t, &[float_t], float_t);
impl_mod(unifier, store, float_t, &[float_t], float_t);
impl_sign(unifier, store, float_t);
impl_not(unifier, store, float_t);
impl_comparison(unifier, store, float_t, float_t);
impl_eq(unifier, store, float_t);
/* bool ======== */
let ndarray_bool_t = make_ndarray_ty(unifier, store, Some(bool_t), None);
impl_invert(unifier, store, bool_t, Some(int32_t));
impl_not(unifier, store, bool_t, Some(bool_t));
impl_sign(unifier, store, bool_t, Some(int32_t));
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None);
/* str ========= */
impl_cmpop(unifier, store, str_t, &[str_t], &[Cmpop::Eq, Cmpop::NotEq], Some(bool_t));
/* list ======== */
impl_binop(unifier, store, list_t, &[list_t], Some(list_t), &[Operator::Add]);
impl_binop(unifier, store, list_t, &[int32_t, int64_t], Some(list_t), &[Operator::Mult]);
impl_cmpop(unifier, store, list_t, &[list_t], &[Cmpop::Eq, Cmpop::NotEq], Some(bool_t));
/* ndarray ===== */
let ndarray_usized_ndims_tvar =
unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None);
let ndarray_unsized_t =
make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.ty));
let (ndarray_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_t);
let (ndarray_unsized_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_unsized_t);
impl_basic_arithmetic(
unifier,
store,
ndarray_t,
&[ndarray_unsized_t, ndarray_unsized_dtype_t],
None,
);
impl_pow(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None);
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_matmul(unifier, store, ndarray_t, &[ndarray_t], Some(ndarray_t));
impl_sign(unifier, store, ndarray_t, Some(ndarray_t));
impl_invert(unifier, store, ndarray_t, Some(ndarray_t));
impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_comparison(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_not(unifier, store, bool_t);
impl_eq(unifier, store, bool_t);
}

View File

@ -1,45 +1,24 @@
use std::{collections::HashMap, fmt::Display};
use std::collections::HashMap;
use std::fmt::Display;
use itertools::Itertools;
use crate::typecheck::typedef::TypeEnum;
use nac3parser::ast::{Cmpop, Location, StrRef};
use super::{
magic_methods::{Binop, HasOpInfo},
typedef::{RecordKey, Type, TypeEnum, Unifier},
};
use super::typedef::{RecordKey, Type, Unifier};
use nac3parser::ast::{Location, StrRef};
#[derive(Debug, Clone)]
pub enum TypeErrorKind {
GotMultipleValues {
name: StrRef,
},
TooManyArguments {
expected_min_count: usize,
expected_max_count: usize,
got_count: usize,
},
MissingArgs {
missing_arg_names: Vec<StrRef>,
expected: usize,
got: usize,
},
MissingArgs(String),
UnknownArgName(StrRef),
IncorrectArgType {
name: StrRef,
expected: Type,
got: Type,
},
UnsupportedBinaryOpTypes {
operator: Binop,
lhs_type: Type,
rhs_type: Type,
expected_rhs_type: Type,
},
UnsupportedComparsionOpTypes {
operator: Cmpop,
lhs_type: Type,
rhs_type: Type,
expected_rhs_type: Type,
},
FieldUnificationError {
field: RecordKey,
types: (Type, Type),
@ -55,7 +34,6 @@ pub enum TypeErrorKind {
},
RequiresTypeAnn,
PolymorphicFunctionPointer,
NoSuchAttribute(RecordKey, Type),
}
#[derive(Debug, Clone)]
@ -99,49 +77,22 @@ impl<'a> Display for DisplayTypeError<'a> {
use TypeErrorKind::*;
let mut notes = Some(HashMap::new());
match &self.err.kind {
GotMultipleValues { name } => {
write!(f, "For multiple values for parameter {name}")
TooManyArguments { expected, got } => {
write!(f, "Too many arguments. Expected {expected} but got {got}")
}
TooManyArguments { expected_min_count, expected_max_count, got_count } => {
debug_assert!(expected_min_count <= expected_max_count);
if expected_min_count == expected_max_count {
let expected_count = expected_min_count; // or expected_max_count
write!(f, "Too many arguments. Expected {expected_count} but got {got_count}")
} else {
write!(f, "Too many arguments. Expected {expected_min_count} to {expected_max_count} arguments but got {got_count}")
}
}
MissingArgs { missing_arg_names } => {
let args = missing_arg_names.iter().join(", ");
MissingArgs(args) => {
write!(f, "Missing arguments: {args}")
}
UnsupportedBinaryOpTypes { operator, lhs_type, rhs_type, expected_rhs_type } => {
let op_symbol = operator.op_info().symbol;
let lhs_type_str = self.unifier.stringify_with_notes(*lhs_type, &mut notes);
let rhs_type_str = self.unifier.stringify_with_notes(*rhs_type, &mut notes);
let expected_rhs_type_str =
self.unifier.stringify_with_notes(*expected_rhs_type, &mut notes);
write!(f, "Unsupported operand type(s) for {op_symbol}: '{lhs_type_str}' and '{rhs_type_str}' (right operand should have type {expected_rhs_type_str})")
}
UnsupportedComparsionOpTypes { operator, lhs_type, rhs_type, expected_rhs_type } => {
let op_symbol = operator.op_info().symbol;
let lhs_type_str = self.unifier.stringify_with_notes(*lhs_type, &mut notes);
let rhs_type_str = self.unifier.stringify_with_notes(*rhs_type, &mut notes);
let expected_rhs_type_str =
self.unifier.stringify_with_notes(*expected_rhs_type, &mut notes);
write!(f, "'{op_symbol}' not supported between instances of '{lhs_type_str}' and '{rhs_type_str}' (right operand should have type {expected_rhs_type_str})")
}
UnknownArgName(name) => {
write!(f, "Unknown argument name: {name}")
}
IncorrectArgType { name, expected, got } => {
let expected = self.unifier.stringify_with_notes(*expected, &mut notes);
let got = self.unifier.stringify_with_notes(*got, &mut notes);
write!(f, "Incorrect argument type for parameter {name}. Expected {expected}, but got {got}")
write!(
f,
"Incorrect argument type for {name}. Expected {expected}, but got {got}"
)
}
FieldUnificationError { field, types, loc } => {
let lhs = self.unifier.stringify_with_notes(types.0, &mut notes);
@ -182,10 +133,9 @@ impl<'a> Display for DisplayTypeError<'a> {
}
result
}
(
TypeEnum::TTuple { ty: ty1, is_vararg_ctx: is_vararg1 },
TypeEnum::TTuple { ty: ty2, is_vararg_ctx: is_vararg2 },
) if !is_vararg1 && !is_vararg2 && ty1.len() != ty2.len() => {
(TypeEnum::TTuple { ty: ty1 }, TypeEnum::TTuple { ty: ty2 })
if ty1.len() != ty2.len() =>
{
let t1 = self.unifier.stringify_with_notes(*t1, &mut notes);
let t2 = self.unifier.stringify_with_notes(*t2, &mut notes);
write!(f, "Tuple length mismatch: got {t1} and {t2}")
@ -209,10 +159,6 @@ impl<'a> Display for DisplayTypeError<'a> {
let t = self.unifier.stringify_with_notes(*t, &mut notes);
write!(f, "`{t}::{name}` field/method does not exist")
}
NoSuchAttribute(name, t) => {
let t = self.unifier.stringify_with_notes(*t, &mut notes);
write!(f, "`{t}::{name}` is not a class attribute")
}
TupleIndexOutOfBounds { index, len } => {
write!(
f,

File diff suppressed because it is too large Load Diff

View File

@ -1,19 +1,15 @@
use std::iter::zip;
use indexmap::IndexMap;
use indoc::indoc;
use parking_lot::RwLock;
use test_case::test_case;
use nac3parser::{ast::FileName, parser::parse_program};
use super::super::{magic_methods::with_fields, typedef::*};
use super::*;
use crate::{
codegen::{CodeGenContext, CodeGenerator},
codegen::CodeGenContext,
symbol_resolver::ValueEnum,
toplevel::{helper::PrimDef, DefinitionId, TopLevelDef},
typecheck::{magic_methods::with_fields, typedef::*},
toplevel::{DefinitionId, TopLevelDef},
};
use indoc::indoc;
use std::iter::zip;
use nac3parser::parser::parse_program;
use parking_lot::RwLock;
use test_case::test_case;
struct Resolver {
id_to_type: HashMap<StrRef, Type>,
@ -36,22 +32,19 @@ impl SymbolResolver for Resolver {
_: &PrimitiveStore,
str: StrRef,
) -> Result<Type, String> {
self.id_to_type.get(&str).copied().ok_or_else(|| format!("cannot find symbol `{str}`"))
self.id_to_type.get(&str).cloned().ok_or_else(|| format!("cannot find symbol `{}`", str))
}
fn get_symbol_value<'ctx>(
fn get_symbol_value<'ctx, 'a>(
&self,
_: StrRef,
_: &mut CodeGenContext<'ctx, '_>,
_: &mut dyn CodeGenerator,
_: &mut CodeGenContext<'ctx, 'a>,
) -> Option<ValueEnum<'ctx>> {
unimplemented!()
}
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> {
self.id_to_def
.get(&id)
.copied()
self.id_to_def.get(&id).cloned()
.ok_or_else(|| HashSet::from(["Unknown identifier".to_string()]))
}
@ -80,86 +73,67 @@ impl TestEnvironment {
let mut unifier = Unifier::new();
let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Int32.id(),
obj_id: DefinitionId(0),
fields: HashMap::new(),
params: VarMap::new(),
params: HashMap::new(),
});
with_fields(&mut unifier, int32, |unifier, fields| {
let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg {
name: "other".into(),
ty: int32,
default_value: None,
is_vararg: false,
}],
args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }],
ret: int32,
vars: VarMap::new(),
vars: HashMap::new(),
}));
fields.insert("__add__".into(), (add_ty, false));
});
let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Int64.id(),
obj_id: DefinitionId(1),
fields: HashMap::new(),
params: VarMap::new(),
params: HashMap::new(),
});
let float = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Float.id(),
obj_id: DefinitionId(2),
fields: HashMap::new(),
params: VarMap::new(),
params: HashMap::new(),
});
let bool = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Bool.id(),
obj_id: DefinitionId(3),
fields: HashMap::new(),
params: VarMap::new(),
params: HashMap::new(),
});
let none = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::None.id(),
obj_id: DefinitionId(4),
fields: HashMap::new(),
params: VarMap::new(),
params: HashMap::new(),
});
let range = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Range.id(),
obj_id: DefinitionId(5),
fields: HashMap::new(),
params: VarMap::new(),
params: HashMap::new(),
});
let str = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Str.id(),
obj_id: DefinitionId(6),
fields: HashMap::new(),
params: VarMap::new(),
params: HashMap::new(),
});
let exception = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Exception.id(),
obj_id: DefinitionId(7),
fields: HashMap::new(),
params: VarMap::new(),
params: HashMap::new(),
});
let uint32 = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::UInt32.id(),
obj_id: DefinitionId(8),
fields: HashMap::new(),
params: VarMap::new(),
params: HashMap::new(),
});
let uint64 = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::UInt64.id(),
obj_id: DefinitionId(9),
fields: HashMap::new(),
params: VarMap::new(),
params: HashMap::new(),
});
let option = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Option.id(),
obj_id: DefinitionId(10),
fields: HashMap::new(),
params: VarMap::new(),
});
let list_elem_tvar = unifier.get_fresh_var(Some("list_elem".into()), None);
let list = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::List.id(),
fields: HashMap::new(),
params: into_var_map([list_elem_tvar]),
});
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
let ndarray_ndims_tvar =
unifier.get_fresh_const_generic_var(uint64, Some("ndarray_ndims".into()), None);
let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::NDArray.id(),
fields: HashMap::new(),
params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
params: HashMap::new(),
});
let primitives = PrimitiveStore {
int32,
@ -173,14 +147,10 @@ impl TestEnvironment {
uint32,
uint64,
option,
list,
ndarray,
size_t: 64,
};
unifier.put_primitive_store(&primitives);
set_primitives_magic_methods(&primitives, &mut unifier);
let id_to_name: HashMap<_, _> = [
let id_to_name = [
(0, "int32".into()),
(1, "int64".into()),
(2, "float".into()),
@ -190,21 +160,23 @@ impl TestEnvironment {
(6, "str".into()),
(7, "exception".into()),
]
.into();
.iter()
.cloned()
.collect();
let mut identifier_mapping = HashMap::new();
identifier_mapping.insert("None".into(), none);
let resolver = Arc::new(Resolver {
id_to_type: identifier_mapping.clone(),
id_to_def: HashMap::default(),
class_names: HashMap::default(),
id_to_def: Default::default(),
class_names: Default::default(),
}) as Arc<dyn SymbolResolver + Send + Sync>;
TestEnvironment {
top_level: TopLevelContext {
definitions: Arc::default(),
unifiers: Arc::default(),
definitions: Default::default(),
unifiers: Default::default(),
personality_symbol: None,
},
unifier,
@ -226,100 +198,70 @@ impl TestEnvironment {
let mut identifier_mapping = HashMap::new();
let mut top_level_defs: Vec<Arc<RwLock<TopLevelDef>>> = Vec::new();
let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Int32.id(),
obj_id: DefinitionId(0),
fields: HashMap::new(),
params: VarMap::new(),
params: HashMap::new(),
});
with_fields(&mut unifier, int32, |unifier, fields| {
let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg {
name: "other".into(),
ty: int32,
default_value: None,
is_vararg: false,
}],
args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }],
ret: int32,
vars: VarMap::new(),
vars: HashMap::new(),
}));
fields.insert("__add__".into(), (add_ty, false));
});
let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Int64.id(),
obj_id: DefinitionId(1),
fields: HashMap::new(),
params: VarMap::new(),
params: HashMap::new(),
});
let float = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Float.id(),
obj_id: DefinitionId(2),
fields: HashMap::new(),
params: VarMap::new(),
params: HashMap::new(),
});
let bool = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Bool.id(),
obj_id: DefinitionId(3),
fields: HashMap::new(),
params: VarMap::new(),
params: HashMap::new(),
});
let none = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::None.id(),
obj_id: DefinitionId(4),
fields: HashMap::new(),
params: VarMap::new(),
params: HashMap::new(),
});
let range = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Range.id(),
obj_id: DefinitionId(5),
fields: HashMap::new(),
params: VarMap::new(),
params: HashMap::new(),
});
let str = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Str.id(),
obj_id: DefinitionId(6),
fields: HashMap::new(),
params: VarMap::new(),
params: HashMap::new(),
});
let exception = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Exception.id(),
obj_id: DefinitionId(7),
fields: HashMap::new(),
params: VarMap::new(),
params: HashMap::new(),
});
let uint32 = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::UInt32.id(),
obj_id: DefinitionId(8),
fields: HashMap::new(),
params: VarMap::new(),
params: HashMap::new(),
});
let uint64 = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::UInt64.id(),
obj_id: DefinitionId(9),
fields: HashMap::new(),
params: VarMap::new(),
params: HashMap::new(),
});
let option = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Option.id(),
obj_id: DefinitionId(10),
fields: HashMap::new(),
params: VarMap::new(),
});
let list_elem_tvar = unifier.get_fresh_var(Some("list_elem".into()), None);
let list = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::List.id(),
fields: HashMap::new(),
params: into_var_map([list_elem_tvar]),
});
let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::NDArray.id(),
fields: HashMap::new(),
params: VarMap::new(),
params: HashMap::new(),
});
identifier_mapping.insert("None".into(), none);
for (i, name) in [
"int32",
"int64",
"float",
"bool",
"none",
"range",
"str",
"Exception",
"uint32",
"uint64",
"Option",
"list",
"ndarray",
]
for (i, name) in ["int32", "int64", "float", "bool", "none", "range", "str", "Exception"]
.iter()
.enumerate()
{
@ -327,11 +269,10 @@ impl TestEnvironment {
RwLock::new(TopLevelDef::Class {
name: (*name).into(),
object_id: DefinitionId(i),
type_vars: Vec::default(),
fields: Vec::default(),
attributes: Vec::default(),
methods: Vec::default(),
ancestors: Vec::default(),
type_vars: Default::default(),
fields: Default::default(),
methods: Default::default(),
ancestors: Default::default(),
resolver: None,
constructor: None,
loc: None,
@ -339,7 +280,7 @@ impl TestEnvironment {
.into(),
);
}
let defs = 12;
let defs = 7;
let primitives = PrimitiveStore {
int32,
@ -353,29 +294,23 @@ impl TestEnvironment {
uint32,
uint64,
option,
list,
ndarray,
size_t: 64,
};
unifier.put_primitive_store(&primitives);
let tvar = unifier.get_dummy_var();
let (v0, id) = unifier.get_dummy_var();
let foo_ty = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(defs + 1),
fields: [("a".into(), (tvar.ty, true))].into(),
params: into_var_map([tvar]),
fields: [("a".into(), (v0, true))].iter().cloned().collect::<HashMap<_, _>>(),
params: [(id, v0)].iter().cloned().collect::<HashMap<_, _>>(),
});
top_level_defs.push(
RwLock::new(TopLevelDef::Class {
name: "Foo".into(),
object_id: DefinitionId(defs + 1),
type_vars: vec![tvar.ty],
fields: [("a".into(), tvar.ty, true)].into(),
attributes: Vec::default(),
methods: Vec::default(),
ancestors: Vec::default(),
type_vars: vec![v0],
fields: [("a".into(), v0, true)].into(),
methods: Default::default(),
ancestors: Default::default(),
resolver: None,
constructor: None,
loc: None,
@ -388,29 +323,31 @@ impl TestEnvironment {
unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![],
ret: foo_ty,
vars: into_var_map([tvar]),
vars: [(id, v0)].iter().cloned().collect(),
})),
);
let fun = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![],
ret: int32,
vars: IndexMap::default(),
vars: Default::default(),
}));
let bar = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(defs + 2),
fields: [("a".into(), (int32, true)), ("b".into(), (fun, true))].into(),
params: IndexMap::default(),
fields: [("a".into(), (int32, true)), ("b".into(), (fun, true))]
.iter()
.cloned()
.collect::<HashMap<_, _>>(),
params: Default::default(),
});
top_level_defs.push(
RwLock::new(TopLevelDef::Class {
name: "Bar".into(),
object_id: DefinitionId(defs + 2),
type_vars: Vec::default(),
type_vars: Default::default(),
fields: [("a".into(), int32, true), ("b".into(), fun, true)].into(),
attributes: Vec::default(),
methods: Vec::default(),
ancestors: Vec::default(),
methods: Default::default(),
ancestors: Default::default(),
resolver: None,
constructor: None,
loc: None,
@ -422,24 +359,26 @@ impl TestEnvironment {
unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![],
ret: bar,
vars: IndexMap::default(),
vars: Default::default(),
})),
);
let bar2 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(defs + 3),
fields: [("a".into(), (bool, true)), ("b".into(), (fun, false))].into(),
params: IndexMap::default(),
fields: [("a".into(), (bool, true)), ("b".into(), (fun, false))]
.iter()
.cloned()
.collect::<HashMap<_, _>>(),
params: Default::default(),
});
top_level_defs.push(
RwLock::new(TopLevelDef::Class {
name: "Bar2".into(),
object_id: DefinitionId(defs + 3),
type_vars: Vec::default(),
type_vars: Default::default(),
fields: [("a".into(), bool, true), ("b".into(), fun, false)].into(),
attributes: Vec::default(),
methods: Vec::default(),
ancestors: Vec::default(),
methods: Default::default(),
ancestors: Default::default(),
resolver: None,
constructor: None,
loc: None,
@ -451,10 +390,10 @@ impl TestEnvironment {
unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![],
ret: bar2,
vars: IndexMap::default(),
vars: Default::default(),
})),
);
let class_names: HashMap<_, _> = [("Bar".into(), bar), ("Bar2".into(), bar2)].into();
let class_names = [("Bar".into(), bar), ("Bar2".into(), bar2)].iter().cloned().collect();
let id_to_name = [
"int32".into(),
@ -465,22 +404,18 @@ impl TestEnvironment {
"range".into(),
"str".into(),
"exception".into(),
"uint32".into(),
"uint64".into(),
"option".into(),
"list".into(),
"ndarray".into(),
"Foo".into(),
"Bar".into(),
"Bar2".into(),
]
.into_iter()
.iter()
.enumerate()
.map(|(a, b)| (a, *b))
.collect();
let top_level = TopLevelContext {
definitions: Arc::new(top_level_defs.into()),
unifiers: Arc::default(),
unifiers: Default::default(),
personality_symbol: None,
};
@ -491,7 +426,9 @@ impl TestEnvironment {
("Bar".into(), DefinitionId(defs + 2)),
("Bar2".into(), DefinitionId(defs + 3)),
]
.into(),
.iter()
.cloned()
.collect(),
class_names,
}) as Arc<dyn SymbolResolver + Send + Sync>;
@ -516,11 +453,11 @@ impl TestEnvironment {
top_level: &self.top_level,
function_data: &mut self.function_data,
unifier: &mut self.unifier,
variable_mapping: HashMap::default(),
variable_mapping: Default::default(),
primitives: &mut self.primitives,
virtual_checks: &mut self.virtual_checks,
calls: &mut self.calls,
defined_identifiers: HashMap::default(),
defined_identifiers: Default::default(),
in_handler: false,
}
}
@ -532,7 +469,7 @@ impl TestEnvironment {
c = 1.234
d = True
"},
&[("a", "int32"), ("b", "int64"), ("c", "float"), ("d", "bool")].into(),
[("a", "int32"), ("b", "int64"), ("c", "float"), ("d", "bool")].iter().cloned().collect(),
&[]
; "primitives test")]
#[test_case(indoc! {"
@ -541,7 +478,7 @@ impl TestEnvironment {
c = 1.234
d = b(c)
"},
&[("a", "fn[[x:float, y:float], float]"), ("b", "fn[[x:float], float]"), ("c", "float"), ("d", "float")].into(),
[("a", "fn[[x:float, y:float], float]"), ("b", "fn[[x:float], float]"), ("c", "float"), ("d", "float")].iter().cloned().collect(),
&[]
; "lambda test")]
#[test_case(indoc! {"
@ -550,7 +487,7 @@ impl TestEnvironment {
a = b
c = b(1)
"},
&[("a", "fn[[x:int32], int32]"), ("b", "fn[[x:int32], int32]"), ("c", "int32")].into(),
[("a", "fn[[x:int32], int32]"), ("b", "fn[[x:int32], int32]"), ("c", "int32")].iter().cloned().collect(),
&[]
; "lambda test 2")]
#[test_case(indoc! {"
@ -566,15 +503,15 @@ impl TestEnvironment {
b(123)
"},
&[("a", "fn[[x:bool], bool]"), ("b", "fn[[x:int32], int32]"), ("c", "bool"),
("d", "int32"), ("foo1", "Foo[bool]"), ("foo2", "Foo[int32]")].into(),
[("a", "fn[[x:bool], bool]"), ("b", "fn[[x:int32], int32]"), ("c", "bool"),
("d", "int32"), ("foo1", "Foo[bool]"), ("foo2", "Foo[int32]")].iter().cloned().collect(),
&[]
; "obj test")]
#[test_case(indoc! {"
a = [1, 2, 3]
b = [x + x for x in a]
"},
&[("a", "list[int32]"), ("b", "list[int32]")].into(),
[("a", "list[int32]"), ("b", "list[int32]")].iter().cloned().collect(),
&[]
; "listcomp test")]
#[test_case(indoc! {"
@ -582,26 +519,25 @@ impl TestEnvironment {
b = a.b()
a = virtual(Bar2())
"},
&[("a", "virtual[Bar]"), ("b", "int32")].into(),
[("a", "virtual[Bar]"), ("b", "int32")].iter().cloned().collect(),
&[("Bar", "Bar"), ("Bar2", "Bar")]
; "virtual test")]
#[test_case(indoc! {"
a = [virtual(Bar(), Bar), virtual(Bar2())]
b = [x.b() for x in a]
"},
&[("a", "list[virtual[Bar]]"), ("b", "list[int32]")].into(),
[("a", "list[virtual[Bar]]"), ("b", "list[int32]")].iter().cloned().collect(),
&[("Bar", "Bar"), ("Bar2", "Bar")]
; "virtual list test")]
fn test_basic(source: &str, mapping: &HashMap<&str, &str>, virtuals: &[(&str, &str)]) {
println!("source:\n{source}");
fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &str)]) {
println!("source:\n{}", source);
let mut env = TestEnvironment::new();
let id_to_name = std::mem::take(&mut env.id_to_name);
let mut defined_identifiers: HashMap<_, _> =
env.identifier_mapping.keys().copied().map(|id| (id, IdentifierInfo::default())).collect();
defined_identifiers.insert("virtual".into(), IdentifierInfo::default());
let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().cloned().collect();
defined_identifiers.insert("virtual".into());
let mut inferencer = env.get_inferencer();
inferencer.defined_identifiers.clone_from(&defined_identifiers);
let statements = parse_program(source, FileName::default()).unwrap();
inferencer.defined_identifiers = defined_identifiers.clone();
let statements = parse_program(source, Default::default()).unwrap();
let statements = statements
.into_iter()
.map(|v| inferencer.fold_stmt(v))
@ -610,37 +546,37 @@ fn test_basic(source: &str, mapping: &HashMap<&str, &str>, virtuals: &[(&str, &s
inferencer.check_block(&statements, &mut defined_identifiers).unwrap();
for (k, v) in &inferencer.variable_mapping {
for (k, v) in inferencer.variable_mapping.iter() {
let name = inferencer.unifier.internal_stringify(
*v,
&mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{v}"),
&mut |v| format!("v{}", v),
&mut None,
);
println!("{k}: {name}");
println!("{}: {}", k, name);
}
for (k, v) in mapping {
for (k, v) in mapping.iter() {
let ty = inferencer.variable_mapping.get(&(*k).into()).unwrap();
let name = inferencer.unifier.internal_stringify(
*ty,
&mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{v}"),
&mut |v| format!("v{}", v),
&mut None,
);
assert_eq!(format!("{k}: {v}"), format!("{k}: {name}"));
assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name));
}
assert_eq!(inferencer.virtual_checks.len(), virtuals.len());
for ((a, b, _), (x, y)) in zip(inferencer.virtual_checks.iter(), virtuals) {
let a = inferencer.unifier.internal_stringify(
*a,
&mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{v}"),
&mut |v| format!("v{}", v),
&mut None,
);
let b = inferencer.unifier.internal_stringify(
*b,
&mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{v}"),
&mut |v| format!("v{}", v),
&mut None,
);
@ -659,14 +595,14 @@ fn test_basic(source: &str, mapping: &HashMap<&str, &str>, virtuals: &[(&str, &s
g = a // b
h = a % b
"},
&[("a", "int32"),
[("a", "int32"),
("b", "int32"),
("c", "int32"),
("d", "int32"),
("e", "int32"),
("f", "float"),
("g", "int32"),
("h", "int32")].into()
("h", "int32")].iter().cloned().collect()
; "int32")]
#[test_case(
indoc! {"
@ -682,7 +618,7 @@ fn test_basic(source: &str, mapping: &HashMap<&str, &str>, virtuals: &[(&str, &s
ii = 3
j = a ** b
"},
&[("a", "float"),
[("a", "float"),
("b", "float"),
("c", "float"),
("d", "float"),
@ -692,7 +628,7 @@ fn test_basic(source: &str, mapping: &HashMap<&str, &str>, virtuals: &[(&str, &s
("h", "float"),
("i", "float"),
("ii", "int32"),
("j", "float")].into()
("j", "float")].iter().cloned().collect()
; "float"
)]
#[test_case(
@ -710,7 +646,7 @@ fn test_basic(source: &str, mapping: &HashMap<&str, &str>, virtuals: &[(&str, &s
k = a < b
l = a != b
"},
&[("a", "int64"),
[("a", "int64"),
("b", "int64"),
("c", "int64"),
("d", "int64"),
@ -721,7 +657,7 @@ fn test_basic(source: &str, mapping: &HashMap<&str, &str>, virtuals: &[(&str, &s
("i", "bool"),
("j", "bool"),
("k", "bool"),
("l", "bool")].into()
("l", "bool")].iter().cloned().collect()
; "int64"
)]
#[test_case(
@ -732,23 +668,22 @@ fn test_basic(source: &str, mapping: &HashMap<&str, &str>, virtuals: &[(&str, &s
d = not a
e = a != b
"},
&[("a", "bool"),
[("a", "bool"),
("b", "bool"),
("c", "bool"),
("d", "bool"),
("e", "bool")].into()
("e", "bool")].iter().cloned().collect()
; "boolean"
)]
fn test_primitive_magic_methods(source: &str, mapping: &HashMap<&str, &str>) {
println!("source:\n{source}");
fn test_primitive_magic_methods(source: &str, mapping: HashMap<&str, &str>) {
println!("source:\n{}", source);
let mut env = TestEnvironment::basic_test_env();
let id_to_name = std::mem::take(&mut env.id_to_name);
let mut defined_identifiers: HashMap<_, _> =
env.identifier_mapping.keys().copied().map(|id| (id, IdentifierInfo::default())).collect();
defined_identifiers.insert("virtual".into(), IdentifierInfo::default());
let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().cloned().collect();
defined_identifiers.insert("virtual".into());
let mut inferencer = env.get_inferencer();
inferencer.defined_identifiers.clone_from(&defined_identifiers);
let statements = parse_program(source, FileName::default()).unwrap();
inferencer.defined_identifiers = defined_identifiers.clone();
let statements = parse_program(source, Default::default()).unwrap();
let statements = statements
.into_iter()
.map(|v| inferencer.fold_stmt(v))
@ -757,23 +692,23 @@ fn test_primitive_magic_methods(source: &str, mapping: &HashMap<&str, &str>) {
inferencer.check_block(&statements, &mut defined_identifiers).unwrap();
for (k, v) in &inferencer.variable_mapping {
for (k, v) in inferencer.variable_mapping.iter() {
let name = inferencer.unifier.internal_stringify(
*v,
&mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{v}"),
&mut |v| format!("v{}", v),
&mut None,
);
println!("{k}: {name}");
println!("{}: {}", k, name);
}
for (k, v) in mapping {
for (k, v) in mapping.iter() {
let ty = inferencer.variable_mapping.get(&(*k).into()).unwrap();
let name = inferencer.unifier.internal_stringify(
*ty,
&mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{v}"),
&mut |v| format!("v{}", v),
&mut None,
);
assert_eq!(format!("{k}: {v}"), format!("{k}: {name}"));
assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name));
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,12 +1,10 @@
use std::collections::HashMap;
use super::super::magic_methods::with_fields;
use super::*;
use indoc::indoc;
use itertools::Itertools;
use std::collections::HashMap;
use test_case::test_case;
use super::*;
use crate::typecheck::magic_methods::with_fields;
impl Unifier {
/// Check whether two types are equal.
fn eq(&mut self, a: Type, b: Type) -> bool {
@ -30,32 +28,32 @@ impl Unifier {
TypeEnum::TVar { fields: Some(map1), .. },
TypeEnum::TVar { fields: Some(map2), .. },
) => self.map_eq2(map1, map2),
(
TypeEnum::TTuple { ty: ty1, is_vararg_ctx: false },
TypeEnum::TTuple { ty: ty2, is_vararg_ctx: false },
) => {
(TypeEnum::TTuple { ty: ty1 }, TypeEnum::TTuple { ty: ty2 }) => {
ty1.len() == ty2.len()
&& ty1.iter().zip(ty2.iter()).all(|(t1, t2)| self.eq(*t1, *t2))
}
(TypeEnum::TVirtual { ty: ty1 }, TypeEnum::TVirtual { ty: ty2 }) => self.eq(*ty1, *ty2),
(TypeEnum::TList { ty: ty1 }, TypeEnum::TList { ty: ty2 })
| (TypeEnum::TVirtual { ty: ty1 }, TypeEnum::TVirtual { ty: ty2 }) => {
self.eq(*ty1, *ty2)
}
(
TypeEnum::TObj { obj_id: id1, params: params1, .. },
TypeEnum::TObj { obj_id: id2, params: params2, .. },
) => id1 == id2 && self.map_eq(params1, params2),
// TLiteral, TCall and TFunc are not yet implemented
// TCall and TFunc are not yet implemented
_ => false,
}
}
fn map_eq<K>(&mut self, map1: &IndexMapping<K>, map2: &IndexMapping<K>) -> bool
fn map_eq<K>(&mut self, map1: &Mapping<K>, map2: &Mapping<K>) -> bool
where
K: std::hash::Hash + Eq + Clone,
{
if map1.len() != map2.len() {
return false;
}
for (k, v) in map1 {
if !map2.get(k).is_some_and(|v1| self.eq(*v, *v1)) {
for (k, v) in map1.iter() {
if !map2.get(k).map(|v1| self.eq(*v, *v1)).unwrap_or(false) {
return false;
}
}
@ -69,8 +67,8 @@ impl Unifier {
if map1.len() != map2.len() {
return false;
}
for (k, v) in map1 {
if !map2.get(k).is_some_and(|v1| self.eq(v.ty, v1.ty)) {
for (k, v) in map1.iter() {
if !map2.get(k).map(|v1| self.eq(v.ty, v1.ty)).unwrap_or(false) {
return false;
}
}
@ -93,7 +91,7 @@ impl TestEnvironment {
unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(0),
fields: HashMap::new(),
params: VarMap::new(),
params: HashMap::new(),
}),
);
type_mapping.insert(
@ -101,7 +99,7 @@ impl TestEnvironment {
unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(1),
fields: HashMap::new(),
params: VarMap::new(),
params: HashMap::new(),
}),
);
type_mapping.insert(
@ -109,25 +107,16 @@ impl TestEnvironment {
unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(2),
fields: HashMap::new(),
params: VarMap::new(),
params: HashMap::new(),
}),
);
let tvar = unifier.get_dummy_var();
let (v0, id) = unifier.get_dummy_var();
type_mapping.insert(
"Foo".into(),
unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(3),
fields: [("a".into(), (tvar.ty, true))].into(),
params: into_var_map([tvar]),
}),
);
let tvar = unifier.get_dummy_var();
type_mapping.insert(
"list".into(),
unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::List.id(),
fields: HashMap::new(),
params: into_var_map([tvar]),
fields: [("a".into(), (v0, true))].iter().cloned().collect::<HashMap<_, _>>(),
params: [(id, v0)].iter().cloned().collect::<HashMap<_, _>>(),
}),
);
@ -140,40 +129,14 @@ impl TestEnvironment {
result.0
}
fn internal_parse<'b>(&mut self, typ: &'b str, mapping: &Mapping<String>) -> (Type, &'b str) {
fn internal_parse<'a, 'b>(
&'a mut self,
typ: &'b str,
mapping: &Mapping<String>,
) -> (Type, &'b str) {
// for testing only, so we can just panic when the input is malformed
let end = typ.find(|c| ['[', ',', ']', '='].contains(&c)).unwrap_or(typ.len());
let end = typ.find(|c| ['[', ',', ']', '='].contains(&c)).unwrap_or_else(|| typ.len());
match &typ[..end] {
"list" => {
let mut s = &typ[end..];
assert_eq!(&s[0..1], "[");
let mut ty = Vec::new();
while &s[0..1] != "]" {
let result = self.internal_parse(&s[1..], mapping);
ty.push(result.0);
s = result.1;
}
assert_eq!(ty.len(), 1);
let list_elem_tvar = if let TypeEnum::TObj { params, .. } =
&*self.unifier.get_ty_immutable(self.type_mapping["list"])
{
iter_type_vars(params).next().unwrap()
} else {
unreachable!()
};
(
self.unifier
.subst(
self.type_mapping["list"],
&into_var_map([TypeVar { id: list_elem_tvar.id, ty: ty[0] }]),
)
.unwrap(),
&s[1..],
)
}
"tuple" => {
let mut s = &typ[end..];
assert_eq!(&s[0..1], "[");
@ -183,7 +146,13 @@ impl TestEnvironment {
ty.push(result.0);
s = result.1;
}
(self.unifier.add_ty(TypeEnum::TTuple { ty, is_vararg_ctx: false }), &s[1..])
(self.unifier.add_ty(TypeEnum::TTuple { ty }), &s[1..])
}
"list" => {
assert_eq!(&typ[end..end + 1], "[");
let (ty, s) = self.internal_parse(&typ[end + 1..], mapping);
assert_eq!(&s[0..1], "]");
(self.unifier.add_ty(TypeEnum::TList { ty }), &s[1..])
}
"Record" => {
let mut s = &typ[end..];
@ -200,12 +169,12 @@ impl TestEnvironment {
}
x => {
let mut s = &typ[end..];
let ty = mapping.get(x).copied().unwrap_or_else(|| {
let ty = mapping.get(x).cloned().unwrap_or_else(|| {
// mapping should be type variables, type_mapping should be concrete types
// we should not resolve the type of type variables.
let mut ty = *self.type_mapping.get(x).unwrap();
let te = self.unifier.get_ty(ty);
if let TypeEnum::TObj { params, .. } = &*te {
if let TypeEnum::TObj { params, .. } = &*te.as_ref() {
if !params.is_empty() {
assert_eq!(&s[0..1], "[");
let mut p = Vec::new();
@ -217,7 +186,7 @@ impl TestEnvironment {
s = &s[1..];
ty = self
.unifier
.subst(ty, &params.keys().copied().zip(p).collect())
.subst(ty, &params.keys().cloned().zip(p.into_iter()).collect())
.unwrap_or(ty);
}
}
@ -281,12 +250,12 @@ fn test_unify(
let mut mapping = HashMap::new();
for i in 1..=variable_count {
let v = env.unifier.get_dummy_var();
mapping.insert(format!("v{i}"), v.ty);
mapping.insert(format!("v{}", i), v.0);
}
// unification may have side effect when we do type resolution, so freeze the types
// before doing unification.
let mut pairs = Vec::new();
for (a, b) in &perm {
for (a, b) in perm.iter() {
let t1 = env.parse(a, &mapping);
let t2 = env.parse(b, &mapping);
pairs.push((t1, t2));
@ -294,8 +263,8 @@ fn test_unify(
for (t1, t2) in pairs {
env.unifier.unify(t1, t2).unwrap();
}
for (a, b) in verify_pairs {
println!("{a} = {b}");
for (a, b) in verify_pairs.iter() {
println!("{} = {}", a, b);
let t1 = env.parse(a, &mapping);
let t2 = env.parse(b, &mapping);
println!("a = {}, b = {}", env.unifier.stringify(t1), env.unifier.stringify(t2));
@ -309,7 +278,7 @@ fn test_unify(
("v1", "tuple[int]"),
("v2", "list[int]"),
],
(("v1", "v2"), "Incompatible types: 11[0] and tuple[0]")
(("v1", "v2"), "Incompatible types: list[0] and tuple[0]")
; "type mismatch"
)]
#[test_case(2,
@ -333,7 +302,7 @@ fn test_unify(
("v1", "Record[a=float,b=int]"),
("v2", "Foo[v3]"),
],
(("v1", "v2"), "`3[typevar5]::b` field/method does not exist")
(("v1", "v2"), "`3[typevar4]::b` field/method does not exist")
; "record obj merge"
)]
/// Test cases for invalid unifications.
@ -346,12 +315,12 @@ fn test_invalid_unification(
let mut mapping = HashMap::new();
for i in 1..=variable_count {
let v = env.unifier.get_dummy_var();
mapping.insert(format!("v{i}"), v.ty);
mapping.insert(format!("v{}", i), v.0);
}
// unification may have side effect when we do type resolution, so freeze the types
// before doing unification.
let mut pairs = Vec::new();
for (a, b) in unify_pairs {
for (a, b) in unify_pairs.iter() {
let t1 = env.parse(a, &mapping);
let t2 = env.parse(b, &mapping);
pairs.push((t1, t2));
@ -373,12 +342,16 @@ fn test_recursive_subst() {
with_fields(&mut env.unifier, foo_id, |_unifier, fields| {
fields.insert("rec".into(), (foo_id, true));
});
let TypeEnum::TObj { params, .. } = &*foo_ty else { unreachable!() };
let TypeEnum::TObj { params, .. } = &*foo_ty else {
unreachable!()
};
let mapping = params.iter().map(|(id, _)| (*id, int)).collect();
let instantiated = env.unifier.subst(foo_id, &mapping).unwrap();
let instantiated_ty = env.unifier.get_ty(instantiated);
let TypeEnum::TObj { fields, .. } = &*instantiated_ty else { unreachable!() };
let TypeEnum::TObj { fields, .. } = &*instantiated_ty else {
unreachable!()
};
assert!(env.unifier.unioned(fields.get(&"a".into()).unwrap().0, int));
assert!(env.unifier.unioned(fields.get(&"rec".into()).unwrap().0, instantiated));
}
@ -390,27 +363,36 @@ fn test_virtual() {
let fun = env.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![],
ret: int,
vars: VarMap::new(),
vars: HashMap::new(),
}));
let bar = env.unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(5),
fields: [("f".into(), (fun, false)), ("a".into(), (int, false))].into(),
params: VarMap::new(),
fields: [("f".into(), (fun, false)), ("a".into(), (int, false))]
.iter()
.cloned()
.collect::<HashMap<StrRef, _>>(),
params: HashMap::new(),
});
let v0 = env.unifier.get_dummy_var().ty;
let v1 = env.unifier.get_dummy_var().ty;
let v0 = env.unifier.get_dummy_var().0;
let v1 = env.unifier.get_dummy_var().0;
let a = env.unifier.add_ty(TypeEnum::TVirtual { ty: bar });
let b = env.unifier.add_ty(TypeEnum::TVirtual { ty: v0 });
let c = env.unifier.add_record([("f".into(), RecordField::new(v1, false, None))].into());
let c = env
.unifier
.add_record([("f".into(), RecordField::new(v1, false, None))].iter().cloned().collect());
env.unifier.unify(a, b).unwrap();
env.unifier.unify(b, c).unwrap();
assert!(env.unifier.eq(v1, fun));
let d = env.unifier.add_record([("a".into(), RecordField::new(v1, true, None))].into());
let d = env
.unifier
.add_record([("a".into(), RecordField::new(v1, true, None))].iter().cloned().collect());
assert_eq!(env.unify(b, d), Err("`virtual[5]::a` field/method does not exist".to_string()));
let d = env.unifier.add_record([("b".into(), RecordField::new(v1, true, None))].into());
let d = env
.unifier
.add_record([("b".into(), RecordField::new(v1, true, None))].iter().cloned().collect());
assert_eq!(env.unify(b, d), Err("`virtual[5]::b` field/method does not exist".to_string()));
}
@ -423,132 +405,86 @@ fn test_typevar_range() {
let int_list = env.parse("list[int]", &HashMap::new());
let float_list = env.parse("list[float]", &HashMap::new());
let list_elem_tvar = if let TypeEnum::TObj { params, .. } =
&*env.unifier.get_ty_immutable(env.type_mapping["list"])
{
iter_type_vars(params).next().unwrap()
} else {
unreachable!()
};
// unification between v and int
// where v in (int, bool)
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty;
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0;
env.unifier.unify(int, v).unwrap();
// unification between v and list[int]
// where v in (int, bool)
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty;
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0;
assert_eq!(
env.unify(int_list, v),
Err("Expected any one of these types: 0, 2, but got 11[0]".to_string())
Err("Expected any one of these types: 0, 2, but got list[0]".to_string())
);
// unification between v and float
// where v in (int, bool)
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty;
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0;
assert_eq!(
env.unify(float, v),
Err("Expected any one of these types: 0, 2, but got 1".to_string())
);
let v1 = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty;
let v1_list = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: v1 }]),
});
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).ty;
let v1 = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0;
let v1_list = env.unifier.add_ty(TypeEnum::TList { ty: v1 });
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).0;
// unification between v and int
// where v in (int, list[v1]), v1 in (int, bool)
env.unifier.unify(int, v).unwrap();
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).ty;
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).0;
// unification between v and list[int]
// where v in (int, list[v1]), v1 in (int, bool)
env.unifier.unify(int_list, v).unwrap();
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).ty;
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).0;
// unification between v and list[float]
// where v in (int, list[v1]), v1 in (int, bool)
println!("float_list: {}, v: {}", env.unifier.stringify(float_list), env.unifier.stringify(v));
assert_eq!(
env.unify(float_list, v),
Err("Expected any one of these types: 0, 11[typevar6], but got 11[1]\n\nNotes:\n typevar6 ∈ {0, 2}".to_string())
Err("Expected any one of these types: 0, list[typevar5], but got list[1]\n\nNotes:\n typevar5 ∈ {0, 2}".to_string())
);
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).ty;
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0;
env.unifier.unify(a, b).unwrap();
env.unifier.unify(a, float).unwrap();
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).ty;
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0;
env.unifier.unify(a, b).unwrap();
assert_eq!(env.unify(a, int), Err("Expected any one of these types: 1, but got 0".into()));
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).ty;
let a_list = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: a }]),
});
let a_list = env.unifier.get_fresh_var_with_range(&[a_list], None, None).ty;
let b_list = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: b }]),
});
let b_list = env.unifier.get_fresh_var_with_range(&[b_list], None, None).ty;
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0;
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a });
let a_list = env.unifier.get_fresh_var_with_range(&[a_list], None, None).0;
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b });
let b_list = env.unifier.get_fresh_var_with_range(&[b_list], None, None).0;
env.unifier.unify(a_list, b_list).unwrap();
let float_list = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: float }]),
});
let float_list = env.unifier.add_ty(TypeEnum::TList { ty: float });
env.unifier.unify(a_list, float_list).unwrap();
// previous unifications should not affect a and b
env.unifier.unify(a, int).unwrap();
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).ty;
let a_list = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: a }]),
});
let b_list = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: b }]),
});
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0;
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a });
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b });
env.unifier.unify(a_list, b_list).unwrap();
let int_list = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: int }]),
});
let int_list = env.unifier.add_ty(TypeEnum::TList { ty: int });
assert_eq!(
env.unify(a_list, int_list),
Err("Incompatible types: 11[typevar23] and 11[0]\
\n\nNotes:\n typevar23 {1}"
.into())
Err("Incompatible types: list[typevar22] and list[0]\
\n\nNotes:\n typevar22 {1}".into())
);
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty;
let b = env.unifier.get_dummy_var().ty;
let a_list = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: a }]),
});
let a_list = env.unifier.get_fresh_var_with_range(&[a_list], None, None).ty;
let b_list = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: b }]),
});
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0;
let b = env.unifier.get_dummy_var().0;
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a });
let a_list = env.unifier.get_fresh_var_with_range(&[a_list], None, None).0;
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b });
env.unifier.unify(a_list, b_list).unwrap();
assert_eq!(
env.unify(b, boolean),
@ -559,29 +495,17 @@ fn test_typevar_range() {
#[test]
fn test_rigid_var() {
let mut env = TestEnvironment::new();
let a = env.unifier.get_fresh_rigid_var(None, None).ty;
let b = env.unifier.get_fresh_rigid_var(None, None).ty;
let x = env.unifier.get_dummy_var().ty;
let list_elem_tvar = env.unifier.get_fresh_var(Some("list_elem".into()), None);
let list_a = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: a }]),
});
let list_x = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: x }]),
});
let a = env.unifier.get_fresh_rigid_var(None, None).0;
let b = env.unifier.get_fresh_rigid_var(None, None).0;
let x = env.unifier.get_dummy_var().0;
let list_a = env.unifier.add_ty(TypeEnum::TList { ty: a });
let list_x = env.unifier.add_ty(TypeEnum::TList { ty: x });
let int = env.parse("int", &HashMap::new());
let list_int = env.parse("list[int]", &HashMap::new());
assert_eq!(env.unify(a, b), Err("Incompatible types: typevar4 and typevar3".to_string()));
assert_eq!(env.unify(a, b), Err("Incompatible types: typevar3 and typevar2".to_string()));
env.unifier.unify(list_a, list_x).unwrap();
assert_eq!(
env.unify(list_x, list_int),
Err("Incompatible types: 11[typevar3] and 11[0]".to_string())
);
assert_eq!(env.unify(list_x, list_int), Err("Incompatible types: list[typevar2] and list[0]".to_string()));
env.unifier.replace_rigid_var(a, int);
env.unifier.unify(list_x, list_int).unwrap();
@ -595,26 +519,16 @@ fn test_instantiation() {
let float = env.parse("float", &HashMap::new());
let list_int = env.parse("list[int]", &HashMap::new());
let list_elem_tvar = if let TypeEnum::TObj { params, .. } =
&*env.unifier.get_ty_immutable(env.type_mapping["list"])
{
iter_type_vars(params).next().unwrap()
} else {
unreachable!()
};
let obj_map: HashMap<_, _> =
[(0usize, "int"), (1, "float"), (2, "bool")].iter().cloned().collect();
let obj_map: HashMap<_, _> = [(0usize, "int"), (1, "float"), (2, "bool"), (11, "list")].into();
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty;
let list_v = env
.unifier
.subst(env.type_mapping["list"], &into_var_map([TypeVar { id: list_elem_tvar.id, ty: v }]))
.unwrap();
let v1 = env.unifier.get_fresh_var_with_range(&[list_v, int], None, None).ty;
let v2 = env.unifier.get_fresh_var_with_range(&[list_int, float], None, None).ty;
let t = env.unifier.get_dummy_var().ty;
let tuple = env.unifier.add_ty(TypeEnum::TTuple { ty: vec![v, v1, v2], is_vararg_ctx: false });
let v3 = env.unifier.get_fresh_var_with_range(&[tuple, t], None, None).ty;
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0;
let list_v = env.unifier.add_ty(TypeEnum::TList { ty: v });
let v1 = env.unifier.get_fresh_var_with_range(&[list_v, int], None, None).0;
let v2 = env.unifier.get_fresh_var_with_range(&[list_int, float], None, None).0;
let t = env.unifier.get_dummy_var().0;
let tuple = env.unifier.add_ty(TypeEnum::TTuple { ty: vec![v, v1, v2] });
let v3 = env.unifier.get_fresh_var_with_range(&[tuple, t], None, None).0;
// t = TypeVar('t')
// v = TypeVar('v', int, bool)
// v1 = TypeVar('v1', 'list[v]', int)
@ -636,7 +550,7 @@ fn test_instantiation() {
tuple[int, list[bool], list[int]]
tuple[int, list[int], float]
tuple[int, list[int], list[int]]
v6"
v5"
}
.split('\n')
.collect_vec();
@ -645,8 +559,8 @@ fn test_instantiation() {
.map(|ty| {
env.unifier.internal_stringify(
*ty,
&mut |i| (*obj_map.get(&i).unwrap()).to_string(),
&mut |i| format!("v{i}"),
&mut |i| obj_map.get(&i).unwrap().to_string(),
&mut |i| format!("v{}", i),
&mut None,
)
})

View File

@ -16,10 +16,21 @@ pub struct UnificationTable<V> {
#[derive(Clone, Debug)]
enum Action<V> {
Parent { key: usize, original_parent: usize },
Value { key: usize, original_value: Option<V> },
Rank { key: usize, original_rank: u32 },
Marker { generation: u32 },
Parent {
key: usize,
original_parent: usize,
},
Value {
key: usize,
original_value: Option<V>,
},
Rank {
key: usize,
original_rank: u32,
},
Marker {
generation: u32,
}
}
impl<V> Default for UnificationTable<V> {
@ -30,13 +41,7 @@ impl<V> Default for UnificationTable<V> {
impl<V> UnificationTable<V> {
pub fn new() -> UnificationTable<V> {
UnificationTable {
parents: Vec::new(),
ranks: Vec::new(),
values: Vec::new(),
log: Vec::new(),
generation: 0,
}
UnificationTable { parents: Vec::new(), ranks: Vec::new(), values: Vec::new(), log: Vec::new(), generation: 0 }
}
pub fn new_key(&mut self, v: V) -> UnificationKey {
@ -120,10 +125,7 @@ impl<V> UnificationTable<V> {
pub fn restore_snapshot(&mut self, snapshot: (usize, u32)) {
let (log_len, generation) = snapshot;
assert!(self.log.len() >= log_len, "snapshot restoration error");
assert!(
matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation),
"snapshot restoration error"
);
assert!(matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation), "snapshot restoration error");
for action in self.log.drain(log_len - 1..).rev() {
match action {
Action::Parent { key, original_parent } => {
@ -143,10 +145,7 @@ impl<V> UnificationTable<V> {
pub fn discard_snapshot(&mut self, snapshot: (usize, u32)) {
let (log_len, generation) = snapshot;
assert!(self.log.len() >= log_len, "snapshot discard error");
assert!(
matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation),
"snapshot discard error"
);
assert!(matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation), "snapshot discard error");
self.log.clear();
}
}
@ -160,23 +159,11 @@ where
.enumerate()
.map(|(i, (v, p))| if *p == i { v.as_ref().map(|v| v.as_ref().clone()) } else { None })
.collect();
UnificationTable {
parents: self.parents.clone(),
ranks: self.ranks.clone(),
values,
log: Vec::new(),
generation: 0,
}
UnificationTable { parents: self.parents.clone(), ranks: self.ranks.clone(), values, log: Vec::new(), generation: 0 }
}
pub fn from_send(table: &UnificationTable<V>) -> UnificationTable<Rc<V>> {
let values = table.values.iter().cloned().map(|v| v.map(Rc::new)).collect();
UnificationTable {
parents: table.parents.clone(),
ranks: table.ranks.clone(),
values,
log: Vec::new(),
generation: 0,
}
UnificationTable { parents: table.parents.clone(), ranks: table.ranks.clone(), values, log: Vec::new(), generation: 0 }
}
}

View File

@ -32,6 +32,7 @@ pub struct DwarfReader<'a> {
}
impl<'a> DwarfReader<'a> {
pub fn new(slice: &[u8], virt_addr: u32) -> DwarfReader {
DwarfReader { slice, virt_addr, base_slice: slice, base_virt_addr: virt_addr }
}
@ -59,7 +60,7 @@ impl<'a> DwarfReader<'a> {
let mut byte: u8;
loop {
byte = self.read_u8();
result |= u64::from(byte & 0x7F) << shift;
result |= ((byte & 0x7F) as u64) << shift;
shift += 7;
if byte & 0x80 == 0 {
break;
@ -74,7 +75,7 @@ impl<'a> DwarfReader<'a> {
let mut byte: u8;
loop {
byte = self.read_u8();
result |= u64::from(byte & 0x7F) << shift;
result |= ((byte & 0x7F) as u64) << shift;
shift += 7;
if byte & 0x80 == 0 {
break;
@ -156,9 +157,10 @@ fn read_encoded_pointer(reader: &mut DwarfReader, encoding: u8) -> Result<usize,
}
match encoding & 0x0F {
DW_EH_PE_absptr | DW_EH_PE_udata4 => Ok(reader.read_u32() as usize),
DW_EH_PE_absptr => Ok(reader.read_u32() as usize),
DW_EH_PE_uleb128 => Ok(reader.read_uleb128() as usize),
DW_EH_PE_udata2 => Ok(reader.read_u16() as usize),
DW_EH_PE_udata4 => Ok(reader.read_u32() as usize),
DW_EH_PE_udata8 => Ok(reader.read_u64() as usize),
DW_EH_PE_sleb128 => Ok(reader.read_sleb128() as usize),
DW_EH_PE_sdata2 => Ok(reader.read_i16() as usize),
@ -168,7 +170,10 @@ fn read_encoded_pointer(reader: &mut DwarfReader, encoding: u8) -> Result<usize,
}
}
fn read_encoded_pointer_with_pc(reader: &mut DwarfReader, encoding: u8) -> Result<usize, ()> {
fn read_encoded_pointer_with_pc(
reader: &mut DwarfReader,
encoding: u8,
) -> Result<usize, ()> {
let entry_virt_addr = reader.virt_addr;
let mut result = read_encoded_pointer(reader, encoding)?;
@ -218,10 +223,11 @@ pub struct EH_Frame<'a> {
}
impl<'a> EH_Frame<'a> {
/// Creates an [EH_Frame] using the bytes in the `.eh_frame` section and its address in the ELF
/// file.
pub fn new(eh_frame_slice: &[u8], eh_frame_addr: u32) -> EH_Frame {
EH_Frame { reader: DwarfReader::new(eh_frame_slice, eh_frame_addr) }
pub fn new(eh_frame_slice: &[u8], eh_frame_addr: u32) -> Result<EH_Frame, ()> {
Ok(EH_Frame { reader: DwarfReader::new(eh_frame_slice, eh_frame_addr) })
}
/// Returns an [Iterator] over all Call Frame Information (CFI) records.
@ -229,7 +235,10 @@ impl<'a> EH_Frame<'a> {
let reader = DwarfReader::from_reader(&self.reader, true);
let len = reader.slice.len();
CFI_Records { reader, available: len }
CFI_Records {
reader,
available: len,
}
}
}
@ -238,7 +247,7 @@ impl<'a> EH_Frame<'a> {
/// From the [specification](https://refspecs.linuxfoundation.org/LSB_5.0.0/LSB-Core-generic/LSB-Core-generic/ehframechpt.html):
///
/// > Each CFI record contains a Common Information Entry (CIE) record followed by 1 or more Frame
/// > Description Entry (FDE) records.
/// Description Entry (FDE) records.
pub struct CFI_Record<'a> {
// It refers to the augmentation data that corresponds to 'R' in the augmentation string
fde_pointer_encoding: u8,
@ -246,6 +255,7 @@ pub struct CFI_Record<'a> {
}
impl<'a> CFI_Record<'a> {
pub fn from_reader(cie_reader: &mut DwarfReader<'a>) -> Result<CFI_Record<'a>, ()> {
let length = cie_reader.read_u32();
let fde_reader = match length {
@ -254,7 +264,7 @@ impl<'a> CFI_Record<'a> {
// length == u32::MAX means that the length is only representable with 64 bits,
// which does not make sense in a system with 32-bit address.
0xFFFF_FFFF => unimplemented!(),
0xFFFFFFFF => unimplemented!(),
_ => {
let mut fde_reader = DwarfReader::from_reader(cie_reader, false);
@ -313,7 +323,10 @@ impl<'a> CFI_Record<'a> {
}
assert_ne!(fde_pointer_encoding, DW_EH_PE_omit);
Ok(CFI_Record { fde_pointer_encoding, fde_reader })
Ok(CFI_Record {
fde_pointer_encoding,
fde_reader,
})
}
/// Returns a [DwarfReader] initialized to the first Frame Description Entry (FDE) of this CFI
@ -327,7 +340,11 @@ impl<'a> CFI_Record<'a> {
let reader = self.get_fde_reader();
let len = reader.slice.len();
FDE_Records { pointer_encoding: self.fde_pointer_encoding, reader, available: len }
FDE_Records {
pointer_encoding: self.fde_pointer_encoding,
reader,
available: len,
}
}
}
@ -354,7 +371,7 @@ impl<'a> Iterator for CFI_Records<'a> {
let length = match length {
// eh_frame with 0-length means the CIE is terminated
0 => return None,
0xFFFF_FFFF => unimplemented!("CIE entries larger than 4 bytes not supported"),
0xFFFFFFFF => unimplemented!("CIE entries larger than 4 bytes not supported"),
other => other,
} as usize;
@ -370,7 +387,7 @@ impl<'a> Iterator for CFI_Records<'a> {
// Skip this record if it is a FDE
if cie_ptr == 0 {
// Rewind back to the start of the CFI Record
return Some(CFI_Record::from_reader(&mut this_reader).ok().unwrap());
return Some(CFI_Record::from_reader(&mut this_reader).ok().unwrap())
}
}
}
@ -400,7 +417,7 @@ impl<'a> Iterator for FDE_Records<'a> {
let length = match self.reader.read_u32() {
// eh_frame with 0-length means the CIE is terminated
0 => return None,
0xFFFF_FFFF => unimplemented!("CIE entries larger than 4 bytes not supported"),
0xFFFFFFFF => unimplemented!("CIE entries larger than 4 bytes not supported"),
other => other,
} as usize;
@ -431,6 +448,7 @@ pub struct EH_Frame_Hdr<'a> {
}
impl<'a> EH_Frame_Hdr<'a> {
/// Create a [EH_Frame_Hdr] object, and write out the fixed fields of `.eh_frame_hdr` to memory.
///
/// Load address is not known at this point.
@ -446,9 +464,8 @@ impl<'a> EH_Frame_Hdr<'a> {
writer.write_u8(0x03); // fde_count_enc - 4-byte unsigned value
writer.write_u8(0x3B); // table_enc - .eh_frame_hdr section-relative 4-byte signed value
let eh_frame_offset = eh_frame_addr.wrapping_sub(
eh_frame_hdr_addr + writer.offset as u32 + ((mem::size_of::<u8>() as u32) * 4),
);
let eh_frame_offset = eh_frame_addr
.wrapping_sub(eh_frame_hdr_addr + writer.offset as u32 + ((mem::size_of::<u8>() as u32) * 4));
writer.write_u32(eh_frame_offset); // eh_frame_ptr
writer.write_u32(0); // `fde_count`, will be written in finalize_fde
@ -475,10 +492,7 @@ impl<'a> EH_Frame_Hdr<'a> {
self.fde_writer.write_u32(*init_loc);
self.fde_writer.write_u32(*addr);
}
LittleEndian::write_u32(
&mut self.fde_writer.slice[Self::fde_count_offset()..],
self.fdes.len() as u32,
);
LittleEndian::write_u32(&mut self.fde_writer.slice[Self::fde_count_offset()..], self.fdes.len() as u32);
}
pub fn size_from_eh_frame(eh_frame: &[u8]) -> usize {
@ -490,7 +504,7 @@ impl<'a> EH_Frame_Hdr<'a> {
// The original length field should be able to hold the entire value.
// The device memory space is limited to 32-bits addresses anyway.
let entry_length = reader.read_u32();
if entry_length == 0 || entry_length == 0xFFFF_FFFF {
if entry_length == 0 || entry_length == 0xFFFFFFFF {
unimplemented!()
}
@ -501,7 +515,7 @@ impl<'a> EH_Frame_Hdr<'a> {
fde_count += 1;
}
reader.offset(entry_length - mem::size_of::<u32>() as u32);
reader.offset(entry_length - mem::size_of::<u32>() as u32)
}
12 + fde_count * 8

View File

@ -1,5 +1,5 @@
/* generated from elf.h with rust-bindgen and then manually altered */
#![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, clippy::pedantic)]
#![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code)]
pub const EI_NIDENT: usize = 16;
pub const EI_MAG0: usize = 0;

View File

@ -1,26 +1,10 @@
#![deny(future_incompatible, let_underscore, nonstandard_style, clippy::all)]
#![warn(clippy::pedantic)]
#![allow(
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::cast_sign_loss,
clippy::doc_markdown,
clippy::enum_glob_use,
clippy::missing_errors_doc,
clippy::missing_panics_doc,
clippy::module_name_repetitions,
clippy::similar_names,
clippy::struct_field_names,
clippy::too_many_lines,
clippy::wildcard_imports
)]
use std::{collections::HashMap, mem, ptr, slice, str};
use byteorder::{ByteOrder, LittleEndian};
use dwarf::*;
use elf::*;
use std::collections::HashMap;
use std::{mem, ptr, slice, str};
extern crate byteorder;
use byteorder::{ByteOrder, LittleEndian};
mod dwarf;
mod elf;
@ -86,45 +70,45 @@ struct SectionRecord<'a> {
data: Vec<u8>,
}
fn read_unaligned<T: Copy>(data: &[u8], offset: usize) -> Option<T> {
fn read_unaligned<T: Copy>(data: &[u8], offset: usize) -> Result<T, ()> {
if data.len() < offset + mem::size_of::<T>() {
None
Err(())
} else {
let ptr = data.as_ptr().wrapping_add(offset).cast();
Some(unsafe { ptr::read_unaligned(ptr) })
let ptr = data.as_ptr().wrapping_add(offset) as *const T;
Ok(unsafe { ptr::read_unaligned(ptr) })
}
}
#[must_use]
pub fn get_ref_slice<T: Copy>(data: &[u8], offset: usize, len: usize) -> Option<&[T]> {
pub fn get_ref_slice<T: Copy>(data: &[u8], offset: usize, len: usize) -> Result<&[T], ()> {
if data.len() < offset + mem::size_of::<T>() * len {
None
Err(())
} else {
let ptr = data.as_ptr().wrapping_add(offset).cast();
Some(unsafe { slice::from_raw_parts(ptr, len) })
let ptr = data.as_ptr().wrapping_add(offset) as *const T;
Ok(unsafe { slice::from_raw_parts(ptr, len) })
}
}
fn from_struct_slice<T>(struct_vec: &[T]) -> Vec<u8> {
fn from_struct_vec<T>(struct_vec: Vec<T>) -> Vec<u8> {
let ptr = struct_vec.as_ptr();
unsafe { slice::from_raw_parts(ptr.cast(), mem::size_of_val(struct_vec)) }.to_vec()
unsafe { slice::from_raw_parts(ptr as *const u8, struct_vec.len() * mem::size_of::<T>()) }
.to_vec()
}
fn to_struct_slice<T>(bytes: &[u8]) -> &[T] {
unsafe { slice::from_raw_parts(bytes.as_ptr().cast(), bytes.len() / mem::size_of::<T>()) }
unsafe { slice::from_raw_parts(bytes.as_ptr() as *const T, bytes.len() / mem::size_of::<T>()) }
}
fn to_struct_mut_slice<T>(bytes: &mut [u8]) -> &mut [T] {
unsafe {
slice::from_raw_parts_mut(bytes.as_mut_ptr().cast(), bytes.len() / mem::size_of::<T>())
slice::from_raw_parts_mut(bytes.as_mut_ptr() as *mut T, bytes.len() / mem::size_of::<T>())
}
}
fn elf_hash(name: &[u8]) -> u32 {
let mut h: u32 = 0;
for c in name {
h = (h << 4) + u32::from(*c);
let g = h & 0xf000_0000;
h = (h << 4) + *c as u32;
let g = h & 0xf0000000;
if g != 0 {
h ^= g >> 24;
h &= !g;
@ -218,26 +202,22 @@ impl<'a> Linker<'a> {
relocs: &[R],
target_section: Elf32_Word,
) -> Result<(), Error> {
type RelocateFn = dyn Fn(&mut [u8], Elf32_Word);
struct RelocInfo<'a, R> {
pub defined_val: bool,
pub indirect_reloc: Option<&'a R>,
pub pc_relative: bool,
pub relocate: Option<Box<RelocateFn>>,
}
for reloc in relocs {
let sym = match reloc.sym_info() as usize {
STN_UNDEF => None,
sym_index => {
Some(self.symtab.get(sym_index).ok_or("symbol out of bounds of symbol table")?)
}
sym_index => Some(
self.symtab
.get(sym_index)
.ok_or("symbol out of bounds of symbol table")?,
),
};
let resolve_symbol_addr =
|sym_option: Option<&Elf32_Sym>| -> Result<Elf32_Word, Error> {
let Some(sym) = sym_option else { return Ok(0) };
let sym = match sym_option {
Some(sym) => sym,
None => return Ok(0),
};
match sym.st_shndx {
SHN_UNDEF => Err(Error::Lookup("undefined symbol")),
@ -264,6 +244,13 @@ impl<'a> Linker<'a> {
.ok_or(Error::Parsing("Cannot find section with matching sh_index"))
};
struct RelocInfo<'a, R> {
pub defined_val: bool,
pub indirect_reloc: Option<&'a R>,
pub pc_relative: bool,
pub relocate: Option<Box<dyn Fn(&mut [u8], Elf32_Word)>>,
}
let classify = |reloc: &R, sym_option: Option<&Elf32_Sym>| -> Option<RelocInfo<R>> {
let defined_val = sym_option.map_or(true, |sym| {
sym.st_shndx != SHN_UNDEF || ELF32_ST_BIND(sym.st_info) == STB_LOCAL
@ -275,7 +262,7 @@ impl<'a> Linker<'a> {
indirect_reloc: None,
pc_relative: true,
relocate: Some(Box::new(|target_word, value| {
LittleEndian::write_u32(target_word, value);
LittleEndian::write_u32(target_word, value)
})),
}),
@ -286,9 +273,9 @@ impl<'a> Linker<'a> {
relocate: Some(Box::new(|target_word, value| {
LittleEndian::write_u32(
target_word,
(LittleEndian::read_u32(target_word) & 0x8000_0000)
| value & 0x7FFF_FFFF,
);
(LittleEndian::read_u32(target_word) & 0x80000000)
| value & 0x7FFFFFFF,
)
})),
}),
@ -310,8 +297,8 @@ impl<'a> Linker<'a> {
relocate: Some(Box::new(|target_word, value| {
let auipc_raw = LittleEndian::read_u32(target_word);
let auipc_insn =
(auipc_raw & 0xFFF) | ((value + 0x800) & 0xFFFF_F000);
LittleEndian::write_u32(target_word, auipc_insn);
(auipc_raw & 0xFFF) | ((value + 0x800) & 0xFFFFF000);
LittleEndian::write_u32(target_word, auipc_insn)
})),
})
}
@ -321,14 +308,15 @@ impl<'a> Linker<'a> {
indirect_reloc: None,
pc_relative: true,
relocate: Some(Box::new(|target_word, value| {
LittleEndian::write_u32(target_word, value);
LittleEndian::write_u32(target_word, value)
})),
}),
R_RISCV_PCREL_LO12_I => {
let expected_offset = sym_option.map_or(0, |sym| sym.st_value);
let indirect_reloc =
relocs.iter().find(|reloc| reloc.offset() == expected_offset)?;
let indirect_reloc = relocs
.iter()
.find(|reloc| reloc.offset() == expected_offset)?;
Some(RelocInfo {
defined_val: {
let indirect_sym =
@ -342,14 +330,14 @@ impl<'a> Linker<'a> {
// Here, we convert to direct addressing
// GOT reloc (indirect) -> lw + addi
// PCREL reloc (direct) -> addi
let (lo_opcode, lo_funct3) = (0b001_0011, 0b000);
let (lo_opcode, lo_funct3) = (0b0010011, 0b000);
let addi_lw_raw = LittleEndian::read_u32(target_word);
let addi_insn = lo_opcode
| (addi_lw_raw & 0xF8F80)
| (lo_funct3 << 12)
| ((value & 0xFFF) << 20);
LittleEndian::write_u32(target_word, addi_insn);
LittleEndian::write_u32(target_word, addi_insn)
})),
})
}
@ -366,7 +354,10 @@ impl<'a> Linker<'a> {
indirect_reloc: None,
pc_relative: false,
relocate: Some(Box::new(|target_word, value| {
LittleEndian::write_u32(target_word, value);
LittleEndian::write_u32(
target_word,
value,
)
})),
}),
@ -376,7 +367,7 @@ impl<'a> Linker<'a> {
pc_relative: false,
relocate: Some(Box::new(|target_word, value| {
let old_value = LittleEndian::read_u32(target_word);
LittleEndian::write_u32(target_word, old_value.wrapping_add(value));
LittleEndian::write_u32(target_word, old_value.wrapping_add(value))
})),
}),
@ -386,7 +377,7 @@ impl<'a> Linker<'a> {
pc_relative: false,
relocate: Some(Box::new(|target_word, value| {
let old_value = LittleEndian::read_u32(target_word);
LittleEndian::write_u32(target_word, old_value.wrapping_sub(value));
LittleEndian::write_u32(target_word, old_value.wrapping_sub(value))
})),
}),
@ -395,7 +386,10 @@ impl<'a> Linker<'a> {
indirect_reloc: None,
pc_relative: false,
relocate: Some(Box::new(|target_word, value| {
LittleEndian::write_u16(target_word, value as u16);
LittleEndian::write_u16(
target_word,
value as u16,
)
})),
}),
@ -408,7 +402,7 @@ impl<'a> Linker<'a> {
LittleEndian::write_u16(
target_word,
old_value.wrapping_add(value as u16),
);
)
})),
}),
@ -421,7 +415,7 @@ impl<'a> Linker<'a> {
LittleEndian::write_u16(
target_word,
old_value.wrapping_sub(value as u16),
);
)
})),
}),
@ -503,7 +497,7 @@ impl<'a> Linker<'a> {
if let Some(relocate) = reloc_info.relocate {
let target_word = &mut target_sec_image[reloc.offset() as usize..];
relocate(target_word, value);
relocate(target_word, value)
} else {
self.rela_dyn_relas.push(Elf32_Rela {
r_offset: rela_off,
@ -551,18 +545,16 @@ impl<'a> Linker<'a> {
let eh_frame_slice = eh_frame_rec.data.as_slice();
// Prepare a new buffer to dodge borrow check
let mut eh_frame_hdr_vec: Vec<u8> = vec![0; eh_frame_hdr_rec.shdr.sh_size as usize];
let eh_frame = EH_Frame::new(eh_frame_slice, eh_frame_rec.shdr.sh_offset);
let eh_frame = EH_Frame::new(eh_frame_slice, eh_frame_rec.shdr.sh_offset)
.map_err(|()| "cannot read EH frame")?;
let mut eh_frame_hdr = EH_Frame_Hdr::new(
eh_frame_hdr_vec.as_mut_slice(),
eh_frame_hdr_rec.shdr.sh_offset,
eh_frame_rec.shdr.sh_offset,
);
eh_frame.cfi_records().flat_map(|cfi| cfi.fde_records()).for_each(&mut |(
init_pos,
virt_addr,
)| {
eh_frame_hdr.add_fde(init_pos, virt_addr);
});
eh_frame.cfi_records()
.flat_map(|cfi| cfi.fde_records())
.for_each(&mut |(init_pos, virt_addr)| eh_frame_hdr.add_fde(init_pos, virt_addr));
// Sort FDE entries in .eh_frame_hdr
eh_frame_hdr.finalize_fde();
@ -576,114 +568,39 @@ impl<'a> Linker<'a> {
}
pub fn ld(data: &'a [u8]) -> Result<Vec<u8>, Error> {
fn allocate_rela_dyn<R: Relocatable>(
linker: &Linker,
relocs: &[R],
) -> Result<(usize, Vec<u32>), Error> {
let mut alloc_size = 0;
let mut rela_dyn_sym_indices = Vec::new();
for reloc in relocs {
if reloc.sym_info() as usize == STN_UNDEF {
continue;
}
let sym: &Elf32_Sym = linker
.symtab
.get(reloc.sym_info() as usize)
.ok_or("symbol out of bounds of symbol table")?;
match (linker.isa, reloc.type_info()) {
// Absolute address relocations
// A runtime relocation is needed to find the loading address
(Isa::CortexA9, R_ARM_ABS32) | (Isa::RiscV32, R_RISCV_32) => {
alloc_size += mem::size_of::<Elf32_Rela>(); // FIXME: RELA vs REL
if ELF32_ST_BIND(sym.st_info) == STB_GLOBAL && sym.st_shndx == SHN_UNDEF {
rela_dyn_sym_indices.push(reloc.sym_info());
}
}
// Relative address relocations
// Relay the relocation to the runtime linker only if the symbol is not defined
(Isa::CortexA9, R_ARM_REL32 | R_ARM_PREL31 | R_ARM_TARGET2)
| (
Isa::RiscV32,
R_RISCV_CALL_PLT | R_RISCV_PCREL_HI20 | R_RISCV_GOT_HI20 | R_RISCV_32_PCREL
| R_RISCV_SET32 | R_RISCV_ADD32 | R_RISCV_SUB32 | R_RISCV_SET16
| R_RISCV_ADD16 | R_RISCV_SUB16 | R_RISCV_SET8 | R_RISCV_ADD8
| R_RISCV_SUB8 | R_RISCV_SET6 | R_RISCV_SUB6,
) => {
if ELF32_ST_BIND(sym.st_info) == STB_GLOBAL && sym.st_shndx == SHN_UNDEF {
alloc_size += mem::size_of::<Elf32_Rela>(); // FIXME: RELA vs REL
rela_dyn_sym_indices.push(reloc.sym_info());
}
}
// RISC-V: Lower 12-bits relocations
// If the upper 20-bits relocation cannot be resolved,
// this relocation will be relayed to the runtime linker.
(Isa::RiscV32, R_RISCV_PCREL_LO12_I) => {
// Find the HI20 relocation
let indirect_reloc = relocs
.iter()
.find(|reloc| reloc.offset() == sym.st_value)
.ok_or("malformatted LO12 relocation")?;
let indirect_sym = linker.symtab[indirect_reloc.sym_info() as usize];
if ELF32_ST_BIND(indirect_sym.st_info) == STB_GLOBAL
&& indirect_sym.st_shndx == SHN_UNDEF
{
alloc_size += mem::size_of::<Elf32_Rela>(); // FIXME: RELA vs REL
rela_dyn_sym_indices.push(reloc.sym_info());
}
}
_ => {
println!("Relocation type 0x{:X?} is not supported", reloc.type_info());
unimplemented!()
}
}
}
Ok((alloc_size, rela_dyn_sym_indices))
}
let Some(ehdr) = read_unaligned::<Elf32_Ehdr>(data, 0) else {
Err("cannot read ELF header")?
};
let ehdr = read_unaligned::<Elf32_Ehdr>(data, 0).map_err(|()| "cannot read ELF header")?;
let isa = match ehdr.e_machine {
EM_ARM => Isa::CortexA9,
EM_RISCV => Isa::RiscV32,
_ => return Err(Error::Parsing("unsupported architecture")),
};
let Some(shdrs) =
get_ref_slice::<Elf32_Shdr>(data, ehdr.e_shoff as usize, ehdr.e_shnum as usize)
else {
Err("cannot read section header table")?
};
let shdrs = get_ref_slice::<Elf32_Shdr>(data, ehdr.e_shoff as usize, ehdr.e_shnum as usize)
.map_err(|()| "cannot read section header table")?;
// Read .strtab
let strtab_shdr = shdrs[ehdr.e_shstrndx as usize];
let Some(strtab) =
let strtab =
get_ref_slice::<u8>(data, strtab_shdr.sh_offset as usize, strtab_shdr.sh_size as usize)
else {
Err("cannot read the string table from data")?
};
.map_err(|()| "cannot read the string table from data")?;
// Read .symtab
let symtab_shdr = shdrs
.iter()
.find(|shdr| shdr.sh_type as usize == SHT_SYMTAB)
.ok_or(Error::Parsing("cannot find the symbol table"))?;
let Some(symtab) = get_ref_slice::<Elf32_Sym>(
let symtab = get_ref_slice::<Elf32_Sym>(
data,
symtab_shdr.sh_offset as usize,
symtab_shdr.sh_size as usize / mem::size_of::<Elf32_Sym>(),
) else {
Err("cannot read the symbol table from data")?
};
)
.map_err(|()| "cannot read the symbol table from data")?;
// Section table for the .elf paired with the section name
// To be formalized incrementally
// Very hashmap-like structure, but the order matters, so it is a vector
let elf_shdrs = vec![SectionRecord {
let elf_shdrs = vec![
SectionRecord {
shdr: Elf32_Shdr {
sh_name: 0,
sh_type: 0,
@ -698,7 +615,8 @@ impl<'a> Linker<'a> {
},
name: "",
data: vec![0; 0],
}];
},
];
let elf_sh_data_off = mem::size_of::<Elf32_Ehdr>() + mem::size_of::<Elf32_Phdr>() * 5;
// Image of the linked dynamic library, to be formalized incrementally
@ -834,27 +752,21 @@ impl<'a> Linker<'a> {
($shdr: expr, $stmt: expr) => {
match $shdr.sh_type as usize {
SHT_RELA => {
let Some(relocs) = get_ref_slice::<Elf32_Rela>(
let relocs = get_ref_slice::<Elf32_Rela>(
data,
$shdr.sh_offset as usize,
$shdr.sh_size as usize / mem::size_of::<Elf32_Rela>(),
) else {
Err("cannot parse relocations")?
};
#[allow(clippy::redundant_closure_call)]
)
.map_err(|()| "cannot parse relocations")?;
$stmt(relocs)
}
SHT_REL => {
let Some(relocs) = get_ref_slice::<Elf32_Rel>(
let relocs = get_ref_slice::<Elf32_Rel>(
data,
$shdr.sh_offset as usize,
$shdr.sh_size as usize / mem::size_of::<Elf32_Rel>(),
) else {
Err("cannot parse relocations")?
};
#[allow(clippy::redundant_closure_call)]
)
.map_err(|()| "cannot parse relocations")?;
$stmt(relocs)
}
_ => unreachable!(),
@ -862,6 +774,84 @@ impl<'a> Linker<'a> {
};
}
fn allocate_rela_dyn<R: Relocatable>(
linker: &Linker,
relocs: &[R],
) -> Result<(usize, Vec<u32>), Error> {
let mut alloc_size = 0;
let mut rela_dyn_sym_indices = Vec::new();
for reloc in relocs {
if reloc.sym_info() as usize == STN_UNDEF {
continue;
}
let sym: &Elf32_Sym = linker
.symtab
.get(reloc.sym_info() as usize)
.ok_or("symbol out of bounds of symbol table")?;
match (linker.isa, reloc.type_info()) {
// Absolute address relocations
// A runtime relocation is needed to find the loading address
(Isa::CortexA9, R_ARM_ABS32) | (Isa::RiscV32, R_RISCV_32) => {
alloc_size += mem::size_of::<Elf32_Rela>(); // FIXME: RELA vs REL
if ELF32_ST_BIND(sym.st_info) == STB_GLOBAL && sym.st_shndx == SHN_UNDEF {
rela_dyn_sym_indices.push(reloc.sym_info());
}
}
// Relative address relocations
// Relay the relocation to the runtime linker only if the symbol is not defined
(Isa::CortexA9, R_ARM_REL32)
| (Isa::CortexA9, R_ARM_PREL31)
| (Isa::CortexA9, R_ARM_TARGET2)
| (Isa::RiscV32, R_RISCV_CALL_PLT)
| (Isa::RiscV32, R_RISCV_PCREL_HI20)
| (Isa::RiscV32, R_RISCV_GOT_HI20)
| (Isa::RiscV32, R_RISCV_32_PCREL)
| (Isa::RiscV32, R_RISCV_SET32)
| (Isa::RiscV32, R_RISCV_ADD32)
| (Isa::RiscV32, R_RISCV_SUB32)
| (Isa::RiscV32, R_RISCV_SET16)
| (Isa::RiscV32, R_RISCV_ADD16)
| (Isa::RiscV32, R_RISCV_SUB16)
| (Isa::RiscV32, R_RISCV_SET8)
| (Isa::RiscV32, R_RISCV_ADD8)
| (Isa::RiscV32, R_RISCV_SUB8)
| (Isa::RiscV32, R_RISCV_SET6)
| (Isa::RiscV32, R_RISCV_SUB6) => {
if ELF32_ST_BIND(sym.st_info) == STB_GLOBAL && sym.st_shndx == SHN_UNDEF {
alloc_size += mem::size_of::<Elf32_Rela>(); // FIXME: RELA vs REL
rela_dyn_sym_indices.push(reloc.sym_info());
}
}
// RISC-V: Lower 12-bits relocations
// If the upper 20-bits relocation cannot be resolved,
// this relocation will be relayed to the runtime linker.
(Isa::RiscV32, R_RISCV_PCREL_LO12_I) => {
// Find the HI20 relocation
let indirect_reloc = relocs
.iter()
.find(|reloc| reloc.offset() == sym.st_value)
.ok_or("malformatted LO12 relocation")?;
let indirect_sym = linker.symtab[indirect_reloc.sym_info() as usize];
if ELF32_ST_BIND(indirect_sym.st_info) == STB_GLOBAL
&& indirect_sym.st_shndx == SHN_UNDEF
{
alloc_size += mem::size_of::<Elf32_Rela>(); // FIXME: RELA vs REL
rela_dyn_sym_indices.push(reloc.sym_info());
}
}
_ => {
println!("Relocation type 0x{:X?} is not supported", reloc.type_info());
unimplemented!()
}
}
}
Ok((alloc_size, rela_dyn_sym_indices))
}
for shdr in shdrs
.iter()
.filter(|shdr| shdr.sh_type as usize == SHT_REL || shdr.sh_type as usize == SHT_RELA)
@ -889,7 +879,7 @@ impl<'a> Linker<'a> {
}
// Avoid symbol duplication
rela_dyn_sym_indices.sort_unstable();
rela_dyn_sym_indices.sort();
rela_dyn_sym_indices.dedup();
if rela_dyn_size != 0 {
@ -1020,9 +1010,7 @@ impl<'a> Linker<'a> {
let mut hash_bucket: Vec<u32> = vec![0; dynsym.len()];
let mut hash_chain: Vec<u32> = vec![0; dynsym.len()];
for (sym_index, (str_start, str_end)) in
dynsym_names.iter().enumerate().take(dynsym.len()).skip(1)
{
for (sym_index, (str_start, str_end)) in dynsym_names.iter().enumerate().take(dynsym.len()).skip(1) {
let hash = elf_hash(&dynstr[*str_start..*str_end]);
let mut hash_index = hash as usize % hash_bucket.len();
@ -1074,7 +1062,7 @@ impl<'a> Linker<'a> {
sh_entsize: mem::size_of::<Elf32_Sym>() as Elf32_Word,
},
".dynsym",
from_struct_slice(&dynsym),
from_struct_vec(dynsym),
);
let hash_elf_index = linker.load_section(
&Elf32_Shdr {
@ -1090,7 +1078,7 @@ impl<'a> Linker<'a> {
sh_entsize: 4,
},
".hash",
from_struct_slice(&hash),
from_struct_vec(hash),
);
// Link .rela.dyn header to the .dynsym header
@ -1189,7 +1177,7 @@ impl<'a> Linker<'a> {
};
let dynamic_elf_index =
linker.load_section(&dynamic_shdr, ".dynamic", from_struct_slice(&dyn_entries));
linker.load_section(&dynamic_shdr, ".dynamic", from_struct_vec(dyn_entries));
let last_w_sec_elf_index = linker.elf_shdrs.len() - 1;
@ -1265,9 +1253,7 @@ impl<'a> Linker<'a> {
update_dynsym_record!(b"__bss_start", bss_offset, bss_elf_index as Elf32_Section);
update_dynsym_record!(b"_end", bss_offset, bss_elf_index as Elf32_Section);
} else {
for (bss_iter_index, &(bss_section_index, section_name)) in
bss_index_vec.iter().enumerate()
{
for (bss_iter_index, &(bss_section_index, section_name)) in bss_index_vec.iter().enumerate() {
let shdr = &shdrs[bss_section_index];
let bss_elf_index = linker.load_section(
shdr,
@ -1340,7 +1326,7 @@ impl<'a> Linker<'a> {
// Prepare a STRTAB to hold the names of section headers
// Fix the sh_name field of the section headers
let mut shstrtab = Vec::new();
for shdr_rec in &mut linker.elf_shdrs {
for shdr_rec in linker.elf_shdrs.iter_mut() {
let shstrtab_index = shstrtab.len();
shstrtab.extend(shdr_rec.name.as_bytes());
shstrtab.push(0);
@ -1381,17 +1367,20 @@ impl<'a> Linker<'a> {
let alignment = (4 - (linker.image.len() % 4)) % 4;
let sec_headers_offset = linker.image.len() + alignment;
linker.image.extend(vec![0; alignment]);
for rec in &linker.elf_shdrs {
for rec in linker.elf_shdrs.iter() {
let shdr = rec.shdr;
linker.image.extend(unsafe {
slice::from_raw_parts(ptr::addr_of!(shdr).cast(), mem::size_of::<Elf32_Shdr>())
slice::from_raw_parts(
&shdr as *const Elf32_Shdr as *const u8,
mem::size_of::<Elf32_Shdr>(),
)
});
}
// Update the PHDRs
let phdr_offset = mem::size_of::<Elf32_Ehdr>();
unsafe {
let phdr_ptr = linker.image.as_mut_ptr().add(phdr_offset).cast();
let phdr_ptr = linker.image.as_mut_ptr().add(phdr_offset) as *mut Elf32_Phdr;
let phdr_slice = slice::from_raw_parts_mut(phdr_ptr, 5);
// List of program headers:
// 1. ELF headers & program headers
@ -1468,7 +1457,7 @@ impl<'a> Linker<'a> {
}
// Update the EHDR
let ehdr_ptr = linker.image.as_mut_ptr().cast();
let ehdr_ptr = linker.image.as_mut_ptr() as *mut Elf32_Ehdr;
unsafe {
*ehdr_ptr = Elf32_Ehdr {
e_ident: ehdr.e_ident,

View File

@ -8,15 +8,15 @@ license = "MIT"
edition = "2021"
[build-dependencies]
lalrpop = "0.22"
lalrpop = "0.20"
[dependencies]
nac3ast = { path = "../nac3ast" }
lalrpop-util = "0.22"
lalrpop-util = "0.20"
log = "0.4"
unic-emoji-char = "0.9"
unic-ucd-ident = "0.9"
unicode_names2 = "1.3"
unicode_names2 = "1.0"
phf = { version = "0.11", features = ["macros"] }
ahash = "0.8"

View File

@ -1,17 +1,15 @@
use crate::{
ast::{Ident, Location},
error::*,
token::Tok,
};
use lalrpop_util::ParseError;
use nac3ast::*;
use crate::ast::Ident;
use crate::ast::Location;
use crate::token::Tok;
use crate::error::*;
pub fn make_config_comment(
com_loc: Location,
stmt_loc: Location,
nac3com_above: Vec<(Ident, Tok)>,
nac3com_end: Option<Ident>,
nac3com_end: Option<Ident>
) -> Result<Vec<Ident>, ParseError<Location, Tok, LexicalError>> {
if com_loc.column() != stmt_loc.column() && !nac3com_above.is_empty() {
return Err(ParseError::User {
@ -19,25 +17,24 @@ pub fn make_config_comment(
location: com_loc,
error: LexicalErrorType::OtherError(
format!(
"config comment at top must have the same indentation with what it applies (comment at {com_loc}, statement at {stmt_loc})",
"config comment at top must have the same indentation with what it applies (comment at {}, statement at {})",
com_loc,
stmt_loc,
)
)
}
});
})
};
Ok(nac3com_above
Ok(
nac3com_above
.into_iter()
.map(|(com, _)| com)
.chain(nac3com_end.map_or_else(|| vec![].into_iter(), |com| vec![com].into_iter()))
.collect())
.collect()
)
}
pub fn handle_small_stmt<U>(
stmts: &mut [Stmt<U>],
nac3com_above: Vec<(Ident, Tok)>,
nac3com_end: Option<Ident>,
com_above_loc: Location,
) -> Result<(), ParseError<Location, Tok, LexicalError>> {
pub fn handle_small_stmt<U>(stmts: &mut [Stmt<U>], nac3com_above: Vec<(Ident, Tok)>, nac3com_end: Option<Ident>, com_above_loc: Location) -> Result<(), ParseError<Location, Tok, LexicalError>> {
if com_above_loc.column() != stmts[0].location.column() && !nac3com_above.is_empty() {
return Err(ParseError::User {
error: LexicalError {
@ -50,12 +47,17 @@ pub fn handle_small_stmt<U>(
)
)
}
});
})
}
apply_config_comments(&mut stmts[0], nac3com_above.into_iter().map(|(com, _)| com).collect());
apply_config_comments(
&mut stmts[0],
nac3com_above
.into_iter()
.map(|(com, _)| com).collect()
);
apply_config_comments(
stmts.last_mut().unwrap(),
nac3com_end.map_or_else(Vec::new, |com| vec![com]),
nac3com_end.map_or_else(Vec::new, |com| vec![com])
);
Ok(())
}
@ -78,8 +80,6 @@ fn apply_config_comments<U>(stmt: &mut Stmt<U>, comments: Vec<Ident>) {
| StmtKind::Nonlocal { config_comment, .. }
| StmtKind::Assert { config_comment, .. } => config_comment.extend(comments),
_ => {
unreachable!("only small statements should call this function")
}
_ => { unreachable!("only small statements should call this function") }
}
}

Some files were not shown because too many files have changed in this diff Show More