forked from M-Labs/nac3
Compare commits
13 Commits
a763ea3b61
...
194fbc51ab
Author | SHA1 | Date | |
---|---|---|---|
194fbc51ab | |||
70d033da61 | |||
600a5c8679 | |||
22c4d25802 | |||
308edb8237 | |||
9848795dcc | |||
58222feed4 | |||
518f21d174 | |||
e8e49684bf | |||
b2900b4883 | |||
c6dade1394 | |||
7e3fcc0845 | |||
d3b4c60d7f |
32
.clang-format
Normal file
32
.clang-format
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
BasedOnStyle: LLVM
|
||||||
|
|
||||||
|
Language: Cpp
|
||||||
|
Standard: Cpp11
|
||||||
|
|
||||||
|
AccessModifierOffset: -1
|
||||||
|
AlignEscapedNewlines: Left
|
||||||
|
AlwaysBreakAfterReturnType: None
|
||||||
|
AlwaysBreakTemplateDeclarations: Yes
|
||||||
|
AllowAllParametersOfDeclarationOnNextLine: false
|
||||||
|
AllowShortFunctionsOnASingleLine: Inline
|
||||||
|
BinPackParameters: false
|
||||||
|
BreakBeforeBinaryOperators: NonAssignment
|
||||||
|
BreakBeforeTernaryOperators: true
|
||||||
|
BreakConstructorInitializers: AfterColon
|
||||||
|
BreakInheritanceList: AfterColon
|
||||||
|
ColumnLimit: 120
|
||||||
|
ConstructorInitializerAllOnOneLineOrOnePerLine: true
|
||||||
|
ContinuationIndentWidth: 4
|
||||||
|
DerivePointerAlignment: false
|
||||||
|
IndentCaseLabels: true
|
||||||
|
IndentPPDirectives: None
|
||||||
|
IndentWidth: 4
|
||||||
|
MaxEmptyLinesToKeep: 1
|
||||||
|
PointerAlignment: Left
|
||||||
|
ReflowComments: true
|
||||||
|
SortIncludes: false
|
||||||
|
SortUsingDeclarations: true
|
||||||
|
SpaceAfterTemplateKeyword: false
|
||||||
|
SpacesBeforeTrailingComments: 2
|
||||||
|
TabWidth: 4
|
||||||
|
UseTab: Never
|
@ -180,7 +180,9 @@
|
|||||||
clippy
|
clippy
|
||||||
pre-commit
|
pre-commit
|
||||||
rustfmt
|
rustfmt
|
||||||
|
rust-analyzer
|
||||||
];
|
];
|
||||||
|
RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
|
||||||
shellHook =
|
shellHook =
|
||||||
''
|
''
|
||||||
export DEMO_LINALG_STUB=${packages.x86_64-linux.demo-linalg-stub}/lib/liblinalg.a
|
export DEMO_LINALG_STUB=${packages.x86_64-linux.demo-linalg-stub}/lib/liblinalg.a
|
||||||
|
BIN
nac3artiq/demo/dataset_db.mdb
Normal file
BIN
nac3artiq/demo/dataset_db.mdb
Normal file
Binary file not shown.
BIN
nac3artiq/demo/dataset_db.mdb-lock
Normal file
BIN
nac3artiq/demo/dataset_db.mdb-lock
Normal file
Binary file not shown.
@ -1,26 +1,87 @@
|
|||||||
from min_artiq import *
|
from min_artiq import *
|
||||||
|
from numpy import int32
|
||||||
|
|
||||||
|
|
||||||
|
# @nac3
|
||||||
|
# class A:
|
||||||
|
# a: int32
|
||||||
|
# core: KernelInvariant[Core]
|
||||||
|
|
||||||
|
# def __init__(self, a: int32):
|
||||||
|
# self.core = Core()
|
||||||
|
# self.a = a
|
||||||
|
|
||||||
|
# @kernel
|
||||||
|
# def output_all_fields(self):
|
||||||
|
# #print(self.a)
|
||||||
|
# pass
|
||||||
|
|
||||||
|
# @kernel
|
||||||
|
# def set_a(self, a: int32):
|
||||||
|
# self.a = a
|
||||||
|
|
||||||
|
# @nac3
|
||||||
|
# class B(A):
|
||||||
|
# b: int32
|
||||||
|
|
||||||
|
# def __init__(self, b: int32):
|
||||||
|
# # A.__init__(self, b + 1)
|
||||||
|
# self.core = Core()
|
||||||
|
# self.a = b
|
||||||
|
# self.b = b
|
||||||
|
# self.set_b(b)
|
||||||
|
|
||||||
|
# @kernel
|
||||||
|
# def output_parent_fields(self):
|
||||||
|
# # A.output_all_fields(self)
|
||||||
|
# pass
|
||||||
|
|
||||||
|
# @kernel
|
||||||
|
# def output_all_fields(self):
|
||||||
|
# # A.output_all_fields(self)
|
||||||
|
# pass
|
||||||
|
# #print(self.b)
|
||||||
|
|
||||||
|
# @kernel
|
||||||
|
# def set_b(self, b: int32):
|
||||||
|
# self.b = b
|
||||||
|
|
||||||
@nac3
|
@nac3
|
||||||
class Demo:
|
class C:
|
||||||
|
c: Kernel[int32]
|
||||||
|
a: Kernel[int32]
|
||||||
|
b: Kernel[int32]
|
||||||
core: KernelInvariant[Core]
|
core: KernelInvariant[Core]
|
||||||
led0: KernelInvariant[TTLOut]
|
|
||||||
led1: KernelInvariant[TTLOut]
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, c: int32):
|
||||||
|
# B.__init__(self, c + 1)
|
||||||
self.core = Core()
|
self.core = Core()
|
||||||
self.led0 = TTLOut(self.core, 18)
|
self.a = c
|
||||||
self.led1 = TTLOut(self.core, 19)
|
self.b = c
|
||||||
|
self.c = c
|
||||||
|
|
||||||
|
@kernel
|
||||||
|
def output_parent_fields(self):
|
||||||
|
# B.output_all_fields(self)
|
||||||
|
pass
|
||||||
|
|
||||||
|
@kernel
|
||||||
|
def output_all_fields(self):
|
||||||
|
# B.output_all_fields(self)
|
||||||
|
#print(self.c)
|
||||||
|
pass
|
||||||
|
|
||||||
|
@kernel
|
||||||
|
def set_c(self, c: int32):
|
||||||
|
self.c = c
|
||||||
|
|
||||||
@kernel
|
@kernel
|
||||||
def run(self):
|
def run(self):
|
||||||
self.core.reset()
|
self.output_all_fields()
|
||||||
while True:
|
# self.set_a(1)
|
||||||
with parallel:
|
# self.set_b(2)
|
||||||
self.led0.pulse(100.*ms)
|
self.set_c(3)
|
||||||
self.led1.pulse(100.*ms)
|
self.output_all_fields()
|
||||||
self.core.delay(100.*ms)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
Demo().run()
|
C(10).run()
|
||||||
|
BIN
nac3artiq/demo/module.elf
Normal file
BIN
nac3artiq/demo/module.elf
Normal file
Binary file not shown.
@ -557,6 +557,10 @@ impl Nac3 {
|
|||||||
.register_top_level(synthesized.pop().unwrap(), Some(resolver.clone()), "", false)
|
.register_top_level(synthesized.pop().unwrap(), Some(resolver.clone()), "", false)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
// Process IRRT
|
||||||
|
let context = inkwell::context::Context::create();
|
||||||
|
let irrt = load_irrt(&context, resolver.as_ref());
|
||||||
|
|
||||||
let fun_signature =
|
let fun_signature =
|
||||||
FunSignature { args: vec![], ret: self.primitive.none, vars: VarMap::new() };
|
FunSignature { args: vec![], ret: self.primitive.none, vars: VarMap::new() };
|
||||||
let mut store = ConcreteTypeStore::new();
|
let mut store = ConcreteTypeStore::new();
|
||||||
@ -727,7 +731,7 @@ impl Nac3 {
|
|||||||
membuffer.lock().push(buffer);
|
membuffer.lock().push(buffer);
|
||||||
});
|
});
|
||||||
|
|
||||||
let context = inkwell::context::Context::create();
|
// Link all modules into `main`.
|
||||||
let buffers = membuffers.lock();
|
let buffers = membuffers.lock();
|
||||||
let main = context
|
let main = context
|
||||||
.create_module_from_ir(MemoryBuffer::create_from_memory_range(&buffers[0], "main"))
|
.create_module_from_ir(MemoryBuffer::create_from_memory_range(&buffers[0], "main"))
|
||||||
@ -756,8 +760,7 @@ impl Nac3 {
|
|||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
main.link_in_module(load_irrt(&context))
|
main.link_in_module(irrt).map_err(|err| CompileError::new_err(err.to_string()))?;
|
||||||
.map_err(|err| CompileError::new_err(err.to_string()))?;
|
|
||||||
|
|
||||||
let mut function_iter = main.get_first_function();
|
let mut function_iter = main.get_first_function();
|
||||||
while let Some(func) = function_iter {
|
while let Some(func) = function_iter {
|
||||||
|
@ -8,37 +8,50 @@ use std::{
|
|||||||
};
|
};
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
const FILE: &str = "src/codegen/irrt/irrt.cpp";
|
let out_dir = env::var("OUT_DIR").unwrap();
|
||||||
|
let out_dir = Path::new(&out_dir);
|
||||||
|
let irrt_dir = Path::new("irrt");
|
||||||
|
|
||||||
|
let irrt_cpp_path = irrt_dir.join("irrt.cpp");
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* HACK: Sadly, clang doesn't let us emit generic LLVM bitcode.
|
* HACK: Sadly, clang doesn't let us emit generic LLVM bitcode.
|
||||||
* Compiling for WASM32 and filtering the output with regex is the closest we can get.
|
* Compiling for WASM32 and filtering the output with regex is the closest we can get.
|
||||||
*/
|
*/
|
||||||
let flags: &[&str] = &[
|
let mut flags: Vec<&str> = vec![
|
||||||
"--target=wasm32",
|
"--target=wasm32",
|
||||||
FILE,
|
|
||||||
"-x",
|
"-x",
|
||||||
"c++",
|
"c++",
|
||||||
|
"-std=c++20",
|
||||||
"-fno-discard-value-names",
|
"-fno-discard-value-names",
|
||||||
"-fno-exceptions",
|
"-fno-exceptions",
|
||||||
"-fno-rtti",
|
"-fno-rtti",
|
||||||
match env::var("PROFILE").as_deref() {
|
|
||||||
Ok("debug") => "-O0",
|
|
||||||
Ok("release") => "-O3",
|
|
||||||
flavor => panic!("Unknown or missing build flavor {flavor:?}"),
|
|
||||||
},
|
|
||||||
"-emit-llvm",
|
"-emit-llvm",
|
||||||
"-S",
|
"-S",
|
||||||
"-Wall",
|
"-Wall",
|
||||||
"-Wextra",
|
"-Wextra",
|
||||||
"-o",
|
"-o",
|
||||||
"-",
|
"-",
|
||||||
|
"-I",
|
||||||
|
irrt_dir.to_str().unwrap(),
|
||||||
|
irrt_cpp_path.to_str().unwrap(),
|
||||||
];
|
];
|
||||||
|
|
||||||
println!("cargo:rerun-if-changed={FILE}");
|
match env::var("PROFILE").as_deref() {
|
||||||
let out_dir = env::var("OUT_DIR").unwrap();
|
Ok("debug") => {
|
||||||
let out_path = Path::new(&out_dir);
|
flags.push("-O0");
|
||||||
|
flags.push("-DIRRT_DEBUG_ASSERT");
|
||||||
|
}
|
||||||
|
Ok("release") => {
|
||||||
|
flags.push("-O3");
|
||||||
|
}
|
||||||
|
flavor => panic!("Unknown or missing build flavor {flavor:?}"),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tell Cargo to rerun if any file under `irrt_dir` (recursive) changes
|
||||||
|
println!("cargo:rerun-if-changed={}", irrt_dir.to_str().unwrap());
|
||||||
|
|
||||||
|
// Compile IRRT and capture the LLVM IR output
|
||||||
let output = Command::new("clang-irrt")
|
let output = Command::new("clang-irrt")
|
||||||
.args(flags)
|
.args(flags)
|
||||||
.output()
|
.output()
|
||||||
@ -52,7 +65,17 @@ fn main() {
|
|||||||
let output = std::str::from_utf8(&output.stdout).unwrap().replace("\r\n", "\n");
|
let output = std::str::from_utf8(&output.stdout).unwrap().replace("\r\n", "\n");
|
||||||
let mut filtered_output = String::with_capacity(output.len());
|
let mut filtered_output = String::with_capacity(output.len());
|
||||||
|
|
||||||
let regex_filter = Regex::new(r"(?ms:^define.*?\}$)|(?m:^declare.*?$)").unwrap();
|
// Filter out irrelevant IR
|
||||||
|
//
|
||||||
|
// Regex:
|
||||||
|
// - `(?ms:^define.*?\}$)` captures LLVM `define` blocks
|
||||||
|
// - `(?m:^declare.*?$)` captures LLVM `declare` lines
|
||||||
|
// - `(?m:^%.+?=\s*type\s*\{.+?\}$)` captures LLVM `type` declarations
|
||||||
|
// - `(?m:^@.+?=.+$)` captures global constants
|
||||||
|
let regex_filter = Regex::new(
|
||||||
|
r"(?ms:^define.*?\}$)|(?m:^declare.*?$)|(?m:^%.+?=\s*type\s*\{.+?\}$)|(?m:^@.+?=.+$)",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
for f in regex_filter.captures_iter(&output) {
|
for f in regex_filter.captures_iter(&output) {
|
||||||
assert_eq!(f.len(), 1);
|
assert_eq!(f.len(), 1);
|
||||||
filtered_output.push_str(&f[0]);
|
filtered_output.push_str(&f[0]);
|
||||||
@ -63,18 +86,22 @@ fn main() {
|
|||||||
.unwrap()
|
.unwrap()
|
||||||
.replace_all(&filtered_output, "");
|
.replace_all(&filtered_output, "");
|
||||||
|
|
||||||
println!("cargo:rerun-if-env-changed=DEBUG_DUMP_IRRT");
|
// For debugging
|
||||||
if env::var("DEBUG_DUMP_IRRT").is_ok() {
|
// Doing `DEBUG_DUMP_IRRT=1 cargo build -p nac3core` dumps the LLVM IR generated
|
||||||
let mut file = File::create(out_path.join("irrt.ll")).unwrap();
|
const DEBUG_DUMP_IRRT: &str = "DEBUG_DUMP_IRRT";
|
||||||
|
println!("cargo:rerun-if-env-changed={DEBUG_DUMP_IRRT}");
|
||||||
|
if env::var(DEBUG_DUMP_IRRT).is_ok() {
|
||||||
|
let mut file = File::create(out_dir.join("irrt.ll")).unwrap();
|
||||||
file.write_all(output.as_bytes()).unwrap();
|
file.write_all(output.as_bytes()).unwrap();
|
||||||
let mut file = File::create(out_path.join("irrt-filtered.ll")).unwrap();
|
|
||||||
|
let mut file = File::create(out_dir.join("irrt-filtered.ll")).unwrap();
|
||||||
file.write_all(filtered_output.as_bytes()).unwrap();
|
file.write_all(filtered_output.as_bytes()).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut llvm_as = Command::new("llvm-as-irrt")
|
let mut llvm_as = Command::new("llvm-as-irrt")
|
||||||
.stdin(Stdio::piped())
|
.stdin(Stdio::piped())
|
||||||
.arg("-o")
|
.arg("-o")
|
||||||
.arg(out_path.join("irrt.bc"))
|
.arg(out_dir.join("irrt.bc"))
|
||||||
.spawn()
|
.spawn()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
llvm_as.stdin.as_mut().unwrap().write_all(filtered_output.as_bytes()).unwrap();
|
llvm_as.stdin.as_mut().unwrap().write_all(filtered_output.as_bytes()).unwrap();
|
||||||
|
6
nac3core/irrt/irrt.cpp
Normal file
6
nac3core/irrt/irrt.cpp
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
#include <irrt/exception.hpp>
|
||||||
|
#include <irrt/int_types.hpp>
|
||||||
|
#include <irrt/list.hpp>
|
||||||
|
#include <irrt/math.hpp>
|
||||||
|
#include <irrt/ndarray.hpp>
|
||||||
|
#include <irrt/slice.hpp>
|
9
nac3core/irrt/irrt/cslice.hpp
Normal file
9
nac3core/irrt/irrt/cslice.hpp
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <irrt/int_types.hpp>
|
||||||
|
|
||||||
|
template<typename SizeT>
|
||||||
|
struct CSlice {
|
||||||
|
uint8_t* base;
|
||||||
|
SizeT len;
|
||||||
|
};
|
25
nac3core/irrt/irrt/debug.hpp
Normal file
25
nac3core/irrt/irrt/debug.hpp
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
// Set in nac3core/build.rs
|
||||||
|
#ifdef IRRT_DEBUG_ASSERT
|
||||||
|
#define IRRT_DEBUG_ASSERT_BOOL true
|
||||||
|
#else
|
||||||
|
#define IRRT_DEBUG_ASSERT_BOOL false
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define raise_debug_assert(SizeT, msg, param1, param2, param3) \
|
||||||
|
raise_exception(SizeT, EXN_ASSERTION_ERROR, "IRRT debug assert failed: " msg, param1, param2, param3)
|
||||||
|
|
||||||
|
#define debug_assert_eq(SizeT, lhs, rhs) \
|
||||||
|
if constexpr (IRRT_DEBUG_ASSERT_BOOL) { \
|
||||||
|
if ((lhs) != (rhs)) { \
|
||||||
|
raise_debug_assert(SizeT, "LHS = {0}. RHS = {1}", lhs, rhs, NO_PARAM); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define debug_assert(SizeT, expr) \
|
||||||
|
if constexpr (IRRT_DEBUG_ASSERT_BOOL) { \
|
||||||
|
if (!(expr)) { \
|
||||||
|
raise_debug_assert(SizeT, "Got false.", NO_PARAM, NO_PARAM, NO_PARAM); \
|
||||||
|
} \
|
||||||
|
}
|
82
nac3core/irrt/irrt/exception.hpp
Normal file
82
nac3core/irrt/irrt/exception.hpp
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <irrt/cslice.hpp>
|
||||||
|
#include <irrt/int_types.hpp>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief The int type of ARTIQ exception IDs.
|
||||||
|
*/
|
||||||
|
typedef int32_t ExceptionId;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Set of exceptions C++ IRRT can use.
|
||||||
|
* Must be synchronized with `setup_irrt_exceptions` in `nac3core/src/codegen/irrt/mod.rs`.
|
||||||
|
*/
|
||||||
|
extern "C" {
|
||||||
|
ExceptionId EXN_INDEX_ERROR;
|
||||||
|
ExceptionId EXN_VALUE_ERROR;
|
||||||
|
ExceptionId EXN_ASSERTION_ERROR;
|
||||||
|
ExceptionId EXN_TYPE_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Extern function to `__nac3_raise`
|
||||||
|
*
|
||||||
|
* The parameter `err` could be `Exception<int32_t>` or `Exception<int64_t>`. The caller
|
||||||
|
* must make sure to pass `Exception`s with the correct `SizeT` depending on the `size_t` of the runtime.
|
||||||
|
*/
|
||||||
|
extern "C" void __nac3_raise(void* err);
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
/**
|
||||||
|
* @brief NAC3's Exception struct
|
||||||
|
*/
|
||||||
|
template<typename SizeT>
|
||||||
|
struct Exception {
|
||||||
|
ExceptionId id;
|
||||||
|
CSlice<SizeT> filename;
|
||||||
|
int32_t line;
|
||||||
|
int32_t column;
|
||||||
|
CSlice<SizeT> function;
|
||||||
|
CSlice<SizeT> msg;
|
||||||
|
int64_t params[3];
|
||||||
|
};
|
||||||
|
|
||||||
|
constexpr int64_t NO_PARAM = 0;
|
||||||
|
|
||||||
|
template<typename SizeT>
|
||||||
|
void _raise_exception_helper(ExceptionId id,
|
||||||
|
const char* filename,
|
||||||
|
int32_t line,
|
||||||
|
const char* function,
|
||||||
|
const char* msg,
|
||||||
|
int64_t param0,
|
||||||
|
int64_t param1,
|
||||||
|
int64_t param2) {
|
||||||
|
Exception<SizeT> e = {
|
||||||
|
.id = id,
|
||||||
|
.filename = {.base = reinterpret_cast<const uint8_t*>(filename), .len = __builtin_strlen(filename)},
|
||||||
|
.line = line,
|
||||||
|
.column = 0,
|
||||||
|
.function = {.base = reinterpret_cast<const uint8_t*>(function), .len = __builtin_strlen(function)},
|
||||||
|
.msg = {.base = reinterpret_cast<const uint8_t*>(msg), .len = __builtin_strlen(msg)},
|
||||||
|
};
|
||||||
|
e.params[0] = param0;
|
||||||
|
e.params[1] = param1;
|
||||||
|
e.params[2] = param2;
|
||||||
|
__nac3_raise(reinterpret_cast<void*>(&e));
|
||||||
|
__builtin_unreachable();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Raise an exception with location details (location in the IRRT source files).
|
||||||
|
* @param SizeT The runtime `size_t` type.
|
||||||
|
* @param id The ID of the exception to raise.
|
||||||
|
* @param msg A global constant C-string of the error message.
|
||||||
|
*
|
||||||
|
* `param0` to `param2` are optional format arguments of `msg`. They should be set to
|
||||||
|
* `NO_PARAM` to indicate they are unused.
|
||||||
|
*/
|
||||||
|
#define raise_exception(SizeT, id, msg, param0, param1, param2) \
|
||||||
|
_raise_exception_helper<SizeT>(id, __FILE__, __LINE__, __FUNCTION__, msg, param0, param1, param2)
|
||||||
|
} // namespace
|
13
nac3core/irrt/irrt/int_types.hpp
Normal file
13
nac3core/irrt/irrt/int_types.hpp
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
using int8_t = _BitInt(8);
|
||||||
|
using uint8_t = unsigned _BitInt(8);
|
||||||
|
using int32_t = _BitInt(32);
|
||||||
|
using uint32_t = unsigned _BitInt(32);
|
||||||
|
using int64_t = _BitInt(64);
|
||||||
|
using uint64_t = unsigned _BitInt(64);
|
||||||
|
|
||||||
|
// NDArray indices are always `uint32_t`.
|
||||||
|
using NDIndex = uint32_t;
|
||||||
|
// The type of an index or a value describing the length of a range/slice is always `int32_t`.
|
||||||
|
using SliceIndex = int32_t;
|
75
nac3core/irrt/irrt/list.hpp
Normal file
75
nac3core/irrt/irrt/list.hpp
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <irrt/int_types.hpp>
|
||||||
|
#include <irrt/math_util.hpp>
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
// Handle list assignment and dropping part of the list when
|
||||||
|
// both dest_step and src_step are +1.
|
||||||
|
// - All the index must *not* be out-of-bound or negative,
|
||||||
|
// - The end index is *inclusive*,
|
||||||
|
// - The length of src and dest slice size should already
|
||||||
|
// be checked: if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest)
|
||||||
|
SliceIndex __nac3_list_slice_assign_var_size(SliceIndex dest_start,
|
||||||
|
SliceIndex dest_end,
|
||||||
|
SliceIndex dest_step,
|
||||||
|
uint8_t* dest_arr,
|
||||||
|
SliceIndex dest_arr_len,
|
||||||
|
SliceIndex src_start,
|
||||||
|
SliceIndex src_end,
|
||||||
|
SliceIndex src_step,
|
||||||
|
uint8_t* src_arr,
|
||||||
|
SliceIndex src_arr_len,
|
||||||
|
const SliceIndex size) {
|
||||||
|
/* if dest_arr_len == 0, do nothing since we do not support extending list */
|
||||||
|
if (dest_arr_len == 0)
|
||||||
|
return dest_arr_len;
|
||||||
|
/* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */
|
||||||
|
if (src_step == dest_step && dest_step == 1) {
|
||||||
|
const SliceIndex src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0;
|
||||||
|
const SliceIndex dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
|
||||||
|
if (src_len > 0) {
|
||||||
|
__builtin_memmove(dest_arr + dest_start * size, src_arr + src_start * size, src_len * size);
|
||||||
|
}
|
||||||
|
if (dest_len > 0) {
|
||||||
|
/* dropping */
|
||||||
|
__builtin_memmove(dest_arr + (dest_start + src_len) * size, dest_arr + (dest_end + 1) * size,
|
||||||
|
(dest_arr_len - dest_end - 1) * size);
|
||||||
|
}
|
||||||
|
/* shrink size */
|
||||||
|
return dest_arr_len - (dest_len - src_len);
|
||||||
|
}
|
||||||
|
/* if two range overlaps, need alloca */
|
||||||
|
uint8_t need_alloca = (dest_arr == src_arr)
|
||||||
|
&& !(max(dest_start, dest_end) < min(src_start, src_end)
|
||||||
|
|| max(src_start, src_end) < min(dest_start, dest_end));
|
||||||
|
if (need_alloca) {
|
||||||
|
uint8_t* tmp = reinterpret_cast<uint8_t*>(__builtin_alloca(src_arr_len * size));
|
||||||
|
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
|
||||||
|
src_arr = tmp;
|
||||||
|
}
|
||||||
|
SliceIndex src_ind = src_start;
|
||||||
|
SliceIndex dest_ind = dest_start;
|
||||||
|
for (; (src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end); src_ind += src_step, dest_ind += dest_step) {
|
||||||
|
/* for constant optimization */
|
||||||
|
if (size == 1) {
|
||||||
|
__builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1);
|
||||||
|
} else if (size == 4) {
|
||||||
|
__builtin_memcpy(dest_arr + dest_ind * 4, src_arr + src_ind * 4, 4);
|
||||||
|
} else if (size == 8) {
|
||||||
|
__builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8);
|
||||||
|
} else {
|
||||||
|
/* memcpy for var size, cannot overlap after previous alloca */
|
||||||
|
__builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/* only dest_step == 1 can we shrink the dest list. */
|
||||||
|
/* size should be ensured prior to calling this function */
|
||||||
|
if (dest_step == 1 && dest_end >= dest_start) {
|
||||||
|
__builtin_memmove(dest_arr + dest_ind * size, dest_arr + (dest_end + 1) * size,
|
||||||
|
(dest_arr_len - dest_end - 1) * size);
|
||||||
|
return dest_arr_len - (dest_end - dest_ind) - 1;
|
||||||
|
}
|
||||||
|
return dest_arr_len;
|
||||||
|
}
|
||||||
|
} // extern "C"
|
93
nac3core/irrt/irrt/math.hpp
Normal file
93
nac3core/irrt/irrt/math.hpp
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
|
||||||
|
// need to make sure `exp >= 0` before calling this function
|
||||||
|
template<typename T>
|
||||||
|
T __nac3_int_exp_impl(T base, T exp) {
|
||||||
|
T res = 1;
|
||||||
|
/* repeated squaring method */
|
||||||
|
do {
|
||||||
|
if (exp & 1) {
|
||||||
|
res *= base; /* for n odd */
|
||||||
|
}
|
||||||
|
exp >>= 1;
|
||||||
|
base *= base;
|
||||||
|
} while (exp);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
#define DEF_nac3_int_exp_(T) \
|
||||||
|
T __nac3_int_exp_##T(T base, T exp) { \
|
||||||
|
return __nac3_int_exp_impl(base, exp); \
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
// Putting semicolons here to make clang-format not reformat this into
|
||||||
|
// a stair shape.
|
||||||
|
DEF_nac3_int_exp_(int32_t);
|
||||||
|
DEF_nac3_int_exp_(int64_t);
|
||||||
|
DEF_nac3_int_exp_(uint32_t);
|
||||||
|
DEF_nac3_int_exp_(uint64_t);
|
||||||
|
|
||||||
|
int32_t __nac3_isinf(double x) {
|
||||||
|
return __builtin_isinf(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t __nac3_isnan(double x) {
|
||||||
|
return __builtin_isnan(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
double tgamma(double arg);
|
||||||
|
|
||||||
|
double __nac3_gamma(double z) {
|
||||||
|
// Handling for denormals
|
||||||
|
// | x | Python gamma(x) | C tgamma(x) |
|
||||||
|
// --- | ----------------- | --------------- | ----------- |
|
||||||
|
// (1) | nan | nan | nan |
|
||||||
|
// (2) | -inf | -inf | inf |
|
||||||
|
// (3) | inf | inf | inf |
|
||||||
|
// (4) | 0.0 | inf | inf |
|
||||||
|
// (5) | {-1.0, -2.0, ...} | inf | nan |
|
||||||
|
|
||||||
|
// (1)-(3)
|
||||||
|
if (__builtin_isinf(z) || __builtin_isnan(z)) {
|
||||||
|
return z;
|
||||||
|
}
|
||||||
|
|
||||||
|
double v = tgamma(z);
|
||||||
|
|
||||||
|
// (4)-(5)
|
||||||
|
return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v;
|
||||||
|
}
|
||||||
|
|
||||||
|
double lgamma(double arg);
|
||||||
|
|
||||||
|
double __nac3_gammaln(double x) {
|
||||||
|
// libm's handling of value overflows differs from scipy:
|
||||||
|
// - scipy: gammaln(-inf) -> -inf
|
||||||
|
// - libm : lgamma(-inf) -> inf
|
||||||
|
|
||||||
|
if (__builtin_isinf(x)) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
return lgamma(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
double j0(double x);
|
||||||
|
|
||||||
|
double __nac3_j0(double x) {
|
||||||
|
// libm's handling of value overflows differs from scipy:
|
||||||
|
// - scipy: j0(inf) -> nan
|
||||||
|
// - libm : j0(inf) -> 0.0
|
||||||
|
|
||||||
|
if (__builtin_isinf(x)) {
|
||||||
|
return __builtin_nan("");
|
||||||
|
}
|
||||||
|
|
||||||
|
return j0(x);
|
||||||
|
}
|
||||||
|
}
|
13
nac3core/irrt/irrt/math_util.hpp
Normal file
13
nac3core/irrt/irrt/math_util.hpp
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template<typename T>
|
||||||
|
const T& max(const T& a, const T& b) {
|
||||||
|
return a > b ? a : b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
const T& min(const T& a, const T& b) {
|
||||||
|
return a > b ? b : a;
|
||||||
|
}
|
||||||
|
} // namespace
|
144
nac3core/irrt/irrt/ndarray.hpp
Normal file
144
nac3core/irrt/irrt/ndarray.hpp
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <irrt/int_types.hpp>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template<typename SizeT>
|
||||||
|
SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) {
|
||||||
|
__builtin_assume(end_idx <= list_len);
|
||||||
|
|
||||||
|
SizeT num_elems = 1;
|
||||||
|
for (SizeT i = begin_idx; i < end_idx; ++i) {
|
||||||
|
SizeT val = list_data[i];
|
||||||
|
__builtin_assume(val > 0);
|
||||||
|
num_elems *= val;
|
||||||
|
}
|
||||||
|
return num_elems;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename SizeT>
|
||||||
|
void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT num_dims, NDIndex* idxs) {
|
||||||
|
SizeT stride = 1;
|
||||||
|
for (SizeT dim = 0; dim < num_dims; dim++) {
|
||||||
|
SizeT i = num_dims - dim - 1;
|
||||||
|
__builtin_assume(dims[i] > 0);
|
||||||
|
idxs[i] = (index / stride) % dims[i];
|
||||||
|
stride *= dims[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename SizeT>
|
||||||
|
SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims, SizeT num_dims, const NDIndex* indices, SizeT num_indices) {
|
||||||
|
SizeT idx = 0;
|
||||||
|
SizeT stride = 1;
|
||||||
|
for (SizeT i = 0; i < num_dims; ++i) {
|
||||||
|
SizeT ri = num_dims - i - 1;
|
||||||
|
if (ri < num_indices) {
|
||||||
|
idx += stride * indices[ri];
|
||||||
|
}
|
||||||
|
|
||||||
|
__builtin_assume(dims[i] > 0);
|
||||||
|
stride *= dims[ri];
|
||||||
|
}
|
||||||
|
return idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename SizeT>
|
||||||
|
void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims,
|
||||||
|
SizeT lhs_ndims,
|
||||||
|
const SizeT* rhs_dims,
|
||||||
|
SizeT rhs_ndims,
|
||||||
|
SizeT* out_dims) {
|
||||||
|
SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
|
||||||
|
|
||||||
|
for (SizeT i = 0; i < max_ndims; ++i) {
|
||||||
|
const SizeT* lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr;
|
||||||
|
const SizeT* rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr;
|
||||||
|
SizeT* out_dim = &out_dims[max_ndims - i - 1];
|
||||||
|
|
||||||
|
if (lhs_dim_sz == nullptr) {
|
||||||
|
*out_dim = *rhs_dim_sz;
|
||||||
|
} else if (rhs_dim_sz == nullptr) {
|
||||||
|
*out_dim = *lhs_dim_sz;
|
||||||
|
} else if (*lhs_dim_sz == 1) {
|
||||||
|
*out_dim = *rhs_dim_sz;
|
||||||
|
} else if (*rhs_dim_sz == 1) {
|
||||||
|
*out_dim = *lhs_dim_sz;
|
||||||
|
} else if (*lhs_dim_sz == *rhs_dim_sz) {
|
||||||
|
*out_dim = *lhs_dim_sz;
|
||||||
|
} else {
|
||||||
|
__builtin_unreachable();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename SizeT>
|
||||||
|
void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT* src_dims,
|
||||||
|
SizeT src_ndims,
|
||||||
|
const NDIndex* in_idx,
|
||||||
|
NDIndex* out_idx) {
|
||||||
|
for (SizeT i = 0; i < src_ndims; ++i) {
|
||||||
|
SizeT src_i = src_ndims - i - 1;
|
||||||
|
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
uint32_t __nac3_ndarray_calc_size(const uint32_t* list_data, uint32_t list_len, uint32_t begin_idx, uint32_t end_idx) {
|
||||||
|
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t
|
||||||
|
__nac3_ndarray_calc_size64(const uint64_t* list_data, uint64_t list_len, uint64_t begin_idx, uint64_t end_idx) {
|
||||||
|
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32_t num_dims, NDIndex* idxs) {
|
||||||
|
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndex* idxs) {
|
||||||
|
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t
|
||||||
|
__nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims, const NDIndex* indices, uint32_t num_indices) {
|
||||||
|
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t
|
||||||
|
__nac3_ndarray_flatten_index64(const uint64_t* dims, uint64_t num_dims, const NDIndex* indices, uint64_t num_indices) {
|
||||||
|
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_calc_broadcast(const uint32_t* lhs_dims,
|
||||||
|
uint32_t lhs_ndims,
|
||||||
|
const uint32_t* rhs_dims,
|
||||||
|
uint32_t rhs_ndims,
|
||||||
|
uint32_t* out_dims) {
|
||||||
|
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_calc_broadcast64(const uint64_t* lhs_dims,
|
||||||
|
uint64_t lhs_ndims,
|
||||||
|
const uint64_t* rhs_dims,
|
||||||
|
uint64_t rhs_ndims,
|
||||||
|
uint64_t* out_dims) {
|
||||||
|
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_calc_broadcast_idx(const uint32_t* src_dims,
|
||||||
|
uint32_t src_ndims,
|
||||||
|
const NDIndex* in_idx,
|
||||||
|
NDIndex* out_idx) {
|
||||||
|
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims,
|
||||||
|
uint64_t src_ndims,
|
||||||
|
const NDIndex* in_idx,
|
||||||
|
NDIndex* out_idx) {
|
||||||
|
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
|
||||||
|
}
|
||||||
|
}
|
28
nac3core/irrt/irrt/slice.hpp
Normal file
28
nac3core/irrt/irrt/slice.hpp
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <irrt/int_types.hpp>
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
|
||||||
|
if (i < 0) {
|
||||||
|
i = len + i;
|
||||||
|
}
|
||||||
|
if (i < 0) {
|
||||||
|
return 0;
|
||||||
|
} else if (i > len) {
|
||||||
|
return len;
|
||||||
|
}
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
|
||||||
|
SliceIndex __nac3_range_slice_len(const SliceIndex start, const SliceIndex end, const SliceIndex step) {
|
||||||
|
SliceIndex diff = end - start;
|
||||||
|
if (diff > 0 && step > 0) {
|
||||||
|
return ((diff - 1) / step) + 1;
|
||||||
|
} else if (diff < 0 && step < 0) {
|
||||||
|
return ((diff + 1) / step) + 1;
|
||||||
|
} else {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -1,414 +0,0 @@
|
|||||||
using int8_t = _BitInt(8);
|
|
||||||
using uint8_t = unsigned _BitInt(8);
|
|
||||||
using int32_t = _BitInt(32);
|
|
||||||
using uint32_t = unsigned _BitInt(32);
|
|
||||||
using int64_t = _BitInt(64);
|
|
||||||
using uint64_t = unsigned _BitInt(64);
|
|
||||||
|
|
||||||
// NDArray indices are always `uint32_t`.
|
|
||||||
using NDIndex = uint32_t;
|
|
||||||
// The type of an index or a value describing the length of a range/slice is always `int32_t`.
|
|
||||||
using SliceIndex = int32_t;
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
template <typename T>
|
|
||||||
const T& max(const T& a, const T& b) {
|
|
||||||
return a > b ? a : b;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
const T& min(const T& a, const T& b) {
|
|
||||||
return a > b ? b : a;
|
|
||||||
}
|
|
||||||
|
|
||||||
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
|
|
||||||
// need to make sure `exp >= 0` before calling this function
|
|
||||||
template <typename T>
|
|
||||||
T __nac3_int_exp_impl(T base, T exp) {
|
|
||||||
T res = 1;
|
|
||||||
/* repeated squaring method */
|
|
||||||
do {
|
|
||||||
if (exp & 1) {
|
|
||||||
res *= base; /* for n odd */
|
|
||||||
}
|
|
||||||
exp >>= 1;
|
|
||||||
base *= base;
|
|
||||||
} while (exp);
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename SizeT>
|
|
||||||
SizeT __nac3_ndarray_calc_size_impl(
|
|
||||||
const SizeT* list_data,
|
|
||||||
SizeT list_len,
|
|
||||||
SizeT begin_idx,
|
|
||||||
SizeT end_idx
|
|
||||||
) {
|
|
||||||
__builtin_assume(end_idx <= list_len);
|
|
||||||
|
|
||||||
SizeT num_elems = 1;
|
|
||||||
for (SizeT i = begin_idx; i < end_idx; ++i) {
|
|
||||||
SizeT val = list_data[i];
|
|
||||||
__builtin_assume(val > 0);
|
|
||||||
num_elems *= val;
|
|
||||||
}
|
|
||||||
return num_elems;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename SizeT>
|
|
||||||
void __nac3_ndarray_calc_nd_indices_impl(
|
|
||||||
SizeT index,
|
|
||||||
const SizeT* dims,
|
|
||||||
SizeT num_dims,
|
|
||||||
NDIndex* idxs
|
|
||||||
) {
|
|
||||||
SizeT stride = 1;
|
|
||||||
for (SizeT dim = 0; dim < num_dims; dim++) {
|
|
||||||
SizeT i = num_dims - dim - 1;
|
|
||||||
__builtin_assume(dims[i] > 0);
|
|
||||||
idxs[i] = (index / stride) % dims[i];
|
|
||||||
stride *= dims[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename SizeT>
|
|
||||||
SizeT __nac3_ndarray_flatten_index_impl(
|
|
||||||
const SizeT* dims,
|
|
||||||
SizeT num_dims,
|
|
||||||
const NDIndex* indices,
|
|
||||||
SizeT num_indices
|
|
||||||
) {
|
|
||||||
SizeT idx = 0;
|
|
||||||
SizeT stride = 1;
|
|
||||||
for (SizeT i = 0; i < num_dims; ++i) {
|
|
||||||
SizeT ri = num_dims - i - 1;
|
|
||||||
if (ri < num_indices) {
|
|
||||||
idx += stride * indices[ri];
|
|
||||||
}
|
|
||||||
|
|
||||||
__builtin_assume(dims[i] > 0);
|
|
||||||
stride *= dims[ri];
|
|
||||||
}
|
|
||||||
return idx;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename SizeT>
|
|
||||||
void __nac3_ndarray_calc_broadcast_impl(
|
|
||||||
const SizeT* lhs_dims,
|
|
||||||
SizeT lhs_ndims,
|
|
||||||
const SizeT* rhs_dims,
|
|
||||||
SizeT rhs_ndims,
|
|
||||||
SizeT* out_dims
|
|
||||||
) {
|
|
||||||
SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
|
|
||||||
|
|
||||||
for (SizeT i = 0; i < max_ndims; ++i) {
|
|
||||||
const SizeT* lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr;
|
|
||||||
const SizeT* rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr;
|
|
||||||
SizeT* out_dim = &out_dims[max_ndims - i - 1];
|
|
||||||
|
|
||||||
if (lhs_dim_sz == nullptr) {
|
|
||||||
*out_dim = *rhs_dim_sz;
|
|
||||||
} else if (rhs_dim_sz == nullptr) {
|
|
||||||
*out_dim = *lhs_dim_sz;
|
|
||||||
} else if (*lhs_dim_sz == 1) {
|
|
||||||
*out_dim = *rhs_dim_sz;
|
|
||||||
} else if (*rhs_dim_sz == 1) {
|
|
||||||
*out_dim = *lhs_dim_sz;
|
|
||||||
} else if (*lhs_dim_sz == *rhs_dim_sz) {
|
|
||||||
*out_dim = *lhs_dim_sz;
|
|
||||||
} else {
|
|
||||||
__builtin_unreachable();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename SizeT>
|
|
||||||
void __nac3_ndarray_calc_broadcast_idx_impl(
|
|
||||||
const SizeT* src_dims,
|
|
||||||
SizeT src_ndims,
|
|
||||||
const NDIndex* in_idx,
|
|
||||||
NDIndex* out_idx
|
|
||||||
) {
|
|
||||||
for (SizeT i = 0; i < src_ndims; ++i) {
|
|
||||||
SizeT src_i = src_ndims - i - 1;
|
|
||||||
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
#define DEF_nac3_int_exp_(T) \
|
|
||||||
T __nac3_int_exp_##T(T base, T exp) {\
|
|
||||||
return __nac3_int_exp_impl(base, exp);\
|
|
||||||
}
|
|
||||||
|
|
||||||
DEF_nac3_int_exp_(int32_t)
|
|
||||||
DEF_nac3_int_exp_(int64_t)
|
|
||||||
DEF_nac3_int_exp_(uint32_t)
|
|
||||||
DEF_nac3_int_exp_(uint64_t)
|
|
||||||
|
|
||||||
SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
|
|
||||||
if (i < 0) {
|
|
||||||
i = len + i;
|
|
||||||
}
|
|
||||||
if (i < 0) {
|
|
||||||
return 0;
|
|
||||||
} else if (i > len) {
|
|
||||||
return len;
|
|
||||||
}
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
|
|
||||||
SliceIndex __nac3_range_slice_len(
|
|
||||||
const SliceIndex start,
|
|
||||||
const SliceIndex end,
|
|
||||||
const SliceIndex step
|
|
||||||
) {
|
|
||||||
SliceIndex diff = end - start;
|
|
||||||
if (diff > 0 && step > 0) {
|
|
||||||
return ((diff - 1) / step) + 1;
|
|
||||||
} else if (diff < 0 && step < 0) {
|
|
||||||
return ((diff + 1) / step) + 1;
|
|
||||||
} else {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle list assignment and dropping part of the list when
|
|
||||||
// both dest_step and src_step are +1.
|
|
||||||
// - All the index must *not* be out-of-bound or negative,
|
|
||||||
// - The end index is *inclusive*,
|
|
||||||
// - The length of src and dest slice size should already
|
|
||||||
// be checked: if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest)
|
|
||||||
SliceIndex __nac3_list_slice_assign_var_size(
|
|
||||||
SliceIndex dest_start,
|
|
||||||
SliceIndex dest_end,
|
|
||||||
SliceIndex dest_step,
|
|
||||||
uint8_t* dest_arr,
|
|
||||||
SliceIndex dest_arr_len,
|
|
||||||
SliceIndex src_start,
|
|
||||||
SliceIndex src_end,
|
|
||||||
SliceIndex src_step,
|
|
||||||
uint8_t* src_arr,
|
|
||||||
SliceIndex src_arr_len,
|
|
||||||
const SliceIndex size
|
|
||||||
) {
|
|
||||||
/* if dest_arr_len == 0, do nothing since we do not support extending list */
|
|
||||||
if (dest_arr_len == 0) return dest_arr_len;
|
|
||||||
/* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */
|
|
||||||
if (src_step == dest_step && dest_step == 1) {
|
|
||||||
const SliceIndex src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0;
|
|
||||||
const SliceIndex dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
|
|
||||||
if (src_len > 0) {
|
|
||||||
__builtin_memmove(
|
|
||||||
dest_arr + dest_start * size,
|
|
||||||
src_arr + src_start * size,
|
|
||||||
src_len * size
|
|
||||||
);
|
|
||||||
}
|
|
||||||
if (dest_len > 0) {
|
|
||||||
/* dropping */
|
|
||||||
__builtin_memmove(
|
|
||||||
dest_arr + (dest_start + src_len) * size,
|
|
||||||
dest_arr + (dest_end + 1) * size,
|
|
||||||
(dest_arr_len - dest_end - 1) * size
|
|
||||||
);
|
|
||||||
}
|
|
||||||
/* shrink size */
|
|
||||||
return dest_arr_len - (dest_len - src_len);
|
|
||||||
}
|
|
||||||
/* if two range overlaps, need alloca */
|
|
||||||
uint8_t need_alloca =
|
|
||||||
(dest_arr == src_arr)
|
|
||||||
&& !(
|
|
||||||
max(dest_start, dest_end) < min(src_start, src_end)
|
|
||||||
|| max(src_start, src_end) < min(dest_start, dest_end)
|
|
||||||
);
|
|
||||||
if (need_alloca) {
|
|
||||||
uint8_t* tmp = reinterpret_cast<uint8_t *>(__builtin_alloca(src_arr_len * size));
|
|
||||||
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
|
|
||||||
src_arr = tmp;
|
|
||||||
}
|
|
||||||
SliceIndex src_ind = src_start;
|
|
||||||
SliceIndex dest_ind = dest_start;
|
|
||||||
for (;
|
|
||||||
(src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end);
|
|
||||||
src_ind += src_step, dest_ind += dest_step
|
|
||||||
) {
|
|
||||||
/* for constant optimization */
|
|
||||||
if (size == 1) {
|
|
||||||
__builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1);
|
|
||||||
} else if (size == 4) {
|
|
||||||
__builtin_memcpy(dest_arr + dest_ind * 4, src_arr + src_ind * 4, 4);
|
|
||||||
} else if (size == 8) {
|
|
||||||
__builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8);
|
|
||||||
} else {
|
|
||||||
/* memcpy for var size, cannot overlap after previous alloca */
|
|
||||||
__builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
/* only dest_step == 1 can we shrink the dest list. */
|
|
||||||
/* size should be ensured prior to calling this function */
|
|
||||||
if (dest_step == 1 && dest_end >= dest_start) {
|
|
||||||
__builtin_memmove(
|
|
||||||
dest_arr + dest_ind * size,
|
|
||||||
dest_arr + (dest_end + 1) * size,
|
|
||||||
(dest_arr_len - dest_end - 1) * size
|
|
||||||
);
|
|
||||||
return dest_arr_len - (dest_end - dest_ind) - 1;
|
|
||||||
}
|
|
||||||
return dest_arr_len;
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t __nac3_isinf(double x) {
|
|
||||||
return __builtin_isinf(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t __nac3_isnan(double x) {
|
|
||||||
return __builtin_isnan(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
double tgamma(double arg);
|
|
||||||
|
|
||||||
double __nac3_gamma(double z) {
|
|
||||||
// Handling for denormals
|
|
||||||
// | x | Python gamma(x) | C tgamma(x) |
|
|
||||||
// --- | ----------------- | --------------- | ----------- |
|
|
||||||
// (1) | nan | nan | nan |
|
|
||||||
// (2) | -inf | -inf | inf |
|
|
||||||
// (3) | inf | inf | inf |
|
|
||||||
// (4) | 0.0 | inf | inf |
|
|
||||||
// (5) | {-1.0, -2.0, ...} | inf | nan |
|
|
||||||
|
|
||||||
// (1)-(3)
|
|
||||||
if (__builtin_isinf(z) || __builtin_isnan(z)) {
|
|
||||||
return z;
|
|
||||||
}
|
|
||||||
|
|
||||||
double v = tgamma(z);
|
|
||||||
|
|
||||||
// (4)-(5)
|
|
||||||
return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v;
|
|
||||||
}
|
|
||||||
|
|
||||||
double lgamma(double arg);
|
|
||||||
|
|
||||||
double __nac3_gammaln(double x) {
|
|
||||||
// libm's handling of value overflows differs from scipy:
|
|
||||||
// - scipy: gammaln(-inf) -> -inf
|
|
||||||
// - libm : lgamma(-inf) -> inf
|
|
||||||
|
|
||||||
if (__builtin_isinf(x)) {
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
return lgamma(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
double j0(double x);
|
|
||||||
|
|
||||||
double __nac3_j0(double x) {
|
|
||||||
// libm's handling of value overflows differs from scipy:
|
|
||||||
// - scipy: j0(inf) -> nan
|
|
||||||
// - libm : j0(inf) -> 0.0
|
|
||||||
|
|
||||||
if (__builtin_isinf(x)) {
|
|
||||||
return __builtin_nan("");
|
|
||||||
}
|
|
||||||
|
|
||||||
return j0(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t __nac3_ndarray_calc_size(
|
|
||||||
const uint32_t* list_data,
|
|
||||||
uint32_t list_len,
|
|
||||||
uint32_t begin_idx,
|
|
||||||
uint32_t end_idx
|
|
||||||
) {
|
|
||||||
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t __nac3_ndarray_calc_size64(
|
|
||||||
const uint64_t* list_data,
|
|
||||||
uint64_t list_len,
|
|
||||||
uint64_t begin_idx,
|
|
||||||
uint64_t end_idx
|
|
||||||
) {
|
|
||||||
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_calc_nd_indices(
|
|
||||||
uint32_t index,
|
|
||||||
const uint32_t* dims,
|
|
||||||
uint32_t num_dims,
|
|
||||||
NDIndex* idxs
|
|
||||||
) {
|
|
||||||
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_calc_nd_indices64(
|
|
||||||
uint64_t index,
|
|
||||||
const uint64_t* dims,
|
|
||||||
uint64_t num_dims,
|
|
||||||
NDIndex* idxs
|
|
||||||
) {
|
|
||||||
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t __nac3_ndarray_flatten_index(
|
|
||||||
const uint32_t* dims,
|
|
||||||
uint32_t num_dims,
|
|
||||||
const NDIndex* indices,
|
|
||||||
uint32_t num_indices
|
|
||||||
) {
|
|
||||||
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t __nac3_ndarray_flatten_index64(
|
|
||||||
const uint64_t* dims,
|
|
||||||
uint64_t num_dims,
|
|
||||||
const NDIndex* indices,
|
|
||||||
uint64_t num_indices
|
|
||||||
) {
|
|
||||||
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_calc_broadcast(
|
|
||||||
const uint32_t* lhs_dims,
|
|
||||||
uint32_t lhs_ndims,
|
|
||||||
const uint32_t* rhs_dims,
|
|
||||||
uint32_t rhs_ndims,
|
|
||||||
uint32_t* out_dims
|
|
||||||
) {
|
|
||||||
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_calc_broadcast64(
|
|
||||||
const uint64_t* lhs_dims,
|
|
||||||
uint64_t lhs_ndims,
|
|
||||||
const uint64_t* rhs_dims,
|
|
||||||
uint64_t rhs_ndims,
|
|
||||||
uint64_t* out_dims
|
|
||||||
) {
|
|
||||||
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_calc_broadcast_idx(
|
|
||||||
const uint32_t* src_dims,
|
|
||||||
uint32_t src_ndims,
|
|
||||||
const NDIndex* in_idx,
|
|
||||||
NDIndex* out_idx
|
|
||||||
) {
|
|
||||||
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_calc_broadcast_idx64(
|
|
||||||
const uint64_t* src_dims,
|
|
||||||
uint64_t src_ndims,
|
|
||||||
const NDIndex* in_idx,
|
|
||||||
NDIndex* out_idx
|
|
||||||
) {
|
|
||||||
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
|
|
||||||
}
|
|
||||||
} // extern "C"
|
|
@ -1,4 +1,4 @@
|
|||||||
use crate::typecheck::typedef::Type;
|
use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Type};
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
classes::{
|
classes::{
|
||||||
@ -16,14 +16,14 @@ use inkwell::{
|
|||||||
memory_buffer::MemoryBuffer,
|
memory_buffer::MemoryBuffer,
|
||||||
module::Module,
|
module::Module,
|
||||||
types::{BasicTypeEnum, IntType},
|
types::{BasicTypeEnum, IntType},
|
||||||
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue},
|
values::{BasicValue, BasicValueEnum, CallSiteValue, FloatValue, IntValue},
|
||||||
AddressSpace, IntPredicate,
|
AddressSpace, IntPredicate,
|
||||||
};
|
};
|
||||||
use itertools::Either;
|
use itertools::Either;
|
||||||
use nac3parser::ast::Expr;
|
use nac3parser::ast::Expr;
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn load_irrt(ctx: &Context) -> Module {
|
pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> {
|
||||||
let bitcode_buf = MemoryBuffer::create_from_memory_range(
|
let bitcode_buf = MemoryBuffer::create_from_memory_range(
|
||||||
include_bytes!(concat!(env!("OUT_DIR"), "/irrt.bc")),
|
include_bytes!(concat!(env!("OUT_DIR"), "/irrt.bc")),
|
||||||
"irrt_bitcode_buffer",
|
"irrt_bitcode_buffer",
|
||||||
@ -39,6 +39,25 @@ pub fn load_irrt(ctx: &Context) -> Module {
|
|||||||
let function = irrt_mod.get_function(symbol).unwrap();
|
let function = irrt_mod.get_function(symbol).unwrap();
|
||||||
function.add_attribute(AttributeLoc::Function, ctx.create_enum_attribute(inline_attr, 0));
|
function.add_attribute(AttributeLoc::Function, ctx.create_enum_attribute(inline_attr, 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize all global `EXN_*` exception IDs in IRRT with the [`SymbolResolver`].
|
||||||
|
let exn_id_type = ctx.i32_type();
|
||||||
|
let errors = &[
|
||||||
|
("EXN_INDEX_ERROR", "0:IndexError"),
|
||||||
|
("EXN_VALUE_ERROR", "0:ValueError"),
|
||||||
|
("EXN_ASSERTION_ERROR", "0:AssertionError"),
|
||||||
|
("EXN_TYPE_ERROR", "0:TypeError"),
|
||||||
|
];
|
||||||
|
for (irrt_name, symbol_name) in errors {
|
||||||
|
let exn_id = symbol_resolver.get_string_id(symbol_name);
|
||||||
|
let exn_id = exn_id_type.const_int(exn_id as u64, false).as_basic_value_enum();
|
||||||
|
|
||||||
|
let global = irrt_mod.get_global(irrt_name).unwrap_or_else(|| {
|
||||||
|
panic!("Exception symbol name '{irrt_name}' should exist in the IRRT LLVM module")
|
||||||
|
});
|
||||||
|
global.set_initializer(&exn_id);
|
||||||
|
}
|
||||||
|
|
||||||
irrt_mod
|
irrt_mod
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
use indexmap::IndexMap;
|
||||||
use nac3parser::ast::fold::Fold;
|
use nac3parser::ast::fold::Fold;
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
|
|
||||||
@ -5,7 +6,7 @@ use crate::{
|
|||||||
codegen::{expr::get_subst_key, stmt::exn_constructor},
|
codegen::{expr::get_subst_key, stmt::exn_constructor},
|
||||||
symbol_resolver::SymbolValue,
|
symbol_resolver::SymbolValue,
|
||||||
typecheck::{
|
typecheck::{
|
||||||
type_inferencer::{FunctionData, Inferencer},
|
type_inferencer::{report_error, FunctionData, Inferencer},
|
||||||
typedef::{TypeVar, VarMap},
|
typedef::{TypeVar, VarMap},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
@ -389,8 +390,15 @@ impl TopLevelComposer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn start_analysis(&mut self, inference: bool) -> Result<(), HashSet<String>> {
|
pub fn start_analysis(&mut self, inference: bool) -> Result<(), HashSet<String>> {
|
||||||
self.analyze_top_level_class_type_var()?;
|
let unifier = self.unifier.borrow_mut();
|
||||||
self.analyze_top_level_class_bases()?;
|
let primitives_store = &self.primitives_ty;
|
||||||
|
let def_list = &self.definition_ast_list;
|
||||||
|
|
||||||
|
// Step 1. Analyze type variables within class definitions
|
||||||
|
Self::analyze_top_level_class_type_var2(def_list, unifier, primitives_store, (&self.keyword_list, &self.core_config))?;
|
||||||
|
|
||||||
|
// self.analyze_top_level_class_type_var()?;
|
||||||
|
// self.analyze_top_level_class_bases()?;
|
||||||
self.analyze_top_level_class_fields_methods()?;
|
self.analyze_top_level_class_fields_methods()?;
|
||||||
self.analyze_top_level_function()?;
|
self.analyze_top_level_function()?;
|
||||||
if inference {
|
if inference {
|
||||||
@ -399,178 +407,70 @@ impl TopLevelComposer {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// step 1, analyze the type vars associated with top level class
|
fn analyze_bases(
|
||||||
fn analyze_top_level_class_type_var(&mut self) -> Result<(), HashSet<String>> {
|
class_def: &Arc<RwLock<TopLevelDef>>,
|
||||||
let def_list = &self.definition_ast_list;
|
class_ast: &Option<Stmt>,
|
||||||
let temp_def_list = self.extract_def_list();
|
temp_def_list: &[Arc<RwLock<TopLevelDef>>],
|
||||||
let unifier = self.unifier.borrow_mut();
|
unifier: &mut Unifier,
|
||||||
let primitives_store = &self.primitives_ty;
|
primitives_store: &PrimitiveStore,
|
||||||
|
) -> Result<(), HashSet<String>> {
|
||||||
let mut analyze = |class_def: &Arc<RwLock<TopLevelDef>>, class_ast: &Option<Stmt>| {
|
let mut class_def = class_def.write();
|
||||||
// only deal with class def here
|
let (class_def_id, class_ancestors, class_bases_ast, class_type_vars, class_resolver) = {
|
||||||
let mut class_def = class_def.write();
|
let TopLevelDef::Class { object_id, ancestors, type_vars, resolver, .. } =
|
||||||
let (class_bases_ast, class_def_type_vars, class_resolver) = {
|
&mut *class_def
|
||||||
if let TopLevelDef::Class { type_vars, resolver, .. } = &mut *class_def {
|
else {
|
||||||
let Some(ast::Located { node: ast::StmtKind::ClassDef { bases, .. }, .. }) =
|
unreachable!()
|
||||||
class_ast
|
|
||||||
else {
|
|
||||||
unreachable!()
|
|
||||||
};
|
|
||||||
|
|
||||||
(bases, type_vars, resolver)
|
|
||||||
} else {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
let class_resolver = class_resolver.as_ref().unwrap();
|
let Some(ast::Located { node: ast::StmtKind::ClassDef { bases, .. }, .. }) = class_ast
|
||||||
let class_resolver = &**class_resolver;
|
else {
|
||||||
|
unreachable!()
|
||||||
let mut is_generic = false;
|
};
|
||||||
for b in class_bases_ast {
|
(object_id, ancestors, bases, type_vars, resolver.as_ref().unwrap().as_ref())
|
||||||
match &b.node {
|
|
||||||
// analyze typevars bounded to the class,
|
|
||||||
// only support things like `class A(Generic[T, V])`,
|
|
||||||
// things like `class A(Generic[T, V, ImportedModule.T])` is not supported
|
|
||||||
// i.e. only simple names are allowed in the subscript
|
|
||||||
// should update the TopLevelDef::Class.typevars and the TypeEnum::TObj.params
|
|
||||||
ast::ExprKind::Subscript { value, slice, .. }
|
|
||||||
if {
|
|
||||||
matches!(
|
|
||||||
&value.node,
|
|
||||||
ast::ExprKind::Name { id, .. } if id == &"Generic".into()
|
|
||||||
)
|
|
||||||
} =>
|
|
||||||
{
|
|
||||||
if is_generic {
|
|
||||||
return Err(HashSet::from([format!(
|
|
||||||
"only single Generic[...] is allowed (at {})",
|
|
||||||
b.location
|
|
||||||
)]));
|
|
||||||
}
|
|
||||||
is_generic = true;
|
|
||||||
|
|
||||||
let type_var_list: Vec<&ast::Expr<()>>;
|
|
||||||
// if `class A(Generic[T, V, G])`
|
|
||||||
if let ast::ExprKind::Tuple { elts, .. } = &slice.node {
|
|
||||||
type_var_list = elts.iter().collect_vec();
|
|
||||||
// `class A(Generic[T])`
|
|
||||||
} else {
|
|
||||||
type_var_list = vec![&**slice];
|
|
||||||
}
|
|
||||||
|
|
||||||
// parse the type vars
|
|
||||||
let type_vars = type_var_list
|
|
||||||
.into_iter()
|
|
||||||
.map(|e| {
|
|
||||||
class_resolver.parse_type_annotation(
|
|
||||||
&temp_def_list,
|
|
||||||
unifier,
|
|
||||||
primitives_store,
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
|
||||||
|
|
||||||
// check if all are unique type vars
|
|
||||||
let all_unique_type_var = {
|
|
||||||
let mut occurred_type_var_id: HashSet<TypeVarId> = HashSet::new();
|
|
||||||
type_vars.iter().all(|x| {
|
|
||||||
let ty = unifier.get_ty(*x);
|
|
||||||
if let TypeEnum::TVar { id, .. } = ty.as_ref() {
|
|
||||||
occurred_type_var_id.insert(*id)
|
|
||||||
} else {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
})
|
|
||||||
};
|
|
||||||
if !all_unique_type_var {
|
|
||||||
return Err(HashSet::from([format!(
|
|
||||||
"duplicate type variable occurs (at {})",
|
|
||||||
slice.location
|
|
||||||
)]));
|
|
||||||
}
|
|
||||||
|
|
||||||
// add to TopLevelDef
|
|
||||||
class_def_type_vars.extend(type_vars);
|
|
||||||
}
|
|
||||||
|
|
||||||
// if others, do nothing in this function
|
|
||||||
_ => continue,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
};
|
};
|
||||||
let mut errors = HashSet::new();
|
|
||||||
for (class_def, class_ast) in def_list.iter().skip(self.builtin_num) {
|
|
||||||
if class_ast.is_none() {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if let Err(e) = analyze(class_def, class_ast) {
|
|
||||||
errors.extend(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !errors.is_empty() {
|
|
||||||
return Err(errors);
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// step 2, base classes.
|
let mut is_generic = false;
|
||||||
/// now that the type vars of all classes are done, handle base classes and
|
let mut has_base = false;
|
||||||
/// put Self class into the ancestors list. We only allow single inheritance
|
// Check class bases for typevars
|
||||||
fn analyze_top_level_class_bases(&mut self) -> Result<(), HashSet<String>> {
|
for b in class_bases_ast {
|
||||||
if self.unifier.top_level.is_none() {
|
match &b.node {
|
||||||
let ctx = Arc::new(self.make_top_level_context());
|
// analyze typevars bounded to the class,
|
||||||
self.unifier.top_level = Some(ctx);
|
// only support things like `class A(Generic[T, V])`,
|
||||||
}
|
// things like `class A(Generic[T, V, ImportedModule.T])` is not supported
|
||||||
|
// i.e. only simple names are allowed in the subscript
|
||||||
|
// should update the TopLevelDef::Class.typevars and the TypeEnum::TObj.params
|
||||||
|
ast::ExprKind::Subscript { value, slice, .. } if matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"Generic".into()) =>
|
||||||
|
{
|
||||||
|
if is_generic {
|
||||||
|
return report_error("only single Generic[...] is allowed", b.location);
|
||||||
|
}
|
||||||
|
is_generic = true;
|
||||||
|
|
||||||
let temp_def_list = self.extract_def_list();
|
let type_var_list: Vec<&ast::Expr<()>>;
|
||||||
let unifier = self.unifier.borrow_mut();
|
// if `class A(Generic[T, V, G])`
|
||||||
let primitive_types = self.primitives_ty;
|
if let ast::ExprKind::Tuple { elts, .. } = &slice.node {
|
||||||
|
type_var_list = elts.iter().collect_vec();
|
||||||
let mut get_direct_parents =
|
// `class A(Generic[T])`
|
||||||
|class_def: &Arc<RwLock<TopLevelDef>>, class_ast: &Option<Stmt>| {
|
|
||||||
let mut class_def = class_def.write();
|
|
||||||
let (class_def_id, class_bases, class_ancestors, class_resolver, class_type_vars) = {
|
|
||||||
if let TopLevelDef::Class {
|
|
||||||
ancestors, resolver, object_id, type_vars, ..
|
|
||||||
} = &mut *class_def
|
|
||||||
{
|
|
||||||
let Some(ast::Located {
|
|
||||||
node: ast::StmtKind::ClassDef { bases, .. }, ..
|
|
||||||
}) = class_ast
|
|
||||||
else {
|
|
||||||
unreachable!()
|
|
||||||
};
|
|
||||||
|
|
||||||
(object_id, bases, ancestors, resolver, type_vars)
|
|
||||||
} else {
|
} else {
|
||||||
return Ok(());
|
type_var_list = vec![&**slice];
|
||||||
}
|
}
|
||||||
};
|
|
||||||
let class_resolver = class_resolver.as_ref().unwrap();
|
|
||||||
let class_resolver = &**class_resolver;
|
|
||||||
|
|
||||||
let mut has_base = false;
|
let type_vars = type_var_list
|
||||||
for b in class_bases {
|
.into_iter()
|
||||||
// type vars have already been handled, so skip on `Generic[...]`
|
.map(|e| {
|
||||||
if matches!(
|
class_resolver.parse_type_annotation(
|
||||||
&b.node,
|
temp_def_list,
|
||||||
ast::ExprKind::Subscript { value, .. }
|
unifier,
|
||||||
if matches!(
|
primitives_store,
|
||||||
&value.node,
|
e,
|
||||||
ast::ExprKind::Name { id, .. } if id == &"Generic".into()
|
|
||||||
)
|
)
|
||||||
) {
|
})
|
||||||
continue;
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
}
|
|
||||||
|
|
||||||
|
class_type_vars.extend(type_vars);
|
||||||
|
}
|
||||||
|
ast::ExprKind::Name { .. } => {
|
||||||
if has_base {
|
if has_base {
|
||||||
return Err(HashSet::from([format!(
|
return report_error("a class definition can only have at most one base class declaration and one generic declaration", b.location);
|
||||||
"a class definition can only have at most one base class \
|
|
||||||
declaration and one generic declaration (at {})",
|
|
||||||
b.location
|
|
||||||
)]));
|
|
||||||
}
|
}
|
||||||
has_base = true;
|
has_base = true;
|
||||||
|
|
||||||
@ -578,9 +478,9 @@ impl TopLevelComposer {
|
|||||||
// bast_ty if it is a CustomClassKind
|
// bast_ty if it is a CustomClassKind
|
||||||
let base_ty = parse_ast_to_type_annotation_kinds(
|
let base_ty = parse_ast_to_type_annotation_kinds(
|
||||||
class_resolver,
|
class_resolver,
|
||||||
&temp_def_list,
|
temp_def_list,
|
||||||
unifier,
|
unifier,
|
||||||
&primitive_types,
|
primitives_store,
|
||||||
b,
|
b,
|
||||||
vec![(*class_def_id, class_type_vars.clone())]
|
vec![(*class_def_id, class_type_vars.clone())]
|
||||||
.into_iter()
|
.into_iter()
|
||||||
@ -590,123 +490,127 @@ impl TopLevelComposer {
|
|||||||
if let TypeAnnotation::CustomClass { .. } = &base_ty {
|
if let TypeAnnotation::CustomClass { .. } = &base_ty {
|
||||||
class_ancestors.push(base_ty);
|
class_ancestors.push(base_ty);
|
||||||
} else {
|
} else {
|
||||||
return Err(HashSet::from([format!(
|
return report_error(
|
||||||
"class base declaration can only be custom class (at {})",
|
"class base declaration can only be custom class",
|
||||||
b.location,
|
b.location,
|
||||||
)]));
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
// TODO: Report Error here
|
||||||
};
|
_ => {
|
||||||
|
println!("Type was => {}", b.node.name());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// first, only push direct parent into the list
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn analyze_ancestors(
|
||||||
|
class_def: &Arc<RwLock<TopLevelDef>>,
|
||||||
|
temp_def_list: &[Arc<RwLock<TopLevelDef>>],
|
||||||
|
) {
|
||||||
|
// Check if class has a direct parent
|
||||||
|
let mut class_def = class_def.write();
|
||||||
|
let TopLevelDef::Class { ancestors, type_vars, object_id, .. } = &mut *class_def else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
let mut anc_set = HashMap::new();
|
||||||
|
|
||||||
|
if let Some(ancestor) = ancestors.first() {
|
||||||
|
let TypeAnnotation::CustomClass { id, .. } = ancestor else { unreachable!() };
|
||||||
|
let TopLevelDef::Class { ancestors: parent_ancestors, .. } =
|
||||||
|
&*temp_def_list[id.0].read()
|
||||||
|
else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
for anc in parent_ancestors.iter().skip(1) {
|
||||||
|
let TypeAnnotation::CustomClass { id, .. } = anc else { unreachable!() };
|
||||||
|
anc_set.insert(id, anc.clone());
|
||||||
|
}
|
||||||
|
ancestors.extend(anc_set.into_iter().map(|f| f.1).collect::<Vec<_>>());
|
||||||
|
}
|
||||||
|
|
||||||
|
ancestors.insert(0, make_self_type_annotation(type_vars.as_slice(), *object_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// step 1, analyze the type vars associated with top level class
|
||||||
|
fn analyze_top_level_class_type_var2(
|
||||||
|
def_list: &[DefAst],
|
||||||
|
unifier: &mut Unifier,
|
||||||
|
primitives_store: &PrimitiveStore,
|
||||||
|
core_info: (&HashSet<StrRef>, &ComposerConfig),
|
||||||
|
) -> Result<(), HashSet<String>> {
|
||||||
let mut errors = HashSet::new();
|
let mut errors = HashSet::new();
|
||||||
for (class_def, class_ast) in self.definition_ast_list.iter_mut().skip(self.builtin_num) {
|
let mut temp_def_list: Vec<Arc<RwLock<TopLevelDef>>> = Vec::default();
|
||||||
if class_ast.is_none() {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if let Err(e) = get_direct_parents(class_def, class_ast) {
|
|
||||||
errors.extend(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !errors.is_empty() {
|
|
||||||
return Err(errors);
|
|
||||||
}
|
|
||||||
|
|
||||||
// second, get all ancestors
|
for (class_def, class_ast) in def_list {
|
||||||
let mut ancestors_store: HashMap<DefinitionId, Vec<TypeAnnotation>> = HashMap::default();
|
if class_ast.is_some() && matches!(&*class_def.read(), TopLevelDef::Class { .. }) {
|
||||||
let mut get_all_ancestors =
|
// Add type vars and direct parents
|
||||||
|class_def: &Arc<RwLock<TopLevelDef>>| -> Result<(), HashSet<String>> {
|
if let Err(e) = Self::analyze_bases(
|
||||||
let class_def = class_def.read();
|
class_def,
|
||||||
let (class_ancestors, class_id) = {
|
class_ast,
|
||||||
if let TopLevelDef::Class { ancestors, object_id, .. } = &*class_def {
|
&temp_def_list,
|
||||||
(ancestors, *object_id)
|
unifier,
|
||||||
} else {
|
primitives_store,
|
||||||
return Ok(());
|
) {
|
||||||
}
|
errors.extend(e);
|
||||||
};
|
|
||||||
ancestors_store.insert(
|
|
||||||
class_id,
|
|
||||||
// if class has direct parents, get all ancestors of its parents. Else just empty
|
|
||||||
if class_ancestors.is_empty() {
|
|
||||||
vec![]
|
|
||||||
} else {
|
|
||||||
Self::get_all_ancestors_helper(
|
|
||||||
&class_ancestors[0],
|
|
||||||
temp_def_list.as_slice(),
|
|
||||||
)?
|
|
||||||
},
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
};
|
|
||||||
for (class_def, ast) in self.definition_ast_list.iter().skip(self.builtin_num) {
|
|
||||||
if ast.is_none() {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if let Err(e) = get_all_ancestors(class_def) {
|
|
||||||
errors.extend(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !errors.is_empty() {
|
|
||||||
return Err(errors);
|
|
||||||
}
|
|
||||||
|
|
||||||
// insert the ancestors to the def list
|
|
||||||
for (class_def, class_ast) in self.definition_ast_list.iter_mut().skip(self.builtin_num) {
|
|
||||||
if class_ast.is_none() {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
let mut class_def = class_def.write();
|
|
||||||
let (class_ancestors, class_id, class_type_vars) = {
|
|
||||||
if let TopLevelDef::Class { ancestors, object_id, type_vars, .. } = &mut *class_def
|
|
||||||
{
|
|
||||||
(ancestors, *object_id, type_vars)
|
|
||||||
} else {
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
};
|
// Get class ancestors order matters here. Like python we will only consider classes to be correct if they are in same order
|
||||||
|
Self::analyze_ancestors(class_def, &temp_def_list);
|
||||||
|
|
||||||
let ans = ancestors_store.get_mut(&class_id).unwrap();
|
let mut type_var_to_concrete_def: HashMap<Type, TypeAnnotation> = HashMap::new();
|
||||||
class_ancestors.append(ans);
|
if let Err(e) = Self::analyze_single_class_methods_fields(
|
||||||
|
class_def,
|
||||||
|
&class_ast.as_ref().unwrap().node,
|
||||||
|
&temp_def_list,
|
||||||
|
unifier,
|
||||||
|
primitives_store,
|
||||||
|
&mut type_var_to_concrete_def,
|
||||||
|
core_info,
|
||||||
|
) {
|
||||||
|
errors.extend(e);
|
||||||
|
}
|
||||||
|
|
||||||
// insert self type annotation to the front of the vector to maintain the order
|
// special case classes that inherit from Exception
|
||||||
class_ancestors
|
let TopLevelDef::Class { ancestors: class_ancestors, loc, .. } = &*class_def.read()
|
||||||
.insert(0, make_self_type_annotation(class_type_vars.as_slice(), class_id));
|
else {
|
||||||
|
|
||||||
// special case classes that inherit from Exception
|
|
||||||
if class_ancestors
|
|
||||||
.iter()
|
|
||||||
.any(|ann| matches!(ann, TypeAnnotation::CustomClass { id, .. } if id.0 == 7))
|
|
||||||
{
|
|
||||||
// if inherited from Exception, the body should be a pass
|
|
||||||
let ast::StmtKind::ClassDef { body, .. } = &class_ast.as_ref().unwrap().node else {
|
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
|
if class_ancestors
|
||||||
for stmt in body {
|
.iter()
|
||||||
if matches!(
|
.any(|ann| matches!(ann, TypeAnnotation::CustomClass { id, .. } if id.0 == 7))
|
||||||
stmt.node,
|
{
|
||||||
ast::StmtKind::FunctionDef { .. } | ast::StmtKind::AnnAssign { .. }
|
// if inherited from Exception, the body should be a pass
|
||||||
) {
|
let ast::StmtKind::ClassDef { body, .. } = &class_ast.as_ref().unwrap().node
|
||||||
return Err(HashSet::from([
|
else {
|
||||||
"Classes inherited from exception should have no custom fields/methods"
|
unreachable!()
|
||||||
.into(),
|
};
|
||||||
]));
|
for stmt in body {
|
||||||
|
if matches!(
|
||||||
|
stmt.node,
|
||||||
|
ast::StmtKind::FunctionDef { .. } | ast::StmtKind::AnnAssign { .. }
|
||||||
|
) {
|
||||||
|
errors.extend(report_error("Classes inherited from exception should have no custom fields/methods", loc.unwrap()));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// println!("Adding class_def of {name} to the temp_def_list with ID {}", object_id.0);
|
||||||
|
temp_def_list.push(class_def.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
// deal with ancestor of Exception object
|
// deal with ancestor of Exception object
|
||||||
let TopLevelDef::Class { name, ancestors, object_id, .. } =
|
let TopLevelDef::Class { name, ancestors, object_id, .. } = &mut *def_list[7].0.write()
|
||||||
&mut *self.definition_ast_list[7].0.write()
|
|
||||||
else {
|
else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
|
|
||||||
assert_eq!(*name, "Exception".into());
|
assert_eq!(*name, "Exception".into());
|
||||||
ancestors.push(make_self_type_annotation(&[], *object_id));
|
ancestors.push(make_self_type_annotation(&[], *object_id));
|
||||||
|
|
||||||
|
if !errors.is_empty() {
|
||||||
|
return Err(errors);
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1199,126 +1103,81 @@ impl TopLevelComposer {
|
|||||||
let mut method_var_map = VarMap::new();
|
let mut method_var_map = VarMap::new();
|
||||||
|
|
||||||
let arg_types: Vec<FuncArg> = {
|
let arg_types: Vec<FuncArg> = {
|
||||||
// check method parameters cannot have same name
|
// Function arguments must have:
|
||||||
|
// 1) `self` as first argument
|
||||||
|
// 2) unique names
|
||||||
|
// 3) names different than keywords
|
||||||
|
match args.args.first() {
|
||||||
|
Some(id) if id.node.arg == "self".into() => {},
|
||||||
|
_ => return report_error("class method must have a `self` parameter", b.location),
|
||||||
|
}
|
||||||
let mut defined_parameter_name: HashSet<_> = HashSet::new();
|
let mut defined_parameter_name: HashSet<_> = HashSet::new();
|
||||||
let zelf: StrRef = "self".into();
|
for arg in args.args.iter().skip(1) {
|
||||||
for x in &args.args {
|
if !defined_parameter_name.insert(arg.node.arg) {
|
||||||
if !defined_parameter_name.insert(x.node.arg)
|
return report_error("class method must have a unique parameter names", b.location)
|
||||||
|| (keyword_list.contains(&x.node.arg) && x.node.arg != zelf)
|
}
|
||||||
{
|
if keyword_list.contains(&arg.node.arg) {
|
||||||
return Err(HashSet::from([
|
return report_error("parameter names should not be the same as the keywords", b.location)
|
||||||
format!("top level function must have unique parameter names \
|
|
||||||
and names should not be the same as the keywords (at {})",
|
|
||||||
x.location),
|
|
||||||
]))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if name == &"__init__".into() && !defined_parameter_name.contains(&zelf) {
|
// `self` must not be provided type annotation or default value
|
||||||
return Err(HashSet::from([
|
if args.args.len() == args.defaults.len() {
|
||||||
format!("__init__ method must have a `self` parameter (at {})", b.location),
|
return report_error("`self` cannot have a default value", b.location)
|
||||||
]))
|
|
||||||
}
|
}
|
||||||
if !defined_parameter_name.contains(&zelf) {
|
if args.args[0].node.annotation.is_some() {
|
||||||
return Err(HashSet::from([
|
return report_error("`self` cannot have a type annotation", b.location)
|
||||||
format!("class method must have a `self` parameter (at {})", b.location),
|
|
||||||
]))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut result = Vec::new();
|
let mut result = Vec::new();
|
||||||
|
let no_defaults = args.args.len() - args.defaults.len() - 1;
|
||||||
let arg_with_default: Vec<(
|
for (idx, x) in itertools::enumerate(args.args.iter().skip(1)) {
|
||||||
&ast::Located<ast::ArgData<()>>,
|
let type_ann = {
|
||||||
Option<&ast::Expr>,
|
let Some(annotation_expr) = x.node.annotation.as_ref() else {return report_error(format!("type annotation needed for `{}`", x.node.arg).as_str(), x.location)};
|
||||||
)> = args
|
parse_ast_to_type_annotation_kinds(
|
||||||
.args
|
class_resolver,
|
||||||
.iter()
|
temp_def_list,
|
||||||
.rev()
|
unifier,
|
||||||
.zip(
|
primitives,
|
||||||
args.defaults
|
annotation_expr,
|
||||||
.iter()
|
vec![(class_id, class_type_vars_def.clone())]
|
||||||
.rev()
|
.into_iter()
|
||||||
.map(|x| -> Option<&ast::Expr> { Some(x) })
|
.collect::<HashMap<_, _>>(),
|
||||||
.chain(std::iter::repeat(None)),
|
)?
|
||||||
)
|
};
|
||||||
.collect_vec();
|
// find type vars within this method parameter type annotation
|
||||||
|
let type_vars_within = get_type_var_contained_in_type_annotation(&type_ann);
|
||||||
for (x, default) in arg_with_default.into_iter().rev() {
|
// handle the class type var and the method type var
|
||||||
let name = x.node.arg;
|
for type_var_within in type_vars_within {
|
||||||
if name != zelf {
|
let TypeAnnotation::TypeVar(ty) = type_var_within else {
|
||||||
let type_ann = {
|
unreachable!("must be type var annotation")
|
||||||
let annotation_expr = x
|
|
||||||
.node
|
|
||||||
.annotation
|
|
||||||
.as_ref()
|
|
||||||
.ok_or_else(|| HashSet::from([
|
|
||||||
format!(
|
|
||||||
"type annotation needed for `{}` at {}",
|
|
||||||
x.node.arg, x.location
|
|
||||||
),
|
|
||||||
]))?
|
|
||||||
.as_ref();
|
|
||||||
parse_ast_to_type_annotation_kinds(
|
|
||||||
class_resolver,
|
|
||||||
temp_def_list,
|
|
||||||
unifier,
|
|
||||||
primitives,
|
|
||||||
annotation_expr,
|
|
||||||
vec![(class_id, class_type_vars_def.clone())]
|
|
||||||
.into_iter()
|
|
||||||
.collect::<HashMap<_, _>>(),
|
|
||||||
)?
|
|
||||||
};
|
};
|
||||||
// find type vars within this method parameter type annotation
|
|
||||||
let type_vars_within =
|
|
||||||
get_type_var_contained_in_type_annotation(&type_ann);
|
|
||||||
// handle the class type var and the method type var
|
|
||||||
for type_var_within in type_vars_within {
|
|
||||||
let TypeAnnotation::TypeVar(ty) = type_var_within else {
|
|
||||||
unreachable!("must be type var annotation")
|
|
||||||
};
|
|
||||||
|
|
||||||
let id = Self::get_var_id(ty, unifier)?;
|
let id = Self::get_var_id(ty, unifier)?;
|
||||||
if let Some(prev_ty) = method_var_map.insert(id, ty) {
|
if let Some(prev_ty) = method_var_map.insert(id, ty) {
|
||||||
// if already in the list, make sure they are the same?
|
// if already in the list, make sure they are the same?
|
||||||
assert_eq!(prev_ty, ty);
|
assert_eq!(prev_ty, ty);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
// finish handling type vars
|
|
||||||
let dummy_func_arg = FuncArg {
|
|
||||||
name,
|
|
||||||
ty: unifier.get_dummy_var().ty,
|
|
||||||
default_value: match default {
|
|
||||||
None => None,
|
|
||||||
Some(default) => {
|
|
||||||
if name == "self".into() {
|
|
||||||
return Err(HashSet::from([
|
|
||||||
format!("`self` parameter cannot take default value (at {})", x.location),
|
|
||||||
]));
|
|
||||||
}
|
|
||||||
Some({
|
|
||||||
let v = Self::parse_parameter_default_value(
|
|
||||||
default,
|
|
||||||
class_resolver,
|
|
||||||
)?;
|
|
||||||
Self::check_default_param_type(
|
|
||||||
&v, &type_ann, primitives, unifier,
|
|
||||||
)
|
|
||||||
.map_err(|err| HashSet::from([
|
|
||||||
format!("{} (at {})", err, x.location),
|
|
||||||
]))?;
|
|
||||||
v
|
|
||||||
})
|
|
||||||
}
|
|
||||||
},
|
|
||||||
is_vararg: false,
|
|
||||||
};
|
|
||||||
// push the dummy type and the type annotation
|
|
||||||
// into the list for later unification
|
|
||||||
type_var_to_concrete_def
|
|
||||||
.insert(dummy_func_arg.ty, type_ann.clone());
|
|
||||||
result.push(dummy_func_arg);
|
|
||||||
}
|
}
|
||||||
|
// finish handling type vars
|
||||||
|
let dummy_func_arg = FuncArg {
|
||||||
|
name: x.node.arg,
|
||||||
|
ty: unifier.get_dummy_var().ty,
|
||||||
|
default_value: if idx < no_defaults { None } else {
|
||||||
|
let default_idx = idx - no_defaults;
|
||||||
|
|
||||||
|
Some({
|
||||||
|
let v = Self::parse_parameter_default_value(&args.defaults[default_idx], class_resolver)?;
|
||||||
|
Self::check_default_param_type(&v, &type_ann, primitives, unifier).map_err(|err| report_error::<()>(err.as_str(), x.location).unwrap_err())?;
|
||||||
|
v
|
||||||
|
})
|
||||||
|
},
|
||||||
|
is_vararg: false,
|
||||||
|
};
|
||||||
|
// push the dummy type and the type annotation
|
||||||
|
// into the list for later unification
|
||||||
|
type_var_to_concrete_def
|
||||||
|
.insert(dummy_func_arg.ty, type_ann.clone());
|
||||||
|
result.push(dummy_func_arg);
|
||||||
}
|
}
|
||||||
result
|
result
|
||||||
};
|
};
|
||||||
@ -1440,23 +1299,13 @@ impl TopLevelComposer {
|
|||||||
match v {
|
match v {
|
||||||
ast::Constant::Bool(_) | ast::Constant::Str(_) | ast::Constant::Int(_) | ast::Constant::Float(_) => {}
|
ast::Constant::Bool(_) | ast::Constant::Str(_) | ast::Constant::Int(_) | ast::Constant::Float(_) => {}
|
||||||
_ => {
|
_ => {
|
||||||
return Err(HashSet::from([
|
return report_error("unsupported statement in class definition body", b.location)
|
||||||
format!(
|
|
||||||
"unsupported statement in class definition body (at {})",
|
|
||||||
b.location
|
|
||||||
),
|
|
||||||
]))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
class_attributes_def.push((*attr, dummy_field_type, v.clone()));
|
class_attributes_def.push((*attr, dummy_field_type, v.clone()));
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
return Err(HashSet::from([
|
return report_error("unsupported statement in class definition body", b.location)
|
||||||
format!(
|
|
||||||
"unsupported statement in class definition body (at {})",
|
|
||||||
b.location
|
|
||||||
),
|
|
||||||
]))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
annotation
|
annotation
|
||||||
@ -1482,43 +1331,22 @@ impl TopLevelComposer {
|
|||||||
};
|
};
|
||||||
|
|
||||||
if !class_type_vars_def.contains(&t) {
|
if !class_type_vars_def.contains(&t) {
|
||||||
return Err(HashSet::from([
|
return report_error("class fields can only use type vars over which the class is generic", b.location)
|
||||||
format!(
|
|
||||||
"class fields can only use type \
|
|
||||||
vars over which the class is generic (at {})",
|
|
||||||
annotation.location
|
|
||||||
),
|
|
||||||
]))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
type_var_to_concrete_def.insert(dummy_field_type, parsed_annotation);
|
type_var_to_concrete_def.insert(dummy_field_type, parsed_annotation);
|
||||||
} else {
|
} else {
|
||||||
return Err(HashSet::from([
|
return report_error(format!("same class fields `{}` defined twice", attr).as_str(), target.location)
|
||||||
format!(
|
|
||||||
"same class fields `{}` defined twice (at {})",
|
|
||||||
attr, target.location
|
|
||||||
),
|
|
||||||
]))
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return Err(HashSet::from([
|
return report_error("unsupported statement in class definition body", target.location)
|
||||||
format!(
|
|
||||||
"unsupported statement type in class definition body (at {})",
|
|
||||||
target.location
|
|
||||||
),
|
|
||||||
]))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ast::StmtKind::Assign { .. } // we don't class attributes
|
ast::StmtKind::Assign { .. } // we don't class attributes
|
||||||
| ast::StmtKind::Expr { value: _, .. } // typically a docstring; ignoring all expressions matches CPython behavior
|
| ast::StmtKind::Expr { value: _, .. } // typically a docstring; ignoring all expressions matches CPython behavior
|
||||||
| ast::StmtKind::Pass { .. } => {}
|
| ast::StmtKind::Pass { .. } => {}
|
||||||
_ => {
|
_ => {
|
||||||
return Err(HashSet::from([
|
return report_error("unsupported statement in class definition body", b.location)
|
||||||
format!(
|
|
||||||
"unsupported statement in class definition body (at {})",
|
|
||||||
b.location
|
|
||||||
),
|
|
||||||
]))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1570,44 +1398,27 @@ impl TopLevelComposer {
|
|||||||
|
|
||||||
// handle methods override
|
// handle methods override
|
||||||
// since we need to maintain the order, create a new list
|
// since we need to maintain the order, create a new list
|
||||||
let mut new_child_methods: Vec<(StrRef, Type, DefinitionId)> = Vec::new();
|
let mut new_child_methods: IndexMap<StrRef, (Type, DefinitionId)> = methods.iter().map(|m| (m.0.clone(), (m.1.clone(), m.2.clone()))).collect();
|
||||||
let mut is_override: HashSet<StrRef> = HashSet::new();
|
// let mut new_child_methods: Vec<(StrRef, Type, DefinitionId)> = methods.clone();
|
||||||
for (anc_method_name, anc_method_ty, anc_method_def_id) in methods {
|
for (class_method_name, class_method_ty, class_method_defid) in &*class_methods_def {
|
||||||
// find if there is a method with same name in the child class
|
if let Some((ty, _ ) ) = new_child_methods.insert(*class_method_name, (*class_method_ty, *class_method_defid)) {
|
||||||
let mut to_be_added = (*anc_method_name, *anc_method_ty, *anc_method_def_id);
|
let ok = class_method_name == &"__init__".into()
|
||||||
for (class_method_name, class_method_ty, class_method_defid) in &*class_methods_def {
|
|| Self::check_overload_function_type(
|
||||||
if class_method_name == anc_method_name {
|
*class_method_ty,
|
||||||
// ignore and handle self
|
ty,
|
||||||
// if is __init__ method, no need to check return type
|
unifier,
|
||||||
let ok = class_method_name == &"__init__".into()
|
type_var_to_concrete_def,
|
||||||
|| Self::check_overload_function_type(
|
);
|
||||||
*class_method_ty,
|
if !ok {
|
||||||
*anc_method_ty,
|
return Err(HashSet::from([format!(
|
||||||
unifier,
|
"method {class_method_name} has same name as ancestors' method, but incompatible type"),
|
||||||
type_var_to_concrete_def,
|
]));
|
||||||
);
|
|
||||||
if !ok {
|
|
||||||
return Err(HashSet::from([format!(
|
|
||||||
"method {class_method_name} has same name as ancestors' method, but incompatible type"),
|
|
||||||
]));
|
|
||||||
}
|
|
||||||
// mark it as added
|
|
||||||
is_override.insert(*class_method_name);
|
|
||||||
to_be_added = (*class_method_name, *class_method_ty, *class_method_defid);
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
new_child_methods.push(to_be_added);
|
|
||||||
}
|
}
|
||||||
// add those that are not overriding method to the new_child_methods
|
|
||||||
for (class_method_name, class_method_ty, class_method_defid) in &*class_methods_def {
|
|
||||||
if !is_override.contains(class_method_name) {
|
|
||||||
new_child_methods.push((*class_method_name, *class_method_ty, *class_method_defid));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// use the new_child_methods to replace all the elements in `class_methods_def`
|
|
||||||
class_methods_def.clear();
|
class_methods_def.clear();
|
||||||
class_methods_def.extend(new_child_methods);
|
class_methods_def.extend(new_child_methods.iter().map(|f| (f.0.clone(), f.1.0, f.1.1)).collect_vec());
|
||||||
|
let is_override: HashSet<StrRef> = HashSet::new();
|
||||||
|
|
||||||
// handle class fields
|
// handle class fields
|
||||||
let mut new_child_fields: Vec<(StrRef, Type, bool)> = Vec::new();
|
let mut new_child_fields: Vec<(StrRef, Type, bool)> = Vec::new();
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::symbol_resolver::SymbolValue;
|
use crate::symbol_resolver::SymbolValue;
|
||||||
use crate::toplevel::helper::{PrimDef, PrimDefDetails};
|
use crate::toplevel::helper::{PrimDef, PrimDefDetails};
|
||||||
|
use crate::typecheck::type_inferencer::report_error;
|
||||||
use crate::typecheck::typedef::VarMap;
|
use crate::typecheck::typedef::VarMap;
|
||||||
use nac3parser::ast::Constant;
|
use nac3parser::ast::Constant;
|
||||||
use strum::IntoEnumIterator;
|
use strum::IntoEnumIterator;
|
||||||
@ -97,7 +98,13 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
|
|||||||
Ok(TypeAnnotation::CustomClass { id: PrimDef::Exception.id(), params: Vec::default() })
|
Ok(TypeAnnotation::CustomClass { id: PrimDef::Exception.id(), params: Vec::default() })
|
||||||
} else if let Ok(obj_id) = resolver.get_identifier_def(*id) {
|
} else if let Ok(obj_id) = resolver.get_identifier_def(*id) {
|
||||||
let type_vars = {
|
let type_vars = {
|
||||||
let def_read = top_level_defs[obj_id.0].try_read();
|
let Some(top_level_def) = top_level_defs.get(obj_id.0) else {
|
||||||
|
return report_error(
|
||||||
|
format!("Name Error undefined name {id}").as_str(),
|
||||||
|
expr.location,
|
||||||
|
);
|
||||||
|
};
|
||||||
|
let def_read = top_level_def.try_read();
|
||||||
if let Some(def_read) = def_read {
|
if let Some(def_read) = def_read {
|
||||||
if let TopLevelDef::Class { type_vars, .. } = &*def_read {
|
if let TopLevelDef::Class { type_vars, .. } = &*def_read {
|
||||||
type_vars.clone()
|
type_vars.clone()
|
||||||
|
@ -520,6 +520,23 @@ pub fn typeof_binop(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Operator::MatMult => {
|
Operator::MatMult => {
|
||||||
|
// NOTE: NumPy matmul's LHS and RHS must both be ndarrays. Scalars are not allowed.
|
||||||
|
match (&*unifier.get_ty(lhs), &*unifier.get_ty(rhs)) {
|
||||||
|
(
|
||||||
|
TypeEnum::TObj { obj_id: lhs_obj_id, .. },
|
||||||
|
TypeEnum::TObj { obj_id: rhs_obj_id, .. },
|
||||||
|
) if *lhs_obj_id == primitives.ndarray.obj_id(unifier).unwrap()
|
||||||
|
&& *rhs_obj_id == primitives.ndarray.obj_id(unifier).unwrap() =>
|
||||||
|
{
|
||||||
|
// LHS and RHS have valid types
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
let lhs_str = unifier.stringify(lhs);
|
||||||
|
let rhs_str = unifier.stringify(rhs);
|
||||||
|
return Err(format!("ndarray.__matmul__ only accepts ndarray operands, but left operand has type {lhs_str}, and right operand has type {rhs_str}"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs);
|
let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs);
|
||||||
let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) {
|
let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) {
|
||||||
TypeEnum::TLiteral { values, .. } => {
|
TypeEnum::TLiteral { values, .. } => {
|
||||||
|
@ -114,7 +114,7 @@ impl Fold<()> for NaiveFolder {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn report_error<T>(msg: &str, location: Location) -> Result<T, InferenceError> {
|
pub fn report_error<T>(msg: &str, location: Location) -> Result<T, InferenceError> {
|
||||||
Err(HashSet::from([format!("{msg} at {location}")]))
|
Err(HashSet::from([format!("{msg} at {location}")]))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
26
nac3standalone/demo/interpreted.log
Normal file
26
nac3standalone/demo/interpreted.log
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
1
|
||||||
|
2
|
||||||
|
3
|
||||||
|
4
|
||||||
|
5
|
||||||
|
1
|
||||||
|
2
|
||||||
|
3
|
||||||
|
4
|
||||||
|
5
|
||||||
|
1
|
||||||
|
2
|
||||||
|
3
|
||||||
|
4
|
||||||
|
5
|
||||||
|
1
|
||||||
|
0
|
||||||
|
2
|
||||||
|
1
|
||||||
|
2
|
||||||
|
False
|
||||||
|
4
|
||||||
|
5
|
||||||
|
6
|
||||||
|
7
|
||||||
|
8
|
@ -314,6 +314,15 @@ fn main() {
|
|||||||
let resolver =
|
let resolver =
|
||||||
Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
|
Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
|
||||||
|
|
||||||
|
let context = inkwell::context::Context::create();
|
||||||
|
|
||||||
|
// Process IRRT
|
||||||
|
let irrt = load_irrt(&context, resolver.as_ref());
|
||||||
|
if emit_llvm {
|
||||||
|
irrt.write_bitcode_to_path(Path::new("irrt.bc"));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the Python script
|
||||||
let parser_result = parser::parse_program(&program, file_name.into()).unwrap();
|
let parser_result = parser::parse_program(&program, file_name.into()).unwrap();
|
||||||
|
|
||||||
for stmt in parser_result {
|
for stmt in parser_result {
|
||||||
@ -418,8 +427,8 @@ fn main() {
|
|||||||
registry.add_task(task);
|
registry.add_task(task);
|
||||||
registry.wait_tasks_complete(handles);
|
registry.wait_tasks_complete(handles);
|
||||||
|
|
||||||
|
// Link all modules together into `main`
|
||||||
let buffers = membuffers.lock();
|
let buffers = membuffers.lock();
|
||||||
let context = inkwell::context::Context::create();
|
|
||||||
let main = context
|
let main = context
|
||||||
.create_module_from_ir(MemoryBuffer::create_from_memory_range(&buffers[0], "main"))
|
.create_module_from_ir(MemoryBuffer::create_from_memory_range(&buffers[0], "main"))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -439,12 +448,9 @@ fn main() {
|
|||||||
main.link_in_module(other).unwrap();
|
main.link_in_module(other).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
let irrt = load_irrt(&context);
|
|
||||||
if emit_llvm {
|
|
||||||
irrt.write_bitcode_to_path(Path::new("irrt.bc"));
|
|
||||||
}
|
|
||||||
main.link_in_module(irrt).unwrap();
|
main.link_in_module(irrt).unwrap();
|
||||||
|
|
||||||
|
// Private all functions except "run"
|
||||||
let mut function_iter = main.get_first_function();
|
let mut function_iter = main.get_first_function();
|
||||||
while let Some(func) = function_iter {
|
while let Some(func) = function_iter {
|
||||||
if func.count_basic_blocks() > 0 && func.get_name().to_str().unwrap() != "run" {
|
if func.count_basic_blocks() > 0 && func.get_name().to_str().unwrap() != "run" {
|
||||||
@ -453,6 +459,7 @@ fn main() {
|
|||||||
function_iter = func.get_next_function();
|
function_iter = func.get_next_function();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Optimize `main`
|
||||||
let target_machine = llvm_options
|
let target_machine = llvm_options
|
||||||
.target
|
.target
|
||||||
.create_target_machine(llvm_options.opt_level)
|
.create_target_machine(llvm_options.opt_level)
|
||||||
@ -466,6 +473,7 @@ fn main() {
|
|||||||
panic!("Failed to run optimization for module `main`: {}", err.to_string());
|
panic!("Failed to run optimization for module `main`: {}", err.to_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Write output
|
||||||
target_machine
|
target_machine
|
||||||
.write_to_file(&main, FileType::Object, Path::new("module.o"))
|
.write_to_file(&main, FileType::Object, Path::new("module.o"))
|
||||||
.expect("couldn't write module to file");
|
.expect("couldn't write module to file");
|
||||||
|
BIN
pyo3/nac3artiq.so
Executable file
BIN
pyo3/nac3artiq.so
Executable file
Binary file not shown.
Loading…
Reference in New Issue
Block a user