Compare commits

..

2 Commits

103 changed files with 5160 additions and 11721 deletions

View File

@ -1,32 +0,0 @@
BasedOnStyle: LLVM
Language: Cpp
Standard: Cpp11
AccessModifierOffset: -1
AlignEscapedNewlines: Left
AlwaysBreakAfterReturnType: None
AlwaysBreakTemplateDeclarations: Yes
AllowAllParametersOfDeclarationOnNextLine: false
AllowShortFunctionsOnASingleLine: Inline
BinPackParameters: false
BreakBeforeBinaryOperators: NonAssignment
BreakBeforeTernaryOperators: true
BreakConstructorInitializers: AfterColon
BreakInheritanceList: AfterColon
ColumnLimit: 120
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ContinuationIndentWidth: 4
DerivePointerAlignment: false
IndentCaseLabels: true
IndentPPDirectives: None
IndentWidth: 4
MaxEmptyLinesToKeep: 1
PointerAlignment: Left
ReflowComments: true
SortIncludes: false
SortUsingDeclarations: true
SpaceAfterTemplateKeyword: false
SpacesBeforeTrailingComments: 2
TabWidth: 4
UseTab: Never

1
.gitignore vendored
View File

@ -1,4 +1,3 @@
__pycache__
/target
/nac3standalone/demo/linalg/target
nix/windows/msys2

View File

@ -8,17 +8,17 @@ repos:
hooks:
- id: nac3-cargo-fmt
name: nac3 cargo format
entry: nix
entry: cargo
language: system
types: [file, rust]
pass_filenames: false
description: Runs cargo fmt on the codebase.
args: [develop, -c, cargo, fmt, --all]
args: [fmt]
- id: nac3-cargo-clippy
name: nac3 cargo clippy
entry: nix
entry: cargo
language: system
types: [file, rust]
pass_filenames: false
description: Runs cargo clippy on the codebase.
args: [develop, -c, cargo, clippy, --tests]
args: [clippy, --tests]

507
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -2,11 +2,11 @@
"nodes": {
"nixpkgs": {
"locked": {
"lastModified": 1731319897,
"narHash": "sha256-PbABj4tnbWFMfBp6OcUK5iGy1QY+/Z96ZcLpooIbuEI=",
"lastModified": 1718530797,
"narHash": "sha256-pup6cYwtgvzDpvpSCFh1TEUjw2zkNpk8iolbKnyFmmU=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "dc460ec76cbff0e66e269457d7b728432263166c",
"rev": "b60ebf54c15553b393d144357375ea956f89e9a9",
"type": "github"
},
"original": {

View File

@ -6,7 +6,6 @@
outputs = { self, nixpkgs }:
let
pkgs = import nixpkgs { system = "x86_64-linux"; };
pkgs32 = import nixpkgs { system = "i686-linux"; };
in rec {
packages.x86_64-linux = rec {
llvm-nac3 = pkgs.callPackage ./nix/llvm {};
@ -16,22 +15,6 @@
ln -s ${pkgs.llvmPackages_14.clang-unwrapped}/bin/clang $out/bin/clang-irrt
ln -s ${pkgs.llvmPackages_14.llvm.out}/bin/llvm-as $out/bin/llvm-as-irrt
'';
demo-linalg-stub = pkgs.rustPlatform.buildRustPackage {
name = "demo-linalg-stub";
src = ./nac3standalone/demo/linalg;
cargoLock = {
lockFile = ./nac3standalone/demo/linalg/Cargo.lock;
};
doCheck = false;
};
demo-linalg-stub32 = pkgs32.rustPlatform.buildRustPackage {
name = "demo-linalg-stub32";
src = ./nac3standalone/demo/linalg;
cargoLock = {
lockFile = ./nac3standalone/demo/linalg/Cargo.lock;
};
doCheck = false;
};
nac3artiq = pkgs.python3Packages.toPythonModule (
pkgs.rustPlatform.buildRustPackage rec {
name = "nac3artiq";
@ -41,7 +24,7 @@
lockFile = ./Cargo.lock;
};
passthru.cargoLock = cargoLock;
nativeBuildInputs = [ pkgs.python3 (pkgs.wrapClangMulti pkgs.llvmPackages_14.clang) llvm-tools-irrt pkgs.llvmPackages_14.llvm.out llvm-nac3 ];
nativeBuildInputs = [ pkgs.python3 pkgs.llvmPackages_14.clang llvm-tools-irrt pkgs.llvmPackages_14.llvm.out llvm-nac3 ];
buildInputs = [ pkgs.python3 llvm-nac3 ];
checkInputs = [ (pkgs.python3.withPackages(ps: [ ps.numpy ps.scipy ])) ];
checkPhase =
@ -49,9 +32,7 @@
echo "Checking nac3standalone demos..."
pushd nac3standalone/demo
patchShebangs .
export DEMO_LINALG_STUB=${demo-linalg-stub}/lib/liblinalg.a
export DEMO_LINALG_STUB32=${demo-linalg-stub32}/lib/liblinalg.a
./check_demos.sh -i686
./check_demos.sh
popd
echo "Running Cargo tests..."
cargoCheckHook
@ -168,7 +149,7 @@
buildInputs = with pkgs; [
# build dependencies
packages.x86_64-linux.llvm-nac3
(pkgs.wrapClangMulti llvmPackages_14.clang) llvmPackages_14.llvm.out # for running nac3standalone demos
llvmPackages_14.clang llvmPackages_14.llvm.out # for running nac3standalone demos
packages.x86_64-linux.llvm-tools-irrt
cargo
rustc
@ -181,11 +162,6 @@
pre-commit
rustfmt
];
shellHook =
''
export DEMO_LINALG_STUB=${packages.x86_64-linux.demo-linalg-stub}/lib/liblinalg.a
export DEMO_LINALG_STUB32=${packages.x86_64-linux.demo-linalg-stub32}/lib/liblinalg.a
'';
};
devShells.x86_64-linux.msys2 = pkgs.mkShell {
name = "nac3-dev-shell-msys2";

View File

@ -12,10 +12,15 @@ crate-type = ["cdylib"]
itertools = "0.13"
pyo3 = { version = "0.21", features = ["extension-module", "gil-refs"] }
parking_lot = "0.12"
tempfile = "3.13"
tempfile = "3.10"
nac3parser = { path = "../nac3parser" }
nac3core = { path = "../nac3core" }
nac3ld = { path = "../nac3ld" }
[dependencies.inkwell]
version = "0.4"
default-features = false
features = ["llvm14-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]
[features]
init-llvm-profile = []
no-escape-analysis = ["nac3core/no-escape-analysis"]

View File

@ -1,24 +0,0 @@
from min_artiq import *
from numpy import int32
@nac3
class EmptyList:
core: KernelInvariant[Core]
def __init__(self):
self.core = Core()
@rpc
def get_empty(self) -> list[int32]:
return []
@kernel
def run(self):
a: list[int32] = self.get_empty()
if a != []:
raise ValueError
if __name__ == "__main__":
EmptyList().run()

View File

@ -112,15 +112,10 @@ def extern(function):
register_function(function)
return function
def rpc(arg=None, flags={}):
"""Decorates a function or method to be executed on the host interpreter."""
if arg is None:
def inner_decorator(function):
return rpc(function, flags)
return inner_decorator
register_function(arg)
return arg
def rpc(function):
"""Decorates a function declaration defined by the core device runtime."""
register_function(function)
return function
def kernel(function_or_method):
"""Decorates a function or method to be executed on the core device."""

View File

@ -1,26 +0,0 @@
from min_artiq import *
from numpy import ndarray, zeros as np_zeros
@nac3
class StrFail:
core: KernelInvariant[Core]
def __init__(self):
self.core = Core()
@kernel
def hello(self, arg: str):
pass
@kernel
def consume_ndarray(self, arg: ndarray[str, 1]):
pass
def run(self):
self.hello("world")
self.consume_ndarray(np_zeros([10], dtype=str))
if __name__ == "__main__":
StrFail().run()

View File

@ -1,40 +0,0 @@
from min_artiq import *
from numpy import int32
@nac3
class Demo:
attr1: KernelInvariant[int32] = 2
attr2: int32 = 4
attr3: Kernel[int32]
@kernel
def __init__(self):
self.attr3 = 8
@nac3
class NAC3Devices:
core: KernelInvariant[Core]
attr4: KernelInvariant[int32] = 16
def __init__(self):
self.core = Core()
@kernel
def run(self):
Demo.attr1 # Supported
# Demo.attr2 # Field not accessible on Kernel
# Demo.attr3 # Only attributes can be accessed in this way
# Demo.attr1 = 2 # Attributes are immutable
self.attr4 # Attributes can be accessed within class
obj = Demo()
obj.attr1 # Attributes can be accessed by class objects
NAC3Devices.attr4 # Attributes accessible for classes without __init__
if __name__ == "__main__":
NAC3Devices().run()

File diff suppressed because it is too large Load Diff

View File

@ -2,9 +2,9 @@
future_incompatible,
let_underscore,
nonstandard_style,
rust_2024_compatibility,
clippy::all
)]
#![warn(rust_2024_compatibility)]
#![warn(clippy::pedantic)]
#![allow(
unsafe_op_in_unsafe_fn,
@ -16,65 +16,63 @@
clippy::wildcard_imports
)]
use std::{
collections::{HashMap, HashSet},
fs,
io::Write,
process::Command,
rc::Rc,
sync::Arc,
};
use std::collections::{HashMap, HashSet};
use std::fs;
use std::io::Write;
use std::process::Command;
use std::rc::Rc;
use std::sync::Arc;
use itertools::Itertools;
use parking_lot::{Mutex, RwLock};
use pyo3::{
create_exception, exceptions,
prelude::*,
types::{PyBytes, PyDict, PySet},
use inkwell::{
memory_buffer::MemoryBuffer,
module::{Linkage, Module},
passes::PassBuilderOptions,
support::is_multithreaded,
targets::*,
OptimizationLevel,
};
use tempfile::{self, TempDir};
use itertools::Itertools;
use nac3core::codegen::{gen_func_impl, CodeGenLLVMOptions, CodeGenTargetMachineOptions};
use nac3core::toplevel::builtins::get_exn_constructor;
use nac3core::typecheck::typedef::{TypeEnum, Unifier, VarMap};
use nac3parser::{
ast::{ExprKind, Stmt, StmtKind, StrRef},
parser::parse_program,
};
use pyo3::create_exception;
use pyo3::prelude::*;
use pyo3::{exceptions, types::PyBytes, types::PyDict, types::PySet};
use parking_lot::{Mutex, RwLock};
use nac3core::{
codegen::{
concrete_type::ConcreteTypeStore, gen_func_impl, irrt::load_irrt, CodeGenLLVMOptions,
CodeGenTargetMachineOptions, CodeGenTask, CodeGenerator, WithCall, WorkerRegistry,
},
inkwell::{
context::Context,
memory_buffer::MemoryBuffer,
module::{FlagBehavior, Linkage, Module},
passes::PassBuilderOptions,
support::is_multithreaded,
targets::*,
OptimizationLevel,
},
nac3parser::{
ast::{Constant, ExprKind, Located, Stmt, StmtKind, StrRef},
parser::parse_program,
},
codegen::irrt::load_irrt,
codegen::{concrete_type::ConcreteTypeStore, CodeGenTask, WithCall, WorkerRegistry},
symbol_resolver::SymbolResolver,
toplevel::{
builtins::get_exn_constructor,
composer::{BuiltinFuncCreator, BuiltinFuncSpec, ComposerConfig, TopLevelComposer},
composer::{ComposerConfig, TopLevelComposer},
DefinitionId, GenCall, TopLevelDef,
},
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{into_var_map, FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
},
typecheck::typedef::{FunSignature, FuncArg},
typecheck::{type_inferencer::PrimitiveStore, typedef::Type},
};
use nac3ld::Linker;
use codegen::{
attributes_writeback, gen_core_log, gen_rtio_log, rpc_codegen_callback, ArtiqCodeGenerator,
use tempfile::{self, TempDir};
use crate::codegen::attributes_writeback;
use crate::{
codegen::{rpc_codegen_callback, ArtiqCodeGenerator},
symbol_resolver::{DeferredEvaluationStore, InnerResolver, PythonHelper, Resolver},
};
use symbol_resolver::{DeferredEvaluationStore, InnerResolver, PythonHelper, Resolver};
use timeline::TimeFns;
mod codegen;
mod symbol_resolver;
mod timeline;
use timeline::TimeFns;
#[derive(PartialEq, Clone, Copy)]
enum Isa {
Host,
@ -128,7 +126,7 @@ struct Nac3 {
isa: Isa,
time_fns: &'static (dyn TimeFns + Sync),
primitive: PrimitiveStore,
builtins: Vec<BuiltinFuncSpec>,
builtins: Vec<(StrRef, FunSignature, Arc<GenCall>)>,
pyid_to_def: Arc<RwLock<HashMap<u64, DefinitionId>>>,
primitive_ids: PrimitivePythonId,
working_directory: TempDir,
@ -195,8 +193,10 @@ impl Nac3 {
body.retain(|stmt| {
if let StmtKind::FunctionDef { ref decorator_list, .. } = stmt.node {
decorator_list.iter().any(|decorator| {
if let Some(id) = decorator_id_string(decorator) {
id == "kernel" || id == "portable" || id == "rpc"
if let ExprKind::Name { id, .. } = decorator.node {
id.to_string() == "kernel"
|| id.to_string() == "portable"
|| id.to_string() == "rpc"
} else {
false
}
@ -209,8 +209,9 @@ impl Nac3 {
}
StmtKind::FunctionDef { ref decorator_list, .. } => {
decorator_list.iter().any(|decorator| {
if let Some(id) = decorator_id_string(decorator) {
id == "extern" || id == "kernel" || id == "portable" || id == "rpc"
if let ExprKind::Name { id, .. } = decorator.node {
let id = id.to_string();
id == "extern" || id == "portable" || id == "kernel" || id == "rpc"
} else {
false
}
@ -263,7 +264,7 @@ impl Nac3 {
arg_names.len(),
));
}
for (i, FuncArg { ty, default_value, name, .. }) in args.iter().enumerate() {
for (i, FuncArg { ty, default_value, name }) in args.iter().enumerate() {
let in_name = match arg_names.get(i) {
Some(n) => n,
None if default_value.is_none() => {
@ -299,64 +300,6 @@ impl Nac3 {
None
}
/// Returns a [`Vec`] of builtins that needs to be initialized during method compilation time.
fn get_lateinit_builtins() -> Vec<Box<BuiltinFuncCreator>> {
vec![
Box::new(|primitives, unifier| {
let arg_ty = unifier.get_fresh_var(Some("T".into()), None);
(
"core_log".into(),
FunSignature {
args: vec![FuncArg {
name: "arg".into(),
ty: arg_ty.ty,
default_value: None,
is_vararg: false,
}],
ret: primitives.none,
vars: into_var_map([arg_ty]),
},
Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| {
gen_core_log(ctx, &obj, fun, &args, generator)?;
Ok(None)
}))),
)
}),
Box::new(|primitives, unifier| {
let arg_ty = unifier.get_fresh_var(Some("T".into()), None);
(
"rtio_log".into(),
FunSignature {
args: vec![
FuncArg {
name: "channel".into(),
ty: primitives.str,
default_value: None,
is_vararg: false,
},
FuncArg {
name: "arg".into(),
ty: arg_ty.ty,
default_value: None,
is_vararg: false,
},
],
ret: primitives.none,
vars: into_var_map([arg_ty]),
},
Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| {
gen_rtio_log(ctx, &obj, fun, &args, generator)?;
Ok(None)
}))),
)
}),
]
}
fn compile_method<T>(
&self,
obj: &PyAny,
@ -369,7 +312,6 @@ impl Nac3 {
let size_t = self.isa.get_size_type();
let (mut composer, mut builtins_def, mut builtins_ty) = TopLevelComposer::new(
self.builtins.clone(),
Self::get_lateinit_builtins(),
ComposerConfig { kernel_ann: Some("Kernel"), kernel_invariant_ann: "KernelInvariant" },
size_t,
);
@ -446,6 +388,7 @@ impl Nac3 {
pyid_to_type: pyid_to_type.clone(),
primitive_ids: self.primitive_ids.clone(),
global_value_ids: global_value_ids.clone(),
class_names: Mutex::default(),
name_to_pyid: name_to_pyid.clone(),
module: module.clone(),
id_to_pyval: RwLock::default(),
@ -476,25 +419,9 @@ impl Nac3 {
match &stmt.node {
StmtKind::FunctionDef { decorator_list, .. } => {
if decorator_list
.iter()
.any(|decorator| decorator_id_string(decorator) == Some("rpc".to_string()))
{
store_fun
.call1(
py,
(
def_id.0.into_py(py),
module.getattr(py, name.to_string().as_str()).unwrap(),
),
)
.unwrap();
let is_async = decorator_list.iter().any(|decorator| {
decorator_get_flags(decorator)
.iter()
.any(|constant| *constant == Constant::Str("async".into()))
});
rpc_ids.push((None, def_id, is_async));
if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "rpc".into())) {
store_fun.call1(py, (def_id.0.into_py(py), module.getattr(py, name.to_string().as_str()).unwrap())).unwrap();
rpc_ids.push((None, def_id));
}
}
StmtKind::ClassDef { name, body, .. } => {
@ -502,26 +429,19 @@ impl Nac3 {
let class_obj = module.getattr(py, class_name.as_str()).unwrap();
for stmt in body {
if let StmtKind::FunctionDef { name, decorator_list, .. } = &stmt.node {
if decorator_list.iter().any(|decorator| {
decorator_id_string(decorator) == Some("rpc".to_string())
}) {
let is_async = decorator_list.iter().any(|decorator| {
decorator_get_flags(decorator)
.iter()
.any(|constant| *constant == Constant::Str("async".into()))
});
if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "rpc".into())) {
if name == &"__init__".into() {
return Err(CompileError::new_err(format!(
"compilation failed\n----------\nThe constructor of class {} should not be decorated with rpc decorator (at {})",
class_name, stmt.location
)));
}
rpc_ids.push((Some((class_obj.clone(), *name)), def_id, is_async));
rpc_ids.push((Some((class_obj.clone(), *name)), def_id));
}
}
}
}
_ => (),
_ => ()
}
let id = *name_to_pyid.get(&name).unwrap();
@ -560,6 +480,7 @@ impl Nac3 {
pyid_to_type: pyid_to_type.clone(),
primitive_ids: self.primitive_ids.clone(),
global_value_ids: global_value_ids.clone(),
class_names: Mutex::default(),
id_to_pyval: RwLock::default(),
id_to_primitive: RwLock::default(),
field_to_val: RwLock::default(),
@ -576,10 +497,6 @@ impl Nac3 {
.register_top_level(synthesized.pop().unwrap(), Some(resolver.clone()), "", false)
.unwrap();
// Process IRRT
let context = Context::create();
let irrt = load_irrt(&context, resolver.as_ref());
let fun_signature =
FunSignature { args: vec![], ret: self.primitive.none, vars: VarMap::new() };
let mut store = ConcreteTypeStore::new();
@ -617,12 +534,13 @@ impl Nac3 {
let top_level = Arc::new(composer.make_top_level_context());
{
let rpc_codegen = rpc_codegen_callback();
let defs = top_level.definitions.read();
for (class_data, id, is_async) in &rpc_ids {
for (class_data, id) in &rpc_ids {
let mut def = defs[id.0].write();
match &mut *def {
TopLevelDef::Function { codegen_callback, .. } => {
*codegen_callback = Some(rpc_codegen_callback(*is_async));
*codegen_callback = Some(rpc_codegen.clone());
}
TopLevelDef::Class { methods, .. } => {
let (class_def, method_name) = class_data.as_ref().unwrap();
@ -633,7 +551,7 @@ impl Nac3 {
if let TopLevelDef::Function { codegen_callback, .. } =
&mut *defs[id.0].write()
{
*codegen_callback = Some(rpc_codegen_callback(*is_async));
*codegen_callback = Some(rpc_codegen.clone());
store_fun
.call1(
py,
@ -648,11 +566,6 @@ impl Nac3 {
}
}
}
TopLevelDef::Variable { .. } => {
return Err(CompileError::new_err(String::from(
"Unsupported @rpc annotation on global variable",
)))
}
}
}
}
@ -673,12 +586,33 @@ impl Nac3 {
let task = CodeGenTask {
subst: Vec::default(),
symbol_name: "__modinit__".to_string(),
body: instance.body,
signature,
resolver: resolver.clone(),
store,
unifier_index: instance.unifier_id,
calls: instance.calls,
id: 0,
};
let mut store = ConcreteTypeStore::new();
let mut cache = HashMap::new();
let signature = store.from_signature(
&mut composer.unifier,
&self.primitive,
&fun_signature,
&mut cache,
);
let signature = store.add_cty(signature);
let attributes_writeback_task = CodeGenTask {
subst: Vec::default(),
symbol_name: "attributes_writeback".to_string(),
body: Arc::new(Vec::default()),
signature,
resolver,
store,
unifier_index: instance.unifier_id,
calls: instance.calls,
calls: Arc::new(HashMap::default()),
id: 0,
};
@ -691,9 +625,7 @@ impl Nac3 {
let buffer = buffer.as_slice().into();
membuffer.lock().push(buffer);
})));
let size_t = context
.ptr_sized_int_type(&self.get_llvm_target_machine().get_target_data(), None)
.get_bit_width();
let size_t = if self.isa == Isa::Host { 64 } else { 32 };
let num_threads = if is_multithreaded() { 4 } else { 1 };
let thread_names: Vec<String> = (0..num_threads).map(|_| "main".to_string()).collect();
let threads: Vec<_> = thread_names
@ -702,27 +634,16 @@ impl Nac3 {
.collect();
let membuffer = membuffers.clone();
let mut has_return = false;
py.allow_threads(|| {
let (registry, handles) =
WorkerRegistry::create_workers(threads, top_level.clone(), &self.llvm_options, &f);
registry.add_task(task);
registry.wait_tasks_complete(handles);
let mut generator = ArtiqCodeGenerator::new("main".to_string(), size_t, self.time_fns);
let context = Context::create();
let module = context.create_module("main");
let target_machine = self.llvm_options.create_target_machine().unwrap();
module.set_data_layout(&target_machine.get_target_data().get_data_layout());
module.set_triple(&target_machine.get_triple());
module.add_basic_value_flag(
"Debug Info Version",
FlagBehavior::Warning,
context.i32_type().const_int(3, false),
);
module.add_basic_value_flag(
"Dwarf Version",
FlagBehavior::Warning,
context.i32_type().const_int(4, false),
);
let mut generator =
ArtiqCodeGenerator::new("attributes_writeback".to_string(), size_t, self.time_fns);
let context = inkwell::context::Context::create();
let module = context.create_module("attributes_writeback");
let builder = context.create_builder();
let (_, module, _) = gen_func_impl(
&context,
@ -730,27 +651,9 @@ impl Nac3 {
&registry,
builder,
module,
task,
attributes_writeback_task,
|generator, ctx| {
assert_eq!(instance.body.len(), 1, "toplevel module should have 1 statement");
let StmtKind::Expr { value: ref expr, .. } = instance.body[0].node else {
unreachable!("toplevel statement must be an expression")
};
let ExprKind::Call { .. } = expr.node else {
unreachable!("toplevel expression must be a function call")
};
let return_obj =
generator.gen_expr(ctx, &expr)?.map(|value| (expr.custom.unwrap(), value));
has_return = return_obj.is_some();
registry.wait_tasks_complete(handles);
attributes_writeback(
ctx,
generator,
inner_resolver.as_ref(),
&host_attributes,
return_obj,
)
attributes_writeback(ctx, generator, inner_resolver.as_ref(), &host_attributes)
},
)
.unwrap();
@ -759,24 +662,37 @@ impl Nac3 {
membuffer.lock().push(buffer);
});
embedding_map.setattr("expects_return", has_return).unwrap();
// Link all modules into `main`.
let context = inkwell::context::Context::create();
let buffers = membuffers.lock();
let main = context
.create_module_from_ir(MemoryBuffer::create_from_memory_range(
&buffers.last().unwrap(),
"main",
))
.create_module_from_ir(MemoryBuffer::create_from_memory_range(&buffers[0], "main"))
.unwrap();
for buffer in buffers.iter().rev().skip(1) {
for buffer in buffers.iter().skip(1) {
let other = context
.create_module_from_ir(MemoryBuffer::create_from_memory_range(buffer, "main"))
.unwrap();
main.link_in_module(other).map_err(|err| CompileError::new_err(err.to_string()))?;
}
main.link_in_module(irrt).map_err(|err| CompileError::new_err(err.to_string()))?;
let builder = context.create_builder();
let modinit_return = main
.get_function("__modinit__")
.unwrap()
.get_last_basic_block()
.unwrap()
.get_terminator()
.unwrap();
builder.position_before(&modinit_return);
builder
.build_call(
main.get_function("attributes_writeback").unwrap(),
&[],
"attributes_writeback",
)
.unwrap();
main.link_in_module(load_irrt(&context))
.map_err(|err| CompileError::new_err(err.to_string()))?;
let mut function_iter = main.get_first_function();
while let Some(func) = function_iter {
@ -862,41 +778,6 @@ impl Nac3 {
}
}
/// Retrieves the Name.id from a decorator, supports decorators with arguments.
fn decorator_id_string(decorator: &Located<ExprKind>) -> Option<String> {
if let ExprKind::Name { id, .. } = decorator.node {
// Bare decorator
return Some(id.to_string());
} else if let ExprKind::Call { func, .. } = &decorator.node {
// Decorators that are calls (e.g. "@rpc()") have Call for the node,
// need to extract the id from within.
if let ExprKind::Name { id, .. } = func.node {
return Some(id.to_string());
}
}
None
}
/// Retrieves flags from a decorator, if any.
fn decorator_get_flags(decorator: &Located<ExprKind>) -> Vec<Constant> {
let mut flags = vec![];
if let ExprKind::Call { keywords, .. } = &decorator.node {
for keyword in keywords {
if keyword.node.arg != Some("flags".into()) {
continue;
}
if let ExprKind::Set { elts } = &keyword.node.value.node {
for elt in elts {
if let ExprKind::Constant { value, .. } = &elt.node {
flags.push(value.clone());
}
}
}
}
}
flags
}
fn link_with_lld(elf_filename: String, obj_filename: String) -> PyResult<()> {
let linker_args = vec![
"-shared".to_string(),
@ -966,7 +847,7 @@ impl Nac3 {
Isa::RiscV32IMA => &timeline::NOW_PINNING_TIME_FNS,
Isa::CortexA9 | Isa::Host => &timeline::EXTERN_TIME_FNS,
};
let (primitive, _) = TopLevelComposer::make_primitives(isa.get_size_type());
let primitive: PrimitiveStore = TopLevelComposer::make_primitives(isa.get_size_type()).0;
let builtins = vec![
(
"now_mu".into(),
@ -982,7 +863,6 @@ impl Nac3 {
name: "t".into(),
ty: primitive.int64,
default_value: None,
is_vararg: false,
}],
ret: primitive.none,
vars: VarMap::new(),
@ -1002,7 +882,6 @@ impl Nac3 {
name: "dt".into(),
ty: primitive.int64,
default_value: None,
is_vararg: false,
}],
ret: primitive.none,
vars: VarMap::new(),

View File

@ -1,30 +1,14 @@
use std::{
collections::{HashMap, HashSet},
sync::{
atomic::{AtomicBool, Ordering::Relaxed},
Arc,
},
use inkwell::{
types::{BasicType, BasicTypeEnum},
values::BasicValueEnum,
AddressSpace,
};
use itertools::Itertools;
use parking_lot::RwLock;
use pyo3::{
types::{PyDict, PyTuple},
PyAny, PyObject, PyResult, Python,
};
use nac3core::{
codegen::{
classes::{NDArrayType, ProxyType},
CodeGenContext, CodeGenerator,
},
inkwell::{
module::Linkage,
types::{BasicType, BasicTypeEnum},
values::BasicValueEnum,
AddressSpace,
},
nac3parser::ast::{self, StrRef},
symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum},
toplevel::{
helper::PrimDef,
@ -36,8 +20,21 @@ use nac3core::{
typedef::{into_var_map, iter_type_vars, Type, TypeEnum, TypeVar, Unifier, VarMap},
},
};
use nac3parser::ast::{self, StrRef};
use parking_lot::{Mutex, RwLock};
use pyo3::{
types::{PyDict, PyTuple},
PyAny, PyObject, PyResult, Python,
};
use std::{
collections::{HashMap, HashSet},
sync::{
atomic::{AtomicBool, Ordering::Relaxed},
Arc,
},
};
use super::PrimitivePythonId;
use crate::PrimitivePythonId;
pub enum PrimitiveValue {
I32(i32),
@ -82,6 +79,7 @@ pub struct InnerResolver {
pub id_to_primitive: RwLock<HashMap<u64, PrimitiveValue>>,
pub field_to_val: RwLock<HashMap<ResolverField, Option<PyFieldHandle>>>,
pub global_value_ids: Arc<RwLock<HashMap<u64, PyObject>>>,
pub class_names: Mutex<HashMap<StrRef, Type>>,
pub pyid_to_def: Arc<RwLock<HashMap<u64, DefinitionId>>>,
pub pyid_to_type: Arc<RwLock<HashMap<u64, Type>>>,
pub primitive_ids: PrimitivePythonId,
@ -135,8 +133,6 @@ impl StaticValue for PythonValue {
format!("{}_const", self.id).as_str(),
);
global.set_constant(true);
// Set linkage of global to private to avoid name collisions
global.set_linkage(Linkage::Private);
global.set_initializer(&ctx.ctx.const_struct(
&[ctx.ctx.i32_type().const_int(u64::from(id), false).into()],
false,
@ -167,7 +163,7 @@ impl StaticValue for PythonValue {
PrimitiveValue::Bool(val) => {
ctx.ctx.i8_type().const_int(u64::from(*val), false).into()
}
PrimitiveValue::Str(val) => ctx.gen_string(generator, val).into(),
PrimitiveValue::Str(val) => ctx.ctx.const_string(val.as_bytes(), true).into(),
});
}
if let Some(global) = ctx.module.get_global(&self.id.to_string()) {
@ -333,19 +329,8 @@ impl InnerResolver {
Ok(Ok((primitives.exception, true)))
} else if ty_id == self.primitive_ids.list {
// do not handle type var param and concrete check here
let list_tvar = if let TypeEnum::TObj { obj_id, params, .. } =
&*unifier.get_ty_immutable(primitives.list)
{
assert_eq!(*obj_id, PrimDef::List.id());
iter_type_vars(params).nth(0).unwrap()
} else {
unreachable!()
};
let var = unifier.get_dummy_var().ty;
let list = unifier
.subst(primitives.list, &into_var_map([TypeVar { id: list_tvar.id, ty: var }]))
.unwrap();
let list = unifier.add_ty(TypeEnum::TList { ty: var });
Ok(Ok((list, false)))
} else if ty_id == self.primitive_ids.ndarray {
// do not handle type var param and concrete check here
@ -355,7 +340,7 @@ impl InnerResolver {
Ok(Ok((ndarray, false)))
} else if ty_id == self.primitive_ids.tuple {
// do not handle type var param and concrete check here
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![], is_vararg_ctx: false }), false)))
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false)))
} else if ty_id == self.primitive_ids.option {
Ok(Ok((primitives.option, false)))
} else if ty_id == self.primitive_ids.none {
@ -475,7 +460,7 @@ impl InnerResolver {
};
match &*unifier.get_ty(origin_ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::List.id() => {
TypeEnum::TList { .. } => {
if args.len() == 1 {
let ty = match self.get_pyty_obj_type(
py,
@ -492,21 +477,7 @@ impl InnerResolver {
"type list should take concrete parameters in typevar range".into(),
));
}
let list_tvar = if let TypeEnum::TObj { obj_id, params, .. } =
&*unifier.get_ty_immutable(primitives.list)
{
assert_eq!(*obj_id, PrimDef::List.id());
iter_type_vars(params).nth(0).unwrap()
} else {
unreachable!()
};
let list = unifier
.subst(
primitives.list,
&into_var_map([TypeVar { id: list_tvar.id, ty: ty.0 }]),
)
.unwrap();
Ok(Ok((list, true)))
Ok(Ok((unifier.add_ty(TypeEnum::TList { ty: ty.0 }), true)))
} else {
return Ok(Err(format!(
"type list needs exactly 1 type parameters, found {}",
@ -559,10 +530,7 @@ impl InnerResolver {
Err(err) => return Ok(Err(err)),
_ => return Ok(Err("tuple type needs at least 1 type parameters".to_string()))
};
Ok(Ok((
unifier.add_ty(TypeEnum::TTuple { ty: args, is_vararg_ctx: false }),
true,
)))
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: args }), true)))
}
TypeEnum::TObj { params, obj_id, .. } => {
let subst = {
@ -659,15 +627,12 @@ impl InnerResolver {
let pyid_to_def = self.pyid_to_def.read();
let constructor_ty = pyid_to_def.get(&py_obj_id).and_then(|def_id| {
defs.iter().find_map(|def| {
if let Some(rear_guard) = def.try_read() {
if let TopLevelDef::Class { object_id, methods, constructor, .. } = &*rear_guard
if let TopLevelDef::Class { object_id, methods, constructor, .. } = &*def.read() {
if object_id == def_id
&& constructor.is_some()
&& methods.iter().any(|(s, _, _)| s == &"__init__".into())
{
if object_id == def_id
&& constructor.is_some()
&& methods.iter().any(|(s, _, _)| s == &"__init__".into())
{
return *constructor;
}
return *constructor;
}
}
None
@ -699,38 +664,15 @@ impl InnerResolver {
primitives,
)? {
Ok(s) => s,
Err(e) => {
// Allow access to Class Attributes of Classes without having to initialize Objects
if self.pyid_to_def.read().contains_key(&py_obj_id) {
if let Some(def_id) = self.pyid_to_def.read().get(&py_obj_id).copied() {
let def = defs[def_id.0].read();
let TopLevelDef::Class { object_id, .. } = &*def else {
// only object is supported, functions are not supported
unreachable!("function type is not supported, should not be queried")
};
let ty = TypeEnum::TObj {
obj_id: *object_id,
params: VarMap::new(),
fields: HashMap::new(),
};
(unifier.add_ty(ty), true)
} else {
return Ok(Err(e));
}
} else {
return Ok(Err(e));
}
}
Err(e) => return Ok(Err(e)),
};
match (&*unifier.get_ty(extracted_ty), inst_check) {
// do the instantiation for these four types
(TypeEnum::TObj { obj_id, params, .. }, false) if *obj_id == PrimDef::List.id() => {
let ty = iter_type_vars(params).nth(0).unwrap().ty;
(TypeEnum::TList { ty }, false) => {
let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?;
if len == 0 {
assert!(matches!(
&*unifier.get_ty(ty),
&*unifier.get_ty(*ty),
TypeEnum::TVar { fields: None, range, .. }
if range.is_empty()
));
@ -739,25 +681,8 @@ impl InnerResolver {
let actual_ty =
self.get_list_elem_type(py, obj, len, unifier, defs, primitives)?;
match actual_ty {
Ok(t) => match unifier.unify(ty, t) {
Ok(()) => {
let list_tvar = if let TypeEnum::TObj { obj_id, params, .. } =
&*unifier.get_ty_immutable(primitives.list)
{
assert_eq!(*obj_id, PrimDef::List.id());
iter_type_vars(params).nth(0).unwrap()
} else {
unreachable!()
};
let list = unifier
.subst(
primitives.list,
&into_var_map([TypeVar { id: list_tvar.id, ty }]),
)
.unwrap();
Ok(Ok(list))
}
Ok(t) => match unifier.unify(*ty, t) {
Ok(()) => Ok(Ok(unifier.add_ty(TypeEnum::TList { ty: *ty }))),
Err(e) => Ok(Err(format!(
"type error ({}) for the list",
e.to_display(unifier)
@ -804,9 +729,7 @@ impl InnerResolver {
.map(|elem| self.get_obj_type(py, elem, unifier, defs, primitives))
.collect();
let types = types?;
Ok(types.map(|types| {
unifier.add_ty(TypeEnum::TTuple { ty: types, is_vararg_ctx: false })
}))
Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types })))
}
// special handling for option type since its class member layout in python side
// is special and cannot be mapped directly to a nac3 type as below
@ -981,7 +904,7 @@ impl InnerResolver {
} else if ty_id == self.primitive_ids.string || ty_id == self.primitive_ids.np_str_ {
let val: String = obj.extract().unwrap();
self.id_to_primitive.write().insert(id, PrimitiveValue::Str(val.clone()));
Ok(Some(ctx.gen_string(generator, val).into()))
Ok(Some(ctx.ctx.const_string(val.as_bytes(), true).into()))
} else if ty_id == self.primitive_ids.float || ty_id == self.primitive_ids.float64 {
let val: f64 = obj.extract().unwrap();
self.id_to_primitive.write().insert(id, PrimitiveValue::F64(val));
@ -994,21 +917,15 @@ impl InnerResolver {
}
let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?;
let elem_ty = match ctx.unifier.get_ty_immutable(expected_ty).as_ref() {
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
iter_type_vars(params).nth(0).unwrap().ty
}
_ => unreachable!("must be list"),
};
let size_t = generator.get_size_type(ctx.ctx);
let ty = if len == 0
&& matches!(&*ctx.unifier.get_ty_immutable(elem_ty), TypeEnum::TVar { .. })
let elem_ty = if let TypeEnum::TList { ty } =
ctx.unifier.get_ty_immutable(expected_ty).as_ref()
{
// The default type for zero-length lists of unknown element type is size_t
size_t.into()
*ty
} else {
ctx.get_llvm_type(generator, elem_ty)
unreachable!("must be list")
};
let ty = ctx.get_llvm_type(generator, elem_ty);
let size_t = generator.get_size_type(ctx.ctx);
let arr_ty = ctx
.ctx
.struct_type(&[ty.ptr_type(AddressSpace::default()).into(), size_t.into()], false);
@ -1212,9 +1129,7 @@ impl InnerResolver {
Ok(Some(ndarray.as_pointer_value().into()))
} else if ty_id == self.primitive_ids.tuple {
let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty);
let TypeEnum::TTuple { ty, is_vararg_ctx: false } = expected_ty_enum.as_ref() else {
unreachable!()
};
let TypeEnum::TTuple { ty } = expected_ty_enum.as_ref() else { unreachable!() };
let tup_tys = ty.iter();
let elements: &PyTuple = obj.downcast()?;
@ -1470,7 +1385,6 @@ impl SymbolResolver for Resolver {
&self,
id: StrRef,
_: &mut CodeGenContext<'ctx, '_>,
_: &mut dyn CodeGenerator,
) -> Option<ValueEnum<'ctx>> {
let sym_value = {
let id_to_val = self.0.id_to_pyval.read();

View File

@ -1,12 +1,9 @@
use itertools::Either;
use nac3core::{
codegen::CodeGenContext,
inkwell::{
values::{BasicValueEnum, CallSiteValue},
AddressSpace, AtomicOrdering,
},
use inkwell::{
values::{BasicValueEnum, CallSiteValue},
AddressSpace, AtomicOrdering,
};
use itertools::Either;
use nac3core::codegen::CodeGenContext;
/// Functions for manipulating the timeline.
pub trait TimeFns {
@ -34,7 +31,7 @@ impl TimeFns for NowPinningTimeFns64 {
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx
.builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
@ -83,7 +80,7 @@ impl TimeFns for NowPinningTimeFns64 {
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx
.builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
@ -112,7 +109,7 @@ impl TimeFns for NowPinningTimeFns64 {
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx
.builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
@ -210,7 +207,7 @@ impl TimeFns for NowPinningTimeFns {
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx
.builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
@ -261,7 +258,7 @@ impl TimeFns for NowPinningTimeFns {
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
let now_hiptr = ctx
.builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();

View File

@ -10,6 +10,7 @@ constant-optimization = ["fold"]
fold = []
[dependencies]
lazy_static = "1.4"
parking_lot = "0.12"
string-interner = "0.17"
fxhash = "0.2"

View File

@ -5,12 +5,14 @@ pub use crate::location::Location;
use fxhash::FxBuildHasher;
use parking_lot::{Mutex, MutexGuard};
use std::{cell::RefCell, collections::HashMap, fmt, sync::LazyLock};
use std::{cell::RefCell, collections::HashMap, fmt};
use string_interner::{symbol::SymbolU32, DefaultBackend, StringInterner};
pub type Interner = StringInterner<DefaultBackend, FxBuildHasher>;
static INTERNER: LazyLock<Mutex<Interner>> =
LazyLock::new(|| Mutex::new(StringInterner::with_hasher(FxBuildHasher::default())));
lazy_static! {
static ref INTERNER: Mutex<Interner> =
Mutex::new(StringInterner::with_hasher(FxBuildHasher::default()));
}
thread_local! {
static LOCAL_INTERNER: RefCell<HashMap<String, StrRef>> = RefCell::default();

View File

@ -2,9 +2,9 @@
future_incompatible,
let_underscore,
nonstandard_style,
rust_2024_compatibility,
clippy::all
)]
#![warn(rust_2024_compatibility)]
#![warn(clippy::pedantic)]
#![allow(
clippy::missing_errors_doc,
@ -14,6 +14,9 @@
clippy::wildcard_imports
)]
#[macro_use]
extern crate lazy_static;
mod ast_gen;
mod constant;
#[cfg(feature = "fold")]

View File

@ -4,23 +4,20 @@ version = "0.1.0"
authors = ["M-Labs"]
edition = "2021"
[features]
no-escape-analysis = []
[dependencies]
itertools = "0.13"
crossbeam = "0.8"
indexmap = "2.6"
indexmap = "2.2"
parking_lot = "0.12"
rayon = "1.10"
rayon = "1.8"
nac3parser = { path = "../nac3parser" }
strum = "0.26"
strum_macros = "0.26"
strum = "0.26.2"
strum_macros = "0.26.4"
[dependencies.inkwell]
version = "0.5"
version = "0.4"
default-features = false
features = ["llvm14-0-prefer-dynamic", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]
features = ["llvm14-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]
[dev-dependencies]
test-case = "1.2.0"

View File

@ -1,3 +1,4 @@
use regex::Regex;
use std::{
env,
fs::File,
@ -6,53 +7,34 @@ use std::{
process::{Command, Stdio},
};
use regex::Regex;
fn main() {
let out_dir = env::var("OUT_DIR").unwrap();
let out_dir = Path::new(&out_dir);
let irrt_dir = Path::new("irrt");
let irrt_cpp_path = irrt_dir.join("irrt.cpp");
const FILE: &str = "src/codegen/irrt/irrt.c";
/*
* HACK: Sadly, clang doesn't let us emit generic LLVM bitcode.
* Compiling for WASM32 and filtering the output with regex is the closest we can get.
*/
let mut flags: Vec<&str> = vec![
let flags: &[&str] = &[
"--target=wasm32",
"-x",
"c++",
"-std=c++20",
FILE,
"-fno-discard-value-names",
"-fno-exceptions",
"-fno-rtti",
match env::var("PROFILE").as_deref() {
Ok("debug") => "-O0",
Ok("release") => "-O3",
flavor => panic!("Unknown or missing build flavor {flavor:?}"),
},
"-emit-llvm",
"-S",
"-Wall",
"-Wextra",
"-o",
"-",
"-I",
irrt_dir.to_str().unwrap(),
irrt_cpp_path.to_str().unwrap(),
];
match env::var("PROFILE").as_deref() {
Ok("debug") => {
flags.push("-O0");
flags.push("-DIRRT_DEBUG_ASSERT");
}
Ok("release") => {
flags.push("-O3");
}
flavor => panic!("Unknown or missing build flavor {flavor:?}"),
}
println!("cargo:rerun-if-changed={FILE}");
let out_dir = env::var("OUT_DIR").unwrap();
let out_path = Path::new(&out_dir);
// Tell Cargo to rerun if any file under `irrt_dir` (recursive) changes
println!("cargo:rerun-if-changed={}", irrt_dir.to_str().unwrap());
// Compile IRRT and capture the LLVM IR output
let output = Command::new("clang-irrt")
.args(flags)
.output()
@ -66,17 +48,7 @@ fn main() {
let output = std::str::from_utf8(&output.stdout).unwrap().replace("\r\n", "\n");
let mut filtered_output = String::with_capacity(output.len());
// Filter out irrelevant IR
//
// Regex:
// - `(?ms:^define.*?\}$)` captures LLVM `define` blocks
// - `(?m:^declare.*?$)` captures LLVM `declare` lines
// - `(?m:^%.+?=\s*type\s*\{.+?\}$)` captures LLVM `type` declarations
// - `(?m:^@.+?=.+$)` captures global constants
let regex_filter = Regex::new(
r"(?ms:^define.*?\}$)|(?m:^declare.*?$)|(?m:^%.+?=\s*type\s*\{.+?\}$)|(?m:^@.+?=.+$)",
)
.unwrap();
let regex_filter = Regex::new(r"(?ms:^define.*?\}$)|(?m:^declare.*?$)").unwrap();
for f in regex_filter.captures_iter(&output) {
assert_eq!(f.len(), 1);
filtered_output.push_str(&f[0]);
@ -87,22 +59,18 @@ fn main() {
.unwrap()
.replace_all(&filtered_output, "");
// For debugging
// Doing `DEBUG_DUMP_IRRT=1 cargo build -p nac3core` dumps the LLVM IR generated
const DEBUG_DUMP_IRRT: &str = "DEBUG_DUMP_IRRT";
println!("cargo:rerun-if-env-changed={DEBUG_DUMP_IRRT}");
if env::var(DEBUG_DUMP_IRRT).is_ok() {
let mut file = File::create(out_dir.join("irrt.ll")).unwrap();
println!("cargo:rerun-if-env-changed=DEBUG_DUMP_IRRT");
if env::var("DEBUG_DUMP_IRRT").is_ok() {
let mut file = File::create(out_path.join("irrt.ll")).unwrap();
file.write_all(output.as_bytes()).unwrap();
let mut file = File::create(out_dir.join("irrt-filtered.ll")).unwrap();
let mut file = File::create(out_path.join("irrt-filtered.ll")).unwrap();
file.write_all(filtered_output.as_bytes()).unwrap();
}
let mut llvm_as = Command::new("llvm-as-irrt")
.stdin(Stdio::piped())
.arg("-o")
.arg(out_dir.join("irrt.bc"))
.arg(out_path.join("irrt.bc"))
.spawn()
.unwrap();
llvm_as.stdin.as_mut().unwrap().write_all(filtered_output.as_bytes()).unwrap();

View File

@ -1,6 +0,0 @@
#include "irrt/exception.hpp"
#include "irrt/int_types.hpp"
#include "irrt/list.hpp"
#include "irrt/math.hpp"
#include "irrt/ndarray.hpp"
#include "irrt/slice.hpp"

View File

@ -1,9 +0,0 @@
#pragma once
#include "irrt/int_types.hpp"
template<typename SizeT>
struct CSlice {
uint8_t* base;
SizeT len;
};

View File

@ -1,25 +0,0 @@
#pragma once
// Set in nac3core/build.rs
#ifdef IRRT_DEBUG_ASSERT
#define IRRT_DEBUG_ASSERT_BOOL true
#else
#define IRRT_DEBUG_ASSERT_BOOL false
#endif
#define raise_debug_assert(SizeT, msg, param1, param2, param3) \
raise_exception(SizeT, EXN_ASSERTION_ERROR, "IRRT debug assert failed: " msg, param1, param2, param3)
#define debug_assert_eq(SizeT, lhs, rhs) \
if constexpr (IRRT_DEBUG_ASSERT_BOOL) { \
if ((lhs) != (rhs)) { \
raise_debug_assert(SizeT, "LHS = {0}. RHS = {1}", lhs, rhs, NO_PARAM); \
} \
}
#define debug_assert(SizeT, expr) \
if constexpr (IRRT_DEBUG_ASSERT_BOOL) { \
if (!(expr)) { \
raise_debug_assert(SizeT, "Got false.", NO_PARAM, NO_PARAM, NO_PARAM); \
} \
}

View File

@ -1,82 +0,0 @@
#pragma once
#include "irrt/cslice.hpp"
#include "irrt/int_types.hpp"
/**
* @brief The int type of ARTIQ exception IDs.
*/
typedef int32_t ExceptionId;
/*
* Set of exceptions C++ IRRT can use.
* Must be synchronized with `setup_irrt_exceptions` in `nac3core/src/codegen/irrt/mod.rs`.
*/
extern "C" {
ExceptionId EXN_INDEX_ERROR;
ExceptionId EXN_VALUE_ERROR;
ExceptionId EXN_ASSERTION_ERROR;
ExceptionId EXN_TYPE_ERROR;
}
/**
* @brief Extern function to `__nac3_raise`
*
* The parameter `err` could be `Exception<int32_t>` or `Exception<int64_t>`. The caller
* must make sure to pass `Exception`s with the correct `SizeT` depending on the `size_t` of the runtime.
*/
extern "C" void __nac3_raise(void* err);
namespace {
/**
* @brief NAC3's Exception struct
*/
template<typename SizeT>
struct Exception {
ExceptionId id;
CSlice<SizeT> filename;
int32_t line;
int32_t column;
CSlice<SizeT> function;
CSlice<SizeT> msg;
int64_t params[3];
};
constexpr int64_t NO_PARAM = 0;
template<typename SizeT>
void _raise_exception_helper(ExceptionId id,
const char* filename,
int32_t line,
const char* function,
const char* msg,
int64_t param0,
int64_t param1,
int64_t param2) {
Exception<SizeT> e = {
.id = id,
.filename = {.base = reinterpret_cast<const uint8_t*>(filename), .len = __builtin_strlen(filename)},
.line = line,
.column = 0,
.function = {.base = reinterpret_cast<const uint8_t*>(function), .len = __builtin_strlen(function)},
.msg = {.base = reinterpret_cast<const uint8_t*>(msg), .len = __builtin_strlen(msg)},
};
e.params[0] = param0;
e.params[1] = param1;
e.params[2] = param2;
__nac3_raise(reinterpret_cast<void*>(&e));
__builtin_unreachable();
}
/**
* @brief Raise an exception with location details (location in the IRRT source files).
* @param SizeT The runtime `size_t` type.
* @param id The ID of the exception to raise.
* @param msg A global constant C-string of the error message.
*
* `param0` to `param2` are optional format arguments of `msg`. They should be set to
* `NO_PARAM` to indicate they are unused.
*/
#define raise_exception(SizeT, id, msg, param0, param1, param2) \
_raise_exception_helper<SizeT>(id, __FILE__, __LINE__, __FUNCTION__, msg, param0, param1, param2)
} // namespace

View File

@ -1,22 +0,0 @@
#pragma once
#if __STDC_VERSION__ >= 202000
using int8_t = _BitInt(8);
using uint8_t = unsigned _BitInt(8);
using int32_t = _BitInt(32);
using uint32_t = unsigned _BitInt(32);
using int64_t = _BitInt(64);
using uint64_t = unsigned _BitInt(64);
#else
using int8_t = _ExtInt(8);
using uint8_t = unsigned _ExtInt(8);
using int32_t = _ExtInt(32);
using uint32_t = unsigned _ExtInt(32);
using int64_t = _ExtInt(64);
using uint64_t = unsigned _ExtInt(64);
#endif
// NDArray indices are always `uint32_t`.
using NDIndex = uint32_t;
// The type of an index or a value describing the length of a range/slice is always `int32_t`.
using SliceIndex = int32_t;

View File

@ -1,75 +0,0 @@
#pragma once
#include "irrt/int_types.hpp"
#include "irrt/math_util.hpp"
extern "C" {
// Handle list assignment and dropping part of the list when
// both dest_step and src_step are +1.
// - All the index must *not* be out-of-bound or negative,
// - The end index is *inclusive*,
// - The length of src and dest slice size should already
// be checked: if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest)
SliceIndex __nac3_list_slice_assign_var_size(SliceIndex dest_start,
SliceIndex dest_end,
SliceIndex dest_step,
uint8_t* dest_arr,
SliceIndex dest_arr_len,
SliceIndex src_start,
SliceIndex src_end,
SliceIndex src_step,
uint8_t* src_arr,
SliceIndex src_arr_len,
const SliceIndex size) {
/* if dest_arr_len == 0, do nothing since we do not support extending list */
if (dest_arr_len == 0)
return dest_arr_len;
/* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */
if (src_step == dest_step && dest_step == 1) {
const SliceIndex src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0;
const SliceIndex dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
if (src_len > 0) {
__builtin_memmove(dest_arr + dest_start * size, src_arr + src_start * size, src_len * size);
}
if (dest_len > 0) {
/* dropping */
__builtin_memmove(dest_arr + (dest_start + src_len) * size, dest_arr + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size);
}
/* shrink size */
return dest_arr_len - (dest_len - src_len);
}
/* if two range overlaps, need alloca */
uint8_t need_alloca = (dest_arr == src_arr)
&& !(max(dest_start, dest_end) < min(src_start, src_end)
|| max(src_start, src_end) < min(dest_start, dest_end));
if (need_alloca) {
uint8_t* tmp = reinterpret_cast<uint8_t*>(__builtin_alloca(src_arr_len * size));
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
src_arr = tmp;
}
SliceIndex src_ind = src_start;
SliceIndex dest_ind = dest_start;
for (; (src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end); src_ind += src_step, dest_ind += dest_step) {
/* for constant optimization */
if (size == 1) {
__builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1);
} else if (size == 4) {
__builtin_memcpy(dest_arr + dest_ind * 4, src_arr + src_ind * 4, 4);
} else if (size == 8) {
__builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8);
} else {
/* memcpy for var size, cannot overlap after previous alloca */
__builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size);
}
}
/* only dest_step == 1 can we shrink the dest list. */
/* size should be ensured prior to calling this function */
if (dest_step == 1 && dest_end >= dest_start) {
__builtin_memmove(dest_arr + dest_ind * size, dest_arr + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size);
return dest_arr_len - (dest_end - dest_ind) - 1;
}
return dest_arr_len;
}
} // extern "C"

View File

@ -1,93 +0,0 @@
#pragma once
namespace {
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
// need to make sure `exp >= 0` before calling this function
template<typename T>
T __nac3_int_exp_impl(T base, T exp) {
T res = 1;
/* repeated squaring method */
do {
if (exp & 1) {
res *= base; /* for n odd */
}
exp >>= 1;
base *= base;
} while (exp);
return res;
}
} // namespace
#define DEF_nac3_int_exp_(T) \
T __nac3_int_exp_##T(T base, T exp) { \
return __nac3_int_exp_impl(base, exp); \
}
extern "C" {
// Putting semicolons here to make clang-format not reformat this into
// a stair shape.
DEF_nac3_int_exp_(int32_t);
DEF_nac3_int_exp_(int64_t);
DEF_nac3_int_exp_(uint32_t);
DEF_nac3_int_exp_(uint64_t);
int32_t __nac3_isinf(double x) {
return __builtin_isinf(x);
}
int32_t __nac3_isnan(double x) {
return __builtin_isnan(x);
}
double tgamma(double arg);
double __nac3_gamma(double z) {
// Handling for denormals
// | x | Python gamma(x) | C tgamma(x) |
// --- | ----------------- | --------------- | ----------- |
// (1) | nan | nan | nan |
// (2) | -inf | -inf | inf |
// (3) | inf | inf | inf |
// (4) | 0.0 | inf | inf |
// (5) | {-1.0, -2.0, ...} | inf | nan |
// (1)-(3)
if (__builtin_isinf(z) || __builtin_isnan(z)) {
return z;
}
double v = tgamma(z);
// (4)-(5)
return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v;
}
double lgamma(double arg);
double __nac3_gammaln(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: gammaln(-inf) -> -inf
// - libm : lgamma(-inf) -> inf
if (__builtin_isinf(x)) {
return x;
}
return lgamma(x);
}
double j0(double x);
double __nac3_j0(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: j0(inf) -> nan
// - libm : j0(inf) -> 0.0
if (__builtin_isinf(x)) {
return __builtin_nan("");
}
return j0(x);
}
}

View File

@ -1,13 +0,0 @@
#pragma once
namespace {
template<typename T>
const T& max(const T& a, const T& b) {
return a > b ? a : b;
}
template<typename T>
const T& min(const T& a, const T& b) {
return a > b ? b : a;
}
} // namespace

View File

@ -1,144 +0,0 @@
#pragma once
#include "irrt/int_types.hpp"
namespace {
template<typename SizeT>
SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) {
__builtin_assume(end_idx <= list_len);
SizeT num_elems = 1;
for (SizeT i = begin_idx; i < end_idx; ++i) {
SizeT val = list_data[i];
__builtin_assume(val > 0);
num_elems *= val;
}
return num_elems;
}
template<typename SizeT>
void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT num_dims, NDIndex* idxs) {
SizeT stride = 1;
for (SizeT dim = 0; dim < num_dims; dim++) {
SizeT i = num_dims - dim - 1;
__builtin_assume(dims[i] > 0);
idxs[i] = (index / stride) % dims[i];
stride *= dims[i];
}
}
template<typename SizeT>
SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims, SizeT num_dims, const NDIndex* indices, SizeT num_indices) {
SizeT idx = 0;
SizeT stride = 1;
for (SizeT i = 0; i < num_dims; ++i) {
SizeT ri = num_dims - i - 1;
if (ri < num_indices) {
idx += stride * indices[ri];
}
__builtin_assume(dims[i] > 0);
stride *= dims[ri];
}
return idx;
}
template<typename SizeT>
void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims,
SizeT lhs_ndims,
const SizeT* rhs_dims,
SizeT rhs_ndims,
SizeT* out_dims) {
SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
for (SizeT i = 0; i < max_ndims; ++i) {
const SizeT* lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr;
const SizeT* rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr;
SizeT* out_dim = &out_dims[max_ndims - i - 1];
if (lhs_dim_sz == nullptr) {
*out_dim = *rhs_dim_sz;
} else if (rhs_dim_sz == nullptr) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == 1) {
*out_dim = *rhs_dim_sz;
} else if (*rhs_dim_sz == 1) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == *rhs_dim_sz) {
*out_dim = *lhs_dim_sz;
} else {
__builtin_unreachable();
}
}
}
template<typename SizeT>
void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT* src_dims,
SizeT src_ndims,
const NDIndex* in_idx,
NDIndex* out_idx) {
for (SizeT i = 0; i < src_ndims; ++i) {
SizeT src_i = src_ndims - i - 1;
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
}
}
} // namespace
extern "C" {
uint32_t __nac3_ndarray_calc_size(const uint32_t* list_data, uint32_t list_len, uint32_t begin_idx, uint32_t end_idx) {
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
}
uint64_t
__nac3_ndarray_calc_size64(const uint64_t* list_data, uint64_t list_len, uint64_t begin_idx, uint64_t end_idx) {
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
}
void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32_t num_dims, NDIndex* idxs) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
}
void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndex* idxs) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
}
uint32_t
__nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims, const NDIndex* indices, uint32_t num_indices) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
}
uint64_t
__nac3_ndarray_flatten_index64(const uint64_t* dims, uint64_t num_dims, const NDIndex* indices, uint64_t num_indices) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
}
void __nac3_ndarray_calc_broadcast(const uint32_t* lhs_dims,
uint32_t lhs_ndims,
const uint32_t* rhs_dims,
uint32_t rhs_ndims,
uint32_t* out_dims) {
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
}
void __nac3_ndarray_calc_broadcast64(const uint64_t* lhs_dims,
uint64_t lhs_ndims,
const uint64_t* rhs_dims,
uint64_t rhs_ndims,
uint64_t* out_dims) {
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
}
void __nac3_ndarray_calc_broadcast_idx(const uint32_t* src_dims,
uint32_t src_ndims,
const NDIndex* in_idx,
NDIndex* out_idx) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
}
void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims,
uint64_t src_ndims,
const NDIndex* in_idx,
NDIndex* out_idx) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
}
}

View File

@ -1,28 +0,0 @@
#pragma once
#include "irrt/int_types.hpp"
extern "C" {
SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
if (i < 0) {
i = len + i;
}
if (i < 0) {
return 0;
} else if (i > len) {
return len;
}
return i;
}
SliceIndex __nac3_range_slice_len(const SliceIndex start, const SliceIndex end, const SliceIndex step) {
SliceIndex diff = end - start;
if (diff > 0 && step > 0) {
return ((diff - 1) / step) + 1;
} else if (diff < 0 && step < 0) {
return ((diff + 1) / step) + 1;
} else {
return 0;
}
}
}

View File

@ -1,102 +1,26 @@
use inkwell::{
types::BasicTypeEnum,
values::{BasicValue, BasicValueEnum, IntValue, PointerValue},
FloatPredicate, IntPredicate, OptimizationLevel,
};
use inkwell::types::BasicTypeEnum;
use inkwell::values::BasicValueEnum;
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
use itertools::Itertools;
use super::{
classes::{
ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor,
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
},
expr::destructure_range,
extern_fns, irrt,
irrt::calculate_len_for_slice_range,
llvm_intrinsics,
macros::codegen_unreachable,
numpy,
numpy::ndarray_elementwise_unaryop_impl,
stmt::gen_for_callback_incrementing,
CodeGenContext, CodeGenerator,
};
use crate::{
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys},
typecheck::typedef::{Type, TypeEnum},
};
use crate::codegen::classes::{NDArrayValue, ProxyValue, UntypedArrayLikeAccessor};
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
use crate::codegen::stmt::gen_for_callback_incrementing;
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
use crate::toplevel::helper::PrimDef;
use crate::toplevel::numpy::unpack_ndarray_var_tys;
use crate::typecheck::typedef::Type;
/// Shorthand for [`unreachable!()`] when a type of argument is not supported.
///
/// The generated message will contain the function name and the name of the unsupported type.
fn unsupported_type(ctx: &CodeGenContext<'_, '_>, fn_name: &str, tys: &[Type]) -> ! {
codegen_unreachable!(
ctx,
unreachable!(
"{fn_name}() not supported for '{}'",
tys.iter().map(|ty| format!("'{}'", ctx.unifier.stringify(*ty))).join(", "),
)
}
/// Invokes the `len` builtin function.
pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
n: (Type, BasicValueEnum<'ctx>),
) -> Result<IntValue<'ctx>, String> {
let llvm_i32 = ctx.ctx.i32_type();
let range_ty = ctx.primitives.range;
let (arg_ty, arg) = n;
Ok(if ctx.unifier.unioned(arg_ty, range_ty) {
let arg = RangeValue::from_ptr_val(arg.into_pointer_value(), Some("range"));
let (start, end, step) = destructure_range(ctx, arg);
calculate_len_for_slice_range(generator, ctx, start, end, step)
} else {
match &*ctx.unifier.get_ty_immutable(arg_ty) {
TypeEnum::TTuple { ty, .. } => llvm_i32.const_int(ty.len() as u64, false),
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::List.id() => {
let zero = llvm_i32.const_zero();
let len = ctx
.build_gep_and_load(
arg.into_pointer_value(),
&[zero, llvm_i32.const_int(1, false)],
None,
)
.into_int_value();
ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap()
}
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let llvm_usize = generator.get_size_type(ctx.ctx);
let arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), llvm_usize, None);
let ndims = arg.dim_sizes().size(ctx, generator);
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(IntPredicate::NE, ndims, llvm_usize.const_zero(), "")
.unwrap(),
"0:TypeError",
"len() of unsized object",
[None, None, None],
ctx.current_loc,
);
let len = unsafe {
arg.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_zero(),
None,
)
};
ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap()
}
_ => codegen_unreachable!(ctx),
}
})
}
/// Invokes the `int32` builtin function.
pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
@ -107,6 +31,7 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
let llvm_usize = generator.get_size_type(ctx.ctx);
let (n_ty, n) = n;
Ok(match n {
BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool));
@ -494,7 +419,7 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::FloatValue(n) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
llvm_intrinsics::call_float_rint(ctx, n, None).into()
llvm_intrinsics::call_float_roundeven(ctx, n, None).into()
}
BasicValueEnum::PointerValue(n)
@ -677,7 +602,7 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>(
ret_elem_ty,
None,
NDArrayValue::from_ptr_val(n, llvm_usize, None),
|generator, ctx, val| call_ceil(generator, ctx, (elem_ty, val), ret_elem_ty),
|generator, ctx, val| call_floor(generator, ctx, (elem_ty, val), ret_elem_ty),
)?;
ndarray.as_base_value().into()
@ -736,6 +661,90 @@ pub fn call_min<'ctx>(
}
}
/// Invokes the `np_min` builtin function.
pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
a: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_min";
let llvm_usize = generator.get_size_type(ctx.ctx);
let (a_ty, a) = a;
Ok(match a {
BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => {
debug_assert!([
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.uint32,
ctx.primitives.int64,
ctx.primitives.uint64,
ctx.primitives.float,
]
.iter()
.any(|ty| ctx.unifier.unioned(a_ty, *ty)));
a
}
BasicValueEnum::PointerValue(n)
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None));
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let n_sz_eqz = ctx
.builder
.build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "")
.unwrap();
ctx.make_assert(
generator,
n_sz_eqz,
"0:ValueError",
"zero-size array to reduction operation minimum which has no identity",
[None, None, None],
ctx.current_loc,
);
}
let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
unsafe {
let identity =
n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
ctx.builder.build_store(accumulator_addr, identity).unwrap();
}
gen_for_callback_incrementing(
generator,
ctx,
llvm_usize.const_int(1, false),
(n_sz, false),
|generator, ctx, _, idx| {
let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) };
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
let result = call_min(ctx, (elem_ty, accumulator), (elem_ty, elem));
ctx.builder.build_store(accumulator_addr, result).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
accumulator
}
_ => unsupported_type(ctx, FN_NAME, &[a_ty]),
})
}
/// Invokes the `np_minimum` builtin function.
pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
@ -794,7 +803,7 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
} else {
codegen_unreachable!(ctx)
unreachable!()
};
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
@ -868,20 +877,18 @@ pub fn call_max<'ctx>(
}
}
/// Invokes the `np_max`, `np_min`, `np_argmax`, `np_argmin` functions
/// * `fn_name`: Can be one of `"np_argmin"`, `"np_argmax"`, `"np_max"`, `"np_min"`
pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
/// Invokes the `np_max` builtin function.
pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
a: (Type, BasicValueEnum<'ctx>),
fn_name: &str,
) -> Result<BasicValueEnum<'ctx>, String> {
debug_assert!(["np_argmin", "np_argmax", "np_max", "np_min"].iter().any(|f| *f == fn_name));
const FN_NAME: &str = "np_max";
let llvm_int64 = ctx.ctx.i64_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let (a_ty, a) = a;
Ok(match a {
BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => {
debug_assert!([
@ -895,12 +902,9 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
.iter()
.any(|ty| ctx.unifier.unioned(a_ty, *ty)));
match fn_name {
"np_argmin" | "np_argmax" => llvm_int64.const_zero().into(),
"np_max" | "np_min" => a,
_ => codegen_unreachable!(ctx),
}
a
}
BasicValueEnum::PointerValue(n)
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
@ -919,82 +923,41 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
generator,
n_sz_eqz,
"0:ValueError",
format!("zero-size array to reduction operation {fn_name}").as_str(),
"zero-size array to reduction operation minimum which has no identity",
[None, None, None],
ctx.current_loc,
);
}
let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
let res_idx = generator.gen_var_alloc(ctx, llvm_int64.into(), None)?;
unsafe {
let identity =
n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
ctx.builder.build_store(accumulator_addr, identity).unwrap();
ctx.builder.build_store(res_idx, llvm_int64.const_zero()).unwrap();
}
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_int64.const_int(1, false),
llvm_usize.const_int(1, false),
(n_sz, false),
|generator, ctx, _, idx| {
let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) };
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
let cur_idx = ctx.builder.build_load(res_idx, "").unwrap();
let result = match fn_name {
"np_argmin" | "np_min" => {
call_min(ctx, (elem_ty, accumulator), (elem_ty, elem))
}
"np_argmax" | "np_max" => {
call_max(ctx, (elem_ty, accumulator), (elem_ty, elem))
}
_ => codegen_unreachable!(ctx),
};
let updated_idx = match (accumulator, result) {
(BasicValueEnum::IntValue(m), BasicValueEnum::IntValue(n)) => ctx
.builder
.build_select(
ctx.builder.build_int_compare(IntPredicate::NE, m, n, "").unwrap(),
idx.into(),
cur_idx,
"",
)
.unwrap(),
(BasicValueEnum::FloatValue(m), BasicValueEnum::FloatValue(n)) => ctx
.builder
.build_select(
ctx.builder
.build_float_compare(FloatPredicate::ONE, m, n, "")
.unwrap(),
idx.into(),
cur_idx,
"",
)
.unwrap(),
_ => unsupported_type(ctx, fn_name, &[elem_ty, elem_ty]),
};
ctx.builder.build_store(res_idx, updated_idx).unwrap();
let result = call_max(ctx, (elem_ty, accumulator), (elem_ty, elem));
ctx.builder.build_store(accumulator_addr, result).unwrap();
Ok(())
},
llvm_int64.const_int(1, false),
llvm_usize.const_int(1, false),
)?;
match fn_name {
"np_argmin" | "np_argmax" => ctx.builder.build_load(res_idx, "").unwrap(),
"np_max" | "np_min" => ctx.builder.build_load(accumulator_addr, "").unwrap(),
_ => codegen_unreachable!(ctx),
}
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
accumulator
}
_ => unsupported_type(ctx, fn_name, &[a_ty]),
_ => unsupported_type(ctx, FN_NAME, &[a_ty]),
})
}
@ -1056,7 +1019,7 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
} else {
codegen_unreachable!(ctx)
unreachable!()
};
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
@ -1086,9 +1049,9 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
/// * `(arg_ty, arg_val)`: The [`Type`] and llvm value of the input argument.
/// * `fn_name`: The name of the function, only used when throwing an error with [`unsupported_type`]
/// * `get_ret_elem_type`: A function that takes in the input scalar [`Type`], and returns the function's return scalar [`Type`].
/// Return a constant [`Type`] here if the return type does not depend on the input type.
/// Return a constant [`Type`] here if the return type does not depend on the input type.
/// * `on_scalar`: The function that acts on the scalars of the input. Returns [`Option::None`]
/// if the scalar type & value are faulty and should panic with [`unsupported_type`].
/// if the scalar type & value are faulty and should panic with [`unsupported_type`].
fn helper_call_numpy_unary_elementwise<'ctx, OnScalarFn, RetElemFn, G>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
@ -1199,9 +1162,9 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
/// * `$name:ident`: The identifier of the rust function to be generated.
/// * `$fn_name:literal`: To be passed to the `fn_name` parameter of [`helper_call_numpy_unary_elementwise`]
/// * `$get_ret_elem_type:expr`: To be passed to the `get_ret_elem_type` parameter of [`helper_call_numpy_unary_elementwise`].
/// But there is no need to make it a reference.
/// But there is no need to make it a reference.
/// * `$on_scalar:expr`: To be passed to the `on_scalar` parameter of [`helper_call_numpy_unary_elementwise`].
/// But there is no need to make it a reference.
/// But there is no need to make it a reference.
macro_rules! create_helper_call_numpy_unary_elementwise {
($name:ident, $fn_name:literal, $get_ret_elem_type:expr, $on_scalar:expr) => {
#[allow(clippy::redundant_closure_call)]
@ -1228,7 +1191,7 @@ macro_rules! create_helper_call_numpy_unary_elementwise {
/// * `$name:ident`: The identifier of the rust function to be generated.
/// * `$fn_name:literal`: To be passed to the `fn_name` parameter of [`helper_call_numpy_unary_elementwise`].
/// * `$on_scalar:expr`: The closure (see below for its type) that acts on float scalar values and returns
/// the boolean results of LLVM type `i1`. The returned `i1` value will be converted into an `i8`.
/// the boolean results of LLVM type `i1`. The returned `i1` value will be converted into an `i8`.
///
/// ```ignore
/// // Type of `$on_scalar:expr`
@ -1420,7 +1383,7 @@ create_helper_call_numpy_unary_elementwise_float_to_float!(
create_helper_call_numpy_unary_elementwise_float_to_float!(
call_numpy_rint,
"np_rint",
llvm_intrinsics::call_float_rint
llvm_intrinsics::call_float_roundeven
);
create_helper_call_numpy_unary_elementwise_float_to_float!(
@ -1496,7 +1459,7 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>(
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
} else {
codegen_unreachable!(ctx)
unreachable!()
};
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
@ -1563,7 +1526,7 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>(
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
} else {
codegen_unreachable!(ctx)
unreachable!()
};
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
@ -1630,7 +1593,7 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>(
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
} else {
codegen_unreachable!(ctx)
unreachable!()
};
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
@ -1697,7 +1660,7 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>(
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
} else {
codegen_unreachable!(ctx)
unreachable!()
};
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
@ -1820,7 +1783,7 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>(
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
} else {
codegen_unreachable!(ctx)
unreachable!()
};
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
@ -1887,7 +1850,7 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
} else if is_ndarray2 {
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
} else {
codegen_unreachable!(ctx)
unreachable!()
};
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
@ -1911,501 +1874,3 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]),
})
}
/// Allocates a struct with the fields specified by `out_matrices` and returns a pointer to it
fn build_output_struct<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
out_matrices: Vec<BasicValueEnum<'ctx>>,
) -> PointerValue<'ctx> {
let field_ty =
out_matrices.iter().map(BasicValueEnum::get_type).collect::<Vec<BasicTypeEnum>>();
let out_ty = ctx.ctx.struct_type(&field_ty, false);
let out_ptr = ctx.builder.build_alloca(out_ty, "").unwrap();
for (i, v) in out_matrices.into_iter().enumerate() {
unsafe {
let ptr = ctx
.builder
.build_in_bounds_gep(
out_ptr,
&[
ctx.ctx.i32_type().const_zero(),
ctx.ctx.i32_type().const_int(i as u64, false),
],
"",
)
.unwrap();
ctx.builder.build_store(ptr, v).unwrap();
}
}
out_ptr
}
/// Invokes the `np_linalg_cholesky` linalg function
pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_cholesky";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_cholesky(ctx, x1, out, None);
Ok(out)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
/// Invokes the `np_linalg_qr` linalg function
pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_qr";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k])
.unwrap()
.as_base_value()
.as_basic_value_enum();
let out_r = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_qr(ctx, x1, out_q, out_r, None);
let out_ptr = build_output_struct(ctx, vec![out_q, out_r]);
Ok(ctx.builder.build_load(out_ptr, "QR_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
/// Invokes the `np_linalg_svd` linalg function
pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_svd";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
let out_s = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k])
.unwrap()
.as_base_value()
.as_basic_value_enum();
let out_vh = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_svd(ctx, x1, out_u, out_s, out_vh, None);
let out_ptr = build_output_struct(ctx, vec![out_u, out_s, out_vh]);
Ok(ctx.builder.build_load(out_ptr, "SVD_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
/// Invokes the `np_linalg_inv` linalg function
pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_inv";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_inv(ctx, x1, out, None);
Ok(out)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
/// Invokes the `np_linalg_pinv` linalg function
pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_pinv";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_pinv(ctx, x1, out, None);
Ok(out)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
/// Invokes the `sp_linalg_lu` linalg function
pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "sp_linalg_lu";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
let out_l = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k])
.unwrap()
.as_base_value()
.as_basic_value_enum();
let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_sp_linalg_lu(ctx, x1, out_l, out_u, None);
let out_ptr = build_output_struct(ctx, vec![out_l, out_u]);
Ok(ctx.builder.build_load(out_ptr, "LU_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
/// Invokes the `np_linalg_matrix_power` linalg function
pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_matrix_power";
let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
let x2 = call_float(generator, ctx, (x2_ty, x2)).unwrap();
let llvm_usize = generator.get_size_type(ctx.ctx);
if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::FloatValue(n2)) = (x1, x2) {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
// Changing second parameter to a `NDArray` for uniformity in function call
let n2_array = numpy::create_ndarray_const_shape(
generator,
ctx,
elem_ty,
&[llvm_usize.const_int(1, false)],
)
.unwrap();
unsafe {
n2_array.data().set_unchecked(
ctx,
generator,
&llvm_usize.const_zero(),
n2.as_basic_value_enum(),
);
};
let n2_array = n2_array.as_base_value().as_basic_value_enum();
let outdim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let outdim1 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_matrix_power(ctx, x1, n2_array, out, None);
Ok(out)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
}
}
/// Invokes the `np_linalg_det` linalg function
pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_matrix_power";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(_) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
// Changing second parameter to a `NDArray` for uniformity in function call
let out = numpy::create_ndarray_const_shape(
generator,
ctx,
elem_ty,
&[llvm_usize.const_int(1, false)],
)
.unwrap();
extern_fns::call_np_linalg_det(ctx, x1, out.as_base_value().as_basic_value_enum(), None);
let res =
unsafe { out.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
Ok(res)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
/// Invokes the `sp_linalg_schur` linalg function
pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "sp_linalg_schur";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let out_t = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
let out_z = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_sp_linalg_schur(ctx, x1, out_t, out_z, None);
let out_ptr = build_output_struct(ctx, vec![out_t, out_z]);
Ok(ctx.builder.build_load(out_ptr, "Schur_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}
/// Invokes the `sp_linalg_hessenberg` linalg function
pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "sp_linalg_hessenberg";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let dim0 = unsafe {
n1.dim_sizes()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let out_h = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_sp_linalg_hessenberg(ctx, x1, out_h, out_q, None);
let out_ptr = build_output_struct(ctx, vec![out_h, out_q]);
Ok(ctx
.builder
.build_load(out_ptr, "Hessenberg_decomposition_result")
.map(Into::into)
.unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
}
}

View File

@ -1,16 +1,17 @@
use inkwell::{
context::Context,
types::{AnyTypeEnum, ArrayType, BasicType, BasicTypeEnum, IntType, PointerType, StructType},
values::{ArrayValue, BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue},
AddressSpace, IntPredicate,
};
use super::{
use crate::codegen::{
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
llvm_intrinsics::call_int_umin,
stmt::gen_for_callback_incrementing,
CodeGenContext, CodeGenerator,
};
use inkwell::context::Context;
use inkwell::types::{ArrayType, BasicType, StructType};
use inkwell::values::{ArrayValue, BasicValue, StructValue};
use inkwell::{
types::{AnyTypeEnum, BasicTypeEnum, IntType, PointerType},
values::{BasicValueEnum, IntValue, PointerValue},
AddressSpace, IntPredicate,
};
/// A LLVM type that is used to represent a non-primitive type in NAC3.
pub trait ProxyType<'ctx>: Into<Self::Base> {
@ -712,25 +713,12 @@ impl<'ctx> ListValue<'ctx> {
/// If `size` is [None], the size stored in the field of this instance is used instead.
pub fn create_data(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
ctx: &CodeGenContext<'ctx, '_>,
elem_ty: BasicTypeEnum<'ctx>,
size: Option<IntValue<'ctx>>,
) {
let size = size.unwrap_or_else(|| self.load_size(ctx, None));
let data = ctx
.builder
.build_select(
ctx.builder
.build_int_compare(IntPredicate::NE, size, self.llvm_usize.const_zero(), "")
.unwrap(),
ctx.builder.build_array_alloca(elem_ty, size, "").unwrap(),
elem_ty.ptr_type(AddressSpace::default()).const_zero(),
"",
)
.map(BasicValueEnum::into_pointer_value)
.unwrap();
self.store_data(ctx, data);
self.store_data(ctx, ctx.builder.build_array_alloca(elem_ty, size, "").unwrap());
}
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
@ -1249,13 +1237,11 @@ impl<'ctx> NDArrayType<'ctx> {
/// Returns the element type of this `ndarray` type.
#[must_use]
pub fn element_type(&self) -> AnyTypeEnum<'ctx> {
pub fn element_type(&self) -> BasicTypeEnum<'ctx> {
self.as_base_type()
.get_element_type()
.into_struct_type()
.get_field_type_at_index(2)
.map(BasicTypeEnum::into_pointer_type)
.map(PointerType::get_element_type)
.unwrap()
}
}
@ -1405,7 +1391,7 @@ impl<'ctx> NDArrayValue<'ctx> {
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
/// on the field.
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default();
@ -1718,7 +1704,6 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(len, false),
|generator, ctx, _, i| {

View File

@ -1,9 +1,3 @@
use std::collections::HashMap;
use indexmap::IndexMap;
use nac3parser::ast::StrRef;
use crate::{
symbol_resolver::SymbolValue,
toplevel::DefinitionId,
@ -15,6 +9,10 @@ use crate::{
},
};
use indexmap::IndexMap;
use nac3parser::ast::StrRef;
use std::collections::HashMap;
pub struct ConcreteTypeStore {
store: Vec<ConcreteTypeEnum>,
}
@ -27,7 +25,6 @@ pub struct ConcreteFuncArg {
pub name: StrRef,
pub ty: ConcreteType,
pub default_value: Option<SymbolValue>,
pub is_vararg: bool,
}
#[derive(Clone, Debug)]
@ -49,7 +46,9 @@ pub enum ConcreteTypeEnum {
TPrimitive(Primitive),
TTuple {
ty: Vec<ConcreteType>,
is_vararg_ctx: bool,
},
TList {
ty: ConcreteType,
},
TObj {
obj_id: DefinitionId,
@ -106,16 +105,8 @@ impl ConcreteTypeStore {
.iter()
.map(|arg| ConcreteFuncArg {
name: arg.name,
ty: if arg.is_vararg {
let tuple_ty = unifier
.add_ty(TypeEnum::TTuple { ty: vec![arg.ty], is_vararg_ctx: true });
self.from_unifier_type(unifier, primitives, tuple_ty, cache)
} else {
self.from_unifier_type(unifier, primitives, arg.ty, cache)
},
ty: self.from_unifier_type(unifier, primitives, arg.ty, cache),
default_value: arg.default_value.clone(),
is_vararg: arg.is_vararg,
})
.collect(),
ret: self.from_unifier_type(unifier, primitives, signature.ret, cache),
@ -170,12 +161,14 @@ impl ConcreteTypeStore {
cache.insert(ty, None);
let ty_enum = unifier.get_ty(ty);
let result = match &*ty_enum {
TypeEnum::TTuple { ty, is_vararg_ctx } => ConcreteTypeEnum::TTuple {
TypeEnum::TTuple { ty } => ConcreteTypeEnum::TTuple {
ty: ty
.iter()
.map(|t| self.from_unifier_type(unifier, primitives, *t, cache))
.collect(),
is_vararg_ctx: *is_vararg_ctx,
},
TypeEnum::TList { ty } => ConcreteTypeEnum::TList {
ty: self.from_unifier_type(unifier, primitives, *ty, cache),
},
TypeEnum::TObj { obj_id, fields, params } => ConcreteTypeEnum::TObj {
obj_id: *obj_id,
@ -261,13 +254,15 @@ impl ConcreteTypeStore {
*cache.get_mut(&cty).unwrap() = Some(ty);
return ty;
}
ConcreteTypeEnum::TTuple { ty, is_vararg_ctx } => TypeEnum::TTuple {
ConcreteTypeEnum::TTuple { ty } => TypeEnum::TTuple {
ty: ty
.iter()
.map(|cty| self.to_unifier_type(unifier, primitives, *cty, cache))
.collect(),
is_vararg_ctx: *is_vararg_ctx,
},
ConcreteTypeEnum::TList { ty } => {
TypeEnum::TList { ty: self.to_unifier_type(unifier, primitives, *ty, cache) }
}
ConcreteTypeEnum::TVirtual { ty } => {
TypeEnum::TVirtual { ty: self.to_unifier_type(unifier, primitives, *ty, cache) }
}
@ -291,7 +286,6 @@ impl ConcreteTypeStore {
name: arg.name,
ty: self.to_unifier_type(unifier, primitives, arg.ty, cache),
default_value: arg.default_value.clone(),
is_vararg: false,
})
.collect(),
ret: self.to_unifier_type(unifier, primitives, *ret, cache),

File diff suppressed because it is too large Load Diff

View File

@ -1,102 +1,517 @@
use inkwell::{
attributes::{Attribute, AttributeLoc},
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue},
};
use inkwell::attributes::{Attribute, AttributeLoc};
use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue};
use itertools::Either;
use super::CodeGenContext;
use crate::codegen::CodeGenContext;
/// Macro to generate extern function
/// Both function return type and function parameter type are `FloatValue`
///
/// Arguments:
/// * `unary/binary`: Whether the extern function requires one (unary) or two (binary) operands
/// * `$fn_name:ident`: The identifier of the rust function to be generated
/// * `$extern_fn:literal`: Name of underlying extern function
///
/// Optional Arguments:
/// * `$(,$attributes:literal)*)`: Attributes linked with the extern function.
/// The default attributes are "mustprogress", "nofree", "nounwind", "willreturn", and "writeonly".
/// These will be used unless other attributes are specified
/// * `$(,$args:ident)*`: Operands of the extern function
/// The data type of these operands will be set to `FloatValue`
///
macro_rules! generate_extern_fn {
("unary", $fn_name:ident, $extern_fn:literal) => {
generate_extern_fn!($fn_name, $extern_fn, arg, "mustprogress", "nofree", "nounwind", "willreturn", "writeonly");
};
("unary", $fn_name:ident, $extern_fn:literal $(,$attributes:literal)*) => {
generate_extern_fn!($fn_name, $extern_fn, arg $(,$attributes)*);
};
("binary", $fn_name:ident, $extern_fn:literal) => {
generate_extern_fn!($fn_name, $extern_fn, arg1, arg2, "mustprogress", "nofree", "nounwind", "willreturn", "writeonly");
};
("binary", $fn_name:ident, $extern_fn:literal $(,$attributes:literal)*) => {
generate_extern_fn!($fn_name, $extern_fn, arg1, arg2 $(,$attributes)*);
};
($fn_name:ident, $extern_fn:literal $(,$args:ident)* $(,$attributes:literal)*) => {
#[doc = concat!("Invokes the [`", stringify!($extern_fn), "`](https://en.cppreference.com/w/c/numeric/math/", stringify!($llvm_name), ") function." )]
pub fn $fn_name<'ctx>(
ctx: &CodeGenContext<'ctx, '_>
$(,$args: FloatValue<'ctx>)*,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = $extern_fn;
/// Invokes the [`tan`](https://en.cppreference.com/w/c/numeric/math/tan) function.
pub fn call_tan<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "tan";
let llvm_f64 = ctx.ctx.f64_type();
$(debug_assert_eq!($args.get_type(), llvm_f64);)*
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[$($args.get_type().into()),*], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in [$($attributes),*] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[$($args.into()),*], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
};
func
});
ctx.builder
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
generate_extern_fn!("unary", call_tan, "tan");
generate_extern_fn!("unary", call_asin, "asin");
generate_extern_fn!("unary", call_acos, "acos");
generate_extern_fn!("unary", call_atan, "atan");
generate_extern_fn!("unary", call_sinh, "sinh");
generate_extern_fn!("unary", call_cosh, "cosh");
generate_extern_fn!("unary", call_tanh, "tanh");
generate_extern_fn!("unary", call_asinh, "asinh");
generate_extern_fn!("unary", call_acosh, "acosh");
generate_extern_fn!("unary", call_atanh, "atanh");
generate_extern_fn!("unary", call_expm1, "expm1");
generate_extern_fn!(
"unary",
call_cbrt,
"cbrt",
"mustprogress",
"nofree",
"nosync",
"nounwind",
"readonly",
"willreturn"
);
generate_extern_fn!("unary", call_erf, "erf", "nounwind");
generate_extern_fn!("unary", call_erfc, "erfc", "nounwind");
generate_extern_fn!("unary", call_j1, "j1", "nounwind");
/// Invokes the [`asin`](https://en.cppreference.com/w/c/numeric/math/asin) function.
pub fn call_asin<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "asin";
generate_extern_fn!("binary", call_atan2, "atan2");
generate_extern_fn!("binary", call_hypot, "hypot", "nounwind");
generate_extern_fn!("binary", call_nextafter, "nextafter", "nounwind");
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`acos`](https://en.cppreference.com/w/c/numeric/math/acos) function.
pub fn call_acos<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "acos";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`atan`](https://en.cppreference.com/w/c/numeric/math/atan) function.
pub fn call_atan<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "atan";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`sinh`](https://en.cppreference.com/w/c/numeric/math/sinh) function.
pub fn call_sinh<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "sinh";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`cosh`](https://en.cppreference.com/w/c/numeric/math/cosh) function.
pub fn call_cosh<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "cosh";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`tanh`](https://en.cppreference.com/w/c/numeric/math/tanh) function.
pub fn call_tanh<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "tanh";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`asinh`](https://en.cppreference.com/w/c/numeric/math/asinh) function.
pub fn call_asinh<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "asinh";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`acosh`](https://en.cppreference.com/w/c/numeric/math/acosh) function.
pub fn call_acosh<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "acosh";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`atanh`](https://en.cppreference.com/w/c/numeric/math/atanh) function.
pub fn call_atanh<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "atanh";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`expm1`](https://en.cppreference.com/w/c/numeric/math/expm1) function.
pub fn call_expm1<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "expm1";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`cbrt`](https://en.cppreference.com/w/c/numeric/math/cbrt) function.
pub fn call_cbrt<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "cbrt";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nosync", "nounwind", "readonly", "willreturn"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`erf`](https://en.cppreference.com/w/c/numeric/math/erf) function.
pub fn call_erf<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "erf";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0),
);
func
});
ctx.builder
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`erfc`](https://en.cppreference.com/w/c/numeric/math/erfc) function.
pub fn call_erfc<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "erfc";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0),
);
func
});
ctx.builder
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`j1`](https://www.gnu.org/software/libc/manual/html_node/Special-Functions.html#index-j1)
/// function.
pub fn call_j1<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "j1";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0),
);
func
});
ctx.builder
.build_call(extern_fn, &[arg.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`atan2`](https://en.cppreference.com/w/c/numeric/math/atan2) function.
pub fn call_atan2<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
y: FloatValue<'ctx>,
x: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "atan2";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(y.get_type(), llvm_f64);
debug_assert_eq!(x.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into(), llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[y.into(), x.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`ldexp`](https://en.cppreference.com/w/c/numeric/math/ldexp) function.
pub fn call_ldexp<'ctx>(
@ -133,61 +548,66 @@ pub fn call_ldexp<'ctx>(
.unwrap()
}
/// Macro to generate `np_linalg` and `sp_linalg` functions
/// The function takes as input `NDArray` and returns ()
///
/// Arguments:
/// * `$fn_name:ident`: The identifier of the rust function to be generated
/// * `$extern_fn:literal`: Name of underlying extern function
/// * (2/3/4): Number of `NDArray` that function takes as input
///
/// Note:
/// The operands and resulting `NDArray` are both passed as input to the funcion
/// It is the responsibility of caller to ensure that output `NDArray` is properly allocated on stack
/// The function changes the content of the output `NDArray` in-place
macro_rules! generate_linalg_extern_fn {
($fn_name:ident, $extern_fn:literal, 2) => {
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2);
};
($fn_name:ident, $extern_fn:literal, 3) => {
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2, mat3);
};
($fn_name:ident, $extern_fn:literal, 4) => {
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2, mat3, mat4);
};
($fn_name:ident, $extern_fn:literal $(,$input_matrix:ident)*) => {
#[doc = concat!("Invokes the linalg `", stringify!($extern_fn), " function." )]
pub fn $fn_name<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>
$(,$input_matrix: BasicValueEnum<'ctx>)*,
name: Option<&str>,
){
const FN_NAME: &str = $extern_fn;
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = ctx.ctx.void_type().fn_type(&[$($input_matrix.get_type().into()),*], false);
/// Invokes the [`hypot`](https://en.cppreference.com/w/c/numeric/math/hypot) function.
pub fn call_hypot<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
x: FloatValue<'ctx>,
y: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "hypot";
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(x.get_type(), llvm_f64);
debug_assert_eq!(y.get_type(), llvm_f64);
ctx.builder.build_call(extern_fn, &[$($input_matrix.into(),)*], name.unwrap_or_default()).unwrap();
}
};
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into(), llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0),
);
func
});
ctx.builder
.build_call(extern_fn, &[x.into(), y.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
generate_linalg_extern_fn!(call_np_linalg_cholesky, "np_linalg_cholesky", 2);
generate_linalg_extern_fn!(call_np_linalg_qr, "np_linalg_qr", 3);
generate_linalg_extern_fn!(call_np_linalg_svd, "np_linalg_svd", 4);
generate_linalg_extern_fn!(call_np_linalg_inv, "np_linalg_inv", 2);
generate_linalg_extern_fn!(call_np_linalg_pinv, "np_linalg_pinv", 2);
generate_linalg_extern_fn!(call_np_linalg_matrix_power, "np_linalg_matrix_power", 3);
generate_linalg_extern_fn!(call_np_linalg_det, "np_linalg_det", 2);
generate_linalg_extern_fn!(call_sp_linalg_lu, "sp_linalg_lu", 3);
generate_linalg_extern_fn!(call_sp_linalg_schur, "sp_linalg_schur", 3);
generate_linalg_extern_fn!(call_sp_linalg_hessenberg, "sp_linalg_hessenberg", 3);
/// Invokes the [`nextafter`](https://en.cppreference.com/w/c/numeric/math/nextafter) function.
pub fn call_nextafter<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
from: FloatValue<'ctx>,
to: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "nextafter";
let llvm_f64 = ctx.ctx.f64_type();
debug_assert_eq!(from.get_type(), llvm_f64);
debug_assert_eq!(to.get_type(), llvm_f64);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into(), llvm_f64.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0),
);
func
});
ctx.builder
.build_call(extern_fn, &[from.into(), to.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}

View File

@ -1,18 +1,16 @@
use crate::{
codegen::{bool_to_i1, bool_to_i8, classes::ArraySliceValue, expr::*, stmt::*, CodeGenContext},
symbol_resolver::ValueEnum,
toplevel::{DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type},
};
use inkwell::{
context::Context,
types::{BasicTypeEnum, IntType},
values::{BasicValueEnum, IntValue, PointerValue},
};
use nac3parser::ast::{Expr, Stmt, StrRef};
use super::{bool_to_i1, bool_to_i8, classes::ArraySliceValue, expr::*, stmt::*, CodeGenContext};
use crate::{
symbol_resolver::ValueEnum,
toplevel::{DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type},
};
pub trait CodeGenerator {
/// Return the module name for the code generator.
fn get_name(&self) -> &str;
@ -59,7 +57,6 @@ pub trait CodeGenerator {
/// - fun: Function signature, definition ID and the substitution key.
/// - params: Function parameters. Note that this does not include the object even if the
/// function is a class method.
///
/// Note that this function should check if the function is generated in another thread (due to
/// possible race condition), see the default implementation for an example.
fn gen_func_instance<'ctx>(
@ -126,45 +123,11 @@ pub trait CodeGenerator {
ctx: &mut CodeGenContext<'ctx, '_>,
target: &Expr<Option<Type>>,
value: ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String>
where
Self: Sized,
{
gen_assign(self, ctx, target, value, value_ty)
}
/// Generate code for an assignment expression where LHS is a `"target_list"`.
///
/// See <https://docs.python.org/3/reference/simple_stmts.html#assignment-statements>.
fn gen_assign_target_list<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
targets: &Vec<Expr<Option<Type>>>,
value: ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String>
where
Self: Sized,
{
gen_assign_target_list(self, ctx, targets, value, value_ty)
}
/// Generate code for an item assignment.
///
/// i.e., `target[key] = value`
fn gen_setitem<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
target: &Expr<Option<Type>>,
key: &Expr<Option<Type>>,
value: ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String>
where
Self: Sized,
{
gen_setitem(self, ctx, target, key, value, value_ty)
gen_assign(self, ctx, target, value)
}
/// Generate code for a while expression.

View File

@ -0,0 +1,389 @@
typedef _BitInt(8) int8_t;
typedef unsigned _BitInt(8) uint8_t;
typedef _BitInt(32) int32_t;
typedef unsigned _BitInt(32) uint32_t;
typedef _BitInt(64) int64_t;
typedef unsigned _BitInt(64) uint64_t;
# define MAX(a, b) (a > b ? a : b)
# define MIN(a, b) (a > b ? b : a)
# define NULL ((void *) 0)
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
// need to make sure `exp >= 0` before calling this function
#define DEF_INT_EXP(T) T __nac3_int_exp_##T( \
T base, \
T exp \
) { \
T res = (T)1; \
/* repeated squaring method */ \
do { \
if (exp & 1) res *= base; /* for n odd */ \
exp >>= 1; \
base *= base; \
} while (exp); \
return res; \
} \
DEF_INT_EXP(int32_t)
DEF_INT_EXP(int64_t)
DEF_INT_EXP(uint32_t)
DEF_INT_EXP(uint64_t)
int32_t __nac3_slice_index_bound(int32_t i, const int32_t len) {
if (i < 0) {
i = len + i;
}
if (i < 0) {
return 0;
} else if (i > len) {
return len;
}
return i;
}
int32_t __nac3_range_slice_len(const int32_t start, const int32_t end, const int32_t step) {
int32_t diff = end - start;
if (diff > 0 && step > 0) {
return ((diff - 1) / step) + 1;
} else if (diff < 0 && step < 0) {
return ((diff + 1) / step) + 1;
} else {
return 0;
}
}
// Handle list assignment and dropping part of the list when
// both dest_step and src_step are +1.
// - All the index must *not* be out-of-bound or negative,
// - The end index is *inclusive*,
// - The length of src and dest slice size should already
// be checked: if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest)
int32_t __nac3_list_slice_assign_var_size(
int32_t dest_start,
int32_t dest_end,
int32_t dest_step,
uint8_t *dest_arr,
int32_t dest_arr_len,
int32_t src_start,
int32_t src_end,
int32_t src_step,
uint8_t *src_arr,
int32_t src_arr_len,
const int32_t size
) {
/* if dest_arr_len == 0, do nothing since we do not support extending list */
if (dest_arr_len == 0) return dest_arr_len;
/* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */
if (src_step == dest_step && dest_step == 1) {
const int32_t src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0;
const int32_t dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
if (src_len > 0) {
__builtin_memmove(
dest_arr + dest_start * size,
src_arr + src_start * size,
src_len * size
);
}
if (dest_len > 0) {
/* dropping */
__builtin_memmove(
dest_arr + (dest_start + src_len) * size,
dest_arr + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size
);
}
/* shrink size */
return dest_arr_len - (dest_len - src_len);
}
/* if two range overlaps, need alloca */
uint8_t need_alloca =
(dest_arr == src_arr)
&& !(
MAX(dest_start, dest_end) < MIN(src_start, src_end)
|| MAX(src_start, src_end) < MIN(dest_start, dest_end)
);
if (need_alloca) {
uint8_t *tmp = __builtin_alloca(src_arr_len * size);
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
src_arr = tmp;
}
int32_t src_ind = src_start;
int32_t dest_ind = dest_start;
for (;
(src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end);
src_ind += src_step, dest_ind += dest_step
) {
/* for constant optimization */
if (size == 1) {
__builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1);
} else if (size == 4) {
__builtin_memcpy(dest_arr + dest_ind * 4, src_arr + src_ind * 4, 4);
} else if (size == 8) {
__builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8);
} else {
/* memcpy for var size, cannot overlap after previous alloca */
__builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size);
}
}
/* only dest_step == 1 can we shrink the dest list. */
/* size should be ensured prior to calling this function */
if (dest_step == 1 && dest_end >= dest_start) {
__builtin_memmove(
dest_arr + dest_ind * size,
dest_arr + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size
);
return dest_arr_len - (dest_end - dest_ind) - 1;
}
return dest_arr_len;
}
int32_t __nac3_isinf(double x) {
return __builtin_isinf(x);
}
int32_t __nac3_isnan(double x) {
return __builtin_isnan(x);
}
double tgamma(double arg);
double __nac3_gamma(double z) {
// Handling for denormals
// | x | Python gamma(x) | C tgamma(x) |
// --- | ----------------- | --------------- | ----------- |
// (1) | nan | nan | nan |
// (2) | -inf | -inf | inf |
// (3) | inf | inf | inf |
// (4) | 0.0 | inf | inf |
// (5) | {-1.0, -2.0, ...} | inf | nan |
// (1)-(3)
if (__builtin_isinf(z) || __builtin_isnan(z)) {
return z;
}
double v = tgamma(z);
// (4)-(5)
return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v;
}
double lgamma(double arg);
double __nac3_gammaln(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: gammaln(-inf) -> -inf
// - libm : lgamma(-inf) -> inf
if (__builtin_isinf(x)) {
return x;
}
return lgamma(x);
}
double j0(double x);
double __nac3_j0(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: j0(inf) -> nan
// - libm : j0(inf) -> 0.0
if (__builtin_isinf(x)) {
return __builtin_nan("");
}
return j0(x);
}
uint32_t __nac3_ndarray_calc_size(
const uint64_t *list_data,
uint32_t list_len,
uint32_t begin_idx,
uint32_t end_idx
) {
__builtin_assume(end_idx <= list_len);
uint32_t num_elems = 1;
for (uint32_t i = begin_idx; i < end_idx; ++i) {
uint64_t val = list_data[i];
__builtin_assume(val > 0);
num_elems *= val;
}
return num_elems;
}
uint64_t __nac3_ndarray_calc_size64(
const uint64_t *list_data,
uint64_t list_len,
uint64_t begin_idx,
uint64_t end_idx
) {
__builtin_assume(end_idx <= list_len);
uint64_t num_elems = 1;
for (uint64_t i = begin_idx; i < end_idx; ++i) {
uint64_t val = list_data[i];
__builtin_assume(val > 0);
num_elems *= val;
}
return num_elems;
}
void __nac3_ndarray_calc_nd_indices(
uint32_t index,
const uint32_t* dims,
uint32_t num_dims,
uint32_t* idxs
) {
uint32_t stride = 1;
for (uint32_t dim = 0; dim < num_dims; dim++) {
uint32_t i = num_dims - dim - 1;
__builtin_assume(dims[i] > 0);
idxs[i] = (index / stride) % dims[i];
stride *= dims[i];
}
}
void __nac3_ndarray_calc_nd_indices64(
uint64_t index,
const uint64_t* dims,
uint64_t num_dims,
uint32_t* idxs
) {
uint64_t stride = 1;
for (uint64_t dim = 0; dim < num_dims; dim++) {
uint64_t i = num_dims - dim - 1;
__builtin_assume(dims[i] > 0);
idxs[i] = (uint32_t) ((index / stride) % dims[i]);
stride *= dims[i];
}
}
uint32_t __nac3_ndarray_flatten_index(
const uint32_t* dims,
uint32_t num_dims,
const uint32_t* indices,
uint32_t num_indices
) {
uint32_t idx = 0;
uint32_t stride = 1;
for (uint32_t i = 0; i < num_dims; ++i) {
uint32_t ri = num_dims - i - 1;
if (ri < num_indices) {
idx += (stride * indices[ri]);
}
__builtin_assume(dims[i] > 0);
stride *= dims[ri];
}
return idx;
}
uint64_t __nac3_ndarray_flatten_index64(
const uint64_t* dims,
uint64_t num_dims,
const uint32_t* indices,
uint64_t num_indices
) {
uint64_t idx = 0;
uint64_t stride = 1;
for (uint64_t i = 0; i < num_dims; ++i) {
uint64_t ri = num_dims - i - 1;
if (ri < num_indices) {
idx += (stride * indices[ri]);
}
__builtin_assume(dims[i] > 0);
stride *= dims[ri];
}
return idx;
}
void __nac3_ndarray_calc_broadcast(
const uint32_t *lhs_dims,
uint32_t lhs_ndims,
const uint32_t *rhs_dims,
uint32_t rhs_ndims,
uint32_t *out_dims
) {
uint32_t max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
for (uint32_t i = 0; i < max_ndims; ++i) {
uint32_t *lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : NULL;
uint32_t *rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : NULL;
uint32_t *out_dim = &out_dims[max_ndims - i - 1];
if (lhs_dim_sz == NULL) {
*out_dim = *rhs_dim_sz;
} else if (rhs_dim_sz == NULL) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == 1) {
*out_dim = *rhs_dim_sz;
} else if (*rhs_dim_sz == 1) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == *rhs_dim_sz) {
*out_dim = *lhs_dim_sz;
} else {
__builtin_unreachable();
}
}
}
void __nac3_ndarray_calc_broadcast64(
const uint64_t *lhs_dims,
uint64_t lhs_ndims,
const uint64_t *rhs_dims,
uint64_t rhs_ndims,
uint64_t *out_dims
) {
uint64_t max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
for (uint64_t i = 0; i < max_ndims; ++i) {
uint64_t *lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : NULL;
uint64_t *rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : NULL;
uint64_t *out_dim = &out_dims[max_ndims - i - 1];
if (lhs_dim_sz == NULL) {
*out_dim = *rhs_dim_sz;
} else if (rhs_dim_sz == NULL) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == 1) {
*out_dim = *rhs_dim_sz;
} else if (*rhs_dim_sz == 1) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == *rhs_dim_sz) {
*out_dim = *lhs_dim_sz;
} else {
__builtin_unreachable();
}
}
}
void __nac3_ndarray_calc_broadcast_idx(
const uint32_t *src_dims,
uint32_t src_ndims,
const uint32_t *in_idx,
uint32_t *out_idx
) {
for (uint32_t i = 0; i < src_ndims; ++i) {
uint32_t src_i = src_ndims - i - 1;
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
}
}
void __nac3_ndarray_calc_broadcast_idx64(
const uint64_t *src_dims,
uint64_t src_ndims,
const uint32_t *in_idx,
uint32_t *out_idx
) {
for (uint64_t i = 0; i < src_ndims; ++i) {
uint64_t src_i = src_ndims - i - 1;
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : (uint32_t) in_idx[src_i];
}
}

View File

@ -1,30 +1,28 @@
use crate::typecheck::typedef::Type;
use super::{
classes::{
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue,
TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
},
llvm_intrinsics, CodeGenContext, CodeGenerator,
};
use crate::codegen::classes::TypedArrayLikeAccessor;
use crate::codegen::stmt::gen_for_callback_incrementing;
use inkwell::{
attributes::{Attribute, AttributeLoc},
context::Context,
memory_buffer::MemoryBuffer,
module::Module,
types::{BasicTypeEnum, IntType},
values::{BasicValue, BasicValueEnum, CallSiteValue, FloatValue, IntValue},
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue},
AddressSpace, IntPredicate,
};
use itertools::Either;
use nac3parser::ast::Expr;
use super::{
classes::{
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue,
TypedArrayLikeAccessor, TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
},
llvm_intrinsics,
macros::codegen_unreachable,
stmt::gen_for_callback_incrementing,
CodeGenContext, CodeGenerator,
};
use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Type};
#[must_use]
pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> {
pub fn load_irrt(ctx: &Context) -> Module {
let bitcode_buf = MemoryBuffer::create_from_memory_range(
include_bytes!(concat!(env!("OUT_DIR"), "/irrt.bc")),
"irrt_bitcode_buffer",
@ -40,25 +38,6 @@ pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver)
let function = irrt_mod.get_function(symbol).unwrap();
function.add_attribute(AttributeLoc::Function, ctx.create_enum_attribute(inline_attr, 0));
}
// Initialize all global `EXN_*` exception IDs in IRRT with the [`SymbolResolver`].
let exn_id_type = ctx.i32_type();
let errors = &[
("EXN_INDEX_ERROR", "0:IndexError"),
("EXN_VALUE_ERROR", "0:ValueError"),
("EXN_ASSERTION_ERROR", "0:AssertionError"),
("EXN_TYPE_ERROR", "0:TypeError"),
];
for (irrt_name, symbol_name) in errors {
let exn_id = symbol_resolver.get_string_id(symbol_name);
let exn_id = exn_id_type.const_int(exn_id as u64, false).as_basic_value_enum();
let global = irrt_mod.get_global(irrt_name).unwrap_or_else(|| {
panic!("Exception symbol name '{irrt_name}' should exist in the IRRT LLVM module")
});
global.set_initializer(&exn_id);
}
irrt_mod
}
@ -76,7 +55,7 @@ pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>(
(64, 64, true) => "__nac3_int_exp_int64_t",
(32, 32, false) => "__nac3_int_exp_uint32_t",
(64, 64, false) => "__nac3_int_exp_uint64_t",
_ => codegen_unreachable!(ctx),
_ => unreachable!(),
};
let base_type = base.get_type();
let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| {
@ -462,7 +441,7 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
BasicTypeEnum::IntType(t) => t.size_of(),
BasicTypeEnum::PointerType(t) => t.size_of(),
BasicTypeEnum::StructType(t) => t.size_of().unwrap(),
_ => codegen_unreachable!(ctx),
_ => unreachable!(),
};
ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size").unwrap()
}
@ -589,8 +568,7 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo
///
/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension.
/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for,
/// or [`None`] if starting from the first dimension and ending at the last dimension
/// respectively.
/// or [`None`] if starting from the first dimension and ending at the last dimension respectively.
pub fn call_ndarray_calc_size<'ctx, G, Dims>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
@ -601,16 +579,18 @@ where
G: CodeGenerator + ?Sized,
Dims: ArrayLikeIndexer<'ctx>,
{
let llvm_i64 = ctx.ctx.i64_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let llvm_pi64 = llvm_i64.ptr_type(AddressSpace::default());
let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_size",
64 => "__nac3_ndarray_calc_size64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
bw => unreachable!("Unsupported size type bit width: {}", bw),
};
let ndarray_calc_size_fn_t = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()],
&[llvm_pi64.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()],
false,
);
let ndarray_calc_size_fn =
@ -642,7 +622,7 @@ where
///
/// * `index` - The index to compute the multidimensional index for.
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
/// `NDArray`.
pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &mut CodeGenContext<'ctx, '_>,
@ -658,7 +638,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_nd_indices",
64 => "__nac3_ndarray_calc_nd_indices64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
bw => unreachable!("Unsupported size type bit width: {}", bw),
};
let ndarray_calc_nd_indices_fn =
ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| {
@ -727,7 +707,7 @@ where
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_flatten_index",
64 => "__nac3_ndarray_flatten_index64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
bw => unreachable!("Unsupported size type bit width: {}", bw),
};
let ndarray_flatten_index_fn =
ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| {
@ -766,7 +746,7 @@ where
/// multidimensional index.
///
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
/// `NDArray`.
/// * `indices` - The multidimensional index to compute the flattened index for.
pub fn call_ndarray_flatten_index<'ctx, G, Index>(
generator: &mut G,
@ -795,7 +775,7 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast",
64 => "__nac3_ndarray_calc_broadcast64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
bw => unreachable!("Unsupported size type bit width: {}", bw),
};
let ndarray_calc_broadcast_fn =
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
@ -820,7 +800,6 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(min_ndims, false),
|generator, ctx, _, idx| {
@ -915,7 +894,7 @@ pub fn call_ndarray_calc_broadcast_index<
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast_idx",
64 => "__nac3_ndarray_calc_broadcast_idx64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
bw => unreachable!("Unsupported size type bit width: {}", bw),
};
let ndarray_calc_broadcast_fn =
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {

View File

@ -1,14 +1,12 @@
use inkwell::{
context::Context,
intrinsics::Intrinsic,
types::{AnyTypeEnum::IntType, FloatType},
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue},
AddressSpace,
};
use crate::codegen::CodeGenContext;
use inkwell::context::Context;
use inkwell::intrinsics::Intrinsic;
use inkwell::types::AnyTypeEnum::IntType;
use inkwell::types::FloatType;
use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue};
use inkwell::AddressSpace;
use itertools::Either;
use super::CodeGenContext;
/// Returns the string representation for the floating-point type `ft` when used in intrinsic
/// functions.
fn get_float_intrinsic_repr(ctx: &Context, ft: FloatType) -> &'static str {
@ -37,40 +35,6 @@ fn get_float_intrinsic_repr(ctx: &Context, ft: FloatType) -> &'static str {
unreachable!()
}
/// Invokes the [`llvm.va_start`](https://llvm.org/docs/LangRef.html#llvm-va-start-intrinsic)
/// intrinsic.
pub fn call_va_start<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue<'ctx>) {
const FN_NAME: &str = "llvm.va_start";
let intrinsic_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let llvm_void = ctx.ctx.void_type();
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let fn_type = llvm_void.fn_type(&[llvm_p0i8.into()], false);
ctx.module.add_function(FN_NAME, fn_type, None)
});
ctx.builder.build_call(intrinsic_fn, &[arglist.into()], "").unwrap();
}
/// Invokes the [`llvm.va_start`](https://llvm.org/docs/LangRef.html#llvm-va-start-intrinsic)
/// intrinsic.
pub fn call_va_end<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue<'ctx>) {
const FN_NAME: &str = "llvm.va_end";
let intrinsic_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let llvm_void = ctx.ctx.void_type();
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let fn_type = llvm_void.fn_type(&[llvm_p0i8.into()], false);
ctx.module.add_function(FN_NAME, fn_type, None)
});
ctx.builder.build_call(intrinsic_fn, &[arglist.into()], "").unwrap();
}
/// Invokes the [`llvm.stacksave`](https://llvm.org/docs/LangRef.html#llvm-stacksave-intrinsic)
/// intrinsic.
pub fn call_stacksave<'ctx>(
@ -98,30 +62,145 @@ pub fn call_stacksave<'ctx>(
pub fn call_stackrestore<'ctx>(ctx: &CodeGenContext<'ctx, '_>, ptr: PointerValue<'ctx>) {
const FN_NAME: &str = "llvm.stackrestore";
/*
SEE https://github.com/TheDan64/inkwell/issues/496
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
We want `llvm.stackrestore`, but the following would generate `llvm.stackrestore.p0i8`.
```ignore
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_p0i8.into()]))
.unwrap();
```
Temp workaround by manually declaring the intrinsic with the correct function name instead.
*/
let intrinsic_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let llvm_void = ctx.ctx.void_type();
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let fn_type = llvm_void.fn_type(&[llvm_p0i8.into()], false);
ctx.module.add_function(FN_NAME, fn_type, None)
});
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_p0i8.into()]))
.unwrap();
ctx.builder.build_call(intrinsic_fn, &[ptr.into()], "").unwrap();
}
/// Invokes the [`llvm.abs`](https://llvm.org/docs/LangRef.html#llvm-abs-intrinsic) intrinsic.
///
/// * `src` - The value for which the absolute value is to be returned.
/// * `is_int_min_poison` - Whether `poison` is to be returned if `src` is `INT_MIN`.
pub fn call_int_abs<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
src: IntValue<'ctx>,
is_int_min_poison: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
const FN_NAME: &str = "llvm.abs";
debug_assert_eq!(is_int_min_poison.get_type().get_bit_width(), 1);
debug_assert!(is_int_min_poison.is_const());
let llvm_src_t = src.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_src_t.into()]))
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[src.into(), is_int_min_poison.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.smax`](https://llvm.org/docs/LangRef.html#llvm-smax-intrinsic) intrinsic.
pub fn call_int_smax<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
a: IntValue<'ctx>,
b: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
const FN_NAME: &str = "llvm.smax";
debug_assert_eq!(a.get_type().get_bit_width(), b.get_type().get_bit_width());
let llvm_int_t = a.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_int_t.into()]))
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[a.into(), b.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.smin`](https://llvm.org/docs/LangRef.html#llvm-smin-intrinsic) intrinsic.
pub fn call_int_smin<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
a: IntValue<'ctx>,
b: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
const FN_NAME: &str = "llvm.smin";
debug_assert_eq!(a.get_type().get_bit_width(), b.get_type().get_bit_width());
let llvm_int_t = a.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_int_t.into()]))
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[a.into(), b.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.umax`](https://llvm.org/docs/LangRef.html#llvm-umax-intrinsic) intrinsic.
pub fn call_int_umax<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
a: IntValue<'ctx>,
b: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
const FN_NAME: &str = "llvm.umax";
debug_assert_eq!(a.get_type().get_bit_width(), b.get_type().get_bit_width());
let llvm_int_t = a.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_int_t.into()]))
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[a.into(), b.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.umin`](https://llvm.org/docs/LangRef.html#llvm-umin-intrinsic) intrinsic.
pub fn call_int_umin<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
a: IntValue<'ctx>,
b: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
const FN_NAME: &str = "llvm.umin";
debug_assert_eq!(a.get_type().get_bit_width(), b.get_type().get_bit_width());
let llvm_int_t = a.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_int_t.into()]))
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[a.into(), b.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.memcpy`](https://llvm.org/docs/LangRef.html#llvm-memcpy-intrinsic) intrinsic.
///
/// * `dest` - The pointer to the destination. Must be a pointer to an integer type.
@ -185,7 +264,7 @@ pub fn call_memcpy_generic<'ctx>(
dest
} else {
ctx.builder
.build_bit_cast(dest, llvm_p0i8, "")
.build_bitcast(dest, llvm_p0i8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
};
@ -193,7 +272,7 @@ pub fn call_memcpy_generic<'ctx>(
src
} else {
ctx.builder
.build_bit_cast(src, llvm_p0i8, "")
.build_bitcast(src, llvm_p0i8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
};
@ -201,123 +280,28 @@ pub fn call_memcpy_generic<'ctx>(
call_memcpy(ctx, dest, src, len, is_volatile);
}
/// Macro to find and generate build call for llvm intrinsic (body of llvm intrinsic function)
///
/// Arguments:
/// * `$ctx:ident`: Reference to the current Code Generation Context
/// * `$name:ident`: Optional name to be assigned to the llvm build call (Option<&str>)
/// * `$llvm_name:literal`: Name of underlying llvm intrinsic function
/// * `$map_fn:ident`: Mapping function to be applied on `BasicValue` (`BasicValue` -> Function Return Type).
/// Use `BasicValueEnum::into_int_value` for Integer return type and
/// `BasicValueEnum::into_float_value` for Float return type
/// * `$llvm_ty:ident`: Type of first operand
/// * `,($val:ident)*`: Comma separated list of operands
macro_rules! generate_llvm_intrinsic_fn_body {
($ctx:ident, $name:ident, $llvm_name:literal, $map_fn:expr, $llvm_ty:ident $(,$val:ident)*) => {{
const FN_NAME: &str = concat!("llvm.", $llvm_name);
let intrinsic_fn = Intrinsic::find(FN_NAME).and_then(|intrinsic| intrinsic.get_declaration(&$ctx.module, &[$llvm_ty.into()])).unwrap();
$ctx.builder.build_call(intrinsic_fn, &[$($val.into()),*], $name.unwrap_or_default()).map(CallSiteValue::try_as_basic_value).map(|v| v.map_left($map_fn)).map(Either::unwrap_left).unwrap()
}};
}
/// Macro to generate the llvm intrinsic function using [`generate_llvm_intrinsic_fn_body`].
///
/// Arguments:
/// * `float/int`: Indicates the return and argument type of the function
/// * `$fn_name:ident`: The identifier of the rust function to be generated
/// * `$llvm_name:literal`: Name of underlying llvm intrinsic function.
/// Omit "llvm." prefix from the function name i.e. use "ceil" instead of "llvm.ceil"
/// * `$val:ident`: The operand for unary operations
/// * `$val1:ident`, `$val2:ident`: The operands for binary operations
macro_rules! generate_llvm_intrinsic_fn {
("float", $fn_name:ident, $llvm_name:literal, $val:ident) => {
#[doc = concat!("Invokes the [`", stringify!($llvm_name), "`](https://llvm.org/docs/LangRef.html#llvm-", stringify!($llvm_name), "-intrinsic) intrinsic." )]
pub fn $fn_name<'ctx> (
ctx: &CodeGenContext<'ctx, '_>,
$val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
let llvm_ty = $val.get_type();
generate_llvm_intrinsic_fn_body!(ctx, name, $llvm_name, BasicValueEnum::into_float_value, llvm_ty, $val)
}
};
("float", $fn_name:ident, $llvm_name:literal, $val1:ident, $val2:ident) => {
#[doc = concat!("Invokes the [`", stringify!($llvm_name), "`](https://llvm.org/docs/LangRef.html#llvm-", stringify!($llvm_name), "-intrinsic) intrinsic." )]
pub fn $fn_name<'ctx> (
ctx: &CodeGenContext<'ctx, '_>,
$val1: FloatValue<'ctx>,
$val2: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
debug_assert_eq!($val1.get_type(), $val2.get_type());
let llvm_ty = $val1.get_type();
generate_llvm_intrinsic_fn_body!(ctx, name, $llvm_name, BasicValueEnum::into_float_value, llvm_ty, $val1, $val2)
}
};
("int", $fn_name:ident, $llvm_name:literal, $val1:ident, $val2:ident) => {
#[doc = concat!("Invokes the [`", stringify!($llvm_name), "`](https://llvm.org/docs/LangRef.html#llvm-", stringify!($llvm_name), "-intrinsic) intrinsic." )]
pub fn $fn_name<'ctx> (
ctx: &CodeGenContext<'ctx, '_>,
$val1: IntValue<'ctx>,
$val2: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
debug_assert_eq!($val1.get_type().get_bit_width(), $val2.get_type().get_bit_width());
let llvm_ty = $val1.get_type();
generate_llvm_intrinsic_fn_body!(ctx, name, $llvm_name, BasicValueEnum::into_int_value, llvm_ty, $val1, $val2)
}
};
}
/// Invokes the [`llvm.abs`](https://llvm.org/docs/LangRef.html#llvm-abs-intrinsic) intrinsic.
///
/// * `src` - The value for which the absolute value is to be returned.
/// * `is_int_min_poison` - Whether `poison` is to be returned if `src` is `INT_MIN`.
pub fn call_int_abs<'ctx>(
/// Invokes the [`llvm.sqrt`](https://llvm.org/docs/LangRef.html#llvm-sqrt-intrinsic) intrinsic.
pub fn call_float_sqrt<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
src: IntValue<'ctx>,
is_int_min_poison: IntValue<'ctx>,
val: FloatValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
debug_assert_eq!(is_int_min_poison.get_type().get_bit_width(), 1);
debug_assert!(is_int_min_poison.is_const());
) -> FloatValue<'ctx> {
const FN_NAME: &str = "llvm.sqrt";
let src_type = src.get_type();
generate_llvm_intrinsic_fn_body!(
ctx,
name,
"abs",
BasicValueEnum::into_int_value,
src_type,
src,
is_int_min_poison
)
let llvm_float_t = val.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_float_t.into()]))
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
generate_llvm_intrinsic_fn!("int", call_int_smax, "smax", a, b);
generate_llvm_intrinsic_fn!("int", call_int_smin, "smin", a, b);
generate_llvm_intrinsic_fn!("int", call_int_umax, "umax", a, b);
generate_llvm_intrinsic_fn!("int", call_int_umin, "umin", a, b);
generate_llvm_intrinsic_fn!("int", call_expect, "expect", val, expected_val);
generate_llvm_intrinsic_fn!("float", call_float_sqrt, "sqrt", val);
generate_llvm_intrinsic_fn!("float", call_float_sin, "sin", val);
generate_llvm_intrinsic_fn!("float", call_float_cos, "cos", val);
generate_llvm_intrinsic_fn!("float", call_float_pow, "pow", val, power);
generate_llvm_intrinsic_fn!("float", call_float_exp, "exp", val);
generate_llvm_intrinsic_fn!("float", call_float_exp2, "exp2", val);
generate_llvm_intrinsic_fn!("float", call_float_log, "log", val);
generate_llvm_intrinsic_fn!("float", call_float_log10, "log10", val);
generate_llvm_intrinsic_fn!("float", call_float_log2, "log2", val);
generate_llvm_intrinsic_fn!("float", call_float_fabs, "fabs", src);
generate_llvm_intrinsic_fn!("float", call_float_minnum, "minnum", val, power);
generate_llvm_intrinsic_fn!("float", call_float_maxnum, "maxnum", val, power);
generate_llvm_intrinsic_fn!("float", call_float_copysign, "copysign", mag, sgn);
generate_llvm_intrinsic_fn!("float", call_float_floor, "floor", val);
generate_llvm_intrinsic_fn!("float", call_float_ceil, "ceil", val);
generate_llvm_intrinsic_fn!("float", call_float_round, "round", val);
generate_llvm_intrinsic_fn!("float", call_float_rint, "rint", val);
/// Invokes the [`llvm.powi`](https://llvm.org/docs/LangRef.html#llvm-powi-intrinsic) intrinsic.
pub fn call_float_powi<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
@ -343,3 +327,393 @@ pub fn call_float_powi<'ctx>(
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.sin`](https://llvm.org/docs/LangRef.html#llvm-sin-intrinsic) intrinsic.
pub fn call_float_sin<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "llvm.sin";
let llvm_float_t = val.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_float_t.into()]))
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.cos`](https://llvm.org/docs/LangRef.html#llvm-cos-intrinsic) intrinsic.
pub fn call_float_cos<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "llvm.cos";
let llvm_float_t = val.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_float_t.into()]))
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.pow`](https://llvm.org/docs/LangRef.html#llvm-pow-intrinsic) intrinsic.
pub fn call_float_pow<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
power: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "llvm.pow";
debug_assert_eq!(val.get_type(), power.get_type());
let llvm_float_t = val.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_float_t.into()]))
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[val.into(), power.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.exp`](https://llvm.org/docs/LangRef.html#llvm-exp-intrinsic) intrinsic.
pub fn call_float_exp<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "llvm.exp";
let llvm_float_t = val.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_float_t.into()]))
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.exp2`](https://llvm.org/docs/LangRef.html#llvm-exp2-intrinsic) intrinsic.
pub fn call_float_exp2<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "llvm.exp2";
let llvm_float_t = val.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_float_t.into()]))
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.log`](https://llvm.org/docs/LangRef.html#llvm-log-intrinsic) intrinsic.
pub fn call_float_log<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "llvm.log";
let llvm_float_t = val.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_float_t.into()]))
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.log10`](https://llvm.org/docs/LangRef.html#llvm-log10-intrinsic) intrinsic.
pub fn call_float_log10<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "llvm.log10";
let llvm_float_t = val.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_float_t.into()]))
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.log2`](https://llvm.org/docs/LangRef.html#llvm-log2-intrinsic) intrinsic.
pub fn call_float_log2<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "llvm.log2";
let llvm_float_t = val.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_float_t.into()]))
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.fabs`](https://llvm.org/docs/LangRef.html#llvm-fabs-intrinsic) intrinsic.
pub fn call_float_fabs<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
src: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "llvm.fabs";
let llvm_src_t = src.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_src_t.into()]))
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[src.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.minnum`](https://llvm.org/docs/LangRef.html#llvm-minnum-intrinsic) intrinsic.
pub fn call_float_minnum<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val1: FloatValue<'ctx>,
val2: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "llvm.minnum";
debug_assert_eq!(val1.get_type(), val2.get_type());
let llvm_float_t = val1.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_float_t.into()]))
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[val1.into(), val2.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.maxnum`](https://llvm.org/docs/LangRef.html#llvm-maxnum-intrinsic) intrinsic.
pub fn call_float_maxnum<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val1: FloatValue<'ctx>,
val2: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "llvm.maxnum";
debug_assert_eq!(val1.get_type(), val2.get_type());
let llvm_float_t = val1.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_float_t.into()]))
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[val1.into(), val2.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.copysign`](https://llvm.org/docs/LangRef.html#llvm-copysign-intrinsic) intrinsic.
pub fn call_float_copysign<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
mag: FloatValue<'ctx>,
sgn: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "llvm.copysign";
debug_assert_eq!(mag.get_type(), sgn.get_type());
let llvm_float_t = mag.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_float_t.into()]))
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[mag.into(), sgn.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.floor`](https://llvm.org/docs/LangRef.html#llvm-floor-intrinsic) intrinsic.
pub fn call_float_floor<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "llvm.floor";
let llvm_float_t = val.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_float_t.into()]))
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.ceil`](https://llvm.org/docs/LangRef.html#llvm-ceil-intrinsic) intrinsic.
pub fn call_float_ceil<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "llvm.ceil";
let llvm_float_t = val.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_float_t.into()]))
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.round`](https://llvm.org/docs/LangRef.html#llvm-round-intrinsic) intrinsic.
pub fn call_float_round<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "llvm.round";
let llvm_float_t = val.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_float_t.into()]))
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the
/// [`llvm.roundeven`](https://llvm.org/docs/LangRef.html#llvm-roundeven-intrinsic) intrinsic.
pub fn call_float_roundeven<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "llvm.roundeven";
let llvm_float_t = val.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_float_t.into()]))
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[val.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.expect`](https://llvm.org/docs/LangRef.html#llvm-expect-intrinsic) intrinsic.
pub fn call_expect<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: IntValue<'ctx>,
expected_val: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
const FN_NAME: &str = "llvm.expect";
debug_assert_eq!(val.get_type().get_bit_width(), expected_val.get_type().get_bit_width());
let llvm_int_t = val.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_int_t.into()]))
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[val.into(), expected_val.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}

View File

@ -1,12 +1,12 @@
use std::{
collections::{HashMap, HashSet},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
use crate::{
codegen::classes::{ListType, NDArrayType, ProxyType, RangeType},
symbol_resolver::{StaticValue, SymbolResolver},
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef},
typecheck::{
type_inferencer::{CodeLocation, PrimitiveStore},
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
},
thread,
};
use crossbeam::channel::{unbounded, Receiver, Sender};
use inkwell::{
attributes::{Attribute, AttributeLoc},
@ -24,21 +24,14 @@ use inkwell::{
AddressSpace, IntPredicate, OptimizationLevel,
};
use itertools::Itertools;
use parking_lot::{Condvar, Mutex};
use nac3parser::ast::{Location, Stmt, StrRef};
use crate::{
symbol_resolver::{StaticValue, SymbolResolver},
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef},
typecheck::{
type_inferencer::{CodeLocation, PrimitiveStore},
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
},
use parking_lot::{Condvar, Mutex};
use std::collections::{HashMap, HashSet};
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use classes::{ListType, NDArrayType, ProxyType, RangeType};
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore};
pub use generator::{CodeGenerator, DefaultCodeGenerator};
use std::thread;
pub mod builtin_fns;
pub mod classes;
@ -54,21 +47,8 @@ pub mod stmt;
#[cfg(test)]
mod test;
mod macros {
/// Codegen-variant of [`std::unreachable`] which accepts an instance of [`CodeGenContext`] as
/// its first argument to provide Python source information to indicate the codegen location
/// causing the assertion.
macro_rules! codegen_unreachable {
($ctx:expr $(,)?) => {
std::unreachable!("unreachable code while processing {}", &$ctx.current_loc)
};
($ctx:expr, $($arg:tt)*) => {
std::unreachable!("unreachable code while processing {}: {}", &$ctx.current_loc, std::format!("{}", std::format_args!($($arg)+)))
};
}
pub(crate) use codegen_unreachable;
}
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore};
pub use generator::{CodeGenerator, DefaultCodeGenerator};
#[derive(Default)]
pub struct StaticValueStore {
@ -88,16 +68,6 @@ pub struct CodeGenLLVMOptions {
pub target: CodeGenTargetMachineOptions,
}
impl CodeGenLLVMOptions {
/// Creates a [`TargetMachine`] using the target options specified by this struct.
///
/// See [`Target::create_target_machine`].
#[must_use]
pub fn create_target_machine(&self) -> Option<TargetMachine> {
self.target.create_target_machine(self.opt_level)
}
}
/// Additional options for code generation for the target machine.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct CodeGenTargetMachineOptions {
@ -368,10 +338,6 @@ impl WorkerRegistry {
let mut builder = context.create_builder();
let mut module = context.create_module(generator.get_name());
let target_machine = self.llvm_options.create_target_machine().unwrap();
module.set_data_layout(&target_machine.get_target_data().get_data_layout());
module.set_triple(&target_machine.get_triple());
module.add_basic_value_flag(
"Debug Info Version",
inkwell::module::FlagBehavior::Warning,
@ -395,10 +361,6 @@ impl WorkerRegistry {
errors.insert(e);
// create a new empty module just to continue codegen and collect errors
module = context.create_module(&format!("{}_recover", generator.get_name()));
let target_machine = self.llvm_options.create_target_machine().unwrap();
module.set_data_layout(&target_machine.get_target_data().get_data_layout());
module.set_triple(&target_machine.get_triple());
}
}
*self.task_count.lock() -= 1;
@ -464,7 +426,7 @@ pub struct CodeGenTask {
fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
ctx: &'ctx Context,
module: &Module<'ctx>,
generator: &G,
generator: &mut G,
unifier: &mut Unifier,
top_level: &TopLevelContext,
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
@ -494,20 +456,6 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
.into()
}
TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
let element_type = get_llvm_type(
ctx,
module,
generator,
unifier,
top_level,
type_cache,
*params.iter().next().unwrap().1,
);
ListType::new(generator, ctx, element_type).as_base_type().into()
}
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (dtype, _) = unpack_ndarray_var_tys(unifier, ty);
let element_type = get_llvm_type(
@ -558,10 +506,8 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
};
return ty;
}
TTuple { ty, is_vararg_ctx } => {
TTuple { ty } => {
// a struct with fields in the order present in the tuple
assert!(!is_vararg_ctx, "Tuples in vararg context must be instantiated with the correct number of arguments before calling get_llvm_type");
let fields = ty
.iter()
.map(|ty| {
@ -570,6 +516,12 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
.collect_vec();
ctx.struct_type(&fields, false).into()
}
TList { ty } => {
let element_type =
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty);
ListType::new(generator, ctx, element_type).as_base_type().into()
}
TVirtual { .. } => unimplemented!(),
_ => unreachable!("{}", ty_enum.get_type_name()),
};
@ -591,7 +543,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
fn get_llvm_abi_type<'ctx, G: CodeGenerator + ?Sized>(
ctx: &'ctx Context,
module: &Module<'ctx>,
generator: &G,
generator: &mut G,
unifier: &mut Unifier,
top_level: &TopLevelContext,
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
@ -600,11 +552,11 @@ fn get_llvm_abi_type<'ctx, G: CodeGenerator + ?Sized>(
) -> BasicTypeEnum<'ctx> {
// If the type is used in the definition of a function, return `i1` instead of `i8` for ABI
// consistency.
if unifier.unioned(ty, primitives.bool) {
return if unifier.unioned(ty, primitives.bool) {
ctx.bool_type().into()
} else {
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, ty)
}
};
}
/// Whether `sret` is needed for a return value with type `ty`.
@ -629,40 +581,6 @@ fn need_sret(ty: BasicTypeEnum) -> bool {
need_sret_impl(ty, true)
}
/// Returns the [`BasicTypeEnum`] representing a `va_list` struct for variadic arguments.
fn get_llvm_valist_type<'ctx>(ctx: &'ctx Context, triple: &TargetTriple) -> BasicTypeEnum<'ctx> {
let triple = TargetMachine::normalize_triple(triple);
let triple = triple.as_str().to_str().unwrap();
let arch = triple.split('-').next().unwrap();
let llvm_pi8 = ctx.i8_type().ptr_type(AddressSpace::default());
// Referenced from parseArch() in llvm/lib/Support/Triple.cpp
match arch {
"i386" | "i486" | "i586" | "i686" | "riscv32" => {
ctx.i8_type().ptr_type(AddressSpace::default()).into()
}
"amd64" | "x86_64" | "x86_64h" => {
let llvm_i32 = ctx.i32_type();
let va_list_tag = ctx.opaque_struct_type("struct.__va_list_tag");
va_list_tag.set_body(
&[llvm_i32.into(), llvm_i32.into(), llvm_pi8.into(), llvm_pi8.into()],
false,
);
va_list_tag.into()
}
"armv7" => {
let va_list = ctx.opaque_struct_type("struct.__va_list");
va_list.set_body(&[llvm_pi8.into()], false);
va_list.into()
}
triple => {
todo!("Unsupported platform for varargs: {triple}")
}
}
}
/// Implementation for generating LLVM IR for a function.
pub fn gen_func_impl<
'ctx,
@ -774,7 +692,6 @@ pub fn gen_func_impl<
name: arg.name,
ty: task.store.to_unifier_type(&mut unifier, &primitives, arg.ty, &mut cache),
default_value: arg.default_value.clone(),
is_vararg: arg.is_vararg,
})
.collect_vec(),
task.store.to_unifier_type(&mut unifier, &primitives, *ret, &mut cache),
@ -797,10 +714,7 @@ pub fn gen_func_impl<
let has_sret = ret_type.map_or(false, |ty| need_sret(ty));
let mut params = args
.iter()
.filter(|arg| !arg.is_vararg)
.map(|arg| {
debug_assert!(!arg.is_vararg);
get_llvm_abi_type(
context,
&module,
@ -819,12 +733,9 @@ pub fn gen_func_impl<
params.insert(0, ret_type.unwrap().ptr_type(AddressSpace::default()).into());
}
debug_assert!(matches!(args.iter().filter(|arg| arg.is_vararg).count(), 0..=1));
let vararg_arg = args.iter().find(|arg| arg.is_vararg);
let fn_type = match ret_type {
Some(ret_type) if !has_sret => ret_type.fn_type(&params, vararg_arg.is_some()),
_ => context.void_type().fn_type(&params, vararg_arg.is_some()),
Some(ret_type) if !has_sret => ret_type.fn_type(&params, false),
_ => context.void_type().fn_type(&params, false),
};
let symbol = &task.symbol_name;
@ -852,10 +763,9 @@ pub fn gen_func_impl<
builder.position_at_end(init_bb);
let body_bb = context.append_basic_block(fn_val, "body");
// Store non-vararg argument values into local variables
let mut var_assignment = HashMap::new();
let offset = u32::from(has_sret);
for (n, arg) in args.iter().enumerate().filter(|(_, arg)| !arg.is_vararg) {
for (n, arg) in args.iter().enumerate() {
let param = fn_val.get_nth_param((n as u32) + offset).unwrap();
let local_type = get_llvm_type(
context,
@ -888,8 +798,6 @@ pub fn gen_func_impl<
var_assignment.insert(arg.name, (alloca, None, 0));
}
// TODO: Save vararg parameters as list
let return_buffer = if has_sret {
Some(fn_val.get_nth_param(0).unwrap().into_pointer_value())
} else {
@ -1112,9 +1020,3 @@ fn gen_in_range_check<'ctx>(
ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp").unwrap()
}
/// Returns the internal name for the `va_count` argument, used to indicate the number of arguments
/// passed to the variadic function.
fn get_va_count_arg_name(arg_name: StrRef) -> StrRef {
format!("__{}_va_count", &arg_name).into()
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,37 +1,34 @@
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use indexmap::IndexMap;
use indoc::indoc;
use inkwell::{
targets::{InitializationConfig, Target},
OptimizationLevel,
};
use nac3parser::{
ast::{fold::Fold, FileName, StrRef},
parser::parse_program,
};
use parking_lot::RwLock;
use super::{
classes::{ListType, NDArrayType, ProxyType, RangeType},
concrete_type::ConcreteTypeStore,
CodeGenContext, CodeGenLLVMOptions, CodeGenTargetMachineOptions, CodeGenTask, CodeGenerator,
DefaultCodeGenerator, WithCall, WorkerRegistry,
};
use crate::{
codegen::{
classes::{ListType, NDArrayType, ProxyType, RangeType},
concrete_type::ConcreteTypeStore,
CodeGenContext, CodeGenLLVMOptions, CodeGenTargetMachineOptions, CodeGenTask,
CodeGenerator, DefaultCodeGenerator, WithCall, WorkerRegistry,
},
symbol_resolver::{SymbolResolver, ValueEnum},
toplevel::{
composer::{ComposerConfig, TopLevelComposer},
DefinitionId, FunInstance, TopLevelContext, TopLevelDef,
},
typecheck::{
type_inferencer::{FunctionData, IdentifierInfo, Inferencer, PrimitiveStore},
type_inferencer::{FunctionData, Inferencer, PrimitiveStore},
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
},
};
use indexmap::IndexMap;
use indoc::indoc;
use inkwell::{
targets::{InitializationConfig, Target},
OptimizationLevel,
};
use nac3parser::ast::FileName;
use nac3parser::{
ast::{fold::Fold, StrRef},
parser::parse_program,
};
use parking_lot::RwLock;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
struct Resolver {
id_to_type: HashMap<StrRef, Type>,
@ -67,7 +64,6 @@ impl SymbolResolver for Resolver {
&self,
_: StrRef,
_: &mut CodeGenContext<'ctx, '_>,
_: &mut dyn CodeGenerator,
) -> Option<ValueEnum<'ctx>> {
unimplemented!()
}
@ -98,7 +94,7 @@ fn test_primitives() {
"};
let statements = parse_program(source, FileName::default()).unwrap();
let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 32).0;
let composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 32).0;
let mut unifier = composer.unifier.clone();
let primitives = composer.primitives_ty;
let top_level = Arc::new(composer.make_top_level_context());
@ -113,18 +109,8 @@ fn test_primitives() {
let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()];
let signature = FunSignature {
args: vec![
FuncArg {
name: "a".into(),
ty: primitives.int32,
default_value: None,
is_vararg: false,
},
FuncArg {
name: "b".into(),
ty: primitives.int32,
default_value: None,
is_vararg: false,
},
FuncArg { name: "a".into(), ty: primitives.int32, default_value: None },
FuncArg { name: "b".into(), ty: primitives.int32, default_value: None },
],
ret: primitives.int32,
vars: VarMap::new(),
@ -142,8 +128,7 @@ fn test_primitives() {
};
let mut virtual_checks = Vec::new();
let mut calls = HashMap::new();
let mut identifiers: HashMap<_, _> =
["a".into(), "b".into()].map(|id| (id, IdentifierInfo::default())).into();
let mut identifiers: HashSet<_> = ["a".into(), "b".into()].into();
let mut inferencer = Inferencer {
top_level: &top_level,
function_data: &mut function_data,
@ -204,8 +189,6 @@ fn test_primitives() {
let expected = indoc! {"
; ModuleID = 'test'
source_filename = \"test\"
target datalayout = \"e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128\"
target triple = \"x86_64-unknown-linux-gnu\"
; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn
define i32 @testing(i32 %0, i32 %1) local_unnamed_addr #0 !dbg !4 {
@ -263,19 +246,14 @@ fn test_simple_call() {
"};
let statements_2 = parse_program(source_2, FileName::default()).unwrap();
let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 32).0;
let composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 32).0;
let mut unifier = composer.unifier.clone();
let primitives = composer.primitives_ty;
let top_level = Arc::new(composer.make_top_level_context());
unifier.top_level = Some(top_level.clone());
let signature = FunSignature {
args: vec![FuncArg {
name: "a".into(),
ty: primitives.int32,
default_value: None,
is_vararg: false,
}],
args: vec![FuncArg { name: "a".into(), ty: primitives.int32, default_value: None }],
ret: primitives.int32,
vars: VarMap::new(),
};
@ -322,8 +300,7 @@ fn test_simple_call() {
};
let mut virtual_checks = Vec::new();
let mut calls = HashMap::new();
let mut identifiers: HashMap<_, _> =
["a".into(), "foo".into()].map(|id| (id, IdentifierInfo::default())).into();
let mut identifiers: HashSet<_> = ["a".into(), "foo".into()].into();
let mut inferencer = Inferencer {
top_level: &top_level,
function_data: &mut function_data,
@ -391,8 +368,6 @@ fn test_simple_call() {
let expected = indoc! {"
; ModuleID = 'test'
source_filename = \"test\"
target datalayout = \"e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128\"
target triple = \"x86_64-unknown-linux-gnu\"
; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn
define i32 @testing(i32 %0) local_unnamed_addr #0 !dbg !5 {

View File

@ -2,9 +2,9 @@
future_incompatible,
let_underscore,
nonstandard_style,
rust_2024_compatibility,
clippy::all
)]
#![warn(rust_2024_compatibility)]
#![warn(clippy::pedantic)]
#![allow(
dead_code,
@ -19,10 +19,6 @@
clippy::wildcard_imports
)]
// users of nac3core need to use the same version of these dependencies, so expose them as nac3core::*
pub use inkwell;
pub use nac3parser;
pub mod codegen;
pub mod symbol_resolver;
pub mod toplevel;

View File

@ -1,15 +1,7 @@
use std::{
collections::{HashMap, HashSet},
fmt::{Debug, Display},
rc::Rc,
sync::Arc,
};
use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue};
use itertools::{chain, izip, Itertools};
use parking_lot::RwLock;
use nac3parser::ast::{Constant, Expr, Location, StrRef};
use std::fmt::Debug;
use std::rc::Rc;
use std::sync::Arc;
use std::{collections::HashMap, collections::HashSet, fmt::Display};
use crate::{
codegen::{CodeGenContext, CodeGenerator},
@ -19,6 +11,10 @@ use crate::{
typedef::{Type, TypeEnum, Unifier, VarMap},
},
};
use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue};
use itertools::{chain, izip, Itertools};
use nac3parser::ast::{Constant, Expr, Location, StrRef};
use parking_lot::RwLock;
#[derive(Clone, PartialEq, Debug)]
pub enum SymbolValue {
@ -82,14 +78,14 @@ impl SymbolValue {
}
Constant::Tuple(t) => {
let expected_ty = unifier.get_ty(expected_ty);
let TypeEnum::TTuple { ty, is_vararg_ctx } = expected_ty.as_ref() else {
let TypeEnum::TTuple { ty } = expected_ty.as_ref() else {
return Err(format!(
"Expected {:?}, but got Tuple",
expected_ty.get_type_name()
));
};
assert!(*is_vararg_ctx || ty.len() == t.len());
assert_eq!(ty.len(), t.len());
let elems = t
.iter()
@ -159,7 +155,7 @@ impl SymbolValue {
SymbolValue::Bool(_) => primitives.bool,
SymbolValue::Tuple(vs) => {
let vs_tys = vs.iter().map(|v| v.get_type(primitives, unifier)).collect::<Vec<_>>();
unifier.add_ty(TypeEnum::TTuple { ty: vs_tys, is_vararg_ctx: false })
unifier.add_ty(TypeEnum::TTuple { ty: vs_tys })
}
SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option,
}
@ -369,7 +365,6 @@ pub trait SymbolResolver {
&self,
str: StrRef,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
) -> Option<ValueEnum<'ctx>>;
fn get_default_param_value(&self, expr: &Expr) -> Option<SymbolValue>;
@ -387,12 +382,13 @@ pub trait SymbolResolver {
}
thread_local! {
static IDENTIFIER_ID: [StrRef; 11] = [
static IDENTIFIER_ID: [StrRef; 12] = [
"int32".into(),
"int64".into(),
"float".into(),
"bool".into(),
"virtual".into(),
"list".into(),
"tuple".into(),
"str".into(),
"Exception".into(),
@ -417,12 +413,13 @@ pub fn parse_type_annotation<T>(
let float_id = ids[2];
let bool_id = ids[3];
let virtual_id = ids[4];
let tuple_id = ids[5];
let str_id = ids[6];
let exn_id = ids[7];
let uint32_id = ids[8];
let uint64_id = ids[9];
let literal_id = ids[10];
let list_id = ids[5];
let tuple_id = ids[6];
let str_id = ids[7];
let exn_id = ids[8];
let uint32_id = ids[9];
let uint64_id = ids[10];
let literal_id = ids[11];
let name_handling = |id: &StrRef, loc: Location, unifier: &mut Unifier| {
if *id == int32_id {
@ -479,6 +476,9 @@ pub fn parse_type_annotation<T>(
if *id == virtual_id {
let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?;
Ok(unifier.add_ty(TypeEnum::TVirtual { ty }))
} else if *id == list_id {
let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?;
Ok(unifier.add_ty(TypeEnum::TList { ty }))
} else if *id == tuple_id {
if let Tuple { elts, .. } = &slice.node {
let ty = elts
@ -487,7 +487,7 @@ pub fn parse_type_annotation<T>(
parse_type_annotation(resolver, top_level_defs, unifier, primitives, elt)
})
.collect::<Result<Vec<_>, _>>()?;
Ok(unifier.add_ty(TypeEnum::TTuple { ty, is_vararg_ctx: false }))
Ok(unifier.add_ty(TypeEnum::TTuple { ty }))
} else {
Err(HashSet::from(["Expected multiple elements for tuple".into()]))
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,20 +1,17 @@
use std::convert::TryInto;
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::numpy::unpack_ndarray_var_tys;
use crate::typecheck::typedef::{into_var_map, Mapping, TypeVarId, VarMap};
use nac3parser::ast::{Constant, Location};
use strum::IntoEnumIterator;
use strum_macros::EnumIter;
use nac3parser::ast::{Constant, ExprKind, Location};
use super::{numpy::unpack_ndarray_var_tys, *};
use crate::{
symbol_resolver::SymbolValue,
typecheck::typedef::{into_var_map, iter_type_vars, Mapping, TypeVarId, VarMap},
};
use super::*;
/// All primitive types and functions in nac3core.
#[derive(Clone, Copy, Debug, EnumIter, PartialEq, Eq)]
pub enum PrimDef {
// Classes
Int32,
Int64,
Float,
@ -26,25 +23,17 @@ pub enum PrimDef {
UInt32,
UInt64,
Option,
List,
OptionIsSome,
OptionIsNone,
OptionUnwrap,
NDArray,
// Option methods
FunOptionIsSome,
FunOptionIsNone,
FunOptionUnwrap,
// Option-related functions
FunSome,
// NDArray methods
FunNDArrayCopy,
FunNDArrayFill,
// Range methods
FunRangeInit,
// NumPy factory functions
NDArrayCopy,
NDArrayFill,
FunInt32,
FunInt64,
FunUInt32,
FunUInt64,
FunFloat,
FunNpNDArray,
FunNpEmpty,
FunNpZeros,
@ -53,17 +42,26 @@ pub enum PrimDef {
FunNpArray,
FunNpEye,
FunNpIdentity,
// Miscellaneous NumPy & SciPy functions
FunRound,
FunRound64,
FunNpRound,
FunRange,
FunStr,
FunBool,
FunFloor,
FunFloor64,
FunNpFloor,
FunCeil,
FunCeil64,
FunNpCeil,
FunLen,
FunMin,
FunNpMin,
FunNpMinimum,
FunNpArgmin,
FunMax,
FunNpMax,
FunNpMaximum,
FunNpArgmax,
FunAbs,
FunNpIsNan,
FunNpIsInf,
FunNpSin,
@ -101,46 +99,15 @@ pub enum PrimDef {
FunNpLdExp,
FunNpHypot,
FunNpNextAfter,
FunNpTranspose,
FunNpReshape,
// Linalg functions
FunNpDot,
FunNpLinalgCholesky,
FunNpLinalgQr,
FunNpLinalgSvd,
FunNpLinalgInv,
FunNpLinalgPinv,
FunNpLinalgMatrixPower,
FunNpLinalgDet,
FunSpLinalgLu,
FunSpLinalgSchur,
FunSpLinalgHessenberg,
// Miscellaneous Python & NAC3 functions
FunInt32,
FunInt64,
FunUInt32,
FunUInt64,
FunFloat,
FunRound,
FunRound64,
FunStr,
FunBool,
FunFloor,
FunFloor64,
FunCeil,
FunCeil64,
FunLen,
FunMin,
FunMax,
FunAbs,
FunSome,
FunNpAny,
FunNpAll,
}
/// Associated details of a [`PrimDef`]
pub enum PrimDefDetails {
PrimFunction { name: &'static str, simple_name: &'static str },
PrimClass { name: &'static str, get_ty_fn: fn(&PrimitiveStore) -> Type },
PrimClass { name: &'static str },
}
impl PrimDef {
@ -182,17 +149,15 @@ impl PrimDef {
#[must_use]
pub fn name(&self) -> &'static str {
match self.details() {
PrimDefDetails::PrimFunction { name, .. } | PrimDefDetails::PrimClass { name, .. } => {
name
}
PrimDefDetails::PrimFunction { name, .. } | PrimDefDetails::PrimClass { name } => name,
}
}
/// Get the associated details of this [`PrimDef`]
#[must_use]
pub fn details(self) -> PrimDefDetails {
fn class(name: &'static str, get_ty_fn: fn(&PrimitiveStore) -> Type) -> PrimDefDetails {
PrimDefDetails::PrimClass { name, get_ty_fn }
fn class(name: &'static str) -> PrimDefDetails {
PrimDefDetails::PrimClass { name }
}
fn fun(name: &'static str, simple_name: Option<&'static str>) -> PrimDefDetails {
@ -200,37 +165,28 @@ impl PrimDef {
}
match self {
// Classes
PrimDef::Int32 => class("int32", |primitives| primitives.int32),
PrimDef::Int64 => class("int64", |primitives| primitives.int64),
PrimDef::Float => class("float", |primitives| primitives.float),
PrimDef::Bool => class("bool", |primitives| primitives.bool),
PrimDef::None => class("none", |primitives| primitives.none),
PrimDef::Range => class("range", |primitives| primitives.range),
PrimDef::Str => class("str", |primitives| primitives.str),
PrimDef::Exception => class("Exception", |primitives| primitives.exception),
PrimDef::UInt32 => class("uint32", |primitives| primitives.uint32),
PrimDef::UInt64 => class("uint64", |primitives| primitives.uint64),
PrimDef::Option => class("Option", |primitives| primitives.option),
PrimDef::List => class("list", |primitives| primitives.list),
PrimDef::NDArray => class("ndarray", |primitives| primitives.ndarray),
// Option methods
PrimDef::FunOptionIsSome => fun("Option.is_some", Some("is_some")),
PrimDef::FunOptionIsNone => fun("Option.is_none", Some("is_none")),
PrimDef::FunOptionUnwrap => fun("Option.unwrap", Some("unwrap")),
// Option-related functions
PrimDef::FunSome => fun("Some", None),
// NDArray methods
PrimDef::FunNDArrayCopy => fun("ndarray.copy", Some("copy")),
PrimDef::FunNDArrayFill => fun("ndarray.fill", Some("fill")),
// Range methods
PrimDef::FunRangeInit => fun("range.__init__", Some("__init__")),
// NumPy factory functions
PrimDef::Int32 => class("int32"),
PrimDef::Int64 => class("int64"),
PrimDef::Float => class("float"),
PrimDef::Bool => class("bool"),
PrimDef::None => class("none"),
PrimDef::Range => class("range"),
PrimDef::Str => class("str"),
PrimDef::Exception => class("Exception"),
PrimDef::UInt32 => class("uint32"),
PrimDef::UInt64 => class("uint64"),
PrimDef::Option => class("Option"),
PrimDef::OptionIsSome => fun("Option.is_some", Some("is_some")),
PrimDef::OptionIsNone => fun("Option.is_none", Some("is_none")),
PrimDef::OptionUnwrap => fun("Option.unwrap", Some("unwrap")),
PrimDef::NDArray => class("ndarray"),
PrimDef::NDArrayCopy => fun("ndarray.copy", Some("copy")),
PrimDef::NDArrayFill => fun("ndarray.fill", Some("fill")),
PrimDef::FunInt32 => fun("int32", None),
PrimDef::FunInt64 => fun("int64", None),
PrimDef::FunUInt32 => fun("uint32", None),
PrimDef::FunUInt64 => fun("uint64", None),
PrimDef::FunFloat => fun("float", None),
PrimDef::FunNpNDArray => fun("np_ndarray", None),
PrimDef::FunNpEmpty => fun("np_empty", None),
PrimDef::FunNpZeros => fun("np_zeros", None),
@ -239,17 +195,26 @@ impl PrimDef {
PrimDef::FunNpArray => fun("np_array", None),
PrimDef::FunNpEye => fun("np_eye", None),
PrimDef::FunNpIdentity => fun("np_identity", None),
// Miscellaneous NumPy & SciPy functions
PrimDef::FunRound => fun("round", None),
PrimDef::FunRound64 => fun("round64", None),
PrimDef::FunNpRound => fun("np_round", None),
PrimDef::FunRange => fun("range", None),
PrimDef::FunStr => fun("str", None),
PrimDef::FunBool => fun("bool", None),
PrimDef::FunFloor => fun("floor", None),
PrimDef::FunFloor64 => fun("floor64", None),
PrimDef::FunNpFloor => fun("np_floor", None),
PrimDef::FunCeil => fun("ceil", None),
PrimDef::FunCeil64 => fun("ceil64", None),
PrimDef::FunNpCeil => fun("np_ceil", None),
PrimDef::FunLen => fun("len", None),
PrimDef::FunMin => fun("min", None),
PrimDef::FunNpMin => fun("np_min", None),
PrimDef::FunNpMinimum => fun("np_minimum", None),
PrimDef::FunNpArgmin => fun("np_argmin", None),
PrimDef::FunMax => fun("max", None),
PrimDef::FunNpMax => fun("np_max", None),
PrimDef::FunNpMaximum => fun("np_maximum", None),
PrimDef::FunNpArgmax => fun("np_argmax", None),
PrimDef::FunAbs => fun("abs", None),
PrimDef::FunNpIsNan => fun("np_isnan", None),
PrimDef::FunNpIsInf => fun("np_isinf", None),
PrimDef::FunNpSin => fun("np_sin", None),
@ -287,40 +252,9 @@ impl PrimDef {
PrimDef::FunNpLdExp => fun("np_ldexp", None),
PrimDef::FunNpHypot => fun("np_hypot", None),
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
PrimDef::FunNpTranspose => fun("np_transpose", None),
PrimDef::FunNpReshape => fun("np_reshape", None),
// Linalg functions
PrimDef::FunNpDot => fun("np_dot", None),
PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None),
PrimDef::FunNpLinalgQr => fun("np_linalg_qr", None),
PrimDef::FunNpLinalgSvd => fun("np_linalg_svd", None),
PrimDef::FunNpLinalgInv => fun("np_linalg_inv", None),
PrimDef::FunNpLinalgPinv => fun("np_linalg_pinv", None),
PrimDef::FunNpLinalgMatrixPower => fun("np_linalg_matrix_power", None),
PrimDef::FunNpLinalgDet => fun("np_linalg_det", None),
PrimDef::FunSpLinalgLu => fun("sp_linalg_lu", None),
PrimDef::FunSpLinalgSchur => fun("sp_linalg_schur", None),
PrimDef::FunSpLinalgHessenberg => fun("sp_linalg_hessenberg", None),
// Miscellaneous Python & NAC3 functions
PrimDef::FunInt32 => fun("int32", None),
PrimDef::FunInt64 => fun("int64", None),
PrimDef::FunUInt32 => fun("uint32", None),
PrimDef::FunUInt64 => fun("uint64", None),
PrimDef::FunFloat => fun("float", None),
PrimDef::FunRound => fun("round", None),
PrimDef::FunRound64 => fun("round64", None),
PrimDef::FunStr => fun("str", None),
PrimDef::FunBool => fun("bool", None),
PrimDef::FunFloor => fun("floor", None),
PrimDef::FunFloor64 => fun("floor64", None),
PrimDef::FunCeil => fun("ceil", None),
PrimDef::FunCeil64 => fun("ceil64", None),
PrimDef::FunLen => fun("len", None),
PrimDef::FunMin => fun("min", None),
PrimDef::FunMax => fun("max", None),
PrimDef::FunAbs => fun("abs", None),
PrimDef::FunSome => fun("Some", None),
PrimDef::FunNpAny => fun("np_any", None),
PrimDef::FunNpAll => fun("np_all", None),
}
}
}
@ -389,9 +323,6 @@ impl TopLevelDef {
r
}
),
TopLevelDef::Variable { name, ty, .. } => {
format!("Variable {{ name: {name:?}, ty: {:?} }}", unifier.stringify(*ty),)
}
}
}
}
@ -427,13 +358,7 @@ impl TopLevelComposer {
});
let range = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Range.id(),
fields: [
("start".into(), (int32, true)),
("stop".into(), (int32, true)),
("step".into(), (int32, true)),
]
.into_iter()
.collect(),
fields: HashMap::new(),
params: VarMap::new(),
});
let str = unifier.add_ty(TypeEnum::TObj {
@ -474,9 +399,9 @@ impl TopLevelComposer {
let option = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Option.id(),
fields: vec![
(PrimDef::FunOptionIsSome.simple_name().into(), (is_some_type_fun_ty, true)),
(PrimDef::FunOptionIsNone.simple_name().into(), (is_some_type_fun_ty, true)),
(PrimDef::FunOptionUnwrap.simple_name().into(), (unwrap_fun_ty, true)),
(PrimDef::OptionIsSome.simple_name().into(), (is_some_type_fun_ty, true)),
(PrimDef::OptionIsNone.simple_name().into(), (is_some_type_fun_ty, true)),
(PrimDef::OptionUnwrap.simple_name().into(), (unwrap_fun_ty, true)),
]
.into_iter()
.collect::<HashMap<_, _>>(),
@ -489,13 +414,6 @@ impl TopLevelComposer {
_ => unreachable!(),
};
let list_elem_tvar = unifier.get_fresh_var(Some("list_elem".into()), None);
let list = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::List.id(),
fields: Mapping::new(),
params: into_var_map([list_elem_tvar]),
});
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
let ndarray_ndims_tvar =
unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None);
@ -510,7 +428,6 @@ impl TopLevelComposer {
name: "value".into(),
ty: ndarray_dtype_tvar.ty,
default_value: None,
is_vararg: false,
}],
ret: none,
vars: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
@ -518,8 +435,8 @@ impl TopLevelComposer {
let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::NDArray.id(),
fields: Mapping::from([
(PrimDef::FunNDArrayCopy.simple_name().into(), (ndarray_copy_fun_ty, true)),
(PrimDef::FunNDArrayFill.simple_name().into(), (ndarray_fill_fun_ty, true)),
(PrimDef::NDArrayCopy.simple_name().into(), (ndarray_copy_fun_ty, true)),
(PrimDef::NDArrayFill.simple_name().into(), (ndarray_fill_fun_ty, true)),
]),
params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
});
@ -538,7 +455,6 @@ impl TopLevelComposer {
str,
exception,
option,
list,
ndarray,
size_t,
};
@ -562,7 +478,6 @@ impl TopLevelComposer {
object_id: obj_id,
type_vars: Vec::default(),
fields: Vec::default(),
attributes: Vec::default(),
methods: Vec::default(),
ancestors: Vec::default(),
constructor,
@ -593,18 +508,6 @@ impl TopLevelComposer {
}
}
#[must_use]
pub fn make_top_level_variable_def(
name: String,
simple_name: StrRef,
ty: Type,
ty_decl: Option<Expr>,
resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,
loc: Option<Location>,
) -> TopLevelDef {
TopLevelDef::Variable { name, simple_name, ty, ty_decl, resolver, loc }
}
#[must_use]
pub fn make_class_method_name(mut class_name: String, method_name: &str) -> String {
class_name.push('.');
@ -750,16 +653,7 @@ impl TopLevelComposer {
)
}
/// This function returns the fields that have been initialized in the `__init__` function of a class
/// The function takes as input:
/// * `class_id`: The `object_id` of the class whose function is being evaluated (check `TopLevelDef::Class`)
/// * `definition_ast_list`: A list of ast definitions and statements defined in `TopLevelComposer`
/// * `stmts`: The body of function being parsed. Each statment is analyzed to check varaible initialization statements
pub fn get_all_assigned_field(
class_id: usize,
definition_ast_list: &Vec<DefAst>,
stmts: &[Stmt<()>],
) -> Result<HashSet<StrRef>, HashSet<String>> {
pub fn get_all_assigned_field(stmts: &[Stmt<()>]) -> Result<HashSet<StrRef>, HashSet<String>> {
let mut result = HashSet::new();
for s in stmts {
match &s.node {
@ -795,138 +689,30 @@ impl TopLevelComposer {
// TODO: do not check for For and While?
ast::StmtKind::For { body, orelse, .. }
| ast::StmtKind::While { body, orelse, .. } => {
result.extend(Self::get_all_assigned_field(
class_id,
definition_ast_list,
body.as_slice(),
)?);
result.extend(Self::get_all_assigned_field(
class_id,
definition_ast_list,
orelse.as_slice(),
)?);
result.extend(Self::get_all_assigned_field(body.as_slice())?);
result.extend(Self::get_all_assigned_field(orelse.as_slice())?);
}
ast::StmtKind::If { body, orelse, .. } => {
let inited_for_sure = Self::get_all_assigned_field(
class_id,
definition_ast_list,
body.as_slice(),
)?
.intersection(&Self::get_all_assigned_field(
class_id,
definition_ast_list,
orelse.as_slice(),
)?)
.copied()
.collect::<HashSet<_>>();
let inited_for_sure = Self::get_all_assigned_field(body.as_slice())?
.intersection(&Self::get_all_assigned_field(orelse.as_slice())?)
.copied()
.collect::<HashSet<_>>();
result.extend(inited_for_sure);
}
ast::StmtKind::Try { body, orelse, finalbody, .. } => {
let inited_for_sure = Self::get_all_assigned_field(
class_id,
definition_ast_list,
body.as_slice(),
)?
.intersection(&Self::get_all_assigned_field(
class_id,
definition_ast_list,
orelse.as_slice(),
)?)
.copied()
.collect::<HashSet<_>>();
let inited_for_sure = Self::get_all_assigned_field(body.as_slice())?
.intersection(&Self::get_all_assigned_field(orelse.as_slice())?)
.copied()
.collect::<HashSet<_>>();
result.extend(inited_for_sure);
result.extend(Self::get_all_assigned_field(
class_id,
definition_ast_list,
finalbody.as_slice(),
)?);
result.extend(Self::get_all_assigned_field(finalbody.as_slice())?);
}
ast::StmtKind::With { body, .. } => {
result.extend(Self::get_all_assigned_field(
class_id,
definition_ast_list,
body.as_slice(),
)?);
}
// Variables Initialized in function calls
ast::StmtKind::Expr { value, .. } => {
let ExprKind::Call { func, .. } = &value.node else {
continue;
};
let ExprKind::Attribute { value, attr, .. } = &func.node else {
continue;
};
let ExprKind::Name { id, .. } = &value.node else {
continue;
};
// Need to consider the two cases:
// Case 1) Call to class function i.e. id = `self`
// Case 2) Call to class ancestor function i.e. id = ancestor_name
// We leave checking whether function in case 2 belonged to class ancestor or not to type checker
//
// According to current handling of `self`, function definition are fixed and do not change regardless
// of which object is passed as `self` i.e. virtual polymorphism is not supported
// Therefore, we change class id for case 2 to reflect behavior of our compiler
let class_name = if *id == "self".into() {
let ast::StmtKind::ClassDef { name, .. } =
&definition_ast_list[class_id].1.as_ref().unwrap().node
else {
unreachable!()
};
name
} else {
id
};
let parent_method = definition_ast_list.iter().find_map(|def| {
let (
class_def,
Some(ast::Located {
node: ast::StmtKind::ClassDef { name, body, .. },
..
}),
) = &def
else {
return None;
};
let TopLevelDef::Class { object_id: class_id, .. } = &*class_def.read()
else {
unreachable!()
};
if name == class_name {
body.iter().find_map(|m| {
let ast::StmtKind::FunctionDef { name, body, .. } = &m.node else {
return None;
};
if *name == *attr {
return Some((body.clone(), class_id.0));
}
None
})
} else {
None
}
});
// If method body is none then method does not exist
if let Some((method_body, class_id)) = parent_method {
result.extend(Self::get_all_assigned_field(
class_id,
definition_ast_list,
method_body.as_slice(),
)?);
} else {
return Err(HashSet::from([format!(
"{}.{} not found in class {class_name} at {}",
*id, *attr, value.location
)]));
}
result.extend(Self::get_all_assigned_field(body.as_slice())?);
}
ast::StmtKind::Pass { .. }
| ast::StmtKind::Assert { .. }
| ast::StmtKind::AnnAssign { .. } => {}
| ast::StmtKind::Expr { .. } => {}
_ => {
unimplemented!()
@ -1105,9 +891,7 @@ pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type {
unpack_ndarray_var_tys(unifier, ty).0
}
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
arraylike_flatten_element_type(unifier, iter_type_vars(params).next().unwrap().ty)
}
TypeEnum::TList { ty } => arraylike_flatten_element_type(unifier, *ty),
_ => ty,
}
}
@ -1128,9 +912,7 @@ pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 {
u64::try_from(values[0].clone()).unwrap()
}
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
arraylike_get_ndims(unifier, iter_type_vars(params).next().unwrap().ty) + 1
}
TypeEnum::TList { ty } => arraylike_get_ndims(unifier, *ty) + 1,
_ => 0,
}
}

View File

@ -6,36 +6,36 @@ use std::{
sync::Arc,
};
use inkwell::values::BasicValueEnum;
use itertools::Itertools;
use parking_lot::RwLock;
use nac3parser::ast::{self, Expr, Location, Stmt, StrRef};
use super::codegen::CodeGenContext;
use super::typecheck::type_inferencer::PrimitiveStore;
use super::typecheck::typedef::{
FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, Unifier, VarMap,
};
use crate::{
codegen::{CodeGenContext, CodeGenerator},
codegen::CodeGenerator,
symbol_resolver::{SymbolResolver, ValueEnum},
typecheck::{
type_inferencer::{CodeLocation, PrimitiveStore},
typedef::{
CallId, FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, TypeVarId, Unifier,
VarMap,
},
type_inferencer::CodeLocation,
typedef::{CallId, TypeVarId},
},
};
use composer::*;
use type_annotation::*;
use inkwell::values::BasicValueEnum;
use itertools::Itertools;
use nac3parser::ast::{self, Location, Stmt, StrRef};
use parking_lot::RwLock;
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash, Debug)]
pub struct DefinitionId(pub usize);
pub mod builtins;
pub mod composer;
pub mod helper;
pub mod numpy;
pub mod type_annotation;
use composer::*;
use type_annotation::*;
#[cfg(test)]
mod test;
pub mod type_annotation;
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash, Debug)]
pub struct DefinitionId(pub usize);
type GenCallCallback = dyn for<'ctx, 'a> Fn(
&mut CodeGenContext<'ctx, 'a>,
@ -103,10 +103,6 @@ pub enum TopLevelDef {
///
/// Name and type is mutable.
fields: Vec<(StrRef, Type, bool)>,
/// Class Attributes.
///
/// Name, type, value.
attributes: Vec<(StrRef, Type, ast::Constant)>,
/// Class methods, pointing to the corresponding function definition.
methods: Vec<(StrRef, Type, DefinitionId)>,
/// Ancestor classes, including itself.
@ -130,14 +126,14 @@ pub enum TopLevelDef {
/// Function instance to symbol mapping
///
/// * Key: String representation of type variable values, sorted by variable ID in ascending
/// order, including type variables associated with the class.
/// order, including type variables associated with the class.
/// * Value: Function symbol name.
instance_to_symbol: HashMap<String, String>,
/// Function instances to annotated AST mapping
///
/// * Key: String representation of type variable values, sorted by variable ID in ascending
/// order, including type variables associated with the class. Excluding rigid type
/// variables.
/// order, including type variables associated with the class. Excluding rigid type
/// variables.
///
/// Rigid type variables that would be substituted when the function is instantiated.
instance_to_stmt: HashMap<String, FunInstance>,
@ -148,25 +144,6 @@ pub enum TopLevelDef {
/// Definition location.
loc: Option<Location>,
},
Variable {
/// Qualified name of the global variable, should be unique globally.
name: String,
/// Simple name, the same as in method/function definition.
simple_name: StrRef,
/// Type of the global variable.
ty: Type,
/// The declared type of the global variable, or [`None`] if no type annotation is provided.
ty_decl: Option<Expr>,
/// Symbol resolver of the module defined the class.
resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,
/// Definition location.
loc: Option<Location>,
},
}
pub struct TopLevelContext {

View File

@ -1,17 +1,18 @@
use itertools::Itertools;
use super::helper::PrimDef;
use crate::typecheck::{
type_inferencer::PrimitiveStore,
typedef::{Type, TypeEnum, TypeVarId, Unifier, VarMap},
use crate::{
toplevel::helper::PrimDef,
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{Type, TypeEnum, TypeVarId, Unifier, VarMap},
},
};
use itertools::Itertools;
/// Creates a `ndarray` [`Type`] with the given type arguments.
///
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
/// specialized.
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
/// specialized.
pub fn make_ndarray_ty(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
@ -24,9 +25,9 @@ pub fn make_ndarray_ty(
/// Substitutes type variables in `ndarray`.
///
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
/// specialized.
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
/// specialized.
pub fn subst_ndarray_tvars(
unifier: &mut Unifier,
ndarray: Type,

View File

@ -5,7 +5,7 @@ expression: res_vec
[
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(241)]\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(239)]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",

View File

@ -7,7 +7,7 @@ expression: res_vec
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar230]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar230\"]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar228]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar228\"]\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",

View File

@ -5,8 +5,8 @@ expression: res_vec
[
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(243)]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(248)]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(241)]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(246)]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",

View File

@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec
---
[
"Class {\nname: \"A\",\nancestors: [\"A[typevar229, typevar230]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar229\", \"typevar230\"]\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[typevar227, typevar228]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar227\", \"typevar228\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",

View File

@ -6,12 +6,12 @@ expression: res_vec
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(249)]\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(247)]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(257)]\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(255)]\n}\n",
]

View File

@ -1,23 +1,20 @@
use std::{collections::HashMap, sync::Arc};
use indoc::indoc;
use parking_lot::Mutex;
use test_case::test_case;
use nac3parser::{
ast::{fold::Fold, FileName},
parser::parse_program,
};
use super::{helper::PrimDef, DefinitionId, *};
use crate::{
codegen::CodeGenContext,
symbol_resolver::{SymbolResolver, ValueEnum},
toplevel::DefinitionId,
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{into_var_map, Type, Unifier},
typedef::{Type, Unifier},
},
};
use indoc::indoc;
use nac3parser::ast::FileName;
use nac3parser::{ast::fold::Fold, parser::parse_program};
use parking_lot::Mutex;
use std::{collections::HashMap, sync::Arc};
use test_case::test_case;
use super::*;
struct ResolverInternal {
id_to_type: Mutex<HashMap<StrRef, Type>>,
@ -64,7 +61,6 @@ impl SymbolResolver for Resolver {
&self,
_: StrRef,
_: &mut CodeGenContext<'ctx, '_>,
_: &mut dyn CodeGenerator,
) -> Option<ValueEnum<'ctx>> {
unimplemented!()
}
@ -120,8 +116,7 @@ impl SymbolResolver for Resolver {
"register"
)]
fn test_simple_register(source: Vec<&str>) {
let mut composer =
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
let mut composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 64).0;
for s in source {
let ast = parse_program(s, FileName::default()).unwrap();
@ -141,8 +136,7 @@ fn test_simple_register(source: Vec<&str>) {
"register"
)]
fn test_simple_register_without_constructor(source: &str) {
let mut composer =
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
let mut composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 64).0;
let ast = parse_program(source, FileName::default()).unwrap();
let ast = ast[0].clone();
composer.register_top_level(ast, None, "", true).unwrap();
@ -176,8 +170,7 @@ fn test_simple_register_without_constructor(source: &str) {
"function compose"
)]
fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
let mut composer =
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
let mut composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 64).0;
let internal_resolver = Arc::new(ResolverInternal {
id_to_def: Mutex::default(),
@ -525,8 +518,7 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
)]
fn test_analyze(source: &[&str], res: &[&str]) {
let print = false;
let mut composer =
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
let mut composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 64).0;
let internal_resolver = make_internal_resolver_with_tvar(
vec![
@ -703,8 +695,7 @@ fn test_analyze(source: &[&str], res: &[&str]) {
)]
fn test_inference(source: Vec<&str>, res: &[&str]) {
let print = true;
let mut composer =
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
let mut composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 64).0;
let internal_resolver = make_internal_resolver_with_tvar(
vec![
@ -784,15 +775,8 @@ fn make_internal_resolver_with_tvar(
unifier: &mut Unifier,
print: bool,
) -> Arc<ResolverInternal> {
let list_elem_tvar = unifier.get_fresh_var(Some("list_elem".into()), None);
let list = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::List.id(),
fields: HashMap::new(),
params: into_var_map([list_elem_tvar]),
});
let res: Arc<ResolverInternal> = ResolverInternal {
id_to_def: Mutex::new(HashMap::from([("list".into(), PrimDef::List.id())])),
id_to_def: Mutex::default(),
id_to_type: tvars
.into_iter()
.map(|(name, range)| {
@ -806,7 +790,7 @@ fn make_internal_resolver_with_tvar(
})
.collect::<HashMap<_, _>>()
.into(),
class_names: Mutex::new(HashMap::from([("list".into(), list)])),
class_names: Mutex::default(),
}
.into();
if print {

View File

@ -1,13 +1,9 @@
use strum::IntoEnumIterator;
use super::*;
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PrimDef;
use crate::typecheck::typedef::VarMap;
use nac3parser::ast::Constant;
use super::{
helper::{PrimDef, PrimDefDetails},
*,
};
use crate::{symbol_resolver::SymbolValue, typecheck::typedef::VarMap};
#[derive(Clone, Debug)]
pub enum TypeAnnotation {
Primitive(Type),
@ -22,6 +18,7 @@ pub enum TypeAnnotation {
TypeVar(Type),
/// A `Literal` allowing a subset of literals.
Literal(Vec<Constant>),
List(Box<TypeAnnotation>),
Tuple(Vec<TypeAnnotation>),
}
@ -54,6 +51,7 @@ impl TypeAnnotation {
format!("Literal({})", values.iter().map(|v| format!("{v:?}")).join(", "))
}
Virtual(ty) => format!("virtual[{}]", ty.stringify(unifier)),
List(ty) => format!("list[{}]", ty.stringify(unifier)),
Tuple(types) => {
format!(
"tuple[{}]",
@ -67,9 +65,9 @@ impl TypeAnnotation {
/// Parses an AST expression `expr` into a [`TypeAnnotation`].
///
/// * `locked` - A [`HashMap`] containing the IDs of known definitions, mapped to a [`Vec`] of all
/// generic variables associated with the definition.
/// generic variables associated with the definition.
/// * `type_var` - The type variable associated with the type argument currently being parsed. Pass
/// [`None`] when this function is invoked externally.
/// [`None`] when this function is invoked externally.
pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
resolver: &(dyn SymbolResolver + Send + Sync),
top_level_defs: &[Arc<RwLock<TopLevelDef>>],
@ -147,7 +145,9 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
slice: &ast::Expr<T>,
unifier: &mut Unifier,
mut locked: HashMap<DefinitionId, Vec<Type>, S>| {
if ["virtual".into(), "Generic".into(), "tuple".into(), "Option".into()].contains(id) {
if ["virtual".into(), "Generic".into(), "list".into(), "tuple".into(), "Option".into()]
.contains(id)
{
return Err(HashSet::from([format!(
"keywords cannot be class name (at {})",
expr.location
@ -236,6 +236,23 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
Ok(TypeAnnotation::Virtual(def.into()))
}
// list
ast::ExprKind::Subscript { value, slice, .. }
if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"list".into())
} =>
{
let def_ann = parse_ast_to_type_annotation_kinds(
resolver,
top_level_defs,
unifier,
primitives,
slice.as_ref(),
locked,
)?;
Ok(TypeAnnotation::List(def_ann.into()))
}
// option
ast::ExprKind::Subscript { value, slice, .. }
if {
@ -361,7 +378,6 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
pub fn get_type_from_type_annotation_kinds(
top_level_defs: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier,
primitives: &PrimitiveStore,
ann: &TypeAnnotation,
subst_list: &mut Option<Vec<Type>>,
) -> Result<Type, HashSet<String>> {
@ -384,141 +400,99 @@ pub fn get_type_from_type_annotation_kinds(
let param_ty = params
.iter()
.map(|x| {
get_type_from_type_annotation_kinds(
top_level_defs,
unifier,
primitives,
x,
subst_list,
)
get_type_from_type_annotation_kinds(top_level_defs, unifier, x, subst_list)
})
.collect::<Result<Vec<_>, _>>()?;
let ty = if let Some(prim_def) = PrimDef::iter().find(|prim| prim.id() == *obj_id) {
// Primitive TopLevelDefs do not contain all fields that are present in their Type
// counterparts, so directly perform subst on the Type instead.
let PrimDefDetails::PrimClass { get_ty_fn, .. } = prim_def.details() else {
unreachable!()
};
let base_ty = get_ty_fn(primitives);
let params =
if let TypeEnum::TObj { params, .. } = &*unifier.get_ty_immutable(base_ty) {
params.clone()
} else {
unreachable!()
};
unifier
.subst(
get_ty_fn(primitives),
&params
.iter()
.zip(param_ty)
.map(|(obj_tv, param)| (*obj_tv.0, param))
.collect(),
)
.unwrap_or(base_ty)
} else {
let subst = {
// check for compatible range
// TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check
let mut result = VarMap::new();
for (tvar, p) in type_vars.iter().zip(param_ty) {
match unifier.get_ty(*tvar).as_ref() {
TypeEnum::TVar {
id,
range,
fields: None,
name,
loc,
is_const_generic: false,
} => {
let ok: bool = {
// create a temp type var and unify to check compatibility
p == *tvar || {
let temp = unifier.get_fresh_var_with_range(
range.as_slice(),
*name,
*loc,
);
unifier.unify(temp.ty, p).is_ok()
}
};
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([format!(
"cannot apply type {} to type variable with id {:?}",
unifier.internal_stringify(
p,
&mut |id| format!("class{id}"),
&mut |id| format!("typevar{id}"),
&mut None
),
*id
)]));
let subst = {
// check for compatible range
// TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check
let mut result = VarMap::new();
for (tvar, p) in type_vars.iter().zip(param_ty) {
match unifier.get_ty(*tvar).as_ref() {
TypeEnum::TVar {
id,
range,
fields: None,
name,
loc,
is_const_generic: false,
} => {
let ok: bool = {
// create a temp type var and unify to check compatibility
p == *tvar || {
let temp = unifier.get_fresh_var_with_range(
range.as_slice(),
*name,
*loc,
);
unifier.unify(temp.ty, p).is_ok()
}
};
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([format!(
"cannot apply type {} to type variable with id {:?}",
unifier.internal_stringify(
p,
&mut |id| format!("class{id}"),
&mut |id| format!("typevar{id}"),
&mut None
),
*id
)]));
}
TypeEnum::TVar {
id, range, name, loc, is_const_generic: true, ..
} => {
let ty = range[0];
let ok: bool = {
// create a temp type var and unify to check compatibility
p == *tvar || {
let temp =
unifier.get_fresh_const_generic_var(ty, *name, *loc);
unifier.unify(temp.ty, p).is_ok()
}
};
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([format!(
"cannot apply type {} to type variable {}",
unifier.stringify(p),
name.unwrap_or_else(|| format!("typevar{id}").into()),
)]));
}
}
_ => unreachable!("must be generic type var"),
}
}
result
};
// Class Attributes keep a copy with Class Definition and are not added to objects
let mut tobj_fields = methods
.iter()
.map(|(name, ty, _)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
// methods are immutable
(*name, (subst_ty, false))
})
.collect::<HashMap<_, _>>();
tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*name, (subst_ty, *mutability))
}));
let need_subst = !subst.is_empty();
let ty = unifier.add_ty(TypeEnum::TObj {
obj_id: *obj_id,
fields: tobj_fields,
params: subst,
});
if need_subst {
if let Some(wl) = subst_list.as_mut() {
wl.push(ty);
TypeEnum::TVar { id, range, name, loc, is_const_generic: true, .. } => {
let ty = range[0];
let ok: bool = {
// create a temp type var and unify to check compatibility
p == *tvar || {
let temp = unifier.get_fresh_const_generic_var(ty, *name, *loc);
unifier.unify(temp.ty, p).is_ok()
}
};
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([format!(
"cannot apply type {} to type variable {}",
unifier.stringify(p),
name.unwrap_or_else(|| format!("typevar{id}").into()),
)]));
}
}
_ => unreachable!("must be generic type var"),
}
}
ty
result
};
let mut tobj_fields = methods
.iter()
.map(|(name, ty, _)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
// methods are immutable
(*name, (subst_ty, false))
})
.collect::<HashMap<_, _>>();
tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*name, (subst_ty, *mutability))
}));
let need_subst = !subst.is_empty();
let ty = unifier.add_ty(TypeEnum::TObj {
obj_id: *obj_id,
fields: tobj_fields,
params: subst,
});
if need_subst {
if let Some(wl) = subst_list.as_mut() {
wl.push(ty);
}
}
Ok(ty)
}
TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty),
@ -536,26 +510,28 @@ pub fn get_type_from_type_annotation_kinds(
let ty = get_type_from_type_annotation_kinds(
top_level_defs,
unifier,
primitives,
ty.as_ref(),
subst_list,
)?;
Ok(unifier.add_ty(TypeEnum::TVirtual { ty }))
}
TypeAnnotation::List(ty) => {
let ty = get_type_from_type_annotation_kinds(
top_level_defs,
unifier,
ty.as_ref(),
subst_list,
)?;
Ok(unifier.add_ty(TypeEnum::TList { ty }))
}
TypeAnnotation::Tuple(tys) => {
let tys = tys
.iter()
.map(|x| {
get_type_from_type_annotation_kinds(
top_level_defs,
unifier,
primitives,
x,
subst_list,
)
get_type_from_type_annotation_kinds(top_level_defs, unifier, x, subst_list)
})
.collect::<Result<Vec<_>, _>>()?;
Ok(unifier.add_ty(TypeEnum::TTuple { ty: tys, is_vararg_ctx: false }))
Ok(unifier.add_ty(TypeEnum::TTuple { ty: tys }))
}
}
}
@ -588,7 +564,7 @@ pub fn get_type_var_contained_in_type_annotation(ann: &TypeAnnotation) -> Vec<Ty
let mut result: Vec<TypeAnnotation> = Vec::new();
match ann {
TypeAnnotation::TypeVar(..) => result.push(ann.clone()),
TypeAnnotation::Virtual(ann) => {
TypeAnnotation::Virtual(ann) | TypeAnnotation::List(ann) => {
result.extend(get_type_var_contained_in_type_annotation(ann.as_ref()));
}
TypeAnnotation::CustomClass { params, .. } => {
@ -629,7 +605,8 @@ pub fn check_overload_type_annotation_compatible(
a == b
}
(TypeAnnotation::Virtual(a), TypeAnnotation::Virtual(b)) => {
(TypeAnnotation::Virtual(a), TypeAnnotation::Virtual(b))
| (TypeAnnotation::List(a), TypeAnnotation::List(b)) => {
check_overload_type_annotation_compatible(a.as_ref(), b.as_ref(), unifier)
}

View File

@ -1,19 +1,13 @@
use std::{
collections::{HashMap, HashSet},
iter::once,
};
use crate::typecheck::typedef::TypeEnum;
use super::type_inferencer::Inferencer;
use super::typedef::Type;
use nac3parser::ast::{
self, Constant, Expr, ExprKind,
Operator::{LShift, RShift},
Stmt, StmtKind, StrRef,
};
use super::{
type_inferencer::{DeclarationSource, IdentifierInfo, Inferencer},
typedef::{Type, TypeEnum},
};
use crate::toplevel::helper::PrimDef;
use std::{collections::HashSet, iter::once};
impl<'a> Inferencer<'a> {
fn should_have_value(&mut self, expr: &Expr<Option<Type>>) -> Result<(), HashSet<String>> {
@ -27,45 +21,26 @@ impl<'a> Inferencer<'a> {
fn check_pattern(
&mut self,
pattern: &Expr<Option<Type>>,
defined_identifiers: &mut HashMap<StrRef, IdentifierInfo>,
defined_identifiers: &mut HashSet<StrRef>,
) -> Result<(), HashSet<String>> {
match &pattern.node {
ExprKind::Name { id, .. } if id == &"none".into() => {
Err(HashSet::from([format!("cannot assign to a `none` (at {})", pattern.location)]))
}
ExprKind::Name { id, .. } => {
// If `id` refers to a declared symbol, reject this assignment if it is used in the
// context of an (implicit) global variable
if let Some(id_info) = defined_identifiers.get(id) {
if matches!(
id_info.source,
DeclarationSource::Global { is_explicit: Some(false) }
) {
return Err(HashSet::from([format!(
"cannot access local variable '{id}' before it is declared (at {})",
pattern.location
)]));
}
}
if !defined_identifiers.contains_key(id) {
defined_identifiers.insert(*id, IdentifierInfo::default());
if !defined_identifiers.contains(id) {
defined_identifiers.insert(*id);
}
self.should_have_value(pattern)?;
Ok(())
}
ExprKind::List { elts, .. } | ExprKind::Tuple { elts, .. } => {
ExprKind::Tuple { elts, .. } => {
for elt in elts {
self.check_pattern(elt, defined_identifiers)?;
self.should_have_value(elt)?;
}
Ok(())
}
ExprKind::Starred { value, .. } => {
self.check_pattern(value, defined_identifiers)?;
self.should_have_value(value)?;
Ok(())
}
ExprKind::Subscript { value, slice, .. } => {
self.check_expr(value, defined_identifiers)?;
self.should_have_value(value)?;
@ -89,12 +64,11 @@ impl<'a> Inferencer<'a> {
fn check_expr(
&mut self,
expr: &Expr<Option<Type>>,
defined_identifiers: &mut HashMap<StrRef, IdentifierInfo>,
defined_identifiers: &mut HashSet<StrRef>,
) -> Result<(), HashSet<String>> {
// there are some cases where the custom field is None
if let Some(ty) = &expr.custom {
if !matches!(&expr.node, ExprKind::Constant { value: Constant::Ellipsis, .. })
&& !ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::List.id())
&& !self.unifier.is_concrete(*ty, &self.function_data.bound_variables)
{
return Err(HashSet::from([format!(
@ -110,7 +84,7 @@ impl<'a> Inferencer<'a> {
return Ok(());
}
self.should_have_value(expr)?;
if !defined_identifiers.contains_key(id) {
if !defined_identifiers.contains(id) {
match self.function_data.resolver.get_symbol_type(
self.unifier,
&self.top_level.definitions.read(),
@ -118,22 +92,7 @@ impl<'a> Inferencer<'a> {
*id,
) {
Ok(_) => {
let is_global = self.is_id_global(*id);
defined_identifiers.insert(
*id,
IdentifierInfo {
source: match is_global {
Some(true) => {
DeclarationSource::Global { is_explicit: Some(false) }
}
Some(false) => {
DeclarationSource::Global { is_explicit: None }
}
None => DeclarationSource::Local,
},
},
);
self.defined_identifiers.insert(*id);
}
Err(e) => {
return Err(HashSet::from([format!(
@ -206,7 +165,9 @@ impl<'a> Inferencer<'a> {
let mut defined_identifiers = defined_identifiers.clone();
for arg in &args.args {
// TODO: should we check the types here?
defined_identifiers.entry(arg.node.arg).or_default();
if !defined_identifiers.contains(&arg.node.arg) {
defined_identifiers.insert(arg.node.arg);
}
}
self.check_expr(body, &mut defined_identifiers)?;
}
@ -245,23 +206,19 @@ impl<'a> Inferencer<'a> {
/// This is a workaround preventing the caller from using a variable `alloca`-ed in the body, which
/// is freed when the function returns.
fn check_return_value_ty(&mut self, ret_ty: Type) -> bool {
if cfg!(feature = "no-escape-analysis") {
true
} else {
match &*self.unifier.get_ty_immutable(ret_ty) {
TypeEnum::TObj { .. } => [
self.primitives.int32,
self.primitives.int64,
self.primitives.uint32,
self.primitives.uint64,
self.primitives.float,
self.primitives.bool,
]
.iter()
.any(|allowed_ty| self.unifier.unioned(ret_ty, *allowed_ty)),
TypeEnum::TTuple { ty, .. } => ty.iter().all(|t| self.check_return_value_ty(*t)),
_ => false,
}
match &*self.unifier.get_ty_immutable(ret_ty) {
TypeEnum::TObj { .. } => [
self.primitives.int32,
self.primitives.int64,
self.primitives.uint32,
self.primitives.uint64,
self.primitives.float,
self.primitives.bool,
]
.iter()
.any(|allowed_ty| self.unifier.unioned(ret_ty, *allowed_ty)),
TypeEnum::TTuple { ty } => ty.iter().all(|t| self.check_return_value_ty(*t)),
_ => false,
}
}
@ -269,7 +226,7 @@ impl<'a> Inferencer<'a> {
fn check_stmt(
&mut self,
stmt: &Stmt<Option<Type>>,
defined_identifiers: &mut HashMap<StrRef, IdentifierInfo>,
defined_identifiers: &mut HashSet<StrRef>,
) -> Result<bool, HashSet<String>> {
match &stmt.node {
StmtKind::For { target, iter, body, orelse, .. } => {
@ -295,11 +252,9 @@ impl<'a> Inferencer<'a> {
let body_returned = self.check_block(body, &mut body_identifiers)?;
let orelse_returned = self.check_block(orelse, &mut orelse_identifiers)?;
for ident in body_identifiers.keys() {
if !defined_identifiers.contains_key(ident)
&& orelse_identifiers.contains_key(ident)
{
defined_identifiers.insert(*ident, IdentifierInfo::default());
for ident in &body_identifiers {
if !defined_identifiers.contains(ident) && orelse_identifiers.contains(ident) {
defined_identifiers.insert(*ident);
}
}
Ok(body_returned && orelse_returned)
@ -330,7 +285,7 @@ impl<'a> Inferencer<'a> {
let mut defined_identifiers = defined_identifiers.clone();
let ast::ExcepthandlerKind::ExceptHandler { name, body, .. } = &handler.node;
if let Some(name) = name {
defined_identifiers.insert(*name, IdentifierInfo::default());
defined_identifiers.insert(*name);
}
self.check_block(body, &mut defined_identifiers)?;
}
@ -394,44 +349,6 @@ impl<'a> Inferencer<'a> {
}
Ok(true)
}
StmtKind::Global { names, .. } => {
for id in names {
if let Some(id_info) = defined_identifiers.get(id) {
if id_info.source == DeclarationSource::Local {
return Err(HashSet::from([format!(
"name '{id}' is referenced prior to global declaration at {}",
stmt.location,
)]));
}
continue;
}
match self.function_data.resolver.get_symbol_type(
self.unifier,
&self.top_level.definitions.read(),
self.primitives,
*id,
) {
Ok(_) => {
defined_identifiers.insert(
*id,
IdentifierInfo {
source: DeclarationSource::Global { is_explicit: Some(true) },
},
);
}
Err(e) => {
return Err(HashSet::from([format!(
"type error at identifier `{}` ({}) at {}",
id, e, stmt.location
)]))
}
}
}
Ok(false)
}
// break, raise, etc.
_ => Ok(false),
}
@ -440,7 +357,7 @@ impl<'a> Inferencer<'a> {
pub fn check_block(
&mut self,
block: &[Stmt<Option<Type>>],
defined_identifiers: &mut HashMap<StrRef, IdentifierInfo>,
defined_identifiers: &mut HashSet<StrRef>,
) -> Result<bool, HashSet<String>> {
let mut ret = false;
for stmt in block {

View File

@ -1,154 +1,79 @@
use std::{cmp::max, collections::HashMap, rc::Rc};
use itertools::{iproduct, Itertools};
use strum::IntoEnumIterator;
use nac3parser::ast::{Cmpop, Operator, StrRef, Unaryop};
use super::{
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PrimDef;
use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys};
use crate::typecheck::{
type_inferencer::*,
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
};
use crate::{
symbol_resolver::SymbolValue,
toplevel::{
helper::PrimDef,
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
},
};
use itertools::Itertools;
use nac3parser::ast::StrRef;
use nac3parser::ast::{Cmpop, Operator, Unaryop};
use std::cmp::max;
use std::collections::HashMap;
use std::rc::Rc;
use strum::IntoEnumIterator;
/// The variant of a binary operator.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BinopVariant {
/// The normal variant.
/// For addition, it would be `+`.
Normal,
/// The "Augmented Assigning Operator" variant.
/// For addition, it would be `+=`.
AugAssign,
}
/// A binary operator with its variant.
#[derive(Debug, Clone, Copy)]
pub struct Binop {
/// The base [`Operator`] of this binary operator.
pub base: Operator,
/// The variant of this binary operator.
pub variant: BinopVariant,
}
impl Binop {
/// Make a [`Binop`] of the normal variant from an [`Operator`].
#[must_use]
pub fn normal(base: Operator) -> Self {
Binop { base, variant: BinopVariant::Normal }
}
/// Make a [`Binop`] of the aug assign variant from an [`Operator`].
#[must_use]
pub fn aug_assign(base: Operator) -> Self {
Binop { base, variant: BinopVariant::AugAssign }
}
}
/// Details about an operator (unary, binary, etc...) in Python
#[derive(Debug, Clone, Copy)]
pub struct OpInfo {
/// The method name of the binary operator.
/// For addition, this would be `__add__`, and `__iadd__` if
/// it is the augmented assigning variant.
pub method_name: &'static str,
/// The symbol of the binary operator.
/// For addition, this would be `+`, and `+=` if
/// it is the augmented assigning variant.
pub symbol: &'static str,
}
/// Helper macro to conveniently build an [`OpInfo`].
///
/// Example usage: `make_info("add", "+")` generates `OpInfo { name: "__add__", symbol: "+" }`
macro_rules! make_op_info {
($name:expr, $symbol:expr) => {
OpInfo { method_name: concat!("__", $name, "__"), symbol: $symbol }
};
}
pub trait HasOpInfo {
fn op_info(&self) -> OpInfo;
}
fn try_get_cmpop_info(op: Cmpop) -> Option<OpInfo> {
#[must_use]
pub fn binop_name(op: Operator) -> &'static str {
match op {
Cmpop::Lt => Some(make_op_info!("lt", "<")),
Cmpop::LtE => Some(make_op_info!("le", "<=")),
Cmpop::Gt => Some(make_op_info!("gt", ">")),
Cmpop::GtE => Some(make_op_info!("ge", ">=")),
Cmpop::Eq => Some(make_op_info!("eq", "==")),
Cmpop::NotEq => Some(make_op_info!("ne", "!=")),
Operator::Add => "__add__",
Operator::Sub => "__sub__",
Operator::Div => "__truediv__",
Operator::Mod => "__mod__",
Operator::Mult => "__mul__",
Operator::Pow => "__pow__",
Operator::BitOr => "__or__",
Operator::BitXor => "__xor__",
Operator::BitAnd => "__and__",
Operator::LShift => "__lshift__",
Operator::RShift => "__rshift__",
Operator::FloorDiv => "__floordiv__",
Operator::MatMult => "__matmul__",
}
}
#[must_use]
pub fn binop_assign_name(op: Operator) -> &'static str {
match op {
Operator::Add => "__iadd__",
Operator::Sub => "__isub__",
Operator::Div => "__itruediv__",
Operator::Mod => "__imod__",
Operator::Mult => "__imul__",
Operator::Pow => "__ipow__",
Operator::BitOr => "__ior__",
Operator::BitXor => "__ixor__",
Operator::BitAnd => "__iand__",
Operator::LShift => "__ilshift__",
Operator::RShift => "__irshift__",
Operator::FloorDiv => "__ifloordiv__",
Operator::MatMult => "__imatmul__",
}
}
#[must_use]
pub fn unaryop_name(op: Unaryop) -> &'static str {
match op {
Unaryop::UAdd => "__pos__",
Unaryop::USub => "__neg__",
Unaryop::Not => "__not__",
Unaryop::Invert => "__inv__",
}
}
#[must_use]
pub fn comparison_name(op: Cmpop) -> Option<&'static str> {
match op {
Cmpop::Lt => Some("__lt__"),
Cmpop::LtE => Some("__le__"),
Cmpop::Gt => Some("__gt__"),
Cmpop::GtE => Some("__ge__"),
Cmpop::Eq => Some("__eq__"),
Cmpop::NotEq => Some("__ne__"),
_ => None,
}
}
impl OpInfo {
#[must_use]
pub fn supports_cmpop(op: Cmpop) -> bool {
try_get_cmpop_info(op).is_some()
}
}
impl HasOpInfo for Cmpop {
fn op_info(&self) -> OpInfo {
try_get_cmpop_info(*self).expect("{self:?} is not supported")
}
}
impl HasOpInfo for Binop {
fn op_info(&self) -> OpInfo {
// Helper macro to generate both the normal variant [`OpInfo`] and the
// augmented assigning variant [`OpInfo`] for a binary operator conveniently.
macro_rules! info {
($name:literal, $symbol:literal) => {
(
make_op_info!($name, $symbol),
make_op_info!(concat!("i", $name), concat!($symbol, "=")),
)
};
}
let (normal_variant, aug_assign_variant) = match self.base {
Operator::Add => info!("add", "+"),
Operator::Sub => info!("sub", "-"),
Operator::Div => info!("truediv", "/"),
Operator::Mod => info!("mod", "%"),
Operator::Mult => info!("mul", "*"),
Operator::Pow => info!("pow", "**"),
Operator::BitOr => info!("or", "|"),
Operator::BitXor => info!("xor", "^"),
Operator::BitAnd => info!("and", "&"),
Operator::LShift => info!("lshift", "<<"),
Operator::RShift => info!("rshift", ">>"),
Operator::FloorDiv => info!("floordiv", "//"),
Operator::MatMult => info!("matmul", "@"),
};
match self.variant {
BinopVariant::Normal => normal_variant,
BinopVariant::AugAssign => aug_assign_variant,
}
}
}
impl HasOpInfo for Unaryop {
fn op_info(&self) -> OpInfo {
match self {
Unaryop::UAdd => make_op_info!("pos", "+"),
Unaryop::USub => make_op_info!("neg", "-"),
Unaryop::Not => make_op_info!("not", "not"), // i.e., `not False`, so the symbol is just `not`.
Unaryop::Invert => make_op_info!("inv", "~"),
}
}
}
pub(super) fn with_fields<F>(unifier: &mut Unifier, ty: Type, f: F)
where
F: FnOnce(&mut Unifier, &mut HashMap<StrRef, (Type, bool)>),
@ -190,9 +115,23 @@ pub fn impl_binop(
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty);
for (base_op, variant) in iproduct!(ops, [BinopVariant::Normal, BinopVariant::AugAssign]) {
let op = Binop { base: *base_op, variant };
fields.insert(op.op_info().method_name.into(), {
for op in ops {
fields.insert(binop_name(*op).into(), {
(
unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty,
vars: function_vars.clone(),
args: vec![FuncArg {
ty: other_ty,
default_value: None,
name: "other".into(),
}],
})),
false,
)
});
fields.insert(binop_assign_name(*op).into(), {
(
unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty,
@ -201,7 +140,6 @@ pub fn impl_binop(
ty: other_ty,
default_value: None,
name: "other".into(),
is_vararg: false,
}],
})),
false,
@ -217,7 +155,7 @@ pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Option<Type>, ops:
for op in ops {
fields.insert(
op.op_info().method_name.into(),
unaryop_name(*op).into(),
(
unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty,
@ -257,7 +195,7 @@ pub fn impl_cmpop(
for op in ops {
fields.insert(
op.op_info().method_name.into(),
comparison_name(*op).unwrap().into(),
(
unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty,
@ -266,7 +204,6 @@ pub fn impl_cmpop(
ty: other_ty,
default_value: None,
name: "other".into(),
is_vararg: false,
}],
})),
false,
@ -492,29 +429,12 @@ pub fn typeof_binop(
lhs: Type,
rhs: Type,
) -> Result<Option<Type>, String> {
let op = Binop { base: op, variant: BinopVariant::Normal };
let is_left_list = lhs.obj_id(unifier).is_some_and(|id| id == PrimDef::List.id());
let is_right_list = rhs.obj_id(unifier).is_some_and(|id| id == PrimDef::List.id());
let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
Ok(Some(match op.base {
Ok(Some(match op {
Operator::Add | Operator::Sub | Operator::Mult | Operator::Mod | Operator::FloorDiv => {
if is_left_list || is_right_list {
if ![Operator::Add, Operator::Mult].contains(&op.base) {
return Err(format!(
"Binary operator {} not supported for list",
op.op_info().symbol
));
}
if is_left_list {
lhs
} else {
rhs
}
} else if is_left_ndarray || is_right_ndarray {
if is_left_ndarray || is_right_ndarray {
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
} else if unifier.unioned(lhs, rhs) {
lhs
@ -524,23 +444,6 @@ pub fn typeof_binop(
}
Operator::MatMult => {
// NOTE: NumPy matmul's LHS and RHS must both be ndarrays. Scalars are not allowed.
match (&*unifier.get_ty(lhs), &*unifier.get_ty(rhs)) {
(
TypeEnum::TObj { obj_id: lhs_obj_id, .. },
TypeEnum::TObj { obj_id: rhs_obj_id, .. },
) if *lhs_obj_id == primitives.ndarray.obj_id(unifier).unwrap()
&& *rhs_obj_id == primitives.ndarray.obj_id(unifier).unwrap() =>
{
// LHS and RHS have valid types
}
_ => {
let lhs_str = unifier.stringify(lhs);
let rhs_str = unifier.stringify(rhs);
return Err(format!("ndarray.__matmul__ only accepts ndarray operands, but left operand has type {lhs_str}, and right operand has type {rhs_str}"));
}
}
let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs);
let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) {
TypeEnum::TLiteral { values, .. } => {
@ -701,8 +604,6 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
bool: bool_t,
uint32: uint32_t,
uint64: uint64_t,
str: str_t,
list: list_t,
ndarray: ndarray_t,
..
} = *store;
@ -747,14 +648,6 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
impl_sign(unifier, store, bool_t, Some(int32_t));
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None);
/* str ========= */
impl_cmpop(unifier, store, str_t, &[str_t], &[Cmpop::Eq, Cmpop::NotEq], Some(bool_t));
/* list ======== */
impl_binop(unifier, store, list_t, &[list_t], Some(list_t), &[Operator::Add]);
impl_binop(unifier, store, list_t, &[int32_t, int64_t], Some(list_t), &[Operator::Mult]);
impl_cmpop(unifier, store, list_t, &[list_t], &[Cmpop::Eq, Cmpop::NotEq], Some(bool_t));
/* ndarray ===== */
let ndarray_usized_ndims_tvar =
unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None);

View File

@ -1,45 +1,24 @@
use std::{collections::HashMap, fmt::Display};
use std::collections::HashMap;
use std::fmt::Display;
use itertools::Itertools;
use crate::typecheck::typedef::TypeEnum;
use nac3parser::ast::{Cmpop, Location, StrRef};
use super::{
magic_methods::{Binop, HasOpInfo},
typedef::{RecordKey, Type, TypeEnum, Unifier},
};
use super::typedef::{RecordKey, Type, Unifier};
use nac3parser::ast::{Location, StrRef};
#[derive(Debug, Clone)]
pub enum TypeErrorKind {
GotMultipleValues {
name: StrRef,
},
TooManyArguments {
expected_min_count: usize,
expected_max_count: usize,
got_count: usize,
},
MissingArgs {
missing_arg_names: Vec<StrRef>,
expected: usize,
got: usize,
},
MissingArgs(String),
UnknownArgName(StrRef),
IncorrectArgType {
name: StrRef,
expected: Type,
got: Type,
},
UnsupportedBinaryOpTypes {
operator: Binop,
lhs_type: Type,
rhs_type: Type,
expected_rhs_type: Type,
},
UnsupportedComparsionOpTypes {
operator: Cmpop,
lhs_type: Type,
rhs_type: Type,
expected_rhs_type: Type,
},
FieldUnificationError {
field: RecordKey,
types: (Type, Type),
@ -55,7 +34,6 @@ pub enum TypeErrorKind {
},
RequiresTypeAnn,
PolymorphicFunctionPointer,
NoSuchAttribute(RecordKey, Type),
}
#[derive(Debug, Clone)]
@ -99,49 +77,19 @@ impl<'a> Display for DisplayTypeError<'a> {
use TypeErrorKind::*;
let mut notes = Some(HashMap::new());
match &self.err.kind {
GotMultipleValues { name } => {
write!(f, "For multiple values for parameter {name}")
TooManyArguments { expected, got } => {
write!(f, "Too many arguments. Expected {expected} but got {got}")
}
TooManyArguments { expected_min_count, expected_max_count, got_count } => {
debug_assert!(expected_min_count <= expected_max_count);
if expected_min_count == expected_max_count {
let expected_count = expected_min_count; // or expected_max_count
write!(f, "Too many arguments. Expected {expected_count} but got {got_count}")
} else {
write!(f, "Too many arguments. Expected {expected_min_count} to {expected_max_count} arguments but got {got_count}")
}
}
MissingArgs { missing_arg_names } => {
let args = missing_arg_names.iter().join(", ");
MissingArgs(args) => {
write!(f, "Missing arguments: {args}")
}
UnsupportedBinaryOpTypes { operator, lhs_type, rhs_type, expected_rhs_type } => {
let op_symbol = operator.op_info().symbol;
let lhs_type_str = self.unifier.stringify_with_notes(*lhs_type, &mut notes);
let rhs_type_str = self.unifier.stringify_with_notes(*rhs_type, &mut notes);
let expected_rhs_type_str =
self.unifier.stringify_with_notes(*expected_rhs_type, &mut notes);
write!(f, "Unsupported operand type(s) for {op_symbol}: '{lhs_type_str}' and '{rhs_type_str}' (right operand should have type {expected_rhs_type_str})")
}
UnsupportedComparsionOpTypes { operator, lhs_type, rhs_type, expected_rhs_type } => {
let op_symbol = operator.op_info().symbol;
let lhs_type_str = self.unifier.stringify_with_notes(*lhs_type, &mut notes);
let rhs_type_str = self.unifier.stringify_with_notes(*rhs_type, &mut notes);
let expected_rhs_type_str =
self.unifier.stringify_with_notes(*expected_rhs_type, &mut notes);
write!(f, "'{op_symbol}' not supported between instances of '{lhs_type_str}' and '{rhs_type_str}' (right operand should have type {expected_rhs_type_str})")
}
UnknownArgName(name) => {
write!(f, "Unknown argument name: {name}")
}
IncorrectArgType { name, expected, got } => {
let expected = self.unifier.stringify_with_notes(*expected, &mut notes);
let got = self.unifier.stringify_with_notes(*got, &mut notes);
write!(f, "Incorrect argument type for parameter {name}. Expected {expected}, but got {got}")
write!(f, "Incorrect argument type for {name}. Expected {expected}, but got {got}")
}
FieldUnificationError { field, types, loc } => {
let lhs = self.unifier.stringify_with_notes(types.0, &mut notes);
@ -182,10 +130,9 @@ impl<'a> Display for DisplayTypeError<'a> {
}
result
}
(
TypeEnum::TTuple { ty: ty1, is_vararg_ctx: is_vararg1 },
TypeEnum::TTuple { ty: ty2, is_vararg_ctx: is_vararg2 },
) if !is_vararg1 && !is_vararg2 && ty1.len() != ty2.len() => {
(TypeEnum::TTuple { ty: ty1 }, TypeEnum::TTuple { ty: ty2 })
if ty1.len() != ty2.len() =>
{
let t1 = self.unifier.stringify_with_notes(*t1, &mut notes);
let t2 = self.unifier.stringify_with_notes(*t2, &mut notes);
write!(f, "Tuple length mismatch: got {t1} and {t2}")
@ -209,10 +156,6 @@ impl<'a> Display for DisplayTypeError<'a> {
let t = self.unifier.stringify_with_notes(*t, &mut notes);
write!(f, "`{t}::{name}` field/method does not exist")
}
NoSuchAttribute(name, t) => {
let t = self.unifier.stringify_with_notes(*t, &mut notes);
write!(f, "`{t}::{name}` is not a class attribute")
}
TupleIndexOutOfBounds { index, len } => {
write!(
f,

File diff suppressed because it is too large Load Diff

View File

@ -1,19 +1,17 @@
use std::iter::zip;
use indexmap::IndexMap;
use indoc::indoc;
use parking_lot::RwLock;
use test_case::test_case;
use nac3parser::{ast::FileName, parser::parse_program};
use super::super::{magic_methods::with_fields, typedef::*};
use super::*;
use crate::{
codegen::{CodeGenContext, CodeGenerator},
codegen::CodeGenContext,
symbol_resolver::ValueEnum,
toplevel::{helper::PrimDef, DefinitionId, TopLevelDef},
typecheck::{magic_methods::with_fields, typedef::*},
};
use indexmap::IndexMap;
use indoc::indoc;
use nac3parser::ast::FileName;
use nac3parser::parser::parse_program;
use parking_lot::RwLock;
use std::iter::zip;
use test_case::test_case;
struct Resolver {
id_to_type: HashMap<StrRef, Type>,
@ -43,7 +41,6 @@ impl SymbolResolver for Resolver {
&self,
_: StrRef,
_: &mut CodeGenContext<'ctx, '_>,
_: &mut dyn CodeGenerator,
) -> Option<ValueEnum<'ctx>> {
unimplemented!()
}
@ -86,12 +83,7 @@ impl TestEnvironment {
});
with_fields(&mut unifier, int32, |unifier, fields| {
let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg {
name: "other".into(),
ty: int32,
default_value: None,
is_vararg: false,
}],
args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }],
ret: int32,
vars: VarMap::new(),
}));
@ -147,12 +139,6 @@ impl TestEnvironment {
fields: HashMap::new(),
params: VarMap::new(),
});
let list_elem_tvar = unifier.get_fresh_var(Some("list_elem".into()), None);
let list = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::List.id(),
fields: HashMap::new(),
params: into_var_map([list_elem_tvar]),
});
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
let ndarray_ndims_tvar =
unifier.get_fresh_const_generic_var(uint64, Some("ndarray_ndims".into()), None);
@ -173,7 +159,6 @@ impl TestEnvironment {
uint32,
uint64,
option,
list,
ndarray,
size_t: 64,
};
@ -232,12 +217,7 @@ impl TestEnvironment {
});
with_fields(&mut unifier, int32, |unifier, fields| {
let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg {
name: "other".into(),
ty: int32,
default_value: None,
is_vararg: false,
}],
args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }],
ret: int32,
vars: VarMap::new(),
}));
@ -293,35 +273,15 @@ impl TestEnvironment {
fields: HashMap::new(),
params: VarMap::new(),
});
let list_elem_tvar = unifier.get_fresh_var(Some("list_elem".into()), None);
let list = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::List.id(),
fields: HashMap::new(),
params: into_var_map([list_elem_tvar]),
});
let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::NDArray.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
identifier_mapping.insert("None".into(), none);
for (i, name) in [
"int32",
"int64",
"float",
"bool",
"none",
"range",
"str",
"Exception",
"uint32",
"uint64",
"Option",
"list",
"ndarray",
]
.iter()
.enumerate()
for (i, name) in ["int32", "int64", "float", "bool", "none", "range", "str", "Exception"]
.iter()
.enumerate()
{
top_level_defs.push(
RwLock::new(TopLevelDef::Class {
@ -329,7 +289,6 @@ impl TestEnvironment {
object_id: DefinitionId(i),
type_vars: Vec::default(),
fields: Vec::default(),
attributes: Vec::default(),
methods: Vec::default(),
ancestors: Vec::default(),
resolver: None,
@ -339,7 +298,7 @@ impl TestEnvironment {
.into(),
);
}
let defs = 12;
let defs = 7;
let primitives = PrimitiveStore {
int32,
@ -353,7 +312,6 @@ impl TestEnvironment {
uint32,
uint64,
option,
list,
ndarray,
size_t: 64,
};
@ -373,7 +331,6 @@ impl TestEnvironment {
object_id: DefinitionId(defs + 1),
type_vars: vec![tvar.ty],
fields: [("a".into(), tvar.ty, true)].into(),
attributes: Vec::default(),
methods: Vec::default(),
ancestors: Vec::default(),
resolver: None,
@ -408,7 +365,6 @@ impl TestEnvironment {
object_id: DefinitionId(defs + 2),
type_vars: Vec::default(),
fields: [("a".into(), int32, true), ("b".into(), fun, true)].into(),
attributes: Vec::default(),
methods: Vec::default(),
ancestors: Vec::default(),
resolver: None,
@ -437,7 +393,6 @@ impl TestEnvironment {
object_id: DefinitionId(defs + 3),
type_vars: Vec::default(),
fields: [("a".into(), bool, true), ("b".into(), fun, false)].into(),
attributes: Vec::default(),
methods: Vec::default(),
ancestors: Vec::default(),
resolver: None,
@ -465,11 +420,6 @@ impl TestEnvironment {
"range".into(),
"str".into(),
"exception".into(),
"uint32".into(),
"uint64".into(),
"option".into(),
"list".into(),
"ndarray".into(),
"Foo".into(),
"Bar".into(),
"Bar2".into(),
@ -520,7 +470,7 @@ impl TestEnvironment {
primitives: &mut self.primitives,
virtual_checks: &mut self.virtual_checks,
calls: &mut self.calls,
defined_identifiers: HashMap::default(),
defined_identifiers: HashSet::default(),
in_handler: false,
}
}
@ -596,9 +546,8 @@ fn test_basic(source: &str, mapping: &HashMap<&str, &str>, virtuals: &[(&str, &s
println!("source:\n{source}");
let mut env = TestEnvironment::new();
let id_to_name = std::mem::take(&mut env.id_to_name);
let mut defined_identifiers: HashMap<_, _> =
env.identifier_mapping.keys().copied().map(|id| (id, IdentifierInfo::default())).collect();
defined_identifiers.insert("virtual".into(), IdentifierInfo::default());
let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().copied().collect();
defined_identifiers.insert("virtual".into());
let mut inferencer = env.get_inferencer();
inferencer.defined_identifiers.clone_from(&defined_identifiers);
let statements = parse_program(source, FileName::default()).unwrap();
@ -743,9 +692,8 @@ fn test_primitive_magic_methods(source: &str, mapping: &HashMap<&str, &str>) {
println!("source:\n{source}");
let mut env = TestEnvironment::basic_test_env();
let id_to_name = std::mem::take(&mut env.id_to_name);
let mut defined_identifiers: HashMap<_, _> =
env.identifier_mapping.keys().copied().map(|id| (id, IdentifierInfo::default())).collect();
defined_identifiers.insert("virtual".into(), IdentifierInfo::default());
let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().copied().collect();
defined_identifiers.insert("virtual".into());
let mut inferencer = env.get_inferencer();
inferencer.defined_identifiers.clone_from(&defined_identifiers);
let statements = parse_program(source, FileName::default()).unwrap();

View File

@ -1,28 +1,20 @@
use std::{
borrow::Cow,
cell::RefCell,
collections::{HashMap, HashSet},
fmt::{self, Display},
iter::{repeat, zip},
rc::Rc,
sync::{Arc, Mutex},
};
use indexmap::IndexMap;
use itertools::{repeat_n, Itertools};
use itertools::Itertools;
use std::cell::RefCell;
use std::collections::HashMap;
use std::fmt::{self, Display};
use std::iter::zip;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use std::{borrow::Cow, collections::HashSet};
use nac3parser::ast::{Cmpop, Location, StrRef, Unaryop};
use nac3parser::ast::{Location, StrRef};
use super::{
magic_methods::{Binop, HasOpInfo, OpInfo},
type_error::{TypeError, TypeErrorKind},
type_inferencer::PrimitiveStore,
unification_table::{UnificationKey, UnificationTable},
};
use crate::{
symbol_resolver::SymbolValue,
toplevel::{helper::PrimDef, DefinitionId, TopLevelContext, TopLevelDef},
};
use super::type_error::{TypeError, TypeErrorKind};
use super::unification_table::{UnificationKey, UnificationTable};
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::{DefinitionId, TopLevelContext, TopLevelDef};
use crate::typecheck::type_inferencer::PrimitiveStore;
#[cfg(test)]
mod test;
@ -81,28 +73,6 @@ pub fn iter_type_vars(var_map: &VarMap) -> impl Iterator<Item = TypeVar> + '_ {
var_map.iter().map(|(&id, &ty)| TypeVar { id, ty })
}
#[derive(Debug, Clone)]
pub enum OperatorInfo {
/// The call was written as an unary operation, e.g., `~a` or `not a`.
IsUnaryOp {
/// The [`Type`] of the `self` object
self_type: Type,
operator: Unaryop,
},
/// The call was written as a binary operation, e.g., `a + b` or `a += b`.
IsBinaryOp {
/// The [`Type`] of the `self` object
self_type: Type,
operator: Binop,
},
/// The call was written as a binary comparison operation, e.g., `a < b`.
IsComparisonOp {
/// The [`Type`] of the `self` object
self_type: Type,
operator: Cmpop,
},
}
#[derive(Clone)]
pub struct Call {
pub posargs: Vec<Type>,
@ -110,9 +80,6 @@ pub struct Call {
pub ret: Type,
pub fun: RefCell<Option<Type>>,
pub loc: Option<Location>,
/// Details about the associated Python user operator expression of this call, if any.
pub operator_info: Option<OperatorInfo>,
}
#[derive(Debug, Clone)]
@ -120,14 +87,6 @@ pub struct FuncArg {
pub name: StrRef,
pub ty: Type,
pub default_value: Option<SymbolValue>,
pub is_vararg: bool,
}
impl FuncArg {
#[must_use]
pub fn is_required(&self) -> bool {
self.default_value.is_none()
}
}
#[derive(Debug, Clone)]
@ -239,12 +198,12 @@ pub enum TypeEnum {
TTuple {
/// The types of elements present in this tuple.
ty: Vec<Type>,
},
/// Whether this tuple is used in a vararg context.
///
/// If `true`, `ty` must only contain one type, and the tuple is assumed to contain any
/// number of `ty`-typed values.
is_vararg_ctx: bool,
/// A list type.
TList {
/// The type of elements present in this list.
ty: Type,
},
/// An object type.
@ -280,6 +239,7 @@ impl TypeEnum {
TypeEnum::TVar { .. } => "TVar",
TypeEnum::TLiteral { .. } => "TConstant",
TypeEnum::TTuple { .. } => "TTuple",
TypeEnum::TList { .. } => "TList",
TypeEnum::TObj { .. } => "TObj",
TypeEnum::TVirtual { .. } => "TVirtual",
TypeEnum::TCall { .. } => "TCall",
@ -515,31 +475,13 @@ impl Unifier {
)
}
}
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
let tv = iter_type_vars(params).nth(0).unwrap();
let tv_id = if let TypeEnum::TVar { id, .. } =
self.unification_table.probe_value(tv.ty).as_ref()
{
*id
} else {
tv.id
};
self.get_instantiations(tv.ty).map(|ty_insts| {
ty_insts
.iter()
.map(|&ty_inst| {
self.subst(ty, &into_var_map([TypeVar { id: tv_id, ty: ty_inst }]))
.unwrap_or(ty)
})
.collect()
})
}
TypeEnum::TList { ty } => self
.get_instantiations(*ty)
.map(|ty| ty.iter().map(|&ty| self.add_ty(TypeEnum::TList { ty })).collect_vec()),
TypeEnum::TVirtual { ty } => self.get_instantiations(*ty).map(|ty| {
ty.iter().map(|&ty| self.add_ty(TypeEnum::TVirtual { ty })).collect_vec()
}),
TypeEnum::TTuple { ty, is_vararg_ctx } => {
TypeEnum::TTuple { ty } => {
let tuples = ty
.iter()
.map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty]))
@ -549,12 +491,7 @@ impl Unifier {
None
} else {
Some(
tuples
.into_iter()
.map(|ty| {
self.add_ty(TypeEnum::TTuple { ty, is_vararg_ctx: *is_vararg_ctx })
})
.collect(),
tuples.into_iter().map(|ty| self.add_ty(TypeEnum::TTuple { ty })).collect(),
)
}
}
@ -597,8 +534,10 @@ impl Unifier {
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
TCall { .. } => false,
TVirtual { ty } => self.is_concrete(*ty, allowed_typevars),
TTuple { ty, .. } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)),
TList { ty }
| TVirtual { ty } => self.is_concrete(*ty, allowed_typevars),
TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)),
TObj { params: vars, .. } => {
vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars))
}
@ -623,243 +562,69 @@ impl Unifier {
call: &Call,
b: Type,
signature: &FunSignature,
required: &[StrRef],
) -> Result<(), TypeError> {
/*
NOTE: scenarios to consider:
```python
def func1(x: int32, y: int32, z: int32 = 5): pass
# Normal scenarios
func1(23, 45) # OK, z has default
func1(23, 45, 67) # OK, z's default is overwritten
func1(x = 23, y = 45) # OK, user is using kwargs to set positional args
func1(y = 45, x = 23) # OK, kwargs order doesn't matter
# Error scenarios
func1() # ERROR: Missing arguments: x, y
func1(23) # ERROR: Missing arguments: y
func1(z = 23) # ERROR: Missing arguments: x, y
func1(x = 23) # ERROR: Missing arguments: y
func1(23, 45, x = 5) # ERROR: Got multiple values for x
func1(23, 45, x = 5, y = 6) # ERROR: Got multiple values for x (y too but Python does not report it)
func1(23, 45, 67, z = 89) # ERROR: Got multiple values for z
func1(23, 45, 67, 89) # ERROR: Function only takes from 2 to 3 positional arguments but 4 were given.
func1(23, 45, 67, w = 3) # ERROR: Got an unexpected keyword argument 'w'
# Error scenarios that do not need to be handled here.
func1(23, 45, z = 67, z = 89) # ERROR: Keyword argument repeated: z, the parser panics on this.
```
*/
struct ParamInfo<'a> {
/// Has this parameter been supplied with an argument already?
has_been_supplied: bool,
/// The corresponding [`FuncArg`] instance of this parameter (for fast table lookups)
param: &'a FuncArg,
}
let snapshot = self.unification_table.get_snapshot();
if self.snapshot.is_none() {
self.snapshot = Some(snapshot);
}
// Get details about the function signature/parameters.
let num_params = signature.args.len();
let is_vararg = signature.args.iter().any(|arg| arg.is_vararg);
// Force the type vars in `b` and `signature' to be up-to-date.
let b = self.instantiate_fun(b, signature);
let TypeEnum::TFunc(signature) = &*self.get_ty(b) else { unreachable!() };
// Get details about the input arguments
let Call { posargs, kwargs, ret, fun, loc, operator_info } = call;
let num_args = posargs.len() + kwargs.len();
// Now we check the arguments against the parameters,
// and depending on what `call_info` is, we might change how `unify_call()` behaves
// to improve user error messages when type checking fails.
match operator_info {
Some(OperatorInfo::IsBinaryOp { self_type, operator }) => {
// The call is written in the form of (say) `a + b`.
// Technically, it is `a.__add__(b)`, and they have the following constraints:
assert_eq!(posargs.len(), 1);
assert_eq!(kwargs.len(), 0);
assert_eq!(num_params, 1);
let other_type = posargs[0]; // the second operand
let expected_other_type = signature.args[0].ty;
let ok = self.unify_impl(expected_other_type, other_type, false).is_ok();
if !ok {
self.restore_snapshot();
return Err(TypeError::new(
TypeErrorKind::UnsupportedBinaryOpTypes {
operator: *operator,
lhs_type: *self_type,
rhs_type: other_type,
expected_rhs_type: expected_other_type,
},
*loc,
));
}
}
Some(OperatorInfo::IsComparisonOp { self_type, operator })
if OpInfo::supports_cmpop(*operator) // Otherwise that comparison operator is not supported.
=>
{
// The call is written in the form of (say) `a <= b`.
// Technically, it is `a.__le__(b)`, and they have the following constraints:
assert_eq!(posargs.len(), 1);
assert_eq!(kwargs.len(), 0);
assert_eq!(num_params, 1);
let other_type = posargs[0]; // the second operand
let expected_other_type = signature.args[0].ty;
let ok = self.unify_impl(expected_other_type, other_type, false).is_ok();
if !ok {
self.restore_snapshot();
return Err(TypeError::new(
TypeErrorKind::UnsupportedComparsionOpTypes {
operator: *operator,
lhs_type: *self_type,
rhs_type: other_type,
expected_rhs_type: expected_other_type,
},
*loc,
));
}
}
_ => {
// Handle [`CallInfo::IsNormalFunctionCall`] and other uninteresting variants
// of [`CallInfo`] (e.g, `CallInfo::IsUnaryOp` and unsupported comparison operators)
// Helper lambdas
let mut type_check_arg = |param_name, expected_arg_ty, arg_ty| {
let ok = self.unify_impl(expected_arg_ty, arg_ty, false).is_ok();
if ok {
Ok(())
} else {
// Typecheck failed, throw an error.
self.restore_snapshot();
Err(TypeError::new(
TypeErrorKind::IncorrectArgType {
name: param_name,
expected: expected_arg_ty,
got: arg_ty,
},
*loc,
))
}
};
// Check for "too many arguments"
if !is_vararg && num_params < posargs.len() {
let expected_min_count =
signature.args.iter().filter(|param| param.is_required()).count();
let expected_max_count = num_params;
self.restore_snapshot();
return Err(TypeError::new(
TypeErrorKind::TooManyArguments {
expected_min_count,
expected_max_count,
got_count: num_args,
},
*loc,
));
}
// NOTE: order of `param_info_by_name` is leveraged, so use an IndexMap
let mut param_info_by_name: IndexMap<StrRef, ParamInfo> = signature
.args
.iter()
.map(|arg| (arg.name, ParamInfo { has_been_supplied: false, param: arg }))
.collect();
// Now consume all positional arguments and typecheck them.
for (&arg_ty, param) in zip(posargs, signature.args.iter()) {
// We will also use this opportunity to mark the corresponding `param_info` as having been supplied.
let param_info = param_info_by_name.get_mut(&param.name).unwrap();
param_info.has_been_supplied = true;
// Typecheck
type_check_arg(param.name, param.ty, arg_ty)?;
}
if is_vararg {
debug_assert!(!signature.args.is_empty());
let vararg_args = posargs.iter().skip(signature.args.len());
let vararg_param = signature.args.last().unwrap();
for (&arg_ty, param) in zip(vararg_args, repeat(vararg_param)) {
// `param_info` for this argument would've already been marked as supplied
// during non-vararg posarg typecheck
type_check_arg(param.name, param.ty, arg_ty)?;
}
}
// Now consume all keyword arguments and typecheck them.
for (&param_name, &arg_ty) in kwargs {
// We will also use this opportunity to check if this keyword argument is "legal".
let Some(param_info) = param_info_by_name.get_mut(&param_name) else {
self.restore_snapshot();
return Err(TypeError::new(
TypeErrorKind::UnknownArgName(param_name),
*loc,
));
};
if param_info.has_been_supplied {
// NOTE: Duplicate keyword argument (i.e., `hello(1, 2, 3, arg = 4, arg = 5)`)
// is IMPOSSIBLE as the parser would have already failed.
// We only have to care about "got multiple values for XYZ"
self.restore_snapshot();
return Err(TypeError::new(
TypeErrorKind::GotMultipleValues { name: param_name },
*loc,
));
}
param_info.has_been_supplied = true;
// Typecheck
type_check_arg(param_name, param_info.param.ty, arg_ty)?;
}
// After checking posargs and kwargs, check if there are any
// unsupplied required parameters, and throw an error if they exist.
let missing_arg_names = param_info_by_name
.values()
.filter(|param_info| {
param_info.param.is_required() && !param_info.has_been_supplied
})
.map(|param_info| param_info.param.name)
.collect_vec();
if !missing_arg_names.is_empty() {
self.restore_snapshot();
return Err(TypeError::new(
TypeErrorKind::MissingArgs { missing_arg_names },
*loc,
));
}
// Finally, check the Call's return type
self.unify_impl(*ret, signature.ret, false).map_err(|mut err| {
self.restore_snapshot();
if err.loc.is_none() {
err.loc = *loc;
}
err
})?;
let Call { posargs, kwargs, ret, fun, loc } = call;
let instantiated = self.instantiate_fun(b, signature);
let r = self.get_ty(instantiated);
let r = r.as_ref();
let TypeEnum::TFunc(signature) = r else { unreachable!() };
// we check to make sure that all required arguments (those without default
// arguments) are provided, and do not provide the same argument twice.
let mut required = required.to_vec();
let mut all_names: Vec<_> = signature.args.iter().map(|v| (v.name, v.ty)).rev().collect();
for (i, t) in posargs.iter().enumerate() {
if signature.args.len() <= i {
self.restore_snapshot();
return Err(TypeError::new(
TypeErrorKind::TooManyArguments {
expected: signature.args.len(),
got: posargs.len() + kwargs.len(),
},
*loc,
));
}
required.pop();
let (name, expected) = all_names.pop().unwrap();
self.unify_impl(expected, *t, false).map_err(|_| {
self.restore_snapshot();
TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc)
})?;
}
*fun.borrow_mut() = Some(b);
for (k, t) in kwargs {
if let Some(i) = required.iter().position(|v| v == k) {
required.remove(i);
}
let i = all_names.iter().position(|v| &v.0 == k).ok_or_else(|| {
self.restore_snapshot();
TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc)
})?;
let (name, expected) = all_names.remove(i);
self.unify_impl(expected, *t, false).map_err(|_| {
self.restore_snapshot();
TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc)
})?;
}
if !required.is_empty() {
self.restore_snapshot();
return Err(TypeError::new(
TypeErrorKind::MissingArgs(required.iter().join(", ")),
*loc,
));
}
self.unify_impl(*ret, signature.ret, false).map_err(|mut err| {
self.restore_snapshot();
if err.loc.is_none() {
err.loc = *loc;
}
err
})?;
*fun.borrow_mut() = Some(instantiated);
self.discard_snapshot(snapshot);
Ok(())
@ -990,10 +755,7 @@ impl Unifier {
self.unify_impl(x, b, false)?;
self.set_a_to_b(a, x);
}
(
TVar { fields: Some(fields), range, is_const_generic: false, .. },
TTuple { ty, .. },
) => {
(TVar { fields: Some(fields), range, is_const_generic: false, .. }, TTuple { ty }) => {
let len = i32::try_from(ty.len()).unwrap();
for (k, v) in fields {
match *k {
@ -1014,18 +776,8 @@ impl Unifier {
self.unify_impl(v.ty, ty[ind as usize], false)
.map_err(|e| e.at(v.loc))?;
}
RecordKey::Str(s) => {
let tuple_fns = [
Cmpop::Eq.op_info().method_name,
Cmpop::NotEq.op_info().method_name,
];
if !tuple_fns.into_iter().any(|op| s.to_string() == op) {
return Err(TypeError::new(
TypeErrorKind::NoSuchField(*k, b),
v.loc,
));
}
RecordKey::Str(_) => {
return Err(TypeError::new(TypeErrorKind::NoSuchField(*k, b), v.loc))
}
}
}
@ -1033,6 +785,22 @@ impl Unifier {
self.unify_impl(x, b, false)?;
self.set_a_to_b(a, x);
}
(TVar { fields: Some(fields), range, is_const_generic: false, .. }, TList { ty }) => {
for (k, v) in fields {
match *k {
RecordKey::Int(_) => {
self.unify_impl(v.ty, *ty, false).map_err(|e| e.at(v.loc))?;
}
RecordKey::Str(_) => {
return Err(TypeError::new(TypeErrorKind::NoSuchField(*k, b), v.loc))
}
}
}
let x = self.check_var_compatibility(b, range)?.unwrap_or(b);
self.unify_impl(x, b, false)?;
self.set_a_to_b(a, x);
}
(
TVar { id: id1, range: ty1, is_const_generic: true, .. },
TVar { id: id2, range: ty2, .. },
@ -1100,50 +868,24 @@ impl Unifier {
self.set_a_to_b(a, b);
}
(
TTuple { ty: ty1, is_vararg_ctx: is_vararg1 },
TTuple { ty: ty2, is_vararg_ctx: is_vararg2 },
) => {
// Rules for Tuples:
// - ty1: is_vararg && ty2: is_vararg -> ty1[0] == ty2[0]
// - ty1: is_vararg && ty2: !is_vararg -> type error (not enough info to infer the correct number of arguments)
// - ty1: !is_vararg && ty2: is_vararg -> ty1[..] == ty2[0]
// - ty1: !is_vararg && ty2: !is_vararg -> ty1.len() == ty2.len() && ty1[i] == ty2[i]
debug_assert!(!is_vararg1 || ty1.len() == 1);
debug_assert!(!is_vararg2 || ty2.len() == 1);
match (*is_vararg1, *is_vararg2) {
(true, true) => {
if self.unify_impl(ty1[0], ty2[0], false).is_err() {
return Self::incompatible_types(a, b);
}
}
(true, false) => return Self::incompatible_types(a, b),
(false, true) => {
for y in ty2 {
if self.unify_impl(ty1[0], *y, false).is_err() {
return Self::incompatible_types(a, b);
}
}
}
(false, false) => {
if ty1.len() != ty2.len() {
return Self::incompatible_types(a, b);
}
for (x, y) in ty1.iter().zip(ty2.iter()) {
if self.unify_impl(*x, *y, false).is_err() {
return Self::incompatible_types(a, b);
}
}
(TTuple { ty: ty1 }, TTuple { ty: ty2 }) => {
if ty1.len() != ty2.len() {
return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None));
}
for (x, y) in ty1.iter().zip(ty2.iter()) {
if self.unify_impl(*x, *y, false).is_err() {
return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None));
}
}
self.set_a_to_b(a, b);
}
(TVar { fields: Some(map), range, .. }, TObj { obj_id, fields, params }) => {
(TList { ty: ty1 }, TList { ty: ty2 }) => {
if self.unify_impl(*ty1, *ty2, false).is_err() {
return Err(TypeError::new(TypeErrorKind::IncompatibleTypes(a, b), None));
}
self.set_a_to_b(a, b);
}
(TVar { fields: Some(map), range, .. }, TObj { fields, .. }) => {
for (k, field) in map {
match *k {
RecordKey::Str(s) => {
@ -1162,18 +904,10 @@ impl Unifier {
self.unify_impl(field.ty, ty, false).map_err(|v| v.at(field.loc))?;
}
RecordKey::Int(_) => {
// Allow expressions such as list[0]
if *obj_id == PrimDef::List.id() {
let ty = iter_type_vars(params).nth(0).unwrap().ty;
self.unify_impl(field.ty, ty, false)
.map_err(|e| e.at(field.loc))?;
} else {
return Err(TypeError::new(
TypeErrorKind::NoSuchField(*k, b),
field.loc,
));
}
return Err(TypeError::new(
TypeErrorKind::NoSuchField(*k, b),
field.loc,
))
}
}
}
@ -1256,10 +990,17 @@ impl Unifier {
self.unification_table.set_value(b, Rc::new(TCall(calls)));
}
(TCall(calls), TFunc(signature)) => {
let required: Vec<StrRef> = signature
.args
.iter()
.filter(|v| v.default_value.is_none())
.map(|v| v.name)
.rev()
.collect();
// we unify every calls to the function signature.
for c in calls {
let call = self.calls[c.0].clone();
self.unify_call(&call, b, signature)?;
self.unify_call(&call, b, signature, &required)?;
}
self.set_a_to_b(a, b);
}
@ -1383,22 +1124,13 @@ impl Unifier {
TypeEnum::TLiteral { values, .. } => {
format!("const({})", values.iter().map(|v| format!("{v:?}")).join(", "))
}
TypeEnum::TTuple { ty, is_vararg_ctx } => {
if *is_vararg_ctx {
debug_assert_eq!(ty.len(), 1);
let field = self.internal_stringify(
*ty.iter().next().unwrap(),
obj_to_name,
var_to_name,
notes,
);
format!("tuple[*{field}]")
} else {
let mut fields = ty
.iter()
.map(|v| self.internal_stringify(*v, obj_to_name, var_to_name, notes));
format!("tuple[{}]", fields.join(", "))
}
TypeEnum::TTuple { ty } => {
let mut fields =
ty.iter().map(|v| self.internal_stringify(*v, obj_to_name, var_to_name, notes));
format!("tuple[{}]", fields.join(", "))
}
TypeEnum::TList { ty } => {
format!("list[{}]", self.internal_stringify(*ty, obj_to_name, var_to_name, notes))
}
TypeEnum::TVirtual { ty } => {
format!(
@ -1423,21 +1155,17 @@ impl Unifier {
.args
.iter()
.map(|arg| {
let vararg_prefix = if arg.is_vararg { "*" } else { "" };
if let Some(dv) = &arg.default_value {
format!(
"{}:{}{}={}",
"{}:{}={}",
arg.name,
vararg_prefix,
self.internal_stringify(arg.ty, obj_to_name, var_to_name, notes),
dv
)
} else {
format!(
"{}:{}{}",
"{}:{}",
arg.name,
vararg_prefix,
self.internal_stringify(arg.ty, obj_to_name, var_to_name, notes)
)
}
@ -1523,7 +1251,7 @@ impl Unifier {
match &*ty {
TypeEnum::TRigidVar { .. } | TypeEnum::TLiteral { .. } => None,
TypeEnum::TVar { id, .. } => mapping.get(id).copied(),
TypeEnum::TTuple { ty, is_vararg_ctx } => {
TypeEnum::TTuple { ty } => {
let mut new_ty = Cow::from(ty);
for (i, t) in ty.iter().enumerate() {
if let Some(t1) = self.subst_impl(*t, mapping, cache) {
@ -1531,14 +1259,14 @@ impl Unifier {
}
}
if matches!(new_ty, Cow::Owned(_)) {
Some(self.add_ty(TypeEnum::TTuple {
ty: new_ty.into_owned(),
is_vararg_ctx: *is_vararg_ctx,
}))
Some(self.add_ty(TypeEnum::TTuple { ty: new_ty.into_owned() }))
} else {
None
}
}
TypeEnum::TList { ty } => {
self.subst_impl(*ty, mapping, cache).map(|t| self.add_ty(TypeEnum::TList { ty: t }))
}
TypeEnum::TVirtual { ty } => self
.subst_impl(*ty, mapping, cache)
.map(|t| self.add_ty(TypeEnum::TVirtual { ty: t })),
@ -1549,7 +1277,6 @@ impl Unifier {
// This is also used to prevent infinite substitution...
let need_subst = params.values().any(|v| {
let ty = self.unification_table.probe_value(*v);
// TODO(Derppening): #444
if let TypeEnum::TVar { id, .. } = ty.as_ref() {
mapping.contains_key(id)
} else {
@ -1694,55 +1421,20 @@ impl Unifier {
}
}
(TVar { range, .. }, _) => self.check_var_compatibility(b, range).or(Err(())),
(
TTuple { ty: ty1, is_vararg_ctx: is_vararg1 },
TTuple { ty: ty2, is_vararg_ctx: is_vararg2 },
) => {
if *is_vararg1 && *is_vararg2 {
let isect_ty = self.get_intersection(ty1[0], ty2[0])?;
Ok(isect_ty.map(|ty| self.add_ty(TTuple { ty: vec![ty], is_vararg_ctx: true })))
(TTuple { ty: ty1 }, TTuple { ty: ty2 }) if ty1.len() == ty2.len() => {
let ty: Vec<_> = zip(ty1.iter(), ty2.iter())
.map(|(a, b)| self.get_intersection(*a, *b))
.try_collect()?;
if ty.iter().any(Option::is_some) {
Ok(Some(self.add_ty(TTuple {
ty: zip(ty, ty1.iter()).map(|(a, b)| a.unwrap_or(*b)).collect(),
})))
} else {
let zip_iter: Box<dyn Iterator<Item = (&Type, &Type)>> =
match (*is_vararg1, *is_vararg2) {
(true, _) => Box::new(repeat_n(&ty1[0], ty2.len()).zip(ty2.iter())),
(_, false) => Box::new(ty1.iter().zip(repeat_n(&ty2[0], ty1.len()))),
_ => {
if ty1.len() != ty2.len() {
return Err(());
}
Box::new(ty1.iter().zip(ty2.iter()))
}
};
let ty: Vec<_> =
zip_iter.map(|(a, b)| self.get_intersection(*a, *b)).try_collect()?;
Ok(if ty.iter().any(Option::is_some) {
Some(self.add_ty(TTuple {
ty: zip(ty, ty1.iter()).map(|(a, b)| a.unwrap_or(*b)).collect(),
is_vararg_ctx: false,
}))
} else {
None
})
Ok(None)
}
}
// TODO(Derppening): #444
(
TObj { obj_id: id1, fields, params: params1 },
TObj { obj_id: id2, params: params2, .. },
) if *id1 == PrimDef::List.id() && *id2 == PrimDef::List.id() => {
let tv_id = iter_type_vars(params1).nth(0).unwrap().id;
let ty1 = iter_type_vars(params1).nth(0).unwrap().ty;
let ty2 = iter_type_vars(params2).nth(0).unwrap().ty;
Ok(self.get_intersection(ty1, ty2)?.map(|ty| {
self.add_ty(TObj {
obj_id: *id1,
fields: fields.clone(),
params: into_var_map([TypeVar { id: tv_id, ty }]),
})
}))
(TList { ty: ty1 }, TList { ty: ty2 }) => {
Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TList { ty })))
}
(TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => {
Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TVirtual { ty })))

View File

@ -1,12 +1,10 @@
use std::collections::HashMap;
use super::super::magic_methods::with_fields;
use super::*;
use indoc::indoc;
use itertools::Itertools;
use std::collections::HashMap;
use test_case::test_case;
use super::*;
use crate::typecheck::magic_methods::with_fields;
impl Unifier {
/// Check whether two types are equal.
fn eq(&mut self, a: Type, b: Type) -> bool {
@ -30,14 +28,14 @@ impl Unifier {
TypeEnum::TVar { fields: Some(map1), .. },
TypeEnum::TVar { fields: Some(map2), .. },
) => self.map_eq2(map1, map2),
(
TypeEnum::TTuple { ty: ty1, is_vararg_ctx: false },
TypeEnum::TTuple { ty: ty2, is_vararg_ctx: false },
) => {
(TypeEnum::TTuple { ty: ty1 }, TypeEnum::TTuple { ty: ty2 }) => {
ty1.len() == ty2.len()
&& ty1.iter().zip(ty2.iter()).all(|(t1, t2)| self.eq(*t1, *t2))
}
(TypeEnum::TVirtual { ty: ty1 }, TypeEnum::TVirtual { ty: ty2 }) => self.eq(*ty1, *ty2),
(TypeEnum::TList { ty: ty1 }, TypeEnum::TList { ty: ty2 })
| (TypeEnum::TVirtual { ty: ty1 }, TypeEnum::TVirtual { ty: ty2 }) => {
self.eq(*ty1, *ty2)
}
(
TypeEnum::TObj { obj_id: id1, params: params1, .. },
TypeEnum::TObj { obj_id: id2, params: params2, .. },
@ -121,15 +119,6 @@ impl TestEnvironment {
params: into_var_map([tvar]),
}),
);
let tvar = unifier.get_dummy_var();
type_mapping.insert(
"list".into(),
unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::List.id(),
fields: HashMap::new(),
params: into_var_map([tvar]),
}),
);
TestEnvironment { unifier, type_mapping }
}
@ -144,36 +133,6 @@ impl TestEnvironment {
// for testing only, so we can just panic when the input is malformed
let end = typ.find(|c| ['[', ',', ']', '='].contains(&c)).unwrap_or(typ.len());
match &typ[..end] {
"list" => {
let mut s = &typ[end..];
assert_eq!(&s[0..1], "[");
let mut ty = Vec::new();
while &s[0..1] != "]" {
let result = self.internal_parse(&s[1..], mapping);
ty.push(result.0);
s = result.1;
}
assert_eq!(ty.len(), 1);
let list_elem_tvar = if let TypeEnum::TObj { params, .. } =
&*self.unifier.get_ty_immutable(self.type_mapping["list"])
{
iter_type_vars(params).next().unwrap()
} else {
unreachable!()
};
(
self.unifier
.subst(
self.type_mapping["list"],
&into_var_map([TypeVar { id: list_elem_tvar.id, ty: ty[0] }]),
)
.unwrap(),
&s[1..],
)
}
"tuple" => {
let mut s = &typ[end..];
assert_eq!(&s[0..1], "[");
@ -183,7 +142,13 @@ impl TestEnvironment {
ty.push(result.0);
s = result.1;
}
(self.unifier.add_ty(TypeEnum::TTuple { ty, is_vararg_ctx: false }), &s[1..])
(self.unifier.add_ty(TypeEnum::TTuple { ty }), &s[1..])
}
"list" => {
assert_eq!(&typ[end..=end], "[");
let (ty, s) = self.internal_parse(&typ[end + 1..], mapping);
assert_eq!(&s[0..1], "]");
(self.unifier.add_ty(TypeEnum::TList { ty }), &s[1..])
}
"Record" => {
let mut s = &typ[end..];
@ -309,7 +274,7 @@ fn test_unify(
("v1", "tuple[int]"),
("v2", "list[int]"),
],
(("v1", "v2"), "Incompatible types: 11[0] and tuple[0]")
(("v1", "v2"), "Incompatible types: list[0] and tuple[0]")
; "type mismatch"
)]
#[test_case(2,
@ -333,7 +298,7 @@ fn test_unify(
("v1", "Record[a=float,b=int]"),
("v2", "Foo[v3]"),
],
(("v1", "v2"), "`3[typevar5]::b` field/method does not exist")
(("v1", "v2"), "`3[typevar4]::b` field/method does not exist")
; "record obj merge"
)]
/// Test cases for invalid unifications.
@ -423,14 +388,6 @@ fn test_typevar_range() {
let int_list = env.parse("list[int]", &HashMap::new());
let float_list = env.parse("list[float]", &HashMap::new());
let list_elem_tvar = if let TypeEnum::TObj { params, .. } =
&*env.unifier.get_ty_immutable(env.type_mapping["list"])
{
iter_type_vars(params).next().unwrap()
} else {
unreachable!()
};
// unification between v and int
// where v in (int, bool)
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty;
@ -441,7 +398,7 @@ fn test_typevar_range() {
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty;
assert_eq!(
env.unify(int_list, v),
Err("Expected any one of these types: 0, 2, but got 11[0]".to_string())
Err("Expected any one of these types: 0, 2, but got list[0]".to_string())
);
// unification between v and float
@ -453,11 +410,7 @@ fn test_typevar_range() {
);
let v1 = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty;
let v1_list = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: v1 }]),
});
let v1_list = env.unifier.add_ty(TypeEnum::TList { ty: v1 });
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).ty;
// unification between v and int
// where v in (int, list[v1]), v1 in (int, bool)
@ -471,10 +424,9 @@ fn test_typevar_range() {
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).ty;
// unification between v and list[float]
// where v in (int, list[v1]), v1 in (int, bool)
println!("float_list: {}, v: {}", env.unifier.stringify(float_list), env.unifier.stringify(v));
assert_eq!(
env.unify(float_list, v),
Err("Expected any one of these types: 0, 11[typevar6], but got 11[1]\n\nNotes:\n typevar6 ∈ {0, 2}".to_string())
Err("Expected any one of these types: 0, list[typevar5], but got list[1]\n\nNotes:\n typevar5 ∈ {0, 2}".to_string())
);
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty;
@ -489,66 +441,34 @@ fn test_typevar_range() {
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).ty;
let a_list = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: a }]),
});
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a });
let a_list = env.unifier.get_fresh_var_with_range(&[a_list], None, None).ty;
let b_list = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: b }]),
});
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b });
let b_list = env.unifier.get_fresh_var_with_range(&[b_list], None, None).ty;
env.unifier.unify(a_list, b_list).unwrap();
let float_list = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: float }]),
});
let float_list = env.unifier.add_ty(TypeEnum::TList { ty: float });
env.unifier.unify(a_list, float_list).unwrap();
// previous unifications should not affect a and b
env.unifier.unify(a, int).unwrap();
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).ty;
let a_list = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: a }]),
});
let b_list = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: b }]),
});
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a });
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b });
env.unifier.unify(a_list, b_list).unwrap();
let int_list = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: int }]),
});
let int_list = env.unifier.add_ty(TypeEnum::TList { ty: int });
assert_eq!(
env.unify(a_list, int_list),
Err("Incompatible types: 11[typevar23] and 11[0]\
\n\nNotes:\n typevar23 {1}"
Err("Incompatible types: list[typevar22] and list[0]\
\n\nNotes:\n typevar22 {1}"
.into())
);
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty;
let b = env.unifier.get_dummy_var().ty;
let a_list = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: a }]),
});
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a });
let a_list = env.unifier.get_fresh_var_with_range(&[a_list], None, None).ty;
let b_list = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: b }]),
});
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b });
env.unifier.unify(a_list, b_list).unwrap();
assert_eq!(
env.unify(b, boolean),
@ -562,25 +482,16 @@ fn test_rigid_var() {
let a = env.unifier.get_fresh_rigid_var(None, None).ty;
let b = env.unifier.get_fresh_rigid_var(None, None).ty;
let x = env.unifier.get_dummy_var().ty;
let list_elem_tvar = env.unifier.get_fresh_var(Some("list_elem".into()), None);
let list_a = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: a }]),
});
let list_x = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: x }]),
});
let list_a = env.unifier.add_ty(TypeEnum::TList { ty: a });
let list_x = env.unifier.add_ty(TypeEnum::TList { ty: x });
let int = env.parse("int", &HashMap::new());
let list_int = env.parse("list[int]", &HashMap::new());
assert_eq!(env.unify(a, b), Err("Incompatible types: typevar4 and typevar3".to_string()));
assert_eq!(env.unify(a, b), Err("Incompatible types: typevar3 and typevar2".to_string()));
env.unifier.unify(list_a, list_x).unwrap();
assert_eq!(
env.unify(list_x, list_int),
Err("Incompatible types: 11[typevar3] and 11[0]".to_string())
Err("Incompatible types: list[typevar2] and list[0]".to_string())
);
env.unifier.replace_rigid_var(a, int);
@ -595,25 +506,14 @@ fn test_instantiation() {
let float = env.parse("float", &HashMap::new());
let list_int = env.parse("list[int]", &HashMap::new());
let list_elem_tvar = if let TypeEnum::TObj { params, .. } =
&*env.unifier.get_ty_immutable(env.type_mapping["list"])
{
iter_type_vars(params).next().unwrap()
} else {
unreachable!()
};
let obj_map: HashMap<_, _> = [(0usize, "int"), (1, "float"), (2, "bool"), (11, "list")].into();
let obj_map: HashMap<_, _> = [(0usize, "int"), (1, "float"), (2, "bool")].into();
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty;
let list_v = env
.unifier
.subst(env.type_mapping["list"], &into_var_map([TypeVar { id: list_elem_tvar.id, ty: v }]))
.unwrap();
let list_v = env.unifier.add_ty(TypeEnum::TList { ty: v });
let v1 = env.unifier.get_fresh_var_with_range(&[list_v, int], None, None).ty;
let v2 = env.unifier.get_fresh_var_with_range(&[list_int, float], None, None).ty;
let t = env.unifier.get_dummy_var().ty;
let tuple = env.unifier.add_ty(TypeEnum::TTuple { ty: vec![v, v1, v2], is_vararg_ctx: false });
let tuple = env.unifier.add_ty(TypeEnum::TTuple { ty: vec![v, v1, v2] });
let v3 = env.unifier.get_fresh_var_with_range(&[tuple, t], None, None).ty;
// t = TypeVar('t')
// v = TypeVar('v', int, bool)
@ -636,7 +536,7 @@ fn test_instantiation() {
tuple[int, list[bool], list[int]]
tuple[int, list[int], float]
tuple[int, list[int], list[int]]
v6"
v5"
}
.split('\n')
.collect_vec();

View File

@ -238,7 +238,7 @@ impl<'a> EH_Frame<'a> {
/// From the [specification](https://refspecs.linuxfoundation.org/LSB_5.0.0/LSB-Core-generic/LSB-Core-generic/ehframechpt.html):
///
/// > Each CFI record contains a Common Information Entry (CIE) record followed by 1 or more Frame
/// > Description Entry (FDE) records.
/// Description Entry (FDE) records.
pub struct CFI_Record<'a> {
// It refers to the augmentation data that corresponds to 'R' in the augmentation string
fde_pointer_encoding: u8,

View File

@ -2,9 +2,9 @@
future_incompatible,
let_underscore,
nonstandard_style,
rust_2024_compatibility,
clippy::all
)]
#![warn(rust_2024_compatibility)]
#![warn(clippy::pedantic)]
#![allow(
clippy::cast_possible_truncation,
@ -21,12 +21,13 @@
clippy::wildcard_imports
)]
use std::{collections::HashMap, mem, ptr, slice, str};
use byteorder::{ByteOrder, LittleEndian};
use dwarf::*;
use elf::*;
use std::collections::HashMap;
use std::{mem, ptr, slice, str};
extern crate byteorder;
use byteorder::{ByteOrder, LittleEndian};
mod dwarf;
mod elf;

View File

@ -8,15 +8,15 @@ license = "MIT"
edition = "2021"
[build-dependencies]
lalrpop = "0.22"
lalrpop = "0.20"
[dependencies]
nac3ast = { path = "../nac3ast" }
lalrpop-util = "0.22"
lalrpop-util = "0.20"
log = "0.4"
unic-emoji-char = "0.9"
unic-ucd-ident = "0.9"
unicode_names2 = "1.3"
unicode_names2 = "1.2"
phf = { version = "0.11", features = ["macros"] }
ahash = "0.8"

View File

@ -1,10 +1,8 @@
use crate::{
ast::{Ident, Location},
error::*,
token::Tok,
};
use crate::ast::Ident;
use crate::ast::Location;
use crate::error::*;
use crate::token::Tok;
use lalrpop_util::ParseError;
use nac3ast::*;
pub fn make_config_comment(

View File

@ -1,11 +1,12 @@
//! Define internal parse error types
//! The goal is to provide a matching and a safe error API, maksing errors from LALR
use std::error::Error;
use std::fmt;
use lalrpop_util::ParseError as LalrpopError;
use crate::{ast::Location, token::Tok};
use crate::ast::Location;
use crate::token::Tok;
use std::error::Error;
use std::fmt;
/// Represents an error during lexical scanning.
#[derive(Debug, PartialEq)]

View File

@ -1,11 +1,12 @@
use std::{iter, mem, str};
use std::iter;
use std::mem;
use std::str;
use crate::ast::{Constant, ConversionFlag, Expr, ExprKind, Location};
use crate::error::{FStringError, FStringErrorType, ParseError};
use crate::parser::parse_expression;
use self::FStringErrorType::*;
use crate::{
ast::{Constant, ConversionFlag, Expr, ExprKind, Location},
error::{FStringError, FStringErrorType, ParseError},
parser::parse_expression,
};
struct FStringParser<'a> {
chars: iter::Peekable<str::Chars<'a>>,

View File

@ -1,11 +1,8 @@
use ahash::RandomState;
use std::collections::HashSet;
use ahash::RandomState;
use crate::{
ast,
error::{LexicalError, LexicalErrorType},
};
use crate::ast;
use crate::error::{LexicalError, LexicalErrorType};
pub struct ArgumentList {
pub args: Vec<ast::Expr>,

View File

@ -1,16 +1,16 @@
//! This module takes care of lexing python source text.
//!
//! This means source code is translated into separate tokens.
use std::{char, cmp::Ordering, num::IntErrorKind, str::FromStr};
use unic_emoji_char::is_emoji_presentation;
use unic_ucd_ident::{is_xid_continue, is_xid_start};
pub use super::token::Tok;
use crate::{
ast::{FileName, Location},
error::{LexicalError, LexicalErrorType},
};
use crate::ast::{FileName, Location};
use crate::error::{LexicalError, LexicalErrorType};
use std::char;
use std::cmp::Ordering;
use std::num::IntErrorKind;
use std::str::FromStr;
use unic_emoji_char::is_emoji_presentation;
use unic_ucd_ident::{is_xid_continue, is_xid_start};
#[derive(Clone, Copy, PartialEq, Debug, Default)]
struct IndentationLevel {

View File

@ -19,9 +19,9 @@
future_incompatible,
let_underscore,
nonstandard_style,
rust_2024_compatibility,
clippy::all
)]
#![warn(rust_2024_compatibility)]
#![warn(clippy::pedantic)]
#![allow(
clippy::enum_glob_use,
@ -49,11 +49,11 @@ lalrpop_mod!(
future_incompatible,
let_underscore,
nonstandard_style,
rust_2024_compatibility,
unused,
clippy::all,
clippy::pedantic
)]
#[warn(rust_2024_compatibility)]
python
);
pub mod config_comment_helper;

View File

@ -5,16 +5,14 @@
//! parse a whole program, a single statement, or a single
//! expression.
use nac3ast::Location;
use std::iter;
use nac3ast::Location;
use crate::ast::{self, FileName};
use crate::error::ParseError;
use crate::lexer;
pub use crate::mode::Mode;
use crate::{
ast::{self, FileName},
error::ParseError,
lexer, python,
};
use crate::python;
/*
* Parse python code.

View File

@ -1,8 +1,7 @@
//! Different token definitions.
//! Loosely based on token.h from CPython source:
use std::fmt::{self, Write};
use crate::ast;
use std::fmt::{self, Write};
/// Python source code can be tokenized in a sequence of these tokens.
#[derive(Clone, Debug, PartialEq)]

View File

@ -4,13 +4,16 @@ version = "0.1.0"
authors = ["M-Labs"]
edition = "2021"
[features]
no-escape-analysis = ["nac3core/no-escape-analysis"]
[dependencies]
parking_lot = "0.12"
nac3parser = { path = "../nac3parser" }
nac3core = { path = "../nac3core" }
[dependencies.clap]
version = "4.5"
features = ["derive"]
[dependencies.inkwell]
version = "0.4"
default-features = false
features = ["llvm14-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]

View File

@ -3,66 +3,23 @@
set -e
if [ -z "$1" ]; then
echo "No argument supplied"
exit 1
echo "Requires at least one argument"
exit 1
fi
declare -a nac3args
while [ $# -gt 1 ]; do
case "$1" in
--help)
echo "Usage: check_demo.sh [--debug] [-i686] -- [NAC3ARGS...] demo"
exit
;;
--debug)
debug=1
;;
-i686)
i686=1
;;
--)
shift
break
;;
*)
echo "Unrecognized argument \"$1\""
exit 1
;;
esac
shift
done
while [ $# -gt 1 ]; do
nac3args+=("$1")
shift
done
demo="$1"
echo "### Checking $demo..."
echo ">>>>>> Running $demo with the Python interpreter"
echo -n "Checking $demo... "
./interpret_demo.py "$demo" > interpreted.log
./run_demo.sh --out run.log "${nac3args[@]}" "$demo"
./run_demo.sh --lli --out run_lli.log "${nac3args[@]}" "$demo"
diff -Nau interpreted.log run.log
diff -Nau interpreted.log run_lli.log
echo "ok"
if [ -n "$i686" ]; then
echo "...... Trying NAC3's 32-bit code generator output"
if [ -n "$debug" ]; then
./run_demo.sh --debug -i686 --out run_32.log -- "${nac3args[@]}" "$demo"
else
./run_demo.sh -i686 --out run_32.log -- "${nac3args[@]}" "$demo"
fi
diff -Nau interpreted.log run_32.log
fi
echo "...... Trying NAC3's 64-bit code generator output"
if [ -n "$debug" ]; then
./run_demo.sh --debug --out run_64.log -- "${nac3args[@]}" "$demo"
else
./run_demo.sh --out run_64.log -- "${nac3args[@]}" "$demo"
fi
diff -Nau interpreted.log run_64.log
echo "...... OK"
rm -f interpreted.log \
run_32.log run_64.log
rm -f interpreted.log run.log run_lli.log

View File

@ -2,11 +2,6 @@
set -e
if [ "$1" == "--help" ]; then
echo "Usage: check_demos.sh [CHECKARGS...] [--] [NAC3ARGS...]"
exit
fi
count=0
for demo in src/*.py; do
./check_demo.sh "$@" "$demo"

View File

@ -6,12 +6,14 @@
#include <stdlib.h>
#include <string.h>
#define usize size_t
double dbl_nan(void) {
return NAN;
return NAN;
}
double dbl_inf(void) {
return INFINITY;
return INFINITY;
}
void output_bool(bool x) {
@ -19,19 +21,19 @@ void output_bool(bool x) {
}
void output_int32(int32_t x) {
printf("%" PRId32 "\n", x);
printf("%"PRId32"\n", x);
}
void output_int64(int64_t x) {
printf("%" PRId64 "\n", x);
printf("%"PRId64"\n", x);
}
void output_uint32(uint32_t x) {
printf("%" PRIu32 "\n", x);
printf("%"PRIu32"\n", x);
}
void output_uint64(uint64_t x) {
printf("%" PRIu64 "\n", x);
printf("%"PRIu64"\n", x);
}
void output_float64(double x) {
@ -42,17 +44,8 @@ void output_float64(double x) {
}
}
void output_range(int32_t range[3]) {
printf("range(");
printf("%d, %d", range[0], range[1]);
if (range[2] != 1) {
printf(", %d", range[2]);
}
puts(")");
}
void output_asciiart(int32_t x) {
static const char* chars = " .,-:;i+hHM$*#@ ";
static const char *chars = " .,-:;i+hHM$*#@ ";
if (x < 0) {
putchar('\n');
} else {
@ -61,15 +54,15 @@ void output_asciiart(int32_t x) {
}
struct cslice {
void* data;
size_t len;
void *data;
usize len;
};
void output_int32_list(struct cslice* slice) {
const int32_t* data = (int32_t*)slice->data;
void output_int32_list(struct cslice *slice) {
const int32_t *data = (int32_t *) slice->data;
putchar('[');
for (size_t i = 0; i < slice->len; ++i) {
for (usize i = 0; i < slice->len; ++i) {
if (i == slice->len - 1) {
printf("%d", data[i]);
} else {
@ -80,23 +73,19 @@ void output_int32_list(struct cslice* slice) {
putchar('\n');
}
void output_str(struct cslice* slice) {
const char* data = (const char*)slice->data;
void output_str(struct cslice *slice) {
const char *data = (const char *) slice->data;
for (size_t i = 0; i < slice->len; ++i) {
for (usize i = 0; i < slice->len; ++i) {
putchar(data[i]);
}
}
void output_strln(struct cslice* slice) {
output_str(slice);
putchar('\n');
}
uint64_t dbg_stack_address(__attribute__((unused)) struct cslice* slice) {
uint64_t dbg_stack_address(__attribute__((unused)) struct cslice *slice) {
int i;
void* ptr = (void*)&i;
return (uintptr_t)ptr;
void *ptr = (void *) &i;
return (uintptr_t) ptr;
}
uint32_t __nac3_personality(uint32_t state, uint32_t exception_object, uint32_t context) {
@ -105,26 +94,8 @@ uint32_t __nac3_personality(uint32_t state, uint32_t exception_object, uint32_t
__builtin_unreachable();
}
// See `struct Exception<'a>` in
// https://github.com/m-labs/artiq/blob/master/artiq/firmware/libeh/eh_artiq.rs
struct Exception {
uint32_t id;
struct cslice file;
uint32_t line;
uint32_t column;
struct cslice function;
struct cslice message;
int64_t param[3];
};
uint32_t __nac3_raise(struct Exception* e) {
printf("__nac3_raise called. Exception details:\n");
printf(" ID: %" PRIu32 "\n", e->id);
printf(" Location: %*s:%" PRIu32 ":%" PRIu32 "\n", (int)e->file.len, (const char*)e->file.data, e->line,
e->column);
printf(" Function: %*s\n", (int)e->function.len, (const char*)e->function.data);
printf(" Message: \"%*s\"\n", (int)e->message.len, (const char*)e->message.data);
printf(" Params: {0}=%" PRId64 ", {1}=%" PRId64 ", {2}=%" PRId64 "\n", e->param[0], e->param[1], e->param[2]);
uint32_t __nac3_raise(uint32_t state, uint32_t exception_object, uint32_t context) {
printf("__nac3_raise(state: %u, exception_object: %u, context: %u)\n", state, exception_object, context);
exit(101);
__builtin_unreachable();
}

View File

@ -6,7 +6,6 @@ import importlib.machinery
import math
import numpy as np
import numpy.typing as npt
import scipy as sp
import pathlib
from numpy import int32, int64, uint32, uint64
@ -108,9 +107,6 @@ def patch(module):
def output_float(x):
print("%f" % x)
def output_strln(x):
print(x, end='')
def dbg_stack_address(_):
return 0
@ -124,8 +120,6 @@ def patch(module):
return output_asciiart
elif name == "output_float64":
return output_float
elif name == "output_str":
return output_strln
elif name in {
"output_bool",
"output_int32",
@ -133,8 +127,7 @@ def patch(module):
"output_int32_list",
"output_uint32",
"output_uint64",
"output_strln",
"output_range",
"output_str",
}:
return print
elif name == "dbg_stack_address":
@ -168,7 +161,7 @@ def patch(module):
module.ceil64 = _ceil
module.np_ceil = np.ceil
# NumPy NDArray factory functions
# NumPy ndarray functions
module.ndarray = NDArray
module.np_ndarray = np.ndarray
module.np_empty = np.empty
@ -184,10 +177,8 @@ def patch(module):
module.np_isinf = np.isinf
module.np_min = np.min
module.np_minimum = np.minimum
module.np_argmin = np.argmin
module.np_max = np.max
module.np_maximum = np.maximum
module.np_argmax = np.argmax
module.np_sin = np.sin
module.np_cos = np.cos
module.np_exp = np.exp
@ -218,10 +209,8 @@ def patch(module):
module.np_ldexp = np.ldexp
module.np_hypot = np.hypot
module.np_nextafter = np.nextafter
module.np_transpose = np.transpose
module.np_reshape = np.reshape
# SciPy Math functions
# SciPy Math Functions
module.sp_spec_erf = special.erf
module.sp_spec_erfc = special.erfc
module.sp_spec_gamma = special.gamma
@ -229,19 +218,16 @@ def patch(module):
module.sp_spec_j0 = special.j0
module.sp_spec_j1 = special.j1
# Linalg functions
module.np_dot = np.dot
module.np_linalg_cholesky = np.linalg.cholesky
module.np_linalg_qr = np.linalg.qr
module.np_linalg_svd = np.linalg.svd
module.np_linalg_inv = np.linalg.inv
module.np_linalg_pinv = np.linalg.pinv
module.np_linalg_matrix_power = np.linalg.matrix_power
module.np_linalg_det = np.linalg.det
module.sp_linalg_lu = lambda x: sp.linalg.lu(x, True)
module.sp_linalg_schur = sp.linalg.schur
module.sp_linalg_hessenberg = lambda x: sp.linalg.hessenberg(x, True)
# NumPy NDArray Functions
module.np_ndarray = np.ndarray
module.np_empty = np.empty
module.np_zeros = np.zeros
module.np_ones = np.ones
module.np_full = np.full
module.np_eye = np.eye
module.np_identity = np.identity
module.np_any = np.any
module.np_all = np.all
def file_import(filename, prefix="file_import_"):
filename = pathlib.Path(filename)

View File

@ -1,114 +0,0 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "approx"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6"
dependencies = [
"num-traits",
]
[[package]]
name = "autocfg"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
[[package]]
name = "cslice"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f8cb7306107e4b10e64994de6d3274bd08996a7c1322a27b86482392f96be0a"
[[package]]
name = "libm"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
[[package]]
name = "linalg"
version = "0.1.0"
dependencies = [
"cslice",
"nalgebra",
]
[[package]]
name = "nalgebra"
version = "0.32.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b5c17de023a86f59ed79891b2e5d5a94c705dbe904a5b5c9c952ea6221b03e4"
dependencies = [
"approx",
"num-complex",
"num-rational",
"num-traits",
"simba",
"typenum",
]
[[package]]
name = "num-complex"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
dependencies = [
"num-traits",
]
[[package]]
name = "num-integer"
version = "0.1.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
dependencies = [
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
dependencies = [
"num-integer",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
"libm",
]
[[package]]
name = "paste"
version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
[[package]]
name = "simba"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "061507c94fc6ab4ba1c9a0305018408e312e17c041eb63bef8aa726fa33aceae"
dependencies = [
"approx",
"num-complex",
"num-traits",
"paste",
]
[[package]]
name = "typenum"
version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"

View File

@ -1,13 +0,0 @@
[package]
name = "linalg"
version = "0.1.0"
edition = "2021"
[lib]
crate-type = ["staticlib"]
[dependencies]
nalgebra = {version = "0.32.6", default-features = false, features = ["libm", "alloc"]}
cslice = "0.3.0"
[workspace]

View File

@ -1,406 +0,0 @@
// Uses `nalgebra` crate to invoke `np_linalg` and `sp_linalg` functions
// When converting between `nalgebra::Matrix` and `NDArray` following considerations are necessary
//
// * Both `nalgebra::Matrix` and `NDArray` require their content to be stored in row-major order
// * `NDArray` data pointer can be directly read and converted to `nalgebra::Matrix` (row and column number must be known)
// * `nalgebra::Matrix::as_slice` returns the content of matrix in column-major order and initial data needs to be transposed before storing it in `NDArray` data pointer
use core::slice;
use nalgebra::DMatrix;
fn report_error(
error_name: &str,
fn_name: &str,
file_name: &str,
line_num: u32,
col_num: u32,
err_msg: &str,
) -> ! {
panic!(
"Exception {} from {} in {}:{}:{}, message: {}",
error_name, fn_name, file_name, line_num, col_num, err_msg
);
}
pub struct InputMatrix {
pub ndims: usize,
pub dims: *const usize,
pub data: *mut f64,
}
impl InputMatrix {
fn get_dims(&mut self) -> Vec<usize> {
let dims = unsafe { slice::from_raw_parts(self.dims, self.ndims) };
dims.to_vec()
}
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_cholesky(mat1: *mut InputMatrix, out: *mut InputMatrix) {
let mat1 = mat1.as_mut().unwrap();
let out = out.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "np_linalg_cholesky", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
if dim1[0] != dim1[1] {
let err_msg =
format!("last 2 dimensions of the array must be square: {0} != {1}", dim1[0], dim1[1]);
report_error("LinAlgError", "np_linalg_cholesky", file!(), line!(), column!(), &err_msg);
}
let outdim = out.get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let result = matrix1.cholesky();
match result {
Some(res) => {
out_slice.copy_from_slice(res.unpack().transpose().as_slice());
}
None => {
report_error(
"LinAlgError",
"np_linalg_cholesky",
file!(),
line!(),
column!(),
"Matrix is not positive definite",
);
}
};
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_qr(
mat1: *mut InputMatrix,
out_q: *mut InputMatrix,
out_r: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let out_q = out_q.as_mut().unwrap();
let out_r = out_r.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "np_linalg_cholesky", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let outq_dim = (*out_q).get_dims();
let outr_dim = (*out_r).get_dims();
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let out_q_slice = unsafe { slice::from_raw_parts_mut(out_q.data, outq_dim[0] * outq_dim[1]) };
let out_r_slice = unsafe { slice::from_raw_parts_mut(out_r.data, outr_dim[0] * outr_dim[1]) };
// Refer to https://github.com/dimforge/nalgebra/issues/735
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let res = matrix1.qr();
let (q, r) = res.unpack();
// Uses different algo need to match numpy
out_q_slice.copy_from_slice(q.transpose().as_slice());
out_r_slice.copy_from_slice(r.transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_svd(
mat1: *mut InputMatrix,
outu: *mut InputMatrix,
outs: *mut InputMatrix,
outvh: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let outu = outu.as_mut().unwrap();
let outs = outs.as_mut().unwrap();
let outvh = outvh.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "np_linalg_svd", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let outu_dim = (*outu).get_dims();
let outs_dim = (*outs).get_dims();
let outvh_dim = (*outvh).get_dims();
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let out_u_slice = unsafe { slice::from_raw_parts_mut(outu.data, outu_dim[0] * outu_dim[1]) };
let out_s_slice = unsafe { slice::from_raw_parts_mut(outs.data, outs_dim[0]) };
let out_vh_slice =
unsafe { slice::from_raw_parts_mut(outvh.data, outvh_dim[0] * outvh_dim[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let result = matrix.svd(true, true);
out_u_slice.copy_from_slice(result.u.unwrap().transpose().as_slice());
out_s_slice.copy_from_slice(result.singular_values.as_slice());
out_vh_slice.copy_from_slice(result.v_t.unwrap().transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_inv(mat1: *mut InputMatrix, out: *mut InputMatrix) {
let mat1 = mat1.as_mut().unwrap();
let out = out.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "np_linalg_inv", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
if dim1[0] != dim1[1] {
let err_msg =
format!("last 2 dimensions of the array must be square: {0} != {1}", dim1[0], dim1[1]);
report_error("LinAlgError", "np_linalg_inv", file!(), line!(), column!(), &err_msg);
}
let outdim = out.get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
if !matrix.is_invertible() {
report_error(
"LinAlgError",
"np_linalg_inv",
file!(),
line!(),
column!(),
"no inverse for Singular Matrix",
);
}
let inv = matrix.try_inverse().unwrap();
out_slice.copy_from_slice(inv.transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_pinv(mat1: *mut InputMatrix, out: *mut InputMatrix) {
let mat1 = mat1.as_mut().unwrap();
let out = out.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "np_linalg_pinv", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let outdim = out.get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let svd = matrix.svd(true, true);
let inv = svd.pseudo_inverse(1e-15);
match inv {
Ok(m) => {
out_slice.copy_from_slice(m.transpose().as_slice());
}
Err(err_msg) => {
report_error("LinAlgError", "np_linalg_pinv", file!(), line!(), column!(), err_msg);
}
}
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_matrix_power(
mat1: *mut InputMatrix,
mat2: *mut InputMatrix,
out: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let mat2 = mat2.as_mut().unwrap();
let out = out.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D", mat1.ndims);
report_error("ValueError", "np_linalg_matrix_power", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let power = unsafe { slice::from_raw_parts_mut(mat2.data, 1) };
let power = power[0];
let outdim = out.get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let abs_pow = power.abs();
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let mut result = matrix1.pow(abs_pow as u32);
if power < 0.0 {
if !result.is_invertible() {
report_error(
"LinAlgError",
"np_linalg_inv",
file!(),
line!(),
column!(),
"no inverse for Singular Matrix",
);
}
result = result.try_inverse().unwrap();
}
out_slice.copy_from_slice(result.transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_det(mat1: *mut InputMatrix, out: *mut InputMatrix) {
let mat1 = mat1.as_mut().unwrap();
let out = out.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "np_linalg_det", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, 1) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
if !matrix.is_square() {
let err_msg =
format!("last 2 dimensions of the array must be square: {0} != {1}", dim1[0], dim1[1]);
report_error("LinAlgError", "np_linalg_inv", file!(), line!(), column!(), &err_msg);
}
out_slice[0] = matrix.determinant();
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn sp_linalg_lu(
mat1: *mut InputMatrix,
out_l: *mut InputMatrix,
out_u: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let out_l = out_l.as_mut().unwrap();
let out_u = out_u.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "sp_linalg_lu", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let outl_dim = (*out_l).get_dims();
let outu_dim = (*out_u).get_dims();
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let out_l_slice = unsafe { slice::from_raw_parts_mut(out_l.data, outl_dim[0] * outl_dim[1]) };
let out_u_slice = unsafe { slice::from_raw_parts_mut(out_u.data, outu_dim[0] * outu_dim[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let (_, l, u) = matrix.lu().unpack();
out_l_slice.copy_from_slice(l.transpose().as_slice());
out_u_slice.copy_from_slice(u.transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn sp_linalg_schur(
mat1: *mut InputMatrix,
out_t: *mut InputMatrix,
out_z: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let out_t = out_t.as_mut().unwrap();
let out_z = out_z.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "sp_linalg_schur", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
if dim1[0] != dim1[1] {
let err_msg =
format!("last 2 dimensions of the array must be square: {0} != {1}", dim1[0], dim1[1]);
report_error("LinAlgError", "np_linalg_schur", file!(), line!(), column!(), &err_msg);
}
let out_t_dim = (*out_t).get_dims();
let out_z_dim = (*out_z).get_dims();
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let out_t_slice = unsafe { slice::from_raw_parts_mut(out_t.data, out_t_dim[0] * out_t_dim[1]) };
let out_z_slice = unsafe { slice::from_raw_parts_mut(out_z.data, out_z_dim[0] * out_z_dim[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let (z, t) = matrix.schur().unpack();
out_t_slice.copy_from_slice(t.transpose().as_slice());
out_z_slice.copy_from_slice(z.transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn sp_linalg_hessenberg(
mat1: *mut InputMatrix,
out_h: *mut InputMatrix,
out_q: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let out_h = out_h.as_mut().unwrap();
let out_q = out_q.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}D input", mat1.ndims);
report_error("ValueError", "sp_linalg_hessenberg", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
if dim1[0] != dim1[1] {
let err_msg =
format!("last 2 dimensions of the array must be square: {} != {}", dim1[0], dim1[1]);
report_error("LinAlgError", "sp_linalg_hessenberg", file!(), line!(), column!(), &err_msg);
}
let out_h_dim = (*out_h).get_dims();
let out_q_dim = (*out_q).get_dims();
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let out_h_slice = unsafe { slice::from_raw_parts_mut(out_h.data, out_h_dim[0] * out_h_dim[1]) };
let out_q_slice = unsafe { slice::from_raw_parts_mut(out_q.data, out_q_dim[0] * out_q_dim[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let (q, h) = matrix.hessenberg().unpack();
out_h_slice.copy_from_slice(h.transpose().as_slice());
out_q_slice.copy_from_slice(q.transpose().as_slice());
}

View File

@ -2,9 +2,6 @@
set -e
: "${DEMO_LINALG_STUB:=linalg/target/release/liblinalg.a}"
: "${DEMO_LINALG_STUB32:=linalg/target/i686-unknown-linux-gnu/release/liblinalg.a}"
if [ -z "$1" ]; then
echo "No argument supplied"
exit 1
@ -14,26 +11,25 @@ declare -a nac3args
while [ $# -ge 1 ]; do
case "$1" in
--help)
echo "Usage: run_demo.sh [--help] [--out OUTFILE] [--debug] [-i686] -- [NAC3ARGS...] demo"
echo "Usage: run_demo.sh [--help] [--out OUTFILE] [--lli] [--debug] -- [NAC3ARGS...]"
exit
;;
--out)
shift
outfile="$1"
;;
--lli)
use_lli=1
;;
--debug)
debug=1
;;
-i686)
i686=1
;;
--)
shift
break
;;
*)
echo "Unrecognized argument \"$1\""
exit 1
break
;;
esac
shift
@ -54,19 +50,29 @@ else
fi
rm -f ./*.o ./*.bc demo
if [ -z "$i686" ]; then
if [ -z "$use_lli" ]; then
$nac3standalone "${nac3args[@]}"
clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c
clang -o demo module.o demo.o $DEMO_LINALG_STUB -lm -Wl,--no-warn-search-mismatch
else
$nac3standalone --triple i686-unknown-linux-gnu --target-features +sse2 "${nac3args[@]}"
clang -m32 -c -std=gnu11 -Wall -Wextra -O3 -msse2 -o demo.o demo.c
clang -m32 -o demo module.o demo.o $DEMO_LINALG_STUB32 -lm -Wl,--no-warn-search-mismatch
fi
if [ -z "$outfile" ]; then
./demo
clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c
clang -lm -o demo module.o demo.o
if [ -z "$outfile" ]; then
./demo
else
./demo > "$outfile"
fi
else
./demo > "$outfile"
$nac3standalone --emit-llvm "${nac3args[@]}"
clang -c -std=gnu11 -Wall -Wextra -O3 -emit-llvm -o demo.bc demo.c
shopt -s nullglob
llvm-link -o nac3out.bc module*.bc main.bc
shopt -u nullglob
if [ -z "$outfile" ]; then
lli --extra-module demo.bc --extra-module irrt.bc nac3out.bc
else
lli --extra-module demo.bc --extra-module irrt.bc nac3out.bc > "$outfile"
fi
fi

View File

@ -1,76 +0,0 @@
@extern
def output_int32(x: int32):
...
@extern
def output_bool(x: bool):
...
def example1():
x, *ys, z = (1, 2, 3, 4, 5)
output_int32(x)
output_int32(len(ys))
output_int32(ys[0])
output_int32(ys[1])
output_int32(ys[2])
output_int32(z)
def example2():
x, y, *zs = (1, 2, 3, 4, 5)
output_int32(x)
output_int32(y)
output_int32(len(zs))
output_int32(zs[0])
output_int32(zs[1])
output_int32(zs[2])
def example3():
*xs, y, z = (1, 2, 3, 4, 5)
output_int32(len(xs))
output_int32(xs[0])
output_int32(xs[1])
output_int32(xs[2])
output_int32(y)
output_int32(z)
def example4():
*xs, y, z = (4, 5)
output_int32(len(xs))
output_int32(y)
output_int32(z)
def example5():
# Example from: https://docs.python.org/3/reference/simple_stmts.html#assignment-statements
x = [0, 1]
i = 0
i, x[i] = 1, 2 # i is updated, then x[i] is updated
output_int32(i)
output_int32(x[0])
output_int32(x[1])
class A:
value: int32
def __init__(self):
self.value = 1000
def example6():
ws = [88, 7, 8]
a = A()
x, [y, *ys, a.value], ws[0], (ws[0],) = 1, (2, False, 4, 5), 99, (6,)
output_int32(x)
output_int32(y)
output_bool(ys[0])
output_int32(ys[1])
output_int32(a.value)
output_int32(ws[0])
output_int32(ws[1])
output_int32(ws[2])
def run() -> int32:
example1()
example2()
example3()
example4()
example5()
example6()
return 0

View File

@ -7,7 +7,7 @@ def output_int64(x: int64):
...
@extern
def output_strln(x: str):
def output_str(x: str):
...
@ -33,7 +33,7 @@ class A:
class Initless:
def foo(self):
output_strln("hello")
output_str("hello")
def run() -> int32:
a = A(10)

View File

@ -22,10 +22,6 @@ def output_uint64(x: uint64):
def output_float64(x: float):
...
@extern
def output_range(x: range):
...
@extern
def output_int32_list(x: list[int32]):
...
@ -38,10 +34,6 @@ def output_asciiart(x: int32):
def output_str(x: str):
...
@extern
def output_strln(x: str):
...
def test_output_bool():
output_bool(True)
output_bool(False)
@ -67,15 +59,6 @@ def test_output_float64():
output_float64(16.25)
output_float64(-16.25)
def test_output_range():
r = range(1, 100, 5)
output_int32(r.start)
output_int32(r.stop)
output_int32(r.step)
output_range(range(10))
output_range(range(1, 10))
output_range(range(1, 10, 2))
def test_output_asciiart():
for i in range(17):
output_asciiart(i)
@ -85,8 +68,7 @@ def test_output_int32_list():
output_int32_list([0, 1, 3, 5, 10])
def test_output_str_family():
output_str("hello")
output_strln(" world")
output_str("hello world")
def run() -> int32:
test_output_bool()
@ -95,7 +77,6 @@ def run() -> int32:
test_output_uint32()
test_output_uint64()
test_output_float64()
test_output_range()
test_output_asciiart()
test_output_int32_list()
test_output_str_family()

View File

@ -1,31 +0,0 @@
@extern
def output_int32(x: int32):
...
@extern
def output_int64(x: int64):
...
X: int32 = 0
Y = int64(1)
def f():
global X, Y
X = 1
Y = int64(2)
def run() -> int32:
global X, Y
output_int32(X)
output_int64(Y)
f()
output_int32(X)
output_int64(Y)
X = 0
Y = int64(0)
output_int32(X)
output_int64(Y)
return 0

View File

@ -10,58 +10,23 @@ class A:
def __init__(self, a: int32):
self.a = a
def output_all_fields(self):
output_int32(self.a)
def f1(self):
self.f2()
def set_a(self, a: int32):
self.a = a
def f2(self):
output_int32(self.a)
class B(A):
b: int32
def __init__(self, b: int32):
A.__init__(self, b + 1)
self.set_b(b)
def output_parent_fields(self):
A.output_all_fields(self)
def output_all_fields(self):
A.output_all_fields(self)
output_int32(self.b)
def set_b(self, b: int32):
self.a = b + 1
self.b = b
class C(B):
c: int32
def __init__(self, c: int32):
B.__init__(self, c + 1)
self.c = c
def output_parent_fields(self):
B.output_all_fields(self)
def output_all_fields(self):
B.output_all_fields(self)
output_int32(self.c)
def set_c(self, c: int32):
self.c = c
def run() -> int32:
ccc = C(10)
ccc.output_all_fields()
ccc.set_a(1)
ccc.set_b(2)
ccc.set_c(3)
ccc.output_all_fields()
bbb = B(10)
bbb.set_a(9)
bbb.set_b(8)
bbb.output_all_fields()
ccc.output_all_fields()
aaa = A(5)
bbb = B(2)
aaa.f1()
bbb.f1()
return 0

View File

@ -1,7 +1,3 @@
@extern
def output_bool(x: bool):
...
@extern
def output_int32_list(x: list[int32]):
...
@ -34,32 +30,6 @@ def run() -> int32:
get_list_slice()
list_slice_assignment()
output_int32_list([1, 2, 3] + [4, 5, 6])
output_int32_list([1, 2, 3] * 3)
output_bool([] == [])
output_bool([0] == [])
output_bool([0] == [0])
output_bool([0, 1] == [0])
output_bool([0, 1] == [0, 1])
output_bool([] != [])
output_bool([0] != [])
output_bool([0] != [0])
output_bool([0] != [0, 1])
output_bool([0, 1] != [0, 1])
output_bool([] == [] == [])
output_bool([0] == [0] == [0])
output_bool([0, 1] == [0] == [0, 1])
output_bool([0, 1] == [0, 1] == [0])
output_bool([0] == [0, 1] == [0, 1])
output_bool([0, 1] == [0, 1] == [0, 1])
output_bool([] != [] != [])
output_bool([0] != [0] != [0])
output_bool([0, 1] != [0] != [0, 1])
output_bool([0, 1] != [0, 1] != [0])
output_bool([0] != [0, 1] != [0, 1])
output_bool([0, 1] != [0, 1] != [0, 1])
return 0
def get_list_slice():

View File

@ -23,12 +23,11 @@ def run() -> int32:
output_int32(x)
output_str(" * ")
output_float64(n / x)
output_str("\n")
except: # Assume this is intended to catch x == 0
break
else:
# loop fell through without finding a factor
output_int32(n)
output_str(" is a prime number\n")
output_str(" is a prime number")
return 0

View File

@ -37,7 +37,7 @@ def test_round64():
output_int64(round64(x))
def test_np_round():
for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan(), 0.0, -0.0, 1.6, 1.4, -1.4, -1.6]:
for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(np_round(x))
def test_np_isnan():

View File

@ -71,65 +71,28 @@ def output_ndarray_float_2(n: ndarray[float, Literal[2]]):
def consume_ndarray_1(n: ndarray[float, Literal[1]]):
pass
def consume_ndarray_2(n: ndarray[float, Literal[2]]):
pass
def test_ndarray_ctor():
n: ndarray[float, Literal[1]] = np_ndarray([1])
consume_ndarray_1(n)
def test_ndarray_empty():
n1: ndarray[float, 1] = np_empty([1])
consume_ndarray_1(n1)
n2: ndarray[float, 1] = np_empty(10)
consume_ndarray_1(n2)
n3: ndarray[float, 1] = np_empty((2,))
consume_ndarray_1(n3)
n4: ndarray[float, 2] = np_empty((4, 4))
consume_ndarray_2(n4)
dim4 = (5, 2)
n5: ndarray[float, 2] = np_empty(dim4)
consume_ndarray_2(n5)
n: ndarray[float, 1] = np_empty([1])
consume_ndarray_1(n)
def test_ndarray_zeros():
n1: ndarray[float, 1] = np_zeros([1])
output_ndarray_float_1(n1)
k = 3 + int32(n1[0]) # to test variable shape inputs
n2: ndarray[float, 1] = np_zeros(k * k)
output_ndarray_float_1(n2)
n3: ndarray[float, 1] = np_zeros((k * 2,))
output_ndarray_float_1(n3)
dim4 = (3, 2 * k)
n4: ndarray[float, 2] = np_zeros(dim4)
output_ndarray_float_2(n4)
n: ndarray[float, 1] = np_zeros([1])
output_ndarray_float_1(n)
def test_ndarray_ones():
n: ndarray[float, 1] = np_ones([1])
output_ndarray_float_1(n)
dim = (1,)
n_tup: ndarray[float, 1] = np_ones(dim)
output_ndarray_float_1(n_tup)
def test_ndarray_full():
n_float: ndarray[float, 1] = np_full([1], 2.0)
output_ndarray_float_1(n_float)
n_i32: ndarray[int32, 1] = np_full([1], 2)
output_ndarray_int32_1(n_i32)
dim = (1,)
n_float_tup: ndarray[float, 1] = np_full(dim, 2.0)
output_ndarray_float_1(n_float_tup)
n_i32_tup: ndarray[int32, 1] = np_full(dim, 2)
output_ndarray_int32_1(n_i32_tup)
def test_ndarray_eye():
n: ndarray[float, 2] = np_eye(2)
output_ndarray_float_2(n)
@ -877,13 +840,6 @@ def test_ndarray_minimum_broadcast_rhs_scalar():
output_ndarray_float_2(min_x_zeros)
output_ndarray_float_2(min_x_ones)
def test_ndarray_argmin():
x = np_array([[1., 2.], [3., 4.]])
y = np_argmin(x)
output_ndarray_float_2(x)
output_int64(y)
def test_ndarray_max():
x = np_identity(2)
y = np_max(x)
@ -927,13 +883,6 @@ def test_ndarray_maximum_broadcast_rhs_scalar():
output_ndarray_float_2(max_x_zeros)
output_ndarray_float_2(max_x_ones)
def test_ndarray_argmax():
x = np_array([[1., 2.], [3., 4.]])
y = np_argmax(x)
output_ndarray_float_2(x)
output_int64(y)
def test_ndarray_abs():
x = np_identity(2)
y = abs(x)
@ -1439,141 +1388,47 @@ def test_ndarray_nextafter_broadcast_rhs_scalar():
output_ndarray_float_2(nextafter_x_zeros)
output_ndarray_float_2(nextafter_x_ones)
def test_ndarray_transpose():
x: ndarray[float, 2] = np_array([[1., 2., 3.], [4., 5., 6.]])
y = np_transpose(x)
z = np_transpose(y)
def test_ndarray_any():
x1 = np_identity(5)
y1 = np_any(x1)
output_ndarray_float_2(x1)
output_bool(y1)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
x2 = np_identity(1)
y2 = np_any(x2)
output_ndarray_float_2(x2)
output_bool(y2)
def test_ndarray_reshape():
w: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
x = np_reshape(w, (1, 2, 1, -1))
y = np_reshape(x, [2, -1])
z = np_reshape(y, 10)
x3 = np_array([[1.0, 2.0], [3.0, 4.0]])
y3 = np_any(x3)
output_ndarray_float_2(x3)
output_bool(y3)
x1: ndarray[int32, 1] = np_array([1, 2, 3, 4])
x2: ndarray[int32, 2] = np_reshape(x1, (2, 2))
x4 = np_zeros([3, 5])
y4 = np_any(x4)
output_ndarray_float_2(x4)
output_bool(y4)
output_ndarray_float_1(w)
output_ndarray_float_2(y)
output_ndarray_float_1(z)
def test_ndarray_all():
x1 = np_identity(5)
y1 = np_all(x1)
output_ndarray_float_2(x1)
output_bool(y1)
def test_ndarray_dot():
x1: ndarray[float, 1] = np_array([5.0, 1.0, 4.0, 2.0])
y1: ndarray[float, 1] = np_array([5.0, 1.0, 6.0, 6.0])
z1 = np_dot(x1, y1)
x2 = np_identity(1)
y2 = np_all(x2)
output_ndarray_float_2(x2)
output_bool(y2)
x2: ndarray[int32, 1] = np_array([5, 1, 4, 2])
y2: ndarray[int32, 1] = np_array([5, 1, 6, 6])
z2 = np_dot(x2, y2)
x3: ndarray[bool, 1] = np_array([True, True, True, True])
y3: ndarray[bool, 1] = np_array([True, True, True, True])
z3 = np_dot(x3, y3)
z4 = np_dot(2, 3)
z5 = np_dot(2., 3.)
z6 = np_dot(True, False)
output_float64(z1)
output_int32(z2)
output_bool(z3)
output_int32(z4)
output_float64(z5)
output_bool(z6)
def test_ndarray_cholesky():
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
y = np_linalg_cholesky(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_qr():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
y, z = np_linalg_qr(x)
output_ndarray_float_2(x)
# QR Factorization is not unique and gives different results in numpy and nalgebra
# Reverting the decomposition to compare the initial arrays
a = y @ z
output_ndarray_float_2(a)
def test_ndarray_linalg_inv():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
y = np_linalg_inv(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_pinv():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5]])
y = np_linalg_pinv(x)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_matrix_power():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
y = np_linalg_matrix_power(x, -9)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_det():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
y = np_linalg_det(x)
output_ndarray_float_2(x)
output_float64(y)
def test_ndarray_schur():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
t, z = sp_linalg_schur(x)
output_ndarray_float_2(x)
# Schur Factorization is not unique and gives different results in scipy and nalgebra
# Reverting the decomposition to compare the initial arrays
a = (z @ t) @ np_linalg_inv(z)
output_ndarray_float_2(a)
def test_ndarray_hessenberg():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 5.0, 8.5]])
h, q = sp_linalg_hessenberg(x)
output_ndarray_float_2(x)
# Hessenberg Factorization is not unique and gives different results in scipy and nalgebra
# Reverting the decomposition to compare the initial arrays
a = (q @ h) @ np_linalg_inv(q)
output_ndarray_float_2(a)
def test_ndarray_lu():
x: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5]])
l, u = sp_linalg_lu(x)
output_ndarray_float_2(x)
output_ndarray_float_2(l)
output_ndarray_float_2(u)
def test_ndarray_svd():
w: ndarray[float, 2] = np_array([[-5.0, -1.0, 2.0], [-1.0, 4.0, 7.5], [-1.0, 8.0, -8.5]])
x, y, z = np_linalg_svd(w)
output_ndarray_float_2(w)
# SVD Factorization is not unique and gives different results in numpy and nalgebra
# Reverting the decomposition to compare the initial arrays
a = x @ z
output_ndarray_float_2(a)
output_ndarray_float_1(y)
x3 = np_array([[1.0, 2.0], [3.0, 4.0]])
y3 = np_all(x3)
output_ndarray_float_2(x3)
output_bool(y3)
x4 = np_zeros([3, 5])
y4 = np_all(x4)
output_ndarray_float_2(x4)
output_bool(y4)
def run() -> int32:
test_ndarray_ctor()
@ -1679,19 +1534,16 @@ def run() -> int32:
test_ndarray_round()
test_ndarray_floor()
test_ndarray_ceil()
test_ndarray_min()
test_ndarray_minimum()
test_ndarray_minimum_broadcast()
test_ndarray_minimum_broadcast_lhs_scalar()
test_ndarray_minimum_broadcast_rhs_scalar()
test_ndarray_argmin()
test_ndarray_max()
test_ndarray_maximum()
test_ndarray_maximum_broadcast()
test_ndarray_maximum_broadcast_lhs_scalar()
test_ndarray_maximum_broadcast_rhs_scalar()
test_ndarray_argmax()
test_ndarray_abs()
test_ndarray_isnan()
test_ndarray_isinf()
@ -1754,18 +1606,8 @@ def run() -> int32:
test_ndarray_nextafter_broadcast()
test_ndarray_nextafter_broadcast_lhs_scalar()
test_ndarray_nextafter_broadcast_rhs_scalar()
test_ndarray_transpose()
test_ndarray_reshape()
test_ndarray_dot()
test_ndarray_cholesky()
test_ndarray_qr()
test_ndarray_svd()
test_ndarray_linalg_inv()
test_ndarray_pinv()
test_ndarray_matrix_power()
test_ndarray_det()
test_ndarray_lu()
test_ndarray_schur()
test_ndarray_hessenberg()
test_ndarray_any()
test_ndarray_all()
return 0

View File

@ -1,30 +0,0 @@
@extern
def output_bool(x: bool):
...
def str_eq():
output_bool("" == "")
output_bool("a" == "")
output_bool("a" == "b")
output_bool("b" == "a")
output_bool("a" == "a")
output_bool("test string" == "test string")
output_bool("test string1" == "test string2")
def str_ne():
output_bool("" != "")
output_bool("a" != "")
output_bool("a" != "b")
output_bool("b" != "a")
output_bool("a" != "a")
output_bool("test string" != "test string")
output_bool("test string1" != "test string2")
def run() -> int32:
str_eq()
str_ne()
return 0

View File

@ -1,7 +1,3 @@
@extern
def output_bool(b: bool):
...
@extern
def output_int32_list(x: list[int32]):
...
@ -17,41 +13,6 @@ class A:
self.a = a
self.b = b
def test_tuple_eq():
# 0-len
output_bool(() == ())
# 1-len
output_bool((1,) == ())
output_bool(() == (1,))
output_bool((1,) == (1,))
output_bool((1,) == (2,))
# # 2-len
output_bool((1, 2) == ())
output_bool(() == (1, 2))
output_bool((1,) == (1, 2))
output_bool((1, 2) == (1,))
output_bool((2, 2) == (1, 2))
output_bool((1, 2) == (2, 2))
def test_tuple_ne():
# 0-len
output_bool(() != ())
# 1-len
output_bool((1,) != ())
output_bool(() != (1,))
output_bool((1,) != (1,))
output_bool((1,) != (2,))
# 2-len
output_bool((1, 2) != ())
output_bool(() != (1, 2))
output_bool((1,) != (1, 2))
output_bool((1, 2) != (1,))
output_bool((2, 2) != (1, 2))
output_bool((1, 2) != (2, 2))
def run() -> int32:
data = [0, 1, 2, 3]
@ -65,14 +26,4 @@ def run() -> int32:
output_int32(tl[0][1])
output_int32(tl[1])
output_int32(len(()))
output_int32(len((1,)))
output_int32(len((1, 2)))
output_int32(len((1, 2, 3)))
output_int32(len((1, 2, 3, 4)))
output_int32(len((1, 2, 3, 4, 5)))
test_tuple_eq()
test_tuple_ne()
return 0

View File

@ -1,11 +0,0 @@
def f(*args: int32):
pass
def run() -> int32:
f()
f(1)
f(1, 2)
f(1, 2, 3)
return 0

View File

@ -1,14 +1,5 @@
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use parking_lot::{Mutex, RwLock};
use nac3core::{
codegen::{CodeGenContext, CodeGenerator},
inkwell::{module::Linkage, values::BasicValue},
nac3parser::ast::{self, StrRef},
codegen::CodeGenContext,
symbol_resolver::{SymbolResolver, SymbolValue, ValueEnum},
toplevel::{DefinitionId, TopLevelDef},
typecheck::{
@ -16,10 +7,15 @@ use nac3core::{
typedef::{Type, Unifier},
},
};
use nac3parser::ast::{self, StrRef};
use parking_lot::{Mutex, RwLock};
use std::collections::HashSet;
use std::{collections::HashMap, sync::Arc};
pub struct ResolverInternal {
pub id_to_type: Mutex<HashMap<StrRef, Type>>,
pub id_to_def: Mutex<HashMap<StrRef, DefinitionId>>,
pub class_names: Mutex<HashMap<StrRef, Type>>,
pub module_globals: Mutex<HashMap<StrRef, SymbolValue>>,
pub str_store: Mutex<HashMap<String, i32>>,
}
@ -50,51 +46,20 @@ impl SymbolResolver for Resolver {
fn get_symbol_type(
&self,
unifier: &mut Unifier,
_: &mut Unifier,
_: &[Arc<RwLock<TopLevelDef>>],
primitives: &PrimitiveStore,
_: &PrimitiveStore,
str: StrRef,
) -> Result<Type, String> {
self.0
.id_to_type
.lock()
.get(&str)
.copied()
.or_else(|| {
self.0
.module_globals
.lock()
.get(&str)
.cloned()
.map(|v| v.get_type(primitives, unifier))
})
.ok_or(format!("cannot get type of {str}"))
self.0.id_to_type.lock().get(&str).copied().ok_or(format!("cannot get type of {str}"))
}
fn get_symbol_value<'ctx>(
&self,
str: StrRef,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
_: StrRef,
_: &mut CodeGenContext<'ctx, '_>,
) -> Option<ValueEnum<'ctx>> {
self.0.module_globals.lock().get(&str).cloned().map(|v| {
ctx.module
.get_global(&str.to_string())
.unwrap_or_else(|| {
let ty = v.get_type(&ctx.primitives, &mut ctx.unifier);
let init_val = ctx.gen_symbol_val(generator, &v, ty);
let llvm_ty = init_val.get_type();
let global = ctx.module.add_global(llvm_ty, None, &str.to_string());
global.set_linkage(Linkage::LinkOnceAny);
global.set_initializer(&init_val);
global
})
.as_basic_value_enum()
.into()
})
unimplemented!()
}
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> {

View File

@ -2,36 +2,27 @@
future_incompatible,
let_underscore,
nonstandard_style,
rust_2024_compatibility,
clippy::all
)]
#![warn(rust_2024_compatibility)]
#![warn(clippy::pedantic)]
#![allow(clippy::too_many_lines, clippy::wildcard_imports)]
use std::{
collections::{HashMap, HashSet},
fs,
num::NonZeroUsize,
path::Path,
sync::Arc,
};
use clap::Parser;
use inkwell::{
memory_buffer::MemoryBuffer, passes::PassBuilderOptions, support::is_multithreaded, targets::*,
OptimizationLevel,
};
use parking_lot::{Mutex, RwLock};
use std::collections::HashSet;
use std::num::NonZeroUsize;
use std::{collections::HashMap, fs, path::Path, sync::Arc};
use nac3core::{
codegen::{
concrete_type::ConcreteTypeStore, irrt::load_irrt, CodeGenLLVMOptions,
CodeGenTargetMachineOptions, CodeGenTask, DefaultCodeGenerator, WithCall, WorkerRegistry,
},
inkwell::{
memory_buffer::MemoryBuffer, module::Linkage, passes::PassBuilderOptions,
support::is_multithreaded, targets::*, OptimizationLevel,
},
nac3parser::{
ast::{Constant, Expr, ExprKind, StmtKind, StrRef},
parser,
},
symbol_resolver::SymbolResolver,
toplevel::{
composer::{ComposerConfig, TopLevelComposer},
@ -44,10 +35,13 @@ use nac3core::{
typedef::{FunSignature, Type, Unifier, VarMap},
},
};
use basic_symbol_resolver::*;
use nac3parser::{
ast::{Constant, Expr, ExprKind, StmtKind, StrRef},
parser,
};
mod basic_symbol_resolver;
use basic_symbol_resolver::*;
/// Command-line argument parser definition.
#[derive(Parser)]
@ -119,9 +113,7 @@ fn handle_typevar_definition(
x,
HashMap::new(),
)?;
get_type_from_type_annotation_kinds(
def_list, unifier, primitives, &ty, &mut None,
)
get_type_from_type_annotation_kinds(def_list, unifier, &ty, &mut None)
})
.collect::<Result<Vec<_>, _>>()?;
let loc = func.location;
@ -160,7 +152,7 @@ fn handle_typevar_definition(
HashMap::new(),
)?;
let constraint =
get_type_from_type_annotation_kinds(def_list, unifier, primitives, &ty, &mut None)?;
get_type_from_type_annotation_kinds(def_list, unifier, &ty, &mut None)?;
let loc = func.location;
Ok(unifier.get_fresh_const_generic_var(constraint, Some(generic_name), Some(loc)).ty)
@ -175,49 +167,46 @@ fn handle_typevar_definition(
fn handle_assignment_pattern(
targets: &[Expr],
value: &Expr,
resolver: Arc<dyn SymbolResolver + Send + Sync>,
resolver: &(dyn SymbolResolver + Send + Sync),
internal_resolver: &ResolverInternal,
composer: &mut TopLevelComposer,
def_list: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier,
primitives: &PrimitiveStore,
) -> Result<(), String> {
if targets.len() == 1 {
let target = &targets[0];
match &target.node {
match &targets[0].node {
ExprKind::Name { id, .. } => {
let def_list = composer.extract_def_list();
let unifier = &mut composer.unifier;
let primitives = &composer.primitives_ty;
if let Ok(var) =
handle_typevar_definition(value, &*resolver, &def_list, unifier, primitives)
handle_typevar_definition(value, resolver, def_list, unifier, primitives)
{
internal_resolver.add_id_type(*id, var);
Ok(())
} else if let Ok(val) = parse_parameter_default_value(value, &*resolver) {
} else if let Ok(val) = parse_parameter_default_value(value, resolver) {
internal_resolver.add_module_global(*id, val);
let (name, def_id, _) = composer
.register_top_level_var(
*id,
None,
Some(resolver.clone()),
"__main__",
target.location,
)
.unwrap();
internal_resolver.add_id_def(name, def_id);
Ok(())
} else {
Err(format!("fails to evaluate this expression `{:?}` as a constant or generic parameter at {}",
target.node,
target.location,
targets[0].node,
targets[0].location,
))
}
}
ExprKind::List { elts, .. } | ExprKind::Tuple { elts, .. } => {
handle_assignment_pattern(elts, value, resolver, internal_resolver, composer)?;
handle_assignment_pattern(
elts,
value,
resolver,
internal_resolver,
def_list,
unifier,
primitives,
)?;
Ok(())
}
_ => Err(format!("assignment to {target:?} is not supported at {}", target.location)),
_ => Err(format!(
"assignment to {:?} is not supported at {}",
targets[0], targets[0].location
)),
}
} else {
match &value.node {
@ -227,9 +216,11 @@ fn handle_assignment_pattern(
handle_assignment_pattern(
std::slice::from_ref(tar),
val,
resolver.clone(),
resolver,
internal_resolver,
composer,
def_list,
unifier,
primitives,
)?;
}
Ok(())
@ -247,40 +238,9 @@ fn handle_assignment_pattern(
}
}
fn handle_global_var(
target: &Expr,
value: Option<&Expr>,
resolver: &Arc<dyn SymbolResolver + Send + Sync>,
internal_resolver: &ResolverInternal,
composer: &mut TopLevelComposer,
) -> Result<(), String> {
let ExprKind::Name { id, .. } = target.node else {
return Err(format!(
"global variable declaration must be an identifier (at {})",
target.location,
));
};
let Some(value) = value else {
return Err(format!("global variable `{id}` must be initialized in its definition"));
};
if let Ok(val) = parse_parameter_default_value(value, &**resolver) {
internal_resolver.add_module_global(id, val);
let (name, def_id, _) = composer
.register_top_level_var(id, None, Some(resolver.clone()), "__main__", target.location)
.unwrap();
internal_resolver.add_id_def(name, def_id);
Ok(())
} else {
Err(format!(
"failed to evaluate this expression `{:?}` as a constant at {}",
target.node, target.location,
))
}
}
fn main() {
const SIZE_T: u32 = usize::BITS;
let cli = CommandLineArgs::parse();
let CommandLineArgs { file_name, threads, opt_level, emit_llvm, triple, mcpu, target_features } =
cli;
@ -313,36 +273,22 @@ fn main() {
_ => OptimizationLevel::Aggressive,
};
let target_machine_options = CodeGenTargetMachineOptions {
triple,
cpu: mcpu,
features: target_features,
reloc_mode: RelocMode::PIC,
..host_target_machine
};
let target_machine = target_machine_options
.create_target_machine(opt_level)
.expect("couldn't create target machine");
let context = nac3core::inkwell::context::Context::create();
let size_t =
context.ptr_sized_int_type(&target_machine.get_target_data(), None).get_bit_width();
let program = match fs::read_to_string(file_name.clone()) {
Ok(program) => program,
Err(err) => {
panic!("Cannot open input file: {err}");
println!("Cannot open input file: {err}");
return;
}
};
let primitive: PrimitiveStore = TopLevelComposer::make_primitives(size_t).0;
let primitive: PrimitiveStore = TopLevelComposer::make_primitives(SIZE_T).0;
let (mut composer, builtins_def, builtins_ty) =
TopLevelComposer::new(vec![], vec![], ComposerConfig::default(), size_t);
TopLevelComposer::new(vec![], ComposerConfig::default(), SIZE_T);
let internal_resolver: Arc<ResolverInternal> = ResolverInternal {
id_to_type: builtins_ty.into(),
id_to_def: builtins_def.into(),
class_names: Mutex::default(),
module_globals: Mutex::default(),
str_store: Mutex::default(),
}
@ -350,41 +296,27 @@ fn main() {
let resolver =
Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
// Process IRRT
let irrt = load_irrt(&context, resolver.as_ref());
if emit_llvm {
irrt.write_bitcode_to_path(Path::new("irrt.bc"));
}
// Process the Python script
let parser_result = parser::parse_program(&program, file_name.into()).unwrap();
for stmt in parser_result {
match &stmt.node {
StmtKind::Assign { targets, value, .. } => {
let def_list = composer.extract_def_list();
let unifier = &mut composer.unifier;
let primitives = &composer.primitives_ty;
if let Err(err) = handle_assignment_pattern(
targets,
value,
resolver.clone(),
resolver.as_ref(),
internal_resolver.as_ref(),
&mut composer,
&def_list,
unifier,
primitives,
) {
panic!("{err}");
eprintln!("{err}");
return;
}
}
StmtKind::AnnAssign { target, value, .. } => {
if let Err(err) = handle_global_var(
target,
value.as_ref().map(Box::as_ref),
&resolver,
internal_resolver.as_ref(),
&mut composer,
) {
panic!("{err}");
}
}
// allow (and ignore) "from __future__ import annotations"
StmtKind::ImportFrom { module, names, .. }
if module == &Some("__future__".into())
@ -408,19 +340,7 @@ fn main() {
let signature = store.from_signature(&mut composer.unifier, &primitive, &signature, &mut cache);
let signature = store.add_cty(signature);
if let Err(errors) = composer.start_analysis(true) {
let error_count = errors.len();
eprintln!("{error_count} error(s) occurred during top level analysis.");
for (error_i, error) in errors.iter().enumerate() {
let error_num = error_i + 1;
eprintln!("=========== ERROR {error_num}/{error_count} ============");
eprintln!("{error}");
}
eprintln!("==================================");
panic!("top level analysis failed");
}
composer.start_analysis(true).unwrap();
let top_level = Arc::new(composer.make_top_level_context());
@ -439,7 +359,16 @@ fn main() {
instance_to_stmt[""].clone()
};
let llvm_options = CodeGenLLVMOptions { opt_level, target: target_machine_options };
let llvm_options = CodeGenLLVMOptions {
opt_level,
target: CodeGenTargetMachineOptions {
triple,
cpu: mcpu,
features: target_features,
reloc_mode: RelocMode::PIC,
..host_target_machine
},
};
let task = CodeGenTask {
subst: Vec::default(),
@ -462,14 +391,14 @@ fn main() {
membuffer.lock().push(buffer);
})));
let threads = (0..threads)
.map(|i| Box::new(DefaultCodeGenerator::new(format!("module{i}"), size_t)))
.map(|i| Box::new(DefaultCodeGenerator::new(format!("module{i}"), SIZE_T)))
.collect();
let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f);
registry.add_task(task);
registry.wait_tasks_complete(handles);
// Link all modules together into `main`
let buffers = membuffers.lock();
let context = inkwell::context::Context::create();
let main = context
.create_module_from_ir(MemoryBuffer::create_from_memory_range(&buffers[0], "main"))
.unwrap();
@ -489,18 +418,25 @@ fn main() {
main.link_in_module(other).unwrap();
}
let irrt = load_irrt(&context);
if emit_llvm {
irrt.write_bitcode_to_path(Path::new("irrt.bc"));
}
main.link_in_module(irrt).unwrap();
// Private all functions except "run"
let mut function_iter = main.get_first_function();
while let Some(func) = function_iter {
if func.count_basic_blocks() > 0 && func.get_name().to_str().unwrap() != "run" {
func.set_linkage(Linkage::Private);
func.set_linkage(inkwell::module::Linkage::Private);
}
function_iter = func.get_next_function();
}
// Optimize `main`
let target_machine = llvm_options
.target
.create_target_machine(llvm_options.opt_level)
.expect("couldn't create target machine");
let pass_options = PassBuilderOptions::create();
pass_options.set_merge_functions(true);
let passes = format!("default<O{}>", opt_level as u32);
@ -509,7 +445,6 @@ fn main() {
panic!("Failed to run optimization for module `main`: {}", err.to_string());
}
// Write output
target_machine
.write_to_file(&main, FileType::Object, Path::new("module.o"))
.expect("couldn't write module to file");

View File

@ -21,6 +21,6 @@ build() {
}
package() {
mkdir -p $pkgdir/clang64/lib/python3.12/site-packages
cp ${srcdir}/nac3artiq.pyd $pkgdir/clang64/lib/python3.12/site-packages
mkdir -p $pkgdir/clang64/lib/python3.11/site-packages
cp ${srcdir}/nac3artiq.pyd $pkgdir/clang64/lib/python3.11/site-packages
}

View File

@ -21,10 +21,10 @@ let
text =
''
implementation=CPython
version=3.12
version=3.11
shared=true
abi3=false
lib_name=python3.12
lib_name=python3.11
lib_dir=${msys2-env}/clang64/lib
pointer_width=64
build_flags=WITH_THREAD

Some files were not shown because too many files have changed in this diff Show More